Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ParseConstant.cxx
Go to the documentation of this file.
3#include "onnx_proto3.pb.h"
4
5namespace TMVA {
6namespace Experimental {
7namespace SOFIE {
8
9// same function used to parse Constant and ConstantOfShape
10
11ParserFuncSignature ParseConstant = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
12 std::string input_name;
13 auto ninputs = nodeproto.input_size();
14 bool isConstantOfShape = false;
15 // case of ConstantOfShape (Constant has zero inputs)
16 if (ninputs > 0) {
17 input_name = nodeproto.input(0);
18 isConstantOfShape = true;
19 if (!parser.IsRegisteredTensorType(input_name)) {
20 throw std::runtime_error("TMVA::SOFIE ONNX Parser ConstantOfShape op has input tensor" + input_name +
21 " but its type is not yet registered");
22 }
23 }
24
25 if (parser.Verbose()) {
26 std::cout << "\t.... ";
27 if (isConstantOfShape)
28 std::cout << "ConstantOfShape " << nodeproto.input(0) << " -> ";
29 else
30 std::cout << "Constant --> ";
31 std::cout << nodeproto.output(0) << std::endl;
32 }
33
34 std::unique_ptr<ROperator> op;
35 std::string attr_type;
36
37 std::string output_name = nodeproto.output(0);
38 ETensorType output_type = ETensorType::FLOAT;
39 std::vector<std::size_t> shape; // output shape (use in case of constant operator)
40 // it should be only one attribute (Constant or 1 or 0 COnstant of Shape)
41 if (nodeproto.attribute_size() > 1)
42 throw std::runtime_error("TMVA::SOFIE ONNX Parser Constant or ConstantOfShape and attribute size is larger than 1");
43 if (nodeproto.attribute_size() > 0) {
44 std::string attribute_name = nodeproto.attribute(0).name();
45 // tensor input
46 if (attribute_name == "value") {
47 const onnx::TensorProto & t = nodeproto.attribute(0).t();
48 output_type = static_cast<ETensorType>(t.data_type());
49
50 std::size_t length = 1;
51 for (int j = 0; j < t.dims_size(); j++) {
52 shape.push_back(t.dims(j));
53 length *= t.dims(j);
54 }
55 if (isConstantOfShape) {
56 // value tensor should be one-element tensor
57 if (length != 1)
58 throw std::runtime_error("TMVA::SOFIE ONNX Parser ConstantOfShape has invalid tensor size " + std::to_string(length));
59 }
60 switch(output_type) {
61 // need to use raw_data() to get the tensor values
62 case ETensorType::INT64: {
63 std::vector<int64_t> values(length);
64 // case empty shape with length=1 represents scalars
65 auto raw_data_ptr = reinterpret_cast<int64_t *>(const_cast<char *>(t.raw_data().c_str()));
66 std::memcpy(values.data(), raw_data_ptr, length * sizeof(int64_t));
67 op.reset(new ROperator_Constant<int64_t>("int64_t", values, shape, input_name, output_name));
68 break;
69 }
70 case ETensorType::FLOAT: {
71 std::vector<float> values(length);
72 auto raw_data_ptr = reinterpret_cast<float *>(const_cast<char *>(t.raw_data().c_str()));
73 std::memcpy(values.data(), raw_data_ptr, length * sizeof(float));
74 op.reset(new ROperator_Constant<float>("float",values, shape, input_name, output_name));
75 break;
76 }
77 case ETensorType::BOOL: {
78 std::vector<bool> values(length);
79 auto raw_data_ptr = reinterpret_cast<bool *>(const_cast<char *>(t.raw_data().c_str()));
80 // cannot use values.data() for vector of bools
81 std::copy(raw_data_ptr, raw_data_ptr + length, values.begin());
82 //std::memcpy(values.data(), raw_data_ptr, length * sizeof(float));
83 op.reset(new ROperator_Constant<bool>("bool",values, shape, input_name, output_name));
84 break;
85 }
86 default:
87 throw std::runtime_error("Data type in Constant op attribute " + ConvertTypeToString(output_type) +
88 " is not supported!\n");
89 }
90 }
91 else {
92 // neither constant nor ConstantOfShape
93 if (!isConstantOfShape) {
94 // case of ConstantOfShape
95 if (attribute_name == "value_float") {
96 std::vector<float> values(1);
97 values[0] = nodeproto.attribute(0).f();
98 shape.push_back(1);
99 op.reset(new ROperator_Constant<float>("float",values, shape, input_name, output_name));
100 }
101 else if (attribute_name == "value_floats") {
102 auto values = std::vector<float>({nodeproto.attribute(0).floats().begin(), nodeproto.attribute(0).floats().end()});
103 shape.push_back(values.size());
104 op.reset(new ROperator_Constant<float>("float",values, shape, input_name, output_name));
105 }
106 else if (attribute_name == "value_int") {
107 std::vector<int64_t> values(1);
108 values[0] = nodeproto.attribute(0).i();
109 shape.push_back(1);
110 op.reset(new ROperator_Constant<int64_t>("int64_t",values, shape, input_name, output_name));
111 }
112 else if (attribute_name == "value_ints") {
113 auto values = std::vector<int64_t>({nodeproto.attribute(0).ints().begin(), nodeproto.attribute(0).ints().end()});
114 shape.push_back(values.size());
115 op.reset(new ROperator_Constant<int64_t>("int64_t",values, shape, input_name, output_name));
116 } else {
117 throw std::runtime_error("TMVA::SOFIE ONNX Parser Constant op: not yet supporting attribute " + attribute_name);
118 }
119 } else {
120 throw std::runtime_error("TMVA::SOFIE ONNX Parser ConstantOfShape op: parsed invalid attribute " + attribute_name);
121 }
122 }
123
124 // case when there is no attribute
125 } else {
126 // case of Constant of Shape : if attribute is not there use by default float type with zero values
127 if (isConstantOfShape) {
128 std::vector<float> values(1);
129 std::vector<size_t> constantShape(1,1);
130 op.reset(new ROperator_Constant<float>("float",values,constantShape, input_name, output_name));
131 } else {
132 throw std::runtime_error("TMVA::SOFIE ONNX Parser Constant has no attribute");
133 }
134 }
135
136 if (!parser.IsRegisteredTensorType(output_name)) {
137 parser.RegisterTensorType(output_name, output_type);
138 }
139
140 if (parser.Verbose())
141 std::cout << "\t ParseConstant: operator created\n";
142
143 return op;
144};
145
146} // namespace SOFIE
147} // namespace Experimental
148} // namespace TMVA
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h length
void RegisterTensorType(const std::string &, ETensorType)
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
std::string ConvertTypeToString(ETensorType type)
ParserFuncSignature ParseConstant
create variable transformations