#include <iostream>
#include <algorithm>
#include <vector>
#include "TMath.h"
#include "TMVA/DecisionTree.h"
#include "TMVA/DecisionTreeNode.h"
#include "TMVA/BinarySearchTree.h"
#include "TMVA/Tools.h"
#include "TMVA/GiniIndex.h"
#include "TMVA/CrossEntropy.h"
#include "TMVA/MisClassificationError.h"
#include "TMVA/SdivSqrtSplusB.h"
#include "TMVA/Event.h"
using std::vector;
#define USE_HELGESCODE 1 // the other one is Dougs implementation of the TrainNode
#define USE_HELGE_V1 0 // out loop is over NVAR in TrainNode, inner loop is Eventloop
ClassImp(TMVA::DecisionTree)
TMVA::DecisionTree::DecisionTree( void )
: BinaryTree(),
fNvars (0),
fNCuts (-1),
fSepType (NULL),
fMinSize (0),
fPruneMethod(kCostComplexityPruning),
fQualityIndex(NULL)
{
fLogger.SetSource( "DecisionTree" );
}
TMVA::DecisionTree::DecisionTree( DecisionTreeNode* n )
: BinaryTree(),
fNvars (0),
fNCuts (-1),
fSepType (NULL),
fMinSize (0),
fPruneMethod(kCostComplexityPruning),
fQualityIndex(NULL)
{
fLogger.SetSource( "DecisionTree" );
this->SetRoot( n );
this->SetParentTreeInNodes();
fLogger.SetSource( "DecisionTree" );
}
TMVA::DecisionTree::DecisionTree( TMVA::SeparationBase *sepType,Int_t minSize,
Int_t nCuts, TMVA::SeparationBase *qtype):
BinaryTree(),
fNvars (0),
fNCuts (nCuts),
fSepType (sepType),
fMinSize (minSize),
fPruneMethod(kCostComplexityPruning),
fQualityIndex(qtype)
{
fLogger.SetSource( "DecisionTree" );
}
TMVA::DecisionTree::DecisionTree( const DecisionTree &d):
BinaryTree(),
fNvars (d.fNvars),
fNCuts (d.fNCuts),
fSepType (d.fSepType),
fMinSize (d.fMinSize),
fPruneMethod(d.fPruneMethod),
fQualityIndex(d.fQualityIndex)
{
this->SetRoot( new DecisionTreeNode ( *((DecisionTreeNode*)(d.GetRoot())) ) );
this->SetParentTreeInNodes();
fNNodes = d.fNNodes;
fLogger.SetSource( "DecisionTree" );
}
TMVA::DecisionTree::~DecisionTree( void )
{
}
void TMVA::DecisionTree::SetParentTreeInNodes( DecisionTreeNode *n)
{
if (n == NULL){
n = (DecisionTreeNode*) this->GetRoot();
if (n == NULL) {
fLogger << kFATAL << "SetParentTreeNodes: started with undefined ROOT node" <<Endl;
return ;
}
}
if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) != NULL) ) {
fLogger << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
return;
} else if ((this->GetLeftDaughter(n) != NULL) && (this->GetRightDaughter(n) == NULL) ) {
fLogger << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
return;
}
else {
if (this->GetLeftDaughter(n) != NULL){
this->SetParentTreeInNodes( this->GetLeftDaughter(n) );
}
if (this->GetRightDaughter(n) != NULL) {
this->SetParentTreeInNodes( this->GetRightDaughter(n) );
}
}
n->SetParentTree(this);
if (n->GetDepth() > this->GetTotalTreeDepth()) this->SetTotalTreeDepth(n->GetDepth());
return;
}
Int_t TMVA::DecisionTree::BuildTree( vector<TMVA::Event*> & eventSample,
TMVA::DecisionTreeNode *node )
{
if (node==NULL) {
node = new TMVA::DecisionTreeNode();
fNNodes = 1;
this->SetRoot(node);
this->GetRoot()->SetPos('s');
this->GetRoot()->SetDepth(0);
this->GetRoot()->SetParentTree(this);
}
UInt_t nevents = eventSample.size();
if (nevents > 0 ) {
fNvars = eventSample[0]->GetNVars();
fVariableImportance.resize(fNvars);
}
else fLogger << kFATAL << ":<BuildTree> eventsample Size == 0 " << Endl;
Double_t s=0, b=0;
Double_t suw=0, buw=0;
for (UInt_t i=0; i<eventSample.size(); i++){
if (eventSample[i]->IsSignal()){
s += eventSample[i]->GetWeight();
suw += 1;
}
else {
b += eventSample[i]->GetWeight();
buw += 1;
}
}
node->SetNSigEvents(s);
node->SetNBkgEvents(b);
node->SetNSigEvents_unweighted(suw);
node->SetNBkgEvents_unweighted(buw);
if (node == this->GetRoot()) {
node->SetNEvents(s+b);
node->SetNEvents_unweighted(suw+buw);
}
node->SetSeparationIndex(fSepType->GetSeparationIndex(s,b));
if ( eventSample.size() >= 2*fMinSize){
Double_t separationGain;
separationGain = this->TrainNode(eventSample, node);
if (separationGain == 0) {
if (node->GetPurity() > 0.5) node->SetNodeType(1);
else node->SetNodeType(-1);
if (node->GetDepth() > this->GetTotalTreeDepth()) this->SetTotalTreeDepth(node->GetDepth());
}
else {
vector<TMVA::Event*> leftSample; leftSample.reserve(nevents);
vector<TMVA::Event*> rightSample; rightSample.reserve(nevents);
Double_t nRight=0, nLeft=0;
for (UInt_t ie=0; ie< nevents ; ie++){
if (node->GoesRight(*eventSample[ie])){
rightSample.push_back(eventSample[ie]);
nRight += eventSample[ie]->GetWeight();
}
else {
leftSample.push_back(eventSample[ie]);
nLeft += eventSample[ie]->GetWeight();
}
}
if (leftSample.size() == 0 || rightSample.size() == 0) {
fLogger << kFATAL << "<TrainNode> all events went to the same branch" << Endl
<< "--- Hence new node == old node ... check" << Endl
<< "--- left:" << leftSample.size()
<< " right:" << rightSample.size() << Endl
<< "--- this should never happen, please write a bug report to Helge.Voss@cern.ch"
<< Endl;
}
TMVA::DecisionTreeNode *rightNode = new TMVA::DecisionTreeNode(node,'r');
fNNodes++;
rightNode->SetNEvents(nRight);
rightNode->SetNEvents_unweighted(rightSample.size());
TMVA::DecisionTreeNode *leftNode = new TMVA::DecisionTreeNode(node,'l');
fNNodes++;
leftNode->SetNEvents(nLeft);
leftNode->SetNEvents_unweighted(leftSample.size());
node->SetNodeType(0);
node->SetLeft(leftNode);
node->SetRight(rightNode);
this->BuildTree(rightSample, rightNode);
this->BuildTree(leftSample, leftNode );
}
}
else{
if (node->GetPurity() > 0.5) node->SetNodeType(1);
else node->SetNodeType(-1);
if (node->GetDepth() > this->GetTotalTreeDepth()) this->SetTotalTreeDepth(node->GetDepth());
}
return fNNodes;
}
void TMVA::DecisionTree::FillTree( vector<TMVA::Event*> & eventSample)
{
for (UInt_t i=0; i<eventSample.size(); i++){
this->FillEvent(*(eventSample[i]),NULL);
}
}
void TMVA::DecisionTree::FillEvent( TMVA::Event & event,
TMVA::DecisionTreeNode *node )
{
if (node == NULL) {
node = (TMVA::DecisionTreeNode*)this->GetRoot();
}
node->IncrementNEvents( event.GetWeight() );
node->IncrementNEvents_unweighted( );
if (event.IsSignal()){
node->IncrementNSigEvents( event.GetWeight() );
node->IncrementNSigEvents_unweighted( );
}
else {
node->IncrementNBkgEvents( event.GetWeight() );
node->IncrementNSigEvents_unweighted( );
}
node->SetSeparationIndex(fSepType->GetSeparationIndex(node->GetNSigEvents(),
node->GetNBkgEvents()));
if (node->GetNodeType() == 0){
if (node->GoesRight(event))
this->FillEvent(event,(TMVA::DecisionTreeNode*)(node->GetRight())) ;
else
this->FillEvent(event,(TMVA::DecisionTreeNode*)(node->GetLeft())) ;
}
}
void TMVA::DecisionTree::ClearTree()
{
if (this->GetRoot()!=NULL)
((DecisionTreeNode*)(this->GetRoot()))->ClearNodeAndAllDaughters();
}
void TMVA::DecisionTree::PruneTree()
{
if (fPruneMethod == kExpectedErrorPruning) this->PruneTreeEEP((DecisionTreeNode *)this->GetRoot());
else if (fPruneMethod == kCostComplexityPruning) this->PruneTreeCC();
else if (fPruneMethod == kMCC) this->PruneTreeMCC();
else {
fLogger << kFATAL << "Selected pruning method not yet implemented "
<< Endl;
}
this->CountNodes();
};
void TMVA::DecisionTree::PruneTreeEEP(DecisionTreeNode *node)
{
DecisionTreeNode *l = (DecisionTreeNode*)node->GetLeft();
DecisionTreeNode *r = (DecisionTreeNode*)node->GetRight();
if (node->GetNodeType() == 0){
this->PruneTreeEEP(l);
this->PruneTreeEEP(r);
if (this->GetSubTreeError(node) >= this->GetNodeError(node)) {
this->PruneNode(node);
}
}
}
void TMVA::DecisionTree::PruneTreeCC()
{
Double_t currentCC = this->GetCostComplexity(fPruneStrength);
Double_t nextCC = this->GetCostComplexityIfNextPruneStep(fPruneStrength);
while (currentCC > nextCC && this->GetNNodes() > 3 ){
this->PruneNode( this->FindCCPruneCandidate() );
currentCC = this->GetCostComplexity(fPruneStrength);
nextCC = this->GetCostComplexityIfNextPruneStep(fPruneStrength);
}
return;
}
void TMVA::DecisionTree::PruneTreeMCC()
{
this->FillLinkStrengthMap();
Double_t currentG = fLinkStrengthMap.begin()->first;
while (currentG < fPruneStrength && this->GetNNodes() > 3 ){
this->PruneNode( this->GetWeakestLink() );
currentG = fLinkStrengthMap.begin()->first;
}
return;
}
TMVA::DecisionTreeNode* TMVA::DecisionTree::GetWeakestLink()
{
this->FillLinkStrengthMap();
return fLinkStrengthMap.begin()->second;
}
void TMVA::DecisionTree::FillLinkStrengthMap(TMVA::DecisionTreeNode *n)
{
if (n == NULL){
n = (DecisionTreeNode*) this->GetRoot();
fLinkStrengthMap.clear();
if (n == NULL) {
fLogger << kFATAL << "FillLinkStrengthMap: started with undefined ROOT node" <<Endl;
return ;
}
}
if (this->GetLeftDaughter(n) != NULL){
this->FillLinkStrengthMap( this->GetLeftDaughter(n));
}
if (this->GetRightDaughter(n) != NULL) {
this->FillLinkStrengthMap( this->GetRightDaughter(n));
}
if ((this->GetLeftDaughter(n) != NULL) && (this->GetRightDaughter(n) != NULL) ) {
Double_t alpha = ( this->MisClassificationCostOfNode(n) -
this->MisClassificationCostOfSubTree(n) ) /
(n->CountMeAndAllDaughters() - 1);
fLinkStrengthMap.insert(std::pair<const Double_t, TMVA::DecisionTreeNode* > ( alpha, n ));
}
}
Double_t TMVA::DecisionTree::MisClassificationCostOfNode(TMVA::DecisionTreeNode *n)
{
return (1 - n->GetPurity()) * n->GetNEvents();
}
Double_t TMVA::DecisionTree::MisClassificationCostOfSubTree(TMVA::DecisionTreeNode *n)
{
Double_t tmp=0;
if (n == NULL){
n = (DecisionTreeNode*) this->GetRoot();
if (n == NULL) {
fLogger << kFATAL << "MisClassificationCostOfSubTree: started with undefined ROOT node" <<Endl;
return 0.;
}
}
if (this->GetLeftDaughter(n) != NULL){
tmp += this->MisClassificationCostOfSubTree( this->GetLeftDaughter(n));
}
if (this->GetRightDaughter(n) != NULL) {
tmp += this->MisClassificationCostOfSubTree( this->GetRightDaughter(n));
}
if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) == NULL) ) {
tmp = this->MisClassificationCostOfNode(n);
}
return tmp;
}
UInt_t TMVA::DecisionTree::CountLeafNodes(TMVA::DecisionTreeNode *n)
{
if (n == NULL){
n = (DecisionTreeNode*) this->GetRoot();
if (n == NULL) {
fLogger << kFATAL << "CountLeafNodes: started with undefined ROOT node" <<Endl;
return 0;
}
}
UInt_t countLeafs=0;
if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) == NULL) ) {
countLeafs += 1;
}
else {
if (this->GetLeftDaughter(n) != NULL){
countLeafs += this->CountLeafNodes( this->GetLeftDaughter(n) );
}
if (this->GetRightDaughter(n) != NULL) {
countLeafs += this->CountLeafNodes( this->GetRightDaughter(n) );
}
}
return countLeafs;
}
Double_t TMVA::DecisionTree::GetCostComplexity(Double_t alpha)
{
Double_t cc=0.;
this->FillQualityMap();
std::multimap<Double_t, TMVA::DecisionTreeNode* >::iterator it=fQualityMap.begin();
Int_t count=0;
for (;it!=fQualityMap.end(); it++){
Double_t s=it->second->GetNSigEvents_unweighted();
Double_t b=it->second->GetNBkgEvents_unweighted();
cc += (s+b) * it->first ;
count++;
}
return cc+alpha * count;
}
Double_t TMVA::DecisionTree::GetCostComplexityIfNextPruneStep(Double_t alpha)
{
Double_t cc=0.;
this->FillQualityMap();
this->FillQualityGainMap();
if (fQualityMap.size() == 0 ){
fLogger << kError << "The Quality Map in the BDT-Pruning is empty.. maybe your Tree has "
<< " absolutely no splits ?? e.g.. minimun number of events for node splitting"
<< " being larger than the number of events available ??? " << Endl;
}
else if (fQualityGainMap.size() == 0 ){
fLogger << kError << "The QualityGain Map in the BDT-Pruning is empty.. This can happen"
<< " if your Tree has absolutely no splits ?? e.g.. minimun number of events for"
<< " node splitting being larger than the number of events available ??? " << Endl;
}
else {
std::multimap<Double_t, TMVA::DecisionTreeNode* >::iterator it=fQualityMap.begin();
Int_t count=0;
for (;it!=fQualityMap.end(); it++){
if (it->second->GetParent() != fQualityGainMap.begin()->second ) {
Double_t s=it->second->GetNSigEvents_unweighted();
Double_t b=it->second->GetNBkgEvents_unweighted();
cc += (s+b) * it->first ;
count++;
}
}
Double_t s=fQualityGainMap.begin()->second->GetNSigEvents_unweighted();
Double_t b=fQualityGainMap.begin()->second->GetNBkgEvents_unweighted();
cc += (s+b) * fQualityIndex->GetSeparationIndex(s,b);
count++;
cc+=alpha*count;
}
return cc;
}
void TMVA::DecisionTree::FillQualityGainMap(DecisionTreeNode* n )
{
if (n == NULL){
n = (DecisionTreeNode*) this->GetRoot();
fQualityGainMap.clear();
if (n == NULL) {
fLogger << kFATAL << "FillQualityGainMap: started with undefined ROOT node" <<Endl;
return ;
}
}
if (this->GetLeftDaughter(n) != NULL){
this->FillQualityGainMap( this->GetLeftDaughter(n));
}
if (this->GetRightDaughter(n) != NULL) {
this->FillQualityGainMap( this->GetRightDaughter(n));
}
if ((this->GetLeftDaughter(n) != NULL) && (this->GetRightDaughter(n) != NULL) ) {
if ((this->GetLeftDaughter(n)->GetLeft() == NULL) &&
(this->GetLeftDaughter(n)->GetRight() == NULL) &&
(this->GetRightDaughter(n)->GetLeft() == NULL) &&
(this->GetRightDaughter(n)->GetRight() == NULL) ){
fQualityGainMap.insert(std::pair<const Double_t, TMVA::DecisionTreeNode* >
( fQualityIndex->GetSeparationGain (this->GetRightDaughter(n)->GetNSigEvents_unweighted(),
this->GetRightDaughter(n)->GetNBkgEvents_unweighted(),
n->GetNSigEvents_unweighted(), n->GetNBkgEvents_unweighted()),
n));
}
}
return;
}
void TMVA::DecisionTree::FillQualityMap(DecisionTreeNode* n )
{
if (n == NULL){
n = (DecisionTreeNode*) this->GetRoot();
fQualityMap.clear();
if (n == NULL) {
fLogger << kFATAL << "FillQualityMap: started with undefined ROOT node" <<Endl;
return ;
}
}
if (this->GetLeftDaughter(n) != NULL){
this->FillQualityMap( this->GetLeftDaughter(n));
}
if (this->GetRightDaughter(n) != NULL) {
this->FillQualityMap( this->GetRightDaughter(n));
}
if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) == NULL) ) {
fQualityMap.insert(std::pair<const Double_t, TMVA::DecisionTreeNode* >
( fQualityIndex->GetSeparationIndex (n->GetNSigEvents_unweighted(),
n->GetNBkgEvents_unweighted()),
n));
}
return;
}
void TMVA::DecisionTree::DescendTree( DecisionTreeNode *n)
{
if (n == NULL){
n = (DecisionTreeNode*) this->GetRoot();
if (n == NULL) {
fLogger << kFATAL << "DescendTree: started with undefined ROOT node" <<Endl;
return ;
}
}
if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) == NULL) ) {
}
else if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) != NULL) ) {
fLogger << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
return;
}
else if ((this->GetLeftDaughter(n) != NULL) && (this->GetRightDaughter(n) == NULL) ) {
fLogger << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
return;
}
else {
if (this->GetLeftDaughter(n) != NULL){
this->DescendTree( this->GetLeftDaughter(n) );
}
if (this->GetRightDaughter(n) != NULL) {
this->DescendTree( this->GetRightDaughter(n) );
}
}
}
TMVA::DecisionTreeNode* TMVA::DecisionTree::FindCCPruneCandidate()
{
this->FillQualityGainMap();
return fQualityGainMap.begin()->second;
}
void TMVA::DecisionTree::PruneNode(DecisionTreeNode *node)
{
DecisionTreeNode *l = (DecisionTreeNode*)node->GetLeft();
DecisionTreeNode *r = (DecisionTreeNode*)node->GetRight();
node->SetRight(NULL);
node->SetLeft(NULL);
node->SetSelector(-1);
node->SetSeparationIndex(-1);
node->SetSeparationGain(-1);
if (node->GetPurity() > 0.5) node->SetNodeType(1);
else node->SetNodeType(-1);
this->DeleteNode(l);
this->DeleteNode(r);
this->CountNodes();
}
Double_t TMVA::DecisionTree::GetNodeError(DecisionTreeNode *node)
{
Double_t errorRate = 0;
Double_t nEvts = node->GetNEvents();
Double_t f=0;
if (node->GetPurity() > 0.5) f = node->GetPurity();
else f = (1-node->GetPurity());
Double_t df = TMath::Sqrt(f*(1-f)/nEvts );
errorRate = std::min(1.,(1 - (f-fPruneStrength*df) ));
return errorRate;
}
Double_t TMVA::DecisionTree::GetSubTreeError(DecisionTreeNode *node)
{
DecisionTreeNode *l = (DecisionTreeNode*)node->GetLeft();
DecisionTreeNode *r = (DecisionTreeNode*)node->GetRight();
if (node->GetNodeType() == 0) {
Double_t subTreeError =
(l->GetNEvents() * this->GetSubTreeError(l) +
r->GetNEvents() * this->GetSubTreeError(r)) /
node->GetNEvents();
return subTreeError;
}
else {
return this->GetNodeError(node);
}
}
TMVA::DecisionTreeNode* TMVA::DecisionTree::GetLeftDaughter( DecisionTreeNode *n)
{
return (DecisionTreeNode*) n->GetLeft();
}
TMVA::DecisionTreeNode* TMVA::DecisionTree::GetRightDaughter( DecisionTreeNode *n)
{
return (DecisionTreeNode*) n->GetRight();
}
TMVA::DecisionTreeNode* TMVA::DecisionTree::GetNode(ULong_t sequence, UInt_t depth)
{
DecisionTreeNode* current = (DecisionTreeNode*) this->GetRoot();
for (UInt_t i =0; i < depth; i++){
ULong_t tmp = 1 << i;
if ( tmp & sequence) current = this->GetRightDaughter(current);
else current = this->GetLeftDaughter(current);
}
return current;
}
void TMVA::DecisionTree::FindMinAndMax(vector<TMVA::Event*> & eventSample,
vector<Double_t> & xmin,
vector<Double_t> & xmax)
{
UInt_t num_events = eventSample.size();
for (Int_t ivar=0; ivar < fNvars; ivar++){
xmin[ivar]=xmax[ivar]=eventSample[0]->GetVal(ivar);
}
for (UInt_t i=1;i<num_events;i++){
for (Int_t ivar=0; ivar < fNvars; ivar++){
if (xmin[ivar]>eventSample[i]->GetVal(ivar))
xmin[ivar]=eventSample[i]->GetVal(ivar);
if (xmax[ivar]<eventSample[i]->GetVal(ivar))
xmax[ivar]=eventSample[i]->GetVal(ivar);
}
}
};
void TMVA::DecisionTree::SetCutPoints(vector<Double_t> & cut_points,
Double_t xmin,
Double_t xmax,
Int_t num_gridpoints)
{
Double_t step = (xmax - xmin)/num_gridpoints;
Double_t x = xmin + step/2;
for (Int_t j=0; j < num_gridpoints; j++){
cut_points[j] = x;
x += step;
}
};
#if USE_HELGESCODE==1
Double_t TMVA::DecisionTree::TrainNode(vector<TMVA::Event*> & eventSample,
TMVA::DecisionTreeNode *node)
{
vector<Double_t> *xmin = new vector<Double_t>( fNvars );
vector<Double_t> *xmax = new vector<Double_t>( fNvars );
Double_t separationGain = -1, sepTmp;
Double_t cutValue=-999;
Int_t mxVar=-1, cutIndex=0;
Bool_t cutType=kTRUE;
Double_t nTotS, nTotB;
Int_t nTotS_unWeighted, nTotB_unWeighted;
UInt_t nevents = eventSample.size();
for (int ivar=0; ivar < fNvars; ivar++){
(*xmin)[ivar]=(*xmax)[ivar]=eventSample[0]->GetVal(ivar);
}
for (UInt_t iev=1;iev<nevents;iev++){
for (Int_t ivar=0; ivar < fNvars; ivar++){
Double_t eventData = eventSample[iev]->GetVal(ivar);
if ((*xmin)[ivar]>eventData)(*xmin)[ivar]=eventData;
if ((*xmax)[ivar]<eventData)(*xmax)[ivar]=eventData;
}
}
vector< vector<Double_t> > nSelS (fNvars);
vector< vector<Double_t> > nSelB (fNvars);
vector< vector<Int_t> > nSelS_unWeighted (fNvars);
vector< vector<Int_t> > nSelB_unWeighted (fNvars);
vector< vector<Double_t> > significance (fNvars);
vector< vector<Double_t> > cutValues(fNvars);
vector< vector<Bool_t> > cutTypes(fNvars);
for (int ivar=0; ivar < fNvars; ivar++){
cutValues[ivar].resize(fNCuts);
cutTypes[ivar].resize(fNCuts);
nSelS[ivar].resize(fNCuts);
nSelB[ivar].resize(fNCuts);
nSelS_unWeighted[ivar].resize(fNCuts);
nSelB_unWeighted[ivar].resize(fNCuts);
significance[ivar].resize(fNCuts);
Double_t istepSize =( (*xmax)[ivar] - (*xmin)[ivar] ) / Double_t(fNCuts);
for (Int_t icut=0; icut<fNCuts; icut++){
cutValues[ivar][icut]=(*xmin)[ivar]+(Float_t(icut)+0.5)*istepSize;
}
}
#if USE_HELGE_V1==1
nTotS=0; nTotB=0;
nTotS_unWeighted=0; nTotB_unWeighted=0;
for (int ivar=0; ivar < fNvars; ivar++){
for (UInt_t iev=0; iev<nevents; iev++){
Double_t eventData = eventSample[iev]->GetData(ivar);
Int_t eventType = eventSample[iev]->GetType();
Double_t eventWeight= eventSample[iev]->GetWeight();
if (ivar==0){
if (eventType==1){
nTotS+=eventWeight;
nTotS_unWeighted++;
}
else {
nTotB+=eventWeight;
nTotB_unWeighted++;
}
}
for (Int_t icut=0; icut<fNCuts; icut++){
if (eventData > cutValues[ivar][icut]){
if (eventType==1) {
nSelS[ivar][icut]+=eventWeight;
nSelS_unWeighted[ivar][icut]++;
}
else {
nSelB[ivar][icut]+=eventWeight;
nSelB_unWeighted[ivar][icut]++;
}
}
}
}
}
#else
nTotS=0; nTotB=0;
nTotS_unWeighted=0; nTotB_unWeighted=0;
for (UInt_t iev=0; iev<nevents; iev++){
Int_t eventType = eventSample[iev]->Type();
Double_t eventWeight = eventSample[iev]->GetWeight();
if (eventType==1){
nTotS+=eventWeight;
nTotS_unWeighted++;
}
else {
nTotB+=eventWeight;
nTotB_unWeighted++;
}
for (int ivar=0; ivar < fNvars; ivar++){
Double_t eventData = eventSample[iev]->GetVal(ivar);
for (Int_t icut=0; icut<fNCuts; icut++){
if (eventData > cutValues[ivar][icut]){
if (eventType==1) {
nSelS[ivar][icut]+=eventWeight;
nSelS_unWeighted[ivar][icut]++;
}
else {
nSelB[ivar][icut]+=eventWeight;
nSelB_unWeighted[ivar][icut]++;
}
}
}
}
}
#endif
for (int ivar=0; ivar < fNvars; ivar++) {
for (Int_t icut=0; icut<fNCuts; icut++){
if ( (nSelS_unWeighted[ivar][icut] + nSelB_unWeighted[ivar][icut]) >= fMinSize &&
(( nTotS_unWeighted+nTotB_unWeighted)-
(nSelS_unWeighted[ivar][icut] + nSelB_unWeighted[ivar][icut])) >= fMinSize) {
sepTmp = fSepType->GetSeparationGain(nSelS[ivar][icut], nSelB[ivar][icut], nTotS, nTotB);
if (separationGain < sepTmp) {
separationGain = sepTmp;
mxVar = ivar;
cutIndex = icut;
}
}
}
}
if (mxVar >= 0) {
if (nSelS[mxVar][cutIndex]/nTotS > nSelB[mxVar][cutIndex]/nTotB) cutType=kTRUE;
else cutType=kFALSE;
cutValue = cutValues[mxVar][cutIndex];
node->SetSelector((UInt_t)mxVar);
node->SetCutValue(cutValue);
node->SetCutType(cutType);
node->SetSeparationGain(separationGain);
fVariableImportance[mxVar] += separationGain*separationGain * (nTotS+nTotB) * (nTotS+nTotB) ;
}
else {
separationGain = 0;
}
delete xmin;
delete xmax;
return separationGain;
}
#else
Double_t TMVA::DecisionTree::TrainNode(vector<TMVA::Event*> & eventSample,
TMVA::DecisionTreeNode *node)
{
vector<Double_t> xmin ( fNvars );
vector<Double_t> xmax ( fNvars );
Double_t separationGain = -1;
Double_t cutValue=-999;
Int_t mxVar=-1;
Bool_t cutType=kTRUE;
Double_t nSelS=0., nSelB=0., nTotS=0., nTotB=0.;
UInt_t num_events = eventSample.size();
vector<vector<Double_t> > signal_counts (fNvars);
vector<vector<Double_t> > background_counts (fNvars);
vector<vector<Double_t> > cut_points (fNvars);
vector<vector<Double_t> > significance (fNvars);
this->FindMinAndMax(eventSample, xmin, xmax);
for (Int_t i=0; i < fNvars; i++){
signal_counts[i].resize(fNCuts);
background_counts[i].resize(fNCuts);
cut_points[i].resize(fNCuts);
significance[i].resize(fNCuts);
this->SetCutPoints(cut_points[i], xmin[i], xmax[i], fNCuts);
}
for (UInt_t event=0; event < num_events; event++){
Int_t event_type = eventSample[event]->GetType();
Double_t event_weight = eventSample[event]->GetWeight();
if (event_type == 1){
nTotS += event_weight;
}
else {
nTotB += event_weight;
}
for (Int_t variable = 0; variable < fNvars; variable++){
Double_t event_val = eventSample[event]->GetData(variable);
for (Int_t cut=0; cut < fNCuts; cut++){
if (event_val > cut_points[variable][cut]){
if (event_type == 1){
signal_counts[variable][cut] += event_weight;
}
else {
background_counts[variable][cut] += event_weight;
}
}
}
}
}
for (Int_t var = 0; var < fNvars; var++){
for (Int_t cut=0; cut < fNCuts; cut++){
Double_t cur_sep = fSepType->GetSeparationGain(signal_counts[var][cut],
background_counts[var][cut],
nTotS, nTotB);
if (separationGain < cur_sep) {
separationGain = cur_sep;
cutValue=cut_points[var][cut];
cutType= (nSelS/nTotS > nSelB/nTotB) ? kTRUE : kFALSE;
mxVar = var;
}
}
}
node->SetSelector(mxVar);
node->SetCutValue(cutValue);
node->SetCutType(cutType);
node->SetSeparationGain(separationGain);
fVariableImportance[mxVar] += separationGain*separationGain * (nTotS+nTotB)* (nTotS+nTotB);
return separationGain;
}
#endif
Double_t TMVA::DecisionTree::CheckEvent(const TMVA::Event & e, Bool_t UseYesNoLeaf)
{
TMVA::DecisionTreeNode *current = (TMVA::DecisionTreeNode*)this->GetRoot();
while(current->GetNodeType() == 0){
if (current->GoesRight(e))
current=(TMVA::DecisionTreeNode*)current->GetRight();
else current=(TMVA::DecisionTreeNode*)current->GetLeft();
}
if (UseYesNoLeaf) return Double_t ( current->GetNodeType() );
else return current->GetPurity();
}
Double_t TMVA::DecisionTree::SamplePurity(vector<TMVA::Event*> eventSample)
{
Double_t sumsig=0, sumbkg=0, sumtot=0;
for (UInt_t ievt=0; ievt<eventSample.size(); ievt++) {
if (eventSample[ievt]->Type()==0) sumbkg+=eventSample[ievt]->GetWeight();
if (eventSample[ievt]->Type()==1) sumsig+=eventSample[ievt]->GetWeight();
sumtot+=eventSample[ievt]->GetWeight();
}
if (sumtot!= (sumsig+sumbkg)){
fLogger << kFATAL << "<SamplePurity> sumtot != sumsig+sumbkg"
<< sumtot << " " << sumsig << " " << sumbkg << Endl;
}
if (sumtot>0) return sumsig/(sumsig + sumbkg);
else return -1;
}
vector< Double_t > TMVA::DecisionTree::GetVariableImportance()
{
vector<Double_t> relativeImportance(fNvars);
Double_t sum=0;
for (int i=0; i< fNvars; i++) {
sum += fVariableImportance[i];
relativeImportance[i] = fVariableImportance[i];
}
for (int i=0; i< fNvars; i++) {
if (sum > std::numeric_limits<double>::epsilon())
relativeImportance[i] /= sum;
else
relativeImportance[i] = 0;
}
return relativeImportance;
}
Double_t TMVA::DecisionTree::GetVariableImportance(Int_t ivar)
{
vector<Double_t> relativeImportance = this->GetVariableImportance();
if (ivar >= 0 && ivar < fNvars) return relativeImportance[ivar];
else {
fLogger << kFATAL << "<GetVariableImportance>" << Endl
<< "--- ivar = " << ivar << " is out of range " << Endl;
}
return -1;
}
Last update: Thu Jan 17 08:58:48 2008
This page has been automatically generated. If you have any comments or suggestions about the page layout send a mail to ROOT support, or contact the developers with any questions or problems regarding ROOT.