Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_Add.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR_ADD
2#define TMVA_SOFIE_ROPERATOR_ADD
3
5#include "TMVA/ROperator.hxx"
6#include "TMVA/RModel.hxx"
7
8#include <sstream>
9
10namespace TMVA{
11namespace Experimental{
12namespace SOFIE{
13
14template <typename T>
15class ROperator_Add final : public ROperator
16{
17
18private:
19
20 std::string fNX1;
21 std::string fNX2;
22 std::string fNY;
23 std::vector<size_t> fShape;
24
25public:
27 ROperator_Add(std::string nameX1, std::string nameX2, std::string nameY):
28 fNX1(UTILITY::Clean_name(nameX1)), fNX2(UTILITY::Clean_name(nameX2)), fNY(UTILITY::Clean_name(nameY)){}
29
30 // type of output given input
31 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
32 return input;
33 }
34
35 // shape of output tensors given input tensors
36 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input){
37 // assume now inputs have same shape (no broadcasting)
38 auto ret = std::vector<std::vector<size_t>>(1, input[0]); // return vector size 1 with first input
39 return ret;
40 }
41
42 void Initialize(RModel& model){
43 // input must be a graph input, or already initialized intermediate tensor
44 if (model.CheckIfTensorAlreadyExist(fNX1) == false){
45 throw std::runtime_error(std::string("TMVA SOFIE Add Op Input Tensor ") + fNX1 + "is not found in model");
46 }
47 if (model.CheckIfTensorAlreadyExist(fNX2) == false) {
48 throw std::runtime_error(std::string("TMVA SOFIE Add Op Input Tensor ") + fNX1 + "is not found in model");
49 }
50 auto shapeX1 = model.GetTensorShape(fNX1);
51 auto shapeX2 = model.GetTensorShape(fNX2);
52 // assume same shape X1 and X2
53 if (shapeX1 != shapeX2) {
54 std::string msg = "TMVA SOFIE Add Op: Support only inputs with same shape, shape 1 is " +
55 ConvertShapeToString(shapeX1) + "shape 2 is " + ConvertShapeToString(shapeX2);
56 throw std::runtime_error(msg);
57 }
58 fShape = shapeX1;
60 }
61
62
63 std::string Generate(std::string OpName){
64 OpName = "op_" + OpName;
65 if (fShape.empty()) {
66 throw std::runtime_error("TMVA SOFIE Add called to Generate without being initialized first");
67 }
68 std::stringstream out;
69 // int length = 1;
70 // for(auto& i: fShape){
71 // length *= i;
72 // }
73 size_t length = ConvertShapeToLength(fShape);
74 out << "\n//------ Add\n";
75 out << SP << "for (size_t id = 0; id < " << length << " ; id++){\n";
76 out << SP << SP << "tensor_" << fNY << "[id] = tensor_" << fNX1 << "[id] + tensor_" << fNX2 << "[id];\n";
77 out << SP << "}\n";
78 return out.str();
79 }
80
81};
82
83}//SOFIE
84}//Experimental
85}//TMVA
86
87
88#endif //TMVA_SOFIE_ROPERATOR_Add
const ETensorType & GetTensorType(std::string name)
Definition RModel.cxx:70
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape)
Definition RModel.cxx:136
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:91
const std::vector< size_t > & GetTensorShape(std::string name)
Definition RModel.cxx:49
ROperator_Add(std::string nameX1, std::string nameX2, std::string nameY)
std::string Generate(std::string OpName)
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > input)
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input)
const std::string SP
space used to correctly indent the generated C++ code
Definition ROperator.hxx:39
std::string ConvertShapeToString(std::vector< size_t > shape)
std::size_t ConvertShapeToLength(std::vector< size_t > shape)
create variable transformations