Logo ROOT   6.14/05
Reference Guide
DecisionTreeNode.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : DecisionTreeNode *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Node for the Decision Tree *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
16  * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
17  * Eckhard von Toerne <evt@physik.uni-bonn.de> - U. of Bonn, Germany *
18  * *
19  * Copyright (c) 2009: *
20  * CERN, Switzerland *
21  * U. of Victoria, Canada *
22  * MPI-K Heidelberg, Germany *
23  * U. of Bonn, Germany *
24  * *
25  * Redistribution and use in source and binary forms, with or without *
26  * modification, are permitted according to the terms listed in LICENSE *
27  * (http://tmva.sourceforge.net/LICENSE) *
28  **********************************************************************************/
29 
30 #ifndef ROOT_TMVA_DecisionTreeNode
31 #define ROOT_TMVA_DecisionTreeNode
32 
33 //////////////////////////////////////////////////////////////////////////
34 // //
35 // DecisionTreeNode //
36 // //
37 // Node for the Decision Tree //
38 // //
39 //////////////////////////////////////////////////////////////////////////
40 
41 #include "TMVA/Node.h"
42 
43 #include "TMVA/Version.h"
44 
45 #include <iostream>
46 #include <vector>
47 #include <map>
48 namespace TMVA {
49 
51  {
52  public:
54  fSampleMax(),
55  fNodeR(0),fSubTreeR(0),fAlpha(0),fG(0),fNTerminal(0),
56  fNB(0),fNS(0),fSumTarget(0),fSumTarget2(0),fCC(0),
57  fNSigEvents ( 0 ), fNBkgEvents ( 0 ),
58  fNEvents ( -1 ),
61  fNEvents_unweighted ( 0 ),
64  fNEvents_unboosted ( 0 ),
65  fSeparationIndex (-1 ),
66  fSeparationGain ( -1 )
67  {
68  }
69  std::vector< Float_t > fSampleMin; // the minima for each ivar of the sample on the node during training
70  std::vector< Float_t > fSampleMax; // the maxima for each ivar of the sample on the node during training
71  Double_t fNodeR; // node resubstitution estimate, R(t)
72  Double_t fSubTreeR; // R(T) = Sum(R(t) : t in ~T)
73  Double_t fAlpha; // critical alpha for this node
74  Double_t fG; // minimum alpha in subtree rooted at this node
75  Int_t fNTerminal; // number of terminal nodes in subtree rooted at this node
76  Double_t fNB; // sum of weights of background events from the pruning sample in this node
77  Double_t fNS; // ditto for the signal events
78  Float_t fSumTarget; // sum of weight*target used for the calculatio of the variance (regression)
79  Float_t fSumTarget2; // sum of weight*target^2 used for the calculatio of the variance (regression)
80  Double_t fCC; // debug variable for cost complexity pruning ..
81 
82  Float_t fNSigEvents; // sum of weights of signal event in the node
83  Float_t fNBkgEvents; // sum of weights of backgr event in the node
84  Float_t fNEvents; // number of events in that entered the node (during training)
85  Float_t fNSigEvents_unweighted; // sum of signal event in the node
86  Float_t fNBkgEvents_unweighted; // sum of backgr event in the node
87  Float_t fNEvents_unweighted; // number of events in that entered the node (during training)
88  Float_t fNSigEvents_unboosted; // sum of signal event in the node
89  Float_t fNBkgEvents_unboosted; // sum of backgr event in the node
90  Float_t fNEvents_unboosted; // number of events in that entered the node (during training)
91  Float_t fSeparationIndex; // measure of "purity" (separation between S and B) AT this node
92  Float_t fSeparationGain; // measure of "purity", separation, or information gained BY this nodes selection
93 
94  // copy constructor
96  fSampleMin(),fSampleMax(), // Samplemin and max are reset in copy constructor
97  fNodeR(n.fNodeR), fSubTreeR(n.fSubTreeR),
98  fAlpha(n.fAlpha), fG(n.fG),
99  fNTerminal(n.fNTerminal),
100  fNB(n.fNB), fNS(n.fNS),
101  fSumTarget(0),fSumTarget2(0), // SumTarget reset in copy constructor
102  fCC(0),
103  fNSigEvents ( n.fNSigEvents ), fNBkgEvents ( n.fNBkgEvents ),
104  fNEvents ( n.fNEvents ),
105  fNSigEvents_unweighted ( n.fNSigEvents_unweighted ),
106  fNBkgEvents_unweighted ( n.fNBkgEvents_unweighted ),
107  fNEvents_unweighted ( n.fNEvents_unweighted ),
108  fSeparationIndex( n.fSeparationIndex ),
109  fSeparationGain ( n.fSeparationGain )
110  { }
111  };
112 
113  class Event;
114  class MsgLogger;
115 
116  class DecisionTreeNode: public Node {
117 
118  public:
119 
120  // constructor of an essentially "empty" node floating in space
121  DecisionTreeNode ();
122  // constructor of a daughter node as a daughter of 'p'
123  DecisionTreeNode (Node* p, char pos);
124 
125  // copy constructor
126  DecisionTreeNode (const DecisionTreeNode &n, DecisionTreeNode* parent = NULL);
127 
128  // destructor
129  virtual ~DecisionTreeNode();
130 
131  virtual Node* CreateNode() const { return new DecisionTreeNode(); }
132 
133  inline void SetNFisherCoeff(Int_t nvars){fFisherCoeff.resize(nvars);}
134  inline UInt_t GetNFisherCoeff() const { return fFisherCoeff.size();}
135  // set fisher coefficients
136  void SetFisherCoeff(Int_t ivar, Double_t coeff);
137  // get fisher coefficients
138  Double_t GetFisherCoeff(Int_t ivar) const {return fFisherCoeff.at(ivar);}
139 
140  // test event if it decends the tree at this node to the right
141  virtual Bool_t GoesRight( const Event & ) const;
142 
143  // test event if it decends the tree at this node to the left
144  virtual Bool_t GoesLeft ( const Event & ) const;
145 
146  // set index of variable used for discrimination at this node
147  void SetSelector( Short_t i) { fSelector = i; }
148  // return index of variable used for discrimination at this node
149  Short_t GetSelector() const { return fSelector; }
150 
151  // set the cut value applied at this node
152  void SetCutValue ( Float_t c ) { fCutValue = c; }
153  // return the cut value applied at this node
154  Float_t GetCutValue ( void ) const { return fCutValue; }
155 
156  // set true: if event variable > cutValue ==> signal , false otherwise
157  void SetCutType( Bool_t t ) { fCutType = t; }
158  // return kTRUE: Cuts select signal, kFALSE: Cuts select bkg
159  Bool_t GetCutType( void ) const { return fCutType; }
160 
161  // set node type: 1 signal node, -1 bkg leave, 0 intermediate Node
162  void SetNodeType( Int_t t ) { fNodeType = t;}
163  // return node type: 1 signal node, -1 bkg leave, 0 intermediate Node
164  Int_t GetNodeType( void ) const { return fNodeType; }
165 
166  //return S/(S+B) (purity) at this node (from training)
167  Float_t GetPurity( void ) const { return fPurity;}
168  //calculate S/(S+B) (purity) at this node (from training)
169  void SetPurity( void );
170 
171  //set the response of the node (for regression)
172  void SetResponse( Float_t r ) { fResponse = r;}
173 
174  //return the response of the node (for regression)
175  Float_t GetResponse( void ) const { return fResponse;}
176 
177  //set the RMS of the response of the node (for regression)
178  void SetRMS( Float_t r ) { fRMS = r;}
179 
180  //return the RMS of the response of the node (for regression)
181  Float_t GetRMS( void ) const { return fRMS;}
182 
183  // set the sum of the signal weights in the node
184  void SetNSigEvents( Float_t s ) { fTrainInfo->fNSigEvents = s; }
185 
186  // set the sum of the backgr weights in the node
187  void SetNBkgEvents( Float_t b ) { fTrainInfo->fNBkgEvents = b; }
188 
189  // set the number of events that entered the node (during training)
190  void SetNEvents( Float_t nev ){ fTrainInfo->fNEvents =nev ; }
191 
192  // set the sum of the unweighted signal events in the node
193  void SetNSigEvents_unweighted( Float_t s ) { fTrainInfo->fNSigEvents_unweighted = s; }
194 
195  // set the sum of the unweighted backgr events in the node
196  void SetNBkgEvents_unweighted( Float_t b ) { fTrainInfo->fNBkgEvents_unweighted = b; }
197 
198  // set the number of unweighted events that entered the node (during training)
199  void SetNEvents_unweighted( Float_t nev ){ fTrainInfo->fNEvents_unweighted =nev ; }
200 
201  // set the sum of the unboosted signal events in the node
202  void SetNSigEvents_unboosted( Float_t s ) { fTrainInfo->fNSigEvents_unboosted = s; }
203 
204  // set the sum of the unboosted backgr events in the node
205  void SetNBkgEvents_unboosted( Float_t b ) { fTrainInfo->fNBkgEvents_unboosted = b; }
206 
207  // set the number of unboosted events that entered the node (during training)
208  void SetNEvents_unboosted( Float_t nev ){ fTrainInfo->fNEvents_unboosted =nev ; }
209 
210  // increment the sum of the signal weights in the node
211  void IncrementNSigEvents( Float_t s ) { fTrainInfo->fNSigEvents += s; }
212 
213  // increment the sum of the backgr weights in the node
214  void IncrementNBkgEvents( Float_t b ) { fTrainInfo->fNBkgEvents += b; }
215 
216  // increment the number of events that entered the node (during training)
217  void IncrementNEvents( Float_t nev ){ fTrainInfo->fNEvents +=nev ; }
218 
219  // increment the sum of the signal weights in the node
220  void IncrementNSigEvents_unweighted( ) { fTrainInfo->fNSigEvents_unweighted += 1; }
221 
222  // increment the sum of the backgr weights in the node
223  void IncrementNBkgEvents_unweighted( ) { fTrainInfo->fNBkgEvents_unweighted += 1; }
224 
225  // increment the number of events that entered the node (during training)
226  void IncrementNEvents_unweighted( ){ fTrainInfo->fNEvents_unweighted +=1 ; }
227 
228  // return the sum of the signal weights in the node
229  Float_t GetNSigEvents( void ) const { return fTrainInfo->fNSigEvents; }
230 
231  // return the sum of the backgr weights in the node
232  Float_t GetNBkgEvents( void ) const { return fTrainInfo->fNBkgEvents; }
233 
234  // return the number of events that entered the node (during training)
235  Float_t GetNEvents( void ) const { return fTrainInfo->fNEvents; }
236 
237  // return the sum of unweighted signal weights in the node
238  Float_t GetNSigEvents_unweighted( void ) const { return fTrainInfo->fNSigEvents_unweighted; }
239 
240  // return the sum of unweighted backgr weights in the node
241  Float_t GetNBkgEvents_unweighted( void ) const { return fTrainInfo->fNBkgEvents_unweighted; }
242 
243  // return the number of unweighted events that entered the node (during training)
244  Float_t GetNEvents_unweighted( void ) const { return fTrainInfo->fNEvents_unweighted; }
245 
246  // return the sum of unboosted signal weights in the node
247  Float_t GetNSigEvents_unboosted( void ) const { return fTrainInfo->fNSigEvents_unboosted; }
248 
249  // return the sum of unboosted backgr weights in the node
250  Float_t GetNBkgEvents_unboosted( void ) const { return fTrainInfo->fNBkgEvents_unboosted; }
251 
252  // return the number of unboosted events that entered the node (during training)
253  Float_t GetNEvents_unboosted( void ) const { return fTrainInfo->fNEvents_unboosted; }
254 
255 
256  // set the choosen index, measure of "purity" (separation between S and B) AT this node
257  void SetSeparationIndex( Float_t sep ){ fTrainInfo->fSeparationIndex =sep ; }
258  // return the separation index AT this node
259  Float_t GetSeparationIndex( void ) const { return fTrainInfo->fSeparationIndex; }
260 
261  // set the separation, or information gained BY this nodes selection
262  void SetSeparationGain( Float_t sep ){ fTrainInfo->fSeparationGain =sep ; }
263  // return the gain in separation obtained by this nodes selection
264  Float_t GetSeparationGain( void ) const { return fTrainInfo->fSeparationGain; }
265 
266  // printout of the node
267  virtual void Print( std::ostream& os ) const;
268 
269  // recursively print the node and its daughters (--> print the 'tree')
270  virtual void PrintRec( std::ostream& os ) const;
271 
272  virtual void AddAttributesToNode(void* node) const;
273  virtual void AddContentToNode(std::stringstream& s) const;
274 
275  // recursively clear the nodes content (S/N etc, but not the cut criteria)
276  void ClearNodeAndAllDaughters();
277 
278  // get pointers to children, mother in the tree
279 
280  // return pointer to the left/right daughter or parent node
281  inline virtual DecisionTreeNode* GetLeft( ) const { return static_cast<DecisionTreeNode*>(fLeft); }
282  inline virtual DecisionTreeNode* GetRight( ) const { return static_cast<DecisionTreeNode*>(fRight); }
283  inline virtual DecisionTreeNode* GetParent( ) const { return static_cast<DecisionTreeNode*>(fParent); }
284 
285  // set pointer to the left/right daughter and parent node
286  inline virtual void SetLeft (Node* l) { fLeft = l;}
287  inline virtual void SetRight (Node* r) { fRight = r;}
288  inline virtual void SetParent(Node* p) { fParent = p;}
289 
290 
291 
292 
293  // the node resubstitution estimate, R(t), for Cost Complexity pruning
294  inline void SetNodeR( Double_t r ) { fTrainInfo->fNodeR = r; }
295  inline Double_t GetNodeR( ) const { return fTrainInfo->fNodeR; }
296 
297  // the resubstitution estimate, R(T_t), of the tree rooted at this node
298  inline void SetSubTreeR( Double_t r ) { fTrainInfo->fSubTreeR = r; }
299  inline Double_t GetSubTreeR( ) const { return fTrainInfo->fSubTreeR; }
300 
301  // R(t) - R(T_t)
302  // the critical point alpha = -------------
303  // |~T_t| - 1
304  inline void SetAlpha( Double_t alpha ) { fTrainInfo->fAlpha = alpha; }
305  inline Double_t GetAlpha( ) const { return fTrainInfo->fAlpha; }
306 
307  // the minimum alpha in the tree rooted at this node
308  inline void SetAlphaMinSubtree( Double_t g ) { fTrainInfo->fG = g; }
309  inline Double_t GetAlphaMinSubtree( ) const { return fTrainInfo->fG; }
310 
311  // number of terminal nodes in the subtree rooted here
312  inline void SetNTerminal( Int_t n ) { fTrainInfo->fNTerminal = n; }
313  inline Int_t GetNTerminal( ) const { return fTrainInfo->fNTerminal; }
314 
315  // number of background/signal events from the pruning validation sample
316  inline void SetNBValidation( Double_t b ) { fTrainInfo->fNB = b; }
317  inline void SetNSValidation( Double_t s ) { fTrainInfo->fNS = s; }
318  inline Double_t GetNBValidation( ) const { return fTrainInfo->fNB; }
319  inline Double_t GetNSValidation( ) const { return fTrainInfo->fNS; }
320 
321 
322  inline void SetSumTarget(Float_t t) {fTrainInfo->fSumTarget = t; }
323  inline void SetSumTarget2(Float_t t2){fTrainInfo->fSumTarget2 = t2; }
324 
325  inline void AddToSumTarget(Float_t t) {fTrainInfo->fSumTarget += t; }
326  inline void AddToSumTarget2(Float_t t2){fTrainInfo->fSumTarget2 += t2; }
327 
328  inline Float_t GetSumTarget() const {return fTrainInfo? fTrainInfo->fSumTarget : -9999;}
329  inline Float_t GetSumTarget2() const {return fTrainInfo? fTrainInfo->fSumTarget2: -9999;}
330 
331 
332  // reset the pruning validation data
333  void ResetValidationData( );
334 
335  // flag indicates whether this node is terminal
336  inline Bool_t IsTerminal() const { return fIsTerminalNode; }
337  inline void SetTerminal( Bool_t s = kTRUE ) { fIsTerminalNode = s; }
338  void PrintPrune( std::ostream& os ) const ;
339  void PrintRecPrune( std::ostream& os ) const;
340 
341  void SetCC(Double_t cc);
342  Double_t GetCC() const {return (fTrainInfo? fTrainInfo->fCC : -1.);}
343 
344  Float_t GetSampleMin(UInt_t ivar) const;
345  Float_t GetSampleMax(UInt_t ivar) const;
346  void SetSampleMin(UInt_t ivar, Float_t xmin);
347  void SetSampleMax(UInt_t ivar, Float_t xmax);
348 
349  static bool fgIsTraining; // static variable to flag training phase in which we need fTrainInfo
350  static UInt_t fgTmva_Version_Code; // set only when read from weightfile
351 
352  virtual Bool_t ReadDataRecord( std::istream& is, UInt_t tmva_Version_Code = TMVA_VERSION_CODE );
353  virtual void ReadAttributes(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE );
354  virtual void ReadContent(std::stringstream& s);
355 
356  protected:
357 
358  static MsgLogger& Log();
359 
360  std::vector<Double_t> fFisherCoeff; // the fisher coeff (offset at the last element)
361 
362  Float_t fCutValue; // cut value appplied on this node to discriminate bkg against sig
363  Bool_t fCutType; // true: if event variable > cutValue ==> signal , false otherwise
364  Short_t fSelector; // index of variable used in node selection (decision tree)
365 
366  Float_t fResponse; // response value in case of regression
367  Float_t fRMS; // response RMS of the regression node
368  Int_t fNodeType; // Type of node: -1 == Bkg-leaf, 1 == Signal-leaf, 0 = internal
369  Float_t fPurity; // the node purity
370 
371  Bool_t fIsTerminalNode; //! flag to set node as terminal (i.e., without deleting its descendants)
372 
374 
375  private:
376 
377  ClassDef(DecisionTreeNode,0); // Node for the Decision Tree
378  };
379 } // namespace TMVA
380 
381 #endif
Float_t GetRMS(void) const
DTNodeTrainingInfo * fTrainInfo
flag to set node as terminal (i.e., without deleting its descendants)
float xmin
Definition: THbookFile.cxx:93
void SetSelector(Short_t i)
#define TMVA_VERSION_CODE
Definition: Version.h:47
Float_t GetNEvents(void) const
Double_t Log(Double_t x)
Definition: TMath.h:759
float Float_t
Definition: RtypesCore.h:53
Float_t GetSumTarget() const
#define g(i)
Definition: RSha256.hxx:105
Float_t GetNBkgEvents_unweighted(void) const
std::vector< Float_t > fSampleMax
virtual DecisionTreeNode * GetParent() const
Float_t GetSumTarget2() const
void IncrementNEvents(Float_t nev)
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
virtual void SetRight(Node *r)
Bool_t GetCutType(void) const
Float_t GetNSigEvents(void) const
std::vector< Float_t > fSampleMin
Float_t GetNEvents_unboosted(void) const
void SetNSigEvents_unweighted(Float_t s)
void SetResponse(Float_t r)
void SetNBValidation(Double_t b)
void SetNFisherCoeff(Int_t nvars)
Double_t GetSubTreeR() const
Float_t GetSeparationIndex(void) const
#define ClassDef(name, id)
Definition: Rtypes.h:320
Float_t GetNBkgEvents(void) const
Float_t GetCutValue(void) const
void IncrementNBkgEvents(Float_t b)
Double_t GetNodeR() const
UInt_t GetNFisherCoeff() const
void SetSeparationGain(Float_t sep)
void SetNodeR(Double_t r)
void SetNBkgEvents(Float_t b)
void SetNSValidation(Double_t s)
void AddToSumTarget(Float_t t)
void SetNEvents(Float_t nev)
void SetSumTarget2(Float_t t2)
Float_t GetNSigEvents_unweighted(void) const
Int_t GetNodeType(void) const
void SetSubTreeR(Double_t r)
ROOT::R::TRInterface & r
Definition: Object.C:4
DTNodeTrainingInfo(const DTNodeTrainingInfo &n)
virtual void SetLeft(Node *l)
void SetAlpha(Double_t alpha)
Double_t GetFisherCoeff(Int_t ivar) const
void SetCutValue(Float_t c)
unsigned int UInt_t
Definition: RtypesCore.h:42
Float_t GetNBkgEvents_unboosted(void) const
short Short_t
Definition: RtypesCore.h:35
float xmax
Definition: THbookFile.cxx:93
void SetSumTarget(Float_t t)
virtual void SetParent(Node *p)
Double_t GetAlphaMinSubtree() const
void AddToSumTarget2(Float_t t2)
Float_t GetPurity(void) const
static UInt_t fgTmva_Version_Code
Double_t GetCC() const
void Print(std::ostream &os, const OptionType &opt)
double Double_t
Definition: RtypesCore.h:55
Double_t GetNBValidation() const
void IncrementNSigEvents(Float_t s)
void SetAlphaMinSubtree(Double_t g)
void SetNEvents_unboosted(Float_t nev)
static constexpr double s
void SetNSigEvents_unboosted(Float_t s)
Float_t GetNSigEvents_unboosted(void) const
void SetTerminal(Bool_t s=kTRUE)
Float_t GetNEvents_unweighted(void) const
void SetNSigEvents(Float_t s)
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
void SetNBkgEvents_unboosted(Float_t b)
void SetNBkgEvents_unweighted(Float_t b)
Abstract ClassifierFactory template that handles arbitrary types.
Node for the BinarySearch or Decision Trees.
Definition: Node.h:56
auto * l
Definition: textangle.C:4
Float_t GetResponse(void) const
Double_t GetNSValidation() const
you should not use this method at all Int_t Int_t Double_t Double_t Double_t Int_t Double_t Double_t Double_t Double_t b
Definition: TRolke.cxx:630
virtual DecisionTreeNode * GetLeft() const
virtual Node * CreateNode() const
virtual DecisionTreeNode * GetRight() const
#define c(i)
Definition: RSha256.hxx:101
std::vector< Double_t > fFisherCoeff
void SetSeparationIndex(Float_t sep)
Double_t GetAlpha() const
Short_t GetSelector() const
const Bool_t kTRUE
Definition: RtypesCore.h:87
const Int_t n
Definition: legend1.C:16
Float_t GetSeparationGain(void) const
void SetNEvents_unweighted(Float_t nev)