Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodPyTorch.cxx
Go to the documentation of this file.
1// @(#)root/tmva/pymva $Id$
2// Author: Anirudh Dagar, 2020
3
4#include <Python.h>
6
7#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
8#include <numpy/arrayobject.h>
9
10#include "TMVA/Types.h"
11#include "TMVA/Config.h"
13#include "TMVA/Results.h"
16#include "TMVA/Tools.h"
17#include "TMVA/Timer.h"
18
19using namespace TMVA;
20
21namespace TMVA {
22namespace Internal {
23class PyGILRAII {
24 PyGILState_STATE m_GILState;
25
26public:
29};
30} // namespace Internal
31} // namespace TMVA
32
33REGISTER_METHOD(PyTorch)
34
35
36
38 : PyMethodBase(jobName, Types::kPyTorch, methodTitle, dsi, theOption) {
39 fNumEpochs = 10;
40 fBatchSize = 100;
41
42 fContinueTraining = false;
43 fSaveBestOnly = true;
44 fLearningRateSchedule = ""; // empty string deactivates learning rate scheduler
45 fFilenameTrainedModel = ""; // empty string sets output model filename to default (in "weights/" directory.)
46}
47
48
51 fNumEpochs = 10;
52 fBatchSize = 100;
53
54 fContinueTraining = false;
55 fSaveBestOnly = true;
56 fLearningRateSchedule = ""; // empty string deactivates learning rate scheduler
57 fFilenameTrainedModel = ""; // empty string sets output model filename to default (in "weights/" directory.)
58}
59
60
63
64
71
72
74 DeclareOptionRef(fFilenameModel, "FilenameModel", "Filename of the initial PyTorch model");
75 DeclareOptionRef(fFilenameTrainedModel, "FilenameTrainedModel", "Filename of the trained output PyTorch model");
76 DeclareOptionRef(fBatchSize, "BatchSize", "Training batch size");
77 DeclareOptionRef(fNumEpochs, "NumEpochs", "Number of training epochs");
78
79 DeclareOptionRef(fContinueTraining, "ContinueTraining", "Load weights from previous training");
80 DeclareOptionRef(fSaveBestOnly, "SaveBestOnly", "Store only weights with smallest validation loss");
81 DeclareOptionRef(fLearningRateSchedule, "LearningRateSchedule", "Set new learning rate during training at specific epochs, e.g., \"50,0.01;70,0.005\"");
82
83 DeclareOptionRef(fNumValidationString = "20%", "ValidationSize", "Part of the training data to use for validation."
84 "Specify as 0.2 or 20% to use a fifth of the data set as validation set."
85 "Specify as 100 to use exactly 100 events. (Default: 20%)");
86 DeclareOptionRef(fUserCodeName = "", "UserCode", "Necessary python code provided by the user to be executed before loading and training the PyTorch Model");
87
88}
89
90
91////////////////////////////////////////////////////////////////////////////////
92/// Validation of the ValidationSize option. Allowed formats are 20%, 0.2 and
93/// 100 etc.
94/// - 20% and 0.2 selects 20% of the training set as validation data.
95/// - 100 selects 100 events as the validation data.
96///
97/// @return number of samples in validation set
98///
100{
102 UInt_t trainingSetSize = GetEventCollection(Types::kTraining).size();
103
104 // Parsing + Validation
105 // --------------------
106 if (fNumValidationString.EndsWith("%")) {
107 // Relative spec. format 20%
108 TString intValStr = TString(fNumValidationString.Strip(TString::kTrailing, '%'));
109
110 if (intValStr.IsFloat()) {
111 Double_t valSizeAsDouble = fNumValidationString.Atof() / 100.0;
112 nValidationSamples = GetEventCollection(Types::kTraining).size() * valSizeAsDouble;
113 } else {
114 Log() << kFATAL << "Cannot parse number \"" << fNumValidationString
115 << "\". Expected string like \"20%\" or \"20.0%\"." << Endl;
116 }
117 } else if (fNumValidationString.IsFloat()) {
118 Double_t valSizeAsDouble = fNumValidationString.Atof();
119
120 if (valSizeAsDouble < 1.0) {
121 // Relative spec. format 0.2
122 nValidationSamples = GetEventCollection(Types::kTraining).size() * valSizeAsDouble;
123 } else {
124 // Absolute spec format 100 or 100.0
126 }
127 } else {
128 Log() << kFATAL << "Cannot parse number \"" << fNumValidationString << "\". Expected string like \"0.2\" or \"100\"."
129 << Endl;
130 }
131
132 // Value validation
133 // ----------------
134 if (nValidationSamples < 0) {
135 Log() << kFATAL << "Validation size \"" << fNumValidationString << "\" is negative." << Endl;
136 }
137
138 if (nValidationSamples == 0) {
139 Log() << kFATAL << "Validation size \"" << fNumValidationString << "\" is zero." << Endl;
140 }
141
143 Log() << kFATAL << "Validation size \"" << fNumValidationString
144 << "\" is larger than or equal in size to training set (size=\"" << trainingSetSize << "\")." << Endl;
145 }
146
147 return nValidationSamples;
148}
149
150
152 // Set default filename for trained model if option is not used
154 fFilenameTrainedModel = GetWeightFileDir() + "/TrainedModel_" + GetName() + ".pt";
155 }
156
157 // - set up number of threads for CPU if NumThreads option was specified
158 // `torch.set_num_threads` sets the number of threads that can be used to
159 // perform cpu operations like conv or mm (usually used by OpenMP or MKL).
160
161 Log() << kINFO << "Using PyTorch - setting special configuration options " << Endl;
162 PyRunString("import torch", "Error importing pytorch");
163
164 // run these above lines also in global namespace to make them visible overall
166
167 // check pytorch version
168 PyRunString("torch_major_version = int(torch.__version__.split('.')[0])");
169 PyObject *pyTorchVersion = PyDict_GetItemString(fLocalNS, "torch_major_version");
171 Log() << kINFO << "Using PyTorch version " << torchVersion << Endl;
172
173 // in case specify number of threads
175 if (num_threads > 0) {
176 Log() << kINFO << "Setting the CPU number of threads = " << num_threads << Endl;
177
178 PyRunString(TString::Format("torch.set_num_threads(%d)", num_threads));
179 PyRunString(TString::Format("torch.set_num_interop_threads(%d)", num_threads));
180 }
181
182 // Setup model, either the initial model from `fFilenameModel` or
183 // the trained model from `fFilenameTrainedModel`
184 if (fContinueTraining) Log() << kINFO << "Continue training with trained model" << Endl;
186}
187
188
190 /*
191 * Load PyTorch model from file
192 */
193
194 Log() << kINFO << " Setup PyTorch Model for training" << Endl;
195
196 if (!fUserCodeName.IsNull()) {
197 Log() << kINFO << " Executing user initialization code from " << fUserCodeName << Endl;
198
199 // run some python code provided by user for method initializations
200 FILE* fp;
201 fp = fopen(fUserCodeName, "r");
202 if (fp) {
204 fclose(fp);
205 }
206 else
207 Log() << kFATAL << "Input user code is not existing : " << fUserCodeName << Endl;
208 }
209
210 PyRunString("print('custom objects for loading model : ',load_model_custom_objects)");
211
212 // Setup the training method
213 PyRunString("fit = load_model_custom_objects[\"train_func\"]",
214 "Failed to load train function from file. Please use key: 'train_func' and pass training loop function as the value.");
215 Log() << kINFO << "Loaded pytorch train function: " << Endl;
216
217
218 // Setup Optimizer. Use SGD Optimizer as Default
219 PyRunString("if 'optimizer' in load_model_custom_objects:\n"
220 " optimizer = load_model_custom_objects['optimizer']\n"
221 "else:\n"
222 " optimizer = torch.optim.SGD\n",
223 "Please use key: 'optimizer' and pass a pytorch optimizer as the value for a custom optimizer.");
224 Log() << kINFO << "Loaded pytorch optimizer: " << Endl;
225
226
227 // Setup the loss criterion
228 PyRunString("criterion = load_model_custom_objects[\"criterion\"]",
229 "Failed to load loss function from file. Using MSE Loss as default. Please use key: 'criterion' and pass a pytorch loss function as the value.");
230 Log() << kINFO << "Loaded pytorch loss function: " << Endl;
231
232
233 // Setup the predict method
234 PyRunString("predict = load_model_custom_objects[\"predict_func\"]",
235 "Can't find user predict function object from file. Please use key: 'predict' and pass a predict function for evaluating the model as the value.");
236 Log() << kINFO << "Loaded pytorch predict function: " << Endl;
237
238
239 // Load already trained model or initial model
241 if (loadTrainedModel) {
243 }
244 else {
246 }
247 PyRunString("model = torch.jit.load('"+filenameLoadModel+"')",
248 "Failed to load PyTorch model from file: "+filenameLoadModel);
249 Log() << kINFO << "Loaded model from file: " << filenameLoadModel << Endl;
250
251
252 /*
253 * Init variables and weights
254 */
255
256 // Get variables, classes and target numbers
260 else Log() << kFATAL << "Selected analysis type is not implemented" << Endl;
261
262 // Init evaluation (needed for getMvaValue)
263 fVals = new float[fNVars]; // holds values used for classification and regression
267
268 fOutput.resize(fNOutputs); // holds classification probabilities or regression output
272
273 // Mark the model as setup
274 fModelIsSetup = true;
275}
276
277
279
281
282 if (!PyIsInitialized()) {
283 Log() << kFATAL << "Python is not initialized" << Endl;
284 }
285 _import_array(); // required to use numpy arrays
286
287 // Import PyTorch
288 PyRunString("import sys; sys.argv = ['']", "Set sys.argv failed");
289 PyRunString("import torch", "import PyTorch failed");
290 // do import also in global namespace
291 auto ret = PyRun_String("import torch", Py_single_input, fGlobalNS, fGlobalNS);
292 if (!ret)
293 Log() << kFATAL << "import torch in global namespace failed!" << Endl;
294
295 // Set flag that model is not setup
296 fModelIsSetup = false;
297}
298
299
301 if(!fModelIsSetup) Log() << kFATAL << "Model is not setup for training" << Endl;
302
303 /*
304 * Load training data to numpy array.
305 * NOTE: These are later forced to be converted into torch tensors throught the training loop which may not be the ideal method.
306 */
307
311
312 Log() << kINFO << "Split TMVA training data in " << nTrainingEvents << " training events and "
313 << nValEvents << " validation events" << Endl;
314
315 float* trainDataX = new float[nTrainingEvents*fNVars];
316 float* trainDataY = new float[nTrainingEvents*fNOutputs];
317 float* trainDataWeights = new float[nTrainingEvents];
318 for (UInt_t i=0; i<nTrainingEvents; i++) {
319 const TMVA::Event* e = GetTrainingEvent(i);
320 // Fill variables
321 for (UInt_t j=0; j<fNVars; j++) {
322 trainDataX[j + i*fNVars] = e->GetValue(j);
323 }
324 // Fill targets
325 // NOTE: For classification, convert class number in one-hot vector,
326 // e.g., 1 -> [0, 1] or 0 -> [1, 0] for binary classification
328 for (UInt_t j=0; j<fNOutputs; j++) {
329 trainDataY[j + i*fNOutputs] = 0;
330 }
331 trainDataY[e->GetClass() + i*fNOutputs] = 1;
332 }
333 else if (GetAnalysisType() == Types::kRegression) {
334 for (UInt_t j=0; j<fNOutputs; j++) {
335 trainDataY[j + i*fNOutputs] = e->GetTarget(j);
336 }
337 }
338 else Log() << kFATAL << "Can not fill target vector because analysis type is not known" << Endl;
339 // Fill weights
340 // NOTE: If no weight branch is given, this defaults to ones for all events
341 trainDataWeights[i] = e->GetWeight();
342 }
343
353
354 /*
355 * Load validation data to numpy array
356 */
357
358 // NOTE: TMVA Validation data is a subset of all the training data
359 // we will not use test data for validation. They will be used for the real testing
360
361
362 float* valDataX = new float[nValEvents*fNVars];
363 float* valDataY = new float[nValEvents*fNOutputs];
364 float* valDataWeights = new float[nValEvents];
365 //validation events follows the trainig one in the TMVA training vector
366 for (UInt_t i=0; i< nValEvents ; i++) {
367 UInt_t ievt = nTrainingEvents + i; // TMVA event index
369 // Fill variables
370 for (UInt_t j=0; j<fNVars; j++) {
371 valDataX[j + i*fNVars] = e->GetValue(j);
372 }
373 // Fill targets
375 for (UInt_t j=0; j<fNOutputs; j++) {
376 valDataY[j + i*fNOutputs] = 0;
377 }
378 valDataY[e->GetClass() + i*fNOutputs] = 1;
379 }
380 else if (GetAnalysisType() == Types::kRegression) {
381 for (UInt_t j=0; j<fNOutputs; j++) {
382 valDataY[j + i*fNOutputs] = e->GetTarget(j);
383 }
384 }
385 else Log() << kFATAL << "Can not fill target vector because analysis type is not known" << Endl;
386 // Fill weights
387 valDataWeights[i] = e->GetWeight();
388 }
389
399
400 /*
401 * Train PyTorch model
402 */
403 Log() << kINFO << "Print Training Model Architecture" << Endl;
404 PyRunString("print(model)");
405
406 // Setup parameters
407
412
413 // Prepare PyTorch Training DataSet
414 PyRunString("train_dataset = torch.utils.data.TensorDataset(torch.Tensor(trainX), torch.Tensor(trainY))",
415 "Failed to create pytorch train Dataset.");
416 // Prepare PyTorch Training Dataloader
417 PyRunString("train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batchSize, shuffle=False)",
418 "Failed to create pytorch train Dataloader.");
419
420
421 // Prepare PyTorch Validation DataSet
422 PyRunString("val_dataset = torch.utils.data.TensorDataset(torch.Tensor(valX), torch.Tensor(valY))",
423 "Failed to create pytorch validation Dataset.");
424 // Prepare PyTorch validation Dataloader
425 PyRunString("val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batchSize, shuffle=False)",
426 "Failed to create pytorch validation Dataloader.");
427
428
429 // Learning Rate Scheduler
430 if (fLearningRateSchedule!="") {
431 // Setup a python dictionary with the desired learning rate steps
432 PyRunString("strScheduleSteps = '"+fLearningRateSchedule+"'\n"
433 "schedulerSteps = {}\n"
434 "for c in strScheduleSteps.split(';'):\n"
435 " x = c.split(',')\n"
436 " schedulerSteps[int(x[0])] = float(x[1])\n",
437 "Failed to setup steps for scheduler function from string: "+fLearningRateSchedule,
439 // Set scheduler function as piecewise function with given steps
440 PyRunString("def schedule(optimizer, epoch, schedulerSteps=schedulerSteps):\n"
441 " if epoch in schedulerSteps:\n"
442 " for param_group in optimizer.param_groups:\n"
443 " param_group['lr'] = float(schedulerSteps[epoch])\n",
444 "Failed to setup scheduler function with string: "+fLearningRateSchedule,
446
447 Log() << kINFO << "Option LearningRateSchedule: Set learning rate during training: " << fLearningRateSchedule << Endl;
448 }
449 else{
450 PyRunString("schedule = None; schedulerSteps = None", "Failed to set scheduler to None.");
451 }
452
453
454 // Save only weights with smallest validation loss
455 if (fSaveBestOnly) {
456 PyRunString("def save_best(model, curr_val, best_val, save_path='"+fFilenameTrainedModel+"'):\n"
457 " if curr_val<=best_val:\n"
458 " best_val = curr_val\n"
459 " best_model_jitted = torch.jit.script(model)\n"
460 " torch.jit.save(best_model_jitted, save_path)\n"
461 " return best_val",
462 "Failed to setup training with option: SaveBestOnly");
463 Log() << kINFO << "Option SaveBestOnly: Only model weights with smallest validation loss will be stored" << Endl;
464 }
465 else{
466 PyRunString("save_best = None", "Failed to set save_best to None.");
467 }
468
469
470 // Note: Early Stopping should not be implemented here. Can be implemented inside train loop function by user if required.
471
472 // Train model
473 PyRunString("trained_model = fit(model, train_loader, val_loader, num_epochs=numEpochs, batch_size=batchSize,"
474 "optimizer=optimizer, criterion=criterion, save_best=save_best, scheduler=(schedule, schedulerSteps))",
475 "Failed to train model");
476
477
478 // Note: PyTorch doesn't store training history data unlike Keras. A user can append and save the loss,
479 // accuracy, other metrics etc to a file for later use.
480
481 /*
482 * Store trained model to file (only if option 'SaveBestOnly' is NOT activated,
483 * because we do not want to override the best model checkpoint)
484 */
485 if (!fSaveBestOnly) {
486 PyRunString("trained_model_jitted = torch.jit.script(trained_model)",
487 "Model not scriptable. Failed to convert to torch script.");
488 PyRunString("torch.jit.save(trained_model_jitted, '"+fFilenameTrainedModel+"')",
489 "Failed to save trained model: "+fFilenameTrainedModel);
490 Log() << kINFO << "Trained model written to file: " << fFilenameTrainedModel << Endl;
491 }
492
493 /*
494 * Clean-up
495 */
496
497 delete[] trainDataX;
498 delete[] trainDataY;
499 delete[] trainDataWeights;
500 delete[] valDataX;
501 delete[] valDataY;
502 delete[] valDataWeights;
503}
504
505
509
511 // Cannot determine error
513
514 // Check whether the model is setup
515 // NOTE: unfortunately this is needed because during evaluation ProcessOptions is not called again
516 if (!fModelIsSetup) {
517 // Setup the trained model
518 SetupPyTorchModel(true);
519 }
520
521 // Get signal probability (called mvaValue here)
522 const TMVA::Event* e = GetEvent();
523 for (UInt_t i=0; i<fNVars; i++) fVals[i] = e->GetValue(i);
524 PyRunString("for i,p in enumerate(predict(model, vals)): output[i]=p\n",
525 "Failed to get predictions");
526
527
529}
530
531
533 // Check whether the model is setup
534 // NOTE: Unfortunately this is needed because during evaluation ProcessOptions is not called again
535 if (!fModelIsSetup) {
536 // Setup the trained model
537 SetupPyTorchModel(true);
538 }
539
540 // Load data to numpy array
541 Long64_t nEvents = Data()->GetNEvents();
542 if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
543 if (firstEvt < 0) firstEvt = 0;
544 nEvents = lastEvt-firstEvt;
545
546 // use timer
547 Timer timer( nEvents, GetName(), kTRUE );
548
549 if (logProgress)
550 Log() << kHEADER << Form("[%s] : ",DataInfo().GetName())
551 << "Evaluation of " << GetMethodName() << " on "
552 << (Data()->GetCurrentType() == Types::kTraining ? "training" : "testing")
553 << " sample (" << nEvents << " events)" << Endl;
554
555 float* data = new float[nEvents*fNVars];
556 for (UInt_t i=0; i<nEvents; i++) {
557 Data()->SetCurrentEvent(i);
558 const TMVA::Event *e = GetEvent();
559 for (UInt_t j=0; j<fNVars; j++) {
560 data[j + i*fNVars] = e->GetValue(j);
561 }
562 }
563
564 npy_intp dimsData[2] = {(npy_intp)nEvents, (npy_intp)fNVars};
566 if (pDataMvaValues==0) Log() << "Failed to load data to Python array" << Endl;
567
568
569 // Get prediction for all events
571 if (pModel==0) Log() << kFATAL << "Failed to get model Python object" << Endl;
572
574 if (pPredict==0) Log() << kFATAL << "Failed to get Python predict function" << Endl;
575
576
577 // Using PyTorch User Defined predict function for predictions
579 if (pPredictions==0) Log() << kFATAL << "Failed to get predictions" << Endl;
580 delete[] data;
581
582 // Load predictions to double vector
583 // NOTE: The signal probability is given at the output
584 std::vector<double> mvaValues(nEvents);
585 float* predictionsData = (float*) PyArray_DATA(pPredictions);
586 for (UInt_t i=0; i<nEvents; i++) {
588 }
589
590 if (logProgress) {
591 Log() << kINFO
592 << "Elapsed time for evaluation of " << nEvents << " events: "
593 << timer.GetElapsedTime() << " " << Endl;
594 }
595
596 return mvaValues;
597}
598
599std::vector<Float_t>& MethodPyTorch::GetRegressionValues() {
600 // Check whether the model is setup
601 // NOTE: unfortunately this is needed because during evaluation ProcessOptions is not called again
602 if (!fModelIsSetup){
603 // Setup the model and load weights
604 SetupPyTorchModel(true);
605 }
606
607 // Get regression values
608 const TMVA::Event* e = GetEvent();
609 for (UInt_t i=0; i<fNVars; i++) fVals[i] = e->GetValue(i);
610
611 PyRunString("for i,p in enumerate(predict(model, vals)): output[i]=p\n",
612 "Failed to get predictions");
613
614
615 // Use inverse transformation of targets to get final regression values
616 Event * eTrans = new Event(*e);
617 for (UInt_t i=0; i<fNOutputs; ++i) {
618 eTrans->SetTarget(i,fOutput[i]);
619 }
620
622 for (UInt_t i=0; i<fNOutputs; ++i) {
623 fOutput[i] = eTrans2->GetTarget(i);
624 }
625
626 return fOutput;
627}
628
629std::vector<Float_t>& MethodPyTorch::GetMulticlassValues() {
630 // Check whether the model is setup
631 // NOTE: unfortunately this is needed because during evaluation ProcessOptions is not called again
632 if (!fModelIsSetup){
633 // Setup the model and load weights
634 SetupPyTorchModel(true);
635 }
636
637 // Get class probabilites
638 const TMVA::Event* e = GetEvent();
639 for (UInt_t i=0; i<fNVars; i++) fVals[i] = e->GetValue(i);
640 PyRunString("for i,p in enumerate(predict(model, vals)): output[i]=p\n",
641 "Failed to get predictions");
642
643 return fOutput;
644}
645
646
649
650
652 Log() << Endl;
653 Log() << "PyTorch is a scientific computing package supporting" << Endl;
654 Log() << "automatic differentiation. This method wraps the training" << Endl;
655 Log() << "and predictions steps of the PyTorch Python package for" << Endl;
656 Log() << "TMVA, so that dataloading, preprocessing and evaluation" << Endl;
657 Log() << "can be done within the TMVA system. To use this PyTorch" << Endl;
658 Log() << "interface, you need to generatea model with PyTorch first." << Endl;
659 Log() << "Then, this model can be loaded and trained in TMVA." << Endl;
660 Log() << Endl;
661}
#define REGISTER_METHOD(CLASS)
for example
_object PyObject
#define Py_single_input
#define e(i)
Definition RSha256.hxx:103
constexpr Bool_t kFALSE
Definition RtypesCore.h:108
long long Long64_t
Portable signed long integer 8 bytes.
Definition RtypesCore.h:83
constexpr Bool_t kTRUE
Definition RtypesCore.h:107
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
Definition TString.cxx:2495
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
MsgLogger & Log() const
Class that contains all the data information.
Definition DataSetInfo.h:62
UInt_t GetNClasses() const
UInt_t GetNTargets() const
Types::ETreeType GetCurrentType() const
Definition DataSet.h:194
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition DataSet.h:206
Long64_t GetNTrainingEvents() const
Definition DataSet.h:68
void SetCurrentEvent(Long64_t ievt) const
Definition DataSet.h:88
const char * GetName() const override
Definition MethodBase.h:334
Types::EAnalysisType GetAnalysisType() const
Definition MethodBase.h:437
const TString & GetWeightFileDir() const
Definition MethodBase.h:492
const TString & GetMethodName() const
Definition MethodBase.h:331
const Event * GetEvent() const
Definition MethodBase.h:751
DataSetInfo & DataInfo() const
Definition MethodBase.h:410
virtual void TestClassification()
initialization
UInt_t GetNVariables() const
Definition MethodBase.h:345
TransformationHandler & GetTransformationHandler(Bool_t takeReroutedIfAvailable=true)
Definition MethodBase.h:394
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
DataSet * Data() const
Definition MethodBase.h:409
const Event * GetTrainingEvent(Long64_t ievt) const
Definition MethodBase.h:771
void GetHelpMessage() const override
void Train() override
void Init() override
Double_t GetMvaValue(Double_t *errLower, Double_t *errUpper) override
std::vector< Float_t > & GetRegressionValues() override
void ProcessOptions() override
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t) override
std::vector< float > fOutput
void ReadModelFromFile() override
MethodPyTorch(const TString &jobName, const TString &methodTitle, DataSetInfo &dsi, const TString &theOption="")
std::vector< Double_t > GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress) override
get all the MVA values for the events of the current Data type
std::vector< Float_t > & GetMulticlassValues() override
void TestClassification() override
initialization
UInt_t GetNumValidationSamples()
Validation of the ValidationSize option.
TString fLearningRateSchedule
TString fFilenameTrainedModel
void SetupPyTorchModel(Bool_t loadTrainedModel)
void DeclareOptions() override
Virtual base class for all TMVA method based on Python.
static int PyIsInitialized()
Check Python interpreter initialization status.
static PyObject * fGlobalNS
void PyRunString(TString code, TString errorMessage="Failed to run python code", int start=256)
Execute Python code from string.
Timing information for training and evaluation of MVA methods.
Definition Timer.h:58
const Event * InverseTransform(const Event *, Bool_t suppressIfNoTargets=true) const
Singleton class for Global types used by TMVA.
Definition Types.h:71
@ kSignal
Never change this number - it is elsewhere assumed to be zero !
Definition Types.h:135
@ kMulticlass
Definition Types.h:129
@ kClassification
Definition Types.h:127
@ kRegression
Definition Types.h:128
@ kTraining
Definition Types.h:143
Basic string class.
Definition TString.h:138
@ kTrailing
Definition TString.h:284
Bool_t IsNull() const
Definition TString.h:422
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2384
create variable transformations
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148