3#include "onnx_proto3.pb.h"
6namespace Experimental {
12 std::string input_name;
13 auto ninputs = nodeproto.input_size();
14 bool isConstantOfShape =
false;
17 input_name = nodeproto.input(0);
18 isConstantOfShape =
true;
20 throw std::runtime_error(
"TMVA::SOFIE ONNX Parser ConstantOfShape op has input tensor" + input_name +
21 " but its type is not yet registered");
26 std::cout <<
"\t.... ";
27 if (isConstantOfShape)
28 std::cout <<
"ConstantOfShape " << nodeproto.input(0) <<
" -> ";
30 std::cout <<
"Constant --> ";
31 std::cout << nodeproto.output(0) << std::endl;
34 std::unique_ptr<ROperator> op;
35 std::string attr_type;
37 std::string output_name = nodeproto.output(0);
39 std::vector<std::size_t> shape;
41 if (nodeproto.attribute_size() > 1)
42 throw std::runtime_error(
"TMVA::SOFIE ONNX Parser Constant or ConstantOfShape and attribute size is larger than 1");
43 if (nodeproto.attribute_size() > 0) {
44 std::string attribute_name = nodeproto.attribute(0).name();
46 if (attribute_name ==
"value") {
47 const onnx::TensorProto & t = nodeproto.attribute(0).t();
48 output_type =
static_cast<ETensorType>(t.data_type());
51 for (
int j = 0; j < t.dims_size(); j++) {
52 shape.push_back(t.dims(j));
55 if (isConstantOfShape) {
58 throw std::runtime_error(
"TMVA::SOFIE ONNX Parser ConstantOfShape has invalid tensor size " + std::to_string(
length));
63 std::vector<int64_t> values(
length);
65 auto raw_data_ptr =
reinterpret_cast<int64_t *
>(
const_cast<char *
>(t.raw_data().c_str()));
66 std::memcpy(values.data(), raw_data_ptr,
length *
sizeof(int64_t));
71 std::vector<float> values(
length);
72 auto raw_data_ptr =
reinterpret_cast<float *
>(
const_cast<char *
>(t.raw_data().c_str()));
73 std::memcpy(values.data(), raw_data_ptr,
length *
sizeof(
float));
78 std::vector<bool> values(
length);
79 auto raw_data_ptr =
reinterpret_cast<bool *
>(
const_cast<char *
>(t.raw_data().c_str()));
81 std::copy(raw_data_ptr, raw_data_ptr +
length, values.begin());
87 throw std::runtime_error(
"Data type in Constant op attribute " +
ConvertTypeToString(output_type) +
88 " is not supported!\n");
93 if (!isConstantOfShape) {
95 if (attribute_name ==
"value_float") {
96 std::vector<float> values(1);
97 values[0] = nodeproto.attribute(0).f();
101 else if (attribute_name ==
"value_floats") {
102 auto values = std::vector<float>({nodeproto.attribute(0).floats().begin(), nodeproto.attribute(0).floats().end()});
103 shape.push_back(values.size());
106 else if (attribute_name ==
"value_int") {
107 std::vector<int64_t> values(1);
108 values[0] = nodeproto.attribute(0).i();
112 else if (attribute_name ==
"value_ints") {
113 auto values = std::vector<int64_t>({nodeproto.attribute(0).ints().begin(), nodeproto.attribute(0).ints().end()});
114 shape.push_back(values.size());
117 throw std::runtime_error(
"TMVA::SOFIE ONNX Parser Constant op: not yet supporting attribute " + attribute_name);
120 throw std::runtime_error(
"TMVA::SOFIE ONNX Parser ConstantOfShape op: parsed invalid attribute " + attribute_name);
127 if (isConstantOfShape) {
128 std::vector<float> values(1);
129 std::vector<size_t> constantShape(1,1);
132 throw std::runtime_error(
"TMVA::SOFIE ONNX Parser Constant has no attribute");
141 std::cout <<
"\t ParseConstant: operator created\n";
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h length
bool IsRegisteredTensorType(const std::string &)
void RegisterTensorType(const std::string &, ETensorType)
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
std::string ConvertTypeToString(ETensorType type)
ParserFuncSignature ParseConstant
create variable transformations