Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ParseLayerNormalization.cxx
Go to the documentation of this file.
3#include "onnx_proto3.pb.h"
4
5namespace TMVA {
6namespace Experimental {
7namespace SOFIE {
8
9ParserFuncSignature ParseLayerNormalization = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto)
10-> std::unique_ptr<ROperator> {
12 const std::string input_name = nodeproto.input(0);
13 if (parser.IsRegisteredTensorType(input_name)) {
14 input_type = parser.GetTensorType(input_name);
15 } else {
16 throw std::runtime_error("TMVA::SOFIE ONNX Parser LayerNormalizaion op has input tensor " + input_name +
17 " but its type is not yet registered");
18 }
19
20 int64_t axis = -1;
21 float epsilon = 1e-5;
22 int64_t stash_type = 1;
23 for (int64_t i = 0; i < nodeproto.attribute_size(); i++) {
24 std::string attribute_name = nodeproto.attribute(i).name();
25 if (attribute_name == "axis") {
26 axis = nodeproto.attribute(i).i();
27 } else if (attribute_name == "epsilon") {
28 epsilon = nodeproto.attribute(i).f();
29 } else if (attribute_name == "stash_type") {
30 stash_type = nodeproto.attribute(i).i();
31 }
32 }
33 size_t input_size = nodeproto.input_size();
34 std::string name_scale = "";
35 if (input_size > 1) {
36 name_scale = nodeproto.input(1);
37 }
38 std::string name_bias = "";
39 if (input_size > 2) {
40 name_bias = nodeproto.input(2);
41 }
42
43 const std::string output_name = nodeproto.output(0);
44 size_t output_size = nodeproto.output_size();
45 std::string name_mean = "";
46 if (output_size > 1) {
47 name_mean = nodeproto.output(1);
48 }
49 std::string name_std = "";
50 if (output_size > 2) {
51 name_std = nodeproto.output(2);
52 }
53
54 std::unique_ptr<ROperator> op;
55 switch (input_type) {
57 op.reset(new ROperator_LayerNormalization<float>(axis, epsilon, stash_type, input_name, name_scale, name_bias,
58 output_name, name_mean, name_std));
59 break;
60 default:
61 throw std::runtime_error("TMVA::SOFIE ONNX parser Operator with input type " + ConvertTypeToString(input_type) +
62 " not supported.");
63 break;
64 }
65
66 if (!parser.IsRegisteredTensorType(output_name)) {
67 parser.RegisterTensorType(output_name, input_type);
68 }
69
70 return op;
71};
72
73}
74} // namespace Experimental
75} // namespace TMVA
#define e(i)
Definition RSha256.hxx:103
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
ParserFuncSignature ParseLayerNormalization
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
std::string ConvertTypeToString(ETensorType type)
create variable transformations