Logo ROOT  
Reference Guide
Loading...
Searching...
No Matches
RFunction.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_RFUNCTION
2#define TMVA_SOFIE_RFUNCTION
3
6
7#include <memory>
8#include <string>
9
10namespace TMVA {
11namespace Experimental {
12namespace SOFIE {
13
14class RModel;
15
16
17class RFunction {
18protected:
19 std::string fFuncName;
21public:
23 virtual ~RFunction() {}
25 return fType;
26 }
27
28 RFunction(std::string funcName, FunctionType type):
29 fFuncName(UTILITY::Clean_name(funcName)),fType(type) {}
30
31};
32
34protected:
35 std::shared_ptr<RModel> fFunction_block;
38 std::vector<std::string> fInputTensors;
39 std::vector<ROperator*> fAddlOp; // temporary vector to store pointer that will be moved in a unique_ptr
40
41public:
42 virtual ~RFunction_Update() {}
45
46 virtual void AddInitializedTensors(const std::vector<std::vector<std::string>>&) {};
47 virtual void Initialize() {};
48 virtual void AddLayerNormalization(int, float, size_t, const std::string&,
49 const std::string&, const std::string&, const std::string&) {};
50 void AddInputTensors(const std::vector<std::vector<std::size_t>>& inputShapes);
51 void AddInputTensors(const std::vector<std::vector<Dim>>& inputShapes);
52 std::shared_ptr<RModel> GetFunctionBlock() {
53 return fFunction_block;
54 }
55 std::string GenerateModel(const std::string& filename, long read_pos = 0, long block_size = -1, bool verbose = false);
56 std::string Generate(const std::vector<std::string>& inputPtrs);
60};
61
63protected:
65public:
71 virtual std::string GenerateModel() = 0;
72 std::string GetFunctionName() {
73 return fFuncName;
74 }
78 std::string Generate(std::size_t num_features, const std::vector<std::string>& inputTensors);
79 std::string Generate(std::size_t num_features, const std::string & inputTensors);
80
81};
82
83
84}//SOFIE
85}//Experimental
86}//TMVA
87
88
89#endif //TMVA_SOFIE_RFUNCTION
std::string Generate(std::size_t num_features, const std::vector< std::string > &inputTensors)
Definition RFunction.cxx:84
std::shared_ptr< RModel > fFunction_block
Definition RFunction.hxx:35
virtual void AddLayerNormalization(int, float, size_t, const std::string &, const std::string &, const std::string &, const std::string &)
Definition RFunction.hxx:48
void AddInputTensors(const std::vector< std::vector< std::size_t > > &inputShapes)
Definition RFunction.cxx:50
std::shared_ptr< RModel > GetFunctionBlock()
Definition RFunction.hxx:52
virtual void AddInitializedTensors(const std::vector< std::vector< std::string > > &)
Definition RFunction.hxx:46
std::string Generate(const std::vector< std::string > &inputPtrs)
Definition RFunction.cxx:72
std::vector< std::string > fInputTensors
Definition RFunction.hxx:38
RFunction(std::string funcName, FunctionType type)
Definition RFunction.hxx:28
create variable transformations