Logo ROOT  
Reference Guide
efficienciesMulticlass.cxx
Go to the documentation of this file.
1// @(#)Root/tmva $Id$
2// Author: Kim Albertsson
3/**********************************************************************************
4 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
5 * Package: TMVAGUI *
6 * Web : http://tmva.sourceforge.net *
7 * *
8 * Description: *
9 * Implementation (see header for description) *
10 * *
11 * Authors : *
12 * Kim Albertsson <kim.albertsson@cern.ch> - LTU & CERN *
13 * *
14 * Copyright (c) 2005-2017: *
15 * CERN, Switzerland *
16 * LTU, Sweden *
17 * *
18 * Redistribution and use in source and binary forms, with or without *
19 * modification, are permitted according to the terms listed in LICENSE *
20 * (http://tmva.sourceforge.net/LICENSE) *
21 **********************************************************************************/
22
24
25// TMVA
26#include "TMVA/Config.h"
27#include "TMVA/tmvaglob.h"
28
29// ROOT
30#include "TControlBar.h"
31#include "TFile.h"
32#include "TGraph.h"
33#include "TH2F.h"
34#include "TIterator.h"
35#include "TKey.h"
36#include "TROOT.h"
37
38// STL
39#include <iostream>
40
41////////////////////////////////////////////////////////////////////////////////
42///
43/// Note: This file assumes a certain structure on the input file. The structure
44/// is as follows:
45///
46/// - dataset (TDirectory)
47/// - ... some variables, plots ...
48/// - Method_XXX (TDirectory)
49/// + XXX (TDirectory)
50/// * ... some plots ...
51/// * MVA_Method_XXX_Test_#classname#
52/// * MVA_Method_XXX_Train_#classname#
53/// * ... some plots ...
54/// - Method_YYY (TDirectory)
55/// + YYY (TDirectory)
56/// * ... some plots ...
57/// * MVA_Method_YYY_Test_#classname#
58/// * MVA_Method_YYY_Train_#classname#
59/// * ... some plots ...
60/// - TestTree (TTree)
61/// + ... data...
62/// - TrainTree (TTree)
63/// + ... data...
64///
65/// Keeping this in mind makes the main loop in getRocCurves easier to follow :)
66///
67
68////////////////////////////////////////////////////////////////////////////////
69/// Private class that simplify drawing plots combining information from
70/// several methods.
71///
72/// Each wrapper will manage a canvas and a legend and provide convenience
73/// functions to add data to these. It also provides a save function for
74/// saving an image representation to disk.
75///
76/// Feel free to extend this class as you see fit. It is intended as a
77/// convenience when showing multiclass roccurves, not a fully general tool.
78///
79/// Usage:
80/// auto p = new EfficiencyPlotWrapper(name, title, dataset, i):
81/// for (TGraph * g : listOfGraphs) {
82/// p->AddGraph(g);
83/// p->AddLegendEntry(methodName);
84/// }
85/// p->save();
86///
87
88class EfficiencyPlotWrapper {
89public:
90 TCanvas *fCanvas;
91 TLegend *fLegend;
92
93 TString fDataset;
94
95 Int_t fColor;
96
97 UInt_t fNumMethods;
98
99 EfficiencyPlotWrapper(TString name, TString title, TString dataset, size_t i);
100
101 Int_t addGraph(TGraph *graph);
102 void addLegendEntry(TString methodTitle, TGraph *graph);
103
104 void save();
105
106private:
107 Float_t fx0L;
108 Float_t fdxL;
109 Float_t fy0H;
110 Float_t fdyH;
111
112 TCanvas *newEfficiencyCanvas(TString name, TString title, size_t i);
113 TLegend *newEfficiencyLegend();
114};
115
116using classcanvasmap_t = std::map<TString, EfficiencyPlotWrapper *>;
117using roccurvelist_t = std::vector<std::tuple<TString, TString, TGraph *>>;
118
119// Constants
120const char *BUTTON_TYPE = "button";
121
122// Private functions
123namespace TMVA {
124std::vector<TString> getclassnames(TString dataset, TString fin);
125roccurvelist_t getRocCurves(TDirectory *binDir, TString methodPrefix, TString graphNameRef);
127}
128
129////////////////////////////////////////////////////////////////////////////////
130/// Private (helper) functions - Implementation
131////////////////////////////////////////////////////////////////////////////////
132
133////////////////////////////////////////////////////////////////////////////////
134///
135
136std::vector<TString> TMVA::getclassnames(TString dataset, TString fin)
137{
139 TDirectory *dir = (TDirectory *)file->GetDirectory(dataset)->GetDirectory("InputVariables_Id");
140 if (!dir) {
141 std::cout << "Could not locate directory '" << dataset << "/InputVariables_Id' in file: " << fin << std::endl;
142 return {};
143 }
144
145 auto classnames = TMVA::TMVAGlob::GetClassNames(dir);
146 return classnames;
147}
148
149////////////////////////////////////////////////////////////////////////////////
150///
151
152roccurvelist_t TMVA::getRocCurves(TDirectory *binDir, TString methodPrefix, TString graphNameRef)
153{
154 roccurvelist_t rocCurves;
155
156 TList methods;
157 UInt_t nm = TMVAGlob::GetListOfMethods(methods, binDir);
158 if (nm == 0) {
159 cout << "ups .. no methods found in to plot ROC curve for ... give up" << endl;
160 return rocCurves;
161 }
162 // TIter next(file->GetListOfKeys());
163 TIter next(&methods);
164
165 // Loop over all method categories
166 TKey *key;
167 while ((key = (TKey *)next())) {
168 TDirectory *mDir = (TDirectory *)key->ReadObj();
169 TList titles;
170 TMVAGlob::GetListOfTitles(mDir, titles);
171
172 // Loop over each method within a category
173 TIter nextTitle(&titles);
174 TKey *titkey;
175 TDirectory *titDir;
176 while ((titkey = TMVAGlob::NextKey(nextTitle, "TDirectory"))) {
177 titDir = (TDirectory *)titkey->ReadObj();
178 TString methodTitle;
179 TMVAGlob::GetMethodTitle(methodTitle, titDir);
180
181 // Loop through all plots for the method
182 TIter nextKey(titDir->GetListOfKeys());
183 TKey *hkey2;
184 while ((hkey2 = TMVAGlob::NextKey(nextKey, "TGraph"))) {
185
186 TGraph *h = (TGraph *)hkey2->ReadObj();
187 TString hname = h->GetName();
188 if (hname.Contains(graphNameRef) && hname.BeginsWith(methodPrefix) && !hname.Contains("Train")) {
189
190 // Extract classname from plot name
191 UInt_t index = hname.Last('_');
192 TString classname = hname(index + 1, hname.Length() - (index + 1));
193
194 rocCurves.push_back(std::make_tuple(methodTitle, classname, h));
195 }
196 }
197 }
198 }
199 return rocCurves;
200}
201
202////////////////////////////////////////////////////////////////////////////////
203/// Public functions - Implementation
204////////////////////////////////////////////////////////////////////////////////
205
206////////////////////////////////////////////////////////////////////////////////
207/// Private convenience function.
208///
209/// Adds a given a list of roc curves provided as n-tuple on the form
210/// (methodname, classname, graph)
211/// to the canvas corresponding to the classname.
212///
213
215{
216 for (auto &item : rocCurves) {
217
218 TString methodTitle = std::get<0>(item);
219 TString classname = std::get<1>(item);
220 TGraph *h = std::get<2>(item);
221
222 try {
223 EfficiencyPlotWrapper *plotWrapper = classCanvasMap.at(classname);
224 plotWrapper->addGraph(h);
225 plotWrapper->addLegendEntry(methodTitle, h);
226 } catch (const std::out_of_range &) {
227 cout << Form("ERROR: Class %s discovered among plots but was not found by TMVAMulticlassGui. Skipping.",
228 classname.Data())
229 << endl;
230 }
231 }
232}
233
234////////////////////////////////////////////////////////////////////////////////
235/// Entry point. Called from the TMVAMulticlassGui Buttons
236///
237/// \param dataset Dataset to operate on. Should be created by the TMVA Multiclass Factory.
238/// \param filename_input Name of the input file procuded by a TMVA Multiclass Factory.
239/// \param plotType Specified what kind of ROC curve to draw. Currently only rejB vs. effS is supported.
240
241void TMVA::efficienciesMulticlass1vsRest(TString dataset, TString filename_input, EEfficiencyPlotType plotType,
242 Bool_t useTMVAStyle)
243{
244 // set style and remove existing canvas'
245 TMVAGlob::Initialize(useTMVAStyle);
246 plotEfficienciesMulticlass1vsRest(dataset, plotType, filename_input);
247 return;
248}
249
250////////////////////////////////////////////////////////////////////////////////
251/// Work horse function. Will operate on the currently open file (opened by
252/// efficienciesMulticlass).
253///
254/// \param plotType See effcienciesMulticlass.
255/// \param binDir Directory in the file on which to operate.
256
257void TMVA::plotEfficienciesMulticlass1vsRest(TString dataset, EEfficiencyPlotType plotType, TString filename_input)
258{
259 // The current multiclass version implements only type 2 - rejB vs effS
260 if (plotType != EEfficiencyPlotType::kRejBvsEffS) {
261 std::cout << "For multiclass, only rejB vs effS is currently implemented.";
262 return;
263 }
264
265 // checks if filename_input is already open, and if not opens one
266 TFile *file = TMVAGlob::OpenFile(filename_input);
267 if (file == nullptr) {
268 std::cout << "ERROR: filename \"" << filename_input << "\" is not found.";
269 return;
270 }
271 auto binDir = file->GetDirectory(dataset.Data());
272
273 size_t iPlot = 0;
274 auto classnames = getclassnames(dataset, filename_input);
275 TString methodPrefix = "MVA_";
276 TString graphNameRef = "_rejBvsS_";
277
278 classcanvasmap_t classCanvasMap;
279 for (auto &classname : classnames) {
280 TString name = Form("roc_%s_vs_rest", classname.Data());
281 TString title = Form("ROC Curve %s vs rest", classname.Data());
282 EfficiencyPlotWrapper *plotWrapper = new EfficiencyPlotWrapper(name, title, dataset, iPlot++);
283 classCanvasMap.emplace(classname.Data(), plotWrapper);
284 }
285
286 roccurvelist_t rocCurves = getRocCurves(binDir, methodPrefix, graphNameRef);
287 plotEfficienciesMulticlass(rocCurves, classCanvasMap);
288
289 for (auto const &item : classCanvasMap) {
290 auto plotWrapper = item.second;
291 plotWrapper->save();
292 }
293}
294
295////////////////////////////////////////////////////////////////////////////////
296/// Entry point. Called from the TMVAMulticlassGui Buttons
297///
298/// \param dataset
299/// \param fin
300
302{
303 std::cout << "--- Running Roc1v1Gui for input file: " << fin << std::endl;
304
306
307 // create the control bar
308 TString title = "1v1 ROC curve comparison";
309 TControlBar *cbar = new TControlBar("vertical", title, 50, 50);
310
311 gDirectory->pwd();
312 auto classnames = getclassnames(dataset, fin);
313
314 // configure buttons
315 for (auto &classname : classnames) {
316 cbar->AddButton(Form("Class: %s", classname.Data()),
317 Form("TMVA::plotEfficienciesMulticlass1vs1(\"%s\", \"%s\", \"%s\")", dataset.Data(), fin.Data(),
318 classname.Data()),
320 }
321
322 cbar->SetTextColor("blue");
323 cbar->Show();
324
325 gROOT->SaveContext();
326}
327
328////////////////////////////////////////////////////////////////////////////////
329/// Generates K-1 plots comparing a given base class against all others (except
330/// itself). For each plot, the base class is considered signal and the other
331/// class is considered background.
332///
333/// Given 3 classes in the dataset and providing "Class 0" as the base class
334/// this would generate 2 plots comparing
335/// - Class 0 vs Class 1, and
336/// - Class 0 vs Class 2.
337/// For the "Class 0 vs Class 1" plot, events from Class 2 are ignored. For the
338/// "Class 0 vs Class 2" plot, events from Class 1 are ignored.
339///
340/// \param dataset
341/// \param fin
342/// \param baseClassname name of the class which will be considered signal
343
344void TMVA::plotEfficienciesMulticlass1vs1(TString dataset, TString fin, TString baseClassname)
345{
346
348
349 auto classnames = getclassnames(dataset, fin);
350 size_t iPlot = 0;
351
352 TString methodPrefix = "MVA_";
353 TString graphNameRef = Form("_1v1rejBvsS_%s_vs_", baseClassname.Data());
354
356 if (file == nullptr) {
357 std::cout << "ERROR: filename \"" << fin << "\" is not found.";
358 return;
359 }
360 auto binDir = file->GetDirectory(dataset.Data());
361
362 classcanvasmap_t classCanvasMap;
363 for (auto &classname : classnames) {
364
365 if (classname == baseClassname) {
366 continue;
367 }
368
369 TString name = Form("1v1roc_%s_vs_%s", baseClassname.Data(), classname.Data());
370 TString title = Form("ROC Curve %s (Sig) vs %s (Bkg)", baseClassname.Data(), classname.Data());
371 EfficiencyPlotWrapper *plotWrapper = new EfficiencyPlotWrapper(name, title, dataset, iPlot++);
372 classCanvasMap.emplace(classname.Data(), plotWrapper);
373 }
374
375 roccurvelist_t rocCurves = getRocCurves(binDir, methodPrefix, graphNameRef);
376 plotEfficienciesMulticlass(rocCurves, classCanvasMap);
377
378 for (auto const &item : classCanvasMap) {
379 auto plotWrapper = item.second;
380 plotWrapper->save();
381 }
382}
383
384////////////////////////////////////////////////////////////////////////////////
385/// Private class EfficiencyPlotWrapper - Implementation
386////////////////////////////////////////////////////////////////////////////////
387
388////////////////////////////////////////////////////////////////////////////////
389/// Constructs a new canvas + auxiliary data for showing an efficiency plot.
390///
391
392EfficiencyPlotWrapper::EfficiencyPlotWrapper(TString name, TString title, TString dataset, size_t i)
393{
394 // Legend extents (init before calling newEfficiencyLegend...)
395 fx0L = 0.107;
396 fy0H = 0.899;
397 fdxL = 0.457 - fx0L;
398 fdyH = 0.22;
399 fx0L = 0.15;
400 fy0H = 1 - fy0H + fdyH + 0.07;
401
402 fColor = 1;
403 fNumMethods = 0;
404
405 fDataset = dataset;
406
407 fCanvas = newEfficiencyCanvas(name, title, i);
408 fLegend = newEfficiencyLegend();
409}
410
411////////////////////////////////////////////////////////////////////////////////
412/// Adds a new graph to the plot. The added graph should contain a single ROC
413/// curve.
414///
415
416Int_t EfficiencyPlotWrapper::addGraph(TGraph *graph)
417{
418 graph->SetLineWidth(3);
419 graph->SetLineColor(fColor);
420 fColor++;
421 if (fColor == 5 || fColor == 10 || fColor == 11) {
422 fColor++;
423 }
424
425 fCanvas->cd();
426 graph->DrawClone("");
427 fCanvas->Update();
428
429 ++fNumMethods;
430
431 return fColor;
432}
433
434////////////////////////////////////////////////////////////////////////////////
435/// WARNING: Uses the current color, thus the correct call ordering is:
436/// plotWrapper->addGraph(...);
437/// plotWrapper->addLegendEntry(...);
438///
439
440void EfficiencyPlotWrapper::addLegendEntry(TString methodTitle, TGraph *graph)
441{
442 fLegend->AddEntry(graph, methodTitle, "l");
443
444 Float_t dyH_local = fdyH * (Float_t(TMath::Min((UInt_t)10, fNumMethods) - 3.0) / 4.0);
445 fLegend->SetY2(fy0H + dyH_local);
446
447 fLegend->Paint();
448 fCanvas->Update();
449}
450
451////////////////////////////////////////////////////////////////////////////////
452/// Helper to create new Canvas
453///
454/// \param name Name...
455/// \param title Title to be displayed on canvas
456/// \param i Index to offset a collection of canvases from each other
457///
458
459TCanvas *EfficiencyPlotWrapper::newEfficiencyCanvas(TString name, TString title, size_t i)
460{
461 TCanvas *c = new TCanvas(name, title, 200 + i * 50, 0 + i * 50, 650, 500);
462 // global style settings
463 c->SetGrid();
464 c->SetTicks();
465
466 // Frame
467 TString xtit = "Signal Efficiency";
468 TString ytit = "Background Rejection (1 - eff)";
469 Double_t x1 = 0.0;
470 Double_t x2 = 1.0;
471 Double_t y1 = 0.0;
472 Double_t y2 = 1.0;
473
474 TH2F *frame = new TH2F(Form("%s_%s", title.Data(), "frame"), title, 500, x1, x2, 500, y1, y2);
475 frame->GetXaxis()->SetTitle(xtit);
476 frame->GetYaxis()->SetTitle(ytit);
478 frame->DrawClone();
479
480 return c;
481}
482
483////////////////////////////////////////////////////////////////////////////////
484/// Helper to create new legend.
485
486TLegend *EfficiencyPlotWrapper::newEfficiencyLegend()
487{
488 TLegend *legend = new TLegend(fx0L, fy0H - fdyH, fx0L + fdxL, fy0H);
489 // legend->SetTextSize( 0.05 );
490 legend->SetHeader("MVA Method:");
491 legend->SetMargin(0.4);
492 legend->Draw("");
493
494 return legend;
495}
496
497////////////////////////////////////////////////////////////////////////////////
498/// Saves the current state of the plot to disk.
499///
500
501void EfficiencyPlotWrapper::save()
502{
503 TString fname = fDataset + "/plots/" + fCanvas->GetName();
504 TMVA::TMVAGlob::imgconv(fCanvas, fname);
505}
#define c(i)
Definition: RSha256.hxx:101
#define h(i)
Definition: RSha256.hxx:106
static const double x2[5]
static const double x1[5]
int Int_t
Definition: RtypesCore.h:43
unsigned int UInt_t
Definition: RtypesCore.h:44
bool Bool_t
Definition: RtypesCore.h:61
double Double_t
Definition: RtypesCore.h:57
float Float_t
Definition: RtypesCore.h:55
#define gDirectory
Definition: TDirectory.h:229
char name[80]
Definition: TGX11.cxx:109
#define gROOT
Definition: TROOT.h:406
char * Form(const char *fmt,...)
The Canvas class.
Definition: TCanvas.h:27
A Control Bar is a fully user configurable tool which provides fast access to frequently used operati...
Definition: TControlBar.h:22
void Show()
Show control bar.
void AddButton(TControlBarButton *button)
Add button.
void SetTextColor(const char *colorName)
Sets text color for control bar buttons, e.g.
Describe directory structure in memory.
Definition: TDirectory.h:40
virtual TList * GetListOfKeys() const
Definition: TDirectory.h:166
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:53
A TGraph is an object made of two arrays X and Y with npoints each.
Definition: TGraph.h:41
TAxis * GetXaxis()
Get the behaviour adopted by the object about the statoverflows. See EStatOverflows for more informat...
Definition: TH1.h:316
TAxis * GetYaxis()
Definition: TH1.h:317
2-D histogram with a float per channel (see TH1 documentation)}
Definition: TH2.h:251
Book space in a file, create I/O buffers, to fill them, (un)compress them.
Definition: TKey.h:28
virtual TObject * ReadObj()
To read a TObject* from the file.
Definition: TKey.cxx:738
This class displays a legend box (TPaveText) containing several legend entries.
Definition: TLegend.h:23
virtual void SetHeader(const char *header="", Option_t *option="")
Sets the header, which is the "title" that appears at the top of the legend.
Definition: TLegend.cxx:1099
virtual void Draw(Option_t *option="")
Draw this legend with its current attributes.
Definition: TLegend.cxx:423
void SetMargin(Float_t margin)
Definition: TLegend.h:69
A doubly linked list.
Definition: TList.h:44
virtual void SetTitle(const char *title="")
Set the title of the TNamed.
Definition: TNamed.cxx:164
virtual TObject * DrawClone(Option_t *option="") const
Draw a clone of this object in the current selected pad for instance with: gROOT->SetSelectedPad(gPad...
Definition: TObject.cxx:219
Basic string class.
Definition: TString.h:131
Ssiz_t Length() const
Definition: TString.h:405
const char * Data() const
Definition: TString.h:364
Ssiz_t Last(char c) const
Find last occurrence of a character c.
Definition: TString.cxx:892
Bool_t BeginsWith(const char *s, ECaseCompare cmp=kExact) const
Definition: TString.h:610
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition: TString.h:619
std::vector< std::tuple< TString, TString, TGraph * > > roccurvelist_t
const char * BUTTON_TYPE
std::map< TString, EfficiencyPlotWrapper * > classcanvasmap_t
static constexpr double nm
UInt_t GetListOfTitles(TDirectory *rfdir, TList &titles)
Definition: tmvaglob.cxx:636
void Initialize(Bool_t useTMVAStyle=kTRUE)
Definition: tmvaglob.cxx:176
TKey * NextKey(TIter &keyIter, TString className)
Definition: tmvaglob.cxx:357
void GetMethodTitle(TString &name, TKey *ikey)
Definition: tmvaglob.cxx:341
TFile * OpenFile(const TString &fin)
Definition: tmvaglob.cxx:192
void SetFrameStyle(TH1 *frame, Float_t scale=1.0)
Definition: tmvaglob.cxx:77
UInt_t GetListOfMethods(TList &methods, TDirectory *dir=0)
Definition: tmvaglob.cxx:583
std::vector< TString > GetClassNames(TDirectory *dir)
Definition: tmvaglob.cxx:462
void imgconv(TCanvas *c, const TString &fname)
Definition: tmvaglob.cxx:212
create variable transformations
roccurvelist_t getRocCurves(TDirectory *binDir, TString methodPrefix, TString graphNameRef)
void efficienciesMulticlass1vs1(TString dataset, TString fin)
std::vector< TString > getclassnames(TString dataset, TString fin)
void plotEfficienciesMulticlass1vs1(TString dataset, TString fin, TString baseClassname)
void plotEfficienciesMulticlass1vsRest(TString dataset, EEfficiencyPlotType plotType=EEfficiencyPlotType::kRejBvsEffS, TString filename_input="TMVAMulticlass.root")
void plotEfficienciesMulticlass(roccurvelist_t rocCurves, classcanvasmap_t classCanvasMap)
void efficienciesMulticlass1vsRest(TString dataset, TString filename_input="TMVAMulticlass.root", EEfficiencyPlotType plotType=EEfficiencyPlotType::kRejBvsEffS, Bool_t useTMVAStyle=kTRUE)
Short_t Min(Short_t a, Short_t b)
Definition: TMathBase.h:180
Definition: file.py:1
Definition: graph.py:1