Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RModelParser_ONNX.cxx
Go to the documentation of this file.
1#include "Byteswap.h"
3#include "onnx_proto3.pb.h"
4
5#include <stdexcept>
6#include <string>
7#include <memory>
8#include <cassert>
9#include <iostream>
10#include <unordered_map>
11#include <functional>
12#include "TMVA/SOFIE_common.hxx"
13
14namespace TMVA {
15namespace Experimental {
16namespace SOFIE {
17
18// Declaration of operators
19// Unary operators
25// Binary operators
31// Nary operators
36//Comparision Operators
42// Reduce operators
47// Others
82// Decalaration of fused operators
86
87// Definition of RModelParser_ONNX::OperatorsMap
89 // Registered operators
90 std::unordered_map<std::string, ParserFuncSignature> fOperatorsMap;
91};
92
93// helper function to get initialized tensor data
94template<typename T>
96};
97// trait function to extract data from TensorProto
98template<>
99struct ExtractDataFromTP<float> {
100 static void Copy(onnx::TensorProto * tensor, void * data) {
101 tensor->mutable_float_data()->ExtractSubrange(0, tensor->float_data_size(),
102 static_cast<float *>(data));
103 }
104};
105template<>
107 static void Copy(onnx::TensorProto * tensor, void * data) {
108 tensor->mutable_double_data()->ExtractSubrange(0, tensor->double_data_size(),
109 static_cast<double *>(data));
110 }
111};
112template<>
113struct ExtractDataFromTP<int32_t> {
114 static void Copy(onnx::TensorProto * tensor, void * data) {
115 tensor->mutable_int32_data()->ExtractSubrange(0, tensor->int32_data_size(),
116 static_cast<int32_t *>(data));
117 }
118};
119template<>
120struct ExtractDataFromTP<int64_t> {
121 static void Copy(onnx::TensorProto * tensor, void * data) {
122 tensor->mutable_int64_data()->ExtractSubrange(0, tensor->int64_data_size(),
123 static_cast<int64_t *>(data));
124 }
125};
126template<typename T>
127std::shared_ptr<void> GetInitializedTensorData(onnx::TensorProto * tensorproto, size_t length) {
128 std::shared_ptr<void> data(malloc(length * sizeof(T)), free);
129
130 if (!tensorproto->raw_data().empty()) {
131#ifdef R__BYTESWAP
132 std::memcpy(data.get(), tensorproto->raw_data().c_str(), length * sizeof(T));
133#else
134 for (std::size_t k = 0; k < fLength; ++k)
135 (reinterpret_cast<uint32_t *>(data.get()))[k] =
136 Rbswap_32((reinterpret_cast<const uint32_t *>(tensorproto->raw_data().c_str()))[k]);
137#endif
138 } else {
139 ExtractDataFromTP<T>::Copy(tensorproto, data.get());
140 }
141 return data;
142}
143
144// Constructor of the parser
145RModelParser_ONNX::RModelParser_ONNX() noexcept : fOperatorsMapImpl(std::make_unique<OperatorsMapImpl>()) {
146 // Register operators
147 // Unary operators
149 RegisterOperator("Reciprocal", ParseReciprocal);
153 // Binary operators
159 // Nary operators
164 //Comparision Operators
165 RegisterOperator("Equal", ParseEq);
167 RegisterOperator("LessOrEqual", ParseLessEq);
168 RegisterOperator("Greater", ParseGreater);
169 RegisterOperator("GreaterOrEqual", ParseGreaterEq);
170 // Reduce operators
171 RegisterOperator("ReduceMean", ParseReduceMean);
172 RegisterOperator("ReduceSum", ParseReduceSum);
173 RegisterOperator("ReduceSumSquare", ParseReduceSumSquare);
174 RegisterOperator("ReduceProd", ParseReduceProd);
175 // Others
176 RegisterOperator("BatchNormalization", ParseBatchNormalization);
177 RegisterOperator("Constant", ParseConstant);
178 RegisterOperator("ConstantOfShape", ParseConstant);
180 RegisterOperator("Concat", ParseConcat);
182 RegisterOperator("ConvTranspose", ParseConvTranspose);
185 RegisterOperator("Identity", ParseIdentity);
186 RegisterOperator("LeakyRelu", ParseLeakyRelu);
188 RegisterOperator("AveragePool", ParsePool);
189 RegisterOperator("GlobalAveragePool", ParsePool);
190 RegisterOperator("MaxPool", ParsePool);
192 RegisterOperator("Reshape", ParseReshape);
193 RegisterOperator("Flatten", ParseReshape);
194 RegisterOperator("Squeeze", ParseReshape);
195 RegisterOperator("Unsqueeze", ParseReshape);
199 RegisterOperator("Sigmoid", ParseSigmoid);
201 RegisterOperator("Softmax", ParseSoftmax);
203 RegisterOperator("Transpose", ParseTranspose);
204 RegisterOperator("MatMul", ParseMatMul);
205 RegisterOperator("LayerNormalization", ParseLayerNormalization);
206 RegisterOperator("Expand", ParseExpand);
207 RegisterOperator("Gather", ParseGather);
210 RegisterOperator("EyeLike", ParseEyeLike);
216}
217
218// Destructor of the parser
220
222{
223 fOperatorsMapImpl->fOperatorsMap[name] = func;
224}
225
227{
228 return fOperatorsMapImpl->fOperatorsMap.find(name) != fOperatorsMapImpl->fOperatorsMap.end();
229}
230
232{
233 std::vector<std::string> ops;
234 ops.reserve(fOperatorsMapImpl->fOperatorsMap.size());
235 for (auto &it : fOperatorsMapImpl->fOperatorsMap) {
236 ops.emplace_back(it.first);
237 }
238 return ops;
239}
240
242{
244}
245
247{
249}
250
252{
254}
255
256// Parse an operator
257std::unique_ptr<ROperator>
258RModelParser_ONNX::ParseOperator(const size_t i, const onnx::GraphProto &graphproto, const std::vector<size_t> &nodes)
259{
260 if (i >= nodes.size())
261 throw std::runtime_error("TMVA::SOFIE - Error in parsing ordered operators " + std::to_string(i) + " is >= " + std::to_string(nodes.size()));
262 int idx = nodes[i];
263 const auto &nodeproto = graphproto.node(idx);
264 const std::string op_type = nodeproto.op_type();
265 if (fVerbose)
266 std::cout << "Parsing operator " << op_type << std::endl;
267
268 // try to fuse with following operator in case it is not last one
269 if (i < nodes.size() - 1) {
270 int idx2 = nodes[i+1];
271 if (op_type == "MatMul") {
272 // Fuse MatMul and Add
273 if (idx2 < graphproto.node_size() && graphproto.node(idx2).op_type() == "Add") {
274 return ParseFuseMatMulAdd(*this, graphproto.node(idx), graphproto.node(idx2));
275 }
276 else {
277 return ParseMatMul(*this, graphproto.node(idx));
278 }
279 } else if (nodeproto.op_type() == "Conv" || nodeproto.op_type() == "ConvTranspose") {
280 // Fuse Conv or ConvTranspose without bias and Add
281 if (idx2 < graphproto.node_size() && graphproto.node(idx2).op_type() == "Add") {
282 if (nodeproto.op_type() == "Conv") {
283 return ParseFuseConvAdd(*this, graphproto.node(idx), graphproto.node(idx2));
284 } else {
285 return ParseFuseConvTransposeAdd(*this, graphproto.node(idx), graphproto.node(idx2));
286 }
287 }
288 }
289 }
290
291 // skip then the following Add if it was fused before
292 if (idx > 0 && op_type == "Add") {
293 int idx0 = nodes[i - 1];
294 if (graphproto.node(idx0).op_type() == "MatMul")
295 return nullptr;
296 else if (graphproto.node(idx0).op_type() == "ConvTranspose")
297 return nullptr;
298 }
299
300 auto it = fOperatorsMapImpl->fOperatorsMap.find(op_type);
301 if (it == fOperatorsMapImpl->fOperatorsMap.end()) {
302 std::cout << "operator " << op_type << " is not supported" << std::endl;
303 throw std::runtime_error("TMVA::SOFIE Operator type " + op_type + " is not yet supported");
304 }
305 if (fVerbose) {
306 std::cout << "\tCreating operator " << op_type << std::endl;
307 }
308 return it->second(*this, nodeproto);
309}
310
311// Parse a model
312RModel RModelParser_ONNX::Parse(std::string filename, bool verbose)
313{
314 fVerbose = verbose;
315 char sep = '/';
316#ifdef _WIN32
317 sep = '\\';
318#endif
319 size_t isep = filename.rfind(sep, filename.length());
320 std::string filename_nodir = filename;
321 if (isep != std::string::npos) {
322 filename_nodir = (filename.substr(isep + 1, filename.length() - isep));
323 }
324
325
326 GOOGLE_PROTOBUF_VERIFY_VERSION;
327 // model I/O
328 onnx::ModelProto model;
329
330
331 fTensorTypeMap.clear();
332
333 std::fstream input(filename, std::ios::in | std::ios::binary);
334 if (!model.ParseFromIstream(&input)) {
335 throw std::runtime_error("TMVA::SOFIE - Failed to parse onnx file " + filename);
336 }
337
338 const onnx::GraphProto &graph = model.graph(); // not a memory leak. model freed automatically at the end.
339 google::protobuf::ShutdownProtobufLibrary();
340
341 // ONNX version is ir_version() - model_version() returns 0
342 if (fVerbose) {
343 std::cout << "ONNX Version " << model.ir_version() << std::endl;
344 }
345
346 std::time_t ttime = std::time(0);
347 std::tm *gmt_time = std::gmtime(&ttime);
348 std::string parsetime(std::asctime(gmt_time));
349
350 RModel rmodel(filename_nodir, parsetime);
351 ParseONNXGraph(rmodel, graph, filename_nodir);
352 return rmodel;
353}
354
355void RModelParser_ONNX::ParseONNXGraph(RModel & rmodel, const onnx::GraphProto & graph, std::string graphName)
356{
357 bool verbose = fVerbose;
358
359 if (graphName.empty())
360 graphName = graph.name();
361
362 if (verbose)
363 std::cout << "\nParsing Graph - " << graphName << std::endl;
364
365 std::unordered_set<std::string> initializer_names;
366 for (int i = 0; i < graph.initializer_size(); i++) {
367 initializer_names.insert(graph.initializer(i).name());
368 }
369
370 if (verbose)
371 std::cout << "Parsing model inputs...." << std::endl;
372 /// Loop on model inputs
373 for (int i = 0; i < graph.input_size(); i++) {
374 RegisterTensorType(graph.input(i).name(),
375 static_cast<ETensorType>(graph.input(i).type().tensor_type().elem_type()));
376
377 if (verbose)
378 std::cout << "\tgraph input " << i << " name " << graph.input(i).name() << " type "
379 << graph.input(i).type().tensor_type().elem_type() << std::endl;
380
381 if (initializer_names.find(graph.input(i).name()) != initializer_names.end())
382 continue;
383
384 // input data node is not a weight node (has no initializer)
385 const onnx::ValueInfoProto &valueinfoproto = graph.input(i);
386 std::string input_name = valueinfoproto.name();
387
388 ETensorType type = static_cast<ETensorType>(valueinfoproto.type().tensor_type().elem_type());
390 throw std::runtime_error("TMVA::SOFIE Data type in input tensor " + input_name + " not supported!\n");
391 }
392
393 std::vector<Dim> fShape;
394 bool existParam = false;
395 if (!valueinfoproto.type().tensor_type().has_shape())
396 throw std::runtime_error("TMVA::SOFIE data node with no shape restrictions is not supported yet");
397 for (int j = 0; j < valueinfoproto.type().tensor_type().shape().dim_size(); j++) {
398 Dim dim;
399 if (valueinfoproto.type().tensor_type().shape().dim(j).value_case() ==
400 onnx::TensorShapeProto_Dimension::ValueCase::kDimValue) {
401 int dim_value = valueinfoproto.type().tensor_type().shape().dim(j).dim_value();
402 dim.dim = dim_value;
403 // case input dim is -1 - set a parametric shape
404 if (dim_value < 0) {
405 dim.isParam = true;
406 existParam = true;
407 dim.param = UTILITY::Clean_name(input_name) + "_size";
408 }
409 } else if (valueinfoproto.type().tensor_type().shape().dim(j).value_case() ==
410 onnx::TensorShapeProto_Dimension::ValueCase::kDimParam) {
411 dim.isParam = true;
412 existParam = true;
413 dim.param = valueinfoproto.type().tensor_type().shape().dim(j).dim_param();
414 } else {
415 throw std::runtime_error("TMVA::SOFIE ONNX file error: Valueinfoproto " + input_name +
416 " has neither dim_value nor dim_param! \n");
417 }
418 fShape.push_back(dim);
419 }
420 if (valueinfoproto.type().tensor_type().shape().dim_size() == 0) {
421 Dim dim;
422 dim.dim = 1;
423 fShape.push_back(dim);
424 } // in case this TensorShapeProto has no dimension message: ONNX IR defines this to be a scalar
425
426 if (!existParam) {
427 std::vector<size_t> fShape_sizet;
428 for (auto &j : fShape) {
429 fShape_sizet.push_back(j.dim);
430 }
431
432 rmodel.AddInputTensorInfo(input_name, type, fShape_sizet);
433 } else {
434 rmodel.AddInputTensorInfo(input_name, type, fShape);
435 }
436 rmodel.AddInputTensorName(input_name); // store also names in given order
437 }
438
439 std::map<std::string, int> allInitializedTensors;
440
441 if (verbose)
442 std::cout << "\nParsing graph initializer list and fill model initialized tensors" << std::endl;
443
444 for (int i = 0; i < graph.initializer_size(); i++) {
445 onnx::TensorProto *tensorproto = const_cast<onnx::TensorProto *>(&graph.initializer(i));
446 std::vector<std::size_t> shape;
447 std::size_t fLength = 1;
448 for (int j = 0; j < tensorproto->dims_size(); j++) {
449 shape.push_back(tensorproto->dims(j));
450 fLength *= tensorproto->dims(j);
451 }
452 // in case of scalars keep an empty shape but with length =1
453
454 std::string input_name = graph.initializer(i).name();
455
456 if (verbose)
457 std::cout << "\t initializer " << i << " name " << input_name << " type " << graph.initializer(i).data_type()
458 << std::endl;
459
460 // register also the initialized tensors
461 auto tensor_type = static_cast<ETensorType>(graph.initializer(i).data_type());
462 RegisterTensorType(input_name, tensor_type);
463
464 switch (tensor_type) {
465 case ETensorType::FLOAT: {
466 std::shared_ptr<void> data = GetInitializedTensorData<float>(tensorproto, fLength);
467 if (verbose) std::cout << "add FLOAT initialized tensor " << input_name << " shape " << ConvertShapeToString(shape) << std::endl;
468 rmodel.AddInitializedTensor(input_name, ETensorType::FLOAT, shape, data);
469 allInitializedTensors[input_name] = i;
470 break;
471 }
472 case ETensorType::DOUBLE: {
473 std::shared_ptr<void> data = GetInitializedTensorData<double>(tensorproto, fLength);
474 if (verbose) std::cout << "add DOUBLE initialized tensor " << input_name << " shape " << ConvertShapeToString(shape) << std::endl;
475 rmodel.AddInitializedTensor(input_name, ETensorType::DOUBLE, shape, data);
476 allInitializedTensors[input_name] = i;
477 break;
478 }
479 case ETensorType::INT32: {
480 std::shared_ptr<void> data = GetInitializedTensorData<int32_t>(tensorproto, fLength);
481 if (verbose) std::cout << "add INT32 initialized tensor " << input_name << " shape " << ConvertShapeToString(shape) << std::endl;
482 rmodel.AddInitializedTensor(input_name, ETensorType::INT32, shape, data);
483 allInitializedTensors[input_name] = i;
484 break;
485 }
486 case ETensorType::INT64: {
487 std::shared_ptr<void> data = GetInitializedTensorData<int64_t>(tensorproto, fLength);
488 if (verbose) std::cout << "add INT64 initialized tensor " << input_name << " shape " << ConvertShapeToString(shape) << std::endl;
489 rmodel.AddInitializedTensor(input_name, ETensorType::INT64, shape, data);
490 allInitializedTensors[input_name] = i;
491 break;
492 }
493 default:
494 throw std::runtime_error("Data type in weight tensor " + graph.initializer(i).name() + " not supported!\n");
495 }
496 }
497
498 // Initial operator order
499 if (verbose) {
500 std::cout << "\nGraph operator list (ONNX order)\n";
501 for (int i = 0; i < graph.node_size(); i++) {
502 std::cout << "\tOperator " << i << " : " << graph.node(i).op_type() << " , " << graph.node(i).input_size()
503 << " inputs : {";
504 for (int j = 0; j < graph.node(i).input_size(); j++) {
505 std::cout << graph.node(i).input(j);
506 if (j < graph.node(i).input_size() - 1)
507 std::cout << ", ";
508 }
509 std::cout << " }" << std::endl;
510 }
511 }
512
513 // make order of nodes:
514 if (verbose)
515 std::cout << "\nRe-Order graph operator list\n";
516 std::vector<size_t> nodesOrder;
517 nodesOrder.reserve(graph.node_size());
518 std::vector<bool> foundNodes(graph.node_size());
519 // loop at graph inputs
520 //std::map<std::string, int> allInputs;
521 for (int i = 0; i < graph.input_size(); i++) {
522 allInputs[graph.input(i).name()] = -1;
523 }
524 do {
525 auto psize = nodesOrder.size();
526 for (int i = 0; i < graph.node_size(); i++) {
527 if (foundNodes[i])
528 continue;
529 // check if all input exists add to list
530 bool existInputs = true;
531 int input_size = graph.node(i).input_size();
532 // special case for Reshape where shape is input and not a weight tensor
533 if (fVerbose )
534 std::cout << "Checking input of Node " << i << " : " << graph.node(i).name() << std::endl;
535 for (int j = 0; j < input_size; j++) {
536 std::string name = graph.node(i).input(j);
537 // skip empty names
538 if (!name.empty()) {
539 existInputs &= (allInputs.find(name) != allInputs.end() ||
540 allInitializedTensors.find(name) != allInitializedTensors.end());
541 if (fVerbose ) {
542 std::cout << "\t\t input " << name << " "
543 << bool(allInputs.find(name) != allInputs.end()) << " " <<
544 bool(allInitializedTensors.find(name) != allInitializedTensors.end()) <<
545 existInputs << std::endl;
546 }
547 }
548 }
549 if (!existInputs) {
550 if (fVerbose) {
551 std::cout << "skip node " << graph.node(i).op_type() << " " << graph.node(i).name() << " inputs are not existing ";
552 for (int j = 0; j < input_size; j++) {
553 std::cout << graph.node(i).input(j) << " ";
554 }
555 std::cout << std::endl;
556 }
557 continue;
558 }
559 if (verbose)
560 std::cout << "===> New node " << graph.node(i).op_type() << " " << graph.node(i).name() << " order " << i << std::endl;
561
562 nodesOrder.push_back(i);
563 foundNodes[i] = true;
564 // register the outputs
565 for (int j = 0; j < graph.node(i).output_size(); j++) {
566 if (fVerbose) std::cout << "\toutput : " << graph.node(i).output(j) << std::endl;
567 allInputs[graph.node(i).output(j)] = i;
568 }
569 }
570 // no increment in nodes - something wrong
571 if (nodesOrder.size() == psize) {
572 int ilast = nodesOrder.back();
573 std::cout << "cannot find a new node after " << graph.node(ilast).op_type() << " " << graph.node(ilast).name() << std::endl;
574 throw std::runtime_error("TMVA::SOFIE - cannot find a new node ");
575 }
576 } while ((int)nodesOrder.size() < graph.node_size());
577
578 // scan operators for orders
579 if (verbose) {
580 std::cout << "\nGraph operator list (re-ordered)\n";
581 for (int k = 0; k < graph.node_size(); k++) {
582 int i = nodesOrder[k];
583 std::cout << "\tOperator " << i << " : " << graph.node(i).op_type() << " , " << graph.node(i).input_size()
584 << " inputs : {";
585 for (int j = 0; j < graph.node(i).input_size(); j++) {
586 std::cout << graph.node(i).input(j);
587 if (j < graph.node(i).input_size() - 1)
588 std::cout << ", ";
589 }
590 std::cout << " }" << std::endl;
591 }
592 }
593
594 // fill model with operators
595 if (verbose) {
596 std::cout << "Fill RModel with operators...\n";
597 }
598 for (int i = 0; i < graph.node_size(); i++) {
599 std::string op_type = graph.node(nodesOrder[i]).op_type();
600
601 if (verbose) {
602 std::cout << "\t" << i << " " << nodesOrder[i] << " parsing operator " << op_type << std::endl;
603 }
604
605 std::unique_ptr<ROperator> op = ParseOperator(i, graph, nodesOrder);
606 if (!op) {
607 if (verbose) {
608 std::cout << "\t\tskipping operator since it is fused with previous one" << std::endl;
609 }
610 // for skipping the fused nodes like Add after MatMul
611 continue;
612 }
613 rmodel.AddOperator(std::move(op));
614 }
615
616 std::vector<std::string> outputnames;
617 if (verbose)
618 std::cout << "\nParsing Graph output list\n";
619 for (int i = 0; i < graph.output_size(); i++) {
620 if (verbose)
621 std::cout << "\toutput " << i << " name " << graph.output(i).name() << std::endl;
622 outputnames.push_back(graph.output(i).name());
623 }
624 rmodel.AddOutputTensorNameList(outputnames);
625
626 return;
627}
628
629} // namespace SOFIE
630} // namespace Experimental
631} // namespace TMVA
#define Rbswap_32(x)
Definition Byteswap.h:108
dims_t fShape
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h length
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
char name[80]
Definition TGX11.cxx:110
#define free
Definition civetweb.c:1539
#define malloc
Definition civetweb.c:1536
void RegisterOperator(const std::string &name, ParserFuncSignature func)
bool IsRegisteredOperator(const std::string &name)
void ParseONNXGraph(RModel &model, const onnx::GraphProto &g, std::string name="")
std::unordered_map< std::string, ETensorType > fTensorTypeMap
RModel Parse(std::string filename, bool verbose=false)
std::unique_ptr< ROperator > ParseOperator(const size_t, const onnx::GraphProto &, const std::vector< size_t > &)
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
std::vector< std::string > GetRegisteredOperators()
std::unique_ptr< OperatorsMapImpl > fOperatorsMapImpl
void AddInputTensorInfo(std::string input_name, ETensorType type, std::vector< Dim > shape)
Definition RModel.cxx:132
void AddOutputTensorNameList(std::vector< std::string > output_tensor_names)
Definition RModel.cxx:241
void AddInitializedTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape, std::shared_ptr< void > data)
Definition RModel.cxx:168
void AddInputTensorName(std::string name)
Definition RModel.cxx:151
void AddOperator(std::unique_ptr< ROperator > op, int order_execution=-1)
Definition RModel.cxx:155
std::string Clean_name(std::string input_tensor_name)
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &, const onnx::NodeProto &)> ParserFuseFuncSignature
ParserFuncSignature ParseSqrt
ParserFuncSignature ParseBatchNormalization
ParserFuncSignature ParseGreater
ParserFuncSignature ParseReshape
ParserFuseFuncSignature ParseFuseConvTransposeAdd
ParserFuncSignature ParseReduceMean
ParserFuseFuncSignature ParseFuseMatMulAdd
ParserFuncSignature ParseGather
ParserFuncSignature ParseNeg
ParserFuncSignature ParseLog
ParserFuncSignature ParseLeakyRelu
ParserFuncSignature ParseExp
ParserFuncSignature ParsePool
Definition ParsePool.cxx:9
ParserFuncSignature ParseDiv
ParserFuncSignature ParseLayerNormalization
ParserFuncSignature ParseConcat
ParserFuncSignature ParseTopK
Definition ParseTopK.cxx:9
ParserFuncSignature ParseMax
ParserFuncSignature ParseEq
ParserFuncSignature ParseIdentity
ParserFuncSignature ParseConvTranspose
ParserFuncSignature ParseReduceProd
ParserFuncSignature ParseSlice
Definition ParseSlice.cxx:9
ParserFuncSignature ParseTranspose
ParserFuncSignature ParseLess
ParserFuncSignature ParseShape
Definition ParseShape.cxx:9
ParserFuncSignature ParseGRU
Definition ParseGRU.cxx:9
ParserFuncSignature ParseMatMul
ParserFuncSignature ParseErf
Definition ParseErf.cxx:9
ParserFuncSignature ParseSub
ParserFuncSignature ParseAdd
std::shared_ptr< void > GetInitializedTensorData(onnx::TensorProto *tensorproto, size_t length)
ParserFuncSignature ParseIf
Definition ParseIf.cxx:9
ParserFuncSignature ParseRange
Definition ParseRange.cxx:9
ParserFuncSignature ParseExpand
ParserFuncSignature ParseRNN
Definition ParseRNN.cxx:9
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
ParserFuncSignature ParseLSTM
Definition ParseLSTM.cxx:9
ParserFuncSignature ParseCast
Definition ParseCast.cxx:9
ParserFuncSignature ParseReciprocal
std::string ConvertShapeToString(std::vector< size_t > shape)
ParserFuncSignature ParseSigmoid
ParserFuseFuncSignature ParseFuseConvAdd
ParserFuncSignature ParseSoftmax
ParserFuncSignature ParseGreaterEq
ParserFuncSignature ParseMean
ParserFuncSignature ParseSplit
Definition ParseSplit.cxx:9
ParserFuncSignature ParseConstant
ParserFuncSignature ParseSelu
Definition ParseSelu.cxx:9
ParserFuncSignature ParseLessEq
ParserFuncSignature ParseSum
ParserFuncSignature ParseEyeLike
ParserFuncSignature ParseElu
Definition ParseElu.cxx:9
ParserFuncSignature ParseMin
ParserFuncSignature ParseRelu
Definition ParseRelu.cxx:9
ParserFuncSignature ParseReduceSum
ParserFuncSignature ParseConv
Definition ParseConv.cxx:9
ParserFuncSignature ParseGemm
Definition ParseGemm.cxx:9
ParserFuncSignature ParseTile
Definition ParseTile.cxx:9
ParserFuncSignature ParseMul
ParserFuncSignature ParsePow
ParserFuncSignature ParseReduceSumSquare
ParserFuncSignature ParseTanh
Definition ParseTanh.cxx:9
create variable transformations
Definition graph.py:1
static void Copy(onnx::TensorProto *tensor, void *data)
static void Copy(onnx::TensorProto *tensor, void *data)
static void Copy(onnx::TensorProto *tensor, void *data)
static void Copy(onnx::TensorProto *tensor, void *data)
std::unordered_map< std::string, ParserFuncSignature > fOperatorsMap