Logo ROOT   6.12/07
Reference Guide
CCTreeWrapper.h
Go to the documentation of this file.
1 
2 /**********************************************************************************
3  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
4  * Package: TMVA *
5  * Class : CCTreeWrapper *
6  * Web : http://tmva.sourceforge.net *
7  * *
8  * Description: a light wrapper of a decision tree, used to perform cost *
9  * complexity pruning "in-place" Cost Complexity Pruning *
10  * *
11  * Author: Doug Schouten (dschoute@sfu.ca) *
12  * *
13  * *
14  * Copyright (c) 2007: *
15  * CERN, Switzerland *
16  * MPI-K Heidelberg, Germany *
17  * U. of Texas at Austin, USA *
18  * *
19  * Redistribution and use in source and binary forms, with or without *
20  * modification, are permitted according to the terms listed in LICENSE *
21  * (http://tmva.sourceforge.net/LICENSE) *
22  **********************************************************************************/
23 
24 #ifndef ROOT_TMVA_CCTreeWrapper
25 #define ROOT_TMVA_CCTreeWrapper
26 
27 #include "TMVA/Event.h"
28 #include "TMVA/SeparationBase.h"
29 #include "TMVA/DecisionTree.h"
30 #include "TMVA/DataSet.h"
31 #include "TMVA/Version.h"
32 
33 
34 namespace TMVA {
35 
36  class CCTreeWrapper {
37 
38  public:
39 
40  typedef std::vector<Event*> EventList;
41 
42  /////////////////////////////////////////////////////////////
43  // CCTreeNode - a light wrapper of a decision tree node //
44  // //
45  /////////////////////////////////////////////////////////////
46 
47  class CCTreeNode : virtual public Node {
48 
49  public:
50 
51  CCTreeNode( DecisionTreeNode* n = NULL );
52  virtual ~CCTreeNode( );
53 
54  virtual Node* CreateNode() const { return new CCTreeNode(); }
55 
56  // set |~T_t|, the number of terminal descendants of node t
57  inline void SetNLeafDaughters( Int_t N ) { fNLeafDaughters = (N > 0 ? N : 0); }
58 
59  // return |~T_t|
60  inline Int_t GetNLeafDaughters() const { return fNLeafDaughters; }
61 
62  // set R(t), the node resubstitution estimate (Gini, misclassification, etc.) for the node t
63  inline void SetNodeResubstitutionEstimate( Double_t R ) { fNodeResubstitutionEstimate = (R >= 0 ? R : 0.0); }
64 
65  // return R(t) for node t
67 
68  // set R(T_t) = sum[t' in ~T_t]{ R(t) }, the resubstitution estimate for the branch rooted at
69  // node t (it is an estimate because it is calculated from the training dataset, i.e., the original tree)
70  inline void SetResubstitutionEstimate( Double_t R ) { fResubstitutionEstimate = (R >= 0 ? R : 0.0); }
71 
72  // return R(T_t) for node t
74 
75  // set the critical point of alpha
76  // R(t) - R(T_t)
77  // alpha_c < ------------- := g(t)
78  // |~T_t| - 1
79  // which is the value of alpha such that the branch rooted at node t is pruned
80  inline void SetAlphaC( Double_t alpha ) { fAlphaC = alpha; }
81 
82  // get the critical alpha value for this node
83  inline Double_t GetAlphaC( ) const { return fAlphaC; }
84 
85  // set the minimum critical alpha value for descendants of node t ( G(t) = min(alpha_c, g(t_l), g(t_r)) )
86  inline void SetMinAlphaC( Double_t alpha ) { fMinAlphaC = alpha; }
87 
88  // get the minimum critical alpha value
89  inline Double_t GetMinAlphaC( ) const { return fMinAlphaC; }
90 
91  // get the pointer to the wrapped DT node
92  inline DecisionTreeNode* GetDTNode( ) const { return fDTNode; }
93 
94  // get pointers to children, mother in the CC tree
95  inline CCTreeNode* GetLeftDaughter( ) { return dynamic_cast<CCTreeNode*>(GetLeft()); }
96  inline CCTreeNode* GetRightDaughter( ) { return dynamic_cast<CCTreeNode*>(GetRight()); }
97  inline CCTreeNode* GetMother( ) { return dynamic_cast<CCTreeNode*>(GetParent()); }
98 
99  // printout of the node (can be read in with ReadDataRecord)
100  virtual void Print( std::ostream& os ) const;
101 
102  // recursive printout of the node and its daughters
103  virtual void PrintRec ( std::ostream& os ) const;
104 
105  virtual void AddAttributesToNode(void* node) const;
106  virtual void AddContentToNode(std::stringstream& s) const;
107 
108 
109  // test event if it decends the tree at this node to the right
110  inline virtual Bool_t GoesRight( const Event& e ) const { return (GetDTNode() != NULL ?
111  GetDTNode()->GoesRight(e) : false); }
112 
113  // test event if it decends the tree at this node to the left
114  inline virtual Bool_t GoesLeft ( const Event& e ) const { return (GetDTNode() != NULL ?
115  GetDTNode()->GoesLeft(e) : false); }
116  // initialize a node from a data record
117  virtual void ReadAttributes(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE);
118  virtual void ReadContent(std::stringstream& s);
119  virtual Bool_t ReadDataRecord( std::istream& in, UInt_t tmva_Version_Code = TMVA_VERSION_CODE );
120 
121  private:
122 
123  Int_t fNLeafDaughters; //! number of terminal descendants
124  Double_t fNodeResubstitutionEstimate; //! R(t) = misclassification rate for node t
125  Double_t fResubstitutionEstimate; //! R(T_t) = sum[t' in ~T_t]{ R(t) }
126  Double_t fAlphaC; //! critical point, g(t) = alpha_c(t)
127  Double_t fMinAlphaC; //! G(t), minimum critical point of t and its descendants
128  DecisionTreeNode* fDTNode; //! pointer to wrapped node in the decision tree
129  };
130 
131  CCTreeWrapper( DecisionTree* T, SeparationBase* qualityIndex );
132  ~CCTreeWrapper( );
133 
134  // return the decision tree output for an event
135  Double_t CheckEvent( const TMVA::Event & e, Bool_t useYesNoLeaf = false );
136  // return the misclassification rate of a pruned tree for a validation event sample
137  Double_t TestTreeQuality( const EventList* validationSample );
138  Double_t TestTreeQuality( const DataSet* validationSample );
139 
140  // remove the branch rooted at node t
141  void PruneNode( CCTreeNode* t );
142  // initialize the node t and all its descendants
143  void InitTree( CCTreeNode* t );
144 
145  // return the root node for this tree
146  CCTreeNode* GetRoot() { return fRoot; }
147  private:
148  SeparationBase* fQualityIndex; //! pointer to the used quality index calculator
149  DecisionTree* fDTParent; //! pointer to underlying DecisionTree
150  CCTreeNode* fRoot; //! the root node of the (wrapped) decision Tree
151  };
152 
153 }
154 
155 #endif
156 
157 
158 
virtual ~CCTreeNode()
destructor of a CCTreeNode
CCTreeNode * fRoot
pointer to underlying DecisionTree
#define TMVA_VERSION_CODE
Definition: Version.h:47
virtual void PrintRec(std::ostream &os) const
recursive printout of the node and its daughters
double T(double x)
Definition: ChebyshevPol.h:34
#define N
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
Double_t fAlphaC
R(T_t) = sum[t&#39; in ~T_t]{ R(t) }.
virtual Node * CreateNode() const
Definition: CCTreeWrapper.h:54
virtual void AddContentToNode(std::stringstream &s) const
void SetResubstitutionEstimate(Double_t R)
Definition: CCTreeWrapper.h:70
void InitTree(CCTreeNode *t)
initialize the node t and all its descendants
Double_t CheckEvent(const TMVA::Event &e, Bool_t useYesNoLeaf=false)
return the decision tree output for an event
CCTreeWrapper(DecisionTree *T, SeparationBase *qualityIndex)
constructor
virtual Node * GetRight() const
Definition: Node.h:88
virtual Node * GetLeft() const
Definition: Node.h:87
DecisionTreeNode * fDTNode
G(t), minimum critical point of t and its descendants.
Double_t GetResubstitutionEstimate() const
Definition: CCTreeWrapper.h:73
Double_t fNodeResubstitutionEstimate
number of terminal descendants
virtual void ReadContent(std::stringstream &s)
Class that contains all the data information.
Definition: DataSet.h:69
DecisionTree * fDTParent
pointer to the used quality index calculator
virtual Node * GetParent() const
Definition: Node.h:89
std::vector< Event * > EventList
Definition: CCTreeWrapper.h:40
virtual Bool_t GoesLeft(const Event &e) const
Double_t TestTreeQuality(const EventList *validationSample)
return the misclassification rate of a pruned tree for a validation event sample using an EventList ...
SeparationBase * fQualityIndex
CCTreeNode * GetRoot()
void SetNodeResubstitutionEstimate(Double_t R)
Definition: CCTreeWrapper.h:63
Implementation of a Decision Tree.
Definition: DecisionTree.h:59
unsigned int UInt_t
Definition: RtypesCore.h:42
Double_t fResubstitutionEstimate
R(t) = misclassification rate for node t.
An interface to calculate the "SeparationGain" for different separation criteria used in various trai...
void SetMinAlphaC(Double_t alpha)
Definition: CCTreeWrapper.h:86
virtual void Print(std::ostream &os) const
printout of the node (can be read in with ReadDataRecord)
virtual Bool_t ReadDataRecord(std::istream &in, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
initialize a node from a data record
virtual Bool_t GoesRight(const Event &e) const
double Double_t
Definition: RtypesCore.h:55
static constexpr double s
you should not use this method at all Int_t Int_t Double_t Double_t Double_t e
Definition: TRolke.cxx:630
DecisionTreeNode * GetDTNode() const
Definition: CCTreeWrapper.h:92
void SetAlphaC(Double_t alpha)
Definition: CCTreeWrapper.h:80
Abstract ClassifierFactory template that handles arbitrary types.
Node for the BinarySearch or Decision Trees.
Definition: Node.h:56
virtual void ReadAttributes(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
~CCTreeWrapper()
destructor
CCTreeNode(DecisionTreeNode *n=NULL)
constructor of the CCTreeNode
void PruneNode(CCTreeNode *t)
remove the branch rooted at node t
constexpr Double_t R()
Definition: TMath.h:213
const Int_t n
Definition: legend1.C:16
virtual void AddAttributesToNode(void *node) const
Double_t fMinAlphaC
critical point, g(t) = alpha_c(t)
Double_t GetNodeResubstitutionEstimate() const
Definition: CCTreeWrapper.h:66