ROOT  6.07/01
Reference Guide
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
ExpectedErrorPruneTool.cxx
Go to the documentation of this file.
1 /**********************************************************************************
2  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
3  * Package: TMVA *
4  * Class : TMVA::DecisionTree *
5  * Web : http://tmva.sourceforge.net *
6  * *
7  * Description: *
8  * Implementation of a Decision Tree *
9  * *
10  * Authors (alphabetical): *
11  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
12  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
13  * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
14  * Doug Schouten <dschoute@sfu.ca> - Simon Fraser U., Canada *
15  * *
16  * Copyright (c) 2005: *
17  * CERN, Switzerland *
18  * U. of Victoria, Canada *
19  * MPI-K Heidelberg, Germany *
20  * *
21  * Redistribution and use in source and binary forms, with or without *
22  * modification, are permitted according to the terms listed in LICENSE *
23  * (http://mva.sourceforge.net/license.txt) *
24  * *
25  **********************************************************************************/
26 
28 #include "TMVA/DecisionTree.h"
29 #include "TMVA/IPruneTool.h"
30 #include "TMVA/MsgLogger.h"
31 #include "TMVA/Types.h"
32 
33 #include "RtypesCore.h"
34 #include "Rtypes.h"
35 #include "TMath.h"
36 
37 #include <map>
38 
39 ////////////////////////////////////////////////////////////////////////////////
40 
42  IPruneTool(),
43  fDeltaPruneStrength(0),
44  fNodePurityLimit(1),
45  fLogger( new MsgLogger("ExpectedErrorPruneTool") )
46 {}
47 
48 ////////////////////////////////////////////////////////////////////////////////
49 
51 {
52  delete fLogger;
53 }
54 
55 ////////////////////////////////////////////////////////////////////////////////
56 
59  const IPruneTool::EventSample* validationSample,
60  Bool_t isAutomatic )
61 {
62  if( isAutomatic ) {
63  //SetAutomatic( );
64  isAutomatic = kFALSE;
65  Log() << kWARNING << "Sorry autmoatic pruning strength determination is not implemented yet" << Endl;
66  }
67  if( dt == NULL || (IsAutomatic() && validationSample == NULL) ) {
68  // must have a valid decision tree to prune, and if the prune strength
69  // is to be chosen automatically, must have a test sample from
70  // which to calculate the quality of the pruned tree(s)
71  return NULL;
72  }
73  fNodePurityLimit = dt->GetNodePurityLimit();
74 
75  if(IsAutomatic()) {
76  Log() << kFATAL << "Sorry autmoatic pruning strength determination is not implemented yet" << Endl;
77  /*
78  dt->ApplyValidationSample(validationSample);
79  Double_t weights = dt->GetSumWeights(validationSample);
80  // set the initial prune strength
81  fPruneStrength = 1.0e-3; //! FIXME somehow make this automatic, it depends strongly on the tree structure
82  // better to set it too small, it will be increased automatically
83  fDeltaPruneStrength = 1.0e-5;
84  Int_t nnodes = this->CountNodes((DecisionTreeNode*)dt->GetRoot());
85 
86  Bool_t forceStop = kFALSE;
87  Int_t errCount = 0,
88  lastNodeCount = nnodes;
89 
90  // find the maxiumum prune strength that still leaves the root's daughter nodes
91 
92  while ( nnodes > 1 && !forceStop ) {
93  fPruneStrength += fDeltaPruneStrength;
94  Log() << "----------------------------------------------------" << Endl;
95  FindListOfNodes((DecisionTreeNode*)dt->GetRoot());
96  for( UInt_t i = 0; i < fPruneSequence.size(); i++ )
97  fPruneSequence[i]->SetTerminal(); // prune all the nodes from the sequence
98  // test the quality of the pruned tree
99  Double_t quality = 1.0 - dt->TestPrunedTreeQuality()/weights;
100  fQualityMap.insert(std::make_pair<const Double_t,Double_t>(quality,fPruneStrength));
101 
102  nnodes = CountNodes((DecisionTreeNode*)dt->GetRoot()); // count the number of nodes in the pruned tree
103 
104  Log() << "Prune strength : " << fPruneStrength << Endl;
105  Log() << "Had " << lastNodeCount << " nodes, now have " << nnodes << Endl;
106  Log() << "Quality index is: " << quality << Endl;
107 
108  if (lastNodeCount == nnodes) errCount++;
109  else {
110  errCount=0; // reset counter
111  if ( nnodes < lastNodeCount / 2 ) {
112  Log() << "Decreasing fDeltaPruneStrength to " << fDeltaPruneStrength/2.0
113  << " because the number of nodes in the tree decreased by a factor of 2." << Endl;
114  fDeltaPruneStrength /= 2.;
115  }
116  }
117  lastNodeCount = nnodes;
118  if (errCount > 20) {
119  Log() << "Increasing fDeltaPruneStrength to " << fDeltaPruneStrength*2.0
120  << " because the number of nodes in the tree didn't change." << Endl;
121  fDeltaPruneStrength *= 2.0;
122  }
123  if (errCount > 40) {
124  Log() << "Having difficulty determining the optimal prune strength, bailing out!" << Endl;
125  forceStop = kTRUE;
126  }
127  // reset the tree for the next iteration
128  for( UInt_t i = 0; i < fPruneSequence.size(); i++ )
129  fPruneSequence[i]->SetTerminal(false);
130  fPruneSequence.clear();
131  }
132  // from the set of pruned trees, find the one with the optimal quality index
133  std::multimap<Double_t,Double_t>::reverse_iterator it = fQualityMap.rend(); ++it;
134  fPruneStrength = it->second;
135  FindListOfNodes((DecisionTreeNode*)dt->GetRoot());
136 
137  // adjust the step size for the next tree automatically
138  fPruneStrength = 1.0e-3;
139  fDeltaPruneStrength = (fPruneStrength - 1.0)/(Double_t)fQualityMap.size();
140 
141  return new PruningInfo(it->first, it->second, fPruneSequence);
142  */
143  return NULL;
144  }
145  else { // no automatic pruning - just use the provided prune strength parameter
146  FindListOfNodes( (DecisionTreeNode*)dt->GetRoot() );
147  return new PruningInfo( -1.0, fPruneStrength, fPruneSequence );
148  }
149 }
150 
151 ////////////////////////////////////////////////////////////////////////////////
152 /// recursive pruning of nodes using the Expected Error Pruning (EEP)
153 
155 {
158  if (node->GetNodeType() == 0 && !(node->IsTerminal())) { // check all internal nodes
159  this->FindListOfNodes(l);
160  this->FindListOfNodes(r);
161  if (this->GetSubTreeError(node) >= this->GetNodeError(node)) {
162  //node->Print(Log());
163  fPruneSequence.push_back(node);
164  }
165  }
166 }
167 
168 ////////////////////////////////////////////////////////////////////////////////
169 /// calculate the expected statistical error on the subtree below "node"
170 /// which is used in the expected error pruning
171 
173 {
176  if (node->GetNodeType() == 0 && !(node->IsTerminal())) {
177  Double_t subTreeError =
178  (l->GetNEvents() * this->GetSubTreeError(l) +
179  r->GetNEvents() * this->GetSubTreeError(r)) /
180  node->GetNEvents();
181  return subTreeError;
182  }
183  else {
184  return this->GetNodeError(node);
185  }
186 }
187 
188 ////////////////////////////////////////////////////////////////////////////////
189 /// Calculate an UPPER limit on the error made by the classification done
190 /// by this node. If the S/S+B of the node is f, then according to the
191 /// training sample, the error rate (fraction of misclassified events by
192 /// this node) is (1-f)
193 /// Now f has a statistical error according to the binomial distribution
194 /// hence the error on f can be estimated (same error as the binomial error
195 /// for efficency calculations ( sigma = sqrt(eff(1-eff)/nEvts ) )
196 
198 {
199  Double_t errorRate = 0;
200 
201  Double_t nEvts = node->GetNEvents();
202 
203  // fraction of correctly classified events by this node:
204  Double_t f = 0;
205  if (node->GetPurity() > fNodePurityLimit) f = node->GetPurity();
206  else f = (1-node->GetPurity());
207 
208  Double_t df = TMath::Sqrt(f*(1-f)/nEvts);
209 
210  errorRate = std::min(1.0,(1.0 - (f-fPruneStrength*df)));
211 
212  // -------------------------------------------------------------------
213  // standard algorithm:
214  // step 1: Estimate error on node using Laplace estimate
215  // NodeError = (N - n + k -1 ) / (N + k)
216  // N: number of events
217  // k: number of event classes (2 for Signal, Background)
218  // n: n event out of N belong to the class which has the majority in the node
219  // step 2: Approximate "backed-up" error assuming we did not prune
220  // (I'm never quite sure if they consider whole subtrees, or only 'next-to-leaf'
221  // nodes)...
222  // Subtree error = Sum_children ( P_i * NodeError_i)
223  // P_i = probability of the node to make the decision, i.e. fraction of events in
224  // leaf node ( N_leaf / N_parent)
225  // step 3:
226 
227  // Minimum Error Pruning (MEP) accordig to Niblett/Bratko
228  //# of correctly classified events by this node:
229  //Double_t n=f*nEvts ;
230  //Double_t p_apriori = 0.5, m=100;
231  //errorRate = (nEvts - n + (1-p_apriori) * m ) / (nEvts + m);
232 
233  // Pessimistic error Pruing (proposed by Quinlan (error estimat with continuity approximation)
234  //# of correctly classified events by this node:
235  //Double_t n=f*nEvts ;
236  //errorRate = (nEvts - n + 0.5) / nEvts ;
237 
238  //const Double Z=.65;
239  //# of correctly classified events by this node:
240  //Double_t n=f*nEvts ;
241  //errorRate = (f + Z*Z/(2*nEvts ) + Z*sqrt(f/nEvts - f*f/nEvts + Z*Z/4/nEvts /nEvts ) ) / (1 + Z*Z/nEvts );
242  //errorRate = (n + Z*Z/2 + Z*sqrt(n - n*n/nEvts + Z*Z/4) )/ (nEvts + Z*Z);
243  //errorRate = 1 - errorRate;
244  // -------------------------------------------------------------------
245 
246  return errorRate;
247 }
248 
249 
static Vc_ALWAYS_INLINE int_v min(const int_v &x, const int_v &y)
Definition: vector.h:433
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
Double_t GetNodePurityLimit() const
Definition: DecisionTree.h:170
virtual DecisionTreeNode * GetRight() const
Int_t GetNodeType(void) const
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kFALSE
Definition: Rtypes.h:92
virtual DecisionTreeNode * GetLeft() const
TFile * f
virtual DecisionTreeNode * GetRoot() const
Definition: DecisionTree.h:102
Float_t GetPurity(void) const
std::vector< const Event * > EventSample
Definition: IPruneTool.h:76
virtual PruningInfo * CalculatePruningInfo(DecisionTree *dt, const IPruneTool::EventSample *testEvents=NULL, Bool_t isAutomatic=kFALSE)
ROOT::R::TRInterface & r
Definition: Object.C:4
void FindListOfNodes(DecisionTreeNode *node)
recursive pruning of nodes using the Expected Error Pruning (EEP)
Float_t GetNEvents(void) const
TLine * l
Definition: textangle.C:4
Bool_t IsTerminal() const
double Double_t
Definition: RtypesCore.h:55
Double_t GetNodeError(DecisionTreeNode *node) const
Calculate an UPPER limit on the error made by the classification done by this node.
#define NULL
Definition: Rtypes.h:82
Double_t Sqrt(Double_t x)
Definition: TMath.h:464
Double_t GetSubTreeError(DecisionTreeNode *node) const
calculate the expected statistical error on the subtree below "node" which is used in the expected er...
Definition: math.cpp:60