Logo ROOT  
Reference Guide
MethodPyTorch.cxx
Go to the documentation of this file.
1 // @(#)root/tmva/pymva $Id$
2 // Author: Anirudh Dagar, 2020
3 
4 #include <Python.h>
5 #include "TMVA/MethodPyTorch.h"
6 
7 #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
8 #include <numpy/arrayobject.h>
9 
10 #include "TMVA/Types.h"
11 #include "TMVA/Config.h"
12 #include "TMVA/ClassifierFactory.h"
13 #include "TMVA/Results.h"
16 #include "TMVA/Tools.h"
17 #include "TMVA/Timer.h"
18 
19 using namespace TMVA;
20 
21 namespace TMVA {
22 namespace Internal {
23 class PyGILRAII {
24  PyGILState_STATE m_GILState;
25 
26 public:
27  PyGILRAII() : m_GILState(PyGILState_Ensure()) {}
28  ~PyGILRAII() { PyGILState_Release(m_GILState); }
29 };
30 } // namespace Internal
31 } // namespace TMVA
32 
33 REGISTER_METHOD(PyTorch)
34 
36 
37 
38 MethodPyTorch::MethodPyTorch(const TString &jobName, const TString &methodTitle, DataSetInfo &dsi, const TString &theOption)
39  : PyMethodBase(jobName, Types::kPyTorch, methodTitle, dsi, theOption) {
40  fNumEpochs = 10;
41  fBatchSize = 100;
42 
43  fContinueTraining = false;
44  fSaveBestOnly = true;
45  fLearningRateSchedule = ""; // empty string deactivates learning rate scheduler
46  fFilenameTrainedModel = ""; // empty string sets output model filename to default (in "weights/" directory.)
47 }
48 
49 
50 MethodPyTorch::MethodPyTorch(DataSetInfo &theData, const TString &theWeightFile)
51  : PyMethodBase(Types::kPyTorch, theData, theWeightFile) {
52  fNumEpochs = 10;
53  fBatchSize = 100;
54 
55  fContinueTraining = false;
56  fSaveBestOnly = true;
57  fLearningRateSchedule = ""; // empty string deactivates learning rate scheduler
58  fFilenameTrainedModel = ""; // empty string sets output model filename to default (in "weights/" directory.)
59 }
60 
61 
63 }
64 
65 
67  if (type == Types::kRegression) return kTRUE;
68  if (type == Types::kClassification && numberClasses == 2) return kTRUE;
69  if (type == Types::kMulticlass && numberClasses >= 2) return kTRUE;
70  return kFALSE;
71 }
72 
73 
75  DeclareOptionRef(fFilenameModel, "FilenameModel", "Filename of the initial PyTorch model");
76  DeclareOptionRef(fFilenameTrainedModel, "FilenameTrainedModel", "Filename of the trained output PyTorch model");
77  DeclareOptionRef(fBatchSize, "BatchSize", "Training batch size");
78  DeclareOptionRef(fNumEpochs, "NumEpochs", "Number of training epochs");
79 
80  DeclareOptionRef(fContinueTraining, "ContinueTraining", "Load weights from previous training");
81  DeclareOptionRef(fSaveBestOnly, "SaveBestOnly", "Store only weights with smallest validation loss");
82  DeclareOptionRef(fLearningRateSchedule, "LearningRateSchedule", "Set new learning rate during training at specific epochs, e.g., \"50,0.01;70,0.005\"");
83 
84  DeclareOptionRef(fNumValidationString = "20%", "ValidationSize", "Part of the training data to use for validation."
85  "Specify as 0.2 or 20% to use a fifth of the data set as validation set."
86  "Specify as 100 to use exactly 100 events. (Default: 20%)");
87  DeclareOptionRef(fUserCodeName = "", "UserCode", "Necessary python code provided by the user to be executed before loading and training the PyTorch Model");
88 
89 }
90 
91 
92 ////////////////////////////////////////////////////////////////////////////////
93 /// Validation of the ValidationSize option. Allowed formats are 20%, 0.2 and
94 /// 100 etc.
95 /// - 20% and 0.2 selects 20% of the training set as validation data.
96 /// - 100 selects 100 events as the validation data.
97 ///
98 /// @return number of samples in validation set
99 ///
101 {
102  Int_t nValidationSamples = 0;
103  UInt_t trainingSetSize = GetEventCollection(Types::kTraining).size();
104 
105  // Parsing + Validation
106  // --------------------
107  if (fNumValidationString.EndsWith("%")) {
108  // Relative spec. format 20%
109  TString intValStr = TString(fNumValidationString.Strip(TString::kTrailing, '%'));
110 
111  if (intValStr.IsFloat()) {
112  Double_t valSizeAsDouble = fNumValidationString.Atof() / 100.0;
113  nValidationSamples = GetEventCollection(Types::kTraining).size() * valSizeAsDouble;
114  } else {
115  Log() << kFATAL << "Cannot parse number \"" << fNumValidationString
116  << "\". Expected string like \"20%\" or \"20.0%\"." << Endl;
117  }
118  } else if (fNumValidationString.IsFloat()) {
119  Double_t valSizeAsDouble = fNumValidationString.Atof();
120 
121  if (valSizeAsDouble < 1.0) {
122  // Relative spec. format 0.2
123  nValidationSamples = GetEventCollection(Types::kTraining).size() * valSizeAsDouble;
124  } else {
125  // Absolute spec format 100 or 100.0
126  nValidationSamples = valSizeAsDouble;
127  }
128  } else {
129  Log() << kFATAL << "Cannot parse number \"" << fNumValidationString << "\". Expected string like \"0.2\" or \"100\"."
130  << Endl;
131  }
132 
133  // Value validation
134  // ----------------
135  if (nValidationSamples < 0) {
136  Log() << kFATAL << "Validation size \"" << fNumValidationString << "\" is negative." << Endl;
137  }
138 
139  if (nValidationSamples == 0) {
140  Log() << kFATAL << "Validation size \"" << fNumValidationString << "\" is zero." << Endl;
141  }
142 
143  if (nValidationSamples >= (Int_t)trainingSetSize) {
144  Log() << kFATAL << "Validation size \"" << fNumValidationString
145  << "\" is larger than or equal in size to training set (size=\"" << trainingSetSize << "\")." << Endl;
146  }
147 
148  return nValidationSamples;
149 }
150 
151 
153  // Set default filename for trained model if option is not used
155  fFilenameTrainedModel = GetWeightFileDir() + "/TrainedModel_" + GetName() + ".pt";
156  }
157 
158  // - set up number of threads for CPU if NumThreads option was specified
159  // `torch.set_num_threads` sets the number of threads that can be used to
160  // perform cpu operations like conv or mm (usually used by OpenMP or MKL).
161 
162  Log() << kINFO << "Using PyTorch - setting special configuration options " << Endl;
163  PyRunString("import torch", "Error importing pytorch");
164 
165  // run these above lines also in global namespace to make them visible overall
166  PyRun_String("import torch", Py_single_input, fGlobalNS, fGlobalNS);
167 
168  // check pytorch version
169  PyRunString("torch_major_version = int(torch.__version__.split('.')[0])");
170  PyObject *pyTorchVersion = PyDict_GetItemString(fLocalNS, "torch_major_version");
171  int torchVersion = PyLong_AsLong(pyTorchVersion);
172  Log() << kINFO << "Using PyTorch version " << torchVersion << Endl;
173 
174  // in case specify number of threads
175  int num_threads = fNumThreads;
176  if (num_threads > 0) {
177  Log() << kINFO << "Setting the CPU number of threads = " << num_threads << Endl;
178 
179  PyRunString(TString::Format("torch.set_num_threads(%d)", num_threads));
180  PyRunString(TString::Format("torch.set_num_interop_threads(%d)", num_threads));
181  }
182 
183  // Setup model, either the initial model from `fFilenameModel` or
184  // the trained model from `fFilenameTrainedModel`
185  if (fContinueTraining) Log() << kINFO << "Continue training with trained model" << Endl;
187 }
188 
189 
190 void MethodPyTorch::SetupPyTorchModel(bool loadTrainedModel) {
191  /*
192  * Load PyTorch model from file
193  */
194 
195  Log() << kINFO << " Setup PyTorch Model " << Endl;
196 
197  if (!fUserCodeName.IsNull()) {
198  Log() << kINFO << " Executing user initialization code from " << fUserCodeName << Endl;
199 
200  // run some python code provided by user for method initializations
201  FILE* fp;
202  fp = fopen(fUserCodeName, "r");
203  PyRun_SimpleFile(fp, fUserCodeName);
204  fclose(fp);
205  }
206 
207  PyRunString("print('custom objects for loading model : ',load_model_custom_objects)");
208 
209  // Setup the training method
210  PyRunString("fit = load_model_custom_objects[\"train_func\"]",
211  "Failed to load train function from file. Please use key: 'train_func' and pass training loop function as the value.");
212  Log() << kINFO << "Loaded pytorch train function: " << Endl;
213 
214 
215  // Setup Optimizer. Use SGD Optimizer as Default
216  PyRunString("if 'optimizer' in load_model_custom_objects:\n"
217  " optimizer = load_model_custom_objects['optimizer']\n"
218  "else:\n"
219  " optimizer = torch.optim.SGD\n",
220  "Please use key: 'optimizer' and pass a pytorch optimizer as the value for a custom optimizer.");
221  Log() << kINFO << "Loaded pytorch optimizer: " << Endl;
222 
223 
224  // Setup the loss criterion
225  PyRunString("criterion = load_model_custom_objects[\"criterion\"]",
226  "Failed to load loss function from file. Using MSE Loss as default. Please use key: 'criterion' and pass a pytorch loss function as the value.");
227  Log() << kINFO << "Loaded pytorch loss function: " << Endl;
228 
229 
230  // Setup the predict method
231  PyRunString("predict = load_model_custom_objects[\"predict_func\"]",
232  "Can't find user predict function object from file. Please use key: 'predict' and pass a predict function for evaluating the model as the value.");
233  Log() << kINFO << "Loaded pytorch predict function: " << Endl;
234 
235 
236  // Load already trained model or initial model
237  TString filenameLoadModel;
238  if (loadTrainedModel) {
239  filenameLoadModel = fFilenameTrainedModel;
240  }
241  else {
242  filenameLoadModel = fFilenameModel;
243  }
244  PyRunString("model = torch.jit.load('"+filenameLoadModel+"')",
245  "Failed to load PyTorch model from file: "+filenameLoadModel);
246  Log() << kINFO << "Load model from file: " << filenameLoadModel << Endl;
247 
248 
249  /*
250  * Init variables and weights
251  */
252 
253  // Get variables, classes and target numbers
254  fNVars = GetNVariables();
257  else Log() << kFATAL << "Selected analysis type is not implemented" << Endl;
258 
259  // Init evaluation (needed for getMvaValue)
260  fVals = new float[fNVars]; // holds values used for classification and regression
261  npy_intp dimsVals[2] = {(npy_intp)1, (npy_intp)fNVars};
262  PyArrayObject* pVals = (PyArrayObject*)PyArray_SimpleNewFromData(2, dimsVals, NPY_FLOAT, (void*)fVals);
263  PyDict_SetItemString(fLocalNS, "vals", (PyObject*)pVals);
264 
265  fOutput.resize(fNOutputs); // holds classification probabilities or regression output
266  npy_intp dimsOutput[2] = {(npy_intp)1, (npy_intp)fNOutputs};
267  PyArrayObject* pOutput = (PyArrayObject*)PyArray_SimpleNewFromData(2, dimsOutput, NPY_FLOAT, (void*)&fOutput[0]);
268  PyDict_SetItemString(fLocalNS, "output", (PyObject*)pOutput);
269 
270  // Mark the model as setup
271  fModelIsSetup = true;
272 }
273 
274 
276 
278 
279  if (!PyIsInitialized()) {
280  Log() << kFATAL << "Python is not initialized" << Endl;
281  }
282  _import_array(); // required to use numpy arrays
283 
284  // Import PyTorch
285  PyRunString("import sys; sys.argv = ['']", "Set sys.argv failed");
286  PyRunString("import torch", "import PyTorch failed");
287  // do import also in global namespace
288  auto ret = PyRun_String("import torch", Py_single_input, fGlobalNS, fGlobalNS);
289  if (!ret)
290  Log() << kFATAL << "import torch in global namespace failed!" << Endl;
291 
292  // Set flag that model is not setup
293  fModelIsSetup = false;
294 }
295 
296 
298  if(!fModelIsSetup) Log() << kFATAL << "Model is not setup for training" << Endl;
299 
300  /*
301  * Load training data to numpy array.
302  * NOTE: These are later forced to be converted into torch tensors throught the training loop which may not be the ideal method.
303  */
304 
305  UInt_t nAllEvents = Data()->GetNTrainingEvents();
306  UInt_t nValEvents = GetNumValidationSamples();
307  UInt_t nTrainingEvents = nAllEvents - nValEvents;
308 
309  Log() << kINFO << "Split TMVA training data in " << nTrainingEvents << " training events and "
310  << nValEvents << " validation events" << Endl;
311 
312  float* trainDataX = new float[nTrainingEvents*fNVars];
313  float* trainDataY = new float[nTrainingEvents*fNOutputs];
314  float* trainDataWeights = new float[nTrainingEvents];
315  for (UInt_t i=0; i<nTrainingEvents; i++) {
316  const TMVA::Event* e = GetTrainingEvent(i);
317  // Fill variables
318  for (UInt_t j=0; j<fNVars; j++) {
319  trainDataX[j + i*fNVars] = e->GetValue(j);
320  }
321  // Fill targets
322  // NOTE: For classification, convert class number in one-hot vector,
323  // e.g., 1 -> [0, 1] or 0 -> [1, 0] for binary classification
325  for (UInt_t j=0; j<fNOutputs; j++) {
326  trainDataY[j + i*fNOutputs] = 0;
327  }
328  trainDataY[e->GetClass() + i*fNOutputs] = 1;
329  }
330  else if (GetAnalysisType() == Types::kRegression) {
331  for (UInt_t j=0; j<fNOutputs; j++) {
332  trainDataY[j + i*fNOutputs] = e->GetTarget(j);
333  }
334  }
335  else Log() << kFATAL << "Can not fill target vector because analysis type is not known" << Endl;
336  // Fill weights
337  // NOTE: If no weight branch is given, this defaults to ones for all events
338  trainDataWeights[i] = e->GetWeight();
339  }
340 
341  npy_intp dimsTrainX[2] = {(npy_intp)nTrainingEvents, (npy_intp)fNVars};
342  npy_intp dimsTrainY[2] = {(npy_intp)nTrainingEvents, (npy_intp)fNOutputs};
343  npy_intp dimsTrainWeights[1] = {(npy_intp)nTrainingEvents};
344  PyArrayObject* pTrainDataX = (PyArrayObject*)PyArray_SimpleNewFromData(2, dimsTrainX, NPY_FLOAT, (void*)trainDataX);
345  PyArrayObject* pTrainDataY = (PyArrayObject*)PyArray_SimpleNewFromData(2, dimsTrainY, NPY_FLOAT, (void*)trainDataY);
346  PyArrayObject* pTrainDataWeights = (PyArrayObject*)PyArray_SimpleNewFromData(1, dimsTrainWeights, NPY_FLOAT, (void*)trainDataWeights);
347  PyDict_SetItemString(fLocalNS, "trainX", (PyObject*)pTrainDataX);
348  PyDict_SetItemString(fLocalNS, "trainY", (PyObject*)pTrainDataY);
349  PyDict_SetItemString(fLocalNS, "trainWeights", (PyObject*)pTrainDataWeights);
350 
351  /*
352  * Load validation data to numpy array
353  */
354 
355  // NOTE: TMVA Validation data is a subset of all the training data
356  // we will not use test data for validation. They will be used for the real testing
357 
358 
359  float* valDataX = new float[nValEvents*fNVars];
360  float* valDataY = new float[nValEvents*fNOutputs];
361  float* valDataWeights = new float[nValEvents];
362  //validation events follows the trainig one in the TMVA training vector
363  for (UInt_t i=0; i< nValEvents ; i++) {
364  UInt_t ievt = nTrainingEvents + i; // TMVA event index
365  const TMVA::Event* e = GetTrainingEvent(ievt);
366  // Fill variables
367  for (UInt_t j=0; j<fNVars; j++) {
368  valDataX[j + i*fNVars] = e->GetValue(j);
369  }
370  // Fill targets
372  for (UInt_t j=0; j<fNOutputs; j++) {
373  valDataY[j + i*fNOutputs] = 0;
374  }
375  valDataY[e->GetClass() + i*fNOutputs] = 1;
376  }
377  else if (GetAnalysisType() == Types::kRegression) {
378  for (UInt_t j=0; j<fNOutputs; j++) {
379  valDataY[j + i*fNOutputs] = e->GetTarget(j);
380  }
381  }
382  else Log() << kFATAL << "Can not fill target vector because analysis type is not known" << Endl;
383  // Fill weights
384  valDataWeights[i] = e->GetWeight();
385  }
386 
387  npy_intp dimsValX[2] = {(npy_intp)nValEvents, (npy_intp)fNVars};
388  npy_intp dimsValY[2] = {(npy_intp)nValEvents, (npy_intp)fNOutputs};
389  npy_intp dimsValWeights[1] = {(npy_intp)nValEvents};
390  PyArrayObject* pValDataX = (PyArrayObject*)PyArray_SimpleNewFromData(2, dimsValX, NPY_FLOAT, (void*)valDataX);
391  PyArrayObject* pValDataY = (PyArrayObject*)PyArray_SimpleNewFromData(2, dimsValY, NPY_FLOAT, (void*)valDataY);
392  PyArrayObject* pValDataWeights = (PyArrayObject*)PyArray_SimpleNewFromData(1, dimsValWeights, NPY_FLOAT, (void*)valDataWeights);
393  PyDict_SetItemString(fLocalNS, "valX", (PyObject*)pValDataX);
394  PyDict_SetItemString(fLocalNS, "valY", (PyObject*)pValDataY);
395  PyDict_SetItemString(fLocalNS, "valWeights", (PyObject*)pValDataWeights);
396 
397  /*
398  * Train PyTorch model
399  */
400  Log() << kINFO << "Print Training Model Architecture" << Endl;
401  PyRunString("print(model)");
402 
403  // Setup parameters
404 
405  PyObject* pBatchSize = PyLong_FromLong(fBatchSize);
406  PyObject* pNumEpochs = PyLong_FromLong(fNumEpochs);
407  PyDict_SetItemString(fLocalNS, "batchSize", pBatchSize);
408  PyDict_SetItemString(fLocalNS, "numEpochs", pNumEpochs);
409 
410  // Prepare PyTorch Training DataSet
411  PyRunString("train_dataset = torch.utils.data.TensorDataset(torch.Tensor(trainX), torch.Tensor(trainY))",
412  "Failed to create pytorch train Dataset.");
413  // Prepare PyTorch Training Dataloader
414  PyRunString("train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batchSize, shuffle=False)",
415  "Failed to create pytorch train Dataloader.");
416 
417 
418  // Prepare PyTorch Validation DataSet
419  PyRunString("val_dataset = torch.utils.data.TensorDataset(torch.Tensor(valX), torch.Tensor(valY))",
420  "Failed to create pytorch validation Dataset.");
421  // Prepare PyTorch validation Dataloader
422  PyRunString("val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batchSize, shuffle=False)",
423  "Failed to create pytorch validation Dataloader.");
424 
425 
426  // Learning Rate Scheduler
427  if (fLearningRateSchedule!="") {
428  // Setup a python dictionary with the desired learning rate steps
429  PyRunString("strScheduleSteps = '"+fLearningRateSchedule+"'\n"
430  "schedulerSteps = {}\n"
431  "for c in strScheduleSteps.split(';'):\n"
432  " x = c.split(',')\n"
433  " schedulerSteps[int(x[0])] = float(x[1])\n",
434  "Failed to setup steps for scheduler function from string: "+fLearningRateSchedule,
435  Py_file_input);
436  // Set scheduler function as piecewise function with given steps
437  PyRunString("def schedule(optimizer, epoch, schedulerSteps=schedulerSteps):\n"
438  " if epoch in schedulerSteps:\n"
439  " for param_group in optimizer.param_groups:\n"
440  " param_group['lr'] = float(schedulerSteps[epoch])\n",
441  "Failed to setup scheduler function with string: "+fLearningRateSchedule,
442  Py_file_input);
443 
444  Log() << kINFO << "Option LearningRateSchedule: Set learning rate during training: " << fLearningRateSchedule << Endl;
445  }
446  else{
447  PyRunString("schedule = None; schedulerSteps = None", "Failed to set scheduler to None.");
448  }
449 
450 
451  // Save only weights with smallest validation loss
452  if (fSaveBestOnly) {
453  PyRunString("def save_best(model, curr_val, best_val, save_path='"+fFilenameTrainedModel+"'):\n"
454  " if curr_val<=best_val:\n"
455  " best_val = curr_val\n"
456  " best_model_jitted = torch.jit.script(model)\n"
457  " torch.jit.save(best_model_jitted, save_path)\n"
458  " return best_val",
459  "Failed to setup training with option: SaveBestOnly");
460  Log() << kINFO << "Option SaveBestOnly: Only model weights with smallest validation loss will be stored" << Endl;
461  }
462  else{
463  PyRunString("save_best = None", "Failed to set save_best to None.");
464  }
465 
466 
467  // Note: Early Stopping should not be implemented here. Can be implemented inside train loop function by user if required.
468 
469  // Train model
470  PyRunString("trained_model = fit(model, train_loader, val_loader, num_epochs=numEpochs, batch_size=batchSize,"
471  "optimizer=optimizer, criterion=criterion, save_best=save_best, scheduler=(schedule, schedulerSteps))",
472  "Failed to train model");
473 
474 
475  // Note: PyTorch doesn't store training history data unlike Keras. A user can append and save the loss,
476  // accuracy, other metrics etc to a file for later use.
477 
478  /*
479  * Store trained model to file (only if option 'SaveBestOnly' is NOT activated,
480  * because we do not want to override the best model checkpoint)
481  */
482  if (!fSaveBestOnly) {
483  PyRunString("trained_model_jitted = torch.jit.script(trained_model)",
484  "Model not scriptable. Failed to convert to torch script.");
485  PyRunString("torch.jit.save(trained_model_jitted, '"+fFilenameTrainedModel+"')",
486  "Failed to save trained model: "+fFilenameTrainedModel);
487  Log() << kINFO << "Trained model written to file: " << fFilenameTrainedModel << Endl;
488  }
489 
490  /*
491  * Clean-up
492  */
493 
494  delete[] trainDataX;
495  delete[] trainDataY;
496  delete[] trainDataWeights;
497  delete[] valDataX;
498  delete[] valDataY;
499  delete[] valDataWeights;
500 }
501 
502 
505 }
506 
508  // Cannot determine error
509  NoErrorCalc(errLower, errUpper);
510 
511  // Check whether the model is setup
512  // NOTE: unfortunately this is needed because during evaluation ProcessOptions is not called again
513  if (!fModelIsSetup) {
514  // Setup the trained model
515  SetupPyTorchModel(true);
516  }
517 
518  // Get signal probability (called mvaValue here)
519  const TMVA::Event* e = GetEvent();
520  for (UInt_t i=0; i<fNVars; i++) fVals[i] = e->GetValue(i);
521  PyRunString("for i,p in enumerate(predict(model, vals)): output[i]=p\n",
522  "Failed to get predictions");
523 
524 
526 }
527 
528 
529 std::vector<Double_t> MethodPyTorch::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress) {
530  // Check whether the model is setup
531  // NOTE: Unfortunately this is needed because during evaluation ProcessOptions is not called again
532  if (!fModelIsSetup) {
533  // Setup the trained model
534  SetupPyTorchModel(true);
535  }
536 
537  // Load data to numpy array
538  Long64_t nEvents = Data()->GetNEvents();
539  if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
540  if (firstEvt < 0) firstEvt = 0;
541  nEvents = lastEvt-firstEvt;
542 
543  // use timer
544  Timer timer( nEvents, GetName(), kTRUE );
545 
546  if (logProgress)
547  Log() << kHEADER << Form("[%s] : ",DataInfo().GetName())
548  << "Evaluation of " << GetMethodName() << " on "
549  << (Data()->GetCurrentType() == Types::kTraining ? "training" : "testing")
550  << " sample (" << nEvents << " events)" << Endl;
551 
552  float* data = new float[nEvents*fNVars];
553  for (UInt_t i=0; i<nEvents; i++) {
554  Data()->SetCurrentEvent(i);
555  const TMVA::Event *e = GetEvent();
556  for (UInt_t j=0; j<fNVars; j++) {
557  data[j + i*fNVars] = e->GetValue(j);
558  }
559  }
560 
561  npy_intp dimsData[2] = {(npy_intp)nEvents, (npy_intp)fNVars};
562  PyArrayObject* pDataMvaValues = (PyArrayObject*)PyArray_SimpleNewFromData(2, dimsData, NPY_FLOAT, (void*)data);
563  if (pDataMvaValues==0) Log() << "Failed to load data to Python array" << Endl;
564 
565 
566  // Get prediction for all events
567  PyObject* pModel = PyDict_GetItemString(fLocalNS, "model");
568  if (pModel==0) Log() << kFATAL << "Failed to get model Python object" << Endl;
569 
570  PyObject* pPredict = PyDict_GetItemString(fLocalNS, "predict");
571  if (pPredict==0) Log() << kFATAL << "Failed to get Python predict function" << Endl;
572 
573 
574  // Using PyTorch User Defined predict function for predictions
575  PyArrayObject* pPredictions = (PyArrayObject*) PyObject_CallFunctionObjArgs(pPredict, pModel, pDataMvaValues, NULL);
576  if (pPredictions==0) Log() << kFATAL << "Failed to get predictions" << Endl;
577  delete[] data;
578 
579  // Load predictions to double vector
580  // NOTE: The signal probability is given at the output
581  std::vector<double> mvaValues(nEvents);
582  float* predictionsData = (float*) PyArray_DATA(pPredictions);
583  for (UInt_t i=0; i<nEvents; i++) {
584  mvaValues[i] = (double) predictionsData[i*fNOutputs + TMVA::Types::kSignal];
585  }
586 
587  if (logProgress) {
588  Log() << kINFO
589  << "Elapsed time for evaluation of " << nEvents << " events: "
590  << timer.GetElapsedTime() << " " << Endl;
591  }
592 
593  return mvaValues;
594 }
595 
596 std::vector<Float_t>& MethodPyTorch::GetRegressionValues() {
597  // Check whether the model is setup
598  // NOTE: unfortunately this is needed because during evaluation ProcessOptions is not called again
599  if (!fModelIsSetup){
600  // Setup the model and load weights
601  SetupPyTorchModel(true);
602  }
603 
604  // Get regression values
605  const TMVA::Event* e = GetEvent();
606  for (UInt_t i=0; i<fNVars; i++) fVals[i] = e->GetValue(i);
607 
608  PyRunString("for i,p in enumerate(predict(model, vals)): output[i]=p\n",
609  "Failed to get predictions");
610 
611 
612  // Use inverse transformation of targets to get final regression values
613  Event * eTrans = new Event(*e);
614  for (UInt_t i=0; i<fNOutputs; ++i) {
615  eTrans->SetTarget(i,fOutput[i]);
616  }
617 
618  const Event* eTrans2 = GetTransformationHandler().InverseTransform(eTrans);
619  for (UInt_t i=0; i<fNOutputs; ++i) {
620  fOutput[i] = eTrans2->GetTarget(i);
621  }
622 
623  return fOutput;
624 }
625 
626 std::vector<Float_t>& MethodPyTorch::GetMulticlassValues() {
627  // Check whether the model is setup
628  // NOTE: unfortunately this is needed because during evaluation ProcessOptions is not called again
629  if (!fModelIsSetup){
630  // Setup the model and load weights
631  SetupPyTorchModel(true);
632  }
633 
634  // Get class probabilites
635  const TMVA::Event* e = GetEvent();
636  for (UInt_t i=0; i<fNVars; i++) fVals[i] = e->GetValue(i);
637  PyRunString("for i,p in enumerate(predict(model, vals)): output[i]=p\n",
638  "Failed to get predictions");
639 
640  return fOutput;
641 }
642 
643 
645 }
646 
647 
649  Log() << Endl;
650  Log() << "PyTorch is a scientific computing package supporting" << Endl;
651  Log() << "automatic differentiation. This method wraps the training" << Endl;
652  Log() << "and predictions steps of the PyTorch Python package for" << Endl;
653  Log() << "TMVA, so that dataloading, preprocessing and evaluation" << Endl;
654  Log() << "can be done within the TMVA system. To use this PyTorch" << Endl;
655  Log() << "interface, you need to generatea model with PyTorch first." << Endl;
656  Log() << "Then, this model can be loaded and trained in TMVA." << Endl;
657  Log() << Endl;
658 }
TMVA::MethodBase::TestClassification
virtual void TestClassification()
initialization
Definition: MethodBase.cxx:1125
TMVA::MethodBase::GetTransformationHandler
TransformationHandler & GetTransformationHandler(Bool_t takeReroutedIfAvailable=true)
Definition: MethodBase.h:394
TMVA::MethodPyTorch::Init
void Init()
Definition: MethodPyTorch.cxx:275
kTRUE
const Bool_t kTRUE
Definition: RtypesCore.h:100
TMVA::Configurable::Log
MsgLogger & Log() const
Definition: Configurable.h:122
TMVA::DataSet::GetCurrentType
Types::ETreeType GetCurrentType() const
Definition: DataSet.h:194
TMVA::Types::kMulticlass
@ kMulticlass
Definition: Types.h:131
e
#define e(i)
Definition: RSha256.hxx:103
TMVA::MethodBase::Data
DataSet * Data() const
Definition: MethodBase.h:409
TMVA::MethodBase::GetAnalysisType
Types::EAnalysisType GetAnalysisType() const
Definition: MethodBase.h:437
TMVA::PyMethodBase
Definition: PyMethodBase.h:56
TMVA::MethodPyTorch::MethodPyTorch
MethodPyTorch(const TString &jobName, const TString &methodTitle, DataSetInfo &dsi, const TString &theOption="")
Definition: MethodPyTorch.cxx:38
TMVA::Types::kRegression
@ kRegression
Definition: Types.h:130
PyObject
_object PyObject
Definition: PyMethodBase.h:42
TMVA::MethodPyTorch::TestClassification
virtual void TestClassification()
initialization
Definition: MethodPyTorch.cxx:503
TMVA::MethodPyTorch::fNumEpochs
UInt_t fNumEpochs
Definition: MethodPyTorch.h:78
TMVA::MethodPyTorch::DeclareOptions
void DeclareOptions()
Definition: MethodPyTorch.cxx:74
TMVA::MethodPyTorch::fNVars
UInt_t fNVars
Definition: MethodPyTorch.h:92
ClassImp
#define ClassImp(name)
Definition: Rtypes.h:364
Form
char * Form(const char *fmt,...)
TMVA::MethodPyTorch::fUserCodeName
TString fUserCodeName
Definition: MethodPyTorch.h:87
Long64_t
long long Long64_t
Definition: RtypesCore.h:80
TMath::Log
Double_t Log(Double_t x)
Definition: TMath.h:760
TMVA::MethodPyTorch::GetNumValidationSamples
UInt_t GetNumValidationSamples()
Validation of the ValidationSize option.
Definition: MethodPyTorch.cxx:100
TMVA::MethodPyTorch::fContinueTraining
Bool_t fContinueTraining
Definition: MethodPyTorch.h:81
TMVA::MethodPyTorch::fVals
float * fVals
Definition: MethodPyTorch.h:90
TMVA::Event::GetTarget
Float_t GetTarget(UInt_t itgt) const
Definition: Event.h:102
TMVA::Event::SetTarget
void SetTarget(UInt_t itgt, Float_t value)
set the target value (dimension itgt) to value
Definition: Event.cxx:359
TMVA::MethodPyTorch::fOutput
std::vector< float > fOutput
Definition: MethodPyTorch.h:91
TString::IsFloat
Bool_t IsFloat() const
Returns kTRUE if string contains a floating point or integer number.
Definition: TString.cxx:1813
TString::Format
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition: TString.cxx:2333
VariableTransformBase.h
TString
Basic string class.
Definition: TString.h:136
TMVA::Types::kSignal
@ kSignal
Definition: Types.h:137
TMVA::MethodPyTorch::GetHelpMessage
void GetHelpMessage() const
Definition: MethodPyTorch.cxx:648
TMVA::MethodPyTorch::fFilenameModel
TString fFilenameModel
Definition: MethodPyTorch.h:76
REGISTER_METHOD
#define REGISTER_METHOD(CLASS)
for example
Definition: ClassifierFactory.h:124
TMVA::MethodPyTorch
Definition: MethodPyTorch.h:34
bool
TMVA::MethodBase::GetNVariables
UInt_t GetNVariables() const
Definition: MethodBase.h:345
TMVA::MethodPyTorch::GetMvaValue
Double_t GetMvaValue(Double_t *errLower, Double_t *errUpper)
Definition: MethodPyTorch.cxx:507
TMVA::MethodBase::DataInfo
DataSetInfo & DataInfo() const
Definition: MethodBase.h:410
TMVA::TransformationHandler::InverseTransform
const Event * InverseTransform(const Event *, Bool_t suppressIfNoTargets=true) const
Definition: TransformationHandler.cxx:167
TMVA::DataSetInfo::GetNClasses
UInt_t GetNClasses() const
Definition: DataSetInfo.h:155
Py_single_input
#define Py_single_input
Definition: PyMethodBase.h:44
TMVA::PyMethodBase::fLocalNS
PyObject * fLocalNS
Definition: PyMethodBase.h:131
TMVA::MethodPyTorch::fLearningRateSchedule
TString fLearningRateSchedule
Definition: MethodPyTorch.h:83
TMVA::DataSetInfo
Class that contains all the data information.
Definition: DataSetInfo.h:62
TMVA::MethodBase::GetTrainingEvent
const Event * GetTrainingEvent(Long64_t ievt) const
Definition: MethodBase.h:771
TMVA::Timer::GetElapsedTime
TString GetElapsedTime(Bool_t Scientific=kTRUE)
returns pretty string with elapsed time
Definition: Timer.cxx:146
TMVA::Internal::PyGILRAII::PyGILRAII
PyGILRAII()
Definition: MethodPyTorch.cxx:27
TMVA::MethodBase::GetMethodName
const TString & GetMethodName() const
Definition: MethodBase.h:331
Timer.h
TMVA::Internal::PyGILRAII::~PyGILRAII
~PyGILRAII()
Definition: MethodPyTorch.cxx:28
TString::kTrailing
@ kTrailing
Definition: TString.h:267
TMVA::Types::EAnalysisType
EAnalysisType
Definition: Types.h:128
TransformationHandler.h
TMVA::DataSet::GetNEvents
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:206
TMVA::MethodPyTorch::fModelIsSetup
bool fModelIsSetup
Definition: MethodPyTorch.h:89
MethodPyTorch.h
kFALSE
const Bool_t kFALSE
Definition: RtypesCore.h:101
TMVA::MethodPyTorch::fNumValidationString
TString fNumValidationString
Definition: MethodPyTorch.h:85
TMVA::Internal::PyGILRAII
Definition: MethodPyAdaBoost.cxx:47
TMVA::Types::kClassification
@ kClassification
Definition: Types.h:129
TMVA::MethodBase::NoErrorCalc
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
Definition: MethodBase.cxx:836
TMVA::MethodBase::GetWeightFileDir
const TString & GetWeightFileDir() const
Definition: MethodBase.h:492
double
double
Definition: Converters.cxx:939
TMVA::MethodPyTorch::fBatchSize
UInt_t fBatchSize
Definition: MethodPyTorch.h:77
TMVA::Types
Singleton class for Global types used by TMVA.
Definition: Types.h:73
Types.h
TMVA::PyMethodBase::PyRunString
void PyRunString(TString code, TString errorMessage="Failed to run python code", int start=Py_single_input)
Execute Python code from string.
Definition: PyMethodBase.cxx:317
TMVA::PyMethodBase::fGlobalNS
static PyObject * fGlobalNS
Definition: PyMethodBase.h:130
TMVA::Endl
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Config.h
unsigned int
TMVA::Timer
Timing information for training and evaluation of MVA methods.
Definition: Timer.h:58
TMVA::MethodPyTorch::GetRegressionValues
std::vector< Float_t > & GetRegressionValues()
Definition: MethodPyTorch.cxx:596
TMVA::Types::kTraining
@ kTraining
Definition: Types.h:145
TMVA::MethodPyTorch::Train
void Train()
Definition: MethodPyTorch.cxx:297
TMVA::MethodPyTorch::GetMvaValues
std::vector< Double_t > GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
get all the MVA values for the events of the current Data type
Definition: MethodPyTorch.cxx:529
TMVA::DataSetInfo::GetNTargets
UInt_t GetNTargets() const
Definition: DataSetInfo.h:128
TMVA::DataSet::SetCurrentEvent
void SetCurrentEvent(Long64_t ievt) const
Definition: DataSet.h:88
TString::IsNull
Bool_t IsNull() const
Definition: TString.h:407
Double_t
double Double_t
Definition: RtypesCore.h:59
TMVA::MethodPyTorch::ReadModelFromFile
void ReadModelFromFile()
Definition: MethodPyTorch.cxx:644
TMVA::kFATAL
@ kFATAL
Definition: Types.h:63
TMVA::MethodPyTorch::fFilenameTrainedModel
TString fFilenameTrainedModel
Definition: MethodPyTorch.h:94
TMVA::MethodBase::GetName
const char * GetName() const
Definition: MethodBase.h:334
TMVA::Event
Definition: Event.h:51
TMVA::MethodBase::GetEvent
const Event * GetEvent() const
Definition: MethodBase.h:751
TMVA::MethodPyTorch::~MethodPyTorch
~MethodPyTorch()
Definition: MethodPyTorch.cxx:62
TMVA::kHEADER
@ kHEADER
Definition: Types.h:65
TMVA::kINFO
@ kINFO
Definition: Types.h:60
TMVA::PyMethodBase::PyIsInitialized
static int PyIsInitialized()
Check Python interpreter initialization status.
Definition: PyMethodBase.cxx:245
Tools.h
TMVA::DataSet::GetNTrainingEvents
Long64_t GetNTrainingEvents() const
Definition: DataSet.h:68
TMVA::MethodPyTorch::fSaveBestOnly
Bool_t fSaveBestOnly
Definition: MethodPyTorch.h:82
ClassifierFactory.h
TMVA::MethodPyTorch::HasAnalysisType
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t)
Definition: MethodPyTorch.cxx:66
type
int type
Definition: TGX11.cxx:121
Results.h
TMVA::Configurable::DeclareOptionRef
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
TMVA::MethodPyTorch::ProcessOptions
void ProcessOptions()
Definition: MethodPyTorch.cxx:152
TMVA::MethodPyTorch::fNOutputs
UInt_t fNOutputs
Definition: MethodPyTorch.h:93
TMVA::MethodPyTorch::SetupPyTorchModel
void SetupPyTorchModel(Bool_t loadTrainedModel)
Definition: MethodPyTorch.cxx:190
TMVA::MethodPyTorch::GetMulticlassValues
std::vector< Float_t > & GetMulticlassValues()
Definition: MethodPyTorch.cxx:626
TMVA::Internal::PyGILRAII::m_GILState
PyGILState_STATE m_GILState
Definition: MethodPyAdaBoost.cxx:48
TMVA::MethodPyTorch::fNumThreads
Int_t fNumThreads
Definition: MethodPyTorch.h:79
TMVA
create variable transformations
Definition: GeneticMinimizer.h:22
int