Logo ROOT  
Reference Guide
Loading...
Searching...
No Matches
ml_dataloader_PyTorch.py
Go to the documentation of this file.
1### \file
2### \ingroup tutorial_ml
3### \notebook -nodraw
4### Example of getting batches of events from a ROOT dataset into a basic
5### PyTorch workflow.
6###
7### \macro_code
8### \macro_output
9### \author Dante Niewenhuis
10
11import ROOT
12import torch
13
14tree_name = "sig_tree"
15file_name = str(ROOT.gROOT.GetTutorialDir()) + "/machine_learning/data/Higgs_data.root"
16
17batch_size = 128
18
19rdataframe = ROOT.RDataFrame(tree_name, file_name)
20
21target = "Type"
22
23# Returns two generators that return training and validation batches
24# as PyTorch tensors.
25dl = ROOT.Experimental.ML.RDataLoader(
26 rdataframe,
27 batch_size,
28 target=target,
29 shuffle=True,
30 drop_remainder=True,
31)
32
33gen_train, gen_validation = dl.train_test_split(test_size=0.3)
34
35# Get a list of the columns used for training
36input_columns = gen_train.train_columns
37num_features = len(input_columns)
38
39
40def calc_accuracy(targets, pred):
41 return torch.sum(targets == pred.round()) / pred.size(0)
42
43
44# Initialize PyTorch model
45model = torch.nn.Sequential(
46 torch.nn.Linear(num_features, 300),
47 torch.nn.Tanh(),
48 torch.nn.Linear(300, 300),
49 torch.nn.Tanh(),
50 torch.nn.Linear(300, 300),
51 torch.nn.Tanh(),
52 torch.nn.Linear(300, 1),
53 torch.nn.Sigmoid(),
54)
55loss_fn = torch.nn.MSELoss(reduction="mean")
56optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
57
58number_of_epochs = 2
59
60for i in range(number_of_epochs):
61 print("Epoch ", i)
62 model.train()
63 # Loop through the training set and train model
64 for i, (x_train, y_train) in enumerate(gen_train.as_torch()):
65 # Make prediction and calculate loss
66 pred = model(x_train)
67 loss = loss_fn(pred, y_train)
68
69 # improve model
70 model.zero_grad()
71 loss.backward()
72 optimizer.step()
73
74 # Calculate accuracy
75 accuracy = calc_accuracy(y_train, pred)
76
77 print(f"Training => accuracy: {accuracy}")
78
79 # #################################################################
80 # # Validation
81 # #################################################################
82
83 model.eval()
84 # Evaluate the model on the validation set
85 for i, (x_val, y_val) in enumerate(gen_validation.as_torch()):
86 # Make prediction and calculate accuracy
87 pred = model(x_val)
88 accuracy = calc_accuracy(y_val, pred)
89
90 print(f"Validation => accuracy: {accuracy}")
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...