Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA_SOFIE_ONNX.py
Go to the documentation of this file.
1## \file
2## \ingroup tutorial_ml
3## \notebook -nodraw
4## This macro provides a simple example for:
5## - creating a model with Pytorch and export to ONNX
6## - parsing the ONNX file with SOFIE and generate C++ code
7## - compiling the model using ROOT Cling
8## - run the code and optionally compare with ONNXRuntime
9##
10##
11## \macro_code
12## \macro_output
13## \author Lorenzo Moneta
14
15
16import inspect
17
18import numpy as np
19import ROOT
20import torch
21import torch.nn as nn
22
23
24def CreateAndTrainModel(modelName):
25
26 model = nn.Sequential(
27 nn.Linear(32,16),
28 nn.ReLU(),
29 nn.Linear(16,8),
30 nn.ReLU(),
31 nn.Linear(8,2),
32 nn.Softmax(dim=1)
33 )
34
35 criterion = nn.MSELoss()
36 optimizer = torch.optim.SGD(model.parameters(),lr=0.01)
37
38
39 #train model with the random data
40 for i in range(500):
41 x=torch.randn(2,32)
42 y=torch.randn(2,2)
43 y_pred = model(x)
44 loss = criterion(y_pred,y)
48
49 #*******************************************************
50 ## EXPORT to ONNX
51 #
52 # need to evaluate the model before exporting to ONNX
53 # and to provide a dummy input tensor to set the input model shape
55
56 modelFile = modelName + ".onnx"
57 dummy_x = torch.randn(1,32)
58 model(dummy_x)
59
60 #check for torch.onnx.export parameters
61 def filtered_kwargs(func, **candidate_kwargs):
62 sig = inspect.signature(func)
63 return {
64 k: v for k, v in candidate_kwargs.items()
65 if k in sig.parameters
66 }
67 kwargs = filtered_kwargs(
69 input_names=["input"],
70 output_names=["output"],
71 external_data=False, # may not exist
72 dynamo=True # may not exist
73 )
74 print("calling torch.onnx.export with parameters",kwargs)
75
76 try:
77 torch.onnx.export(model, dummy_x, modelFile, **kwargs)
78 print("model exported to ONNX as",modelFile)
79 return modelFile
80 except TypeError:
81 print("Cannot export model from pytorch to ONNX - with version ",torch.__version__)
82 print("Skip tutorial execution")
83 exit()
84
85
86def ParseModel(modelFile, verbose=False):
87
89 model = parser.Parse(modelFile,verbose)
90 #
91 #print model weights
92 if (verbose):
94 data = model.GetTensorData['float']('0weight')
95 print("0weight",data)
96 data = model.GetTensorData['float']('2weight')
97 print("2weight",data)
98
99 # Generating inference code
101 #generate header file (and .dat file) with modelName+.hxx
103 if (verbose) :
105
106 modelCode = modelFile.replace(".onnx",".hxx")
107 print("Generated model header file ",modelCode)
108 return modelCode
109
110###################################################################
111## Step 1 : Create and Train model
112###################################################################
113
114#use an arbitrary modelName
115modelName = "LinearModel"
116modelFile = CreateAndTrainModel(modelName)
117
118
119###################################################################
120## Step 2 : Parse model and generate inference code with SOFIE
121###################################################################
122
123modelCode = ParseModel(modelFile, False)
124
125###################################################################
126## Step 3 : Compile the generated C++ model code
127###################################################################
128
129ROOT.gInterpreter.Declare('#include "' + modelCode + '"')
130
131###################################################################
132## Step 4: Evaluate the model
133###################################################################
134
135#get first the SOFIE session namespace
136sofie = getattr(ROOT, 'TMVA_SOFIE_' + modelName)
137session = sofie.Session()
138
139x = np.random.normal(0,1,(1,32)).astype(np.float32)
140print("\n************************************************************")
141print("Running inference with SOFIE ")
142print("\ninput to model is ",x)
143y = session.infer(x)
144# output shape is (1,2)
145y_sofie = np.asarray(y.data())
146print("-> output using SOFIE = ", y_sofie)
147
148#check inference with onnx
149try:
150 import onnxruntime as ort
151 # Load model
152 print("Running inference with ONNXRuntime ")
153 ort_session = ort.InferenceSession(modelFile)
154
155 # Run inference
156 outputs = ort_session.run(None, {"input": x})
157 y_ort = outputs[0]
158 print("-> output using ORT =", y_ort)
159
160 testFailed = abs(y_sofie-y_ort) > 0.01
161 if (np.any(testFailed)):
162 raiseError('Result is different between SOFIE and ONNXRT')
163 else :
164 print("OK")
165
166except ImportError:
167 print("Missing ONNXRuntime: skipping comparison test")
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.