In this case we are parsing the input file and then run the inference in the same macro making use of the ROOT JITing capability
void CompileModelForRDF(const std::string & headerModelFile, unsigned int ninputs, unsigned int nslots=0) {
std::string modelName = headerModelFile.substr(0,headerModelFile.find(".hxx"));
std::string cmd = std::string("#include \"") + headerModelFile + std::string("\"");
throw std::runtime_error("Error compiling : " + cmd);
std::cout << "compiled : " << cmd << std::endl;
cmd = "auto sofie_functor = TMVA::Experimental::SofieFunctor<" + std::to_string(ninputs) + ",TMVA_SOFIE_" +
modelName + "::Session>(" + std::to_string(nslots) + ");";
throw std::runtime_error("Error compiling : " + cmd);
std::cout << "compiled : " << cmd << std::endl;
std::cout << "Model is ready to be evaluated" << std::endl;
return;
}
void TMVA_SOFIE_RDataFrame_JIT(std::string modelName = "HiggsModel"){
std::string modelHeaderFile = modelName + ".hxx";
if (
gSystem->AccessPathName(modelHeaderFile.c_str())) {
Info(
"TMVA_SOFIE_RDataFrame",
"You need to run TMVA_SOFIE_Keras_Higgs_Model.py to generate the SOFIE header for the Keras trained model");
return;
}
std::string modelWeightFile = modelName + std::string(".dat");
if (
gSystem->AccessPathName(modelWeightFile.c_str())) {
Error(
"TMVA_SOFIE_RDataFrame",
"Generated weight file is missing");
return;
}
CompileModelForRDF(modelHeaderFile,7);
std::string inputFileName = "Higgs_data.root";
std::string inputFile = std::string{
gROOT->GetTutorialDir()} +
"/machine_learning/data/" + inputFileName;
auto h1 = df1.Define(
"DNN_Value",
"sofie_functor(rdfslot_,m_jj, m_jjj, m_lv, m_jlv, m_bb, m_wbb, m_wwbb)")
.Histo1D({"h_sig", "", 100, 0, 1},"DNN_Value");
auto h2 = df2.Define("DNN_Value", "sofie_functor(rdfslot_,m_jj, m_jjj, m_lv, m_jlv, m_bb, m_wbb, m_wwbb)")
.Histo1D({"h_bkg", "", 100, 0, 1},"DNN_Value");
h2->DrawClone();
}
Error("WriteTObject","The current directory (%s) is not associated with a file. The object (%s) has not been written.", GetName(), objname)
void Info(const char *location, const char *msgfmt,...)
Use this function for informational messages.
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...