3#include "onnx_proto3.pb.h"
6namespace Experimental {
10 const onnx::NodeProto &addnode) -> std::unique_ptr<ROperator> {
12 if (convnode.output(0) != addnode.input(0)) {
13 throw std::runtime_error(
"Cannot fuse ConvTranspose and Add operators");
18 auto input_name = convnode.input(0);
22 throw std::runtime_error(
"TMVA::SOFIE ONNX Parser ConvTranspose op has input tensor " + input_name +
23 " but its type is not yet registered");
26 std::unique_ptr<ROperator> op;
28 std::string attr_auto_pad =
"NOTSET";
29 std::vector<size_t> attr_dilations;
30 size_t attr_group = 0;
31 std::vector<size_t> attr_kernel_shape;
32 std::vector<size_t> attr_output_padding;
33 std::vector<size_t> attr_output_shape;
34 std::vector<size_t> attr_pads;
35 std::vector<size_t> attr_strides;
37 for (
int_t i = 0; i < convnode.attribute_size(); i++) {
38 std::string attribute_name = convnode.attribute(i).name();
39 if (attribute_name ==
"auto_pad") {
40 attr_auto_pad = convnode.attribute(i).s();
41 }
else if (attribute_name ==
"dilations") {
43 std::vector<size_t>({convnode.attribute(i).ints().begin(), convnode.attribute(i).ints().end()});
44 }
else if (attribute_name ==
"group") {
45 attr_group = convnode.attribute(i).i();
46 }
else if (attribute_name ==
"kernel_shape") {
48 std::vector<size_t>({convnode.attribute(i).ints().begin(), convnode.attribute(i).ints().end()});
49 }
else if (attribute_name ==
"output_padding") {
51 std::vector<size_t>({convnode.attribute(i).ints().begin(), convnode.attribute(i).ints().end()});
52 }
else if (attribute_name ==
"output_shape") {
54 std::vector<size_t>({convnode.attribute(i).ints().begin(), convnode.attribute(i).ints().end()});
55 }
else if (attribute_name ==
"pads") {
56 attr_pads = std::vector<size_t>({convnode.attribute(i).ints().begin(), convnode.attribute(i).ints().end()});
57 }
else if (attribute_name ==
"strides") {
58 attr_strides = std::vector<size_t>({convnode.attribute(i).ints().begin(), convnode.attribute(i).ints().end()});
60 std::cout <<
"TMVA::SOFIE Warning - Model Loading - Attribute " << attribute_name <<
" in OperatorNode "
61 << convnode.name() <<
" is not defined in ONNX IR and not applied!\n";
64 if (addnode.input_size() != 2) {
65 throw std::runtime_error(
"TMVA::SOFIE - Cannote fuse ConvTranspose - Add is input size of add is not 2");
68 if (convnode.output(0) == addnode.input(0) )
69 name_b = addnode.input(1);
70 else if (convnode.output(0) == addnode.input(1))
71 name_b = addnode.input(0);
73 throw std::runtime_error(
"TMVA::SOFIE - Cannote fuse ConvTranspose - Output of ConvTrans is not input to Add");
78 attr_output_padding, attr_output_shape, attr_pads, attr_strides,
79 convnode.input(0), convnode.input(1), name_b, addnode.output(0)));
82 throw std::runtime_error(
"TMVA::SOFIE - Unsupported - Operator ConvTranspose does not yet support input type " +
83 std::to_string(
static_cast<int>(input_type)));
87 std::string output_name = addnode.output(0);
bool IsRegisteredTensorType(const std::string &)
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
Transposed Convolution operator.
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &, const onnx::NodeProto &)> ParserFuseFuncSignature
ParserFuseFuncSignature ParseFuseConvTransposeAdd
create variable transformations