Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RBatchGenerator_TensorFlow.py
Go to the documentation of this file.
1### \file
2### \ingroup tutorial_tmva
3### \notebook -nodraw
4### Example of getting batches of events from a ROOT dataset into a basic
5### TensorFlow workflow.
6###
7### \macro_code
8### \macro_output
9### \author Dante Niewenhuis
10
11import tensorflow as tf
12import ROOT
13
14tree_name = "sig_tree"
15file_name = "http://root.cern/files/Higgs_data.root"
16
17batch_size = 128
18chunk_size = 5_000
19
20rdataframe = ROOT.RDataFrame(tree_name, file_name)
21
22target = "Type"
23
24# Returns two TF.Dataset for training and validation batches.
25ds_train, ds_valid = ROOT.TMVA.Experimental.CreateTFDatasets(
26 rdataframe,
27 batch_size,
28 chunk_size,
29 validation_split=0.3,
30 target=target,
31)
32
33num_of_epochs = 2
34
35# Datasets have to be repeated as many times as there are epochs
36ds_train_repeated = ds_train.repeat(num_of_epochs)
37ds_valid_repeated = ds_valid.repeat(num_of_epochs)
38
39# Number of batches per epoch must be given for model.fit
40train_batches_per_epoch = ds_train.number_of_batches
41validation_batches_per_epoch = ds_valid.number_of_batches
42
43# Get a list of the columns used for training
44input_columns = ds_train.train_columns
45num_features = len(input_columns)
46
47##############################################################################
48# AI example
49##############################################################################
50
51# Define TensorFlow model
52model = tf.keras.Sequential(
53 [
54 tf.keras.layers.Input(shape=(num_features,)),
55 tf.keras.layers.Dense(300, activation=tf.nn.tanh),
56 tf.keras.layers.Dense(300, activation=tf.nn.tanh),
57 tf.keras.layers.Dense(300, activation=tf.nn.tanh),
58 tf.keras.layers.Dense(1, activation=tf.nn.sigmoid),
59 ]
60)
61loss_fn = tf.keras.losses.BinaryCrossentropy()
62model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
63
64# Train model
65model.fit(ds_train_repeated, steps_per_epoch=train_batches_per_epoch, validation_data=ds_valid_repeated,\
66 validation_steps=validation_batches_per_epoch, epochs=num_of_epochs)
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t UChar_t len
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...