Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RBatchGenerator_NumPy.py
Go to the documentation of this file.
1### \file
2### \ingroup tutorial_tmva
3### \notebook -nodraw
4### Example of getting batches of events from a ROOT dataset as Python
5### generators of numpy arrays.
6###
7### \macro_code
8### \macro_output
9### \author Dante Niewenhuis
10
11import ROOT
12
13tree_name = "sig_tree"
14file_name = "http://root.cern/files/Higgs_data.root"
15
16batch_size = 128
17chunk_size = 5_000
18
19rdataframe = ROOT.RDataFrame(tree_name, file_name)
20
21gen_train, gen_validation = ROOT.TMVA.Experimental.CreateNumPyGenerators(
22 rdataframe,
23 batch_size,
24 chunk_size,
25 validation_split=0.3,
26 shuffle=True,
27 drop_remainder=False
28)
29
30# Loop through training set
31for i, b in enumerate(gen_train):
32 print(f"Training batch {i} => {b.shape}")
33
34
35# Loop through Validation set
36for i, b in enumerate(gen_validation):
37 print(f"Validation batch {i} => {b.shape}")
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...