Logo ROOT  
Reference Guide
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
Loading...
Searching...
No Matches
ApplicationRegressionPyTorch.py File Reference

Detailed Description

View in nbviewer Open in SWAN
This tutorial shows how to apply a trained model to new data (regression).

# PyTorch has to be imported before ROOT to avoid crashes because of clashing
# std::regexp symbols that are exported by cppyy.
# See also: https://github.com/wlav/cppyy/issues/227
import torch
from ROOT import TMVA, TFile, TString, gROOT
from array import array
from subprocess import call
from os.path import isfile
# Setup TMVA
reader = TMVA.Reader("Color:!Silent")
# Load data
data = TFile.Open(str(gROOT.GetTutorialDir()) + "/tmva/data/tmva_reg_example.root")
tree = data.Get('TreeR')
branches = {}
for branch in tree.GetListOfBranches():
branchName = branch.GetName()
branches[branchName] = array('f', [-999])
tree.SetBranchAddress(branchName, branches[branchName])
if branchName != 'fvalue':
reader.AddVariable(branchName, branches[branchName])
# Book methods
reader.BookMVA('PyTorch', TString('dataset/weights/TMVARegression_PyTorch.weights.xml'))
# Define predict function
def predict(model, test_X, batch_size=32):
# Set to eval mode
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
predictions = []
with torch.no_grad():
for i, data in enumerate(test_loader):
X = data[0]
outputs = model(X)
preds = torch.cat(predictions)
return preds.numpy()
load_model_custom_objects = {"optimizer": None, "criterion": None, "train_func": None, "predict_func": predict}
# Print some example regressions
print('Some example regressions:')
for i in range(20):
print('True/MVA value: {}/{}'.format(branches['fvalue'][0],reader.EvaluateMVA('PyTorch')))
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t format
The Reader class serves to use the MVAs in a specific analysis context.
Definition Reader.h:64
Basic string class.
Definition TString.h:139
Date
2020
Author
Anirudh Dagar aniru.nosp@m.dhda.nosp@m.gar6@.nosp@m.gmai.nosp@m.l.com - IIT, Roorkee

Definition in file ApplicationRegressionPyTorch.py.