ROOT  6.07/01
Reference Guide
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
RuleFit.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Fredrik Tegenfeldt, Helge Voss
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : RuleFit *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * A class implementing various fits of rule ensembles *
12  * *
13  * Authors (alphabetical): *
14  * Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA *
15  * Helge Voss <Helge.Voss@cern.ch> - MPI-KP Heidelberg, Ger. *
16  * *
17  * Copyright (c) 2005: *
18  * CERN, Switzerland *
19  * Iowa State U. *
20  * MPI-K Heidelberg, Germany *
21  * *
22  * Redistribution and use in source and binary forms, with or without *
23  * modification, are permitted according to the terms listed in LICENSE *
24  * (http://tmva.sourceforge.net/LICENSE) *
25  **********************************************************************************/
26 
27 #ifndef ROOT_TMVA_RuleFit
28 #define ROOT_TMVA_RuleFit
29 
30 #include <algorithm>
31 
32 #ifndef ROOT_TMVA_DecisionTree
33 #include "TMVA/DecisionTree.h"
34 #endif
35 #ifndef ROOT_TMVA_RuleEnsemble
36 #include "TMVA/RuleEnsemble.h"
37 #endif
38 #ifndef ROOT_TMVA_RuleFitParams
39 #include "TMVA/RuleFitParams.h"
40 #endif
41 #ifndef ROOT_TMVA_Event
42 #include "TMVA/Event.h"
43 #endif
44 
45 namespace TMVA {
46 
47 
48  class MethodBase;
49  class MethodRuleFit;
50  class MsgLogger;
51 
52  class RuleFit {
53 
54  public:
55 
56  // main constructor
57  RuleFit( const TMVA::MethodBase *rfbase );
58 
59  // empty constructor
60  RuleFit( void );
61 
62  virtual ~RuleFit( void );
63 
64  void InitNEveEff();
65  void InitPtrs( const TMVA::MethodBase *rfbase );
66  void Initialize( const TMVA::MethodBase *rfbase );
67 
68  void SetMsgType( EMsgType t );
69 
70  void SetTrainingEvents( const std::vector<const TMVA::Event *> & el );
71 
72  void ReshuffleEvents() { std::random_shuffle(fTrainingEventsRndm.begin(),fTrainingEventsRndm.end()); }
73 
74  void SetMethodBase( const MethodBase *rfbase );
75 
76  // make the forest of trees for rule generation
77  void MakeForest();
78 
79  // build a tree
80  void BuildTree( TMVA::DecisionTree *dt );
81 
82  // save event weights
83  void SaveEventWeights();
84 
85  // restore saved event weights
86  void RestoreEventWeights();
87 
88  // boost events based on the given tree
89  void Boost( TMVA::DecisionTree *dt );
90 
91  // calculate and print some statistics on the given forest
92  void ForestStatistics();
93 
94  // calculate the discriminating variable for the given event
95  Double_t EvalEvent( const Event& e );
96 
97  // calculate sum of
98  Double_t CalcWeightSum( const std::vector<const TMVA::Event *> *events, UInt_t neve=0 );
99 
100  // do the fitting of the coefficients
101  void FitCoefficients();
102 
103  // calculate variable and rule importance from a set of events
104  void CalcImportance();
105 
106  // set usage of linear term
108  // set usage of rules
110  // set usage of linear term
112  // set minimum importance allowed
114  // set minimum rule distance - see RuleEnsemble
116  // set path related parameters
117  void SetGDTau( Double_t t=0.0 ) { fRuleFitParams.SetGDTau(t); }
120  // make visualization histograms
124  void MakeVisHists();
125  void FillVisHistCut(const Rule * rule, std::vector<TH2F *> & hlist);
126  void FillVisHistCorr(const Rule * rule, std::vector<TH2F *> & hlist);
127  void FillCut(TH2F* h2,const TMVA::Rule *rule,Int_t vind);
128  void FillLin(TH2F* h2,Int_t vind);
129  void FillCorr(TH2F* h2,const TMVA::Rule *rule,Int_t v1, Int_t v2);
130  void NormVisHists(std::vector<TH2F *> & hlist);
131  void MakeDebugHists();
132  Bool_t GetCorrVars(TString & title, TString & var1, TString & var2);
133  // accessors
134  UInt_t GetNTreeSample() const { return fNTreeSample; }
135  Double_t GetNEveEff() const { return fNEveEffTrain; } // reweighted number of events = sum(wi)
136  const Event* GetTrainingEvent(UInt_t i) const { return static_cast< const Event *>(fTrainingEvents[i]); }
137  Double_t GetTrainingEventWeight(UInt_t i) const { return fTrainingEvents[i]->GetWeight(); }
138 
139  // const Event* GetTrainingEvent(UInt_t i, UInt_t isub) const { return &(fTrainingEvents[fSubsampleEvents[isub]])[i]; }
140 
141  const std::vector< const TMVA::Event * > & GetTrainingEvents() const { return fTrainingEvents; }
142  // const std::vector< Int_t > & GetSubsampleEvents() const { return fSubsampleEvents; }
143 
144  // void GetSubsampleEvents(Int_t sub, UInt_t & ibeg, UInt_t & iend) const;
145  void GetRndmSampleEvents(std::vector< const TMVA::Event * > & evevec, UInt_t nevents);
146  //
147  const std::vector< const TMVA::DecisionTree *> & GetForest() const { return fForest; }
148  const RuleEnsemble & GetRuleEnsemble() const { return fRuleEnsemble; }
150  const RuleFitParams & GetRuleFitParams() const { return fRuleFitParams; }
152  const MethodRuleFit * GetMethodRuleFit() const { return fMethodRuleFit; }
153  const MethodBase * GetMethodBase() const { return fMethodBase; }
154 
155  private:
156 
157  // copy constructor
158  RuleFit( const RuleFit & other );
159 
160  // copy method
161  void Copy( const RuleFit & other );
162 
163  std::vector<const TMVA::Event *> fTrainingEvents; // all training events
164  std::vector<const TMVA::Event *> fTrainingEventsRndm; // idem, but randomly shuffled
165  std::vector<Double_t> fEventWeights; // original weights of the events - follows fTrainingEvents
166  UInt_t fNTreeSample; // number of events in sub sample = frac*neve
167 
168  Double_t fNEveEffTrain; // reweighted number of events = sum(wi)
169  std::vector< const TMVA::DecisionTree *> fForest; // the input forest of decision trees
170  RuleEnsemble fRuleEnsemble; // the ensemble of rules
171  RuleFitParams fRuleFitParams; // fit rule parameters
172  const MethodRuleFit *fMethodRuleFit; // pointer the method which initialized this RuleFit instance
173  const MethodBase *fMethodBase; // pointer the method base which initialized this RuleFit instance
174  Bool_t fVisHistsUseImp; // if true, use importance as weight; else coef in vis hists
175 
176  mutable MsgLogger* fLogger; // message logger
177  MsgLogger& Log() const { return *fLogger; }
178 
179  static const Int_t randSEED = 0; // set to 1 for debugging purposes or to zero for random seeds
180 
181  ClassDef(RuleFit,0) // Calculations for Friedman's RuleFit method
182  };
183 }
184 
185 #endif
std::vector< const TMVA::Event * > fTrainingEventsRndm
Definition: RuleFit.h:164
void ForestStatistics()
summary of statistics of all trees
Definition: RuleFit.cxx:379
void MakeForest()
make a forest of decisiontrees
Definition: RuleFit.cxx:214
const RuleFitParams & GetRuleFitParams() const
Definition: RuleFit.h:150
void SetVisHistsUseImp(Bool_t f)
Definition: RuleFit.h:121
const Double_t * v1
Definition: TArcBall.cxx:33
void SetGDTau(Double_t t=0.0)
Definition: RuleFit.h:117
const std::vector< const TMVA::Event * > & GetTrainingEvents() const
Definition: RuleFit.h:141
void CalcImportance()
calculates the importance of each rule
Definition: RuleFit.cxx:411
MsgLogger & Log() const
Definition: RuleFit.h:177
void FillVisHistCorr(const Rule *rule, std::vector< TH2F * > &hlist)
help routine to MakeVisHists() - fills for all correlation plots
Definition: RuleFit.cxx:709
void SetGDTau(Double_t t)
Definition: RuleFitParams.h:90
void SetMsgType(EMsgType t)
set the current message type to that of mlog for this class and all other subtools ...
Definition: RuleFit.cxx:183
const MethodBase * fMethodBase
Definition: RuleFit.h:173
Bool_t GetCorrVars(TString &title, TString &var1, TString &var2)
get first and second variables from title
Definition: RuleFit.cxx:748
void InitNEveEff()
init effective number of events (using event weights)
Definition: RuleFit.cxx:90
void FitCoefficients()
Fit the coefficients for the rule ensemble.
Definition: RuleFit.cxx:402
Basic string class.
Definition: TString.h:137
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kFALSE
Definition: Rtypes.h:92
void SetModelFull()
Definition: RuleFit.h:111
std::vector< Double_t > fEventWeights
Definition: RuleFit.h:165
Double_t fNEveEffTrain
Definition: RuleFit.h:168
void SetModelLinear()
Definition: RuleFit.h:107
TFile * f
void SetModelRules()
Definition: RuleFit.h:109
RuleFit(void)
default constructor
Definition: RuleFit.cxx:68
#define ClassDef(name, id)
Definition: Rtypes.h:254
void BuildTree(TMVA::DecisionTree *dt)
build the decision tree using fNTreeSample events from fTrainingEventsRndm
Definition: RuleFit.cxx:193
const MethodBase * GetMethodBase() const
Definition: RuleFit.h:153
void UseImportanceVisHists()
Definition: RuleFit.h:122
int d
Definition: tornado.py:11
void ReshuffleEvents()
Definition: RuleFit.h:72
void SetTrainingEvents(const std::vector< const TMVA::Event * > &el)
set the training events randomly
Definition: RuleFit.cxx:433
void GetRndmSampleEvents(std::vector< const TMVA::Event * > &evevec, UInt_t nevents)
draw a random subsample of the training events without replacement
Definition: RuleFit.cxx:460
TH2D * h2
Definition: fit2dHist.C:45
virtual ~RuleFit(void)
destructor
Definition: RuleFit.cxx:82
RuleEnsemble * GetRuleEnsemblePtr()
Definition: RuleFit.h:149
void SetGDNPathSteps(Int_t n=100)
Definition: RuleFit.h:119
void SetMethodBase(const MethodBase *rfbase)
set MethodBase
Definition: RuleFit.cxx:143
void SetGDNPathSteps(Int_t np)
Definition: RuleFitParams.h:73
UInt_t fNTreeSample
Definition: RuleFit.h:166
const std::vector< const TMVA::DecisionTree * > & GetForest() const
Definition: RuleFit.h:147
TThread * t[5]
Definition: threadsh1.C:13
void RestoreEventWeights()
save event weights - must be done before making the forest
Definition: RuleFit.cxx:314
void FillVisHistCut(const Rule *rule, std::vector< TH2F * > &hlist)
help routine to MakeVisHists() - fills for all variables
Definition: RuleFit.cxx:678
RuleFitParams * GetRuleFitParamsPtr()
Definition: RuleFit.h:151
void Copy(const RuleFit &other)
copy method
Definition: RuleFit.cxx:152
TPaveLabel title(3, 27.1, 15, 28.7,"ROOT Environment and Tools")
void FillCorr(TH2F *h2, const TMVA::Rule *rule, Int_t v1, Int_t v2)
fill rule correlation between vx and vy, weighted with either the importance or the coefficient ...
Definition: RuleFit.cxx:602
void MakeDebugHists()
this will create a histograms intended rather for debugging or for the curious user ...
Definition: RuleFit.cxx:932
static const Int_t randSEED
Definition: RuleFit.h:179
2-D histogram with a float per channel (see TH1 documentation)}
Definition: TH2.h:256
const RuleEnsemble & GetRuleEnsemble() const
Definition: RuleFit.h:148
EMsgType
Definition: Types.h:61
unsigned int UInt_t
Definition: RtypesCore.h:42
UInt_t GetNTreeSample() const
Definition: RuleFit.h:134
void SetImportanceCut(Double_t minimp=0)
Definition: RuleEnsemble.h:141
void FillLin(TH2F *h2, Int_t vind)
fill lin
Definition: RuleFit.cxx:578
RuleEnsemble fRuleEnsemble
Definition: RuleFit.h:170
const MethodRuleFit * GetMethodRuleFit() const
Definition: RuleFit.h:152
void Boost(TMVA::DecisionTree *dt)
Boost the events.
Definition: RuleFit.cxx:332
void SetGDPathStep(Double_t s=0.01)
Definition: RuleFit.h:118
void SaveEventWeights()
save event weights - must be done before making the forest
Definition: RuleFit.cxx:302
double Double_t
Definition: RtypesCore.h:55
const MethodRuleFit * fMethodRuleFit
Definition: RuleFit.h:172
void MakeVisHists()
this will create histograms visualizing the rule ensemble
Definition: RuleFit.cxx:771
void InitPtrs(const TMVA::MethodBase *rfbase)
initialize pointers
Definition: RuleFit.cxx:102
Double_t CalcWeightSum(const std::vector< const TMVA::Event * > *events, UInt_t neve=0)
calculate the sum of weights
Definition: RuleFit.cxx:168
void NormVisHists(std::vector< TH2F * > &hlist)
normalize rule importance hists
Definition: RuleFit.cxx:480
Double_t GetTrainingEventWeight(UInt_t i) const
Definition: RuleFit.h:137
void SetRuleMinDist(Double_t d)
Definition: RuleFit.h:115
void FillCut(TH2F *h2, const TMVA::Rule *rule, Int_t vind)
Fill cut.
Definition: RuleFit.cxx:527
Bool_t fVisHistsUseImp
Definition: RuleFit.h:174
RuleFitParams fRuleFitParams
Definition: RuleFit.h:171
Double_t EvalEvent(const Event &e)
evaluate single event
Definition: RuleFit.cxx:425
std::vector< const TMVA::DecisionTree * > fForest
Definition: RuleFit.h:169
void UseCoefficientsVisHists()
Definition: RuleFit.h:123
void Initialize(const TMVA::MethodBase *rfbase)
initialize the parameters of the RuleFit method and make rules
Definition: RuleFit.cxx:112
const Bool_t kTRUE
Definition: Rtypes.h:91
std::vector< const TMVA::Event * > fTrainingEvents
Definition: RuleFit.h:163
const Int_t n
Definition: legend1.C:16
const Event * GetTrainingEvent(UInt_t i) const
Definition: RuleFit.h:136
void SetGDPathStep(Double_t s)
Definition: RuleFitParams.h:76
void SetRuleMinDist(Double_t d)
Definition: RuleEnsemble.h:138
void SetImportanceCut(Double_t minimp=0)
Definition: RuleFit.h:113
MsgLogger * fLogger
Definition: RuleFit.h:176
Double_t GetNEveEff() const
Definition: RuleFit.h:135