Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodPyRandomForest.cxx
Go to the documentation of this file.
1// @(#)root/tmva/pymva $Id$
2// Authors: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer 2015
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodPyRandomForest *
8 * Web : http://oproject.org *
9 * *
10 * Description: *
11 * Random Forest Classifiear from Scikit learn *
12 * *
13 * *
14 * Redistribution and use in source and binary forms, with or without *
15 * modification, are permitted according to the terms listed in LICENSE *
16 * (see tmva/doc/LICENSE) *
17 * *
18 **********************************************************************************/
19#include <Python.h> // Needs to be included first to avoid redefinition of _POSIX_C_SOURCE
21
22#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
23#include <numpy/arrayobject.h>
24
25#include "TMVA/Configurable.h"
27#include "TMVA/Config.h"
28#include "TMVA/DataSet.h"
29#include "TMVA/Event.h"
30#include "TMVA/IMethod.h"
31#include "TMVA/MsgLogger.h"
32#include "TMVA/PDF.h"
33#include "TMVA/Ranking.h"
34#include "TMVA/Results.h"
35#include "TMVA/Tools.h"
36#include "TMVA/Types.h"
37#include "TMVA/Timer.h"
39
40#include "TMatrix.h"
41
42using namespace TMVA;
43
44namespace TMVA {
45namespace Internal {
46class PyGILRAII {
47 PyGILState_STATE m_GILState;
48
49public:
52};
53} // namespace Internal
54} // namespace TMVA
55
57
58
59//_______________________________________________________________________
61 const TString &methodTitle,
64 PyMethodBase(jobName, Types::kPyRandomForest, methodTitle, dsi, theOption),
65 fNestimators(10),
66 fCriterion("gini"),
67 fMaxDepth("None"),
68 fMinSamplesSplit(2),
69 fMinSamplesLeaf(1),
70 fMinWeightFractionLeaf(0),
71 fMaxFeatures("'sqrt'"),
72 fMaxLeafNodes("None"),
73 fBootstrap(kTRUE),
74 fOobScore(kFALSE),
75 fNjobs(1),
76 fRandomState("None"),
77 fVerbose(0),
78 fWarmStart(kFALSE),
79 fClassWeight("None")
80{
81}
82
83//_______________________________________________________________________
85 : PyMethodBase(Types::kPyRandomForest, theData, theWeightFile),
86 fNestimators(10),
87 fCriterion("gini"),
88 fMaxDepth("None"),
89 fMinSamplesSplit(2),
90 fMinSamplesLeaf(1),
91 fMinWeightFractionLeaf(0),
92 fMaxFeatures("'sqrt'"),
93 fMaxLeafNodes("None"),
94 fBootstrap(kTRUE),
95 fOobScore(kFALSE),
96 fNjobs(1),
97 fRandomState("None"),
98 fVerbose(0),
99 fWarmStart(kFALSE),
100 fClassWeight("None")
101{
102}
103
104
105//_______________________________________________________________________
109
110//_______________________________________________________________________
117
118//_______________________________________________________________________
120{
122
123 DeclareOptionRef(fNestimators, "NEstimators", "Integer, optional (default=10). The number of trees in the forest.");
124 DeclareOptionRef(fCriterion, "Criterion", "String, optional (default='gini') \
125 The function to measure the quality of a split. Supported criteria are \
126 'gini' for the Gini impurity and 'entropy' for the information gain. \
127 Note: this parameter is tree-specific.");
128
129 DeclareOptionRef(fMaxDepth, "MaxDepth", "integer or None, optional (default=None) \
130 The maximum depth of the tree. If None, then nodes are expanded until \
131 all leaves are pure or until all leaves contain less than \
132 min_samples_split samples. \
133 Ignored if ``max_leaf_nodes`` is not None.");
134
135 DeclareOptionRef(fMinSamplesSplit, "MinSamplesSplit", "integer, optional (default=2)\
136 The minimum number of samples required to split an internal node.");
137
138 DeclareOptionRef(fMinSamplesLeaf, "MinSamplesLeaf", "integer, optional (default=1) \
139 The minimum number of samples in newly created leaves. A split is \
140 discarded if after the split, one of the leaves would contain less then \
141 ``min_samples_leaf`` samples.");
142 DeclareOptionRef(fMinWeightFractionLeaf, "MinWeightFractionLeaf", "//float, optional (default=0.) \
143 The minimum weighted fraction of the input samples required to be at a \
144 leaf node.");
145 DeclareOptionRef(fMaxFeatures, "MaxFeatures", "The number of features to consider when looking for the best split");
146
147 DeclareOptionRef(fMaxLeafNodes, "MaxLeafNodes", "int or None, optional (default=None)\
148 Grow trees with ``max_leaf_nodes`` in best-first fashion.\
149 Best nodes are defined as relative reduction in impurity.\
150 If None then unlimited number of leaf nodes.\
151 If not None then ``max_depth`` will be ignored.");
152
153 DeclareOptionRef(fBootstrap, "Bootstrap", "boolean, optional (default=True) \
154 Whether bootstrap samples are used when building trees.");
155
156 DeclareOptionRef(fOobScore, "OoBScore", " bool Whether to use out-of-bag samples to estimate\
157 the generalization error.");
158
159 DeclareOptionRef(fNjobs, "NJobs", " integer, optional (default=1) \
160 The number of jobs to run in parallel for both `fit` and `predict`. \
161 If -1, then the number of jobs is set to the number of cores.");
162
163 DeclareOptionRef(fRandomState, "RandomState", "int, RandomState instance or None, optional (default=None)\
164 If int, random_state is the seed used by the random number generator;\
165 If RandomState instance, random_state is the random number generator;\
166 If None, the random number generator is the RandomState instance used\
167 by `np.random`.");
168
169 DeclareOptionRef(fVerbose, "Verbose", "int, optional (default=0)\
170 Controls the verbosity of the tree building process.");
171
172 DeclareOptionRef(fWarmStart, "WarmStart", "bool, optional (default=False)\
173 When set to ``True``, reuse the solution of the previous call to fit\
174 and add more estimators to the ensemble, otherwise, just fit a whole\
175 new forest.");
176
177 DeclareOptionRef(fClassWeight, "ClassWeight", "dict, list of dicts, \"auto\", \"subsample\" or None, optional\
178 Weights associated with classes in the form ``{class_label: weight}``.\
179 If not given, all classes are supposed to have weight one. For\
180 multi-output problems, a list of dicts can be provided in the same\
181 order as the columns of y.\
182 The \"auto\" mode uses the values of y to automatically adjust\
183 weights inversely proportional to class frequencies in the input data.\
184 The \"subsample\" mode is the same as \"auto\" except that weights are\
185 computed based on the bootstrap sample for every tree grown.\
186 For multi-output, the weights of each column of y will be multiplied.\
187 Note that these weights will be multiplied with sample_weight (passed\
188 through the fit method) if sample_weight is specified.");
189
190 DeclareOptionRef(fFilenameClassifier, "FilenameClassifier",
191 "Store trained classifier in this file");
192}
193
194//_______________________________________________________________________
195// Check options and load them to local python namespace
197{
198 if (fNestimators <= 0) {
199 Log() << kFATAL << " NEstimators <=0... that does not work !! " << Endl;
200 }
203
204 if (fCriterion != "gini" && fCriterion != "entropy") {
205 Log() << kFATAL << Form(" Criterion = %s... that does not work !! ", fCriterion.Data())
206 << " The options are `gini` or `entropy`." << Endl;
207 }
208 pCriterion = Eval(Form("'%s'", fCriterion.Data()));
210
213 if (!pMaxDepth) {
214 Log() << kFATAL << Form(" MaxDepth = %s... that does not work !! ", fMaxDepth.Data())
215 << " The options are None or integer." << Endl;
216 }
217
218 if (fMinSamplesSplit < 0) {
219 Log() << kFATAL << " MinSamplesSplit < 0... that does not work !! " << Endl;
220 }
223
224 if (fMinSamplesLeaf < 0) {
225 Log() << kFATAL << " MinSamplesLeaf < 0... that does not work !! " << Endl;
226 }
229
230 if (fMinWeightFractionLeaf < 0) {
231 Log() << kERROR << " MinWeightFractionLeaf < 0... that does not work !! " << Endl;
232 }
234 PyDict_SetItemString(fLocalNS, "minWeightFractionLeaf", pMinWeightFractionLeaf);
235
236 if (fMaxFeatures == "auto") fMaxFeatures = "sqrt"; // change in API from v 1.11
237 if (fMaxFeatures == "sqrt" || fMaxFeatures == "log2"){
238 fMaxFeatures = Form("'%s'", fMaxFeatures.Data());
239 }
242
243 if (!pMaxFeatures) {
244 Log() << kFATAL << Form(" MaxFeatures = %s... that does not work !! ", fMaxFeatures.Data())
245 << "int, float, string or None, optional (default='auto')"
246 << "The number of features to consider when looking for the best split:"
247 << "If int, then consider `max_features` features at each split."
248 << "If float, then `max_features` is a percentage and"
249 << "`int(max_features * n_features)` features are considered at each split."
250 << "If 'auto', then `max_features=sqrt(n_features)`."
251 << "If 'sqrt', then `max_features=sqrt(n_features)`."
252 << "If 'log2', then `max_features=log2(n_features)`."
253 << "If None, then `max_features=n_features`." << Endl;
254 }
255
257 if (!pMaxLeafNodes) {
258 Log() << kFATAL << Form(" MaxLeafNodes = %s... that does not work !! ", fMaxLeafNodes.Data())
259 << " The options are None or integer." << Endl;
260 }
262
264 if (!pRandomState) {
265 Log() << kFATAL << Form(" RandomState = %s... that does not work !! ", fRandomState.Data())
266 << "If int, random_state is the seed used by the random number generator;"
267 << "If RandomState instance, random_state is the random number generator;"
268 << "If None, the random number generator is the RandomState instance used by `np.random`." << Endl;
269 }
271
273 if (!pClassWeight) {
274 Log() << kFATAL << Form(" ClassWeight = %s... that does not work !! ", fClassWeight.Data())
275 << "dict, list of dicts, 'auto', 'subsample' or None, optional" << Endl;
276 }
278
279 if(fNjobs < 1) {
280 Log() << kFATAL << Form(" NJobs = %i... that does not work !! ", fNjobs)
281 << "Value has to be greater than zero." << Endl;
282 }
283 pNjobs = Eval(Form("%i", fNjobs));
285
286 pBootstrap = (fBootstrap) ? Eval("True") : Eval("False");
288 pOobScore = (fOobScore) ? Eval("True") : Eval("False");
290 pVerbose = Eval(Form("%i", fVerbose));
292 pWarmStart = (fWarmStart) ? Eval("True") : Eval("False");
294
295 // If no filename is given, set default
297 {
298 fFilenameClassifier = GetWeightFileDir() + "/PyRFModel_" + GetName() + ".PyData";
299 }
300}
301
302//_______________________________________________________________________
304{
306 _import_array(); //require to use numpy arrays
307
308 // Check options and load them to local python namespace
310
311 // Import module for random forest classifier
312 PyRunString("import sklearn.ensemble");
313
314 // Get data properties
317}
318
319//_______________________________________________________________________
321{
322 // Load training data (data, classes, weights) to python arrays
323 int fNrowsTraining = Data()->GetNTrainingEvents(); //every row is an event, a class type and a weight
326 dimsData[1] = fNvars;
329 float *TrainData = (float *)(PyArray_DATA(fTrainData));
330
335
339
340 for (int i = 0; i < fNrowsTraining; i++) {
341 // Fill training data matrix
342 const TMVA::Event *e = Data()->GetTrainingEvent(i);
343 for (UInt_t j = 0; j < fNvars; j++) {
344 TrainData[j + i * fNvars] = e->GetValue(j);
345 }
346
347 // Fill target classes
348 TrainDataClasses[i] = e->GetClass();
349
350 // Get event weight
351 TrainDataWeights[i] = e->GetWeight();
352 }
353
354 // Create classifier object
355 PyRunString("classifier = sklearn.ensemble.RandomForestClassifier(bootstrap=bootstrap, class_weight=classWeight, criterion=criterion, max_depth=maxDepth, max_features=maxFeatures, max_leaf_nodes=maxLeafNodes, min_samples_leaf=minSamplesLeaf, min_samples_split=minSamplesSplit, min_weight_fraction_leaf=minWeightFractionLeaf, n_estimators=nEstimators, n_jobs=nJobs, oob_score=oobScore, random_state=randomState, verbose=verbose, warm_start=warmStart)",
356 "Failed to setup classifier");
357
358 // Fit classifier
359 // NOTE: We dump the output to a variable so that the call does not pollute stdout
360 PyRunString("dump = classifier.fit(trainData, trainDataClasses, trainDataWeights)", "Failed to train classifier");
361
362 // Store classifier
364 if(fClassifier == 0) {
365 Log() << kFATAL << "Can't create classifier object from RandomForestClassifier" << Endl;
366 Log() << Endl;
367 }
368
369 if (IsModelPersistence()) {
370 Log() << Endl;
371 Log() << gTools().Color("bold") << "Saving state file: " << gTools().Color("reset") << fFilenameClassifier << Endl;
372 Log() << Endl;
374 }
375}
376
377//_______________________________________________________________________
382
383//_______________________________________________________________________
385{
386 // Load model if not already done
387 if (fClassifier == 0) ReadModelFromFile();
388
389 // Determine number of events
390 Long64_t nEvents = Data()->GetNEvents();
391 if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
392 if (firstEvt < 0) firstEvt = 0;
393 nEvents = lastEvt-firstEvt;
394
395 // use timer
396 Timer timer( nEvents, GetName(), kTRUE );
397
398 if (logProgress)
399 Log() << kHEADER << Form("[%s] : ",DataInfo().GetName())
400 << "Evaluation of " << GetMethodName() << " on "
401 << (Data()->GetCurrentType() == Types::kTraining ? "training" : "testing")
402 << " sample (" << nEvents << " events)" << Endl;
403
404 // Get data
405 npy_intp dims[2];
406 dims[0] = nEvents;
407 dims[1] = fNvars;
409 float *pValue = (float *)(PyArray_DATA(pEvent));
410
411 for (Int_t ievt=0; ievt<nEvents; ievt++) {
413 const TMVA::Event *e = Data()->GetEvent();
414 for (UInt_t i = 0; i < fNvars; i++) {
415 pValue[ievt * fNvars + i] = e->GetValue(i);
416 }
417 }
418
419 // Get prediction from classifier
420 PyArrayObject *result = (PyArrayObject *)PyObject_CallMethod(fClassifier, const_cast<char *>("predict_proba"), const_cast<char *>("(O)"), pEvent);
421 double *proba = (double *)(PyArray_DATA(result));
422
423 // Return signal probabilities
424 if(Long64_t(mvaValues.size()) != nEvents) mvaValues.resize(nEvents);
425 for (int i = 0; i < nEvents; ++i) {
427 }
428
431
432 if (logProgress) {
433 Log() << kINFO
434 << "Elapsed time for evaluation of " << nEvents << " events: "
435 << timer.GetElapsedTime() << " " << Endl;
436 }
437
438 return mvaValues;
439}
440
441//_______________________________________________________________________
443{
444 // cannot determine error
446
447 // Load model if not already done
448 if (fClassifier == 0) ReadModelFromFile();
449
450 // Get current event and load to python array
451 const TMVA::Event *e = Data()->GetEvent();
452 npy_intp dims[2];
453 dims[0] = 1;
454 dims[1] = fNvars;
456 float *pValue = (float *)(PyArray_DATA(pEvent));
457 for (UInt_t i = 0; i < fNvars; i++) pValue[i] = e->GetValue(i);
458
459 // Get prediction from classifier
460 PyArrayObject *result = (PyArrayObject *)PyObject_CallMethod(fClassifier, const_cast<char *>("predict_proba"), const_cast<char *>("(O)"), pEvent);
461 double *proba = (double *)(PyArray_DATA(result));
462
463 // Return MVA value
465 mvaValue = proba[TMVA::Types::kSignal]; // getting signal probability
466
469
470 return mvaValue;
471}
472
473//_______________________________________________________________________
475{
476 // Load model if not already done
477 if (fClassifier == 0) ReadModelFromFile();
478
479 // Get current event and load to python array
480 const TMVA::Event *e = Data()->GetEvent();
481 npy_intp dims[2];
482 dims[0] = 1;
483 dims[1] = fNvars;
485 float *pValue = (float *)(PyArray_DATA(pEvent));
486 for (UInt_t i = 0; i < fNvars; i++) pValue[i] = e->GetValue(i);
487
488 // Get prediction from classifier
489 PyArrayObject *result = (PyArrayObject *)PyObject_CallMethod(fClassifier, const_cast<char *>("predict_proba"), const_cast<char *>("(O)"), pEvent);
490 double *proba = (double *)(PyArray_DATA(result));
491
492 // Return MVA values
493 if(UInt_t(classValues.size()) != fNoutputs) classValues.resize(fNoutputs);
494 for(UInt_t i = 0; i < fNoutputs; i++) classValues[i] = proba[i];
495
498
499 return classValues;
500}
501
502//_______________________________________________________________________
504{
505 if (!PyIsInitialized()) {
506 PyInitialize();
507 }
508
509 Log() << Endl;
510 Log() << gTools().Color("bold") << "Loading state file: " << gTools().Color("reset") << fFilenameClassifier << Endl;
511 Log() << Endl;
512
513 // Load classifier from file
515 if(err != 0)
516 {
517 Log() << kFATAL << Form("Failed to load classifier from file (error code: %i): %s", err, fFilenameClassifier.Data()) << Endl;
518 }
519
520 // Book classifier object in python dict
522
523 // Load data properties
524 // NOTE: This has to be repeated here for the reader application
527}
528
529//_______________________________________________________________________
531{
532 // Get feature importance from classifier as an array with length equal
533 // number of variables, higher value signals a higher importance
535 if(pRanking == 0) Log() << kFATAL << "Failed to get ranking from classifier" << Endl;
536
537 // Fill ranking object and return it
538 fRanking = new Ranking(GetName(), "Variable Importance");
540 for(UInt_t iVar=0; iVar<fNvars; iVar++){
542 }
543
545
546 return fRanking;
547}
548
549//_______________________________________________________________________
551{
552 // typical length of text line:
553 // "|--------------------------------------------------------------|"
554 Log() << "A random forest is a meta estimator that fits a number of decision" << Endl;
555 Log() << "tree classifiers on various sub-samples of the dataset and use" << Endl;
556 Log() << "averaging to improve the predictive accuracy and control over-fitting." << Endl;
557 Log() << Endl;
558 Log() << "Check out the scikit-learn documentation for more information." << Endl;
559}
#define REGISTER_METHOD(CLASS)
for example
_object PyObject
#define e(i)
Definition RSha256.hxx:103
unsigned int UInt_t
Unsigned integer 4 bytes (unsigned int)
Definition RtypesCore.h:60
constexpr Bool_t kFALSE
Definition RtypesCore.h:108
long long Long64_t
Portable signed long integer 8 bytes.
Definition RtypesCore.h:83
constexpr Bool_t kTRUE
Definition RtypesCore.h:107
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
Definition TString.cxx:2495
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
MsgLogger & Log() const
Class that contains all the data information.
Definition DataSetInfo.h:62
UInt_t GetNClasses() const
const Event * GetEvent() const
returns event without transformations
Definition DataSet.cxx:202
Types::ETreeType GetCurrentType() const
Definition DataSet.h:194
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition DataSet.h:206
Long64_t GetNTrainingEvents() const
Definition DataSet.h:68
void SetCurrentEvent(Long64_t ievt) const
Definition DataSet.h:88
const Event * GetTrainingEvent(Long64_t ievt) const
Definition DataSet.h:74
const char * GetName() const override
Definition MethodBase.h:334
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
Bool_t IsModelPersistence() const
Definition MethodBase.h:383
const TString & GetWeightFileDir() const
Definition MethodBase.h:492
const TString & GetMethodName() const
Definition MethodBase.h:331
DataSetInfo & DataInfo() const
Definition MethodBase.h:410
virtual void TestClassification()
initialization
UInt_t GetNVariables() const
Definition MethodBase.h:345
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
const TString & GetInputLabel(Int_t i) const
Definition MethodBase.h:350
Ranking * fRanking
Definition MethodBase.h:587
DataSet * Data() const
Definition MethodBase.h:409
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets) override
MethodPyRandomForest(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false) override
get all the MVA values for the events of the current Data type
std::vector< Float_t > & GetMulticlassValues() override
std::vector< Float_t > classValues
std::vector< Double_t > mvaValues
void GetHelpMessage() const override
void TestClassification() override
initialization
const Ranking * CreateRanking() override
Double_t GetMvaValue(Double_t *errLower=nullptr, Double_t *errUpper=nullptr) override
Virtual base class for all TMVA method based on Python.
static int PyIsInitialized()
Check Python interpreter initialization status.
PyObject * Eval(TString code)
Evaluate Python code.
static void PyInitialize()
Initialize Python interpreter.
static void Serialize(TString file, PyObject *classifier)
Serialize Python object.
static Int_t UnSerialize(TString file, PyObject **obj)
Unserialize Python object.
PyObject * fClassifier
void PyRunString(TString code, TString errorMessage="Failed to run python code", int start=256)
Execute Python code from string.
Ranking for variables in method (implementation)
Definition Ranking.h:48
virtual void AddRank(const Rank &rank)
Add a new rank take ownership of it.
Definition Ranking.cxx:85
Timing information for training and evaluation of MVA methods.
Definition Timer.h:58
const TString & Color(const TString &)
human readable color strings
Definition Tools.cxx:828
Singleton class for Global types used by TMVA.
Definition Types.h:71
@ kSignal
Never change this number - it is elsewhere assumed to be zero !
Definition Types.h:135
@ kMulticlass
Definition Types.h:129
@ kClassification
Definition Types.h:127
@ kTraining
Definition Types.h:143
Basic string class.
Definition TString.h:138
const char * Data() const
Definition TString.h:384
Bool_t IsNull() const
Definition TString.h:422
create variable transformations
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148