Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RBatchGenerator_NumPy.py
Go to the documentation of this file.
1### \file
2### \ingroup tutorial_tmva
3### \notebook -nodraw
4###
5### Example of getting batches of events from a ROOT dataset as Python
6### generators of numpy arrays.
7###
8### \macro_code
9### \macro_output
10### \author Dante Niewenhuis
11
12import ROOT
13
14tree_name = "sig_tree"
15file_name = "http://root.cern/files/Higgs_data.root"
16
17batch_size = 128
18chunk_size = 5_000
19
20ds_train, ds_validation = ROOT.TMVA.Experimental.CreateNumPyGenerators(
21 tree_name,
22 file_name,
23 batch_size,
24 chunk_size,
25 validation_split=0.3,
26 shuffle=True,
27)
28
29# Loop through training set
30for i, b in enumerate(ds_train):
31 print(f"Training batch {i} => {b.shape}")
32
33
34# Loop through Validation set
35for i, b in enumerate(ds_validation):
36 print(f"Validation batch {i} => {b.shape}")