Logo ROOT   6.07/09
Reference Guide
PyMethodBase.h
Go to the documentation of this file.
1 // @(#)root/tmva/pymva $Id$
2 // Authors: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer 2015
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : PyMethodBase *
8  * Web : http://oproject.org *
9  * *
10  * Description: *
11  * Virtual base class for all MVA method based on Python *
12  * *
13  **********************************************************************************/
14 
15 #ifndef ROOT_TMVA_PyMethodBase
16 #define ROOT_TMVA_PyMethodBase
17 
18 ////////////////////////////////////////////////////////////////////////////////
19 // //
20 // PyMethodBase //
21 // //
22 // Virtual base class for all TMVA method based on Python/scikit-learn //
23 // //
24 ////////////////////////////////////////////////////////////////////////////////
25 
26 #include "TMVA/MethodBase.h"
27 #include "TMVA/Types.h"
28 
29 #include "Rtypes.h"
30 #include "TString.h"
31 
32 class TFile;
33 class TGraph;
34 class TTree;
35 class TDirectory;
36 class TSpline;
37 class TH1F;
38 class TH1D;
39 
40 struct _object;
41 typedef _object PyObject;
42 
43 // needed by NPY_API_VERSION
44 #include "numpy/numpyconfig.h"
45 
46 #if (NPY_API_VERSION >= 0x00000007 )
47 struct tagPyArrayObject;
48 typedef tagPyArrayObject PyArrayObject;
49 #else
50 struct PyArrayObject;
51 #endif
52 
53 
54 namespace TMVA {
55 
56  class Ranking;
57  class PDF;
58  class TSpline1;
59  class MethodCuts;
60  class MethodBoost;
61  class DataSetInfo;
62 
63  class PyMethodBase : public MethodBase {
64 
65  friend class Factory;
66  public:
67 
68  // default constructur
69  PyMethodBase(const TString &jobName,
70  Types::EMVA methodType,
71  const TString &methodTitle,
72  DataSetInfo &dsi,
73  const TString &theOption = "");
74 
75  // constructor used for Testing + Application of the MVA, only (no training),
76  // using given weight file
77  PyMethodBase(Types::EMVA methodType,
78  DataSetInfo &dsi,
79  const TString &weightFile);
80 
81  // default destructur
82  virtual ~PyMethodBase();
83  //basic python related function
84  static void PyInitialize();
85  static int PyIsInitialized();
86  static void PyFinalize();
87  static void PySetProgramName(TString name);
88  static TString Py_GetProgramName();
89 
90  static PyObject *Eval(TString code);//required to parse booking options from string to pyobjects
91  static void Serialize(TString file,PyObject *classifier);
92  static void UnSerialize(TString file,PyObject** obj);
93 
94  virtual void Train() = 0;
95  // options treatment
96  virtual void Init() = 0;
97  virtual void DeclareOptions() = 0;
98  virtual void ProcessOptions() = 0;
99  // create ranking
100  virtual const Ranking *CreateRanking() = 0;
101 
102  virtual Double_t GetMvaValue(Double_t *errLower = 0, Double_t *errUpper = 0) = 0;
103 
104  Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets) = 0;
105  protected:
106  // the actual "weights"
107  virtual void AddWeightsXMLTo(void *parent) const = 0;
108  virtual void ReadWeightsFromXML(void *wghtnode) = 0;
109  virtual void ReadWeightsFromStream(std::istream &) = 0; // backward compatibility
110  virtual void ReadWeightsFromStream(TFile &) {} // backward compatibility
111 
112 
113  virtual void ReadModelFromFile() = 0;
114 
115  // signal/background classification response for all current set of data
116  virtual std::vector<Double_t> GetMvaValues(Long64_t firstEvt = 0, Long64_t lastEvt = -1, Bool_t logProgress = false);
117 
118  protected:
119 
120  PyObject *fModule;//Module to load
121  PyObject *fClassifier;//Classifier object
122 
123  PyArrayObject *fTrainData;
124  PyArrayObject *fTrainDataWeights;//array of weights
125  PyArrayObject *fTrainDataClasses;//array with sig/bgk class
126  private:
127 
129  static PyObject *fEval;//eval funtion from python
130  static PyObject *fOpen;//open function for files
131  protected:
132  static PyObject *fModulePickle; //Module for model persistence
133  static PyObject *fPickleDumps; //Function to dumps PyObject information into string
134  static PyObject *fPickleLoads; //Function to load PyObject information from string
135 
136  static PyObject *fMain;//module __main__ to get namesapace local and global
137  static PyObject *fGlobalNS;//global namesapace
138  static PyObject *fLocalNS;//local namesapace
139 
140 
141  ClassDef(PyMethodBase, 0) // Virtual base class for all TMVA method
142 
143  };
144 } // namespace TMVA
145 
146 #endif
147 
148 
PyMethodBase(const TString &jobName, Types::EMVA methodType, const TString &methodTitle, DataSetInfo &dsi, const TString &theOption="")
long long Long64_t
Definition: RtypesCore.h:69
PyObject * fClassifier
Definition: PyMethodBase.h:121
virtual void AddWeightsXMLTo(void *parent) const =0
static PyObject * fModulePickle
Definition: PyMethodBase.h:132
virtual void DeclareOptions()=0
Base class for spline implementation containing the Draw/Paint methods //.
Definition: TSpline.h:22
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format...
Definition: TFile.h:50
EAnalysisType
Definition: Types.h:128
virtual void Init()=0
Basic string class.
Definition: TString.h:137
tomato 1-D histogram with a float per channel (see TH1 documentation)}
Definition: TH1.h:575
virtual void Train()=0
bool Bool_t
Definition: RtypesCore.h:59
static void Serialize(TString file, PyObject *classifier)
PyArrayObject * fTrainDataClasses
Definition: PyMethodBase.h:125
static int PyIsInitialized()
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)=0
static void PyInitialize()
#define ClassDef(name, id)
Definition: Rtypes.h:254
virtual void ReadModelFromFile()=0
virtual void ReadWeightsFromStream(TFile &)
Definition: PyMethodBase.h:110
static PyObject * fEval
Definition: PyMethodBase.h:129
static PyObject * Eval(TString code)
PyArrayObject * fTrainDataWeights
Definition: PyMethodBase.h:124
virtual std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
virtual void ReadWeightsFromStream(std::istream &)=0
PyObject * fModule
Definition: PyMethodBase.h:120
static void PyFinalize()
static TString Py_GetProgramName()
_object PyObject
Definition: PyMethodBase.h:40
virtual const Ranking * CreateRanking()=0
static void PySetProgramName(TString name)
unsigned int UInt_t
Definition: RtypesCore.h:42
PyArrayObject * fTrainData
Definition: PyMethodBase.h:123
static PyObject * fOpen
Definition: PyMethodBase.h:130
tomato 1-D histogram with a double per channel (see TH1 documentation)}
Definition: TH1.h:618
virtual void ReadWeightsFromXML(void *wghtnode)=0
static PyObject * fLocalNS
Definition: PyMethodBase.h:138
static PyObject * fModuleBuiltin
Definition: PyMethodBase.h:128
double Double_t
Definition: RtypesCore.h:55
Describe directory structure in memory.
Definition: TDirectory.h:44
virtual Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)=0
int type
Definition: TGX11.cxx:120
virtual void ProcessOptions()=0
Abstract ClassifierFactory template that handles arbitrary types.
static PyObject * fPickleLoads
Definition: PyMethodBase.h:134
static PyObject * fPickleDumps
Definition: PyMethodBase.h:133
Definition: file.py:1
virtual ~PyMethodBase()
A Graph is a graphics object made of two arrays X and Y with npoints each.
Definition: TGraph.h:53
A TTree object has a header with a name and a title.
Definition: TTree.h:98
static PyObject * fMain
Definition: PyMethodBase.h:136
static void UnSerialize(TString file, PyObject **obj)
char name[80]
Definition: TGX11.cxx:109
static PyObject * fGlobalNS
Definition: PyMethodBase.h:137
_object PyObject
Definition: TPyArg.h:22