47    fGC+=
"\n\nnamespace Edge_Update{\nstruct Session {\n";
 
   67    fGC+=
"\n\nnamespace Node_Update{\nstruct Session {\n";
 
   90    fGC+=
"\n\nnamespace Global_Update{\nstruct Session {\n";
 
  116    fGC += 
"struct Session {\n";
 
  117    fGC += 
"\n// Instantiating session objects for graph components\n";
 
  118    fGC += 
"Edge_Update::Session edge_update;\n";
 
  119    fGC += 
"Node_Update::Session node_update;\n";
 
  120    fGC += 
"Global_Update::Session global_update;\n\n";
 
  132    fGC += 
"std::vector<float> fEdgeUpdates = std::vector<float>(" + 
e_num + 
"*" + 
e_size + 
");\n";
 
  133    fGC += 
"\n\nstd::vector<float> fNodeUpdates = std::vector<float>(" + 
n_num + 
"*" + 
n_size + 
");\n";
 
  135    fGC += 
"\n// input vectors for edge update\n";
 
  136    fGC += 
"std::vector<float> fEdgeInputs = std::vector<float>(" + 
e_num + 
"*" + 
e_size_input + 
");\n";
 
  137    fGC += 
"std::vector<float> fRecNodeInputs = std::vector<float>(" + 
e_num + 
"*" + 
n_size_input + 
");\n";
 
  138    fGC += 
"std::vector<float> fSndNodeInputs = std::vector<float>(" + 
e_num + 
"*" + 
n_size_input + 
");\n";
 
  139    fGC += 
"std::vector<float> fGlobInputs = std::vector<float>(" + 
e_num + 
"*" + 
g_size_input + 
");\n\n";
 
  141    fGC += 
"\n// input vectors for node update\n";
 
  142    fGC += 
"std::vector<float> fNodeInputs = std::vector<float>(" + 
n_num + 
"*" + 
n_size_input + 
");\n";
 
  143    fGC += 
"std::vector<float> fNodeEdgeAggregate = std::vector<float>(" + 
n_num + 
"*" + 
n_size_input + 
", 0);\n";
 
  144    fGC += 
"std::vector<float> fNodeAggregateTemp;\n";
 
  146    fGC += 
"\nvoid infer(TMVA::Experimental::SOFIE::GNN_Data& input_graph){\n";
 
  149    fGC += 
"\n// --- Edge Update ---\n";
 
  150    fGC +=  
"size_t n_edges = input_graph.edge_data.GetShape()[0];\n";
 
  151    fGC +=  
"if (n_edges > " + 
e_num + 
")\n";
 
  152    fGC +=  
"   throw std::runtime_error(\"Number of input edges larger than " + 
e_num + 
"\" );\n\n";
 
  153    fGC += 
"auto receivers = input_graph.edge_index.GetData();\n";
 
  154    fGC += 
"auto senders = input_graph.edge_index.GetData() + n_edges;\n";
 
  156    fGC += 
"for (size_t k = 0; k < n_edges; k++) { \n";
 
  157    fGC += 
"   std::copy(input_graph.edge_data.GetData() + k * " + 
e_size_input +
 
  158           ", input_graph.edge_data.GetData() + (k + 1) * " + 
e_size_input +
 
  160    fGC += 
"   std::copy(input_graph.node_data.GetData() + receivers[k] * " + 
n_size_input +
 
  161           ", input_graph.node_data.GetData() + (receivers[k] + 1) * " + 
n_size_input +
 
  162           ", fRecNodeInputs.begin() + k * " + 
n_size_input + 
");\n";
 
  163    fGC += 
"   std::copy(input_graph.node_data.GetData() + senders[k] * " + 
n_size_input +
 
  164           ", input_graph.node_data.GetData() + (senders[k] + 1) * " + 
n_size_input +
 
  165           ", fSndNodeInputs.begin() + k * " + 
n_size_input + 
");\n";
 
  166    fGC += 
"   std::copy(input_graph.global_data.GetData()";
 
  171    fGC += 
"fEdgeUpdates = " + 
edges_update_block->Generate({
"n_edges",
"fEdgeInputs.data(), fRecNodeInputs.data(), fSndNodeInputs.data(), fGlobInputs.data()"}) + 
"\n";
 
  174        fGC += 
"\n//  resize edge graph data since output feature size is not equal to input size\n";
 
  175        fGC+=
"input_graph.edge_data = input_graph.edge_data.Resize({n_edges, "+
e_size+
"});\n";
 
  178    fGC += 
"\nfor (size_t k = 0; k < n_edges; k++) { \n";
 
  179    fGC += 
