1#ifndef TMVA_SOFIE_ROPERATOR_GATHER
2#define TMVA_SOFIE_ROPERATOR_GATHER
36 ROperator_Gather(int64_t attrAxis, std::string nameX, std::string nameIndices, std::string nameY):
42 std::vector<ETensorType>
TypeInference(std::vector<ETensorType> input)
override {
46 std::vector<std::vector<size_t>>
ShapeInference(std::vector<std::vector<size_t>> input)
override {
52 if (!model.CheckIfTensorAlreadyExist(
fNX)) {
53 throw std::runtime_error(
"TMVA SOFIE Gather Op Input Tensor " +
fNX +
" is not found in model");
71 if (model.IsInitializedTensor(
fNIndices)) {
74 int64_t* indicesData =
static_cast<int64_t*
>(model.GetInitializedTensorData(
fNIndices).get());
76 for (
size_t i = 0; i < indicesLength; i++) {
79 if (indicesData[i] < 0) {
85 fIndices = std::vector<int64_t>(indicesData, indicesData + indicesLength);
98 for (
size_t i = 0; i <
q; i++) {
105 if (model.IsInitializedTensor(
fNX) &&
q <= 1 &&
r == 1 &&
fIndices.size() > 0) {
109 auto inputData =
static_cast<int64_t*
>(model.GetInitializedTensorData(
fNX).get());
111 std::vector<int64_t> outputData(1);
112 outputData[0] = inputData[
fIndices[0]];
113 model.AddConstantTensor(
fNY, shapeY, outputData.data());
121 else if (model.IsShapeTensor(
fNX) &&
q <=1 &&
fIndices.size() > 0) {
122 auto inputData = model.GetShapeTensorValues(
fNX);
124 std::vector<Dim> outputData(1);
125 outputData[0] = inputData[
fIndices[0]];
126 if (outputData[0].isParam) {
129 model.AddShapeTensor(
fNY, outputData,
fShapeY.size() == 0);
134 int64_t value =
static_cast<int64_t
>(outputData[0].dim);
136 model.AddConstantTensor(
fNY, shapeY, &value);
140 <<
" and values {" << value <<
"} (constant) " << std::endl;
145 model.AddIntermediateTensor(
fNY, model.GetTensorType(
fNX),
fShapeY);
153 std::string
Generate(std::string opName)
override {
154 opName =
"op_" + opName;
155 std::stringstream out;
159 out <<
"//--------------------(constant)----------\n";
174 out <<
SP <<
"// correct in case of negative gather indices\n";
175 out <<
SP <<
"for (size_t i = 0; i < " << indicesLength <<
"; i++){\n";
176 out <<
SP <<
SP <<
"if (tensor_" <<
fNIndices <<
"[i] < 0)\n";
186 for (
size_t j = 0; j < size_t(
fAttrAxis); j++) {
187 std::string index =
"j_" + std::to_string(j);
188 for (
size_t k = 0; k <= j; k++) out <<
SP;
189 out <<
"for (size_t " << index <<
" = 0; " << index <<
" < " <<
fShapeY[j] <<
"; " << index <<
"++) {\n";
192 for (
size_t i = 0; i <
q; i++) {
193 std::string index =
"i_" + std::to_string(i);
194 for (
size_t k = 0; k <= i +
fAttrAxis; k++) out <<
SP;
195 out <<
"for (size_t " << index <<
" = " << 0 <<
"; " << index <<
" < " <<
fShapeIndices[i] <<
"; " << index <<
"++) {\n";
199 std::string index =
"j_" + std::to_string(
q+j);
200 for (
size_t k = 0; k <=
q + j; k++) out <<
SP;
201 out <<
"for (size_t " << index <<
" = 0; " << index <<
" < " <<
fShapeY[
q + j] <<
"; " << index <<
"++) {\n";
206 out <<
SP <<
"{ // scalar case \n";
209 for (
size_t k = 0; k <
q +
r; k++) out <<
SP;
210 out <<
"size_t y_index = ";
211 for (
size_t j = 0; j < size_t(
fAttrAxis); j++) {
212 if (j > 0) out <<
" + ";
214 if (stridesY[j].dim != 1) out <<
" * " << stridesY[j];
216 for (
size_t i = 0; i <
q; i++) {
222 if (j +
q > 0) out <<
" + ";
224 if (stridesY[
q+j].dim != 1) out <<
" * " << stridesY[
q+j];
232 for (
size_t k = 0; k <
q +
r; k++) out <<
SP;
233 out <<
"size_t i_index = ";
234 for (
size_t i = 0; i <
q; i++) {
235 if (i > 0) out <<
" + ";
237 if (stridesIndices[i].dim != 1) out <<
" * " << stridesIndices[i];
245 for (
size_t k = 0; k <
q +
r; k++) out <<
SP;
246 out <<
"size_t k = static_cast<size_t>(" <<
"tensor_" <<
fNIndices <<
"[i_index]" <<
");\n";
248 for (
size_t k = 0; k <
q +
r; k++) out <<
SP;
249 out <<
"size_t x_index = k";
251 for (
size_t j = 0; j < size_t(
fAttrAxis); j++) {
254 if (stridesX[j].dim != 1) out <<
" * " << stridesX[j];
261 if (stridesX[j+1].dim != 1) out <<
" * " << stridesX[j+1];
264 for (
size_t k = 0; k <
q +
r; k++) out <<
SP;
265 out <<
"tensor_" <<
fNY <<
"[y_index] = tensor_" <<
fNX <<
"[x_index];\n";
268 for (
size_t j =
q+
r-1; j > 0; j--) {
269 for (
size_t k = 0; k <j; k++) out <<
SP;
274 out <<
SP <<
"} // close Gather scope for scalar case \n";
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > input) override
void Initialize(RModel &model) override
std::vector< int64_t > fIndices
std::string Generate(std::string opName) override
ROperator_Gather(int64_t attrAxis, std::string nameX, std::string nameIndices, std::string nameY)
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input) override
std::vector< Dim > fShapeY
std::vector< Dim > fShapeX
std::vector< Dim > fShapeIndices
std::vector< std::string_view > fInputTensorNames
bool fIsOutputConstant
flag to identify if operator has a constant output (no need to generate code)
const std::string SP
space used to correctly indent the generated C++ code
std::vector< std::string_view > fOutputTensorNames
std::vector< size_t > ComputeStrideFromShape(const std::vector< size_t > &shape)
compute stride of a tensor given its shape (assume layout is row-major)
std::string ConvertDimShapeToString(const std::vector< Dim > &shape)
std::size_t ConvertShapeToLength(const std::vector< size_t > &shape)
std::string ConvertValuesToString(size_t n, const T *data, size_t maxprint=-1)
std::vector< size_t > ConvertShapeToInt(const std::vector< Dim > &shape)
Convert shape based on Dim to integer format.
std::string ConvertTypeToString(ETensorType type)
std::string ConvertDimShapeToLength(const std::vector< Dim > &shape)
std::string ConvertShapeToString(const std::vector< size_t > &shape)
create variable transformations