ROOT logo
ROOT » TMVA » TMVA::DecisionTree

class TMVA::DecisionTree: public TMVA::BinaryTree


 Implementation of a Decision Tree

 In a decision tree successive decision nodes are used to categorize the
 events out of the sample as either signal or background. Each node
 uses only a single discriminating variable to decide if the event is
 signal-like ("goes right") or background-like ("goes left"). This
 forms a tree like structure with "baskets" at the end (leave nodes),
 and an event is classified as either signal or background according to
 whether the basket where it ends up has been classified signal or
 background during the training. Training of a decision tree is the
 process to define the "cut criteria" for each node. The training
 starts with the root node. Here one takes the full training event
 sample and selects the variable and corresponding cut value that gives
 the best separation between signal and background at this stage. Using
 this cut criterion, the sample is then divided into two subsamples, a
 signal-like (right) and a background-like (left) sample. Two new nodes
 are then created for each of the two sub-samples and they are
 constructed using the same mechanism as described for the root
 node. The devision is stopped once a certain node has reached either a
 minimum number of events, or a minimum or maximum signal purity. These
 leave nodes are then called "signal" or "background" if they contain
 more signal respective background events from the training sample.

Function Members (Methods)

public:
virtual~DecisionTree()
virtual void*TMVA::BinaryTree::AddXMLTo(void* parent) const
voidApplyValidationSample(const TMVA::DecisionTree::EventConstList* validationSample) const
UInt_tBuildTree(const TMVA::DecisionTree::EventConstList& eventSample, TMVA::DecisionTreeNode* node = NULL)
Double_tCheckEvent(const TMVA::Event*, Bool_t UseYesNoLeaf = kFALSE) const
voidCheckEventWithPrunedTree(const TMVA::Event*) const
static TClass*Class()
virtual const char*ClassName() const
UInt_tCleanTree(TMVA::DecisionTreeNode* node = NULL)
voidClearTree()
UInt_tCountLeafNodes(TMVA::Node* n = NULL)
UInt_tTMVA::BinaryTree::CountNodes(TMVA::Node* n = NULL)
static TMVA::DecisionTree*CreateFromXML(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE)
virtual TMVA::DecisionTreeNode*CreateNode(UInt_t) const
virtual TMVA::BinaryTree*CreateTree() const
TMVA::DecisionTreeDecisionTree()
TMVA::DecisionTreeDecisionTree(const TMVA::DecisionTree& d)
TMVA::DecisionTreeDecisionTree(TMVA::SeparationBase* sepType, Float_t minSize, Int_t nCuts, TMVA::DataSetInfo* = NULL, UInt_t cls = 0, Bool_t randomisedTree = kFALSE, Int_t useNvars = 0, Bool_t usePoissonNvars = kFALSE, UInt_t nMaxDepth = 9999999, Int_t iSeed = fgRandomSeed, Float_t purityLimit = 0.5, Int_t treeID = 0)
voidDescendTree(TMVA::Node* n = NULL)
Bool_tDoRegression() const
voidFillEvent(const TMVA::Event& event, TMVA::DecisionTreeNode* node)
voidFillTree(const TMVA::DecisionTree::EventList& eventSample)
TMVA::Types::EAnalysisTypeGetAnalysisType()
TMVA::DecisionTreeNode*GetEventNode(const TMVA::Event& e) const
vector<Double_t>GetFisherCoefficients(const TMVA::DecisionTree::EventConstList& eventSample, UInt_t nFisherVars, UInt_t* mapVarInFisher)
TMVA::Node*TMVA::BinaryTree::GetLeftDaughter(TMVA::Node* n)
UInt_tTMVA::BinaryTree::GetNNodes() const
Int_tGetNNodesBeforePruning()
TMVA::Node*GetNode(ULong_t sequence, UInt_t depth)
Double_tGetNodePurityLimit() const
Double_tGetPruneStrength() const
voidGetRandomisedVariables(Bool_t* useVariable, UInt_t* variableMap, UInt_t& nVars)
TMVA::Node*TMVA::BinaryTree::GetRightDaughter(TMVA::Node* n)
virtual TMVA::DecisionTreeNode*GetRoot() const
Double_tGetSumWeights(const TMVA::DecisionTree::EventConstList* validationSample) const
UInt_tTMVA::BinaryTree::GetTotalTreeDepth() const
Int_tGetTreeID()
vector<Double_t>GetVariableImportance()
Double_tGetVariableImportance(UInt_t ivar)
virtual TClass*IsA() const
TMVA::DecisionTree&operator=(const TMVA::DecisionTree&)
virtual voidTMVA::BinaryTree::Print(ostream& os) const
voidPruneNode(TMVA::DecisionTreeNode* node)
voidPruneNodeInPlace(TMVA::DecisionTreeNode* node)
Double_tPruneTree(const TMVA::DecisionTree::EventConstList* validationSample = NULL)
virtual voidTMVA::BinaryTree::Read(istream& istr, UInt_t tmva_Version_Code = TMVA_VERSION_CODE)
virtual voidTMVA::BinaryTree::ReadXML(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE)
voidSetAnalysisType(TMVA::Types::EAnalysisType t)
voidSetMinLinCorrForFisher(Double_t min)
voidSetNodePurityLimit(Double_t p)
voidSetNVars(Int_t n)
voidSetParentTreeInNodes(TMVA::Node* n = NULL)
voidSetPruneMethod(TMVA::DecisionTree::EPruneMethod m = kCostComplexityPruning)
voidSetPruneStrength(Double_t p)
voidTMVA::BinaryTree::SetRoot(TMVA::Node* r)
voidTMVA::BinaryTree::SetTotalTreeDepth(Int_t depth)
voidTMVA::BinaryTree::SetTotalTreeDepth(TMVA::Node* n = NULL)
voidSetTreeID(Int_t treeID)
voidSetUseExclusiveVars(Bool_t t = kTRUE)
voidSetUseFisherCuts(Bool_t t = kTRUE)
virtual voidShowMembers(TMemberInspector&)
virtual voidStreamer(TBuffer&)
voidStreamerNVirtual(TBuffer& ClassDef_StreamerNVirtual_b)
Double_tTestPrunedTreeQuality(const TMVA::DecisionTreeNode* dt = NULL, Int_t mode = 0) const
Double_tTrainNode(const TMVA::DecisionTree::EventConstList& eventSample, TMVA::DecisionTreeNode* node)
Double_tTrainNodeFast(const TMVA::DecisionTree::EventConstList& eventSample, TMVA::DecisionTreeNode* node)
Double_tTrainNodeFull(const TMVA::DecisionTree::EventConstList& eventSample, TMVA::DecisionTreeNode* node)

