Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_Where.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROperator_Where
2#define TMVA_SOFIE_ROperator_Where
3
5#include "TMVA/ROperator.hxx"
6#include "TMVA/RModel.hxx"
7
8#include <sstream>
9
10namespace TMVA {
11namespace Experimental {
12namespace SOFIE {
13
14template <typename T>
16private:
17
18 bool fIsInputBoolTensor = false;
19
20 // Tensor names: C = condition, X = true branch, Y = false branch, Z = output
21 std::string fNC; // condition (bool)
22 std::string fNX; // true-branch values
23 std::string fNY; // false-branch values
24 std::string fNZ; // output
25 std::string fNBroadcastedC;
26 std::string fNBroadcastedX;
27 std::string fNBroadcastedY;
28
29 // Static shapes (used when all inputs are non-dynamic)
30 std::vector<size_t> fShapeC;
31 std::vector<size_t> fShapeX;
32 std::vector<size_t> fShapeY;
33 std::vector<size_t> fShapeZ;
34
35 // Dynamic shapes (Dim-aware, used when any input is dynamic)
36 std::vector<Dim> fDimShapeC;
37 std::vector<Dim> fDimShapeX;
38 std::vector<Dim> fDimShapeY;
39 std::vector<Dim> fDimShapeZ;
40
41 // Broadcast flag: mirrors convention of BasicBinary
42 // bit 0: broadcast Y->X (Y needs expanding)
43 // bit 1: broadcast X->Y (X needs expanding)
44 // bit 2: broadcast C->Z (C needs expanding)
45 // bit 4: shapes may differ at runtime (dynamic)
47
48public:
50 ROperator_Where(const std::string &nameC,
51 const std::string &nameX,
52 const std::string &nameY,
53 const std::string &nameZ)
54 : fNC(UTILITY::Clean_name(nameC)),
55 fNX(UTILITY::Clean_name(nameX)),
56 fNY(UTILITY::Clean_name(nameY)),
57 fNZ(UTILITY::Clean_name(nameZ))
58 {
61 }
62
63 // type of output given input
64 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override
65 {
66 // output type follows X (and Y), not C (which is bool)
67 return { input[1] };
68 }
69
70 // shape of output tensors given input tensors
71 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override
72 {
73 // conservative: assume same shape (broadcasting resolved in Initialize)
74 return { input[1] };
75 }
76
77 void Initialize(RModel &model) override
78 {
79 // ---------------------------------------------------------------- //
80 // Check all inputs exist
81 // ---------------------------------------------------------------- //
83 throw std::runtime_error(std::string("TMVA SOFIE Where Op: condition tensor ") + fNC + " not found in model");
85 throw std::runtime_error(std::string("TMVA SOFIE Where Op: X tensor ") + fNX + " not found in model");
87 throw std::runtime_error(std::string("TMVA SOFIE Where Op: Y tensor ") + fNY + " not found in model");
88
89 // condition tensor is bool (uint8) - mark if it is a live input tensor
90 if (model.IsReadyInputTensor(fNC))
91 fIsInputBoolTensor = true;
92
93 // ---------------------------------------------------------------- //
94 // Collect shapes – dynamic or static
95 // ---------------------------------------------------------------- //
96 int dynamicInputs = 0; // bitmask: bit0=C, bit1=X, bit2=Y
97
98 if (model.IsDynamicTensor(fNC)) {
100 dynamicInputs |= 1;
101 } else {
102 fShapeC = model.GetTensorShape(fNC);
104 }
105 if (model.IsDynamicTensor(fNX)) {
107 dynamicInputs |= 2;
108 } else {
109 fShapeX = model.GetTensorShape(fNX);
111 }
112 if (model.IsDynamicTensor(fNY)) {
114 dynamicInputs |= 4;
115 } else {
116 fShapeY = model.GetTensorShape(fNY);
118 }
119
120 if (model.Verbose()) {
121 if (dynamicInputs & 1)
122 std::cout << "Where : condition " << fNC << " is dynamic " << ConvertDimShapeToString(fDimShapeC) << "\n";
123 if (dynamicInputs & 2)
124 std::cout << "Where : X " << fNX << " is dynamic " << ConvertDimShapeToString(fDimShapeX) << "\n";
125 if (dynamicInputs & 4)
126 std::cout << "Where : Y " << fNY << " is dynamic " << ConvertDimShapeToString(fDimShapeY) << "\n";
127 }
128
129 // ---------------------------------------------------------------- //
130 // Static path: all shapes known at code-gen time
131 // ---------------------------------------------------------------- //
132 if (dynamicInputs == 0) {
133
134 // Multidirectional broadcast over all three tensors
136 fBroadcastFlag = retXY.first;
137 fShapeZ = retXY.second;
138 // also factor in C
140 fBroadcastFlag |= retCZ.first;
141 fShapeZ = retCZ.second;
142
143 bool allConstant = model.IsInitializedTensor(fNC) &&
144 model.IsInitializedTensor(fNX) &&
146
147 if (allConstant) {
148 // ----------------------------------------------------------
149 // Constant folding: evaluate Where at model initialisation
150 // ----------------------------------------------------------
151 auto broadcastIfNeeded = [&](const std::string &name,
152 const std::vector<size_t> &shape,
153 std::string &bcName,
154 const std::string &prefix) {
155 if (shape != fShapeZ) {
156 bcName = prefix + name + "to" + fNZ;
157 auto data = model.GetInitializedTensorData(name);
158 std::shared_ptr<void> bcData(
159 UTILITY::UnidirectionalBroadcast(static_cast<T *>(data.get()), shape, fShapeZ),
160 std::default_delete<T[]>());
162 }
163 };
164
168
169 const std::string &nameC = fNBroadcastedC.empty() ? fNC : fNBroadcastedC;
170 const std::string &nameX = fNBroadcastedX.empty() ? fNX : fNBroadcastedX;
171 const std::string &nameY = fNBroadcastedY.empty() ? fNY : fNBroadcastedY;
172
173 auto dataC = static_cast<bool *>(model.GetInitializedTensorData(nameC).get());
174 auto dataX = static_cast<T *> (model.GetInitializedTensorData(nameX).get());
175 auto dataY = static_cast<T *> (model.GetInitializedTensorData(nameY).get());
176
178 std::vector<T> dataZ(len);
179 for (size_t i = 0; i < len; ++i)
180 dataZ[i] = dataC[i] ? dataX[i] : dataY[i];
181
182 model.AddConstantTensor<T>(fNZ, fShapeZ, dataZ.data());
186 fIsOutputConstant = true;
187 fOutputTensorNames.pop_back();
188
189 if (model.Verbose())
190 std::cout << "Where --> " << fNZ << " " << ConvertShapeToString(fShapeZ)
191 << " : " << ConvertValuesToString(dataZ) << " (constant)\n";
192 } else {
193 // ----------------------------------------------------------
194 // Non-constant static tensors - we don't need to broadcast tensors
195 // ----------------------------------------------------------
196
199
200 if (model.Verbose())
201 std::cout << "Where : C=" << fNC << " " << ConvertShapeToString(fShapeC)
202 << " X=" << fNX << " " << ConvertShapeToString(fShapeX)
203 << " Y=" << fNY << " " << ConvertShapeToString(fShapeY)
204 << " --> Z=" << fNZ << " " << ConvertShapeToString(fShapeZ) << "\n";
205 }
206
207 } else {
208 // ---------------------------------------------------------------- //
209 // Dynamic path: at least one input has a parametric shape
210 // Need to use BroadcastShape to find output shape
211 // ---------------------------------------------------------------- //
213 fBroadcastFlag = retXY.first;
214 fDimShapeZ = retXY.second;
216 fBroadcastFlag |= retCZ.first;
217 fDimShapeZ = retCZ.second;
218
219 // Resolve std::max params to actual input dim params (same logic as BasicBinary)
220 if (fBroadcastFlag & 4) {
221 auto IsInputDimParam = [&](const std::string &p) {
222 for (auto &input : model.GetInputTensorNames())
223 for (auto &s : model.GetDimTensorShape(input))
224 if (s.isParam && s.param == p) return true;
225 return false;
226 };
227 for (size_t i = 0; i < fDimShapeZ.size(); i++) {
228 auto &s = fDimShapeZ[i];
229 if (s.isParam && s.param.find("std::max") != std::string::npos) {
230 // prefer X dim over Y dim
231 if (i < fDimShapeX.size() && IsInputDimParam(fDimShapeX[i].param)) {
232 s = (fDimShapeX[i].dim != 1) ? fDimShapeX[i] : fDimShapeY[i];
233 } else if (i < fDimShapeY.size() && IsInputDimParam(fDimShapeY[i].param)) {
234 s = (fDimShapeY[i].dim != 1) ? fDimShapeY[i] : fDimShapeX[i];
235 }
236 }
237 }
238 }
239
241
242 if (model.Verbose())
243 std::cout << "Where (dynamic) : C=" << ConvertDimShapeToString(fDimShapeC)
246 << " --> Z=" << ConvertDimShapeToString(fDimShapeZ) << "\n";
247 }
248 }
249
250 std::string GenerateInitCode() override
251 {
252 std::stringstream out;
253 return out.str();
254 }
255
256 std::string Generate(std::string opName) override
257 {
258 if (fIsOutputConstant) return "";
259
260 opName = "op_" + opName;
261
262 if (fDimShapeZ.empty()) {
263 throw std::runtime_error("TMVA SOFIE Where Op called to Generate without being initialized first");
264 }
265
266 std::stringstream out;
267 out << SP << "\n//------ WHERE " << opName << " --> " << ConvertDimShapeToString(fDimShapeZ) << "\n";
268
269 // ---------------------------------------------------------------- //
270 // Runtime broadcast validation (dynamic shapes, flag bit 4)
271 // ---------------------------------------------------------------- //
272 if (fBroadcastFlag & 4) {
276 out << SP << "if (" << lengthX << " != " << lengthY << " || "
277 << lengthX << " != " << lengthC << ") {\n";
278 for (size_t i = 0; i < fDimShapeZ.size(); i++) {
279 // validate X vs Z
280 if (i < fDimShapeX.size() && fDimShapeX[i].isParam) {
281 out << SP << SP << "if (" << fDimShapeX[i] << " != 1 && "
282 << fDimShapeX[i] << " != " << fDimShapeZ[i] << ")\n";
283 out << SP << SP << SP
284 << "throw std::runtime_error(\"SOFIE Where: cannot broadcast X dim " << i << " in " << opName << "\");\n";
285 }
286 // validate Y vs Z
287 if (i < fDimShapeY.size() && fDimShapeY[i].isParam) {
288 out << SP << SP << "if (" << fDimShapeY[i] << " != 1 && "
289 << fDimShapeY[i] << " != " << fDimShapeZ[i] << ")\n";
290 out << SP << SP << SP
291 << "throw std::runtime_error(\"SOFIE Where: cannot broadcast Y dim " << i << " in " << opName << "\");\n";
292 }
293 // validate C vs Z
294 if (i < fDimShapeC.size() && fDimShapeC[i].isParam) {
295 out << SP << SP << "if (" << fDimShapeC[i] << " != 1 && "
296 << fDimShapeC[i] << " != " << fDimShapeZ[i] << ")\n";
297 out << SP << SP << SP
298 << "throw std::runtime_error(\"SOFIE Where: cannot broadcast C dim " << i << " in " << opName << "\");\n";
299 }
300 }
301 out << SP << "}\n";
302 }
303
304 // ---------------------------------------------------------------- //
305 // Runtime for non-constant, non-initialised tensors
306 //
307 // Generate loop(s) with per-dimension stride-based index arithmetic
308 // ---------------------------------------------------------------- //
313
314 auto buildIdxExpr = [&](const std::vector<Dim> &dimShape,
315 const std::vector<Dim> &strides,
316 size_t rankZ) -> std::string {
317 if (dimShape.empty() ||
318 std::all_of(dimShape.begin(), dimShape.end(),
319 [](Dim d) { return d.dim == 1 || d.GetVal() == "1"; }))
320 return "0";
321 std::string expr;
322 size_t offset = rankZ - dimShape.size();
323 for (size_t i = 0; i < dimShape.size(); ++i) {
324 if (dimShape[i].dim == 1 || dimShape[i].GetVal() == "1") continue;
325 expr += "idx_" + std::to_string(i + offset);
326 if (strides[i].GetVal() != "1")
327 expr += " * " + strides[i].GetVal();
328 expr += " + ";
329 }
330 if (expr.size() >= 3)
331 for (int j = 0; j < 3; j++) expr.pop_back(); // remove trailing " + "
332 return expr.empty() ? "0" : expr;
333 };
334
335 std::string idxX = buildIdxExpr(fDimShapeX, stridesX, fDimShapeZ.size());
336 std::string idxY = buildIdxExpr(fDimShapeY, stridesY, fDimShapeZ.size());
337 std::string idxC = buildIdxExpr(fDimShapeC, stridesC, fDimShapeZ.size());
338
339 // Emit nested loops over output shape
340 int nloop = 0;
341 std::string idxZ;
342 if (fDimShapeZ.empty() ||
343 std::all_of(fDimShapeZ.begin(), fDimShapeZ.end(),
344 [](Dim d) { return d.dim == 1 || d.GetVal() == "1"; })) {
345 idxZ = "0";
346 } else {
347 for (size_t i = 0; i < fDimShapeZ.size(); ++i) {
348 if (fDimShapeZ[i].dim != 1 && fDimShapeZ[i].GetVal() != "1") {
349 nloop++;
350 for (int j = 0; j < nloop; j++) out << SP;
351 out << "for (size_t idx_" << i << " = 0; idx_" << i
352 << " < " << fDimShapeZ[i] << "; ++idx_" << i << ") {\n";
353 idxZ += "idx_" + std::to_string(i);
354 if (stridesZ[i].GetVal() != "1")
355 idxZ += " * " + stridesZ[i].GetVal();
356 idxZ += " + ";
357 }
358 }
359 if (idxZ.size() >= 3)
360 for (int j = 0; j < 3; j++) idxZ.pop_back();
361 }
362
363 // Inner assignment
364 for (int j = 0; j < nloop + 1; j++) out << SP;
365 out << "tensor_" << fNZ << "[" << idxZ << "] = "
366 << "tensor_" << fNC << "[" << idxC << "] ? "
367 << "tensor_" << fNX << "[" << idxX << "] : "
368 << "tensor_" << fNY << "[" << idxY << "];\n";
369
370 // Close loops
371 for (int i = nloop; i > 0; i--) {
372 for (int j = 0; j < i; j++) out << SP;
373 out << "}\n";
374 }
375
376 return out.str();
377 }
378};
379
380} // namespace SOFIE
381} // namespace Experimental
382} // namespace TMVA
383
384#endif // TMVA_SOFIE_ROperator_Where
#define d(i)
Definition RSha256.hxx:102
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
winID h TVirtualViewer3D TVirtualGLPainter p
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h offset
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t UChar_t len
char name[80]
Definition TGX11.cxx:148
const_iterator begin() const
const_iterator end() const
std::vector< size_t > GetTensorShape(const std::string &name) const
Definition RModel.cxx:51
std::vector< Dim > GetDimTensorShape(const std::string &name) const
Definition RModel.cxx:87
bool IsDynamicTensor(const std::string &name) const
Definition RModel.cxx:269
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< Dim > dim_shape)
Definition RModel.cxx:284
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:144
void AddConstantTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape, std::shared_ptr< void > data)
Definition RModel.cxx:215
bool IsInitializedTensor(const std::string &name) const
Definition RModel.cxx:256
std::vector< Dim > GetDynamicTensorShape(const std::string &name) const
Definition RModel.cxx:98
std::shared_ptr< void > GetInitializedTensorData(std::string tensor_name)
Definition RModel.cxx:349
void SetNotWritableInitializedTensor(const std::string &tensor_name)
Definition RModel.cxx:358
ETensorType GetTensorType(std::string name) const
Definition RModel.cxx:112
const std::vector< std::string > & GetInputTensorNames() const
Definition RModel.hxx:203
bool IsReadyInputTensor(const std::string &name) const
Definition RModel.cxx:278
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > input) override
std::string Generate(std::string opName) override
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input) override
ROperator_Where(const std::string &nameC, const std::string &nameX, const std::string &nameY, const std::string &nameZ)
std::vector< std::string_view > fInputTensorNames
Definition ROperator.hxx:50
bool fIsOutputConstant
flag to identify if operator has a constant output (no need to generate code)
Definition ROperator.hxx:47
const std::string SP
space used to correctly indent the generated C++ code
Definition ROperator.hxx:45
std::vector< std::string_view > fOutputTensorNames
Definition ROperator.hxx:51
std::vector< size_t > MultidirectionalBroadcastShape(std::vector< std::vector< size_t > >)
T * UnidirectionalBroadcast(const T *data, const std::vector< size_t > &shape, const std::vector< size_t > &targetShape)
std::vector< size_t > ComputeStrideFromShape(const std::vector< size_t > &shape)
compute stride of a tensor given its shape (assume layout is row-major)
std::string ConvertDimShapeToString(const std::vector< Dim > &shape)
std::size_t ConvertShapeToLength(const std::vector< size_t > &shape)
std::string ConvertValuesToString(size_t n, const T *data, size_t maxprint=-1)
std::vector< Dim > ConvertShapeToDim(const std::vector< size_t > &shape)
Convert shape from integer format to dynamic one (based on Dim)
std::string ConvertDimShapeToLength(const std::vector< Dim > &shape)
std::string ConvertShapeToString(const std::vector< size_t > &shape)
create variable transformations