Logo ROOT   6.12/07
Reference Guide
CCPruner.cxx
Go to the documentation of this file.
1 /**********************************************************************************
2  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
3  * Package: TMVA *
4  * Class : CCPruner *
5  * Web : http://tmva.sourceforge.net *
6  * *
7  * Description: Cost Complexity Pruning *
8  *
9  * Author: Doug Schouten (dschoute@sfu.ca)
10  *
11  * *
12  * Copyright (c) 2007: *
13  * CERN, Switzerland *
14  * MPI-K Heidelberg, Germany *
15  * U. of Texas at Austin, USA *
16  * *
17  * Redistribution and use in source and binary forms, with or without *
18  * modification, are permitted according to the terms listed in LICENSE *
19  * (http://tmva.sourceforge.net/LICENSE) *
20  **********************************************************************************/
21 
22 #include "TMVA/CCPruner.h"
23 #include "TMVA/SeparationBase.h"
24 #include "TMVA/GiniIndex.h"
26 #include "TMVA/CCTreeWrapper.h"
27 #include "TMVA/DataSet.h"
28 
29 #include "Rtypes.h"
30 
31 #include <iostream>
32 #include <fstream>
33 #include <limits>
34 #include <math.h>
35 
36 /*! \class TMVA::CCPruner
37 \ingroup TMVA
38 A helper class to prune a decision tree using the Cost Complexity method
39 (see Classification and Regression Trees by Leo Breiman et al)
40 
41 ### Some definitions:
42 
43  - \f$ T_{max} \f$ - the initial, usually highly overtrained tree, that is to be pruned back
44  - \f$ R(T) \f$ - quality index (Gini, misclassification rate, or other) of a tree \f$ T \f$
45  - \f$ \sim T \f$ - set of terminal nodes in \f$ T \f$
46  - \f$ T' \f$ - the pruned subtree of \f$ T_max \f$ that has the best quality index \f$ R(T') \f$
47  - \f$ \alpha \f$ - the prune strength parameter in Cost Complexity pruning \f$ (R_{\alpha}(T) = R(T) + \alpha*|\sim T|) \f$
48 
49 There are two running modes in CCPruner: (i) one may select a prune strength and prune back
50 the tree \f$ T_{max}\f$ until the criterion:
51 \f[
52  \alpha < \frac{R(T) - R(t)}{|\sim T_t| - 1}
53 \f]
54 
55 is true for all nodes t in \f$ T \f$, or (ii) the algorithm finds the sequence of critical points
56 \f$ \alpha_k < \alpha_{k+1} ... < \alpha_K \f$ such that \f$ T_K = root(T_{max}) \f$ and then selects the optimally-pruned
57 subtree, defined to be the subtree with the best quality index for the validation sample.
58 */
59 
60 namespace TMVA {
61  class DecisionTree;
62 }
63 
64 using namespace TMVA;
65 
66 ////////////////////////////////////////////////////////////////////////////////
67 /// constructor
68 
69 CCPruner::CCPruner( DecisionTree* t_max, const EventList* validationSample,
70  SeparationBase* qualityIndex ) :
71  fAlpha(-1.0),
72  fValidationSample(validationSample),
73  fValidationDataSet(NULL),
74  fOptimalK(-1)
75 {
76  fTree = t_max;
77 
78  if(qualityIndex == NULL) {
79  fOwnQIndex = true;
81  }
82  else {
83  fOwnQIndex = false;
84  fQualityIndex = qualityIndex;
85  }
86  fDebug = kTRUE;
87 }
88 
89 ////////////////////////////////////////////////////////////////////////////////
90 /// constructor
91 
92 CCPruner::CCPruner( DecisionTree* t_max, const DataSet* validationSample,
93  SeparationBase* qualityIndex ) :
94  fAlpha(-1.0),
95  fValidationSample(NULL),
96  fValidationDataSet(validationSample),
97  fOptimalK(-1)
98 {
99  fTree = t_max;
100 
101  if(qualityIndex == NULL) {
102  fOwnQIndex = true;
104  }
105  else {
106  fOwnQIndex = false;
107  fQualityIndex = qualityIndex;
108  }
109  fDebug = kTRUE;
110 }
111 
112 
113 ////////////////////////////////////////////////////////////////////////////////
114 
116 {
117  if(fOwnQIndex) delete fQualityIndex;
118  // destructor
119 }
120 
121 ////////////////////////////////////////////////////////////////////////////////
122 /// determine the pruning sequence
123 
125 {
126  Bool_t HaveStopCondition = fAlpha > 0; // keep pruning the tree until reach the limit fAlpha
127 
128  // build a wrapper tree to perform work on
129  CCTreeWrapper* dTWrapper = new CCTreeWrapper(fTree, fQualityIndex);
130 
131  Int_t k = 0;
133  Double_t alpha = -1.0e10;
134 
135  std::ofstream outfile;
136  if (fDebug) outfile.open("costcomplexity.log");
137  if(!HaveStopCondition && (fValidationSample == NULL && fValidationDataSet == NULL) ) {
138  if (fDebug) outfile << "ERROR: no validation sample, so cannot optimize pruning!" << std::endl;
139  delete dTWrapper;
140  if (fDebug) outfile.close();
141  return;
142  }
143 
144  CCTreeWrapper::CCTreeNode* R = dTWrapper->GetRoot();
145  while(R->GetNLeafDaughters() > 1) { // prune upwards to the root node
146  if(R->GetMinAlphaC() > alpha)
147  alpha = R->GetMinAlphaC(); // initialize alpha
148 
149  if(HaveStopCondition && alpha > fAlpha) break;
150 
152 
153  while(t->GetMinAlphaC() < t->GetAlphaC()) { // descend to the weakest link
154 
155  if(fabs(t->GetMinAlphaC() - t->GetLeftDaughter()->GetMinAlphaC())/fabs(t->GetMinAlphaC()) < epsilon)
156  t = t->GetLeftDaughter();
157  else
158  t = t->GetRightDaughter();
159  }
160 
161  if( t == R ) {
162  if (fDebug) outfile << std::endl << "Caught trying to prune the root node!" << std::endl;
163  break;
164  }
165 
167 
168  if (fDebug){
169  outfile << "===========================" << std::endl
170  << "Pruning branch listed below" << std::endl
171  << "===========================" << std::endl;
172  t->PrintRec( outfile );
173 
174  }
175  if (!(t->GetLeftDaughter()) && !(t->GetRightDaughter()) ) {
176  break;
177  }
178  dTWrapper->PruneNode(t); // prune the branch rooted at node t
179 
180  while(t != R) { // go back up the (pruned) tree and recalculate R(T), alpha_c
181  t = t->GetMother();
186  t->SetMinAlphaC(std::min(t->GetAlphaC(), std::min(t->GetLeftDaughter()->GetMinAlphaC(),
187  t->GetRightDaughter()->GetMinAlphaC())));
188  }
189  k += 1;
190  if(!HaveStopCondition) {
191  Double_t q;
192  if (fValidationDataSet != NULL) q = dTWrapper->TestTreeQuality(fValidationDataSet);
193  else q = dTWrapper->TestTreeQuality(fValidationSample);
194  fQualityIndexList.push_back(q);
195  }
196  else {
197  fQualityIndexList.push_back(1.0);
198  }
199  fPruneSequence.push_back(n->GetDTNode());
200  fPruneStrengthList.push_back(alpha);
201  }
202 
203  Double_t qmax = -1.0e6;
204  if(!HaveStopCondition) {
205  for(UInt_t i = 0; i < fQualityIndexList.size(); i++) {
206  if(fQualityIndexList[i] > qmax) {
207  qmax = fQualityIndexList[i];
208  k = i;
209  }
210  }
211  fOptimalK = k;
212  }
213  else {
214  fOptimalK = fPruneSequence.size() - 1;
215  }
216 
217  if (fDebug){
218  outfile << std::endl << "************ Summary **************" << std::endl
219  << "Number of trees in the sequence: " << fPruneSequence.size() << std::endl;
220 
221  outfile << "Pruning strength parameters: [";
222  for(UInt_t i = 0; i < fPruneStrengthList.size()-1; i++)
223  outfile << fPruneStrengthList[i] << ", ";
224  outfile << fPruneStrengthList[fPruneStrengthList.size()-1] << "]" << std::endl;
225 
226  outfile << "Misclassification rates: [";
227  for(UInt_t i = 0; i < fQualityIndexList.size()-1; i++)
228  outfile << fQualityIndexList[i] << ", ";
229  outfile << fQualityIndexList[fQualityIndexList.size()-1] << "]" << std::endl;
230 
231  outfile << "Optimal index: " << fOptimalK+1 << std::endl;
232  outfile.close();
233  }
234  delete dTWrapper;
235 }
236 
237 ////////////////////////////////////////////////////////////////////////////////
238 /// return the prune strength (=alpha) corresponding to the prune sequence
239 
240 std::vector<DecisionTreeNode*> CCPruner::GetOptimalPruneSequence( ) const
241 {
242  std::vector<DecisionTreeNode*> optimalSequence;
243  if( fOptimalK >= 0 ) {
244  for( Int_t i = 0; i < fOptimalK; i++ ) {
245  optimalSequence.push_back(fPruneSequence[i]);
246  }
247  }
248  return optimalSequence;
249 }
250 
251 
void Optimize()
determine the pruning sequence
Definition: CCPruner.cxx:124
std::vector< Float_t > fQualityIndexList
map of alpha -> pruning index
Definition: CCPruner.h:102
Float_t fAlpha
Definition: CCPruner.h:92
virtual void PrintRec(std::ostream &os) const
recursive printout of the node and its daughters
const DataSet * fValidationDataSet
the event sample to select the optimally-pruned tree
Definition: CCPruner.h:94
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
Int_t fOptimalK
map of R(T) -> pruning index
Definition: CCPruner.h:104
void SetResubstitutionEstimate(Double_t R)
Definition: CCTreeWrapper.h:70
DecisionTree * fTree
flag indicates if fQualityIndex is owned by this
Definition: CCPruner.h:98
Bool_t fDebug
index of the optimal tree in the pruned tree sequence
Definition: CCPruner.h:105
Double_t GetResubstitutionEstimate() const
Definition: CCTreeWrapper.h:73
SeparationBase * fQualityIndex
the event sample to select the optimally-pruned tree
Definition: CCPruner.h:95
Implementation of the MisClassificationError as separation criterion.
Class that contains all the data information.
Definition: DataSet.h:69
VecExpr< UnaryOp< Fabs< T >, VecExpr< A, T, D >, T >, T, D > fabs(const VecExpr< A, T, D > &rhs)
Double_t TestTreeQuality(const EventList *validationSample)
return the misclassification rate of a pruned tree for a validation event sample using an EventList ...
Bool_t fOwnQIndex
the quality index used to calculate R(t), R(T) = sum[t in ~T]{ R(t) }
Definition: CCPruner.h:96
CCTreeNode * GetRoot()
std::vector< TMVA::DecisionTreeNode * > fPruneSequence
(pruned) decision tree
Definition: CCPruner.h:100
std::vector< Event * > EventList
Definition: CCPruner.h:63
Implementation of a Decision Tree.
Definition: DecisionTree.h:59
unsigned int UInt_t
Definition: RtypesCore.h:42
An interface to calculate the "SeparationGain" for different separation criteria used in various trai...
void SetMinAlphaC(Double_t alpha)
Definition: CCTreeWrapper.h:86
REAL epsilon
Definition: triangle.c:617
const EventList * fValidationSample
regularization parameter in CC pruning
Definition: CCPruner.h:93
double Double_t
Definition: RtypesCore.h:55
std::vector< TMVA::DecisionTreeNode * > GetOptimalPruneSequence() const
return the prune strength (=alpha) corresponding to the prune sequence
Definition: CCPruner.cxx:240
DecisionTreeNode * GetDTNode() const
Definition: CCTreeWrapper.h:92
void SetAlphaC(Double_t alpha)
Definition: CCTreeWrapper.h:80
Abstract ClassifierFactory template that handles arbitrary types.
CCPruner(DecisionTree *t_max, const EventList *validationSample, SeparationBase *qualityIndex=NULL)
constructor
Definition: CCPruner.cxx:69
void PruneNode(CCTreeNode *t)
remove the branch rooted at node t
std::vector< Float_t > fPruneStrengthList
map of weakest links (i.e., branches to prune) -> pruning index
Definition: CCPruner.h:101
float * q
Definition: THbookFile.cxx:87
constexpr Double_t R()
Definition: TMath.h:213
const Bool_t kTRUE
Definition: RtypesCore.h:87
const Int_t n
Definition: legend1.C:16
Double_t GetNodeResubstitutionEstimate() const
Definition: CCTreeWrapper.h:66