Logo ROOT   6.16/01
Reference Guide
MethodPyKeras.h
Go to the documentation of this file.
1// @(#)root/tmva/pymva $Id$
2// Author: Stefan Wunsch
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodPyKeras *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Interface for Keras python package which is a wrapper for the Theano and *
12 * Tensorflow libraries *
13 * *
14 * Authors (alphabetical): *
15 * Stefan Wunsch <stefan.wunsch@cern.ch> - KIT, Germany *
16 * *
17 * Copyright (c) 2016: *
18 * CERN, Switzerland *
19 * KIT, Germany *
20 * *
21 * Redistribution and use in source and binary forms, with or without *
22 * modification, are permitted according to the terms listed in LICENSE *
23 * (http://tmva.sourceforge.net/LICENSE) *
24 **********************************************************************************/
25
26#ifndef ROOT_TMVA_MethodPyKeras
27#define ROOT_TMVA_MethodPyKeras
28
29#include "TMVA/PyMethodBase.h"
30
31namespace TMVA {
32
33 class MethodPyKeras : public PyMethodBase {
34
35 public :
36
37 // constructors
38 MethodPyKeras(const TString &jobName,
39 const TString &methodTitle,
40 DataSetInfo &dsi,
41 const TString &theOption = "");
43 const TString &theWeightFile);
45
46 void Train();
47 void Init();
48 void DeclareOptions();
49 void ProcessOptions();
50
51 // Check whether the given analysis type (regression, classification, ...)
52 // is supported by this method
54 // Get signal probability of given event
55 Double_t GetMvaValue(Double_t *errLower, Double_t *errUpper);
56 std::vector<Double_t> GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress);
57 // Get regression values of given event
58 std::vector<Float_t>& GetRegressionValues();
59 // Get class probabilities of given event
60 std::vector<Float_t>& GetMulticlassValues();
61
62 const Ranking *CreateRanking() { return 0; }
63 virtual void TestClassification();
64 virtual void AddWeightsXMLTo(void*) const{}
65 virtual void ReadWeightsFromXML(void*){}
66 virtual void ReadWeightsFromStream(std::istream&) {} // backward compatibility
67 virtual void ReadWeightsFromStream(TFile&){} // backward compatibility
68 void ReadModelFromFile();
69
70 void GetHelpMessage() const;
71
72 private:
73
74 TString fFilenameModel; // Filename of the previously exported Keras model
75 UInt_t fBatchSize {0}; // Training batch size
76 UInt_t fNumEpochs {0}; // Number of training epochs
77 Int_t fVerbose; // Keras verbosity during training
78 Bool_t fContinueTraining; // Load weights from previous training
79 Bool_t fSaveBestOnly; // Store only weights with smallest validation loss
80 Int_t fTriesEarlyStopping; // Stop training if validation loss is not decreasing for several epochs
81 TString fLearningRateSchedule; // Set new learning rate at specific epochs
82 TString fTensorBoard; // Store log files during training
83 TString fNumValidationString; // option string defining the number of validation events
84
85 bool fModelIsSetup = false; // flag whether model is loaded, neede for getMvaValue during evaluation
86 float* fVals = nullptr; // variables array used for GetMvaValue
87 std::vector<float> fOutput; // probability or regression output array used for GetMvaValue
88 UInt_t fNVars {0}; // number of variables
89 UInt_t fNOutputs {0}; // number of outputs (classes or targets)
90 TString fFilenameTrainedModel; // output filename for trained model
91
92 void SetupKerasModel(Bool_t loadTrainedModel); // setups the needed variables loads the model
93 UInt_t GetNumValidationSamples(); // get numer of validation events according to given option
94
96 };
97
98} // namespace TMVA
99
100#endif // ROOT_TMVA_MethodPyKeras
int Int_t
Definition: RtypesCore.h:41
unsigned int UInt_t
Definition: RtypesCore.h:42
bool Bool_t
Definition: RtypesCore.h:59
double Double_t
Definition: RtypesCore.h:55
long long Long64_t
Definition: RtypesCore.h:69
#define ClassDef(name, id)
Definition: Rtypes.h:324
int type
Definition: TGX11.cxx:120
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:48
Class that contains all the data information.
Definition: DataSetInfo.h:60
void GetHelpMessage() const
std::vector< float > fOutput
Definition: MethodPyKeras.h:87
virtual void AddWeightsXMLTo(void *) const
Definition: MethodPyKeras.h:64
virtual void TestClassification()
initialization
virtual void ReadWeightsFromStream(std::istream &)
Definition: MethodPyKeras.h:66
virtual void ReadWeightsFromXML(void *)
Definition: MethodPyKeras.h:65
void SetupKerasModel(Bool_t loadTrainedModel)
std::vector< Float_t > & GetMulticlassValues()
UInt_t GetNumValidationSamples()
Validation of the ValidationSize option.
Double_t GetMvaValue(Double_t *errLower, Double_t *errUpper)
std::vector< Float_t > & GetRegressionValues()
const Ranking * CreateRanking()
Definition: MethodPyKeras.h:62
TString fNumValidationString
Definition: MethodPyKeras.h:83
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t)
MethodPyKeras(const TString &jobName, const TString &methodTitle, DataSetInfo &dsi, const TString &theOption="")
TString fLearningRateSchedule
Definition: MethodPyKeras.h:81
TString fFilenameTrainedModel
Definition: MethodPyKeras.h:90
std::vector< Double_t > GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
get all the MVA values for the events of the current Data type
virtual void ReadWeightsFromStream(TFile &)
Definition: MethodPyKeras.h:67
Ranking for variables in method (implementation)
Definition: Ranking.h:48
EAnalysisType
Definition: Types.h:127
Basic string class.
Definition: TString.h:131
Abstract ClassifierFactory template that handles arbitrary types.