24def CreateAndTrainModel(modelName):
26 model = nn.Sequential(
35 criterion = nn.MSELoss()
36 optimizer = torch.optim.SGD(model.parameters(),lr=0.01)
44 loss = criterion(y_pred,y)
56 modelFile = modelName +
".onnx"
57 dummy_x = torch.randn(1,32)
61 def filtered_kwargs(func, **candidate_kwargs):
62 sig = inspect.signature(func)
64 k: v
for k, v
in candidate_kwargs.items()
65 if k
in sig.parameters
67 kwargs = filtered_kwargs(
69 input_names=[
"input"],
70 output_names=[
"output"],
74 print(
"calling torch.onnx.export with parameters",kwargs)
77 torch.onnx.export(model, dummy_x, modelFile, **kwargs)
78 print(
"model exported to ONNX as",modelFile)
81 print(
"Cannot export model from pytorch to ONNX - with version ",torch.__version__)
82 print(
"Skip tutorial execution")
86def ParseModel(modelFile, verbose=False):
88 parser = ROOT.TMVA.Experimental.SOFIE.RModelParser_ONNX()
89 model = parser.Parse(modelFile,verbose)
93 model.PrintInitializedTensors()
94 data = model.GetTensorData[
'float'](
'0weight')
96 data = model.GetTensorData[
'float'](
'2weight')
102 model.OutputGenerated()
104 model.PrintGenerated()
106 modelCode = modelFile.replace(
".onnx",
".hxx")
107 print(
"Generated model header file ",modelCode)
115modelName =
"LinearModel"
116modelFile = CreateAndTrainModel(modelName)
123modelCode = ParseModel(modelFile,
False)
129ROOT.gInterpreter.Declare(
'#include "' + modelCode +
'"')
136sofie = getattr(ROOT,
'TMVA_SOFIE_' + modelName)
137session = sofie.Session()
139x = np.random.normal(0,1,(1,32)).astype(np.float32)
140print(
"\n************************************************************")
141print(
"Running inference with SOFIE ")
142print(
"\ninput to model is ",x)
145y_sofie = np.asarray(y.data())
146print(
"-> output using SOFIE = ", y_sofie)
150 import onnxruntime
as ort
152 print(
"Running inference with ONNXRuntime ")
153 ort_session = ort.InferenceSession(modelFile)
156 outputs = ort_session.run(
None, {
"input": x})
158 print(
"-> output using ORT =", y_ort)
160 testFailed = abs(y_sofie-y_ort) > 0.01
161 if (np.any(testFailed)):
162 raiseError(
'Result is different between SOFIE and ONNXRT')
167 print(
"Missing ONNXRuntime: skipping comparison test")