Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ml_dataloader_NumPy.py
Go to the documentation of this file.
1### \file
2### \ingroup tutorial_ml
3### \notebook -nodraw
4### Example of getting batches of events from a ROOT dataset as Python
5### generators of numpy arrays.
6###
7### \macro_code
8### \macro_output
9### \author Dante Niewenhuis
10
11import ROOT
12
13tree_name = "sig_tree"
14file_name = str(ROOT.gROOT.GetTutorialDir()) + "/machine_learning/data/Higgs_data.root"
15
16batch_size = 128
17
18rdataframe = ROOT.RDataFrame(tree_name, file_name)
19
20target = "Type"
21
22num_of_epochs = 2
23
25 rdataframe,
26 batch_size,
27 target=target,
28 shuffle=True,
29 drop_remainder=True,
30)
31
32gen_train, gen_validation = dl.train_test_split(test_size=0.3)
33
34for i in range(num_of_epochs):
35 # Loop through training set
36 for i, (x_train, y_train) in enumerate(gen_train.as_numpy()):
37 print(f"Training batch {i + 1} => x: {x_train.shape}, y: {y_train.shape}")
38
39 # Loop through Validation set
40 for i, (x_validation, y_validation) in enumerate(gen_validation.as_numpy()):
41 print(f"Validation batch {i + 1} => x: {x_validation.shape}, y: {y_validation.shape}")
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...