Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodRuleFit.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Fredrik Tegenfeldt
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodRuleFit *
8 * *
9 * *
10 * Description: *
11 * Implementation (see header file for description) *
12 * *
13 * Authors (alphabetical): *
14 * Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA *
15 * *
16 * Copyright (c) 2005: *
17 * CERN, Switzerland *
18 * Iowa State U. *
19 * MPI-K Heidelberg, Germany *
20 * *
21 * Redistribution and use in source and binary forms, with or without *
22 * modification, are permitted according to the terms listed in LICENSE *
23 * (see tmva/doc/LICENSE) *
24 **********************************************************************************/
25
26/*! \class TMVA::MethodRuleFit
27\ingroup TMVA
28J Friedman's RuleFit method
29*/
30
31#include "TMVA/MethodRuleFit.h"
32
34#include "TMVA/Config.h"
35#include "TMVA/Configurable.h"
36#include "TMVA/CrossEntropy.h"
37#include "TMVA/DataSet.h"
38#include "TMVA/DecisionTree.h"
39#include "TMVA/GiniIndex.h"
40#include "TMVA/IMethod.h"
41#include "TMVA/MethodBase.h"
43#include "TMVA/MsgLogger.h"
44#include "TMVA/Ranking.h"
45#include "TMVA/RuleFitAPI.h"
46#include "TMVA/SdivSqrtSplusB.h"
47#include "TMVA/SeparationBase.h"
48#include "TMVA/Timer.h"
49#include "TMVA/Tools.h"
50#include "TMVA/Types.h"
51
52#include "TRandom3.h"
53#include "TMatrix.h"
54
55#include <iostream>
56#include <iomanip>
57#include <algorithm>
58#include <list>
59#include <random>
60
61using std::min;
62
63REGISTER_METHOD(RuleFit)
64
65
66////////////////////////////////////////////////////////////////////////////////
67/// standard constructor
68
70 const TString& methodTitle,
73 MethodBase( jobName, Types::kRuleFit, methodTitle, theData, theOption)
74 , fSignalFraction(0)
75 , fNTImportance(0)
76 , fNTCoefficient(0)
77 , fNTSupport(0)
78 , fNTNcuts(0)
79 , fNTNvars(0)
80 , fNTPtag(0)
81 , fNTPss(0)
82 , fNTPsb(0)
83 , fNTPbs(0)
84 , fNTPbb(0)
85 , fNTSSB(0)
86 , fNTType(0)
87 , fUseRuleFitJF(kFALSE)
88 , fRFNrules(0)
89 , fRFNendnodes(0)
90 , fNTrees(0)
91 , fTreeEveFrac(0)
92 , fSepType(0)
93 , fMinFracNEve(0)
94 , fMaxFracNEve(0)
95 , fNCuts(0)
96 , fPruneMethod(TMVA::DecisionTree::kCostComplexityPruning)
97 , fPruneStrength(0)
98 , fUseBoost(kFALSE)
99 , fGDPathEveFrac(0)
100 , fGDValidEveFrac(0)
101 , fGDTau(0)
102 , fGDTauPrec(0)
103 , fGDTauMin(0)
104 , fGDTauMax(0)
105 , fGDTauScan(0)
106 , fGDPathStep(0)
107 , fGDNPathSteps(0)
108 , fGDErrScale(0)
109 , fMinimp(0)
110 , fRuleMinDist(0)
111 , fLinQuantile(0)
112{
113 fMonitorNtuple = NULL;
114}
115
116////////////////////////////////////////////////////////////////////////////////
117/// constructor from weight file
118
120 const TString& theWeightFile) :
122 , fSignalFraction(0)
123 , fNTImportance(0)
124 , fNTCoefficient(0)
125 , fNTSupport(0)
126 , fNTNcuts(0)
127 , fNTNvars(0)
128 , fNTPtag(0)
129 , fNTPss(0)
130 , fNTPsb(0)
131 , fNTPbs(0)
132 , fNTPbb(0)
133 , fNTSSB(0)
134 , fNTType(0)
135 , fUseRuleFitJF(kFALSE)
136 , fRFNrules(0)
137 , fRFNendnodes(0)
138 , fNTrees(0)
139 , fTreeEveFrac(0)
140 , fSepType(0)
141 , fMinFracNEve(0)
142 , fMaxFracNEve(0)
143 , fNCuts(0)
144 , fPruneMethod(TMVA::DecisionTree::kCostComplexityPruning)
145 , fPruneStrength(0)
146 , fUseBoost(kFALSE)
147 , fGDPathEveFrac(0)
148 , fGDValidEveFrac(0)
149 , fGDTau(0)
150 , fGDTauPrec(0)
151 , fGDTauMin(0)
152 , fGDTauMax(0)
153 , fGDTauScan(0)
154 , fGDPathStep(0)
155 , fGDNPathSteps(0)
156 , fGDErrScale(0)
157 , fMinimp(0)
158 , fRuleMinDist(0)
159 , fLinQuantile(0)
160{
162}
163
164////////////////////////////////////////////////////////////////////////////////
165/// destructor
166
168{
169 for (UInt_t i=0; i<fEventSample.size(); i++) delete fEventSample[i];
170 for (UInt_t i=0; i<fForest.size(); i++) delete fForest[i];
171}
172
173////////////////////////////////////////////////////////////////////////////////
174/// RuleFit can handle classification with 2 classes
175
181
182////////////////////////////////////////////////////////////////////////////////
183/// define the options (their key words) that can be set in the option string
184/// know options.
185///
186/// #### general
187///
188/// - RuleFitModule `<string>`
189/// available values are:
190/// - RFTMVA - use TMVA implementation
191/// - RFFriedman - use Friedmans original implementation
192///
193/// #### Path search (fitting)
194///
195/// - GDTau `<float>` gradient-directed path: fit threshold, default
196/// - GDTauPrec `<float>` gradient-directed path: precision of estimated tau
197/// - GDStep `<float>` gradient-directed path: step size
198/// - GDNSteps `<float>` gradient-directed path: number of steps
199/// - GDErrScale `<float>` stop scan when error>scale*errmin
200///
201/// #### Tree generation
202///
203/// - fEventsMin `<float>` minimum fraction of events in a splittable node
204/// - fEventsMax `<float>` maximum fraction of events in a splittable node
205/// - nTrees `<float>` number of trees in forest.
206/// - ForestType `<string>`
207/// available values are:
208/// - Random - create forest using random subsample and only random variables subset at each node
209/// - AdaBoost - create forest with boosted events
210///
211/// #### Model creation
212///
213/// - RuleMinDist `<float>` min distance allowed between rules
214/// - MinImp `<float>` minimum rule importance accepted
215/// - Model `<string>` model to be used
216/// available values are:
217/// - ModRuleLinear `<default>`
218/// - ModRule
219/// - ModLinear
220///
221/// #### Friedmans module
222///
223/// - RFWorkDir `<string>` directory where Friedmans module (rf_go.exe) is installed
224/// - RFNrules `<int>` maximum number of rules allowed
225/// - RFNendnodes `<int>` average number of end nodes in the forest of trees
226
228{
229 DeclareOptionRef(fGDTau=-1, "GDTau", "Gradient-directed (GD) path: default fit cut-off");
230 DeclareOptionRef(fGDTauPrec=0.01, "GDTauPrec", "GD path: precision of tau");
231 DeclareOptionRef(fGDPathStep=0.01, "GDStep", "GD path: step size");
232 DeclareOptionRef(fGDNPathSteps=10000, "GDNSteps", "GD path: number of steps");
233 DeclareOptionRef(fGDErrScale=1.1, "GDErrScale", "Stop scan when error > scale*errmin");
234 DeclareOptionRef(fLinQuantile, "LinQuantile", "Quantile of linear terms (removes outliers)");
235 DeclareOptionRef(fGDPathEveFrac=0.5, "GDPathEveFrac", "Fraction of events used for the path search");
236 DeclareOptionRef(fGDValidEveFrac=0.5, "GDValidEveFrac", "Fraction of events used for the validation");
237 // tree options
238 DeclareOptionRef(fMinFracNEve=0.1, "fEventsMin", "Minimum fraction of events in a splittable node");
239 DeclareOptionRef(fMaxFracNEve=0.9, "fEventsMax", "Maximum fraction of events in a splittable node");
240 DeclareOptionRef(fNTrees=20, "nTrees", "Number of trees in forest.");
241
242 DeclareOptionRef(fForestTypeS="AdaBoost", "ForestType", "Method to use for forest generation (AdaBoost or RandomForest)");
243 AddPreDefVal(TString("AdaBoost"));
244 AddPreDefVal(TString("Random"));
245 // rule cleanup options
246 DeclareOptionRef(fRuleMinDist=0.001, "RuleMinDist", "Minimum distance between rules");
247 DeclareOptionRef(fMinimp=0.01, "MinImp", "Minimum rule importance accepted");
248 // rule model option
249 DeclareOptionRef(fModelTypeS="ModRuleLinear", "Model", "Model to be used");
250 AddPreDefVal(TString("ModRule"));
251 AddPreDefVal(TString("ModRuleLinear"));
252 AddPreDefVal(TString("ModLinear"));
253 DeclareOptionRef(fRuleFitModuleS="RFTMVA", "RuleFitModule","Which RuleFit module to use");
254 AddPreDefVal(TString("RFTMVA"));
255 AddPreDefVal(TString("RFFriedman"));
256
257 DeclareOptionRef(fRFWorkDir="./rulefit", "RFWorkDir", "Friedman\'s RuleFit module (RFF): working dir");
258 DeclareOptionRef(fRFNrules=2000, "RFNrules", "RFF: Mximum number of rules");
259 DeclareOptionRef(fRFNendnodes=4, "RFNendnodes", "RFF: Average number of end nodes");
260}
261
262////////////////////////////////////////////////////////////////////////////////
263/// process the options specified by the user
264
266{
267 if (IgnoreEventsWithNegWeightsInTraining()) {
268 Log() << kFATAL << "Mechanism to ignore events with negative weights in training not yet available for method: "
269 << GetMethodTypeName()
270 << " --> please remove \"IgnoreNegWeightsInTraining\" option from booking string."
271 << Endl;
272 }
273
274 fRuleFitModuleS.ToLower();
275 if (fRuleFitModuleS == "rftmva") fUseRuleFitJF = kFALSE;
276 else if (fRuleFitModuleS == "rffriedman") fUseRuleFitJF = kTRUE;
277 else fUseRuleFitJF = kTRUE;
278
279 fSepTypeS.ToLower();
280 if (fSepTypeS == "misclassificationerror") fSepType = new MisClassificationError();
281 else if (fSepTypeS == "giniindex") fSepType = new GiniIndex();
282 else if (fSepTypeS == "crossentropy") fSepType = new CrossEntropy();
283 else fSepType = new SdivSqrtSplusB();
284
285 fModelTypeS.ToLower();
286 if (fModelTypeS == "modlinear" ) fRuleFit.SetModelLinear();
287 else if (fModelTypeS == "modrule" ) fRuleFit.SetModelRules();
288 else fRuleFit.SetModelFull();
289
290 fPruneMethodS.ToLower();
291 if (fPruneMethodS == "expectederror" ) fPruneMethod = DecisionTree::kExpectedErrorPruning;
292 else if (fPruneMethodS == "costcomplexity" ) fPruneMethod = DecisionTree::kCostComplexityPruning;
293 else fPruneMethod = DecisionTree::kNoPruning;
294
295 fForestTypeS.ToLower();
296 if (fForestTypeS == "random" ) fUseBoost = kFALSE;
297 else if (fForestTypeS == "adaboost" ) fUseBoost = kTRUE;
298 else fUseBoost = kTRUE;
299 //
300 // if creating the forest by boosting the events
301 // the full training sample is used per tree
302 // -> only true for the TMVA version of RuleFit.
303 if (fUseBoost && (!fUseRuleFitJF)) fTreeEveFrac = 1.0;
304
305 // check event fraction for tree generation
306 // if <0 set to automatic number
307 if (fTreeEveFrac<=0) {
308 Int_t nevents = Data()->GetNTrainingEvents();
309 Double_t n = static_cast<Double_t>(nevents);
310 fTreeEveFrac = min( 0.5, (100.0 +6.0*sqrt(n))/n);
311 }
312 // verify ranges of options
313 VerifyRange(Log(), "nTrees", fNTrees,0,100000,20);
314 VerifyRange(Log(), "MinImp", fMinimp,0.0,1.0,0.0);
315 VerifyRange(Log(), "GDTauPrec", fGDTauPrec,1e-5,5e-1);
316 VerifyRange(Log(), "GDTauMin", fGDTauMin,0.0,1.0);
317 VerifyRange(Log(), "GDTauMax", fGDTauMax,fGDTauMin,1.0);
318 VerifyRange(Log(), "GDPathStep", fGDPathStep,0.0,100.0,0.01);
319 VerifyRange(Log(), "GDErrScale", fGDErrScale,1.0,100.0,1.1);
320 VerifyRange(Log(), "GDPathEveFrac", fGDPathEveFrac,0.01,0.9,0.5);
321 VerifyRange(Log(), "GDValidEveFrac",fGDValidEveFrac,0.01,1.0-fGDPathEveFrac,1.0-fGDPathEveFrac);
322 VerifyRange(Log(), "fEventsMin", fMinFracNEve,0.0,1.0);
323 VerifyRange(Log(), "fEventsMax", fMaxFracNEve,fMinFracNEve,1.0);
324
325 fRuleFit.GetRuleEnsemblePtr()->SetLinQuantile(fLinQuantile);
326 fRuleFit.GetRuleFitParamsPtr()->SetGDTauRange(fGDTauMin,fGDTauMax);
327 fRuleFit.GetRuleFitParamsPtr()->SetGDTau(fGDTau);
328 fRuleFit.GetRuleFitParamsPtr()->SetGDTauPrec(fGDTauPrec);
329 fRuleFit.GetRuleFitParamsPtr()->SetGDTauScan(fGDTauScan);
330 fRuleFit.GetRuleFitParamsPtr()->SetGDPathStep(fGDPathStep);
331 fRuleFit.GetRuleFitParamsPtr()->SetGDNPathSteps(fGDNPathSteps);
332 fRuleFit.GetRuleFitParamsPtr()->SetGDErrScale(fGDErrScale);
333 fRuleFit.SetImportanceCut(fMinimp);
334 fRuleFit.SetRuleMinDist(fRuleMinDist);
335
336
337 // check if Friedmans module is used.
338 // print a message concerning the options.
339 if (fUseRuleFitJF) {
340 Log() << kINFO << "" << Endl;
341 Log() << kINFO << "--------------------------------------" <<Endl;
342 Log() << kINFO << "Friedmans RuleFit module is selected." << Endl;
343 Log() << kINFO << "Only the following options are used:" << Endl;
344 Log() << kINFO << Endl;
345 Log() << kINFO << gTools().Color("bold") << " Model" << gTools().Color("reset") << Endl;
346 Log() << kINFO << gTools().Color("bold") << " RFWorkDir" << gTools().Color("reset") << Endl;
347 Log() << kINFO << gTools().Color("bold") << " RFNrules" << gTools().Color("reset") << Endl;
348 Log() << kINFO << gTools().Color("bold") << " RFNendnodes" << gTools().Color("reset") << Endl;
349 Log() << kINFO << gTools().Color("bold") << " GDNPathSteps" << gTools().Color("reset") << Endl;
350 Log() << kINFO << gTools().Color("bold") << " GDPathStep" << gTools().Color("reset") << Endl;
351 Log() << kINFO << gTools().Color("bold") << " GDErrScale" << gTools().Color("reset") << Endl;
352 Log() << kINFO << "--------------------------------------" <<Endl;
353 Log() << kINFO << Endl;
354 }
355
356 // Select what weight to use in the 'importance' rule visualisation plots.
357 // Note that if UseCoefficientsVisHists() is selected, the following weight is used:
358 // w = rule coefficient * rule support
359 // The support is a positive number which is 0 if no events are accepted by the rule.
360 // Normally the importance gives more useful information.
361 //
362 //fRuleFit.UseCoefficientsVisHists();
363 fRuleFit.UseImportanceVisHists();
364
365 fRuleFit.SetMsgType( Log().GetMinType() );
366
367 if (HasTrainingTree()) InitEventSample();
368
369}
370
371////////////////////////////////////////////////////////////////////////////////
372/// initialize the monitoring ntuple
373
375{
376 BaseDir()->cd();
377 fMonitorNtuple= new TTree("MonitorNtuple_RuleFit","RuleFit variables");
378 fMonitorNtuple->Branch("importance",&fNTImportance,"importance/D");
379 fMonitorNtuple->Branch("support",&fNTSupport,"support/D");
380 fMonitorNtuple->Branch("coefficient",&fNTCoefficient,"coefficient/D");
381 fMonitorNtuple->Branch("ncuts",&fNTNcuts,"ncuts/I");
382 fMonitorNtuple->Branch("nvars",&fNTNvars,"nvars/I");
383 fMonitorNtuple->Branch("type",&fNTType,"type/I");
384 fMonitorNtuple->Branch("ptag",&fNTPtag,"ptag/D");
385 fMonitorNtuple->Branch("pss",&fNTPss,"pss/D");
386 fMonitorNtuple->Branch("psb",&fNTPsb,"psb/D");
387 fMonitorNtuple->Branch("pbs",&fNTPbs,"pbs/D");
388 fMonitorNtuple->Branch("pbb",&fNTPbb,"pbb/D");
389 fMonitorNtuple->Branch("soversb",&fNTSSB,"soversb/D");
390}
391
392////////////////////////////////////////////////////////////////////////////////
393/// default initialization
394
396{
397 // the minimum requirement to declare an event signal-like
398 SetSignalReferenceCut( 0.0 );
399
400 // set variables that used to be options
401 // any modifications are then made in ProcessOptions()
402 fLinQuantile = 0.025; // Quantile of linear terms (remove outliers)
403 fTreeEveFrac = -1.0; // Fraction of events used to train each tree
404 fNCuts = 20; // Number of steps during node cut optimisation
405 fSepTypeS = "GiniIndex"; // Separation criterion for node splitting; see BDT
406 fPruneMethodS = "NONE"; // Pruning method; see BDT
407 fPruneStrength = 3.5; // Pruning strength; see BDT
408 fGDTauMin = 0.0; // Gradient-directed path: min fit threshold (tau)
409 fGDTauMax = 1.0; // Gradient-directed path: max fit threshold (tau)
410 fGDTauScan = 1000; // Gradient-directed path: number of points scanning for best tau
411
412}
413
414////////////////////////////////////////////////////////////////////////////////
415/// write all Events from the Tree into a vector of Events, that are
416/// more easily manipulated.
417/// This method should never be called without existing trainingTree, as it
418/// the vector of events from the ROOT training tree
419
421{
422 if (Data()->GetNEvents()==0) Log() << kFATAL << "<Init> Data().TrainingTree() is zero pointer" << Endl;
423
424 Int_t nevents = Data()->GetNEvents();
425 for (Int_t ievt=0; ievt<nevents; ievt++){
426 const Event * ev = GetEvent(ievt);
427 fEventSample.push_back( new Event(*ev));
428 }
429 if (fTreeEveFrac<=0) {
430 Double_t n = static_cast<Double_t>(nevents);
431 fTreeEveFrac = min( 0.5, (100.0 +6.0*sqrt(n))/n);
432 }
433 if (fTreeEveFrac>1.0) fTreeEveFrac=1.0;
434 //
435 std::shuffle(fEventSample.begin(), fEventSample.end(), std::default_random_engine{});
436 //
437 Log() << kDEBUG << "Set sub-sample fraction to " << fTreeEveFrac << Endl;
438}
439
440////////////////////////////////////////////////////////////////////////////////
441
443{
445 // training of rules
446
447 if(!IsSilentFile()) InitMonitorNtuple();
448
449 // fill the STL Vector with the event sample
450 this->InitEventSample();
451
452 if (fUseRuleFitJF) {
453 TrainJFRuleFit();
454 }
455 else {
456 TrainTMVARuleFit();
457 }
458 fRuleFit.GetRuleEnsemblePtr()->ClearRuleMap();
460 ExitFromTraining();
461}
462
463////////////////////////////////////////////////////////////////////////////////
464/// training of rules using TMVA implementation
465
467{
468 if (IsNormalised()) Log() << kFATAL << "\"Normalise\" option cannot be used with RuleFit; "
469 << "please remove the option from the configuration string, or "
470 << "use \"!Normalise\""
471 << Endl;
472
473 // timer
474 Timer timer( 1, GetName() );
475
476 // test tree nmin cut -> for debug purposes
477 // the routine will generate trees with stopping cut on N(eve) given by
478 // a fraction between [20,N(eve)-1].
479 //
480 // MakeForestRnd();
481 // exit(1);
482 //
483
484 // Init RuleFit object and create rule ensemble
485 // + make forest & rules
486 fRuleFit.Initialize( this );
487
488 // Make forest of decision trees
489 // if (fRuleFit.GetRuleEnsemble().DoRules()) fRuleFit.MakeForest();
490
491 // Fit the rules
492 Log() << kDEBUG << "Fitting rule coefficients ..." << Endl;
493 fRuleFit.FitCoefficients();
494
495 // Calculate importance
496 Log() << kDEBUG << "Computing rule and variable importance" << Endl;
497 fRuleFit.CalcImportance();
498
499 // Output results and fill monitor ntuple
500 fRuleFit.GetRuleEnsemblePtr()->Print();
501 //
502 if(!IsSilentFile())
503 {
504 Log() << kDEBUG << "Filling rule ntuple" << Endl;
505 UInt_t nrules = fRuleFit.GetRuleEnsemble().GetRulesConst().size();
506 const Rule *rule;
507 for (UInt_t i=0; i<nrules; i++ ) {
508 rule = fRuleFit.GetRuleEnsemble().GetRulesConst(i);
509 fNTImportance = rule->GetRelImportance();
510 fNTSupport = rule->GetSupport();
511 fNTCoefficient = rule->GetCoefficient();
512 fNTType = (rule->IsSignalRule() ? 1:-1 );
513 fNTNvars = rule->GetRuleCut()->GetNvars();
514 fNTNcuts = rule->GetRuleCut()->GetNcuts();
515 fNTPtag = fRuleFit.GetRuleEnsemble().GetRulePTag(i); // should be identical with support
516 fNTPss = fRuleFit.GetRuleEnsemble().GetRulePSS(i);
517 fNTPsb = fRuleFit.GetRuleEnsemble().GetRulePSB(i);
518 fNTPbs = fRuleFit.GetRuleEnsemble().GetRulePBS(i);
519 fNTPbb = fRuleFit.GetRuleEnsemble().GetRulePBB(i);
520 fNTSSB = rule->GetSSB();
521 fMonitorNtuple->Fill();
522 }
523
524 fRuleFit.MakeVisHists();
525 fRuleFit.MakeDebugHists();
526 }
527 Log() << kDEBUG << "Training done" << Endl;
528
529}
530
531////////////////////////////////////////////////////////////////////////////////
532/// training of rules using Jerome Friedmans implementation
533
535{
536 fRuleFit.InitPtrs( this );
537 Data()->SetCurrentType(Types::kTraining);
538 UInt_t nevents = Data()->GetNTrainingEvents();
539 std::vector<const TMVA::Event*> tmp;
540 for (Long64_t ievt=0; ievt<nevents; ievt++) {
541 const Event *event = GetEvent(ievt);
542 tmp.push_back(event);
543 }
544 fRuleFit.SetTrainingEvents( tmp );
545
546 RuleFitAPI *rfAPI = new RuleFitAPI( this, &fRuleFit, Log().GetMinType() );
547
548 rfAPI->WelcomeMessage();
549
550 // timer
551 Timer timer( 1, GetName() );
552
553 Log() << kINFO << "Training ..." << Endl;
554 rfAPI->TrainRuleFit();
555
556 Log() << kDEBUG << "reading model summary from rf_go.exe output" << Endl;
557 rfAPI->ReadModelSum();
558
559 // fRuleFit.GetRuleEnsemblePtr()->MakeRuleMap();
560
561 Log() << kDEBUG << "calculating rule and variable importance" << Endl;
562 fRuleFit.CalcImportance();
563
564 // Output results and fill monitor ntuple
565 fRuleFit.GetRuleEnsemblePtr()->Print();
566 //
567 if(!IsSilentFile())fRuleFit.MakeVisHists();
568
569 delete rfAPI;
570
571 Log() << kDEBUG << "done training" << Endl;
572}
573
574////////////////////////////////////////////////////////////////////////////////
575/// computes ranking of input variables
576
578{
579 // create the ranking object
580 fRanking = new Ranking( GetName(), "Importance" );
581
582 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
583 fRanking->AddRank( Rank( GetInputLabel(ivar), fRuleFit.GetRuleEnsemble().GetVarImportance(ivar) ) );
584 }
585
586 return fRanking;
587}
588
589////////////////////////////////////////////////////////////////////////////////
590/// add the rules to XML node
591
592void TMVA::MethodRuleFit::AddWeightsXMLTo( void* parent ) const
593{
594 fRuleFit.GetRuleEnsemble().AddXMLTo( parent );
595}
596
597////////////////////////////////////////////////////////////////////////////////
598/// read rules from an std::istream
599
601{
602 fRuleFit.GetRuleEnsemblePtr()->ReadRaw( istr );
603}
604
605////////////////////////////////////////////////////////////////////////////////
606/// read rules from XML node
607
609{
610 fRuleFit.GetRuleEnsemblePtr()->ReadFromXML( wghtnode );
611}
612
613////////////////////////////////////////////////////////////////////////////////
614/// returns MVA value for given event
615
617{
618 // cannot determine error
619 NoErrorCalc(err, errUpper);
620
621 return fRuleFit.EvalEvent( *GetEvent() );
622}
623
624////////////////////////////////////////////////////////////////////////////////
625/// write special monitoring histograms to file (here ntuple)
626
628{
629 BaseDir()->cd();
630 Log() << kINFO << "Write monitoring ntuple to file: " << BaseDir()->GetPath() << Endl;
631 fMonitorNtuple->Write();
632}
633
634////////////////////////////////////////////////////////////////////////////////
635/// write specific classifier response
636
637void TMVA::MethodRuleFit::MakeClassSpecific( std::ostream& fout, const TString& className ) const
638{
639 Int_t dp = fout.precision();
640 fout << " // not implemented for class: \"" << className << "\"" << std::endl;
641 fout << "};" << std::endl;
642 fout << "void " << className << "::Initialize(){}" << std::endl;
643 fout << "void " << className << "::Clear(){}" << std::endl;
644 fout << "double " << className << "::GetMvaValue__( const std::vector<double>& inputValues ) const {" << std::endl;
645 fout << " double rval=" << std::setprecision(10) << fRuleFit.GetRuleEnsemble().GetOffset() << ";" << std::endl;
646 MakeClassRuleCuts(fout);
647 MakeClassLinear(fout);
648 fout << " return rval;" << std::endl;
649 fout << "}" << std::endl;
650 fout << std::setprecision(dp);
651}
652
653////////////////////////////////////////////////////////////////////////////////
654/// print out the rule cuts
655
657{
658 Int_t dp = fout.precision();
659 if (!fRuleFit.GetRuleEnsemble().DoRules()) {
660 fout << " //" << std::endl;
661 fout << " // ==> MODEL CONTAINS NO RULES <==" << std::endl;
662 fout << " //" << std::endl;
663 return;
664 }
665 const RuleEnsemble *rens = &(fRuleFit.GetRuleEnsemble());
666 const std::vector< Rule* > *rules = &(rens->GetRulesConst());
667 const RuleCut *ruleCut;
668 //
669 std::list< std::pair<Double_t,Int_t> > sortedRules;
670 for (UInt_t ir=0; ir<rules->size(); ir++) {
671 sortedRules.push_back( std::pair<Double_t,Int_t>( (*rules)[ir]->GetImportance()/rens->GetImportanceRef(),ir ) );
672 }
673 sortedRules.sort();
674 //
675 fout << " //" << std::endl;
676 fout << " // here follows all rules ordered in importance (most important first)" << std::endl;
677 fout << " // at the end of each line, the relative importance of the rule is given" << std::endl;
678 fout << " //" << std::endl;
679 //
680 for ( std::list< std::pair<double,int> >::reverse_iterator itpair = sortedRules.rbegin();
681 itpair != sortedRules.rend(); ++itpair ) {
682 UInt_t ir = itpair->second;
683 Double_t impr = itpair->first;
684 ruleCut = (*rules)[ir]->GetRuleCut();
685 if (impr<rens->GetImportanceCut()) fout << " //" << std::endl;
686 fout << " if (" << std::flush;
687 for (UInt_t ic=0; ic<ruleCut->GetNvars(); ic++) {
688 Double_t sel = ruleCut->GetSelector(ic);
689 Double_t valmin = ruleCut->GetCutMin(ic);
690 Double_t valmax = ruleCut->GetCutMax(ic);
691 Bool_t domin = ruleCut->GetCutDoMin(ic);
692 Bool_t domax = ruleCut->GetCutDoMax(ic);
693 //
694 if (ic>0) fout << "&&" << std::flush;
695 if (domin) {
696 fout << "(" << std::setprecision(10) << valmin << std::flush;
697 fout << "<inputValues[" << sel << "])" << std::flush;
698 }
699 if (domax) {
700 if (domin) fout << "&&" << std::flush;
701 fout << "(inputValues[" << sel << "]" << std::flush;
702 fout << "<" << std::setprecision(10) << valmax << ")" <<std::flush;
703 }
704 }
705 fout << ") rval+=" << std::setprecision(10) << (*rules)[ir]->GetCoefficient() << ";" << std::flush;
706 fout << " // importance = " << TString::Format("%3.3f",impr) << std::endl;
707 }
708 fout << std::setprecision(dp);
709}
710
711////////////////////////////////////////////////////////////////////////////////
712/// print out the linear terms
713
715{
716 if (!fRuleFit.GetRuleEnsemble().DoLinear()) {
717 fout << " //" << std::endl;
718 fout << " // ==> MODEL CONTAINS NO LINEAR TERMS <==" << std::endl;
719 fout << " //" << std::endl;
720 return;
721 }
722 fout << " //" << std::endl;
723 fout << " // here follows all linear terms" << std::endl;
724 fout << " // at the end of each line, the relative importance of the term is given" << std::endl;
725 fout << " //" << std::endl;
726 const RuleEnsemble *rens = &(fRuleFit.GetRuleEnsemble());
727 UInt_t nlin = rens->GetNLinear();
728 for (UInt_t il=0; il<nlin; il++) {
729 if (rens->IsLinTermOK(il)) {
730 Double_t norm = rens->GetLinNorm(il);
731 Double_t imp = rens->GetLinImportance(il)/rens->GetImportanceRef();
732 fout << " rval+="
733 // << std::setprecision(10) << rens->GetLinCoefficients(il)*norm << "*std::min(" << setprecision(10) << rens->GetLinDP(il)
734 // << ", std::max( inputValues[" << il << "]," << std::setprecision(10) << rens->GetLinDM(il) << "));"
735 << std::setprecision(10) << rens->GetLinCoefficients(il)*norm
736 << "*std::min( double(" << std::setprecision(10) << rens->GetLinDP(il)
737 << "), std::max( double(inputValues[" << il << "]), double(" << std::setprecision(10) << rens->GetLinDM(il) << ")));"
738 << std::flush;
739 fout << " // importance = " << TString::Format("%3.3f",imp) << std::endl;
740 }
741 }
742}
743
744////////////////////////////////////////////////////////////////////////////////
745/// get help message text
746///
747/// typical length of text line:
748/// "|--------------------------------------------------------------|"
749
751{
752 TString col = gConfig().WriteOptionsReference() ? TString() : gTools().Color("bold");
754 TString brk = gConfig().WriteOptionsReference() ? "<br>" : "";
755
756 Log() << Endl;
757 Log() << col << "--- Short description:" << colres << Endl;
758 Log() << Endl;
759 Log() << "This method uses a collection of so called rules to create a" << Endl;
760 Log() << "discriminating scoring function. Each rule consists of a series" << Endl;
761 Log() << "of cuts in parameter space. The ensemble of rules are created" << Endl;
762 Log() << "from a forest of decision trees, trained using the training data." << Endl;
763 Log() << "Each node (apart from the root) corresponds to one rule." << Endl;
764 Log() << "The scoring function is then obtained by linearly combining" << Endl;
765 Log() << "the rules. A fitting procedure is applied to find the optimum" << Endl;
766 Log() << "set of coefficients. The goal is to find a model with few rules" << Endl;
767 Log() << "but with a strong discriminating power." << Endl;
768 Log() << Endl;
769 Log() << col << "--- Performance optimisation:" << colres << Endl;
770 Log() << Endl;
771 Log() << "There are two important considerations to make when optimising:" << Endl;
772 Log() << Endl;
773 Log() << " 1. Topology of the decision tree forest" << brk << Endl;
774 Log() << " 2. Fitting of the coefficients" << Endl;
775 Log() << Endl;
776 Log() << "The maximum complexity of the rules is defined by the size of" << Endl;
777 Log() << "the trees. Large trees will yield many complex rules and capture" << Endl;
778 Log() << "higher order correlations. On the other hand, small trees will" << Endl;
779 Log() << "lead to a smaller ensemble with simple rules, only capable of" << Endl;
780 Log() << "modeling simple structures." << Endl;
781 Log() << "Several parameters exists for controlling the complexity of the" << Endl;
782 Log() << "rule ensemble." << Endl;
783 Log() << Endl;
784 Log() << "The fitting procedure searches for a minimum using a gradient" << Endl;
785 Log() << "directed path. Apart from step size and number of steps, the" << Endl;
786 Log() << "evolution of the path is defined by a cut-off parameter, tau." << Endl;
787 Log() << "This parameter is unknown and depends on the training data." << Endl;
788 Log() << "A large value will tend to give large weights to a few rules." << Endl;
789 Log() << "Similarly, a small value will lead to a large set of rules" << Endl;
790 Log() << "with similar weights." << Endl;
791 Log() << Endl;
792 Log() << "A final point is the model used; rules and/or linear terms." << Endl;
793 Log() << "For a given training sample, the result may improve by adding" << Endl;
794 Log() << "linear terms. If best performance is obtained using only linear" << Endl;
795 Log() << "terms, it is very likely that the Fisher discriminant would be" << Endl;
796 Log() << "a better choice. Ideally the fitting procedure should be able to" << Endl;
797 Log() << "make this choice by giving appropriate weights for either terms." << Endl;
798 Log() << Endl;
799 Log() << col << "--- Performance tuning via configuration options:" << colres << Endl;
800 Log() << Endl;
801 Log() << "I. TUNING OF RULE ENSEMBLE:" << Endl;
802 Log() << Endl;
803 Log() << " " << col << "ForestType " << colres
804 << ": Recommended is to use the default \"AdaBoost\"." << brk << Endl;
805 Log() << " " << col << "nTrees " << colres
806 << ": More trees leads to more rules but also slow" << Endl;
807 Log() << " performance. With too few trees the risk is" << Endl;
808 Log() << " that the rule ensemble becomes too simple." << brk << Endl;
809 Log() << " " << col << "fEventsMin " << colres << brk << Endl;
810 Log() << " " << col << "fEventsMax " << colres
811 << ": With a lower min, more large trees will be generated" << Endl;
812 Log() << " leading to more complex rules." << Endl;
813 Log() << " With a higher max, more small trees will be" << Endl;
814 Log() << " generated leading to more simple rules." << Endl;
815 Log() << " By changing this range, the average complexity" << Endl;
816 Log() << " of the rule ensemble can be controlled." << brk << Endl;
817 Log() << " " << col << "RuleMinDist " << colres
818 << ": By increasing the minimum distance between" << Endl;
819 Log() << " rules, fewer and more diverse rules will remain." << Endl;
820 Log() << " Initially it is a good idea to keep this small" << Endl;
821 Log() << " or zero and let the fitting do the selection of" << Endl;
822 Log() << " rules. In order to reduce the ensemble size," << Endl;
823 Log() << " the value can then be increased." << Endl;
824 Log() << Endl;
825 // "|--------------------------------------------------------------|"
826 Log() << "II. TUNING OF THE FITTING:" << Endl;
827 Log() << Endl;
828 Log() << " " << col << "GDPathEveFrac " << colres
829 << ": fraction of events in path evaluation" << Endl;
830 Log() << " Increasing this fraction will improve the path" << Endl;
831 Log() << " finding. However, a too high value will give few" << Endl;
832 Log() << " unique events available for error estimation." << Endl;
833 Log() << " It is recommended to use the default = 0.5." << brk << Endl;
834 Log() << " " << col << "GDTau " << colres
835 << ": cutoff parameter tau" << Endl;
836 Log() << " By default this value is set to -1.0." << Endl;
837 // "|----------------|---------------------------------------------|"
838 Log() << " This means that the cut off parameter is" << Endl;
839 Log() << " automatically estimated. In most cases" << Endl;
840 Log() << " this should be fine. However, you may want" << Endl;
841 Log() << " to fix this value if you already know it" << Endl;
842 Log() << " and want to reduce on training time." << brk << Endl;
843 Log() << " " << col << "GDTauPrec " << colres
844 << ": precision of estimated tau" << Endl;
845 Log() << " Increase this precision to find a more" << Endl;
846 Log() << " optimum cut-off parameter." << brk << Endl;
847 Log() << " " << col << "GDNStep " << colres
848 << ": number of steps in path search" << Endl;
849 Log() << " If the number of steps is too small, then" << Endl;
850 Log() << " the program will give a warning message." << Endl;
851 Log() << Endl;
852 Log() << "III. WARNING MESSAGES" << Endl;
853 Log() << Endl;
854 Log() << col << "Risk(i+1)>=Risk(i) in path" << colres << brk << Endl;
855 Log() << col << "Chaotic behaviour of risk evolution." << colres << Endl;
856 // "|----------------|---------------------------------------------|"
857 Log() << " The error rate was still decreasing at the end" << Endl;
858 Log() << " By construction the Risk should always decrease." << Endl;
859 Log() << " However, if the training sample is too small or" << Endl;
860 Log() << " the model is overtrained, such warnings can" << Endl;
861 Log() << " occur." << Endl;
862 Log() << " The warnings can safely be ignored if only a" << Endl;
863 Log() << " few (<3) occur. If more warnings are generated," << Endl;
864 Log() << " the fitting fails." << Endl;
865 Log() << " A remedy may be to increase the value" << brk << Endl;
866 Log() << " "
867 << col << "GDValidEveFrac" << colres
868 << " to 1.0 (or a larger value)." << brk << Endl;
869 Log() << " In addition, if "
870 << col << "GDPathEveFrac" << colres
871 << " is too high" << Endl;
872 Log() << " the same warnings may occur since the events" << Endl;
873 Log() << " used for error estimation are also used for" << Endl;
874 Log() << " path estimation." << Endl;
875 Log() << " Another possibility is to modify the model - " << Endl;
876 Log() << " See above on tuning the rule ensemble." << Endl;
877 Log() << Endl;
878 Log() << col << "The error rate was still decreasing at the end of the path"
879 << colres << Endl;
880 Log() << " Too few steps in path! Increase "
881 << col << "GDNSteps" << colres << "." << Endl;
882 Log() << Endl;
883 Log() << col << "Reached minimum early in the search" << colres << Endl;
884
885 Log() << " Minimum was found early in the fitting. This" << Endl;
886 Log() << " may indicate that the used step size "
887 << col << "GDStep" << colres << "." << Endl;
888 Log() << " was too large. Reduce it and rerun." << Endl;
889 Log() << " If the results still are not OK, modify the" << Endl;
890 Log() << " model either by modifying the rule ensemble" << Endl;
891 Log() << " or add/remove linear terms" << Endl;
892}
#define REGISTER_METHOD(CLASS)
for example
#define e(i)
Definition RSha256.hxx:103
constexpr Bool_t kFALSE
Definition RtypesCore.h:108
constexpr Bool_t kTRUE
Definition RtypesCore.h:107
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t sel
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
Bool_t WriteOptionsReference() const
Definition Config.h:65
Implementation of the CrossEntropy as separation criterion.
Class that contains all the data information.
Definition DataSetInfo.h:62
static void SetIsTraining(bool on)
Implementation of a Decision Tree.
Implementation of the GiniIndex as separation criterion.
Definition GiniIndex.h:63
Virtual base Class for all MVA method.
Definition MethodBase.h:111
J Friedman's RuleFit method.
void GetHelpMessage() const override
get help message text
void ReadWeightsFromXML(void *wghtnode) override
read rules from XML node
void DeclareOptions() override
define the options (their key words) that can be set in the option string know options.
Double_t GetMvaValue(Double_t *err=nullptr, Double_t *errUpper=nullptr) override
returns MVA value for given event
void MakeClassLinear(std::ostream &) const
print out the linear terms
void TrainJFRuleFit()
training of rules using Jerome Friedmans implementation
void InitEventSample(void)
write all Events from the Tree into a vector of Events, that are more easily manipulated.
void MakeClassRuleCuts(std::ostream &) const
print out the rule cuts
void Init(void) override
default initialization
void InitMonitorNtuple()
initialize the monitoring ntuple
virtual ~MethodRuleFit(void)
destructor
void ReadWeightsFromStream(std::istream &istr) override
read rules from an std::istream
TTree * fMonitorNtuple
pointer to monitor rule ntuple
void ProcessOptions() override
process the options specified by the user
void AddWeightsXMLTo(void *parent) const override
add the rules to XML node
MethodRuleFit(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
standard constructor
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t) override
RuleFit can handle classification with 2 classes.
const Ranking * CreateRanking() override
computes ranking of input variables
void MakeClassSpecific(std::ostream &, const TString &) const override
write specific classifier response
void Train(void) override
void WriteMonitoringHistosToFile(void) const override
write special monitoring histograms to file (here ntuple)
void TrainTMVARuleFit()
training of rules using TMVA implementation
Implementation of the MisClassificationError as separation criterion.
Ranking for variables in method (implementation)
Definition Ranking.h:48
A class describing a 'rule cut'.
Definition RuleCut.h:36
J Friedman's RuleFit method.
Definition RuleFitAPI.h:51
Implementation of a rule.
Definition Rule.h:50
Implementation of the SdivSqrtSplusB as separation criterion.
Timing information for training and evaluation of MVA methods.
Definition Timer.h:58
const TString & Color(const TString &)
human readable color strings
Definition Tools.cxx:828
Singleton class for Global types used by TMVA.
Definition Types.h:71
@ kClassification
Definition Types.h:127
@ kTraining
Definition Types.h:143
virtual Int_t Write(const char *name=nullptr, Int_t option=0, Int_t bufsize=0)
Write this object to the current directory.
Definition TObject.cxx:964
virtual void Print(Option_t *option="") const
This method must be overridden when a class wants to print itself.
Definition TObject.cxx:655
Basic string class.
Definition TString.h:138
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2384
A TTree represents a columnar dataset.
Definition TTree.h:89
const Int_t n
Definition legend1.C:16
create variable transformations
Config & gConfig()
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148