Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ParseFuseConvTransposeAdd.cxx
Go to the documentation of this file.
3#include "onnx_proto3.pb.h"
4
5namespace TMVA {
6namespace Experimental {
7namespace SOFIE {
8
9ParserFuseFuncSignature ParseFuseConvTransposeAdd = [](RModelParser_ONNX &parser, const onnx::NodeProto &convnode,
10 const onnx::NodeProto &addnode) -> std::unique_ptr<ROperator> {
11 // Output of ConvTranspose must be the input of Add
12 if (convnode.output(0) != addnode.input(0)) {
13 throw std::runtime_error("Cannot fuse ConvTranspose and Add operators");
14 }
15
16 // input type of ConvTranspose
17 ETensorType input_type;
18 auto input_name = convnode.input(0);
19 if (parser.IsRegisteredTensorType(input_name)) {
20 input_type = parser.GetTensorType(input_name);
21 } else {
22 throw std::runtime_error("TMVA::SOFIE ONNX Parser ConvTranspose op has input tensor " + input_name +
23 " but its type is not yet registered");
24 }
25
26 std::unique_ptr<ROperator> op;
27
28 std::string attr_auto_pad = "NOTSET";
29 std::vector<size_t> attr_dilations;
30 size_t attr_group = 0;
31 std::vector<size_t> attr_kernel_shape;
32 std::vector<size_t> attr_output_padding;
33 std::vector<size_t> attr_output_shape;
34 std::vector<size_t> attr_pads;
35 std::vector<size_t> attr_strides;
36
37 for (int_t i = 0; i < convnode.attribute_size(); i++) {
38 std::string attribute_name = convnode.attribute(i).name();
39 if (attribute_name == "auto_pad") {
40 attr_auto_pad = convnode.attribute(i).s();
41 } else if (attribute_name == "dilations") {
42 attr_dilations =
43 std::vector<size_t>({convnode.attribute(i).ints().begin(), convnode.attribute(i).ints().end()});
44 } else if (attribute_name == "group") {
45 attr_group = convnode.attribute(i).i();
46 } else if (attribute_name == "kernel_shape") {
47 attr_kernel_shape =
48 std::vector<size_t>({convnode.attribute(i).ints().begin(), convnode.attribute(i).ints().end()});
49 } else if (attribute_name == "output_padding") {
50 attr_output_padding =
51 std::vector<size_t>({convnode.attribute(i).ints().begin(), convnode.attribute(i).ints().end()});
52 } else if (attribute_name == "output_shape") {
53 attr_output_shape =
54 std::vector<size_t>({convnode.attribute(i).ints().begin(), convnode.attribute(i).ints().end()});
55 } else if (attribute_name == "pads") {
56 attr_pads = std::vector<size_t>({convnode.attribute(i).ints().begin(), convnode.attribute(i).ints().end()});
57 } else if (attribute_name == "strides") {
58 attr_strides = std::vector<size_t>({convnode.attribute(i).ints().begin(), convnode.attribute(i).ints().end()});
59 } else {
60 std::cout << "TMVA::SOFIE Warning - Model Loading - Attribute " << attribute_name << " in OperatorNode "
61 << convnode.name() << " is not defined in ONNX IR and not applied!\n";
62 }
63 }
64 if (addnode.input_size() != 2) {
65 throw std::runtime_error("TMVA::SOFIE - Cannote fuse ConvTranspose - Add is input size of add is not 2");
66 }
67 std::string name_b;
68 if (convnode.output(0) == addnode.input(0) )
69 name_b = addnode.input(1);
70 else if (convnode.output(0) == addnode.input(1))
71 name_b = addnode.input(0);
72 else
73 throw std::runtime_error("TMVA::SOFIE - Cannote fuse ConvTranspose - Output of ConvTrans is not input to Add");
74
75 switch (input_type) {
77 op.reset(new ROperator_ConvTranspose<float>(attr_auto_pad, attr_dilations, attr_group, attr_kernel_shape,
78 attr_output_padding, attr_output_shape, attr_pads, attr_strides,
79 convnode.input(0), convnode.input(1), name_b, addnode.output(0)));
80 break;
81 default:
82 throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator ConvTranspose does not yet support input type " +
83 std::to_string(static_cast<int>(input_type)));
84 }
85
86 // Output of ConvTranspose and input of Add must be have the same type
87 std::string output_name = addnode.output(0);
88 if (!parser.IsRegisteredTensorType(output_name)) {
89 parser.RegisterTensorType(output_name, input_type);
90 }
91
92 return op;
93};
94
95} // namespace SOFIE
96} // namespace Experimental
97} // namespace TMVA
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &, const onnx::NodeProto &)> ParserFuseFuncSignature
ParserFuseFuncSignature ParseFuseConvTransposeAdd
create variable transformations