Logo ROOT  
Reference Guide
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
Loading...
Searching...
No Matches
ROperator_ScatterElements.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROperator_ScatterElements
2#define TMVA_SOFIE_ROperator_ScatterElements
3
5#include "TMVA/ROperator.hxx"
6#include "TMVA/RModel.hxx"
7
8#include <sstream>
9
10namespace TMVA{
11namespace Experimental{
12namespace SOFIE{
13
14
16private:
17
18 int64_t fAxis;
19
20 std::string fNX;
21 std::string fNI;
22 std::string fNU;
23 std::string fNY;
24 std::string fReduction;
25
26 std::vector<size_t> fShapeX;
27 std::vector<size_t> fShapeI;
28 std::vector<size_t> fShapeY;
29
30 // define reduction function. Possibilities are:
31 // none (default), add, mul, max, min
32 std::string ReductionFunction(const std::string & t1, const std::string & t2 ) {
33 std::string name = fReduction;
34 if (name.empty() || name == "none")
35 return t2;
36 else if (name == "add")
37 return t1 + " + " + t2;
38 else if (name == "mul")
39 return t1 + " * " + t2;
40 else if (name == "max")
41 return "std::max(" + t1 + "," + t2 + ")";
42 else if (name == "min")
43 return "std::min(" + t1 + "," + t2 + ")";
44 else
45 throw std::runtime_error("TMVA SOFIE ScatterElements : invalid reduction attribute");
46
47 return std::string();
48 }
49
50public:
52 ROperator_ScatterElements(const std::string & nameX, const std::string & nameI, const std::string & nameU, const std::string & nameY,
53 int axis, std::string reduction):
54 fAxis(axis),
55 fNX(UTILITY::Clean_name(nameX)), fNI(UTILITY::Clean_name(nameI)), fNU(UTILITY::Clean_name(nameU)),
56 fNY(UTILITY::Clean_name(nameY)),
58 {
61 }
62
63 // type of output given input
64 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
65 return input;
66 }
67
68 // shape of output tensors given input tensors
69 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
70 auto ret = std::vector<std::vector<size_t>>(1, input[0]); // return vector size 1 with first input
71 return ret;
72 }
73
74 void Initialize(RModel& model) override {
75 // input must be a graph input, or already initialized intermediate tensor
76 if (!model.CheckIfTensorAlreadyExist(fNX)){
77 throw std::runtime_error(std::string("TMVA SOFIE ScatterElements Op Input Tensor ") + fNX + "is not found in model");
78 }
79 if (!model.CheckIfTensorAlreadyExist(fNI)) {
80 throw std::runtime_error(std::string("TMVA SOFIE ScatterElements Op Input Tensor ") + fNI + "is not found in model");
81 }
82 if (!model.CheckIfTensorAlreadyExist(fNU)) {
83 throw std::runtime_error(std::string("TMVA SOFIE ScatterElements Op Input Tensor ") + fNU + "is not found in model");
84 }
85 //tbd check for constant tensors
86
87 fShapeX = model.GetTensorShape(fNX);
88 fShapeI = model.GetTensorShape(fNI);
89 if (model.GetTensorShape(fNU) != fShapeI)
90 throw std::runtime_error(std::string("TMVA SOFIE ScatterElements - update tensor has invalid shape ")) ;
91 if (fShapeX.size() == 0)
92 throw std::runtime_error(std::string("TMVA SOFIE ScatterElements - input tensor has zero rank ")) ;
93 if (fShapeX.size() != fShapeI.size())
94 throw std::runtime_error(std::string("TMVA SOFIE ScatterElements - index tensor has invalid rank ")) ;
95
96 if (fAxis < 0) fAxis += fShapeX.size();
97
98 // assume output shape is identical to input shape
101 }
102
103 std::string GenerateInitCode() override {
104 std::stringstream out;
105 return out.str();
106 }
107
108 std::string Generate(std::string opName) override {
109
110 if (fIsOutputConstant) return "";
111
112 if (fShapeY.empty()) {
113 throw std::runtime_error("TMVA SOFIE ScatterElements Op called to Generate without being initialized first");
114 }
115 std::stringstream out;
116 out << SP << "\n//-------- ScatterElements --- " << opName << "\n";
117
120
122
123 // function to write compute expression for global index from axes indices
124 auto tensorIndex = [](const std::vector<size_t> & stride, const std::vector<std::string> & idx) {
125 std::stringstream strst;
126 int dims = idx.size();
127 assert (dims == (int) stride.size());
128 for (int i = 0; i < dims; i++) {
129 if (stride[i] != 1)
130 strst << stride[i] << "*" << idx[i];
131 else
132 strst << idx[i];
133 if (i < dims-1)
134 strst << " + ";
135 }
136 return strst.str();
137 };
138
139
140 // copy first input in output (maybe can be avoided??)
141 out << SP << "std::copy(tensor_" << fNX << ", tensor_" << fNX << " + " << length << ", tensor_" << fNY << ");\n";
142
143 // loop on tensor rank
144 int dims = fShapeY.size();
145 std::vector<std::string> idx(dims);
146 for (int i = 0; i < dims; i++) {
147 idx[i] = std::string("i") + std::to_string(i);
148 for (int j = 0; j <= i; j++) out << SP;
149 out << "for (int " << idx[i] << " = 0; " << idx[i] << " < " << fShapeI[i] << "; " << idx[i] << "++) {\n";
150 }
151 // correct index for specific axis
152 for (int j = 0; j <= dims; j++) out << SP;
153 out << "int updateIndex = " << tensorIndex(strideI,idx) << ";\n";
154 for (int j = 0; j <= dims; j++) out << SP;
155 out << "int iAxis = tensor_" << fNI << "[updateIndex];\n";
156 for (int j = 0; j <= dims; j++) out << SP;
157 out << "if (iAxis < 0) iAxis += " << fShapeY[fAxis] << ";\n";
158 idx[fAxis] = "iAxis";
159 for (int j = 0; j <= dims; j++) out << SP;
160 out << "int outIndex = " << tensorIndex(strideY, idx) << ";\n";
161 for (int j = 0; j <= dims; j++) out << SP;
162 out << "tensor_" << fNY << "[outIndex] = "
163 << ReductionFunction(std::string("tensor_") + fNY + "[outIndex]", std::string("tensor_") + fNU + "[updateIndex]") << ";\n";
164
165 for (int i = dims; i > 0; i--) {
166 for (int j = 0; j < i; j++) out << SP;
167 out << "}\n";
168 }
169 return out.str();
170 }
171
172};
173
174}//SOFIE
175}//Experimental
176}//TMVA
177
178
179#endif //TMVA_SOFIE_ROperator_ScatterElements
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h length
char name[80]
Definition TGX11.cxx:110
const ETensorType & GetTensorType(std::string name)
Definition RModel.cxx:94
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< Dim > dim_shape)
Definition RModel.cxx:227
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:122
const std::vector< size_t > & GetTensorShape(std::string name)
Definition RModel.cxx:56
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input) override
std::string ReductionFunction(const std::string &t1, const std::string &t2)
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > input) override
ROperator_ScatterElements(const std::string &nameX, const std::string &nameI, const std::string &nameU, const std::string &nameY, int axis, std::string reduction)
std::vector< std::string_view > fInputTensorNames
Definition ROperator.hxx:46
bool fIsOutputConstant
flag to identify if operator has a constant output (no need to generate code)
Definition ROperator.hxx:44
const std::string SP
space used to correctly indent the generated C++ code
Definition ROperator.hxx:42
std::vector< std::string_view > fOutputTensorNames
Definition ROperator.hxx:47
std::vector< size_t > ComputeStrideFromShape(const std::vector< size_t > &shape)
compute stride of a tensor given its shape (assume layout is row-major)
std::size_t ConvertShapeToLength(std::vector< size_t > shape)
create variable transformations
auto * t1
Definition textangle.C:20