ROOT  6.06/09
Reference Guide
CCPruner.cxx
Go to the documentation of this file.
1 /**********************************************************************************
2  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
3  * Package: TMVA *
4  * Class : CCPruner *
5  * Web : http://tmva.sourceforge.net *
6  * *
7  * Description: Cost Complexity Pruning *
8  *
9  * Author: Doug Schouten (dschoute@sfu.ca)
10  *
11  * *
12  * Copyright (c) 2007: *
13  * CERN, Switzerland *
14  * MPI-K Heidelberg, Germany *
15  * U. of Texas at Austin, USA *
16  * *
17  * Redistribution and use in source and binary forms, with or without *
18  * modification, are permitted according to the terms listed in LICENSE *
19  * (http://tmva.sourceforge.net/LICENSE) *
20  **********************************************************************************/
21 
22 #include "TMVA/CCPruner.h"
23 #include "TMVA/SeparationBase.h"
24 #include "TMVA/GiniIndex.h"
26 #include "TMVA/CCTreeWrapper.h"
27 #include "TMVA/DataSet.h"
28 
29 #include <iostream>
30 #include <fstream>
31 #include <limits>
32 #include <math.h>
33 
34  using namespace TMVA;
35 
36 ////////////////////////////////////////////////////////////////////////////////
37 /// constructor
38 
39 CCPruner::CCPruner( DecisionTree* t_max, const EventList* validationSample,
40  SeparationBase* qualityIndex ) :
41  fAlpha(-1.0),
42  fValidationSample(validationSample),
43  fValidationDataSet(NULL),
44  fOptimalK(-1)
45 {
46  fTree = t_max;
47 
48  if(qualityIndex == NULL) {
49  fOwnQIndex = true;
51  }
52  else {
53  fOwnQIndex = false;
54  fQualityIndex = qualityIndex;
55  }
56  fDebug = kTRUE;
57 }
58 
59 ////////////////////////////////////////////////////////////////////////////////
60 /// constructor
61 
62 CCPruner::CCPruner( DecisionTree* t_max, const DataSet* validationSample,
63  SeparationBase* qualityIndex ) :
64  fAlpha(-1.0),
65  fValidationSample(NULL),
66  fValidationDataSet(validationSample),
67  fOptimalK(-1)
68 {
69  fTree = t_max;
70 
71  if(qualityIndex == NULL) {
72  fOwnQIndex = true;
74  }
75  else {
76  fOwnQIndex = false;
77  fQualityIndex = qualityIndex;
78  }
79  fDebug = kTRUE;
80 }
81 
82 
83 ////////////////////////////////////////////////////////////////////////////////
84 
86 {
87  if(fOwnQIndex) delete fQualityIndex;
88  // destructor
89 }
90 
91 ////////////////////////////////////////////////////////////////////////////////
92 /// determine the pruning sequence
93 
95 {
96  Bool_t HaveStopCondition = fAlpha > 0; // keep pruning the tree until reach the limit fAlpha
97 
98  // build a wrapper tree to perform work on
100 
101  Int_t k = 0;
103  Double_t alpha = -1.0e10;
104 
105  std::ofstream outfile;
106  if (fDebug) outfile.open("costcomplexity.log");
107  if(!HaveStopCondition && (fValidationSample == NULL && fValidationDataSet == NULL) ) {
108  if (fDebug) outfile << "ERROR: no validation sample, so cannot optimize pruning!" << std::endl;
109  delete dTWrapper;
110  if (fDebug) outfile.close();
111  return;
112  }
113 
114  CCTreeWrapper::CCTreeNode* R = dTWrapper->GetRoot();
115  while(R->GetNLeafDaughters() > 1) { // prune upwards to the root node
116  if(R->GetMinAlphaC() > alpha)
117  alpha = R->GetMinAlphaC(); // initialize alpha
118 
119  if(HaveStopCondition && alpha > fAlpha) break;
120 
122 
123  while(t->GetMinAlphaC() < t->GetAlphaC()) { // descend to the weakest link
124 
125  if(fabs(t->GetMinAlphaC() - t->GetLeftDaughter()->GetMinAlphaC())/fabs(t->GetMinAlphaC()) < epsilon)
126  t = t->GetLeftDaughter();
127  else
128  t = t->GetRightDaughter();
129  }
130 
131  if( t == R ) {
132  if (fDebug) outfile << std::endl << "Caught trying to prune the root node!" << std::endl;
133  break;
134  }
135 
137 
138  if (fDebug){
139  outfile << "===========================" << std::endl
140  << "Pruning branch listed below" << std::endl
141  << "===========================" << std::endl;
142  t->PrintRec( outfile );
143 
144  }
145  if (!(t->GetLeftDaughter()) && !(t->GetRightDaughter()) ) {
146  break;
147  }
148  dTWrapper->PruneNode(t); // prune the branch rooted at node t
149 
150  while(t != R) { // go back up the (pruned) tree and recalculate R(T), alpha_c
151  t = t->GetMother();
157  t->GetRightDaughter()->GetMinAlphaC())));
158  }
159  k += 1;
160  if(!HaveStopCondition) {
161  Double_t q;
163  else q = dTWrapper->TestTreeQuality(fValidationSample);
164  fQualityIndexList.push_back(q);
165  }
166  else {
167  fQualityIndexList.push_back(1.0);
168  }
169  fPruneSequence.push_back(n->GetDTNode());
170  fPruneStrengthList.push_back(alpha);
171  }
172 
173  Double_t qmax = -1.0e6;
174  if(!HaveStopCondition) {
175  for(UInt_t i = 0; i < fQualityIndexList.size(); i++) {
176  if(fQualityIndexList[i] > qmax) {
177  qmax = fQualityIndexList[i];
178  k = i;
179  }
180  }
181  fOptimalK = k;
182  }
183  else {
184  fOptimalK = fPruneSequence.size() - 1;
185  }
186 
187  if (fDebug){
188  outfile << std::endl << "************ Summary **************" << std::endl
189  << "Number of trees in the sequence: " << fPruneSequence.size() << std::endl;
190 
191  outfile << "Pruning strength parameters: [";
192  for(UInt_t i = 0; i < fPruneStrengthList.size()-1; i++)
193  outfile << fPruneStrengthList[i] << ", ";
194  outfile << fPruneStrengthList[fPruneStrengthList.size()-1] << "]" << std::endl;
195 
196  outfile << "Misclassification rates: [";
197  for(UInt_t i = 0; i < fQualityIndexList.size()-1; i++)
198  outfile << fQualityIndexList[i] << ", ";
199  outfile << fQualityIndexList[fQualityIndexList.size()-1] << "]" << std::endl;
200 
201  outfile << "Optimal index: " << fOptimalK+1 << std::endl;
202  outfile.close();
203  }
204  delete dTWrapper;
205 }
206 
207 ////////////////////////////////////////////////////////////////////////////////
208 /// return the prune strength (=alpha) corresponding to the prune sequence
209 
210 std::vector<DecisionTreeNode*> CCPruner::GetOptimalPruneSequence( ) const
211 {
212  std::vector<DecisionTreeNode*> optimalSequence;
213  if( fOptimalK >= 0 ) {
214  for( Int_t i = 0; i < fOptimalK; i++ ) {
215  optimalSequence.push_back(fPruneSequence[i]);
216  }
217  }
218  return optimalSequence;
219 }
220 
221 
void Optimize()
determine the pruning sequence
Definition: CCPruner.cxx:94
std::vector< Float_t > fQualityIndexList
map of alpha -> pruning index
Definition: CCPruner.h:106
static Vc_ALWAYS_INLINE int_v min(const int_v &x, const int_v &y)
Definition: vector.h:433
Float_t fAlpha
Definition: CCPruner.h:96
const DataSet * fValidationDataSet
the event sample to select the optimally-pruned tree
Definition: CCPruner.h:98
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
std::vector< TMVA::DecisionTreeNode * > GetOptimalPruneSequence() const
return the prune strength (=alpha) corresponding to the prune sequence
Definition: CCPruner.cxx:210
Double_t GetResubstitutionEstimate() const
Definition: CCTreeWrapper.h:83
Int_t fOptimalK
map of R(T) -> pruning index
Definition: CCPruner.h:108
void SetResubstitutionEstimate(Double_t R)
Definition: CCTreeWrapper.h:80
Double_t GetNodeResubstitutionEstimate() const
Definition: CCTreeWrapper.h:76
DecisionTree * fTree
flag indicates if fQualityIndex is owned by this
Definition: CCPruner.h:102
Bool_t fDebug
index of the optimal tree in the pruned tree sequence
Definition: CCPruner.h:109
DecisionTreeNode * GetDTNode() const
SeparationBase * fQualityIndex
the event sample to select the optimally-pruned tree
Definition: CCPruner.h:99
VecExpr< UnaryOp< Fabs< T >, VecExpr< A, T, D >, T >, T, D > fabs(const VecExpr< A, T, D > &rhs)
Double_t TestTreeQuality(const EventList *validationSample)
return the misclassification rate of a pruned tree for a validation event sample using an EventList ...
Bool_t fOwnQIndex
the quality index used to calculate R(t), R(T) = sum[t in ~T]{ R(t) }
Definition: CCPruner.h:100
CCTreeNode * GetRoot()
std::vector< TMVA::DecisionTreeNode * > fPruneSequence
(pruned) decision tree
Definition: CCPruner.h:104
virtual void PrintRec(std::ostream &os) const
recursive printout of the node and its daughters
std::vector< Event * > EventList
Definition: CCPruner.h:67
unsigned int UInt_t
Definition: RtypesCore.h:42
void SetMinAlphaC(Double_t alpha)
Definition: CCTreeWrapper.h:96
REAL epsilon
Definition: triangle.c:617
const EventList * fValidationSample
regularization parameter in CC pruning
Definition: CCPruner.h:97
Double_t GetMinAlphaC() const
Definition: CCTreeWrapper.h:99
double Double_t
Definition: RtypesCore.h:55
void SetAlphaC(Double_t alpha)
Definition: CCTreeWrapper.h:90
Abstract ClassifierFactory template that handles arbitrary types.
#define NULL
Definition: Rtypes.h:82
CCPruner(DecisionTree *t_max, const EventList *validationSample, SeparationBase *qualityIndex=NULL)
constructor
Definition: CCPruner.cxx:39
void PruneNode(CCTreeNode *t)
remove the branch rooted at node t
std::vector< Float_t > fPruneStrengthList
map of weakest links (i.e., branches to prune) -> pruning index
Definition: CCPruner.h:105
const Bool_t kTRUE
Definition: Rtypes.h:91
float * q
Definition: THbookFile.cxx:87
const Int_t n
Definition: legend1.C:16
TRandom3 R
a TMatrixD.
Definition: testIO.cxx:28