Logo ROOT  
Reference Guide
Loading...
Searching...
No Matches
Factory.cxx
Go to the documentation of this file.
1// @(#)Root/tmva $Id$
2// Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne, Jan Therhaag
3// Updated by: Omar Zapata, Kim Albertsson
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : Factory *
8 * *
9 * *
10 * Description: *
11 * Implementation (see header for description) *
12 * *
13 * Authors : *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Joerg Stelzer <stelzer@cern.ch> - DESY, Germany *
16 * Peter Speckmayer <peter.speckmayer@cern.ch> - CERN, Switzerland *
17 * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
18 * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
19 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
20 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
21 * Omar Zapata <Omar.Zapata@cern.ch> - UdeA/ITM Colombia *
22 * Lorenzo Moneta <Lorenzo.Moneta@cern.ch> - CERN, Switzerland *
23 * Sergei Gleyzer <Sergei.Gleyzer@cern.ch> - U of Florida & CERN *
24 * Kim Albertsson <kim.albertsson@cern.ch> - LTU & CERN *
25 * *
26 * Copyright (c) 2005-2015: *
27 * CERN, Switzerland *
28 * U. of Victoria, Canada *
29 * MPI-K Heidelberg, Germany *
30 * U. of Bonn, Germany *
31 * UdeA/ITM, Colombia *
32 * U. of Florida, USA *
33 * *
34 * Redistribution and use in source and binary forms, with or without *
35 * modification, are permitted according to the terms listed in LICENSE *
36 * (see tmva/doc/LICENSE) *
37 **********************************************************************************/
38
39/*! \class TMVA::Factory
40\ingroup TMVA
41
42This is the main MVA steering class.
43It creates all MVA methods, and guides them through the training, testing and
44evaluation phases.
45*/
46
47#include "TMVA/Factory.h"
48
50#include "TMVA/Config.h"
51#include "TMVA/Configurable.h"
52#include "TMVA/Tools.h"
53#include "TMVA/Ranking.h"
54#include "TMVA/DataSet.h"
55#include "TMVA/IMethod.h"
56#include "TMVA/MethodBase.h"
58#include "TMVA/DataSetManager.h"
59#include "TMVA/DataSetInfo.h"
60#include "TMVA/DataLoader.h"
61#include "TMVA/MethodBoost.h"
62#include "TMVA/MethodCategory.h"
63#include "TMVA/ROCCalc.h"
64#include "TMVA/ROCCurve.h"
65#include "TMVA/MsgLogger.h"
66
67#include "TMVA/VariableInfo.h"
69
70#include "TMVA/Results.h"
74#include <list>
75#include <bitset>
76#include <set>
77
78#include "TMVA/Types.h"
79
80#include "TROOT.h"
81#include "TFile.h"
82#include "TLeaf.h"
83#include "TEventList.h"
84#include "TH2.h"
85#include "TGraph.h"
86#include "TStyle.h"
87#include "TMatrixF.h"
88#include "TMatrixDSym.h"
89#include "TMultiGraph.h"
90#include "TPrincipal.h"
91#include "TMath.h"
92#include "TSystem.h"
93#include "TCanvas.h"
94
96// const Int_t MinNoTestEvents = 1;
97
98
99#define READXML kTRUE
100
101// number of bits for bitset
102#define VIBITS 32
103
104////////////////////////////////////////////////////////////////////////////////
105/// Standard constructor.
106///
107/// - jobname : this name will appear in all weight file names produced by the MVAs
108/// - theTargetFile : output ROOT file; the test tree and all evaluation plots
109/// will be stored here
110/// - theOption : option string; currently: "V" for verbose
111
112TMVA::Factory::Factory(TString jobName, TFile *theTargetFile, TString theOption)
114 fROC(kTRUE), fSilentFile(theTargetFile == nullptr), fJobName(jobName), fAnalysisType(Types::kClassification),
116{
117 fName = "Factory";
118 fgTargetFile = theTargetFile;
119 fLogger->SetSource(fName.Data());
120
121 // render silent
122 if (gTools().CheckForSilentOption(GetOptions()))
123 Log().InhibitOutput(); // make sure is silent if wanted to
124
125 // init configurable
126 SetConfigDescription("Configuration options for Factory running");
128
129 // histograms are not automatically associated with the current
130 // directory and hence don't go out of scope when closing the file
131 // TH1::AddDirectory(kFALSE);
132 Bool_t silent = kFALSE;
133#ifdef WIN32
134 // under Windows, switch progress bar and color off by default, as the typical windows shell doesn't handle these
135 // (would need different sequences..)
136 Bool_t color = kFALSE;
137 Bool_t drawProgressBar = kFALSE;
138#else
139 Bool_t color = !gROOT->IsBatch();
140 Bool_t drawProgressBar = kTRUE;
141#endif
142 DeclareOptionRef(fVerbose, "V", "Verbose flag");
143 DeclareOptionRef(fVerboseLevel = TString("Info"), "VerboseLevel", "VerboseLevel (Debug/Verbose/Info)");
144 AddPreDefVal(TString("Debug"));
145 AddPreDefVal(TString("Verbose"));
146 AddPreDefVal(TString("Info"));
147 DeclareOptionRef(color, "Color", "Flag for coloured screen output (default: True, if in batch mode: False)");
149 fTransformations, "Transformations",
150 "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for identity, "
151 "decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation transformations");
152 DeclareOptionRef(fCorrelations, "Correlations", "boolean to show correlation in output");
153 DeclareOptionRef(fROC, "ROC", "boolean to show ROC in output");
154 DeclareOptionRef(silent, "Silent",
155 "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory "
156 "class object (default: False)");
157 DeclareOptionRef(drawProgressBar, "DrawProgressBar",
158 "Draw progress bar to display training, testing and evaluation schedule (default: True)");
159 DeclareOptionRef(fModelPersistence, "ModelPersistence",
160 "Option to save the trained model in xml file or using serialization");
161
162 TString analysisType("Auto");
163 DeclareOptionRef(analysisType, "AnalysisType",
164 "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)");
165 AddPreDefVal(TString("Classification"));
166 AddPreDefVal(TString("Regression"));
167 AddPreDefVal(TString("Multiclass"));
168 AddPreDefVal(TString("Auto"));
169
170 ParseOptions();
172
173 if (Verbose())
174 fLogger->SetMinType(kVERBOSE);
175 if (fVerboseLevel.CompareTo("Debug") == 0)
176 fLogger->SetMinType(kDEBUG);
177 if (fVerboseLevel.CompareTo("Verbose") == 0)
178 fLogger->SetMinType(kVERBOSE);
179 if (fVerboseLevel.CompareTo("Info") == 0)
180 fLogger->SetMinType(kINFO);
181
182 // global settings
183 gConfig().SetUseColor(color);
184 gConfig().SetSilent(silent);
185 gConfig().SetDrawProgressBar(drawProgressBar);
186
187 analysisType.ToLower();
188 if (analysisType == "classification")
190 else if (analysisType == "regression")
192 else if (analysisType == "multiclass")
194 else if (analysisType == "auto")
196
197 // Greetings();
198}
199
200////////////////////////////////////////////////////////////////////////////////
201/// Constructor.
202
205 fSilentFile(kTRUE), fJobName(jobName), fAnalysisType(Types::kClassification), fModelPersistence(kTRUE)
206{
207 fName = "Factory";
208 fgTargetFile = nullptr;
209 fLogger->SetSource(fName.Data());
210
211 // render silent
212 if (gTools().CheckForSilentOption(GetOptions()))
213 Log().InhibitOutput(); // make sure is silent if wanted to
214
215 // init configurable
216 SetConfigDescription("Configuration options for Factory running");
218
219 // histograms are not automatically associated with the current
220 // directory and hence don't go out of scope when closing the file
222 Bool_t silent = kFALSE;
223#ifdef WIN32
224 // under Windows, switch progress bar and color off by default, as the typical windows shell doesn't handle these
225 // (would need different sequences..)
226 Bool_t color = kFALSE;
227 Bool_t drawProgressBar = kFALSE;
228#else
229 Bool_t color = !gROOT->IsBatch();
230 Bool_t drawProgressBar = kTRUE;
231#endif
232 DeclareOptionRef(fVerbose, "V", "Verbose flag");
233 DeclareOptionRef(fVerboseLevel = TString("Info"), "VerboseLevel", "VerboseLevel (Debug/Verbose/Info)");
234 AddPreDefVal(TString("Debug"));
235 AddPreDefVal(TString("Verbose"));
236 AddPreDefVal(TString("Info"));
237 DeclareOptionRef(color, "Color", "Flag for coloured screen output (default: True, if in batch mode: False)");
239 fTransformations, "Transformations",
240 "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for identity, "
241 "decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation transformations");
242 DeclareOptionRef(fCorrelations, "Correlations", "boolean to show correlation in output");
243 DeclareOptionRef(fROC, "ROC", "boolean to show ROC in output");
244 DeclareOptionRef(silent, "Silent",
245 "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory "
246 "class object (default: False)");
247 DeclareOptionRef(drawProgressBar, "DrawProgressBar",
248 "Draw progress bar to display training, testing and evaluation schedule (default: True)");
249 DeclareOptionRef(fModelPersistence, "ModelPersistence",
250 "Option to save the trained model in xml file or using serialization");
251
252 TString analysisType("Auto");
253 DeclareOptionRef(analysisType, "AnalysisType",
254 "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)");
255 AddPreDefVal(TString("Classification"));
256 AddPreDefVal(TString("Regression"));
257 AddPreDefVal(TString("Multiclass"));
258 AddPreDefVal(TString("Auto"));
259
260 ParseOptions();
262
263 if (Verbose())
264 fLogger->SetMinType(kVERBOSE);
265 if (fVerboseLevel.CompareTo("Debug") == 0)
266 fLogger->SetMinType(kDEBUG);
267 if (fVerboseLevel.CompareTo("Verbose") == 0)
268 fLogger->SetMinType(kVERBOSE);
269 if (fVerboseLevel.CompareTo("Info") == 0)
270 fLogger->SetMinType(kINFO);
271
272 // global settings
273 gConfig().SetUseColor(color);
274 gConfig().SetSilent(silent);
275 gConfig().SetDrawProgressBar(drawProgressBar);
276
277 analysisType.ToLower();
278 if (analysisType == "classification")
280 else if (analysisType == "regression")
282 else if (analysisType == "multiclass")
284 else if (analysisType == "auto")
286
287 Greetings();
288}
289
290////////////////////////////////////////////////////////////////////////////////
291/// Print welcome message.
292/// Options are: kLogoWelcomeMsg, kIsometricWelcomeMsg, kLeanWelcomeMsg
293
295{
297 gTools().TMVAWelcomeMessage(Log(), gTools().kLogoWelcomeMsg);
299 Log() << Endl;
300}
301
302////////////////////////////////////////////////////////////////////////////////
303/// Destructor.
304
306{
307 std::vector<TMVA::VariableTransformBase *>::iterator trfIt = fDefaultTrfs.begin();
308 for (; trfIt != fDefaultTrfs.end(); ++trfIt)
309 delete (*trfIt);
310
311 this->DeleteAllMethods();
312
313 // problem with call of REGISTER_METHOD macro ...
314 // ClassifierFactory::DestroyInstance();
315 // Types::DestroyInstance();
316 // Tools::DestroyInstance();
317 // Config::DestroyInstance();
318}
319
320////////////////////////////////////////////////////////////////////////////////
321/// Delete methods.
322
324{
325 std::map<TString, MVector *>::iterator itrMap;
326
327 for (itrMap = fMethodsMap.begin(); itrMap != fMethodsMap.end(); ++itrMap) {
328 MVector *methods = itrMap->second;
329 // delete methods
330 MVector::iterator itrMethod = methods->begin();
331 for (; itrMethod != methods->end(); ++itrMethod) {
332 Log() << kDEBUG << "Delete method: " << (*itrMethod)->GetName() << Endl;
333 delete (*itrMethod);
334 }
335 methods->clear();
336 delete methods;
337 }
338}
339
340////////////////////////////////////////////////////////////////////////////////
341
346
347////////////////////////////////////////////////////////////////////////////////
348/// Books an MVA classifier or regression method. The option configuration
349/// string is custom for each MVA. The TString field `theNameAppendix` serves to
350/// define (and distinguish) several instances of a given MVA, e.g., when one
351/// wants to compare the performance of various configurations
352///
353/// The method is identified by `theMethodName`, which can be provided either as:
354/// - a string containing the method's name, or
355/// - a `TMVA::Types::EMVA` enum value (automatically converted to the corresponding method name).
356
358 TString methodTitle, TString theOption)
359{
361 gSystem->MakeDirectory(loader->GetName()); // creating directory for DataLoader output
362
363 TString datasetname = loader->GetName();
364
366 if (loader->GetDataSetInfo().GetNClasses() == 2 && loader->GetDataSetInfo().GetClassInfo("Signal") != NULL &&
367 loader->GetDataSetInfo().GetClassInfo("Background") != NULL) {
368 fAnalysisType = Types::kClassification; // default is classification
369 } else if (loader->GetDataSetInfo().GetNClasses() >= 2) {
370 fAnalysisType = Types::kMulticlass; // if two classes, but not named "Signal" and "Background"
371 } else
372 Log() << kFATAL << "No analysis type for " << loader->GetDataSetInfo().GetNClasses() << " classes and "
373 << loader->GetDataSetInfo().GetNTargets() << " regression targets." << Endl;
374 }
375
376 // booking via name; the names are translated into enums and the
377 // corresponding overloaded BookMethod is called
378
379 if (fMethodsMap.find(datasetname) != fMethodsMap.end()) {
380 if (GetMethod(datasetname, methodTitle) != 0) {
381 Log() << kFATAL << "Booking failed since method with title <" << methodTitle << "> already exists "
382 << "in with DataSet Name <" << loader->GetName() << "> " << Endl;
383 }
384 }
385
386 Log() << kHEADER << "Booking method: " << gTools().Color("bold")
387 << methodTitle
388 // << gTools().Color("reset")<<" DataSet Name: "<<gTools().Color("bold")<<loader->GetName()
389 << gTools().Color("reset") << Endl << Endl;
390
391 // interpret option string with respect to a request for boosting (i.e., BostNum > 0)
392 Int_t boostNum = 0;
393 TMVA::Configurable *conf = new TMVA::Configurable(theOption);
394 conf->DeclareOptionRef(boostNum = 0, "Boost_num", "Number of times the classifier will be boosted");
395 conf->ParseOptions();
396 delete conf;
397 // this is name of weight file directory
398 TString fileDir;
399 if (fModelPersistence) {
400 // find prefix in fWeightFileDir;
402 fileDir = prefix;
403 if (!prefix.IsNull())
404 if (fileDir[fileDir.Length() - 1] != '/')
405 fileDir += "/";
406 fileDir += loader->GetName();
407 fileDir += "/" + gConfig().GetIONames().fWeightFileDir;
408 }
409 // initialize methods
410 IMethod *im;
411 if (!boostNum) {
412 im = ClassifierFactory::Instance().Create(theMethodName.tString().Data(), fJobName, methodTitle, loader->GetDataSetInfo(),
413 theOption);
414 } else {
415 // boosted classifier, requires a specific definition, making it transparent for the user
416 Log() << kDEBUG << "Boost Number is " << boostNum << " > 0: train boosted classifier" << Endl;
417 im = ClassifierFactory::Instance().Create("Boost", fJobName, methodTitle, loader->GetDataSetInfo(), theOption);
418 MethodBoost *methBoost = dynamic_cast<MethodBoost *>(im); // DSMTEST divided into two lines
419 if (!methBoost) { // DSMTEST
420 Log() << kFATAL << "Method with type kBoost cannot be casted to MethodCategory. /Factory" << Endl; // DSMTEST
421 return nullptr;
422 }
424 methBoost->SetWeightFileDir(fileDir);
426 methBoost->SetBoostedMethodName(theMethodName.tString()); // DSMTEST divided into two lines
427 methBoost->fDataSetManager = loader->GetDataSetInfo().GetDataSetManager(); // DSMTEST
428 methBoost->SetFile(fgTargetFile);
429 methBoost->SetSilentFile(IsSilentFile());
430 }
431
432 MethodBase *method = dynamic_cast<MethodBase *>(im);
433 if (method == 0)
434 return 0; // could not create method
435
436 // set fDataSetManager if MethodCategory (to enable Category to create datasetinfo objects) // DSMTEST
437 if (method->GetMethodType() == Types::kCategory) { // DSMTEST
438 MethodCategory *methCat = (dynamic_cast<MethodCategory *>(im)); // DSMTEST
439 if (!methCat) { // DSMTEST
440 Log() << kFATAL << "Method with type kCategory cannot be casted to MethodCategory. /Factory"
441 << Endl; // DSMTEST
442 return nullptr;
443 }
445 methCat->SetWeightFileDir(fileDir);
447 methCat->fDataSetManager = loader->GetDataSetInfo().GetDataSetManager(); // DSMTEST
448 methCat->SetFile(fgTargetFile);
449 methCat->SetSilentFile(IsSilentFile());
450 } // DSMTEST
451
453 loader->GetDataSetInfo().GetNTargets())) {
454 Log() << kWARNING << "Method " << method->GetMethodTypeName() << " is not capable of handling ";
456 Log() << "regression with " << loader->GetDataSetInfo().GetNTargets() << " targets." << Endl;
457 } else if (fAnalysisType == Types::kMulticlass) {
458 Log() << "multiclass classification with " << loader->GetDataSetInfo().GetNClasses() << " classes." << Endl;
459 } else {
460 Log() << "classification with " << loader->GetDataSetInfo().GetNClasses() << " classes." << Endl;
461 }
462 return 0;
463 }
464
466 method->SetWeightFileDir(fileDir);
469 method->SetupMethod();
470 method->ParseOptions();
471 method->ProcessSetup();
472 method->SetFile(fgTargetFile);
473 method->SetSilentFile(IsSilentFile());
474
475 // check-for-unused-options is performed; may be overridden by derived classes
476 method->CheckSetup();
477
478 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
479 MVector *mvector = new MVector;
480 fMethodsMap[datasetname] = mvector;
481 }
482 fMethodsMap[datasetname]->push_back(method);
483 return method;
484}
485
486////////////////////////////////////////////////////////////////////////////////
487/// Adds an already constructed method to be managed by this factory.
488///
489/// \note Private.
490/// \note Know what you are doing when using this method. The method that you
491/// are loading could be trained already.
492///
493
496{
497 TString datasetname = loader->GetName();
498 std::string methodTypeName = std::string(Types::Instance().GetMethodName(methodType).Data());
499 DataSetInfo &dsi = loader->GetDataSetInfo();
500
501 IMethod *im = ClassifierFactory::Instance().Create(methodTypeName, dsi, weightfile);
502 MethodBase *method = (dynamic_cast<MethodBase *>(im));
503
504 if (method == nullptr)
505 return nullptr;
506
507 if (method->GetMethodType() == Types::kCategory) {
508 Log() << kERROR << "Cannot handle category methods for now." << Endl;
509 }
510
511 TString fileDir;
512 if (fModelPersistence) {
513 // find prefix in fWeightFileDir;
515 fileDir = prefix;
516 if (!prefix.IsNull())
517 if (fileDir[fileDir.Length() - 1] != '/')
518 fileDir += "/";
519 fileDir = loader->GetName();
520 fileDir += "/" + gConfig().GetIONames().fWeightFileDir;
521 }
522
524 method->SetWeightFileDir(fileDir);
527 method->SetupMethod();
528 method->SetFile(fgTargetFile);
529 method->SetSilentFile(IsSilentFile());
530
532
533 // read weight file
534 method->ReadStateFromFile();
535
536 // method->CheckSetup();
537
538 TString methodTitle = method->GetName();
539 if (HasMethod(datasetname, methodTitle) != 0) {
540 Log() << kFATAL << "Booking failed since method with title <" << methodTitle << "> already exists "
541 << "in with DataSet Name <" << loader->GetName() << "> " << Endl;
542 }
543
544 Log() << kINFO << "Booked classifier \"" << method->GetMethodName() << "\" of type: \""
545 << method->GetMethodTypeName() << "\"" << Endl;
546
547 if (fMethodsMap.count(datasetname) == 0) {
548 MVector *mvector = new MVector;
549 fMethodsMap[datasetname] = mvector;
550 }
551
552 fMethodsMap[datasetname]->push_back(method);
553
554 return method;
555}
556
557////////////////////////////////////////////////////////////////////////////////
558/// Returns pointer to MVA that corresponds to given method title.
559
560TMVA::IMethod *TMVA::Factory::GetMethod(const TString &datasetname, const TString &methodTitle) const
561{
562 if (fMethodsMap.find(datasetname) == fMethodsMap.end())
563 return 0;
564
565 MVector *methods = fMethodsMap.find(datasetname)->second;
566
567 MVector::const_iterator itrMethod;
568 //
569 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
570 MethodBase *mva = dynamic_cast<MethodBase *>(*itrMethod);
571 if ((mva->GetMethodName()) == methodTitle)
572 return mva;
573 }
574 return 0;
575}
576
577////////////////////////////////////////////////////////////////////////////////
578/// Checks whether a given method name is defined for a given dataset.
579
580Bool_t TMVA::Factory::HasMethod(const TString &datasetname, const TString &methodTitle) const
581{
582 if (fMethodsMap.find(datasetname) == fMethodsMap.end())
583 return 0;
584
585 std::string methodName = methodTitle.Data();
586 auto isEqualToMethodName = [&methodName](TMVA::IMethod *m) { return (0 == methodName.compare(m->GetName())); };
587
588 TMVA::Factory::MVector *methods = this->fMethodsMap.at(datasetname);
589 Bool_t isMethodNameExisting = std::any_of(methods->begin(), methods->end(), isEqualToMethodName);
590
591 return isMethodNameExisting;
592}
593
594////////////////////////////////////////////////////////////////////////////////
595
597{
598 RootBaseDir()->cd();
599
600 if (!RootBaseDir()->GetDirectory(fDataSetInfo.GetName()))
601 RootBaseDir()->mkdir(fDataSetInfo.GetName());
602 else
603 return; // loader is now in the output file, we dont need to save again
604
605 RootBaseDir()->cd(fDataSetInfo.GetName());
606 fDataSetInfo.GetDataSet(); // builds dataset (including calculation of correlation matrix)
607
608 // correlation matrix of the default DS
609 const TMatrixD *m(0);
610 const TH2 *h(0);
611
613 for (UInt_t cls = 0; cls < fDataSetInfo.GetNClasses(); cls++) {
614 m = fDataSetInfo.CorrelationMatrix(fDataSetInfo.GetClassInfo(cls)->GetName());
615 h = fDataSetInfo.CreateCorrelationMatrixHist(
616 m, TString("CorrelationMatrix") + fDataSetInfo.GetClassInfo(cls)->GetName(),
617 TString("Correlation Matrix (") + fDataSetInfo.GetClassInfo(cls)->GetName() + TString(")"));
618 if (h != 0) {
619 h->Write();
620 delete h;
621 }
622 }
623 } else {
624 m = fDataSetInfo.CorrelationMatrix("Signal");
625 h = fDataSetInfo.CreateCorrelationMatrixHist(m, "CorrelationMatrixS", "Correlation Matrix (signal)");
626 if (h != 0) {
627 h->Write();
628 delete h;
629 }
630
631 m = fDataSetInfo.CorrelationMatrix("Background");
632 h = fDataSetInfo.CreateCorrelationMatrixHist(m, "CorrelationMatrixB", "Correlation Matrix (background)");
633 if (h != 0) {
634 h->Write();
635 delete h;
636 }
637
638 m = fDataSetInfo.CorrelationMatrix("Regression");
639 h = fDataSetInfo.CreateCorrelationMatrixHist(m, "CorrelationMatrix", "Correlation Matrix");
640 if (h != 0) {
641 h->Write();
642 delete h;
643 }
644 }
645
646 // some default transformations to evaluate
647 // NOTE: all transformations are destroyed after this test
648 TString processTrfs = "I"; //"I;N;D;P;U;G,D;"
649
650 // plus some user defined transformations
651 processTrfs = fTransformations;
652
653 // remove any trace of identity transform - if given (avoid to apply it twice)
654 std::vector<TMVA::TransformationHandler *> trfs;
655 TransformationHandler *identityTrHandler = 0;
656
657 std::vector<TString> trfsDef = gTools().SplitString(processTrfs, ';');
658 std::vector<TString>::iterator trfsDefIt = trfsDef.begin();
659 for (; trfsDefIt != trfsDef.end(); ++trfsDefIt) {
660 trfs.push_back(new TMVA::TransformationHandler(fDataSetInfo, "Factory"));
661 TString trfS = (*trfsDefIt);
662
663 // Log() << kINFO << Endl;
664 Log() << kDEBUG << "current transformation string: '" << trfS.Data() << "'" << Endl;
665 TMVA::CreateVariableTransforms(trfS, fDataSetInfo, *(trfs.back()), Log());
666
667 if (trfS.BeginsWith('I'))
668 identityTrHandler = trfs.back();
669 }
670
671 const std::vector<Event *> &inputEvents = fDataSetInfo.GetDataSet()->GetEventCollection();
672
673 // apply all transformations
674 std::vector<TMVA::TransformationHandler *>::iterator trfIt = trfs.begin();
675
676 for (; trfIt != trfs.end(); ++trfIt) {
677 // setting a Root dir causes the variables distributions to be saved to the root file
678 (*trfIt)->SetRootDir(RootBaseDir()->GetDirectory(fDataSetInfo.GetName())); // every dataloader have its own dir
679 (*trfIt)->CalcTransformations(inputEvents);
680 }
681 if (identityTrHandler)
682 identityTrHandler->PrintVariableRanking();
683
684 // clean up
685 for (trfIt = trfs.begin(); trfIt != trfs.end(); ++trfIt)
686 delete *trfIt;
687}
688
689////////////////////////////////////////////////////////////////////////////////
690/// Iterates through all booked methods and sees if they use parameter tuning and if so
691/// does just that, i.e.\ calls "Method::Train()" for different parameter settings and
692/// keeps in mind the "optimal one"...\ and that's the one that will later on be used
693/// in the main training loop.
694
695std::map<TString, Double_t> TMVA::Factory::OptimizeAllMethods(TString fomType, TString fitType)
696{
697
698 std::map<TString, MVector *>::iterator itrMap;
699 std::map<TString, Double_t> TunedParameters;
700 for (itrMap = fMethodsMap.begin(); itrMap != fMethodsMap.end(); ++itrMap) {
701 MVector *methods = itrMap->second;
702
703 MVector::iterator itrMethod;
704
705 // iterate over methods and optimize
706 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
708 MethodBase *mva = dynamic_cast<MethodBase *>(*itrMethod);
709 if (!mva) {
710 Log() << kFATAL << "Dynamic cast to MethodBase failed" << Endl;
711 return TunedParameters;
712 }
713
715 Log() << kWARNING << "Method " << mva->GetMethodName() << " not trained (training tree has less entries ["
716 << mva->Data()->GetNTrainingEvents() << "] than required [" << MinNoTrainingEvents << "]" << Endl;
717 continue;
718 }
719
720 Log() << kINFO << "Optimize method: " << mva->GetMethodName() << " for "
722 ? "Regression"
723 : (fAnalysisType == Types::kMulticlass ? "Multiclass classification" : "Classification"))
724 << Endl;
725
726 TunedParameters = mva->OptimizeTuningParameters(fomType, fitType);
727 Log() << kINFO << "Optimization of tuning parameters finished for Method:" << mva->GetName() << Endl;
728 }
729 }
730
731 return TunedParameters;
732}
733
734////////////////////////////////////////////////////////////////////////////////
735/// Private method to generate a ROCCurve instance for a given method.
736/// Handles the conversion from TMVA ResultSet to a format the ROCCurve class
737/// understands.
738///
739/// \note You own the retured pointer.
740///
741
744{
745 return GetROC((TString)loader->GetName(), theMethodName, iClass, type);
746}
747
748////////////////////////////////////////////////////////////////////////////////
749/// Private method to generate a ROCCurve instance for a given method.
750/// Handles the conversion from TMVA ResultSet to a format the ROCCurve class
751/// understands.
752///
753/// \note You own the retured pointer.
754///
755
757{
758 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
759 Log() << kERROR << Form("DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
760 return nullptr;
761 }
762
763 if (!this->HasMethod(datasetname, theMethodName)) {
764 Log() << kERROR << Form("Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data())
765 << Endl;
766 return nullptr;
767 }
768
769 std::set<Types::EAnalysisType> allowedAnalysisTypes = {Types::kClassification, Types::kMulticlass};
770 if (allowedAnalysisTypes.count(this->fAnalysisType) == 0) {
771 Log() << kERROR << Form("Can only generate ROC curves for analysis type kClassification and kMulticlass.")
772 << Endl;
773 return nullptr;
774 }
775
776 TMVA::MethodBase *method = dynamic_cast<TMVA::MethodBase *>(this->GetMethod(datasetname, theMethodName));
777 TMVA::DataSet *dataset = method->Data();
778 dataset->SetCurrentType(type);
779 TMVA::Results *results = dataset->GetResults(theMethodName, type, this->fAnalysisType);
780
781 UInt_t nClasses = method->DataInfo().GetNClasses();
782 if (this->fAnalysisType == Types::kMulticlass && iClass >= nClasses) {
783 Log() << kERROR
784 << Form("Given class number (iClass = %i) does not exist. There are %i classes in dataset.", iClass,
785 nClasses)
786 << Endl;
787 return nullptr;
788 }
789
790 TMVA::ROCCurve *rocCurve = nullptr;
792
793 std::vector<Float_t> *mvaRes = dynamic_cast<ResultsClassification *>(results)->GetValueVector();
794 std::vector<Bool_t> *mvaResTypes = dynamic_cast<ResultsClassification *>(results)->GetValueVectorTypes();
795 std::vector<Float_t> mvaResWeights;
796
797 auto eventCollection = dataset->GetEventCollection(type);
798 mvaResWeights.reserve(eventCollection.size());
799 for (auto ev : eventCollection) {
800 mvaResWeights.push_back(ev->GetWeight());
801 }
802
803 rocCurve = new TMVA::ROCCurve(*mvaRes, *mvaResTypes, mvaResWeights);
804
805 } else if (this->fAnalysisType == Types::kMulticlass) {
806 std::vector<Float_t> mvaRes;
807 std::vector<Bool_t> mvaResTypes;
808 std::vector<Float_t> mvaResWeights;
809
810 std::vector<std::vector<Float_t>> *rawMvaRes = dynamic_cast<ResultsMulticlass *>(results)->GetValueVector();
811
812 // Vector transpose due to values being stored as
813 // [ [0, 1, 2], [0, 1, 2], ... ]
814 // in ResultsMulticlass::GetValueVector.
815 mvaRes.reserve(rawMvaRes->size());
816 for (auto item : *rawMvaRes) {
817 mvaRes.push_back(item[iClass]);
818 }
819
820 auto eventCollection = dataset->GetEventCollection(type);
821 mvaResTypes.reserve(eventCollection.size());
822 mvaResWeights.reserve(eventCollection.size());
823 for (auto ev : eventCollection) {
824 mvaResTypes.push_back(ev->GetClass() == iClass);
825 mvaResWeights.push_back(ev->GetWeight());
826 }
827
828 rocCurve = new TMVA::ROCCurve(mvaRes, mvaResTypes, mvaResWeights);
829 }
830
831 return rocCurve;
832}
833
834////////////////////////////////////////////////////////////////////////////////
835/// Calculate the integral of the ROC curve, also known as the area under curve
836/// (AUC), for a given method.
837///
838/// Argument iClass specifies the class to generate the ROC curve in a
839/// multiclass setting. It is ignored for binary classification.
840///
841
844{
845 return GetROCIntegral((TString)loader->GetName(), theMethodName, iClass, type);
846}
847
848////////////////////////////////////////////////////////////////////////////////
849/// Calculate the integral of the ROC curve, also known as the area under curve
850/// (AUC), for a given method.
851///
852/// Argument iClass specifies the class to generate the ROC curve in a
853/// multiclass setting. It is ignored for binary classification.
854///
855
857{
858 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
859 Log() << kERROR << Form("DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
860 return 0;
861 }
862
863 if (!this->HasMethod(datasetname, theMethodName)) {
864 Log() << kERROR << Form("Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data())
865 << Endl;
866 return 0;
867 }
868
869 std::set<Types::EAnalysisType> allowedAnalysisTypes = {Types::kClassification, Types::kMulticlass};
870 if (allowedAnalysisTypes.count(this->fAnalysisType) == 0) {
871 Log() << kERROR << Form("Can only generate ROC integral for analysis type kClassification. and kMulticlass.")
872 << Endl;
873 return 0;
874 }
875
876 TMVA::ROCCurve *rocCurve = GetROC(datasetname, theMethodName, iClass, type);
877 if (!rocCurve) {
878 Log() << kFATAL
879 << Form("ROCCurve object was not created in Method = %s not found with Dataset = %s ", theMethodName.Data(),
880 datasetname.Data())
881 << Endl;
882 return 0;
883 }
884
886 Double_t rocIntegral = rocCurve->GetROCIntegral(npoints);
887 delete rocCurve;
888
889 return rocIntegral;
890}
891
892////////////////////////////////////////////////////////////////////////////////
893/// Argument iClass specifies the class to generate the ROC curve in a
894/// multiclass setting. It is ignored for binary classification.
895///
896/// Returns a ROC graph for a given method, or nullptr on error.
897///
898/// Note: Evaluation of the given method must have been run prior to ROC
899/// generation through Factory::EvaluateAllMetods.
900///
901/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
902/// and the others considered background. This is ok in binary classification
903/// but in in multi class classification, the ROC surface is an N dimensional
904/// shape, where N is number of classes - 1.
905
906TGraph *TMVA::Factory::GetROCCurve(DataLoader *loader, TString theMethodName, Bool_t setTitles, UInt_t iClass,
907 Types::ETreeType type)
908{
909 return GetROCCurve((TString)loader->GetName(), theMethodName, setTitles, iClass, type);
910}
911
912////////////////////////////////////////////////////////////////////////////////
913/// Argument iClass specifies the class to generate the ROC curve in a
914/// multiclass setting. It is ignored for binary classification.
915///
916/// Returns a ROC graph for a given method, or nullptr on error.
917///
918/// Note: Evaluation of the given method must have been run prior to ROC
919/// generation through Factory::EvaluateAllMetods.
920///
921/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
922/// and the others considered background. This is ok in binary classification
923/// but in in multi class classification, the ROC surface is an N dimensional
924/// shape, where N is number of classes - 1.
925
926TGraph *TMVA::Factory::GetROCCurve(TString datasetname, TString theMethodName, Bool_t setTitles, UInt_t iClass,
927 Types::ETreeType type)
928{
929 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
930 Log() << kERROR << Form("DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
931 return nullptr;
932 }
933
934 if (!this->HasMethod(datasetname, theMethodName)) {
935 Log() << kERROR << Form("Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data())
936 << Endl;
937 return nullptr;
938 }
939
940 std::set<Types::EAnalysisType> allowedAnalysisTypes = {Types::kClassification, Types::kMulticlass};
941 if (allowedAnalysisTypes.count(this->fAnalysisType) == 0) {
942 Log() << kERROR << Form("Can only generate ROC curves for analysis type kClassification and kMulticlass.")
943 << Endl;
944 return nullptr;
945 }
946
947 TMVA::ROCCurve *rocCurve = GetROC(datasetname, theMethodName, iClass, type);
948 TGraph *graph = nullptr;
949
950 if (!rocCurve) {
951 Log() << kFATAL
952 << Form("ROCCurve object was not created in Method = %s not found with Dataset = %s ", theMethodName.Data(),
953 datasetname.Data())
954 << Endl;
955 return nullptr;
956 }
957
958 graph = (TGraph *)rocCurve->GetROCCurve()->Clone();
959 delete rocCurve;
960
961 if (setTitles) {
962 graph->GetYaxis()->SetTitle("Background rejection (Specificity)");
963 graph->GetXaxis()->SetTitle("Signal efficiency (Sensitivity)");
964 graph->SetTitle(TString::Format("Signal efficiency vs. Background rejection (%s)", theMethodName.Data()).Data());
965 }
966
967 return graph;
968}
969
970////////////////////////////////////////////////////////////////////////////////
971/// Generate a collection of graphs, for all methods for a given class. Suitable
972/// for comparing method performance.
973///
974/// Argument iClass specifies the class to generate the ROC curve in a
975/// multiclass setting. It is ignored for binary classification.
976///
977/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
978/// and the others considered background. This is ok in binary classification
979/// but in in multi class classification, the ROC surface is an N dimensional
980/// shape, where N is number of classes - 1.
981
986
987////////////////////////////////////////////////////////////////////////////////
988/// Generate a collection of graphs, for all methods for a given class. Suitable
989/// for comparing method performance.
990///
991/// Argument iClass specifies the class to generate the ROC curve in a
992/// multiclass setting. It is ignored for binary classification.
993///
994/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
995/// and the others considered background. This is ok in binary classification
996/// but in in multi class classification, the ROC surface is an N dimensional
997/// shape, where N is number of classes - 1.
998
1000{
1001 UInt_t line_color = 1;
1002
1003 TMultiGraph *multigraph = new TMultiGraph();
1004
1005 MVector *methods = fMethodsMap[datasetname.Data()];
1006 for (auto *method_raw : *methods) {
1007 TMVA::MethodBase *method = dynamic_cast<TMVA::MethodBase *>(method_raw);
1008 if (method == nullptr) {
1009 continue;
1010 }
1011
1012 TString methodName = method->GetMethodName();
1013 UInt_t nClasses = method->DataInfo().GetNClasses();
1014
1015 if (this->fAnalysisType == Types::kMulticlass && iClass >= nClasses) {
1016 Log() << kERROR
1017 << Form("Given class number (iClass = %i) does not exist. There are %i classes in dataset.", iClass,
1018 nClasses)
1019 << Endl;
1020 continue;
1021 }
1022
1023 TString className = method->DataInfo().GetClassInfo(iClass)->GetName();
1024
1025 TGraph *graph = this->GetROCCurve(datasetname, methodName, false, iClass, type);
1026 graph->SetTitle(methodName);
1027
1028 graph->SetLineWidth(2);
1029 graph->SetLineColor(line_color++);
1030 graph->SetFillColor(10);
1031
1032 multigraph->Add(graph);
1033 }
1034
1035 if (multigraph->GetListOfGraphs() == nullptr) {
1036 Log() << kERROR << Form("No metohds have class %i defined.", iClass) << Endl;
1037 return nullptr;
1038 }
1039
1040 return multigraph;
1041}
1042
1043////////////////////////////////////////////////////////////////////////////////
1044/// Draws ROC curves for all methods booked with the factory for a given class
1045/// onto a canvas.
1046///
1047/// Argument iClass specifies the class to generate the ROC curve in a
1048/// multiclass setting. It is ignored for binary classification.
1049///
1050/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
1051/// and the others considered background. This is ok in binary classification
1052/// but in in multi class classification, the ROC surface is an N dimensional
1053/// shape, where N is number of classes - 1.
1054
1056{
1057 return GetROCCurve((TString)loader->GetName(), iClass, type);
1058}
1059
1060////////////////////////////////////////////////////////////////////////////////
1061/// Draws ROC curves for all methods booked with the factory for a given class.
1062///
1063/// Argument iClass specifies the class to generate the ROC curve in a
1064/// multiclass setting. It is ignored for binary classification.
1065///
1066/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
1067/// and the others considered background. This is ok in binary classification
1068/// but in in multi class classification, the ROC surface is an N dimensional
1069/// shape, where N is number of classes - 1.
1070
1072{
1073 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
1074 Log() << kERROR << Form("DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
1075 return 0;
1076 }
1077
1078 TString name = TString::Format("ROCCurve %s class %i", datasetname.Data(), iClass);
1079 TCanvas *canvas = new TCanvas(name.Data(), "ROC Curve", 200, 10, 700, 500);
1080 canvas->SetGrid();
1081
1082 TMultiGraph *multigraph = this->GetROCCurveAsMultiGraph(datasetname, iClass, type);
1083
1084 if (multigraph) {
1085 multigraph->Draw("AL");
1086
1087 multigraph->GetYaxis()->SetTitle("Background rejection (Specificity)");
1088 multigraph->GetXaxis()->SetTitle("Signal efficiency (Sensitivity)");
1089
1090 TString titleString = TString::Format("Signal efficiency vs. Background rejection");
1091 if (this->fAnalysisType == Types::kMulticlass) {
1092 titleString = TString::Format("%s (Class=%i)", titleString.Data(), iClass);
1093 }
1094
1095 // Workaround for TMultigraph not drawing title correctly.
1096 multigraph->GetHistogram()->SetTitle(titleString.Data());
1097 multigraph->SetTitle(titleString.Data());
1098
1099 canvas->BuildLegend(0.15, 0.15, 0.35, 0.3, "MVA Method");
1100 }
1101
1102 return canvas;
1103}
1104
1105////////////////////////////////////////////////////////////////////////////////
1106/// Iterates through all booked methods and calls training
1107
1109{
1110 Log() << kHEADER << gTools().Color("bold") << "Train all methods" << gTools().Color("reset") << Endl;
1111 // iterates over all MVAs that have been booked, and calls their training methods
1112
1113 // don't do anything if no method booked
1114 if (fMethodsMap.empty()) {
1115 Log() << kINFO << "...nothing found to train" << Endl;
1116 return;
1117 }
1118
1119 // here the training starts
1120 // Log() << kINFO << " " << Endl;
1121 Log() << kDEBUG << "Train all methods for "
1123 ? "Regression"
1124 : (fAnalysisType == Types::kMulticlass ? "Multiclass" : "Classification"))
1125 << " ..." << Endl;
1126
1127 std::map<TString, MVector *>::iterator itrMap;
1128
1129 for (itrMap = fMethodsMap.begin(); itrMap != fMethodsMap.end(); ++itrMap) {
1130 MVector *methods = itrMap->second;
1131 MVector::iterator itrMethod;
1132
1133 // iterate over methods and train
1134 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1136 MethodBase *mva = dynamic_cast<MethodBase *>(*itrMethod);
1137
1138 if (mva == 0)
1139 continue;
1140
1141 if (mva->DataInfo().GetDataSetManager()->DataInput().GetEntries() <=
1142 1) { // 0 entries --> 0 events, 1 entry --> dynamical dataset (or one entry)
1143 Log() << kFATAL << "No input data for the training provided!" << Endl;
1144 }
1145
1147 Log() << kFATAL << "You want to do regression training without specifying a target." << Endl;
1149 mva->DataInfo().GetNClasses() < 2)
1150 Log() << kFATAL << "You want to do classification training, but specified less than two classes." << Endl;
1151
1152 // first print some information about the default dataset
1153 if (!IsSilentFile())
1155
1157 Log() << kWARNING << "Method " << mva->GetMethodName() << " not trained (training tree has less entries ["
1158 << mva->Data()->GetNTrainingEvents() << "] than required [" << MinNoTrainingEvents << "]" << Endl;
1159 continue;
1160 }
1161
1162 Log() << kHEADER << "Train method: " << mva->GetMethodName() << " for "
1164 ? "Regression"
1165 : (fAnalysisType == Types::kMulticlass ? "Multiclass classification" : "Classification"))
1166 << Endl << Endl;
1167 mva->TrainMethod();
1168 Log() << kHEADER << "Training finished" << Endl << Endl;
1169 }
1170
1172
1173 // variable ranking
1174 // Log() << Endl;
1175 Log() << kINFO << "Ranking input variables (method specific)..." << Endl;
1176 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1177 MethodBase *mva = dynamic_cast<MethodBase *>(*itrMethod);
1178 if (mva && mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
1179
1180 // create and print ranking
1181 const Ranking *ranking = (*itrMethod)->CreateRanking();
1182 if (ranking != 0)
1183 ranking->Print();
1184 else
1185 Log() << kINFO << "No variable ranking supplied by classifier: "
1186 << dynamic_cast<MethodBase *>(*itrMethod)->GetMethodName() << Endl;
1187 }
1188 }
1189 }
1190
1191 // save training history in case we are not in the silent mode
1192 if (!IsSilentFile()) {
1193 for (UInt_t i = 0; i < methods->size(); i++) {
1194 MethodBase *m = dynamic_cast<MethodBase *>((*methods)[i]);
1195 if (m == 0)
1196 continue;
1197 m->BaseDir()->cd();
1198 m->fTrainHistory.SaveHistory(m->GetMethodName());
1199 }
1200 }
1201
1202 // delete all methods and recreate them from weight file - this ensures that the application
1203 // of the methods (in TMVAClassificationApplication) is consistent with the results obtained
1204 // in the testing
1205 // Log() << Endl;
1206 if (fModelPersistence) {
1207
1208 Log() << kHEADER << "=== Destroy and recreate all methods via weight files for testing ===" << Endl << Endl;
1209
1210 if (!IsSilentFile())
1211 RootBaseDir()->cd();
1212
1213 // iterate through all booked methods
1214 for (UInt_t i = 0; i < methods->size(); i++) {
1215
1216 MethodBase *m = dynamic_cast<MethodBase *>((*methods)[i]);
1217 if (m == nullptr)
1218 continue;
1219
1220 TMVA::Types::EMVA methodType = m->GetMethodType();
1221 TString weightfile = m->GetWeightFileName();
1222
1223 // decide if .txt or .xml file should be read:
1224 if (READXML)
1225 weightfile.ReplaceAll(".txt", ".xml");
1226
1227 DataSetInfo &dataSetInfo = m->DataInfo();
1228 TString testvarName = m->GetTestvarName();
1229 delete m; // itrMethod[i];
1230
1231 // recreate
1232 m = dynamic_cast<MethodBase *>(ClassifierFactory::Instance().Create(
1233 Types::Instance().GetMethodName(methodType).Data(), dataSetInfo, weightfile));
1234 if (m->GetMethodType() == Types::kCategory) {
1235 MethodCategory *methCat = (dynamic_cast<MethodCategory *>(m));
1236 if (!methCat)
1237 Log() << kFATAL << "Method with type kCategory cannot be casted to MethodCategory. /Factory" << Endl;
1238 else
1239 methCat->fDataSetManager = m->DataInfo().GetDataSetManager();
1240 }
1241 // ToDo, Do we need to fill the DataSetManager of MethodBoost here too?
1242
1243 TString wfileDir = m->DataInfo().GetName();
1244 wfileDir += "/" + gConfig().GetIONames().fWeightFileDir;
1245 m->SetWeightFileDir(wfileDir);
1246 m->SetModelPersistence(fModelPersistence);
1247 m->SetSilentFile(IsSilentFile());
1248 m->SetAnalysisType(fAnalysisType);
1249 m->SetupMethod();
1250 m->ReadStateFromFile();
1251 m->SetTestvarName(testvarName);
1252
1253 // replace trained method by newly created one (from weight file) in methods vector
1254 (*methods)[i] = m;
1255 }
1256 }
1257 }
1258}
1259
1260////////////////////////////////////////////////////////////////////////////////
1261/// Evaluates all booked methods on the testing data and adds the output to the
1262/// Results in the corresponiding DataSet.
1263///
1264
1266{
1267 Log() << kHEADER << gTools().Color("bold") << "Test all methods" << gTools().Color("reset") << Endl;
1268
1269 // don't do anything if no method booked
1270 if (fMethodsMap.empty()) {
1271 Log() << kINFO << "...nothing found to test" << Endl;
1272 return;
1273 }
1274 std::map<TString, MVector *>::iterator itrMap;
1275
1276 for (itrMap = fMethodsMap.begin(); itrMap != fMethodsMap.end(); ++itrMap) {
1277 MVector *methods = itrMap->second;
1278 MVector::iterator itrMethod;
1279
1280 // iterate over methods and test
1281 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1283 MethodBase *mva = dynamic_cast<MethodBase *>(*itrMethod);
1284 if (mva == 0)
1285 continue;
1286 Types::EAnalysisType analysisType = mva->GetAnalysisType();
1287 Log() << kHEADER << "Test method: " << mva->GetMethodName() << " for "
1288 << (analysisType == Types::kRegression
1289 ? "Regression"
1290 : (analysisType == Types::kMulticlass ? "Multiclass classification" : "Classification"))
1291 << " performance" << Endl << Endl;
1292 mva->AddOutput(Types::kTesting, analysisType);
1293 }
1294 }
1295}
1296
1297////////////////////////////////////////////////////////////////////////////////
1298
1299void TMVA::Factory::MakeClass(const TString &datasetname, const TString &methodTitle) const
1300{
1301 if (methodTitle != "") {
1302 IMethod *method = GetMethod(datasetname, methodTitle);
1303 if (method)
1304 method->MakeClass();
1305 else {
1306 Log() << kWARNING << "<MakeClass> Could not find classifier \"" << methodTitle << "\" in list" << Endl;
1307 }
1308 } else {
1309
1310 // no classifier specified, print all help messages
1311 MVector *methods = fMethodsMap.find(datasetname)->second;
1312 MVector::const_iterator itrMethod;
1313 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1314 MethodBase *method = dynamic_cast<MethodBase *>(*itrMethod);
1315 if (method == 0)
1316 continue;
1317 Log() << kINFO << "Make response class for classifier: " << method->GetMethodName() << Endl;
1318 method->MakeClass();
1319 }
1320 }
1321}
1322
1323////////////////////////////////////////////////////////////////////////////////
1324/// Print predefined help message of classifier.
1325/// Iterate over methods and test.
1326
1327void TMVA::Factory::PrintHelpMessage(const TString &datasetname, const TString &methodTitle) const
1328{
1329 if (methodTitle != "") {
1330 IMethod *method = GetMethod(datasetname, methodTitle);
1331 if (method)
1332 method->PrintHelpMessage();
1333 else {
1334 Log() << kWARNING << "<PrintHelpMessage> Could not find classifier \"" << methodTitle << "\" in list" << Endl;
1335 }
1336 } else {
1337
1338 // no classifier specified, print all help messages
1339 MVector *methods = fMethodsMap.find(datasetname)->second;
1340 MVector::const_iterator itrMethod;
1341 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1342 MethodBase *method = dynamic_cast<MethodBase *>(*itrMethod);
1343 if (method == 0)
1344 continue;
1345 Log() << kINFO << "Print help message for classifier: " << method->GetMethodName() << Endl;
1346 method->PrintHelpMessage();
1347 }
1348 }
1349}
1350
1351////////////////////////////////////////////////////////////////////////////////
1352/// Iterates over all MVA input variables and evaluates them.
1353
1355{
1356 Log() << kINFO << "Evaluating all variables..." << Endl;
1358
1359 for (UInt_t i = 0; i < loader->GetDataSetInfo().GetNVariables(); i++) {
1361 if (options.Contains("V"))
1362 s += ":V";
1363 this->BookMethod(loader, "Variable", s);
1364 }
1365}
1366
1367////////////////////////////////////////////////////////////////////////////////
1368/// Iterates over all MVAs that have been booked, and calls their evaluation methods.
1369
1371{
1372 Log() << kHEADER << gTools().Color("bold") << "Evaluate all methods" << gTools().Color("reset") << Endl;
1373
1374 // don't do anything if no method booked
1375 if (fMethodsMap.empty()) {
1376 Log() << kINFO << "...nothing found to evaluate" << Endl;
1377 return;
1378 }
1379 std::map<TString, MVector *>::iterator itrMap;
1380
1381 for (itrMap = fMethodsMap.begin(); itrMap != fMethodsMap.end(); ++itrMap) {
1382 MVector *methods = itrMap->second;
1383
1384 // -----------------------------------------------------------------------
1385 // First part of evaluation process
1386 // --> compute efficiencies, and other separation estimators
1387 // -----------------------------------------------------------------------
1388
1389 // although equal, we now want to separate the output for the variables
1390 // and the real methods
1391 Int_t isel; // will be 0 for a Method; 1 for a Variable
1392 Int_t nmeth_used[2] = {0, 0}; // 0 Method; 1 Variable
1393
1394 std::vector<std::vector<TString>> mname(2);
1395 std::vector<std::vector<Double_t>> sig(2), sep(2), roc(2);
1396 std::vector<std::vector<Double_t>> eff01(2), eff10(2), eff30(2), effArea(2);
1397 std::vector<std::vector<Double_t>> eff01err(2), eff10err(2), eff30err(2);
1398 std::vector<std::vector<Double_t>> trainEff01(2), trainEff10(2), trainEff30(2);
1399
1400 std::vector<std::vector<Float_t>> multiclass_testEff;
1401 std::vector<std::vector<Float_t>> multiclass_trainEff;
1402 std::vector<std::vector<Float_t>> multiclass_testPur;
1403 std::vector<std::vector<Float_t>> multiclass_trainPur;
1404
1405 std::vector<std::vector<Float_t>> train_history;
1406
1407 // Multiclass confusion matrices.
1408 std::vector<TMatrixD> multiclass_trainConfusionEffB01;
1409 std::vector<TMatrixD> multiclass_trainConfusionEffB10;
1410 std::vector<TMatrixD> multiclass_trainConfusionEffB30;
1411 std::vector<TMatrixD> multiclass_testConfusionEffB01;
1412 std::vector<TMatrixD> multiclass_testConfusionEffB10;
1413 std::vector<TMatrixD> multiclass_testConfusionEffB30;
1414
1415 std::vector<std::vector<Double_t>> biastrain(1); // "bias" of the regression on the training data
1416 std::vector<std::vector<Double_t>> biastest(1); // "bias" of the regression on test data
1417 std::vector<std::vector<Double_t>> devtrain(1); // "dev" of the regression on the training data
1418 std::vector<std::vector<Double_t>> devtest(1); // "dev" of the regression on test data
1419 std::vector<std::vector<Double_t>> rmstrain(1); // "rms" of the regression on the training data
1420 std::vector<std::vector<Double_t>> rmstest(1); // "rms" of the regression on test data
1421 std::vector<std::vector<Double_t>> minftrain(1); // "minf" of the regression on the training data
1422 std::vector<std::vector<Double_t>> minftest(1); // "minf" of the regression on test data
1423 std::vector<std::vector<Double_t>> rhotrain(1); // correlation of the regression on the training data
1424 std::vector<std::vector<Double_t>> rhotest(1); // correlation of the regression on test data
1425
1426 // same as above but for 'truncated' quantities (computed for events within 2sigma of RMS)
1427 std::vector<std::vector<Double_t>> biastrainT(1);
1428 std::vector<std::vector<Double_t>> biastestT(1);
1429 std::vector<std::vector<Double_t>> devtrainT(1);
1430 std::vector<std::vector<Double_t>> devtestT(1);
1431 std::vector<std::vector<Double_t>> rmstrainT(1);
1432 std::vector<std::vector<Double_t>> rmstestT(1);
1433 std::vector<std::vector<Double_t>> minftrainT(1);
1434 std::vector<std::vector<Double_t>> minftestT(1);
1435
1436 // following vector contains all methods - with the exception of Cuts, which are special
1437 MVector methodsNoCuts;
1438
1439 Bool_t doRegression = kFALSE;
1440 Bool_t doMulticlass = kFALSE;
1441
1442 // iterate over methods and evaluate
1443 for (MVector::iterator itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1445 MethodBase *theMethod = dynamic_cast<MethodBase *>(*itrMethod);
1446 if (theMethod == 0)
1447 continue;
1448 theMethod->SetFile(fgTargetFile);
1449 theMethod->SetSilentFile(IsSilentFile());
1450 if (theMethod->GetMethodType() != Types::kCuts)
1451 methodsNoCuts.push_back(*itrMethod);
1452
1453 if (theMethod->DoRegression()) {
1454 doRegression = kTRUE;
1455
1456 Log() << kINFO << "Evaluate regression method: " << theMethod->GetMethodName() << Endl;
1457 Double_t bias, dev, rms, mInf;
1458 Double_t biasT, devT, rmsT, mInfT;
1459 Double_t rho;
1460
1461 Log() << kINFO << "TestRegression (testing)" << Endl;
1462 theMethod->TestRegression(bias, biasT, dev, devT, rms, rmsT, mInf, mInfT, rho, TMVA::Types::kTesting);
1463 biastest[0].push_back(bias);
1464 devtest[0].push_back(dev);
1465 rmstest[0].push_back(rms);
1466 minftest[0].push_back(mInf);
1467 rhotest[0].push_back(rho);
1468 biastestT[0].push_back(biasT);
1469 devtestT[0].push_back(devT);
1470 rmstestT[0].push_back(rmsT);
1471 minftestT[0].push_back(mInfT);
1472
1473 Log() << kINFO << "TestRegression (training)" << Endl;
1474 theMethod->TestRegression(bias, biasT, dev, devT, rms, rmsT, mInf, mInfT, rho, TMVA::Types::kTraining);
1475 biastrain[0].push_back(bias);
1476 devtrain[0].push_back(dev);
1477 rmstrain[0].push_back(rms);
1478 minftrain[0].push_back(mInf);
1479 rhotrain[0].push_back(rho);
1480 biastrainT[0].push_back(biasT);
1481 devtrainT[0].push_back(devT);
1482 rmstrainT[0].push_back(rmsT);
1483 minftrainT[0].push_back(mInfT);
1484
1485 mname[0].push_back(theMethod->GetMethodName());
1486 nmeth_used[0]++;
1487 if (!IsSilentFile()) {
1488 Log() << kDEBUG << "\tWrite evaluation histograms to file" << Endl;
1491 }
1492 } else if (theMethod->DoMulticlass()) {
1493 // ====================================================================
1494 // === Multiclass evaluation
1495 // ====================================================================
1496 doMulticlass = kTRUE;
1497 Log() << kINFO << "Evaluate multiclass classification method: " << theMethod->GetMethodName() << Endl;
1498
1499 // This part uses a genetic alg. to evaluate the optimal sig eff * sig pur.
1500 // This is why it is disabled for now.
1501 // Find approximate optimal working point w.r.t. signalEfficiency * signalPurity.
1502 // theMethod->TestMulticlass(); // This is where the actual GA calc is done
1503 // multiclass_testEff.push_back(theMethod->GetMulticlassEfficiency(multiclass_testPur));
1504
1505 theMethod->TestMulticlass();
1506
1507 // Confusion matrix at three background efficiency levels
1508 multiclass_trainConfusionEffB01.push_back(theMethod->GetMulticlassConfusionMatrix(0.01, Types::kTraining));
1509 multiclass_trainConfusionEffB10.push_back(theMethod->GetMulticlassConfusionMatrix(0.10, Types::kTraining));
1510 multiclass_trainConfusionEffB30.push_back(theMethod->GetMulticlassConfusionMatrix(0.30, Types::kTraining));
1511
1512 multiclass_testConfusionEffB01.push_back(theMethod->GetMulticlassConfusionMatrix(0.01, Types::kTesting));
1513 multiclass_testConfusionEffB10.push_back(theMethod->GetMulticlassConfusionMatrix(0.10, Types::kTesting));
1514 multiclass_testConfusionEffB30.push_back(theMethod->GetMulticlassConfusionMatrix(0.30, Types::kTesting));
1515
1516 if (!IsSilentFile()) {
1517 Log() << kDEBUG << "\tWrite evaluation histograms to file" << Endl;
1520 }
1521
1522 nmeth_used[0]++;
1523 mname[0].push_back(theMethod->GetMethodName());
1524 } else {
1525
1526 Log() << kHEADER << "Evaluate classifier: " << theMethod->GetMethodName() << Endl << Endl;
1527 isel = (theMethod->GetMethodTypeName().Contains("Variable")) ? 1 : 0;
1528
1529 // perform the evaluation
1530 theMethod->TestClassification();
1531
1532 // evaluate the classifier
1533 mname[isel].push_back(theMethod->GetMethodName());
1534 sig[isel].push_back(theMethod->GetSignificance());
1535 sep[isel].push_back(theMethod->GetSeparation());
1536 roc[isel].push_back(theMethod->GetROCIntegral());
1537
1538 Double_t err;
1539 eff01[isel].push_back(theMethod->GetEfficiency("Efficiency:0.01", Types::kTesting, err));
1540 eff01err[isel].push_back(err);
1541 eff10[isel].push_back(theMethod->GetEfficiency("Efficiency:0.10", Types::kTesting, err));
1542 eff10err[isel].push_back(err);
1543 eff30[isel].push_back(theMethod->GetEfficiency("Efficiency:0.30", Types::kTesting, err));
1544 eff30err[isel].push_back(err);
1545 effArea[isel].push_back(theMethod->GetEfficiency("", Types::kTesting, err)); // computes the area (average)
1546
1547 trainEff01[isel].push_back(
1548 theMethod->GetTrainingEfficiency("Efficiency:0.01")); // the first pass takes longer
1549 trainEff10[isel].push_back(theMethod->GetTrainingEfficiency("Efficiency:0.10"));
1550 trainEff30[isel].push_back(theMethod->GetTrainingEfficiency("Efficiency:0.30"));
1551
1552 nmeth_used[isel]++;
1553
1554 if (!IsSilentFile()) {
1555 Log() << kDEBUG << "\tWrite evaluation histograms to file" << Endl;
1558 }
1559 }
1560 }
1561 if (doRegression) {
1562
1563 std::vector<TString> vtemps = mname[0];
1564 std::vector<std::vector<Double_t>> vtmp;
1565 vtmp.push_back(devtest[0]); // this is the vector that is ranked
1566 vtmp.push_back(devtrain[0]);
1567 vtmp.push_back(biastest[0]);
1568 vtmp.push_back(biastrain[0]);
1569 vtmp.push_back(rmstest[0]);
1570 vtmp.push_back(rmstrain[0]);
1571 vtmp.push_back(minftest[0]);
1572 vtmp.push_back(minftrain[0]);
1573 vtmp.push_back(rhotest[0]);
1574 vtmp.push_back(rhotrain[0]);
1575 vtmp.push_back(devtestT[0]); // this is the vector that is ranked
1576 vtmp.push_back(devtrainT[0]);
1577 vtmp.push_back(biastestT[0]);
1578 vtmp.push_back(biastrainT[0]);
1579 vtmp.push_back(rmstestT[0]);
1580 vtmp.push_back(rmstrainT[0]);
1581 vtmp.push_back(minftestT[0]);
1582 vtmp.push_back(minftrainT[0]);
1583 gTools().UsefulSortAscending(vtmp, &vtemps);
1584 mname[0] = vtemps;
1585 devtest[0] = vtmp[0];
1586 devtrain[0] = vtmp[1];
1587 biastest[0] = vtmp[2];
1588 biastrain[0] = vtmp[3];
1589 rmstest[0] = vtmp[4];
1590 rmstrain[0] = vtmp[5];
1591 minftest[0] = vtmp[6];
1592 minftrain[0] = vtmp[7];
1593 rhotest[0] = vtmp[8];
1594 rhotrain[0] = vtmp[9];
1595 devtestT[0] = vtmp[10];
1596 devtrainT[0] = vtmp[11];
1597 biastestT[0] = vtmp[12];
1598 biastrainT[0] = vtmp[13];
1599 rmstestT[0] = vtmp[14];
1600 rmstrainT[0] = vtmp[15];
1601 minftestT[0] = vtmp[16];
1602 minftrainT[0] = vtmp[17];
1603 } else if (doMulticlass) {
1604 // TODO: fill in something meaningful
1605 // If there is some ranking of methods to be done it should be done here.
1606 // However, this is not so easy to define for multiclass so it is left out for now.
1607
1608 } else {
1609 // now sort the variables according to the best 'eff at Beff=0.10'
1610 for (Int_t k = 0; k < 2; k++) {
1611 std::vector<std::vector<Double_t>> vtemp;
1612 vtemp.push_back(effArea[k]); // this is the vector that is ranked
1613 vtemp.push_back(eff10[k]);
1614 vtemp.push_back(eff01[k]);
1615 vtemp.push_back(eff30[k]);
1616 vtemp.push_back(eff10err[k]);
1617 vtemp.push_back(eff01err[k]);
1618 vtemp.push_back(eff30err[k]);
1619 vtemp.push_back(trainEff10[k]);
1620 vtemp.push_back(trainEff01[k]);
1621 vtemp.push_back(trainEff30[k]);
1622 vtemp.push_back(sig[k]);
1623 vtemp.push_back(sep[k]);
1624 vtemp.push_back(roc[k]);
1625 std::vector<TString> vtemps = mname[k];
1626 gTools().UsefulSortDescending(vtemp, &vtemps);
1627 effArea[k] = vtemp[0];
1628 eff10[k] = vtemp[1];
1629 eff01[k] = vtemp[2];
1630 eff30[k] = vtemp[3];
1631 eff10err[k] = vtemp[4];
1632 eff01err[k] = vtemp[5];
1633 eff30err[k] = vtemp[6];
1634 trainEff10[k] = vtemp[7];
1635 trainEff01[k] = vtemp[8];
1636 trainEff30[k] = vtemp[9];
1637 sig[k] = vtemp[10];
1638 sep[k] = vtemp[11];
1639 roc[k] = vtemp[12];
1640 mname[k] = vtemps;
1641 }
1642 }
1643
1644 // -----------------------------------------------------------------------
1645 // Second part of evaluation process
1646 // --> compute correlations among MVAs
1647 // --> compute correlations between input variables and MVA (determines importance)
1648 // --> count overlaps
1649 // -----------------------------------------------------------------------
1650 if (fCorrelations) {
1651 const Int_t nmeth = methodsNoCuts.size();
1652 MethodBase *method = dynamic_cast<MethodBase *>(methods[0][0]);
1653 const Int_t nvar = method->fDataSetInfo.GetNVariables();
1654 if (!doRegression && !doMulticlass) {
1655
1656 if (nmeth > 0) {
1657
1658 // needed for correlations
1659 Double_t *dvec = new Double_t[nmeth + nvar];
1660 std::vector<Double_t> rvec;
1661
1662 // for correlations
1663 TPrincipal *tpSig = new TPrincipal(nmeth + nvar, "");
1664 TPrincipal *tpBkg = new TPrincipal(nmeth + nvar, "");
1665
1666 // set required tree branch references
1667 std::vector<TString> *theVars = new std::vector<TString>;
1668 std::vector<ResultsClassification *> mvaRes;
1669 for (MVector::iterator itrMethod = methodsNoCuts.begin(); itrMethod != methodsNoCuts.end();
1670 ++itrMethod) {
1671 MethodBase *m = dynamic_cast<MethodBase *>(*itrMethod);
1672 if (m == 0)
1673 continue;
1674 theVars->push_back(m->GetTestvarName());
1675 rvec.push_back(m->GetSignalReferenceCut());
1676 theVars->back().ReplaceAll("MVA_", "");
1677 mvaRes.push_back(dynamic_cast<ResultsClassification *>(
1678 m->Data()->GetResults(m->GetMethodName(), Types::kTesting, Types::kMaxAnalysisType)));
1679 }
1680
1681 // for overlap study
1682 TMatrixD *overlapS = new TMatrixD(nmeth, nmeth);
1683 TMatrixD *overlapB = new TMatrixD(nmeth, nmeth);
1684 (*overlapS) *= 0; // init...
1685 (*overlapB) *= 0; // init...
1686
1687 // loop over test tree
1688 DataSet *defDs = method->fDataSetInfo.GetDataSet();
1690 for (Int_t ievt = 0; ievt < defDs->GetNEvents(); ievt++) {
1691 const Event *ev = defDs->GetEvent(ievt);
1692
1693 // for correlations
1694 TMatrixD *theMat = 0;
1695 for (Int_t im = 0; im < nmeth; im++) {
1696 // check for NaN value
1697 Double_t retval = (Double_t)(*mvaRes[im])[ievt][0];
1698 if (TMath::IsNaN(retval)) {
1699 Log() << kWARNING << "Found NaN return value in event: " << ievt << " for method \""
1700 << methodsNoCuts[im]->GetName() << "\"" << Endl;
1701 dvec[im] = 0;
1702 } else
1703 dvec[im] = retval;
1704 }
1705 for (Int_t iv = 0; iv < nvar; iv++)
1706 dvec[iv + nmeth] = (Double_t)ev->GetValue(iv);
1707 if (method->fDataSetInfo.IsSignal(ev)) {
1708 tpSig->AddRow(dvec);
1709 theMat = overlapS;
1710 } else {
1711 tpBkg->AddRow(dvec);
1712 theMat = overlapB;
1713 }
1714
1715 // count overlaps
1716 for (Int_t im = 0; im < nmeth; im++) {
1717 for (Int_t jm = im; jm < nmeth; jm++) {
1718 if ((dvec[im] - rvec[im]) * (dvec[jm] - rvec[jm]) > 0) {
1719 (*theMat)(im, jm)++;
1720 if (im != jm)
1721 (*theMat)(jm, im)++;
1722 }
1723 }
1724 }
1725 }
1726
1727 // renormalise overlap matrix
1728 (*overlapS) *= (1.0 / defDs->GetNEvtSigTest()); // init...
1729 (*overlapB) *= (1.0 / defDs->GetNEvtBkgdTest()); // init...
1730
1731 tpSig->MakePrincipals();
1732 tpBkg->MakePrincipals();
1733
1734 const TMatrixD *covMatS = tpSig->GetCovarianceMatrix();
1735 const TMatrixD *covMatB = tpBkg->GetCovarianceMatrix();
1736
1737 const TMatrixD *corrMatS = gTools().GetCorrelationMatrix(covMatS);
1738 const TMatrixD *corrMatB = gTools().GetCorrelationMatrix(covMatB);
1739
1740 // print correlation matrices
1741 if (corrMatS != 0 && corrMatB != 0) {
1742
1743 // extract MVA matrix
1744 TMatrixD mvaMatS(nmeth, nmeth);
1745 TMatrixD mvaMatB(nmeth, nmeth);
1746 for (Int_t im = 0; im < nmeth; im++) {
1747 for (Int_t jm = 0; jm < nmeth; jm++) {
1748 mvaMatS(im, jm) = (*corrMatS)(im, jm);
1749 mvaMatB(im, jm) = (*corrMatB)(im, jm);
1750 }
1751 }
1752
1753 // extract variables - to MVA matrix
1754 std::vector<TString> theInputVars;
1755 TMatrixD varmvaMatS(nvar, nmeth);
1756 TMatrixD varmvaMatB(nvar, nmeth);
1757 for (Int_t iv = 0; iv < nvar; iv++) {
1758 theInputVars.push_back(method->fDataSetInfo.GetVariableInfo(iv).GetLabel());
1759 for (Int_t jm = 0; jm < nmeth; jm++) {
1760 varmvaMatS(iv, jm) = (*corrMatS)(nmeth + iv, jm);
1761 varmvaMatB(iv, jm) = (*corrMatB)(nmeth + iv, jm);
1762 }
1763 }
1764
1765 if (nmeth > 1) {
1766 Log() << kINFO << Endl;
1767 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1768 << "Inter-MVA correlation matrix (signal):" << Endl;
1769 gTools().FormattedOutput(mvaMatS, *theVars, Log());
1770 Log() << kINFO << Endl;
1771
1772 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1773 << "Inter-MVA correlation matrix (background):" << Endl;
1774 gTools().FormattedOutput(mvaMatB, *theVars, Log());
1775 Log() << kINFO << Endl;
1776 }
1777
1778 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1779 << "Correlations between input variables and MVA response (signal):" << Endl;
1780 gTools().FormattedOutput(varmvaMatS, theInputVars, *theVars, Log());
1781 Log() << kINFO << Endl;
1782
1783 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1784 << "Correlations between input variables and MVA response (background):" << Endl;
1785 gTools().FormattedOutput(varmvaMatB, theInputVars, *theVars, Log());
1786 Log() << kINFO << Endl;
1787 } else
1788 Log() << kWARNING << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1789 << "<TestAllMethods> cannot compute correlation matrices" << Endl;
1790
1791 // print overlap matrices
1792 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1793 << "The following \"overlap\" matrices contain the fraction of events for which " << Endl;
1794 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1795 << "the MVAs 'i' and 'j' have returned conform answers about \"signal-likeness\"" << Endl;
1796 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1797 << "An event is signal-like, if its MVA output exceeds the following value:" << Endl;
1798 gTools().FormattedOutput(rvec, *theVars, "Method", "Cut value", Log());
1799 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1800 << "which correspond to the working point: eff(signal) = 1 - eff(background)" << Endl;
1801
1802 // give notice that cut method has been excluded from this test
1803 if (nmeth != (Int_t)methods->size())
1804 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1805 << "Note: no correlations and overlap with cut method are provided at present" << Endl;
1806
1807 if (nmeth > 1) {
1808 Log() << kINFO << Endl;
1809 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1810 << "Inter-MVA overlap matrix (signal):" << Endl;
1811 gTools().FormattedOutput(*overlapS, *theVars, Log());
1812 Log() << kINFO << Endl;
1813
1814 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1815 << "Inter-MVA overlap matrix (background):" << Endl;
1816 gTools().FormattedOutput(*overlapB, *theVars, Log());
1817 }
1818
1819 // cleanup
1820 delete tpSig;
1821 delete tpBkg;
1822 delete corrMatS;
1823 delete corrMatB;
1824 delete theVars;
1825 delete overlapS;
1826 delete overlapB;
1827 delete[] dvec;
1828 }
1829 }
1830 }
1831 // -----------------------------------------------------------------------
1832 // Third part of evaluation process
1833 // --> output
1834 // -----------------------------------------------------------------------
1835
1836 if (doRegression) {
1837
1838 Log() << kINFO << Endl;
1839 TString hLine =
1840 "--------------------------------------------------------------------------------------------------";
1841 Log() << kINFO << "Evaluation results ranked by smallest RMS on test sample:" << Endl;
1842 Log() << kINFO << "(\"Bias\" quotes the mean deviation of the regression from true target." << Endl;
1843 Log() << kINFO << " \"MutInf\" is the \"Mutual Information\" between regression and target." << Endl;
1844 Log() << kINFO << " Indicated by \"_T\" are the corresponding \"truncated\" quantities ob-" << Endl;
1845 Log() << kINFO << " tained when removing events deviating more than 2sigma from average.)" << Endl;
1846 Log() << kINFO << hLine << Endl;
1847 // Log() << kINFO << "DataSet Name: MVA Method: <Bias> <Bias_T> RMS RMS_T | MutInf
1848 // MutInf_T" << Endl;
1849 Log() << kINFO << hLine << Endl;
1850
1851 for (Int_t i = 0; i < nmeth_used[0]; i++) {
1852 MethodBase *theMethod = dynamic_cast<MethodBase *>((*methods)[i]);
1853 if (theMethod == 0)
1854 continue;
1855
1856 Log() << kINFO
1857 << Form("%-20s %-15s:%#9.3g%#9.3g%#9.3g%#9.3g | %#5.3f %#5.3f", theMethod->fDataSetInfo.GetName(),
1858 (const char *)mname[0][i], biastest[0][i], biastestT[0][i], rmstest[0][i], rmstestT[0][i],
1859 minftest[0][i], minftestT[0][i])
1860 << Endl;
1861 }
1862 Log() << kINFO << hLine << Endl;
1863 Log() << kINFO << Endl;
1864 Log() << kINFO << "Evaluation results ranked by smallest RMS on training sample:" << Endl;
1865 Log() << kINFO << "(overtraining check)" << Endl;
1866 Log() << kINFO << hLine << Endl;
1867 Log() << kINFO
1868 << "DataSet Name: MVA Method: <Bias> <Bias_T> RMS RMS_T | MutInf MutInf_T"
1869 << Endl;
1870 Log() << kINFO << hLine << Endl;
1871
1872 for (Int_t i = 0; i < nmeth_used[0]; i++) {
1873 MethodBase *theMethod = dynamic_cast<MethodBase *>((*methods)[i]);
1874 if (theMethod == 0)
1875 continue;
1876 Log() << kINFO
1877 << Form("%-20s %-15s:%#9.3g%#9.3g%#9.3g%#9.3g | %#5.3f %#5.3f", theMethod->fDataSetInfo.GetName(),
1878 (const char *)mname[0][i], biastrain[0][i], biastrainT[0][i], rmstrain[0][i], rmstrainT[0][i],
1879 minftrain[0][i], minftrainT[0][i])
1880 << Endl;
1881 }
1882 Log() << kINFO << hLine << Endl;
1883 Log() << kINFO << Endl;
1884 } else if (doMulticlass) {
1885 // ====================================================================
1886 // === Multiclass Output
1887 // ====================================================================
1888
1889 TString hLine =
1890 "-------------------------------------------------------------------------------------------------------";
1891
1892 // This part uses a genetic alg. to evaluate the optimal sig eff * sig pur.
1893 // This is why it is disabled for now.
1894 //
1895 // // --- Acheivable signal efficiency * signal purity
1896 // // --------------------------------------------------------------------
1897 // Log() << kINFO << Endl;
1898 // Log() << kINFO << "Evaluation results ranked by best signal efficiency times signal purity " << Endl;
1899 // Log() << kINFO << hLine << Endl;
1900
1901 // // iterate over methods and evaluate
1902 // for (MVector::iterator itrMethod = methods->begin(); itrMethod != methods->end(); itrMethod++) {
1903 // MethodBase *theMethod = dynamic_cast<MethodBase *>(*itrMethod);
1904 // if (theMethod == 0) {
1905 // continue;
1906 // }
1907
1908 // TString header = "DataSet Name MVA Method ";
1909 // for (UInt_t icls = 0; icls < theMethod->fDataSetInfo.GetNClasses(); ++icls) {
1910 // header += TString::Format("%-12s ", theMethod->fDataSetInfo.GetClassInfo(icls)->GetName());
1911 // }
1912
1913 // Log() << kINFO << header << Endl;
1914 // Log() << kINFO << hLine << Endl;
1915 // for (Int_t i = 0; i < nmeth_used[0]; i++) {
1916 // TString res = TString::Format("[%-14s] %-15s", theMethod->fDataSetInfo.GetName(), mname[0][i].Data());
1917 // for (UInt_t icls = 0; icls < theMethod->fDataSetInfo.GetNClasses(); ++icls) {
1918 // res += TString::Format("%#1.3f ", (multiclass_testEff[i][icls]) * (multiclass_testPur[i][icls]));
1919 // }
1920 // Log() << kINFO << res << Endl;
1921 // }
1922
1923 // Log() << kINFO << hLine << Endl;
1924 // Log() << kINFO << Endl;
1925 // }
1926
1927 // --- 1 vs Rest ROC AUC, signal efficiency @ given background efficiency
1928 // --------------------------------------------------------------------
1929 TString header1 = TString::Format("%-15s%-15s%-15s%-15s%-15s%-15s", "Dataset", "MVA Method", "ROC AUC", "Sig eff@B=0.01",
1930 "Sig eff@B=0.10", "Sig eff@B=0.30");
1931 TString header2 = TString::Format("%-15s%-15s%-15s%-15s%-15s%-15s", "Name:", "/ Class:", "test (train)", "test (train)",
1932 "test (train)", "test (train)");
1933 Log() << kINFO << Endl;
1934 Log() << kINFO << "1-vs-rest performance metrics per class" << Endl;
1935 Log() << kINFO << hLine << Endl;
1936 Log() << kINFO << Endl;
1937 Log() << kINFO << "Considers the listed class as signal and the other classes" << Endl;
1938 Log() << kINFO << "as background, reporting the resulting binary performance." << Endl;
1939 Log() << kINFO << "A score of 0.820 (0.850) means 0.820 was acheived on the" << Endl;
1940 Log() << kINFO << "test set and 0.850 on the training set." << Endl;
1941
1942 Log() << kINFO << Endl;
1943 Log() << kINFO << header1 << Endl;
1944 Log() << kINFO << header2 << Endl;
1945 for (Int_t k = 0; k < 2; k++) {
1946 for (Int_t i = 0; i < nmeth_used[k]; i++) {
1947 if (k == 1) {
1948 mname[k][i].ReplaceAll("Variable_", "");
1949 }
1950
1951 const TString datasetName = itrMap->first;
1952 const TString mvaName = mname[k][i];
1953
1954 MethodBase *theMethod = dynamic_cast<MethodBase *>(GetMethod(datasetName, mvaName));
1955 if (theMethod == 0) {
1956 continue;
1957 }
1958
1959 Log() << kINFO << Endl;
1960 TString row = TString::Format("%-15s%-15s", datasetName.Data(), mvaName.Data());
1961 Log() << kINFO << row << Endl;
1962 Log() << kINFO << "------------------------------" << Endl;
1963
1964 UInt_t numClasses = theMethod->fDataSetInfo.GetNClasses();
1965 for (UInt_t iClass = 0; iClass < numClasses; ++iClass) {
1966
1967 ROCCurve *rocCurveTrain = GetROC(datasetName, mvaName, iClass, Types::kTraining);
1968 ROCCurve *rocCurveTest = GetROC(datasetName, mvaName, iClass, Types::kTesting);
1969
1970 const TString className = theMethod->DataInfo().GetClassInfo(iClass)->GetName();
1971 const Double_t rocaucTrain = rocCurveTrain->GetROCIntegral();
1972 const Double_t effB01Train = rocCurveTrain->GetEffSForEffB(0.01);
1973 const Double_t effB10Train = rocCurveTrain->GetEffSForEffB(0.10);
1974 const Double_t effB30Train = rocCurveTrain->GetEffSForEffB(0.30);
1975 const Double_t rocaucTest = rocCurveTest->GetROCIntegral();
1976 const Double_t effB01Test = rocCurveTest->GetEffSForEffB(0.01);
1977 const Double_t effB10Test = rocCurveTest->GetEffSForEffB(0.10);
1978 const Double_t effB30Test = rocCurveTest->GetEffSForEffB(0.30);
1979 const TString rocaucCmp = TString::Format("%5.3f (%5.3f)", rocaucTest, rocaucTrain);
1980 const TString effB01Cmp = TString::Format("%5.3f (%5.3f)", effB01Test, effB01Train);
1981 const TString effB10Cmp = TString::Format("%5.3f (%5.3f)", effB10Test, effB10Train);
1982 const TString effB30Cmp = TString::Format("%5.3f (%5.3f)", effB30Test, effB30Train);
1983 row = TString::Format("%-15s%-15s%-15s%-15s%-15s%-15s", "", className.Data(), rocaucCmp.Data(), effB01Cmp.Data(),
1984 effB10Cmp.Data(), effB30Cmp.Data());
1985 Log() << kINFO << row << Endl;
1986
1987 delete rocCurveTrain;
1988 delete rocCurveTest;
1989 }
1990 }
1991 }
1992 Log() << kINFO << Endl;
1993 Log() << kINFO << hLine << Endl;
1994 Log() << kINFO << Endl;
1995
1996 // --- Confusion matrices
1997 // --------------------------------------------------------------------
1998 auto printMatrix = [](TMatrixD const &matTraining, TMatrixD const &matTesting, std::vector<TString> classnames,
1999 UInt_t numClasses, MsgLogger &stream) {
2000 // assert (classLabledWidth >= valueLabelWidth + 2)
2001 // if (...) {Log() << kWARN << "..." << Endl; }
2002
2003 // TODO: Ensure matrices are same size.
2004
2005 TString header = TString::Format(" %-14s", " ");
2006 TString headerInfo = TString::Format(" %-14s", " ");
2007
2008 for (UInt_t iCol = 0; iCol < numClasses; ++iCol) {
2009 header += TString::Format(" %-14s", classnames[iCol].Data());
2010 headerInfo += TString::Format(" %-14s", " test (train)");
2011 }
2012 stream << kINFO << header << Endl;
2013 stream << kINFO << headerInfo << Endl;
2014
2015 for (UInt_t iRow = 0; iRow < numClasses; ++iRow) {
2016 stream << kINFO << TString::Format(" %-14s", classnames[iRow].Data());
2017
2018 for (UInt_t iCol = 0; iCol < numClasses; ++iCol) {
2019 if (iCol == iRow) {
2020 stream << kINFO << TString::Format(" %-14s", "-");
2021 } else {
2022 Double_t trainValue = matTraining[iRow][iCol];
2023 Double_t testValue = matTesting[iRow][iCol];
2024 TString entry = TString::Format("%-5.3f (%-5.3f)", testValue, trainValue);
2025 stream << kINFO << TString::Format(" %-14s", entry.Data());
2026 }
2027 }
2028 stream << kINFO << Endl;
2029 }
2030 };
2031
2032 Log() << kINFO << Endl;
2033 Log() << kINFO << "Confusion matrices for all methods" << Endl;
2034 Log() << kINFO << hLine << Endl;
2035 Log() << kINFO << Endl;
2036 Log() << kINFO << "Does a binary comparison between the two classes given by a " << Endl;
2037 Log() << kINFO << "particular row-column combination. In each case, the class " << Endl;
2038 Log() << kINFO << "given by the row is considered signal while the class given " << Endl;
2039 Log() << kINFO << "by the column index is considered background." << Endl;
2040 Log() << kINFO << Endl;
2041 for (UInt_t iMethod = 0; iMethod < methods->size(); ++iMethod) {
2042 MethodBase *theMethod = dynamic_cast<MethodBase *>(methods->at(iMethod));
2043 if (theMethod == nullptr) {
2044 continue;
2045 }
2046 UInt_t numClasses = theMethod->fDataSetInfo.GetNClasses();
2047
2048 std::vector<TString> classnames;
2049 for (UInt_t iCls = 0; iCls < numClasses; ++iCls) {
2050 classnames.push_back(theMethod->fDataSetInfo.GetClassInfo(iCls)->GetName());
2051 }
2052 Log() << kINFO
2053 << "=== Showing confusion matrix for method : " << Form("%-15s", (const char *)mname[0][iMethod])
2054 << Endl;
2055 Log() << kINFO << "(Signal Efficiency for Background Efficiency 0.01%)" << Endl;
2056 Log() << kINFO << "---------------------------------------------------" << Endl;
2057 printMatrix(multiclass_testConfusionEffB01[iMethod], multiclass_trainConfusionEffB01[iMethod], classnames,
2058 numClasses, Log());
2059 Log() << kINFO << Endl;
2060
2061 Log() << kINFO << "(Signal Efficiency for Background Efficiency 0.10%)" << Endl;
2062 Log() << kINFO << "---------------------------------------------------" << Endl;
2063 printMatrix(multiclass_testConfusionEffB10[iMethod], multiclass_trainConfusionEffB10[iMethod], classnames,
2064 numClasses, Log());
2065 Log() << kINFO << Endl;
2066
2067 Log() << kINFO << "(Signal Efficiency for Background Efficiency 0.30%)" << Endl;
2068 Log() << kINFO << "---------------------------------------------------" << Endl;
2069 printMatrix(multiclass_testConfusionEffB30[iMethod], multiclass_trainConfusionEffB30[iMethod], classnames,
2070 numClasses, Log());
2071 Log() << kINFO << Endl;
2072 }
2073 Log() << kINFO << hLine << Endl;
2074 Log() << kINFO << Endl;
2075
2076 } else {
2077 // Binary classification
2078 if (fROC) {
2079 Log().EnableOutput();
2081 Log() << Endl;
2082 TString hLine = "------------------------------------------------------------------------------------------"
2083 "-------------------------";
2084 Log() << kINFO << "Evaluation results ranked by best signal efficiency and purity (area)" << Endl;
2085 Log() << kINFO << hLine << Endl;
2086 Log() << kINFO << "DataSet MVA " << Endl;
2087 Log() << kINFO << "Name: Method: ROC-integ" << Endl;
2088
2089 // Log() << kDEBUG << "DataSet MVA Signal efficiency at bkg eff.(error):
2090 // | Sepa- Signifi- " << Endl; Log() << kDEBUG << "Name: Method: @B=0.01
2091 // @B=0.10 @B=0.30 ROC-integ ROCCurve| ration: cance: " << Endl;
2092 Log() << kDEBUG << hLine << Endl;
2093 for (Int_t k = 0; k < 2; k++) {
2094 if (k == 1 && nmeth_used[k] > 0) {
2095 Log() << kINFO << hLine << Endl;
2096 Log() << kINFO << "Input Variables: " << Endl << hLine << Endl;
2097 }
2098 for (Int_t i = 0; i < nmeth_used[k]; i++) {
2099 TString datasetName = itrMap->first;
2100 TString methodName = mname[k][i];
2101
2102 if (k == 1) {
2103 methodName.ReplaceAll("Variable_", "");
2104 }
2105
2106 MethodBase *theMethod = dynamic_cast<MethodBase *>(GetMethod(datasetName, methodName));
2107 if (theMethod == 0) {
2108 continue;
2109 }
2110
2111 TMVA::DataSet *dataset = theMethod->Data();
2112 TMVA::Results *results = dataset->GetResults(methodName, Types::kTesting, this->fAnalysisType);
2113 std::vector<Bool_t> *mvaResType =
2114 dynamic_cast<ResultsClassification *>(results)->GetValueVectorTypes();
2115
2116 Double_t rocIntegral = 0.0;
2117 if (mvaResType->size() != 0) {
2118 rocIntegral = GetROCIntegral(datasetName, methodName);
2119 }
2120
2121 if (sep[k][i] < 0 || sig[k][i] < 0) {
2122 // cannot compute separation/significance -> no MVA (usually for Cuts)
2123 Log() << kINFO << Form("%-13s %-15s: %#1.3f", datasetName.Data(), methodName.Data(), effArea[k][i])
2124 << Endl;
2125
2126 // Log() << kDEBUG << Form("%-20s %-15s: %#1.3f(%02i) %#1.3f(%02i) %#1.3f(%02i)
2127 // %#1.3f %#1.3f | -- --",
2128 // datasetName.Data(),
2129 // methodName.Data(),
2130 // eff01[k][i], Int_t(1000*eff01err[k][i]),
2131 // eff10[k][i], Int_t(1000*eff10err[k][i]),
2132 // eff30[k][i], Int_t(1000*eff30err[k][i]),
2133 // effArea[k][i],rocIntegral) << Endl;
2134 } else {
2135 Log() << kINFO << Form("%-13s %-15s: %#1.3f", datasetName.Data(), methodName.Data(), rocIntegral)
2136 << Endl;
2137 // Log() << kDEBUG << Form("%-20s %-15s: %#1.3f(%02i) %#1.3f(%02i) %#1.3f(%02i)
2138 // %#1.3f %#1.3f | %#1.3f %#1.3f",
2139 // datasetName.Data(),
2140 // methodName.Data(),
2141 // eff01[k][i], Int_t(1000*eff01err[k][i]),
2142 // eff10[k][i], Int_t(1000*eff10err[k][i]),
2143 // eff30[k][i], Int_t(1000*eff30err[k][i]),
2144 // effArea[k][i],rocIntegral,
2145 // sep[k][i], sig[k][i]) << Endl;
2146 }
2147 }
2148 }
2149 Log() << kINFO << hLine << Endl;
2150 Log() << kINFO << Endl;
2151 Log() << kINFO << "Testing efficiency compared to training efficiency (overtraining check)" << Endl;
2152 Log() << kINFO << hLine << Endl;
2153 Log() << kINFO
2154 << "DataSet MVA Signal efficiency: from test sample (from training sample) "
2155 << Endl;
2156 Log() << kINFO << "Name: Method: @B=0.01 @B=0.10 @B=0.30 "
2157 << Endl;
2158 Log() << kINFO << hLine << Endl;
2159 for (Int_t k = 0; k < 2; k++) {
2160 if (k == 1 && nmeth_used[k] > 0) {
2161 Log() << kINFO << hLine << Endl;
2162 Log() << kINFO << "Input Variables: " << Endl << hLine << Endl;
2163 }
2164 for (Int_t i = 0; i < nmeth_used[k]; i++) {
2165 if (k == 1)
2166 mname[k][i].ReplaceAll("Variable_", "");
2167 MethodBase *theMethod = dynamic_cast<MethodBase *>((*methods)[i]);
2168 if (theMethod == 0)
2169 continue;
2170
2171 Log() << kINFO
2172 << Form("%-20s %-15s: %#1.3f (%#1.3f) %#1.3f (%#1.3f) %#1.3f (%#1.3f)",
2173 theMethod->fDataSetInfo.GetName(), (const char *)mname[k][i], eff01[k][i],
2174 trainEff01[k][i], eff10[k][i], trainEff10[k][i], eff30[k][i], trainEff30[k][i])
2175 << Endl;
2176 }
2177 }
2178 Log() << kINFO << hLine << Endl;
2179 Log() << kINFO << Endl;
2180
2181 if (gTools().CheckForSilentOption(GetOptions()))
2182 Log().InhibitOutput();
2183 } // end fROC
2184 }
2185 if (!IsSilentFile()) {
2186 std::list<TString> datasets;
2187 for (Int_t k = 0; k < 2; k++) {
2188 for (Int_t i = 0; i < nmeth_used[k]; i++) {
2189 MethodBase *theMethod = dynamic_cast<MethodBase *>((*methods)[i]);
2190 if (theMethod == 0)
2191 continue;
2192 // write test/training trees
2193 RootBaseDir()->cd(theMethod->fDataSetInfo.GetName());
2194 if (std::find(datasets.begin(), datasets.end(), theMethod->fDataSetInfo.GetName()) == datasets.end()) {
2197 datasets.push_back(theMethod->fDataSetInfo.GetName());
2198 }
2199 }
2200 }
2201 }
2202 } // end for MethodsMap
2203 // references for citation
2205}
2206
2207////////////////////////////////////////////////////////////////////////////////
2208/// Evaluate Variable Importance
2209
2210TH1F *TMVA::Factory::EvaluateImportance(DataLoader *loader, VIType vitype, Types::EMVA theMethod, TString methodTitle,
2211 const char *theOption)
2212{
2214 fSilentFile = kTRUE; // we need silent file here because we need fast classification results
2215
2216 // getting number of variables and variable names from loader
2217 const int nbits = loader->GetDataSetInfo().GetNVariables();
2218 if (vitype == VIType::kShort)
2219 return EvaluateImportanceShort(loader, theMethod, methodTitle, theOption);
2220 else if (vitype == VIType::kAll)
2221 return EvaluateImportanceAll(loader, theMethod, methodTitle, theOption);
2222 else if (vitype == VIType::kRandom) {
2223 if ( nbits > 10 && nbits < 30) {
2224 // limit nbits to less than 30 to avoid error converting from double to uint and also cannot deal with too many combinations
2225 return EvaluateImportanceRandom(loader, static_cast<UInt_t>( pow(2, nbits) ), theMethod, methodTitle, theOption);
2226 } else if (nbits < 10) {
2227 Log() << kERROR << "Error in Variable Importance: Random mode require more that 10 variables in the dataset."
2228 << Endl;
2229 } else if (nbits > 30) {
2230 Log() << kERROR << "Error in Variable Importance: Number of variables is too large for Random mode"
2231 << Endl;
2232 }
2233 }
2234 return nullptr;
2235}
2236
2237////////////////////////////////////////////////////////////////////////////////
2238
2240 const char *theOption)
2241{
2242
2243 uint64_t x = 0;
2244 uint64_t y = 0;
2245
2246 // getting number of variables and variable names from loader
2247 const int nbits = loader->GetDataSetInfo().GetNVariables();
2248 std::vector<TString> varNames = loader->GetDataSetInfo().GetListOfVariables();
2249
2250 if (nbits > 60) {
2251 Log() << kERROR << "Number of combinations is too large , is 2^" << nbits << Endl;
2252 return nullptr;
2253 }
2254 if (nbits > 20) {
2255 Log() << kWARNING << "Number of combinations is very large , is 2^" << nbits << Endl;
2256 }
2257 uint64_t range = static_cast<uint64_t>(pow(2, nbits));
2258
2259
2260 // vector to save importances
2261 std::vector<Double_t> importances(nbits);
2262 // vector to save ROC
2263 std::vector<Double_t> ROC(range);
2264 ROC[0] = 0.5;
2265 for (int i = 0; i < nbits; i++)
2266 importances[i] = 0;
2267
2268 Double_t SROC, SSROC; // computed ROC value
2269 for (x = 1; x < range; x++) {
2270
2271 std::bitset<VIBITS> xbitset(x);
2272 if (x == 0)
2273 continue; // data loader need at least one variable
2274
2275 // creating loader for seed
2276 TMVA::DataLoader *seedloader = new TMVA::DataLoader(xbitset.to_string());
2277
2278 // adding variables from seed
2279 for (int index = 0; index < nbits; index++) {
2280 if (xbitset[index])
2281 seedloader->AddVariable(varNames[index], 'F');
2282 }
2283
2284 DataLoaderCopy(seedloader, loader);
2285 seedloader->PrepareTrainingAndTestTree(loader->GetDataSetInfo().GetCut("Signal"),
2286 loader->GetDataSetInfo().GetCut("Background"),
2287 loader->GetDataSetInfo().GetSplitOptions());
2288
2289 // Booking Seed
2290 BookMethod(seedloader, theMethod, methodTitle, theOption);
2291
2292 // Train/Test/Evaluation
2296
2297 // getting ROC
2298 ROC[x] = GetROCIntegral(xbitset.to_string(), methodTitle);
2299
2300 // cleaning information to process sub-seeds
2301 TMVA::MethodBase *smethod = dynamic_cast<TMVA::MethodBase *>(fMethodsMap[xbitset.to_string().c_str()][0][0]);
2304 delete sresults;
2305 delete seedloader;
2306 this->DeleteAllMethods();
2307
2308 fMethodsMap.clear();
2309 // removing global result because it is requiring a lot of RAM for all seeds
2310 }
2311
2312 for (x = 0; x < range; x++) {
2313 SROC = ROC[x];
2314 for (uint32_t i = 0; i < VIBITS; ++i) {
2315 if (x & (uint64_t(1) << i)) {
2316 y = x & ~(uint64_t(1) << i);
2317 std::bitset<VIBITS> ybitset(y);
2318 // need at least one variable
2319 // NOTE: if sub-seed is zero then is the special case
2320 // that count in xbitset is 1
2321 uint32_t ny = static_cast<uint32_t>( log(x - y) / 0.693147 ) ;
2322 if (y == 0) {
2323 importances[ny] = SROC - 0.5;
2324 continue;
2325 }
2326
2327 // getting ROC
2328 SSROC = ROC[y];
2329 importances[ny] += SROC - SSROC;
2330 // cleaning information
2331 }
2332 }
2333 }
2334 std::cout << "--- Variable Importance Results (All)" << std::endl;
2335 return GetImportance(nbits, importances, varNames);
2336}
2337
2338static uint64_t sum(uint64_t i)
2339{
2340 // add a limit for overflows
2341 if (i > 62) return 0;
2342 return static_cast<uint64_t>( std::pow(2, i + 1)) - 1;
2343 // uint64_t _sum = 0;
2344 // for (uint64_t n = 0; n < i; n++)
2345 // _sum += pow(2, n);
2346 // return _sum;
2347}
2348
2349////////////////////////////////////////////////////////////////////////////////
2350
2352 const char *theOption)
2353{
2354 uint64_t x = 0;
2355 uint64_t y = 0;
2356
2357 // getting number of variables and variable names from loader
2358 const int nbits = loader->GetDataSetInfo().GetNVariables();
2359 std::vector<TString> varNames = loader->GetDataSetInfo().GetListOfVariables();
2360
2361 if (nbits > 60) {
2362 Log() << kERROR << "Number of combinations is too large , is 2^" << nbits << Endl;
2363 return nullptr;
2364 }
2365 long int range = sum(nbits);
2366 // std::cout<<range<<std::endl;
2367 // vector to save importances
2368 std::vector<Double_t> importances(nbits);
2369 for (int i = 0; i < nbits; i++)
2370 importances[i] = 0;
2371
2372 Double_t SROC, SSROC; // computed ROC value
2373
2374 x = range;
2375
2376 std::bitset<VIBITS> xbitset(x);
2377 if (x == 0)
2378 Log() << kFATAL << "Error: need at least one variable."; // data loader need at least one variable
2379
2380 // creating loader for seed
2381 TMVA::DataLoader *seedloader = new TMVA::DataLoader(xbitset.to_string());
2382
2383 // adding variables from seed
2384 for (int index = 0; index < nbits; index++) {
2385 if (xbitset[index])
2386 seedloader->AddVariable(varNames[index], 'F');
2387 }
2388
2389 // Loading Dataset
2390 DataLoaderCopy(seedloader, loader);
2391
2392 // Booking Seed
2393 BookMethod(seedloader, theMethod, methodTitle, theOption);
2394
2395 // Train/Test/Evaluation
2399
2400 // getting ROC
2401 SROC = GetROCIntegral(xbitset.to_string(), methodTitle);
2402
2403 // cleaning information to process sub-seeds
2404 TMVA::MethodBase *smethod = dynamic_cast<TMVA::MethodBase *>(fMethodsMap[xbitset.to_string().c_str()][0][0]);
2407 delete sresults;
2408 delete seedloader;
2409 this->DeleteAllMethods();
2410 fMethodsMap.clear();
2411
2412 // removing global result because it is requiring a lot of RAM for all seeds
2413
2414 for (uint32_t i = 0; i < VIBITS; ++i) {
2415 if (x & (uint64_t(1) << i)) {
2416 y = x & ~(uint64_t(1) << i);
2417 std::bitset<VIBITS> ybitset(y);
2418 // need at least one variable
2419 // NOTE: if sub-seed is zero then is the special case
2420 // that count in xbitset is 1
2421 uint32_t ny = static_cast<uint32_t>(log(x - y) / 0.693147);
2422 if (y == 0) {
2423 importances[ny] = SROC - 0.5;
2424 continue;
2425 }
2426
2427 // creating loader for sub-seed
2428 TMVA::DataLoader *subseedloader = new TMVA::DataLoader(ybitset.to_string());
2429 // adding variables from sub-seed
2430 for (int index = 0; index < nbits; index++) {
2431 if (ybitset[index])
2432 subseedloader->AddVariable(varNames[index], 'F');
2433 }
2434
2435 // Loading Dataset
2436 DataLoaderCopy(subseedloader, loader);
2437
2438 // Booking SubSeed
2439 BookMethod(subseedloader, theMethod, methodTitle, theOption);
2440
2441 // Train/Test/Evaluation
2445
2446 // getting ROC
2447 SSROC = GetROCIntegral(ybitset.to_string(), methodTitle);
2448 importances[ny] += SROC - SSROC;
2449
2450 // cleaning information
2451 TMVA::MethodBase *ssmethod = dynamic_cast<TMVA::MethodBase *>(fMethodsMap[ybitset.to_string().c_str()][0][0]);
2454 delete ssresults;
2455 delete subseedloader;
2456 this->DeleteAllMethods();
2457 fMethodsMap.clear();
2458 }
2459 }
2460 std::cout << "--- Variable Importance Results (Short)" << std::endl;
2461 return GetImportance(nbits, importances, varNames);
2462}
2463
2464////////////////////////////////////////////////////////////////////////////////
2465
2467 TString methodTitle, const char *theOption)
2468{
2469 TRandom3 *rangen = new TRandom3(0); // Random Gen.
2470
2471 uint64_t x = 0;
2472 uint64_t y = 0;
2473
2474 // getting number of variables and variable names from loader
2475 const int nbits = loader->GetDataSetInfo().GetNVariables();
2476 std::vector<TString> varNames = loader->GetDataSetInfo().GetListOfVariables();
2477
2478 long int range = pow(2, nbits);
2479
2480 // vector to save importances
2481 std::vector<Double_t> importances(nbits);
2482 for (int i = 0; i < nbits; i++)
2483 importances[i] = 0;
2484
2485 Double_t SROC, SSROC; // computed ROC value
2486 for (UInt_t n = 0; n < nseeds; n++) {
2487 x = rangen->Integer(range);
2488
2489 std::bitset<32> xbitset(x);
2490 if (x == 0)
2491 continue; // data loader need at least one variable
2492
2493 // creating loader for seed
2494 TMVA::DataLoader *seedloader = new TMVA::DataLoader(xbitset.to_string());
2495
2496 // adding variables from seed
2497 for (int index = 0; index < nbits; index++) {
2498 if (xbitset[index])
2499 seedloader->AddVariable(varNames[index], 'F');
2500 }
2501
2502 // Loading Dataset
2503 DataLoaderCopy(seedloader, loader);
2504
2505 // Booking Seed
2506 BookMethod(seedloader, theMethod, methodTitle, theOption);
2507
2508 // Train/Test/Evaluation
2512
2513 // getting ROC
2514 SROC = GetROCIntegral(xbitset.to_string(), methodTitle);
2515 // std::cout << "Seed: n " << n << " x " << x << " xbitset:" << xbitset << " ROC " << SROC << std::endl;
2516
2517 // cleaning information to process sub-seeds
2518 TMVA::MethodBase *smethod = dynamic_cast<TMVA::MethodBase *>(fMethodsMap[xbitset.to_string().c_str()][0][0]);
2521 delete sresults;
2522 delete seedloader;
2523 this->DeleteAllMethods();
2524 fMethodsMap.clear();
2525
2526 // removing global result because it is requiring a lot of RAM for all seeds
2527
2528 for (uint32_t i = 0; i < 32; ++i) {
2529 if (x & (uint64_t(1) << i)) {
2530 y = x & ~(uint64_t(1) << i);
2531 std::bitset<32> ybitset(y);
2532 // need at least one variable
2533 // NOTE: if sub-seed is zero then is the special case
2534 // that count in xbitset is 1
2535 Double_t ny = log(x - y) / 0.693147;
2536 if (y == 0) {
2537 importances[ny] = SROC - 0.5;
2538 // std::cout << "SubSeed: " << y << " y:" << ybitset << "ROC " << 0.5 << std::endl;
2539 continue;
2540 }
2541
2542 // creating loader for sub-seed
2543 TMVA::DataLoader *subseedloader = new TMVA::DataLoader(ybitset.to_string());
2544 // adding variables from sub-seed
2545 for (int index = 0; index < nbits; index++) {
2546 if (ybitset[index])
2547 subseedloader->AddVariable(varNames[index], 'F');
2548 }
2549
2550 // Loading Dataset
2551 DataLoaderCopy(subseedloader, loader);
2552
2553 // Booking SubSeed
2554 BookMethod(subseedloader, theMethod, methodTitle, theOption);
2555
2556 // Train/Test/Evaluation
2560
2561 // getting ROC
2562 SSROC = GetROCIntegral(ybitset.to_string(), methodTitle);
2563 importances[ny] += SROC - SSROC;
2564 // std::cout << "SubSeed: " << y << " y:" << ybitset << " x-y " << x - y << " " << std::bitset<32>(x - y) <<
2565 // " ny " << ny << " SROC " << SROC << " SSROC " << SSROC << " Importance = " << importances[ny] <<
2566 // std::endl; cleaning information
2567 TMVA::MethodBase *ssmethod =
2568 dynamic_cast<TMVA::MethodBase *>(fMethodsMap[ybitset.to_string().c_str()][0][0]);
2571 delete ssresults;
2572 delete subseedloader;
2573 this->DeleteAllMethods();
2574 fMethodsMap.clear();
2575 }
2576 }
2577 }
2578 std::cout << "--- Variable Importance Results (Random)" << std::endl;
2579 return GetImportance(nbits, importances, varNames);
2580}
2581
2582////////////////////////////////////////////////////////////////////////////////
2583
2584TH1F *TMVA::Factory::GetImportance(const int nbits, std::vector<Double_t> importances, std::vector<TString> varNames)
2585{
2586 TH1F *vih1 = new TH1F("vih1", "", nbits, 0, nbits);
2587
2588 gStyle->SetOptStat(000000);
2589
2590 Float_t normalization = 0.0;
2591 for (int i = 0; i < nbits; i++) {
2592 normalization = normalization + importances[i];
2593 }
2594
2595 Float_t roc = 0.0;
2596
2597 gStyle->SetTitleXOffset(0.4);
2598 gStyle->SetTitleXOffset(1.2);
2599
2600 std::vector<Double_t> x_ie(nbits), y_ie(nbits);
2601 for (Int_t i = 1; i < nbits + 1; i++) {
2602 x_ie[i - 1] = (i - 1) * 1.;
2603 roc = 100.0 * importances[i - 1] / normalization;
2604 y_ie[i - 1] = roc;
2605 std::cout << "--- " << varNames[i - 1] << " = " << roc << " %" << std::endl;
2606 vih1->GetXaxis()->SetBinLabel(i, varNames[i - 1].Data());
2607 vih1->SetBinContent(i, roc);
2608 }
2609 TGraph *g_ie = new TGraph(nbits + 2, &x_ie[0], &y_ie[0]);
2610 g_ie->SetTitle("");
2611
2612 vih1->LabelsOption("v >", "X");
2613 vih1->SetBarWidth(0.97);
2614 Int_t ca = TColor::GetColor("#006600");
2615 vih1->SetFillColor(ca);
2616 // Int_t ci = TColor::GetColor("#990000");
2617
2618 vih1->GetYaxis()->SetTitle("Importance (%)");
2619 vih1->GetYaxis()->SetTitleSize(0.045);
2620 vih1->GetYaxis()->CenterTitle();
2621 vih1->GetYaxis()->SetTitleOffset(1.24);
2622
2623 vih1->GetYaxis()->SetRangeUser(-7, 50);
2624 vih1->SetDirectory(nullptr);
2625
2626 // vih1->Draw("B");
2627 return vih1;
2628}
#define MinNoTrainingEvents
#define h(i)
Definition RSha256.hxx:106
void printMatrix(const TMatrixD &mat)
write a matrix
int Int_t
Signed integer 4 bytes (int).
Definition RtypesCore.h:59
unsigned int UInt_t
Unsigned integer 4 bytes (unsigned int).
Definition RtypesCore.h:60
bool Bool_t
Boolean (0=false, 1=true) (bool).
Definition RtypesCore.h:77
constexpr Bool_t kFALSE
Definition RtypesCore.h:108
double Double_t
Double 8 bytes.
Definition RtypesCore.h:73
float Float_t
Float 4 bytes (float).
Definition RtypesCore.h:71
constexpr Bool_t kTRUE
Definition RtypesCore.h:107
if(name) objname
char name[80]
Definition TGX11.cxx:148
TMatrixT< Double_t > TMatrixD
Definition TMatrixDfwd.h:23
Double_t err
#define gROOT
Definition TROOT.h:417
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
Definition TString.cxx:2496
externTStyle * gStyle
Definition TStyle.h:442
externTSystem * gSystem
Definition TSystem.h:582
virtual void SetTitleOffset(Float_t offset=1)
Set distance between the axis and the axis title.
Definition TAttAxis.cxx:279
virtual void SetTitleSize(Float_t size=0.04)
Set size of axis title.
Definition TAttAxis.cxx:290
virtual void SetFillColor(Color_t fcolor)
Set the fill area color.
Definition TAttFill.h:40
virtual void SetLineWidth(Width_t lwidth)
Set the line width.
Definition TAttLine.h:47
virtual void SetLineColor(Color_t lcolor)
Set the line color.
Definition TAttLine.h:44
virtual void SetBinLabel(Int_t bin, const char *label)
Set label for bin.
Definition TAxis.cxx:891
void CenterTitle(Bool_t center=kTRUE)
Center axis title.
Definition TAxis.h:196
virtual void SetRangeUser(Double_t ufirst, Double_t ulast)
Set the viewing range for the axis from ufirst to ulast (in user coordinates, that is,...
Definition TAxis.cxx:1090
The Canvas class.
Definition TCanvas.h:23
static Int_t GetColor(const char *hexcolor)
A file, usually with extension .root, that stores data and code in the form of serialized objects in ...
Definition TFile.h:130
TAxis * GetXaxis() const
TAxis * GetYaxis() const
void SetTitle(const char *title="") override
Set the title of the TNamed.
1-D histogram with a float per channel (see TH1 documentation)
Definition TH1.h:878
virtual void SetDirectory(TDirectory *dir)
By default, when a histogram is created, it is added to the list of histogram objects in the current ...
Definition TH1.cxx:9074
void SetTitle(const char *title) override
Change/set the title.
Definition TH1.cxx:6836
virtual void LabelsOption(Option_t *option="h", Option_t *axis="X")
Sort bins with labels or set option(s) to draw axis with labels.
Definition TH1.cxx:5464
static void AddDirectory(Bool_t add=kTRUE)
Sets the flag controlling the automatic add of histograms in memory.
Definition TH1.cxx:1309
TAxis * GetXaxis()
Definition TH1.h:571
TAxis * GetYaxis()
Definition TH1.h:572
virtual void SetBarWidth(Float_t width=0.5)
Set the width of bars as fraction of the bin width for drawing mode "B".
Definition TH1.h:613
virtual void SetBinContent(Int_t bin, Double_t content)
Set bin content see convention for numbering bins in TH1::GetBin In case the bin number is greater th...
Definition TH1.cxx:9356
Service class for 2-D histogram classes.
Definition TH2.h:30
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
TString fWeightFileDir
Definition Config.h:124
TString fWeightFileDirPrefix
Definition Config.h:123
void SetDrawProgressBar(Bool_t d)
Definition Config.h:69
void SetUseColor(Bool_t uc)
Definition Config.h:60
class TMVA::Config::VariablePlotting fVariablePlotting
void SetSilent(Bool_t s)
Definition Config.h:63
IONames & GetIONames()
Definition Config.h:98
void SetConfigDescription(const char *d)
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
void AddPreDefVal(const T &)
void SetConfigName(const char *n)
Configurable(const TString &theOption="")
constructor
virtual void ParseOptions()
options parser
const TString & GetOptions() const
MsgLogger & Log() const
MsgLogger * fLogger
! message logger
void CheckForUnusedOptions() const
checks for unused options in option string
UInt_t GetEntries(const TString &name) const
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
DataSetInfo & GetDataSetInfo()
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
Class that contains all the data information.
Definition DataSetInfo.h:62
UInt_t GetNVariables() const
const TMatrixD * CorrelationMatrix(const TString &className) const
UInt_t GetNClasses() const
const TString & GetSplitOptions() const
UInt_t GetNTargets() const
DataSet * GetDataSet() const
returns data set
TH2 * CreateCorrelationMatrixHist(const TMatrixD *m, const TString &hName, const TString &hTitle) const
std::vector< TString > GetListOfVariables() const
returns list of variables
const char * GetName() const override
Returns name of object.
Definition DataSetInfo.h:71
ClassInfo * GetClassInfo(Int_t clNum) const
const TCut & GetCut(Int_t i) const
VariableInfo & GetVariableInfo(Int_t i)
Bool_t IsSignal(const Event *ev) const
DataSetManager * GetDataSetManager()
DataInputHandler & DataInput()
Class that contains all the data information.
Definition DataSet.h:58
Long64_t GetNEvtSigTest()
return number of signal test events in dataset
Definition DataSet.cxx:427
TTree * GetTree(Types::ETreeType type)
create the test/trainings tree with all the variables, the weights, the classes, the targets,...
Definition DataSet.cxx:609
const Event * GetEvent() const
returns event without transformations
Definition DataSet.cxx:202
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition DataSet.h:206
Results * GetResults(const TString &, Types::ETreeType type, Types::EAnalysisType analysistype)
Definition DataSet.cxx:265
Long64_t GetNTrainingEvents() const
Definition DataSet.h:68
void SetCurrentType(Types::ETreeType type) const
Definition DataSet.h:89
const std::vector< Event * > & GetEventCollection(Types::ETreeType type=Types::kMaxTreeType) const
Definition DataSet.h:216
Long64_t GetNEvtBkgdTest()
return number of background test events in dataset
Definition DataSet.cxx:435
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition Event.cxx:236
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition Event.cxx:399
TString const & tString() const
Definition Factory.h:109
void PrintHelpMessage(const TString &datasetname, const TString &methodTitle="") const
Print predefined help message of classifier.
Definition Factory.cxx:1327
Bool_t fSilentFile
! used in constructor without file
Definition Factory.h:226
Bool_t fCorrelations
! enable to calculate correlations
Definition Factory.h:224
std::vector< IMethod * > MVector
Definition Factory.h:84
void TrainAllMethods()
Iterates through all booked methods and calls training.
Definition Factory.cxx:1108
Bool_t Verbose(void) const
Definition Factory.h:143
void WriteDataInformation(DataSetInfo &fDataSetInfo)
Definition Factory.cxx:596
Factory(TString theJobName, TFile *theTargetFile, TString theOption="")
Standard constructor.
Definition Factory.cxx:112
void TestAllMethods()
Evaluates all booked methods on the testing data and adds the output to the Results in the corresponi...
Definition Factory.cxx:1265
Bool_t fVerbose
! verbose mode
Definition Factory.h:222
void EvaluateAllMethods(void)
Iterates over all MVAs that have been booked, and calls their evaluation methods.
Definition Factory.cxx:1370
TH1F * EvaluateImportanceRandom(DataLoader *loader, UInt_t nseeds, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Definition Factory.cxx:2466
TH1F * GetImportance(const int nbits, std::vector< Double_t > importances, std::vector< TString > varNames)
Definition Factory.cxx:2584
Bool_t fROC
! enable to calculate ROC values
Definition Factory.h:225
void EvaluateAllVariables(DataLoader *loader, TString options="")
Iterates over all MVA input variables and evaluates them.
Definition Factory.cxx:1354
TDirectory * RootBaseDir()
Definition Factory.h:158
TString fVerboseLevel
! verbosity level, controls granularity of logging
Definition Factory.h:223
TMultiGraph * GetROCCurveAsMultiGraph(DataLoader *loader, UInt_t iClass, Types::ETreeType type=Types::kTesting)
Generate a collection of graphs, for all methods for a given class.
Definition Factory.cxx:982
TH1F * EvaluateImportance(DataLoader *loader, VIType vitype, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Evaluate Variable Importance.
Definition Factory.cxx:2210
Double_t GetROCIntegral(DataLoader *loader, TString theMethodName, UInt_t iClass=0, Types::ETreeType type=Types::kTesting)
Calculate the integral of the ROC curve, also known as the area under curve (AUC),...
Definition Factory.cxx:843
std::map< TString, MVector * > fMethodsMap
Definition Factory.h:85
virtual ~Factory()
Destructor.
Definition Factory.cxx:305
MethodBase * BookMethod(DataLoader *loader, MethodName theMethodName, TString methodTitle, TString theOption="")
Books an MVA classifier or regression method.
Definition Factory.cxx:357
virtual void MakeClass(const TString &datasetname, const TString &methodTitle="") const
Definition Factory.cxx:1299
MethodBase * BookMethodWeightfile(DataLoader *dataloader, TMVA::Types::EMVA methodType, const TString &weightfile)
Adds an already constructed method to be managed by this factory.
Definition Factory.cxx:495
Bool_t fModelPersistence
! option to save the trained model in xml file or using serialization
Definition Factory.h:231
std::map< TString, Double_t > OptimizeAllMethods(TString fomType="ROCIntegral", TString fitType="FitGA")
Iterates through all booked methods and sees if they use parameter tuning and if so does just that,...
Definition Factory.cxx:695
ROCCurve * GetROC(DataLoader *loader, TString theMethodName, UInt_t iClass=0, Types::ETreeType type=Types::kTesting)
Private method to generate a ROCCurve instance for a given method.
Definition Factory.cxx:743
Bool_t IsSilentFile() const
Definition Factory.h:160
TH1F * EvaluateImportanceShort(DataLoader *loader, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Definition Factory.cxx:2351
Types::EAnalysisType fAnalysisType
! the training type
Definition Factory.h:230
TString fJobName
! jobname, used as extension in weight file names
Definition Factory.h:228
Bool_t HasMethod(const TString &datasetname, const TString &title) const
Checks whether a given method name is defined for a given dataset.
Definition Factory.cxx:580
TGraph * GetROCCurve(DataLoader *loader, TString theMethodName, Bool_t setTitles=kTRUE, UInt_t iClass=0, Types::ETreeType type=Types::kTesting)
Argument iClass specifies the class to generate the ROC curve in a multiclass setting.
Definition Factory.cxx:906
TH1F * EvaluateImportanceAll(DataLoader *loader, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Definition Factory.cxx:2239
void SetVerbose(Bool_t v=kTRUE)
Definition Factory.cxx:342
TFile * fgTargetFile
! ROOT output file
Definition Factory.h:214
std::vector< TMVA::VariableTransformBase * > fDefaultTrfs
! list of transformations on default DataSet
Definition Factory.h:217
IMethod * GetMethod(const TString &datasetname, const TString &title) const
Returns pointer to MVA that corresponds to given method title.
Definition Factory.cxx:560
void DeleteAllMethods(void)
Delete methods.
Definition Factory.cxx:323
TString fTransformations
! list of transformations to test
Definition Factory.h:221
void Greetings()
Print welcome message.
Definition Factory.cxx:294
Interface for all concrete MVA method implementations.
Definition IMethod.h:53
virtual void PrintHelpMessage() const =0
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)=0
virtual void MakeClass(const TString &classFileName=TString("")) const =0
Virtual base Class for all MVA method.
Definition MethodBase.h:111
void PrintHelpMessage() const override
prints out method-specific help method
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as
const char * GetName() const override
Definition MethodBase.h:337
void SetSilentFile(Bool_t status)
Definition MethodBase.h:381
void SetWeightFileDir(TString fileDir)
set directory of weight file
virtual void TestRegression(Double_t &bias, Double_t &biasT, Double_t &dev, Double_t &devT, Double_t &rms, Double_t &rmsT, Double_t &mInf, Double_t &mInfT, Double_t &corr, Types::ETreeType type)
calculate <sum-of-deviation-squared> of regression output versus "true" value from test sample
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
TString GetMethodTypeName() const
Definition MethodBase.h:335
Bool_t DoMulticlass() const
Definition MethodBase.h:442
virtual Double_t GetSignificance() const
compute significance of mean difference
Types::EAnalysisType GetAnalysisType() const
Definition MethodBase.h:440
virtual TMatrixD GetMulticlassConfusionMatrix(Double_t effB, Types::ETreeType type)
Construct a confusion matrix for a multiclass classifier.
virtual void WriteEvaluationHistosToFile(Types::ETreeType treetype)
writes all MVA evaluation histograms to file
virtual void TestMulticlass()
test multiclass classification
void SetupMethod()
setup of methods
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
virtual void SetAnalysisType(Types::EAnalysisType type)
Definition MethodBase.h:439
const TString & GetMethodName() const
Definition MethodBase.h:334
Bool_t DoRegression() const
Definition MethodBase.h:441
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
virtual Double_t GetTrainingEfficiency(const TString &)
DataSetInfo & DataInfo() const
Definition MethodBase.h:413
virtual void TestClassification()
initialization
void AddOutput(Types::ETreeType type, Types::EAnalysisType analysisType)
void MakeClass(const TString &classFileName=TString("")) const override
create reader class for method (classification only at present)
void ReadStateFromFile()
Function to write options and weights to file.
virtual std::map< TString, Double_t > OptimizeTuningParameters(TString fomType="ROCIntegral", TString fitType="FitGA")
call the Optimizer with the set of parameters and ranges that are meant to be tuned.
DataSetInfo & fDataSetInfo
! the data set information (sometimes needed)
Definition MethodBase.h:610
Types::EMVA GetMethodType() const
Definition MethodBase.h:336
void SetFile(TFile *file)
Definition MethodBase.h:378
DataSet * Data() const
Definition MethodBase.h:412
void SetModelPersistence(Bool_t status)
Definition MethodBase.h:385
virtual Double_t GetROCIntegral(TH1D *histS, TH1D *histB) const
calculate the area (integral) under the ROC curve as a overall quality measure of the classification
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Class for boosting a TMVA method.
Definition MethodBoost.h:58
void SetBoostedMethodName(TString methodName)
Definition MethodBoost.h:86
DataSetManager * fDataSetManager
DSMTEST.
Class for categorizing the phase space.
DataSetManager * fDataSetManager
ostringstream derivative to redirect and format output
Definition MsgLogger.h:57
Double_t GetEffSForEffB(Double_t effB, const UInt_t num_points=41)
Calculate the signal efficiency (sensitivity) for a given background efficiency (sensitivity).
Definition ROCCurve.cxx:217
Double_t GetROCIntegral(const UInt_t points=41)
Calculates the ROC integral (AUC).
Definition ROCCurve.cxx:248
TGraph * GetROCCurve(const UInt_t points=100)
Returns a new TGraph containing the ROC curve.
Definition ROCCurve.cxx:274
Ranking for variables in method (implementation).
Definition Ranking.h:48
virtual void Print() const
get maximum length of variable names
Definition Ranking.cxx:110
Class that is the base-class for a vector of result.
Class which takes the results of a multiclass classification.
Class that is the base-class for a vector of result.
Definition Results.h:57
void FormattedOutput(const std::vector< Double_t > &, const std::vector< TString > &, const TString titleVars, const TString titleValues, MsgLogger &logger, TString format="%+1.3f")
formatted output of simple table
Definition Tools.cxx:862
void ROOTVersionMessage(MsgLogger &logger)
prints the ROOT release number and date
Definition Tools.cxx:1300
void UsefulSortDescending(std::vector< std::vector< Double_t > > &, std::vector< TString > *vs=nullptr)
sort 2D vector (AND in parallel a TString vector) in such a way that the "first vector is sorted" and...
Definition Tools.cxx:539
std::vector< TString > SplitString(const TString &theOpt, const char separator) const
splits the option string at 'separator' and fills the list 'splitV' with the primitive strings
Definition Tools.cxx:1174
const TString & Color(const TString &)
human readable color strings
Definition Tools.cxx:803
const TMatrixD * GetCorrelationMatrix(const TMatrixD *covMat)
turns covariance into correlation matrix
Definition Tools.cxx:299
@ kHtmlLink
Definition Tools.h:212
void UsefulSortAscending(std::vector< std::vector< Double_t > > &, std::vector< TString > *vs=nullptr)
sort 2D vector (AND in parallel a TString vector) in such a way that the "first vector is sorted" and...
Definition Tools.cxx:513
void TMVACitation(MsgLogger &logger, ECitation citType=kPlainText)
kinds of TMVA citation
Definition Tools.cxx:1415
void TMVAVersionMessage(MsgLogger &logger)
prints the TMVA release number and date
Definition Tools.cxx:1291
void TMVAWelcomeMessage()
direct output, eg, when starting ROOT session -> no use of Logger here
Definition Tools.cxx:1277
Class that contains all the data information.
void PrintVariableRanking() const
prints ranking of input variables
Singleton class for Global types used by TMVA.
Definition Types.h:71
static Types & Instance()
The single instance of "Types" if existing already, or create it (Singleton).
Definition Types.cxx:70
@ kCategory
Definition Types.h:97
@ kMulticlass
Definition Types.h:129
@ kNoAnalysisType
Definition Types.h:130
@ kClassification
Definition Types.h:127
@ kMaxAnalysisType
Definition Types.h:131
@ kRegression
Definition Types.h:128
@ kTraining
Definition Types.h:143
const TString & GetLabel() const
virtual void Add(TGraph *graph, Option_t *chopt="")
TList * GetListOfGraphs() const
Definition TMultiGraph.h:67
TH1F * GetHistogram()
void Draw(Option_t *chopt="") override
Default Draw method for all objects.
TAxis * GetYaxis()
TAxis * GetXaxis()
TObject * Clone(const char *newname="") const override
Make a clone of an object using the Streamer facility.
Definition TNamed.cxx:73
virtual void SetTitle(const char *title="")
Set the title of the TNamed.
Definition TNamed.cxx:173
const char * GetName() const override
Returns name of object.
Definition TNamed.h:49
TString fName
Definition TNamed.h:32
@ kOverwrite
overwrite existing object with same name
Definition TObject.h:101
void SetGrid(Int_t valuex=1, Int_t valuey=1) override
Definition TPad.h:341
TLegend * BuildLegend(Double_t x1=0.3, Double_t y1=0.21, Double_t x2=0.3, Double_t y2=0.21, const char *title="", Option_t *option="") override
Build a legend from the graphical objects in the pad.
Definition TPad.cxx:556
virtual void MakePrincipals()
virtual void AddRow(const Double_t *x)
const TMatrixD * GetCovarianceMatrix() const
Return the covariance matrix.
Definition TPrincipal.h:60
Random number generator class based on M.
Definition TRandom3.h:27
virtual UInt_t Integer(UInt_t imax)
Returns a random integer uniformly distributed on the interval [ 0, imax-1 ].
Definition TRandom.cxx:360
Basic string class.
Definition TString.h:138
Ssiz_t Length() const
Definition TString.h:425
void ToLower()
Change string to lower-case.
Definition TString.cxx:1189
const char * Data() const
Definition TString.h:384
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition TString.h:713
Bool_t BeginsWith(const char *s, ECaseCompare cmp=kExact) const
Definition TString.h:632
Bool_t IsNull() const
Definition TString.h:422
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2385
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition TString.h:641
Int_t Write(const char *name=nullptr, Int_t option=0, Int_t bufsize=0) override
Write this object to the current directory.
Definition TTree.cxx:10021
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
const Int_t n
Definition legend1.C:16
void DataLoaderCopy(TMVA::DataLoader *des, TMVA::DataLoader *src)
Config & gConfig()
Tools & gTools()
void CreateVariableTransforms(const TString &trafoDefinition, TMVA::DataSetInfo &dataInfo, TMVA::TransformationHandler &transformationHandler, TMVA::MsgLogger &log)
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148
Bool_t IsNaN(Double_t x)
Definition TMath.h:903
TMarker m
Definition textangle.C:8
#define VIBITS
Definition Factory.cxx:102
static uint64_t sum(uint64_t i)
Definition Factory.cxx:2338
#define READXML
Definition Factory.cxx:99