Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RegressionPyTorch.py
Go to the documentation of this file.
1#!/usr/bin/env python
2## \file
3## \ingroup tutorial_tmva_pytorch
4## \notebook -nodraw
5## This tutorial shows how to do regression in TMVA with neural networks
6## trained with PyTorch.
7##
8## \macro_code
9##
10## \date 2020
11## \author Anirudh Dagar <anirudhdagar6@gmail.com> - IIT, Roorkee
12
13
14# PyTorch has to be imported before ROOT to avoid crashes because of clashing
15# std::regexp symbols that are exported by cppyy.
16# See also: https://github.com/wlav/cppyy/issues/227
17import torch
18from torch import nn
19
20from ROOT import TMVA, TFile, TTree, TCut
21from subprocess import call
22from os.path import isfile
23
24
25# Setup TMVA
28
29# create factory without output file since it is not needed
30factory = TMVA.Factory('TMVARegression',
31 '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=Regression')
32
33
34# Load data
36data = TFile.Open("http://root.cern.ch/files/tmva_reg_example.root", "CACHEREAD")
37if data is None:
38 raise FileNotFoundError("Input file cannot be downloaded - exit")
39
40tree = data.Get('TreeR')
41
42dataloader = TMVA.DataLoader('dataset')
43for branch in tree.GetListOfBranches():
44 name = branch.GetName()
45 if name != 'fvalue':
46 dataloader.AddVariable(name)
47dataloader.AddTarget('fvalue')
48
49dataloader.AddRegressionTree(tree, 1.0)
50dataloader.PrepareTrainingAndTestTree(TCut(''),
51 'nTrain_Regression=4000:SplitMode=Random:NormMode=NumEvents:!V')
52
53
54# Generate model
55
56# Define model
57model = nn.Sequential()
58model.add_module('linear_1', nn.Linear(in_features=2, out_features=64))
59model.add_module('relu', nn.Tanh())
60model.add_module('linear_2', nn.Linear(in_features=64, out_features=1))
61
62
63# Construct loss function and Optimizer.
64loss = torch.nn.MSELoss()
65optimizer = torch.optim.SGD
66
67
68# Define train function
69def train(model, train_loader, val_loader, num_epochs, batch_size, optimizer, criterion, save_best, scheduler):
70 trainer = optimizer(model.parameters(), lr=0.01)
71 schedule, schedulerSteps = scheduler
72 best_val = None
73
74 for epoch in range(num_epochs):
75 # Training Loop
76 # Set to train mode
77 model.train()
78 running_train_loss = 0.0
79 running_val_loss = 0.0
80 for i, (X, y) in enumerate(train_loader):
81 trainer.zero_grad()
82 output = model(X)
83 train_loss = criterion(output, y)
84 train_loss.backward()
85 trainer.step()
86
87 # print train statistics
88 running_train_loss += train_loss.item()
89 if i % 32 == 31: # print every 32 mini-batches
90 print("[{}, {}] train loss: {:.3f}".format(epoch+1, i+1, running_train_loss / 32))
91 running_train_loss = 0.0
92
93 if schedule:
94 schedule(optimizer, epoch, schedulerSteps)
95
96 # Validation Loop
97 # Set to eval mode
98 model.eval()
99 with torch.no_grad():
100 for i, (X, y) in enumerate(val_loader):
101 output = model(X)
102 val_loss = criterion(output, y)
103 running_val_loss += val_loss.item()
104
105 curr_val = running_val_loss / len(val_loader)
106 if save_best:
107 if best_val==None:
108 best_val = curr_val
109 best_val = save_best(model, curr_val, best_val)
110
111 # print val statistics per epoch
112 print("[{}] val loss: {:.3f}".format(epoch+1, curr_val))
113 running_val_loss = 0.0
114
115 print("Finished Training on {} Epochs!".format(epoch+1))
116
117 return model
118
119
120# Define predict function
121def predict(model, test_X, batch_size=32):
122 # Set to eval mode
123 model.eval()
124
125 test_dataset = torch.utils.data.TensorDataset(torch.Tensor(test_X))
126 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
127
128 predictions = []
129 with torch.no_grad():
130 for i, data in enumerate(test_loader):
131 X = data[0]
132 outputs = model(X)
133 predictions.append(outputs)
134 preds = torch.cat(predictions)
135
136 return preds.numpy()
137
138
139load_model_custom_objects = {"optimizer": optimizer, "criterion": loss, "train_func": train, "predict_func": predict}
140
141
142# Store model to file
143# Convert the model to torchscript before saving
144m = torch.jit.script(model)
145torch.jit.save(m, "modelRegression.pt")
146print(m)
147
148
149# Book methods
150factory.BookMethod(dataloader, TMVA.Types.kPyTorch, 'PyTorch',
151 'H:!V:VarTransform=D,G:FilenameModel=modelRegression.pt:FilenameTrainedModel=trainedModelRegression.pt:NumEpochs=20:BatchSize=32')
152factory.BookMethod(dataloader, TMVA.Types.kBDT, 'BDTG',
153 '!H:!V:VarTransform=D,G:NTrees=1000:BoostType=Grad:Shrinkage=0.1:UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=4')
154
155
156# Run TMVA
157factory.TrainAllMethods()
158factory.TestAllMethods()
159factory.EvaluateAllMethods()
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t UChar_t len
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t format
A specialized string object used for TTree selections.
Definition TCut.h:25
static Bool_t SetCacheFileDir(ROOT::Internal::TStringView cacheDir, Bool_t operateDisconnected=kTRUE, Bool_t forceCacheread=kFALSE)
Definition TFile.h:323
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4053
This is the main MVA steering class.
Definition Factory.h:80
static void PyInitialize()
Initialize Python interpreter.
static Tools & Instance()
Definition Tools.cxx:71