Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
GenerateModel.py File Reference

Namespaces

namespace  GenerateModel
 

Detailed Description

View in nbviewer Open in SWAN This tutorial shows how to define and generate a keras model for use with TMVA.

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation
from tensorflow.keras.regularizers import l2
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.utils import plot_model
# Setup the model here
num_input_nodes = 4
num_output_nodes = 2
num_hidden_layers = 1
nodes_hidden_layer = 64
l2_val = 1e-5
model = Sequential()
# Hidden layer 1
# NOTE: Number of input nodes need to be defined in this layer
model.add(Dense(nodes_hidden_layer, activation='relu', kernel_regularizer=l2(l2_val), input_dim=num_input_nodes))
# Hidden layer 2 to num_hidden_layers
# NOTE: Here, you can do what you want
for k in range(num_hidden_layers-1):
model.add(Dense(nodes_hidden_layer, activation='relu', kernel_regularizer=l2(l2_val)))
# Ouput layer
# NOTE: Use following output types for the different tasks
# Binary classification: 2 output nodes with 'softmax' activation
# Regression: 1 output with any activation ('linear' recommended)
# Multiclass classification: (number of classes) output nodes with 'softmax' activation
model.add(Dense(num_output_nodes, activation='softmax'))
# Compile model
# NOTE: Use following settings for the different tasks
# Any classification: 'categorical_crossentropy' is recommended loss function
# Regression: 'mean_squared_error' is recommended loss function
model.compile(loss='categorical_crossentropy', optimizer=SGD(learning_rate=0.01), metrics=['accuracy',])
# Save model
model.save('model.h5')
# Additional information about the model
# NOTE: This is not needed to run the model
# Print summary
model.summary()
# Visualize model as graph
try:
plot_model(model, to_file='model.png', show_shapes=True)
except:
print('[INFO] Failed to make model plot')
Date
2017
Author
TMVA Team

Definition in file GenerateModel.py.