Loading [MathJax]/jax/output/HTML-CSS/config.js
Logo ROOT  
Reference Guide
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Joerg Stelzer, Fredrik Tegenfeldt, Helge Voss
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : Rule *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * A class describing a 'rule' *
12 * Each internal node of a tree defines a rule from all the parental nodes. *
13 * A rule with 0 or 1 nodes in the list is a root rule -> corresponds to a0. *
14 * Input: a decision tree (in the constructor) *
15 * its coefficient *
16 * *
17 * *
18 * Authors (alphabetical): *
19 * Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA *
20 * *
21 * Copyright (c) 2005: *
22 * CERN, Switzerland *
23 * Iowa State U. *
24 * MPI-K Heidelberg, Germany *
25 * *
26 * Redistribution and use in source and binary forms, with or without *
27 * modification, are permitted according to the terms listed in LICENSE *
28 * (http://tmva.sourceforge.net/LICENSE) *
29 **********************************************************************************/
31/*! \class TMVA::RuleFit
32\ingroup TMVA
33A class implementing various fits of rule ensembles
35#include "TMVA/RuleFit.h"
37#include "TMVA/DataSet.h"
38#include "TMVA/DecisionTree.h"
39#include "TMVA/Event.h"
40#include "TMVA/Factory.h" // for root base dir
41#include "TMVA/GiniIndex.h"
42#include "TMVA/MethodBase.h"
43#include "TMVA/MethodRuleFit.h"
44#include "TMVA/MsgLogger.h"
45#include "TMVA/Timer.h"
46#include "TMVA/Tools.h"
47#include "TMVA/Types.h"
48#include "TMVA/SeparationBase.h"
50#include "TDirectory.h"
51#include "TH2F.h"
52#include "TFile.h"
53#include "TKey.h"
54#include "TRandom3.h"
55#include "TROOT.h" // for gROOT
57#include <algorithm>
58#include <random>
63/// constructor
66 : fVisHistsUseImp( kTRUE )
67 , fLogger(new MsgLogger("RuleFit"))
69 Initialize(rfbase);
70 fRNGEngine.seed(randSEED);
74/// default constructor
77 : fNTreeSample(0)
78 , fNEveEffTrain(0)
79 , fMethodRuleFit(0)
80 , fMethodBase(0)
81 , fVisHistsUseImp(kTRUE)
82 , fLogger(new MsgLogger("RuleFit"))
84 fRNGEngine.seed(randSEED);
88/// destructor
92 delete fLogger;
96/// init effective number of events (using event weights)
100 UInt_t neve = fTrainingEvents.size();
101 if (neve==0) return;
102 //
103 fNEveEffTrain = CalcWeightSum( &fTrainingEvents );
104 //
108/// initialize pointers
112 this->SetMethodBase(rfbase);
113 fRuleEnsemble.Initialize( this );
114 fRuleFitParams.SetRuleFit( this );
118/// initialize the parameters of the RuleFit method and make rules
122 InitPtrs(rfbase);
124 if (fMethodRuleFit){
125 fMethodRuleFit->Data()->SetCurrentType(Types::kTraining);
126 UInt_t nevents = fMethodRuleFit->Data()->GetNTrainingEvents();
127 std::vector<const TMVA::Event*> tmp;
128 for (Long64_t ievt=0; ievt<nevents; ievt++) {
129 const Event *event = fMethodRuleFit->GetEvent(ievt);
130 tmp.push_back(event);
131 }
132 SetTrainingEvents( tmp );
133 }
134 // SetTrainingEvents( fMethodRuleFit->GetTrainingEvents() );
136 InitNEveEff();
138 MakeForest();
140 // Make the model - Rule + Linear (if fDoLinear is true)
141 fRuleEnsemble.MakeModel();
143 // init rulefit params
144 fRuleFitParams.Init();
149/// set MethodBase
153 fMethodBase = rfbase;
154 fMethodRuleFit = dynamic_cast<const MethodRuleFit *>(rfbase);
158/// copy method
160void TMVA::RuleFit::Copy( const RuleFit& other )
162 if(this != &other) {
163 fMethodRuleFit = other.GetMethodRuleFit();
164 fMethodBase = other.GetMethodBase();
165 fTrainingEvents = other.GetTrainingEvents();
166 // fSubsampleEvents = other.GetSubsampleEvents();
168 fForest = other.GetForest();
169 fRuleEnsemble = other.GetRuleEnsemble();
170 }
174/// calculate the sum of weights
176Double_t TMVA::RuleFit::CalcWeightSum( const std::vector<const Event *> *events, UInt_t neve )
178 if (events==0) return 0.0;
179 if (neve==0) neve=events->size();
180 //
181 Double_t sumw=0;
182 for (UInt_t ie=0; ie<neve; ie++) {
183 sumw += ((*events)[ie])->GetWeight();
184 }
185 return sumw;
189/// set the current message type to that of mlog for this class and all other subtools
193 fLogger->SetMinType(t);
194 fRuleEnsemble.SetMsgType(t);
195 fRuleFitParams.SetMsgType(t);
199/// build the decision tree using fNTreeSample events from fTrainingEventsRndm
203 if (dt==0) return;
204 if (fMethodRuleFit==0) {
205 Log() << kFATAL << "RuleFit::BuildTree() - Attempting to build a tree NOT from a MethodRuleFit" << Endl;
206 }
207 std::vector<const Event *> evevec;
208 for (UInt_t ie=0; ie<fNTreeSample; ie++) {
209 evevec.push_back(fTrainingEventsRndm[ie]);
210 }
211 dt->BuildTree(evevec);
212 if (fMethodRuleFit->GetPruneMethod() != DecisionTree::kNoPruning) {
213 dt->SetPruneMethod(fMethodRuleFit->GetPruneMethod());
214 dt->SetPruneStrength(fMethodRuleFit->GetPruneStrength());
215 dt->PruneTree();
216 }
220/// make a forest of decisiontrees
224 if (fMethodRuleFit==0) {
225 Log() << kFATAL << "RuleFit::BuildTree() - Attempting to build a tree NOT from a MethodRuleFit" << Endl;
226 }
227 Log() << kDEBUG << "Creating a forest with " << fMethodRuleFit->GetNTrees() << " decision trees" << Endl;
228 Log() << kDEBUG << "Each tree is built using a random subsample with " << fNTreeSample << " events" << Endl;
229 //
230 Timer timer( fMethodRuleFit->GetNTrees(), "RuleFit" );
232 // Double_t fsig;
233 Int_t nsig,nbkg;
234 //
235 TRandom3 rndGen;
236 //
237 // First save all event weights.
238 // Weights are modified by the boosting.
239 // Those weights we do not want for the later fitting.
240 //
241 Bool_t useBoost = fMethodRuleFit->UseBoost(); // (AdaBoost (True) or RandomForest/Tree (False)
243 if (useBoost) SaveEventWeights();
245 for (Int_t i=0; i<fMethodRuleFit->GetNTrees(); i++) {
246 // timer.DrawProgressBar(i);
247 if (!useBoost) ReshuffleEvents();
248 nsig=0;
249 nbkg=0;
250 for (UInt_t ie = 0; ie<fNTreeSample; ie++) {
251 if (fMethodBase->DataInfo().IsSignal(fTrainingEventsRndm[ie])) nsig++; // ignore weights here
252 else nbkg++;
253 }
254 // fsig = Double_t(nsig)/Double_t(nsig+nbkg);
255 // do not implement the above in this release...just set it to default
257 DecisionTree *dt=nullptr;
258 Bool_t tryAgain=kTRUE;
259 Int_t ntries=0;
260 const Int_t ntriesMax=10;
261 Double_t frnd = 0.;
262 while (tryAgain) {
263 frnd = 100*rndGen.Uniform( fMethodRuleFit->GetMinFracNEve(), 0.5*fMethodRuleFit->GetMaxFracNEve() );
264 Int_t iclass = 0; // event class being treated as signal during training
265 Bool_t useRandomisedTree = !useBoost;
266 dt = new DecisionTree( fMethodRuleFit->GetSeparationBase(), frnd, fMethodRuleFit->GetNCuts(), &(fMethodRuleFit->DataInfo()), iclass, useRandomisedTree);
267 dt->SetNVars(fMethodBase->GetNvar());
269 BuildTree(dt); // reads fNTreeSample events from fTrainingEventsRndm
270 if (dt->GetNNodes()<3) {
271 delete dt;
272 dt=0;
273 }
274 ntries++;
275 tryAgain = ((dt==0) && (ntries<ntriesMax));
276 }
277 if (dt) {
278 fForest.push_back(dt);
279 if (useBoost) Boost(dt);
281 } else {
283 Log() << kWARNING << "------------------------------------------------------------------" << Endl;
284 Log() << kWARNING << " Failed growing a tree even after " << ntriesMax << " trials" << Endl;
285 Log() << kWARNING << " Possible solutions: " << Endl;
286 Log() << kWARNING << " 1. increase the number of training events" << Endl;
287 Log() << kWARNING << " 2. set a lower min fraction cut (fEventsMin)" << Endl;
288 Log() << kWARNING << " 3. maybe also decrease the max fraction cut (fEventsMax)" << Endl;
289 Log() << kWARNING << " If the above warning occurs rarely only, it can be ignored" << Endl;
290 Log() << kWARNING << "------------------------------------------------------------------" << Endl;
291 }
293 Log() << kDEBUG << "Built tree with minimum cut at N = " << frnd <<"% events"
294 << " => N(nodes) = " << fForest.back()->GetNNodes()
295 << " ; n(tries) = " << ntries
296 << Endl;
297 }
299 // Now restore event weights
300 if (useBoost) RestoreEventWeights();
302 // print statistics on the forest created
303 ForestStatistics();
307/// save event weights - must be done before making the forest
311 fEventWeights.clear();
312 for (std::vector<const Event*>::iterator e=fTrainingEvents.begin(); e!=fTrainingEvents.end(); ++e) {
313 Double_t w = (*e)->GetBoostWeight();
314 fEventWeights.push_back(w);
315 }
319/// save event weights - must be done before making the forest
323 UInt_t ie=0;
324 if (fEventWeights.size() != fTrainingEvents.size()) {
325 Log() << kERROR << "RuleFit::RestoreEventWeights() called without having called SaveEventWeights() before!" << Endl;
326 return;
327 }
328 for (std::vector<const Event*>::iterator e=fTrainingEvents.begin(); e!=fTrainingEvents.end(); ++e) {
329 (*e)->SetBoostWeight(fEventWeights[ie]);
330 ie++;
331 }
335/// Boost the events. The algorithm below is the called AdaBoost.
336/// See MethodBDT for details.
337/// Actually, this is a more or less copy of MethodBDT::AdaBoost().
341 Double_t sumw=0; // sum of initial weights - all events
342 Double_t sumwfalse=0; // idem, only misclassified events
343 //
344 std::vector<Char_t> correctSelected; // <--- boolean stored
345 //
346 for (std::vector<const Event*>::iterator e=fTrainingEvents.begin(); e!=fTrainingEvents.end(); ++e) {
347 Bool_t isSignalType = (dt->CheckEvent(*e,kTRUE) > 0.5 );
348 Double_t w = (*e)->GetWeight();
349 sumw += w;
350 //
351 if (isSignalType == fMethodBase->DataInfo().IsSignal(*e)) { // correctly classified
352 correctSelected.push_back(kTRUE);
353 }
354 else { // misclassified
355 sumwfalse+= w;
356 correctSelected.push_back(kFALSE);
357 }
358 }
359 // misclassification error
360 Double_t err = sumwfalse/sumw;
361 // calculate boost weight for misclassified events
362 // use for now the exponent = 1.0
363 // one could have w = ((1-err)/err)^beta
364 Double_t boostWeight = (err>0 ? (1.0-err)/err : 1000.0);
365 Double_t newSumw=0.0;
366 UInt_t ie=0;
367 // set new weight to misclassified events
368 for (std::vector<const Event*>::iterator e=fTrainingEvents.begin(); e!=fTrainingEvents.end(); ++e) {
369 if (!correctSelected[ie])
370 (*e)->SetBoostWeight( (*e)->GetBoostWeight() * boostWeight);
371 newSumw+=(*e)->GetWeight();
372 ie++;
373 }
374 // reweight all events
375 Double_t scale = sumw/newSumw;
376 for (std::vector<const Event*>::iterator e=fTrainingEvents.begin(); e!=fTrainingEvents.end(); ++e) {
377 (*e)->SetBoostWeight( (*e)->GetBoostWeight() * scale);
378 }
379 Log() << kDEBUG << "boostWeight = " << boostWeight << " scale = " << scale << Endl;
383/// summary of statistics of all trees
384/// - end-nodes: average and spread
388 UInt_t ntrees = fForest.size();
389 if (ntrees==0) return;
390 const DecisionTree *tree;
391 Double_t sumn2 = 0;
392 Double_t sumn = 0;
393 Double_t nd;
394 for (UInt_t i=0; i<ntrees; i++) {
395 tree = fForest[i];
396 nd = Double_t(tree->GetNNodes());
397 sumn += nd;
398 sumn2 += nd*nd;
399 }
400 Double_t sig = TMath::Sqrt( gTools().ComputeVariance( sumn2, sumn, ntrees ));
401 Log() << kVERBOSE << "Nodes in trees: average & std dev = " << sumn/ntrees << " , " << sig << Endl;
406/// Fit the coefficients for the rule ensemble
411 Log() << kVERBOSE << "Fitting rule/linear terms" << Endl;
412 fRuleFitParams.MakeGDPath();
416/// calculates the importance of each rule
420 Log() << kVERBOSE << "Calculating importance" << Endl;
421 fRuleEnsemble.CalcImportance();
422 fRuleEnsemble.CleanupRules();
423 fRuleEnsemble.CleanupLinear();
424 fRuleEnsemble.CalcVarImportance();
425 Log() << kVERBOSE << "Filling rule statistics" << Endl;
426 fRuleEnsemble.RuleResponseStats();
430/// evaluate single event
434 return fRuleEnsemble.EvalEvent( e );
438/// set the training events randomly
440void TMVA::RuleFit::SetTrainingEvents( const std::vector<const Event *>& el )
442 if (fMethodRuleFit==0) Log() << kFATAL << "RuleFit::SetTrainingEvents - MethodRuleFit not initialized" << Endl;
443 UInt_t neve = el.size();
444 if (neve==0) Log() << kWARNING << "An empty sample of training events was given" << Endl;
446 // copy vector
447 fTrainingEvents.clear();
448 fTrainingEventsRndm.clear();
449 for (UInt_t i=0; i<neve; i++) {
450 fTrainingEvents.push_back(static_cast< const Event *>(el[i]));
451 fTrainingEventsRndm.push_back(static_cast< const Event *>(el[i]));
452 }
454 // Re-shuffle the vector, ie, recreate it in a random order
455 std::shuffle(fTrainingEventsRndm.begin(), fTrainingEventsRndm.end(), fRNGEngine);
457 // fraction events per tree
458 fNTreeSample = static_cast<UInt_t>(neve*fMethodRuleFit->GetTreeEveFrac());
459 Log() << kDEBUG << "Number of events per tree : " << fNTreeSample
460 << " ( N(events) = " << neve << " )"
461 << " randomly drawn without replacement" << Endl;
465/// draw a random subsample of the training events without replacement
467void TMVA::RuleFit::GetRndmSampleEvents(std::vector< const Event * > & evevec, UInt_t nevents)
469 ReshuffleEvents();
470 if ((nevents<fTrainingEventsRndm.size()) && (nevents>0)) {
471 evevec.resize(nevents);
472 for (UInt_t ie=0; ie<nevents; ie++) {
473 evevec[ie] = fTrainingEventsRndm[ie];
474 }
475 }
476 else {
477 Log() << kWARNING << "GetRndmSampleEvents() : requested sub sample size larger than total size (BUG!).";
478 }
481/// normalize rule importance hists
483/// if all weights are positive, the scale will be 1/maxweight
484/// if minimum weight < 0, then the scale will be 1/max(maxweight,abs(minweight))
486void TMVA::RuleFit::NormVisHists(std::vector<TH2F *> & hlist)
488 if (hlist.empty()) return;
489 //
490 Double_t wmin=0;
491 Double_t wmax=0;
492 Double_t w,wm;
493 Double_t awmin;
494 Double_t scale;
495 for (UInt_t i=0; i<hlist.size(); i++) {
496 TH2F *hs = hlist[i];
497 w = hs->GetMaximum();
498 wm = hs->GetMinimum();
499 if (i==0) {
500 wmin=wm;
501 wmax=w;
502 }
503 else {
504 if (w>wmax) wmax=w;
505 if (wm<wmin) wmin=wm;
506 }
507 }
508 awmin = TMath::Abs(wmin);
509 Double_t usemin,usemax;
510 if (awmin>wmax) {
511 scale = 1.0/awmin;
512 usemin = -1.0;
513 usemax = scale*wmax;
514 }
515 else {
516 scale = 1.0/wmax;
517 usemin = scale*wmin;
518 usemax = 1.0;
519 }
521 //
522 for (UInt_t i=0; i<hlist.size(); i++) {
523 TH2F *hs = hlist[i];
524 hs->Scale(scale);
525 hs->SetMinimum(usemin);
526 hs->SetMaximum(usemax);
527 }
531/// Fill cut
533void TMVA::RuleFit::FillCut(TH2F* h2, const Rule *rule, Int_t vind)
535 if (rule==0) return;
536 if (h2==0) return;
537 //
538 Double_t rmin, rmax;
539 Bool_t dormin,dormax;
540 Bool_t ruleHasVar = rule->GetRuleCut()->GetCutRange(vind,rmin,rmax,dormin,dormax);
541 if (!ruleHasVar) return;
542 //
543 Int_t firstbin = h2->GetBin(1,1,1);
544 if(firstbin<0) firstbin=0;
545 Int_t lastbin = h2->GetBin(h2->GetNbinsX(),1,1);
546 Int_t binmin=(dormin ? h2->FindBin(rmin,0.5):firstbin);
547 Int_t binmax=(dormax ? h2->FindBin(rmax,0.5):lastbin);
548 Int_t fbin;
549 Double_t xbinw = h2->GetXaxis()->GetBinWidth(firstbin);
550 Double_t fbmin = h2->GetXaxis()->GetBinLowEdge(binmin-firstbin+1);
551 Double_t lbmax = h2->GetXaxis()->GetBinLowEdge(binmax-firstbin+1)+xbinw;
552 Double_t fbfrac = (dormin ? ((fbmin+xbinw-rmin)/xbinw):1.0);
553 Double_t lbfrac = (dormax ? ((rmax-lbmax+xbinw)/xbinw):1.0);
554 Double_t f;
555 Double_t xc;
556 Double_t val;
558 for (Int_t bin = binmin; bin<binmax+1; bin++) {
559 fbin = bin-firstbin+1;
560 if (bin==binmin) {
561 f = fbfrac;
562 }
563 else if (bin==binmax) {
564 f = lbfrac;
565 }
566 else {
567 f = 1.0;
568 }
569 xc = h2->GetXaxis()->GetBinCenter(fbin);
570 //
571 if (fVisHistsUseImp) {
572 val = rule->GetImportance();
573 }
574 else {
575 val = rule->GetCoefficient()*rule->GetSupport();
576 }
577 h2->Fill(xc,0.5,val*f);
578 }
582/// fill lin
586 if (h2==0) return;
587 if (!fRuleEnsemble.DoLinear()) return;
588 //
589 Int_t firstbin = 1;
590 Int_t lastbin = h2->GetNbinsX();
591 Double_t xc;
592 Double_t val;
593 if (fVisHistsUseImp) {
594 val = fRuleEnsemble.GetLinImportance(vind);
595 }
596 else {
597 val = fRuleEnsemble.GetLinCoefficients(vind);
598 }
599 for (Int_t bin = firstbin; bin<lastbin+1; bin++) {
600 xc = h2->GetXaxis()->GetBinCenter(bin);
601 h2->Fill(xc,0.5,val);
602 }
606/// fill rule correlation between vx and vy, weighted with either the importance or the coefficient
608void TMVA::RuleFit::FillCorr(TH2F* h2,const Rule *rule,Int_t vx, Int_t vy)
610 if (rule==0) return;
611 if (h2==0) return;
612 Double_t val;
613 if (fVisHistsUseImp) {
614 val = rule->GetImportance();
615 }
616 else {
617 val = rule->GetCoefficient()*rule->GetSupport();
618 }
619 //
620 Double_t rxmin, rxmax, rymin, rymax;
621 Bool_t dorxmin, dorxmax, dorymin, dorymax;
622 //
623 // Get range in rule for X and Y
624 //
625 Bool_t ruleHasVarX = rule->GetRuleCut()->GetCutRange(vx,rxmin,rxmax,dorxmin,dorxmax);
626 Bool_t ruleHasVarY = rule->GetRuleCut()->GetCutRange(vy,rymin,rymax,dorymin,dorymax);
627 if (!(ruleHasVarX || ruleHasVarY)) return;
628 // min max of varX and varY in hist
629 Double_t vxmin = (dorxmin ? rxmin:h2->GetXaxis()->GetXmin());
630 Double_t vxmax = (dorxmax ? rxmax:h2->GetXaxis()->GetXmax());
631 Double_t vymin = (dorymin ? rymin:h2->GetYaxis()->GetXmin());
632 Double_t vymax = (dorymax ? rymax:h2->GetYaxis()->GetXmax());
633 // min max bin in X and Y
634 Int_t binxmin = h2->GetXaxis()->FindBin(vxmin);
635 Int_t binxmax = h2->GetXaxis()->FindBin(vxmax);
636 Int_t binymin = h2->GetYaxis()->FindBin(vymin);
637 Int_t binymax = h2->GetYaxis()->FindBin(vymax);
638 // bin widths
639 Double_t xbinw = h2->GetXaxis()->GetBinWidth(binxmin);
640 Double_t ybinw = h2->GetYaxis()->GetBinWidth(binxmin);
641 Double_t xbinmin = h2->GetXaxis()->GetBinLowEdge(binxmin);
642 Double_t xbinmax = h2->GetXaxis()->GetBinLowEdge(binxmax)+xbinw;
643 Double_t ybinmin = h2->GetYaxis()->GetBinLowEdge(binymin);
644 Double_t ybinmax = h2->GetYaxis()->GetBinLowEdge(binymax)+ybinw;
645 // fraction of edges
646 Double_t fxbinmin = (dorxmin ? ((xbinmin+xbinw-vxmin)/xbinw):1.0);
647 Double_t fxbinmax = (dorxmax ? ((vxmax-xbinmax+xbinw)/xbinw):1.0);
648 Double_t fybinmin = (dorymin ? ((ybinmin+ybinw-vymin)/ybinw):1.0);
649 Double_t fybinmax = (dorymax ? ((vymax-ybinmax+ybinw)/ybinw):1.0);
650 //
651 Double_t fx,fy;
652 Double_t xc,yc;
653 // fill histo
654 for (Int_t binx = binxmin; binx<binxmax+1; binx++) {
655 if (binx==binxmin) {
656 fx = fxbinmin;
657 }
658 else if (binx==binxmax) {
659 fx = fxbinmax;
660 }
661 else {
662 fx = 1.0;
663 }
664 xc = h2->GetXaxis()->GetBinCenter(binx);
665 for (Int_t biny = binymin; biny<binymax+1; biny++) {
666 if (biny==binymin) {
667 fy = fybinmin;
668 }
669 else if (biny==binymax) {
670 fy = fybinmax;
671 }
672 else {
673 fy = 1.0;
674 }
675 yc = h2->GetYaxis()->GetBinCenter(biny);
676 h2->Fill(xc,yc,val*fx*fy);
677 }
678 }
682/// help routine to MakeVisHists() - fills for all variables
684void TMVA::RuleFit::FillVisHistCut(const Rule* rule, std::vector<TH2F *> & hlist)
686 Int_t nhists = hlist.size();
687 Int_t nvar = fMethodBase->GetNvar();
688 if (nhists!=nvar) Log() << kFATAL << "BUG TRAP: number of hists is not equal the number of variables!" << Endl;
689 //
690 std::vector<Int_t> vindex;
691 TString hstr;
692 // not a nice way to do a check...
693 for (Int_t ih=0; ih<nhists; ih++) {
694 hstr = hlist[ih]->GetTitle();
695 for (Int_t iv=0; iv<nvar; iv++) {
696 if (fMethodBase->GetInputTitle(iv) == hstr)
697 vindex.push_back(iv);
698 }
699 }
700 //
701 for (Int_t iv=0; iv<nvar; iv++) {
702 if (rule) {
703 if (rule->ContainsVariable(vindex[iv])) {
704 FillCut(hlist[iv],rule,vindex[iv]);
705 }
706 }
707 else {
708 FillLin(hlist[iv],vindex[iv]);
709 }
710 }
713/// help routine to MakeVisHists() - fills for all correlation plots
715void TMVA::RuleFit::FillVisHistCorr(const Rule * rule, std::vector<TH2F *> & hlist)
717 if (rule==0) return;
718 Double_t ruleimp = rule->GetImportance();
719 if (!(ruleimp>0)) return;
720 if (ruleimp<fRuleEnsemble.GetImportanceCut()) return;
721 //
722 Int_t nhists = hlist.size();
723 Int_t nvar = fMethodBase->GetNvar();
724 Int_t ncorr = (nvar*(nvar+1)/2)-nvar;
725 if (nhists!=ncorr) Log() << kERROR << "BUG TRAP: number of corr hists is not correct! ncorr = "
726 << ncorr << " nvar = " << nvar << " nhists = " << nhists << Endl;
727 //
728 std::vector< std::pair<Int_t,Int_t> > vindex;
729 TString hstr, var1, var2;
730 Int_t iv1=0,iv2=0;
731 // not a nice way to do a check...
732 for (Int_t ih=0; ih<nhists; ih++) {
733 hstr = hlist[ih]->GetName();
734 if (GetCorrVars( hstr, var1, var2 )) {
735 iv1 = fMethodBase->DataInfo().FindVarIndex( var1 );
736 iv2 = fMethodBase->DataInfo().FindVarIndex( var2 );
737 vindex.push_back( std::pair<Int_t,Int_t>(iv2,iv1) ); // pair X, Y
738 }
739 else {
740 Log() << kERROR << "BUG TRAP: should not be here - failed getting var1 and var2" << Endl;
741 }
742 }
743 //
744 for (Int_t ih=0; ih<nhists; ih++) {
745 if ( (rule->ContainsVariable(vindex[ih].first)) ||
746 (rule->ContainsVariable(vindex[ih].second)) ) {
747 FillCorr(hlist[ih],rule,vindex[ih].first,vindex[ih].second);
748 }
749 }
752/// get first and second variables from title
756 var1="";
757 var2="";
758 if(!title.BeginsWith("scat_")) return kFALSE;
760 TString titleCopy = title(5,title.Length());
761 if(titleCopy.Index("_RF2D")>=0) titleCopy.Remove(titleCopy.Index("_RF2D"));
763 Int_t splitPos = titleCopy.Index("_vs_");
764 if(splitPos>=0) { // there is a _vs_ in the string
765 var1 = titleCopy(0,splitPos);
766 var2 = titleCopy(splitPos+4, titleCopy.Length());
767 return kTRUE;
768 }
769 else {
770 var1 = titleCopy;
771 return kFALSE;
772 }
775/// this will create histograms visualizing the rule ensemble
779 const TString directories[5] = { "InputVariables_Id",
780 "InputVariables_Deco",
781 "InputVariables_PCA",
782 "InputVariables_Gauss",
783 "InputVariables_Gauss_Deco" };
785 const TString corrDirName = "CorrelationPlots";
787 TDirectory* rootDir = fMethodBase->GetFile();
788 TDirectory* varDir = 0;
789 TDirectory* corrDir = 0;
791 TDirectory* methodDir = fMethodBase->BaseDir();
792 TString varDirName;
793 //
794 Bool_t done=(rootDir==0);
795 Int_t type=0;
796 if (done) {
797 Log() << kWARNING << "No basedir - BUG??" << Endl;
798 return;
799 }
800 while (!done) {
801 varDir = (TDirectory*)rootDir->Get( directories[type] );
802 type++;
803 done = ((varDir!=0) || (type>4));
804 }
805 if (varDir==0) {
806 Log() << kWARNING << "No input variable directory found - BUG?" << Endl;
807 return;
808 }
809 corrDir = (TDirectory*)varDir->Get( corrDirName );
810 if (corrDir==0) {
811 Log() << kWARNING << "No correlation directory found" << Endl;
812 Log() << kWARNING << "Check for other warnings related to correlation histograms" << Endl;
813 return;
814 }
815 if (methodDir==0) {
816 Log() << kWARNING << "No rulefit method directory found - BUG?" << Endl;
817 return;
818 }
820 varDirName = varDir->GetName();
821 varDir->cd();
822 //
823 // get correlation plot directory
824 corrDir = (TDirectory *)varDir->Get(corrDirName);
825 if (corrDir==0) {
826 Log() << kWARNING << "No correlation directory found : " << corrDirName << Endl;
827 return;
828 }
830 // how many plots are in the var directory?
831 Int_t noPlots = ((varDir->GetListOfKeys())->GetEntries()) / 2;
832 Log() << kDEBUG << "Got number of plots = " << noPlots << Endl;
834 // loop over all objects in directory
835 std::vector<TH2F *> h1Vector;
836 std::vector<TH2F *> h2CorrVector;
837 TIter next(varDir->GetListOfKeys());
838 TKey *key;
839 while ((key = (TKey*)next())) {
840 // make sure, that we only look at histograms
841 TClass *cl = gROOT->GetClass(key->GetClassName());
842 if (!cl->InheritsFrom(TH1F::Class())) continue;
843 TH1F *sig = (TH1F*)key->ReadObj();
844 TString hname= sig->GetName();
845 Log() << kDEBUG << "Got histogram : " << hname << Endl;
847 // check for all signal histograms
848 if (hname.Contains("__S")){ // found a new signal plot
849 TString htitle = sig->GetTitle();
850 htitle.ReplaceAll("signal","");
851 TString newname = hname;
852 newname.ReplaceAll("__Signal","__RF");
853 newname.ReplaceAll("__S","__RF");
855 methodDir->cd();
856 TH2F *newhist = new TH2F(newname,htitle,sig->GetNbinsX(),sig->GetXaxis()->GetXmin(),sig->GetXaxis()->GetXmax(),
857 1,sig->GetYaxis()->GetXmin(),sig->GetYaxis()->GetXmax());
858 varDir->cd();
859 h1Vector.push_back( newhist );
860 }
861 }
862 //
863 corrDir->cd();
864 TString var1,var2;
865 TIter nextCorr(corrDir->GetListOfKeys());
866 while ((key = (TKey*)nextCorr())) {
867 // make sure, that we only look at histograms
868 TClass *cl = gROOT->GetClass(key->GetClassName());
869 if (!cl->InheritsFrom(TH2F::Class())) continue;
870 TH2F *sig = (TH2F*)key->ReadObj();
871 TString hname= sig->GetName();
873 // check for all signal histograms
874 if ((hname.Contains("scat_")) && (hname.Contains("_Signal"))) {
875 Log() << kDEBUG << "Got histogram (2D) : " << hname << Endl;
876 TString htitle = sig->GetTitle();
877 htitle.ReplaceAll("(Signal)","");
878 TString newname = hname;
879 newname.ReplaceAll("_Signal","_RF2D");
881 methodDir->cd();
882 const Int_t rebin=2;
883 TH2F *newhist = new TH2F(newname,htitle,
884 sig->GetNbinsX()/rebin,sig->GetXaxis()->GetXmin(),sig->GetXaxis()->GetXmax(),
885 sig->GetNbinsY()/rebin,sig->GetYaxis()->GetXmin(),sig->GetYaxis()->GetXmax());
886 if (GetCorrVars( newname, var1, var2 )) {
887 Int_t iv1 = fMethodBase->DataInfo().FindVarIndex(var1);
888 Int_t iv2 = fMethodBase->DataInfo().FindVarIndex(var2);
889 if (iv1<0) {
890 sig->GetYaxis()->SetTitle(var1);
891 }
892 else {
893 sig->GetYaxis()->SetTitle(fMethodBase->GetInputTitle(iv1));
894 }
895 if (iv2<0) {
896 sig->GetXaxis()->SetTitle(var2);
897 }
898 else {
899 sig->GetXaxis()->SetTitle(fMethodBase->GetInputTitle(iv2));
900 }
901 }
902 corrDir->cd();
903 h2CorrVector.push_back( newhist );
904 }
905 }
907 varDir->cd();
908 // fill rules
909 UInt_t nrules = fRuleEnsemble.GetNRules();
910 const Rule *rule;
911 for (UInt_t i=0; i<nrules; i++) {
912 rule = fRuleEnsemble.GetRulesConst(i);
913 FillVisHistCut(rule, h1Vector);
914 }
915 // fill linear terms and normalise hists
916 FillVisHistCut(0, h1Vector);
917 NormVisHists(h1Vector);
919 //
920 corrDir->cd();
921 // fill rules
922 for (UInt_t i=0; i<nrules; i++) {
923 rule = fRuleEnsemble.GetRulesConst(i);
924 FillVisHistCorr(rule, h2CorrVector);
925 }
926 NormVisHists(h2CorrVector);
928 // write histograms to file
929 methodDir->cd();
930 for (UInt_t i=0; i<h1Vector.size(); i++) h1Vector[i]->Write();
931 for (UInt_t i=0; i<h2CorrVector.size(); i++) h2CorrVector[i]->Write();
935/// this will create a histograms intended rather for debugging or for the curious user
939 TDirectory* methodDir = fMethodBase->BaseDir();
940 if (methodDir==0) {
941 Log() << kWARNING << "<MakeDebugHists> No rulefit method directory found - bug?" << Endl;
942 return;
943 }
944 //
945 methodDir->cd();
946 std::vector<Double_t> distances;
947 std::vector<Double_t> fncuts;
948 std::vector<Double_t> fnvars;
949 const Rule *ruleA;
950 const Rule *ruleB;
951 Double_t dABmin=1000000.0;
952 Double_t dABmax=-1.0;
953 UInt_t nrules = fRuleEnsemble.GetNRules();
954 for (UInt_t i=0; i<nrules; i++) {
955 ruleA = fRuleEnsemble.GetRulesConst(i);
956 for (UInt_t j=i+1; j<nrules; j++) {
957 ruleB = fRuleEnsemble.GetRulesConst(j);
958 Double_t dAB = ruleA->RuleDist( *ruleB, kTRUE );
959 if (dAB>-0.5) {
960 UInt_t nc = ruleA->GetNcuts();
961 UInt_t nv = ruleA->GetNumVarsUsed();
962 distances.push_back(dAB);
963 fncuts.push_back(static_cast<Double_t>(nc));
964 fnvars.push_back(static_cast<Double_t>(nv));
965 if (dAB<dABmin) dABmin=dAB;
966 if (dAB>dABmax) dABmax=dAB;
967 }
968 }
969 }
970 //
971 TH1F *histDist = new TH1F("RuleDist","Rule distances",100,dABmin,dABmax);
972 TTree *distNtuple = new TTree("RuleDistNtuple","RuleDist ntuple");
973 Double_t ntDist;
974 Double_t ntNcuts;
975 Double_t ntNvars;
976 distNtuple->Branch("dist", &ntDist, "dist/D");
977 distNtuple->Branch("ncuts",&ntNcuts, "ncuts/D");
978 distNtuple->Branch("nvars",&ntNvars, "nvars/D");
979 //
980 for (UInt_t i=0; i<distances.size(); i++) {
981 histDist->Fill(distances[i]);
982 ntDist = distances[i];
983 ntNcuts = fncuts[i];
984 ntNvars = fnvars[i];
985 distNtuple->Fill();
986 }
987 distNtuple->Write();
void Class()
Definition: Class.C:29
#define f(i)
Definition: RSha256.hxx:104
#define e(i)
Definition: RSha256.hxx:103
int Int_t
Definition: RtypesCore.h:41
unsigned int UInt_t
Definition: RtypesCore.h:42
const Bool_t kFALSE
Definition: RtypesCore.h:88
bool Bool_t
Definition: RtypesCore.h:59
double Double_t
Definition: RtypesCore.h:55
long long Long64_t
Definition: RtypesCore.h:69
const Bool_t kTRUE
Definition: RtypesCore.h:87
#define ClassImp(name)
Definition: Rtypes.h:365
int type
Definition: TGX11.cxx:120
#define gROOT
Definition: TROOT.h:415
virtual Double_t GetBinCenter(Int_t bin) const
Return center of bin.
Definition: TAxis.cxx:464
Double_t GetXmax() const
Definition: TAxis.h:134
virtual Int_t FindBin(Double_t x)
Find bin number corresponding to abscissa x.
Definition: TAxis.cxx:279
virtual Double_t GetBinLowEdge(Int_t bin) const
Return low edge of bin.
Definition: TAxis.cxx:504
Double_t GetXmin() const
Definition: TAxis.h:133
virtual Double_t GetBinWidth(Int_t bin) const
Return bin width.
Definition: TAxis.cxx:526
TClass instances represent classes, structs and namespaces in the ROOT type system.
Definition: TClass.h:75
Bool_t InheritsFrom(const char *cl) const
Return kTRUE if this class inherits from a class with name "classname".
Definition: TClass.cxx:4708
Describe directory structure in memory.
Definition: TDirectory.h:34
virtual TObject * Get(const char *namecycle)
Return pointer to object identified by namecycle.
Definition: TDirectory.cxx:805
virtual TFile * GetFile() const
Definition: TDirectory.h:157
virtual TList * GetListOfKeys() const
Definition: TDirectory.h:160
virtual Bool_t cd(const char *path=nullptr)
Change current directory to "this" directory.
Definition: TDirectory.cxx:497
1-D histogram with a float per channel (see TH1 documentation)}
Definition: TH1.h:571
virtual Int_t GetNbinsY() const
Definition: TH1.h:293
TAxis * GetXaxis()
Get the behaviour adopted by the object about the statoverflows. See EStatOverflows for more informat...
Definition: TH1.h:316
virtual Double_t GetMaximum(Double_t maxval=FLT_MAX) const
Return maximum value smaller than maxval of bins in the range, unless the value has been overridden b...
Definition: TH1.cxx:7994
virtual Int_t GetNbinsX() const
Definition: TH1.h:292
virtual void SetMaximum(Double_t maximum=-1111)
Definition: TH1.h:394
virtual Int_t Fill(Double_t x)
Increment bin with abscissa X by 1.
Definition: TH1.cxx:3275
TAxis * GetYaxis()
Definition: TH1.h:317
virtual void SetMinimum(Double_t minimum=-1111)
Definition: TH1.h:395
virtual void Scale(Double_t c1=1, Option_t *option="")
Multiply this histogram by a constant c1.
Definition: TH1.cxx:6234
virtual Int_t FindBin(Double_t x, Double_t y=0, Double_t z=0)
Return Global bin number corresponding to x,y,z.
Definition: TH1.cxx:3596
virtual Double_t GetMinimum(Double_t minval=-FLT_MAX) const
Return minimum value larger than minval of bins in the range, unless the value has been overridden by...
Definition: TH1.cxx:8079
2-D histogram with a float per channel (see TH1 documentation)}
Definition: TH2.h:251
Int_t Fill(Double_t)
Invalid Fill method.
Definition: TH2.cxx:292
virtual Int_t GetBin(Int_t binx, Int_t biny, Int_t binz=0) const
Return Global bin number corresponding to binx,y,z.
Definition: TH2.cxx:928
Book space in a file, create I/O buffers, to fill them, (un)compress them.
Definition: TKey.h:24
virtual const char * GetClassName() const
Definition: TKey.h:72
virtual TObject * ReadObj()
To read a TObject* from the file.
Definition: TKey.cxx:729
UInt_t GetNNodes() const
Definition: BinaryTree.h:86
Implementation of a Decision Tree.
Definition: DecisionTree.h:64
void SetPruneMethod(EPruneMethod m=kCostComplexityPruning)
Definition: DecisionTree.h:139
void SetPruneStrength(Double_t p)
Definition: DecisionTree.h:145
Double_t CheckEvent(const TMVA::Event *, Bool_t UseYesNoLeaf=kFALSE) const
the event e is put into the decision tree (starting at the root node) and the output is NodeType (sig...
UInt_t BuildTree(const EventConstList &eventSample, DecisionTreeNode *node=NULL)
building the decision tree by recursively calling the splitting of one (root-) node into two daughter...
Double_t PruneTree(const EventConstList *validationSample=NULL)
prune (get rid of internal nodes) the Decision tree to avoid overtraining several different pruning m...
void SetNVars(Int_t n)
Definition: DecisionTree.h:193
Virtual base Class for all MVA method.
Definition: MethodBase.h:111
J Friedman's RuleFit method.
Definition: MethodRuleFit.h:47
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
Bool_t GetCutRange(Int_t sel, Double_t &rmin, Double_t &rmax, Bool_t &dormin, Bool_t &dormax) const
get cut range for a given selector
Definition: RuleCut.cxx:176
A class implementing various fits of rule ensembles.
Definition: RuleFit.h:45
void GetRndmSampleEvents(std::vector< const TMVA::Event * > &evevec, UInt_t nevents)
draw a random subsample of the training events without replacement
Definition: RuleFit.cxx:467
Double_t EvalEvent(const Event &e)
evaluate single event
Definition: RuleFit.cxx:432
void SetMethodBase(const MethodBase *rfbase)
set MethodBase
Definition: RuleFit.cxx:151
void InitPtrs(const TMVA::MethodBase *rfbase)
initialize pointers
Definition: RuleFit.cxx:110
void Boost(TMVA::DecisionTree *dt)
Boost the events.
Definition: RuleFit.cxx:339
void ForestStatistics()
summary of statistics of all trees
Definition: RuleFit.cxx:386
static const Int_t randSEED
Definition: RuleFit.h:175
void CalcImportance()
calculates the importance of each rule
Definition: RuleFit.cxx:418
void SetMsgType(EMsgType t)
set the current message type to that of mlog for this class and all other subtools
Definition: RuleFit.cxx:191
void Initialize(const TMVA::MethodBase *rfbase)
initialize the parameters of the RuleFit method and make rules
Definition: RuleFit.cxx:120
virtual ~RuleFit(void)
Definition: RuleFit.cxx:90
void FillVisHistCorr(const Rule *rule, std::vector< TH2F * > &hlist)
help routine to MakeVisHists() - fills for all correlation plots
Definition: RuleFit.cxx:715
std::default_random_engine fRNGEngine
Definition: RuleFit.h:176
void InitNEveEff()
init effective number of events (using event weights)
Definition: RuleFit.cxx:98
void SaveEventWeights()
save event weights - must be done before making the forest
Definition: RuleFit.cxx:309
void FillCut(TH2F *h2, const TMVA::Rule *rule, Int_t vind)
Fill cut.
Definition: RuleFit.cxx:533
void FillLin(TH2F *h2, Int_t vind)
fill lin
Definition: RuleFit.cxx:584
Bool_t GetCorrVars(TString &title, TString &var1, TString &var2)
get first and second variables from title
Definition: RuleFit.cxx:754
void MakeForest()
make a forest of decisiontrees
Definition: RuleFit.cxx:222
const std::vector< const TMVA::DecisionTree * > & GetForest() const
Definition: RuleFit.h:143
void FitCoefficients()
Fit the coefficients for the rule ensemble.
Definition: RuleFit.cxx:409
const MethodBase * GetMethodBase() const
Definition: RuleFit.h:149
void FillCorr(TH2F *h2, const TMVA::Rule *rule, Int_t v1, Int_t v2)
fill rule correlation between vx and vy, weighted with either the importance or the coefficient
Definition: RuleFit.cxx:608
void NormVisHists(std::vector< TH2F * > &hlist)
normalize rule importance hists
Definition: RuleFit.cxx:486
void RestoreEventWeights()
save event weights - must be done before making the forest
Definition: RuleFit.cxx:321
void MakeVisHists()
this will create histograms visualizing the rule ensemble
Definition: RuleFit.cxx:777
void FillVisHistCut(const Rule *rule, std::vector< TH2F * > &hlist)
help routine to MakeVisHists() - fills for all variables
Definition: RuleFit.cxx:684
void BuildTree(TMVA::DecisionTree *dt)
build the decision tree using fNTreeSample events from fTrainingEventsRndm
Definition: RuleFit.cxx:201
const std::vector< const TMVA::Event * > & GetTrainingEvents() const
Definition: RuleFit.h:137
const MethodRuleFit * GetMethodRuleFit() const
Definition: RuleFit.h:148
void SetTrainingEvents(const std::vector< const TMVA::Event * > &el)
set the training events randomly
Definition: RuleFit.cxx:440
void Copy(const RuleFit &other)
copy method
Definition: RuleFit.cxx:160
const RuleEnsemble & GetRuleEnsemble() const
Definition: RuleFit.h:144
Double_t CalcWeightSum(const std::vector< const TMVA::Event * > *events, UInt_t neve=0)
calculate the sum of weights
Definition: RuleFit.cxx:176
default constructor
Definition: RuleFit.cxx:76
void MakeDebugHists()
this will create a histograms intended rather for debugging or for the curious user
Definition: RuleFit.cxx:937
Implementation of a rule.
Definition: Rule.h:48
Double_t GetSupport() const
Definition: Rule.h:140
UInt_t GetNcuts() const
Definition: Rule.h:131
UInt_t GetNumVarsUsed() const
Definition: Rule.h:128
const RuleCut * GetRuleCut() const
Definition: Rule.h:137
Double_t GetCoefficient() const
Definition: Rule.h:139
Double_t GetImportance() const
Definition: Rule.h:143
Double_t RuleDist(const Rule &other, Bool_t useCutValue) const
Definition: Rule.cxx:190
Bool_t ContainsVariable(UInt_t iv) const
check if variable in node
Definition: Rule.cxx:137
Timing information for training and evaluation of MVA methods.
Definition: Timer.h:58
@ kTraining
Definition: Types.h:144
virtual void SetTitle(const char *title="")
Set the title of the TNamed.
Definition: TNamed.cxx:164
virtual const char * GetTitle() const
Returns title of object.
Definition: TNamed.h:48
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:47
Random number generator class based on M.
Definition: TRandom3.h:27
virtual Double_t Uniform(Double_t x1=1)
Returns a uniform deviate on the interval (0, x1).
Definition: TRandom.cxx:635
Basic string class.
Definition: TString.h:131
Ssiz_t Length() const
Definition: TString.h:405
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition: TString.h:687
Bool_t BeginsWith(const char *s, ECaseCompare cmp=kExact) const
Definition: TString.h:610
TString & Remove(Ssiz_t pos)
Definition: TString.h:668
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition: TString.h:619
Ssiz_t Index(const char *pat, Ssiz_t i=0, ECaseCompare cmp=kExact) const
Definition: TString.h:634
A TTree represents a columnar dataset.
Definition: TTree.h:72
virtual Int_t Fill()
Fill all branches.
Definition: TTree.cxx:4487
TBranch * Branch(const char *name, T *obj, Int_t bufsize=32000, Int_t splitlevel=99)
Add a new branch, and infer the data type from the type of obj being passed.
Definition: TTree.h:341
virtual Int_t Write(const char *name=0, Int_t option=0, Int_t bufsize=0)
Write this object to the current directory.
Definition: TTree.cxx:9485
static constexpr double second
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Double_t Log(Double_t x)
Definition: TMath.h:750
Double_t Sqrt(Double_t x)
Definition: TMath.h:681
Short_t Abs(Short_t d)
Definition: TMathBase.h:120
Definition: first.py:1
Definition: tree.py:1