Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RBatchGenerator_PyTorch.py File Reference

Detailed Description

View in nbviewer Open in SWAN
Example of getting batches of events from a ROOT dataset into a basic PyTorch workflow.

import torch
import ROOT
tree_name = "sig_tree"
file_name = str(ROOT.gROOT.GetTutorialDir()) + "/machine_learning/data/Higgs_data.root"
batch_size = 128
chunk_size = 5000
block_size = 300
rdataframe = ROOT.RDataFrame(tree_name, file_name)
target = "Type"
# Returns two generators that return training and validation batches
# as PyTorch tensors.
rdataframe,
batch_size,
chunk_size,
block_size,
target = target,
validation_split = 0.3,
shuffle = True,
drop_remainder=True,
)
# Get a list of the columns used for training
input_columns = gen_train.train_columns
num_features = len(input_columns)
def calc_accuracy(targets, pred):
return torch.sum(targets == pred.round()) / pred.size(0)
# Initialize PyTorch model
torch.nn.Linear(num_features, 300),
torch.nn.Linear(300, 300),
torch.nn.Linear(300, 300),
torch.nn.Linear(300, 1),
)
loss_fn = torch.nn.MSELoss(reduction="mean")
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
number_of_epochs = 2
for i in range(number_of_epochs):
print("Epoch ", i)
# Loop through the training set and train model
for i, (x_train, y_train) in enumerate(gen_train):
# Make prediction and calculate loss
pred = model(x_train)
loss = loss_fn(pred, y_train)
# improve model
# Calculate accuracy
accuracy = calc_accuracy(y_train, pred)
print(f"Training => accuracy: {accuracy}")
# #################################################################
# # Validation
# #################################################################
# Evaluate the model on the validation set
for i, (x_val, y_val) in enumerate(gen_validation):
# Make prediction and calculate accuracy
pred = model(x_val)
accuracy = calc_accuracy(y_val, pred)
print(f"Validation => accuracy: {accuracy}")
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t UChar_t len
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...
Epoch 0
Training => accuracy: 0.3671875
Training => accuracy: 0.453125
Training => accuracy: 0.8359375
Training => accuracy: 0.953125
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Epoch 1
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Training => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Validation => accuracy: 1.0
Author
Dante Niewenhuis

Definition in file RBatchGenerator_PyTorch.py.