Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_Expand.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROperator_Expand
2#define TMVA_SOFIE_ROperator_Expand
3
5#include "TMVA/ROperator.hxx"
6#include "TMVA/RModel.hxx"
7
8#include <sstream>
9
10namespace TMVA{
11namespace Experimental{
12namespace SOFIE{
13
14template<typename T>
16private:
17
18 std::vector<Dim> fShapeX;
19 std::vector<size_t> fShape;
20 std::vector<Dim> fShapeY;
21 std::vector<Dim> fShapeDim;
22
23 std::string fNX;
24 std::string fNShape;
25 std::string fNY;
26 std::string fType;
27
28 bool fInitialized = false;
29 bool fInitializedShape = false;
30 bool fInitBroadcast = false;
31
32public:
34 ROperator_Expand(std::string nameX, std::string nameShape, std::string nameY):
35 fNX(UTILITY::Clean_name(nameX)), fNShape(UTILITY::Clean_name(nameShape)), fNY(UTILITY::Clean_name(nameY)){
38 }
39
40
41 void Initialize(RModel& model) override {
42 // input must be a graph input, or already initialized intermediate tensor
43 if (!model.CheckIfTensorAlreadyExist(fNX)) {
44 throw std::runtime_error("TMVA SOFIE Expand Op Input Tensor " + fNX + " is not found in model");
45 }
47 if (model.IsInitializedTensor(fNShape)) {
48 fInitializedShape = true;
49 int64_t *shapeData =
50 static_cast<int64_t *>(model.GetInitializedTensorData(fNShape).get());
52 if (fShape.size() != 1) {
53 throw std::runtime_error("TMVA::SOFIE - Expand operator shape must be a 1d tensor.");
54 }
55 size_t N = fShape[0];
56 // what do we do if shapeData contains negative values?
57 for (size_t i = 0; i < N; i++) {
58 if ( shapeData[i] < 0)
59 throw std::runtime_error("TMVA::SOFIE - Expand: invalid shape value " + std::to_string(shapeData[i]));
60 }
61 std::vector<size_t> shape(shapeData, shapeData + N);
63 } else if (model.IsShapeTensor(fNShape)) {
64 // case input shape is a shape tensor
66 fInitializedShape = true;
67 } else {
68 // assume shape of input shape is known (size is 1)
71 for (size_t i = 0; i < fShapeDim.size(); i++) {
72 fShapeDim[i] = Dim{std::string("v_") + fNShape + "_" + std::to_string(i)};
73 model.AddShapeParam(fShapeDim[i].param);
74 }
75 }
76 // Y is the common shape of fShapeX and shape
78 fShapeY = ret.second;
80 std::vector<size_t> shapeX;
81 std::vector<size_t> shapeY;
82 // case shape tensor and input shape are known
83 if (!model.IsDynamicTensor(fNX) && !model.IsDimInputTensor(fNX) && fInitializedShape) {
87 fInitBroadcast = true;
88 }
89 if (fInitialized) {
90 // cannot have Dim initialized tensors
91 assert(!shapeX.empty() && !shapeY.empty());
92 // Broadcast X to the common shape shapeY
93 // If X is an initialized tensor (constant)
94 auto data = model.GetInitializedTensorData(fNX);
95 if (fInitBroadcast) {
96 std::shared_ptr<void> broadcastedData(
97 UTILITY::UnidirectionalBroadcast<T>(static_cast<T *>(data.get()), shapeX, shapeY),
98 std::default_delete<T[]>());
99 // Update the data and the shape of X
102 // need to set as a not writable tensor
105 }
106 if (fInitBroadcast || model.IsConstantTensor(fNX)) {
107 fIsOutputConstant = true; // constant output in this case
109 fOutputTensorNames.pop_back();
110 } else {
112 }
113 } else {
114 // // case input is not initialized
115 // if (shapeX.empty() && shapeDim.empty()) {
116
117 // }
118 // if (fInitializedShape)
120 }
122 if (model.Verbose()) {
123 std::cout << "Expand - input " << fNX << " shape " << ConvertShapeToString(fShapeX) << " --> " << fNY << " shape "
124 << ConvertShapeToString(fShapeY) << (fIsOutputConstant ? ConvertValuesToString(model.GetTensorData<T>(fNY)) + " (constant)" : "") << std::endl;
125 }
126 }
127
128 std::string GenerateInitCode() override {
129 std::stringstream out;
131 // shapeX and shapeY are the same in this case
133 out << "// Copying initialized tensor " << fNX << " to " << fNY << "\n";
134 out << SP << "std::copy(tensor_" << fNX << ", " << "tensor_" << fNX << " + " << length << ", tensor_" << fNY << ");\n";
135 }
136 return out.str();
137 }
138
139 std::string Generate(std::string opName) override {
140 if (fIsOutputConstant) return "";
141 opName = "op_" + opName;
142 if (fShapeY.empty()) {
143 throw std::runtime_error("TMVA SOFIE Expand Op called to Generate without being initialized first");
144 }
145 std::stringstream out;
146 out << SP << "\n//------ Expand " << opName << " --> " << ConvertShapeToString(fShapeY) << "\n";
147 // need to declare shape parameters for non initialized shapes
148 if (!fInitializedShape) {
149 for (size_t i = 0; i < fShapeDim.size(); i++) {
150 out << SP << "size_t " << fShapeDim[i] << " = " << "tensor_" << fNShape << "[" << i << "];\n";
151 }
152 }
153 // No need to broadcast A if it's an initialized tensor or shapes are the same
154 if (!fInitialized && fShapeX != fShapeY) {
155 out << SP << "// Broadcasting uninitialized tensor " << fNX << "\n";
156 out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" << fType << ">(tensor_" << fNX << ", " << ConvertShapeToString(fShapeX) << ", " << ConvertShapeToString(fShapeY)
157 << ", std::span<"<<fType<<">(tensor_"<<fNY<<", "<<ConvertDimShapeToLength(fShapeY)<<"));\n";
158 }
159 return out.str();
160 }
161
162};
163
164}//SOFIE
165}//Experimental
166}//TMVA
167
168
169#endif //TMVA_SOFIE_ROperator_Expand
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
#define N
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h length
void AddShapeParam(const std::string &name, size_t def_value=0)
Definition RModel.cxx:281
std::vector< size_t > GetTensorShape(const std::string &name) const
Definition RModel.cxx:29
std::vector< Dim > GetDimTensorShape(const std::string &name) const
Definition RModel.cxx:65
bool IsDynamicTensor(const std::string &name) const
Definition RModel.cxx:232
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< Dim > dim_shape)
Definition RModel.cxx:247
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:122
void AddConstantTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape, std::shared_ptr< void > data)
Definition RModel.cxx:193
bool IsDimInputTensor(const std::string &name) const
Definition RModel.cxx:237
bool IsShapeTensor(const std::string &name) const
check if a tensor is a shape tensor
Definition RModel.cxx:211
bool IsInitializedTensor(const std::string &name) const
Definition RModel.cxx:220
bool IsConstantTensor(const std::string &name) const
Definition RModel.cxx:224
std::shared_ptr< void > GetInitializedTensorData(std::string tensor_name)
Definition RModel.cxx:312
std::vector< T > GetTensorData(const std::string &name)
Definition RModel.hxx:242
void SetNotWritableInitializedTensor(const std::string &tensor_name)
Definition RModel.cxx:321
ETensorType GetTensorType(std::string name) const
Definition RModel.cxx:90
void UpdateInitializedTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape, std::shared_ptr< void > data)
Definition RModel.cxx:303
const std::vector< Dim > & GetShapeTensorValues(const std::string &tensor_name) const
Definition RModel.cxx:215
ROperator_Expand(std::string nameX, std::string nameShape, std::string nameY)
std::string Generate(std::string opName) override
std::vector< std::string_view > fInputTensorNames
Definition ROperator.hxx:47
bool fIsOutputConstant
flag to identify if operator has a constant output (no need to generate code)
Definition ROperator.hxx:44
const std::string SP
space used to correctly indent the generated C++ code
Definition ROperator.hxx:42
std::vector< std::string_view > fOutputTensorNames
Definition ROperator.hxx:48
bool AreSameShape(const std::vector< size_t > &, const std::vector< size_t > &)
std::vector< size_t > MultidirectionalBroadcastShape(std::vector< std::vector< size_t > >)
std::vector< Dim > ConvertShapeToDim(const std::vector< size_t > &shape)
Convert shape from integer format to dynamic one (based on Dim)
std::string ConvertValuesToString(size_t n, const T *data)
std::vector< size_t > ConvertShapeToInt(const std::vector< Dim > &shape)
Convert shape based on Dim to integer format.
std::string ConvertTypeToString(ETensorType type)
std::string ConvertDimShapeToLength(const std::vector< Dim > &shape)
std::string ConvertShapeToString(const std::vector< size_t > &shape)
create variable transformations