Logo ROOT   6.07/09
Reference Guide
MethodPyAdaBoost.cxx
Go to the documentation of this file.
1 // @(#)root/tmva/pymva $Id$
2 // Authors: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer 2015
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodPyAdaBoost *
8  * Web : http://oproject.org *
9  * *
10  * Description: *
11  * AdaBoost Classifiear from Scikit learn *
12  * *
13  * *
14  * Redistribution and use in source and binary forms, with or without *
15  * modification, are permitted according to the terms listed in LICENSE *
16  * (http://tmva.sourceforge.net/LICENSE) *
17  * *
18  **********************************************************************************/
19 
20 #include <Python.h> // Needs to be included first to avoid redefinition of _POSIX_C_SOURCE
21 #include "TMVA/MethodPyAdaBoost.h"
22 
23 #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
24 #include <numpy/arrayobject.h>
25 
26 #pragma GCC diagnostic ignored "-Wunused-parameter"
27 
28 #include "TMVA/Config.h"
29 #include "TMVA/Configurable.h"
30 #include "TMVA/ClassifierFactory.h"
31 #include "TMVA/DataSet.h"
32 #include "TMVA/Event.h"
33 #include "TMVA/IMethod.h"
34 #include "TMVA/MsgLogger.h"
35 #include "TMVA/PDF.h"
36 #include "TMVA/Ranking.h"
37 #include "TMVA/Tools.h"
38 #include "TMVA/Types.h"
40 #include "TMVA/Results.h"
41 
42 #include "TMath.h"
43 #include "Riostream.h"
44 #include "TMatrix.h"
45 #include "TMatrixD.h"
46 #include "TVectorD.h"
47 
48 #include <iomanip>
49 #include <fstream>
50 
51 using namespace TMVA;
52 
53 REGISTER_METHOD(PyAdaBoost)
54 
56 
57 //_______________________________________________________________________
59  const TString &methodTitle,
60  DataSetInfo &dsi,
61  const TString &theOption) :
62  PyMethodBase(jobName, Types::kPyAdaBoost, methodTitle, dsi, theOption),
63  base_estimator("None"),
64  n_estimators(50),
65  learning_rate(1.0),
66  algorithm("SAMME.R"),
67  random_state("None")
68 {
69 }
70 
71 //_______________________________________________________________________
73  : PyMethodBase(Types::kPyAdaBoost, theData, theWeightFile),
74  base_estimator("None"),
75  n_estimators(50),
76  learning_rate(1.0),
77  algorithm("SAMME.R"),
78  random_state("None")
79 {
80 }
81 
82 
83 //_______________________________________________________________________
85 {
86 }
87 
88 //_______________________________________________________________________
90 {
91  if (type == Types::kClassification && numberClasses == 2) return kTRUE;
92  return kFALSE;
93 }
94 
95 
96 //_______________________________________________________________________
98 {
100 
101  DeclareOptionRef(base_estimator, "BaseEstimator", "object, optional (default=DecisionTreeClassifier)\
102  The base estimator from which the boosted ensemble is built.\
103  Support for sample weighting is required, as well as proper `classes_`\
104  and `n_classes_` attributes.");
105 
106  DeclareOptionRef(n_estimators, "NEstimators", "integer, optional (default=50)\
107  The maximum number of estimators at which boosting is terminated.\
108  In case of perfect fit, the learning procedure is stopped early.");
109 
110  DeclareOptionRef(learning_rate, "LearningRate", "float, optional (default=1.)\
111  Learning rate shrinks the contribution of each classifier by\
112  ``learning_rate``. There is a trade-off between ``learning_rate`` and\
113  ``n_estimators``.");
114 
115  DeclareOptionRef(algorithm, "Algorithm", "{'SAMME', 'SAMME.R'}, optional (default='SAMME.R')\
116  If 'SAMME.R' then use the SAMME.R real boosting algorithm.\
117  ``base_estimator`` must support calculation of class probabilities.\
118  If 'SAMME' then use the SAMME discrete boosting algorithm.\
119  The SAMME.R algorithm typically converges faster than SAMME,\
120  achieving a lower test error with fewer boosting iterations.");
121 
122  DeclareOptionRef(random_state, "RandomState", "int, RandomState instance or None, optional (default=None)\
123  If int, random_state is the seed used by the random number generator;\
124  If RandomState instance, random_state is the random number generator;\
125  If None, the random number generator is the RandomState instance used\
126  by `np.random`.");
127 }
128 
129 //_______________________________________________________________________
131 {
132  PyObject *pobase_estimator = Eval(base_estimator);
133  if (!pobase_estimator) {
134  Log() << kFATAL << Form(" BaseEstimator = %s... that does not work !! ", base_estimator.Data())
135  << " The options are Object or None."
136  << Endl;
137  }
138  Py_DECREF(pobase_estimator);
139 
140  if (n_estimators <= 0) {
141  Log() << kERROR << " NEstimators <=0... that does not work !! "
142  << " I set it to 10 .. just so that the program does not crash"
143  << Endl;
144  n_estimators = 10;
145  }
146  if (learning_rate <= 0) {
147  Log() << kERROR << " LearningRate <=0... that does not work !! "
148  << " I set it to 1.0 .. just so that the program does not crash"
149  << Endl;
150  learning_rate = 1.0;
151  }
152 
153  if (algorithm != "SAMME" && algorithm != "SAMME.R") {
154  Log() << kFATAL << Form(" Algorithm = %s... that does not work !! ", algorithm.Data())
155  << " The options are SAMME of SAMME.R."
156  << Endl;
157  }
158  PyObject *porandom_state = Eval(random_state);
159  if (!porandom_state) {
160  Log() << kFATAL << Form(" RandomState = %s... that does not work !! ", random_state.Data())
161  << "If int, random_state is the seed used by the random number generator;"
162  << "If RandomState instance, random_state is the random number generator;"
163  << "If None, the random number generator is the RandomState instance used by `np.random`."
164  << Endl;
165  }
166  Py_DECREF(porandom_state);
167 }
168 
169 
170 //_______________________________________________________________________
172 {
173  ProcessOptions();
174  _import_array();//require to use numpy arrays
175 
176  //Import sklearn
177  // Convert the file name to a Python string.
178  PyObject *pName = PyUnicode_FromString("sklearn.ensemble");
179  // Import the file as a Python module.
180  fModule = PyImport_Import(pName);
181  Py_DECREF(pName);
182 
183  if (!fModule) {
184  Log() << kFATAL << "Can't import sklearn.ensemble" << Endl;
185  Log() << Endl;
186  }
187 
188 
189  //Training data
190  UInt_t fNvars = Data()->GetNVariables();
191  int fNrowsTraining = Data()->GetNTrainingEvents(); //every row is an event, a class type and a weight
192  int *dims = new int[2];
193  dims[0] = fNrowsTraining;
194  dims[1] = fNvars;
195  fTrainData = (PyArrayObject *)PyArray_FromDims(2, dims, NPY_FLOAT);
196  float *TrainData = (float *)(PyArray_DATA(fTrainData));
197 
198 
199  fTrainDataClasses = (PyArrayObject *)PyArray_FromDims(1, &fNrowsTraining, NPY_FLOAT);
200  float *TrainDataClasses = (float *)(PyArray_DATA(fTrainDataClasses));
201 
202  fTrainDataWeights = (PyArrayObject *)PyArray_FromDims(1, &fNrowsTraining, NPY_FLOAT);
203  float *TrainDataWeights = (float *)(PyArray_DATA(fTrainDataWeights));
204 
205  for (int i = 0; i < fNrowsTraining; i++) {
206  const TMVA::Event *e = Data()->GetTrainingEvent(i);
207  for (UInt_t j = 0; j < fNvars; j++) {
208  TrainData[j + i * fNvars] = e->GetValue(j);
209  }
210  if (e->GetClass() == TMVA::Types::kSignal) TrainDataClasses[i] = TMVA::Types::kSignal;
211  else TrainDataClasses[i] = TMVA::Types::kBackground;
212 
213  TrainDataWeights[i] = e->GetWeight();
214  }
215 }
216 
218 {
219  PyObject *pobase_estimator = Eval(base_estimator);
220  PyObject *porandom_state = Eval(random_state);
221 
222  PyObject *args = Py_BuildValue("(OifsO)", pobase_estimator, n_estimators, learning_rate, algorithm.Data(), porandom_state);
223  PyObject_Print(args, stdout, 0);
224  std::cout << std::endl;
225  PyObject *pDict = PyModule_GetDict(fModule);
226  PyObject *fClassifierClass = PyDict_GetItemString(pDict, "AdaBoostClassifier");
227 
228  // Create an instance of the class
229  if (PyCallable_Check(fClassifierClass)) {
230  //instance
231  fClassifier = PyObject_CallObject(fClassifierClass , args);
232  PyObject_Print(fClassifier, stdout, 0);
233 
234  Py_DECREF(args);
235  } else {
236  PyErr_Print();
237  Py_DECREF(pDict);
238  Py_DECREF(fClassifierClass);
239  Log() << kFATAL << "Can't call function AdaBoostClassifier" << Endl;
240  Log() << Endl;
241 
242  }
243 
244  fClassifier = PyObject_CallMethod(fClassifier, (char *)"fit", (char *)"(OOO)", fTrainData, fTrainDataClasses, fTrainDataWeights);
245 
246  if (IsModelPersistence())
247  {
248  TString path = GetWeightFileDir() + "/PyAdaBoostModel.PyData";
249  Log() << Endl;
250  Log() << gTools().Color("bold") << "--- Saving State File In:" << gTools().Color("reset") << path << Endl;
251  Log() << Endl;
252  Serialize(path,fClassifier);
253  }
254 }
255 
256 //_______________________________________________________________________
258 {
260 }
261 
262 
263 //_______________________________________________________________________
265 {
266  // cannot determine error
267  NoErrorCalc(errLower, errUpper);
268 
270 
271  Double_t mvaValue;
272  const TMVA::Event *e = Data()->GetEvent();
273  UInt_t nvars = e->GetNVariables();
274  int dims[2];
275  dims[0] = 1;
276  dims[1] = nvars;
277  PyArrayObject *pEvent= (PyArrayObject *)PyArray_FromDims(2, dims, NPY_FLOAT);
278  float *pValue = (float *)(PyArray_DATA(pEvent));
279 
280  for (UInt_t i = 0; i < nvars; i++) pValue[i] = e->GetValue(i);
281 
282  PyArrayObject *result = (PyArrayObject *)PyObject_CallMethod(fClassifier, const_cast<char *>("predict_proba"), const_cast<char *>("(O)"), pEvent);
283  double *proba = (double *)(PyArray_DATA(result));
284  mvaValue = proba[0]; //getting signal prob
285  Py_DECREF(result);
286  Py_DECREF(pEvent);
287  return mvaValue;
288 }
289 
290 //_______________________________________________________________________
292 {
293  if (!PyIsInitialized()) {
294  PyInitialize();
295  }
296 
297  TString path = GetWeightFileDir() + "/PyAdaBoostModel.PyData";
298  Log() << Endl;
299  Log() << gTools().Color("bold") << "--- Loading State File From:" << gTools().Color("reset") << path << Endl;
300  Log() << Endl;
301  UnSerialize(path,&fClassifier);
302 }
303 
304 //_______________________________________________________________________
306 {
307  // get help message text
308  //
309  // typical length of text line:
310  // "|--------------------------------------------------------------|"
311  Log() << Endl;
312  Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
313  Log() << Endl;
314  Log() << "Decision Trees and Rule-Based Models " << Endl;
315  Log() << Endl;
316  Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
317  Log() << Endl;
318  Log() << Endl;
319  Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
320  Log() << Endl;
321  Log() << "<None>" << Endl;
322 }
const TString & GetWeightFileDir() const
Definition: MethodBase.h:486
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
#define REGISTER_METHOD(CLASS)
for example
PyObject * fClassifier
Definition: PyMethodBase.h:121
const Event * GetTrainingEvent(Long64_t ievt) const
Definition: DataSet.h:99
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
DataSet * Data() const
Definition: MethodBase.h:405
EAnalysisType
Definition: Types.h:128
Basic string class.
Definition: TString.h:137
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kFALSE
Definition: Rtypes.h:92
virtual void TestClassification()
initialization
static void Serialize(TString file, PyObject *classifier)
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not...
Definition: Event.cxx:378
PyArrayObject * fTrainDataClasses
Definition: PyMethodBase.h:125
static int PyIsInitialized()
static void PyInitialize()
Float_t GetValue(UInt_t ivar) const
return value of i&#39;th variable
Definition: Event.cxx:233
const char * Data() const
Definition: TString.h:349
Tools & gTools()
Definition: Tools.cxx:79
static PyObject * Eval(TString code)
PyArrayObject * fTrainDataWeights
Definition: PyMethodBase.h:124
UInt_t GetNVariables() const
accessor to the number of variables
Definition: Event.cxx:305
Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)
#define None
Definition: TGWin32.h:59
PyObject * fModule
Definition: PyMethodBase.h:120
unsigned int UInt_t
Definition: RtypesCore.h:42
char * Form(const char *fmt,...)
PyArrayObject * fTrainData
Definition: PyMethodBase.h:123
const Event * GetEvent() const
Definition: DataSet.cxx:211
#define ClassImp(name)
Definition: Rtypes.h:279
double Double_t
Definition: RtypesCore.h:55
int type
Definition: TGX11.cxx:120
MsgLogger & Log() const
Definition: Configurable.h:128
UInt_t GetNVariables() const
access the number of variables through the datasetinfo
Definition: DataSet.cxx:225
you should not use this method at all Int_t Int_t Double_t Double_t Double_t e
Definition: TRolke.cxx:630
UInt_t GetClass() const
Definition: Event.h:89
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:837
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
Abstract ClassifierFactory template that handles arbitrary types.
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
Definition: MethodBase.cxx:590
Long64_t GetNTrainingEvents() const
Definition: DataSet.h:93
double result[121]
virtual void ReadModelFromFile()
static void UnSerialize(TString file, PyObject **obj)
const Bool_t kTRUE
Definition: Rtypes.h:91
MethodPyAdaBoost(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
virtual void TestClassification()
initialization
TRandom3 R
a TMatrixD.
Definition: testIO.cxx:28
_object PyObject
Definition: TPyArg.h:22
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
Definition: MethodBase.cxx:819
Bool_t IsModelPersistence()
Definition: MethodBase.h:379