Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_Reshape.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR_RESHAPE
2#define TMVA_SOFIE_ROPERATOR_RESHAPE
3
5#include "TMVA/ROperator.hxx"
6#include "TMVA/RModel.hxx"
7
8#include <sstream>
9
10namespace TMVA{
11namespace Experimental{
12namespace SOFIE{
13
15
16template <typename T>
17class ROperator_Reshape final : public ROperator
18{
19
20private:
21
22 ReshapeOpMode fOpMode = Reshape; // type of Reshape operator
23
24 int fAllowZero = 0; // (for Reshape) zero in tensor shape makes output shape equal to input tensor shape
25 int fAxis = 1; // (for Flatten)
26
27 std::string fNData; // input data tensor name
28 std::string fNShape; // reshape tensor name
29 std::string fNOutput; // output tensor name
30 std::vector<size_t> fShapeInput; // input shape data
31 std::vector<size_t> fShapeOutput; // output shape data
32 std::vector<int64_t> fAttrAxes; // axes attributes (provided for all version of Squeeze/Unsqueeze)
33
34public:
35
37 ROperator_Reshape(ReshapeOpMode opMode, int attr_value, std::string nameData, std::string nameShape, std::string nameOutput)
38 : fOpMode(opMode), fNData(UTILITY::Clean_name(nameData)), fNShape(UTILITY::Clean_name(nameShape)),
39 fNOutput(UTILITY::Clean_name(nameOutput))
40 {
41 if (opMode == Reshape) fAllowZero = attr_value;
42 if (opMode == Flatten) fAxis = attr_value;
43 }
44
45 // for squeeze/unsqueezed operators following old ONNX version (< 10)
46 // IN this cases axes are passed as attribute values
47 ROperator_Reshape(ReshapeOpMode opMode, std::vector<int64_t> attrAxes, std::string nameData, std::string nameOutput)
48 : fOpMode(opMode), fNData(UTILITY::Clean_name(nameData)), fNOutput(UTILITY::Clean_name(nameOutput)),
49 fAttrAxes(attrAxes)
50 {
51 assert(fOpMode == Squeeze || fOpMode == Unsqueeze);
52 }
53
54 // output type is same as input
55 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
56 auto ret = std::vector<ETensorType>(1, input[0]);
57 return ret;
58 }
59
60 // output shape
61 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input){
62 std::vector<std::vector<size_t>> ret;
63 auto & input_shape = input[0];
64
65 if (fOpMode == Reshape) {
66 if (input.size() != 2) throw std::runtime_error("TMVA SOFIE Reshape Op needs 2 input tensors");
67 auto output_shape = input[1]; // the provided shape
68 size_t input_length = ConvertShapeToLength(input_shape);
69 size_t output_length = ConvertShapeToLength(output_shape);
70 // input_length == output_length) is the easy case : (2,3,4) -> (2,12)
71 if (input_length != output_length) {
72 if (output_shape.size() > 1 && ((output_length == 0 && fAllowZero == 0) || output_length > INT64_MAX)) {
73 // in this case value 0 in shape are automatically corrected
74 for (size_t i = 0; i < output_shape.size(); i++) {
75 if (output_shape[i] == 0 || output_shape[i] == static_cast<size_t>(-1)) {
76 auto tmp = output_shape;
77 tmp.erase(tmp.begin() + i);
78 auto tmp_length = ConvertShapeToLength(tmp);
79 output_shape[i] = input_length / tmp_length;
80 break;
81 }
82 }
83 }
84 if (ConvertShapeToLength(output_shape) != input_length) {
85 throw std::runtime_error("TMVA Reshape Op : Invalid shapes : " + ConvertShapeToString(input_shape) +
86 ConvertShapeToString(output_shape));
87 }
88 }
89 ret.push_back(output_shape);
90
91 } else if (fOpMode == Flatten) {
92 // flattenig case
93 size_t inputSize = ConvertShapeToLength(input_shape);
94 size_t b = input[0][0];
95 std::vector<size_t> newShape = {b, inputSize / b};
96 ret.push_back(newShape);
97
98 } else if (fOpMode == Squeeze) {
99 // squeeze
100 // assume no axis is provided - remove all axes with value equal to 1
101 auto output_shape = input[0];
102 if (input.size() == 1) {
103 for (size_t i = 0; i < output_shape.size(); i++) {
104 if (output_shape[i] == 1 ) {
105 output_shape.erase(output_shape.begin() + i);
106 }
107 }
108 } else if (input.size() == 2) {
109 auto & axes = input[1];
110 for (size_t i = 0; i < axes.size(); i++){
111 if (output_shape[axes[i]] != 1)
112 throw std::runtime_error("TMVA Squeeze Op : Invalid axes : " + ConvertShapeToString(axes) +
113 ConvertShapeToString(output_shape));
114 output_shape.erase(output_shape.begin() + axes[i]);
115 }
116 }
117 ret.push_back(output_shape);
118 }
119
120 else if (fOpMode == Unsqueeze) {
121 // unsqueeze
122 assert(input.size() == 2);
123 auto output_shape = input[0];
124 auto &axes = input[1];
125 if (axes[0] > 0) { // positive axis start from beginning
126 for (auto & i : axes)
127 output_shape.insert(output_shape.begin() + i, 1);
128 } else {
129 //negative axes
130 for (auto &i : axes) {
131 assert(i < 0);
132 output_shape.insert(output_shape.begin() + (output_shape.size() + i - 1), 1);
133 }
134 }
135 ret.push_back(output_shape);
136 }
137 return ret;
138 }
139
140 void Initialize(RModel &model)
141 {
142
143 if (model.CheckIfTensorAlreadyExist(fNData) == false) {
144 // input must be a graph input, or already initialized intermediate tensor
145 throw std::runtime_error("TMVA Reshape Op Input Tensor is not found in model");
146 }
148
149 // check if optional shape tensor exist
150 if (!fNShape.empty()) {
152 auto dptr = model.GetInitializedTensorData(fNShape);
153 auto input_shape = static_cast<int64_t *>(dptr.get());
154 auto vec = model.GetTensorShape(fNShape);
155 assert(vec.size() == 1);
156 size_t n = vec[0]; // size of shape input tensor
157
158 std::vector<size_t> descShape(n);
159 std::copy(input_shape, input_shape + n, descShape.begin());
160 fShapeOutput = ShapeInference({fShapeInput, descShape})[0];
161 } else {
162 throw std::runtime_error("TMVA Reshape Op Input Tensor is not found in model");
163 }
164 } else if (!fAttrAxes.empty()) {
165 // case fNShape is empty and axes are provided as attributes
166 std::vector<size_t> descShape(fAttrAxes.size());
167 std::copy(fAttrAxes.begin(), fAttrAxes.end(), descShape.begin());
168 fShapeOutput = ShapeInference({fShapeInput, descShape})[0];
169 } else if (fOpMode == Flatten || fOpMode == Squeeze) {
171 } else {
172 throw std::runtime_error("TMVA Reshape Op : Invalid Input/Attribute data");
173 }
175 }
176
177 std::string Generate(std::string OpName)
178 {
179 OpName = "op_" + OpName;
180 if (fShapeInput.empty() || fShapeOutput.empty()) {
181 throw std::runtime_error("TMVA SOFIE Reshape Op called to Generate without being initialized first");
182 }
183
184 // output of reshape is same as input
185 size_t length = ConvertShapeToLength(fShapeOutput);
186 if (length != ConvertShapeToLength(fShapeInput)) {
187 throw std::runtime_error("TMVA SOFIE Reshape Op : wrong output shape - is " +
188 ConvertShapeToString(fShapeOutput) + " and input is " +
190 }
191 for (auto &i : fShapeOutput) {
192 length *= i;
193 }
194 std::stringstream out;
195 std::string opName = "Reshape";
196 if (fOpMode == Flatten)
197 opName = "Flatten";
198 else if (fOpMode == Squeeze)
199 opName = "Squeeze";
200 else if (fOpMode == Unsqueeze)
201 opName = "Unsquueze";
202
203 out << SP << "///--------" << opName << " operator\n" << std::endl;
204 out << SP << "std::copy( fTensor_" << fNData << ".begin(), fTensor_" << fNData << ".end(), fTensor_" << fNOutput
205 << ".begin() );\n";
206 return out.str();
207 }
208};
209
210}//SOFIE
211}//Experimental
212}//TMVA
213
214
215#endif //TMVA_SOFIE_ROPERATOR_RESHAPE
#define b(i)
Definition RSha256.hxx:100
#define INT64_MAX
Definition civetweb.c:439
const ETensorType & GetTensorType(std::string name)
Definition RModel.cxx:70
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape)
Definition RModel.cxx:136
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:91
const std::vector< size_t > & GetTensorShape(std::string name)
Definition RModel.cxx:49
std::shared_ptr< void > GetInitializedTensorData(std::string tensor_name)
Definition RModel.cxx:160
ROperator_Reshape(ReshapeOpMode opMode, std::vector< int64_t > attrAxes, std::string nameData, std::string nameOutput)
ROperator_Reshape(ReshapeOpMode opMode, int attr_value, std::string nameData, std::string nameShape, std::string nameOutput)
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > input)
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input)
const std::string SP
space used to correctly indent the generated C++ code
Definition ROperator.hxx:39
const Int_t n
Definition legend1.C:16
std::string ConvertShapeToString(std::vector< size_t > shape)
std::size_t ConvertShapeToLength(std::vector< size_t > shape)
create variable transformations