ROOT
master
Reference Guide
Loading...
Searching...
No Matches
TMVA_SOFIE_RDataFrame.py
Go to the documentation of this file.
1
### \file
2
### \ingroup tutorial_tmva
3
### \notebook -nodraw
4
### Example of inference with SOFIE and RDataFrame, of a model trained with Keras.
5
### First, generate the input model by running `TMVA_Higgs_Classification.C`.
6
###
7
### This tutorial parses the input model and runs the inference using ROOT's JITing capability.
8
###
9
### \macro_code
10
### \macro_output
11
### \author Lorenzo Moneta
12
13
import
ROOT
14
from
os.path
import
exists
15
16
ROOT.TMVA.PyMethodBase.PyInitialize
()
17
18
19
# check if the input file exists
20
modelFile =
"Higgs_trained_model.h5"
21
modelName =
"Higgs_trained_model"
;
22
23
if
not
exists(modelFile):
24
raise
FileNotFoundError
(
"You need to run TMVA_Higgs_Classification.C to generate the Keras trained model"
)
25
26
# parse the input Keras model into RModel object
27
model =
ROOT.TMVA.Experimental.SOFIE.PyKeras.Parse
(modelFile)
28
29
# generating inference code
30
model.Generate
()
31
model.OutputGenerated
(
"Higgs_trained_model_generated.hxx"
)
32
model.PrintGenerated
()
33
34
# compile using ROOT JIT trained model
35
print(
"compiling SOFIE model and functor...."
)
36
ROOT.gInterpreter.Declare
(
'#include "Higgs_trained_model_generated.hxx"'
)
37
ROOT.gInterpreter.Declare
(
'auto sofie_functor = TMVA::Experimental::SofieFunctor<7,TMVA_SOFIE_'
+modelName+
'::Session>(0,"Higgs_trained_model_generated.dat");'
)
38
39
# run inference over input data
40
inputFile =
ROOT.gROOT.GetTutorialDir
() +
"machine_learning/data/Higgs_data.root"
41
df1 =
ROOT.RDataFrame
(
"sig_tree"
, inputFile)
42
h1 =
df1.Define
(
"DNN_Value"
,
"sofie_functor(rdfslot_,m_jj, m_jjj, m_lv, m_jlv, m_bb, m_wbb, m_wwbb)"
).Histo1D((
"h_sig"
,
""
, 100, 0, 1),
"DNN_Value"
)
43
44
df2 =
ROOT.RDataFrame
(
"bkg_tree"
, inputFile)
45
h2 =
df2.Define
(
"DNN_Value"
,
"sofie_functor(rdfslot_,m_jj, m_jjj, m_lv, m_jlv, m_bb, m_wbb, m_wwbb)"
).Histo1D((
"h_bkg"
,
""
, 100, 0, 1),
"DNN_Value"
)
46
47
# run over the input data once, combining both RDataFrame graphs.
48
ROOT.RDF.RunGraphs
([h1, h2]);
49
50
print(
"Number of signal entries"
,
h1.GetEntries
())
51
print(
"Number of background entries"
,
h2.GetEntries
())
52
53
h1.SetLineColor
(
"kRed"
)
54
h2.SetLineColor
(
"kBlue"
)
55
56
c1 =
ROOT.TCanvas
()
57
ROOT.gStyle.SetOptStat
(0)
58
59
h2.DrawClone
()
60
h1.DrawClone
(
"SAME"
)
TRangeDynCast
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Definition
TCollection.h:358
ROOT::Detail::TRangeCast
Definition
TCollection.h:311
ROOT::RDataFrame
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...
Definition
RDataFrame.hxx:41
tutorials
machine_learning
TMVA_SOFIE_RDataFrame.py
ROOT master - Reference Guide Generated on Tue Mar 11 2025 12:22:39 (GVA Time) using Doxygen 1.10.0