Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodPyKeras.h
Go to the documentation of this file.
1// @(#)root/tmva/pymva $Id$
2// Author: Stefan Wunsch
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodPyKeras *
8 * *
9 * *
10 * Description: *
11 * Interface for Keras python package which is a wrapper for the Theano and *
12 * Tensorflow libraries *
13 * *
14 * Authors (alphabetical): *
15 * Stefan Wunsch <stefan.wunsch@cern.ch> - KIT, Germany *
16 * *
17 * Copyright (c) 2016: *
18 * CERN, Switzerland *
19 * KIT, Germany *
20 * *
21 * Redistribution and use in source and binary forms, with or without *
22 * modification, are permitted according to the terms listed in LICENSE *
23 * (see tmva/doc/LICENSE) *
24 **********************************************************************************/
25
26#ifndef ROOT_TMVA_MethodPyKeras
27#define ROOT_TMVA_MethodPyKeras
28
29#include "TMVA/PyMethodBase.h"
30#include <vector>
31
32//class PyArrayObject;
33
34namespace TMVA {
35
36 class MethodPyKeras : public PyMethodBase {
37
38 public :
39
40 // constructors
42 const TString &methodTitle,
44 const TString &theOption = "");
46 const TString &theWeightFile);
48
49 void Train() override;
50 void Init() override;
51 void DeclareOptions() override;
52 void ProcessOptions() override;
53
54 // Check whether the given analysis type (regression, classification, ...)
55 // is supported by this method
57 // Get signal probability of given event
59 std::vector<Double_t> GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress) override;
60 // Get regression values of given event
61 std::vector<Float_t>& GetRegressionValues() override;
62 // Get all regression values for all the events in the data set
63 std::vector<Float_t> GetAllRegressionValues() override;
64 // Get class probabilities of given event
65 std::vector<Float_t>& GetMulticlassValues() override;
66 // Get all multiclass values for all the events in the data set
67 std::vector<Float_t> GetAllMulticlassValues() override;
68
69 const Ranking *CreateRanking() override { return nullptr; }
70 void TestClassification() override;
71 void AddWeightsXMLTo(void*) const override{}
72 void ReadWeightsFromXML(void*) override{}
73 void ReadWeightsFromStream(std::istream&) override {} // backward compatibility
74 void ReadWeightsFromStream(TFile&) override{} // backward compatibility
75 void ReadModelFromFile() override;
76
77 void GetHelpMessage() const override;
78
79 /// enumeration defining the used Keras backend
80 enum EBackendType { kUndefined = -1, kTensorFlow = 0, kTheano = 1, kCNTK = 2 };
81
82 /// Get the Keras backend (can be: TensorFlow, Theano or CNTK)
85 // flag to indicate we are using the Keras shipped with Tensorflow 2
86 Bool_t UseTFKeras() const { return fUseTFKeras; }
87
88 private:
89
90 TString fFilenameModel; // Filename of the previously exported Keras model
91 UInt_t fBatchSize {0}; // Training batch size
92 UInt_t fNumEpochs {0}; // Number of training epochs
93 Int_t fNumThreads {0}; // Number of CPU threads (if 0 uses default values)
94 Int_t fVerbose; // Keras verbosity during training
95 Bool_t fUseTFKeras { true}; // use Keras from Tensorflow default is true
96 Bool_t fContinueTraining; // Load weights from previous training
97 Bool_t fSaveBestOnly; // Store only weights with smallest validation loss
98 Int_t fTriesEarlyStopping; // Stop training if validation loss is not decreasing for several epochs
99 TString fLearningRateSchedule; // Set new learning rate at specific epochs
100 TString fTensorBoard; // Store log files during training
101 TString fNumValidationString; // option string defining the number of validation events
102 TString fGpuOptions; // GPU options (for Tensorflow to set in session_config.gpu_options)
103 TString fUserCodeName; // filename of an optional user script that will be executed before loading the Keras model
104 TString fKerasString; // string identifying keras or tf.keras
105
106 bool fUseKeras3 = false; // use new Keras API (available from TF 2.16)
107 bool fModelIsSetup = false; // flag whether current model is setup for being used
108 bool fModelIsSetupForEval = false; // flag to indicate whether model is setup for evaluation
109 std::vector<float> fVals; // variables array used for GetMvaValue
110 std::vector<float> fOutput; // probability or regression output array used for GetMvaValue
111 UInt_t fNVars {0}; // number of variables
112 UInt_t fNOutputs {0}; // number of outputs (classes or targets)
113 TString fFilenameTrainedModel; // output filename for trained model
114 PyObject * fPyVals = nullptr; // Python array object for input data
115 PyObject * fPyOutput = nullptr; // Python array object for output data
116
117
118 void InitKeras(); // initialize Keras (importing the readed modules)
119 void SetupKerasModel(Bool_t loadTrainedModel); // setups the needed variables, loads the model
120 void SetupKerasModelForEval(); // optimizes model for evaluation
121 void InitEvaluation(size_t nEvents); // allocate arrays for evaluation
122 UInt_t GetNumValidationSamples(); // get number of validation events according to given option
123
125 };
126
127} // namespace TMVA
128
129#endif // ROOT_TMVA_MethodPyKeras
_object PyObject
long long Long64_t
Portable signed long integer 8 bytes.
Definition RtypesCore.h:83
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Definition TFile.h:130
Class that contains all the data information.
Definition DataSetInfo.h:62
std::vector< Float_t > GetAllRegressionValues() override
Get al regression values in one call.
void GetHelpMessage() const override
void ProcessOptions() override
Function processing the options This is called only when creating the method before training not when...
std::vector< float > fOutput
std::vector< Float_t > & GetRegressionValues() override
const Ranking * CreateRanking() override
void AddWeightsXMLTo(void *) const override
Bool_t UseTFKeras() const
void ReadWeightsFromXML(void *) override
void ReadWeightsFromStream(TFile &) override
EBackendType
enumeration defining the used Keras backend
void SetupKerasModel(Bool_t loadTrainedModel)
Double_t GetMvaValue(Double_t *errLower, Double_t *errUpper0) override
void DeclareOptions() override
UInt_t GetNumValidationSamples()
Validation of the ValidationSize option.
void TestClassification() override
initialization
void SetupKerasModelForEval()
Setting up model for evaluation Add here some needed optimizations like disabling eager execution.
ClassDefOverride(MethodPyKeras, 0)
std::vector< Float_t > GetAllMulticlassValues() override
Get all multi-class values.
std::vector< float > fVals
std::vector< Double_t > GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress) override
get all the MVA values for the events of the current Data type
MethodPyKeras(const TString &jobName, const TString &methodTitle, DataSetInfo &dsi, const TString &theOption="")
void Train() override
TString fLearningRateSchedule
std::vector< Float_t > & GetMulticlassValues() override
void ReadWeightsFromStream(std::istream &) override
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t) override
void InitEvaluation(size_t nEvents)
void Init() override
Initialization function called from MethodBase::SetupMethod() Note that option string are not yet fil...
void ReadModelFromFile() override
EBackendType GetKerasBackend()
Get the Keras backend (can be: TensorFlow, Theano or CNTK)
Virtual base class for all TMVA method based on Python.
Ranking for variables in method (implementation)
Definition Ranking.h:48
Basic string class.
Definition TString.h:138
create variable transformations