Logo ROOT  
Reference Guide
MethodPyRandomForest.h
Go to the documentation of this file.
1// @(#)root/tmva/pymva $Id$
2// Authors: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer 2015
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodPyRandomForest *
8 * Web : http://oproject.org *
9 * *
10 * Description: *
11 * scikit-learn Package RandomForestClassifier method based on python *
12 * *
13 **********************************************************************************/
14
15#ifndef ROOT_TMVA_MethodPyRandomForest
16#define ROOT_TMVA_MethodPyRandomForest
17
18//////////////////////////////////////////////////////////////////////////
19// //
20// MethodPyRandomForest //
21// //
22//////////////////////////////////////////////////////////////////////////
23
24#include "TMVA/PyMethodBase.h"
25#include <vector>
26
27namespace TMVA {
28
29 class Factory; // DSMTEST
30 class Reader; // DSMTEST
31 class DataSetManager; // DSMTEST
32 class Types;
34
35 public :
36 // constructors
37 MethodPyRandomForest(const TString &jobName,
38 const TString &methodTitle,
39 DataSetInfo &theData,
40 const TString &theOption = "");
41
43 const TString &theWeightFile);
44
46 void Train();
47
48 // options treatment
49 void Init();
50 void DeclareOptions();
51 void ProcessOptions();
52
53 // create ranking
54 const Ranking *CreateRanking();
55
56 Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets);
57
58 // performs classifier testing
59 virtual void TestClassification();
60
61 // Get class probabilities of given event
62 Double_t GetMvaValue(Double_t *errLower = 0, Double_t *errUpper = 0);
63 std::vector<Double_t> GetMvaValues(Long64_t firstEvt = 0, Long64_t lastEvt = -1, Bool_t logProgress = false);
64 std::vector<Float_t>& GetMulticlassValues();
65
67 // the actual "weights"
68 virtual void AddWeightsXMLTo(void * /* parent */) const {} // = 0;
69 virtual void ReadWeightsFromXML(void * /* wghtnode */) {} // = 0;
70 virtual void ReadWeightsFromStream(std::istream &) {} //= 0; // backward compatibility
71
72 void ReadModelFromFile();
73
74 private :
76 friend class Factory; // DSMTEST
77 friend class Reader; // DSMTEST
78
79 protected:
80 std::vector<Double_t> mvaValues;
81 std::vector<Float_t> classValues;
82
83 UInt_t fNvars; // number of variables
84 UInt_t fNoutputs; // number of outputs
85 TString fFilenameClassifier; // Path to serialized classifier (default in `weights` folder)
86
87 // RandomForest options
88
90 Int_t fNestimators; //integer, optional (default=10)
91 //The number of trees in the forest.
92
94 TString fCriterion; //string, optional (default="gini")
95 //The function to measure the quality of a split. Supported criteria are
96 //"gini" for the Gini impurity and "entropy" for the information gain.
97 //Note: this parameter is tree-specific.
98
100 TString fMaxDepth; //integer or None, optional (default=None)
101 //The maximum depth of the tree. If None, then nodes are expanded until
102 //all leaves are pure or until all leaves contain less than `fMinSamplesSplit`.
103
105 Int_t fMinSamplesSplit; //integer, optional (default=2)
106 //The minimum number of samples required to split an internal node.
107
109 Int_t fMinSamplesLeaf; //integer, optional (default=1)
110 //The minimum number of samples in newly created leaves. A split is
111 //discarded if after the split, one of the leaves would contain less then
112 //``min_samples_leaf`` samples.
113 //Note: this parameter is tree-specific.
114
116 Double_t fMinWeightFractionLeaf; //float, optional (default=0.)
117 //The minimum weighted fraction of the input samples required to be at a
118 //leaf node.
119 //Note: this parameter is tree-specific.
120
122 TString fMaxFeatures; //int, float, string or None, optional (default="auto")
123 //The number of features to consider when looking for the best split:
124 //- If int, then consider `max_features` features at each split.
125 //- If float, then `max_features` is a percentage and
126 //`int(max_features * n_features)` features are considered at each split.
127 //- If "auto", then `max_features=sqrt(n_features)`.
128 //- If "sqrt", then `max_features=sqrt(n_features)`.
129 //- If "log2", then `max_features=log2(n_features)`.
130 //- If None, then `max_features=n_features`.
131 // Note: the search for a split does not stop until at least one
132 // valid partition of the node samples is found, even if it requires to
133 // effectively inspect more than ``max_features`` features.
134 // Note: this parameter is tree-specific.
135
137 TString fMaxLeafNodes; //int or None, optional (default=None)
138 //Grow trees with ``max_leaf_nodes`` in best-first fashion.
139 //Best nodes are defined as relative reduction in impurity.
140 //If None then unlimited number of leaf nodes.
141 //If not None then ``max_depth`` will be ignored.
142
144 Bool_t fBootstrap; //boolean, optional (default=True)
145 //Whether bootstrap samples are used when building trees.
146
148 Bool_t fOobScore; //Whether to use out-of-bag samples to estimate
149 //the generalization error.
150
152 Int_t fNjobs; // integer, optional (default=1)
153 //The number of jobs to run in parallel for both `fit` and `predict`.
154 //If -1, then the number of jobs is set to the number of cores.
155
157 TString fRandomState; //int, RandomState instance or None, optional (default=None)
158 //If int, random_state is the seed used by the random number generator;
159 //If RandomState instance, random_state is the random number generator;
160 //If None, the random number generator is the RandomState instance used
161 //by `np.random`.
162
164 Int_t fVerbose; //Controls the verbosity of the tree building process.
165
167 Bool_t fWarmStart; //bool, optional (default=False)
168 //When set to ``True``, reuse the solution of the previous call to fit
169 //and add more estimators to the ensemble, otherwise, just fit a whole
170 //new forest.
171
173 TString fClassWeight; //dict, list of dicts, "auto", "subsample" or None, optional
174 //Weights associated with classes in the form ``{class_label: weight}``.
175 //If not given, all classes are supposed to have weight one. For
176 //multi-output problems, a list of dicts can be provided in the same
177 //order as the columns of y.
178 //The "auto" mode uses the values of y to automatically adjust
179 //weights inversely proportional to class frequencies in the input data.
180 //The "subsample" mode is the same as "auto" except that weights are
181 //computed based on the bootstrap sample for every tree grown.
182 //For multi-output, the weights of each column of y will be multiplied.
183 //Note that these weights will be multiplied with sample_weight (passed
184 //through the fit method) if sample_weight is specified.
185
186 // get help message text
187 void GetHelpMessage() const;
188
190 };
191
192} // namespace TMVA
193
194#endif // ROOT_TMVA_MethodPyRandomForest
_object PyObject
Definition: PyMethodBase.h:42
int Int_t
Definition: RtypesCore.h:45
unsigned int UInt_t
Definition: RtypesCore.h:46
bool Bool_t
Definition: RtypesCore.h:63
double Double_t
Definition: RtypesCore.h:59
long long Long64_t
Definition: RtypesCore.h:80
#define ClassDef(name, id)
Definition: Rtypes.h:325
int type
Definition: TGX11.cxx:121
Class that contains all the data information.
Definition: DataSetInfo.h:62
Class that contains all the data information.
This is the main MVA steering class.
Definition: Factory.h:80
virtual void ReadWeightsFromStream(std::istream &)=0
MethodPyRandomForest(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
std::vector< Float_t > & GetMulticlassValues()
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
std::vector< Float_t > classValues
virtual void AddWeightsXMLTo(void *) const
std::vector< Double_t > mvaValues
virtual void ReadWeightsFromStream(std::istream &)
virtual void TestClassification()
initialization
virtual void ReadWeightsFromXML(void *)
Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)
Ranking for variables in method (implementation)
Definition: Ranking.h:48
The Reader class serves to use the MVAs in a specific analysis context.
Definition: Reader.h:64
EAnalysisType
Definition: Types.h:128
Basic string class.
Definition: TString.h:136
create variable transformations