Logo ROOT   6.07/09
Reference Guide
MethodRuleFit.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Fredrik Tegenfeldt
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodRuleFit *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Implementation (see header file for description) *
12  * *
13  * Authors (alphabetical): *
14  * Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA *
15  * *
16  * Copyright (c) 2005: *
17  * CERN, Switzerland *
18  * Iowa State U. *
19  * MPI-K Heidelberg, Germany *
20  * *
21  * Redistribution and use in source and binary forms, with or without *
22  * modification, are permitted according to the terms listed in LICENSE *
23  * (http://tmva.sourceforge.net/LICENSE) *
24  **********************************************************************************/
25 
26 //_______________________________________________________________________
27 //
28 // J Friedman's RuleFit method
29 //_______________________________________________________________________
30 
31 #include "TMVA/MethodRuleFit.h"
32 
33 #include "TMVA/ClassifierFactory.h"
34 #include "TMVA/Config.h"
35 #include "TMVA/Configurable.h"
36 #include "TMVA/CrossEntropy.h"
37 #include "TMVA/DataSet.h"
38 #include "TMVA/DecisionTree.h"
39 #include "TMVA/GiniIndex.h"
40 #include "TMVA/IMethod.h"
41 #include "TMVA/MethodBase.h"
43 #include "TMVA/MsgLogger.h"
44 #include "TMVA/Ranking.h"
45 #include "TMVA/RuleFitAPI.h"
46 #include "TMVA/SdivSqrtSplusB.h"
47 #include "TMVA/SeparationBase.h"
48 #include "TMVA/Timer.h"
49 #include "TMVA/Tools.h"
50 #include "TMVA/Types.h"
51 
52 #include "Riostream.h"
53 #include "TRandom3.h"
54 #include "TMath.h"
55 #include "TMatrix.h"
56 #include "TDirectory.h"
57 
58 #include <algorithm>
59 #include <list>
60 
61 using std::min;
62 
63 REGISTER_METHOD(RuleFit)
64 
65 ClassImp(TMVA::MethodRuleFit)
66 
67 ////////////////////////////////////////////////////////////////////////////////
68 /// standard constructor
69 
71  const TString& methodTitle,
72  DataSetInfo& theData,
73  const TString& theOption) :
74  MethodBase( jobName, Types::kRuleFit, methodTitle, theData, theOption)
75  , fSignalFraction(0)
76  , fNTImportance(0)
77  , fNTCoefficient(0)
78  , fNTSupport(0)
79  , fNTNcuts(0)
80  , fNTNvars(0)
81  , fNTPtag(0)
82  , fNTPss(0)
83  , fNTPsb(0)
84  , fNTPbs(0)
85  , fNTPbb(0)
86  , fNTSSB(0)
87  , fNTType(0)
88  , fUseRuleFitJF(kFALSE)
89  , fRFNrules(0)
90  , fRFNendnodes(0)
91  , fNTrees(0)
92  , fTreeEveFrac(0)
93  , fSepType(0)
94  , fMinFracNEve(0)
95  , fMaxFracNEve(0)
96  , fNCuts(0)
97  , fPruneMethod(TMVA::DecisionTree::kCostComplexityPruning)
98  , fPruneStrength(0)
99  , fUseBoost(kFALSE)
100  , fGDPathEveFrac(0)
101  , fGDValidEveFrac(0)
102  , fGDTau(0)
103  , fGDTauPrec(0)
104  , fGDTauMin(0)
105  , fGDTauMax(0)
106  , fGDTauScan(0)
107  , fGDPathStep(0)
108  , fGDNPathSteps(0)
109  , fGDErrScale(0)
110  , fMinimp(0)
111  , fRuleMinDist(0)
112  , fLinQuantile(0)
113 {
114 }
115 
116 ////////////////////////////////////////////////////////////////////////////////
117 /// constructor from weight file
118 
120  const TString& theWeightFile) :
121  MethodBase( Types::kRuleFit, theData, theWeightFile)
122  , fSignalFraction(0)
123  , fNTImportance(0)
124  , fNTCoefficient(0)
125  , fNTSupport(0)
126  , fNTNcuts(0)
127  , fNTNvars(0)
128  , fNTPtag(0)
129  , fNTPss(0)
130  , fNTPsb(0)
131  , fNTPbs(0)
132  , fNTPbb(0)
133  , fNTSSB(0)
134  , fNTType(0)
135  , fUseRuleFitJF(kFALSE)
136  , fRFNrules(0)
137  , fRFNendnodes(0)
138  , fNTrees(0)
139  , fTreeEveFrac(0)
140  , fSepType(0)
141  , fMinFracNEve(0)
142  , fMaxFracNEve(0)
143  , fNCuts(0)
144  , fPruneMethod(TMVA::DecisionTree::kCostComplexityPruning)
145  , fPruneStrength(0)
146  , fUseBoost(kFALSE)
147  , fGDPathEveFrac(0)
148  , fGDValidEveFrac(0)
149  , fGDTau(0)
150  , fGDTauPrec(0)
151  , fGDTauMin(0)
152  , fGDTauMax(0)
153  , fGDTauScan(0)
154  , fGDPathStep(0)
155  , fGDNPathSteps(0)
156  , fGDErrScale(0)
157  , fMinimp(0)
158  , fRuleMinDist(0)
159  , fLinQuantile(0)
160 {
161 }
162 
163 ////////////////////////////////////////////////////////////////////////////////
164 /// destructor
165 
167 {
168  for (UInt_t i=0; i<fEventSample.size(); i++) delete fEventSample[i];
169  for (UInt_t i=0; i<fForest.size(); i++) delete fForest[i];
170 }
171 
172 ////////////////////////////////////////////////////////////////////////////////
173 /// RuleFit can handle classification with 2 classes
174 
176 {
177  if (type == Types::kClassification && numberClasses == 2) return kTRUE;
178  return kFALSE;
179 }
180 
181 ////////////////////////////////////////////////////////////////////////////////
182 /// define the options (their key words) that can be set in the option string
183 /// know options.
184 ///---------
185 /// general
186 ///---------
187 /// RuleFitModule <string>
188 /// available values are: RFTMVA - use TMVA implementation
189 /// RFFriedman - use Friedmans original implementation
190 ///----------------------
191 /// Path search (fitting)
192 ///----------------------
193 /// GDTau <float> gradient-directed path: fit threshhold, default
194 /// GDTauPrec <float> gradient-directed path: precision of estimated tau
195 /// GDStep <float> gradient-directed path: step size
196 /// GDNSteps <float> gradient-directed path: number of steps
197 /// GDErrScale <float> stop scan when error>scale*errmin
198 ///-----------------
199 /// Tree generation
200 ///-----------------
201 /// fEventsMin <float> minimum fraction of events in a splittable node
202 /// fEventsMax <float> maximum fraction of events in a splittable node
203 /// nTrees <float> number of trees in forest.
204 /// ForestType <string>
205 /// available values are: Random - create forest using random subsample and only random variables subset at each node
206 /// AdaBoost - create forest with boosted events
207 ///
208 ///-----------------
209 /// Model creation
210 ///-----------------
211 /// RuleMinDist <float> min distance allowed between rules
212 /// MinImp <float> minimum rule importance accepted
213 /// Model <string> model to be used
214 /// available values are: ModRuleLinear <default>
215 /// ModRule
216 /// ModLinear
217 ///
218 ///-----------------
219 /// Friedmans module
220 ///-----------------
221 /// RFWorkDir <string> directory where Friedmans module (rf_go.exe) is installed
222 /// RFNrules <int> maximum number of rules allowed
223 /// RFNendnodes <int> average number of end nodes in the forest of trees
224 ///
225 
227 {
228  DeclareOptionRef(fGDTau=-1, "GDTau", "Gradient-directed (GD) path: default fit cut-off");
229  DeclareOptionRef(fGDTauPrec=0.01, "GDTauPrec", "GD path: precision of tau");
230  DeclareOptionRef(fGDPathStep=0.01, "GDStep", "GD path: step size");
231  DeclareOptionRef(fGDNPathSteps=10000, "GDNSteps", "GD path: number of steps");
232  DeclareOptionRef(fGDErrScale=1.1, "GDErrScale", "Stop scan when error > scale*errmin");
233  DeclareOptionRef(fLinQuantile, "LinQuantile", "Quantile of linear terms (removes outliers)");
234  DeclareOptionRef(fGDPathEveFrac=0.5, "GDPathEveFrac", "Fraction of events used for the path search");
235  DeclareOptionRef(fGDValidEveFrac=0.5, "GDValidEveFrac", "Fraction of events used for the validation");
236  // tree options
237  DeclareOptionRef(fMinFracNEve=0.1, "fEventsMin", "Minimum fraction of events in a splittable node");
238  DeclareOptionRef(fMaxFracNEve=0.9, "fEventsMax", "Maximum fraction of events in a splittable node");
239  DeclareOptionRef(fNTrees=20, "nTrees", "Number of trees in forest.");
240 
241  DeclareOptionRef(fForestTypeS="AdaBoost", "ForestType", "Method to use for forest generation (AdaBoost or RandomForest)");
242  AddPreDefVal(TString("AdaBoost"));
243  AddPreDefVal(TString("Random"));
244  // rule cleanup options
245  DeclareOptionRef(fRuleMinDist=0.001, "RuleMinDist", "Minimum distance between rules");
246  DeclareOptionRef(fMinimp=0.01, "MinImp", "Minimum rule importance accepted");
247  // rule model option
248  DeclareOptionRef(fModelTypeS="ModRuleLinear", "Model", "Model to be used");
249  AddPreDefVal(TString("ModRule"));
250  AddPreDefVal(TString("ModRuleLinear"));
251  AddPreDefVal(TString("ModLinear"));
252  DeclareOptionRef(fRuleFitModuleS="RFTMVA", "RuleFitModule","Which RuleFit module to use");
253  AddPreDefVal(TString("RFTMVA"));
254  AddPreDefVal(TString("RFFriedman"));
255 
256  DeclareOptionRef(fRFWorkDir="./rulefit", "RFWorkDir", "Friedman\'s RuleFit module (RFF): working dir");
257  DeclareOptionRef(fRFNrules=2000, "RFNrules", "RFF: Mximum number of rules");
258  DeclareOptionRef(fRFNendnodes=4, "RFNendnodes", "RFF: Average number of end nodes");
259 }
260 
261 ////////////////////////////////////////////////////////////////////////////////
262 /// process the options specified by the user
263 
265 {
267  Log() << kFATAL << "Mechanism to ignore events with negative weights in training not yet available for method: "
268  << GetMethodTypeName()
269  << " --> please remove \"IgnoreNegWeightsInTraining\" option from booking string."
270  << Endl;
271  }
272 
274  if (fRuleFitModuleS == "rftmva") fUseRuleFitJF = kFALSE;
275  else if (fRuleFitModuleS == "rffriedman") fUseRuleFitJF = kTRUE;
276  else fUseRuleFitJF = kTRUE;
277 
278  fSepTypeS.ToLower();
279  if (fSepTypeS == "misclassificationerror") fSepType = new MisClassificationError();
280  else if (fSepTypeS == "giniindex") fSepType = new GiniIndex();
281  else if (fSepTypeS == "crossentropy") fSepType = new CrossEntropy();
282  else fSepType = new SdivSqrtSplusB();
283 
285  if (fModelTypeS == "modlinear" ) fRuleFit.SetModelLinear();
286  else if (fModelTypeS == "modrule" ) fRuleFit.SetModelRules();
287  else fRuleFit.SetModelFull();
288 
291  else if (fPruneMethodS == "costcomplexity" ) fPruneMethod = DecisionTree::kCostComplexityPruning;
293 
295  if (fForestTypeS == "random" ) fUseBoost = kFALSE;
296  else if (fForestTypeS == "adaboost" ) fUseBoost = kTRUE;
297  else fUseBoost = kTRUE;
298  //
299  // if creating the forest by boosting the events
300  // the full training sample is used per tree
301  // -> only true for the TMVA version of RuleFit.
302  if (fUseBoost && (!fUseRuleFitJF)) fTreeEveFrac = 1.0;
303 
304  // check event fraction for tree generation
305  // if <0 set to automatic number
306  if (fTreeEveFrac<=0) {
307  Int_t nevents = Data()->GetNTrainingEvents();
308  Double_t n = static_cast<Double_t>(nevents);
309  fTreeEveFrac = min( 0.5, (100.0 +6.0*sqrt(n))/n);
310  }
311  // verify ranges of options
312  VerifyRange(Log(), "nTrees", fNTrees,0,100000,20);
313  VerifyRange(Log(), "MinImp", fMinimp,0.0,1.0,0.0);
314  VerifyRange(Log(), "GDTauPrec", fGDTauPrec,1e-5,5e-1);
315  VerifyRange(Log(), "GDTauMin", fGDTauMin,0.0,1.0);
316  VerifyRange(Log(), "GDTauMax", fGDTauMax,fGDTauMin,1.0);
317  VerifyRange(Log(), "GDPathStep", fGDPathStep,0.0,100.0,0.01);
318  VerifyRange(Log(), "GDErrScale", fGDErrScale,1.0,100.0,1.1);
319  VerifyRange(Log(), "GDPathEveFrac", fGDPathEveFrac,0.01,0.9,0.5);
320  VerifyRange(Log(), "GDValidEveFrac",fGDValidEveFrac,0.01,1.0-fGDPathEveFrac,1.0-fGDPathEveFrac);
321  VerifyRange(Log(), "fEventsMin", fMinFracNEve,0.0,1.0);
322  VerifyRange(Log(), "fEventsMax", fMaxFracNEve,fMinFracNEve,1.0);
323 
334 
335 
336  // check if Friedmans module is used.
337  // print a message concerning the options.
338  if (fUseRuleFitJF) {
339  Log() << kINFO << "" << Endl;
340  Log() << kINFO << "--------------------------------------" <<Endl;
341  Log() << kINFO << "Friedmans RuleFit module is selected." << Endl;
342  Log() << kINFO << "Only the following options are used:" << Endl;
343  Log() << kINFO << Endl;
344  Log() << kINFO << gTools().Color("bold") << " Model" << gTools().Color("reset") << Endl;
345  Log() << kINFO << gTools().Color("bold") << " RFWorkDir" << gTools().Color("reset") << Endl;
346  Log() << kINFO << gTools().Color("bold") << " RFNrules" << gTools().Color("reset") << Endl;
347  Log() << kINFO << gTools().Color("bold") << " RFNendnodes" << gTools().Color("reset") << Endl;
348  Log() << kINFO << gTools().Color("bold") << " GDNPathSteps" << gTools().Color("reset") << Endl;
349  Log() << kINFO << gTools().Color("bold") << " GDPathStep" << gTools().Color("reset") << Endl;
350  Log() << kINFO << gTools().Color("bold") << " GDErrScale" << gTools().Color("reset") << Endl;
351  Log() << kINFO << "--------------------------------------" <<Endl;
352  Log() << kINFO << Endl;
353  }
354 
355  // Select what weight to use in the 'importance' rule visualisation plots.
356  // Note that if UseCoefficientsVisHists() is selected, the following weight is used:
357  // w = rule coefficient * rule support
358  // The support is a positive number which is 0 if no events are accepted by the rule.
359  // Normally the importance gives more useful information.
360  //
361  //fRuleFit.UseCoefficientsVisHists();
363 
364  fRuleFit.SetMsgType( Log().GetMinType() );
365 
367 
368 }
369 
370 ////////////////////////////////////////////////////////////////////////////////
371 /// initialize the monitoring ntuple
372 
374 {
375  BaseDir()->cd();
376  fMonitorNtuple= new TTree("MonitorNtuple_RuleFit","RuleFit variables");
377  fMonitorNtuple->Branch("importance",&fNTImportance,"importance/D");
378  fMonitorNtuple->Branch("support",&fNTSupport,"support/D");
379  fMonitorNtuple->Branch("coefficient",&fNTCoefficient,"coefficient/D");
380  fMonitorNtuple->Branch("ncuts",&fNTNcuts,"ncuts/I");
381  fMonitorNtuple->Branch("nvars",&fNTNvars,"nvars/I");
382  fMonitorNtuple->Branch("type",&fNTType,"type/I");
383  fMonitorNtuple->Branch("ptag",&fNTPtag,"ptag/D");
384  fMonitorNtuple->Branch("pss",&fNTPss,"pss/D");
385  fMonitorNtuple->Branch("psb",&fNTPsb,"psb/D");
386  fMonitorNtuple->Branch("pbs",&fNTPbs,"pbs/D");
387  fMonitorNtuple->Branch("pbb",&fNTPbb,"pbb/D");
388  fMonitorNtuple->Branch("soversb",&fNTSSB,"soversb/D");
389 }
390 
391 ////////////////////////////////////////////////////////////////////////////////
392 /// default initialization
393 
395 {
396  // the minimum requirement to declare an event signal-like
397  SetSignalReferenceCut( 0.0 );
398 
399  // set variables that used to be options
400  // any modifications are then made in ProcessOptions()
401  fLinQuantile = 0.025; // Quantile of linear terms (remove outliers)
402  fTreeEveFrac = -1.0; // Fraction of events used to train each tree
403  fNCuts = 20; // Number of steps during node cut optimisation
404  fSepTypeS = "GiniIndex"; // Separation criterion for node splitting; see BDT
405  fPruneMethodS = "NONE"; // Pruning method; see BDT
406  fPruneStrength = 3.5; // Pruning strength; see BDT
407  fGDTauMin = 0.0; // Gradient-directed path: min fit threshold (tau)
408  fGDTauMax = 1.0; // Gradient-directed path: max fit threshold (tau)
409  fGDTauScan = 1000; // Gradient-directed path: number of points scanning for best tau
410 
411 }
412 
413 ////////////////////////////////////////////////////////////////////////////////
414 /// write all Events from the Tree into a vector of Events, that are
415 /// more easily manipulated.
416 /// This method should never be called without existing trainingTree, as it
417 /// the vector of events from the ROOT training tree
418 
420 {
421  if (Data()->GetNEvents()==0) Log() << kFATAL << "<Init> Data().TrainingTree() is zero pointer" << Endl;
422 
423  Int_t nevents = Data()->GetNEvents();
424  for (Int_t ievt=0; ievt<nevents; ievt++){
425  const Event * ev = GetEvent(ievt);
426  fEventSample.push_back( new Event(*ev));
427  }
428  if (fTreeEveFrac<=0) {
429  Double_t n = static_cast<Double_t>(nevents);
430  fTreeEveFrac = min( 0.5, (100.0 +6.0*sqrt(n))/n);
431  }
432  if (fTreeEveFrac>1.0) fTreeEveFrac=1.0;
433  //
434  std::random_shuffle(fEventSample.begin(), fEventSample.end());
435  //
436  Log() << kDEBUG << "Set sub-sample fraction to " << fTreeEveFrac << Endl;
437 }
438 
439 ////////////////////////////////////////////////////////////////////////////////
440 
442 {
444  // training of rules
445 
447 
448  // fill the STL Vector with the event sample
449  this->InitEventSample();
450 
451  if (fUseRuleFitJF) {
452  TrainJFRuleFit();
453  }
454  else {
456  }
460 }
461 
462 ////////////////////////////////////////////////////////////////////////////////
463 /// training of rules using TMVA implementation
464 
466 {
467  if (IsNormalised()) Log() << kFATAL << "\"Normalise\" option cannot be used with RuleFit; "
468  << "please remove the optoin from the configuration string, or "
469  << "use \"!Normalise\""
470  << Endl;
471 
472  // timer
473  Timer timer( 1, GetName() );
474 
475  // test tree nmin cut -> for debug purposes
476  // the routine will generate trees with stopping cut on N(eve) given by
477  // a fraction between [20,N(eve)-1].
478  //
479  // MakeForestRnd();
480  // exit(1);
481  //
482 
483  // Init RuleFit object and create rule ensemble
484  // + make forest & rules
485  fRuleFit.Initialize( this );
486 
487  // Make forest of decision trees
488  // if (fRuleFit.GetRuleEnsemble().DoRules()) fRuleFit.MakeForest();
489 
490  // Fit the rules
491  Log() << kDEBUG << "Fitting rule coefficients ..." << Endl;
493 
494  // Calculate importance
495  Log() << kDEBUG << "Computing rule and variable importance" << Endl;
497 
498  // Output results and fill monitor ntuple
500  //
501  if(!IsSilentFile())
502  {
503  Log() << kDEBUG << "Filling rule ntuple" << Endl;
504  UInt_t nrules = fRuleFit.GetRuleEnsemble().GetRulesConst().size();
505  const Rule *rule;
506  for (UInt_t i=0; i<nrules; i++ ) {
509  fNTSupport = rule->GetSupport();
510  fNTCoefficient = rule->GetCoefficient();
511  fNTType = (rule->IsSignalRule() ? 1:-1 );
512  fNTNvars = rule->GetRuleCut()->GetNvars();
513  fNTNcuts = rule->GetRuleCut()->GetNcuts();
514  fNTPtag = fRuleFit.GetRuleEnsemble().GetRulePTag(i); // should be identical with support
519  fNTSSB = rule->GetSSB();
520  fMonitorNtuple->Fill();
521  }
522 
525  }
526  Log() << kDEBUG << "Training done" << Endl;
527 
528 }
529 
530 ////////////////////////////////////////////////////////////////////////////////
531 /// training of rules using Jerome Friedmans implementation
532 
534 {
535  fRuleFit.InitPtrs( this );
537  UInt_t nevents = Data()->GetNTrainingEvents();
538  std::vector<const TMVA::Event*> tmp;
539  for (Long64_t ievt=0; ievt<nevents; ievt++) {
540  const Event *event = GetEvent(ievt);
541  tmp.push_back(event);
542  }
544 
545  RuleFitAPI *rfAPI = new RuleFitAPI( this, &fRuleFit, Log().GetMinType() );
546 
547  rfAPI->WelcomeMessage();
548 
549  // timer
550  Timer timer( 1, GetName() );
551 
552  Log() << kINFO << "Training ..." << Endl;
553  rfAPI->TrainRuleFit();
554 
555  Log() << kDEBUG << "reading model summary from rf_go.exe output" << Endl;
556  rfAPI->ReadModelSum();
557 
558  // fRuleFit.GetRuleEnsemblePtr()->MakeRuleMap();
559 
560  Log() << kDEBUG << "calculating rule and variable importance" << Endl;
562 
563  // Output results and fill monitor ntuple
565  //
567 
568  delete rfAPI;
569 
570  Log() << kDEBUG << "done training" << Endl;
571 }
572 
573 ////////////////////////////////////////////////////////////////////////////////
574 /// computes ranking of input variables
575 
577 {
578  // create the ranking object
579  fRanking = new Ranking( GetName(), "Importance" );
580 
581  for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
583  }
584 
585  return fRanking;
586 }
587 
588 ////////////////////////////////////////////////////////////////////////////////
589 /// add the rules to XML node
590 
591 void TMVA::MethodRuleFit::AddWeightsXMLTo( void* parent ) const
592 {
593  fRuleFit.GetRuleEnsemble().AddXMLTo( parent );
594 }
595 
596 ////////////////////////////////////////////////////////////////////////////////
597 /// read rules from an std::istream
598 
600 {
602 }
603 
604 ////////////////////////////////////////////////////////////////////////////////
605 /// read rules from XML node
606 
608 {
609  fRuleFit.GetRuleEnsemblePtr()->ReadFromXML( wghtnode );
610 }
611 
612 ////////////////////////////////////////////////////////////////////////////////
613 /// returns MVA value for given event
614 
616 {
617  // cannot determine error
618  NoErrorCalc(err, errUpper);
619 
620  return fRuleFit.EvalEvent( *GetEvent() );
621 }
622 
623 ////////////////////////////////////////////////////////////////////////////////
624 /// write special monitoring histograms to file (here ntuple)
625 
627 {
628  BaseDir()->cd();
629  Log() << kINFO << "Write monitoring ntuple to file: " << BaseDir()->GetPath() << Endl;
631 }
632 
633 ////////////////////////////////////////////////////////////////////////////////
634 /// write specific classifier response
635 
636 void TMVA::MethodRuleFit::MakeClassSpecific( std::ostream& fout, const TString& className ) const
637 {
638  Int_t dp = fout.precision();
639  fout << " // not implemented for class: \"" << className << "\"" << std::endl;
640  fout << "};" << std::endl;
641  fout << "void " << className << "::Initialize(){}" << std::endl;
642  fout << "void " << className << "::Clear(){}" << std::endl;
643  fout << "double " << className << "::GetMvaValue__( const std::vector<double>& inputValues ) const {" << std::endl;
644  fout << " double rval=" << std::setprecision(10) << fRuleFit.GetRuleEnsemble().GetOffset() << ";" << std::endl;
645  MakeClassRuleCuts(fout);
646  MakeClassLinear(fout);
647  fout << " return rval;" << std::endl;
648  fout << "}" << std::endl;
649  fout << std::setprecision(dp);
650 }
651 
652 ////////////////////////////////////////////////////////////////////////////////
653 /// print out the rule cuts
654 
655 void TMVA::MethodRuleFit::MakeClassRuleCuts( std::ostream& fout ) const
656 {
657  Int_t dp = fout.precision();
658  if (!fRuleFit.GetRuleEnsemble().DoRules()) {
659  fout << " //" << std::endl;
660  fout << " // ==> MODEL CONTAINS NO RULES <==" << std::endl;
661  fout << " //" << std::endl;
662  return;
663  }
664  const RuleEnsemble *rens = &(fRuleFit.GetRuleEnsemble());
665  const std::vector< Rule* > *rules = &(rens->GetRulesConst());
666  const RuleCut *ruleCut;
667  //
668  std::list< std::pair<Double_t,Int_t> > sortedRules;
669  for (UInt_t ir=0; ir<rules->size(); ir++) {
670  sortedRules.push_back( std::pair<Double_t,Int_t>( (*rules)[ir]->GetImportance()/rens->GetImportanceRef(),ir ) );
671  }
672  sortedRules.sort();
673  //
674  fout << " //" << std::endl;
675  fout << " // here follows all rules ordered in importance (most important first)" << std::endl;
676  fout << " // at the end of each line, the relative importance of the rule is given" << std::endl;
677  fout << " //" << std::endl;
678  //
679  for ( std::list< std::pair<double,int> >::reverse_iterator itpair = sortedRules.rbegin();
680  itpair != sortedRules.rend(); itpair++ ) {
681  UInt_t ir = itpair->second;
682  Double_t impr = itpair->first;
683  ruleCut = (*rules)[ir]->GetRuleCut();
684  if (impr<rens->GetImportanceCut()) fout << " //" << std::endl;
685  fout << " if (" << std::flush;
686  for (UInt_t ic=0; ic<ruleCut->GetNvars(); ic++) {
687  Double_t sel = ruleCut->GetSelector(ic);
688  Double_t valmin = ruleCut->GetCutMin(ic);
689  Double_t valmax = ruleCut->GetCutMax(ic);
690  Bool_t domin = ruleCut->GetCutDoMin(ic);
691  Bool_t domax = ruleCut->GetCutDoMax(ic);
692  //
693  if (ic>0) fout << "&&" << std::flush;
694  if (domin) {
695  fout << "(" << std::setprecision(10) << valmin << std::flush;
696  fout << "<inputValues[" << sel << "])" << std::flush;
697  }
698  if (domax) {
699  if (domin) fout << "&&" << std::flush;
700  fout << "(inputValues[" << sel << "]" << std::flush;
701  fout << "<" << std::setprecision(10) << valmax << ")" <<std::flush;
702  }
703  }
704  fout << ") rval+=" << std::setprecision(10) << (*rules)[ir]->GetCoefficient() << ";" << std::flush;
705  fout << " // importance = " << Form("%3.3f",impr) << std::endl;
706  }
707  fout << std::setprecision(dp);
708 }
709 
710 ////////////////////////////////////////////////////////////////////////////////
711 /// print out the linear terms
712 
713 void TMVA::MethodRuleFit::MakeClassLinear( std::ostream& fout ) const
714 {
715  if (!fRuleFit.GetRuleEnsemble().DoLinear()) {
716  fout << " //" << std::endl;
717  fout << " // ==> MODEL CONTAINS NO LINEAR TERMS <==" << std::endl;
718  fout << " //" << std::endl;
719  return;
720  }
721  fout << " //" << std::endl;
722  fout << " // here follows all linear terms" << std::endl;
723  fout << " // at the end of each line, the relative importance of the term is given" << std::endl;
724  fout << " //" << std::endl;
725  const RuleEnsemble *rens = &(fRuleFit.GetRuleEnsemble());
726  UInt_t nlin = rens->GetNLinear();
727  for (UInt_t il=0; il<nlin; il++) {
728  if (rens->IsLinTermOK(il)) {
729  Double_t norm = rens->GetLinNorm(il);
730  Double_t imp = rens->GetLinImportance(il)/rens->GetImportanceRef();
731  fout << " rval+="
732  // << std::setprecision(10) << rens->GetLinCoefficients(il)*norm << "*std::min(" << setprecision(10) << rens->GetLinDP(il)
733  // << ", std::max( inputValues[" << il << "]," << std::setprecision(10) << rens->GetLinDM(il) << "));"
734  << std::setprecision(10) << rens->GetLinCoefficients(il)*norm
735  << "*std::min( double(" << std::setprecision(10) << rens->GetLinDP(il)
736  << "), std::max( double(inputValues[" << il << "]), double(" << std::setprecision(10) << rens->GetLinDM(il) << ")));"
737  << std::flush;
738  fout << " // importance = " << Form("%3.3f",imp) << std::endl;
739  }
740  }
741 }
742 
743 ////////////////////////////////////////////////////////////////////////////////
744 /// get help message text
745 ///
746 /// typical length of text line:
747 /// "|--------------------------------------------------------------|"
748 
750 {
751  TString col = gConfig().WriteOptionsReference() ? TString() : gTools().Color("bold");
752  TString colres = gConfig().WriteOptionsReference() ? TString() : gTools().Color("reset");
753  TString brk = gConfig().WriteOptionsReference() ? "<br>" : "";
754 
755  Log() << Endl;
756  Log() << col << "--- Short description:" << colres << Endl;
757  Log() << Endl;
758  Log() << "This method uses a collection of so called rules to create a" << Endl;
759  Log() << "discriminating scoring function. Each rule consists of a series" << Endl;
760  Log() << "of cuts in parameter space. The ensemble of rules are created" << Endl;
761  Log() << "from a forest of decision trees, trained using the training data." << Endl;
762  Log() << "Each node (apart from the root) corresponds to one rule." << Endl;
763  Log() << "The scoring function is then obtained by linearly combining" << Endl;
764  Log() << "the rules. A fitting procedure is applied to find the optimum" << Endl;
765  Log() << "set of coefficients. The goal is to find a model with few rules" << Endl;
766  Log() << "but with a strong discriminating power." << Endl;
767  Log() << Endl;
768  Log() << col << "--- Performance optimisation:" << colres << Endl;
769  Log() << Endl;
770  Log() << "There are two important considerations to make when optimising:" << Endl;
771  Log() << Endl;
772  Log() << " 1. Topology of the decision tree forest" << brk << Endl;
773  Log() << " 2. Fitting of the coefficients" << Endl;
774  Log() << Endl;
775  Log() << "The maximum complexity of the rules is defined by the size of" << Endl;
776  Log() << "the trees. Large trees will yield many complex rules and capture" << Endl;
777  Log() << "higher order correlations. On the other hand, small trees will" << Endl;
778  Log() << "lead to a smaller ensemble with simple rules, only capable of" << Endl;
779  Log() << "modeling simple structures." << Endl;
780  Log() << "Several parameters exists for controlling the complexity of the" << Endl;
781  Log() << "rule ensemble." << Endl;
782  Log() << Endl;
783  Log() << "The fitting procedure searches for a minimum using a gradient" << Endl;
784  Log() << "directed path. Apart from step size and number of steps, the" << Endl;
785  Log() << "evolution of the path is defined by a cut-off parameter, tau." << Endl;
786  Log() << "This parameter is unknown and depends on the training data." << Endl;
787  Log() << "A large value will tend to give large weights to a few rules." << Endl;
788  Log() << "Similarily, a small value will lead to a large set of rules" << Endl;
789  Log() << "with similar weights." << Endl;
790  Log() << Endl;
791  Log() << "A final point is the model used; rules and/or linear terms." << Endl;
792  Log() << "For a given training sample, the result may improve by adding" << Endl;
793  Log() << "linear terms. If best performance is optained using only linear" << Endl;
794  Log() << "terms, it is very likely that the Fisher discriminant would be" << Endl;
795  Log() << "a better choice. Ideally the fitting procedure should be able to" << Endl;
796  Log() << "make this choice by giving appropriate weights for either terms." << Endl;
797  Log() << Endl;
798  Log() << col << "--- Performance tuning via configuration options:" << colres << Endl;
799  Log() << Endl;
800  Log() << "I. TUNING OF RULE ENSEMBLE:" << Endl;
801  Log() << Endl;
802  Log() << " " << col << "ForestType " << colres
803  << ": Recomended is to use the default \"AdaBoost\"." << brk << Endl;
804  Log() << " " << col << "nTrees " << colres
805  << ": More trees leads to more rules but also slow" << Endl;
806  Log() << " performance. With too few trees the risk is" << Endl;
807  Log() << " that the rule ensemble becomes too simple." << brk << Endl;
808  Log() << " " << col << "fEventsMin " << colres << brk << Endl;
809  Log() << " " << col << "fEventsMax " << colres
810  << ": With a lower min, more large trees will be generated" << Endl;
811  Log() << " leading to more complex rules." << Endl;
812  Log() << " With a higher max, more small trees will be" << Endl;
813  Log() << " generated leading to more simple rules." << Endl;
814  Log() << " By changing this range, the average complexity" << Endl;
815  Log() << " of the rule ensemble can be controlled." << brk << Endl;
816  Log() << " " << col << "RuleMinDist " << colres
817  << ": By increasing the minimum distance between" << Endl;
818  Log() << " rules, fewer and more diverse rules will remain." << Endl;
819  Log() << " Initially it is a good idea to keep this small" << Endl;
820  Log() << " or zero and let the fitting do the selection of" << Endl;
821  Log() << " rules. In order to reduce the ensemble size," << Endl;
822  Log() << " the value can then be increased." << Endl;
823  Log() << Endl;
824  // "|--------------------------------------------------------------|"
825  Log() << "II. TUNING OF THE FITTING:" << Endl;
826  Log() << Endl;
827  Log() << " " << col << "GDPathEveFrac " << colres
828  << ": fraction of events in path evaluation" << Endl;
829  Log() << " Increasing this fraction will improve the path" << Endl;
830  Log() << " finding. However, a too high value will give few" << Endl;
831  Log() << " unique events available for error estimation." << Endl;
832  Log() << " It is recomended to usethe default = 0.5." << brk << Endl;
833  Log() << " " << col << "GDTau " << colres
834  << ": cutoff parameter tau" << Endl;
835  Log() << " By default this value is set to -1.0." << Endl;
836  // "|----------------|---------------------------------------------|"
837  Log() << " This means that the cut off parameter is" << Endl;
838  Log() << " automatically estimated. In most cases" << Endl;
839  Log() << " this should be fine. However, you may want" << Endl;
840  Log() << " to fix this value if you already know it" << Endl;
841  Log() << " and want to reduce on training time." << brk << Endl;
842  Log() << " " << col << "GDTauPrec " << colres
843  << ": precision of estimated tau" << Endl;
844  Log() << " Increase this precision to find a more" << Endl;
845  Log() << " optimum cut-off parameter." << brk << Endl;
846  Log() << " " << col << "GDNStep " << colres
847  << ": number of steps in path search" << Endl;
848  Log() << " If the number of steps is too small, then" << Endl;
849  Log() << " the program will give a warning message." << Endl;
850  Log() << Endl;
851  Log() << "III. WARNING MESSAGES" << Endl;
852  Log() << Endl;
853  Log() << col << "Risk(i+1)>=Risk(i) in path" << colres << brk << Endl;
854  Log() << col << "Chaotic behaviour of risk evolution." << colres << Endl;
855  // "|----------------|---------------------------------------------|"
856  Log() << " The error rate was still decreasing at the end" << Endl;
857  Log() << " By construction the Risk should always decrease." << Endl;
858  Log() << " However, if the training sample is too small or" << Endl;
859  Log() << " the model is overtrained, such warnings can" << Endl;
860  Log() << " occur." << Endl;
861  Log() << " The warnings can safely be ignored if only a" << Endl;
862  Log() << " few (<3) occur. If more warnings are generated," << Endl;
863  Log() << " the fitting fails." << Endl;
864  Log() << " A remedy may be to increase the value" << brk << Endl;
865  Log() << " "
866  << col << "GDValidEveFrac" << colres
867  << " to 1.0 (or a larger value)." << brk << Endl;
868  Log() << " In addition, if "
869  << col << "GDPathEveFrac" << colres
870  << " is too high" << Endl;
871  Log() << " the same warnings may occur since the events" << Endl;
872  Log() << " used for error estimation are also used for" << Endl;
873  Log() << " path estimation." << Endl;
874  Log() << " Another possibility is to modify the model - " << Endl;
875  Log() << " See above on tuning the rule ensemble." << Endl;
876  Log() << Endl;
877  Log() << col << "The error rate was still decreasing at the end of the path"
878  << colres << Endl;
879  Log() << " Too few steps in path! Increase "
880  << col << "GDNSteps" << colres << "." << Endl;
881  Log() << Endl;
882  Log() << col << "Reached minimum early in the search" << colres << Endl;
883 
884  Log() << " Minimum was found early in the fitting. This" << Endl;
885  Log() << " may indicate that the used step size "
886  << col << "GDStep" << colres << "." << Endl;
887  Log() << " was too large. Reduce it and rerun." << Endl;
888  Log() << " If the results still are not OK, modify the" << Endl;
889  Log() << " model either by modifying the rule ensemble" << Endl;
890  Log() << " or add/remove linear terms" << Endl;
891 }
Config & gConfig()
Definition: Config.cxx:43
void DeclareOptions()
define the options (their key words) that can be set in the option string know options.
void Init(void)
default initialization
void WelcomeMessage()
welcome message
Definition: RuleFitAPI.cxx:78
void ReadWeightsFromXML(void *wghtnode)
read rules from XML node
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
void TrainRuleFit()
Definition: RuleFitAPI.h:201
long long Long64_t
Definition: RtypesCore.h:69
const std::vector< Double_t > & GetVarImportance() const
Definition: RuleEnsemble.h:282
#define REGISTER_METHOD(CLASS)
for example
void ReadWeightsFromStream(std::istream &istr)
read rules from an std::istream
void CalcImportance()
calculates the importance of each rule
Definition: RuleFit.cxx:414
Double_t GetRulePBS(int i) const
Definition: RuleEnsemble.h:300
const std::vector< Double_t > & GetLinCoefficients() const
Definition: RuleEnsemble.h:279
const char * GetName() const
Definition: MethodBase.h:330
Double_t GetLinDP(int i) const
Definition: RuleEnsemble.h:294
Bool_t IgnoreEventsWithNegWeightsInTraining() const
Definition: MethodBase.h:680
virtual Int_t Fill()
Fill all branches.
Definition: TTree.cxx:4374
void SetGDTau(Double_t t)
Definition: RuleFitParams.h:90
void SetMsgType(EMsgType t)
set the current message type to that of mlog for this class and all other subtools ...
Definition: RuleFit.cxx:186
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
void InitMonitorNtuple()
initialize the monitoring ntuple
UInt_t GetNvar() const
Definition: MethodBase.h:340
DataSet * Data() const
Definition: MethodBase.h:405
EAnalysisType
Definition: Types.h:128
Bool_t IsNormalised() const
Definition: MethodBase.h:490
void WriteMonitoringHistosToFile(void) const
write special monitoring histograms to file (here ntuple)
void FitCoefficients()
Fit the coefficients for the rule ensemble.
Definition: RuleFit.cxx:405
Basic string class.
Definition: TString.h:137
Bool_t IsSignalRule() const
Definition: Rule.h:125
void SetGDTauPrec(Double_t p)
Definition: RuleFitParams.h:94
Double_t GetCutMax(Int_t is) const
Definition: RuleCut.h:75
void ToLower()
Change string to lower-case.
Definition: TString.cxx:1089
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kFALSE
Definition: Rtypes.h:92
const Ranking * CreateRanking()
computes ranking of input variables
void SetModelFull()
Definition: RuleFit.h:111
void TrainJFRuleFit()
training of rules using Jerome Friedmans implementation
void Print() const
print function
Double_t GetImportanceRef() const
Definition: RuleEnsemble.h:274
virtual const char * GetPath() const
Returns the full path of the directory.
Definition: TDirectory.cxx:911
void SetModelLinear()
Definition: RuleFit.h:107
MethodRuleFit(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
standard constructor
void SetModelRules()
Definition: RuleFit.h:109
UInt_t GetSelector(Int_t is) const
Definition: RuleCut.h:73
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
returns MVA value for given event
Double_t GetRulePSS(int i) const
Definition: RuleEnsemble.h:298
const std::vector< Double_t > & GetLinImportance() const
Definition: RuleEnsemble.h:281
double sqrt(double)
Tools & gTools()
Definition: Tools.cxx:79
Char_t GetCutDoMax(Int_t is) const
Definition: RuleCut.h:77
TStopwatch timer
Definition: pirndm.C:37
TMVA::DecisionTree::EPruneMethod fPruneMethod
Double_t GetCutMin(Int_t is) const
Definition: RuleCut.h:74
void ProcessOptions()
process the options specified by the user
void MakeClassLinear(std::ostream &) const
print out the linear terms
void UseImportanceVisHists()
Definition: RuleFit.h:122
Bool_t IsLinTermOK(int i) const
Definition: RuleEnsemble.h:303
void SetTrainingEvents(const std::vector< const TMVA::Event * > &el)
set the training events randomly
Definition: RuleFit.cxx:436
UInt_t GetNEvents() const
temporary event when testing on a different DataSet than the own one
Definition: MethodBase.h:413
void SetGDErrScale(Double_t s)
Definition: RuleFitParams.h:93
std::vector< TMVA::Event * > fEventSample
Double_t GetSSB() const
Definition: Rule.h:123
Bool_t ReadModelSum()
read model from rulefit.sum
Definition: RuleFitAPI.cxx:546
void TrainTMVARuleFit()
training of rules using TMVA implementation
RuleEnsemble * GetRuleEnsemblePtr()
Definition: RuleFit.h:149
const RuleCut * GetRuleCut() const
Definition: Rule.h:145
UInt_t GetNvars() const
Definition: RuleCut.h:72
void SetGDNPathSteps(Int_t np)
Definition: RuleFitParams.h:73
Double_t GetLinDM(int i) const
Definition: RuleEnsemble.h:293
void SetGDTauScan(UInt_t n)
Definition: RuleFitParams.h:87
Bool_t HasTrainingTree() const
Definition: MethodBase.h:507
SeparationBase * fSepType
void AddWeightsXMLTo(void *parent) const
add the rules to XML node
virtual Int_t Write(const char *name=0, Int_t option=0, Int_t bufsize=0)
Write this object to the current directory.
Definition: TTree.cxx:9001
Bool_t DoRules() const
Definition: RuleEnsemble.h:268
Double_t GetRulePSB(int i) const
Definition: RuleEnsemble.h:299
RuleFitParams * GetRuleFitParamsPtr()
Definition: RuleFit.h:151
void MakeDebugHists()
this will create a histograms intended rather for debugging or for the curious user ...
Definition: RuleFit.cxx:935
const RuleEnsemble & GetRuleEnsemble() const
Definition: RuleFit.h:148
unsigned int UInt_t
Definition: RtypesCore.h:42
const Event * GetEvent() const
Definition: MethodBase.h:745
char * Form(const char *fmt,...)
UInt_t GetNcuts() const
get number of cuts
Definition: RuleCut.cxx:158
Double_t GetRulePTag(int i) const
Definition: RuleEnsemble.h:297
const std::vector< TMVA::Rule * > & GetRulesConst() const
Definition: RuleEnsemble.h:277
Bool_t IsSilentFile()
Definition: MethodBase.h:375
void ReadFromXML(void *wghtnode)
read rules from XML
Bool_t VerifyRange(MsgLogger &mlog, const char *varstr, T &var, const T &vmin, const T &vmax)
void SetCurrentType(Types::ETreeType type) const
Definition: DataSet.h:114
void GetHelpMessage() const
get help message text
#define ClassImp(name)
Definition: Rtypes.h:279
UInt_t GetNLinear() const
Definition: RuleEnsemble.h:283
double Double_t
Definition: RtypesCore.h:55
Double_t GetRulePBB(int i) const
Definition: RuleEnsemble.h:301
int type
Definition: TGX11.cxx:120
void MakeVisHists()
this will create histograms visualizing the rule ensemble
Definition: RuleFit.cxx:774
Double_t GetRelImportance() const
Definition: Rule.h:108
TDirectory * BaseDir() const
returns the ROOT directory where info/histograms etc of the corresponding MVA method instance are sto...
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:229
MsgLogger & Log() const
Definition: Configurable.h:128
Char_t GetCutDoMin(Int_t is) const
Definition: RuleCut.h:76
you should not use this method at all Int_t Int_t Double_t Double_t Double_t e
Definition: TRolke.cxx:630
void AddPreDefVal(const T &)
Definition: Configurable.h:174
const TString & GetInputLabel(Int_t i) const
Definition: MethodBase.h:346
void InitPtrs(const TMVA::MethodBase *rfbase)
initialize pointers
Definition: RuleFit.cxx:105
Double_t GetCoefficient() const
Definition: Rule.h:147
void ExitFromTraining()
Definition: MethodBase.h:458
void SetGDTauRange(Double_t t0, Double_t t1)
Definition: RuleFitParams.h:79
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:837
virtual Int_t Branch(TCollection *list, Int_t bufsize=32000, Int_t splitlevel=99, const char *name="")
Create one branch for each element in the collection.
Definition: TTree.cxx:1651
void SetRuleMinDist(Double_t d)
Definition: RuleFit.h:115
Abstract ClassifierFactory template that handles arbitrary types.
Ranking * fRanking
Definition: MethodBase.h:581
virtual Bool_t cd(const char *path=0)
Change current directory to "this" directory.
Definition: TDirectory.cxx:435
virtual void AddRank(const Rank &rank)
Add a new rank take ownership of it.
Definition: Ranking.cxx:86
Double_t GetSupport() const
Definition: Rule.h:148
Double_t EvalEvent(const Event &e)
evaluate single event
Definition: RuleFit.cxx:428
void * AddXMLTo(void *parent) const
write rules to XML
Bool_t WriteOptionsReference() const
Definition: Config.h:66
void MakeClassSpecific(std::ostream &, const TString &) const
write specific classifier response
Long64_t GetNTrainingEvents() const
Definition: DataSet.h:93
A TTree object has a header with a name and a title.
Definition: TTree.h:98
const std::vector< Double_t > & GetLinNorm() const
Definition: RuleEnsemble.h:280
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t)
RuleFit can handle classification with 2 classes.
virtual ~MethodRuleFit(void)
destructor
void Initialize(const TMVA::MethodBase *rfbase)
initialize the parameters of the RuleFit method and make rules
Definition: RuleFit.cxx:115
TString GetMethodTypeName() const
Definition: MethodBase.h:328
const Bool_t kTRUE
Definition: Rtypes.h:91
Double_t GetOffset() const
Definition: RuleEnsemble.h:275
double norm(double *x, double *p)
Definition: unuranDistr.cxx:40
const Int_t n
Definition: legend1.C:16
std::vector< DecisionTree * > fForest
void SetGDPathStep(Double_t s)
Definition: RuleFitParams.h:76
Bool_t DoLinear() const
Definition: RuleEnsemble.h:267
void SetImportanceCut(Double_t minimp=0)
Definition: RuleFit.h:113
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
Definition: MethodBase.cxx:819
void SetSignalReferenceCut(Double_t cut)
Definition: MethodBase.h:360
void InitEventSample(void)
write all Events from the Tree into a vector of Events, that are more easily manipulated.
void SetLinQuantile(Double_t q)
Definition: RuleEnsemble.h:144
void MakeClassRuleCuts(std::ostream &) const
print out the rule cuts
void ReadRaw(std::istream &istr)
read rule ensemble from stream