Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR
2#define TMVA_SOFIE_ROPERATOR
3
4#include <vector>
5#include <memory>
6
8//#include "RModel.hxx"
9
10
11
12namespace TMVA{
13namespace Experimental{
14namespace SOFIE{
15
16class RModel;
17
19
20
21public:
22 virtual std::vector<std::string> GetBlasRoutines() { return {}; }
23 virtual std::vector<std::string> GetStdLibs() { return {}; }
24 virtual std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>>) { return {}; };
25 virtual std::vector<ETensorType> TypeInference(std::vector<ETensorType>) { return {}; };
26 virtual void Initialize(RModel&) = 0;
27 virtual std::string Generate(std::string OpName) = 0; //expect unique opName for each operator within the same RModel
28 // generate initialization code for session constructor
29 virtual std::string GenerateInitCode() { return "";}
30 // generate some specific declaration code for Session
31 virtual std::string GenerateDeclCode() { return "";}
32 // generate session data members specific to operator
33 virtual std::string GenerateSessionMembersCode(std::string /*opName*/) { return ""; }
34 virtual std::string Header() { return "";}
35
36 /// check if the output of the operator is Constant and is evaluated at initialization time
37 bool IsOutputConstant() const { return fIsOutputConstant; }
38
39 //virtual void Forward_reference() = 0;
40 //virtual void Forward_blas() = 0;
41 virtual ~ROperator(){}
42
43protected:
44
45 const std::string SP = " "; ///< space used to correctly indent the generated C++ code
46 bool fUseSession = false; ///< flag to identify if using the session class
47 bool fIsOutputConstant = false; ///< flag to identify if operator has a constant output (no need to generate code)
48 bool fIsOutputParamShape = false; ///< flag to identify of the output represents a parametric shape (can be knwon at compile time)
49
50 mutable std::vector<std::string_view> fInputTensorNames;
51 mutable std::vector<std::string_view> fOutputTensorNames;
52
53public:
54 std::span<const std::string_view> GetOpInputTensors() const {
55 return fInputTensorNames;
56 }
57
58 std::span<const std::string_view> GetOpOutputTensors() const {
59 return fOutputTensorNames;
60 }
61
62};
63
64
65
66}//SOFIE
67}//Experimental
68}//TMVA
69
70
71#endif //TMVA_SOFIE_OPERATOR
std::vector< std::string_view > fInputTensorNames
Definition ROperator.hxx:50
virtual std::vector< std::string > GetBlasRoutines()
Definition ROperator.hxx:22
virtual void Initialize(RModel &)=0
bool fIsOutputParamShape
flag to identify of the output represents a parametric shape (can be knwon at compile time)
Definition ROperator.hxx:48
bool fIsOutputConstant
flag to identify if operator has a constant output (no need to generate code)
Definition ROperator.hxx:47
virtual std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > >)
Definition ROperator.hxx:24
virtual std::string GenerateInitCode()
Definition ROperator.hxx:29
const std::string SP
space used to correctly indent the generated C++ code
Definition ROperator.hxx:45
virtual std::vector< ETensorType > TypeInference(std::vector< ETensorType >)
Definition ROperator.hxx:25
virtual std::string GenerateSessionMembersCode(std::string)
Definition ROperator.hxx:33
std::span< const std::string_view > GetOpInputTensors() const
Definition ROperator.hxx:54
bool fUseSession
flag to identify if using the session class
Definition ROperator.hxx:46
virtual std::string Generate(std::string OpName)=0
std::span< const std::string_view > GetOpOutputTensors() const
Definition ROperator.hxx:58
virtual std::string GenerateDeclCode()
Definition ROperator.hxx:31
bool IsOutputConstant() const
check if the output of the operator is Constant and is evaluated at initialization time
Definition ROperator.hxx:37
virtual std::vector< std::string > GetStdLibs()
Definition ROperator.hxx:23
std::vector< std::string_view > fOutputTensorNames
Definition ROperator.hxx:51
create variable transformations