Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodRuleFit.h
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Fredrik Tegenfeldt
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodRuleFit *
8 * *
9 * *
10 * Description: *
11 * Friedman's RuleFit method *
12 * *
13 * Authors (alphabetical): *
14 * Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA *
15 * *
16 * Copyright (c) 2005: *
17 * CERN, Switzerland *
18 * Iowa State U. *
19 * MPI-K Heidelberg, Germany *
20 * *
21 * Redistribution and use in source and binary forms, with or without *
22 * modification, are permitted according to the terms listed in LICENSE *
23 * *
24 **********************************************************************************/
25
26#ifndef ROOT_TMVA_MethodRuleFit
27#define ROOT_TMVA_MethodRuleFit
28
29//////////////////////////////////////////////////////////////////////////
30// //
31// MethodRuleFit //
32// //
33// J Friedman's RuleFit method //
34// //
35//////////////////////////////////////////////////////////////////////////
36
37#include "TMVA/MethodBase.h"
38#include "TMatrixDfwd.h"
39#include "TVectorD.h"
40#include "TMVA/DecisionTree.h"
41#include "TMVA/RuleFit.h"
42#include <vector>
43
44namespace TMVA {
45
46 class SeparationBase;
47
48 class MethodRuleFit : public MethodBase {
49
50 public:
51
52 MethodRuleFit( const TString& jobName,
53 const TString& methodTitle,
54 DataSetInfo& theData,
55 const TString& theOption = "");
56
57 MethodRuleFit( DataSetInfo& theData,
58 const TString& theWeightFile);
59
60 virtual ~MethodRuleFit( void );
61
62 virtual Bool_t HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t /*numberTargets*/ );
63
64 // training method
65 void Train( void );
66
68
69 // write weights to file
70 void AddWeightsXMLTo ( void* parent ) const;
71
72 // read weights from file
73 void ReadWeightsFromStream( std::istream& istr );
74 void ReadWeightsFromXML ( void* wghtnode );
75
76 // calculate the MVA value
77 Double_t GetMvaValue( Double_t* err = nullptr, Double_t* errUpper = nullptr );
78
79 // write method specific histos to target file
80 void WriteMonitoringHistosToFile( void ) const;
81
82 // ranking of input variables
83 const Ranking* CreateRanking();
84
85 Bool_t UseBoost() const { return fUseBoost; }
86
87 // accessors
89 const RuleFit* GetRuleFitConstPtr() const { return &fRuleFit; }
90 TDirectory* GetMethodBaseDir() const { return BaseDir(); }
91 const std::vector<TMVA::Event*>& GetTrainingEvents() const { return fEventSample; }
92 const std::vector<TMVA::DecisionTree*>& GetForest() const { return fForest; }
93 Int_t GetNTrees() const { return fNTrees; }
101 Int_t GetNCuts() const { return fNCuts; }
102 //
108 //
110
111 const TString GetRFWorkDir() const { return fRFWorkDir; }
112 Int_t GetRFNrules() const { return fRFNrules; }
114
115 protected:
116
117 // make ROOT-independent C++ class for classifier response (classifier-specific implementation)
118 void MakeClassSpecific( std::ostream&, const TString& ) const;
119
120 void MakeClassRuleCuts( std::ostream& ) const;
121
122 void MakeClassLinear( std::ostream& ) const;
123
124 // get help message text
125 void GetHelpMessage() const;
126
127 // initialize rulefit
128 void Init( void );
129
130 // copy all training events into a stl::vector
131 void InitEventSample( void );
132
133 // initialize monitor ntuple
134 void InitMonitorNtuple();
135
136 void TrainTMVARuleFit();
137 void TrainJFRuleFit();
138
139 private:
140
141 // check variable range and set var to lower or upper if out of range
142 template<typename T>
143 inline Bool_t VerifyRange( MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax );
144
145 template<typename T>
146 inline Bool_t VerifyRange( MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax, const T& vdef );
147
148 template<typename T>
149 inline Int_t VerifyRange( const T& var, const T& vmin, const T& vmax );
150
151 // the option handling methods
152 void DeclareOptions();
153 void ProcessOptions();
154
155 RuleFit fRuleFit; ///< RuleFit instance
156 std::vector<TMVA::Event *> fEventSample; ///< the complete training sample
157 Double_t fSignalFraction; ///< scalefactor for bkg events to modify initial s/b fraction in training data
158
159 // ntuple
160 TTree *fMonitorNtuple; ///< pointer to monitor rule ntuple
161 Double_t fNTImportance; ///< ntuple: rule importance
162 Double_t fNTCoefficient; ///< ntuple: rule coefficient
163 Double_t fNTSupport; ///< ntuple: rule support
164 Int_t fNTNcuts; ///< ntuple: rule number of cuts
165 Int_t fNTNvars; ///< ntuple: rule number of vars
166 Double_t fNTPtag; ///< ntuple: rule P(tag)
167 Double_t fNTPss; ///< ntuple: rule P(tag s, true s)
168 Double_t fNTPsb; ///< ntuple: rule P(tag s, true b)
169 Double_t fNTPbs; ///< ntuple: rule P(tag b, true s)
170 Double_t fNTPbb; ///< ntuple: rule P(tag b, true b)
171 Double_t fNTSSB; ///< ntuple: rule S/(S+B)
172 Int_t fNTType; ///< ntuple: rule type (+1->signal, -1->bkg)
173
174 // options
175 TString fRuleFitModuleS;///< which rulefit module to use
176 Bool_t fUseRuleFitJF; ///< if true interface with J.Friedmans RuleFit module
177 TString fRFWorkDir; ///< working directory from Friedmans module
178 Int_t fRFNrules; ///< max number of rules (only Friedmans module)
179 Int_t fRFNendnodes; ///< max number of rules (only Friedmans module)
180 std::vector<DecisionTree *> fForest; ///< the forest
181 Int_t fNTrees; ///< number of trees in forest
182 Double_t fTreeEveFrac; ///< fraction of events used for training each tree
183 SeparationBase *fSepType; ///< the separation used in node splitting
184 Double_t fMinFracNEve; ///< min fraction of number events
185 Double_t fMaxFracNEve; ///< ditto max
186 Int_t fNCuts; ///< grid used in cut applied in node splitting
187 TString fSepTypeS; ///< forest generation: separation type - see DecisionTree
188 TString fPruneMethodS; ///< forest generation: prune method - see DecisionTree
189 TMVA::DecisionTree::EPruneMethod fPruneMethod; ///< forest generation: method used for pruning - see DecisionTree
190 Double_t fPruneStrength; ///< forest generation: prune strength - see DecisionTree
191 TString fForestTypeS; ///< forest generation: how the trees are generated
192 Bool_t fUseBoost; ///< use boosted events for forest generation
193 //
194 Double_t fGDPathEveFrac; ///< GD path: fraction of subsamples used for the fitting
195 Double_t fGDValidEveFrac; ///< GD path: fraction of subsamples used for the fitting
196 Double_t fGDTau; ///< GD path: def threshold fraction [0..1]
197 Double_t fGDTauPrec; ///< GD path: precision of estimated tau
198 Double_t fGDTauMin; ///< GD path: min threshold fraction [0..1]
199 Double_t fGDTauMax; ///< GD path: max threshold fraction [0..1]
200 UInt_t fGDTauScan; ///< GD path: number of points to scan
201 Double_t fGDPathStep; ///< GD path: step size in path
202 Int_t fGDNPathSteps; ///< GD path: number of steps
203 Double_t fGDErrScale; ///< GD path: stop
204 Double_t fMinimp; ///< rule/linear: minimum importance
205 //
206 TString fModelTypeS; ///< rule ensemble: which model (rule,linear or both)
207 Double_t fRuleMinDist; ///< rule min distance - see RuleEnsemble
208 Double_t fLinQuantile; ///< quantile cut to remove outliers - see RuleEnsemble
209
210 ClassDef(MethodRuleFit,0); // Friedman's RuleFit method
211 };
212
213} // namespace TMVA
214
215
216//_______________________________________________________________________
217template<typename T>
218inline Int_t TMVA::MethodRuleFit::VerifyRange( const T& var, const T& vmin, const T& vmax )
219{
220 // check range and return +1 if above, -1 if below or 0 if inside
221 if (var>vmax) return 1;
222 if (var<vmin) return -1;
223 return 0;
224}
225
226//_______________________________________________________________________
227template<typename T>
228inline Bool_t TMVA::MethodRuleFit::VerifyRange( TMVA::MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax )
229{
230 // verify range and print out message
231 // if outside range, set to closest limit
232 Int_t dir = TMVA::MethodRuleFit::VerifyRange(var,vmin,vmax);
233 Bool_t modif=kFALSE;
234 if (dir==1) {
235 modif = kTRUE;
236 var=vmax;
237 }
238 if (dir==-1) {
239 modif = kTRUE;
240 var=vmin;
241 }
242 if (modif) {
243 mlog << kWARNING << "Option <" << varstr << "> " << (dir==1 ? "above":"below") << " allowed range. Reset to new value = " << var << Endl;
244 }
245 return modif;
246}
247
248//_______________________________________________________________________
249template<typename T>
250inline Bool_t TMVA::MethodRuleFit::VerifyRange( TMVA::MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax, const T& vdef )
251{
252 // verify range and print out message
253 // if outside range, set to given default value
254 Int_t dir = TMVA::MethodRuleFit::VerifyRange(var,vmin,vmax);
255 Bool_t modif=kFALSE;
256 if (dir!=0) {
257 modif = kTRUE;
258 var=vdef;
259 }
260 if (modif) {
261 mlog << kWARNING << "Option <" << varstr << "> " << (dir==1 ? "above":"below") << " allowed range. Reset to default value = " << var << Endl;
262 }
263 return modif;
264}
265
266
267#endif // MethodRuleFit_H
bool Bool_t
Definition RtypesCore.h:63
constexpr Bool_t kFALSE
Definition RtypesCore.h:101
constexpr Bool_t kTRUE
Definition RtypesCore.h:100
#define ClassDef(name, id)
Definition Rtypes.h:337
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
Describe directory structure in memory.
Definition TDirectory.h:45
Class that contains all the data information.
Definition DataSetInfo.h:62
Virtual base Class for all MVA method.
Definition MethodBase.h:111
TDirectory * BaseDir() const
returns the ROOT directory where info/histograms etc of the corresponding MVA method instance are sto...
virtual void ReadWeightsFromStream(std::istream &)=0
J Friedman's RuleFit method.
Double_t GetLinQuantile() const
RuleFit fRuleFit
RuleFit instance.
RuleFit * GetRuleFitPtr()
const std::vector< TMVA::Event * > & GetTrainingEvents() const
UInt_t fGDTauScan
GD path: number of points to scan.
Double_t fNTPss
ntuple: rule P(tag s, true s)
TString fForestTypeS
forest generation: how the trees are generated
Int_t GetRFNendnodes() const
Double_t GetMinFracNEve() const
Double_t GetGDPathEveFrac() const
Double_t fMinimp
rule/linear: minimum importance
TString fRuleFitModuleS
which rulefit module to use
Double_t fLinQuantile
quantile cut to remove outliers - see RuleEnsemble
Int_t GetNCuts() const
void MakeClassSpecific(std::ostream &, const TString &) const
write specific classifier response
Double_t fMinFracNEve
min fraction of number events
Int_t fNTType
ntuple: rule type (+1->signal, -1->bkg)
Bool_t fUseRuleFitJF
if true interface with J.Friedmans RuleFit module
Double_t fGDTauMax
GD path: max threshold fraction [0..1].
Double_t fGDPathEveFrac
GD path: fraction of subsamples used for the fitting.
Bool_t fUseBoost
use boosted events for forest generation
TMVA::DecisionTree::EPruneMethod fPruneMethod
forest generation: method used for pruning - see DecisionTree
std::vector< DecisionTree * > fForest
the forest
Double_t GetGDErrScale() const
Double_t fMaxFracNEve
ditto max
Double_t GetGDValidEveFrac() const
Int_t GetRFNrules() const
TString fRFWorkDir
working directory from Friedmans module
Double_t fTreeEveFrac
fraction of events used for training each tree
const TString GetRFWorkDir() const
void MakeClassLinear(std::ostream &) const
print out the linear terms
Double_t fNTPsb
ntuple: rule P(tag s, true b)
void GetHelpMessage() const
get help message text
Int_t fNTNvars
ntuple: rule number of vars
Int_t fGDNPathSteps
GD path: number of steps.
std::vector< TMVA::Event * > fEventSample
the complete training sample
Double_t fNTPbb
ntuple: rule P(tag b, true b)
Double_t fNTSSB
ntuple: rule S/(S+B)
void TrainJFRuleFit()
training of rules using Jerome Friedmans implementation
Double_t GetPruneStrength() const
Int_t fNTNcuts
ntuple: rule number of cuts
TString fModelTypeS
rule ensemble: which model (rule,linear or both)
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t)
RuleFit can handle classification with 2 classes.
Double_t fNTImportance
ntuple: rule importance
void ProcessOptions()
process the options specified by the user
const SeparationBase * GetSeparationBaseConst() const
TMVA::DecisionTree::EPruneMethod GetPruneMethod() const
void ReadWeightsFromStream(std::istream &istr)
read rules from an std::istream
void AddWeightsXMLTo(void *parent) const
add the rules to XML node
Double_t fGDValidEveFrac
GD path: fraction of subsamples used for the fitting.
void InitEventSample(void)
write all Events from the Tree into a vector of Events, that are more easily manipulated.
void MakeClassRuleCuts(std::ostream &) const
print out the rule cuts
Double_t fNTCoefficient
ntuple: rule coefficient
TDirectory * GetMethodBaseDir() const
const std::vector< TMVA::DecisionTree * > & GetForest() const
TString fSepTypeS
forest generation: separation type - see DecisionTree
void InitMonitorNtuple()
initialize the monitoring ntuple
Int_t fNTrees
number of trees in forest
Double_t fNTPtag
ntuple: rule P(tag)
virtual ~MethodRuleFit(void)
destructor
Int_t fRFNendnodes
max number of rules (only Friedmans module)
void Init(void)
default initialization
Double_t GetGDPathStep() const
Double_t fNTPbs
ntuple: rule P(tag b, true s)
TTree * fMonitorNtuple
pointer to monitor rule ntuple
void WriteMonitoringHistosToFile(void) const
write special monitoring histograms to file (here ntuple)
Double_t GetTreeEveFrac() const
Double_t fNTSupport
ntuple: rule support
Double_t fGDTauMin
GD path: min threshold fraction [0..1].
void ReadWeightsFromXML(void *wghtnode)
read rules from XML node
Double_t fGDTau
GD path: def threshold fraction [0..1].
SeparationBase * fSepType
the separation used in node splitting
Double_t fGDPathStep
GD path: step size in path.
Double_t fGDTauPrec
GD path: precision of estimated tau.
Double_t fPruneStrength
forest generation: prune strength - see DecisionTree
Int_t fNCuts
grid used in cut applied in node splitting
const RuleFit * GetRuleFitConstPtr() const
Double_t fGDErrScale
GD path: stop.
void DeclareOptions()
define the options (their key words) that can be set in the option string know options.
TString fPruneMethodS
forest generation: prune method - see DecisionTree
Double_t GetMaxFracNEve() const
Double_t fRuleMinDist
rule min distance - see RuleEnsemble
Int_t fRFNrules
max number of rules (only Friedmans module)
Bool_t VerifyRange(MsgLogger &mlog, const char *varstr, T &var, const T &vmin, const T &vmax)
Int_t GetNTrees() const
Double_t GetMvaValue(Double_t *err=nullptr, Double_t *errUpper=nullptr)
returns MVA value for given event
Bool_t UseBoost() const
SeparationBase * GetSeparationBase() const
Int_t GetGDNPathSteps() const
const Ranking * CreateRanking()
computes ranking of input variables
void TrainTMVARuleFit()
training of rules using TMVA implementation
Double_t fSignalFraction
scalefactor for bkg events to modify initial s/b fraction in training data
ostringstream derivative to redirect and format output
Definition MsgLogger.h:57
Ranking for variables in method (implementation)
Definition Ranking.h:48
A class implementing various fits of rule ensembles.
Definition RuleFit.h:46
An interface to calculate the "SeparationGain" for different separation criteria used in various trai...
Basic string class.
Definition TString.h:139
A TTree represents a columnar dataset.
Definition TTree.h:79
create variable transformations
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148