Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RModel_GNN.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_RMODEL_GNN
2#define TMVA_SOFIE_RMODEL_GNN
3
4#include <ctime>
5
7#include "TMVA/RModel.hxx"
8#include "TMVA/RFunction.hxx"
9
10namespace TMVA {
11namespace Experimental {
12namespace SOFIE {
13
14class RFunction_Update;
15class RFunction_Aggregate;
16
17struct GNN_Init {
18 // update blocks
19 std::unique_ptr<RFunction_Update> edges_update_block;
20 std::unique_ptr<RFunction_Update> nodes_update_block;
21 std::unique_ptr<RFunction_Update> globals_update_block;
22
23 // aggregation blocks
24 std::unique_ptr<RFunction_Aggregate> edge_node_agg_block;
25 std::unique_ptr<RFunction_Aggregate> edge_global_agg_block;
26 std::unique_ptr<RFunction_Aggregate> node_global_agg_block;
27
28 std::size_t num_nodes;
29 std::vector<std::pair<int, int>> edges;
30
31 std::size_t num_node_features;
32 std::size_t num_edge_features;
34
35 std::string filename;
36
38 {
39 edges_update_block.reset();
40 nodes_update_block.reset();
42
43 edge_node_agg_block.reset();
46 }
47
48 template <typename T>
49 void createUpdateFunction(T &updateFunction)
50 {
51 switch (updateFunction.GetFunctionTarget()) {
53 edges_update_block.reset(new T(updateFunction));
54 break;
55 }
57 nodes_update_block.reset(new T(updateFunction));
58 break;
59 }
61 globals_update_block.reset(new T(updateFunction));
62 break;
63 }
64 default: {
65 throw std::runtime_error("TMVA SOFIE: Invalid Update function supplied for creating GNN function block.");
66 }
67 }
68 }
69
70 template <typename T>
71 void createAggregateFunction(T &aggFunction, FunctionRelation relation)
72 {
73 switch (relation) {
75 edge_node_agg_block.reset(new T(aggFunction));
76 break;
77 }
79 node_global_agg_block.reset(new T(aggFunction));
80 break;
81 }
83 edge_global_agg_block.reset(new T(aggFunction));
84 break;
85 }
86 default: {
87 throw std::runtime_error("TMVA SOFIE: Invalid Aggregate function supplied for creating GNN function block.");
88 }
89 }
90 }
91};
92
93class RModel_GNN final : public RModel_GNNBase {
94
95private:
96 // update function for edges, nodes & global attributes
97 std::unique_ptr<RFunction_Update> edges_update_block;
98 std::unique_ptr<RFunction_Update> nodes_update_block;
99 std::unique_ptr<RFunction_Update> globals_update_block;
100
101 // aggregation function for edges, nodes & global attributes
102 std::unique_ptr<RFunction_Aggregate> edge_node_agg_block;
103 std::unique_ptr<RFunction_Aggregate> edge_global_agg_block;
104 std::unique_ptr<RFunction_Aggregate> node_global_agg_block;
105
106 std::size_t num_nodes; // maximum number of nodes
107 std::size_t num_edges; // maximum number of edges
108
109 std::size_t num_node_features;
110 std::size_t num_edge_features;
112
113public:
114 /**
115 Default constructor. Needed to allow serialization of ROOT objects. See
116 https://root.cern/manual/io_custom_classes/#restrictions-on-types-root-io-can-handle
117 */
118 RModel_GNN() = default;
119 RModel_GNN(GNN_Init &graph_input_struct);
120
121 // Rule of five: explicitly define move semantics, disallow copy
122 RModel_GNN(RModel_GNN &&other);
124 RModel_GNN(const RModel_GNN &other) = delete;
125 RModel_GNN &operator=(const RModel_GNN &other) = delete;
126 ~RModel_GNN() final = default;
127
128 void Generate() final;
129};
130
131} // namespace SOFIE
132} // namespace Experimental
133} // namespace TMVA
134
135#endif // TMVA_SOFIE_RMODEL_GNN
RModel_GNN(const RModel_GNN &other)=delete
RModel_GNN & operator=(const RModel_GNN &other)=delete
std::unique_ptr< RFunction_Aggregate > node_global_agg_block
std::unique_ptr< RFunction_Update > globals_update_block
RModel_GNN & operator=(RModel_GNN &&other)
RModel_GNN()=default
Default constructor.
std::unique_ptr< RFunction_Update > edges_update_block
std::unique_ptr< RFunction_Aggregate > edge_global_agg_block
std::unique_ptr< RFunction_Aggregate > edge_node_agg_block
std::unique_ptr< RFunction_Update > nodes_update_block
create variable transformations
std::vector< std::pair< int, int > > edges
std::unique_ptr< RFunction_Aggregate > node_global_agg_block
std::unique_ptr< RFunction_Update > globals_update_block
std::unique_ptr< RFunction_Update > nodes_update_block
std::unique_ptr< RFunction_Aggregate > edge_node_agg_block
std::unique_ptr< RFunction_Aggregate > edge_global_agg_block
void createAggregateFunction(T &aggFunction, FunctionRelation relation)
std::unique_ptr< RFunction_Update > edges_update_block
void createUpdateFunction(T &updateFunction)