1#ifndef TMVA_SOFIE_ROPERATOR_ScatterND
2#define TMVA_SOFIE_ROPERATOR_ScatterND
39 ROperator_ScatterND(
const std::string & nameX,
const std::string & nameI,
const std::string & nameU,
const std::string & nameY,
40 std::string reduction):
51 if (!model.CheckIfTensorAlreadyExist(
fNX)){
52 throw std::runtime_error(std::string(
"TMVA SOFIE ScatterND Op Input Tensor ") +
fNX +
"is not found in model");
54 if (!model.CheckIfTensorAlreadyExist(
fNI)) {
55 throw std::runtime_error(std::string(
"TMVA SOFIE ScatterND Op Input Tensor ") +
fNI +
"is not found in model");
57 if (!model.CheckIfTensorAlreadyExist(
fNU)) {
58 throw std::runtime_error(std::string(
"TMVA SOFIE ScatterND Op Input Tensor ") +
fNU +
"is not found in model");
64 auto shapeU = model.GetDimTensorShape(
fNU);
71 if (!(
fShapeI.back().isParam) ) {
72 const size_t k =
fShapeI.back().dim;
75 throw std::invalid_argument(
76 "ScatterND: last dim of indices (" + std::to_string(k) +
77 ") must be <= rank of data (" + std::to_string(
r) +
")");
80 int64_t expected_updates_rank =
q - 1 +
r - k;
81 if ((int64_t) shapeU.size() != expected_updates_rank)
82 throw std::invalid_argument(
"ScatterND: updates rank mismatch");
85 throw std::runtime_error(
"TMVA SOFIE ScatterND : Index_shape(-1) is not known. This case is not supported");
91 model.AddIntermediateTensor(
fNY, model.GetTensorType(
fNX),
fShapeY);
92 if (model.Verbose()) {
100 std::string
Generate(std::string opName)
override {
103 return "//---------------------------------------\n";
105 opName =
"op_" + opName;
106 std::stringstream out;
130 out <<
SP <<
"// Step 1: copy input data to output\n";
131 out <<
SP <<
"std::copy(tensor_" <<
fNX <<
", tensor_" <<
fNX <<
" + " << data_length <<
", tensor_" <<
fNY <<
");\n";
134 out <<
SP <<
"// Step 2: data strides (row-major)\n";
136 out <<
SP <<
"size_t " << opName <<
"_data_strides[" <<
r <<
"] = {";
137 for (
size_t i = 0; i <
r; ++i)
138 out << stridesX[i] << (i + 1 <
r ?
", " :
"");
142 out <<
SP <<
"// Step 3: scatter updates into output\n";
143 out <<
SP <<
"for (int64_t idx = 0; idx < " << num_index_tuples <<
"; idx++) {\n";
146 out <<
SP <<
SP <<
"int64_t data_offset = 0;\n";
147 for (
size_t dim = 0; dim < k; ++dim) {
148 out <<
SP <<
SP <<
"{\n";
149 out <<
SP <<
SP <<
SP <<
"int64_t coord = tensor_" <<
fNI
150 <<
"[idx * " << k <<
" + " << dim <<
"];\n";
152 out <<
SP <<
SP <<
SP <<
"if (coord < 0) coord += " <<
fShapeX[dim] <<
";\n";
153 out <<
SP <<
SP <<
SP <<
"data_offset += coord * "
154 << opName <<
"_data_strides[" << dim <<
"];\n";
155 out <<
SP <<
SP <<
"}\n";
159 out <<
SP <<
SP <<
"for (int64_t s = 0; s < " << slice_size <<
"; s++) {\n";
160 out <<
SP <<
SP <<
SP <<
"auto upd = tensor_" <<
fNU
161 <<
"[idx * " << slice_size <<
" + s];\n";
164 out <<
SP <<
SP <<
SP <<
"tensor_" <<
fNY <<
"[data_offset + s] = upd;\n";
166 out <<
SP <<
SP <<
SP <<
"tensor_" <<
fNY<<
"[data_offset + s] += upd;\n";
168 out <<
SP <<
SP <<
SP <<
"tensor_" <<
fNY <<
"[data_offset + s] *= upd;\n";
170 out <<
SP <<
SP <<
SP <<
"tensor_" <<
fNY<<
"[data_offset + s] = "
171 <<
"std::min(tensor_" <<
fNY <<
"[data_offset + s], upd);\n";
173 out <<
SP <<
SP <<
SP <<
"tensor_" <<
fNY <<
"[data_offset + s] = "
174 <<
"std::max(tensor_" <<
fNY <<
"[data_offset + s], upd);\n";
176 throw std::runtime_error(
177 "TMVA SOFIE ScatterND: unsupported reduction '" +
fReduction +
"'");
180 out <<
SP <<
SP <<
"}\n";
std::vector< Dim > fShapeI
std::string Generate(std::string opName) override
std::vector< Dim > fShapeX
ROperator_ScatterND(const std::string &nameX, const std::string &nameI, const std::string &nameU, const std::string &nameY, std::string reduction)
void Initialize(RModel &model) override
std::vector< int64_t > fIndices
std::vector< Dim > fShapeY
std::vector< std::string_view > fInputTensorNames
bool fIsOutputConstant
flag to identify if operator has a constant output (no need to generate code)
const std::string SP
space used to correctly indent the generated C++ code
std::vector< std::string_view > fOutputTensorNames
std::vector< size_t > ComputeStrideFromShape(const std::vector< size_t > &shape)
compute stride of a tensor given its shape (assume layout is row-major)
std::string ConvertDimShapeToString(const std::vector< Dim > &shape)
std::string ConvertDimShapeToLength(const std::vector< Dim > &shape)
create variable transformations