#include "TMVA/MethodBDT.h"
#include "TMVA/Tools.h"
#include "TMVA/Timer.h"
#include "Riostream.h"
#include "TRandom.h"
#include <algorithm>
#include "TObjString.h"
#include "TMVA/Ranking.h"
using std::vector;
ClassImp(TMVA::MethodBDT)
;
TMVA::MethodBDT::MethodBDT( TString jobName, TString methodTitle, DataSet& theData,
TString theOption, TDirectory* theTargetDir )
: TMVA::MethodBase( jobName, methodTitle, theData, theOption, theTargetDir )
{
InitBDT();
DeclareOptions();
ParseOptions();
ProcessOptions();
if (HasTrainingTree()) {
fLogger << kVERBOSE << "method has been called " << Endl;
this->InitEventSample();
}
else {
fLogger << kWARNING << "no training Tree given: you will not be allowed to call ::Train etc." << Endl;
}
fBoostWeightHist = new TH1F("fBoostWeight","Ada Boost weights",100,1,100);
fBoostWeightVsTree = new TH1F("fBoostWeightVsTree","Ada Boost weights",fNTrees,0,fNTrees);
fErrFractHist = new TH1F("fErrFractHist","error fraction vs tree number",fNTrees,0,fNTrees);
fNodesBeforePruningVsTree = new TH1I("fNodesBeforePruning","nodes before pruning",fNTrees,0,fNTrees);
fNodesAfterPruningVsTree = new TH1I("fNodesAfterPruning","nodes after pruning",fNTrees,0,fNTrees);
fMonitorNtuple= new TTree("fMonitorNtuple","BDT variables");
fMonitorNtuple->Branch("iTree",&fITree,"iTree/I");
fMonitorNtuple->Branch("boostWeight",&fBoostWeight,"boostWeight/D");
fMonitorNtuple->Branch("errorFraction",&fErrorFraction,"errorFraction/D");
}
TMVA::MethodBDT::MethodBDT( DataSet& theData,
TString theWeightFile,
TDirectory* theTargetDir )
: TMVA::MethodBase( theData, theWeightFile, theTargetDir )
{
InitBDT();
DeclareOptions();
}
void TMVA::MethodBDT::DeclareOptions()
{
DeclareOptionRef(fNTrees, "NTrees", "number of trees in the forest");
DeclareOptionRef(fBoostType, "BoostType", "boosting type for the trees in the forest");
AddPreDefVal(TString("AdaBoost"));
AddPreDefVal(TString("Bagging"));
DeclareOptionRef(fUseYesNoLeaf=kTRUE, "UseYesNoLeaf", "use Sig or Bkg node type or the ratio S/B as classification in the leaf node");
DeclareOptionRef(fUseWeightedTrees=kTRUE, "UseWeightedTrees", "use weighted trees or simple average in classification from the forest");
DeclareOptionRef(fSepTypeS="GiniIndex", "SeparationType", "separation criterion for node splitting");
AddPreDefVal(TString("MisClassificationError"));
AddPreDefVal(TString("GiniIndex"));
AddPreDefVal(TString("CrossEntropy"));
AddPreDefVal(TString("SDivSqrtSPlusB"));
DeclareOptionRef(fNodeMinEvents, "nEventsMin", "minimum number of events in a leaf node");
DeclareOptionRef(fNCuts, "nCuts", "number of steps during node cut optimisation");
DeclareOptionRef(fPruneStrength, "PruneStrength", "a parameter to adjust the amount of pruning. Should be large enouth such that overtraining is avoided, or negative == automatic (takes time)");
DeclareOptionRef(fPruneMethodS, "PruneMethod", "Pruning method: Expected Error or Cost Complexity");
AddPreDefVal(TString("ExpectedError"));
AddPreDefVal(TString("CostComplexity"));
AddPreDefVal(TString("CostComplexity2"));
}
void TMVA::MethodBDT::ProcessOptions()
{
MethodBase::ProcessOptions();
fSepTypeS.ToLower();
if (fSepTypeS == "misclassificationerror") fSepType = new TMVA::MisClassificationError();
else if (fSepTypeS == "giniindex") fSepType = new TMVA::GiniIndex();
else if (fSepTypeS == "crossentropy") fSepType = new TMVA::CrossEntropy();
else if (fSepTypeS == "sdivsqrtsplusb") fSepType = new TMVA::SdivSqrtSplusB();
else {
fLogger << kINFO << GetOptions() << Endl;
fLogger << kFATAL << "<ProcessOptions> unknown Separation Index option called" << Endl;
}
fPruneMethodS.ToLower();
if (fPruneMethodS == "expectederror" ) fPruneMethod = TMVA::DecisionTree::kExpectedErrorPruning;
else if (fPruneMethodS == "costcomplexity" ) fPruneMethod = TMVA::DecisionTree::kCostComplexityPruning;
else if (fPruneMethodS == "costcomplexity2" ) fPruneMethod = TMVA::DecisionTree::kMCC;
else {
fLogger << kINFO << GetOptions() << Endl;
fLogger << kFATAL << "<ProcessOptions> unknown PruneMethod option called" << Endl;
}
if (fPruneStrength < 0) fAutomatic = kTRUE;
else fAutomatic = kFALSE;
}
void TMVA::MethodBDT::InitBDT( void )
{
SetMethodName( "BDT" );
SetMethodType( TMVA::Types::kBDT );
SetTestvarName();
fNTrees = 200;
fBoostType = "AdaBoost";
fNodeMinEvents = 10;
fNCuts = 20;
fPruneMethod = TMVA::DecisionTree::kMCC;
fPruneStrength = 5;
fDeltaPruneStrength=0.1;
}
TMVA::MethodBDT::~MethodBDT( void )
{
for (UInt_t i=0; i<fEventSample.size(); i++) delete fEventSample[i];
for (UInt_t i=0; i<fValidationSample.size(); i++) delete fValidationSample[i];
for (UInt_t i=0; i<fForest.size(); i++) delete fForest[i];
}
void TMVA::MethodBDT::InitEventSample( void )
{
if (!HasTrainingTree()) fLogger << kFATAL << "<Init> Data().TrainingTree() is zero pointer" << Endl;
Int_t nevents = Data().GetNEvtTrain();
Int_t ievt=0;
for (; ievt<nevents; ievt++){
ReadTrainingEvent(ievt);
if (ievt%2 == 0 || !fAutomatic ) {
fEventSample.push_back(new TMVA::Event(Data().Event()));
}else{
fValidationSample.push_back(new TMVA::Event(Data().Event()));
}
}
fLogger << kINFO << "<InitEventSample> : internally I use " << fEventSample.size()
<< " for Training and " << fValidationSample.size()
<< " for Validation " << Endl;
}
void TMVA::MethodBDT::Train( void )
{
if (!CheckSanity()) fLogger << kFATAL << "<Train> sanity check failed" << Endl;
fLogger << kINFO << "will train "<< fNTrees << " Decision Trees ... patience please" << Endl;
TMVA::Timer timer( fNTrees, GetName() );
Int_t nNodesBeforePruningCount = 0;
Int_t nNodesAfterPruningCount = 0;
Int_t nNodesBeforePruning = 0;
Int_t nNodesAfterPruning = 0;
TMVA::SeparationBase *qualitySepType = new TMVA::GiniIndex();
for (int itree=0; itree<fNTrees; itree++){
timer.DrawProgressBar( itree );
fForest.push_back(new TMVA::DecisionTree(fSepType,
fNodeMinEvents,fNCuts, qualitySepType));
nNodesBeforePruning = fForest.back()->BuildTree(fEventSample);
if (itree==1 && fgDebugLevel==1){
DecisionTree *d = new DecisionTree(*(fForest[itree]));
TH1D *h=new TH1D("h","CostComplexity",d->GetNNodes(),0,d->GetNNodes());
ofstream out1("theOriginal.txt");
ofstream out2("theCopy.txt");
fForest[itree]->Print(out1);
out2 << "************* pruned T " << 1 << " ****************" <<endl;
d->Print(out2);
Int_t count=1;
h->SetBinContent(count++,d->GetCostComplexity(fPruneStrength));
while (d->GetNNodes() > 3) {
d->FillQualityMap();
d->FillQualityGainMap();
multimap<Double_t, TMVA::DecisionTreeNode* > qgm = d->GetQualityGainMap();
multimap<Double_t, TMVA::DecisionTreeNode* >::iterator it=qgm.begin();
d->PruneNode(it->second);
out2 << "************* pruned T " << count << " ****************" <<endl;
d->Print(out2);
h->SetBinContent(count++,d->GetCostComplexity(fPruneStrength));
}
h->Write();
}
if (itree==1 && fgDebugLevel==1){
DecisionTree *d = new DecisionTree(*(fForest[itree]));
TH1D *h=new TH1D("h2","Weakestlink",d->GetNNodes(),0,d->GetNNodes());
ofstream out2("theCopy2.txt");
out2 << "************* pruned T " << 1 << " ****************" <<endl;
d->Print(out2);
Int_t count=1;
while (d->GetNNodes() > 3) {
DecisionTreeNode *n = d->GetWeakestLink();
multimap<Double_t, TMVA::DecisionTreeNode* > ls = d->GetLinkStrengthMap();
multimap<Double_t, TMVA::DecisionTreeNode* >::iterator it=ls.begin();
cout << "nodes before " << d->CountNodes() << endl;
h->SetBinContent(count++,it->first);
cout << " Prune Node seq: " << n->GetSequence() << " depth=" << n->GetDepth() <<endl;
d->PruneNode(n);
cout << "nodes after " << d->CountNodes() << endl;
for (it=ls.begin();it!=ls.end();it++) cout << it->first << " / " ;
cout << endl;
out2 << "************* pruned T " << count << " ****************" <<endl;
d->Print(out2);
}
h->Write();
}
nNodesBeforePruningCount +=nNodesBeforePruning;
fNodesBeforePruningVsTree->SetBinContent(itree+1,nNodesBeforePruning);
fBoostWeights.push_back( this->Boost(fEventSample, fForest.back(), itree) );
fITree = itree;
fMonitorNtuple->Fill();
}
fLogger << kINFO << "<Train> elapsed time: " << timer.GetElapsedTime()
<< " " << Endl;
fLogger << kINFO << "will prune "<< fNTrees << " Decision Trees ... patience please" << Endl;
TMVA::Timer timer2( fNTrees, GetName() );
TH1D *alpha = new TH1D("alpha","PruneStrengths",fNTrees,0,fNTrees);
alpha->SetXTitle("#tree");
alpha->SetYTitle("PruneStrength");
for (int itree=0; itree<fNTrees; itree++){
timer2.DrawProgressBar( itree );
fForest[itree]->SetPruneMethod(fPruneMethod);
if (fAutomatic){
fPruneStrength = this->PruneTree(fForest[itree], itree);
}else{
fForest[itree]->SetPruneStrength(fPruneStrength);
fForest[itree]->PruneTree();
}
nNodesAfterPruning = fForest[itree]->GetNNodes();
nNodesAfterPruningCount += nNodesAfterPruning;
fNodesAfterPruningVsTree->SetBinContent(itree+1,nNodesAfterPruning);
alpha->SetBinContent(itree+1,fPruneStrength);
}
alpha->Write();
fLogger << kINFO << "<Train> average number of nodes before/after pruning : "
<< nNodesBeforePruningCount/fNTrees << " / "
<< nNodesAfterPruningCount/fNTrees
<< Endl;
fLogger << kINFO << "<Train_Prune> elapsed time: " << timer2.GetElapsedTime()
<< " " << Endl;
}
Double_t TMVA::MethodBDT::PruneTree( TMVA::DecisionTree *dt, Int_t itree)
{
Double_t alpha = 0;
Double_t delta = fDeltaPruneStrength;
TMVA::DecisionTree* dcopy;
vector<Double_t> q;
multimap<Double_t,Double_t> quality;
Int_t nnodes=dt->GetNNodes();
Bool_t forceStop = kFALSE;
Int_t troubleCount=0, previousNnodes=nnodes;
nnodes=dt->GetNNodes();
while (nnodes > 3 && !forceStop) {
dcopy = new DecisionTree(*dt);
dcopy->SetPruneStrength(alpha+=delta);
dcopy->PruneTree();
q.push_back(this->TestTreeQuality((dcopy)));
quality.insert(pair<const Double_t,Double_t>(q.back(),alpha));
nnodes=dcopy->GetNNodes();
if (previousNnodes == nnodes) troubleCount++;
else {
troubleCount=0;
if (nnodes < previousNnodes / 2 ) fDeltaPruneStrength /= 2.;
}
previousNnodes = nnodes;
if (troubleCount > 20) {
if (itree == 0 && fPruneStrength <=0){
fDeltaPruneStrength *= 5;
fLogger << kINFO << "<PruneTree> trouble determining optimal prune strength"
<< " for Tree " << itree
<< " --> first try to increase the step size"
<< " currently Prunestrenght= " << alpha
<< " stepsize " << fDeltaPruneStrength << " " << Endl;
troubleCount = 0;
fPruneStrength = 1;
} else if (itree == 0 && fPruneStrength <=2){
fDeltaPruneStrength *= 5;
fLogger << kINFO << "<PruneTree> trouble determining optimal prune strength"
<< " for Tree " << itree
<< " --> try to increase the step size even more.. "
<< " if that stitill didn't work, TRY IT BY HAND"
<< " currently Prunestrenght= " << alpha
<< " stepsize " << fDeltaPruneStrength << " " << Endl;
troubleCount = 0;
fPruneStrength = 3;
}else{
forceStop=kTRUE;
fLogger << kINFO << "<PruneTree> trouble determining optimal prune strength"
<< " for Tree " << itree << " at tested prune strength: " << alpha
<< " --> abort forced, use same strength as for previous tree:"
<< fPruneStrength << Endl;
}
}
if (fgDebugLevel==1) cout << "bla: Pruneed with ("<<alpha
<< ") give quality: " << q.back()
<< " and #nodes: " << nnodes
<< endl;
delete dcopy;
}
if (!forceStop) {
multimap<Double_t,Double_t>::reverse_iterator it=quality.rend();
it++;
fPruneStrength = it->second;
fDeltaPruneStrength *= Double_t(q.size())/20.;
}
char buffer[10];
sprintf (buffer,"quad%d",itree);
TH1D *qual=new TH1D(buffer,"Quality of tree prune steps",q.size(),0.,alpha);
qual->SetXTitle("PruneStrength");
qual->SetYTitle("TreeQuality (Purity)");
for (UInt_t i=0; i< q.size(); i++){
qual->SetBinContent(i+1,q[i]);
}
qual->Write();
dt->SetPruneStrength(fPruneStrength);
dt->PruneTree();
return fPruneStrength;
}
Double_t TMVA::MethodBDT::TestTreeQuality( TMVA::DecisionTree *dt )
{
Double_t ncorrect=0, nfalse=0;
for (UInt_t ievt=0; ievt<fValidationSample.size(); ievt++) {
Bool_t isSignalType= (dt->CheckEvent(*(fValidationSample[ievt])) > 0.5 ) ? 1 : 0;
if (isSignalType == (fValidationSample[ievt]->IsSignal()) ) {
ncorrect += fValidationSample[ievt]->GetWeight();
}else{
nfalse += fValidationSample[ievt]->GetWeight();
}
}
return ncorrect / (ncorrect + nfalse);
}
Double_t TMVA::MethodBDT::Boost( vector<TMVA::Event*> eventSample, TMVA::DecisionTree *dt, Int_t iTree )
{
if (fBoostType=="AdaBoost") return this->AdaBoost(eventSample, dt);
else if (fBoostType=="Bagging") return this->Bagging(eventSample, iTree);
else {
fLogger << kINFO << GetOptions() << Endl;
fLogger << kFATAL << "<Boost> unknown boost option called" << Endl;
}
return -1;
}
Double_t TMVA::MethodBDT::AdaBoost( vector<TMVA::Event*> eventSample, TMVA::DecisionTree *dt )
{
Double_t adaBoostBeta=1.;
Double_t err=0, sumw=0, sumwfalse=0, count=0;
vector<Bool_t> correctSelected;
correctSelected.reserve(eventSample.size());
for (vector<TMVA::Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++) {
Bool_t isSignalType= (dt->CheckEvent(*(*e),fUseYesNoLeaf) > 0.5 ) ? 1 : 0;
sumw+=(*e)->GetWeight();
if (isSignalType == (*e)->IsSignal()) {
correctSelected.push_back(kTRUE);
}
else{
sumwfalse+= (*e)->GetWeight();
count+=1;
correctSelected.push_back(kFALSE);
}
}
err=sumwfalse/sumw;
Double_t newSumw=0;
Int_t i=0;
Double_t boostWeight;
if (err>0){
if (adaBoostBeta == 1){
boostWeight = (1-err)/err ;
}
else {
boostWeight = pow((1-err)/err,adaBoostBeta) ;
}
}
else {
boostWeight = 1000;
}
for (vector<TMVA::Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++){
if (!correctSelected[i]){
(*e)->SetWeight( (*e)->GetWeight() * boostWeight);
}
newSumw+=(*e)->GetWeight();
i++;
}
for (vector<TMVA::Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++){
(*e)->SetWeight( (*e)->GetWeight() * sumw / newSumw );
}
fBoostWeightHist->Fill(boostWeight);
fBoostWeightVsTree->SetBinContent(fForest.size(),boostWeight);
fErrFractHist->SetBinContent(fForest.size(),err);
fBoostWeight = boostWeight;
fErrorFraction = err;
return log(boostWeight);
}
Double_t TMVA::MethodBDT::Bagging( vector<TMVA::Event*> eventSample, Int_t iTree )
{
Double_t newSumw=0;
Double_t newWeight;
TRandom *trandom = new TRandom(iTree);
for (vector<TMVA::Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++) {
newWeight = trandom->Rndm();
(*e)->SetWeight(newWeight);
newSumw+=(*e)->GetWeight();
}
for (vector<TMVA::Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++) {
(*e)->SetWeight( (*e)->GetWeight() * eventSample.size() / newSumw );
}
return 1.;
}
void TMVA::MethodBDT::WriteWeightsToStream( ostream& o) const
{
o << "NTrees= " << fForest.size() <<endl;
for (UInt_t i=0; i< fForest.size(); i++){
o << "-999 T " << i << " boostWeight " << fBoostWeights[i] << endl;
(fForest[i])->Print(o);
}
}
void TMVA::MethodBDT::ReadWeightsFromStream( istream& istr )
{
TString var, dummy;
istr >> dummy >> fNTrees;
fLogger << kINFO << "read " << fNTrees << " Decision trees" << Endl;
for (UInt_t i=0;i<fForest.size();i++) delete fForest[i];
fForest.clear();
fBoostWeights.clear();
Int_t iTree;
Double_t boostWeight;
istr >> var >> var;
for (int i=0;i<fNTrees;i++){
istr >> iTree >> dummy >> boostWeight;
if (iTree != i) {
fForest.back()->Print(cout);
fLogger << kFATAL << "error while reading weight file; mismatch Itree="
<< iTree << " i=" << i
<< " dummy " << dummy
<< " boostweight " << boostWeight
<< Endl;
}
TMVA::DecisionTreeNode *n = new TMVA::DecisionTreeNode();
char pos='s';
UInt_t depth =0;
n->ReadRec(istr,pos,depth);
fForest.push_back(new TMVA::DecisionTree());
fForest.back()->SetRoot(n);
fBoostWeights.push_back(boostWeight);
}
}
Double_t TMVA::MethodBDT::GetMvaValue()
{
Double_t myMVA = 0;
Double_t norm = 0;
for (UInt_t itree=0; itree<fForest.size(); itree++){
if (fUseWeightedTrees){
myMVA += fBoostWeights[itree] * fForest[itree]->CheckEvent(Data().Event(),fUseYesNoLeaf);
norm += fBoostWeights[itree];
}
else {
myMVA += fForest[itree]->CheckEvent(Data().Event(),fUseYesNoLeaf);
norm += 1;
}
}
return myMVA /= Double_t(norm);
}
void TMVA::MethodBDT::WriteMonitoringHistosToFile( void ) const
{
fLogger << kINFO << "write monitoring histograms to file: " << BaseDir()->GetPath() << Endl;
BaseDir()->cd();
fBoostWeightHist->Write();
fBoostWeightVsTree->Write();
fErrFractHist->Write();
fNodesBeforePruningVsTree->Write();
fNodesAfterPruningVsTree->Write();
fMonitorNtuple->Write();
}
vector< Double_t > TMVA::MethodBDT::GetVariableImportance()
{
fVariableImportance.resize(GetNvar());
Double_t sum=0;
for (int itree = 0; itree < fNTrees; itree++){
vector<Double_t> relativeImportance(fForest[itree]->GetVariableImportance());
for (unsigned int i=0; i< relativeImportance.size(); i++) {
fVariableImportance[i] += relativeImportance[i] ;
}
}
for (unsigned int i=0; i< fVariableImportance.size(); i++) sum += fVariableImportance[i];
for (unsigned int i=0; i< fVariableImportance.size(); i++) fVariableImportance[i] /= sum;
return fVariableImportance;
}
Double_t TMVA::MethodBDT::GetVariableImportance( UInt_t ivar )
{
vector<Double_t> relativeImportance = this->GetVariableImportance();
if (ivar < (UInt_t)relativeImportance.size()) return relativeImportance[ivar];
else fLogger << kFATAL << "<GetVariableImportance> ivar = " << ivar << " is out of range " << Endl;
return -1;
}
const TMVA::Ranking* TMVA::MethodBDT::CreateRanking()
{
fRanking = new Ranking( GetName(), "Variable Importance" );
vector< Double_t> importance(this->GetVariableImportance());
for (Int_t ivar=0; ivar<GetNvar(); ivar++) {
fRanking->AddRank( *new Rank( GetInputExp(ivar), importance[ivar] ) );
}
return fRanking;
}
ROOT page - Class index - Class Hierarchy - Top of the page
This page has been automatically generated. If you have any comments or suggestions about the page layout send a mail to ROOT support, or contact the developers with any questions or problems regarding ROOT.