Logo ROOT   6.16/01
Reference Guide
MethodRuleFit.h
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Fredrik Tegenfeldt
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodRuleFit *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Friedman's RuleFit method *
12 * *
13 * Authors (alphabetical): *
14 * Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA *
15 * *
16 * Copyright (c) 2005: *
17 * CERN, Switzerland *
18 * Iowa State U. *
19 * MPI-K Heidelberg, Germany *
20 * *
21 * Redistribution and use in source and binary forms, with or without *
22 * modification, are permitted according to the terms listed in LICENSE *
23 * *
24 **********************************************************************************/
25
26#ifndef ROOT_TMVA_MethodRuleFit
27#define ROOT_TMVA_MethodRuleFit
28
29//////////////////////////////////////////////////////////////////////////
30// //
31// MethodRuleFit //
32// //
33// J Friedman's RuleFit method //
34// //
35//////////////////////////////////////////////////////////////////////////
36
37#include "TMVA/MethodBase.h"
38#include "TMatrixDfwd.h"
39#include "TVectorD.h"
40#include "TMVA/DecisionTree.h"
41#include "TMVA/RuleFit.h"
42
43namespace TMVA {
44
45 class SeparationBase;
46
47 class MethodRuleFit : public MethodBase {
48
49 public:
50
51 MethodRuleFit( const TString& jobName,
52 const TString& methodTitle,
53 DataSetInfo& theData,
54 const TString& theOption = "");
55
56 MethodRuleFit( DataSetInfo& theData,
57 const TString& theWeightFile);
58
59 virtual ~MethodRuleFit( void );
60
61 virtual Bool_t HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t /*numberTargets*/ );
62
63 // training method
64 void Train( void );
65
67
68 // write weights to file
69 void AddWeightsXMLTo ( void* parent ) const;
70
71 // read weights from file
72 void ReadWeightsFromStream( std::istream& istr );
73 void ReadWeightsFromXML ( void* wghtnode );
74
75 // calculate the MVA value
76 Double_t GetMvaValue( Double_t* err = 0, Double_t* errUpper = 0 );
77
78 // write method specific histos to target file
79 void WriteMonitoringHistosToFile( void ) const;
80
81 // ranking of input variables
82 const Ranking* CreateRanking();
83
84 Bool_t UseBoost() const { return fUseBoost; }
85
86 // accessors
88 const RuleFit* GetRuleFitConstPtr() const { return &fRuleFit; }
89 TDirectory* GetMethodBaseDir() const { return BaseDir(); }
90 const std::vector<TMVA::Event*>& GetTrainingEvents() const { return fEventSample; }
91 const std::vector<TMVA::DecisionTree*>& GetForest() const { return fForest; }
92 Int_t GetNTrees() const { return fNTrees; }
100 Int_t GetNCuts() const { return fNCuts; }
101 //
107 //
109
110 const TString GetRFWorkDir() const { return fRFWorkDir; }
111 Int_t GetRFNrules() const { return fRFNrules; }
113
114 protected:
115
116 // make ROOT-independent C++ class for classifier response (classifier-specific implementation)
117 void MakeClassSpecific( std::ostream&, const TString& ) const;
118
119 void MakeClassRuleCuts( std::ostream& ) const;
120
121 void MakeClassLinear( std::ostream& ) const;
122
123 // get help message text
124 void GetHelpMessage() const;
125
126 // initialize rulefit
127 void Init( void );
128
129 // copy all training events into a stl::vector
130 void InitEventSample( void );
131
132 // initialize monitor ntuple
133 void InitMonitorNtuple();
134
135 void TrainTMVARuleFit();
136 void TrainJFRuleFit();
137
138 private:
139
140 // check variable range and set var to lower or upper if out of range
141 template<typename T>
142 inline Bool_t VerifyRange( MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax );
143
144 template<typename T>
145 inline Bool_t VerifyRange( MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax, const T& vdef );
146
147 template<typename T>
148 inline Int_t VerifyRange( const T& var, const T& vmin, const T& vmax );
149
150 // the option handling methods
151 void DeclareOptions();
152 void ProcessOptions();
153
154 RuleFit fRuleFit; // RuleFit instance
155 std::vector<TMVA::Event *> fEventSample; // the complete training sample
156 Double_t fSignalFraction; // scalefactor for bkg events to modify initial s/b fraction in training data
157
158 // ntuple
159 TTree *fMonitorNtuple; // pointer to monitor rule ntuple
160 Double_t fNTImportance; // ntuple: rule importance
161 Double_t fNTCoefficient; // ntuple: rule coefficient
162 Double_t fNTSupport; // ntuple: rule support
163 Int_t fNTNcuts; // ntuple: rule number of cuts
164 Int_t fNTNvars; // ntuple: rule number of vars
165 Double_t fNTPtag; // ntuple: rule P(tag)
166 Double_t fNTPss; // ntuple: rule P(tag s, true s)
167 Double_t fNTPsb; // ntuple: rule P(tag s, true b)
168 Double_t fNTPbs; // ntuple: rule P(tag b, true s)
169 Double_t fNTPbb; // ntuple: rule P(tag b, true b)
170 Double_t fNTSSB; // ntuple: rule S/(S+B)
171 Int_t fNTType; // ntuple: rule type (+1->signal, -1->bkg)
172
173 // options
174 TString fRuleFitModuleS;// which rulefit module to use
175 Bool_t fUseRuleFitJF; // if true interface with J.Friedmans RuleFit module
176 TString fRFWorkDir; // working directory from Friedmans module
177 Int_t fRFNrules; // max number of rules (only Friedmans module)
178 Int_t fRFNendnodes; // max number of rules (only Friedmans module)
179 std::vector<DecisionTree *> fForest; // the forest
180 Int_t fNTrees; // number of trees in forest
181 Double_t fTreeEveFrac; // fraction of events used for training each tree
182 SeparationBase *fSepType; // the separation used in node splitting
183 Double_t fMinFracNEve; // min fraction of number events
184 Double_t fMaxFracNEve; // ditto max
185 Int_t fNCuts; // grid used in cut applied in node splitting
186 TString fSepTypeS; // forest generation: separation type - see DecisionTree
187 TString fPruneMethodS; // forest generation: prune method - see DecisionTree
188 TMVA::DecisionTree::EPruneMethod fPruneMethod; // forest generation: method used for pruning - see DecisionTree
189 Double_t fPruneStrength; // forest generation: prune strength - see DecisionTree
190 TString fForestTypeS; // forest generation: how the trees are generated
191 Bool_t fUseBoost; // use boosted events for forest generation
192 //
193 Double_t fGDPathEveFrac; // GD path: fraction of subsamples used for the fitting
194 Double_t fGDValidEveFrac; // GD path: fraction of subsamples used for the fitting
195 Double_t fGDTau; // GD path: def threshold fraction [0..1]
196 Double_t fGDTauPrec; // GD path: precision of estimated tau
197 Double_t fGDTauMin; // GD path: min threshold fraction [0..1]
198 Double_t fGDTauMax; // GD path: max threshold fraction [0..1]
199 UInt_t fGDTauScan; // GD path: number of points to scan
200 Double_t fGDPathStep; // GD path: step size in path
201 Int_t fGDNPathSteps; // GD path: number of steps
202 Double_t fGDErrScale; // GD path: stop
203 Double_t fMinimp; // rule/linear: minimum importance
204 //
205 TString fModelTypeS; // rule ensemble: which model (rule,linear or both)
206 Double_t fRuleMinDist; // rule min distance - see RuleEnsemble
207 Double_t fLinQuantile; // quantile cut to remove outliers - see RuleEnsemble
208
209 ClassDef(MethodRuleFit,0); // Friedman's RuleFit method
210 };
211
212} // namespace TMVA
213
214
215//_______________________________________________________________________
216template<typename T>
217inline Int_t TMVA::MethodRuleFit::VerifyRange( const T& var, const T& vmin, const T& vmax )
218{
219 // check range and return +1 if above, -1 if below or 0 if inside
220 if (var>vmax) return 1;
221 if (var<vmin) return -1;
222 return 0;
223}
224
225//_______________________________________________________________________
226template<typename T>
227inline Bool_t TMVA::MethodRuleFit::VerifyRange( TMVA::MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax )
228{
229 // verify range and print out message
230 // if outside range, set to closest limit
231 Int_t dir = TMVA::MethodRuleFit::VerifyRange(var,vmin,vmax);
232 Bool_t modif=kFALSE;
233 if (dir==1) {
234 modif = kTRUE;
235 var=vmax;
236 }
237 if (dir==-1) {
238 modif = kTRUE;
239 var=vmin;
240 }
241 if (modif) {
242 mlog << kWARNING << "Option <" << varstr << "> " << (dir==1 ? "above":"below") << " allowed range. Reset to new value = " << var << Endl;
243 }
244 return modif;
245}
246
247//_______________________________________________________________________
248template<typename T>
249inline Bool_t TMVA::MethodRuleFit::VerifyRange( TMVA::MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax, const T& vdef )
250{
251 // verify range and print out message
252 // if outside range, set to given default value
253 Int_t dir = TMVA::MethodRuleFit::VerifyRange(var,vmin,vmax);
254 Bool_t modif=kFALSE;
255 if (dir!=0) {
256 modif = kTRUE;
257 var=vdef;
258 }
259 if (modif) {
260 mlog << kWARNING << "Option <" << varstr << "> " << (dir==1 ? "above":"below") << " allowed range. Reset to default value = " << var << Endl;
261 }
262 return modif;
263}
264
265
266#endif // MethodRuleFit_H
int Int_t
Definition: RtypesCore.h:41
unsigned int UInt_t
Definition: RtypesCore.h:42
const Bool_t kFALSE
Definition: RtypesCore.h:88
bool Bool_t
Definition: RtypesCore.h:59
double Double_t
Definition: RtypesCore.h:55
const Bool_t kTRUE
Definition: RtypesCore.h:87
#define ClassDef(name, id)
Definition: Rtypes.h:324
int type
Definition: TGX11.cxx:120
Describe directory structure in memory.
Definition: TDirectory.h:34
Class that contains all the data information.
Definition: DataSetInfo.h:60
Virtual base Class for all MVA method.
Definition: MethodBase.h:109
TDirectory * BaseDir() const
returns the ROOT directory where info/histograms etc of the corresponding MVA method instance are sto...
virtual void ReadWeightsFromStream(std::istream &)=0
J Friedman's RuleFit method.
Definition: MethodRuleFit.h:47
Double_t GetLinQuantile() const
RuleFit * GetRuleFitPtr()
Definition: MethodRuleFit.h:87
const std::vector< TMVA::Event * > & GetTrainingEvents() const
Definition: MethodRuleFit.h:90
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
returns MVA value for given event
Int_t GetRFNendnodes() const
Double_t GetMinFracNEve() const
Definition: MethodRuleFit.h:98
Double_t GetGDPathEveFrac() const
Int_t GetNCuts() const
void MakeClassSpecific(std::ostream &, const TString &) const
write specific classifier response
TMVA::DecisionTree::EPruneMethod fPruneMethod
std::vector< DecisionTree * > fForest
Double_t GetGDErrScale() const
Double_t GetGDValidEveFrac() const
Int_t GetRFNrules() const
const TString GetRFWorkDir() const
void MakeClassLinear(std::ostream &) const
print out the linear terms
void GetHelpMessage() const
get help message text
std::vector< TMVA::Event * > fEventSample
void TrainJFRuleFit()
training of rules using Jerome Friedmans implementation
Double_t GetPruneStrength() const
Definition: MethodRuleFit.h:97
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t)
RuleFit can handle classification with 2 classes.
void ProcessOptions()
process the options specified by the user
const SeparationBase * GetSeparationBaseConst() const
Definition: MethodRuleFit.h:94
TMVA::DecisionTree::EPruneMethod GetPruneMethod() const
Definition: MethodRuleFit.h:96
void ReadWeightsFromStream(std::istream &istr)
read rules from an std::istream
void AddWeightsXMLTo(void *parent) const
add the rules to XML node
void InitEventSample(void)
write all Events from the Tree into a vector of Events, that are more easily manipulated.
void MakeClassRuleCuts(std::ostream &) const
print out the rule cuts
TDirectory * GetMethodBaseDir() const
Definition: MethodRuleFit.h:89
const std::vector< TMVA::DecisionTree * > & GetForest() const
Definition: MethodRuleFit.h:91
void InitMonitorNtuple()
initialize the monitoring ntuple
virtual ~MethodRuleFit(void)
destructor
void Init(void)
default initialization
Double_t GetGDPathStep() const
void WriteMonitoringHistosToFile(void) const
write special monitoring histograms to file (here ntuple)
Double_t GetTreeEveFrac() const
Definition: MethodRuleFit.h:93
MethodRuleFit(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
standard constructor
void ReadWeightsFromXML(void *wghtnode)
read rules from XML node
SeparationBase * fSepType
const RuleFit * GetRuleFitConstPtr() const
Definition: MethodRuleFit.h:88
void DeclareOptions()
define the options (their key words) that can be set in the option string know options.
Double_t GetMaxFracNEve() const
Definition: MethodRuleFit.h:99
Bool_t VerifyRange(MsgLogger &mlog, const char *varstr, T &var, const T &vmin, const T &vmax)
Int_t GetNTrees() const
Definition: MethodRuleFit.h:92
Bool_t UseBoost() const
Definition: MethodRuleFit.h:84
SeparationBase * GetSeparationBase() const
Definition: MethodRuleFit.h:95
Int_t GetGDNPathSteps() const
const Ranking * CreateRanking()
computes ranking of input variables
void TrainTMVARuleFit()
training of rules using TMVA implementation
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
Ranking for variables in method (implementation)
Definition: Ranking.h:48
A class implementing various fits of rule ensembles.
Definition: RuleFit.h:45
An interface to calculate the "SeparationGain" for different separation criteria used in various trai...
EAnalysisType
Definition: Types.h:127
Basic string class.
Definition: TString.h:131
A TTree object has a header with a name and a title.
Definition: TTree.h:71
double T(double x)
Definition: ChebyshevPol.h:34
Abstract ClassifierFactory template that handles arbitrary types.
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158