Example of getting batches of events from a ROOT dataset as Python generators of numpy arrays.
import ROOT
tree_name = "sig_tree"
file_name = str(ROOT.gROOT.GetTutorialDir()) + "/machine_learning/data/Higgs_data.root"
batch_size = 128
target = "Type"
num_of_epochs = 2
dl = ROOT.Experimental.ML.RDataLoader(
rdataframe,
batch_size,
target=target,
shuffle=True,
drop_remainder=True,
)
gen_train, gen_validation = dl.train_test_split(test_size=0.3)
for i in range(num_of_epochs):
for i, (x_train, y_train) in enumerate(gen_train.as_numpy()):
print(f"Training batch {i + 1} => x: {x_train.shape}, y: {y_train.shape}")
for i, (x_validation, y_validation) in enumerate(gen_validation.as_numpy()):
print(f"Validation batch {i + 1} => x: {x_validation.shape}, y: {y_validation.shape}")
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...
- Author
- Dante Niewenhuis
Definition in file ml_dataloader_NumPy.py.