Logo ROOT  
Reference Guide
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
Loading...
Searching...
No Matches
RegressionPyTorch.py
Go to the documentation of this file.
1#!/usr/bin/env python
2## \file
3## \ingroup tutorial_tmva_pytorch
4## \notebook -nodraw
5## This tutorial shows how to do regression in TMVA with neural networks
6## trained with PyTorch.
7##
8## \macro_code
9##
10## \date 2020
11## \author Anirudh Dagar <anirudhdagar6@gmail.com> - IIT, Roorkee
12
13
14# PyTorch has to be imported before ROOT to avoid crashes because of clashing
15# std::regexp symbols that are exported by cppyy.
16# See also: https://github.com/wlav/cppyy/issues/227
17import torch
18from torch import nn
19
20from ROOT import TMVA, TFile, TCut, gROOT
21from subprocess import call
22from os.path import isfile
23
24
25# Setup TMVA
28
29# create factory without output file since it is not needed
30factory = TMVA.Factory('TMVARegression',
31 '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=Regression')
32
33
34# Load data
35
36data = TFile.Open(str(gROOT.GetTutorialDir()) + '/tmva/data/tmva_reg_example.root')
37tree = data.Get('TreeR')
38
39dataloader = TMVA.DataLoader('dataset')
40for branch in tree.GetListOfBranches():
41 name = branch.GetName()
42 if name != 'fvalue':
45
48 'nTrain_Regression=4000:SplitMode=Random:NormMode=NumEvents:!V')
49
50
51# Generate model
52
53# Define model
54model = nn.Sequential()
55model.add_module('linear_1', nn.Linear(in_features=2, out_features=64))
57model.add_module('linear_2', nn.Linear(in_features=64, out_features=1))
58
59
60# Construct loss function and Optimizer.
61loss = torch.nn.MSELoss()
62optimizer = torch.optim.SGD
63
64
65# Define train function
66def train(model, train_loader, val_loader, num_epochs, batch_size, optimizer, criterion, save_best, scheduler):
67 trainer = optimizer(model.parameters(), lr=0.01)
68 schedule, schedulerSteps = scheduler
69 best_val = None
70
71 for epoch in range(num_epochs):
72 # Training Loop
73 # Set to train mode
75 running_train_loss = 0.0
76 running_val_loss = 0.0
77 for i, (X, y) in enumerate(train_loader):
79 output = model(X)
80 train_loss = criterion(output, y)
83
84 # print train statistics
85 running_train_loss += train_loss.item()
86 if i % 32 == 31: # print every 32 mini-batches
87 print("[{}, {}] train loss: {:.3f}".format(epoch+1, i+1, running_train_loss / 32))
88 running_train_loss = 0.0
89
90 if schedule:
91 schedule(optimizer, epoch, schedulerSteps)
92
93 # Validation Loop
94 # Set to eval mode
96 with torch.no_grad():
97 for i, (X, y) in enumerate(val_loader):
98 output = model(X)
99 val_loss = criterion(output, y)
100 running_val_loss += val_loss.item()
101
102 curr_val = running_val_loss / len(val_loader)
103 if save_best:
104 if best_val==None:
105 best_val = curr_val
106 best_val = save_best(model, curr_val, best_val)
107
108 # print val statistics per epoch
109 print("[{}] val loss: {:.3f}".format(epoch+1, curr_val))
110 running_val_loss = 0.0
111
112 print("Finished Training on {} Epochs!".format(epoch+1))
113
114 return model
115
116
117# Define predict function
118def predict(model, test_X, batch_size=32):
119 # Set to eval mode
120 model.eval()
121
122 test_dataset = torch.utils.data.TensorDataset(torch.Tensor(test_X))
123 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
124
125 predictions = []
126 with torch.no_grad():
127 for i, data in enumerate(test_loader):
128 X = data[0]
129 outputs = model(X)
130 predictions.append(outputs)
131 preds = torch.cat(predictions)
132
133 return preds.numpy()
134
135
136load_model_custom_objects = {"optimizer": optimizer, "criterion": loss, "train_func": train, "predict_func": predict}
137
138
139# Store model to file
140# Convert the model to torchscript before saving
141m = torch.jit.script(model)
142torch.jit.save(m, "modelRegression.pt")
143print(m)
144
145
146# Book methods
147factory.BookMethod(dataloader, TMVA.Types.kPyTorch, 'PyTorch',
148 'H:!V:VarTransform=D,G:FilenameModel=modelRegression.pt:FilenameTrainedModel=trainedModelRegression.pt:NumEpochs=20:BatchSize=32')
149factory.BookMethod(dataloader, TMVA.Types.kBDT, 'BDTG',
150 '!H:!V:VarTransform=D,G:NTrees=1000:BoostType=Grad:Shrinkage=0.1:UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=4')
151
152
153# Run TMVA
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t UChar_t len
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t format
A specialized string object used for TTree selections.
Definition TCut.h:25
This is the main MVA steering class.
Definition Factory.h:80