"   std::copy(fEdgeUpdates.begin()+ k * " + 
e_size + 
", fEdgeUpdates.begin()+ (k+1) * " + 
e_size +
 
  180           ",input_graph.edge_data.GetData() + k * " + 
e_size + 
");\n";
 
  184    fGC += 
"\n\n// --- Node Update ---\n";
 
  185    fGC += 
"size_t n_nodes = input_graph.node_data.GetShape()[0];\n";
 
  187    fGC += 
"for (size_t k = 0; k < n_nodes; k++) { \n";
 
  188    fGC += 
"   std::copy(input_graph.node_data.GetData() + k * " + 
n_size_input +
 
  189           ", input_graph.node_data.GetData() + (k + 1) * " + 
n_size_input +
 
  193    fGC += 
"\nstd::fill(fNodeEdgeAggregate.begin(), fNodeEdgeAggregate.end(), 0.);\n";
 
  197    fGC += 
"\n// resize global vector feature to number of nodes if needed\n";
 
  198    fGC += 
"if (n_nodes > n_edges) {\n";
 
  200    fGC += 
"   for (size_t k = n_edges; k < n_nodes; k++)\n";
 
  201    fGC += 
"      std::copy(fGlobInputs.begin(), fGlobInputs.begin() + " + 
g_size_input +
 
  202                   " , fGlobInputs.begin() + k * " + 
g_size_input + 
");\n";
 
  206    fGC += 
"\n// aggregate edges going to a node\n";
 
  207    fGC += 
"for (size_t j = 0; j < n_nodes; j++) {\n";
 
  209    fGC += 
"   std::vector<float *> edgesData; edgesData.reserve( int(n_edges/n_nodes) +1);\n";
 
  211    fGC += 
"   for (size_t k = 0; k < n_edges; k++) {\n";
 
  212    fGC += 
"      if (receivers[k] == j) \n";
 
  213    fGC += 
"         edgesData.emplace_back(input_graph.edge_data.GetData() + k * " + 
e_size + 
");\n";
 
  216    fGC += 
"   std::copy(fNodeAggregateTemp.begin(), fNodeAggregateTemp.end(), fNodeEdgeAggregate.begin() + " +
 
  222    fGC+=
"fNodeUpdates = ";
 
  223    fGC+=
nodes_update_block->Generate({
"n_nodes",
"fNodeEdgeAggregate.data()",
"fNodeInputs.data()",
"fGlobInputs.data()"});    
 
  227        fGC += 
"\n//  resize node graph data since output feature size is not equal to input size\n";
 
  228        fGC+=
"input_graph.node_data = input_graph.node_data.Resize({n_nodes, " + 
n_size + 
"});\n";
 
  231    fGC += 
"\nfor (size_t k = 0; k < n_nodes; k++) { \n";
 
  232    fGC += 
"   std::copy(fNodeUpdates.begin()+ k * " + 
n_size + 
", fNodeUpdates.begin() + (k+1) * " + 
n_size +
 
  233           ",input_graph.node_data.GetData() + k * " + 
n_size+ 
");\n";
 
  238    fGC += 
"std::vector<float *> allEdgesData; allEdgesData.reserve(n_edges);\n";
 
  239    fGC += 
"for (size_t k = 0; k < n_edges; k++) {\n";
 
  240    fGC += 
"   allEdgesData.emplace_back(input_graph.edge_data.GetData() + k * " + 
e_size + 
");\n";
 
  242    fGC += 
"std::vector<float *> allNodesData; allNodesData.reserve(n_nodes);\n";
 
  243    fGC += 
"for (size_t k = 0; k < n_nodes; k++) {\n";
 
  244    fGC += 
"   allNodesData.emplace_back(input_graph.node_data.GetData() + k * " + 
n_size + 
");\n";
 
  248    fGC += 
"\n// --- Global Update ---\n";
 
  249    fGC+=
"std::vector<float> Edge_Global_Aggregate = ";
 
  253    fGC+=
"std::vector<float> Node_Global_Aggregate = ";
 
  258    fGC += 
"std::vector<float> Global_Data = ";
 
  259    fGC += 
globals_update_block->Generate({
"Edge_Global_Aggregate.data()",
"Node_Global_Aggregate.data()", 
"input_graph.global_data.GetData()"});
 
  261        fGC += 
"\n//  resize global graph data since output feature size is not equal to input size\n";
 
  262        fGC+=
"input_graph.global_data = input_graph.global_data.Resize({"+
g_size+
"});\n";
 
  264    fGC += 
"\nstd::copy(Global_Data.begin(), Global_Data.end(), input_graph.global_data.GetData());";
 
  268    fGC += (
"} //TMVA_SOFIE_" + 
fName + 
"\n");
 
  269    fGC += 
"\n#endif  // TMVA_SOFIE_" + 
hgname + 
"\n";