Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
tmva102_Testing.py
Go to the documentation of this file.
1## \file
2## \ingroup tutorial_tmva
3## \notebook -nodraw
4## This tutorial illustrates how you can test a trained BDT model using the fast
5## tree inference engine offered by TMVA and external tools such as scikit-learn.
6##
7## \macro_code
8## \macro_output
9##
10## \date August 2019
11## \author Stefan Wunsch
12
13import ROOT
14import pickle
15
16from tmva100_DataPreparation import variables
17from tmva101_Training import load_data
18
19
20# Load data
21x, y_true, w = load_data("test_signal.root", "test_background.root")
22
23# Load trained model
24bdt = ROOT.TMVA.Experimental.RBDT[""]("myBDT", "tmva101.root")
25
26# Make prediction
27y_pred = bdt.Compute(x)
28
29# Compute ROC using sklearn
30from sklearn.metrics import roc_curve, auc
31fpr, tpr, _ = roc_curve(y_true, y_pred, sample_weight=w)
32score = auc(fpr, tpr, reorder=True)
33
34# Plot ROC
35c = ROOT.TCanvas("roc", "", 600, 600)
36g = ROOT.TGraph(len(fpr), fpr, tpr)
37g.SetTitle("AUC = {:.2f}".format(score))
38g.SetLineWidth(3)
39g.SetLineColor(ROOT.kRed)
40g.Draw("AC")
41g.GetXaxis().SetRangeUser(0, 1)
42g.GetYaxis().SetRangeUser(0, 1)
43g.GetXaxis().SetTitle("False-positive rate")
44g.GetYaxis().SetTitle("True-positive rate")
45c.Draw()