Logo ROOT   6.12/07
Reference Guide
ROCCurve.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer, Simon Pfreundschuh and Kim Albertsson
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : ROCCurve *
8  * *
9  * Description: *
10  * This is class to compute ROC Integral (AUC) *
11  * *
12  * Authors : *
13  * Omar Zapata <Omar.Zapata@cern.ch> - UdeA/ITM Colombia *
14  * Lorenzo Moneta <Lorenzo.Moneta@cern.ch> - CERN, Switzerland *
15  * Sergei Gleyzer <Sergei.Gleyzer@cern.ch> - U of Florida & CERN *
16  * Kim Albertsson <kim.albertsson@cern.ch> - LTU & CERN *
17  * *
18  * Copyright (c) 2015: *
19  * CERN, Switzerland *
20  * UdeA/ITM, Colombia *
21  * U. of Florida, USA *
22  **********************************************************************************/
23 
24 /*! \class TMVA::ROCCurve
25 \ingroup TMVA
26 
27 */
28 #include "TMVA/Tools.h"
29 #include "TMVA/TSpline1.h"
30 #include "TMVA/ROCCurve.h"
31 #include "TMVA/Config.h"
32 #include "TMVA/Version.h"
33 #include "TMVA/MsgLogger.h"
34 #include "TGraph.h"
35 #include "TMath.h"
36 
37 #include <algorithm>
38 #include <vector>
39 #include <cassert>
40 
41 using namespace std;
42 
43 auto tupleSort = [](std::tuple<Float_t, Float_t, Bool_t> _a, std::tuple<Float_t, Float_t, Bool_t> _b) {
44  return std::get<0>(_a) < std::get<0>(_b);
45 };
46 
47 //_______________________________________________________________________
48 TMVA::ROCCurve::ROCCurve(const std::vector<std::tuple<Float_t, Float_t, Bool_t>> &mvas)
49  : fLogger(new TMVA::MsgLogger("ROCCurve")), fGraph(NULL), fMva(mvas)
50 {
51 }
52 
53 ////////////////////////////////////////////////////////////////////////////////
54 ///
55 
56 TMVA::ROCCurve::ROCCurve(const std::vector<Float_t> &mvaValues, const std::vector<Bool_t> &mvaTargets,
57  const std::vector<Float_t> &mvaWeights)
58  : fLogger(new TMVA::MsgLogger("ROCCurve")), fGraph(NULL)
59 {
60  assert(mvaValues.size() == mvaTargets.size());
61  assert(mvaValues.size() == mvaWeights.size());
62 
63  for (UInt_t i = 0; i < mvaValues.size(); i++) {
64  fMva.emplace_back(mvaValues[i], mvaWeights[i], mvaTargets[i]);
65  }
66 
67  std::sort(fMva.begin(), fMva.end(), tupleSort);
68 }
69 
70 ////////////////////////////////////////////////////////////////////////////////
71 ///
72 
73 TMVA::ROCCurve::ROCCurve(const std::vector<Float_t> &mvaValues, const std::vector<Bool_t> &mvaTargets)
74  : fLogger(new TMVA::MsgLogger("ROCCurve")), fGraph(NULL)
75 {
76  assert(mvaValues.size() == mvaTargets.size());
77 
78  for (UInt_t i = 0; i < mvaValues.size(); i++) {
79  fMva.emplace_back(mvaValues[i], 1, mvaTargets[i]);
80  }
81 
82  std::sort(fMva.begin(), fMva.end(), tupleSort);
83 }
84 
85 ////////////////////////////////////////////////////////////////////////////////
86 ///
87 
88 TMVA::ROCCurve::ROCCurve(const std::vector<Float_t> &mvaSignal, const std::vector<Float_t> &mvaBackground)
89  : fLogger(new TMVA::MsgLogger("ROCCurve")), fGraph(NULL)
90 {
91  for (UInt_t i = 0; i < mvaSignal.size(); i++) {
92  fMva.emplace_back(mvaSignal[i], 1, kTRUE);
93  }
94 
95  for (UInt_t i = 0; i < mvaBackground.size(); i++) {
96  fMva.emplace_back(mvaBackground[i], 1, kFALSE);
97  }
98 
99  std::sort(fMva.begin(), fMva.end(), tupleSort);
100 }
101 
102 ////////////////////////////////////////////////////////////////////////////////
103 ///
104 
105 TMVA::ROCCurve::ROCCurve(const std::vector<Float_t> &mvaSignal, const std::vector<Float_t> &mvaBackground,
106  const std::vector<Float_t> &mvaSignalWeights, const std::vector<Float_t> &mvaBackgroundWeights)
107  : fLogger(new TMVA::MsgLogger("ROCCurve")), fGraph(NULL)
108 {
109  assert(mvaSignal.size() == mvaSignalWeights.size());
110  assert(mvaBackground.size() == mvaBackgroundWeights.size());
111 
112  for (UInt_t i = 0; i < mvaSignal.size(); i++) {
113  fMva.emplace_back(mvaSignal[i], mvaSignalWeights[i], kTRUE);
114  }
115 
116  for (UInt_t i = 0; i < mvaBackground.size(); i++) {
117  fMva.emplace_back(mvaBackground[i], mvaBackgroundWeights[i], kFALSE);
118  }
119 
120  std::sort(fMva.begin(), fMva.end(), tupleSort);
121 }
122 
123 ////////////////////////////////////////////////////////////////////////////////
124 /// destructor
125 
127  delete fLogger;
128  if(fGraph) delete fGraph;
129 }
130 
132 {
133  if (!fLogger)
134  fLogger = new TMVA::MsgLogger("ROCCurve");
135  return *fLogger;
136 }
137 
138 ////////////////////////////////////////////////////////////////////////////////
139 ///
140 
141 std::vector<Double_t> TMVA::ROCCurve::ComputeSpecificity(const UInt_t num_points)
142 {
143  if (num_points <= 2) {
144  return {0.0, 1.0};
145  }
146 
147  std::vector<Double_t> specificity_vector;
148  std::vector<Double_t> true_negatives;
149  specificity_vector.reserve(fMva.size());
150  true_negatives.reserve(fMva.size());
151 
152  Double_t true_negatives_sum = 0.0;
153  for (auto &ev : fMva) {
154  // auto value = std::get<0>(ev);
155  auto weight = std::get<1>(ev);
156  auto isSignal = std::get<2>(ev);
157 
158  true_negatives_sum += weight * (not isSignal);
159  true_negatives.push_back(true_negatives_sum);
160  }
161 
162  specificity_vector.push_back(0.0);
163  Double_t total_background = true_negatives_sum;
164  for (auto &tn : true_negatives) {
165  Double_t specificity =
166  (total_background <= std::numeric_limits<Double_t>::min()) ? (0.0) : (tn / total_background);
167  specificity_vector.push_back(specificity);
168  }
169  specificity_vector.push_back(1.0);
170 
171  return specificity_vector;
172 }
173 
174 ////////////////////////////////////////////////////////////////////////////////
175 ///
176 
177 std::vector<Double_t> TMVA::ROCCurve::ComputeSensitivity(const UInt_t num_points)
178 {
179  if (num_points <= 2) {
180  return {1.0, 0.0};
181  }
182 
183  std::vector<Double_t> sensitivity_vector;
184  std::vector<Double_t> true_positives;
185  sensitivity_vector.reserve(fMva.size());
186  true_positives.reserve(fMva.size());
187 
188  Double_t true_positives_sum = 0.0;
189  for (auto it = fMva.rbegin(); it != fMva.rend(); ++it) {
190  // auto value = std::get<0>(*it);
191  auto weight = std::get<1>(*it);
192  auto isSignal = std::get<2>(*it);
193 
194  true_positives_sum += weight * (isSignal);
195  true_positives.push_back(true_positives_sum);
196  }
197  std::reverse(true_positives.begin(), true_positives.end());
198 
199  sensitivity_vector.push_back(1.0);
200  Double_t total_signal = true_positives_sum;
201  for (auto &tp : true_positives) {
202  Double_t sensitivity = (total_signal <= std::numeric_limits<Double_t>::min()) ? (0.0) : (tp / total_signal);
203  sensitivity_vector.push_back(sensitivity);
204  }
205  sensitivity_vector.push_back(0.0);
206 
207  return sensitivity_vector;
208 }
209 
210 ////////////////////////////////////////////////////////////////////////////////
211 /// Calculate the signal efficiency (sensitivity) for a given background
212 /// efficiency (sensitivity).
213 ///
214 /// @param effB Background efficiency for which to calculate signal
215 /// efficiency.
216 /// @param num_points Number of points used for the underlying histogram.
217 /// The number of bins will be num_points - 1.
218 ///
219 
221 {
222  assert(0.0 <= effB and effB <= 1.0);
223 
224  auto effS_vec = ComputeSensitivity(num_points);
225  auto effB_vec = ComputeSpecificity(num_points);
226 
227  // Specificity is actually rejB, so we need to transform it.
228  auto complement = [](Double_t x) { return 1 - x; };
229  std::transform(effB_vec.begin(), effB_vec.end(), effB_vec.begin(), complement);
230 
231  // Since TSpline1 uses binary search (and assumes ascending sorting) we must ensure this.
232  std::reverse(effS_vec.begin(), effS_vec.end());
233  std::reverse(effB_vec.begin(), effB_vec.end());
234 
235  TGraph *graph = new TGraph(effS_vec.size(), &effB_vec[0], &effS_vec[0]);
236 
237  // TSpline1 does linear interpolation of ROC curve
238  TSpline1 rocSpline = TSpline1("", graph);
239  return rocSpline.Eval(effB);
240 }
241 
242 ////////////////////////////////////////////////////////////////////////////////
243 /// Calculates the ROC integral (AUC)
244 ///
245 /// @param num_points Granularity of the resulting curve used for integration.
246 /// The curve will be subdivided into num_points - 1 regions
247 /// where the performance of the classifier is sampled.
248 /// Larger number means more accurate, but more costly,
249 /// evaluation.
250 
252 {
253  auto sensitivity = ComputeSensitivity(num_points);
254  auto specificity = ComputeSpecificity(num_points);
255 
256  Double_t integral = 0.0;
257  for (UInt_t i = 0; i < sensitivity.size() - 1; i++) {
258  // FNR, false negatigve rate = 1 - Sensitivity
259  Double_t currFnr = 1 - sensitivity[i];
260  Double_t nextFnr = 1 - sensitivity[i + 1];
261  // Trapezodial integration
262  integral += 0.5 * (nextFnr - currFnr) * (specificity[i] + specificity[i + 1]);
263  }
264 
265  return integral;
266 }
267 
268 ////////////////////////////////////////////////////////////////////////////////
269 /// Returns a new TGraph containing the ROC curve. Specificity is on the x-axis,
270 /// sensitivity on the y-axis.
271 ///
272 /// @param num_points Granularity of the resulting curve. The curve will be subdivided
273 /// into num_points - 1 regions where the performance of the
274 /// classifier is sampled. Larger number means more accurate,
275 /// but more costly, evaluation.
276 
278 {
279  if (fGraph != nullptr) {
280  delete fGraph;
281  }
282 
283  auto sensitivity = ComputeSensitivity(num_points);
284  auto specificity = ComputeSpecificity(num_points);
285 
286  fGraph = new TGraph(sensitivity.size(), &sensitivity[0], &specificity[0]);
287 
288  return fGraph;
289 }
MsgLogger * fLogger
Definition: ROCCurve.h:71
Double_t GetEffSForEffB(Double_t effB, const UInt_t num_points=41)
Calculate the signal efficiency (sensitivity) for a given background efficiency (sensitivity).
Definition: ROCCurve.cxx:220
Double_t GetROCIntegral(const UInt_t points=41)
Calculates the ROC integral (AUC)
Definition: ROCCurve.cxx:251
TGraph * fGraph
Definition: ROCCurve.h:74
~ROCCurve()
destructor
Definition: ROCCurve.cxx:126
std::vector< Double_t > ComputeSpecificity(const UInt_t num_points)
Definition: ROCCurve.cxx:141
void mvas(TString dataset, TString fin="TMVA.root", HistType htype=kMVAType, Bool_t useTMVAStyle=kTRUE)
STL namespace.
Double_t x[n]
Definition: legend1.C:17
ROCCurve(const std::vector< std::tuple< Float_t, Float_t, Bool_t >> &mvas)
Definition: ROCCurve.cxx:48
unsigned int UInt_t
Definition: RtypesCore.h:42
Definition: graph.py:1
Linear interpolation of TGraph.
Definition: TSpline1.h:43
const Bool_t kFALSE
Definition: RtypesCore.h:88
double Double_t
Definition: RtypesCore.h:55
auto tupleSort
Definition: ROCCurve.cxx:43
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
MsgLogger & Log() const
message logger
Definition: ROCCurve.cxx:131
std::vector< Double_t > ComputeSensitivity(const UInt_t num_points)
Definition: ROCCurve.cxx:177
Abstract ClassifierFactory template that handles arbitrary types.
TGraph * GetROCCurve(const UInt_t points=100)
Returns a new TGraph containing the ROC curve.
Definition: ROCCurve.cxx:277
A Graph is a graphics object made of two arrays X and Y with npoints each.
Definition: TGraph.h:41
std::vector< std::tuple< Float_t, Float_t, Bool_t > > fMva
Definition: ROCCurve.h:76
const Bool_t kTRUE
Definition: RtypesCore.h:87