87 std::vector<std::vector<Dim>>
ShapeInference(
const std::vector<std::vector<Dim>> & input) {
88 std::vector<std::vector<Dim>>
ret;
89 auto & input_shape = input[0];
92 std::vector<Dim> output_shape(
fShape.size());
94 for (
size_t i = 0; i < output_shape.size(); i++) {
96 output_shape[i] =
Dim{
static_cast<size_t>(
fShape[i]) };
98 output_shape[i] = input_shape[i];
101 for (
size_t i = 0; i < output_shape.size(); i++) {
103 auto tmp = output_shape;
104 tmp.erase(tmp.begin() + i);
109 << input_length <<
" to " << tmp_length << std::endl;
112 output_shape[i] =
Dim{
static_cast<size_t>(std::stoi(input_length) / std::stoi(tmp_length))};
113 else if (
IsInteger(tmp_length) && std::stoi(tmp_length) == 1) {
114 output_shape[i] =
Dim{input_length,
static_cast<size_t>(-1)};
119 bool canSimplify =
false;
120 std::vector <Dim> reduced_input;
125 std::stringstream ss(input_length);
130 while(getline(ss, token,
'*'))
133 token.erase(std::remove_if(token.begin(), token.end(),
134 [](
unsigned char x) { return std::isspace(x); }), token.end());
135 if (token != tmp_length) {
137 size_t il =
static_cast<size_t>(std::stoi(input_length));
138 size_t tl =
static_cast<size_t>(std::stoi(tmp_length));
139 if ((il % tl) == 0) {
141 reduced_input.push_back(
Dim{il / tl});
144 reduced_input.push_back(
Dim{token});
155 if (res_shape.find(
'*') != std::string::npos)
156 output_shape[i] =
Dim{std::string(
"(") + res_shape +
")",
static_cast<size_t>(-1)};
158 output_shape[i] =
Dim{res_shape};
161 output_shape[i] =
Dim{std::string(
"(") + input_length +
" / (" + tmp_length +
"))",
static_cast<size_t>(-1)};
177 ret.push_back(output_shape);
182 fAxis += input_shape.size();
183 auto s1 = std::vector<Dim>(input_shape.begin(), input_shape.begin() +
fAxis);
184 auto s2 = std::vector<Dim>(input_shape.begin() +
fAxis, input_shape.end());
187 std::vector<Dim> newShape = {
Dim{l1},
Dim{l2}};
188 ret.push_back(newShape);
192 auto output_shape = input_shape;
195 while (i < output_shape.size()) {
196 if (output_shape[i] ==
Dim{1}) {
197 output_shape.erase(output_shape.begin() + i);
203 std::cout <<
"getting shape for Squeeze...from attribute\n";
205 for (
size_t i = 0; i < axes.size(); i++) {
206 std::cout << i <<
" " << axes[i] << std::endl;
208 axes[i] += input_shape.size();
209 if (!(output_shape[axes[i]] ==
Dim{1}))
210 throw std::runtime_error(
"TMVA Squeeze Op : Invalid axis value " + std::to_string(axes[i]) +
214 std::sort(axes.begin(), axes.end(), std::greater<int>());
215 for (
auto & axis : axes) {
216 std::cout <<
"erase give axis " << axis <<
" -> ";
217 for (
auto & o : output_shape) std::cout << o <<
" , ";
218 std::cout << std::endl;
219 output_shape.erase(output_shape.begin() + axis);
222 ret.push_back(output_shape);
227 auto output_shape = input_shape;
230 int64_t
r = input[0].size() + axes.size();
231 for (
auto &
a : axes) {
232 int64_t i =
static_cast<int64_t
>(
a);
233 if (i < -r || i >
r - 1)
234 throw std::runtime_error(
"TMVA Unsqueeze Op - axes input is not in correct range");
236 output_shape.insert(output_shape.begin() + i,
Dim{1});
239 output_shape.insert(output_shape.end() + i + 1,
Dim{1});
241 ret.push_back(output_shape);
250 std::cout <<
"initialize reshape op type " <<
fOpMode <<
" - " <<
fNInput2 <<
" " <<
fNData << std::endl;
252 if (model.CheckIfTensorAlreadyExist(
fNData) ==
false) {
254 throw std::runtime_error(
"TMVA Reshape Op Input Tensor " +
fNData +
" is not found in model");
260 if (model.CheckIfTensorAlreadyExist(
fNInput2)) {
261 if (model.IsInitializedTensor(
fNInput2)) {
263 auto dptr = model.GetInitializedTensorData(
fNInput2);
264 auto values =
static_cast<int64_t *
>(dptr.get());
271 fShape = std::vector<int64_t>(values, values +
n);
273 fAttrAxes = std::vector<int64_t>(values, values +
n);
277 model.SetNotWritableInitializedTensor(
fNInput2);
278 }
else if (model.IsShapeTensor(
fNInput2)) {
279 auto shapeData = model.GetShapeTensorValues(
fNInput2);
285 auto shapeInput2 = model.GetTensorShape(
fNInput2);
292 throw std::runtime_error(
"TMVA Reshape Op 2nd input Tensor " +
fNInput2 +
" is not found in model");
296 std::cout <<
"attribute axes exists\n";
301 throw std::runtime_error(
"TMVA Reshape Op : Invalid Input/Attribute data");
306 auto inputData =
static_cast<int64_t*
>(model.GetInitializedTensorData(
fNData).get());
309 throw std::runtime_error(
"TMVA Reshape Op : Invalid Input/Output lengths");
310 model.AddConstantTensor<int64_t>(
fNOutput, o_shape, inputData);
311 if (model.Verbose()) {
319 auto inputData = model.GetShapeTensorValues(
fNData);
320 model.AddShapeTensor(
fNOutput, inputData);
321 if (model.Verbose()) {