ROOT  6.06/09
Reference Guide
MethodPyRandomForest.h
Go to the documentation of this file.
1 // @(#)root/tmva/pymva $Id$
2 // Authors: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer 2015
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodPyRandomForest *
8  * Web : http://oproject.org *
9  * *
10  * Description: *
11  * scikit-learn Package RandomForestClassifier method based on python *
12  * *
13  **********************************************************************************/
14 
15 #ifndef ROOT_TMVA_MethodPyRandomForest
16 #define ROOT_TMVA_MethodPyRandomForest
17 
18 //////////////////////////////////////////////////////////////////////////
19 // //
20 // MethodPyRandomForest //
21 // //
22 // //
23 //////////////////////////////////////////////////////////////////////////
24 
25 #ifndef ROOT_TMVA_PyMethodBase
26 #include "TMVA/PyMethodBase.h"
27 #endif
28 
29 namespace TMVA {
30 
31  class Factory; // DSMTEST
32  class Reader; // DSMTEST
33  class DataSetManager; // DSMTEST
34  class Types;
36 
37  public :
38 
39  // constructors
40  MethodPyRandomForest(const TString &jobName,
41  const TString &methodTitle,
42  DataSetInfo &theData,
43  const TString &theOption = "",
44  TDirectory *theTargetDir = NULL);
45 
47  const TString &theWeightFile,
48  TDirectory *theTargetDir = NULL);
49 
50 
52  void Train();
53  // options treatment
54  void Init();
55  void DeclareOptions();
56  void ProcessOptions();
57  // create ranking
59  {
60  return NULL; // = 0;
61  }
62 
63 
64  Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets);
65 
66  // performs classifier testing
67  virtual void TestClassification();
68 
69 
70  Double_t GetMvaValue(Double_t *errLower = 0, Double_t *errUpper = 0);
71 
73  // the actual "weights"
74  virtual void AddWeightsXMLTo(void * /* parent */) const {} // = 0;
75  virtual void ReadWeightsFromXML(void * /* wghtnode */) {} // = 0;
76  virtual void ReadWeightsFromStream(std::istream &) {} //= 0; // backward compatibility
77  void ReadStateFromFile();
78  private :
80  friend class Factory; // DSMTEST
81  friend class Reader; // DSMTEST
82  protected:
83  //RandromForest options
84  Int_t n_estimators;//integer, optional (default=10)
85  //The number of trees in the forest.
86  TString criterion;//string, optional (default="gini")
87  //The function to measure the quality of a split. Supported criteria are
88  //"gini" for the Gini impurity and "entropy" for the information gain.
89  //Note: this parameter is tree-specific.
90 
91  TString max_depth;//integer or None, optional (default=None)
92  //The maximum depth of the tree. If None, then nodes are expanded until
93  //all leaves are pure or until all leaves contain less than
94  Int_t min_samples_split;//integer, optional (default=2)
95  //The minimum number of samples required to split an internal node.
96 
97 
98  Int_t min_samples_leaf;//integer, optional (default=1)
99  //The minimum number of samples in newly created leaves. A split is
100  //discarded if after the split, one of the leaves would contain less then
101  //``min_samples_leaf`` samples.
102  //Note: this parameter is tree-specific.
103  Double_t min_weight_fraction_leaf;//float, optional (default=0.)
104  //The minimum weighted fraction of the input samples required to be at a
105  //leaf node.
106  //Note: this parameter is tree-specific.
107  TString max_features;//int, float, string or None, optional (default="auto")
108  //The number of features to consider when looking for the best split:
109  //- If int, then consider `max_features` features at each split.
110  //- If float, then `max_features` is a percentage and
111  //`int(max_features * n_features)` features are considered at each split.
112  //- If "auto", then `max_features=sqrt(n_features)`.
113  //- If "sqrt", then `max_features=sqrt(n_features)`.
114  //- If "log2", then `max_features=log2(n_features)`.
115  //- If None, then `max_features=n_features`.
116  // Note: the search for a split does not stop until at least one
117  // valid partition of the node samples is found, even if it requires to
118  // effectively inspect more than ``max_features`` features.
119  // Note: this parameter is tree-specific.
120  TString max_leaf_nodes;//int or None, optional (default=None)
121  //Grow trees with ``max_leaf_nodes`` in best-first fashion.
122  //Best nodes are defined as relative reduction in impurity.
123  //If None then unlimited number of leaf nodes.
124  //If not None then ``max_depth`` will be ignored.
125  Bool_t bootstrap;//boolean, optional (default=True)
126  //Whether bootstrap samples are used when building trees.
127  Bool_t oob_score;//Whether to use out-of-bag samples to estimate
128  //the generalization error.
129  Int_t n_jobs;// : integer, optional (default=1)
130  //The number of jobs to run in parallel for both `fit` and `predict`.
131  //If -1, then the number of jobs is set to the number of cores.
132  TString random_state;//int, RandomState instance or None, optional (default=None)
133  //If int, random_state is the seed used by the random number generator;
134  //If RandomState instance, random_state is the random number generator;
135  //If None, the random number generator is the RandomState instance used
136  //by `np.random`.
137  Int_t verbose;//Controls the verbosity of the tree building process.
138  Bool_t warm_start;//bool, optional (default=False)
139  //When set to ``True``, reuse the solution of the previous call to fit
140  //and add more estimators to the ensemble, otherwise, just fit a whole
141  //new forest.
142  TString class_weight;//dict, list of dicts, "auto", "subsample" or None, optional
143  //Weights associated with classes in the form ``{class_label: weight}``.
144  //If not given, all classes are supposed to have weight one. For
145  //multi-output problems, a list of dicts can be provided in the same
146  //order as the columns of y.
147  //The "auto" mode uses the values of y to automatically adjust
148  //weights inversely proportional to class frequencies in the input data.
149  //The "subsample" mode is the same as "auto" except that weights are
150  //computed based on the bootstrap sample for every tree grown.
151  //For multi-output, the weights of each column of y will be multiplied.
152  //Note that these weights will be multiplied with sample_weight (passed
153  //through the fit method) if sample_weight is specified.
154 
155  // get help message text
156  void GetHelpMessage() const;
157 
158 
160  };
161 } // namespace TMVA
162 #endif
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
EAnalysisType
Definition: Types.h:124
Basic string class.
Definition: TString.h:137
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
#define ClassDef(name, id)
Definition: Rtypes.h:254
Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)
virtual void ReadWeightsFromXML(void *)
unsigned int UInt_t
Definition: RtypesCore.h:42
double Double_t
Definition: RtypesCore.h:55
Describe directory structure in memory.
Definition: TDirectory.h:41
int type
Definition: TGX11.cxx:120
virtual void ReadWeightsFromStream(std::istream &)
Abstract ClassifierFactory template that handles arbitrary types.
MethodPyRandomForest(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="", TDirectory *theTargetDir=NULL)
virtual void TestClassification()
initialization
#define NULL
Definition: Rtypes.h:82
virtual void ReadWeightsFromStream(std::istream &)=0
virtual void AddWeightsXMLTo(void *) const