Logo ROOT   6.07/09
Reference Guide
ROCCurve.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer and Simon Pfreundschuh
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : ROCCurve *
8  * *
9  * Description: *
10  * This is class to compute ROC Integral (AUC) *
11  * *
12  * Authors : *
13  * Omar Zapata <Omar.Zapata@cern.ch> - UdeA/ITM Colombia *
14  * Lorenzo Moneta <Lorenzo.Moneta@cern.ch> - CERN, Switzerland *
15  * Sergei Gleyzer <Sergei.Gleyzer@cern.ch> - U of Florida & CERN *
16  * *
17  * Copyright (c) 2015: *
18  * CERN, Switzerland *
19  * UdeA/ITM, Colombia *
20  * U. of Florida, USA *
21  **********************************************************************************/
22 
23 #ifndef ROOT_TMVA_Tools
24 #include "TMVA/Tools.h"
25 #endif
26 #ifndef ROOT_TMVA_ROCCurve
27 #include "TMVA/ROCCurve.h"
28 #endif
29 #ifndef ROOT_TMVA_Config
30 #include "TMVA/Config.h"
31 #endif
32 #ifndef ROOT_TMVA_Version
33 #include "TMVA/Version.h"
34 #endif
35 #ifndef ROOT_TMVA_MsgLogger
36 #include "TMVA/MsgLogger.h"
37 #endif
38 #ifndef ROOT_TGraph
39 #include "TGraph.h"
40 #endif
41 
42 #include<vector>
43 #include <cassert>
44 
45 using namespace std;
46 
47 
48 TMVA::ROCCurve::ROCCurve(const std::vector<Float_t> & mva, const std::vector<Bool_t> & mvat) :
49  fLogger ( new TMVA::MsgLogger("ROCCurve") ),fGraph(NULL)
50 {
51  assert(mva.size() == mvat.size() );
52  for(UInt_t i=0;i<mva.size();i++)
53  {
54  if(mvat[i] ) fMvaS.push_back(mva[i]);
55  else fMvaB.push_back(mva[i]);
56  }
57 }
58 
59 
60 
61 ////////////////////////////////////////////////////////////////////////////////
62 /// destructor
63 
65  delete fLogger;
66  if(fGraph) delete fGraph;
67 }
68 
69 ////////////////////////////////////////////////////////////////////////////////
70 /// ROC Integral (AUC)
71 
73 
74  Float_t integral=0;
75  int ndivisions = 40;
76  fEpsilonSig.push_back(0);
77  fEpsilonBgk.push_back(0);
78 
79  Float_t epsilon_s = 0.0;
80  Float_t epsilon_b = 0.0;
81 
82  for(Float_t i=-1.0;i<1.0;i+=(1.0/ndivisions))
83  {
84  Float_t acounter = 0.0;
85  Float_t bcounter = 0.0;
86  Float_t ccounter = 0.0;
87  Float_t dcounter = 0.0;
88 
89  for(UInt_t j=0;j<fMvaS.size();j++)
90  {
91  if(fMvaS[j] > i) acounter++;
92  else bcounter++;
93 
94  if(fMvaB[j] > i) ccounter++;
95  else dcounter++;
96  }
97 
98  if(acounter != 0 || bcounter != 0)
99  {
100  epsilon_s = 1.0*bcounter/(acounter+bcounter);
101  }
102  fEpsilonSig.push_back(epsilon_s);
103 
104  if(ccounter != 0 || dcounter != 0)
105  {
106  epsilon_b = 1.0*dcounter/(ccounter+dcounter);
107  }
108  fEpsilonBgk.push_back(epsilon_b);
109  }
110  fEpsilonSig.push_back(1.0);
111  fEpsilonBgk.push_back(1.0);
112  for(UInt_t i=0;i<fEpsilonSig.size()-1;i++)
113  {
114  integral += 0.5*(fEpsilonSig[i+1]-fEpsilonSig[i])*(fEpsilonBgk[i]+fEpsilonBgk[i+1]);
115  }
116  return integral;
117 }
118 
119 
121 {
122 
123  const UInt_t ndivisions = points - 1;
124  fEpsilonSig.resize(points);
125  fEpsilonBgk.resize(points);
126  // Fixed values.
127  fEpsilonSig[0] = 0.0;
128  fEpsilonSig[ndivisions] = 1.0;
129  fEpsilonBgk[0] = 1.0;
130  fEpsilonBgk[ndivisions] = 0.0;
131 
132  for(UInt_t i = 1; i < ndivisions; i++)
133  {
134  Float_t threshold = -1.0 + i * 2.0 / (Float_t) ndivisions;
135  Float_t true_positives = 0.0;
136  Float_t false_positives = 0.0;
137  Float_t true_negatives = 0.0;
138  Float_t false_negatives = 0.0;
139 
140  for(UInt_t j=0; j<fMvaS.size(); j++)
141  {
142  if(fMvaS[j] > threshold)
143  true_positives += 1.0;
144  else
145  false_negatives += 1.0;
146 
147  if(fMvaB[j] > threshold)
148  false_positives += 1.0;
149  else
150  true_negatives += 1.0;
151  }
152 
153  fEpsilonSig[ndivisions - i] = 0.0;
154  if ((true_positives > 0.0) || (false_negatives > 0.0))
155  fEpsilonSig[ndivisions - i] =
156  true_positives / (true_positives + false_negatives);
157 
158  fEpsilonBgk[ndivisions - i] =0.0;
159  if ((true_negatives > 0.0) || (false_positives > 0.0))
160  fEpsilonBgk[ndivisions - i] =
161  true_negatives / (true_negatives + false_positives);
162 
163  }
164 
165  if(!fGraph) fGraph=new TGraph(fEpsilonSig.size(),&fEpsilonSig[0],&fEpsilonBgk[0]);
166  return fGraph;
167 }
168 
169 
170 
171 
MsgLogger * fLogger
Definition: ROCCurve.h:60
ROCCurve(const std::vector< Float_t > &mvaS, const std::vector< Bool_t > &mvat)
Definition: ROCCurve.cxx:48
float Float_t
Definition: RtypesCore.h:53
TGraph * fGraph
Definition: ROCCurve.h:62
~ROCCurve()
destructor
Definition: ROCCurve.cxx:64
STL namespace.
std::vector< Float_t > fEpsilonSig
Definition: ROCCurve.h:65
std::vector< Float_t > fMvaB
Definition: ROCCurve.h:64
point * points
Definition: X3DBuffer.c:20
Double_t GetROCIntegral()
ROC Integral (AUC)
Definition: ROCCurve.cxx:72
unsigned int UInt_t
Definition: RtypesCore.h:42
std::vector< Float_t > fMvaS
Definition: ROCCurve.h:63
double Double_t
Definition: RtypesCore.h:55
Abstract ClassifierFactory template that handles arbitrary types.
TGraph * GetROCCurve(const UInt_t points=100)
Definition: ROCCurve.cxx:120
A Graph is a graphics object made of two arrays X and Y with npoints each.
Definition: TGraph.h:53
#define NULL
Definition: Rtypes.h:82
std::vector< Float_t > fEpsilonBgk
Definition: ROCCurve.h:66