1#ifndef TMVA_SOFIE_ROPERATOR_Split
2#define TMVA_SOFIE_ROPERATOR_Split
23 std::vector<std::string>
fNYs;
32 ROperator_Split(
const std::string & nameX,
const std::string & nameS,
int axis,
const std::vector<std::string> & namesY):
34 fNYs.reserve(namesY.size());
41 [](
const std::string& s) -> std::string_view { return s; });
44 std::vector<ETensorType>
TypeInference(std::vector<ETensorType> input)
override {
48 std::vector<std::vector<size_t>>
ShapeInference(std::vector<std::vector<size_t>> input)
override {
54 if (model.CheckIfTensorAlreadyExist(
fNX) ==
false){
55 throw std::runtime_error(
"TMVA SOFIE Split Op Input Tensor is not found in model");
62 throw std::runtime_error(
"TMVA SOFIE Split - invalid axis " + std::to_string(
fAxis));
66 throw std::runtime_error(
"TMVA SOFIE Split - splitting in dynamic axis is not supported");
71 size_t nsplit =
fNYs.size();
74 int64_t splitValue = 0;
75 if (origValue % nsplit == 0) {
76 splitValue = origValue/nsplit;
77 fSplit = std::vector<int64_t>(nsplit, splitValue);
80 splitValue = std::ceil(
double(origValue)/nsplit);
81 fSplit = std::vector<int64_t>(nsplit-1, splitValue);
82 fSplit.push_back(origValue % splitValue);
87 if (!model.IsInitializedTensor(
fNSplit))
88 throw std::runtime_error(
"TMVA SOFIE Split - non-initialized split tensors are not supported");
89 auto splitShape = model.GetTensorShape(
fNSplit);
90 if (splitShape.size() != 1 || splitShape[0] != nsplit)
91 throw std::runtime_error(
"TMVA SOFIE Split - split input tensor has invalid shape");
92 auto split_data =
static_cast<int64_t *
>(model.GetInitializedTensorData(
fNSplit).get());
93 fSplit = std::vector<int64_t>(split_data, split_data + nsplit);
97 for (
size_t i = 0; i <
fNYs.size(); i++) {
101 model.AddIntermediateTensor(
fNYs[i], model.GetTensorType(
fNX), outputShape);
104 if (tot_split != origValue)
105 throw std::runtime_error(
"TMVA SOFIE Split - Sum of split sizes must match the input dimension along the axis");
108 if (model.Verbose()) {
112 std::cout << std::endl;
117 std::string
Generate(std::string OpName)
override {
118 OpName =
"op_" + OpName;
120 throw std::runtime_error(
"TMVA SOFIE Operator Split called to Generate without being initialized first");
126 std::stringstream out;
127 out <<
"\n" <<
SP <<
"//------ Split\n";
128 out <<
SP <<
"size_t " << OpName <<
"_axis_offset = 0;\n";
130 for (
size_t i = 0; i <
fNYs.size(); i++) {
134 out <<
SP <<
"for (int id = 0; id < " << length <<
" ; id++){\n";
136 out <<
SP <<
SP <<
"int input_index = 0;\n";
137 out <<
SP <<
SP <<
"int remaining = id;\n";
140 out <<
SP <<
SP <<
"// dim " << k <<
"\n";
142 out <<
SP <<
SP <<
"input_index += (int(remaining / " << output_strides[k] <<
")";
144 if (k ==
static_cast<size_t>(
fAxis) && i > 0)
145 out <<
" + " << OpName <<
"_axis_offset";
146 out <<
") * " << input_strides[k] <<
";\n";
147 out <<
SP <<
SP <<
"remaining %= " << output_strides[k] <<
";\n";
150 out <<
SP <<
SP <<
"input_index += remaining";
151 if (k ==
static_cast<size_t>(
fAxis) && i > 0)
152 out <<
" + " << OpName <<
"_axis_offset";
157 out <<
SP <<
SP <<
"tensor_" <<
fNYs[i] <<
"[id] = tensor_" <<
fNX <<
"[input_index];\n";
159 if (i <
fNYs.size()-1) out <<
SP << OpName <<
"_axis_offset += " <<
fSplit[i] <<
";\n";
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
std::vector< std::vector< Dim > > fOutputShapes
std::string Generate(std::string OpName) override
ROperator_Split(const std::string &nameX, const std::string &nameS, int axis, const std::vector< std::string > &namesY)
void Initialize(RModel &model) override
std::vector< Dim > fInputShape
std::vector< std::string > fNYs
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input) override
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > input) override
std::vector< int64_t > fSplit
std::vector< std::string_view > fInputTensorNames
const std::string SP
space used to correctly indent the generated C++ code
std::vector< std::string_view > fOutputTensorNames
std::string Clean_name(std::string input_tensor_name)
std::vector< size_t > ComputeStrideFromShape(const std::vector< size_t > &shape)
compute stride of a tensor given its shape (assume layout is row-major)
std::string ConvertDimShapeToString(const std::vector< Dim > &shape)
std::string ConvertDimShapeToLength(const std::vector< Dim > &shape)
create variable transformations