#include <algorithm>
#include "Riostream.h"
#include "TRandom.h"
#include "TMath.h"
#include "TObjString.h"
#include "TMVA/MethodBDT.h"
#include "TMVA/Tools.h"
#include "TMVA/Timer.h"
#include "TMVA/Ranking.h"
#include "TMVA/SdivSqrtSplusB.h"
#include "TMVA/BinarySearchTree.h"
#include "TMVA/SeparationBase.h"
#include "TMVA/GiniIndex.h"
#include "TMVA/CrossEntropy.h"
#include "TMVA/MisClassificationError.h"
using std::vector;
ClassImp(TMVA::MethodBDT)
TMVA::MethodBDT::MethodBDT( const TString& jobName, const TString& methodTitle, DataSet& theData,
const TString& theOption, TDirectory* theTargetDir )
: TMVA::MethodBase( jobName, methodTitle, theData, theOption, theTargetDir )
{
InitBDT();
DeclareOptions();
ParseOptions();
ProcessOptions();
if (HasTrainingTree()) {
fLogger << kVERBOSE << "Method has been called " << Endl;
this->InitEventSample();
}
else {
fLogger << kWARNING << "No training Tree given: you will not be allowed to call ::Train etc." << Endl;
}
BaseDir()->cd();
fBoostWeightHist = new TH1F("BoostWeight","Ada Boost weights",100,1,100);
fBoostWeightVsTree = new TH1F("BoostWeightVsTree","Ada Boost weights",fNTrees,0,fNTrees);
fErrFractHist = new TH1F("ErrFractHist","error fraction vs tree number",fNTrees,0,fNTrees);
fNodesBeforePruningVsTree = new TH1I("NodesBeforePruning","nodes before pruning",fNTrees,0,fNTrees);
fNodesAfterPruningVsTree = new TH1I("NodesAfterPruning","nodes after pruning",fNTrees,0,fNTrees);
fMonitorNtuple= new TTree("MonitorNtuple","BDT variables");
fMonitorNtuple->Branch("iTree",&fITree,"iTree/I");
fMonitorNtuple->Branch("boostWeight",&fBoostWeight,"boostWeight/D");
fMonitorNtuple->Branch("errorFraction",&fErrorFraction,"errorFraction/D");
}
TMVA::MethodBDT::MethodBDT( DataSet& theData,
const TString& theWeightFile,
TDirectory* theTargetDir )
: TMVA::MethodBase( theData, theWeightFile, theTargetDir )
{
InitBDT();
DeclareOptions();
}
void TMVA::MethodBDT::DeclareOptions()
{
DeclareOptionRef(fNTrees, "NTrees", "Number of trees in the forest");
DeclareOptionRef(fBoostType, "BoostType", "Boosting type for the trees in the forest");
AddPreDefVal(TString("AdaBoost"));
AddPreDefVal(TString("Bagging"));
DeclareOptionRef(fUseYesNoLeaf=kTRUE, "UseYesNoLeaf",
"Use Sig or Bkg node type or the ratio S/B as classification in the leaf node");
DeclareOptionRef(fUseWeightedTrees=kTRUE, "UseWeightedTrees",
"Use weighted trees or simple average in classification from the forest");
DeclareOptionRef(fSepTypeS="GiniIndex", "SeparationType", "Separation criterion for node splitting");
AddPreDefVal(TString("MisClassificationError"));
AddPreDefVal(TString("GiniIndex"));
AddPreDefVal(TString("CrossEntropy"));
AddPreDefVal(TString("SDivSqrtSPlusB"));
DeclareOptionRef(fNodeMinEvents, "nEventsMin", "Minimum number of events in a leaf node (default: max(20, N_train/(Nvar^2)/10) ) ");
DeclareOptionRef(fNCuts, "nCuts", "Number of steps during node cut optimisation");
DeclareOptionRef(fPruneStrength, "PruneStrength", "Pruning strength (negative value == automatic adjustment)");
DeclareOptionRef(fPruneMethodS, "PruneMethod", "Pruning method: NoPruning (switched off), ExpectedError or CostComplexity");
AddPreDefVal(TString("NoPruning"));
AddPreDefVal(TString("ExpectedError"));
AddPreDefVal(TString("CostComplexity"));
AddPreDefVal(TString("CostComplexity2"));
}
void TMVA::MethodBDT::ProcessOptions()
{
MethodBase::ProcessOptions();
fSepTypeS.ToLower();
if (fSepTypeS == "misclassificationerror") fSepType = new MisClassificationError();
else if (fSepTypeS == "giniindex") fSepType = new GiniIndex();
else if (fSepTypeS == "crossentropy") fSepType = new CrossEntropy();
else if (fSepTypeS == "sdivsqrtsplusb") fSepType = new SdivSqrtSplusB();
else {
fLogger << kINFO << GetOptions() << Endl;
fLogger << kFATAL << "<ProcessOptions> unknown Separation Index option called" << Endl;
}
fPruneMethodS.ToLower();
if (fPruneMethodS == "expectederror" ) fPruneMethod = DecisionTree::kExpectedErrorPruning;
else if (fPruneMethodS == "costcomplexity" ) fPruneMethod = DecisionTree::kCostComplexityPruning;
else if (fPruneMethodS == "costcomplexity2" ) fPruneMethod = DecisionTree::kMCC;
else if (fPruneMethodS == "nopruning" ) fPruneMethod = DecisionTree::kNoPruning;
else {
fLogger << kINFO << GetOptions() << Endl;
fLogger << kFATAL << "<ProcessOptions> unknown PruneMethod option called" << Endl;
}
if (fPruneStrength < 0) fAutomatic = kTRUE;
else fAutomatic = kFALSE;
}
void TMVA::MethodBDT::InitBDT( void )
{
SetMethodName( "BDT" );
SetMethodType( Types::kBDT );
SetTestvarName();
fNTrees = 200;
fBoostType = "AdaBoost";
fNodeMinEvents = TMath::Max( 20, int( this->Data().GetNEvtTrain() / this->GetNvar()/ this->GetNvar() / 10) );
fNCuts = 20;
fPruneMethodS = "CostComplexity";
fPruneMethod = DecisionTree::kCostComplexityPruning;
fPruneStrength = 5;
fDeltaPruneStrength=0.1;
SetSignalReferenceCut( 0 );
}
TMVA::MethodBDT::~MethodBDT( void )
{
for (UInt_t i=0; i<fEventSample.size(); i++) delete fEventSample[i];
for (UInt_t i=0; i<fValidationSample.size(); i++) delete fValidationSample[i];
for (UInt_t i=0; i<fForest.size(); i++) delete fForest[i];
}
void TMVA::MethodBDT::InitEventSample( void )
{
if (!HasTrainingTree()) fLogger << kFATAL << "<Init> Data().TrainingTree() is zero pointer" << Endl;
Int_t nevents = Data().GetNEvtTrain();
Int_t ievt=0;
for (; ievt<nevents; ievt++) {
ReadTrainingEvent(ievt);
Event* event = new Event( GetEvent() );
if (ievt%2 == 0 || !fAutomatic ) fEventSample .push_back( event );
else fValidationSample.push_back( event );
}
fLogger << kINFO << "<InitEventSample> Internally I use " << fEventSample.size()
<< " for Training and " << fValidationSample.size()
<< " for Validation " << Endl;
}
void TMVA::MethodBDT::Train( void )
{
Bool_t pruneBeforeBoost = kFALSE;
if (!CheckSanity()) fLogger << kFATAL << "<Train> sanity check failed" << Endl;
if (IsNormalised()) fLogger << kFATAL << "\"Normalise\" option cannot be used with BDT; "
<< "please remove the option from the configuration string, or "
<< "use \"!Normalise\""
<< Endl;
fLogger << kINFO << "Training "<< fNTrees << " Decision Trees ... patience please" << Endl;
Timer timer( fNTrees, GetName() );
Int_t nNodesBeforePruningCount = 0;
Int_t nNodesAfterPruningCount = 0;
Int_t nNodesBeforePruning = 0;
Int_t nNodesAfterPruning = 0;
SeparationBase *qualitySepType = new GiniIndex();
for (int itree=0; itree<fNTrees; itree++) {
timer.DrawProgressBar( itree );
fForest.push_back( new DecisionTree( fSepType, fNodeMinEvents, fNCuts, qualitySepType ));
nNodesBeforePruning = fForest.back()->BuildTree(fEventSample);
if (itree==1 && fgDebugLevel==1) {
DecisionTree *d = new DecisionTree(*(fForest[itree]));
TH1D *h=new TH1D("h","CostComplexity",d->GetNNodes(),0,d->GetNNodes());
ofstream out1("theOriginal.txt");
ofstream out2("theCopy.txt");
fForest[itree]->Print(out1);
out2 << "************* pruned T " << 1 << " ****************" <<endl;
d->Print(out2);
Int_t count=1;
h->SetBinContent(count++,d->GetCostComplexity(fPruneStrength));
while (d->GetNNodes() > 3) {
d->FillQualityMap();
d->FillQualityGainMap();
multimap<Double_t, DecisionTreeNode* > qgm = d->GetQualityGainMap();
multimap<Double_t, DecisionTreeNode* >::iterator it=qgm.begin();
d->PruneNode(it->second);
out2 << "************* pruned T " << count << " ****************" <<endl;
d->Print(out2);
h->SetBinContent(count++,d->GetCostComplexity(fPruneStrength));
}
h->Write();
}
if (itree==1 && fgDebugLevel==1) {
DecisionTree *d = new DecisionTree(*(fForest[itree]));
TH1D *h=new TH1D("h2","Weakestlink",d->GetNNodes(),0,d->GetNNodes());
ofstream out2("theCopy2.txt");
out2 << "************* pruned T " << 1 << " ****************" <<endl;
d->Print(out2);
Int_t count=1;
while (d->GetNNodes() > 3) {
DecisionTreeNode *n = d->GetWeakestLink();
multimap<Double_t, DecisionTreeNode* > ls = d->GetLinkStrengthMap();
multimap<Double_t, DecisionTreeNode* >::iterator it=ls.begin();
fLogger << kINFO << "Nodes before " << d->CountNodes() << Endl;
h->SetBinContent(count++,it->first);
fLogger << kINFO << "Prune Node sequence: " << n->GetSequence() << ", depth:" << n->GetDepth() << Endl;
d->PruneNode(n);
fLogger << kINFO << "Nodes after " << d->CountNodes() << Endl;
for (it=ls.begin();it!=ls.end();it++) cout << it->first << " / ";
fLogger << kINFO << Endl;
out2 << "************* pruned T " << count << " ****************" <<endl;
d->Print(out2);
}
h->Write();
}
nNodesBeforePruningCount +=nNodesBeforePruning;
fNodesBeforePruningVsTree->SetBinContent(itree+1,nNodesBeforePruning);
if (pruneBeforeBoost && fPruneMethod != DecisionTree::kNoPruning) {
fForest.back()->SetPruneMethod(fPruneMethod);
fForest.back()->SetPruneStrength(fPruneStrength);
fForest.back()->PruneTree();
nNodesAfterPruning = fForest.back()->GetNNodes();
nNodesAfterPruningCount += nNodesAfterPruning;
fNodesAfterPruningVsTree->SetBinContent(itree+1,nNodesAfterPruning);
}
fBoostWeights.push_back( this->Boost(fEventSample, fForest.back(), itree) );
fITree = itree;
fMonitorNtuple->Fill();
}
fLogger << kINFO << "<Train> elapsed time: " << timer.GetElapsedTime()
<< " " << Endl;
if (!pruneBeforeBoost && fPruneMethod != DecisionTree::kNoPruning) {
fLogger << kINFO << "Pruning "<< fNTrees << " Decision Trees ... patience please" << Endl;
Timer timer2( fNTrees, GetName() );
TH1D *alpha = new TH1D("alpha","PruneStrengths",fNTrees,0,fNTrees);
alpha->SetXTitle("#tree");
alpha->SetYTitle("PruneStrength");
for (int itree=0; itree<fNTrees; itree++) {
timer2.DrawProgressBar( itree );
fForest[itree]->SetPruneMethod(fPruneMethod);
if (fAutomatic) {
fPruneStrength = this->PruneTree(fForest[itree], itree);
}
else{
fForest[itree]->SetPruneStrength(fPruneStrength);
fForest[itree]->PruneTree();
}
nNodesAfterPruning = fForest[itree]->GetNNodes();
nNodesAfterPruningCount += nNodesAfterPruning;
fNodesAfterPruningVsTree->SetBinContent(itree+1,nNodesAfterPruning);
alpha->SetBinContent(itree+1,fPruneStrength);
}
alpha->Write();
fLogger << kINFO << "<Train_Prune> elapsed time: " << timer2.GetElapsedTime()
<< " " << Endl;
}
if (fPruneMethod == DecisionTree::kNoPruning) {
fLogger << kINFO << "<Train> average number of nodes (w/o pruning) : "
<< nNodesBeforePruningCount/fNTrees << Endl;
}
else {
fLogger << kINFO << "<Train> average number of nodes before/after pruning : "
<< nNodesBeforePruningCount/fNTrees << " / "
<< nNodesAfterPruningCount/fNTrees
<< Endl;
}
}
Double_t TMVA::MethodBDT::PruneTree( DecisionTree *dt, Int_t itree)
{
Double_t alpha = 0;
Double_t delta = fDeltaPruneStrength;
DecisionTree* dcopy;
vector<Double_t> q;
multimap<Double_t,Double_t> quality;
Int_t nnodes=dt->GetNNodes();
Bool_t forceStop = kFALSE;
Int_t troubleCount=0, previousNnodes=nnodes;
nnodes=dt->GetNNodes();
while (nnodes > 3 && !forceStop) {
dcopy = new DecisionTree(*dt);
dcopy->SetPruneStrength(alpha+=delta);
dcopy->PruneTree();
q.push_back(this->TestTreeQuality((dcopy)));
quality.insert(pair<const Double_t,Double_t>(q.back(),alpha));
nnodes=dcopy->GetNNodes();
if (previousNnodes == nnodes) troubleCount++;
else {
troubleCount=0;
if (nnodes < previousNnodes / 2 ) fDeltaPruneStrength /= 2.;
}
previousNnodes = nnodes;
if (troubleCount > 20) {
if (itree == 0 && fPruneStrength <=0) {
fDeltaPruneStrength *= 5;
fLogger << kINFO << "<PruneTree> trouble determining optimal prune strength"
<< " for Tree " << itree
<< " --> first try to increase the step size"
<< " currently Prunestrenght= " << alpha
<< " stepsize " << fDeltaPruneStrength << " " << Endl;
troubleCount = 0;
fPruneStrength = 1;
}
else if (itree == 0 && fPruneStrength <=2) {
fDeltaPruneStrength *= 5;
fLogger << kINFO << "<PruneTree> trouble determining optimal prune strength"
<< " for Tree " << itree
<< " --> try to increase the step size even more.. "
<< " if that stitill didn't work, TRY IT BY HAND"
<< " currently Prunestrenght= " << alpha
<< " stepsize " << fDeltaPruneStrength << " " << Endl;
troubleCount = 0;
fPruneStrength = 3;
}
else{
forceStop=kTRUE;
fLogger << kINFO << "<PruneTree> trouble determining optimal prune strength"
<< " for Tree " << itree << " at tested prune strength: " << alpha
<< " --> abort forced, use same strength as for previous tree:"
<< fPruneStrength << Endl;
}
}
if (fgDebugLevel==1) fLogger << kINFO << "Pruneed with ("<<alpha
<< ") give quality: " << q.back()
<< " and #nodes: " << nnodes
<< Endl;
delete dcopy;
}
if (!forceStop) {
multimap<Double_t,Double_t>::reverse_iterator it=quality.rend();
it++;
fPruneStrength = it->second;
fDeltaPruneStrength *= Double_t(q.size())/20.;
}
char buffer[10];
sprintf (buffer,"quad%d",itree);
TH1D *qual=new TH1D(buffer,"Quality of tree prune steps",q.size(),0.,alpha);
qual->SetXTitle("PruneStrength");
qual->SetYTitle("TreeQuality (Purity)");
for (UInt_t i=0; i< q.size(); i++) {
qual->SetBinContent(i+1,q[i]);
}
qual->Write();
dt->SetPruneStrength(fPruneStrength);
dt->PruneTree();
return fPruneStrength;
}
Double_t TMVA::MethodBDT::TestTreeQuality( DecisionTree *dt )
{
Double_t ncorrect=0, nfalse=0;
for (UInt_t ievt=0; ievt<fValidationSample.size(); ievt++) {
Bool_t isSignalType= (dt->CheckEvent(*(fValidationSample[ievt])) > 0.5 ) ? 1 : 0;
if (isSignalType == (fValidationSample[ievt]->IsSignal()) ) {
ncorrect += fValidationSample[ievt]->GetWeight();
}
else{
nfalse += fValidationSample[ievt]->GetWeight();
}
}
return ncorrect / (ncorrect + nfalse);
}
Double_t TMVA::MethodBDT::Boost( vector<TMVA::Event*> eventSample, DecisionTree *dt, Int_t iTree )
{
if (fBoostType=="AdaBoost") return this->AdaBoost(eventSample, dt);
else if (fBoostType=="Bagging") return this->Bagging(eventSample, iTree);
else {
fLogger << kINFO << GetOptions() << Endl;
fLogger << kFATAL << "<Boost> unknown boost option called" << Endl;
}
return -1;
}
Double_t TMVA::MethodBDT::AdaBoost( vector<TMVA::Event*> eventSample, DecisionTree *dt )
{
Double_t adaBoostBeta=1.;
Double_t err=0, sumw=0, sumwfalse=0, count=0;
vector<Bool_t> correctSelected;
correctSelected.reserve(eventSample.size());
for (vector<TMVA::Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++) {
Bool_t isSignalType = (dt->CheckEvent(*(*e),fUseYesNoLeaf) > 0.5 );
Double_t w = (*e)->GetWeight();
sumw += w;
if (isSignalType == (*e)->IsSignal()) {
correctSelected.push_back(kTRUE);
}
else {
sumwfalse+= w;
count++;
correctSelected.push_back(kFALSE);
}
}
err = sumwfalse/sumw;
Double_t newSumw=0;
Int_t i=0;
Double_t boostWeight;
if (err>0) {
if (err > 0.5) {
fLogger << kWARNING << " The error rate in the BDT boosting is > 0.5. "
<< " That should not happen, please check your code (i.e... the BDT code) " << Endl;
}
if (adaBoostBeta == 1) {
boostWeight = (1-err)/err;
}
else {
boostWeight = TMath::Power((1.0 - err)/err, adaBoostBeta);
}
}
else {
boostWeight = 1000;
}
for (vector<TMVA::Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++) {
if (!correctSelected[i]) {
(*e)->SetWeight( (*e)->GetWeight() * boostWeight);
}
newSumw+=(*e)->GetWeight();
i++;
}
for (vector<TMVA::Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++) {
(*e)->SetWeight( (*e)->GetWeight() * sumw / newSumw );
}
fBoostWeightHist->Fill(boostWeight);
fBoostWeightVsTree->SetBinContent(fForest.size(),boostWeight);
fErrFractHist->SetBinContent(fForest.size(),err);
fBoostWeight = boostWeight;
fErrorFraction = err;
return TMath::Log(boostWeight);
}
Double_t TMVA::MethodBDT::Bagging( vector<TMVA::Event*> eventSample, Int_t iTree )
{
Double_t newSumw=0;
Double_t newWeight;
TRandom *trandom = new TRandom(iTree);
for (vector<TMVA::Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++) {
newWeight = trandom->Rndm();
(*e)->SetWeight(newWeight);
newSumw+=(*e)->GetWeight();
}
for (vector<TMVA::Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++) {
(*e)->SetWeight( (*e)->GetWeight() * eventSample.size() / newSumw );
}
return 1.;
}
void TMVA::MethodBDT::WriteWeightsToStream( ostream& o) const
{
o << "NTrees= " << fForest.size() <<endl;
for (UInt_t i=0; i< fForest.size(); i++) {
o << "Tree " << i << " boostWeight " << fBoostWeights[i] << endl;
(fForest[i])->Print(o);
}
}
void TMVA::MethodBDT::ReadWeightsFromStream( istream& istr )
{
TString var, dummy;
istr >> dummy >> fNTrees;
fLogger << kINFO << "Read " << fNTrees << " Decision trees" << Endl;
for (UInt_t i=0;i<fForest.size();i++) delete fForest[i];
fForest.clear();
fBoostWeights.clear();
Int_t iTree;
Double_t boostWeight;
for (int i=0;i<fNTrees;i++) {
istr >> dummy >> iTree >> dummy >> boostWeight;
if (iTree != i) {
fForest.back()->Print( cout );
fLogger << kFATAL << "Error while reading weight file; mismatch Itree="
<< iTree << " i=" << i
<< " dummy " << dummy
<< " boostweight " << boostWeight
<< Endl;
}
fForest.push_back( new DecisionTree() );
fForest.back()->Read(istr);
fBoostWeights.push_back(boostWeight);
}
}
Double_t TMVA::MethodBDT::GetMvaValue()
{
Double_t myMVA = 0;
Double_t norm = 0;
for (UInt_t itree=0; itree<fForest.size(); itree++) {
if (fUseWeightedTrees) {
myMVA += fBoostWeights[itree] * fForest[itree]->CheckEvent(GetEvent(),fUseYesNoLeaf);
norm += fBoostWeights[itree];
}
else {
myMVA += fForest[itree]->CheckEvent(GetEvent(),fUseYesNoLeaf);
norm += 1;
}
}
return myMVA /= norm;
}
void TMVA::MethodBDT::WriteMonitoringHistosToFile( void ) const
{
fLogger << kINFO << "Write monitoring histograms to file: " << BaseDir()->GetPath() << Endl;
fBoostWeightHist->Write();
fBoostWeightVsTree->Write();
fErrFractHist->Write();
fNodesBeforePruningVsTree->Write();
fNodesAfterPruningVsTree->Write();
fMonitorNtuple->Write();
}
vector< Double_t > TMVA::MethodBDT::GetVariableImportance()
{
fVariableImportance.resize(GetNvar());
Double_t sum=0;
for (int itree = 0; itree < fNTrees; itree++) {
vector<Double_t> relativeImportance(fForest[itree]->GetVariableImportance());
for (UInt_t i=0; i< relativeImportance.size(); i++) {
fVariableImportance[i] += relativeImportance[i];
}
}
for (UInt_t i=0; i< fVariableImportance.size(); i++) sum += fVariableImportance[i];
for (UInt_t i=0; i< fVariableImportance.size(); i++) fVariableImportance[i] /= sum;
return fVariableImportance;
}
Double_t TMVA::MethodBDT::GetVariableImportance( UInt_t ivar )
{
vector<Double_t> relativeImportance = this->GetVariableImportance();
if (ivar < (UInt_t)relativeImportance.size()) return relativeImportance[ivar];
else fLogger << kFATAL << "<GetVariableImportance> ivar = " << ivar << " is out of range " << Endl;
return -1;
}
const TMVA::Ranking* TMVA::MethodBDT::CreateRanking()
{
fRanking = new Ranking( GetName(), "Variable Importance" );
vector< Double_t> importance(this->GetVariableImportance());
for (Int_t ivar=0; ivar<GetNvar(); ivar++) {
fRanking->AddRank( *new Rank( GetInputExp(ivar), importance[ivar] ) );
}
return fRanking;
}
void TMVA::MethodBDT::GetHelpMessage() const
{
fLogger << Endl;
fLogger << Tools::Color("bold") << "--- Short description:" << Tools::Color("reset") << Endl;
fLogger << Endl;
fLogger << "Boosted Decision Trees are a collection of individual decision" << Endl;
fLogger << "trees which form a multivariate classifier by (weighted) majority " << Endl;
fLogger << "vote of the individual trees. Consecutive decision trees are " << Endl;
fLogger << "trained using the original training data set with re-weighted " << Endl;
fLogger << "events. By default, the AdaBoost method is employed, which gives " << Endl;
fLogger << "events that were misclassified in the previous tree a larger " << Endl;
fLogger << "weight in the training of the following tree." << Endl;
fLogger << Endl;
fLogger << "Decision trees are a sequence of binary splits of the data sample" << Endl;
fLogger << "using a single descriminant variable at a time. A test event " << Endl;
fLogger << "ending up after the sequence of left-right splits in a final " << Endl;
fLogger << "(\"leaf\") node is classified as either signal or background" << Endl;
fLogger << "depending on the majority type of training events in that node." << Endl;
fLogger << Endl;
fLogger << Tools::Color("bold") << "--- Performance optimisation:" << Tools::Color("reset") << Endl;
fLogger << Endl;
fLogger << "By the nature of the binary splits performed on the individual" << Endl;
fLogger << "variables, decision trees do not deal well with linear correlations" << Endl;
fLogger << "between variables (they need to approximate the linear split in" << Endl;
fLogger << "the two dimensional space by a sequence of splits on the two " << Endl;
fLogger << "variables individually). Hence decorrelation could be useful " << Endl;
fLogger << "to optimise the BDT performance." << Endl;
fLogger << Endl;
fLogger << Tools::Color("bold") << "--- Performance tuning via configuration options:" << Tools::Color("reset") << Endl;
fLogger << Endl;
fLogger << "The two most important parameters in the configuration are the " << Endl;
fLogger << "minimal number of events requested by a leaf node (option " << Endl;
fLogger << "\"nEventsMin\"). If this number is too large, detailed features " << Endl;
fLogger << "in the parameter space cannot be modeled. If it is too small, " << Endl;
fLogger << "the risk to overtain rises." << Endl;
fLogger << " (Imagine the decision tree is split until the leaf node contains" << Endl;
fLogger << " only a single event. In such a case, no training event is " << Endl;
fLogger << " misclassified, while the situation will look very different" << Endl;
fLogger << " for the test sample.)" << Endl;
fLogger << Endl;
fLogger << "The default minumal number is currently set to " << Endl;
fLogger << " max(20, (N_training_events / N_variables^2 / 10) " << Endl;
fLogger << "and can be changed by the user." << Endl;
fLogger << Endl;
fLogger << "The other crucial paramter, the pruning strength (\"PruneStrength\")," << Endl;
fLogger << "is also related to overtraining. It is a regularistion parameter " << Endl;
fLogger << "that is used when determining after the training which splits " << Endl;
fLogger << "are considered statistically insignificant and are removed. The" << Endl;
fLogger << "user is advised to carefully watch the BDT screen output for" << Endl;
fLogger << "the comparison between efficiencies obtained on the training and" << Endl;
fLogger << "the independent test sample. They should be equal within statistical" << Endl;
fLogger << "errors." << Endl;
}
void TMVA::MethodBDT::MakeClassSpecific( std::ostream& fout, const TString& className ) const
{
fout << " std::vector<"<<className<<"_DecisionTreeNode*> fForest; // i.e. root nodes of decision trees" << endl;
fout << " std::vector<double> fBoostWeights; // the weights applied in the individual boosts" << endl;
fout << "};" << endl;
fout << "double " << className << "::GetMvaValue__( const std::vector<double>& inputValues ) const{" << endl;
fout << " double myMVA = 0;" << endl;
fout << " double norm = 0;" << endl;
fout << "for (unsigned int itree=0; itree<fForest.size(); itree++){" << endl;
fout << " "<<className<<"_DecisionTreeNode *current = fForest[itree];" << endl;
fout << " while (current->GetNodeType() == 0){ //intermediate node" << endl;
fout << " if (current->GoesRight(inputValues)) current=("<<className<<"_DecisionTreeNode*)current->GetRight();" << endl;
fout << " else current=("<<className<<"_DecisionTreeNode*)current->GetLeft();" << endl;
fout << " }" << endl;
if (fUseWeightedTrees) {
if (fUseYesNoLeaf) fout << " myMVA += fBoostWeights[itree] * current->GetNodeType();" << endl;
else fout << " myMVA += fBoostWeights[itree] * current->GetPurity();" << endl;
fout << " norm += fBoostWeights[itree];" << endl;
}
else {
if (fUseYesNoLeaf) fout << " myMVA += current->GetNodeType();" << endl;
else fout << " myMVA += current->GetPurity();" << endl;
fout << " norm += 1.;" << endl;
}
fout << " }" << endl;
fout << " return myMVA /= norm;" << endl;
fout << "};" << endl;
fout << "void " << className << "::Initialize(){" << endl;
fout << " " << endl;
for (int itree=0; itree<fNTrees; itree++) {
fout << " // itree = " << itree << endl;
fout << " fBoostWeights.push_back("<<fBoostWeights[itree]<<");" << endl;
fout << " fForest.push_back( " << endl;
this->MakeClassInstantiateNode((DecisionTreeNode*)fForest[itree]->GetRoot(), fout, className);
fout <<" );" << endl;
}
fout << " return;" << endl;
fout << "};" << endl;
fout << " " << endl;
fout << "// Clean up" << endl;
fout << "inline void " << className << "::Clear() " << endl;
fout << "{" << endl;
fout << " for (unsigned int itree=0; itree<fForest.size(); itree++) { " << endl;
fout << " delete fForest[itree]; " << endl;
fout << " }" << endl;
fout << "}" << endl;
}
void TMVA::MethodBDT::MakeClassSpecificHeader( std::ostream& fout, const TString& className ) const
{
fout << "#define NN new "<<className<<"_DecisionTreeNode" << endl;
fout << "class "<<className<<"_DecisionTreeNode{" << endl;
fout << " " << endl;
fout << "public:" << endl;
fout << " " << endl;
fout << " // constructor of an essentially \"empty\" node floating in space" << endl;
fout << " "<<className<<"_DecisionTreeNode ( " << endl;
fout << " "<<className<<"_DecisionTreeNode* left," << endl;
fout << " "<<className<<"_DecisionTreeNode* right," << endl;
fout << " double cutValue, bool cutType, int selector," << endl;
fout << " int nodeType, double purity):" << endl;
fout << " fLeft(left)," << endl;
fout << " fRight(right)," << endl;
fout << " fCutValue(cutValue)," << endl;
fout << " fCutType(cutType)," << endl;
fout << " fSelector(selector)," << endl;
fout << " fNodeType(nodeType)," << endl;
fout << " fPurity(purity) {}" << endl;
fout << " virtual ~"<<className<<"_DecisionTreeNode(); " << endl;
fout << " // test event if it decends the tree at this node to the right" << endl;
fout << " virtual bool GoesRight( const std::vector<double>& inputValues ) const;" << endl;
fout << " "<<className<<"_DecisionTreeNode* GetRight( void ) {return fRight; };" << endl;
fout << " // test event if it decends the tree at this node to the left " << endl;
fout << " virtual bool GoesLeft ( const std::vector<double>& inputValues ) const;" << endl;
fout << " "<<className<<"_DecisionTreeNode* GetLeft( void ) {return fLeft; }; " << endl;
fout << " //return S/(S+B) (purity) at this node (from training)" << endl;
fout << " double GetPurity( void ) const {return fPurity;} " << endl;
fout << " //return the node type" << endl;
fout << " int GetNodeType( void ) const {return fNodeType;}" << endl;
fout << "private:" << endl;
fout << " "<<className<<"_DecisionTreeNode* fLeft; // pointer to the left daughter node" << endl;
fout << " "<<className<<"_DecisionTreeNode* fRight; // pointer to the right daughter node" << endl;
fout << " double fCutValue;// cut value appplied on this node to discriminate bkg against sig" << endl;
fout << " bool fCutType; // true: if event variable > cutValue ==> signal , false otherwise" << endl;
fout << " int fSelector;// index of variable used in node selection (decision tree) " << endl;
fout << " int fNodeType;// Type of node: -1 == Bkg-leaf, 1 == Signal-leaf, 0 = internal " << endl;
fout << " double fPurity; // Purity of node from training"<< endl;
fout << "}; " << endl;
fout << "//_______________________________________________________________________" << endl;
fout << " "<<className<<"_DecisionTreeNode::~"<<className<<"_DecisionTreeNode(){ " << endl;
fout << " if (fLeft != NULL) delete fLeft;" << endl;
fout << " if (fRight != NULL) delete fRight;" << endl;
fout << "}; " << endl;
fout << "//_______________________________________________________________________" << endl;
fout << "bool "<<className<<"_DecisionTreeNode::GoesRight(const std::vector<double>& inputValues) const{" << endl;
fout << " // test event if it decends the tree at this node to the right" << endl;
fout << " bool result = (inputValues[fSelector] > fCutValue );" << endl;
fout << " if (fCutType == true) return result; //the cuts are selecting Signal ;" << endl;
fout << " else return !result;" << endl;
fout << "}" << endl;
fout << "//_______________________________________________________________________" << endl;
fout << "bool "<<className<<"_DecisionTreeNode::GoesLeft(const std::vector<double>& inputValues) const{" << endl;
fout << " // test event if it decends the tree at this node to the left" << endl;
fout << " if (!this->GoesRight(inputValues)) return true;" << endl;
fout << " else return false;" << endl;
fout << "}" << endl;
}
void TMVA::MethodBDT::MakeClassInstantiateNode( DecisionTreeNode *n, std::ostream& fout, const TString& className ) const
{
if (n == NULL) {
fLogger << kFATAL << "MakeClassInstantiateNode: started with undefined node" <<Endl;
return ;
}
fout << "NN("<<endl;
if (n->GetLeft() != NULL){
this->MakeClassInstantiateNode( (DecisionTreeNode*)n->GetLeft() , fout, className);
}
else {
fout << "0";
}
fout << ", " <<endl;
if (n->GetRight() != NULL){
this->MakeClassInstantiateNode( (DecisionTreeNode*)n->GetRight(), fout, className );
}
else {
fout << "0";
}
fout << ", " << endl
<< setprecision(6)
<< n->GetCutValue() << ", "
<< n->GetCutType() << ", "
<< n->GetSelector() << ", "
<< n->GetNodeType() << ", "
<< n->GetPurity() << ") ";
}
Last update: Thu Jan 17 08:58:58 2008
This page has been automatically generated. If you have any comments or suggestions about the page layout send a mail to ROOT support, or contact the developers with any questions or problems regarding ROOT.