Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_Reshape.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR_RESHAPE
2#define TMVA_SOFIE_ROPERATOR_RESHAPE
3
5#include "TMVA/ROperator.hxx"
6#include "TMVA/RModel.hxx"
7
8#include <cassert>
9#include <sstream>
10
11namespace TMVA{
12namespace Experimental{
13namespace SOFIE{
14
16
17template <typename T>
18class ROperator_Reshape final : public ROperator
19{
20
21private:
22
23 ReshapeOpMode fOpMode = Reshape; // type of Reshape operator
24
25 int fAllowZero = 0; // (for Reshape) zero in tensor shape makes output shape equal to input tensor shape
26 int fAxis = 1; // (for Flatten)
27
28 std::string fNData; // input data tensor name
29 std::string fNShape; // reshape tensor name
30 std::string fNOutput; // output tensor name
31 std::vector<size_t> fShapeInput; // input shape data
32 std::vector<size_t> fShapeOutput; // output shape data
33 std::vector<int64_t> fAttrAxes; // axes attributes (provided for all version of Squeeze/Unsqueeze)
34
35public:
36
38 ROperator_Reshape(ReshapeOpMode opMode, int attr_value, std::string nameData, std::string nameShape, std::string nameOutput)
39 : fOpMode(opMode), fNData(UTILITY::Clean_name(nameData)), fNShape(UTILITY::Clean_name(nameShape)),
40 fNOutput(UTILITY::Clean_name(nameOutput))
41 {
42 if (opMode == Reshape) fAllowZero = attr_value;
43 if (opMode == Flatten) fAxis = attr_value;
44 }
45
46 // for squeeze/unsqueezed operators following old ONNX version (< 10)
47 // In this cases axes are passed as attribute values
48 ROperator_Reshape(ReshapeOpMode opMode, std::vector<int64_t> attrAxes, std::string nameData, std::string nameOutput)
49 : fOpMode(opMode), fNData(UTILITY::Clean_name(nameData)), fNOutput(UTILITY::Clean_name(nameOutput)),
50 fAttrAxes(attrAxes)
51 {
52 assert(fOpMode == Squeeze || fOpMode == Unsqueeze);
53 }
54
55 // output type is same as input
56 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
57 auto ret = std::vector<ETensorType>(1, input[0]);
58 return ret;
59 }
60
61 // output shape
62 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input){
63 std::vector<std::vector<size_t>> ret;
64 auto & input_shape = input[0];
65
66 if (fOpMode == Reshape) {
67 if (input.size() != 2) throw std::runtime_error("TMVA SOFIE Reshape Op needs 2 input tensors");
68 auto output_shape = input[1]; // the provided shape
69 size_t input_length = ConvertShapeToLength(input_shape);
70 size_t output_length = ConvertShapeToLength(output_shape);
71 // (input_length == output_length) is the easy case : (2,3,4) -> (2,12)
72 if (input_length != output_length) {
73 if (output_shape.size() > 1 && ((output_length == 0 && fAllowZero == 0) || output_length > INT64_MAX)) {
74 // in this case value 0 in shape are automatically corrected
75 for (size_t i = 0; i < output_shape.size(); i++) {
76 if (output_shape[i] == 0 || output_shape[i] == static_cast<size_t>(-1)) {
77 auto tmp = output_shape;
78 tmp.erase(tmp.begin() + i);
79 auto tmp_length = ConvertShapeToLength(tmp);
80 output_shape[i] = input_length / tmp_length;
81 break;
82 }
83 }
84 }
85 if (ConvertShapeToLength(output_shape) != input_length) {
86 throw std::runtime_error("TMVA Reshape Op : Invalid shapes : " + ConvertShapeToString(input_shape) +
87 ConvertShapeToString(output_shape));
88 }
89 }
90 ret.push_back(output_shape);
91
92 } else if (fOpMode == Flatten) {
93 // flattenig case
94 size_t inputSize = ConvertShapeToLength(input_shape);
95 size_t b = input[0][0];
96 std::vector<size_t> newShape = {b, inputSize / b};
97 ret.push_back(newShape);
98
99 } else if (fOpMode == Squeeze) {
100 // squeeze
101 // assume no axis is provided - remove all axes with value equal to 1
102 auto output_shape = input[0];
103 if (input.size() == 1) {
104 size_t i = 0;
105 while (i < output_shape.size()) {
106 if (output_shape[i] == 1 ) {
107 output_shape.erase(output_shape.begin() + i);
108 }
109 else {
110 i++;
111 }
112 }
113 } else if (input.size() == 2) {
114 auto & axes = input[1];
115 for (size_t i = 0; i < axes.size(); i++){
116 if (output_shape[axes[i]] != 1)
117 throw std::runtime_error("TMVA Squeeze Op : Invalid axes : " + ConvertShapeToString(axes) +
118 ConvertShapeToString(output_shape));
119 output_shape.erase(output_shape.begin() + axes[i]);
120 }
121 }
122 ret.push_back(output_shape);
123 }
124
125 else if (fOpMode == Unsqueeze) {
126 // unsqueeze
127 assert(input.size() == 2);
128 auto output_shape = input[0];
129 auto &axes = input[1];
130 if (axes[0] > 0) { // positive axis start from beginning
131 for (auto & i : axes)
132 output_shape.insert(output_shape.begin() + i, 1);
133 } else {
134 //negative axes
135 for (auto &i : axes) {
136 assert(i < 0);
137 output_shape.insert(output_shape.begin() + (output_shape.size() + i - 1), 1);
138 }
139 }
140 ret.push_back(output_shape);
141 }
142 return ret;
143 }
144
145 void Initialize(RModel &model)
146 {
147
148 if (model.CheckIfTensorAlreadyExist(fNData) == false) {
149 // input must be a graph input, or already initialized intermediate tensor
150 throw std::runtime_error("TMVA Reshape Op Input Tensor " + fNData + " is not found in model");
151 }
153 // check if optional shape tensor exist
154 if (!fNShape.empty()) {
156 auto dptr = model.GetInitializedTensorData(fNShape);
157 auto input_shape = static_cast<int64_t *>(dptr.get());
158 auto vec = model.GetTensorShape(fNShape);
159 assert(vec.size() == 1);
160 size_t n = vec[0]; // size of shape input tensor
161
162 std::vector<size_t> descShape(n);
163 std::copy(input_shape, input_shape + n, descShape.begin());
164 fShapeOutput = ShapeInference({fShapeInput, descShape})[0];
165 } else {
166 throw std::runtime_error("TMVA Reshape Op Shape Tensor " + fNShape + " is not found in model");
167 }
168 } else if (!fAttrAxes.empty()) {
169 // case fNShape is empty and axes are provided as attributes
170 std::vector<size_t> descShape(fAttrAxes.size());
171 std::copy(fAttrAxes.begin(), fAttrAxes.end(), descShape.begin());
172 fShapeOutput = ShapeInference({fShapeInput, descShape})[0];
173 } else if (fOpMode == Flatten || fOpMode == Squeeze) {
175 } else {
176 throw std::runtime_error("TMVA Reshape Op : Invalid Input/Attribute data");
177 }
179 }
180
181 std::string Generate(std::string OpName)
182 {
183 OpName = "op_" + OpName;
184 if (fShapeInput.empty() || fShapeOutput.empty()) {
185 throw std::runtime_error("TMVA SOFIE Reshape Op called to Generate without being initialized first");
186 }
187
188 // output of reshape is same as input
191 throw std::runtime_error("TMVA SOFIE Reshape Op : wrong output shape - is " +
192 ConvertShapeToString(fShapeOutput) + " and input is " +
194 }
195 std::stringstream out;
196 std::string opName = "Reshape";
197 if (fOpMode == Flatten)
198 opName = "Flatten";
199 else if (fOpMode == Squeeze)
200 opName = "Squeeze";
201 else if (fOpMode == Unsqueeze)
202 opName = "Unsquueze";
203
204 out << SP << "///--------" << opName << " operator\n" << std::endl;
205 out << SP << "std::copy( tensor_" << fNData << ", tensor_" << fNData << " + " << length << ", " << "tensor_" << fNOutput
206 << ");\n";
207 return out.str();
208 }
209};
210
211}//SOFIE
212}//Experimental
213}//TMVA
214
215
216#endif //TMVA_SOFIE_ROPERATOR_RESHAPE
#define b(i)
Definition RSha256.hxx:100
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h length
#define INT64_MAX
Definition civetweb.c:511
const ETensorType & GetTensorType(std::string name)
Definition RModel.cxx:91
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< Dim > dim_shape)
Definition RModel.cxx:187
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:116
const std::vector< size_t > & GetTensorShape(std::string name)
Definition RModel.cxx:56
std::shared_ptr< void > GetInitializedTensorData(std::string tensor_name)
Definition RModel.cxx:248
ROperator_Reshape(ReshapeOpMode opMode, std::vector< int64_t > attrAxes, std::string nameData, std::string nameOutput)
ROperator_Reshape(ReshapeOpMode opMode, int attr_value, std::string nameData, std::string nameShape, std::string nameOutput)
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > input)
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input)
const std::string SP
space used to correctly indent the generated C++ code
Definition ROperator.hxx:41
const Int_t n
Definition legend1.C:16
std::string ConvertShapeToString(std::vector< size_t > shape)
std::size_t ConvertShapeToLength(std::vector< size_t > shape)
create variable transformations