8namespace Experimental {
19 fName = std::move(other.fName);
32 fName = std::move(other.fName);
53 std::time_t ttime = std::time(0);
54 std::tm* gmt_time = std::gmtime(&ttime);
75 fGC +=
"\n\nnamespace Edge_Update{\nstruct Session {\n";
86 if(!edges_update_output_shape[1].isParam && edges_update_output_shape[1].dim != num_edge_features_input) {
92 fGC+=
"\n\nnamespace Node_Update{\nstruct Session {\n";
105 if(!nodes_update_output_shape[1].isParam && nodes_update_output_shape[1].dim != num_node_features_input) {
112 fGC+=
"\n\nnamespace Global_Update{\nstruct Session {\n";
124 if(!globals_update_output_shape[1].isParam && globals_update_output_shape[1].dim != num_global_features_input) {
135 fGC +=
"struct Session {\n";
136 fGC +=
"\n// Instantiating session objects for graph components\n";
139 fGC +=
"Edge_Update::Session edge_update;\n";
141 fGC +=
"std::vector<float> fEdgeInputs = std::vector<float>(" + std::to_string(
num_edges) +
"*" + std::to_string(num_edge_features_input) +
");\n";
145 fGC +=
"Node_Update::Session node_update;\n";
146 fGC +=
"std::vector<float> fNodeInputs = std::vector<float>(" + std::to_string(
num_nodes) +
"*" + std::to_string(num_node_features_input) +
");\n";
150 fGC +=
"Global_Update::Session global_update;\n\n";
154 fGC +=
"\nvoid infer(TMVA::Experimental::SOFIE::GNN_Data& input_graph){\n";
159 fGC +=
"\n// --- Edge Update ---\n";
161 std::string e_size_input = std::to_string(num_edge_features_input);
162 fGC +=
"size_t n_edges = input_graph.edge_data.GetShape()[0];\n";
163 fGC +=
"for (size_t k = 0; k < n_edges; k++) { \n";
164 fGC +=
" std::copy(input_graph.edge_data.GetData() + k * " + e_size_input +
165 ", input_graph.edge_data.GetData() + (k + 1) * " + e_size_input +
", fEdgeInputs.begin() + k * " +
166 e_size_input +
");\n";
169 fGC +=
"auto edgeUpdates = " +
edges_update_block->Generate({
"n_edges",
"fEdgeInputs.data()"}) +
"\n";
172 fGC +=
"\n// resize edge graph data since output feature size is not equal to input size\n";
173 fGC +=
"input_graph.edge_data = input_graph.edge_data.Resize({ n_edges, " +
177 fGC +=
"\nfor (size_t k = 0; k < n_edges; k++) { \n";
180 ",input_graph.edge_data.GetData() + k * " + std::to_string(
num_edge_features) +
");\n";
187 std::string n_size_input = std::to_string(num_node_features_input);
188 fGC +=
"\n// --- Node Update ---\n";
189 fGC +=
"size_t n_nodes = input_graph.node_data.GetShape()[0];\n";
190 fGC +=
"for (size_t k = 0; k < n_nodes; k++) { \n";
191 fGC +=
" std::copy(input_graph.node_data.GetData() + k * " + n_size_input +
192 ", input_graph.node_data.GetData() + (k + 1) * " + n_size_input +
", fNodeInputs.begin() + k * " +
193 n_size_input +
");\n";
196 fGC +=
"auto nodeUpdates = ";
201 fGC +=
"\n// resize node graph data since output feature size is not equal to input size\n";
202 fGC +=
"input_graph.node_data = input_graph.node_data.Resize({ n_nodes, " +
206 fGC +=
"\nfor (size_t k = 0; k < n_nodes; k++) { \n";
209 ",input_graph.node_data.GetData() + k * " + std::to_string(
num_node_features) +
");\n";
216 fGC +=
"\n// --- Global Update ---\n";
217 fGC +=
"std::vector<float> Global_Data = ";
222 fGC +=
"\n// resize global graph data since output feature size is not equal to input size\n";
223 fGC +=
"input_graph.global_data = input_graph.global_data.Resize({" + std::to_string(
num_global_features) +
227 fGC +=
"\nstd::copy(Global_Data.begin(), Global_Data.end(), input_graph.global_data.GetData());";
231 fGC += (
"}\n};\n} //TMVA_SOFIE_" +
fName +
"\n");
232 fGC +=
"\n#endif // TMVA_SOFIE_" + hgname +
"\n";
void GenerateHeaderInfo(std::string &hgname)
std::size_t num_global_features
RModel_GraphIndependent & operator=(RModel_GraphIndependent &&other)
std::size_t num_node_features
std::unique_ptr< RFunction_Update > edges_update_block
RModel_GraphIndependent()=default
Default constructor.
std::size_t num_edge_features
std::unique_ptr< RFunction_Update > globals_update_block
std::unique_ptr< RFunction_Update > nodes_update_block
create variable transformations
std::unique_ptr< RFunction_Update > nodes_update_block
std::vector< std::pair< int, int > > edges
std::unique_ptr< RFunction_Update > globals_update_block
std::unique_ptr< RFunction_Update > edges_update_block