Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RModelParser_ONNX.cxx
Go to the documentation of this file.
1#include "Byteswap.h"
3#include "onnx_proto3.pb.h"
4
5#include <stdexcept>
6#include <string>
7#include <memory>
8#include <cassert>
9#include <iostream>
10#include <unordered_map>
11#include <functional>
12#include "TMVA/SOFIE_common.hxx"
13
14namespace TMVA {
15namespace Experimental {
16namespace SOFIE {
17
18// Declaration of operators
19// Unary operators
31// Binary operators
38// Nary operators
43//Comparision Operators
49//Is Operators
53// Reduce operators
58// Others
103// Declaration of fused operators
109
110// Definition of RModelParser_ONNX::OperatorsMap
112 // Registered operators
113 std::unordered_map<std::string, ParserFuncSignature> fOperatorsMap;
114};
115
116// helper function to get initialized tensor data
117template<typename T>
119};
120// trait function to extract data from TensorProto
121template<>
122struct ExtractDataFromTP<float> {
123 static void Copy(onnx::TensorProto * tensor, void * data, int length) {
124 if (tensor->float_data_size() != length)
125 throw std::runtime_error("TMVA::SOFIE - Failed to read float initialized tensor - actual size is " + std::to_string(tensor->float_data_size()));
126 tensor->mutable_float_data()->ExtractSubrange(0, tensor->float_data_size(),
127 static_cast<float *>(data));
128 }
129};
130template<>
132 static void Copy(onnx::TensorProto * tensor, void * data, int length) {
133 if (tensor->double_data_size() != length)
134 throw std::runtime_error("TMVA::SOFIE - Failed to read double initialized tensor - actual size is " + std::to_string(tensor->double_data_size()));
135 tensor->mutable_double_data()->ExtractSubrange(0, tensor->double_data_size(),
136 static_cast<double *>(data));
137 }
138};
139template<>
140struct ExtractDataFromTP<int32_t> {
141 static void Copy(onnx::TensorProto * tensor, void * data, int length) {
142 if (tensor->int32_data_size() != length)
143 throw std::runtime_error("TMVA::SOFIE - Failed to read int32 initialized tensor - actual size is " + std::to_string(tensor->int32_data_size()));
144 tensor->mutable_int32_data()->ExtractSubrange(0, tensor->int32_data_size(),
145 static_cast<int32_t *>(data));
146 }
147};
148template<>
149struct ExtractDataFromTP<int64_t> {
150 static void Copy(onnx::TensorProto * tensor, void * data, int length) {
151 if (tensor->int64_data_size() != length)
152 throw std::runtime_error("TMVA::SOFIE - Failed to read int64 initialized tensor - actual size is " + std::to_string(tensor->int64_data_size()));
153 tensor->mutable_int64_data()->ExtractSubrange(0, tensor->int64_data_size(),
154 static_cast<int64_t *>(data));
155 }
156};
157
158std::shared_ptr<void> RModelParser_ONNX::GetInitializedTensorData(onnx::TensorProto *tensorproto, size_t tensor_size, ETensorType tensor_type)
159{
160
161 std::shared_ptr<void> data(malloc(tensor_size), free);
162
163 // check if initialized tensors are stored internally
164 if (tensorproto->data_location() != onnx::TensorProto::EXTERNAL) {
165 if (tensorproto->raw_data().size() > 0) {
166 if (tensorproto->raw_data().size() != tensor_size)
167 throw std::runtime_error("TMVA::SOFIE - Failed to read raw data of initialized tensor - actual raw size is " +
168 std::to_string(tensorproto->raw_data().size()));
169
170#ifdef R__BYTESWAP
171 // R__BYTESWAP is defined for little-endian architectures (most common ones)
172 std::memcpy(data.get(), tensorproto->raw_data().c_str(), tensor_size);
173#else
174 // big-endian architectures - need to swap bytes
175 for (std::size_t k = 0; k < tensor_size; ++k)
176 (reinterpret_cast<typename RByteSwap<sizeof(uint8_t)>::value_type *>(data.get()))[k] =
177 RByteSwap<sizeof(T)>::bswap((reinterpret_cast<const typename RByteSwap<sizeof(uint8_t)>::value_type *>(
178 tensorproto->raw_data().c_str()))[k]);
179#endif
180 } else {
181 // case tensor data are stored as specific types and now in raw_data
182 switch (tensor_type) {
183 case ETensorType::FLOAT: {
184 ExtractDataFromTP<float>::Copy(tensorproto, data.get(), tensor_size/ 4);
185 break;
186 }
187 case ETensorType::DOUBLE: {
188 ExtractDataFromTP<double>::Copy(tensorproto, data.get(), tensor_size/ 8);
189 break;
190 }
191 case ETensorType::INT32: {
192 ExtractDataFromTP<int32_t>::Copy(tensorproto, data.get(), tensor_size/ 4);
193 break;
194 }
195 case ETensorType::INT64: {
196 ExtractDataFromTP<int64_t>::Copy(tensorproto, data.get(), tensor_size/ 8);
197 break;
198 }
199 case ETensorType::BOOL: {
200 throw std::runtime_error("TMVA::SOFIE - ExtractData from TP in BOOL not supported");
201 break;
202 }
203 case ETensorType::UINT8: {
204 throw std::runtime_error("TMVA::SOFIE - ExtractData from TP in UINT8 not supported");
205 break;
206 }
207 default:
208 throw std::runtime_error("Data type " + ConvertTypeToString(tensor_type) + " in weight tensor is not supported!\n");
209 }
210 }
211
212 } else {
213 // case of external data
214 if (fVerbose)
215 std::cout << "Initialized data are stored externally in file " << fDataFileName;
216
217 // read now tensor from file
218 std::string location;
219 size_t offset = 0, buffer_size = 0;
220
221 for (const auto &kv : tensorproto->external_data()) {
222 if (kv.key() == "location") location = kv.value();
223 else if (kv.key() == "offset") offset = std::stoull(kv.value());
224 else if (kv.key() == "length") buffer_size = std::stoull(kv.value());
225 }
226 if (fVerbose)
227 std::cout << " at location " << location << " offset " << offset << " and with length " << buffer_size << std::endl;
228
229 if (buffer_size != tensor_size)
230 throw std::runtime_error("TMVA::SOFIE ONNX : invalid stored data size vs tensor size");
231
232 // open the data file if needed
233 if (!fDataFile.is_open()) {
234 fDataFile.open(fDataFileName, std::ios::binary);
235 if (!fDataFile.is_open())
236 throw std::runtime_error("TMVA::SOFIE ONNX: error reading external weight ONNX data file " + fDataFileName);
237 }
238
239 fDataFile.seekg(offset);
240 fDataFile.read(reinterpret_cast<char *>(data.get()), buffer_size);
241 }
242
243 return data;
244}
245
246
247// Constructor of the parser
248RModelParser_ONNX::RModelParser_ONNX() noexcept : fOperatorsMapImpl(std::make_unique<OperatorsMapImpl>()) {
249 // Register operators
250 // Unary operators
252 RegisterOperator("Reciprocal", ParseReciprocal);
259 RegisterOperator("Softplus", ParseSoftplus);
262 // Binary operators
269 // Nary operators
274 //Comparision Operators
275 RegisterOperator("Equal", ParseEq);
277 RegisterOperator("LessOrEqual", ParseLessEq);
278 RegisterOperator("Greater", ParseGreater);
279 RegisterOperator("GreaterOrEqual", ParseGreaterEq);
280 // Is If operators
284 // Reduce operators
285 RegisterOperator("ReduceMean", ParseReduceMean);
286 RegisterOperator("ReduceSum", ParseReduceSum);
287 RegisterOperator("ReduceSumSquare", ParseReduceSumSquare);
288 RegisterOperator("ReduceProd", ParseReduceProd);
289 // Others
290 RegisterOperator("BatchNormalization", ParseBatchNormalization);
291 RegisterOperator("Constant", ParseConstant);
292 RegisterOperator("ConstantOfShape", ParseConstant);
294 RegisterOperator("Concat", ParseConcat);
296 RegisterOperator("ConvTranspose", ParseConvTranspose);
299 RegisterOperator("Identity", ParseIdentity);
300 RegisterOperator("LeakyRelu", ParseLeakyRelu);
302 RegisterOperator("AveragePool", ParsePool);
303 RegisterOperator("GlobalAveragePool", ParsePool);
304 RegisterOperator("MaxPool", ParsePool);
306 RegisterOperator("Reshape", ParseReshape);
307 RegisterOperator("Flatten", ParseReshape);
308 RegisterOperator("Squeeze", ParseReshape);
309 RegisterOperator("Unsqueeze", ParseReshape);
314 RegisterOperator("Sigmoid", ParseSigmoid);
316 RegisterOperator("Softmax", ParseSoftmax);
317 RegisterOperator("LogSoftmax", ParseSoftmax);
319 RegisterOperator("Transpose", ParseTranspose);
320 RegisterOperator("MatMul", ParseMatMul);
321 RegisterOperator("LayerNormalization", ParseLayerNormalization);
322 RegisterOperator("Expand", ParseExpand);
323 RegisterOperator("Gather", ParseGather);
324 RegisterOperator("GatherND", ParseGatherND);
327 RegisterOperator("EyeLike", ParseEyeLike);
335 RegisterOperator("Einsum", ParseEinsum);
336 RegisterOperator("RandomNormal", ParseRandom);
337 RegisterOperator("RandomNormalLike", ParseRandom);
338 RegisterOperator("RandomUniform", ParseRandom);
339 RegisterOperator("RandomUniformLike", ParseRandom);
340 RegisterOperator("ScatterElements", ParseScatterElements);
341 RegisterOperator("ScatterND", ParseScatterND);
342 RegisterOperator("NonZero", ParseNonZero);
344}
345
346// Destructor of the parser
348
350{
351 fOperatorsMapImpl->fOperatorsMap[name] = func;
352}
353
355{
356 return fOperatorsMapImpl->fOperatorsMap.find(name) != fOperatorsMapImpl->fOperatorsMap.end();
357}
358
360{
361 std::vector<std::string> ops;
362 ops.reserve(fOperatorsMapImpl->fOperatorsMap.size());
363 for (auto &it : fOperatorsMapImpl->fOperatorsMap) {
364 ops.emplace_back(it.first);
365 }
366 // return sorted list in alphabetical order
367 std::sort(ops.begin(), ops.end());
368 return ops;
369}
370
375
377{
379}
380
385
386// Parse an operator
387std::unique_ptr<ROperator>
388RModelParser_ONNX::ParseOperator(const size_t i, const onnx::GraphProto &graphproto, const std::vector<size_t> &nodes, const std::vector<int> & children)
389{
390 if (i >= nodes.size())
391 throw std::runtime_error("TMVA::SOFIE - Error in parsing ordered operators " + std::to_string(i) + " is >= " + std::to_string(nodes.size()));
392 int idx = nodes[i];
393 const auto &nodeproto = graphproto.node(idx);
394 const std::string op_type = nodeproto.op_type();
395 if (fVerbose)
396 std::cout << "Parsing operator " << op_type << std::endl;
397
398 // perform the fusion of operators
399 if (fFusedOperators.count(idx) == 1) {
400 int idx1 = fFusedOperators[idx].second;
401 if (fVerbose) {
402 std::cout << "\tFusing operators " << graphproto.node(idx1).name()
403 << " with " << graphproto.node(idx1).name() << std::endl;
404 }
405 if (fFusedOperators[idx].first == EFusedOp::kMatMulAdd) {
406 return ParseFuseMatMulAdd(*this, graphproto.node(idx1), graphproto.node(idx));
407 } else if (fFusedOperators[idx].first == EFusedOp::kConvAdd) {
408 return ParseFuseConvAdd(*this, graphproto.node(idx1), graphproto.node(idx));
409 } else if (fFusedOperators[idx].first == EFusedOp::kConvTransAdd) {
410 return ParseFuseConvTransposeAdd(*this, graphproto.node(idx1), graphproto.node(idx));
411 } else if (fFusedOperators[idx].first == EFusedOp::kGemmRelu) {
412 return ParseFuseGemmRelu(*this, graphproto.node(idx1), graphproto.node(idx));
413 } else if (fFusedOperators[idx].first == EFusedOp::kBatchnormRelu) {
414 return ParseFuseBatchnormRelu(*this, graphproto.node(idx1), graphproto.node(idx));
415 }
416 }
417
418 // try to fuse with following operator in case it is not last one and having only a single child
419 if (children.size() == 1) {
420 int idx2 = children.front();
421 if (op_type == "MatMul") {
422 // Fuse MatMul and Add
423 if (idx2 < graphproto.node_size() && graphproto.node(idx2).op_type() == "Add") {
425 return nullptr;
426 }
427 } else if (nodeproto.op_type() == "Conv" || nodeproto.op_type() == "ConvTranspose") {
428 // Fuse Conv or ConvTranspose without bias and Add
429 if (idx2 < graphproto.node_size() && graphproto.node(idx2).op_type() == "Add") {
430 if (nodeproto.op_type() == "Conv") {
432 return nullptr;
433 } else {
435 return nullptr;
436 }
437 }
438 } else if (nodeproto.op_type() == "Gemm") {
439 // Fuse Gemm with activation operators
440 if (idx2 < graphproto.node_size() && graphproto.node(idx2).op_type() == "Relu") {
442 return nullptr;
443 }
444 } else if (nodeproto.op_type() == "BatchNormalization") {
445 if (idx2 < graphproto.node_size() && graphproto.node(idx2).op_type() == "Relu") {
447 return nullptr;
448 }
449 }
450 }
451
452 auto it = fOperatorsMapImpl->fOperatorsMap.find(op_type);
453 if (it == fOperatorsMapImpl->fOperatorsMap.end()) {
454 std::cout << "operator " << op_type << " is not supported" << std::endl;
455 throw std::runtime_error("TMVA::SOFIE Operator type " + op_type + " is not yet supported");
456 }
457 if (fVerbose) {
458 std::cout << "\tCreating operator " << op_type << std::endl;
459 }
460 return it->second(*this, nodeproto);
461}
462
463// Parse a model
464RModel RModelParser_ONNX::Parse(std::string const &filename, bool verbose)
465{
466 fVerbose = verbose;
467
468 fTensorTypeMap.clear();
469
470 auto model = LoadModel(filename);
471 if (!model)
472 throw std::runtime_error("TMVA::SOFIE - Failed to load onnx file " + filename);
473
474 const onnx::GraphProto &graph = model->graph(); // not a memory leak. model freed automatically at the end.
475
476
477 std::time_t ttime = std::time(0);
478 std::tm *gmt_time = std::gmtime(&ttime);
479 std::string parsetime(std::asctime(gmt_time));
480
481 // get name of model (filename without directory name)
482 char sep = '/';
483#ifdef _WIN32
484 sep = '\\';
485#endif
486 size_t isep = filename.rfind(sep, filename.length());
487 std::string filename_nodir = filename;
488 if (isep != std::string::npos) {
489 filename_nodir = (filename.substr(isep + 1, filename.length() - isep));
490 }
491
492 if (fDataFileName.empty() ) fDataFileName = filename + ".data";
493
496 return rmodel;
497}
498
499RModel RModelParser_ONNX::Parse(std::istream &input, std::string const &name, bool verbose)
500{
501 fVerbose = verbose;
502
503 fTensorTypeMap.clear();
504
505 auto model = LoadModel(input);
506 if (!model)
507 throw std::runtime_error("TMVA::SOFIE - Failed to parse ONNX model from input stream");
508
509 const onnx::GraphProto &graph = model->graph(); // not a memory leak. model freed automatically at the end.
510
511 std::time_t ttime = std::time(0);
512 std::tm *gmt_time = std::gmtime(&ttime);
513 std::string parsetime(std::asctime(gmt_time));
514
516 ParseONNXGraph(rmodel, graph, name);
517 return rmodel;
518}
519
520std::unique_ptr<onnx::ModelProto> RModelParser_ONNX::LoadModel(const std::string &filename) {
521 std::fstream input(filename, std::ios::in | std::ios::binary);
522 if (!input) {
523 std::cerr << "TMVA::SOFIE - Failed to open onnx file " << filename << std::endl;
524 return {};
525 }
526
527 return LoadModel(input);
528}
529
530std::unique_ptr<onnx::ModelProto> RModelParser_ONNX::LoadModel(std::istream &input) {
532 auto model = std::make_unique<onnx::ModelProto>();
533
534 if (!model->ParseFromIstream(&input)) {
535 std::cerr << "TMVA::SOFIE - Failed to parse ONNX model from input stream" << std::endl;
536 return {};
537 }
538
539 // ONNX version is ir_version() - model_version() returns 0
540 if (fVerbose) {
541 std::cout << "ONNX Version " << model->ir_version() << std::endl;
542 }
543 google::protobuf::ShutdownProtobufLibrary();
544 return model;
545
546}
547
548void RModelParser_ONNX::CheckGraph(const onnx::GraphProto & graph, int & level, std::map<std::string, int> & missingOperators) {
549 if (fVerbose)
550 std::cout << "\n" << graph.name() << " Graph operator list\n";
551 for (int i = 0; i < graph.node_size(); i++) {
552 const auto & node = graph.node(i);
553 const std::string opType = node.op_type();
554 if (fVerbose) {
555 std::cout << "\tOperator " << i << " : " << opType << " (" << node.name() << "), " << graph.node(i).input_size()
556 << " inputs : {";
557 for (int j = 0; j < graph.node(i).input_size(); j++) {
558 std::cout << graph.node(i).input(j);
559 if (j < graph.node(i).input_size() - 1)
560 std::cout << ", ";
561 }
562 std::cout << " }" << std::endl;
563 }
564 // check if operator exists
566 missingOperators[opType] = level;
567 // see if sub-graph exists as node attributes
568 for (int j = 0; j < node.attribute_size(); j++) {
569 const auto & attribute = node.attribute(j);
570 if (attribute.has_g()) {
571 const auto & subGraph = attribute.g();
572 level += 1;
574 }
575 }
576 }
577}
578
579bool RModelParser_ONNX::CheckModel(std::string filename, bool verbose) {
580
581 fVerbose = verbose;
582 auto model = LoadModel(filename);
583 if (!model) return false;
584
585 const onnx::GraphProto &graph = model->graph();
586 // Initial operator order
587 if (fVerbose)
588 std::cout << "\nModel operator list " << model->producer_name() << "\n";
589
590 std::map<std::string, int> missingOperators;
591 int level = 1;
592 CheckGraph(graph, level, missingOperators);
593
594 if (!missingOperators.empty()) {
595 std::cout << "List of missing operators for model loaded from file " << filename << std::endl;
596 for (auto & op : missingOperators) {
597 std::cout << op.first << " " << op.second << std::endl;
598 }
599 return false;
600 }
601 std::cout << "All operators in the loaded model are supported!\n";
602 return true;
603}
604
605void RModelParser_ONNX::ParseONNXGraph(RModel & rmodel, const onnx::GraphProto & graph, std::string graphName)
606{
607 bool verbose = fVerbose;
608
609 if (graphName.empty())
610 graphName = graph.name();
611
612 if (verbose)
613 std::cout << "\nParsing Graph - " << graphName << std::endl;
614
615 std::unordered_set<std::string> initializer_names;
616 for (int i = 0; i < graph.initializer_size(); i++) {
617 initializer_names.insert(graph.initializer(i).name());
618 }
619
620 if (verbose)
621 std::cout << "Parsing model inputs...." << std::endl;
622 /// Loop on model inputs
623 for (int i = 0; i < graph.input_size(); i++) {
624 RegisterTensorType(graph.input(i).name(),
625 static_cast<ETensorType>(graph.input(i).type().tensor_type().elem_type()));
626
627 if (verbose)
628 std::cout << "\tgraph input " << i << " name " << graph.input(i).name() << " type "
629 << graph.input(i).type().tensor_type().elem_type() << std::endl;
630
631 if (initializer_names.find(graph.input(i).name()) != initializer_names.end())
632 continue;
633
634 // input data node is not a weight node (has no initializer)
635 const onnx::ValueInfoProto &valueinfoproto = graph.input(i);
636 std::string input_name = valueinfoproto.name();
637
638 ETensorType type = static_cast<ETensorType>(valueinfoproto.type().tensor_type().elem_type());
639
640 std::vector<Dim> fShape;
641 bool existParam = false;
642 if (!valueinfoproto.type().tensor_type().has_shape())
643 throw std::runtime_error("TMVA::SOFIE data node with no shape restrictions is not supported yet");
644 for (int j = 0; j < valueinfoproto.type().tensor_type().shape().dim_size(); j++) {
645 Dim dim;
646 if (valueinfoproto.type().tensor_type().shape().dim(j).value_case() ==
647 onnx::TensorShapeProto_Dimension::ValueCase::kDimValue) {
648 int dim_value = valueinfoproto.type().tensor_type().shape().dim(j).dim_value();
649 dim.dim = dim_value;
650 // case input dim is -1 - set a parametric shape
651 if (dim_value < 0) {
652 dim.isParam = true;
653 existParam = true;
654 dim.param = UTILITY::Clean_name(input_name) + "_size";
655 }
656 } else if (valueinfoproto.type().tensor_type().shape().dim(j).value_case() ==
657 onnx::TensorShapeProto_Dimension::ValueCase::kDimParam) {
658 dim.isParam = true;
659 existParam = true;
660 dim.param = valueinfoproto.type().tensor_type().shape().dim(j).dim_param();
661 } else {
662 throw std::runtime_error("TMVA::SOFIE ONNX file error: Valueinfoproto " + input_name +
663 " has neither dim_value nor dim_param! \n");
664 }
665 fShape.push_back(dim);
666 }
667 if (valueinfoproto.type().tensor_type().shape().dim_size() == 0) {
668 Dim dim;
669 dim.dim = 1;
670 fShape.push_back(dim);
671 } // in case this TensorShapeProto has no dimension message: ONNX IR defines this to be a scalar
672
673 if (!existParam) {
674 std::vector<size_t> fShape_sizet;
675 for (auto &j : fShape) {
676 fShape_sizet.push_back(j.dim);
677 }
678
679 rmodel.AddInputTensorInfo(input_name, type, fShape_sizet);
680 } else {
681 rmodel.AddInputTensorInfo(input_name, type, fShape);
682 }
683 rmodel.AddInputTensorName(input_name); // store also names in given order
684 }
685
686 std::map<std::string, int> allInitializedTensors;
687
688 if (verbose)
689 std::cout << "\nParsing graph initializer list and fill model initialized tensors" << std::endl;
690
691 for (int i = 0; i < graph.initializer_size(); i++) {
692 onnx::TensorProto *tensorproto = const_cast<onnx::TensorProto *>(&graph.initializer(i));
693 std::vector<std::size_t> shape;
694 std::size_t tensor_length = 1;
695 for (int j = 0; j < tensorproto->dims_size(); j++) {
696 shape.push_back(tensorproto->dims(j));
697 tensor_length *= tensorproto->dims(j);
698 }
699 // in case of scalars keep an empty shape but with length =1
700
701 std::string tensor_name = graph.initializer(i).name();
702
703 if (verbose)
704 std::cout << "\t initializer " << i << " name " << tensor_name << " type " << graph.initializer(i).data_type()
705 << " and length " << tensor_length << std::endl;
706
707
708 // register also the initialized tensors
709 auto tensor_type = static_cast<ETensorType>(graph.initializer(i).data_type());
710 RegisterTensorType(tensor_name, tensor_type);
711
713 rmodel.AddInitializedTensor(tensor_name, tensor_type, shape, data);
714 allInitializedTensors[tensor_name] = i;
715
716 if (verbose) {
717 std::cout << "add initialized tensor " << tensor_name << "with shape " << ConvertShapeToString(shape) << "and ";
719 std::cout << " float data: ";
721 }
722 else if (tensor_type == ETensorType::INT64) {
723 std::cout << " int64 data: ";
725 }
726 else if (tensor_type == ETensorType::UINT8) {
727 std::cout << " uint8 data: ";
729 }
730 else if (tensor_type == ETensorType::BOOL) {
731 std::cout << " Boolean data: ";
733 }
734 std::cout << std::endl;
735 }
736 } // end initializer list
737
738 // Initial operator order
739 if (verbose) {
740 std::cout << "\nGraph operator list (ONNX order)\n";
741 for (int i = 0; i < graph.node_size(); i++) {
742 std::cout << "\tOperator " << i << " : " << graph.node(i).op_type() << " , " << graph.node(i).input_size()
743 << " inputs : {";
744 for (int j = 0; j < graph.node(i).input_size(); j++) {
745 std::cout << graph.node(i).input(j);
746 if (j < graph.node(i).input_size() - 1)
747 std::cout << ", ";
748 }
749 std::cout << " }" << std::endl;
750 }
751 }
752
753 // make order of nodes:
754 if (verbose)
755 std::cout << "\n***********************\nRe-Order graph operator list\n*************************\n";
756 std::vector<size_t> nodesOrder;
757 nodesOrder.reserve(graph.node_size());
758 std::vector<bool> foundNodes(graph.node_size());
759
760 // loop at graph inputs
761 std::map<std::string, int> allInputs;
762 for (int i = 0; i < graph.input_size(); i++) {
763 allInputs[graph.input(i).name()] = -1;
764 }
765 do {
766 auto psize = nodesOrder.size();
767 for (int i = 0; i < graph.node_size(); i++) {
768 if (foundNodes[i])
769 continue;
770 // check if all input exists add to list
771 bool existInputs = true;
772 int input_size = graph.node(i).input_size();
773 // special case for Reshape where shape is input and not a weight tensor
774 if (fVerbose)
775 std::cout << "Checking input of Node " << i << " : " << graph.node(i).name() << std::endl;
776 for (int j = 0; j < input_size; j++) {
777 std::string name = graph.node(i).input(j);
778 // skip empty names
779 if (!name.empty()) {
780 existInputs &= (allInputs.find(name) != allInputs.end() ||
782 if (fVerbose) {
783 std::cout << "\t\t input " << name << " "
784 << bool(allInputs.find(name) != allInputs.end()) << " " <<
786 existInputs << std::endl;
787 }
788 }
789 }
790 if (!existInputs) {
791 if (fVerbose) {
792 std::cout << "skip node " << graph.node(i).op_type() << " " << graph.node(i).name() << " inputs are not existing ";
793 for (int j = 0; j < input_size; j++) {
794 std::cout << graph.node(i).input(j) << " ";
795 }
796 std::cout << std::endl;
797 }
798 continue;
799 }
800
801 // adding node to the currectly ordered list
802 if (verbose)
803 std::cout << "===> New node " << graph.node(i).op_type() << " " << graph.node(i).name() << " order " << i << std::endl;
804
805 nodesOrder.push_back(i);
806 foundNodes[i] = true;
807 // register the outputs
808 for (int j = 0; j < graph.node(i).output_size(); j++) {
809 if (fVerbose) std::cout << "\toutput : " << graph.node(i).output(j) << std::endl;
810 allInputs[graph.node(i).output(j)] = i;
811 }
812 }
813 // no increment in nodes - something wrong
814 if (nodesOrder.size() == psize) {
815 int ilast = nodesOrder.back();
816 std::cout << "cannot find a new node after " << graph.node(ilast).op_type() << " " << graph.node(ilast).name() << std::endl;
817 throw std::runtime_error("TMVA::SOFIE - cannot find a new node ");
818 }
819 } while ((int)nodesOrder.size() < graph.node_size());
820
821
822 // find list of children for each operator (used for fusing oiperators)
823 std::vector<std::vector<int>> nodesChildren(graph.node_size());
824
825 for (int k = 0; k < graph.node_size(); k++) {
826 int i = nodesOrder[k];
827 // compute the number of output for the operators
828 if (graph.node(i).output_size() > 0) nodesChildren[i].reserve(graph.node(i).output_size());
829 for (const auto& output_name : graph.node(i).output()) {
830 // loop on all nodes
831 for (int l = k; l < graph.node_size(); l++) {
832 int j = nodesOrder[l];
833 for (const auto& input_name : graph.node(j).input()) {
834 if (input_name == output_name)
835 nodesChildren[i].push_back(j);
836 }
837 }
838 }
839 }
840
841 // print lit of order operators with list of inputs and list of children nodes
842 if (verbose) {
843 std::cout << "\nGraph operator list (re-ordered)\n";
844 for (int k = 0; k < graph.node_size(); k++) {
845 int i = nodesOrder[k];
846 std::cout << "\tOperator " << i << " : " << graph.node(i).op_type() << " , " << graph.node(i).name() << " input tensors : {";
847 for (int j = 0; j < graph.node(i).input_size(); j++) {
848 std::cout << graph.node(i).input(j);
849 if (j < graph.node(i).input_size() - 1)
850 std::cout << ", ";
851 }
852 std::cout << " } ";
853 std::cout << " children : {";
854 for ( const auto & ichild : nodesChildren[i]) {
855 std::cout << " [ " << ichild << " " << graph.node(ichild).op_type() << " , " << graph.node(ichild).name() << "]";
856 }
857 std::cout << "}" << std::endl;
858 }
859 }
860
861 // fill model with operators
862 if (verbose) {
863 std::cout << "Fill RModel with operators...\n";
864 }
865
866 // we have to record order of node execution separately to
867 // account for fused operators
868 size_t node_order_exec = 0;
869 for (int i = 0; i < graph.node_size(); i++) {
870 std::string op_type = graph.node(nodesOrder[i]).op_type();
871
872 if (verbose) {
873 std::cout << "\t" << i << " " << nodesOrder[i] << " parsing operator " << op_type << std::endl;
874 }
875
876 std::unique_ptr<ROperator> op = ParseOperator(i, graph, nodesOrder, nodesChildren[i]);
877 if (!op) {
878 if (verbose) {
879 std::cout << "\t\tskipping operator since it is fused with previous one" << std::endl;
880 }
881 // for skipping the fused nodes like Add after MatMul
882 continue;
883 }
884 rmodel.AddOperator(std::move(op), node_order_exec++);
885 }
886
887 std::vector<std::string> outputnames;
888 if (verbose)
889 std::cout << "\nParsing Graph output list\n";
890 for (int i = 0; i < graph.output_size(); i++) {
891 if (verbose)
892 std::cout << "\toutput " << i << " name " << graph.output(i).name() << std::endl;
893 outputnames.push_back(graph.output(i).name());
894 }
895 rmodel.AddOutputTensorNameList(outputnames);
896
897 return;
898}
899
900} // namespace SOFIE
901} // namespace Experimental
902} // namespace TMVA
dims_t fShape
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h offset
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h length
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
char name[80]
Definition TGX11.cxx:146
#define malloc
Definition civetweb.c:1575
const_iterator begin() const
const_iterator end() const
void RegisterOperator(const std::string &name, ParserFuncSignature func)
std::unique_ptr< ROperator > ParseOperator(const size_t, const onnx::GraphProto &, const std::vector< size_t > &, const std::vector< int > &)
bool IsRegisteredOperator(const std::string &name)
void CheckGraph(const onnx::GraphProto &g, int &level, std::map< std::string, int > &missingOperators)
void ParseONNXGraph(RModel &model, const onnx::GraphProto &g, std::string name="")
std::unordered_map< std::string, ETensorType > fTensorTypeMap
RModel Parse(std::string const &filename, bool verbose=false)
std::shared_ptr< void > GetInitializedTensorData(onnx::TensorProto *tensorproto, size_t tensor_length, ETensorType type)
std::map< int, std::pair< EFusedOp, int > > fFusedOperators
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
std::vector< std::string > GetRegisteredOperators()
std::unique_ptr< onnx::ModelProto > LoadModel(const std::string &filename)
std::unique_ptr< OperatorsMapImpl > fOperatorsMapImpl
bool CheckModel(std::string filename, bool verbose=false)
std::string Clean_name(std::string input_tensor_name)
ParserFuncSignature ParseIsNaN
ParserFuncSignature ParseSqrt
ParserFuncSignature ParseBatchNormalization
ParserFuncSignature ParseGreater
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &, const onnx::NodeProto &)> ParserFuseFuncSignature
ParserFuncSignature ParseReshape
ParserFuseFuncSignature ParseFuseConvTransposeAdd
ParserFuncSignature ParseReduceMean
ParserFuseFuncSignature ParseFuseMatMulAdd
ParserFuncSignature ParseGather
ParserFuncSignature ParseNeg
ParserFuncSignature ParseWhere
Definition ParseWhere.cxx:9
ParserFuncSignature ParseCos
ParserFuncSignature ParseLog
ParserFuncSignature ParseLeakyRelu
ParserFuncSignature ParseExp
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
ParserFuncSignature ParseEinsum
ParserFuncSignature ParsePool
Definition ParsePool.cxx:9
ParserFuncSignature ParseDiv
ParserFuncSignature ParseLayerNormalization
ParserFuncSignature ParseConcat
ParserFuncSignature ParseTopK
Definition ParseTopK.cxx:9
ParserFuncSignature ParseMax
ParserFuncSignature ParseEq
ParserFuncSignature ParseIdentity
ParserFuncSignature ParseConvTranspose
ParserFuncSignature ParseReduceProd
ParserFuncSignature ParseNot
Definition ParseNot.cxx:9
ParserFuncSignature ParseSlice
Definition ParseSlice.cxx:9
ParserFuncSignature ParseRandom
ParserFuncSignature ParseTranspose
ParserFuncSignature ParseLess
ParserFuncSignature ParseShape
Definition ParseShape.cxx:9
ParserFuncSignature ParseClip
Definition ParseClip.cxx:25
constexpr size_t GetTypeSize(ETensorType type)
ParserFuncSignature ParseScatterND
ParserFuncSignature ParseGRU
Definition ParseGRU.cxx:9
ParserFuncSignature ParseMatMul
ParserFuncSignature ParseErf
Definition ParseErf.cxx:9
ParserFuncSignature ParseSub
ParserFuncSignature ParseAdd
ParserFuncSignature ParseNonZero
ParserFuncSignature ParseIf
Definition ParseIf.cxx:9
ParserFuncSignature ParseRange
Definition ParseRange.cxx:9
ParserFuncSignature ParseSoftplus
ParserFuncSignature ParseExpand
ParserFuncSignature ParseRNN
Definition ParseRNN.cxx:9
ParserFuncSignature ParseLSTM
Definition ParseLSTM.cxx:9
ParserFuncSignature ParseCast
Definition ParseCast.cxx:9
ParserFuncSignature ParseReciprocal
ParserFuncSignature ParseSigmoid
ParserFuseFuncSignature ParseFuseConvAdd
ParserFuncSignature ParseAtan
ParserFuncSignature ParseFloor
ParserFuseFuncSignature ParseFuseBatchnormRelu
ParserFuncSignature ParseIsInf
ParserFuncSignature ParseSoftmax
ParserFuncSignature ParseGreaterEq
ParserFuncSignature ParseMod
std::string ConvertTypeToString(ETensorType type)
ParserFuncSignature ParseGelu
Definition ParseGelu.cxx:9
ParserFuncSignature ParseMean
ParserFuncSignature ParseSplit
Definition ParseSplit.cxx:9
ParserFuncSignature ParseConstant
ParserFuncSignature ParseSelu
Definition ParseSelu.cxx:9
ParserFuncSignature ParseLessEq
ParserFuncSignature ParseGatherND
ParserFuncSignature ParseSum
ParserFuncSignature ParseEyeLike
ParserFuncSignature ParsePad
Definition ParsePad.cxx:9
ParserFuncSignature ParseElu
Definition ParseElu.cxx:9
std::string ConvertShapeToString(const std::vector< size_t > &shape)
ParserFuncSignature ParseMin
ParserFuncSignature ParseRelu
Definition ParseRelu.cxx:9
ParserFuncSignature ParseReduceSum
ParserFuncSignature ParseConv
Definition ParseConv.cxx:9
ParserFuncSignature ParseScatterElements
ParserFuncSignature ParseGemm
Definition ParseGemm.cxx:9
ParserFuncSignature ParseTile
Definition ParseTile.cxx:9
ParserFuncSignature ParseMul
ParserFuseFuncSignature ParseFuseGemmRelu
ParserFuncSignature ParsePow
ParserFuncSignature ParseAbs
ParserFuncSignature ParseSin
ParserFuncSignature ParseReduceSumSquare
ParserFuncSignature ParseTanh
Definition ParseTanh.cxx:9
create variable transformations
Helper templated class for swapping bytes; specializations for N={2,4,8} are provided below.
Definition Byteswap.h:124
static void Copy(onnx::TensorProto *tensor, void *data, int length)
static void Copy(onnx::TensorProto *tensor, void *data, int length)
static void Copy(onnx::TensorProto *tensor, void *data, int length)
static void Copy(onnx::TensorProto *tensor, void *data, int length)
std::unordered_map< std::string, ParserFuncSignature > fOperatorsMap
TLine l
Definition textangle.C:4