Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RBatchGenerator_TensorFlow.py File Reference

Detailed Description

View in nbviewer Open in SWAN
Example of getting batches of events from a ROOT dataset into a basic TensorFlow workflow.

import tensorflow as tf
import ROOT
tree_name = "sig_tree"
file_name = "http://root.cern/files/Higgs_data.root"
batch_size = 128
chunk_size = 5_000
rdataframe = ROOT.RDataFrame(tree_name, file_name)
target = "Type"
# Returns two TF.Dataset for training and validation batches.
ds_train, ds_valid = ROOT.TMVA.Experimental.CreateTFDatasets(
rdataframe,
batch_size,
chunk_size,
validation_split=0.3,
target=target,
)
num_of_epochs = 2
# Datasets have to be repeated as many times as there are epochs
ds_train_repeated = ds_train.repeat(num_of_epochs)
ds_valid_repeated = ds_valid.repeat(num_of_epochs)
# Number of batches per epoch must be given for model.fit
train_batches_per_epoch = ds_train.number_of_batches
validation_batches_per_epoch = ds_valid.number_of_batches
# Get a list of the columns used for training
input_columns = ds_train.train_columns
num_features = len(input_columns)
##############################################################################
# AI example
##############################################################################
# Define TensorFlow model
model = tf.keras.Sequential(
[
tf.keras.layers.Input(shape=(num_features,)),
tf.keras.layers.Dense(300, activation=tf.nn.tanh),
tf.keras.layers.Dense(300, activation=tf.nn.tanh),
tf.keras.layers.Dense(300, activation=tf.nn.tanh),
tf.keras.layers.Dense(1, activation=tf.nn.sigmoid),
]
)
loss_fn = tf.keras.losses.BinaryCrossentropy()
model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
# Train model
model.fit(ds_train_repeated, steps_per_epoch=train_batches_per_epoch, validation_data=ds_valid_repeated,\
validation_steps=validation_batches_per_epoch, epochs=num_of_epochs)
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t UChar_t len
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...
Epoch 1/2
1/54 [..............................] - ETA: 2:58 - loss: 0.5529 - accuracy: 0.9297␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
11/54 [=====>........................] - ETA: 0s - loss: 0.0585 - accuracy: 0.9936 ␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
24/54 [============>.................] - ETA: 0s - loss: 0.0268 - accuracy: 0.9971␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
38/54 [====================>.........] - ETA: 0s - loss: 0.0169 - accuracy: 0.9981␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
51/54 [===========================>..] - ETA: 0s - loss: 0.0126 - accuracy: 0.9986␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
54/54 [==============================] - 4s 11ms/step - loss: 0.0119 - accuracy: 0.9987 - val_loss: 4.9476e-07 - val_accuracy: 1.0000
Epoch 2/2
1/54 [..............................] - ETA: 0s - loss: 4.1275e-07 - accuracy: 1.0000␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
17/54 [========>.....................] - ETA: 0s - loss: 4.4795e-07 - accuracy: 1.0000␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
32/54 [================>.............] - ETA: 0s - loss: 4.4375e-07 - accuracy: 1.0000␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
48/54 [=========================>....] - ETA: 0s - loss: 4.4154e-07 - accuracy: 1.0000␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
54/54 [==============================] - 0s 9ms/step - loss: 4.4386e-07 - accuracy: 1.0000 - val_loss: 4.7763e-07 - val_accuracy: 1.0000
Author
Dante Niewenhuis

Definition in file RBatchGenerator_TensorFlow.py.