1#ifndef TMVA_SOFIE_ROPERATOR_Concat
2 #define TMVA_SOFIE_ROPERATOR_Concat
32 ROperator_Concat(std::vector<std::string> inputs,
int axis,
int newAxis, std::string output):
40 [](
const std::string& s) -> std::string_view { return s; });
44 std::vector<ETensorType>
TypeInference(std::vector<ETensorType> input)
override {
49 std::vector<std::vector<size_t>>
ShapeInference(std::vector<std::vector<size_t>> inputs)
override {
50 std::vector<std::vector<size_t>>
ret(1);
56 throw std::runtime_error(
"TMVA SOFIE Concat Op - invalid axis value ");
61 for (
size_t i = 0; i < inputs.size(); i++) {
62 if (i > 0 && inputs[i].
size() != inputs[i - 1].
size())
63 throw std::runtime_error(
"TMVA SOFIE Concat Op - input tensors have different shapes " +
65 for (
size_t iaxis = 0; iaxis < inputs[i].size(); iaxis++) {
66 if ((
int)iaxis ==
fAxis)
67 concat_dim += inputs[i][iaxis];
68 else if (i > 0 && inputs[i][iaxis] != inputs[i - 1][iaxis])
69 throw std::runtime_error(
"TMVA SOFIE Concat Op - input tensors have wrong shapes " +
79 std::vector<int> stack;
82 for(
size_t i = 0; i < inputs.size(); i++) {
83 if (i > 0 && inputs[i].
size() != inputs[i-1].
size() )
84 throw std::runtime_error(
"TMVA SOFIE Concat Op - input tensors have different shapes " +
fInputs[i] +
" : " +
86 for (
size_t iaxis = 0; iaxis < inputs[i].size(); iaxis++) {
87 if ((
int) iaxis ==
fAxis)
88 stack.push_back(inputs[i][iaxis]);
90 if (i> 0 && inputs[i][iaxis] != inputs[i-1][iaxis])
91 throw std::runtime_error(
"TMVA SOFIE Concat Op - input tensors have wrong shapes " +
105 std::vector<Dim>
ret(inputs[0].
size());
111 throw std::runtime_error(
"TMVA SOFIE Concat Op - invalid axis value ");
115 for (
size_t i = 0; i < inputs.size(); i++) {
116 if (i > 0 && inputs[i].
size() != inputs[i - 1].
size())
117 throw std::runtime_error(
"TMVA SOFIE Concat Op - input tensors have different shapes " +
fInputs[i] +
" : " +
119 for (
size_t iaxis = 0; iaxis < inputs[i].size(); iaxis++) {
120 if ((
int)iaxis ==
fAxis) {
122 if (concat_dim.
param.empty() && concat_dim.
dim == 0)
123 concat_dim = inputs[i][iaxis];
124 else if (inputs[i][iaxis].isParam || concat_dim.
isParam) {
126 Dim{ concat_dim.
GetVal() + std::string(
" + ") + inputs[i][iaxis].GetVal(),
127 static_cast<size_t>(-1)};
129 concat_dim =
Dim { concat_dim.
dim + inputs[i][iaxis].dim };
133 ret[iaxis] = inputs[i][iaxis];
135 else if ((!inputs[i][iaxis].isParam && !
ret[iaxis].isParam) && (inputs[i][iaxis].dim !=
ret[iaxis].dim)) {
136 throw std::runtime_error(
"TMVA SOFIE Concat Op - input tensors have wrong shapes " +
140 else if (!inputs[i][iaxis].isParam &&
ret[iaxis].isParam){
142 ret[iaxis] = inputs[i][iaxis];
144 else if (inputs[i][iaxis].isParam &&
ret[iaxis].isParam) {
146 auto & dimNames = model.GetDimShapeNames();
147 auto p1 = std::find(dimNames.begin(), dimNames.end(), inputs[i][iaxis].param);
148 auto p2 = std::find(dimNames.begin(), dimNames.end(),
ret[iaxis].param);
149 if (p1 < p2)
ret[iaxis] = inputs[i][iaxis];
154 if (concat_dim.
isParam && concat_dim.
dim ==
static_cast<size_t>(-1))
155 concat_dim =
Dim{ std::string(
"(") + concat_dim.
GetVal() + std::string(
")"), concat_dim.
dim };
167 throw std::runtime_error(
"TMVA SOFIE Concat Op - stacking (i.e. COncatFromSequence with new_axis=1) is not supported ");
174 if (model.CheckIfTensorAlreadyExist(it) ==
false) {
175 throw std::runtime_error(
"TMVA SOFIE Concat Op Input Tensor " + it +
" is not found in model");
184 bool isOutputShape =
false;
187 isOutputShape =
true;
189 for (
auto & input :
fInputs) {
190 if (!model.IsInitializedTensor(input)) {
192 if (!model.IsShapeTensor(input)) {
193 isOutputShape =
false;
202 for (
auto & input :
fInputs) {
203 auto inputData =
static_cast<int64_t*
>(model.GetInitializedTensorData(input).get());
204 auto inputShape = model.GetTensorShape(input);
206 std::copy(inputData, inputData + inputLength, outputData.begin() + offset );
207 offset += inputLength;
209 model.SetNotWritableInitializedTensor(input);
211 model.AddConstantTensor<int64_t>(
fOutput, outputShape, outputData.data());
212 if (model.Verbose()) {
213 std::cout <<
"output of Concat is a constant tensor " <<
ConvertShapeToString(outputShape) <<
" : "
216 }
else if (isOutputShape) {
220 for (
auto & input :
fInputs) {
221 std::vector<Dim> inputData;
222 auto inputShape = model.GetTensorShape(input);
224 if (model.IsShapeTensor(input)) {
225 inputData = model.GetShapeTensorValues(input);
226 }
else if (model.IsInitializedTensor(input)) {
227 inputData.resize(inputLength);
228 auto intData =
static_cast<int64_t*
>(model.GetInitializedTensorData(input).get());
229 for (
size_t i = 0; i < inputData.size(); i++)
230 inputData[i] =
Dim{
static_cast<size_t>(intData[i])};
234 throw std::runtime_error(
"TMVA SOFIE Concat Operator- invalid input type for shape output type");
236 std::copy(inputData.begin(), inputData.end(), outputData.begin() + offset );
237 offset += inputLength;
240 model.AddShapeTensor(
fOutput,outputData,
false);
241 if (model.Verbose()) {
242 std::cout <<
"output of Concat is a shape tensor " <<
ConvertShapeToString(outputShape) <<
" : "
250 if (model.Verbose()) {
256 std::string
Generate(std::string opName)
override {
257 opName =
"op_" + opName;
258 std::stringstream out;
264 throw std::runtime_error(
"TMVA SOFIE Concat called to Generate without being initialized first");
267 bool hasShapeOnes =
true;
268 for(
int i = 0; i<
fAxis; ++i){
270 hasShapeOnes =
false;
274 if (
fAxis == 0 || hasShapeOnes) {
276 for(
size_t i=0; i<
fInputs.size(); ++i) {
278 out <<
SP <<
"TMVA::Experimental::SOFIE::Copy(tensor_" <<
fOutput;
281 offset +=
" + " + length;
282 out <<
", " <<
"tensor_" <<
fInputs[i] <<
", " + length <<
");\n";
288 std::vector<std::vector<Dim>> inStrides(
fInputs.size());
290 for (
auto &s : inStrides) {
294 for (
int i = 0; i <
fAxis; ++i) {
296 out <<
SP <<
"for (size_t i" << i <<
" = 0; i" << i <<
" < " <<
fOutputShape[i].GetVal() <<
"; ++i" << i <<
") {\n";
299 out <<
SP <<
SP <<
SP <<
"int idxOut = ";
300 for (
int k = 0; k <
fAxis; k++) {
301 if (k > 0) out <<
" + ";
302 out << outStride[k].GetVal() <<
"*i" << k;
306 for (
size_t j = 0; j <
fInputs.size(); j++) {
308 out <<
SP <<
SP <<
SP <<
"idxOut += " << inStrides[j-1][
fAxis-1].GetVal() <<
";\n";
309 out <<
SP <<
SP <<
SP <<
"int idxIn" << j <<
" = ";
310 for (
int k = 0; k <
fAxis; k++) {
311 if (k > 0) out <<
" + ";
312 out << inStrides[j][k].GetVal() <<
"*i" << k;
315 out <<
SP <<
SP <<
SP <<
"for (size_t iC = 0; iC < " << inStrides[j][
fAxis-1].GetVal() <<
"; ++iC) {\n";
316 out <<
SP <<
SP <<
SP <<
SP <<
"tensor_" <<
fOutput <<
"[idxOut+iC] = tensor_" <<
fInputs[j] <<
"[idxIn" << j <<
"+iC];\n";
317 out <<
SP <<
SP <<
SP <<
"}\n";
320 for (
int i = 0; i <
fAxis; ++i) {
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
std::vector< std::string > fInputs
std::vector< Dim > ShapeInference(const std::vector< std::vector< Dim > > &inputs, const RModel &model)
std::vector< Dim > fOutputShape
std::vector< std::vector< Dim > > fInputShapes
ROperator_Concat(std::vector< std::string > inputs, int axis, int newAxis, std::string output)
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input) override
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > inputs) override
void Initialize(RModel &model) override
std::string Generate(std::string opName) override
std::vector< std::string_view > fInputTensorNames
bool fIsOutputConstant
flag to identify if operator has a constant output (no need to generate code)
const std::string SP
space used to correctly indent the generated C++ code
std::vector< std::string_view > fOutputTensorNames
std::string Clean_name(std::string input_tensor_name)
std::vector< size_t > ComputeStrideFromShape(const std::vector< size_t > &shape)
compute stride of a tensor given its shape (assume layout is row-major)
std::string ConvertDimShapeToString(const std::vector< Dim > &shape)
std::size_t ConvertShapeToLength(const std::vector< size_t > &shape)
std::string ConvertValuesToString(size_t n, const T *data, size_t maxprint=-1)
std::vector< size_t > ConvertShapeToInt(const std::vector< Dim > &shape)
Convert shape based on Dim to integer format.
std::string ConvertDimShapeToLength(const std::vector< Dim > &shape)
std::string ConvertShapeToString(const std::vector< size_t > &shape)
create variable transformations
std::string GetVal() const