38modelName = sys.argv[1]
41@contextlib.contextmanager
42def expect_warning(category, message):
43 # Silence a known third-party warning and raise if it stops firing.
45 # Notifies us to drop the workaround once the upstream library is fixed.
46 with warnings.catch_warnings(record=True) as caught:
47 warnings.simplefilter("always")
51 if issubclass(w.category, category) and message in str(w.message):
54 warnings.warn_explicit(w.message, w.category, w.filename, w.lineno)
57 f"Expected {category.__name__} containing {message!r} was not "
58 "emitted. This tutorial's workaround can probably be removed."
62def CreateAndTrainModel(modelName):
64 model = nn.Sequential(nn.Linear(32, 16), nn.ReLU(), nn.Linear(16, 8), nn.ReLU(), nn.Linear(8, 2), nn.Softmax(dim=1))
66 criterion = nn.MSELoss()
67 optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
69 # train model with the random data
71 x = torch.randn(2, 32)
74 loss = criterion(y_pred, y)
79 # *******************************************************
82 # need to evaluate the model before exporting to ONNX
83 # and to provide a dummy input tensor to set the input model shape
86 modelFile = modelName + ".onnx"
87 dummy_x = torch.randn(1, 32)
90 # check for torch.onnx.export parameters
91 def filtered_kwargs(func, **candidate_kwargs):
92 sig = inspect.signature(func)
93 return {k: v for k, v in candidate_kwargs.items() if k in sig.parameters}
95 kwargs = filtered_kwargs(
97 input_names=["input"],
98 output_names=["output"],
99 external_data=False, # may not exist
100 dynamo=True, # may not exist
102 print("calling torch.onnx.export with parameters", kwargs)
105 # torch.onnx.export (dynamo path) pickles its export program through
106 # copyreg, which still references the deprecated LeafSpec. The warning
107 # is emitted from inside PyTorch and cannot be avoided from user code.
108 with expect_warning(FutureWarning, "isinstance(treespec, LeafSpec)"):
109 torch.onnx.export(model, dummy_x, modelFile, **kwargs)
110 print("model exported to ONNX as", modelFile)
113 print("Cannot export model from pytorch to ONNX - with version ", torch.__version__)
114 # leave no .onnx behind: which the parent process treats as a RuntimeError
117CreateAndTrainModel(modelName)
130 print(
"0weight", data)
132 print(
"2weight", data)
142 print(
"Generated model header file ", modelCode)
152modelName =
"LinearModel"
153modelFile = modelName +
".onnx"
177sofie =
getattr(ROOT,
"TMVA_SOFIE_" + modelName)
181print(
"\n************************************************************")
182print(
"Running inference with SOFIE ")
183print(
"\ninput to model is ", x)
187print(
"-> output using SOFIE = ", y_sofie)
191 import onnxruntime
as ort
194 print(
"Running inference with ONNXRuntime ")
200 print(
"-> output using ORT =", y_ort)
202 testFailed = abs(y_sofie - y_ort) > 0.01
204 raise RuntimeError(
"Result is different between SOFIE and ONNXRT")
209 print(
"Missing ONNXRuntime: skipping comparison test")
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.