Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RuleFit.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Joerg Stelzer, Fredrik Tegenfeldt, Helge Voss
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : Rule *
8 * *
9 * *
10 * Description: *
11 * A class describing a 'rule' *
12 * Each internal node of a tree defines a rule from all the parental nodes. *
13 * A rule with 0 or 1 nodes in the list is a root rule -> corresponds to a0. *
14 * Input: a decision tree (in the constructor) *
15 * its coefficient *
16 * *
17 * *
18 * Authors (alphabetical): *
19 * Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA *
20 * *
21 * Copyright (c) 2005: *
22 * CERN, Switzerland *
23 * Iowa State U. *
24 * MPI-K Heidelberg, Germany *
25 * *
26 * Redistribution and use in source and binary forms, with or without *
27 * modification, are permitted according to the terms listed in LICENSE *
28 * (see tmva/doc/LICENSE) *
29 **********************************************************************************/
30
31/*! \class TMVA::RuleFit
32\ingroup TMVA
33A class implementing various fits of rule ensembles
34*/
35#include "TMVA/RuleFit.h"
36
37#include "TMVA/DataSet.h"
38#include "TMVA/DecisionTree.h"
39#include "TMVA/Event.h"
40#include "TMVA/Factory.h" // for root base dir
41#include "TMVA/GiniIndex.h"
42#include "TMVA/MethodBase.h"
43#include "TMVA/MethodRuleFit.h"
44#include "TMVA/MsgLogger.h"
45#include "TMVA/Timer.h"
46#include "TMVA/Tools.h"
47#include "TMVA/Types.h"
48#include "TMVA/SeparationBase.h"
49
50#include "TDirectory.h"
51#include "TH2F.h"
52#include "TKey.h"
53#include "TRandom3.h"
54#include "TROOT.h" // for gROOT
55
56#include <algorithm>
57#include <random>
58
59
60////////////////////////////////////////////////////////////////////////////////
61/// constructor
62
64 : fVisHistsUseImp( kTRUE )
65 , fLogger(new MsgLogger("RuleFit"))
66{
68 fRNGEngine.seed(randSEED);
69}
70
71////////////////////////////////////////////////////////////////////////////////
72/// default constructor
73
75 : fNTreeSample(0)
76 , fNEveEffTrain(0)
77 , fMethodRuleFit(0)
78 , fMethodBase(0)
79 , fVisHistsUseImp(kTRUE)
80 , fLogger(new MsgLogger("RuleFit"))
81{
82 fRNGEngine.seed(randSEED);
83}
84
85////////////////////////////////////////////////////////////////////////////////
86/// destructor
87
89{
90 delete fLogger;
91}
92
93////////////////////////////////////////////////////////////////////////////////
94/// init effective number of events (using event weights)
95
97{
98 UInt_t neve = fTrainingEvents.size();
99 if (neve==0) return;
100 //
101 fNEveEffTrain = CalcWeightSum( &fTrainingEvents );
102 //
103}
104
105////////////////////////////////////////////////////////////////////////////////
106/// initialize pointers
107
109{
110 this->SetMethodBase(rfbase);
111 fRuleEnsemble.Initialize( this );
112 fRuleFitParams.SetRuleFit( this );
113}
114
115////////////////////////////////////////////////////////////////////////////////
116/// initialize the parameters of the RuleFit method and make rules
117
119{
120 InitPtrs(rfbase);
121
122 if (fMethodRuleFit){
123 fMethodRuleFit->Data()->SetCurrentType(Types::kTraining);
124 UInt_t nevents = fMethodRuleFit->Data()->GetNTrainingEvents();
125 std::vector<const TMVA::Event*> tmp;
126 for (Long64_t ievt=0; ievt<nevents; ievt++) {
127 const Event *event = fMethodRuleFit->GetEvent(ievt);
128 tmp.push_back(event);
129 }
130 SetTrainingEvents( tmp );
131 }
132 // SetTrainingEvents( fMethodRuleFit->GetTrainingEvents() );
133
134 InitNEveEff();
135
136 MakeForest();
137
138 // Make the model - Rule + Linear (if fDoLinear is true)
139 fRuleEnsemble.MakeModel();
140
141 // init rulefit params
142 fRuleFitParams.Init();
143
144}
145
146////////////////////////////////////////////////////////////////////////////////
147/// set MethodBase
148
150{
151 fMethodBase = rfbase;
152 fMethodRuleFit = dynamic_cast<const MethodRuleFit *>(rfbase);
153}
154
155////////////////////////////////////////////////////////////////////////////////
156/// copy method
157
159{
160 if(this != &other) {
161 fMethodRuleFit = other.GetMethodRuleFit();
162 fMethodBase = other.GetMethodBase();
163 fTrainingEvents = other.GetTrainingEvents();
164 // fSubsampleEvents = other.GetSubsampleEvents();
165
166 fForest = other.GetForest();
167 fRuleEnsemble = other.GetRuleEnsemble();
168 }
169}
170
171////////////////////////////////////////////////////////////////////////////////
172/// calculate the sum of weights
173
174Double_t TMVA::RuleFit::CalcWeightSum( const std::vector<const Event *> *events, UInt_t neve )
175{
176 if (events==0) return 0.0;
177 if (neve==0) neve=events->size();
178 //
179 Double_t sumw=0;
180 for (UInt_t ie=0; ie<neve; ie++) {
181 sumw += ((*events)[ie])->GetWeight();
182 }
183 return sumw;
184}
185
186////////////////////////////////////////////////////////////////////////////////
187/// set the current message type to that of mlog for this class and all other subtools
188
190{
191 fLogger->SetMinType(t);
192 fRuleEnsemble.SetMsgType(t);
193 fRuleFitParams.SetMsgType(t);
194}
195
196////////////////////////////////////////////////////////////////////////////////
197/// build the decision tree using fNTreeSample events from fTrainingEventsRndm
198
200{
201 if (dt==0) return;
202 if (fMethodRuleFit==0) {
203 Log() << kFATAL << "RuleFit::BuildTree() - Attempting to build a tree NOT from a MethodRuleFit" << Endl;
204 }
205 std::vector<const Event *> evevec;
206 for (UInt_t ie=0; ie<fNTreeSample; ie++) {
207 evevec.push_back(fTrainingEventsRndm[ie]);
208 }
209 dt->BuildTree(evevec);
210 if (fMethodRuleFit->GetPruneMethod() != DecisionTree::kNoPruning) {
211 dt->SetPruneMethod(fMethodRuleFit->GetPruneMethod());
212 dt->SetPruneStrength(fMethodRuleFit->GetPruneStrength());
213 dt->PruneTree();
214 }
215}
216
217////////////////////////////////////////////////////////////////////////////////
218/// make a forest of decisiontrees
219
221{
222 if (fMethodRuleFit==0) {
223 Log() << kFATAL << "RuleFit::BuildTree() - Attempting to build a tree NOT from a MethodRuleFit" << Endl;
224 }
225 Log() << kDEBUG << "Creating a forest with " << fMethodRuleFit->GetNTrees() << " decision trees" << Endl;
226 Log() << kDEBUG << "Each tree is built using a random subsample with " << fNTreeSample << " events" << Endl;
227 //
228 Timer timer( fMethodRuleFit->GetNTrees(), "RuleFit" );
229
230 //
232 //
233 // First save all event weights.
234 // Weights are modified by the boosting.
235 // Those weights we do not want for the later fitting.
236 //
237 Bool_t useBoost = fMethodRuleFit->UseBoost(); // (AdaBoost (True) or RandomForest/Tree (False)
238
239 if (useBoost) SaveEventWeights();
240
241 for (Int_t i=0; i<fMethodRuleFit->GetNTrees(); i++) {
242 // timer.DrawProgressBar(i);
243 if (!useBoost) ReshuffleEvents();
244
245 DecisionTree *dt=nullptr;
247 Int_t ntries=0;
248 const Int_t ntriesMax=10;
249 Double_t frnd = 0.;
250 while (tryAgain) {
251 frnd = 100*rndGen.Uniform( fMethodRuleFit->GetMinFracNEve(), 0.5*fMethodRuleFit->GetMaxFracNEve() );
252 Int_t iclass = 0; // event class being treated as signal during training
254 dt = new DecisionTree( fMethodRuleFit->GetSeparationBase(), frnd, fMethodRuleFit->GetNCuts(), &(fMethodRuleFit->DataInfo()), iclass, useRandomisedTree);
255 dt->SetNVars(fMethodBase->GetNvar());
256
257 BuildTree(dt); // reads fNTreeSample events from fTrainingEventsRndm
258 if (dt->GetNNodes()<3) {
259 delete dt;
260 dt=0;
261 }
262 ntries++;
263 tryAgain = ((dt==0) && (ntries<ntriesMax));
264 }
265 if (dt) {
266 fForest.push_back(dt);
267 if (useBoost) Boost(dt);
268
269 } else {
270
271 Log() << kWARNING << "------------------------------------------------------------------" << Endl;
272 Log() << kWARNING << " Failed growing a tree even after " << ntriesMax << " trials" << Endl;
273 Log() << kWARNING << " Possible solutions: " << Endl;
274 Log() << kWARNING << " 1. increase the number of training events" << Endl;
275 Log() << kWARNING << " 2. set a lower min fraction cut (fEventsMin)" << Endl;
276 Log() << kWARNING << " 3. maybe also decrease the max fraction cut (fEventsMax)" << Endl;
277 Log() << kWARNING << " If the above warning occurs rarely only, it can be ignored" << Endl;
278 Log() << kWARNING << "------------------------------------------------------------------" << Endl;
279 }
280
281 Log() << kDEBUG << "Built tree with minimum cut at N = " << frnd <<"% events"
282 << " => N(nodes) = " << fForest.back()->GetNNodes()
283 << " ; n(tries) = " << ntries
284 << Endl;
285 }
286
287 // Now restore event weights
288 if (useBoost) RestoreEventWeights();
289
290 // print statistics on the forest created
291 ForestStatistics();
292}
293
294////////////////////////////////////////////////////////////////////////////////
295/// save event weights - must be done before making the forest
296
298{
299 fEventWeights.clear();
300 for (std::vector<const Event*>::iterator e=fTrainingEvents.begin(); e!=fTrainingEvents.end(); ++e) {
301 Double_t w = (*e)->GetBoostWeight();
302 fEventWeights.push_back(w);
303 }
304}
305
306////////////////////////////////////////////////////////////////////////////////
307/// save event weights - must be done before making the forest
308
310{
311 UInt_t ie=0;
312 if (fEventWeights.size() != fTrainingEvents.size()) {
313 Log() << kERROR << "RuleFit::RestoreEventWeights() called without having called SaveEventWeights() before!" << Endl;
314 return;
315 }
316 for (std::vector<const Event*>::iterator e=fTrainingEvents.begin(); e!=fTrainingEvents.end(); ++e) {
317 (*e)->SetBoostWeight(fEventWeights[ie]);
318 ie++;
319 }
320}
321
322////////////////////////////////////////////////////////////////////////////////
323/// Boost the events. The algorithm below is the called AdaBoost.
324/// See MethodBDT for details.
325/// Actually, this is a more or less copy of MethodBDT::AdaBoost().
326
328{
329 Double_t sumw=0; // sum of initial weights - all events
330 Double_t sumwfalse=0; // idem, only misclassified events
331 //
332 std::vector<Char_t> correctSelected; // <--- boolean stored
333 //
334 for (std::vector<const Event*>::iterator e=fTrainingEvents.begin(); e!=fTrainingEvents.end(); ++e) {
335 Bool_t isSignalType = (dt->CheckEvent(*e,kTRUE) > 0.5 );
336 Double_t w = (*e)->GetWeight();
337 sumw += w;
338 //
339 if (isSignalType == fMethodBase->DataInfo().IsSignal(*e)) { // correctly classified
340 correctSelected.push_back(kTRUE);
341 }
342 else { // misclassified
343 sumwfalse+= w;
344 correctSelected.push_back(kFALSE);
345 }
346 }
347 // misclassification error
348 Double_t err = sumwfalse/sumw;
349 // calculate boost weight for misclassified events
350 // use for now the exponent = 1.0
351 // one could have w = ((1-err)/err)^beta
352 Double_t boostWeight = (err>0 ? (1.0-err)/err : 1000.0);
353 Double_t newSumw=0.0;
354 UInt_t ie=0;
355 // set new weight to misclassified events
356 for (std::vector<const Event*>::iterator e=fTrainingEvents.begin(); e!=fTrainingEvents.end(); ++e) {
357 if (!correctSelected[ie])
358 (*e)->SetBoostWeight( (*e)->GetBoostWeight() * boostWeight);
359 newSumw+=(*e)->GetWeight();
360 ie++;
361 }
362 // reweight all events
364 for (std::vector<const Event*>::iterator e=fTrainingEvents.begin(); e!=fTrainingEvents.end(); ++e) {
365 (*e)->SetBoostWeight( (*e)->GetBoostWeight() * scale);
366 }
367 Log() << kDEBUG << "boostWeight = " << boostWeight << " scale = " << scale << Endl;
368}
369
370////////////////////////////////////////////////////////////////////////////////
371/// summary of statistics of all trees
372/// - end-nodes: average and spread
373
375{
376 UInt_t ntrees = fForest.size();
377 if (ntrees==0) return;
378 const DecisionTree *tree;
379 Double_t sumn2 = 0;
380 Double_t sumn = 0;
381 Double_t nd;
382 for (UInt_t i=0; i<ntrees; i++) {
383 tree = fForest[i];
384 nd = Double_t(tree->GetNNodes());
385 sumn += nd;
386 sumn2 += nd*nd;
387 }
388 Double_t sig = TMath::Sqrt( gTools().ComputeVariance( sumn2, sumn, ntrees ));
389 Log() << kVERBOSE << "Nodes in trees: average & std dev = " << sumn/ntrees << " , " << sig << Endl;
390}
391
392////////////////////////////////////////////////////////////////////////////////
393///
394/// Fit the coefficients for the rule ensemble
395///
396
398{
399 Log() << kVERBOSE << "Fitting rule/linear terms" << Endl;
400 fRuleFitParams.MakeGDPath();
401}
402
403////////////////////////////////////////////////////////////////////////////////
404/// calculates the importance of each rule
405
407{
408 Log() << kVERBOSE << "Calculating importance" << Endl;
409 fRuleEnsemble.CalcImportance();
410 fRuleEnsemble.CleanupRules();
411 fRuleEnsemble.CleanupLinear();
412 fRuleEnsemble.CalcVarImportance();
413 Log() << kVERBOSE << "Filling rule statistics" << Endl;
414 fRuleEnsemble.RuleResponseStats();
415}
416
417////////////////////////////////////////////////////////////////////////////////
418/// evaluate single event
419
421{
422 return fRuleEnsemble.EvalEvent( e );
423}
424
425////////////////////////////////////////////////////////////////////////////////
426/// set the training events randomly
427
428void TMVA::RuleFit::SetTrainingEvents( const std::vector<const Event *>& el )
429{
430 if (fMethodRuleFit==0) Log() << kFATAL << "RuleFit::SetTrainingEvents - MethodRuleFit not initialized" << Endl;
431 UInt_t neve = el.size();
432 if (neve==0) Log() << kWARNING << "An empty sample of training events was given" << Endl;
433
434 // copy vector
435 fTrainingEvents.clear();
436 fTrainingEventsRndm.clear();
437 for (UInt_t i=0; i<neve; i++) {
438 fTrainingEvents.push_back(static_cast< const Event *>(el[i]));
439 fTrainingEventsRndm.push_back(static_cast< const Event *>(el[i]));
440 }
441
442 // Re-shuffle the vector, ie, recreate it in a random order
443 std::shuffle(fTrainingEventsRndm.begin(), fTrainingEventsRndm.end(), fRNGEngine);
444
445 // fraction events per tree
446 fNTreeSample = static_cast<UInt_t>(neve*fMethodRuleFit->GetTreeEveFrac());
447 Log() << kDEBUG << "Number of events per tree : " << fNTreeSample
448 << " ( N(events) = " << neve << " )"
449 << " randomly drawn without replacement" << Endl;
450}
451
452////////////////////////////////////////////////////////////////////////////////
453/// draw a random subsample of the training events without replacement
454
455void TMVA::RuleFit::GetRndmSampleEvents(std::vector< const Event * > & evevec, UInt_t nevents)
456{
457 ReshuffleEvents();
458 if ((nevents<fTrainingEventsRndm.size()) && (nevents>0)) {
459 evevec.resize(nevents);
460 for (UInt_t ie=0; ie<nevents; ie++) {
461 evevec[ie] = fTrainingEventsRndm[ie];
462 }
463 }
464 else {
465 Log() << kWARNING << "GetRndmSampleEvents() : requested sub sample size larger than total size (BUG!).";
466 }
467}
468////////////////////////////////////////////////////////////////////////////////
469/// normalize rule importance hists
470///
471/// if all weights are positive, the scale will be 1/maxweight
472/// if minimum weight < 0, then the scale will be 1/max(maxweight,abs(minweight))
473
474void TMVA::RuleFit::NormVisHists(std::vector<TH2F *> & hlist)
475{
476 if (hlist.empty()) return;
477 //
478 Double_t wmin=0;
479 Double_t wmax=0;
480 Double_t w,wm;
483 for (UInt_t i=0; i<hlist.size(); i++) {
484 TH2F *hs = hlist[i];
485 w = hs->GetMaximum();
486 wm = hs->GetMinimum();
487 if (i==0) {
488 wmin=wm;
489 wmax=w;
490 }
491 else {
492 if (w>wmax) wmax=w;
493 if (wm<wmin) wmin=wm;
494 }
495 }
498 if (awmin>wmax) {
499 scale = 1.0/awmin;
500 usemin = -1.0;
501 usemax = scale*wmax;
502 }
503 else {
504 scale = 1.0/wmax;
505 usemin = scale*wmin;
506 usemax = 1.0;
507 }
508
509 //
510 for (UInt_t i=0; i<hlist.size(); i++) {
511 TH2F *hs = hlist[i];
512 hs->Scale(scale);
513 hs->SetMinimum(usemin);
514 hs->SetMaximum(usemax);
515 }
516}
517
518////////////////////////////////////////////////////////////////////////////////
519/// Fill cut
520
522{
523 if (rule==0) return;
524 if (h2==0) return;
525 //
528 Bool_t ruleHasVar = rule->GetRuleCut()->GetCutRange(vind,rmin,rmax,dormin,dormax);
529 if (!ruleHasVar) return;
530 //
531 Int_t firstbin = h2->GetBin(1,1,1);
532 if(firstbin<0) firstbin=0;
533 Int_t lastbin = h2->GetBin(h2->GetNbinsX(),1,1);
534 Int_t binmin=(dormin ? h2->FindBin(rmin,0.5):firstbin);
535 Int_t binmax=(dormax ? h2->FindBin(rmax,0.5):lastbin);
536 Int_t fbin;
537 Double_t xbinw = h2->GetXaxis()->GetBinWidth(firstbin);
538 Double_t fbmin = h2->GetXaxis()->GetBinLowEdge(binmin-firstbin+1);
539 Double_t lbmax = h2->GetXaxis()->GetBinLowEdge(binmax-firstbin+1)+xbinw;
540 Double_t fbfrac = (dormin ? ((fbmin+xbinw-rmin)/xbinw):1.0);
541 Double_t lbfrac = (dormax ? ((rmax-lbmax+xbinw)/xbinw):1.0);
542 Double_t f;
543 Double_t xc;
544 Double_t val;
545
546 for (Int_t bin = binmin; bin<binmax+1; bin++) {
547 fbin = bin-firstbin+1;
548 if (bin==binmin) {
549 f = fbfrac;
550 }
551 else if (bin==binmax) {
552 f = lbfrac;
553 }
554 else {
555 f = 1.0;
556 }
557 xc = h2->GetXaxis()->GetBinCenter(fbin);
558 //
559 if (fVisHistsUseImp) {
560 val = rule->GetImportance();
561 }
562 else {
563 val = rule->GetCoefficient()*rule->GetSupport();
564 }
565 h2->Fill(xc,0.5,val*f);
566 }
567}
568
569////////////////////////////////////////////////////////////////////////////////
570/// fill lin
571
573{
574 if (h2==0) return;
575 if (!fRuleEnsemble.DoLinear()) return;
576 //
577 Int_t firstbin = 1;
578 Int_t lastbin = h2->GetNbinsX();
579 Double_t xc;
580 Double_t val;
581 if (fVisHistsUseImp) {
582 val = fRuleEnsemble.GetLinImportance(vind);
583 }
584 else {
585 val = fRuleEnsemble.GetLinCoefficients(vind);
586 }
587 for (Int_t bin = firstbin; bin<lastbin+1; bin++) {
588 xc = h2->GetXaxis()->GetBinCenter(bin);
589 h2->Fill(xc,0.5,val);
590 }
591}
592
593////////////////////////////////////////////////////////////////////////////////
594/// fill rule correlation between vx and vy, weighted with either the importance or the coefficient
595
597{
598 if (rule==0) return;
599 if (h2==0) return;
600 Double_t val;
601 if (fVisHistsUseImp) {
602 val = rule->GetImportance();
603 }
604 else {
605 val = rule->GetCoefficient()*rule->GetSupport();
606 }
607 //
610 //
611 // Get range in rule for X and Y
612 //
613 Bool_t ruleHasVarX = rule->GetRuleCut()->GetCutRange(vx,rxmin,rxmax,dorxmin,dorxmax);
614 Bool_t ruleHasVarY = rule->GetRuleCut()->GetCutRange(vy,rymin,rymax,dorymin,dorymax);
615 if (!(ruleHasVarX || ruleHasVarY)) return;
616 // min max of varX and varY in hist
617 Double_t vxmin = (dorxmin ? rxmin:h2->GetXaxis()->GetXmin());
618 Double_t vxmax = (dorxmax ? rxmax:h2->GetXaxis()->GetXmax());
619 Double_t vymin = (dorymin ? rymin:h2->GetYaxis()->GetXmin());
620 Double_t vymax = (dorymax ? rymax:h2->GetYaxis()->GetXmax());
621 // min max bin in X and Y
622 Int_t binxmin = h2->GetXaxis()->FindBin(vxmin);
623 Int_t binxmax = h2->GetXaxis()->FindBin(vxmax);
624 Int_t binymin = h2->GetYaxis()->FindBin(vymin);
625 Int_t binymax = h2->GetYaxis()->FindBin(vymax);
626 // bin widths
627 Double_t xbinw = h2->GetXaxis()->GetBinWidth(binxmin);
628 Double_t ybinw = h2->GetYaxis()->GetBinWidth(binxmin);
629 Double_t xbinmin = h2->GetXaxis()->GetBinLowEdge(binxmin);
630 Double_t xbinmax = h2->GetXaxis()->GetBinLowEdge(binxmax)+xbinw;
631 Double_t ybinmin = h2->GetYaxis()->GetBinLowEdge(binymin);
632 Double_t ybinmax = h2->GetYaxis()->GetBinLowEdge(binymax)+ybinw;
633 // fraction of edges
638 //
639 Double_t fx,fy;
640 Double_t xc,yc;
641 // fill histo
642 for (Int_t binx = binxmin; binx<binxmax+1; binx++) {
643 if (binx==binxmin) {
644 fx = fxbinmin;
645 }
646 else if (binx==binxmax) {
647 fx = fxbinmax;
648 }
649 else {
650 fx = 1.0;
651 }
652 xc = h2->GetXaxis()->GetBinCenter(binx);
653 for (Int_t biny = binymin; biny<binymax+1; biny++) {
654 if (biny==binymin) {
655 fy = fybinmin;
656 }
657 else if (biny==binymax) {
658 fy = fybinmax;
659 }
660 else {
661 fy = 1.0;
662 }
663 yc = h2->GetYaxis()->GetBinCenter(biny);
664 h2->Fill(xc,yc,val*fx*fy);
665 }
666 }
667}
668
669////////////////////////////////////////////////////////////////////////////////
670/// help routine to MakeVisHists() - fills for all variables
671
672void TMVA::RuleFit::FillVisHistCut(const Rule* rule, std::vector<TH2F *> & hlist)
673{
674 Int_t nhists = hlist.size();
675 Int_t nvar = fMethodBase->GetNvar();
676 if (nhists!=nvar) Log() << kFATAL << "BUG TRAP: number of hists is not equal the number of variables!" << Endl;
677 //
678 std::vector<Int_t> vindex;
680 // not a nice way to do a check...
681 for (Int_t ih=0; ih<nhists; ih++) {
682 hstr = hlist[ih]->GetTitle();
683 for (Int_t iv=0; iv<nvar; iv++) {
684 if (fMethodBase->GetInputTitle(iv) == hstr)
685 vindex.push_back(iv);
686 }
687 }
688 //
689 for (Int_t iv=0; iv<nvar; iv++) {
690 if (rule) {
691 if (rule->ContainsVariable(vindex[iv])) {
692 FillCut(hlist[iv],rule,vindex[iv]);
693 }
694 }
695 else {
696 FillLin(hlist[iv],vindex[iv]);
697 }
698 }
699}
700////////////////////////////////////////////////////////////////////////////////
701/// help routine to MakeVisHists() - fills for all correlation plots
702
703void TMVA::RuleFit::FillVisHistCorr(const Rule * rule, std::vector<TH2F *> & hlist)
704{
705 if (rule==0) return;
706 Double_t ruleimp = rule->GetImportance();
707 if (!(ruleimp>0)) return;
708 if (ruleimp<fRuleEnsemble.GetImportanceCut()) return;
709 //
710 Int_t nhists = hlist.size();
711 Int_t nvar = fMethodBase->GetNvar();
712 Int_t ncorr = (nvar*(nvar+1)/2)-nvar;
713 if (nhists!=ncorr) Log() << kERROR << "BUG TRAP: number of corr hists is not correct! ncorr = "
714 << ncorr << " nvar = " << nvar << " nhists = " << nhists << Endl;
715 //
716 std::vector< std::pair<Int_t,Int_t> > vindex;
718 Int_t iv1=0,iv2=0;
719 // not a nice way to do a check...
720 for (Int_t ih=0; ih<nhists; ih++) {
721 hstr = hlist[ih]->GetName();
722 if (GetCorrVars( hstr, var1, var2 )) {
723 iv1 = fMethodBase->DataInfo().FindVarIndex( var1 );
724 iv2 = fMethodBase->DataInfo().FindVarIndex( var2 );
725 vindex.push_back( std::pair<Int_t,Int_t>(iv2,iv1) ); // pair X, Y
726 }
727 else {
728 Log() << kERROR << "BUG TRAP: should not be here - failed getting var1 and var2" << Endl;
729 }
730 }
731 //
732 for (Int_t ih=0; ih<nhists; ih++) {
733 if ( (rule->ContainsVariable(vindex[ih].first)) ||
734 (rule->ContainsVariable(vindex[ih].second)) ) {
735 FillCorr(hlist[ih],rule,vindex[ih].first,vindex[ih].second);
736 }
737 }
738}
739////////////////////////////////////////////////////////////////////////////////
740/// get first and second variables from title
741
743{
744 var1="";
745 var2="";
746 if(!title.BeginsWith("scat_")) return kFALSE;
747
748 TString titleCopy = title(5,title.Length());
749 if(titleCopy.Index("_RF2D")>=0) titleCopy.Remove(titleCopy.Index("_RF2D"));
750
751 Int_t splitPos = titleCopy.Index("_vs_");
752 if(splitPos>=0) { // there is a _vs_ in the string
754 var2 = titleCopy(splitPos+4, titleCopy.Length());
755 return kTRUE;
756 }
757 else {
758 var1 = titleCopy;
759 return kFALSE;
760 }
761}
762////////////////////////////////////////////////////////////////////////////////
763/// this will create histograms visualizing the rule ensemble
764
766{
767 const TString directories[5] = { "InputVariables_Id",
768 "InputVariables_Deco",
769 "InputVariables_PCA",
770 "InputVariables_Gauss",
771 "InputVariables_Gauss_Deco" };
772
773 const TString corrDirName = "CorrelationPlots";
774
775 TDirectory* rootDir = fMethodBase->GetFile();
776 TDirectory* varDir = 0;
777 TDirectory* corrDir = 0;
778
779 TDirectory* methodDir = fMethodBase->BaseDir();
781 //
782 Bool_t done=(rootDir==0);
783 Int_t type=0;
784 if (done) {
785 Log() << kWARNING << "No basedir - BUG??" << Endl;
786 return;
787 }
788 while (!done) {
790 type++;
791 done = ((varDir!=0) || (type>4));
792 }
793 if (varDir==0) {
794 Log() << kWARNING << "No input variable directory found - BUG?" << Endl;
795 return;
796 }
798 if (corrDir==0) {
799 Log() << kWARNING << "No correlation directory found" << Endl;
800 Log() << kWARNING << "Check for other warnings related to correlation histograms" << Endl;
801 return;
802 }
803 if (methodDir==0) {
804 Log() << kWARNING << "No rulefit method directory found - BUG?" << Endl;
805 return;
806 }
807
808 varDirName = varDir->GetName();
809 varDir->cd();
810 //
811 // get correlation plot directory
813 if (corrDir==0) {
814 Log() << kWARNING << "No correlation directory found : " << corrDirName << Endl;
815 return;
816 }
817
818 // how many plots are in the var directory?
819 Int_t noPlots = ((varDir->GetListOfKeys())->GetEntries()) / 2;
820 Log() << kDEBUG << "Got number of plots = " << noPlots << Endl;
821
822 // loop over all objects in directory
823 std::vector<TH2F *> h1Vector;
824 std::vector<TH2F *> h2CorrVector;
825 TIter next(varDir->GetListOfKeys());
826 TKey *key;
827 while ((key = (TKey*)next())) {
828 // make sure, that we only look at histograms
829 TClass *cl = gROOT->GetClass(key->GetClassName());
830 if (!cl->InheritsFrom(TH1F::Class())) continue;
831 TH1F *sig = (TH1F*)key->ReadObj();
832 TString hname= sig->GetName();
833 Log() << kDEBUG << "Got histogram : " << hname << Endl;
834
835 // check for all signal histograms
836 if (hname.Contains("__S")){ // found a new signal plot
837 TString htitle = sig->GetTitle();
838 htitle.ReplaceAll("signal","");
840 newname.ReplaceAll("__Signal","__RF");
841 newname.ReplaceAll("__S","__RF");
842
843 methodDir->cd();
844 TH2F *newhist = new TH2F(newname,htitle,sig->GetNbinsX(),sig->GetXaxis()->GetXmin(),sig->GetXaxis()->GetXmax(),
845 1,sig->GetYaxis()->GetXmin(),sig->GetYaxis()->GetXmax());
846 varDir->cd();
847 h1Vector.push_back( newhist );
848 }
849 }
850 //
851 corrDir->cd();
853 TIter nextCorr(corrDir->GetListOfKeys());
854 while ((key = (TKey*)nextCorr())) {
855 // make sure, that we only look at histograms
856 TClass *cl = gROOT->GetClass(key->GetClassName());
857 if (!cl->InheritsFrom(TH2F::Class())) continue;
858 TH2F *sig = (TH2F*)key->ReadObj();
859 TString hname= sig->GetName();
860
861 // check for all signal histograms
862 if ((hname.Contains("scat_")) && (hname.Contains("_Signal"))) {
863 Log() << kDEBUG << "Got histogram (2D) : " << hname << Endl;
864 TString htitle = sig->GetTitle();
865 htitle.ReplaceAll("(Signal)","");
867 newname.ReplaceAll("_Signal","_RF2D");
868
869 methodDir->cd();
870 const Int_t rebin=2;
872 sig->GetNbinsX()/rebin,sig->GetXaxis()->GetXmin(),sig->GetXaxis()->GetXmax(),
873 sig->GetNbinsY()/rebin,sig->GetYaxis()->GetXmin(),sig->GetYaxis()->GetXmax());
874 if (GetCorrVars( newname, var1, var2 )) {
875 Int_t iv1 = fMethodBase->DataInfo().FindVarIndex(var1);
876 Int_t iv2 = fMethodBase->DataInfo().FindVarIndex(var2);
877 if (iv1<0) {
878 sig->GetYaxis()->SetTitle(var1);
879 }
880 else {
881 sig->GetYaxis()->SetTitle(fMethodBase->GetInputTitle(iv1));
882 }
883 if (iv2<0) {
884 sig->GetXaxis()->SetTitle(var2);
885 }
886 else {
887 sig->GetXaxis()->SetTitle(fMethodBase->GetInputTitle(iv2));
888 }
889 }
890 corrDir->cd();
891 h2CorrVector.push_back( newhist );
892 }
893 }
894
895 varDir->cd();
896 // fill rules
897 UInt_t nrules = fRuleEnsemble.GetNRules();
898 const Rule *rule;
899 for (UInt_t i=0; i<nrules; i++) {
900 rule = fRuleEnsemble.GetRulesConst(i);
901 FillVisHistCut(rule, h1Vector);
902 }
903 // fill linear terms and normalise hists
904 FillVisHistCut(0, h1Vector);
905 NormVisHists(h1Vector);
906
907 //
908 corrDir->cd();
909 // fill rules
910 for (UInt_t i=0; i<nrules; i++) {
911 rule = fRuleEnsemble.GetRulesConst(i);
912 FillVisHistCorr(rule, h2CorrVector);
913 }
914 NormVisHists(h2CorrVector);
915
916 // write histograms to file
917 methodDir->cd();
918 for (UInt_t i=0; i<h1Vector.size(); i++) h1Vector[i]->Write();
919 for (UInt_t i=0; i<h2CorrVector.size(); i++) h2CorrVector[i]->Write();
920}
921
922////////////////////////////////////////////////////////////////////////////////
923/// this will create a histograms intended rather for debugging or for the curious user
924
926{
927 TDirectory* methodDir = fMethodBase->BaseDir();
928 if (methodDir==0) {
929 Log() << kWARNING << "<MakeDebugHists> No rulefit method directory found - bug?" << Endl;
930 return;
931 }
932 //
933 methodDir->cd();
934 std::vector<Double_t> distances;
935 std::vector<Double_t> fncuts;
936 std::vector<Double_t> fnvars;
937 const Rule *ruleA;
938 const Rule *ruleB;
939 Double_t dABmin=1000000.0;
940 Double_t dABmax=-1.0;
941 UInt_t nrules = fRuleEnsemble.GetNRules();
942 for (UInt_t i=0; i<nrules; i++) {
943 ruleA = fRuleEnsemble.GetRulesConst(i);
944 for (UInt_t j=i+1; j<nrules; j++) {
945 ruleB = fRuleEnsemble.GetRulesConst(j);
946 Double_t dAB = ruleA->RuleDist( *ruleB, kTRUE );
947 if (dAB>-0.5) {
948 UInt_t nc = ruleA->GetNcuts();
949 UInt_t nv = ruleA->GetNumVarsUsed();
950 distances.push_back(dAB);
951 fncuts.push_back(static_cast<Double_t>(nc));
952 fnvars.push_back(static_cast<Double_t>(nv));
953 if (dAB<dABmin) dABmin=dAB;
954 if (dAB>dABmax) dABmax=dAB;
955 }
956 }
957 }
958 //
959 TH1F *histDist = new TH1F("RuleDist","Rule distances",100,dABmin,dABmax);
960 TTree *distNtuple = new TTree("RuleDistNtuple","RuleDist ntuple");
964 distNtuple->Branch("dist", &ntDist, "dist/D");
965 distNtuple->Branch("ncuts",&ntNcuts, "ncuts/D");
966 distNtuple->Branch("nvars",&ntNvars, "nvars/D");
967 //
968 for (UInt_t i=0; i<distances.size(); i++) {
969 histDist->Fill(distances[i]);
970 ntDist = distances[i];
971 ntNcuts = fncuts[i];
972 ntNvars = fnvars[i];
973 distNtuple->Fill();
974 }
975 distNtuple->Write();
976}
#define f(i)
Definition RSha256.hxx:104
#define e(i)
Definition RSha256.hxx:103
constexpr Bool_t kFALSE
Definition RtypesCore.h:108
double Double_t
Double 8 bytes.
Definition RtypesCore.h:73
constexpr Bool_t kTRUE
Definition RtypesCore.h:107
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t wmin
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t wmax
#define gROOT
Definition TROOT.h:411
Double_t GetXmax() const
Definition TAxis.h:142
Double_t GetXmin() const
Definition TAxis.h:141
TClass instances represent classes, structs and namespaces in the ROOT type system.
Definition TClass.h:84
Bool_t InheritsFrom(const char *cl) const override
Return kTRUE if this class inherits from a class with name "classname".
Definition TClass.cxx:4901
static TClass * GetClass(const char *name, Bool_t load=kTRUE, Bool_t silent=kFALSE)
Static method returning pointer to TClass of the specified class name.
Definition TClass.cxx:2973
Describe directory structure in memory.
Definition TDirectory.h:45
1-D histogram with a float per channel (see TH1 documentation)
Definition TH1.h:878
static TClass * Class()
virtual Int_t GetNbinsY() const
Definition TH1.h:542
TAxis * GetXaxis()
Definition TH1.h:571
virtual Int_t GetNbinsX() const
Definition TH1.h:541
TAxis * GetYaxis()
Definition TH1.h:572
2-D histogram with a float per channel (see TH1 documentation)
Definition TH2.h:345
static TClass * Class()
Book space in a file, create I/O buffers, to fill them, (un)compress them.
Definition TKey.h:28
virtual const char * GetClassName() const
Definition TKey.h:75
virtual TObject * ReadObj()
To read a TObject* from the file.
Definition TKey.cxx:760
Implementation of a Decision Tree.
Virtual base Class for all MVA method.
Definition MethodBase.h:111
J Friedman's RuleFit method.
ostringstream derivative to redirect and format output
Definition MsgLogger.h:57
A class implementing various fits of rule ensembles.
Definition RuleFit.h:46
void GetRndmSampleEvents(std::vector< const TMVA::Event * > &evevec, UInt_t nevents)
draw a random subsample of the training events without replacement
Definition RuleFit.cxx:455
Double_t EvalEvent(const Event &e)
evaluate single event
Definition RuleFit.cxx:420
void SetMethodBase(const MethodBase *rfbase)
set MethodBase
Definition RuleFit.cxx:149
void InitPtrs(const TMVA::MethodBase *rfbase)
initialize pointers
Definition RuleFit.cxx:108
void Boost(TMVA::DecisionTree *dt)
Boost the events.
Definition RuleFit.cxx:327
void ForestStatistics()
summary of statistics of all trees
Definition RuleFit.cxx:374
static const Int_t randSEED
Definition RuleFit.h:176
void CalcImportance()
calculates the importance of each rule
Definition RuleFit.cxx:406
void SetMsgType(EMsgType t)
set the current message type to that of mlog for this class and all other subtools
Definition RuleFit.cxx:189
void Initialize(const TMVA::MethodBase *rfbase)
initialize the parameters of the RuleFit method and make rules
Definition RuleFit.cxx:118
virtual ~RuleFit(void)
destructor
Definition RuleFit.cxx:88
void FillVisHistCorr(const Rule *rule, std::vector< TH2F * > &hlist)
help routine to MakeVisHists() - fills for all correlation plots
Definition RuleFit.cxx:703
std::default_random_engine fRNGEngine
Definition RuleFit.h:177
void InitNEveEff()
init effective number of events (using event weights)
Definition RuleFit.cxx:96
void SaveEventWeights()
save event weights - must be done before making the forest
Definition RuleFit.cxx:297
void FillCut(TH2F *h2, const TMVA::Rule *rule, Int_t vind)
Fill cut.
Definition RuleFit.cxx:521
void FillLin(TH2F *h2, Int_t vind)
fill lin
Definition RuleFit.cxx:572
Bool_t GetCorrVars(TString &title, TString &var1, TString &var2)
get first and second variables from title
Definition RuleFit.cxx:742
void MakeForest()
make a forest of decisiontrees
Definition RuleFit.cxx:220
void FitCoefficients()
Fit the coefficients for the rule ensemble.
Definition RuleFit.cxx:397
void FillCorr(TH2F *h2, const TMVA::Rule *rule, Int_t v1, Int_t v2)
fill rule correlation between vx and vy, weighted with either the importance or the coefficient
Definition RuleFit.cxx:596
void NormVisHists(std::vector< TH2F * > &hlist)
normalize rule importance hists
Definition RuleFit.cxx:474
void RestoreEventWeights()
save event weights - must be done before making the forest
Definition RuleFit.cxx:309
void MakeVisHists()
this will create histograms visualizing the rule ensemble
Definition RuleFit.cxx:765
void FillVisHistCut(const Rule *rule, std::vector< TH2F * > &hlist)
help routine to MakeVisHists() - fills for all variables
Definition RuleFit.cxx:672
void BuildTree(TMVA::DecisionTree *dt)
build the decision tree using fNTreeSample events from fTrainingEventsRndm
Definition RuleFit.cxx:199
void SetTrainingEvents(const std::vector< const TMVA::Event * > &el)
set the training events randomly
Definition RuleFit.cxx:428
void Copy(const RuleFit &other)
copy method
Definition RuleFit.cxx:158
Double_t CalcWeightSum(const std::vector< const TMVA::Event * > *events, UInt_t neve=0)
calculate the sum of weights
Definition RuleFit.cxx:174
RuleFit(void)
default constructor
Definition RuleFit.cxx:74
void MakeDebugHists()
this will create a histograms intended rather for debugging or for the curious user
Definition RuleFit.cxx:925
Implementation of a rule.
Definition Rule.h:50
Timing information for training and evaluation of MVA methods.
Definition Timer.h:58
@ kTraining
Definition Types.h:143
virtual void SetTitle(const char *title="")
Set the title of the TNamed.
Definition TNamed.cxx:173
const char * GetName() const override
Returns name of object.
Definition TNamed.h:49
const char * GetTitle() const override
Returns title of object.
Definition TNamed.h:50
Random number generator class based on M.
Definition TRandom3.h:27
Basic string class.
Definition TString.h:138
Ssiz_t Length() const
Definition TString.h:425
Bool_t BeginsWith(const char *s, ECaseCompare cmp=kExact) const
Definition TString.h:631
A TTree represents a columnar dataset.
Definition TTree.h:89
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148
Double_t Sqrt(Double_t x)
Returns the square root of x.
Definition TMath.h:673
Short_t Abs(Short_t d)
Returns the absolute value of parameter Short_t d.
Definition TMathBase.h:124