Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodPyTorch.cxx
Go to the documentation of this file.
1// @(#)root/tmva/pymva $Id$
2// Author: Anirudh Dagar, 2020
3
4#include <Python.h>
6
7#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
8#include <numpy/arrayobject.h>
9
10#include "TMVA/Types.h"
11#include "TMVA/Config.h"
13#include "TMVA/Results.h"
16#include "TMVA/Tools.h"
17#include "TMVA/Timer.h"
18
19using namespace TMVA;
20
21namespace TMVA {
22namespace Internal {
23class PyGILRAII {
24 PyGILState_STATE m_GILState;
25
26public:
29};
30} // namespace Internal
31} // namespace TMVA
32
33REGISTER_METHOD(PyTorch)
34
35
36
38 : PyMethodBase(jobName, Types::kPyTorch, methodTitle, dsi, theOption) {
39 fNumEpochs = 10;
40 fBatchSize = 100;
41
42 fContinueTraining = false;
43 fSaveBestOnly = true;
44 fLearningRateSchedule = ""; // empty string deactivates learning rate scheduler
45 fFilenameTrainedModel = ""; // empty string sets output model filename to default (in "weights/" directory.)
46}
47
48
51 fNumEpochs = 10;
52 fBatchSize = 100;
53
54 fContinueTraining = false;
55 fSaveBestOnly = true;
56 fLearningRateSchedule = ""; // empty string deactivates learning rate scheduler
57 fFilenameTrainedModel = ""; // empty string sets output model filename to default (in "weights/" directory.)
58}
59
60
62 if (fPyVals != nullptr) Py_DECREF(fPyVals);
63 if (fPyOutput != nullptr) Py_DECREF(fPyOutput);
64}
65
66
73
74
76 DeclareOptionRef(fFilenameModel, "FilenameModel", "Filename of the initial PyTorch model");
77 DeclareOptionRef(fFilenameTrainedModel, "FilenameTrainedModel", "Filename of the trained output PyTorch model");
78 DeclareOptionRef(fBatchSize, "BatchSize", "Training batch size");
79 DeclareOptionRef(fNumEpochs, "NumEpochs", "Number of training epochs");
80
81 DeclareOptionRef(fContinueTraining, "ContinueTraining", "Load weights from previous training");
82 DeclareOptionRef(fSaveBestOnly, "SaveBestOnly", "Store only weights with smallest validation loss");
83 DeclareOptionRef(fLearningRateSchedule, "LearningRateSchedule", "Set new learning rate during training at specific epochs, e.g., \"50,0.01;70,0.005\"");
84
85 DeclareOptionRef(fNumValidationString = "20%", "ValidationSize", "Part of the training data to use for validation."
86 "Specify as 0.2 or 20% to use a fifth of the data set as validation set."
87 "Specify as 100 to use exactly 100 events. (Default: 20%)");
88 DeclareOptionRef(fUserCodeName = "", "UserCode", "Necessary python code provided by the user to be executed before loading and training the PyTorch Model");
89
90}
91
92
93////////////////////////////////////////////////////////////////////////////////
94/// Validation of the ValidationSize option. Allowed formats are 20%, 0.2 and
95/// 100 etc.
96/// - 20% and 0.2 selects 20% of the training set as validation data.
97/// - 100 selects 100 events as the validation data.
98///
99/// @return number of samples in validation set
100///
102{
104 UInt_t trainingSetSize = GetEventCollection(Types::kTraining).size();
105
106 // Parsing + Validation
107 // --------------------
108 if (fNumValidationString.EndsWith("%")) {
109 // Relative spec. format 20%
110 TString intValStr = TString(fNumValidationString.Strip(TString::kTrailing, '%'));
111
112 if (intValStr.IsFloat()) {
113 Double_t valSizeAsDouble = fNumValidationString.Atof() / 100.0;
114 nValidationSamples = GetEventCollection(Types::kTraining).size() * valSizeAsDouble;
115 } else {
116 Log() << kFATAL << "Cannot parse number \"" << fNumValidationString
117 << "\". Expected string like \"20%\" or \"20.0%\"." << Endl;
118 }
119 } else if (fNumValidationString.IsFloat()) {
120 Double_t valSizeAsDouble = fNumValidationString.Atof();
121
122 if (valSizeAsDouble < 1.0) {
123 // Relative spec. format 0.2
124 nValidationSamples = GetEventCollection(Types::kTraining).size() * valSizeAsDouble;
125 } else {
126 // Absolute spec format 100 or 100.0
128 }
129 } else {
130 Log() << kFATAL << "Cannot parse number \"" << fNumValidationString << "\". Expected string like \"0.2\" or \"100\"."
131 << Endl;
132 }
133
134 // Value validation
135 // ----------------
136 if (nValidationSamples < 0) {
137 Log() << kFATAL << "Validation size \"" << fNumValidationString << "\" is negative." << Endl;
138 }
139
140 if (nValidationSamples == 0) {
141 Log() << kFATAL << "Validation size \"" << fNumValidationString << "\" is zero." << Endl;
142 }
143
145 Log() << kFATAL << "Validation size \"" << fNumValidationString
146 << "\" is larger than or equal in size to training set (size=\"" << trainingSetSize << "\")." << Endl;
147 }
148
149 return nValidationSamples;
150}
151
152
154 // Set default filename for trained model if option is not used
156 fFilenameTrainedModel = GetWeightFileDir() + "/TrainedModel_" + GetName() + ".pt";
157 }
158
159 // - set up number of threads for CPU if NumThreads option was specified
160 // `torch.set_num_threads` sets the number of threads that can be used to
161 // perform cpu operations like conv or mm (usually used by OpenMP or MKL).
162
163 Log() << kINFO << "Using PyTorch - setting special configuration options " << Endl;
164 PyRunString("import torch", "Error importing pytorch");
165
166 // run these above lines also in global namespace to make them visible overall
168
169 // check pytorch version
170 PyRunString("torch_major_version = int(torch.__version__.split('.')[0])");
171 PyObject *pyTorchVersion = PyDict_GetItemString(fLocalNS, "torch_major_version");
173 Log() << kINFO << "Using PyTorch version " << torchVersion << Endl;
174
175 // in case specify number of threads
177 if (num_threads > 0) {
178 Log() << kINFO << "Setting the CPU number of threads = " << num_threads << Endl;
179
180 PyRunString(TString::Format("torch.set_num_threads(%d)", num_threads));
181 PyRunString(TString::Format("torch.set_num_interop_threads(%d)", num_threads));
182 }
183
184 // Setup model, either the initial model from `fFilenameModel` or
185 // the trained model from `fFilenameTrainedModel`
186 if (fContinueTraining) Log() << kINFO << "Continue training with trained model" << Endl;
188}
189
190
192 /*
193 * Load PyTorch model from file
194 */
195
196 Log() << kINFO << " Setup PyTorch Model for training" << Endl;
197
198 if (!fUserCodeName.IsNull()) {
199 Log() << kINFO << " Executing user initialization code from " << fUserCodeName << Endl;
200
201 // run some python code provided by user for method initializations
202 FILE* fp;
203 fp = fopen(fUserCodeName, "r");
204 if (fp) {
206 fclose(fp);
207 }
208 else
209 Log() << kFATAL << "Input user code is not existing : " << fUserCodeName << Endl;
210 }
211
212 PyRunString("print('custom objects for loading model : ',load_model_custom_objects)");
213
214 // Setup the training method
215 PyRunString("fit = load_model_custom_objects[\"train_func\"]",
216 "Failed to load train function from file. Please use key: 'train_func' and pass training loop function as the value.");
217 Log() << kINFO << "Loaded pytorch train function: " << Endl;
218
219
220 // Setup Optimizer. Use SGD Optimizer as Default
221 PyRunString("if 'optimizer' in load_model_custom_objects:\n"
222 " optimizer = load_model_custom_objects['optimizer']\n"
223 "else:\n"
224 " optimizer = torch.optim.SGD\n",
225 "Please use key: 'optimizer' and pass a pytorch optimizer as the value for a custom optimizer.");
226 Log() << kINFO << "Loaded pytorch optimizer: " << Endl;
227
228
229 // Setup the loss criterion
230 PyRunString("criterion = load_model_custom_objects[\"criterion\"]",
231 "Failed to load loss function from file. Using MSE Loss as default. Please use key: 'criterion' and pass a pytorch loss function as the value.");
232 Log() << kINFO << "Loaded pytorch loss function: " << Endl;
233
234
235 // Setup the predict method
236 PyRunString("predict = load_model_custom_objects[\"predict_func\"]",
237 "Can't find user predict function object from file. Please use key: 'predict' and pass a predict function for evaluating the model as the value.");
238 Log() << kINFO << "Loaded pytorch predict function: " << Endl;
239
240
241 // Load already trained model or initial model
243 if (loadTrainedModel) {
245 }
246 else {
248 }
249 PyRunString("model = torch.jit.load('"+filenameLoadModel+"')",
250 "Failed to load PyTorch model from file: "+filenameLoadModel);
251 Log() << kINFO << "Loaded model from file: " << filenameLoadModel << Endl;
252
253 // Get variables, classes and target numbers
259 else
260 Log() << kFATAL << "Selected analysis type is not implemented" << Endl;
261
262 if (fNVars == 0 || fNOutputs == 0) {
263 Log() << kERROR << "Model does not have a number of inputs or output. Setup failed" << Endl;
264 fModelIsSetup = false;
265 }
266 else {
267 // Mark the model as setup
268 fModelIsSetup = true;
269 }
270}
271
273{
274 // initialize python arays used in the model evaluation (prediction)
275 size_t inputSize = fNVars*nEvents;
276 size_t outputSize = fNOutputs*nEvents;
277
278 // Init evaluation by allocating the array with the right size
279
280 if (inputSize > 0 && (fVals.size() != inputSize || fPyVals == nullptr)) {
281 fVals.resize(inputSize);
282 npy_intp dimsVals[2] = {(npy_intp)nEvents, (npy_intp)fNVars};
283 if (fPyVals != nullptr) Py_DECREF(fPyVals); // delete previous object
285 if (!fPyVals)
286 Log() << kFATAL << "Failed to load data to Python array" << Endl;
288 }
289
290 if (outputSize > 0 && ( fOutput.size() != outputSize || fPyOutput == nullptr)) {
291 fOutput.resize(outputSize); // holds classification probabilities or regression output
292 // allocation of Python output array is needed only for single event evaluation
293 if (nEvents == 1) {
295 if (fPyOutput != nullptr) Py_DECREF(fPyOutput); // delete previous object
297 if (!fPyOutput)
298 Log() << kFATAL << "Failed to create output data Python array" << Endl;
300 }
301 }
302}
303
304
306
308
309 if (!PyIsInitialized()) {
310 Log() << kFATAL << "Python is not initialized" << Endl;
311 }
312 _import_array(); // required to use numpy arrays
313
314 // Import PyTorch
315 PyRunString("import sys; sys.argv = ['']", "Set sys.argv failed");
316 PyRunString("import torch", "import PyTorch failed");
317 // do import also in global namespace
318 auto ret = PyRun_String("import torch", Py_single_input, fGlobalNS, fGlobalNS);
319 if (!ret)
320 Log() << kFATAL << "import torch in global namespace failed!" << Endl;
321
322 // Set flag that model is not setup
323 fModelIsSetup = false;
324}
325
326
328 if(!fModelIsSetup) Log() << kFATAL << "Model is not setup for training" << Endl;
329
330 /*
331 * Load training data to numpy array.
332 * NOTE: These are later forced to be converted into torch tensors throught the training loop which may not be the ideal method.
333 */
334
338
339 Log() << kINFO << "Split TMVA training data in " << nTrainingEvents << " training events and "
340 << nValEvents << " validation events" << Endl;
341
342 float* trainDataX = new float[nTrainingEvents*fNVars];
343 float* trainDataY = new float[nTrainingEvents*fNOutputs];
344 float* trainDataWeights = new float[nTrainingEvents];
345 for (UInt_t i=0; i<nTrainingEvents; i++) {
346 const TMVA::Event* e = GetTrainingEvent(i);
347 // Fill variables
348 for (UInt_t j=0; j<fNVars; j++) {
349 trainDataX[j + i*fNVars] = e->GetValue(j);
350 }
351 // Fill targets
352 // NOTE: For classification, convert class number in one-hot vector,
353 // e.g., 1 -> [0, 1] or 0 -> [1, 0] for binary classification
355 for (UInt_t j=0; j<fNOutputs; j++) {
356 trainDataY[j + i*fNOutputs] = 0;
357 }
358 trainDataY[e->GetClass() + i*fNOutputs] = 1;
359 }
360 else if (GetAnalysisType() == Types::kRegression) {
361 for (UInt_t j=0; j<fNOutputs; j++) {
362 trainDataY[j + i*fNOutputs] = e->GetTarget(j);
363 }
364 }
365 else Log() << kFATAL << "Can not fill target vector because analysis type is not known" << Endl;
366 // Fill weights
367 // NOTE: If no weight branch is given, this defaults to ones for all events
368 trainDataWeights[i] = e->GetWeight();
369 }
370
380
381 /*
382 * Load validation data to numpy array
383 */
384
385 // NOTE: TMVA Validation data is a subset of all the training data
386 // we will not use test data for validation. They will be used for the real testing
387
388
389 float* valDataX = new float[nValEvents*fNVars];
390 float* valDataY = new float[nValEvents*fNOutputs];
391 float* valDataWeights = new float[nValEvents];
392 //validation events follows the trainig one in the TMVA training vector
393 for (UInt_t i=0; i< nValEvents ; i++) {
394 UInt_t ievt = nTrainingEvents + i; // TMVA event index
396 // Fill variables
397 for (UInt_t j=0; j<fNVars; j++) {
398 valDataX[j + i*fNVars] = e->GetValue(j);
399 }
400 // Fill targets
402 for (UInt_t j=0; j<fNOutputs; j++) {
403 valDataY[j + i*fNOutputs] = 0;
404 }
405 valDataY[e->GetClass() + i*fNOutputs] = 1;
406 }
407 else if (GetAnalysisType() == Types::kRegression) {
408 for (UInt_t j=0; j<fNOutputs; j++) {
409 valDataY[j + i*fNOutputs] = e->GetTarget(j);
410 }
411 }
412 else Log() << kFATAL << "Can not fill target vector because analysis type is not known" << Endl;
413 // Fill weights
414 valDataWeights[i] = e->GetWeight();
415 }
416
426
427 /*
428 * Train PyTorch model
429 */
430 Log() << kINFO << "Print Training Model Architecture" << Endl;
431 PyRunString("print(model)");
432
433 // Setup parameters
434
439
440 // Prepare PyTorch Training DataSet
441 PyRunString("train_dataset = torch.utils.data.TensorDataset(torch.Tensor(trainX), torch.Tensor(trainY))",
442 "Failed to create pytorch train Dataset.");
443 // Prepare PyTorch Training Dataloader
444 PyRunString("train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batchSize, shuffle=False)",
445 "Failed to create pytorch train Dataloader.");
446
447
448 // Prepare PyTorch Validation DataSet
449 PyRunString("val_dataset = torch.utils.data.TensorDataset(torch.Tensor(valX), torch.Tensor(valY))",
450 "Failed to create pytorch validation Dataset.");
451 // Prepare PyTorch validation Dataloader
452 PyRunString("val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batchSize, shuffle=False)",
453 "Failed to create pytorch validation Dataloader.");
454
455
456 // Learning Rate Scheduler
457 if (fLearningRateSchedule!="") {
458 // Setup a python dictionary with the desired learning rate steps
459 PyRunString("strScheduleSteps = '"+fLearningRateSchedule+"'\n"
460 "schedulerSteps = {}\n"
461 "for c in strScheduleSteps.split(';'):\n"
462 " x = c.split(',')\n"
463 " schedulerSteps[int(x[0])] = float(x[1])\n",
464 "Failed to setup steps for scheduler function from string: "+fLearningRateSchedule,
466 // Set scheduler function as piecewise function with given steps
467 PyRunString("def schedule(optimizer, epoch, schedulerSteps=schedulerSteps):\n"
468 " if epoch in schedulerSteps:\n"
469 " for param_group in optimizer.param_groups:\n"
470 " param_group['lr'] = float(schedulerSteps[epoch])\n",
471 "Failed to setup scheduler function with string: "+fLearningRateSchedule,
473
474 Log() << kINFO << "Option LearningRateSchedule: Set learning rate during training: " << fLearningRateSchedule << Endl;
475 }
476 else{
477 PyRunString("schedule = None; schedulerSteps = None", "Failed to set scheduler to None.");
478 }
479
480
481 // Save only weights with smallest validation loss
482 if (fSaveBestOnly) {
483 PyRunString("def save_best(model, curr_val, best_val, save_path='"+fFilenameTrainedModel+"'):\n"
484 " if curr_val<=best_val:\n"
485 " best_val = curr_val\n"
486 " best_model_jitted = torch.jit.script(model)\n"
487 " torch.jit.save(best_model_jitted, save_path)\n"
488 " return best_val",
489 "Failed to setup training with option: SaveBestOnly");
490 Log() << kINFO << "Option SaveBestOnly: Only model weights with smallest validation loss will be stored" << Endl;
491 }
492 else{
493 PyRunString("save_best = None", "Failed to set save_best to None.");
494 }
495
496
497 // Note: Early Stopping should not be implemented here. Can be implemented inside train loop function by user if required.
498
499 // Train model
500 PyRunString("trained_model = fit(model, train_loader, val_loader, num_epochs=numEpochs, batch_size=batchSize,"
501 "optimizer=optimizer, criterion=criterion, save_best=save_best, scheduler=(schedule, schedulerSteps))",
502 "Failed to train model");
503
504
505 // Note: PyTorch doesn't store training history data unlike Keras. A user can append and save the loss,
506 // accuracy, other metrics etc to a file for later use.
507
508 /*
509 * Store trained model to file (only if option 'SaveBestOnly' is NOT activated,
510 * because we do not want to override the best model checkpoint)
511 */
512 if (!fSaveBestOnly) {
513 PyRunString("trained_model_jitted = torch.jit.script(trained_model)",
514 "Model not scriptable. Failed to convert to torch script.");
515 PyRunString("torch.jit.save(trained_model_jitted, '"+fFilenameTrainedModel+"')",
516 "Failed to save trained model: "+fFilenameTrainedModel);
517 Log() << kINFO << "Trained model written to file: " << fFilenameTrainedModel << Endl;
518 }
519
520 /*
521 * Clean-up
522 */
523
524 delete[] trainDataX;
525 delete[] trainDataY;
526 delete[] trainDataWeights;
527 delete[] valDataX;
528 delete[] valDataY;
529 delete[] valDataWeights;
530}
531
532
536
538 // Cannot determine error
540
541 // Check whether the model is setup
542 // NOTE: unfortunately this is needed because during evaluation ProcessOptions is not called again
543 if (!fModelIsSetup) {
544 // Setup the trained model
545 SetupPyTorchModel(true);
546 }
548
549 // Get signal probability (called mvaValue here)
550 const TMVA::Event* e = GetEvent();
551 for (UInt_t i=0; i<fNVars; i++) fVals[i] = e->GetValue(i);
552 PyRunString("for i,p in enumerate(predict(model, vals)): output[i]=p\n",
553 "Failed to get predictions");
554
555
557}
558
559
560std::vector<Double_t> MethodPyTorch::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t /* logProgress */) {
561 // Check whether the model is setup
562 // NOTE: Unfortunately this is needed because during evaluation ProcessOptions is not called again
563 if (!fModelIsSetup) {
564 // Setup the trained model
565 SetupPyTorchModel(true);
566 }
567
568 // Load data to numpy array
569 Long64_t nEvents = Data()->GetNEvents();
570 if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
571 if (firstEvt < 0) firstEvt = 0;
572 nEvents = lastEvt-firstEvt;
573
574 InitEvaluation(nEvents);
575
576 assert (fVals.size() == fNVars*nEvents);
577 for (UInt_t i=0; i<nEvents; i++) {
578 Data()->SetCurrentEvent(i);
579 const TMVA::Event *e = GetEvent();
580 for (UInt_t j=0; j<fNVars; j++) {
581 fVals[j + i*fNVars] = e->GetValue(j);
582 }
583 }
584
585
586 // Get prediction for all events
588 if (pModel==0) Log() << kFATAL << "Failed to get model Python object" << Endl;
589
591 if (pPredict==0) Log() << kFATAL << "Failed to get Python predict function" << Endl;
592
593
594 // Using PyTorch User Defined predict function for predictions
596 if (pPredictions==0) Log() << kFATAL << "Failed to get predictions" << Endl;
597
598 // Load predictions to double vector
599 // NOTE: The signal probability is given at the output
600 std::vector<double> mvaValues(nEvents);
601 float* predictionsData = (float*) PyArray_DATA(pPredictions);
602 for (UInt_t i=0; i<nEvents; i++) {
604 }
605
607
608 return mvaValues;
609}
610
611std::vector<Float_t>& MethodPyTorch::GetRegressionValues() {
612 // Check whether the model is setup
613 // NOTE: unfortunately this is needed because during evaluation ProcessOptions is not called again
614 if (!fModelIsSetup){
615 // Setup the model and load weights
616 SetupPyTorchModel(true);
617 }
619
620 // Get regression values
621 const TMVA::Event* e = GetEvent();
622 for (UInt_t i=0; i<fNVars; i++) fVals[i] = e->GetValue(i);
623
624 PyRunString("for i,p in enumerate(predict(model, vals)): output[i]=p\n",
625 "Failed to get predictions");
626
627
628 // Use inverse transformation of targets to get final regression values
629 Event * eTrans = new Event(*e);
630 for (UInt_t i=0; i<fNOutputs; ++i) {
631 eTrans->SetTarget(i,fOutput[i]);
632 }
633
635 for (UInt_t i=0; i<fNOutputs; ++i) {
636 fOutput[i] = eTrans2->GetTarget(i);
637 }
638
639 return fOutput;
640}
641
643
644 if (!fModelIsSetup){
645 // Setup the model and load weights
646 SetupPyTorchModel(true);
647 }
648
649 auto nEvents = Data()->GetNEvents();
650 InitEvaluation(nEvents);
651
652
653 assert (fVals.size() == fNVars*nEvents);
654 assert (fOutput.size() == fNOutputs*nEvents);
655 for (UInt_t i=0; i<nEvents; i++) {
656 Data()->SetCurrentEvent(i);
657 const TMVA::Event *e = GetEvent();
658 for (UInt_t j=0; j<fNVars; j++) {
659 fVals[j + i*fNVars] = e->GetValue(j);
660 }
661 }
662
663 // Get prediction for all events
665 if (pModel==0) Log() << kFATAL << "Failed to get model Python object" << Endl;
666
668 if (pPredict==0) Log() << kFATAL << "Failed to get Python predict function" << Endl;
669
670 std::cout << " calling predict functon for regression \n";
671 // Using PyTorch User Defined predict function for predictions
673 if (pPredictions==0) Log() << kFATAL << "Failed to get predictions" << Endl;
674
675 // Load predictions to double vector
676 float* predictionsData = (float*) PyArray_DATA(pPredictions);
677
678 // need to loop on events since we use an inverse transformation to get final regression values
679 // this can be probably optimized
680 for (UInt_t ievt = 0; ievt < nEvents; ievt++) {
681 const TMVA::Event* e = GetEvent(ievt);
682 Event eTrans(*e);
683 for (UInt_t i = 0; i < fNOutputs; ++i) {
684 eTrans.SetTarget(i,predictionsData[ievt*fNOutputs + i]);
685 }
686 // apply the inverse transformation
688 for (UInt_t i = 0; i < fNOutputs; ++i) {
689 fOutput[ievt*fNOutputs + i] = eTrans2->GetTarget(i);
690 }
691 }
693
694 return fOutput;
695}
696
697std::vector<Float_t>& MethodPyTorch::GetMulticlassValues() {
698 // Check whether the model is setup
699 // NOTE: unfortunately this is needed because during evaluation ProcessOptions is not called again
700 if (!fModelIsSetup){
701 // Setup the model and load weights
702 SetupPyTorchModel(true);
703 }
705
706 // Get class probabilites
707 const TMVA::Event* e = GetEvent();
708 for (UInt_t i=0; i<fNVars; i++) fVals[i] = e->GetValue(i);
709 PyRunString("for i,p in enumerate(predict(model, vals)): output[i]=p\n",
710 "Failed to get predictions");
711
712 return fOutput;
713}
714
716 // Check whether the model is setup
717 if (!fModelIsSetup){
718 // Setup the model and load weights
719 SetupPyTorchModel(true);
720 }
721 auto nEvents = Data()->GetNEvents();
722 InitEvaluation(nEvents);
723
724 assert (fVals.size() == fNVars*nEvents);
725 assert (fOutput.size() == fNOutputs*nEvents);
726 for (UInt_t i=0; i<nEvents; i++) {
727 Data()->SetCurrentEvent(i);
728 const TMVA::Event *e = GetEvent();
729 for (UInt_t j=0; j<fNVars; j++) {
730 fVals[j + i*fNVars] = e->GetValue(j);
731 }
732 }
733
734 // Get prediction for all events
736 if (pModel==0) Log() << kFATAL << "Failed to get model Python object" << Endl;
737
739 if (pPredict==0) Log() << kFATAL << "Failed to get Python predict function" << Endl;
740
741
742 // Using PyTorch User Defined predict function for predictions
744 if (pPredictions==0) Log() << kFATAL << "Failed to get predictions" << Endl;
745
746 // Load predictions to double vector
747 float* predictionsData = (float*) PyArray_DATA(pPredictions);
748
749 std::copy(predictionsData, predictionsData+nEvents*fNOutputs, fOutput.begin());
750
752
753 return fOutput;
754}
755
756
759
760
762 Log() << Endl;
763 Log() << "PyTorch is a scientific computing package supporting" << Endl;
764 Log() << "automatic differentiation. This method wraps the training" << Endl;
765 Log() << "and predictions steps of the PyTorch Python package for" << Endl;
766 Log() << "TMVA, so that dataloading, preprocessing and evaluation" << Endl;
767 Log() << "can be done within the TMVA system. To use this PyTorch" << Endl;
768 Log() << "interface, you need to generatea model with PyTorch first." << Endl;
769 Log() << "Then, this model can be loaded and trained in TMVA." << Endl;
770 Log() << Endl;
771}
#define REGISTER_METHOD(CLASS)
for example
_object PyObject
#define Py_single_input
#define e(i)
Definition RSha256.hxx:103
constexpr Bool_t kFALSE
Definition RtypesCore.h:108
long long Long64_t
Portable signed long integer 8 bytes.
Definition RtypesCore.h:83
constexpr Bool_t kTRUE
Definition RtypesCore.h:107
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
MsgLogger & Log() const
Class that contains all the data information.
Definition DataSetInfo.h:62
UInt_t GetNClasses() const
UInt_t GetNTargets() const
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition DataSet.h:206
Long64_t GetNTrainingEvents() const
Definition DataSet.h:68
void SetCurrentEvent(Long64_t ievt) const
Definition DataSet.h:88
const char * GetName() const override
Definition MethodBase.h:337
Types::EAnalysisType GetAnalysisType() const
Definition MethodBase.h:440
const TString & GetWeightFileDir() const
Definition MethodBase.h:495
const Event * GetEvent() const
Definition MethodBase.h:754
DataSetInfo & DataInfo() const
Definition MethodBase.h:413
virtual void TestClassification()
initialization
UInt_t GetNVariables() const
Definition MethodBase.h:348
TransformationHandler & GetTransformationHandler(Bool_t takeReroutedIfAvailable=true)
Definition MethodBase.h:397
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
DataSet * Data() const
Definition MethodBase.h:412
const Event * GetTrainingEvent(Long64_t ievt) const
Definition MethodBase.h:774
void InitEvaluation(size_t nEvents)
void GetHelpMessage() const override
void Train() override
void Init() override
Double_t GetMvaValue(Double_t *errLower, Double_t *errUpper) override
std::vector< Float_t > & GetRegressionValues() override
void ProcessOptions() override
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t) override
std::vector< float > fOutput
void ReadModelFromFile() override
MethodPyTorch(const TString &jobName, const TString &methodTitle, DataSetInfo &dsi, const TString &theOption="")
std::vector< Float_t > GetAllMulticlassValues() override
Get all multi-class values.
std::vector< Double_t > GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress) override
get all the MVA values for the events of the current Data type
std::vector< float > fVals
std::vector< Float_t > & GetMulticlassValues() override
std::vector< Float_t > GetAllRegressionValues() override
Get al regression values in one call.
void TestClassification() override
initialization
UInt_t GetNumValidationSamples()
Validation of the ValidationSize option.
TString fLearningRateSchedule
TString fFilenameTrainedModel
void SetupPyTorchModel(Bool_t loadTrainedModel)
void DeclareOptions() override
Virtual base class for all TMVA method based on Python.
static int PyIsInitialized()
Check Python interpreter initialization status.
static PyObject * fGlobalNS
void PyRunString(TString code, TString errorMessage="Failed to run python code", int start=256)
Execute Python code from string.
const Event * InverseTransform(const Event *, Bool_t suppressIfNoTargets=true) const
Singleton class for Global types used by TMVA.
Definition Types.h:71
@ kSignal
Never change this number - it is elsewhere assumed to be zero !
Definition Types.h:135
@ kMulticlass
Definition Types.h:129
@ kClassification
Definition Types.h:127
@ kRegression
Definition Types.h:128
@ kTraining
Definition Types.h:143
Basic string class.
Definition TString.h:138
@ kTrailing
Definition TString.h:284
Bool_t IsNull() const
Definition TString.h:422
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2384
create variable transformations
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148