#include <iostream>
#include <algorithm>
#include "TMVA/DecisionTree.h"
#include "TMVA/DecisionTreeNode.h"
#include "TMVA/BinarySearchTree.h"
#include "TMVA/Tools.h"
#include "TMVA/GiniIndex.h"
#include "TMVA/CrossEntropy.h"
#include "TMVA/MisClassificationError.h"
#include "TMVA/SdivSqrtSplusB.h"
#include "TMVA/Event.h"
using std::vector;
#define USE_HELGESCODE 1 // the other one is Dougs implementation of the TrainNode
#define USE_HELGE_V1 0 // out loop is over NVAR in TrainNode, inner loop is Eventloop
ClassImp(TMVA::DecisionTree)
;
TMVA::DecisionTree::DecisionTree( void )
: BinaryTree(),
fNvars (0),
fNCuts (-1),
fSepType (NULL),
fMinSize (0),
fPruneMethod(kExpectedErrorPruning),
fDepth (0),
fQualityIndex(NULL)
{
fLogger.SetSource( "DecisionTree" );
}
TMVA::DecisionTree::DecisionTree( DecisionTreeNode* n )
: BinaryTree(),
fNvars (0),
fNCuts (-1),
fSepType (NULL),
fMinSize (0),
fPruneMethod(kExpectedErrorPruning),
fDepth (0),
fQualityIndex(NULL)
{
fLogger.SetSource( "DecisionTree" );
this->SetRoot( n );
this->SetParentTreeInNodes();
fLogger.SetSource( "DecisionTree" );
}
TMVA::DecisionTree::DecisionTree( TMVA::SeparationBase *sepType,Int_t minSize,
Int_t nCuts, TMVA::SeparationBase *qtype):
BinaryTree(),
fNvars (0),
fNCuts (nCuts),
fSepType (sepType),
fMinSize (minSize),
fPruneMethod(kExpectedErrorPruning),
fDepth (0),
fQualityIndex(qtype)
{
fLogger.SetSource( "DecisionTree" );
}
TMVA::DecisionTree::DecisionTree( const DecisionTree &d):
BinaryTree(),
fNvars (d.fNvars),
fNCuts (d.fNCuts),
fSepType (d.fSepType),
fMinSize (d.fMinSize),
fPruneMethod(d.fPruneMethod),
fDepth (d.fDepth),
fQualityIndex(d.fQualityIndex)
{
this->SetRoot( new DecisionTreeNode ( *((DecisionTreeNode*)(d.GetRoot())) ) );
this->SetParentTreeInNodes();
fNNodes = d.fNNodes;
fLogger.SetSource( "DecisionTree" );
}
TMVA::DecisionTree::~DecisionTree( void )
{
}
void TMVA::DecisionTree::SetParentTreeInNodes( DecisionTreeNode *n)
{
if (n == NULL){
n = (DecisionTreeNode*) this->GetRoot();
if (n == NULL) return ;
}
if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) != NULL) ) {
fLogger << kFatal << " Node with only one daughter?? Something went wrong" << Endl;
return;
} else if ((this->GetLeftDaughter(n) != NULL) && (this->GetRightDaughter(n) == NULL) ) {
fLogger << kFatal << " Node with only one daughter?? Something went wrong" << Endl;
return;
} else {
if (this->GetLeftDaughter(n) != NULL){
this->SetParentTreeInNodes( this->GetLeftDaughter(n) );
}
if (this->GetRightDaughter(n) != NULL) {
this->SetParentTreeInNodes( this->GetRightDaughter(n) );
}
}
n->SetParentTree(this);
if (n->GetDepth() > fDepth) fDepth = n->GetDepth();
return;
}
Int_t TMVA::DecisionTree::BuildTree( vector<TMVA::Event*> & eventSample,
TMVA::DecisionTreeNode *node )
{
if (node==NULL) {
node = new TMVA::DecisionTreeNode();
fNNodes = 1;
this->SetRoot(node);
this->GetRoot()->SetPos('s');
this->GetRoot()->SetDepth(0);
this->GetRoot()->SetParentTree(this);
}
UInt_t nevents = eventSample.size();
if (nevents > 0 ) {
fNvars = eventSample[0]->GetNVars();
fVariableImportance.resize(fNvars);
}
else fLogger << kFATAL << ":<BuildTree> eventsample Size == 0 " << Endl;
Double_t s=0, b=0;
for (UInt_t i=0; i<eventSample.size(); i++){
if (eventSample[i]->IsSignal())
s += eventSample[i]->GetWeight();
else
b += eventSample[i]->GetWeight();
}
node->SetNSigEvents(s);
node->SetNBkgEvents(b);
if (node == this->GetRoot()) node->SetNEvents(s+b);
node->SetSeparationIndex(fSepType->GetSeparationIndex(s,b));
if ( eventSample.size() > fMinSize &&
node->GetSoverSB()*eventSample.size() > fMinSize &&
node->GetSoverSB()*eventSample.size() < eventSample.size()-fMinSize ) {
Double_t separationGain;
separationGain = this->TrainNode(eventSample, node);
if (separationGain == 0) {
if (node->GetSoverSB() > 0.5) node->SetNodeType(1);
else node->SetNodeType(-1);
if (node->GetDepth() > fDepth) fDepth = node->GetDepth();
}else{
vector<TMVA::Event*> leftSample; leftSample.reserve(nevents);
vector<TMVA::Event*> rightSample; rightSample.reserve(nevents);
Double_t nRight=0, nLeft=0;
for (UInt_t ie=0; ie< nevents ; ie++){
if (node->GoesRight(*eventSample[ie])){
rightSample.push_back(eventSample[ie]);
nRight += eventSample[ie]->GetWeight();
}
else {
leftSample.push_back(eventSample[ie]);
nLeft += eventSample[ie]->GetWeight();
}
}
if (leftSample.size() == 0 || rightSample.size() == 0) {
fLogger << kFATAL << "<TrainNode> all events went to the same branch" << Endl
<< "--- Hence new node == old node ... check" << Endl
<< "--- left:" << leftSample.size()
<< " right:" << rightSample.size() << Endl
<< "--- this should never happen, please write a bug report to Helge.Voss@cern.ch"
<< Endl;
}
TMVA::DecisionTreeNode *rightNode = new TMVA::DecisionTreeNode(node,'r');
fNNodes++;
rightNode->SetNEvents(nRight);
TMVA::DecisionTreeNode *leftNode = new TMVA::DecisionTreeNode(node,'l');
fNNodes++;
leftNode->SetNEvents(nLeft);
node->SetNodeType(0);
node->SetLeft(leftNode);
node->SetRight(rightNode);
this->BuildTree(rightSample, rightNode);
this->BuildTree(leftSample, leftNode );
}
}
else{
if (node->GetSoverSB() > 0.5) node->SetNodeType(1);
else node->SetNodeType(-1);
if (node->GetDepth() > fDepth) fDepth = node->GetDepth();
}
return fNNodes;
}
void TMVA::DecisionTree::FillTree( vector<TMVA::Event*> & eventSample)
{
for (UInt_t i=0; i<eventSample.size(); i++){
this->FillEvent(*(eventSample[i]),NULL);
}
}
void TMVA::DecisionTree::FillEvent( TMVA::Event & event,
TMVA::DecisionTreeNode *node )
{
if (node == NULL) {
node = (TMVA::DecisionTreeNode*)this->GetRoot();
}
node->IncrementNEvents( event.GetWeight() );
if (event.IsSignal())
node->IncrementNSigEvents( event.GetWeight() );
else
node->IncrementNBkgEvents( event.GetWeight() );
node->SetSeparationIndex(fSepType->GetSeparationIndex(node->GetNSigEvents(),
node->GetNBkgEvents()));
if (node->GetNodeType() == 0){
if (node->GoesRight(event))
this->FillEvent(event,(TMVA::DecisionTreeNode*)(node->GetRight())) ;
else
this->FillEvent(event,(TMVA::DecisionTreeNode*)(node->GetLeft())) ;
}
}
void TMVA::DecisionTree::ClearTree()
{
if (this->GetRoot()!=NULL)
((DecisionTreeNode*)(this->GetRoot()))->ClearNodeAndAllDaughters();
}
void TMVA::DecisionTree::PruneTree()
{
if (fPruneMethod == kExpectedErrorPruning) {
this->PruneTreeEEP((DecisionTreeNode *)this->GetRoot());
} else if (fPruneMethod == kCostComplexityPruning) {
this->PruneTreeCC();
} else if (fPruneMethod == kMCC) {
this->PruneTreeMCC();
} else {
fLogger << kFatal << "Selected pruning method not yet implemented "
<< Endl;
}
this->CountNodes();
};
void TMVA::DecisionTree::PruneTreeEEP(DecisionTreeNode *node)
{
DecisionTreeNode *l = (DecisionTreeNode*)node->GetLeft();
DecisionTreeNode *r = (DecisionTreeNode*)node->GetRight();
if (node->GetNodeType() == 0){
this->PruneTreeEEP(l);
this->PruneTreeEEP(r);
if (this->GetSubTreeError(node) >= this->GetNodeError(node)) {
this->PruneNode(node);
}
}
}
void TMVA::DecisionTree::PruneTreeCC()
{
Double_t currentCC = this->GetCostComplexity(fPruneStrength);
Double_t nextCC = this->GetCostComplexityIfNextPruneStep(fPruneStrength);
while (currentCC > nextCC && this->GetNNodes() > 3 ){
this->PruneNode( this->FindCCPruneCandidate() );
currentCC = this->GetCostComplexity(fPruneStrength);
nextCC = this->GetCostComplexityIfNextPruneStep(fPruneStrength);
}
return;
}
void TMVA::DecisionTree::PruneTreeMCC()
{
this->FillLinkStrengthMap();
Double_t currentG = fLinkStrengthMap.begin()->first;
while (currentG < fPruneStrength && this->GetNNodes() > 3 ){
this->PruneNode( this->GetWeakestLink() );
currentG = fLinkStrengthMap.begin()->first;
}
return;
}
TMVA::DecisionTreeNode* TMVA::DecisionTree::GetWeakestLink()
{
this->FillLinkStrengthMap();
return fLinkStrengthMap.begin()->second;
}
void TMVA::DecisionTree::FillLinkStrengthMap(TMVA::DecisionTreeNode *n)
{
if (n == NULL){
n = (DecisionTreeNode*) this->GetRoot();
fLinkStrengthMap.clear();
if (n == NULL) return ;
}
if (this->GetLeftDaughter(n) != NULL){
this->FillLinkStrengthMap( this->GetLeftDaughter(n));
}
if (this->GetRightDaughter(n) != NULL) {
this->FillLinkStrengthMap( this->GetRightDaughter(n));
}
if ((this->GetLeftDaughter(n) != NULL) && (this->GetRightDaughter(n) != NULL) ) {
Double_t alpha = ( this->MisClassificationCostOfNode(n) -
this->MisClassificationCostOfSubTree(n) ) /
(n->CountMeAndAllDaughters() - 1);
fLinkStrengthMap.insert(pair<const Double_t, TMVA::DecisionTreeNode* > ( alpha, n ));
}
}
Double_t TMVA::DecisionTree::MisClassificationCostOfNode(TMVA::DecisionTreeNode *n)
{
return (1 - n->GetPurity()) * n->GetNEvents();
}
Double_t TMVA::DecisionTree::MisClassificationCostOfSubTree(TMVA::DecisionTreeNode *n)
{
Double_t tmp=0;
if (n == NULL){
n = (DecisionTreeNode*) this->GetRoot();
if (n == NULL) return 0;
}
if (this->GetLeftDaughter(n) != NULL){
tmp += this->MisClassificationCostOfSubTree( this->GetLeftDaughter(n));
}
if (this->GetRightDaughter(n) != NULL) {
tmp += this->MisClassificationCostOfSubTree( this->GetRightDaughter(n));
}
if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) == NULL) ) {
tmp = this->MisClassificationCostOfNode(n);
}
return tmp;
}
UInt_t TMVA::DecisionTree::CountLeafNodes(TMVA::DecisionTreeNode *n)
{
if (n == NULL){
n = (DecisionTreeNode*) this->GetRoot();
if (n == NULL) return 0 ;
}
UInt_t countLeafs=0;
if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) == NULL) ) {
countLeafs += 1;
} else {
if (this->GetLeftDaughter(n) != NULL){
countLeafs += this->CountLeafNodes( this->GetLeftDaughter(n) );
}
if (this->GetRightDaughter(n) != NULL) {
countLeafs += this->CountLeafNodes( this->GetRightDaughter(n) );
}
}
return countLeafs;
}
Double_t TMVA::DecisionTree::GetCostComplexity(Double_t alpha)
{
Double_t cc=0.;
this->FillQualityMap();
multimap<Double_t, TMVA::DecisionTreeNode* >::iterator it=fQualityMap.begin();
Int_t count=0;
for (;it!=fQualityMap.end(); it++){
Double_t s=it->second->GetNSigEvents();
Double_t b=it->second->GetNBkgEvents();
cc += (s+b) * it->first ;
count++;
}
return cc+alpha * count;
}
Double_t TMVA::DecisionTree::GetCostComplexityIfNextPruneStep(Double_t alpha)
{
Double_t cc=0.;
this->FillQualityMap();
this->FillQualityGainMap();
multimap<Double_t, TMVA::DecisionTreeNode* >::iterator it=fQualityMap.begin();
Int_t count=0;
for (;it!=fQualityMap.end(); it++){
if (it->second->GetParent() != fQualityGainMap.begin()->second ) {
Double_t s=it->second->GetNSigEvents();
Double_t b=it->second->GetNBkgEvents();
cc += (s+b) * it->first ;
count++;
}
}
Double_t s=fQualityGainMap.begin()->second->GetNSigEvents();
Double_t b=fQualityGainMap.begin()->second->GetNBkgEvents();
cc += (s+b) * fQualityIndex->GetSeparationIndex(s,b);
count++;
cc+=alpha*count;
return cc;
}
void TMVA::DecisionTree::FillQualityGainMap(DecisionTreeNode* n )
{
if (n == NULL){
n = (DecisionTreeNode*) this->GetRoot();
fQualityGainMap.clear();
if (n == NULL) return ;
}
if (this->GetLeftDaughter(n) != NULL){
this->FillQualityGainMap( this->GetLeftDaughter(n));
}
if (this->GetRightDaughter(n) != NULL) {
this->FillQualityGainMap( this->GetRightDaughter(n));
}
if ((this->GetLeftDaughter(n) != NULL) && (this->GetRightDaughter(n) != NULL) ) {
if ((this->GetLeftDaughter(n)->GetLeft() == NULL) &&
(this->GetLeftDaughter(n)->GetRight() == NULL) &&
(this->GetRightDaughter(n)->GetLeft() == NULL) &&
(this->GetRightDaughter(n)->GetRight() == NULL) ){
fQualityGainMap.insert(pair<const Double_t, TMVA::DecisionTreeNode* >
( fQualityIndex->GetSeparationGain (this->GetRightDaughter(n)->GetNSigEvents(),
this->GetRightDaughter(n)->GetNBkgEvents(),
n->GetNSigEvents(), n->GetNBkgEvents()),
n));
}
}
return;
}
void TMVA::DecisionTree::FillQualityMap(DecisionTreeNode* n )
{
if (n == NULL){
n = (DecisionTreeNode*) this->GetRoot();
fQualityMap.clear();
if (n == NULL) return ;
}
if (this->GetLeftDaughter(n) != NULL){
this->FillQualityMap( this->GetLeftDaughter(n));
}
if (this->GetRightDaughter(n) != NULL) {
this->FillQualityMap( this->GetRightDaughter(n));
}
if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) == NULL) ) {
fQualityMap.insert(pair<const Double_t, TMVA::DecisionTreeNode* >
( fQualityIndex->GetSeparationIndex (n->GetNSigEvents(),
n->GetNBkgEvents()),
n));
}
return;
}
void TMVA::DecisionTree::DescendTree( DecisionTreeNode *n)
{
if (n == NULL){
n = (DecisionTreeNode*) this->GetRoot();
if (n == NULL) return ;
}
if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) == NULL) ) {
} else if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) != NULL) ) {
fLogger << kFatal << " Node with only one daughter?? Something went wrong" << Endl;
return;
} else if ((this->GetLeftDaughter(n) != NULL) && (this->GetRightDaughter(n) == NULL) ) {
fLogger << kFatal << " Node with only one daughter?? Something went wrong" << Endl;
return;
} else {
if (this->GetLeftDaughter(n) != NULL){
this->DescendTree( this->GetLeftDaughter(n) );
}
if (this->GetRightDaughter(n) != NULL) {
this->DescendTree( this->GetRightDaughter(n) );
}
}
return;
}
TMVA::DecisionTreeNode* TMVA::DecisionTree::FindCCPruneCandidate()
{
this->FillQualityGainMap();
return fQualityGainMap.begin()->second;
}
void TMVA::DecisionTree::PruneNode(DecisionTreeNode *node)
{
DecisionTreeNode *l = (DecisionTreeNode*)node->GetLeft();
DecisionTreeNode *r = (DecisionTreeNode*)node->GetRight();
node->SetRight(NULL);
node->SetLeft(NULL);
if (node->GetSoverSB() > 0.5) node->SetNodeType(1);
else node->SetNodeType(-1);
this->DeleteNode(l);
this->DeleteNode(r);
this->CountNodes();
}
Double_t TMVA::DecisionTree::GetNodeError(DecisionTreeNode *node)
{
Double_t errorRate = 0;
Double_t nEvts = node->GetNEvents();
Double_t f=0;
if (node->GetSoverSB() > 0.5) f = node->GetSoverSB();
else f = (1-node->GetSoverSB());
Double_t df = sqrt(f*(1-f)/nEvts );
errorRate = std::min(1.,(1 - (f-fPruneStrength*df) ));
return errorRate;
}
Double_t TMVA::DecisionTree::GetSubTreeError(DecisionTreeNode *node)
{
DecisionTreeNode *l = (DecisionTreeNode*)node->GetLeft();
DecisionTreeNode *r = (DecisionTreeNode*)node->GetRight();
if (node->GetNodeType() == 0) {
Double_t subTreeError =
(l->GetNEvents() * this->GetSubTreeError(l) +
r->GetNEvents() * this->GetSubTreeError(r)) /
node->GetNEvents();
return subTreeError;
}else{
return this->GetNodeError(node);
}
}
TMVA::DecisionTreeNode* TMVA::DecisionTree::GetLeftDaughter( DecisionTreeNode *n)
{
return (DecisionTreeNode*) n->GetLeft();
}
TMVA::DecisionTreeNode* TMVA::DecisionTree::GetRightDaughter( DecisionTreeNode *n)
{
return (DecisionTreeNode*) n->GetRight();
}
TMVA::DecisionTreeNode* TMVA::DecisionTree::GetNode(ULong_t sequence, UInt_t depth)
{
DecisionTreeNode* current = (DecisionTreeNode*) this->GetRoot();
for (UInt_t i =0; i < depth; i++){
ULong_t tmp = 1 << i;
if ( tmp & sequence) current = this->GetRightDaughter(current);
else current = this->GetLeftDaughter(current);
}
return current;
}
void TMVA::DecisionTree::FindMinAndMax(vector<TMVA::Event*> & eventSample,
vector<Double_t> & xmin,
vector<Double_t> & xmax)
{
UInt_t num_events = eventSample.size();
for (Int_t ivar=0; ivar < fNvars; ivar++){
xmin[ivar]=xmax[ivar]=eventSample[0]->GetVal(ivar);
}
for (UInt_t i=1;i<num_events;i++){
for (Int_t ivar=0; ivar < fNvars; ivar++){
if (xmin[ivar]>eventSample[i]->GetVal(ivar))
xmin[ivar]=eventSample[i]->GetVal(ivar);
if (xmax[ivar]<eventSample[i]->GetVal(ivar))
xmax[ivar]=eventSample[i]->GetVal(ivar);
}
}
};
void TMVA::DecisionTree::SetCutPoints(vector<Double_t> & cut_points,
Double_t xmin,
Double_t xmax,
Int_t num_gridpoints)
{
Double_t step = (xmax - xmin)/num_gridpoints;
Double_t x = xmin + step/2;
for (Int_t j=0; j < num_gridpoints; j++){
cut_points[j] = x;
x += step;
}
};
#if USE_HELGESCODE==1
Double_t TMVA::DecisionTree::TrainNode(vector<TMVA::Event*> & eventSample,
TMVA::DecisionTreeNode *node)
{
vector<Double_t> *xmin = new vector<Double_t>( fNvars );
vector<Double_t> *xmax = new vector<Double_t>( fNvars );
Double_t separationGain = -1, sepTmp;
Double_t cutValue=-999;
Int_t mxVar=-1, cutIndex=0;
Bool_t cutType=kTRUE;
Double_t nTotS, nTotB;
UInt_t nevents = eventSample.size();
for (int ivar=0; ivar < fNvars; ivar++){
(*xmin)[ivar]=(*xmax)[ivar]=eventSample[0]->GetVal(ivar);
}
for (UInt_t iev=1;iev<nevents;iev++){
for (Int_t ivar=0; ivar < fNvars; ivar++){
Double_t eventData = eventSample[iev]->GetVal(ivar);
if ((*xmin)[ivar]>eventData)(*xmin)[ivar]=eventData;
if ((*xmax)[ivar]<eventData)(*xmax)[ivar]=eventData;
}
}
vector< vector<Double_t> > nSelS (fNvars);
vector< vector<Double_t> > nSelB (fNvars);
vector< vector<Double_t> > significance (fNvars);
vector< vector<Double_t> > cutValues(fNvars);
vector< vector<Bool_t> > cutTypes(fNvars);
for (int ivar=0; ivar < fNvars; ivar++){
cutValues[ivar].resize(fNCuts);
cutTypes[ivar].resize(fNCuts);
nSelS[ivar].resize(fNCuts);
nSelB[ivar].resize(fNCuts);
significance[ivar].resize(fNCuts);
Double_t istepSize =( (*xmax)[ivar] - (*xmin)[ivar] ) / Double_t(fNCuts);
for (Int_t icut=0; icut<fNCuts; icut++){
cutValues[ivar][icut]=(*xmin)[ivar]+(Float_t(icut)+0.5)*istepSize;
}
}
#if USE_HELGE_V1==1
nTotS=0; nTotB=0;
for (int ivar=0; ivar < fNvars; ivar++){
for (UInt_t iev=0; iev<nevents; iev++){
Double_t eventData = eventSample[iev]->GetData(ivar);
Int_t eventType = eventSample[iev]->GetType();
Double_t eventWeight= eventSample[iev]->GetWeight();
if (ivar==0){
if (eventType==1){
nTotS+=eventWeight;
}else {
nTotB+=eventWeight;
}
}
for (Int_t icut=0; icut<fNCuts; icut++){
if (eventData > cutValues[ivar][icut]){
if (eventType==1) nSelS[ivar][icut]+=eventWeight;
else nSelB[ivar][icut]+=eventWeight;
}
}
}
}
#else
nTotS=0; nTotB=0;
for (UInt_t iev=0; iev<nevents; iev++){
Int_t eventType = eventSample[iev]->Type();
Double_t eventWeight = eventSample[iev]->GetWeight();
if (eventType==1){
nTotS+=eventWeight;
}else {
nTotB+=eventWeight;
}
for (int ivar=0; ivar < fNvars; ivar++){
Double_t eventData = eventSample[iev]->GetVal(ivar);
for (Int_t icut=0; icut<fNCuts; icut++){
if (eventData > cutValues[ivar][icut]){
if (eventType==1) nSelS[ivar][icut]+=eventWeight;
else nSelB[ivar][icut]+=eventWeight;
}
}
}
}
#endif
for (int ivar=0; ivar < fNvars; ivar++){
for (Int_t icut=0; icut<fNCuts; icut++){
sepTmp = fSepType->GetSeparationGain(nSelS[ivar][icut], nSelB[ivar][icut], nTotS, nTotB);
if (separationGain < sepTmp) {
separationGain = sepTmp;
mxVar = ivar;
cutIndex = icut;
}
}
}
if (nSelS[mxVar][cutIndex]/nTotS > nSelB[mxVar][cutIndex]/nTotB) cutType=kTRUE;
else cutType=kFALSE;
cutValue = cutValues[mxVar][cutIndex];
node->SetSelector((UInt_t)mxVar);
node->SetCutValue(cutValue);
node->SetCutType(cutType);
node->SetSeparationGain(separationGain);
fVariableImportance[mxVar] += separationGain*separationGain * (nTotS+nTotB) * (nTotS+nTotB) ;
delete xmin;
delete xmax;
return separationGain;
}
#else
Double_t TMVA::DecisionTree::TrainNode(vector<TMVA::Event*> & eventSample,
TMVA::DecisionTreeNode *node)
{
vector<Double_t> xmin ( fNvars );
vector<Double_t> xmax ( fNvars );
Double_t separationGain = -1;
Double_t cutValue=-999;
Int_t mxVar=-1;
Bool_t cutType=kTRUE;
Double_t nSelS=0., nSelB=0., nTotS=0., nTotB=0.;
UInt_t num_events = eventSample.size();
vector<vector<Double_t> > signal_counts (fNvars);
vector<vector<Double_t> > background_counts (fNvars);
vector<vector<Double_t> > cut_points (fNvars);
vector<vector<Double_t> > significance (fNvars);
this->FindMinAndMax(eventSample, xmin, xmax);
for (Int_t i=0; i < fNvars; i++){
signal_counts[i].resize(fNCuts);
background_counts[i].resize(fNCuts);
cut_points[i].resize(fNCuts);
significance[i].resize(fNCuts);
this->SetCutPoints(cut_points[i], xmin[i], xmax[i], fNCuts);
}
for (UInt_t event=0; event < num_events; event++){
Int_t event_type = eventSample[event]->GetType();
Double_t event_weight = eventSample[event]->GetWeight();
if (event_type == 1){
nTotS += event_weight;
} else {
nTotB += event_weight;
}
for (Int_t variable = 0; variable < fNvars; variable++){
Double_t event_val = eventSample[event]->GetData(variable);
for (Int_t cut=0; cut < fNCuts; cut++){
if (event_val > cut_points[variable][cut]){
if (event_type == 1){
signal_counts[variable][cut] += event_weight;
} else {
background_counts[variable][cut] += event_weight;
}
}
}
}
}
for (Int_t var = 0; var < fNvars; var++){
for (Int_t cut=0; cut < fNCuts; cut++){
Double_t cur_sep = fSepType->GetSeparationGain(signal_counts[var][cut],
background_counts[var][cut],
nTotS, nTotB);
if (separationGain < cur_sep) {
separationGain = cur_sep;
cutValue=cut_points[var][cut];
cutType= (nSelS/nTotS > nSelB/nTotB) ? kTRUE : kFALSE;
mxVar = var;
}
}
}
node->SetSelector(mxVar);
node->SetCutValue(cutValue);
node->SetCutType(cutType);
node->SetSeparationGain(separationGain);
fVariableImportance[mxVar] += separationGain*separationGain * (nTotS+nTotB)* (nTotS+nTotB);
return separationGain;
}
#endif
Double_t TMVA::DecisionTree::CheckEvent(const TMVA::Event & e, Bool_t UseYesNoLeaf)
{
TMVA::DecisionTreeNode *current = (TMVA::DecisionTreeNode*)this->GetRoot();
while(current->GetNodeType() == 0){
if (current->GoesRight(e))
current=(TMVA::DecisionTreeNode*)current->GetRight();
else current=(TMVA::DecisionTreeNode*)current->GetLeft();
}
if (UseYesNoLeaf) return Double_t ( current->GetNodeType() );
else return current->GetSoverSB();
}
Double_t TMVA::DecisionTree::SamplePurity(vector<TMVA::Event*> eventSample)
{
Double_t sumsig=0, sumbkg=0, sumtot=0;
for (UInt_t ievt=0; ievt<eventSample.size(); ievt++) {
if (eventSample[ievt]->Type()==0) sumbkg+=eventSample[ievt]->GetWeight();
if (eventSample[ievt]->Type()==1) sumsig+=eventSample[ievt]->GetWeight();
sumtot+=eventSample[ievt]->GetWeight();
}
if (sumtot!= (sumsig+sumbkg)){
fLogger << kFATAL << "<SamplePurity> sumtot != sumsig+sumbkg"
<< sumtot << " " << sumsig << " " << sumbkg << Endl;
}
if (sumtot>0) return sumsig/(sumsig + sumbkg);
else return -1;
}
vector< Double_t > TMVA::DecisionTree::GetVariableImportance()
{
vector<Double_t> relativeImportance(fNvars);
Double_t sum=0;
for (int i=0; i< fNvars; i++) {
sum += fVariableImportance[i];
relativeImportance[i] = fVariableImportance[i];
}
for (int i=0; i< fNvars; i++) {
relativeImportance[i] /= sum;
}
return relativeImportance;
}
Double_t TMVA::DecisionTree::GetVariableImportance(Int_t ivar)
{
vector<Double_t> relativeImportance = this->GetVariableImportance();
if (ivar >= 0 && ivar < fNvars) return relativeImportance[ivar];
else {
fLogger << kFATAL << "<GetVariableImportance>" << Endl
<< "--- ivar = " << ivar << " is out of range " << Endl;
}
return -1;
}
TH2D* TMVA::DecisionTree::DrawTree(TString hname)
{
Double_t xmax= 2*fDepth + 0.5;
Double_t xmin= -xmax;
ULong_t nbins =1; for (UInt_t i=0; i<fDepth; i++) nbins *= 2;
TH2D* h=new TH2D(hname,"",2*nbins+1, xmin, xmax,
2*fDepth+2, -0.5, 2*fDepth+0.5);
this->DrawNode( h, (DecisionTreeNode*)this->GetRoot(), 2*fDepth, 0, Double_t(fDepth) );
return h;
}
void TMVA::DecisionTree::DrawNode( TH2D* h, DecisionTreeNode *n,
Double_t y, Double_t x, Double_t scale)
{
if (this->GetLeftDaughter(n) != NULL){
this->DrawNode( h, this->GetLeftDaughter(n), y-2, x-scale, scale/2. );
}
if (this->GetRightDaughter(n) != NULL) {
this->DrawNode( h, this->GetRightDaughter(n), y-2, x+scale, scale/2.);
}
h->Fill(x,y,n->GetNEvents());
return;
}
ROOT page - Class index - Class Hierarchy - Top of the page
This page has been automatically generated. If you have any comments or suggestions about the page layout send a mail to ROOT support, or contact the developers with any questions or problems regarding ROOT.