Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodPyKeras.h
Go to the documentation of this file.
1// @(#)root/tmva/pymva $Id$
2// Author: Stefan Wunsch
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodPyKeras *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Interface for Keras python package which is a wrapper for the Theano and *
12 * Tensorflow libraries *
13 * *
14 * Authors (alphabetical): *
15 * Stefan Wunsch <stefan.wunsch@cern.ch> - KIT, Germany *
16 * *
17 * Copyright (c) 2016: *
18 * CERN, Switzerland *
19 * KIT, Germany *
20 * *
21 * Redistribution and use in source and binary forms, with or without *
22 * modification, are permitted according to the terms listed in LICENSE *
23 * (http://tmva.sourceforge.net/LICENSE) *
24 **********************************************************************************/
25
26#ifndef ROOT_TMVA_MethodPyKeras
27#define ROOT_TMVA_MethodPyKeras
28
29#include "TMVA/PyMethodBase.h"
30#include <vector>
31
32namespace TMVA {
33
34 class MethodPyKeras : public PyMethodBase {
35
36 public :
37
38 // constructors
39 MethodPyKeras(const TString &jobName,
40 const TString &methodTitle,
41 DataSetInfo &dsi,
42 const TString &theOption = "");
44 const TString &theWeightFile);
46
47 void Train();
48 void Init();
49 void DeclareOptions();
50 void ProcessOptions();
51
52 // Check whether the given analysis type (regression, classification, ...)
53 // is supported by this method
55 // Get signal probability of given event
56 Double_t GetMvaValue(Double_t *errLower, Double_t *errUpper);
57 std::vector<Double_t> GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress);
58 // Get regression values of given event
59 std::vector<Float_t>& GetRegressionValues();
60 // Get class probabilities of given event
61 std::vector<Float_t>& GetMulticlassValues();
62
63 const Ranking *CreateRanking() { return nullptr; }
64 virtual void TestClassification();
65 virtual void AddWeightsXMLTo(void*) const{}
66 virtual void ReadWeightsFromXML(void*){}
67 virtual void ReadWeightsFromStream(std::istream&) {} // backward compatibility
68 virtual void ReadWeightsFromStream(TFile&){} // backward compatibility
69 void ReadModelFromFile();
70
71 void GetHelpMessage() const;
72
73 /// enumeration defining the used Keras backend
74 enum EBackendType { kUndefined = -1, kTensorFlow = 0, kTheano = 1, kCNTK = 2 };
75
76 /// Get the Keras backend (can be: TensorFlow, Theano or CNTK)
79 // flag to indicate we are using the Keras shipped with Tensorflow 2
80 Bool_t UseTFKeras() const { return fUseTFKeras; }
81
82 private:
83
84 TString fFilenameModel; // Filename of the previously exported Keras model
85 UInt_t fBatchSize {0}; // Training batch size
86 UInt_t fNumEpochs {0}; // Number of training epochs
87 Int_t fNumThreads {0}; // Number of CPU threads (if 0 uses default values)
88 Int_t fVerbose; // Keras verbosity during training
89 Bool_t fUseTFKeras { true}; // use Keras from Tensorflow default is true
90 Bool_t fContinueTraining; // Load weights from previous training
91 Bool_t fSaveBestOnly; // Store only weights with smallest validation loss
92 Int_t fTriesEarlyStopping; // Stop training if validation loss is not decreasing for several epochs
93 TString fLearningRateSchedule; // Set new learning rate at specific epochs
94 TString fTensorBoard; // Store log files during training
95 TString fNumValidationString; // option string defining the number of validation events
96 TString fGpuOptions; // GPU options (for Tensorflow to set in session_config.gpu_options)
97 TString fUserCodeName; // filename of an optional user script that will be executed before loading the Keras model
98 TString fKerasString; // string identifying keras or tf.keras
99
100 bool fModelIsSetup = false; // flag whether current model is setup for being used
101 bool fModelIsSetupForEval = false; // flag to indicate whether model is setup for evaluation
102 std::vector<float> fVals; // variables array used for GetMvaValue
103 std::vector<float> fOutput; // probability or regression output array used for GetMvaValue
104 UInt_t fNVars {0}; // number of variables
105 UInt_t fNOutputs {0}; // number of outputs (classes or targets)
106 TString fFilenameTrainedModel; // output filename for trained model
107
108 void InitKeras(); // initialize Keras (importing the readed modules)
109 void SetupKerasModel(Bool_t loadTrainedModel); // setups the needed variables, loads the model
110 void SetupKerasModelForEval(); // optimizes model for evaluation
111 UInt_t GetNumValidationSamples(); // get number of validation events according to given option
112
114 };
115
116} // namespace TMVA
117
118#endif // ROOT_TMVA_MethodPyKeras
long long Long64_t
Definition RtypesCore.h:80
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
A ROOT file is composed of a header, followed by consecutive data records (TKey instances) with a wel...
Definition TFile.h:53
Class that contains all the data information.
Definition DataSetInfo.h:62
void GetHelpMessage() const
void Init()
Initialization function called from MethodBase::SetupMethod() Note that option string are not yet fil...
std::vector< float > fOutput
virtual void AddWeightsXMLTo(void *) const
ClassDef(MethodPyKeras, 0)
virtual void TestClassification()
initialization
void ProcessOptions()
Function processing the options This is called only when creating the method before training not when...
Bool_t UseTFKeras() const
virtual void ReadWeightsFromStream(std::istream &)
virtual void ReadWeightsFromXML(void *)
EBackendType
enumeration defining the used Keras backend
void SetupKerasModel(Bool_t loadTrainedModel)
std::vector< Float_t > & GetMulticlassValues()
UInt_t GetNumValidationSamples()
Validation of the ValidationSize option.
Double_t GetMvaValue(Double_t *errLower, Double_t *errUpper)
void SetupKerasModelForEval()
Setting up model for evaluation Add here some needed optimizations like disabling eager execution.
std::vector< Float_t > & GetRegressionValues()
const Ranking * CreateRanking()
std::vector< float > fVals
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t)
TString fLearningRateSchedule
EBackendType GetKerasBackend()
Get the Keras backend (can be: TensorFlow, Theano or CNTK)
std::vector< Double_t > GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
get all the MVA values for the events of the current Data type
virtual void ReadWeightsFromStream(TFile &)
Ranking for variables in method (implementation)
Definition Ranking.h:48
Basic string class.
Definition TString.h:139
create variable transformations