Loading [MathJax]/extensions/tex2jax.js
Logo ROOT   6.08/07
Reference Guide
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
GenerateModel.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 
3 from keras.models import Sequential
4 from keras.layers.core import Dense, Activation
5 from keras.regularizers import l2
6 from keras import initializations
7 from keras.optimizers import SGD
8 
9 # Setup the model here
10 num_input_nodes = 4
11 num_output_nodes = 2
12 num_hidden_layers = 1
13 nodes_hidden_layer = 64
14 l2_val = 1e-5
15 
16 # NOTE: Either you can use predefined initializations (see Keras documentation)
17 # or you can define your own initialization in such a function
18 
19 def normal(shape, name=None):
20  return initializations.normal(shape, scale=0.05, name=name)
21 
22 model = Sequential()
23 
24 # Hidden layer 1
25 # NOTE: Number of input nodes need to be defined in this layer
26 model.add(Dense(nodes_hidden_layer, init=normal, activation='relu', W_regularizer=l2(l2_val), input_dim=num_input_nodes))
27 
28 # Hidden layer 2 to num_hidden_layers
29 # NOTE: Here, you can do what you want
30 for k in range(num_hidden_layers-1):
31  model.add(Dense(nodes_hidden_layer, init=normal, activation='relu', W_regularizer=l2(l2_val)))
32 
33 # Ouput layer
34 # NOTE: Use following output types for the different tasks
35 # Binary classification: 2 output nodes with 'softmax' activation
36 # Regression: 1 output with any activation ('linear' recommended)
37 # Multiclass classification: (number of classes) output nodes with 'softmax' activation
38 model.add(Dense(num_output_nodes, init=normal, activation='softmax'))
39 
40 # Compile model
41 # NOTE: Use following settings for the different tasks
42 # Any classification: 'categorical_crossentropy' is recommended loss function
43 # Regression: 'mean_squared_error' is recommended loss function
44 model.compile(loss='categorical_crossentropy', optimizer=SGD(lr=0.01), metrics=['accuracy',])
45 
46 # Save model
47 model.save('model.h5')
48 
49 # Additional information about the model
50 # NOTE: This is not needed to run the model
51 
52 # Print summary
53 model.summary()
54 
55 # Visualize model as graph
56 try:
57  from keras.utils.visualize_util import plot
58  plot(model, to_file='model.png', show_shapes=True)
59 except:
60  print('[INFO] Failed to make model plot')
def normal(shape, name=None)