Logo ROOT   6.08/07
Reference Guide
CostComplexityPruneTool.cxx
Go to the documentation of this file.
1 /**********************************************************************************
2  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
3  * Package: TMVA *
4  * Class : TMVA::DecisionTree *
5  * Web : http://tmva.sourceforge.net *
6  * *
7  * Description: *
8  * Implementation of a Decision Tree *
9  * *
10  * Authors (alphabetical): *
11  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
12  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
13  * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
14  * Doug Schouten <dschoute@sfu.ca> - Simon Fraser U., Canada *
15  * *
16  * Copyright (c) 2005: *
17  * CERN, Switzerland *
18  * U. of Victoria, Canada *
19  * MPI-K Heidelberg, Germany *
20  * *
21  * Redistribution and use in source and binary forms, with or without *
22  * modification, are permitted according to the terms listed in LICENSE *
23  * (http://mva.sourceforge.net/license.txt) *
24  * *
25  **********************************************************************************/
26 
28 
29 #include "TMVA/MsgLogger.h"
30 #include "TMVA/SeparationBase.h"
31 #include "TMVA/DecisionTree.h"
32 
33 #include "RtypesCore.h"
34 
35 #include <fstream>
36 #include <limits>
37 #include <math.h>
38 
39 using namespace TMVA;
40 
41 
42 ////////////////////////////////////////////////////////////////////////////////
43 /// the constructor for the cost complexity prunig
44 
46  IPruneTool(),
47  fLogger(new MsgLogger("CostComplexityPruneTool") )
48 {
49  fOptimalK = -1;
50 
51  // !! changed from Dougs code. Now use the QualityIndex stored already
52  // in the nodes when no "new" QualityIndex calculator is given. Like this
53  // I can easily implement the Regression. For Regression, the pruning uses the
54  // same sepearation index as in the tree building, hence doesn't need to re-calculate
55  // (which would need more info than simply "s" and "b")
56 
57  fQualityIndexTool = qualityIndex;
58 
59  //fLogger->SetMinType( kDEBUG );
61 }
62 
63 ////////////////////////////////////////////////////////////////////////////////
64 /// the destructor for the cost complexity prunig
65 
68 }
69 
70 ////////////////////////////////////////////////////////////////////////////////
71 
74  const IPruneTool::EventSample* validationSample,
75  Bool_t isAutomatic )
76 {
77  // the routine that basically "steers" the pruning process. Call the calculation of
78  // the pruning sequence, the tree quality and alike..
79 
80  if( isAutomatic ) SetAutomatic();
81 
82  if( dt == NULL || (IsAutomatic() && validationSample == NULL) ) {
83  // must have a valid decision tree to prune, and if the prune strength
84  // is to be chosen automatically, must have a test sample from
85  // which to calculate the quality of the pruned tree(s)
86  return NULL;
87  }
88 
89  Double_t Q = -1.0;
90  Double_t W = 1.0;
91 
92  if(IsAutomatic()) {
93  // run the pruning validation sample through the unpruned tree
94  dt->ApplyValidationSample(validationSample);
95  W = dt->GetSumWeights(validationSample); // get the sum of weights in the pruning validation sample
96  // calculate the quality of the tree in the unpruned case
97  Q = dt->TestPrunedTreeQuality();
98 
99  Log() << kDEBUG << "Node purity limit is: " << dt->GetNodePurityLimit() << Endl;
100  Log() << kDEBUG << "Sum of weights in pruning validation sample: " << W << Endl;
101  Log() << kDEBUG << "Quality of tree prior to any pruning is " << Q/W << Endl;
102  }
103 
104  // store the cost complexity metadata for the decision tree at each node
105  try {
107  }
108  catch(std::string error) {
109  Log() << kERROR << "Couldn't initialize the tree meta data because of error ("
110  << error << ")" << Endl;
111  return NULL;
112  }
113 
114  Log() << kDEBUG << "Automatic cost complexity pruning is " << (IsAutomatic()?"on":"off") << "." << Endl;
115 
116  try {
117  Optimize( dt, W ); // run the cost complexity pruning algorithm
118  }
119  catch(std::string error) {
120  Log() << kERROR << "Error optimzing pruning sequence ("
121  << error << ")" << Endl;
122  return NULL;
123  }
124 
125  Log() << kDEBUG << "Index of pruning sequence to stop at: " << fOptimalK << Endl;
126 
127  PruningInfo* info = new PruningInfo();
128 
129 
130  if(fOptimalK < 0) {
131  // no pruning necessary, or wasn't able to compute a sequence
132  info->PruneStrength = 0;
133  info->QualityIndex = Q/W;
134  info->PruneSequence.clear();
135  Log() << kINFO << "no proper pruning could be calulated. Tree "
136  << dt->GetTreeID() << " will not be pruned. Do not worry if this "
137  << " happens for a few trees " << Endl;
138  return info;
139  }
141  Log() << kDEBUG << " prune until k=" << fOptimalK << " with alpha="<<fPruneStrengthList[fOptimalK]<< Endl;
142  for( Int_t i = 0; i < fOptimalK; i++ ){
143  info->PruneSequence.push_back(fPruneSequence[i]);
144  }
145  if( IsAutomatic() ){
147  }
148  else {
150  }
151 
152  return info;
153 }
154 
155 ////////////////////////////////////////////////////////////////////////////////
156 /// initialise "meta data" for the pruning, like the "costcomplexity", the
157 /// critical alpha, the minimal alpha down the tree, etc... for each node!!
158 
160  if( n == NULL ) return;
161 
162  Double_t s = n->GetNSigEvents();
163  Double_t b = n->GetNBkgEvents();
164  // set R(t) = N_events*Gini(t) or MisclassificationError(t), etc.
166  else n->SetNodeR( (s+b)*n->GetSeparationIndex() );
167 
168  if(n->GetLeft() != NULL && n->GetRight() != NULL) { // n is an interior (non-leaf) node
169  n->SetTerminal(kFALSE);
170  // traverse the tree
173  // set |~T_t|
174  n->SetNTerminal( n->GetLeft()->GetNTerminal() +
175  n->GetRight()->GetNTerminal());
176  // set R(T) = sum[n' in ~T]{ R(n') }
177  n->SetSubTreeR( (n->GetLeft()->GetSubTreeR() +
178  n->GetRight()->GetSubTreeR()));
179  // set alpha_c, the alpha value at which it becomes advantageaus to prune at node n
180  n->SetAlpha( ((n->GetNodeR() - n->GetSubTreeR()) /
181  (n->GetNTerminal() - 1)));
182 
183  // G(t) = min( alpha_c, G(l(n)), G(r(n)) )
184  // the minimum alpha in subtree rooted at this node
185  n->SetAlphaMinSubtree( std::min(n->GetAlpha(), std::min(n->GetLeft()->GetAlphaMinSubtree(),
186  n->GetRight()->GetAlphaMinSubtree())));
187  n->SetCC(n->GetAlpha());
188 
189  } else { // n is a terminal node
190  n->SetNTerminal( 1 ); n->SetTerminal( );
192  else n->SetSubTreeR( (s+b)*n->GetSeparationIndex() );
193  n->SetAlpha(std::numeric_limits<double>::infinity( ));
194  n->SetAlphaMinSubtree(std::numeric_limits<double>::infinity( ));
195  n->SetCC(n->GetAlpha());
196  }
197 
198  // DecisionTreeNode* R = (DecisionTreeNode*)mdt->GetRoot();
199  // Double_t x = R->GetAlphaMinSubtree();
200  // Log() << "alphaMin(Root) = " << x << Endl;
201 }
202 
203 
204 ////////////////////////////////////////////////////////////////////////////////
205 /// after the critical alpha values (at which the corresponding nodes would
206 /// be pruned away) had been established in the "InitMetaData" we need now:
207 /// automatic pruning:
208 /// find the value of "alpha" for which the test sample gives minimal error,
209 /// on the tree with all nodes pruned that have alpha_critital < alpha,
210 /// fixed parameter pruning
211 ///
212 
214  Int_t k = 1;
215  Double_t alpha = -1.0e10;
217 
218  fQualityIndexList.clear();
219  fPruneSequence.clear();
220  fPruneStrengthList.clear();
221 
223 
224  Double_t qmin = 0.0;
225  if(IsAutomatic()){
226  // initialize the tree quality (actually at this stage, it is the quality of the yet unpruned tree
227  qmin = dt->TestPrunedTreeQuality()/weights;
228  }
229 
230  // now prune the tree in steps until it is gone. At each pruning step, the pruning
231  // takes place at the node that is regarded as the "weakest link".
232  // for automatic pruning, at each step, we calculate the current quality of the
233  // tree and in the end we will prune at the minimum of the tree quality
234  // for the fixed parameter pruing, the cut is simply set at a relative position
235  // in the sequence according to the "lenght" of the sequence of pruned trees.
236  // 100: at the end (pruned until the root node would be the next pruning candidate
237  // 50: in the middle of the sequence
238  // etc...
239  while(R->GetNTerminal() > 1) { // prune upwards to the root node
240 
241  // initialize alpha
242  alpha = TMath::Max(R->GetAlphaMinSubtree(), alpha);
243 
244  if( R->GetAlphaMinSubtree() >= R->GetAlpha() ) {
245  Log() << kDEBUG << "\nCaught trying to prune the root node!" << Endl;
246  break;
247  }
248 
249 
250  DecisionTreeNode* t = R;
251 
252  // descend to the weakest link
253  while(t->GetAlphaMinSubtree() < t->GetAlpha()) {
254  // std::cout << t->GetAlphaMinSubtree() << " " << t->GetAlpha()<< " "
255  // << t->GetAlphaMinSubtree()- t->GetAlpha()<< " t==R?" << int(t == R) << std::endl;
256  // while( (t->GetAlphaMinSubtree() - t->GetAlpha()) < epsilon) {
257  // if(TMath::Abs(t->GetAlphaMinSubtree() - t->GetLeft()->GetAlphaMinSubtree())/TMath::Abs(t->GetAlphaMinSubtree()) < epsilon) {
258  if(TMath::Abs(t->GetAlphaMinSubtree() - t->GetLeft()->GetAlphaMinSubtree()) < epsilon) {
259  t = t->GetLeft();
260  } else {
261  t = t->GetRight();
262  }
263  }
264 
265  if( t == R ) {
266  Log() << kDEBUG << "\nCaught trying to prune the root node!" << Endl;
267  break;
268  }
269 
270  DecisionTreeNode* n = t;
271 
272  // Log() << kDEBUG << "alpha[" << k << "]: " << alpha << Endl;
273  // Log() << kDEBUG << "===========================" << Endl
274  // << "Pruning branch listed below the node" << Endl;
275  // t->Print( Log() );
276  // Log() << kDEBUG << "===========================" << Endl;
277  // t->PrintRecPrune( Log() );
278 
279  dt->PruneNodeInPlace(t); // prune the branch rooted at node t
280 
281  while(t != R) { // go back up the (pruned) tree and recalculate R(T), alpha_c
282  t = t->GetParent();
283  t->SetNTerminal(t->GetLeft()->GetNTerminal() + t->GetRight()->GetNTerminal());
284  t->SetSubTreeR(t->GetLeft()->GetSubTreeR() + t->GetRight()->GetSubTreeR());
285  t->SetAlpha((t->GetNodeR() - t->GetSubTreeR())/(t->GetNTerminal() - 1));
286  t->SetAlphaMinSubtree(std::min(t->GetAlpha(), std::min(t->GetLeft()->GetAlphaMinSubtree(),
287  t->GetRight()->GetAlphaMinSubtree())));
288  t->SetCC(t->GetAlpha());
289  }
290  k += 1;
291 
292  Log() << kDEBUG << "after this pruning step I would have " << R->GetNTerminal() << " remaining terminal nodes " << Endl;
293 
294  if(IsAutomatic()) {
295  Double_t q = dt->TestPrunedTreeQuality()/weights;
296  fQualityIndexList.push_back(q);
297  }
298  else {
299  fQualityIndexList.push_back(1.0);
300  }
301  fPruneSequence.push_back(n);
302  fPruneStrengthList.push_back(alpha);
303  }
304 
305  if(fPruneSequence.empty()) {
306  fOptimalK = -1;
307  return;
308  }
309 
310  if(IsAutomatic()) {
311  k = -1;
312  for(UInt_t i = 0; i < fQualityIndexList.size(); i++) {
313  if(fQualityIndexList[i] < qmin) {
314  qmin = fQualityIndexList[i];
315  k = i;
316  }
317  }
318  fOptimalK = k;
319  }
320  else {
321  // regularize the prune strength relative to this tree
322  fOptimalK = int(fPruneStrength/100.0 * fPruneSequence.size() );
323  Log() << kDEBUG << "SequenzeSize="<<fPruneSequence.size()
324  << " fOptimalK " << fOptimalK << Endl;
325 
326  }
327 
328  Log() << kDEBUG << "\n************ Summary for Tree " << dt->GetTreeID() << " *******" << Endl
329  << "Number of trees in the sequence: " << fPruneSequence.size() << Endl;
330 
331  Log() << kDEBUG << "Pruning strength parameters: [";
332  for(UInt_t i = 0; i < fPruneStrengthList.size()-1; i++)
333  Log() << kDEBUG << fPruneStrengthList[i] << ", ";
334  Log() << kDEBUG << fPruneStrengthList[fPruneStrengthList.size()-1] << "]" << Endl;
335 
336  Log() << kDEBUG << "Misclassification rates: [";
337  for(UInt_t i = 0; i < fQualityIndexList.size()-1; i++)
338  Log() << kDEBUG << fQualityIndexList[i] << ", ";
339  Log() << kDEBUG << fQualityIndexList[fQualityIndexList.size()-1] << "]" << Endl;
340 
341  Log() << kDEBUG << "Prune index: " << fOptimalK+1 << Endl;
342 
343 }
344 
Double_t PruneStrength
quality measure for a pruned subtree T of T_max
Definition: IPruneTool.h:50
Int_t fOptimalK
map of R(T) -> pruning index
virtual ~CostComplexityPruneTool()
the destructor for the cost complexity prunig
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
Double_t fPruneStrength
Definition: IPruneTool.h:103
Double_t GetNodePurityLimit() const
Definition: DecisionTree.h:170
CostComplexityPruneTool(SeparationBase *qualityIndex=NULL)
the constructor for the cost complexity prunig
virtual DecisionTreeNode * GetParent() const
int Int_t
Definition: RtypesCore.h:41
std::vector< DecisionTreeNode * > PruneSequence
the regularization parameter for pruning
Definition: IPruneTool.h:51
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kFALSE
Definition: Rtypes.h:92
Double_t TestPrunedTreeQuality(const DecisionTreeNode *dt=NULL, Int_t mode=0) const
return the misclassification rate of a pruned tree a "pruned tree" may have set the variable "IsTermi...
Float_t GetNSigEvents(void) const
virtual DecisionTreeNode * GetRoot() const
Definition: DecisionTree.h:102
void SetAutomatic()
Definition: IPruneTool.h:96
Short_t Abs(Short_t d)
Definition: TMathBase.h:110
Bool_t IsAutomatic() const
Definition: IPruneTool.h:97
Double_t GetSubTreeR() const
Float_t GetSeparationIndex(void) const
Float_t GetNBkgEvents(void) const
Double_t GetNodeR() const
Double_t GetSumWeights(const EventConstList *validationSample) const
calculate the normalization factor for a pruning validation sample
void SetNodeR(Double_t r)
void SetMinType(EMsgType minType)
Definition: MsgLogger.h:76
std::vector< const Event * > EventSample
Definition: IPruneTool.h:76
void Optimize(DecisionTree *dt, Double_t weights)
after the critical alpha values (at which the corresponding nodes would be pruned away) had been esta...
void SetSubTreeR(Double_t r)
void SetAlpha(Double_t alpha)
unsigned int UInt_t
Definition: RtypesCore.h:42
void PruneNodeInPlace(TMVA::DecisionTreeNode *node)
prune a node temporaily (without actually deleting its decendants which allows testing the pruned tre...
void ApplyValidationSample(const EventConstList *validationSample) const
run the validation sample through the (pruned) tree and fill in the nodes the variables NSValidation ...
REAL epsilon
Definition: triangle.c:617
Double_t GetAlphaMinSubtree() const
double Double_t
Definition: RtypesCore.h:55
void SetAlphaMinSubtree(Double_t g)
void SetTerminal(Bool_t s=kTRUE)
std::vector< Double_t > fPruneStrengthList
map of weakest links (i.e., branches to prune) -> pruning index
std::vector< Double_t > fQualityIndexList
map of alpha -> pruning index
std::vector< DecisionTreeNode * > fPruneSequence
the quality index used to calculate R(t), R(T) = sum[t in ~T]{ R(t) }
Abstract ClassifierFactory template that handles arbitrary types.
void InitTreePruningMetaData(DecisionTreeNode *n)
the optimal index of the prune sequence
virtual Double_t GetSeparationIndex(const Double_t &s, const Double_t &b)=0
Short_t Max(Short_t a, Short_t b)
Definition: TMathBase.h:202
you should not use this method at all Int_t Int_t Double_t Double_t Double_t Int_t Double_t Double_t Double_t Double_t b
Definition: TRolke.cxx:630
virtual DecisionTreeNode * GetLeft() const
#define NULL
Definition: Rtypes.h:82
Double_t QualityIndex
Definition: IPruneTool.h:49
virtual DecisionTreeNode * GetRight() const
Double_t GetAlpha() const
float * q
Definition: THbookFile.cxx:87
const Int_t n
Definition: legend1.C:16
TRandom3 R
a TMatrixD.
Definition: testIO.cxx:28
static double Q[]
virtual PruningInfo * CalculatePruningInfo(DecisionTree *dt, const IPruneTool::EventSample *testEvents=NULL, Bool_t isAutomatic=kFALSE)
MsgLogger & Log() const
output stream to save logging information