Logo ROOT  
Reference Guide
Loading...
Searching...
No Matches
tmva102_Testing.py
Go to the documentation of this file.
1## \file
2## \ingroup tutorial_ml
3## \notebook -nodraw
4## This tutorial illustrates how you can test a trained BDT model using the fast
5## tree inference engine offered by TMVA and external tools such as scikit-learn.
6##
7## \macro_code
8## \macro_output
9##
10## \date August 2019
11## \author Stefan Wunsch
12
13
14import ROOT
15from tmva101_Training import load_data
16
17# Load data
18x, y_true, w = load_data("test_signal.root", "test_background.root")
19
20# Load trained model
21File = "tmva101.root"
22
23bdt = ROOT.TMVA.Experimental.RBDT("myBDT", File)
24
25# Make prediction
26y_pred = bdt.Compute(x)
27
28# Compute ROC using sklearn
29from sklearn.metrics import auc, roc_curve
30
31false_positive_rate, true_positive_rate, _ = roc_curve(y_true, y_pred, sample_weight=w)
32score = auc(false_positive_rate, true_positive_rate)
33
34# Plot ROC
35c = ROOT.TCanvas("roc", "", 600, 600)
36g = ROOT.TGraph(len(false_positive_rate), false_positive_rate, true_positive_rate)
37g.SetTitle("AUC = {:.2f}".format(score))
38g.SetLineWidth(3)
39g.SetLineColor("kRed")
40g.Draw("AC")
41g.GetXaxis().SetRangeUser(0, 1)
42g.GetYaxis().SetRangeUser(0, 1)
43g.GetXaxis().SetTitle("False-positive rate")
44g.GetYaxis().SetTitle("True-positive rate")
45c.Draw()