Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodPyRandomForest.cxx
Go to the documentation of this file.
1// @(#)root/tmva/pymva $Id$
2// Authors: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer 2015
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodPyRandomForest *
8 * Web : http://oproject.org *
9 * *
10 * Description: *
11 * Random Forest Classifiear from Scikit learn *
12 * *
13 * *
14 * Redistribution and use in source and binary forms, with or without *
15 * modification, are permitted according to the terms listed in LICENSE *
16 * (see tmva/doc/LICENSE) *
17 * *
18 **********************************************************************************/
19#include <Python.h> // Needs to be included first to avoid redefinition of _POSIX_C_SOURCE
21
22#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
23#include <numpy/arrayobject.h>
24
25#include "TMVA/Configurable.h"
27#include "TMVA/Config.h"
28#include "TMVA/DataSet.h"
29#include "TMVA/Event.h"
30#include "TMVA/IMethod.h"
31#include "TMVA/MsgLogger.h"
32#include "TMVA/PDF.h"
33#include "TMVA/Ranking.h"
34#include "TMVA/Results.h"
35#include "TMVA/Tools.h"
36#include "TMVA/Types.h"
37#include "TMVA/Timer.h"
39
40#include "TMatrix.h"
41
42using namespace TMVA;
43
44namespace TMVA {
45namespace Internal {
46class PyGILRAII {
47 PyGILState_STATE m_GILState;
48
49public:
52};
53} // namespace Internal
54} // namespace TMVA
55
57
58
59//_______________________________________________________________________
61 const TString &methodTitle,
64 PyMethodBase(jobName, Types::kPyRandomForest, methodTitle, dsi, theOption),
65 fNestimators(10),
66 fCriterion("gini"),
67 fMaxDepth("None"),
68 fMinSamplesSplit(2),
69 fMinSamplesLeaf(1),
70 fMinWeightFractionLeaf(0),
71 fMaxFeatures("'sqrt'"),
72 fMaxLeafNodes("None"),
73 fBootstrap(kTRUE),
74 fOobScore(kFALSE),
75 fNjobs(1),
76 fRandomState("None"),
77 fVerbose(0),
78 fWarmStart(kFALSE),
79 fClassWeight("None")
80{
81}
82
83//_______________________________________________________________________
85 : PyMethodBase(Types::kPyRandomForest, theData, theWeightFile),
86 fNestimators(10),
87 fCriterion("gini"),
88 fMaxDepth("None"),
89 fMinSamplesSplit(2),
90 fMinSamplesLeaf(1),
91 fMinWeightFractionLeaf(0),
92 fMaxFeatures("'sqrt'"),
93 fMaxLeafNodes("None"),
94 fBootstrap(kTRUE),
95 fOobScore(kFALSE),
96 fNjobs(1),
97 fRandomState("None"),
98 fVerbose(0),
99 fWarmStart(kFALSE),
100 fClassWeight("None")
101{
102}
103
104
105//_______________________________________________________________________
109
110//_______________________________________________________________________
117
118//_______________________________________________________________________
120{
122
123 DeclareOptionRef(fNestimators, "NEstimators", "Integer, optional (default=10). The number of trees in the forest.");
124 DeclareOptionRef(fCriterion, "Criterion", "String, optional (default='gini') \
125 The function to measure the quality of a split. Supported criteria are \
126 'gini' for the Gini impurity and 'entropy' for the information gain. \
127 Note: this parameter is tree-specific.");
128
129 DeclareOptionRef(fMaxDepth, "MaxDepth", "integer or None, optional (default=None) \
130 The maximum depth of the tree. If None, then nodes are expanded until \
131 all leaves are pure or until all leaves contain less than \
132 min_samples_split samples. \
133 Ignored if ``max_leaf_nodes`` is not None.");
134
135 DeclareOptionRef(fMinSamplesSplit, "MinSamplesSplit", "integer, optional (default=2)\
136 The minimum number of samples required to split an internal node.");
137
138 DeclareOptionRef(fMinSamplesLeaf, "MinSamplesLeaf", "integer, optional (default=1) \
139 The minimum number of samples in newly created leaves. A split is \
140 discarded if after the split, one of the leaves would contain less then \
141 ``min_samples_leaf`` samples.");
142 DeclareOptionRef(fMinWeightFractionLeaf, "MinWeightFractionLeaf", "//float, optional (default=0.) \
143 The minimum weighted fraction of the input samples required to be at a \
144 leaf node.");
145 DeclareOptionRef(fMaxFeatures, "MaxFeatures", "The number of features to consider when looking for the best split");
146
147 DeclareOptionRef(fMaxLeafNodes, "MaxLeafNodes", "int or None, optional (default=None)\
148 Grow trees with ``max_leaf_nodes`` in best-first fashion.\
149 Best nodes are defined as relative reduction in impurity.\
150 If None then unlimited number of leaf nodes.\
151 If not None then ``max_depth`` will be ignored.");
152
153 DeclareOptionRef(fBootstrap, "Bootstrap", "boolean, optional (default=True) \
154 Whether bootstrap samples are used when building trees.");
155
156 DeclareOptionRef(fOobScore, "OoBScore", " bool Whether to use out-of-bag samples to estimate\
157 the generalization error.");
158
159 DeclareOptionRef(fNjobs, "NJobs", " integer, optional (default=1) \
160 The number of jobs to run in parallel for both `fit` and `predict`. \
161 If -1, then the number of jobs is set to the number of cores.");
162
163 DeclareOptionRef(fRandomState, "RandomState", "int, RandomState instance or None, optional (default=None)\
164 If int, random_state is the seed used by the random number generator;\
165 If RandomState instance, random_state is the random number generator;\
166 If None, the random number generator is the RandomState instance used\
167 by `np.random`.");
168
169 DeclareOptionRef(fVerbose, "Verbose", "int, optional (default=0)\
170 Controls the verbosity of the tree building process.");
171
172 DeclareOptionRef(fWarmStart, "WarmStart", "bool, optional (default=False)\
173 When set to ``True``, reuse the solution of the previous call to fit\
174 and add more estimators to the ensemble, otherwise, just fit a whole\
175 new forest.");
176
177 DeclareOptionRef(fClassWeight, "ClassWeight", "dict, list of dicts, \"auto\", \"subsample\" or None, optional\
178 Weights associated with classes in the form ``{class_label: weight}``.\
179 If not given, all classes are supposed to have weight one. For\
180 multi-output problems, a list of dicts can be provided in the same\
181 order as the columns of y.\
182 The \"auto\" mode uses the values of y to automatically adjust\
183 weights inversely proportional to class frequencies in the input data.\
184 The \"subsample\" mode is the same as \"auto\" except that weights are\
185 computed based on the bootstrap sample for every tree grown.\
186 For multi-output, the weights of each column of y will be multiplied.\
187 Note that these weights will be multiplied with sample_weight (passed\
188 through the fit method) if sample_weight is specified.");
189
190 DeclareOptionRef(fFilenameClassifier, "FilenameClassifier",
191 "Store trained classifier in this file");
192}
193
194//_______________________________________________________________________
195// Check options and load them to local python namespace
197{
198 if (fNestimators <= 0) {
199 Log() << kFATAL << " NEstimators <=0... that does not work !! " << Endl;
200 }
203
204 if (fCriterion != "gini" && fCriterion != "entropy") {
205 Log() << kFATAL << Form(" Criterion = %s... that does not work !! ", fCriterion.Data())
206 << " The options are `gini` or `entropy`." << Endl;
207 }
208 pCriterion = Eval(Form("'%s'", fCriterion.Data()));
210
213 if (!pMaxDepth) {
214 Log() << kFATAL << Form(" MaxDepth = %s... that does not work !! ", fMaxDepth.Data())
215 << " The options are None or integer." << Endl;
216 }
217
218 if (fMinSamplesSplit < 0) {
219 Log() << kFATAL << " MinSamplesSplit < 0... that does not work !! " << Endl;
220 }
223
224 if (fMinSamplesLeaf < 0) {
225 Log() << kFATAL << " MinSamplesLeaf < 0... that does not work !! " << Endl;
226 }
229
230 if (fMinWeightFractionLeaf < 0) {
231 Log() << kERROR << " MinWeightFractionLeaf < 0... that does not work !! " << Endl;
232 }
234 PyDict_SetItemString(fLocalNS, "minWeightFractionLeaf", pMinWeightFractionLeaf);
235
236 if (fMaxFeatures == "auto") fMaxFeatures = "sqrt"; // change in API from v 1.11
237 if (fMaxFeatures == "sqrt" || fMaxFeatures == "log2"){
238 fMaxFeatures = Form("'%s'", fMaxFeatures.Data());
239 }
242
243 if (!pMaxFeatures) {
244 Log() << kFATAL << Form(" MaxFeatures = %s... that does not work !! ", fMaxFeatures.Data())
245 << "int, float, string or None, optional (default='auto')"
246 << "The number of features to consider when looking for the best split:"
247 << "If int, then consider `max_features` features at each split."
248 << "If float, then `max_features` is a percentage and"
249 << "`int(max_features * n_features)` features are considered at each split."
250 << "If 'auto', then `max_features=sqrt(n_features)`."
251 << "If 'sqrt', then `max_features=sqrt(n_features)`."
252 << "If 'log2', then `max_features=log2(n_features)`."
253 << "If None, then `max_features=n_features`." << Endl;
254 }
255
257 if (!pMaxLeafNodes) {
258 Log() << kFATAL << Form(" MaxLeafNodes = %s... that does not work !! ", fMaxLeafNodes.Data())
259 << " The options are None or integer." << Endl;
260 }
262
264 if (!pRandomState) {
265 Log() << kFATAL << Form(" RandomState = %s... that does not work !! ", fRandomState.Data())
266 << "If int, random_state is the seed used by the random number generator;"
267 << "If RandomState instance, random_state is the random number generator;"
268 << "If None, the random number generator is the RandomState instance used by `np.random`." << Endl;
269 }
271
273 if (!pClassWeight) {
274 Log() << kFATAL << Form(" ClassWeight = %s... that does not work !! ", fClassWeight.Data())
275 << "dict, list of dicts, 'auto', 'subsample' or None, optional" << Endl;
276 }
278
279 if(fNjobs < 1) {
280 Log() << kFATAL << Form(" NJobs = %i... that does not work !! ", fNjobs)
281 << "Value has to be greater than zero." << Endl;
282 }
283 pNjobs = Eval(Form("%i", fNjobs));
285
286 pBootstrap = (fBootstrap) ? Eval("True") : Eval("False");
288 pOobScore = (fOobScore) ? Eval("True") : Eval("False");
290 pVerbose = Eval(Form("%i", fVerbose));
292 pWarmStart = (fWarmStart) ? Eval("True") : Eval("False");
294
295 // If no filename is given, set default
297 {
298 fFilenameClassifier = GetWeightFileDir() + "/PyRFModel_" + GetName() + ".PyData";
299 }
300}
301
302//_______________________________________________________________________
304{
306 _import_array(); //require to use numpy arrays
307
308 // Check options and load them to local python namespace
310
311 // Import module for random forest classifier
312 PyRunString("import sklearn.ensemble");
313
314 // Get data properties
317}
318
319//_______________________________________________________________________
321{
322 // Load training data (data, classes, weights) to python arrays
323 int fNrowsTraining = Data()->GetNTrainingEvents(); //every row is an event, a class type and a weight
326 dimsData[1] = fNvars;
329 float *TrainData = (float *)(PyArray_DATA(fTrainData));
330
335
339
340 for (int i = 0; i < fNrowsTraining; i++) {
341 // Fill training data matrix
342 const TMVA::Event *e = Data()->GetTrainingEvent(i);
343 for (UInt_t j = 0; j < fNvars; j++) {
344 TrainData[j + i * fNvars] = e->GetValue(j);
345 }
346
347 // Fill target classes
348 TrainDataClasses[i] = e->GetClass();
349
350 // Get event weight
351 TrainDataWeights[i] = e->GetWeight();
352 }
353
354 // Create classifier object
355 PyRunString("classifier = sklearn.ensemble.RandomForestClassifier(bootstrap=bootstrap, class_weight=classWeight, criterion=criterion, max_depth=maxDepth, max_features=maxFeatures, max_leaf_nodes=maxLeafNodes, min_samples_leaf=minSamplesLeaf, min_samples_split=minSamplesSplit, min_weight_fraction_leaf=minWeightFractionLeaf, n_estimators=nEstimators, n_jobs=nJobs, oob_score=oobScore, random_state=randomState, verbose=verbose, warm_start=warmStart)",
356 "Failed to setup classifier");
357
358 // Fit classifier
359 // NOTE: We dump the output to a variable so that the call does not pollute stdout
360 PyRunString("dump = classifier.fit(trainData, trainDataClasses, trainDataWeights)", "Failed to train classifier");
361
362 // Store classifier
364 if(fClassifier == 0) {
365 Log() << kFATAL << "Can't create classifier object from RandomForestClassifier" << Endl;
366 Log() << Endl;
367 }
368
369 if (IsModelPersistence()) {
370 Log() << Endl;
371 Log() << gTools().Color("bold") << "Saving state file: " << gTools().Color("reset") << fFilenameClassifier << Endl;
372 Log() << Endl;
374 }
375}
376
377//_______________________________________________________________________
382
383//_______________________________________________________________________
385{
386 // Load model if not already done
387 if (fClassifier == 0) ReadModelFromFile();
388
389 // Determine number of events
390 Long64_t nEvents = Data()->GetNEvents();
391 if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
392 if (firstEvt < 0) firstEvt = 0;
393 nEvents = lastEvt-firstEvt;
394
395
396 // Get data
397 npy_intp dims[2];
398 dims[0] = nEvents;
399 dims[1] = fNvars;
401 float *pValue = (float *)(PyArray_DATA(pEvent));
402
403 for (Int_t ievt=0; ievt<nEvents; ievt++) {
405 const TMVA::Event *e = Data()->GetEvent();
406 for (UInt_t i = 0; i < fNvars; i++) {
407 pValue[ievt * fNvars + i] = e->GetValue(i);
408 }
409 }
410
411 // Get prediction from classifier
412 PyArrayObject *result = (PyArrayObject *)PyObject_CallMethod(fClassifier, const_cast<char *>("predict_proba"), const_cast<char *>("(O)"), pEvent);
413 double *proba = (double *)(PyArray_DATA(result));
414
415 // Return signal probabilities
416 if(Long64_t(mvaValues.size()) != nEvents) mvaValues.resize(nEvents);
417 for (int i = 0; i < nEvents; ++i) {
419 }
420
423
424
425 return mvaValues;
426}
427
428//_______________________________________________________________________
430{
431 // cannot determine error
433
434 // Load model if not already done
435 if (fClassifier == 0) ReadModelFromFile();
436
437 // Get current event and load to python array
438 const TMVA::Event *e = Data()->GetEvent();
439 npy_intp dims[2];
440 dims[0] = 1;
441 dims[1] = fNvars;
443 float *pValue = (float *)(PyArray_DATA(pEvent));
444 for (UInt_t i = 0; i < fNvars; i++) pValue[i] = e->GetValue(i);
445
446 // Get prediction from classifier
447 PyArrayObject *result = (PyArrayObject *)PyObject_CallMethod(fClassifier, const_cast<char *>("predict_proba"), const_cast<char *>("(O)"), pEvent);
448 double *proba = (double *)(PyArray_DATA(result));
449
450 // Return MVA value
452 mvaValue = proba[TMVA::Types::kSignal]; // getting signal probability
453
456
457 return mvaValue;
458}
459
460//_______________________________________________________________________
462{
463 // Load model if not already done
464 if (fClassifier == 0) ReadModelFromFile();
465
466 // Get current event and load to python array
467 const TMVA::Event *e = Data()->GetEvent();
468 npy_intp dims[2];
469 dims[0] = 1;
470 dims[1] = fNvars;
472 float *pValue = (float *)(PyArray_DATA(pEvent));
473 for (UInt_t i = 0; i < fNvars; i++) pValue[i] = e->GetValue(i);
474
475 // Get prediction from classifier
476 PyArrayObject *result = (PyArrayObject *)PyObject_CallMethod(fClassifier, const_cast<char *>("predict_proba"), const_cast<char *>("(O)"), pEvent);
477 double *proba = (double *)(PyArray_DATA(result));
478
479 // Return MVA values
480 if(UInt_t(classValues.size()) != fNoutputs) classValues.resize(fNoutputs);
481 for(UInt_t i = 0; i < fNoutputs; i++) classValues[i] = proba[i];
482
485
486 return classValues;
487}
488
489//_______________________________________________________________________
491{
492 if (!PyIsInitialized()) {
493 PyInitialize();
494 }
495
496 Log() << Endl;
497 Log() << gTools().Color("bold") << "Loading state file: " << gTools().Color("reset") << fFilenameClassifier << Endl;
498 Log() << Endl;
499
500 // Load classifier from file
502 if(err != 0)
503 {
504 Log() << kFATAL << Form("Failed to load classifier from file (error code: %i): %s", err, fFilenameClassifier.Data()) << Endl;
505 }
506
507 // Book classifier object in python dict
509
510 // Load data properties
511 // NOTE: This has to be repeated here for the reader application
514}
515
516//_______________________________________________________________________
518{
519 // Get feature importance from classifier as an array with length equal
520 // number of variables, higher value signals a higher importance
522 if(pRanking == 0) Log() << kFATAL << "Failed to get ranking from classifier" << Endl;
523
524 // Fill ranking object and return it
525 fRanking = new Ranking(GetName(), "Variable Importance");
527 for(UInt_t iVar=0; iVar<fNvars; iVar++){
529 }
530
532
533 return fRanking;
534}
535
536//_______________________________________________________________________
538{
539 // typical length of text line:
540 // "|--------------------------------------------------------------|"
541 Log() << "A random forest is a meta estimator that fits a number of decision" << Endl;
542 Log() << "tree classifiers on various sub-samples of the dataset and use" << Endl;
543 Log() << "averaging to improve the predictive accuracy and control over-fitting." << Endl;
544 Log() << Endl;
545 Log() << "Check out the scikit-learn documentation for more information." << Endl;
546}
#define REGISTER_METHOD(CLASS)
for example
_object PyObject
#define e(i)
Definition RSha256.hxx:103
unsigned int UInt_t
Unsigned integer 4 bytes (unsigned int)
Definition RtypesCore.h:60
constexpr Bool_t kFALSE
Definition RtypesCore.h:108
long long Long64_t
Portable signed long integer 8 bytes.
Definition RtypesCore.h:83
constexpr Bool_t kTRUE
Definition RtypesCore.h:107
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
Definition TString.cxx:2495
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
MsgLogger & Log() const
Class that contains all the data information.
Definition DataSetInfo.h:62
UInt_t GetNClasses() const
const Event * GetEvent() const
returns event without transformations
Definition DataSet.cxx:202
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition DataSet.h:206
Long64_t GetNTrainingEvents() const
Definition DataSet.h:68
void SetCurrentEvent(Long64_t ievt) const
Definition DataSet.h:88
const Event * GetTrainingEvent(Long64_t ievt) const
Definition DataSet.h:74
const char * GetName() const override
Definition MethodBase.h:337
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
Bool_t IsModelPersistence() const
Definition MethodBase.h:386
const TString & GetWeightFileDir() const
Definition MethodBase.h:495
DataSetInfo & DataInfo() const
Definition MethodBase.h:413
virtual void TestClassification()
initialization
UInt_t GetNVariables() const
Definition MethodBase.h:348
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
const TString & GetInputLabel(Int_t i) const
Definition MethodBase.h:353
Ranking * fRanking
Definition MethodBase.h:590
DataSet * Data() const
Definition MethodBase.h:412
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets) override
MethodPyRandomForest(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false) override
get all the MVA values for the events of the current Data type
std::vector< Float_t > & GetMulticlassValues() override
std::vector< Float_t > classValues
std::vector< Double_t > mvaValues
void GetHelpMessage() const override
void TestClassification() override
initialization
const Ranking * CreateRanking() override
Double_t GetMvaValue(Double_t *errLower=nullptr, Double_t *errUpper=nullptr) override
Virtual base class for all TMVA method based on Python.
static int PyIsInitialized()
Check Python interpreter initialization status.
PyObject * Eval(TString code)
Evaluate Python code.
static void PyInitialize()
Initialize Python interpreter.
static void Serialize(TString file, PyObject *classifier)
Serialize Python object.
static Int_t UnSerialize(TString file, PyObject **obj)
Unserialize Python object.
PyObject * fClassifier
void PyRunString(TString code, TString errorMessage="Failed to run python code", int start=256)
Execute Python code from string.
Ranking for variables in method (implementation)
Definition Ranking.h:48
virtual void AddRank(const Rank &rank)
Add a new rank take ownership of it.
Definition Ranking.cxx:85
const TString & Color(const TString &)
human readable color strings
Definition Tools.cxx:803
Singleton class for Global types used by TMVA.
Definition Types.h:71
@ kSignal
Never change this number - it is elsewhere assumed to be zero !
Definition Types.h:135
@ kMulticlass
Definition Types.h:129
@ kClassification
Definition Types.h:127
Basic string class.
Definition TString.h:138
const char * Data() const
Definition TString.h:384
Bool_t IsNull() const
Definition TString.h:422
create variable transformations
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148