Logo ROOT  
Reference Guide
Loading...
Searching...
No Matches
RFunction.cxx
Go to the documentation of this file.
1#include "TMVA/RModel.hxx"
2#include "TMVA/RFunction.hxx"
3
4
5namespace TMVA {
6namespace Experimental {
7namespace SOFIE {
8
9
10
12 switch(target) {
14 fFuncName = "edge_update";
15 break;
16 }
18 fFuncName = "node_update";
19 break;
20 }
22 fFuncName = "global_update";
23 break;
24 }
25 default:
26 throw std::runtime_error("Invalid target for Update function");
27 }
29 fFunction_block = std::make_unique<RModel>(fFuncName);
30
33 fInputTensors = {"edge","receiver","sender","global"};
35 fInputTensors = {"edge","node","global"};
36 }
37
40 fInputTensors = {"edge"};
41 } else if(fTarget == FunctionTarget::NODES) {
42 fInputTensors = {"node"};
43 } else {
44 fInputTensors = {"global"};
45 }
46 }
47}
48
49// add input tensors, order of provided shapes must be the same as in fInputTensors
50void RFunction_Update::AddInputTensors(const std::vector<std::vector<std::size_t>>& inputShapes) {
51 for(long unsigned int i=0; i<inputShapes.size(); ++i) {
52 fFunction_block->AddInputTensorInfo(fInputTensors[i],ETensorType::FLOAT, inputShapes[i]);
53 fFunction_block->AddInputTensorName(fInputTensors[i]);
54 }
55}
56void RFunction_Update::AddInputTensors(const std::vector<std::vector<Dim>>& inputShapes) {
57 for(long unsigned int i=0; i<inputShapes.size(); ++i) {
58 fFunction_block->AddInputTensorInfo(fInputTensors[i],ETensorType::FLOAT, inputShapes[i]);
59 fFunction_block->AddInputTensorName(fInputTensors[i]);
60 }
61}
62
63std::string RFunction_Update::GenerateModel(const std::string& filename, long read_pos, long block_size, bool verbose) {
64 fFunction_block->SetFilename(filename);
65 // use batch size as block size in RModel::generate
66 fFunction_block->Generate(Options::kGNNComponent,block_size,read_pos, verbose);
67 std::string modelGenerationString;
68 modelGenerationString = "\n//--------- GNN_Update_Function---"+fFuncName+"\n"+fFunction_block->ReturnGenerated();
69 return modelGenerationString;
70}
71
72std::string RFunction_Update::Generate(const std::vector<std::string>& inputs) {
73 std::string inferFunc = fFuncName+".infer(";
74 for(auto&it : inputs) {
75 inferFunc+=it;
76 inferFunc+=",";
77 }
78 inferFunc.pop_back();
79 inferFunc+=");";
80 return inferFunc;
81}
82
83// passing as input a vector of strings for each input tensor
84std::string RFunction_Aggregate::Generate(std::size_t num_features, const std::vector<std::string>& inputTensors) {
85 std::string inferFunc = fFuncName+"("+std::to_string(num_features)+",{";
86 for(auto&it : inputTensors) {
87 inferFunc+=it;
88 inferFunc+=",";
89 }
90 inferFunc.pop_back();
91 inferFunc+="});";
92 return inferFunc;
93}
94
95// here passing directly the name of the vector containing the input tensor
96std::string RFunction_Aggregate::Generate(std::size_t num_features, const std::string & inputTensors) {
97 std::string inferFunc = fFuncName + "(" +std::to_string(num_features) + "," + inputTensors + ")";
98 return inferFunc;
99}
100
101
102
103
104}
105}
106}
std::string Generate(std::size_t num_features, const std::vector< std::string > &inputTensors)
Definition RFunction.cxx:84
std::string GenerateModel(const std::string &filename, long read_pos=0, long block_size=-1, bool verbose=false)
Definition RFunction.cxx:63
std::shared_ptr< RModel > fFunction_block
Definition RFunction.hxx:35
void AddInputTensors(const std::vector< std::vector< std::size_t > > &inputShapes)
Definition RFunction.cxx:50
std::string Generate(const std::vector< std::string > &inputPtrs)
Definition RFunction.cxx:72
std::vector< std::string > fInputTensors
Definition RFunction.hxx:38
create variable transformations