Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROCCurve.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer, Simon Pfreundschuh and Kim Albertsson
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : ROCCurve *
8 * *
9 * Description: *
10 * This is class to compute ROC Integral (AUC) *
11 * *
12 * Authors : *
13 * Omar Zapata <Omar.Zapata@cern.ch> - UdeA/ITM Colombia *
14 * Lorenzo Moneta <Lorenzo.Moneta@cern.ch> - CERN, Switzerland *
15 * Sergei Gleyzer <Sergei.Gleyzer@cern.ch> - U of Florida & CERN *
16 * Kim Albertsson <kim.albertsson@cern.ch> - LTU & CERN *
17 * *
18 * Copyright (c) 2015: *
19 * CERN, Switzerland *
20 * UdeA/ITM, Colombia *
21 * U. of Florida, USA *
22 **********************************************************************************/
23
24/*! \class TMVA::ROCCurve
25\ingroup TMVA
26
27*/
28#include "TMVA/Tools.h"
29#include "TMVA/TSpline1.h"
30#include "TMVA/ROCCurve.h"
31#include "TMVA/Config.h"
32#include "TMVA/Version.h"
33#include "TMVA/MsgLogger.h"
34#include "TGraph.h"
35
36#include <algorithm>
37#include <vector>
38#include <cassert>
39
40using namespace std;
41
42auto tupleSort = [](std::tuple<Float_t, Float_t, Bool_t> _a, std::tuple<Float_t, Float_t, Bool_t> _b) {
43 return std::get<0>(_a) < std::get<0>(_b);
44};
45
46//_______________________________________________________________________
47TMVA::ROCCurve::ROCCurve(const std::vector<std::tuple<Float_t, Float_t, Bool_t>> &mvas)
48 : fLogger(new TMVA::MsgLogger("ROCCurve")), fGraph(NULL), fMva(mvas)
49{
50}
51
52////////////////////////////////////////////////////////////////////////////////
53///
54
55TMVA::ROCCurve::ROCCurve(const std::vector<Float_t> &mvaValues, const std::vector<Bool_t> &mvaTargets,
56 const std::vector<Float_t> &mvaWeights)
57 : fLogger(new TMVA::MsgLogger("ROCCurve")), fGraph(NULL)
58{
59 assert(mvaValues.size() == mvaTargets.size());
60 assert(mvaValues.size() == mvaWeights.size());
61
62 for (UInt_t i = 0; i < mvaValues.size(); i++) {
63 fMva.emplace_back(mvaValues[i], mvaWeights[i], mvaTargets[i]);
64 }
65
66 std::sort(fMva.begin(), fMva.end(), tupleSort);
67}
68
69////////////////////////////////////////////////////////////////////////////////
70///
71
72TMVA::ROCCurve::ROCCurve(const std::vector<Float_t> &mvaValues, const std::vector<Bool_t> &mvaTargets)
73 : fLogger(new TMVA::MsgLogger("ROCCurve")), fGraph(NULL)
74{
75 assert(mvaValues.size() == mvaTargets.size());
76
77 for (UInt_t i = 0; i < mvaValues.size(); i++) {
78 fMva.emplace_back(mvaValues[i], 1, mvaTargets[i]);
79 }
80
81 std::sort(fMva.begin(), fMva.end(), tupleSort);
82}
83
84////////////////////////////////////////////////////////////////////////////////
85///
86
87TMVA::ROCCurve::ROCCurve(const std::vector<Float_t> &mvaSignal, const std::vector<Float_t> &mvaBackground)
88 : fLogger(new TMVA::MsgLogger("ROCCurve")), fGraph(NULL)
89{
90 for (UInt_t i = 0; i < mvaSignal.size(); i++) {
91 fMva.emplace_back(mvaSignal[i], 1, kTRUE);
92 }
93
94 for (UInt_t i = 0; i < mvaBackground.size(); i++) {
95 fMva.emplace_back(mvaBackground[i], 1, kFALSE);
96 }
97
98 std::sort(fMva.begin(), fMva.end(), tupleSort);
99}
100
101////////////////////////////////////////////////////////////////////////////////
102///
103
104TMVA::ROCCurve::ROCCurve(const std::vector<Float_t> &mvaSignal, const std::vector<Float_t> &mvaBackground,
105 const std::vector<Float_t> &mvaSignalWeights, const std::vector<Float_t> &mvaBackgroundWeights)
106 : fLogger(new TMVA::MsgLogger("ROCCurve")), fGraph(NULL)
107{
108 assert(mvaSignal.size() == mvaSignalWeights.size());
109 assert(mvaBackground.size() == mvaBackgroundWeights.size());
110
111 for (UInt_t i = 0; i < mvaSignal.size(); i++) {
112 fMva.emplace_back(mvaSignal[i], mvaSignalWeights[i], kTRUE);
113 }
114
115 for (UInt_t i = 0; i < mvaBackground.size(); i++) {
116 fMva.emplace_back(mvaBackground[i], mvaBackgroundWeights[i], kFALSE);
117 }
118
119 std::sort(fMva.begin(), fMva.end(), tupleSort);
120}
121
122////////////////////////////////////////////////////////////////////////////////
123/// destructor
124
126 delete fLogger;
127 if(fGraph) delete fGraph;
128}
129
131{
132 if (!fLogger)
133 fLogger = new TMVA::MsgLogger("ROCCurve");
134 return *fLogger;
135}
136
137////////////////////////////////////////////////////////////////////////////////
138///
139
140std::vector<Double_t> TMVA::ROCCurve::ComputeSpecificity(const UInt_t num_points)
141{
142 if (num_points <= 2) {
143 return {0.0, 1.0};
144 }
145
146 std::vector<Double_t> specificity_vector;
147 std::vector<Double_t> true_negatives;
148 specificity_vector.reserve(fMva.size());
149 true_negatives.reserve(fMva.size());
150
151 Double_t true_negatives_sum = 0.0;
152 for (auto &ev : fMva) {
153 // auto value = std::get<0>(ev);
154 auto weight = std::get<1>(ev);
155 auto isSignal = std::get<2>(ev);
156
157 true_negatives_sum += weight * (!isSignal ? 1. : 0.);
158 true_negatives.push_back(true_negatives_sum);
159 }
160
161 specificity_vector.push_back(0.0);
162 Double_t total_background = true_negatives_sum;
163 for (auto &tn : true_negatives) {
164 Double_t specificity =
165 (total_background <= std::numeric_limits<Double_t>::min()) ? (0.0) : (tn / total_background);
166 specificity_vector.push_back(specificity);
167 }
168 specificity_vector.push_back(1.0);
169
170 return specificity_vector;
171}
172
173////////////////////////////////////////////////////////////////////////////////
174///
175
176std::vector<Double_t> TMVA::ROCCurve::ComputeSensitivity(const UInt_t num_points)
177{
178 if (num_points <= 2) {
179 return {1.0, 0.0};
180 }
181
182 std::vector<Double_t> sensitivity_vector;
183 std::vector<Double_t> true_positives;
184 sensitivity_vector.reserve(fMva.size());
185 true_positives.reserve(fMva.size());
186
187 Double_t true_positives_sum = 0.0;
188 for (auto it = fMva.rbegin(); it != fMva.rend(); ++it) {
189 // auto value = std::get<0>(*it);
190 auto weight = std::get<1>(*it);
191 auto isSignal = std::get<2>(*it);
192
193 true_positives_sum += weight * (isSignal);
194 true_positives.push_back(true_positives_sum);
195 }
196 std::reverse(true_positives.begin(), true_positives.end());
197
198 sensitivity_vector.push_back(1.0);
199 Double_t total_signal = true_positives_sum;
200 for (auto &tp : true_positives) {
201 Double_t sensitivity = (total_signal <= std::numeric_limits<Double_t>::min()) ? (0.0) : (tp / total_signal);
202 sensitivity_vector.push_back(sensitivity);
203 }
204 sensitivity_vector.push_back(0.0);
205
206 return sensitivity_vector;
207}
208
209////////////////////////////////////////////////////////////////////////////////
210/// Calculate the signal efficiency (sensitivity) for a given background
211/// efficiency (sensitivity).
212///
213/// @param effB Background efficiency for which to calculate signal
214/// efficiency.
215/// @param num_points Number of points used for the underlying histogram.
216/// The number of bins will be num_points - 1.
217///
218
220{
221 assert(0.0 <= effB && effB <= 1.0);
222
223 auto effS_vec = ComputeSensitivity(num_points);
224 auto effB_vec = ComputeSpecificity(num_points);
225
226 // Specificity is actually rejB, so we need to transform it.
227 auto complement = [](Double_t x) { return 1 - x; };
228 std::transform(effB_vec.begin(), effB_vec.end(), effB_vec.begin(), complement);
229
230 // Since TSpline1 uses binary search (and assumes ascending sorting) we must ensure this.
231 std::reverse(effS_vec.begin(), effS_vec.end());
232 std::reverse(effB_vec.begin(), effB_vec.end());
233
234 TGraph *graph = new TGraph(effS_vec.size(), &effB_vec[0], &effS_vec[0]);
235
236 // TSpline1 does linear interpolation of ROC curve
237 TSpline1 rocSpline = TSpline1("", graph);
238 return rocSpline.Eval(effB);
239}
240
241////////////////////////////////////////////////////////////////////////////////
242/// Calculates the ROC integral (AUC)
243///
244/// @param num_points Granularity of the resulting curve used for integration.
245/// The curve will be subdivided into num_points - 1 regions
246/// where the performance of the classifier is sampled.
247/// Larger number means more accurate, but more costly,
248/// evaluation.
249
251{
252 auto sensitivity = ComputeSensitivity(num_points);
253 auto specificity = ComputeSpecificity(num_points);
254
255 Double_t integral = 0.0;
256 for (UInt_t i = 0; i < sensitivity.size() - 1; i++) {
257 // FNR, false negatigve rate = 1 - Sensitivity
258 Double_t currFnr = 1 - sensitivity[i];
259 Double_t nextFnr = 1 - sensitivity[i + 1];
260 // Trapezodial integration
261 integral += 0.5 * (nextFnr - currFnr) * (specificity[i] + specificity[i + 1]);
262 }
263
264 return integral;
265}
266
267////////////////////////////////////////////////////////////////////////////////
268/// Returns a new TGraph containing the ROC curve. Sensitivity is on the x-axis,
269/// specificity on the y-axis.
270///
271/// @param num_points Granularity of the resulting curve. The curve will be subdivided
272/// into num_points - 1 regions where the performance of the
273/// classifier is sampled. Larger number means more accurate,
274/// but more costly, evaluation.
275
277{
278 if (fGraph != nullptr) {
279 delete fGraph;
280 }
281
282 auto sensitivity = ComputeSensitivity(num_points);
283 auto specificity = ComputeSpecificity(num_points);
284
285 fGraph = new TGraph(sensitivity.size(), &sensitivity[0], &specificity[0]);
286
287 return fGraph;
288}
auto tupleSort
Definition ROCCurve.cxx:42
const Bool_t kFALSE
Definition RtypesCore.h:101
const Bool_t kTRUE
Definition RtypesCore.h:100
A TGraph is an object made of two arrays X and Y with npoints each.
Definition TGraph.h:41
ostringstream derivative to redirect and format output
Definition MsgLogger.h:57
std::vector< Double_t > ComputeSpecificity(const UInt_t num_points)
Definition ROCCurve.cxx:140
ROCCurve(const std::vector< std::tuple< Float_t, Float_t, Bool_t > > &mvas)
Definition ROCCurve.cxx:47
~ROCCurve()
destructor
Definition ROCCurve.cxx:125
Double_t GetEffSForEffB(Double_t effB, const UInt_t num_points=41)
Calculate the signal efficiency (sensitivity) for a given background efficiency (sensitivity).
Definition ROCCurve.cxx:219
std::vector< Double_t > ComputeSensitivity(const UInt_t num_points)
Definition ROCCurve.cxx:176
Double_t GetROCIntegral(const UInt_t points=41)
Calculates the ROC integral (AUC)
Definition ROCCurve.cxx:250
MsgLogger & Log() const
message logger
Definition ROCCurve.cxx:130
std::vector< std::tuple< Float_t, Float_t, Bool_t > > fMva
Definition ROCCurve.h:74
TGraph * GetROCCurve(const UInt_t points=100)
Returns a new TGraph containing the ROC curve.
Definition ROCCurve.cxx:276
Linear interpolation of TGraph.
Definition TSpline1.h:43
virtual Double_t Eval(Double_t x) const
returns linearly interpolated TGraph entry around x
Definition TSpline1.cxx:61
Double_t x[n]
Definition legend1.C:17
create variable transformations
void mvas(TString dataset, TString fin="TMVA.root", HistType htype=kMVAType, Bool_t useTMVAStyle=kTRUE)
Definition graph.py:1