Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ParseGemm.cxx
Go to the documentation of this file.
3#include "onnx_proto3.pb.h"
4
5namespace TMVA {
6namespace Experimental {
7namespace SOFIE {
8
9ParserFuncSignature ParseGemm = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
10 ETensorType input_type;
11
12 auto input_name = nodeproto.input(0);
13 if (parser.IsRegisteredTensorType(input_name)) {
14 input_type = parser.GetTensorType(input_name);
15 } else {
16 throw std::runtime_error("TMVA::SOFIE ONNX Parser gemm op has input tensor" + input_name +
17 " but its type is not yet registered");
18 }
19
20 std::unique_ptr<ROperator> op;
21
22 float attr_alpha = 1.0;
23 float attr_beta = 1.0;
24 int_t attr_transA = 0;
25 int_t attr_transB = 0;
26
27 for (int i = 0; i < nodeproto.attribute_size(); i++) {
28 std::string attribute_name = nodeproto.attribute(i).name();
29 if (attribute_name == "alpha") {
30 attr_alpha = nodeproto.attribute(i).f();
31 } else if (attribute_name == "beta") {
32 attr_beta = nodeproto.attribute(i).f();
33 } else if (attribute_name == "transA") {
34 attr_transA = nodeproto.attribute(i).i();
35 if (attr_transA != 0 && attr_transA != 1)
36 throw std::runtime_error("TMVA::SOFIE Error - Model Loading - attribute transA in Operator Gemm not 0/1");
37 } else if (attribute_name == "transB") {
38 attr_transB = nodeproto.attribute(i).i();
39 if (attr_transB != 0 && attr_transB != 1)
40 throw std::runtime_error("TMVA::SOFIE Error - Model Loading - attribute transB in Operator Gemm not 0/1");
41 } else {
42 std::cout << "TMVA::SOFIE Warning - Model Loading - Attribute " << attribute_name << " in OperatorNode "
43 << nodeproto.name() << " is not defined in ONNX IR and not applied!\n";
44 }
45 }
46
47 std::string output_name = nodeproto.output(0);
48 switch (input_type) {
50 if (nodeproto.input_size() == 2) {
51 op.reset(new ROperator_Gemm<float>(attr_alpha, attr_beta, attr_transA, attr_transB, nodeproto.input(0),
52 nodeproto.input(1), output_name));
53 } else {
54 op.reset(new ROperator_Gemm<float>(attr_alpha, attr_beta, attr_transA, attr_transB, nodeproto.input(0),
55 nodeproto.input(1), nodeproto.input(2), output_name));
56 }
57 break;
58 default:
59 throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Gemm does not yet support input type " +
60 std::to_string(static_cast<int>(input_type)));
61 }
62
63 if (!parser.IsRegisteredTensorType(output_name)) {
64 parser.RegisterTensorType(output_name, input_type);
65 }
66
67 return op;
68};
69
70} // namespace SOFIE
71} // namespace Experimental
72} // namespace TMVA
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
ParserFuncSignature ParseGemm
Definition ParseGemm.cxx:9
create variable transformations