Logo ROOT  
Reference Guide
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
Loading...
Searching...
No Matches
ClassificationPyTorch.py
Go to the documentation of this file.
1## \file
2## \ingroup tutorial_tmva_pytorch
3## \notebook -nodraw
4## This tutorial shows how to do classification in TMVA with neural networks
5## trained with PyTorch.
6##
7## \macro_code
8##
9## \date 2020
10## \author Anirudh Dagar <anirudhdagar6@gmail.com> - IIT, Roorkee
11
12
13# PyTorch has to be imported before ROOT to avoid crashes because of clashing
14# std::regexp symbols that are exported by cppyy.
15# See also: https://github.com/wlav/cppyy/issues/227
16import torch
17from torch import nn
18
19from ROOT import TMVA, TFile, TCut, gROOT
20from subprocess import call
21from os.path import isfile
22
23
24# Setup TMVA
27
28# create factory without output file since it is not needed
29factory = TMVA.Factory('TMVAClassification',
30 '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=Classification')
31
32
33# Load data
34data = TFile.Open(str(gROOT.GetTutorialDir()) + '/machine_learning/data/tmva_class_example.root')
35signal = data.Get('TreeS')
36background = data.Get('TreeB')
37
38dataloader = TMVA.DataLoader('dataset')
39for branch in signal.GetListOfBranches():
41
42dataloader.AddSignalTree(signal, 1.0)
43dataloader.AddBackgroundTree(background, 1.0)
45 'nTrain_Signal=4000:nTrain_Background=4000:SplitMode=Random:NormMode=NumEvents:!V')
46
47
48# Generate model
49
50# Define model
51model = nn.Sequential()
52model.add_module('linear_1', nn.Linear(in_features=4, out_features=64))
54model.add_module('linear_2', nn.Linear(in_features=64, out_features=2))
55model.add_module('softmax', nn.Softmax(dim=1))
56
57
58# Construct loss function and Optimizer.
59loss = torch.nn.MSELoss()
60optimizer = torch.optim.SGD
61
62
63# Define train function
64def train(model, train_loader, val_loader, num_epochs, batch_size, optimizer, criterion, save_best, scheduler):
65 trainer = optimizer(model.parameters(), lr=0.01)
66 schedule, schedulerSteps = scheduler
67 best_val = None
68
69 for epoch in range(num_epochs):
70 # Training Loop
71 # Set to train mode
73 running_train_loss = 0.0
74 running_val_loss = 0.0
75 for i, (X, y) in enumerate(train_loader):
77 output = model(X)
78 train_loss = criterion(output, y)
81
82 # print train statistics
83 running_train_loss += train_loss.item()
84 if i % 32 == 31: # print every 32 mini-batches
85 print("[{}, {}] train loss: {:.3f}".format(epoch+1, i+1, running_train_loss / 32))
86 running_train_loss = 0.0
87
88 if schedule:
89 schedule(optimizer, epoch, schedulerSteps)
90
91 # Validation Loop
92 # Set to eval mode
94 with torch.no_grad():
95 for i, (X, y) in enumerate(val_loader):
96 output = model(X)
97 val_loss = criterion(output, y)
98 running_val_loss += val_loss.item()
99
100 curr_val = running_val_loss / len(val_loader)
101 if save_best:
102 if best_val==None:
103 best_val = curr_val
104 best_val = save_best(model, curr_val, best_val)
105
106 # print val statistics per epoch
107 print("[{}] val loss: {:.3f}".format(epoch+1, curr_val))
108 running_val_loss = 0.0
109
110 print("Finished Training on {} Epochs!".format(epoch+1))
111
112 return model
113
114
115# Define predict function
116def predict(model, test_X, batch_size=32):
117 # Set to eval mode
118 model.eval()
119
120 test_dataset = torch.utils.data.TensorDataset(torch.Tensor(test_X))
121 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
122
123 predictions = []
124 with torch.no_grad():
125 for i, data in enumerate(test_loader):
126 X = data[0]
127 outputs = model(X)
128 predictions.append(outputs)
129 preds = torch.cat(predictions)
130
131 return preds.numpy()
132
133
134load_model_custom_objects = {"optimizer": optimizer, "criterion": loss, "train_func": train, "predict_func": predict}
135
136
137# Store model to file
138# Convert the model to torchscript before saving
139m = torch.jit.script(model)
140torch.jit.save(m, "modelClassification.pt")
141print(m)
142
143
144# Book methods
145factory.BookMethod(dataloader, TMVA.Types.kFisher, 'Fisher',
146 '!H:!V:Fisher:VarTransform=D,G')
147factory.BookMethod(dataloader, TMVA.Types.kPyTorch, 'PyTorch',
148 'H:!V:VarTransform=D,G:FilenameModel=modelClassification.pt:FilenameTrainedModel=trainedModelClassification.pt:NumEpochs=20:BatchSize=32')
149
150
151# Run training, test and evaluation
155
156
157# Plot ROC Curves
158roc = factory.GetROCCurve(dataloader)
159roc.SaveAs('ROC_ClassificationPyTorch.png')
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t UChar_t len
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t format
A specialized string object used for TTree selections.
Definition TCut.h:25
This is the main MVA steering class.
Definition Factory.h:80