Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA_SOFIE_RDataFrame_JIT.C File Reference

Detailed Description

View in nbviewer Open in SWAN
This macro provides an example of using a trained model with Keras and make inference using SOFIE and RDataFrame This macro uses as input a Keras model generated with the TMVA_Higgs_Classification.C tutorial You need to run that macro before this one.

In this case we are parsing the input file and then run the inference in the same macro making use of the ROOT JITing capability

using namespace TMVA::Experimental;
/// function to compile the generated model and the declaration of the SofieFunctor
/// used by RDF.
/// Assume that the model name as in the header file
void CompileModelForRDF(const std::string & headerModelFile, unsigned int ninputs, unsigned int nslots=0) {
std::string modelName = headerModelFile.substr(0,headerModelFile.find(".hxx"));
std::string cmd = std::string("#include \"") + headerModelFile + std::string("\"");
auto ret = gInterpreter->Declare(cmd.c_str());
if (!ret)
throw std::runtime_error("Error compiling : " + cmd);
std::cout << "compiled : " << cmd << std::endl;
cmd = "auto sofie_functor = TMVA::Experimental::SofieFunctor<" + std::to_string(ninputs) + ",TMVA_SOFIE_" +
modelName + "::Session>(" + std::to_string(nslots) + ");";
ret = gInterpreter->Declare(cmd.c_str());
if (!ret)
throw std::runtime_error("Error compiling : " + cmd);
std::cout << "compiled : " << cmd << std::endl;
std::cout << "Model is ready to be evaluated" << std::endl;
return;
}
void TMVA_SOFIE_RDataFrame_JIT(std::string modelFile = "Higgs_trained_model.h5"){
// check if the input file exists
if (gSystem->AccessPathName(modelFile.c_str())) {
Info("TMVA_SOFIE_RDataFrame","You need to run TMVA_Higgs_Classification.C to generate the Keras trained model");
return;
}
// parse the input Keras model into RModel object
SOFIE::RModel model = SOFIE::PyKeras::Parse(modelFile);
std::string modelName = modelFile.substr(0,modelFile.find(".h5"));
std::string modelHeaderFile = modelName + std::string(".hxx");
//Generating inference code
model.Generate();
model.PrintGenerated();
// check that also weigh file exists
std::string modelWeightFile = modelName + std::string(".dat");
Error("TMVA_SOFIE_RDataFrame","Generated weight file is missing");
return;
}
// now compile using ROOT JIT trained model (see function above)
std::string inputFileName = "Higgs_data.root";
std::string inputFile = gROOT->GetTutorialDir() + "/machine_learning/data/" + inputFileName;
auto h1 = df1.Define("DNN_Value", "sofie_functor(rdfslot_,m_jj, m_jjj, m_lv, m_jlv, m_bb, m_wbb, m_wwbb)")
.Histo1D({"h_sig", "", 100, 0, 1},"DNN_Value");
.Histo1D({"h_bkg", "", 100, 0, 1},"DNN_Value");
h2->SetLineColor(kBlue);
auto c1 = new TCanvas();
h2->DrawClone();
h1->DrawClone("SAME");
c1->BuildLegend();
}
@ kRed
Definition Rtypes.h:66
@ kBlue
Definition Rtypes.h:66
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
void Error(const char *location, const char *msgfmt,...)
Use this function in case an error occurred.
Definition TError.cxx:185
#define gInterpreter
#define gROOT
Definition TROOT.h:414
R__EXTERN TStyle * gStyle
Definition TStyle.h:442
R__EXTERN TSystem * gSystem
Definition TSystem.h:572
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...
virtual void SetLineColor(Color_t lcolor)
Set the line color.
Definition TAttLine.h:42
The Canvas class.
Definition TCanvas.h:23
void OutputGenerated(std::string filename="", bool append=false)
Definition RModel.cxx:1284
void Generate(std::underlying_type_t< Options > options, int batchSize=-1, long pos=0, bool verbose=false)
Definition RModel.cxx:916
static void PyInitialize()
Initialize Python interpreter.
virtual TObject * DrawClone(Option_t *option="") const
Draw a clone of this object in the current selected pad with: gROOT->SetSelectedPad(c1).
Definition TObject.cxx:318
void SetOptStat(Int_t stat=1)
The type of information printed in the histogram statistics box can be selected via the parameter mod...
Definition TStyle.cxx:1642
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition TSystem.cxx:1309
std::ostream & Info()
Definition hadd.cxx:177
return c1
Definition legend1.C:41
TH1F * h1
Definition legend1.C:5
Author
Lorenzo Moneta

Definition in file TMVA_SOFIE_RDataFrame_JIT.C.