3#include "onnx_proto3.pb.h"
6namespace Experimental {
13 auto input_name = nodeproto.input(0);
17 throw std::runtime_error(
"TMVA::SOFIE ONNX Parser Slice op has input tensor" + input_name +
18 " but its type is not yet registered");
21 std::vector<std::string> axisTensorNames;
22 if (nodeproto.input_size() > 1)
23 axisTensorNames.push_back(nodeproto.input(1));
24 if (nodeproto.input_size() > 2)
25 axisTensorNames.push_back(nodeproto.input(1));
26 if (nodeproto.input_size() > 3)
27 axisTensorNames.push_back(nodeproto.input(3));
28 if (nodeproto.input_size() > 4)
29 axisTensorNames.push_back(nodeproto.input(4));
34 std::vector<int64_t> attr_starts = {};
35 std::vector<int64_t> attr_ends = {};
36 std::vector<int64_t> attr_axes = {};
37 if (nodeproto.input_size() == 1) {
38 for (
int_t i = 0; i < nodeproto.attribute_size(); i++) {
39 std::string attribute_name = nodeproto.attribute(i).name();
40 if (attribute_name ==
"starts")
41 attr_starts = {nodeproto.attribute(i).ints().begin(), nodeproto.attribute(i).ints().end()};
42 if (attribute_name ==
"ends")
43 attr_ends = {nodeproto.attribute(i).ints().begin(), nodeproto.attribute(i).ints().end()};
44 if (attribute_name ==
"axes")
45 attr_axes = {nodeproto.attribute(i).ints().begin(), nodeproto.attribute(i).ints().end()};
49 std::unique_ptr<ROperator> op;
50 std::string output_name = nodeproto.output(0);
53 if (axisTensorNames.size() > 0) {
60 throw std::runtime_error(
61 "TMVA::SOFIE - Unsupported - Operator Slice has invalid input type for input axis descriptors " +
62 std::to_string(
static_cast<int>(axis_type)));
63 }
else if (attr_starts.size() > 0 && attr_ends.size() > 0) {
66 throw std::runtime_error(
"TMVA::SOFIE - Unsupported - Operator Slice has invalid attribues");
70 throw std::runtime_error(
"TMVA::SOFIE - Unsupported - Operator Slice does not yet support input type " +
71 std::to_string(
static_cast<int>(input_type)));
bool IsRegisteredTensorType(const std::string &)
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
ParserFuncSignature ParseSlice
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
create variable transformations