Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ParseBasicBinary.cxx
Go to the documentation of this file.
3#include "onnx_proto3.pb.h"
4
5namespace TMVA {
6namespace Experimental {
7namespace SOFIE {
8
9template <EBasicBinaryOperator Op>
10std::unique_ptr<ROperator> ParseBasicBinary(RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto)
11{
13
14 for (int i = 0; i < 2; ++i) {
15 auto input_name = nodeproto.input(i);
17 // according to ONNX both inputs have same type
18 if (i == 0)
20 else {
22 if (input_type2 != input_type) {
23 throw
24 std::runtime_error("TMVA::SOFIE ONNX parser Binary op has input tensors of different types: " +
26 " and " + nodeproto.input(0) + " : " + ConvertTypeToString(input_type));
27 }
28 }
29 } else {
30 throw std::runtime_error("TMVA::SOFIE ONNX Parser Binary op has input tensor " + input_name +
31 " but its type is not yet registered");
32 }
33 }
34
35 std::unique_ptr<ROperator> op;
36 std::string output_name = nodeproto.output(0);
37
38 switch (input_type) {
41 break;
44 break;
47 break;
50 break;
51 default:
52 throw std::runtime_error("TMVA::SOFIE - Unsupported - Binary Operator does not yet support input type " +
53 std::to_string(static_cast<int>(input_type)));
54 }
55
56 // Infer the output type
59 }
60
61 return op;
62};
63
64
65// Parse Add
69
70// Parse Sub
74
75// Parse Mul
79
80// Parse Div
84
85// Parse Pow
89
90// Mod (and fmod) is a special case di BasicBinary
91
92ParserFuncSignature ParseMod = [] (RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
93
95 for (int i = 0; i < 2; ++i) {
96 auto input_name = nodeproto.input(i);
98 // according to ONNX both inputs have same type
99 if (i == 0)
101 else {
103 if (input_type2 != input_type) {
104 throw
105 std::runtime_error("TMVA::SOFIE ONNX parser Binary op has input tensors of different types: " +
107 " and " + nodeproto.input(0) + " : " + ConvertTypeToString(input_type));
108 }
109 }
110 } else {
111 throw std::runtime_error("TMVA::SOFIE ONNX Parser Binary op has input tensor " + input_name +
112 " but its type is not yet registered");
113 }
114 }
115 // in case of Mod there can be an attribute
116 int fmod = 0;
117 if (nodeproto.attribute_size() > 0) {
118 fmod = nodeproto.attribute(0).i();
119 }
120 // case of float or double fmod must be 1
122 if (fmod != 1)
123 std::runtime_error("TMVA::SOFIE ONNX parser Mod operator has fmod = 0 for floating inputs");
124 }
125
126 std::unique_ptr<ROperator> op;
127 std::string output_name = nodeproto.output(0);
128
129 switch (input_type) {
132 break;
135 break;
137 if (fmod == 1)
139 else
141 break;
143 if (fmod == 1)
145 else
147 break;
148 default:
149 throw std::runtime_error("TMVA::SOFIE - Unsupported - Binary Operator does not yet support input type " +
150 std::to_string(static_cast<int>(input_type)));
151 }
152
153 // Infer the output type
154 if (!parser.IsRegisteredTensorType(output_name)) {
156 }
157
158 return op;
159};
160
161
162} // namespace SOFIE
163} // namespace Experimental
164} // namespace TMVA
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
ParserFuncSignature ParseDiv
ParserFuncSignature ParseSub
std::unique_ptr< ROperator > ParseBasicBinary(RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto)
ParserFuncSignature ParseAdd
ParserFuncSignature ParseMod
std::string ConvertTypeToString(ETensorType type)
ParserFuncSignature ParseMul
ParserFuncSignature ParsePow
create variable transformations