Logo ROOT   6.07/09
Reference Guide
PyMethodBase.cxx
Go to the documentation of this file.
1 // @(#)root/tmva/pymva $Id$
2 // Authors: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer 2015
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : PyMethodBase *
8  * *
9  * Description: *
10  * Virtual base class for all MVA method based on python *
11  * *
12  **********************************************************************************/
13 #include <Python.h> // Needs to be included first to avoid redefinition of _POSIX_C_SOURCE
14 #include<TMVA/PyMethodBase.h>
15 
16 #pragma GCC diagnostic ignored "-Wunused-parameter"
17 #pragma GCC diagnostic ignored "-Wunused-function"
18 
19 #include "TMVA/DataSet.h"
20 #include "TMVA/DataSetInfo.h"
21 #include "TMVA/MsgLogger.h"
22 #include "TMVA/Results.h"
23 #include "TMVA/Timer.h"
24 
25 #include<TApplication.h>
26 
27 #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
28 #include <numpy/arrayobject.h>
29 
30 #include <fstream>
31 #include <wchar.h>
32 
33 using namespace TMVA;
34 
36 
37 PyObject *PyMethodBase::fModuleBuiltin = NULL;
38 PyObject *PyMethodBase::fEval = NULL;
39 PyObject *PyMethodBase::fOpen = NULL;
40 
41 PyObject *PyMethodBase::fModulePickle = NULL;
42 PyObject *PyMethodBase::fPickleDumps = NULL;
43 PyObject *PyMethodBase::fPickleLoads = NULL;
44 
45 PyObject *PyMethodBase::fMain = NULL;
46 PyObject *PyMethodBase::fGlobalNS = NULL;
47 PyObject *PyMethodBase::fLocalNS = NULL;
48 
49 class PyGILRAII {
50  PyGILState_STATE m_GILState;
51 public:
52  PyGILRAII():m_GILState(PyGILState_Ensure()){}
53  ~PyGILRAII(){PyGILState_Release(m_GILState);}
54 };
55 
56 //_______________________________________________________________________
58  Types::EMVA methodType,
59  const TString &methodTitle,
60  DataSetInfo &dsi,
61  const TString &theOption ): MethodBase(jobName, methodType, methodTitle, dsi, theOption),
62  fClassifier(NULL)
63 {
64  if (!PyIsInitialized()) {
65  PyInitialize();
66  }
67 }
68 
69 //_______________________________________________________________________
71  DataSetInfo &dsi,
72  const TString &weightFile): MethodBase(methodType, dsi, weightFile),
74 {
75  if (!PyIsInitialized()) {
76  PyInitialize();
77  }
78 }
79 
80 //_______________________________________________________________________
82 {
83 }
84 
85 //_______________________________________________________________________
87 {
89  PyObject *pycode = Py_BuildValue("(sOO)", code.Data(), fGlobalNS, fLocalNS);
90  PyObject *result = PyObject_CallObject(fEval, pycode);
91  Py_DECREF(pycode);
92  return result;
93 }
94 
95 //_______________________________________________________________________
97 {
99 
100  bool pyIsInitialized = PyIsInitialized();
101  if (!pyIsInitialized) {
102  Py_Initialize();
103  }
104 
105  PyGILRAII thePyGILRAII;
106 
107  if (!pyIsInitialized) {
108  _import_array();
109  }
110 
111  fMain = PyImport_AddModule("__main__");
112  if (!fMain) {
113  Log << kFATAL << "Can't import __main__" << Endl;
114  Log << Endl;
115  }
116 
117  fGlobalNS = PyModule_GetDict(fMain);
118  if (!fGlobalNS) {
119  Log << kFATAL << "Can't init global namespace" << Endl;
120  Log << Endl;
121  }
122 
123  fLocalNS = PyDict_New();
124  if (!fMain) {
125  Log << kFATAL << "Can't init local namespace" << Endl;
126  Log << Endl;
127  }
128 
129  #if PY_MAJOR_VERSION < 3
130  //preparing objects for eval
131  PyObject *bName = PyUnicode_FromString("__builtin__");
132  // Import the file as a Python module.
133  fModuleBuiltin = PyImport_Import(bName);
134  if (!fModuleBuiltin) {
135  Log << kFATAL << "Can't import __builtin__" << Endl;
136  Log << Endl;
137  }
138  #else
139  //preparing objects for eval
140  PyObject *bName = PyUnicode_FromString("builtins");
141  // Import the file as a Python module.
142  fModuleBuiltin = PyImport_Import(bName);
143  if (!fModuleBuiltin) {
144  Log << kFATAL << "Can't import builtins" << Endl;
145  Log << Endl;
146  }
147  #endif
148 
149  PyObject *mDict = PyModule_GetDict(fModuleBuiltin);
150  fEval = PyDict_GetItemString(mDict, "eval");
151  fOpen = PyDict_GetItemString(mDict, "open");
152 
153  Py_DECREF(bName);
154  Py_DECREF(mDict);
155  //preparing objects for pickle
156  PyObject *pName = PyUnicode_FromString("pickle");
157  // Import the file as a Python module.
158  fModulePickle = PyImport_Import(pName);
159  if (!fModulePickle) {
160  Log << kFATAL << "Can't import pickle" << Endl;
161  Log << Endl;
162  }
163  PyObject *pDict = PyModule_GetDict(fModulePickle);
164  fPickleDumps = PyDict_GetItemString(pDict, "dump");
165  fPickleLoads = PyDict_GetItemString(pDict, "load");
166 
167  Py_DECREF(pName);
168  Py_DECREF(pDict);
169 
170 
171 }
172 
173 //_______________________________________________________________________
175 {
176  Py_Finalize();
177  if (fEval) Py_DECREF(fEval);
178  if (fModuleBuiltin) Py_DECREF(fModuleBuiltin);
179  if (fPickleDumps) Py_DECREF(fPickleDumps);
180  if (fPickleLoads) Py_DECREF(fPickleLoads);
181  if(fMain) Py_DECREF(fMain);//objects fGlobalNS and fLocalNS will be free here
182 }
184 {
185  #if PY_MAJOR_VERSION < 3
186  Py_SetProgramName(const_cast<char*>(name.Data()));
187  #else
188  Py_SetProgramName((wchar_t *)name.Data());
189  #endif
190 }
191 
192 size_t mystrlen(const char* s) { return strlen(s); }
193 size_t mystrlen(const wchar_t* s) { return wcslen(s); }
194 
195 //_______________________________________________________________________
197 {
198 auto progName = ::Py_GetProgramName();
199 return std::string(progName, progName + mystrlen(progName));
200 }
201 //_______________________________________________________________________
203 {
204  if (!Py_IsInitialized()) return kFALSE;
205  if (!fEval) return kFALSE;
206  if (!fModuleBuiltin) return kFALSE;
207  if (!fPickleDumps) return kFALSE;
208  if (!fPickleLoads) return kFALSE;
209  return kTRUE;
210 }
211 
213 {
214  if(!PyIsInitialized()) PyInitialize();
215  PyObject *file_arg = Py_BuildValue("(ss)", path.Data(),"wb");
216  PyObject *file = PyObject_CallObject(fOpen,file_arg);
217  PyObject *model_arg = Py_BuildValue("(OO)", obj,file);
218  PyObject *model_data = PyObject_CallObject(fPickleDumps , model_arg);
219 
220  Py_DECREF(file_arg);
221  Py_DECREF(file);
222  Py_DECREF(model_arg);
223  Py_DECREF(model_data);
224 }
225 
227 {
228  PyObject *file_arg = Py_BuildValue("(ss)", path.Data(),"rb");
229  PyObject *file = PyObject_CallObject(fOpen,file_arg);
230 
231  PyObject *model_arg = Py_BuildValue("(O)", file);
232  *obj = PyObject_CallObject(fPickleLoads , model_arg);
233 
234  Py_DECREF(file_arg);
235  Py_DECREF(file);
236  Py_DECREF(model_arg);
237 }
238 
239 
240 ////////////////////////////////////////////////////////////////////////////////
241 /// get all the MVA values for the events of the current Data type
242 std::vector<Double_t> PyMethodBase::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
243 {
244 
246 
248  if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
249  if (firstEvt < 0) firstEvt = 0;
250  std::vector<Double_t> values(lastEvt-firstEvt);
251 
252  nEvents = values.size();
253 
254  UInt_t nvars = Data()->GetNVariables();
255 
256  int dims[2];
257  dims[0] = nEvents;
258  dims[1] = nvars;
259  PyArrayObject *pEvent= (PyArrayObject *)PyArray_FromDims(2, dims, NPY_FLOAT);
260  float *pValue = (float *)(PyArray_DATA(pEvent));
261 
262 // int dims2[2];
263 // dims2[0] = 1;
264 // dims2[1] = nvars;
265 
266  // use timer
267  Timer timer( nEvents, GetName(), kTRUE );
268  if (logProgress)
269  Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName())<< "Evaluation of " << GetMethodName() << " on "
270  << (Data()->GetCurrentType()==Types::kTraining?"training":"testing") << " sample (" << nEvents << " events)" << Endl;
271 
272 
273  // fill numpy array with events data
274  for (Int_t ievt=0; ievt<nEvents; ievt++) {
275  Data()->SetCurrentEvent(ievt);
276  const TMVA::Event *e = Data()->GetEvent();
277  assert(nvars == e->GetNVariables());
278  for (UInt_t i = 0; i < nvars; i++) {
279  pValue[ievt * nvars + i] = e->GetValue(i);
280  }
281  // if (ievt%100 == 0)
282  // std::cout << "Event " << ievt << " type" << DataInfo().IsSignal(e) << " : " << pValue[ievt*nvars] << " " << pValue[ievt*nvars+1] << " " << pValue[ievt*nvars+2] << std::endl;
283  }
284 
285  // pass all the events to Scikit and evaluate the probabilities
286  PyArrayObject *result = (PyArrayObject *)PyObject_CallMethod(fClassifier, const_cast<char *>("predict_proba"), const_cast<char *>("(O)"), pEvent);
287  double *proba = (double *)(PyArray_DATA(result));
288 
289  // the return probabilities is a vector of pairs of (p_sig,p_backg)
290  // we ar einterested only in the signal probability
291  std::vector<double> mvaValues(nEvents);
292  for (int i = 0; i < nEvents; ++i)
293  mvaValues[i] = proba[2*i];
294 
295  if (logProgress) {
296  Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Elapsed time for evaluation of " << nEvents << " events: "
297  << timer.GetElapsedTime() << " " << Endl;
298  }
299 
300  Py_DECREF(result);
301  Py_DECREF(pEvent);
302 
303  return mvaValues;
304 }
PyMethodBase(const TString &jobName, Types::EMVA methodType, const TString &methodTitle, DataSetInfo &dsi, const TString &theOption="")
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
long long Long64_t
Definition: RtypesCore.h:69
PyObject * fClassifier
Definition: PyMethodBase.h:121
static PyObject * fModulePickle
Definition: PyMethodBase.h:132
const char * GetName() const
Definition: MethodBase.h:330
DataSet * Data() const
Definition: MethodBase.h:405
Basic string class.
Definition: TString.h:137
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kFALSE
Definition: Rtypes.h:92
static void Serialize(TString file, PyObject *classifier)
TString GetElapsedTime(Bool_t Scientific=kTRUE)
Definition: Timer.cxx:129
static int PyIsInitialized()
static void PyInitialize()
const TString & GetMethodName() const
Definition: MethodBase.h:327
Float_t GetValue(UInt_t ivar) const
return value of i&#39;th variable
Definition: Event.cxx:233
const char * Data() const
Definition: TString.h:349
TStopwatch timer
Definition: pirndm.C:37
virtual void ReadModelFromFile()=0
static PyObject * fEval
Definition: PyMethodBase.h:129
static PyObject * Eval(TString code)
void SetCurrentEvent(Long64_t ievt) const
Definition: DataSet.h:113
virtual std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
UInt_t GetNVariables() const
accessor to the number of variables
Definition: Event.cxx:305
const int nEvents
Definition: testRooFit.cxx:42
static void PyFinalize()
static TString Py_GetProgramName()
static void PySetProgramName(TString name)
unsigned int UInt_t
Definition: RtypesCore.h:42
char * Form(const char *fmt,...)
Types::ETreeType GetCurrentType() const
Definition: DataSet.h:217
static PyObject * fOpen
Definition: PyMethodBase.h:130
size_t mystrlen(const char *s)
const Event * GetEvent() const
Definition: DataSet.cxx:211
static PyObject * fLocalNS
Definition: PyMethodBase.h:138
static PyObject * fModuleBuiltin
Definition: PyMethodBase.h:128
#define ClassImp(name)
Definition: Rtypes.h:279
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:229
MsgLogger & Log() const
Definition: Configurable.h:128
UInt_t GetNVariables() const
access the number of variables through the datasetinfo
Definition: DataSet.cxx:225
you should not use this method at all Int_t Int_t Double_t Double_t Double_t e
Definition: TRolke.cxx:630
DataSetInfo & DataInfo() const
Definition: MethodBase.h:406
Abstract ClassifierFactory template that handles arbitrary types.
static PyObject * fPickleLoads
Definition: PyMethodBase.h:134
static PyObject * fPickleDumps
Definition: PyMethodBase.h:133
Definition: file.py:1
virtual ~PyMethodBase()
#define NULL
Definition: Rtypes.h:82
double result[121]
static PyObject * fMain
Definition: PyMethodBase.h:136
static void UnSerialize(TString file, PyObject **obj)
const Bool_t kTRUE
Definition: Rtypes.h:91
char name[80]
Definition: TGX11.cxx:109
static PyObject * fGlobalNS
Definition: PyMethodBase.h:137
_object PyObject
Definition: TPyArg.h:22