Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ParseSplit.cxx
Go to the documentation of this file.
3#include "onnx_proto3.pb.h"
4
5namespace TMVA {
6namespace Experimental {
7namespace SOFIE {
8
9ParserFuncSignature ParseSplit = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
10 ETensorType input_type;
11
12 std::string input_name = nodeproto.input(0);
13 if (parser.IsRegisteredTensorType(input_name)) {
14 input_type = parser.GetTensorType(input_name);
15 } else {
16 throw std::runtime_error("TMVA::SOFIE ONNX Parser Split op has input tensor" + input_name +
17 " but its type is not yet registered");
18 }
19
20 std::string split_name;
21 if (nodeproto.input_size() > 1) {
22 split_name = nodeproto.input(1);
23 if (!parser.IsRegisteredTensorType(split_name)) {
24 throw std::runtime_error("TMVA::SOFIE ONNX Parser Split op has input tensor" + split_name +
25 " but its type is not yet registered");
26 }
27 }
28
29 // ignore for time being attributes
30 if (nodeproto.attribute_size() > 0 )
31 std::cout << "WARNING: TMVA::SOFIE ONNX Parser Split operator: attributes are not yet supported- they are ignored" << std::endl;
32
33 // number of splits are given by the number of output tensors
34 size_t output_size = nodeproto.output_size();
35 std::vector<std::string> output_names(output_size);
36 for (size_t i = 0; i < output_size; i++)
37 output_names[i] = nodeproto.output(i);
38
39 std::unique_ptr<ROperator> op(new ROperator_Split<float>(input_name, split_name, output_names));
40
41 for (size_t i = 0; i < output_size; i++) {
42 if (!parser.IsRegisteredTensorType(output_names[i])) {
43 parser.RegisterTensorType(output_names[i], input_type);
44 }
45 }
46
47 return op;
48};
49
50} // namespace SOFIE
51} // namespace Experimental
52} // namespace TMVA
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
ParserFuncSignature ParseSplit
Definition ParseSplit.cxx:9
create variable transformations