Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RModel_GNN.cxx
Go to the documentation of this file.
1#include <algorithm>
2#include <cctype>
3#include <fstream>
4#include <limits>
5
6#include "TMVA/RModel_GNN.hxx"
7#include "TMVA/RFunction.hxx"
8
9namespace TMVA {
10namespace Experimental {
11namespace SOFIE {
12
14 edges_update_block = std::move(other.edges_update_block);
15 nodes_update_block = std::move(other.nodes_update_block);
16 globals_update_block = std::move(other.globals_update_block);
17
18 edge_node_agg_block = std::move(other.edge_node_agg_block);
19 edge_global_agg_block = std::move(other.edge_global_agg_block);
20 node_global_agg_block = std::move(other.node_global_agg_block);
21
22 num_nodes = std::move(other.num_nodes);
23 num_edges = std::move(other.num_edges);
24
25 fName = std::move(other.fName);
26 fFileName = std::move(other.fFileName);
27 fParseTime = std::move(other.fParseTime);
28}
29
31 edges_update_block = std::move(other.edges_update_block);
32 nodes_update_block = std::move(other.nodes_update_block);
33 globals_update_block = std::move(other.globals_update_block);
34
35 edge_node_agg_block = std::move(other.edge_node_agg_block);
36 edge_global_agg_block = std::move(other.edge_global_agg_block);
37 node_global_agg_block = std::move(other.node_global_agg_block);
38
39 num_nodes = std::move(other.num_nodes);
40 num_edges = std::move(other.num_edges);
41
42 fName = std::move(other.fName);
43 fFileName = std::move(other.fFileName);
44 fParseTime = std::move(other.fParseTime);
45
46 return *this;
47}
48
49RModel_GNN::RModel_GNN(GNN_Init& graph_input_struct) {
50 edges_update_block = std::move(graph_input_struct.edges_update_block);
51 nodes_update_block = std::move(graph_input_struct.nodes_update_block);
52 globals_update_block = std::move(graph_input_struct.globals_update_block);
53
54 edge_node_agg_block = std::move(graph_input_struct.edge_node_agg_block);
55 edge_global_agg_block = std::move(graph_input_struct.edge_global_agg_block);
56 node_global_agg_block = std::move(graph_input_struct.node_global_agg_block);
57
58 num_nodes = graph_input_struct.num_nodes;
59 num_edges = graph_input_struct.edges.size();
60 num_node_features = graph_input_struct.num_node_features;
61 num_edge_features = graph_input_struct.num_edge_features;
62 num_global_features = graph_input_struct.num_global_features;
63
64 fFileName = graph_input_struct.filename;
65 fName = fFileName.substr(0, fFileName.rfind("."));
66
67 std::time_t ttime = std::time(0);
68 std::tm* gmt_time = std::gmtime(&ttime);
69 fParseTime = std::asctime(gmt_time);
70}
71
73 std::string hgname;
74 GenerateHeaderInfo(hgname);
75
76 std::ofstream f;
77 f.open(fName+".dat");
78 f.close();
79
80 // Generating Infer function definition for Edge Update function
81 long next_pos;
82 //size_t block_size = num_edges;
83 fGC+="\n\nnamespace Edge_Update{\nstruct Session {\n";
84 // there are 4 input tensors for edge updates: {edges, receiver nodes, sender nodes, globals }
85 std::vector<std::vector<Dim>> update_input_edges(4);
86 update_input_edges[0] = {Dim{"num_edges",num_edges}, Dim{num_edge_features}};
87 update_input_edges[1] = {Dim{"num_edges",num_edges}, Dim{num_node_features}};
88 update_input_edges[2] = {Dim{"num_edges",num_edges}, Dim{num_node_features}};
89 update_input_edges[3] = {Dim{"num_edges",num_edges}, Dim{num_global_features}};
90 edges_update_block->Initialize();
91 edges_update_block->AddInputTensors(update_input_edges);
92 fGC+=edges_update_block->GenerateModel(fName);
93 next_pos = edges_update_block->GetFunctionBlock()->WriteInitializedTensorsToFile(fName+".dat");
94 fGC+="};\n}\n";
95
96 // the number of output edges features can be smaller, so we need to correct here
97 auto num_edge_features_input = num_edge_features;
98 auto edges_update_output_shape = edges_update_block->GetFunctionBlock()->GetDynamicTensorShape(edges_update_block->GetFunctionBlock()->GetOutputTensorNames()[0]);
99 if(!edges_update_output_shape[1].isParam && edges_update_output_shape[1].dim != num_edge_features_input) {
100 num_edge_features = edges_update_output_shape[1].dim;
101 }
102
103 fGC+="\n\nnamespace Node_Update{\nstruct Session {\n";
104 // Generating Infer function definition for Node Update function
105 // num_node_features is the output one
106
107 //block_size = num_nodes;
108 // there are 3 input tensors for node updates: {received edges, nodes, globals }
109 std::vector<std::vector<Dim>> update_input_nodes(3);
110 update_input_nodes[0] = {Dim{"num_nodes",num_nodes}, Dim{num_edge_features}};
111 update_input_nodes[1] = {Dim{"num_nodes",num_nodes}, Dim{num_node_features}};
112 update_input_nodes[2] = {Dim{"num_nodes",num_nodes}, Dim{num_global_features}};
113 nodes_update_block->Initialize();
114 nodes_update_block->AddInputTensors(update_input_nodes);
115 fGC+=nodes_update_block->GenerateModel(fName,next_pos);
116 next_pos = nodes_update_block->GetFunctionBlock()->WriteInitializedTensorsToFile(fName+".dat");
117 fGC+="};\n}\n";
118
119 // we need to correct the output number of node features
120 auto num_node_features_input = num_node_features;
121 auto nodes_update_output_shape = nodes_update_block->GetFunctionBlock()->GetDynamicTensorShape(nodes_update_block->GetFunctionBlock()->GetOutputTensorNames()[0]);
122 if(!nodes_update_output_shape[1].isParam && nodes_update_output_shape[1].dim != num_node_features_input) {
123 num_node_features = nodes_update_output_shape[1].dim;
124 }
125
126 fGC+="\n\nnamespace Global_Update{\nstruct Session {\n";
127 // Generating Infer function definition for Global Update function
128 std::vector<std::vector<std::size_t>> update_input_globals = {{1, num_edge_features},{1, num_node_features},{1, num_global_features}};
129 globals_update_block->Initialize();
130 globals_update_block->AddInputTensors(update_input_globals);
131 fGC+=globals_update_block->GenerateModel(fName,next_pos);
132 next_pos = globals_update_block->GetFunctionBlock()->WriteInitializedTensorsToFile(fName+".dat");
133 fGC+="};\n}\n";
134
135 // correct for difference in global size (check shape[1] of output of the globals update)
136 auto num_global_features_input = num_global_features;
137 if(globals_update_block->GetFunctionBlock()->GetTensorShape(globals_update_block->GetFunctionBlock()->GetOutputTensorNames()[0])[1] != num_global_features) {
138 num_global_features = globals_update_block->GetFunctionBlock()->GetTensorShape(globals_update_block->GetFunctionBlock()->GetOutputTensorNames()[0])[1];
139 }
140
141 fGC+=edge_node_agg_block->GenerateModel();
142
143 if(edge_node_agg_block->GetFunctionType() != edge_global_agg_block->GetFunctionType()) {
144 fGC+=edge_global_agg_block->GenerateModel();
145 }
146 if((edge_node_agg_block->GetFunctionType() != node_global_agg_block->GetFunctionType()) && (edge_global_agg_block->GetFunctionType() != node_global_agg_block->GetFunctionType())) {
147 fGC+=node_global_agg_block->GenerateModel();
148 }
149 fGC+="\n\n";
150
151 // computing inplace on input graph
152 fGC += "struct Session {\n";
153 fGC += "\n// Instantiating session objects for graph components\n";
154 fGC += "Edge_Update::Session edge_update;\n";
155 fGC += "Node_Update::Session node_update;\n";
156 fGC += "Global_Update::Session global_update;\n\n";
157
158 std::string e_num = std::to_string(num_edges);
159 std::string n_num = std::to_string(num_nodes);
160 std::string e_size_input = std::to_string(num_edge_features_input);
161 std::string n_size_input = std::to_string(num_node_features_input);
162 std::string g_size_input = std::to_string(num_global_features_input);
163 std::string e_size = std::to_string(num_edge_features);
164 std::string n_size = std::to_string(num_node_features);
165 std::string g_size = std::to_string(num_global_features);
166
167 // create temp vector for edge and node updates
168 fGC += "std::vector<float> fEdgeUpdates = std::vector<float>(" + e_num + "*" + e_size + ");\n";
169 fGC += "\n\nstd::vector<float> fNodeUpdates = std::vector<float>(" + n_num + "*" + n_size + ");\n";
170
171 fGC += "\n// input vectors for edge update\n";
172 fGC += "std::vector<float> fEdgeInputs = std::vector<float>(" + e_num + "*" + e_size_input + ");\n";
173 fGC += "std::vector<float> fRecNodeInputs = std::vector<float>(" + e_num + "*" + n_size_input + ");\n";
174 fGC += "std::vector<float> fSndNodeInputs = std::vector<float>(" + e_num + "*" + n_size_input + ");\n";
175 fGC += "std::vector<float> fGlobInputs = std::vector<float>(" + e_num + "*" + g_size_input + ");\n\n";
176
177 fGC += "\n// input vectors for node update\n";
178 fGC += "std::vector<float> fNodeInputs = std::vector<float>(" + n_num + "*" + n_size_input + ");\n";
179 fGC += "std::vector<float> fNodeEdgeAggregate = std::vector<float>(" + n_num + "*" + n_size_input + ", 0);\n";
180 fGC += "std::vector<float> fNodeAggregateTemp;\n";
181
182 fGC += "\nvoid infer(TMVA::Experimental::SOFIE::GNN_Data& input_graph){\n";
183
184 // computing updated edge attributes
185 fGC += "\n// --- Edge Update ---\n";
186 fGC += "size_t n_edges = input_graph.edge_data.GetShape()[0];\n";
187 fGC += "if (n_edges > " + e_num + ")\n";
188 fGC += " throw std::runtime_error(\"Number of input edges larger than " + e_num + "\" );\n\n";
189 fGC += "auto receivers = input_graph.edge_index.GetData();\n";
190 fGC += "auto senders = input_graph.edge_index.GetData() + n_edges;\n";
191
192 fGC += "for (size_t k = 0; k < n_edges; k++) { \n";
193 fGC += " std::copy(input_graph.edge_data.GetData() + k * " + e_size_input +
194 ", input_graph.edge_data.GetData() + (k + 1) * " + e_size_input +
195 ", fEdgeInputs.begin() + k * " + e_size_input + ");\n";
196 fGC += " std::copy(input_graph.node_data.GetData() + receivers[k] * " + n_size_input +
197 ", input_graph.node_data.GetData() + (receivers[k] + 1) * " + n_size_input +
198 ", fRecNodeInputs.begin() + k * " + n_size_input + ");\n";
199 fGC += " std::copy(input_graph.node_data.GetData() + senders[k] * " + n_size_input +
200 ", input_graph.node_data.GetData() + (senders[k] + 1) * " + n_size_input +
201 ", fSndNodeInputs.begin() + k * " + n_size_input + ");\n";
202 fGC += " std::copy(input_graph.global_data.GetData()";
203 fGC += ", input_graph.global_data.GetData() + " + g_size_input +
204 ", fGlobInputs.begin() + k * " + g_size_input + ");\n";
205 fGC += "}\n";
206
207 fGC += "fEdgeUpdates = " + edges_update_block->Generate({"n_edges","fEdgeInputs.data(), fRecNodeInputs.data(), fSndNodeInputs.data(), fGlobInputs.data()"}) + "\n";
208
209 if(num_edge_features != num_edge_features_input) {
210 fGC += "\n// resize edge graph data since output feature size is not equal to input size\n";
211 fGC+="input_graph.edge_data = input_graph.edge_data.Resize({n_edges, "+e_size+"});\n";
212 }
213 // copy output
214 fGC += "\nfor (size_t k = 0; k < n_edges; k++) { \n";
215 fGC += " std::copy(fEdgeUpdates.begin()+ k * " + e_size + ", fEdgeUpdates.begin()+ (k+1) * " + e_size +
216 ",input_graph.edge_data.GetData() + k * " + e_size + ");\n";
217 fGC += "}\n";
218 fGC += "\n";
219
220 fGC += "\n\n// --- Node Update ---\n";
221 fGC += "size_t n_nodes = input_graph.node_data.GetShape()[0];\n";
222 // computing updated node attributes
223 fGC += "for (size_t k = 0; k < n_nodes; k++) { \n";
224 fGC += " std::copy(input_graph.node_data.GetData() + k * " + n_size_input +
225 ", input_graph.node_data.GetData() + (k + 1) * " + n_size_input +
226 ", fNodeInputs.begin() + k * " + n_size_input + ");\n";
227 fGC += "}\n";
228 // reset initial aggregate edge vector to zero
229 fGC += "\nstd::fill(fNodeEdgeAggregate.begin(), fNodeEdgeAggregate.end(), 0.);\n";
230 // fGlobInputs is size { n_edges, n_globals}. It needs to be here { n_nodes, n_globals}
231 // if number of nodes is larger than edges we need to resize it and copy values
232
233 fGC += "\n// resize global vector feature to number of nodes if needed\n";
234 fGC += "if (n_nodes > n_edges) {\n";
235 fGC += " fGlobInputs.resize( n_nodes * " + std::to_string(num_global_features_input) + ");\n";
236 fGC += " for (size_t k = n_edges; k < n_nodes; k++)\n";
237 fGC += " std::copy(fGlobInputs.begin(), fGlobInputs.begin() + " + g_size_input +
238 " , fGlobInputs.begin() + k * " + g_size_input + ");\n";
239 fGC += "}\n";
240
241 // loop on nodes and aggregate incoming edges
242 fGC += "\n// aggregate edges going to a node\n";
243 fGC += "for (size_t j = 0; j < n_nodes; j++) {\n";
244 // approximate number of receivers/node to allocate vector
245 fGC += " std::vector<float *> edgesData; edgesData.reserve( int(n_edges/n_nodes) +1);\n";
246 // loop on edges
247 fGC += " for (size_t k = 0; k < n_edges; k++) {\n";
248 fGC += " if (receivers[k] == j) \n";
249 fGC += " edgesData.emplace_back(input_graph.edge_data.GetData() + k * " + e_size + ");\n";
250 fGC += " }\n";
251 fGC += " fNodeAggregateTemp = " + edge_node_agg_block->Generate(num_edge_features, "edgesData") + ";\n";
252 fGC += " std::copy(fNodeAggregateTemp.begin(), fNodeAggregateTemp.end(), fNodeEdgeAggregate.begin() + " +
253 e_size + " * j);\n";
254 fGC += "}\n"; // end node loop
255
256
257 fGC+="\n";
258 fGC+="fNodeUpdates = ";
259 fGC+=nodes_update_block->Generate({"n_nodes","fNodeEdgeAggregate.data()","fNodeInputs.data()","fGlobInputs.data()"}); // computing updated node attributes
260 fGC+="\n";
261
262 if(num_node_features != num_node_features_input) {
263 fGC += "\n// resize node graph data since output feature size is not equal to input size\n";
264 fGC+="input_graph.node_data = input_graph.node_data.Resize({n_nodes, " + n_size + "});\n";
265 }
266 // copy output
267 fGC += "\nfor (size_t k = 0; k < n_nodes; k++) { \n";
268 fGC += " std::copy(fNodeUpdates.begin()+ k * " + n_size + ", fNodeUpdates.begin() + (k+1) * " + n_size +
269 ",input_graph.node_data.GetData() + k * " + n_size+ ");\n";
270 fGC += "}\n";
271 fGC += "\n";
272
273 // aggregating edges & nodes for global update
274 fGC += "std::vector<float *> allEdgesData; allEdgesData.reserve(n_edges);\n";
275 fGC += "for (size_t k = 0; k < n_edges; k++) {\n";
276 fGC += " allEdgesData.emplace_back(input_graph.edge_data.GetData() + k * " + e_size + ");\n";
277 fGC += "}\n";
278 fGC += "std::vector<float *> allNodesData; allNodesData.reserve(n_nodes);\n";
279 fGC += "for (size_t k = 0; k < n_nodes; k++) {\n";
280 fGC += " allNodesData.emplace_back(input_graph.node_data.GetData() + k * " + n_size + ");\n";
281 fGC += "}\n";
282
283
284 fGC += "\n// --- Global Update ---\n";
285 fGC+="std::vector<float> Edge_Global_Aggregate = ";
286 fGC+=edge_global_agg_block->Generate(num_edge_features, "allEdgesData"); // aggregating edge attributes globally
287 fGC+=";\n";
288
289 fGC+="std::vector<float> Node_Global_Aggregate = ";
290 fGC+=node_global_agg_block->Generate(num_node_features, "allNodesData"); // aggregating node attributes globally
291 fGC+=";\n";
292
293 // computing updated global attributes
294 fGC += "std::vector<float> Global_Data = ";
295 fGC += globals_update_block->Generate({"Edge_Global_Aggregate.data()","Node_Global_Aggregate.data()", "input_graph.global_data.GetData()"});
296 if(num_global_features != num_global_features_input) {
297 fGC += "\n// resize global graph data since output feature size is not equal to input size\n";
298 fGC+="input_graph.global_data = input_graph.global_data.Resize({"+g_size+"});\n";
299 }
300 fGC += "\nstd::copy(Global_Data.begin(), Global_Data.end(), input_graph.global_data.GetData());";
301 fGC+="\n}\n";
302 fGC+="};\n";
303
304 fGC += ("} //TMVA_SOFIE_" + fName + "\n");
305 fGC += "\n#endif // TMVA_SOFIE_" + hgname + "\n";
306}
307
308}//SOFIE
309}//Experimental
310}//TMVA
#define f(i)
Definition RSha256.hxx:104
void GenerateHeaderInfo(std::string &hgname)
std::unique_ptr< RFunction_Aggregate > node_global_agg_block
std::unique_ptr< RFunction_Update > globals_update_block
RModel_GNN & operator=(RModel_GNN &&other)
RModel_GNN()=default
Default constructor.
std::unique_ptr< RFunction_Update > edges_update_block
std::unique_ptr< RFunction_Aggregate > edge_global_agg_block
std::unique_ptr< RFunction_Aggregate > edge_node_agg_block
std::unique_ptr< RFunction_Update > nodes_update_block
create variable transformations
std::vector< std::pair< int, int > > edges
std::unique_ptr< RFunction_Aggregate > node_global_agg_block
std::unique_ptr< RFunction_Update > globals_update_block
std::unique_ptr< RFunction_Update > nodes_update_block
std::unique_ptr< RFunction_Aggregate > edge_node_agg_block
std::unique_ptr< RFunction_Aggregate > edge_global_agg_block
std::unique_ptr< RFunction_Update > edges_update_block