Data Members

public:
enum EPruneMethod { kExpectedErrorPruning
kCostComplexityPruning
kNoPruning
};
protected:
UInt_tTMVA::BinaryTree::fDepthmaximal depth in tree reached
UInt_tTMVA::BinaryTree::fNNodestotal number of nodes in the tree (counted)
TMVA::Node*TMVA::BinaryTree::fRootthe root node of the tree
static TMVA::MsgLogger*TMVA::BinaryTree::fgLoggermessage logger, static to save resources
private:
TMVA::Types::EAnalysisTypefAnalysisTypekClassification(=0=false) or kRegression(=1=true)
TMVA::DataSetInfo*fDataSetInfo
UInt_tfMaxDepthmax depth
Double_tfMinLinCorrForFisherthe minimum linear correlation between two variables demanded for use in fisher criterium in node splitting
Double_tfMinNodeSizemin fraction of training events in node
Double_tfMinSepGainmin number of separation gain to perform node splitting
Double_tfMinSizemin number of events in node
TRandom3*fMyTrandomrandom number generator for randomised trees
Int_tfNCutsnumber of grid point in variable cut scans
Int_tfNNodesBeforePruningremember this one (in case of pruning, it allows to monitor the before/after
Double_tfNodePurityLimitpurity limit to decide whether a node is signal
UInt_tfNvarsnumber of variables used to separate S and B
TMVA::DecisionTree::EPruneMethodfPruneMethodmethod used for prunig
Double_tfPruneStrengtha parameter to set the "amount" of pruning..needs to be adjusted
Bool_tfRandomisedTreechoose at each node splitting a random set of variables
TMVA::RegressionVariance*fRegTypethe separation crition used in Regression
TMVA::SeparationBase*fSepTypethe separation crition
UInt_tfSigClassclass which is treated as signal when building the tree
Int_tfTreeIDjust an ID number given to the tree.. makes debugging easier as tree knows who he is.
Bool_tfUseExclusiveVarsindividual variables already used in fisher criterium are not anymore analysed individually for node splitting
Bool_tfUseFisherCutsuse multivariate splits using the Fisher criterium
Int_tfUseNvarsthe number of variables used in randomised trees;
Bool_tfUsePoissonNvarsuse "fUseNvars" not as fixed number but as mean of a possion distr. in each split
Bool_tfUseSearchTreecut scan done with binary trees or simple event loop.
vector<Double_t>fVariableImportancethe relative importance of the different variables
static const Int_tfgDebugLeveldebug level determining some printout/control plots etc.
static const Int_tfgRandomSeedset nonzero for debugging and zero for random seeds

Class Charts

Inheritance Inherited Members Includes Libraries
Class Charts

Function documentation

DecisionTree()
 default constructor using the GiniIndex as separation criterion,
 no restrictions on minium number of events in a leave note or the
 separation gain in the node splitting
DecisionTree(TMVA::SeparationBase* sepType, Float_t minSize, Int_t nCuts, TMVA::DataSetInfo* = NULL, UInt_t cls = 0, Bool_t randomisedTree = kFALSE, Int_t useNvars = 0, Bool_t usePoissonNvars = kFALSE, UInt_t nMaxDepth = 9999999, Int_t iSeed = fgRandomSeed, Float_t purityLimit = 0.5, Int_t treeID = 0)
 constructor specifying the separation type, the min number of
 events in a no that is still subjected to further splitting, the
 number of bins in the grid used in applying the cut for the node
 splitting.
DecisionTree(const TMVA::DecisionTree& d)
 copy constructor that creates a true copy, i.e. a completely independent tree
 the node copy will recursively copy all the nodes
~DecisionTree()
 destructor
void SetParentTreeInNodes(TMVA::Node* n = NULL)
 descend a tree to find all its leaf nodes, fill max depth reached in the
 tree at the same time.
TMVA::DecisionTree* CreateFromXML(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE)
 re-create a new tree (decision tree or search tree) from XML
UInt_t BuildTree(const TMVA::DecisionTree::EventConstList& eventSample, TMVA::DecisionTreeNode* node = NULL)
 building the decision tree by recursively calling the splitting of
 one (root-) node into two daughter nodes (returns the number of nodes)
void FillTree(const TMVA::DecisionTree::EventList& eventSample)
 fill the existing the decision tree structure by filling event
 in from the top node and see where they happen to end up
void FillEvent(const TMVA::Event& event, TMVA::DecisionTreeNode* node)
 fill the existing the decision tree structure by filling event
 in from the top node and see where they happen to end up
void ClearTree()
 clear the tree nodes (their S/N, Nevents etc), just keep the structure of the tree
UInt_t CleanTree(TMVA::DecisionTreeNode* node = NULL)
 remove those last splits that result in two leaf nodes that
 are both of the type (i.e. both signal or both background)
 this of course is only a reasonable thing to do when you use
 "YesOrNo" leafs, while it might loose s.th. if you use the
 purity information in the nodes.
 --> hence I don't call it automatically in the tree building
Double_t PruneTree(const TMVA::DecisionTree::EventConstList* validationSample = NULL)
 prune (get rid of internal nodes) the Decision tree to avoid overtraining
 serveral different pruning methods can be applied as selected by the
 variable "fPruneMethod".
void ApplyValidationSample(const TMVA::DecisionTree::EventConstList* validationSample) const
 run the validation sample through the (pruned) tree and fill in the nodes
 the variables NSValidation and NBValidadtion (i.e. how many of the Signal
 and Background events from the validation sample. This is then later used
 when asking for the "tree quality" ..
Double_t TestPrunedTreeQuality(const TMVA::DecisionTreeNode* dt = NULL, Int_t mode = 0) const
 return the misclassification rate of a pruned tree
 a "pruned tree" may have set the variable "IsTerminal" to "arbitrary" at
 any node, hence this tree quality testing will stop there, hence test
 the pruned tree (while the full tree is still in place for normal/later use)
void CheckEventWithPrunedTree(const TMVA::Event* ) const
 pass a single validation event throught a pruned decision tree
 on the way down the tree, fill in all the "intermediate" information
 that would normally be there from training.
Double_t GetSumWeights(const TMVA::DecisionTree::EventConstList* validationSample) const
 calculate the normalization factor for a pruning validation sample
UInt_t CountLeafNodes(TMVA::Node* n = NULL)
 return the number of terminal nodes in the sub-tree below Node n
void DescendTree(TMVA::Node* n = NULL)
 descend a tree to find all its leaf nodes
void PruneNode(TMVA::DecisionTreeNode* node)
 prune away the subtree below the node
void PruneNodeInPlace(TMVA::DecisionTreeNode* node)
 prune a node temporaily (without actually deleting its decendants
 which allows testing the pruned tree quality for many different
 pruning stages without "touching" the tree.
TMVA::Node* GetNode(ULong_t sequence, UInt_t depth)
 retrieve node from the tree. Its position (up to a maximal tree depth of 64)
 is coded as a sequence of left-right moves starting from the root, coded as
 0-1 bit patterns stored in the "long-integer"  (i.e. 0:left ; 1:right
void GetRandomisedVariables(Bool_t* useVariable, UInt_t* variableMap, UInt_t& nVars)
Double_t TrainNodeFast(const TMVA::DecisionTree::EventConstList& eventSample, TMVA::DecisionTreeNode* node)
 Decide how to split a node using one of the variables that gives
 the best separation of signal/background. In order to do this, for each
 variable a scan of the different cut values in a grid (grid = fNCuts) is
 performed and the resulting separation gains are compared.
 in addition to the individual variables, one can also ask for a fisher
 discriminant being built out of (some) of the variables and used as a
 possible multivariate split.
std::vector<Double_t> GetFisherCoefficients(const TMVA::DecisionTree::EventConstList& eventSample, UInt_t nFisherVars, UInt_t* mapVarInFisher)
 calculate the fisher coefficients for the event sample and the variables used
Double_t TrainNodeFull(const TMVA::DecisionTree::EventConstList& eventSample, TMVA::DecisionTreeNode* node)
TMVA::DecisionTreeNode* GetEventNode(const TMVA::Event& e) const
 get the pointer to the leaf node where a particular event ends up in...
 (used in gradient boosting)
Double_t CheckEvent(const TMVA::Event* , Bool_t UseYesNoLeaf = kFALSE) const
 the event e is put into the decision tree (starting at the root node)
 and the output is NodeType (signal) or (background) of the final node (basket)
 in which the given events ends up. I.e. the result of the classification if
 the event for this decision tree.
Double_t SamplePurity(TMVA::DecisionTree::EventList eventSample)
 calculates the purity S/(S+B) of a given event sample
vector< Double_t > GetVariableImportance()
 Return the relative variable importance, normalized to all
 variables together having the importance 1. The importance in
 evaluated as the total separation-gain that this variable had in
 the decision trees (weighted by the number of events)
Double_t GetVariableImportance(UInt_t ivar)
 returns the relative improtance of variable ivar
DecisionTreeNode* GetRoot() const
 Retrieves the address of the root node
{ return dynamic_cast<TMVA::DecisionTreeNode*>(fRoot); }
DecisionTreeNode * CreateNode(UInt_t ) const
{ return new DecisionTreeNode(); }
BinaryTree* CreateTree() const
{ return new DecisionTree(); }
const char* ClassName() const
{ return "DecisionTree"; }
Double_t TrainNode(const TMVA::DecisionTree::EventConstList& eventSample, TMVA::DecisionTreeNode* node)
 determine the way how a node is split (which variable, which cut value)
{ return TrainNodeFast( eventSample, node ); }
void SetPruneMethod(TMVA::DecisionTree::EPruneMethod m = kCostComplexityPruning)
{ fPruneMethod = m; }
void SetPruneStrength(Double_t p)
 manage the pruning strength parameter (iff < 0 -> automate the pruning process)
Double_t GetPruneStrength() const
{ return fPruneStrength; }
void SetNodePurityLimit(Double_t p)
Double_t GetNodePurityLimit() const
{ return fNodePurityLimit; }
Int_t GetNNodesBeforePruning()
void SetTreeID(Int_t treeID)
{fTreeID = treeID;}
Int_t GetTreeID()
{return fTreeID;}
Bool_t DoRegression() const
{ return fAnalysisType == Types::kRegression; }
void SetAnalysisType(TMVA::Types::EAnalysisType t)
Types::EAnalysisType GetAnalysisType( void )
{ return fAnalysisType;}
void SetUseFisherCuts(Bool_t t = kTRUE)
void SetMinLinCorrForFisher(Double_t min)
void SetUseExclusiveVars(Bool_t t = kTRUE)
void SetNVars(Int_t n)
{fNvars = n;}