In this case we are parsing the input file and then run the inference in the same macro making use of the ROOT JITing capability
import ROOT
import numpy as np
ROOT.TMVA.PyMethodBase.PyInitialize()
modelFile = "Higgs_trained_model.h5"
if (ROOT.gSystem.AccessPathName(modelFile)) :
ROOT.Info("TMVA_SOFIE_RDataFrame","You need to run TMVA_Higgs_Classification.C to generate the Keras trained model")
exit()
model = ROOT.TMVA.Experimental.SOFIE.PyKeras.Parse(modelFile)
generatedHeaderFile = modelFile.replace(".h5",".hxx")
print("Generating inference code for the Keras model from ",modelFile,"in the header ", generatedHeaderFile)
model.Generate()
model.OutputGenerated(generatedHeaderFile)
model.PrintGenerated()
modelName = modelFile.replace(".h5","")
print("compiling SOFIE model ", modelName)
ROOT.gInterpreter.Declare('#include "' + generatedHeaderFile + '"')
generatedHeaderFile = modelFile.replace(".h5",".hxx")
print("Generating inference code for the Keras model from ",modelFile,"in the header ", generatedHeaderFile)
inputFileName = "Higgs_data.root"
inputFile = str(ROOT.gROOT.GetTutorialDir()) + "/tmva/data/" + inputFileName
sigData = df1.AsNumpy(columns=['m_jj', 'm_jjj', 'm_lv', 'm_jlv', 'm_bb', 'm_wbb', 'm_wwbb'])
xsig = np.column_stack(list(sigData.values()))
dataset_size = xsig.shape[0]
print("size of data", dataset_size)
session = ROOT.TMVA_SOFIE_Higgs_trained_model.Session()
hs = ROOT.TH1D("hs","Signal result",100,0,1)
for i in range(0,dataset_size):
result = session.infer(xsig[i,:])
hs.Fill(result[0])
bkgData = df2.AsNumpy(columns=['m_jj', 'm_jjj', 'm_lv', 'm_jlv', 'm_bb', 'm_wbb', 'm_wwbb'])
xbkg = np.column_stack(list(bkgData.values()))
dataset_size = xbkg.shape[0]
hb = ROOT.TH1D("hb","Background result",100,0,1)
for i in range(0,dataset_size):
result = session.infer(xbkg[i,:])
hb.Fill(result[0])
c1 = ROOT.TCanvas()
ROOT.gStyle.SetOptStat(0)
hs.SetLineColor(ROOT.kRed)
hs.Draw()
hb.SetLineColor(ROOT.kBlue)
hb.Draw("SAME")
c1.BuildLegend()
c1.Draw()
print("Number of signal entries",hs.GetEntries())
print("Number of background entries",hb.GetEntries())
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...