Loading [MathJax]/extensions/tex2jax.js
Logo ROOT  
Reference Guide
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
Loading...
Searching...
No Matches
GenerateModel.py File Reference

Namespaces

namespace  GenerateModel
 

Detailed Description

View in nbviewer Open in SWAN This tutorial shows how to define and generate a keras model for use with TMVA.

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation
from tensorflow.keras.regularizers import l2
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.utils import plot_model
# Setup the model here
num_input_nodes = 4
num_output_nodes = 2
num_hidden_layers = 1
nodes_hidden_layer = 64
l2_val = 1e-5
model = Sequential()
# Hidden layer 1
# NOTE: Number of input nodes need to be defined in this layer
model.add(Dense(nodes_hidden_layer, activation='relu', kernel_regularizer=l2(l2_val), input_dim=num_input_nodes))
# Hidden layer 2 to num_hidden_layers
# NOTE: Here, you can do what you want
for k in range(num_hidden_layers-1):
model.add(Dense(nodes_hidden_layer, activation='relu', kernel_regularizer=l2(l2_val)))
# Ouput layer
# NOTE: Use following output types for the different tasks
# Binary classification: 2 output nodes with 'softmax' activation
# Regression: 1 output with any activation ('linear' recommended)
# Multiclass classification: (number of classes) output nodes with 'softmax' activation
model.add(Dense(num_output_nodes, activation='softmax'))
# Compile model
# NOTE: Use following settings for the different tasks
# Any classification: 'categorical_crossentropy' is recommended loss function
# Regression: 'mean_squared_error' is recommended loss function
model.compile(loss='categorical_crossentropy', optimizer=SGD(learning_rate=0.01), metrics=['accuracy',])
# Save model
model.save('model.h5')
# Additional information about the model
# NOTE: This is not needed to run the model
# Print summary
model.summary()
# Visualize model as graph
try:
plot_model(model, to_file='model.png', show_shapes=True)
except:
print('[INFO] Failed to make model plot')
Date
2017
Author
TMVA Team

Definition in file GenerateModel.py.