Logo ROOT   6.08/07
Reference Guide
CCTreeWrapper.cxx
Go to the documentation of this file.
1 /**********************************************************************************
2  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
3  * Package: TMVA *
4  * Class : CCTreeWrapper *
5  * Web : http://tmva.sourceforge.net *
6  * *
7  * Description: a light wrapper of a decision tree, used to perform cost *
8  * complexity pruning "in-place" Cost Complexity Pruning *
9  * *
10  * Author: Doug Schouten (dschoute@sfu.ca) *
11  * *
12  * *
13  * Copyright (c) 2007: *
14  * CERN, Switzerland *
15  * MPI-K Heidelberg, Germany *
16  * U. of Texas at Austin, USA *
17  * *
18  * Redistribution and use in source and binary forms, with or without *
19  * modification, are permitted according to the terms listed in LICENSE *
20  * (http://tmva.sourceforge.net/LICENSE) *
21  **********************************************************************************/
22 
23 #include "TMVA/CCTreeWrapper.h"
24 #include "TMVA/DecisionTree.h"
25 
26 #include <iostream>
27 #include <limits>
28 
29 using namespace TMVA;
30 
31 ////////////////////////////////////////////////////////////////////////////////
32 ///constructor of the CCTreeNode
33 
35  Node(),
36  fNLeafDaughters(0),
37  fNodeResubstitutionEstimate(-1.0),
38  fResubstitutionEstimate(-1.0),
39  fAlphaC(-1.0),
40  fMinAlphaC(-1.0),
41  fDTNode(n)
42 {
43  if ( n != NULL && n->GetRight() != NULL && n->GetLeft() != NULL ) {
44  SetRight( new CCTreeNode( ((DecisionTreeNode*) n->GetRight()) ) );
45  GetRight()->SetParent(this);
46  SetLeft( new CCTreeNode( ((DecisionTreeNode*) n->GetLeft()) ) );
47  GetLeft()->SetParent(this);
48  }
49 }
50 
51 ////////////////////////////////////////////////////////////////////////////////
52 /// destructor of a CCTreeNode
53 
55  if(GetLeft() != NULL) delete GetLeftDaughter();
56  if(GetRight() != NULL) delete GetRightDaughter();
57 }
58 
59 ////////////////////////////////////////////////////////////////////////////////
60 /// initialize a node from a data record
61 
62 Bool_t TMVA::CCTreeWrapper::CCTreeNode::ReadDataRecord( std::istream& in, UInt_t /* tmva_Version_Code */ ) {
63  std::string header, title;
64  in >> header;
65  in >> title; in >> fNLeafDaughters;
66  in >> title; in >> fNodeResubstitutionEstimate;
67  in >> title; in >> fResubstitutionEstimate;
68  in >> title; in >> fAlphaC;
69  in >> title; in >> fMinAlphaC;
70  return true;
71 }
72 
73 ////////////////////////////////////////////////////////////////////////////////
74 /// printout of the node (can be read in with ReadDataRecord)
75 
76 void TMVA::CCTreeWrapper::CCTreeNode::Print( std::ostream& os ) const {
77  os << "----------------------" << std::endl
78  << "|~T_t| " << fNLeafDaughters << std::endl
79  << "R(t): " << fNodeResubstitutionEstimate << std::endl
80  << "R(T_t): " << fResubstitutionEstimate << std::endl
81  << "g(t): " << fAlphaC << std::endl
82  << "G(t): " << fMinAlphaC << std::endl;
83 }
84 
85 ////////////////////////////////////////////////////////////////////////////////
86 /// recursive printout of the node and its daughters
87 
88 void TMVA::CCTreeWrapper::CCTreeNode::PrintRec( std::ostream& os ) const {
89  this->Print(os);
90  if(this->GetLeft() != NULL && this->GetRight() != NULL) {
91  this->GetLeft()->PrintRec(os);
92  this->GetRight()->PrintRec(os);
93  }
94 }
95 
96 ////////////////////////////////////////////////////////////////////////////////
97 /// constructor
98 
100  fRoot(NULL)
101 {
102  fDTParent = T;
103  fRoot = new CCTreeNode( dynamic_cast<DecisionTreeNode*>(T->GetRoot()) );
104  fQualityIndex = qualityIndex;
105  InitTree(fRoot);
106 }
107 
108 ////////////////////////////////////////////////////////////////////////////////
109 /// destructor
110 
112  delete fRoot;
113 }
114 
115 ////////////////////////////////////////////////////////////////////////////////
116 /// initialize the node t and all its descendants
117 
119 {
120  Double_t s = t->GetDTNode()->GetNSigEvents();
121  Double_t b = t->GetDTNode()->GetNBkgEvents();
122  // Double_t s = t->GetDTNode()->GetNSigEvents_unweighted();
123  // Double_t b = t->GetDTNode()->GetNBkgEvents_unweighted();
124  // set R(t) = Gini(t) or MisclassificationError(t), etc.
126 
127  if(t->GetLeft() != NULL && t->GetRight() != NULL) { // n is an interior (non-leaf) node
128  // traverse the tree
131  // set |~T_t|
134  // set R(T) = sum[t' in ~T]{ R(t) }
137  // set g(t)
139  (t->GetNLeafDaughters() - 1));
140  // G(t) = min( g(t), G(l(t)), G(r(t)) )
141  t->SetMinAlphaC(std::min(t->GetAlphaC(), std::min(t->GetLeftDaughter()->GetMinAlphaC(),
142  t->GetRightDaughter()->GetMinAlphaC())));
143  }
144  else { // n is a terminal node
145  t->SetNLeafDaughters(1);
147  t->SetAlphaC(std::numeric_limits<double>::infinity( ));
148  t->SetMinAlphaC(std::numeric_limits<double>::infinity( ));
149  }
150 }
151 
152 ////////////////////////////////////////////////////////////////////////////////
153 /// remove the branch rooted at node t
154 
156 {
157  if( t->GetLeft() != NULL &&
158  t->GetRight() != NULL ) {
159  CCTreeNode* l = t->GetLeftDaughter();
160  CCTreeNode* r = t->GetRightDaughter();
161  t->SetNLeafDaughters( 1 );
163  t->SetAlphaC( std::numeric_limits<double>::infinity( ) );
164  t->SetMinAlphaC( std::numeric_limits<double>::infinity( ) );
165  delete l;
166  delete r;
167  t->SetLeft(NULL);
168  t->SetRight(NULL);
169  }else{
170  std::cout << " ERROR in CCTreeWrapper::PruneNode: you try to prune a leaf node.. that does not make sense " << std::endl;
171  }
172 }
173 
174 ////////////////////////////////////////////////////////////////////////////////
175 /// return the misclassification rate of a pruned tree for a validation event sample
176 /// using an EventList
177 
179 {
180  Double_t ncorrect=0, nfalse=0;
181  for (UInt_t ievt=0; ievt < validationSample->size(); ievt++) {
182  Bool_t isSignalType = (CheckEvent(*(*validationSample)[ievt]) > fDTParent->GetNodePurityLimit() ) ? 1 : 0;
183 
184  if (isSignalType == ((*validationSample)[ievt]->GetClass() == 0)) {
185  ncorrect += (*validationSample)[ievt]->GetWeight();
186  }
187  else{
188  nfalse += (*validationSample)[ievt]->GetWeight();
189  }
190  }
191  return ncorrect / (ncorrect + nfalse);
192 }
193 
194 ////////////////////////////////////////////////////////////////////////////////
195 /// return the misclassification rate of a pruned tree for a validation event sample
196 /// using the DataSet
197 
199 {
200  validationSample->SetCurrentType(Types::kValidation);
201  // test the tree quality.. in terms of Miscalssification
202  Double_t ncorrect=0, nfalse=0;
203  for (Long64_t ievt=0; ievt<validationSample->GetNEvents(); ievt++){
204  const Event *ev = validationSample->GetEvent(ievt);
205 
206  Bool_t isSignalType = (CheckEvent(*ev) > fDTParent->GetNodePurityLimit() ) ? 1 : 0;
207 
208  if (isSignalType == (ev->GetClass() == 0)) {
209  ncorrect += ev->GetWeight();
210  }
211  else{
212  nfalse += ev->GetWeight();
213  }
214  }
215  return ncorrect / (ncorrect + nfalse);
216 }
217 
218 ////////////////////////////////////////////////////////////////////////////////
219 /// return the decision tree output for an event
220 
222 {
223  const DecisionTreeNode* current = fRoot->GetDTNode();
224  CCTreeNode* t = fRoot;
225 
226  while(//current->GetNodeType() == 0 &&
227  t->GetLeft() != NULL &&
228  t->GetRight() != NULL){ // at an interior (non-leaf) node
229  if (current->GoesRight(e)) {
230  //current = (DecisionTreeNode*)current->GetRight();
231  t = t->GetRightDaughter();
232  current = t->GetDTNode();
233  }
234  else {
235  //current = (DecisionTreeNode*)current->GetLeft();
236  t = t->GetLeftDaughter();
237  current = t->GetDTNode();
238  }
239  }
240 
241  if (useYesNoLeaf) return (current->GetPurity() > fDTParent->GetNodePurityLimit() ? 1.0 : -1.0);
242  else return current->GetPurity();
243 }
244 
245 ////////////////////////////////////////////////////////////////////////////////
246 
248 {}
249 
250 ////////////////////////////////////////////////////////////////////////////////
251 
252 void TMVA::CCTreeWrapper::CCTreeNode::AddContentToNode( std::stringstream& /*s*/ ) const
253 {}
254 
255 ////////////////////////////////////////////////////////////////////////////////
256 
257 void TMVA::CCTreeWrapper::CCTreeNode::ReadAttributes( void* /*node*/, UInt_t /* tmva_Version_Code */ )
258 {}
259 
260 ////////////////////////////////////////////////////////////////////////////////
261 
262 void TMVA::CCTreeWrapper::CCTreeNode::ReadContent( std::stringstream& /*s*/ )
263 {}
virtual void PrintRec(std::ostream &os) const =0
virtual ~CCTreeNode()
destructor of a CCTreeNode
CCTreeNode * fRoot
pointer to underlying DecisionTree
long long Long64_t
Definition: RtypesCore.h:69
virtual void PrintRec(std::ostream &os) const
recursive printout of the node and its daughters
double T(double x)
Definition: ChebyshevPol.h:34
Double_t GetNodePurityLimit() const
Definition: DecisionTree.h:170
bool Bool_t
Definition: RtypesCore.h:59
virtual void SetRight(Node *r)
Definition: Node.h:97
Float_t GetNSigEvents(void) const
Double_t fAlphaC
R(T_t) = sum[t&#39; in ~T_t]{ R(t) }.
virtual DecisionTreeNode * GetRoot() const
Definition: DecisionTree.h:102
virtual void AddContentToNode(std::stringstream &s) const
void SetResubstitutionEstimate(Double_t R)
Definition: CCTreeWrapper.h:80
void InitTree(CCTreeNode *t)
initialize the node t and all its descendants
Double_t CheckEvent(const TMVA::Event &e, Bool_t useYesNoLeaf=false)
return the decision tree output for an event
TClass * GetClass(T *)
Definition: TClass.h:555
virtual void SetLeft(Node *l)
Definition: Node.h:96
Float_t GetNBkgEvents(void) const
CCTreeWrapper(DecisionTree *T, SeparationBase *qualityIndex)
constructor
virtual Node * GetRight() const
Definition: Node.h:92
virtual Node * GetLeft() const
Definition: Node.h:91
UInt_t GetClass() const
Definition: Event.h:89
Double_t GetResubstitutionEstimate() const
Definition: CCTreeWrapper.h:83
Double_t fNodeResubstitutionEstimate
number of terminal descendants
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not...
Definition: Event.cxx:378
virtual void ReadContent(std::stringstream &s)
DecisionTree * fDTParent
pointer to the used quality index calculator
std::vector< Event * > EventList
Definition: CCTreeWrapper.h:50
Double_t TestTreeQuality(const EventList *validationSample)
return the misclassification rate of a pruned tree for a validation event sample using an EventList ...
TRandom2 r(17)
SeparationBase * fQualityIndex
void SetNodeResubstitutionEstimate(Double_t R)
Definition: CCTreeWrapper.h:73
unsigned int UInt_t
Definition: RtypesCore.h:42
Double_t fResubstitutionEstimate
R(t) = misclassification rate for node t.
TLine * l
Definition: textangle.C:4
void SetMinAlphaC(Double_t alpha)
Definition: CCTreeWrapper.h:96
virtual void Print(std::ostream &os) const
printout of the node (can be read in with ReadDataRecord)
virtual Bool_t ReadDataRecord(std::istream &in, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
initialize a node from a data record
virtual void SetParent(Node *p)
Definition: Node.h:98
Float_t GetPurity(void) const
virtual Bool_t GoesRight(const Event &) const
test event if it decends the tree at this node to the right
double Double_t
Definition: RtypesCore.h:55
void SetCurrentType(Types::ETreeType type) const
Definition: DataSet.h:114
you should not use this method at all Int_t Int_t Double_t Double_t Double_t e
Definition: TRolke.cxx:630
DecisionTreeNode * GetDTNode() const
void SetAlphaC(Double_t alpha)
Definition: CCTreeWrapper.h:90
Abstract ClassifierFactory template that handles arbitrary types.
virtual void ReadAttributes(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
virtual Double_t GetSeparationIndex(const Double_t &s, const Double_t &b)=0
you should not use this method at all Int_t Int_t Double_t Double_t Double_t Int_t Double_t Double_t Double_t Double_t b
Definition: TRolke.cxx:630
virtual DecisionTreeNode * GetLeft() const
#define NULL
Definition: Rtypes.h:82
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:229
virtual DecisionTreeNode * GetRight() const
~CCTreeWrapper()
destructor
CCTreeNode(DecisionTreeNode *n=NULL)
constructor of the CCTreeNode
void PruneNode(CCTreeNode *t)
remove the branch rooted at node t
const Int_t n
Definition: legend1.C:16
const Event * GetEvent() const
Definition: DataSet.cxx:211
virtual void AddAttributesToNode(void *node) const
Double_t fMinAlphaC
critical point, g(t) = alpha_c(t)
Double_t GetNodeResubstitutionEstimate() const
Definition: CCTreeWrapper.h:76