Logo ROOT  
Reference Guide
Loading...
Searching...
No Matches
MethodRuleFit.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Fredrik Tegenfeldt
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodRuleFit *
8 * *
9 * *
10 * Description: *
11 * Implementation (see header file for description) *
12 * *
13 * Authors (alphabetical): *
14 * Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA *
15 * *
16 * Copyright (c) 2005: *
17 * CERN, Switzerland *
18 * Iowa State U. *
19 * MPI-K Heidelberg, Germany *
20 * *
21 * Redistribution and use in source and binary forms, with or without *
22 * modification, are permitted according to the terms listed in LICENSE *
23 * (see tmva/doc/LICENSE) *
24 **********************************************************************************/
25
26/*! \class TMVA::MethodRuleFit
27\ingroup TMVA
28J Friedman's RuleFit method
29*/
30
31#include "TMVA/MethodRuleFit.h"
32
34#include "TMVA/Config.h"
35#include "TMVA/Configurable.h"
36#include "TMVA/CrossEntropy.h"
37#include "TMVA/DataSet.h"
38#include "TMVA/DecisionTree.h"
39#include "TMVA/GiniIndex.h"
40#include "TMVA/IMethod.h"
41#include "TMVA/MethodBase.h"
43#include "TMVA/MsgLogger.h"
44#include "TMVA/Ranking.h"
45#include "TMVA/RuleFitAPI.h"
46#include "TMVA/SdivSqrtSplusB.h"
47#include "TMVA/SeparationBase.h"
48#include "TMVA/Timer.h"
49#include "TMVA/Tools.h"
50#include "TMVA/Types.h"
51
52#include "TRandom3.h"
53#include "TMatrix.h"
54
55#include <iostream>
56#include <iomanip>
57#include <algorithm>
58#include <list>
59#include <random>
60
61using std::min;
62
63REGISTER_METHOD(RuleFit)
64
65
66////////////////////////////////////////////////////////////////////////////////
67/// standard constructor
68
70 const TString& methodTitle,
71 DataSetInfo& theData,
72 const TString& theOption) :
73 MethodBase( jobName, Types::kRuleFit, methodTitle, theData, theOption)
75 , fNTImportance(0)
77 , fNTSupport(0)
78 , fNTNcuts(0)
79 , fNTNvars(0)
80 , fNTPtag(0)
81 , fNTPss(0)
82 , fNTPsb(0)
83 , fNTPbs(0)
84 , fNTPbb(0)
85 , fNTSSB(0)
86 , fNTType(0)
88 , fRFNrules(0)
89 , fRFNendnodes(0)
90 , fNTrees(0)
91 , fTreeEveFrac(0)
92 , fSepType(0)
93 , fMinFracNEve(0)
94 , fMaxFracNEve(0)
95 , fNCuts(0)
96 , fPruneMethod(TMVA::DecisionTree::kCostComplexityPruning)
100 , fGDValidEveFrac(0)
101 , fGDTau(0)
102 , fGDTauPrec(0)
103 , fGDTauMin(0)
104 , fGDTauMax(0)
105 , fGDTauScan(0)
106 , fGDPathStep(0)
107 , fGDNPathSteps(0)
108 , fGDErrScale(0)
109 , fMinimp(0)
110 , fRuleMinDist(0)
111 , fLinQuantile(0)
112{
113 fMonitorNtuple = NULL;
114}
115
116////////////////////////////////////////////////////////////////////////////////
117/// constructor from weight file
118
120 const TString& theWeightFile) :
121 MethodBase( Types::kRuleFit, theData, theWeightFile)
122 , fSignalFraction(0)
123 , fNTImportance(0)
124 , fNTCoefficient(0)
125 , fNTSupport(0)
126 , fNTNcuts(0)
127 , fNTNvars(0)
128 , fNTPtag(0)
129 , fNTPss(0)
130 , fNTPsb(0)
131 , fNTPbs(0)
132 , fNTPbb(0)
133 , fNTSSB(0)
134 , fNTType(0)
136 , fRFNrules(0)
137 , fRFNendnodes(0)
138 , fNTrees(0)
139 , fTreeEveFrac(0)
140 , fSepType(0)
141 , fMinFracNEve(0)
142 , fMaxFracNEve(0)
143 , fNCuts(0)
144 , fPruneMethod(TMVA::DecisionTree::kCostComplexityPruning)
145 , fPruneStrength(0)
147 , fGDPathEveFrac(0)
148 , fGDValidEveFrac(0)
149 , fGDTau(0)
150 , fGDTauPrec(0)
151 , fGDTauMin(0)
152 , fGDTauMax(0)
153 , fGDTauScan(0)
154 , fGDPathStep(0)
155 , fGDNPathSteps(0)
156 , fGDErrScale(0)
157 , fMinimp(0)
158 , fRuleMinDist(0)
159 , fLinQuantile(0)
160{
161 fMonitorNtuple = NULL;
162}
163
164////////////////////////////////////////////////////////////////////////////////
165/// destructor
166
168{
169 for (UInt_t i=0; i<fEventSample.size(); i++) delete fEventSample[i];
170 for (UInt_t i=0; i<fForest.size(); i++) delete fForest[i];
171}
172
173////////////////////////////////////////////////////////////////////////////////
174/// RuleFit can handle classification with 2 classes
175
177{
178 if (type == Types::kClassification && numberClasses == 2) return kTRUE;
179 return kFALSE;
180}
181
182////////////////////////////////////////////////////////////////////////////////
183/// define the options (their key words) that can be set in the option string
184/// know options.
185///
186/// #### general
187///
188/// - RuleFitModule `<string>`
189/// available values are:
190/// - RFTMVA - use TMVA implementation
191/// - RFFriedman - use Friedmans original implementation
192///
193/// #### Path search (fitting)
194///
195/// - GDTau `<float>` gradient-directed path: fit threshold, default
196/// - GDTauPrec `<float>` gradient-directed path: precision of estimated tau
197/// - GDStep `<float>` gradient-directed path: step size
198/// - GDNSteps `<float>` gradient-directed path: number of steps
199/// - GDErrScale `<float>` stop scan when error>scale*errmin
200///
201/// #### Tree generation
202///
203/// - fEventsMin `<float>` minimum fraction of events in a splittable node
204/// - fEventsMax `<float>` maximum fraction of events in a splittable node
205/// - nTrees `<float>` number of trees in forest.
206/// - ForestType `<string>`
207/// available values are:
208/// - Random - create forest using random subsample and only random variables subset at each node
209/// - AdaBoost - create forest with boosted events
210///
211/// #### Model creation
212///
213/// - RuleMinDist `<float>` min distance allowed between rules
214/// - MinImp `<float>` minimum rule importance accepted
215/// - Model `<string>` model to be used
216/// available values are:
217/// - ModRuleLinear `<default>`
218/// - ModRule
219/// - ModLinear
220///
221/// #### Friedmans module
222///
223/// - RFWorkDir `<string>` directory where Friedmans module (rf_go.exe) is installed
224/// - RFNrules `<int>` maximum number of rules allowed
225/// - RFNendnodes `<int>` average number of end nodes in the forest of trees
226
228{
229 DeclareOptionRef(fGDTau=-1, "GDTau", "Gradient-directed (GD) path: default fit cut-off");
230 DeclareOptionRef(fGDTauPrec=0.01, "GDTauPrec", "GD path: precision of tau");
231 DeclareOptionRef(fGDPathStep=0.01, "GDStep", "GD path: step size");
232 DeclareOptionRef(fGDNPathSteps=10000, "GDNSteps", "GD path: number of steps");
233 DeclareOptionRef(fGDErrScale=1.1, "GDErrScale", "Stop scan when error > scale*errmin");
234 DeclareOptionRef(fLinQuantile, "LinQuantile", "Quantile of linear terms (removes outliers)");
235 DeclareOptionRef(fGDPathEveFrac=0.5, "GDPathEveFrac", "Fraction of events used for the path search");
236 DeclareOptionRef(fGDValidEveFrac=0.5, "GDValidEveFrac", "Fraction of events used for the validation");
237 // tree options
238 DeclareOptionRef(fMinFracNEve=0.1, "fEventsMin", "Minimum fraction of events in a splittable node");
239 DeclareOptionRef(fMaxFracNEve=0.9, "fEventsMax", "Maximum fraction of events in a splittable node");
240 DeclareOptionRef(fNTrees=20, "nTrees", "Number of trees in forest.");
241
242 DeclareOptionRef(fForestTypeS="AdaBoost", "ForestType", "Method to use for forest generation (AdaBoost or RandomForest)");
243 AddPreDefVal(TString("AdaBoost"));
244 AddPreDefVal(TString("Random"));
245 // rule cleanup options
246 DeclareOptionRef(fRuleMinDist=0.001, "RuleMinDist", "Minimum distance between rules");
247 DeclareOptionRef(fMinimp=0.01, "MinImp", "Minimum rule importance accepted");
248 // rule model option
249 DeclareOptionRef(fModelTypeS="ModRuleLinear", "Model", "Model to be used");
250 AddPreDefVal(TString("ModRule"));
251 AddPreDefVal(TString("ModRuleLinear"));
252 AddPreDefVal(TString("ModLinear"));
253 DeclareOptionRef(fRuleFitModuleS="RFTMVA", "RuleFitModule","Which RuleFit module to use");
254 AddPreDefVal(TString("RFTMVA"));
255 AddPreDefVal(TString("RFFriedman"));
256
257 DeclareOptionRef(fRFWorkDir="./rulefit", "RFWorkDir", "Friedman\'s RuleFit module (RFF): working dir");
258 DeclareOptionRef(fRFNrules=2000, "RFNrules", "RFF: Mximum number of rules");
259 DeclareOptionRef(fRFNendnodes=4, "RFNendnodes", "RFF: Average number of end nodes");
260}
261
262////////////////////////////////////////////////////////////////////////////////
263/// process the options specified by the user
264
266{
268 Log() << kFATAL << "Mechanism to ignore events with negative weights in training not yet available for method: "
270 << " --> please remove \"IgnoreNegWeightsInTraining\" option from booking string."
271 << Endl;
272 }
273
274 fRuleFitModuleS.ToLower();
275 if (fRuleFitModuleS == "rftmva") fUseRuleFitJF = kFALSE;
276 else if (fRuleFitModuleS == "rffriedman") fUseRuleFitJF = kTRUE;
277 else fUseRuleFitJF = kTRUE;
278
279 fSepTypeS.ToLower();
280 if (fSepTypeS == "misclassificationerror") fSepType = new MisClassificationError();
281 else if (fSepTypeS == "giniindex") fSepType = new GiniIndex();
282 else if (fSepTypeS == "crossentropy") fSepType = new CrossEntropy();
283 else fSepType = new SdivSqrtSplusB();
284
285 fModelTypeS.ToLower();
286 if (fModelTypeS == "modlinear" ) fRuleFit.SetModelLinear();
287 else if (fModelTypeS == "modrule" ) fRuleFit.SetModelRules();
288 else fRuleFit.SetModelFull();
289
290 fPruneMethodS.ToLower();
292 else if (fPruneMethodS == "costcomplexity" ) fPruneMethod = DecisionTree::kCostComplexityPruning;
294
295 fForestTypeS.ToLower();
296 if (fForestTypeS == "random" ) fUseBoost = kFALSE;
297 else if (fForestTypeS == "adaboost" ) fUseBoost = kTRUE;
298 else fUseBoost = kTRUE;
299 //
300 // if creating the forest by boosting the events
301 // the full training sample is used per tree
302 // -> only true for the TMVA version of RuleFit.
303 if (fUseBoost && (!fUseRuleFitJF)) fTreeEveFrac = 1.0;
304
305 // check event fraction for tree generation
306 // if <0 set to automatic number
307 if (fTreeEveFrac<=0) {
308 Int_t nevents = Data()->GetNTrainingEvents();
309 Double_t n = static_cast<Double_t>(nevents);
310 fTreeEveFrac = min( 0.5, (100.0 +6.0*sqrt(n))/n);
311 }
312 // verify ranges of options
313 VerifyRange(Log(), "nTrees", fNTrees,0,100000,20);
314 VerifyRange(Log(), "MinImp", fMinimp,0.0,1.0,0.0);
315 VerifyRange(Log(), "GDTauPrec", fGDTauPrec,1e-5,5e-1);
316 VerifyRange(Log(), "GDTauMin", fGDTauMin,0.0,1.0);
317 VerifyRange(Log(), "GDTauMax", fGDTauMax,fGDTauMin,1.0);
318 VerifyRange(Log(), "GDPathStep", fGDPathStep,0.0,100.0,0.01);
319 VerifyRange(Log(), "GDErrScale", fGDErrScale,1.0,100.0,1.1);
320 VerifyRange(Log(), "GDPathEveFrac", fGDPathEveFrac,0.01,0.9,0.5);
321 VerifyRange(Log(), "GDValidEveFrac",fGDValidEveFrac,0.01,1.0-fGDPathEveFrac,1.0-fGDPathEveFrac);
322 VerifyRange(Log(), "fEventsMin", fMinFracNEve,0.0,1.0);
323 VerifyRange(Log(), "fEventsMax", fMaxFracNEve,fMinFracNEve,1.0);
324
325 fRuleFit.GetRuleEnsemblePtr()->SetLinQuantile(fLinQuantile);
326 fRuleFit.GetRuleFitParamsPtr()->SetGDTauRange(fGDTauMin,fGDTauMax);
327 fRuleFit.GetRuleFitParamsPtr()->SetGDTau(fGDTau);
328 fRuleFit.GetRuleFitParamsPtr()->SetGDTauPrec(fGDTauPrec);
329 fRuleFit.GetRuleFitParamsPtr()->SetGDTauScan(fGDTauScan);
330 fRuleFit.GetRuleFitParamsPtr()->SetGDPathStep(fGDPathStep);
331 fRuleFit.GetRuleFitParamsPtr()->SetGDNPathSteps(fGDNPathSteps);
332 fRuleFit.GetRuleFitParamsPtr()->SetGDErrScale(fGDErrScale);
333 fRuleFit.SetImportanceCut(fMinimp);
334 fRuleFit.SetRuleMinDist(fRuleMinDist);
335
336
337 // check if Friedmans module is used.
338 // print a message concerning the options.
339 if (fUseRuleFitJF) {
340 Log() << kINFO << "" << Endl;
341 Log() << kINFO << "--------------------------------------" <<Endl;
342 Log() << kINFO << "Friedmans RuleFit module is selected." << Endl;
343 Log() << kINFO << "Only the following options are used:" << Endl;
344 Log() << kINFO << Endl;
345 Log() << kINFO << gTools().Color("bold") << " Model" << gTools().Color("reset") << Endl;
346 Log() << kINFO << gTools().Color("bold") << " RFWorkDir" << gTools().Color("reset") << Endl;
347 Log() << kINFO << gTools().Color("bold") << " RFNrules" << gTools().Color("reset") << Endl;
348 Log() << kINFO << gTools().Color("bold") << " RFNendnodes" << gTools().Color("reset") << Endl;
349 Log() << kINFO << gTools().Color("bold") << " GDNPathSteps" << gTools().Color("reset") << Endl;
350 Log() << kINFO << gTools().Color("bold") << " GDPathStep" << gTools().Color("reset") << Endl;
351 Log() << kINFO << gTools().Color("bold") << " GDErrScale" << gTools().Color("reset") << Endl;
352 Log() << kINFO << "--------------------------------------" <<Endl;
353 Log() << kINFO << Endl;
354 }
355
356 // Select what weight to use in the 'importance' rule visualisation plots.
357 // Note that if UseCoefficientsVisHists() is selected, the following weight is used:
358 // w = rule coefficient * rule support
359 // The support is a positive number which is 0 if no events are accepted by the rule.
360 // Normally the importance gives more useful information.
361 //
362 //fRuleFit.UseCoefficientsVisHists();
363 fRuleFit.UseImportanceVisHists();
364
365 fRuleFit.SetMsgType( Log().GetMinType() );
366
368
369}
370
371////////////////////////////////////////////////////////////////////////////////
372/// initialize the monitoring ntuple
373
375{
376 BaseDir()->cd();
377 fMonitorNtuple= new TTree("MonitorNtuple_RuleFit","RuleFit variables");
378 fMonitorNtuple->Branch("importance",&fNTImportance,"importance/D");
379 fMonitorNtuple->Branch("support",&fNTSupport,"support/D");
380 fMonitorNtuple->Branch("coefficient",&fNTCoefficient,"coefficient/D");
381 fMonitorNtuple->Branch("ncuts",&fNTNcuts,"ncuts/I");
382 fMonitorNtuple->Branch("nvars",&fNTNvars,"nvars/I");
383 fMonitorNtuple->Branch("type",&fNTType,"type/I");
384 fMonitorNtuple->Branch("ptag",&fNTPtag,"ptag/D");
385 fMonitorNtuple->Branch("pss",&fNTPss,"pss/D");
386 fMonitorNtuple->Branch("psb",&fNTPsb,"psb/D");
387 fMonitorNtuple->Branch("pbs",&fNTPbs,"pbs/D");
388 fMonitorNtuple->Branch("pbb",&fNTPbb,"pbb/D");
389 fMonitorNtuple->Branch("soversb",&fNTSSB,"soversb/D");
390}
391
392////////////////////////////////////////////////////////////////////////////////
393/// default initialization
394
396{
397 // the minimum requirement to declare an event signal-like
399
400 // set variables that used to be options
401 // any modifications are then made in ProcessOptions()
402 fLinQuantile = 0.025; // Quantile of linear terms (remove outliers)
403 fTreeEveFrac = -1.0; // Fraction of events used to train each tree
404 fNCuts = 20; // Number of steps during node cut optimisation
405 fSepTypeS = "GiniIndex"; // Separation criterion for node splitting; see BDT
406 fPruneMethodS = "NONE"; // Pruning method; see BDT
407 fPruneStrength = 3.5; // Pruning strength; see BDT
408 fGDTauMin = 0.0; // Gradient-directed path: min fit threshold (tau)
409 fGDTauMax = 1.0; // Gradient-directed path: max fit threshold (tau)
410 fGDTauScan = 1000; // Gradient-directed path: number of points scanning for best tau
411
412}
413
414////////////////////////////////////////////////////////////////////////////////
415/// write all Events from the Tree into a vector of Events, that are
416/// more easily manipulated.
417/// This method should never be called without existing trainingTree, as it
418/// the vector of events from the ROOT training tree
419
421{
422 if (Data()->GetNEvents()==0) Log() << kFATAL << "<Init> Data().TrainingTree() is zero pointer" << Endl;
423
424 Int_t nevents = Data()->GetNEvents();
425 for (Int_t ievt=0; ievt<nevents; ievt++){
426 const Event * ev = GetEvent(ievt);
427 fEventSample.push_back( new Event(*ev));
428 }
429 if (fTreeEveFrac<=0) {
430 Double_t n = static_cast<Double_t>(nevents);
431 fTreeEveFrac = min( 0.5, (100.0 +6.0*sqrt(n))/n);
432 }
433 if (fTreeEveFrac>1.0) fTreeEveFrac=1.0;
434 //
435 std::shuffle(fEventSample.begin(), fEventSample.end(), std::default_random_engine{});
436 //
437 Log() << kDEBUG << "Set sub-sample fraction to " << fTreeEveFrac << Endl;
438}
439
440////////////////////////////////////////////////////////////////////////////////
441
443{
445 // training of rules
446
448
449 // fill the STL Vector with the event sample
450 this->InitEventSample();
451
452 if (fUseRuleFitJF) {
454 }
455 else {
457 }
458 fRuleFit.GetRuleEnsemblePtr()->ClearRuleMap();
461}
462
463////////////////////////////////////////////////////////////////////////////////
464/// training of rules using TMVA implementation
465
467{
468 if (IsNormalised()) Log() << kFATAL << "\"Normalise\" option cannot be used with RuleFit; "
469 << "please remove the option from the configuration string, or "
470 << "use \"!Normalise\""
471 << Endl;
472
473 // timer
474 Timer timer( 1, GetName() );
475
476 // test tree nmin cut -> for debug purposes
477 // the routine will generate trees with stopping cut on N(eve) given by
478 // a fraction between [20,N(eve)-1].
479 //
480 // MakeForestRnd();
481 // exit(1);
482 //
483
484 // Init RuleFit object and create rule ensemble
485 // + make forest & rules
486 fRuleFit.Initialize( this );
487
488 // Make forest of decision trees
489 // if (fRuleFit.GetRuleEnsemble().DoRules()) fRuleFit.MakeForest();
490
491 // Fit the rules
492 Log() << kDEBUG << "Fitting rule coefficients ..." << Endl;
493 fRuleFit.FitCoefficients();
494
495 // Calculate importance
496 Log() << kDEBUG << "Computing rule and variable importance" << Endl;
497 fRuleFit.CalcImportance();
498
499 // Output results and fill monitor ntuple
500 fRuleFit.GetRuleEnsemblePtr()->Print();
501 //
502 if(!IsSilentFile())
503 {
504 Log() << kDEBUG << "Filling rule ntuple" << Endl;
505 UInt_t nrules = fRuleFit.GetRuleEnsemble().GetRulesConst().size();
506 const Rule *rule;
507 for (UInt_t i=0; i<nrules; i++ ) {
508 rule = fRuleFit.GetRuleEnsemble().GetRulesConst(i);
510 fNTSupport = rule->GetSupport();
512 fNTType = (rule->IsSignalRule() ? 1:-1 );
513 fNTNvars = rule->GetRuleCut()->GetNvars();
514 fNTNcuts = rule->GetRuleCut()->GetNcuts();
515 fNTPtag = fRuleFit.GetRuleEnsemble().GetRulePTag(i); // should be identical with support
516 fNTPss = fRuleFit.GetRuleEnsemble().GetRulePSS(i);
517 fNTPsb = fRuleFit.GetRuleEnsemble().GetRulePSB(i);
518 fNTPbs = fRuleFit.GetRuleEnsemble().GetRulePBS(i);
519 fNTPbb = fRuleFit.GetRuleEnsemble().GetRulePBB(i);
520 fNTSSB = rule->GetSSB();
521 fMonitorNtuple->Fill();
522 }
523
524 fRuleFit.MakeVisHists();
525 fRuleFit.MakeDebugHists();
526 }
527 Log() << kDEBUG << "Training done" << Endl;
528
529}
530
531////////////////////////////////////////////////////////////////////////////////
532/// training of rules using Jerome Friedmans implementation
533
535{
536 fRuleFit.InitPtrs( this );
537 Data()->SetCurrentType(Types::kTraining);
538 UInt_t nevents = Data()->GetNTrainingEvents();
539 std::vector<const TMVA::Event*> tmp;
540 for (Long64_t ievt=0; ievt<nevents; ievt++) {
541 const Event *event = GetEvent(ievt);
542 tmp.push_back(event);
543 }
544 fRuleFit.SetTrainingEvents( tmp );
545
546 RuleFitAPI *rfAPI = new RuleFitAPI( this, &fRuleFit, Log().GetMinType() );
547
548 rfAPI->WelcomeMessage();
549
550 // timer
551 Timer timer( 1, GetName() );
552
553 Log() << kINFO << "Training ..." << Endl;
554 rfAPI->TrainRuleFit();
555
556 Log() << kDEBUG << "reading model summary from rf_go.exe output" << Endl;
557 rfAPI->ReadModelSum();
558
559 // fRuleFit.GetRuleEnsemblePtr()->MakeRuleMap();
560
561 Log() << kDEBUG << "calculating rule and variable importance" << Endl;
562 fRuleFit.CalcImportance();
563
564 // Output results and fill monitor ntuple
565 fRuleFit.GetRuleEnsemblePtr()->Print();
566 //
567 if(!IsSilentFile())fRuleFit.MakeVisHists();
568
569 delete rfAPI;
570
571 Log() << kDEBUG << "done training" << Endl;
572}
573
574////////////////////////////////////////////////////////////////////////////////
575/// computes ranking of input variables
576
578{
579 // create the ranking object
580 fRanking = new Ranking( GetName(), "Importance" );
581
582 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
583 fRanking->AddRank( Rank( GetInputLabel(ivar), fRuleFit.GetRuleEnsemble().GetVarImportance(ivar) ) );
584 }
585
586 return fRanking;
587}
588
589////////////////////////////////////////////////////////////////////////////////
590/// add the rules to XML node
591
592void TMVA::MethodRuleFit::AddWeightsXMLTo( void* parent ) const
593{
594 fRuleFit.GetRuleEnsemble().AddXMLTo( parent );
595}
596
597////////////////////////////////////////////////////////////////////////////////
598/// read rules from an std::istream
599
601{
602 fRuleFit.GetRuleEnsemblePtr()->ReadRaw( istr );
603}
604
605////////////////////////////////////////////////////////////////////////////////
606/// read rules from XML node
607
609{
610 fRuleFit.GetRuleEnsemblePtr()->ReadFromXML( wghtnode );
611}
612
613////////////////////////////////////////////////////////////////////////////////
614/// returns MVA value for given event
615
617{
618 // cannot determine error
619 NoErrorCalc(err, errUpper);
620
621 return fRuleFit.EvalEvent( *GetEvent() );
622}
623
624////////////////////////////////////////////////////////////////////////////////
625/// write special monitoring histograms to file (here ntuple)
626
628{
629 BaseDir()->cd();
630 Log() << kINFO << "Write monitoring ntuple to file: " << BaseDir()->GetPath() << Endl;
631 fMonitorNtuple->Write();
632}
633
634////////////////////////////////////////////////////////////////////////////////
635/// write specific classifier response
636
637void TMVA::MethodRuleFit::MakeClassSpecific( std::ostream& fout, const TString& className ) const
638{
639 Int_t dp = fout.precision();
640 fout << " // not implemented for class: \"" << className << "\"" << std::endl;
641 fout << "};" << std::endl;
642 fout << "void " << className << "::Initialize(){}" << std::endl;
643 fout << "void " << className << "::Clear(){}" << std::endl;
644 fout << "double " << className << "::GetMvaValue__( const std::vector<double>& inputValues ) const {" << std::endl;
645 fout << " double rval=" << std::setprecision(10) << fRuleFit.GetRuleEnsemble().GetOffset() << ";" << std::endl;
646 MakeClassRuleCuts(fout);
647 MakeClassLinear(fout);
648 fout << " return rval;" << std::endl;
649 fout << "}" << std::endl;
650 fout << std::setprecision(dp);
651}
652
653////////////////////////////////////////////////////////////////////////////////
654/// print out the rule cuts
655
656void TMVA::MethodRuleFit::MakeClassRuleCuts( std::ostream& fout ) const
657{
658 Int_t dp = fout.precision();
659 if (!fRuleFit.GetRuleEnsemble().DoRules()) {
660 fout << " //" << std::endl;
661 fout << " // ==> MODEL CONTAINS NO RULES <==" << std::endl;
662 fout << " //" << std::endl;
663 return;
664 }
665 const RuleEnsemble *rens = &(fRuleFit.GetRuleEnsemble());
666 const std::vector< Rule* > *rules = &(rens->GetRulesConst());
667 const RuleCut *ruleCut;
668 //
669 std::list< std::pair<Double_t,Int_t> > sortedRules;
670 for (UInt_t ir=0; ir<rules->size(); ir++) {
671 sortedRules.push_back( std::pair<Double_t,Int_t>( (*rules)[ir]->GetImportance()/rens->GetImportanceRef(),ir ) );
672 }
673 sortedRules.sort();
674 //
675 fout << " //" << std::endl;
676 fout << " // here follows all rules ordered in importance (most important first)" << std::endl;
677 fout << " // at the end of each line, the relative importance of the rule is given" << std::endl;
678 fout << " //" << std::endl;
679 //
680 for ( std::list< std::pair<double,int> >::reverse_iterator itpair = sortedRules.rbegin();
681 itpair != sortedRules.rend(); ++itpair ) {
682 UInt_t ir = itpair->second;
683 Double_t impr = itpair->first;
684 ruleCut = (*rules)[ir]->GetRuleCut();
685 if (impr<rens->GetImportanceCut()) fout << " //" << std::endl;
686 fout << " if (" << std::flush;
687 for (UInt_t ic=0; ic<ruleCut->GetNvars(); ic++) {
688 Double_t sel = ruleCut->GetSelector(ic);
689 Double_t valmin = ruleCut->GetCutMin(ic);
690 Double_t valmax = ruleCut->GetCutMax(ic);
691 Bool_t domin = ruleCut->GetCutDoMin(ic);
692 Bool_t domax = ruleCut->GetCutDoMax(ic);
693 //
694 if (ic>0) fout << "&&" << std::flush;
695 if (domin) {
696 fout << "(" << std::setprecision(10) << valmin << std::flush;
697 fout << "<inputValues[" << sel << "])" << std::flush;
698 }
699 if (domax) {
700 if (domin) fout << "&&" << std::flush;
701 fout << "(inputValues[" << sel << "]" << std::flush;
702 fout << "<" << std::setprecision(10) << valmax << ")" <<std::flush;
703 }
704 }
705 fout << ") rval+=" << std::setprecision(10) << (*rules)[ir]->GetCoefficient() << ";" << std::flush;
706 fout << " // importance = " << TString::Format("%3.3f",impr) << std::endl;
707 }
708 fout << std::setprecision(dp);
709}
710
711////////////////////////////////////////////////////////////////////////////////
712/// print out the linear terms
713
714void TMVA::MethodRuleFit::MakeClassLinear( std::ostream& fout ) const
715{
716 if (!fRuleFit.GetRuleEnsemble().DoLinear()) {
717 fout << " //" << std::endl;
718 fout << " // ==> MODEL CONTAINS NO LINEAR TERMS <==" << std::endl;
719 fout << " //" << std::endl;
720 return;
721 }
722 fout << " //" << std::endl;
723 fout << " // here follows all linear terms" << std::endl;
724 fout << " // at the end of each line, the relative importance of the term is given" << std::endl;
725 fout << " //" << std::endl;
726 const RuleEnsemble *rens = &(fRuleFit.GetRuleEnsemble());
727 UInt_t nlin = rens->GetNLinear();
728 for (UInt_t il=0; il<nlin; il++) {
729 if (rens->IsLinTermOK(il)) {
730 Double_t norm = rens->GetLinNorm(il);
731 Double_t imp = rens->GetLinImportance(il)/rens->GetImportanceRef();
732 fout << " rval+="
733 // << std::setprecision(10) << rens->GetLinCoefficients(il)*norm << "*std::min(" << setprecision(10) << rens->GetLinDP(il)
734 // << ", std::max( inputValues[" << il << "]," << std::setprecision(10) << rens->GetLinDM(il) << "));"
735 << std::setprecision(10) << rens->GetLinCoefficients(il)*norm
736 << "*std::min( double(" << std::setprecision(10) << rens->GetLinDP(il)
737 << "), std::max( double(inputValues[" << il << "]), double(" << std::setprecision(10) << rens->GetLinDM(il) << ")));"
738 << std::flush;
739 fout << " // importance = " << TString::Format("%3.3f",imp) << std::endl;
740 }
741 }
742}
743
744////////////////////////////////////////////////////////////////////////////////
745/// get help message text
746///
747/// typical length of text line:
748/// "|--------------------------------------------------------------|"
749
751{
752 TString col = gConfig().WriteOptionsReference() ? TString() : gTools().Color("bold");
753 TString colres = gConfig().WriteOptionsReference() ? TString() : gTools().Color("reset");
754 TString brk = gConfig().WriteOptionsReference() ? "<br>" : "";
755
756 Log() << Endl;
757 Log() << col << "--- Short description:" << colres << Endl;
758 Log() << Endl;
759 Log() << "This method uses a collection of so called rules to create a" << Endl;
760 Log() << "discriminating scoring function. Each rule consists of a series" << Endl;
761 Log() << "of cuts in parameter space. The ensemble of rules are created" << Endl;
762 Log() << "from a forest of decision trees, trained using the training data." << Endl;
763 Log() << "Each node (apart from the root) corresponds to one rule." << Endl;
764 Log() << "The scoring function is then obtained by linearly combining" << Endl;
765 Log() << "the rules. A fitting procedure is applied to find the optimum" << Endl;
766 Log() << "set of coefficients. The goal is to find a model with few rules" << Endl;
767 Log() << "but with a strong discriminating power." << Endl;
768 Log() << Endl;
769 Log() << col << "--- Performance optimisation:" << colres << Endl;
770 Log() << Endl;
771 Log() << "There are two important considerations to make when optimising:" << Endl;
772 Log() << Endl;
773 Log() << " 1. Topology of the decision tree forest" << brk << Endl;
774 Log() << " 2. Fitting of the coefficients" << Endl;
775 Log() << Endl;
776 Log() << "The maximum complexity of the rules is defined by the size of" << Endl;
777 Log() << "the trees. Large trees will yield many complex rules and capture" << Endl;
778 Log() << "higher order correlations. On the other hand, small trees will" << Endl;
779 Log() << "lead to a smaller ensemble with simple rules, only capable of" << Endl;
780 Log() << "modeling simple structures." << Endl;
781 Log() << "Several parameters exists for controlling the complexity of the" << Endl;
782 Log() << "rule ensemble." << Endl;
783 Log() << Endl;
784 Log() << "The fitting procedure searches for a minimum using a gradient" << Endl;
785 Log() << "directed path. Apart from step size and number of steps, the" << Endl;
786 Log() << "evolution of the path is defined by a cut-off parameter, tau." << Endl;
787 Log() << "This parameter is unknown and depends on the training data." << Endl;
788 Log() << "A large value will tend to give large weights to a few rules." << Endl;
789 Log() << "Similarly, a small value will lead to a large set of rules" << Endl;
790 Log() << "with similar weights." << Endl;
791 Log() << Endl;
792 Log() << "A final point is the model used; rules and/or linear terms." << Endl;
793 Log() << "For a given training sample, the result may improve by adding" << Endl;
794 Log() << "linear terms. If best performance is obtained using only linear" << Endl;
795 Log() << "terms, it is very likely that the Fisher discriminant would be" << Endl;
796 Log() << "a better choice. Ideally the fitting procedure should be able to" << Endl;
797 Log() << "make this choice by giving appropriate weights for either terms." << Endl;
798 Log() << Endl;
799 Log() << col << "--- Performance tuning via configuration options:" << colres << Endl;
800 Log() << Endl;
801 Log() << "I. TUNING OF RULE ENSEMBLE:" << Endl;
802 Log() << Endl;
803 Log() << " " << col << "ForestType " << colres
804 << ": Recommended is to use the default \"AdaBoost\"." << brk << Endl;
805 Log() << " " << col << "nTrees " << colres
806 << ": More trees leads to more rules but also slow" << Endl;
807 Log() << " performance. With too few trees the risk is" << Endl;
808 Log() << " that the rule ensemble becomes too simple." << brk << Endl;
809 Log() << " " << col << "fEventsMin " << colres << brk << Endl;
810 Log() << " " << col << "fEventsMax " << colres
811 << ": With a lower min, more large trees will be generated" << Endl;
812 Log() << " leading to more complex rules." << Endl;
813 Log() << " With a higher max, more small trees will be" << Endl;
814 Log() << " generated leading to more simple rules." << Endl;
815 Log() << " By changing this range, the average complexity" << Endl;
816 Log() << " of the rule ensemble can be controlled." << brk << Endl;
817 Log() << " " << col << "RuleMinDist " << colres
818 << ": By increasing the minimum distance between" << Endl;
819 Log() << " rules, fewer and more diverse rules will remain." << Endl;
820 Log() << " Initially it is a good idea to keep this small" << Endl;
821 Log() << " or zero and let the fitting do the selection of" << Endl;
822 Log() << " rules. In order to reduce the ensemble size," << Endl;
823 Log() << " the value can then be increased." << Endl;
824 Log() << Endl;
825 // "|--------------------------------------------------------------|"
826 Log() << "II. TUNING OF THE FITTING:" << Endl;
827 Log() << Endl;
828 Log() << " " << col << "GDPathEveFrac " << colres
829 << ": fraction of events in path evaluation" << Endl;
830 Log() << " Increasing this fraction will improve the path" << Endl;
831 Log() << " finding. However, a too high value will give few" << Endl;
832 Log() << " unique events available for error estimation." << Endl;
833 Log() << " It is recommended to use the default = 0.5." << brk << Endl;
834 Log() << " " << col << "GDTau " << colres
835 << ": cutoff parameter tau" << Endl;
836 Log() << " By default this value is set to -1.0." << Endl;
837 // "|----------------|---------------------------------------------|"
838 Log() << " This means that the cut off parameter is" << Endl;
839 Log() << " automatically estimated. In most cases" << Endl;
840 Log() << " this should be fine. However, you may want" << Endl;
841 Log() << " to fix this value if you already know it" << Endl;
842 Log() << " and want to reduce on training time." << brk << Endl;
843 Log() << " " << col << "GDTauPrec " << colres
844 << ": precision of estimated tau" << Endl;
845 Log() << " Increase this precision to find a more" << Endl;
846 Log() << " optimum cut-off parameter." << brk << Endl;
847 Log() << " " << col << "GDNStep " << colres
848 << ": number of steps in path search" << Endl;
849 Log() << " If the number of steps is too small, then" << Endl;
850 Log() << " the program will give a warning message." << Endl;
851 Log() << Endl;
852 Log() << "III. WARNING MESSAGES" << Endl;
853 Log() << Endl;
854 Log() << col << "Risk(i+1)>=Risk(i) in path" << colres << brk << Endl;
855 Log() << col << "Chaotic behaviour of risk evolution." << colres << Endl;
856 // "|----------------|---------------------------------------------|"
857 Log() << " The error rate was still decreasing at the end" << Endl;
858 Log() << " By construction the Risk should always decrease." << Endl;
859 Log() << " However, if the training sample is too small or" << Endl;
860 Log() << " the model is overtrained, such warnings can" << Endl;
861 Log() << " occur." << Endl;
862 Log() << " The warnings can safely be ignored if only a" << Endl;
863 Log() << " few (<3) occur. If more warnings are generated," << Endl;
864 Log() << " the fitting fails." << Endl;
865 Log() << " A remedy may be to increase the value" << brk << Endl;
866 Log() << " "
867 << col << "GDValidEveFrac" << colres
868 << " to 1.0 (or a larger value)." << brk << Endl;
869 Log() << " In addition, if "
870 << col << "GDPathEveFrac" << colres
871 << " is too high" << Endl;
872 Log() << " the same warnings may occur since the events" << Endl;
873 Log() << " used for error estimation are also used for" << Endl;
874 Log() << " path estimation." << Endl;
875 Log() << " Another possibility is to modify the model - " << Endl;
876 Log() << " See above on tuning the rule ensemble." << Endl;
877 Log() << Endl;
878 Log() << col << "The error rate was still decreasing at the end of the path"
879 << colres << Endl;
880 Log() << " Too few steps in path! Increase "
881 << col << "GDNSteps" << colres << "." << Endl;
882 Log() << Endl;
883 Log() << col << "Reached minimum early in the search" << colres << Endl;
884
885 Log() << " Minimum was found early in the fitting. This" << Endl;
886 Log() << " may indicate that the used step size "
887 << col << "GDStep" << colres << "." << Endl;
888 Log() << " was too large. Reduce it and rerun." << Endl;
889 Log() << " If the results still are not OK, modify the" << Endl;
890 Log() << " model either by modifying the rule ensemble" << Endl;
891 Log() << " or add/remove linear terms" << Endl;
892}
#define REGISTER_METHOD(CLASS)
for example
#define e(i)
Definition RSha256.hxx:103
int Int_t
Signed integer 4 bytes (int).
Definition RtypesCore.h:59
unsigned int UInt_t
Unsigned integer 4 bytes (unsigned int).
Definition RtypesCore.h:60
bool Bool_t
Boolean (0=false, 1=true) (bool).
Definition RtypesCore.h:77
constexpr Bool_t kFALSE
Definition RtypesCore.h:108
double Double_t
Double 8 bytes.
Definition RtypesCore.h:73
long long Long64_t
Portable signed long integer 8 bytes.
Definition RtypesCore.h:83
constexpr Bool_t kTRUE
Definition RtypesCore.h:107
Double_t err
Bool_t WriteOptionsReference() const
Definition Config.h:65
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
void AddPreDefVal(const T &)
MsgLogger & Log() const
Implementation of the CrossEntropy as separation criterion.
Class that contains all the data information.
Definition DataSetInfo.h:62
static void SetIsTraining(bool on)
Implementation of a Decision Tree.
Implementation of the GiniIndex as separation criterion.
Definition GiniIndex.h:63
MethodBase(const TString &jobName, Types::EMVA methodType, const TString &methodTitle, DataSetInfo &dsi, const TString &theOption="")
standard constructor
Bool_t HasTrainingTree() const
Definition MethodBase.h:516
const char * GetName() const override
Definition MethodBase.h:337
TString GetMethodTypeName() const
Definition MethodBase.h:335
Bool_t IgnoreEventsWithNegWeightsInTraining() const
Definition MethodBase.h:689
TDirectory * BaseDir() const
returns the ROOT directory where info/histograms etc of the corresponding MVA method instance are sto...
void ExitFromTraining()
Definition MethodBase.h:467
UInt_t GetNEvents() const
Definition MethodBase.h:419
const Event * GetEvent() const
Definition MethodBase.h:754
UInt_t GetNvar() const
Definition MethodBase.h:347
Bool_t IsSilentFile() const
Definition MethodBase.h:382
void SetSignalReferenceCut(Double_t cut)
Definition MethodBase.h:367
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
const TString & GetInputLabel(Int_t i) const
Definition MethodBase.h:353
Ranking * fRanking
Definition MethodBase.h:590
DataSet * Data() const
Definition MethodBase.h:412
Bool_t IsNormalised() const
Definition MethodBase.h:499
RuleFit fRuleFit
RuleFit instance.
UInt_t fGDTauScan
GD path: number of points to scan.
Double_t fNTPss
ntuple: rule P(tag s, true s)
TString fForestTypeS
forest generation: how the trees are generated
Double_t fMinimp
rule/linear: minimum importance
TString fRuleFitModuleS
which rulefit module to use
Double_t fLinQuantile
quantile cut to remove outliers - see RuleEnsemble
void GetHelpMessage() const override
get help message text
void ReadWeightsFromXML(void *wghtnode) override
read rules from XML node
Double_t fMinFracNEve
min fraction of number events
Int_t fNTType
ntuple: rule type (+1->signal, -1->bkg)
Bool_t fUseRuleFitJF
if true interface with J.Friedmans RuleFit module
Double_t fGDTauMax
GD path: max threshold fraction [0..1].
void DeclareOptions() override
define the options (their key words) that can be set in the option string know options.
Double_t fGDPathEveFrac
GD path: fraction of subsamples used for the fitting.
Bool_t fUseBoost
use boosted events for forest generation
TMVA::DecisionTree::EPruneMethod fPruneMethod
forest generation: method used for pruning - see DecisionTree
std::vector< DecisionTree * > fForest
the forest
Double_t fMaxFracNEve
ditto max
Double_t GetMvaValue(Double_t *err=nullptr, Double_t *errUpper=nullptr) override
returns MVA value for given event
TString fRFWorkDir
working directory from Friedmans module
Double_t fTreeEveFrac
fraction of events used for training each tree
void MakeClassLinear(std::ostream &) const
print out the linear terms
Double_t fNTPsb
ntuple: rule P(tag s, true b)
Int_t fNTNvars
ntuple: rule number of vars
Int_t fGDNPathSteps
GD path: number of steps.
std::vector< TMVA::Event * > fEventSample
the complete training sample
Double_t fNTPbb
ntuple: rule P(tag b, true b)
Double_t fNTSSB
ntuple: rule S/(S+B)
void TrainJFRuleFit()
training of rules using Jerome Friedmans implementation
Int_t fNTNcuts
ntuple: rule number of cuts
TString fModelTypeS
rule ensemble: which model (rule,linear or both)
Double_t fNTImportance
ntuple: rule importance
Double_t fGDValidEveFrac
GD path: fraction of subsamples used for the fitting.
void InitEventSample(void)
write all Events from the Tree into a vector of Events, that are more easily manipulated.
void MakeClassRuleCuts(std::ostream &) const
print out the rule cuts
Double_t fNTCoefficient
ntuple: rule coefficient
void Init(void) override
default initialization
TString fSepTypeS
forest generation: separation type - see DecisionTree
void InitMonitorNtuple()
initialize the monitoring ntuple
Int_t fNTrees
number of trees in forest
Double_t fNTPtag
ntuple: rule P(tag)
virtual ~MethodRuleFit(void)
destructor
Int_t fRFNendnodes
max number of rules (only Friedmans module)
Double_t fNTPbs
ntuple: rule P(tag b, true s)
void ReadWeightsFromStream(std::istream &istr) override
read rules from an std::istream
TTree * fMonitorNtuple
pointer to monitor rule ntuple
void ProcessOptions() override
process the options specified by the user
void AddWeightsXMLTo(void *parent) const override
add the rules to XML node
MethodRuleFit(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
standard constructor
Double_t fNTSupport
ntuple: rule support
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t) override
RuleFit can handle classification with 2 classes.
Double_t fGDTauMin
GD path: min threshold fraction [0..1].
Double_t fGDTau
GD path: def threshold fraction [0..1].
SeparationBase * fSepType
the separation used in node splitting
const Ranking * CreateRanking() override
computes ranking of input variables
Double_t fGDPathStep
GD path: step size in path.
Double_t fGDTauPrec
GD path: precision of estimated tau.
void MakeClassSpecific(std::ostream &, const TString &) const override
write specific classifier response
Double_t fPruneStrength
forest generation: prune strength - see DecisionTree
void Train(void) override
Int_t fNCuts
grid used in cut applied in node splitting
void WriteMonitoringHistosToFile(void) const override
write special monitoring histograms to file (here ntuple)
Double_t fGDErrScale
GD path: stop.
TString fPruneMethodS
forest generation: prune method - see DecisionTree
Double_t fRuleMinDist
rule min distance - see RuleEnsemble
Int_t fRFNrules
max number of rules (only Friedmans module)
Bool_t VerifyRange(MsgLogger &mlog, const char *varstr, T &var, const T &vmin, const T &vmax)
void TrainTMVARuleFit()
training of rules using TMVA implementation
Double_t fSignalFraction
scalefactor for bkg events to modify initial s/b fraction in training data
Implementation of the MisClassificationError as separation criterion.
Ranking for variables in method (implementation).
Definition Ranking.h:48
A class describing a 'rule cut'.
Definition RuleCut.h:36
UInt_t GetNvars() const
Definition RuleCut.h:72
Double_t GetCutMin(Int_t is) const
Definition RuleCut.h:74
UInt_t GetSelector(Int_t is) const
Definition RuleCut.h:73
Char_t GetCutDoMin(Int_t is) const
Definition RuleCut.h:76
Char_t GetCutDoMax(Int_t is) const
Definition RuleCut.h:77
UInt_t GetNcuts() const
get number of cuts
Definition RuleCut.cxx:164
Double_t GetCutMax(Int_t is) const
Definition RuleCut.h:75
Double_t GetLinDP(int i) const
Double_t GetLinDM(int i) const
const std::vector< Double_t > & GetLinCoefficients() const
Double_t GetImportanceRef() const
const std::vector< Double_t > & GetLinNorm() const
UInt_t GetNLinear() const
const std::vector< TMVA::Rule * > & GetRulesConst() const
const std::vector< Double_t > & GetLinImportance() const
Bool_t IsLinTermOK(int i) const
J Friedman's RuleFit method.
Definition RuleFitAPI.h:51
Bool_t ReadModelSum()
read model from rulefit.sum
void WelcomeMessage()
welcome message
Implementation of a rule.
Definition Rule.h:50
Double_t GetSupport() const
Definition Rule.h:142
const RuleCut * GetRuleCut() const
Definition Rule.h:139
Bool_t IsSignalRule() const
Definition Rule.h:119
Double_t GetCoefficient() const
Definition Rule.h:141
Double_t GetSSB() const
Definition Rule.h:117
Double_t GetRelImportance() const
Definition Rule.h:102
Implementation of the SdivSqrtSplusB as separation criterion.
Timing information for training and evaluation of MVA methods.
Definition Timer.h:58
const TString & Color(const TString &)
human readable color strings
Definition Tools.cxx:803
Singleton class for Global types used by TMVA.
Definition Types.h:71
@ kClassification
Definition Types.h:127
@ kTraining
Definition Types.h:143
Basic string class.
Definition TString.h:138
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2385
A TTree represents a columnar dataset.
Definition TTree.h:89
const Int_t n
Definition legend1.C:16
create variable transformations
Config & gConfig()
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148