Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
Factory.cxx
Go to the documentation of this file.
1// @(#)Root/tmva $Id$
2// Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne, Jan Therhaag
3// Updated by: Omar Zapata, Kim Albertsson
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : Factory *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Implementation (see header for description) *
12 * *
13 * Authors : *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Joerg Stelzer <stelzer@cern.ch> - DESY, Germany *
16 * Peter Speckmayer <peter.speckmayer@cern.ch> - CERN, Switzerland *
17 * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
18 * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
19 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
20 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
21 * Omar Zapata <Omar.Zapata@cern.ch> - UdeA/ITM Colombia *
22 * Lorenzo Moneta <Lorenzo.Moneta@cern.ch> - CERN, Switzerland *
23 * Sergei Gleyzer <Sergei.Gleyzer@cern.ch> - U of Florida & CERN *
24 * Kim Albertsson <kim.albertsson@cern.ch> - LTU & CERN *
25 * *
26 * Copyright (c) 2005-2015: *
27 * CERN, Switzerland *
28 * U. of Victoria, Canada *
29 * MPI-K Heidelberg, Germany *
30 * U. of Bonn, Germany *
31 * UdeA/ITM, Colombia *
32 * U. of Florida, USA *
33 * *
34 * Redistribution and use in source and binary forms, with or without *
35 * modification, are permitted according to the terms listed in LICENSE *
36 * (http://tmva.sourceforge.net/LICENSE) *
37 **********************************************************************************/
38
39/*! \class TMVA::Factory
40\ingroup TMVA
41
42This is the main MVA steering class.
43It creates all MVA methods, and guides them through the training, testing and
44evaluation phases.
45*/
46
47#include "TMVA/Factory.h"
48
50#include "TMVA/Config.h"
51#include "TMVA/Configurable.h"
52#include "TMVA/Tools.h"
53#include "TMVA/Ranking.h"
54#include "TMVA/DataSet.h"
55#include "TMVA/IMethod.h"
56#include "TMVA/MethodBase.h"
58#include "TMVA/DataSetManager.h"
59#include "TMVA/DataSetInfo.h"
60#include "TMVA/DataLoader.h"
61#include "TMVA/MethodBoost.h"
62#include "TMVA/MethodCategory.h"
63#include "TMVA/ROCCalc.h"
64#include "TMVA/ROCCurve.h"
65#include "TMVA/MsgLogger.h"
66
67#include "TMVA/VariableInfo.h"
69
70#include "TMVA/Results.h"
74#include <list>
75#include <bitset>
76#include <set>
77
78#include "TMVA/Types.h"
79
80#include "TROOT.h"
81#include "TFile.h"
82#include "TLeaf.h"
83#include "TEventList.h"
84#include "TH2.h"
85#include "TGraph.h"
86#include "TStyle.h"
87#include "TMatrixF.h"
88#include "TMatrixDSym.h"
89#include "TMultiGraph.h"
90#include "TPrincipal.h"
91#include "TMath.h"
92#include "TSystem.h"
93#include "TCanvas.h"
94
96// const Int_t MinNoTestEvents = 1;
97
99
100#define READXML kTRUE
101
102// number of bits for bitset
103#define VIBITS 32
104
105////////////////////////////////////////////////////////////////////////////////
106/// Standard constructor.
107///
108/// - jobname : this name will appear in all weight file names produced by the MVAs
109/// - theTargetFile : output ROOT file; the test tree and all evaluation plots
110/// will be stored here
111/// - theOption : option string; currently: "V" for verbose
112
113TMVA::Factory::Factory(TString jobName, TFile *theTargetFile, TString theOption)
114 : Configurable(theOption), fTransformations("I"), fVerbose(kFALSE), fVerboseLevel(kINFO), fCorrelations(kFALSE),
115 fROC(kTRUE), fSilentFile(theTargetFile == nullptr), fJobName(jobName), fAnalysisType(Types::kClassification),
116 fModelPersistence(kTRUE)
117{
118 fName = "Factory";
119 fgTargetFile = theTargetFile;
121
122 // render silent
123 if (gTools().CheckForSilentOption(GetOptions()))
124 Log().InhibitOutput(); // make sure is silent if wanted to
125
126 // init configurable
127 SetConfigDescription("Configuration options for Factory running");
129
130 // histograms are not automatically associated with the current
131 // directory and hence don't go out of scope when closing the file
132 // TH1::AddDirectory(kFALSE);
133 Bool_t silent = kFALSE;
134#ifdef WIN32
135 // under Windows, switch progress bar and color off by default, as the typical windows shell doesn't handle these
136 // (would need different sequences..)
137 Bool_t color = kFALSE;
138 Bool_t drawProgressBar = kFALSE;
139#else
140 Bool_t color = !gROOT->IsBatch();
141 Bool_t drawProgressBar = kTRUE;
142#endif
143 DeclareOptionRef(fVerbose, "V", "Verbose flag");
144 DeclareOptionRef(fVerboseLevel = TString("Info"), "VerboseLevel", "VerboseLevel (Debug/Verbose/Info)");
145 AddPreDefVal(TString("Debug"));
146 AddPreDefVal(TString("Verbose"));
147 AddPreDefVal(TString("Info"));
148 DeclareOptionRef(color, "Color", "Flag for coloured screen output (default: True, if in batch mode: False)");
150 fTransformations, "Transformations",
151 "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for identity, "
152 "decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation transformations");
153 DeclareOptionRef(fCorrelations, "Correlations", "boolean to show correlation in output");
154 DeclareOptionRef(fROC, "ROC", "boolean to show ROC in output");
155 DeclareOptionRef(silent, "Silent",
156 "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory "
157 "class object (default: False)");
158 DeclareOptionRef(drawProgressBar, "DrawProgressBar",
159 "Draw progress bar to display training, testing and evaluation schedule (default: True)");
160 DeclareOptionRef(fModelPersistence, "ModelPersistence",
161 "Option to save the trained model in xml file or using serialization");
162
163 TString analysisType("Auto");
164 DeclareOptionRef(analysisType, "AnalysisType",
165 "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)");
166 AddPreDefVal(TString("Classification"));
167 AddPreDefVal(TString("Regression"));
168 AddPreDefVal(TString("Multiclass"));
169 AddPreDefVal(TString("Auto"));
170
171 ParseOptions();
173
174 if (Verbose())
175 fLogger->SetMinType(kVERBOSE);
176 if (fVerboseLevel.CompareTo("Debug") == 0)
177 fLogger->SetMinType(kDEBUG);
178 if (fVerboseLevel.CompareTo("Verbose") == 0)
179 fLogger->SetMinType(kVERBOSE);
180 if (fVerboseLevel.CompareTo("Info") == 0)
181 fLogger->SetMinType(kINFO);
182
183 // global settings
184 gConfig().SetUseColor(color);
185 gConfig().SetSilent(silent);
186 gConfig().SetDrawProgressBar(drawProgressBar);
187
188 analysisType.ToLower();
189 if (analysisType == "classification")
191 else if (analysisType == "regression")
193 else if (analysisType == "multiclass")
195 else if (analysisType == "auto")
197
198 // Greetings();
199}
200
201////////////////////////////////////////////////////////////////////////////////
202/// Constructor.
203
205 : Configurable(theOption), fTransformations("I"), fVerbose(kFALSE), fCorrelations(kFALSE), fROC(kTRUE),
206 fSilentFile(kTRUE), fJobName(jobName), fAnalysisType(Types::kClassification), fModelPersistence(kTRUE)
207{
208 fName = "Factory";
209 fgTargetFile = nullptr;
211
212 // render silent
213 if (gTools().CheckForSilentOption(GetOptions()))
214 Log().InhibitOutput(); // make sure is silent if wanted to
215
216 // init configurable
217 SetConfigDescription("Configuration options for Factory running");
219
220 // histograms are not automatically associated with the current
221 // directory and hence don't go out of scope when closing the file
223 Bool_t silent = kFALSE;
224#ifdef WIN32
225 // under Windows, switch progress bar and color off by default, as the typical windows shell doesn't handle these
226 // (would need different sequences..)
227 Bool_t color = kFALSE;
228 Bool_t drawProgressBar = kFALSE;
229#else
230 Bool_t color = !gROOT->IsBatch();
231 Bool_t drawProgressBar = kTRUE;
232#endif
233 DeclareOptionRef(fVerbose, "V", "Verbose flag");
234 DeclareOptionRef(fVerboseLevel = TString("Info"), "VerboseLevel", "VerboseLevel (Debug/Verbose/Info)");
235 AddPreDefVal(TString("Debug"));
236 AddPreDefVal(TString("Verbose"));
237 AddPreDefVal(TString("Info"));
238 DeclareOptionRef(color, "Color", "Flag for coloured screen output (default: True, if in batch mode: False)");
240 fTransformations, "Transformations",
241 "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for identity, "
242 "decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation transformations");
243 DeclareOptionRef(fCorrelations, "Correlations", "boolean to show correlation in output");
244 DeclareOptionRef(fROC, "ROC", "boolean to show ROC in output");
245 DeclareOptionRef(silent, "Silent",
246 "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory "
247 "class object (default: False)");
248 DeclareOptionRef(drawProgressBar, "DrawProgressBar",
249 "Draw progress bar to display training, testing and evaluation schedule (default: True)");
250 DeclareOptionRef(fModelPersistence, "ModelPersistence",
251 "Option to save the trained model in xml file or using serialization");
252
253 TString analysisType("Auto");
254 DeclareOptionRef(analysisType, "AnalysisType",
255 "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)");
256 AddPreDefVal(TString("Classification"));
257 AddPreDefVal(TString("Regression"));
258 AddPreDefVal(TString("Multiclass"));
259 AddPreDefVal(TString("Auto"));
260
261 ParseOptions();
263
264 if (Verbose())
265 fLogger->SetMinType(kVERBOSE);
266 if (fVerboseLevel.CompareTo("Debug") == 0)
267 fLogger->SetMinType(kDEBUG);
268 if (fVerboseLevel.CompareTo("Verbose") == 0)
269 fLogger->SetMinType(kVERBOSE);
270 if (fVerboseLevel.CompareTo("Info") == 0)
271 fLogger->SetMinType(kINFO);
272
273 // global settings
274 gConfig().SetUseColor(color);
275 gConfig().SetSilent(silent);
276 gConfig().SetDrawProgressBar(drawProgressBar);
277
278 analysisType.ToLower();
279 if (analysisType == "classification")
281 else if (analysisType == "regression")
283 else if (analysisType == "multiclass")
285 else if (analysisType == "auto")
287
288 Greetings();
289}
290
291////////////////////////////////////////////////////////////////////////////////
292/// Print welcome message.
293/// Options are: kLogoWelcomeMsg, kIsometricWelcomeMsg, kLeanWelcomeMsg
294
296{
297 gTools().ROOTVersionMessage(Log());
298 gTools().TMVAWelcomeMessage(Log(), gTools().kLogoWelcomeMsg);
299 gTools().TMVAVersionMessage(Log());
300 Log() << Endl;
301}
302
303////////////////////////////////////////////////////////////////////////////////
304/// Destructor.
305
307{
308 std::vector<TMVA::VariableTransformBase *>::iterator trfIt = fDefaultTrfs.begin();
309 for (; trfIt != fDefaultTrfs.end(); ++trfIt)
310 delete (*trfIt);
311
312 this->DeleteAllMethods();
313
314 // problem with call of REGISTER_METHOD macro ...
315 // ClassifierFactory::DestroyInstance();
316 // Types::DestroyInstance();
317 // Tools::DestroyInstance();
318 // Config::DestroyInstance();
319}
320
321////////////////////////////////////////////////////////////////////////////////
322/// Delete methods.
323
325{
326 std::map<TString, MVector *>::iterator itrMap;
327
328 for (itrMap = fMethodsMap.begin(); itrMap != fMethodsMap.end(); ++itrMap) {
329 MVector *methods = itrMap->second;
330 // delete methods
331 MVector::iterator itrMethod = methods->begin();
332 for (; itrMethod != methods->end(); ++itrMethod) {
333 Log() << kDEBUG << "Delete method: " << (*itrMethod)->GetName() << Endl;
334 delete (*itrMethod);
335 }
336 methods->clear();
337 delete methods;
338 }
339}
340
341////////////////////////////////////////////////////////////////////////////////
342
344{
345 fVerbose = v;
346}
347
348////////////////////////////////////////////////////////////////////////////////
349/// Book a classifier or regression method.
350
352TMVA::Factory::BookMethod(TMVA::DataLoader *loader, TString theMethodName, TString methodTitle, TString theOption)
353{
354 if (fModelPersistence)
355 gSystem->MakeDirectory(loader->GetName()); // creating directory for DataLoader output
356
357 TString datasetname = loader->GetName();
358
359 if (fAnalysisType == Types::kNoAnalysisType) {
360 if (loader->GetDataSetInfo().GetNClasses() == 2 && loader->GetDataSetInfo().GetClassInfo("Signal") != NULL &&
361 loader->GetDataSetInfo().GetClassInfo("Background") != NULL) {
362 fAnalysisType = Types::kClassification; // default is classification
363 } else if (loader->GetDataSetInfo().GetNClasses() >= 2) {
364 fAnalysisType = Types::kMulticlass; // if two classes, but not named "Signal" and "Background"
365 } else
366 Log() << kFATAL << "No analysis type for " << loader->GetDataSetInfo().GetNClasses() << " classes and "
367 << loader->GetDataSetInfo().GetNTargets() << " regression targets." << Endl;
368 }
369
370 // booking via name; the names are translated into enums and the
371 // corresponding overloaded BookMethod is called
372
373 if (fMethodsMap.find(datasetname) != fMethodsMap.end()) {
374 if (GetMethod(datasetname, methodTitle) != 0) {
375 Log() << kFATAL << "Booking failed since method with title <" << methodTitle << "> already exists "
376 << "in with DataSet Name <" << loader->GetName() << "> " << Endl;
377 }
378 }
379
380 Log() << kHEADER << "Booking method: " << gTools().Color("bold")
381 << methodTitle
382 // << gTools().Color("reset")<<" DataSet Name: "<<gTools().Color("bold")<<loader->GetName()
383 << gTools().Color("reset") << Endl << Endl;
384
385 // interpret option string with respect to a request for boosting (i.e., BostNum > 0)
386 Int_t boostNum = 0;
387 TMVA::Configurable *conf = new TMVA::Configurable(theOption);
388 conf->DeclareOptionRef(boostNum = 0, "Boost_num", "Number of times the classifier will be boosted");
389 conf->ParseOptions();
390 delete conf;
391 // this is name of weight file directory
392 TString fileDir;
393 if (fModelPersistence) {
394 // find prefix in fWeightFileDir;
396 fileDir = prefix;
397 if (!prefix.IsNull())
398 if (fileDir[fileDir.Length() - 1] != '/')
399 fileDir += "/";
400 fileDir += loader->GetName();
401 fileDir += "/" + gConfig().GetIONames().fWeightFileDir;
402 }
403 // initialize methods
404 IMethod *im;
405 if (!boostNum) {
406 im = ClassifierFactory::Instance().Create(theMethodName.Data(), fJobName, methodTitle, loader->GetDataSetInfo(),
407 theOption);
408 } else {
409 // boosted classifier, requires a specific definition, making it transparent for the user
410 Log() << kDEBUG << "Boost Number is " << boostNum << " > 0: train boosted classifier" << Endl;
411 im = ClassifierFactory::Instance().Create("Boost", fJobName, methodTitle, loader->GetDataSetInfo(), theOption);
412 MethodBoost *methBoost = dynamic_cast<MethodBoost *>(im); // DSMTEST divided into two lines
413 if (!methBoost) { // DSMTEST
414 Log() << kFATAL << "Method with type kBoost cannot be casted to MethodCategory. /Factory" << Endl; // DSMTEST
415 return nullptr;
416 }
417 if (fModelPersistence)
418 methBoost->SetWeightFileDir(fileDir);
419 methBoost->SetModelPersistence(fModelPersistence);
420 methBoost->SetBoostedMethodName(theMethodName); // DSMTEST divided into two lines
421 methBoost->fDataSetManager = loader->GetDataSetInfo().GetDataSetManager(); // DSMTEST
422 methBoost->SetFile(fgTargetFile);
423 methBoost->SetSilentFile(IsSilentFile());
424 }
425
426 MethodBase *method = dynamic_cast<MethodBase *>(im);
427 if (method == 0)
428 return 0; // could not create method
429
430 // set fDataSetManager if MethodCategory (to enable Category to create datasetinfo objects) // DSMTEST
431 if (method->GetMethodType() == Types::kCategory) { // DSMTEST
432 MethodCategory *methCat = (dynamic_cast<MethodCategory *>(im)); // DSMTEST
433 if (!methCat) { // DSMTEST
434 Log() << kFATAL << "Method with type kCategory cannot be casted to MethodCategory. /Factory"
435 << Endl; // DSMTEST
436 return nullptr;
437 }
438 if (fModelPersistence)
439 methCat->SetWeightFileDir(fileDir);
440 methCat->SetModelPersistence(fModelPersistence);
441 methCat->fDataSetManager = loader->GetDataSetInfo().GetDataSetManager(); // DSMTEST
442 methCat->SetFile(fgTargetFile);
443 methCat->SetSilentFile(IsSilentFile());
444 } // DSMTEST
445
446 if (!method->HasAnalysisType(fAnalysisType, loader->GetDataSetInfo().GetNClasses(),
447 loader->GetDataSetInfo().GetNTargets())) {
448 Log() << kWARNING << "Method " << method->GetMethodTypeName() << " is not capable of handling ";
449 if (fAnalysisType == Types::kRegression) {
450 Log() << "regression with " << loader->GetDataSetInfo().GetNTargets() << " targets." << Endl;
451 } else if (fAnalysisType == Types::kMulticlass) {
452 Log() << "multiclass classification with " << loader->GetDataSetInfo().GetNClasses() << " classes." << Endl;
453 } else {
454 Log() << "classification with " << loader->GetDataSetInfo().GetNClasses() << " classes." << Endl;
455 }
456 return 0;
457 }
458
459 if (fModelPersistence)
460 method->SetWeightFileDir(fileDir);
461 method->SetModelPersistence(fModelPersistence);
462 method->SetAnalysisType(fAnalysisType);
463 method->SetupMethod();
464 method->ParseOptions();
465 method->ProcessSetup();
466 method->SetFile(fgTargetFile);
467 method->SetSilentFile(IsSilentFile());
468
469 // check-for-unused-options is performed; may be overridden by derived classes
470 method->CheckSetup();
471
472 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
473 MVector *mvector = new MVector;
474 fMethodsMap[datasetname] = mvector;
475 }
476 fMethodsMap[datasetname]->push_back(method);
477 return method;
478}
479
480////////////////////////////////////////////////////////////////////////////////
481/// Books MVA method. The option configuration string is custom for each MVA
482/// the TString field "theNameAppendix" serves to define (and distinguish)
483/// several instances of a given MVA, eg, when one wants to compare the
484/// performance of various configurations
485
487TMVA::Factory::BookMethod(TMVA::DataLoader *loader, Types::EMVA theMethod, TString methodTitle, TString theOption)
488{
489 return BookMethod(loader, Types::Instance().GetMethodName(theMethod), methodTitle, theOption);
490}
491
492////////////////////////////////////////////////////////////////////////////////
493/// Adds an already constructed method to be managed by this factory.
494///
495/// \note Private.
496/// \note Know what you are doing when using this method. The method that you
497/// are loading could be trained already.
498///
499
502{
503 TString datasetname = loader->GetName();
504 std::string methodTypeName = std::string(Types::Instance().GetMethodName(methodType).Data());
505 DataSetInfo &dsi = loader->GetDataSetInfo();
506
507 IMethod *im = ClassifierFactory::Instance().Create(methodTypeName, dsi, weightfile);
508 MethodBase *method = (dynamic_cast<MethodBase *>(im));
509
510 if (method == nullptr)
511 return nullptr;
512
513 if (method->GetMethodType() == Types::kCategory) {
514 Log() << kERROR << "Cannot handle category methods for now." << Endl;
515 }
516
517 TString fileDir;
518 if (fModelPersistence) {
519 // find prefix in fWeightFileDir;
521 fileDir = prefix;
522 if (!prefix.IsNull())
523 if (fileDir[fileDir.Length() - 1] != '/')
524 fileDir += "/";
525 fileDir = loader->GetName();
526 fileDir += "/" + gConfig().GetIONames().fWeightFileDir;
527 }
528
529 if (fModelPersistence)
530 method->SetWeightFileDir(fileDir);
531 method->SetModelPersistence(fModelPersistence);
532 method->SetAnalysisType(fAnalysisType);
533 method->SetupMethod();
534 method->SetFile(fgTargetFile);
535 method->SetSilentFile(IsSilentFile());
536
538
539 // read weight file
540 method->ReadStateFromFile();
541
542 // method->CheckSetup();
543
544 TString methodTitle = method->GetName();
545 if (HasMethod(datasetname, methodTitle) != 0) {
546 Log() << kFATAL << "Booking failed since method with title <" << methodTitle << "> already exists "
547 << "in with DataSet Name <" << loader->GetName() << "> " << Endl;
548 }
549
550 Log() << kINFO << "Booked classifier \"" << method->GetMethodName() << "\" of type: \""
551 << method->GetMethodTypeName() << "\"" << Endl;
552
553 if (fMethodsMap.count(datasetname) == 0) {
554 MVector *mvector = new MVector;
555 fMethodsMap[datasetname] = mvector;
556 }
557
558 fMethodsMap[datasetname]->push_back(method);
559
560 return method;
561}
562
563////////////////////////////////////////////////////////////////////////////////
564/// Returns pointer to MVA that corresponds to given method title.
565
566TMVA::IMethod *TMVA::Factory::GetMethod(const TString &datasetname, const TString &methodTitle) const
567{
568 if (fMethodsMap.find(datasetname) == fMethodsMap.end())
569 return 0;
570
571 MVector *methods = fMethodsMap.find(datasetname)->second;
572
573 MVector::const_iterator itrMethod;
574 //
575 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
576 MethodBase *mva = dynamic_cast<MethodBase *>(*itrMethod);
577 if ((mva->GetMethodName()) == methodTitle)
578 return mva;
579 }
580 return 0;
581}
582
583////////////////////////////////////////////////////////////////////////////////
584/// Checks whether a given method name is defined for a given dataset.
585
586Bool_t TMVA::Factory::HasMethod(const TString &datasetname, const TString &methodTitle) const
587{
588 if (fMethodsMap.find(datasetname) == fMethodsMap.end())
589 return 0;
590
591 std::string methodName = methodTitle.Data();
592 auto isEqualToMethodName = [&methodName](TMVA::IMethod *m) { return (0 == methodName.compare(m->GetName())); };
593
594 TMVA::Factory::MVector *methods = this->fMethodsMap.at(datasetname);
595 Bool_t isMethodNameExisting = std::any_of(methods->begin(), methods->end(), isEqualToMethodName);
596
597 return isMethodNameExisting;
598}
599
600////////////////////////////////////////////////////////////////////////////////
601
603{
604 RootBaseDir()->cd();
605
606 if (!RootBaseDir()->GetDirectory(fDataSetInfo.GetName()))
607 RootBaseDir()->mkdir(fDataSetInfo.GetName());
608 else
609 return; // loader is now in the output file, we dont need to save again
610
611 RootBaseDir()->cd(fDataSetInfo.GetName());
612 fDataSetInfo.GetDataSet(); // builds dataset (including calculation of correlation matrix)
613
614 // correlation matrix of the default DS
615 const TMatrixD *m(0);
616 const TH2 *h(0);
617
618 if (fAnalysisType == Types::kMulticlass) {
619 for (UInt_t cls = 0; cls < fDataSetInfo.GetNClasses(); cls++) {
620 m = fDataSetInfo.CorrelationMatrix(fDataSetInfo.GetClassInfo(cls)->GetName());
621 h = fDataSetInfo.CreateCorrelationMatrixHist(
622 m, TString("CorrelationMatrix") + fDataSetInfo.GetClassInfo(cls)->GetName(),
623 TString("Correlation Matrix (") + fDataSetInfo.GetClassInfo(cls)->GetName() + TString(")"));
624 if (h != 0) {
625 h->Write();
626 delete h;
627 }
628 }
629 } else {
630 m = fDataSetInfo.CorrelationMatrix("Signal");
631 h = fDataSetInfo.CreateCorrelationMatrixHist(m, "CorrelationMatrixS", "Correlation Matrix (signal)");
632 if (h != 0) {
633 h->Write();
634 delete h;
635 }
636
637 m = fDataSetInfo.CorrelationMatrix("Background");
638 h = fDataSetInfo.CreateCorrelationMatrixHist(m, "CorrelationMatrixB", "Correlation Matrix (background)");
639 if (h != 0) {
640 h->Write();
641 delete h;
642 }
643
644 m = fDataSetInfo.CorrelationMatrix("Regression");
645 h = fDataSetInfo.CreateCorrelationMatrixHist(m, "CorrelationMatrix", "Correlation Matrix");
646 if (h != 0) {
647 h->Write();
648 delete h;
649 }
650 }
651
652 // some default transformations to evaluate
653 // NOTE: all transformations are destroyed after this test
654 TString processTrfs = "I"; //"I;N;D;P;U;G,D;"
655
656 // plus some user defined transformations
657 processTrfs = fTransformations;
658
659 // remove any trace of identity transform - if given (avoid to apply it twice)
660 std::vector<TMVA::TransformationHandler *> trfs;
661 TransformationHandler *identityTrHandler = 0;
662
663 std::vector<TString> trfsDef = gTools().SplitString(processTrfs, ';');
664 std::vector<TString>::iterator trfsDefIt = trfsDef.begin();
665 for (; trfsDefIt != trfsDef.end(); ++trfsDefIt) {
666 trfs.push_back(new TMVA::TransformationHandler(fDataSetInfo, "Factory"));
667 TString trfS = (*trfsDefIt);
668
669 // Log() << kINFO << Endl;
670 Log() << kDEBUG << "current transformation string: '" << trfS.Data() << "'" << Endl;
671 TMVA::CreateVariableTransforms(trfS, fDataSetInfo, *(trfs.back()), Log());
672
673 if (trfS.BeginsWith('I'))
674 identityTrHandler = trfs.back();
675 }
676
677 const std::vector<Event *> &inputEvents = fDataSetInfo.GetDataSet()->GetEventCollection();
678
679 // apply all transformations
680 std::vector<TMVA::TransformationHandler *>::iterator trfIt = trfs.begin();
681
682 for (; trfIt != trfs.end(); ++trfIt) {
683 // setting a Root dir causes the variables distributions to be saved to the root file
684 (*trfIt)->SetRootDir(RootBaseDir()->GetDirectory(fDataSetInfo.GetName())); // every dataloader have its own dir
685 (*trfIt)->CalcTransformations(inputEvents);
686 }
687 if (identityTrHandler)
688 identityTrHandler->PrintVariableRanking();
689
690 // clean up
691 for (trfIt = trfs.begin(); trfIt != trfs.end(); ++trfIt)
692 delete *trfIt;
693}
694
695////////////////////////////////////////////////////////////////////////////////
696/// Iterates through all booked methods and sees if they use parameter tuning and if so
697/// does just that, i.e.\ calls "Method::Train()" for different parameter settings and
698/// keeps in mind the "optimal one"...\ and that's the one that will later on be used
699/// in the main training loop.
700
701std::map<TString, Double_t> TMVA::Factory::OptimizeAllMethods(TString fomType, TString fitType)
702{
703
704 std::map<TString, MVector *>::iterator itrMap;
705 std::map<TString, Double_t> TunedParameters;
706 for (itrMap = fMethodsMap.begin(); itrMap != fMethodsMap.end(); ++itrMap) {
707 MVector *methods = itrMap->second;
708
709 MVector::iterator itrMethod;
710
711 // iterate over methods and optimize
712 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
714 MethodBase *mva = dynamic_cast<MethodBase *>(*itrMethod);
715 if (!mva) {
716 Log() << kFATAL << "Dynamic cast to MethodBase failed" << Endl;
717 return TunedParameters;
718 }
719
721 Log() << kWARNING << "Method " << mva->GetMethodName() << " not trained (training tree has less entries ["
722 << mva->Data()->GetNTrainingEvents() << "] than required [" << MinNoTrainingEvents << "]" << Endl;
723 continue;
724 }
725
726 Log() << kINFO << "Optimize method: " << mva->GetMethodName() << " for "
727 << (fAnalysisType == Types::kRegression
728 ? "Regression"
729 : (fAnalysisType == Types::kMulticlass ? "Multiclass classification" : "Classification"))
730 << Endl;
731
732 TunedParameters = mva->OptimizeTuningParameters(fomType, fitType);
733 Log() << kINFO << "Optimization of tuning parameters finished for Method:" << mva->GetName() << Endl;
734 }
735 }
736
737 return TunedParameters;
738}
739
740////////////////////////////////////////////////////////////////////////////////
741/// Private method to generate a ROCCurve instance for a given method.
742/// Handles the conversion from TMVA ResultSet to a format the ROCCurve class
743/// understands.
744///
745/// \note You own the retured pointer.
746///
747
750{
751 return GetROC((TString)loader->GetName(), theMethodName, iClass, type);
752}
753
754////////////////////////////////////////////////////////////////////////////////
755/// Private method to generate a ROCCurve instance for a given method.
756/// Handles the conversion from TMVA ResultSet to a format the ROCCurve class
757/// understands.
758///
759/// \note You own the retured pointer.
760///
761
763{
764 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
765 Log() << kERROR << Form("DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
766 return nullptr;
767 }
768
769 if (!this->HasMethod(datasetname, theMethodName)) {
770 Log() << kERROR << Form("Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data())
771 << Endl;
772 return nullptr;
773 }
774
775 std::set<Types::EAnalysisType> allowedAnalysisTypes = {Types::kClassification, Types::kMulticlass};
776 if (allowedAnalysisTypes.count(this->fAnalysisType) == 0) {
777 Log() << kERROR << Form("Can only generate ROC curves for analysis type kClassification and kMulticlass.")
778 << Endl;
779 return nullptr;
780 }
781
782 TMVA::MethodBase *method = dynamic_cast<TMVA::MethodBase *>(this->GetMethod(datasetname, theMethodName));
783 TMVA::DataSet *dataset = method->Data();
784 dataset->SetCurrentType(type);
785 TMVA::Results *results = dataset->GetResults(theMethodName, type, this->fAnalysisType);
786
787 UInt_t nClasses = method->DataInfo().GetNClasses();
788 if (this->fAnalysisType == Types::kMulticlass && iClass >= nClasses) {
789 Log() << kERROR
790 << Form("Given class number (iClass = %i) does not exist. There are %i classes in dataset.", iClass,
791 nClasses)
792 << Endl;
793 return nullptr;
794 }
795
796 TMVA::ROCCurve *rocCurve = nullptr;
797 if (this->fAnalysisType == Types::kClassification) {
798
799 std::vector<Float_t> *mvaRes = dynamic_cast<ResultsClassification *>(results)->GetValueVector();
800 std::vector<Bool_t> *mvaResTypes = dynamic_cast<ResultsClassification *>(results)->GetValueVectorTypes();
801 std::vector<Float_t> mvaResWeights;
802
803 auto eventCollection = dataset->GetEventCollection(type);
804 mvaResWeights.reserve(eventCollection.size());
805 for (auto ev : eventCollection) {
806 mvaResWeights.push_back(ev->GetWeight());
807 }
808
809 rocCurve = new TMVA::ROCCurve(*mvaRes, *mvaResTypes, mvaResWeights);
810
811 } else if (this->fAnalysisType == Types::kMulticlass) {
812 std::vector<Float_t> mvaRes;
813 std::vector<Bool_t> mvaResTypes;
814 std::vector<Float_t> mvaResWeights;
815
816 std::vector<std::vector<Float_t>> *rawMvaRes = dynamic_cast<ResultsMulticlass *>(results)->GetValueVector();
817
818 // Vector transpose due to values being stored as
819 // [ [0, 1, 2], [0, 1, 2], ... ]
820 // in ResultsMulticlass::GetValueVector.
821 mvaRes.reserve(rawMvaRes->size());
822 for (auto item : *rawMvaRes) {
823 mvaRes.push_back(item[iClass]);
824 }
825
826 auto eventCollection = dataset->GetEventCollection(type);
827 mvaResTypes.reserve(eventCollection.size());
828 mvaResWeights.reserve(eventCollection.size());
829 for (auto ev : eventCollection) {
830 mvaResTypes.push_back(ev->GetClass() == iClass);
831 mvaResWeights.push_back(ev->GetWeight());
832 }
833
834 rocCurve = new TMVA::ROCCurve(mvaRes, mvaResTypes, mvaResWeights);
835 }
836
837 return rocCurve;
838}
839
840////////////////////////////////////////////////////////////////////////////////
841/// Calculate the integral of the ROC curve, also known as the area under curve
842/// (AUC), for a given method.
843///
844/// Argument iClass specifies the class to generate the ROC curve in a
845/// multiclass setting. It is ignored for binary classification.
846///
847
850{
851 return GetROCIntegral((TString)loader->GetName(), theMethodName, iClass, type);
852}
853
854////////////////////////////////////////////////////////////////////////////////
855/// Calculate the integral of the ROC curve, also known as the area under curve
856/// (AUC), for a given method.
857///
858/// Argument iClass specifies the class to generate the ROC curve in a
859/// multiclass setting. It is ignored for binary classification.
860///
861
863{
864 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
865 Log() << kERROR << Form("DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
866 return 0;
867 }
868
869 if (!this->HasMethod(datasetname, theMethodName)) {
870 Log() << kERROR << Form("Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data())
871 << Endl;
872 return 0;
873 }
874
875 std::set<Types::EAnalysisType> allowedAnalysisTypes = {Types::kClassification, Types::kMulticlass};
876 if (allowedAnalysisTypes.count(this->fAnalysisType) == 0) {
877 Log() << kERROR << Form("Can only generate ROC integral for analysis type kClassification. and kMulticlass.")
878 << Endl;
879 return 0;
880 }
881
882 TMVA::ROCCurve *rocCurve = GetROC(datasetname, theMethodName, iClass, type);
883 if (!rocCurve) {
884 Log() << kFATAL
885 << Form("ROCCurve object was not created in Method = %s not found with Dataset = %s ", theMethodName.Data(),
886 datasetname.Data())
887 << Endl;
888 return 0;
889 }
890
892 Double_t rocIntegral = rocCurve->GetROCIntegral(npoints);
893 delete rocCurve;
894
895 return rocIntegral;
896}
897
898////////////////////////////////////////////////////////////////////////////////
899/// Argument iClass specifies the class to generate the ROC curve in a
900/// multiclass setting. It is ignored for binary classification.
901///
902/// Returns a ROC graph for a given method, or nullptr on error.
903///
904/// Note: Evaluation of the given method must have been run prior to ROC
905/// generation through Factory::EvaluateAllMetods.
906///
907/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
908/// and the others considered background. This is ok in binary classification
909/// but in in multi class classification, the ROC surface is an N dimensional
910/// shape, where N is number of classes - 1.
911
912TGraph *TMVA::Factory::GetROCCurve(DataLoader *loader, TString theMethodName, Bool_t setTitles, UInt_t iClass,
914{
915 return GetROCCurve((TString)loader->GetName(), theMethodName, setTitles, iClass, type);
916}
917
918////////////////////////////////////////////////////////////////////////////////
919/// Argument iClass specifies the class to generate the ROC curve in a
920/// multiclass setting. It is ignored for binary classification.
921///
922/// Returns a ROC graph for a given method, or nullptr on error.
923///
924/// Note: Evaluation of the given method must have been run prior to ROC
925/// generation through Factory::EvaluateAllMetods.
926///
927/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
928/// and the others considered background. This is ok in binary classification
929/// but in in multi class classification, the ROC surface is an N dimensional
930/// shape, where N is number of classes - 1.
931
932TGraph *TMVA::Factory::GetROCCurve(TString datasetname, TString theMethodName, Bool_t setTitles, UInt_t iClass,
934{
935 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
936 Log() << kERROR << Form("DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
937 return nullptr;
938 }
939
940 if (!this->HasMethod(datasetname, theMethodName)) {
941 Log() << kERROR << Form("Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data())
942 << Endl;
943 return nullptr;
944 }
945
946 std::set<Types::EAnalysisType> allowedAnalysisTypes = {Types::kClassification, Types::kMulticlass};
947 if (allowedAnalysisTypes.count(this->fAnalysisType) == 0) {
948 Log() << kERROR << Form("Can only generate ROC curves for analysis type kClassification and kMulticlass.")
949 << Endl;
950 return nullptr;
951 }
952
953 TMVA::ROCCurve *rocCurve = GetROC(datasetname, theMethodName, iClass, type);
954 TGraph *graph = nullptr;
955
956 if (!rocCurve) {
957 Log() << kFATAL
958 << Form("ROCCurve object was not created in Method = %s not found with Dataset = %s ", theMethodName.Data(),
959 datasetname.Data())
960 << Endl;
961 return nullptr;
962 }
963
964 graph = (TGraph *)rocCurve->GetROCCurve()->Clone();
965 delete rocCurve;
966
967 if (setTitles) {
968 graph->GetYaxis()->SetTitle("Background rejection (Specificity)");
969 graph->GetXaxis()->SetTitle("Signal efficiency (Sensitivity)");
970 graph->SetTitle(TString::Format("Signal efficiency vs. Background rejection (%s)", theMethodName.Data()).Data());
971 }
972
973 return graph;
974}
975
976////////////////////////////////////////////////////////////////////////////////
977/// Generate a collection of graphs, for all methods for a given class. Suitable
978/// for comparing method performance.
979///
980/// Argument iClass specifies the class to generate the ROC curve in a
981/// multiclass setting. It is ignored for binary classification.
982///
983/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
984/// and the others considered background. This is ok in binary classification
985/// but in in multi class classification, the ROC surface is an N dimensional
986/// shape, where N is number of classes - 1.
987
989{
990 return GetROCCurveAsMultiGraph((TString)loader->GetName(), iClass, type);
991}
992
993////////////////////////////////////////////////////////////////////////////////
994/// Generate a collection of graphs, for all methods for a given class. Suitable
995/// for comparing method performance.
996///
997/// Argument iClass specifies the class to generate the ROC curve in a
998/// multiclass setting. It is ignored for binary classification.
999///
1000/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
1001/// and the others considered background. This is ok in binary classification
1002/// but in in multi class classification, the ROC surface is an N dimensional
1003/// shape, where N is number of classes - 1.
1004
1006{
1007 UInt_t line_color = 1;
1008
1009 TMultiGraph *multigraph = new TMultiGraph();
1010
1011 MVector *methods = fMethodsMap[datasetname.Data()];
1012 for (auto *method_raw : *methods) {
1013 TMVA::MethodBase *method = dynamic_cast<TMVA::MethodBase *>(method_raw);
1014 if (method == nullptr) {
1015 continue;
1016 }
1017
1018 TString methodName = method->GetMethodName();
1019 UInt_t nClasses = method->DataInfo().GetNClasses();
1020
1021 if (this->fAnalysisType == Types::kMulticlass && iClass >= nClasses) {
1022 Log() << kERROR
1023 << Form("Given class number (iClass = %i) does not exist. There are %i classes in dataset.", iClass,
1024 nClasses)
1025 << Endl;
1026 continue;
1027 }
1028
1029 TString className = method->DataInfo().GetClassInfo(iClass)->GetName();
1030
1031 TGraph *graph = this->GetROCCurve(datasetname, methodName, false, iClass, type);
1032 graph->SetTitle(methodName);
1033
1034 graph->SetLineWidth(2);
1035 graph->SetLineColor(line_color++);
1036 graph->SetFillColor(10);
1037
1038 multigraph->Add(graph);
1039 }
1040
1041 if (multigraph->GetListOfGraphs() == nullptr) {
1042 Log() << kERROR << Form("No metohds have class %i defined.", iClass) << Endl;
1043 return nullptr;
1044 }
1045
1046 return multigraph;
1047}
1048
1049////////////////////////////////////////////////////////////////////////////////
1050/// Draws ROC curves for all methods booked with the factory for a given class
1051/// onto a canvas.
1052///
1053/// Argument iClass specifies the class to generate the ROC curve in a
1054/// multiclass setting. It is ignored for binary classification.
1055///
1056/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
1057/// and the others considered background. This is ok in binary classification
1058/// but in in multi class classification, the ROC surface is an N dimensional
1059/// shape, where N is number of classes - 1.
1060
1062{
1063 return GetROCCurve((TString)loader->GetName(), iClass, type);
1064}
1065
1066////////////////////////////////////////////////////////////////////////////////
1067/// Draws ROC curves for all methods booked with the factory for a given class.
1068///
1069/// Argument iClass specifies the class to generate the ROC curve in a
1070/// multiclass setting. It is ignored for binary classification.
1071///
1072/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
1073/// and the others considered background. This is ok in binary classification
1074/// but in in multi class classification, the ROC surface is an N dimensional
1075/// shape, where N is number of classes - 1.
1076
1078{
1079 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
1080 Log() << kERROR << Form("DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
1081 return 0;
1082 }
1083
1084 TString name = TString::Format("ROCCurve %s class %i", datasetname.Data(), iClass);
1085 TCanvas *canvas = new TCanvas(name.Data(), "ROC Curve", 200, 10, 700, 500);
1086 canvas->SetGrid();
1087
1088 TMultiGraph *multigraph = this->GetROCCurveAsMultiGraph(datasetname, iClass, type);
1089
1090 if (multigraph) {
1091 multigraph->Draw("AL");
1092
1093 multigraph->GetYaxis()->SetTitle("Background rejection (Specificity)");
1094 multigraph->GetXaxis()->SetTitle("Signal efficiency (Sensitivity)");
1095
1096 TString titleString = TString::Format("Signal efficiency vs. Background rejection");
1097 if (this->fAnalysisType == Types::kMulticlass) {
1098 titleString = TString::Format("%s (Class=%i)", titleString.Data(), iClass);
1099 }
1100
1101 // Workaround for TMultigraph not drawing title correctly.
1102 multigraph->GetHistogram()->SetTitle(titleString.Data());
1103 multigraph->SetTitle(titleString.Data());
1104
1105 canvas->BuildLegend(0.15, 0.15, 0.35, 0.3, "MVA Method");
1106 }
1107
1108 return canvas;
1109}
1110
1111////////////////////////////////////////////////////////////////////////////////
1112/// Iterates through all booked methods and calls training
1113
1115{
1116 Log() << kHEADER << gTools().Color("bold") << "Train all methods" << gTools().Color("reset") << Endl;
1117 // iterates over all MVAs that have been booked, and calls their training methods
1118
1119 // don't do anything if no method booked
1120 if (fMethodsMap.empty()) {
1121 Log() << kINFO << "...nothing found to train" << Endl;
1122 return;
1123 }
1124
1125 // here the training starts
1126 // Log() << kINFO << " " << Endl;
1127 Log() << kDEBUG << "Train all methods for "
1128 << (fAnalysisType == Types::kRegression
1129 ? "Regression"
1130 : (fAnalysisType == Types::kMulticlass ? "Multiclass" : "Classification"))
1131 << " ..." << Endl;
1132
1133 std::map<TString, MVector *>::iterator itrMap;
1134
1135 for (itrMap = fMethodsMap.begin(); itrMap != fMethodsMap.end(); ++itrMap) {
1136 MVector *methods = itrMap->second;
1137 MVector::iterator itrMethod;
1138
1139 // iterate over methods and train
1140 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1142 MethodBase *mva = dynamic_cast<MethodBase *>(*itrMethod);
1143
1144 if (mva == 0)
1145 continue;
1146
1147 if (mva->DataInfo().GetDataSetManager()->DataInput().GetEntries() <=
1148 1) { // 0 entries --> 0 events, 1 entry --> dynamical dataset (or one entry)
1149 Log() << kFATAL << "No input data for the training provided!" << Endl;
1150 }
1151
1152 if (fAnalysisType == Types::kRegression && mva->DataInfo().GetNTargets() < 1)
1153 Log() << kFATAL << "You want to do regression training without specifying a target." << Endl;
1154 else if ((fAnalysisType == Types::kMulticlass || fAnalysisType == Types::kClassification) &&
1155 mva->DataInfo().GetNClasses() < 2)
1156 Log() << kFATAL << "You want to do classification training, but specified less than two classes." << Endl;
1157
1158 // first print some information about the default dataset
1159 if (!IsSilentFile())
1160 WriteDataInformation(mva->fDataSetInfo);
1161
1163 Log() << kWARNING << "Method " << mva->GetMethodName() << " not trained (training tree has less entries ["
1164 << mva->Data()->GetNTrainingEvents() << "] than required [" << MinNoTrainingEvents << "]" << Endl;
1165 continue;
1166 }
1167
1168 Log() << kHEADER << "Train method: " << mva->GetMethodName() << " for "
1169 << (fAnalysisType == Types::kRegression
1170 ? "Regression"
1171 : (fAnalysisType == Types::kMulticlass ? "Multiclass classification" : "Classification"))
1172 << Endl << Endl;
1173 mva->TrainMethod();
1174 Log() << kHEADER << "Training finished" << Endl << Endl;
1175 }
1176
1177 if (fAnalysisType != Types::kRegression) {
1178
1179 // variable ranking
1180 // Log() << Endl;
1181 Log() << kINFO << "Ranking input variables (method specific)..." << Endl;
1182 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1183 MethodBase *mva = dynamic_cast<MethodBase *>(*itrMethod);
1184 if (mva && mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
1185
1186 // create and print ranking
1187 const Ranking *ranking = (*itrMethod)->CreateRanking();
1188 if (ranking != 0)
1189 ranking->Print();
1190 else
1191 Log() << kINFO << "No variable ranking supplied by classifier: "
1192 << dynamic_cast<MethodBase *>(*itrMethod)->GetMethodName() << Endl;
1193 }
1194 }
1195 }
1196
1197 // save training history in case we are not in the silent mode
1198 if (!IsSilentFile()) {
1199 for (UInt_t i = 0; i < methods->size(); i++) {
1200 MethodBase *m = dynamic_cast<MethodBase *>((*methods)[i]);
1201 if (m == 0)
1202 continue;
1203 m->BaseDir()->cd();
1204 m->fTrainHistory.SaveHistory(m->GetMethodName());
1205 }
1206 }
1207
1208 // delete all methods and recreate them from weight file - this ensures that the application
1209 // of the methods (in TMVAClassificationApplication) is consistent with the results obtained
1210 // in the testing
1211 // Log() << Endl;
1212 if (fModelPersistence) {
1213
1214 Log() << kHEADER << "=== Destroy and recreate all methods via weight files for testing ===" << Endl << Endl;
1215
1216 if (!IsSilentFile())
1217 RootBaseDir()->cd();
1218
1219 // iterate through all booked methods
1220 for (UInt_t i = 0; i < methods->size(); i++) {
1221
1222 MethodBase *m = dynamic_cast<MethodBase *>((*methods)[i]);
1223 if (m == nullptr)
1224 continue;
1225
1226 TMVA::Types::EMVA methodType = m->GetMethodType();
1227 TString weightfile = m->GetWeightFileName();
1228
1229 // decide if .txt or .xml file should be read:
1230 if (READXML)
1231 weightfile.ReplaceAll(".txt", ".xml");
1232
1233 DataSetInfo &dataSetInfo = m->DataInfo();
1234 TString testvarName = m->GetTestvarName();
1235 delete m; // itrMethod[i];
1236
1237 // recreate
1238 m = dynamic_cast<MethodBase *>(ClassifierFactory::Instance().Create(
1239 Types::Instance().GetMethodName(methodType).Data(), dataSetInfo, weightfile));
1240 if (m->GetMethodType() == Types::kCategory) {
1241 MethodCategory *methCat = (dynamic_cast<MethodCategory *>(m));
1242 if (!methCat)
1243 Log() << kFATAL << "Method with type kCategory cannot be casted to MethodCategory. /Factory" << Endl;
1244 else
1245 methCat->fDataSetManager = m->DataInfo().GetDataSetManager();
1246 }
1247 // ToDo, Do we need to fill the DataSetManager of MethodBoost here too?
1248
1249 TString wfileDir = m->DataInfo().GetName();
1250 wfileDir += "/" + gConfig().GetIONames().fWeightFileDir;
1251 m->SetWeightFileDir(wfileDir);
1252 m->SetModelPersistence(fModelPersistence);
1253 m->SetSilentFile(IsSilentFile());
1254 m->SetAnalysisType(fAnalysisType);
1255 m->SetupMethod();
1256 m->ReadStateFromFile();
1257 m->SetTestvarName(testvarName);
1258
1259 // replace trained method by newly created one (from weight file) in methods vector
1260 (*methods)[i] = m;
1261 }
1262 }
1263 }
1264}
1265
1266////////////////////////////////////////////////////////////////////////////////
1267/// Evaluates all booked methods on the testing data and adds the output to the
1268/// Results in the corresponiding DataSet.
1269///
1270
1272{
1273 Log() << kHEADER << gTools().Color("bold") << "Test all methods" << gTools().Color("reset") << Endl;
1274
1275 // don't do anything if no method booked
1276 if (fMethodsMap.empty()) {
1277 Log() << kINFO << "...nothing found to test" << Endl;
1278 return;
1279 }
1280 std::map<TString, MVector *>::iterator itrMap;
1281
1282 for (itrMap = fMethodsMap.begin(); itrMap != fMethodsMap.end(); ++itrMap) {
1283 MVector *methods = itrMap->second;
1284 MVector::iterator itrMethod;
1285
1286 // iterate over methods and test
1287 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1289 MethodBase *mva = dynamic_cast<MethodBase *>(*itrMethod);
1290 if (mva == 0)
1291 continue;
1292 Types::EAnalysisType analysisType = mva->GetAnalysisType();
1293 Log() << kHEADER << "Test method: " << mva->GetMethodName() << " for "
1294 << (analysisType == Types::kRegression
1295 ? "Regression"
1296 : (analysisType == Types::kMulticlass ? "Multiclass classification" : "Classification"))
1297 << " performance" << Endl << Endl;
1298 mva->AddOutput(Types::kTesting, analysisType);
1299 }
1300 }
1301}
1302
1303////////////////////////////////////////////////////////////////////////////////
1304
1305void TMVA::Factory::MakeClass(const TString &datasetname, const TString &methodTitle) const
1306{
1307 if (methodTitle != "") {
1308 IMethod *method = GetMethod(datasetname, methodTitle);
1309 if (method)
1310 method->MakeClass();
1311 else {
1312 Log() << kWARNING << "<MakeClass> Could not find classifier \"" << methodTitle << "\" in list" << Endl;
1313 }
1314 } else {
1315
1316 // no classifier specified, print all help messages
1317 MVector *methods = fMethodsMap.find(datasetname)->second;
1318 MVector::const_iterator itrMethod;
1319 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1320 MethodBase *method = dynamic_cast<MethodBase *>(*itrMethod);
1321 if (method == 0)
1322 continue;
1323 Log() << kINFO << "Make response class for classifier: " << method->GetMethodName() << Endl;
1324 method->MakeClass();
1325 }
1326 }
1327}
1328
1329////////////////////////////////////////////////////////////////////////////////
1330/// Print predefined help message of classifier.
1331/// Iterate over methods and test.
1332
1333void TMVA::Factory::PrintHelpMessage(const TString &datasetname, const TString &methodTitle) const
1334{
1335 if (methodTitle != "") {
1336 IMethod *method = GetMethod(datasetname, methodTitle);
1337 if (method)
1338 method->PrintHelpMessage();
1339 else {
1340 Log() << kWARNING << "<PrintHelpMessage> Could not find classifier \"" << methodTitle << "\" in list" << Endl;
1341 }
1342 } else {
1343
1344 // no classifier specified, print all help messages
1345 MVector *methods = fMethodsMap.find(datasetname)->second;
1346 MVector::const_iterator itrMethod;
1347 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1348 MethodBase *method = dynamic_cast<MethodBase *>(*itrMethod);
1349 if (method == 0)
1350 continue;
1351 Log() << kINFO << "Print help message for classifier: " << method->GetMethodName() << Endl;
1352 method->PrintHelpMessage();
1353 }
1354 }
1355}
1356
1357////////////////////////////////////////////////////////////////////////////////
1358/// Iterates over all MVA input variables and evaluates them.
1359
1361{
1362 Log() << kINFO << "Evaluating all variables..." << Endl;
1364
1365 for (UInt_t i = 0; i < loader->GetDataSetInfo().GetNVariables(); i++) {
1367 if (options.Contains("V"))
1368 s += ":V";
1369 this->BookMethod(loader, "Variable", s);
1370 }
1371}
1372
1373////////////////////////////////////////////////////////////////////////////////
1374/// Iterates over all MVAs that have been booked, and calls their evaluation methods.
1375
1377{
1378 Log() << kHEADER << gTools().Color("bold") << "Evaluate all methods" << gTools().Color("reset") << Endl;
1379
1380 // don't do anything if no method booked
1381 if (fMethodsMap.empty()) {
1382 Log() << kINFO << "...nothing found to evaluate" << Endl;
1383 return;
1384 }
1385 std::map<TString, MVector *>::iterator itrMap;
1386
1387 for (itrMap = fMethodsMap.begin(); itrMap != fMethodsMap.end(); ++itrMap) {
1388 MVector *methods = itrMap->second;
1389
1390 // -----------------------------------------------------------------------
1391 // First part of evaluation process
1392 // --> compute efficiencies, and other separation estimators
1393 // -----------------------------------------------------------------------
1394
1395 // although equal, we now want to separate the output for the variables
1396 // and the real methods
1397 Int_t isel; // will be 0 for a Method; 1 for a Variable
1398 Int_t nmeth_used[2] = {0, 0}; // 0 Method; 1 Variable
1399
1400 std::vector<std::vector<TString>> mname(2);
1401 std::vector<std::vector<Double_t>> sig(2), sep(2), roc(2);
1402 std::vector<std::vector<Double_t>> eff01(2), eff10(2), eff30(2), effArea(2);
1403 std::vector<std::vector<Double_t>> eff01err(2), eff10err(2), eff30err(2);
1404 std::vector<std::vector<Double_t>> trainEff01(2), trainEff10(2), trainEff30(2);
1405
1406 std::vector<std::vector<Float_t>> multiclass_testEff;
1407 std::vector<std::vector<Float_t>> multiclass_trainEff;
1408 std::vector<std::vector<Float_t>> multiclass_testPur;
1409 std::vector<std::vector<Float_t>> multiclass_trainPur;
1410
1411 std::vector<std::vector<Float_t>> train_history;
1412
1413 // Multiclass confusion matrices.
1414 std::vector<TMatrixD> multiclass_trainConfusionEffB01;
1415 std::vector<TMatrixD> multiclass_trainConfusionEffB10;
1416 std::vector<TMatrixD> multiclass_trainConfusionEffB30;
1417 std::vector<TMatrixD> multiclass_testConfusionEffB01;
1418 std::vector<TMatrixD> multiclass_testConfusionEffB10;
1419 std::vector<TMatrixD> multiclass_testConfusionEffB30;
1420
1421 std::vector<std::vector<Double_t>> biastrain(1); // "bias" of the regression on the training data
1422 std::vector<std::vector<Double_t>> biastest(1); // "bias" of the regression on test data
1423 std::vector<std::vector<Double_t>> devtrain(1); // "dev" of the regression on the training data
1424 std::vector<std::vector<Double_t>> devtest(1); // "dev" of the regression on test data
1425 std::vector<std::vector<Double_t>> rmstrain(1); // "rms" of the regression on the training data
1426 std::vector<std::vector<Double_t>> rmstest(1); // "rms" of the regression on test data
1427 std::vector<std::vector<Double_t>> minftrain(1); // "minf" of the regression on the training data
1428 std::vector<std::vector<Double_t>> minftest(1); // "minf" of the regression on test data
1429 std::vector<std::vector<Double_t>> rhotrain(1); // correlation of the regression on the training data
1430 std::vector<std::vector<Double_t>> rhotest(1); // correlation of the regression on test data
1431
1432 // same as above but for 'truncated' quantities (computed for events within 2sigma of RMS)
1433 std::vector<std::vector<Double_t>> biastrainT(1);
1434 std::vector<std::vector<Double_t>> biastestT(1);
1435 std::vector<std::vector<Double_t>> devtrainT(1);
1436 std::vector<std::vector<Double_t>> devtestT(1);
1437 std::vector<std::vector<Double_t>> rmstrainT(1);
1438 std::vector<std::vector<Double_t>> rmstestT(1);
1439 std::vector<std::vector<Double_t>> minftrainT(1);
1440 std::vector<std::vector<Double_t>> minftestT(1);
1441
1442 // following vector contains all methods - with the exception of Cuts, which are special
1443 MVector methodsNoCuts;
1444
1445 Bool_t doRegression = kFALSE;
1446 Bool_t doMulticlass = kFALSE;
1447
1448 // iterate over methods and evaluate
1449 for (MVector::iterator itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1451 MethodBase *theMethod = dynamic_cast<MethodBase *>(*itrMethod);
1452 if (theMethod == 0)
1453 continue;
1454 theMethod->SetFile(fgTargetFile);
1455 theMethod->SetSilentFile(IsSilentFile());
1456 if (theMethod->GetMethodType() != Types::kCuts)
1457 methodsNoCuts.push_back(*itrMethod);
1458
1459 if (theMethod->DoRegression()) {
1460 doRegression = kTRUE;
1461
1462 Log() << kINFO << "Evaluate regression method: " << theMethod->GetMethodName() << Endl;
1463 Double_t bias, dev, rms, mInf;
1464 Double_t biasT, devT, rmsT, mInfT;
1465 Double_t rho;
1466
1467 Log() << kINFO << "TestRegression (testing)" << Endl;
1468 theMethod->TestRegression(bias, biasT, dev, devT, rms, rmsT, mInf, mInfT, rho, TMVA::Types::kTesting);
1469 biastest[0].push_back(bias);
1470 devtest[0].push_back(dev);
1471 rmstest[0].push_back(rms);
1472 minftest[0].push_back(mInf);
1473 rhotest[0].push_back(rho);
1474 biastestT[0].push_back(biasT);
1475 devtestT[0].push_back(devT);
1476 rmstestT[0].push_back(rmsT);
1477 minftestT[0].push_back(mInfT);
1478
1479 Log() << kINFO << "TestRegression (training)" << Endl;
1480 theMethod->TestRegression(bias, biasT, dev, devT, rms, rmsT, mInf, mInfT, rho, TMVA::Types::kTraining);
1481 biastrain[0].push_back(bias);
1482 devtrain[0].push_back(dev);
1483 rmstrain[0].push_back(rms);
1484 minftrain[0].push_back(mInf);
1485 rhotrain[0].push_back(rho);
1486 biastrainT[0].push_back(biasT);
1487 devtrainT[0].push_back(devT);
1488 rmstrainT[0].push_back(rmsT);
1489 minftrainT[0].push_back(mInfT);
1490
1491 mname[0].push_back(theMethod->GetMethodName());
1492 nmeth_used[0]++;
1493 if (!IsSilentFile()) {
1494 Log() << kDEBUG << "\tWrite evaluation histograms to file" << Endl;
1497 }
1498 } else if (theMethod->DoMulticlass()) {
1499 // ====================================================================
1500 // === Multiclass evaluation
1501 // ====================================================================
1502 doMulticlass = kTRUE;
1503 Log() << kINFO << "Evaluate multiclass classification method: " << theMethod->GetMethodName() << Endl;
1504
1505 // This part uses a genetic alg. to evaluate the optimal sig eff * sig pur.
1506 // This is why it is disabled for now.
1507 // Find approximate optimal working point w.r.t. signalEfficiency * signalPurity.
1508 // theMethod->TestMulticlass(); // This is where the actual GA calc is done
1509 // multiclass_testEff.push_back(theMethod->GetMulticlassEfficiency(multiclass_testPur));
1510
1511 theMethod->TestMulticlass();
1512
1513 // Confusion matrix at three background efficiency levels
1514 multiclass_trainConfusionEffB01.push_back(theMethod->GetMulticlassConfusionMatrix(0.01, Types::kTraining));
1515 multiclass_trainConfusionEffB10.push_back(theMethod->GetMulticlassConfusionMatrix(0.10, Types::kTraining));
1516 multiclass_trainConfusionEffB30.push_back(theMethod->GetMulticlassConfusionMatrix(0.30, Types::kTraining));
1517
1518 multiclass_testConfusionEffB01.push_back(theMethod->GetMulticlassConfusionMatrix(0.01, Types::kTesting));
1519 multiclass_testConfusionEffB10.push_back(theMethod->GetMulticlassConfusionMatrix(0.10, Types::kTesting));
1520 multiclass_testConfusionEffB30.push_back(theMethod->GetMulticlassConfusionMatrix(0.30, Types::kTesting));
1521
1522 if (!IsSilentFile()) {
1523 Log() << kDEBUG << "\tWrite evaluation histograms to file" << Endl;
1526 }
1527
1528 nmeth_used[0]++;
1529 mname[0].push_back(theMethod->GetMethodName());
1530 } else {
1531
1532 Log() << kHEADER << "Evaluate classifier: " << theMethod->GetMethodName() << Endl << Endl;
1533 isel = (theMethod->GetMethodTypeName().Contains("Variable")) ? 1 : 0;
1534
1535 // perform the evaluation
1536 theMethod->TestClassification();
1537
1538 // evaluate the classifier
1539 mname[isel].push_back(theMethod->GetMethodName());
1540 sig[isel].push_back(theMethod->GetSignificance());
1541 sep[isel].push_back(theMethod->GetSeparation());
1542 roc[isel].push_back(theMethod->GetROCIntegral());
1543
1544 Double_t err;
1545 eff01[isel].push_back(theMethod->GetEfficiency("Efficiency:0.01", Types::kTesting, err));
1546 eff01err[isel].push_back(err);
1547 eff10[isel].push_back(theMethod->GetEfficiency("Efficiency:0.10", Types::kTesting, err));
1548 eff10err[isel].push_back(err);
1549 eff30[isel].push_back(theMethod->GetEfficiency("Efficiency:0.30", Types::kTesting, err));
1550 eff30err[isel].push_back(err);
1551 effArea[isel].push_back(theMethod->GetEfficiency("", Types::kTesting, err)); // computes the area (average)
1552
1553 trainEff01[isel].push_back(
1554 theMethod->GetTrainingEfficiency("Efficiency:0.01")); // the first pass takes longer
1555 trainEff10[isel].push_back(theMethod->GetTrainingEfficiency("Efficiency:0.10"));
1556 trainEff30[isel].push_back(theMethod->GetTrainingEfficiency("Efficiency:0.30"));
1557
1558 nmeth_used[isel]++;
1559
1560 if (!IsSilentFile()) {
1561 Log() << kDEBUG << "\tWrite evaluation histograms to file" << Endl;
1564 }
1565 }
1566 }
1567 if (doRegression) {
1568
1569 std::vector<TString> vtemps = mname[0];
1570 std::vector<std::vector<Double_t>> vtmp;
1571 vtmp.push_back(devtest[0]); // this is the vector that is ranked
1572 vtmp.push_back(devtrain[0]);
1573 vtmp.push_back(biastest[0]);
1574 vtmp.push_back(biastrain[0]);
1575 vtmp.push_back(rmstest[0]);
1576 vtmp.push_back(rmstrain[0]);
1577 vtmp.push_back(minftest[0]);
1578 vtmp.push_back(minftrain[0]);
1579 vtmp.push_back(rhotest[0]);
1580 vtmp.push_back(rhotrain[0]);
1581 vtmp.push_back(devtestT[0]); // this is the vector that is ranked
1582 vtmp.push_back(devtrainT[0]);
1583 vtmp.push_back(biastestT[0]);
1584 vtmp.push_back(biastrainT[0]);
1585 vtmp.push_back(rmstestT[0]);
1586 vtmp.push_back(rmstrainT[0]);
1587 vtmp.push_back(minftestT[0]);
1588 vtmp.push_back(minftrainT[0]);
1589 gTools().UsefulSortAscending(vtmp, &vtemps);
1590 mname[0] = vtemps;
1591 devtest[0] = vtmp[0];
1592 devtrain[0] = vtmp[1];
1593 biastest[0] = vtmp[2];
1594 biastrain[0] = vtmp[3];
1595 rmstest[0] = vtmp[4];
1596 rmstrain[0] = vtmp[5];
1597 minftest[0] = vtmp[6];
1598 minftrain[0] = vtmp[7];
1599 rhotest[0] = vtmp[8];
1600 rhotrain[0] = vtmp[9];
1601 devtestT[0] = vtmp[10];
1602 devtrainT[0] = vtmp[11];
1603 biastestT[0] = vtmp[12];
1604 biastrainT[0] = vtmp[13];
1605 rmstestT[0] = vtmp[14];
1606 rmstrainT[0] = vtmp[15];
1607 minftestT[0] = vtmp[16];
1608 minftrainT[0] = vtmp[17];
1609 } else if (doMulticlass) {
1610 // TODO: fill in something meaningful
1611 // If there is some ranking of methods to be done it should be done here.
1612 // However, this is not so easy to define for multiclass so it is left out for now.
1613
1614 } else {
1615 // now sort the variables according to the best 'eff at Beff=0.10'
1616 for (Int_t k = 0; k < 2; k++) {
1617 std::vector<std::vector<Double_t>> vtemp;
1618 vtemp.push_back(effArea[k]); // this is the vector that is ranked
1619 vtemp.push_back(eff10[k]);
1620 vtemp.push_back(eff01[k]);
1621 vtemp.push_back(eff30[k]);
1622 vtemp.push_back(eff10err[k]);
1623 vtemp.push_back(eff01err[k]);
1624 vtemp.push_back(eff30err[k]);
1625 vtemp.push_back(trainEff10[k]);
1626 vtemp.push_back(trainEff01[k]);
1627 vtemp.push_back(trainEff30[k]);
1628 vtemp.push_back(sig[k]);
1629 vtemp.push_back(sep[k]);
1630 vtemp.push_back(roc[k]);
1631 std::vector<TString> vtemps = mname[k];
1632 gTools().UsefulSortDescending(vtemp, &vtemps);
1633 effArea[k] = vtemp[0];
1634 eff10[k] = vtemp[1];
1635 eff01[k] = vtemp[2];
1636 eff30[k] = vtemp[3];
1637 eff10err[k] = vtemp[4];
1638 eff01err[k] = vtemp[5];
1639 eff30err[k] = vtemp[6];
1640 trainEff10[k] = vtemp[7];
1641 trainEff01[k] = vtemp[8];
1642 trainEff30[k] = vtemp[9];
1643 sig[k] = vtemp[10];
1644 sep[k] = vtemp[11];
1645 roc[k] = vtemp[12];
1646 mname[k] = vtemps;
1647 }
1648 }
1649
1650 // -----------------------------------------------------------------------
1651 // Second part of evaluation process
1652 // --> compute correlations among MVAs
1653 // --> compute correlations between input variables and MVA (determines importance)
1654 // --> count overlaps
1655 // -----------------------------------------------------------------------
1656 if (fCorrelations) {
1657 const Int_t nmeth = methodsNoCuts.size();
1658 MethodBase *method = dynamic_cast<MethodBase *>(methods[0][0]);
1659 const Int_t nvar = method->fDataSetInfo.GetNVariables();
1660 if (!doRegression && !doMulticlass) {
1661
1662 if (nmeth > 0) {
1663
1664 // needed for correlations
1665 Double_t *dvec = new Double_t[nmeth + nvar];
1666 std::vector<Double_t> rvec;
1667
1668 // for correlations
1669 TPrincipal *tpSig = new TPrincipal(nmeth + nvar, "");
1670 TPrincipal *tpBkg = new TPrincipal(nmeth + nvar, "");
1671
1672 // set required tree branch references
1673 Int_t ivar = 0;
1674 std::vector<TString> *theVars = new std::vector<TString>;
1675 std::vector<ResultsClassification *> mvaRes;
1676 for (MVector::iterator itrMethod = methodsNoCuts.begin(); itrMethod != methodsNoCuts.end();
1677 ++itrMethod, ++ivar) {
1678 MethodBase *m = dynamic_cast<MethodBase *>(*itrMethod);
1679 if (m == 0)
1680 continue;
1681 theVars->push_back(m->GetTestvarName());
1682 rvec.push_back(m->GetSignalReferenceCut());
1683 theVars->back().ReplaceAll("MVA_", "");
1684 mvaRes.push_back(dynamic_cast<ResultsClassification *>(
1685 m->Data()->GetResults(m->GetMethodName(), Types::kTesting, Types::kMaxAnalysisType)));
1686 }
1687
1688 // for overlap study
1689 TMatrixD *overlapS = new TMatrixD(nmeth, nmeth);
1690 TMatrixD *overlapB = new TMatrixD(nmeth, nmeth);
1691 (*overlapS) *= 0; // init...
1692 (*overlapB) *= 0; // init...
1693
1694 // loop over test tree
1695 DataSet *defDs = method->fDataSetInfo.GetDataSet();
1697 for (Int_t ievt = 0; ievt < defDs->GetNEvents(); ievt++) {
1698 const Event *ev = defDs->GetEvent(ievt);
1699
1700 // for correlations
1701 TMatrixD *theMat = 0;
1702 for (Int_t im = 0; im < nmeth; im++) {
1703 // check for NaN value
1704 Double_t retval = (Double_t)(*mvaRes[im])[ievt][0];
1705 if (TMath::IsNaN(retval)) {
1706 Log() << kWARNING << "Found NaN return value in event: " << ievt << " for method \""
1707 << methodsNoCuts[im]->GetName() << "\"" << Endl;
1708 dvec[im] = 0;
1709 } else
1710 dvec[im] = retval;
1711 }
1712 for (Int_t iv = 0; iv < nvar; iv++)
1713 dvec[iv + nmeth] = (Double_t)ev->GetValue(iv);
1714 if (method->fDataSetInfo.IsSignal(ev)) {
1715 tpSig->AddRow(dvec);
1716 theMat = overlapS;
1717 } else {
1718 tpBkg->AddRow(dvec);
1719 theMat = overlapB;
1720 }
1721
1722 // count overlaps
1723 for (Int_t im = 0; im < nmeth; im++) {
1724 for (Int_t jm = im; jm < nmeth; jm++) {
1725 if ((dvec[im] - rvec[im]) * (dvec[jm] - rvec[jm]) > 0) {
1726 (*theMat)(im, jm)++;
1727 if (im != jm)
1728 (*theMat)(jm, im)++;
1729 }
1730 }
1731 }
1732 }
1733
1734 // renormalise overlap matrix
1735 (*overlapS) *= (1.0 / defDs->GetNEvtSigTest()); // init...
1736 (*overlapB) *= (1.0 / defDs->GetNEvtBkgdTest()); // init...
1737
1738 tpSig->MakePrincipals();
1739 tpBkg->MakePrincipals();
1740
1741 const TMatrixD *covMatS = tpSig->GetCovarianceMatrix();
1742 const TMatrixD *covMatB = tpBkg->GetCovarianceMatrix();
1743
1744 const TMatrixD *corrMatS = gTools().GetCorrelationMatrix(covMatS);
1745 const TMatrixD *corrMatB = gTools().GetCorrelationMatrix(covMatB);
1746
1747 // print correlation matrices
1748 if (corrMatS != 0 && corrMatB != 0) {
1749
1750 // extract MVA matrix
1751 TMatrixD mvaMatS(nmeth, nmeth);
1752 TMatrixD mvaMatB(nmeth, nmeth);
1753 for (Int_t im = 0; im < nmeth; im++) {
1754 for (Int_t jm = 0; jm < nmeth; jm++) {
1755 mvaMatS(im, jm) = (*corrMatS)(im, jm);
1756 mvaMatB(im, jm) = (*corrMatB)(im, jm);
1757 }
1758 }
1759
1760 // extract variables - to MVA matrix
1761 std::vector<TString> theInputVars;
1762 TMatrixD varmvaMatS(nvar, nmeth);
1763 TMatrixD varmvaMatB(nvar, nmeth);
1764 for (Int_t iv = 0; iv < nvar; iv++) {
1765 theInputVars.push_back(method->fDataSetInfo.GetVariableInfo(iv).GetLabel());
1766 for (Int_t jm = 0; jm < nmeth; jm++) {
1767 varmvaMatS(iv, jm) = (*corrMatS)(nmeth + iv, jm);
1768 varmvaMatB(iv, jm) = (*corrMatB)(nmeth + iv, jm);
1769 }
1770 }
1771
1772 if (nmeth > 1) {
1773 Log() << kINFO << Endl;
1774 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1775 << "Inter-MVA correlation matrix (signal):" << Endl;
1776 gTools().FormattedOutput(mvaMatS, *theVars, Log());
1777 Log() << kINFO << Endl;
1778
1779 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1780 << "Inter-MVA correlation matrix (background):" << Endl;
1781 gTools().FormattedOutput(mvaMatB, *theVars, Log());
1782 Log() << kINFO << Endl;
1783 }
1784
1785 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1786 << "Correlations between input variables and MVA response (signal):" << Endl;
1787 gTools().FormattedOutput(varmvaMatS, theInputVars, *theVars, Log());
1788 Log() << kINFO << Endl;
1789
1790 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1791 << "Correlations between input variables and MVA response (background):" << Endl;
1792 gTools().FormattedOutput(varmvaMatB, theInputVars, *theVars, Log());
1793 Log() << kINFO << Endl;
1794 } else
1795 Log() << kWARNING << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1796 << "<TestAllMethods> cannot compute correlation matrices" << Endl;
1797
1798 // print overlap matrices
1799 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1800 << "The following \"overlap\" matrices contain the fraction of events for which " << Endl;
1801 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1802 << "the MVAs 'i' and 'j' have returned conform answers about \"signal-likeness\"" << Endl;
1803 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1804 << "An event is signal-like, if its MVA output exceeds the following value:" << Endl;
1805 gTools().FormattedOutput(rvec, *theVars, "Method", "Cut value", Log());
1806 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1807 << "which correspond to the working point: eff(signal) = 1 - eff(background)" << Endl;
1808
1809 // give notice that cut method has been excluded from this test
1810 if (nmeth != (Int_t)methods->size())
1811 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1812 << "Note: no correlations and overlap with cut method are provided at present" << Endl;
1813
1814 if (nmeth > 1) {
1815 Log() << kINFO << Endl;
1816 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1817 << "Inter-MVA overlap matrix (signal):" << Endl;
1818 gTools().FormattedOutput(*overlapS, *theVars, Log());
1819 Log() << kINFO << Endl;
1820
1821 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1822 << "Inter-MVA overlap matrix (background):" << Endl;
1823 gTools().FormattedOutput(*overlapB, *theVars, Log());
1824 }
1825
1826 // cleanup
1827 delete tpSig;
1828 delete tpBkg;
1829 delete corrMatS;
1830 delete corrMatB;
1831 delete theVars;
1832 delete overlapS;
1833 delete overlapB;
1834 delete[] dvec;
1835 }
1836 }
1837 }
1838 // -----------------------------------------------------------------------
1839 // Third part of evaluation process
1840 // --> output
1841 // -----------------------------------------------------------------------
1842
1843 if (doRegression) {
1844
1845 Log() << kINFO << Endl;
1846 TString hLine =
1847 "--------------------------------------------------------------------------------------------------";
1848 Log() << kINFO << "Evaluation results ranked by smallest RMS on test sample:" << Endl;
1849 Log() << kINFO << "(\"Bias\" quotes the mean deviation of the regression from true target." << Endl;
1850 Log() << kINFO << " \"MutInf\" is the \"Mutual Information\" between regression and target." << Endl;
1851 Log() << kINFO << " Indicated by \"_T\" are the corresponding \"truncated\" quantities ob-" << Endl;
1852 Log() << kINFO << " tained when removing events deviating more than 2sigma from average.)" << Endl;
1853 Log() << kINFO << hLine << Endl;
1854 // Log() << kINFO << "DataSet Name: MVA Method: <Bias> <Bias_T> RMS RMS_T | MutInf
1855 // MutInf_T" << Endl;
1856 Log() << kINFO << hLine << Endl;
1857
1858 for (Int_t i = 0; i < nmeth_used[0]; i++) {
1859 MethodBase *theMethod = dynamic_cast<MethodBase *>((*methods)[i]);
1860 if (theMethod == 0)
1861 continue;
1862
1863 Log() << kINFO
1864 << Form("%-20s %-15s:%#9.3g%#9.3g%#9.3g%#9.3g | %#5.3f %#5.3f", theMethod->fDataSetInfo.GetName(),
1865 (const char *)mname[0][i], biastest[0][i], biastestT[0][i], rmstest[0][i], rmstestT[0][i],
1866 minftest[0][i], minftestT[0][i])
1867 << Endl;
1868 }
1869 Log() << kINFO << hLine << Endl;
1870 Log() << kINFO << Endl;
1871 Log() << kINFO << "Evaluation results ranked by smallest RMS on training sample:" << Endl;
1872 Log() << kINFO << "(overtraining check)" << Endl;
1873 Log() << kINFO << hLine << Endl;
1874 Log() << kINFO
1875 << "DataSet Name: MVA Method: <Bias> <Bias_T> RMS RMS_T | MutInf MutInf_T"
1876 << Endl;
1877 Log() << kINFO << hLine << Endl;
1878
1879 for (Int_t i = 0; i < nmeth_used[0]; i++) {
1880 MethodBase *theMethod = dynamic_cast<MethodBase *>((*methods)[i]);
1881 if (theMethod == 0)
1882 continue;
1883 Log() << kINFO
1884 << Form("%-20s %-15s:%#9.3g%#9.3g%#9.3g%#9.3g | %#5.3f %#5.3f", theMethod->fDataSetInfo.GetName(),
1885 (const char *)mname[0][i], biastrain[0][i], biastrainT[0][i], rmstrain[0][i], rmstrainT[0][i],
1886 minftrain[0][i], minftrainT[0][i])
1887 << Endl;
1888 }
1889 Log() << kINFO << hLine << Endl;
1890 Log() << kINFO << Endl;
1891 } else if (doMulticlass) {
1892 // ====================================================================
1893 // === Multiclass Output
1894 // ====================================================================
1895
1896 TString hLine =
1897 "-------------------------------------------------------------------------------------------------------";
1898
1899 // This part uses a genetic alg. to evaluate the optimal sig eff * sig pur.
1900 // This is why it is disabled for now.
1901 //
1902 // // --- Acheivable signal efficiency * signal purity
1903 // // --------------------------------------------------------------------
1904 // Log() << kINFO << Endl;
1905 // Log() << kINFO << "Evaluation results ranked by best signal efficiency times signal purity " << Endl;
1906 // Log() << kINFO << hLine << Endl;
1907
1908 // // iterate over methods and evaluate
1909 // for (MVector::iterator itrMethod = methods->begin(); itrMethod != methods->end(); itrMethod++) {
1910 // MethodBase *theMethod = dynamic_cast<MethodBase *>(*itrMethod);
1911 // if (theMethod == 0) {
1912 // continue;
1913 // }
1914
1915 // TString header = "DataSet Name MVA Method ";
1916 // for (UInt_t icls = 0; icls < theMethod->fDataSetInfo.GetNClasses(); ++icls) {
1917 // header += TString::Format("%-12s ", theMethod->fDataSetInfo.GetClassInfo(icls)->GetName());
1918 // }
1919
1920 // Log() << kINFO << header << Endl;
1921 // Log() << kINFO << hLine << Endl;
1922 // for (Int_t i = 0; i < nmeth_used[0]; i++) {
1923 // TString res = TString::Format("[%-14s] %-15s", theMethod->fDataSetInfo.GetName(), mname[0][i].Data());
1924 // for (UInt_t icls = 0; icls < theMethod->fDataSetInfo.GetNClasses(); ++icls) {
1925 // res += TString::Format("%#1.3f ", (multiclass_testEff[i][icls]) * (multiclass_testPur[i][icls]));
1926 // }
1927 // Log() << kINFO << res << Endl;
1928 // }
1929
1930 // Log() << kINFO << hLine << Endl;
1931 // Log() << kINFO << Endl;
1932 // }
1933
1934 // --- 1 vs Rest ROC AUC, signal efficiency @ given background efficiency
1935 // --------------------------------------------------------------------
1936 TString header1 = TString::Format("%-15s%-15s%-15s%-15s%-15s%-15s", "Dataset", "MVA Method", "ROC AUC", "Sig eff@B=0.01",
1937 "Sig eff@B=0.10", "Sig eff@B=0.30");
1938 TString header2 = TString::Format("%-15s%-15s%-15s%-15s%-15s%-15s", "Name:", "/ Class:", "test (train)", "test (train)",
1939 "test (train)", "test (train)");
1940 Log() << kINFO << Endl;
1941 Log() << kINFO << "1-vs-rest performance metrics per class" << Endl;
1942 Log() << kINFO << hLine << Endl;
1943 Log() << kINFO << Endl;
1944 Log() << kINFO << "Considers the listed class as signal and the other classes" << Endl;
1945 Log() << kINFO << "as background, reporting the resulting binary performance." << Endl;
1946 Log() << kINFO << "A score of 0.820 (0.850) means 0.820 was acheived on the" << Endl;
1947 Log() << kINFO << "test set and 0.850 on the training set." << Endl;
1948
1949 Log() << kINFO << Endl;
1950 Log() << kINFO << header1 << Endl;
1951 Log() << kINFO << header2 << Endl;
1952 for (Int_t k = 0; k < 2; k++) {
1953 for (Int_t i = 0; i < nmeth_used[k]; i++) {
1954 if (k == 1) {
1955 mname[k][i].ReplaceAll("Variable_", "");
1956 }
1957
1958 const TString datasetName = itrMap->first;
1959 const TString mvaName = mname[k][i];
1960
1961 MethodBase *theMethod = dynamic_cast<MethodBase *>(GetMethod(datasetName, mvaName));
1962 if (theMethod == 0) {
1963 continue;
1964 }
1965
1966 Log() << kINFO << Endl;
1967 TString row = TString::Format("%-15s%-15s", datasetName.Data(), mvaName.Data());
1968 Log() << kINFO << row << Endl;
1969 Log() << kINFO << "------------------------------" << Endl;
1970
1971 UInt_t numClasses = theMethod->fDataSetInfo.GetNClasses();
1972 for (UInt_t iClass = 0; iClass < numClasses; ++iClass) {
1973
1974 ROCCurve *rocCurveTrain = GetROC(datasetName, mvaName, iClass, Types::kTraining);
1975 ROCCurve *rocCurveTest = GetROC(datasetName, mvaName, iClass, Types::kTesting);
1976
1977 const TString className = theMethod->DataInfo().GetClassInfo(iClass)->GetName();
1978 const Double_t rocaucTrain = rocCurveTrain->GetROCIntegral();
1979 const Double_t effB01Train = rocCurveTrain->GetEffSForEffB(0.01);
1980 const Double_t effB10Train = rocCurveTrain->GetEffSForEffB(0.10);
1981 const Double_t effB30Train = rocCurveTrain->GetEffSForEffB(0.30);
1982 const Double_t rocaucTest = rocCurveTest->GetROCIntegral();
1983 const Double_t effB01Test = rocCurveTest->GetEffSForEffB(0.01);
1984 const Double_t effB10Test = rocCurveTest->GetEffSForEffB(0.10);
1985 const Double_t effB30Test = rocCurveTest->GetEffSForEffB(0.30);
1986 const TString rocaucCmp = TString::Format("%5.3f (%5.3f)", rocaucTest, rocaucTrain);
1987 const TString effB01Cmp = TString::Format("%5.3f (%5.3f)", effB01Test, effB01Train);
1988 const TString effB10Cmp = TString::Format("%5.3f (%5.3f)", effB10Test, effB10Train);
1989 const TString effB30Cmp = TString::Format("%5.3f (%5.3f)", effB30Test, effB30Train);
1990 row = TString::Format("%-15s%-15s%-15s%-15s%-15s%-15s", "", className.Data(), rocaucCmp.Data(), effB01Cmp.Data(),
1991 effB10Cmp.Data(), effB30Cmp.Data());
1992 Log() << kINFO << row << Endl;
1993
1994 delete rocCurveTrain;
1995 delete rocCurveTest;
1996 }
1997 }
1998 }
1999 Log() << kINFO << Endl;
2000 Log() << kINFO << hLine << Endl;
2001 Log() << kINFO << Endl;
2002
2003 // --- Confusion matrices
2004 // --------------------------------------------------------------------
2005 auto printMatrix = [](TMatrixD const &matTraining, TMatrixD const &matTesting, std::vector<TString> classnames,
2006 UInt_t numClasses, MsgLogger &stream) {
2007 // assert (classLabledWidth >= valueLabelWidth + 2)
2008 // if (...) {Log() << kWARN << "..." << Endl; }
2009
2010 // TODO: Ensure matrices are same size.
2011
2012 TString header = TString::Format(" %-14s", " ");
2013 TString headerInfo = TString::Format(" %-14s", " ");
2014
2015 for (UInt_t iCol = 0; iCol < numClasses; ++iCol) {
2016 header += TString::Format(" %-14s", classnames[iCol].Data());
2017 headerInfo += TString::Format(" %-14s", " test (train)");
2018 }
2019 stream << kINFO << header << Endl;
2020 stream << kINFO << headerInfo << Endl;
2021
2022 for (UInt_t iRow = 0; iRow < numClasses; ++iRow) {
2023 stream << kINFO << TString::Format(" %-14s", classnames[iRow].Data());
2024
2025 for (UInt_t iCol = 0; iCol < numClasses; ++iCol) {
2026 if (iCol == iRow) {
2027 stream << kINFO << TString::Format(" %-14s", "-");
2028 } else {
2029 Double_t trainValue = matTraining[iRow][iCol];
2030 Double_t testValue = matTesting[iRow][iCol];
2031 TString entry = TString::Format("%-5.3f (%-5.3f)", testValue, trainValue);
2032 stream << kINFO << TString::Format(" %-14s", entry.Data());
2033 }
2034 }
2035 stream << kINFO << Endl;
2036 }
2037 };
2038
2039 Log() << kINFO << Endl;
2040 Log() << kINFO << "Confusion matrices for all methods" << Endl;
2041 Log() << kINFO << hLine << Endl;
2042 Log() << kINFO << Endl;
2043 Log() << kINFO << "Does a binary comparison between the two classes given by a " << Endl;
2044 Log() << kINFO << "particular row-column combination. In each case, the class " << Endl;
2045 Log() << kINFO << "given by the row is considered signal while the class given " << Endl;
2046 Log() << kINFO << "by the column index is considered background." << Endl;
2047 Log() << kINFO << Endl;
2048 for (UInt_t iMethod = 0; iMethod < methods->size(); ++iMethod) {
2049 MethodBase *theMethod = dynamic_cast<MethodBase *>(methods->at(iMethod));
2050 if (theMethod == nullptr) {
2051 continue;
2052 }
2053 UInt_t numClasses = theMethod->fDataSetInfo.GetNClasses();
2054
2055 std::vector<TString> classnames;
2056 for (UInt_t iCls = 0; iCls < numClasses; ++iCls) {
2057 classnames.push_back(theMethod->fDataSetInfo.GetClassInfo(iCls)->GetName());
2058 }
2059 Log() << kINFO
2060 << "=== Showing confusion matrix for method : " << Form("%-15s", (const char *)mname[0][iMethod])
2061 << Endl;
2062 Log() << kINFO << "(Signal Efficiency for Background Efficiency 0.01%)" << Endl;
2063 Log() << kINFO << "---------------------------------------------------" << Endl;
2064 printMatrix(multiclass_testConfusionEffB01[iMethod], multiclass_trainConfusionEffB01[iMethod], classnames,
2065 numClasses, Log());
2066 Log() << kINFO << Endl;
2067
2068 Log() << kINFO << "(Signal Efficiency for Background Efficiency 0.10%)" << Endl;
2069 Log() << kINFO << "---------------------------------------------------" << Endl;
2070 printMatrix(multiclass_testConfusionEffB10[iMethod], multiclass_trainConfusionEffB10[iMethod], classnames,
2071 numClasses, Log());
2072 Log() << kINFO << Endl;
2073
2074 Log() << kINFO << "(Signal Efficiency for Background Efficiency 0.30%)" << Endl;
2075 Log() << kINFO << "---------------------------------------------------" << Endl;
2076 printMatrix(multiclass_testConfusionEffB30[iMethod], multiclass_trainConfusionEffB30[iMethod], classnames,
2077 numClasses, Log());
2078 Log() << kINFO << Endl;
2079 }
2080 Log() << kINFO << hLine << Endl;
2081 Log() << kINFO << Endl;
2082
2083 } else {
2084 // Binary classification
2085 if (fROC) {
2086 Log().EnableOutput();
2088 Log() << Endl;
2089 TString hLine = "------------------------------------------------------------------------------------------"
2090 "-------------------------";
2091 Log() << kINFO << "Evaluation results ranked by best signal efficiency and purity (area)" << Endl;
2092 Log() << kINFO << hLine << Endl;
2093 Log() << kINFO << "DataSet MVA " << Endl;
2094 Log() << kINFO << "Name: Method: ROC-integ" << Endl;
2095
2096 // Log() << kDEBUG << "DataSet MVA Signal efficiency at bkg eff.(error):
2097 // | Sepa- Signifi- " << Endl; Log() << kDEBUG << "Name: Method: @B=0.01
2098 // @B=0.10 @B=0.30 ROC-integ ROCCurve| ration: cance: " << Endl;
2099 Log() << kDEBUG << hLine << Endl;
2100 for (Int_t k = 0; k < 2; k++) {
2101 if (k == 1 && nmeth_used[k] > 0) {
2102 Log() << kINFO << hLine << Endl;
2103 Log() << kINFO << "Input Variables: " << Endl << hLine << Endl;
2104 }
2105 for (Int_t i = 0; i < nmeth_used[k]; i++) {
2106 TString datasetName = itrMap->first;
2107 TString methodName = mname[k][i];
2108
2109 if (k == 1) {
2110 methodName.ReplaceAll("Variable_", "");
2111 }
2112
2113 MethodBase *theMethod = dynamic_cast<MethodBase *>(GetMethod(datasetName, methodName));
2114 if (theMethod == 0) {
2115 continue;
2116 }
2117
2118 TMVA::DataSet *dataset = theMethod->Data();
2119 TMVA::Results *results = dataset->GetResults(methodName, Types::kTesting, this->fAnalysisType);
2120 std::vector<Bool_t> *mvaResType =
2121 dynamic_cast<ResultsClassification *>(results)->GetValueVectorTypes();
2122
2123 Double_t rocIntegral = 0.0;
2124 if (mvaResType->size() != 0) {
2125 rocIntegral = GetROCIntegral(datasetName, methodName);
2126 }
2127
2128 if (sep[k][i] < 0 || sig[k][i] < 0) {
2129 // cannot compute separation/significance -> no MVA (usually for Cuts)
2130 Log() << kINFO << Form("%-13s %-15s: %#1.3f", datasetName.Data(), methodName.Data(), effArea[k][i])
2131 << Endl;
2132
2133 // Log() << kDEBUG << Form("%-20s %-15s: %#1.3f(%02i) %#1.3f(%02i) %#1.3f(%02i)
2134 // %#1.3f %#1.3f | -- --",
2135 // datasetName.Data(),
2136 // methodName.Data(),
2137 // eff01[k][i], Int_t(1000*eff01err[k][i]),
2138 // eff10[k][i], Int_t(1000*eff10err[k][i]),
2139 // eff30[k][i], Int_t(1000*eff30err[k][i]),
2140 // effArea[k][i],rocIntegral) << Endl;
2141 } else {
2142 Log() << kINFO << Form("%-13s %-15s: %#1.3f", datasetName.Data(), methodName.Data(), rocIntegral)
2143 << Endl;
2144 // Log() << kDEBUG << Form("%-20s %-15s: %#1.3f(%02i) %#1.3f(%02i) %#1.3f(%02i)
2145 // %#1.3f %#1.3f | %#1.3f %#1.3f",
2146 // datasetName.Data(),
2147 // methodName.Data(),
2148 // eff01[k][i], Int_t(1000*eff01err[k][i]),
2149 // eff10[k][i], Int_t(1000*eff10err[k][i]),
2150 // eff30[k][i], Int_t(1000*eff30err[k][i]),
2151 // effArea[k][i],rocIntegral,
2152 // sep[k][i], sig[k][i]) << Endl;
2153 }
2154 }
2155 }
2156 Log() << kINFO << hLine << Endl;
2157 Log() << kINFO << Endl;
2158 Log() << kINFO << "Testing efficiency compared to training efficiency (overtraining check)" << Endl;
2159 Log() << kINFO << hLine << Endl;
2160 Log() << kINFO
2161 << "DataSet MVA Signal efficiency: from test sample (from training sample) "
2162 << Endl;
2163 Log() << kINFO << "Name: Method: @B=0.01 @B=0.10 @B=0.30 "
2164 << Endl;
2165 Log() << kINFO << hLine << Endl;
2166 for (Int_t k = 0; k < 2; k++) {
2167 if (k == 1 && nmeth_used[k] > 0) {
2168 Log() << kINFO << hLine << Endl;
2169 Log() << kINFO << "Input Variables: " << Endl << hLine << Endl;
2170 }
2171 for (Int_t i = 0; i < nmeth_used[k]; i++) {
2172 if (k == 1)
2173 mname[k][i].ReplaceAll("Variable_", "");
2174 MethodBase *theMethod = dynamic_cast<MethodBase *>((*methods)[i]);
2175 if (theMethod == 0)
2176 continue;
2177
2178 Log() << kINFO
2179 << Form("%-20s %-15s: %#1.3f (%#1.3f) %#1.3f (%#1.3f) %#1.3f (%#1.3f)",
2180 theMethod->fDataSetInfo.GetName(), (const char *)mname[k][i], eff01[k][i],
2181 trainEff01[k][i], eff10[k][i], trainEff10[k][i], eff30[k][i], trainEff30[k][i])
2182 << Endl;
2183 }
2184 }
2185 Log() << kINFO << hLine << Endl;
2186 Log() << kINFO << Endl;
2187
2188 if (gTools().CheckForSilentOption(GetOptions()))
2189 Log().InhibitOutput();
2190 } // end fROC
2191 }
2192 if (!IsSilentFile()) {
2193 std::list<TString> datasets;
2194 for (Int_t k = 0; k < 2; k++) {
2195 for (Int_t i = 0; i < nmeth_used[k]; i++) {
2196 MethodBase *theMethod = dynamic_cast<MethodBase *>((*methods)[i]);
2197 if (theMethod == 0)
2198 continue;
2199 // write test/training trees
2200 RootBaseDir()->cd(theMethod->fDataSetInfo.GetName());
2201 if (std::find(datasets.begin(), datasets.end(), theMethod->fDataSetInfo.GetName()) == datasets.end()) {
2204 datasets.push_back(theMethod->fDataSetInfo.GetName());
2205 }
2206 }
2207 }
2208 }
2209 } // end for MethodsMap
2210 // references for citation
2212}
2213
2214////////////////////////////////////////////////////////////////////////////////
2215/// Evaluate Variable Importance
2216
2217TH1F *TMVA::Factory::EvaluateImportance(DataLoader *loader, VIType vitype, Types::EMVA theMethod, TString methodTitle,
2218 const char *theOption)
2219{
2220 fModelPersistence = kFALSE;
2221 fSilentFile = kTRUE; // we need silent file here because we need fast classification results
2222
2223 // getting number of variables and variable names from loader
2224 const int nbits = loader->GetDataSetInfo().GetNVariables();
2225 if (vitype == VIType::kShort)
2226 return EvaluateImportanceShort(loader, theMethod, methodTitle, theOption);
2227 else if (vitype == VIType::kAll)
2228 return EvaluateImportanceAll(loader, theMethod, methodTitle, theOption);
2229 else if (vitype == VIType::kRandom) {
2230 if ( nbits > 10 && nbits < 30) {
2231 // limit nbits to less than 30 to avoid error converting from double to uint and also cannot deal with too many combinations
2232 return EvaluateImportanceRandom(loader, static_cast<UInt_t>( pow(2, nbits) ), theMethod, methodTitle, theOption);
2233 } else if (nbits < 10) {
2234 Log() << kERROR << "Error in Variable Importance: Random mode require more that 10 variables in the dataset."
2235 << Endl;
2236 } else if (nbits > 30) {
2237 Log() << kERROR << "Error in Variable Importance: Number of variables is too large for Random mode"
2238 << Endl;
2239 }
2240 }
2241 return nullptr;
2242}
2243
2244////////////////////////////////////////////////////////////////////////////////
2245
2247 const char *theOption)
2248{
2249
2250 uint64_t x = 0;
2251 uint64_t y = 0;
2252
2253 // getting number of variables and variable names from loader
2254 const int nbits = loader->GetDataSetInfo().GetNVariables();
2255 std::vector<TString> varNames = loader->GetDataSetInfo().GetListOfVariables();
2256
2257 if (nbits > 60) {
2258 Log() << kERROR << "Number of combinations is too large , is 2^" << nbits << Endl;
2259 return nullptr;
2260 }
2261 if (nbits > 20) {
2262 Log() << kWARNING << "Number of combinations is very large , is 2^" << nbits << Endl;
2263 }
2264 uint64_t range = static_cast<uint64_t>(pow(2, nbits));
2265
2266
2267 // vector to save importances
2268 std::vector<Double_t> importances(nbits);
2269 // vector to save ROC
2270 std::vector<Double_t> ROC(range);
2271 ROC[0] = 0.5;
2272 for (int i = 0; i < nbits; i++)
2273 importances[i] = 0;
2274
2275 Double_t SROC, SSROC; // computed ROC value
2276 for (x = 1; x < range; x++) {
2277
2278 std::bitset<VIBITS> xbitset(x);
2279 if (x == 0)
2280 continue; // data loader need at least one variable
2281
2282 // creating loader for seed
2283 TMVA::DataLoader *seedloader = new TMVA::DataLoader(xbitset.to_string());
2284
2285 // adding variables from seed
2286 for (int index = 0; index < nbits; index++) {
2287 if (xbitset[index])
2288 seedloader->AddVariable(varNames[index], 'F');
2289 }
2290
2291 DataLoaderCopy(seedloader, loader);
2292 seedloader->PrepareTrainingAndTestTree(loader->GetDataSetInfo().GetCut("Signal"),
2293 loader->GetDataSetInfo().GetCut("Background"),
2294 loader->GetDataSetInfo().GetSplitOptions());
2295
2296 // Booking Seed
2297 BookMethod(seedloader, theMethod, methodTitle, theOption);
2298
2299 // Train/Test/Evaluation
2300 TrainAllMethods();
2301 TestAllMethods();
2302 EvaluateAllMethods();
2303
2304 // getting ROC
2305 ROC[x] = GetROCIntegral(xbitset.to_string(), methodTitle);
2306
2307 // cleaning information to process sub-seeds
2308 TMVA::MethodBase *smethod = dynamic_cast<TMVA::MethodBase *>(fMethodsMap[xbitset.to_string().c_str()][0][0]);
2311 delete sresults;
2312 delete seedloader;
2313 this->DeleteAllMethods();
2314
2315 fMethodsMap.clear();
2316 // removing global result because it is requiring a lot of RAM for all seeds
2317 }
2318
2319 for (x = 0; x < range; x++) {
2320 SROC = ROC[x];
2321 for (uint32_t i = 0; i < VIBITS; ++i) {
2322 if (x & (uint64_t(1) << i)) {
2323 y = x & ~(1 << i);
2324 std::bitset<VIBITS> ybitset(y);
2325 // need at least one variable
2326 // NOTE: if sub-seed is zero then is the special case
2327 // that count in xbitset is 1
2328 uint32_t ny = static_cast<uint32_t>( log(x - y) / 0.693147 ) ;
2329 if (y == 0) {
2330 importances[ny] = SROC - 0.5;
2331 continue;
2332 }
2333
2334 // getting ROC
2335 SSROC = ROC[y];
2336 importances[ny] += SROC - SSROC;
2337 // cleaning information
2338 }
2339 }
2340 }
2341 std::cout << "--- Variable Importance Results (All)" << std::endl;
2342 return GetImportance(nbits, importances, varNames);
2343}
2344
2345static uint64_t sum(uint64_t i)
2346{
2347 // add a limit for overflows
2348 if (i > 62) return 0;
2349 return static_cast<uint64_t>( std::pow(2, i + 1)) - 1;
2350 // uint64_t _sum = 0;
2351 // for (uint64_t n = 0; n < i; n++)
2352 // _sum += pow(2, n);
2353 // return _sum;
2354}
2355
2356////////////////////////////////////////////////////////////////////////////////
2357
2359 const char *theOption)
2360{
2361 uint64_t x = 0;
2362 uint64_t y = 0;
2363
2364 // getting number of variables and variable names from loader
2365 const int nbits = loader->GetDataSetInfo().GetNVariables();
2366 std::vector<TString> varNames = loader->GetDataSetInfo().GetListOfVariables();
2367
2368 if (nbits > 60) {
2369 Log() << kERROR << "Number of combinations is too large , is 2^" << nbits << Endl;
2370 return nullptr;
2371 }
2372 long int range = sum(nbits);
2373 // std::cout<<range<<std::endl;
2374 // vector to save importances
2375 std::vector<Double_t> importances(nbits);
2376 for (int i = 0; i < nbits; i++)
2377 importances[i] = 0;
2378
2379 Double_t SROC, SSROC; // computed ROC value
2380
2381 x = range;
2382
2383 std::bitset<VIBITS> xbitset(x);
2384 if (x == 0)
2385 Log() << kFATAL << "Error: need at least one variable."; // data loader need at least one variable
2386
2387 // creating loader for seed
2388 TMVA::DataLoader *seedloader = new TMVA::DataLoader(xbitset.to_string());
2389
2390 // adding variables from seed
2391 for (int index = 0; index < nbits; index++) {
2392 if (xbitset[index])
2393 seedloader->AddVariable(varNames[index], 'F');
2394 }
2395
2396 // Loading Dataset
2397 DataLoaderCopy(seedloader, loader);
2398
2399 // Booking Seed
2400 BookMethod(seedloader, theMethod, methodTitle, theOption);
2401
2402 // Train/Test/Evaluation
2403 TrainAllMethods();
2404 TestAllMethods();
2405 EvaluateAllMethods();
2406
2407 // getting ROC
2408 SROC = GetROCIntegral(xbitset.to_string(), methodTitle);
2409
2410 // cleaning information to process sub-seeds
2411 TMVA::MethodBase *smethod = dynamic_cast<TMVA::MethodBase *>(fMethodsMap[xbitset.to_string().c_str()][0][0]);
2414 delete sresults;
2415 delete seedloader;
2416 this->DeleteAllMethods();
2417 fMethodsMap.clear();
2418
2419 // removing global result because it is requiring a lot of RAM for all seeds
2420
2421 for (uint32_t i = 0; i < VIBITS; ++i) {
2422 if (x & (1 << i)) {
2423 y = x & ~(uint64_t(1) << i);
2424 std::bitset<VIBITS> ybitset(y);
2425 // need at least one variable
2426 // NOTE: if sub-seed is zero then is the special case
2427 // that count in xbitset is 1
2428 uint32_t ny = static_cast<uint32_t>(log(x - y) / 0.693147);
2429 if (y == 0) {
2430 importances[ny] = SROC - 0.5;
2431 continue;
2432 }
2433
2434 // creating loader for sub-seed
2435 TMVA::DataLoader *subseedloader = new TMVA::DataLoader(ybitset.to_string());
2436 // adding variables from sub-seed
2437 for (int index = 0; index < nbits; index++) {
2438 if (ybitset[index])
2439 subseedloader->AddVariable(varNames[index], 'F');
2440 }
2441
2442 // Loading Dataset
2443 DataLoaderCopy(subseedloader, loader);
2444
2445 // Booking SubSeed
2446 BookMethod(subseedloader, theMethod, methodTitle, theOption);
2447
2448 // Train/Test/Evaluation
2449 TrainAllMethods();
2450 TestAllMethods();
2451 EvaluateAllMethods();
2452
2453 // getting ROC
2454 SSROC = GetROCIntegral(ybitset.to_string(), methodTitle);
2455 importances[ny] += SROC - SSROC;
2456
2457 // cleaning information
2458 TMVA::MethodBase *ssmethod = dynamic_cast<TMVA::MethodBase *>(fMethodsMap[ybitset.to_string().c_str()][0][0]);
2461 delete ssresults;
2462 delete subseedloader;
2463 this->DeleteAllMethods();
2464 fMethodsMap.clear();
2465 }
2466 }
2467 std::cout << "--- Variable Importance Results (Short)" << std::endl;
2468 return GetImportance(nbits, importances, varNames);
2469}
2470
2471////////////////////////////////////////////////////////////////////////////////
2472
2474 TString methodTitle, const char *theOption)
2475{
2476 TRandom3 *rangen = new TRandom3(0); // Random Gen.
2477
2478 uint64_t x = 0;
2479 uint64_t y = 0;
2480
2481 // getting number of variables and variable names from loader
2482 const int nbits = loader->GetDataSetInfo().GetNVariables();
2483 std::vector<TString> varNames = loader->GetDataSetInfo().GetListOfVariables();
2484
2485 long int range = pow(2, nbits);
2486
2487 // vector to save importances
2488 std::vector<Double_t> importances(nbits);
2489 for (int i = 0; i < nbits; i++)
2490 importances[i] = 0;
2491
2492 Double_t SROC, SSROC; // computed ROC value
2493 for (UInt_t n = 0; n < nseeds; n++) {
2494 x = rangen->Integer(range);
2495
2496 std::bitset<32> xbitset(x);
2497 if (x == 0)
2498 continue; // data loader need at least one variable
2499
2500 // creating loader for seed
2501 TMVA::DataLoader *seedloader = new TMVA::DataLoader(xbitset.to_string());
2502
2503 // adding variables from seed
2504 for (int index = 0; index < nbits; index++) {
2505 if (xbitset[index])
2506 seedloader->AddVariable(varNames[index], 'F');
2507 }
2508
2509 // Loading Dataset
2510 DataLoaderCopy(seedloader, loader);
2511
2512 // Booking Seed
2513 BookMethod(seedloader, theMethod, methodTitle, theOption);
2514
2515 // Train/Test/Evaluation
2516 TrainAllMethods();
2517 TestAllMethods();
2518 EvaluateAllMethods();
2519
2520 // getting ROC
2521 SROC = GetROCIntegral(xbitset.to_string(), methodTitle);
2522 // std::cout << "Seed: n " << n << " x " << x << " xbitset:" << xbitset << " ROC " << SROC << std::endl;
2523
2524 // cleaning information to process sub-seeds
2525 TMVA::MethodBase *smethod = dynamic_cast<TMVA::MethodBase *>(fMethodsMap[xbitset.to_string().c_str()][0][0]);
2528 delete sresults;
2529 delete seedloader;
2530 this->DeleteAllMethods();
2531 fMethodsMap.clear();
2532
2533 // removing global result because it is requiring a lot of RAM for all seeds
2534
2535 for (uint32_t i = 0; i < 32; ++i) {
2536 if (x & (uint64_t(1) << i)) {
2537 y = x & ~(1 << i);
2538 std::bitset<32> ybitset(y);
2539 // need at least one variable
2540 // NOTE: if sub-seed is zero then is the special case
2541 // that count in xbitset is 1
2542 Double_t ny = log(x - y) / 0.693147;
2543 if (y == 0) {
2544 importances[ny] = SROC - 0.5;
2545 // std::cout << "SubSeed: " << y << " y:" << ybitset << "ROC " << 0.5 << std::endl;
2546 continue;
2547 }
2548
2549 // creating loader for sub-seed
2550 TMVA::DataLoader *subseedloader = new TMVA::DataLoader(ybitset.to_string());
2551 // adding variables from sub-seed
2552 for (int index = 0; index < nbits; index++) {
2553 if (ybitset[index])
2554 subseedloader->AddVariable(varNames[index], 'F');
2555 }
2556
2557 // Loading Dataset
2558 DataLoaderCopy(subseedloader, loader);
2559
2560 // Booking SubSeed
2561 BookMethod(subseedloader, theMethod, methodTitle, theOption);
2562
2563 // Train/Test/Evaluation
2564 TrainAllMethods();
2565 TestAllMethods();
2566 EvaluateAllMethods();
2567
2568 // getting ROC
2569 SSROC = GetROCIntegral(ybitset.to_string(), methodTitle);
2570 importances[ny] += SROC - SSROC;
2571 // std::cout << "SubSeed: " << y << " y:" << ybitset << " x-y " << x - y << " " << std::bitset<32>(x - y) <<
2572 // " ny " << ny << " SROC " << SROC << " SSROC " << SSROC << " Importance = " << importances[ny] <<
2573 // std::endl; cleaning information
2574 TMVA::MethodBase *ssmethod =
2575 dynamic_cast<TMVA::MethodBase *>(fMethodsMap[ybitset.to_string().c_str()][0][0]);
2578 delete ssresults;
2579 delete subseedloader;
2580 this->DeleteAllMethods();
2581 fMethodsMap.clear();
2582 }
2583 }
2584 }
2585 std::cout << "--- Variable Importance Results (Random)" << std::endl;
2586 return GetImportance(nbits, importances, varNames);
2587}
2588
2589////////////////////////////////////////////////////////////////////////////////
2590
2591TH1F *TMVA::Factory::GetImportance(const int nbits, std::vector<Double_t> importances, std::vector<TString> varNames)
2592{
2593 TH1F *vih1 = new TH1F("vih1", "", nbits, 0, nbits);
2594
2595 gStyle->SetOptStat(000000);
2596
2597 Float_t normalization = 0.0;
2598 for (int i = 0; i < nbits; i++) {
2599 normalization = normalization + importances[i];
2600 }
2601
2602 Float_t roc = 0.0;
2603
2604 gStyle->SetTitleXOffset(0.4);
2605 gStyle->SetTitleXOffset(1.2);
2606
2607 std::vector<Double_t> x_ie(nbits), y_ie(nbits);
2608 for (Int_t i = 1; i < nbits + 1; i++) {
2609 x_ie[i - 1] = (i - 1) * 1.;
2610 roc = 100.0 * importances[i - 1] / normalization;
2611 y_ie[i - 1] = roc;
2612 std::cout << "--- " << varNames[i - 1] << " = " << roc << " %" << std::endl;
2613 vih1->GetXaxis()->SetBinLabel(i, varNames[i - 1].Data());
2614 vih1->SetBinContent(i, roc);
2615 }
2616 TGraph *g_ie = new TGraph(nbits + 2, &x_ie[0], &y_ie[0]);
2617 g_ie->SetTitle("");
2618
2619 vih1->LabelsOption("v >", "X");
2620 vih1->SetBarWidth(0.97);
2621 Int_t ca = TColor::GetColor("#006600");
2622 vih1->SetFillColor(ca);
2623 // Int_t ci = TColor::GetColor("#990000");
2624
2625 vih1->GetYaxis()->SetTitle("Importance (%)");
2626 vih1->GetYaxis()->SetTitleSize(0.045);
2627 vih1->GetYaxis()->CenterTitle();
2628 vih1->GetYaxis()->SetTitleOffset(1.24);
2629
2630 vih1->GetYaxis()->SetRangeUser(-7, 50);
2631 vih1->SetDirectory(nullptr);
2632
2633 // vih1->Draw("B");
2634 return vih1;
2635}
#define MinNoTrainingEvents
#define h(i)
Definition RSha256.hxx:106
void printMatrix(const TMatrixD &mat)
write a matrix
int Int_t
Definition RtypesCore.h:45
float Float_t
Definition RtypesCore.h:57
constexpr Bool_t kFALSE
Definition RtypesCore.h:101
double Double_t
Definition RtypesCore.h:59
constexpr Bool_t kTRUE
Definition RtypesCore.h:100
#define ClassImp(name)
Definition Rtypes.h:377
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t index
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
char name[80]
Definition TGX11.cxx:110
TMatrixT< Double_t > TMatrixD
Definition TMatrixDfwd.h:23
#define gROOT
Definition TROOT.h:407
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
Definition TString.cxx:2467
R__EXTERN TStyle * gStyle
Definition TStyle.h:433
R__EXTERN TSystem * gSystem
Definition TSystem.h:560
virtual void SetTitleOffset(Float_t offset=1)
Set distance between the axis and the axis title.
Definition TAttAxis.cxx:298
virtual void SetTitleSize(Float_t size=0.04)
Set size of axis title.
Definition TAttAxis.cxx:309
virtual void SetFillColor(Color_t fcolor)
Set the fill area color.
Definition TAttFill.h:37
virtual void SetBinLabel(Int_t bin, const char *label)
Set label for bin.
Definition TAxis.cxx:886
void CenterTitle(Bool_t center=kTRUE)
Center axis title.
Definition TAxis.h:194
virtual void SetRangeUser(Double_t ufirst, Double_t ulast)
Set the viewing range for the axis from ufirst to ulast (in user coordinates, that is,...
Definition TAxis.cxx:1078
The Canvas class.
Definition TCanvas.h:23
static Int_t GetColor(const char *hexcolor)
Static method returning color number for color specified by hex color string of form: "#rrggbb",...
Definition TColor.cxx:1823
A ROOT file is composed of a header, followed by consecutive data records (TKey instances) with a wel...
Definition TFile.h:53
A TGraph is an object made of two arrays X and Y with npoints each.
Definition TGraph.h:41
void SetTitle(const char *title="") override
Change (i.e.
Definition TGraph.cxx:2370
1-D histogram with a float per channel (see TH1 documentation)}
Definition TH1.h:577
virtual void SetDirectory(TDirectory *dir)
By default, when a histogram is created, it is added to the list of histogram objects in the current ...
Definition TH1.cxx:8854
void SetTitle(const char *title) override
Change/set the title.
Definition TH1.cxx:6707
virtual void LabelsOption(Option_t *option="h", Option_t *axis="X")
Sort bins with labels or set option(s) to draw axis with labels.
Definition TH1.cxx:5353
static void AddDirectory(Bool_t add=kTRUE)
Sets the flag controlling the automatic add of histograms in memory.
Definition TH1.cxx:1267
TAxis * GetXaxis()
Definition TH1.h:322
TAxis * GetYaxis()
Definition TH1.h:323
virtual void SetBinContent(Int_t bin, Double_t content)
Set bin content see convention for numbering bins in TH1::GetBin In case the bin number is greater th...
Definition TH1.cxx:9139
virtual void SetBarWidth(Float_t width=0.5)
Set the width of bars as fraction of the bin width for drawing mode "B".
Definition TH1.h:362
Service class for 2-D histogram classes.
Definition TH2.h:30
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
TString fWeightFileDir
Definition Config.h:124
TString fWeightFileDirPrefix
Definition Config.h:123
void SetDrawProgressBar(Bool_t d)
Definition Config.h:69
void SetUseColor(Bool_t uc)
Definition Config.h:60
class TMVA::Config::VariablePlotting fVariablePlotting
void SetSilent(Bool_t s)
Definition Config.h:63
IONames & GetIONames()
Definition Config.h:98
void SetConfigDescription(const char *d)
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
void AddPreDefVal(const T &)
void SetConfigName(const char *n)
virtual void ParseOptions()
options parser
const TString & GetOptions() const
MsgLogger & Log() const
MsgLogger * fLogger
! message logger
void CheckForUnusedOptions() const
checks for unused options in option string
UInt_t GetEntries(const TString &name) const
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
DataSetInfo & GetDataSetInfo()
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
Class that contains all the data information.
Definition DataSetInfo.h:62
UInt_t GetNVariables() const
virtual const char * GetName() const
Returns name of object.
Definition DataSetInfo.h:71
const TMatrixD * CorrelationMatrix(const TString &className) const
UInt_t GetNClasses() const
const TString & GetSplitOptions() const
UInt_t GetNTargets() const
DataSet * GetDataSet() const
returns data set
TH2 * CreateCorrelationMatrixHist(const TMatrixD *m, const TString &hName, const TString &hTitle) const
std::vector< TString > GetListOfVariables() const
returns list of variables
ClassInfo * GetClassInfo(Int_t clNum) const
const TCut & GetCut(Int_t i) const
VariableInfo & GetVariableInfo(Int_t i)
Bool_t IsSignal(const Event *ev) const
DataSetManager * GetDataSetManager()
DataInputHandler & DataInput()
Class that contains all the data information.
Definition DataSet.h:58
Long64_t GetNEvtSigTest()
return number of signal test events in dataset
Definition DataSet.cxx:427
TTree * GetTree(Types::ETreeType type)
create the test/trainings tree with all the variables, the weights, the classes, the targets,...
Definition DataSet.cxx:609
const Event * GetEvent() const
returns event without transformations
Definition DataSet.cxx:202
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition DataSet.h:206
Results * GetResults(const TString &, Types::ETreeType type, Types::EAnalysisType analysistype)
Definition DataSet.cxx:265
Long64_t GetNTrainingEvents() const
Definition DataSet.h:68
void SetCurrentType(Types::ETreeType type) const
Definition DataSet.h:89
const std::vector< Event * > & GetEventCollection(Types::ETreeType type=Types::kMaxTreeType) const
Definition DataSet.h:216
Long64_t GetNEvtBkgdTest()
return number of background test events in dataset
Definition DataSet.cxx:435
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition Event.cxx:236
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition Event.cxx:399
This is the main MVA steering class.
Definition Factory.h:80
void PrintHelpMessage(const TString &datasetname, const TString &methodTitle="") const
Print predefined help message of classifier.
Definition Factory.cxx:1333
Bool_t fCorrelations
! enable to calculate correlations
Definition Factory.h:215
std::vector< IMethod * > MVector
Definition Factory.h:84
void TrainAllMethods()
Iterates through all booked methods and calls training.
Definition Factory.cxx:1114
Bool_t Verbose(void) const
Definition Factory.h:134
void WriteDataInformation(DataSetInfo &fDataSetInfo)
Definition Factory.cxx:602
MethodBase * BookMethod(DataLoader *loader, TString theMethodName, TString methodTitle, TString theOption="")
Book a classifier or regression method.
Definition Factory.cxx:352
Factory(TString theJobName, TFile *theTargetFile, TString theOption="")
Standard constructor.
Definition Factory.cxx:113
void TestAllMethods()
Evaluates all booked methods on the testing data and adds the output to the Results in the corresponi...
Definition Factory.cxx:1271
Bool_t fVerbose
! verbose mode
Definition Factory.h:213
void EvaluateAllMethods(void)
Iterates over all MVAs that have been booked, and calls their evaluation methods.
Definition Factory.cxx:1376
TH1F * EvaluateImportanceRandom(DataLoader *loader, UInt_t nseeds, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Definition Factory.cxx:2473
TH1F * GetImportance(const int nbits, std::vector< Double_t > importances, std::vector< TString > varNames)
Definition Factory.cxx:2591
Bool_t fROC
! enable to calculate ROC values
Definition Factory.h:216
void EvaluateAllVariables(DataLoader *loader, TString options="")
Iterates over all MVA input variables and evaluates them.
Definition Factory.cxx:1360
TString fVerboseLevel
! verbosity level, controls granularity of logging
Definition Factory.h:214
TMultiGraph * GetROCCurveAsMultiGraph(DataLoader *loader, UInt_t iClass, Types::ETreeType type=Types::kTesting)
Generate a collection of graphs, for all methods for a given class.
Definition Factory.cxx:988
TH1F * EvaluateImportance(DataLoader *loader, VIType vitype, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Evaluate Variable Importance.
Definition Factory.cxx:2217
Double_t GetROCIntegral(DataLoader *loader, TString theMethodName, UInt_t iClass=0, Types::ETreeType type=Types::kTesting)
Calculate the integral of the ROC curve, also known as the area under curve (AUC),...
Definition Factory.cxx:849
virtual ~Factory()
Destructor.
Definition Factory.cxx:306
virtual void MakeClass(const TString &datasetname, const TString &methodTitle="") const
Definition Factory.cxx:1305
MethodBase * BookMethodWeightfile(DataLoader *dataloader, TMVA::Types::EMVA methodType, const TString &weightfile)
Adds an already constructed method to be managed by this factory.
Definition Factory.cxx:501
Bool_t fModelPersistence
! option to save the trained model in xml file or using serialization
Definition Factory.h:222
std::map< TString, Double_t > OptimizeAllMethods(TString fomType="ROCIntegral", TString fitType="FitGA")
Iterates through all booked methods and sees if they use parameter tuning and if so does just that,...
Definition Factory.cxx:701
ROCCurve * GetROC(DataLoader *loader, TString theMethodName, UInt_t iClass=0, Types::ETreeType type=Types::kTesting)
Private method to generate a ROCCurve instance for a given method.
Definition Factory.cxx:749
TH1F * EvaluateImportanceShort(DataLoader *loader, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Definition Factory.cxx:2358
Types::EAnalysisType fAnalysisType
! the training type
Definition Factory.h:221
Bool_t HasMethod(const TString &datasetname, const TString &title) const
Checks whether a given method name is defined for a given dataset.
Definition Factory.cxx:586
TGraph * GetROCCurve(DataLoader *loader, TString theMethodName, Bool_t setTitles=kTRUE, UInt_t iClass=0, Types::ETreeType type=Types::kTesting)
Argument iClass specifies the class to generate the ROC curve in a multiclass setting.
Definition Factory.cxx:912
TH1F * EvaluateImportanceAll(DataLoader *loader, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Definition Factory.cxx:2246
void SetVerbose(Bool_t v=kTRUE)
Definition Factory.cxx:343
TFile * fgTargetFile
! ROOT output file
Definition Factory.h:205
IMethod * GetMethod(const TString &datasetname, const TString &title) const
Returns pointer to MVA that corresponds to given method title.
Definition Factory.cxx:566
void DeleteAllMethods(void)
Delete methods.
Definition Factory.cxx:324
TString fTransformations
! list of transformations to test
Definition Factory.h:212
void Greetings()
Print welcome message.
Definition Factory.cxx:295
Interface for all concrete MVA method implementations.
Definition IMethod.h:53
virtual void PrintHelpMessage() const =0
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)=0
virtual void MakeClass(const TString &classFileName=TString("")) const =0
Virtual base Class for all MVA method.
Definition MethodBase.h:111
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as
void SetSilentFile(Bool_t status)
Definition MethodBase.h:378
void SetWeightFileDir(TString fileDir)
set directory of weight file
virtual void TestRegression(Double_t &bias, Double_t &biasT, Double_t &dev, Double_t &devT, Double_t &rms, Double_t &rmsT, Double_t &mInf, Double_t &mInfT, Double_t &corr, Types::ETreeType type)
calculate <sum-of-deviation-squared> of regression output versus "true" value from test sample
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
TString GetMethodTypeName() const
Definition MethodBase.h:332
Bool_t DoMulticlass() const
Definition MethodBase.h:439
virtual Double_t GetSignificance() const
compute significance of mean difference
const char * GetName() const
Definition MethodBase.h:334
Types::EAnalysisType GetAnalysisType() const
Definition MethodBase.h:437
virtual TMatrixD GetMulticlassConfusionMatrix(Double_t effB, Types::ETreeType type)
Construct a confusion matrix for a multiclass classifier.
void PrintHelpMessage() const
prints out method-specific help method
virtual void WriteEvaluationHistosToFile(Types::ETreeType treetype)
writes all MVA evaluation histograms to file
virtual void TestMulticlass()
test multiclass classification
void SetupMethod()
setup of methods
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
virtual void SetAnalysisType(Types::EAnalysisType type)
Definition MethodBase.h:436
const TString & GetMethodName() const
Definition MethodBase.h:331
Bool_t DoRegression() const
Definition MethodBase.h:438
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
virtual Double_t GetTrainingEfficiency(const TString &)
DataSetInfo & DataInfo() const
Definition MethodBase.h:410
virtual void MakeClass(const TString &classFileName=TString("")) const
create reader class for method (classification only at present)
virtual void TestClassification()
initialization
void AddOutput(Types::ETreeType type, Types::EAnalysisType analysisType)
void ReadStateFromFile()
Function to write options and weights to file.
virtual std::map< TString, Double_t > OptimizeTuningParameters(TString fomType="ROCIntegral", TString fitType="FitGA")
call the Optimizer with the set of parameters and ranges that are meant to be tuned.
DataSetInfo & fDataSetInfo
Definition MethodBase.h:607
Types::EMVA GetMethodType() const
Definition MethodBase.h:333
void SetFile(TFile *file)
Definition MethodBase.h:375
DataSet * Data() const
Definition MethodBase.h:409
void SetModelPersistence(Bool_t status)
Definition MethodBase.h:382
virtual Double_t GetROCIntegral(TH1D *histS, TH1D *histB) const
calculate the area (integral) under the ROC curve as a overall quality measure of the classification
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Class for boosting a TMVA method.
Definition MethodBoost.h:58
void SetBoostedMethodName(TString methodName)
Definition MethodBoost.h:86
DataSetManager * fDataSetManager
DSMTEST.
Class for categorizing the phase space.
DataSetManager * fDataSetManager
ostringstream derivative to redirect and format output
Definition MsgLogger.h:57
void SetMinType(EMsgType minType)
Definition MsgLogger.h:70
void SetSource(const std::string &source)
Definition MsgLogger.h:68
static void InhibitOutput()
Definition MsgLogger.cxx:67
Double_t GetEffSForEffB(Double_t effB, const UInt_t num_points=41)
Calculate the signal efficiency (sensitivity) for a given background efficiency (sensitivity).
Definition ROCCurve.cxx:219
Double_t GetROCIntegral(const UInt_t points=41)
Calculates the ROC integral (AUC)
Definition ROCCurve.cxx:250
TGraph * GetROCCurve(const UInt_t points=100)
Returns a new TGraph containing the ROC curve.
Definition ROCCurve.cxx:276
Ranking for variables in method (implementation)
Definition Ranking.h:48
virtual void Print() const
get maximum length of variable names
Definition Ranking.cxx:111
Class that is the base-class for a vector of result.
Class which takes the results of a multiclass classification.
Class that is the base-class for a vector of result.
Definition Results.h:57
void FormattedOutput(const std::vector< Double_t > &, const std::vector< TString > &, const TString titleVars, const TString titleValues, MsgLogger &logger, TString format="%+1.3f")
formatted output of simple table
Definition Tools.cxx:887
void ROOTVersionMessage(MsgLogger &logger)
prints the ROOT release number and date
Definition Tools.cxx:1325
void UsefulSortDescending(std::vector< std::vector< Double_t > > &, std::vector< TString > *vs=nullptr)
sort 2D vector (AND in parallel a TString vector) in such a way that the "first vector is sorted" and...
Definition Tools.cxx:564
std::vector< TString > SplitString(const TString &theOpt, const char separator) const
splits the option string at 'separator' and fills the list 'splitV' with the primitive strings
Definition Tools.cxx:1199
const TString & Color(const TString &)
human readable color strings
Definition Tools.cxx:828
const TMatrixD * GetCorrelationMatrix(const TMatrixD *covMat)
turns covariance into correlation matrix
Definition Tools.cxx:324
@ kHtmlLink
Definition Tools.h:212
void UsefulSortAscending(std::vector< std::vector< Double_t > > &, std::vector< TString > *vs=nullptr)
sort 2D vector (AND in parallel a TString vector) in such a way that the "first vector is sorted" and...
Definition Tools.cxx:538
void TMVACitation(MsgLogger &logger, ECitation citType=kPlainText)
kinds of TMVA citation
Definition Tools.cxx:1441
void TMVAVersionMessage(MsgLogger &logger)
prints the TMVA release number and date
Definition Tools.cxx:1316
void TMVAWelcomeMessage()
direct output, eg, when starting ROOT session -> no use of Logger here
Definition Tools.cxx:1302
Class that contains all the data information.
void PrintVariableRanking() const
prints ranking of input variables
Singleton class for Global types used by TMVA.
Definition Types.h:71
static Types & Instance()
The single instance of "Types" if existing already, or create it (Singleton)
Definition Types.cxx:70
@ kCategory
Definition Types.h:97
@ kMulticlass
Definition Types.h:129
@ kNoAnalysisType
Definition Types.h:130
@ kClassification
Definition Types.h:127
@ kMaxAnalysisType
Definition Types.h:131
@ kRegression
Definition Types.h:128
@ kTraining
Definition Types.h:143
const TString & GetLabel() const
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition TMultiGraph.h:34
TList * GetListOfGraphs() const
Definition TMultiGraph.h:68
virtual void Add(TGraph *graph, Option_t *chopt="")
Add a new graph to the list of graphs.
TH1F * GetHistogram()
Returns a pointer to the histogram used to draw the axis.
void Draw(Option_t *chopt="") override
Draw this multigraph with its current attributes.
TAxis * GetYaxis()
Get y axis of the graph.
TAxis * GetXaxis()
Get x axis of the graph.
TObject * Clone(const char *newname="") const override
Make a clone of an object using the Streamer facility.
Definition TNamed.cxx:74
virtual void SetTitle(const char *title="")
Set the title of the TNamed.
Definition TNamed.cxx:164
const char * GetName() const override
Returns name of object.
Definition TNamed.h:47
TString fName
Definition TNamed.h:32
@ kOverwrite
overwrite existing object with same name
Definition TObject.h:92
virtual const char * GetName() const
Returns name of object.
Definition TObject.cxx:439
virtual Int_t Write(const char *name=nullptr, Int_t option=0, Int_t bufsize=0)
Write this object to the current directory.
Definition TObject.cxx:880
void SetGrid(Int_t valuex=1, Int_t valuey=1) override
Definition TPad.h:331
TLegend * BuildLegend(Double_t x1=0.3, Double_t y1=0.21, Double_t x2=0.3, Double_t y2=0.21, const char *title="", Option_t *option="") override
Build a legend from the graphical objects in the pad.
Definition TPad.cxx:502
Principal Components Analysis (PCA)
Definition TPrincipal.h:21
virtual void AddRow(const Double_t *x)
Add a data point and update the covariance matrix.
const TMatrixD * GetCovarianceMatrix() const
Definition TPrincipal.h:59
virtual void MakePrincipals()
Perform the principal components analysis.
Random number generator class based on M.
Definition TRandom3.h:27
virtual UInt_t Integer(UInt_t imax)
Returns a random integer uniformly distributed on the interval [ 0, imax-1 ].
Definition TRandom.cxx:360
Basic string class.
Definition TString.h:139
Ssiz_t Length() const
Definition TString.h:421
void ToLower()
Change string to lower-case.
Definition TString.cxx:1170
int CompareTo(const char *cs, ECaseCompare cmp=kExact) const
Compare a string to char *cs2.
Definition TString.cxx:450
const char * Data() const
Definition TString.h:380
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition TString.h:704
Bool_t BeginsWith(const char *s, ECaseCompare cmp=kExact) const
Definition TString.h:627
Bool_t IsNull() const
Definition TString.h:418
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2356
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition TString.h:636
void SetOptStat(Int_t stat=1)
The type of information printed in the histogram statistics box can be selected via the parameter mod...
Definition TStyle.cxx:1636
void SetTitleXOffset(Float_t offset=1)
Definition TStyle.h:406
virtual int MakeDirectory(const char *name)
Make a directory.
Definition TSystem.cxx:814
Int_t Write(const char *name=nullptr, Int_t option=0, Int_t bufsize=0) override
Write this object to the current directory.
Definition TTree.cxx:9740
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
const Int_t n
Definition legend1.C:16
void DataLoaderCopy(TMVA::DataLoader *des, TMVA::DataLoader *src)
Config & gConfig()
Tools & gTools()
void CreateVariableTransforms(const TString &trafoDefinition, TMVA::DataSetInfo &dataInfo, TMVA::TransformationHandler &transformationHandler, TMVA::MsgLogger &log)
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148
Bool_t IsNaN(Double_t x)
Definition TMath.h:892
Definition graph.py:1
TMarker m
Definition textangle.C:8
#define VIBITS
Definition Factory.cxx:103
static uint64_t sum(uint64_t i)
Definition Factory.cxx:2345
const Int_t MinNoTrainingEvents
Definition Factory.cxx:95
#define READXML
Definition Factory.cxx:100