Logo ROOT   6.12/07
Reference Guide
DecisionTree.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss, Jan Therhaag, Eckhard von Toerne
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : DecisionTree *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Implementation of a Decision Tree *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
16  * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
17  * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
18  * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
19  * *
20  * Copyright (c) 2005-2011: *
21  * CERN, Switzerland *
22  * U. of Victoria, Canada *
23  * MPI-K Heidelberg, Germany *
24  * U. of Bonn, Germany *
25  * *
26  * Redistribution and use in source and binary forms, with or without *
27  * modification, are permitted according to the terms listed in LICENSE *
28  * (http://mva.sourceforge.net/license.txt) *
29  * *
30  **********************************************************************************/
31 
32 #ifndef ROOT_TMVA_DecisionTree
33 #define ROOT_TMVA_DecisionTree
34 
35 //////////////////////////////////////////////////////////////////////////
36 // //
37 // DecisionTree //
38 // //
39 // Implementation of a Decision Tree //
40 // //
41 //////////////////////////////////////////////////////////////////////////
42 
43 #include "TH2.h"
44 
45 #include "TMVA/Types.h"
46 #include "TMVA/DecisionTreeNode.h"
47 #include "TMVA/BinaryTree.h"
48 #include "TMVA/BinarySearchTree.h"
49 #include "TMVA/SeparationBase.h"
51 #include "TMVA/DataSetInfo.h"
52 
53 class TRandom3;
54 
55 namespace TMVA {
56 
57  class Event;
58 
59  class DecisionTree : public BinaryTree {
60 
61  private:
62 
63  static const Int_t fgRandomSeed; // set nonzero for debugging and zero for random seeds
64 
65  public:
66 
67  typedef std::vector<TMVA::Event*> EventList;
68  typedef std::vector<const TMVA::Event*> EventConstList;
69 
70  // the constructur needed for the "reading" of the decision tree from weight files
71  DecisionTree( void );
72 
73  // the constructur needed for constructing the decision tree via training with events
74  DecisionTree( SeparationBase *sepType, Float_t minSize,
75  Int_t nCuts, DataSetInfo* = NULL,
76  UInt_t cls =0,
77  Bool_t randomisedTree=kFALSE, Int_t useNvars=0, Bool_t usePoissonNvars=kFALSE,
78  UInt_t nMaxDepth=9999999,
79  Int_t iSeed=fgRandomSeed, Float_t purityLimit=0.5,
80  Int_t treeID = 0);
81 
82  // copy constructor
83  DecisionTree (const DecisionTree &d);
84 
85  virtual ~DecisionTree( void );
86 
87  // Retrieves the address of the root node
88  virtual DecisionTreeNode* GetRoot() const { return static_cast<TMVA::DecisionTreeNode*>(fRoot); }
89  virtual DecisionTreeNode * CreateNode(UInt_t) const { return new DecisionTreeNode(); }
90  virtual BinaryTree* CreateTree() const { return new DecisionTree(); }
91  static DecisionTree* CreateFromXML(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE);
92  virtual const char* ClassName() const { return "DecisionTree"; }
93 
94  // building of a tree by recursivly splitting the nodes
95 
96  // UInt_t BuildTree( const EventList & eventSample,
97  // DecisionTreeNode *node = NULL);
98  UInt_t BuildTree( const EventConstList & eventSample,
99  DecisionTreeNode *node = NULL);
100  // determine the way how a node is split (which variable, which cut value)
101 
102  Double_t TrainNode( const EventConstList & eventSample, DecisionTreeNode *node ) { return TrainNodeFast( eventSample, node ); }
103  Double_t TrainNodeFast( const EventConstList & eventSample, DecisionTreeNode *node );
104  Double_t TrainNodeFull( const EventConstList & eventSample, DecisionTreeNode *node );
105  void GetRandomisedVariables(Bool_t *useVariable, UInt_t *variableMap, UInt_t & nVars);
106  std::vector<Double_t> GetFisherCoefficients(const EventConstList &eventSample, UInt_t nFisherVars, UInt_t *mapVarInFisher);
107 
108  // fill at tree with a given structure already (just see how many signa/bkgr
109  // events end up in each node
110 
111  void FillTree( const EventList & eventSample);
112 
113  // fill the existing the decision tree structure by filling event
114  // in from the top node and see where they happen to end up
115  void FillEvent( const TMVA::Event & event,
116  TMVA::DecisionTreeNode *node );
117 
118  // returns: 1 = Signal (right), -1 = Bkg (left)
119 
120  Double_t CheckEvent( const TMVA::Event * , Bool_t UseYesNoLeaf = kFALSE ) const;
122 
123  // return the individual relative variable importance
124  std::vector< Double_t > GetVariableImportance();
125 
127 
128  // clear the tree nodes (their S/N, Nevents etc), just keep the structure of the tree
129 
130  void ClearTree();
131 
132  // set pruning method
135 
136  // recursive pruning of the tree, validation sample required for automatic pruning
137  Double_t PruneTree( const EventConstList* validationSample = NULL );
138 
139  // manage the pruning strength parameter (iff < 0 -> automate the pruning process)
142 
143  // apply pruning validation sample to a decision tree
144  void ApplyValidationSample( const EventConstList* validationSample ) const;
145 
146  // return the misclassification rate of a pruned tree
147  Double_t TestPrunedTreeQuality( const DecisionTreeNode* dt = NULL, Int_t mode=0 ) const;
148 
149  // pass a single validation event throught a pruned decision tree
150  void CheckEventWithPrunedTree( const TMVA::Event* ) const;
151 
152  // calculate the normalization factor for a pruning validation sample
153  Double_t GetSumWeights( const EventConstList* validationSample ) const;
154 
157 
158  void DescendTree( Node *n = NULL );
159  void SetParentTreeInNodes( Node *n = NULL );
160 
161  // retrieve node from the tree. Its position (up to a maximal tree depth of 64)
162  // is coded as a sequence of left-right moves starting from the root, coded as
163  // 0-1 bit patterns stored in the "long-integer" together with the depth
164  Node* GetNode( ULong_t sequence, UInt_t depth );
165 
166  UInt_t CleanTree(DecisionTreeNode *node=NULL);
167 
168  void PruneNode(TMVA::DecisionTreeNode *node);
169 
170  // prune a node from the tree without deleting its descendants; allows one to
171  // effectively prune a tree many times without making deep copies
173 
175 
176 
178 
179  void SetTreeID(Int_t treeID){fTreeID = treeID;};
180  Int_t GetTreeID(){return fTreeID;};
181 
188  inline void SetNVars(Int_t n){fNvars = n;}
189 
190 
191  private:
192  // utility functions
193 
194  // calculate the Purity out of the number of sig and bkg events collected
195  // from individual samples.
196 
197  // calculates the purity S/(S+B) of a given event sample
198  Double_t SamplePurity(EventList eventSample);
199 
200  UInt_t fNvars; // number of variables used to separate S and B
201  Int_t fNCuts; // number of grid point in variable cut scans
202  Bool_t fUseFisherCuts; // use multivariate splits using the Fisher criterium
203  Double_t fMinLinCorrForFisher; // the minimum linear correlation between two variables demanded for use in fisher criterium in node splitting
204  Bool_t fUseExclusiveVars; // individual variables already used in fisher criterium are not anymore analysed individually for node splitting
205 
206  SeparationBase *fSepType; // the separation crition
207  RegressionVariance *fRegType; // the separation crition used in Regression
208 
209  Double_t fMinSize; // min number of events in node
210  Double_t fMinNodeSize; // min fraction of training events in node
211  Double_t fMinSepGain; // min number of separation gain to perform node splitting
212 
213  Bool_t fUseSearchTree; // cut scan done with binary trees or simple event loop.
214  Double_t fPruneStrength; // a parameter to set the "amount" of pruning..needs to be adjusted
215 
216  EPruneMethod fPruneMethod; // method used for prunig
217  Int_t fNNodesBeforePruning; //remember this one (in case of pruning, it allows to monitor the before/after
218 
219  Double_t fNodePurityLimit;// purity limit to decide whether a node is signal
220 
221  Bool_t fRandomisedTree; // choose at each node splitting a random set of variables
222  Int_t fUseNvars; // the number of variables used in randomised trees;
223  Bool_t fUsePoissonNvars; // use "fUseNvars" not as fixed number but as mean of a possion distr. in each split
224 
225  TRandom3 *fMyTrandom; // random number generator for randomised trees
226 
227  std::vector< Double_t > fVariableImportance; // the relative importance of the different variables
228 
229  UInt_t fMaxDepth; // max depth
230  UInt_t fSigClass; // class which is treated as signal when building the tree
231  static const Int_t fgDebugLevel = 0; // debug level determining some printout/control plots etc.
232  Int_t fTreeID; // just an ID number given to the tree.. makes debugging easier as tree knows who he is.
233 
234  Types::EAnalysisType fAnalysisType; // kClassification(=0=false) or kRegression(=1=true)
235 
237 
238 
239  ClassDef(DecisionTree,0); // implementation of a Decision Tree
240  };
241 
242 } // namespace TMVA
243 
244 #endif
void SetPruneMethod(EPruneMethod m=kCostComplexityPruning)
Definition: DecisionTree.h:134
DataSetInfo * fDataSetInfo
Definition: DecisionTree.h:236
virtual BinaryTree * CreateTree() const
Definition: DecisionTree.h:90
Random number generator class based on M.
Definition: TRandom3.h:27
#define TMVA_VERSION_CODE
Definition: Version.h:47
auto * m
Definition: textangle.C:8
float Float_t
Definition: RtypesCore.h:53
Double_t CheckEvent(const TMVA::Event *, Bool_t UseYesNoLeaf=kFALSE) const
the event e is put into the decision tree (starting at the root node) and the output is NodeType (sig...
Double_t GetNodePurityLimit() const
Definition: DecisionTree.h:156
UInt_t GetNNodes() const
Definition: BinaryTree.h:86
EPruneMethod fPruneMethod
Definition: DecisionTree.h:216
EAnalysisType
Definition: Types.h:125
Types::EAnalysisType GetAnalysisType(void)
Definition: DecisionTree.h:184
Calculate the "SeparationGain" for Regression analysis separation criteria used in various training a...
std::vector< Double_t > GetFisherCoefficients(const EventConstList &eventSample, UInt_t nFisherVars, UInt_t *mapVarInFisher)
calculate the fisher coefficients for the event sample and the variables used
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
void SetUseExclusiveVars(Bool_t t=kTRUE)
Definition: DecisionTree.h:187
Double_t fNodePurityLimit
Definition: DecisionTree.h:219
virtual ~DecisionTree(void)
destructor
Double_t TestPrunedTreeQuality(const DecisionTreeNode *dt=NULL, Int_t mode=0) const
return the misclassification rate of a pruned tree a "pruned tree" may have set the variable "IsTermi...
virtual DecisionTreeNode * GetRoot() const
Definition: DecisionTree.h:88
void CheckEventWithPrunedTree(const TMVA::Event *) const
pass a single validation event through a pruned decision tree on the way down the tree...
void SetNodePurityLimit(Double_t p)
Definition: DecisionTree.h:155
std::vector< Double_t > GetVariableImportance()
Return the relative variable importance, normalized to all variables together having the importance 1...
void SetAnalysisType(Types::EAnalysisType t)
Definition: DecisionTree.h:183
std::vector< const TMVA::Event * > EventConstList
Definition: DecisionTree.h:68
Base class for BinarySearch and Decision Trees.
Definition: BinaryTree.h:62
#define ClassDef(name, id)
Definition: Rtypes.h:320
static const Int_t fgRandomSeed
Definition: DecisionTree.h:63
void FillTree(const EventList &eventSample)
fill the existing the decision tree structure by filling event in from the top node and see where the...
Double_t SamplePurity(EventList eventSample)
calculates the purity S/(S+B) of a given event sample
std::vector< Double_t > fVariableImportance
Definition: DecisionTree.h:227
Double_t GetSumWeights(const EventConstList *validationSample) const
calculate the normalization factor for a pruning validation sample
Class that contains all the data information.
Definition: DataSetInfo.h:60
void SetTreeID(Int_t treeID)
Definition: DecisionTree.h:179
UInt_t CountLeafNodes(TMVA::Node *n=NULL)
return the number of terminal nodes in the sub-tree below Node n
Double_t TrainNodeFast(const EventConstList &eventSample, DecisionTreeNode *node)
Decide how to split a node using one of the variables that gives the best separation of signal/backgr...
void DescendTree(Node *n=NULL)
descend a tree to find all its leaf nodes
void FillEvent(const TMVA::Event &event, TMVA::DecisionTreeNode *node)
fill the existing the decision tree structure by filling event in from the top node and see where the...
Double_t fPruneStrength
Definition: DecisionTree.h:214
Bool_t DoRegression() const
Definition: DecisionTree.h:182
Double_t fMinLinCorrForFisher
Definition: DecisionTree.h:203
void SetNVars(Int_t n)
Definition: DecisionTree.h:188
void SetMinLinCorrForFisher(Double_t min)
Definition: DecisionTree.h:186
UInt_t CleanTree(DecisionTreeNode *node=NULL)
remove those last splits that result in two leaf nodes that are both of the type (i.e.
virtual DecisionTreeNode * CreateNode(UInt_t) const
Definition: DecisionTree.h:89
Int_t GetNNodesBeforePruning()
Definition: DecisionTree.h:174
void SetPruneStrength(Double_t p)
Definition: DecisionTree.h:140
void GetRandomisedVariables(Bool_t *useVariable, UInt_t *variableMap, UInt_t &nVars)
Implementation of a Decision Tree.
Definition: DecisionTree.h:59
unsigned int UInt_t
Definition: RtypesCore.h:42
Double_t TrainNodeFull(const EventConstList &eventSample, DecisionTreeNode *node)
train a node by finding the single optimal cut for a single variable that best separates signal and b...
void SetParentTreeInNodes(Node *n=NULL)
descend a tree to find all its leaf nodes, fill max depth reached in the tree at the same time...
An interface to calculate the "SeparationGain" for different separation criteria used in various trai...
void PruneNodeInPlace(TMVA::DecisionTreeNode *node)
prune a node temporarily (without actually deleting its descendants which allows testing the pruned t...
TMVA::DecisionTreeNode * GetEventNode(const TMVA::Event &e) const
get the pointer to the leaf node where a particular event ends up in...
std::vector< TMVA::Event * > EventList
Definition: DecisionTree.h:67
void SetUseFisherCuts(Bool_t t=kTRUE)
Definition: DecisionTree.h:185
void ApplyValidationSample(const EventConstList *validationSample) const
run the validation sample through the (pruned) tree and fill in the nodes the variables NSValidation ...
const Bool_t kFALSE
Definition: RtypesCore.h:88
TRandom3 * fMyTrandom
Definition: DecisionTree.h:225
double Double_t
Definition: RtypesCore.h:55
Node * GetNode(ULong_t sequence, UInt_t depth)
retrieve node from the tree.
static const Int_t fgDebugLevel
Definition: DecisionTree.h:231
void ClearTree()
clear the tree nodes (their S/N, Nevents etc), just keep the structure of the tree ...
Types::EAnalysisType fAnalysisType
Definition: DecisionTree.h:234
unsigned long ULong_t
Definition: RtypesCore.h:51
static DecisionTree * CreateFromXML(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
re-create a new tree (decision tree or search tree) from XML
you should not use this method at all Int_t Int_t Double_t Double_t Double_t e
Definition: TRolke.cxx:630
RegressionVariance * fRegType
Definition: DecisionTree.h:207
SeparationBase * fSepType
Definition: DecisionTree.h:206
Double_t PruneTree(const EventConstList *validationSample=NULL)
prune (get rid of internal nodes) the Decision tree to avoid overtraining several different pruning m...
Abstract ClassifierFactory template that handles arbitrary types.
Node for the BinarySearch or Decision Trees.
Definition: Node.h:56
UInt_t BuildTree(const EventConstList &eventSample, DecisionTreeNode *node=NULL)
building the decision tree by recursively calling the splitting of one (root-) node into two daughter...
DecisionTree(void)
default constructor using the GiniIndex as separation criterion, no restrictions on minium number of ...
Double_t GetPruneStrength() const
Definition: DecisionTree.h:141
const Bool_t kTRUE
Definition: RtypesCore.h:87
virtual const char * ClassName() const
Definition: DecisionTree.h:92
const Int_t n
Definition: legend1.C:16
Double_t TrainNode(const EventConstList &eventSample, DecisionTreeNode *node)
Definition: DecisionTree.h:102
void PruneNode(TMVA::DecisionTreeNode *node)
prune away the subtree below the node