44 if (!model.CheckIfTensorAlreadyExist(
fNX)) {
45 throw std::runtime_error(
"TMVA SOFIE Expand Op Input Tensor " +
fNX +
" is not found in model");
48 if (model.IsInitializedTensor(
fNShape)) {
51 static_cast<int64_t *
>(model.GetInitializedTensorData(
fNShape).get());
54 throw std::runtime_error(
"TMVA::SOFIE - Expand operator shape must be a 1d tensor.");
58 for (
size_t i = 0; i <
N; i++) {
59 if ( shapeData[i] < 0)
60 throw std::runtime_error(
"TMVA::SOFIE - Expand: invalid shape value " + std::to_string(shapeData[i]));
62 std::vector<size_t> shape(shapeData, shapeData +
N);
64 }
else if (model.IsShapeTensor(
fNShape)) {
70 auto shapeOfInputShape = model.GetTensorShape(
fNShape);
72 for (
size_t i = 0; i <
fShapeDim.size(); i++) {
81 std::vector<size_t> shapeX;
82 std::vector<size_t> shapeY;
92 assert(!shapeX.empty() && !shapeY.empty());
95 auto data = model.GetInitializedTensorData(
fNX);
97 std::shared_ptr<void> broadcastedData(
99 std::default_delete<T[]>());
101 model.UpdateInitializedTensor(
fNX, model.GetTensorType(
fNX), shapeY, broadcastedData);
104 model.SetNotWritableInitializedTensor(
fNX);
105 data = broadcastedData;
109 model.AddConstantTensor(
fNY, model.GetTensorType(
fNX), shapeY, data);
112 model.AddIntermediateTensor(
fNY, model.GetTensorType(
fNX), shapeY);
116 model.AddIntermediateTensor(
fNY, model.GetTensorType(
fNX),
fShapeY);
119 if (model.Verbose()) {