Logo ROOT  
Reference Guide
CCTreeWrapper.h
Go to the documentation of this file.
1
2/**********************************************************************************
3 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
4 * Package: TMVA *
5 * Class : CCTreeWrapper *
6 * Web : http://tmva.sourceforge.net *
7 * *
8 * Description: a light wrapper of a decision tree, used to perform cost *
9 * complexity pruning "in-place" Cost Complexity Pruning *
10 * *
11 * Author: Doug Schouten (dschoute@sfu.ca) *
12 * *
13 * *
14 * Copyright (c) 2007: *
15 * CERN, Switzerland *
16 * MPI-K Heidelberg, Germany *
17 * U. of Texas at Austin, USA *
18 * *
19 * Redistribution and use in source and binary forms, with or without *
20 * modification, are permitted according to the terms listed in LICENSE *
21 * (http://tmva.sourceforge.net/LICENSE) *
22 **********************************************************************************/
23
24#ifndef ROOT_TMVA_CCTreeWrapper
25#define ROOT_TMVA_CCTreeWrapper
26
27#include "TMVA/Event.h"
28#include "TMVA/SeparationBase.h"
29#include "TMVA/DecisionTree.h"
30#include "TMVA/DataSet.h"
31#include "TMVA/Version.h"
32
33
34namespace TMVA {
35
37
38 public:
39
40 typedef std::vector<Event*> EventList;
41
42 /////////////////////////////////////////////////////////////
43 // CCTreeNode - a light wrapper of a decision tree node //
44 // //
45 /////////////////////////////////////////////////////////////
46
47 class CCTreeNode : virtual public Node {
48
49 public:
50
51 CCTreeNode( DecisionTreeNode* n = NULL );
52 virtual ~CCTreeNode( );
53
54 virtual Node* CreateNode() const { return new CCTreeNode(); }
55
56 // set |~T_t|, the number of terminal descendants of node t
57 inline void SetNLeafDaughters( Int_t N ) { fNLeafDaughters = (N > 0 ? N : 0); }
58
59 // return |~T_t|
60 inline Int_t GetNLeafDaughters() const { return fNLeafDaughters; }
61
62 // set R(t), the node resubstitution estimate (Gini, misclassification, etc.) for the node t
64
65 // return R(t) for node t
67
68 // set R(T_t) = sum[t' in ~T_t]{ R(t) }, the resubstitution estimate for the branch rooted at
69 // node t (it is an estimate because it is calculated from the training dataset, i.e., the original tree)
70 inline void SetResubstitutionEstimate( Double_t R ) { fResubstitutionEstimate = (R >= 0 ? R : 0.0); }
71
72 // return R(T_t) for node t
74
75 // set the critical point of alpha
76 // R(t) - R(T_t)
77 // alpha_c < ------------- := g(t)
78 // |~T_t| - 1
79 // which is the value of alpha such that the branch rooted at node t is pruned
80 inline void SetAlphaC( Double_t alpha ) { fAlphaC = alpha; }
81
82 // get the critical alpha value for this node
83 inline Double_t GetAlphaC( ) const { return fAlphaC; }
84
85 // set the minimum critical alpha value for descendants of node t ( G(t) = min(alpha_c, g(t_l), g(t_r)) )
86 inline void SetMinAlphaC( Double_t alpha ) { fMinAlphaC = alpha; }
87
88 // get the minimum critical alpha value
89 inline Double_t GetMinAlphaC( ) const { return fMinAlphaC; }
90
91 // get the pointer to the wrapped DT node
92 inline DecisionTreeNode* GetDTNode( ) const { return fDTNode; }
93
94 // get pointers to children, mother in the CC tree
95 inline CCTreeNode* GetLeftDaughter( ) { return dynamic_cast<CCTreeNode*>(GetLeft()); }
96 inline CCTreeNode* GetRightDaughter( ) { return dynamic_cast<CCTreeNode*>(GetRight()); }
97 inline CCTreeNode* GetMother( ) { return dynamic_cast<CCTreeNode*>(GetParent()); }
98
99 // printout of the node (can be read in with ReadDataRecord)
100 virtual void Print( std::ostream& os ) const;
101
102 // recursive printout of the node and its daughters
103 virtual void PrintRec ( std::ostream& os ) const;
104
105 virtual void AddAttributesToNode(void* node) const;
106 virtual void AddContentToNode(std::stringstream& s) const;
107
108
109 // test event if it decends the tree at this node to the right
110 inline virtual Bool_t GoesRight( const Event& e ) const { return (GetDTNode() != NULL ?
111 GetDTNode()->GoesRight(e) : false); }
112
113 // test event if it decends the tree at this node to the left
114 inline virtual Bool_t GoesLeft ( const Event& e ) const { return (GetDTNode() != NULL ?
115 GetDTNode()->GoesLeft(e) : false); }
116 // initialize a node from a data record
117 virtual void ReadAttributes(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE);
118 virtual void ReadContent(std::stringstream& s);
119 virtual Bool_t ReadDataRecord( std::istream& in, UInt_t tmva_Version_Code = TMVA_VERSION_CODE );
120
121 private:
122
123 Int_t fNLeafDaughters; //! number of terminal descendants
124 Double_t fNodeResubstitutionEstimate; //! R(t) = misclassification rate for node t
125 Double_t fResubstitutionEstimate; //! R(T_t) = sum[t' in ~T_t]{ R(t) }
126 Double_t fAlphaC; //! critical point, g(t) = alpha_c(t)
127 Double_t fMinAlphaC; //! G(t), minimum critical point of t and its descendants
128 DecisionTreeNode* fDTNode; //! pointer to wrapped node in the decision tree
129 };
130
131 CCTreeWrapper( DecisionTree* T, SeparationBase* qualityIndex );
133
134 // return the decision tree output for an event
135 Double_t CheckEvent( const TMVA::Event & e, Bool_t useYesNoLeaf = false );
136 // return the misclassification rate of a pruned tree for a validation event sample
137 Double_t TestTreeQuality( const EventList* validationSample );
138 Double_t TestTreeQuality( const DataSet* validationSample );
139
140 // remove the branch rooted at node t
141 void PruneNode( CCTreeNode* t );
142 // initialize the node t and all its descendants
143 void InitTree( CCTreeNode* t );
144
145 // return the root node for this tree
146 CCTreeNode* GetRoot() { return fRoot; }
147 private:
148 SeparationBase* fQualityIndex; //! pointer to the used quality index calculator
149 DecisionTree* fDTParent; //! pointer to underlying DecisionTree
150 CCTreeNode* fRoot; //! the root node of the (wrapped) decision Tree
151 };
152
153}
154
155#endif
156
157
158
#define R(a, b, c, d, e, f, g, h, i)
Definition: RSha256.hxx:110
#define e(i)
Definition: RSha256.hxx:103
double Double_t
Definition: RtypesCore.h:57
#define N
#define TMVA_VERSION_CODE
Definition: Version.h:47
Double_t fMinAlphaC
critical point, g(t) = alpha_c(t)
Double_t GetNodeResubstitutionEstimate() const
Definition: CCTreeWrapper.h:66
virtual void ReadContent(std::stringstream &s)
CCTreeNode(DecisionTreeNode *n=NULL)
constructor of the CCTreeNode
Double_t fNodeResubstitutionEstimate
number of terminal descendants
virtual void AddAttributesToNode(void *node) const
void SetMinAlphaC(Double_t alpha)
Definition: CCTreeWrapper.h:86
DecisionTreeNode * GetDTNode() const
Definition: CCTreeWrapper.h:92
void SetResubstitutionEstimate(Double_t R)
Definition: CCTreeWrapper.h:70
Double_t fAlphaC
R(T_t) = sum[t' in ~T_t]{ R(t) }.
virtual Bool_t GoesRight(const Event &e) const
Double_t GetResubstitutionEstimate() const
Definition: CCTreeWrapper.h:73
virtual void PrintRec(std::ostream &os) const
recursive printout of the node and its daughters
Double_t fResubstitutionEstimate
R(t) = misclassification rate for node t.
virtual Bool_t GoesLeft(const Event &e) const
DecisionTreeNode * fDTNode
G(t), minimum critical point of t and its descendants.
virtual void AddContentToNode(std::stringstream &s) const
virtual Bool_t ReadDataRecord(std::istream &in, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
initialize a node from a data record
void SetAlphaC(Double_t alpha)
Definition: CCTreeWrapper.h:80
void SetNodeResubstitutionEstimate(Double_t R)
Definition: CCTreeWrapper.h:63
virtual ~CCTreeNode()
destructor of a CCTreeNode
virtual void ReadAttributes(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
virtual Node * CreateNode() const
Definition: CCTreeWrapper.h:54
virtual void Print(std::ostream &os) const
printout of the node (can be read in with ReadDataRecord)
SeparationBase * fQualityIndex
std::vector< Event * > EventList
Definition: CCTreeWrapper.h:40
DecisionTree * fDTParent
pointer to the used quality index calculator
CCTreeNode * fRoot
pointer to underlying DecisionTree
CCTreeNode * GetRoot()
Double_t TestTreeQuality(const EventList *validationSample)
return the misclassification rate of a pruned tree for a validation event sample using an EventList
~CCTreeWrapper()
destructor
void InitTree(CCTreeNode *t)
initialize the node t and all its descendants
void PruneNode(CCTreeNode *t)
remove the branch rooted at node t
CCTreeWrapper(DecisionTree *T, SeparationBase *qualityIndex)
constructor
Double_t CheckEvent(const TMVA::Event &e, Bool_t useYesNoLeaf=false)
return the decision tree output for an event
Class that contains all the data information.
Definition: DataSet.h:69
Implementation of a Decision Tree.
Definition: DecisionTree.h:64
Node for the BinarySearch or Decision Trees.
Definition: Node.h:56
virtual Node * GetLeft() const
Definition: Node.h:87
virtual Node * GetParent() const
Definition: Node.h:89
virtual Node * GetRight() const
Definition: Node.h:88
An interface to calculate the "SeparationGain" for different separation criteria used in various trai...
const Int_t n
Definition: legend1.C:16
double T(double x)
Definition: ChebyshevPol.h:34
static constexpr double s
create variable transformations