Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RBatchGenerator_TensorFlow.py
Go to the documentation of this file.
1### \file
2### \ingroup tutorial_ml
3### \notebook -nodraw
4### Example of getting batches of events from a ROOT dataset into a basic
5### TensorFlow workflow.
6###
7### \macro_code
8### \macro_output
9### \author Dante Niewenhuis
10
11import ROOT
12
13# TensorFlow has to be imported after ROOT to avoid LLVM symbol clashes if ROOT
14# was built with LLVM in Debug mode and TensorFlow>=2.20.0.
15import tensorflow as tf
16
17tree_name = "sig_tree"
18file_name = str(ROOT.gROOT.GetTutorialDir()) + "/machine_learning/data/Higgs_data.root"
19
20batch_size = 128
21chunk_size = 5000
22block_size = 300
23
24rdataframe = ROOT.RDataFrame(tree_name, file_name)
25target = ["Type"]
26
27# Returns two TF.Dataset for training and validation batches.
29 rdataframe,
30 batch_size,
31 chunk_size,
32 block_size,
33 target = target,
34 validation_split = 0.3,
35 shuffle = True,
36 drop_remainder = True
37)
38
39num_of_epochs = 2
40
41# Datasets have to be repeated as many times as there are epochs
42ds_train_repeated = ds_train.repeat(num_of_epochs)
43ds_valid_repeated = ds_valid.repeat(num_of_epochs)
44
45# Number of batches per epoch must be given for model.fit
46train_batches_per_epoch = ds_train.number_of_batches
47validation_batches_per_epoch = ds_valid.number_of_batches
48
49# Get a list of the columns used for training
50input_columns = ds_train.train_columns
51num_features = len(input_columns)
52
53##############################################################################
54# AI example
55##############################################################################
56
57# Define TensorFlow model
59 [
60 tf.keras.layers.Input(shape=(num_features,)),
61 tf.keras.layers.Dense(300, activation=tf.nn.tanh),
62 tf.keras.layers.Dense(300, activation=tf.nn.tanh),
63 tf.keras.layers.Dense(300, activation=tf.nn.tanh),
65 ]
66)
67
69model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
70
71model.fit(ds_train_repeated, steps_per_epoch=train_batches_per_epoch, validation_data=ds_valid_repeated,\
72 validation_steps=validation_batches_per_epoch, epochs=num_of_epochs)
73
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t UChar_t len
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...