Logo ROOT   6.16/01
Reference Guide
Factory.cxx
Go to the documentation of this file.
1// @(#)Root/tmva $Id$
2// Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne, Jan Therhaag
3// Updated by: Omar Zapata, Kim Albertsson
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : Factory *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Implementation (see header for description) *
12 * *
13 * Authors : *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Joerg Stelzer <stelzer@cern.ch> - DESY, Germany *
16 * Peter Speckmayer <peter.speckmayer@cern.ch> - CERN, Switzerland *
17 * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
18 * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
19 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
20 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
21 * Omar Zapata <Omar.Zapata@cern.ch> - UdeA/ITM Colombia *
22 * Lorenzo Moneta <Lorenzo.Moneta@cern.ch> - CERN, Switzerland *
23 * Sergei Gleyzer <Sergei.Gleyzer@cern.ch> - U of Florida & CERN *
24 * Kim Albertsson <kim.albertsson@cern.ch> - LTU & CERN *
25 * *
26 * Copyright (c) 2005-2015: *
27 * CERN, Switzerland *
28 * U. of Victoria, Canada *
29 * MPI-K Heidelberg, Germany *
30 * U. of Bonn, Germany *
31 * UdeA/ITM, Colombia *
32 * U. of Florida, USA *
33 * *
34 * Redistribution and use in source and binary forms, with or without *
35 * modification, are permitted according to the terms listed in LICENSE *
36 * (http://tmva.sourceforge.net/LICENSE) *
37 **********************************************************************************/
38
39/*! \class TMVA::Factory
40\ingroup TMVA
41
42This is the main MVA steering class.
43It creates all MVA methods, and guides them through the training, testing and
44evaluation phases.
45*/
46
47#include "TMVA/Factory.h"
48
50#include "TMVA/Config.h"
51#include "TMVA/Configurable.h"
52#include "TMVA/Tools.h"
53#include "TMVA/Ranking.h"
54#include "TMVA/DataSet.h"
55#include "TMVA/IMethod.h"
56#include "TMVA/MethodBase.h"
58#include "TMVA/DataSetManager.h"
59#include "TMVA/DataSetInfo.h"
60#include "TMVA/DataLoader.h"
61#include "TMVA/MethodBoost.h"
62#include "TMVA/MethodCategory.h"
63#include "TMVA/ROCCalc.h"
64#include "TMVA/ROCCurve.h"
65#include "TMVA/MsgLogger.h"
66
67#include "TMVA/VariableInfo.h"
69
70#include "TMVA/Results.h"
74#include <list>
75#include <bitset>
76
77#include "TMVA/Types.h"
78
79#include "TROOT.h"
80#include "TFile.h"
81#include "TTree.h"
82#include "TLeaf.h"
83#include "TEventList.h"
84#include "TH2.h"
85#include "TText.h"
86#include "TLegend.h"
87#include "TGraph.h"
88#include "TStyle.h"
89#include "TMatrixF.h"
90#include "TMatrixDSym.h"
91#include "TMultiGraph.h"
92#include "TPaletteAxis.h"
93#include "TPrincipal.h"
94#include "TMath.h"
95#include "TObjString.h"
96#include "TSystem.h"
97#include "TCanvas.h"
98
100//const Int_t MinNoTestEvents = 1;
101
103
104#define READXML kTRUE
105
106//number of bits for bitset
107#define VIBITS 32
108
109
110
111////////////////////////////////////////////////////////////////////////////////
112/// Standard constructor.
113///
114/// - jobname : this name will appear in all weight file names produced by the MVAs
115/// - theTargetFile : output ROOT file; the test tree and all evaluation plots
116/// will be stored here
117/// - theOption : option string; currently: "V" for verbose
118
119TMVA::Factory::Factory( TString jobName, TFile* theTargetFile, TString theOption )
120: Configurable ( theOption ),
121 fTransformations ( "I" ),
122 fVerbose ( kFALSE ),
123 fVerboseLevel ( kINFO ),
124 fCorrelations ( kFALSE ),
125 fROC ( kTRUE ),
126 fSilentFile ( kFALSE ),
127 fJobName ( jobName ),
128 fAnalysisType ( Types::kClassification ),
129 fModelPersistence (kTRUE)
130{
131 fgTargetFile = theTargetFile;
133
134 // render silent
135 if (gTools().CheckForSilentOption( GetOptions() )) Log().InhibitOutput(); // make sure is silent if wanted to
136
137
138 // init configurable
139 SetConfigDescription( "Configuration options for Factory running" );
141
142 // histograms are not automatically associated with the current
143 // directory and hence don't go out of scope when closing the file
144 // TH1::AddDirectory(kFALSE);
145 Bool_t silent = kFALSE;
146#ifdef WIN32
147 // under Windows, switch progress bar and color off by default, as the typical windows shell doesn't handle these (would need different sequences..)
148 Bool_t color = kFALSE;
149 Bool_t drawProgressBar = kFALSE;
150#else
151 Bool_t color = !gROOT->IsBatch();
152 Bool_t drawProgressBar = kTRUE;
153#endif
154 DeclareOptionRef( fVerbose, "V", "Verbose flag" );
155 DeclareOptionRef( fVerboseLevel=TString("Info"), "VerboseLevel", "VerboseLevel (Debug/Verbose/Info)" );
156 AddPreDefVal(TString("Debug"));
157 AddPreDefVal(TString("Verbose"));
158 AddPreDefVal(TString("Info"));
159 DeclareOptionRef( color, "Color", "Flag for coloured screen output (default: True, if in batch mode: False)" );
160 DeclareOptionRef( fTransformations, "Transformations", "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for identity, decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation transformations" );
161 DeclareOptionRef( fCorrelations, "Correlations", "boolean to show correlation in output" );
162 DeclareOptionRef( fROC, "ROC", "boolean to show ROC in output" );
163 DeclareOptionRef( silent, "Silent", "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory class object (default: False)" );
164 DeclareOptionRef( drawProgressBar,
165 "DrawProgressBar", "Draw progress bar to display training, testing and evaluation schedule (default: True)" );
167 "ModelPersistence",
168 "Option to save the trained model in xml file or using serialization");
169
170 TString analysisType("Auto");
171 DeclareOptionRef( analysisType,
172 "AnalysisType", "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)" );
173 AddPreDefVal(TString("Classification"));
174 AddPreDefVal(TString("Regression"));
175 AddPreDefVal(TString("Multiclass"));
176 AddPreDefVal(TString("Auto"));
177
178 ParseOptions();
180
181 if (Verbose()) fLogger->SetMinType( kVERBOSE );
182 if (fVerboseLevel.CompareTo("Debug") ==0) fLogger->SetMinType( kDEBUG );
183 if (fVerboseLevel.CompareTo("Verbose") ==0) fLogger->SetMinType( kVERBOSE );
184 if (fVerboseLevel.CompareTo("Info") ==0) fLogger->SetMinType( kINFO );
185
186 // global settings
187 gConfig().SetUseColor( color );
188 gConfig().SetSilent( silent );
189 gConfig().SetDrawProgressBar( drawProgressBar );
190
191 analysisType.ToLower();
192 if ( analysisType == "classification" ) fAnalysisType = Types::kClassification;
193 else if( analysisType == "regression" ) fAnalysisType = Types::kRegression;
194 else if( analysisType == "multiclass" ) fAnalysisType = Types::kMulticlass;
195 else if( analysisType == "auto" ) fAnalysisType = Types::kNoAnalysisType;
196
197// Greetings();
198}
199
200////////////////////////////////////////////////////////////////////////////////
201/// Constructor.
202
204: Configurable ( theOption ),
205 fTransformations ( "I" ),
206 fVerbose ( kFALSE ),
207 fCorrelations ( kFALSE ),
208 fROC ( kTRUE ),
209 fSilentFile ( kTRUE ),
210 fJobName ( jobName ),
211 fAnalysisType ( Types::kClassification ),
212 fModelPersistence (kTRUE)
213{
214 fgTargetFile = 0;
216
217
218 // render silent
219 if (gTools().CheckForSilentOption( GetOptions() )) Log().InhibitOutput(); // make sure is silent if wanted to
220
221
222 // init configurable
223 SetConfigDescription( "Configuration options for Factory running" );
225
226 // histograms are not automatically associated with the current
227 // directory and hence don't go out of scope when closing the file
229 Bool_t silent = kFALSE;
230#ifdef WIN32
231 // under Windows, switch progress bar and color off by default, as the typical windows shell doesn't handle these (would need different sequences..)
232 Bool_t color = kFALSE;
233 Bool_t drawProgressBar = kFALSE;
234#else
235 Bool_t color = !gROOT->IsBatch();
236 Bool_t drawProgressBar = kTRUE;
237#endif
238 DeclareOptionRef( fVerbose, "V", "Verbose flag" );
239 DeclareOptionRef( fVerboseLevel=TString("Info"), "VerboseLevel", "VerboseLevel (Debug/Verbose/Info)" );
240 AddPreDefVal(TString("Debug"));
241 AddPreDefVal(TString("Verbose"));
242 AddPreDefVal(TString("Info"));
243 DeclareOptionRef( color, "Color", "Flag for coloured screen output (default: True, if in batch mode: False)" );
244 DeclareOptionRef( fTransformations, "Transformations", "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for identity, decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation transformations" );
245 DeclareOptionRef( fCorrelations, "Correlations", "boolean to show correlation in output" );
246 DeclareOptionRef( fROC, "ROC", "boolean to show ROC in output" );
247 DeclareOptionRef( silent, "Silent", "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory class object (default: False)" );
248 DeclareOptionRef( drawProgressBar,
249 "DrawProgressBar", "Draw progress bar to display training, testing and evaluation schedule (default: True)" );
251 "ModelPersistence",
252 "Option to save the trained model in xml file or using serialization");
253
254 TString analysisType("Auto");
255 DeclareOptionRef( analysisType,
256 "AnalysisType", "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)" );
257 AddPreDefVal(TString("Classification"));
258 AddPreDefVal(TString("Regression"));
259 AddPreDefVal(TString("Multiclass"));
260 AddPreDefVal(TString("Auto"));
261
262 ParseOptions();
264
265 if (Verbose()) fLogger->SetMinType( kVERBOSE );
266 if (fVerboseLevel.CompareTo("Debug") ==0) fLogger->SetMinType( kDEBUG );
267 if (fVerboseLevel.CompareTo("Verbose") ==0) fLogger->SetMinType( kVERBOSE );
268 if (fVerboseLevel.CompareTo("Info") ==0) fLogger->SetMinType( kINFO );
269
270 // global settings
271 gConfig().SetUseColor( color );
272 gConfig().SetSilent( silent );
273 gConfig().SetDrawProgressBar( drawProgressBar );
274
275 analysisType.ToLower();
276 if ( analysisType == "classification" ) fAnalysisType = Types::kClassification;
277 else if( analysisType == "regression" ) fAnalysisType = Types::kRegression;
278 else if( analysisType == "multiclass" ) fAnalysisType = Types::kMulticlass;
279 else if( analysisType == "auto" ) fAnalysisType = Types::kNoAnalysisType;
280
281 Greetings();
282}
283
284////////////////////////////////////////////////////////////////////////////////
285/// Print welcome message.
286/// Options are: kLogoWelcomeMsg, kIsometricWelcomeMsg, kLeanWelcomeMsg
287
289{
291 gTools().TMVAWelcomeMessage( Log(), gTools().kLogoWelcomeMsg );
292 gTools().TMVAVersionMessage( Log() ); Log() << Endl;
293}
294
295////////////////////////////////////////////////////////////////////////////////
296
298{
299 return fSilentFile;
300}
301
302////////////////////////////////////////////////////////////////////////////////
303
305{
306 return fModelPersistence;
307}
308
309////////////////////////////////////////////////////////////////////////////////
310/// Destructor.
311
313{
314 std::vector<TMVA::VariableTransformBase*>::iterator trfIt = fDefaultTrfs.begin();
315 for (;trfIt != fDefaultTrfs.end(); ++trfIt) delete (*trfIt);
316
317 this->DeleteAllMethods();
318
319
320 // problem with call of REGISTER_METHOD macro ...
321 // ClassifierFactory::DestroyInstance();
322 // Types::DestroyInstance();
323 //Tools::DestroyInstance();
324 //Config::DestroyInstance();
325}
326
327////////////////////////////////////////////////////////////////////////////////
328/// Delete methods.
329
331{
332 std::map<TString,MVector*>::iterator itrMap;
333
334 for(itrMap = fMethodsMap.begin();itrMap != fMethodsMap.end();++itrMap)
335 {
336 MVector *methods=itrMap->second;
337 // delete methods
338 MVector::iterator itrMethod = methods->begin();
339 for (; itrMethod != methods->end(); ++itrMethod) {
340 Log() << kDEBUG << "Delete method: " << (*itrMethod)->GetName() << Endl;
341 delete (*itrMethod);
342 }
343 methods->clear();
344 delete methods;
345 }
346}
347
348////////////////////////////////////////////////////////////////////////////////
349
351{
352 fVerbose = v;
353}
354
355////////////////////////////////////////////////////////////////////////////////
356/// Book a classifier or regression method.
357
358TMVA::MethodBase* TMVA::Factory::BookMethod( TMVA::DataLoader *loader, TString theMethodName, TString methodTitle, TString theOption )
359{
360 if(fModelPersistence) gSystem->MakeDirectory(loader->GetName());//creating directory for DataLoader output
361
362 TString datasetname=loader->GetName();
363
364 if( fAnalysisType == Types::kNoAnalysisType ){
365 if( loader->DefaultDataSetInfo().GetNClasses()==2
366 && loader->DefaultDataSetInfo().GetClassInfo("Signal") != NULL
367 && loader->DefaultDataSetInfo().GetClassInfo("Background") != NULL
368 ){
369 fAnalysisType = Types::kClassification; // default is classification
370 } else if( loader->DefaultDataSetInfo().GetNClasses() >= 2 ){
371 fAnalysisType = Types::kMulticlass; // if two classes, but not named "Signal" and "Background"
372 } else
373 Log() << kFATAL << "No analysis type for " << loader->DefaultDataSetInfo().GetNClasses() << " classes and "
374 << loader->DefaultDataSetInfo().GetNTargets() << " regression targets." << Endl;
375 }
376
377 // booking via name; the names are translated into enums and the
378 // corresponding overloaded BookMethod is called
379
380 if(fMethodsMap.find(datasetname)!=fMethodsMap.end())
381 {
382 if (GetMethod( datasetname,methodTitle ) != 0) {
383 Log() << kFATAL << "Booking failed since method with title <"
384 << methodTitle <<"> already exists "<< "in with DataSet Name <"<< loader->GetName()<<"> "
385 << Endl;
386 }
387 }
388
389
390 Log() << kHEADER << "Booking method: " << gTools().Color("bold") << methodTitle
391 // << gTools().Color("reset")<<" DataSet Name: "<<gTools().Color("bold")<<loader->GetName()
392 << gTools().Color("reset") << Endl << Endl;
393
394 // interpret option string with respect to a request for boosting (i.e., BostNum > 0)
395 Int_t boostNum = 0;
396 TMVA::Configurable* conf = new TMVA::Configurable( theOption );
397 conf->DeclareOptionRef( boostNum = 0, "Boost_num",
398 "Number of times the classifier will be boosted" );
399 conf->ParseOptions();
400 delete conf;
401 TString fFileDir;
402 if(fModelPersistence)
403 {
404 fFileDir=loader->GetName();
405 fFileDir+="/"+gConfig().GetIONames().fWeightFileDir;
406 }
407 // initialize methods
408 IMethod* im;
409 if (!boostNum) {
410 im = ClassifierFactory::Instance().Create(theMethodName.Data(), fJobName, methodTitle,
411 loader->DefaultDataSetInfo(), theOption);
412 }
413 else {
414 // boosted classifier, requires a specific definition, making it transparent for the user
415 Log() << kDEBUG <<"Boost Number is " << boostNum << " > 0: train boosted classifier" << Endl;
416 im = ClassifierFactory::Instance().Create("Boost", fJobName, methodTitle, loader->DefaultDataSetInfo(), theOption);
417 MethodBoost *methBoost = dynamic_cast<MethodBoost *>(im); // DSMTEST divided into two lines
418 if (!methBoost) // DSMTEST
419 Log() << kFATAL << "Method with type kBoost cannot be casted to MethodCategory. /Factory" << Endl; // DSMTEST
420
421 if (fModelPersistence)
422 methBoost->SetWeightFileDir(fFileDir);
423 methBoost->SetModelPersistence(fModelPersistence);
424 methBoost->SetBoostedMethodName(theMethodName); // DSMTEST divided into two lines
425 methBoost->fDataSetManager = loader->fDataSetManager; // DSMTEST
426 methBoost->SetFile(fgTargetFile);
427 methBoost->SetSilentFile(IsSilentFile());
428 }
429
430 MethodBase *method = dynamic_cast<MethodBase*>(im);
431 if (method==0) return 0; // could not create method
432
433 // set fDataSetManager if MethodCategory (to enable Category to create datasetinfo objects) // DSMTEST
434 if (method->GetMethodType() == Types::kCategory) { // DSMTEST
435 MethodCategory *methCat = (dynamic_cast<MethodCategory*>(im)); // DSMTEST
436 if (!methCat) // DSMTEST
437 Log() << kFATAL << "Method with type kCategory cannot be casted to MethodCategory. /Factory" << Endl; // DSMTEST
438
439 if(fModelPersistence) methCat->SetWeightFileDir(fFileDir);
440 methCat->SetModelPersistence(fModelPersistence);
441 methCat->fDataSetManager = loader->fDataSetManager; // DSMTEST
442 methCat->SetFile(fgTargetFile);
443 methCat->SetSilentFile(IsSilentFile());
444 } // DSMTEST
445
446
447 if (!method->HasAnalysisType( fAnalysisType,
449 loader->DefaultDataSetInfo().GetNTargets() )) {
450 Log() << kWARNING << "Method " << method->GetMethodTypeName() << " is not capable of handling " ;
451 if (fAnalysisType == Types::kRegression) {
452 Log() << "regression with " << loader->DefaultDataSetInfo().GetNTargets() << " targets." << Endl;
453 }
454 else if (fAnalysisType == Types::kMulticlass ) {
455 Log() << "multiclass classification with " << loader->DefaultDataSetInfo().GetNClasses() << " classes." << Endl;
456 }
457 else {
458 Log() << "classification with " << loader->DefaultDataSetInfo().GetNClasses() << " classes." << Endl;
459 }
460 return 0;
461 }
462
463 if(fModelPersistence) method->SetWeightFileDir(fFileDir);
464 method->SetModelPersistence(fModelPersistence);
465 method->SetAnalysisType( fAnalysisType );
466 method->SetupMethod();
467 method->ParseOptions();
468 method->ProcessSetup();
469 method->SetFile(fgTargetFile);
470 method->SetSilentFile(IsSilentFile());
471
472 // check-for-unused-options is performed; may be overridden by derived classes
473 method->CheckSetup();
474
475 if(fMethodsMap.find(datasetname)==fMethodsMap.end())
476 {
477 MVector *mvector=new MVector;
478 fMethodsMap[datasetname]=mvector;
479 }
480 fMethodsMap[datasetname]->push_back( method );
481 return method;
482}
483
484////////////////////////////////////////////////////////////////////////////////
485/// Books MVA method. The option configuration string is custom for each MVA
486/// the TString field "theNameAppendix" serves to define (and distinguish)
487/// several instances of a given MVA, eg, when one wants to compare the
488/// performance of various configurations
489
491{
492 return BookMethod(loader, Types::Instance().GetMethodName( theMethod ), methodTitle, theOption );
493}
494
495////////////////////////////////////////////////////////////////////////////////
496/// Adds an already constructed method to be managed by this factory.
497///
498/// \note Private.
499/// \note Know what you are doing when using this method. The method that you
500/// are loading could be trained already.
501///
502
504{
505 TString datasetname = loader->GetName();
506 std::string methodTypeName = std::string(Types::Instance().GetMethodName(methodType).Data());
507 DataSetInfo &dsi = loader->DefaultDataSetInfo();
508
509 IMethod *im = ClassifierFactory::Instance().Create(methodTypeName, dsi, weightfile );
510 MethodBase *method = (dynamic_cast<MethodBase*>(im));
511
512 if (method == nullptr) return nullptr;
513
514 if( method->GetMethodType() == Types::kCategory ){
515 Log() << kERROR << "Cannot handle category methods for now." << Endl;
516 }
517
518 TString fFileDir;
519 if(fModelPersistence) {
520 fFileDir=loader->GetName();
521 fFileDir+="/"+gConfig().GetIONames().fWeightFileDir;
522 }
523
524 if(fModelPersistence) method->SetWeightFileDir(fFileDir);
525 method->SetModelPersistence(fModelPersistence);
526 method->SetAnalysisType( fAnalysisType );
527 method->SetupMethod();
528 method->SetFile(fgTargetFile);
529 method->SetSilentFile(IsSilentFile());
530
532
533 // read weight file
534 method->ReadStateFromFile();
535
536 //method->CheckSetup();
537
538 TString methodTitle = method->GetName();
539 if (HasMethod(datasetname, methodTitle) != 0) {
540 Log() << kFATAL << "Booking failed since method with title <"
541 << methodTitle <<"> already exists "<< "in with DataSet Name <"<< loader->GetName()<<"> "
542 << Endl;
543 }
544
545 Log() << kINFO << "Booked classifier \"" << method->GetMethodName()
546 << "\" of type: \"" << method->GetMethodTypeName() << "\"" << Endl;
547
548 if(fMethodsMap.count(datasetname) == 0) {
549 MVector *mvector = new MVector;
550 fMethodsMap[datasetname] = mvector;
551 }
552
553 fMethodsMap[datasetname]->push_back( method );
554
555 return method;
556}
557
558////////////////////////////////////////////////////////////////////////////////
559/// Returns pointer to MVA that corresponds to given method title.
560
561TMVA::IMethod* TMVA::Factory::GetMethod(const TString& datasetname, const TString &methodTitle ) const
562{
563 if(fMethodsMap.find(datasetname)==fMethodsMap.end()) return 0;
564
565 MVector *methods=fMethodsMap.find(datasetname)->second;
566
567 MVector::const_iterator itrMethod;
568 //
569 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
570 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
571 if ( (mva->GetMethodName())==methodTitle ) return mva;
572 }
573 return 0;
574}
575
576////////////////////////////////////////////////////////////////////////////////
577/// Checks whether a given method name is defined for a given dataset.
578
579Bool_t TMVA::Factory::HasMethod(const TString& datasetname, const TString &methodTitle ) const
580{
581 if(fMethodsMap.find(datasetname)==fMethodsMap.end()) return 0;
582
583 std::string methodName = methodTitle.Data();
584 auto isEqualToMethodName = [&methodName](TMVA::IMethod * m) {
585 return ( 0 == methodName.compare( m->GetName() ) );
586 };
587
588 TMVA::Factory::MVector * methods = this->fMethodsMap.at(datasetname);
589 Bool_t isMethodNameExisting = std::any_of( methods->begin(), methods->end(), isEqualToMethodName);
590
591 return isMethodNameExisting;
592}
593
594////////////////////////////////////////////////////////////////////////////////
595
597{
598 RootBaseDir()->cd();
599
600 if(!RootBaseDir()->GetDirectory(fDataSetInfo.GetName())) RootBaseDir()->mkdir(fDataSetInfo.GetName());
601 else return; //loader is now in the output file, we dont need to save again
602
603 RootBaseDir()->cd(fDataSetInfo.GetName());
604 fDataSetInfo.GetDataSet(); // builds dataset (including calculation of correlation matrix)
605
606
607 // correlation matrix of the default DS
608 const TMatrixD* m(0);
609 const TH2* h(0);
610
611 if(fAnalysisType == Types::kMulticlass){
612 for (UInt_t cls = 0; cls < fDataSetInfo.GetNClasses() ; cls++) {
613 m = fDataSetInfo.CorrelationMatrix(fDataSetInfo.GetClassInfo(cls)->GetName());
614 h = fDataSetInfo.CreateCorrelationMatrixHist(m, TString("CorrelationMatrix")+fDataSetInfo.GetClassInfo(cls)->GetName(),
615 TString("Correlation Matrix (")+ fDataSetInfo.GetClassInfo(cls)->GetName() +TString(")"));
616 if (h!=0) {
617 h->Write();
618 delete h;
619 }
620 }
621 }
622 else{
623 m = fDataSetInfo.CorrelationMatrix( "Signal" );
624 h = fDataSetInfo.CreateCorrelationMatrixHist(m, "CorrelationMatrixS", "Correlation Matrix (signal)");
625 if (h!=0) {
626 h->Write();
627 delete h;
628 }
629
630 m = fDataSetInfo.CorrelationMatrix( "Background" );
631 h = fDataSetInfo.CreateCorrelationMatrixHist(m, "CorrelationMatrixB", "Correlation Matrix (background)");
632 if (h!=0) {
633 h->Write();
634 delete h;
635 }
636
637 m = fDataSetInfo.CorrelationMatrix( "Regression" );
638 h = fDataSetInfo.CreateCorrelationMatrixHist(m, "CorrelationMatrix", "Correlation Matrix");
639 if (h!=0) {
640 h->Write();
641 delete h;
642 }
643 }
644
645 // some default transformations to evaluate
646 // NOTE: all transformations are destroyed after this test
647 TString processTrfs = "I"; //"I;N;D;P;U;G,D;"
648
649 // plus some user defined transformations
650 processTrfs = fTransformations;
651
652 // remove any trace of identity transform - if given (avoid to apply it twice)
653 std::vector<TMVA::TransformationHandler*> trfs;
654 TransformationHandler* identityTrHandler = 0;
655
656 std::vector<TString> trfsDef = gTools().SplitString(processTrfs,';');
657 std::vector<TString>::iterator trfsDefIt = trfsDef.begin();
658 for (; trfsDefIt!=trfsDef.end(); ++trfsDefIt) {
659 trfs.push_back(new TMVA::TransformationHandler(fDataSetInfo, "Factory"));
660 TString trfS = (*trfsDefIt);
661
662 //Log() << kINFO << Endl;
663 Log() << kDEBUG << "current transformation string: '" << trfS.Data() << "'" << Endl;
665 fDataSetInfo,
666 *(trfs.back()),
667 Log() );
668
669 if (trfS.BeginsWith('I')) identityTrHandler = trfs.back();
670 }
671
672 const std::vector<Event*>& inputEvents = fDataSetInfo.GetDataSet()->GetEventCollection();
673
674 // apply all transformations
675 std::vector<TMVA::TransformationHandler*>::iterator trfIt = trfs.begin();
676
677 for (;trfIt != trfs.end(); ++trfIt) {
678 // setting a Root dir causes the variables distributions to be saved to the root file
679 (*trfIt)->SetRootDir(RootBaseDir()->GetDirectory(fDataSetInfo.GetName()));// every dataloader have its own dir
680 (*trfIt)->CalcTransformations(inputEvents);
681 }
682 if(identityTrHandler) identityTrHandler->PrintVariableRanking();
683
684 // clean up
685 for (trfIt = trfs.begin(); trfIt != trfs.end(); ++trfIt) delete *trfIt;
686}
687
688////////////////////////////////////////////////////////////////////////////////
689/// Iterates through all booked methods and sees if they use parameter tuning and if so..
690/// does just that i.e. calls "Method::Train()" for different parameter settings and
691/// keeps in mind the "optimal one"... and that's the one that will later on be used
692/// in the main training loop.
693
694std::map<TString,Double_t> TMVA::Factory::OptimizeAllMethods(TString fomType, TString fitType)
695{
696
697 std::map<TString,MVector*>::iterator itrMap;
698 std::map<TString,Double_t> TunedParameters;
699 for(itrMap = fMethodsMap.begin();itrMap != fMethodsMap.end();++itrMap)
700 {
701 MVector *methods=itrMap->second;
702
703 MVector::iterator itrMethod;
704
705 // iterate over methods and optimize
706 for( itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod ) {
708 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
709 if (!mva) {
710 Log() << kFATAL << "Dynamic cast to MethodBase failed" <<Endl;
711 return TunedParameters;
712 }
713
715 Log() << kWARNING << "Method " << mva->GetMethodName()
716 << " not trained (training tree has less entries ["
717 << mva->Data()->GetNTrainingEvents()
718 << "] than required [" << MinNoTrainingEvents << "]" << Endl;
719 continue;
720 }
721
722 Log() << kINFO << "Optimize method: " << mva->GetMethodName() << " for "
723 << (fAnalysisType == Types::kRegression ? "Regression" :
724 (fAnalysisType == Types::kMulticlass ? "Multiclass classification" : "Classification")) << Endl;
725
726 TunedParameters = mva->OptimizeTuningParameters(fomType,fitType);
727 Log() << kINFO << "Optimization of tuning parameters finished for Method:"<<mva->GetName() << Endl;
728 }
729 }
730
731 return TunedParameters;
732
733}
734
735////////////////////////////////////////////////////////////////////////////////
736/// Private method to generate a ROCCurve instance for a given method.
737/// Handles the conversion from TMVA ResultSet to a format the ROCCurve class
738/// understands.
739///
740/// \note You own the retured pointer.
741///
742
745{
746 return GetROC((TString)loader->GetName(), theMethodName, iClass, type);
747}
748
749////////////////////////////////////////////////////////////////////////////////
750/// Private method to generate a ROCCurve instance for a given method.
751/// Handles the conversion from TMVA ResultSet to a format the ROCCurve class
752/// understands.
753///
754/// \note You own the retured pointer.
755///
756
758{
759 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
760 Log() << kERROR << Form("DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
761 return nullptr;
762 }
763
764 if (!this->HasMethod(datasetname, theMethodName)) {
765 Log() << kERROR << Form("Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data())
766 << Endl;
767 return nullptr;
768 }
769
770 std::set<Types::EAnalysisType> allowedAnalysisTypes = {Types::kClassification, Types::kMulticlass};
771 if (allowedAnalysisTypes.count(this->fAnalysisType) == 0) {
772 Log() << kERROR << Form("Can only generate ROC curves for analysis type kClassification and kMulticlass.")
773 << Endl;
774 return nullptr;
775 }
776
777 TMVA::MethodBase *method = dynamic_cast<TMVA::MethodBase *>(this->GetMethod(datasetname, theMethodName));
778 TMVA::DataSet *dataset = method->Data();
779 dataset->SetCurrentType(type);
780 TMVA::Results *results = dataset->GetResults(theMethodName, type, this->fAnalysisType);
781
782 UInt_t nClasses = method->DataInfo().GetNClasses();
783 if (this->fAnalysisType == Types::kMulticlass && iClass >= nClasses) {
784 Log() << kERROR << Form("Given class number (iClass = %i) does not exist. There are %i classes in dataset.",
785 iClass, nClasses)
786 << Endl;
787 return nullptr;
788 }
789
790 TMVA::ROCCurve *rocCurve = nullptr;
791 if (this->fAnalysisType == Types::kClassification) {
792
793 std::vector<Float_t> *mvaRes = dynamic_cast<ResultsClassification *>(results)->GetValueVector();
794 std::vector<Bool_t> *mvaResTypes = dynamic_cast<ResultsClassification *>(results)->GetValueVectorTypes();
795 std::vector<Float_t> mvaResWeights;
796
797 auto eventCollection = dataset->GetEventCollection(type);
798 mvaResWeights.reserve(eventCollection.size());
799 for (auto ev : eventCollection) {
800 mvaResWeights.push_back(ev->GetWeight());
801 }
802
803 rocCurve = new TMVA::ROCCurve(*mvaRes, *mvaResTypes, mvaResWeights);
804
805 } else if (this->fAnalysisType == Types::kMulticlass) {
806 std::vector<Float_t> mvaRes;
807 std::vector<Bool_t> mvaResTypes;
808 std::vector<Float_t> mvaResWeights;
809
810 std::vector<std::vector<Float_t>> *rawMvaRes = dynamic_cast<ResultsMulticlass *>(results)->GetValueVector();
811
812 // Vector transpose due to values being stored as
813 // [ [0, 1, 2], [0, 1, 2], ... ]
814 // in ResultsMulticlass::GetValueVector.
815 mvaRes.reserve(rawMvaRes->size());
816 for (auto item : *rawMvaRes) {
817 mvaRes.push_back(item[iClass]);
818 }
819
820 auto eventCollection = dataset->GetEventCollection(type);
821 mvaResTypes.reserve(eventCollection.size());
822 mvaResWeights.reserve(eventCollection.size());
823 for (auto ev : eventCollection) {
824 mvaResTypes.push_back(ev->GetClass() == iClass);
825 mvaResWeights.push_back(ev->GetWeight());
826 }
827
828 rocCurve = new TMVA::ROCCurve(mvaRes, mvaResTypes, mvaResWeights);
829 }
830
831 return rocCurve;
832}
833
834////////////////////////////////////////////////////////////////////////////////
835/// Calculate the integral of the ROC curve, also known as the area under curve
836/// (AUC), for a given method.
837///
838/// Argument iClass specifies the class to generate the ROC curve in a
839/// multiclass setting. It is ignored for binary classification.
840///
841
843{
844 return GetROCIntegral((TString)loader->GetName(), theMethodName, iClass);
845}
846
847////////////////////////////////////////////////////////////////////////////////
848/// Calculate the integral of the ROC curve, also known as the area under curve
849/// (AUC), for a given method.
850///
851/// Argument iClass specifies the class to generate the ROC curve in a
852/// multiclass setting. It is ignored for binary classification.
853///
854
856{
857 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
858 Log() << kERROR << Form("DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
859 return 0;
860 }
861
862 if ( ! this->HasMethod(datasetname, theMethodName) ) {
863 Log() << kERROR << Form("Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data()) << Endl;
864 return 0;
865 }
866
867 std::set<Types::EAnalysisType> allowedAnalysisTypes = {Types::kClassification, Types::kMulticlass};
868 if ( allowedAnalysisTypes.count(this->fAnalysisType) == 0 ) {
869 Log() << kERROR << Form("Can only generate ROC integral for analysis type kClassification. and kMulticlass.")
870 << Endl;
871 return 0;
872 }
873
874 TMVA::ROCCurve *rocCurve = GetROC(datasetname, theMethodName, iClass);
875 if (!rocCurve) {
876 Log() << kFATAL << Form("ROCCurve object was not created in Method = %s not found with Dataset = %s ",
877 theMethodName.Data(), datasetname.Data())
878 << Endl;
879 return 0;
880 }
881
883 Double_t rocIntegral = rocCurve->GetROCIntegral(npoints);
884 delete rocCurve;
885
886 return rocIntegral;
887}
888
889////////////////////////////////////////////////////////////////////////////////
890/// Argument iClass specifies the class to generate the ROC curve in a
891/// multiclass setting. It is ignored for binary classification.
892///
893/// Returns a ROC graph for a given method, or nullptr on error.
894///
895/// Note: Evaluation of the given method must have been run prior to ROC
896/// generation through Factory::EvaluateAllMetods.
897///
898/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
899/// and the others considered background. This is ok in binary classification
900/// but in in multi class classification, the ROC surface is an N dimensional
901/// shape, where N is number of classes - 1.
902
903TGraph* TMVA::Factory::GetROCCurve(DataLoader *loader, TString theMethodName, Bool_t setTitles, UInt_t iClass)
904{
905 return GetROCCurve( (TString)loader->GetName(), theMethodName, setTitles, iClass );
906}
907
908////////////////////////////////////////////////////////////////////////////////
909/// Argument iClass specifies the class to generate the ROC curve in a
910/// multiclass setting. It is ignored for binary classification.
911///
912/// Returns a ROC graph for a given method, or nullptr on error.
913///
914/// Note: Evaluation of the given method must have been run prior to ROC
915/// generation through Factory::EvaluateAllMetods.
916///
917/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
918/// and the others considered background. This is ok in binary classification
919/// but in in multi class classification, the ROC surface is an N dimensional
920/// shape, where N is number of classes - 1.
921
922TGraph* TMVA::Factory::GetROCCurve(TString datasetname, TString theMethodName, Bool_t setTitles, UInt_t iClass)
923{
924 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
925 Log() << kERROR << Form("DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
926 return nullptr;
927 }
928
929 if ( ! this->HasMethod(datasetname, theMethodName) ) {
930 Log() << kERROR << Form("Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data()) << Endl;
931 return nullptr;
932 }
933
934 std::set<Types::EAnalysisType> allowedAnalysisTypes = {Types::kClassification, Types::kMulticlass};
935 if ( allowedAnalysisTypes.count(this->fAnalysisType) == 0 ) {
936 Log() << kERROR << Form("Can only generate ROC curves for analysis type kClassification and kMulticlass.") << Endl;
937 return nullptr;
938 }
939
940 TMVA::ROCCurve *rocCurve = GetROC(datasetname, theMethodName, iClass);
941 TGraph *graph = nullptr;
942
943 if ( ! rocCurve ) {
944 Log() << kFATAL << Form("ROCCurve object was not created in Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data()) << Endl;
945 return nullptr;
946 }
947
948 graph = (TGraph *)rocCurve->GetROCCurve()->Clone();
949 delete rocCurve;
950
951 if(setTitles) {
952 graph->GetYaxis()->SetTitle("Background rejection (Specificity)");
953 graph->GetXaxis()->SetTitle("Signal efficiency (Sensitivity)");
954 graph->SetTitle(Form("Signal efficiency vs. Background rejection (%s)", theMethodName.Data()));
955 }
956
957 return graph;
958}
959
960////////////////////////////////////////////////////////////////////////////////
961/// Generate a collection of graphs, for all methods for a given class. Suitable
962/// for comparing method performance.
963///
964/// Argument iClass specifies the class to generate the ROC curve in a
965/// multiclass setting. It is ignored for binary classification.
966///
967/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
968/// and the others considered background. This is ok in binary classification
969/// but in in multi class classification, the ROC surface is an N dimensional
970/// shape, where N is number of classes - 1.
971
973{
974 return GetROCCurveAsMultiGraph((TString)loader->GetName(), iClass);
975}
976
977////////////////////////////////////////////////////////////////////////////////
978/// Generate a collection of graphs, for all methods for a given class. Suitable
979/// for comparing method performance.
980///
981/// Argument iClass specifies the class to generate the ROC curve in a
982/// multiclass setting. It is ignored for binary classification.
983///
984/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
985/// and the others considered background. This is ok in binary classification
986/// but in in multi class classification, the ROC surface is an N dimensional
987/// shape, where N is number of classes - 1.
988
990{
991 UInt_t line_color = 1;
992
993 TMultiGraph *multigraph = new TMultiGraph();
994
995 MVector *methods = fMethodsMap[datasetname.Data()];
996 for (auto * method_raw : *methods) {
997 TMVA::MethodBase *method = dynamic_cast<TMVA::MethodBase *>(method_raw);
998 if (method == nullptr) { continue; }
999
1000 TString methodName = method->GetMethodName();
1001 UInt_t nClasses = method->DataInfo().GetNClasses();
1002
1003 if ( this->fAnalysisType == Types::kMulticlass && iClass >= nClasses ) {
1004 Log() << kERROR << Form("Given class number (iClass = %i) does not exist. There are %i classes in dataset.", iClass, nClasses) << Endl;
1005 continue;
1006 }
1007
1008 TString className = method->DataInfo().GetClassInfo(iClass)->GetName();
1009
1010 TGraph *graph = this->GetROCCurve(datasetname, methodName, false, iClass);
1011 graph->SetTitle(methodName);
1012
1013 graph->SetLineWidth(2);
1014 graph->SetLineColor(line_color++);
1015 graph->SetFillColor(10);
1016
1017 multigraph->Add(graph);
1018 }
1019
1020 if ( multigraph->GetListOfGraphs() == nullptr ) {
1021 Log() << kERROR << Form("No metohds have class %i defined.", iClass) << Endl;
1022 return nullptr;
1023 }
1024
1025 return multigraph;
1026}
1027
1028////////////////////////////////////////////////////////////////////////////////
1029/// Draws ROC curves for all methods booked with the factory for a given class
1030/// onto a canvas.
1031///
1032/// Argument iClass specifies the class to generate the ROC curve in a
1033/// multiclass setting. It is ignored for binary classification.
1034///
1035/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
1036/// and the others considered background. This is ok in binary classification
1037/// but in in multi class classification, the ROC surface is an N dimensional
1038/// shape, where N is number of classes - 1.
1039
1041{
1042 return GetROCCurve((TString)loader->GetName(), iClass);
1043}
1044
1045////////////////////////////////////////////////////////////////////////////////
1046/// Draws ROC curves for all methods booked with the factory for a given class.
1047///
1048/// Argument iClass specifies the class to generate the ROC curve in a
1049/// multiclass setting. It is ignored for binary classification.
1050///
1051/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
1052/// and the others considered background. This is ok in binary classification
1053/// but in in multi class classification, the ROC surface is an N dimensional
1054/// shape, where N is number of classes - 1.
1055
1057{
1058 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
1059 Log() << kERROR << Form("DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
1060 return 0;
1061 }
1062
1063 TString name = Form("ROCCurve %s class %i", datasetname.Data(), iClass);
1064 TCanvas *canvas = new TCanvas(name, "ROC Curve", 200, 10, 700, 500);
1065 canvas->SetGrid();
1066
1067 TMultiGraph *multigraph = this->GetROCCurveAsMultiGraph(datasetname, iClass);
1068
1069 if ( multigraph ) {
1070 multigraph->Draw("AL");
1071
1072 multigraph->GetYaxis()->SetTitle("Background rejection (Specificity)");
1073 multigraph->GetXaxis()->SetTitle("Signal efficiency (Sensitivity)");
1074
1075 TString titleString = Form("Signal efficiency vs. Background rejection");
1076 if (this->fAnalysisType == Types::kMulticlass) {
1077 titleString = Form("%s (Class=%i)", titleString.Data(), iClass);
1078 }
1079
1080 // Workaround for TMultigraph not drawing title correctly.
1081 multigraph->GetHistogram()->SetTitle( titleString );
1082 multigraph->SetTitle( titleString );
1083
1084 canvas->BuildLegend(0.15, 0.15, 0.35, 0.3, "MVA Method");
1085 }
1086
1087 return canvas;
1088}
1089
1090////////////////////////////////////////////////////////////////////////////////
1091/// Iterates through all booked methods and calls training
1092
1094{
1095 Log() << kHEADER << gTools().Color("bold") << "Train all methods" << gTools().Color("reset") << Endl;
1096 // iterates over all MVAs that have been booked, and calls their training methods
1097
1098
1099 // don't do anything if no method booked
1100 if (fMethodsMap.empty()) {
1101 Log() << kINFO << "...nothing found to train" << Endl;
1102 return;
1103 }
1104
1105 // here the training starts
1106 //Log() << kINFO << " " << Endl;
1107 Log() << kDEBUG << "Train all methods for "
1108 << (fAnalysisType == Types::kRegression ? "Regression" :
1109 (fAnalysisType == Types::kMulticlass ? "Multiclass" : "Classification") ) << " ..." << Endl;
1110
1111 std::map<TString,MVector*>::iterator itrMap;
1112
1113 for(itrMap = fMethodsMap.begin();itrMap != fMethodsMap.end();++itrMap)
1114 {
1115 MVector *methods=itrMap->second;
1116 MVector::iterator itrMethod;
1117
1118 // iterate over methods and train
1119 for( itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod ) {
1121 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
1122
1123 if(mva==0) continue;
1124
1125 if(mva->DataInfo().GetDataSetManager()->DataInput().GetEntries() <=1) { // 0 entries --> 0 events, 1 entry --> dynamical dataset (or one entry)
1126 Log() << kFATAL << "No input data for the training provided!" << Endl;
1127 }
1128
1129 if(fAnalysisType == Types::kRegression && mva->DataInfo().GetNTargets() < 1 )
1130 Log() << kFATAL << "You want to do regression training without specifying a target." << Endl;
1131 else if( (fAnalysisType == Types::kMulticlass || fAnalysisType == Types::kClassification)
1132 && mva->DataInfo().GetNClasses() < 2 )
1133 Log() << kFATAL << "You want to do classification training, but specified less than two classes." << Endl;
1134
1135 // first print some information about the default dataset
1136 if(!IsSilentFile()) WriteDataInformation(mva->fDataSetInfo);
1137
1138
1140 Log() << kWARNING << "Method " << mva->GetMethodName()
1141 << " not trained (training tree has less entries ["
1142 << mva->Data()->GetNTrainingEvents()
1143 << "] than required [" << MinNoTrainingEvents << "]" << Endl;
1144 continue;
1145 }
1146
1147 Log() << kHEADER << "Train method: " << mva->GetMethodName() << " for "
1148 << (fAnalysisType == Types::kRegression ? "Regression" :
1149 (fAnalysisType == Types::kMulticlass ? "Multiclass classification" : "Classification")) << Endl << Endl;
1150 mva->TrainMethod();
1151 Log() << kHEADER << "Training finished" << Endl << Endl;
1152 }
1153
1154 if (fAnalysisType != Types::kRegression) {
1155
1156 // variable ranking
1157 //Log() << Endl;
1158 Log() << kINFO << "Ranking input variables (method specific)..." << Endl;
1159 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1160 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
1161 if (mva && mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
1162
1163 // create and print ranking
1164 const Ranking* ranking = (*itrMethod)->CreateRanking();
1165 if (ranking != 0) ranking->Print();
1166 else Log() << kINFO << "No variable ranking supplied by classifier: "
1167 << dynamic_cast<MethodBase*>(*itrMethod)->GetMethodName() << Endl;
1168 }
1169 }
1170 }
1171
1172 // delete all methods and recreate them from weight file - this ensures that the application
1173 // of the methods (in TMVAClassificationApplication) is consistent with the results obtained
1174 // in the testing
1175 //Log() << Endl;
1176 if (fModelPersistence) {
1177
1178 Log() << kHEADER << "=== Destroy and recreate all methods via weight files for testing ===" << Endl << Endl;
1179
1180 if(!IsSilentFile())RootBaseDir()->cd();
1181
1182 // iterate through all booked methods
1183 for (UInt_t i=0; i<methods->size(); i++) {
1184
1185 MethodBase* m = dynamic_cast<MethodBase*>((*methods)[i]);
1186 if(m==0) continue;
1187
1188 TMVA::Types::EMVA methodType = m->GetMethodType();
1189 TString weightfile = m->GetWeightFileName();
1190
1191 // decide if .txt or .xml file should be read:
1192 if (READXML) weightfile.ReplaceAll(".txt",".xml");
1193
1194 DataSetInfo& dataSetInfo = m->DataInfo();
1195 TString testvarName = m->GetTestvarName();
1196 delete m; //itrMethod[i];
1197
1198 // recreate
1199 m = dynamic_cast<MethodBase *>(ClassifierFactory::Instance().Create(
1200 Types::Instance().GetMethodName(methodType).Data(), dataSetInfo, weightfile));
1201 if( m->GetMethodType() == Types::kCategory ){
1202 MethodCategory *methCat = (dynamic_cast<MethodCategory*>(m));
1203 if( !methCat ) Log() << kFATAL << "Method with type kCategory cannot be casted to MethodCategory. /Factory" << Endl;
1204 else methCat->fDataSetManager = m->DataInfo().GetDataSetManager();
1205 }
1206 //ToDo, Do we need to fill the DataSetManager of MethodBoost here too?
1207
1208
1209 TString fFileDir= m->DataInfo().GetName();
1210 fFileDir+="/"+gConfig().GetIONames().fWeightFileDir;
1211 m->SetWeightFileDir(fFileDir);
1212 m->SetModelPersistence(fModelPersistence);
1213 m->SetSilentFile(IsSilentFile());
1214 m->SetAnalysisType(fAnalysisType);
1215 m->SetupMethod();
1216 m->ReadStateFromFile();
1217 m->SetTestvarName(testvarName);
1218
1219 // replace trained method by newly created one (from weight file) in methods vector
1220 (*methods)[i] = m;
1221 }
1222 }
1223 }
1224}
1225
1226////////////////////////////////////////////////////////////////////////////////
1227/// Evaluates all booked methods on the testing data and adds the output to the
1228/// Results in the corresponiding DataSet.
1229///
1230
1232{
1233 Log() << kHEADER << gTools().Color("bold") << "Test all methods" << gTools().Color("reset") << Endl;
1234
1235 // don't do anything if no method booked
1236 if (fMethodsMap.empty()) {
1237 Log() << kINFO << "...nothing found to test" << Endl;
1238 return;
1239 }
1240 std::map<TString,MVector*>::iterator itrMap;
1241
1242 for(itrMap = fMethodsMap.begin();itrMap != fMethodsMap.end();++itrMap)
1243 {
1244 MVector *methods=itrMap->second;
1245 MVector::iterator itrMethod;
1246
1247 // iterate over methods and test
1248 for( itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod ) {
1250 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
1251 if(mva==0) continue;
1252 Types::EAnalysisType analysisType = mva->GetAnalysisType();
1253 Log() << kHEADER << "Test method: " << mva->GetMethodName() << " for "
1254 << (analysisType == Types::kRegression ? "Regression" :
1255 (analysisType == Types::kMulticlass ? "Multiclass classification" : "Classification")) << " performance" << Endl << Endl;
1256 mva->AddOutput( Types::kTesting, analysisType );
1257 }
1258 }
1259}
1260
1261////////////////////////////////////////////////////////////////////////////////
1262
1263void TMVA::Factory::MakeClass(const TString& datasetname , const TString& methodTitle ) const
1264{
1265 if (methodTitle != "") {
1266 IMethod* method = GetMethod(datasetname, methodTitle);
1267 if (method) method->MakeClass();
1268 else {
1269 Log() << kWARNING << "<MakeClass> Could not find classifier \"" << methodTitle
1270 << "\" in list" << Endl;
1271 }
1272 }
1273 else {
1274
1275 // no classifier specified, print all help messages
1276 MVector *methods=fMethodsMap.find(datasetname)->second;
1277 MVector::const_iterator itrMethod;
1278 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1279 MethodBase* method = dynamic_cast<MethodBase*>(*itrMethod);
1280 if(method==0) continue;
1281 Log() << kINFO << "Make response class for classifier: " << method->GetMethodName() << Endl;
1282 method->MakeClass();
1283 }
1284 }
1285}
1286
1287////////////////////////////////////////////////////////////////////////////////
1288/// Print predefined help message of classifier.
1289/// Iterate over methods and test.
1290
1291void TMVA::Factory::PrintHelpMessage(const TString& datasetname , const TString& methodTitle ) const
1292{
1293 if (methodTitle != "") {
1294 IMethod* method = GetMethod(datasetname , methodTitle );
1295 if (method) method->PrintHelpMessage();
1296 else {
1297 Log() << kWARNING << "<PrintHelpMessage> Could not find classifier \"" << methodTitle
1298 << "\" in list" << Endl;
1299 }
1300 }
1301 else {
1302
1303 // no classifier specified, print all help messages
1304 MVector *methods=fMethodsMap.find(datasetname)->second;
1305 MVector::const_iterator itrMethod ;
1306 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1307 MethodBase* method = dynamic_cast<MethodBase*>(*itrMethod);
1308 if(method==0) continue;
1309 Log() << kINFO << "Print help message for classifier: " << method->GetMethodName() << Endl;
1310 method->PrintHelpMessage();
1311 }
1312 }
1313}
1314
1315////////////////////////////////////////////////////////////////////////////////
1316/// Iterates over all MVA input variables and evaluates them.
1317
1319{
1320 Log() << kINFO << "Evaluating all variables..." << Endl;
1322
1323 for (UInt_t i=0; i<loader->DefaultDataSetInfo().GetNVariables(); i++) {
1325 if (options.Contains("V")) s += ":V";
1326 this->BookMethod(loader, "Variable", s );
1327 }
1328}
1329
1330////////////////////////////////////////////////////////////////////////////////
1331/// Iterates over all MVAs that have been booked, and calls their evaluation methods.
1332
1334{
1335 Log() << kHEADER << gTools().Color("bold") << "Evaluate all methods" << gTools().Color("reset") << Endl;
1336
1337 // don't do anything if no method booked
1338 if (fMethodsMap.empty()) {
1339 Log() << kINFO << "...nothing found to evaluate" << Endl;
1340 return;
1341 }
1342 std::map<TString,MVector*>::iterator itrMap;
1343
1344 for(itrMap = fMethodsMap.begin();itrMap != fMethodsMap.end();++itrMap)
1345 {
1346 MVector *methods=itrMap->second;
1347
1348 // -----------------------------------------------------------------------
1349 // First part of evaluation process
1350 // --> compute efficiencies, and other separation estimators
1351 // -----------------------------------------------------------------------
1352
1353 // although equal, we now want to separate the output for the variables
1354 // and the real methods
1355 Int_t isel; // will be 0 for a Method; 1 for a Variable
1356 Int_t nmeth_used[2] = {0,0}; // 0 Method; 1 Variable
1357
1358 std::vector<std::vector<TString> > mname(2);
1359 std::vector<std::vector<Double_t> > sig(2), sep(2), roc(2);
1360 std::vector<std::vector<Double_t> > eff01(2), eff10(2), eff30(2), effArea(2);
1361 std::vector<std::vector<Double_t> > eff01err(2), eff10err(2), eff30err(2);
1362 std::vector<std::vector<Double_t> > trainEff01(2), trainEff10(2), trainEff30(2);
1363
1364 std::vector<std::vector<Float_t> > multiclass_testEff;
1365 std::vector<std::vector<Float_t> > multiclass_trainEff;
1366 std::vector<std::vector<Float_t> > multiclass_testPur;
1367 std::vector<std::vector<Float_t> > multiclass_trainPur;
1368
1369 // Multiclass confusion matrices.
1370 std::vector<TMatrixD> multiclass_trainConfusionEffB01;
1371 std::vector<TMatrixD> multiclass_trainConfusionEffB10;
1372 std::vector<TMatrixD> multiclass_trainConfusionEffB30;
1373 std::vector<TMatrixD> multiclass_testConfusionEffB01;
1374 std::vector<TMatrixD> multiclass_testConfusionEffB10;
1375 std::vector<TMatrixD> multiclass_testConfusionEffB30;
1376
1377 std::vector<std::vector<Double_t> > biastrain(1); // "bias" of the regression on the training data
1378 std::vector<std::vector<Double_t> > biastest(1); // "bias" of the regression on test data
1379 std::vector<std::vector<Double_t> > devtrain(1); // "dev" of the regression on the training data
1380 std::vector<std::vector<Double_t> > devtest(1); // "dev" of the regression on test data
1381 std::vector<std::vector<Double_t> > rmstrain(1); // "rms" of the regression on the training data
1382 std::vector<std::vector<Double_t> > rmstest(1); // "rms" of the regression on test data
1383 std::vector<std::vector<Double_t> > minftrain(1); // "minf" of the regression on the training data
1384 std::vector<std::vector<Double_t> > minftest(1); // "minf" of the regression on test data
1385 std::vector<std::vector<Double_t> > rhotrain(1); // correlation of the regression on the training data
1386 std::vector<std::vector<Double_t> > rhotest(1); // correlation of the regression on test data
1387
1388 // same as above but for 'truncated' quantities (computed for events within 2sigma of RMS)
1389 std::vector<std::vector<Double_t> > biastrainT(1);
1390 std::vector<std::vector<Double_t> > biastestT(1);
1391 std::vector<std::vector<Double_t> > devtrainT(1);
1392 std::vector<std::vector<Double_t> > devtestT(1);
1393 std::vector<std::vector<Double_t> > rmstrainT(1);
1394 std::vector<std::vector<Double_t> > rmstestT(1);
1395 std::vector<std::vector<Double_t> > minftrainT(1);
1396 std::vector<std::vector<Double_t> > minftestT(1);
1397
1398 // following vector contains all methods - with the exception of Cuts, which are special
1399 MVector methodsNoCuts;
1400
1401 Bool_t doRegression = kFALSE;
1402 Bool_t doMulticlass = kFALSE;
1403
1404 // iterate over methods and evaluate
1405 for (MVector::iterator itrMethod =methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1407 MethodBase* theMethod = dynamic_cast<MethodBase*>(*itrMethod);
1408 if(theMethod==0) continue;
1409 theMethod->SetFile(fgTargetFile);
1410 theMethod->SetSilentFile(IsSilentFile());
1411 if (theMethod->GetMethodType() != Types::kCuts) methodsNoCuts.push_back( *itrMethod );
1412
1413 if (theMethod->DoRegression()) {
1414 doRegression = kTRUE;
1415
1416 Log() << kINFO << "Evaluate regression method: " << theMethod->GetMethodName() << Endl;
1417 Double_t bias, dev, rms, mInf;
1418 Double_t biasT, devT, rmsT, mInfT;
1419 Double_t rho;
1420
1421 Log() << kINFO << "TestRegression (testing)" << Endl;
1422 theMethod->TestRegression( bias, biasT, dev, devT, rms, rmsT, mInf, mInfT, rho, TMVA::Types::kTesting );
1423 biastest[0] .push_back( bias );
1424 devtest[0] .push_back( dev );
1425 rmstest[0] .push_back( rms );
1426 minftest[0] .push_back( mInf );
1427 rhotest[0] .push_back( rho );
1428 biastestT[0] .push_back( biasT );
1429 devtestT[0] .push_back( devT );
1430 rmstestT[0] .push_back( rmsT );
1431 minftestT[0] .push_back( mInfT );
1432
1433 Log() << kINFO << "TestRegression (training)" << Endl;
1434 theMethod->TestRegression( bias, biasT, dev, devT, rms, rmsT, mInf, mInfT, rho, TMVA::Types::kTraining );
1435 biastrain[0] .push_back( bias );
1436 devtrain[0] .push_back( dev );
1437 rmstrain[0] .push_back( rms );
1438 minftrain[0] .push_back( mInf );
1439 rhotrain[0] .push_back( rho );
1440 biastrainT[0].push_back( biasT );
1441 devtrainT[0] .push_back( devT );
1442 rmstrainT[0] .push_back( rmsT );
1443 minftrainT[0].push_back( mInfT );
1444
1445 mname[0].push_back( theMethod->GetMethodName() );
1446 nmeth_used[0]++;
1447 if(!IsSilentFile())
1448 {
1449 Log() << kDEBUG << "\tWrite evaluation histograms to file" << Endl;
1452 }
1453 } else if (theMethod->DoMulticlass()) {
1454 // ====================================================================
1455 // === Multiclass evaluation
1456 // ====================================================================
1457 doMulticlass = kTRUE;
1458 Log() << kINFO << "Evaluate multiclass classification method: " << theMethod->GetMethodName() << Endl;
1459
1460 // This part uses a genetic alg. to evaluate the optimal sig eff * sig pur.
1461 // This is why it is disabled for now.
1462 // Find approximate optimal working point w.r.t. signalEfficiency * signalPurity.
1463 // theMethod->TestMulticlass(); // This is where the actual GA calc is done
1464 // multiclass_testEff.push_back(theMethod->GetMulticlassEfficiency(multiclass_testPur));
1465
1466 theMethod->TestMulticlass();
1467
1468 // Confusion matrix at three background efficiency levels
1469 multiclass_trainConfusionEffB01.push_back(theMethod->GetMulticlassConfusionMatrix(0.01, Types::kTraining));
1470 multiclass_trainConfusionEffB10.push_back(theMethod->GetMulticlassConfusionMatrix(0.10, Types::kTraining));
1471 multiclass_trainConfusionEffB30.push_back(theMethod->GetMulticlassConfusionMatrix(0.30, Types::kTraining));
1472
1473 multiclass_testConfusionEffB01.push_back(theMethod->GetMulticlassConfusionMatrix(0.01, Types::kTesting));
1474 multiclass_testConfusionEffB10.push_back(theMethod->GetMulticlassConfusionMatrix(0.10, Types::kTesting));
1475 multiclass_testConfusionEffB30.push_back(theMethod->GetMulticlassConfusionMatrix(0.30, Types::kTesting));
1476
1477 if (not IsSilentFile()) {
1478 Log() << kDEBUG << "\tWrite evaluation histograms to file" << Endl;
1481 }
1482
1483 nmeth_used[0]++;
1484 mname[0].push_back(theMethod->GetMethodName());
1485 } else {
1486
1487 Log() << kHEADER << "Evaluate classifier: " << theMethod->GetMethodName() << Endl << Endl;
1488 isel = (theMethod->GetMethodTypeName().Contains("Variable")) ? 1 : 0;
1489
1490 // perform the evaluation
1491 theMethod->TestClassification();
1492
1493 // evaluate the classifier
1494 mname[isel].push_back(theMethod->GetMethodName());
1495 sig[isel].push_back(theMethod->GetSignificance());
1496 sep[isel].push_back(theMethod->GetSeparation());
1497 roc[isel].push_back(theMethod->GetROCIntegral());
1498
1499 Double_t err;
1500 eff01[isel].push_back(theMethod->GetEfficiency("Efficiency:0.01", Types::kTesting, err));
1501 eff01err[isel].push_back(err);
1502 eff10[isel].push_back(theMethod->GetEfficiency("Efficiency:0.10", Types::kTesting, err));
1503 eff10err[isel].push_back(err);
1504 eff30[isel].push_back(theMethod->GetEfficiency("Efficiency:0.30", Types::kTesting, err));
1505 eff30err[isel].push_back(err);
1506 effArea[isel].push_back(theMethod->GetEfficiency("", Types::kTesting, err)); // computes the area (average)
1507
1508 trainEff01[isel].push_back(theMethod->GetTrainingEfficiency("Efficiency:0.01")); // the first pass takes longer
1509 trainEff10[isel].push_back(theMethod->GetTrainingEfficiency("Efficiency:0.10"));
1510 trainEff30[isel].push_back(theMethod->GetTrainingEfficiency("Efficiency:0.30"));
1511
1512 nmeth_used[isel]++;
1513
1514 if (!IsSilentFile()) {
1515 Log() << kDEBUG << "\tWrite evaluation histograms to file" << Endl;
1518 }
1519 }
1520 }
1521 if (doRegression) {
1522
1523 std::vector<TString> vtemps = mname[0];
1524 std::vector< std::vector<Double_t> > vtmp;
1525 vtmp.push_back( devtest[0] ); // this is the vector that is ranked
1526 vtmp.push_back( devtrain[0] );
1527 vtmp.push_back( biastest[0] );
1528 vtmp.push_back( biastrain[0] );
1529 vtmp.push_back( rmstest[0] );
1530 vtmp.push_back( rmstrain[0] );
1531 vtmp.push_back( minftest[0] );
1532 vtmp.push_back( minftrain[0] );
1533 vtmp.push_back( rhotest[0] );
1534 vtmp.push_back( rhotrain[0] );
1535 vtmp.push_back( devtestT[0] ); // this is the vector that is ranked
1536 vtmp.push_back( devtrainT[0] );
1537 vtmp.push_back( biastestT[0] );
1538 vtmp.push_back( biastrainT[0]);
1539 vtmp.push_back( rmstestT[0] );
1540 vtmp.push_back( rmstrainT[0] );
1541 vtmp.push_back( minftestT[0] );
1542 vtmp.push_back( minftrainT[0]);
1543 gTools().UsefulSortAscending( vtmp, &vtemps );
1544 mname[0] = vtemps;
1545 devtest[0] = vtmp[0];
1546 devtrain[0] = vtmp[1];
1547 biastest[0] = vtmp[2];
1548 biastrain[0] = vtmp[3];
1549 rmstest[0] = vtmp[4];
1550 rmstrain[0] = vtmp[5];
1551 minftest[0] = vtmp[6];
1552 minftrain[0] = vtmp[7];
1553 rhotest[0] = vtmp[8];
1554 rhotrain[0] = vtmp[9];
1555 devtestT[0] = vtmp[10];
1556 devtrainT[0] = vtmp[11];
1557 biastestT[0] = vtmp[12];
1558 biastrainT[0] = vtmp[13];
1559 rmstestT[0] = vtmp[14];
1560 rmstrainT[0] = vtmp[15];
1561 minftestT[0] = vtmp[16];
1562 minftrainT[0] = vtmp[17];
1563 } else if (doMulticlass) {
1564 // TODO: fill in something meaningful
1565 // If there is some ranking of methods to be done it should be done here.
1566 // However, this is not so easy to define for multiclass so it is left out for now.
1567
1568 }
1569 else {
1570 // now sort the variables according to the best 'eff at Beff=0.10'
1571 for (Int_t k=0; k<2; k++) {
1572 std::vector< std::vector<Double_t> > vtemp;
1573 vtemp.push_back( effArea[k] ); // this is the vector that is ranked
1574 vtemp.push_back( eff10[k] );
1575 vtemp.push_back( eff01[k] );
1576 vtemp.push_back( eff30[k] );
1577 vtemp.push_back( eff10err[k] );
1578 vtemp.push_back( eff01err[k] );
1579 vtemp.push_back( eff30err[k] );
1580 vtemp.push_back( trainEff10[k] );
1581 vtemp.push_back( trainEff01[k] );
1582 vtemp.push_back( trainEff30[k] );
1583 vtemp.push_back( sig[k] );
1584 vtemp.push_back( sep[k] );
1585 vtemp.push_back( roc[k] );
1586 std::vector<TString> vtemps = mname[k];
1587 gTools().UsefulSortDescending( vtemp, &vtemps );
1588 effArea[k] = vtemp[0];
1589 eff10[k] = vtemp[1];
1590 eff01[k] = vtemp[2];
1591 eff30[k] = vtemp[3];
1592 eff10err[k] = vtemp[4];
1593 eff01err[k] = vtemp[5];
1594 eff30err[k] = vtemp[6];
1595 trainEff10[k] = vtemp[7];
1596 trainEff01[k] = vtemp[8];
1597 trainEff30[k] = vtemp[9];
1598 sig[k] = vtemp[10];
1599 sep[k] = vtemp[11];
1600 roc[k] = vtemp[12];
1601 mname[k] = vtemps;
1602 }
1603 }
1604
1605 // -----------------------------------------------------------------------
1606 // Second part of evaluation process
1607 // --> compute correlations among MVAs
1608 // --> compute correlations between input variables and MVA (determines importance)
1609 // --> count overlaps
1610 // -----------------------------------------------------------------------
1611 if(fCorrelations)
1612 {
1613 const Int_t nmeth = methodsNoCuts.size();
1614 MethodBase* method = dynamic_cast<MethodBase*>(methods[0][0]);
1615 const Int_t nvar = method->fDataSetInfo.GetNVariables();
1616 if (!doRegression && !doMulticlass ) {
1617
1618 if (nmeth > 0) {
1619
1620 // needed for correlations
1621 Double_t *dvec = new Double_t[nmeth+nvar];
1622 std::vector<Double_t> rvec;
1623
1624 // for correlations
1625 TPrincipal* tpSig = new TPrincipal( nmeth+nvar, "" );
1626 TPrincipal* tpBkg = new TPrincipal( nmeth+nvar, "" );
1627
1628 // set required tree branch references
1629 Int_t ivar = 0;
1630 std::vector<TString>* theVars = new std::vector<TString>;
1631 std::vector<ResultsClassification*> mvaRes;
1632 for (MVector::iterator itrMethod = methodsNoCuts.begin(); itrMethod != methodsNoCuts.end(); ++itrMethod, ++ivar) {
1633 MethodBase* m = dynamic_cast<MethodBase*>(*itrMethod);
1634 if(m==0) continue;
1635 theVars->push_back( m->GetTestvarName() );
1636 rvec.push_back( m->GetSignalReferenceCut() );
1637 theVars->back().ReplaceAll( "MVA_", "" );
1638 mvaRes.push_back( dynamic_cast<ResultsClassification*>( m->Data()->GetResults( m->GetMethodName(),
1641 }
1642
1643 // for overlap study
1644 TMatrixD* overlapS = new TMatrixD( nmeth, nmeth );
1645 TMatrixD* overlapB = new TMatrixD( nmeth, nmeth );
1646 (*overlapS) *= 0; // init...
1647 (*overlapB) *= 0; // init...
1648
1649 // loop over test tree
1650 DataSet* defDs = method->fDataSetInfo.GetDataSet();
1652 for (Int_t ievt=0; ievt<defDs->GetNEvents(); ievt++) {
1653 const Event* ev = defDs->GetEvent(ievt);
1654
1655 // for correlations
1656 TMatrixD* theMat = 0;
1657 for (Int_t im=0; im<nmeth; im++) {
1658 // check for NaN value
1659 Double_t retval = (Double_t)(*mvaRes[im])[ievt][0];
1660 if (TMath::IsNaN(retval)) {
1661 Log() << kWARNING << "Found NaN return value in event: " << ievt
1662 << " for method \"" << methodsNoCuts[im]->GetName() << "\"" << Endl;
1663 dvec[im] = 0;
1664 }
1665 else dvec[im] = retval;
1666 }
1667 for (Int_t iv=0; iv<nvar; iv++) dvec[iv+nmeth] = (Double_t)ev->GetValue(iv);
1668 if (method->fDataSetInfo.IsSignal(ev)) { tpSig->AddRow( dvec ); theMat = overlapS; }
1669 else { tpBkg->AddRow( dvec ); theMat = overlapB; }
1670
1671 // count overlaps
1672 for (Int_t im=0; im<nmeth; im++) {
1673 for (Int_t jm=im; jm<nmeth; jm++) {
1674 if ((dvec[im] - rvec[im])*(dvec[jm] - rvec[jm]) > 0) {
1675 (*theMat)(im,jm)++;
1676 if (im != jm) (*theMat)(jm,im)++;
1677 }
1678 }
1679 }
1680 }
1681
1682 // renormalise overlap matrix
1683 (*overlapS) *= (1.0/defDs->GetNEvtSigTest()); // init...
1684 (*overlapB) *= (1.0/defDs->GetNEvtBkgdTest()); // init...
1685
1686 tpSig->MakePrincipals();
1687 tpBkg->MakePrincipals();
1688
1689 const TMatrixD* covMatS = tpSig->GetCovarianceMatrix();
1690 const TMatrixD* covMatB = tpBkg->GetCovarianceMatrix();
1691
1692 const TMatrixD* corrMatS = gTools().GetCorrelationMatrix( covMatS );
1693 const TMatrixD* corrMatB = gTools().GetCorrelationMatrix( covMatB );
1694
1695 // print correlation matrices
1696 if (corrMatS != 0 && corrMatB != 0) {
1697
1698 // extract MVA matrix
1699 TMatrixD mvaMatS(nmeth,nmeth);
1700 TMatrixD mvaMatB(nmeth,nmeth);
1701 for (Int_t im=0; im<nmeth; im++) {
1702 for (Int_t jm=0; jm<nmeth; jm++) {
1703 mvaMatS(im,jm) = (*corrMatS)(im,jm);
1704 mvaMatB(im,jm) = (*corrMatB)(im,jm);
1705 }
1706 }
1707
1708 // extract variables - to MVA matrix
1709 std::vector<TString> theInputVars;
1710 TMatrixD varmvaMatS(nvar,nmeth);
1711 TMatrixD varmvaMatB(nvar,nmeth);
1712 for (Int_t iv=0; iv<nvar; iv++) {
1713 theInputVars.push_back( method->fDataSetInfo.GetVariableInfo( iv ).GetLabel() );
1714 for (Int_t jm=0; jm<nmeth; jm++) {
1715 varmvaMatS(iv,jm) = (*corrMatS)(nmeth+iv,jm);
1716 varmvaMatB(iv,jm) = (*corrMatB)(nmeth+iv,jm);
1717 }
1718 }
1719
1720 if (nmeth > 1) {
1721 Log() << kINFO << Endl;
1722 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "Inter-MVA correlation matrix (signal):" << Endl;
1723 gTools().FormattedOutput( mvaMatS, *theVars, Log() );
1724 Log() << kINFO << Endl;
1725
1726 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "Inter-MVA correlation matrix (background):" << Endl;
1727 gTools().FormattedOutput( mvaMatB, *theVars, Log() );
1728 Log() << kINFO << Endl;
1729 }
1730
1731 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "Correlations between input variables and MVA response (signal):" << Endl;
1732 gTools().FormattedOutput( varmvaMatS, theInputVars, *theVars, Log() );
1733 Log() << kINFO << Endl;
1734
1735 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "Correlations between input variables and MVA response (background):" << Endl;
1736 gTools().FormattedOutput( varmvaMatB, theInputVars, *theVars, Log() );
1737 Log() << kINFO << Endl;
1738 }
1739 else Log() << kWARNING <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "<TestAllMethods> cannot compute correlation matrices" << Endl;
1740
1741 // print overlap matrices
1742 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "The following \"overlap\" matrices contain the fraction of events for which " << Endl;
1743 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "the MVAs 'i' and 'j' have returned conform answers about \"signal-likeness\"" << Endl;
1744 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "An event is signal-like, if its MVA output exceeds the following value:" << Endl;
1745 gTools().FormattedOutput( rvec, *theVars, "Method" , "Cut value", Log() );
1746 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "which correspond to the working point: eff(signal) = 1 - eff(background)" << Endl;
1747
1748 // give notice that cut method has been excluded from this test
1749 if (nmeth != (Int_t)methods->size())
1750 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "Note: no correlations and overlap with cut method are provided at present" << Endl;
1751
1752 if (nmeth > 1) {
1753 Log() << kINFO << Endl;
1754 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "Inter-MVA overlap matrix (signal):" << Endl;
1755 gTools().FormattedOutput( *overlapS, *theVars, Log() );
1756 Log() << kINFO << Endl;
1757
1758 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "Inter-MVA overlap matrix (background):" << Endl;
1759 gTools().FormattedOutput( *overlapB, *theVars, Log() );
1760 }
1761
1762 // cleanup
1763 delete tpSig;
1764 delete tpBkg;
1765 delete corrMatS;
1766 delete corrMatB;
1767 delete theVars;
1768 delete overlapS;
1769 delete overlapB;
1770 delete [] dvec;
1771 }
1772 }
1773 }
1774 // -----------------------------------------------------------------------
1775 // Third part of evaluation process
1776 // --> output
1777 // -----------------------------------------------------------------------
1778
1779 if (doRegression) {
1780
1781 Log() << kINFO << Endl;
1782 TString hLine = "--------------------------------------------------------------------------------------------------";
1783 Log() << kINFO << "Evaluation results ranked by smallest RMS on test sample:" << Endl;
1784 Log() << kINFO << "(\"Bias\" quotes the mean deviation of the regression from true target." << Endl;
1785 Log() << kINFO << " \"MutInf\" is the \"Mutual Information\" between regression and target." << Endl;
1786 Log() << kINFO << " Indicated by \"_T\" are the corresponding \"truncated\" quantities ob-" << Endl;
1787 Log() << kINFO << " tained when removing events deviating more than 2sigma from average.)" << Endl;
1788 Log() << kINFO << hLine << Endl;
1789 //Log() << kINFO << "DataSet Name: MVA Method: <Bias> <Bias_T> RMS RMS_T | MutInf MutInf_T" << Endl;
1790 Log() << kINFO << hLine << Endl;
1791
1792 for (Int_t i=0; i<nmeth_used[0]; i++) {
1793 MethodBase* theMethod = dynamic_cast<MethodBase*>((*methods)[i]);
1794 if(theMethod==0) continue;
1795
1796 Log() << kINFO << Form("%-20s %-15s:%#9.3g%#9.3g%#9.3g%#9.3g | %#5.3f %#5.3f",
1797 theMethod->fDataSetInfo.GetName(),
1798 (const char*)mname[0][i],
1799 biastest[0][i], biastestT[0][i],
1800 rmstest[0][i], rmstestT[0][i],
1801 minftest[0][i], minftestT[0][i] )
1802 << Endl;
1803 }
1804 Log() << kINFO << hLine << Endl;
1805 Log() << kINFO << Endl;
1806 Log() << kINFO << "Evaluation results ranked by smallest RMS on training sample:" << Endl;
1807 Log() << kINFO << "(overtraining check)" << Endl;
1808 Log() << kINFO << hLine << Endl;
1809 Log() << kINFO << "DataSet Name: MVA Method: <Bias> <Bias_T> RMS RMS_T | MutInf MutInf_T" << Endl;
1810 Log() << kINFO << hLine << Endl;
1811
1812 for (Int_t i=0; i<nmeth_used[0]; i++) {
1813 MethodBase* theMethod = dynamic_cast<MethodBase*>((*methods)[i]);
1814 if(theMethod==0) continue;
1815 Log() << kINFO << Form("%-20s %-15s:%#9.3g%#9.3g%#9.3g%#9.3g | %#5.3f %#5.3f",
1816 theMethod->fDataSetInfo.GetName(),
1817 (const char*)mname[0][i],
1818 biastrain[0][i], biastrainT[0][i],
1819 rmstrain[0][i], rmstrainT[0][i],
1820 minftrain[0][i], minftrainT[0][i] )
1821 << Endl;
1822 }
1823 Log() << kINFO << hLine << Endl;
1824 Log() << kINFO << Endl;
1825 } else if (doMulticlass) {
1826 // ====================================================================
1827 // === Multiclass Output
1828 // ====================================================================
1829
1830 TString hLine =
1831 "-------------------------------------------------------------------------------------------------------";
1832
1833 // This part uses a genetic alg. to evaluate the optimal sig eff * sig pur.
1834 // This is why it is disabled for now.
1835 //
1836 // // --- Acheivable signal efficiency * signal purity
1837 // // --------------------------------------------------------------------
1838 // Log() << kINFO << Endl;
1839 // Log() << kINFO << "Evaluation results ranked by best signal efficiency times signal purity " << Endl;
1840 // Log() << kINFO << hLine << Endl;
1841
1842 // // iterate over methods and evaluate
1843 // for (MVector::iterator itrMethod = methods->begin(); itrMethod != methods->end(); itrMethod++) {
1844 // MethodBase *theMethod = dynamic_cast<MethodBase *>(*itrMethod);
1845 // if (theMethod == 0) {
1846 // continue;
1847 // }
1848
1849 // TString header = "DataSet Name MVA Method ";
1850 // for (UInt_t icls = 0; icls < theMethod->fDataSetInfo.GetNClasses(); ++icls) {
1851 // header += Form("%-12s ", theMethod->fDataSetInfo.GetClassInfo(icls)->GetName());
1852 // }
1853
1854 // Log() << kINFO << header << Endl;
1855 // Log() << kINFO << hLine << Endl;
1856 // for (Int_t i = 0; i < nmeth_used[0]; i++) {
1857 // TString res = Form("[%-14s] %-15s", theMethod->fDataSetInfo.GetName(), (const char *)mname[0][i]);
1858 // for (UInt_t icls = 0; icls < theMethod->fDataSetInfo.GetNClasses(); ++icls) {
1859 // res += Form("%#1.3f ", (multiclass_testEff[i][icls]) * (multiclass_testPur[i][icls]));
1860 // }
1861 // Log() << kINFO << res << Endl;
1862 // }
1863
1864 // Log() << kINFO << hLine << Endl;
1865 // Log() << kINFO << Endl;
1866 // }
1867
1868 // --- 1 vs Rest ROC AUC, signal efficiency @ given background efficiency
1869 // --------------------------------------------------------------------
1870 TString header1 = Form("%-15s%-15s%-15s%-15s%-15s%-15s", "Dataset", "MVA Method", "ROC AUC", "Sig eff@B=0.01",
1871 "Sig eff@B=0.10", "Sig eff@B=0.30");
1872 TString header2 = Form("%-15s%-15s%-15s%-15s%-15s%-15s", "Name:", "/ Class:", "test (train)", "test (train)",
1873 "test (train)", "test (train)");
1874 Log() << kINFO << Endl;
1875 Log() << kINFO << "1-vs-rest performance metrics per class" << Endl;
1876 Log() << kINFO << hLine << Endl;
1877 Log() << kINFO << Endl;
1878 Log() << kINFO << "Considers the listed class as signal and the other classes" << Endl;
1879 Log() << kINFO << "as background, reporting the resulting binary performance." << Endl;
1880 Log() << kINFO << "A score of 0.820 (0.850) means 0.820 was acheived on the" << Endl;
1881 Log() << kINFO << "test set and 0.850 on the training set." << Endl;
1882
1883 Log() << kINFO << Endl;
1884 Log() << kINFO << header1 << Endl;
1885 Log() << kINFO << header2 << Endl;
1886 for (Int_t k = 0; k < 2; k++) {
1887 for (Int_t i = 0; i < nmeth_used[k]; i++) {
1888 if (k == 1) {
1889 mname[k][i].ReplaceAll("Variable_", "");
1890 }
1891
1892 const TString datasetName = itrMap->first;
1893 const TString mvaName = mname[k][i];
1894
1895 MethodBase *theMethod = dynamic_cast<MethodBase *>(GetMethod(datasetName, mvaName));
1896 if (theMethod == 0) {
1897 continue;
1898 }
1899
1900 Log() << kINFO << Endl;
1901 TString row = Form("%-15s%-15s", datasetName.Data(), mvaName.Data());
1902 Log() << kINFO << row << Endl;
1903 Log() << kINFO << "------------------------------" << Endl;
1904
1905 UInt_t numClasses = theMethod->fDataSetInfo.GetNClasses();
1906 for (UInt_t iClass = 0; iClass < numClasses; ++iClass) {
1907
1908 ROCCurve *rocCurveTrain = GetROC(datasetName, mvaName, iClass, Types::kTraining);
1909 ROCCurve *rocCurveTest = GetROC(datasetName, mvaName, iClass, Types::kTesting);
1910
1911 const TString className = theMethod->DataInfo().GetClassInfo(iClass)->GetName();
1912 const Double_t rocaucTrain = rocCurveTrain->GetROCIntegral();
1913 const Double_t effB01Train = rocCurveTrain->GetEffSForEffB(0.01);
1914 const Double_t effB10Train = rocCurveTrain->GetEffSForEffB(0.10);
1915 const Double_t effB30Train = rocCurveTrain->GetEffSForEffB(0.30);
1916 const Double_t rocaucTest = rocCurveTest->GetROCIntegral();
1917 const Double_t effB01Test = rocCurveTest->GetEffSForEffB(0.01);
1918 const Double_t effB10Test = rocCurveTest->GetEffSForEffB(0.10);
1919 const Double_t effB30Test = rocCurveTest->GetEffSForEffB(0.30);
1920 const TString rocaucCmp = Form("%5.3f (%5.3f)", rocaucTest, rocaucTrain);
1921 const TString effB01Cmp = Form("%5.3f (%5.3f)", effB01Test, effB01Train);
1922 const TString effB10Cmp = Form("%5.3f (%5.3f)", effB10Test, effB10Train);
1923 const TString effB30Cmp = Form("%5.3f (%5.3f)", effB30Test, effB30Train);
1924 row = Form("%-15s%-15s%-15s%-15s%-15s%-15s", "", className.Data(), rocaucCmp.Data(), effB01Cmp.Data(),
1925 effB10Cmp.Data(), effB30Cmp.Data());
1926 Log() << kINFO << row << Endl;
1927
1928 delete rocCurveTrain;
1929 delete rocCurveTest;
1930 }
1931 }
1932 }
1933 Log() << kINFO << Endl;
1934 Log() << kINFO << hLine << Endl;
1935 Log() << kINFO << Endl;
1936
1937 // --- Confusion matrices
1938 // --------------------------------------------------------------------
1939 auto printMatrix = [](TMatrixD const &matTraining, TMatrixD const &matTesting, std::vector<TString> classnames,
1940 UInt_t numClasses, MsgLogger &stream) {
1941 // assert (classLabledWidth >= valueLabelWidth + 2)
1942 // if (...) {Log() << kWARN << "..." << Endl; }
1943
1944 // TODO: Ensure matrices are same size.
1945
1946 TString header = Form(" %-14s", " ");
1947 TString headerInfo = Form(" %-14s", " ");
1948 ;
1949 for (UInt_t iCol = 0; iCol < numClasses; ++iCol) {
1950 header += Form(" %-14s", classnames[iCol].Data());
1951 headerInfo += Form(" %-14s", " test (train)");
1952 }
1953 stream << kINFO << header << Endl;
1954 stream << kINFO << headerInfo << Endl;
1955
1956 for (UInt_t iRow = 0; iRow < numClasses; ++iRow) {
1957 stream << kINFO << Form(" %-14s", classnames[iRow].Data());
1958
1959 for (UInt_t iCol = 0; iCol < numClasses; ++iCol) {
1960 if (iCol == iRow) {
1961 stream << kINFO << Form(" %-14s", "-");
1962 } else {
1963 Double_t trainValue = matTraining[iRow][iCol];
1964 Double_t testValue = matTesting[iRow][iCol];
1965 TString entry = Form("%-5.3f (%-5.3f)", testValue, trainValue);
1966 stream << kINFO << Form(" %-14s", entry.Data());
1967 }
1968 }
1969 stream << kINFO << Endl;
1970 }
1971 };
1972
1973 Log() << kINFO << Endl;
1974 Log() << kINFO << "Confusion matrices for all methods" << Endl;
1975 Log() << kINFO << hLine << Endl;
1976 Log() << kINFO << Endl;
1977 Log() << kINFO << "Does a binary comparison between the two classes given by a " << Endl;
1978 Log() << kINFO << "particular row-column combination. In each case, the class " << Endl;
1979 Log() << kINFO << "given by the row is considered signal while the class given " << Endl;
1980 Log() << kINFO << "by the column index is considered background." << Endl;
1981 Log() << kINFO << Endl;
1982 for (UInt_t iMethod = 0; iMethod < methods->size(); ++iMethod) {
1983 MethodBase *theMethod = dynamic_cast<MethodBase *>(methods->at(iMethod));
1984 if (theMethod == nullptr) {
1985 continue;
1986 }
1987 UInt_t numClasses = theMethod->fDataSetInfo.GetNClasses();
1988
1989 std::vector<TString> classnames;
1990 for (UInt_t iCls = 0; iCls < numClasses; ++iCls) {
1991 classnames.push_back(theMethod->fDataSetInfo.GetClassInfo(iCls)->GetName());
1992 }
1993 Log() << kINFO
1994 << "=== Showing confusion matrix for method : " << Form("%-15s", (const char *)mname[0][iMethod])
1995 << Endl;
1996 Log() << kINFO << "(Signal Efficiency for Background Efficiency 0.01%)" << Endl;
1997 Log() << kINFO << "---------------------------------------------------" << Endl;
1998 printMatrix(multiclass_testConfusionEffB01[iMethod], multiclass_trainConfusionEffB01[iMethod], classnames,
1999 numClasses, Log());
2000 Log() << kINFO << Endl;
2001
2002 Log() << kINFO << "(Signal Efficiency for Background Efficiency 0.10%)" << Endl;
2003 Log() << kINFO << "---------------------------------------------------" << Endl;
2004 printMatrix(multiclass_testConfusionEffB10[iMethod], multiclass_trainConfusionEffB10[iMethod], classnames,
2005 numClasses, Log());
2006 Log() << kINFO << Endl;
2007
2008 Log() << kINFO << "(Signal Efficiency for Background Efficiency 0.30%)" << Endl;
2009 Log() << kINFO << "---------------------------------------------------" << Endl;
2010 printMatrix(multiclass_testConfusionEffB30[iMethod], multiclass_trainConfusionEffB30[iMethod], classnames,
2011 numClasses, Log());
2012 Log() << kINFO << Endl;
2013 }
2014 Log() << kINFO << hLine << Endl;
2015 Log() << kINFO << Endl;
2016
2017 } else {
2018 // Binary classification
2019 if (fROC) {
2020 Log().EnableOutput();
2022 Log() << Endl;
2023 TString hLine = "------------------------------------------------------------------------------------------"
2024 "-------------------------";
2025 Log() << kINFO << "Evaluation results ranked by best signal efficiency and purity (area)" << Endl;
2026 Log() << kINFO << hLine << Endl;
2027 Log() << kINFO << "DataSet MVA " << Endl;
2028 Log() << kINFO << "Name: Method: ROC-integ" << Endl;
2029
2030 // Log() << kDEBUG << "DataSet MVA Signal efficiency at bkg eff.(error):
2031 // | Sepa- Signifi- " << Endl; Log() << kDEBUG << "Name: Method: @B=0.01
2032 // @B=0.10 @B=0.30 ROC-integ ROCCurve| ration: cance: " << Endl;
2033 Log() << kDEBUG << hLine << Endl;
2034 for (Int_t k = 0; k < 2; k++) {
2035 if (k == 1 && nmeth_used[k] > 0) {
2036 Log() << kINFO << hLine << Endl;
2037 Log() << kINFO << "Input Variables: " << Endl << hLine << Endl;
2038 }
2039 for (Int_t i = 0; i < nmeth_used[k]; i++) {
2040 TString datasetName = itrMap->first;
2041 TString methodName = mname[k][i];
2042
2043 if (k == 1) {
2044 methodName.ReplaceAll("Variable_", "");
2045 }
2046
2047 MethodBase *theMethod = dynamic_cast<MethodBase *>(GetMethod(datasetName, methodName));
2048 if (theMethod == 0) {
2049 continue;
2050 }
2051
2052 TMVA::DataSet *dataset = theMethod->Data();
2053 TMVA::Results *results = dataset->GetResults(methodName, Types::kTesting, this->fAnalysisType);
2054 std::vector<Bool_t> *mvaResType =
2055 dynamic_cast<ResultsClassification *>(results)->GetValueVectorTypes();
2056
2057 Double_t rocIntegral = 0.0;
2058 if (mvaResType->size() != 0) {
2059 rocIntegral = GetROCIntegral(datasetName, methodName);
2060 }
2061
2062 if (sep[k][i] < 0 || sig[k][i] < 0) {
2063 // cannot compute separation/significance -> no MVA (usually for Cuts)
2064 Log() << kINFO << Form("%-13s %-15s: %#1.3f", datasetName.Data(), methodName.Data(), effArea[k][i])
2065 << Endl;
2066
2067 // Log() << kDEBUG << Form("%-20s %-15s: %#1.3f(%02i) %#1.3f(%02i) %#1.3f(%02i)
2068 // %#1.3f %#1.3f | -- --",
2069 // datasetName.Data(),
2070 // methodName.Data(),
2071 // eff01[k][i], Int_t(1000*eff01err[k][i]),
2072 // eff10[k][i], Int_t(1000*eff10err[k][i]),
2073 // eff30[k][i], Int_t(1000*eff30err[k][i]),
2074 // effArea[k][i],rocIntegral) << Endl;
2075 } else {
2076 Log() << kINFO << Form("%-13s %-15s: %#1.3f", datasetName.Data(), methodName.Data(), rocIntegral)
2077 << Endl;
2078 // Log() << kDEBUG << Form("%-20s %-15s: %#1.3f(%02i) %#1.3f(%02i) %#1.3f(%02i)
2079 // %#1.3f %#1.3f | %#1.3f %#1.3f",
2080 // datasetName.Data(),
2081 // methodName.Data(),
2082 // eff01[k][i], Int_t(1000*eff01err[k][i]),
2083 // eff10[k][i], Int_t(1000*eff10err[k][i]),
2084 // eff30[k][i], Int_t(1000*eff30err[k][i]),
2085 // effArea[k][i],rocIntegral,
2086 // sep[k][i], sig[k][i]) << Endl;
2087 }
2088 }
2089 }
2090 Log() << kINFO << hLine << Endl;
2091 Log() << kINFO << Endl;
2092 Log() << kINFO << "Testing efficiency compared to training efficiency (overtraining check)" << Endl;
2093 Log() << kINFO << hLine << Endl;
2094 Log() << kINFO
2095 << "DataSet MVA Signal efficiency: from test sample (from training sample) "
2096 << Endl;
2097 Log() << kINFO << "Name: Method: @B=0.01 @B=0.10 @B=0.30 "
2098 << Endl;
2099 Log() << kINFO << hLine << Endl;
2100 for (Int_t k = 0; k < 2; k++) {
2101 if (k == 1 && nmeth_used[k] > 0) {
2102 Log() << kINFO << hLine << Endl;
2103 Log() << kINFO << "Input Variables: " << Endl << hLine << Endl;
2104 }
2105 for (Int_t i = 0; i < nmeth_used[k]; i++) {
2106 if (k == 1) mname[k][i].ReplaceAll("Variable_", "");
2107 MethodBase *theMethod = dynamic_cast<MethodBase *>((*methods)[i]);
2108 if (theMethod == 0) continue;
2109
2110 Log() << kINFO << Form("%-20s %-15s: %#1.3f (%#1.3f) %#1.3f (%#1.3f) %#1.3f (%#1.3f)",
2111 theMethod->fDataSetInfo.GetName(), (const char *)mname[k][i], eff01[k][i],
2112 trainEff01[k][i], eff10[k][i], trainEff10[k][i], eff30[k][i], trainEff30[k][i])
2113 << Endl;
2114 }
2115 }
2116 Log() << kINFO << hLine << Endl;
2117 Log() << kINFO << Endl;
2118
2119 if (gTools().CheckForSilentOption(GetOptions())) Log().InhibitOutput();
2120 } // end fROC
2121 }
2122 if(!IsSilentFile())
2123 {
2124 std::list<TString> datasets;
2125 for (Int_t k=0; k<2; k++) {
2126 for (Int_t i=0; i<nmeth_used[k]; i++) {
2127 MethodBase* theMethod = dynamic_cast<MethodBase*>((*methods)[i]);
2128 if(theMethod==0) continue;
2129 // write test/training trees
2130 RootBaseDir()->cd(theMethod->fDataSetInfo.GetName());
2131 if(std::find(datasets.begin(), datasets.end(), theMethod->fDataSetInfo.GetName()) == datasets.end())
2132 {
2135 datasets.push_back(theMethod->fDataSetInfo.GetName());
2136 }
2137 }
2138 }
2139 }
2140 }//end for MethodsMap
2141 // references for citation
2143}
2144
2145////////////////////////////////////////////////////////////////////////////////
2146/// Evaluate Variable Importance
2147
2148TH1F* TMVA::Factory::EvaluateImportance(DataLoader *loader,VIType vitype, Types::EMVA theMethod, TString methodTitle, const char *theOption)
2149{
2150 fModelPersistence=kFALSE;
2151 fSilentFile=kTRUE;//we need silent file here because we need fast classification results
2152
2153 //getting number of variables and variable names from loader
2154 const int nbits = loader->DefaultDataSetInfo().GetNVariables();
2155 if(vitype==VIType::kShort)
2156 return EvaluateImportanceShort(loader,theMethod,methodTitle,theOption);
2157 else if(vitype==VIType::kAll)
2158 return EvaluateImportanceAll(loader,theMethod,methodTitle,theOption);
2159 else if(vitype==VIType::kRandom&&nbits>10)
2160 {
2161 return EvaluateImportanceRandom(loader,pow(2,nbits),theMethod,methodTitle,theOption);
2162 }else
2163 {
2164 std::cerr<<"Error in Variable Importance: Random mode require more that 10 variables in the dataset."<<std::endl;
2165 return nullptr;
2166 }
2167}
2168
2169////////////////////////////////////////////////////////////////////////////////
2170
2171TH1F* TMVA::Factory::EvaluateImportanceAll(DataLoader *loader, Types::EMVA theMethod, TString methodTitle, const char *theOption)
2172{
2173
2174 uint64_t x = 0;
2175 uint64_t y = 0;
2176
2177 //getting number of variables and variable names from loader
2178 const int nbits = loader->DefaultDataSetInfo().GetNVariables();
2179 std::vector<TString> varNames = loader->DefaultDataSetInfo().GetListOfVariables();
2180
2181 uint64_t range = pow(2, nbits);
2182
2183 //vector to save importances
2184 std::vector<Double_t> importances(nbits);
2185 //vector to save ROC
2186 std::vector<Double_t> ROC(range);
2187 ROC[0]=0.5;
2188 for (int i = 0; i < nbits; i++)importances[i] = 0;
2189
2190 Double_t SROC, SSROC; //computed ROC value
2191 for ( x = 1; x <range ; x++) {
2192
2193 std::bitset<VIBITS> xbitset(x);
2194 if (x == 0) continue; //data loader need at least one variable
2195
2196 //creating loader for seed
2197 TMVA::DataLoader *seedloader = new TMVA::DataLoader(xbitset.to_string());
2198
2199 //adding variables from seed
2200 for (int index = 0; index < nbits; index++) {
2201 if (xbitset[index]) seedloader->AddVariable(varNames[index], 'F');
2202 }
2203
2204 DataLoaderCopy(seedloader,loader);
2205 seedloader->PrepareTrainingAndTestTree(loader->DefaultDataSetInfo().GetCut("Signal"), loader->DefaultDataSetInfo().GetCut("Background"), loader->DefaultDataSetInfo().GetSplitOptions());
2206
2207 //Booking Seed
2208 BookMethod(seedloader, theMethod, methodTitle, theOption);
2209
2210 //Train/Test/Evaluation
2211 TrainAllMethods();
2212 TestAllMethods();
2213 EvaluateAllMethods();
2214
2215 //getting ROC
2216 ROC[x] = GetROCIntegral(xbitset.to_string(), methodTitle);
2217
2218 //cleaning information to process sub-seeds
2219 TMVA::MethodBase *smethod=dynamic_cast<TMVA::MethodBase*>(fMethodsMap[xbitset.to_string().c_str()][0][0]);
2221 delete sresults;
2222 delete seedloader;
2223 this->DeleteAllMethods();
2224
2225 fMethodsMap.clear();
2226 //removing global result because it is requiring a lot of RAM for all seeds
2227 }
2228
2229
2230 for ( x = 0; x <range ; x++)
2231 {
2232 SROC=ROC[x];
2233 for (uint32_t i = 0; i < VIBITS; ++i) {
2234 if (x & (1 << i)) {
2235 y = x & ~(1 << i);
2236 std::bitset<VIBITS> ybitset(y);
2237 //need at least one variable
2238 //NOTE: if sub-seed is zero then is the special case
2239 //that count in xbitset is 1
2240 Double_t ny = log(x - y) / 0.693147;
2241 if (y == 0) {
2242 importances[ny] = SROC - 0.5;
2243 continue;
2244 }
2245
2246 //getting ROC
2247 SSROC = ROC[y];
2248 importances[ny] += SROC - SSROC;
2249 //cleaning information
2250 }
2251
2252 }
2253 }
2254 std::cout<<"--- Variable Importance Results (All)"<<std::endl;
2255 return GetImportance(nbits,importances,varNames);
2256}
2257
2258static long int sum(long int i)
2259{
2260 long int _sum=0;
2261 for(long int n=0;n<i;n++) _sum+=pow(2,n);
2262 return _sum;
2263}
2264
2265////////////////////////////////////////////////////////////////////////////////
2266
2267TH1F* TMVA::Factory::EvaluateImportanceShort(DataLoader *loader, Types::EMVA theMethod, TString methodTitle, const char *theOption)
2268{
2269 uint64_t x = 0;
2270 uint64_t y = 0;
2271
2272 //getting number of variables and variable names from loader
2273 const int nbits = loader->DefaultDataSetInfo().GetNVariables();
2274 std::vector<TString> varNames = loader->DefaultDataSetInfo().GetListOfVariables();
2275
2276 long int range = sum(nbits);
2277// std::cout<<range<<std::endl;
2278 //vector to save importances
2279 std::vector<Double_t> importances(nbits);
2280 for (int i = 0; i < nbits; i++)importances[i] = 0;
2281
2282 Double_t SROC, SSROC; //computed ROC value
2283
2284 x = range;
2285
2286 std::bitset<VIBITS> xbitset(x);
2287 if (x == 0) Log()<<kFATAL<<"Error: need at least one variable."; //data loader need at least one variable
2288
2289
2290 //creating loader for seed
2291 TMVA::DataLoader *seedloader = new TMVA::DataLoader(xbitset.to_string());
2292
2293 //adding variables from seed
2294 for (int index = 0; index < nbits; index++) {
2295 if (xbitset[index]) seedloader->AddVariable(varNames[index], 'F');
2296 }
2297
2298 //Loading Dataset
2299 DataLoaderCopy(seedloader,loader);
2300
2301 //Booking Seed
2302 BookMethod(seedloader, theMethod, methodTitle, theOption);
2303
2304 //Train/Test/Evaluation
2305 TrainAllMethods();
2306 TestAllMethods();
2307 EvaluateAllMethods();
2308
2309 //getting ROC
2310 SROC = GetROCIntegral(xbitset.to_string(), methodTitle);
2311
2312 //cleaning information to process sub-seeds
2313 TMVA::MethodBase *smethod=dynamic_cast<TMVA::MethodBase*>(fMethodsMap[xbitset.to_string().c_str()][0][0]);
2315 delete sresults;
2316 delete seedloader;
2317 this->DeleteAllMethods();
2318 fMethodsMap.clear();
2319
2320 //removing global result because it is requiring a lot of RAM for all seeds
2321
2322 for (uint32_t i = 0; i < VIBITS; ++i) {
2323 if (x & (1 << i)) {
2324 y = x & ~(1 << i);
2325 std::bitset<VIBITS> ybitset(y);
2326 //need at least one variable
2327 //NOTE: if sub-seed is zero then is the special case
2328 //that count in xbitset is 1
2329 Double_t ny = log(x - y) / 0.693147;
2330 if (y == 0) {
2331 importances[ny] = SROC - 0.5;
2332 continue;
2333 }
2334
2335 //creating loader for sub-seed
2336 TMVA::DataLoader *subseedloader = new TMVA::DataLoader(ybitset.to_string());
2337 //adding variables from sub-seed
2338 for (int index = 0; index < nbits; index++) {
2339 if (ybitset[index]) subseedloader->AddVariable(varNames[index], 'F');
2340 }
2341
2342 //Loading Dataset
2343 DataLoaderCopy(subseedloader,loader);
2344
2345 //Booking SubSeed
2346 BookMethod(subseedloader, theMethod, methodTitle, theOption);
2347
2348 //Train/Test/Evaluation
2349 TrainAllMethods();
2350 TestAllMethods();
2351 EvaluateAllMethods();
2352
2353 //getting ROC
2354 SSROC = GetROCIntegral(ybitset.to_string(), methodTitle);
2355 importances[ny] += SROC - SSROC;
2356
2357 //cleaning information
2358 TMVA::MethodBase *ssmethod=dynamic_cast<TMVA::MethodBase*>(fMethodsMap[ybitset.to_string().c_str()][0][0]);
2360 delete ssresults;
2361 delete subseedloader;
2362 this->DeleteAllMethods();
2363 fMethodsMap.clear();
2364 }
2365 }
2366 std::cout<<"--- Variable Importance Results (Short)"<<std::endl;
2367 return GetImportance(nbits,importances,varNames);
2368}
2369
2370////////////////////////////////////////////////////////////////////////////////
2371
2372TH1F* TMVA::Factory::EvaluateImportanceRandom(DataLoader *loader, UInt_t nseeds, Types::EMVA theMethod, TString methodTitle, const char *theOption)
2373{
2374 TRandom3 *rangen = new TRandom3(0); //Random Gen.
2375
2376 uint64_t x = 0;
2377 uint64_t y = 0;
2378
2379 //getting number of variables and variable names from loader
2380 const int nbits = loader->DefaultDataSetInfo().GetNVariables();
2381 std::vector<TString> varNames = loader->DefaultDataSetInfo().GetListOfVariables();
2382
2383 long int range = pow(2, nbits);
2384
2385 //vector to save importances
2386 std::vector<Double_t> importances(nbits);
2387 Double_t importances_norm = 0;
2388 for (int i = 0; i < nbits; i++)importances[i] = 0;
2389
2390 Double_t SROC, SSROC; //computed ROC value
2391 for (UInt_t n = 0; n < nseeds; n++) {
2392 x = rangen -> Integer(range);
2393
2394 std::bitset<32> xbitset(x);
2395 if (x == 0) continue; //data loader need at least one variable
2396
2397
2398 //creating loader for seed
2399 TMVA::DataLoader *seedloader = new TMVA::DataLoader(xbitset.to_string());
2400
2401 //adding variables from seed
2402 for (int index = 0; index < nbits; index++) {
2403 if (xbitset[index]) seedloader->AddVariable(varNames[index], 'F');
2404 }
2405
2406 //Loading Dataset
2407 DataLoaderCopy(seedloader,loader);
2408
2409 //Booking Seed
2410 BookMethod(seedloader, theMethod, methodTitle, theOption);
2411
2412 //Train/Test/Evaluation
2413 TrainAllMethods();
2414 TestAllMethods();
2415 EvaluateAllMethods();
2416
2417 //getting ROC
2418 SROC = GetROCIntegral(xbitset.to_string(), methodTitle);
2419// std::cout << "Seed: n " << n << " x " << x << " xbitset:" << xbitset << " ROC " << SROC << std::endl;
2420
2421 //cleaning information to process sub-seeds
2422 TMVA::MethodBase *smethod=dynamic_cast<TMVA::MethodBase*>(fMethodsMap[xbitset.to_string().c_str()][0][0]);
2424 delete sresults;
2425 delete seedloader;
2426 this->DeleteAllMethods();
2427 fMethodsMap.clear();
2428
2429 //removing global result because it is requiring a lot of RAM for all seeds
2430
2431 for (uint32_t i = 0; i < 32; ++i) {
2432 if (x & (1 << i)) {
2433 y = x & ~(1 << i);
2434 std::bitset<32> ybitset(y);
2435 //need at least one variable
2436 //NOTE: if sub-seed is zero then is the special case
2437 //that count in xbitset is 1
2438 Double_t ny = log(x - y) / 0.693147;
2439 if (y == 0) {
2440 importances[ny] = SROC - 0.5;
2441 importances_norm += importances[ny];
2442 // std::cout << "SubSeed: " << y << " y:" << ybitset << "ROC " << 0.5 << std::endl;
2443 continue;
2444 }
2445
2446 //creating loader for sub-seed
2447 TMVA::DataLoader *subseedloader = new TMVA::DataLoader(ybitset.to_string());
2448 //adding variables from sub-seed
2449 for (int index = 0; index < nbits; index++) {
2450 if (ybitset[index]) subseedloader->AddVariable(varNames[index], 'F');
2451 }
2452
2453 //Loading Dataset
2454 DataLoaderCopy(subseedloader,loader);
2455
2456 //Booking SubSeed
2457 BookMethod(subseedloader, theMethod, methodTitle, theOption);
2458
2459 //Train/Test/Evaluation
2460 TrainAllMethods();
2461 TestAllMethods();
2462 EvaluateAllMethods();
2463
2464 //getting ROC
2465 SSROC = GetROCIntegral(ybitset.to_string(), methodTitle);
2466 importances[ny] += SROC - SSROC;
2467 //std::cout << "SubSeed: " << y << " y:" << ybitset << " x-y " << x - y << " " << std::bitset<32>(x - y) << " ny " << ny << " SROC " << SROC << " SSROC " << SSROC << " Importance = " << importances[ny] << std::endl;
2468 //cleaning information
2469 TMVA::MethodBase *ssmethod=dynamic_cast<TMVA::MethodBase*>(fMethodsMap[ybitset.to_string().c_str()][0][0]);
2471 delete ssresults;
2472 delete subseedloader;
2473 this->DeleteAllMethods();
2474 fMethodsMap.clear();
2475 }
2476 }
2477 }
2478 std::cout<<"--- Variable Importance Results (Random)"<<std::endl;
2479 return GetImportance(nbits,importances,varNames);
2480}
2481
2482////////////////////////////////////////////////////////////////////////////////
2483
2484TH1F* TMVA::Factory::GetImportance(const int nbits,std::vector<Double_t> importances,std::vector<TString> varNames)
2485{
2486 TH1F *vih1 = new TH1F("vih1", "", nbits, 0, nbits);
2487
2488 gStyle->SetOptStat(000000);
2489
2490 Float_t normalization = 0.0;
2491 for (int i = 0; i < nbits; i++) {
2492 normalization = normalization + importances[i];
2493 }
2494
2495 Float_t roc = 0.0;
2496
2497 gStyle->SetTitleXOffset(0.4);
2498 gStyle->SetTitleXOffset(1.2);
2499
2500
2501 Double_t x_ie[nbits], y_ie[nbits];
2502 for (Int_t i = 1; i < nbits + 1; i++) {
2503 x_ie[i - 1] = (i - 1) * 1.;
2504 roc = 100.0 * importances[i - 1] / normalization;
2505 y_ie[i - 1] = roc;
2506 std::cout<<"--- "<<varNames[i-1]<<" = "<<roc<<" %"<<std::endl;
2507 vih1->GetXaxis()->SetBinLabel(i, varNames[i - 1].Data());
2508 vih1->SetBinContent(i, roc);
2509 }
2510 TGraph *g_ie = new TGraph(nbits + 2, x_ie, y_ie);
2511 g_ie->SetTitle("");
2512
2513 vih1->LabelsOption("v >", "X");
2514 vih1->SetBarWidth(0.97);
2515 Int_t ca = TColor::GetColor("#006600");
2516 vih1->SetFillColor(ca);
2517 //Int_t ci = TColor::GetColor("#990000");
2518
2519 vih1->GetYaxis()->SetTitle("Importance (%)");
2520 vih1->GetYaxis()->SetTitleSize(0.045);
2521 vih1->GetYaxis()->CenterTitle();
2522 vih1->GetYaxis()->SetTitleOffset(1.24);
2523
2524 vih1->GetYaxis()->SetRangeUser(-7, 50);
2525 vih1->SetDirectory(0);
2526
2527// vih1->Draw("B");
2528 return vih1;
2529}
2530
SVector< double, 2 > v
Definition: Dict.h:5
#define h(i)
Definition: RSha256.hxx:106
int Int_t
Definition: RtypesCore.h:41
unsigned int UInt_t
Definition: RtypesCore.h:42
const Bool_t kFALSE
Definition: RtypesCore.h:88
bool Bool_t
Definition: RtypesCore.h:59
double Double_t
Definition: RtypesCore.h:55
float Float_t
Definition: RtypesCore.h:53
const Bool_t kTRUE
Definition: RtypesCore.h:87
#define ClassImp(name)
Definition: Rtypes.h:363
int type
Definition: TGX11.cxx:120
double pow(double, double)
double log(double)
TMatrixT< Double_t > TMatrixD
Definition: TMatrixDfwd.h:22
#define gROOT
Definition: TROOT.h:410
char * Form(const char *fmt,...)
R__EXTERN TStyle * gStyle
Definition: TStyle.h:406
R__EXTERN TSystem * gSystem
Definition: TSystem.h:540
virtual void SetTitleOffset(Float_t offset=1)
Set distance between the axis and the axis title Offset is a correction factor with respect to the "s...
Definition: TAttAxis.cxx:294
virtual void SetTitleSize(Float_t size=0.04)
Set size of axis title The size is expressed in per cent of the pad width.
Definition: TAttAxis.cxx:304
virtual void SetFillColor(Color_t fcolor)
Set the fill area color.
Definition: TAttFill.h:37
virtual void SetBinLabel(Int_t bin, const char *label)
Set label for bin.
Definition: TAxis.cxx:809
void CenterTitle(Bool_t center=kTRUE)
Center axis title.
Definition: TAxis.h:184
virtual void SetRangeUser(Double_t ufirst, Double_t ulast)
Set the viewing range for the axis from ufirst to ulast (in user coordinates).
Definition: TAxis.cxx:928
The Canvas class.
Definition: TCanvas.h:31
static Int_t GetColor(const char *hexcolor)
Static method returning color number for color specified by hex color string of form: "#rrggbb",...
Definition: TColor.cxx:1758
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:48
A Graph is a graphics object made of two arrays X and Y with npoints each.
Definition: TGraph.h:41
virtual void SetTitle(const char *title="")
Set graph title.
Definition: TGraph.cxx:2232
1-D histogram with a float per channel (see TH1 documentation)}
Definition: TH1.h:571
virtual void SetDirectory(TDirectory *dir)
By default when an histogram is created, it is added to the list of histogram objects in the current ...
Definition: TH1.cxx:8259
virtual void SetTitle(const char *title)
See GetStatOverflows for more information.
Definition: TH1.cxx:6217
virtual void LabelsOption(Option_t *option="h", Option_t *axis="X")
Set option(s) to draw axis with labels.
Definition: TH1.cxx:5105
static void AddDirectory(Bool_t add=kTRUE)
Sets the flag controlling the automatic add of histograms in memory.
Definition: TH1.cxx:1225
TAxis * GetXaxis()
Get the behaviour adopted by the object about the statoverflows. See EStatOverflows for more informat...
Definition: TH1.h:316
TAxis * GetYaxis()
Definition: TH1.h:317
virtual void SetBarWidth(Float_t width=0.5)
Definition: TH1.h:356
virtual void SetBinContent(Int_t bin, Double_t content)
Set bin content see convention for numbering bins in TH1::GetBin In case the bin number is greater th...
Definition: TH1.cxx:8542
Service class for 2-Dim histogram classes.
Definition: TH2.h:30
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
TString fWeightFileDir
Definition: Config.h:112
void SetDrawProgressBar(Bool_t d)
Definition: Config.h:75
void SetUseColor(Bool_t uc)
Definition: Config.h:66
class TMVA::Config::VariablePlotting fVariablePlotting
void SetSilent(Bool_t s)
Definition: Config.h:69
IONames & GetIONames()
Definition: Config.h:90
void SetConfigDescription(const char *d)
Definition: Configurable.h:64
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
void AddPreDefVal(const T &)
Definition: Configurable.h:168
void SetConfigName(const char *n)
Definition: Configurable.h:63
virtual void ParseOptions()
options parser
const TString & GetOptions() const
Definition: Configurable.h:84
MsgLogger & Log() const
Definition: Configurable.h:122
MsgLogger * fLogger
Definition: Configurable.h:128
void CheckForUnusedOptions() const
checks for unused options in option string
UInt_t GetEntries(const TString &name) const
DataSetInfo & DefaultDataSetInfo()
default creation
Definition: DataLoader.cxx:530
DataSetManager * fDataSetManager
Definition: DataLoader.h:189
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
Definition: DataLoader.cxx:629
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
Definition: DataLoader.cxx:491
Class that contains all the data information.
Definition: DataSetInfo.h:60
UInt_t GetNVariables() const
Definition: DataSetInfo.h:110
virtual const char * GetName() const
Returns name of object.
Definition: DataSetInfo.h:67
const TMatrixD * CorrelationMatrix(const TString &className) const
UInt_t GetNClasses() const
Definition: DataSetInfo.h:136
const TString & GetSplitOptions() const
Definition: DataSetInfo.h:167
UInt_t GetNTargets() const
Definition: DataSetInfo.h:111
DataSet * GetDataSet() const
returns data set
TH2 * CreateCorrelationMatrixHist(const TMatrixD *m, const TString &hName, const TString &hTitle) const
std::vector< TString > GetListOfVariables() const
returns list of variables
ClassInfo * GetClassInfo(Int_t clNum) const
const TCut & GetCut(Int_t i) const
Definition: DataSetInfo.h:149
VariableInfo & GetVariableInfo(Int_t i)
Definition: DataSetInfo.h:96
Bool_t IsSignal(const Event *ev) const
DataSetManager * GetDataSetManager()
Definition: DataSetInfo.h:175
DataInputHandler & DataInput()
Class that contains all the data information.
Definition: DataSet.h:69
Long64_t GetNEvtSigTest()
return number of signal test events in dataset
Definition: DataSet.cxx:427
TTree * GetTree(Types::ETreeType type)
create the test/trainings tree with all the variables, the weights, the classes, the targets,...
Definition: DataSet.cxx:609
const Event * GetEvent() const
Definition: DataSet.cxx:202
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:217
Results * GetResults(const TString &, Types::ETreeType type, Types::EAnalysisType analysistype)
Definition: DataSet.cxx:265
Long64_t GetNTrainingEvents() const
Definition: DataSet.h:79
void SetCurrentType(Types::ETreeType type) const
Definition: DataSet.h:100
const std::vector< Event * > & GetEventCollection(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:227
Long64_t GetNEvtBkgdTest()
return number of background test events in dataset
Definition: DataSet.cxx:435
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition: Event.cxx:237
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition: Event.cxx:392
This is the main MVA steering class.
Definition: Factory.h:81
void PrintHelpMessage(const TString &datasetname, const TString &methodTitle="") const
Print predefined help message of classifier.
Definition: Factory.cxx:1291
Double_t GetROCIntegral(DataLoader *loader, TString theMethodName, UInt_t iClass=0)
Calculate the integral of the ROC curve, also known as the area under curve (AUC),...
Definition: Factory.cxx:842
Bool_t fCorrelations
verbosity level, controls granularity of logging
Definition: Factory.h:211
std::vector< IMethod * > MVector
Definition: Factory.h:85
void TrainAllMethods()
Iterates through all booked methods and calls training.
Definition: Factory.cxx:1093
Bool_t Verbose(void) const
Definition: Factory.h:134
void WriteDataInformation(DataSetInfo &fDataSetInfo)
Definition: Factory.cxx:596
MethodBase * BookMethod(DataLoader *loader, TString theMethodName, TString methodTitle, TString theOption="")
Book a classifier or regression method.
Definition: Factory.cxx:358
Factory(TString theJobName, TFile *theTargetFile, TString theOption="")
Standard constructor.
Definition: Factory.cxx:119
void TestAllMethods()
Evaluates all booked methods on the testing data and adds the output to the Results in the corresponi...
Definition: Factory.cxx:1231
Bool_t fVerbose
list of transformations to test
Definition: Factory.h:209
void EvaluateAllMethods(void)
Iterates over all MVAs that have been booked, and calls their evaluation methods.
Definition: Factory.cxx:1333
TH1F * EvaluateImportanceRandom(DataLoader *loader, UInt_t nseeds, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Definition: Factory.cxx:2372
TH1F * GetImportance(const int nbits, std::vector< Double_t > importances, std::vector< TString > varNames)
Definition: Factory.cxx:2484
Bool_t fROC
enable to calculate corelations
Definition: Factory.h:212
void EvaluateAllVariables(DataLoader *loader, TString options="")
Iterates over all MVA input variables and evaluates them.
Definition: Factory.cxx:1318
TString fVerboseLevel
verbose mode
Definition: Factory.h:210
TH1F * EvaluateImportance(DataLoader *loader, VIType vitype, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Evaluate Variable Importance.
Definition: Factory.cxx:2148
virtual const char * GetName() const
Returns name of object.
Definition: Factory.h:97
virtual ~Factory()
Destructor.
Definition: Factory.cxx:312
TGraph * GetROCCurve(DataLoader *loader, TString theMethodName, Bool_t setTitles=kTRUE, UInt_t iClass=0)
Argument iClass specifies the class to generate the ROC curve in a multiclass setting.
Definition: Factory.cxx:903
virtual void MakeClass(const TString &datasetname, const TString &methodTitle="") const
Definition: Factory.cxx:1263
Bool_t IsModelPersistence()
Definition: Factory.cxx:304
MethodBase * BookMethodWeightfile(DataLoader *dataloader, TMVA::Types::EMVA methodType, const TString &weightfile)
Adds an already constructed method to be managed by this factory.
Definition: Factory.cxx:503
Bool_t fModelPersistence
the training type
Definition: Factory.h:218
std::map< TString, Double_t > OptimizeAllMethods(TString fomType="ROCIntegral", TString fitType="FitGA")
Iterates through all booked methods and sees if they use parameter tuning and if so.
Definition: Factory.cxx:694
ROCCurve * GetROC(DataLoader *loader, TString theMethodName, UInt_t iClass=0, Types::ETreeType type=Types::kTesting)
Private method to generate a ROCCurve instance for a given method.
Definition: Factory.cxx:743
TH1F * EvaluateImportanceShort(DataLoader *loader, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Definition: Factory.cxx:2267
Bool_t IsSilentFile()
Definition: Factory.cxx:297
Types::EAnalysisType fAnalysisType
jobname, used as extension in weight file names
Definition: Factory.h:217
Bool_t HasMethod(const TString &datasetname, const TString &title) const
Checks whether a given method name is defined for a given dataset.
Definition: Factory.cxx:579
TH1F * EvaluateImportanceAll(DataLoader *loader, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Definition: Factory.cxx:2171
void SetVerbose(Bool_t v=kTRUE)
Definition: Factory.cxx:350
TFile * fgTargetFile
Definition: Factory.h:201
IMethod * GetMethod(const TString &datasetname, const TString &title) const
Returns pointer to MVA that corresponds to given method title.
Definition: Factory.cxx:561
void DeleteAllMethods(void)
Delete methods.
Definition: Factory.cxx:330
TString fTransformations
option string given by construction (presently only "V")
Definition: Factory.h:208
void Greetings()
Print welcome message.
Definition: Factory.cxx:288
TMultiGraph * GetROCCurveAsMultiGraph(DataLoader *loader, UInt_t iClass)
Generate a collection of graphs, for all methods for a given class.
Definition: Factory.cxx:972
Interface for all concrete MVA method implementations.
Definition: IMethod.h:54
virtual void PrintHelpMessage() const =0
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)=0
virtual void MakeClass(const TString &classFileName=TString("")) const =0
Virtual base Class for all MVA method.
Definition: MethodBase.h:109
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as
void SetSilentFile(Bool_t status)
Definition: MethodBase.h:369
void SetWeightFileDir(TString fileDir)
set directory of weight file
virtual void TestRegression(Double_t &bias, Double_t &biasT, Double_t &dev, Double_t &devT, Double_t &rms, Double_t &rmsT, Double_t &mInf, Double_t &mInfT, Double_t &corr, Types::ETreeType type)
calculate <sum-of-deviation-squared> of regression output versus "true" value from test sample
Definition: MethodBase.cxx:982
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
Definition: MethodBase.cxx:601
TString GetMethodTypeName() const
Definition: MethodBase.h:323
Bool_t DoMulticlass() const
Definition: MethodBase.h:430
virtual Double_t GetSignificance() const
compute significance of mean difference
const char * GetName() const
Definition: MethodBase.h:325
Types::EAnalysisType GetAnalysisType() const
Definition: MethodBase.h:428
virtual TMatrixD GetMulticlassConfusionMatrix(Double_t effB, Types::ETreeType type)
Construct a confusion matrix for a multiclass classifier.
void PrintHelpMessage() const
prints out method-specific help method
virtual void WriteEvaluationHistosToFile(Types::ETreeType treetype)
writes all MVA evaluation histograms to file
virtual void TestMulticlass()
test multiclass classification
void SetupMethod()
setup of methods
Definition: MethodBase.cxx:411
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
virtual void SetAnalysisType(Types::EAnalysisType type)
Definition: MethodBase.h:427
const TString & GetMethodName() const
Definition: MethodBase.h:322
Bool_t DoRegression() const
Definition: MethodBase.h:429
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
Definition: MethodBase.cxx:428
virtual Double_t GetTrainingEfficiency(const TString &)
DataSetInfo & DataInfo() const
Definition: MethodBase.h:401
virtual void MakeClass(const TString &classFileName=TString("")) const
create reader class for method (classification only at present)
virtual void TestClassification()
initialization
void AddOutput(Types::ETreeType type, Types::EAnalysisType analysisType)
void ReadStateFromFile()
Function to write options and weights to file.
virtual std::map< TString, Double_t > OptimizeTuningParameters(TString fomType="ROCIntegral", TString fitType="FitGA")
call the Optimizer with the set of parameters and ranges that are meant to be tuned.
Definition: MethodBase.cxx:628
DataSetInfo & fDataSetInfo
Definition: MethodBase.h:596
Types::EMVA GetMethodType() const
Definition: MethodBase.h:324
void SetFile(TFile *file)
Definition: MethodBase.h:366
DataSet * Data() const
Definition: MethodBase.h:400
void SetModelPersistence(Bool_t status)
Definition: MethodBase.h:373
virtual Double_t GetROCIntegral(TH1D *histS, TH1D *histB) const
calculate the area (integral) under the ROC curve as a overall quality measure of the classification
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Definition: MethodBase.cxx:438
Class for boosting a TMVA method.
Definition: MethodBoost.h:58
void SetBoostedMethodName(TString methodName)
Definition: MethodBoost.h:86
DataSetManager * fDataSetManager
Definition: MethodBoost.h:193
Class for categorizing the phase space.
DataSetManager * fDataSetManager
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
void SetMinType(EMsgType minType)
Definition: MsgLogger.h:72
void SetSource(const std::string &source)
Definition: MsgLogger.h:70
static void InhibitOutput()
Definition: MsgLogger.cxx:74
Double_t GetEffSForEffB(Double_t effB, const UInt_t num_points=41)
Calculate the signal efficiency (sensitivity) for a given background efficiency (sensitivity).
Definition: ROCCurve.cxx:220
Double_t GetROCIntegral(const UInt_t points=41)
Calculates the ROC integral (AUC)
Definition: ROCCurve.cxx:251
TGraph * GetROCCurve(const UInt_t points=100)
Returns a new TGraph containing the ROC curve.
Definition: ROCCurve.cxx:277
Ranking for variables in method (implementation)
Definition: Ranking.h:48
virtual void Print() const
get maximum length of variable names
Definition: Ranking.cxx:111
Class that is the base-class for a vector of result.
Class which takes the results of a multiclass classification.
Class that is the base-class for a vector of result.
Definition: Results.h:57
void FormattedOutput(const std::vector< Double_t > &, const std::vector< TString > &, const TString titleVars, const TString titleValues, MsgLogger &logger, TString format="%+1.3f")
formatted output of simple table
Definition: Tools.cxx:899
void UsefulSortDescending(std::vector< std::vector< Double_t > > &, std::vector< TString > *vs=0)
sort 2D vector (AND in parallel a TString vector) in such a way that the "first vector is sorted" and...
Definition: Tools.cxx:576
void ROOTVersionMessage(MsgLogger &logger)
prints the ROOT release number and date
Definition: Tools.cxx:1337
std::vector< TString > SplitString(const TString &theOpt, const char separator) const
splits the option string at 'separator' and fills the list 'splitV' with the primitive strings
Definition: Tools.cxx:1211
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:840
const TMatrixD * GetCorrelationMatrix(const TMatrixD *covMat)
turns covariance into correlation matrix
Definition: Tools.cxx:336
@ kHtmlLink
Definition: Tools.h:216
void TMVACitation(MsgLogger &logger, ECitation citType=kPlainText)
kinds of TMVA citation
Definition: Tools.cxx:1453
void TMVAVersionMessage(MsgLogger &logger)
prints the TMVA release number and date
Definition: Tools.cxx:1328
void TMVAWelcomeMessage()
direct output, eg, when starting ROOT session -> no use of Logger here
Definition: Tools.cxx:1314
void UsefulSortAscending(std::vector< std::vector< Double_t > > &, std::vector< TString > *vs=0)
sort 2D vector (AND in parallel a TString vector) in such a way that the "first vector is sorted" and...
Definition: Tools.cxx:550
Class that contains all the data information.
void PrintVariableRanking() const
prints ranking of input variables
Singleton class for Global types used by TMVA.
Definition: Types.h:73
static Types & Instance()
the the single instance of "Types" if existing already, or create it (Singleton)
Definition: Types.cxx:70
@ kCategory
Definition: Types.h:99
@ kCuts
Definition: Types.h:80
EAnalysisType
Definition: Types.h:127
@ kMulticlass
Definition: Types.h:130
@ kNoAnalysisType
Definition: Types.h:131
@ kClassification
Definition: Types.h:128
@ kMaxAnalysisType
Definition: Types.h:132
@ kRegression
Definition: Types.h:129
@ kTraining
Definition: Types.h:144
@ kTesting
Definition: Types.h:145
const TString & GetLabel() const
Definition: VariableInfo.h:59
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:35
TList * GetListOfGraphs() const
Definition: TMultiGraph.h:69
virtual void Add(TGraph *graph, Option_t *chopt="")
Add a new graph to the list of graphs.
TH1F * GetHistogram()
Returns a pointer to the histogram used to draw the axis.
virtual void Draw(Option_t *chopt="")
Draw this multigraph with its current attributes.
TAxis * GetYaxis()
Get y axis of the graph.
TAxis * GetXaxis()
Get x axis of the graph.
virtual void SetTitle(const char *title="")
Set the title of the TNamed.
Definition: TNamed.cxx:164
virtual TObject * Clone(const char *newname="") const
Make a clone of an object using the Streamer facility.
Definition: TNamed.cxx:74
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:47
@ kOverwrite
overwrite existing object with same name
Definition: TObject.h:88
virtual TLegend * BuildLegend(Double_t x1=0.3, Double_t y1=0.21, Double_t x2=0.3, Double_t y2=0.21, const char *title="", Option_t *option="")
Build a legend from the graphical objects in the pad.
Definition: TPad.cxx:494
virtual void SetGrid(Int_t valuex=1, Int_t valuey=1)
Definition: TPad.h:328
Principal Components Analysis (PCA)
Definition: TPrincipal.h:20
virtual void AddRow(const Double_t *x)
Add a data point and update the covariance matrix.
Definition: TPrincipal.cxx:410
const TMatrixD * GetCovarianceMatrix() const
Definition: TPrincipal.h:58
virtual void MakePrincipals()
Perform the principal components analysis.
Definition: TPrincipal.cxx:869
Random number generator class based on M.
Definition: TRandom3.h:27
Basic string class.
Definition: TString.h:131
void ToLower()
Change string to lower-case.
Definition: TString.cxx:1100
int CompareTo(const char *cs, ECaseCompare cmp=kExact) const
Compare a string to char *cs2.
Definition: TString.cxx:406
const char * Data() const
Definition: TString.h:364
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition: TString.h:687
Bool_t BeginsWith(const char *s, ECaseCompare cmp=kExact) const
Definition: TString.h:610
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition: TString.h:619
void SetOptStat(Int_t stat=1)
The type of information printed in the histogram statistics box can be selected via the parameter mod...
Definition: TStyle.cxx:1444
void SetTitleXOffset(Float_t offset=1)
Definition: TStyle.h:386
virtual int MakeDirectory(const char *name)
Make a directory.
Definition: TSystem.cxx:834
virtual Int_t Write(const char *name=0, Int_t option=0, Int_t bufsize=0)
Write this object to the current directory.
Definition: TTree.cxx:9338
Double_t y[n]
Definition: legend1.C:17
Double_t x[n]
Definition: legend1.C:17
const Int_t n
Definition: legend1.C:16
std::string GetMethodName(TCppMethod_t)
Definition: Cppyy.cxx:750
TCppMethod_t GetMethod(TCppScope_t scope, TCppIndex_t imeth)
Definition: Cppyy.cxx:744
static constexpr double s
void DataLoaderCopy(TMVA::DataLoader *des, TMVA::DataLoader *src)
Config & gConfig()
Tools & gTools()
void CreateVariableTransforms(const TString &trafoDefinition, TMVA::DataSetInfo &dataInfo, TMVA::TransformationHandler &transformationHandler, TMVA::MsgLogger &log)
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Bool_t IsNaN(Double_t x)
Definition: TMath.h:880
Double_t Log(Double_t x)
Definition: TMath.h:748
Definition: graph.py:1
auto * m
Definition: textangle.C:8
#define VIBITS
Definition: Factory.cxx:107
static long int sum(long int i)
Definition: Factory.cxx:2258
const Int_t MinNoTrainingEvents
Definition: Factory.cxx:99
#define READXML
Definition: Factory.cxx:104