Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RModel_GNN.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_RMODEL_GNN
2#define TMVA_SOFIE_RMODEL_GNN
3
4#include <ctime>
5
7#include "TMVA/RModel.hxx"
8#include "TMVA/RFunction.hxx"
9
10namespace TMVA {
11namespace Experimental {
12namespace SOFIE {
13
14class RFunction_Update;
15class RFunction_Aggregate;
16
17struct GNN_Init {
18
19 // Explicitly define default constructor so cppyy doesn't attempt
20 // aggregate initialization.
22
23 // update blocks
24 std::unique_ptr<RFunction_Update> edges_update_block;
25 std::unique_ptr<RFunction_Update> nodes_update_block;
26 std::unique_ptr<RFunction_Update> globals_update_block;
27
28 // aggregation blocks
29 std::unique_ptr<RFunction_Aggregate> edge_node_agg_block;
30 std::unique_ptr<RFunction_Aggregate> edge_global_agg_block;
31 std::unique_ptr<RFunction_Aggregate> node_global_agg_block;
32
33 std::size_t num_nodes;
34 std::vector<std::pair<int, int>> edges;
35
36 std::size_t num_node_features;
37 std::size_t num_edge_features;
39
40 std::string filename;
41
43 {
44 edges_update_block.reset();
45 nodes_update_block.reset();
47
48 edge_node_agg_block.reset();
51 }
52
53 template <typename T>
54 void createUpdateFunction(T &updateFunction)
55 {
56 switch (updateFunction.GetFunctionTarget()) {
58 edges_update_block.reset(new T(updateFunction));
59 break;
60 }
62 nodes_update_block.reset(new T(updateFunction));
63 break;
64 }
66 globals_update_block.reset(new T(updateFunction));
67 break;
68 }
69 default: {
70 throw std::runtime_error("TMVA SOFIE: Invalid Update function supplied for creating GNN function block.");
71 }
72 }
73 }
74
75 template <typename T>
76 void createAggregateFunction(T &aggFunction, FunctionRelation relation)
77 {
78 switch (relation) {
80 edge_node_agg_block.reset(new T(aggFunction));
81 break;
82 }
84 node_global_agg_block.reset(new T(aggFunction));
85 break;
86 }
88 edge_global_agg_block.reset(new T(aggFunction));
89 break;
90 }
91 default: {
92 throw std::runtime_error("TMVA SOFIE: Invalid Aggregate function supplied for creating GNN function block.");
93 }
94 }
95 }
96};
97
98class RModel_GNN final : public RModel_GNNBase {
99
100private:
101 // update function for edges, nodes & global attributes
102 std::unique_ptr<RFunction_Update> edges_update_block;
103 std::unique_ptr<RFunction_Update> nodes_update_block;
104 std::unique_ptr<RFunction_Update> globals_update_block;
105
106 // aggregation function for edges, nodes & global attributes
107 std::unique_ptr<RFunction_Aggregate> edge_node_agg_block;
108 std::unique_ptr<RFunction_Aggregate> edge_global_agg_block;
109 std::unique_ptr<RFunction_Aggregate> node_global_agg_block;
110
111 std::size_t num_nodes; // maximum number of nodes
112 std::size_t num_edges; // maximum number of edges
113
114 std::size_t num_node_features;
115 std::size_t num_edge_features;
117
118public:
119 /**
120 Default constructor. Needed to allow serialization of ROOT objects. See
121 https://root.cern/manual/io_custom_classes/#restrictions-on-types-root-io-can-handle
122 */
123 RModel_GNN() = default;
124 RModel_GNN(GNN_Init &graph_input_struct);
125
126 // Rule of five: explicitly define move semantics, disallow copy
127 RModel_GNN(RModel_GNN &&other);
129 RModel_GNN(const RModel_GNN &other) = delete;
130 RModel_GNN &operator=(const RModel_GNN &other) = delete;
131 ~RModel_GNN() final = default;
132
133 void Generate() final;
134};
135
136} // namespace SOFIE
137} // namespace Experimental
138} // namespace TMVA
139
140#endif // TMVA_SOFIE_RMODEL_GNN
RModel_GNN(const RModel_GNN &other)=delete
RModel_GNN & operator=(const RModel_GNN &other)=delete
std::unique_ptr< RFunction_Aggregate > node_global_agg_block
std::unique_ptr< RFunction_Update > globals_update_block
RModel_GNN & operator=(RModel_GNN &&other)
RModel_GNN()=default
Default constructor.
std::unique_ptr< RFunction_Update > edges_update_block
std::unique_ptr< RFunction_Aggregate > edge_global_agg_block
std::unique_ptr< RFunction_Aggregate > edge_node_agg_block
std::unique_ptr< RFunction_Update > nodes_update_block
create variable transformations
std::vector< std::pair< int, int > > edges
std::unique_ptr< RFunction_Aggregate > node_global_agg_block
std::unique_ptr< RFunction_Update > globals_update_block
std::unique_ptr< RFunction_Update > nodes_update_block
std::unique_ptr< RFunction_Aggregate > edge_node_agg_block
std::unique_ptr< RFunction_Aggregate > edge_global_agg_block
void createAggregateFunction(T &aggFunction, FunctionRelation relation)
std::unique_ptr< RFunction_Update > edges_update_block
void createUpdateFunction(T &updateFunction)