Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ResultsMulticlass.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Jan Therhaag
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : ResultsMulticlass *
8 * *
9 * *
10 * Description: *
11 * Implementation (see header for description) *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Peter Speckmayer <Peter.Speckmayer@cern.ch> - CERN, Switzerland *
16 * Joerg Stelzer <Joerg.Stelzer@cern.ch> - CERN, Switzerland *
17 * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
18 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
19 * *
20 * Copyright (c) 2006: *
21 * CERN, Switzerland *
22 * MPI-K Heidelberg, Germany *
23 * U. of Bonn, Germany *
24 * *
25 * Redistribution and use in source and binary forms, with or without *
26 * modification, are permitted according to the terms listed in LICENSE *
27 * (see tmva/doc/LICENSE) *
28 **********************************************************************************/
29
30/*! \class TMVA::ResultsMulticlass
31\ingroup TMVA
32Class which takes the results of a multiclass classification
33*/
34
36
37#include "TMVA/DataSet.h"
38#include "TMVA/DataSetInfo.h"
40#include "TMVA/GeneticFitter.h"
41#include "TMVA/MsgLogger.h"
42#include "TMVA/Results.h"
43#include "TMVA/ROCCurve.h"
44#include "TMVA/Tools.h"
45#include "TMVA/Types.h"
46
47#include "TGraph.h"
48#include "TH1F.h"
49#include "TMatrixD.h"
50
51#include <limits>
52#include <vector>
53
54
55////////////////////////////////////////////////////////////////////////////////
56/// constructor
57
59 : Results( dsi, resultsName ),
61 fLogger( new MsgLogger(TString::Format("ResultsMultiClass%s",resultsName.Data()).Data() , kINFO) ),
62 fClassToOptimize(0),
63 fAchievableEff(dsi->GetNClasses()),
64 fAchievablePur(dsi->GetNClasses()),
65 fBestCuts(dsi->GetNClasses(),std::vector<Double_t>(dsi->GetNClasses()))
66{
67}
68
69////////////////////////////////////////////////////////////////////////////////
70/// destructor
71
73{
74 delete fLogger;
75}
76
77////////////////////////////////////////////////////////////////////////////////
78
79void TMVA::ResultsMulticlass::SetValue( std::vector<Float_t>& value, Int_t ievt )
80{
81 if (ievt >= (Int_t)fMultiClassValues.size()) fMultiClassValues.resize( ievt+1 );
82 fMultiClassValues[ievt] = value;
83}
84
85////////////////////////////////////////////////////////////////////////////////
86/// Returns a confusion matrix where each class is pitted against each other.
87/// Results are
88
90{
91 const DataSet *ds = GetDataSet();
92 const DataSetInfo *dsi = GetDataSetInfo();
93 ds->SetCurrentType(GetTreeType());
94
95 UInt_t numClasses = dsi->GetNClasses();
96 TMatrixD mat(numClasses, numClasses);
97
98 // class == iRow is considered signal class
99 for (UInt_t iRow = 0; iRow < numClasses; ++iRow) {
100 for (UInt_t iCol = 0; iCol < numClasses; ++iCol) {
101
102 // Number is meaningless with only one class
103 if (iRow == iCol) {
104 mat(iRow, iCol) = std::numeric_limits<double>::quiet_NaN();
105 }
106
107 std::vector<Float_t> valueVector;
108 std::vector<Bool_t> classVector;
109 std::vector<Float_t> weightVector;
110
111 for (UInt_t iEvt = 0; iEvt < ds->GetNEvents(); ++iEvt) {
112 const Event *ev = ds->GetEvent(iEvt);
113 const UInt_t cls = ev->GetClass();
114 const Float_t weight = ev->GetWeight();
115 const Float_t mvaValue = fMultiClassValues[iEvt][iRow];
116
117 if (cls != iRow && cls != iCol) {
118 continue;
119 }
120
121 classVector.push_back(cls == iRow);
122 weightVector.push_back(weight);
123 valueVector.push_back(mvaValue);
124 }
125
126 ROCCurve roc(valueVector, classVector, weightVector);
127 mat(iRow, iCol) = roc.GetEffSForEffB(effB);
128 }
129 }
130
131 return mat;
132}
133
134////////////////////////////////////////////////////////////////////////////////
135
136Double_t TMVA::ResultsMulticlass::EstimatorFunction( std::vector<Double_t> & cutvalues ){
137
138 DataSet* ds = GetDataSet();
139 ds->SetCurrentType( GetTreeType() );
140
141 // Cache optimisation, count true and false positives with memory access
142 // instead of code branch.
143 Float_t positives[2] = {0, 0};
144
145 for (Int_t ievt = 0; ievt < ds->GetNEvents(); ievt++) {
146 UInt_t evClass = fEventClasses[ievt];
147 Float_t w = fEventWeights[ievt];
148
149 Bool_t break_outer_loop = false;
150 for (UInt_t icls = 0; icls < cutvalues.size(); ++icls) {
151 auto value = fMultiClassValues[ievt][icls];
152 auto cutvalue = cutvalues.at(icls);
153 if (cutvalue < 0. ? (-value < cutvalue) : (+value <= cutvalue)) {
154 break_outer_loop = true;
155 break;
156 }
157 }
158
159 if (break_outer_loop) {
160 continue;
161 }
162
163 Bool_t isEvCurrClass = (evClass == fClassToOptimize);
164 positives[isEvCurrClass] += w;
165 }
166
167 const Float_t truePositive = positives[1];
168 const Float_t falsePositive = positives[0];
169
170 Float_t eff = truePositive / fClassSumWeights[fClassToOptimize];
171 Float_t pur = truePositive / (truePositive + falsePositive);
172 Float_t effTimesPur = eff*pur;
173
174 Float_t toMinimize = std::numeric_limits<float>::max();
175 if (effTimesPur > std::numeric_limits<float>::min())
176 toMinimize = 1./(effTimesPur); // we want to minimize 1/efficiency*purity
177
178 fAchievableEff.at(fClassToOptimize) = eff;
179 fAchievablePur.at(fClassToOptimize) = pur;
180
181 return toMinimize;
182}
183
184////////////////////////////////////////////////////////////////////////////////
185///calculate the best working point (optimal cut values)
186///for the multiclass classifier
187
189
190 const DataSetInfo* dsi = GetDataSetInfo();
191 Log() << kINFO << "Calculating best set of cuts for class "
192 << dsi->GetClassInfo( targetClass )->GetName() << Endl;
193
194 fClassToOptimize = targetClass;
195 std::vector<Interval*> ranges(dsi->GetNClasses(), new Interval(-1,1));
196
197 fClassSumWeights.clear();
198 fEventWeights.clear();
199 fEventClasses.clear();
200
201 for (UInt_t icls = 0; icls < dsi->GetNClasses(); ++icls) {
202 fClassSumWeights.push_back(0);
203 }
204
205 DataSet *ds = GetDataSet();
206 for (Int_t ievt = 0; ievt < ds->GetNEvents(); ievt++) {
207 const Event *ev = ds->GetEvent(ievt);
208 fClassSumWeights[ev->GetClass()] += ev->GetWeight();
209 fEventWeights.push_back(ev->GetWeight());
210 fEventClasses.push_back(ev->GetClass());
211 }
212
213 const TString name( "MulticlassGA" );
214 const TString opts( "PopSize=100:Steps=30" );
215 GeneticFitter mg( *this, name, ranges, opts);
216
217 std::vector<Double_t> result;
218 mg.Run(result);
219
220 fBestCuts.at(targetClass) = result;
221
222 UInt_t n = 0;
223 for( std::vector<Double_t>::iterator it = result.begin(); it<result.end(); ++it ){
224 Log() << kINFO << " cutValue[" <<dsi->GetClassInfo( n )->GetName() << "] = " << (*it) << ";"<< Endl;
225 n++;
226 }
227
228 return result;
229}
230
231////////////////////////////////////////////////////////////////////////////////
232/// Create performance graphs for this classifier a multiclass setting.
233/// Requires that the method has already been evaluated (that a resultset
234/// already exists.)
235///
236/// Currently uses the new way of calculating ROC Curves. If anything looks
237/// fishy, please contact the ROOT TMVA team.
238///
239
241{
242
243 Log() << kINFO << "Creating multiclass performance histograms..." << Endl;
244
245 DataSet *ds = GetDataSet();
246 ds->SetCurrentType(GetTreeType());
247 const DataSetInfo *dsi = GetDataSetInfo();
248
249 UInt_t numClasses = dsi->GetNClasses();
250
251 std::vector<std::vector<Float_t>> *rawMvaRes = GetValueVector();
252
253 //
254 // 1-vs-rest ROC curves
255 //
256 for (size_t iClass = 0; iClass < numClasses; ++iClass) {
257
258 TString className = dsi->GetClassInfo(iClass)->GetName();
259 TString name = TString::Format("%s_rejBvsS_%s", prefix.Data(), className.Data());
260 TString title = TString::Format("%s_%s", prefix.Data(), className.Data());
261
262 // Histograms are already generated, skip.
263 if ( DoesExist(name) ) {
264 return;
265 }
266
267 // Format data
268 std::vector<Float_t> mvaRes;
269 std::vector<Bool_t> mvaResTypes;
270 std::vector<Float_t> mvaResWeights;
271
272 // Vector transpose due to values being stored as
273 // [ [0, 1, 2], [0, 1, 2], ... ]
274 // in ResultsMulticlass::GetValueVector.
275 mvaRes.reserve(rawMvaRes->size());
276 for (auto item : *rawMvaRes) {
277 mvaRes.push_back(item[iClass]);
278 }
279
280 auto eventCollection = ds->GetEventCollection();
281 mvaResTypes.reserve(eventCollection.size());
282 mvaResWeights.reserve(eventCollection.size());
283 for (auto ev : eventCollection) {
284 mvaResTypes.push_back(ev->GetClass() == iClass);
285 mvaResWeights.push_back(ev->GetWeight());
286 }
287
288 // Get ROC Curve
289 ROCCurve *roc = new ROCCurve(mvaRes, mvaResTypes, mvaResWeights);
290 TGraph *rocGraph = new TGraph(*(roc->GetROCCurve()));
291 delete roc;
292
293 // Style ROC Curve
294 rocGraph->SetName(name);
295 rocGraph->SetTitle(title);
296
297 // Store ROC Curve
298 Store(rocGraph);
299 }
300
301 //
302 // 1-vs-1 ROC curves
303 //
304 for (size_t iClass = 0; iClass < numClasses; ++iClass) {
305 for (size_t jClass = 0; jClass < numClasses; ++jClass) {
306 if (iClass == jClass) {
307 continue;
308 }
309
310 auto eventCollection = ds->GetEventCollection();
311
312 // Format data
313 std::vector<Float_t> mvaRes;
314 std::vector<Bool_t> mvaResTypes;
315 std::vector<Float_t> mvaResWeights;
316
317 mvaRes.reserve(rawMvaRes->size());
318 mvaResTypes.reserve(eventCollection.size());
319 mvaResWeights.reserve(eventCollection.size());
320
321 for (size_t iEvent = 0; iEvent < eventCollection.size(); ++iEvent) {
322 Event *ev = eventCollection[iEvent];
323
324 if (ev->GetClass() == iClass || ev->GetClass() == jClass) {
325 Float_t output_value = (*rawMvaRes)[iEvent][iClass];
326 mvaRes.push_back(output_value);
327 mvaResTypes.push_back(ev->GetClass() == iClass);
328 mvaResWeights.push_back(ev->GetWeight());
329 }
330 }
331
332 // Get ROC Curve
333 ROCCurve *roc = new ROCCurve(mvaRes, mvaResTypes, mvaResWeights);
334 TGraph *rocGraph = new TGraph(*(roc->GetROCCurve()));
335 delete roc;
336
337 // Style ROC Curve
338 TString iClassName = dsi->GetClassInfo(iClass)->GetName();
339 TString jClassName = dsi->GetClassInfo(jClass)->GetName();
340 TString name = TString::Format("%s_1v1rejBvsS_%s_vs_%s", prefix.Data(), iClassName.Data(), jClassName.Data());
341 TString title = TString::Format("%s_%s_vs_%s", prefix.Data(), iClassName.Data(), jClassName.Data());
342 rocGraph->SetName(name);
343 rocGraph->SetTitle(title);
344
345 // Store ROC Curve
346 Store(rocGraph);
347 }
348 }
349}
350
351////////////////////////////////////////////////////////////////////////////////
352/// this function fills the mva response histos for multiclass classification
353
355{
356 Log() << kINFO << "Creating multiclass response histograms..." << Endl;
357
358 DataSet* ds = GetDataSet();
359 ds->SetCurrentType( GetTreeType() );
360 const DataSetInfo* dsi = GetDataSetInfo();
361
362 std::vector<std::vector<TH1F*> > histos;
363 Float_t xmin = 0.-0.0002;
364 Float_t xmax = 1.+0.0002;
365 for (UInt_t iCls = 0; iCls < dsi->GetNClasses(); iCls++) {
366 histos.push_back(std::vector<TH1F*>(0));
367 for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
368 TString name = TString::Format("%s_%s_prob_for_%s",prefix.Data(),
369 dsi->GetClassInfo( jCls )->GetName(),
370 dsi->GetClassInfo( iCls )->GetName());
371
372 // Histograms are already generated, skip.
373 if ( DoesExist(name) ) {
374 return;
375 }
376
377 histos.at(iCls).push_back(new TH1F(name,name,nbins,xmin,xmax));
378 }
379 }
380
381 for (Int_t ievt=0; ievt<ds->GetNEvents(); ievt++) {
382 const Event* ev = ds->GetEvent(ievt);
383 Int_t cls = ev->GetClass();
384 Float_t w = ev->GetWeight();
385 for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
386 histos.at(cls).at(jCls)->Fill(fMultiClassValues[ievt][jCls],w);
387 }
388 }
389 for (UInt_t iCls = 0; iCls < dsi->GetNClasses(); iCls++) {
390 for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
391 gTools().NormHist( histos.at(iCls).at(jCls) );
392 Store(histos.at(iCls).at(jCls));
393 }
394 }
395
396 /*
397 //fill fine binned histos for testing
398 if(prefix.Contains("Test")){
399 std::vector<std::vector<TH1F*> > histos_highbin;
400 for (UInt_t iCls = 0; iCls < dsi->GetNClasses(); iCls++) {
401 histos_highbin.push_back(std::vector<TH1F*>(0));
402 for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
403 TString name = TString::Format("%s_%s_prob_for_%s_HIGHBIN",prefix.Data(),
404 dsi->GetClassInfo( jCls )->GetName().Data(),
405 dsi->GetClassInfo( iCls )->GetName().Data());
406 histos_highbin.at(iCls).push_back(new TH1F(name,name,nbins_high,xmin,xmax));
407 }
408 }
409
410 for (Int_t ievt=0; ievt<ds->GetNEvents(); ievt++) {
411 const Event* ev = ds->GetEvent(ievt);
412 Int_t cls = ev->GetClass();
413 Float_t w = ev->GetWeight();
414 for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
415 histos_highbin.at(cls).at(jCls)->Fill(fMultiClassValues[ievt][jCls],w);
416 }
417 }
418 for (UInt_t iCls = 0; iCls < dsi->GetNClasses(); iCls++) {
419 for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
420 gTools().NormHist( histos_highbin.at(iCls).at(jCls) );
421 Store(histos_highbin.at(iCls).at(jCls));
422 }
423 }
424 }
425 */
426}
float Float_t
Definition RtypesCore.h:57
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void value
char name[80]
Definition TGX11.cxx:110
float xmin
float xmax
static char * Format(const char *format, va_list ap)
Format a string in a circular formatting buffer (using a printf style format descriptor).
Definition TString.cxx:2442
A TGraph is an object made of two arrays X and Y with npoints each.
Definition TGraph.h:41
void SetName(const char *name="") override
Set graph name.
Definition TGraph.cxx:2364
void SetTitle(const char *title="") override
Change (i.e.
Definition TGraph.cxx:2380
1-D histogram with a float per channel (see TH1 documentation)
Definition TH1.h:621
Class that contains all the data information.
Definition DataSetInfo.h:62
UInt_t GetNClasses() const
ClassInfo * GetClassInfo(Int_t clNum) const
Class that contains all the data information.
Definition DataSet.h:58
const Event * GetEvent() const
returns event without transformations
Definition DataSet.cxx:202
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition DataSet.h:206
void SetCurrentType(Types::ETreeType type) const
Definition DataSet.h:89
const std::vector< Event * > & GetEventCollection(Types::ETreeType type=Types::kMaxTreeType) const
Definition DataSet.h:216
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not.
Definition Event.cxx:389
UInt_t GetClass() const
Definition Event.h:86
Fitter using a Genetic Algorithm.
Double_t Run(std::vector< Double_t > &pars)
Execute fitting.
Interface for a fitter 'target'.
The TMVA::Interval Class.
Definition Interval.h:61
ostringstream derivative to redirect and format output
Definition MsgLogger.h:57
Double_t GetEffSForEffB(Double_t effB, const UInt_t num_points=41)
Calculate the signal efficiency (sensitivity) for a given background efficiency (sensitivity).
Definition ROCCurve.cxx:217
TGraph * GetROCCurve(const UInt_t points=100)
Returns a new TGraph containing the ROC curve.
Definition ROCCurve.cxx:274
TMatrixD GetConfusionMatrix(Double_t effB)
Returns a confusion matrix where each class is pitted against each other.
Double_t EstimatorFunction(std::vector< Double_t > &) override
std::vector< Double_t > GetBestMultiClassCuts(UInt_t targetClass)
calculate the best working point (optimal cut values) for the multiclass classifier
ResultsMulticlass(const DataSetInfo *dsi, TString resultsName)
constructor
void CreateMulticlassHistos(TString prefix, Int_t nbins, Int_t nbins_high)
this function fills the mva response histos for multiclass classification
void CreateMulticlassPerformanceHistos(TString prefix)
Create performance graphs for this classifier a multiclass setting.
void SetValue(std::vector< Float_t > &value, Int_t ievt)
Class that is the base-class for a vector of result.
Definition Results.h:57
Double_t NormHist(TH1 *theHist, Double_t norm=1.0)
normalises histogram
Definition Tools.cxx:383
const char * GetName() const override
Returns name of object.
Definition TNamed.h:47
Basic string class.
Definition TString.h:139
const char * Data() const
Definition TString.h:376
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2378
const Int_t n
Definition legend1.C:16
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148