Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RFunction.cxx
Go to the documentation of this file.
1#include "TMVA/RModel.hxx"
2#include "TMVA/RFunction.hxx"
3
4
5namespace TMVA {
6namespace Experimental {
7namespace SOFIE {
8
9
10
12 switch(target) {
14 fFuncName = "edge_update";
15 break;
16 }
18 fFuncName = "node_update";
19 break;
20 }
22 fFuncName = "global_update";
23 break;
24 }
25 default:
26 throw std::runtime_error("Invalid target for Update function");
27 }
29 function_block = std::make_unique<RModel>(fFuncName);
30
33 fInputTensors = {"edge","receiver","sender","global"};
35 fInputTensors = {"edge","node","global"};
36 }
37
40 fInputTensors = {"edge"};
41 } else if(fTarget == FunctionTarget::NODES) {
42 fInputTensors = {"node"};
43 } else {
44 fInputTensors = {"global"};
45 }
46 }
47}
48
49// add inpuit tensors, order of provided shapes must be the same as in fInputTensors
50void RFunction_Update::AddInputTensors(const std::vector<std::vector<std::size_t>>& inputShapes) {
51 for(long unsigned int i=0; i<inputShapes.size(); ++i) {
52 function_block->AddInputTensorInfo(fInputTensors[i],ETensorType::FLOAT, inputShapes[i]);
53 function_block->AddInputTensorName(fInputTensors[i]);
54 }
55}
56void RFunction_Update::AddInputTensors(const std::vector<std::vector<Dim>>& inputShapes) {
57 for(long unsigned int i=0; i<inputShapes.size(); ++i) {
58 function_block->AddInputTensorInfo(fInputTensors[i],ETensorType::FLOAT, inputShapes[i]);
59 function_block->AddInputTensorName(fInputTensors[i]);
60 }
61}
62
63std::string RFunction_Update::GenerateModel(const std::string& filename, long read_pos, long block_size) {
64 function_block->SetFilename(filename);
65 // use batch size as block size in RModel::generate
66 function_block->PrintRequiredInputTensors();
67 function_block->PrintDynamicTensors();
68 function_block->Generate(Options::kGNNComponent,block_size,read_pos);
69 std::string modelGenerationString;
70 modelGenerationString = "\n//--------- GNN_Update_Function---"+fFuncName+"\n"+function_block->ReturnGenerated();
71 return modelGenerationString;
72}
73
74std::string RFunction_Update::Generate(const std::vector<std::string>& inputs) {
75 std::string inferFunc = fFuncName+".infer(";
76 for(auto&it : inputs) {
77 inferFunc+=it;
78 inferFunc+=",";
79 }
80 inferFunc.pop_back();
81 inferFunc+=");";
82 return inferFunc;
83}
84
85// passing as input a vector of strings for each input tensor
86std::string RFunction_Aggregate::Generate(std::size_t num_features, const std::vector<std::string>& inputTensors) {
87 std::string inferFunc = fFuncName+"("+std::to_string(num_features)+",{";
88 for(auto&it : inputTensors) {
89 inferFunc+=it;
90 inferFunc+=",";
91 }
92 inferFunc.pop_back();
93 inferFunc+="});";
94 return inferFunc;
95}
96
97// here passing directly the name of the vector containing the input tensor
98std::string RFunction_Aggregate::Generate(std::size_t num_features, const std::string & inputTensors) {
99 std::string inferFunc = fFuncName + "(" +std::to_string(num_features) + "," + inputTensors + ")";
100 return inferFunc;
101}
102
103
104
105
106}
107}
108}
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t target
std::string Generate(std::size_t num_features, const std::vector< std::string > &inputTensors)
Definition RFunction.cxx:86
void AddInputTensors(const std::vector< std::vector< std::size_t > > &inputShapes)
Definition RFunction.cxx:50
std::string GenerateModel(const std::string &filename, long read_pos=0, long block_size=-1)
Definition RFunction.cxx:63
std::shared_ptr< RModel > function_block
Definition RFunction.hxx:35
std::string Generate(const std::vector< std::string > &inputPtrs)
Definition RFunction.cxx:74
std::vector< std::string > fInputTensors
Definition RFunction.hxx:38
create variable transformations
static int gType