Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodRuleFit.h
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Fredrik Tegenfeldt
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodRuleFit *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Friedman's RuleFit method *
12 * *
13 * Authors (alphabetical): *
14 * Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA *
15 * *
16 * Copyright (c) 2005: *
17 * CERN, Switzerland *
18 * Iowa State U. *
19 * MPI-K Heidelberg, Germany *
20 * *
21 * Redistribution and use in source and binary forms, with or without *
22 * modification, are permitted according to the terms listed in LICENSE *
23 * *
24 **********************************************************************************/
25
26#ifndef ROOT_TMVA_MethodRuleFit
27#define ROOT_TMVA_MethodRuleFit
28
29//////////////////////////////////////////////////////////////////////////
30// //
31// MethodRuleFit //
32// //
33// J Friedman's RuleFit method //
34// //
35//////////////////////////////////////////////////////////////////////////
36
37#include "TMVA/MethodBase.h"
38#include "TMatrixDfwd.h"
39#include "TVectorD.h"
40#include "TMVA/DecisionTree.h"
41#include "TMVA/RuleFit.h"
42#include <vector>
43
44namespace TMVA {
45
46 class SeparationBase;
47
48 class MethodRuleFit : public MethodBase {
49
50 public:
51
52 MethodRuleFit( const TString& jobName,
53 const TString& methodTitle,
54 DataSetInfo& theData,
55 const TString& theOption = "");
56
57 MethodRuleFit( DataSetInfo& theData,
58 const TString& theWeightFile);
59
60 virtual ~MethodRuleFit( void );
61
62 virtual Bool_t HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t /*numberTargets*/ );
63
64 // training method
65 void Train( void );
66
68
69 // write weights to file
70 void AddWeightsXMLTo ( void* parent ) const;
71
72 // read weights from file
73 void ReadWeightsFromStream( std::istream& istr );
74 void ReadWeightsFromXML ( void* wghtnode );
75
76 // calculate the MVA value
77 Double_t GetMvaValue( Double_t* err = 0, Double_t* errUpper = 0 );
78
79 // write method specific histos to target file
80 void WriteMonitoringHistosToFile( void ) const;
81
82 // ranking of input variables
83 const Ranking* CreateRanking();
84
85 Bool_t UseBoost() const { return fUseBoost; }
86
87 // accessors
89 const RuleFit* GetRuleFitConstPtr() const { return &fRuleFit; }
90 TDirectory* GetMethodBaseDir() const { return BaseDir(); }
91 const std::vector<TMVA::Event*>& GetTrainingEvents() const { return fEventSample; }
92 const std::vector<TMVA::DecisionTree*>& GetForest() const { return fForest; }
93 Int_t GetNTrees() const { return fNTrees; }
101 Int_t GetNCuts() const { return fNCuts; }
102 //
108 //
110
111 const TString GetRFWorkDir() const { return fRFWorkDir; }
112 Int_t GetRFNrules() const { return fRFNrules; }
114
115 protected:
116
117 // make ROOT-independent C++ class for classifier response (classifier-specific implementation)
118 void MakeClassSpecific( std::ostream&, const TString& ) const;
119
120 void MakeClassRuleCuts( std::ostream& ) const;
121
122 void MakeClassLinear( std::ostream& ) const;
123
124 // get help message text
125 void GetHelpMessage() const;
126
127 // initialize rulefit
128 void Init( void );
129
130 // copy all training events into a stl::vector
131 void InitEventSample( void );
132
133 // initialize monitor ntuple
134 void InitMonitorNtuple();
135
136 void TrainTMVARuleFit();
137 void TrainJFRuleFit();
138
139 private:
140
141 // check variable range and set var to lower or upper if out of range
142 template<typename T>
143 inline Bool_t VerifyRange( MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax );
144
145 template<typename T>
146 inline Bool_t VerifyRange( MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax, const T& vdef );
147
148 template<typename T>
149 inline Int_t VerifyRange( const T& var, const T& vmin, const T& vmax );
150
151 // the option handling methods
152 void DeclareOptions();
153 void ProcessOptions();
154
155 RuleFit fRuleFit; // RuleFit instance
156 std::vector<TMVA::Event *> fEventSample; // the complete training sample
157 Double_t fSignalFraction; // scalefactor for bkg events to modify initial s/b fraction in training data
158
159 // ntuple
160 TTree *fMonitorNtuple; // pointer to monitor rule ntuple
161 Double_t fNTImportance; // ntuple: rule importance
162 Double_t fNTCoefficient; // ntuple: rule coefficient
163 Double_t fNTSupport; // ntuple: rule support
164 Int_t fNTNcuts; // ntuple: rule number of cuts
165 Int_t fNTNvars; // ntuple: rule number of vars
166 Double_t fNTPtag; // ntuple: rule P(tag)
167 Double_t fNTPss; // ntuple: rule P(tag s, true s)
168 Double_t fNTPsb; // ntuple: rule P(tag s, true b)
169 Double_t fNTPbs; // ntuple: rule P(tag b, true s)
170 Double_t fNTPbb; // ntuple: rule P(tag b, true b)
171 Double_t fNTSSB; // ntuple: rule S/(S+B)
172 Int_t fNTType; // ntuple: rule type (+1->signal, -1->bkg)
173
174 // options
175 TString fRuleFitModuleS;// which rulefit module to use
176 Bool_t fUseRuleFitJF; // if true interface with J.Friedmans RuleFit module
177 TString fRFWorkDir; // working directory from Friedmans module
178 Int_t fRFNrules; // max number of rules (only Friedmans module)
179 Int_t fRFNendnodes; // max number of rules (only Friedmans module)
180 std::vector<DecisionTree *> fForest; // the forest
181 Int_t fNTrees; // number of trees in forest
182 Double_t fTreeEveFrac; // fraction of events used for training each tree
183 SeparationBase *fSepType; // the separation used in node splitting
184 Double_t fMinFracNEve; // min fraction of number events
185 Double_t fMaxFracNEve; // ditto max
186 Int_t fNCuts; // grid used in cut applied in node splitting
187 TString fSepTypeS; // forest generation: separation type - see DecisionTree
188 TString fPruneMethodS; // forest generation: prune method - see DecisionTree
189 TMVA::DecisionTree::EPruneMethod fPruneMethod; // forest generation: method used for pruning - see DecisionTree
190 Double_t fPruneStrength; // forest generation: prune strength - see DecisionTree
191 TString fForestTypeS; // forest generation: how the trees are generated
192 Bool_t fUseBoost; // use boosted events for forest generation
193 //
194 Double_t fGDPathEveFrac; // GD path: fraction of subsamples used for the fitting
195 Double_t fGDValidEveFrac; // GD path: fraction of subsamples used for the fitting
196 Double_t fGDTau; // GD path: def threshold fraction [0..1]
197 Double_t fGDTauPrec; // GD path: precision of estimated tau
198 Double_t fGDTauMin; // GD path: min threshold fraction [0..1]
199 Double_t fGDTauMax; // GD path: max threshold fraction [0..1]
200 UInt_t fGDTauScan; // GD path: number of points to scan
201 Double_t fGDPathStep; // GD path: step size in path
202 Int_t fGDNPathSteps; // GD path: number of steps
203 Double_t fGDErrScale; // GD path: stop
204 Double_t fMinimp; // rule/linear: minimum importance
205 //
206 TString fModelTypeS; // rule ensemble: which model (rule,linear or both)
207 Double_t fRuleMinDist; // rule min distance - see RuleEnsemble
208 Double_t fLinQuantile; // quantile cut to remove outliers - see RuleEnsemble
209
210 ClassDef(MethodRuleFit,0); // Friedman's RuleFit method
211 };
212
213} // namespace TMVA
214
215
216//_______________________________________________________________________
217template<typename T>
218inline Int_t TMVA::MethodRuleFit::VerifyRange( const T& var, const T& vmin, const T& vmax )
219{
220 // check range and return +1 if above, -1 if below or 0 if inside
221 if (var>vmax) return 1;
222 if (var<vmin) return -1;
223 return 0;
224}
225
226//_______________________________________________________________________
227template<typename T>
228inline Bool_t TMVA::MethodRuleFit::VerifyRange( TMVA::MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax )
229{
230 // verify range and print out message
231 // if outside range, set to closest limit
232 Int_t dir = TMVA::MethodRuleFit::VerifyRange(var,vmin,vmax);
233 Bool_t modif=kFALSE;
234 if (dir==1) {
235 modif = kTRUE;
236 var=vmax;
237 }
238 if (dir==-1) {
239 modif = kTRUE;
240 var=vmin;
241 }
242 if (modif) {
243 mlog << kWARNING << "Option <" << varstr << "> " << (dir==1 ? "above":"below") << " allowed range. Reset to new value = " << var << Endl;
244 }
245 return modif;
246}
247
248//_______________________________________________________________________
249template<typename T>
250inline Bool_t TMVA::MethodRuleFit::VerifyRange( TMVA::MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax, const T& vdef )
251{
252 // verify range and print out message
253 // if outside range, set to given default value
254 Int_t dir = TMVA::MethodRuleFit::VerifyRange(var,vmin,vmax);
255 Bool_t modif=kFALSE;
256 if (dir!=0) {
257 modif = kTRUE;
258 var=vdef;
259 }
260 if (modif) {
261 mlog << kWARNING << "Option <" << varstr << "> " << (dir==1 ? "above":"below") << " allowed range. Reset to default value = " << var << Endl;
262 }
263 return modif;
264}
265
266
267#endif // MethodRuleFit_H
const Bool_t kFALSE
Definition RtypesCore.h:101
bool Bool_t
Definition RtypesCore.h:63
double Double_t
Definition RtypesCore.h:59
const Bool_t kTRUE
Definition RtypesCore.h:100
#define ClassDef(name, id)
Definition Rtypes.h:325
int type
Definition TGX11.cxx:121
Describe directory structure in memory.
Definition TDirectory.h:45
Class that contains all the data information.
Definition DataSetInfo.h:62
Virtual base Class for all MVA method.
Definition MethodBase.h:111
TDirectory * BaseDir() const
returns the ROOT directory where info/histograms etc of the corresponding MVA method instance are sto...
virtual void ReadWeightsFromStream(std::istream &)=0
J Friedman's RuleFit method.
Double_t GetLinQuantile() const
RuleFit * GetRuleFitPtr()
const std::vector< TMVA::Event * > & GetTrainingEvents() const
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
returns MVA value for given event
Int_t GetRFNendnodes() const
Double_t GetMinFracNEve() const
Double_t GetGDPathEveFrac() const
Int_t GetNCuts() const
void MakeClassSpecific(std::ostream &, const TString &) const
write specific classifier response
TMVA::DecisionTree::EPruneMethod fPruneMethod
std::vector< DecisionTree * > fForest
Double_t GetGDErrScale() const
Double_t GetGDValidEveFrac() const
Int_t GetRFNrules() const
const TString GetRFWorkDir() const
void MakeClassLinear(std::ostream &) const
print out the linear terms
void GetHelpMessage() const
get help message text
std::vector< TMVA::Event * > fEventSample
void TrainJFRuleFit()
training of rules using Jerome Friedmans implementation
Double_t GetPruneStrength() const
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t)
RuleFit can handle classification with 2 classes.
void ProcessOptions()
process the options specified by the user
const SeparationBase * GetSeparationBaseConst() const
TMVA::DecisionTree::EPruneMethod GetPruneMethod() const
void ReadWeightsFromStream(std::istream &istr)
read rules from an std::istream
void AddWeightsXMLTo(void *parent) const
add the rules to XML node
void InitEventSample(void)
write all Events from the Tree into a vector of Events, that are more easily manipulated.
void MakeClassRuleCuts(std::ostream &) const
print out the rule cuts
TDirectory * GetMethodBaseDir() const
const std::vector< TMVA::DecisionTree * > & GetForest() const
void InitMonitorNtuple()
initialize the monitoring ntuple
virtual ~MethodRuleFit(void)
destructor
void Init(void)
default initialization
Double_t GetGDPathStep() const
void WriteMonitoringHistosToFile(void) const
write special monitoring histograms to file (here ntuple)
Double_t GetTreeEveFrac() const
void ReadWeightsFromXML(void *wghtnode)
read rules from XML node
SeparationBase * fSepType
const RuleFit * GetRuleFitConstPtr() const
void DeclareOptions()
define the options (their key words) that can be set in the option string know options.
Double_t GetMaxFracNEve() const
Bool_t VerifyRange(MsgLogger &mlog, const char *varstr, T &var, const T &vmin, const T &vmax)
Int_t GetNTrees() const
Bool_t UseBoost() const
SeparationBase * GetSeparationBase() const
Int_t GetGDNPathSteps() const
const Ranking * CreateRanking()
computes ranking of input variables
void TrainTMVARuleFit()
training of rules using TMVA implementation
ostringstream derivative to redirect and format output
Definition MsgLogger.h:57
Ranking for variables in method (implementation)
Definition Ranking.h:48
A class implementing various fits of rule ensembles.
Definition RuleFit.h:46
An interface to calculate the "SeparationGain" for different separation criteria used in various trai...
Basic string class.
Definition TString.h:136
A TTree represents a columnar dataset.
Definition TTree.h:79
create variable transformations
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148