Logo ROOT  
Reference Guide
Factory.cxx
Go to the documentation of this file.
1// @(#)Root/tmva $Id$
2// Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne, Jan Therhaag
3// Updated by: Omar Zapata, Kim Albertsson
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : Factory *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Implementation (see header for description) *
12 * *
13 * Authors : *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Joerg Stelzer <stelzer@cern.ch> - DESY, Germany *
16 * Peter Speckmayer <peter.speckmayer@cern.ch> - CERN, Switzerland *
17 * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
18 * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
19 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
20 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
21 * Omar Zapata <Omar.Zapata@cern.ch> - UdeA/ITM Colombia *
22 * Lorenzo Moneta <Lorenzo.Moneta@cern.ch> - CERN, Switzerland *
23 * Sergei Gleyzer <Sergei.Gleyzer@cern.ch> - U of Florida & CERN *
24 * Kim Albertsson <kim.albertsson@cern.ch> - LTU & CERN *
25 * *
26 * Copyright (c) 2005-2015: *
27 * CERN, Switzerland *
28 * U. of Victoria, Canada *
29 * MPI-K Heidelberg, Germany *
30 * U. of Bonn, Germany *
31 * UdeA/ITM, Colombia *
32 * U. of Florida, USA *
33 * *
34 * Redistribution and use in source and binary forms, with or without *
35 * modification, are permitted according to the terms listed in LICENSE *
36 * (http://tmva.sourceforge.net/LICENSE) *
37 **********************************************************************************/
38
39/*! \class TMVA::Factory
40\ingroup TMVA
41
42This is the main MVA steering class.
43It creates all MVA methods, and guides them through the training, testing and
44evaluation phases.
45*/
46
47#include "TMVA/Factory.h"
48
50#include "TMVA/Config.h"
51#include "TMVA/Configurable.h"
52#include "TMVA/Tools.h"
53#include "TMVA/Ranking.h"
54#include "TMVA/DataSet.h"
55#include "TMVA/IMethod.h"
56#include "TMVA/MethodBase.h"
58#include "TMVA/DataSetManager.h"
59#include "TMVA/DataSetInfo.h"
60#include "TMVA/DataLoader.h"
61#include "TMVA/MethodBoost.h"
62#include "TMVA/MethodCategory.h"
63#include "TMVA/ROCCalc.h"
64#include "TMVA/ROCCurve.h"
65#include "TMVA/MsgLogger.h"
66
67#include "TMVA/VariableInfo.h"
69
70#include "TMVA/Results.h"
74#include <list>
75#include <bitset>
76#include <set>
77
78#include "TMVA/Types.h"
79
80#include "TROOT.h"
81#include "TFile.h"
82#include "TLeaf.h"
83#include "TEventList.h"
84#include "TH2.h"
85#include "TGraph.h"
86#include "TStyle.h"
87#include "TMatrixF.h"
88#include "TMatrixDSym.h"
89#include "TMultiGraph.h"
90#include "TPrincipal.h"
91#include "TMath.h"
92#include "TSystem.h"
93#include "TCanvas.h"
94
96// const Int_t MinNoTestEvents = 1;
97
99
100#define READXML kTRUE
101
102// number of bits for bitset
103#define VIBITS 32
104
105////////////////////////////////////////////////////////////////////////////////
106/// Standard constructor.
107///
108/// - jobname : this name will appear in all weight file names produced by the MVAs
109/// - theTargetFile : output ROOT file; the test tree and all evaluation plots
110/// will be stored here
111/// - theOption : option string; currently: "V" for verbose
112
113TMVA::Factory::Factory(TString jobName, TFile *theTargetFile, TString theOption)
114 : Configurable(theOption), fTransformations("I"), fVerbose(kFALSE), fVerboseLevel(kINFO), fCorrelations(kFALSE),
115 fROC(kTRUE), fSilentFile(theTargetFile == nullptr), fJobName(jobName), fAnalysisType(Types::kClassification),
116 fModelPersistence(kTRUE)
117{
118 fName = "Factory";
119 fgTargetFile = theTargetFile;
121
122 // render silent
123 if (gTools().CheckForSilentOption(GetOptions()))
124 Log().InhibitOutput(); // make sure is silent if wanted to
125
126 // init configurable
127 SetConfigDescription("Configuration options for Factory running");
129
130 // histograms are not automatically associated with the current
131 // directory and hence don't go out of scope when closing the file
132 // TH1::AddDirectory(kFALSE);
133 Bool_t silent = kFALSE;
134#ifdef WIN32
135 // under Windows, switch progress bar and color off by default, as the typical windows shell doesn't handle these
136 // (would need different sequences..)
137 Bool_t color = kFALSE;
138 Bool_t drawProgressBar = kFALSE;
139#else
140 Bool_t color = !gROOT->IsBatch();
141 Bool_t drawProgressBar = kTRUE;
142#endif
143 DeclareOptionRef(fVerbose, "V", "Verbose flag");
144 DeclareOptionRef(fVerboseLevel = TString("Info"), "VerboseLevel", "VerboseLevel (Debug/Verbose/Info)");
145 AddPreDefVal(TString("Debug"));
146 AddPreDefVal(TString("Verbose"));
147 AddPreDefVal(TString("Info"));
148 DeclareOptionRef(color, "Color", "Flag for coloured screen output (default: True, if in batch mode: False)");
150 fTransformations, "Transformations",
151 "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for identity, "
152 "decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation transformations");
153 DeclareOptionRef(fCorrelations, "Correlations", "boolean to show correlation in output");
154 DeclareOptionRef(fROC, "ROC", "boolean to show ROC in output");
155 DeclareOptionRef(silent, "Silent",
156 "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory "
157 "class object (default: False)");
158 DeclareOptionRef(drawProgressBar, "DrawProgressBar",
159 "Draw progress bar to display training, testing and evaluation schedule (default: True)");
160 DeclareOptionRef(fModelPersistence, "ModelPersistence",
161 "Option to save the trained model in xml file or using serialization");
162
163 TString analysisType("Auto");
164 DeclareOptionRef(analysisType, "AnalysisType",
165 "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)");
166 AddPreDefVal(TString("Classification"));
167 AddPreDefVal(TString("Regression"));
168 AddPreDefVal(TString("Multiclass"));
169 AddPreDefVal(TString("Auto"));
170
171 ParseOptions();
173
174 if (Verbose())
176 if (fVerboseLevel.CompareTo("Debug") == 0)
178 if (fVerboseLevel.CompareTo("Verbose") == 0)
180 if (fVerboseLevel.CompareTo("Info") == 0)
182
183 // global settings
184 gConfig().SetUseColor(color);
185 gConfig().SetSilent(silent);
186 gConfig().SetDrawProgressBar(drawProgressBar);
187
188 analysisType.ToLower();
189 if (analysisType == "classification")
191 else if (analysisType == "regression")
193 else if (analysisType == "multiclass")
195 else if (analysisType == "auto")
197
198 // Greetings();
199}
200
201////////////////////////////////////////////////////////////////////////////////
202/// Constructor.
203
205 : Configurable(theOption), fTransformations("I"), fVerbose(kFALSE), fCorrelations(kFALSE), fROC(kTRUE),
206 fSilentFile(kTRUE), fJobName(jobName), fAnalysisType(Types::kClassification), fModelPersistence(kTRUE)
207{
208 fName = "Factory";
209 fgTargetFile = nullptr;
211
212 // render silent
213 if (gTools().CheckForSilentOption(GetOptions()))
214 Log().InhibitOutput(); // make sure is silent if wanted to
215
216 // init configurable
217 SetConfigDescription("Configuration options for Factory running");
219
220 // histograms are not automatically associated with the current
221 // directory and hence don't go out of scope when closing the file
223 Bool_t silent = kFALSE;
224#ifdef WIN32
225 // under Windows, switch progress bar and color off by default, as the typical windows shell doesn't handle these
226 // (would need different sequences..)
227 Bool_t color = kFALSE;
228 Bool_t drawProgressBar = kFALSE;
229#else
230 Bool_t color = !gROOT->IsBatch();
231 Bool_t drawProgressBar = kTRUE;
232#endif
233 DeclareOptionRef(fVerbose, "V", "Verbose flag");
234 DeclareOptionRef(fVerboseLevel = TString("Info"), "VerboseLevel", "VerboseLevel (Debug/Verbose/Info)");
235 AddPreDefVal(TString("Debug"));
236 AddPreDefVal(TString("Verbose"));
237 AddPreDefVal(TString("Info"));
238 DeclareOptionRef(color, "Color", "Flag for coloured screen output (default: True, if in batch mode: False)");
240 fTransformations, "Transformations",
241 "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for identity, "
242 "decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation transformations");
243 DeclareOptionRef(fCorrelations, "Correlations", "boolean to show correlation in output");
244 DeclareOptionRef(fROC, "ROC", "boolean to show ROC in output");
245 DeclareOptionRef(silent, "Silent",
246 "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory "
247 "class object (default: False)");
248 DeclareOptionRef(drawProgressBar, "DrawProgressBar",
249 "Draw progress bar to display training, testing and evaluation schedule (default: True)");
250 DeclareOptionRef(fModelPersistence, "ModelPersistence",
251 "Option to save the trained model in xml file or using serialization");
252
253 TString analysisType("Auto");
254 DeclareOptionRef(analysisType, "AnalysisType",
255 "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)");
256 AddPreDefVal(TString("Classification"));
257 AddPreDefVal(TString("Regression"));
258 AddPreDefVal(TString("Multiclass"));
259 AddPreDefVal(TString("Auto"));
260
261 ParseOptions();
263
264 if (Verbose())
266 if (fVerboseLevel.CompareTo("Debug") == 0)
268 if (fVerboseLevel.CompareTo("Verbose") == 0)
270 if (fVerboseLevel.CompareTo("Info") == 0)
272
273 // global settings
274 gConfig().SetUseColor(color);
275 gConfig().SetSilent(silent);
276 gConfig().SetDrawProgressBar(drawProgressBar);
277
278 analysisType.ToLower();
279 if (analysisType == "classification")
281 else if (analysisType == "regression")
283 else if (analysisType == "multiclass")
285 else if (analysisType == "auto")
287
288 Greetings();
289}
290
291////////////////////////////////////////////////////////////////////////////////
292/// Print welcome message.
293/// Options are: kLogoWelcomeMsg, kIsometricWelcomeMsg, kLeanWelcomeMsg
294
296{
298 gTools().TMVAWelcomeMessage(Log(), gTools().kLogoWelcomeMsg);
300 Log() << Endl;
301}
302
303////////////////////////////////////////////////////////////////////////////////
304/// Destructor.
305
307{
308 std::vector<TMVA::VariableTransformBase *>::iterator trfIt = fDefaultTrfs.begin();
309 for (; trfIt != fDefaultTrfs.end(); ++trfIt)
310 delete (*trfIt);
311
312 this->DeleteAllMethods();
313
314 // problem with call of REGISTER_METHOD macro ...
315 // ClassifierFactory::DestroyInstance();
316 // Types::DestroyInstance();
317 // Tools::DestroyInstance();
318 // Config::DestroyInstance();
319}
320
321////////////////////////////////////////////////////////////////////////////////
322/// Delete methods.
323
325{
326 std::map<TString, MVector *>::iterator itrMap;
327
328 for (itrMap = fMethodsMap.begin(); itrMap != fMethodsMap.end(); ++itrMap) {
329 MVector *methods = itrMap->second;
330 // delete methods
331 MVector::iterator itrMethod = methods->begin();
332 for (; itrMethod != methods->end(); ++itrMethod) {
333 Log() << kDEBUG << "Delete method: " << (*itrMethod)->GetName() << Endl;
334 delete (*itrMethod);
335 }
336 methods->clear();
337 delete methods;
338 }
339}
340
341////////////////////////////////////////////////////////////////////////////////
342
344{
345 fVerbose = v;
346}
347
348////////////////////////////////////////////////////////////////////////////////
349/// Book a classifier or regression method.
350
352TMVA::Factory::BookMethod(TMVA::DataLoader *loader, TString theMethodName, TString methodTitle, TString theOption)
353{
354 if (fModelPersistence)
355 gSystem->MakeDirectory(loader->GetName()); // creating directory for DataLoader output
356
357 TString datasetname = loader->GetName();
358
359 if (fAnalysisType == Types::kNoAnalysisType) {
360 if (loader->GetDataSetInfo().GetNClasses() == 2 && loader->GetDataSetInfo().GetClassInfo("Signal") != NULL &&
361 loader->GetDataSetInfo().GetClassInfo("Background") != NULL) {
362 fAnalysisType = Types::kClassification; // default is classification
363 } else if (loader->GetDataSetInfo().GetNClasses() >= 2) {
364 fAnalysisType = Types::kMulticlass; // if two classes, but not named "Signal" and "Background"
365 } else
366 Log() << kFATAL << "No analysis type for " << loader->GetDataSetInfo().GetNClasses() << " classes and "
367 << loader->GetDataSetInfo().GetNTargets() << " regression targets." << Endl;
368 }
369
370 // booking via name; the names are translated into enums and the
371 // corresponding overloaded BookMethod is called
372
373 if (fMethodsMap.find(datasetname) != fMethodsMap.end()) {
374 if (GetMethod(datasetname, methodTitle) != 0) {
375 Log() << kFATAL << "Booking failed since method with title <" << methodTitle << "> already exists "
376 << "in with DataSet Name <" << loader->GetName() << "> " << Endl;
377 }
378 }
379
380 Log() << kHEADER << "Booking method: " << gTools().Color("bold")
381 << methodTitle
382 // << gTools().Color("reset")<<" DataSet Name: "<<gTools().Color("bold")<<loader->GetName()
383 << gTools().Color("reset") << Endl << Endl;
384
385 // interpret option string with respect to a request for boosting (i.e., BostNum > 0)
386 Int_t boostNum = 0;
387 TMVA::Configurable *conf = new TMVA::Configurable(theOption);
388 conf->DeclareOptionRef(boostNum = 0, "Boost_num", "Number of times the classifier will be boosted");
389 conf->ParseOptions();
390 delete conf;
391 // this is name of weight file directory
392 TString fileDir;
393 if (fModelPersistence) {
394 // find prefix in fWeightFileDir;
396 fileDir = prefix;
397 if (!prefix.IsNull())
398 if (fileDir[fileDir.Length() - 1] != '/')
399 fileDir += "/";
400 fileDir += loader->GetName();
401 fileDir += "/" + gConfig().GetIONames().fWeightFileDir;
402 }
403 // initialize methods
404 IMethod *im;
405 if (!boostNum) {
406 im = ClassifierFactory::Instance().Create(theMethodName.Data(), fJobName, methodTitle, loader->GetDataSetInfo(),
407 theOption);
408 } else {
409 // boosted classifier, requires a specific definition, making it transparent for the user
410 Log() << kDEBUG << "Boost Number is " << boostNum << " > 0: train boosted classifier" << Endl;
411 im = ClassifierFactory::Instance().Create("Boost", fJobName, methodTitle, loader->GetDataSetInfo(), theOption);
412 MethodBoost *methBoost = dynamic_cast<MethodBoost *>(im); // DSMTEST divided into two lines
413 if (!methBoost) { // DSMTEST
414 Log() << kFATAL << "Method with type kBoost cannot be casted to MethodCategory. /Factory" << Endl; // DSMTEST
415 return nullptr;
416 }
417 if (fModelPersistence)
418 methBoost->SetWeightFileDir(fileDir);
419 methBoost->SetModelPersistence(fModelPersistence);
420 methBoost->SetBoostedMethodName(theMethodName); // DSMTEST divided into two lines
421 methBoost->fDataSetManager = loader->GetDataSetInfo().GetDataSetManager(); // DSMTEST
422 methBoost->SetFile(fgTargetFile);
423 methBoost->SetSilentFile(IsSilentFile());
424 }
425
426 MethodBase *method = dynamic_cast<MethodBase *>(im);
427 if (method == 0)
428 return 0; // could not create method
429
430 // set fDataSetManager if MethodCategory (to enable Category to create datasetinfo objects) // DSMTEST
431 if (method->GetMethodType() == Types::kCategory) { // DSMTEST
432 MethodCategory *methCat = (dynamic_cast<MethodCategory *>(im)); // DSMTEST
433 if (!methCat) { // DSMTEST
434 Log() << kFATAL << "Method with type kCategory cannot be casted to MethodCategory. /Factory"
435 << Endl; // DSMTEST
436 return nullptr;
437 }
438 if (fModelPersistence)
439 methCat->SetWeightFileDir(fileDir);
440 methCat->SetModelPersistence(fModelPersistence);
441 methCat->fDataSetManager = loader->GetDataSetInfo().GetDataSetManager(); // DSMTEST
442 methCat->SetFile(fgTargetFile);
443 methCat->SetSilentFile(IsSilentFile());
444 } // DSMTEST
445
446 if (!method->HasAnalysisType(fAnalysisType, loader->GetDataSetInfo().GetNClasses(),
447 loader->GetDataSetInfo().GetNTargets())) {
448 Log() << kWARNING << "Method " << method->GetMethodTypeName() << " is not capable of handling ";
449 if (fAnalysisType == Types::kRegression) {
450 Log() << "regression with " << loader->GetDataSetInfo().GetNTargets() << " targets." << Endl;
451 } else if (fAnalysisType == Types::kMulticlass) {
452 Log() << "multiclass classification with " << loader->GetDataSetInfo().GetNClasses() << " classes." << Endl;
453 } else {
454 Log() << "classification with " << loader->GetDataSetInfo().GetNClasses() << " classes." << Endl;
455 }
456 return 0;
457 }
458
459 if (fModelPersistence)
460 method->SetWeightFileDir(fileDir);
461 method->SetModelPersistence(fModelPersistence);
462 method->SetAnalysisType(fAnalysisType);
463 method->SetupMethod();
464 method->ParseOptions();
465 method->ProcessSetup();
466 method->SetFile(fgTargetFile);
467 method->SetSilentFile(IsSilentFile());
468
469 // check-for-unused-options is performed; may be overridden by derived classes
470 method->CheckSetup();
471
472 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
473 MVector *mvector = new MVector;
474 fMethodsMap[datasetname] = mvector;
475 }
476 fMethodsMap[datasetname]->push_back(method);
477 return method;
478}
479
480////////////////////////////////////////////////////////////////////////////////
481/// Books MVA method. The option configuration string is custom for each MVA
482/// the TString field "theNameAppendix" serves to define (and distinguish)
483/// several instances of a given MVA, eg, when one wants to compare the
484/// performance of various configurations
485
487TMVA::Factory::BookMethod(TMVA::DataLoader *loader, Types::EMVA theMethod, TString methodTitle, TString theOption)
488{
489 return BookMethod(loader, Types::Instance().GetMethodName(theMethod), methodTitle, theOption);
490}
491
492////////////////////////////////////////////////////////////////////////////////
493/// Adds an already constructed method to be managed by this factory.
494///
495/// \note Private.
496/// \note Know what you are doing when using this method. The method that you
497/// are loading could be trained already.
498///
499
502{
503 TString datasetname = loader->GetName();
504 std::string methodTypeName = std::string(Types::Instance().GetMethodName(methodType).Data());
505 DataSetInfo &dsi = loader->GetDataSetInfo();
506
507 IMethod *im = ClassifierFactory::Instance().Create(methodTypeName, dsi, weightfile);
508 MethodBase *method = (dynamic_cast<MethodBase *>(im));
509
510 if (method == nullptr)
511 return nullptr;
512
513 if (method->GetMethodType() == Types::kCategory) {
514 Log() << kERROR << "Cannot handle category methods for now." << Endl;
515 }
516
517 TString fileDir;
518 if (fModelPersistence) {
519 // find prefix in fWeightFileDir;
521 fileDir = prefix;
522 if (!prefix.IsNull())
523 if (fileDir[fileDir.Length() - 1] != '/')
524 fileDir += "/";
525 fileDir = loader->GetName();
526 fileDir += "/" + gConfig().GetIONames().fWeightFileDir;
527 }
528
529 if (fModelPersistence)
530 method->SetWeightFileDir(fileDir);
531 method->SetModelPersistence(fModelPersistence);
532 method->SetAnalysisType(fAnalysisType);
533 method->SetupMethod();
534 method->SetFile(fgTargetFile);
535 method->SetSilentFile(IsSilentFile());
536
538
539 // read weight file
540 method->ReadStateFromFile();
541
542 // method->CheckSetup();
543
544 TString methodTitle = method->GetName();
545 if (HasMethod(datasetname, methodTitle) != 0) {
546 Log() << kFATAL << "Booking failed since method with title <" << methodTitle << "> already exists "
547 << "in with DataSet Name <" << loader->GetName() << "> " << Endl;
548 }
549
550 Log() << kINFO << "Booked classifier \"" << method->GetMethodName() << "\" of type: \""
551 << method->GetMethodTypeName() << "\"" << Endl;
552
553 if (fMethodsMap.count(datasetname) == 0) {
554 MVector *mvector = new MVector;
555 fMethodsMap[datasetname] = mvector;
556 }
557
558 fMethodsMap[datasetname]->push_back(method);
559
560 return method;
561}
562
563////////////////////////////////////////////////////////////////////////////////
564/// Returns pointer to MVA that corresponds to given method title.
565
566TMVA::IMethod *TMVA::Factory::GetMethod(const TString &datasetname, const TString &methodTitle) const
567{
568 if (fMethodsMap.find(datasetname) == fMethodsMap.end())
569 return 0;
570
571 MVector *methods = fMethodsMap.find(datasetname)->second;
572
573 MVector::const_iterator itrMethod;
574 //
575 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
576 MethodBase *mva = dynamic_cast<MethodBase *>(*itrMethod);
577 if ((mva->GetMethodName()) == methodTitle)
578 return mva;
579 }
580 return 0;
581}
582
583////////////////////////////////////////////////////////////////////////////////
584/// Checks whether a given method name is defined for a given dataset.
585
586Bool_t TMVA::Factory::HasMethod(const TString &datasetname, const TString &methodTitle) const
587{
588 if (fMethodsMap.find(datasetname) == fMethodsMap.end())
589 return 0;
590
591 std::string methodName = methodTitle.Data();
592 auto isEqualToMethodName = [&methodName](TMVA::IMethod *m) { return (0 == methodName.compare(m->GetName())); };
593
594 TMVA::Factory::MVector *methods = this->fMethodsMap.at(datasetname);
595 Bool_t isMethodNameExisting = std::any_of(methods->begin(), methods->end(), isEqualToMethodName);
596
597 return isMethodNameExisting;
598}
599
600////////////////////////////////////////////////////////////////////////////////
601
603{
604 RootBaseDir()->cd();
605
606 if (!RootBaseDir()->GetDirectory(fDataSetInfo.GetName()))
607 RootBaseDir()->mkdir(fDataSetInfo.GetName());
608 else
609 return; // loader is now in the output file, we dont need to save again
610
611 RootBaseDir()->cd(fDataSetInfo.GetName());
612 fDataSetInfo.GetDataSet(); // builds dataset (including calculation of correlation matrix)
613
614 // correlation matrix of the default DS
615 const TMatrixD *m(0);
616 const TH2 *h(0);
617
618 if (fAnalysisType == Types::kMulticlass) {
619 for (UInt_t cls = 0; cls < fDataSetInfo.GetNClasses(); cls++) {
620 m = fDataSetInfo.CorrelationMatrix(fDataSetInfo.GetClassInfo(cls)->GetName());
621 h = fDataSetInfo.CreateCorrelationMatrixHist(
622 m, TString("CorrelationMatrix") + fDataSetInfo.GetClassInfo(cls)->GetName(),
623 TString("Correlation Matrix (") + fDataSetInfo.GetClassInfo(cls)->GetName() + TString(")"));
624 if (h != 0) {
625 h->Write();
626 delete h;
627 }
628 }
629 } else {
630 m = fDataSetInfo.CorrelationMatrix("Signal");
631 h = fDataSetInfo.CreateCorrelationMatrixHist(m, "CorrelationMatrixS", "Correlation Matrix (signal)");
632 if (h != 0) {
633 h->Write();
634 delete h;
635 }
636
637 m = fDataSetInfo.CorrelationMatrix("Background");
638 h = fDataSetInfo.CreateCorrelationMatrixHist(m, "CorrelationMatrixB", "Correlation Matrix (background)");
639 if (h != 0) {
640 h->Write();
641 delete h;
642 }
643
644 m = fDataSetInfo.CorrelationMatrix("Regression");
645 h = fDataSetInfo.CreateCorrelationMatrixHist(m, "CorrelationMatrix", "Correlation Matrix");
646 if (h != 0) {
647 h->Write();
648 delete h;
649 }
650 }
651
652 // some default transformations to evaluate
653 // NOTE: all transformations are destroyed after this test
654 TString processTrfs = "I"; //"I;N;D;P;U;G,D;"
655
656 // plus some user defined transformations
657 processTrfs = fTransformations;
658
659 // remove any trace of identity transform - if given (avoid to apply it twice)
660 std::vector<TMVA::TransformationHandler *> trfs;
661 TransformationHandler *identityTrHandler = 0;
662
663 std::vector<TString> trfsDef = gTools().SplitString(processTrfs, ';');
664 std::vector<TString>::iterator trfsDefIt = trfsDef.begin();
665 for (; trfsDefIt != trfsDef.end(); ++trfsDefIt) {
666 trfs.push_back(new TMVA::TransformationHandler(fDataSetInfo, "Factory"));
667 TString trfS = (*trfsDefIt);
668
669 // Log() << kINFO << Endl;
670 Log() << kDEBUG << "current transformation string: '" << trfS.Data() << "'" << Endl;
671 TMVA::CreateVariableTransforms(trfS, fDataSetInfo, *(trfs.back()), Log());
672
673 if (trfS.BeginsWith('I'))
674 identityTrHandler = trfs.back();
675 }
676
677 const std::vector<Event *> &inputEvents = fDataSetInfo.GetDataSet()->GetEventCollection();
678
679 // apply all transformations
680 std::vector<TMVA::TransformationHandler *>::iterator trfIt = trfs.begin();
681
682 for (; trfIt != trfs.end(); ++trfIt) {
683 // setting a Root dir causes the variables distributions to be saved to the root file
684 (*trfIt)->SetRootDir(RootBaseDir()->GetDirectory(fDataSetInfo.GetName())); // every dataloader have its own dir
685 (*trfIt)->CalcTransformations(inputEvents);
686 }
687 if (identityTrHandler)
688 identityTrHandler->PrintVariableRanking();
689
690 // clean up
691 for (trfIt = trfs.begin(); trfIt != trfs.end(); ++trfIt)
692 delete *trfIt;
693}
694
695////////////////////////////////////////////////////////////////////////////////
696/// Iterates through all booked methods and sees if they use parameter tuning and if so
697/// does just that, i.e.\ calls "Method::Train()" for different parameter settings and
698/// keeps in mind the "optimal one"...\ and that's the one that will later on be used
699/// in the main training loop.
700
701std::map<TString, Double_t> TMVA::Factory::OptimizeAllMethods(TString fomType, TString fitType)
702{
703
704 std::map<TString, MVector *>::iterator itrMap;
705 std::map<TString, Double_t> TunedParameters;
706 for (itrMap = fMethodsMap.begin(); itrMap != fMethodsMap.end(); ++itrMap) {
707 MVector *methods = itrMap->second;
708
709 MVector::iterator itrMethod;
710
711 // iterate over methods and optimize
712 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
714 MethodBase *mva = dynamic_cast<MethodBase *>(*itrMethod);
715 if (!mva) {
716 Log() << kFATAL << "Dynamic cast to MethodBase failed" << Endl;
717 return TunedParameters;
718 }
719
721 Log() << kWARNING << "Method " << mva->GetMethodName() << " not trained (training tree has less entries ["
722 << mva->Data()->GetNTrainingEvents() << "] than required [" << MinNoTrainingEvents << "]" << Endl;
723 continue;
724 }
725
726 Log() << kINFO << "Optimize method: " << mva->GetMethodName() << " for "
727 << (fAnalysisType == Types::kRegression
728 ? "Regression"
729 : (fAnalysisType == Types::kMulticlass ? "Multiclass classification" : "Classification"))
730 << Endl;
731
732 TunedParameters = mva->OptimizeTuningParameters(fomType, fitType);
733 Log() << kINFO << "Optimization of tuning parameters finished for Method:" << mva->GetName() << Endl;
734 }
735 }
736
737 return TunedParameters;
738}
739
740////////////////////////////////////////////////////////////////////////////////
741/// Private method to generate a ROCCurve instance for a given method.
742/// Handles the conversion from TMVA ResultSet to a format the ROCCurve class
743/// understands.
744///
745/// \note You own the retured pointer.
746///
747
750{
751 return GetROC((TString)loader->GetName(), theMethodName, iClass, type);
752}
753
754////////////////////////////////////////////////////////////////////////////////
755/// Private method to generate a ROCCurve instance for a given method.
756/// Handles the conversion from TMVA ResultSet to a format the ROCCurve class
757/// understands.
758///
759/// \note You own the retured pointer.
760///
761
763{
764 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
765 Log() << kERROR << Form("DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
766 return nullptr;
767 }
768
769 if (!this->HasMethod(datasetname, theMethodName)) {
770 Log() << kERROR << Form("Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data())
771 << Endl;
772 return nullptr;
773 }
774
775 std::set<Types::EAnalysisType> allowedAnalysisTypes = {Types::kClassification, Types::kMulticlass};
776 if (allowedAnalysisTypes.count(this->fAnalysisType) == 0) {
777 Log() << kERROR << Form("Can only generate ROC curves for analysis type kClassification and kMulticlass.")
778 << Endl;
779 return nullptr;
780 }
781
782 TMVA::MethodBase *method = dynamic_cast<TMVA::MethodBase *>(this->GetMethod(datasetname, theMethodName));
783 TMVA::DataSet *dataset = method->Data();
784 dataset->SetCurrentType(type);
785 TMVA::Results *results = dataset->GetResults(theMethodName, type, this->fAnalysisType);
786
787 UInt_t nClasses = method->DataInfo().GetNClasses();
788 if (this->fAnalysisType == Types::kMulticlass && iClass >= nClasses) {
789 Log() << kERROR
790 << Form("Given class number (iClass = %i) does not exist. There are %i classes in dataset.", iClass,
791 nClasses)
792 << Endl;
793 return nullptr;
794 }
795
796 TMVA::ROCCurve *rocCurve = nullptr;
797 if (this->fAnalysisType == Types::kClassification) {
798
799 std::vector<Float_t> *mvaRes = dynamic_cast<ResultsClassification *>(results)->GetValueVector();
800 std::vector<Bool_t> *mvaResTypes = dynamic_cast<ResultsClassification *>(results)->GetValueVectorTypes();
801 std::vector<Float_t> mvaResWeights;
802
803 auto eventCollection = dataset->GetEventCollection(type);
804 mvaResWeights.reserve(eventCollection.size());
805 for (auto ev : eventCollection) {
806 mvaResWeights.push_back(ev->GetWeight());
807 }
808
809 rocCurve = new TMVA::ROCCurve(*mvaRes, *mvaResTypes, mvaResWeights);
810
811 } else if (this->fAnalysisType == Types::kMulticlass) {
812 std::vector<Float_t> mvaRes;
813 std::vector<Bool_t> mvaResTypes;
814 std::vector<Float_t> mvaResWeights;
815
816 std::vector<std::vector<Float_t>> *rawMvaRes = dynamic_cast<ResultsMulticlass *>(results)->GetValueVector();
817
818 // Vector transpose due to values being stored as
819 // [ [0, 1, 2], [0, 1, 2], ... ]
820 // in ResultsMulticlass::GetValueVector.
821 mvaRes.reserve(rawMvaRes->size());
822 for (auto item : *rawMvaRes) {
823 mvaRes.push_back(item[iClass]);
824 }
825
826 auto eventCollection = dataset->GetEventCollection(type);
827 mvaResTypes.reserve(eventCollection.size());
828 mvaResWeights.reserve(eventCollection.size());
829 for (auto ev : eventCollection) {
830 mvaResTypes.push_back(ev->GetClass() == iClass);
831 mvaResWeights.push_back(ev->GetWeight());
832 }
833
834 rocCurve = new TMVA::ROCCurve(mvaRes, mvaResTypes, mvaResWeights);
835 }
836
837 return rocCurve;
838}
839
840////////////////////////////////////////////////////////////////////////////////
841/// Calculate the integral of the ROC curve, also known as the area under curve
842/// (AUC), for a given method.
843///
844/// Argument iClass specifies the class to generate the ROC curve in a
845/// multiclass setting. It is ignored for binary classification.
846///
847
850{
851 return GetROCIntegral((TString)loader->GetName(), theMethodName, iClass, type);
852}
853
854////////////////////////////////////////////////////////////////////////////////
855/// Calculate the integral of the ROC curve, also known as the area under curve
856/// (AUC), for a given method.
857///
858/// Argument iClass specifies the class to generate the ROC curve in a
859/// multiclass setting. It is ignored for binary classification.
860///
861
863{
864 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
865 Log() << kERROR << Form("DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
866 return 0;
867 }
868
869 if (!this->HasMethod(datasetname, theMethodName)) {
870 Log() << kERROR << Form("Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data())
871 << Endl;
872 return 0;
873 }
874
875 std::set<Types::EAnalysisType> allowedAnalysisTypes = {Types::kClassification, Types::kMulticlass};
876 if (allowedAnalysisTypes.count(this->fAnalysisType) == 0) {
877 Log() << kERROR << Form("Can only generate ROC integral for analysis type kClassification. and kMulticlass.")
878 << Endl;
879 return 0;
880 }
881
882 TMVA::ROCCurve *rocCurve = GetROC(datasetname, theMethodName, iClass, type);
883 if (!rocCurve) {
884 Log() << kFATAL
885 << Form("ROCCurve object was not created in Method = %s not found with Dataset = %s ", theMethodName.Data(),
886 datasetname.Data())
887 << Endl;
888 return 0;
889 }
890
892 Double_t rocIntegral = rocCurve->GetROCIntegral(npoints);
893 delete rocCurve;
894
895 return rocIntegral;
896}
897
898////////////////////////////////////////////////////////////////////////////////
899/// Argument iClass specifies the class to generate the ROC curve in a
900/// multiclass setting. It is ignored for binary classification.
901///
902/// Returns a ROC graph for a given method, or nullptr on error.
903///
904/// Note: Evaluation of the given method must have been run prior to ROC
905/// generation through Factory::EvaluateAllMetods.
906///
907/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
908/// and the others considered background. This is ok in binary classification
909/// but in in multi class classification, the ROC surface is an N dimensional
910/// shape, where N is number of classes - 1.
911
912TGraph *TMVA::Factory::GetROCCurve(DataLoader *loader, TString theMethodName, Bool_t setTitles, UInt_t iClass,
914{
915 return GetROCCurve((TString)loader->GetName(), theMethodName, setTitles, iClass, type);
916}
917
918////////////////////////////////////////////////////////////////////////////////
919/// Argument iClass specifies the class to generate the ROC curve in a
920/// multiclass setting. It is ignored for binary classification.
921///
922/// Returns a ROC graph for a given method, or nullptr on error.
923///
924/// Note: Evaluation of the given method must have been run prior to ROC
925/// generation through Factory::EvaluateAllMetods.
926///
927/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
928/// and the others considered background. This is ok in binary classification
929/// but in in multi class classification, the ROC surface is an N dimensional
930/// shape, where N is number of classes - 1.
931
932TGraph *TMVA::Factory::GetROCCurve(TString datasetname, TString theMethodName, Bool_t setTitles, UInt_t iClass,
934{
935 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
936 Log() << kERROR << Form("DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
937 return nullptr;
938 }
939
940 if (!this->HasMethod(datasetname, theMethodName)) {
941 Log() << kERROR << Form("Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data())
942 << Endl;
943 return nullptr;
944 }
945
946 std::set<Types::EAnalysisType> allowedAnalysisTypes = {Types::kClassification, Types::kMulticlass};
947 if (allowedAnalysisTypes.count(this->fAnalysisType) == 0) {
948 Log() << kERROR << Form("Can only generate ROC curves for analysis type kClassification and kMulticlass.")
949 << Endl;
950 return nullptr;
951 }
952
953 TMVA::ROCCurve *rocCurve = GetROC(datasetname, theMethodName, iClass, type);
954 TGraph *graph = nullptr;
955
956 if (!rocCurve) {
957 Log() << kFATAL
958 << Form("ROCCurve object was not created in Method = %s not found with Dataset = %s ", theMethodName.Data(),
959 datasetname.Data())
960 << Endl;
961 return nullptr;
962 }
963
964 graph = (TGraph *)rocCurve->GetROCCurve()->Clone();
965 delete rocCurve;
966
967 if (setTitles) {
968 graph->GetYaxis()->SetTitle("Background rejection (Specificity)");
969 graph->GetXaxis()->SetTitle("Signal efficiency (Sensitivity)");
970 graph->SetTitle(Form("Signal efficiency vs. Background rejection (%s)", theMethodName.Data()));
971 }
972
973 return graph;
974}
975
976////////////////////////////////////////////////////////////////////////////////
977/// Generate a collection of graphs, for all methods for a given class. Suitable
978/// for comparing method performance.
979///
980/// Argument iClass specifies the class to generate the ROC curve in a
981/// multiclass setting. It is ignored for binary classification.
982///
983/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
984/// and the others considered background. This is ok in binary classification
985/// but in in multi class classification, the ROC surface is an N dimensional
986/// shape, where N is number of classes - 1.
987
989{
990 return GetROCCurveAsMultiGraph((TString)loader->GetName(), iClass, type);
991}
992
993////////////////////////////////////////////////////////////////////////////////
994/// Generate a collection of graphs, for all methods for a given class. Suitable
995/// for comparing method performance.
996///
997/// Argument iClass specifies the class to generate the ROC curve in a
998/// multiclass setting. It is ignored for binary classification.
999///
1000/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
1001/// and the others considered background. This is ok in binary classification
1002/// but in in multi class classification, the ROC surface is an N dimensional
1003/// shape, where N is number of classes - 1.
1004
1006{
1007 UInt_t line_color = 1;
1008
1009 TMultiGraph *multigraph = new TMultiGraph();
1010
1011 MVector *methods = fMethodsMap[datasetname.Data()];
1012 for (auto *method_raw : *methods) {
1013 TMVA::MethodBase *method = dynamic_cast<TMVA::MethodBase *>(method_raw);
1014 if (method == nullptr) {
1015 continue;
1016 }
1017
1018 TString methodName = method->GetMethodName();
1019 UInt_t nClasses = method->DataInfo().GetNClasses();
1020
1021 if (this->fAnalysisType == Types::kMulticlass && iClass >= nClasses) {
1022 Log() << kERROR
1023 << Form("Given class number (iClass = %i) does not exist. There are %i classes in dataset.", iClass,
1024 nClasses)
1025 << Endl;
1026 continue;
1027 }
1028
1029 TString className = method->DataInfo().GetClassInfo(iClass)->GetName();
1030
1031 TGraph *graph = this->GetROCCurve(datasetname, methodName, false, iClass, type);
1032 graph->SetTitle(methodName);
1033
1034 graph->SetLineWidth(2);
1035 graph->SetLineColor(line_color++);
1036 graph->SetFillColor(10);
1037
1038 multigraph->Add(graph);
1039 }
1040
1041 if (multigraph->GetListOfGraphs() == nullptr) {
1042 Log() << kERROR << Form("No metohds have class %i defined.", iClass) << Endl;
1043 return nullptr;
1044 }
1045
1046 return multigraph;
1047}
1048
1049////////////////////////////////////////////////////////////////////////////////
1050/// Draws ROC curves for all methods booked with the factory for a given class
1051/// onto a canvas.
1052///
1053/// Argument iClass specifies the class to generate the ROC curve in a
1054/// multiclass setting. It is ignored for binary classification.
1055///
1056/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
1057/// and the others considered background. This is ok in binary classification
1058/// but in in multi class classification, the ROC surface is an N dimensional
1059/// shape, where N is number of classes - 1.
1060
1062{
1063 return GetROCCurve((TString)loader->GetName(), iClass, type);
1064}
1065
1066////////////////////////////////////////////////////////////////////////////////
1067/// Draws ROC curves for all methods booked with the factory for a given class.
1068///
1069/// Argument iClass specifies the class to generate the ROC curve in a
1070/// multiclass setting. It is ignored for binary classification.
1071///
1072/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
1073/// and the others considered background. This is ok in binary classification
1074/// but in in multi class classification, the ROC surface is an N dimensional
1075/// shape, where N is number of classes - 1.
1076
1078{
1079 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
1080 Log() << kERROR << Form("DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
1081 return 0;
1082 }
1083
1084 TString name = Form("ROCCurve %s class %i", datasetname.Data(), iClass);
1085 TCanvas *canvas = new TCanvas(name, "ROC Curve", 200, 10, 700, 500);
1086 canvas->SetGrid();
1087
1088 TMultiGraph *multigraph = this->GetROCCurveAsMultiGraph(datasetname, iClass, type);
1089
1090 if (multigraph) {
1091 multigraph->Draw("AL");
1092
1093 multigraph->GetYaxis()->SetTitle("Background rejection (Specificity)");
1094 multigraph->GetXaxis()->SetTitle("Signal efficiency (Sensitivity)");
1095
1096 TString titleString = Form("Signal efficiency vs. Background rejection");
1097 if (this->fAnalysisType == Types::kMulticlass) {
1098 titleString = Form("%s (Class=%i)", titleString.Data(), iClass);
1099 }
1100
1101 // Workaround for TMultigraph not drawing title correctly.
1102 multigraph->GetHistogram()->SetTitle(titleString);
1103 multigraph->SetTitle(titleString);
1104
1105 canvas->BuildLegend(0.15, 0.15, 0.35, 0.3, "MVA Method");
1106 }
1107
1108 return canvas;
1109}
1110
1111////////////////////////////////////////////////////////////////////////////////
1112/// Iterates through all booked methods and calls training
1113
1115{
1116 Log() << kHEADER << gTools().Color("bold") << "Train all methods" << gTools().Color("reset") << Endl;
1117 // iterates over all MVAs that have been booked, and calls their training methods
1118
1119 // don't do anything if no method booked
1120 if (fMethodsMap.empty()) {
1121 Log() << kINFO << "...nothing found to train" << Endl;
1122 return;
1123 }
1124
1125 // here the training starts
1126 // Log() << kINFO << " " << Endl;
1127 Log() << kDEBUG << "Train all methods for "
1128 << (fAnalysisType == Types::kRegression
1129 ? "Regression"
1130 : (fAnalysisType == Types::kMulticlass ? "Multiclass" : "Classification"))
1131 << " ..." << Endl;
1132
1133 std::map<TString, MVector *>::iterator itrMap;
1134
1135 for (itrMap = fMethodsMap.begin(); itrMap != fMethodsMap.end(); ++itrMap) {
1136 MVector *methods = itrMap->second;
1137 MVector::iterator itrMethod;
1138
1139 // iterate over methods and train
1140 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1142 MethodBase *mva = dynamic_cast<MethodBase *>(*itrMethod);
1143
1144 if (mva == 0)
1145 continue;
1146
1147 if (mva->DataInfo().GetDataSetManager()->DataInput().GetEntries() <=
1148 1) { // 0 entries --> 0 events, 1 entry --> dynamical dataset (or one entry)
1149 Log() << kFATAL << "No input data for the training provided!" << Endl;
1150 }
1151
1152 if (fAnalysisType == Types::kRegression && mva->DataInfo().GetNTargets() < 1)
1153 Log() << kFATAL << "You want to do regression training without specifying a target." << Endl;
1154 else if ((fAnalysisType == Types::kMulticlass || fAnalysisType == Types::kClassification) &&
1155 mva->DataInfo().GetNClasses() < 2)
1156 Log() << kFATAL << "You want to do classification training, but specified less than two classes." << Endl;
1157
1158 // first print some information about the default dataset
1159 if (!IsSilentFile())
1160 WriteDataInformation(mva->fDataSetInfo);
1161
1163 Log() << kWARNING << "Method " << mva->GetMethodName() << " not trained (training tree has less entries ["
1164 << mva->Data()->GetNTrainingEvents() << "] than required [" << MinNoTrainingEvents << "]" << Endl;
1165 continue;
1166 }
1167
1168 Log() << kHEADER << "Train method: " << mva->GetMethodName() << " for "
1169 << (fAnalysisType == Types::kRegression
1170 ? "Regression"
1171 : (fAnalysisType == Types::kMulticlass ? "Multiclass classification" : "Classification"))
1172 << Endl << Endl;
1173 mva->TrainMethod();
1174 Log() << kHEADER << "Training finished" << Endl << Endl;
1175 }
1176
1177 if (fAnalysisType != Types::kRegression) {
1178
1179 // variable ranking
1180 // Log() << Endl;
1181 Log() << kINFO << "Ranking input variables (method specific)..." << Endl;
1182 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1183 MethodBase *mva = dynamic_cast<MethodBase *>(*itrMethod);
1184 if (mva && mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
1185
1186 // create and print ranking
1187 const Ranking *ranking = (*itrMethod)->CreateRanking();
1188 if (ranking != 0)
1189 ranking->Print();
1190 else
1191 Log() << kINFO << "No variable ranking supplied by classifier: "
1192 << dynamic_cast<MethodBase *>(*itrMethod)->GetMethodName() << Endl;
1193 }
1194 }
1195 }
1196
1197 // save training history in case we are not in the silent mode
1198 if (!IsSilentFile()) {
1199 for (UInt_t i = 0; i < methods->size(); i++) {
1200 MethodBase *m = dynamic_cast<MethodBase *>((*methods)[i]);
1201 if (m == 0)
1202 continue;
1203 m->BaseDir()->cd();
1204 m->fTrainHistory.SaveHistory(m->GetMethodName());
1205 }
1206 }
1207
1208 // delete all methods and recreate them from weight file - this ensures that the application
1209 // of the methods (in TMVAClassificationApplication) is consistent with the results obtained
1210 // in the testing
1211 // Log() << Endl;
1212 if (fModelPersistence) {
1213
1214 Log() << kHEADER << "=== Destroy and recreate all methods via weight files for testing ===" << Endl << Endl;
1215
1216 if (!IsSilentFile())
1217 RootBaseDir()->cd();
1218
1219 // iterate through all booked methods
1220 for (UInt_t i = 0; i < methods->size(); i++) {
1221
1222 MethodBase *m = dynamic_cast<MethodBase *>((*methods)[i]);
1223 if (m == nullptr)
1224 continue;
1225
1226 TMVA::Types::EMVA methodType = m->GetMethodType();
1227 TString weightfile = m->GetWeightFileName();
1228
1229 // decide if .txt or .xml file should be read:
1230 if (READXML)
1231 weightfile.ReplaceAll(".txt", ".xml");
1232
1233 DataSetInfo &dataSetInfo = m->DataInfo();
1234 TString testvarName = m->GetTestvarName();
1235 delete m; // itrMethod[i];
1236
1237 // recreate
1238 m = dynamic_cast<MethodBase *>(ClassifierFactory::Instance().Create(
1239 Types::Instance().GetMethodName(methodType).Data(), dataSetInfo, weightfile));
1240 if (m->GetMethodType() == Types::kCategory) {
1241 MethodCategory *methCat = (dynamic_cast<MethodCategory *>(m));
1242 if (!methCat)
1243 Log() << kFATAL << "Method with type kCategory cannot be casted to MethodCategory. /Factory" << Endl;
1244 else
1245 methCat->fDataSetManager = m->DataInfo().GetDataSetManager();
1246 }
1247 // ToDo, Do we need to fill the DataSetManager of MethodBoost here too?
1248
1249 TString wfileDir = m->DataInfo().GetName();
1250 wfileDir += "/" + gConfig().GetIONames().fWeightFileDir;
1251 m->SetWeightFileDir(wfileDir);
1252 m->SetModelPersistence(fModelPersistence);
1253 m->SetSilentFile(IsSilentFile());
1254 m->SetAnalysisType(fAnalysisType);
1255 m->SetupMethod();
1256 m->ReadStateFromFile();
1257 m->SetTestvarName(testvarName);
1258
1259 // replace trained method by newly created one (from weight file) in methods vector
1260 (*methods)[i] = m;
1261 }
1262 }
1263 }
1264}
1265
1266////////////////////////////////////////////////////////////////////////////////
1267/// Evaluates all booked methods on the testing data and adds the output to the
1268/// Results in the corresponiding DataSet.
1269///
1270
1272{
1273 Log() << kHEADER << gTools().Color("bold") << "Test all methods" << gTools().Color("reset") << Endl;
1274
1275 // don't do anything if no method booked
1276 if (fMethodsMap.empty()) {
1277 Log() << kINFO << "...nothing found to test" << Endl;
1278 return;
1279 }
1280 std::map<TString, MVector *>::iterator itrMap;
1281
1282 for (itrMap = fMethodsMap.begin(); itrMap != fMethodsMap.end(); ++itrMap) {
1283 MVector *methods = itrMap->second;
1284 MVector::iterator itrMethod;
1285
1286 // iterate over methods and test
1287 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1289 MethodBase *mva = dynamic_cast<MethodBase *>(*itrMethod);
1290 if (mva == 0)
1291 continue;
1292 Types::EAnalysisType analysisType = mva->GetAnalysisType();
1293 Log() << kHEADER << "Test method: " << mva->GetMethodName() << " for "
1294 << (analysisType == Types::kRegression
1295 ? "Regression"
1296 : (analysisType == Types::kMulticlass ? "Multiclass classification" : "Classification"))
1297 << " performance" << Endl << Endl;
1298 mva->AddOutput(Types::kTesting, analysisType);
1299 }
1300 }
1301}
1302
1303////////////////////////////////////////////////////////////////////////////////
1304
1305void TMVA::Factory::MakeClass(const TString &datasetname, const TString &methodTitle) const
1306{
1307 if (methodTitle != "") {
1308 IMethod *method = GetMethod(datasetname, methodTitle);
1309 if (method)
1310 method->MakeClass();
1311 else {
1312 Log() << kWARNING << "<MakeClass> Could not find classifier \"" << methodTitle << "\" in list" << Endl;
1313 }
1314 } else {
1315
1316 // no classifier specified, print all help messages
1317 MVector *methods = fMethodsMap.find(datasetname)->second;
1318 MVector::const_iterator itrMethod;
1319 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1320 MethodBase *method = dynamic_cast<MethodBase *>(*itrMethod);
1321 if (method == 0)
1322 continue;
1323 Log() << kINFO << "Make response class for classifier: " << method->GetMethodName() << Endl;
1324 method->MakeClass();
1325 }
1326 }
1327}
1328
1329////////////////////////////////////////////////////////////////////////////////
1330/// Print predefined help message of classifier.
1331/// Iterate over methods and test.
1332
1333void TMVA::Factory::PrintHelpMessage(const TString &datasetname, const TString &methodTitle) const
1334{
1335 if (methodTitle != "") {
1336 IMethod *method = GetMethod(datasetname, methodTitle);
1337 if (method)
1338 method->PrintHelpMessage();
1339 else {
1340 Log() << kWARNING << "<PrintHelpMessage> Could not find classifier \"" << methodTitle << "\" in list" << Endl;
1341 }
1342 } else {
1343
1344 // no classifier specified, print all help messages
1345 MVector *methods = fMethodsMap.find(datasetname)->second;
1346 MVector::const_iterator itrMethod;
1347 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1348 MethodBase *method = dynamic_cast<MethodBase *>(*itrMethod);
1349 if (method == 0)
1350 continue;
1351 Log() << kINFO << "Print help message for classifier: " << method->GetMethodName() << Endl;
1352 method->PrintHelpMessage();
1353 }
1354 }
1355}
1356
1357////////////////////////////////////////////////////////////////////////////////
1358/// Iterates over all MVA input variables and evaluates them.
1359
1361{
1362 Log() << kINFO << "Evaluating all variables..." << Endl;
1364
1365 for (UInt_t i = 0; i < loader->GetDataSetInfo().GetNVariables(); i++) {
1367 if (options.Contains("V"))
1368 s += ":V";
1369 this->BookMethod(loader, "Variable", s);
1370 }
1371}
1372
1373////////////////////////////////////////////////////////////////////////////////
1374/// Iterates over all MVAs that have been booked, and calls their evaluation methods.
1375
1377{
1378 Log() << kHEADER << gTools().Color("bold") << "Evaluate all methods" << gTools().Color("reset") << Endl;
1379
1380 // don't do anything if no method booked
1381 if (fMethodsMap.empty()) {
1382 Log() << kINFO << "...nothing found to evaluate" << Endl;
1383 return;
1384 }
1385 std::map<TString, MVector *>::iterator itrMap;
1386
1387 for (itrMap = fMethodsMap.begin(); itrMap != fMethodsMap.end(); ++itrMap) {
1388 MVector *methods = itrMap->second;
1389
1390 // -----------------------------------------------------------------------
1391 // First part of evaluation process
1392 // --> compute efficiencies, and other separation estimators
1393 // -----------------------------------------------------------------------
1394
1395 // although equal, we now want to separate the output for the variables
1396 // and the real methods
1397 Int_t isel; // will be 0 for a Method; 1 for a Variable
1398 Int_t nmeth_used[2] = {0, 0}; // 0 Method; 1 Variable
1399
1400 std::vector<std::vector<TString>> mname(2);
1401 std::vector<std::vector<Double_t>> sig(2), sep(2), roc(2);
1402 std::vector<std::vector<Double_t>> eff01(2), eff10(2), eff30(2), effArea(2);
1403 std::vector<std::vector<Double_t>> eff01err(2), eff10err(2), eff30err(2);
1404 std::vector<std::vector<Double_t>> trainEff01(2), trainEff10(2), trainEff30(2);
1405
1406 std::vector<std::vector<Float_t>> multiclass_testEff;
1407 std::vector<std::vector<Float_t>> multiclass_trainEff;
1408 std::vector<std::vector<Float_t>> multiclass_testPur;
1409 std::vector<std::vector<Float_t>> multiclass_trainPur;
1410
1411 std::vector<std::vector<Float_t>> train_history;
1412
1413 // Multiclass confusion matrices.
1414 std::vector<TMatrixD> multiclass_trainConfusionEffB01;
1415 std::vector<TMatrixD> multiclass_trainConfusionEffB10;
1416 std::vector<TMatrixD> multiclass_trainConfusionEffB30;
1417 std::vector<TMatrixD> multiclass_testConfusionEffB01;
1418 std::vector<TMatrixD> multiclass_testConfusionEffB10;
1419 std::vector<TMatrixD> multiclass_testConfusionEffB30;
1420
1421 std::vector<std::vector<Double_t>> biastrain(1); // "bias" of the regression on the training data
1422 std::vector<std::vector<Double_t>> biastest(1); // "bias" of the regression on test data
1423 std::vector<std::vector<Double_t>> devtrain(1); // "dev" of the regression on the training data
1424 std::vector<std::vector<Double_t>> devtest(1); // "dev" of the regression on test data
1425 std::vector<std::vector<Double_t>> rmstrain(1); // "rms" of the regression on the training data
1426 std::vector<std::vector<Double_t>> rmstest(1); // "rms" of the regression on test data
1427 std::vector<std::vector<Double_t>> minftrain(1); // "minf" of the regression on the training data
1428 std::vector<std::vector<Double_t>> minftest(1); // "minf" of the regression on test data
1429 std::vector<std::vector<Double_t>> rhotrain(1); // correlation of the regression on the training data
1430 std::vector<std::vector<Double_t>> rhotest(1); // correlation of the regression on test data
1431
1432 // same as above but for 'truncated' quantities (computed for events within 2sigma of RMS)
1433 std::vector<std::vector<Double_t>> biastrainT(1);
1434 std::vector<std::vector<Double_t>> biastestT(1);
1435 std::vector<std::vector<Double_t>> devtrainT(1);
1436 std::vector<std::vector<Double_t>> devtestT(1);
1437 std::vector<std::vector<Double_t>> rmstrainT(1);
1438 std::vector<std::vector<Double_t>> rmstestT(1);
1439 std::vector<std::vector<Double_t>> minftrainT(1);
1440 std::vector<std::vector<Double_t>> minftestT(1);
1441
1442 // following vector contains all methods - with the exception of Cuts, which are special
1443 MVector methodsNoCuts;
1444
1445 Bool_t doRegression = kFALSE;
1446 Bool_t doMulticlass = kFALSE;
1447
1448 // iterate over methods and evaluate
1449 for (MVector::iterator itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1451 MethodBase *theMethod = dynamic_cast<MethodBase *>(*itrMethod);
1452 if (theMethod == 0)
1453 continue;
1454 theMethod->SetFile(fgTargetFile);
1455 theMethod->SetSilentFile(IsSilentFile());
1456 if (theMethod->GetMethodType() != Types::kCuts)
1457 methodsNoCuts.push_back(*itrMethod);
1458
1459 if (theMethod->DoRegression()) {
1460 doRegression = kTRUE;
1461
1462 Log() << kINFO << "Evaluate regression method: " << theMethod->GetMethodName() << Endl;
1463 Double_t bias, dev, rms, mInf;
1464 Double_t biasT, devT, rmsT, mInfT;
1465 Double_t rho;
1466
1467 Log() << kINFO << "TestRegression (testing)" << Endl;
1468 theMethod->TestRegression(bias, biasT, dev, devT, rms, rmsT, mInf, mInfT, rho, TMVA::Types::kTesting);
1469 biastest[0].push_back(bias);
1470 devtest[0].push_back(dev);
1471 rmstest[0].push_back(rms);
1472 minftest[0].push_back(mInf);
1473 rhotest[0].push_back(rho);
1474 biastestT[0].push_back(biasT);
1475 devtestT[0].push_back(devT);
1476 rmstestT[0].push_back(rmsT);
1477 minftestT[0].push_back(mInfT);
1478
1479 Log() << kINFO << "TestRegression (training)" << Endl;
1480 theMethod->TestRegression(bias, biasT, dev, devT, rms, rmsT, mInf, mInfT, rho, TMVA::Types::kTraining);
1481 biastrain[0].push_back(bias);
1482 devtrain[0].push_back(dev);
1483 rmstrain[0].push_back(rms);
1484 minftrain[0].push_back(mInf);
1485 rhotrain[0].push_back(rho);
1486 biastrainT[0].push_back(biasT);
1487 devtrainT[0].push_back(devT);
1488 rmstrainT[0].push_back(rmsT);
1489 minftrainT[0].push_back(mInfT);
1490
1491 mname[0].push_back(theMethod->GetMethodName());
1492 nmeth_used[0]++;
1493 if (!IsSilentFile()) {
1494 Log() << kDEBUG << "\tWrite evaluation histograms to file" << Endl;
1497 }
1498 } else if (theMethod->DoMulticlass()) {
1499 // ====================================================================
1500 // === Multiclass evaluation
1501 // ====================================================================
1502 doMulticlass = kTRUE;
1503 Log() << kINFO << "Evaluate multiclass classification method: " << theMethod->GetMethodName() << Endl;
1504
1505 // This part uses a genetic alg. to evaluate the optimal sig eff * sig pur.
1506 // This is why it is disabled for now.
1507 // Find approximate optimal working point w.r.t. signalEfficiency * signalPurity.
1508 // theMethod->TestMulticlass(); // This is where the actual GA calc is done
1509 // multiclass_testEff.push_back(theMethod->GetMulticlassEfficiency(multiclass_testPur));
1510
1511 theMethod->TestMulticlass();
1512
1513 // Confusion matrix at three background efficiency levels
1514 multiclass_trainConfusionEffB01.push_back(theMethod->GetMulticlassConfusionMatrix(0.01, Types::kTraining));
1515 multiclass_trainConfusionEffB10.push_back(theMethod->GetMulticlassConfusionMatrix(0.10, Types::kTraining));
1516 multiclass_trainConfusionEffB30.push_back(theMethod->GetMulticlassConfusionMatrix(0.30, Types::kTraining));
1517
1518 multiclass_testConfusionEffB01.push_back(theMethod->GetMulticlassConfusionMatrix(0.01, Types::kTesting));
1519 multiclass_testConfusionEffB10.push_back(theMethod->GetMulticlassConfusionMatrix(0.10, Types::kTesting));
1520 multiclass_testConfusionEffB30.push_back(theMethod->GetMulticlassConfusionMatrix(0.30, Types::kTesting));
1521
1522 if (!IsSilentFile()) {
1523 Log() << kDEBUG << "\tWrite evaluation histograms to file" << Endl;
1526 }
1527
1528 nmeth_used[0]++;
1529 mname[0].push_back(theMethod->GetMethodName());
1530 } else {
1531
1532 Log() << kHEADER << "Evaluate classifier: " << theMethod->GetMethodName() << Endl << Endl;
1533 isel = (theMethod->GetMethodTypeName().Contains("Variable")) ? 1 : 0;
1534
1535 // perform the evaluation
1536 theMethod->TestClassification();
1537
1538 // evaluate the classifier
1539 mname[isel].push_back(theMethod->GetMethodName());
1540 sig[isel].push_back(theMethod->GetSignificance());
1541 sep[isel].push_back(theMethod->GetSeparation());
1542 roc[isel].push_back(theMethod->GetROCIntegral());
1543
1544 Double_t err;
1545 eff01[isel].push_back(theMethod->GetEfficiency("Efficiency:0.01", Types::kTesting, err));
1546 eff01err[isel].push_back(err);
1547 eff10[isel].push_back(theMethod->GetEfficiency("Efficiency:0.10", Types::kTesting, err));
1548 eff10err[isel].push_back(err);
1549 eff30[isel].push_back(theMethod->GetEfficiency("Efficiency:0.30", Types::kTesting, err));
1550 eff30err[isel].push_back(err);
1551 effArea[isel].push_back(theMethod->GetEfficiency("", Types::kTesting, err)); // computes the area (average)
1552
1553 trainEff01[isel].push_back(
1554 theMethod->GetTrainingEfficiency("Efficiency:0.01")); // the first pass takes longer
1555 trainEff10[isel].push_back(theMethod->GetTrainingEfficiency("Efficiency:0.10"));
1556 trainEff30[isel].push_back(theMethod->GetTrainingEfficiency("Efficiency:0.30"));
1557
1558 nmeth_used[isel]++;
1559
1560 if (!IsSilentFile()) {
1561 Log() << kDEBUG << "\tWrite evaluation histograms to file" << Endl;
1564 }
1565 }
1566 }
1567 if (doRegression) {
1568
1569 std::vector<TString> vtemps = mname[0];
1570 std::vector<std::vector<Double_t>> vtmp;
1571 vtmp.push_back(devtest[0]); // this is the vector that is ranked
1572 vtmp.push_back(devtrain[0]);
1573 vtmp.push_back(biastest[0]);
1574 vtmp.push_back(biastrain[0]);
1575 vtmp.push_back(rmstest[0]);
1576 vtmp.push_back(rmstrain[0]);
1577 vtmp.push_back(minftest[0]);
1578 vtmp.push_back(minftrain[0]);
1579 vtmp.push_back(rhotest[0]);
1580 vtmp.push_back(rhotrain[0]);
1581 vtmp.push_back(devtestT[0]); // this is the vector that is ranked
1582 vtmp.push_back(devtrainT[0]);
1583 vtmp.push_back(biastestT[0]);
1584 vtmp.push_back(biastrainT[0]);
1585 vtmp.push_back(rmstestT[0]);
1586 vtmp.push_back(rmstrainT[0]);
1587 vtmp.push_back(minftestT[0]);
1588 vtmp.push_back(minftrainT[0]);
1589 gTools().UsefulSortAscending(vtmp, &vtemps);
1590 mname[0] = vtemps;
1591 devtest[0] = vtmp[0];
1592 devtrain[0] = vtmp[1];
1593 biastest[0] = vtmp[2];
1594 biastrain[0] = vtmp[3];
1595 rmstest[0] = vtmp[4];
1596 rmstrain[0] = vtmp[5];
1597 minftest[0] = vtmp[6];
1598 minftrain[0] = vtmp[7];
1599 rhotest[0] = vtmp[8];
1600 rhotrain[0] = vtmp[9];
1601 devtestT[0] = vtmp[10];
1602 devtrainT[0] = vtmp[11];
1603 biastestT[0] = vtmp[12];
1604 biastrainT[0] = vtmp[13];
1605 rmstestT[0] = vtmp[14];
1606 rmstrainT[0] = vtmp[15];
1607 minftestT[0] = vtmp[16];
1608 minftrainT[0] = vtmp[17];
1609 } else if (doMulticlass) {
1610 // TODO: fill in something meaningful
1611 // If there is some ranking of methods to be done it should be done here.
1612 // However, this is not so easy to define for multiclass so it is left out for now.
1613
1614 } else {
1615 // now sort the variables according to the best 'eff at Beff=0.10'
1616 for (Int_t k = 0; k < 2; k++) {
1617 std::vector<std::vector<Double_t>> vtemp;
1618 vtemp.push_back(effArea[k]); // this is the vector that is ranked
1619 vtemp.push_back(eff10[k]);
1620 vtemp.push_back(eff01[k]);
1621 vtemp.push_back(eff30[k]);
1622 vtemp.push_back(eff10err[k]);
1623 vtemp.push_back(eff01err[k]);
1624 vtemp.push_back(eff30err[k]);
1625 vtemp.push_back(trainEff10[k]);
1626 vtemp.push_back(trainEff01[k]);
1627 vtemp.push_back(trainEff30[k]);
1628 vtemp.push_back(sig[k]);
1629 vtemp.push_back(sep[k]);
1630 vtemp.push_back(roc[k]);
1631 std::vector<TString> vtemps = mname[k];
1632 gTools().UsefulSortDescending(vtemp, &vtemps);
1633 effArea[k] = vtemp[0];
1634 eff10[k] = vtemp[1];
1635 eff01[k] = vtemp[2];
1636 eff30[k] = vtemp[3];
1637 eff10err[k] = vtemp[4];
1638 eff01err[k] = vtemp[5];
1639 eff30err[k] = vtemp[6];
1640 trainEff10[k] = vtemp[7];
1641 trainEff01[k] = vtemp[8];
1642 trainEff30[k] = vtemp[9];
1643 sig[k] = vtemp[10];
1644 sep[k] = vtemp[11];
1645 roc[k] = vtemp[12];
1646 mname[k] = vtemps;
1647 }
1648 }
1649
1650 // -----------------------------------------------------------------------
1651 // Second part of evaluation process
1652 // --> compute correlations among MVAs
1653 // --> compute correlations between input variables and MVA (determines importance)
1654 // --> count overlaps
1655 // -----------------------------------------------------------------------
1656 if (fCorrelations) {
1657 const Int_t nmeth = methodsNoCuts.size();
1658 MethodBase *method = dynamic_cast<MethodBase *>(methods[0][0]);
1659 const Int_t nvar = method->fDataSetInfo.GetNVariables();
1660 if (!doRegression && !doMulticlass) {
1661
1662 if (nmeth > 0) {
1663
1664 // needed for correlations
1665 Double_t *dvec = new Double_t[nmeth + nvar];
1666 std::vector<Double_t> rvec;
1667
1668 // for correlations
1669 TPrincipal *tpSig = new TPrincipal(nmeth + nvar, "");
1670 TPrincipal *tpBkg = new TPrincipal(nmeth + nvar, "");
1671
1672 // set required tree branch references
1673 Int_t ivar = 0;
1674 std::vector<TString> *theVars = new std::vector<TString>;
1675 std::vector<ResultsClassification *> mvaRes;
1676 for (MVector::iterator itrMethod = methodsNoCuts.begin(); itrMethod != methodsNoCuts.end();
1677 ++itrMethod, ++ivar) {
1678 MethodBase *m = dynamic_cast<MethodBase *>(*itrMethod);
1679 if (m == 0)
1680 continue;
1681 theVars->push_back(m->GetTestvarName());
1682 rvec.push_back(m->GetSignalReferenceCut());
1683 theVars->back().ReplaceAll("MVA_", "");
1684 mvaRes.push_back(dynamic_cast<ResultsClassification *>(
1685 m->Data()->GetResults(m->GetMethodName(), Types::kTesting, Types::kMaxAnalysisType)));
1686 }
1687
1688 // for overlap study
1689 TMatrixD *overlapS = new TMatrixD(nmeth, nmeth);
1690 TMatrixD *overlapB = new TMatrixD(nmeth, nmeth);
1691 (*overlapS) *= 0; // init...
1692 (*overlapB) *= 0; // init...
1693
1694 // loop over test tree
1695 DataSet *defDs = method->fDataSetInfo.GetDataSet();
1697 for (Int_t ievt = 0; ievt < defDs->GetNEvents(); ievt++) {
1698 const Event *ev = defDs->GetEvent(ievt);
1699
1700 // for correlations
1701 TMatrixD *theMat = 0;
1702 for (Int_t im = 0; im < nmeth; im++) {
1703 // check for NaN value
1704 Double_t retval = (Double_t)(*mvaRes[im])[ievt][0];
1705 if (TMath::IsNaN(retval)) {
1706 Log() << kWARNING << "Found NaN return value in event: " << ievt << " for method \""
1707 << methodsNoCuts[im]->GetName() << "\"" << Endl;
1708 dvec[im] = 0;
1709 } else
1710 dvec[im] = retval;
1711 }
1712 for (Int_t iv = 0; iv < nvar; iv++)
1713 dvec[iv + nmeth] = (Double_t)ev->GetValue(iv);
1714 if (method->fDataSetInfo.IsSignal(ev)) {
1715 tpSig->AddRow(dvec);
1716 theMat = overlapS;
1717 } else {
1718 tpBkg->AddRow(dvec);
1719 theMat = overlapB;
1720 }
1721
1722 // count overlaps
1723 for (Int_t im = 0; im < nmeth; im++) {
1724 for (Int_t jm = im; jm < nmeth; jm++) {
1725 if ((dvec[im] - rvec[im]) * (dvec[jm] - rvec[jm]) > 0) {
1726 (*theMat)(im, jm)++;
1727 if (im != jm)
1728 (*theMat)(jm, im)++;
1729 }
1730 }
1731 }
1732 }
1733
1734 // renormalise overlap matrix
1735 (*overlapS) *= (1.0 / defDs->GetNEvtSigTest()); // init...
1736 (*overlapB) *= (1.0 / defDs->GetNEvtBkgdTest()); // init...
1737
1738 tpSig->MakePrincipals();
1739 tpBkg->MakePrincipals();
1740
1741 const TMatrixD *covMatS = tpSig->GetCovarianceMatrix();
1742 const TMatrixD *covMatB = tpBkg->GetCovarianceMatrix();
1743
1744 const TMatrixD *corrMatS = gTools().GetCorrelationMatrix(covMatS);
1745 const TMatrixD *corrMatB = gTools().GetCorrelationMatrix(covMatB);
1746
1747 // print correlation matrices
1748 if (corrMatS != 0 && corrMatB != 0) {
1749
1750 // extract MVA matrix
1751 TMatrixD mvaMatS(nmeth, nmeth);
1752 TMatrixD mvaMatB(nmeth, nmeth);
1753 for (Int_t im = 0; im < nmeth; im++) {
1754 for (Int_t jm = 0; jm < nmeth; jm++) {
1755 mvaMatS(im, jm) = (*corrMatS)(im, jm);
1756 mvaMatB(im, jm) = (*corrMatB)(im, jm);
1757 }
1758 }
1759
1760 // extract variables - to MVA matrix
1761 std::vector<TString> theInputVars;
1762 TMatrixD varmvaMatS(nvar, nmeth);
1763 TMatrixD varmvaMatB(nvar, nmeth);
1764 for (Int_t iv = 0; iv < nvar; iv++) {
1765 theInputVars.push_back(method->fDataSetInfo.GetVariableInfo(iv).GetLabel());
1766 for (Int_t jm = 0; jm < nmeth; jm++) {
1767 varmvaMatS(iv, jm) = (*corrMatS)(nmeth + iv, jm);
1768 varmvaMatB(iv, jm) = (*corrMatB)(nmeth + iv, jm);
1769 }
1770 }
1771
1772 if (nmeth > 1) {
1773 Log() << kINFO << Endl;
1774 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1775 << "Inter-MVA correlation matrix (signal):" << Endl;
1776 gTools().FormattedOutput(mvaMatS, *theVars, Log());
1777 Log() << kINFO << Endl;
1778
1779 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1780 << "Inter-MVA correlation matrix (background):" << Endl;
1781 gTools().FormattedOutput(mvaMatB, *theVars, Log());
1782 Log() << kINFO << Endl;
1783 }
1784
1785 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1786 << "Correlations between input variables and MVA response (signal):" << Endl;
1787 gTools().FormattedOutput(varmvaMatS, theInputVars, *theVars, Log());
1788 Log() << kINFO << Endl;
1789
1790 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1791 << "Correlations between input variables and MVA response (background):" << Endl;
1792 gTools().FormattedOutput(varmvaMatB, theInputVars, *theVars, Log());
1793 Log() << kINFO << Endl;
1794 } else
1795 Log() << kWARNING << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1796 << "<TestAllMethods> cannot compute correlation matrices" << Endl;
1797
1798 // print overlap matrices
1799 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1800 << "The following \"overlap\" matrices contain the fraction of events for which " << Endl;
1801 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1802 << "the MVAs 'i' and 'j' have returned conform answers about \"signal-likeness\"" << Endl;
1803 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1804 << "An event is signal-like, if its MVA output exceeds the following value:" << Endl;
1805 gTools().FormattedOutput(rvec, *theVars, "Method", "Cut value", Log());
1806 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1807 << "which correspond to the working point: eff(signal) = 1 - eff(background)" << Endl;
1808
1809 // give notice that cut method has been excluded from this test
1810 if (nmeth != (Int_t)methods->size())
1811 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1812 << "Note: no correlations and overlap with cut method are provided at present" << Endl;
1813
1814 if (nmeth > 1) {
1815 Log() << kINFO << Endl;
1816 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1817 << "Inter-MVA overlap matrix (signal):" << Endl;
1818 gTools().FormattedOutput(*overlapS, *theVars, Log());
1819 Log() << kINFO << Endl;
1820
1821 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
1822 << "Inter-MVA overlap matrix (background):" << Endl;
1823 gTools().FormattedOutput(*overlapB, *theVars, Log());
1824 }
1825
1826 // cleanup
1827 delete tpSig;
1828 delete tpBkg;
1829 delete corrMatS;
1830 delete corrMatB;
1831 delete theVars;
1832 delete overlapS;
1833 delete overlapB;
1834 delete[] dvec;
1835 }
1836 }
1837 }
1838 // -----------------------------------------------------------------------
1839 // Third part of evaluation process
1840 // --> output
1841 // -----------------------------------------------------------------------
1842
1843 if (doRegression) {
1844
1845 Log() << kINFO << Endl;
1846 TString hLine =
1847 "--------------------------------------------------------------------------------------------------";
1848 Log() << kINFO << "Evaluation results ranked by smallest RMS on test sample:" << Endl;
1849 Log() << kINFO << "(\"Bias\" quotes the mean deviation of the regression from true target." << Endl;
1850 Log() << kINFO << " \"MutInf\" is the \"Mutual Information\" between regression and target." << Endl;
1851 Log() << kINFO << " Indicated by \"_T\" are the corresponding \"truncated\" quantities ob-" << Endl;
1852 Log() << kINFO << " tained when removing events deviating more than 2sigma from average.)" << Endl;
1853 Log() << kINFO << hLine << Endl;
1854 // Log() << kINFO << "DataSet Name: MVA Method: <Bias> <Bias_T> RMS RMS_T | MutInf
1855 // MutInf_T" << Endl;
1856 Log() << kINFO << hLine << Endl;
1857
1858 for (Int_t i = 0; i < nmeth_used[0]; i++) {
1859 MethodBase *theMethod = dynamic_cast<MethodBase *>((*methods)[i]);
1860 if (theMethod == 0)
1861 continue;
1862
1863 Log() << kINFO
1864 << Form("%-20s %-15s:%#9.3g%#9.3g%#9.3g%#9.3g | %#5.3f %#5.3f", theMethod->fDataSetInfo.GetName(),
1865 (const char *)mname[0][i], biastest[0][i], biastestT[0][i], rmstest[0][i], rmstestT[0][i],
1866 minftest[0][i], minftestT[0][i])
1867 << Endl;
1868 }
1869 Log() << kINFO << hLine << Endl;
1870 Log() << kINFO << Endl;
1871 Log() << kINFO << "Evaluation results ranked by smallest RMS on training sample:" << Endl;
1872 Log() << kINFO << "(overtraining check)" << Endl;
1873 Log() << kINFO << hLine << Endl;
1874 Log() << kINFO
1875 << "DataSet Name: MVA Method: <Bias> <Bias_T> RMS RMS_T | MutInf MutInf_T"
1876 << Endl;
1877 Log() << kINFO << hLine << Endl;
1878
1879 for (Int_t i = 0; i < nmeth_used[0]; i++) {
1880 MethodBase *theMethod = dynamic_cast<MethodBase *>((*methods)[i]);
1881 if (theMethod == 0)
1882 continue;
1883 Log() << kINFO
1884 << Form("%-20s %-15s:%#9.3g%#9.3g%#9.3g%#9.3g | %#5.3f %#5.3f", theMethod->fDataSetInfo.GetName(),
1885 (const char *)mname[0][i], biastrain[0][i], biastrainT[0][i], rmstrain[0][i], rmstrainT[0][i],
1886 minftrain[0][i], minftrainT[0][i])
1887 << Endl;
1888 }
1889 Log() << kINFO << hLine << Endl;
1890 Log() << kINFO << Endl;
1891 } else if (doMulticlass) {
1892 // ====================================================================
1893 // === Multiclass Output
1894 // ====================================================================
1895
1896 TString hLine =
1897 "-------------------------------------------------------------------------------------------------------";
1898
1899 // This part uses a genetic alg. to evaluate the optimal sig eff * sig pur.
1900 // This is why it is disabled for now.
1901 //
1902 // // --- Acheivable signal efficiency * signal purity
1903 // // --------------------------------------------------------------------
1904 // Log() << kINFO << Endl;
1905 // Log() << kINFO << "Evaluation results ranked by best signal efficiency times signal purity " << Endl;
1906 // Log() << kINFO << hLine << Endl;
1907
1908 // // iterate over methods and evaluate
1909 // for (MVector::iterator itrMethod = methods->begin(); itrMethod != methods->end(); itrMethod++) {
1910 // MethodBase *theMethod = dynamic_cast<MethodBase *>(*itrMethod);
1911 // if (theMethod == 0) {
1912 // continue;
1913 // }
1914
1915 // TString header = "DataSet Name MVA Method ";
1916 // for (UInt_t icls = 0; icls < theMethod->fDataSetInfo.GetNClasses(); ++icls) {
1917 // header += Form("%-12s ", theMethod->fDataSetInfo.GetClassInfo(icls)->GetName());
1918 // }
1919
1920 // Log() << kINFO << header << Endl;
1921 // Log() << kINFO << hLine << Endl;
1922 // for (Int_t i = 0; i < nmeth_used[0]; i++) {
1923 // TString res = Form("[%-14s] %-15s", theMethod->fDataSetInfo.GetName(), (const char *)mname[0][i]);
1924 // for (UInt_t icls = 0; icls < theMethod->fDataSetInfo.GetNClasses(); ++icls) {
1925 // res += Form("%#1.3f ", (multiclass_testEff[i][icls]) * (multiclass_testPur[i][icls]));
1926 // }
1927 // Log() << kINFO << res << Endl;
1928 // }
1929
1930 // Log() << kINFO << hLine << Endl;
1931 // Log() << kINFO << Endl;
1932 // }
1933
1934 // --- 1 vs Rest ROC AUC, signal efficiency @ given background efficiency
1935 // --------------------------------------------------------------------
1936 TString header1 = Form("%-15s%-15s%-15s%-15s%-15s%-15s", "Dataset", "MVA Method", "ROC AUC", "Sig eff@B=0.01",
1937 "Sig eff@B=0.10", "Sig eff@B=0.30");
1938 TString header2 = Form("%-15s%-15s%-15s%-15s%-15s%-15s", "Name:", "/ Class:", "test (train)", "test (train)",
1939 "test (train)", "test (train)");
1940 Log() << kINFO << Endl;
1941 Log() << kINFO << "1-vs-rest performance metrics per class" << Endl;
1942 Log() << kINFO << hLine << Endl;
1943 Log() << kINFO << Endl;
1944 Log() << kINFO << "Considers the listed class as signal and the other classes" << Endl;
1945 Log() << kINFO << "as background, reporting the resulting binary performance." << Endl;
1946 Log() << kINFO << "A score of 0.820 (0.850) means 0.820 was acheived on the" << Endl;
1947 Log() << kINFO << "test set and 0.850 on the training set." << Endl;
1948
1949 Log() << kINFO << Endl;
1950 Log() << kINFO << header1 << Endl;
1951 Log() << kINFO << header2 << Endl;
1952 for (Int_t k = 0; k < 2; k++) {
1953 for (Int_t i = 0; i < nmeth_used[k]; i++) {
1954 if (k == 1) {
1955 mname[k][i].ReplaceAll("Variable_", "");
1956 }
1957
1958 const TString datasetName = itrMap->first;
1959 const TString mvaName = mname[k][i];
1960
1961 MethodBase *theMethod = dynamic_cast<MethodBase *>(GetMethod(datasetName, mvaName));
1962 if (theMethod == 0) {
1963 continue;
1964 }
1965
1966 Log() << kINFO << Endl;
1967 TString row = Form("%-15s%-15s", datasetName.Data(), mvaName.Data());
1968 Log() << kINFO << row << Endl;
1969 Log() << kINFO << "------------------------------" << Endl;
1970
1971 UInt_t numClasses = theMethod->fDataSetInfo.GetNClasses();
1972 for (UInt_t iClass = 0; iClass < numClasses; ++iClass) {
1973
1974 ROCCurve *rocCurveTrain = GetROC(datasetName, mvaName, iClass, Types::kTraining);
1975 ROCCurve *rocCurveTest = GetROC(datasetName, mvaName, iClass, Types::kTesting);
1976
1977 const TString className = theMethod->DataInfo().GetClassInfo(iClass)->GetName();
1978 const Double_t rocaucTrain = rocCurveTrain->GetROCIntegral();
1979 const Double_t effB01Train = rocCurveTrain->GetEffSForEffB(0.01);
1980 const Double_t effB10Train = rocCurveTrain->GetEffSForEffB(0.10);
1981 const Double_t effB30Train = rocCurveTrain->GetEffSForEffB(0.30);
1982 const Double_t rocaucTest = rocCurveTest->GetROCIntegral();
1983 const Double_t effB01Test = rocCurveTest->GetEffSForEffB(0.01);
1984 const Double_t effB10Test = rocCurveTest->GetEffSForEffB(0.10);
1985 const Double_t effB30Test = rocCurveTest->GetEffSForEffB(0.30);
1986 const TString rocaucCmp = Form("%5.3f (%5.3f)", rocaucTest, rocaucTrain);
1987 const TString effB01Cmp = Form("%5.3f (%5.3f)", effB01Test, effB01Train);
1988 const TString effB10Cmp = Form("%5.3f (%5.3f)", effB10Test, effB10Train);
1989 const TString effB30Cmp = Form("%5.3f (%5.3f)", effB30Test, effB30Train);
1990 row = Form("%-15s%-15s%-15s%-15s%-15s%-15s", "", className.Data(), rocaucCmp.Data(), effB01Cmp.Data(),
1991 effB10Cmp.Data(), effB30Cmp.Data());
1992 Log() << kINFO << row << Endl;
1993
1994 delete rocCurveTrain;
1995 delete rocCurveTest;
1996 }
1997 }
1998 }
1999 Log() << kINFO << Endl;
2000 Log() << kINFO << hLine << Endl;
2001 Log() << kINFO << Endl;
2002
2003 // --- Confusion matrices
2004 // --------------------------------------------------------------------
2005 auto printMatrix = [](TMatrixD const &matTraining, TMatrixD const &matTesting, std::vector<TString> classnames,
2006 UInt_t numClasses, MsgLogger &stream) {
2007 // assert (classLabledWidth >= valueLabelWidth + 2)
2008 // if (...) {Log() << kWARN << "..." << Endl; }
2009
2010 // TODO: Ensure matrices are same size.
2011
2012 TString header = Form(" %-14s", " ");
2013 TString headerInfo = Form(" %-14s", " ");
2014 ;
2015 for (UInt_t iCol = 0; iCol < numClasses; ++iCol) {
2016 header += Form(" %-14s", classnames[iCol].Data());
2017 headerInfo += Form(" %-14s", " test (train)");
2018 }
2019 stream << kINFO << header << Endl;
2020 stream << kINFO << headerInfo << Endl;
2021
2022 for (UInt_t iRow = 0; iRow < numClasses; ++iRow) {
2023 stream << kINFO << Form(" %-14s", classnames[iRow].Data());
2024
2025 for (UInt_t iCol = 0; iCol < numClasses; ++iCol) {
2026 if (iCol == iRow) {
2027 stream << kINFO << Form(" %-14s", "-");
2028 } else {
2029 Double_t trainValue = matTraining[iRow][iCol];
2030 Double_t testValue = matTesting[iRow][iCol];
2031 TString entry = Form("%-5.3f (%-5.3f)", testValue, trainValue);
2032 stream << kINFO << Form(" %-14s", entry.Data());
2033 }
2034 }
2035 stream << kINFO << Endl;
2036 }
2037 };
2038
2039 Log() << kINFO << Endl;
2040 Log() << kINFO << "Confusion matrices for all methods" << Endl;
2041 Log() << kINFO << hLine << Endl;
2042 Log() << kINFO << Endl;
2043 Log() << kINFO << "Does a binary comparison between the two classes given by a " << Endl;
2044 Log() << kINFO << "particular row-column combination. In each case, the class " << Endl;
2045 Log() << kINFO << "given by the row is considered signal while the class given " << Endl;
2046 Log() << kINFO << "by the column index is considered background." << Endl;
2047 Log() << kINFO << Endl;
2048 for (UInt_t iMethod = 0; iMethod < methods->size(); ++iMethod) {
2049 MethodBase *theMethod = dynamic_cast<MethodBase *>(methods->at(iMethod));
2050 if (theMethod == nullptr) {
2051 continue;
2052 }
2053 UInt_t numClasses = theMethod->fDataSetInfo.GetNClasses();
2054
2055 std::vector<TString> classnames;
2056 for (UInt_t iCls = 0; iCls < numClasses; ++iCls) {
2057 classnames.push_back(theMethod->fDataSetInfo.GetClassInfo(iCls)->GetName());
2058 }
2059 Log() << kINFO
2060 << "=== Showing confusion matrix for method : " << Form("%-15s", (const char *)mname[0][iMethod])
2061 << Endl;
2062 Log() << kINFO << "(Signal Efficiency for Background Efficiency 0.01%)" << Endl;
2063 Log() << kINFO << "---------------------------------------------------" << Endl;
2064 printMatrix(multiclass_testConfusionEffB01[iMethod], multiclass_trainConfusionEffB01[iMethod], classnames,
2065 numClasses, Log());
2066 Log() << kINFO << Endl;
2067
2068 Log() << kINFO << "(Signal Efficiency for Background Efficiency 0.10%)" << Endl;
2069 Log() << kINFO << "---------------------------------------------------" << Endl;
2070 printMatrix(multiclass_testConfusionEffB10[iMethod], multiclass_trainConfusionEffB10[iMethod], classnames,
2071 numClasses, Log());
2072 Log() << kINFO << Endl;
2073
2074 Log() << kINFO << "(Signal Efficiency for Background Efficiency 0.30%)" << Endl;
2075 Log() << kINFO << "---------------------------------------------------" << Endl;
2076 printMatrix(multiclass_testConfusionEffB30[iMethod], multiclass_trainConfusionEffB30[iMethod], classnames,
2077 numClasses, Log());
2078 Log() << kINFO << Endl;
2079 }
2080 Log() << kINFO << hLine << Endl;
2081 Log() << kINFO << Endl;
2082
2083 } else {
2084 // Binary classification
2085 if (fROC) {
2086 Log().EnableOutput();
2088 Log() << Endl;
2089 TString hLine = "------------------------------------------------------------------------------------------"
2090 "-------------------------";
2091 Log() << kINFO << "Evaluation results ranked by best signal efficiency and purity (area)" << Endl;
2092 Log() << kINFO << hLine << Endl;
2093 Log() << kINFO << "DataSet MVA " << Endl;
2094 Log() << kINFO << "Name: Method: ROC-integ" << Endl;
2095
2096 // Log() << kDEBUG << "DataSet MVA Signal efficiency at bkg eff.(error):
2097 // | Sepa- Signifi- " << Endl; Log() << kDEBUG << "Name: Method: @B=0.01
2098 // @B=0.10 @B=0.30 ROC-integ ROCCurve| ration: cance: " << Endl;
2099 Log() << kDEBUG << hLine << Endl;
2100 for (Int_t k = 0; k < 2; k++) {
2101 if (k == 1 && nmeth_used[k] > 0) {
2102 Log() << kINFO << hLine << Endl;
2103 Log() << kINFO << "Input Variables: " << Endl << hLine << Endl;
2104 }
2105 for (Int_t i = 0; i < nmeth_used[k]; i++) {
2106 TString datasetName = itrMap->first;
2107 TString methodName = mname[k][i];
2108
2109 if (k == 1) {
2110 methodName.ReplaceAll("Variable_", "");
2111 }
2112
2113 MethodBase *theMethod = dynamic_cast<MethodBase *>(GetMethod(datasetName, methodName));
2114 if (theMethod == 0) {
2115 continue;
2116 }
2117
2118 TMVA::DataSet *dataset = theMethod->Data();
2119 TMVA::Results *results = dataset->GetResults(methodName, Types::kTesting, this->fAnalysisType);
2120 std::vector<Bool_t> *mvaResType =
2121 dynamic_cast<ResultsClassification *>(results)->GetValueVectorTypes();
2122
2123 Double_t rocIntegral = 0.0;
2124 if (mvaResType->size() != 0) {
2125 rocIntegral = GetROCIntegral(datasetName, methodName);
2126 }
2127
2128 if (sep[k][i] < 0 || sig[k][i] < 0) {
2129 // cannot compute separation/significance -> no MVA (usually for Cuts)
2130 Log() << kINFO << Form("%-13s %-15s: %#1.3f", datasetName.Data(), methodName.Data(), effArea[k][i])
2131 << Endl;
2132
2133 // Log() << kDEBUG << Form("%-20s %-15s: %#1.3f(%02i) %#1.3f(%02i) %#1.3f(%02i)
2134 // %#1.3f %#1.3f | -- --",
2135 // datasetName.Data(),
2136 // methodName.Data(),
2137 // eff01[k][i], Int_t(1000*eff01err[k][i]),
2138 // eff10[k][i], Int_t(1000*eff10err[k][i]),
2139 // eff30[k][i], Int_t(1000*eff30err[k][i]),
2140 // effArea[k][i],rocIntegral) << Endl;
2141 } else {
2142 Log() << kINFO << Form("%-13s %-15s: %#1.3f", datasetName.Data(), methodName.Data(), rocIntegral)
2143 << Endl;
2144 // Log() << kDEBUG << Form("%-20s %-15s: %#1.3f(%02i) %#1.3f(%02i) %#1.3f(%02i)
2145 // %#1.3f %#1.3f | %#1.3f %#1.3f",
2146 // datasetName.Data(),
2147 // methodName.Data(),
2148 // eff01[k][i], Int_t(1000*eff01err[k][i]),
2149 // eff10[k][i], Int_t(1000*eff10err[k][i]),
2150 // eff30[k][i], Int_t(1000*eff30err[k][i]),
2151 // effArea[k][i],rocIntegral,
2152 // sep[k][i], sig[k][i]) << Endl;
2153 }
2154 }
2155 }
2156 Log() << kINFO << hLine << Endl;
2157 Log() << kINFO << Endl;
2158 Log() << kINFO << "Testing efficiency compared to training efficiency (overtraining check)" << Endl;
2159 Log() << kINFO << hLine << Endl;
2160 Log() << kINFO
2161 << "DataSet MVA Signal efficiency: from test sample (from training sample) "
2162 << Endl;
2163 Log() << kINFO << "Name: Method: @B=0.01 @B=0.10 @B=0.30 "
2164 << Endl;
2165 Log() << kINFO << hLine << Endl;
2166 for (Int_t k = 0; k < 2; k++) {
2167 if (k == 1 && nmeth_used[k] > 0) {
2168 Log() << kINFO << hLine << Endl;
2169 Log() << kINFO << "Input Variables: " << Endl << hLine << Endl;
2170 }
2171 for (Int_t i = 0; i < nmeth_used[k]; i++) {
2172 if (k == 1)
2173 mname[k][i].ReplaceAll("Variable_", "");
2174 MethodBase *theMethod = dynamic_cast<MethodBase *>((*methods)[i]);
2175 if (theMethod == 0)
2176 continue;
2177
2178 Log() << kINFO
2179 << Form("%-20s %-15s: %#1.3f (%#1.3f) %#1.3f (%#1.3f) %#1.3f (%#1.3f)",
2180 theMethod->fDataSetInfo.GetName(), (const char *)mname[k][i], eff01[k][i],
2181 trainEff01[k][i], eff10[k][i], trainEff10[k][i], eff30[k][i], trainEff30[k][i])
2182 << Endl;
2183 }
2184 }
2185 Log() << kINFO << hLine << Endl;
2186 Log() << kINFO << Endl;
2187
2188 if (gTools().CheckForSilentOption(GetOptions()))
2189 Log().InhibitOutput();
2190 } // end fROC
2191 }
2192 if (!IsSilentFile()) {
2193 std::list<TString> datasets;
2194 for (Int_t k = 0; k < 2; k++) {
2195 for (Int_t i = 0; i < nmeth_used[k]; i++) {
2196 MethodBase *theMethod = dynamic_cast<MethodBase *>((*methods)[i]);
2197 if (theMethod == 0)
2198 continue;
2199 // write test/training trees
2200 RootBaseDir()->cd(theMethod->fDataSetInfo.GetName());
2201 if (std::find(datasets.begin(), datasets.end(), theMethod->fDataSetInfo.GetName()) == datasets.end()) {
2204 datasets.push_back(theMethod->fDataSetInfo.GetName());
2205 }
2206 }
2207 }
2208 }
2209 } // end for MethodsMap
2210 // references for citation
2212}
2213
2214////////////////////////////////////////////////////////////////////////////////
2215/// Evaluate Variable Importance
2216
2218 const char *theOption)
2219{
2220 fModelPersistence = kFALSE;
2221 fSilentFile = kTRUE; // we need silent file here because we need fast classification results
2222
2223 // getting number of variables and variable names from loader
2224 const int nbits = loader->GetDataSetInfo().GetNVariables();
2225 if (vitype == VIType::kShort)
2226 return EvaluateImportanceShort(loader, theMethod, methodTitle, theOption);
2227 else if (vitype == VIType::kAll)
2228 return EvaluateImportanceAll(loader, theMethod, methodTitle, theOption);
2229 else if (vitype == VIType::kRandom) {
2230 if ( nbits > 10 && nbits < 30) {
2231 // limit nbits to less than 30 to avoid error converting from double to uint and also cannot deal with too many combinations
2232 return EvaluateImportanceRandom(loader, static_cast<UInt_t>( pow(2, nbits) ), theMethod, methodTitle, theOption);
2233 } else if (nbits < 10) {
2234 Log() << kERROR << "Error in Variable Importance: Random mode require more that 10 variables in the dataset."
2235 << Endl;
2236 } else if (nbits > 30) {
2237 Log() << kERROR << "Error in Variable Importance: Number of variables is too large for Random mode"
2238 << Endl;
2239 }
2240 }
2241 return nullptr;
2242}
2243
2244////////////////////////////////////////////////////////////////////////////////
2245
2247 const char *theOption)
2248{
2249
2250 uint64_t x = 0;
2251 uint64_t y = 0;
2252
2253 // getting number of variables and variable names from loader
2254 const int nbits = loader->GetDataSetInfo().GetNVariables();
2255 std::vector<TString> varNames = loader->GetDataSetInfo().GetListOfVariables();
2256
2257 if (nbits > 60) {
2258 Log() << kERROR << "Number of combinations is too large , is 2^" << nbits << Endl;
2259 return nullptr;
2260 }
2261 if (nbits > 20) {
2262 Log() << kWARNING << "Number of combinations is very large , is 2^" << nbits << Endl;
2263 }
2264 uint64_t range = static_cast<uint64_t>(pow(2, nbits));
2265
2266
2267 // vector to save importances
2268 std::vector<Double_t> importances(nbits);
2269 // vector to save ROC
2270 std::vector<Double_t> ROC(range);
2271 ROC[0] = 0.5;
2272 for (int i = 0; i < nbits; i++)
2273 importances[i] = 0;
2274
2275 Double_t SROC, SSROC; // computed ROC value
2276 for (x = 1; x < range; x++) {
2277
2278 std::bitset<VIBITS> xbitset(x);
2279 if (x == 0)
2280 continue; // data loader need at least one variable
2281
2282 // creating loader for seed
2283 TMVA::DataLoader *seedloader = new TMVA::DataLoader(xbitset.to_string());
2284
2285 // adding variables from seed
2286 for (int index = 0; index < nbits; index++) {
2287 if (xbitset[index])
2288 seedloader->AddVariable(varNames[index], 'F');
2289 }
2290
2291 DataLoaderCopy(seedloader, loader);
2292 seedloader->PrepareTrainingAndTestTree(loader->GetDataSetInfo().GetCut("Signal"),
2293 loader->GetDataSetInfo().GetCut("Background"),
2294 loader->GetDataSetInfo().GetSplitOptions());
2295
2296 // Booking Seed
2297 BookMethod(seedloader, theMethod, methodTitle, theOption);
2298
2299 // Train/Test/Evaluation
2300 TrainAllMethods();
2301 TestAllMethods();
2302 EvaluateAllMethods();
2303
2304 // getting ROC
2305 ROC[x] = GetROCIntegral(xbitset.to_string(), methodTitle);
2306
2307 // cleaning information to process sub-seeds
2308 TMVA::MethodBase *smethod = dynamic_cast<TMVA::MethodBase *>(fMethodsMap[xbitset.to_string().c_str()][0][0]);
2311 delete sresults;
2312 delete seedloader;
2313 this->DeleteAllMethods();
2314
2315 fMethodsMap.clear();
2316 // removing global result because it is requiring a lot of RAM for all seeds
2317 }
2318
2319 for (x = 0; x < range; x++) {
2320 SROC = ROC[x];
2321 for (uint32_t i = 0; i < VIBITS; ++i) {
2322 if (x & (uint64_t(1) << i)) {
2323 y = x & ~(1 << i);
2324 std::bitset<VIBITS> ybitset(y);
2325 // need at least one variable
2326 // NOTE: if sub-seed is zero then is the special case
2327 // that count in xbitset is 1
2328 uint32_t ny = static_cast<uint32_t>( log(x - y) / 0.693147 ) ;
2329 if (y == 0) {
2330 importances[ny] = SROC - 0.5;
2331 continue;
2332 }
2333
2334 // getting ROC
2335 SSROC = ROC[y];
2336 importances[ny] += SROC - SSROC;
2337 // cleaning information
2338 }
2339 }
2340 }
2341 std::cout << "--- Variable Importance Results (All)" << std::endl;
2342 return GetImportance(nbits, importances, varNames);
2343}
2344
2345static uint64_t sum(uint64_t i)
2346{
2347 // add a limit for overflows
2348 if (i > 62) return 0;
2349 return static_cast<uint64_t>( std::pow(2, i + 1)) - 1;
2350 // uint64_t _sum = 0;
2351 // for (uint64_t n = 0; n < i; n++)
2352 // _sum += pow(2, n);
2353 // return _sum;
2354}
2355
2356////////////////////////////////////////////////////////////////////////////////
2357
2359 const char *theOption)
2360{
2361 uint64_t x = 0;
2362 uint64_t y = 0;
2363
2364 // getting number of variables and variable names from loader
2365 const int nbits = loader->GetDataSetInfo().GetNVariables();
2366 std::vector<TString> varNames = loader->GetDataSetInfo().GetListOfVariables();
2367
2368 if (nbits > 60) {
2369 Log() << kERROR << "Number of combinations is too large , is 2^" << nbits << Endl;
2370 return nullptr;
2371 }
2372 long int range = sum(nbits);
2373 // std::cout<<range<<std::endl;
2374 // vector to save importances
2375 std::vector<Double_t> importances(nbits);
2376 for (int i = 0; i < nbits; i++)
2377 importances[i] = 0;
2378
2379 Double_t SROC, SSROC; // computed ROC value
2380
2381 x = range;
2382
2383 std::bitset<VIBITS> xbitset(x);
2384 if (x == 0)
2385 Log() << kFATAL << "Error: need at least one variable."; // data loader need at least one variable
2386
2387 // creating loader for seed
2388 TMVA::DataLoader *seedloader = new TMVA::DataLoader(xbitset.to_string());
2389
2390 // adding variables from seed
2391 for (int index = 0; index < nbits; index++) {
2392 if (xbitset[index])
2393 seedloader->AddVariable(varNames[index], 'F');
2394 }
2395
2396 // Loading Dataset
2397 DataLoaderCopy(seedloader, loader);
2398
2399 // Booking Seed
2400 BookMethod(seedloader, theMethod, methodTitle, theOption);
2401
2402 // Train/Test/Evaluation
2403 TrainAllMethods();
2404 TestAllMethods();
2405 EvaluateAllMethods();
2406
2407 // getting ROC
2408 SROC = GetROCIntegral(xbitset.to_string(), methodTitle);
2409
2410 // cleaning information to process sub-seeds
2411 TMVA::MethodBase *smethod = dynamic_cast<TMVA::MethodBase *>(fMethodsMap[xbitset.to_string().c_str()][0][0]);
2414 delete sresults;
2415 delete seedloader;
2416 this->DeleteAllMethods();
2417 fMethodsMap.clear();
2418
2419 // removing global result because it is requiring a lot of RAM for all seeds
2420
2421 for (uint32_t i = 0; i < VIBITS; ++i) {
2422 if (x & (1 << i)) {
2423 y = x & ~(uint64_t(1) << i);
2424 std::bitset<VIBITS> ybitset(y);
2425 // need at least one variable
2426 // NOTE: if sub-seed is zero then is the special case
2427 // that count in xbitset is 1
2428 uint32_t ny = static_cast<uint32_t>(log(x - y) / 0.693147);
2429 if (y == 0) {
2430 importances[ny] = SROC - 0.5;
2431 continue;
2432 }
2433
2434 // creating loader for sub-seed
2435 TMVA::DataLoader *subseedloader = new TMVA::DataLoader(ybitset.to_string());
2436 // adding variables from sub-seed
2437 for (int index = 0; index < nbits; index++) {
2438 if (ybitset[index])
2439 subseedloader->AddVariable(varNames[index], 'F');
2440 }
2441
2442 // Loading Dataset
2443 DataLoaderCopy(subseedloader, loader);
2444
2445 // Booking SubSeed
2446 BookMethod(subseedloader, theMethod, methodTitle, theOption);
2447
2448 // Train/Test/Evaluation
2449 TrainAllMethods();
2450 TestAllMethods();
2451 EvaluateAllMethods();
2452
2453 // getting ROC
2454 SSROC = GetROCIntegral(ybitset.to_string(), methodTitle);
2455 importances[ny] += SROC - SSROC;
2456
2457 // cleaning information
2458 TMVA::MethodBase *ssmethod = dynamic_cast<TMVA::MethodBase *>(fMethodsMap[ybitset.to_string().c_str()][0][0]);
2461 delete ssresults;
2462 delete subseedloader;
2463 this->DeleteAllMethods();
2464 fMethodsMap.clear();
2465 }
2466 }
2467 std::cout << "--- Variable Importance Results (Short)" << std::endl;
2468 return GetImportance(nbits, importances, varNames);
2469}
2470
2471////////////////////////////////////////////////////////////////////////////////
2472
2474 TString methodTitle, const char *theOption)
2475{
2476 TRandom3 *rangen = new TRandom3(0); // Random Gen.
2477
2478 uint64_t x = 0;
2479 uint64_t y = 0;
2480
2481 // getting number of variables and variable names from loader
2482 const int nbits = loader->GetDataSetInfo().GetNVariables();
2483 std::vector<TString> varNames = loader->GetDataSetInfo().GetListOfVariables();
2484
2485 long int range = pow(2, nbits);
2486
2487 // vector to save importances
2488 std::vector<Double_t> importances(nbits);
2489 Double_t importances_norm = 0;
2490 for (int i = 0; i < nbits; i++)
2491 importances[i] = 0;
2492
2493 Double_t SROC, SSROC; // computed ROC value
2494 for (UInt_t n = 0; n < nseeds; n++) {
2495 x = rangen->Integer(range);
2496
2497 std::bitset<32> xbitset(x);
2498 if (x == 0)
2499 continue; // data loader need at least one variable
2500
2501 // creating loader for seed
2502 TMVA::DataLoader *seedloader = new TMVA::DataLoader(xbitset.to_string());
2503
2504 // adding variables from seed
2505 for (int index = 0; index < nbits; index++) {
2506 if (xbitset[index])
2507 seedloader->AddVariable(varNames[index], 'F');
2508 }
2509
2510 // Loading Dataset
2511 DataLoaderCopy(seedloader, loader);
2512
2513 // Booking Seed
2514 BookMethod(seedloader, theMethod, methodTitle, theOption);
2515
2516 // Train/Test/Evaluation
2517 TrainAllMethods();
2518 TestAllMethods();
2519 EvaluateAllMethods();
2520
2521 // getting ROC
2522 SROC = GetROCIntegral(xbitset.to_string(), methodTitle);
2523 // std::cout << "Seed: n " << n << " x " << x << " xbitset:" << xbitset << " ROC " << SROC << std::endl;
2524
2525 // cleaning information to process sub-seeds
2526 TMVA::MethodBase *smethod = dynamic_cast<TMVA::MethodBase *>(fMethodsMap[xbitset.to_string().c_str()][0][0]);
2529 delete sresults;
2530 delete seedloader;
2531 this->DeleteAllMethods();
2532 fMethodsMap.clear();
2533
2534 // removing global result because it is requiring a lot of RAM for all seeds
2535
2536 for (uint32_t i = 0; i < 32; ++i) {
2537 if (x & (uint64_t(1) << i)) {
2538 y = x & ~(1 << i);
2539 std::bitset<32> ybitset(y);
2540 // need at least one variable
2541 // NOTE: if sub-seed is zero then is the special case
2542 // that count in xbitset is 1
2543 Double_t ny = log(x - y) / 0.693147;
2544 if (y == 0) {
2545 importances[ny] = SROC - 0.5;
2546 importances_norm += importances[ny];
2547 // std::cout << "SubSeed: " << y << " y:" << ybitset << "ROC " << 0.5 << std::endl;
2548 continue;
2549 }
2550
2551 // creating loader for sub-seed
2552 TMVA::DataLoader *subseedloader = new TMVA::DataLoader(ybitset.to_string());
2553 // adding variables from sub-seed
2554 for (int index = 0; index < nbits; index++) {
2555 if (ybitset[index])
2556 subseedloader->AddVariable(varNames[index], 'F');
2557 }
2558
2559 // Loading Dataset
2560 DataLoaderCopy(subseedloader, loader);
2561
2562 // Booking SubSeed
2563 BookMethod(subseedloader, theMethod, methodTitle, theOption);
2564
2565 // Train/Test/Evaluation
2566 TrainAllMethods();
2567 TestAllMethods();
2568 EvaluateAllMethods();
2569
2570 // getting ROC
2571 SSROC = GetROCIntegral(ybitset.to_string(), methodTitle);
2572 importances[ny] += SROC - SSROC;
2573 // std::cout << "SubSeed: " << y << " y:" << ybitset << " x-y " << x - y << " " << std::bitset<32>(x - y) <<
2574 // " ny " << ny << " SROC " << SROC << " SSROC " << SSROC << " Importance = " << importances[ny] <<
2575 // std::endl; cleaning information
2576 TMVA::MethodBase *ssmethod =
2577 dynamic_cast<TMVA::MethodBase *>(fMethodsMap[ybitset.to_string().c_str()][0][0]);
2580 delete ssresults;
2581 delete subseedloader;
2582 this->DeleteAllMethods();
2583 fMethodsMap.clear();
2584 }
2585 }
2586 }
2587 std::cout << "--- Variable Importance Results (Random)" << std::endl;
2588 return GetImportance(nbits, importances, varNames);
2589}
2590
2591////////////////////////////////////////////////////////////////////////////////
2592
2593TH1F *TMVA::Factory::GetImportance(const int nbits, std::vector<Double_t> importances, std::vector<TString> varNames)
2594{
2595 TH1F *vih1 = new TH1F("vih1", "", nbits, 0, nbits);
2596
2597 gStyle->SetOptStat(000000);
2598
2599 Float_t normalization = 0.0;
2600 for (int i = 0; i < nbits; i++) {
2601 normalization = normalization + importances[i];
2602 }
2603
2604 Float_t roc = 0.0;
2605
2606 gStyle->SetTitleXOffset(0.4);
2607 gStyle->SetTitleXOffset(1.2);
2608
2609 std::vector<Double_t> x_ie(nbits), y_ie(nbits);
2610 for (Int_t i = 1; i < nbits + 1; i++) {
2611 x_ie[i - 1] = (i - 1) * 1.;
2612 roc = 100.0 * importances[i - 1] / normalization;
2613 y_ie[i - 1] = roc;
2614 std::cout << "--- " << varNames[i - 1] << " = " << roc << " %" << std::endl;
2615 vih1->GetXaxis()->SetBinLabel(i, varNames[i - 1].Data());
2616 vih1->SetBinContent(i, roc);
2617 }
2618 TGraph *g_ie = new TGraph(nbits + 2, &x_ie[0], &y_ie[0]);
2619 g_ie->SetTitle("");
2620
2621 vih1->LabelsOption("v >", "X");
2622 vih1->SetBarWidth(0.97);
2623 Int_t ca = TColor::GetColor("#006600");
2624 vih1->SetFillColor(ca);
2625 // Int_t ci = TColor::GetColor("#990000");
2626
2627 vih1->GetYaxis()->SetTitle("Importance (%)");
2628 vih1->GetYaxis()->SetTitleSize(0.045);
2629 vih1->GetYaxis()->CenterTitle();
2630 vih1->GetYaxis()->SetTitleOffset(1.24);
2631
2632 vih1->GetYaxis()->SetRangeUser(-7, 50);
2633 vih1->SetDirectory(0);
2634
2635 // vih1->Draw("B");
2636 return vih1;
2637}
#define h(i)
Definition: RSha256.hxx:106
void printMatrix(const TMatrixD &mat)
write a matrix
int Int_t
Definition: RtypesCore.h:45
unsigned int UInt_t
Definition: RtypesCore.h:46
const Bool_t kFALSE
Definition: RtypesCore.h:101
bool Bool_t
Definition: RtypesCore.h:63
double Double_t
Definition: RtypesCore.h:59
float Float_t
Definition: RtypesCore.h:57
const Bool_t kTRUE
Definition: RtypesCore.h:100
#define ClassImp(name)
Definition: Rtypes.h:364
char name[80]
Definition: TGX11.cxx:110
int type
Definition: TGX11.cxx:121
double pow(double, double)
double log(double)
TMatrixT< Double_t > TMatrixD
Definition: TMatrixDfwd.h:22
#define gROOT
Definition: TROOT.h:404
char * Form(const char *fmt,...)
R__EXTERN TStyle * gStyle
Definition: TStyle.h:412
R__EXTERN TSystem * gSystem
Definition: TSystem.h:559
virtual void SetTitleOffset(Float_t offset=1)
Set distance between the axis and the axis title.
Definition: TAttAxis.cxx:302
virtual void SetTitleSize(Float_t size=0.04)
Set size of axis title.
Definition: TAttAxis.cxx:312
virtual void SetFillColor(Color_t fcolor)
Set the fill area color.
Definition: TAttFill.h:37
virtual void SetBinLabel(Int_t bin, const char *label)
Set label for bin.
Definition: TAxis.cxx:823
void CenterTitle(Bool_t center=kTRUE)
Center axis title.
Definition: TAxis.h:184
virtual void SetRangeUser(Double_t ufirst, Double_t ulast)
Set the viewing range for the axis from ufirst to ulast (in user coordinates, that is,...
Definition: TAxis.cxx:946
The Canvas class.
Definition: TCanvas.h:23
static Int_t GetColor(const char *hexcolor)
Static method returning color number for color specified by hex color string of form: "#rrggbb",...
Definition: TColor.cxx:1775
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:54
A TGraph is an object made of two arrays X and Y with npoints each.
Definition: TGraph.h:41
virtual void SetTitle(const char *title="")
Change (i.e.
Definition: TGraph.cxx:2333
1-D histogram with a float per channel (see TH1 documentation)}
Definition: TH1.h:575
virtual void SetDirectory(TDirectory *dir)
By default when an histogram is created, it is added to the list of histogram objects in the current ...
Definition: TH1.cxx:8780
virtual void SetTitle(const char *title)
See GetStatOverflows for more information.
Definition: TH1.cxx:6680
virtual void LabelsOption(Option_t *option="h", Option_t *axis="X")
Sort bins with labels or set option(s) to draw axis with labels.
Definition: TH1.cxx:5315
static void AddDirectory(Bool_t add=kTRUE)
Sets the flag controlling the automatic add of histograms in memory.
Definition: TH1.cxx:1283
TAxis * GetXaxis()
Get the behaviour adopted by the object about the statoverflows. See EStatOverflows for more informat...
Definition: TH1.h:320
TAxis * GetYaxis()
Definition: TH1.h:321
virtual void SetBinContent(Int_t bin, Double_t content)
Set bin content see convention for numbering bins in TH1::GetBin In case the bin number is greater th...
Definition: TH1.cxx:9065
virtual void SetBarWidth(Float_t width=0.5)
Set the width of bars as fraction of the bin width for drawing mode "B".
Definition: TH1.h:360
Service class for 2-D histogram classes.
Definition: TH2.h:30
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
TString fWeightFileDir
Definition: Config.h:124
TString fWeightFileDirPrefix
Definition: Config.h:123
void SetDrawProgressBar(Bool_t d)
Definition: Config.h:69
void SetUseColor(Bool_t uc)
Definition: Config.h:60
class TMVA::Config::VariablePlotting fVariablePlotting
void SetSilent(Bool_t s)
Definition: Config.h:63
IONames & GetIONames()
Definition: Config.h:98
void SetConfigDescription(const char *d)
Definition: Configurable.h:64
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
void AddPreDefVal(const T &)
Definition: Configurable.h:168
void SetConfigName(const char *n)
Definition: Configurable.h:63
virtual void ParseOptions()
options parser
const TString & GetOptions() const
Definition: Configurable.h:84
MsgLogger & Log() const
Definition: Configurable.h:122
MsgLogger * fLogger
Definition: Configurable.h:128
void CheckForUnusedOptions() const
checks for unused options in option string
UInt_t GetEntries(const TString &name) const
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
Definition: DataLoader.cxx:632
DataSetInfo & GetDataSetInfo()
Definition: DataLoader.cxx:137
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
Definition: DataLoader.cxx:485
Class that contains all the data information.
Definition: DataSetInfo.h:62
UInt_t GetNVariables() const
Definition: DataSetInfo.h:127
virtual const char * GetName() const
Returns name of object.
Definition: DataSetInfo.h:71
const TMatrixD * CorrelationMatrix(const TString &className) const
UInt_t GetNClasses() const
Definition: DataSetInfo.h:155
const TString & GetSplitOptions() const
Definition: DataSetInfo.h:186
UInt_t GetNTargets() const
Definition: DataSetInfo.h:128
DataSet * GetDataSet() const
returns data set
TH2 * CreateCorrelationMatrixHist(const TMatrixD *m, const TString &hName, const TString &hTitle) const
std::vector< TString > GetListOfVariables() const
returns list of variables
ClassInfo * GetClassInfo(Int_t clNum) const
const TCut & GetCut(Int_t i) const
Definition: DataSetInfo.h:168
VariableInfo & GetVariableInfo(Int_t i)
Definition: DataSetInfo.h:105
Bool_t IsSignal(const Event *ev) const
DataSetManager * GetDataSetManager()
Definition: DataSetInfo.h:194
DataInputHandler & DataInput()
Class that contains all the data information.
Definition: DataSet.h:58
Long64_t GetNEvtSigTest()
return number of signal test events in dataset
Definition: DataSet.cxx:427
TTree * GetTree(Types::ETreeType type)
create the test/trainings tree with all the variables, the weights, the classes, the targets,...
Definition: DataSet.cxx:609
const Event * GetEvent() const
Definition: DataSet.cxx:202
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:206
Results * GetResults(const TString &, Types::ETreeType type, Types::EAnalysisType analysistype)
Definition: DataSet.cxx:265
Long64_t GetNTrainingEvents() const
Definition: DataSet.h:68
void SetCurrentType(Types::ETreeType type) const
Definition: DataSet.h:89
const std::vector< Event * > & GetEventCollection(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:216
Long64_t GetNEvtBkgdTest()
return number of background test events in dataset
Definition: DataSet.cxx:435
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition: Event.cxx:236
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition: Event.cxx:391
This is the main MVA steering class.
Definition: Factory.h:80
void PrintHelpMessage(const TString &datasetname, const TString &methodTitle="") const
Print predefined help message of classifier.
Definition: Factory.cxx:1333
Bool_t fCorrelations
verbosity level, controls granularity of logging
Definition: Factory.h:215
std::vector< IMethod * > MVector
Definition: Factory.h:84
void TrainAllMethods()
Iterates through all booked methods and calls training.
Definition: Factory.cxx:1114
Bool_t Verbose(void) const
Definition: Factory.h:134
void WriteDataInformation(DataSetInfo &fDataSetInfo)
Definition: Factory.cxx:602
MethodBase * BookMethod(DataLoader *loader, TString theMethodName, TString methodTitle, TString theOption="")
Book a classifier or regression method.
Definition: Factory.cxx:352
Factory(TString theJobName, TFile *theTargetFile, TString theOption="")
Standard constructor.
Definition: Factory.cxx:113
void TestAllMethods()
Evaluates all booked methods on the testing data and adds the output to the Results in the corresponi...
Definition: Factory.cxx:1271
Bool_t fVerbose
list of transformations to test
Definition: Factory.h:213
void EvaluateAllMethods(void)
Iterates over all MVAs that have been booked, and calls their evaluation methods.
Definition: Factory.cxx:1376
TH1F * EvaluateImportanceRandom(DataLoader *loader, UInt_t nseeds, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Definition: Factory.cxx:2473
TH1F * GetImportance(const int nbits, std::vector< Double_t > importances, std::vector< TString > varNames)
Definition: Factory.cxx:2593
Bool_t fROC
enable to calculate corelations
Definition: Factory.h:216
void EvaluateAllVariables(DataLoader *loader, TString options="")
Iterates over all MVA input variables and evaluates them.
Definition: Factory.cxx:1360
TString fVerboseLevel
verbose mode
Definition: Factory.h:214
TMultiGraph * GetROCCurveAsMultiGraph(DataLoader *loader, UInt_t iClass, Types::ETreeType type=Types::kTesting)
Generate a collection of graphs, for all methods for a given class.
Definition: Factory.cxx:988
TH1F * EvaluateImportance(DataLoader *loader, VIType vitype, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Evaluate Variable Importance.
Definition: Factory.cxx:2217
Double_t GetROCIntegral(DataLoader *loader, TString theMethodName, UInt_t iClass=0, Types::ETreeType type=Types::kTesting)
Calculate the integral of the ROC curve, also known as the area under curve (AUC),...
Definition: Factory.cxx:849
virtual ~Factory()
Destructor.
Definition: Factory.cxx:306
virtual void MakeClass(const TString &datasetname, const TString &methodTitle="") const
Definition: Factory.cxx:1305
MethodBase * BookMethodWeightfile(DataLoader *dataloader, TMVA::Types::EMVA methodType, const TString &weightfile)
Adds an already constructed method to be managed by this factory.
Definition: Factory.cxx:501
Bool_t fModelPersistence
the training type
Definition: Factory.h:222
std::map< TString, Double_t > OptimizeAllMethods(TString fomType="ROCIntegral", TString fitType="FitGA")
Iterates through all booked methods and sees if they use parameter tuning and if so does just that,...
Definition: Factory.cxx:701
ROCCurve * GetROC(DataLoader *loader, TString theMethodName, UInt_t iClass=0, Types::ETreeType type=Types::kTesting)
Private method to generate a ROCCurve instance for a given method.
Definition: Factory.cxx:749
TH1F * EvaluateImportanceShort(DataLoader *loader, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Definition: Factory.cxx:2358
Types::EAnalysisType fAnalysisType
jobname, used as extension in weight file names
Definition: Factory.h:221
Bool_t HasMethod(const TString &datasetname, const TString &title) const
Checks whether a given method name is defined for a given dataset.
Definition: Factory.cxx:586
TGraph * GetROCCurve(DataLoader *loader, TString theMethodName, Bool_t setTitles=kTRUE, UInt_t iClass=0, Types::ETreeType type=Types::kTesting)
Argument iClass specifies the class to generate the ROC curve in a multiclass setting.
Definition: Factory.cxx:912
TH1F * EvaluateImportanceAll(DataLoader *loader, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Definition: Factory.cxx:2246
void SetVerbose(Bool_t v=kTRUE)
Definition: Factory.cxx:343
TFile * fgTargetFile
Definition: Factory.h:205
IMethod * GetMethod(const TString &datasetname, const TString &title) const
Returns pointer to MVA that corresponds to given method title.
Definition: Factory.cxx:566
void DeleteAllMethods(void)
Delete methods.
Definition: Factory.cxx:324
TString fTransformations
option string given by construction (presently only "V")
Definition: Factory.h:212
void Greetings()
Print welcome message.
Definition: Factory.cxx:295
Interface for all concrete MVA method implementations.
Definition: IMethod.h:53
virtual void PrintHelpMessage() const =0
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)=0
virtual void MakeClass(const TString &classFileName=TString("")) const =0
Virtual base Class for all MVA method.
Definition: MethodBase.h:111
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as
void SetSilentFile(Bool_t status)
Definition: MethodBase.h:378
void SetWeightFileDir(TString fileDir)
set directory of weight file
virtual void TestRegression(Double_t &bias, Double_t &biasT, Double_t &dev, Double_t &devT, Double_t &rms, Double_t &rmsT, Double_t &mInf, Double_t &mInfT, Double_t &corr, Types::ETreeType type)
calculate <sum-of-deviation-squared> of regression output versus "true" value from test sample
Definition: MethodBase.cxx:991
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
Definition: MethodBase.cxx:596
TString GetMethodTypeName() const
Definition: MethodBase.h:332
Bool_t DoMulticlass() const
Definition: MethodBase.h:439
virtual Double_t GetSignificance() const
compute significance of mean difference
const char * GetName() const
Definition: MethodBase.h:334
Types::EAnalysisType GetAnalysisType() const
Definition: MethodBase.h:437
virtual TMatrixD GetMulticlassConfusionMatrix(Double_t effB, Types::ETreeType type)
Construct a confusion matrix for a multiclass classifier.
void PrintHelpMessage() const
prints out method-specific help method
virtual void WriteEvaluationHistosToFile(Types::ETreeType treetype)
writes all MVA evaluation histograms to file
virtual void TestMulticlass()
test multiclass classification
void SetupMethod()
setup of methods
Definition: MethodBase.cxx:406
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
virtual void SetAnalysisType(Types::EAnalysisType type)
Definition: MethodBase.h:436
const TString & GetMethodName() const
Definition: MethodBase.h:331
Bool_t DoRegression() const
Definition: MethodBase.h:438
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
Definition: MethodBase.cxx:423
virtual Double_t GetTrainingEfficiency(const TString &)
DataSetInfo & DataInfo() const
Definition: MethodBase.h:410
virtual void MakeClass(const TString &classFileName=TString("")) const
create reader class for method (classification only at present)
virtual void TestClassification()
initialization
void AddOutput(Types::ETreeType type, Types::EAnalysisType analysisType)
void ReadStateFromFile()
Function to write options and weights to file.
virtual std::map< TString, Double_t > OptimizeTuningParameters(TString fomType="ROCIntegral", TString fitType="FitGA")
call the Optimizer with the set of parameters and ranges that are meant to be tuned.
Definition: MethodBase.cxx:623
DataSetInfo & fDataSetInfo
Definition: MethodBase.h:607
Types::EMVA GetMethodType() const
Definition: MethodBase.h:333
void SetFile(TFile *file)
Definition: MethodBase.h:375
DataSet * Data() const
Definition: MethodBase.h:409
void SetModelPersistence(Bool_t status)
Definition: MethodBase.h:382
virtual Double_t GetROCIntegral(TH1D *histS, TH1D *histB) const
calculate the area (integral) under the ROC curve as a overall quality measure of the classification
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Definition: MethodBase.cxx:433
Class for boosting a TMVA method.
Definition: MethodBoost.h:58
void SetBoostedMethodName(TString methodName)
Definition: MethodBoost.h:86
DataSetManager * fDataSetManager
Definition: MethodBoost.h:193
Class for categorizing the phase space.
DataSetManager * fDataSetManager
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:57
void SetMinType(EMsgType minType)
Definition: MsgLogger.h:70
void SetSource(const std::string &source)
Definition: MsgLogger.h:68
static void InhibitOutput()
Definition: MsgLogger.cxx:67
Double_t GetEffSForEffB(Double_t effB, const UInt_t num_points=41)
Calculate the signal efficiency (sensitivity) for a given background efficiency (sensitivity).
Definition: ROCCurve.cxx:219
Double_t GetROCIntegral(const UInt_t points=41)
Calculates the ROC integral (AUC)
Definition: ROCCurve.cxx:250
TGraph * GetROCCurve(const UInt_t points=100)
Returns a new TGraph containing the ROC curve.
Definition: ROCCurve.cxx:276
Ranking for variables in method (implementation)
Definition: Ranking.h:48
virtual void Print() const
get maximum length of variable names
Definition: Ranking.cxx:111
Class that is the base-class for a vector of result.
Class which takes the results of a multiclass classification.
Class that is the base-class for a vector of result.
Definition: Results.h:57
void FormattedOutput(const std::vector< Double_t > &, const std::vector< TString > &, const TString titleVars, const TString titleValues, MsgLogger &logger, TString format="%+1.3f")
formatted output of simple table
Definition: Tools.cxx:887
void UsefulSortDescending(std::vector< std::vector< Double_t > > &, std::vector< TString > *vs=0)
sort 2D vector (AND in parallel a TString vector) in such a way that the "first vector is sorted" and...
Definition: Tools.cxx:564
void ROOTVersionMessage(MsgLogger &logger)
prints the ROOT release number and date
Definition: Tools.cxx:1325
std::vector< TString > SplitString(const TString &theOpt, const char separator) const
splits the option string at 'separator' and fills the list 'splitV' with the primitive strings
Definition: Tools.cxx:1199
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:828
const TMatrixD * GetCorrelationMatrix(const TMatrixD *covMat)
turns covariance into correlation matrix
Definition: Tools.cxx:324
@ kHtmlLink
Definition: Tools.h:212
void TMVACitation(MsgLogger &logger, ECitation citType=kPlainText)
kinds of TMVA citation
Definition: Tools.cxx:1441
void TMVAVersionMessage(MsgLogger &logger)
prints the TMVA release number and date
Definition: Tools.cxx:1316
void TMVAWelcomeMessage()
direct output, eg, when starting ROOT session -> no use of Logger here
Definition: Tools.cxx:1302
void UsefulSortAscending(std::vector< std::vector< Double_t > > &, std::vector< TString > *vs=0)
sort 2D vector (AND in parallel a TString vector) in such a way that the "first vector is sorted" and...
Definition: Tools.cxx:538
Class that contains all the data information.
void PrintVariableRanking() const
prints ranking of input variables
Singleton class for Global types used by TMVA.
Definition: Types.h:71
static Types & Instance()
the the single instance of "Types" if existing already, or create it (Singleton)
Definition: Types.cxx:70
@ kCategory
Definition: Types.h:97
@ kCuts
Definition: Types.h:78
EAnalysisType
Definition: Types.h:126
@ kMulticlass
Definition: Types.h:129
@ kNoAnalysisType
Definition: Types.h:130
@ kClassification
Definition: Types.h:127
@ kMaxAnalysisType
Definition: Types.h:131
@ kRegression
Definition: Types.h:128
@ kTraining
Definition: Types.h:143
@ kTesting
Definition: Types.h:144
const TString & GetLabel() const
Definition: VariableInfo.h:59
VIType
Definition: Types.h:69
@ kDEBUG
Definition: Types.h:56
@ kVERBOSE
Definition: Types.h:57
@ kHEADER
Definition: Types.h:63
@ kERROR
Definition: Types.h:60
@ kINFO
Definition: Types.h:58
@ kWARNING
Definition: Types.h:59
@ kFATAL
Definition: Types.h:61
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:36
TList * GetListOfGraphs() const
Definition: TMultiGraph.h:70
virtual void Add(TGraph *graph, Option_t *chopt="")
Add a new graph to the list of graphs.
TH1F * GetHistogram()
Returns a pointer to the histogram used to draw the axis.
virtual void Draw(Option_t *chopt="")
Draw this multigraph with its current attributes.
TAxis * GetYaxis()
Get y axis of the graph.
TAxis * GetXaxis()
Get x axis of the graph.
virtual void SetTitle(const char *title="")
Set the title of the TNamed.
Definition: TNamed.cxx:164
TString fName
Definition: TNamed.h:32
virtual TObject * Clone(const char *newname="") const
Make a clone of an object using the Streamer facility.
Definition: TNamed.cxx:74
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:47
@ kOverwrite
overwrite existing object with same name
Definition: TObject.h:88
void SetGrid(Int_t valuex=1, Int_t valuey=1) override
Definition: TPad.h:327
TLegend * BuildLegend(Double_t x1=0.3, Double_t y1=0.21, Double_t x2=0.3, Double_t y2=0.21, const char *title="", Option_t *option="") override
Build a legend from the graphical objects in the pad.
Definition: TPad.cxx:503
Principal Components Analysis (PCA)
Definition: TPrincipal.h:21
virtual void AddRow(const Double_t *x)
Add a data point and update the covariance matrix.
Definition: TPrincipal.cxx:410
const TMatrixD * GetCovarianceMatrix() const
Definition: TPrincipal.h:59
virtual void MakePrincipals()
Perform the principal components analysis.
Definition: TPrincipal.cxx:875
Random number generator class based on M.
Definition: TRandom3.h:27
virtual UInt_t Integer(UInt_t imax)
Returns a random integer uniformly distributed on the interval [ 0, imax-1 ].
Definition: TRandom.cxx:360
Basic string class.
Definition: TString.h:136
Ssiz_t Length() const
Definition: TString.h:410
void ToLower()
Change string to lower-case.
Definition: TString.cxx:1150
int CompareTo(const char *cs, ECaseCompare cmp=kExact) const
Compare a string to char *cs2.
Definition: TString.cxx:442
const char * Data() const
Definition: TString.h:369
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition: TString.h:692
Bool_t BeginsWith(const char *s, ECaseCompare cmp=kExact) const
Definition: TString.h:615
Bool_t IsNull() const
Definition: TString.h:407
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition: TString.h:624
void SetOptStat(Int_t stat=1)
The type of information printed in the histogram statistics box can be selected via the parameter mod...
Definition: TStyle.cxx:1589
void SetTitleXOffset(Float_t offset=1)
Definition: TStyle.h:392
virtual int MakeDirectory(const char *name)
Make a directory.
Definition: TSystem.cxx:828
virtual Int_t Write(const char *name=0, Int_t option=0, Int_t bufsize=0)
Write this object to the current directory.
Definition: TTree.cxx:9698
Double_t y[n]
Definition: legend1.C:17
Double_t x[n]
Definition: legend1.C:17
const Int_t n
Definition: legend1.C:16
RPY_EXPORTED TCppMethod_t GetMethod(TCppScope_t scope, TCppIndex_t imeth)
static constexpr double s
void GetMethodName(TString &name, TKey *mkey)
Definition: tmvaglob.cxx:342
void DataLoaderCopy(TMVA::DataLoader *des, TMVA::DataLoader *src)
Config & gConfig()
Tools & gTools()
void CreateVariableTransforms(const TString &trafoDefinition, TMVA::DataSetInfo &dataInfo, TMVA::TransformationHandler &transformationHandler, TMVA::MsgLogger &log)
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:148
Bool_t IsNaN(Double_t x)
Definition: TMath.h:892
Double_t Log(Double_t x)
Definition: TMath.h:760
Definition: graph.py:1
auto * m
Definition: textangle.C:8
#define VIBITS
Definition: Factory.cxx:103
static uint64_t sum(uint64_t i)
Definition: Factory.cxx:2345
const Int_t MinNoTrainingEvents
Definition: Factory.cxx:95
#define READXML
Definition: Factory.cxx:100