Logo ROOT  
Reference Guide
DecisionTreeNode.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : TMVA::DecisionTreeNode *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Implementation of a Decision Tree Node *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
16 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
17 * Eckhard von Toerne <evt@physik.uni-bonn.de> - U. of Bonn, Germany *
18 * *
19 * CopyRight (c) 2009: *
20 * CERN, Switzerland *
21 * U. of Victoria, Canada *
22 * MPI-K Heidelberg, Germany *
23 * U. of Bonn, Germany *
24 * *
25 * Redistribution and use in source and binary forms, with or without *
26 * modification, are permitted according to the terms listed in LICENSE *
27 * (http://tmva.sourceforge.net/LICENSE) *
28 **********************************************************************************/
29
30/*! \class TMVA::
31\ingroup TMVA
32
33Node for the Decision Tree.
34
35The node specifies ONE variable out of the given set of selection variable
36that is used to split the sample which "arrives" at the node, into a left
37(background-enhanced) and a right (signal-enhanced) sample.
38
39*/
40
42
43#include "TMVA/Types.h"
44#include "TMVA/MsgLogger.h"
45#include "TMVA/Tools.h"
46#include "TMVA/Event.h"
47
48#include "ThreadLocalStorage.h"
49#include "TString.h"
50
51#include <algorithm>
52#include <exception>
53#include <iomanip>
54#include <limits>
55#include <sstream>
56
57using std::string;
58
60
63
64////////////////////////////////////////////////////////////////////////////////
65/// constructor of an essentially "empty" node floating in space
66
68 : TMVA::Node(),
69 fCutValue(0),
70 fCutType ( kTRUE ),
71 fSelector ( -1 ),
72 fResponse(-99 ),
73 fRMS(0),
74 fNodeType (-99 ),
75 fPurity (-99),
76 fIsTerminalNode( kFALSE )
77{
80 //std::cout << "Node constructor with TrainingINFO"<<std::endl;
81 }
82 else {
83 //std::cout << "**Node constructor WITHOUT TrainingINFO"<<std::endl;
84 fTrainInfo = 0;
85 }
86}
87
88////////////////////////////////////////////////////////////////////////////////
89/// constructor of a daughter node as a daughter of 'p'
90
92 : TMVA::Node(p, pos),
93 fCutValue( 0 ),
94 fCutType ( kTRUE ),
95 fSelector( -1 ),
96 fResponse(-99 ),
97 fRMS(0),
98 fNodeType( -99 ),
99 fPurity (-99),
100 fIsTerminalNode( kFALSE )
101{
104 //std::cout << "Node constructor with TrainingINFO"<<std::endl;
105 }
106 else {
107 //std::cout << "**Node constructor WITHOUT TrainingINFO"<<std::endl;
108 fTrainInfo = 0;
109 }
110}
111
112////////////////////////////////////////////////////////////////////////////////
113/// copy constructor of a node. It will result in an explicit copy of
114/// the node and recursively all it's daughters
115
117 DecisionTreeNode* parent)
118 : TMVA::Node(n),
119 fCutValue( n.fCutValue ),
120 fCutType ( n.fCutType ),
121 fSelector( n.fSelector ),
122 fResponse( n.fResponse ),
123 fRMS ( n.fRMS),
124 fNodeType( n.fNodeType ),
125 fPurity ( n.fPurity),
126 fIsTerminalNode( n.fIsTerminalNode )
127{
128 this->SetParent( parent );
129 if (n.GetLeft() == 0 ) this->SetLeft(NULL);
130 else this->SetLeft( new DecisionTreeNode( *((DecisionTreeNode*)(n.GetLeft())),this));
131
132 if (n.GetRight() == 0 ) this->SetRight(NULL);
133 else this->SetRight( new DecisionTreeNode( *((DecisionTreeNode*)(n.GetRight())),this));
134
136 fTrainInfo = new DTNodeTrainingInfo(*(n.fTrainInfo));
137 //std::cout << "Node constructor with TrainingINFO"<<std::endl;
138 }
139 else {
140 //std::cout << "**Node constructor WITHOUT TrainingINFO"<<std::endl;
141 fTrainInfo = 0;
142 }
143}
144
145////////////////////////////////////////////////////////////////////////////////
146/// destructor
147
149 delete fTrainInfo;
150}
151
152////////////////////////////////////////////////////////////////////////////////
153/// test event if it descends the tree at this node to the right
154
156{
157 Bool_t result;
158 // first check if the fisher criterium is used or ordinary cuts:
159 if (GetNFisherCoeff() == 0){
160
161 result = (e.GetValueFast(this->GetSelector()) >= this->GetCutValue() );
162
163 }else{
164
165 Double_t fisher = this->GetFisherCoeff(fFisherCoeff.size()-1); // the offset
166 for (UInt_t ivar=0; ivar<fFisherCoeff.size()-1; ivar++)
167 fisher += this->GetFisherCoeff(ivar)*(e.GetValueFast(ivar));
168
169 result = fisher > this->GetCutValue();
170 }
171
172 if (fCutType == kTRUE) return result; //the cuts are selecting Signal ;
173 else return !result;
174}
175
176////////////////////////////////////////////////////////////////////////////////
177/// test event if it descends the tree at this node to the left
178
180{
181 if (!this->GoesRight(e)) return kTRUE;
182 else return kFALSE;
183}
184
185
186////////////////////////////////////////////////////////////////////////////////
187/// return the S/(S+B) (purity) for the node
188/// REM: even if nodes with purity 0.01 are very PURE background nodes, they still
189/// get a small value of the purity.
190
192{
193 if ( ( this->GetNSigEvents() + this->GetNBkgEvents() ) > 0 ) {
194 fPurity = this->GetNSigEvents() / ( this->GetNSigEvents() + this->GetNBkgEvents());
195 }
196 else {
197 Log() << kINFO << "Zero events in purity calculation , return purity=0.5" << Endl;
198 std::ostringstream oss;
199 this->Print(oss);
200 Log() <<oss.str();
201 fPurity = 0.5;
202 }
203 return;
204}
205
206////////////////////////////////////////////////////////////////////////////////
207///print the node
208
209void TMVA::DecisionTreeNode::Print(std::ostream& os) const
210{
211 os << "< *** " << std::endl;
212 os << " d: " << this->GetDepth()
213 << std::setprecision(6)
214 << "NCoef: " << this->GetNFisherCoeff();
215 for (Int_t i=0; i< (Int_t) this->GetNFisherCoeff(); i++) { os << "fC"<<i<<": " << this->GetFisherCoeff(i);}
216 os << " ivar: " << this->GetSelector()
217 << " cut: " << this->GetCutValue()
218 << " cType: " << this->GetCutType()
219 << " s: " << this->GetNSigEvents()
220 << " b: " << this->GetNBkgEvents()
221 << " nEv: " << this->GetNEvents()
222 << " suw: " << this->GetNSigEvents_unweighted()
223 << " buw: " << this->GetNBkgEvents_unweighted()
224 << " nEvuw: " << this->GetNEvents_unweighted()
225 << " sepI: " << this->GetSeparationIndex()
226 << " sepG: " << this->GetSeparationGain()
227 << " nType: " << this->GetNodeType()
228 << std::endl;
229
230 os << "My address is " << long(this) << ", ";
231 if (this->GetParent() != NULL) os << " parent at addr: " << long(this->GetParent()) ;
232 if (this->GetLeft() != NULL) os << " left daughter at addr: " << long(this->GetLeft());
233 if (this->GetRight() != NULL) os << " right daughter at addr: " << long(this->GetRight()) ;
234
235 os << " **** > " << std::endl;
236}
237
238////////////////////////////////////////////////////////////////////////////////
239/// recursively print the node and its daughters (--> print the 'tree')
240
241void TMVA::DecisionTreeNode::PrintRec(std::ostream& os) const
242{
243 os << this->GetDepth()
244 << std::setprecision(6)
245 << " " << this->GetPos()
246 << "NCoef: " << this->GetNFisherCoeff();
247 for (Int_t i=0; i< (Int_t) this->GetNFisherCoeff(); i++) {os << "fC"<<i<<": " << this->GetFisherCoeff(i);}
248 os << " ivar: " << this->GetSelector()
249 << " cut: " << this->GetCutValue()
250 << " cType: " << this->GetCutType()
251 << " s: " << this->GetNSigEvents()
252 << " b: " << this->GetNBkgEvents()
253 << " nEv: " << this->GetNEvents()
254 << " suw: " << this->GetNSigEvents_unweighted()
255 << " buw: " << this->GetNBkgEvents_unweighted()
256 << " nEvuw: " << this->GetNEvents_unweighted()
257 << " sepI: " << this->GetSeparationIndex()
258 << " sepG: " << this->GetSeparationGain()
259 << " res: " << this->GetResponse()
260 << " rms: " << this->GetRMS()
261 << " nType: " << this->GetNodeType();
262 if (this->GetCC() > 10000000000000.) os << " CC: " << 100000. << std::endl;
263 else os << " CC: " << this->GetCC() << std::endl;
264
265 if (this->GetLeft() != NULL) this->GetLeft() ->PrintRec(os);
266 if (this->GetRight() != NULL) this->GetRight()->PrintRec(os);
267}
268
269////////////////////////////////////////////////////////////////////////////////
270/// Read the data block
271
272Bool_t TMVA::DecisionTreeNode::ReadDataRecord( std::istream& is, UInt_t tmva_Version_Code )
273{
274 fgTmva_Version_Code=tmva_Version_Code;
275 string tmp;
276
277 Float_t cutVal, cutType, nsig, nbkg, nEv, nsig_unweighted, nbkg_unweighted, nEv_unweighted;
278 Float_t separationIndex, separationGain, response(-99), cc(0);
279 Int_t depth, ivar, nodeType;
280 ULong_t lseq;
281 char pos;
282
283 is >> depth; // 2
284 if ( depth==-1 ) { return kFALSE; }
285 // if ( depth==-1 ) { delete this; return kFALSE; }
286 is >> pos ; // r
287 this->SetDepth(depth);
288 this->SetPos(pos);
289
290 if (tmva_Version_Code < TMVA_VERSION(4,0,0)) {
291 is >> tmp >> lseq
292 >> tmp >> ivar
293 >> tmp >> cutVal
294 >> tmp >> cutType
295 >> tmp >> nsig
296 >> tmp >> nbkg
297 >> tmp >> nEv
298 >> tmp >> nsig_unweighted
299 >> tmp >> nbkg_unweighted
300 >> tmp >> nEv_unweighted
301 >> tmp >> separationIndex
302 >> tmp >> separationGain
303 >> tmp >> nodeType;
304 } else {
305 is >> tmp >> lseq
306 >> tmp >> ivar
307 >> tmp >> cutVal
308 >> tmp >> cutType
309 >> tmp >> nsig
310 >> tmp >> nbkg
311 >> tmp >> nEv
312 >> tmp >> nsig_unweighted
313 >> tmp >> nbkg_unweighted
314 >> tmp >> nEv_unweighted
315 >> tmp >> separationIndex
316 >> tmp >> separationGain
317 >> tmp >> response
318 >> tmp >> nodeType
319 >> tmp >> cc;
320 }
321
322 this->SetSelector((UInt_t)ivar);
323 this->SetCutValue(cutVal);
324 this->SetCutType(cutType);
325 this->SetNodeType(nodeType);
326 if (fTrainInfo){
327 this->SetNSigEvents(nsig);
328 this->SetNBkgEvents(nbkg);
329 this->SetNEvents(nEv);
330 this->SetNSigEvents_unweighted(nsig_unweighted);
331 this->SetNBkgEvents_unweighted(nbkg_unweighted);
332 this->SetNEvents_unweighted(nEv_unweighted);
333 this->SetSeparationIndex(separationIndex);
334 this->SetSeparationGain(separationGain);
335 this->SetPurity();
336 // this->SetResponse(response); old .txt weightfiles don't know regression yet
337 this->SetCC(cc);
338 }
339
340 return kTRUE;
341}
342
343////////////////////////////////////////////////////////////////////////////////
344/// clear the nodes (their S/N, Nevents etc), just keep the structure of the tree
345
347{
348 SetNSigEvents(0);
349 SetNBkgEvents(0);
350 SetNEvents(0);
351 SetNSigEvents_unweighted(0);
352 SetNBkgEvents_unweighted(0);
353 SetNEvents_unweighted(0);
354 SetSeparationIndex(-1);
355 SetSeparationGain(-1);
356 SetPurity();
357
358 if (this->GetLeft() != NULL) ((DecisionTreeNode*)(this->GetLeft()))->ClearNodeAndAllDaughters();
359 if (this->GetRight() != NULL) ((DecisionTreeNode*)(this->GetRight()))->ClearNodeAndAllDaughters();
360}
361
362////////////////////////////////////////////////////////////////////////////////
363/// temporary stored node values (number of events, etc.) that originate
364/// not from the training but from the validation data (used in pruning)
365
367 SetNBValidation( 0.0 );
368 SetNSValidation( 0.0 );
369 SetSumTarget( 0 );
370 SetSumTarget2( 0 );
371
372 if(GetLeft() != NULL && GetRight() != NULL) {
373 GetLeft()->ResetValidationData();
374 GetRight()->ResetValidationData();
375 }
376}
377
378////////////////////////////////////////////////////////////////////////////////
379/// printout of the node (can be read in with ReadDataRecord)
380
381void TMVA::DecisionTreeNode::PrintPrune( std::ostream& os ) const {
382 os << "----------------------" << std::endl
383 << "|~T_t| " << GetNTerminal() << std::endl
384 << "R(t): " << GetNodeR() << std::endl
385 << "R(T_t): " << GetSubTreeR() << std::endl
386 << "g(t): " << GetAlpha() << std::endl
387 << "G(t): " << GetAlphaMinSubtree() << std::endl;
388}
389
390////////////////////////////////////////////////////////////////////////////////
391/// recursive printout of the node and its daughters
392
393void TMVA::DecisionTreeNode::PrintRecPrune( std::ostream& os ) const {
394 this->PrintPrune(os);
395 if(this->GetLeft() != NULL && this->GetRight() != NULL) {
396 ((DecisionTreeNode*)this->GetLeft())->PrintRecPrune(os);
397 ((DecisionTreeNode*)this->GetRight())->PrintRecPrune(os);
398 }
399}
400
401////////////////////////////////////////////////////////////////////////////////
402
404{
405 if (fTrainInfo) fTrainInfo->fCC = cc;
406 else Log() << kFATAL << "call to SetCC without trainingInfo" << Endl;
407}
408
409////////////////////////////////////////////////////////////////////////////////
410/// return the minimum of variable ivar from the training sample
411/// that pass/end up in this node
412
414 if (fTrainInfo && ivar < fTrainInfo->fSampleMin.size()) return fTrainInfo->fSampleMin[ivar];
415 else Log() << kFATAL << "You asked for Min of the event sample in node for variable "
416 << ivar << " that is out of range" << Endl;
417 return -9999;
418}
419
420////////////////////////////////////////////////////////////////////////////////
421/// return the maximum of variable ivar from the training sample
422/// that pass/end up in this node
423
425 if (fTrainInfo && ivar < fTrainInfo->fSampleMin.size()) return fTrainInfo->fSampleMax[ivar];
426 else Log() << kFATAL << "You asked for Max of the event sample in node for variable "
427 << ivar << " that is out of range" << Endl;
428 return 9999;
429}
430
431////////////////////////////////////////////////////////////////////////////////
432/// set the minimum of variable ivar from the training sample
433/// that pass/end up in this node
434
436 if ( fTrainInfo) {
437 if ( ivar >= fTrainInfo->fSampleMin.size()) fTrainInfo->fSampleMin.resize(ivar+1);
438 fTrainInfo->fSampleMin[ivar]=xmin;
439 }
440}
441
442////////////////////////////////////////////////////////////////////////////////
443/// set the maximum of variable ivar from the training sample
444/// that pass/end up in this node
445
447 if( ! fTrainInfo ) return;
448 if ( ivar >= fTrainInfo->fSampleMax.size() )
449 fTrainInfo->fSampleMax.resize(ivar+1);
450 fTrainInfo->fSampleMax[ivar]=xmax;
451}
452
453////////////////////////////////////////////////////////////////////////////////
454
455void TMVA::DecisionTreeNode::ReadAttributes(void* node, UInt_t /* tmva_Version_Code */ )
456{
457 Float_t tempNSigEvents,tempNBkgEvents;
458
459 Int_t nCoef;
460 if (gTools().HasAttr(node, "NCoef")){
461 gTools().ReadAttr(node, "NCoef", nCoef );
462 this->SetNFisherCoeff(nCoef);
463 Double_t tmp;
464 for (Int_t i=0; i< (Int_t) this->GetNFisherCoeff(); i++) {
465 gTools().ReadAttr(node, Form("fC%d",i), tmp );
466 this->SetFisherCoeff(i,tmp);
467 }
468 }else{
469 this->SetNFisherCoeff(0);
470 }
471 gTools().ReadAttr(node, "IVar", fSelector );
472 gTools().ReadAttr(node, "Cut", fCutValue );
473 gTools().ReadAttr(node, "cType", fCutType );
474 if (gTools().HasAttr(node,"res")) gTools().ReadAttr(node, "res", fResponse);
475 if (gTools().HasAttr(node,"rms")) gTools().ReadAttr(node, "rms", fRMS);
476 // else {
477 if( gTools().HasAttr(node, "purity") ) {
478 gTools().ReadAttr(node, "purity",fPurity );
479 } else {
480 gTools().ReadAttr(node, "nS", tempNSigEvents );
481 gTools().ReadAttr(node, "nB", tempNBkgEvents );
482 fPurity = tempNSigEvents / (tempNSigEvents + tempNBkgEvents);
483 }
484 // }
485 gTools().ReadAttr(node, "nType", fNodeType );
486}
487
488
489////////////////////////////////////////////////////////////////////////////////
490/// add attribute to xml
491
493{
494 gTools().AddAttr(node, "NCoef", GetNFisherCoeff());
495 for (Int_t i=0; i< (Int_t) this->GetNFisherCoeff(); i++)
496 gTools().AddAttr(node, Form("fC%d",i), this->GetFisherCoeff(i));
497
498 gTools().AddAttr(node, "IVar", GetSelector());
499 gTools().AddAttr(node, "Cut", GetCutValue());
500 gTools().AddAttr(node, "cType", GetCutType());
501
502 //UInt_t analysisType = (dynamic_cast<const TMVA::DecisionTree*>(GetParentTree()) )->GetAnalysisType();
503 // if ( analysisType == TMVA::Types:: kRegression) {
504 gTools().AddAttr(node, "res", GetResponse());
505 gTools().AddAttr(node, "rms", GetRMS());
506 //} else if ( analysisType == TMVA::Types::kClassification) {
507 gTools().AddAttr(node, "purity",GetPurity());
508 //}
509 gTools().AddAttr(node, "nType", GetNodeType());
510}
511
512////////////////////////////////////////////////////////////////////////////////
513/// set fisher coefficients
514
516{
517 if ((Int_t) fFisherCoeff.size()<ivar+1) fFisherCoeff.resize(ivar+1) ;
518 fFisherCoeff[ivar]=coeff;
519}
520
521////////////////////////////////////////////////////////////////////////////////
522/// adding attributes to tree node (well, was used in BinarySearchTree,
523/// and somehow I guess someone programmed it such that we need this in
524/// this tree too, although we don't..)
525
526void TMVA::DecisionTreeNode::AddContentToNode( std::stringstream& /*s*/ ) const
527{
528}
529
530////////////////////////////////////////////////////////////////////////////////
531/// reading attributes from tree node (well, was used in BinarySearchTree,
532/// and somehow I guess someone programmed it such that we need this in
533/// this tree too, although we don't..)
534
535void TMVA::DecisionTreeNode::ReadContent( std::stringstream& /*s*/ )
536{
537}
538////////////////////////////////////////////////////////////////////////////////
539
541 TTHREAD_TLS_DECL_ARG(MsgLogger,logger,"DecisionTreeNode"); // static because there is a huge number of nodes...
542 return logger;
543}
#define e(i)
Definition: RSha256.hxx:103
int Int_t
Definition: RtypesCore.h:41
unsigned int UInt_t
Definition: RtypesCore.h:42
const Bool_t kFALSE
Definition: RtypesCore.h:88
unsigned long ULong_t
Definition: RtypesCore.h:51
bool Bool_t
Definition: RtypesCore.h:59
double Double_t
Definition: RtypesCore.h:55
float Float_t
Definition: RtypesCore.h:53
const Bool_t kTRUE
Definition: RtypesCore.h:87
#define ClassImp(name)
Definition: Rtypes.h:365
float xmin
Definition: THbookFile.cxx:93
float xmax
Definition: THbookFile.cxx:93
char * Form(const char *fmt,...)
#define TMVA_VERSION(a, b, c)
Definition: Version.h:48
virtual void AddContentToNode(std::stringstream &s) const
adding attributes to tree node (well, was used in BinarySearchTree, and somehow I guess someone progr...
DTNodeTrainingInfo * fTrainInfo
flag to set node as terminal (i.e., without deleting its descendants)
virtual ~DecisionTreeNode()
destructor
void PrintPrune(std::ostream &os) const
printout of the node (can be read in with ReadDataRecord)
void PrintRecPrune(std::ostream &os) const
recursive printout of the node and its daughters
void SetFisherCoeff(Int_t ivar, Double_t coeff)
set fisher coefficients
static UInt_t fgTmva_Version_Code
virtual void SetLeft(Node *l)
void SetSampleMax(UInt_t ivar, Float_t xmax)
set the maximum of variable ivar from the training sample that pass/end up in this node
void ClearNodeAndAllDaughters()
clear the nodes (their S/N, Nevents etc), just keep the structure of the tree
virtual Bool_t GoesLeft(const Event &) const
test event if it descends the tree at this node to the left
virtual void ReadContent(std::stringstream &s)
reading attributes from tree node (well, was used in BinarySearchTree, and somehow I guess someone pr...
void SetPurity(void)
return the S/(S+B) (purity) for the node REM: even if nodes with purity 0.01 are very PURE background...
virtual void Print(std::ostream &os) const
print the node
virtual Bool_t GoesRight(const Event &) const
test event if it descends the tree at this node to the right
DecisionTreeNode()
constructor of an essentially "empty" node floating in space
virtual void AddAttributesToNode(void *node) const
add attribute to xml
virtual void ReadAttributes(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
static MsgLogger & Log()
void ResetValidationData()
temporary stored node values (number of events, etc.) that originate not from the training but from t...
virtual void PrintRec(std::ostream &os) const
recursively print the node and its daughters (--> print the 'tree')
virtual void SetRight(Node *r)
virtual Bool_t ReadDataRecord(std::istream &is, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
Read the data block.
virtual void SetParent(Node *p)
Float_t GetSampleMax(UInt_t ivar) const
return the maximum of variable ivar from the training sample that pass/end up in this node
Float_t GetSampleMin(UInt_t ivar) const
return the minimum of variable ivar from the training sample that pass/end up in this node
void SetSampleMin(UInt_t ivar, Float_t xmin)
set the minimum of variable ivar from the training sample that pass/end up in this node
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
Node for the BinarySearch or Decision Trees.
Definition: Node.h:56
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:337
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition: Tools.h:355
const Int_t n
Definition: legend1.C:16
void Print(std::ostream &os, const OptionType &opt)
create variable transformations
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Double_t Log(Double_t x)
Definition: TMath.h:750