Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ParseConstant.cxx
Go to the documentation of this file.
3#include "onnx_proto3.pb.h"
4
5namespace TMVA {
6namespace Experimental {
7namespace SOFIE {
8
9// same function used to parse Constant and ConstantOfShape
10
11ParserFuncSignature ParseConstant = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
12 std::string input_name;
13 auto ninputs = nodeproto.input_size();
14 bool isConstantOfShape = false;
15 // case of ConstantOfShape (Constant has zero inputs)
16 if (ninputs > 0) {
17 input_name = nodeproto.input(0);
18 isConstantOfShape = true;
20 throw std::runtime_error("TMVA::SOFIE ONNX Parser ConstantOfShape op has input tensor" + input_name +
21 " but its type is not yet registered");
22 }
23 }
24
25 if (parser.Verbose()) {
26 std::cout << "\t.... ";
28 std::cout << "ConstantOfShape " << nodeproto.input(0) << " -> ";
29 else
30 std::cout << "Constant --> ";
31 std::cout << nodeproto.output(0) << std::endl;
32 }
33
34 std::unique_ptr<ROperator> op;
35 std::string attr_type;
36
37 std::string output_name = nodeproto.output(0);
39 std::vector<std::size_t> shape; // output shape (use in case of constant operator)
40 // it should be only one attribute (Constant or 1 or 0 COnstant of Shape)
41 if (nodeproto.attribute_size() > 1)
42 throw std::runtime_error("TMVA::SOFIE ONNX Parser Constant or ConstantOfShape and attribute size is larger than 1");
43 if (nodeproto.attribute_size() > 0) {
44 std::string attribute_name = nodeproto.attribute(0).name();
45 // tensor input
46 if (attribute_name == "value") {
47 const onnx::TensorProto & t = nodeproto.attribute(0).t();
48 output_type = static_cast<ETensorType>(t.data_type());
49
50 std::size_t length = 1;
51 for (int j = 0; j < t.dims_size(); j++) {
52 shape.push_back(t.dims(j));
53 length *= t.dims(j);
54 }
56 // value tensor should be one-element tensor
57 if (length != 1)
58 throw std::runtime_error("TMVA::SOFIE ONNX Parser ConstantOfShape has invalid tensor size " + std::to_string(length));
59 }
60 switch(output_type) {
61 // to get the tensor values one needs to use the given data types or the raw_data.
62 // it depends how the operator was created. We cannot get size of the raw_data
63 case ETensorType::INT32: {
64 std::vector<int32_t> values(length);
65 if (t.int32_data_size() == int(length)) {
66 for (size_t i = 0; i < length; i++)
67 values[i] = t.int32_data(i);
68 } else {
69 auto raw_data_ptr = reinterpret_cast<int32_t *>(const_cast<char *>(t.raw_data().c_str()));
70 std::memcpy(values.data(), raw_data_ptr, length * sizeof(int32_t));
71 }
72 op.reset(new ROperator_Constant<int32_t>("int32_t", values, shape, input_name, output_name));
73 break;
74 }
75 case ETensorType::INT64: {
76 std::vector<int64_t> values(length);
77 if (t.int64_data_size() == int(length)) {
78 for (size_t i = 0; i < length; i++)
79 values[i] = t.int64_data(i);
80 } else { // cannot get size of raw data : assume is ok
81 auto raw_data_ptr = reinterpret_cast<int64_t *>(const_cast<char *>(t.raw_data().c_str()));
82 std::memcpy(values.data(), raw_data_ptr, length * sizeof(int64_t));
83 }
84 op.reset(new ROperator_Constant<int64_t>("int64_t", values, shape, input_name, output_name));
85 break;
86 }
87 case ETensorType::FLOAT: {
88 std::vector<float> values(length);
89 if (t.float_data_size() == int(length)) {
90 for (size_t i = 0; i < length; i++)
91 values[i] = t.float_data(i);
92 } else {
93 auto raw_data_ptr = reinterpret_cast<float *>(const_cast<char *>(t.raw_data().c_str()));
94 std::memcpy(values.data(), raw_data_ptr, length * sizeof(float));
95 }
96 op.reset(new ROperator_Constant<float>("float",values, shape, input_name, output_name));
97 break;
98 }
100 std::vector<double> values(length);
101 if (t.double_data_size() == int(length)) {
102 for (size_t i = 0; i < length; i++)
103 values[i] = t.double_data(i);
104 } else {
105 auto raw_data_ptr = reinterpret_cast<double *>(const_cast<char *>(t.raw_data().c_str()));
106 std::memcpy(values.data(), raw_data_ptr, length * sizeof(double));
107 }
108 op.reset(new ROperator_Constant<double>("double",values, shape, input_name, output_name));
109 break;
110 }
111 case ETensorType::BOOL: {
112 //values are int32 in ONNX
113 std::vector<int8_t> values(length);
114 if (t.int32_data_size() == int(length)) {
115 for (size_t i = 0; i < length; i++) {
116 auto val = t.int32_data(i);
117 if (val < 0 || val > 1)
118 throw std::runtime_error("TMVA::SOFIE ONNX Parser Constant has invalid boolean value " + std::to_string(val));
119 values[i] = static_cast<int8_t>(val);
120 }
121 } else
122 throw std::runtime_error("TMVA::SOFIE ONNX Parser COnstant : invalid tensor data values");
123
124 op.reset(new ROperator_Constant<int8_t>("bool",values, shape, input_name, output_name));
125 break;
126 }
127 default:
128 throw std::runtime_error("Data type in Constant op attribute " + ConvertTypeToString(output_type) +
129 " is not supported!\n");
130 }
131 }
132 else {
133 // neither constant nor ConstantOfShape
134 if (!isConstantOfShape) {
135 // case of ConstantOfShape
136 if (attribute_name == "value_float") {
137 std::vector<float> values(1);
138 values[0] = nodeproto.attribute(0).f();
139 shape.push_back(1);
140 op.reset(new ROperator_Constant<float>("float",values, shape, input_name, output_name));
141 }
142 else if (attribute_name == "value_floats") {
143 auto values = std::vector<float>({nodeproto.attribute(0).floats().begin(), nodeproto.attribute(0).floats().end()});
144 shape.push_back(values.size());
145 op.reset(new ROperator_Constant<float>("float",values, shape, input_name, output_name));
146 }
147 else if (attribute_name == "value_int") {
148 std::vector<int64_t> values(1);
149 values[0] = nodeproto.attribute(0).i();
150 shape.push_back(1);
151 op.reset(new ROperator_Constant<int64_t>("int64_t",values, shape, input_name, output_name));
152 }
153 else if (attribute_name == "value_ints") {
154 auto values = std::vector<int64_t>({nodeproto.attribute(0).ints().begin(), nodeproto.attribute(0).ints().end()});
155 shape.push_back(values.size());
156 op.reset(new ROperator_Constant<int64_t>("int64_t",values, shape, input_name, output_name));
157 } else {
158 throw std::runtime_error("TMVA::SOFIE ONNX Parser Constant op: not yet supporting attribute " + attribute_name);
159 }
160 } else {
161 throw std::runtime_error("TMVA::SOFIE ONNX Parser ConstantOfShape op: parsed invalid attribute " + attribute_name);
162 }
163 }
164
165 // case when there is no attribute
166 } else {
167 // case of Constant of Shape : if attribute is not there use by default float type with zero values
168 if (isConstantOfShape) {
169 std::vector<float> values(1);
170 std::vector<size_t> constantShape(1,1);
172 } else {
173 throw std::runtime_error("TMVA::SOFIE ONNX Parser Constant has no attribute");
174 }
175 }
176
177 if (!parser.IsRegisteredTensorType(output_name)) {
179 }
180
181 if (parser.Verbose())
182 std::cout << "\t ParseConstant: operator created\n";
183
184 return op;
185};
186
187} // namespace SOFIE
188} // namespace Experimental
189} // namespace TMVA
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h length
const_iterator end() const
void RegisterTensorType(const std::string &, ETensorType)
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
std::string ConvertTypeToString(ETensorType type)
ParserFuncSignature ParseConstant
create variable transformations