#include <iostream>
#include <algorithm>
#include <vector>
#include <limits>
#include <fstream>
#include <algorithm>
#include <cassert>
#include "TRandom3.h"
#include "TMath.h"
#include "TMatrix.h"
#include "TMVA/MsgLogger.h"
#include "TMVA/DecisionTree.h"
#include "TMVA/DecisionTreeNode.h"
#include "TMVA/BinarySearchTree.h"
#include "TMVA/Tools.h"
#include "TMVA/GiniIndex.h"
#include "TMVA/CrossEntropy.h"
#include "TMVA/MisClassificationError.h"
#include "TMVA/SdivSqrtSplusB.h"
#include "TMVA/Event.h"
#include "TMVA/BDTEventWrapper.h"
#include "TMVA/IPruneTool.h"
#include "TMVA/CostComplexityPruneTool.h"
#include "TMVA/ExpectedErrorPruneTool.h"
const Int_t TMVA::DecisionTree::fgRandomSeed = 0;
using std::vector;
ClassImp(TMVA::DecisionTree)
TMVA::DecisionTree::DecisionTree():
BinaryTree(),
fNvars (0),
fNCuts (-1),
fUseFisherCuts (kFALSE),
fMinLinCorrForFisher (1),
fUseExclusiveVars (kTRUE),
fSepType (NULL),
fRegType (NULL),
fMinSize (0),
fMinNodeSize (1),
fMinSepGain (0),
fUseSearchTree(kFALSE),
fPruneStrength(0),
fPruneMethod (kNoPruning),
fNNodesBeforePruning(0),
fNodePurityLimit(0.5),
fRandomisedTree (kFALSE),
fUseNvars (0),
fUsePoissonNvars(kFALSE),
fMyTrandom (NULL),
fMaxDepth (999999),
fSigClass (0),
fTreeID (0),
fAnalysisType (Types::kClassification)
{
}
TMVA::DecisionTree::DecisionTree( TMVA::SeparationBase *sepType, Float_t minSize, Int_t nCuts, UInt_t cls,
Bool_t randomisedTree, Int_t useNvars, Bool_t usePoissonNvars,
UInt_t nMaxDepth, Int_t iSeed, Float_t purityLimit, Int_t treeID):
BinaryTree(),
fNvars (0),
fNCuts (nCuts),
fUseFisherCuts (kFALSE),
fMinLinCorrForFisher (1),
fUseExclusiveVars (kTRUE),
fSepType (sepType),
fRegType (NULL),
fMinSize (0),
fMinNodeSize (minSize),
fMinSepGain (0),
fUseSearchTree (kFALSE),
fPruneStrength (0),
fPruneMethod (kNoPruning),
fNNodesBeforePruning(0),
fNodePurityLimit(purityLimit),
fRandomisedTree (randomisedTree),
fUseNvars (useNvars),
fUsePoissonNvars(usePoissonNvars),
fMyTrandom (new TRandom3(iSeed)),
fMaxDepth (nMaxDepth),
fSigClass (cls),
fTreeID (treeID),
fAnalysisType (Types::kClassification)
{
if (sepType == NULL) {
fAnalysisType = Types::kRegression;
fRegType = new RegressionVariance();
if ( nCuts <=0 ) {
fNCuts = 200;
Log() << kWARNING << " You had choosen the training mode using optimal cuts, not\n"
<< " based on a grid of " << fNCuts << " by setting the option NCuts < 0\n"
<< " as this doesn't exist yet, I set it to " << fNCuts << " and use the grid"
<< Endl;
}
}else{
fAnalysisType = Types::kClassification;
}
}
TMVA::DecisionTree::DecisionTree( const DecisionTree &d ):
BinaryTree(),
fNvars (d.fNvars),
fNCuts (d.fNCuts),
fUseFisherCuts (d.fUseFisherCuts),
fMinLinCorrForFisher (d.fMinLinCorrForFisher),
fUseExclusiveVars (d.fUseExclusiveVars),
fSepType (d.fSepType),
fRegType (d.fRegType),
fMinSize (d.fMinSize),
fMinNodeSize(d.fMinNodeSize),
fMinSepGain (d.fMinSepGain),
fUseSearchTree (d.fUseSearchTree),
fPruneStrength (d.fPruneStrength),
fPruneMethod (d.fPruneMethod),
fNodePurityLimit(d.fNodePurityLimit),
fRandomisedTree (d.fRandomisedTree),
fUseNvars (d.fUseNvars),
fUsePoissonNvars(d.fUsePoissonNvars),
fMyTrandom (new TRandom3(fgRandomSeed)),
fMaxDepth (d.fMaxDepth),
fSigClass (d.fSigClass),
fTreeID (d.fTreeID),
fAnalysisType(d.fAnalysisType)
{
this->SetRoot( new TMVA::DecisionTreeNode ( *((DecisionTreeNode*)(d.GetRoot())) ) );
this->SetParentTreeInNodes();
fNNodes = d.fNNodes;
}
TMVA::DecisionTree::~DecisionTree()
{
if (fMyTrandom) delete fMyTrandom;
if (fRegType) delete fRegType;
}
void TMVA::DecisionTree::SetParentTreeInNodes( Node *n )
{
if (n == NULL) {
n = this->GetRoot();
if (n == NULL) {
Log() << kFATAL << "SetParentTreeNodes: started with undefined ROOT node" <<Endl;
return ;
}
}
if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) != NULL) ) {
Log() << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
return;
} else if ((this->GetLeftDaughter(n) != NULL) && (this->GetRightDaughter(n) == NULL) ) {
Log() << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
return;
}
else {
if (this->GetLeftDaughter(n) != NULL) {
this->SetParentTreeInNodes( this->GetLeftDaughter(n) );
}
if (this->GetRightDaughter(n) != NULL) {
this->SetParentTreeInNodes( this->GetRightDaughter(n) );
}
}
n->SetParentTree(this);
if (n->GetDepth() > this->GetTotalTreeDepth()) this->SetTotalTreeDepth(n->GetDepth());
return;
}
TMVA::DecisionTree* TMVA::DecisionTree::CreateFromXML(void* node, UInt_t tmva_Version_Code ) {
std::string type("");
gTools().ReadAttr(node,"type", type);
DecisionTree* dt = new DecisionTree();
dt->ReadXML( node, tmva_Version_Code );
return dt;
}
UInt_t TMVA::DecisionTree::BuildTree( const std::vector<const TMVA::Event*> & eventSample,
TMVA::DecisionTreeNode *node)
{
if (node==NULL) {
node = new TMVA::DecisionTreeNode();
fNNodes = 1;
this->SetRoot(node);
this->GetRoot()->SetPos('s');
this->GetRoot()->SetDepth(0);
this->GetRoot()->SetParentTree(this);
fMinSize = fMinNodeSize/100. * eventSample.size();
if (GetTreeID()==0){
Log() << kINFO << "The minimal node size MinNodeSize=" << fMinNodeSize << " fMinNodeSize="<<fMinNodeSize<< "% is translated to an actual number of events = "<< fMinSize<< " for the training sample size of " << eventSample.size() << Endl;
Log() << kINFO << "Note: This number will be taken as absolute minimum in the node, " << Endl;
Log() << kINFO << " in terms of 'weighted events' and unweighted ones !! " << Endl;
}
}
UInt_t nevents = eventSample.size();
if (nevents > 0 ) {
if (fNvars==0) fNvars = eventSample[0]->GetNVariables();
fVariableImportance.resize(fNvars);
}
else Log() << kFATAL << ":<BuildTree> eventsample Size == 0 " << Endl;
Double_t s=0, b=0;
Double_t suw=0, buw=0;
Double_t sub=0, bub=0;
Double_t target=0, target2=0;
Float_t *xmin = new Float_t[fNvars];
Float_t *xmax = new Float_t[fNvars];
for (UInt_t ivar=0; ivar<fNvars; ivar++) {
xmin[ivar]=xmax[ivar]=0;
}
for (UInt_t iev=0; iev<eventSample.size(); iev++) {
const TMVA::Event* evt = eventSample[iev];
const Double_t weight = evt->GetWeight();
const Double_t orgWeight = evt->GetOriginalWeight();
if (evt->GetClass() == fSigClass) {
s += weight;
suw += 1;
sub += orgWeight;
}
else {
b += weight;
buw += 1;
bub += orgWeight;
}
if ( DoRegression() ) {
const Double_t tgt = evt->GetTarget(0);
target +=weight*tgt;
target2+=weight*tgt*tgt;
}
for (UInt_t ivar=0; ivar<fNvars; ivar++) {
const Double_t val = evt->GetValue(ivar);
if (iev==0) xmin[ivar]=xmax[ivar]=val;
if (val < xmin[ivar]) xmin[ivar]=val;
if (val > xmax[ivar]) xmax[ivar]=val;
}
}
if (s+b < 0) {
Log() << kWARNING << " One of the Decision Tree nodes has negative total number of signal or background events. "
<< "(Nsig="<<s<<" Nbkg="<<b<<" Probaby you use a Monte Carlo with negative weights. That should in principle "
<< "be fine as long as on average you end up with something positive. For this you have to make sure that the "
<< "minimul number of (unweighted) events demanded for a tree node (currently you use: MinNodeSize="<<fMinNodeSize
<< "% of training events, you can set this via the BDT option string when booking the classifier) is large enough "
<< "to allow for reasonable averaging!!!" << Endl
<< " If this does not help.. maybe you want to try the option: NoNegWeightsInTraining which ignores events "
<< "with negative weight in the training." << Endl;
double nBkg=0.;
for (UInt_t i=0; i<eventSample.size(); i++) {
if (eventSample[i]->GetClass() != fSigClass) {
nBkg += eventSample[i]->GetWeight();
Log() << kDEBUG << "Event "<< i<< " has (original) weight: " << eventSample[i]->GetWeight()/eventSample[i]->GetBoostWeight()
<< " boostWeight: " << eventSample[i]->GetBoostWeight() << Endl;
}
}
Log() << kDEBUG << " that gives in total: " << nBkg<<Endl;
}
node->SetNSigEvents(s);
node->SetNBkgEvents(b);
node->SetNSigEvents_unweighted(suw);
node->SetNBkgEvents_unweighted(buw);
node->SetNSigEvents_unboosted(sub);
node->SetNBkgEvents_unboosted(bub);
node->SetPurity();
if (node == this->GetRoot()) {
node->SetNEvents(s+b);
node->SetNEvents_unweighted(suw+buw);
node->SetNEvents_unboosted(sub+bub);
}
for (UInt_t ivar=0; ivar<fNvars; ivar++) {
node->SetSampleMin(ivar,xmin[ivar]);
node->SetSampleMax(ivar,xmax[ivar]);
}
delete[] xmin;
delete[] xmax;
if ((eventSample.size() >= 2*fMinSize && s+b >= 2*fMinSize) && node->GetDepth() < fMaxDepth
&& ( ( s!=0 && b !=0 && !DoRegression()) || ( (s+b)!=0 && DoRegression()) ) ) {
Double_t separationGain;
if (fNCuts > 0){
separationGain = this->TrainNodeFast(eventSample, node);
} else {
separationGain = this->TrainNodeFull(eventSample, node);
}
if (separationGain < std::numeric_limits<double>::epsilon()) {
if (DoRegression()) {
node->SetSeparationIndex(fRegType->GetSeparationIndex(s+b,target,target2));
node->SetResponse(target/(s+b));
if( (target2/(s+b) - target/(s+b)*target/(s+b)) < std::numeric_limits<double>::epsilon() ){
node->SetRMS(0);
}else{
node->SetRMS(TMath::Sqrt(target2/(s+b) - target/(s+b)*target/(s+b)));
}
}
else {
node->SetSeparationIndex(fSepType->GetSeparationIndex(s,b));
if (node->GetPurity() > fNodePurityLimit) node->SetNodeType(1);
else node->SetNodeType(-1);
}
if (node->GetDepth() > this->GetTotalTreeDepth()) this->SetTotalTreeDepth(node->GetDepth());
} else {
std::vector<const TMVA::Event*> leftSample; leftSample.reserve(nevents);
std::vector<const TMVA::Event*> rightSample; rightSample.reserve(nevents);
Double_t nRight=0, nLeft=0;
Double_t nRightUnBoosted=0, nLeftUnBoosted=0;
for (UInt_t ie=0; ie< nevents ; ie++) {
if (node->GoesRight(*eventSample[ie])) {
rightSample.push_back(eventSample[ie]);
nRight += eventSample[ie]->GetWeight();
nRightUnBoosted += eventSample[ie]->GetOriginalWeight();
}
else {
leftSample.push_back(eventSample[ie]);
nLeft += eventSample[ie]->GetWeight();
nLeftUnBoosted += eventSample[ie]->GetOriginalWeight();
}
}
if (leftSample.empty() || rightSample.empty()) {
Log() << kFATAL << "<TrainNode> all events went to the same branch" << Endl
<< "--- Hence new node == old node ... check" << Endl
<< "--- left:" << leftSample.size()
<< " right:" << rightSample.size() << Endl
<< "--- this should never happen, please write a bug report to Helge.Voss@cern.ch"
<< Endl;
}
TMVA::DecisionTreeNode *rightNode = new TMVA::DecisionTreeNode(node,'r');
fNNodes++;
rightNode->SetNEvents(nRight);
rightNode->SetNEvents_unboosted(nRightUnBoosted);
rightNode->SetNEvents_unweighted(rightSample.size());
TMVA::DecisionTreeNode *leftNode = new TMVA::DecisionTreeNode(node,'l');
fNNodes++;
leftNode->SetNEvents(nLeft);
leftNode->SetNEvents_unboosted(nLeftUnBoosted);
leftNode->SetNEvents_unweighted(leftSample.size());
node->SetNodeType(0);
node->SetLeft(leftNode);
node->SetRight(rightNode);
this->BuildTree(rightSample, rightNode);
this->BuildTree(leftSample, leftNode );
}
}
else{
if (DoRegression()) {
node->SetSeparationIndex(fRegType->GetSeparationIndex(s+b,target,target2));
node->SetResponse(target/(s+b));
if( (target2/(s+b) - target/(s+b)*target/(s+b)) < std::numeric_limits<double>::epsilon() ) {
node->SetRMS(0);
}else{
node->SetRMS(TMath::Sqrt(target2/(s+b) - target/(s+b)*target/(s+b)));
}
}
else {
node->SetSeparationIndex(fSepType->GetSeparationIndex(s,b));
if (node->GetPurity() > fNodePurityLimit) node->SetNodeType(1);
else node->SetNodeType(-1);
}
if (node->GetDepth() > this->GetTotalTreeDepth()) this->SetTotalTreeDepth(node->GetDepth());
}
return fNNodes;
}
void TMVA::DecisionTree::FillTree( const std::vector<TMVA::Event*> & eventSample )
{
for (UInt_t i=0; i<eventSample.size(); i++) {
this->FillEvent(*(eventSample[i]),NULL);
}
}
void TMVA::DecisionTree::FillEvent( const TMVA::Event & event,
TMVA::DecisionTreeNode *node )
{
if (node == NULL) {
node = this->GetRoot();
}
node->IncrementNEvents( event.GetWeight() );
node->IncrementNEvents_unweighted( );
if (event.GetClass() == fSigClass) {
node->IncrementNSigEvents( event.GetWeight() );
node->IncrementNSigEvents_unweighted( );
}
else {
node->IncrementNBkgEvents( event.GetWeight() );
node->IncrementNBkgEvents_unweighted( );
}
node->SetSeparationIndex(fSepType->GetSeparationIndex(node->GetNSigEvents(),
node->GetNBkgEvents()));
if (node->GetNodeType() == 0) {
if (node->GoesRight(event))
this->FillEvent(event,dynamic_cast<TMVA::DecisionTreeNode*>(node->GetRight())) ;
else
this->FillEvent(event,dynamic_cast<TMVA::DecisionTreeNode*>(node->GetLeft())) ;
}
}
void TMVA::DecisionTree::ClearTree()
{
if (this->GetRoot()!=NULL) this->GetRoot()->ClearNodeAndAllDaughters();
}
UInt_t TMVA::DecisionTree::CleanTree( DecisionTreeNode *node )
{
if (node==NULL) {
node = this->GetRoot();
}
DecisionTreeNode *l = node->GetLeft();
DecisionTreeNode *r = node->GetRight();
if (node->GetNodeType() == 0) {
this->CleanTree(l);
this->CleanTree(r);
if (l->GetNodeType() * r->GetNodeType() > 0) {
this->PruneNode(node);
}
}
return this->CountNodes();
}
Double_t TMVA::DecisionTree::PruneTree( const EventConstList* validationSample )
{
IPruneTool* tool(NULL);
PruningInfo* info(NULL);
if( fPruneMethod == kNoPruning ) return 0.0;
if (fPruneMethod == kExpectedErrorPruning)
tool = new ExpectedErrorPruneTool();
else if (fPruneMethod == kCostComplexityPruning)
{
tool = new CostComplexityPruneTool();
}
else {
Log() << kFATAL << "Selected pruning method not yet implemented "
<< Endl;
}
if(!tool) return 0.0;
tool->SetPruneStrength(GetPruneStrength());
if(tool->IsAutomatic()) {
if(validationSample == NULL){
Log() << kFATAL << "Cannot automate the pruning algorithm without an "
<< "independent validation sample!" << Endl;
}else if(validationSample->size() == 0) {
Log() << kFATAL << "Cannot automate the pruning algorithm with "
<< "independent validation sample of ZERO events!" << Endl;
}
}
info = tool->CalculatePruningInfo(this,validationSample);
Double_t pruneStrength=0;
if(!info) {
Log() << kFATAL << "Error pruning tree! Check prune.log for more information."
<< Endl;
} else {
pruneStrength = info->PruneStrength;
for (UInt_t i = 0; i < info->PruneSequence.size(); ++i) {
PruneNode(info->PruneSequence[i]);
}
this->CountNodes();
}
delete tool;
delete info;
return pruneStrength;
};
void TMVA::DecisionTree::ApplyValidationSample( const EventConstList* validationSample ) const
{
GetRoot()->ResetValidationData();
for (UInt_t ievt=0; ievt < validationSample->size(); ievt++) {
CheckEventWithPrunedTree((*validationSample)[ievt]);
}
}
Double_t TMVA::DecisionTree::TestPrunedTreeQuality( const DecisionTreeNode* n, Int_t mode ) const
{
if (n == NULL) {
n = this->GetRoot();
if (n == NULL) {
Log() << kFATAL << "TestPrunedTreeQuality: started with undefined ROOT node" <<Endl;
return 0;
}
}
if( n->GetLeft() != NULL && n->GetRight() != NULL && !n->IsTerminal() ) {
return (TestPrunedTreeQuality( n->GetLeft(), mode ) +
TestPrunedTreeQuality( n->GetRight(), mode ));
}
else {
if (DoRegression()) {
Double_t sumw = n->GetNSValidation() + n->GetNBValidation();
return n->GetSumTarget2() - 2*n->GetSumTarget()*n->GetResponse() + sumw*n->GetResponse()*n->GetResponse();
}
else {
if (mode == 0) {
if (n->GetPurity() > this->GetNodePurityLimit())
return n->GetNBValidation();
else
return n->GetNSValidation();
}
else if ( mode == 1 ) {
return (n->GetPurity() * n->GetNBValidation() + (1.0 - n->GetPurity()) * n->GetNSValidation());
}
else {
throw std::string("Unknown ValidationQualityMode");
}
}
}
}
void TMVA::DecisionTree::CheckEventWithPrunedTree( const Event* e ) const
{
DecisionTreeNode* current = this->GetRoot();
if (current == NULL) {
Log() << kFATAL << "CheckEventWithPrunedTree: started with undefined ROOT node" <<Endl;
}
while(current != NULL) {
if(e->GetClass() == fSigClass)
current->SetNSValidation(current->GetNSValidation() + e->GetWeight());
else
current->SetNBValidation(current->GetNBValidation() + e->GetWeight());
if (e->GetNTargets() > 0) {
current->AddToSumTarget(e->GetWeight()*e->GetTarget(0));
current->AddToSumTarget2(e->GetWeight()*e->GetTarget(0)*e->GetTarget(0));
}
if (current->GetRight() == NULL || current->GetLeft() == NULL) {
current = NULL;
}
else {
if (current->GoesRight(*e))
current = (TMVA::DecisionTreeNode*)current->GetRight();
else
current = (TMVA::DecisionTreeNode*)current->GetLeft();
}
}
}
Double_t TMVA::DecisionTree::GetSumWeights( const EventConstList* validationSample ) const
{
Double_t sumWeights = 0.0;
for( EventConstList::const_iterator it = validationSample->begin();
it != validationSample->end(); ++it ) {
sumWeights += (*it)->GetWeight();
}
return sumWeights;
}
UInt_t TMVA::DecisionTree::CountLeafNodes( TMVA::Node *n )
{
if (n == NULL) {
n = this->GetRoot();
if (n == NULL) {
Log() << kFATAL << "CountLeafNodes: started with undefined ROOT node" <<Endl;
return 0;
}
}
UInt_t countLeafs=0;
if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) == NULL) ) {
countLeafs += 1;
}
else {
if (this->GetLeftDaughter(n) != NULL) {
countLeafs += this->CountLeafNodes( this->GetLeftDaughter(n) );
}
if (this->GetRightDaughter(n) != NULL) {
countLeafs += this->CountLeafNodes( this->GetRightDaughter(n) );
}
}
return countLeafs;
}
void TMVA::DecisionTree::DescendTree( Node* n )
{
if (n == NULL) {
n = this->GetRoot();
if (n == NULL) {
Log() << kFATAL << "DescendTree: started with undefined ROOT node" <<Endl;
return ;
}
}
if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) == NULL) ) {
}
else if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) != NULL) ) {
Log() << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
return;
}
else if ((this->GetLeftDaughter(n) != NULL) && (this->GetRightDaughter(n) == NULL) ) {
Log() << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
return;
}
else {
if (this->GetLeftDaughter(n) != NULL) {
this->DescendTree( this->GetLeftDaughter(n) );
}
if (this->GetRightDaughter(n) != NULL) {
this->DescendTree( this->GetRightDaughter(n) );
}
}
}
void TMVA::DecisionTree::PruneNode( DecisionTreeNode* node )
{
DecisionTreeNode *l = node->GetLeft();
DecisionTreeNode *r = node->GetRight();
node->SetRight(NULL);
node->SetLeft(NULL);
node->SetSelector(-1);
node->SetSeparationGain(-1);
if (node->GetPurity() > fNodePurityLimit) node->SetNodeType(1);
else node->SetNodeType(-1);
this->DeleteNode(l);
this->DeleteNode(r);
this->CountNodes();
}
void TMVA::DecisionTree::PruneNodeInPlace( DecisionTreeNode* node ) {
if(node == NULL) return;
node->SetNTerminal(1);
node->SetSubTreeR( node->GetNodeR() );
node->SetAlpha( std::numeric_limits<double>::infinity( ) );
node->SetAlphaMinSubtree( std::numeric_limits<double>::infinity( ) );
node->SetTerminal(kTRUE);
}
TMVA::Node* TMVA::DecisionTree::GetNode( ULong_t sequence, UInt_t depth )
{
Node* current = this->GetRoot();
for (UInt_t i =0; i < depth; i++) {
ULong_t tmp = 1 << i;
if ( tmp & sequence) current = this->GetRightDaughter(current);
else current = this->GetLeftDaughter(current);
}
return current;
}
void TMVA::DecisionTree::GetRandomisedVariables(Bool_t *useVariable, UInt_t *mapVariable, UInt_t &useNvars){
for (UInt_t ivar=0; ivar<fNvars; ivar++) useVariable[ivar]=kFALSE;
if (fUseNvars==0) {
fUseNvars = UInt_t(TMath::Sqrt(fNvars)+0.6);
}
if (fUsePoissonNvars) useNvars=TMath::Min(fNvars,TMath::Max(UInt_t(1),(UInt_t) fMyTrandom->Poisson(fUseNvars)));
else useNvars = fUseNvars;
UInt_t nSelectedVars = 0;
while (nSelectedVars < useNvars) {
Double_t bla = fMyTrandom->Rndm()*fNvars;
useVariable[Int_t (bla)] = kTRUE;
nSelectedVars = 0;
for (UInt_t ivar=0; ivar < fNvars; ivar++) {
if (useVariable[ivar] == kTRUE) {
mapVariable[nSelectedVars] = ivar;
nSelectedVars++;
}
}
}
if (nSelectedVars != useNvars) { std::cout << "Bug in TrainNode - GetRandisedVariables()... sorry" << std::endl; std::exit(1);}
}
Double_t TMVA::DecisionTree::TrainNodeFast( const EventConstList & eventSample,
TMVA::DecisionTreeNode *node )
{
Double_t separationGainTotal = -1, sepTmp;
Double_t *separationGain = new Double_t[fNvars+1];
Int_t *cutIndex = new Int_t[fNvars+1];
for (UInt_t ivar=0; ivar <= fNvars; ivar++) {
separationGain[ivar]=-1;
cutIndex[ivar]=-1;
}
Int_t mxVar = -1;
Bool_t cutType = kTRUE;
Double_t nTotS, nTotB;
Int_t nTotS_unWeighted, nTotB_unWeighted;
UInt_t nevents = eventSample.size();
Bool_t *useVariable = new Bool_t[fNvars+1];
UInt_t *mapVariable = new UInt_t[fNvars+1];
std::vector<Double_t> fisherCoeff;
if (fRandomisedTree) {
UInt_t tmp=fUseNvars;
GetRandomisedVariables(useVariable,mapVariable,tmp);
}
else {
for (UInt_t ivar=0; ivar < fNvars; ivar++) {
useVariable[ivar] = kTRUE;
mapVariable[ivar] = ivar;
}
}
useVariable[fNvars] = kFALSE;
Bool_t fisherOK = kFALSE;
if (fUseFisherCuts) {
useVariable[fNvars] = kTRUE;
Bool_t *useVarInFisher = new Bool_t[fNvars];
UInt_t *mapVarInFisher = new UInt_t[fNvars];
for (UInt_t ivar=0; ivar < fNvars; ivar++) {
useVarInFisher[ivar] = kFALSE;
mapVarInFisher[ivar] = ivar;
}
std::vector<TMatrixDSym*>* covMatrices;
covMatrices = gTools().CalcCovarianceMatrices( eventSample, 2 );
if (!covMatrices){
Log() << kWARNING << " in TrainNodeFast, the covariance Matrices needed for the Fisher-Cuts returned error --> revert to just normal cuts for this node" << Endl;
fisherOK = kFALSE;
}else{
TMatrixD *ss = new TMatrixD(*(covMatrices->at(0)));
TMatrixD *bb = new TMatrixD(*(covMatrices->at(1)));
const TMatrixD *s = gTools().GetCorrelationMatrix(ss);
const TMatrixD *b = gTools().GetCorrelationMatrix(bb);
for (UInt_t ivar=0; ivar < fNvars; ivar++) {
for (UInt_t jvar=ivar+1; jvar < fNvars; jvar++) {
if ( ( TMath::Abs( (*s)(ivar, jvar)) > fMinLinCorrForFisher) ||
( TMath::Abs( (*b)(ivar, jvar)) > fMinLinCorrForFisher) ){
useVarInFisher[ivar] = kTRUE;
useVarInFisher[jvar] = kTRUE;
}
}
}
UInt_t nFisherVars = 0;
for (UInt_t ivar=0; ivar < fNvars; ivar++) {
if (useVarInFisher[ivar] && useVariable[ivar]) {
mapVarInFisher[nFisherVars++]=ivar;
if (fUseExclusiveVars) useVariable[ivar] = kFALSE;
}
}
fisherCoeff = this->GetFisherCoefficients(eventSample, nFisherVars, mapVarInFisher);
fisherOK = kTRUE;
}
delete [] useVarInFisher;
delete [] mapVarInFisher;
}
const UInt_t nBins = fNCuts+1;
UInt_t cNvars = fNvars;
if (fUseFisherCuts && fisherOK) cNvars++;
Double_t** nSelS = new Double_t* [cNvars];
Double_t** nSelB = new Double_t* [cNvars];
Double_t** nSelS_unWeighted = new Double_t* [cNvars];
Double_t** nSelB_unWeighted = new Double_t* [cNvars];
Double_t** target = new Double_t* [cNvars];
Double_t** target2 = new Double_t* [cNvars];
Double_t** cutValues = new Double_t* [cNvars];
for (UInt_t i=0; i<cNvars; i++) {
nSelS[i] = new Double_t [nBins];
nSelB[i] = new Double_t [nBins];
nSelS_unWeighted[i] = new Double_t [nBins];
nSelB_unWeighted[i] = new Double_t [nBins];
target[i] = new Double_t [nBins];
target2[i] = new Double_t [nBins];
cutValues[i] = new Double_t [nBins];
}
Double_t *xmin = new Double_t[cNvars];
Double_t *xmax = new Double_t[cNvars];
for (UInt_t ivar=0; ivar < cNvars; ivar++) {
if (ivar < fNvars){
xmin[ivar]=node->GetSampleMin(ivar);
xmax[ivar]=node->GetSampleMax(ivar);
if (xmax[ivar]-xmin[ivar] < std::numeric_limits<double>::epsilon() ) {
useVariable[ivar]=kFALSE;
}
} else {
xmin[ivar]=999;
xmax[ivar]=-999;
for (UInt_t iev=0; iev<nevents; iev++) {
Double_t result = fisherCoeff[fNvars];
for (UInt_t jvar=0; jvar<fNvars; jvar++)
result += fisherCoeff[jvar]*(eventSample[iev])->GetValue(jvar);
if (result > xmax[ivar]) xmax[ivar]=result;
if (result < xmin[ivar]) xmin[ivar]=result;
}
}
for (UInt_t ibin=0; ibin<nBins; ibin++) {
nSelS[ivar][ibin]=0;
nSelB[ivar][ibin]=0;
nSelS_unWeighted[ivar][ibin]=0;
nSelB_unWeighted[ivar][ibin]=0;
target[ivar][ibin]=0;
target2[ivar][ibin]=0;
cutValues[ivar][ibin]=0;
}
}
for (UInt_t ivar=0; ivar < cNvars; ivar++) {
if ( useVariable[ivar] ) {
Double_t istepSize =( xmax[ivar] - xmin[ivar] ) / Double_t(nBins);
for (Int_t icut=0; icut<fNCuts; icut++) {
cutValues[ivar][icut]=xmin[ivar]+(Double_t(icut+1))*istepSize;
}
}
}
nTotS=0; nTotB=0;
nTotS_unWeighted=0; nTotB_unWeighted=0;
for (UInt_t iev=0; iev<nevents; iev++) {
Double_t eventWeight = eventSample[iev]->GetWeight();
if (eventSample[iev]->GetClass() == fSigClass) {
nTotS+=eventWeight;
nTotS_unWeighted++;
}
else {
nTotB+=eventWeight;
nTotB_unWeighted++;
}
Int_t iBin=-1;
for (UInt_t ivar=0; ivar < cNvars; ivar++) {
if ( useVariable[ivar] ) {
Double_t eventData;
if (ivar < fNvars) eventData = eventSample[iev]->GetValue(ivar);
else {
eventData = fisherCoeff[fNvars];
for (UInt_t jvar=0; jvar<fNvars; jvar++)
eventData += fisherCoeff[jvar]*(eventSample[iev])->GetValue(jvar);
}
iBin = TMath::Min(Int_t(nBins-1),TMath::Max(0,int (nBins*(eventData-xmin[ivar])/(xmax[ivar]-xmin[ivar]) ) ));
if (eventSample[iev]->GetClass() == fSigClass) {
nSelS[ivar][iBin]+=eventWeight;
nSelS_unWeighted[ivar][iBin]++;
}
else {
nSelB[ivar][iBin]+=eventWeight;
nSelB_unWeighted[ivar][iBin]++;
}
if (DoRegression()) {
target[ivar][iBin] +=eventWeight*eventSample[iev]->GetTarget(0);
target2[ivar][iBin]+=eventWeight*eventSample[iev]->GetTarget(0)*eventSample[iev]->GetTarget(0);
}
}
}
}
for (UInt_t ivar=0; ivar < cNvars; ivar++) {
if (useVariable[ivar]) {
for (UInt_t ibin=1; ibin < nBins; ibin++) {
nSelS[ivar][ibin]+=nSelS[ivar][ibin-1];
nSelS_unWeighted[ivar][ibin]+=nSelS_unWeighted[ivar][ibin-1];
nSelB[ivar][ibin]+=nSelB[ivar][ibin-1];
nSelB_unWeighted[ivar][ibin]+=nSelB_unWeighted[ivar][ibin-1];
if (DoRegression()) {
target[ivar][ibin] +=target[ivar][ibin-1] ;
target2[ivar][ibin]+=target2[ivar][ibin-1];
}
}
if (nSelS_unWeighted[ivar][nBins-1] +nSelB_unWeighted[ivar][nBins-1] != eventSample.size()) {
Log() << kFATAL << "Helge, you have a bug ....nSelS_unw..+nSelB_unw..= "
<< nSelS_unWeighted[ivar][nBins-1] +nSelB_unWeighted[ivar][nBins-1]
<< " while eventsample size = " << eventSample.size()
<< Endl;
}
double lastBins=nSelS[ivar][nBins-1] +nSelB[ivar][nBins-1];
double totalSum=nTotS+nTotB;
if (TMath::Abs(lastBins-totalSum)/totalSum>0.01) {
Log() << kFATAL << "Helge, you have another bug ....nSelS+nSelB= "
<< lastBins
<< " while total number of events = " << totalSum
<< Endl;
}
}
}
for (UInt_t ivar=0; ivar < cNvars; ivar++) {
if (useVariable[ivar]) {
for (UInt_t iBin=0; iBin<nBins-1; iBin++) {
Double_t sl = nSelS_unWeighted[ivar][iBin];
Double_t bl = nSelB_unWeighted[ivar][iBin];
Double_t s = nTotS_unWeighted;
Double_t b = nTotB_unWeighted;
Double_t slW = nSelS[ivar][iBin];
Double_t blW = nSelB[ivar][iBin];
Double_t sW = nTotS;
Double_t bW = nTotB;
Double_t sr = s-sl;
Double_t br = b-bl;
Double_t srW = sW-slW;
Double_t brW = bW-blW;
if ( ((sl+bl)>=fMinSize && (sr+br)>=fMinSize)
&& ((slW+blW)>=fMinSize && (srW+brW)>=fMinSize)
) {
if (DoRegression()) {
sepTmp = fRegType->GetSeparationGain(nSelS[ivar][iBin]+nSelB[ivar][iBin],
target[ivar][iBin],target2[ivar][iBin],
nTotS+nTotB,
target[ivar][nBins-1],target2[ivar][nBins-1]);
} else {
sepTmp = fSepType->GetSeparationGain(nSelS[ivar][iBin], nSelB[ivar][iBin], nTotS, nTotB);
}
if (separationGain[ivar] < sepTmp) {
separationGain[ivar] = sepTmp;
cutIndex[ivar] = iBin;
}
}
}
}
}
for (UInt_t ivar=0; ivar < cNvars; ivar++) {
if (useVariable[ivar] ) {
if (separationGainTotal < separationGain[ivar]) {
separationGainTotal = separationGain[ivar];
mxVar = ivar;
}
}
}
if (DoRegression()) {
node->SetSeparationIndex(fRegType->GetSeparationIndex(nTotS+nTotB,target[0][nBins-1],target2[0][nBins-1]));
node->SetResponse(target[0][nBins-1]/(nTotS+nTotB));
if ( (target2[0][nBins-1]/(nTotS+nTotB) - target[0][nBins-1]/(nTotS+nTotB)*target[0][nBins-1]/(nTotS+nTotB)) < std::numeric_limits<double>::epsilon() ) {
node->SetRMS(0);
}else{
node->SetRMS(TMath::Sqrt(target2[0][nBins-1]/(nTotS+nTotB) - target[0][nBins-1]/(nTotS+nTotB)*target[0][nBins-1]/(nTotS+nTotB)));
}
}
else {
node->SetSeparationIndex(fSepType->GetSeparationIndex(nTotS,nTotB));
if (mxVar >=0){
if (nSelS[mxVar][cutIndex[mxVar]]/nTotS > nSelB[mxVar][cutIndex[mxVar]]/nTotB) cutType=kTRUE;
else cutType=kFALSE;
}
}
if (mxVar >= 0) {
node->SetSelector((UInt_t)mxVar);
node->SetCutValue(cutValues[mxVar][cutIndex[mxVar]]);
node->SetCutType(cutType);
node->SetSeparationGain(separationGainTotal);
if (mxVar < (Int_t) fNvars){
node->SetNFisherCoeff(0);
fVariableImportance[mxVar] += separationGainTotal*separationGainTotal * (nTotS+nTotB) * (nTotS+nTotB) ;
}else{
node->SetNFisherCoeff(fNvars+1);
for (UInt_t ivar=0; ivar<=fNvars; ivar++) {
node->SetFisherCoeff(ivar,fisherCoeff[ivar]);
if (ivar<fNvars){
fVariableImportance[ivar] += fisherCoeff[ivar]*fisherCoeff[ivar]*separationGainTotal*separationGainTotal * (nTotS+nTotB) * (nTotS+nTotB) ;
}
}
}
}
else {
separationGainTotal = 0;
}
for (UInt_t i=0; i<cNvars; i++) {
delete [] nSelS[i];
delete [] nSelB[i];
delete [] nSelS_unWeighted[i];
delete [] nSelB_unWeighted[i];
delete [] target[i];
delete [] target2[i];
delete [] cutValues[i];
}
delete [] nSelS;
delete [] nSelB;
delete [] nSelS_unWeighted;
delete [] nSelB_unWeighted;
delete [] target;
delete [] target2;
delete [] cutValues;
delete [] xmin;
delete [] xmax;
delete [] useVariable;
delete [] mapVariable;
delete [] separationGain;
delete [] cutIndex;
return separationGainTotal;
}
std::vector<Double_t> TMVA::DecisionTree::GetFisherCoefficients(const EventConstList &eventSample, UInt_t nFisherVars, UInt_t *mapVarInFisher){
std::vector<Double_t> fisherCoeff(fNvars+1);
TMatrixD* meanMatx = new TMatrixD( nFisherVars, 3 );
TMatrixD* betw = new TMatrixD( nFisherVars, nFisherVars );
TMatrixD* with = new TMatrixD( nFisherVars, nFisherVars );
TMatrixD* cov = new TMatrixD( nFisherVars, nFisherVars );
Double_t sumOfWeightsS = 0;
Double_t sumOfWeightsB = 0;
Double_t* sumS = new Double_t[nFisherVars];
Double_t* sumB = new Double_t[nFisherVars];
for (UInt_t ivar=0; ivar<nFisherVars; ivar++) { sumS[ivar] = sumB[ivar] = 0; }
UInt_t nevents = eventSample.size();
for (UInt_t ievt=0; ievt<nevents; ievt++) {
const Event * ev = eventSample[ievt];
Double_t weight = ev->GetWeight();
if (ev->GetClass() == fSigClass) sumOfWeightsS += weight;
else sumOfWeightsB += weight;
Double_t* sum = ev->GetClass() == fSigClass ? sumS : sumB;
for (UInt_t ivar=0; ivar<nFisherVars; ivar++) {
sum[ivar] += ev->GetValue( mapVarInFisher[ivar] )*weight;
}
}
for (UInt_t ivar=0; ivar<nFisherVars; ivar++) {
(*meanMatx)( ivar, 2 ) = sumS[ivar];
(*meanMatx)( ivar, 0 ) = sumS[ivar]/sumOfWeightsS;
(*meanMatx)( ivar, 2 ) += sumB[ivar];
(*meanMatx)( ivar, 1 ) = sumB[ivar]/sumOfWeightsB;
(*meanMatx)( ivar, 2 ) /= (sumOfWeightsS + sumOfWeightsB);
}
delete [] sumS;
delete [] sumB;
assert( sumOfWeightsS > 0 && sumOfWeightsB > 0 );
const Int_t nFisherVars2 = nFisherVars*nFisherVars;
Double_t *sum2Sig = new Double_t[nFisherVars2];
Double_t *sum2Bgd = new Double_t[nFisherVars2];
Double_t *xval = new Double_t[nFisherVars2];
memset(sum2Sig,0,nFisherVars2*sizeof(Double_t));
memset(sum2Bgd,0,nFisherVars2*sizeof(Double_t));
for (UInt_t ievt=0; ievt<nevents; ievt++) {
const Event* ev = eventSample.at(ievt);
Double_t weight = ev->GetWeight();
for (UInt_t x=0; x<nFisherVars; x++) {
xval[x] = ev->GetValue( mapVarInFisher[x] );
}
Int_t k=0;
for (UInt_t x=0; x<nFisherVars; x++) {
for (UInt_t y=0; y<nFisherVars; y++) {
if ( ev->GetClass() == fSigClass ) sum2Sig[k] += ( (xval[x] - (*meanMatx)(x, 0))*(xval[y] - (*meanMatx)(y, 0)) )*weight;
else sum2Bgd[k] += ( (xval[x] - (*meanMatx)(x, 1))*(xval[y] - (*meanMatx)(y, 1)) )*weight;
k++;
}
}
}
Int_t k=0;
for (UInt_t x=0; x<nFisherVars; x++) {
for (UInt_t y=0; y<nFisherVars; y++) {
(*with)(x, y) = sum2Sig[k]/sumOfWeightsS + sum2Bgd[k]/sumOfWeightsB;
k++;
}
}
delete [] sum2Sig;
delete [] sum2Bgd;
delete [] xval;
Double_t prodSig, prodBgd;
for (UInt_t x=0; x<nFisherVars; x++) {
for (UInt_t y=0; y<nFisherVars; y++) {
prodSig = ( ((*meanMatx)(x, 0) - (*meanMatx)(x, 2))*
((*meanMatx)(y, 0) - (*meanMatx)(y, 2)) );
prodBgd = ( ((*meanMatx)(x, 1) - (*meanMatx)(x, 2))*
((*meanMatx)(y, 1) - (*meanMatx)(y, 2)) );
(*betw)(x, y) = (sumOfWeightsS*prodSig + sumOfWeightsB*prodBgd) / (sumOfWeightsS + sumOfWeightsB);
}
}
for (UInt_t x=0; x<nFisherVars; x++)
for (UInt_t y=0; y<nFisherVars; y++)
(*cov)(x, y) = (*with)(x, y) + (*betw)(x, y);
TMatrixD* theMat = with;
TMatrixD invCov( *theMat );
if ( TMath::Abs(invCov.Determinant()) < 10E-24 ) {
Log() << kWARNING << "FisherCoeff matrix is almost singular with deterninant="
<< TMath::Abs(invCov.Determinant())
<< " did you use the variables that are linear combinations or highly correlated?"
<< Endl;
}
if ( TMath::Abs(invCov.Determinant()) < 10E-120 ) {
Log() << kFATAL << "FisherCoeff matrix is singular with determinant="
<< TMath::Abs(invCov.Determinant())
<< " did you use the variables that are linear combinations?"
<< Endl;
}
invCov.Invert();
Double_t xfact = TMath::Sqrt( sumOfWeightsS*sumOfWeightsB ) / (sumOfWeightsS + sumOfWeightsB);
std::vector<Double_t> diffMeans( nFisherVars );
for (UInt_t ivar=0; ivar<=fNvars; ivar++) fisherCoeff[ivar] = 0;
for (UInt_t ivar=0; ivar<nFisherVars; ivar++) {
for (UInt_t jvar=0; jvar<nFisherVars; jvar++) {
Double_t d = (*meanMatx)(jvar, 0) - (*meanMatx)(jvar, 1);
fisherCoeff[mapVarInFisher[ivar]] += invCov(ivar, jvar)*d;
}
fisherCoeff[mapVarInFisher[ivar]] *= xfact;
}
Double_t f0 = 0.0;
for (UInt_t ivar=0; ivar<nFisherVars; ivar++){
f0 += fisherCoeff[mapVarInFisher[ivar]]*((*meanMatx)(ivar, 0) + (*meanMatx)(ivar, 1));
}
f0 /= -2.0;
fisherCoeff[fNvars] = f0;
return fisherCoeff;
}
Double_t TMVA::DecisionTree::TrainNodeFull( const EventConstList & eventSample,
TMVA::DecisionTreeNode *node )
{
Double_t nTotS = 0.0, nTotB = 0.0;
Int_t nTotS_unWeighted = 0, nTotB_unWeighted = 0;
std::vector<TMVA::BDTEventWrapper> bdtEventSample;
std::vector<Double_t> lCutValue( fNvars, 0.0 );
std::vector<Double_t> lSepGain( fNvars, -1.0e6 );
std::vector<Char_t> lCutType( fNvars );
lCutType.assign( fNvars, Char_t(kFALSE) );
for( std::vector<const TMVA::Event*>::const_iterator it = eventSample.begin(); it != eventSample.end(); ++it ) {
if((*it)->GetClass() == fSigClass) {
nTotS += (*it)->GetWeight();
++nTotS_unWeighted;
}
else {
nTotB += (*it)->GetWeight();
++nTotB_unWeighted;
}
bdtEventSample.push_back(TMVA::BDTEventWrapper(*it));
}
std::vector<Char_t> useVariable(fNvars);
useVariable.assign( fNvars, Char_t(kTRUE) );
for (UInt_t ivar=0; ivar < fNvars; ivar++) useVariable[ivar]=Char_t(kFALSE);
if (fRandomisedTree) {
if (fUseNvars ==0 ) {
fUseNvars = UInt_t(TMath::Sqrt(fNvars)+0.6);
}
Int_t nSelectedVars = 0;
while (nSelectedVars < fUseNvars) {
Double_t bla = fMyTrandom->Rndm()*fNvars;
useVariable[Int_t (bla)] = Char_t(kTRUE);
nSelectedVars = 0;
for (UInt_t ivar=0; ivar < fNvars; ivar++) {
if(useVariable[ivar] == Char_t(kTRUE)) nSelectedVars++;
}
}
}
else {
for (UInt_t ivar=0; ivar < fNvars; ivar++) useVariable[ivar] = Char_t(kTRUE);
}
for( UInt_t ivar = 0; ivar < fNvars; ivar++ ) {
if(!useVariable[ivar]) continue;
TMVA::BDTEventWrapper::SetVarIndex(ivar);
std::sort( bdtEventSample.begin(),bdtEventSample.end() );
Double_t bkgWeightCtr = 0.0, sigWeightCtr = 0.0;
std::vector<TMVA::BDTEventWrapper>::iterator it = bdtEventSample.begin(), it_end = bdtEventSample.end();
for( ; it != it_end; ++it ) {
if((**it)->GetClass() == fSigClass )
sigWeightCtr += (**it)->GetWeight();
else
bkgWeightCtr += (**it)->GetWeight();
it->SetCumulativeWeight(false,bkgWeightCtr);
it->SetCumulativeWeight(true,sigWeightCtr);
}
const Double_t fPMin = 1.0e-6;
Bool_t cutType = kFALSE;
Long64_t index = 0;
Double_t separationGain = -1.0, sepTmp = 0.0, cutValue = 0.0, dVal = 0.0, norm = 0.0;
for( it = bdtEventSample.begin(); it != it_end; ++it ) {
if( index == 0 ) { ++index; continue; }
if( *(*it) == NULL ) {
Log() << kFATAL << "In TrainNodeFull(): have a null event! Where index="
<< index << ", and parent node=" << node->GetParent() << Endl;
break;
}
dVal = bdtEventSample[index].GetVal() - bdtEventSample[index-1].GetVal();
norm = TMath::Abs(bdtEventSample[index].GetVal() + bdtEventSample[index-1].GetVal());
if( index >= fMinSize && (nTotS_unWeighted + nTotB_unWeighted) - index >= fMinSize && TMath::Abs(dVal/(0.5*norm + 1)) > fPMin ) {
sepTmp = fSepType->GetSeparationGain( it->GetCumulativeWeight(true), it->GetCumulativeWeight(false), sigWeightCtr, bkgWeightCtr );
if( sepTmp > separationGain ) {
separationGain = sepTmp;
cutValue = it->GetVal() - 0.5*dVal;
Double_t nSelS = it->GetCumulativeWeight(true);
Double_t nSelB = it->GetCumulativeWeight(false);
if( nSelS/sigWeightCtr > nSelB/bkgWeightCtr ) cutType = kTRUE;
else cutType = kFALSE;
}
}
++index;
}
lCutType[ivar] = Char_t(cutType);
lCutValue[ivar] = cutValue;
lSepGain[ivar] = separationGain;
}
Double_t separationGain = -1.0;
Int_t iVarIndex = -1;
for( UInt_t ivar = 0; ivar < fNvars; ivar++ ) {
if( lSepGain[ivar] > separationGain ) {
iVarIndex = ivar;
separationGain = lSepGain[ivar];
}
}
if(iVarIndex >= 0) {
node->SetSelector(iVarIndex);
node->SetCutValue(lCutValue[iVarIndex]);
node->SetSeparationGain(lSepGain[iVarIndex]);
node->SetCutType(lCutType[iVarIndex]);
fVariableImportance[iVarIndex] += separationGain*separationGain * (nTotS+nTotB) * (nTotS+nTotB);
}
else {
separationGain = 0.0;
}
return separationGain;
}
TMVA::DecisionTreeNode* TMVA::DecisionTree::GetEventNode(const TMVA::Event & e) const
{
TMVA::DecisionTreeNode *current = (TMVA::DecisionTreeNode*)this->GetRoot();
while(current->GetNodeType() == 0) {
current = (current->GoesRight(e)) ?
(TMVA::DecisionTreeNode*)current->GetRight() :
(TMVA::DecisionTreeNode*)current->GetLeft();
}
return current;
}
Double_t TMVA::DecisionTree::CheckEvent( const TMVA::Event * e, Bool_t UseYesNoLeaf ) const
{
TMVA::DecisionTreeNode *current = this->GetRoot();
if (!current){
Log() << kFATAL << "CheckEvent: started with undefined ROOT node" <<Endl;
return 0;
}
while (current->GetNodeType() == 0) {
current = (current->GoesRight(*e)) ?
current->GetRight() :
current->GetLeft();
if (!current) {
Log() << kFATAL << "DT::CheckEvent: inconsistent tree structure" <<Endl;
}
}
if ( DoRegression() ){
return current->GetResponse();
}
else {
if (UseYesNoLeaf) return Double_t ( current->GetNodeType() );
else return current->GetPurity();
}
}
Double_t TMVA::DecisionTree::SamplePurity( std::vector<TMVA::Event*> eventSample )
{
Double_t sumsig=0, sumbkg=0, sumtot=0;
for (UInt_t ievt=0; ievt<eventSample.size(); ievt++) {
if (eventSample[ievt]->GetClass() != fSigClass) sumbkg+=eventSample[ievt]->GetWeight();
else sumsig+=eventSample[ievt]->GetWeight();
sumtot+=eventSample[ievt]->GetWeight();
}
if (sumtot!= (sumsig+sumbkg)){
Log() << kFATAL << "<SamplePurity> sumtot != sumsig+sumbkg"
<< sumtot << " " << sumsig << " " << sumbkg << Endl;
}
if (sumtot>0) return sumsig/(sumsig + sumbkg);
else return -1;
}
vector< Double_t > TMVA::DecisionTree::GetVariableImportance()
{
std::vector<Double_t> relativeImportance(fNvars);
Double_t sum=0;
for (UInt_t i=0; i< fNvars; i++) {
sum += fVariableImportance[i];
relativeImportance[i] = fVariableImportance[i];
}
for (UInt_t i=0; i< fNvars; i++) {
if (sum > std::numeric_limits<double>::epsilon())
relativeImportance[i] /= sum;
else
relativeImportance[i] = 0;
}
return relativeImportance;
}
Double_t TMVA::DecisionTree::GetVariableImportance( UInt_t ivar )
{
std::vector<Double_t> relativeImportance = this->GetVariableImportance();
if (ivar < fNvars) return relativeImportance[ivar];
else {
Log() << kFATAL << "<GetVariableImportance>" << Endl
<< "--- ivar = " << ivar << " is out of range " << Endl;
}
return -1;
}