#include <algorithm>
#include "Riostream.h"
#include "TRandom.h"
#include "TRandom2.h"
#include "TMath.h"
#include "TObjString.h"
#include "TMVA/MethodBDT.h"
#include "TMVA/Tools.h"
#include "TMVA/Timer.h"
#include "TMVA/Ranking.h"
#include "TMVA/SdivSqrtSplusB.h"
#include "TMVA/BinarySearchTree.h"
#include "TMVA/SeparationBase.h"
#include "TMVA/GiniIndex.h"
#include "TMVA/CrossEntropy.h"
#include "TMVA/MisClassificationError.h"
using std::vector;
ClassImp(TMVA::MethodBDT)
TMVA::MethodBDT::MethodBDT( const TString& jobName, const TString& methodTitle, DataSet& theData,
const TString& theOption, TDirectory* theTargetDir )
: TMVA::MethodBase( jobName, methodTitle, theData, theOption, theTargetDir )
{
InitBDT();
DeclareOptions();
ParseOptions();
ProcessOptions();
if (HasTrainingTree()) {
fLogger << kVERBOSE << "Method has been called " << Endl;
this->InitEventSample();
}
else {
fLogger << kWARNING << "No training Tree given: you will not be allowed to call ::Train etc." << Endl;
}
BaseDir()->cd();
fBoostWeightHist = new TH1F("BoostWeight","Ada Boost weights",100,1,100);
fBoostWeightVsTree = new TH1F("BoostWeightVsTree","Ada Boost weights",fNTrees,0,fNTrees);
fErrFractHist = new TH1F("ErrFractHist","error fraction vs tree number",fNTrees,0,fNTrees);
fNodesBeforePruningVsTree = new TH1I("NodesBeforePruning","nodes before pruning",fNTrees,0,fNTrees);
fNodesAfterPruningVsTree = new TH1I("NodesAfterPruning","nodes after pruning",fNTrees,0,fNTrees);
fMonitorNtuple= new TTree("MonitorNtuple","BDT variables");
fMonitorNtuple->Branch("iTree",&fITree,"iTree/I");
fMonitorNtuple->Branch("boostWeight",&fBoostWeight,"boostWeight/D");
fMonitorNtuple->Branch("errorFraction",&fErrorFraction,"errorFraction/D");
}
TMVA::MethodBDT::MethodBDT( DataSet& theData,
const TString& theWeightFile,
TDirectory* theTargetDir )
: TMVA::MethodBase( theData, theWeightFile, theTargetDir )
{
InitBDT();
DeclareOptions();
}
void TMVA::MethodBDT::DeclareOptions()
{
DeclareOptionRef(fNTrees, "NTrees", "Number of trees in the forest");
DeclareOptionRef(fBoostType, "BoostType", "Boosting type for the trees in the forest");
AddPreDefVal(TString("AdaBoost"));
AddPreDefVal(TString("Bagging"));
DeclareOptionRef(fUseYesNoLeaf=kTRUE, "UseYesNoLeaf",
"Use Sig or Bkg node type or the ratio S/B as classification in the leaf node");
DeclareOptionRef(fUseWeightedTrees=kTRUE, "UseWeightedTrees",
"Use weighted trees or simple average in classification from the forest");
DeclareOptionRef(fSepTypeS="GiniIndex", "SeparationType", "Separation criterion for node splitting");
AddPreDefVal(TString("MisClassificationError"));
AddPreDefVal(TString("GiniIndex"));
AddPreDefVal(TString("CrossEntropy"));
AddPreDefVal(TString("SDivSqrtSPlusB"));
DeclareOptionRef(fNodeMinEvents, "nEventsMin", "Minimum number of events in a leaf node (default: max(20, N_train/(Nvar^2)/10) ) ");
DeclareOptionRef(fNCuts, "nCuts", "Number of steps during node cut optimisation");
DeclareOptionRef(fPruneStrength, "PruneStrength", "Pruning strength");
DeclareOptionRef(fPruneMethodS, "PruneMethod", "Pruning method: NoPruning (switched off), ExpectedError or CostComplexity");
AddPreDefVal(TString("NoPruning"));
AddPreDefVal(TString("ExpectedError"));
AddPreDefVal(TString("CostComplexity"));
AddPreDefVal(TString("CostComplexity2"));
DeclareOptionRef(fNoNegWeightsInTraining,"NoNegWeightsInTraining","Ignore negative event weights in the training process" );
DeclareOptionRef(fRandomisedTrees,"UseRandomisedTrees","Choose at each node splitting a random set of variables");
DeclareOptionRef(fUseNvars,"UseNvars","the number of variables used if randomised Tree option is chosen");
}
void TMVA::MethodBDT::ProcessOptions()
{
MethodBase::ProcessOptions();
fSepTypeS.ToLower();
if (fSepTypeS == "misclassificationerror") fSepType = new MisClassificationError();
else if (fSepTypeS == "giniindex") fSepType = new GiniIndex();
else if (fSepTypeS == "crossentropy") fSepType = new CrossEntropy();
else if (fSepTypeS == "sdivsqrtsplusb") fSepType = new SdivSqrtSplusB();
else {
fLogger << kINFO << GetOptions() << Endl;
fLogger << kFATAL << "<ProcessOptions> unknown Separation Index option called" << Endl;
}
fPruneMethodS.ToLower();
if (fPruneMethodS == "expectederror" ) fPruneMethod = DecisionTree::kExpectedErrorPruning;
else if (fPruneMethodS == "costcomplexity" ) fPruneMethod = DecisionTree::kCostComplexityPruning;
else if (fPruneMethodS == "costcomplexity2" ) fPruneMethod = DecisionTree::kMCC;
else if (fPruneMethodS == "nopruning" ) fPruneMethod = DecisionTree::kNoPruning;
else {
fLogger << kINFO << GetOptions() << Endl;
fLogger << kFATAL << "<ProcessOptions> unknown PruneMethod option called" << Endl;
}
if (fPruneStrength < 0) fAutomatic = kTRUE;
else fAutomatic = kFALSE;
if (this->Data().HasNegativeEventWeights()){
fLogger << kINFO << " You are using a Monte Carlo that has also negative weights. That should in principle be fine as long as on average you end up with something positive. For this you have to make sure that the minimal number of (unweighted) events demanded for a tree node (currently you use: nEventsMin="<<fNodeMinEvents<<", you can set this via the BDT option string when booking the classifier) is large enough to allow for reasonable averaging!!! "
<< " If this does not help.. maybe you want to try the option: NoNegWeightsInTraining which ignores events with negative weight in the training. " << Endl
<<Endl << "Note: You'll get a WARNING message during the training if that should ever happen" << Endl;
}
if (fRandomisedTrees){
fLogger << kINFO << " Randomised trees use *bagging* as *boost* method and no pruning" << Endl;
fPruneMethod = DecisionTree::kNoPruning;
fBoostType = "Bagging";
}
}
void TMVA::MethodBDT::InitBDT( void )
{
SetMethodName( "BDT" );
SetMethodType( Types::kBDT );
SetTestvarName();
fNTrees = 200;
fBoostType = "AdaBoost";
fNodeMinEvents = TMath::Max( 20, int( this->Data().GetNEvtTrain() / this->GetNvar()/ this->GetNvar() / 10) );
fNCuts = 20;
fPruneMethodS = "CostComplexity";
fPruneMethod = DecisionTree::kCostComplexityPruning;
fPruneStrength = 5;
fDeltaPruneStrength=0.1;
fNoNegWeightsInTraining=kFALSE;
fRandomisedTrees= kFALSE;
fUseNvars = GetNvar();
SetSignalReferenceCut( 0 );
}
TMVA::MethodBDT::~MethodBDT( void )
{
for (UInt_t i=0; i<fEventSample.size(); i++) delete fEventSample[i];
for (UInt_t i=0; i<fValidationSample.size(); i++) delete fValidationSample[i];
for (UInt_t i=0; i<fForest.size(); i++) delete fForest[i];
}
void TMVA::MethodBDT::InitEventSample( void )
{
if (!HasTrainingTree()) fLogger << kFATAL << "<Init> Data().TrainingTree() is zero pointer" << Endl;
Int_t nevents = Data().GetNEvtTrain();
Int_t ievt=0;
Bool_t first=kTRUE;
for (; ievt<nevents; ievt++) {
ReadTrainingEvent(ievt);
Event* event = new Event( GetEvent() );
if ( ! (fNoNegWeightsInTraining && event->GetWeight() < 0 ) ) {
if (first){
first = kFALSE;
fLogger << kINFO << "Events with negative event weights are ignored during the BDT training (option NoNegWeightsInTraining="<< fNoNegWeightsInTraining << Endl;
}
if ( ievt%2 == 0 || !fAutomatic ) fEventSample .push_back( event );
else fValidationSample.push_back( event );
}
}
fLogger << kINFO << "<InitEventSample> Internally I use " << fEventSample.size()
<< " for Training and " << fValidationSample.size()
<< " for Validation " << Endl;
}
void TMVA::MethodBDT::Train( void )
{
Bool_t pruneBeforeBoost = kFALSE;
if (!CheckSanity()) fLogger << kFATAL << "<Train> sanity check failed" << Endl;
if (IsNormalised()) fLogger << kFATAL << "\"Normalise\" option cannot be used with BDT; "
<< "please remove the option from the configuration string, or "
<< "use \"!Normalise\""
<< Endl;
fLogger << kINFO << "Training "<< fNTrees << " Decision Trees ... patience please" << Endl;
Timer timer( fNTrees, GetName() );
Int_t nNodesBeforePruningCount = 0;
Int_t nNodesAfterPruningCount = 0;
Int_t nNodesBeforePruning = 0;
Int_t nNodesAfterPruning = 0;
SeparationBase *qualitySepType = new GiniIndex();
TH1D *alpha = new TH1D("alpha","PruneStrengths",fNTrees,0,fNTrees);
alpha->SetXTitle("#tree");
alpha->SetYTitle("PruneStrength");
for (int itree=0; itree<fNTrees; itree++) {
timer.DrawProgressBar( itree );
fForest.push_back( new DecisionTree( fSepType, fNodeMinEvents, fNCuts, qualitySepType,
fRandomisedTrees, fUseNvars, 123+itree));
nNodesBeforePruning = fForest.back()->BuildTree(fEventSample);
if (itree==1 && fgDebugLevel==1) {
DecisionTree *d = new DecisionTree(*(fForest[itree]));
TH1D *h=new TH1D("h","CostComplexity",d->GetNNodes(),0,d->GetNNodes());
ofstream out1("theOriginal.txt");
ofstream out2("theCopy.txt");
fForest[itree]->Print(out1);
out2 << "************* pruned T " << 1 << " ****************" <<endl;
d->Print(out2);
Int_t count=1;
h->SetBinContent(count++,d->GetCostComplexity(fPruneStrength));
while (d->GetNNodes() > 3) {
d->FillQualityMap();
d->FillQualityGainMap();
multimap<Double_t, DecisionTreeNode* >& qgm = d->GetQualityGainMap();
multimap<Double_t, DecisionTreeNode* >::iterator it=qgm.begin();
d->PruneNode(it->second);
out2 << "************* pruned T " << count << " ****************" <<endl;
d->Print(out2);
h->SetBinContent(count++,d->GetCostComplexity(fPruneStrength));
}
h->Write();
}
if (itree==1 && fgDebugLevel==1) {
DecisionTree *d = new DecisionTree(*(fForest[itree]));
TH1D *h=new TH1D("h2","Weakestlink",d->GetNNodes(),0,d->GetNNodes());
ofstream out2("theCopy2.txt");
out2 << "************* pruned T " << 1 << " ****************" <<endl;
d->Print(out2);
Int_t count=1;
while (d->GetNNodes() > 3) {
DecisionTreeNode *n = d->GetWeakestLink();
multimap<Double_t, DecisionTreeNode* >& lsm = d->GetLinkStrengthMap();
multimap<Double_t, DecisionTreeNode* >::iterator it=lsm.begin();
fLogger << kINFO << "Nodes before " << d->CountNodes() << Endl;
h->SetBinContent(count++,it->first);
fLogger << kINFO << "Prune Node sequence: " << n->GetSequence() << ", depth:" << n->GetDepth() << Endl;
d->PruneNode(n);
fLogger << kINFO << "Nodes after " << d->CountNodes() << Endl;
for (it=lsm.begin();it!=lsm.end();it++) cout << it->first << " / ";
fLogger << kINFO << Endl;
out2 << "************* pruned T " << count << " ****************" <<endl;
d->Print(out2);
}
h->Write();
}
nNodesBeforePruningCount +=nNodesBeforePruning;
fNodesBeforePruningVsTree->SetBinContent(itree+1,nNodesBeforePruning);
if (pruneBeforeBoost && fPruneMethod != DecisionTree::kNoPruning) {
fForest.back()->SetPruneMethod(fPruneMethod);
fForest.back()->SetPruneStrength(fPruneStrength);
fForest.back()->PruneTree();
if (fUseYesNoLeaf){
fForest.back()->CleanTree();
}
nNodesAfterPruning = fForest.back()->GetNNodes();
nNodesAfterPruningCount += nNodesAfterPruning;
fNodesAfterPruningVsTree->SetBinContent(itree+1,nNodesAfterPruning);
fBoostWeights.push_back( this->Boost(fEventSample, fForest.back(), itree) );
}
else if (!pruneBeforeBoost && fPruneMethod != DecisionTree::kNoPruning) {
fBoostWeights.push_back( this->Boost(fEventSample, fForest.back(), itree) );
fForest.back()->SetPruneMethod(fPruneMethod);
if (fAutomatic) {
fPruneStrength = this->PruneTree(fForest.back(), itree);
}
else{
fForest.back()->SetPruneStrength(fPruneStrength);
fForest.back()->PruneTree();
}
if (fUseYesNoLeaf){
fForest.back()->CleanTree();
}
nNodesAfterPruning = fForest.back()->GetNNodes();
nNodesAfterPruningCount += nNodesAfterPruning;
fNodesAfterPruningVsTree->SetBinContent(itree+1,nNodesAfterPruning);
alpha->SetBinContent(itree+1,fPruneStrength);
}
else {
if (fUseYesNoLeaf){
fForest.back()->CleanTree();
}
fBoostWeights.push_back( this->Boost(fEventSample, fForest.back(), itree) );
}
fITree = itree;
fMonitorNtuple->Fill();
}
alpha->Write();
fLogger << kINFO << "<Train> elapsed time: " << timer.GetElapsedTime()
<< " " << Endl;
if (fPruneMethod == DecisionTree::kNoPruning) {
fLogger << kINFO << "<Train> average number of nodes (w/o pruning) : "
<< nNodesBeforePruningCount/fNTrees << Endl;
}
else {
fLogger << kINFO << "<Train> average number of nodes before/after pruning : "
<< nNodesBeforePruningCount/fNTrees << " / "
<< nNodesAfterPruningCount/fNTrees
<< Endl;
}
}
Double_t TMVA::MethodBDT::PruneTree( DecisionTree *dt, Int_t itree)
{
Double_t alpha = 0;
Double_t delta = fDeltaPruneStrength;
DecisionTree* dcopy;
vector<Double_t> q;
multimap<Double_t,Double_t> quality;
Int_t nnodes=dt->GetNNodes();
Bool_t forceStop = kFALSE;
Int_t troubleCount=0, previousNnodes=nnodes;
nnodes=dt->GetNNodes();
while (nnodes > 3 && !forceStop) {
dcopy = new DecisionTree(*dt);
dcopy->SetPruneStrength(alpha+=delta);
dcopy->PruneTree();
q.push_back(this->TestTreeQuality((dcopy)));
quality.insert(pair<const Double_t,Double_t>(q.back(),alpha));
nnodes=dcopy->GetNNodes();
if (previousNnodes == nnodes) troubleCount++;
else {
troubleCount=0;
if (nnodes < previousNnodes / 2 ) fDeltaPruneStrength /= 2.;
}
previousNnodes = nnodes;
if (troubleCount > 20) {
if (itree == 0 && fPruneStrength <=0) {
fDeltaPruneStrength *= 5;
fLogger << kINFO << "<PruneTree> trouble determining optimal prune strength"
<< " for Tree " << itree
<< " --> first try to increase the step size"
<< " currently Prunestrenght= " << alpha
<< " stepsize " << fDeltaPruneStrength << " " << Endl;
troubleCount = 0;
fPruneStrength = 1;
}
else if (itree == 0 && fPruneStrength <=2) {
fDeltaPruneStrength *= 5;
fLogger << kINFO << "<PruneTree> trouble determining optimal prune strength"
<< " for Tree " << itree
<< " --> try to increase the step size even more.. "
<< " if that stitill didn't work, TRY IT BY HAND"
<< " currently Prunestrenght= " << alpha
<< " stepsize " << fDeltaPruneStrength << " " << Endl;
troubleCount = 0;
fPruneStrength = 3;
}
else{
forceStop=kTRUE;
fLogger << kINFO << "<PruneTree> trouble determining optimal prune strength"
<< " for Tree " << itree << " at tested prune strength: " << alpha
<< " --> abort forced, use same strength as for previous tree:"
<< fPruneStrength << Endl;
}
}
if (fgDebugLevel==1) fLogger << kINFO << "Pruneed with ("<<alpha
<< ") give quality: " << q.back()
<< " and #nodes: " << nnodes
<< Endl;
delete dcopy;
}
if (!forceStop) {
multimap<Double_t,Double_t>::reverse_iterator it=quality.rend();
it++;
fPruneStrength = it->second;
fDeltaPruneStrength *= Double_t(q.size())/20.;
}
char buffer[10];
sprintf (buffer,"quad%d",itree);
TH1D *qual=new TH1D(buffer,"Quality of tree prune steps",q.size(),0.,alpha);
qual->SetXTitle("PruneStrength");
qual->SetYTitle("TreeQuality (Purity)");
for (UInt_t i=0; i< q.size(); i++) {
qual->SetBinContent(i+1,q[i]);
}
qual->Write();
dt->SetPruneStrength(fPruneStrength);
dt->PruneTree();
return fPruneStrength;
}
Double_t TMVA::MethodBDT::TestTreeQuality( DecisionTree *dt )
{
Double_t ncorrect=0, nfalse=0;
for (UInt_t ievt=0; ievt<fValidationSample.size(); ievt++) {
Bool_t isSignalType= (dt->CheckEvent(*(fValidationSample[ievt])) > 0.5 ) ? 1 : 0;
if (isSignalType == (fValidationSample[ievt]->IsSignal()) ) {
ncorrect += fValidationSample[ievt]->GetWeight();
}
else{
nfalse += fValidationSample[ievt]->GetWeight();
}
}
return ncorrect / (ncorrect + nfalse);
}
Double_t TMVA::MethodBDT::Boost( vector<TMVA::Event*> eventSample, DecisionTree *dt, Int_t iTree )
{
if (fBoostType=="AdaBoost") return this->AdaBoost(eventSample, dt);
else if (fBoostType=="Bagging") return this->Bagging(eventSample, iTree);
else {
fLogger << kINFO << GetOptions() << Endl;
fLogger << kFATAL << "<Boost> unknown boost option called" << Endl;
}
return -1;
}
Double_t TMVA::MethodBDT::AdaBoost( vector<TMVA::Event*> eventSample, DecisionTree *dt )
{
Double_t adaBoostBeta=1.;
Double_t err=0, sumw=0, sumwfalse=0;
for (vector<TMVA::Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++) {
Bool_t isSignalType = (dt->CheckEvent(*(*e),fUseYesNoLeaf) > 0.5 );
Double_t w = (*e)->GetWeight();
sumw += w;
if (!(isSignalType == (*e)->IsSignal())) {
sumwfalse+= w;
}
}
err = sumwfalse/sumw;
Double_t newSumw=0;
Int_t i=0;
Double_t boostWeight=1.;
if (err > 0.5) {
fLogger << kWARNING << " The error rate in the BDT boosting is > 0.5. "
<< " That should not happen, please check your code (i.e... the BDT code), I "
<< " set it to 0.5.. just to continue.." << Endl;
err = 0.5;
} else if (err < 0) {
fLogger << kWARNING << " The error rate in the BDT boosting is < 0. That can happen"
<< " due to improper treatment of negative weights in a Monte Carlo.. (if you have"
<< " an idea on how to do it in a better way, please let me know (Helge.Voss@cern.ch)"
<< " for the time being I set it to its absolute value.. just to continue.." << Endl;
err = TMath::Abs(err);
}
if (adaBoostBeta == 1) {
boostWeight = (1-err)/err;
}
else {
boostWeight = TMath::Power((1.0 - err)/err, adaBoostBeta);
}
for (vector<TMVA::Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++) {
if (!( (dt->CheckEvent(*(*e),fUseYesNoLeaf) > 0.5 ) == (*e)->IsSignal())) {
if ( (*e)->GetWeight() > 0 ){
(*e)->SetBoostWeight( (*e)->GetBoostWeight() * boostWeight);
} else {
(*e)->SetBoostWeight( (*e)->GetBoostWeight() / boostWeight);
}
}
newSumw+=(*e)->GetWeight();
i++;
}
for (vector<TMVA::Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++) {
(*e)->SetBoostWeight( (*e)->GetBoostWeight() * sumw / newSumw );
}
fBoostWeightHist->Fill(boostWeight);
fBoostWeightVsTree->SetBinContent(fForest.size(),boostWeight);
fErrFractHist->SetBinContent(fForest.size(),err);
fBoostWeight = boostWeight;
fErrorFraction = err;
return TMath::Log(boostWeight);
}
Double_t TMVA::MethodBDT::Bagging( vector<TMVA::Event*> eventSample, Int_t iTree )
{
Double_t newSumw=0;
Double_t newWeight;
TRandom2 *trandom = new TRandom2(iTree);
for (vector<TMVA::Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++) {
newWeight = trandom->PoissonD(1);
(*e)->SetBoostWeight(newWeight);
newSumw+=(*e)->GetBoostWeight();
}
for (vector<TMVA::Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++) {
(*e)->SetBoostWeight( (*e)->GetBoostWeight() * eventSample.size() / newSumw );
}
return 1.;
}
void TMVA::MethodBDT::WriteWeightsToStream( ostream& o) const
{
o << "NTrees= " << fForest.size() <<endl;
for (UInt_t i=0; i< fForest.size(); i++) {
o << "Tree " << i << " boostWeight " << fBoostWeights[i] << endl;
(fForest[i])->Print(o);
}
}
void TMVA::MethodBDT::ReadWeightsFromStream( istream& istr )
{
TString var, dummy;
istr >> dummy >> fNTrees;
fLogger << kINFO << "Read " << fNTrees << " Decision trees" << Endl;
for (UInt_t i=0;i<fForest.size();i++) delete fForest[i];
fForest.clear();
fBoostWeights.clear();
Int_t iTree;
Double_t boostWeight;
for (int i=0;i<fNTrees;i++) {
istr >> dummy >> iTree >> dummy >> boostWeight;
if (iTree != i) {
fForest.back()->Print( cout );
fLogger << kFATAL << "Error while reading weight file; mismatch Itree="
<< iTree << " i=" << i
<< " dummy " << dummy
<< " boostweight " << boostWeight
<< Endl;
}
fForest.push_back( new DecisionTree() );
fForest.back()->Read(istr);
fBoostWeights.push_back(boostWeight);
}
}
Double_t TMVA::MethodBDT::GetMvaValue()
{
Double_t myMVA = 0;
Double_t norm = 0;
for (UInt_t itree=0; itree<fForest.size(); itree++) {
if (fUseWeightedTrees) {
myMVA += fBoostWeights[itree] * fForest[itree]->CheckEvent(GetEvent(),fUseYesNoLeaf);
norm += fBoostWeights[itree];
}
else {
myMVA += fForest[itree]->CheckEvent(GetEvent(),fUseYesNoLeaf);
norm += 1;
}
}
return myMVA /= norm;
}
void TMVA::MethodBDT::WriteMonitoringHistosToFile( void ) const
{
fLogger << kINFO << "Write monitoring histograms to file: " << BaseDir()->GetPath() << Endl;
fBoostWeightHist->Write();
fBoostWeightVsTree->Write();
fErrFractHist->Write();
fNodesBeforePruningVsTree->Write();
fNodesAfterPruningVsTree->Write();
fMonitorNtuple->Write();
}
vector< Double_t > TMVA::MethodBDT::GetVariableImportance()
{
fVariableImportance.resize(GetNvar());
Double_t sum=0;
for (int itree = 0; itree < fNTrees; itree++) {
vector<Double_t> relativeImportance(fForest[itree]->GetVariableImportance());
for (UInt_t i=0; i< relativeImportance.size(); i++) {
fVariableImportance[i] += relativeImportance[i];
}
}
for (UInt_t i=0; i< fVariableImportance.size(); i++) sum += fVariableImportance[i];
for (UInt_t i=0; i< fVariableImportance.size(); i++) fVariableImportance[i] /= sum;
return fVariableImportance;
}
Double_t TMVA::MethodBDT::GetVariableImportance( UInt_t ivar )
{
vector<Double_t> relativeImportance = this->GetVariableImportance();
if (ivar < (UInt_t)relativeImportance.size()) return relativeImportance[ivar];
else fLogger << kFATAL << "<GetVariableImportance> ivar = " << ivar << " is out of range " << Endl;
return -1;
}
const TMVA::Ranking* TMVA::MethodBDT::CreateRanking()
{
fRanking = new Ranking( GetName(), "Variable Importance" );
vector< Double_t> importance(this->GetVariableImportance());
for (Int_t ivar=0; ivar<GetNvar(); ivar++) {
fRanking->AddRank( *new Rank( GetInputExp(ivar), importance[ivar] ) );
}
return fRanking;
}
void TMVA::MethodBDT::GetHelpMessage() const
{
fLogger << Endl;
fLogger << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
fLogger << Endl;
fLogger << "Boosted Decision Trees are a collection of individual decision" << Endl;
fLogger << "trees which form a multivariate classifier by (weighted) majority " << Endl;
fLogger << "vote of the individual trees. Consecutive decision trees are " << Endl;
fLogger << "trained using the original training data set with re-weighted " << Endl;
fLogger << "events. By default, the AdaBoost method is employed, which gives " << Endl;
fLogger << "events that were misclassified in the previous tree a larger " << Endl;
fLogger << "weight in the training of the following tree." << Endl;
fLogger << Endl;
fLogger << "Decision trees are a sequence of binary splits of the data sample" << Endl;
fLogger << "using a single descriminant variable at a time. A test event " << Endl;
fLogger << "ending up after the sequence of left-right splits in a final " << Endl;
fLogger << "(\"leaf\") node is classified as either signal or background" << Endl;
fLogger << "depending on the majority type of training events in that node." << Endl;
fLogger << Endl;
fLogger << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
fLogger << Endl;
fLogger << "By the nature of the binary splits performed on the individual" << Endl;
fLogger << "variables, decision trees do not deal well with linear correlations" << Endl;
fLogger << "between variables (they need to approximate the linear split in" << Endl;
fLogger << "the two dimensional space by a sequence of splits on the two " << Endl;
fLogger << "variables individually). Hence decorrelation could be useful " << Endl;
fLogger << "to optimise the BDT performance." << Endl;
fLogger << Endl;
fLogger << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
fLogger << Endl;
fLogger << "The two most important parameters in the configuration are the " << Endl;
fLogger << "minimal number of events requested by a leaf node (option " << Endl;
fLogger << "\"nEventsMin\"). If this number is too large, detailed features " << Endl;
fLogger << "in the parameter space cannot be modeled. If it is too small, " << Endl;
fLogger << "the risk to overtain rises." << Endl;
fLogger << " (Imagine the decision tree is split until the leaf node contains" << Endl;
fLogger << " only a single event. In such a case, no training event is " << Endl;
fLogger << " misclassified, while the situation will look very different" << Endl;
fLogger << " for the test sample.)" << Endl;
fLogger << Endl;
fLogger << "The default minumal number is currently set to " << Endl;
fLogger << " max(20, (N_training_events / N_variables^2 / 10) " << Endl;
fLogger << "and can be changed by the user." << Endl;
fLogger << Endl;
fLogger << "The other crucial paramter, the pruning strength (\"PruneStrength\")," << Endl;
fLogger << "is also related to overtraining. It is a regularistion parameter " << Endl;
fLogger << "that is used when determining after the training which splits " << Endl;
fLogger << "are considered statistically insignificant and are removed. The" << Endl;
fLogger << "user is advised to carefully watch the BDT screen output for" << Endl;
fLogger << "the comparison between efficiencies obtained on the training and" << Endl;
fLogger << "the independent test sample. They should be equal within statistical" << Endl;
fLogger << "errors." << Endl;
}
void TMVA::MethodBDT::MakeClassSpecific( std::ostream& fout, const TString& className ) const
{
fout << " std::vector<BDT_DecisionTreeNode*> fForest; // i.e. root nodes of decision trees" << endl;
fout << " std::vector<double> fBoostWeights; // the weights applied in the individual boosts" << endl;
fout << "};" << endl << endl;
fout << "double " << className << "::GetMvaValue__( const std::vector<double>& inputValues ) const" << endl;
fout << "{" << endl;
fout << " double myMVA = 0;" << endl;
fout << " double norm = 0;" << endl;
fout << " for (unsigned int itree=0; itree<fForest.size(); itree++){" << endl;
fout << " BDT_DecisionTreeNode *current = fForest[itree];" << endl;
fout << " while (current->GetNodeType() == 0) { //intermediate node" << endl;
fout << " if (current->GoesRight(inputValues)) current=(BDT_DecisionTreeNode*)current->GetRight();" << endl;
fout << " else current=(BDT_DecisionTreeNode*)current->GetLeft();" << endl;
fout << " }" << endl;
if (fUseWeightedTrees) {
if (fUseYesNoLeaf) fout << " myMVA += fBoostWeights[itree] * current->GetNodeType();" << endl;
else fout << " myMVA += fBoostWeights[itree] * current->GetPurity();" << endl;
fout << " norm += fBoostWeights[itree];" << endl;
}
else {
if (fUseYesNoLeaf) fout << " myMVA += current->GetNodeType();" << endl;
else fout << " myMVA += current->GetPurity();" << endl;
fout << " norm += 1.;" << endl;
}
fout << " }" << endl;
fout << " return myMVA /= norm;" << endl;
fout << "};" << endl << endl;
fout << "void " << className << "::Initialize()" << endl;
fout << "{" << endl;
for (int itree=0; itree<fNTrees; itree++) {
fout << " // itree = " << itree << endl;
fout << " fBoostWeights.push_back(" << fBoostWeights[itree] << ");" << endl;
fout << " fForest.push_back( " << endl;
this->MakeClassInstantiateNode((DecisionTreeNode*)fForest[itree]->GetRoot(), fout, className);
fout <<" );" << endl;
}
fout << " return;" << endl;
fout << "};" << endl;
fout << " " << endl;
fout << "// Clean up" << endl;
fout << "inline void " << className << "::Clear() " << endl;
fout << "{" << endl;
fout << " for (unsigned int itree=0; itree<fForest.size(); itree++) { " << endl;
fout << " delete fForest[itree]; " << endl;
fout << " }" << endl;
fout << "}" << endl;
}
void TMVA::MethodBDT::MakeClassSpecificHeader( std::ostream& fout, const TString& ) const
{
fout << "#ifndef NN" << endl;
fout << "#define NN new BDT_DecisionTreeNode" << endl;
fout << "#endif" << endl;
fout << " " << endl;
fout << "#ifndef BDT_DecisionTreeNode__def" << endl;
fout << "#define BDT_DecisionTreeNode__def" << endl;
fout << " " << endl;
fout << "class BDT_DecisionTreeNode {" << endl;
fout << " " << endl;
fout << "public:" << endl;
fout << " " << endl;
fout << " // constructor of an essentially \"empty\" node floating in space" << endl;
fout << " BDT_DecisionTreeNode ( BDT_DecisionTreeNode* left," << endl;
fout << " BDT_DecisionTreeNode* right," << endl;
fout << " double cutValue, bool cutType, int selector," << endl;
fout << " int nodeType, double purity ) :" << endl;
fout << " fLeft ( left )," << endl;
fout << " fRight ( right )," << endl;
fout << " fCutValue( cutValue )," << endl;
fout << " fCutType ( cutType )," << endl;
fout << " fSelector( selector )," << endl;
fout << " fNodeType( nodeType )," << endl;
fout << " fPurity ( purity ) {}" << endl << endl;
fout << " virtual ~BDT_DecisionTreeNode();" << endl << endl;
fout << " // test event if it decends the tree at this node to the right" << endl;
fout << " virtual bool GoesRight( const std::vector<double>& inputValues ) const;" << endl;
fout << " BDT_DecisionTreeNode* GetRight( void ) {return fRight; };" << endl << endl;
fout << " // test event if it decends the tree at this node to the left " << endl;
fout << " virtual bool GoesLeft ( const std::vector<double>& inputValues ) const;" << endl;
fout << " BDT_DecisionTreeNode* GetLeft( void ) { return fLeft; }; " << endl << endl;
fout << " // return S/(S+B) (purity) at this node (from training)" << endl << endl;
fout << " double GetPurity( void ) const { return fPurity; } " << endl;
fout << " // return the node type" << endl;
fout << " int GetNodeType( void ) const { return fNodeType; }" << endl << endl;
fout << "private:" << endl << endl;
fout << " BDT_DecisionTreeNode* fLeft; // pointer to the left daughter node" << endl;
fout << " BDT_DecisionTreeNode* fRight; // pointer to the right daughter node" << endl;
fout << " double fCutValue; // cut value appplied on this node to discriminate bkg against sig" << endl;
fout << " bool fCutType; // true: if event variable > cutValue ==> signal , false otherwise" << endl;
fout << " int fSelector; // index of variable used in node selection (decision tree) " << endl;
fout << " int fNodeType; // Type of node: -1 == Bkg-leaf, 1 == Signal-leaf, 0 = internal " << endl;
fout << " double fPurity; // Purity of node from training"<< endl;
fout << "}; " << endl;
fout << " " << endl;
fout << "//_______________________________________________________________________" << endl;
fout << "BDT_DecisionTreeNode::~BDT_DecisionTreeNode()" << endl;
fout << "{" << endl;
fout << " if (fLeft != NULL) delete fLeft;" << endl;
fout << " if (fRight != NULL) delete fRight;" << endl;
fout << "}; " << endl;
fout << " " << endl;
fout << "//_______________________________________________________________________" << endl;
fout << "bool BDT_DecisionTreeNode::GoesRight( const std::vector<double>& inputValues ) const" << endl;
fout << "{" << endl;
fout << " // test event if it decends the tree at this node to the right" << endl;
fout << " bool result = (inputValues[fSelector] > fCutValue );" << endl;
fout << " if (fCutType == true) return result; //the cuts are selecting Signal ;" << endl;
fout << " else return !result;" << endl;
fout << "}" << endl;
fout << " " << endl;
fout << "//_______________________________________________________________________" << endl;
fout << "bool BDT_DecisionTreeNode::GoesLeft( const std::vector<double>& inputValues ) const" << endl;
fout << "{" << endl;
fout << " // test event if it decends the tree at this node to the left" << endl;
fout << " if (!this->GoesRight(inputValues)) return true;" << endl;
fout << " else return false;" << endl;
fout << "}" << endl;
fout << " " << endl;
fout << "#endif" << endl;
fout << " " << endl;
}
void TMVA::MethodBDT::MakeClassInstantiateNode( DecisionTreeNode *n, std::ostream& fout, const TString& className ) const
{
if (n == NULL) {
fLogger << kFATAL << "MakeClassInstantiateNode: started with undefined node" <<Endl;
return ;
}
fout << "NN("<<endl;
if (n->GetLeft() != NULL){
this->MakeClassInstantiateNode( (DecisionTreeNode*)n->GetLeft() , fout, className);
}
else {
fout << "0";
}
fout << ", " <<endl;
if (n->GetRight() != NULL){
this->MakeClassInstantiateNode( (DecisionTreeNode*)n->GetRight(), fout, className );
}
else {
fout << "0";
}
fout << ", " << endl
<< setprecision(6)
<< n->GetCutValue() << ", "
<< n->GetCutType() << ", "
<< n->GetSelector() << ", "
<< n->GetNodeType() << ", "
<< n->GetPurity() << ") ";
}
Last change: Wed Jun 25 08:48:18 2008
Last generated: 2008-06-25 08:48
This page has been automatically generated. If you have any comments or suggestions about the page layout send a mail to ROOT support, or contact the developers with any questions or problems regarding ROOT.