66 if (!model.CheckIfTensorAlreadyExist(
fNX)){
67 throw std::runtime_error(std::string(
"TMVA SOFIE ScatterElements Op Input Tensor ") +
fNX +
"is not found in model");
69 if (!model.CheckIfTensorAlreadyExist(
fNI)) {
70 throw std::runtime_error(std::string(
"TMVA SOFIE ScatterElements Op Input Tensor ") +
fNI +
"is not found in model");
72 if (!model.CheckIfTensorAlreadyExist(
fNU)) {
73 throw std::runtime_error(std::string(
"TMVA SOFIE ScatterElements Op Input Tensor ") +
fNU +
"is not found in model");
79 auto shapeU = model.GetDimTensorShape(
fNU);
80 if (model.Verbose()) {
85 if (!model.IsDynamicTensor(
fNI) && !model.IsDynamicTensor(
fNU)) {
87 throw std::runtime_error(std::string(
"TMVA SOFIE ScatterElements - update tensor has invalid shape ")) ;
90 throw std::runtime_error(std::string(
"TMVA SOFIE ScatterElements - input tensor has zero rank ")) ;
92 throw std::runtime_error(std::string(
"TMVA SOFIE ScatterElements - index tensor has invalid rank ")) ;
98 model.AddIntermediateTensor(
fNY, model.GetTensorType(
fNX),
fShapeY);
108 std::string
Generate(std::string opName)
override {
113 throw std::runtime_error(
"TMVA SOFIE ScatterElements Op called to Generate without being initialized first");
115 std::stringstream out;
116 out <<
SP <<
"\n//-------- ScatterElements --- " << opName <<
"\n";
124 auto tensorIndex = [](
const std::vector<Dim> & stride,
const std::vector<std::string> & idx) {
125 std::stringstream strst;
126 int dims = idx.size();
127 assert (dims == (
int) stride.size());
128 for (
int i = 0; i < dims; i++) {
129 if (stride[i].GetVal() !=
"1")
130 strst << stride[i] <<
"*" << idx[i];
139 auto tensorIndexOpt = [](
const std::vector<std::string> & sdx,
const std::vector<std::string> & idx) {
140 std::stringstream strst;
141 int dims = idx.size();
142 for (
int i = 0; i < dims-1; i++) {
146 strst << idx[dims-1];
152 out <<
SP <<
"std::copy(tensor_" <<
fNX <<
", tensor_" <<
fNX <<
" + " << length <<
", tensor_" <<
fNY <<
");\n";
156 std::vector<std::string> idx(dims);
157 std::vector<std::string> sdx(dims);
158 for (
int i = 0; i < dims; i++) {
159 idx[i] = std::string(
"i") + std::to_string(i);
160 sdx[i] = std::string(
"s") + std::to_string(i);
161 for (
int j = 0; j <= i; j++) out <<
SP;
162 out <<
"for (int " << idx[i] <<
" = 0; " << idx[i] <<
" < " <<
fShapeI[i] <<
"; " << idx[i] <<
"++) {\n";
164 for (
int j = 0; j <= i+1 ; j++) out <<
SP;
165 if (strideI[i].GetVal() !=
"1")
166 out <<
"int "<< sdx[i] <<
" = " << strideI[i] <<
" * " << idx[i] <<
";\n";
168 out <<
"int "<< sdx[i] <<
" = " << idx[i] <<
";\n";
172 for (
int j = 0; j <= dims; j++) out <<
SP;
174 out <<
"int updateIndex = " << tensorIndexOpt(sdx,idx) <<
";\n";
175 for (
int j = 0; j <= dims; j++) out <<
SP;
176 out <<
"int iAxis = tensor_" <<
fNI <<
"[updateIndex];\n";
177 for (
int j = 0; j <= dims; j++) out <<
SP;
178 out <<
"if (iAxis < 0) iAxis += " <<
fShapeY[
fAxis] <<
";\n";
179 idx[
fAxis] =
"iAxis";
180 for (
int j = 0; j <= dims; j++) out <<
SP;
181 out <<
"int outIndex = " << tensorIndex(strideY, idx) <<
";\n";
182 for (
int j = 0; j <= dims; j++) out <<
SP;
183 out <<
"tensor_" <<
fNY <<
"[outIndex] = "
184 <<
ReductionFunction(std::string(
"tensor_") +
fNY +
"[outIndex]", std::string(
"tensor_") +
fNU +
"[updateIndex]") <<
";\n";
186 for (
int i = dims; i > 0; i--) {
187 for (
int j = 0; j < i; j++) out <<
SP;