Logo ROOT  
Reference Guide
MethodPyKeras.h
Go to the documentation of this file.
1 // @(#)root/tmva/pymva $Id$
2 // Author: Stefan Wunsch
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodPyKeras *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Interface for Keras python package which is a wrapper for the Theano and *
12  * Tensorflow libraries *
13  * *
14  * Authors (alphabetical): *
15  * Stefan Wunsch <stefan.wunsch@cern.ch> - KIT, Germany *
16  * *
17  * Copyright (c) 2016: *
18  * CERN, Switzerland *
19  * KIT, Germany *
20  * *
21  * Redistribution and use in source and binary forms, with or without *
22  * modification, are permitted according to the terms listed in LICENSE *
23  * (http://tmva.sourceforge.net/LICENSE) *
24  **********************************************************************************/
25 
26 #ifndef ROOT_TMVA_MethodPyKeras
27 #define ROOT_TMVA_MethodPyKeras
28 
29 #include "TMVA/PyMethodBase.h"
30 #include <vector>
31 
32 namespace TMVA {
33 
34  class MethodPyKeras : public PyMethodBase {
35 
36  public :
37 
38  // constructors
39  MethodPyKeras(const TString &jobName,
40  const TString &methodTitle,
41  DataSetInfo &dsi,
42  const TString &theOption = "");
44  const TString &theWeightFile);
46 
47  void Train();
48  void Init();
49  void DeclareOptions();
50  void ProcessOptions();
51 
52  // Check whether the given analysis type (regression, classification, ...)
53  // is supported by this method
55  // Get signal probability of given event
56  Double_t GetMvaValue(Double_t *errLower, Double_t *errUpper);
57  std::vector<Double_t> GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress);
58  // Get regression values of given event
59  std::vector<Float_t>& GetRegressionValues();
60  // Get class probabilities of given event
61  std::vector<Float_t>& GetMulticlassValues();
62 
63  const Ranking *CreateRanking() { return 0; }
64  virtual void TestClassification();
65  virtual void AddWeightsXMLTo(void*) const{}
66  virtual void ReadWeightsFromXML(void*){}
67  virtual void ReadWeightsFromStream(std::istream&) {} // backward compatibility
68  virtual void ReadWeightsFromStream(TFile&){} // backward compatibility
69  void ReadModelFromFile();
70 
71  void GetHelpMessage() const;
72 
73  /// enumeration defining the used Keras backend
74  enum EBackendType { kUndefined = -1, kTensorFlow = 0, kTheano = 1, kCNTK = 2 };
75 
76  /// Get the Keras backend (can be: TensorFlow, Theano or CNTK)
79  // flag to indicate we are using the Keras shipped with Tensorflow 2
80  Bool_t UseTFKeras() const { return fUseTFKeras; }
81 
82  private:
83 
84  TString fFilenameModel; // Filename of the previously exported Keras model
85  UInt_t fBatchSize {0}; // Training batch size
86  UInt_t fNumEpochs {0}; // Number of training epochs
87  Int_t fNumThreads {0}; // Number of CPU threads (if 0 uses default values)
88  Int_t fVerbose; // Keras verbosity during training
89  Bool_t fUseTFKeras { kFALSE}; // use Keras from Tensorflow (-1, default, 0 false, 1, true)
90  Bool_t fContinueTraining; // Load weights from previous training
91  Bool_t fSaveBestOnly; // Store only weights with smallest validation loss
92  Int_t fTriesEarlyStopping; // Stop training if validation loss is not decreasing for several epochs
93  TString fLearningRateSchedule; // Set new learning rate at specific epochs
94  TString fTensorBoard; // Store log files during training
95  TString fNumValidationString; // option string defining the number of validation events
96  TString fGpuOptions; // GPU options (for Tensorflow to set in session_config.gpu_options)
97  TString fUserCodeName; // filename of an optional user script that will be executed before loading the Keras model
98  TString fKerasString; // string identifying keras or tf.keras
99 
100  bool fModelIsSetup = false; // flag whether model is loaded, needed for getMvaValue during evaluation
101  float* fVals = nullptr; // variables array used for GetMvaValue
102  std::vector<float> fOutput; // probability or regression output array used for GetMvaValue
103  UInt_t fNVars {0}; // number of variables
104  UInt_t fNOutputs {0}; // number of outputs (classes or targets)
105  TString fFilenameTrainedModel; // output filename for trained model
106 
107  void SetupKerasModel(Bool_t loadTrainedModel); // setups the needed variables, loads the model
108  UInt_t GetNumValidationSamples(); // get number of validation events according to given option
109 
111  };
112 
113 } // namespace TMVA
114 
115 #endif // ROOT_TMVA_MethodPyKeras
TMVA::MethodPyKeras::GetMvaValue
Double_t GetMvaValue(Double_t *errLower, Double_t *errUpper)
Definition: MethodPyKeras.cxx:630
TMVA::MethodPyKeras::GetMulticlassValues
std::vector< Float_t > & GetMulticlassValues()
Definition: MethodPyKeras.cxx:739
TMVA::MethodPyKeras::SetupKerasModel
void SetupKerasModel(Bool_t loadTrainedModel)
Definition: MethodPyKeras.cxx:176
TMVA::MethodPyKeras::fTriesEarlyStopping
Int_t fTriesEarlyStopping
Definition: MethodPyKeras.h:92
TMVA::PyMethodBase
Definition: PyMethodBase.h:56
TMVA::MethodPyKeras::fVals
float * fVals
Definition: MethodPyKeras.h:101
TMVA::MethodPyKeras::fBatchSize
UInt_t fBatchSize
Definition: MethodPyKeras.h:85
TMVA::MethodPyKeras::fUserCodeName
TString fUserCodeName
Definition: MethodPyKeras.h:97
TMVA::MethodPyKeras::fNumEpochs
UInt_t fNumEpochs
Definition: MethodPyKeras.h:86
TMVA::Ranking
Ranking for variables in method (implementation)
Definition: Ranking.h:48
TMVA::MethodPyKeras::fNumThreads
Int_t fNumThreads
Definition: MethodPyKeras.h:87
Long64_t
long long Long64_t
Definition: RtypesCore.h:80
TMVA::MethodPyKeras::fUseTFKeras
Bool_t fUseTFKeras
Definition: MethodPyKeras.h:89
TMVA::MethodPyKeras::ProcessOptions
void ProcessOptions()
Function processing the options This is called only when creating the method before training not when...
Definition: MethodPyKeras.cxx:163
TMVA::MethodPyKeras::GetMvaValues
std::vector< Double_t > GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
get all the MVA values for the events of the current Data type
Definition: MethodPyKeras.cxx:650
TMVA::MethodPyKeras::fNVars
UInt_t fNVars
Definition: MethodPyKeras.h:103
TMVA::MethodPyKeras::fFilenameModel
TString fFilenameModel
Definition: MethodPyKeras.h:84
TMVA::MethodPyKeras::GetKerasBackend
EBackendType GetKerasBackend()
Get the Keras backend (can be: TensorFlow, Theano or CNTK)
Definition: MethodPyKeras.cxx:772
TMVA::MethodPyKeras::fKerasString
TString fKerasString
Definition: MethodPyKeras.h:98
TMVA::MethodPyKeras::fContinueTraining
Bool_t fContinueTraining
Definition: MethodPyKeras.h:90
TString
Basic string class.
Definition: TString.h:136
bool
TMVA::MethodPyKeras::fGpuOptions
TString fGpuOptions
Definition: MethodPyKeras.h:96
TMVA::MethodPyKeras::MethodPyKeras
MethodPyKeras(const TString &jobName, const TString &methodTitle, DataSetInfo &dsi, const TString &theOption="")
Definition: MethodPyKeras.cxx:38
TMVA::MethodPyKeras::UseTFKeras
Bool_t UseTFKeras() const
Definition: MethodPyKeras.h:80
TMVA::MethodPyKeras::GetNumValidationSamples
UInt_t GetNumValidationSamples()
Validation of the ValidationSize option.
Definition: MethodPyKeras.cxx:108
TMVA::DataSetInfo
Class that contains all the data information.
Definition: DataSetInfo.h:62
TMVA::MethodPyKeras::Init
void Init()
Initialization function called from MethodBase::SetupMethod() Note that option string are not yet fil...
Definition: MethodPyKeras.cxx:382
TMVA::MethodPyKeras::kUndefined
@ kUndefined
Definition: MethodPyKeras.h:74
TMVA::MethodPyKeras::GetKerasBackendName
TString GetKerasBackendName()
Definition: MethodPyKeras.cxx:797
TMVA::Types::EAnalysisType
EAnalysisType
Definition: Types.h:128
TMVA::MethodPyKeras::ReadModelFromFile
void ReadModelFromFile()
Definition: MethodPyKeras.cxx:756
kFALSE
const Bool_t kFALSE
Definition: RtypesCore.h:101
TMVA::MethodPyKeras::~MethodPyKeras
~MethodPyKeras()
Definition: MethodPyKeras.cxx:65
TMVA::MethodPyKeras::ReadWeightsFromXML
virtual void ReadWeightsFromXML(void *)
Definition: MethodPyKeras.h:66
TFile
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:54
TMVA::MethodPyKeras
Definition: MethodPyKeras.h:34
unsigned int
TMVA::MethodPyKeras::fNOutputs
UInt_t fNOutputs
Definition: MethodPyKeras.h:104
TMVA::MethodPyKeras::ReadWeightsFromStream
virtual void ReadWeightsFromStream(TFile &)
Definition: MethodPyKeras.h:68
TMVA::MethodPyKeras::CreateRanking
const Ranking * CreateRanking()
Definition: MethodPyKeras.h:63
Double_t
double Double_t
Definition: RtypesCore.h:59
TMVA::MethodPyKeras::fVerbose
Int_t fVerbose
Definition: MethodPyKeras.h:88
TMVA::MethodPyKeras::fSaveBestOnly
Bool_t fSaveBestOnly
Definition: MethodPyKeras.h:91
TMVA::MethodPyKeras::TestClassification
virtual void TestClassification()
initialization
Definition: MethodPyKeras.cxx:626
ClassDef
#define ClassDef(name, id)
Definition: Rtypes.h:325
TMVA::MethodPyKeras::AddWeightsXMLTo
virtual void AddWeightsXMLTo(void *) const
Definition: MethodPyKeras.h:65
TMVA::MethodPyKeras::EBackendType
EBackendType
enumeration defining the used Keras backend
Definition: MethodPyKeras.h:74
TMVA::MethodPyKeras::kCNTK
@ kCNTK
Definition: MethodPyKeras.h:74
TMVA::MethodPyKeras::GetHelpMessage
void GetHelpMessage() const
Definition: MethodPyKeras.cxx:759
TMVA::MethodPyKeras::fFilenameTrainedModel
TString fFilenameTrainedModel
Definition: MethodPyKeras.h:105
type
int type
Definition: TGX11.cxx:121
PyMethodBase.h
TMVA::MethodPyKeras::fModelIsSetup
bool fModelIsSetup
Definition: MethodPyKeras.h:100
TMVA::MethodPyKeras::GetRegressionValues
std::vector< Float_t > & GetRegressionValues()
Definition: MethodPyKeras.cxx:711
TMVA::MethodPyKeras::fOutput
std::vector< float > fOutput
Definition: MethodPyKeras.h:102
TMVA::MethodPyKeras::fNumValidationString
TString fNumValidationString
Definition: MethodPyKeras.h:95
TMVA::MethodPyKeras::HasAnalysisType
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t)
Definition: MethodPyKeras.cxx:68
TMVA::MethodPyKeras::DeclareOptions
void DeclareOptions()
Definition: MethodPyKeras.cxx:77
TMVA::MethodPyKeras::kTensorFlow
@ kTensorFlow
Definition: MethodPyKeras.h:74
TMVA
create variable transformations
Definition: GeneticMinimizer.h:22
TMVA::MethodPyKeras::Train
void Train()
Definition: MethodPyKeras.cxx:398
int
TMVA::MethodPyKeras::fLearningRateSchedule
TString fLearningRateSchedule
Definition: MethodPyKeras.h:93
TMVA::MethodPyKeras::ReadWeightsFromStream
virtual void ReadWeightsFromStream(std::istream &)
Definition: MethodPyKeras.h:67
TMVA::MethodPyKeras::fTensorBoard
TString fTensorBoard
Definition: MethodPyKeras.h:94
TMVA::MethodPyKeras::kTheano
@ kTheano
Definition: MethodPyKeras.h:74