Logo ROOT   6.12/07
Reference Guide
MethodPyKeras.h
Go to the documentation of this file.
1 // @(#)root/tmva/pymva $Id$
2 // Author: Stefan Wunsch
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodPyKeras *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Interface for Keras python package which is a wrapper for the Theano and *
12  * Tensorflow libraries *
13  * *
14  * Authors (alphabetical): *
15  * Stefan Wunsch <stefan.wunsch@cern.ch> - KIT, Germany *
16  * *
17  * Copyright (c) 2016: *
18  * CERN, Switzerland *
19  * KIT, Germany *
20  * *
21  * Redistribution and use in source and binary forms, with or without *
22  * modification, are permitted according to the terms listed in LICENSE *
23  * (http://tmva.sourceforge.net/LICENSE) *
24  **********************************************************************************/
25 
26 #ifndef ROOT_TMVA_MethodPyKeras
27 #define ROOT_TMVA_MethodPyKeras
28 
29 #include "TMVA/PyMethodBase.h"
30 
31 namespace TMVA {
32 
33  class MethodPyKeras : public PyMethodBase {
34 
35  public :
36 
37  // constructors
38  MethodPyKeras(const TString &jobName,
39  const TString &methodTitle,
40  DataSetInfo &dsi,
41  const TString &theOption = "");
43  const TString &theWeightFile);
45 
46  void Train();
47  void Init();
48  void DeclareOptions();
49  void ProcessOptions();
50 
51  // Check whether the given analysis type (regression, classification, ...)
52  // is supported by this method
54  // Get signal probability of given event
55  Double_t GetMvaValue(Double_t *errLower, Double_t *errUpper);
56  std::vector<Double_t> GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress);
57  // Get regression values of given event
58  std::vector<Float_t>& GetRegressionValues();
59  // Get class probabilities of given event
60  std::vector<Float_t>& GetMulticlassValues();
61 
62  const Ranking *CreateRanking() { return 0; }
63  virtual void TestClassification();
64  virtual void AddWeightsXMLTo(void*) const{}
65  virtual void ReadWeightsFromXML(void*){}
66  virtual void ReadWeightsFromStream(std::istream&) {} // backward compatibility
67  virtual void ReadWeightsFromStream(TFile&){} // backward compatibility
68  void ReadModelFromFile();
69 
70  void GetHelpMessage() const;
71 
72  private:
73 
74  TString fFilenameModel; // Filename of the previously exported Keras model
75  UInt_t fBatchSize {0}; // Training batch size
76  UInt_t fNumEpochs {0}; // Number of training epochs
77  Int_t fVerbose; // Keras verbosity during training
78  Bool_t fContinueTraining; // Load weights from previous training
79  Bool_t fSaveBestOnly; // Store only weights with smallest validation loss
80  Int_t fTriesEarlyStopping; // Stop training if validation loss is not decreasing for several epochs
81  TString fLearningRateSchedule; // Set new learning rate at specific epochs
82 
83  bool fModelIsSetup = false; // flag whether model is loaded, neede for getMvaValue during evaluation
84  float* fVals = nullptr; // variables array used for GetMvaValue
85  std::vector<float> fOutput; // probability or regression output array used for GetMvaValue
86  UInt_t fNVars {0}; // number of variables
87  UInt_t fNOutputs {0}; // number of outputs (classes or targets)
88  TString fFilenameTrainedModel; // output filename for trained model
89 
90  void SetupKerasModel(Bool_t loadTrainedModel); // setups the needed variables loads the model
91 
93  };
94 
95 } // namespace TMVA
96 
97 #endif // ROOT_TMVA_MethodPyKeras
Double_t GetMvaValue(Double_t *errLower, Double_t *errUpper)
long long Long64_t
Definition: RtypesCore.h:69
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format...
Definition: TFile.h:46
EAnalysisType
Definition: Types.h:125
virtual void ReadWeightsFromStream(TFile &)
Definition: MethodPyKeras.h:67
Basic string class.
Definition: TString.h:125
Ranking for variables in method (implementation)
Definition: Ranking.h:48
int Int_t
Definition: RtypesCore.h:41
virtual void ReadWeightsFromXML(void *)
Definition: MethodPyKeras.h:65
bool Bool_t
Definition: RtypesCore.h:59
std::vector< Float_t > & GetRegressionValues()
void GetHelpMessage() const
#define ClassDef(name, id)
Definition: Rtypes.h:320
TString fFilenameTrainedModel
Definition: MethodPyKeras.h:88
virtual void AddWeightsXMLTo(void *) const
Definition: MethodPyKeras.h:64
Class that contains all the data information.
Definition: DataSetInfo.h:60
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t)
TString fLearningRateSchedule
Definition: MethodPyKeras.h:81
unsigned int UInt_t
Definition: RtypesCore.h:42
std::vector< Float_t > & GetMulticlassValues()
const Ranking * CreateRanking()
Definition: MethodPyKeras.h:62
void SetupKerasModel(Bool_t loadTrainedModel)
double Double_t
Definition: RtypesCore.h:55
std::vector< Double_t > GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
get all the MVA values for the events of the current Data type
int type
Definition: TGX11.cxx:120
virtual void TestClassification()
initialization
Abstract ClassifierFactory template that handles arbitrary types.
MethodPyKeras(const TString &jobName, const TString &methodTitle, DataSetInfo &dsi, const TString &theOption="")
virtual void ReadWeightsFromStream(std::istream &)
Definition: MethodPyKeras.h:66
std::vector< float > fOutput
Definition: MethodPyKeras.h:85