Logo ROOT  
Reference Guide
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
Loading...
Searching...
No Matches
RBatchGenerator_TensorFlow.py File Reference

Detailed Description

View in nbviewer Open in SWAN

Example of getting batches of events from a ROOT dataset into a basic TensorFlow workflow.

import tensorflow as tf
import ROOT
tree_name = "sig_tree"
file_name = str(ROOT.gROOT.GetTutorialDir()) + "/tmva/data/Higgs_data.root"
batch_size = 128
chunk_size = 5_000
target = "Type"
# Returns two TF.Dataset for training and validation batches.
tree_name,
file_name,
batch_size,
chunk_size,
validation_split=0.3,
target=target,
)
# Get a list of the columns used for training
input_columns = ds_train.train_columns
num_features = len(input_columns)
##############################################################################
# AI example
##############################################################################
# Define TensorFlow model
[
300, activation=tf.nn.tanh, input_shape=(num_features,)
), # input shape required
]
)
model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
# Train model
model.fit(ds_train, validation_data=ds_valid, epochs=2)
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t UChar_t len
Epoch 1/2
1/Unknown - 4s 4s/step - loss: 0.6381 - accuracy: 0.7656␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
21/Unknown - 4s 3ms/step - loss: 0.0356 - accuracy: 0.9888␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
43/Unknown - 4s 2ms/step - loss: 0.0174 - accuracy: 0.9945␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
54/54 [==============================] - 5s 6ms/step - loss: 0.0139 - accuracy: 0.9957 - val_loss: 5.9185e-07 - val_accuracy: 1.0000
Epoch 2/2
1/54 [..............................] - ETA: 1s - loss: 6.1846e-07 - accuracy: 1.0000␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
23/54 [===========>..................] - ETA: 0s - loss: 5.8334e-07 - accuracy: 1.0000␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
45/54 [========================>.....] - ETA: 0s - loss: 5.5920e-07 - accuracy: 1.0000␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
54/54 [==============================] - 0s 3ms/step - loss: 5.5601e-07 - accuracy: 1.0000 - val_loss: 5.6220e-07 - val_accuracy: 1.0000
Author
Dante Niewenhuis

Definition in file RBatchGenerator_TensorFlow.py.