Logo ROOT   6.08/07
Reference Guide
CCPruner.cxx
Go to the documentation of this file.
1 /**********************************************************************************
2  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
3  * Package: TMVA *
4  * Class : CCPruner *
5  * Web : http://tmva.sourceforge.net *
6  * *
7  * Description: Cost Complexity Pruning *
8  *
9  * Author: Doug Schouten (dschoute@sfu.ca)
10  *
11  * *
12  * Copyright (c) 2007: *
13  * CERN, Switzerland *
14  * MPI-K Heidelberg, Germany *
15  * U. of Texas at Austin, USA *
16  * *
17  * Redistribution and use in source and binary forms, with or without *
18  * modification, are permitted according to the terms listed in LICENSE *
19  * (http://tmva.sourceforge.net/LICENSE) *
20  **********************************************************************************/
21 
22 #include "TMVA/CCPruner.h"
23 #include "TMVA/SeparationBase.h"
24 #include "TMVA/GiniIndex.h"
26 #include "TMVA/CCTreeWrapper.h"
27 #include "TMVA/DataSet.h"
28 
29 #include "Rtypes.h"
30 
31 #include <iostream>
32 #include <fstream>
33 #include <limits>
34 #include <math.h>
35 
36 namespace TMVA {
37  class DecisionTree;
38 }
39 
40 using namespace TMVA;
41 
42 ////////////////////////////////////////////////////////////////////////////////
43 /// constructor
44 
45 CCPruner::CCPruner( DecisionTree* t_max, const EventList* validationSample,
46  SeparationBase* qualityIndex ) :
47  fAlpha(-1.0),
48  fValidationSample(validationSample),
49  fValidationDataSet(NULL),
50  fOptimalK(-1)
51 {
52  fTree = t_max;
53 
54  if(qualityIndex == NULL) {
55  fOwnQIndex = true;
57  }
58  else {
59  fOwnQIndex = false;
60  fQualityIndex = qualityIndex;
61  }
62  fDebug = kTRUE;
63 }
64 
65 ////////////////////////////////////////////////////////////////////////////////
66 /// constructor
67 
68 CCPruner::CCPruner( DecisionTree* t_max, const DataSet* validationSample,
69  SeparationBase* qualityIndex ) :
70  fAlpha(-1.0),
72  fValidationDataSet(validationSample),
73  fOptimalK(-1)
74 {
75  fTree = t_max;
76 
77  if(qualityIndex == NULL) {
78  fOwnQIndex = true;
80  }
81  else {
82  fOwnQIndex = false;
83  fQualityIndex = qualityIndex;
84  }
85  fDebug = kTRUE;
86 }
87 
88 
89 ////////////////////////////////////////////////////////////////////////////////
90 
92 {
93  if(fOwnQIndex) delete fQualityIndex;
94  // destructor
95 }
96 
97 ////////////////////////////////////////////////////////////////////////////////
98 /// determine the pruning sequence
99 
101 {
102  Bool_t HaveStopCondition = fAlpha > 0; // keep pruning the tree until reach the limit fAlpha
103 
104  // build a wrapper tree to perform work on
105  CCTreeWrapper* dTWrapper = new CCTreeWrapper(fTree, fQualityIndex);
106 
107  Int_t k = 0;
109  Double_t alpha = -1.0e10;
110 
111  std::ofstream outfile;
112  if (fDebug) outfile.open("costcomplexity.log");
113  if(!HaveStopCondition && (fValidationSample == NULL && fValidationDataSet == NULL) ) {
114  if (fDebug) outfile << "ERROR: no validation sample, so cannot optimize pruning!" << std::endl;
115  delete dTWrapper;
116  if (fDebug) outfile.close();
117  return;
118  }
119 
120  CCTreeWrapper::CCTreeNode* R = dTWrapper->GetRoot();
121  while(R->GetNLeafDaughters() > 1) { // prune upwards to the root node
122  if(R->GetMinAlphaC() > alpha)
123  alpha = R->GetMinAlphaC(); // initialize alpha
124 
125  if(HaveStopCondition && alpha > fAlpha) break;
126 
128 
129  while(t->GetMinAlphaC() < t->GetAlphaC()) { // descend to the weakest link
130 
131  if(fabs(t->GetMinAlphaC() - t->GetLeftDaughter()->GetMinAlphaC())/fabs(t->GetMinAlphaC()) < epsilon)
132  t = t->GetLeftDaughter();
133  else
134  t = t->GetRightDaughter();
135  }
136 
137  if( t == R ) {
138  if (fDebug) outfile << std::endl << "Caught trying to prune the root node!" << std::endl;
139  break;
140  }
141 
143 
144  if (fDebug){
145  outfile << "===========================" << std::endl
146  << "Pruning branch listed below" << std::endl
147  << "===========================" << std::endl;
148  t->PrintRec( outfile );
149 
150  }
151  if (!(t->GetLeftDaughter()) && !(t->GetRightDaughter()) ) {
152  break;
153  }
154  dTWrapper->PruneNode(t); // prune the branch rooted at node t
155 
156  while(t != R) { // go back up the (pruned) tree and recalculate R(T), alpha_c
157  t = t->GetMother();
162  t->SetMinAlphaC(std::min(t->GetAlphaC(), std::min(t->GetLeftDaughter()->GetMinAlphaC(),
163  t->GetRightDaughter()->GetMinAlphaC())));
164  }
165  k += 1;
166  if(!HaveStopCondition) {
167  Double_t q;
169  else q = dTWrapper->TestTreeQuality(fValidationSample);
170  fQualityIndexList.push_back(q);
171  }
172  else {
173  fQualityIndexList.push_back(1.0);
174  }
175  fPruneSequence.push_back(n->GetDTNode());
176  fPruneStrengthList.push_back(alpha);
177  }
178 
179  Double_t qmax = -1.0e6;
180  if(!HaveStopCondition) {
181  for(UInt_t i = 0; i < fQualityIndexList.size(); i++) {
182  if(fQualityIndexList[i] > qmax) {
183  qmax = fQualityIndexList[i];
184  k = i;
185  }
186  }
187  fOptimalK = k;
188  }
189  else {
190  fOptimalK = fPruneSequence.size() - 1;
191  }
192 
193  if (fDebug){
194  outfile << std::endl << "************ Summary **************" << std::endl
195  << "Number of trees in the sequence: " << fPruneSequence.size() << std::endl;
196 
197  outfile << "Pruning strength parameters: [";
198  for(UInt_t i = 0; i < fPruneStrengthList.size()-1; i++)
199  outfile << fPruneStrengthList[i] << ", ";
200  outfile << fPruneStrengthList[fPruneStrengthList.size()-1] << "]" << std::endl;
201 
202  outfile << "Misclassification rates: [";
203  for(UInt_t i = 0; i < fQualityIndexList.size()-1; i++)
204  outfile << fQualityIndexList[i] << ", ";
205  outfile << fQualityIndexList[fQualityIndexList.size()-1] << "]" << std::endl;
206 
207  outfile << "Optimal index: " << fOptimalK+1 << std::endl;
208  outfile.close();
209  }
210  delete dTWrapper;
211 }
212 
213 ////////////////////////////////////////////////////////////////////////////////
214 /// return the prune strength (=alpha) corresponding to the prune sequence
215 
216 std::vector<DecisionTreeNode*> CCPruner::GetOptimalPruneSequence( ) const
217 {
218  std::vector<DecisionTreeNode*> optimalSequence;
219  if( fOptimalK >= 0 ) {
220  for( Int_t i = 0; i < fOptimalK; i++ ) {
221  optimalSequence.push_back(fPruneSequence[i]);
222  }
223  }
224  return optimalSequence;
225 }
226 
227 
void Optimize()
determine the pruning sequence
Definition: CCPruner.cxx:100
std::vector< Float_t > fQualityIndexList
map of alpha -> pruning index
Definition: CCPruner.h:106
Float_t fAlpha
Definition: CCPruner.h:96
virtual void PrintRec(std::ostream &os) const
recursive printout of the node and its daughters
const DataSet * fValidationDataSet
the event sample to select the optimally-pruned tree
Definition: CCPruner.h:98
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
Int_t fOptimalK
map of R(T) -> pruning index
Definition: CCPruner.h:108
void SetResubstitutionEstimate(Double_t R)
Definition: CCTreeWrapper.h:80
DecisionTree * fTree
flag indicates if fQualityIndex is owned by this
Definition: CCPruner.h:102
Bool_t fDebug
index of the optimal tree in the pruned tree sequence
Definition: CCPruner.h:109
Double_t GetResubstitutionEstimate() const
Definition: CCTreeWrapper.h:83
SeparationBase * fQualityIndex
the event sample to select the optimally-pruned tree
Definition: CCPruner.h:99
VecExpr< UnaryOp< Fabs< T >, VecExpr< A, T, D >, T >, T, D > fabs(const VecExpr< A, T, D > &rhs)
Double_t TestTreeQuality(const EventList *validationSample)
return the misclassification rate of a pruned tree for a validation event sample using an EventList ...
Bool_t fOwnQIndex
the quality index used to calculate R(t), R(T) = sum[t in ~T]{ R(t) }
Definition: CCPruner.h:100
CCTreeNode * GetRoot()
std::vector< TMVA::DecisionTreeNode * > fPruneSequence
(pruned) decision tree
Definition: CCPruner.h:104
std::vector< Event * > EventList
Definition: CCPruner.h:67
unsigned int UInt_t
Definition: RtypesCore.h:42
void SetMinAlphaC(Double_t alpha)
Definition: CCTreeWrapper.h:96
REAL epsilon
Definition: triangle.c:617
const EventList * fValidationSample
regularization parameter in CC pruning
Definition: CCPruner.h:97
double Double_t
Definition: RtypesCore.h:55
std::vector< TMVA::DecisionTreeNode * > GetOptimalPruneSequence() const
return the prune strength (=alpha) corresponding to the prune sequence
Definition: CCPruner.cxx:216
DecisionTreeNode * GetDTNode() const
void SetAlphaC(Double_t alpha)
Definition: CCTreeWrapper.h:90
Abstract ClassifierFactory template that handles arbitrary types.
#define NULL
Definition: Rtypes.h:82
CCPruner(DecisionTree *t_max, const EventList *validationSample, SeparationBase *qualityIndex=NULL)
constructor
Definition: CCPruner.cxx:45
void PruneNode(CCTreeNode *t)
remove the branch rooted at node t
std::vector< Float_t > fPruneStrengthList
map of weakest links (i.e., branches to prune) -> pruning index
Definition: CCPruner.h:105
const Bool_t kTRUE
Definition: Rtypes.h:91
float * q
Definition: THbookFile.cxx:87
const Int_t n
Definition: legend1.C:16
TRandom3 R
a TMatrixD.
Definition: testIO.cxx:28
Double_t GetNodeResubstitutionEstimate() const
Definition: CCTreeWrapper.h:76