82 if (!model.CheckIfTensorAlreadyExist(
fNC))
83 throw std::runtime_error(std::string(
"TMVA SOFIE Where Op: condition tensor ") +
fNC +
" not found in model");
84 if (!model.CheckIfTensorAlreadyExist(
fNX))
85 throw std::runtime_error(std::string(
"TMVA SOFIE Where Op: X tensor ") +
fNX +
" not found in model");
86 if (!model.CheckIfTensorAlreadyExist(
fNY))
87 throw std::runtime_error(std::string(
"TMVA SOFIE Where Op: Y tensor ") +
fNY +
" not found in model");
90 if (model.IsReadyInputTensor(
fNC))
96 int dynamicInputs = 0;
98 if (model.IsDynamicTensor(
fNC)) {
105 if (model.IsDynamicTensor(
fNX)) {
112 if (model.IsDynamicTensor(
fNY)) {
120 if (model.Verbose()) {
121 if (dynamicInputs & 1)
123 if (dynamicInputs & 2)
125 if (dynamicInputs & 4)
132 if (dynamicInputs == 0) {
143 bool allConstant = model.IsInitializedTensor(
fNC) &&
144 model.IsInitializedTensor(
fNX) &&
145 model.IsInitializedTensor(
fNY);
151 auto broadcastIfNeeded = [&](
const std::string &
name,
152 const std::vector<size_t> &shape,
154 const std::string &prefix) {
156 bcName = prefix +
name +
"to" +
fNZ;
157 auto data = model.GetInitializedTensorData(
name);
158 std::shared_ptr<void> bcData(
160 std::default_delete<T[]>());
161 model.AddConstantTensor(bcName, model.GetTensorType(
name),
fShapeZ, bcData);
173 auto dataC =
static_cast<bool *
>(model.GetInitializedTensorData(nameC).get());
174 auto dataX =
static_cast<T *
> (model.GetInitializedTensorData(nameX).get());
175 auto dataY =
static_cast<T *
> (model.GetInitializedTensorData(nameY).get());
178 std::vector<T> dataZ(len);
179 for (
size_t i = 0; i < len; ++i)
180 dataZ[i] = dataC[i] ? dataX[i] : dataY[i];
182 model.AddConstantTensor<T>(
fNZ,
fShapeZ, dataZ.data());
183 model.SetNotWritableInitializedTensor(nameC);
184 model.SetNotWritableInitializedTensor(nameX);
185 model.SetNotWritableInitializedTensor(nameY);
198 model.AddIntermediateTensor(
fNZ, model.GetTensorType(
fNX),
fShapeZ);
221 auto IsInputDimParam = [&](
const std::string &p) {
222 for (
auto &input : model.GetInputTensorNames())
223 for (
auto &s : model.GetDimTensorShape(input))
224 if (s.isParam && s.param == p)
return true;
227 for (
size_t i = 0; i <
fDimShapeZ.size(); i++) {
229 if (s.isParam && s.param.find(
"std::max") != std::string::npos) {
260 opName =
"op_" + opName;
263 throw std::runtime_error(
"TMVA SOFIE Where Op called to Generate without being initialized first");
266 std::stringstream out;
276 out <<
SP <<
"if (" << lengthX <<
" != " << lengthY <<
" || "
277 << lengthX <<
" != " << lengthC <<
") {\n";
278 for (
size_t i = 0; i <
fDimShapeZ.size(); i++) {
284 <<
"throw std::runtime_error(\"SOFIE Where: cannot broadcast X dim " << i <<
" in " << opName <<
"\");\n";
291 <<
"throw std::runtime_error(\"SOFIE Where: cannot broadcast Y dim " << i <<
" in " << opName <<
"\");\n";
298 <<
"throw std::runtime_error(\"SOFIE Where: cannot broadcast C dim " << i <<
" in " << opName <<
"\");\n";
314 auto buildIdxExpr = [&](
const std::vector<Dim> &dimShape,
315 const std::vector<Dim> &strides,
316 size_t rankZ) -> std::string {
317 if (dimShape.empty() ||
318 std::all_of(dimShape.begin(), dimShape.end(),
319 [](
Dim d) { return d.dim == 1 || d.GetVal() ==
"1"; }))
322 size_t offset = rankZ - dimShape.size();
323 for (
size_t i = 0; i < dimShape.size(); ++i) {
324 if (dimShape[i].dim == 1 || dimShape[i].GetVal() ==
"1")
continue;
325 expr +=
"idx_" + std::to_string(i + offset);
326 if (strides[i].GetVal() !=
"1")
327 expr +=
" * " + strides[i].GetVal();
330 if (expr.size() >= 3)
331 for (
int j = 0; j < 3; j++) expr.pop_back();
332 return expr.empty() ?
"0" : expr;
344 [](
Dim d) { return d.dim == 1 || d.GetVal() ==
"1"; })) {
347 for (
size_t i = 0; i <
fDimShapeZ.size(); ++i) {
350 for (
int j = 0; j < nloop; j++) out <<
SP;
351 out <<
"for (size_t idx_" << i <<
" = 0; idx_" << i
352 <<
" < " <<
fDimShapeZ[i] <<
"; ++idx_" << i <<
") {\n";
353 idxZ +=
"idx_" + std::to_string(i);
354 if (stridesZ[i].GetVal() !=
"1")
355 idxZ +=
" * " + stridesZ[i].GetVal();
359 if (idxZ.size() >= 3)
360 for (
int j = 0; j < 3; j++) idxZ.pop_back();
364 for (
int j = 0; j < nloop + 1; j++) out <<
SP;
365 out <<
"tensor_" <<
fNZ <<
"[" << idxZ <<
"] = "
366 <<
"tensor_" <<
fNC <<
"[" << idxC <<
"] ? "
367 <<
"tensor_" <<
fNX <<
"[" << idxX <<
"] : "
368 <<
"tensor_" <<
fNY <<
"[" << idxY <<
"];\n";
371 for (
int i = nloop; i > 0; i--) {
372 for (
int j = 0; j < i; j++) out <<
SP;