Logo ROOT  
Reference Guide
Loading...
Searching...
No Matches
MulticlassPyTorch.py
Go to the documentation of this file.
1## \file
2## \ingroup tutorial_tmva_pytorch
3## \notebook -nodraw
4## This tutorial shows how to do multiclass classification in TMVA with neural
5## networks trained with PyTorch.
6##
7## \macro_code
8##
9## \date 2020
10## \author Anirudh Dagar <anirudhdagar6@gmail.com> - IIT, Roorkee
11
12
13# PyTorch has to be imported before ROOT to avoid crashes because of clashing
14# std::regexp symbols that are exported by cppyy.
15# See also: https://github.com/wlav/cppyy/issues/227
16import torch
17from torch import nn
18
19from ROOT import TMVA, TFile, TCut, gROOT
20from os.path import isfile
21
22
23# Setup TMVA
26
27# create factory without output file since it is not needed
28factory = TMVA.Factory('TMVAClassification',
29 '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=multiclass')
30
31
32# Load data
33if not isfile('tmva_example_multiple_background.root'):
34 createDataMacro = str(gROOT.GetTutorialDir()) + '/machine_learning/createData.C'
35 print(createDataMacro)
36 gROOT.ProcessLine('.L {}'.format(createDataMacro))
37 gROOT.ProcessLine('create_MultipleBackground(4000)')
38
39data = TFile.Open('tmva_example_multiple_background.root')
40signal = data.Get('TreeS')
41background0 = data.Get('TreeB0')
42background1 = data.Get('TreeB1')
43background2 = data.Get('TreeB2')
44
45dataloader = TMVA.DataLoader('dataset')
46for branch in signal.GetListOfBranches():
47 dataloader.AddVariable(branch.GetName())
48
49dataloader.AddTree(signal, 'Signal')
50dataloader.AddTree(background0, 'Background_0')
51dataloader.AddTree(background1, 'Background_1')
52dataloader.AddTree(background2, 'Background_2')
53dataloader.PrepareTrainingAndTestTree(TCut(''),
54 'SplitMode=Random:NormMode=NumEvents:!V')
55
56
57# Generate model
58# Define model
59model = nn.Sequential()
60model.add_module('linear_1', nn.Linear(in_features=4, out_features=32))
61model.add_module('relu', nn.ReLU())
62model.add_module('linear_2', nn.Linear(in_features=32, out_features=4))
63model.add_module('softmax', nn.Softmax(dim=1))
64
65
66# Set loss and optimizer
67loss = nn.CrossEntropyLoss()
68optimizer = torch.optim.SGD
69
70
71# Define train function
72def train(model, train_loader, val_loader, num_epochs, batch_size, optimizer, criterion, save_best, scheduler):
73 trainer = optimizer(model.parameters(), lr=0.01)
74 schedule, schedulerSteps = scheduler
75 best_val = None
76
77 for epoch in range(num_epochs):
78 # Training Loop
79 # Set to train mode
80 model.train()
81 running_train_loss = 0.0
82 running_val_loss = 0.0
83 for i, (X, y) in enumerate(train_loader):
84 trainer.zero_grad()
85 output = model(X)
86 target = torch.max(y, 1)[1]
87 train_loss = criterion(output, target)
88 train_loss.backward()
89 trainer.step()
90
91 # print train statistics
92 running_train_loss += train_loss.item()
93 if i % 32 == 31: # print every 32 mini-batches
94 print("[{}, {}] train loss: {:.3f}".format(epoch+1, i+1, running_train_loss / 32))
95 running_train_loss = 0.0
96
97 if schedule:
98 schedule(optimizer, epoch, schedulerSteps)
99
100 # Validation Loop
101 # Set to eval mode
102 model.eval()
103 with torch.no_grad():
104 for i, (X, y) in enumerate(val_loader):
105 output = model(X)
106 target = torch.max(y, 1)[1]
107 val_loss = criterion(output, target)
108 running_val_loss += val_loss.item()
109
110 curr_val = running_val_loss / len(val_loader)
111 if save_best:
112 if best_val==None:
113 best_val = curr_val
114 best_val = save_best(model, curr_val, best_val)
115
116 # print val statistics per epoch
117 print("[{}] val loss: {:.3f}".format(epoch+1, curr_val))
118 running_val_loss = 0.0
119
120 print("Finished Training on {} Epochs!".format(epoch+1))
121
122 return model
123
124
125# Define predict function
126def predict(model, test_X, batch_size=32):
127 # Set to eval mode
128 model.eval()
129
130 test_dataset = torch.utils.data.TensorDataset(torch.Tensor(test_X))
131 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
132
133 predictions = []
134 with torch.no_grad():
135 for i, data in enumerate(test_loader):
136 X = data[0]
137 outputs = model(X)
138 predictions.append(outputs)
139 preds = torch.cat(predictions)
140
141 return preds.numpy()
142
143
144load_model_custom_objects = {"optimizer": optimizer, "criterion": loss, "train_func": train, "predict_func": predict}
145
146
147# Store model to file
148# Convert the model to torchscript before saving
149m = torch.jit.script(model)
150torch.jit.save(m, "modelMultiClass.pt")
151print(m)
152
153
154# Book methods
155factory.BookMethod(dataloader, TMVA.Types.kFisher, 'Fisher',
156 '!H:!V:Fisher:VarTransform=D,G')
157factory.BookMethod(dataloader, TMVA.Types.kPyTorch, "PyTorch",
158 'H:!V:VarTransform=D,G:FilenameModel=modelMultiClass.pt:FilenameTrainedModel=trainedModelMultiClass.pt:NumEpochs=20:BatchSize=32')
159
160
161# Run TMVA
162factory.TrainAllMethods()
163factory.TestAllMethods()
164factory.EvaluateAllMethods()
165
166# Plot ROC Curves
167roc = factory.GetROCCurve(dataloader)
168roc.SaveAs('ROC_MulticlassPyTorch.png')
A specialized string object used for TTree selections.
Definition TCut.h:25
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:3787
This is the main MVA steering class.
Definition Factory.h:80
static void PyInitialize()
Initialize Python interpreter.
static Tools & Instance()
Definition Tools.cxx:72