Logo ROOT  
Reference Guide
Factory.cxx
Go to the documentation of this file.
1// @(#)Root/tmva $Id$
2// Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne, Jan Therhaag
3// Updated by: Omar Zapata, Kim Albertsson
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : Factory *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Implementation (see header for description) *
12 * *
13 * Authors : *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Joerg Stelzer <stelzer@cern.ch> - DESY, Germany *
16 * Peter Speckmayer <peter.speckmayer@cern.ch> - CERN, Switzerland *
17 * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
18 * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
19 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
20 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
21 * Omar Zapata <Omar.Zapata@cern.ch> - UdeA/ITM Colombia *
22 * Lorenzo Moneta <Lorenzo.Moneta@cern.ch> - CERN, Switzerland *
23 * Sergei Gleyzer <Sergei.Gleyzer@cern.ch> - U of Florida & CERN *
24 * Kim Albertsson <kim.albertsson@cern.ch> - LTU & CERN *
25 * *
26 * Copyright (c) 2005-2015: *
27 * CERN, Switzerland *
28 * U. of Victoria, Canada *
29 * MPI-K Heidelberg, Germany *
30 * U. of Bonn, Germany *
31 * UdeA/ITM, Colombia *
32 * U. of Florida, USA *
33 * *
34 * Redistribution and use in source and binary forms, with or without *
35 * modification, are permitted according to the terms listed in LICENSE *
36 * (http://tmva.sourceforge.net/LICENSE) *
37 **********************************************************************************/
38
39/*! \class TMVA::Factory
40\ingroup TMVA
41
42This is the main MVA steering class.
43It creates all MVA methods, and guides them through the training, testing and
44evaluation phases.
45*/
46
47#include "TMVA/Factory.h"
48
50#include "TMVA/Config.h"
51#include "TMVA/Configurable.h"
52#include "TMVA/Tools.h"
53#include "TMVA/Ranking.h"
54#include "TMVA/DataSet.h"
55#include "TMVA/IMethod.h"
56#include "TMVA/MethodBase.h"
58#include "TMVA/DataSetManager.h"
59#include "TMVA/DataSetInfo.h"
60#include "TMVA/DataLoader.h"
61#include "TMVA/MethodBoost.h"
62#include "TMVA/MethodCategory.h"
63#include "TMVA/ROCCalc.h"
64#include "TMVA/ROCCurve.h"
65#include "TMVA/MsgLogger.h"
66
67#include "TMVA/VariableInfo.h"
69
70#include "TMVA/Results.h"
74#include <list>
75#include <bitset>
76
77#include "TMVA/Types.h"
78
79#include "TROOT.h"
80#include "TFile.h"
81#include "TTree.h"
82#include "TLeaf.h"
83#include "TEventList.h"
84#include "TH2.h"
85#include "TText.h"
86#include "TLegend.h"
87#include "TGraph.h"
88#include "TStyle.h"
89#include "TMatrixF.h"
90#include "TMatrixDSym.h"
91#include "TMultiGraph.h"
92#include "TPaletteAxis.h"
93#include "TPrincipal.h"
94#include "TMath.h"
95#include "TObjString.h"
96#include "TSystem.h"
97#include "TCanvas.h"
98
100//const Int_t MinNoTestEvents = 1;
101
103
104#define READXML kTRUE
105
106//number of bits for bitset
107#define VIBITS 32
108
109
110
111////////////////////////////////////////////////////////////////////////////////
112/// Standard constructor.
113///
114/// - jobname : this name will appear in all weight file names produced by the MVAs
115/// - theTargetFile : output ROOT file; the test tree and all evaluation plots
116/// will be stored here
117/// - theOption : option string; currently: "V" for verbose
118
119TMVA::Factory::Factory( TString jobName, TFile* theTargetFile, TString theOption )
120: Configurable ( theOption ),
121 fTransformations ( "I" ),
122 fVerbose ( kFALSE ),
123 fVerboseLevel ( kINFO ),
124 fCorrelations ( kFALSE ),
125 fROC ( kTRUE ),
126 fSilentFile ( theTargetFile == nullptr ),
127 fJobName ( jobName ),
128 fAnalysisType ( Types::kClassification ),
129 fModelPersistence (kTRUE)
130{
131 fName = "Factory";
132 fgTargetFile = theTargetFile;
134
135 // render silent
136 if (gTools().CheckForSilentOption( GetOptions() )) Log().InhibitOutput(); // make sure is silent if wanted to
137
138
139 // init configurable
140 SetConfigDescription( "Configuration options for Factory running" );
142
143 // histograms are not automatically associated with the current
144 // directory and hence don't go out of scope when closing the file
145 // TH1::AddDirectory(kFALSE);
146 Bool_t silent = kFALSE;
147#ifdef WIN32
148 // under Windows, switch progress bar and color off by default, as the typical windows shell doesn't handle these (would need different sequences..)
149 Bool_t color = kFALSE;
150 Bool_t drawProgressBar = kFALSE;
151#else
152 Bool_t color = !gROOT->IsBatch();
153 Bool_t drawProgressBar = kTRUE;
154#endif
155 DeclareOptionRef( fVerbose, "V", "Verbose flag" );
156 DeclareOptionRef( fVerboseLevel=TString("Info"), "VerboseLevel", "VerboseLevel (Debug/Verbose/Info)" );
157 AddPreDefVal(TString("Debug"));
158 AddPreDefVal(TString("Verbose"));
159 AddPreDefVal(TString("Info"));
160 DeclareOptionRef( color, "Color", "Flag for coloured screen output (default: True, if in batch mode: False)" );
161 DeclareOptionRef( fTransformations, "Transformations", "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for identity, decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation transformations" );
162 DeclareOptionRef( fCorrelations, "Correlations", "boolean to show correlation in output" );
163 DeclareOptionRef( fROC, "ROC", "boolean to show ROC in output" );
164 DeclareOptionRef( silent, "Silent", "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory class object (default: False)" );
165 DeclareOptionRef( drawProgressBar,
166 "DrawProgressBar", "Draw progress bar to display training, testing and evaluation schedule (default: True)" );
168 "ModelPersistence",
169 "Option to save the trained model in xml file or using serialization");
170
171 TString analysisType("Auto");
172 DeclareOptionRef( analysisType,
173 "AnalysisType", "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)" );
174 AddPreDefVal(TString("Classification"));
175 AddPreDefVal(TString("Regression"));
176 AddPreDefVal(TString("Multiclass"));
177 AddPreDefVal(TString("Auto"));
178
179 ParseOptions();
181
182 if (Verbose()) fLogger->SetMinType( kVERBOSE );
183 if (fVerboseLevel.CompareTo("Debug") ==0) fLogger->SetMinType( kDEBUG );
184 if (fVerboseLevel.CompareTo("Verbose") ==0) fLogger->SetMinType( kVERBOSE );
185 if (fVerboseLevel.CompareTo("Info") ==0) fLogger->SetMinType( kINFO );
186
187 // global settings
188 gConfig().SetUseColor( color );
189 gConfig().SetSilent( silent );
190 gConfig().SetDrawProgressBar( drawProgressBar );
191
192 analysisType.ToLower();
193 if ( analysisType == "classification" ) fAnalysisType = Types::kClassification;
194 else if( analysisType == "regression" ) fAnalysisType = Types::kRegression;
195 else if( analysisType == "multiclass" ) fAnalysisType = Types::kMulticlass;
196 else if( analysisType == "auto" ) fAnalysisType = Types::kNoAnalysisType;
197
198// Greetings();
199}
200
201////////////////////////////////////////////////////////////////////////////////
202/// Constructor.
203
205: Configurable ( theOption ),
206 fTransformations ( "I" ),
207 fVerbose ( kFALSE ),
208 fCorrelations ( kFALSE ),
209 fROC ( kTRUE ),
210 fSilentFile ( kTRUE ),
211 fJobName ( jobName ),
212 fAnalysisType ( Types::kClassification ),
213 fModelPersistence (kTRUE)
214{
215 fName = "Factory";
216 fgTargetFile = nullptr;
218
219
220 // render silent
221 if (gTools().CheckForSilentOption( GetOptions() )) Log().InhibitOutput(); // make sure is silent if wanted to
222
223
224 // init configurable
225 SetConfigDescription( "Configuration options for Factory running" );
227
228 // histograms are not automatically associated with the current
229 // directory and hence don't go out of scope when closing the file
231 Bool_t silent = kFALSE;
232#ifdef WIN32
233 // under Windows, switch progress bar and color off by default, as the typical windows shell doesn't handle these (would need different sequences..)
234 Bool_t color = kFALSE;
235 Bool_t drawProgressBar = kFALSE;
236#else
237 Bool_t color = !gROOT->IsBatch();
238 Bool_t drawProgressBar = kTRUE;
239#endif
240 DeclareOptionRef( fVerbose, "V", "Verbose flag" );
241 DeclareOptionRef( fVerboseLevel=TString("Info"), "VerboseLevel", "VerboseLevel (Debug/Verbose/Info)" );
242 AddPreDefVal(TString("Debug"));
243 AddPreDefVal(TString("Verbose"));
244 AddPreDefVal(TString("Info"));
245 DeclareOptionRef( color, "Color", "Flag for coloured screen output (default: True, if in batch mode: False)" );
246 DeclareOptionRef( fTransformations, "Transformations", "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for identity, decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation transformations" );
247 DeclareOptionRef( fCorrelations, "Correlations", "boolean to show correlation in output" );
248 DeclareOptionRef( fROC, "ROC", "boolean to show ROC in output" );
249 DeclareOptionRef( silent, "Silent", "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory class object (default: False)" );
250 DeclareOptionRef( drawProgressBar,
251 "DrawProgressBar", "Draw progress bar to display training, testing and evaluation schedule (default: True)" );
253 "ModelPersistence",
254 "Option to save the trained model in xml file or using serialization");
255
256 TString analysisType("Auto");
257 DeclareOptionRef( analysisType,
258 "AnalysisType", "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)" );
259 AddPreDefVal(TString("Classification"));
260 AddPreDefVal(TString("Regression"));
261 AddPreDefVal(TString("Multiclass"));
262 AddPreDefVal(TString("Auto"));
263
264 ParseOptions();
266
267 if (Verbose()) fLogger->SetMinType( kVERBOSE );
268 if (fVerboseLevel.CompareTo("Debug") ==0) fLogger->SetMinType( kDEBUG );
269 if (fVerboseLevel.CompareTo("Verbose") ==0) fLogger->SetMinType( kVERBOSE );
270 if (fVerboseLevel.CompareTo("Info") ==0) fLogger->SetMinType( kINFO );
271
272 // global settings
273 gConfig().SetUseColor( color );
274 gConfig().SetSilent( silent );
275 gConfig().SetDrawProgressBar( drawProgressBar );
276
277 analysisType.ToLower();
278 if ( analysisType == "classification" ) fAnalysisType = Types::kClassification;
279 else if( analysisType == "regression" ) fAnalysisType = Types::kRegression;
280 else if( analysisType == "multiclass" ) fAnalysisType = Types::kMulticlass;
281 else if( analysisType == "auto" ) fAnalysisType = Types::kNoAnalysisType;
282
283 Greetings();
284}
285
286////////////////////////////////////////////////////////////////////////////////
287/// Print welcome message.
288/// Options are: kLogoWelcomeMsg, kIsometricWelcomeMsg, kLeanWelcomeMsg
289
291{
293 gTools().TMVAWelcomeMessage( Log(), gTools().kLogoWelcomeMsg );
294 gTools().TMVAVersionMessage( Log() ); Log() << Endl;
295}
296
297////////////////////////////////////////////////////////////////////////////////
298/// Destructor.
299
301{
302 std::vector<TMVA::VariableTransformBase*>::iterator trfIt = fDefaultTrfs.begin();
303 for (;trfIt != fDefaultTrfs.end(); ++trfIt) delete (*trfIt);
304
305 this->DeleteAllMethods();
306
307
308 // problem with call of REGISTER_METHOD macro ...
309 // ClassifierFactory::DestroyInstance();
310 // Types::DestroyInstance();
311 //Tools::DestroyInstance();
312 //Config::DestroyInstance();
313}
314
315////////////////////////////////////////////////////////////////////////////////
316/// Delete methods.
317
319{
320 std::map<TString,MVector*>::iterator itrMap;
321
322 for(itrMap = fMethodsMap.begin();itrMap != fMethodsMap.end();++itrMap)
323 {
324 MVector *methods=itrMap->second;
325 // delete methods
326 MVector::iterator itrMethod = methods->begin();
327 for (; itrMethod != methods->end(); ++itrMethod) {
328 Log() << kDEBUG << "Delete method: " << (*itrMethod)->GetName() << Endl;
329 delete (*itrMethod);
330 }
331 methods->clear();
332 delete methods;
333 }
334}
335
336////////////////////////////////////////////////////////////////////////////////
337
339{
340 fVerbose = v;
341}
342
343////////////////////////////////////////////////////////////////////////////////
344/// Book a classifier or regression method.
345
346TMVA::MethodBase* TMVA::Factory::BookMethod( TMVA::DataLoader *loader, TString theMethodName, TString methodTitle, TString theOption )
347{
348 if(fModelPersistence) gSystem->MakeDirectory(loader->GetName());//creating directory for DataLoader output
349
350 TString datasetname=loader->GetName();
351
352 if( fAnalysisType == Types::kNoAnalysisType ){
353 if( loader->GetDataSetInfo().GetNClasses()==2
354 && loader->GetDataSetInfo().GetClassInfo("Signal") != NULL
355 && loader->GetDataSetInfo().GetClassInfo("Background") != NULL
356 ){
357 fAnalysisType = Types::kClassification; // default is classification
358 } else if( loader->GetDataSetInfo().GetNClasses() >= 2 ){
359 fAnalysisType = Types::kMulticlass; // if two classes, but not named "Signal" and "Background"
360 } else
361 Log() << kFATAL << "No analysis type for " << loader->GetDataSetInfo().GetNClasses() << " classes and "
362 << loader->GetDataSetInfo().GetNTargets() << " regression targets." << Endl;
363 }
364
365 // booking via name; the names are translated into enums and the
366 // corresponding overloaded BookMethod is called
367
368 if(fMethodsMap.find(datasetname)!=fMethodsMap.end())
369 {
370 if (GetMethod( datasetname,methodTitle ) != 0) {
371 Log() << kFATAL << "Booking failed since method with title <"
372 << methodTitle <<"> already exists "<< "in with DataSet Name <"<< loader->GetName()<<"> "
373 << Endl;
374 }
375 }
376
377
378 Log() << kHEADER << "Booking method: " << gTools().Color("bold") << methodTitle
379 // << gTools().Color("reset")<<" DataSet Name: "<<gTools().Color("bold")<<loader->GetName()
380 << gTools().Color("reset") << Endl << Endl;
381
382 // interpret option string with respect to a request for boosting (i.e., BostNum > 0)
383 Int_t boostNum = 0;
384 TMVA::Configurable* conf = new TMVA::Configurable( theOption );
385 conf->DeclareOptionRef( boostNum = 0, "Boost_num",
386 "Number of times the classifier will be boosted" );
387 conf->ParseOptions();
388 delete conf;
389 // this is name of weight file directory
390 TString fileDir;
391 if(fModelPersistence)
392 {
393 // find prefix in fWeightFileDir;
395 fileDir = prefix;
396 if (!prefix.IsNull())
397 if (fileDir[fileDir.Length()-1] != '/') fileDir += "/";
398 fileDir += loader->GetName();
399 fileDir += "/" + gConfig().GetIONames().fWeightFileDir;
400 }
401 // initialize methods
402 IMethod* im;
403 if (!boostNum) {
404 im = ClassifierFactory::Instance().Create(theMethodName.Data(), fJobName, methodTitle,
405 loader->GetDataSetInfo(), theOption);
406 }
407 else {
408 // boosted classifier, requires a specific definition, making it transparent for the user
409 Log() << kDEBUG <<"Boost Number is " << boostNum << " > 0: train boosted classifier" << Endl;
410 im = ClassifierFactory::Instance().Create("Boost", fJobName, methodTitle, loader->GetDataSetInfo(), theOption);
411 MethodBoost *methBoost = dynamic_cast<MethodBoost *>(im); // DSMTEST divided into two lines
412 if (!methBoost) { // DSMTEST
413 Log() << kFATAL << "Method with type kBoost cannot be casted to MethodCategory. /Factory" << Endl; // DSMTEST
414 return nullptr;
415 }
416 if (fModelPersistence) methBoost->SetWeightFileDir(fileDir);
417 methBoost->SetModelPersistence(fModelPersistence);
418 methBoost->SetBoostedMethodName(theMethodName); // DSMTEST divided into two lines
419 methBoost->fDataSetManager = loader->GetDataSetInfo().GetDataSetManager(); // DSMTEST
420 methBoost->SetFile(fgTargetFile);
421 methBoost->SetSilentFile(IsSilentFile());
422 }
423
424 MethodBase *method = dynamic_cast<MethodBase*>(im);
425 if (method==0) return 0; // could not create method
426
427 // set fDataSetManager if MethodCategory (to enable Category to create datasetinfo objects) // DSMTEST
428 if (method->GetMethodType() == Types::kCategory) { // DSMTEST
429 MethodCategory *methCat = (dynamic_cast<MethodCategory*>(im)); // DSMTEST
430 if (!methCat) {// DSMTEST
431 Log() << kFATAL << "Method with type kCategory cannot be casted to MethodCategory. /Factory" << Endl; // DSMTEST
432 return nullptr;
433 }
434 if(fModelPersistence) methCat->SetWeightFileDir(fileDir);
435 methCat->SetModelPersistence(fModelPersistence);
436 methCat->fDataSetManager = loader->GetDataSetInfo().GetDataSetManager(); // DSMTEST
437 methCat->SetFile(fgTargetFile);
438 methCat->SetSilentFile(IsSilentFile());
439 } // DSMTEST
440
441
442 if (!method->HasAnalysisType( fAnalysisType,
443 loader->GetDataSetInfo().GetNClasses(),
444 loader->GetDataSetInfo().GetNTargets() )) {
445 Log() << kWARNING << "Method " << method->GetMethodTypeName() << " is not capable of handling " ;
446 if (fAnalysisType == Types::kRegression) {
447 Log() << "regression with " << loader->GetDataSetInfo().GetNTargets() << " targets." << Endl;
448 }
449 else if (fAnalysisType == Types::kMulticlass ) {
450 Log() << "multiclass classification with " << loader->GetDataSetInfo().GetNClasses() << " classes." << Endl;
451 }
452 else {
453 Log() << "classification with " << loader->GetDataSetInfo().GetNClasses() << " classes." << Endl;
454 }
455 return 0;
456 }
457
458 if(fModelPersistence) method->SetWeightFileDir(fileDir);
459 method->SetModelPersistence(fModelPersistence);
460 method->SetAnalysisType( fAnalysisType );
461 method->SetupMethod();
462 method->ParseOptions();
463 method->ProcessSetup();
464 method->SetFile(fgTargetFile);
465 method->SetSilentFile(IsSilentFile());
466
467 // check-for-unused-options is performed; may be overridden by derived classes
468 method->CheckSetup();
469
470 if(fMethodsMap.find(datasetname)==fMethodsMap.end())
471 {
472 MVector *mvector=new MVector;
473 fMethodsMap[datasetname]=mvector;
474 }
475 fMethodsMap[datasetname]->push_back( method );
476 return method;
477}
478
479////////////////////////////////////////////////////////////////////////////////
480/// Books MVA method. The option configuration string is custom for each MVA
481/// the TString field "theNameAppendix" serves to define (and distinguish)
482/// several instances of a given MVA, eg, when one wants to compare the
483/// performance of various configurations
484
486{
487 return BookMethod(loader, Types::Instance().GetMethodName( theMethod ), methodTitle, theOption );
488}
489
490////////////////////////////////////////////////////////////////////////////////
491/// Adds an already constructed method to be managed by this factory.
492///
493/// \note Private.
494/// \note Know what you are doing when using this method. The method that you
495/// are loading could be trained already.
496///
497
499{
500 TString datasetname = loader->GetName();
501 std::string methodTypeName = std::string(Types::Instance().GetMethodName(methodType).Data());
502 DataSetInfo &dsi = loader->GetDataSetInfo();
503
504 IMethod *im = ClassifierFactory::Instance().Create(methodTypeName, dsi, weightfile );
505 MethodBase *method = (dynamic_cast<MethodBase*>(im));
506
507 if (method == nullptr) return nullptr;
508
509 if( method->GetMethodType() == Types::kCategory ){
510 Log() << kERROR << "Cannot handle category methods for now." << Endl;
511 }
512
513 TString fileDir;
514 if(fModelPersistence) {
515 // find prefix in fWeightFileDir;
517 fileDir = prefix;
518 if (!prefix.IsNull())
519 if (fileDir[fileDir.Length() - 1] != '/')
520 fileDir += "/";
521 fileDir=loader->GetName();
522 fileDir+="/"+gConfig().GetIONames().fWeightFileDir;
523 }
524
525 if(fModelPersistence) method->SetWeightFileDir(fileDir);
526 method->SetModelPersistence(fModelPersistence);
527 method->SetAnalysisType( fAnalysisType );
528 method->SetupMethod();
529 method->SetFile(fgTargetFile);
530 method->SetSilentFile(IsSilentFile());
531
533
534 // read weight file
535 method->ReadStateFromFile();
536
537 //method->CheckSetup();
538
539 TString methodTitle = method->GetName();
540 if (HasMethod(datasetname, methodTitle) != 0) {
541 Log() << kFATAL << "Booking failed since method with title <"
542 << methodTitle <<"> already exists "<< "in with DataSet Name <"<< loader->GetName()<<"> "
543 << Endl;
544 }
545
546 Log() << kINFO << "Booked classifier \"" << method->GetMethodName()
547 << "\" of type: \"" << method->GetMethodTypeName() << "\"" << Endl;
548
549 if(fMethodsMap.count(datasetname) == 0) {
550 MVector *mvector = new MVector;
551 fMethodsMap[datasetname] = mvector;
552 }
553
554 fMethodsMap[datasetname]->push_back( method );
555
556 return method;
557}
558
559////////////////////////////////////////////////////////////////////////////////
560/// Returns pointer to MVA that corresponds to given method title.
561
562TMVA::IMethod* TMVA::Factory::GetMethod(const TString& datasetname, const TString &methodTitle ) const
563{
564 if(fMethodsMap.find(datasetname)==fMethodsMap.end()) return 0;
565
566 MVector *methods=fMethodsMap.find(datasetname)->second;
567
568 MVector::const_iterator itrMethod;
569 //
570 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
571 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
572 if ( (mva->GetMethodName())==methodTitle ) return mva;
573 }
574 return 0;
575}
576
577////////////////////////////////////////////////////////////////////////////////
578/// Checks whether a given method name is defined for a given dataset.
579
580Bool_t TMVA::Factory::HasMethod(const TString& datasetname, const TString &methodTitle ) const
581{
582 if(fMethodsMap.find(datasetname)==fMethodsMap.end()) return 0;
583
584 std::string methodName = methodTitle.Data();
585 auto isEqualToMethodName = [&methodName](TMVA::IMethod * m) {
586 return ( 0 == methodName.compare( m->GetName() ) );
587 };
588
589 TMVA::Factory::MVector * methods = this->fMethodsMap.at(datasetname);
590 Bool_t isMethodNameExisting = std::any_of( methods->begin(), methods->end(), isEqualToMethodName);
591
592 return isMethodNameExisting;
593}
594
595////////////////////////////////////////////////////////////////////////////////
596
598{
599 RootBaseDir()->cd();
600
601 if(!RootBaseDir()->GetDirectory(fDataSetInfo.GetName())) RootBaseDir()->mkdir(fDataSetInfo.GetName());
602 else return; //loader is now in the output file, we dont need to save again
603
604 RootBaseDir()->cd(fDataSetInfo.GetName());
605 fDataSetInfo.GetDataSet(); // builds dataset (including calculation of correlation matrix)
606
607
608 // correlation matrix of the default DS
609 const TMatrixD* m(0);
610 const TH2* h(0);
611
612 if(fAnalysisType == Types::kMulticlass){
613 for (UInt_t cls = 0; cls < fDataSetInfo.GetNClasses() ; cls++) {
614 m = fDataSetInfo.CorrelationMatrix(fDataSetInfo.GetClassInfo(cls)->GetName());
615 h = fDataSetInfo.CreateCorrelationMatrixHist(m, TString("CorrelationMatrix")+fDataSetInfo.GetClassInfo(cls)->GetName(),
616 TString("Correlation Matrix (")+ fDataSetInfo.GetClassInfo(cls)->GetName() +TString(")"));
617 if (h!=0) {
618 h->Write();
619 delete h;
620 }
621 }
622 }
623 else{
624 m = fDataSetInfo.CorrelationMatrix( "Signal" );
625 h = fDataSetInfo.CreateCorrelationMatrixHist(m, "CorrelationMatrixS", "Correlation Matrix (signal)");
626 if (h!=0) {
627 h->Write();
628 delete h;
629 }
630
631 m = fDataSetInfo.CorrelationMatrix( "Background" );
632 h = fDataSetInfo.CreateCorrelationMatrixHist(m, "CorrelationMatrixB", "Correlation Matrix (background)");
633 if (h!=0) {
634 h->Write();
635 delete h;
636 }
637
638 m = fDataSetInfo.CorrelationMatrix( "Regression" );
639 h = fDataSetInfo.CreateCorrelationMatrixHist(m, "CorrelationMatrix", "Correlation Matrix");
640 if (h!=0) {
641 h->Write();
642 delete h;
643 }
644 }
645
646 // some default transformations to evaluate
647 // NOTE: all transformations are destroyed after this test
648 TString processTrfs = "I"; //"I;N;D;P;U;G,D;"
649
650 // plus some user defined transformations
651 processTrfs = fTransformations;
652
653 // remove any trace of identity transform - if given (avoid to apply it twice)
654 std::vector<TMVA::TransformationHandler*> trfs;
655 TransformationHandler* identityTrHandler = 0;
656
657 std::vector<TString> trfsDef = gTools().SplitString(processTrfs,';');
658 std::vector<TString>::iterator trfsDefIt = trfsDef.begin();
659 for (; trfsDefIt!=trfsDef.end(); ++trfsDefIt) {
660 trfs.push_back(new TMVA::TransformationHandler(fDataSetInfo, "Factory"));
661 TString trfS = (*trfsDefIt);
662
663 //Log() << kINFO << Endl;
664 Log() << kDEBUG << "current transformation string: '" << trfS.Data() << "'" << Endl;
666 fDataSetInfo,
667 *(trfs.back()),
668 Log() );
669
670 if (trfS.BeginsWith('I')) identityTrHandler = trfs.back();
671 }
672
673 const std::vector<Event*>& inputEvents = fDataSetInfo.GetDataSet()->GetEventCollection();
674
675 // apply all transformations
676 std::vector<TMVA::TransformationHandler*>::iterator trfIt = trfs.begin();
677
678 for (;trfIt != trfs.end(); ++trfIt) {
679 // setting a Root dir causes the variables distributions to be saved to the root file
680 (*trfIt)->SetRootDir(RootBaseDir()->GetDirectory(fDataSetInfo.GetName()));// every dataloader have its own dir
681 (*trfIt)->CalcTransformations(inputEvents);
682 }
683 if(identityTrHandler) identityTrHandler->PrintVariableRanking();
684
685 // clean up
686 for (trfIt = trfs.begin(); trfIt != trfs.end(); ++trfIt) delete *trfIt;
687}
688
689////////////////////////////////////////////////////////////////////////////////
690/// Iterates through all booked methods and sees if they use parameter tuning and if so..
691/// does just that i.e. calls "Method::Train()" for different parameter settings and
692/// keeps in mind the "optimal one"... and that's the one that will later on be used
693/// in the main training loop.
694
695std::map<TString,Double_t> TMVA::Factory::OptimizeAllMethods(TString fomType, TString fitType)
696{
697
698 std::map<TString,MVector*>::iterator itrMap;
699 std::map<TString,Double_t> TunedParameters;
700 for(itrMap = fMethodsMap.begin();itrMap != fMethodsMap.end();++itrMap)
701 {
702 MVector *methods=itrMap->second;
703
704 MVector::iterator itrMethod;
705
706 // iterate over methods and optimize
707 for( itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod ) {
709 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
710 if (!mva) {
711 Log() << kFATAL << "Dynamic cast to MethodBase failed" <<Endl;
712 return TunedParameters;
713 }
714
716 Log() << kWARNING << "Method " << mva->GetMethodName()
717 << " not trained (training tree has less entries ["
718 << mva->Data()->GetNTrainingEvents()
719 << "] than required [" << MinNoTrainingEvents << "]" << Endl;
720 continue;
721 }
722
723 Log() << kINFO << "Optimize method: " << mva->GetMethodName() << " for "
724 << (fAnalysisType == Types::kRegression ? "Regression" :
725 (fAnalysisType == Types::kMulticlass ? "Multiclass classification" : "Classification")) << Endl;
726
727 TunedParameters = mva->OptimizeTuningParameters(fomType,fitType);
728 Log() << kINFO << "Optimization of tuning parameters finished for Method:"<<mva->GetName() << Endl;
729 }
730 }
731
732 return TunedParameters;
733
734}
735
736////////////////////////////////////////////////////////////////////////////////
737/// Private method to generate a ROCCurve instance for a given method.
738/// Handles the conversion from TMVA ResultSet to a format the ROCCurve class
739/// understands.
740///
741/// \note You own the retured pointer.
742///
743
746{
747 return GetROC((TString)loader->GetName(), theMethodName, iClass, type);
748}
749
750////////////////////////////////////////////////////////////////////////////////
751/// Private method to generate a ROCCurve instance for a given method.
752/// Handles the conversion from TMVA ResultSet to a format the ROCCurve class
753/// understands.
754///
755/// \note You own the retured pointer.
756///
757
759{
760 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
761 Log() << kERROR << Form("DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
762 return nullptr;
763 }
764
765 if (!this->HasMethod(datasetname, theMethodName)) {
766 Log() << kERROR << Form("Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data())
767 << Endl;
768 return nullptr;
769 }
770
771 std::set<Types::EAnalysisType> allowedAnalysisTypes = {Types::kClassification, Types::kMulticlass};
772 if (allowedAnalysisTypes.count(this->fAnalysisType) == 0) {
773 Log() << kERROR << Form("Can only generate ROC curves for analysis type kClassification and kMulticlass.")
774 << Endl;
775 return nullptr;
776 }
777
778 TMVA::MethodBase *method = dynamic_cast<TMVA::MethodBase *>(this->GetMethod(datasetname, theMethodName));
779 TMVA::DataSet *dataset = method->Data();
780 dataset->SetCurrentType(type);
781 TMVA::Results *results = dataset->GetResults(theMethodName, type, this->fAnalysisType);
782
783 UInt_t nClasses = method->DataInfo().GetNClasses();
784 if (this->fAnalysisType == Types::kMulticlass && iClass >= nClasses) {
785 Log() << kERROR << Form("Given class number (iClass = %i) does not exist. There are %i classes in dataset.",
786 iClass, nClasses)
787 << Endl;
788 return nullptr;
789 }
790
791 TMVA::ROCCurve *rocCurve = nullptr;
792 if (this->fAnalysisType == Types::kClassification) {
793
794 std::vector<Float_t> *mvaRes = dynamic_cast<ResultsClassification *>(results)->GetValueVector();
795 std::vector<Bool_t> *mvaResTypes = dynamic_cast<ResultsClassification *>(results)->GetValueVectorTypes();
796 std::vector<Float_t> mvaResWeights;
797
798 auto eventCollection = dataset->GetEventCollection(type);
799 mvaResWeights.reserve(eventCollection.size());
800 for (auto ev : eventCollection) {
801 mvaResWeights.push_back(ev->GetWeight());
802 }
803
804 rocCurve = new TMVA::ROCCurve(*mvaRes, *mvaResTypes, mvaResWeights);
805
806 } else if (this->fAnalysisType == Types::kMulticlass) {
807 std::vector<Float_t> mvaRes;
808 std::vector<Bool_t> mvaResTypes;
809 std::vector<Float_t> mvaResWeights;
810
811 std::vector<std::vector<Float_t>> *rawMvaRes = dynamic_cast<ResultsMulticlass *>(results)->GetValueVector();
812
813 // Vector transpose due to values being stored as
814 // [ [0, 1, 2], [0, 1, 2], ... ]
815 // in ResultsMulticlass::GetValueVector.
816 mvaRes.reserve(rawMvaRes->size());
817 for (auto item : *rawMvaRes) {
818 mvaRes.push_back(item[iClass]);
819 }
820
821 auto eventCollection = dataset->GetEventCollection(type);
822 mvaResTypes.reserve(eventCollection.size());
823 mvaResWeights.reserve(eventCollection.size());
824 for (auto ev : eventCollection) {
825 mvaResTypes.push_back(ev->GetClass() == iClass);
826 mvaResWeights.push_back(ev->GetWeight());
827 }
828
829 rocCurve = new TMVA::ROCCurve(mvaRes, mvaResTypes, mvaResWeights);
830 }
831
832 return rocCurve;
833}
834
835////////////////////////////////////////////////////////////////////////////////
836/// Calculate the integral of the ROC curve, also known as the area under curve
837/// (AUC), for a given method.
838///
839/// Argument iClass specifies the class to generate the ROC curve in a
840/// multiclass setting. It is ignored for binary classification.
841///
842
844{
845 return GetROCIntegral((TString)loader->GetName(), theMethodName, iClass);
846}
847
848////////////////////////////////////////////////////////////////////////////////
849/// Calculate the integral of the ROC curve, also known as the area under curve
850/// (AUC), for a given method.
851///
852/// Argument iClass specifies the class to generate the ROC curve in a
853/// multiclass setting. It is ignored for binary classification.
854///
855
857{
858 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
859 Log() << kERROR << Form("DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
860 return 0;
861 }
862
863 if ( ! this->HasMethod(datasetname, theMethodName) ) {
864 Log() << kERROR << Form("Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data()) << Endl;
865 return 0;
866 }
867
868 std::set<Types::EAnalysisType> allowedAnalysisTypes = {Types::kClassification, Types::kMulticlass};
869 if ( allowedAnalysisTypes.count(this->fAnalysisType) == 0 ) {
870 Log() << kERROR << Form("Can only generate ROC integral for analysis type kClassification. and kMulticlass.")
871 << Endl;
872 return 0;
873 }
874
875 TMVA::ROCCurve *rocCurve = GetROC(datasetname, theMethodName, iClass);
876 if (!rocCurve) {
877 Log() << kFATAL << Form("ROCCurve object was not created in Method = %s not found with Dataset = %s ",
878 theMethodName.Data(), datasetname.Data())
879 << Endl;
880 return 0;
881 }
882
884 Double_t rocIntegral = rocCurve->GetROCIntegral(npoints);
885 delete rocCurve;
886
887 return rocIntegral;
888}
889
890////////////////////////////////////////////////////////////////////////////////
891/// Argument iClass specifies the class to generate the ROC curve in a
892/// multiclass setting. It is ignored for binary classification.
893///
894/// Returns a ROC graph for a given method, or nullptr on error.
895///
896/// Note: Evaluation of the given method must have been run prior to ROC
897/// generation through Factory::EvaluateAllMetods.
898///
899/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
900/// and the others considered background. This is ok in binary classification
901/// but in in multi class classification, the ROC surface is an N dimensional
902/// shape, where N is number of classes - 1.
903
904TGraph* TMVA::Factory::GetROCCurve(DataLoader *loader, TString theMethodName, Bool_t setTitles, UInt_t iClass)
905{
906 return GetROCCurve( (TString)loader->GetName(), theMethodName, setTitles, iClass );
907}
908
909////////////////////////////////////////////////////////////////////////////////
910/// Argument iClass specifies the class to generate the ROC curve in a
911/// multiclass setting. It is ignored for binary classification.
912///
913/// Returns a ROC graph for a given method, or nullptr on error.
914///
915/// Note: Evaluation of the given method must have been run prior to ROC
916/// generation through Factory::EvaluateAllMetods.
917///
918/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
919/// and the others considered background. This is ok in binary classification
920/// but in in multi class classification, the ROC surface is an N dimensional
921/// shape, where N is number of classes - 1.
922
923TGraph* TMVA::Factory::GetROCCurve(TString datasetname, TString theMethodName, Bool_t setTitles, UInt_t iClass)
924{
925 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
926 Log() << kERROR << Form("DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
927 return nullptr;
928 }
929
930 if ( ! this->HasMethod(datasetname, theMethodName) ) {
931 Log() << kERROR << Form("Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data()) << Endl;
932 return nullptr;
933 }
934
935 std::set<Types::EAnalysisType> allowedAnalysisTypes = {Types::kClassification, Types::kMulticlass};
936 if ( allowedAnalysisTypes.count(this->fAnalysisType) == 0 ) {
937 Log() << kERROR << Form("Can only generate ROC curves for analysis type kClassification and kMulticlass.") << Endl;
938 return nullptr;
939 }
940
941 TMVA::ROCCurve *rocCurve = GetROC(datasetname, theMethodName, iClass);
942 TGraph *graph = nullptr;
943
944 if ( ! rocCurve ) {
945 Log() << kFATAL << Form("ROCCurve object was not created in Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data()) << Endl;
946 return nullptr;
947 }
948
949 graph = (TGraph *)rocCurve->GetROCCurve()->Clone();
950 delete rocCurve;
951
952 if(setTitles) {
953 graph->GetYaxis()->SetTitle("Background rejection (Specificity)");
954 graph->GetXaxis()->SetTitle("Signal efficiency (Sensitivity)");
955 graph->SetTitle(Form("Signal efficiency vs. Background rejection (%s)", theMethodName.Data()));
956 }
957
958 return graph;
959}
960
961////////////////////////////////////////////////////////////////////////////////
962/// Generate a collection of graphs, for all methods for a given class. Suitable
963/// for comparing method performance.
964///
965/// Argument iClass specifies the class to generate the ROC curve in a
966/// multiclass setting. It is ignored for binary classification.
967///
968/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
969/// and the others considered background. This is ok in binary classification
970/// but in in multi class classification, the ROC surface is an N dimensional
971/// shape, where N is number of classes - 1.
972
974{
975 return GetROCCurveAsMultiGraph((TString)loader->GetName(), iClass);
976}
977
978////////////////////////////////////////////////////////////////////////////////
979/// Generate a collection of graphs, for all methods for a given class. Suitable
980/// for comparing method performance.
981///
982/// Argument iClass specifies the class to generate the ROC curve in a
983/// multiclass setting. It is ignored for binary classification.
984///
985/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
986/// and the others considered background. This is ok in binary classification
987/// but in in multi class classification, the ROC surface is an N dimensional
988/// shape, where N is number of classes - 1.
989
991{
992 UInt_t line_color = 1;
993
994 TMultiGraph *multigraph = new TMultiGraph();
995
996 MVector *methods = fMethodsMap[datasetname.Data()];
997 for (auto * method_raw : *methods) {
998 TMVA::MethodBase *method = dynamic_cast<TMVA::MethodBase *>(method_raw);
999 if (method == nullptr) { continue; }
1000
1001 TString methodName = method->GetMethodName();
1002 UInt_t nClasses = method->DataInfo().GetNClasses();
1003
1004 if ( this->fAnalysisType == Types::kMulticlass && iClass >= nClasses ) {
1005 Log() << kERROR << Form("Given class number (iClass = %i) does not exist. There are %i classes in dataset.", iClass, nClasses) << Endl;
1006 continue;
1007 }
1008
1009 TString className = method->DataInfo().GetClassInfo(iClass)->GetName();
1010
1011 TGraph *graph = this->GetROCCurve(datasetname, methodName, false, iClass);
1012 graph->SetTitle(methodName);
1013
1014 graph->SetLineWidth(2);
1015 graph->SetLineColor(line_color++);
1016 graph->SetFillColor(10);
1017
1018 multigraph->Add(graph);
1019 }
1020
1021 if ( multigraph->GetListOfGraphs() == nullptr ) {
1022 Log() << kERROR << Form("No metohds have class %i defined.", iClass) << Endl;
1023 return nullptr;
1024 }
1025
1026 return multigraph;
1027}
1028
1029////////////////////////////////////////////////////////////////////////////////
1030/// Draws ROC curves for all methods booked with the factory for a given class
1031/// onto a canvas.
1032///
1033/// Argument iClass specifies the class to generate the ROC curve in a
1034/// multiclass setting. It is ignored for binary classification.
1035///
1036/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
1037/// and the others considered background. This is ok in binary classification
1038/// but in in multi class classification, the ROC surface is an N dimensional
1039/// shape, where N is number of classes - 1.
1040
1042{
1043 return GetROCCurve((TString)loader->GetName(), iClass);
1044}
1045
1046////////////////////////////////////////////////////////////////////////////////
1047/// Draws ROC curves for all methods booked with the factory for a given class.
1048///
1049/// Argument iClass specifies the class to generate the ROC curve in a
1050/// multiclass setting. It is ignored for binary classification.
1051///
1052/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
1053/// and the others considered background. This is ok in binary classification
1054/// but in in multi class classification, the ROC surface is an N dimensional
1055/// shape, where N is number of classes - 1.
1056
1058{
1059 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
1060 Log() << kERROR << Form("DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
1061 return 0;
1062 }
1063
1064 TString name = Form("ROCCurve %s class %i", datasetname.Data(), iClass);
1065 TCanvas *canvas = new TCanvas(name, "ROC Curve", 200, 10, 700, 500);
1066 canvas->SetGrid();
1067
1068 TMultiGraph *multigraph = this->GetROCCurveAsMultiGraph(datasetname, iClass);
1069
1070 if ( multigraph ) {
1071 multigraph->Draw("AL");
1072
1073 multigraph->GetYaxis()->SetTitle("Background rejection (Specificity)");
1074 multigraph->GetXaxis()->SetTitle("Signal efficiency (Sensitivity)");
1075
1076 TString titleString = Form("Signal efficiency vs. Background rejection");
1077 if (this->fAnalysisType == Types::kMulticlass) {
1078 titleString = Form("%s (Class=%i)", titleString.Data(), iClass);
1079 }
1080
1081 // Workaround for TMultigraph not drawing title correctly.
1082 multigraph->GetHistogram()->SetTitle( titleString );
1083 multigraph->SetTitle( titleString );
1084
1085 canvas->BuildLegend(0.15, 0.15, 0.35, 0.3, "MVA Method");
1086 }
1087
1088 return canvas;
1089}
1090
1091////////////////////////////////////////////////////////////////////////////////
1092/// Iterates through all booked methods and calls training
1093
1095{
1096 Log() << kHEADER << gTools().Color("bold") << "Train all methods" << gTools().Color("reset") << Endl;
1097 // iterates over all MVAs that have been booked, and calls their training methods
1098
1099
1100 // don't do anything if no method booked
1101 if (fMethodsMap.empty()) {
1102 Log() << kINFO << "...nothing found to train" << Endl;
1103 return;
1104 }
1105
1106 // here the training starts
1107 //Log() << kINFO << " " << Endl;
1108 Log() << kDEBUG << "Train all methods for "
1109 << (fAnalysisType == Types::kRegression ? "Regression" :
1110 (fAnalysisType == Types::kMulticlass ? "Multiclass" : "Classification") ) << " ..." << Endl;
1111
1112 std::map<TString,MVector*>::iterator itrMap;
1113
1114 for(itrMap = fMethodsMap.begin();itrMap != fMethodsMap.end();++itrMap)
1115 {
1116 MVector *methods=itrMap->second;
1117 MVector::iterator itrMethod;
1118
1119 // iterate over methods and train
1120 for( itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod ) {
1122 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
1123
1124 if(mva==0) continue;
1125
1126 if(mva->DataInfo().GetDataSetManager()->DataInput().GetEntries() <=1) { // 0 entries --> 0 events, 1 entry --> dynamical dataset (or one entry)
1127 Log() << kFATAL << "No input data for the training provided!" << Endl;
1128 }
1129
1130 if(fAnalysisType == Types::kRegression && mva->DataInfo().GetNTargets() < 1 )
1131 Log() << kFATAL << "You want to do regression training without specifying a target." << Endl;
1132 else if( (fAnalysisType == Types::kMulticlass || fAnalysisType == Types::kClassification)
1133 && mva->DataInfo().GetNClasses() < 2 )
1134 Log() << kFATAL << "You want to do classification training, but specified less than two classes." << Endl;
1135
1136 // first print some information about the default dataset
1137 if(!IsSilentFile()) WriteDataInformation(mva->fDataSetInfo);
1138
1139
1141 Log() << kWARNING << "Method " << mva->GetMethodName()
1142 << " not trained (training tree has less entries ["
1143 << mva->Data()->GetNTrainingEvents()
1144 << "] than required [" << MinNoTrainingEvents << "]" << Endl;
1145 continue;
1146 }
1147
1148 Log() << kHEADER << "Train method: " << mva->GetMethodName() << " for "
1149 << (fAnalysisType == Types::kRegression ? "Regression" :
1150 (fAnalysisType == Types::kMulticlass ? "Multiclass classification" : "Classification")) << Endl << Endl;
1151 mva->TrainMethod();
1152 Log() << kHEADER << "Training finished" << Endl << Endl;
1153 }
1154
1155 if (fAnalysisType != Types::kRegression) {
1156
1157 // variable ranking
1158 //Log() << Endl;
1159 Log() << kINFO << "Ranking input variables (method specific)..." << Endl;
1160 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1161 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
1162 if (mva && mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
1163
1164 // create and print ranking
1165 const Ranking* ranking = (*itrMethod)->CreateRanking();
1166 if (ranking != 0) ranking->Print();
1167 else Log() << kINFO << "No variable ranking supplied by classifier: "
1168 << dynamic_cast<MethodBase*>(*itrMethod)->GetMethodName() << Endl;
1169 }
1170 }
1171 }
1172
1173 // save training history in case we are not in the silent mode
1174 if (!IsSilentFile()) {
1175 for (UInt_t i=0; i<methods->size(); i++) {
1176 MethodBase* m = dynamic_cast<MethodBase*>((*methods)[i]);
1177 if(m==0) continue;
1178 m->BaseDir()->cd();
1179 m->fTrainHistory.SaveHistory(m->GetMethodName());
1180 }
1181 }
1182
1183 // delete all methods and recreate them from weight file - this ensures that the application
1184 // of the methods (in TMVAClassificationApplication) is consistent with the results obtained
1185 // in the testing
1186 //Log() << Endl;
1187 if (fModelPersistence) {
1188
1189 Log() << kHEADER << "=== Destroy and recreate all methods via weight files for testing ===" << Endl << Endl;
1190
1191 if(!IsSilentFile())RootBaseDir()->cd();
1192
1193 // iterate through all booked methods
1194 for (UInt_t i=0; i<methods->size(); i++) {
1195
1196 MethodBase *m = dynamic_cast<MethodBase *>((*methods)[i]);
1197 if (m == nullptr)
1198 continue;
1199
1200 TMVA::Types::EMVA methodType = m->GetMethodType();
1201 TString weightfile = m->GetWeightFileName();
1202
1203 // decide if .txt or .xml file should be read:
1204 if (READXML)
1205 weightfile.ReplaceAll(".txt", ".xml");
1206
1207 DataSetInfo &dataSetInfo = m->DataInfo();
1208 TString testvarName = m->GetTestvarName();
1209 delete m; // itrMethod[i];
1210
1211 // recreate
1212 m = dynamic_cast<MethodBase *>(ClassifierFactory::Instance().Create(
1213 Types::Instance().GetMethodName(methodType).Data(), dataSetInfo, weightfile));
1214 if (m->GetMethodType() == Types::kCategory) {
1215 MethodCategory *methCat = (dynamic_cast<MethodCategory *>(m));
1216 if (!methCat)
1217 Log() << kFATAL << "Method with type kCategory cannot be casted to MethodCategory. /Factory" << Endl;
1218 else
1219 methCat->fDataSetManager = m->DataInfo().GetDataSetManager();
1220 }
1221 // ToDo, Do we need to fill the DataSetManager of MethodBoost here too?
1222
1223 TString wfileDir = m->DataInfo().GetName();
1224 wfileDir += "/" + gConfig().GetIONames().fWeightFileDir;
1225 m->SetWeightFileDir(wfileDir);
1226 m->SetModelPersistence(fModelPersistence);
1227 m->SetSilentFile(IsSilentFile());
1228 m->SetAnalysisType(fAnalysisType);
1229 m->SetupMethod();
1230 m->ReadStateFromFile();
1231 m->SetTestvarName(testvarName);
1232
1233 // replace trained method by newly created one (from weight file) in methods vector
1234 (*methods)[i] = m;
1235 }
1236 }
1237 }
1238}
1239
1240////////////////////////////////////////////////////////////////////////////////
1241/// Evaluates all booked methods on the testing data and adds the output to the
1242/// Results in the corresponiding DataSet.
1243///
1244
1246{
1247 Log() << kHEADER << gTools().Color("bold") << "Test all methods" << gTools().Color("reset") << Endl;
1248
1249 // don't do anything if no method booked
1250 if (fMethodsMap.empty()) {
1251 Log() << kINFO << "...nothing found to test" << Endl;
1252 return;
1253 }
1254 std::map<TString,MVector*>::iterator itrMap;
1255
1256 for(itrMap = fMethodsMap.begin();itrMap != fMethodsMap.end();++itrMap)
1257 {
1258 MVector *methods=itrMap->second;
1259 MVector::iterator itrMethod;
1260
1261 // iterate over methods and test
1262 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1264 MethodBase *mva = dynamic_cast<MethodBase *>(*itrMethod);
1265 if (mva == 0)
1266 continue;
1267 Types::EAnalysisType analysisType = mva->GetAnalysisType();
1268 Log() << kHEADER << "Test method: " << mva->GetMethodName() << " for "
1269 << (analysisType == Types::kRegression
1270 ? "Regression"
1271 : (analysisType == Types::kMulticlass ? "Multiclass classification" : "Classification"))
1272 << " performance" << Endl << Endl;
1273 mva->AddOutput(Types::kTesting, analysisType);
1274 }
1275 }
1276}
1277
1278////////////////////////////////////////////////////////////////////////////////
1279
1280void TMVA::Factory::MakeClass(const TString& datasetname , const TString& methodTitle ) const
1281{
1282 if (methodTitle != "") {
1283 IMethod* method = GetMethod(datasetname, methodTitle);
1284 if (method) method->MakeClass();
1285 else {
1286 Log() << kWARNING << "<MakeClass> Could not find classifier \"" << methodTitle
1287 << "\" in list" << Endl;
1288 }
1289 }
1290 else {
1291
1292 // no classifier specified, print all help messages
1293 MVector *methods=fMethodsMap.find(datasetname)->second;
1294 MVector::const_iterator itrMethod;
1295 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1296 MethodBase* method = dynamic_cast<MethodBase*>(*itrMethod);
1297 if(method==0) continue;
1298 Log() << kINFO << "Make response class for classifier: " << method->GetMethodName() << Endl;
1299 method->MakeClass();
1300 }
1301 }
1302}
1303
1304////////////////////////////////////////////////////////////////////////////////
1305/// Print predefined help message of classifier.
1306/// Iterate over methods and test.
1307
1308void TMVA::Factory::PrintHelpMessage(const TString& datasetname , const TString& methodTitle ) const
1309{
1310 if (methodTitle != "") {
1311 IMethod* method = GetMethod(datasetname , methodTitle );
1312 if (method) method->PrintHelpMessage();
1313 else {
1314 Log() << kWARNING << "<PrintHelpMessage> Could not find classifier \"" << methodTitle
1315 << "\" in list" << Endl;
1316 }
1317 }
1318 else {
1319
1320 // no classifier specified, print all help messages
1321 MVector *methods=fMethodsMap.find(datasetname)->second;
1322 MVector::const_iterator itrMethod ;
1323 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1324 MethodBase* method = dynamic_cast<MethodBase*>(*itrMethod);
1325 if(method==0) continue;
1326 Log() << kINFO << "Print help message for classifier: " << method->GetMethodName() << Endl;
1327 method->PrintHelpMessage();
1328 }
1329 }
1330}
1331
1332////////////////////////////////////////////////////////////////////////////////
1333/// Iterates over all MVA input variables and evaluates them.
1334
1336{
1337 Log() << kINFO << "Evaluating all variables..." << Endl;
1339
1340 for (UInt_t i=0; i<loader->GetDataSetInfo().GetNVariables(); i++) {
1342 if (options.Contains("V")) s += ":V";
1343 this->BookMethod(loader, "Variable", s );
1344 }
1345}
1346
1347////////////////////////////////////////////////////////////////////////////////
1348/// Iterates over all MVAs that have been booked, and calls their evaluation methods.
1349
1351{
1352 Log() << kHEADER << gTools().Color("bold") << "Evaluate all methods" << gTools().Color("reset") << Endl;
1353
1354 // don't do anything if no method booked
1355 if (fMethodsMap.empty()) {
1356 Log() << kINFO << "...nothing found to evaluate" << Endl;
1357 return;
1358 }
1359 std::map<TString,MVector*>::iterator itrMap;
1360
1361 for(itrMap = fMethodsMap.begin();itrMap != fMethodsMap.end();++itrMap)
1362 {
1363 MVector *methods=itrMap->second;
1364
1365 // -----------------------------------------------------------------------
1366 // First part of evaluation process
1367 // --> compute efficiencies, and other separation estimators
1368 // -----------------------------------------------------------------------
1369
1370 // although equal, we now want to separate the output for the variables
1371 // and the real methods
1372 Int_t isel; // will be 0 for a Method; 1 for a Variable
1373 Int_t nmeth_used[2] = {0,0}; // 0 Method; 1 Variable
1374
1375 std::vector<std::vector<TString> > mname(2);
1376 std::vector<std::vector<Double_t> > sig(2), sep(2), roc(2);
1377 std::vector<std::vector<Double_t> > eff01(2), eff10(2), eff30(2), effArea(2);
1378 std::vector<std::vector<Double_t> > eff01err(2), eff10err(2), eff30err(2);
1379 std::vector<std::vector<Double_t> > trainEff01(2), trainEff10(2), trainEff30(2);
1380
1381 std::vector<std::vector<Float_t> > multiclass_testEff;
1382 std::vector<std::vector<Float_t> > multiclass_trainEff;
1383 std::vector<std::vector<Float_t> > multiclass_testPur;
1384 std::vector<std::vector<Float_t> > multiclass_trainPur;
1385
1386 std::vector<std::vector<Float_t> > train_history;
1387
1388 // Multiclass confusion matrices.
1389 std::vector<TMatrixD> multiclass_trainConfusionEffB01;
1390 std::vector<TMatrixD> multiclass_trainConfusionEffB10;
1391 std::vector<TMatrixD> multiclass_trainConfusionEffB30;
1392 std::vector<TMatrixD> multiclass_testConfusionEffB01;
1393 std::vector<TMatrixD> multiclass_testConfusionEffB10;
1394 std::vector<TMatrixD> multiclass_testConfusionEffB30;
1395
1396 std::vector<std::vector<Double_t> > biastrain(1); // "bias" of the regression on the training data
1397 std::vector<std::vector<Double_t> > biastest(1); // "bias" of the regression on test data
1398 std::vector<std::vector<Double_t> > devtrain(1); // "dev" of the regression on the training data
1399 std::vector<std::vector<Double_t> > devtest(1); // "dev" of the regression on test data
1400 std::vector<std::vector<Double_t> > rmstrain(1); // "rms" of the regression on the training data
1401 std::vector<std::vector<Double_t> > rmstest(1); // "rms" of the regression on test data
1402 std::vector<std::vector<Double_t> > minftrain(1); // "minf" of the regression on the training data
1403 std::vector<std::vector<Double_t> > minftest(1); // "minf" of the regression on test data
1404 std::vector<std::vector<Double_t> > rhotrain(1); // correlation of the regression on the training data
1405 std::vector<std::vector<Double_t> > rhotest(1); // correlation of the regression on test data
1406
1407 // same as above but for 'truncated' quantities (computed for events within 2sigma of RMS)
1408 std::vector<std::vector<Double_t> > biastrainT(1);
1409 std::vector<std::vector<Double_t> > biastestT(1);
1410 std::vector<std::vector<Double_t> > devtrainT(1);
1411 std::vector<std::vector<Double_t> > devtestT(1);
1412 std::vector<std::vector<Double_t> > rmstrainT(1);
1413 std::vector<std::vector<Double_t> > rmstestT(1);
1414 std::vector<std::vector<Double_t> > minftrainT(1);
1415 std::vector<std::vector<Double_t> > minftestT(1);
1416
1417 // following vector contains all methods - with the exception of Cuts, which are special
1418 MVector methodsNoCuts;
1419
1420 Bool_t doRegression = kFALSE;
1421 Bool_t doMulticlass = kFALSE;
1422
1423 // iterate over methods and evaluate
1424 for (MVector::iterator itrMethod =methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1426 MethodBase* theMethod = dynamic_cast<MethodBase*>(*itrMethod);
1427 if(theMethod==0) continue;
1428 theMethod->SetFile(fgTargetFile);
1429 theMethod->SetSilentFile(IsSilentFile());
1430 if (theMethod->GetMethodType() != Types::kCuts) methodsNoCuts.push_back( *itrMethod );
1431
1432 if (theMethod->DoRegression()) {
1433 doRegression = kTRUE;
1434
1435 Log() << kINFO << "Evaluate regression method: " << theMethod->GetMethodName() << Endl;
1436 Double_t bias, dev, rms, mInf;
1437 Double_t biasT, devT, rmsT, mInfT;
1438 Double_t rho;
1439
1440 Log() << kINFO << "TestRegression (testing)" << Endl;
1441 theMethod->TestRegression( bias, biasT, dev, devT, rms, rmsT, mInf, mInfT, rho, TMVA::Types::kTesting );
1442 biastest[0] .push_back( bias );
1443 devtest[0] .push_back( dev );
1444 rmstest[0] .push_back( rms );
1445 minftest[0] .push_back( mInf );
1446 rhotest[0] .push_back( rho );
1447 biastestT[0] .push_back( biasT );
1448 devtestT[0] .push_back( devT );
1449 rmstestT[0] .push_back( rmsT );
1450 minftestT[0] .push_back( mInfT );
1451
1452 Log() << kINFO << "TestRegression (training)" << Endl;
1453 theMethod->TestRegression( bias, biasT, dev, devT, rms, rmsT, mInf, mInfT, rho, TMVA::Types::kTraining );
1454 biastrain[0] .push_back( bias );
1455 devtrain[0] .push_back( dev );
1456 rmstrain[0] .push_back( rms );
1457 minftrain[0] .push_back( mInf );
1458 rhotrain[0] .push_back( rho );
1459 biastrainT[0].push_back( biasT );
1460 devtrainT[0] .push_back( devT );
1461 rmstrainT[0] .push_back( rmsT );
1462 minftrainT[0].push_back( mInfT );
1463
1464 mname[0].push_back( theMethod->GetMethodName() );
1465 nmeth_used[0]++;
1466 if (!IsSilentFile()) {
1467 Log() << kDEBUG << "\tWrite evaluation histograms to file" << Endl;
1470 }
1471 } else if (theMethod->DoMulticlass()) {
1472 // ====================================================================
1473 // === Multiclass evaluation
1474 // ====================================================================
1475 doMulticlass = kTRUE;
1476 Log() << kINFO << "Evaluate multiclass classification method: " << theMethod->GetMethodName() << Endl;
1477
1478 // This part uses a genetic alg. to evaluate the optimal sig eff * sig pur.
1479 // This is why it is disabled for now.
1480 // Find approximate optimal working point w.r.t. signalEfficiency * signalPurity.
1481 // theMethod->TestMulticlass(); // This is where the actual GA calc is done
1482 // multiclass_testEff.push_back(theMethod->GetMulticlassEfficiency(multiclass_testPur));
1483
1484 theMethod->TestMulticlass();
1485
1486 // Confusion matrix at three background efficiency levels
1487 multiclass_trainConfusionEffB01.push_back(theMethod->GetMulticlassConfusionMatrix(0.01, Types::kTraining));
1488 multiclass_trainConfusionEffB10.push_back(theMethod->GetMulticlassConfusionMatrix(0.10, Types::kTraining));
1489 multiclass_trainConfusionEffB30.push_back(theMethod->GetMulticlassConfusionMatrix(0.30, Types::kTraining));
1490
1491 multiclass_testConfusionEffB01.push_back(theMethod->GetMulticlassConfusionMatrix(0.01, Types::kTesting));
1492 multiclass_testConfusionEffB10.push_back(theMethod->GetMulticlassConfusionMatrix(0.10, Types::kTesting));
1493 multiclass_testConfusionEffB30.push_back(theMethod->GetMulticlassConfusionMatrix(0.30, Types::kTesting));
1494
1495 if (!IsSilentFile()) {
1496 Log() << kDEBUG << "\tWrite evaluation histograms to file" << Endl;
1499 }
1500
1501 nmeth_used[0]++;
1502 mname[0].push_back(theMethod->GetMethodName());
1503 } else {
1504
1505 Log() << kHEADER << "Evaluate classifier: " << theMethod->GetMethodName() << Endl << Endl;
1506 isel = (theMethod->GetMethodTypeName().Contains("Variable")) ? 1 : 0;
1507
1508 // perform the evaluation
1509 theMethod->TestClassification();
1510
1511 // evaluate the classifier
1512 mname[isel].push_back(theMethod->GetMethodName());
1513 sig[isel].push_back(theMethod->GetSignificance());
1514 sep[isel].push_back(theMethod->GetSeparation());
1515 roc[isel].push_back(theMethod->GetROCIntegral());
1516
1517 Double_t err;
1518 eff01[isel].push_back(theMethod->GetEfficiency("Efficiency:0.01", Types::kTesting, err));
1519 eff01err[isel].push_back(err);
1520 eff10[isel].push_back(theMethod->GetEfficiency("Efficiency:0.10", Types::kTesting, err));
1521 eff10err[isel].push_back(err);
1522 eff30[isel].push_back(theMethod->GetEfficiency("Efficiency:0.30", Types::kTesting, err));
1523 eff30err[isel].push_back(err);
1524 effArea[isel].push_back(theMethod->GetEfficiency("", Types::kTesting, err)); // computes the area (average)
1525
1526 trainEff01[isel].push_back(theMethod->GetTrainingEfficiency("Efficiency:0.01")); // the first pass takes longer
1527 trainEff10[isel].push_back(theMethod->GetTrainingEfficiency("Efficiency:0.10"));
1528 trainEff30[isel].push_back(theMethod->GetTrainingEfficiency("Efficiency:0.30"));
1529
1530 nmeth_used[isel]++;
1531
1532 if (!IsSilentFile()) {
1533 Log() << kDEBUG << "\tWrite evaluation histograms to file" << Endl;
1536 }
1537 }
1538 }
1539 if (doRegression) {
1540
1541 std::vector<TString> vtemps = mname[0];
1542 std::vector< std::vector<Double_t> > vtmp;
1543 vtmp.push_back( devtest[0] ); // this is the vector that is ranked
1544 vtmp.push_back( devtrain[0] );
1545 vtmp.push_back( biastest[0] );
1546 vtmp.push_back( biastrain[0] );
1547 vtmp.push_back( rmstest[0] );
1548 vtmp.push_back( rmstrain[0] );
1549 vtmp.push_back( minftest[0] );
1550 vtmp.push_back( minftrain[0] );
1551 vtmp.push_back( rhotest[0] );
1552 vtmp.push_back( rhotrain[0] );
1553 vtmp.push_back( devtestT[0] ); // this is the vector that is ranked
1554 vtmp.push_back( devtrainT[0] );
1555 vtmp.push_back( biastestT[0] );
1556 vtmp.push_back( biastrainT[0]);
1557 vtmp.push_back( rmstestT[0] );
1558 vtmp.push_back( rmstrainT[0] );
1559 vtmp.push_back( minftestT[0] );
1560 vtmp.push_back( minftrainT[0]);
1561 gTools().UsefulSortAscending( vtmp, &vtemps );
1562 mname[0] = vtemps;
1563 devtest[0] = vtmp[0];
1564 devtrain[0] = vtmp[1];
1565 biastest[0] = vtmp[2];
1566 biastrain[0] = vtmp[3];
1567 rmstest[0] = vtmp[4];
1568 rmstrain[0] = vtmp[5];
1569 minftest[0] = vtmp[6];
1570 minftrain[0] = vtmp[7];
1571 rhotest[0] = vtmp[8];
1572 rhotrain[0] = vtmp[9];
1573 devtestT[0] = vtmp[10];
1574 devtrainT[0] = vtmp[11];
1575 biastestT[0] = vtmp[12];
1576 biastrainT[0] = vtmp[13];
1577 rmstestT[0] = vtmp[14];
1578 rmstrainT[0] = vtmp[15];
1579 minftestT[0] = vtmp[16];
1580 minftrainT[0] = vtmp[17];
1581 } else if (doMulticlass) {
1582 // TODO: fill in something meaningful
1583 // If there is some ranking of methods to be done it should be done here.
1584 // However, this is not so easy to define for multiclass so it is left out for now.
1585
1586 }
1587 else {
1588 // now sort the variables according to the best 'eff at Beff=0.10'
1589 for (Int_t k=0; k<2; k++) {
1590 std::vector< std::vector<Double_t> > vtemp;
1591 vtemp.push_back( effArea[k] ); // this is the vector that is ranked
1592 vtemp.push_back( eff10[k] );
1593 vtemp.push_back( eff01[k] );
1594 vtemp.push_back( eff30[k] );
1595 vtemp.push_back( eff10err[k] );
1596 vtemp.push_back( eff01err[k] );
1597 vtemp.push_back( eff30err[k] );
1598 vtemp.push_back( trainEff10[k] );
1599 vtemp.push_back( trainEff01[k] );
1600 vtemp.push_back( trainEff30[k] );
1601 vtemp.push_back( sig[k] );
1602 vtemp.push_back( sep[k] );
1603 vtemp.push_back( roc[k] );
1604 std::vector<TString> vtemps = mname[k];
1605 gTools().UsefulSortDescending( vtemp, &vtemps );
1606 effArea[k] = vtemp[0];
1607 eff10[k] = vtemp[1];
1608 eff01[k] = vtemp[2];
1609 eff30[k] = vtemp[3];
1610 eff10err[k] = vtemp[4];
1611 eff01err[k] = vtemp[5];
1612 eff30err[k] = vtemp[6];
1613 trainEff10[k] = vtemp[7];
1614 trainEff01[k] = vtemp[8];
1615 trainEff30[k] = vtemp[9];
1616 sig[k] = vtemp[10];
1617 sep[k] = vtemp[11];
1618 roc[k] = vtemp[12];
1619 mname[k] = vtemps;
1620 }
1621 }
1622
1623 // -----------------------------------------------------------------------
1624 // Second part of evaluation process
1625 // --> compute correlations among MVAs
1626 // --> compute correlations between input variables and MVA (determines importance)
1627 // --> count overlaps
1628 // -----------------------------------------------------------------------
1629 if(fCorrelations)
1630 {
1631 const Int_t nmeth = methodsNoCuts.size();
1632 MethodBase* method = dynamic_cast<MethodBase*>(methods[0][0]);
1633 const Int_t nvar = method->fDataSetInfo.GetNVariables();
1634 if (!doRegression && !doMulticlass ) {
1635
1636 if (nmeth > 0) {
1637
1638 // needed for correlations
1639 Double_t *dvec = new Double_t[nmeth+nvar];
1640 std::vector<Double_t> rvec;
1641
1642 // for correlations
1643 TPrincipal* tpSig = new TPrincipal( nmeth+nvar, "" );
1644 TPrincipal* tpBkg = new TPrincipal( nmeth+nvar, "" );
1645
1646 // set required tree branch references
1647 Int_t ivar = 0;
1648 std::vector<TString>* theVars = new std::vector<TString>;
1649 std::vector<ResultsClassification*> mvaRes;
1650 for (MVector::iterator itrMethod = methodsNoCuts.begin(); itrMethod != methodsNoCuts.end(); ++itrMethod, ++ivar) {
1651 MethodBase* m = dynamic_cast<MethodBase*>(*itrMethod);
1652 if(m==0) continue;
1653 theVars->push_back( m->GetTestvarName() );
1654 rvec.push_back( m->GetSignalReferenceCut() );
1655 theVars->back().ReplaceAll( "MVA_", "" );
1656 mvaRes.push_back( dynamic_cast<ResultsClassification*>( m->Data()->GetResults( m->GetMethodName(),
1659 }
1660
1661 // for overlap study
1662 TMatrixD* overlapS = new TMatrixD( nmeth, nmeth );
1663 TMatrixD* overlapB = new TMatrixD( nmeth, nmeth );
1664 (*overlapS) *= 0; // init...
1665 (*overlapB) *= 0; // init...
1666
1667 // loop over test tree
1668 DataSet* defDs = method->fDataSetInfo.GetDataSet();
1670 for (Int_t ievt=0; ievt<defDs->GetNEvents(); ievt++) {
1671 const Event* ev = defDs->GetEvent(ievt);
1672
1673 // for correlations
1674 TMatrixD* theMat = 0;
1675 for (Int_t im=0; im<nmeth; im++) {
1676 // check for NaN value
1677 Double_t retval = (Double_t)(*mvaRes[im])[ievt][0];
1678 if (TMath::IsNaN(retval)) {
1679 Log() << kWARNING << "Found NaN return value in event: " << ievt
1680 << " for method \"" << methodsNoCuts[im]->GetName() << "\"" << Endl;
1681 dvec[im] = 0;
1682 }
1683 else dvec[im] = retval;
1684 }
1685 for (Int_t iv=0; iv<nvar; iv++) dvec[iv+nmeth] = (Double_t)ev->GetValue(iv);
1686 if (method->fDataSetInfo.IsSignal(ev)) { tpSig->AddRow( dvec ); theMat = overlapS; }
1687 else { tpBkg->AddRow( dvec ); theMat = overlapB; }
1688
1689 // count overlaps
1690 for (Int_t im=0; im<nmeth; im++) {
1691 for (Int_t jm=im; jm<nmeth; jm++) {
1692 if ((dvec[im] - rvec[im])*(dvec[jm] - rvec[jm]) > 0) {
1693 (*theMat)(im,jm)++;
1694 if (im != jm) (*theMat)(jm,im)++;
1695 }
1696 }
1697 }
1698 }
1699
1700 // renormalise overlap matrix
1701 (*overlapS) *= (1.0/defDs->GetNEvtSigTest()); // init...
1702 (*overlapB) *= (1.0/defDs->GetNEvtBkgdTest()); // init...
1703
1704 tpSig->MakePrincipals();
1705 tpBkg->MakePrincipals();
1706
1707 const TMatrixD* covMatS = tpSig->GetCovarianceMatrix();
1708 const TMatrixD* covMatB = tpBkg->GetCovarianceMatrix();
1709
1710 const TMatrixD* corrMatS = gTools().GetCorrelationMatrix( covMatS );
1711 const TMatrixD* corrMatB = gTools().GetCorrelationMatrix( covMatB );
1712
1713 // print correlation matrices
1714 if (corrMatS != 0 && corrMatB != 0) {
1715
1716 // extract MVA matrix
1717 TMatrixD mvaMatS(nmeth,nmeth);
1718 TMatrixD mvaMatB(nmeth,nmeth);
1719 for (Int_t im=0; im<nmeth; im++) {
1720 for (Int_t jm=0; jm<nmeth; jm++) {
1721 mvaMatS(im,jm) = (*corrMatS)(im,jm);
1722 mvaMatB(im,jm) = (*corrMatB)(im,jm);
1723 }
1724 }
1725
1726 // extract variables - to MVA matrix
1727 std::vector<TString> theInputVars;
1728 TMatrixD varmvaMatS(nvar,nmeth);
1729 TMatrixD varmvaMatB(nvar,nmeth);
1730 for (Int_t iv=0; iv<nvar; iv++) {
1731 theInputVars.push_back( method->fDataSetInfo.GetVariableInfo( iv ).GetLabel() );
1732 for (Int_t jm=0; jm<nmeth; jm++) {
1733 varmvaMatS(iv,jm) = (*corrMatS)(nmeth+iv,jm);
1734 varmvaMatB(iv,jm) = (*corrMatB)(nmeth+iv,jm);
1735 }
1736 }
1737
1738 if (nmeth > 1) {
1739 Log() << kINFO << Endl;
1740 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "Inter-MVA correlation matrix (signal):" << Endl;
1741 gTools().FormattedOutput( mvaMatS, *theVars, Log() );
1742 Log() << kINFO << Endl;
1743
1744 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "Inter-MVA correlation matrix (background):" << Endl;
1745 gTools().FormattedOutput( mvaMatB, *theVars, Log() );
1746 Log() << kINFO << Endl;
1747 }
1748
1749 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "Correlations between input variables and MVA response (signal):" << Endl;
1750 gTools().FormattedOutput( varmvaMatS, theInputVars, *theVars, Log() );
1751 Log() << kINFO << Endl;
1752
1753 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "Correlations between input variables and MVA response (background):" << Endl;
1754 gTools().FormattedOutput( varmvaMatB, theInputVars, *theVars, Log() );
1755 Log() << kINFO << Endl;
1756 }
1757 else Log() << kWARNING <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "<TestAllMethods> cannot compute correlation matrices" << Endl;
1758
1759 // print overlap matrices
1760 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "The following \"overlap\" matrices contain the fraction of events for which " << Endl;
1761 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "the MVAs 'i' and 'j' have returned conform answers about \"signal-likeness\"" << Endl;
1762 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "An event is signal-like, if its MVA output exceeds the following value:" << Endl;
1763 gTools().FormattedOutput( rvec, *theVars, "Method" , "Cut value", Log() );
1764 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "which correspond to the working point: eff(signal) = 1 - eff(background)" << Endl;
1765
1766 // give notice that cut method has been excluded from this test
1767 if (nmeth != (Int_t)methods->size())
1768 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "Note: no correlations and overlap with cut method are provided at present" << Endl;
1769
1770 if (nmeth > 1) {
1771 Log() << kINFO << Endl;
1772 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "Inter-MVA overlap matrix (signal):" << Endl;
1773 gTools().FormattedOutput( *overlapS, *theVars, Log() );
1774 Log() << kINFO << Endl;
1775
1776 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "Inter-MVA overlap matrix (background):" << Endl;
1777 gTools().FormattedOutput( *overlapB, *theVars, Log() );
1778 }
1779
1780 // cleanup
1781 delete tpSig;
1782 delete tpBkg;
1783 delete corrMatS;
1784 delete corrMatB;
1785 delete theVars;
1786 delete overlapS;
1787 delete overlapB;
1788 delete [] dvec;
1789 }
1790 }
1791 }
1792 // -----------------------------------------------------------------------
1793 // Third part of evaluation process
1794 // --> output
1795 // -----------------------------------------------------------------------
1796
1797 if (doRegression) {
1798
1799 Log() << kINFO << Endl;
1800 TString hLine = "--------------------------------------------------------------------------------------------------";
1801 Log() << kINFO << "Evaluation results ranked by smallest RMS on test sample:" << Endl;
1802 Log() << kINFO << "(\"Bias\" quotes the mean deviation of the regression from true target." << Endl;
1803 Log() << kINFO << " \"MutInf\" is the \"Mutual Information\" between regression and target." << Endl;
1804 Log() << kINFO << " Indicated by \"_T\" are the corresponding \"truncated\" quantities ob-" << Endl;
1805 Log() << kINFO << " tained when removing events deviating more than 2sigma from average.)" << Endl;
1806 Log() << kINFO << hLine << Endl;
1807 //Log() << kINFO << "DataSet Name: MVA Method: <Bias> <Bias_T> RMS RMS_T | MutInf MutInf_T" << Endl;
1808 Log() << kINFO << hLine << Endl;
1809
1810 for (Int_t i=0; i<nmeth_used[0]; i++) {
1811 MethodBase* theMethod = dynamic_cast<MethodBase*>((*methods)[i]);
1812 if(theMethod==0) continue;
1813
1814 Log() << kINFO << Form("%-20s %-15s:%#9.3g%#9.3g%#9.3g%#9.3g | %#5.3f %#5.3f",
1815 theMethod->fDataSetInfo.GetName(),
1816 (const char*)mname[0][i],
1817 biastest[0][i], biastestT[0][i],
1818 rmstest[0][i], rmstestT[0][i],
1819 minftest[0][i], minftestT[0][i] )
1820 << Endl;
1821 }
1822 Log() << kINFO << hLine << Endl;
1823 Log() << kINFO << Endl;
1824 Log() << kINFO << "Evaluation results ranked by smallest RMS on training sample:" << Endl;
1825 Log() << kINFO << "(overtraining check)" << Endl;
1826 Log() << kINFO << hLine << Endl;
1827 Log() << kINFO << "DataSet Name: MVA Method: <Bias> <Bias_T> RMS RMS_T | MutInf MutInf_T" << Endl;
1828 Log() << kINFO << hLine << Endl;
1829
1830 for (Int_t i=0; i<nmeth_used[0]; i++) {
1831 MethodBase* theMethod = dynamic_cast<MethodBase*>((*methods)[i]);
1832 if(theMethod==0) continue;
1833 Log() << kINFO << Form("%-20s %-15s:%#9.3g%#9.3g%#9.3g%#9.3g | %#5.3f %#5.3f",
1834 theMethod->fDataSetInfo.GetName(),
1835 (const char*)mname[0][i],
1836 biastrain[0][i], biastrainT[0][i],
1837 rmstrain[0][i], rmstrainT[0][i],
1838 minftrain[0][i], minftrainT[0][i] )
1839 << Endl;
1840 }
1841 Log() << kINFO << hLine << Endl;
1842 Log() << kINFO << Endl;
1843 } else if (doMulticlass) {
1844 // ====================================================================
1845 // === Multiclass Output
1846 // ====================================================================
1847
1848 TString hLine =
1849 "-------------------------------------------------------------------------------------------------------";
1850
1851 // This part uses a genetic alg. to evaluate the optimal sig eff * sig pur.
1852 // This is why it is disabled for now.
1853 //
1854 // // --- Acheivable signal efficiency * signal purity
1855 // // --------------------------------------------------------------------
1856 // Log() << kINFO << Endl;
1857 // Log() << kINFO << "Evaluation results ranked by best signal efficiency times signal purity " << Endl;
1858 // Log() << kINFO << hLine << Endl;
1859
1860 // // iterate over methods and evaluate
1861 // for (MVector::iterator itrMethod = methods->begin(); itrMethod != methods->end(); itrMethod++) {
1862 // MethodBase *theMethod = dynamic_cast<MethodBase *>(*itrMethod);
1863 // if (theMethod == 0) {
1864 // continue;
1865 // }
1866
1867 // TString header = "DataSet Name MVA Method ";
1868 // for (UInt_t icls = 0; icls < theMethod->fDataSetInfo.GetNClasses(); ++icls) {
1869 // header += Form("%-12s ", theMethod->fDataSetInfo.GetClassInfo(icls)->GetName());
1870 // }
1871
1872 // Log() << kINFO << header << Endl;
1873 // Log() << kINFO << hLine << Endl;
1874 // for (Int_t i = 0; i < nmeth_used[0]; i++) {
1875 // TString res = Form("[%-14s] %-15s", theMethod->fDataSetInfo.GetName(), (const char *)mname[0][i]);
1876 // for (UInt_t icls = 0; icls < theMethod->fDataSetInfo.GetNClasses(); ++icls) {
1877 // res += Form("%#1.3f ", (multiclass_testEff[i][icls]) * (multiclass_testPur[i][icls]));
1878 // }
1879 // Log() << kINFO << res << Endl;
1880 // }
1881
1882 // Log() << kINFO << hLine << Endl;
1883 // Log() << kINFO << Endl;
1884 // }
1885
1886 // --- 1 vs Rest ROC AUC, signal efficiency @ given background efficiency
1887 // --------------------------------------------------------------------
1888 TString header1 = Form("%-15s%-15s%-15s%-15s%-15s%-15s", "Dataset", "MVA Method", "ROC AUC", "Sig eff@B=0.01",
1889 "Sig eff@B=0.10", "Sig eff@B=0.30");
1890 TString header2 = Form("%-15s%-15s%-15s%-15s%-15s%-15s", "Name:", "/ Class:", "test (train)", "test (train)",
1891 "test (train)", "test (train)");
1892 Log() << kINFO << Endl;
1893 Log() << kINFO << "1-vs-rest performance metrics per class" << Endl;
1894 Log() << kINFO << hLine << Endl;
1895 Log() << kINFO << Endl;
1896 Log() << kINFO << "Considers the listed class as signal and the other classes" << Endl;
1897 Log() << kINFO << "as background, reporting the resulting binary performance." << Endl;
1898 Log() << kINFO << "A score of 0.820 (0.850) means 0.820 was acheived on the" << Endl;
1899 Log() << kINFO << "test set and 0.850 on the training set." << Endl;
1900
1901 Log() << kINFO << Endl;
1902 Log() << kINFO << header1 << Endl;
1903 Log() << kINFO << header2 << Endl;
1904 for (Int_t k = 0; k < 2; k++) {
1905 for (Int_t i = 0; i < nmeth_used[k]; i++) {
1906 if (k == 1) {
1907 mname[k][i].ReplaceAll("Variable_", "");
1908 }
1909
1910 const TString datasetName = itrMap->first;
1911 const TString mvaName = mname[k][i];
1912
1913 MethodBase *theMethod = dynamic_cast<MethodBase *>(GetMethod(datasetName, mvaName));
1914 if (theMethod == 0) {
1915 continue;
1916 }
1917
1918 Log() << kINFO << Endl;
1919 TString row = Form("%-15s%-15s", datasetName.Data(), mvaName.Data());
1920 Log() << kINFO << row << Endl;
1921 Log() << kINFO << "------------------------------" << Endl;
1922
1923 UInt_t numClasses = theMethod->fDataSetInfo.GetNClasses();
1924 for (UInt_t iClass = 0; iClass < numClasses; ++iClass) {
1925
1926 ROCCurve *rocCurveTrain = GetROC(datasetName, mvaName, iClass, Types::kTraining);
1927 ROCCurve *rocCurveTest = GetROC(datasetName, mvaName, iClass, Types::kTesting);
1928
1929 const TString className = theMethod->DataInfo().GetClassInfo(iClass)->GetName();
1930 const Double_t rocaucTrain = rocCurveTrain->GetROCIntegral();
1931 const Double_t effB01Train = rocCurveTrain->GetEffSForEffB(0.01);
1932 const Double_t effB10Train = rocCurveTrain->GetEffSForEffB(0.10);
1933 const Double_t effB30Train = rocCurveTrain->GetEffSForEffB(0.30);
1934 const Double_t rocaucTest = rocCurveTest->GetROCIntegral();
1935 const Double_t effB01Test = rocCurveTest->GetEffSForEffB(0.01);
1936 const Double_t effB10Test = rocCurveTest->GetEffSForEffB(0.10);
1937 const Double_t effB30Test = rocCurveTest->GetEffSForEffB(0.30);
1938 const TString rocaucCmp = Form("%5.3f (%5.3f)", rocaucTest, rocaucTrain);
1939 const TString effB01Cmp = Form("%5.3f (%5.3f)", effB01Test, effB01Train);
1940 const TString effB10Cmp = Form("%5.3f (%5.3f)", effB10Test, effB10Train);
1941 const TString effB30Cmp = Form("%5.3f (%5.3f)", effB30Test, effB30Train);
1942 row = Form("%-15s%-15s%-15s%-15s%-15s%-15s", "", className.Data(), rocaucCmp.Data(), effB01Cmp.Data(),
1943 effB10Cmp.Data(), effB30Cmp.Data());
1944 Log() << kINFO << row << Endl;
1945
1946 delete rocCurveTrain;
1947 delete rocCurveTest;
1948 }
1949 }
1950 }
1951 Log() << kINFO << Endl;
1952 Log() << kINFO << hLine << Endl;
1953 Log() << kINFO << Endl;
1954
1955 // --- Confusion matrices
1956 // --------------------------------------------------------------------
1957 auto printMatrix = [](TMatrixD const &matTraining, TMatrixD const &matTesting, std::vector<TString> classnames,
1958 UInt_t numClasses, MsgLogger &stream) {
1959 // assert (classLabledWidth >= valueLabelWidth + 2)
1960 // if (...) {Log() << kWARN << "..." << Endl; }
1961
1962 // TODO: Ensure matrices are same size.
1963
1964 TString header = Form(" %-14s", " ");
1965 TString headerInfo = Form(" %-14s", " ");
1966 ;
1967 for (UInt_t iCol = 0; iCol < numClasses; ++iCol) {
1968 header += Form(" %-14s", classnames[iCol].Data());
1969 headerInfo += Form(" %-14s", " test (train)");
1970 }
1971 stream << kINFO << header << Endl;
1972 stream << kINFO << headerInfo << Endl;
1973
1974 for (UInt_t iRow = 0; iRow < numClasses; ++iRow) {
1975 stream << kINFO << Form(" %-14s", classnames[iRow].Data());
1976
1977 for (UInt_t iCol = 0; iCol < numClasses; ++iCol) {
1978 if (iCol == iRow) {
1979 stream << kINFO << Form(" %-14s", "-");
1980 } else {
1981 Double_t trainValue = matTraining[iRow][iCol];
1982 Double_t testValue = matTesting[iRow][iCol];
1983 TString entry = Form("%-5.3f (%-5.3f)", testValue, trainValue);
1984 stream << kINFO << Form(" %-14s", entry.Data());
1985 }
1986 }
1987 stream << kINFO << Endl;
1988 }
1989 };
1990
1991 Log() << kINFO << Endl;
1992 Log() << kINFO << "Confusion matrices for all methods" << Endl;
1993 Log() << kINFO << hLine << Endl;
1994 Log() << kINFO << Endl;
1995 Log() << kINFO << "Does a binary comparison between the two classes given by a " << Endl;
1996 Log() << kINFO << "particular row-column combination. In each case, the class " << Endl;
1997 Log() << kINFO << "given by the row is considered signal while the class given " << Endl;
1998 Log() << kINFO << "by the column index is considered background." << Endl;
1999 Log() << kINFO << Endl;
2000 for (UInt_t iMethod = 0; iMethod < methods->size(); ++iMethod) {
2001 MethodBase *theMethod = dynamic_cast<MethodBase *>(methods->at(iMethod));
2002 if (theMethod == nullptr) {
2003 continue;
2004 }
2005 UInt_t numClasses = theMethod->fDataSetInfo.GetNClasses();
2006
2007 std::vector<TString> classnames;
2008 for (UInt_t iCls = 0; iCls < numClasses; ++iCls) {
2009 classnames.push_back(theMethod->fDataSetInfo.GetClassInfo(iCls)->GetName());
2010 }
2011 Log() << kINFO
2012 << "=== Showing confusion matrix for method : " << Form("%-15s", (const char *)mname[0][iMethod])
2013 << Endl;
2014 Log() << kINFO << "(Signal Efficiency for Background Efficiency 0.01%)" << Endl;
2015 Log() << kINFO << "---------------------------------------------------" << Endl;
2016 printMatrix(multiclass_testConfusionEffB01[iMethod], multiclass_trainConfusionEffB01[iMethod], classnames,
2017 numClasses, Log());
2018 Log() << kINFO << Endl;
2019
2020 Log() << kINFO << "(Signal Efficiency for Background Efficiency 0.10%)" << Endl;
2021 Log() << kINFO << "---------------------------------------------------" << Endl;
2022 printMatrix(multiclass_testConfusionEffB10[iMethod], multiclass_trainConfusionEffB10[iMethod], classnames,
2023 numClasses, Log());
2024 Log() << kINFO << Endl;
2025
2026 Log() << kINFO << "(Signal Efficiency for Background Efficiency 0.30%)" << Endl;
2027 Log() << kINFO << "---------------------------------------------------" << Endl;
2028 printMatrix(multiclass_testConfusionEffB30[iMethod], multiclass_trainConfusionEffB30[iMethod], classnames,
2029 numClasses, Log());
2030 Log() << kINFO << Endl;
2031 }
2032 Log() << kINFO << hLine << Endl;
2033 Log() << kINFO << Endl;
2034
2035 } else {
2036 // Binary classification
2037 if (fROC) {
2038 Log().EnableOutput();
2040 Log() << Endl;
2041 TString hLine = "------------------------------------------------------------------------------------------"
2042 "-------------------------";
2043 Log() << kINFO << "Evaluation results ranked by best signal efficiency and purity (area)" << Endl;
2044 Log() << kINFO << hLine << Endl;
2045 Log() << kINFO << "DataSet MVA " << Endl;
2046 Log() << kINFO << "Name: Method: ROC-integ" << Endl;
2047
2048 // Log() << kDEBUG << "DataSet MVA Signal efficiency at bkg eff.(error):
2049 // | Sepa- Signifi- " << Endl; Log() << kDEBUG << "Name: Method: @B=0.01
2050 // @B=0.10 @B=0.30 ROC-integ ROCCurve| ration: cance: " << Endl;
2051 Log() << kDEBUG << hLine << Endl;
2052 for (Int_t k = 0; k < 2; k++) {
2053 if (k == 1 && nmeth_used[k] > 0) {
2054 Log() << kINFO << hLine << Endl;
2055 Log() << kINFO << "Input Variables: " << Endl << hLine << Endl;
2056 }
2057 for (Int_t i = 0; i < nmeth_used[k]; i++) {
2058 TString datasetName = itrMap->first;
2059 TString methodName = mname[k][i];
2060
2061 if (k == 1) {
2062 methodName.ReplaceAll("Variable_", "");
2063 }
2064
2065 MethodBase *theMethod = dynamic_cast<MethodBase *>(GetMethod(datasetName, methodName));
2066 if (theMethod == 0) {
2067 continue;
2068 }
2069
2070 TMVA::DataSet *dataset = theMethod->Data();
2071 TMVA::Results *results = dataset->GetResults(methodName, Types::kTesting, this->fAnalysisType);
2072 std::vector<Bool_t> *mvaResType =
2073 dynamic_cast<ResultsClassification *>(results)->GetValueVectorTypes();
2074
2075 Double_t rocIntegral = 0.0;
2076 if (mvaResType->size() != 0) {
2077 rocIntegral = GetROCIntegral(datasetName, methodName);
2078 }
2079
2080 if (sep[k][i] < 0 || sig[k][i] < 0) {
2081 // cannot compute separation/significance -> no MVA (usually for Cuts)
2082 Log() << kINFO << Form("%-13s %-15s: %#1.3f", datasetName.Data(), methodName.Data(), effArea[k][i])
2083 << Endl;
2084
2085 // Log() << kDEBUG << Form("%-20s %-15s: %#1.3f(%02i) %#1.3f(%02i) %#1.3f(%02i)
2086 // %#1.3f %#1.3f | -- --",
2087 // datasetName.Data(),
2088 // methodName.Data(),
2089 // eff01[k][i], Int_t(1000*eff01err[k][i]),
2090 // eff10[k][i], Int_t(1000*eff10err[k][i]),
2091 // eff30[k][i], Int_t(1000*eff30err[k][i]),
2092 // effArea[k][i],rocIntegral) << Endl;
2093 } else {
2094 Log() << kINFO << Form("%-13s %-15s: %#1.3f", datasetName.Data(), methodName.Data(), rocIntegral)
2095 << Endl;
2096 // Log() << kDEBUG << Form("%-20s %-15s: %#1.3f(%02i) %#1.3f(%02i) %#1.3f(%02i)
2097 // %#1.3f %#1.3f | %#1.3f %#1.3f",
2098 // datasetName.Data(),
2099 // methodName.Data(),
2100 // eff01[k][i], Int_t(1000*eff01err[k][i]),
2101 // eff10[k][i], Int_t(1000*eff10err[k][i]),
2102 // eff30[k][i], Int_t(1000*eff30err[k][i]),
2103 // effArea[k][i],rocIntegral,
2104 // sep[k][i], sig[k][i]) << Endl;
2105 }
2106 }
2107 }
2108 Log() << kINFO << hLine << Endl;
2109 Log() << kINFO << Endl;
2110 Log() << kINFO << "Testing efficiency compared to training efficiency (overtraining check)" << Endl;
2111 Log() << kINFO << hLine << Endl;
2112 Log() << kINFO
2113 << "DataSet MVA Signal efficiency: from test sample (from training sample) "
2114 << Endl;
2115 Log() << kINFO << "Name: Method: @B=0.01 @B=0.10 @B=0.30 "
2116 << Endl;
2117 Log() << kINFO << hLine << Endl;
2118 for (Int_t k = 0; k < 2; k++) {
2119 if (k == 1 && nmeth_used[k] > 0) {
2120 Log() << kINFO << hLine << Endl;
2121 Log() << kINFO << "Input Variables: " << Endl << hLine << Endl;
2122 }
2123 for (Int_t i = 0; i < nmeth_used[k]; i++) {
2124 if (k == 1) mname[k][i].ReplaceAll("Variable_", "");
2125 MethodBase *theMethod = dynamic_cast<MethodBase *>((*methods)[i]);
2126 if (theMethod == 0) continue;
2127
2128 Log() << kINFO << Form("%-20s %-15s: %#1.3f (%#1.3f) %#1.3f (%#1.3f) %#1.3f (%#1.3f)",
2129 theMethod->fDataSetInfo.GetName(), (const char *)mname[k][i], eff01[k][i],
2130 trainEff01[k][i], eff10[k][i], trainEff10[k][i], eff30[k][i], trainEff30[k][i])
2131 << Endl;
2132 }
2133 }
2134 Log() << kINFO << hLine << Endl;
2135 Log() << kINFO << Endl;
2136
2137 if (gTools().CheckForSilentOption(GetOptions())) Log().InhibitOutput();
2138 } // end fROC
2139 }
2140 if(!IsSilentFile())
2141 {
2142 std::list<TString> datasets;
2143 for (Int_t k=0; k<2; k++) {
2144 for (Int_t i=0; i<nmeth_used[k]; i++) {
2145 MethodBase* theMethod = dynamic_cast<MethodBase*>((*methods)[i]);
2146 if(theMethod==0) continue;
2147 // write test/training trees
2148 RootBaseDir()->cd(theMethod->fDataSetInfo.GetName());
2149 if(std::find(datasets.begin(), datasets.end(), theMethod->fDataSetInfo.GetName()) == datasets.end())
2150 {
2153 datasets.push_back(theMethod->fDataSetInfo.GetName());
2154 }
2155 }
2156 }
2157 }
2158 }//end for MethodsMap
2159 // references for citation
2161}
2162
2163////////////////////////////////////////////////////////////////////////////////
2164/// Evaluate Variable Importance
2165
2166TH1F* TMVA::Factory::EvaluateImportance(DataLoader *loader,VIType vitype, Types::EMVA theMethod, TString methodTitle, const char *theOption)
2167{
2168 fModelPersistence=kFALSE;
2169 fSilentFile=kTRUE;//we need silent file here because we need fast classification results
2170
2171 //getting number of variables and variable names from loader
2172 const int nbits = loader->GetDataSetInfo().GetNVariables();
2173 if(vitype==VIType::kShort)
2174 return EvaluateImportanceShort(loader,theMethod,methodTitle,theOption);
2175 else if(vitype==VIType::kAll)
2176 return EvaluateImportanceAll(loader,theMethod,methodTitle,theOption);
2177 else if(vitype==VIType::kRandom&&nbits>10)
2178 {
2179 return EvaluateImportanceRandom(loader,pow(2,nbits),theMethod,methodTitle,theOption);
2180 }else
2181 {
2182 std::cerr<<"Error in Variable Importance: Random mode require more that 10 variables in the dataset."<<std::endl;
2183 return nullptr;
2184 }
2185}
2186
2187////////////////////////////////////////////////////////////////////////////////
2188
2189TH1F* TMVA::Factory::EvaluateImportanceAll(DataLoader *loader, Types::EMVA theMethod, TString methodTitle, const char *theOption)
2190{
2191
2192 uint64_t x = 0;
2193 uint64_t y = 0;
2194
2195 //getting number of variables and variable names from loader
2196 const int nbits = loader->GetDataSetInfo().GetNVariables();
2197 std::vector<TString> varNames = loader->GetDataSetInfo().GetListOfVariables();
2198
2199 uint64_t range = pow(2, nbits);
2200
2201 //vector to save importances
2202 std::vector<Double_t> importances(nbits);
2203 //vector to save ROC
2204 std::vector<Double_t> ROC(range);
2205 ROC[0]=0.5;
2206 for (int i = 0; i < nbits; i++)importances[i] = 0;
2207
2208 Double_t SROC, SSROC; //computed ROC value
2209 for ( x = 1; x <range ; x++) {
2210
2211 std::bitset<VIBITS> xbitset(x);
2212 if (x == 0) continue; //data loader need at least one variable
2213
2214 //creating loader for seed
2215 TMVA::DataLoader *seedloader = new TMVA::DataLoader(xbitset.to_string());
2216
2217 //adding variables from seed
2218 for (int index = 0; index < nbits; index++) {
2219 if (xbitset[index]) seedloader->AddVariable(varNames[index], 'F');
2220 }
2221
2222 DataLoaderCopy(seedloader,loader);
2223 seedloader->PrepareTrainingAndTestTree(loader->GetDataSetInfo().GetCut("Signal"), loader->GetDataSetInfo().GetCut("Background"), loader->GetDataSetInfo().GetSplitOptions());
2224
2225 //Booking Seed
2226 BookMethod(seedloader, theMethod, methodTitle, theOption);
2227
2228 //Train/Test/Evaluation
2229 TrainAllMethods();
2230 TestAllMethods();
2231 EvaluateAllMethods();
2232
2233 //getting ROC
2234 ROC[x] = GetROCIntegral(xbitset.to_string(), methodTitle);
2235
2236 //cleaning information to process sub-seeds
2237 TMVA::MethodBase *smethod=dynamic_cast<TMVA::MethodBase*>(fMethodsMap[xbitset.to_string().c_str()][0][0]);
2239 delete sresults;
2240 delete seedloader;
2241 this->DeleteAllMethods();
2242
2243 fMethodsMap.clear();
2244 //removing global result because it is requiring a lot of RAM for all seeds
2245 }
2246
2247
2248 for ( x = 0; x <range ; x++)
2249 {
2250 SROC=ROC[x];
2251 for (uint32_t i = 0; i < VIBITS; ++i) {
2252 if (x & (1 << i)) {
2253 y = x & ~(1 << i);
2254 std::bitset<VIBITS> ybitset(y);
2255 //need at least one variable
2256 //NOTE: if sub-seed is zero then is the special case
2257 //that count in xbitset is 1
2258 Double_t ny = log(x - y) / 0.693147;
2259 if (y == 0) {
2260 importances[ny] = SROC - 0.5;
2261 continue;
2262 }
2263
2264 //getting ROC
2265 SSROC = ROC[y];
2266 importances[ny] += SROC - SSROC;
2267 //cleaning information
2268 }
2269
2270 }
2271 }
2272 std::cout<<"--- Variable Importance Results (All)"<<std::endl;
2273 return GetImportance(nbits,importances,varNames);
2274}
2275
2276static long int sum(long int i)
2277{
2278 long int _sum=0;
2279 for(long int n=0;n<i;n++) _sum+=pow(2,n);
2280 return _sum;
2281}
2282
2283////////////////////////////////////////////////////////////////////////////////
2284
2285TH1F* TMVA::Factory::EvaluateImportanceShort(DataLoader *loader, Types::EMVA theMethod, TString methodTitle, const char *theOption)
2286{
2287 uint64_t x = 0;
2288 uint64_t y = 0;
2289
2290 //getting number of variables and variable names from loader
2291 const int nbits = loader->GetDataSetInfo().GetNVariables();
2292 std::vector<TString> varNames = loader->GetDataSetInfo().GetListOfVariables();
2293
2294 long int range = sum(nbits);
2295// std::cout<<range<<std::endl;
2296 //vector to save importances
2297 std::vector<Double_t> importances(nbits);
2298 for (int i = 0; i < nbits; i++)importances[i] = 0;
2299
2300 Double_t SROC, SSROC; //computed ROC value
2301
2302 x = range;
2303
2304 std::bitset<VIBITS> xbitset(x);
2305 if (x == 0) Log()<<kFATAL<<"Error: need at least one variable."; //data loader need at least one variable
2306
2307
2308 //creating loader for seed
2309 TMVA::DataLoader *seedloader = new TMVA::DataLoader(xbitset.to_string());
2310
2311 //adding variables from seed
2312 for (int index = 0; index < nbits; index++) {
2313 if (xbitset[index]) seedloader->AddVariable(varNames[index], 'F');
2314 }
2315
2316 //Loading Dataset
2317 DataLoaderCopy(seedloader,loader);
2318
2319 //Booking Seed
2320 BookMethod(seedloader, theMethod, methodTitle, theOption);
2321
2322 //Train/Test/Evaluation
2323 TrainAllMethods();
2324 TestAllMethods();
2325 EvaluateAllMethods();
2326
2327 //getting ROC
2328 SROC = GetROCIntegral(xbitset.to_string(), methodTitle);
2329
2330 //cleaning information to process sub-seeds
2331 TMVA::MethodBase *smethod=dynamic_cast<TMVA::MethodBase*>(fMethodsMap[xbitset.to_string().c_str()][0][0]);
2333 delete sresults;
2334 delete seedloader;
2335 this->DeleteAllMethods();
2336 fMethodsMap.clear();
2337
2338 //removing global result because it is requiring a lot of RAM for all seeds
2339
2340 for (uint32_t i = 0; i < VIBITS; ++i) {
2341 if (x & (1 << i)) {
2342 y = x & ~(1 << i);
2343 std::bitset<VIBITS> ybitset(y);
2344 //need at least one variable
2345 //NOTE: if sub-seed is zero then is the special case
2346 //that count in xbitset is 1
2347 Double_t ny = log(x - y) / 0.693147;
2348 if (y == 0) {
2349 importances[ny] = SROC - 0.5;
2350 continue;
2351 }
2352
2353 //creating loader for sub-seed
2354 TMVA::DataLoader *subseedloader = new TMVA::DataLoader(ybitset.to_string());
2355 //adding variables from sub-seed
2356 for (int index = 0; index < nbits; index++) {
2357 if (ybitset[index]) subseedloader->AddVariable(varNames[index], 'F');
2358 }
2359
2360 //Loading Dataset
2361 DataLoaderCopy(subseedloader,loader);
2362
2363 //Booking SubSeed
2364 BookMethod(subseedloader, theMethod, methodTitle, theOption);
2365
2366 //Train/Test/Evaluation
2367 TrainAllMethods();
2368 TestAllMethods();
2369 EvaluateAllMethods();
2370
2371 //getting ROC
2372 SSROC = GetROCIntegral(ybitset.to_string(), methodTitle);
2373 importances[ny] += SROC - SSROC;
2374
2375 //cleaning information
2376 TMVA::MethodBase *ssmethod=dynamic_cast<TMVA::MethodBase*>(fMethodsMap[ybitset.to_string().c_str()][0][0]);
2378 delete ssresults;
2379 delete subseedloader;
2380 this->DeleteAllMethods();
2381 fMethodsMap.clear();
2382 }
2383 }
2384 std::cout<<"--- Variable Importance Results (Short)"<<std::endl;
2385 return GetImportance(nbits,importances,varNames);
2386}
2387
2388////////////////////////////////////////////////////////////////////////////////
2389
2390TH1F* TMVA::Factory::EvaluateImportanceRandom(DataLoader *loader, UInt_t nseeds, Types::EMVA theMethod, TString methodTitle, const char *theOption)
2391{
2392 TRandom3 *rangen = new TRandom3(0); //Random Gen.
2393
2394 uint64_t x = 0;
2395 uint64_t y = 0;
2396
2397 //getting number of variables and variable names from loader
2398 const int nbits = loader->GetDataSetInfo().GetNVariables();
2399 std::vector<TString> varNames = loader->GetDataSetInfo().GetListOfVariables();
2400
2401 long int range = pow(2, nbits);
2402
2403 //vector to save importances
2404 std::vector<Double_t> importances(nbits);
2405 Double_t importances_norm = 0;
2406 for (int i = 0; i < nbits; i++)importances[i] = 0;
2407
2408 Double_t SROC, SSROC; //computed ROC value
2409 for (UInt_t n = 0; n < nseeds; n++) {
2410 x = rangen -> Integer(range);
2411
2412 std::bitset<32> xbitset(x);
2413 if (x == 0) continue; //data loader need at least one variable
2414
2415
2416 //creating loader for seed
2417 TMVA::DataLoader *seedloader = new TMVA::DataLoader(xbitset.to_string());
2418
2419 //adding variables from seed
2420 for (int index = 0; index < nbits; index++) {
2421 if (xbitset[index]) seedloader->AddVariable(varNames[index], 'F');
2422 }
2423
2424 //Loading Dataset
2425 DataLoaderCopy(seedloader,loader);
2426
2427 //Booking Seed
2428 BookMethod(seedloader, theMethod, methodTitle, theOption);
2429
2430 //Train/Test/Evaluation
2431 TrainAllMethods();
2432 TestAllMethods();
2433 EvaluateAllMethods();
2434
2435 //getting ROC
2436 SROC = GetROCIntegral(xbitset.to_string(), methodTitle);
2437// std::cout << "Seed: n " << n << " x " << x << " xbitset:" << xbitset << " ROC " << SROC << std::endl;
2438
2439 //cleaning information to process sub-seeds
2440 TMVA::MethodBase *smethod=dynamic_cast<TMVA::MethodBase*>(fMethodsMap[xbitset.to_string().c_str()][0][0]);
2442 delete sresults;
2443 delete seedloader;
2444 this->DeleteAllMethods();
2445 fMethodsMap.clear();
2446
2447 //removing global result because it is requiring a lot of RAM for all seeds
2448
2449 for (uint32_t i = 0; i < 32; ++i) {
2450 if (x & (1 << i)) {
2451 y = x & ~(1 << i);
2452 std::bitset<32> ybitset(y);
2453 //need at least one variable
2454 //NOTE: if sub-seed is zero then is the special case
2455 //that count in xbitset is 1
2456 Double_t ny = log(x - y) / 0.693147;
2457 if (y == 0) {
2458 importances[ny] = SROC - 0.5;
2459 importances_norm += importances[ny];
2460 // std::cout << "SubSeed: " << y << " y:" << ybitset << "ROC " << 0.5 << std::endl;
2461 continue;
2462 }
2463
2464 //creating loader for sub-seed
2465 TMVA::DataLoader *subseedloader = new TMVA::DataLoader(ybitset.to_string());
2466 //adding variables from sub-seed
2467 for (int index = 0; index < nbits; index++) {
2468 if (ybitset[index]) subseedloader->AddVariable(varNames[index], 'F');
2469 }
2470
2471 //Loading Dataset
2472 DataLoaderCopy(subseedloader,loader);
2473
2474 //Booking SubSeed
2475 BookMethod(subseedloader, theMethod, methodTitle, theOption);
2476
2477 //Train/Test/Evaluation
2478 TrainAllMethods();
2479 TestAllMethods();
2480 EvaluateAllMethods();
2481
2482 //getting ROC
2483 SSROC = GetROCIntegral(ybitset.to_string(), methodTitle);
2484 importances[ny] += SROC - SSROC;
2485 //std::cout << "SubSeed: " << y << " y:" << ybitset << " x-y " << x - y << " " << std::bitset<32>(x - y) << " ny " << ny << " SROC " << SROC << " SSROC " << SSROC << " Importance = " << importances[ny] << std::endl;
2486 //cleaning information
2487 TMVA::MethodBase *ssmethod=dynamic_cast<TMVA::MethodBase*>(fMethodsMap[ybitset.to_string().c_str()][0][0]);
2489 delete ssresults;
2490 delete subseedloader;
2491 this->DeleteAllMethods();
2492 fMethodsMap.clear();
2493 }
2494 }
2495 }
2496 std::cout<<"--- Variable Importance Results (Random)"<<std::endl;
2497 return GetImportance(nbits,importances,varNames);
2498}
2499
2500////////////////////////////////////////////////////////////////////////////////
2501
2502TH1F* TMVA::Factory::GetImportance(const int nbits,std::vector<Double_t> importances,std::vector<TString> varNames)
2503{
2504 TH1F *vih1 = new TH1F("vih1", "", nbits, 0, nbits);
2505
2506 gStyle->SetOptStat(000000);
2507
2508 Float_t normalization = 0.0;
2509 for (int i = 0; i < nbits; i++) {
2510 normalization = normalization + importances[i];
2511 }
2512
2513 Float_t roc = 0.0;
2514
2515 gStyle->SetTitleXOffset(0.4);
2516 gStyle->SetTitleXOffset(1.2);
2517
2518
2519 std::vector<Double_t> x_ie(nbits), y_ie(nbits);
2520 for (Int_t i = 1; i < nbits + 1; i++) {
2521 x_ie[i - 1] = (i - 1) * 1.;
2522 roc = 100.0 * importances[i - 1] / normalization;
2523 y_ie[i - 1] = roc;
2524 std::cout<<"--- "<<varNames[i-1]<<" = "<<roc<<" %"<<std::endl;
2525 vih1->GetXaxis()->SetBinLabel(i, varNames[i - 1].Data());
2526 vih1->SetBinContent(i, roc);
2527 }
2528 TGraph *g_ie = new TGraph(nbits + 2, &x_ie[0], &y_ie[0]);
2529 g_ie->SetTitle("");
2530
2531 vih1->LabelsOption("v >", "X");
2532 vih1->SetBarWidth(0.97);
2533 Int_t ca = TColor::GetColor("#006600");
2534 vih1->SetFillColor(ca);
2535 //Int_t ci = TColor::GetColor("#990000");
2536
2537 vih1->GetYaxis()->SetTitle("Importance (%)");
2538 vih1->GetYaxis()->SetTitleSize(0.045);
2539 vih1->GetYaxis()->CenterTitle();
2540 vih1->GetYaxis()->SetTitleOffset(1.24);
2541
2542 vih1->GetYaxis()->SetRangeUser(-7, 50);
2543 vih1->SetDirectory(0);
2544
2545// vih1->Draw("B");
2546 return vih1;
2547}
#define h(i)
Definition: RSha256.hxx:106
int Int_t
Definition: RtypesCore.h:41
unsigned int UInt_t
Definition: RtypesCore.h:42
const Bool_t kFALSE
Definition: RtypesCore.h:88
bool Bool_t
Definition: RtypesCore.h:59
double Double_t
Definition: RtypesCore.h:55
float Float_t
Definition: RtypesCore.h:53
const Bool_t kTRUE
Definition: RtypesCore.h:87
#define ClassImp(name)
Definition: Rtypes.h:365
char name[80]
Definition: TGX11.cxx:109
int type
Definition: TGX11.cxx:120
double pow(double, double)
double log(double)
TMatrixT< Double_t > TMatrixD
Definition: TMatrixDfwd.h:22
#define gROOT
Definition: TROOT.h:415
char * Form(const char *fmt,...)
R__EXTERN TStyle * gStyle
Definition: TStyle.h:407
R__EXTERN TSystem * gSystem
Definition: TSystem.h:560
virtual void SetTitleOffset(Float_t offset=1)
Set distance between the axis and the axis title.
Definition: TAttAxis.cxx:294
virtual void SetTitleSize(Float_t size=0.04)
Set size of axis title.
Definition: TAttAxis.cxx:304
virtual void SetFillColor(Color_t fcolor)
Set the fill area color.
Definition: TAttFill.h:37
virtual void SetBinLabel(Int_t bin, const char *label)
Set label for bin.
Definition: TAxis.cxx:809
void CenterTitle(Bool_t center=kTRUE)
Center axis title.
Definition: TAxis.h:184
virtual void SetRangeUser(Double_t ufirst, Double_t ulast)
Set the viewing range for the axis from ufirst to ulast (in user coordinates).
Definition: TAxis.cxx:928
The Canvas class.
Definition: TCanvas.h:31
static Int_t GetColor(const char *hexcolor)
Static method returning color number for color specified by hex color string of form: "#rrggbb",...
Definition: TColor.cxx:1764
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:48
A Graph is a graphics object made of two arrays X and Y with npoints each.
Definition: TGraph.h:41
virtual void SetTitle(const char *title="")
Change (i.e.
Definition: TGraph.cxx:2312
1-D histogram with a float per channel (see TH1 documentation)}
Definition: TH1.h:571
virtual void SetDirectory(TDirectory *dir)
By default when an histogram is created, it is added to the list of histogram objects in the current ...
Definition: TH1.cxx:8381
virtual void SetTitle(const char *title)
See GetStatOverflows for more information.
Definition: TH1.cxx:6333
virtual void LabelsOption(Option_t *option="h", Option_t *axis="X")
Set option(s) to draw axis with labels.
Definition: TH1.cxx:5214
static void AddDirectory(Bool_t add=kTRUE)
Sets the flag controlling the automatic add of histograms in memory.
Definition: TH1.cxx:1226
TAxis * GetXaxis()
Get the behaviour adopted by the object about the statoverflows. See EStatOverflows for more informat...
Definition: TH1.h:316
TAxis * GetYaxis()
Definition: TH1.h:317
virtual void SetBarWidth(Float_t width=0.5)
Definition: TH1.h:356
virtual void SetBinContent(Int_t bin, Double_t content)
Set bin content see convention for numbering bins in TH1::GetBin In case the bin number is greater th...
Definition: TH1.cxx:8666
Service class for 2-Dim histogram classes.
Definition: TH2.h:30
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
TString fWeightFileDir
Definition: Config.h:124
TString fWeightFileDirPrefix
Definition: Config.h:123
void SetDrawProgressBar(Bool_t d)
Definition: Config.h:71
void SetUseColor(Bool_t uc)
Definition: Config.h:62
class TMVA::Config::VariablePlotting fVariablePlotting
void SetSilent(Bool_t s)
Definition: Config.h:65
IONames & GetIONames()
Definition: Config.h:100
void SetConfigDescription(const char *d)
Definition: Configurable.h:64
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
void AddPreDefVal(const T &)
Definition: Configurable.h:168
void SetConfigName(const char *n)
Definition: Configurable.h:63
virtual void ParseOptions()
options parser
const TString & GetOptions() const
Definition: Configurable.h:84
MsgLogger & Log() const
Definition: Configurable.h:122
MsgLogger * fLogger
Definition: Configurable.h:128
void CheckForUnusedOptions() const
checks for unused options in option string
UInt_t GetEntries(const TString &name) const
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
Definition: DataLoader.cxx:617
DataSetInfo & GetDataSetInfo()
Definition: DataLoader.cxx:123
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
Definition: DataLoader.cxx:470
Class that contains all the data information.
Definition: DataSetInfo.h:60
UInt_t GetNVariables() const
Definition: DataSetInfo.h:125
virtual const char * GetName() const
Returns name of object.
Definition: DataSetInfo.h:69
const TMatrixD * CorrelationMatrix(const TString &className) const
UInt_t GetNClasses() const
Definition: DataSetInfo.h:153
const TString & GetSplitOptions() const
Definition: DataSetInfo.h:184
UInt_t GetNTargets() const
Definition: DataSetInfo.h:126
DataSet * GetDataSet() const
returns data set
TH2 * CreateCorrelationMatrixHist(const TMatrixD *m, const TString &hName, const TString &hTitle) const
std::vector< TString > GetListOfVariables() const
returns list of variables
ClassInfo * GetClassInfo(Int_t clNum) const
const TCut & GetCut(Int_t i) const
Definition: DataSetInfo.h:166
VariableInfo & GetVariableInfo(Int_t i)
Definition: DataSetInfo.h:103
Bool_t IsSignal(const Event *ev) const
DataSetManager * GetDataSetManager()
Definition: DataSetInfo.h:192
DataInputHandler & DataInput()
Class that contains all the data information.
Definition: DataSet.h:69
Long64_t GetNEvtSigTest()
return number of signal test events in dataset
Definition: DataSet.cxx:427
TTree * GetTree(Types::ETreeType type)
create the test/trainings tree with all the variables, the weights, the classes, the targets,...
Definition: DataSet.cxx:609
const Event * GetEvent() const
Definition: DataSet.cxx:202
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:217
Results * GetResults(const TString &, Types::ETreeType type, Types::EAnalysisType analysistype)
Definition: DataSet.cxx:265
Long64_t GetNTrainingEvents() const
Definition: DataSet.h:79
void SetCurrentType(Types::ETreeType type) const
Definition: DataSet.h:100
const std::vector< Event * > & GetEventCollection(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:227
Long64_t GetNEvtBkgdTest()
return number of background test events in dataset
Definition: DataSet.cxx:435
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition: Event.cxx:237
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition: Event.cxx:392
This is the main MVA steering class.
Definition: Factory.h:81
void PrintHelpMessage(const TString &datasetname, const TString &methodTitle="") const
Print predefined help message of classifier.
Definition: Factory.cxx:1308
Double_t GetROCIntegral(DataLoader *loader, TString theMethodName, UInt_t iClass=0)
Calculate the integral of the ROC curve, also known as the area under curve (AUC),...
Definition: Factory.cxx:843
Bool_t fCorrelations
verbosity level, controls granularity of logging
Definition: Factory.h:212
std::vector< IMethod * > MVector
Definition: Factory.h:85
void TrainAllMethods()
Iterates through all booked methods and calls training.
Definition: Factory.cxx:1094
Bool_t Verbose(void) const
Definition: Factory.h:135
void WriteDataInformation(DataSetInfo &fDataSetInfo)
Definition: Factory.cxx:597
MethodBase * BookMethod(DataLoader *loader, TString theMethodName, TString methodTitle, TString theOption="")
Book a classifier or regression method.
Definition: Factory.cxx:346
Factory(TString theJobName, TFile *theTargetFile, TString theOption="")
Standard constructor.
Definition: Factory.cxx:119
void TestAllMethods()
Evaluates all booked methods on the testing data and adds the output to the Results in the corresponi...
Definition: Factory.cxx:1245
Bool_t fVerbose
list of transformations to test
Definition: Factory.h:210
void EvaluateAllMethods(void)
Iterates over all MVAs that have been booked, and calls their evaluation methods.
Definition: Factory.cxx:1350
TH1F * EvaluateImportanceRandom(DataLoader *loader, UInt_t nseeds, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Definition: Factory.cxx:2390
TH1F * GetImportance(const int nbits, std::vector< Double_t > importances, std::vector< TString > varNames)
Definition: Factory.cxx:2502
Bool_t fROC
enable to calculate corelations
Definition: Factory.h:213
void EvaluateAllVariables(DataLoader *loader, TString options="")
Iterates over all MVA input variables and evaluates them.
Definition: Factory.cxx:1335
TString fVerboseLevel
verbose mode
Definition: Factory.h:211
TH1F * EvaluateImportance(DataLoader *loader, VIType vitype, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Evaluate Variable Importance.
Definition: Factory.cxx:2166
virtual ~Factory()
Destructor.
Definition: Factory.cxx:300
TGraph * GetROCCurve(DataLoader *loader, TString theMethodName, Bool_t setTitles=kTRUE, UInt_t iClass=0)
Argument iClass specifies the class to generate the ROC curve in a multiclass setting.
Definition: Factory.cxx:904
virtual void MakeClass(const TString &datasetname, const TString &methodTitle="") const
Definition: Factory.cxx:1280
MethodBase * BookMethodWeightfile(DataLoader *dataloader, TMVA::Types::EMVA methodType, const TString &weightfile)
Adds an already constructed method to be managed by this factory.
Definition: Factory.cxx:498
Bool_t fModelPersistence
the training type
Definition: Factory.h:219
std::map< TString, Double_t > OptimizeAllMethods(TString fomType="ROCIntegral", TString fitType="FitGA")
Iterates through all booked methods and sees if they use parameter tuning and if so.
Definition: Factory.cxx:695
ROCCurve * GetROC(DataLoader *loader, TString theMethodName, UInt_t iClass=0, Types::ETreeType type=Types::kTesting)
Private method to generate a ROCCurve instance for a given method.
Definition: Factory.cxx:744
TH1F * EvaluateImportanceShort(DataLoader *loader, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Definition: Factory.cxx:2285
Types::EAnalysisType fAnalysisType
jobname, used as extension in weight file names
Definition: Factory.h:218
Bool_t HasMethod(const TString &datasetname, const TString &title) const
Checks whether a given method name is defined for a given dataset.
Definition: Factory.cxx:580
TH1F * EvaluateImportanceAll(DataLoader *loader, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Definition: Factory.cxx:2189
void SetVerbose(Bool_t v=kTRUE)
Definition: Factory.cxx:338
TFile * fgTargetFile
Definition: Factory.h:202
IMethod * GetMethod(const TString &datasetname, const TString &title) const
Returns pointer to MVA that corresponds to given method title.
Definition: Factory.cxx:562
void DeleteAllMethods(void)
Delete methods.
Definition: Factory.cxx:318
TString fTransformations
option string given by construction (presently only "V")
Definition: Factory.h:209
void Greetings()
Print welcome message.
Definition: Factory.cxx:290
TMultiGraph * GetROCCurveAsMultiGraph(DataLoader *loader, UInt_t iClass)
Generate a collection of graphs, for all methods for a given class.
Definition: Factory.cxx:973
Interface for all concrete MVA method implementations.
Definition: IMethod.h:54
virtual void PrintHelpMessage() const =0
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)=0
virtual void MakeClass(const TString &classFileName=TString("")) const =0
Virtual base Class for all MVA method.
Definition: MethodBase.h:111
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as
void SetSilentFile(Bool_t status)
Definition: MethodBase.h:377
void SetWeightFileDir(TString fileDir)
set directory of weight file
virtual void TestRegression(Double_t &bias, Double_t &biasT, Double_t &dev, Double_t &devT, Double_t &rms, Double_t &rmsT, Double_t &mInf, Double_t &mInfT, Double_t &corr, Types::ETreeType type)
calculate <sum-of-deviation-squared> of regression output versus "true" value from test sample
Definition: MethodBase.cxx:982
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
Definition: MethodBase.cxx:601
TString GetMethodTypeName() const
Definition: MethodBase.h:331
Bool_t DoMulticlass() const
Definition: MethodBase.h:439
virtual Double_t GetSignificance() const
compute significance of mean difference
const char * GetName() const
Definition: MethodBase.h:333
Types::EAnalysisType GetAnalysisType() const
Definition: MethodBase.h:437
virtual TMatrixD GetMulticlassConfusionMatrix(Double_t effB, Types::ETreeType type)
Construct a confusion matrix for a multiclass classifier.
void PrintHelpMessage() const
prints out method-specific help method
virtual void WriteEvaluationHistosToFile(Types::ETreeType treetype)
writes all MVA evaluation histograms to file
virtual void TestMulticlass()
test multiclass classification
void SetupMethod()
setup of methods
Definition: MethodBase.cxx:411
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
virtual void SetAnalysisType(Types::EAnalysisType type)
Definition: MethodBase.h:436
const TString & GetMethodName() const
Definition: MethodBase.h:330
Bool_t DoRegression() const
Definition: MethodBase.h:438
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
Definition: MethodBase.cxx:428
virtual Double_t GetTrainingEfficiency(const TString &)
DataSetInfo & DataInfo() const
Definition: MethodBase.h:409
virtual void MakeClass(const TString &classFileName=TString("")) const
create reader class for method (classification only at present)
virtual void TestClassification()
initialization
void AddOutput(Types::ETreeType type, Types::EAnalysisType analysisType)
void ReadStateFromFile()
Function to write options and weights to file.
virtual std::map< TString, Double_t > OptimizeTuningParameters(TString fomType="ROCIntegral", TString fitType="FitGA")
call the Optimizer with the set of parameters and ranges that are meant to be tuned.
Definition: MethodBase.cxx:628
DataSetInfo & fDataSetInfo
Definition: MethodBase.h:605
Types::EMVA GetMethodType() const
Definition: MethodBase.h:332
void SetFile(TFile *file)
Definition: MethodBase.h:374
DataSet * Data() const
Definition: MethodBase.h:408
void SetModelPersistence(Bool_t status)
Definition: MethodBase.h:381
virtual Double_t GetROCIntegral(TH1D *histS, TH1D *histB) const
calculate the area (integral) under the ROC curve as a overall quality measure of the classification
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Definition: MethodBase.cxx:438
Class for boosting a TMVA method.
Definition: MethodBoost.h:58
void SetBoostedMethodName(TString methodName)
Definition: MethodBoost.h:86
DataSetManager * fDataSetManager
Definition: MethodBoost.h:193
Class for categorizing the phase space.
DataSetManager * fDataSetManager
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
void SetMinType(EMsgType minType)
Definition: MsgLogger.h:72
void SetSource(const std::string &source)
Definition: MsgLogger.h:70
static void InhibitOutput()
Definition: MsgLogger.cxx:74
Double_t GetEffSForEffB(Double_t effB, const UInt_t num_points=41)
Calculate the signal efficiency (sensitivity) for a given background efficiency (sensitivity).
Definition: ROCCurve.cxx:220
Double_t GetROCIntegral(const UInt_t points=41)
Calculates the ROC integral (AUC)
Definition: ROCCurve.cxx:251
TGraph * GetROCCurve(const UInt_t points=100)
Returns a new TGraph containing the ROC curve.
Definition: ROCCurve.cxx:277
Ranking for variables in method (implementation)
Definition: Ranking.h:48
virtual void Print() const
get maximum length of variable names
Definition: Ranking.cxx:111
Class that is the base-class for a vector of result.
Class which takes the results of a multiclass classification.
Class that is the base-class for a vector of result.
Definition: Results.h:57
void FormattedOutput(const std::vector< Double_t > &, const std::vector< TString > &, const TString titleVars, const TString titleValues, MsgLogger &logger, TString format="%+1.3f")
formatted output of simple table
Definition: Tools.cxx:899
void UsefulSortDescending(std::vector< std::vector< Double_t > > &, std::vector< TString > *vs=0)
sort 2D vector (AND in parallel a TString vector) in such a way that the "first vector is sorted" and...
Definition: Tools.cxx:576
void ROOTVersionMessage(MsgLogger &logger)
prints the ROOT release number and date
Definition: Tools.cxx:1337
std::vector< TString > SplitString(const TString &theOpt, const char separator) const
splits the option string at 'separator' and fills the list 'splitV' with the primitive strings
Definition: Tools.cxx:1211
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:840
const TMatrixD * GetCorrelationMatrix(const TMatrixD *covMat)
turns covariance into correlation matrix
Definition: Tools.cxx:336
@ kHtmlLink
Definition: Tools.h:216
void TMVACitation(MsgLogger &logger, ECitation citType=kPlainText)
kinds of TMVA citation
Definition: Tools.cxx:1453
void TMVAVersionMessage(MsgLogger &logger)
prints the TMVA release number and date
Definition: Tools.cxx:1328
void TMVAWelcomeMessage()
direct output, eg, when starting ROOT session -> no use of Logger here
Definition: Tools.cxx:1314
void UsefulSortAscending(std::vector< std::vector< Double_t > > &, std::vector< TString > *vs=0)
sort 2D vector (AND in parallel a TString vector) in such a way that the "first vector is sorted" and...
Definition: Tools.cxx:550
Class that contains all the data information.
void PrintVariableRanking() const
prints ranking of input variables
Singleton class for Global types used by TMVA.
Definition: Types.h:73
static Types & Instance()
the the single instance of "Types" if existing already, or create it (Singleton)
Definition: Types.cxx:70
@ kCategory
Definition: Types.h:99
@ kCuts
Definition: Types.h:80
EAnalysisType
Definition: Types.h:127
@ kMulticlass
Definition: Types.h:130
@ kNoAnalysisType
Definition: Types.h:131
@ kClassification
Definition: Types.h:128
@ kMaxAnalysisType
Definition: Types.h:132
@ kRegression
Definition: Types.h:129
@ kTraining
Definition: Types.h:144
@ kTesting
Definition: Types.h:145
const TString & GetLabel() const
Definition: VariableInfo.h:59
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:35
TList * GetListOfGraphs() const
Definition: TMultiGraph.h:69
virtual void Add(TGraph *graph, Option_t *chopt="")
Add a new graph to the list of graphs.
TH1F * GetHistogram()
Returns a pointer to the histogram used to draw the axis.
virtual void Draw(Option_t *chopt="")
Draw this multigraph with its current attributes.
TAxis * GetYaxis()
Get y axis of the graph.
TAxis * GetXaxis()
Get x axis of the graph.
virtual void SetTitle(const char *title="")
Set the title of the TNamed.
Definition: TNamed.cxx:164
TString fName
Definition: TNamed.h:32
virtual TObject * Clone(const char *newname="") const
Make a clone of an object using the Streamer facility.
Definition: TNamed.cxx:74
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:47
@ kOverwrite
overwrite existing object with same name
Definition: TObject.h:88
virtual TLegend * BuildLegend(Double_t x1=0.3, Double_t y1=0.21, Double_t x2=0.3, Double_t y2=0.21, const char *title="", Option_t *option="")
Build a legend from the graphical objects in the pad.
Definition: TPad.cxx:491
virtual void SetGrid(Int_t valuex=1, Int_t valuey=1)
Definition: TPad.h:330
Principal Components Analysis (PCA)
Definition: TPrincipal.h:20
virtual void AddRow(const Double_t *x)
Add a data point and update the covariance matrix.
Definition: TPrincipal.cxx:410
const TMatrixD * GetCovarianceMatrix() const
Definition: TPrincipal.h:58
virtual void MakePrincipals()
Perform the principal components analysis.
Definition: TPrincipal.cxx:869
Random number generator class based on M.
Definition: TRandom3.h:27
Basic string class.
Definition: TString.h:131
Ssiz_t Length() const
Definition: TString.h:405
void ToLower()
Change string to lower-case.
Definition: TString.cxx:1125
int CompareTo(const char *cs, ECaseCompare cmp=kExact) const
Compare a string to char *cs2.
Definition: TString.cxx:418
const char * Data() const
Definition: TString.h:364
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition: TString.h:687
Bool_t BeginsWith(const char *s, ECaseCompare cmp=kExact) const
Definition: TString.h:610
Bool_t IsNull() const
Definition: TString.h:402
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition: TString.h:619
void SetOptStat(Int_t stat=1)
The type of information printed in the histogram statistics box can be selected via the parameter mod...
Definition: TStyle.cxx:1450
void SetTitleXOffset(Float_t offset=1)
Definition: TStyle.h:387
virtual int MakeDirectory(const char *name)
Make a directory.
Definition: TSystem.cxx:835
virtual Int_t Write(const char *name=0, Int_t option=0, Int_t bufsize=0)
Write this object to the current directory.
Definition: TTree.cxx:9485
Double_t y[n]
Definition: legend1.C:17
Double_t x[n]
Definition: legend1.C:17
const Int_t n
Definition: legend1.C:16
std::string GetMethodName(TCppMethod_t)
Definition: Cppyy.cxx:757
TCppMethod_t GetMethod(TCppScope_t scope, TCppIndex_t imeth)
Definition: Cppyy.cxx:751
static constexpr double s
void DataLoaderCopy(TMVA::DataLoader *des, TMVA::DataLoader *src)
Config & gConfig()
Tools & gTools()
void CreateVariableTransforms(const TString &trafoDefinition, TMVA::DataSetInfo &dataInfo, TMVA::TransformationHandler &transformationHandler, TMVA::MsgLogger &log)
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Bool_t IsNaN(Double_t x)
Definition: TMath.h:882
Double_t Log(Double_t x)
Definition: TMath.h:750
Definition: graph.py:1
auto * m
Definition: textangle.C:8
#define VIBITS
Definition: Factory.cxx:107
static long int sum(long int i)
Definition: Factory.cxx:2276
const Int_t MinNoTrainingEvents
Definition: Factory.cxx:99
#define READXML
Definition: Factory.cxx:104