Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
efficienciesMulticlass.cxx
Go to the documentation of this file.
1// @(#)Root/tmva $Id$
2// Author: Kim Albertsson
3/**********************************************************************************
4 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
5 * Package: TMVAGUI *
6 * Web : http://tmva.sourceforge.net *
7 * *
8 * Description: *
9 * Implementation (see header for description) *
10 * *
11 * Authors : *
12 * Kim Albertsson <kim.albertsson@cern.ch> - LTU & CERN *
13 * *
14 * Copyright (c) 2005-2017: *
15 * CERN, Switzerland *
16 * LTU, Sweden *
17 * *
18 * Redistribution and use in source and binary forms, with or without *
19 * modification, are permitted according to the terms listed in LICENSE *
20 * (http://tmva.sourceforge.net/LICENSE) *
21 **********************************************************************************/
22
24
25// TMVA
26#include "TMVA/Config.h"
27#include "TMVA/tmvaglob.h"
28
29// ROOT
30#include "TControlBar.h"
31#include "TFile.h"
32#include "TGraph.h"
33#include "TH2F.h"
34#include "TIterator.h"
35#include "TKey.h"
36#include "TROOT.h"
37
38// STL
39#include <iostream>
40
41////////////////////////////////////////////////////////////////////////////////
42///
43/// Note: This file assumes a certain structure on the input file. The structure
44/// is as follows:
45///
46/// - dataset (TDirectory)
47/// - ... some variables, plots ...
48/// - Method_XXX (TDirectory)
49/// + XXX (TDirectory)
50/// * ... some plots ...
51/// * MVA_Method_XXX_Test_#classname#
52/// * MVA_Method_XXX_Train_#classname#
53/// * ... some plots ...
54/// - Method_YYY (TDirectory)
55/// + YYY (TDirectory)
56/// * ... some plots ...
57/// * MVA_Method_YYY_Test_#classname#
58/// * MVA_Method_YYY_Train_#classname#
59/// * ... some plots ...
60/// - TestTree (TTree)
61/// + ... data...
62/// - TrainTree (TTree)
63/// + ... data...
64///
65/// Keeping this in mind makes the main loop in getRocCurves easier to follow :)
66///
67
68////////////////////////////////////////////////////////////////////////////////
69/// Private class that simplify drawing plots combining information from
70/// several methods.
71///
72/// Each wrapper will manage a canvas and a legend and provide convenience
73/// functions to add data to these. It also provides a save function for
74/// saving an image representation to disk.
75///
76/// Feel free to extend this class as you see fit. It is intended as a
77/// convenience when showing multiclass roccurves, not a fully general tool.
78///
79/// Usage:
80/// auto p = new EfficiencyPlotWrapper(name, title, dataset, i):
81/// for (TGraph * g : listOfGraphs) {
82/// p->AddGraph(g);
83/// p->AddLegendEntry(methodName);
84/// }
85/// p->save();
86///
87
88class EfficiencyPlotWrapper {
89public:
90 TCanvas *fCanvas;
91 TLegend *fLegend;
92
93 TString fDataset;
94
95 Int_t fColor;
96
97 UInt_t fNumMethods;
98
99 EfficiencyPlotWrapper(TString name, TString title, TString dataset, size_t i);
100
101 Int_t addGraph(TGraph *graph);
102 void addLegendEntry(TString methodTitle, TGraph *graph);
103
104 void save();
105
106private:
107 Float_t fx0L;
108 Float_t fdxL;
109 Float_t fy0H;
110 Float_t fdyH;
111
112 TCanvas *newEfficiencyCanvas(TString name, TString title, size_t i);
113 TLegend *newEfficiencyLegend();
114};
115
116using classcanvasmap_t = std::map<TString, EfficiencyPlotWrapper *>;
117using roccurvelist_t = std::vector<std::tuple<TString, TString, TGraph *>>;
118
119// Constants
120const char *BUTTON_TYPE = "button";
121
122// Private functions
123namespace TMVA {
124std::vector<TString> getclassnames(TString dataset, TString fin);
125roccurvelist_t getRocCurves(TDirectory *binDir, TString methodPrefix, TString graphNameRef);
127}
128
129////////////////////////////////////////////////////////////////////////////////
130/// Private (helper) functions - Implementation
131////////////////////////////////////////////////////////////////////////////////
132
133////////////////////////////////////////////////////////////////////////////////
134///
135
136std::vector<TString> TMVA::getclassnames(TString dataset, TString fin)
137{
139 TDirectory *dir = (TDirectory *)file->GetDirectory(dataset)->GetDirectory("InputVariables_Id");
140 if (!dir) {
141 std::cout << "Could not locate directory '" << dataset << "/InputVariables_Id' in file: " << fin << std::endl;
142 return {};
143 }
144
145 auto classnames = TMVA::TMVAGlob::GetClassNames(dir);
146 return classnames;
147}
148
149////////////////////////////////////////////////////////////////////////////////
150///
151
152roccurvelist_t TMVA::getRocCurves(TDirectory *binDir, TString methodPrefix, TString graphNameRef)
153{
154 roccurvelist_t rocCurves;
155
156 TList methods;
157 UInt_t nm = TMVAGlob::GetListOfMethods(methods, binDir);
158 if (nm == 0) {
159 cout << "ups .. no methods found in to plot ROC curve for ... give up" << endl;
160 return rocCurves;
161 }
162 // TIter next(file->GetListOfKeys());
163 TIter next(&methods);
164
165 // Loop over all method categories
166 TKey *key;
167 while ((key = (TKey *)next())) {
168 TDirectory *mDir = (TDirectory *)key->ReadObj();
169 TList titles;
170 TMVAGlob::GetListOfTitles(mDir, titles);
171
172 // Loop over each method within a category
173 TIter nextTitle(&titles);
174 TKey *titkey;
175 TDirectory *titDir;
176 while ((titkey = TMVAGlob::NextKey(nextTitle, "TDirectory"))) {
177 titDir = (TDirectory *)titkey->ReadObj();
178 TString methodTitle;
179 TMVAGlob::GetMethodTitle(methodTitle, titDir);
180
181 // Loop through all plots for the method
182 TIter nextKey(titDir->GetListOfKeys());
183 TKey *hkey2;
184 while ((hkey2 = TMVAGlob::NextKey(nextKey, "TGraph"))) {
185
186 TGraph *h = (TGraph *)hkey2->ReadObj();
187 TString hname = h->GetName();
188 if (hname.Contains(graphNameRef) && hname.BeginsWith(methodPrefix) && !hname.Contains("Train")) {
189
190 // Extract classname from plot name
191 // classname is string after nameref
192 Int_t index = hname.Index(graphNameRef) + graphNameRef.Length();
193 TString classname = hname(index, hname.Length() - index);
194
195 //std::cout << "Found TGraph " << hname << " with classname " << classname << std::endl;
196
197 rocCurves.push_back(std::make_tuple(methodTitle, classname, h));
198 }
199 }
200 }
201 }
202 return rocCurves;
203}
204
205////////////////////////////////////////////////////////////////////////////////
206/// Public functions - Implementation
207////////////////////////////////////////////////////////////////////////////////
208
209////////////////////////////////////////////////////////////////////////////////
210/// Private convenience function.
211///
212/// Adds a given a list of roc curves provided as n-tuple on the form
213/// (methodname, classname, graph)
214/// to the canvas corresponding to the classname.
215///
216
218{
219 for (auto &item : rocCurves) {
220
221 TString methodTitle = std::get<0>(item);
222 TString classname = std::get<1>(item);
223 TGraph *h = std::get<2>(item);
224
225 try {
226 EfficiencyPlotWrapper *plotWrapper = classCanvasMap.at(classname);
227 plotWrapper->addGraph(h);
228 plotWrapper->addLegendEntry(methodTitle, h);
229 } catch (const std::out_of_range &) {
230 cout << Form("ERROR: Class %s discovered among plots but was not found by TMVAMulticlassGui. Skipping.",
231 classname.Data())
232 << endl;
233 }
234 }
235}
236
237////////////////////////////////////////////////////////////////////////////////
238/// Entry point. Called from the TMVAMulticlassGui Buttons
239///
240/// \param dataset Dataset to operate on. Should be created by the TMVA Multiclass Factory.
241/// \param filename_input Name of the input file procuded by a TMVA Multiclass Factory.
242/// \param plotType Specified what kind of ROC curve to draw. Currently only rejB vs. effS is supported.
243
244void TMVA::efficienciesMulticlass1vsRest(TString dataset, TString filename_input, EEfficiencyPlotType plotType,
245 Bool_t useTMVAStyle)
246{
247 // set style and remove existing canvas'
248 TMVAGlob::Initialize(useTMVAStyle);
249 plotEfficienciesMulticlass1vsRest(dataset, plotType, filename_input);
250 return;
251}
252
253////////////////////////////////////////////////////////////////////////////////
254/// Work horse function. Will operate on the currently open file (opened by
255/// efficienciesMulticlass).
256///
257/// \param plotType See effcienciesMulticlass.
258/// \param binDir Directory in the file on which to operate.
259
260void TMVA::plotEfficienciesMulticlass1vsRest(TString dataset, EEfficiencyPlotType plotType, TString filename_input)
261{
262 // The current multiclass version implements only type 2 - rejB vs effS
263 if (plotType != EEfficiencyPlotType::kRejBvsEffS) {
264 std::cout << "For multiclass, only rejB vs effS is currently implemented.";
265 return;
266 }
267
268 // checks if filename_input is already open, and if not opens one
269 TFile *file = TMVAGlob::OpenFile(filename_input);
270 if (file == nullptr) {
271 std::cout << "ERROR: filename \"" << filename_input << "\" is not found.";
272 return;
273 }
274 auto binDir = file->GetDirectory(dataset.Data());
275
276 size_t iPlot = 0;
277 auto classnames = getclassnames(dataset, filename_input);
278 TString methodPrefix = "MVA_";
279 TString graphNameRef = "_rejBvsS_";
280
281 classcanvasmap_t classCanvasMap;
282 for (auto &classname : classnames) {
283 TString name = Form("roc_%s_vs_rest", classname.Data());
284 TString title = Form("ROC Curve %s vs rest", classname.Data());
285 EfficiencyPlotWrapper *plotWrapper = new EfficiencyPlotWrapper(name, title, dataset, iPlot++);
286 classCanvasMap.emplace(classname.Data(), plotWrapper);
287 }
288
289 roccurvelist_t rocCurves = getRocCurves(binDir, methodPrefix, graphNameRef);
290 plotEfficienciesMulticlass(rocCurves, classCanvasMap);
291
292 for (auto const &item : classCanvasMap) {
293 auto plotWrapper = item.second;
294 plotWrapper->save();
295 }
296}
297
298////////////////////////////////////////////////////////////////////////////////
299/// Entry point. Called from the TMVAMulticlassGui Buttons
300///
301/// \param dataset
302/// \param fin
303
305{
306 std::cout << "--- Running Roc1v1Gui for input file: " << fin << std::endl;
307
308 TMVAGlob::Initialize();
309
310 // create the control bar
311 TString title = "1v1 ROC curve comparison";
312 TControlBar *cbar = new TControlBar("vertical", title, 50, 50);
313
314 gDirectory->pwd();
315 auto classnames = getclassnames(dataset, fin);
316
317 // configure buttons
318 for (auto &classname : classnames) {
319 cbar->AddButton(Form("Class: %s", classname.Data()),
320 Form("TMVA::plotEfficienciesMulticlass1vs1(\"%s\", \"%s\", \"%s\")", dataset.Data(), fin.Data(),
321 classname.Data()),
323 }
324
325 cbar->SetTextColor("blue");
326 cbar->Show();
327
328 gROOT->SaveContext();
329}
330
331////////////////////////////////////////////////////////////////////////////////
332/// Generates K-1 plots comparing a given base class against all others (except
333/// itself). For each plot, the base class is considered signal and the other
334/// class is considered background.
335///
336/// Given 3 classes in the dataset and providing "Class 0" as the base class
337/// this would generate 2 plots comparing
338/// - Class 0 vs Class 1, and
339/// - Class 0 vs Class 2.
340/// For the "Class 0 vs Class 1" plot, events from Class 2 are ignored. For the
341/// "Class 0 vs Class 2" plot, events from Class 1 are ignored.
342///
343/// \param dataset
344/// \param fin
345/// \param baseClassname name of the class which will be considered signal
346
347void TMVA::plotEfficienciesMulticlass1vs1(TString dataset, TString fin, TString baseClassname)
348{
349
350 TMVAGlob::Initialize();
351
352 auto classnames = getclassnames(dataset, fin);
353 size_t iPlot = 0;
354
355 TString methodPrefix = "MVA_";
356 TString graphNameRef = Form("_1v1rejBvsS_%s_vs_", baseClassname.Data());
357
358 TFile *file = TMVAGlob::OpenFile(fin);
359 if (file == nullptr) {
360 std::cout << "ERROR: filename \"" << fin << "\" is not found.";
361 return;
362 }
363 auto binDir = file->GetDirectory(dataset.Data());
364
365 classcanvasmap_t classCanvasMap;
366 for (auto &classname : classnames) {
367
368 if (classname == baseClassname) {
369 continue;
370 }
371
372 TString name = Form("1v1roc_%s_vs_%s", baseClassname.Data(), classname.Data());
373 TString title = Form("ROC Curve %s (Sig) vs %s (Bkg)", baseClassname.Data(), classname.Data());
374 EfficiencyPlotWrapper *plotWrapper = new EfficiencyPlotWrapper(name, title, dataset, iPlot++);
375 classCanvasMap.emplace(classname.Data(), plotWrapper);
376 }
377
378 roccurvelist_t rocCurves = getRocCurves(binDir, methodPrefix, graphNameRef);
379 plotEfficienciesMulticlass(rocCurves, classCanvasMap);
380
381 for (auto const &item : classCanvasMap) {
382 auto plotWrapper = item.second;
383 plotWrapper->save();
384 }
385}
386
387////////////////////////////////////////////////////////////////////////////////
388/// Private class EfficiencyPlotWrapper - Implementation
389////////////////////////////////////////////////////////////////////////////////
390
391////////////////////////////////////////////////////////////////////////////////
392/// Constructs a new canvas + auxiliary data for showing an efficiency plot.
393///
394
395EfficiencyPlotWrapper::EfficiencyPlotWrapper(TString name, TString title, TString dataset, size_t i)
396{
397 // Legend extents (init before calling newEfficiencyLegend...)
398 fx0L = 0.107;
399 fy0H = 0.899;
400 fdxL = 0.457 - fx0L;
401 fdyH = 0.22;
402 fx0L = 0.15;
403 fy0H = 1 - fy0H + fdyH + 0.07;
404
405 fColor = 1;
406 fNumMethods = 0;
407
408 fDataset = dataset;
409
410 fCanvas = newEfficiencyCanvas(name, title, i);
411 fLegend = newEfficiencyLegend();
412}
413
414////////////////////////////////////////////////////////////////////////////////
415/// Adds a new graph to the plot. The added graph should contain a single ROC
416/// curve.
417///
418
419Int_t EfficiencyPlotWrapper::addGraph(TGraph *graph)
420{
421 graph->SetLineWidth(3);
422 graph->SetLineColor(fColor);
423 fColor++;
424 if (fColor == 5 || fColor == 10 || fColor == 11) {
425 fColor++;
426 }
427
428 fCanvas->cd();
429 graph->Draw("");
430 fCanvas->Update();
431
432 ++fNumMethods;
433
434 return fColor;
435}
436
437////////////////////////////////////////////////////////////////////////////////
438/// WARNING: Uses the current color, thus the correct call ordering is:
439/// plotWrapper->addGraph(...);
440/// plotWrapper->addLegendEntry(...);
441///
442
443void EfficiencyPlotWrapper::addLegendEntry(TString methodTitle, TGraph *graph)
444{
445 fLegend->AddEntry(graph, methodTitle, "l");
446
447 Float_t dyH_local = fdyH * (Float_t(TMath::Min((UInt_t)10, fNumMethods) - 3.0) / 4.0);
448 fLegend->SetY2(fy0H + dyH_local);
449
450 fLegend->Paint();
451 fCanvas->Update();
452}
453
454////////////////////////////////////////////////////////////////////////////////
455/// Helper to create new Canvas
456///
457/// \param name Name...
458/// \param title Title to be displayed on canvas
459/// \param i Index to offset a collection of canvases from each other
460///
461
462TCanvas *EfficiencyPlotWrapper::newEfficiencyCanvas(TString name, TString title, size_t i)
463{
464 TCanvas *c = new TCanvas(name, title, 200 + i * 50, 0 + i * 50, 650, 500);
465 // global style settings
466 c->SetGrid();
467 c->SetTicks();
468
469 // Frame
470 TString xtit = "Signal Efficiency";
471 TString ytit = "Background Rejection (1 - eff)";
472 Double_t x1 = 0.0;
473 Double_t x2 = 1.0;
474 Double_t y1 = 0.0;
475 Double_t y2 = 1.0;
476
477 TH1F *frame = new TH1F(Form("%s_%s", title.Data(), "frame"), title, 500, x1, x2);
478 frame->SetMinimum(y1);
479 frame->SetMaximum(y2);
480
481 frame->GetXaxis()->SetTitle(xtit);
482 frame->GetYaxis()->SetTitle(ytit);
484 frame->Draw();
485
486 return c;
487}
488
489////////////////////////////////////////////////////////////////////////////////
490/// Helper to create new legend.
491
492TLegend *EfficiencyPlotWrapper::newEfficiencyLegend()
493{
494 TLegend *legend = new TLegend(fx0L, fy0H - fdyH, fx0L + fdxL, fy0H);
495 // legend->SetTextSize( 0.05 );
496 legend->SetHeader("MVA Method:");
497 legend->SetMargin(0.4);
498 legend->Draw("");
499
500 return legend;
501}
502
503////////////////////////////////////////////////////////////////////////////////
504/// Saves the current state of the plot to disk.
505///
506
507void EfficiencyPlotWrapper::save()
508{
509 TString fname = fDataset + "/plots/" + fCanvas->GetName();
510 TMVA::TMVAGlob::imgconv(fCanvas, fname);
511}
#define c(i)
Definition RSha256.hxx:101
#define h(i)
Definition RSha256.hxx:106
static const double x2[5]
static const double x1[5]
int Int_t
Definition RtypesCore.h:45
unsigned int UInt_t
Definition RtypesCore.h:46
bool Bool_t
Definition RtypesCore.h:63
double Double_t
Definition RtypesCore.h:59
float Float_t
Definition RtypesCore.h:57
#define gDirectory
Definition TDirectory.h:290
char name[80]
Definition TGX11.cxx:110
#define gROOT
Definition TROOT.h:406
char * Form(const char *fmt,...)
The Canvas class.
Definition TCanvas.h:23
A Control Bar is a fully user configurable tool which provides fast access to frequently used operati...
Definition TControlBar.h:26
void Show()
Show control bar.
void AddButton(TControlBarButton *button)
Add button.
void SetTextColor(const char *colorName)
Sets text color for control bar buttons, e.g.:
Describe directory structure in memory.
Definition TDirectory.h:45
virtual TList * GetListOfKeys() const
Definition TDirectory.h:177
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition TFile.h:54
A TGraph is an object made of two arrays X and Y with npoints each.
Definition TGraph.h:41
1-D histogram with a float per channel (see TH1 documentation)}
Definition TH1.h:575
TAxis * GetXaxis()
Get the behaviour adopted by the object about the statoverflows. See EStatOverflows for more informat...
Definition TH1.h:320
virtual void SetMaximum(Double_t maximum=-1111)
Definition TH1.h:398
TAxis * GetYaxis()
Definition TH1.h:321
virtual void SetMinimum(Double_t minimum=-1111)
Definition TH1.h:399
virtual void Draw(Option_t *option="")
Draw this histogram with options.
Definition TH1.cxx:3073
Book space in a file, create I/O buffers, to fill them, (un)compress them.
Definition TKey.h:28
virtual TObject * ReadObj()
To read a TObject* from the file.
Definition TKey.cxx:750
This class displays a legend box (TPaveText) containing several legend entries.
Definition TLegend.h:23
virtual void SetHeader(const char *header="", Option_t *option="")
Sets the header, which is the "title" that appears at the top of the legend.
Definition TLegend.cxx:1099
virtual void Draw(Option_t *option="")
Draw this legend with its current attributes.
Definition TLegend.cxx:423
void SetMargin(Float_t margin)
Definition TLegend.h:69
A doubly linked list.
Definition TList.h:44
virtual void SetTitle(const char *title="")
Set the title of the TNamed.
Definition TNamed.cxx:164
Basic string class.
Definition TString.h:136
Ssiz_t Length() const
Definition TString.h:410
const char * Data() const
Definition TString.h:369
Bool_t BeginsWith(const char *s, ECaseCompare cmp=kExact) const
Definition TString.h:615
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition TString.h:624
Ssiz_t Index(const char *pat, Ssiz_t i=0, ECaseCompare cmp=kExact) const
Definition TString.h:639
std::vector< std::tuple< TString, TString, TGraph * > > roccurvelist_t
const char * BUTTON_TYPE
std::map< TString, EfficiencyPlotWrapper * > classcanvasmap_t
TFile * OpenFile(const TString &fin)
Definition tmvaglob.cxx:192
void SetFrameStyle(TH1 *frame, Float_t scale=1.0)
Definition tmvaglob.cxx:77
std::vector< TString > GetClassNames(TDirectory *dir)
Definition tmvaglob.cxx:469
void imgconv(TCanvas *c, const TString &fname)
Definition tmvaglob.cxx:212
create variable transformations
roccurvelist_t getRocCurves(TDirectory *binDir, TString methodPrefix, TString graphNameRef)
void efficienciesMulticlass1vs1(TString dataset, TString fin)
std::vector< TString > getclassnames(TString dataset, TString fin)
void plotEfficienciesMulticlass1vs1(TString dataset, TString fin, TString baseClassname)
void plotEfficienciesMulticlass1vsRest(TString dataset, EEfficiencyPlotType plotType=EEfficiencyPlotType::kRejBvsEffS, TString filename_input="TMVAMulticlass.root")
void plotEfficienciesMulticlass(roccurvelist_t rocCurves, classcanvasmap_t classCanvasMap)
void efficienciesMulticlass1vsRest(TString dataset, TString filename_input="TMVAMulticlass.root", EEfficiencyPlotType plotType=EEfficiencyPlotType::kRejBvsEffS, Bool_t useTMVAStyle=kTRUE)
Short_t Min(Short_t a, Short_t b)
Definition TMathBase.h:180
Definition file.py:1
Definition graph.py:1