Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ParseConvTranspose.cxx
Go to the documentation of this file.
3#include "onnx_proto3.pb.h"
4
5namespace TMVA {
6namespace Experimental {
7namespace SOFIE {
8
10 const onnx::NodeProto &nodeproto) -> std::unique_ptr<ROperator> {
11 auto inputName = nodeproto.input(0);
12 if (!model.IsRegisteredTensorType(inputName)) {
13 throw std::runtime_error("TMVA::SOFIE ONNX Parser ConvTranspose op has input tensor " + inputName +
14 " but its type is not yet registered");
15 }
16 ETensorType inputType = model.GetTensorType(inputName);
17
18 std::string autoPad = "NOTSET";
19 std::vector<size_t> dilations;
20 size_t group = 0;
21 std::vector<size_t> kernelShape;
22 std::vector<size_t> outputPadding;
23 std::vector<size_t> outputShape;
24 std::vector<size_t> pads;
25 std::vector<size_t> strides;
26
27 for (int_t i = 0; i < nodeproto.attribute_size(); i++) {
28 std::string attributeName = nodeproto.attribute(i).name();
29 if (attributeName == "auto_pad") {
30 autoPad = nodeproto.attribute(i).s();
31 } else if (attributeName == "dilations") {
32 dilations = std::vector<size_t>({nodeproto.attribute(i).ints().begin(), nodeproto.attribute(i).ints().end()});
33 } else if (attributeName == "group") {
34 group = nodeproto.attribute(i).i();
35 } else if (attributeName == "kernel_shape") {
36 kernelShape =
37 std::vector<size_t>({nodeproto.attribute(i).ints().begin(), nodeproto.attribute(i).ints().end()});
38 } else if (attributeName == "output_padding") {
39 outputPadding =
40 std::vector<size_t>({nodeproto.attribute(i).ints().begin(), nodeproto.attribute(i).ints().end()});
41 } else if (attributeName == "output_shape") {
42 outputShape =
43 std::vector<size_t>({nodeproto.attribute(i).ints().begin(), nodeproto.attribute(i).ints().end()});
44 } else if (attributeName == "pads") {
45 pads = std::vector<size_t>({nodeproto.attribute(i).ints().begin(), nodeproto.attribute(i).ints().end()});
46 } else if (attributeName == "strides") {
47 strides = std::vector<size_t>({nodeproto.attribute(i).ints().begin(), nodeproto.attribute(i).ints().end()});
48 } else {
49 std::cout << "TMVA::SOFIE Warning - Model Loading - Attribute " << attributeName << " in OperatorNode "
50 << nodeproto.name() << " is not defined in ONNX IR and not applied!\n";
51 }
52 }
53
54 std::string nameW = nodeproto.input(1);
55 std::string nameBias;
56 if (nodeproto.input_size() > 2) {
57 nameBias = nodeproto.input(2);
58 }
59 std::string outputName = nodeproto.output(0);
60
61 std::unique_ptr<ROperator> op;
62 switch (inputType) {
64 op.reset(new ROperator_ConvTranspose<float>(autoPad, dilations, group, kernelShape, outputPadding, outputShape,
65 pads, strides, inputName, nameW, nameBias, outputName));
66 break;
67 default:
68 throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator ConvTranspose does not yet support input type " +
69 std::to_string(static_cast<int>(inputType)));
70 }
71
72 if (!model.IsRegisteredTensorType(outputName)) {
73 model.RegisterTensorType(outputName, inputType);
74 }
75
76 return op;
77};
78
79} // namespace SOFIE
80} // namespace Experimental
81} // namespace TMVA
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
ParserFuncSignature ParseConvTranspose
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
create variable transformations