Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RBatchGenerator_filters_vectors.py
Go to the documentation of this file.
1##################################################
2# This tutorial shows the usage of filters and vectors
3# when using RBatchGenerator
4##################################################
5
6import ROOT
7
8
9tree_name = "test_tree"
10file_name = (
11 ROOT.gROOT.GetTutorialDir().Data()
12 + "/tmva/RBatchGenerator_filters_vectors_hvector.root"
13)
14
15chunk_size = 50 # Defines the size of the chunks
16batch_size = 5 # Defines the size of the returned batches
17
18# Define filters as strings
19filters = ["f1 > 30", "f2 < 70", "f3 == true"]
20max_vec_sizes = {"f4": 3, "f5": 2, "f6": 1}
21
22ds_train, ds_validation = ROOT.TMVA.Experimental.CreateNumPyGenerators(
23 tree_name,
24 file_name,
25 batch_size,
26 chunk_size,
27 validation_split=0.3,
28 filters=filters,
29 max_vec_sizes=max_vec_sizes,
30 shuffle=True,
31)
32
33print(f"Columns: {ds_train.columns}")
34
35for i, b in enumerate(ds_train):
36 print(f"Training batch {i} => {b.shape}")
37
38for i, b in enumerate(ds_validation):
39 print(f"Validation batch {i} => {b.shape}")