Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
CostComplexityPruneTool.cxx
Go to the documentation of this file.
1/**********************************************************************************
2 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
3 * Package: TMVA *
4 * Class : TMVA::DecisionTree *
5 * *
6 * *
7 * Description: *
8 * Implementation of a Decision Tree *
9 * *
10 * Authors (alphabetical): *
11 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
12 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
13 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
14 * Doug Schouten <dschoute@sfu.ca> - Simon Fraser U., Canada *
15 * *
16 * Copyright (c) 2005: *
17 * CERN, Switzerland *
18 * U. of Victoria, Canada *
19 * MPI-K Heidelberg, Germany *
20 * *
21 * Redistribution and use in source and binary forms, with or without *
22 * modification, are permitted according to the terms listed in LICENSE *
23 * (http://mva.sourceforge.net/license.txt) *
24 * *
25 **********************************************************************************/
26
27/*! \class TMVA::CostComplexityPruneTool
28\ingroup TMVA
29A class to prune a decision tree using the Cost Complexity method.
30(see "Classification and Regression Trees" by Leo Breiman et al)
31
32### Some definitions:
33
34 - \f$ T_{max} \f$ - the initial, usually highly overtrained tree, that is to be pruned back
35 - \f$ R(T) \f$ - quality index (Gini, misclassification rate, or other) of a tree \f$ T \f$
36 - \f$ \sim T \f$ - set of terminal nodes in \f$ T \f$
37 - \f$ T' \f$ - the pruned subtree of \f$ T_max \f$ that has the best quality index \f$ R(T') \f$
38 - \f$ \alpha \f$ - the prune strength parameter in Cost Complexity pruning \f$ (R_{\alpha}(T) = R(T) + \alpha*|\sim T|) \f$
39
40There are two running modes in CCPruner: (i) one may select a prune strength and prune back
41the tree \f$ T_{max}\f$ until the criterion:
42\f[
43 \alpha < \frac{R(T) - R(t)}{|\sim T_t| - 1}
44\f]
45
46is true for all nodes t in \f$ T \f$, or (ii) the algorithm finds the sequence of critical points
47\f$ \alpha_k < \alpha_{k+1} ... < \alpha_K \f$ such that \f$ T_K = root(T_{max}) \f$ and then selects the optimally-pruned
48subtree, defined to be the subtree with the best quality index for the validation sample.
49*/
50
52
53#include "TMVA/MsgLogger.h"
54#include "TMVA/SeparationBase.h"
55#include "TMVA/DecisionTree.h"
56
57#include "RtypesCore.h"
58
59#include <limits>
60#include <cmath>
61
62using namespace TMVA;
63
64
65////////////////////////////////////////////////////////////////////////////////
66/// the constructor for the cost complexity pruning
67
69 IPruneTool(),
70 fLogger(new MsgLogger("CostComplexityPruneTool") )
71{
72 fOptimalK = -1;
73
74 // !! changed from Dougs code. Now use the QualityIndex stored already
75 // in the nodes when no "new" QualityIndex calculator is given. Like this
76 // I can easily implement the Regression. For Regression, the pruning uses the
77 // same separation index as in the tree building, hence doesn't need to re-calculate
78 // (which would need more info than simply "s" and "b")
79
80 fQualityIndexTool = qualityIndex;
81
82 //fLogger->SetMinType( kDEBUG );
83 fLogger->SetMinType( kWARNING );
84}
85
86////////////////////////////////////////////////////////////////////////////////
87/// the destructor for the cost complexity pruning
88
90 if(fQualityIndexTool != NULL) delete fQualityIndexTool;
91}
92
93////////////////////////////////////////////////////////////////////////////////
94/// the routine that basically "steers" the pruning process. Call the calculation of
95/// the pruning sequence, the tree quality and alike..
96
99 const IPruneTool::EventSample* validationSample,
100 Bool_t isAutomatic )
101{
102 if( isAutomatic ) SetAutomatic();
103
104 if( dt == NULL || (IsAutomatic() && validationSample == NULL) ) {
105 // must have a valid decision tree to prune, and if the prune strength
106 // is to be chosen automatically, must have a test sample from
107 // which to calculate the quality of the pruned tree(s)
108 return NULL;
109 }
110
111 Double_t Q = -1.0;
112 Double_t W = 1.0;
113
114 if(IsAutomatic()) {
115 // run the pruning validation sample through the unpruned tree
116 dt->ApplyValidationSample(validationSample);
117 W = dt->GetSumWeights(validationSample); // get the sum of weights in the pruning validation sample
118 // calculate the quality of the tree in the unpruned case
119 Q = dt->TestPrunedTreeQuality();
120
121 Log() << kDEBUG << "Node purity limit is: " << dt->GetNodePurityLimit() << Endl;
122 Log() << kDEBUG << "Sum of weights in pruning validation sample: " << W << Endl;
123 Log() << kDEBUG << "Quality of tree prior to any pruning is " << Q/W << Endl;
124 }
125
126 // store the cost complexity metadata for the decision tree at each node
127 try {
129 }
130 catch(const std::string &error) {
131 Log() << kERROR << "Couldn't initialize the tree meta data because of error ("
132 << error << ")" << Endl;
133 return NULL;
134 }
135
136 Log() << kDEBUG << "Automatic cost complexity pruning is " << (IsAutomatic()?"on":"off") << "." << Endl;
137
138 try {
139 Optimize( dt, W ); // run the cost complexity pruning algorithm
140 }
141 catch(const std::string &error) {
142 Log() << kERROR << "Error optimizing pruning sequence ("
143 << error << ")" << Endl;
144 return NULL;
145 }
146
147 Log() << kDEBUG << "Index of pruning sequence to stop at: " << fOptimalK << Endl;
148
149 PruningInfo* info = new PruningInfo();
150
151
152 if(fOptimalK < 0) {
153 // no pruning necessary, or wasn't able to compute a sequence
154 info->PruneStrength = 0;
155 info->QualityIndex = Q/W;
156 info->PruneSequence.clear();
157 Log() << kINFO << "no proper pruning could be calculated. Tree "
158 << dt->GetTreeID() << " will not be pruned. Do not worry if this "
159 << " happens for a few trees " << Endl;
160 return info;
161 }
163 Log() << kDEBUG << " prune until k=" << fOptimalK << " with alpha="<<fPruneStrengthList[fOptimalK]<< Endl;
164 for( Int_t i = 0; i < fOptimalK; i++ ){
165 info->PruneSequence.push_back(fPruneSequence[i]);
166 }
167 if( IsAutomatic() ){
169 }
170 else {
172 }
173
174 return info;
175}
176
177////////////////////////////////////////////////////////////////////////////////
178/// initialise "meta data" for the pruning, like the "costcomplexity", the
179/// critical alpha, the minimal alpha down the tree, etc... for each node!!
180
182 if( n == NULL ) return;
183
184 Double_t s = n->GetNSigEvents();
185 Double_t b = n->GetNBkgEvents();
186 // set R(t) = N_events*Gini(t) or MisclassificationError(t), etc.
188 else n->SetNodeR( (s+b)*n->GetSeparationIndex() );
189
190 if(n->GetLeft() != NULL && n->GetRight() != NULL) { // n is an interior (non-leaf) node
191 n->SetTerminal(kFALSE);
192 // traverse the tree
193 InitTreePruningMetaData(n->GetLeft());
194 InitTreePruningMetaData(n->GetRight());
195 // set |~T_t|
196 n->SetNTerminal( n->GetLeft()->GetNTerminal() +
197 n->GetRight()->GetNTerminal());
198 // set R(T) = sum[n' in ~T]{ R(n') }
199 n->SetSubTreeR( (n->GetLeft()->GetSubTreeR() +
200 n->GetRight()->GetSubTreeR()));
201 // set alpha_c, the alpha value at which it becomes advantageous to prune at node n
202 n->SetAlpha( ((n->GetNodeR() - n->GetSubTreeR()) /
203 (n->GetNTerminal() - 1)));
204
205 // G(t) = min( alpha_c, G(l(n)), G(r(n)) )
206 // the minimum alpha in subtree rooted at this node
207 n->SetAlphaMinSubtree( std::min(n->GetAlpha(), std::min(n->GetLeft()->GetAlphaMinSubtree(),
208 n->GetRight()->GetAlphaMinSubtree())));
209 n->SetCC(n->GetAlpha());
210
211 } else { // n is a terminal node
212 n->SetNTerminal( 1 ); n->SetTerminal( );
213 if (fQualityIndexTool) n->SetSubTreeR(((s+b)*fQualityIndexTool->GetSeparationIndex(s,b)));
214 else n->SetSubTreeR( (s+b)*n->GetSeparationIndex() );
215 n->SetAlpha(std::numeric_limits<double>::infinity( ));
216 n->SetAlphaMinSubtree(std::numeric_limits<double>::infinity( ));
217 n->SetCC(n->GetAlpha());
218 }
219
220 // DecisionTreeNode* R = (DecisionTreeNode*)mdt->GetRoot();
221 // Double_t x = R->GetAlphaMinSubtree();
222 // Log() << "alphaMin(Root) = " << x << Endl;
223}
224
225
226////////////////////////////////////////////////////////////////////////////////
227/// after the critical \f$ \alpha \f$ values (at which the corresponding nodes would
228/// be pruned away) had been established in the "InitMetaData" we need now:
229/// automatic pruning:
230///
231/// find the value of \f$ \alpha \f$ for which the test sample gives minimal error,
232/// on the tree with all nodes pruned that have \f$ \alpha_{critical} < \alpha \f$,
233/// fixed parameter pruning
234///
235
237 Int_t k = 1;
238 Double_t alpha = -1.0e10;
239 Double_t epsilon = std::numeric_limits<double>::epsilon();
240
241 fQualityIndexList.clear();
242 fPruneSequence.clear();
243 fPruneStrengthList.clear();
244
246
247 Double_t qmin = 0.0;
248 if(IsAutomatic()){
249 // initialize the tree quality (actually at this stage, it is the quality of the yet unpruned tree
250 qmin = dt->TestPrunedTreeQuality()/weights;
251 }
252
253 // now prune the tree in steps until it is gone. At each pruning step, the pruning
254 // takes place at the node that is regarded as the "weakest link".
255 // for automatic pruning, at each step, we calculate the current quality of the
256 // tree and in the end we will prune at the minimum of the tree quality
257 // for the fixed parameter pruning, the cut is simply set at a relative position
258 // in the sequence according to the "length" of the sequence of pruned trees.
259 // 100: at the end (pruned until the root node would be the next pruning candidate
260 // 50: in the middle of the sequence
261 // etc...
262 while(R->GetNTerminal() > 1) { // prune upwards to the root node
263
264 // initialize alpha
265 alpha = TMath::Max(R->GetAlphaMinSubtree(), alpha);
266
267 if( R->GetAlphaMinSubtree() >= R->GetAlpha() ) {
268 Log() << kDEBUG << "\nCaught trying to prune the root node!" << Endl;
269 break;
270 }
271
272
273 DecisionTreeNode* t = R;
274
275 // descend to the weakest link
276 while(t->GetAlphaMinSubtree() < t->GetAlpha()) {
277 // std::cout << t->GetAlphaMinSubtree() << " " << t->GetAlpha()<< " "
278 // << t->GetAlphaMinSubtree()- t->GetAlpha()<< " t==R?" << int(t == R) << std::endl;
279 // while( (t->GetAlphaMinSubtree() - t->GetAlpha()) < epsilon) {
280 // if(TMath::Abs(t->GetAlphaMinSubtree() - t->GetLeft()->GetAlphaMinSubtree())/TMath::Abs(t->GetAlphaMinSubtree()) < epsilon) {
281 if(TMath::Abs(t->GetAlphaMinSubtree() - t->GetLeft()->GetAlphaMinSubtree()) < epsilon) {
282 t = t->GetLeft();
283 } else {
284 t = t->GetRight();
285 }
286 }
287
288 if( t == R ) {
289 Log() << kDEBUG << "\nCaught trying to prune the root node!" << Endl;
290 break;
291 }
292
293 DecisionTreeNode* n = t;
294
295 // Log() << kDEBUG << "alpha[" << k << "]: " << alpha << Endl;
296 // Log() << kDEBUG << "===========================" << Endl
297 // << "Pruning branch listed below the node" << Endl;
298 // t->Print( Log() );
299 // Log() << kDEBUG << "===========================" << Endl;
300 // t->PrintRecPrune( Log() );
301
302 dt->PruneNodeInPlace(t); // prune the branch rooted at node t
303
304 while(t != R) { // go back up the (pruned) tree and recalculate R(T), alpha_c
305 t = t->GetParent();
307 t->SetSubTreeR(t->GetLeft()->GetSubTreeR() + t->GetRight()->GetSubTreeR());
308 t->SetAlpha((t->GetNodeR() - t->GetSubTreeR())/(t->GetNTerminal() - 1));
309 t->SetAlphaMinSubtree(std::min(t->GetAlpha(), std::min(t->GetLeft()->GetAlphaMinSubtree(),
310 t->GetRight()->GetAlphaMinSubtree())));
311 t->SetCC(t->GetAlpha());
312 }
313 k += 1;
314
315 Log() << kDEBUG << "after this pruning step I would have " << R->GetNTerminal() << " remaining terminal nodes " << Endl;
316
317 if(IsAutomatic()) {
318 Double_t q = dt->TestPrunedTreeQuality()/weights;
319 fQualityIndexList.push_back(q);
320 }
321 else {
322 fQualityIndexList.push_back(1.0);
323 }
324 fPruneSequence.push_back(n);
325 fPruneStrengthList.push_back(alpha);
326 }
327
328 if(fPruneSequence.empty()) {
329 fOptimalK = -1;
330 return;
331 }
332
333 if(IsAutomatic()) {
334 k = -1;
335 for(UInt_t i = 0; i < fQualityIndexList.size(); i++) {
336 if(fQualityIndexList[i] < qmin) {
337 qmin = fQualityIndexList[i];
338 k = i;
339 }
340 }
341 fOptimalK = k;
342 }
343 else {
344 // regularize the prune strength relative to this tree
345 fOptimalK = int(fPruneStrength/100.0 * fPruneSequence.size() );
346 Log() << kDEBUG << "SequenzeSize="<<fPruneSequence.size()
347 << " fOptimalK " << fOptimalK << Endl;
348
349 }
350
351 Log() << kDEBUG << "\n************ Summary for Tree " << dt->GetTreeID() << " *******" << Endl
352 << "Number of trees in the sequence: " << fPruneSequence.size() << Endl;
353
354 Log() << kDEBUG << "Pruning strength parameters: [";
355 for(UInt_t i = 0; i < fPruneStrengthList.size()-1; i++)
356 Log() << kDEBUG << fPruneStrengthList[i] << ", ";
357 Log() << kDEBUG << fPruneStrengthList[fPruneStrengthList.size()-1] << "]" << Endl;
358
359 Log() << kDEBUG << "Misclassification rates: [";
360 for(UInt_t i = 0; i < fQualityIndexList.size()-1; i++)
361 Log() << kDEBUG << fQualityIndexList[i] << ", ";
362 Log() << kDEBUG << fQualityIndexList[fQualityIndexList.size()-1] << "]" << Endl;
363
364 Log() << kDEBUG << "Prune index: " << fOptimalK+1 << Endl;
365
366}
367
#define b(i)
Definition RSha256.hxx:100
constexpr Bool_t kFALSE
Definition RtypesCore.h:101
float * q
void InitTreePruningMetaData(DecisionTreeNode *n)
initialise "meta data" for the pruning, like the "costcomplexity", the critical alpha,...
std::vector< Double_t > fPruneStrengthList
! map of alpha -> pruning index
virtual ~CostComplexityPruneTool()
the destructor for the cost complexity pruning
CostComplexityPruneTool(SeparationBase *qualityIndex=nullptr)
the constructor for the cost complexity pruning
std::vector< DecisionTreeNode * > fPruneSequence
! map of weakest links (i.e., branches to prune) -> pruning index
std::vector< Double_t > fQualityIndexList
! map of R(T) -> pruning index
void Optimize(DecisionTree *dt, Double_t weights)
after the critical values (at which the corresponding nodes would be pruned away) had been establish...
SeparationBase * fQualityIndexTool
! the quality index used to calculate R(t), R(T) = sum[t in ~T]{ R(t) }
MsgLogger & Log() const
output stream to save logging information
virtual PruningInfo * CalculatePruningInfo(DecisionTree *dt, const IPruneTool::EventSample *testEvents=nullptr, Bool_t isAutomatic=kFALSE)
the routine that basically "steers" the pruning process.
Int_t fOptimalK
! the optimal index of the prune sequence
void SetCC(Double_t cc)
Set CC, if traininfo defined, otherwise Log Fatal.
Double_t GetSubTreeR() const
return the resubstitution estimate, R(T_t), of the tree rooted at this node, or -1 if traininfo undef...
void SetAlphaMinSubtree(Double_t g)
set the minimum alpha in the tree rooted at this node, if traininfo defined
Double_t GetAlphaMinSubtree() const
return the minimum alpha in the tree rooted at this node, or -1 if traininfo undefined
void SetSubTreeR(Double_t r)
set the resubstitution estimate, R(T_t), of the tree rooted at this node, if traininfo defined
virtual DecisionTreeNode * GetLeft() const
Double_t GetNodeR() const
return the node resubstitution estimate, R(t), for Cost Complexity pruning, or -1 if traininfo undefi...
Double_t GetAlpha() const
return the critical point alpha, or -1 if traininfo undefined
Int_t GetNTerminal() const
return number of terminal nodes in the subtree rooted here, or -1 if traininfo undefined
void SetAlpha(Double_t alpha)
set the critical point alpha, if traininfo defined
virtual DecisionTreeNode * GetParent() const
void SetNTerminal(Int_t n)
set number of terminal nodes in the subtree rooted here, if traininfo defined
virtual DecisionTreeNode * GetRight() const
Implementation of a Decision Tree.
Double_t GetNodePurityLimit() const
void ApplyValidationSample(const EventConstList *validationSample) const
run the validation sample through the (pruned) tree and fill in the nodes the variables NSValidation ...
virtual DecisionTreeNode * GetRoot() const
void PruneNodeInPlace(TMVA::DecisionTreeNode *node)
prune a node temporarily (without actually deleting its descendants which allows testing the pruned t...
Double_t TestPrunedTreeQuality(const DecisionTreeNode *dt=nullptr, Int_t mode=0) const
return the misclassification rate of a pruned tree a "pruned tree" may have set the variable "IsTermi...
Double_t GetSumWeights(const EventConstList *validationSample) const
calculate the normalization factor for a pruning validation sample
IPruneTool - a helper interface class to prune a decision tree.
Definition IPruneTool.h:70
void SetAutomatic()
Definition IPruneTool.h:94
Double_t fPruneStrength
! regularization parameter in pruning
Definition IPruneTool.h:101
std::vector< const Event * > EventSample
Definition IPruneTool.h:74
Bool_t IsAutomatic() const
Definition IPruneTool.h:95
ostringstream derivative to redirect and format output
Definition MsgLogger.h:57
void SetMinType(EMsgType minType)
Definition MsgLogger.h:70
Double_t QualityIndex
Definition IPruneTool.h:45
std::vector< DecisionTreeNode * > PruneSequence
the regularization parameter for pruning
Definition IPruneTool.h:47
Double_t PruneStrength
quality measure for a pruned subtree T of T_max
Definition IPruneTool.h:46
An interface to calculate the "SeparationGain" for different separation criteria used in various trai...
virtual Double_t GetSeparationIndex(const Double_t s, const Double_t b)=0
const Int_t n
Definition legend1.C:16
create variable transformations
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148
Short_t Max(Short_t a, Short_t b)
Returns the largest of a and b.
Definition TMathBase.h:250
Short_t Abs(Short_t d)
Returns the absolute value of parameter Short_t d.
Definition TMathBase.h:123