17from os.path
import exists
23modelFile =
"HiggsModel.keras"
25if not exists(modelFile):
26 raise FileNotFoundError(
"You need to run TMVA_Higgs_Classification.C to generate the Keras trained model")
30model = ROOT.TMVA.Experimental.SOFIE.PyKeras.Parse(modelFile)
32generatedHeaderFile = modelFile.replace(
".keras",
".hxx")
33print(
"Generating inference code for the Keras model from ",modelFile,
"in the header ", generatedHeaderFile)
36model.OutputGenerated(generatedHeaderFile)
40modelName = modelFile.replace(
".keras",
"")
41print(
"compiling SOFIE model ", modelName)
42ROOT.gInterpreter.Declare(
'#include "' + generatedHeaderFile +
'"')
45generatedHeaderFile = modelFile.replace(
".keras",
".hxx")
46print(
"Generating inference code for the Keras model from ",modelFile,
"in the header ", generatedHeaderFile)
49inputFileName =
"Higgs_data.root"
50inputFile = str(ROOT.gROOT.GetTutorialDir()) +
"/machine_learning/data/" + inputFileName
59sigData = df1.AsNumpy(columns=[
'm_jj',
'm_jjj',
'm_lv',
'm_jlv',
'm_bb',
'm_wbb',
'm_wwbb'])
63xsig = np.column_stack(list(sigData.values()))
64dataset_size = xsig.shape[0]
65print(
"size of signal data", dataset_size)
70sofie = getattr(ROOT,
'TMVA_SOFIE_' + modelName)
71session = sofie.Session()
73print(
"Evaluating SOFIE models on signal data")
74hs = ROOT.TH1D(
"hs",
"Signal result",100,0,1)
75for i
in range(0,dataset_size):
76 result = session.infer(xsig[i,:])
77 if (i % dataset_size/10 == 0) :
78 print(
"result for signal event ",i,result[0])
81print(
"using RDsataFrame to extract input data in a numpy array")
84bkgData = df2.AsNumpy(columns=[
'm_jj',
'm_jjj',
'm_lv',
'm_jlv',
'm_bb',
'm_wbb',
'm_wwbb'])
86xbkg = np.column_stack(list(bkgData.values()))
87dataset_size = xbkg.shape[0]
88print(
"size of background data", dataset_size)
90hb = ROOT.TH1D(
"hb",
"Background result",100,0,1)
91for i
in range(0,dataset_size):
92 result = session.infer(xbkg[i,:])
93 if (i % dataset_size/10 == 0) :
94 print(
"result for background event ",i,result[0])
100ROOT.gStyle.SetOptStat(0)
101hs.SetLineColor(
"kRed")
103hb.SetLineColor(
"kBlue")
109print(
"Number of signal entries",hs.GetEntries())
110print(
"Number of background entries",hb.GetEntries())
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...