Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RModel_GNN.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_RMODEL_GNN
2#define TMVA_SOFIE_RMODEL_GNN
3
4#include <ctime>
5
7#include "TMVA/RModel.hxx"
8#include "TMVA/RFunction.hxx"
9
10namespace TMVA {
11namespace Experimental {
12namespace SOFIE {
13
14class RFunction_Update;
15class RFunction_Aggregate;
16
17struct GNN_Init {
18 // updation blocks
19 std::unique_ptr<RFunction_Update> edges_update_block;
20 std::unique_ptr<RFunction_Update> nodes_update_block;
21 std::unique_ptr<RFunction_Update> globals_update_block;
22
23 // aggregation blocks
24 std::unique_ptr<RFunction_Aggregate> edge_node_agg_block;
25 std::unique_ptr<RFunction_Aggregate> edge_global_agg_block;
26 std::unique_ptr<RFunction_Aggregate> node_global_agg_block;
27
29 std::vector<std::pair<int,int>> edges;
30
31 std::size_t num_node_features;
32 std::size_t num_edge_features;
34
35 std::string filename;
36
38 edges_update_block.reset();
39 nodes_update_block.reset();
41
42 edge_node_agg_block.reset();
45 }
46
47 template <typename T>
48 void createUpdateFunction(T& updateFunction) {
49 switch(updateFunction.GetFunctionTarget()) {
51 edges_update_block.reset(new T(updateFunction));
52 break;
53 }
55 nodes_update_block.reset(new T(updateFunction));
56 break;
57 }
59 globals_update_block.reset(new T(updateFunction));
60 break;
61 }
62 default: {
63 throw std::runtime_error("TMVA SOFIE: Invalid Update function supplied for creating GNN function block.");
64 }
65 }
66 }
67
68 template <typename T>
69 void createAggregateFunction(T& aggFunction, FunctionRelation relation) {
70 switch(relation) {
72 edge_node_agg_block.reset(new T(aggFunction));
73 break;
74 }
76 node_global_agg_block.reset(new T(aggFunction));
77 break;
78 }
80 edge_global_agg_block.reset(new T(aggFunction));
81 break;
82 }
83 default: {
84 throw std::runtime_error("TMVA SOFIE: Invalid Aggregate function supplied for creating GNN function block.");
85 }
86 }
87 }
88
89};
90
92
93private:
94
95 // updation function for edges, nodes & global attributes
96 std::unique_ptr<RFunction_Update> edges_update_block;
97 std::unique_ptr<RFunction_Update> nodes_update_block;
98 std::unique_ptr<RFunction_Update> globals_update_block;
99
100 // aggregation function for edges, nodes & global attributes
101 std::unique_ptr<RFunction_Aggregate> edge_node_agg_block;
102 std::unique_ptr<RFunction_Aggregate> edge_global_agg_block;
103 std::unique_ptr<RFunction_Aggregate> node_global_agg_block;
104
107 std::vector<int> senders; // contains node indices
108 std::vector<int> receivers; // contains node indices
109
110 std::size_t num_node_features;
111 std::size_t num_edge_features;
113
114public:
115
116 //explicit move ctor/assn
117 RModel_GNN(RModel_GNN&& other);
118
120
121 //disallow copy
122 RModel_GNN(const RModel_GNN& other) = delete;
123 RModel_GNN& operator=(const RModel_GNN& other) = delete;
124
125 RModel_GNN(GNN_Init& graph_input_struct);
127
128 void Generate();
129
131// ClassDef(RModel_GNN,1);
132};
133
134}//SOFIE
135}//Experimental
136}//TMVA
137
138#endif //TMVA_SOFIE_RMODEL_GNN
RModel_GNN(const RModel_GNN &other)=delete
RModel_GNN & operator=(const RModel_GNN &other)=delete
std::unique_ptr< RFunction_Aggregate > node_global_agg_block
std::unique_ptr< RFunction_Update > globals_update_block
RModel_GNN & operator=(RModel_GNN &&other)
std::unique_ptr< RFunction_Update > edges_update_block
std::unique_ptr< RFunction_Aggregate > edge_global_agg_block
std::unique_ptr< RFunction_Aggregate > edge_node_agg_block
std::unique_ptr< RFunction_Update > nodes_update_block
create variable transformations
std::unique_ptr< RFunction_Aggregate > node_global_agg_block
std::unique_ptr< RFunction_Update > globals_update_block
std::unique_ptr< RFunction_Update > nodes_update_block
std::unique_ptr< RFunction_Aggregate > edge_node_agg_block
std::unique_ptr< RFunction_Aggregate > edge_global_agg_block
std::vector< std::pair< int, int > > edges
void createAggregateFunction(T &aggFunction, FunctionRelation relation)
std::unique_ptr< RFunction_Update > edges_update_block
void createUpdateFunction(T &updateFunction)