Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA_SOFIE_Keras.py
Go to the documentation of this file.
1### \file
2### \ingroup tutorial_ml
3### \notebook -nodraw
4### This macro provides a simple example for the parsing of Keras .keras file
5### into RModel object and further generating the .hxx header files for inference.
6###
7### \macro_code
8### \macro_output
9### \author Sanjiban Sengupta and Lorenzo Moneta
10
11
12
13import ROOT
14
15# Enable ROOT in batch mode (same effect as -nodraw)
17
18# -----------------------------------------------------------------------------
19# Step 1: Create and train a simple Keras model (via embedded Python)
20# -----------------------------------------------------------------------------
21
22import numpy as np
23from tensorflow.keras.layers import Activation, Dense, Input, Softmax
24from tensorflow.keras.models import Model
25
26input=Input(shape=(4,),batch_size=2)
27x=Dense(32)(input)
28x=Activation('relu')(x)
29x=Dense(16,activation='relu')(x)
30x=Dense(8,activation='relu')(x)
31x=Dense(2)(x)
32output=Softmax()(x)
33model=Model(inputs=input,outputs=output)
34
35randomGenerator=np.random.RandomState(0)
36x_train=randomGenerator.rand(4,4)
37y_train=randomGenerator.rand(4,2)
38
39model.compile(loss='mse', optimizer='adam')
40model.fit(x_train, y_train, epochs=3, batch_size=2)
41model.save('KerasModel.keras')
43
44# -----------------------------------------------------------------------------
45# Step 2: Use TMVA::SOFIE to parse the ONNX model
46# -----------------------------------------------------------------------------
47
48import ROOT
49
50# Parse the ONNX model
51
52model = ROOT.TMVA.Experimental.SOFIE.PyKeras.Parse("KerasModel.keras")
53
54# Generate inference code
57#print generated code
58print("\n**************************************************")
59print(" Generated code")
60print("**************************************************\n")
62print("**************************************************\n\n\n")
63
64# Compile the generated code
65ROOT.gInterpreter.Declare('#include "KerasModel.hxx"')
66
67
68# -----------------------------------------------------------------------------
69# Step 3: Run inference
70# -----------------------------------------------------------------------------
71
72#instantiate SOFIE session class
74
75# Input tensor (same shape as training input)
76x = np.array([[0.1, 0.2, 0.3, 0.4],[0.5, 0.6, 0.7, 0.8]], dtype=np.float32)
77
78# Run inference
79y = session.infer(x)
80
81print("Inference output:", y)
82
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.