Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
CCPruner.cxx
Go to the documentation of this file.
1/**********************************************************************************
2 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
3 * Package: TMVA *
4 * Class : CCPruner *
5 * *
6 * *
7 * Description: Cost Complexity Pruning *
8 *
9 * Author: Doug Schouten (dschoute@sfu.ca)
10 *
11 * *
12 * Copyright (c) 2007: *
13 * CERN, Switzerland *
14 * MPI-K Heidelberg, Germany *
15 * U. of Texas at Austin, USA *
16 * *
17 * Redistribution and use in source and binary forms, with or without *
18 * modification, are permitted according to the terms listed in LICENSE *
19 * (see tmva/doc/LICENSE) *
20 **********************************************************************************/
21
22#include "TMVA/CCPruner.h"
23#include "TMVA/SeparationBase.h"
24#include "TMVA/GiniIndex.h"
26#include "TMVA/CCTreeWrapper.h"
27#include "TMVA/DataSet.h"
28
29#include "Rtypes.h"
30
31#include <iostream>
32#include <fstream>
33#include <limits>
34#include <math.h>
35
36/*! \class TMVA::CCPruner
37\ingroup TMVA
38A helper class to prune a decision tree using the Cost Complexity method
39(see Classification and Regression Trees by Leo Breiman et al)
40
41### Some definitions:
42
43 - \f$ T_{max} \f$ - the initial, usually highly overtrained tree, that is to be pruned back
44 - \f$ R(T) \f$ - quality index (Gini, misclassification rate, or other) of a tree \f$ T \f$
45 - \f$ \sim T \f$ - set of terminal nodes in \f$ T \f$
46 - \f$ T' \f$ - the pruned subtree of \f$ T_max \f$ that has the best quality index \f$ R(T') \f$
47 - \f$ \alpha \f$ - the prune strength parameter in Cost Complexity pruning \f$ (R_{\alpha}(T) = R(T) + \alpha*|\sim T|) \f$
48
49There are two running modes in CCPruner: (i) one may select a prune strength and prune back
50the tree \f$ T_{max}\f$ until the criterion:
51\f[
52 \alpha < \frac{R(T) - R(t)}{|\sim T_t| - 1}
53\f]
54
55is true for all nodes t in \f$ T \f$, or (ii) the algorithm finds the sequence of critical points
56\f$ \alpha_k < \alpha_{k+1} ... < \alpha_K \f$ such that \f$ T_K = root(T_{max}) \f$ and then selects the optimally-pruned
57subtree, defined to be the subtree with the best quality index for the validation sample.
58*/
59
60namespace TMVA {
61 class DecisionTree;
62}
63
64using namespace TMVA;
65
66////////////////////////////////////////////////////////////////////////////////
67/// constructor
68
69CCPruner::CCPruner( DecisionTree* t_max, const EventList* validationSample,
70 SeparationBase* qualityIndex ) :
71 fAlpha(-1.0),
72 fValidationSample(validationSample),
73 fValidationDataSet(NULL),
74 fOptimalK(-1)
75{
76 fTree = t_max;
77
78 if(qualityIndex == NULL) {
79 fOwnQIndex = true;
81 }
82 else {
83 fOwnQIndex = false;
84 fQualityIndex = qualityIndex;
85 }
86 fDebug = kTRUE;
87}
88
89////////////////////////////////////////////////////////////////////////////////
90/// constructor
91
92CCPruner::CCPruner( DecisionTree* t_max, const DataSet* validationSample,
93 SeparationBase* qualityIndex ) :
94 fAlpha(-1.0),
95 fValidationSample(NULL),
96 fValidationDataSet(validationSample),
97 fOptimalK(-1)
98{
99 fTree = t_max;
100
101 if(qualityIndex == NULL) {
102 fOwnQIndex = true;
104 }
105 else {
106 fOwnQIndex = false;
107 fQualityIndex = qualityIndex;
108 }
109 fDebug = kTRUE;
110}
111
112
113////////////////////////////////////////////////////////////////////////////////
114
116{
117 if(fOwnQIndex) delete fQualityIndex;
118 // destructor
119}
120
121////////////////////////////////////////////////////////////////////////////////
122/// determine the pruning sequence
123
125{
126 Bool_t HaveStopCondition = fAlpha > 0; // keep pruning the tree until reach the limit fAlpha
127
128 // build a wrapper tree to perform work on
130
131 Int_t k = 0;
132 Double_t epsilon = std::numeric_limits<double>::epsilon();
133 Double_t alpha = -1.0e10;
134
135 std::ofstream outfile;
136 if (fDebug) outfile.open("costcomplexity.log");
137 if(!HaveStopCondition && (fValidationSample == NULL && fValidationDataSet == NULL) ) {
138 if (fDebug) outfile << "ERROR: no validation sample, so cannot optimize pruning!" << std::endl;
139 delete dTWrapper;
140 if (fDebug) outfile.close();
141 return;
142 }
143
144 CCTreeWrapper::CCTreeNode* R = dTWrapper->GetRoot();
145 while(R->GetNLeafDaughters() > 1) { // prune upwards to the root node
146 if(R->GetMinAlphaC() > alpha)
147 alpha = R->GetMinAlphaC(); // initialize alpha
148
149 if(HaveStopCondition && alpha > fAlpha) break;
150
152
153 while(t->GetMinAlphaC() < t->GetAlphaC()) { // descend to the weakest link
154
155 if(fabs(t->GetMinAlphaC() - t->GetLeftDaughter()->GetMinAlphaC())/fabs(t->GetMinAlphaC()) < epsilon)
156 t = t->GetLeftDaughter();
157 else
158 t = t->GetRightDaughter();
159 }
160
161 if( t == R ) {
162 if (fDebug) outfile << std::endl << "Caught trying to prune the root node!" << std::endl;
163 break;
164 }
165
167
168 if (fDebug){
169 outfile << "===========================" << std::endl
170 << "Pruning branch listed below" << std::endl
171 << "===========================" << std::endl;
172 t->PrintRec( outfile );
173
174 }
175 if (!(t->GetLeftDaughter()) && !(t->GetRightDaughter()) ) {
176 break;
177 }
178 dTWrapper->PruneNode(t); // prune the branch rooted at node t
179
180 while(t != R) { // go back up the (pruned) tree and recalculate R(T), alpha_c
181 t = t->GetMother();
186 t->SetMinAlphaC(std::min(t->GetAlphaC(), std::min(t->GetLeftDaughter()->GetMinAlphaC(),
188 }
189 k += 1;
190 if(!HaveStopCondition) {
191 Double_t q;
192 if (fValidationDataSet != NULL) q = dTWrapper->TestTreeQuality(fValidationDataSet);
193 else q = dTWrapper->TestTreeQuality(fValidationSample);
194 fQualityIndexList.push_back(q);
195 }
196 else {
197 fQualityIndexList.push_back(1.0);
198 }
199 fPruneSequence.push_back(n->GetDTNode());
200 fPruneStrengthList.push_back(alpha);
201 }
202
203 Double_t qmax = -1.0e6;
204 if(!HaveStopCondition) {
205 for(UInt_t i = 0; i < fQualityIndexList.size(); i++) {
206 if(fQualityIndexList[i] > qmax) {
207 qmax = fQualityIndexList[i];
208 k = i;
209 }
210 }
211 fOptimalK = k;
212 }
213 else {
214 fOptimalK = fPruneSequence.size() - 1;
215 }
216
217 if (fDebug){
218 outfile << std::endl << "************ Summary **************" << std::endl
219 << "Number of trees in the sequence: " << fPruneSequence.size() << std::endl;
220
221 outfile << "Pruning strength parameters: [";
222 for(UInt_t i = 0; i < fPruneStrengthList.size()-1; i++)
223 outfile << fPruneStrengthList[i] << ", ";
224 outfile << fPruneStrengthList[fPruneStrengthList.size()-1] << "]" << std::endl;
225
226 outfile << "Misclassification rates: [";
227 for(UInt_t i = 0; i < fQualityIndexList.size()-1; i++)
228 outfile << fQualityIndexList[i] << ", ";
229 outfile << fQualityIndexList[fQualityIndexList.size()-1] << "]" << std::endl;
230
231 outfile << "Optimal index: " << fOptimalK+1 << std::endl;
232 outfile.close();
233 }
234 delete dTWrapper;
235}
236
237////////////////////////////////////////////////////////////////////////////////
238/// return the prune strength (=alpha) corresponding to the prune sequence
239
240std::vector<DecisionTreeNode*> CCPruner::GetOptimalPruneSequence( ) const
241{
242 std::vector<DecisionTreeNode*> optimalSequence;
243 if( fOptimalK >= 0 ) {
244 for( Int_t i = 0; i < fOptimalK; i++ ) {
245 optimalSequence.push_back(fPruneSequence[i]);
246 }
247 }
248 return optimalSequence;
249}
250
251
#define R(a, b, c, d, e, f, g, h, i)
Definition RSha256.hxx:110
constexpr Bool_t kTRUE
Definition RtypesCore.h:100
float * q
CCPruner(DecisionTree *t_max, const EventList *validationSample, SeparationBase *qualityIndex=nullptr)
constructor
Definition CCPruner.cxx:69
Float_t fAlpha
! regularization parameter in CC pruning
Definition CCPruner.h:93
std::vector< Float_t > fQualityIndexList
! map of R(T) -> pruning index
Definition CCPruner.h:103
void Optimize()
determine the pruning sequence
Definition CCPruner.cxx:124
Bool_t fDebug
! debug flag
Definition CCPruner.h:106
Bool_t fOwnQIndex
! flag indicates if fQualityIndex is owned by this
Definition CCPruner.h:97
std::vector< Event * > EventList
Definition CCPruner.h:64
std::vector< TMVA::DecisionTreeNode * > fPruneSequence
! map of weakest links (i.e., branches to prune) -> pruning index
Definition CCPruner.h:101
const EventList * fValidationSample
! the event sample to select the optimally-pruned tree
Definition CCPruner.h:94
std::vector< TMVA::DecisionTreeNode * > GetOptimalPruneSequence() const
return the prune strength (=alpha) corresponding to the prune sequence
Definition CCPruner.cxx:240
Int_t fOptimalK
! index of the optimal tree in the pruned tree sequence
Definition CCPruner.h:105
const DataSet * fValidationDataSet
! the event sample to select the optimally-pruned tree
Definition CCPruner.h:95
std::vector< Float_t > fPruneStrengthList
! map of alpha -> pruning index
Definition CCPruner.h:102
SeparationBase * fQualityIndex
! the quality index used to calculate R(t), R(T) = sum[t in ~T]{ R(t) }
Definition CCPruner.h:96
DecisionTree * fTree
! (pruned) decision tree
Definition CCPruner.h:99
Double_t GetNodeResubstitutionEstimate() const
void SetMinAlphaC(Double_t alpha)
void SetResubstitutionEstimate(Double_t R)
Double_t GetResubstitutionEstimate() const
virtual void PrintRec(std::ostream &os) const
recursive printout of the node and its daughters
void SetAlphaC(Double_t alpha)
CCTreeNode * GetRoot()
Double_t TestTreeQuality(const EventList *validationSample)
return the misclassification rate of a pruned tree for a validation event sample using an EventList
void PruneNode(CCTreeNode *t)
remove the branch rooted at node t
Class that contains all the data information.
Definition DataSet.h:58
Implementation of a Decision Tree.
Implementation of the MisClassificationError as separation criterion.
An interface to calculate the "SeparationGain" for different separation criteria used in various trai...
const Int_t n
Definition legend1.C:16
create variable transformations