Logo ROOT  
Reference Guide
MethodBase.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne, Jan Therhaag
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodBase *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Implementation (see header for description) *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Joerg Stelzer <Joerg.Stelzer@cern.ch> - CERN, Switzerland *
16 * Peter Speckmayer <Peter.Speckmayer@cern.ch> - CERN, Switzerland *
17 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
18 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
19 * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
20 * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
21 * *
22 * Copyright (c) 2005-2011: *
23 * CERN, Switzerland *
24 * U. of Victoria, Canada *
25 * MPI-K Heidelberg, Germany *
26 * U. of Bonn, Germany *
27 * *
28 * Redistribution and use in source and binary forms, with or without *
29 * modification, are permitted according to the terms listed in LICENSE *
30 * (http://tmva.sourceforge.net/LICENSE) *
31 * *
32 **********************************************************************************/
33
34/*! \class TMVA::MethodBase
35\ingroup TMVA
36
37 Virtual base Class for all MVA method
38
39 MethodBase hosts several specific evaluation methods.
40
41 The kind of MVA that provides optimal performance in an analysis strongly
42 depends on the particular application. The evaluation factory provides a
43 number of numerical benchmark results to directly assess the performance
44 of the MVA training on the independent test sample. These are:
45
46 - The _signal efficiency_ at three representative background efficiencies
47 (which is 1 &minus; rejection).
48 - The _significance_ of an MVA estimator, defined by the difference
49 between the MVA mean values for signal and background, divided by the
50 quadratic sum of their root mean squares.
51 - The _separation_ of an MVA _x_, defined by the integral
52 \f[
53 \frac{1}{2} \int \frac{(S(x) - B(x))^2}{(S(x) + B(x))} dx
54 \f]
55 where
56 \f$ S(x) \f$ and \f$ B(x) \f$ are the signal and background distributions,
57 respectively. The separation is zero for identical signal and background MVA
58 shapes, and it is one for disjunctive shapes.
59 - The average, \f$ \int x \mu (S(x)) dx \f$, of the signal \f$ \mu_{transform} \f$.
60 The \f$ \mu_{transform} \f$ of an MVA denotes the transformation that yields
61 a uniform background distribution. In this way, the signal distributions
62 \f$ S(x) \f$ can be directly compared among the various MVAs. The stronger
63 \f$ S(x) \f$ peaks towards one, the better is the discrimination of the MVA.
64 The \f$ \mu_{transform} \f$ is
65 [documented here](http://tel.ccsd.cnrs.fr/documents/archives0/00/00/29/91/index_fr.html).
66
67 The MVA standard output also prints the linear correlation coefficients between
68 signal and background, which can be useful to eliminate variables that exhibit too
69 strong correlations.
70*/
71
72#include "TMVA/MethodBase.h"
73
74#include "TMVA/Config.h"
75#include "TMVA/Configurable.h"
76#include "TMVA/DataSetInfo.h"
77#include "TMVA/DataSet.h"
78#include "TMVA/Factory.h"
79#include "TMVA/IMethod.h"
80#include "TMVA/MsgLogger.h"
81#include "TMVA/PDF.h"
82#include "TMVA/Ranking.h"
83#include "TMVA/DataLoader.h"
84#include "TMVA/Tools.h"
85#include "TMVA/Results.h"
89#include "TMVA/RootFinder.h"
90#include "TMVA/Timer.h"
91#include "TMVA/TSpline1.h"
92#include "TMVA/Types.h"
96#include "TMVA/VariableInfo.h"
100#include "TMVA/Version.h"
101
102#include "TROOT.h"
103#include "TSystem.h"
104#include "TObjString.h"
105#include "TQObject.h"
106#include "TSpline.h"
107#include "TMatrix.h"
108#include "TMath.h"
109#include "TH1F.h"
110#include "TH2F.h"
111#include "TFile.h"
112#include "TGraph.h"
113#include "TXMLEngine.h"
114
115#include <iomanip>
116#include <iostream>
117#include <fstream>
118#include <sstream>
119#include <cstdlib>
120#include <algorithm>
121#include <limits>
122
123
125
126using std::endl;
127using std::atof;
128
129//const Int_t MethodBase_MaxIterations_ = 200;
131
132//const Int_t NBIN_HIST_PLOT = 100;
133const Int_t NBIN_HIST_HIGH = 10000;
134
135#ifdef _WIN32
136/* Disable warning C4355: 'this' : used in base member initializer list */
137#pragma warning ( disable : 4355 )
138#endif
139
140
141#include "TMultiGraph.h"
142
143////////////////////////////////////////////////////////////////////////////////
144/// standard constructor
145
147{
148 fNumGraphs = 0;
149 fIndex = 0;
150}
151
152////////////////////////////////////////////////////////////////////////////////
153/// standard destructor
155{
156 if (fMultiGraph){
157 delete fMultiGraph;
158 fMultiGraph = nullptr;
159 }
160 return;
161}
162
163////////////////////////////////////////////////////////////////////////////////
164/// This function gets some title and it creates a TGraph for every title.
165/// It also sets up the style for every TGraph. All graphs are added to a single TMultiGraph.
166///
167/// \param[in] graphTitles vector of titles
168
169void TMVA::IPythonInteractive::Init(std::vector<TString>& graphTitles)
170{
171 if (fNumGraphs!=0){
172 std::cerr << kERROR << "IPythonInteractive::Init: already initialized..." << std::endl;
173 return;
174 }
175 Int_t color = 2;
176 for(auto& title : graphTitles){
177 fGraphs.push_back( new TGraph() );
178 fGraphs.back()->SetTitle(title);
179 fGraphs.back()->SetName(title);
180 fGraphs.back()->SetFillColor(color);
181 fGraphs.back()->SetLineColor(color);
182 fGraphs.back()->SetMarkerColor(color);
183 fMultiGraph->Add(fGraphs.back());
184 color += 2;
185 fNumGraphs += 1;
186 }
187 return;
188}
189
190////////////////////////////////////////////////////////////////////////////////
191/// This function sets the point number to 0 for all graphs.
192
194{
195 for(Int_t i=0; i<fNumGraphs; i++){
196 fGraphs[i]->Set(0);
197 }
198}
199
200////////////////////////////////////////////////////////////////////////////////
201/// This function is used only in 2 TGraph case, and it will add new data points to graphs.
202///
203/// \param[in] x the x coordinate
204/// \param[in] y1 the y coordinate for the first TGraph
205/// \param[in] y2 the y coordinate for the second TGraph
206
208{
209 fGraphs[0]->Set(fIndex+1);
210 fGraphs[1]->Set(fIndex+1);
211 fGraphs[0]->SetPoint(fIndex, x, y1);
212 fGraphs[1]->SetPoint(fIndex, x, y2);
213 fIndex++;
214 return;
215}
216
217////////////////////////////////////////////////////////////////////////////////
218/// This function can add data points to as many TGraphs as we have.
219///
220/// \param[in] dat vector of data points. The dat[0] contains the x coordinate,
221/// dat[1] contains the y coordinate for first TGraph, dat[2] for second, ...
222
223void TMVA::IPythonInteractive::AddPoint(std::vector<Double_t>& dat)
224{
225 for(Int_t i=0; i<fNumGraphs;i++){
226 fGraphs[i]->Set(fIndex+1);
227 fGraphs[i]->SetPoint(fIndex, dat[0], dat[i+1]);
228 }
229 fIndex++;
230 return;
231}
232
233
234////////////////////////////////////////////////////////////////////////////////
235/// standard constructor
236
238 Types::EMVA methodType,
239 const TString& methodTitle,
240 DataSetInfo& dsi,
241 const TString& theOption) :
242 IMethod(),
243 Configurable ( theOption ),
244 fTmpEvent ( 0 ),
245 fRanking ( 0 ),
246 fInputVars ( 0 ),
247 fAnalysisType ( Types::kNoAnalysisType ),
248 fRegressionReturnVal ( 0 ),
249 fMulticlassReturnVal ( 0 ),
250 fDataSetInfo ( dsi ),
251 fSignalReferenceCut ( 0.5 ),
252 fSignalReferenceCutOrientation( 1. ),
253 fVariableTransformType ( Types::kSignal ),
254 fJobName ( jobName ),
255 fMethodName ( methodTitle ),
256 fMethodType ( methodType ),
257 fTestvar ( "" ),
258 fTMVATrainingVersion ( TMVA_VERSION_CODE ),
259 fROOTTrainingVersion ( ROOT_VERSION_CODE ),
260 fConstructedFromWeightFile ( kFALSE ),
261 fBaseDir ( 0 ),
262 fMethodBaseDir ( 0 ),
263 fFile ( 0 ),
264 fSilentFile (kFALSE),
265 fModelPersistence (kTRUE),
266 fWeightFile ( "" ),
267 fEffS ( 0 ),
268 fDefaultPDF ( 0 ),
269 fMVAPdfS ( 0 ),
270 fMVAPdfB ( 0 ),
271 fSplS ( 0 ),
272 fSplB ( 0 ),
273 fSpleffBvsS ( 0 ),
274 fSplTrainS ( 0 ),
275 fSplTrainB ( 0 ),
276 fSplTrainEffBvsS ( 0 ),
277 fVarTransformString ( "None" ),
278 fTransformationPointer ( 0 ),
279 fTransformation ( dsi, methodTitle ),
280 fVerbose ( kFALSE ),
281 fVerbosityLevelString ( "Default" ),
282 fHelp ( kFALSE ),
283 fHasMVAPdfs ( kFALSE ),
284 fIgnoreNegWeightsInTraining( kFALSE ),
285 fSignalClass ( 0 ),
286 fBackgroundClass ( 0 ),
287 fSplRefS ( 0 ),
288 fSplRefB ( 0 ),
289 fSplTrainRefS ( 0 ),
290 fSplTrainRefB ( 0 ),
291 fSetupCompleted (kFALSE)
292{
295
296// // default extension for weight files
297}
298
299////////////////////////////////////////////////////////////////////////////////
300/// constructor used for Testing + Application of the MVA,
301/// only (no training), using given WeightFiles
302
304 DataSetInfo& dsi,
305 const TString& weightFile ) :
306 IMethod(),
307 Configurable(""),
308 fTmpEvent ( 0 ),
309 fRanking ( 0 ),
310 fInputVars ( 0 ),
311 fAnalysisType ( Types::kNoAnalysisType ),
312 fRegressionReturnVal ( 0 ),
313 fMulticlassReturnVal ( 0 ),
314 fDataSetInfo ( dsi ),
315 fSignalReferenceCut ( 0.5 ),
316 fVariableTransformType ( Types::kSignal ),
317 fJobName ( "" ),
318 fMethodName ( "MethodBase" ),
319 fMethodType ( methodType ),
320 fTestvar ( "" ),
321 fTMVATrainingVersion ( 0 ),
322 fROOTTrainingVersion ( 0 ),
323 fConstructedFromWeightFile ( kTRUE ),
324 fBaseDir ( 0 ),
325 fMethodBaseDir ( 0 ),
326 fFile ( 0 ),
327 fSilentFile (kFALSE),
328 fModelPersistence (kTRUE),
329 fWeightFile ( weightFile ),
330 fEffS ( 0 ),
331 fDefaultPDF ( 0 ),
332 fMVAPdfS ( 0 ),
333 fMVAPdfB ( 0 ),
334 fSplS ( 0 ),
335 fSplB ( 0 ),
336 fSpleffBvsS ( 0 ),
337 fSplTrainS ( 0 ),
338 fSplTrainB ( 0 ),
339 fSplTrainEffBvsS ( 0 ),
340 fVarTransformString ( "None" ),
341 fTransformationPointer ( 0 ),
342 fTransformation ( dsi, "" ),
343 fVerbose ( kFALSE ),
344 fVerbosityLevelString ( "Default" ),
345 fHelp ( kFALSE ),
346 fHasMVAPdfs ( kFALSE ),
347 fIgnoreNegWeightsInTraining( kFALSE ),
348 fSignalClass ( 0 ),
349 fBackgroundClass ( 0 ),
350 fSplRefS ( 0 ),
351 fSplRefB ( 0 ),
352 fSplTrainRefS ( 0 ),
353 fSplTrainRefB ( 0 ),
354 fSetupCompleted (kFALSE)
355{
357// // constructor used for Testing + Application of the MVA,
358// // only (no training), using given WeightFiles
359}
360
361////////////////////////////////////////////////////////////////////////////////
362/// destructor
363
365{
366 // destructor
367 if (!fSetupCompleted) Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Calling destructor of method which got never setup" << Endl;
368
369 // destructor
370 if (fInputVars != 0) { fInputVars->clear(); delete fInputVars; }
371 if (fRanking != 0) delete fRanking;
372
373 // PDFs
374 if (fDefaultPDF!= 0) { delete fDefaultPDF; fDefaultPDF = 0; }
375 if (fMVAPdfS != 0) { delete fMVAPdfS; fMVAPdfS = 0; }
376 if (fMVAPdfB != 0) { delete fMVAPdfB; fMVAPdfB = 0; }
377
378 // Splines
379 if (fSplS) { delete fSplS; fSplS = 0; }
380 if (fSplB) { delete fSplB; fSplB = 0; }
381 if (fSpleffBvsS) { delete fSpleffBvsS; fSpleffBvsS = 0; }
382 if (fSplRefS) { delete fSplRefS; fSplRefS = 0; }
383 if (fSplRefB) { delete fSplRefB; fSplRefB = 0; }
384 if (fSplTrainRefS) { delete fSplTrainRefS; fSplTrainRefS = 0; }
385 if (fSplTrainRefB) { delete fSplTrainRefB; fSplTrainRefB = 0; }
386 if (fSplTrainEffBvsS) { delete fSplTrainEffBvsS; fSplTrainEffBvsS = 0; }
387
388 for (Int_t i = 0; i < 2; i++ ) {
389 if (fEventCollections.at(i)) {
390 for (std::vector<Event*>::const_iterator it = fEventCollections.at(i)->begin();
391 it != fEventCollections.at(i)->end(); ++it) {
392 delete (*it);
393 }
394 delete fEventCollections.at(i);
395 fEventCollections.at(i) = 0;
396 }
397 }
398
399 if (fRegressionReturnVal) delete fRegressionReturnVal;
400 if (fMulticlassReturnVal) delete fMulticlassReturnVal;
401}
402
403////////////////////////////////////////////////////////////////////////////////
404/// setup of methods
405
407{
408 // setup of methods
409
410 if (fSetupCompleted) Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Calling SetupMethod for the second time" << Endl;
411 InitBase();
412 DeclareBaseOptions();
413 Init();
414 DeclareOptions();
415 fSetupCompleted = kTRUE;
416}
417
418////////////////////////////////////////////////////////////////////////////////
419/// process all options
420/// the "CheckForUnusedOptions" is done in an independent call, since it may be overridden by derived class
421/// (sometimes, eg, fitters are used which can only be implemented during training phase)
422
424{
425 ProcessBaseOptions();
426 ProcessOptions();
427}
428
429////////////////////////////////////////////////////////////////////////////////
430/// check may be overridden by derived class
431/// (sometimes, eg, fitters are used which can only be implemented during training phase)
432
434{
435 CheckForUnusedOptions();
436}
437
438////////////////////////////////////////////////////////////////////////////////
439/// default initialization called by all constructors
440
442{
443 SetConfigDescription( "Configuration options for classifier architecture and tuning" );
444
446 fNbinsMVAoutput = gConfig().fVariablePlotting.fNbinsMVAoutput;
447 fNbinsH = NBIN_HIST_HIGH;
448
449 fSplTrainS = 0;
450 fSplTrainB = 0;
451 fSplTrainEffBvsS = 0;
452 fMeanS = -1;
453 fMeanB = -1;
454 fRmsS = -1;
455 fRmsB = -1;
456 fXmin = DBL_MAX;
457 fXmax = -DBL_MAX;
458 fTxtWeightsOnly = kTRUE;
459 fSplRefS = 0;
460 fSplRefB = 0;
461
462 fTrainTime = -1.;
463 fTestTime = -1.;
464
465 fRanking = 0;
466
467 // temporary until the move to DataSet is complete
468 fInputVars = new std::vector<TString>;
469 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
470 fInputVars->push_back(DataInfo().GetVariableInfo(ivar).GetLabel());
471 }
472 fRegressionReturnVal = 0;
473 fMulticlassReturnVal = 0;
474
475 fEventCollections.resize( 2 );
476 fEventCollections.at(0) = 0;
477 fEventCollections.at(1) = 0;
478
479 // retrieve signal and background class index
480 if (DataInfo().GetClassInfo("Signal") != 0) {
481 fSignalClass = DataInfo().GetClassInfo("Signal")->GetNumber();
482 }
483 if (DataInfo().GetClassInfo("Background") != 0) {
484 fBackgroundClass = DataInfo().GetClassInfo("Background")->GetNumber();
485 }
486
487 SetConfigDescription( "Configuration options for MVA method" );
488 SetConfigName( TString("Method") + GetMethodTypeName() );
489}
490
491////////////////////////////////////////////////////////////////////////////////
492/// define the options (their key words) that can be set in the option string
493/// here the options valid for ALL MVA methods are declared.
494///
495/// know options:
496///
497/// - VariableTransform=None,Decorrelated,PCA to use transformed variables
498/// instead of the original ones
499/// - VariableTransformType=Signal,Background which decorrelation matrix to use
500/// in the method. Only the Likelihood
501/// Method can make proper use of independent
502/// transformations of signal and background
503/// - fNbinsMVAPdf = 50 Number of bins used to create a PDF of MVA
504/// - fNsmoothMVAPdf = 2 Number of times a histogram is smoothed before creating the PDF
505/// - fHasMVAPdfs create PDFs for the MVA outputs
506/// - V for Verbose output (!V) for non verbos
507/// - H for Help message
508
510{
511 DeclareOptionRef( fVerbose, "V", "Verbose output (short form of \"VerbosityLevel\" below - overrides the latter one)" );
512
513 DeclareOptionRef( fVerbosityLevelString="Default", "VerbosityLevel", "Verbosity level" );
514 AddPreDefVal( TString("Default") ); // uses default defined in MsgLogger header
515 AddPreDefVal( TString("Debug") );
516 AddPreDefVal( TString("Verbose") );
517 AddPreDefVal( TString("Info") );
518 AddPreDefVal( TString("Warning") );
519 AddPreDefVal( TString("Error") );
520 AddPreDefVal( TString("Fatal") );
521
522 // If True (default): write all training results (weights) as text files only;
523 // if False: write also in ROOT format (not available for all methods - will abort if not
524 fTxtWeightsOnly = kTRUE; // OBSOLETE !!!
525 fNormalise = kFALSE; // OBSOLETE !!!
526
527 DeclareOptionRef( fVarTransformString, "VarTransform", "List of variable transformations performed before training, e.g., \"D_Background,P_Signal,G,N_AllClasses\" for: \"Decorrelation, PCA-transformation, Gaussianisation, Normalisation, each for the given class of events ('AllClasses' denotes all events of all classes, if no class indication is given, 'All' is assumed)\"" );
528
529 DeclareOptionRef( fHelp, "H", "Print method-specific help message" );
530
531 DeclareOptionRef( fHasMVAPdfs, "CreateMVAPdfs", "Create PDFs for classifier outputs (signal and background)" );
532
533 DeclareOptionRef( fIgnoreNegWeightsInTraining, "IgnoreNegWeightsInTraining",
534 "Events with negative weights are ignored in the training (but are included for testing and performance evaluation)" );
535}
536
537////////////////////////////////////////////////////////////////////////////////
538/// the option string is decoded, for available options see "DeclareOptions"
539
541{
542 if (HasMVAPdfs()) {
543 // setting the default bin num... maybe should be static ? ==> Please no static (JS)
544 // You can't use the logger in the constructor!!! Log() << kINFO << "Create PDFs" << Endl;
545 // reading every PDF's definition and passing the option string to the next one to be read and marked
546 fDefaultPDF = new PDF( TString(GetName())+"_PDF", GetOptions(), "MVAPdf" );
547 fDefaultPDF->DeclareOptions();
548 fDefaultPDF->ParseOptions();
549 fDefaultPDF->ProcessOptions();
550 fMVAPdfB = new PDF( TString(GetName())+"_PDFBkg", fDefaultPDF->GetOptions(), "MVAPdfBkg", fDefaultPDF );
551 fMVAPdfB->DeclareOptions();
552 fMVAPdfB->ParseOptions();
553 fMVAPdfB->ProcessOptions();
554 fMVAPdfS = new PDF( TString(GetName())+"_PDFSig", fMVAPdfB->GetOptions(), "MVAPdfSig", fDefaultPDF );
555 fMVAPdfS->DeclareOptions();
556 fMVAPdfS->ParseOptions();
557 fMVAPdfS->ProcessOptions();
558
559 // the final marked option string is written back to the original methodbase
560 SetOptions( fMVAPdfS->GetOptions() );
561 }
562
563 TMVA::CreateVariableTransforms( fVarTransformString,
564 DataInfo(),
565 GetTransformationHandler(),
566 Log() );
567
568 if (!HasMVAPdfs()) {
569 if (fDefaultPDF!= 0) { delete fDefaultPDF; fDefaultPDF = 0; }
570 if (fMVAPdfS != 0) { delete fMVAPdfS; fMVAPdfS = 0; }
571 if (fMVAPdfB != 0) { delete fMVAPdfB; fMVAPdfB = 0; }
572 }
573
574 if (fVerbose) { // overwrites other settings
575 fVerbosityLevelString = TString("Verbose");
576 Log().SetMinType( kVERBOSE );
577 }
578 else if (fVerbosityLevelString == "Debug" ) Log().SetMinType( kDEBUG );
579 else if (fVerbosityLevelString == "Verbose" ) Log().SetMinType( kVERBOSE );
580 else if (fVerbosityLevelString == "Info" ) Log().SetMinType( kINFO );
581 else if (fVerbosityLevelString == "Warning" ) Log().SetMinType( kWARNING );
582 else if (fVerbosityLevelString == "Error" ) Log().SetMinType( kERROR );
583 else if (fVerbosityLevelString == "Fatal" ) Log().SetMinType( kFATAL );
584 else if (fVerbosityLevelString != "Default" ) {
585 Log() << kFATAL << "<ProcessOptions> Verbosity level type '"
586 << fVerbosityLevelString << "' unknown." << Endl;
587 }
588 Event::SetIgnoreNegWeightsInTraining(fIgnoreNegWeightsInTraining);
589}
590
591////////////////////////////////////////////////////////////////////////////////
592/// options that are used ONLY for the READER to ensure backward compatibility
593/// they are hence without any effect (the reader is only reading the training
594/// options that HAD been used at the training of the .xml weight file at hand
595
597{
598 DeclareOptionRef( fNormalise=kFALSE, "Normalise", "Normalise input variables" ); // don't change the default !!!
599 DeclareOptionRef( fUseDecorr=kFALSE, "D", "Use-decorrelated-variables flag" );
600 DeclareOptionRef( fVariableTransformTypeString="Signal", "VarTransformType",
601 "Use signal or background events to derive for variable transformation (the transformation is applied on both types of, course)" );
602 AddPreDefVal( TString("Signal") );
603 AddPreDefVal( TString("Background") );
604 DeclareOptionRef( fTxtWeightsOnly=kTRUE, "TxtWeightFilesOnly", "If True: write all training results (weights) as text files (False: some are written in ROOT format)" );
605 // Why on earth ?? was this here? Was the verbosity level option meant to 'disappear? Not a good idea i think..
606 // DeclareOptionRef( fVerbosityLevelString="Default", "VerboseLevel", "Verbosity level" );
607 // AddPreDefVal( TString("Default") ); // uses default defined in MsgLogger header
608 // AddPreDefVal( TString("Debug") );
609 // AddPreDefVal( TString("Verbose") );
610 // AddPreDefVal( TString("Info") );
611 // AddPreDefVal( TString("Warning") );
612 // AddPreDefVal( TString("Error") );
613 // AddPreDefVal( TString("Fatal") );
614 DeclareOptionRef( fNbinsMVAPdf = 60, "NbinsMVAPdf", "Number of bins used for the PDFs of classifier outputs" );
615 DeclareOptionRef( fNsmoothMVAPdf = 2, "NsmoothMVAPdf", "Number of smoothing iterations for classifier PDFs" );
616}
617
618
619////////////////////////////////////////////////////////////////////////////////
620/// call the Optimizer with the set of parameters and ranges that
621/// are meant to be tuned.
622
623std::map<TString,Double_t> TMVA::MethodBase::OptimizeTuningParameters(TString /* fomType */ , TString /* fitType */)
624{
625 // this is just a dummy... needs to be implemented for each method
626 // individually (as long as we don't have it automatized via the
627 // configuration string
628
629 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Parameter optimization is not yet implemented for method "
630 << GetName() << Endl;
631 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Currently we need to set hardcoded which parameter is tuned in which ranges"<<Endl;
632
633 std::map<TString,Double_t> tunedParameters;
634 tunedParameters.size(); // just to get rid of "unused" warning
635 return tunedParameters;
636
637}
638
639////////////////////////////////////////////////////////////////////////////////
640/// set the tuning parameters according to the argument
641/// This is just a dummy .. have a look at the MethodBDT how you could
642/// perhaps implement the same thing for the other Classifiers..
643
644void TMVA::MethodBase::SetTuneParameters(std::map<TString,Double_t> /* tuneParameters */)
645{
646}
647
648////////////////////////////////////////////////////////////////////////////////
649
651{
652 Data()->SetCurrentType(Types::kTraining);
653 Event::SetIsTraining(kTRUE); // used to set negative event weights to zero if chosen to do so
654
655 // train the MVA method
656 if (Help()) PrintHelpMessage();
657
658 // all histograms should be created in the method's subdirectory
659 if(!IsSilentFile()) BaseDir()->cd();
660
661 // once calculate all the transformation (e.g. the sequence of Decorr:Gauss:Decorr)
662 // needed for this classifier
663 GetTransformationHandler().CalcTransformations(Data()->GetEventCollection());
664
665 // call training of derived MVA
666 Log() << kDEBUG //<<Form("\tDataset[%s] : ",DataInfo().GetName())
667 << "Begin training" << Endl;
668 Long64_t nEvents = Data()->GetNEvents();
669 Timer traintimer( nEvents, GetName(), kTRUE );
670 Train();
671 Log() << kDEBUG //<<Form("Dataset[%s] : ",DataInfo().GetName()
672 << "\tEnd of training " << Endl;
673 SetTrainTime(traintimer.ElapsedSeconds());
674 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
675 << "Elapsed time for training with " << nEvents << " events: "
676 << traintimer.GetElapsedTime() << " " << Endl;
677
678 Log() << kDEBUG //<<Form("Dataset[%s] : ",DataInfo().GetName())
679 << "\tCreate MVA output for ";
680
681 // create PDFs for the signal and background MVA distributions (if required)
682 if (DoMulticlass()) {
683 Log() <<Form("[%s] : ",DataInfo().GetName())<< "Multiclass classification on training sample" << Endl;
684 AddMulticlassOutput(Types::kTraining);
685 }
686 else if (!DoRegression()) {
687
688 Log() <<Form("[%s] : ",DataInfo().GetName())<< "classification on training sample" << Endl;
689 AddClassifierOutput(Types::kTraining);
690 if (HasMVAPdfs()) {
691 CreateMVAPdfs();
692 AddClassifierOutputProb(Types::kTraining);
693 }
694
695 } else {
696
697 Log() <<Form("Dataset[%s] : ",DataInfo().GetName())<< "regression on training sample" << Endl;
698 AddRegressionOutput( Types::kTraining );
699
700 if (HasMVAPdfs() ) {
701 Log() <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Create PDFs" << Endl;
702 CreateMVAPdfs();
703 }
704 }
705
706 // write the current MVA state into stream
707 // produced are one text file and one ROOT file
708 if (fModelPersistence ) WriteStateToFile();
709
710 // produce standalone make class (presently only supported for classification)
711 if ((!DoRegression()) && (fModelPersistence)) MakeClass();
712
713 // write additional monitoring histograms to main target file (not the weight file)
714 // again, make sure the histograms go into the method's subdirectory
715 if(!IsSilentFile())
716 {
717 BaseDir()->cd();
718 WriteMonitoringHistosToFile();
719 }
720}
721
722////////////////////////////////////////////////////////////////////////////////
723
725{
726 if (!DoRegression()) Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Trying to use GetRegressionDeviation() with a classification job" << Endl;
727 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Create results for " << (type==Types::kTraining?"training":"testing") << Endl;
729 bool truncate = false;
730 TH1F* h1 = regRes->QuadraticDeviation( tgtNum , truncate, 1.);
731 stddev = sqrt(h1->GetMean());
732 truncate = true;
733 Double_t yq[1], xq[]={0.9};
734 h1->GetQuantiles(1,yq,xq);
735 TH1F* h2 = regRes->QuadraticDeviation( tgtNum , truncate, yq[0]);
736 stddev90Percent = sqrt(h2->GetMean());
737 delete h1;
738 delete h2;
739}
740
741////////////////////////////////////////////////////////////////////////////////
742/// prepare tree branch with the method's discriminating variable
743
745{
746 Data()->SetCurrentType(type);
747
748 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Create results for " << (type==Types::kTraining?"training":"testing") << Endl;
749
750 ResultsRegression* regRes = (ResultsRegression*)Data()->GetResults(GetMethodName(), type, Types::kRegression);
751
752 Long64_t nEvents = Data()->GetNEvents();
753
754 // use timer
755 Timer timer( nEvents, GetName(), kTRUE );
756 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName()) << "Evaluation of " << GetMethodName() << " on "
757 << (type==Types::kTraining?"training":"testing") << " sample" << Endl;
758
759 regRes->Resize( nEvents );
760
761 // Drawing the progress bar every event was causing a huge slowdown in the evaluation time
762 // So we set some parameters to draw the progress bar a total of totalProgressDraws, i.e. only draw every 1 in 100
763
764 Int_t totalProgressDraws = 100; // total number of times to update the progress bar
765 Int_t drawProgressEvery = 1; // draw every nth event such that we have a total of totalProgressDraws
766 if(nEvents >= totalProgressDraws) drawProgressEvery = nEvents/totalProgressDraws;
767
768 for (Int_t ievt=0; ievt<nEvents; ievt++) {
769
770 Data()->SetCurrentEvent(ievt);
771 std::vector< Float_t > vals = GetRegressionValues();
772 regRes->SetValue( vals, ievt );
773
774 // Only draw the progress bar once in a while, doing this every event causes the evaluation to be ridiculously slow
775 if(ievt % drawProgressEvery == 0 || ievt==nEvents-1) timer.DrawProgressBar( ievt );
776 }
777
778 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())
779 << "Elapsed time for evaluation of " << nEvents << " events: "
780 << timer.GetElapsedTime() << " " << Endl;
781
782 // store time used for testing
784 SetTestTime(timer.ElapsedSeconds());
785
786 TString histNamePrefix(GetTestvarName());
787 histNamePrefix += (type==Types::kTraining?"train":"test");
788 regRes->CreateDeviationHistograms( histNamePrefix );
789}
790
791////////////////////////////////////////////////////////////////////////////////
792/// prepare tree branch with the method's discriminating variable
793
795{
796 Data()->SetCurrentType(type);
797
798 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Create results for " << (type==Types::kTraining?"training":"testing") << Endl;
799
800 ResultsMulticlass* resMulticlass = dynamic_cast<ResultsMulticlass*>(Data()->GetResults(GetMethodName(), type, Types::kMulticlass));
801 if (!resMulticlass) Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName())<< "unable to create pointer in AddMulticlassOutput, exiting."<<Endl;
802
803 Long64_t nEvents = Data()->GetNEvents();
804
805 // use timer
806 Timer timer( nEvents, GetName(), kTRUE );
807
808 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Multiclass evaluation of " << GetMethodName() << " on "
809 << (type==Types::kTraining?"training":"testing") << " sample" << Endl;
810
811 resMulticlass->Resize( nEvents );
812 for (Int_t ievt=0; ievt<nEvents; ievt++) {
813 Data()->SetCurrentEvent(ievt);
814 std::vector< Float_t > vals = GetMulticlassValues();
815 resMulticlass->SetValue( vals, ievt );
816 timer.DrawProgressBar( ievt );
817 }
818
819 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())
820 << "Elapsed time for evaluation of " << nEvents << " events: "
821 << timer.GetElapsedTime() << " " << Endl;
822
823 // store time used for testing
825 SetTestTime(timer.ElapsedSeconds());
826
827 TString histNamePrefix(GetTestvarName());
828 histNamePrefix += (type==Types::kTraining?"_Train":"_Test");
829
830 resMulticlass->CreateMulticlassHistos( histNamePrefix, fNbinsMVAoutput, fNbinsH );
831 resMulticlass->CreateMulticlassPerformanceHistos(histNamePrefix);
832}
833
834////////////////////////////////////////////////////////////////////////////////
835
836void TMVA::MethodBase::NoErrorCalc(Double_t* const err, Double_t* const errUpper) {
837 if (err) *err=-1;
838 if (errUpper) *errUpper=-1;
839}
840
841////////////////////////////////////////////////////////////////////////////////
842
843Double_t TMVA::MethodBase::GetMvaValue( const Event* const ev, Double_t* err, Double_t* errUpper ) {
844 fTmpEvent = ev;
845 Double_t val = GetMvaValue(err, errUpper);
846 fTmpEvent = 0;
847 return val;
848}
849
850////////////////////////////////////////////////////////////////////////////////
851/// uses a pre-set cut on the MVA output (SetSignalReferenceCut and SetSignalReferenceCutOrientation)
852/// for a quick determination if an event would be selected as signal or background
853
855 return GetMvaValue()*GetSignalReferenceCutOrientation() > GetSignalReferenceCut()*GetSignalReferenceCutOrientation() ? kTRUE : kFALSE;
856}
857////////////////////////////////////////////////////////////////////////////////
858/// uses a pre-set cut on the MVA output (SetSignalReferenceCut and SetSignalReferenceCutOrientation)
859/// for a quick determination if an event with this mva output value would be selected as signal or background
860
862 return mvaVal*GetSignalReferenceCutOrientation() > GetSignalReferenceCut()*GetSignalReferenceCutOrientation() ? kTRUE : kFALSE;
863}
864
865////////////////////////////////////////////////////////////////////////////////
866/// prepare tree branch with the method's discriminating variable
867
869{
870 Data()->SetCurrentType(type);
871
872 ResultsClassification* clRes =
874
875 Long64_t nEvents = Data()->GetNEvents();
876 clRes->Resize( nEvents );
877
878 // use timer
879 Timer timer( nEvents, GetName(), kTRUE );
880 std::vector<Double_t> mvaValues = GetMvaValues(0, nEvents, true);
881
882 // store time used for testing
884 SetTestTime(timer.ElapsedSeconds());
885
886 // load mva values and type to results object
887 for (Int_t ievt = 0; ievt < nEvents; ievt++) {
888 // note we do not need the trasformed event to get the signal/background information
889 // by calling Data()->GetEvent instead of this->GetEvent we access the untransformed one
890 auto ev = Data()->GetEvent(ievt);
891 clRes->SetValue(mvaValues[ievt], ievt, DataInfo().IsSignal(ev));
892 }
893}
894
895////////////////////////////////////////////////////////////////////////////////
896/// get all the MVA values for the events of the current Data type
897std::vector<Double_t> TMVA::MethodBase::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
898{
899
900 Long64_t nEvents = Data()->GetNEvents();
901 if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
902 if (firstEvt < 0) firstEvt = 0;
903 std::vector<Double_t> values(lastEvt-firstEvt);
904 // log in case of looping on all the events
905 nEvents = values.size();
906
907 // use timer
908 Timer timer( nEvents, GetName(), kTRUE );
909
910 if (logProgress)
911 Log() << kHEADER << Form("[%s] : ",DataInfo().GetName())
912 << "Evaluation of " << GetMethodName() << " on "
913 << (Data()->GetCurrentType() == Types::kTraining ? "training" : "testing")
914 << " sample (" << nEvents << " events)" << Endl;
915
916 for (Int_t ievt=firstEvt; ievt<lastEvt; ievt++) {
917 Data()->SetCurrentEvent(ievt);
918 values[ievt] = GetMvaValue();
919
920 // print progress
921 if (logProgress) {
922 Int_t modulo = Int_t(nEvents/100);
923 if (modulo <= 0 ) modulo = 1;
924 if (ievt%modulo == 0) timer.DrawProgressBar( ievt );
925 }
926 }
927 if (logProgress) {
928 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
929 << "Elapsed time for evaluation of " << nEvents << " events: "
930 << timer.GetElapsedTime() << " " << Endl;
931 }
932
933 return values;
934}
935
936////////////////////////////////////////////////////////////////////////////////
937/// get all the MVA values for the events of the given Data type
938// (this is used by Method Category and it does not need to be re-implmented by derived classes )
939std::vector<Double_t> TMVA::MethodBase::GetDataMvaValues(DataSet * data, Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
940{
941 fTmpData = data;
942 auto result = GetMvaValues(firstEvt, lastEvt, logProgress);
943 fTmpData = nullptr;
944 return result;
945}
946
947////////////////////////////////////////////////////////////////////////////////
948/// prepare tree branch with the method's discriminating variable
949
951{
952 Data()->SetCurrentType(type);
953
954 ResultsClassification* mvaProb =
955 (ResultsClassification*)Data()->GetResults(TString("prob_")+GetMethodName(), type, Types::kClassification );
956
957 Long64_t nEvents = Data()->GetNEvents();
958
959 // use timer
960 Timer timer( nEvents, GetName(), kTRUE );
961
962 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName()) << "Evaluation of " << GetMethodName() << " on "
963 << (type==Types::kTraining?"training":"testing") << " sample" << Endl;
964
965 mvaProb->Resize( nEvents );
966 for (Int_t ievt=0; ievt<nEvents; ievt++) {
967
968 Data()->SetCurrentEvent(ievt);
969 Float_t proba = ((Float_t)GetProba( GetMvaValue(), 0.5 ));
970 if (proba < 0) break;
971 mvaProb->SetValue( proba, ievt, DataInfo().IsSignal( Data()->GetEvent()) );
972
973 // print progress
974 Int_t modulo = Int_t(nEvents/100);
975 if (modulo <= 0 ) modulo = 1;
976 if (ievt%modulo == 0) timer.DrawProgressBar( ievt );
977 }
978
979 Log() << kDEBUG <<Form("Dataset[%s] : ",DataInfo().GetName())
980 << "Elapsed time for evaluation of " << nEvents << " events: "
981 << timer.GetElapsedTime() << " " << Endl;
982}
983
984////////////////////////////////////////////////////////////////////////////////
985/// calculate <sum-of-deviation-squared> of regression output versus "true" value from test sample
986///
987/// - bias = average deviation
988/// - dev = average absolute deviation
989/// - rms = rms of deviation
990
992 Double_t& dev, Double_t& devT,
993 Double_t& rms, Double_t& rmsT,
994 Double_t& mInf, Double_t& mInfT,
995 Double_t& corr,
997{
998 Types::ETreeType savedType = Data()->GetCurrentType();
999 Data()->SetCurrentType(type);
1000
1001 bias = 0; biasT = 0; dev = 0; devT = 0; rms = 0; rmsT = 0;
1002 Double_t sumw = 0;
1003 Double_t m1 = 0, m2 = 0, s1 = 0, s2 = 0, s12 = 0; // for correlation
1004 const Int_t nevt = GetNEvents();
1005 Float_t* rV = new Float_t[nevt];
1006 Float_t* tV = new Float_t[nevt];
1007 Float_t* wV = new Float_t[nevt];
1008 Float_t xmin = 1e30, xmax = -1e30;
1009 Log() << kINFO << "Calculate regression for all events" << Endl;
1010 Timer timer( nevt, GetName(), kTRUE );
1011 for (Long64_t ievt=0; ievt<nevt; ievt++) {
1012
1013 const Event* ev = Data()->GetEvent(ievt); // NOTE: need untransformed event here !
1014 Float_t t = ev->GetTarget(0);
1015 Float_t w = ev->GetWeight();
1016 Float_t r = GetRegressionValues()[0];
1017 Float_t d = (r-t);
1018
1019 // find min/max
1022
1023 // store for truncated RMS computation
1024 rV[ievt] = r;
1025 tV[ievt] = t;
1026 wV[ievt] = w;
1027
1028 // compute deviation-squared
1029 sumw += w;
1030 bias += w * d;
1031 dev += w * TMath::Abs(d);
1032 rms += w * d * d;
1033
1034 // compute correlation between target and regression estimate
1035 m1 += t*w; s1 += t*t*w;
1036 m2 += r*w; s2 += r*r*w;
1037 s12 += t*r;
1038 if ((ievt & 0xFF) == 0) timer.DrawProgressBar(ievt);
1039 }
1040 timer.DrawProgressBar(nevt - 1);
1041 Log() << kINFO << "Elapsed time for evaluation of " << nevt << " events: "
1042 << timer.GetElapsedTime() << " " << Endl;
1043
1044 // standard quantities
1045 bias /= sumw;
1046 dev /= sumw;
1047 rms /= sumw;
1048 rms = TMath::Sqrt(rms - bias*bias);
1049
1050 // correlation
1051 m1 /= sumw;
1052 m2 /= sumw;
1053 corr = s12/sumw - m1*m2;
1054 corr /= TMath::Sqrt( (s1/sumw - m1*m1) * (s2/sumw - m2*m2) );
1055
1056 // create histogram required for computation of mutual information
1057 TH2F* hist = new TH2F( "hist", "hist", 150, xmin, xmax, 100, xmin, xmax );
1058 TH2F* histT = new TH2F( "histT", "histT", 150, xmin, xmax, 100, xmin, xmax );
1059
1060 // compute truncated RMS and fill histogram
1061 Double_t devMax = bias + 2*rms;
1062 Double_t devMin = bias - 2*rms;
1063 sumw = 0;
1064 int ic=0;
1065 for (Long64_t ievt=0; ievt<nevt; ievt++) {
1066 Float_t d = (rV[ievt] - tV[ievt]);
1067 hist->Fill( rV[ievt], tV[ievt], wV[ievt] );
1068 if (d >= devMin && d <= devMax) {
1069 sumw += wV[ievt];
1070 biasT += wV[ievt] * d;
1071 devT += wV[ievt] * TMath::Abs(d);
1072 rmsT += wV[ievt] * d * d;
1073 histT->Fill( rV[ievt], tV[ievt], wV[ievt] );
1074 ic++;
1075 }
1076 }
1077 biasT /= sumw;
1078 devT /= sumw;
1079 rmsT /= sumw;
1080 rmsT = TMath::Sqrt(rmsT - biasT*biasT);
1081 mInf = gTools().GetMutualInformation( *hist );
1082 mInfT = gTools().GetMutualInformation( *histT );
1083
1084 delete hist;
1085 delete histT;
1086
1087 delete [] rV;
1088 delete [] tV;
1089 delete [] wV;
1090
1091 Data()->SetCurrentType(savedType);
1092}
1093
1094
1095////////////////////////////////////////////////////////////////////////////////
1096/// test multiclass classification
1097
1099{
1100 ResultsMulticlass* resMulticlass = dynamic_cast<ResultsMulticlass*>(Data()->GetResults(GetMethodName(), Types::kTesting, Types::kMulticlass));
1101 if (!resMulticlass) Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName())<< "unable to create pointer in TestMulticlass, exiting."<<Endl;
1102
1103 // GA evaluation of best cut for sig eff * sig pur. Slow, disabled for now.
1104 // Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Determine optimal multiclass cuts for test
1105 // data..." << Endl; for (UInt_t icls = 0; icls<DataInfo().GetNClasses(); ++icls) {
1106 // resMulticlass->GetBestMultiClassCuts(icls);
1107 // }
1108
1109 // Create histograms for use in TMVA GUI
1110 TString histNamePrefix(GetTestvarName());
1111 TString histNamePrefixTest{histNamePrefix + "_Test"};
1112 TString histNamePrefixTrain{histNamePrefix + "_Train"};
1113
1114 resMulticlass->CreateMulticlassHistos(histNamePrefixTest, fNbinsMVAoutput, fNbinsH);
1115 resMulticlass->CreateMulticlassPerformanceHistos(histNamePrefixTest);
1116
1117 resMulticlass->CreateMulticlassHistos(histNamePrefixTrain, fNbinsMVAoutput, fNbinsH);
1118 resMulticlass->CreateMulticlassPerformanceHistos(histNamePrefixTrain);
1119}
1120
1121
1122////////////////////////////////////////////////////////////////////////////////
1123/// initialization
1124
1126{
1127 Data()->SetCurrentType(Types::kTesting);
1128
1129 ResultsClassification* mvaRes = dynamic_cast<ResultsClassification*>
1130 ( Data()->GetResults(GetMethodName(),Types::kTesting, Types::kClassification) );
1131
1132 // sanity checks: tree must exist, and theVar must be in tree
1133 if (0==mvaRes && !(GetMethodTypeName().Contains("Cuts"))) {
1134 Log()<<Form("Dataset[%s] : ",DataInfo().GetName()) << "mvaRes " << mvaRes << " GetMethodTypeName " << GetMethodTypeName()
1135 << " contains " << !(GetMethodTypeName().Contains("Cuts")) << Endl;
1136 Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName()) << "<TestInit> Test variable " << GetTestvarName()
1137 << " not found in tree" << Endl;
1138 }
1139
1140 // basic statistics operations are made in base class
1141 gTools().ComputeStat( GetEventCollection(Types::kTesting), mvaRes->GetValueVector(),
1142 fMeanS, fMeanB, fRmsS, fRmsB, fXmin, fXmax, fSignalClass );
1143
1144 // choose reasonable histogram ranges, by removing outliers
1145 Double_t nrms = 10;
1146 fXmin = TMath::Max( TMath::Min( fMeanS - nrms*fRmsS, fMeanB - nrms*fRmsB ), fXmin );
1147 fXmax = TMath::Min( TMath::Max( fMeanS + nrms*fRmsS, fMeanB + nrms*fRmsB ), fXmax );
1148
1149 // determine cut orientation
1150 fCutOrientation = (fMeanS > fMeanB) ? kPositive : kNegative;
1151
1152 // fill 2 types of histograms for the various analyses
1153 // this one is for actual plotting
1154
1155 Double_t sxmax = fXmax+0.00001;
1156
1157 // classifier response distributions for training sample
1158 // MVA plots used for graphics representation (signal)
1159 TString TestvarName;
1160 if(IsSilentFile())
1161 {
1162 TestvarName=Form("[%s]%s",DataInfo().GetName(),GetTestvarName().Data());
1163 }else
1164 {
1165 TestvarName=GetTestvarName();
1166 }
1167 TH1* mva_s = new TH1D( TestvarName + "_S",TestvarName + "_S", fNbinsMVAoutput, fXmin, sxmax );
1168 TH1* mva_b = new TH1D( TestvarName + "_B",TestvarName + "_B", fNbinsMVAoutput, fXmin, sxmax );
1169 mvaRes->Store(mva_s, "MVA_S");
1170 mvaRes->Store(mva_b, "MVA_B");
1171 mva_s->Sumw2();
1172 mva_b->Sumw2();
1173
1174 TH1* proba_s = 0;
1175 TH1* proba_b = 0;
1176 TH1* rarity_s = 0;
1177 TH1* rarity_b = 0;
1178 if (HasMVAPdfs()) {
1179 // P(MVA) plots used for graphics representation
1180 proba_s = new TH1D( TestvarName + "_Proba_S", TestvarName + "_Proba_S", fNbinsMVAoutput, 0.0, 1.0 );
1181 proba_b = new TH1D( TestvarName + "_Proba_B", TestvarName + "_Proba_B", fNbinsMVAoutput, 0.0, 1.0 );
1182 mvaRes->Store(proba_s, "Prob_S");
1183 mvaRes->Store(proba_b, "Prob_B");
1184 proba_s->Sumw2();
1185 proba_b->Sumw2();
1186
1187 // R(MVA) plots used for graphics representation
1188 rarity_s = new TH1D( TestvarName + "_Rarity_S", TestvarName + "_Rarity_S", fNbinsMVAoutput, 0.0, 1.0 );
1189 rarity_b = new TH1D( TestvarName + "_Rarity_B", TestvarName + "_Rarity_B", fNbinsMVAoutput, 0.0, 1.0 );
1190 mvaRes->Store(rarity_s, "Rar_S");
1191 mvaRes->Store(rarity_b, "Rar_B");
1192 rarity_s->Sumw2();
1193 rarity_b->Sumw2();
1194 }
1195
1196 // MVA plots used for efficiency calculations (large number of bins)
1197 TH1* mva_eff_s = new TH1D( TestvarName + "_S_high", TestvarName + "_S_high", fNbinsH, fXmin, sxmax );
1198 TH1* mva_eff_b = new TH1D( TestvarName + "_B_high", TestvarName + "_B_high", fNbinsH, fXmin, sxmax );
1199 mvaRes->Store(mva_eff_s, "MVA_HIGHBIN_S");
1200 mvaRes->Store(mva_eff_b, "MVA_HIGHBIN_B");
1201 mva_eff_s->Sumw2();
1202 mva_eff_b->Sumw2();
1203
1204 // fill the histograms
1205
1206 ResultsClassification* mvaProb = dynamic_cast<ResultsClassification*>
1207 (Data()->GetResults( TString("prob_")+GetMethodName(), Types::kTesting, Types::kMaxAnalysisType ) );
1208
1209 Log() << kHEADER <<Form("[%s] : ",DataInfo().GetName())<< "Loop over test events and fill histograms with classifier response..." << Endl << Endl;
1210 if (mvaProb) Log() << kINFO << "Also filling probability and rarity histograms (on request)..." << Endl;
1211 //std::vector<Bool_t>* mvaResTypes = mvaRes->GetValueVectorTypes();
1212
1213 //LM: this is needed to avoid crashes in ROOCCURVE
1214 if ( mvaRes->GetSize() != GetNEvents() ) {
1215 Log() << kFATAL << TString::Format("Inconsistent result size %lld with number of events %u ", mvaRes->GetSize() , GetNEvents() ) << Endl;
1216 assert(mvaRes->GetSize() == GetNEvents());
1217 }
1218
1219 for (Long64_t ievt=0; ievt<GetNEvents(); ievt++) {
1220
1221 const Event* ev = GetEvent(ievt);
1222 Float_t v = (*mvaRes)[ievt][0];
1223 Float_t w = ev->GetWeight();
1224
1225 if (DataInfo().IsSignal(ev)) {
1226 //mvaResTypes->push_back(kTRUE);
1227 mva_s ->Fill( v, w );
1228 if (mvaProb) {
1229 proba_s->Fill( (*mvaProb)[ievt][0], w );
1230 rarity_s->Fill( GetRarity( v ), w );
1231 }
1232
1233 mva_eff_s ->Fill( v, w );
1234 }
1235 else {
1236 //mvaResTypes->push_back(kFALSE);
1237 mva_b ->Fill( v, w );
1238 if (mvaProb) {
1239 proba_b->Fill( (*mvaProb)[ievt][0], w );
1240 rarity_b->Fill( GetRarity( v ), w );
1241 }
1242 mva_eff_b ->Fill( v, w );
1243 }
1244 }
1245
1246 // uncomment those (and several others if you want unnormalized output
1247 gTools().NormHist( mva_s );
1248 gTools().NormHist( mva_b );
1249 gTools().NormHist( proba_s );
1250 gTools().NormHist( proba_b );
1251 gTools().NormHist( rarity_s );
1252 gTools().NormHist( rarity_b );
1253 gTools().NormHist( mva_eff_s );
1254 gTools().NormHist( mva_eff_b );
1255
1256 // create PDFs from histograms, using default splines, and no additional smoothing
1257 if (fSplS) { delete fSplS; fSplS = 0; }
1258 if (fSplB) { delete fSplB; fSplB = 0; }
1259 fSplS = new PDF( TString(GetName()) + " PDF Sig", mva_s, PDF::kSpline2 );
1260 fSplB = new PDF( TString(GetName()) + " PDF Bkg", mva_b, PDF::kSpline2 );
1261}
1262
1263////////////////////////////////////////////////////////////////////////////////
1264/// general method used in writing the header of the weight files where
1265/// the used variables, variable transformation type etc. is specified
1266
1267void TMVA::MethodBase::WriteStateToStream( std::ostream& tf ) const
1268{
1269 TString prefix = "";
1270 UserGroup_t * userInfo = gSystem->GetUserInfo();
1271
1272 tf << prefix << "#GEN -*-*-*-*-*-*-*-*-*-*-*- general info -*-*-*-*-*-*-*-*-*-*-*-" << std::endl << prefix << std::endl;
1273 tf << prefix << "Method : " << GetMethodTypeName() << "::" << GetMethodName() << std::endl;
1274 tf.setf(std::ios::left);
1275 tf << prefix << "TMVA Release : " << std::setw(10) << GetTrainingTMVAVersionString() << " ["
1276 << GetTrainingTMVAVersionCode() << "]" << std::endl;
1277 tf << prefix << "ROOT Release : " << std::setw(10) << GetTrainingROOTVersionString() << " ["
1278 << GetTrainingROOTVersionCode() << "]" << std::endl;
1279 tf << prefix << "Creator : " << userInfo->fUser << std::endl;
1280 tf << prefix << "Date : "; TDatime *d = new TDatime; tf << d->AsString() << std::endl; delete d;
1281 tf << prefix << "Host : " << gSystem->GetBuildNode() << std::endl;
1282 tf << prefix << "Dir : " << gSystem->WorkingDirectory() << std::endl;
1283 tf << prefix << "Training events: " << Data()->GetNTrainingEvents() << std::endl;
1284
1285 TString analysisType(((const_cast<TMVA::MethodBase*>(this)->GetAnalysisType()==Types::kRegression) ? "Regression" : "Classification"));
1286
1287 tf << prefix << "Analysis type : " << "[" << ((GetAnalysisType()==Types::kRegression) ? "Regression" : "Classification") << "]" << std::endl;
1288 tf << prefix << std::endl;
1289
1290 delete userInfo;
1291
1292 // First write all options
1293 tf << prefix << std::endl << prefix << "#OPT -*-*-*-*-*-*-*-*-*-*-*-*- options -*-*-*-*-*-*-*-*-*-*-*-*-" << std::endl << prefix << std::endl;
1294 WriteOptionsToStream( tf, prefix );
1295 tf << prefix << std::endl;
1296
1297 // Second write variable info
1298 tf << prefix << std::endl << prefix << "#VAR -*-*-*-*-*-*-*-*-*-*-*-* variables *-*-*-*-*-*-*-*-*-*-*-*-" << std::endl << prefix << std::endl;
1299 WriteVarsToStream( tf, prefix );
1300 tf << prefix << std::endl;
1301}
1302
1303////////////////////////////////////////////////////////////////////////////////
1304/// xml writing
1305
1306void TMVA::MethodBase::AddInfoItem( void* gi, const TString& name, const TString& value) const
1307{
1308 void* it = gTools().AddChild(gi,"Info");
1309 gTools().AddAttr(it,"name", name);
1310 gTools().AddAttr(it,"value", value);
1311}
1312
1313////////////////////////////////////////////////////////////////////////////////
1314
1316 if (analysisType == Types::kRegression) {
1317 AddRegressionOutput( type );
1318 } else if (analysisType == Types::kMulticlass) {
1319 AddMulticlassOutput( type );
1320 } else {
1321 AddClassifierOutput( type );
1322 if (HasMVAPdfs())
1323 AddClassifierOutputProb( type );
1324 }
1325}
1326
1327////////////////////////////////////////////////////////////////////////////////
1328/// general method used in writing the header of the weight files where
1329/// the used variables, variable transformation type etc. is specified
1330
1331void TMVA::MethodBase::WriteStateToXML( void* parent ) const
1332{
1333 if (!parent) return;
1334
1335 UserGroup_t* userInfo = gSystem->GetUserInfo();
1336
1337 void* gi = gTools().AddChild(parent, "GeneralInfo");
1338 AddInfoItem( gi, "TMVA Release", GetTrainingTMVAVersionString() + " [" + gTools().StringFromInt(GetTrainingTMVAVersionCode()) + "]" );
1339 AddInfoItem( gi, "ROOT Release", GetTrainingROOTVersionString() + " [" + gTools().StringFromInt(GetTrainingROOTVersionCode()) + "]");
1340 AddInfoItem( gi, "Creator", userInfo->fUser);
1341 TDatime dt; AddInfoItem( gi, "Date", dt.AsString());
1342 AddInfoItem( gi, "Host", gSystem->GetBuildNode() );
1343 AddInfoItem( gi, "Dir", gSystem->WorkingDirectory());
1344 AddInfoItem( gi, "Training events", gTools().StringFromInt(Data()->GetNTrainingEvents()));
1345 AddInfoItem( gi, "TrainingTime", gTools().StringFromDouble(const_cast<TMVA::MethodBase*>(this)->GetTrainTime()));
1346
1347 Types::EAnalysisType aType = const_cast<TMVA::MethodBase*>(this)->GetAnalysisType();
1348 TString analysisType((aType==Types::kRegression) ? "Regression" :
1349 (aType==Types::kMulticlass ? "Multiclass" : "Classification"));
1350 AddInfoItem( gi, "AnalysisType", analysisType );
1351 delete userInfo;
1352
1353 // write options
1354 AddOptionsXMLTo( parent );
1355
1356 // write variable info
1357 AddVarsXMLTo( parent );
1358
1359 // write spectator info
1360 if (fModelPersistence)
1361 AddSpectatorsXMLTo( parent );
1362
1363 // write class info if in multiclass mode
1364 AddClassesXMLTo(parent);
1365
1366 // write target info if in regression mode
1367 if (DoRegression()) AddTargetsXMLTo(parent);
1368
1369 // write transformations
1370 GetTransformationHandler(false).AddXMLTo( parent );
1371
1372 // write MVA variable distributions
1373 void* pdfs = gTools().AddChild(parent, "MVAPdfs");
1374 if (fMVAPdfS) fMVAPdfS->AddXMLTo(pdfs);
1375 if (fMVAPdfB) fMVAPdfB->AddXMLTo(pdfs);
1376
1377 // write weights
1378 AddWeightsXMLTo( parent );
1379}
1380
1381////////////////////////////////////////////////////////////////////////////////
1382/// write reference MVA distributions (and other information)
1383/// to a ROOT type weight file
1384
1386{
1387 Bool_t addDirStatus = TH1::AddDirectoryStatus();
1388 TH1::AddDirectory( 0 ); // this avoids the binding of the hists in PDF to the current ROOT file
1389 fMVAPdfS = (TMVA::PDF*)rf.Get( "MVA_PDF_Signal" );
1390 fMVAPdfB = (TMVA::PDF*)rf.Get( "MVA_PDF_Background" );
1391
1392 TH1::AddDirectory( addDirStatus );
1393
1394 ReadWeightsFromStream( rf );
1395
1396 SetTestvarName();
1397}
1398
1399////////////////////////////////////////////////////////////////////////////////
1400/// write options and weights to file
1401/// note that each one text file for the main configuration information
1402/// and one ROOT file for ROOT objects are created
1403
1405{
1406 // ---- create the text file
1407 TString tfname( GetWeightFileName() );
1408
1409 // writing xml file
1410 TString xmlfname( tfname ); xmlfname.ReplaceAll( ".txt", ".xml" );
1411 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
1412 << "Creating xml weight file: "
1413 << gTools().Color("lightblue") << xmlfname << gTools().Color("reset") << Endl;
1414 void* doc = gTools().xmlengine().NewDoc();
1415 void* rootnode = gTools().AddChild(0,"MethodSetup", "", true);
1416 gTools().xmlengine().DocSetRootElement(doc,rootnode);
1417 gTools().AddAttr(rootnode,"Method", GetMethodTypeName() + "::" + GetMethodName());
1418 WriteStateToXML(rootnode);
1419 gTools().xmlengine().SaveDoc(doc,xmlfname);
1420 gTools().xmlengine().FreeDoc(doc);
1421}
1422
1423////////////////////////////////////////////////////////////////////////////////
1424/// Function to write options and weights to file
1425
1427{
1428 // get the filename
1429
1430 TString tfname(GetWeightFileName());
1431
1432 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
1433 << "Reading weight file: "
1434 << gTools().Color("lightblue") << tfname << gTools().Color("reset") << Endl;
1435
1436 if (tfname.EndsWith(".xml") ) {
1437 void* doc = gTools().xmlengine().ParseFile(tfname,gTools().xmlenginebuffersize()); // the default buffer size in TXMLEngine::ParseFile is 100k. Starting with ROOT 5.29 one can set the buffer size, see: http://savannah.cern.ch/bugs/?78864. This might be necessary for large XML files
1438 if (!doc) {
1439 Log() << kFATAL << "Error parsing XML file " << tfname << Endl;
1440 }
1441 void* rootnode = gTools().xmlengine().DocGetRootElement(doc); // node "MethodSetup"
1442 ReadStateFromXML(rootnode);
1443 gTools().xmlengine().FreeDoc(doc);
1444 }
1445 else {
1446 std::filebuf fb;
1447 fb.open(tfname.Data(),std::ios::in);
1448 if (!fb.is_open()) { // file not found --> Error
1449 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<ReadStateFromFile> "
1450 << "Unable to open input weight file: " << tfname << Endl;
1451 }
1452 std::istream fin(&fb);
1453 ReadStateFromStream(fin);
1454 fb.close();
1455 }
1456 if (!fTxtWeightsOnly) {
1457 // ---- read the ROOT file
1458 TString rfname( tfname ); rfname.ReplaceAll( ".txt", ".root" );
1459 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Reading root weight file: "
1460 << gTools().Color("lightblue") << rfname << gTools().Color("reset") << Endl;
1461 TFile* rfile = TFile::Open( rfname, "READ" );
1462 ReadStateFromStream( *rfile );
1463 rfile->Close();
1464 }
1465}
1466////////////////////////////////////////////////////////////////////////////////
1467/// for reading from memory
1468
1470 void* doc = gTools().xmlengine().ParseString(xmlstr);
1471 void* rootnode = gTools().xmlengine().DocGetRootElement(doc); // node "MethodSetup"
1472 ReadStateFromXML(rootnode);
1473 gTools().xmlengine().FreeDoc(doc);
1474
1475 return;
1476}
1477
1478////////////////////////////////////////////////////////////////////////////////
1479
1481{
1482
1483 TString fullMethodName;
1484 gTools().ReadAttr( methodNode, "Method", fullMethodName );
1485
1486 fMethodName = fullMethodName(fullMethodName.Index("::")+2,fullMethodName.Length());
1487
1488 // update logger
1489 Log().SetSource( GetName() );
1490 Log() << kDEBUG//<<Form("Dataset[%s] : ",DataInfo().GetName())
1491 << "Read method \"" << GetMethodName() << "\" of type \"" << GetMethodTypeName() << "\"" << Endl;
1492
1493 // after the method name is read, the testvar can be set
1494 SetTestvarName();
1495
1496 TString nodeName("");
1497 void* ch = gTools().GetChild(methodNode);
1498 while (ch!=0) {
1499 nodeName = TString( gTools().GetName(ch) );
1500
1501 if (nodeName=="GeneralInfo") {
1502 // read analysis type
1503
1504 TString name(""),val("");
1505 void* antypeNode = gTools().GetChild(ch);
1506 while (antypeNode) {
1507 gTools().ReadAttr( antypeNode, "name", name );
1508
1509 if (name == "TrainingTime")
1510 gTools().ReadAttr( antypeNode, "value", fTrainTime );
1511
1512 if (name == "AnalysisType") {
1513 gTools().ReadAttr( antypeNode, "value", val );
1514 val.ToLower();
1515 if (val == "regression" ) SetAnalysisType( Types::kRegression );
1516 else if (val == "classification" ) SetAnalysisType( Types::kClassification );
1517 else if (val == "multiclass" ) SetAnalysisType( Types::kMulticlass );
1518 else Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Analysis type " << val << " is not known." << Endl;
1519 }
1520
1521 if (name == "TMVA Release" || name == "TMVA") {
1522 TString s;
1523 gTools().ReadAttr( antypeNode, "value", s);
1524 fTMVATrainingVersion = TString(s(s.Index("[")+1,s.Index("]")-s.Index("[")-1)).Atoi();
1525 Log() << kDEBUG <<Form("[%s] : ",DataInfo().GetName()) << "MVA method was trained with TMVA Version: " << GetTrainingTMVAVersionString() << Endl;
1526 }
1527
1528 if (name == "ROOT Release" || name == "ROOT") {
1529 TString s;
1530 gTools().ReadAttr( antypeNode, "value", s);
1531 fROOTTrainingVersion = TString(s(s.Index("[")+1,s.Index("]")-s.Index("[")-1)).Atoi();
1532 Log() << kDEBUG //<<Form("Dataset[%s] : ",DataInfo().GetName())
1533 << "MVA method was trained with ROOT Version: " << GetTrainingROOTVersionString() << Endl;
1534 }
1535 antypeNode = gTools().GetNextChild(antypeNode);
1536 }
1537 }
1538 else if (nodeName=="Options") {
1539 ReadOptionsFromXML(ch);
1540 ParseOptions();
1541
1542 }
1543 else if (nodeName=="Variables") {
1544 ReadVariablesFromXML(ch);
1545 }
1546 else if (nodeName=="Spectators") {
1547 ReadSpectatorsFromXML(ch);
1548 }
1549 else if (nodeName=="Classes") {
1550 if (DataInfo().GetNClasses()==0) ReadClassesFromXML(ch);
1551 }
1552 else if (nodeName=="Targets") {
1553 if (DataInfo().GetNTargets()==0 && DoRegression()) ReadTargetsFromXML(ch);
1554 }
1555 else if (nodeName=="Transformations") {
1556 GetTransformationHandler().ReadFromXML(ch);
1557 }
1558 else if (nodeName=="MVAPdfs") {
1559 TString pdfname;
1560 if (fMVAPdfS) { delete fMVAPdfS; fMVAPdfS=0; }
1561 if (fMVAPdfB) { delete fMVAPdfB; fMVAPdfB=0; }
1562 void* pdfnode = gTools().GetChild(ch);
1563 if (pdfnode) {
1564 gTools().ReadAttr(pdfnode, "Name", pdfname);
1565 fMVAPdfS = new PDF(pdfname);
1566 fMVAPdfS->ReadXML(pdfnode);
1567 pdfnode = gTools().GetNextChild(pdfnode);
1568 gTools().ReadAttr(pdfnode, "Name", pdfname);
1569 fMVAPdfB = new PDF(pdfname);
1570 fMVAPdfB->ReadXML(pdfnode);
1571 }
1572 }
1573 else if (nodeName=="Weights") {
1574 ReadWeightsFromXML(ch);
1575 }
1576 else {
1577 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Unparsed XML node: '" << nodeName << "'" << Endl;
1578 }
1579 ch = gTools().GetNextChild(ch);
1580
1581 }
1582
1583 // update transformation handler
1584 if (GetTransformationHandler().GetCallerName() == "") GetTransformationHandler().SetCallerName( GetName() );
1585}
1586
1587////////////////////////////////////////////////////////////////////////////////
1588/// read the header from the weight files of the different MVA methods
1589
1591{
1592 char buf[512];
1593
1594 // when reading from stream, we assume the files are produced with TMVA<=397
1595 SetAnalysisType(Types::kClassification);
1596
1597
1598 // first read the method name
1599 GetLine(fin,buf);
1600 while (!TString(buf).BeginsWith("Method")) GetLine(fin,buf);
1601 TString namestr(buf);
1602
1603 TString methodType = namestr(0,namestr.Index("::"));
1604 methodType = methodType(methodType.Last(' '),methodType.Length());
1605 methodType = methodType.Strip(TString::kLeading);
1606
1607 TString methodName = namestr(namestr.Index("::")+2,namestr.Length());
1608 methodName = methodName.Strip(TString::kLeading);
1609 if (methodName == "") methodName = methodType;
1610 fMethodName = methodName;
1611
1612 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Read method \"" << GetMethodName() << "\" of type \"" << GetMethodTypeName() << "\"" << Endl;
1613
1614 // update logger
1615 Log().SetSource( GetName() );
1616
1617 // now the question is whether to read the variables first or the options (well, of course the order
1618 // of writing them needs to agree)
1619 //
1620 // the option "Decorrelation" is needed to decide if the variables we
1621 // read are decorrelated or not
1622 //
1623 // the variables are needed by some methods (TMLP) to build the NN
1624 // which is done in ProcessOptions so for the time being we first Read and Parse the options then
1625 // we read the variables, and then we process the options
1626
1627 // now read all options
1628 GetLine(fin,buf);
1629 while (!TString(buf).BeginsWith("#OPT")) GetLine(fin,buf);
1630 ReadOptionsFromStream(fin);
1631 ParseOptions();
1632
1633 // Now read variable info
1634 fin.getline(buf,512);
1635 while (!TString(buf).BeginsWith("#VAR")) fin.getline(buf,512);
1636 ReadVarsFromStream(fin);
1637
1638 // now we process the options (of the derived class)
1639 ProcessOptions();
1640
1641 if (IsNormalised()) {
1643 GetTransformationHandler().AddTransformation( new VariableNormalizeTransform(DataInfo()), -1 );
1644 norm->BuildTransformationFromVarInfo( DataInfo().GetVariableInfos() );
1645 }
1646 VariableTransformBase *varTrafo(0), *varTrafo2(0);
1647 if ( fVarTransformString == "None") {
1648 if (fUseDecorr)
1649 varTrafo = GetTransformationHandler().AddTransformation( new VariableDecorrTransform(DataInfo()), -1 );
1650 } else if ( fVarTransformString == "Decorrelate" ) {
1651 varTrafo = GetTransformationHandler().AddTransformation( new VariableDecorrTransform(DataInfo()), -1 );
1652 } else if ( fVarTransformString == "PCA" ) {
1653 varTrafo = GetTransformationHandler().AddTransformation( new VariablePCATransform(DataInfo()), -1 );
1654 } else if ( fVarTransformString == "Uniform" ) {
1655 varTrafo = GetTransformationHandler().AddTransformation( new VariableGaussTransform(DataInfo(),"Uniform"), -1 );
1656 } else if ( fVarTransformString == "Gauss" ) {
1657 varTrafo = GetTransformationHandler().AddTransformation( new VariableGaussTransform(DataInfo()), -1 );
1658 } else if ( fVarTransformString == "GaussDecorr" ) {
1659 varTrafo = GetTransformationHandler().AddTransformation( new VariableGaussTransform(DataInfo()), -1 );
1660 varTrafo2 = GetTransformationHandler().AddTransformation( new VariableDecorrTransform(DataInfo()), -1 );
1661 } else {
1662 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<ProcessOptions> Variable transform '"
1663 << fVarTransformString << "' unknown." << Endl;
1664 }
1665 // Now read decorrelation matrix if available
1666 if (GetTransformationHandler().GetTransformationList().GetSize() > 0) {
1667 fin.getline(buf,512);
1668 while (!TString(buf).BeginsWith("#MAT")) fin.getline(buf,512);
1669 if (varTrafo) {
1670 TString trafo(fVariableTransformTypeString); trafo.ToLower();
1671 varTrafo->ReadTransformationFromStream(fin, trafo );
1672 }
1673 if (varTrafo2) {
1674 TString trafo(fVariableTransformTypeString); trafo.ToLower();
1675 varTrafo2->ReadTransformationFromStream(fin, trafo );
1676 }
1677 }
1678
1679
1680 if (HasMVAPdfs()) {
1681 // Now read the MVA PDFs
1682 fin.getline(buf,512);
1683 while (!TString(buf).BeginsWith("#MVAPDFS")) fin.getline(buf,512);
1684 if (fMVAPdfS != 0) { delete fMVAPdfS; fMVAPdfS = 0; }
1685 if (fMVAPdfB != 0) { delete fMVAPdfB; fMVAPdfB = 0; }
1686 fMVAPdfS = new PDF(TString(GetName()) + " MVA PDF Sig");
1687 fMVAPdfB = new PDF(TString(GetName()) + " MVA PDF Bkg");
1688 fMVAPdfS->SetReadingVersion( GetTrainingTMVAVersionCode() );
1689 fMVAPdfB->SetReadingVersion( GetTrainingTMVAVersionCode() );
1690
1691 fin >> *fMVAPdfS;
1692 fin >> *fMVAPdfB;
1693 }
1694
1695 // Now read weights
1696 fin.getline(buf,512);
1697 while (!TString(buf).BeginsWith("#WGT")) fin.getline(buf,512);
1698 fin.getline(buf,512);
1699 ReadWeightsFromStream( fin );;
1700
1701 // update transformation handler
1702 if (GetTransformationHandler().GetCallerName() == "") GetTransformationHandler().SetCallerName( GetName() );
1703
1704}
1705
1706////////////////////////////////////////////////////////////////////////////////
1707/// write the list of variables (name, min, max) for a given data
1708/// transformation method to the stream
1709
1710void TMVA::MethodBase::WriteVarsToStream( std::ostream& o, const TString& prefix ) const
1711{
1712 o << prefix << "NVar " << DataInfo().GetNVariables() << std::endl;
1713 std::vector<VariableInfo>::const_iterator varIt = DataInfo().GetVariableInfos().begin();
1714 for (; varIt!=DataInfo().GetVariableInfos().end(); ++varIt) { o << prefix; varIt->WriteToStream(o); }
1715 o << prefix << "NSpec " << DataInfo().GetNSpectators() << std::endl;
1716 varIt = DataInfo().GetSpectatorInfos().begin();
1717 for (; varIt!=DataInfo().GetSpectatorInfos().end(); ++varIt) { o << prefix; varIt->WriteToStream(o); }
1718}
1719
1720////////////////////////////////////////////////////////////////////////////////
1721/// Read the variables (name, min, max) for a given data
1722/// transformation method from the stream. In the stream we only
1723/// expect the limits which will be set
1724
1726{
1727 TString dummy;
1728 UInt_t readNVar;
1729 istr >> dummy >> readNVar;
1730
1731 if (readNVar!=DataInfo().GetNVariables()) {
1732 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "You declared "<< DataInfo().GetNVariables() << " variables in the Reader"
1733 << " while there are " << readNVar << " variables declared in the file"
1734 << Endl;
1735 }
1736
1737 // we want to make sure all variables are read in the order they are defined
1738 VariableInfo varInfo;
1739 std::vector<VariableInfo>::iterator varIt = DataInfo().GetVariableInfos().begin();
1740 int varIdx = 0;
1741 for (; varIt!=DataInfo().GetVariableInfos().end(); ++varIt, ++varIdx) {
1742 varInfo.ReadFromStream(istr);
1743 if (varIt->GetExpression() == varInfo.GetExpression()) {
1744 varInfo.SetExternalLink((*varIt).GetExternalLink());
1745 (*varIt) = varInfo;
1746 }
1747 else {
1748 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "ERROR in <ReadVarsFromStream>" << Endl;
1749 Log() << kINFO << "The definition (or the order) of the variables found in the input file is" << Endl;
1750 Log() << kINFO << "is not the same as the one declared in the Reader (which is necessary for" << Endl;
1751 Log() << kINFO << "the correct working of the method):" << Endl;
1752 Log() << kINFO << " var #" << varIdx <<" declared in Reader: " << varIt->GetExpression() << Endl;
1753 Log() << kINFO << " var #" << varIdx <<" declared in file : " << varInfo.GetExpression() << Endl;
1754 Log() << kFATAL << "The expression declared to the Reader needs to be checked (name or order are wrong)" << Endl;
1755 }
1756 }
1757}
1758
1759////////////////////////////////////////////////////////////////////////////////
1760/// write variable info to XML
1761
1762void TMVA::MethodBase::AddVarsXMLTo( void* parent ) const
1763{
1764 void* vars = gTools().AddChild(parent, "Variables");
1765 gTools().AddAttr( vars, "NVar", gTools().StringFromInt(DataInfo().GetNVariables()) );
1766
1767 for (UInt_t idx=0; idx<DataInfo().GetVariableInfos().size(); idx++) {
1768 VariableInfo& vi = DataInfo().GetVariableInfos()[idx];
1769 void* var = gTools().AddChild( vars, "Variable" );
1770 gTools().AddAttr( var, "VarIndex", idx );
1771 vi.AddToXML( var );
1772 }
1773}
1774
1775////////////////////////////////////////////////////////////////////////////////
1776/// write spectator info to XML
1777
1779{
1780 void* specs = gTools().AddChild(parent, "Spectators");
1781
1782 UInt_t writeIdx=0;
1783 for (UInt_t idx=0; idx<DataInfo().GetSpectatorInfos().size(); idx++) {
1784
1785 VariableInfo& vi = DataInfo().GetSpectatorInfos()[idx];
1786
1787 // we do not want to write spectators that are category-cuts,
1788 // except if the method is the category method and the spectators belong to it
1789 if (vi.GetVarType()=='C') continue;
1790
1791 void* spec = gTools().AddChild( specs, "Spectator" );
1792 gTools().AddAttr( spec, "SpecIndex", writeIdx++ );
1793 vi.AddToXML( spec );
1794 }
1795 gTools().AddAttr( specs, "NSpec", gTools().StringFromInt(writeIdx) );
1796}
1797
1798////////////////////////////////////////////////////////////////////////////////
1799/// write class info to XML
1800
1801void TMVA::MethodBase::AddClassesXMLTo( void* parent ) const
1802{
1803 UInt_t nClasses=DataInfo().GetNClasses();
1804
1805 void* classes = gTools().AddChild(parent, "Classes");
1806 gTools().AddAttr( classes, "NClass", nClasses );
1807
1808 for (UInt_t iCls=0; iCls<nClasses; ++iCls) {
1809 ClassInfo *classInfo=DataInfo().GetClassInfo (iCls);
1810 TString className =classInfo->GetName();
1811 UInt_t classNumber=classInfo->GetNumber();
1812
1813 void* classNode=gTools().AddChild(classes, "Class");
1814 gTools().AddAttr( classNode, "Name", className );
1815 gTools().AddAttr( classNode, "Index", classNumber );
1816 }
1817}
1818////////////////////////////////////////////////////////////////////////////////
1819/// write target info to XML
1820
1821void TMVA::MethodBase::AddTargetsXMLTo( void* parent ) const
1822{
1823 void* targets = gTools().AddChild(parent, "Targets");
1824 gTools().AddAttr( targets, "NTrgt", gTools().StringFromInt(DataInfo().GetNTargets()) );
1825
1826 for (UInt_t idx=0; idx<DataInfo().GetTargetInfos().size(); idx++) {
1827 VariableInfo& vi = DataInfo().GetTargetInfos()[idx];
1828 void* tar = gTools().AddChild( targets, "Target" );
1829 gTools().AddAttr( tar, "TargetIndex", idx );
1830 vi.AddToXML( tar );
1831 }
1832}
1833
1834////////////////////////////////////////////////////////////////////////////////
1835/// read variable info from XML
1836
1838{
1839 UInt_t readNVar;
1840 gTools().ReadAttr( varnode, "NVar", readNVar);
1841
1842 if (readNVar!=DataInfo().GetNVariables()) {
1843 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "You declared "<< DataInfo().GetNVariables() << " variables in the Reader"
1844 << " while there are " << readNVar << " variables declared in the file"
1845 << Endl;
1846 }
1847
1848 // we want to make sure all variables are read in the order they are defined
1849 VariableInfo readVarInfo, existingVarInfo;
1850 int varIdx = 0;
1851 void* ch = gTools().GetChild(varnode);
1852 while (ch) {
1853 gTools().ReadAttr( ch, "VarIndex", varIdx);
1854 existingVarInfo = DataInfo().GetVariableInfos()[varIdx];
1855 readVarInfo.ReadFromXML(ch);
1856
1857 if (existingVarInfo.GetExpression() == readVarInfo.GetExpression()) {
1858 readVarInfo.SetExternalLink(existingVarInfo.GetExternalLink());
1859 existingVarInfo = readVarInfo;
1860 }
1861 else {
1862 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "ERROR in <ReadVariablesFromXML>" << Endl;
1863 Log() << kINFO << "The definition (or the order) of the variables found in the input file is" << Endl;
1864 Log() << kINFO << "not the same as the one declared in the Reader (which is necessary for the" << Endl;
1865 Log() << kINFO << "correct working of the method):" << Endl;
1866 Log() << kINFO << " var #" << varIdx <<" declared in Reader: " << existingVarInfo.GetExpression() << Endl;
1867 Log() << kINFO << " var #" << varIdx <<" declared in file : " << readVarInfo.GetExpression() << Endl;
1868 Log() << kFATAL << "The expression declared to the Reader needs to be checked (name or order are wrong)" << Endl;
1869 }
1870 ch = gTools().GetNextChild(ch);
1871 }
1872}
1873
1874////////////////////////////////////////////////////////////////////////////////
1875/// read spectator info from XML
1876
1878{
1879 UInt_t readNSpec;
1880 gTools().ReadAttr( specnode, "NSpec", readNSpec);
1881
1882 if (readNSpec!=DataInfo().GetNSpectators(kFALSE)) {
1883 Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName()) << "You declared "<< DataInfo().GetNSpectators(kFALSE) << " spectators in the Reader"
1884 << " while there are " << readNSpec << " spectators declared in the file"
1885 << Endl;
1886 }
1887
1888 // we want to make sure all variables are read in the order they are defined
1889 VariableInfo readSpecInfo, existingSpecInfo;
1890 int specIdx = 0;
1891 void* ch = gTools().GetChild(specnode);
1892 while (ch) {
1893 gTools().ReadAttr( ch, "SpecIndex", specIdx);
1894 existingSpecInfo = DataInfo().GetSpectatorInfos()[specIdx];
1895 readSpecInfo.ReadFromXML(ch);
1896
1897 if (existingSpecInfo.GetExpression() == readSpecInfo.GetExpression()) {
1898 readSpecInfo.SetExternalLink(existingSpecInfo.GetExternalLink());
1899 existingSpecInfo = readSpecInfo;
1900 }
1901 else {
1902 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "ERROR in <ReadSpectatorsFromXML>" << Endl;
1903 Log() << kINFO << "The definition (or the order) of the spectators found in the input file is" << Endl;
1904 Log() << kINFO << "not the same as the one declared in the Reader (which is necessary for the" << Endl;
1905 Log() << kINFO << "correct working of the method):" << Endl;
1906 Log() << kINFO << " spec #" << specIdx <<" declared in Reader: " << existingSpecInfo.GetExpression() << Endl;
1907 Log() << kINFO << " spec #" << specIdx <<" declared in file : " << readSpecInfo.GetExpression() << Endl;
1908 Log() << kFATAL << "The expression declared to the Reader needs to be checked (name or order are wrong)" << Endl;
1909 }
1910 ch = gTools().GetNextChild(ch);
1911 }
1912}
1913
1914////////////////////////////////////////////////////////////////////////////////
1915/// read number of classes from XML
1916
1918{
1919 UInt_t readNCls;
1920 // coverity[tainted_data_argument]
1921 gTools().ReadAttr( clsnode, "NClass", readNCls);
1922
1923 TString className="";
1924 UInt_t classIndex=0;
1925 void* ch = gTools().GetChild(clsnode);
1926 if (!ch) {
1927 for (UInt_t icls = 0; icls<readNCls;++icls) {
1928 TString classname = Form("class%i",icls);
1929 DataInfo().AddClass(classname);
1930
1931 }
1932 }
1933 else{
1934 while (ch) {
1935 gTools().ReadAttr( ch, "Index", classIndex);
1936 gTools().ReadAttr( ch, "Name", className );
1937 DataInfo().AddClass(className);
1938
1939 ch = gTools().GetNextChild(ch);
1940 }
1941 }
1942
1943 // retrieve signal and background class index
1944 if (DataInfo().GetClassInfo("Signal") != 0) {
1945 fSignalClass = DataInfo().GetClassInfo("Signal")->GetNumber();
1946 }
1947 else
1948 fSignalClass=0;
1949 if (DataInfo().GetClassInfo("Background") != 0) {
1950 fBackgroundClass = DataInfo().GetClassInfo("Background")->GetNumber();
1951 }
1952 else
1953 fBackgroundClass=1;
1954}
1955
1956////////////////////////////////////////////////////////////////////////////////
1957/// read target info from XML
1958
1960{
1961 UInt_t readNTar;
1962 gTools().ReadAttr( tarnode, "NTrgt", readNTar);
1963
1964 int tarIdx = 0;
1965 TString expression;
1966 void* ch = gTools().GetChild(tarnode);
1967 while (ch) {
1968 gTools().ReadAttr( ch, "TargetIndex", tarIdx);
1969 gTools().ReadAttr( ch, "Expression", expression);
1970 DataInfo().AddTarget(expression,"","",0,0);
1971
1972 ch = gTools().GetNextChild(ch);
1973 }
1974}
1975
1976////////////////////////////////////////////////////////////////////////////////
1977/// returns the ROOT directory where info/histograms etc of the
1978/// corresponding MVA method instance are stored
1979
1981{
1982 if (fBaseDir != 0) return fBaseDir;
1983 Log()<<kDEBUG<<Form("Dataset[%s] : ",DataInfo().GetName())<<" Base Directory for " << GetMethodName() << " not set yet --> check if already there.." <<Endl;
1984
1985 if (IsSilentFile()) {
1986 Log() << kFATAL << Form("Dataset[%s] : ", DataInfo().GetName())
1987 << "MethodBase::BaseDir() - No directory exists when running a Method without output file. Enable the "
1988 "output when creating the factory"
1989 << Endl;
1990 }
1991
1992 TDirectory* methodDir = MethodBaseDir();
1993 if (methodDir==0)
1994 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "MethodBase::BaseDir() - MethodBaseDir() return a NULL pointer!" << Endl;
1995
1996 TString defaultDir = GetMethodName();
1997 TDirectory *sdir = methodDir->GetDirectory(defaultDir.Data());
1998 if(!sdir)
1999 {
2000 Log()<<kDEBUG<<Form("Dataset[%s] : ",DataInfo().GetName())<<" Base Directory for " << GetMethodTypeName() << " does not exist yet--> created it" <<Endl;
2001 sdir = methodDir->mkdir(defaultDir);
2002 sdir->cd();
2003 // write weight file name into target file
2004 if (fModelPersistence) {
2005 TObjString wfilePath( gSystem->WorkingDirectory() );
2006 TObjString wfileName( GetWeightFileName() );
2007 wfilePath.Write( "TrainingPath" );
2008 wfileName.Write( "WeightFileName" );
2009 }
2010 }
2011
2012 Log()<<kDEBUG<<Form("Dataset[%s] : ",DataInfo().GetName())<<" Base Directory for " << GetMethodTypeName() << " existed, return it.." <<Endl;
2013 return sdir;
2014}
2015
2016////////////////////////////////////////////////////////////////////////////////
2017/// returns the ROOT directory where all instances of the
2018/// corresponding MVA method are stored
2019
2021{
2022 if (fMethodBaseDir != 0) {
2023 return fMethodBaseDir;
2024 }
2025
2026 const char *datasetName = DataInfo().GetName();
2027
2028 Log() << kDEBUG << Form("Dataset[%s] : ", datasetName) << " Base Directory for " << GetMethodTypeName()
2029 << " not set yet --> check if already there.." << Endl;
2030
2031 TDirectory *factoryBaseDir = GetFile();
2032 if (!factoryBaseDir) return nullptr;
2033 fMethodBaseDir = factoryBaseDir->GetDirectory(datasetName);
2034 if (!fMethodBaseDir) {
2035 fMethodBaseDir = factoryBaseDir->mkdir(datasetName, Form("Base directory for dataset %s", datasetName));
2036 if (!fMethodBaseDir) {
2037 Log() << kFATAL << "Can not create dir " << datasetName;
2038 }
2039 }
2040 TString methodTypeDir = Form("Method_%s", GetMethodTypeName().Data());
2041 fMethodBaseDir = fMethodBaseDir->GetDirectory(methodTypeDir.Data());
2042
2043 if (!fMethodBaseDir) {
2044 TDirectory *datasetDir = factoryBaseDir->GetDirectory(datasetName);
2045 TString methodTypeDirHelpStr = Form("Directory for all %s methods", GetMethodTypeName().Data());
2046 fMethodBaseDir = datasetDir->mkdir(methodTypeDir.Data(), methodTypeDirHelpStr);
2047 Log() << kDEBUG << Form("Dataset[%s] : ", datasetName) << " Base Directory for " << GetMethodName()
2048 << " does not exist yet--> created it" << Endl;
2049 }
2050
2051 Log() << kDEBUG << Form("Dataset[%s] : ", datasetName)
2052 << "Return from MethodBaseDir() after creating base directory " << Endl;
2053 return fMethodBaseDir;
2054}
2055
2056////////////////////////////////////////////////////////////////////////////////
2057/// set directory of weight file
2058
2060{
2061 fFileDir = fileDir;
2062 gSystem->mkdir( fFileDir, kTRUE );
2063}
2064
2065////////////////////////////////////////////////////////////////////////////////
2066/// set the weight file name (depreciated)
2067
2069{
2070 fWeightFile = theWeightFile;
2071}
2072
2073////////////////////////////////////////////////////////////////////////////////
2074/// retrieve weight file name
2075
2077{
2078 if (fWeightFile!="") return fWeightFile;
2079
2080 // the default consists of
2081 // directory/jobname_methodname_suffix.extension.{root/txt}
2082 TString suffix = "";
2083 TString wFileDir(GetWeightFileDir());
2084 TString wFileName = GetJobName() + "_" + GetMethodName() +
2085 suffix + "." + gConfig().GetIONames().fWeightFileExtension + ".xml";
2086 if (wFileDir.IsNull() ) return wFileName;
2087 // add weight file directory of it is not null
2088 return ( wFileDir + (wFileDir[wFileDir.Length()-1]=='/' ? "" : "/")
2089 + wFileName );
2090}
2091////////////////////////////////////////////////////////////////////////////////
2092/// writes all MVA evaluation histograms to file
2093
2095{
2096 BaseDir()->cd();
2097
2098
2099 // write MVA PDFs to file - if exist
2100 if (0 != fMVAPdfS) {
2101 fMVAPdfS->GetOriginalHist()->Write();
2102 fMVAPdfS->GetSmoothedHist()->Write();
2103 fMVAPdfS->GetPDFHist()->Write();
2104 }
2105 if (0 != fMVAPdfB) {
2106 fMVAPdfB->GetOriginalHist()->Write();
2107 fMVAPdfB->GetSmoothedHist()->Write();
2108 fMVAPdfB->GetPDFHist()->Write();
2109 }
2110
2111 // write result-histograms
2112 Results* results = Data()->GetResults( GetMethodName(), treetype, Types::kMaxAnalysisType );
2113 if (!results)
2114 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<WriteEvaluationHistosToFile> Unknown result: "
2115 << GetMethodName() << (treetype==Types::kTraining?"/kTraining":"/kTesting")
2116 << "/kMaxAnalysisType" << Endl;
2117 results->GetStorage()->Write();
2118 if (treetype==Types::kTesting) {
2119 // skipping plotting of variables if too many (default is 200)
2120 if ((int) DataInfo().GetNVariables()< gConfig().GetVariablePlotting().fMaxNumOfAllowedVariables)
2121 GetTransformationHandler().PlotVariables (GetEventCollection( Types::kTesting ), BaseDir() );
2122 else
2123 Log() << kINFO << TString::Format("Dataset[%s] : ",DataInfo().GetName())
2124 << " variable plots are not produces ! The number of variables is " << DataInfo().GetNVariables()
2125 << " , it is larger than " << gConfig().GetVariablePlotting().fMaxNumOfAllowedVariables << Endl;
2126 }
2127}
2128
2129////////////////////////////////////////////////////////////////////////////////
2130/// write special monitoring histograms to file
2131/// dummy implementation here -----------------
2132
2134{
2135}
2136
2137////////////////////////////////////////////////////////////////////////////////
2138/// reads one line from the input stream
2139/// checks for certain keywords and interprets
2140/// the line if keywords are found
2141
2142Bool_t TMVA::MethodBase::GetLine(std::istream& fin, char* buf )
2143{
2144 fin.getline(buf,512);
2145 TString line(buf);
2146 if (line.BeginsWith("TMVA Release")) {
2147 Ssiz_t start = line.First('[')+1;
2148 Ssiz_t length = line.Index("]",start)-start;
2149 TString code = line(start,length);
2150 std::stringstream s(code.Data());
2151 s >> fTMVATrainingVersion;
2152 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "MVA method was trained with TMVA Version: " << GetTrainingTMVAVersionString() << Endl;
2153 }
2154 if (line.BeginsWith("ROOT Release")) {
2155 Ssiz_t start = line.First('[')+1;
2156 Ssiz_t length = line.Index("]",start)-start;
2157 TString code = line(start,length);
2158 std::stringstream s(code.Data());
2159 s >> fROOTTrainingVersion;
2160 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "MVA method was trained with ROOT Version: " << GetTrainingROOTVersionString() << Endl;
2161 }
2162 if (line.BeginsWith("Analysis type")) {
2163 Ssiz_t start = line.First('[')+1;
2164 Ssiz_t length = line.Index("]",start)-start;
2165 TString code = line(start,length);
2166 std::stringstream s(code.Data());
2167 std::string analysisType;
2168 s >> analysisType;
2169 if (analysisType == "regression" || analysisType == "Regression") SetAnalysisType( Types::kRegression );
2170 else if (analysisType == "classification" || analysisType == "Classification") SetAnalysisType( Types::kClassification );
2171 else if (analysisType == "multiclass" || analysisType == "Multiclass") SetAnalysisType( Types::kMulticlass );
2172 else Log() << kFATAL << "Analysis type " << analysisType << " from weight-file not known!" << std::endl;
2173
2174 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Method was trained for "
2175 << (GetAnalysisType() == Types::kRegression ? "Regression" :
2176 (GetAnalysisType() == Types::kMulticlass ? "Multiclass" : "Classification")) << Endl;
2177 }
2178
2179 return true;
2180}
2181
2182////////////////////////////////////////////////////////////////////////////////
2183/// Create PDFs of the MVA output variables
2184
2186{
2187 Data()->SetCurrentType(Types::kTraining);
2188
2189 // the PDF's are stored as results ONLY if the corresponding "results" are booked,
2190 // otherwise they will be only used 'online'
2191 ResultsClassification * mvaRes = dynamic_cast<ResultsClassification*>
2192 ( Data()->GetResults(GetMethodName(), Types::kTraining, Types::kClassification) );
2193
2194 if (mvaRes==0 || mvaRes->GetSize()==0) {
2195 Log() << kERROR<<Form("Dataset[%s] : ",DataInfo().GetName())<< "<CreateMVAPdfs> No result of classifier testing available" << Endl;
2196 }
2197
2198 Double_t minVal = *std::min_element(mvaRes->GetValueVector()->begin(),mvaRes->GetValueVector()->end());
2199 Double_t maxVal = *std::max_element(mvaRes->GetValueVector()->begin(),mvaRes->GetValueVector()->end());
2200
2201 // create histograms that serve as basis to create the MVA Pdfs
2202 TH1* histMVAPdfS = new TH1D( GetMethodTypeName() + "_tr_S", GetMethodTypeName() + "_tr_S",
2203 fMVAPdfS->GetHistNBins( mvaRes->GetSize() ), minVal, maxVal );
2204 TH1* histMVAPdfB = new TH1D( GetMethodTypeName() + "_tr_B", GetMethodTypeName() + "_tr_B",
2205 fMVAPdfB->GetHistNBins( mvaRes->GetSize() ), minVal, maxVal );
2206
2207
2208 // compute sum of weights properly
2209 histMVAPdfS->Sumw2();
2210 histMVAPdfB->Sumw2();
2211
2212 // fill histograms
2213 for (UInt_t ievt=0; ievt<mvaRes->GetSize(); ievt++) {
2214 Double_t theVal = mvaRes->GetValueVector()->at(ievt);
2215 Double_t theWeight = Data()->GetEvent(ievt)->GetWeight();
2216
2217 if (DataInfo().IsSignal(Data()->GetEvent(ievt))) histMVAPdfS->Fill( theVal, theWeight );
2218 else histMVAPdfB->Fill( theVal, theWeight );
2219 }
2220
2221 gTools().NormHist( histMVAPdfS );
2222 gTools().NormHist( histMVAPdfB );
2223
2224 // momentary hack for ROOT problem
2225 if(!IsSilentFile())
2226 {
2227 histMVAPdfS->Write();
2228 histMVAPdfB->Write();
2229 }
2230 // create PDFs
2231 fMVAPdfS->BuildPDF ( histMVAPdfS );
2232 fMVAPdfB->BuildPDF ( histMVAPdfB );
2233 fMVAPdfS->ValidatePDF( histMVAPdfS );
2234 fMVAPdfB->ValidatePDF( histMVAPdfB );
2235
2236 if (DataInfo().GetNClasses() == 2) { // TODO: this is an ugly hack.. adapt this to new framework
2237 Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName())
2238 << Form( "<CreateMVAPdfs> Separation from histogram (PDF): %1.3f (%1.3f)",
2239 GetSeparation( histMVAPdfS, histMVAPdfB ), GetSeparation( fMVAPdfS, fMVAPdfB ) )
2240 << Endl;
2241 }
2242
2243 delete histMVAPdfS;
2244 delete histMVAPdfB;
2245}
2246
2248 // the simple one, automatically calculates the mvaVal and uses the
2249 // SAME sig/bkg ratio as given in the training sample (typically 50/50
2250 // .. (NormMode=EqualNumEvents) but can be different)
2251 if (!fMVAPdfS || !fMVAPdfB) {
2252 Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName()) << "<GetProba> MVA PDFs for Signal and Background don't exist yet, we'll create them on demand" << Endl;
2253 CreateMVAPdfs();
2254 }
2255 Double_t sigFraction = DataInfo().GetTrainingSumSignalWeights() / (DataInfo().GetTrainingSumSignalWeights() + DataInfo().GetTrainingSumBackgrWeights() );
2256 Double_t mvaVal = GetMvaValue(ev);
2257
2258 return GetProba(mvaVal,sigFraction);
2259
2260}
2261////////////////////////////////////////////////////////////////////////////////
2262/// compute likelihood ratio
2263
2265{
2266 if (!fMVAPdfS || !fMVAPdfB) {
2267 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetProba> MVA PDFs for Signal and Background don't exist" << Endl;
2268 return -1.0;
2269 }
2270 Double_t p_s = fMVAPdfS->GetVal( mvaVal );
2271 Double_t p_b = fMVAPdfB->GetVal( mvaVal );
2272
2273 Double_t denom = p_s*ap_sig + p_b*(1 - ap_sig);
2274
2275 return (denom > 0) ? (p_s*ap_sig) / denom : -1;
2276}
2277
2278////////////////////////////////////////////////////////////////////////////////
2279/// compute rarity:
2280/// \f[
2281/// R(x) = \int_{[-\infty..x]} { PDF(x') dx' }
2282/// \f]
2283/// where PDF(x) is the PDF of the classifier's signal or background distribution
2284
2286{
2287 if ((reftype == Types::kSignal && !fMVAPdfS) || (reftype == Types::kBackground && !fMVAPdfB)) {
2288 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetRarity> Required MVA PDF for Signal or Background does not exist: "
2289 << "select option \"CreateMVAPdfs\"" << Endl;
2290 return 0.0;
2291 }
2292
2293 PDF* thePdf = ((reftype == Types::kSignal) ? fMVAPdfS : fMVAPdfB);
2294
2295 return thePdf->GetIntegral( thePdf->GetXmin(), mvaVal );
2296}
2297
2298////////////////////////////////////////////////////////////////////////////////
2299/// fill background efficiency (resp. rejection) versus signal efficiency plots
2300/// returns signal efficiency at background efficiency indicated in theString
2301
2303{
2304 Data()->SetCurrentType(type);
2305 Results* results = Data()->GetResults( GetMethodName(), type, Types::kClassification );
2306 std::vector<Float_t>* mvaRes = dynamic_cast<ResultsClassification*>(results)->GetValueVector();
2307
2308 // parse input string for required background efficiency
2309 TList* list = gTools().ParseFormatLine( theString );
2310
2311 // sanity check
2312 Bool_t computeArea = kFALSE;
2313 if (!list || list->GetSize() < 2) computeArea = kTRUE; // the area is computed
2314 else if (list->GetSize() > 2) {
2315 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetEfficiency> Wrong number of arguments"
2316 << " in string: " << theString
2317 << " | required format, e.g., Efficiency:0.05, or empty string" << Endl;
2318 delete list;
2319 return -1;
2320 }
2321
2322 // sanity check
2323 if ( results->GetHist("MVA_S")->GetNbinsX() != results->GetHist("MVA_B")->GetNbinsX() ||
2324 results->GetHist("MVA_HIGHBIN_S")->GetNbinsX() != results->GetHist("MVA_HIGHBIN_B")->GetNbinsX() ) {
2325 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetEfficiency> Binning mismatch between signal and background histos" << Endl;
2326 delete list;
2327 return -1.0;
2328 }
2329
2330 // create histograms
2331
2332 // first, get efficiency histograms for signal and background
2333 TH1 * effhist = results->GetHist("MVA_HIGHBIN_S");
2334 Double_t xmin = effhist->GetXaxis()->GetXmin();
2335 Double_t xmax = effhist->GetXaxis()->GetXmax();
2336
2337 TTHREAD_TLS(Double_t) nevtS;
2338
2339 // first round ? --> create histograms
2340 if (results->DoesExist("MVA_EFF_S")==0) {
2341
2342 // for efficiency plot
2343 TH1* eff_s = new TH1D( GetTestvarName() + "_effS", GetTestvarName() + " (signal)", fNbinsH, xmin, xmax );
2344 TH1* eff_b = new TH1D( GetTestvarName() + "_effB", GetTestvarName() + " (background)", fNbinsH, xmin, xmax );
2345 results->Store(eff_s, "MVA_EFF_S");
2346 results->Store(eff_b, "MVA_EFF_B");
2347
2348 // sign if cut
2349 Int_t sign = (fCutOrientation == kPositive) ? +1 : -1;
2350
2351 // this method is unbinned
2352 nevtS = 0;
2353 for (UInt_t ievt=0; ievt<Data()->GetNEvents(); ievt++) {
2354
2355 // read the tree
2356 Bool_t isSignal = DataInfo().IsSignal(GetEvent(ievt));
2357 Float_t theWeight = GetEvent(ievt)->GetWeight();
2358 Float_t theVal = (*mvaRes)[ievt];
2359
2360 // select histogram depending on if sig or bgd
2361 TH1* theHist = isSignal ? eff_s : eff_b;
2362
2363 // count signal and background events in tree
2364 if (isSignal) nevtS+=theWeight;
2365
2366 TAxis* axis = theHist->GetXaxis();
2367 Int_t maxbin = Int_t((theVal - axis->GetXmin())/(axis->GetXmax() - axis->GetXmin())*fNbinsH) + 1;
2368 if (sign > 0 && maxbin > fNbinsH) continue; // can happen... event doesn't count
2369 if (sign < 0 && maxbin < 1 ) continue; // can happen... event doesn't count
2370 if (sign > 0 && maxbin < 1 ) maxbin = 1;
2371 if (sign < 0 && maxbin > fNbinsH) maxbin = fNbinsH;
2372
2373 if (sign > 0)
2374 for (Int_t ibin=1; ibin<=maxbin; ibin++) theHist->AddBinContent( ibin , theWeight);
2375 else if (sign < 0)
2376 for (Int_t ibin=maxbin+1; ibin<=fNbinsH; ibin++) theHist->AddBinContent( ibin , theWeight );
2377 else
2378 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetEfficiency> Mismatch in sign" << Endl;
2379 }
2380
2381 // renormalise maximum to <=1
2382 // eff_s->Scale( 1.0/TMath::Max(1.,eff_s->GetMaximum()) );
2383 // eff_b->Scale( 1.0/TMath::Max(1.,eff_b->GetMaximum()) );
2384
2387
2388 // background efficiency versus signal efficiency
2389 TH1* eff_BvsS = new TH1D( GetTestvarName() + "_effBvsS", GetTestvarName() + "", fNbins, 0, 1 );
2390 results->Store(eff_BvsS, "MVA_EFF_BvsS");
2391 eff_BvsS->SetXTitle( "Signal eff" );
2392 eff_BvsS->SetYTitle( "Backgr eff" );
2393
2394 // background rejection (=1-eff.) versus signal efficiency
2395 TH1* rej_BvsS = new TH1D( GetTestvarName() + "_rejBvsS", GetTestvarName() + "", fNbins, 0, 1 );
2396 results->Store(rej_BvsS);
2397 rej_BvsS->SetXTitle( "Signal eff" );
2398 rej_BvsS->SetYTitle( "Backgr rejection (1-eff)" );
2399
2400 // inverse background eff (1/eff.) versus signal efficiency
2401 TH1* inveff_BvsS = new TH1D( GetTestvarName() + "_invBeffvsSeff",
2402 GetTestvarName(), fNbins, 0, 1 );
2403 results->Store(inveff_BvsS);
2404 inveff_BvsS->SetXTitle( "Signal eff" );
2405 inveff_BvsS->SetYTitle( "Inverse backgr. eff (1/eff)" );
2406
2407 // use root finder
2408 // spline background efficiency plot
2409 // note that there is a bin shift when going from a TH1D object to a TGraph :-(
2411 fSplRefS = new TSpline1( "spline2_signal", new TGraph( eff_s ) );
2412 fSplRefB = new TSpline1( "spline2_background", new TGraph( eff_b ) );
2413
2414 // verify spline sanity
2415 gTools().CheckSplines( eff_s, fSplRefS );
2416 gTools().CheckSplines( eff_b, fSplRefB );
2417 }
2418
2419 // make the background-vs-signal efficiency plot
2420
2421 // create root finder
2422 RootFinder rootFinder( this, fXmin, fXmax );
2423
2424 Double_t effB = 0;
2425 fEffS = eff_s; // to be set for the root finder
2426 for (Int_t bini=1; bini<=fNbins; bini++) {
2427
2428 // find cut value corresponding to a given signal efficiency
2429 Double_t effS = eff_BvsS->GetBinCenter( bini );
2430 Double_t cut = rootFinder.Root( effS );
2431
2432 // retrieve background efficiency for given cut
2433 if (Use_Splines_for_Eff_) effB = fSplRefB->Eval( cut );
2434 else effB = eff_b->GetBinContent( eff_b->FindBin( cut ) );
2435
2436 // and fill histograms
2437 eff_BvsS->SetBinContent( bini, effB );
2438 rej_BvsS->SetBinContent( bini, 1.0-effB );
2440 inveff_BvsS->SetBinContent( bini, 1.0/effB );
2441 }
2442
2443 // create splines for histogram
2444 fSpleffBvsS = new TSpline1( "effBvsS", new TGraph( eff_BvsS ) );
2445
2446 // search for overlap point where, when cutting on it,
2447 // one would obtain: eff_S = rej_B = 1 - eff_B
2448 Double_t effS = 0., rejB, effS_ = 0., rejB_ = 0.;
2449 Int_t nbins_ = 5000;
2450 for (Int_t bini=1; bini<=nbins_; bini++) {
2451
2452 // get corresponding signal and background efficiencies
2453 effS = (bini - 0.5)/Float_t(nbins_);
2454 rejB = 1.0 - fSpleffBvsS->Eval( effS );
2455
2456 // find signal efficiency that corresponds to required background efficiency
2457 if ((effS - rejB)*(effS_ - rejB_) < 0) break;
2458 effS_ = effS;
2459 rejB_ = rejB;
2460 }
2461
2462 // find cut that corresponds to signal efficiency and update signal-like criterion
2463 Double_t cut = rootFinder.Root( 0.5*(effS + effS_) );
2464 SetSignalReferenceCut( cut );
2465 fEffS = 0;
2466 }
2467
2468 // must exist...
2469 if (0 == fSpleffBvsS) {
2470 delete list;
2471 return 0.0;
2472 }
2473
2474 // now find signal efficiency that corresponds to required background efficiency
2475 Double_t effS = 0, effB = 0, effS_ = 0, effB_ = 0;
2476 Int_t nbins_ = 1000;
2477
2478 if (computeArea) {
2479
2480 // compute area of rej-vs-eff plot
2481 Double_t integral = 0;
2482 for (Int_t bini=1; bini<=nbins_; bini++) {
2483
2484 // get corresponding signal and background efficiencies
2485 effS = (bini - 0.5)/Float_t(nbins_);
2486 effB = fSpleffBvsS->Eval( effS );
2487 integral += (1.0 - effB);
2488 }
2489 integral /= nbins_;
2490
2491 delete list;
2492 return integral;
2493 }
2494 else {
2495
2496 // that will be the value of the efficiency retured (does not affect
2497 // the efficiency-vs-bkg plot which is done anyway.
2498 Float_t effBref = atof( ((TObjString*)list->At(1))->GetString() );
2499
2500 // find precise efficiency value
2501 for (Int_t bini=1; bini<=nbins_; bini++) {
2502
2503 // get corresponding signal and background efficiencies
2504 effS = (bini - 0.5)/Float_t(nbins_);
2505 effB = fSpleffBvsS->Eval( effS );
2506
2507 // find signal efficiency that corresponds to required background efficiency
2508 if ((effB - effBref)*(effB_ - effBref) <= 0) break;
2509 effS_ = effS;
2510 effB_ = effB;
2511 }
2512
2513 // take mean between bin above and bin below
2514 effS = 0.5*(effS + effS_);
2515
2516 effSerr = 0;
2517 if (nevtS > 0) effSerr = TMath::Sqrt( effS*(1.0 - effS)/nevtS );
2518
2519 delete list;
2520 return effS;
2521 }
2522
2523 return -1;
2524}
2525
2526////////////////////////////////////////////////////////////////////////////////
2527
2529{
2530 Data()->SetCurrentType(Types::kTraining);
2531
2532 Results* results = Data()->GetResults(GetMethodName(), Types::kTesting, Types::kNoAnalysisType);
2533
2534 // fill background efficiency (resp. rejection) versus signal efficiency plots
2535 // returns signal efficiency at background efficiency indicated in theString
2536
2537 // parse input string for required background efficiency
2538 TList* list = gTools().ParseFormatLine( theString );
2539 // sanity check
2540
2541 if (list->GetSize() != 2) {
2542 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetTrainingEfficiency> Wrong number of arguments"
2543 << " in string: " << theString
2544 << " | required format, e.g., Efficiency:0.05" << Endl;
2545 delete list;
2546 return -1;
2547 }
2548 // that will be the value of the efficiency retured (does not affect
2549 // the efficiency-vs-bkg plot which is done anyway.
2550 Float_t effBref = atof( ((TObjString*)list->At(1))->GetString() );
2551
2552 delete list;
2553
2554 // sanity check
2555 if (results->GetHist("MVA_S")->GetNbinsX() != results->GetHist("MVA_B")->GetNbinsX() ||
2556 results->GetHist("MVA_HIGHBIN_S")->GetNbinsX() != results->GetHist("MVA_HIGHBIN_B")->GetNbinsX() ) {
2557 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetTrainingEfficiency> Binning mismatch between signal and background histos"
2558 << Endl;
2559 return -1.0;
2560 }
2561
2562 // create histogram
2563
2564 // first, get efficiency histograms for signal and background
2565 TH1 * effhist = results->GetHist("MVA_HIGHBIN_S");
2566 Double_t xmin = effhist->GetXaxis()->GetXmin();
2567 Double_t xmax = effhist->GetXaxis()->GetXmax();
2568
2569 // first round ? --> create and fill histograms
2570 if (results->DoesExist("MVA_TRAIN_S")==0) {
2571
2572 // classifier response distributions for test sample
2573 Double_t sxmax = fXmax+0.00001;
2574
2575 // MVA plots on the training sample (check for overtraining)
2576 TH1* mva_s_tr = new TH1D( GetTestvarName() + "_Train_S",GetTestvarName() + "_Train_S", fNbinsMVAoutput, fXmin, sxmax );
2577 TH1* mva_b_tr = new TH1D( GetTestvarName() + "_Train_B",GetTestvarName() + "_Train_B", fNbinsMVAoutput, fXmin, sxmax );
2578 results->Store(mva_s_tr, "MVA_TRAIN_S");
2579 results->Store(mva_b_tr, "MVA_TRAIN_B");
2580 mva_s_tr->Sumw2();
2581 mva_b_tr->Sumw2();
2582
2583 // Training efficiency plots
2584 TH1* mva_eff_tr_s = new TH1D( GetTestvarName() + "_trainingEffS", GetTestvarName() + " (signal)",
2585 fNbinsH, xmin, xmax );
2586 TH1* mva_eff_tr_b = new TH1D( GetTestvarName() + "_trainingEffB", GetTestvarName() + " (background)",
2587 fNbinsH, xmin, xmax );
2588 results->Store(mva_eff_tr_s, "MVA_TRAINEFF_S");
2589 results->Store(mva_eff_tr_b, "MVA_TRAINEFF_B");
2590
2591 // sign if cut
2592 Int_t sign = (fCutOrientation == kPositive) ? +1 : -1;
2593
2594 std::vector<Double_t> mvaValues = GetMvaValues(0,Data()->GetNEvents());
2595 assert( (Long64_t) mvaValues.size() == Data()->GetNEvents());
2596
2597 // this method is unbinned
2598 for (Int_t ievt=0; ievt<Data()->GetNEvents(); ievt++) {
2599
2600 Data()->SetCurrentEvent(ievt);
2601 const Event* ev = GetEvent();
2602
2603 Double_t theVal = mvaValues[ievt];
2604 Double_t theWeight = ev->GetWeight();
2605
2606 TH1* theEffHist = DataInfo().IsSignal(ev) ? mva_eff_tr_s : mva_eff_tr_b;
2607 TH1* theClsHist = DataInfo().IsSignal(ev) ? mva_s_tr : mva_b_tr;
2608
2609 theClsHist->Fill( theVal, theWeight );
2610
2611 TAxis* axis = theEffHist->GetXaxis();
2612 Int_t maxbin = Int_t((theVal - axis->GetXmin())/(axis->GetXmax() - axis->GetXmin())*fNbinsH) + 1;
2613 if (sign > 0 && maxbin > fNbinsH) continue; // can happen... event doesn't count
2614 if (sign < 0 && maxbin < 1 ) continue; // can happen... event doesn't count
2615 if (sign > 0 && maxbin < 1 ) maxbin = 1;
2616 if (sign < 0 && maxbin > fNbinsH) maxbin = fNbinsH;
2617
2618 if (sign > 0) for (Int_t ibin=1; ibin<=maxbin; ibin++) theEffHist->AddBinContent( ibin , theWeight );
2619 else for (Int_t ibin=maxbin+1; ibin<=fNbinsH; ibin++) theEffHist->AddBinContent( ibin , theWeight );
2620 }
2621
2622 // normalise output distributions
2623 // uncomment those (and several others if you want unnormalized output
2624 gTools().NormHist( mva_s_tr );
2625 gTools().NormHist( mva_b_tr );
2626
2627 // renormalise to maximum
2628 mva_eff_tr_s->Scale( 1.0/TMath::Max(std::numeric_limits<double>::epsilon(), mva_eff_tr_s->GetMaximum()) );
2629 mva_eff_tr_b->Scale( 1.0/TMath::Max(std::numeric_limits<double>::epsilon(), mva_eff_tr_b->GetMaximum()) );
2630
2631 // Training background efficiency versus signal efficiency
2632 TH1* eff_bvss = new TH1D( GetTestvarName() + "_trainingEffBvsS", GetTestvarName() + "", fNbins, 0, 1 );
2633 // Training background rejection (=1-eff.) versus signal efficiency
2634 TH1* rej_bvss = new TH1D( GetTestvarName() + "_trainingRejBvsS", GetTestvarName() + "", fNbins, 0, 1 );
2635 results->Store(eff_bvss, "EFF_BVSS_TR");
2636 results->Store(rej_bvss, "REJ_BVSS_TR");
2637
2638 // use root finder
2639 // spline background efficiency plot
2640 // note that there is a bin shift when going from a TH1D object to a TGraph :-(
2642 if (fSplTrainRefS) delete fSplTrainRefS;
2643 if (fSplTrainRefB) delete fSplTrainRefB;
2644 fSplTrainRefS = new TSpline1( "spline2_signal", new TGraph( mva_eff_tr_s ) );
2645 fSplTrainRefB = new TSpline1( "spline2_background", new TGraph( mva_eff_tr_b ) );
2646
2647 // verify spline sanity
2648 gTools().CheckSplines( mva_eff_tr_s, fSplTrainRefS );
2649 gTools().CheckSplines( mva_eff_tr_b, fSplTrainRefB );
2650 }
2651
2652 // make the background-vs-signal efficiency plot
2653
2654 // create root finder
2655 RootFinder rootFinder(this, fXmin, fXmax );
2656
2657 Double_t effB = 0;
2658 fEffS = results->GetHist("MVA_TRAINEFF_S");
2659 for (Int_t bini=1; bini<=fNbins; bini++) {
2660
2661 // find cut value corresponding to a given signal efficiency
2662 Double_t effS = eff_bvss->GetBinCenter( bini );
2663
2664 Double_t cut = rootFinder.Root( effS );
2665
2666 // retrieve background efficiency for given cut
2667 if (Use_Splines_for_Eff_) effB = fSplTrainRefB->Eval( cut );
2668 else effB = mva_eff_tr_b->GetBinContent( mva_eff_tr_b->FindBin( cut ) );
2669
2670 // and fill histograms
2671 eff_bvss->SetBinContent( bini, effB );
2672 rej_bvss->SetBinContent( bini, 1.0-effB );
2673 }
2674 fEffS = 0;
2675
2676 // create splines for histogram
2677 fSplTrainEffBvsS = new TSpline1( "effBvsS", new TGraph( eff_bvss ) );
2678 }
2679
2680 // must exist...
2681 if (0 == fSplTrainEffBvsS) return 0.0;
2682
2683 // now find signal efficiency that corresponds to required background efficiency
2684 Double_t effS = 0., effB, effS_ = 0., effB_ = 0.;
2685 Int_t nbins_ = 1000;
2686 for (Int_t bini=1; bini<=nbins_; bini++) {
2687
2688 // get corresponding signal and background efficiencies
2689 effS = (bini - 0.5)/Float_t(nbins_);
2690 effB = fSplTrainEffBvsS->Eval( effS );
2691
2692 // find signal efficiency that corresponds to required background efficiency
2693 if ((effB - effBref)*(effB_ - effBref) <= 0) break;
2694 effS_ = effS;
2695 effB_ = effB;
2696 }
2697
2698 return 0.5*(effS + effS_); // the mean between bin above and bin below
2699}
2700
2701////////////////////////////////////////////////////////////////////////////////
2702
2703std::vector<Float_t> TMVA::MethodBase::GetMulticlassEfficiency(std::vector<std::vector<Float_t> >& purity)
2704{
2705 Data()->SetCurrentType(Types::kTesting);
2706 ResultsMulticlass* resMulticlass = dynamic_cast<ResultsMulticlass*>(Data()->GetResults(GetMethodName(), Types::kTesting, Types::kMulticlass));
2707 if (!resMulticlass) Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName())<< "unable to create pointer in GetMulticlassEfficiency, exiting."<<Endl;
2708
2709 purity.push_back(resMulticlass->GetAchievablePur());
2710 return resMulticlass->GetAchievableEff();
2711}
2712
2713////////////////////////////////////////////////////////////////////////////////
2714
2715std::vector<Float_t> TMVA::MethodBase::GetMulticlassTrainingEfficiency(std::vector<std::vector<Float_t> >& purity)
2716{
2717 Data()->SetCurrentType(Types::kTraining);
2718 ResultsMulticlass* resMulticlass = dynamic_cast<ResultsMulticlass*>(Data()->GetResults(GetMethodName(), Types::kTraining, Types::kMulticlass));
2719 if (!resMulticlass) Log() << kFATAL<< "unable to create pointer in GetMulticlassTrainingEfficiency, exiting."<<Endl;
2720
2721 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Determine optimal multiclass cuts for training data..." << Endl;
2722 for (UInt_t icls = 0; icls<DataInfo().GetNClasses(); ++icls) {
2723 resMulticlass->GetBestMultiClassCuts(icls);
2724 }
2725
2726 purity.push_back(resMulticlass->GetAchievablePur());
2727 return resMulticlass->GetAchievableEff();
2728}
2729
2730////////////////////////////////////////////////////////////////////////////////
2731/// Construct a confusion matrix for a multiclass classifier. The confusion
2732/// matrix compares, in turn, each class agaist all other classes in a pair-wise
2733/// fashion. In rows with index \f$ k_r = 0 ... K \f$, \f$ k_r \f$ is
2734/// considered signal for the sake of comparison and for each column
2735/// \f$ k_c = 0 ... K \f$ the corresponding class is considered background.
2736///
2737/// Note that the diagonal elements will be returned as NaN since this will
2738/// compare a class against itself.
2739///
2740/// \see TMVA::ResultsMulticlass::GetConfusionMatrix
2741///
2742/// \param[in] effB The background efficiency for which to evaluate.
2743/// \param[in] type The data set on which to evaluate (training, testing ...).
2744///
2745/// \return A matrix containing signal efficiencies for the given background
2746/// efficiency. The diagonal elements are NaN since this measure is
2747/// meaningless (comparing a class against itself).
2748///
2749
2751{
2752 if (GetAnalysisType() != Types::kMulticlass) {
2753 Log() << kFATAL << "Cannot get confusion matrix for non-multiclass analysis." << std::endl;
2754 return TMatrixD(0, 0);
2755 }
2756
2757 Data()->SetCurrentType(type);
2758 ResultsMulticlass *resMulticlass =
2759 dynamic_cast<ResultsMulticlass *>(Data()->GetResults(GetMethodName(), type, Types::kMulticlass));
2760
2761 if (resMulticlass == nullptr) {
2762 Log() << kFATAL << Form("Dataset[%s] : ", DataInfo().GetName())
2763 << "unable to create pointer in GetMulticlassEfficiency, exiting." << Endl;
2764 return TMatrixD(0, 0);
2765 }
2766
2767 return resMulticlass->GetConfusionMatrix(effB);
2768}
2769
2770////////////////////////////////////////////////////////////////////////////////
2771/// compute significance of mean difference
2772/// \f[
2773/// significance = \frac{|<S> - <B>|}{\sqrt{RMS_{S2} + RMS_{B2}}}
2774/// \f]
2775
2777{
2778 Double_t rms = sqrt( fRmsS*fRmsS + fRmsB*fRmsB );
2779
2780 return (rms > 0) ? TMath::Abs(fMeanS - fMeanB)/rms : 0;
2781}
2782
2783////////////////////////////////////////////////////////////////////////////////
2784/// compute "separation" defined as
2785/// \f[
2786/// <s2> = \frac{1}{2} \int_{-\infty}^{+\infty} { \frac{(S(x) - B(x))^2}{(S(x) + B(x))} dx }
2787/// \f]
2788
2790{
2791 return gTools().GetSeparation( histoS, histoB );
2792}
2793
2794////////////////////////////////////////////////////////////////////////////////
2795/// compute "separation" defined as
2796/// \f[
2797/// <s2> = \frac{1}{2} \int_{-\infty}^{+\infty} { \frac{(S(x) - B(x))^2}{(S(x) + B(x))} dx }
2798/// \f]
2799
2801{
2802 // note, if zero pointers given, use internal pdf
2803 // sanity check first
2804 if ((!pdfS && pdfB) || (pdfS && !pdfB))
2805 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetSeparation> Mismatch in pdfs" << Endl;
2806 if (!pdfS) pdfS = fSplS;
2807 if (!pdfB) pdfB = fSplB;
2808
2809 if (!fSplS || !fSplB) {
2810 Log()<<kDEBUG<<Form("[%s] : ",DataInfo().GetName())<< "could not calculate the separation, distributions"
2811 << " fSplS or fSplB are not yet filled" << Endl;
2812 return 0;
2813 }else{
2814 return gTools().GetSeparation( *pdfS, *pdfB );
2815 }
2816}
2817
2818////////////////////////////////////////////////////////////////////////////////
2819/// calculate the area (integral) under the ROC curve as a
2820/// overall quality measure of the classification
2821
2823{
2824 // note, if zero pointers given, use internal pdf
2825 // sanity check first
2826 if ((!histS && histB) || (histS && !histB))
2827 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetROCIntegral(TH1D*, TH1D*)> Mismatch in hists" << Endl;
2828
2829 if (histS==0 || histB==0) return 0.;
2830
2831 TMVA::PDF *pdfS = new TMVA::PDF( " PDF Sig", histS, TMVA::PDF::kSpline3 );
2832 TMVA::PDF *pdfB = new TMVA::PDF( " PDF Bkg", histB, TMVA::PDF::kSpline3 );
2833
2834
2835 Double_t xmin = TMath::Min(pdfS->GetXmin(), pdfB->GetXmin());
2836 Double_t xmax = TMath::Max(pdfS->GetXmax(), pdfB->GetXmax());
2837
2838 Double_t integral = 0;
2839 UInt_t nsteps = 1000;
2840 Double_t step = (xmax-xmin)/Double_t(nsteps);
2841 Double_t cut = xmin;
2842 for (UInt_t i=0; i<nsteps; i++) {
2843 integral += (1-pdfB->GetIntegral(cut,xmax)) * pdfS->GetVal(cut);
2844 cut+=step;
2845 }
2846 delete pdfS;
2847 delete pdfB;
2848 return integral*step;
2849}
2850
2851
2852////////////////////////////////////////////////////////////////////////////////
2853/// calculate the area (integral) under the ROC curve as a
2854/// overall quality measure of the classification
2855
2857{
2858 // note, if zero pointers given, use internal pdf
2859 // sanity check first
2860 if ((!pdfS && pdfB) || (pdfS && !pdfB))
2861 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetSeparation> Mismatch in pdfs" << Endl;
2862 if (!pdfS) pdfS = fSplS;
2863 if (!pdfB) pdfB = fSplB;
2864
2865 if (pdfS==0 || pdfB==0) return 0.;
2866
2867 Double_t xmin = TMath::Min(pdfS->GetXmin(), pdfB->GetXmin());
2868 Double_t xmax = TMath::Max(pdfS->GetXmax(), pdfB->GetXmax());
2869
2870 Double_t integral = 0;
2871 UInt_t nsteps = 1000;
2872 Double_t step = (xmax-xmin)/Double_t(nsteps);
2873 Double_t cut = xmin;
2874 for (UInt_t i=0; i<nsteps; i++) {
2875 integral += (1-pdfB->GetIntegral(cut,xmax)) * pdfS->GetVal(cut);
2876 cut+=step;
2877 }
2878 return integral*step;
2879}
2880
2881////////////////////////////////////////////////////////////////////////////////
2882/// plot significance, \f$ \frac{S}{\sqrt{S^2 + B^2}} \f$, curve for given number
2883/// of signal and background events; returns cut for maximum significance
2884/// also returned via reference is the maximum significance
2885
2887 Double_t BackgroundEvents,
2888 Double_t& max_significance_value ) const
2889{
2890 Results* results = Data()->GetResults( GetMethodName(), Types::kTesting, Types::kMaxAnalysisType );
2891
2892 Double_t max_significance(0);
2893 Double_t effS(0),effB(0),significance(0);
2894 TH1D *temp_histogram = new TH1D("temp", "temp", fNbinsH, fXmin, fXmax );
2895
2896 if (SignalEvents <= 0 || BackgroundEvents <= 0) {
2897 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetMaximumSignificance> "
2898 << "Number of signal or background events is <= 0 ==> abort"
2899 << Endl;
2900 }
2901
2902 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Using ratio SignalEvents/BackgroundEvents = "
2903 << SignalEvents/BackgroundEvents << Endl;
2904
2905 TH1* eff_s = results->GetHist("MVA_EFF_S");
2906 TH1* eff_b = results->GetHist("MVA_EFF_B");
2907
2908 if ( (eff_s==0) || (eff_b==0) ) {
2909 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Efficiency histograms empty !" << Endl;
2910 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "no maximum cut found, return 0" << Endl;
2911 return 0;
2912 }
2913
2914 for (Int_t bin=1; bin<=fNbinsH; bin++) {
2915 effS = eff_s->GetBinContent( bin );
2916 effB = eff_b->GetBinContent( bin );
2917
2918 // put significance into a histogram
2919 significance = sqrt(SignalEvents)*( effS )/sqrt( effS + ( BackgroundEvents / SignalEvents) * effB );
2920
2921 temp_histogram->SetBinContent(bin,significance);
2922 }
2923
2924 // find maximum in histogram
2925 max_significance = temp_histogram->GetBinCenter( temp_histogram->GetMaximumBin() );
2926 max_significance_value = temp_histogram->GetBinContent( temp_histogram->GetMaximumBin() );
2927
2928 // delete
2929 delete temp_histogram;
2930
2931 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Optimal cut at : " << max_significance << Endl;
2932 Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName()) << "Maximum significance: " << max_significance_value << Endl;
2933
2934 return max_significance;
2935}
2936
2937////////////////////////////////////////////////////////////////////////////////
2938/// calculates rms,mean, xmin, xmax of the event variable
2939/// this can be either done for the variables as they are or for
2940/// normalised variables (in the range of 0-1) if "norm" is set to kTRUE
2941
2943 Double_t& meanS, Double_t& meanB,
2944 Double_t& rmsS, Double_t& rmsB,
2946{
2947 Types::ETreeType previousTreeType = Data()->GetCurrentType();
2948 Data()->SetCurrentType(treeType);
2949
2950 Long64_t entries = Data()->GetNEvents();
2951
2952 // sanity check
2953 if (entries <=0)
2954 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<CalculateEstimator> Wrong tree type: " << treeType << Endl;
2955
2956 // index of the wanted variable
2957 UInt_t varIndex = DataInfo().FindVarIndex( theVarName );
2958
2959 // first fill signal and background in arrays before analysis
2960 xmin = +DBL_MAX;
2961 xmax = -DBL_MAX;
2962 Long64_t nEventsS = -1;
2963 Long64_t nEventsB = -1;
2964
2965 // take into account event weights
2966 meanS = 0;
2967 meanB = 0;
2968 rmsS = 0;
2969 rmsB = 0;
2970 Double_t sumwS = 0, sumwB = 0;
2971
2972 // loop over all training events
2973 for (Int_t ievt = 0; ievt < entries; ievt++) {
2974
2975 const Event* ev = GetEvent(ievt);
2976
2977 Double_t theVar = ev->GetValue(varIndex);
2978 Double_t weight = ev->GetWeight();
2979
2980 if (DataInfo().IsSignal(ev)) {
2981 sumwS += weight;
2982 meanS += weight*theVar;
2983 rmsS += weight*theVar*theVar;
2984 }
2985 else {
2986 sumwB += weight;
2987 meanB += weight*theVar;
2988 rmsB += weight*theVar*theVar;
2989 }
2990 xmin = TMath::Min( xmin, theVar );
2991 xmax = TMath::Max( xmax, theVar );
2992 }
2993 ++nEventsS;
2994 ++nEventsB;
2995
2996 meanS = meanS/sumwS;
2997 meanB = meanB/sumwB;
2998 rmsS = TMath::Sqrt( rmsS/sumwS - meanS*meanS );
2999 rmsB = TMath::Sqrt( rmsB/sumwB - meanB*meanB );
3000
3001 Data()->SetCurrentType(previousTreeType);
3002}
3003
3004////////////////////////////////////////////////////////////////////////////////
3005/// create reader class for method (classification only at present)
3006
3007void TMVA::MethodBase::MakeClass( const TString& theClassFileName ) const
3008{
3009 // the default consists of
3010 TString classFileName = "";
3011 if (theClassFileName == "")
3012 classFileName = GetWeightFileDir() + "/" + GetJobName() + "_" + GetMethodName() + ".class.C";
3013 else
3014 classFileName = theClassFileName;
3015
3016 TString className = TString("Read") + GetMethodName();
3017
3018 TString tfname( classFileName );
3019 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
3020 << "Creating standalone class: "
3021 << gTools().Color("lightblue") << classFileName << gTools().Color("reset") << Endl;
3022
3023 std::ofstream fout( classFileName );
3024 if (!fout.good()) { // file could not be opened --> Error
3025 Log() << kFATAL << "<MakeClass> Unable to open file: " << classFileName << Endl;
3026 }
3027
3028 // now create the class
3029 // preamble
3030 fout << "// Class: " << className << std::endl;
3031 fout << "// Automatically generated by MethodBase::MakeClass" << std::endl << "//" << std::endl;
3032
3033 // print general information and configuration state
3034 fout << std::endl;
3035 fout << "/* configuration options =====================================================" << std::endl << std::endl;
3036 WriteStateToStream( fout );
3037 fout << std::endl;
3038 fout << "============================================================================ */" << std::endl;
3039
3040 // generate the class
3041 fout << "" << std::endl;
3042 fout << "#include <array>" << std::endl;
3043 fout << "#include <vector>" << std::endl;
3044 fout << "#include <cmath>" << std::endl;
3045 fout << "#include <string>" << std::endl;
3046 fout << "#include <iostream>" << std::endl;
3047 fout << "" << std::endl;
3048 // now if the classifier needs to write some additional classes for its response implementation
3049 // this code goes here: (at least the header declarations need to come before the main class
3050 this->MakeClassSpecificHeader( fout, className );
3051
3052 fout << "#ifndef IClassifierReader__def" << std::endl;
3053 fout << "#define IClassifierReader__def" << std::endl;
3054 fout << std::endl;
3055 fout << "class IClassifierReader {" << std::endl;
3056 fout << std::endl;
3057 fout << " public:" << std::endl;
3058 fout << std::endl;
3059 fout << " // constructor" << std::endl;
3060 fout << " IClassifierReader() : fStatusIsClean( true ) {}" << std::endl;
3061 fout << " virtual ~IClassifierReader() {}" << std::endl;
3062 fout << std::endl;
3063 fout << " // return classifier response" << std::endl;
3064 if(GetAnalysisType() == Types::kMulticlass) {
3065 fout << " virtual std::vector<double> GetMulticlassValues( const std::vector<double>& inputValues ) const = 0;" << std::endl;
3066 } else {
3067 fout << " virtual double GetMvaValue( const std::vector<double>& inputValues ) const = 0;" << std::endl;
3068 }
3069 fout << std::endl;
3070 fout << " // returns classifier status" << std::endl;
3071 fout << " bool IsStatusClean() const { return fStatusIsClean; }" << std::endl;
3072 fout << std::endl;
3073 fout << " protected:" << std::endl;
3074 fout << std::endl;
3075 fout << " bool fStatusIsClean;" << std::endl;
3076 fout << "};" << std::endl;
3077 fout << std::endl;
3078 fout << "#endif" << std::endl;
3079 fout << std::endl;
3080 fout << "class " << className << " : public IClassifierReader {" << std::endl;
3081 fout << std::endl;
3082 fout << " public:" << std::endl;
3083 fout << std::endl;
3084 fout << " // constructor" << std::endl;
3085 fout << " " << className << "( std::vector<std::string>& theInputVars )" << std::endl;
3086 fout << " : IClassifierReader()," << std::endl;
3087 fout << " fClassName( \"" << className << "\" )," << std::endl;
3088 fout << " fNvars( " << GetNvar() << " )" << std::endl;
3089 fout << " {" << std::endl;
3090 fout << " // the training input variables" << std::endl;
3091 fout << " const char* inputVars[] = { ";
3092 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
3093 fout << "\"" << GetOriginalVarName(ivar) << "\"";
3094 if (ivar<GetNvar()-1) fout << ", ";
3095 }
3096 fout << " };" << std::endl;
3097 fout << std::endl;
3098 fout << " // sanity checks" << std::endl;
3099 fout << " if (theInputVars.size() <= 0) {" << std::endl;
3100 fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": empty input vector\" << std::endl;" << std::endl;
3101 fout << " fStatusIsClean = false;" << std::endl;
3102 fout << " }" << std::endl;
3103 fout << std::endl;
3104 fout << " if (theInputVars.size() != fNvars) {" << std::endl;
3105 fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": mismatch in number of input values: \"" << std::endl;
3106 fout << " << theInputVars.size() << \" != \" << fNvars << std::endl;" << std::endl;
3107 fout << " fStatusIsClean = false;" << std::endl;
3108 fout << " }" << std::endl;
3109 fout << std::endl;
3110 fout << " // validate input variables" << std::endl;
3111 fout << " for (size_t ivar = 0; ivar < theInputVars.size(); ivar++) {" << std::endl;
3112 fout << " if (theInputVars[ivar] != inputVars[ivar]) {" << std::endl;
3113 fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": mismatch in input variable names\" << std::endl" << std::endl;
3114 fout << " << \" for variable [\" << ivar << \"]: \" << theInputVars[ivar].c_str() << \" != \" << inputVars[ivar] << std::endl;" << std::endl;
3115 fout << " fStatusIsClean = false;" << std::endl;
3116 fout << " }" << std::endl;
3117 fout << " }" << std::endl;
3118 fout << std::endl;
3119 fout << " // initialize min and max vectors (for normalisation)" << std::endl;
3120 for (UInt_t ivar = 0; ivar < GetNvar(); ivar++) {
3121 fout << " fVmin[" << ivar << "] = " << std::setprecision(15) << GetXmin( ivar ) << ";" << std::endl;
3122 fout << " fVmax[" << ivar << "] = " << std::setprecision(15) << GetXmax( ivar ) << ";" << std::endl;
3123 }
3124 fout << std::endl;
3125 fout << " // initialize input variable types" << std::endl;
3126 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
3127 fout << " fType[" << ivar << "] = \'" << DataInfo().GetVariableInfo(ivar).GetVarType() << "\';" << std::endl;
3128 }
3129 fout << std::endl;
3130 fout << " // initialize constants" << std::endl;
3131 fout << " Initialize();" << std::endl;
3132 fout << std::endl;
3133 if (GetTransformationHandler().GetTransformationList().GetSize() != 0) {
3134 fout << " // initialize transformation" << std::endl;
3135 fout << " InitTransform();" << std::endl;
3136 }
3137 fout << " }" << std::endl;
3138 fout << std::endl;
3139 fout << " // destructor" << std::endl;
3140 fout << " virtual ~" << className << "() {" << std::endl;
3141 fout << " Clear(); // method-specific" << std::endl;
3142 fout << " }" << std::endl;
3143 fout << std::endl;
3144 fout << " // the classifier response" << std::endl;
3145 fout << " // \"inputValues\" is a vector of input values in the same order as the" << std::endl;
3146 fout << " // variables given to the constructor" << std::endl;
3147 if(GetAnalysisType() == Types::kMulticlass) {
3148 fout << " std::vector<double> GetMulticlassValues( const std::vector<double>& inputValues ) const override;" << std::endl;
3149 } else {
3150 fout << " double GetMvaValue( const std::vector<double>& inputValues ) const override;" << std::endl;
3151 }
3152 fout << std::endl;
3153 fout << " private:" << std::endl;
3154 fout << std::endl;
3155 fout << " // method-specific destructor" << std::endl;
3156 fout << " void Clear();" << std::endl;
3157 fout << std::endl;
3158 if (GetTransformationHandler().GetTransformationList().GetSize()!=0) {
3159 fout << " // input variable transformation" << std::endl;
3160 GetTransformationHandler().MakeFunction(fout, className,1);
3161 fout << " void InitTransform();" << std::endl;
3162 fout << " void Transform( std::vector<double> & iv, int sigOrBgd ) const;" << std::endl;
3163 fout << std::endl;
3164 }
3165 fout << " // common member variables" << std::endl;
3166 fout << " const char* fClassName;" << std::endl;
3167 fout << std::endl;
3168 fout << " const size_t fNvars;" << std::endl;
3169 fout << " size_t GetNvar() const { return fNvars; }" << std::endl;
3170 fout << " char GetType( int ivar ) const { return fType[ivar]; }" << std::endl;
3171 fout << std::endl;
3172 fout << " // normalisation of input variables" << std::endl;
3173 fout << " double fVmin[" << GetNvar() << "];" << std::endl;
3174 fout << " double fVmax[" << GetNvar() << "];" << std::endl;
3175 fout << " double NormVariable( double x, double xmin, double xmax ) const {" << std::endl;
3176 fout << " // normalise to output range: [-1, 1]" << std::endl;
3177 fout << " return 2*(x - xmin)/(xmax - xmin) - 1.0;" << std::endl;
3178 fout << " }" << std::endl;
3179 fout << std::endl;
3180 fout << " // type of input variable: 'F' or 'I'" << std::endl;
3181 fout << " char fType[" << GetNvar() << "];" << std::endl;
3182 fout << std::endl;
3183 fout << " // initialize internal variables" << std::endl;
3184 fout << " void Initialize();" << std::endl;
3185 if(GetAnalysisType() == Types::kMulticlass) {
3186 fout << " std::vector<double> GetMulticlassValues__( const std::vector<double>& inputValues ) const;" << std::endl;
3187 } else {
3188 fout << " double GetMvaValue__( const std::vector<double>& inputValues ) const;" << std::endl;
3189 }
3190 fout << "" << std::endl;
3191 fout << " // private members (method specific)" << std::endl;
3192
3193 // call the classifier specific output (the classifier must close the class !)
3194 MakeClassSpecific( fout, className );
3195
3196 if(GetAnalysisType() == Types::kMulticlass) {
3197 fout << "inline std::vector<double> " << className << "::GetMulticlassValues( const std::vector<double>& inputValues ) const" << std::endl;
3198 } else {
3199 fout << "inline double " << className << "::GetMvaValue( const std::vector<double>& inputValues ) const" << std::endl;
3200 }
3201 fout << "{" << std::endl;
3202 fout << " // classifier response value" << std::endl;
3203 if(GetAnalysisType() == Types::kMulticlass) {
3204 fout << " std::vector<double> retval;" << std::endl;
3205 } else {
3206 fout << " double retval = 0;" << std::endl;
3207 }
3208 fout << std::endl;
3209 fout << " // classifier response, sanity check first" << std::endl;
3210 fout << " if (!IsStatusClean()) {" << std::endl;
3211 fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": cannot return classifier response\"" << std::endl;
3212 fout << " << \" because status is dirty\" << std::endl;" << std::endl;
3213 fout << " }" << std::endl;
3214 fout << " else {" << std::endl;
3215 if (IsNormalised()) {
3216 fout << " // normalise variables" << std::endl;
3217 fout << " std::vector<double> iV;" << std::endl;
3218 fout << " iV.reserve(inputValues.size());" << std::endl;
3219 fout << " int ivar = 0;" << std::endl;
3220 fout << " for (std::vector<double>::const_iterator varIt = inputValues.begin();" << std::endl;
3221 fout << " varIt != inputValues.end(); varIt++, ivar++) {" << std::endl;
3222 fout << " iV.push_back(NormVariable( *varIt, fVmin[ivar], fVmax[ivar] ));" << std::endl;
3223 fout << " }" << std::endl;
3224 if (GetTransformationHandler().GetTransformationList().GetSize() != 0 && GetMethodType() != Types::kLikelihood &&
3225 GetMethodType() != Types::kHMatrix) {
3226 fout << " Transform( iV, -1 );" << std::endl;
3227 }
3228
3229 if(GetAnalysisType() == Types::kMulticlass) {
3230 fout << " retval = GetMulticlassValues__( iV );" << std::endl;
3231 } else {
3232 fout << " retval = GetMvaValue__( iV );" << std::endl;
3233 }
3234 } else {
3235 if (GetTransformationHandler().GetTransformationList().GetSize() != 0 && GetMethodType() != Types::kLikelihood &&
3236 GetMethodType() != Types::kHMatrix) {
3237 fout << " std::vector<double> iV(inputValues);" << std::endl;
3238 fout << " Transform( iV, -1 );" << std::endl;
3239 if(GetAnalysisType() == Types::kMulticlass) {
3240 fout << " retval = GetMulticlassValues__( iV );" << std::endl;
3241 } else {
3242 fout << " retval = GetMvaValue__( iV );" << std::endl;
3243 }
3244 } else {
3245 if(GetAnalysisType() == Types::kMulticlass) {
3246 fout << " retval = GetMulticlassValues__( inputValues );" << std::endl;
3247 } else {
3248 fout << " retval = GetMvaValue__( inputValues );" << std::endl;
3249 }
3250 }
3251 }
3252 fout << " }" << std::endl;
3253 fout << std::endl;
3254 fout << " return retval;" << std::endl;
3255 fout << "}" << std::endl;
3256
3257 // create output for transformation - if any
3258 if (GetTransformationHandler().GetTransformationList().GetSize()!=0)
3259 GetTransformationHandler().MakeFunction(fout, className,2);
3260
3261 // close the file
3262 fout.close();
3263}
3264
3265////////////////////////////////////////////////////////////////////////////////
3266/// prints out method-specific help method
3267
3269{
3270 // if options are written to reference file, also append help info
3271 std::streambuf* cout_sbuf = std::cout.rdbuf(); // save original sbuf
3272 std::ofstream* o = 0;
3273 if (gConfig().WriteOptionsReference()) {
3274 Log() << kINFO << "Print Help message for class " << GetName() << " into file: " << GetReferenceFile() << Endl;
3275 o = new std::ofstream( GetReferenceFile(), std::ios::app );
3276 if (!o->good()) { // file could not be opened --> Error
3277 Log() << kFATAL << "<PrintHelpMessage> Unable to append to output file: " << GetReferenceFile() << Endl;
3278 }
3279 std::cout.rdbuf( o->rdbuf() ); // redirect 'std::cout' to file
3280 }
3281
3282 // "|--------------------------------------------------------------|"
3283 if (!o) {
3284 Log() << kINFO << Endl;
3285 Log() << gTools().Color("bold")
3286 << "================================================================"
3287 << gTools().Color( "reset" )
3288 << Endl;
3289 Log() << gTools().Color("bold")
3290 << "H e l p f o r M V A m e t h o d [ " << GetName() << " ] :"
3291 << gTools().Color( "reset" )
3292 << Endl;
3293 }
3294 else {
3295 Log() << "Help for MVA method [ " << GetName() << " ] :" << Endl;
3296 }
3297
3298 // print method-specific help message
3299 GetHelpMessage();
3300
3301 if (!o) {
3302 Log() << Endl;
3303 Log() << "<Suppress this message by specifying \"!H\" in the booking option>" << Endl;
3304 Log() << gTools().Color("bold")
3305 << "================================================================"
3306 << gTools().Color( "reset" )
3307 << Endl;
3308 Log() << Endl;
3309 }
3310 else {
3311 // indicate END
3312 Log() << "# End of Message___" << Endl;
3313 }
3314
3315 std::cout.rdbuf( cout_sbuf ); // restore the original stream buffer
3316 if (o) o->close();
3317}
3318
3319// ----------------------- r o o t f i n d i n g ----------------------------
3320
3321////////////////////////////////////////////////////////////////////////////////
3322/// returns efficiency as function of cut
3323
3325{
3326 Double_t retval=0;
3327
3328 // retrieve the class object
3330 retval = fSplRefS->Eval( theCut );
3331 }
3332 else retval = fEffS->GetBinContent( fEffS->FindBin( theCut ) );
3333
3334 // caution: here we take some "forbidden" action to hide a problem:
3335 // in some cases, in particular for likelihood, the binned efficiency distributions
3336 // do not equal 1, at xmin, and 0 at xmax; of course, in principle we have the
3337 // unbinned information available in the trees, but the unbinned minimization is
3338 // too slow, and we don't need to do a precision measurement here. Hence, we force
3339 // this property.
3340 Double_t eps = 1.0e-5;
3341 if (theCut-fXmin < eps) retval = (GetCutOrientation() == kPositive) ? 1.0 : 0.0;
3342 else if (fXmax-theCut < eps) retval = (GetCutOrientation() == kPositive) ? 0.0 : 1.0;
3343
3344 return retval;
3345}
3346
3347////////////////////////////////////////////////////////////////////////////////
3348/// returns the event collection (i.e. the dataset) TRANSFORMED using the
3349/// classifiers specific Variable Transformation (e.g. Decorr or Decorr:Gauss:Decorr)
3350
3352{
3353 // if there's no variable transformation for this classifier, just hand back the
3354 // event collection of the data set
3355 if (GetTransformationHandler().GetTransformationList().GetEntries() <= 0) {
3356 return (Data()->GetEventCollection(type));
3357 }
3358
3359 // otherwise, transform ALL the events and hand back the vector of the pointers to the
3360 // transformed events. If the pointer is already != 0, i.e. the whole thing has been
3361 // done before, I don't need to do it again, but just "hand over" the pointer to those events.
3362 Int_t idx = Data()->TreeIndex(type); //index indicating Training,Testing,... events/datasets
3363 if (fEventCollections.at(idx) == 0) {
3364 fEventCollections.at(idx) = &(Data()->GetEventCollection(type));
3365 fEventCollections.at(idx) = GetTransformationHandler().CalcTransformations(*(fEventCollections.at(idx)),kTRUE);
3366 }
3367 return *(fEventCollections.at(idx));
3368}
3369
3370////////////////////////////////////////////////////////////////////////////////
3371/// calculates the TMVA version string from the training version code on the fly
3372
3374{
3375 UInt_t a = GetTrainingTMVAVersionCode() & 0xff0000; a>>=16;
3376 UInt_t b = GetTrainingTMVAVersionCode() & 0x00ff00; b>>=8;
3377 UInt_t c = GetTrainingTMVAVersionCode() & 0x0000ff;
3378
3379 return TString(Form("%i.%i.%i",a,b,c));
3380}
3381
3382////////////////////////////////////////////////////////////////////////////////
3383/// calculates the ROOT version string from the training version code on the fly
3384
3386{
3387 UInt_t a = GetTrainingROOTVersionCode() & 0xff0000; a>>=16;
3388 UInt_t b = GetTrainingROOTVersionCode() & 0x00ff00; b>>=8;
3389 UInt_t c = GetTrainingROOTVersionCode() & 0x0000ff;
3390
3391 return TString(Form("%i.%02i/%02i",a,b,c));
3392}
3393
3394////////////////////////////////////////////////////////////////////////////////
3395
3397 ResultsClassification* mvaRes = dynamic_cast<ResultsClassification*>
3398 ( Data()->GetResults(GetMethodName(),Types::kTesting, Types::kClassification) );
3399
3400 if (mvaRes != NULL) {
3401 TH1D *mva_s = dynamic_cast<TH1D*> (mvaRes->GetHist("MVA_S"));
3402 TH1D *mva_b = dynamic_cast<TH1D*> (mvaRes->GetHist("MVA_B"));
3403 TH1D *mva_s_tr = dynamic_cast<TH1D*> (mvaRes->GetHist("MVA_TRAIN_S"));
3404 TH1D *mva_b_tr = dynamic_cast<TH1D*> (mvaRes->GetHist("MVA_TRAIN_B"));
3405
3406 if ( !mva_s || !mva_b || !mva_s_tr || !mva_b_tr) return -1;
3407
3408 if (SorB == 's' || SorB == 'S')
3409 return mva_s->KolmogorovTest( mva_s_tr, opt.Data() );
3410 else
3411 return mva_b->KolmogorovTest( mva_b_tr, opt.Data() );
3412 }
3413 return -1;
3414}
const Bool_t Use_Splines_for_Eff_
Definition: MethodBase.cxx:130
const Int_t NBIN_HIST_HIGH
Definition: MethodBase.cxx:133
ROOT::R::TRInterface & r
Definition: Object.C:4
#define d(i)
Definition: RSha256.hxx:102
#define b(i)
Definition: RSha256.hxx:100
#define c(i)
Definition: RSha256.hxx:101
#define s1(x)
Definition: RSha256.hxx:91
#define ROOT_VERSION_CODE
Definition: RVersion.h:21
int Int_t
Definition: RtypesCore.h:45
int Ssiz_t
Definition: RtypesCore.h:67
char Char_t
Definition: RtypesCore.h:33
unsigned int UInt_t
Definition: RtypesCore.h:46
const Bool_t kFALSE
Definition: RtypesCore.h:101
bool Bool_t
Definition: RtypesCore.h:63
double Double_t
Definition: RtypesCore.h:59
long long Long64_t
Definition: RtypesCore.h:80
float Float_t
Definition: RtypesCore.h:57
const Bool_t kTRUE
Definition: RtypesCore.h:100
#define ClassImp(name)
Definition: Rtypes.h:364
char name[80]
Definition: TGX11.cxx:110
int type
Definition: TGX11.cxx:121
float xmin
Definition: THbookFile.cxx:95
float xmax
Definition: THbookFile.cxx:95
double sqrt(double)
TMatrixT< Double_t > TMatrixD
Definition: TMatrixDfwd.h:22
char * Form(const char *fmt,...)
R__EXTERN TSystem * gSystem
Definition: TSystem.h:559
#define TMVA_VERSION_CODE
Definition: Version.h:47
Class to manage histogram axis.
Definition: TAxis.h:30
Double_t GetXmax() const
Definition: TAxis.h:134
Double_t GetXmin() const
Definition: TAxis.h:133
virtual Int_t GetSize() const
Return the capacity of the collection, i.e.
Definition: TCollection.h:184
virtual Int_t Write(const char *name=0, Int_t option=0, Int_t bufsize=0)
Write all objects in this collection.
This class stores the date and time with a precision of one second in an unsigned 32 bit word (950130...
Definition: TDatime.h:37
const char * AsString() const
Return the date & time as a string (ctime() format).
Definition: TDatime.cxx:102
TObject * Get(const char *namecycle) override
Return pointer to object identified by namecycle.
Describe directory structure in memory.
Definition: TDirectory.h:45
virtual TDirectory * GetDirectory(const char *namecycle, Bool_t printError=false, const char *funcname="GetDirectory")
Find a directory using apath.
Definition: TDirectory.cxx:407
virtual TDirectory * mkdir(const char *name, const char *title="", Bool_t returnExistingDirectory=kFALSE)
Create a sub-directory "a" or a hierarchy of sub-directories "a/b/c/...".
virtual Bool_t cd(const char *path=nullptr)
Change current directory to "this" directory.
Definition: TDirectory.cxx:504
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:54
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:4011
void Close(Option_t *option="") override
Close a file.
Definition: TFile.cxx:889
A TGraph is an object made of two arrays X and Y with npoints each.
Definition: TGraph.h:41
1-D histogram with a double per channel (see TH1 documentation)}
Definition: TH1.h:618
1-D histogram with a float per channel (see TH1 documentation)}
Definition: TH1.h:575
TH1 is the base class of all histogram classes in ROOT.
Definition: TH1.h:58
virtual Double_t GetBinCenter(Int_t bin) const
Return bin center for 1D histogram.
Definition: TH1.cxx:8984
virtual Int_t GetQuantiles(Int_t nprobSum, Double_t *q, const Double_t *probSum=0)
Compute Quantiles for this histogram Quantile x_q of a probability distribution Function F is defined...
Definition: TH1.cxx:4544
virtual void AddBinContent(Int_t bin)
Increment bin content by 1.
Definition: TH1.cxx:1258
virtual Double_t GetMean(Int_t axis=1) const
For axis = 1,2 or 3 returns the mean value of the histogram along X,Y or Z axis.
Definition: TH1.cxx:7423
virtual void SetXTitle(const char *title)
Definition: TH1.h:413
static void AddDirectory(Bool_t add=kTRUE)
Sets the flag controlling the automatic add of histograms in memory.
Definition: TH1.cxx:1283
TAxis * GetXaxis()
Get the behaviour adopted by the object about the statoverflows. See EStatOverflows for more informat...
Definition: TH1.h:320
virtual Double_t GetMaximum(Double_t maxval=FLT_MAX) const
Return maximum value smaller than maxval of bins in the range, unless the value has been overridden b...
Definition: TH1.cxx:8388
virtual Int_t GetNbinsX() const
Definition: TH1.h:296
virtual Int_t Fill(Double_t x)
Increment bin with abscissa X by 1.
Definition: TH1.cxx:3351
virtual void SetBinContent(Int_t bin, Double_t content)
Set bin content see convention for numbering bins in TH1::GetBin In case the bin number is greater th...
Definition: TH1.cxx:9065
virtual Int_t GetMaximumBin() const
Return location of bin with maximum value in the range.
Definition: TH1.cxx:8420
virtual Double_t GetBinContent(Int_t bin) const
Return content of bin number bin.
Definition: TH1.cxx:4994
virtual void SetYTitle(const char *title)
Definition: TH1.h:414
virtual void Scale(Double_t c1=1, Option_t *option="")
Multiply this histogram by a constant c1.
Definition: TH1.cxx:6566
virtual Int_t FindBin(Double_t x, Double_t y=0, Double_t z=0)
Return Global bin number corresponding to x,y,z.
Definition: TH1.cxx:3681
virtual Double_t KolmogorovTest(const TH1 *h2, Option_t *option="") const
Statistical test of compatibility in shape between this histogram and h2, using Kolmogorov test.
Definition: TH1.cxx:8063
virtual void Sumw2(Bool_t flag=kTRUE)
Create structure to store sum of squares of weights.
Definition: TH1.cxx:8863
static Bool_t AddDirectoryStatus()
Static function: cannot be inlined on Windows/NT.
Definition: TH1.cxx:751
2-D histogram with a float per channel (see TH1 documentation)}
Definition: TH2.h:251
Int_t Fill(Double_t)
Invalid Fill method.
Definition: TH2.cxx:358
A doubly linked list.
Definition: TList.h:44
virtual TObject * At(Int_t idx) const
Returns the object at position idx. Returns 0 if idx is out of range.
Definition: TList.cxx:357
Class that contains all the information of a class.
Definition: ClassInfo.h:49
UInt_t GetNumber() const
Definition: ClassInfo.h:65
TString fWeightFileExtension
Definition: Config.h:125
VariablePlotting & GetVariablePlotting()
Definition: Config.h:97
class TMVA::Config::VariablePlotting fVariablePlotting
IONames & GetIONames()
Definition: Config.h:98
MsgLogger * fLogger
Definition: Configurable.h:128
Class that contains all the data information.
Definition: DataSetInfo.h:62
Class that contains all the data information.
Definition: DataSet.h:58
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition: Event.cxx:236
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not.
Definition: Event.cxx:381
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition: Event.cxx:391
Float_t GetTarget(UInt_t itgt) const
Definition: Event.h:102
static void SetIgnoreNegWeightsInTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition: Event.cxx:400
Interface for all concrete MVA method implementations.
Definition: IMethod.h:53
void Init(std::vector< TString > &graphTitles)
This function gets some title and it creates a TGraph for every title.
Definition: MethodBase.cxx:169
IPythonInteractive()
standard constructor
Definition: MethodBase.cxx:146
~IPythonInteractive()
standard destructor
Definition: MethodBase.cxx:154
void ClearGraphs()
This function sets the point number to 0 for all graphs.
Definition: MethodBase.cxx:193
void AddPoint(Double_t x, Double_t y1, Double_t y2)
This function is used only in 2 TGraph case, and it will add new data points to graphs.
Definition: MethodBase.cxx:207
Virtual base Class for all MVA method.
Definition: MethodBase.h:111
TDirectory * MethodBaseDir() const
returns the ROOT directory where all instances of the corresponding MVA method are stored
virtual Double_t GetKSTrainingVsTest(Char_t SorB, TString opt="X")
MethodBase(const TString &jobName, Types::EMVA methodType, const TString &methodTitle, DataSetInfo &dsi, const TString &theOption="")
standard constructor
Definition: MethodBase.cxx:237
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as
void ReadClassesFromXML(void *clsnode)
read number of classes from XML
void SetWeightFileDir(TString fileDir)
set directory of weight file
void WriteStateToXML(void *parent) const
general method used in writing the header of the weight files where the used variables,...
void DeclareBaseOptions()
define the options (their key words) that can be set in the option string here the options valid for ...
Definition: MethodBase.cxx:509
virtual void TestRegression(Double_t &bias, Double_t &biasT, Double_t &dev, Double_t &devT, Double_t &rms, Double_t &rmsT, Double_t &mInf, Double_t &mInfT, Double_t &corr, Types::ETreeType type)
calculate <sum-of-deviation-squared> of regression output versus "true" value from test sample
Definition: MethodBase.cxx:991
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
Definition: MethodBase.cxx:596
virtual Double_t GetSignificance() const
compute significance of mean difference
virtual Double_t GetProba(const Event *ev)
const char * GetName() const
Definition: MethodBase.h:334
virtual TMatrixD GetMulticlassConfusionMatrix(Double_t effB, Types::ETreeType type)
Construct a confusion matrix for a multiclass classifier.
void PrintHelpMessage() const
prints out method-specific help method
virtual void WriteEvaluationHistosToFile(Types::ETreeType treetype)
writes all MVA evaluation histograms to file
virtual void TestMulticlass()
test multiclass classification
const std::vector< TMVA::Event * > & GetEventCollection(Types::ETreeType type)
returns the event collection (i.e.
virtual std::vector< Double_t > GetDataMvaValues(DataSet *data=nullptr, Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the given Data type
Definition: MethodBase.cxx:939
void SetupMethod()
setup of methods
Definition: MethodBase.cxx:406
TDirectory * BaseDir() const
returns the ROOT directory where info/histograms etc of the corresponding MVA method instance are sto...
virtual std::vector< Float_t > GetMulticlassEfficiency(std::vector< std::vector< Float_t > > &purity)
void AddInfoItem(void *gi, const TString &name, const TString &value) const
xml writing
virtual void AddClassifierOutputProb(Types::ETreeType type)
prepare tree branch with the method's discriminating variable
Definition: MethodBase.cxx:950
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
TString GetTrainingTMVAVersionString() const
calculates the TMVA version string from the training version code on the fly
void Statistics(Types::ETreeType treeType, const TString &theVarName, Double_t &, Double_t &, Double_t &, Double_t &, Double_t &, Double_t &)
calculates rms,mean, xmin, xmax of the event variable this can be either done for the variables as th...
Bool_t GetLine(std::istream &fin, char *buf)
reads one line from the input stream checks for certain keywords and interprets the line if keywords ...
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
Definition: MethodBase.cxx:423
virtual std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
Definition: MethodBase.cxx:897
virtual Bool_t IsSignalLike()
uses a pre-set cut on the MVA output (SetSignalReferenceCut and SetSignalReferenceCutOrientation) for...
Definition: MethodBase.cxx:854
virtual ~MethodBase()
destructor
Definition: MethodBase.cxx:364
virtual Double_t GetMaximumSignificance(Double_t SignalEvents, Double_t BackgroundEvents, Double_t &optimal_significance_value) const
plot significance, , curve for given number of signal and background events; returns cut for maximum ...
virtual Double_t GetTrainingEfficiency(const TString &)
void SetWeightFileName(TString)
set the weight file name (depreciated)
virtual void MakeClass(const TString &classFileName=TString("")) const
create reader class for method (classification only at present)
TString GetWeightFileName() const
retrieve weight file name
virtual void TestClassification()
initialization
void AddOutput(Types::ETreeType type, Types::EAnalysisType analysisType)
virtual void WriteMonitoringHistosToFile() const
write special monitoring histograms to file dummy implementation here --------------—
virtual void AddRegressionOutput(Types::ETreeType type)
prepare tree branch with the method's discriminating variable
Definition: MethodBase.cxx:744
void InitBase()
default initialization called by all constructors
Definition: MethodBase.cxx:441
virtual void GetRegressionDeviation(UInt_t tgtNum, Types::ETreeType type, Double_t &stddev, Double_t &stddev90Percent) const
Definition: MethodBase.cxx:724
void ReadStateFromXMLString(const char *xmlstr)
for reading from memory
void CreateMVAPdfs()
Create PDFs of the MVA output variables.
TString GetTrainingROOTVersionString() const
calculates the ROOT version string from the training version code on the fly
virtual Double_t GetValueForRoot(Double_t)
returns efficiency as function of cut
void ReadStateFromFile()
Function to write options and weights to file.
void WriteVarsToStream(std::ostream &tf, const TString &prefix="") const
write the list of variables (name, min, max) for a given data transformation method to the stream
void ReadVarsFromStream(std::istream &istr)
Read the variables (name, min, max) for a given data transformation method from the stream.
void ReadSpectatorsFromXML(void *specnode)
read spectator info from XML
virtual Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)=0
void SetTestvarName(const TString &v="")
Definition: MethodBase.h:341
void ReadVariablesFromXML(void *varnode)
read variable info from XML
virtual std::map< TString, Double_t > OptimizeTuningParameters(TString fomType="ROCIntegral", TString fitType="FitGA")
call the Optimizer with the set of parameters and ranges that are meant to be tuned.
Definition: MethodBase.cxx:623
virtual std::vector< Float_t > GetMulticlassTrainingEfficiency(std::vector< std::vector< Float_t > > &purity)
void WriteStateToStream(std::ostream &tf) const
general method used in writing the header of the weight files where the used variables,...
virtual Double_t GetRarity(Double_t mvaVal, Types::ESBType reftype=Types::kBackground) const
compute rarity:
virtual void SetTuneParameters(std::map< TString, Double_t > tuneParameters)
set the tuning parameters according to the argument This is just a dummy .
Definition: MethodBase.cxx:644
void ReadStateFromStream(std::istream &tf)
read the header from the weight files of the different MVA methods
void AddVarsXMLTo(void *parent) const
write variable info to XML
void AddTargetsXMLTo(void *parent) const
write target info to XML
void ReadTargetsFromXML(void *tarnode)
read target info from XML
void ProcessBaseOptions()
the option string is decoded, for available options see "DeclareOptions"
Definition: MethodBase.cxx:540
void ReadStateFromXML(void *parent)
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
Definition: MethodBase.cxx:836
void WriteStateToFile() const
write options and weights to file note that each one text file for the main configuration information...
void AddClassesXMLTo(void *parent) const
write class info to XML
virtual void AddClassifierOutput(Types::ETreeType type)
prepare tree branch with the method's discriminating variable
Definition: MethodBase.cxx:868
void AddSpectatorsXMLTo(void *parent) const
write spectator info to XML
virtual Double_t GetROCIntegral(TH1D *histS, TH1D *histB) const
calculate the area (integral) under the ROC curve as a overall quality measure of the classification
virtual void AddMulticlassOutput(Types::ETreeType type)
prepare tree branch with the method's discriminating variable
Definition: MethodBase.cxx:794
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Definition: MethodBase.cxx:433
void SetSource(const std::string &source)
Definition: MsgLogger.h:68
PDF wrapper for histograms; uses user-defined spline interpolation.
Definition: PDF.h:63
Double_t GetXmin() const
Definition: PDF.h:104
Double_t GetXmax() const
Definition: PDF.h:105
Double_t GetVal(Double_t x) const
returns value PDF(x)
Definition: PDF.cxx:701
@ kSpline3
Definition: PDF.h:70
@ kSpline2
Definition: PDF.h:70
Double_t GetIntegral(Double_t xmin, Double_t xmax)
computes PDF integral within given ranges
Definition: PDF.cxx:654
Class that is the base-class for a vector of result.
std::vector< Float_t > * GetValueVector()
void SetValue(Float_t value, Int_t ievt, Bool_t type)
set MVA response
Class which takes the results of a multiclass classification.
TMatrixD GetConfusionMatrix(Double_t effB)
Returns a confusion matrix where each class is pitted against each other.
Float_t GetAchievablePur(UInt_t cls)
std::vector< Double_t > GetBestMultiClassCuts(UInt_t targetClass)
calculate the best working point (optimal cut values) for the multiclass classifier
void CreateMulticlassHistos(TString prefix, Int_t nbins, Int_t nbins_high)
this function fills the mva response histos for multiclass classification
Float_t GetAchievableEff(UInt_t cls)
void CreateMulticlassPerformanceHistos(TString prefix)
Create performance graphs for this classifier a multiclass setting.
Class that is the base-class for a vector of result.
Class that is the base-class for a vector of result.
Definition: Results.h:57
Bool_t DoesExist(const TString &alias) const
Returns true if there is an object stored in the result for a given alias, false otherwise.
Definition: Results.cxx:127
TH1 * GetHist(const TString &alias) const
Definition: Results.cxx:136
TList * GetStorage() const
Definition: Results.h:73
void Store(TObject *obj, const char *alias=0)
Definition: Results.cxx:86
Root finding using Brents algorithm (translated from CERNLIB function RZERO)
Definition: RootFinder.h:48
Double_t Root(Double_t refValue)
Root finding using Brents algorithm; taken from CERNLIB function RZERO.
Definition: RootFinder.cxx:71
Linear interpolation of TGraph.
Definition: TSpline1.h:43
Timing information for training and evaluation of MVA methods.
Definition: Timer.h:58
Double_t ElapsedSeconds(void)
computes elapsed tim in seconds
Definition: Timer.cxx:137
TString GetElapsedTime(Bool_t Scientific=kTRUE)
returns pretty string with elapsed time
Definition: Timer.cxx:146
void DrawProgressBar(Int_t, const TString &comment="")
draws progress bar in color or B&W caution:
Definition: Timer.cxx:202
void ComputeStat(const std::vector< TMVA::Event * > &, std::vector< Float_t > *, Double_t &, Double_t &, Double_t &, Double_t &, Double_t &, Double_t &, Int_t signalClass, Bool_t norm=kFALSE)
sanity check
Definition: Tools.cxx:202
TList * ParseFormatLine(TString theString, const char *sep=":")
Parse the string and cut into labels separated by ":".
Definition: Tools.cxx:401
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1162
Double_t GetSeparation(TH1 *S, TH1 *B) const
compute "separation" defined as
Definition: Tools.cxx:121
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1124
Double_t GetMutualInformation(const TH2F &)
Mutual Information method for non-linear correlations estimates in 2D histogram Author: Moritz Backes...
Definition: Tools.cxx:589
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:828
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1150
TXMLEngine & xmlengine()
Definition: Tools.h:262
Bool_t CheckSplines(const TH1 *, const TSpline *)
check quality of splining by comparing splines and histograms in each bin
Definition: Tools.cxx:479
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:329
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition: Tools.h:347
Double_t NormHist(TH1 *theHist, Double_t norm=1.0)
normalises histogram
Definition: Tools.cxx:383
Singleton class for Global types used by TMVA.
Definition: Types.h:71
@ kSignal
Definition: Types.h:135
@ kBackground
Definition: Types.h:136
@ kLikelihood
Definition: Types.h:79
@ kHMatrix
Definition: Types.h:81
EAnalysisType
Definition: Types.h:126
@ kMulticlass
Definition: Types.h:129
@ kNoAnalysisType
Definition: Types.h:130
@ kClassification
Definition: Types.h:127
@ kMaxAnalysisType
Definition: Types.h:131
@ kRegression
Definition: Types.h:128
@ kTraining
Definition: Types.h:143
@ kTesting
Definition: Types.h:144
Linear interpolation class.
Gaussian Transformation of input variables.
Class for type info of MVA input variable.
Definition: VariableInfo.h:47
void ReadFromXML(void *varnode)
read VariableInfo from stream
const TString & GetExpression() const
Definition: VariableInfo.h:57
char GetVarType() const
Definition: VariableInfo.h:61
void ReadFromStream(std::istream &istr)
read VariableInfo from stream
void AddToXML(void *varnode)
write class to XML
void SetExternalLink(void *p)
Definition: VariableInfo.h:75
void * GetExternalLink() const
Definition: VariableInfo.h:83
void BuildTransformationFromVarInfo(const std::vector< TMVA::VariableInfo > &var)
this method is only used when building a normalization transformation from old text files in this cas...
Linear interpolation class.
Linear interpolation class.
virtual void ReadTransformationFromStream(std::istream &istr, const TString &classname="")=0
@ kDEBUG
Definition: Types.h:56
@ kVERBOSE
Definition: Types.h:57
@ kHEADER
Definition: Types.h:63
@ kERROR
Definition: Types.h:60
@ kINFO
Definition: Types.h:58
@ kWARNING
Definition: Types.h:59
@ kFATAL
Definition: Types.h:61
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:36
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:47
Collectable string class.
Definition: TObjString.h:28
virtual Int_t Write(const char *name=0, Int_t option=0, Int_t bufsize=0)
Write this object to the current directory.
Definition: TObject.cxx:798
Basic string class.
Definition: TString.h:136
Ssiz_t Length() const
Definition: TString.h:410
void ToLower()
Change string to lower-case.
Definition: TString.cxx:1150
Int_t Atoi() const
Return integer value of string.
Definition: TString.cxx:1946
Bool_t EndsWith(const char *pat, ECaseCompare cmp=kExact) const
Return true if string ends with the specified string.
Definition: TString.cxx:2202
TSubString Strip(EStripType s=kTrailing, char c=' ') const
Return a substring of self stripped at beginning and/or end.
Definition: TString.cxx:1131
const char * Data() const
Definition: TString.h:369
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition: TString.h:692
@ kLeading
Definition: TString.h:267
Ssiz_t Last(char c) const
Find last occurrence of a character c.
Definition: TString.cxx:916
Bool_t IsNull() const
Definition: TString.h:407
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition: TString.cxx:2336
Ssiz_t Index(const char *pat, Ssiz_t i=0, ECaseCompare cmp=kExact) const
Definition: TString.h:639
virtual const char * GetBuildNode() const
Return the build node name.
Definition: TSystem.cxx:3894
virtual int mkdir(const char *name, Bool_t recursive=kFALSE)
Make a file system directory.
Definition: TSystem.cxx:907
virtual const char * WorkingDirectory()
Return working directory.
Definition: TSystem.cxx:872
virtual UserGroup_t * GetUserInfo(Int_t uid)
Returns all user info in the UserGroup_t structure.
Definition: TSystem.cxx:1599
void SaveDoc(XMLDocPointer_t xmldoc, const char *filename, Int_t layout=1)
store document content to file if layout<=0, no any spaces or newlines will be placed between xmlnode...
void FreeDoc(XMLDocPointer_t xmldoc)
frees allocated document data and deletes document itself
XMLNodePointer_t DocGetRootElement(XMLDocPointer_t xmldoc)
returns root node of document
XMLDocPointer_t NewDoc(const char *version="1.0")
creates new xml document with provided version
XMLDocPointer_t ParseFile(const char *filename, Int_t maxbuf=100000)
Parses content of file and tries to produce xml structures.
XMLDocPointer_t ParseString(const char *xmlstring)
parses content of string and tries to produce xml structures
void DocSetRootElement(XMLDocPointer_t xmldoc, XMLNodePointer_t xmlnode)
set main (root) node for document
TLine * line
Double_t x[n]
Definition: legend1.C:17
TH1F * h1
Definition: legend1.C:5
bool BeginsWith(const std::string &theString, const std::string &theSubstring)
void Init(TClassEdit::TInterpreterLookupHelper *helper)
Definition: TClassEdit.cxx:154
static constexpr double s
static constexpr double m2
void GetMethodName(TString &name, TKey *mkey)
Definition: tmvaglob.cxx:342
Config & gConfig()
Tools & gTools()
void CreateVariableTransforms(const TString &trafoDefinition, TMVA::DataSetInfo &dataInfo, TMVA::TransformationHandler &transformationHandler, TMVA::MsgLogger &log)
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:148
Short_t Max(Short_t a, Short_t b)
Definition: TMathBase.h:208
Double_t Log(Double_t x)
Definition: TMath.h:760
Double_t Sqrt(Double_t x)
Definition: TMath.h:691
Short_t Min(Short_t a, Short_t b)
Definition: TMathBase.h:176
Short_t Abs(Short_t d)
Definition: TMathBase.h:120
TString fUser
Definition: TSystem.h:141
auto * a
Definition: textangle.C:12
REAL epsilon
Definition: triangle.c:618