Logo ROOT  
Reference Guide
MethodPyRandomForest.h
Go to the documentation of this file.
1 // @(#)root/tmva/pymva $Id$
2 // Authors: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer 2015
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodPyRandomForest *
8  * Web : http://oproject.org *
9  * *
10  * Description: *
11  * scikit-learn Package RandomForestClassifier method based on python *
12  * *
13  **********************************************************************************/
14 
15 #ifndef ROOT_TMVA_MethodPyRandomForest
16 #define ROOT_TMVA_MethodPyRandomForest
17 
18 //////////////////////////////////////////////////////////////////////////
19 // //
20 // MethodPyRandomForest //
21 // //
22 //////////////////////////////////////////////////////////////////////////
23 
24 #include "TMVA/PyMethodBase.h"
25 #include <vector>
26 
27 namespace TMVA {
28 
29  class Factory; // DSMTEST
30  class Reader; // DSMTEST
31  class DataSetManager; // DSMTEST
32  class Types;
34 
35  public :
36  // constructors
37  MethodPyRandomForest(const TString &jobName,
38  const TString &methodTitle,
39  DataSetInfo &theData,
40  const TString &theOption = "");
41 
43  const TString &theWeightFile);
44 
46  void Train();
47 
48  // options treatment
49  void Init();
50  void DeclareOptions();
51  void ProcessOptions();
52 
53  // create ranking
54  const Ranking *CreateRanking();
55 
56  Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets);
57 
58  // performs classifier testing
59  virtual void TestClassification();
60 
61  // Get class probabilities of given event
62  Double_t GetMvaValue(Double_t *errLower = 0, Double_t *errUpper = 0);
63  std::vector<Double_t> GetMvaValues(Long64_t firstEvt = 0, Long64_t lastEvt = -1, Bool_t logProgress = false);
64  std::vector<Float_t>& GetMulticlassValues();
65 
67  // the actual "weights"
68  virtual void AddWeightsXMLTo(void * /* parent */) const {} // = 0;
69  virtual void ReadWeightsFromXML(void * /* wghtnode */) {} // = 0;
70  virtual void ReadWeightsFromStream(std::istream &) {} //= 0; // backward compatibility
71 
72  void ReadModelFromFile();
73 
74  private :
76  friend class Factory; // DSMTEST
77  friend class Reader; // DSMTEST
78 
79  protected:
80  std::vector<Double_t> mvaValues;
81  std::vector<Float_t> classValues;
82 
83  UInt_t fNvars; // number of variables
84  UInt_t fNoutputs; // number of outputs
85  TString fFilenameClassifier; // Path to serialized classifier (default in `weights` folder)
86 
87  // RandomForest options
88 
90  Int_t fNestimators; //integer, optional (default=10)
91  //The number of trees in the forest.
92 
94  TString fCriterion; //string, optional (default="gini")
95  //The function to measure the quality of a split. Supported criteria are
96  //"gini" for the Gini impurity and "entropy" for the information gain.
97  //Note: this parameter is tree-specific.
98 
100  TString fMaxDepth; //integer or None, optional (default=None)
101  //The maximum depth of the tree. If None, then nodes are expanded until
102  //all leaves are pure or until all leaves contain less than `fMinSamplesSplit`.
103 
105  Int_t fMinSamplesSplit; //integer, optional (default=2)
106  //The minimum number of samples required to split an internal node.
107 
109  Int_t fMinSamplesLeaf; //integer, optional (default=1)
110  //The minimum number of samples in newly created leaves. A split is
111  //discarded if after the split, one of the leaves would contain less then
112  //``min_samples_leaf`` samples.
113  //Note: this parameter is tree-specific.
114 
116  Double_t fMinWeightFractionLeaf; //float, optional (default=0.)
117  //The minimum weighted fraction of the input samples required to be at a
118  //leaf node.
119  //Note: this parameter is tree-specific.
120 
122  TString fMaxFeatures; //int, float, string or None, optional (default="auto")
123  //The number of features to consider when looking for the best split:
124  //- If int, then consider `max_features` features at each split.
125  //- If float, then `max_features` is a percentage and
126  //`int(max_features * n_features)` features are considered at each split.
127  //- If "auto", then `max_features=sqrt(n_features)`.
128  //- If "sqrt", then `max_features=sqrt(n_features)`.
129  //- If "log2", then `max_features=log2(n_features)`.
130  //- If None, then `max_features=n_features`.
131  // Note: the search for a split does not stop until at least one
132  // valid partition of the node samples is found, even if it requires to
133  // effectively inspect more than ``max_features`` features.
134  // Note: this parameter is tree-specific.
135 
137  TString fMaxLeafNodes; //int or None, optional (default=None)
138  //Grow trees with ``max_leaf_nodes`` in best-first fashion.
139  //Best nodes are defined as relative reduction in impurity.
140  //If None then unlimited number of leaf nodes.
141  //If not None then ``max_depth`` will be ignored.
142 
144  Bool_t fBootstrap; //boolean, optional (default=True)
145  //Whether bootstrap samples are used when building trees.
146 
148  Bool_t fOobScore; //Whether to use out-of-bag samples to estimate
149  //the generalization error.
150 
152  Int_t fNjobs; // integer, optional (default=1)
153  //The number of jobs to run in parallel for both `fit` and `predict`.
154  //If -1, then the number of jobs is set to the number of cores.
155 
157  TString fRandomState; //int, RandomState instance or None, optional (default=None)
158  //If int, random_state is the seed used by the random number generator;
159  //If RandomState instance, random_state is the random number generator;
160  //If None, the random number generator is the RandomState instance used
161  //by `np.random`.
162 
164  Int_t fVerbose; //Controls the verbosity of the tree building process.
165 
167  Bool_t fWarmStart; //bool, optional (default=False)
168  //When set to ``True``, reuse the solution of the previous call to fit
169  //and add more estimators to the ensemble, otherwise, just fit a whole
170  //new forest.
171 
173  TString fClassWeight; //dict, list of dicts, "auto", "subsample" or None, optional
174  //Weights associated with classes in the form ``{class_label: weight}``.
175  //If not given, all classes are supposed to have weight one. For
176  //multi-output problems, a list of dicts can be provided in the same
177  //order as the columns of y.
178  //The "auto" mode uses the values of y to automatically adjust
179  //weights inversely proportional to class frequencies in the input data.
180  //The "subsample" mode is the same as "auto" except that weights are
181  //computed based on the bootstrap sample for every tree grown.
182  //For multi-output, the weights of each column of y will be multiplied.
183  //Note that these weights will be multiplied with sample_weight (passed
184  //through the fit method) if sample_weight is specified.
185 
186  // get help message text
187  void GetHelpMessage() const;
188 
190  };
191 
192 } // namespace TMVA
193 
194 #endif // ROOT_TMVA_MethodPyRandomForest
TMVA::MethodPyRandomForest::pMaxDepth
PyObject * pMaxDepth
Definition: MethodPyRandomForest.h:99
TMVA::MethodPyRandomForest::fCriterion
TString fCriterion
Definition: MethodPyRandomForest.h:94
TMVA::MethodPyRandomForest::fMaxFeatures
TString fMaxFeatures
Definition: MethodPyRandomForest.h:122
TMVA::MethodPyRandomForest::pRandomState
PyObject * pRandomState
Definition: MethodPyRandomForest.h:156
TMVA::MethodPyRandomForest::GetHelpMessage
void GetHelpMessage() const
Definition: MethodPyRandomForest.cxx:550
TMVA::PyMethodBase
Definition: PyMethodBase.h:56
PyObject
_object PyObject
Definition: PyMethodBase.h:42
TMVA::MethodBase::ReadWeightsFromStream
virtual void ReadWeightsFromStream(std::istream &)=0
TMVA::MethodPyRandomForest::GetMulticlassValues
std::vector< Float_t > & GetMulticlassValues()
Definition: MethodPyRandomForest.cxx:474
TMVA::MethodPyRandomForest::Train
void Train()
Definition: MethodPyRandomForest.cxx:320
TMVA::Ranking
Ranking for variables in method (implementation)
Definition: Ranking.h:48
Long64_t
long long Long64_t
Definition: RtypesCore.h:80
TMVA::MethodPyRandomForest::CreateRanking
const Ranking * CreateRanking()
Definition: MethodPyRandomForest.cxx:530
TMVA::MethodPyRandomForest::fNvars
UInt_t fNvars
Definition: MethodPyRandomForest.h:83
TMVA::MethodPyRandomForest::ReadWeightsFromXML
virtual void ReadWeightsFromXML(void *)
Definition: MethodPyRandomForest.h:69
TMVA::MethodPyRandomForest::fRandomState
TString fRandomState
Definition: MethodPyRandomForest.h:157
TMVA::MethodPyRandomForest::pClassWeight
PyObject * pClassWeight
Definition: MethodPyRandomForest.h:172
TMVA::MethodPyRandomForest::mvaValues
std::vector< Double_t > mvaValues
Definition: MethodPyRandomForest.h:80
TMVA::MethodPyRandomForest::MethodPyRandomForest
MethodPyRandomForest(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
Definition: MethodPyRandomForest.cxx:61
TMVA::MethodPyRandomForest::~MethodPyRandomForest
~MethodPyRandomForest(void)
Definition: MethodPyRandomForest.cxx:107
TString
Basic string class.
Definition: TString.h:136
TMVA::MethodPyRandomForest::pBootstrap
PyObject * pBootstrap
Definition: MethodPyRandomForest.h:143
TMVA::MethodPyRandomForest::fMinSamplesSplit
Int_t fMinSamplesSplit
Definition: MethodPyRandomForest.h:105
TMVA::MethodPyRandomForest::pCriterion
PyObject * pCriterion
Definition: MethodPyRandomForest.h:93
TMVA::MethodPyRandomForest::fBootstrap
Bool_t fBootstrap
Definition: MethodPyRandomForest.h:144
bool
TMVA::MethodPyRandomForest::fNoutputs
UInt_t fNoutputs
Definition: MethodPyRandomForest.h:84
TMVA::MethodPyRandomForest::fClassWeight
TString fClassWeight
Definition: MethodPyRandomForest.h:173
TMVA::MethodPyRandomForest::fMaxLeafNodes
TString fMaxLeafNodes
Definition: MethodPyRandomForest.h:137
TMVA::MethodPyRandomForest::pMinSamplesSplit
PyObject * pMinSamplesSplit
Definition: MethodPyRandomForest.h:104
TMVA::MethodPyRandomForest::pMinSamplesLeaf
PyObject * pMinSamplesLeaf
Definition: MethodPyRandomForest.h:108
TMVA::DataSetInfo
Class that contains all the data information.
Definition: DataSetInfo.h:62
TMVA::MethodPyRandomForest::pNestimators
PyObject * pNestimators
Definition: MethodPyRandomForest.h:89
TMVA::MethodPyRandomForest::GetMvaValue
Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)
Definition: MethodPyRandomForest.cxx:442
TMVA::MethodPyRandomForest::ReadModelFromFile
void ReadModelFromFile()
Definition: MethodPyRandomForest.cxx:503
TMVA::MethodPyRandomForest::pMaxFeatures
PyObject * pMaxFeatures
Definition: MethodPyRandomForest.h:121
TMVA::MethodPyRandomForest::Init
void Init()
Definition: MethodPyRandomForest.cxx:303
TMVA::Types::EAnalysisType
EAnalysisType
Definition: Types.h:128
TMVA::MethodPyRandomForest::fDataSetManager
DataSetManager * fDataSetManager
Definition: MethodPyRandomForest.h:75
TMVA::MethodPyRandomForest::fVerbose
Int_t fVerbose
Definition: MethodPyRandomForest.h:164
TMVA::Factory
This is the main MVA steering class.
Definition: Factory.h:80
TMVA::MethodPyRandomForest::pWarmStart
PyObject * pWarmStart
Definition: MethodPyRandomForest.h:166
TMVA::MethodPyRandomForest::fWarmStart
Bool_t fWarmStart
Definition: MethodPyRandomForest.h:167
TMVA::MethodPyRandomForest::pOobScore
PyObject * pOobScore
Definition: MethodPyRandomForest.h:147
unsigned int
TMVA::DataSetManager
Class that contains all the data information.
Definition: DataSetManager.h:51
TMVA::MethodPyRandomForest::classValues
std::vector< Float_t > classValues
Definition: MethodPyRandomForest.h:81
TMVA::MethodPyRandomForest::fMaxDepth
TString fMaxDepth
Definition: MethodPyRandomForest.h:100
TMVA::MethodPyRandomForest::ReadWeightsFromStream
virtual void ReadWeightsFromStream(std::istream &)
Definition: MethodPyRandomForest.h:70
TMVA::MethodPyRandomForest::fMinWeightFractionLeaf
Double_t fMinWeightFractionLeaf
Definition: MethodPyRandomForest.h:116
Double_t
double Double_t
Definition: RtypesCore.h:59
TMVA::MethodPyRandomForest::HasAnalysisType
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
Definition: MethodPyRandomForest.cxx:112
TMVA::MethodPyRandomForest::fMinSamplesLeaf
Int_t fMinSamplesLeaf
Definition: MethodPyRandomForest.h:109
TMVA::MethodPyRandomForest::pMaxLeafNodes
PyObject * pMaxLeafNodes
Definition: MethodPyRandomForest.h:136
ClassDef
#define ClassDef(name, id)
Definition: Rtypes.h:325
TMVA::MethodPyRandomForest::GetMvaValues
std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
Definition: MethodPyRandomForest.cxx:384
TMVA::MethodPyRandomForest::ProcessOptions
void ProcessOptions()
Definition: MethodPyRandomForest.cxx:197
TMVA::MethodPyRandomForest::TestClassification
virtual void TestClassification()
initialization
Definition: MethodPyRandomForest.cxx:378
TMVA::MethodPyRandomForest::fNjobs
Int_t fNjobs
Definition: MethodPyRandomForest.h:152
type
int type
Definition: TGX11.cxx:121
TMVA::MethodPyRandomForest::fFilenameClassifier
TString fFilenameClassifier
Definition: MethodPyRandomForest.h:85
PyMethodBase.h
TMVA::Reader
The Reader class serves to use the MVAs in a specific analysis context.
Definition: Reader.h:64
TMVA::MethodPyRandomForest::pMinWeightFractionLeaf
PyObject * pMinWeightFractionLeaf
Definition: MethodPyRandomForest.h:115
TMVA::MethodPyRandomForest::fOobScore
Bool_t fOobScore
Definition: MethodPyRandomForest.h:148
TMVA::MethodPyRandomForest::fNestimators
Int_t fNestimators
Definition: MethodPyRandomForest.h:90
TMVA::MethodPyRandomForest
Definition: MethodPyRandomForest.h:33
TMVA::MethodPyRandomForest::pVerbose
PyObject * pVerbose
Definition: MethodPyRandomForest.h:163
TMVA::MethodPyRandomForest::DeclareOptions
void DeclareOptions()
Definition: MethodPyRandomForest.cxx:120
TMVA::MethodPyRandomForest::AddWeightsXMLTo
virtual void AddWeightsXMLTo(void *) const
Definition: MethodPyRandomForest.h:68
TMVA::MethodPyRandomForest::pNjobs
PyObject * pNjobs
Definition: MethodPyRandomForest.h:151
TMVA
create variable transformations
Definition: GeneticMinimizer.h:22
int