ROOT  6.07/01
Reference Guide
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
CCTreeWrapper.h
Go to the documentation of this file.
1 
2 /**********************************************************************************
3  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
4  * Package: TMVA *
5  * Class : CCTreeWrapper *
6  * Web : http://tmva.sourceforge.net *
7  * *
8  * Description: a light wrapper of a decision tree, used to perform cost *
9  * complexity pruning "in-place" Cost Complexity Pruning *
10  * *
11  * Author: Doug Schouten (dschoute@sfu.ca) *
12  * *
13  * *
14  * Copyright (c) 2007: *
15  * CERN, Switzerland *
16  * MPI-K Heidelberg, Germany *
17  * U. of Texas at Austin, USA *
18  * *
19  * Redistribution and use in source and binary forms, with or without *
20  * modification, are permitted according to the terms listed in LICENSE *
21  * (http://tmva.sourceforge.net/LICENSE) *
22  **********************************************************************************/
23 
24 #ifndef ROOT_TMVA_CCTreeWrapper
25 #define ROOT_TMVA_CCTreeWrapper
26 
27 #ifndef ROOT_TMVA_Event
28 #include "TMVA/Event.h"
29 #endif
30 #ifndef ROOT_TMVA_SeparationBase
31 #include "TMVA/SeparationBase.h"
32 #endif
33 #ifndef ROOT_TMVA_DecisionTree
34 #include "TMVA/DecisionTree.h"
35 #endif
36 #ifndef ROOT_TMVA_DataSet
37 #include "TMVA/DataSet.h"
38 #endif
39 #ifndef ROOT_TMVA_Version
40 #include "TMVA/Version.h"
41 #endif
42 
43 
44 namespace TMVA {
45 
46  class CCTreeWrapper {
47 
48  public:
49 
50  typedef std::vector<Event*> EventList;
51 
52  /////////////////////////////////////////////////////////////
53  // CCTreeNode - a light wrapper of a decision tree node //
54  // //
55  /////////////////////////////////////////////////////////////
56 
57  class CCTreeNode : virtual public Node {
58 
59  public:
60 
62  virtual ~CCTreeNode( );
63 
64  virtual Node* CreateNode() const { return new CCTreeNode(); }
65 
66  // set |~T_t|, the number of terminal descendants of node t
67  inline void SetNLeafDaughters( Int_t N ) { fNLeafDaughters = (N > 0 ? N : 0); }
68 
69  // return |~T_t|
70  inline Int_t GetNLeafDaughters() const { return fNLeafDaughters; }
71 
72  // set R(t), the node resubstitution estimate (Gini, misclassification, etc.) for the node t
73  inline void SetNodeResubstitutionEstimate( Double_t R ) { fNodeResubstitutionEstimate = (R >= 0 ? R : 0.0); }
74 
75  // return R(t) for node t
77 
78  // set R(T_t) = sum[t' in ~T_t]{ R(t) }, the resubstitution estimate for the branch rooted at
79  // node t (it is an estimate because it is calculated from the training dataset, i.e., the original tree)
80  inline void SetResubstitutionEstimate( Double_t R ) { fResubstitutionEstimate = (R >= 0 ? R : 0.0); }
81 
82  // return R(T_t) for node t
84 
85  // set the critical point of alpha
86  // R(t) - R(T_t)
87  // alpha_c < ------------- := g(t)
88  // |~T_t| - 1
89  // which is the value of alpha such that the branch rooted at node t is pruned
90  inline void SetAlphaC( Double_t alpha ) { fAlphaC = alpha; }
91 
92  // get the critical alpha value for this node
93  inline Double_t GetAlphaC( ) const { return fAlphaC; }
94 
95  // set the minimum critical alpha value for descendants of node t ( G(t) = min(alpha_c, g(t_l), g(t_r)) )
96  inline void SetMinAlphaC( Double_t alpha ) { fMinAlphaC = alpha; }
97 
98  // get the minimum critical alpha value
99  inline Double_t GetMinAlphaC( ) const { return fMinAlphaC; }
100 
101  // get the pointer to the wrapped DT node
102  inline DecisionTreeNode* GetDTNode( ) const { return fDTNode; }
103 
104  // get pointers to children, mother in the CC tree
105  inline CCTreeNode* GetLeftDaughter( ) { return dynamic_cast<CCTreeNode*>(GetLeft()); }
106  inline CCTreeNode* GetRightDaughter( ) { return dynamic_cast<CCTreeNode*>(GetRight()); }
107  inline CCTreeNode* GetMother( ) { return dynamic_cast<CCTreeNode*>(GetParent()); }
108 
109  // printout of the node (can be read in with ReadDataRecord)
110  virtual void Print( std::ostream& os ) const;
111 
112  // recursive printout of the node and its daughters
113  virtual void PrintRec ( std::ostream& os ) const;
114 
115  virtual void AddAttributesToNode(void* node) const;
116  virtual void AddContentToNode(std::stringstream& s) const;
117 
118 
119  // test event if it decends the tree at this node to the right
120  inline virtual Bool_t GoesRight( const Event& e ) const { return (GetDTNode() != NULL ?
121  GetDTNode()->GoesRight(e) : false); }
122 
123  // test event if it decends the tree at this node to the left
124  inline virtual Bool_t GoesLeft ( const Event& e ) const { return (GetDTNode() != NULL ?
125  GetDTNode()->GoesLeft(e) : false); }
126 
127  private:
128 
129  // initialize a node from a data record
130  virtual void ReadAttributes(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE);
131  virtual Bool_t ReadDataRecord( std::istream& in, UInt_t tmva_Version_Code = TMVA_VERSION_CODE );
132  virtual void ReadContent(std::stringstream& s);
133 
134  Int_t fNLeafDaughters; //! number of terminal descendants
135  Double_t fNodeResubstitutionEstimate; //! R(t) = misclassification rate for node t
136  Double_t fResubstitutionEstimate; //! R(T_t) = sum[t' in ~T_t]{ R(t) }
137  Double_t fAlphaC; //! critical point, g(t) = alpha_c(t)
138  Double_t fMinAlphaC; //! G(t), minimum critical point of t and its descendants
139  DecisionTreeNode* fDTNode; //! pointer to wrapped node in the decision tree
140  };
141 
142  CCTreeWrapper( DecisionTree* T, SeparationBase* qualityIndex );
143  ~CCTreeWrapper( );
144 
145  // return the decision tree output for an event
146  Double_t CheckEvent( const TMVA::Event & e, Bool_t useYesNoLeaf = false );
147  // return the misclassification rate of a pruned tree for a validation event sample
148  Double_t TestTreeQuality( const EventList* validationSample );
149  Double_t TestTreeQuality( const DataSet* validationSample );
150 
151  // remove the branch rooted at node t
152  void PruneNode( CCTreeNode* t );
153  // initialize the node t and all its descendants
154  void InitTree( CCTreeNode* t );
155 
156  // return the root node for this tree
157  CCTreeNode* GetRoot() { return fRoot; }
158  private:
159  SeparationBase* fQualityIndex; //! pointer to the used quality index calculator
160  DecisionTree* fDTParent; //! pointer to underlying DecisionTree
161  CCTreeNode* fRoot; //! the root node of the (wrapped) decision Tree
162  };
163 
164 }
165 
166 #endif
167 
168 
169 
virtual ~CCTreeNode()
destructor of a CCTreeNode
virtual Bool_t GoesRight(const Event &e) const
CCTreeNode * fRoot
pointer to underlying DecisionTree
#define TMVA_VERSION_CODE
Definition: Version.h:47
#define N
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
Double_t GetResubstitutionEstimate() const
Definition: CCTreeWrapper.h:83
Double_t fAlphaC
R(T_t) = sum[t' in ~T_t]{ R(t) }.
void SetResubstitutionEstimate(Double_t R)
Definition: CCTreeWrapper.h:80
void InitTree(CCTreeNode *t)
initialize the node t and all its descendants
TTree * T
Double_t CheckEvent(const TMVA::Event &e, Bool_t useYesNoLeaf=false)
return the decision tree output for an event
CCTreeWrapper(DecisionTree *T, SeparationBase *qualityIndex)
constructor
Double_t GetNodeResubstitutionEstimate() const
Definition: CCTreeWrapper.h:76
DecisionTreeNode * GetDTNode() const
DecisionTreeNode * fDTNode
G(t), minimum critical point of t and its descendants.
Double_t fNodeResubstitutionEstimate
number of terminal descendants
virtual void ReadContent(std::stringstream &s)
DecisionTree * fDTParent
pointer to the used quality index calculator
TThread * t[5]
Definition: threadsh1.C:13
std::vector< Event * > EventList
Definition: CCTreeWrapper.h:50
Double_t TestTreeQuality(const EventList *validationSample)
return the misclassification rate of a pruned tree for a validation event sample using an EventList ...
SeparationBase * fQualityIndex
CCTreeNode * GetRoot()
virtual Node * CreateNode() const
Definition: CCTreeWrapper.h:64
virtual void PrintRec(std::ostream &os) const
recursive printout of the node and its daughters
virtual void AddAttributesToNode(void *node) const
void SetNodeResubstitutionEstimate(Double_t R)
Definition: CCTreeWrapper.h:73
unsigned int UInt_t
Definition: RtypesCore.h:42
Double_t fResubstitutionEstimate
R(t) = misclassification rate for node t.
virtual void AddContentToNode(std::stringstream &s) const
void SetMinAlphaC(Double_t alpha)
Definition: CCTreeWrapper.h:96
virtual Bool_t ReadDataRecord(std::istream &in, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
initialize a node from a data record
virtual void Print(std::ostream &os) const
printout of the node (can be read in with ReadDataRecord)
Double_t GetMinAlphaC() const
Definition: CCTreeWrapper.h:99
double Double_t
Definition: RtypesCore.h:55
virtual Node * GetParent() const
Definition: Node.h:93
virtual Node * GetRight() const
Definition: Node.h:92
void SetAlphaC(Double_t alpha)
Definition: CCTreeWrapper.h:90
virtual void ReadAttributes(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
virtual Bool_t GoesLeft(const Event &e) const
#define NULL
Definition: Rtypes.h:82
~CCTreeWrapper()
destructor
CCTreeNode(DecisionTreeNode *n=NULL)
constructor of the CCTreeNode
void PruneNode(CCTreeNode *t)
remove the branch rooted at node t
const Int_t n
Definition: legend1.C:16
TRandom3 R
a TMatrixD.
Definition: testIO.cxx:28
Double_t fMinAlphaC
critical point, g(t) = alpha_c(t)
virtual Node * GetLeft() const
Definition: Node.h:91