Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_Split.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR_Split
2#define TMVA_SOFIE_ROPERATOR_Split
3
5#include "TMVA/ROperator.hxx"
6#include "TMVA/RModel.hxx"
7
8#include <sstream>
9
10namespace TMVA{
11namespace Experimental{
12namespace SOFIE{
13
14
16{
17
18private:
19
20 int fAxis = 0;
21 std::string fNX;
22 std::string fNSplit;
23 std::vector<std::string> fNYs;
24 std::vector<Dim> fInputShape;
25 std::vector<int64_t> fSplit;
26 std::vector<std::vector<Dim>> fOutputShapes;
27
28
29
30public:
32 ROperator_Split(const std::string & nameX, const std::string & nameS, int axis, const std::vector<std::string> & namesY):
33 fAxis(axis), fNX(UTILITY::Clean_name(nameX)), fNSplit(UTILITY::Clean_name(nameS)) {
34 fNYs.reserve(namesY.size());
35 for (auto & name : namesY)
36 fNYs.push_back(UTILITY::Clean_name(name));
37
39 fOutputTensorNames.resize(fNYs.size());
40 std::transform(fNYs.begin(), fNYs.end(), fOutputTensorNames.begin(),
41 [](const std::string& s) -> std::string_view { return s; });
42 }
43
44 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
45 return input;
46 }
47
48 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
49 auto ret = input; //suggest copy to compiler
50 return ret;
51 }
52
53 void Initialize(RModel& model) override {
54 if (model.CheckIfTensorAlreadyExist(fNX) == false){ //input must be a graph input, or already initialized intermediate tensor
55 throw std::runtime_error("TMVA SOFIE Split Op Input Tensor is not found in model");
56 }
58
59 // correct for negative axis
60 if (fAxis < 0) fAxis += fInputShape.size();
61 if (fAxis < 0 || fAxis >= static_cast<int>(fInputShape.size()) )
62 throw std::runtime_error("TMVA SOFIE Split - invalid axis " + std::to_string(fAxis));
63
64 // support for time being split in axis whi are defined not parametrics
65 if (fInputShape[fAxis].isParam)
66 throw std::runtime_error("TMVA SOFIE Split - splitting in dynamic axis is not supported");
67
68 size_t origValue = fInputShape[fAxis].dim;
69
70 // compute output shapes
71 size_t nsplit = fNYs.size();
72 // case split tensor is empty
73 if (fNSplit.empty()) {
74 int64_t splitValue = 0;
75 if (origValue % nsplit == 0) {
77 fSplit = std::vector<int64_t>(nsplit, splitValue);
78 } else {
79 // case of not equal splitting
80 splitValue = std::ceil(double(origValue)/nsplit);
81 fSplit = std::vector<int64_t>(nsplit-1, splitValue);
82 fSplit.push_back(origValue % splitValue);
83 }
84 } else {
85 // NB : in this case we could support dynamic split axes
86 // get split tensor values
87 if (!model.IsInitializedTensor(fNSplit))
88 throw std::runtime_error("TMVA SOFIE Split - non-initialized split tensors are not supported");
89 auto splitShape = model.GetTensorShape(fNSplit);
90 if (splitShape.size() != 1 || splitShape[0] != nsplit)
91 throw std::runtime_error("TMVA SOFIE Split - split input tensor has invalid shape");
92 auto split_data = static_cast<int64_t *>(model.GetInitializedTensorData(fNSplit).get());
93 fSplit = std::vector<int64_t>(split_data, split_data + nsplit);
94 }
95 // compute now the output shapes
96 size_t tot_split = 0;
97 for (size_t i = 0; i < fNYs.size(); i++) {
98 std::vector<Dim> outputShape = fInputShape;
99 outputShape[fAxis] = Dim{ static_cast<size_t>(fSplit[i]) };
100 tot_split += fSplit[i];
102 fOutputShapes.push_back(outputShape);
103 }
104 if (tot_split != origValue)
105 throw std::runtime_error("TMVA SOFIE Split - Sum of split sizes must match the input dimension along the axis");
106
107
108 if (model.Verbose()) {
109 std::cout << "Split - input shape " << ConvertShapeToString(fInputShape) << " --> ";
110 for (auto & s : fOutputShapes)
111 std::cout << ConvertShapeToString(s) << " ";
112 std::cout << std::endl;
113 }
114 }
115
116
117 std::string Generate(std::string OpName) override {
118 OpName = "op_" + OpName;
119 if (fOutputShapes.empty()){
120 throw std::runtime_error("TMVA SOFIE Operator Split called to Generate without being initialized first");
121 }
122
124
125 // generate now the code for split
126 std::stringstream out;
127 out << "\n" << SP << "//------ Split\n";
128 out << SP << "size_t " << OpName << "_axis_offset = 0;\n";
129 // unroll the loop on split outputs
130 for (size_t i = 0; i < fNYs.size(); i++) {
133
134 out << SP << "for (int id = 0; id < " << length << " ; id++){\n";
135 // convert output index to input index
136 out << SP << SP << "int input_index = 0;\n";
137 out << SP << SP << "int remaining = id;\n";
138 // loop on dimensions to compute the input indices(unroll this loop)
139 for (size_t k = 0; k < fOutputShapes[i].size(); ++k) {
140 out << SP << SP << "// dim " << k << "\n";
141 if (k < fOutputShapes[i].size()-1) {
142 out << SP << SP << "input_index += (int(remaining / " << output_strides[k] << ")";
143 // for the split axis we need to consider the offset in the splits when converting to input coordinates
144 if (k == static_cast<size_t>(fAxis) && i > 0)
145 out << " + " << OpName << "_axis_offset";
146 out << ") * " << input_strides[k] << ";\n";
147 out << SP << SP << "remaining %= " << output_strides[k] << ";\n";
148 } else {
149 // for last dims all strides are one
150 out << SP << SP << "input_index += remaining";
151 if (k == static_cast<size_t>(fAxis) && i > 0)
152 out << " + " << OpName << "_axis_offset";
153 out << ";\n\n";
154 }
155 }
156
157 out << SP << SP << "tensor_" << fNYs[i] << "[id] = tensor_" << fNX <<"[input_index];\n";
158 out << SP << "}\n";
159 if (i < fNYs.size()-1) out << SP << OpName << "_axis_offset += " << fSplit[i] << ";\n";
160 }
161 return out.str();
162 }
163
164};
165
166}//SOFIE
167}//Experimental
168}//TMVA
169
170
171#endif //TMVA_SOFIE_ROPERATOR_Swish
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h length
char name[80]
Definition TGX11.cxx:110
std::vector< size_t > GetTensorShape(const std::string &name) const
Definition RModel.cxx:29
std::vector< Dim > GetDimTensorShape(const std::string &name) const
Definition RModel.cxx:65
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< Dim > dim_shape)
Definition RModel.cxx:247
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:122
bool IsInitializedTensor(const std::string &name) const
Definition RModel.cxx:220
std::shared_ptr< void > GetInitializedTensorData(std::string tensor_name)
Definition RModel.cxx:312
ETensorType GetTensorType(std::string name) const
Definition RModel.cxx:90
std::vector< std::vector< Dim > > fOutputShapes
std::string Generate(std::string OpName) override
ROperator_Split(const std::string &nameX, const std::string &nameS, int axis, const std::vector< std::string > &namesY)
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input) override
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > input) override
std::vector< std::string_view > fInputTensorNames
Definition ROperator.hxx:47
const std::string SP
space used to correctly indent the generated C++ code
Definition ROperator.hxx:42
std::vector< std::string_view > fOutputTensorNames
Definition ROperator.hxx:48
std::string Clean_name(std::string input_tensor_name)
std::vector< size_t > ComputeStrideFromShape(const std::vector< size_t > &shape)
compute stride of a tensor given its shape (assume layout is row-major)
std::string ConvertDimShapeToLength(const std::vector< Dim > &shape)
std::string ConvertShapeToString(const std::vector< size_t > &shape)
create variable transformations