Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_Custom.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR_Custom
2#define TMVA_SOFIE_ROPERATOR_Custom
3
4
6#include "TMVA/ROperator.hxx"
7#include "TMVA/RModel.hxx"
8
9namespace TMVA{
10namespace Experimental{
11namespace SOFIE{
12
13
14template<typename T>
15class ROperator_Custom final : public ROperator
16{
17
18private:
19 std::string fOpName;
20 std::vector<std::string> fInputNames;
21 std::vector<std::string> fOutputNames;
22 std::vector<std::vector<std::size_t>> fOutputShapes;
23 std::string fHeaderName;
24
25public:
27 ROperator_Custom(std::string OpName, std::vector<std::string>Inputs, std::vector<std::string>Outputs, std::vector<std::vector<std::size_t>> OutputShapes, std::string HeaderName){
28 fOpName = OpName;
29 fOutputShapes = OutputShapes;
30 fHeaderName = HeaderName;
31 for(auto& it:Inputs){
32 fInputNames.emplace_back(UTILITY::Clean_name(it));
33 }
34 for(auto& it:Outputs){
35 fOutputNames.emplace_back(UTILITY::Clean_name(it));
36 }
37 }
38
39 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>>) {return {{}};};
40 std::vector<ETensorType> TypeInference(std::vector<ETensorType>){ return {};};
41
42 void Initialize(RModel& model){
44 for(auto& it:fInputNames){
45 if (model.CheckIfTensorAlreadyExist(it) == false){
46 throw std::runtime_error("TMVA SOFIE Custom " + fOpName + " Op Input Tensor " + it + " is not found in model");
47 }
48 }
49
50 if(fOutputNames.size() != fOutputShapes.size()){
51 throw std::runtime_error("TMVA SOFIE Custom "+ fOpName + " Op was not intialized with the names/shapes of all the output tensors");
52 }
53
54 for(long unsigned int i=0; i<fOutputNames.size(); ++i){
56 }
58 if (model.Verbose()) {
59 std::cout << "Custom operator using " << fHeaderName;
60 for (auto & i : fInputNames) std::cout << " " << i;
61 std::cout << " ---> ";
62 for (auto & i : fOutputNames) std::cout << " " << i;
63 std::cout << "\n";
64 }
65 }
66
67 std::string Generate(std::string OpName){
68 OpName = "op_" + OpName;
69 std::stringstream out;
70 out << "\n//------ "<<fOpName<<" \n";
71 std::string args;
72 for(long unsigned int i = 0; i<fInputNames.size(); ++i){
73 args+="fTensor_"+fInputNames[i]+",";
74 }
75
76 for(long unsigned int i = 0; i<fOutputNames.size(); ++i){
77 args+="fTensor_"+fOutputNames[i]+",";
78 }
79 args.pop_back();
80 out << SP << fOpName<<"::Compute("+args+");\n";
81 return out.str();
82 }
83
84};
85
86
87}//SOFIE
88}//Experimental
89}//TMVA
90
91
92#endif //TMVA_SOFIE_ROPERATOR_Custom
void AddNeededCustomHeader(std::string filename)
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< Dim > dim_shape)
Definition RModel.cxx:203
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:122
void UpdateOutputTensorList(std::vector< std::string > curr_output_tensor, std::vector< std::string > modify_output_tensor)
Definition RModel.cxx:248
ROperator_Custom(std::string OpName, std::vector< std::string >Inputs, std::vector< std::string >Outputs, std::vector< std::vector< std::size_t > > OutputShapes, std::string HeaderName)
std::string Generate(std::string OpName)
std::vector< ETensorType > TypeInference(std::vector< ETensorType >)
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > >)
std::vector< std::vector< std::size_t > > fOutputShapes
const std::string SP
space used to correctly indent the generated C++ code
Definition ROperator.hxx:41
std::string Clean_name(std::string input_tensor_name)
create variable transformations