3#include "onnx_proto3.pb.h"
6namespace Experimental {
10-> std::unique_ptr<ROperator> {
12 const std::string input_name = nodeproto.input(0);
16 throw std::runtime_error(
"TMVA::SOFIE ONNX Parser LayerNormalizaion op has input tensor " + input_name +
17 " but its type is not yet registered");
22 int64_t stash_type = 1;
23 for (int64_t i = 0; i < nodeproto.attribute_size(); i++) {
24 std::string attribute_name = nodeproto.attribute(i).name();
25 if (attribute_name ==
"axis") {
26 axis = nodeproto.attribute(i).i();
27 }
else if (attribute_name ==
"epsilon") {
28 epsilon = nodeproto.attribute(i).f();
29 }
else if (attribute_name ==
"stash_type") {
30 stash_type = nodeproto.attribute(i).i();
33 size_t input_size = nodeproto.input_size();
34 std::string name_scale =
"";
36 name_scale = nodeproto.input(1);
38 std::string name_bias =
"";
40 name_bias = nodeproto.input(2);
43 const std::string output_name = nodeproto.output(0);
44 size_t output_size = nodeproto.output_size();
45 std::string name_mean =
"";
46 if (output_size > 1) {
47 name_mean = nodeproto.output(1);
49 std::string name_std =
"";
50 if (output_size > 2) {
51 name_std = nodeproto.output(2);
54 std::unique_ptr<ROperator> op;
58 output_name, name_mean, name_std));
61 throw std::runtime_error(
"TMVA::SOFIE ONNX parser Operator with input type " +
ConvertTypeToString(input_type) +
bool IsRegisteredTensorType(const std::string &)
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
ParserFuncSignature ParseLayerNormalization
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
std::string ConvertTypeToString(ETensorType type)
create variable transformations