Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ml_dataloader_PyTorch.py
Go to the documentation of this file.
1### \file
2### \ingroup tutorial_ml
3### \notebook -nodraw
4### Example of getting batches of events from a ROOT dataset into a basic
5### PyTorch workflow.
6###
7### \macro_code
8### \macro_output
9### \author Dante Niewenhuis
10
11import ROOT
12import torch
13
14tree_name = "sig_tree"
15file_name = str(ROOT.gROOT.GetTutorialDir()) + "/machine_learning/data/Higgs_data.root"
16
17batch_size = 128
18
19rdataframe = ROOT.RDataFrame(tree_name, file_name)
20
21target = "Type"
22
23# Returns two generators that return training and validation batches
24# as PyTorch tensors.
26 rdataframe,
27 batch_size,
28 target=target,
29 shuffle=True,
30 drop_remainder=True,
31)
32
33gen_train, gen_validation = dl.train_test_split(test_size=0.3)
34
35# Get a list of the columns used for training
36input_columns = gen_train.train_columns
37num_features = len(input_columns)
38
39
40def calc_accuracy(targets, pred):
41 return torch.sum(targets == pred.round()) / pred.size(0)
42
43
44# Initialize PyTorch model
46 torch.nn.Linear(num_features, 300),
48 torch.nn.Linear(300, 300),
50 torch.nn.Linear(300, 300),
52 torch.nn.Linear(300, 1),
54)
55loss_fn = torch.nn.MSELoss(reduction="mean")
56optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
57
58number_of_epochs = 2
59
60for i in range(number_of_epochs):
61 print("Epoch ", i)
63 # Loop through the training set and train model
64 for i, (x_train, y_train) in enumerate(gen_train.as_torch()):
65 # Make prediction and calculate loss
66 pred = model(x_train)
67 loss = loss_fn(pred, y_train)
68
69 # improve model
73
74 # Calculate accuracy
75 accuracy = calc_accuracy(y_train, pred)
76
77 print(f"Training => accuracy: {accuracy}")
78
79 # #################################################################
80 # # Validation
81 # #################################################################
82
84 # Evaluate the model on the validation set
85 for i, (x_val, y_val) in enumerate(gen_validation.as_torch()):
86 # Make prediction and calculate accuracy
87 pred = model(x_val)
88 accuracy = calc_accuracy(y_val, pred)
89
90 print(f"Validation => accuracy: {accuracy}")
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t UChar_t len
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...