13import graph_nets
as gn
14from graph_nets
import utils_tf
40 print(s,
"memory:",memoryUse,
"(MB)")
get_dynamic_graph_data_dict(NODE_FEATURE_SIZE=2, EDGE_FEATURE_SIZE=2, GLOBAL_FEATURE_SIZE=1):
get_fix_graph_data_dict(num_nodes, num_edges, NODE_FEATURE_SIZE=2, EDGE_FEATURE_SIZE=2, GLOBAL_FEATURE_SIZE=1):
71 snt.nets.MLP([LATENT_SIZE]*NUM_LAYERS, activate_final=
True),
77 def __init__(self, name="MLPGraphIndependent"):
80 edge_model_fn =
lambda:
snt.nets.MLP([LATENT_SIZE]*NUM_LAYERS, activate_final=
True),
81 node_model_fn =
lambda:
snt.nets.MLP([LATENT_SIZE]*NUM_LAYERS, activate_final=
True),
82 global_model_fn =
lambda:
snt.nets.MLP([LATENT_SIZE]*NUM_LAYERS, activate_final=
True))
92 edge_model_fn=make_mlp_model,
93 node_model_fn=make_mlp_model,
94 global_model_fn=make_mlp_model)
103 name="EncodeProcessDecode"):
110 def __call__(self, input_op, num_processing_steps):
114 for _
in range(num_processing_steps):
116 latent = self.
_core(core_input)
126printMemory(
"before instantiating")
128printMemory(
"after instantiating")
get_fix_graph_data_dict(num_max_nodes, num_max_edges, node_size, edge_size, global_size)
utils_tf.data_dicts_to_graphs_tuple([GraphData])
get_fix_graph_data_dict(num_max_nodes, num_max_edges, 2*LATENT_SIZE, 2*LATENT_SIZE, 2*LATENT_SIZE)
utils_tf.data_dicts_to_graphs_tuple([CoreGraphData])
get_fix_graph_data_dict(num_max_nodes, num_max_edges, LATENT_SIZE, LATENT_SIZE, LATENT_SIZE)
145printMemory(
"before first eval")
ep_model(input_graph_data, processing_steps)
147printMemory(
"after first eval")
ep_model._encoder._network, GraphData, filename = "encoder")
ep_model._core._network, CoreGraphData, filename = "core")
ep_model._decoder._network, DecodeGraphData, filename = "decoder")
ep_model._output_transform._network, DecodeGraphData, filename = "output_transform")
ROOT.TFile.Open("graph_data.root","RECREATE")
ROOT.TTree("gdata","GNN data")
ROOT.std.vector['float'](num_max_nodes*node_size)
ROOT.std.vector['float'](num_max_edges*edge_size)
ROOT.std.vector['float'](global_size)
ROOT.std.vector['int'](num_max_edges)
ROOT.std.vector['int'](num_max_edges)
ROOT.std.vector['float'](3)
186tree.Branch(
"node_data",
"std::vector<float>" , node_data)
187tree.Branch(
"edge_data",
"std::vector<float>" , edge_data)
188tree.Branch(
"global_data",
"std::vector<float>" , global_data)
189tree.Branch(
"receivers",
"std::vector<int>" , receivers)
190tree.Branch(
"senders",
"std::vector<int>" , senders)
193print(
"\n\nSaving data in a ROOT File:")
ROOT.TH1D("h1","GraphNet nodes output",40,1,0)
ROOT.TH1D("h2","GraphNet edges output",40,1,0)
ROOT.TH1D("h3","GraphNet global output",40,1,0)
198for i
in range(0,numevts):
get_dynamic_graph_data_dict(node_size, edge_size, global_size)
200 s_nodes = graphData[
'nodes'].size
201 s_edges = graphData[
'edges'].size
202 num_edges = graphData[
'edges'].shape[0]
reshape((graphData['nodes'].size)))
214 if (i < 1
and verbose) :
215 print(
"Nodes - shape:",
int(
node_data.size()/node_size),node_size,
"data: ",node_data)
216 print(
"Edges - shape:",num_edges, edge_size,
"data: ", edge_data)
217 print(
"Globals : ",global_data)
218 print(
"Receivers : ",receivers)
219 print(
"Senders : ",senders)
utils_tf.data_dicts_to_graphs_tuple([graphData])
230printMemory(
"before eval1")
ep_model(dataset[0], processing_steps)
232printMemory(
"after eval1")
236for tf_graph_data
in dataset:
237 output_gnn = ep_model(tf_graph_data, processing_steps)
241 outgnn[0] =
np.mean(output_nodes)
242 outgnn[1] =
np.mean(output_edges)
243 outgnn[2] =
np.mean(output_globals)
247 if (firstEvent
and verbose) :
248 print(
"Output of first event")
257print(
"time to evaluate events",end-start)
258printMemory(
"after eval Nevts")
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
__init__(self, name="EncodeProcessDecode")
__call__(self, input_op, num_processing_steps)
__init__(self, name="MLPGraphIndependent")
__init__(self, name="MLPGraphNetwork")
get_dynamic_graph_data_dict(NODE_FEATURE_SIZE=2, EDGE_FEATURE_SIZE=2, GLOBAL_FEATURE_SIZE=1)
get_fix_graph_data_dict(num_nodes, num_edges, NODE_FEATURE_SIZE=2, EDGE_FEATURE_SIZE=2, GLOBAL_FEATURE_SIZE=1)