Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodPyTorch.h
Go to the documentation of this file.
1// @(#)root/tmva/pymva $Id$
2// Author: Anirudh Dagar, 2020
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodPyTorch *
8 * *
9 * *
10 * Description: *
11 * Interface for PyTorch python based scientific package supporting *
12 * automatic differentiation for machine learning. *
13 * *
14 * Authors (alphabetical): *
15 * Anirudh Dagar <anirudhdagar6@gmail.com> - IIT, Roorkee *
16 * *
17 * Copyright (c) 2020: *
18 * CERN, Switzerland *
19 * IIT, Roorkee *
20 * *
21 * Redistribution and use in source and binary forms, with or without *
22 * modification, are permitted according to the terms listed in LICENSE *
23 * (see tmva/doc/LICENSE) *
24 **********************************************************************************/
25
26#ifndef ROOT_TMVA_MethodPyTorch
27#define ROOT_TMVA_MethodPyTorch
28
29#include "TMVA/PyMethodBase.h"
30#include <vector>
31
32namespace TMVA {
33
34 class MethodPyTorch : public PyMethodBase {
35
36 public :
37
38 // constructors
40 const TString &methodTitle,
42 const TString &theOption = "");
44 const TString &theWeightFile);
46
47 void Train() override;
48 void Init() override;
49 void DeclareOptions() override;
50 void ProcessOptions() override;
51
52 // Check whether the given analysis type (regression, classification, ...)
53 // is supported by this method
55 // Get signal probability of given event
57 std::vector<Double_t> GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress) override;
58 // Get regression values of given event
59 std::vector<Float_t>& GetRegressionValues() override;
60 // Get class probabilities of given event
61 std::vector<Float_t>& GetMulticlassValues() override;
62
63 const Ranking *CreateRanking() override { return nullptr; }
64 void TestClassification() override;
65 void AddWeightsXMLTo(void*) const override{}
66 void ReadWeightsFromXML(void*) override{}
67 void ReadWeightsFromStream(std::istream&) override {} // backward compatibility
68 void ReadWeightsFromStream(TFile&) override{} // backward compatibility
69 void ReadModelFromFile() override;
70
71 void GetHelpMessage() const override;
72
73
74 private:
75
76 TString fFilenameModel; // Filename of the previously exported PyTorch model
77 UInt_t fBatchSize {0}; // Training batch size
78 UInt_t fNumEpochs {0}; // Number of training epochs
79 Int_t fNumThreads {0}; // Number of CPU threads (if 0 uses default values)
80
81 Bool_t fContinueTraining; // Load weights from previous training
82 Bool_t fSaveBestOnly; // Store only weights with smallest validation loss
83 TString fLearningRateSchedule; // Set new learning rate at specific epochs
84
85 TString fNumValidationString; // option string defining the number of validation events
86
87 TString fUserCodeName; // filename of the user script that will be executed before loading the PyTorch model
88
89 bool fModelIsSetup = false; // flag whether model is loaded, needed for getMvaValue during evaluation
90 float* fVals = nullptr; // variables array used for GetMvaValue
91 std::vector<float> fOutput; // probability or regression output array used for GetMvaValue
92 UInt_t fNVars {0}; // number of variables
93 UInt_t fNOutputs {0}; // number of outputs (classes or targets)
94 TString fFilenameTrainedModel; // output filename for trained model
95
96 void SetupPyTorchModel(Bool_t loadTrainedModel); // setups the needed variables, loads the model
97 UInt_t GetNumValidationSamples(); // get number of validation events according to given option
98
100 };
101
102} // namespace TMVA
103
104#endif // ROOT_TMVA_MethodPyTorch
long long Long64_t
Portable signed long integer 8 bytes.
Definition RtypesCore.h:83
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Definition TFile.h:131
Class that contains all the data information.
Definition DataSetInfo.h:62
void ReadWeightsFromStream(std::istream &) override
void GetHelpMessage() const override
void Train() override
void Init() override
const Ranking * CreateRanking() override
Double_t GetMvaValue(Double_t *errLower, Double_t *errUpper) override
std::vector< Float_t > & GetRegressionValues() override
void ProcessOptions() override
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t) override
std::vector< float > fOutput
void ReadModelFromFile() override
MethodPyTorch(const TString &jobName, const TString &methodTitle, DataSetInfo &dsi, const TString &theOption="")
ClassDefOverride(MethodPyTorch, 0)
void ReadWeightsFromStream(TFile &) override
std::vector< Double_t > GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress) override
get all the MVA values for the events of the current Data type
void ReadWeightsFromXML(void *) override
std::vector< Float_t > & GetMulticlassValues() override
void TestClassification() override
initialization
UInt_t GetNumValidationSamples()
Validation of the ValidationSize option.
TString fLearningRateSchedule
void AddWeightsXMLTo(void *) const override
TString fFilenameTrainedModel
void SetupPyTorchModel(Bool_t loadTrainedModel)
void DeclareOptions() override
Virtual base class for all TMVA method based on Python.
Ranking for variables in method (implementation)
Definition Ranking.h:48
Basic string class.
Definition TString.h:138
create variable transformations