Logo ROOT  
Reference Guide
ROperator_Reshape.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR_RESHAPE
2#define TMVA_SOFIE_ROPERATOR_RESHAPE
3
5#include "TMVA/ROperator.hxx"
6#include "TMVA/RModel.hxx"
7
8#include <cassert>
9#include <sstream>
10
11namespace TMVA{
12namespace Experimental{
13namespace SOFIE{
14
16
17template <typename T>
18class ROperator_Reshape final : public ROperator
19{
20
21private:
22
23 ReshapeOpMode fOpMode = Reshape; // type of Reshape operator
24
25 int fAllowZero = 0; // (for Reshape) zero in tensor shape makes output shape equal to input tensor shape
26 int fAxis = 1; // (for Flatten)
27
28 std::string fNData; // input data tensor name
29 std::string fNShape; // reshape tensor name
30 std::string fNOutput; // output tensor name
31 std::vector<size_t> fShapeInput; // input shape data
32 std::vector<size_t> fShapeOutput; // output shape data
33 std::vector<int64_t> fAttrAxes; // axes attributes (provided for all version of Squeeze/Unsqueeze)
34
35public:
36
38 ROperator_Reshape(ReshapeOpMode opMode, int attr_value, std::string nameData, std::string nameShape, std::string nameOutput)
39 : fOpMode(opMode), fNData(UTILITY::Clean_name(nameData)), fNShape(UTILITY::Clean_name(nameShape)),
40 fNOutput(UTILITY::Clean_name(nameOutput))
41 {
42 if (opMode == Reshape) fAllowZero = attr_value;
43 if (opMode == Flatten) fAxis = attr_value;
44 }
45
46 // for squeeze/unsqueezed operators following old ONNX version (< 10)
47 // In this cases axes are passed as attribute values
48 ROperator_Reshape(ReshapeOpMode opMode, std::vector<int64_t> attrAxes, std::string nameData, std::string nameOutput)
49 : fOpMode(opMode), fNData(UTILITY::Clean_name(nameData)), fNOutput(UTILITY::Clean_name(nameOutput)),
50 fAttrAxes(attrAxes)
51 {
52 assert(fOpMode == Squeeze || fOpMode == Unsqueeze);
53 }
54
55 // output type is same as input
56 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
57 auto ret = std::vector<ETensorType>(1, input[0]);
58 return ret;
59 }
60
61 // output shape
62 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input){
63 std::vector<std::vector<size_t>> ret;
64 auto & input_shape = input[0];
65
66 if (fOpMode == Reshape) {
67 if (input.size() != 2) throw std::runtime_error("TMVA SOFIE Reshape Op needs 2 input tensors");
68 auto output_shape = input[1]; // the provided shape
69 size_t input_length = ConvertShapeToLength(input_shape);
70 size_t output_length = ConvertShapeToLength(output_shape);
71 // (input_length == output_length) is the easy case : (2,3,4) -> (2,12)
72 if (input_length != output_length) {
73 if (output_shape.size() > 1 && ((output_length == 0 && fAllowZero == 0) || output_length > INT64_MAX)) {
74 // in this case value 0 in shape are automatically corrected
75 for (size_t i = 0; i < output_shape.size(); i++) {
76 if (output_shape[i] == 0 || output_shape[i] == static_cast<size_t>(-1)) {
77 auto tmp = output_shape;
78 tmp.erase(tmp.begin() + i);
79 auto tmp_length = ConvertShapeToLength(tmp);
80 output_shape[i] = input_length / tmp_length;
81 break;
82 }
83 }
84 }
85 if (ConvertShapeToLength(output_shape) != input_length) {
86 throw std::runtime_error("TMVA Reshape Op : Invalid shapes : " + ConvertShapeToString(input_shape) +
87 ConvertShapeToString(output_shape));
88 }
89 }
90 ret.push_back(output_shape);
91
92 } else if (fOpMode == Flatten) {
93 // flattenig case
94 size_t inputSize = ConvertShapeToLength(input_shape);
95 size_t b = input[0][0];
96 std::vector<size_t> newShape = {b, inputSize / b};
97 ret.push_back(newShape);
98
99 } else if (fOpMode == Squeeze) {
100 // squeeze
101 // assume no axis is provided - remove all axes with value equal to 1
102 auto output_shape = input[0];
103 if (input.size() == 1) {
104 for (size_t i = 0; i < output_shape.size(); i++) {
105 if (output_shape[i] == 1 ) {
106 output_shape.erase(output_shape.begin() + i);
107 }
108 }
109 } else if (input.size() == 2) {
110 auto & axes = input[1];
111 for (size_t i = 0; i < axes.size(); i++){
112 if (output_shape[axes[i]] != 1)
113 throw std::runtime_error("TMVA Squeeze Op : Invalid axes : " + ConvertShapeToString(axes) +
114 ConvertShapeToString(output_shape));
115 output_shape.erase(output_shape.begin() + axes[i]);
116 }
117 }
118 ret.push_back(output_shape);
119 }
120
121 else if (fOpMode == Unsqueeze) {
122 // unsqueeze
123 assert(input.size() == 2);
124 auto output_shape = input[0];
125 auto &axes = input[1];
126 if (axes[0] > 0) { // positive axis start from beginning
127 for (auto & i : axes)
128 output_shape.insert(output_shape.begin() + i, 1);
129 } else {
130 //negative axes
131 for (auto &i : axes) {
132 assert(i < 0);
133 output_shape.insert(output_shape.begin() + (output_shape.size() + i - 1), 1);
134 }
135 }
136 ret.push_back(output_shape);
137 }
138 return ret;
139 }
140
141 void Initialize(RModel &model)
142 {
143
144 if (model.CheckIfTensorAlreadyExist(fNData) == false) {
145 // input must be a graph input, or already initialized intermediate tensor
146 throw std::runtime_error("TMVA Reshape Op Input Tensor " + fNData + " is not found in model");
147 }
149
150 // check if optional shape tensor exist
151 if (!fNShape.empty()) {
153 auto dptr = model.GetInitializedTensorData(fNShape);
154 auto input_shape = static_cast<int64_t *>(dptr.get());
155 auto vec = model.GetTensorShape(fNShape);
156 assert(vec.size() == 1);
157 size_t n = vec[0]; // size of shape input tensor
158
159 std::vector<size_t> descShape(n);
160 std::copy(input_shape, input_shape + n, descShape.begin());
161 fShapeOutput = ShapeInference({fShapeInput, descShape})[0];
162 } else {
163 throw std::runtime_error("TMVA Reshape Op Shape Tensor " + fNShape + " is not found in model");
164 }
165 } else if (!fAttrAxes.empty()) {
166 // case fNShape is empty and axes are provided as attributes
167 std::vector<size_t> descShape(fAttrAxes.size());
168 std::copy(fAttrAxes.begin(), fAttrAxes.end(), descShape.begin());
169 fShapeOutput = ShapeInference({fShapeInput, descShape})[0];
170 } else if (fOpMode == Flatten || fOpMode == Squeeze) {
172 } else {
173 throw std::runtime_error("TMVA Reshape Op : Invalid Input/Attribute data");
174 }
176 }
177
178 std::string Generate(std::string OpName)
179 {
180 OpName = "op_" + OpName;
181 if (fShapeInput.empty() || fShapeOutput.empty()) {
182 throw std::runtime_error("TMVA SOFIE Reshape Op called to Generate without being initialized first");
183 }
184
185 // output of reshape is same as input
188 throw std::runtime_error("TMVA SOFIE Reshape Op : wrong output shape - is " +
189 ConvertShapeToString(fShapeOutput) + " and input is " +
191 }
192 for (auto &i : fShapeOutput) {
193 length *= i;
194 }
195 std::stringstream out;
196 std::string opName = "Reshape";
197 if (fOpMode == Flatten)
198 opName = "Flatten";
199 else if (fOpMode == Squeeze)
200 opName = "Squeeze";
201 else if (fOpMode == Unsqueeze)
202 opName = "Unsquueze";
203
204 out << SP << "///--------" << opName << " operator\n" << std::endl;
205 out << SP << "std::copy( fTensor_" << fNData << ".begin(), fTensor_" << fNData << ".end(), fTensor_" << fNOutput
206 << ".begin() );\n";
207 return out.str();
208 }
209};
210
211}//SOFIE
212}//Experimental
213}//TMVA
214
215
216#endif //TMVA_SOFIE_ROPERATOR_RESHAPE
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t b
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h length
#define INT64_MAX
Definition: civetweb.c:511
const ETensorType & GetTensorType(std::string name)
Definition: RModel.cxx:79
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape)
Definition: RModel.cxx:149
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition: RModel.cxx:100
const std::vector< size_t > & GetTensorShape(std::string name)
Definition: RModel.cxx:58
std::shared_ptr< void > GetInitializedTensorData(std::string tensor_name)
Definition: RModel.cxx:173
ROperator_Reshape(ReshapeOpMode opMode, std::vector< int64_t > attrAxes, std::string nameData, std::string nameOutput)
ROperator_Reshape(ReshapeOpMode opMode, int attr_value, std::string nameData, std::string nameShape, std::string nameOutput)
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > input)
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input)
const std::string SP
space used to correctly indent the generated C++ code
Definition: ROperator.hxx:39
const Int_t n
Definition: legend1.C:16
std::string Clean_name(std::string input_tensor_name)
std::string ConvertShapeToString(std::vector< size_t > shape)
std::size_t ConvertShapeToLength(std::vector< size_t > shape)
create variable transformations
Definition: civetweb.c:1856