import os
import sys
import subprocess
import numpy as np
import ROOT
EXPORT_SCRIPT = r"""
import sys
import inspect
import warnings
import contextlib
import torch
import torch.nn as nn
modelName = sys.argv[1]
@contextlib.contextmanager
def expect_warning(category, message):
# Silence a known third-party warning and raise if it stops firing.
# Notifies us to drop the workaround once the upstream library is fixed.
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
yield
seen = False
for w in caught:
if issubclass(w.category, category) and message in str(w.message):
seen = True
else:
warnings.warn_explicit(w.message, w.category, w.filename, w.lineno)
if not seen:
raise RuntimeError(
f"Expected {category.__name__} containing {message!r} was not "
"emitted. This tutorial's workaround can probably be removed."
)
def CreateAndTrainModel(modelName):
model = nn.Sequential(nn.Linear(32, 16), nn.ReLU(), nn.Linear(16, 8), nn.ReLU(), nn.Linear(8, 2), nn.Softmax(dim=1))
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# train model with the random data
for i in range(500):
x = torch.randn(2, 32)
y = torch.randn(2, 2)
y_pred = model(x)
loss = criterion(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# *******************************************************
## EXPORT to ONNX
#
# need to evaluate the model before exporting to ONNX
# and to provide a dummy input tensor to set the input model shape
model.eval()
modelFile = modelName + ".onnx"
dummy_x = torch.randn(1, 32)
model(dummy_x)
# check for torch.onnx.export parameters
def filtered_kwargs(func, **candidate_kwargs):
sig = inspect.signature(func)
return {k: v for k, v in candidate_kwargs.items() if k in sig.parameters}
kwargs = filtered_kwargs(
torch.onnx.export,
input_names=["input"],
output_names=["output"],
external_data=False, # may not exist
dynamo=True, # may not exist
)
print("calling torch.onnx.export with parameters", kwargs)
try:
# torch.onnx.export (dynamo path) pickles its export program through
# copyreg, which still references the deprecated LeafSpec. The warning
# is emitted from inside PyTorch and cannot be avoided from user code.
with expect_warning(FutureWarning, "isinstance(treespec, LeafSpec)"):
torch.onnx.export(model, dummy_x, modelFile, **kwargs)
print("model exported to ONNX as", modelFile)
return modelFile
except TypeError:
print("Cannot export model from pytorch to ONNX - with version ", torch.__version__)
# leave no .onnx behind: which the parent process treats as a RuntimeError
sys.exit()
CreateAndTrainModel(modelName)
"""
if verbose:
print("0weight", data)
print("2weight", data)
if verbose:
print("Generated model header file ", modelCode)
return modelCode
modelName = "LinearModel"
modelFile = modelName + ".onnx"
sofie =
getattr(ROOT,
"TMVA_SOFIE_" + modelName)
print("\n************************************************************")
print("Running inference with SOFIE ")
print("\ninput to model is ", x)
print("-> output using SOFIE = ", y_sofie)
try:
import onnxruntime as ort
print("Running inference with ONNXRuntime ")
y_ort = outputs[0]
print("-> output using ORT =", y_ort)
testFailed = abs(y_sofie - y_ort) > 0.01
raise RuntimeError(
"Result is different between SOFIE and ONNXRT")
else:
print("OK")
except ImportError:
print("Missing ONNXRuntime: skipping comparison test")
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
calling torch.onnx.export with parameters {'input_names': ['input'], 'output_names': ['output'], 'external_data': False, 'dynamo': True}
[torch.onnx] Obtain model graph for `Sequential([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `Sequential([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decompositions...
[torch.onnx] Run decompositions... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
[torch.onnx] Optimize the ONNX graph...
[torch.onnx] Optimize the ONNX graph... ✅
model exported to ONNX as LinearModel.onnx
Generated model header file LinearModel.hxx
************************************************************
Running inference with SOFIE
input to model is [[ 0.7331435 -1.1810836 0.8479717 -2.8299522 1.6055895 0.958995
-1.1113143 0.17913957 0.94363946 -0.18613863 -0.2952656 -0.5260549
0.36408323 0.7399261 -2.2180364 1.9918501 -0.21362494 0.12576789
0.6231531 -1.6836269 -0.78191465 1.2402849 0.01697755 1.7398267
0.27317137 1.2935599 1.0760139 -0.02842077 -1.68563 0.2623278
0.33979574 0.50666916]]
-> output using SOFIE = [0.47936755 0.5206325 ]
Missing ONNXRuntime: skipping comparison test