10namespace Experimental {
25 fName = std::move(other.fName);
42 fName = std::move(other.fName);
67 std::time_t ttime = std::time(0);
68 std::tm* gmt_time = std::gmtime(&ttime);
83 fGC+=
"\n\nnamespace Edge_Update{\nstruct Session {\n";
85 std::vector<std::vector<Dim>> update_input_edges(4);
99 if(!edges_update_output_shape[1].isParam && edges_update_output_shape[1].dim != num_edge_features_input) {
103 fGC+=
"\n\nnamespace Node_Update{\nstruct Session {\n";
109 std::vector<std::vector<Dim>> update_input_nodes(3);
122 if(!nodes_update_output_shape[1].isParam && nodes_update_output_shape[1].dim != num_node_features_input) {
126 fGC+=
"\n\nnamespace Global_Update{\nstruct Session {\n";
152 fGC +=
"struct Session {\n";
153 fGC +=
"\n// Instantiating session objects for graph components\n";
154 fGC +=
"Edge_Update::Session edge_update;\n";
155 fGC +=
"Node_Update::Session node_update;\n";
156 fGC +=
"Global_Update::Session global_update;\n\n";
158 std::string e_num = std::to_string(
num_edges);
159 std::string n_num = std::to_string(
num_nodes);
160 std::string e_size_input = std::to_string(num_edge_features_input);
161 std::string n_size_input = std::to_string(num_node_features_input);
162 std::string g_size_input = std::to_string(num_global_features_input);
168 fGC +=
"std::vector<float> fEdgeUpdates = std::vector<float>(" + e_num +
"*" + e_size +
");\n";
169 fGC +=
"\n\nstd::vector<float> fNodeUpdates = std::vector<float>(" + n_num +
"*" + n_size +
");\n";
171 fGC +=
"\n// input vectors for edge update\n";
172 fGC +=
"std::vector<float> fEdgeInputs = std::vector<float>(" + e_num +
"*" + e_size_input +
");\n";
173 fGC +=
"std::vector<float> fRecNodeInputs = std::vector<float>(" + e_num +
"*" + n_size_input +
");\n";
174 fGC +=
"std::vector<float> fSndNodeInputs = std::vector<float>(" + e_num +
"*" + n_size_input +
");\n";
175 fGC +=
"std::vector<float> fGlobInputs = std::vector<float>(" + e_num +
"*" + g_size_input +
");\n\n";
177 fGC +=
"\n// input vectors for node update\n";
178 fGC +=
"std::vector<float> fNodeInputs = std::vector<float>(" + n_num +
"*" + n_size_input +
");\n";
179 fGC +=
"std::vector<float> fNodeEdgeAggregate = std::vector<float>(" + n_num +
"*" + n_size_input +
", 0);\n";
180 fGC +=
"std::vector<float> fNodeAggregateTemp;\n";
182 fGC +=
"\nvoid infer(TMVA::Experimental::SOFIE::GNN_Data& input_graph){\n";
185 fGC +=
"\n// --- Edge Update ---\n";
186 fGC +=
"size_t n_edges = input_graph.edge_data.GetShape()[0];\n";
187 fGC +=
"if (n_edges > " + e_num +
")\n";
188 fGC +=
" throw std::runtime_error(\"Number of input edges larger than " + e_num +
"\" );\n\n";
189 fGC +=
"auto receivers = input_graph.edge_index.GetData();\n";
190 fGC +=
"auto senders = input_graph.edge_index.GetData() + n_edges;\n";
192 fGC +=
"for (size_t k = 0; k < n_edges; k++) { \n";
193 fGC +=
" std::copy(input_graph.edge_data.GetData() + k * " + e_size_input +
194 ", input_graph.edge_data.GetData() + (k + 1) * " + e_size_input +
195 ", fEdgeInputs.begin() + k * " + e_size_input +
");\n";
196 fGC +=
" std::copy(input_graph.node_data.GetData() + receivers[k] * " + n_size_input +
197 ", input_graph.node_data.GetData() + (receivers[k] + 1) * " + n_size_input +
198 ", fRecNodeInputs.begin() + k * " + n_size_input +
");\n";
199 fGC +=
" std::copy(input_graph.node_data.GetData() + senders[k] * " + n_size_input +
200 ", input_graph.node_data.GetData() + (senders[k] + 1) * " + n_size_input +
201 ", fSndNodeInputs.begin() + k * " + n_size_input +
");\n";
202 fGC +=
" std::copy(input_graph.global_data.GetData()";
203 fGC +=
", input_graph.global_data.GetData() + " + g_size_input +
204 ", fGlobInputs.begin() + k * " + g_size_input +
");\n";
207 fGC +=
"fEdgeUpdates = " +
edges_update_block->Generate({
"n_edges",
"fEdgeInputs.data(), fRecNodeInputs.data(), fSndNodeInputs.data(), fGlobInputs.data()"}) +
"\n";
210 fGC +=
"\n// resize edge graph data since output feature size is not equal to input size\n";
211 fGC+=
"input_graph.edge_data = input_graph.edge_data.Resize({n_edges, "+e_size+
"});\n";
214 fGC +=
"\nfor (size_t k = 0; k < n_edges; k++) { \n";
215 fGC +=
" std::copy(fEdgeUpdates.begin()+ k * " + e_size +
", fEdgeUpdates.begin()+ (k+1) * " + e_size +
216 ",input_graph.edge_data.GetData() + k * " + e_size +
");\n";
220 fGC +=
"\n\n// --- Node Update ---\n";
221 fGC +=
"size_t n_nodes = input_graph.node_data.GetShape()[0];\n";
223 fGC +=
"for (size_t k = 0; k < n_nodes; k++) { \n";
224 fGC +=
" std::copy(input_graph.node_data.GetData() + k * " + n_size_input +
225 ", input_graph.node_data.GetData() + (k + 1) * " + n_size_input +
226 ", fNodeInputs.begin() + k * " + n_size_input +
");\n";
229 fGC +=
"\nstd::fill(fNodeEdgeAggregate.begin(), fNodeEdgeAggregate.end(), 0.);\n";
233 fGC +=
"\n// resize global vector feature to number of nodes if needed\n";
234 fGC +=
"if (n_nodes > n_edges) {\n";
235 fGC +=
" fGlobInputs.resize( n_nodes * " + std::to_string(num_global_features_input) +
");\n";
236 fGC +=
" for (size_t k = n_edges; k < n_nodes; k++)\n";
237 fGC +=
" std::copy(fGlobInputs.begin(), fGlobInputs.begin() + " + g_size_input +
238 " , fGlobInputs.begin() + k * " + g_size_input +
");\n";
242 fGC +=
"\n// aggregate edges going to a node\n";
243 fGC +=
"for (size_t j = 0; j < n_nodes; j++) {\n";
245 fGC +=
" std::vector<float *> edgesData; edgesData.reserve( int(n_edges/n_nodes) +1);\n";
247 fGC +=
" for (size_t k = 0; k < n_edges; k++) {\n";
248 fGC +=
" if (receivers[k] == j) \n";
249 fGC +=
" edgesData.emplace_back(input_graph.edge_data.GetData() + k * " + e_size +
");\n";
252 fGC +=
" std::copy(fNodeAggregateTemp.begin(), fNodeAggregateTemp.end(), fNodeEdgeAggregate.begin() + " +
258 fGC+=
"fNodeUpdates = ";
259 fGC+=
nodes_update_block->Generate({
"n_nodes",
"fNodeEdgeAggregate.data()",
"fNodeInputs.data()",
"fGlobInputs.data()"});
263 fGC +=
"\n// resize node graph data since output feature size is not equal to input size\n";
264 fGC+=
"input_graph.node_data = input_graph.node_data.Resize({n_nodes, " + n_size +
"});\n";
267 fGC +=
"\nfor (size_t k = 0; k < n_nodes; k++) { \n";
268 fGC +=
" std::copy(fNodeUpdates.begin()+ k * " + n_size +
", fNodeUpdates.begin() + (k+1) * " + n_size +
269 ",input_graph.node_data.GetData() + k * " + n_size+
");\n";
274 fGC +=
"std::vector<float *> allEdgesData; allEdgesData.reserve(n_edges);\n";
275 fGC +=
"for (size_t k = 0; k < n_edges; k++) {\n";
276 fGC +=
" allEdgesData.emplace_back(input_graph.edge_data.GetData() + k * " + e_size +
");\n";
278 fGC +=
"std::vector<float *> allNodesData; allNodesData.reserve(n_nodes);\n";
279 fGC +=
"for (size_t k = 0; k < n_nodes; k++) {\n";
280 fGC +=
" allNodesData.emplace_back(input_graph.node_data.GetData() + k * " + n_size +
");\n";
284 fGC +=
"\n// --- Global Update ---\n";
285 fGC+=
"std::vector<float> Edge_Global_Aggregate = ";
289 fGC+=
"std::vector<float> Node_Global_Aggregate = ";
294 fGC +=
"std::vector<float> Global_Data = ";
295 fGC +=
globals_update_block->Generate({
"Edge_Global_Aggregate.data()",
"Node_Global_Aggregate.data()",
"input_graph.global_data.GetData()"});
297 fGC +=
"\n// resize global graph data since output feature size is not equal to input size\n";
298 fGC+=
"input_graph.global_data = input_graph.global_data.Resize({"+g_size+
"});\n";
300 fGC +=
"\nstd::copy(Global_Data.begin(), Global_Data.end(), input_graph.global_data.GetData());";
304 fGC += (
"} //TMVA_SOFIE_" +
fName +
"\n");
305 fGC +=
"\n#endif // TMVA_SOFIE_" + hgname +
"\n";
void GenerateHeaderInfo(std::string &hgname)
std::size_t num_edge_features
std::unique_ptr< RFunction_Aggregate > node_global_agg_block
std::unique_ptr< RFunction_Update > globals_update_block
RModel_GNN & operator=(RModel_GNN &&other)
RModel_GNN()=default
Default constructor.
std::unique_ptr< RFunction_Update > edges_update_block
std::unique_ptr< RFunction_Aggregate > edge_global_agg_block
std::unique_ptr< RFunction_Aggregate > edge_node_agg_block
std::size_t num_global_features
std::size_t num_node_features
std::unique_ptr< RFunction_Update > nodes_update_block
create variable transformations
std::vector< std::pair< int, int > > edges
std::unique_ptr< RFunction_Aggregate > node_global_agg_block
std::unique_ptr< RFunction_Update > globals_update_block
std::unique_ptr< RFunction_Update > nodes_update_block
std::unique_ptr< RFunction_Aggregate > edge_node_agg_block
std::unique_ptr< RFunction_Aggregate > edge_global_agg_block
std::size_t num_node_features
std::unique_ptr< RFunction_Update > edges_update_block
std::size_t num_global_features
std::size_t num_edge_features