Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_Reshape.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR_RESHAPE
2#define TMVA_SOFIE_ROPERATOR_RESHAPE
3
5#include "TMVA/ROperator.hxx"
6#include "TMVA/RModel.hxx"
7
8#include <cassert>
9#include <sstream>
10
11namespace TMVA{
12namespace Experimental{
13namespace SOFIE{
14
16
17
18class ROperator_Reshape final : public ROperator
19{
20
21private:
22
23 bool fVerbose = false;
24 ReshapeOpMode fOpMode = Reshape; // type of Reshape operator
25
26 int fAllowZero = 0; // (for Reshape) zero in tensor shape makes output shape equal to input tensor shape
27 int fAxis = 1; // (for Flatten)
28
29 std::string fNData; // input data tensor name
30 std::string fNShape; // reshape tensor name
31 std::string fNOutput; // output tensor name
32 std::vector<size_t> fShapeInput; // input shape data
33 std::vector<size_t> fShapeOutput; // output shape data
34 std::vector<int64_t> fAttrAxes; // axes attributes (provided for all version of Squeeze/Unsqueeze)
35
36public:
37
38 std::string Name() const {
39 if (fOpMode == Reshape) return "Reshape";
40 if (fOpMode == Flatten) return "Flatten";
41 if (fOpMode == Squeeze) return "Squeeze";
42 if (fOpMode == Unsqueeze) return "Unsqueeze";
43 return "";
44 }
45
47 ROperator_Reshape(ReshapeOpMode opMode, int attr_value, std::string nameData, std::string nameShape, std::string nameOutput)
48 : fOpMode(opMode), fNData(UTILITY::Clean_name(nameData)), fNShape(UTILITY::Clean_name(nameShape)),
49 fNOutput(UTILITY::Clean_name(nameOutput))
50 {
51 if (opMode == Reshape) fAllowZero = attr_value;
52 if (opMode == Flatten) fAxis = attr_value;
53 }
54
55 // for squeeze/unsqueezed operators following old ONNX version (< 10)
56 // In this cases axes are passed as attribute values
57 ROperator_Reshape(ReshapeOpMode opMode, std::vector<int64_t> attrAxes, std::string nameData, std::string nameOutput)
58 : fOpMode(opMode), fNData(UTILITY::Clean_name(nameData)), fNOutput(UTILITY::Clean_name(nameOutput)),
59 fAttrAxes(attrAxes)
60 {
61 assert(fOpMode == Squeeze || fOpMode == Unsqueeze);
62 }
63
64 // output type is same as input
65 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
66 auto ret = std::vector<ETensorType>(1, input[0]);
67 return ret;
68 }
69
70 // output shape
71 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input){
72 std::vector<std::vector<size_t>> ret;
73 auto & input_shape = input[0];
74
75 if (fOpMode == Reshape) {
76 if (input.size() != 2) throw std::runtime_error("TMVA SOFIE Reshape Op needs 2 input tensors");
77 auto output_shape = input[1]; // the provided shape
78 size_t input_length = ConvertShapeToLength(input_shape);
79 size_t output_length = ConvertShapeToLength(output_shape);
80 // (input_length == output_length) is the easy case : (2,3,4) -> (2,12)
81 if (input_length != output_length) {
82 if ((output_length == 0 && fAllowZero == 0) || output_length > INT64_MAX) {
83 // in this case value 0 or -1 in shape are automatically corrected
84 bool replacementDone = false;
85 for (size_t i = 0; i < output_shape.size(); i++) {
86 if (output_shape[i] == 0 || output_shape[i] == static_cast<size_t>(-1)) {
87 if (replacementDone) {
88 throw std::runtime_error("TMVA Reshape Op : output shape has multiple negative or zero values");
89 }
90 auto tmp = output_shape;
91 tmp.erase(tmp.begin() + i);
92 auto tmp_length = ConvertShapeToLength(tmp);
93 output_shape[i] = input_length / tmp_length;
94 replacementDone = true;
95 }
96 }
97 if (fVerbose)
98 std::cout << "Reshape: correct output shape from " << ConvertShapeToString(input[1])
99 << " to " << ConvertShapeToString(output_shape) << std::endl;
100 }
101 if (ConvertShapeToLength(output_shape) != input_length) {
102 throw std::runtime_error("TMVA Reshape Op : Invalid shapes : " + ConvertShapeToString(input_shape) +
103 ConvertShapeToString(output_shape));
104 }
105 }
106 ret.push_back(output_shape);
107
108 } else if (fOpMode == Flatten) {
109 // flattenig case
110 size_t inputSize = ConvertShapeToLength(input_shape);
111 size_t b = input[0][0];
112 std::vector<size_t> newShape = {b, inputSize / b};
113 ret.push_back(newShape);
114
115 } else if (fOpMode == Squeeze) {
116 // squeeze
117 // assume no axis is provided - remove all axes with value equal to 1
118 auto output_shape = input[0];
119 if (input.size() == 1) {
120 size_t i = 0;
121 while (i < output_shape.size()) {
122 if (output_shape[i] == 1 ) {
123 output_shape.erase(output_shape.begin() + i);
124 }
125 else {
126 i++;
127 }
128 }
129 } else if (input.size() == 2) {
130 auto & axes = input[1];
131 for (size_t i = 0; i < axes.size(); i++){
132 if (output_shape[axes[i]] != 1)
133 throw std::runtime_error("TMVA Squeeze Op : Invalid axes : " + ConvertShapeToString(axes) +
134 ConvertShapeToString(output_shape));
135 output_shape.erase(output_shape.begin() + axes[i]);
136 }
137 }
138 ret.push_back(output_shape);
139 }
140
141 else if (fOpMode == Unsqueeze) {
142 // unsqueeze
143 assert(input.size() == 2);
144 auto output_shape = input[0];
145 auto &axes = input[1];
146 // output rank
147 int64_t r = input[0].size() + axes.size();
148 for (auto & a : axes) {
149 int64_t i = static_cast<int64_t>(a);
150 if ( i < -r || i > r - 1 )
151 throw std::runtime_error("TMVA Unsqueeze Op - axes input is not in correct range");
152 if (i >= 0)
153 output_shape.insert(output_shape.begin() + i, 1);
154 else
155 //negative axes
156 output_shape.insert(output_shape.end() + i + 1, 1);
157 }
158 ret.push_back(output_shape);
159 }
160 return ret;
161 }
162
163 void Initialize(RModel &model)
164 {
165 fVerbose = model.Verbose();
166 if (model.CheckIfTensorAlreadyExist(fNData) == false) {
167 // input must be a graph input, or already initialized intermediate tensor
168 throw std::runtime_error("TMVA Reshape Op Input Tensor " + fNData + " is not found in model");
169 }
171 // check if optional shape tensor exist
172 if (!fNShape.empty()) {
174 auto dptr = model.GetInitializedTensorData(fNShape);
175 auto input_shape = static_cast<int64_t *>(dptr.get());
176 auto vec = model.GetTensorShape(fNShape);
177 assert(vec.size() == 1);
178 size_t n = vec[0]; // size of shape input tensor
179
180 std::vector<size_t> descShape(n);
181 std::copy(input_shape, input_shape + n, descShape.begin());
182 fShapeOutput = ShapeInference({fShapeInput, descShape})[0];
183 // set flag to not write tensor in weight file. Its data will be hard-coded in way model is constructed
185 } else {
186 throw std::runtime_error("TMVA Reshape Op Shape Tensor " + fNShape + " is not found in model");
187 }
188 } else if (!fAttrAxes.empty()) {
189 // case fNShape is empty and axes are provided as attributes
190 std::vector<size_t> descShape(fAttrAxes.size());
191 std::copy(fAttrAxes.begin(), fAttrAxes.end(), descShape.begin());
192 fShapeOutput = ShapeInference({fShapeInput, descShape})[0];
193 } else if (fOpMode == Flatten || fOpMode == Squeeze) {
195 } else {
196 throw std::runtime_error("TMVA Reshape Op : Invalid Input/Attribute data");
197 }
198 // check if output is constant or not
200 fIsOutputConstant = true;
201 auto inputData = static_cast<int64_t*>(model.GetInitializedTensorData(fNData).get());
203 throw std::runtime_error("TMVA Reshape Op : Invalid Input/Output lengths");
204 model.AddConstantTensor<int64_t>(fNOutput, fShapeOutput, inputData);
205 if (model.Verbose()) {
206 std::cout << Name() << " : " << fNData << " " << ConvertShapeToString(fShapeInput) << " --> " << fNOutput << " (constant) " << ConvertShapeToString(fShapeOutput) << " : " <<
208 }
209 } else {
210 // non-constant case
212 if (model.Verbose())
213 std::cout << Name() << " : " << fNData << " " << ConvertShapeToString(fShapeInput) << " --> "<< fNOutput << " " << ConvertShapeToString(fShapeOutput) << std::endl;
214 }
215 }
216
217 std::string Generate(std::string OpName)
218 {
219 if (fIsOutputConstant) return ""; //no op for constant tensors
220
221 OpName = "op_" + OpName;
222
223 // output of reshape is same as input
226 throw std::runtime_error("TMVA SOFIE Reshape Op : wrong output shape - is " +
227 ConvertShapeToString(fShapeOutput) + " and input is " +
229 }
230 std::stringstream out;
231 std::string opName = "Reshape";
232 if (fOpMode == Flatten)
233 opName = "Flatten";
234 else if (fOpMode == Squeeze)
235 opName = "Squeeze";
236 else if (fOpMode == Unsqueeze)
237 opName = "Unsquueze";
238
239 out << SP << "///--------" << opName << " operator\n" << std::endl;
240 out << SP << "std::copy( tensor_" << fNData << ", tensor_" << fNData << " + " << length << ", " << "tensor_" << fNOutput
241 << ");\n";
242 return out.str();
243 }
244};
245
246}//SOFIE
247}//Experimental
248}//TMVA
249
250
251#endif //TMVA_SOFIE_ROPERATOR_RESHAPE
#define b(i)
Definition RSha256.hxx:100
#define a(i)
Definition RSha256.hxx:99
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t r
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h length
#define INT64_MAX
Definition civetweb.c:511
const ETensorType & GetTensorType(std::string name)
Definition RModel.cxx:94
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< Dim > dim_shape)
Definition RModel.cxx:203
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:122
void AddConstantTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape, std::shared_ptr< void > data)
Definition RModel.cxx:178
bool IsInitializedTensor(const std::string &name) const
Definition RModel.cxx:188
const std::vector< size_t > & GetTensorShape(std::string name)
Definition RModel.cxx:56
std::shared_ptr< void > GetInitializedTensorData(std::string tensor_name)
Definition RModel.cxx:264
void SetNotWritableInitializedTensor(const std::string &tensor_name)
Definition RModel.cxx:273
ROperator_Reshape(ReshapeOpMode opMode, std::vector< int64_t > attrAxes, std::string nameData, std::string nameOutput)
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > input)
ROperator_Reshape(ReshapeOpMode opMode, int attr_value, std::string nameData, std::string nameShape, std::string nameOutput)
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input)
bool fIsOutputConstant
flag to identify if operator has a constant output (no need to generate code)
Definition ROperator.hxx:43
const std::string SP
space used to correctly indent the generated C++ code
Definition ROperator.hxx:41
const Int_t n
Definition legend1.C:16
std::string ConvertValuesToString(size_t n, const T *data)
std::string ConvertShapeToString(std::vector< size_t > shape)
std::size_t ConvertShapeToLength(std::vector< size_t > shape)
create variable transformations