Logo ROOT   6.10/09
Reference Guide
ROCCurve.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer, Simon Pfreundschuh and Kim Albertsson
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : ROCCurve *
8  * *
9  * Description: *
10  * This is class to compute ROC Integral (AUC) *
11  * *
12  * Authors : *
13  * Omar Zapata <Omar.Zapata@cern.ch> - UdeA/ITM Colombia *
14  * Lorenzo Moneta <Lorenzo.Moneta@cern.ch> - CERN, Switzerland *
15  * Sergei Gleyzer <Sergei.Gleyzer@cern.ch> - U of Florida & CERN *
16  * Kim Albertsson <kim.albertsson@cern.ch> - LTU & CERN *
17  * *
18  * Copyright (c) 2015: *
19  * CERN, Switzerland *
20  * UdeA/ITM, Colombia *
21  * U. of Florida, USA *
22  **********************************************************************************/
23 
24 /*! \class TMVA::ROCCurve
25 \ingroup TMVA
26 
27 */
28 #include "TMVA/Tools.h"
29 #include "TMVA/TSpline1.h"
30 #include "TMVA/ROCCurve.h"
31 #include "TMVA/Config.h"
32 #include "TMVA/Version.h"
33 #include "TMVA/MsgLogger.h"
34 #include "TGraph.h"
35 #include "TMath.h"
36 
37 #include <vector>
38 #include <cassert>
39 
40 using namespace std;
41 
42 ////////////////////////////////////////////////////////////////////////////////
43 ///
44 
45 TMVA::ROCCurve::ROCCurve(const std::vector<Float_t> &mvaValues, const std::vector<Bool_t> &mvaTargets,
46  const std::vector<Float_t> &mvaWeights)
47  : fLogger(new TMVA::MsgLogger("ROCCurve")), fGraph(NULL)
48 {
49  assert(mvaValues.size() == mvaTargets.size());
50  assert(mvaValues.size() == mvaWeights.size());
51 
52  for (UInt_t i = 0; i < mvaValues.size(); i++) {
53  if (mvaTargets[i]) {
54  fMvaSignal.push_back(mvaValues[i]);
55  fMvaSignalWeights.push_back(mvaWeights[i]);
56  } else {
57  fMvaBackground.push_back(mvaValues[i]);
58  fMvaBackgroundWeights.push_back(mvaWeights[i]);
59  }
60  }
61 }
62 
63 ////////////////////////////////////////////////////////////////////////////////
64 ///
65 
66 TMVA::ROCCurve::ROCCurve(const std::vector<Float_t> &mvaValues, const std::vector<Bool_t> &mvaTargets)
67  : fLogger(new TMVA::MsgLogger("ROCCurve")), fGraph(NULL)
68 {
69  assert(mvaValues.size() == mvaTargets.size());
70 
71  for (UInt_t i = 0; i < mvaValues.size(); i++) {
72  if (mvaTargets[i]) {
73  fMvaSignal.push_back(mvaValues[i]);
74  } else {
75  fMvaBackground.push_back(mvaValues[i]);
76  }
77  }
78 }
79 
80 ////////////////////////////////////////////////////////////////////////////////
81 ///
82 
83 TMVA::ROCCurve::ROCCurve(const std::vector<Float_t> &mvaSignal, const std::vector<Float_t> &mvaBackground)
84  : fLogger(new TMVA::MsgLogger("ROCCurve")), fGraph(NULL)
85 {
86  fMvaSignal = mvaSignal;
87  fMvaBackground = mvaBackground;
88 }
89 
90 ////////////////////////////////////////////////////////////////////////////////
91 ///
92 
93 TMVA::ROCCurve::ROCCurve(const std::vector<Float_t> &mvaSignal, const std::vector<Float_t> &mvaBackground,
94  const std::vector<Float_t> &mvaSignalWeights, const std::vector<Float_t> &mvaBackgroundWeights)
95  : ROCCurve(mvaSignal, mvaBackground)
96 {
97  assert(mvaSignal.size() == mvaSignalWeights.size());
98  assert(mvaBackground.size() == mvaBackgroundWeights.size());
99 
100  fMvaSignalWeights = mvaSignalWeights;
101  fMvaBackgroundWeights = mvaBackgroundWeights;
102 }
103 
104 ////////////////////////////////////////////////////////////////////////////////
105 /// destructor
106 
108  delete fLogger;
109  if(fGraph) delete fGraph;
110 }
111 
112 ////////////////////////////////////////////////////////////////////////////////
113 ///
114 
115 std::vector<Double_t> TMVA::ROCCurve::ComputeSpecificity(const UInt_t num_points)
116 {
117  if (num_points <= 2) {
118  return {0.0, 1.0};
119  }
120 
121  UInt_t num_divisions = num_points - 1;
122  std::vector<Double_t> specificity_vector;
123  specificity_vector.push_back(0.0);
124 
125  for (Double_t threshold = -1.0; threshold < 1.0; threshold += (1.0 / num_divisions)) {
126  Double_t false_positives = 0.0;
127  Double_t true_negatives = 0.0;
128 
129  for (size_t i = 0; i < fMvaBackground.size(); ++i) {
130  auto value = fMvaBackground.at(i);
131  auto weight = fMvaBackgroundWeights.empty() ? (1.0) : fMvaBackgroundWeights.at(i);
132 
133  if (value > threshold) {
134  false_positives += weight;
135  } else {
136  true_negatives += weight;
137  }
138  }
139 
140  Double_t total_background = false_positives + true_negatives;
141  Double_t specificity =
142  (total_background <= std::numeric_limits<Double_t>::min()) ? (0.0) : (true_negatives / total_background);
143 
144  specificity_vector.push_back(specificity);
145  }
146 
147  specificity_vector.push_back(1.0);
148  return specificity_vector;
149 }
150 
151 ////////////////////////////////////////////////////////////////////////////////
152 ///
153 
154 std::vector<Double_t> TMVA::ROCCurve::ComputeSensitivity(const UInt_t num_points)
155 {
156  if (num_points <= 2) {
157  return {1.0, 0.0};
158  }
159 
160  UInt_t num_divisions = num_points - 1;
161  std::vector<Double_t> sensitivity_vector;
162  sensitivity_vector.push_back(1.0);
163 
164  for (Double_t threshold = -1.0; threshold < 1.0; threshold += (1.0 / num_divisions)) {
165  Double_t true_positives = 0.0;
166  Double_t false_negatives = 0.0;
167 
168  for (size_t i = 0; i < fMvaSignal.size(); ++i) {
169  auto value = fMvaSignal.at(i);
170  auto weight = fMvaSignalWeights.empty() ? (1.0) : fMvaSignalWeights.at(i);
171 
172  if (value > threshold) {
173  true_positives += weight;
174  } else {
175  false_negatives += weight;
176  }
177  }
178 
179  Double_t total_signal = true_positives + false_negatives;
180  Double_t sensitivity =
181  (total_signal <= std::numeric_limits<Double_t>::min()) ? (0.0) : (true_positives / total_signal);
182  sensitivity_vector.push_back(sensitivity);
183  }
184 
185  sensitivity_vector.push_back(0.0);
186  return sensitivity_vector;
187 }
188 
189 ////////////////////////////////////////////////////////////////////////////////
190 /// Calculate the signal efficiency (sensitivity) for a given background
191 /// efficiency (sensitivity).
192 ///
193 /// @param effB Background efficiency for which to calculate signal
194 /// efficiency.
195 /// @param num_points Number of points used for the underlying histogram.
196 /// The number of bins will be num_points - 1.
197 ///
198 
200 {
201  assert(0.0 <= effB and effB <= 1.0);
202 
203  auto effS_vec = ComputeSensitivity(num_points);
204  auto effB_vec = ComputeSpecificity(num_points);
205 
206  // Specificity is actually rejB, so we need to transform it.
207  auto complement = [](Double_t x) { return 1 - x; };
208  std::transform(effB_vec.begin(), effB_vec.end(), effB_vec.begin(), complement);
209 
210  // Since TSpline1 uses binary search (and assumes ascending sorting) we must ensure this.
211  std::reverse(effS_vec.begin(), effS_vec.end());
212  std::reverse(effB_vec.begin(), effB_vec.end());
213 
214  TGraph *graph = new TGraph(effS_vec.size(), &effB_vec[0], &effS_vec[0]);
215 
216  // TSpline1 does linear interpolation of ROC curve
217  TSpline1 rocSpline = TSpline1("", graph);
218  return rocSpline.Eval(effB);
219 }
220 
221 ////////////////////////////////////////////////////////////////////////////////
222 /// Calculates the ROC integral (AUC)
223 ///
224 /// @param num_points Granularity of the resulting curve used for integration.
225 /// The curve will be subdivided into num_points - 1 regions
226 /// where the performance of the classifier is sampled.
227 /// Larger number means more accurate, but more costly,
228 /// evaluation.
229 
231 {
232  auto sensitivity = ComputeSensitivity(num_points);
233  auto specificity = ComputeSpecificity(num_points);
234 
235  Double_t integral = 0.0;
236  for (UInt_t i = 0; i < sensitivity.size() - 1; i++) {
237  // FNR, false negatigve rate = 1 - Sensitivity
238  Double_t currFnr = 1 - sensitivity[i];
239  Double_t nextFnr = 1 - sensitivity[i + 1];
240  // Trapezodial integration
241  integral += 0.5 * (nextFnr - currFnr) * (specificity[i] + specificity[i + 1]);
242  }
243 
244  return integral;
245 }
246 
247 ////////////////////////////////////////////////////////////////////////////////
248 /// Returns a new TGraph containing the ROC curve. Specificity is on the x-axis,
249 /// sensitivity on the y-axis.
250 ///
251 /// @param num_points Granularity of the resulting curve. The curve will be subdivided
252 /// into num_points - 1 regions where the performance of the
253 /// classifier is sampled. Larger number means more accurate,
254 /// but more costly, evaluation.
255 
257 {
258  if (fGraph != nullptr) {
259  delete fGraph;
260  }
261 
262  auto sensitivity = ComputeSensitivity(num_points);
263  auto specificity = ComputeSpecificity(num_points);
264 
265  fGraph = new TGraph(sensitivity.size(), &sensitivity[0], &specificity[0]);
266 
267  return fGraph;
268 }
MsgLogger * fLogger
Definition: ROCCurve.h:68
std::vector< Float_t > fMvaSignal
Definition: ROCCurve.h:73
Double_t GetEffSForEffB(Double_t effB, const UInt_t num_points=41)
Calculate the signal efficiency (sensitivity) for a given background efficiency (sensitivity).
Definition: ROCCurve.cxx:199
ROCCurve(const std::vector< Float_t > &mvaValues, const std::vector< Bool_t > &mvaTargets, const std::vector< Float_t > &mvaWeights)
Definition: ROCCurve.cxx:45
Double_t GetROCIntegral(const UInt_t points=41)
Calculates the ROC integral (AUC)
Definition: ROCCurve.cxx:230
TGraph * fGraph
Definition: ROCCurve.h:71
~ROCCurve()
destructor
Definition: ROCCurve.cxx:107
std::vector< Double_t > ComputeSpecificity(const UInt_t num_points)
Definition: ROCCurve.cxx:115
STL namespace.
#define NULL
Definition: RtypesCore.h:88
Double_t x[n]
Definition: legend1.C:17
std::vector< Float_t > fMvaSignalWeights
Definition: ROCCurve.h:75
unsigned int UInt_t
Definition: RtypesCore.h:42
Definition: graph.py:1
Linear interpolation of TGraph.
Definition: TSpline1.h:43
double Double_t
Definition: RtypesCore.h:55
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
std::vector< Double_t > ComputeSensitivity(const UInt_t num_points)
Definition: ROCCurve.cxx:154
Abstract ClassifierFactory template that handles arbitrary types.
TGraph * GetROCCurve(const UInt_t points=100)
Returns a new TGraph containing the ROC curve.
Definition: ROCCurve.cxx:256
A Graph is a graphics object made of two arrays X and Y with npoints each.
Definition: TGraph.h:41
std::vector< Float_t > fMvaBackgroundWeights
Definition: ROCCurve.h:76
std::vector< Float_t > fMvaBackground
Definition: ROCCurve.h:74