ROOT  6.06/09
Reference Guide
CCTreeWrapper.cxx
Go to the documentation of this file.
1 /**********************************************************************************
2  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
3  * Package: TMVA *
4  * Class : CCTreeWrapper *
5  * Web : http://tmva.sourceforge.net *
6  * *
7  * Description: a light wrapper of a decision tree, used to perform cost *
8  * complexity pruning "in-place" Cost Complexity Pruning *
9  * *
10  * Author: Doug Schouten (dschoute@sfu.ca) *
11  * *
12  * *
13  * Copyright (c) 2007: *
14  * CERN, Switzerland *
15  * MPI-K Heidelberg, Germany *
16  * U. of Texas at Austin, USA *
17  * *
18  * Redistribution and use in source and binary forms, with or without *
19  * modification, are permitted according to the terms listed in LICENSE *
20  * (http://tmva.sourceforge.net/LICENSE) *
21  **********************************************************************************/
22 
23 #include "TMVA/CCTreeWrapper.h"
24 
25 #include <iostream>
26 #include <limits>
27 
28 using namespace TMVA;
29 
30 ////////////////////////////////////////////////////////////////////////////////
31 ///constructor of the CCTreeNode
32 
34  Node(),
35  fNLeafDaughters(0),
36  fNodeResubstitutionEstimate(-1.0),
37  fResubstitutionEstimate(-1.0),
38  fAlphaC(-1.0),
39  fMinAlphaC(-1.0),
40  fDTNode(n)
41 {
42  if ( n != NULL && n->GetRight() != NULL && n->GetLeft() != NULL ) {
43  SetRight( new CCTreeNode( ((DecisionTreeNode*) n->GetRight()) ) );
44  GetRight()->SetParent(this);
45  SetLeft( new CCTreeNode( ((DecisionTreeNode*) n->GetLeft()) ) );
46  GetLeft()->SetParent(this);
47  }
48 }
49 
50 ////////////////////////////////////////////////////////////////////////////////
51 /// destructor of a CCTreeNode
52 
54  if(GetLeft() != NULL) delete GetLeftDaughter();
55  if(GetRight() != NULL) delete GetRightDaughter();
56 }
57 
58 ////////////////////////////////////////////////////////////////////////////////
59 /// initialize a node from a data record
60 
61 Bool_t TMVA::CCTreeWrapper::CCTreeNode::ReadDataRecord( std::istream& in, UInt_t /* tmva_Version_Code */ ) {
62  std::string header, title;
63  in >> header;
64  in >> title; in >> fNLeafDaughters;
65  in >> title; in >> fNodeResubstitutionEstimate;
66  in >> title; in >> fResubstitutionEstimate;
67  in >> title; in >> fAlphaC;
68  in >> title; in >> fMinAlphaC;
69  return true;
70 }
71 
72 ////////////////////////////////////////////////////////////////////////////////
73 /// printout of the node (can be read in with ReadDataRecord)
74 
75 void TMVA::CCTreeWrapper::CCTreeNode::Print( std::ostream& os ) const {
76  os << "----------------------" << std::endl
77  << "|~T_t| " << fNLeafDaughters << std::endl
78  << "R(t): " << fNodeResubstitutionEstimate << std::endl
79  << "R(T_t): " << fResubstitutionEstimate << std::endl
80  << "g(t): " << fAlphaC << std::endl
81  << "G(t): " << fMinAlphaC << std::endl;
82 }
83 
84 ////////////////////////////////////////////////////////////////////////////////
85 /// recursive printout of the node and its daughters
86 
87 void TMVA::CCTreeWrapper::CCTreeNode::PrintRec( std::ostream& os ) const {
88  this->Print(os);
89  if(this->GetLeft() != NULL && this->GetRight() != NULL) {
90  this->GetLeft()->PrintRec(os);
91  this->GetRight()->PrintRec(os);
92  }
93 }
94 
95 ////////////////////////////////////////////////////////////////////////////////
96 /// constructor
97 
99  fRoot(NULL)
100 {
101  fDTParent = T;
102  fRoot = new CCTreeNode( dynamic_cast<DecisionTreeNode*>(T->GetRoot()) );
103  fQualityIndex = qualityIndex;
104  InitTree(fRoot);
105 }
106 
107 ////////////////////////////////////////////////////////////////////////////////
108 /// destructor
109 
111  delete fRoot;
112 }
113 
114 ////////////////////////////////////////////////////////////////////////////////
115 /// initialize the node t and all its descendants
116 
118 {
119  Double_t s = t->GetDTNode()->GetNSigEvents();
120  Double_t b = t->GetDTNode()->GetNBkgEvents();
121  // Double_t s = t->GetDTNode()->GetNSigEvents_unweighted();
122  // Double_t b = t->GetDTNode()->GetNBkgEvents_unweighted();
123  // set R(t) = Gini(t) or MisclassificationError(t), etc.
124  t->SetNodeResubstitutionEstimate((s+b)*fQualityIndex->GetSeparationIndex(s,b));
125 
126  if(t->GetLeft() != NULL && t->GetRight() != NULL) { // n is an interior (non-leaf) node
127  // traverse the tree
128  InitTree(t->GetLeftDaughter());
129  InitTree(t->GetRightDaughter());
130  // set |~T_t|
133  // set R(T) = sum[t' in ~T]{ R(t) }
136  // set g(t)
138  (t->GetNLeafDaughters() - 1));
139  // G(t) = min( g(t), G(l(t)), G(r(t)) )
141  t->GetRightDaughter()->GetMinAlphaC())));
142  }
143  else { // n is a terminal node
144  t->SetNLeafDaughters(1);
145  t->SetResubstitutionEstimate((s+b)*fQualityIndex->GetSeparationIndex(s,b));
148  }
149 }
150 
151 ////////////////////////////////////////////////////////////////////////////////
152 /// remove the branch rooted at node t
153 
155 {
156  if( t->GetLeft() != NULL &&
157  t->GetRight() != NULL ) {
158  CCTreeNode* l = t->GetLeftDaughter();
159  CCTreeNode* r = t->GetRightDaughter();
160  t->SetNLeafDaughters( 1 );
164  delete l;
165  delete r;
166  t->SetLeft(NULL);
167  t->SetRight(NULL);
168  }else{
169  std::cout << " ERROR in CCTreeWrapper::PruneNode: you try to prune a leaf node.. that does not make sense " << std::endl;
170  }
171 }
172 
173 ////////////////////////////////////////////////////////////////////////////////
174 /// return the misclassification rate of a pruned tree for a validation event sample
175 /// using an EventList
176 
178 {
179  Double_t ncorrect=0, nfalse=0;
180  for (UInt_t ievt=0; ievt < validationSample->size(); ievt++) {
181  Bool_t isSignalType = (CheckEvent(*(*validationSample)[ievt]) > fDTParent->GetNodePurityLimit() ) ? 1 : 0;
182 
183  if (isSignalType == ((*validationSample)[ievt]->GetClass() == 0)) {
184  ncorrect += (*validationSample)[ievt]->GetWeight();
185  }
186  else{
187  nfalse += (*validationSample)[ievt]->GetWeight();
188  }
189  }
190  return ncorrect / (ncorrect + nfalse);
191 }
192 
193 ////////////////////////////////////////////////////////////////////////////////
194 /// return the misclassification rate of a pruned tree for a validation event sample
195 /// using the DataSet
196 
198 {
199  validationSample->SetCurrentType(Types::kValidation);
200  // test the tree quality.. in terms of Miscalssification
201  Double_t ncorrect=0, nfalse=0;
202  for (Long64_t ievt=0; ievt<validationSample->GetNEvents(); ievt++){
203  const Event *ev = validationSample->GetEvent(ievt);
204 
205  Bool_t isSignalType = (CheckEvent(*ev) > fDTParent->GetNodePurityLimit() ) ? 1 : 0;
206 
207  if (isSignalType == (ev->GetClass() == 0)) {
208  ncorrect += ev->GetWeight();
209  }
210  else{
211  nfalse += ev->GetWeight();
212  }
213  }
214  return ncorrect / (ncorrect + nfalse);
215 }
216 
217 ////////////////////////////////////////////////////////////////////////////////
218 /// return the decision tree output for an event
219 
221 {
222  const DecisionTreeNode* current = fRoot->GetDTNode();
223  CCTreeNode* t = fRoot;
224 
225  while(//current->GetNodeType() == 0 &&
226  t->GetLeft() != NULL &&
227  t->GetRight() != NULL){ // at an interior (non-leaf) node
228  if (current->GoesRight(e)) {
229  //current = (DecisionTreeNode*)current->GetRight();
230  t = t->GetRightDaughter();
231  current = t->GetDTNode();
232  }
233  else {
234  //current = (DecisionTreeNode*)current->GetLeft();
235  t = t->GetLeftDaughter();
236  current = t->GetDTNode();
237  }
238  }
239 
240  if (useYesNoLeaf) return (current->GetPurity() > fDTParent->GetNodePurityLimit() ? 1.0 : -1.0);
241  else return current->GetPurity();
242 }
243 
244 ////////////////////////////////////////////////////////////////////////////////
245 
247 {}
248 
249 ////////////////////////////////////////////////////////////////////////////////
250 
251 void TMVA::CCTreeWrapper::CCTreeNode::AddContentToNode( std::stringstream& /*s*/ ) const
252 {}
253 
254 ////////////////////////////////////////////////////////////////////////////////
255 
256 void TMVA::CCTreeWrapper::CCTreeNode::ReadAttributes( void* /*node*/, UInt_t /* tmva_Version_Code */ )
257 {}
258 
259 ////////////////////////////////////////////////////////////////////////////////
260 
261 void TMVA::CCTreeWrapper::CCTreeNode::ReadContent( std::stringstream& /*s*/ )
262 {}
virtual ~CCTreeNode()
destructor of a CCTreeNode
CCTreeNode * fRoot
pointer to underlying DecisionTree
static Vc_ALWAYS_INLINE int_v min(const int_v &x, const int_v &y)
Definition: vector.h:433
long long Long64_t
Definition: RtypesCore.h:69
double T(double x)
Definition: ChebyshevPol.h:34
virtual DecisionTreeNode * GetRight() const
bool Bool_t
Definition: RtypesCore.h:59
virtual void SetRight(Node *r)
Definition: Node.h:97
Double_t GetResubstitutionEstimate() const
Definition: CCTreeWrapper.h:83
virtual DecisionTreeNode * GetLeft() const
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not...
Definition: Event.cxx:376
virtual DecisionTreeNode * GetRoot() const
Definition: DecisionTree.h:102
void SetResubstitutionEstimate(Double_t R)
Definition: CCTreeWrapper.h:80
void InitTree(CCTreeNode *t)
initialize the node t and all its descendants
Double_t CheckEvent(const TMVA::Event &e, Bool_t useYesNoLeaf=false)
return the decision tree output for an event
TClass * GetClass(T *)
Definition: TClass.h:555
virtual void SetLeft(Node *l)
Definition: Node.h:96
CCTreeWrapper(DecisionTree *T, SeparationBase *qualityIndex)
constructor
Double_t GetNodeResubstitutionEstimate() const
Definition: CCTreeWrapper.h:76
Float_t GetPurity(void) const
DecisionTreeNode * GetDTNode() const
virtual void ReadContent(std::stringstream &s)
Float_t GetNBkgEvents(void) const
DecisionTree * fDTParent
pointer to the used quality index calculator
std::vector< Event * > EventList
Definition: CCTreeWrapper.h:50
Double_t TestTreeQuality(const EventList *validationSample)
return the misclassification rate of a pruned tree for a validation event sample using an EventList ...
ROOT::R::TRInterface & r
Definition: Object.C:4
SeparationBase * fQualityIndex
virtual void PrintRec(std::ostream &os) const
recursive printout of the node and its daughters
virtual void AddAttributesToNode(void *node) const
void SetNodeResubstitutionEstimate(Double_t R)
Definition: CCTreeWrapper.h:73
unsigned int UInt_t
Definition: RtypesCore.h:42
TLine * l
Definition: textangle.C:4
virtual void AddContentToNode(std::stringstream &s) const
const Double_t infinity
Definition: CsgOps.cxx:85
void SetMinAlphaC(Double_t alpha)
Definition: CCTreeWrapper.h:96
virtual Bool_t ReadDataRecord(std::istream &in, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
initialize a node from a data record
virtual void SetParent(Node *p)
Definition: Node.h:98
virtual void Print(std::ostream &os) const
printout of the node (can be read in with ReadDataRecord)
const Event * GetEvent() const
Definition: DataSet.cxx:180
void SetCurrentType(Types::ETreeType type) const
Definition: DataSet.h:111
void Print(std::ostream &os, const OptionType &opt)
Double_t GetMinAlphaC() const
Definition: CCTreeWrapper.h:99
double Double_t
Definition: RtypesCore.h:55
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:225
virtual Node * GetRight() const
Definition: Node.h:92
UInt_t GetClass() const
Definition: Event.h:86
void SetAlphaC(Double_t alpha)
Definition: CCTreeWrapper.h:90
Abstract ClassifierFactory template that handles arbitrary types.
virtual void ReadAttributes(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
#define NULL
Definition: Rtypes.h:82
Float_t GetNSigEvents(void) const
~CCTreeWrapper()
destructor
CCTreeNode(DecisionTreeNode *n=NULL)
constructor of the CCTreeNode
void PruneNode(CCTreeNode *t)
remove the branch rooted at node t
virtual Bool_t GoesRight(const Event &) const
test event if it decends the tree at this node to the right
const Int_t n
Definition: legend1.C:16
virtual Node * GetLeft() const
Definition: Node.h:91