Logo ROOT  
Reference Guide
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
Loading...
Searching...
No Matches
TMVA_SOFIE_RDataFrame_JIT.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_tmva
3/// \notebook -nodraw
4/// This macro provides an example of using a trained model with Keras
5/// and make inference using SOFIE and RDataFrame
6/// This macro uses as input a Keras model generated with the
7/// TMVA_Higgs_Classification.C tutorial
8/// You need to run that macro before this one.
9/// In this case we are parsing the input file and then run the inference in the same
10/// macro making use of the ROOT JITing capability
11///
12///
13/// \macro_code
14/// \macro_output
15/// \author Lorenzo Moneta
16
17using namespace TMVA::Experimental;
18
19/// function to compile the generated model and the declaration of the SofieFunctor
20/// used by RDF.
21/// Assume that the model name as in the header file
22void CompileModelForRDF(const std::string & headerModelFile, unsigned int ninputs, unsigned int nslots=0) {
23
24 std::string modelName = headerModelFile.substr(0,headerModelFile.find(".hxx"));
25 std::string cmd = std::string("#include \"") + headerModelFile + std::string("\"");
26 auto ret = gInterpreter->Declare(cmd.c_str());
27 if (!ret)
28 throw std::runtime_error("Error compiling : " + cmd);
29 std::cout << "compiled : " << cmd << std::endl;
30
31 cmd = "auto sofie_functor = TMVA::Experimental::SofieFunctor<" + std::to_string(ninputs) + ",TMVA_SOFIE_" +
32 modelName + "::Session>(" + std::to_string(nslots) + ");";
33 ret = gInterpreter->Declare(cmd.c_str());
34 if (!ret)
35 throw std::runtime_error("Error compiling : " + cmd);
36 std::cout << "compiled : " << cmd << std::endl;
37 std::cout << "Model is ready to be evaluated" << std::endl;
38 return;
39}
40
41void TMVA_SOFIE_RDataFrame_JIT(std::string modelFile = "Higgs_trained_model.h5"){
42
44
45 // check if the input file exists
46 if (gSystem->AccessPathName(modelFile.c_str())) {
47 Info("TMVA_SOFIE_RDataFrame","You need to run TMVA_Higgs_Classification.C to generate the Keras trained model");
48 return;
49 }
50
51 // parse the input Keras model into RModel object
52 SOFIE::RModel model = SOFIE::PyKeras::Parse(modelFile);
53
54 std::string modelName = modelFile.substr(0,modelFile.find(".h5"));
55 std::string modelHeaderFile = modelName + std::string(".hxx");
56 //Generating inference code
57 model.Generate();
58 model.OutputGenerated(modelHeaderFile);
59 model.PrintGenerated();
60 // check that also weigh file exists
61 std::string modelWeightFile = modelName + std::string(".dat");
62 if (gSystem->AccessPathName(modelWeightFile.c_str())) {
63 Error("TMVA_SOFIE_RDataFrame","Generated weight file is missing");
64 return;
65 }
66
67 // now compile using ROOT JIT trained model (see function above)
68 CompileModelForRDF(modelHeaderFile,7);
69
70 std::string inputFileName = "Higgs_data.root";
71 std::string inputFile = "http://root.cern.ch/files/" + inputFileName;
72
73 ROOT::RDataFrame df1("sig_tree", inputFile);
74 auto h1 = df1.Define("DNN_Value", "sofie_functor(rdfslot_,m_jj, m_jjj, m_lv, m_jlv, m_bb, m_wbb, m_wwbb)")
75 .Histo1D({"h_sig", "", 100, 0, 1},"DNN_Value");
76
77 ROOT::RDataFrame df2("bkg_tree", inputFile);
78 auto h2 = df2.Define("DNN_Value", "sofie_functor(rdfslot_,m_jj, m_jjj, m_lv, m_jlv, m_bb, m_wbb, m_wwbb)")
79 .Histo1D({"h_bkg", "", 100, 0, 1},"DNN_Value");
80
82 h2->SetLineColor(kBlue);
83
84 auto c1 = new TCanvas();
86
87 h2->DrawClone();
88 h1->DrawClone("SAME");
89 c1->BuildLegend();
90
91
92}
@ kRed
Definition Rtypes.h:66
@ kBlue
Definition Rtypes.h:66
void Info(const char *location, const char *msgfmt,...)
Use this function for informational messages.
Definition TError.cxx:230
void Error(const char *location, const char *msgfmt,...)
Use this function in case an error occurred.
Definition TError.cxx:197
#define gInterpreter
R__EXTERN TStyle * gStyle
Definition TStyle.h:414
R__EXTERN TSystem * gSystem
Definition TSystem.h:560
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...
virtual void SetLineColor(Color_t lcolor)
Set the line color.
Definition TAttLine.h:40
The Canvas class.
Definition TCanvas.h:23
void OutputGenerated(std::string filename="")
Definition RModel.cxx:627
void Generate(std::underlying_type_t< Options > options, int batchSize=1)
Definition RModel.cxx:240
static void PyInitialize()
Initialize Python interpreter.
virtual TObject * DrawClone(Option_t *option="") const
Draw a clone of this object in the current selected pad with: gROOT->SetSelectedPad(c1).
Definition TObject.cxx:299
void SetOptStat(Int_t stat=1)
The type of information printed in the histogram statistics box can be selected via the parameter mod...
Definition TStyle.cxx:1589
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition TSystem.cxx:1299
return c1
Definition legend1.C:41
TH1F * h1
Definition legend1.C:5