3#include "onnx_proto3.pb.h"
6namespace Experimental {
12 auto input_name = nodeproto.input(0);
16 throw std::runtime_error(
"TMVA::SOFIE ONNX Parser gemm op has input tensor" + input_name +
17 " but its type is not yet registered");
20 std::unique_ptr<ROperator> op;
22 float attr_alpha = 1.0;
23 float attr_beta = 1.0;
24 int_t attr_transA = 0;
25 int_t attr_transB = 0;
27 for (
int i = 0; i < nodeproto.attribute_size(); i++) {
28 std::string attribute_name = nodeproto.attribute(i).name();
29 if (attribute_name ==
"alpha") {
30 attr_alpha = nodeproto.attribute(i).f();
31 }
else if (attribute_name ==
"beta") {
32 attr_beta = nodeproto.attribute(i).f();
33 }
else if (attribute_name ==
"transA") {
34 attr_transA = nodeproto.attribute(i).i();
35 if (attr_transA != 0 && attr_transA != 1)
36 throw std::runtime_error(
"TMVA::SOFIE Error - Model Loading - attribute transA in Operator Gemm not 0/1");
37 }
else if (attribute_name ==
"transB") {
38 attr_transB = nodeproto.attribute(i).i();
39 if (attr_transB != 0 && attr_transB != 1)
40 throw std::runtime_error(
"TMVA::SOFIE Error - Model Loading - attribute transB in Operator Gemm not 0/1");
42 std::cout <<
"TMVA::SOFIE Warning - Model Loading - Attribute " << attribute_name <<
" in OperatorNode "
43 << nodeproto.name() <<
" is not defined in ONNX IR and not applied!\n";
47 std::string output_name = nodeproto.output(0);
50 if (nodeproto.input_size() == 2) {
51 op.reset(
new ROperator_Gemm<float>(attr_alpha, attr_beta, attr_transA, attr_transB, nodeproto.input(0),
52 nodeproto.input(1), output_name));
54 op.reset(
new ROperator_Gemm<float>(attr_alpha, attr_beta, attr_transA, attr_transB, nodeproto.input(0),
55 nodeproto.input(1), nodeproto.input(2), output_name));
59 throw std::runtime_error(
"TMVA::SOFIE - Unsupported - Operator Gemm does not yet support input type " +
60 std::to_string(
static_cast<int>(input_type)));
bool IsRegisteredTensorType(const std::string &)
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
ParserFuncSignature ParseGemm
create variable transformations