Logo ROOT  
Reference Guide
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
Loading...
Searching...
No Matches
GenerateModel.py File Reference

Detailed Description

View in nbviewer Open in SWAN
This tutorial shows how to define and generate a keras model for use with TMVA.

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation
from tensorflow.keras.utils import plot_model
# Setup the model here
num_input_nodes = 4
num_output_nodes = 2
num_hidden_layers = 1
nodes_hidden_layer = 64
l2_val = 1e-5
model = Sequential()
# Hidden layer 1
# NOTE: Number of input nodes need to be defined in this layer
model.add(Dense(nodes_hidden_layer, activation='relu', kernel_regularizer=l2(l2_val), input_dim=num_input_nodes))
# Hidden layer 2 to num_hidden_layers
# NOTE: Here, you can do what you want
for k in range(num_hidden_layers-1):
model.add(Dense(nodes_hidden_layer, activation='relu', kernel_regularizer=l2(l2_val)))
# Output layer
# NOTE: Use following output types for the different tasks
# Binary classification: 2 output nodes with 'softmax' activation
# Regression: 1 output with any activation ('linear' recommended)
# Multiclass classification: (number of classes) output nodes with 'softmax' activation
model.add(Dense(num_output_nodes, activation='softmax'))
# Compile model
# NOTE: Use following settings for the different tasks
# Any classification: 'categorical_crossentropy' is recommended loss function
# Regression: 'mean_squared_error' is recommended loss function
model.compile(loss='categorical_crossentropy', optimizer=SGD(learning_rate=0.01), weighted_metrics=['accuracy',])
# Save model
model.save('model.h5')
# Additional information about the model
# NOTE: This is not needed to run the model
# Print summary
# Visualize model as graph
try:
plot_model(model, to_file='model.png', show_shapes=True)
except:
print('[INFO] Failed to make model plot')
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Date
2017
Author
TMVA Team

Definition in file GenerateModel.py.