Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA_SOFIE_ONNX.py
Go to the documentation of this file.
1## \file
2## \ingroup tutorial_ml
3## \notebook -nodraw
4## This macro provides a simple example for:
5## - creating a model with Pytorch and export to ONNX
6## - parsing the ONNX file with SOFIE and generate C++ code
7## - compiling the model using ROOT Cling
8## - run the code and optionally compare with ONNXRuntime
9##
10##
11## \macro_code
12## \macro_output
13## \author Lorenzo Moneta
14
15
16import torch
17import torch.nn as nn
18import ROOT
19import numpy as np
20import inspect
21
22def CreateAndTrainModel(modelName):
23
24 model = nn.Sequential(
25 nn.Linear(32,16),
26 nn.ReLU(),
27 nn.Linear(16,8),
28 nn.ReLU(),
29 nn.Linear(8,2),
30 nn.Softmax(dim=1)
31 )
32
33 criterion = nn.MSELoss()
34 optimizer = torch.optim.SGD(model.parameters(),lr=0.01)
35
36
37 #train model with the random data
38 for i in range(500):
39 x=torch.randn(2,32)
40 y=torch.randn(2,2)
41 y_pred = model(x)
42 loss = criterion(y_pred,y)
46
47 #*******************************************************
48 ## EXPORT to ONNX
49 #
50 # need to evaluate the model before exporting to ONNX
51 # and to provide a dummy input tensor to set the input model shape
53
54 modelFile = modelName + ".onnx"
55 dummy_x = torch.randn(1,32)
56 model(dummy_x)
57
58 #check for torch.onnx.export parameters
59 def filtered_kwargs(func, **candidate_kwargs):
60 sig = inspect.signature(func)
61 return {
62 k: v for k, v in candidate_kwargs.items()
63 if k in sig.parameters
64 }
65 kwargs = filtered_kwargs(
67 input_names=["input"],
68 output_names=["output"],
69 external_data=False, # may not exist
70 dynamo=True # may not exist
71 )
72 print("calling torch.onnx.export with parameters",kwargs)
73
74 try:
75 torch.onnx.export(model, dummy_x, modelFile, **kwargs)
76 print("model exported to ONNX as",modelFile)
77 return modelFile
78 except TypeError:
79 print("Cannot export model from pytorch to ONNX - with version ",torch.__version__)
80 print("Skip tutorial execution")
81 exit()
82
83
84def ParseModel(modelFile, verbose=False):
85
87 model = parser.Parse(modelFile,verbose)
88 #
89 #print model weights
90 if (verbose):
92 data = model.GetTensorData['float']('0weight')
93 print("0weight",data)
94 data = model.GetTensorData['float']('2weight')
95 print("2weight",data)
96
97 # Generating inference code
99 #generate header file (and .dat file) with modelName+.hxx
101 if (verbose) :
103
104 modelCode = modelFile.replace(".onnx",".hxx")
105 print("Generated model header file ",modelCode)
106 return modelCode
107
108###################################################################
109## Step 1 : Create and Train model
110###################################################################
111
112#use an arbitrary modelName
113modelName = "LinearModel"
114modelFile = CreateAndTrainModel(modelName)
115
116
117###################################################################
118## Step 2 : Parse model and generate inference code with SOFIE
119###################################################################
120
121modelCode = ParseModel(modelFile, False)
122
123###################################################################
124## Step 3 : Compile the generated C++ model code
125###################################################################
126
127ROOT.gInterpreter.Declare('#include "' + modelCode + '"')
128
129###################################################################
130## Step 4: Evaluate the model
131###################################################################
132
133#get first the SOFIE session namespace
134sofie = getattr(ROOT, 'TMVA_SOFIE_' + modelName)
135session = sofie.Session()
136
137x = np.random.normal(0,1,(1,32)).astype(np.float32)
138print("\n************************************************************")
139print("Running inference with SOFIE ")
140print("\ninput to model is ",x)
141y = session.infer(x)
142# output shape is (1,2)
143y_sofie = np.asarray(y.data())
144print("-> output using SOFIE = ", y_sofie)
145
146#check inference with onnx
147try:
148 import onnxruntime as ort
149 # Load model
150 print("Running inference with ONNXRuntime ")
151 ort_session = ort.InferenceSession(modelFile)
152
153 # Run inference
154 outputs = ort_session.run(None, {"input": x})
155 y_ort = outputs[0]
156 print("-> output using ORT =", y_ort)
157
158 testFailed = abs(y_sofie-y_ort) > 0.01
159 if (np.any(testFailed)):
160 raiseError('Result is different between SOFIE and ONNXRT')
161 else :
162 print("OK")
163
164except ImportError:
165 print("Missing ONNXRuntime: skipping comparison test")
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.