Logo ROOT  
Reference Guide
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
Loading...
Searching...
No Matches
ApplicationClassificationPyTorch.py File Reference

Namespaces

namespace  ApplicationClassificationPyTorch
 

Detailed Description

View in nbviewer Open in SWAN
This tutorial shows how to apply a trained model to new data.

# PyTorch has to be imported before ROOT to avoid crashes because of clashing
# std::regexp symbols that are exported by cppyy.
# See also: https://github.com/wlav/cppyy/issues/227
import torch
from ROOT import TMVA, TFile, TString
from array import array
from subprocess import call
from os.path import isfile
# Setup TMVA
reader = TMVA.Reader("Color:!Silent")
# Load data
if not isfile('tmva_class_example.root'):
call(['curl', '-L', '-O', 'http://root.cern.ch/files/tmva_class_example.root'])
data = TFile.Open('tmva_class_example.root')
signal = data.Get('TreeS')
background = data.Get('TreeB')
branches = {}
for branch in signal.GetListOfBranches():
branchName = branch.GetName()
branches[branchName] = array('f', [-999])
reader.AddVariable(branchName, branches[branchName])
signal.SetBranchAddress(branchName, branches[branchName])
background.SetBranchAddress(branchName, branches[branchName])
# Define predict function
def predict(model, test_X, batch_size=32):
# Set to eval mode
model.eval()
test_dataset = torch.utils.data.TensorDataset(torch.Tensor(test_X))
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
predictions = []
with torch.no_grad():
for i, data in enumerate(test_loader):
X = data[0]
outputs = model(X)
predictions.append(outputs)
preds = torch.cat(predictions)
return preds.numpy()
load_model_custom_objects = {"optimizer": None, "criterion": None, "train_func": None, "predict_func": predict}
# Book methods
reader.BookMVA('PyTorch', TString('dataset/weights/TMVAClassification_PyTorch.weights.xml'))
# Print some example classifications
print('Some signal example classifications:')
for i in range(20):
signal.GetEntry(i)
print(reader.EvaluateMVA('PyTorch'))
print('')
print('Some background example classifications:')
for i in range(20):
background.GetEntry(i)
print(reader.EvaluateMVA('PyTorch'))
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4075
static void PyInitialize()
Initialize Python interpreter.
The Reader class serves to use the MVAs in a specific analysis context.
Definition Reader.h:64
static Tools & Instance()
Definition Tools.cxx:71
Basic string class.
Definition TString.h:139
Date
2020
Author
Anirudh Dagar aniru.nosp@m.dhda.nosp@m.gar6@.nosp@m.gmai.nosp@m.l.com - IIT, Roorkee

Definition in file ApplicationClassificationPyTorch.py.