Logo ROOT   6.08/07
Reference Guide
MethodPyRandomForest.h
Go to the documentation of this file.
1 // @(#)root/tmva/pymva $Id$
2 // Authors: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer 2015
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodPyRandomForest *
8  * Web : http://oproject.org *
9  * *
10  * Description: *
11  * scikit-learn Package RandomForestClassifier method based on python *
12  * *
13  **********************************************************************************/
14 
15 #ifndef ROOT_TMVA_MethodPyRandomForest
16 #define ROOT_TMVA_MethodPyRandomForest
17 
18 //////////////////////////////////////////////////////////////////////////
19 // //
20 // MethodPyRandomForest //
21 // //
22 // //
23 //////////////////////////////////////////////////////////////////////////
24 
25 #ifndef ROOT_TMVA_PyMethodBase
26 #include "TMVA/PyMethodBase.h"
27 #endif
28 
29 namespace TMVA {
30 
31  class Factory; // DSMTEST
32  class Reader; // DSMTEST
33  class DataSetManager; // DSMTEST
34  class Types;
36 
37  public :
38 
39  // constructors
40  MethodPyRandomForest(const TString &jobName,
41  const TString &methodTitle,
42  DataSetInfo &theData,
43  const TString &theOption = "");
44 
46  const TString &theWeightFile);
47 
48 
50  void Train();
51  // options treatment
52  void Init();
53  void DeclareOptions();
54  void ProcessOptions();
55  // create ranking
57  {
58  return NULL; // = 0;
59  }
60 
61 
62  Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets);
63 
64  // performs classifier testing
65  virtual void TestClassification();
66 
67 
68  Double_t GetMvaValue(Double_t *errLower = 0, Double_t *errUpper = 0);
69 
71  // the actual "weights"
72  virtual void AddWeightsXMLTo(void * /* parent */) const {} // = 0;
73  virtual void ReadWeightsFromXML(void * /* wghtnode */) {} // = 0;
74  virtual void ReadWeightsFromStream(std::istream &) {} //= 0; // backward compatibility
75 
76  void ReadModelFromFile();
77 
78  private :
80  friend class Factory; // DSMTEST
81  friend class Reader; // DSMTEST
82  protected:
83 
84  //RandromForest options
85  Int_t n_estimators;//integer, optional (default=10)
86  //The number of trees in the forest.
87  TString criterion;//string, optional (default="gini")
88  //The function to measure the quality of a split. Supported criteria are
89  //"gini" for the Gini impurity and "entropy" for the information gain.
90  //Note: this parameter is tree-specific.
91 
92  TString max_depth;//integer or None, optional (default=None)
93  //The maximum depth of the tree. If None, then nodes are expanded until
94  //all leaves are pure or until all leaves contain less than
95  Int_t min_samples_split;//integer, optional (default=2)
96  //The minimum number of samples required to split an internal node.
97 
98 
99  Int_t min_samples_leaf;//integer, optional (default=1)
100  //The minimum number of samples in newly created leaves. A split is
101  //discarded if after the split, one of the leaves would contain less then
102  //``min_samples_leaf`` samples.
103  //Note: this parameter is tree-specific.
104  Double_t min_weight_fraction_leaf;//float, optional (default=0.)
105  //The minimum weighted fraction of the input samples required to be at a
106  //leaf node.
107  //Note: this parameter is tree-specific.
108  TString max_features;//int, float, string or None, optional (default="auto")
109  //The number of features to consider when looking for the best split:
110  //- If int, then consider `max_features` features at each split.
111  //- If float, then `max_features` is a percentage and
112  //`int(max_features * n_features)` features are considered at each split.
113  //- If "auto", then `max_features=sqrt(n_features)`.
114  //- If "sqrt", then `max_features=sqrt(n_features)`.
115  //- If "log2", then `max_features=log2(n_features)`.
116  //- If None, then `max_features=n_features`.
117  // Note: the search for a split does not stop until at least one
118  // valid partition of the node samples is found, even if it requires to
119  // effectively inspect more than ``max_features`` features.
120  // Note: this parameter is tree-specific.
121  TString max_leaf_nodes;//int or None, optional (default=None)
122  //Grow trees with ``max_leaf_nodes`` in best-first fashion.
123  //Best nodes are defined as relative reduction in impurity.
124  //If None then unlimited number of leaf nodes.
125  //If not None then ``max_depth`` will be ignored.
126  Bool_t bootstrap;//boolean, optional (default=True)
127  //Whether bootstrap samples are used when building trees.
128  Bool_t oob_score;//Whether to use out-of-bag samples to estimate
129  //the generalization error.
130  Int_t n_jobs;// : integer, optional (default=1)
131  //The number of jobs to run in parallel for both `fit` and `predict`.
132  //If -1, then the number of jobs is set to the number of cores.
133  TString random_state;//int, RandomState instance or None, optional (default=None)
134  //If int, random_state is the seed used by the random number generator;
135  //If RandomState instance, random_state is the random number generator;
136  //If None, the random number generator is the RandomState instance used
137  //by `np.random`.
138  Int_t verbose;//Controls the verbosity of the tree building process.
139  Bool_t warm_start;//bool, optional (default=False)
140  //When set to ``True``, reuse the solution of the previous call to fit
141  //and add more estimators to the ensemble, otherwise, just fit a whole
142  //new forest.
143  TString class_weight;//dict, list of dicts, "auto", "subsample" or None, optional
144  //Weights associated with classes in the form ``{class_label: weight}``.
145  //If not given, all classes are supposed to have weight one. For
146  //multi-output problems, a list of dicts can be provided in the same
147  //order as the columns of y.
148  //The "auto" mode uses the values of y to automatically adjust
149  //weights inversely proportional to class frequencies in the input data.
150  //The "subsample" mode is the same as "auto" except that weights are
151  //computed based on the bootstrap sample for every tree grown.
152  //For multi-output, the weights of each column of y will be multiplied.
153  //Note that these weights will be multiplied with sample_weight (passed
154  //through the fit method) if sample_weight is specified.
155 
156  // get help message text
157  void GetHelpMessage() const;
158 
159 
161  };
162 } // namespace TMVA
163 #endif
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
virtual void AddWeightsXMLTo(void *) const
EAnalysisType
Definition: Types.h:129
Basic string class.
Definition: TString.h:137
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
#define ClassDef(name, id)
Definition: Rtypes.h:254
Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)
MethodPyRandomForest(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
virtual void ReadWeightsFromXML(void *)
unsigned int UInt_t
Definition: RtypesCore.h:42
double Double_t
Definition: RtypesCore.h:55
int type
Definition: TGX11.cxx:120
virtual void ReadWeightsFromStream(std::istream &)
Abstract ClassifierFactory template that handles arbitrary types.
virtual void TestClassification()
initialization
#define NULL
Definition: Rtypes.h:82
virtual void ReadWeightsFromStream(std::istream &)=0