* *
**********************************************************************************/
#include "TMVA/MethodBDT.h"
#include "TMVA/Tools.h"
#include "TMVA/Timer.h"
#include "Riostream.h"
#include "TRandom.h"
#include <algorithm>
#include "TObjString.h"
using std::vector;
ClassImp(TMVA::MethodBDT)
TMVA::MethodBDT::MethodBDT( TString jobName, vector<TString>* theVariables,
TTree* theTree, TString theOption, TDirectory* theTargetDir )
: TMVA::MethodBase( jobName, theVariables, theTree, theOption, theTargetDir )
{
InitBDT();
if (fOptions.Sizeof()<0) {
cout << "--- " << GetName() << ": using default options= "<< fOptions <<endl;
}
cout << "--- "<<GetName() << " options:" << fOptions <<endl;
fOptions.ToLower();
TList* list = TMVA::Tools::ParseFormatLine( fOptions );
if (list->GetSize() > 0){
fNTrees = atoi( ((TObjString*)list->At(0))->GetString() ) ;
}
if (list->GetSize() > 1)fBoostType=((TObjString*)list->At(1))->GetString();
if (list->GetSize() > 2){
TString sepType=((TObjString*)list->At(2))->GetString();
if (sepType.Contains("misclassificationerror")) {
fSepType = new TMVA::MisClassificationError();
}
else if (sepType.Contains("giniindex")) {
fSepType = new TMVA::GiniIndex();
}
else if (sepType.Contains("crossentropy")) {
fSepType = new TMVA::CrossEntropy();
}
else if (sepType.Contains("sdivsqrtsplusb")) {
fSepType = new TMVA::SdivSqrtSplusB();
}
else{
cout <<"--- TMVA::DecisionTree::TrainNode Error!! separation Routine not found\n" << endl;
cout << sepType <<endl;
exit(1);
}
}
else{
cout <<"---" <<GetName() <<": using default GiniIndex as separation criterion"<<endl;
fSepType = new TMVA::GiniIndex();
}
fMethodName = "BDT"+fSepType->GetName();
fTestvar = fTestvarPrefix+GetMethodName();
if (list->GetSize() > 4){
fNodeMinEvents = atoi( ((TObjString*)list->At(3))->GetString() ) ;
fDummyOpt = Double_t(atof( ((TObjString*)list->At(4))->GetString() )) ;
}
if (list->GetSize() > 5){
fNCuts = atoi( ((TObjString*)list->At(5))->GetString() ) ;
}
if (list->GetSize() > 6){
fSignalFraction = atof( ((TObjString*)list->At(6))->GetString() ) ;
}
cout << "--- " << GetName() << ": Called with "<<fNTrees <<" trees in the forest"<<endl;
cout << "--- " << GetName() << ": Booked with options: "<<endl;
cout << "--- " << GetName() << ": separation criteria in Node training: "
<< fSepType->GetName()<<endl;
cout << "--- " << GetName() << ": BoostType: "
<< fBoostType << " nTress "<< fNTrees<<endl;
cout << "--- " << GetName() << ": NodeMinEvents: " << fNodeMinEvents << endl
<< "--- " << GetName() << ": dummy: " << fDummyOpt << endl
<< "--- " << GetName() << ": NCuts: " << fNCuts << endl
<< "--- " << GetName() << ": SignalFraction: " << fSignalFraction << endl;
if (0 != fTrainingTree) {
if (Verbose())
cout << "--- " << GetName() << " called " << endl;
this->InitEventSample();
}
else{
cout << "--- " << GetName() << ": Warning: no training Tree given " <<endl;
cout << "--- " << GetName() << " you'll not allowed to cal Train e.t.c..."<<endl;
}
fBoostWeightHist = new TH1F("fBoostWeight","Ada Boost weights",100,1,100);
fErrFractHist = new TH2F("fErrFractHist","error fraction vs tree number",
fNTrees,0,fNTrees,50,0,0.5);
fMonitorNtuple= new TTree("fMonitorNtuple","BDT variables");
fMonitorNtuple->Branch("iTree",&fITree,"iTree/I");
fMonitorNtuple->Branch("boostWeight",&fBoostWeight,"boostWeight/D");
fMonitorNtuple->Branch("errorFraction",&fErrorFraction,"errorFraction/D");
fMonitorNtuple->Branch("nNodes",&fNnodes,"nNodes/I");
delete list;
}
TMVA::MethodBDT::MethodBDT( vector<TString> *theVariables,
TString theWeightFile,
TDirectory* theTargetDir )
: TMVA::MethodBase( theVariables, theWeightFile, theTargetDir )
{
InitBDT();
}
void TMVA::MethodBDT::InitBDT( void )
{
fMethodName = "BDT";
fMethod = TMVA::Types::BDT;
fNTrees = 100;
fBoostType = "AdaBoost";
fNodeMinEvents = 10;
fDummyOpt = 0.;
fNCuts = 20;
fSignalFraction =-1.;
}
TMVA::MethodBDT::~MethodBDT( void )
{
for (UInt_t i=0; i<fEventSample.size(); i++) delete fEventSample[i];
for (UInt_t i=0; i<fForest.size(); i++) delete fForest[i];
}
void TMVA::MethodBDT::InitEventSample( void )
{
if (0 == fTrainingTree) {
cout << "--- " << GetName() << ": Error in ::Init(): fTrainingTree is zero pointer"
<< " --> exit(1)" << endl;
exit(1);
}
Int_t nevents = fTrainingTree->GetEntries();
for (int ievt=0; ievt<nevents; ievt++){
fEventSample.push_back(new TMVA::Event(fTrainingTree, ievt, fInputVars));
if (fSignalFraction > 0){
if (fEventSample.back()->GetType2() < 0) fEventSample.back()->SetWeight(fSignalFraction*fEventSample.back()->GetWeight());
}
}
}
void TMVA::MethodBDT::Train( void )
{
if (!CheckSanity()) {
cout << "--- " << GetName() << ": Error: sanity check failed" << endl;
exit(1);
}
cout << "--- " << GetName() << ": I will train "<< fNTrees << " Decision Trees"
<< " ... patience please" << endl;
TMVA::Timer timer( fNTrees, GetName() );
for (int itree=0; itree<fNTrees; itree++){
timer.DrawProgressBar( itree );
fForest.push_back(new TMVA::DecisionTree(fSepType,
fNodeMinEvents,fNCuts));
fNnodes = fForest.back()->BuildTree(fEventSample);
fBoostWeights.push_back( this->Boost(fEventSample, fForest.back(), itree) );
fITree = itree;
fMonitorNtuple->Fill();
}
cout << "--- " << GetName() << ": elapsed time: " << timer.GetElapsedTime()
<< endl;
WriteWeightsToFile();
WriteHistosToFile();
}
Double_t TMVA::MethodBDT::Boost( vector<TMVA::Event*> eventSample, TMVA::DecisionTree *dt, Int_t iTree )
{
if (fOptions.Contains("adaboost")) return this->AdaBoost(eventSample, dt);
else if (fOptions.Contains("bagging")) return this->Bagging(eventSample, iTree);
else {
cout << "--- " << this->GetName() << "::Boost: ERROR Unknow boost option called\n";
cout << fOptions << endl;
exit(1);
}
}
Double_t TMVA::MethodBDT::AdaBoost( vector<TMVA::Event*> eventSample, TMVA::DecisionTree *dt )
{
fAdaBoostBeta=1.;
Double_t err=0, sumw=0, sumwfalse=0, count=0;
vector<Bool_t> correctSelected;
correctSelected.reserve(eventSample.size());
for (vector<TMVA::Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++) {
Int_t evType= ( 0.5 > dt->CheckEvent(*e) ) ? -1 : 1;
sumw+=(*e)->GetWeight();
if (evType != (*e)->GetType2()) {
sumwfalse+= (*e)->GetWeight();
count+=1;
correctSelected.push_back(kFALSE);
}
else{
correctSelected.push_back(kTRUE);
}
}
err=sumwfalse/sumw;
Double_t newSumw=0;
Int_t i=0;
Double_t boostWeight;
if (err>0){
if (fAdaBoostBeta == 1){
boostWeight = (1-err)/err ;
}else{
boostWeight = pow((1-err)/err,fAdaBoostBeta) ;
}
}else{
boostWeight = 1000;
}
for (vector<TMVA::Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++){
if (!correctSelected[i]){
(*e)->SetWeight( (*e)->GetWeight() * boostWeight);
}
newSumw+=(*e)->GetWeight();
i++;
}
for (vector<TMVA::Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++){
(*e)->SetWeight( (*e)->GetWeight() * sumw / newSumw );
}
fBoostWeightHist->Fill(boostWeight);
fErrFractHist->Fill(fForest.size(),err);
fBoostWeight = boostWeight;
fErrorFraction = err;
return log(boostWeight);
}
Double_t TMVA::MethodBDT::Bagging( vector<TMVA::Event*> eventSample, Int_t iTree )
{
Double_t newSumw=0;
Double_t newWeight;
TRandom *trandom = new TRandom(iTree);
for (vector<TMVA::Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++){
newWeight = trandom->Rndm();
(*e)->SetWeight(newWeight);
newSumw+=(*e)->GetWeight();
}
for (vector<TMVA::Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++){
(*e)->SetWeight( (*e)->GetWeight() * eventSample.size() / newSumw );
}
return 1.;
}
void TMVA::MethodBDT::WriteWeightsToFile( void )
{
TString fname = GetWeightFileName();
cout << "--- " << GetName() << ": creating Weight file: " << fname << endl;
ofstream fout( fname );
if (!fout.good( )) {
cout << "--- " << GetName() << ": Error in ::WriteWeightsToFile: "
<< "unable to open output Weight file: " << fname << endl;
exit(1);
}
fout << this->GetMethodName() <<endl;
fout << "NVars= " << fNvar <<endl;
for (Int_t ivar=0; ivar<fNvar; ivar++) {
TString var = (*fInputVars)[ivar];
fout << var << " " << GetXminNorm( var ) << " " << GetXmaxNorm( var ) << endl;
}
fout << "NTrees= " << fForest.size() <<endl;
for (UInt_t i=0; i< fForest.size(); i++){
fout << "-999 *******Tree " << i << " boostWeight " << fBoostWeights[i] << endl;
(fForest[i])->Print(fout);
}
fout.close();
}
void TMVA::MethodBDT::ReadWeightsFromFile( void )
{
TString fname = GetWeightFileName();
cout << "--- " << GetName() << ": reading Weight file: " << fname << endl;
ifstream fin( fname );
if (!fin.good( )) {
cout << "--- " << GetName() << ": Error in ::ReadWeightsFromFile: "
<< "unable to open input file: " << fname << endl;
exit(1);
}
TString var, dummy;
Double_t xmin, xmax;
fin >> dummy;
this->SetMethodName(dummy);
fin >> dummy >> fNvar;
for (Int_t ivar=0; ivar<fNvar; ivar++) {
fin >> var >> xmin >> xmax;
(*fInputVars)[ivar] = var;
this->SetXminNorm( ivar, xmin );
this->SetXmaxNorm( ivar, xmax );
}
fin >> dummy >> fNTrees;
cout << "--- " << GetName() << ": Read "<<fNTrees<<" Decision trees\n";
for (UInt_t i=0;i<fForest.size();i++) delete fForest[i];
fForest.clear();
fBoostWeights.clear();
Int_t iTree;
Double_t boostWeight;
fin >> var >> var;
for (int i=0;i<fNTrees;i++){
fin >> iTree >> dummy >> boostWeight;
if (iTree != i) {
cout << "--- " << ": Error while reading Weight file \n ";
cout << "--- " << ": mismatch Itree="<<iTree<<" i="<<i<<endl;
exit(1);
}
TMVA::DecisionTreeNode *n = new TMVA::DecisionTreeNode();
TMVA::NodeID id;
n->ReadRec(fin,id);
fForest.push_back(new TMVA::DecisionTree());
fForest.back()->SetRoot(n);
fBoostWeights.push_back(boostWeight);
}
fin.close();
}
Double_t TMVA::MethodBDT::GetMvaValue(TMVA::Event *e)
{
const bool useWeightedMajorityVote = kFALSE;
Double_t myMVA = 0;
Double_t norm = 0;
for (UInt_t itree=0; itree<fForest.size(); itree++){
if (useWeightedMajorityVote){
myMVA += fBoostWeights[itree] * fForest[itree]->CheckEvent(e);
norm += fBoostWeights[itree];
}
else {
myMVA += fForest[itree]->CheckEvent(e);
norm += 1.;
}
}
return myMVA /= Double_t(norm);
}
void TMVA::MethodBDT::WriteHistosToFile( void )
{
cout << "--- " << GetName() << ": write " << GetName()
<<" special histos to file: " << fBaseDir->GetPath() << endl;
gDirectory->GetListOfKeys()->Print();
fLocalTDir = fBaseDir->mkdir(GetName()+GetMethodName());
fLocalTDir->cd();
fBoostWeightHist->Write();
fErrFractHist->Write();
fMonitorNtuple->Write();
}
ROOT page - Class index - Class Hierarchy - Top of the page
This page has been automatically generated. If you have any comments or suggestions about the page layout send a mail to ROOT support, or contact the developers with any questions or problems regarding ROOT.