Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA_SOFIE_Keras_HiggsModel.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_tmva
3/// \notebook -nodraw
4/// This macro run the SOFIE parser on the Keras model
5/// obtaining running TMVA_Higgs_Classification.C
6/// You need to run that macro before this one
7///
8/// \author Lorenzo Moneta
9
10using namespace TMVA::Experimental;
11
12
13void TMVA_SOFIE_Keras_HiggsModel(const char * modelFile = "Higgs_trained_model.h5"){
14
15 // check if the input file exists
16 if (gSystem->AccessPathName(modelFile)) {
17 Error("TMVA_SOFIE_RDataFrame","You need to run TMVA_Higgs_Classification.C to generate the Keras trained model");
18 return;
19 }
20
21 // parse the input Keras model into RModel object
22 SOFIE::RModel model = SOFIE::PyKeras::Parse(modelFile);
23
24 TString modelHeaderFile = modelFile;
25 modelHeaderFile.ReplaceAll(".h5",".hxx");
26 //Generating inference code
27 model.Generate();
28 model.OutputGenerated(std::string(modelHeaderFile));
29
30 // copy include in $ROOTSYS/tutorials/
31 std::cout << "include is in " << gROOT->GetIncludeDir() << std::endl;
32}
void Error(const char *location, const char *msgfmt,...)
Use this function in case an error occurred.
Definition TError.cxx:197
#define gROOT
Definition TROOT.h:405
R__EXTERN TSystem * gSystem
Definition TSystem.h:560
void OutputGenerated(std::string filename="")
Definition RModel.cxx:627
void Generate(std::underlying_type_t< Options > options, int batchSize=1)
Definition RModel.cxx:240
Basic string class.
Definition TString.h:139
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition TString.h:704
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition TSystem.cxx:1299