Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RBatchGenerator_PyTorch.py
Go to the documentation of this file.
1### \file
2### \ingroup tutorial_tmva
3### \notebook -nodraw
4### Example of getting batches of events from a ROOT dataset into a basic
5### PyTorch workflow.
6###
7### \macro_code
8### \macro_output
9### \author Dante Niewenhuis
10
11import torch
12import ROOT
13
14tree_name = "sig_tree"
15file_name = "http://root.cern/files/Higgs_data.root"
16
17batch_size = 128
18chunk_size = 5_000
19
20rdataframe = ROOT.RDataFrame(tree_name, file_name)
21
22target = "Type"
23
24# Returns two generators that return training and validation batches
25# as PyTorch tensors.
26gen_train, gen_validation = ROOT.TMVA.Experimental.CreatePyTorchGenerators(
27 rdataframe,
28 batch_size,
29 chunk_size,
30 target=target,
31 validation_split=0.3,
32)
33
34# Get a list of the columns used for training
35input_columns = gen_train.train_columns
36num_features = len(input_columns)
37
38
39def calc_accuracy(targets, pred):
40 return torch.sum(targets == pred.round()) / pred.size(0)
41
42
43# Initialize PyTorch model
44model = torch.nn.Sequential(
45 torch.nn.Linear(num_features, 300),
46 torch.nn.Tanh(),
47 torch.nn.Linear(300, 300),
48 torch.nn.Tanh(),
49 torch.nn.Linear(300, 300),
50 torch.nn.Tanh(),
51 torch.nn.Linear(300, 1),
52 torch.nn.Sigmoid(),
53)
54loss_fn = torch.nn.MSELoss(reduction="mean")
55optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
56
57number_of_epochs = 2
58
59for _ in range(number_of_epochs):
60 # Loop through the training set and train model
61 for i, (x_train, y_train) in enumerate(gen_train):
62 # Make prediction and calculate loss
63 pred = model(x_train)
64 loss = loss_fn(pred, y_train)
65
66 # improve model
67 model.zero_grad()
68 loss.backward()
69 optimizer.step()
70
71 # Calculate accuracy
72 accuracy = calc_accuracy(y_train, pred)
73
74 print(f"Training => accuracy: {accuracy}")
75
76 #################################################################
77 # Validation
78 #################################################################
79
80 # Evaluate the model on the validation set
81 for i, (x_train, y_train) in enumerate(gen_validation):
82 # Make prediction and calculate accuracy
83 pred = model(x_train)
84 accuracy = calc_accuracy(y_train, pred)
85
86 print(f"Validation => accuracy: {accuracy}")
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t UChar_t len
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...