Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA_SOFIE_ONNX.py
Go to the documentation of this file.
1## \file
2## \ingroup tutorial_ml
3## \notebook -nodraw
4## This macro provides a simple example for:
5## - creating a model with Pytorch and export to ONNX
6## - parsing the ONNX file with SOFIE and generate C++ code
7## - compiling the model using ROOT Cling
8## - run the code and optionally compare with ONNXRuntime
9##
10## The PyTorch export and ROOT's SOFIE parser are both linked against protobuf,
11## but usually against different versions, so loading them in the same process
12## leads to a symbol clash. We therefore run the PyTorch -> ONNX export in a
13## separate Python process and only import ROOT afterwards.
14##
15## \macro_code
16## \macro_output
17## \author Lorenzo Moneta
18
19import os
20import sys
21import subprocess
22
23import numpy as np
24import ROOT
25
26
27# The PyTorch export, as a small standalone script run in its own process.
28# It takes the model name as its only argument and writes <modelName>.onnx.
29EXPORT_SCRIPT = r"""
30import sys
31import inspect
32import warnings
33import contextlib
34
35import torch
36import torch.nn as nn
37
38modelName = sys.argv[1]
39
40
41@contextlib.contextmanager
42def expect_warning(category, message):
43 # Silence a known third-party warning and raise if it stops firing.
44
45 # Notifies us to drop the workaround once the upstream library is fixed.
46 with warnings.catch_warnings(record=True) as caught:
47 warnings.simplefilter("always")
48 yield
49 seen = False
50 for w in caught:
51 if issubclass(w.category, category) and message in str(w.message):
52 seen = True
53 else:
54 warnings.warn_explicit(w.message, w.category, w.filename, w.lineno)
55 if not seen:
56 raise RuntimeError(
57 f"Expected {category.__name__} containing {message!r} was not "
58 "emitted. This tutorial's workaround can probably be removed."
59 )
60
61
62def CreateAndTrainModel(modelName):
63
64 model = nn.Sequential(nn.Linear(32, 16), nn.ReLU(), nn.Linear(16, 8), nn.ReLU(), nn.Linear(8, 2), nn.Softmax(dim=1))
65
66 criterion = nn.MSELoss()
67 optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
68
69 # train model with the random data
70 for i in range(500):
71 x = torch.randn(2, 32)
72 y = torch.randn(2, 2)
73 y_pred = model(x)
74 loss = criterion(y_pred, y)
75 optimizer.zero_grad()
76 loss.backward()
77 optimizer.step()
78
79 # *******************************************************
80 ## EXPORT to ONNX
81 #
82 # need to evaluate the model before exporting to ONNX
83 # and to provide a dummy input tensor to set the input model shape
84 model.eval()
85
86 modelFile = modelName + ".onnx"
87 dummy_x = torch.randn(1, 32)
88 model(dummy_x)
89
90 # check for torch.onnx.export parameters
91 def filtered_kwargs(func, **candidate_kwargs):
92 sig = inspect.signature(func)
93 return {k: v for k, v in candidate_kwargs.items() if k in sig.parameters}
94
95 kwargs = filtered_kwargs(
96 torch.onnx.export,
97 input_names=["input"],
98 output_names=["output"],
99 external_data=False, # may not exist
100 dynamo=True, # may not exist
101 )
102 print("calling torch.onnx.export with parameters", kwargs)
103
104 try:
105 # torch.onnx.export (dynamo path) pickles its export program through
106 # copyreg, which still references the deprecated LeafSpec. The warning
107 # is emitted from inside PyTorch and cannot be avoided from user code.
108 with expect_warning(FutureWarning, "isinstance(treespec, LeafSpec)"):
109 torch.onnx.export(model, dummy_x, modelFile, **kwargs)
110 print("model exported to ONNX as", modelFile)
111 return modelFile
112 except TypeError:
113 print("Cannot export model from pytorch to ONNX - with version ", torch.__version__)
114 # leave no .onnx behind: which the parent process treats as a RuntimeError
115 sys.exit()
116
117CreateAndTrainModel(modelName)
118"""
119
120
121def ParseModel(modelFile, verbose=False):
122
124 model = parser.Parse(modelFile, verbose)
125 #
126 # print model weights
127 if verbose:
129 data = model.GetTensorData["float"]("0weight")
130 print("0weight", data)
131 data = model.GetTensorData["float"]("2weight")
132 print("2weight", data)
133
134 # Generating inference code
136 # generate header file (and .dat file) with modelName+.hxx
138 if verbose:
140
141 modelCode = modelFile.replace(".onnx", ".hxx")
142 print("Generated model header file ", modelCode)
143 return modelCode
144
145
146###################################################################
147## Step 1 : Create and train the model, export it to ONNX
148## (done in a separate process to avoid the protobuf clash)
149###################################################################
150
151# use an arbitrary modelName
152modelName = "LinearModel"
153modelFile = modelName + ".onnx"
154
155subprocess.run([sys.executable, "-c", EXPORT_SCRIPT, modelName])
156if not os.path.exists(modelFile):
157 raise RuntimeError("ONNX model could not be exported")
158
159
160###################################################################
161## Step 2 : Parse model and generate inference code with SOFIE
162###################################################################
163
164modelCode = ParseModel(modelFile, False)
165
166###################################################################
167## Step 3 : Compile the generated C++ model code
168###################################################################
169
170ROOT.gInterpreter.Declare('#include "' + modelCode + '"')
171
172###################################################################
173## Step 4: Evaluate the model
174###################################################################
175
176# get first the SOFIE session namespace
177sofie = getattr(ROOT, "TMVA_SOFIE_" + modelName)
178session = sofie.Session()
179
180x = np.random.normal(0, 1, (1, 32)).astype(np.float32)
181print("\n************************************************************")
182print("Running inference with SOFIE ")
183print("\ninput to model is ", x)
184y = session.infer(x)
185# output shape is (1,2)
186y_sofie = np.asarray(y.data())
187print("-> output using SOFIE = ", y_sofie)
188
189# check inference with onnx
190try:
191 import onnxruntime as ort
192
193 # Load model
194 print("Running inference with ONNXRuntime ")
195 ort_session = ort.InferenceSession(modelFile)
196
197 # Run inference
198 outputs = ort_session.run(None, {"input": x})
199 y_ort = outputs[0]
200 print("-> output using ORT =", y_ort)
201
202 testFailed = abs(y_sofie - y_ort) > 0.01
203 if np.any(testFailed):
204 raise RuntimeError("Result is different between SOFIE and ONNXRT")
205 else:
206 print("OK")
207
208except ImportError:
209 print("Missing ONNXRuntime: skipping comparison test")
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.