import os
import sys
import subprocess
import numpy as np
import ROOT
EXPORT_SCRIPT = r"""
import sys
import inspect
import warnings
import contextlib
import torch
import torch.nn as nn
modelName = sys.argv[1]
@contextlib.contextmanager
def expect_warning(category, message):
# Silence a known third-party warning and raise if it stops firing.
# Notifies us to drop the workaround once the upstream library is fixed.
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
yield
seen = False
for w in caught:
if issubclass(w.category, category) and message in str(w.message):
seen = True
else:
warnings.warn_explicit(w.message, w.category, w.filename, w.lineno)
if not seen:
raise RuntimeError(
f"Expected {category.__name__} containing {message!r} was not "
"emitted. This tutorial's workaround can probably be removed."
)
def CreateAndTrainModel(modelName):
model = nn.Sequential(nn.Linear(32, 16), nn.ReLU(), nn.Linear(16, 8), nn.ReLU(), nn.Linear(8, 2), nn.Softmax(dim=1))
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# train model with the random data
for i in range(500):
x = torch.randn(2, 32)
y = torch.randn(2, 2)
y_pred = model(x)
loss = criterion(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# *******************************************************
## EXPORT to ONNX
#
# need to evaluate the model before exporting to ONNX
# and to provide a dummy input tensor to set the input model shape
model.eval()
modelFile = modelName + ".onnx"
dummy_x = torch.randn(1, 32)
model(dummy_x)
# check for torch.onnx.export parameters
def filtered_kwargs(func, **candidate_kwargs):
sig = inspect.signature(func)
return {k: v for k, v in candidate_kwargs.items() if k in sig.parameters}
kwargs = filtered_kwargs(
torch.onnx.export,
input_names=["input"],
output_names=["output"],
external_data=False, # may not exist
dynamo=True, # may not exist
)
print("calling torch.onnx.export with parameters", kwargs)
try:
# torch.onnx.export (dynamo path) pickles its export program through
# copyreg, which still references the deprecated LeafSpec. The warning
# is emitted from inside PyTorch and cannot be avoided from user code.
with expect_warning(FutureWarning, "isinstance(treespec, LeafSpec)"):
torch.onnx.export(model, dummy_x, modelFile, **kwargs)
print("model exported to ONNX as", modelFile)
return modelFile
except TypeError:
print("Cannot export model from pytorch to ONNX - with version ", torch.__version__)
# leave no .onnx behind: which the parent process treats as a RuntimeError
sys.exit()
CreateAndTrainModel(modelName)
"""
if verbose:
print("0weight", data)
print("2weight", data)
if verbose:
print("Generated model header file ", modelCode)
return modelCode
modelName = "LinearModel"
modelFile = modelName + ".onnx"
sofie =
getattr(ROOT,
"TMVA_SOFIE_" + modelName)
print("\n************************************************************")
print("Running inference with SOFIE ")
print("\ninput to model is ", x)
print("-> output using SOFIE = ", y_sofie)
try:
import onnxruntime as ort
print("Running inference with ONNXRuntime ")
y_ort = outputs[0]
print("-> output using ORT =", y_ort)
testFailed = abs(y_sofie - y_ort) > 0.01
raise RuntimeError(
"Result is different between SOFIE and ONNXRT")
else:
print("OK")
except ImportError:
print("Missing ONNXRuntime: skipping comparison test")
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
calling torch.onnx.export with parameters {'input_names': ['input'], 'output_names': ['output'], 'external_data': False, 'dynamo': True}
[torch.onnx] Obtain model graph for `Sequential([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `Sequential([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decompositions...
[torch.onnx] Run decompositions... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
[torch.onnx] Optimize the ONNX graph...
[torch.onnx] Optimize the ONNX graph... ✅
model exported to ONNX as LinearModel.onnx
Generated model header file LinearModel.hxx
************************************************************
Running inference with SOFIE
input to model is [[ 1.2114118 1.451967 -0.14347067 0.97046924 0.16031976 -0.32432693
-0.15014982 -0.93951786 -1.4363766 1.892936 0.69244826 -1.468675
-1.1324115 -0.5894106 0.5675628 -1.4235257 -0.2687674 -1.5356702
0.01113858 -0.98209095 2.127685 -0.2864464 0.75798196 0.4604458
0.02719817 1.2260139 -1.01243 -0.6166083 0.36727965 -1.4795494
1.7457016 -1.143139 ]]
-> output using SOFIE = [0.51601905 0.48398095]
Missing ONNXRuntime: skipping comparison test