Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RModelParser_ONNX.cxx
Go to the documentation of this file.
2#include "onnx_proto3.pb.h"
3
4#include <stdexcept>
5#include <string>
6#include <memory>
7#include <cassert>
8#include <iostream>
9#include <unordered_map>
10#include <functional>
11#include "TMVA/SOFIE_common.hxx"
12
13namespace TMVA {
14namespace Experimental {
15namespace SOFIE {
16
17// Declaration of operators
18// Unary operators
24// Binary operators
30// Nary operators
35// Reduce operators
39// Others
66// Decalaration of fused operators
70
71// Definition of RModelParser_ONNX::OperatorsMap
73 // Registered operators
74 std::unordered_map<std::string, ParserFuncSignature> fOperatorsMap;
75};
76
77// Constructor of the parser
78RModelParser_ONNX::RModelParser_ONNX() noexcept : fOperatorsMapImpl(std::make_unique<OperatorsMapImpl>()) {
79 // Register operators
80 // Unary operators
82 RegisterOperator("Reciprocal", ParseReciprocal);
86 // Binary operators
92 // Nary operators
97 // Reduce operators
98 RegisterOperator("ReduceMean", ParseReduceMean);
99 RegisterOperator("ReduceSumsquare", ParseReduceSumsquare);
100 RegisterOperator("ReduceProd", ParseReduceProd);
101 // Others
102 RegisterOperator("BatchNormalization", ParseBatchNormalization);
104 RegisterOperator("Concat", ParseConcat);
106 RegisterOperator("ConvTranspose", ParseConvTranspose);
109 RegisterOperator("Identity", ParseIdentity);
110 RegisterOperator("LeakyRelu", ParseLeakyRelu);
112 RegisterOperator("AveragePool", ParsePool);
113 RegisterOperator("GlobalAveragePool", ParsePool);
114 RegisterOperator("MaxPool", ParsePool);
116 RegisterOperator("Reshape", ParseReshape);
117 RegisterOperator("Flatten", ParseReshape);
118 RegisterOperator("Squeeze", ParseReshape);
119 RegisterOperator("Unsqueeze", ParseReshape);
123 RegisterOperator("Sigmoid", ParseSigmoid);
125 RegisterOperator("Softmax", ParseSoftmax);
127 RegisterOperator("Softmax", ParseSoftmax);
129 RegisterOperator("Transpose", ParseTranspose);
130 RegisterOperator("MatMul", ParseMatMul);
131 RegisterOperator("LayerNormalization", ParseLayerNormalization);
132 RegisterOperator("Expand", ParseExpand);
133 RegisterOperator("Gather", ParseGather);
135}
136
137// Destructor of the parser
139
141{
142 fOperatorsMapImpl->fOperatorsMap[name] = func;
143}
144
146{
147 return fOperatorsMapImpl->fOperatorsMap.find(name) != fOperatorsMapImpl->fOperatorsMap.end();
148}
149
151{
152 std::vector<std::string> ops;
153 ops.reserve(fOperatorsMapImpl->fOperatorsMap.size());
154 for (auto &it : fOperatorsMapImpl->fOperatorsMap) {
155 ops.emplace_back(it.first);
156 }
157 return ops;
158}
159
161{
163}
164
166{
168}
169
171{
173}
174
175// Parse an operator
176std::unique_ptr<ROperator>
177RModelParser_ONNX::ParseOperator(const size_t i, const onnx::GraphProto &graphproto, const std::vector<size_t> &nodes)
178{
179 int idx = (nodes.size() > i) ? nodes[i] : (int)i;
180 const auto &nodeproto = graphproto.node(idx);
181 const std::string op_type = nodeproto.op_type();
182 if (fVerbose)
183 std::cout << "Parsing an operator " << op_type << std::endl;
184
185
186 if (op_type == "MatMul") {
187 // Fuse MatMul and Add
188 int idx2 = (nodes.size() > i + 1) ? nodes[i + 1] : (int)i + 1;
189 if (idx2 < graphproto.node_size() && graphproto.node(idx2).op_type() == "Add") {
190 return ParseFuseMatMulAdd(*this, graphproto.node(idx), graphproto.node(idx2));
191 }
192 else if(graphproto.node(idx2).op_type() != "Add"){
193 return ParseMatMul(*this, graphproto.node(idx));
194 }
195 } else if (nodeproto.op_type() == "Conv" || nodeproto.op_type() == "ConvTranspose") {
196 // Fuse Conv or ConvTranspose without bias and Add
197 int j = (nodes.size() > i + 1) ? nodes[i + 1] : (int)i + 1;
198 if (j < graphproto.node_size() && graphproto.node(j).op_type() == "Add") {
199 if (nodeproto.op_type() == "Conv") {
200 return ParseFuseConvAdd(*this, graphproto.node(idx), graphproto.node(j));
201 } else {
202 return ParseFuseConvTransposeAdd(*this, graphproto.node(idx), graphproto.node(j));
203 }
204 }
205 }
206
207 // skip then the following Add
208 if (idx > 0 && op_type == "Add") {
209 int idx0 = (nodes.size() > i) ? nodes[i - 1] : (int)i - 1;
210 if (graphproto.node(idx0).op_type() == "MatMul")
211 return nullptr;
212 else if (graphproto.node(idx0).op_type() == "ConvTranspose")
213 return nullptr;
214 }
215
216 auto it = fOperatorsMapImpl->fOperatorsMap.find(op_type);
217 if (it == fOperatorsMapImpl->fOperatorsMap.end()) {
218 throw std::runtime_error("TMVA::SOFIE Operator type " + op_type + " is not yet supported");
219 }
220 if (fVerbose) {
221 std::cout << "\tCreating operator " << op_type << std::endl;
222 }
223 return it->second(*this, nodeproto);
224}
225
226// Parse a model
227RModel RModelParser_ONNX::Parse(std::string filename, bool verbose)
228{
229 fVerbose = verbose;
230 char sep = '/';
231#ifdef _WIN32
232 sep = '\\';
233#endif
234 size_t isep = filename.rfind(sep, filename.length());
235 std::string filename_nodir = filename;
236 if (isep != std::string::npos) {
237 filename_nodir = (filename.substr(isep + 1, filename.length() - isep));
238 }
239
240 std::time_t ttime = std::time(0);
241 std::tm *gmt_time = std::gmtime(&ttime);
242 std::string parsetime(std::asctime(gmt_time));
243
244 GOOGLE_PROTOBUF_VERIFY_VERSION;
245 // model I/O
246 onnx::ModelProto model;
247 RModel rmodel(filename_nodir, parsetime);
248
249 fTensorTypeMap.clear();
250
251 std::fstream input(filename, std::ios::in | std::ios::binary);
252 if (!model.ParseFromIstream(&input)) {
253 throw std::runtime_error("TMVA::SOFIE - Failed to parse onnx file " + filename);
254 }
255
256 const onnx::GraphProto &graph = model.graph(); // not a memory leak. model freed automatically at the end.
257 google::protobuf::ShutdownProtobufLibrary();
258
259 // ONNX version is ir_version() - model_version() returns 0
260 if (fVerbose) {
261 std::cout << "ONNX Version " << model.ir_version() << std::endl;
262 }
263
264 std::unordered_set<std::string> initializer_names;
265 for (int i = 0; i < graph.initializer_size(); i++) {
266 initializer_names.insert(graph.initializer(i).name());
267 }
268
269 if (verbose)
270 std::cout << "Parsing model inputs...." << std::endl;
271 /// Loop on model inputs
272 for (int i = 0; i < graph.input_size(); i++) {
273 RegisterTensorType(graph.input(i).name(),
274 static_cast<ETensorType>(graph.input(i).type().tensor_type().elem_type()));
275
276 if (verbose)
277 std::cout << "\tgraph input " << i << " name " << graph.input(i).name() << " type "
278 << graph.input(i).type().tensor_type().elem_type() << std::endl;
279
280 if (initializer_names.find(graph.input(i).name()) != initializer_names.end())
281 continue;
282
283 // input data node is not a weight node (has no initializer)
284 const onnx::ValueInfoProto &valueinfoproto = graph.input(i);
285 std::string input_name = valueinfoproto.name();
286
287 ETensorType type = static_cast<ETensorType>(valueinfoproto.type().tensor_type().elem_type());
289 throw std::runtime_error("TMVA::SOFIE Data type in input tensor " + input_name + " not supported!\n");
290 }
291
292 std::vector<Dim> fShape;
293 bool existParam = false;
294 if (!valueinfoproto.type().tensor_type().has_shape())
295 throw std::runtime_error("TMVA::SOFIE datanode with no shape restrictions is not supported yet");
296 for (int j = 0; j < valueinfoproto.type().tensor_type().shape().dim_size(); j++) {
297 Dim dim;
298 if (valueinfoproto.type().tensor_type().shape().dim(j).value_case() ==
299 onnx::TensorShapeProto_Dimension::ValueCase::kDimValue) {
300 dim.dim = valueinfoproto.type().tensor_type().shape().dim(j).dim_value();
301 } else if (valueinfoproto.type().tensor_type().shape().dim(j).value_case() ==
302 onnx::TensorShapeProto_Dimension::ValueCase::kDimParam) {
303 dim.isParam = true;
304 existParam = true;
305 dim.param = valueinfoproto.type().tensor_type().shape().dim(j).dim_param();
306 } else {
307 throw std::runtime_error("TMVA::SOFIE ONNX file error: Valueinfoproto " + input_name +
308 " has neither dim_value nor dim_param! \n");
309 }
310 fShape.push_back(dim);
311 }
312 if (valueinfoproto.type().tensor_type().shape().dim_size() == 0) {
313 Dim dim;
314 dim.dim = 1;
315 fShape.push_back(dim);
316 } // in case this TensorShapeProto has no dimension message: ONNX IR defines this to be a scalar
317
318 if (!existParam) {
319 std::vector<size_t> fShape_sizet;
320 for (auto &j : fShape) {
321 fShape_sizet.push_back(j.dim);
322 }
323
324 rmodel.AddInputTensorInfo(input_name, type, fShape_sizet);
325 } else {
326 rmodel.AddInputTensorInfo(input_name, type, fShape);
327 }
328 rmodel.AddInputTensorName(input_name); // store also names in given order
329 }
330
331 std::map<std::string, int> allInitializedTensors;
332
333 if (verbose)
334 std::cout << "\nParsing graph initializer list and fill model initialized tensors" << std::endl;
335
336 for (int i = 0; i < graph.initializer_size(); i++) {
337 onnx::TensorProto *tensorproto = const_cast<onnx::TensorProto *>(&graph.initializer(i));
338 std::vector<std::size_t> shape;
339 std::size_t fLength = 1;
340 for (int j = 0; j < tensorproto->dims_size(); j++) {
341 shape.push_back(tensorproto->dims(j));
342 fLength *= tensorproto->dims(j);
343 }
344 // in case of scalars keep an empty shape but with length =1
345
346 std::string input_name = graph.initializer(i).name();
347
348 if (verbose)
349 std::cout << "\t initializer " << i << " name " << input_name << " type " << graph.initializer(i).data_type()
350 << std::endl;
351
352 switch (static_cast<ETensorType>(graph.initializer(i).data_type())) {
353 case ETensorType::FLOAT: {
354 std::shared_ptr<void> data(malloc(fLength * sizeof(float)), free);
355
356 if (tensorproto->raw_data().empty() == false) {
357 auto raw_data_ptr = reinterpret_cast<float *>(const_cast<char *>(tensorproto->raw_data().c_str()));
358 std::memcpy(data.get(), raw_data_ptr, fLength * sizeof(float));
359 } else {
360 tensorproto->mutable_float_data()->ExtractSubrange(0, tensorproto->float_data_size(),
361 static_cast<float *>(data.get()));
362 }
363
364 if (verbose) std::cout << "add FLOAT initialized tensor " << input_name << " shape " << ConvertShapeToString(shape) << std::endl;
365 rmodel.AddInitializedTensor(input_name, ETensorType::FLOAT, shape, data);
366 allInitializedTensors[input_name] = i;
367 break;
368 }
369 case ETensorType::INT64: {
370 std::shared_ptr<void> data(malloc(fLength * sizeof(int64_t)), free);
371
372 if (tensorproto->raw_data().empty() == false) {
373 auto raw_data_ptr = reinterpret_cast<int64_t *>(const_cast<char *>(tensorproto->raw_data().c_str()));
374 std::memcpy(data.get(), raw_data_ptr, fLength * sizeof(int64_t));
375 } else {
376 tensorproto->mutable_int64_data()->ExtractSubrange(0, tensorproto->int64_data_size(),
377 static_cast<int64_t *>(data.get()));
378 }
379
380 if (verbose) std::cout << "add INT64 initialized tensor " << input_name << " shape " << ConvertShapeToString(shape) << std::endl;
381 rmodel.AddInitializedTensor(input_name, ETensorType::INT64, shape, data);
382 allInitializedTensors[input_name] = i;
383 break;
384 }
385 default:
386 throw std::runtime_error("Data type in weight tensor " + graph.initializer(i).name() + " not supported!\n");
387 }
388 }
389
390 // Initial operator order
391 if (verbose) {
392 std::cout << "\nGraph operator list (ONNX order)\n";
393 for (int i = 0; i < graph.node_size(); i++) {
394 std::cout << "\tOperator " << i << " : " << graph.node(i).op_type() << " , " << graph.node(i).input_size()
395 << " inputs : {";
396 for (int j = 0; j < graph.node(i).input_size(); j++) {
397 std::cout << graph.node(i).input(j);
398 if (j < graph.node(i).input_size() - 1)
399 std::cout << ", ";
400 }
401 std::cout << " }" << std::endl;
402 }
403 }
404
405 // make order of nodes:
406 if (verbose)
407 std::cout << "\nRe-Order graph operator list\n";
408 std::vector<size_t> nodesOrder;
409 nodesOrder.reserve(graph.node_size());
410 std::vector<bool> foundNodes(graph.node_size());
411 // loop at graph inputs
412 std::map<std::string, int> allInputs;
413 for (int i = 0; i < graph.input_size(); i++) {
414 allInputs[graph.input(i).name()] = -1;
415 }
416 do {
417 auto psize = nodesOrder.size();
418 for (int i = 0; i < graph.node_size(); i++) {
419 if (foundNodes[i])
420 continue;
421 // check if all input exists add to list
422 bool existInputs = true;
423 int input_size = graph.node(i).input_size();
424 // special case for Reshape where shape is input and not a weight tensor
425 for (int j = 0; j < input_size; j++) {
426 std::string name = graph.node(i).input(j);
427 // skip empty names
428 if (!name.empty()) {
429 existInputs &= (allInputs.find(name) != allInputs.end() ||
430 allInitializedTensors.find(name) != allInitializedTensors.end());
431 if (fVerbose) {
432 std::cout << graph.node(i).op_type() << " input " << name << " "
433 << bool(allInputs.find(name) != allInputs.end()) << " " <<
434 bool(allInitializedTensors.find(name) != allInitializedTensors.end()) <<
435 existInputs << std::endl;
436 }
437 }
438 }
439 if (!existInputs) {
440 if (fVerbose) {
441 std::cout << "skip op " << graph.node(i).op_type() << " inputs are ";
442 for (int j = 0; j < input_size; j++) {
443 std::cout << graph.node(i).input(j) << " ";
444 }
445 std::cout << std::endl;
446 }
447 continue;
448 }
449 if (verbose)
450 std::cout << "\tadd node " << graph.node(i).op_type() << " order " << i << std::endl;
451
452 nodesOrder.push_back(i);
453 foundNodes[i] = true;
454 // register the outputs
455 for (int j = 0; j < graph.node(i).output_size(); j++) {
456 allInputs[graph.node(i).output(j)] = i;
457 }
458 }
459 // no increment in nodes - something wrong
460 if (nodesOrder.size() == psize) {
461 throw std::runtime_error("TMVA::SOFIE - cannot find a new node ");
462 }
463 } while ((int)nodesOrder.size() < graph.node_size());
464
465 // scan operators for orders
466 if (verbose) {
467 std::cout << "\nGraph operator list (re-ordered)\n";
468 for (int k = 0; k < graph.node_size(); k++) {
469 int i = nodesOrder[k];
470 std::cout << "\tOperator " << i << " : " << graph.node(i).op_type() << " , " << graph.node(i).input_size()
471 << " inputs : {";
472 for (int j = 0; j < graph.node(i).input_size(); j++) {
473 std::cout << graph.node(i).input(j);
474 if (j < graph.node(i).input_size() - 1)
475 std::cout << ", ";
476 }
477 std::cout << " }" << std::endl;
478 }
479 }
480
481 // fill model with operators
482 if (verbose) {
483 std::cout << "Fill RModel with operators...\n";
484 }
485 for (int i = 0; i < graph.node_size(); i++) {
486 std::string op_type = graph.node(nodesOrder[i]).op_type();
487
488 if (verbose) {
489 std::cout << "\t" << i << " " << nodesOrder[i] << " parsing operator " << op_type << std::endl;
490 }
491
492 std::unique_ptr<ROperator> op = ParseOperator(i, graph, nodesOrder);
493 if (!op) {
494 if (verbose) {
495 std::cout << "\t\tskipping operator since it is fused with previous one" << std::endl;
496 }
497 // for skipping the fused nodes like Add after MatMul
498 continue;
499 }
500 rmodel.AddOperator(std::move(op));
501 }
502
503 std::vector<std::string> outputnames;
504 if (verbose)
505 std::cout << "\nParsing Graph output list\n";
506 for (int i = 0; i < graph.output_size(); i++) {
507 if (verbose)
508 std::cout << "\toutput " << i << " name " << graph.output(i).name() << std::endl;
509 outputnames.push_back(graph.output(i).name());
510 }
511 rmodel.AddOutputTensorNameList(outputnames);
512
513 return rmodel;
514}
515
516} // namespace SOFIE
517} // namespace Experimental
518} // namespace TMVA
Py_ssize_t * fShape
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
char name[80]
Definition TGX11.cxx:110
#define free
Definition civetweb.c:1539
#define malloc
Definition civetweb.c:1536
void RegisterOperator(const std::string &name, ParserFuncSignature func)
bool IsRegisteredOperator(const std::string &name)
std::unordered_map< std::string, ETensorType > fTensorTypeMap
RModel Parse(std::string filename, bool verbose=false)
std::unique_ptr< ROperator > ParseOperator(const size_t, const onnx::GraphProto &, const std::vector< size_t > &)
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
std::vector< std::string > GetRegisteredOperators()
std::unique_ptr< OperatorsMapImpl > fOperatorsMapImpl
void AddOutputTensorNameList(std::vector< std::string > outputtensornames)
Definition RModel.cxx:165
void AddInputTensorInfo(std::string input_name, ETensorType type, std::vector< Dim > shape)
Definition RModel.cxx:104
void AddInitializedTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape, std::shared_ptr< void > data)
Definition RModel.cxx:140
void AddInputTensorName(std::string name)
Definition RModel.cxx:123
void AddOperator(std::unique_ptr< ROperator > op, int order_execution=-1)
Definition RModel.cxx:127
std::string Clean_name(std::string input_tensor_name)
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &, const onnx::NodeProto &)> ParserFuseFuncSignature
ParserFuncSignature ParseSqrt
ParserFuncSignature ParseBatchNormalization
ParserFuncSignature ParseReshape
ParserFuseFuncSignature ParseFuseConvTransposeAdd
ParserFuncSignature ParseReduceMean
ParserFuseFuncSignature ParseFuseMatMulAdd
ParserFuncSignature ParseGather
ParserFuncSignature ParseNeg
ParserFuncSignature ParseLog
ParserFuncSignature ParseLeakyRelu
ParserFuncSignature ParseExp
ParserFuncSignature ParsePool
Definition ParsePool.cxx:9
ParserFuncSignature ParseDiv
ParserFuncSignature ParseLayerNormalization
ParserFuncSignature ParseConcat
ParserFuncSignature ParseMax
ParserFuncSignature ParseIdentity
ParserFuncSignature ParseConvTranspose
ParserFuncSignature ParseReduceProd
ParserFuncSignature ParseSlice
Definition ParseSlice.cxx:9
ParserFuncSignature ParseTranspose
ParserFuncSignature ParseShape
Definition ParseShape.cxx:9
ParserFuncSignature ParseGRU
Definition ParseGRU.cxx:9
ParserFuncSignature ParseMatMul
ParserFuncSignature ParseErf
Definition ParseErf.cxx:9
ParserFuncSignature ParseSub
ParserFuncSignature ParseReduceSumsquare
ParserFuncSignature ParseAdd
ParserFuncSignature ParseExpand
ParserFuncSignature ParseRNN
Definition ParseRNN.cxx:9
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
ParserFuncSignature ParseLSTM
Definition ParseLSTM.cxx:9
ParserFuncSignature ParseCast
Definition ParseCast.cxx:9
ParserFuncSignature ParseReciprocal
std::string ConvertShapeToString(std::vector< size_t > shape)
ParserFuncSignature ParseSigmoid
ParserFuseFuncSignature ParseFuseConvAdd
ParserFuncSignature ParseSoftmax
ParserFuncSignature ParseMean
ParserFuncSignature ParseSelu
Definition ParseSelu.cxx:9
ParserFuncSignature ParseSum
ParserFuncSignature ParseMin
ParserFuncSignature ParseRelu
Definition ParseRelu.cxx:9
ParserFuncSignature ParseConv
Definition ParseConv.cxx:9
ParserFuncSignature ParseGemm
Definition ParseGemm.cxx:9
ParserFuncSignature ParseMul
ParserFuncSignature ParsePow
ParserFuncSignature ParseTanh
Definition ParseTanh.cxx:9
create variable transformations
Definition graph.py:1
std::unordered_map< std::string, ParserFuncSignature > fOperatorsMap