Logo ROOT  
Reference Guide
Loading...
Searching...
No Matches
ml_dataloader_TensorFlow.py
Go to the documentation of this file.
1### \file
2### \ingroup tutorial_ml
3### \notebook -nodraw
4### Example of getting batches of events from a ROOT dataset into a basic
5### TensorFlow workflow.
6###
7### \macro_code
8### \macro_output
9### \author Dante Niewenhuis
10
11import ROOT
12
13# TensorFlow has to be imported after ROOT to avoid LLVM symbol clashes if ROOT
14# was built with LLVM in Debug mode and TensorFlow>=2.20.0.
15import tensorflow as tf
16
17tree_name = "sig_tree"
18file_name = str(ROOT.gROOT.GetTutorialDir()) + "/machine_learning/data/Higgs_data.root"
19
20batch_size = 128
21approx_batches_in_memory = 50
22
23rdataframe = ROOT.RDataFrame(tree_name, file_name)
24target = ["Type"]
25
26# Returns two TF.Dataset for training and validation batches.
27dl = ROOT.Experimental.ML.RDataLoader(
28 rdataframe,
29 batch_size,
30 approx_batches_in_memory,
31 target=target,
32 shuffle=True,
33 drop_remainder=True,
34)
35
36ds_train, ds_valid = dl.train_test_split(test_size=0.3)
37
38num_of_epochs = 2
39
40# Datasets have to be repeated as many times as there are epochs
41ds_train_repeated = ds_train.as_tensorflow().repeat(num_of_epochs)
42ds_valid_repeated = ds_valid.as_tensorflow().repeat(num_of_epochs)
43
44# Number of batches per epoch must be given for model.fit
45train_batches_per_epoch = ds_train.num_batches
46validation_batches_per_epoch = ds_valid.num_batches
47
48# Get a list of the columns used for training
49input_columns = ds_train.train_columns
50num_features = len(input_columns)
51
52##############################################################################
53# AI example
54##############################################################################
55
56# Define TensorFlow model
57model = tf.keras.Sequential(
58 [
59 tf.keras.layers.Input(shape=(num_features,)),
60 tf.keras.layers.Dense(300, activation=tf.nn.tanh),
61 tf.keras.layers.Dense(300, activation=tf.nn.tanh),
62 tf.keras.layers.Dense(300, activation=tf.nn.tanh),
63 tf.keras.layers.Dense(1, activation=tf.nn.sigmoid),
64 ]
65)
66
67loss_fn = tf.keras.losses.BinaryCrossentropy()
68model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
69
70model.fit(
71 ds_train_repeated,
72 steps_per_epoch=train_batches_per_epoch,
73 validation_data=ds_valid_repeated,
74 validation_steps=validation_batches_per_epoch,
75 epochs=num_of_epochs,
76)
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...