10 const onnx::NodeProto &relunode) -> std::unique_ptr<ROperator> {
13 auto input_name = gemmnode.input(0);
17 throw std::runtime_error(
"TMVA::SOFIE ONNX Parser MatMul op has input tensor " + input_name +
18 " but its type is not yet registered");
22 std::unique_ptr<ROperator> op;
24 float attr_alpha = 1.0;
25 float attr_beta = 1.0;
26 int_t attr_transA = 0;
27 int_t attr_transB = 0;
29 for (
int i = 0; i < gemmnode.attribute_size(); i++) {
30 std::string attribute_name = gemmnode.attribute(i).name();
31 if (attribute_name ==
"alpha") {
32 attr_alpha = gemmnode.attribute(i).f();
33 }
else if (attribute_name ==
"beta") {
34 attr_beta = gemmnode.attribute(i).f();
35 }
else if (attribute_name ==
"transA") {
36 attr_transA = gemmnode.attribute(i).i();
37 if (attr_transA != 0 && attr_transA != 1)
38 throw std::runtime_error(
"TMVA::SOFIE Error - Model Loading - attribute transA in Operator Gemm not 0/1");
39 }
else if (attribute_name ==
"transB") {
40 attr_transB = gemmnode.attribute(i).i();
41 if (attr_transB != 0 && attr_transB != 1)
42 throw std::runtime_error(
"TMVA::SOFIE Error - Model Loading - attribute transB in Operator Gemm not 0/1");
44 std::cout <<
"TMVA::SOFIE Warning - Model Loading - Attribute " << attribute_name <<
" in OperatorNode "
45 << gemmnode.name() <<
" is not defined in ONNX IR and not applied!\n";
50 if (gemmnode.input_size() == 2) {
51 op.reset(
new ROperator_Gemm<float>(attr_alpha, attr_beta, attr_transA, attr_transB, gemmnode.input(0),
54 op.reset(
new ROperator_Gemm<float>(attr_alpha, attr_beta, attr_transA, attr_transB, gemmnode.input(0),
59 throw std::runtime_error(
"TMVA::SOFIE - Unsupported - Operator Gemm does not yet support input type " +
60 std::to_string(
static_cast<int>(input_type)));
63 std::string output_name = relunode.output(0);