In this case we are parsing the input file and then run the inference in the same macro making use of the ROOT JITing capability
void CompileModelForRDF(const std::string & headerModelFile, unsigned int ninputs, unsigned int nslots=0) {
std::string modelName = headerModelFile.substr(0,headerModelFile.find(".hxx"));
std::string cmd = std::string("#include \"") + headerModelFile + std::string("\"");
if (!ret)
throw std::runtime_error("Error compiling : " + cmd);
std::cout << "compiled : " << cmd << std::endl;
cmd = "auto sofie_functor = TMVA::Experimental::SofieFunctor<" + std::to_string(ninputs) + ",TMVA_SOFIE_" +
modelName + "::Session>(" + std::to_string(nslots) + ");";
if (!ret)
throw std::runtime_error("Error compiling : " + cmd);
std::cout << "compiled : " << cmd << std::endl;
std::cout << "Model is ready to be evaluated" << std::endl;
return;
}
void TMVA_SOFIE_RDataFrame_JIT(std::string modelFile = "Higgs_trained_model.h5"){
if (
gSystem->AccessPathName(modelFile.c_str())) {
Info(
"TMVA_SOFIE_RDataFrame",
"You need to run TMVA_Higgs_Classification.C to generate the Keras trained model");
return;
}
std::string modelName = modelFile.substr(0,modelFile.find(".h5"));
std::string modelHeaderFile = modelName + std::string(".hxx");
std::string modelWeightFile = modelName + std::string(".dat");
if (
gSystem->AccessPathName(modelWeightFile.c_str())) {
Error(
"TMVA_SOFIE_RDataFrame",
"Generated weight file is missing");
return;
}
CompileModelForRDF(modelHeaderFile,7);
std::string inputFileName = "Higgs_data.root";
std::string inputFile = std::string(
gROOT->GetTutorialDir()) +
"/tmva/data/" + inputFileName;
auto h1 = df1.Define(
"DNN_Value",
"sofie_functor(rdfslot_,m_jj, m_jjj, m_lv, m_jlv, m_bb, m_wbb, m_wwbb)")
.Histo1D({"h_sig", "", 100, 0, 1},"DNN_Value");
auto h2 = df2.Define(
"DNN_Value",
"sofie_functor(rdfslot_,m_jj, m_jjj, m_lv, m_jlv, m_bb, m_wbb, m_wwbb)")
.Histo1D({"h_bkg", "", 100, 0, 1},"DNN_Value");
}
void Info(const char *location, const char *msgfmt,...)
Use this function for informational messages.
void Error(const char *location, const char *msgfmt,...)
Use this function in case an error occurred.
R__EXTERN TStyle * gStyle
R__EXTERN TSystem * gSystem
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...
void OutputGenerated(std::string filename="", bool append=false)
void Generate(std::underlying_type_t< Options > options, int batchSize=-1, long pos=0)
static void PyInitialize()
Initialize Python interpreter.
RModel Parse(std::string filename, int batch_size=-1)
Parser function for translatng Keras .h5 model into a RModel object.