Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RBatchGenerator_PyTorch.py
Go to the documentation of this file.
1### \file
2### \ingroup tutorial_ml
3### \notebook -nodraw
4### Example of getting batches of events from a ROOT dataset into a basic
5### PyTorch workflow.
6###
7### \macro_code
8### \macro_output
9### \author Dante Niewenhuis
10
11import torch
12import ROOT
13
14tree_name = "sig_tree"
15file_name = str(ROOT.gROOT.GetTutorialDir()) + "/machine_learning/data/Higgs_data.root"
16
17batch_size = 128
18chunk_size = 5000
19block_size = 300
20
21rdataframe = ROOT.RDataFrame(tree_name, file_name)
22
23target = "Type"
24
25# Returns two generators that return training and validation batches
26# as PyTorch tensors.
27gen_train, gen_validation = ROOT.TMVA.Experimental.CreatePyTorchGenerators(
28 rdataframe,
29 batch_size,
30 chunk_size,
31 block_size,
32 target = target,
33 validation_split = 0.3,
34 shuffle = True,
35 drop_remainder=True,
36)
37
38# Get a list of the columns used for training
39input_columns = gen_train.train_columns
40num_features = len(input_columns)
41
42
43def calc_accuracy(targets, pred):
44 return torch.sum(targets == pred.round()) / pred.size(0)
45
46
47# Initialize PyTorch model
49 torch.nn.Linear(num_features, 300),
51 torch.nn.Linear(300, 300),
53 torch.nn.Linear(300, 300),
55 torch.nn.Linear(300, 1),
57)
58loss_fn = torch.nn.MSELoss(reduction="mean")
59optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
60
61number_of_epochs = 2
62
63for i in range(number_of_epochs):
64 print("Epoch ", i)
66 # Loop through the training set and train model
67 for i, (x_train, y_train) in enumerate(gen_train):
68 # Make prediction and calculate loss
69 pred = model(x_train)
70 loss = loss_fn(pred, y_train)
71
72 # improve model
76
77 # Calculate accuracy
78 accuracy = calc_accuracy(y_train, pred)
79
80 print(f"Training => accuracy: {accuracy}")
81
82 # #################################################################
83 # # Validation
84 # #################################################################
85
86 model.eval()
87 # Evaluate the model on the validation set
88 for i, (x_val, y_val) in enumerate(gen_validation):
89 # Make prediction and calculate accuracy
90 pred = model(x_val)
91 accuracy = calc_accuracy(y_val, pred)
92
93 print(f"Validation => accuracy: {accuracy}")
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t UChar_t len
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...