Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
CCTreeWrapper.cxx
Go to the documentation of this file.
1/**********************************************************************************
2 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
3 * Package: TMVA *
4 * Class : CCTreeWrapper *
5 * Web : http://tmva.sourceforge.net *
6 * *
7 * Description: a light wrapper of a decision tree, used to perform cost *
8 * complexity pruning "in-place" Cost Complexity Pruning *
9 * *
10 * Author: Doug Schouten (dschoute@sfu.ca) *
11 * *
12 * *
13 * Copyright (c) 2007: *
14 * CERN, Switzerland *
15 * MPI-K Heidelberg, Germany *
16 * U. of Texas at Austin, USA *
17 * *
18 * Redistribution and use in source and binary forms, with or without *
19 * modification, are permitted according to the terms listed in LICENSE *
20 * (http://tmva.sourceforge.net/LICENSE) *
21 **********************************************************************************/
22
23/*! \class TMVA::CCTreeWrapper
24\ingroup TMVA
25
26*/
27
28#include "TMVA/CCTreeWrapper.h"
29#include "TMVA/DecisionTree.h"
30
31#include <iostream>
32#include <limits>
33
34using namespace TMVA;
35
36////////////////////////////////////////////////////////////////////////////////
37///constructor of the CCTreeNode
38
40 Node(),
41 fNLeafDaughters(0),
42 fNodeResubstitutionEstimate(-1.0),
43 fResubstitutionEstimate(-1.0),
44 fAlphaC(-1.0),
45 fMinAlphaC(-1.0),
46 fDTNode(n)
47{
48 if ( n != NULL && n->GetRight() != NULL && n->GetLeft() != NULL ) {
49 SetRight( new CCTreeNode( ((DecisionTreeNode*) n->GetRight()) ) );
50 GetRight()->SetParent(this);
51 SetLeft( new CCTreeNode( ((DecisionTreeNode*) n->GetLeft()) ) );
52 GetLeft()->SetParent(this);
53 }
54}
55
56////////////////////////////////////////////////////////////////////////////////
57/// destructor of a CCTreeNode
58
60 if(GetLeft() != NULL) delete GetLeftDaughter();
61 if(GetRight() != NULL) delete GetRightDaughter();
62}
63
64////////////////////////////////////////////////////////////////////////////////
65/// initialize a node from a data record
66
67Bool_t TMVA::CCTreeWrapper::CCTreeNode::ReadDataRecord( std::istream& in, UInt_t /* tmva_Version_Code */ ) {
68 std::string header, title;
69 in >> header;
70 in >> title; in >> fNLeafDaughters;
71 in >> title; in >> fNodeResubstitutionEstimate;
72 in >> title; in >> fResubstitutionEstimate;
73 in >> title; in >> fAlphaC;
74 in >> title; in >> fMinAlphaC;
75 return true;
76}
77
78////////////////////////////////////////////////////////////////////////////////
79/// printout of the node (can be read in with ReadDataRecord)
80
81void TMVA::CCTreeWrapper::CCTreeNode::Print( std::ostream& os ) const {
82 os << "----------------------" << std::endl
83 << "|~T_t| " << fNLeafDaughters << std::endl
84 << "R(t): " << fNodeResubstitutionEstimate << std::endl
85 << "R(T_t): " << fResubstitutionEstimate << std::endl
86 << "g(t): " << fAlphaC << std::endl
87 << "G(t): " << fMinAlphaC << std::endl;
88}
89
90////////////////////////////////////////////////////////////////////////////////
91/// recursive printout of the node and its daughters
92
93void TMVA::CCTreeWrapper::CCTreeNode::PrintRec( std::ostream& os ) const {
94 this->Print(os);
95 if(this->GetLeft() != NULL && this->GetRight() != NULL) {
96 this->GetLeft()->PrintRec(os);
97 this->GetRight()->PrintRec(os);
98 }
99}
100
101////////////////////////////////////////////////////////////////////////////////
102/// constructor
103
105 fRoot(NULL)
106{
107 fDTParent = T;
108 fRoot = new CCTreeNode( dynamic_cast<DecisionTreeNode*>(T->GetRoot()) );
109 fQualityIndex = qualityIndex;
111}
112
113////////////////////////////////////////////////////////////////////////////////
114/// destructor
115
117 delete fRoot;
118}
119
120////////////////////////////////////////////////////////////////////////////////
121/// initialize the node t and all its descendants
122
124{
125 Double_t s = t->GetDTNode()->GetNSigEvents();
127 // Double_t s = t->GetDTNode()->GetNSigEvents_unweighted();
128 // Double_t b = t->GetDTNode()->GetNBkgEvents_unweighted();
129 // set R(t) = Gini(t) or MisclassificationError(t), etc.
130 t->SetNodeResubstitutionEstimate((s+b)*fQualityIndex->GetSeparationIndex(s,b));
131
132 if(t->GetLeft() != NULL && t->GetRight() != NULL) { // n is an interior (non-leaf) node
133 // traverse the tree
134 InitTree(t->GetLeftDaughter());
135 InitTree(t->GetRightDaughter());
136 // set |~T_t|
139 // set R(T) = sum[t' in ~T]{ R(t) }
142 // set g(t)
144 (t->GetNLeafDaughters() - 1));
145 // G(t) = min( g(t), G(l(t)), G(r(t)) )
146 t->SetMinAlphaC(std::min(t->GetAlphaC(), std::min(t->GetLeftDaughter()->GetMinAlphaC(),
148 }
149 else { // n is a terminal node
150 t->SetNLeafDaughters(1);
151 t->SetResubstitutionEstimate((s+b)*fQualityIndex->GetSeparationIndex(s,b));
152 t->SetAlphaC(std::numeric_limits<double>::infinity( ));
153 t->SetMinAlphaC(std::numeric_limits<double>::infinity( ));
154 }
155}
156
157////////////////////////////////////////////////////////////////////////////////
158/// remove the branch rooted at node t
159
161{
162 if( t->GetLeft() != NULL &&
163 t->GetRight() != NULL ) {
166 t->SetNLeafDaughters( 1 );
168 t->SetAlphaC( std::numeric_limits<double>::infinity( ) );
169 t->SetMinAlphaC( std::numeric_limits<double>::infinity( ) );
170 delete l;
171 delete r;
172 t->SetLeft(NULL);
173 t->SetRight(NULL);
174 }else{
175 std::cout << " ERROR in CCTreeWrapper::PruneNode: you try to prune a leaf node.. that does not make sense " << std::endl;
176 }
177}
178
179////////////////////////////////////////////////////////////////////////////////
180/// return the misclassification rate of a pruned tree for a validation event sample
181/// using an EventList
182
184{
185 Double_t ncorrect=0, nfalse=0;
186 for (UInt_t ievt=0; ievt < validationSample->size(); ievt++) {
187 Bool_t isSignalType = (CheckEvent(*(*validationSample)[ievt]) > fDTParent->GetNodePurityLimit() ) ? 1 : 0;
188
189 if (isSignalType == ((*validationSample)[ievt]->GetClass() == 0)) {
190 ncorrect += (*validationSample)[ievt]->GetWeight();
191 }
192 else{
193 nfalse += (*validationSample)[ievt]->GetWeight();
194 }
195 }
196 return ncorrect / (ncorrect + nfalse);
197}
198
199////////////////////////////////////////////////////////////////////////////////
200/// return the misclassification rate of a pruned tree for a validation event sample
201/// using the DataSet
202
204{
205 validationSample->SetCurrentType(Types::kValidation);
206 // test the tree quality.. in terms of Misclassification
207 Double_t ncorrect=0, nfalse=0;
208 for (Long64_t ievt=0; ievt<validationSample->GetNEvents(); ievt++){
209 const Event *ev = validationSample->GetEvent(ievt);
210
211 Bool_t isSignalType = (CheckEvent(*ev) > fDTParent->GetNodePurityLimit() ) ? 1 : 0;
212
213 if (isSignalType == (ev->GetClass() == 0)) {
214 ncorrect += ev->GetWeight();
215 }
216 else{
217 nfalse += ev->GetWeight();
218 }
219 }
220 return ncorrect / (ncorrect + nfalse);
221}
222
223////////////////////////////////////////////////////////////////////////////////
224/// return the decision tree output for an event
225
227{
228 const DecisionTreeNode* current = fRoot->GetDTNode();
229 CCTreeNode* t = fRoot;
230
231 while(//current->GetNodeType() == 0 &&
232 t->GetLeft() != NULL &&
233 t->GetRight() != NULL){ // at an interior (non-leaf) node
234 if (current->GoesRight(e)) {
235 //current = (DecisionTreeNode*)current->GetRight();
236 t = t->GetRightDaughter();
237 current = t->GetDTNode();
238 }
239 else {
240 //current = (DecisionTreeNode*)current->GetLeft();
241 t = t->GetLeftDaughter();
242 current = t->GetDTNode();
243 }
244 }
245
246 if (useYesNoLeaf) return (current->GetPurity() > fDTParent->GetNodePurityLimit() ? 1.0 : -1.0);
247 else return current->GetPurity();
248}
249
250////////////////////////////////////////////////////////////////////////////////
251
253{}
254
255////////////////////////////////////////////////////////////////////////////////
256
257void TMVA::CCTreeWrapper::CCTreeNode::AddContentToNode( std::stringstream& /*s*/ ) const
258{}
259
260////////////////////////////////////////////////////////////////////////////////
261
262void TMVA::CCTreeWrapper::CCTreeNode::ReadAttributes( void* /*node*/, UInt_t /* tmva_Version_Code */ )
263{}
264
265////////////////////////////////////////////////////////////////////////////////
266
267void TMVA::CCTreeWrapper::CCTreeNode::ReadContent( std::stringstream& /*s*/ )
268{}
ROOT::R::TRInterface & r
Definition Object.C:4
#define b(i)
Definition RSha256.hxx:100
#define e(i)
Definition RSha256.hxx:103
long long Long64_t
Definition RtypesCore.h:73
Double_t GetNodeResubstitutionEstimate() const
virtual void ReadContent(std::stringstream &s)
CCTreeNode(DecisionTreeNode *n=NULL)
constructor of the CCTreeNode
virtual void AddAttributesToNode(void *node) const
void SetMinAlphaC(Double_t alpha)
DecisionTreeNode * GetDTNode() const
void SetResubstitutionEstimate(Double_t R)
Double_t GetResubstitutionEstimate() const
virtual void PrintRec(std::ostream &os) const
recursive printout of the node and its daughters
virtual void AddContentToNode(std::stringstream &s) const
virtual Bool_t ReadDataRecord(std::istream &in, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
initialize a node from a data record
void SetAlphaC(Double_t alpha)
void SetNodeResubstitutionEstimate(Double_t R)
virtual ~CCTreeNode()
destructor of a CCTreeNode
virtual void ReadAttributes(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
virtual void Print(std::ostream &os) const
printout of the node (can be read in with ReadDataRecord)
SeparationBase * fQualityIndex
std::vector< Event * > EventList
DecisionTree * fDTParent
pointer to the used quality index calculator
CCTreeNode * fRoot
pointer to underlying DecisionTree
Double_t TestTreeQuality(const EventList *validationSample)
return the misclassification rate of a pruned tree for a validation event sample using an EventList
void InitTree(CCTreeNode *t)
initialize the node t and all its descendants
void PruneNode(CCTreeNode *t)
remove the branch rooted at node t
CCTreeWrapper(DecisionTree *T, SeparationBase *qualityIndex)
constructor
Double_t CheckEvent(const TMVA::Event &e, Bool_t useYesNoLeaf=false)
return the decision tree output for an event
Class that contains all the data information.
Definition DataSet.h:58
const Event * GetEvent() const
Definition DataSet.cxx:202
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition DataSet.h:206
void SetCurrentType(Types::ETreeType type) const
Definition DataSet.h:89
Float_t GetNSigEvents(void) const
virtual Bool_t GoesRight(const Event &) const
test event if it descends the tree at this node to the right
Float_t GetPurity(void) const
Float_t GetNBkgEvents(void) const
Implementation of a Decision Tree.
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not.
Definition Event.cxx:381
UInt_t GetClass() const
Definition Event.h:86
Node for the BinarySearch or Decision Trees.
Definition Node.h:58
virtual Node * GetLeft() const
Definition Node.h:89
virtual void SetRight(Node *r)
Definition Node.h:95
virtual void SetLeft(Node *l)
Definition Node.h:94
virtual void SetParent(Node *p)
Definition Node.h:96
virtual Node * GetRight() const
Definition Node.h:90
An interface to calculate the "SeparationGain" for different separation criteria used in various trai...
@ kValidation
Definition Types.h:148
const Int_t n
Definition legend1.C:16
create variable transformations
auto * l
Definition textangle.C:4