Logo ROOT   6.18/05
Reference Guide
MethodBase.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne, Jan Therhaag
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodBase *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Implementation (see header for description) *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Joerg Stelzer <Joerg.Stelzer@cern.ch> - CERN, Switzerland *
16 * Peter Speckmayer <Peter.Speckmayer@cern.ch> - CERN, Switzerland *
17 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
18 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
19 * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
20 * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
21 * *
22 * Copyright (c) 2005-2011: *
23 * CERN, Switzerland *
24 * U. of Victoria, Canada *
25 * MPI-K Heidelberg, Germany *
26 * U. of Bonn, Germany *
27 * *
28 * Redistribution and use in source and binary forms, with or without *
29 * modification, are permitted according to the terms listed in LICENSE *
30 * (http://tmva.sourceforge.net/LICENSE) *
31 * *
32 **********************************************************************************/
33
34/*! \class TMVA::MethodBase
35\ingroup TMVA
36
37 Virtual base Class for all MVA method
38
39 MethodBase hosts several specific evaluation methods.
40
41 The kind of MVA that provides optimal performance in an analysis strongly
42 depends on the particular application. The evaluation factory provides a
43 number of numerical benchmark results to directly assess the performance
44 of the MVA training on the independent test sample. These are:
45
46 - The _signal efficiency_ at three representative background efficiencies
47 (which is 1 &minus; rejection).
48 - The _significance_ of an MVA estimator, defined by the difference
49 between the MVA mean values for signal and background, divided by the
50 quadratic sum of their root mean squares.
51 - The _separation_ of an MVA _x_, defined by the integral
52 \f[
53 \frac{1}{2} \int \frac{(S(x) - B(x))^2}{(S(x) + B(x))} dx
54 \f]
55 where
56 \f$ S(x) \f$ and \f$ B(x) \f$ are the signal and background distributions,
57 respectively. The separation is zero for identical signal and background MVA
58 shapes, and it is one for disjunctive shapes.
59 - The average, \f$ \int x \mu (S(x)) dx \f$, of the signal \f$ \mu_{transform} \f$.
60 The \f$ \mu_{transform} \f$ of an MVA denotes the transformation that yields
61 a uniform background distribution. In this way, the signal distributions
62 \f$ S(x) \f$ can be directly compared among the various MVAs. The stronger
63 \f$ S(x) \f$ peaks towards one, the better is the discrimination of the MVA.
64 The \f$ \mu_{transform} \f$ is
65 [documented here](http://tel.ccsd.cnrs.fr/documents/archives0/00/00/29/91/index_fr.html).
66
67 The MVA standard output also prints the linear correlation coefficients between
68 signal and background, which can be useful to eliminate variables that exhibit too
69 strong correlations.
70*/
71
72#include "TMVA/MethodBase.h"
73
74#include "TMVA/Config.h"
75#include "TMVA/Configurable.h"
76#include "TMVA/DataSetInfo.h"
77#include "TMVA/DataSet.h"
78#include "TMVA/Factory.h"
79#include "TMVA/IMethod.h"
80#include "TMVA/MsgLogger.h"
81#include "TMVA/PDF.h"
82#include "TMVA/Ranking.h"
83#include "TMVA/Factory.h"
84#include "TMVA/DataLoader.h"
85#include "TMVA/Tools.h"
86#include "TMVA/Results.h"
90#include "TMVA/RootFinder.h"
91#include "TMVA/Timer.h"
92#include "TMVA/Tools.h"
93#include "TMVA/TSpline1.h"
94#include "TMVA/Types.h"
98#include "TMVA/VariableInfo.h"
102#include "TMVA/Version.h"
103
104#include "TROOT.h"
105#include "TSystem.h"
106#include "TObjString.h"
107#include "TQObject.h"
108#include "TSpline.h"
109#include "TMatrix.h"
110#include "TMath.h"
111#include "TH1F.h"
112#include "TH2F.h"
113#include "TFile.h"
114#include "TKey.h"
115#include "TGraph.h"
116#include "Riostream.h"
117#include "TXMLEngine.h"
118
119#include <iomanip>
120#include <iostream>
121#include <fstream>
122#include <sstream>
123#include <cstdlib>
124#include <algorithm>
125#include <limits>
126
127
129
130using std::endl;
131using std::atof;
132
133//const Int_t MethodBase_MaxIterations_ = 200;
135
136//const Int_t NBIN_HIST_PLOT = 100;
137const Int_t NBIN_HIST_HIGH = 10000;
138
139#ifdef _WIN32
140/* Disable warning C4355: 'this' : used in base member initializer list */
141#pragma warning ( disable : 4355 )
142#endif
143
144
145#include "TGraph.h"
146#include "TMultiGraph.h"
147
148////////////////////////////////////////////////////////////////////////////////
149/// standard constructor
150
152{
153 fNumGraphs = 0;
154 fIndex = 0;
155}
156
157////////////////////////////////////////////////////////////////////////////////
158/// standard destructor
160{
161 if (fMultiGraph){
162 delete fMultiGraph;
163 fMultiGraph = nullptr;
164 }
165 return;
166}
167
168////////////////////////////////////////////////////////////////////////////////
169/// This function gets some title and it creates a TGraph for every title.
170/// It also sets up the style for every TGraph. All graphs are added to a single TMultiGraph.
171///
172/// \param[in] graphTitles vector of titles
173
174void TMVA::IPythonInteractive::Init(std::vector<TString>& graphTitles)
175{
176 if (fNumGraphs!=0){
177 std::cerr << kERROR << "IPythonInteractive::Init: already initialized..." << std::endl;
178 return;
179 }
180 Int_t color = 2;
181 for(auto& title : graphTitles){
182 fGraphs.push_back( new TGraph() );
183 fGraphs.back()->SetTitle(title);
184 fGraphs.back()->SetName(title);
185 fGraphs.back()->SetFillColor(color);
186 fGraphs.back()->SetLineColor(color);
187 fGraphs.back()->SetMarkerColor(color);
188 fMultiGraph->Add(fGraphs.back());
189 color += 2;
190 fNumGraphs += 1;
191 }
192 return;
193}
194
195////////////////////////////////////////////////////////////////////////////////
196/// This function sets the point number to 0 for all graphs.
197
200 for(Int_t i=0; i<fNumGraphs; i++){
201 fGraphs[i]->Set(0);
202 }
203}
204
205////////////////////////////////////////////////////////////////////////////////
206/// This function is used only in 2 TGraph case, and it will add new data points to graphs.
207///
208/// \param[in] x the x coordinate
209/// \param[in] y1 the y coordinate for the first TGraph
210/// \param[in] y2 the y coordinate for the second TGraph
211
213{
214 fGraphs[0]->Set(fIndex+1);
215 fGraphs[1]->Set(fIndex+1);
216 fGraphs[0]->SetPoint(fIndex, x, y1);
217 fGraphs[1]->SetPoint(fIndex, x, y2);
218 fIndex++;
219 return;
220}
221
222////////////////////////////////////////////////////////////////////////////////
223/// This function can add data points to as many TGraphs as we have.
224///
225/// \param[in] dat vector of data points. The dat[0] contains the x coordinate,
226/// dat[1] contains the y coordinate for first TGraph, dat[2] for second, ...
227
228void TMVA::IPythonInteractive::AddPoint(std::vector<Double_t>& dat)
229{
230 for(Int_t i=0; i<fNumGraphs;i++){
231 fGraphs[i]->Set(fIndex+1);
232 fGraphs[i]->SetPoint(fIndex, dat[0], dat[i+1]);
233 }
234 fIndex++;
235 return;
236}
237
238
239////////////////////////////////////////////////////////////////////////////////
240/// standard constructor
241
243 Types::EMVA methodType,
244 const TString& methodTitle,
245 DataSetInfo& dsi,
246 const TString& theOption) :
247 IMethod(),
248 Configurable ( theOption ),
249 fTmpEvent ( 0 ),
250 fRanking ( 0 ),
251 fInputVars ( 0 ),
252 fAnalysisType ( Types::kNoAnalysisType ),
253 fRegressionReturnVal ( 0 ),
254 fMulticlassReturnVal ( 0 ),
255 fDataSetInfo ( dsi ),
256 fSignalReferenceCut ( 0.5 ),
257 fSignalReferenceCutOrientation( 1. ),
258 fVariableTransformType ( Types::kSignal ),
259 fJobName ( jobName ),
260 fMethodName ( methodTitle ),
261 fMethodType ( methodType ),
262 fTestvar ( "" ),
263 fTMVATrainingVersion ( TMVA_VERSION_CODE ),
264 fROOTTrainingVersion ( ROOT_VERSION_CODE ),
265 fConstructedFromWeightFile ( kFALSE ),
266 fBaseDir ( 0 ),
267 fMethodBaseDir ( 0 ),
268 fFile ( 0 ),
269 fSilentFile (kFALSE),
270 fModelPersistence (kTRUE),
271 fWeightFile ( "" ),
272 fEffS ( 0 ),
273 fDefaultPDF ( 0 ),
274 fMVAPdfS ( 0 ),
275 fMVAPdfB ( 0 ),
276 fSplS ( 0 ),
277 fSplB ( 0 ),
278 fSpleffBvsS ( 0 ),
279 fSplTrainS ( 0 ),
280 fSplTrainB ( 0 ),
281 fSplTrainEffBvsS ( 0 ),
282 fVarTransformString ( "None" ),
283 fTransformationPointer ( 0 ),
284 fTransformation ( dsi, methodTitle ),
285 fVerbose ( kFALSE ),
286 fVerbosityLevelString ( "Default" ),
287 fHelp ( kFALSE ),
288 fHasMVAPdfs ( kFALSE ),
289 fIgnoreNegWeightsInTraining( kFALSE ),
290 fSignalClass ( 0 ),
291 fBackgroundClass ( 0 ),
292 fSplRefS ( 0 ),
293 fSplRefB ( 0 ),
294 fSplTrainRefS ( 0 ),
295 fSplTrainRefB ( 0 ),
296 fSetupCompleted (kFALSE)
297{
300
301// // default extension for weight files
302}
303
304////////////////////////////////////////////////////////////////////////////////
305/// constructor used for Testing + Application of the MVA,
306/// only (no training), using given WeightFiles
307
309 DataSetInfo& dsi,
310 const TString& weightFile ) :
311 IMethod(),
312 Configurable(""),
313 fTmpEvent ( 0 ),
314 fRanking ( 0 ),
315 fInputVars ( 0 ),
316 fAnalysisType ( Types::kNoAnalysisType ),
317 fRegressionReturnVal ( 0 ),
318 fMulticlassReturnVal ( 0 ),
319 fDataSetInfo ( dsi ),
320 fSignalReferenceCut ( 0.5 ),
321 fVariableTransformType ( Types::kSignal ),
322 fJobName ( "" ),
323 fMethodName ( "MethodBase" ),
324 fMethodType ( methodType ),
325 fTestvar ( "" ),
326 fTMVATrainingVersion ( 0 ),
327 fROOTTrainingVersion ( 0 ),
328 fConstructedFromWeightFile ( kTRUE ),
329 fBaseDir ( 0 ),
330 fMethodBaseDir ( 0 ),
331 fFile ( 0 ),
332 fSilentFile (kFALSE),
333 fModelPersistence (kTRUE),
334 fWeightFile ( weightFile ),
335 fEffS ( 0 ),
336 fDefaultPDF ( 0 ),
337 fMVAPdfS ( 0 ),
338 fMVAPdfB ( 0 ),
339 fSplS ( 0 ),
340 fSplB ( 0 ),
341 fSpleffBvsS ( 0 ),
342 fSplTrainS ( 0 ),
343 fSplTrainB ( 0 ),
344 fSplTrainEffBvsS ( 0 ),
345 fVarTransformString ( "None" ),
346 fTransformationPointer ( 0 ),
347 fTransformation ( dsi, "" ),
348 fVerbose ( kFALSE ),
349 fVerbosityLevelString ( "Default" ),
350 fHelp ( kFALSE ),
351 fHasMVAPdfs ( kFALSE ),
352 fIgnoreNegWeightsInTraining( kFALSE ),
353 fSignalClass ( 0 ),
354 fBackgroundClass ( 0 ),
355 fSplRefS ( 0 ),
356 fSplRefB ( 0 ),
357 fSplTrainRefS ( 0 ),
358 fSplTrainRefB ( 0 ),
359 fSetupCompleted (kFALSE)
360{
362// // constructor used for Testing + Application of the MVA,
363// // only (no training), using given WeightFiles
364}
365
366////////////////////////////////////////////////////////////////////////////////
367/// destructor
368
370{
371 // destructor
372 if (!fSetupCompleted) Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Calling destructor of method which got never setup" << Endl;
373
374 // destructor
375 if (fInputVars != 0) { fInputVars->clear(); delete fInputVars; }
376 if (fRanking != 0) delete fRanking;
377
378 // PDFs
379 if (fDefaultPDF!= 0) { delete fDefaultPDF; fDefaultPDF = 0; }
380 if (fMVAPdfS != 0) { delete fMVAPdfS; fMVAPdfS = 0; }
381 if (fMVAPdfB != 0) { delete fMVAPdfB; fMVAPdfB = 0; }
382
383 // Splines
384 if (fSplS) { delete fSplS; fSplS = 0; }
385 if (fSplB) { delete fSplB; fSplB = 0; }
386 if (fSpleffBvsS) { delete fSpleffBvsS; fSpleffBvsS = 0; }
387 if (fSplRefS) { delete fSplRefS; fSplRefS = 0; }
388 if (fSplRefB) { delete fSplRefB; fSplRefB = 0; }
389 if (fSplTrainRefS) { delete fSplTrainRefS; fSplTrainRefS = 0; }
390 if (fSplTrainRefB) { delete fSplTrainRefB; fSplTrainRefB = 0; }
391 if (fSplTrainEffBvsS) { delete fSplTrainEffBvsS; fSplTrainEffBvsS = 0; }
392
393 for (Int_t i = 0; i < 2; i++ ) {
394 if (fEventCollections.at(i)) {
395 for (std::vector<Event*>::const_iterator it = fEventCollections.at(i)->begin();
396 it != fEventCollections.at(i)->end(); ++it) {
397 delete (*it);
398 }
399 delete fEventCollections.at(i);
400 fEventCollections.at(i) = 0;
401 }
402 }
403
404 if (fRegressionReturnVal) delete fRegressionReturnVal;
405 if (fMulticlassReturnVal) delete fMulticlassReturnVal;
406}
407
408////////////////////////////////////////////////////////////////////////////////
409/// setup of methods
410
412{
413 // setup of methods
414
415 if (fSetupCompleted) Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Calling SetupMethod for the second time" << Endl;
416 InitBase();
417 DeclareBaseOptions();
418 Init();
419 DeclareOptions();
420 fSetupCompleted = kTRUE;
421}
422
423////////////////////////////////////////////////////////////////////////////////
424/// process all options
425/// the "CheckForUnusedOptions" is done in an independent call, since it may be overridden by derived class
426/// (sometimes, eg, fitters are used which can only be implemented during training phase)
427
429{
430 ProcessBaseOptions();
431 ProcessOptions();
432}
433
434////////////////////////////////////////////////////////////////////////////////
435/// check may be overridden by derived class
436/// (sometimes, eg, fitters are used which can only be implemented during training phase)
437
439{
440 CheckForUnusedOptions();
441}
442
443////////////////////////////////////////////////////////////////////////////////
444/// default initialization called by all constructors
445
447{
448 SetConfigDescription( "Configuration options for classifier architecture and tuning" );
449
451 fNbinsMVAoutput = gConfig().fVariablePlotting.fNbinsMVAoutput;
452 fNbinsH = NBIN_HIST_HIGH;
453
454 fSplTrainS = 0;
455 fSplTrainB = 0;
456 fSplTrainEffBvsS = 0;
457 fMeanS = -1;
458 fMeanB = -1;
459 fRmsS = -1;
460 fRmsB = -1;
461 fXmin = DBL_MAX;
462 fXmax = -DBL_MAX;
463 fTxtWeightsOnly = kTRUE;
464 fSplRefS = 0;
465 fSplRefB = 0;
466
467 fTrainTime = -1.;
468 fTestTime = -1.;
469
470 fRanking = 0;
471
472 // temporary until the move to DataSet is complete
473 fInputVars = new std::vector<TString>;
474 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
475 fInputVars->push_back(DataInfo().GetVariableInfo(ivar).GetLabel());
476 }
477 fRegressionReturnVal = 0;
478 fMulticlassReturnVal = 0;
479
480 fEventCollections.resize( 2 );
481 fEventCollections.at(0) = 0;
482 fEventCollections.at(1) = 0;
483
484 // retrieve signal and background class index
485 if (DataInfo().GetClassInfo("Signal") != 0) {
486 fSignalClass = DataInfo().GetClassInfo("Signal")->GetNumber();
487 }
488 if (DataInfo().GetClassInfo("Background") != 0) {
489 fBackgroundClass = DataInfo().GetClassInfo("Background")->GetNumber();
490 }
491
492 SetConfigDescription( "Configuration options for MVA method" );
493 SetConfigName( TString("Method") + GetMethodTypeName() );
494}
495
496////////////////////////////////////////////////////////////////////////////////
497/// define the options (their key words) that can be set in the option string
498/// here the options valid for ALL MVA methods are declared.
499///
500/// know options:
501///
502/// - VariableTransform=None,Decorrelated,PCA to use transformed variables
503/// instead of the original ones
504/// - VariableTransformType=Signal,Background which decorrelation matrix to use
505/// in the method. Only the Likelihood
506/// Method can make proper use of independent
507/// transformations of signal and background
508/// - fNbinsMVAPdf = 50 Number of bins used to create a PDF of MVA
509/// - fNsmoothMVAPdf = 2 Number of times a histogram is smoothed before creating the PDF
510/// - fHasMVAPdfs create PDFs for the MVA outputs
511/// - V for Verbose output (!V) for non verbos
512/// - H for Help message
513
515{
516 DeclareOptionRef( fVerbose, "V", "Verbose output (short form of \"VerbosityLevel\" below - overrides the latter one)" );
517
518 DeclareOptionRef( fVerbosityLevelString="Default", "VerbosityLevel", "Verbosity level" );
519 AddPreDefVal( TString("Default") ); // uses default defined in MsgLogger header
520 AddPreDefVal( TString("Debug") );
521 AddPreDefVal( TString("Verbose") );
522 AddPreDefVal( TString("Info") );
523 AddPreDefVal( TString("Warning") );
524 AddPreDefVal( TString("Error") );
525 AddPreDefVal( TString("Fatal") );
526
527 // If True (default): write all training results (weights) as text files only;
528 // if False: write also in ROOT format (not available for all methods - will abort if not
529 fTxtWeightsOnly = kTRUE; // OBSOLETE !!!
530 fNormalise = kFALSE; // OBSOLETE !!!
531
532 DeclareOptionRef( fVarTransformString, "VarTransform", "List of variable transformations performed before training, e.g., \"D_Background,P_Signal,G,N_AllClasses\" for: \"Decorrelation, PCA-transformation, Gaussianisation, Normalisation, each for the given class of events ('AllClasses' denotes all events of all classes, if no class indication is given, 'All' is assumed)\"" );
533
534 DeclareOptionRef( fHelp, "H", "Print method-specific help message" );
535
536 DeclareOptionRef( fHasMVAPdfs, "CreateMVAPdfs", "Create PDFs for classifier outputs (signal and background)" );
537
538 DeclareOptionRef( fIgnoreNegWeightsInTraining, "IgnoreNegWeightsInTraining",
539 "Events with negative weights are ignored in the training (but are included for testing and performance evaluation)" );
540}
541
542////////////////////////////////////////////////////////////////////////////////
543/// the option string is decoded, for available options see "DeclareOptions"
544
546{
547 if (HasMVAPdfs()) {
548 // setting the default bin num... maybe should be static ? ==> Please no static (JS)
549 // You can't use the logger in the constructor!!! Log() << kINFO << "Create PDFs" << Endl;
550 // reading every PDF's definition and passing the option string to the next one to be read and marked
551 fDefaultPDF = new PDF( TString(GetName())+"_PDF", GetOptions(), "MVAPdf" );
552 fDefaultPDF->DeclareOptions();
553 fDefaultPDF->ParseOptions();
554 fDefaultPDF->ProcessOptions();
555 fMVAPdfB = new PDF( TString(GetName())+"_PDFBkg", fDefaultPDF->GetOptions(), "MVAPdfBkg", fDefaultPDF );
556 fMVAPdfB->DeclareOptions();
557 fMVAPdfB->ParseOptions();
558 fMVAPdfB->ProcessOptions();
559 fMVAPdfS = new PDF( TString(GetName())+"_PDFSig", fMVAPdfB->GetOptions(), "MVAPdfSig", fDefaultPDF );
560 fMVAPdfS->DeclareOptions();
561 fMVAPdfS->ParseOptions();
562 fMVAPdfS->ProcessOptions();
563
564 // the final marked option string is written back to the original methodbase
565 SetOptions( fMVAPdfS->GetOptions() );
566 }
567
568 TMVA::CreateVariableTransforms( fVarTransformString,
569 DataInfo(),
570 GetTransformationHandler(),
571 Log() );
572
573 if (!HasMVAPdfs()) {
574 if (fDefaultPDF!= 0) { delete fDefaultPDF; fDefaultPDF = 0; }
575 if (fMVAPdfS != 0) { delete fMVAPdfS; fMVAPdfS = 0; }
576 if (fMVAPdfB != 0) { delete fMVAPdfB; fMVAPdfB = 0; }
577 }
578
579 if (fVerbose) { // overwrites other settings
580 fVerbosityLevelString = TString("Verbose");
581 Log().SetMinType( kVERBOSE );
582 }
583 else if (fVerbosityLevelString == "Debug" ) Log().SetMinType( kDEBUG );
584 else if (fVerbosityLevelString == "Verbose" ) Log().SetMinType( kVERBOSE );
585 else if (fVerbosityLevelString == "Info" ) Log().SetMinType( kINFO );
586 else if (fVerbosityLevelString == "Warning" ) Log().SetMinType( kWARNING );
587 else if (fVerbosityLevelString == "Error" ) Log().SetMinType( kERROR );
588 else if (fVerbosityLevelString == "Fatal" ) Log().SetMinType( kFATAL );
589 else if (fVerbosityLevelString != "Default" ) {
590 Log() << kFATAL << "<ProcessOptions> Verbosity level type '"
591 << fVerbosityLevelString << "' unknown." << Endl;
592 }
593 Event::SetIgnoreNegWeightsInTraining(fIgnoreNegWeightsInTraining);
594}
595
596////////////////////////////////////////////////////////////////////////////////
597/// options that are used ONLY for the READER to ensure backward compatibility
598/// they are hence without any effect (the reader is only reading the training
599/// options that HAD been used at the training of the .xml weight file at hand
600
602{
603 DeclareOptionRef( fNormalise=kFALSE, "Normalise", "Normalise input variables" ); // don't change the default !!!
604 DeclareOptionRef( fUseDecorr=kFALSE, "D", "Use-decorrelated-variables flag" );
605 DeclareOptionRef( fVariableTransformTypeString="Signal", "VarTransformType",
606 "Use signal or background events to derive for variable transformation (the transformation is applied on both types of, course)" );
607 AddPreDefVal( TString("Signal") );
608 AddPreDefVal( TString("Background") );
609 DeclareOptionRef( fTxtWeightsOnly=kTRUE, "TxtWeightFilesOnly", "If True: write all training results (weights) as text files (False: some are written in ROOT format)" );
610 // Why on earth ?? was this here? Was the verbosity level option meant to 'disappear? Not a good idea i think..
611 // DeclareOptionRef( fVerbosityLevelString="Default", "VerboseLevel", "Verbosity level" );
612 // AddPreDefVal( TString("Default") ); // uses default defined in MsgLogger header
613 // AddPreDefVal( TString("Debug") );
614 // AddPreDefVal( TString("Verbose") );
615 // AddPreDefVal( TString("Info") );
616 // AddPreDefVal( TString("Warning") );
617 // AddPreDefVal( TString("Error") );
618 // AddPreDefVal( TString("Fatal") );
619 DeclareOptionRef( fNbinsMVAPdf = 60, "NbinsMVAPdf", "Number of bins used for the PDFs of classifier outputs" );
620 DeclareOptionRef( fNsmoothMVAPdf = 2, "NsmoothMVAPdf", "Number of smoothing iterations for classifier PDFs" );
621}
622
623
624////////////////////////////////////////////////////////////////////////////////
625/// call the Optimizer with the set of parameters and ranges that
626/// are meant to be tuned.
627
628std::map<TString,Double_t> TMVA::MethodBase::OptimizeTuningParameters(TString /* fomType */ , TString /* fitType */)
629{
630 // this is just a dummy... needs to be implemented for each method
631 // individually (as long as we don't have it automatized via the
632 // configuration string
633
634 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Parameter optimization is not yet implemented for method "
635 << GetName() << Endl;
636 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Currently we need to set hardcoded which parameter is tuned in which ranges"<<Endl;
637
638 std::map<TString,Double_t> tunedParameters;
639 tunedParameters.size(); // just to get rid of "unused" warning
640 return tunedParameters;
641
642}
643
644////////////////////////////////////////////////////////////////////////////////
645/// set the tuning parameters according to the argument
646/// This is just a dummy .. have a look at the MethodBDT how you could
647/// perhaps implement the same thing for the other Classifiers..
648
649void TMVA::MethodBase::SetTuneParameters(std::map<TString,Double_t> /* tuneParameters */)
650{
651}
652
653////////////////////////////////////////////////////////////////////////////////
654
656{
657 Data()->SetCurrentType(Types::kTraining);
658 Event::SetIsTraining(kTRUE); // used to set negative event weights to zero if chosen to do so
659
660 // train the MVA method
661 if (Help()) PrintHelpMessage();
662
663 // all histograms should be created in the method's subdirectory
664 if(!IsSilentFile()) BaseDir()->cd();
665
666 // once calculate all the transformation (e.g. the sequence of Decorr:Gauss:Decorr)
667 // needed for this classifier
668 GetTransformationHandler().CalcTransformations(Data()->GetEventCollection());
669
670 // call training of derived MVA
671 Log() << kDEBUG //<<Form("\tDataset[%s] : ",DataInfo().GetName())
672 << "Begin training" << Endl;
673 Long64_t nEvents = Data()->GetNEvents();
674 Timer traintimer( nEvents, GetName(), kTRUE );
675 Train();
676 Log() << kDEBUG //<<Form("Dataset[%s] : ",DataInfo().GetName()
677 << "\tEnd of training " << Endl;
678 SetTrainTime(traintimer.ElapsedSeconds());
679 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
680 << "Elapsed time for training with " << nEvents << " events: "
681 << traintimer.GetElapsedTime() << " " << Endl;
682
683 Log() << kDEBUG //<<Form("Dataset[%s] : ",DataInfo().GetName())
684 << "\tCreate MVA output for ";
685
686 // create PDFs for the signal and background MVA distributions (if required)
687 if (DoMulticlass()) {
688 Log() <<Form("[%s] : ",DataInfo().GetName())<< "Multiclass classification on training sample" << Endl;
689 AddMulticlassOutput(Types::kTraining);
690 }
691 else if (!DoRegression()) {
692
693 Log() <<Form("[%s] : ",DataInfo().GetName())<< "classification on training sample" << Endl;
694 AddClassifierOutput(Types::kTraining);
695 if (HasMVAPdfs()) {
696 CreateMVAPdfs();
697 AddClassifierOutputProb(Types::kTraining);
698 }
699
700 } else {
701
702 Log() <<Form("Dataset[%s] : ",DataInfo().GetName())<< "regression on training sample" << Endl;
703 AddRegressionOutput( Types::kTraining );
704
705 if (HasMVAPdfs() ) {
706 Log() <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Create PDFs" << Endl;
707 CreateMVAPdfs();
708 }
709 }
710
711 // write the current MVA state into stream
712 // produced are one text file and one ROOT file
713 if (fModelPersistence ) WriteStateToFile();
714
715 // produce standalone make class (presently only supported for classification)
716 if ((!DoRegression()) && (fModelPersistence)) MakeClass();
717
718 // write additional monitoring histograms to main target file (not the weight file)
719 // again, make sure the histograms go into the method's subdirectory
720 if(!IsSilentFile())
721 {
722 BaseDir()->cd();
723 WriteMonitoringHistosToFile();
724 }
725}
726
727////////////////////////////////////////////////////////////////////////////////
728
730{
731 if (!DoRegression()) Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Trying to use GetRegressionDeviation() with a classification job" << Endl;
732 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Create results for " << (type==Types::kTraining?"training":"testing") << Endl;
734 bool truncate = false;
735 TH1F* h1 = regRes->QuadraticDeviation( tgtNum , truncate, 1.);
736 stddev = sqrt(h1->GetMean());
737 truncate = true;
738 Double_t yq[1], xq[]={0.9};
739 h1->GetQuantiles(1,yq,xq);
740 TH1F* h2 = regRes->QuadraticDeviation( tgtNum , truncate, yq[0]);
741 stddev90Percent = sqrt(h2->GetMean());
742 delete h1;
743 delete h2;
744}
745
746////////////////////////////////////////////////////////////////////////////////
747/// prepare tree branch with the method's discriminating variable
748
750{
751 Data()->SetCurrentType(type);
752
753 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Create results for " << (type==Types::kTraining?"training":"testing") << Endl;
754
755 ResultsRegression* regRes = (ResultsRegression*)Data()->GetResults(GetMethodName(), type, Types::kRegression);
756
757 Long64_t nEvents = Data()->GetNEvents();
758
759 // use timer
760 Timer timer( nEvents, GetName(), kTRUE );
761 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName()) << "Evaluation of " << GetMethodName() << " on "
762 << (type==Types::kTraining?"training":"testing") << " sample" << Endl;
763
764 regRes->Resize( nEvents );
765
766 // Drawing the progress bar every event was causing a huge slowdown in the evaluation time
767 // So we set some parameters to draw the progress bar a total of totalProgressDraws, i.e. only draw every 1 in 100
768
769 Int_t totalProgressDraws = 100; // total number of times to update the progress bar
770 Int_t drawProgressEvery = 1; // draw every nth event such that we have a total of totalProgressDraws
771 if(nEvents >= totalProgressDraws) drawProgressEvery = nEvents/totalProgressDraws;
772
773 for (Int_t ievt=0; ievt<nEvents; ievt++) {
774
775 Data()->SetCurrentEvent(ievt);
776 std::vector< Float_t > vals = GetRegressionValues();
777 regRes->SetValue( vals, ievt );
778
779 // Only draw the progress bar once in a while, doing this every event causes the evaluation to be ridiculously slow
780 if(ievt % drawProgressEvery == 0 || ievt==nEvents-1) timer.DrawProgressBar( ievt );
781 }
782
783 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())
784 << "Elapsed time for evaluation of " << nEvents << " events: "
785 << timer.GetElapsedTime() << " " << Endl;
786
787 // store time used for testing
789 SetTestTime(timer.ElapsedSeconds());
790
791 TString histNamePrefix(GetTestvarName());
792 histNamePrefix += (type==Types::kTraining?"train":"test");
793 regRes->CreateDeviationHistograms( histNamePrefix );
794}
795
796////////////////////////////////////////////////////////////////////////////////
797/// prepare tree branch with the method's discriminating variable
798
800{
801 Data()->SetCurrentType(type);
802
803 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Create results for " << (type==Types::kTraining?"training":"testing") << Endl;
804
805 ResultsMulticlass* resMulticlass = dynamic_cast<ResultsMulticlass*>(Data()->GetResults(GetMethodName(), type, Types::kMulticlass));
806 if (!resMulticlass) Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName())<< "unable to create pointer in AddMulticlassOutput, exiting."<<Endl;
807
808 Long64_t nEvents = Data()->GetNEvents();
809
810 // use timer
811 Timer timer( nEvents, GetName(), kTRUE );
812
813 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Multiclass evaluation of " << GetMethodName() << " on "
814 << (type==Types::kTraining?"training":"testing") << " sample" << Endl;
815
816 resMulticlass->Resize( nEvents );
817 for (Int_t ievt=0; ievt<nEvents; ievt++) {
818 Data()->SetCurrentEvent(ievt);
819 std::vector< Float_t > vals = GetMulticlassValues();
820 resMulticlass->SetValue( vals, ievt );
821 timer.DrawProgressBar( ievt );
822 }
823
824 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())
825 << "Elapsed time for evaluation of " << nEvents << " events: "
826 << timer.GetElapsedTime() << " " << Endl;
827
828 // store time used for testing
830 SetTestTime(timer.ElapsedSeconds());
831
832 TString histNamePrefix(GetTestvarName());
833 histNamePrefix += (type==Types::kTraining?"_Train":"_Test");
834
835 resMulticlass->CreateMulticlassHistos( histNamePrefix, fNbinsMVAoutput, fNbinsH );
836 resMulticlass->CreateMulticlassPerformanceHistos(histNamePrefix);
837}
838
839////////////////////////////////////////////////////////////////////////////////
840
841void TMVA::MethodBase::NoErrorCalc(Double_t* const err, Double_t* const errUpper) {
842 if (err) *err=-1;
843 if (errUpper) *errUpper=-1;
844}
845
846////////////////////////////////////////////////////////////////////////////////
847
848Double_t TMVA::MethodBase::GetMvaValue( const Event* const ev, Double_t* err, Double_t* errUpper ) {
849 fTmpEvent = ev;
850 Double_t val = GetMvaValue(err, errUpper);
851 fTmpEvent = 0;
852 return val;
853}
854
855////////////////////////////////////////////////////////////////////////////////
856/// uses a pre-set cut on the MVA output (SetSignalReferenceCut and SetSignalReferenceCutOrientation)
857/// for a quick determination if an event would be selected as signal or background
858
860 return GetMvaValue()*GetSignalReferenceCutOrientation() > GetSignalReferenceCut()*GetSignalReferenceCutOrientation() ? kTRUE : kFALSE;
861}
862////////////////////////////////////////////////////////////////////////////////
863/// uses a pre-set cut on the MVA output (SetSignalReferenceCut and SetSignalReferenceCutOrientation)
864/// for a quick determination if an event with this mva output value would be selected as signal or background
865
867 return mvaVal*GetSignalReferenceCutOrientation() > GetSignalReferenceCut()*GetSignalReferenceCutOrientation() ? kTRUE : kFALSE;
868}
869
870////////////////////////////////////////////////////////////////////////////////
871/// prepare tree branch with the method's discriminating variable
872
874{
875 Data()->SetCurrentType(type);
876
877 ResultsClassification* clRes =
879
880 Long64_t nEvents = Data()->GetNEvents();
881 clRes->Resize( nEvents );
882
883 // use timer
884 Timer timer( nEvents, GetName(), kTRUE );
885 std::vector<Double_t> mvaValues = GetMvaValues(0, nEvents, true);
886
887 // store time used for testing
889 SetTestTime(timer.ElapsedSeconds());
890
891 // load mva values to results object
892 for (Int_t ievt=0; ievt<nEvents; ievt++) {
893 clRes->SetValue( mvaValues[ievt], ievt );
894 }
895}
896
897////////////////////////////////////////////////////////////////////////////////
898/// get all the MVA values for the events of the current Data type
899std::vector<Double_t> TMVA::MethodBase::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
900{
901
902 Long64_t nEvents = Data()->GetNEvents();
903 if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
904 if (firstEvt < 0) firstEvt = 0;
905 std::vector<Double_t> values(lastEvt-firstEvt);
906 // log in case of looping on all the events
907 nEvents = values.size();
908
909 // use timer
910 Timer timer( nEvents, GetName(), kTRUE );
911
912 if (logProgress)
913 Log() << kHEADER << Form("[%s] : ",DataInfo().GetName())
914 << "Evaluation of " << GetMethodName() << " on "
915 << (Data()->GetCurrentType() == Types::kTraining ? "training" : "testing")
916 << " sample (" << nEvents << " events)" << Endl;
917
918 for (Int_t ievt=firstEvt; ievt<lastEvt; ievt++) {
919 Data()->SetCurrentEvent(ievt);
920 values[ievt] = GetMvaValue();
921
922 // print progress
923 if (logProgress) {
924 Int_t modulo = Int_t(nEvents/100);
925 if (modulo <= 0 ) modulo = 1;
926 if (ievt%modulo == 0) timer.DrawProgressBar( ievt );
927 }
928 }
929 if (logProgress) {
930 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
931 << "Elapsed time for evaluation of " << nEvents << " events: "
932 << timer.GetElapsedTime() << " " << Endl;
933 }
934
935 return values;
936}
937
938////////////////////////////////////////////////////////////////////////////////
939/// prepare tree branch with the method's discriminating variable
940
942{
943 Data()->SetCurrentType(type);
944
945 ResultsClassification* mvaProb =
946 (ResultsClassification*)Data()->GetResults(TString("prob_")+GetMethodName(), type, Types::kClassification );
947
948 Long64_t nEvents = Data()->GetNEvents();
949
950 // use timer
951 Timer timer( nEvents, GetName(), kTRUE );
952
953 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName()) << "Evaluation of " << GetMethodName() << " on "
954 << (type==Types::kTraining?"training":"testing") << " sample" << Endl;
955
956 mvaProb->Resize( nEvents );
957 for (Int_t ievt=0; ievt<nEvents; ievt++) {
958
959 Data()->SetCurrentEvent(ievt);
960 Float_t proba = ((Float_t)GetProba( GetMvaValue(), 0.5 ));
961 if (proba < 0) break;
962 mvaProb->SetValue( proba, ievt );
963
964 // print progress
965 Int_t modulo = Int_t(nEvents/100);
966 if (modulo <= 0 ) modulo = 1;
967 if (ievt%modulo == 0) timer.DrawProgressBar( ievt );
968 }
969
970 Log() << kDEBUG <<Form("Dataset[%s] : ",DataInfo().GetName())
971 << "Elapsed time for evaluation of " << nEvents << " events: "
972 << timer.GetElapsedTime() << " " << Endl;
973}
974
975////////////////////////////////////////////////////////////////////////////////
976/// calculate <sum-of-deviation-squared> of regression output versus "true" value from test sample
977///
978/// - bias = average deviation
979/// - dev = average absolute deviation
980/// - rms = rms of deviation
981
983 Double_t& dev, Double_t& devT,
984 Double_t& rms, Double_t& rmsT,
985 Double_t& mInf, Double_t& mInfT,
986 Double_t& corr,
988{
989 Types::ETreeType savedType = Data()->GetCurrentType();
990 Data()->SetCurrentType(type);
991
992 bias = 0; biasT = 0; dev = 0; devT = 0; rms = 0; rmsT = 0;
993 Double_t sumw = 0;
994 Double_t m1 = 0, m2 = 0, s1 = 0, s2 = 0, s12 = 0; // for correlation
995 const Int_t nevt = GetNEvents();
996 Float_t* rV = new Float_t[nevt];
997 Float_t* tV = new Float_t[nevt];
998 Float_t* wV = new Float_t[nevt];
999 Float_t xmin = 1e30, xmax = -1e30;
1000 Log() << kINFO << "Calculate regression for all events" << Endl;
1001 Timer timer( nevt, GetName(), kTRUE );
1002 for (Long64_t ievt=0; ievt<nevt; ievt++) {
1003
1004 const Event* ev = Data()->GetEvent(ievt); // NOTE: need untransformed event here !
1005 Float_t t = ev->GetTarget(0);
1006 Float_t w = ev->GetWeight();
1007 Float_t r = GetRegressionValues()[0];
1008 Float_t d = (r-t);
1009
1010 // find min/max
1013
1014 // store for truncated RMS computation
1015 rV[ievt] = r;
1016 tV[ievt] = t;
1017 wV[ievt] = w;
1018
1019 // compute deviation-squared
1020 sumw += w;
1021 bias += w * d;
1022 dev += w * TMath::Abs(d);
1023 rms += w * d * d;
1024
1025 // compute correlation between target and regression estimate
1026 m1 += t*w; s1 += t*t*w;
1027 m2 += r*w; s2 += r*r*w;
1028 s12 += t*r;
1029 if ((ievt & 0xFF) == 0) timer.DrawProgressBar(ievt);
1030 }
1031 timer.DrawProgressBar(nevt - 1);
1032 Log() << kINFO << "Elapsed time for evaluation of " << nevt << " events: "
1033 << timer.GetElapsedTime() << " " << Endl;
1034
1035 // standard quantities
1036 bias /= sumw;
1037 dev /= sumw;
1038 rms /= sumw;
1039 rms = TMath::Sqrt(rms - bias*bias);
1040
1041 // correlation
1042 m1 /= sumw;
1043 m2 /= sumw;
1044 corr = s12/sumw - m1*m2;
1045 corr /= TMath::Sqrt( (s1/sumw - m1*m1) * (s2/sumw - m2*m2) );
1046
1047 // create histogram required for computation of mutual information
1048 TH2F* hist = new TH2F( "hist", "hist", 150, xmin, xmax, 100, xmin, xmax );
1049 TH2F* histT = new TH2F( "histT", "histT", 150, xmin, xmax, 100, xmin, xmax );
1050
1051 // compute truncated RMS and fill histogram
1052 Double_t devMax = bias + 2*rms;
1053 Double_t devMin = bias - 2*rms;
1054 sumw = 0;
1055 int ic=0;
1056 for (Long64_t ievt=0; ievt<nevt; ievt++) {
1057 Float_t d = (rV[ievt] - tV[ievt]);
1058 hist->Fill( rV[ievt], tV[ievt], wV[ievt] );
1059 if (d >= devMin && d <= devMax) {
1060 sumw += wV[ievt];
1061 biasT += wV[ievt] * d;
1062 devT += wV[ievt] * TMath::Abs(d);
1063 rmsT += wV[ievt] * d * d;
1064 histT->Fill( rV[ievt], tV[ievt], wV[ievt] );
1065 ic++;
1066 }
1067 }
1068 biasT /= sumw;
1069 devT /= sumw;
1070 rmsT /= sumw;
1071 rmsT = TMath::Sqrt(rmsT - biasT*biasT);
1072 mInf = gTools().GetMutualInformation( *hist );
1073 mInfT = gTools().GetMutualInformation( *histT );
1074
1075 delete hist;
1076 delete histT;
1077
1078 delete [] rV;
1079 delete [] tV;
1080 delete [] wV;
1081
1082 Data()->SetCurrentType(savedType);
1083}
1084
1085
1086////////////////////////////////////////////////////////////////////////////////
1087/// test multiclass classification
1088
1090{
1091 ResultsMulticlass* resMulticlass = dynamic_cast<ResultsMulticlass*>(Data()->GetResults(GetMethodName(), Types::kTesting, Types::kMulticlass));
1092 if (!resMulticlass) Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName())<< "unable to create pointer in TestMulticlass, exiting."<<Endl;
1093
1094 // GA evaluation of best cut for sig eff * sig pur. Slow, disabled for now.
1095 // Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Determine optimal multiclass cuts for test
1096 // data..." << Endl; for (UInt_t icls = 0; icls<DataInfo().GetNClasses(); ++icls) {
1097 // resMulticlass->GetBestMultiClassCuts(icls);
1098 // }
1099
1100 // Create histograms for use in TMVA GUI
1101 TString histNamePrefix(GetTestvarName());
1102 TString histNamePrefixTest{histNamePrefix + "_Test"};
1103 TString histNamePrefixTrain{histNamePrefix + "_Train"};
1104
1105 resMulticlass->CreateMulticlassHistos(histNamePrefixTest, fNbinsMVAoutput, fNbinsH);
1106 resMulticlass->CreateMulticlassPerformanceHistos(histNamePrefixTest);
1107
1108 resMulticlass->CreateMulticlassHistos(histNamePrefixTrain, fNbinsMVAoutput, fNbinsH);
1109 resMulticlass->CreateMulticlassPerformanceHistos(histNamePrefixTrain);
1110}
1111
1112
1113////////////////////////////////////////////////////////////////////////////////
1114/// initialization
1115
1117{
1118 Data()->SetCurrentType(Types::kTesting);
1119
1120 ResultsClassification* mvaRes = dynamic_cast<ResultsClassification*>
1121 ( Data()->GetResults(GetMethodName(),Types::kTesting, Types::kClassification) );
1122
1123 // sanity checks: tree must exist, and theVar must be in tree
1124 if (0==mvaRes && !(GetMethodTypeName().Contains("Cuts"))) {
1125 Log()<<Form("Dataset[%s] : ",DataInfo().GetName()) << "mvaRes " << mvaRes << " GetMethodTypeName " << GetMethodTypeName()
1126 << " contains " << !(GetMethodTypeName().Contains("Cuts")) << Endl;
1127 Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName()) << "<TestInit> Test variable " << GetTestvarName()
1128 << " not found in tree" << Endl;
1129 }
1130
1131 // basic statistics operations are made in base class
1132 gTools().ComputeStat( GetEventCollection(Types::kTesting), mvaRes->GetValueVector(),
1133 fMeanS, fMeanB, fRmsS, fRmsB, fXmin, fXmax, fSignalClass );
1134
1135 // choose reasonable histogram ranges, by removing outliers
1136 Double_t nrms = 10;
1137 fXmin = TMath::Max( TMath::Min( fMeanS - nrms*fRmsS, fMeanB - nrms*fRmsB ), fXmin );
1138 fXmax = TMath::Min( TMath::Max( fMeanS + nrms*fRmsS, fMeanB + nrms*fRmsB ), fXmax );
1139
1140 // determine cut orientation
1141 fCutOrientation = (fMeanS > fMeanB) ? kPositive : kNegative;
1142
1143 // fill 2 types of histograms for the various analyses
1144 // this one is for actual plotting
1145
1146 Double_t sxmax = fXmax+0.00001;
1147
1148 // classifier response distributions for training sample
1149 // MVA plots used for graphics representation (signal)
1150 TString TestvarName;
1151 if(IsSilentFile())
1152 {
1153 TestvarName=Form("[%s]%s",DataInfo().GetName(),GetTestvarName().Data());
1154 }else
1155 {
1156 TestvarName=GetTestvarName();
1157 }
1158 TH1* mva_s = new TH1D( TestvarName + "_S",TestvarName + "_S", fNbinsMVAoutput, fXmin, sxmax );
1159 TH1* mva_b = new TH1D( TestvarName + "_B",TestvarName + "_B", fNbinsMVAoutput, fXmin, sxmax );
1160 mvaRes->Store(mva_s, "MVA_S");
1161 mvaRes->Store(mva_b, "MVA_B");
1162 mva_s->Sumw2();
1163 mva_b->Sumw2();
1164
1165 TH1* proba_s = 0;
1166 TH1* proba_b = 0;
1167 TH1* rarity_s = 0;
1168 TH1* rarity_b = 0;
1169 if (HasMVAPdfs()) {
1170 // P(MVA) plots used for graphics representation
1171 proba_s = new TH1D( TestvarName + "_Proba_S", TestvarName + "_Proba_S", fNbinsMVAoutput, 0.0, 1.0 );
1172 proba_b = new TH1D( TestvarName + "_Proba_B", TestvarName + "_Proba_B", fNbinsMVAoutput, 0.0, 1.0 );
1173 mvaRes->Store(proba_s, "Prob_S");
1174 mvaRes->Store(proba_b, "Prob_B");
1175 proba_s->Sumw2();
1176 proba_b->Sumw2();
1177
1178 // R(MVA) plots used for graphics representation
1179 rarity_s = new TH1D( TestvarName + "_Rarity_S", TestvarName + "_Rarity_S", fNbinsMVAoutput, 0.0, 1.0 );
1180 rarity_b = new TH1D( TestvarName + "_Rarity_B", TestvarName + "_Rarity_B", fNbinsMVAoutput, 0.0, 1.0 );
1181 mvaRes->Store(rarity_s, "Rar_S");
1182 mvaRes->Store(rarity_b, "Rar_B");
1183 rarity_s->Sumw2();
1184 rarity_b->Sumw2();
1185 }
1186
1187 // MVA plots used for efficiency calculations (large number of bins)
1188 TH1* mva_eff_s = new TH1D( TestvarName + "_S_high", TestvarName + "_S_high", fNbinsH, fXmin, sxmax );
1189 TH1* mva_eff_b = new TH1D( TestvarName + "_B_high", TestvarName + "_B_high", fNbinsH, fXmin, sxmax );
1190 mvaRes->Store(mva_eff_s, "MVA_HIGHBIN_S");
1191 mvaRes->Store(mva_eff_b, "MVA_HIGHBIN_B");
1192 mva_eff_s->Sumw2();
1193 mva_eff_b->Sumw2();
1194
1195 // fill the histograms
1196
1197 ResultsClassification* mvaProb = dynamic_cast<ResultsClassification*>
1198 (Data()->GetResults( TString("prob_")+GetMethodName(), Types::kTesting, Types::kMaxAnalysisType ) );
1199
1200 Log() << kHEADER <<Form("[%s] : ",DataInfo().GetName())<< "Loop over test events and fill histograms with classifier response..." << Endl << Endl;
1201 if (mvaProb) Log() << kINFO << "Also filling probability and rarity histograms (on request)..." << Endl;
1202 std::vector<Bool_t>* mvaResTypes = mvaRes->GetValueVectorTypes();
1203
1204 //LM: this is needed to avoid crashes in ROOCCURVE
1205 if ( mvaRes->GetSize() != GetNEvents() ) {
1206 Log() << kFATAL << TString::Format("Inconsistent result size %lld with number of events %u ", mvaRes->GetSize() , GetNEvents() ) << Endl;
1207 assert(mvaRes->GetSize() == GetNEvents());
1208 }
1209
1210 for (Long64_t ievt=0; ievt<GetNEvents(); ievt++) {
1211
1212 const Event* ev = GetEvent(ievt);
1213 Float_t v = (*mvaRes)[ievt][0];
1214 Float_t w = ev->GetWeight();
1215
1216 if (DataInfo().IsSignal(ev)) {
1217 mvaResTypes->push_back(kTRUE);
1218 mva_s ->Fill( v, w );
1219 if (mvaProb) {
1220 proba_s->Fill( (*mvaProb)[ievt][0], w );
1221 rarity_s->Fill( GetRarity( v ), w );
1222 }
1223
1224 mva_eff_s ->Fill( v, w );
1225 }
1226 else {
1227 mvaResTypes->push_back(kFALSE);
1228 mva_b ->Fill( v, w );
1229 if (mvaProb) {
1230 proba_b->Fill( (*mvaProb)[ievt][0], w );
1231 rarity_b->Fill( GetRarity( v ), w );
1232 }
1233 mva_eff_b ->Fill( v, w );
1234 }
1235 }
1236
1237 // uncomment those (and several others if you want unnormalized output
1238 gTools().NormHist( mva_s );
1239 gTools().NormHist( mva_b );
1240 gTools().NormHist( proba_s );
1241 gTools().NormHist( proba_b );
1242 gTools().NormHist( rarity_s );
1243 gTools().NormHist( rarity_b );
1244 gTools().NormHist( mva_eff_s );
1245 gTools().NormHist( mva_eff_b );
1246
1247 // create PDFs from histograms, using default splines, and no additional smoothing
1248 if (fSplS) { delete fSplS; fSplS = 0; }
1249 if (fSplB) { delete fSplB; fSplB = 0; }
1250 fSplS = new PDF( TString(GetName()) + " PDF Sig", mva_s, PDF::kSpline2 );
1251 fSplB = new PDF( TString(GetName()) + " PDF Bkg", mva_b, PDF::kSpline2 );
1252}
1253
1254////////////////////////////////////////////////////////////////////////////////
1255/// general method used in writing the header of the weight files where
1256/// the used variables, variable transformation type etc. is specified
1257
1258void TMVA::MethodBase::WriteStateToStream( std::ostream& tf ) const
1259{
1260 TString prefix = "";
1261 UserGroup_t * userInfo = gSystem->GetUserInfo();
1262
1263 tf << prefix << "#GEN -*-*-*-*-*-*-*-*-*-*-*- general info -*-*-*-*-*-*-*-*-*-*-*-" << std::endl << prefix << std::endl;
1264 tf << prefix << "Method : " << GetMethodTypeName() << "::" << GetMethodName() << std::endl;
1265 tf.setf(std::ios::left);
1266 tf << prefix << "TMVA Release : " << std::setw(10) << GetTrainingTMVAVersionString() << " ["
1267 << GetTrainingTMVAVersionCode() << "]" << std::endl;
1268 tf << prefix << "ROOT Release : " << std::setw(10) << GetTrainingROOTVersionString() << " ["
1269 << GetTrainingROOTVersionCode() << "]" << std::endl;
1270 tf << prefix << "Creator : " << userInfo->fUser << std::endl;
1271 tf << prefix << "Date : "; TDatime *d = new TDatime; tf << d->AsString() << std::endl; delete d;
1272 tf << prefix << "Host : " << gSystem->GetBuildNode() << std::endl;
1273 tf << prefix << "Dir : " << gSystem->WorkingDirectory() << std::endl;
1274 tf << prefix << "Training events: " << Data()->GetNTrainingEvents() << std::endl;
1275
1276 TString analysisType(((const_cast<TMVA::MethodBase*>(this)->GetAnalysisType()==Types::kRegression) ? "Regression" : "Classification"));
1277
1278 tf << prefix << "Analysis type : " << "[" << ((GetAnalysisType()==Types::kRegression) ? "Regression" : "Classification") << "]" << std::endl;
1279 tf << prefix << std::endl;
1280
1281 delete userInfo;
1282
1283 // First write all options
1284 tf << prefix << std::endl << prefix << "#OPT -*-*-*-*-*-*-*-*-*-*-*-*- options -*-*-*-*-*-*-*-*-*-*-*-*-" << std::endl << prefix << std::endl;
1285 WriteOptionsToStream( tf, prefix );
1286 tf << prefix << std::endl;
1287
1288 // Second write variable info
1289 tf << prefix << std::endl << prefix << "#VAR -*-*-*-*-*-*-*-*-*-*-*-* variables *-*-*-*-*-*-*-*-*-*-*-*-" << std::endl << prefix << std::endl;
1290 WriteVarsToStream( tf, prefix );
1291 tf << prefix << std::endl;
1292}
1293
1294////////////////////////////////////////////////////////////////////////////////
1295/// xml writing
1296
1297void TMVA::MethodBase::AddInfoItem( void* gi, const TString& name, const TString& value) const
1298{
1299 void* it = gTools().AddChild(gi,"Info");
1300 gTools().AddAttr(it,"name", name);
1301 gTools().AddAttr(it,"value", value);
1302}
1303
1304////////////////////////////////////////////////////////////////////////////////
1305
1307 if (analysisType == Types::kRegression) {
1308 AddRegressionOutput( type );
1309 } else if (analysisType == Types::kMulticlass) {
1310 AddMulticlassOutput( type );
1311 } else {
1312 AddClassifierOutput( type );
1313 if (HasMVAPdfs())
1314 AddClassifierOutputProb( type );
1315 }
1316}
1317
1318////////////////////////////////////////////////////////////////////////////////
1319/// general method used in writing the header of the weight files where
1320/// the used variables, variable transformation type etc. is specified
1321
1322void TMVA::MethodBase::WriteStateToXML( void* parent ) const
1323{
1324 if (!parent) return;
1325
1326 UserGroup_t* userInfo = gSystem->GetUserInfo();
1327
1328 void* gi = gTools().AddChild(parent, "GeneralInfo");
1329 AddInfoItem( gi, "TMVA Release", GetTrainingTMVAVersionString() + " [" + gTools().StringFromInt(GetTrainingTMVAVersionCode()) + "]" );
1330 AddInfoItem( gi, "ROOT Release", GetTrainingROOTVersionString() + " [" + gTools().StringFromInt(GetTrainingROOTVersionCode()) + "]");
1331 AddInfoItem( gi, "Creator", userInfo->fUser);
1332 TDatime dt; AddInfoItem( gi, "Date", dt.AsString());
1333 AddInfoItem( gi, "Host", gSystem->GetBuildNode() );
1334 AddInfoItem( gi, "Dir", gSystem->WorkingDirectory());
1335 AddInfoItem( gi, "Training events", gTools().StringFromInt(Data()->GetNTrainingEvents()));
1336 AddInfoItem( gi, "TrainingTime", gTools().StringFromDouble(const_cast<TMVA::MethodBase*>(this)->GetTrainTime()));
1337
1338 Types::EAnalysisType aType = const_cast<TMVA::MethodBase*>(this)->GetAnalysisType();
1339 TString analysisType((aType==Types::kRegression) ? "Regression" :
1340 (aType==Types::kMulticlass ? "Multiclass" : "Classification"));
1341 AddInfoItem( gi, "AnalysisType", analysisType );
1342 delete userInfo;
1343
1344 // write options
1345 AddOptionsXMLTo( parent );
1346
1347 // write variable info
1348 AddVarsXMLTo( parent );
1349
1350 // write spectator info
1351 if (fModelPersistence)
1352 AddSpectatorsXMLTo( parent );
1353
1354 // write class info if in multiclass mode
1355 AddClassesXMLTo(parent);
1356
1357 // write target info if in regression mode
1358 if (DoRegression()) AddTargetsXMLTo(parent);
1359
1360 // write transformations
1361 GetTransformationHandler(false).AddXMLTo( parent );
1362
1363 // write MVA variable distributions
1364 void* pdfs = gTools().AddChild(parent, "MVAPdfs");
1365 if (fMVAPdfS) fMVAPdfS->AddXMLTo(pdfs);
1366 if (fMVAPdfB) fMVAPdfB->AddXMLTo(pdfs);
1367
1368 // write weights
1369 AddWeightsXMLTo( parent );
1370}
1371
1372////////////////////////////////////////////////////////////////////////////////
1373/// write reference MVA distributions (and other information)
1374/// to a ROOT type weight file
1375
1377{
1378 Bool_t addDirStatus = TH1::AddDirectoryStatus();
1379 TH1::AddDirectory( 0 ); // this avoids the binding of the hists in PDF to the current ROOT file
1380 fMVAPdfS = (TMVA::PDF*)rf.Get( "MVA_PDF_Signal" );
1381 fMVAPdfB = (TMVA::PDF*)rf.Get( "MVA_PDF_Background" );
1382
1383 TH1::AddDirectory( addDirStatus );
1384
1385 ReadWeightsFromStream( rf );
1386
1387 SetTestvarName();
1388}
1389
1390////////////////////////////////////////////////////////////////////////////////
1391/// write options and weights to file
1392/// note that each one text file for the main configuration information
1393/// and one ROOT file for ROOT objects are created
1394
1396{
1397 // ---- create the text file
1398 TString tfname( GetWeightFileName() );
1399
1400 // writing xml file
1401 TString xmlfname( tfname ); xmlfname.ReplaceAll( ".txt", ".xml" );
1402 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
1403 << "Creating xml weight file: "
1404 << gTools().Color("lightblue") << xmlfname << gTools().Color("reset") << Endl;
1405 void* doc = gTools().xmlengine().NewDoc();
1406 void* rootnode = gTools().AddChild(0,"MethodSetup", "", true);
1407 gTools().xmlengine().DocSetRootElement(doc,rootnode);
1408 gTools().AddAttr(rootnode,"Method", GetMethodTypeName() + "::" + GetMethodName());
1409 WriteStateToXML(rootnode);
1410 gTools().xmlengine().SaveDoc(doc,xmlfname);
1411 gTools().xmlengine().FreeDoc(doc);
1412}
1413
1414////////////////////////////////////////////////////////////////////////////////
1415/// Function to write options and weights to file
1416
1418{
1419 // get the filename
1420
1421 TString tfname(GetWeightFileName());
1422
1423 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
1424 << "Reading weight file: "
1425 << gTools().Color("lightblue") << tfname << gTools().Color("reset") << Endl;
1426
1427 if (tfname.EndsWith(".xml") ) {
1428#if ROOT_VERSION_CODE >= ROOT_VERSION(5,29,0)
1429 void* doc = gTools().xmlengine().ParseFile(tfname,gTools().xmlenginebuffersize()); // the default buffer size in TXMLEngine::ParseFile is 100k. Starting with ROOT 5.29 one can set the buffer size, see: http://savannah.cern.ch/bugs/?78864. This might be necessary for large XML files
1430#else
1431 void* doc = gTools().xmlengine().ParseFile(tfname);
1432#endif
1433 if (!doc) {
1434 Log() << kFATAL << "Error parsing XML file " << tfname << Endl;
1435 }
1436 void* rootnode = gTools().xmlengine().DocGetRootElement(doc); // node "MethodSetup"
1437 ReadStateFromXML(rootnode);
1438 gTools().xmlengine().FreeDoc(doc);
1439 }
1440 else {
1441 std::filebuf fb;
1442 fb.open(tfname.Data(),std::ios::in);
1443 if (!fb.is_open()) { // file not found --> Error
1444 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<ReadStateFromFile> "
1445 << "Unable to open input weight file: " << tfname << Endl;
1446 }
1447 std::istream fin(&fb);
1448 ReadStateFromStream(fin);
1449 fb.close();
1450 }
1451 if (!fTxtWeightsOnly) {
1452 // ---- read the ROOT file
1453 TString rfname( tfname ); rfname.ReplaceAll( ".txt", ".root" );
1454 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Reading root weight file: "
1455 << gTools().Color("lightblue") << rfname << gTools().Color("reset") << Endl;
1456 TFile* rfile = TFile::Open( rfname, "READ" );
1457 ReadStateFromStream( *rfile );
1458 rfile->Close();
1459 }
1460}
1461////////////////////////////////////////////////////////////////////////////////
1462/// for reading from memory
1463
1465#if ROOT_VERSION_CODE >= ROOT_VERSION(5,26,00)
1466 void* doc = gTools().xmlengine().ParseString(xmlstr);
1467 void* rootnode = gTools().xmlengine().DocGetRootElement(doc); // node "MethodSetup"
1468 ReadStateFromXML(rootnode);
1469 gTools().xmlengine().FreeDoc(doc);
1470#else
1471 Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName()) << "Method MethodBase::ReadStateFromXMLString( const char* xmlstr = "
1472 << xmlstr << " ) is not available for ROOT versions prior to 5.26/00." << Endl;
1473#endif
1474
1475 return;
1476}
1477
1478////////////////////////////////////////////////////////////////////////////////
1479
1481{
1482
1483 TString fullMethodName;
1484 gTools().ReadAttr( methodNode, "Method", fullMethodName );
1485
1486 fMethodName = fullMethodName(fullMethodName.Index("::")+2,fullMethodName.Length());
1487
1488 // update logger
1489 Log().SetSource( GetName() );
1490 Log() << kDEBUG//<<Form("Dataset[%s] : ",DataInfo().GetName())
1491 << "Read method \"" << GetMethodName() << "\" of type \"" << GetMethodTypeName() << "\"" << Endl;
1492
1493 // after the method name is read, the testvar can be set
1494 SetTestvarName();
1495
1496 TString nodeName("");
1497 void* ch = gTools().GetChild(methodNode);
1498 while (ch!=0) {
1499 nodeName = TString( gTools().GetName(ch) );
1500
1501 if (nodeName=="GeneralInfo") {
1502 // read analysis type
1503
1504 TString name(""),val("");
1505 void* antypeNode = gTools().GetChild(ch);
1506 while (antypeNode) {
1507 gTools().ReadAttr( antypeNode, "name", name );
1508
1509 if (name == "TrainingTime")
1510 gTools().ReadAttr( antypeNode, "value", fTrainTime );
1511
1512 if (name == "AnalysisType") {
1513 gTools().ReadAttr( antypeNode, "value", val );
1514 val.ToLower();
1515 if (val == "regression" ) SetAnalysisType( Types::kRegression );
1516 else if (val == "classification" ) SetAnalysisType( Types::kClassification );
1517 else if (val == "multiclass" ) SetAnalysisType( Types::kMulticlass );
1518 else Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Analysis type " << val << " is not known." << Endl;
1519 }
1520
1521 if (name == "TMVA Release" || name == "TMVA") {
1522 TString s;
1523 gTools().ReadAttr( antypeNode, "value", s);
1524 fTMVATrainingVersion = TString(s(s.Index("[")+1,s.Index("]")-s.Index("[")-1)).Atoi();
1525 Log() << kDEBUG <<Form("[%s] : ",DataInfo().GetName()) << "MVA method was trained with TMVA Version: " << GetTrainingTMVAVersionString() << Endl;
1526 }
1527
1528 if (name == "ROOT Release" || name == "ROOT") {
1529 TString s;
1530 gTools().ReadAttr( antypeNode, "value", s);
1531 fROOTTrainingVersion = TString(s(s.Index("[")+1,s.Index("]")-s.Index("[")-1)).Atoi();
1532 Log() << kDEBUG //<<Form("Dataset[%s] : ",DataInfo().GetName())
1533 << "MVA method was trained with ROOT Version: " << GetTrainingROOTVersionString() << Endl;
1534 }
1535 antypeNode = gTools().GetNextChild(antypeNode);
1536 }
1537 }
1538 else if (nodeName=="Options") {
1539 ReadOptionsFromXML(ch);
1540 ParseOptions();
1541
1542 }
1543 else if (nodeName=="Variables") {
1544 ReadVariablesFromXML(ch);
1545 }
1546 else if (nodeName=="Spectators") {
1547 ReadSpectatorsFromXML(ch);
1548 }
1549 else if (nodeName=="Classes") {
1550 if (DataInfo().GetNClasses()==0) ReadClassesFromXML(ch);
1551 }
1552 else if (nodeName=="Targets") {
1553 if (DataInfo().GetNTargets()==0 && DoRegression()) ReadTargetsFromXML(ch);
1554 }
1555 else if (nodeName=="Transformations") {
1556 GetTransformationHandler().ReadFromXML(ch);
1557 }
1558 else if (nodeName=="MVAPdfs") {
1559 TString pdfname;
1560 if (fMVAPdfS) { delete fMVAPdfS; fMVAPdfS=0; }
1561 if (fMVAPdfB) { delete fMVAPdfB; fMVAPdfB=0; }
1562 void* pdfnode = gTools().GetChild(ch);
1563 if (pdfnode) {
1564 gTools().ReadAttr(pdfnode, "Name", pdfname);
1565 fMVAPdfS = new PDF(pdfname);
1566 fMVAPdfS->ReadXML(pdfnode);
1567 pdfnode = gTools().GetNextChild(pdfnode);
1568 gTools().ReadAttr(pdfnode, "Name", pdfname);
1569 fMVAPdfB = new PDF(pdfname);
1570 fMVAPdfB->ReadXML(pdfnode);
1571 }
1572 }
1573 else if (nodeName=="Weights") {
1574 ReadWeightsFromXML(ch);
1575 }
1576 else {
1577 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Unparsed XML node: '" << nodeName << "'" << Endl;
1578 }
1579 ch = gTools().GetNextChild(ch);
1580
1581 }
1582
1583 // update transformation handler
1584 if (GetTransformationHandler().GetCallerName() == "") GetTransformationHandler().SetCallerName( GetName() );
1585}
1586
1587////////////////////////////////////////////////////////////////////////////////
1588/// read the header from the weight files of the different MVA methods
1589
1591{
1592 char buf[512];
1593
1594 // when reading from stream, we assume the files are produced with TMVA<=397
1595 SetAnalysisType(Types::kClassification);
1596
1597
1598 // first read the method name
1599 GetLine(fin,buf);
1600 while (!TString(buf).BeginsWith("Method")) GetLine(fin,buf);
1601 TString namestr(buf);
1602
1603 TString methodType = namestr(0,namestr.Index("::"));
1604 methodType = methodType(methodType.Last(' '),methodType.Length());
1605 methodType = methodType.Strip(TString::kLeading);
1606
1607 TString methodName = namestr(namestr.Index("::")+2,namestr.Length());
1608 methodName = methodName.Strip(TString::kLeading);
1609 if (methodName == "") methodName = methodType;
1610 fMethodName = methodName;
1611
1612 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Read method \"" << GetMethodName() << "\" of type \"" << GetMethodTypeName() << "\"" << Endl;
1613
1614 // update logger
1615 Log().SetSource( GetName() );
1616
1617 // now the question is whether to read the variables first or the options (well, of course the order
1618 // of writing them needs to agree)
1619 //
1620 // the option "Decorrelation" is needed to decide if the variables we
1621 // read are decorrelated or not
1622 //
1623 // the variables are needed by some methods (TMLP) to build the NN
1624 // which is done in ProcessOptions so for the time being we first Read and Parse the options then
1625 // we read the variables, and then we process the options
1626
1627 // now read all options
1628 GetLine(fin,buf);
1629 while (!TString(buf).BeginsWith("#OPT")) GetLine(fin,buf);
1630 ReadOptionsFromStream(fin);
1631 ParseOptions();
1632
1633 // Now read variable info
1634 fin.getline(buf,512);
1635 while (!TString(buf).BeginsWith("#VAR")) fin.getline(buf,512);
1636 ReadVarsFromStream(fin);
1637
1638 // now we process the options (of the derived class)
1639 ProcessOptions();
1640
1641 if (IsNormalised()) {
1643 GetTransformationHandler().AddTransformation( new VariableNormalizeTransform(DataInfo()), -1 );
1644 norm->BuildTransformationFromVarInfo( DataInfo().GetVariableInfos() );
1645 }
1646 VariableTransformBase *varTrafo(0), *varTrafo2(0);
1647 if ( fVarTransformString == "None") {
1648 if (fUseDecorr)
1649 varTrafo = GetTransformationHandler().AddTransformation( new VariableDecorrTransform(DataInfo()), -1 );
1650 } else if ( fVarTransformString == "Decorrelate" ) {
1651 varTrafo = GetTransformationHandler().AddTransformation( new VariableDecorrTransform(DataInfo()), -1 );
1652 } else if ( fVarTransformString == "PCA" ) {
1653 varTrafo = GetTransformationHandler().AddTransformation( new VariablePCATransform(DataInfo()), -1 );
1654 } else if ( fVarTransformString == "Uniform" ) {
1655 varTrafo = GetTransformationHandler().AddTransformation( new VariableGaussTransform(DataInfo(),"Uniform"), -1 );
1656 } else if ( fVarTransformString == "Gauss" ) {
1657 varTrafo = GetTransformationHandler().AddTransformation( new VariableGaussTransform(DataInfo()), -1 );
1658 } else if ( fVarTransformString == "GaussDecorr" ) {
1659 varTrafo = GetTransformationHandler().AddTransformation( new VariableGaussTransform(DataInfo()), -1 );
1660 varTrafo2 = GetTransformationHandler().AddTransformation( new VariableDecorrTransform(DataInfo()), -1 );
1661 } else {
1662 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<ProcessOptions> Variable transform '"
1663 << fVarTransformString << "' unknown." << Endl;
1664 }
1665 // Now read decorrelation matrix if available
1666 if (GetTransformationHandler().GetTransformationList().GetSize() > 0) {
1667 fin.getline(buf,512);
1668 while (!TString(buf).BeginsWith("#MAT")) fin.getline(buf,512);
1669 if (varTrafo) {
1670 TString trafo(fVariableTransformTypeString); trafo.ToLower();
1671 varTrafo->ReadTransformationFromStream(fin, trafo );
1672 }
1673 if (varTrafo2) {
1674 TString trafo(fVariableTransformTypeString); trafo.ToLower();
1675 varTrafo2->ReadTransformationFromStream(fin, trafo );
1676 }
1677 }
1678
1679
1680 if (HasMVAPdfs()) {
1681 // Now read the MVA PDFs
1682 fin.getline(buf,512);
1683 while (!TString(buf).BeginsWith("#MVAPDFS")) fin.getline(buf,512);
1684 if (fMVAPdfS != 0) { delete fMVAPdfS; fMVAPdfS = 0; }
1685 if (fMVAPdfB != 0) { delete fMVAPdfB; fMVAPdfB = 0; }
1686 fMVAPdfS = new PDF(TString(GetName()) + " MVA PDF Sig");
1687 fMVAPdfB = new PDF(TString(GetName()) + " MVA PDF Bkg");
1688 fMVAPdfS->SetReadingVersion( GetTrainingTMVAVersionCode() );
1689 fMVAPdfB->SetReadingVersion( GetTrainingTMVAVersionCode() );
1690
1691 fin >> *fMVAPdfS;
1692 fin >> *fMVAPdfB;
1693 }
1694
1695 // Now read weights
1696 fin.getline(buf,512);
1697 while (!TString(buf).BeginsWith("#WGT")) fin.getline(buf,512);
1698 fin.getline(buf,512);
1699 ReadWeightsFromStream( fin );;
1700
1701 // update transformation handler
1702 if (GetTransformationHandler().GetCallerName() == "") GetTransformationHandler().SetCallerName( GetName() );
1703
1704}
1705
1706////////////////////////////////////////////////////////////////////////////////
1707/// write the list of variables (name, min, max) for a given data
1708/// transformation method to the stream
1709
1710void TMVA::MethodBase::WriteVarsToStream( std::ostream& o, const TString& prefix ) const
1711{
1712 o << prefix << "NVar " << DataInfo().GetNVariables() << std::endl;
1713 std::vector<VariableInfo>::const_iterator varIt = DataInfo().GetVariableInfos().begin();
1714 for (; varIt!=DataInfo().GetVariableInfos().end(); ++varIt) { o << prefix; varIt->WriteToStream(o); }
1715 o << prefix << "NSpec " << DataInfo().GetNSpectators() << std::endl;
1716 varIt = DataInfo().GetSpectatorInfos().begin();
1717 for (; varIt!=DataInfo().GetSpectatorInfos().end(); ++varIt) { o << prefix; varIt->WriteToStream(o); }
1718}
1719
1720////////////////////////////////////////////////////////////////////////////////
1721/// Read the variables (name, min, max) for a given data
1722/// transformation method from the stream. In the stream we only
1723/// expect the limits which will be set
1724
1726{
1727 TString dummy;
1728 UInt_t readNVar;
1729 istr >> dummy >> readNVar;
1730
1731 if (readNVar!=DataInfo().GetNVariables()) {
1732 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "You declared "<< DataInfo().GetNVariables() << " variables in the Reader"
1733 << " while there are " << readNVar << " variables declared in the file"
1734 << Endl;
1735 }
1736
1737 // we want to make sure all variables are read in the order they are defined
1738 VariableInfo varInfo;
1739 std::vector<VariableInfo>::iterator varIt = DataInfo().GetVariableInfos().begin();
1740 int varIdx = 0;
1741 for (; varIt!=DataInfo().GetVariableInfos().end(); ++varIt, ++varIdx) {
1742 varInfo.ReadFromStream(istr);
1743 if (varIt->GetExpression() == varInfo.GetExpression()) {
1744 varInfo.SetExternalLink((*varIt).GetExternalLink());
1745 (*varIt) = varInfo;
1746 }
1747 else {
1748 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "ERROR in <ReadVarsFromStream>" << Endl;
1749 Log() << kINFO << "The definition (or the order) of the variables found in the input file is" << Endl;
1750 Log() << kINFO << "is not the same as the one declared in the Reader (which is necessary for" << Endl;
1751 Log() << kINFO << "the correct working of the method):" << Endl;
1752 Log() << kINFO << " var #" << varIdx <<" declared in Reader: " << varIt->GetExpression() << Endl;
1753 Log() << kINFO << " var #" << varIdx <<" declared in file : " << varInfo.GetExpression() << Endl;
1754 Log() << kFATAL << "The expression declared to the Reader needs to be checked (name or order are wrong)" << Endl;
1755 }
1756 }
1757}
1758
1759////////////////////////////////////////////////////////////////////////////////
1760/// write variable info to XML
1761
1762void TMVA::MethodBase::AddVarsXMLTo( void* parent ) const
1763{
1764 void* vars = gTools().AddChild(parent, "Variables");
1765 gTools().AddAttr( vars, "NVar", gTools().StringFromInt(DataInfo().GetNVariables()) );
1766
1767 for (UInt_t idx=0; idx<DataInfo().GetVariableInfos().size(); idx++) {
1768 VariableInfo& vi = DataInfo().GetVariableInfos()[idx];
1769 void* var = gTools().AddChild( vars, "Variable" );
1770 gTools().AddAttr( var, "VarIndex", idx );
1771 vi.AddToXML( var );
1772 }
1773}
1774
1775////////////////////////////////////////////////////////////////////////////////
1776/// write spectator info to XML
1777
1779{
1780 void* specs = gTools().AddChild(parent, "Spectators");
1781
1782 UInt_t writeIdx=0;
1783 for (UInt_t idx=0; idx<DataInfo().GetSpectatorInfos().size(); idx++) {
1784
1785 VariableInfo& vi = DataInfo().GetSpectatorInfos()[idx];
1786
1787 // we do not want to write spectators that are category-cuts,
1788 // except if the method is the category method and the spectators belong to it
1789 if (vi.GetVarType()=='C') continue;
1790
1791 void* spec = gTools().AddChild( specs, "Spectator" );
1792 gTools().AddAttr( spec, "SpecIndex", writeIdx++ );
1793 vi.AddToXML( spec );
1794 }
1795 gTools().AddAttr( specs, "NSpec", gTools().StringFromInt(writeIdx) );
1796}
1797
1798////////////////////////////////////////////////////////////////////////////////
1799/// write class info to XML
1800
1801void TMVA::MethodBase::AddClassesXMLTo( void* parent ) const
1802{
1803 UInt_t nClasses=DataInfo().GetNClasses();
1804
1805 void* classes = gTools().AddChild(parent, "Classes");
1806 gTools().AddAttr( classes, "NClass", nClasses );
1807
1808 for (UInt_t iCls=0; iCls<nClasses; ++iCls) {
1809 ClassInfo *classInfo=DataInfo().GetClassInfo (iCls);
1810 TString className =classInfo->GetName();
1811 UInt_t classNumber=classInfo->GetNumber();
1812
1813 void* classNode=gTools().AddChild(classes, "Class");
1814 gTools().AddAttr( classNode, "Name", className );
1815 gTools().AddAttr( classNode, "Index", classNumber );
1816 }
1817}
1818////////////////////////////////////////////////////////////////////////////////
1819/// write target info to XML
1820
1821void TMVA::MethodBase::AddTargetsXMLTo( void* parent ) const
1822{
1823 void* targets = gTools().AddChild(parent, "Targets");
1824 gTools().AddAttr( targets, "NTrgt", gTools().StringFromInt(DataInfo().GetNTargets()) );
1825
1826 for (UInt_t idx=0; idx<DataInfo().GetTargetInfos().size(); idx++) {
1827 VariableInfo& vi = DataInfo().GetTargetInfos()[idx];
1828 void* tar = gTools().AddChild( targets, "Target" );
1829 gTools().AddAttr( tar, "TargetIndex", idx );
1830 vi.AddToXML( tar );
1831 }
1832}
1833
1834////////////////////////////////////////////////////////////////////////////////
1835/// read variable info from XML
1836
1838{
1839 UInt_t readNVar;
1840 gTools().ReadAttr( varnode, "NVar", readNVar);
1841
1842 if (readNVar!=DataInfo().GetNVariables()) {
1843 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "You declared "<< DataInfo().GetNVariables() << " variables in the Reader"
1844 << " while there are " << readNVar << " variables declared in the file"
1845 << Endl;
1846 }
1847
1848 // we want to make sure all variables are read in the order they are defined
1849 VariableInfo readVarInfo, existingVarInfo;
1850 int varIdx = 0;
1851 void* ch = gTools().GetChild(varnode);
1852 while (ch) {
1853 gTools().ReadAttr( ch, "VarIndex", varIdx);
1854 existingVarInfo = DataInfo().GetVariableInfos()[varIdx];
1855 readVarInfo.ReadFromXML(ch);
1856
1857 if (existingVarInfo.GetExpression() == readVarInfo.GetExpression()) {
1858 readVarInfo.SetExternalLink(existingVarInfo.GetExternalLink());
1859 existingVarInfo = readVarInfo;
1860 }
1861 else {
1862 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "ERROR in <ReadVariablesFromXML>" << Endl;
1863 Log() << kINFO << "The definition (or the order) of the variables found in the input file is" << Endl;
1864 Log() << kINFO << "not the same as the one declared in the Reader (which is necessary for the" << Endl;
1865 Log() << kINFO << "correct working of the method):" << Endl;
1866 Log() << kINFO << " var #" << varIdx <<" declared in Reader: " << existingVarInfo.GetExpression() << Endl;
1867 Log() << kINFO << " var #" << varIdx <<" declared in file : " << readVarInfo.GetExpression() << Endl;
1868 Log() << kFATAL << "The expression declared to the Reader needs to be checked (name or order are wrong)" << Endl;
1869 }
1870 ch = gTools().GetNextChild(ch);
1871 }
1872}
1873
1874////////////////////////////////////////////////////////////////////////////////
1875/// read spectator info from XML
1876
1878{
1879 UInt_t readNSpec;
1880 gTools().ReadAttr( specnode, "NSpec", readNSpec);
1881
1882 if (readNSpec!=DataInfo().GetNSpectators(kFALSE)) {
1883 Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName()) << "You declared "<< DataInfo().GetNSpectators(kFALSE) << " spectators in the Reader"
1884 << " while there are " << readNSpec << " spectators declared in the file"
1885 << Endl;
1886 }
1887
1888 // we want to make sure all variables are read in the order they are defined
1889 VariableInfo readSpecInfo, existingSpecInfo;
1890 int specIdx = 0;
1891 void* ch = gTools().GetChild(specnode);
1892 while (ch) {
1893 gTools().ReadAttr( ch, "SpecIndex", specIdx);
1894 existingSpecInfo = DataInfo().GetSpectatorInfos()[specIdx];
1895 readSpecInfo.ReadFromXML(ch);
1896
1897 if (existingSpecInfo.GetExpression() == readSpecInfo.GetExpression()) {
1898 readSpecInfo.SetExternalLink(existingSpecInfo.GetExternalLink());
1899 existingSpecInfo = readSpecInfo;
1900 }
1901 else {
1902 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "ERROR in <ReadSpectatorsFromXML>" << Endl;
1903 Log() << kINFO << "The definition (or the order) of the spectators found in the input file is" << Endl;
1904 Log() << kINFO << "not the same as the one declared in the Reader (which is necessary for the" << Endl;
1905 Log() << kINFO << "correct working of the method):" << Endl;
1906 Log() << kINFO << " spec #" << specIdx <<" declared in Reader: " << existingSpecInfo.GetExpression() << Endl;
1907 Log() << kINFO << " spec #" << specIdx <<" declared in file : " << readSpecInfo.GetExpression() << Endl;
1908 Log() << kFATAL << "The expression declared to the Reader needs to be checked (name or order are wrong)" << Endl;
1909 }
1910 ch = gTools().GetNextChild(ch);
1911 }
1912}
1913
1914////////////////////////////////////////////////////////////////////////////////
1915/// read number of classes from XML
1916
1918{
1919 UInt_t readNCls;
1920 // coverity[tainted_data_argument]
1921 gTools().ReadAttr( clsnode, "NClass", readNCls);
1922
1923 TString className="";
1924 UInt_t classIndex=0;
1925 void* ch = gTools().GetChild(clsnode);
1926 if (!ch) {
1927 for (UInt_t icls = 0; icls<readNCls;++icls) {
1928 TString classname = Form("class%i",icls);
1929 DataInfo().AddClass(classname);
1930
1931 }
1932 }
1933 else{
1934 while (ch) {
1935 gTools().ReadAttr( ch, "Index", classIndex);
1936 gTools().ReadAttr( ch, "Name", className );
1937 DataInfo().AddClass(className);
1938
1939 ch = gTools().GetNextChild(ch);
1940 }
1941 }
1942
1943 // retrieve signal and background class index
1944 if (DataInfo().GetClassInfo("Signal") != 0) {
1945 fSignalClass = DataInfo().GetClassInfo("Signal")->GetNumber();
1946 }
1947 else
1948 fSignalClass=0;
1949 if (DataInfo().GetClassInfo("Background") != 0) {
1950 fBackgroundClass = DataInfo().GetClassInfo("Background")->GetNumber();
1951 }
1952 else
1953 fBackgroundClass=1;
1954}
1955
1956////////////////////////////////////////////////////////////////////////////////
1957/// read target info from XML
1958
1960{
1961 UInt_t readNTar;
1962 gTools().ReadAttr( tarnode, "NTrgt", readNTar);
1963
1964 int tarIdx = 0;
1965 TString expression;
1966 void* ch = gTools().GetChild(tarnode);
1967 while (ch) {
1968 gTools().ReadAttr( ch, "TargetIndex", tarIdx);
1969 gTools().ReadAttr( ch, "Expression", expression);
1970 DataInfo().AddTarget(expression,"","",0,0);
1971
1972 ch = gTools().GetNextChild(ch);
1973 }
1974}
1975
1976////////////////////////////////////////////////////////////////////////////////
1977/// returns the ROOT directory where info/histograms etc of the
1978/// corresponding MVA method instance are stored
1979
1981{
1982 if (fBaseDir != 0) return fBaseDir;
1983 Log()<<kDEBUG<<Form("Dataset[%s] : ",DataInfo().GetName())<<" Base Directory for " << GetMethodName() << " not set yet --> check if already there.." <<Endl;
1984
1985 TDirectory* methodDir = MethodBaseDir();
1986 if (methodDir==0)
1987 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "MethodBase::BaseDir() - MethodBaseDir() return a NULL pointer!" << Endl;
1988
1989 TString defaultDir = GetMethodName();
1990 TDirectory *sdir = methodDir->GetDirectory(defaultDir.Data());
1991 if(!sdir)
1992 {
1993 Log()<<kDEBUG<<Form("Dataset[%s] : ",DataInfo().GetName())<<" Base Directory for " << GetMethodTypeName() << " does not exist yet--> created it" <<Endl;
1994 sdir = methodDir->mkdir(defaultDir);
1995 sdir->cd();
1996 // write weight file name into target file
1997 if (fModelPersistence) {
1998 TObjString wfilePath( gSystem->WorkingDirectory() );
1999 TObjString wfileName( GetWeightFileName() );
2000 wfilePath.Write( "TrainingPath" );
2001 wfileName.Write( "WeightFileName" );
2002 }
2003 }
2004
2005 Log()<<kDEBUG<<Form("Dataset[%s] : ",DataInfo().GetName())<<" Base Directory for " << GetMethodTypeName() << " existed, return it.." <<Endl;
2006 return sdir;
2007}
2008
2009////////////////////////////////////////////////////////////////////////////////
2010/// returns the ROOT directory where all instances of the
2011/// corresponding MVA method are stored
2012
2014{
2015 if (fMethodBaseDir != 0) {
2016 return fMethodBaseDir;
2017 }
2018
2019 const char *datasetName = DataInfo().GetName();
2020
2021 Log() << kDEBUG << Form("Dataset[%s] : ", datasetName) << " Base Directory for " << GetMethodTypeName()
2022 << " not set yet --> check if already there.." << Endl;
2023
2024 TDirectory *factoryBaseDir = GetFile();
2025 fMethodBaseDir = factoryBaseDir->GetDirectory(datasetName);
2026 if (!fMethodBaseDir) {
2027 fMethodBaseDir = factoryBaseDir->mkdir(datasetName, Form("Base directory for dataset %s", datasetName));
2028 if (!fMethodBaseDir) {
2029 Log() << kFATAL << "Can not create dir " << datasetName;
2030 }
2031 }
2032 TString methodTypeDir = Form("Method_%s", GetMethodTypeName().Data());
2033 fMethodBaseDir = fMethodBaseDir->GetDirectory(methodTypeDir.Data());
2034
2035 if (!fMethodBaseDir) {
2036 TDirectory *datasetDir = factoryBaseDir->GetDirectory(datasetName);
2037 TString methodTypeDirHelpStr = Form("Directory for all %s methods", GetMethodTypeName().Data());
2038 fMethodBaseDir = datasetDir->mkdir(methodTypeDir.Data(), methodTypeDirHelpStr);
2039 Log() << kDEBUG << Form("Dataset[%s] : ", datasetName) << " Base Directory for " << GetMethodName()
2040 << " does not exist yet--> created it" << Endl;
2041 }
2042
2043 Log() << kDEBUG << Form("Dataset[%s] : ", datasetName)
2044 << "Return from MethodBaseDir() after creating base directory " << Endl;
2045 return fMethodBaseDir;
2046}
2047
2048////////////////////////////////////////////////////////////////////////////////
2049/// set directory of weight file
2050
2052{
2053 fFileDir = fileDir;
2054 gSystem->MakeDirectory( fFileDir );
2055}
2056
2057////////////////////////////////////////////////////////////////////////////////
2058/// set the weight file name (depreciated)
2059
2061{
2062 fWeightFile = theWeightFile;
2063}
2064
2065////////////////////////////////////////////////////////////////////////////////
2066/// retrieve weight file name
2067
2069{
2070 if (fWeightFile!="") return fWeightFile;
2071
2072 // the default consists of
2073 // directory/jobname_methodname_suffix.extension.{root/txt}
2074 TString suffix = "";
2075 TString wFileDir(GetWeightFileDir());
2076 TString wFileName = GetJobName() + "_" + GetMethodName() +
2077 suffix + "." + gConfig().GetIONames().fWeightFileExtension + ".xml";
2078 if (wFileDir.IsNull() ) return wFileName;
2079 // add weight file directory of it is not null
2080 return ( wFileDir + (wFileDir[wFileDir.Length()-1]=='/' ? "" : "/")
2081 + wFileName );
2082}
2083////////////////////////////////////////////////////////////////////////////////
2084/// writes all MVA evaluation histograms to file
2085
2087{
2088 BaseDir()->cd();
2089
2090
2091 // write MVA PDFs to file - if exist
2092 if (0 != fMVAPdfS) {
2093 fMVAPdfS->GetOriginalHist()->Write();
2094 fMVAPdfS->GetSmoothedHist()->Write();
2095 fMVAPdfS->GetPDFHist()->Write();
2096 }
2097 if (0 != fMVAPdfB) {
2098 fMVAPdfB->GetOriginalHist()->Write();
2099 fMVAPdfB->GetSmoothedHist()->Write();
2100 fMVAPdfB->GetPDFHist()->Write();
2101 }
2102
2103 // write result-histograms
2104 Results* results = Data()->GetResults( GetMethodName(), treetype, Types::kMaxAnalysisType );
2105 if (!results)
2106 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<WriteEvaluationHistosToFile> Unknown result: "
2107 << GetMethodName() << (treetype==Types::kTraining?"/kTraining":"/kTesting")
2108 << "/kMaxAnalysisType" << Endl;
2109 results->GetStorage()->Write();
2110 if (treetype==Types::kTesting) {
2111 GetTransformationHandler().PlotVariables (GetEventCollection( Types::kTesting ), BaseDir() );
2112 }
2113}
2114
2115////////////////////////////////////////////////////////////////////////////////
2116/// write special monitoring histograms to file
2117/// dummy implementation here -----------------
2118
2120{
2121}
2122
2123////////////////////////////////////////////////////////////////////////////////
2124/// reads one line from the input stream
2125/// checks for certain keywords and interprets
2126/// the line if keywords are found
2127
2128Bool_t TMVA::MethodBase::GetLine(std::istream& fin, char* buf )
2129{
2130 fin.getline(buf,512);
2131 TString line(buf);
2132 if (line.BeginsWith("TMVA Release")) {
2133 Ssiz_t start = line.First('[')+1;
2134 Ssiz_t length = line.Index("]",start)-start;
2135 TString code = line(start,length);
2136 std::stringstream s(code.Data());
2137 s >> fTMVATrainingVersion;
2138 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "MVA method was trained with TMVA Version: " << GetTrainingTMVAVersionString() << Endl;
2139 }
2140 if (line.BeginsWith("ROOT Release")) {
2141 Ssiz_t start = line.First('[')+1;
2142 Ssiz_t length = line.Index("]",start)-start;
2143 TString code = line(start,length);
2144 std::stringstream s(code.Data());
2145 s >> fROOTTrainingVersion;
2146 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "MVA method was trained with ROOT Version: " << GetTrainingROOTVersionString() << Endl;
2147 }
2148 if (line.BeginsWith("Analysis type")) {
2149 Ssiz_t start = line.First('[')+1;
2150 Ssiz_t length = line.Index("]",start)-start;
2151 TString code = line(start,length);
2152 std::stringstream s(code.Data());
2153 std::string analysisType;
2154 s >> analysisType;
2155 if (analysisType == "regression" || analysisType == "Regression") SetAnalysisType( Types::kRegression );
2156 else if (analysisType == "classification" || analysisType == "Classification") SetAnalysisType( Types::kClassification );
2157 else if (analysisType == "multiclass" || analysisType == "Multiclass") SetAnalysisType( Types::kMulticlass );
2158 else Log() << kFATAL << "Analysis type " << analysisType << " from weight-file not known!" << std::endl;
2159
2160 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Method was trained for "
2161 << (GetAnalysisType() == Types::kRegression ? "Regression" :
2162 (GetAnalysisType() == Types::kMulticlass ? "Multiclass" : "Classification")) << Endl;
2163 }
2164
2165 return true;
2166}
2167
2168////////////////////////////////////////////////////////////////////////////////
2169/// Create PDFs of the MVA output variables
2170
2172{
2173 Data()->SetCurrentType(Types::kTraining);
2174
2175 // the PDF's are stored as results ONLY if the corresponding "results" are booked,
2176 // otherwise they will be only used 'online'
2177 ResultsClassification * mvaRes = dynamic_cast<ResultsClassification*>
2178 ( Data()->GetResults(GetMethodName(), Types::kTraining, Types::kClassification) );
2179
2180 if (mvaRes==0 || mvaRes->GetSize()==0) {
2181 Log() << kERROR<<Form("Dataset[%s] : ",DataInfo().GetName())<< "<CreateMVAPdfs> No result of classifier testing available" << Endl;
2182 }
2183
2184 Double_t minVal = *std::min_element(mvaRes->GetValueVector()->begin(),mvaRes->GetValueVector()->end());
2185 Double_t maxVal = *std::max_element(mvaRes->GetValueVector()->begin(),mvaRes->GetValueVector()->end());
2186
2187 // create histograms that serve as basis to create the MVA Pdfs
2188 TH1* histMVAPdfS = new TH1D( GetMethodTypeName() + "_tr_S", GetMethodTypeName() + "_tr_S",
2189 fMVAPdfS->GetHistNBins( mvaRes->GetSize() ), minVal, maxVal );
2190 TH1* histMVAPdfB = new TH1D( GetMethodTypeName() + "_tr_B", GetMethodTypeName() + "_tr_B",
2191 fMVAPdfB->GetHistNBins( mvaRes->GetSize() ), minVal, maxVal );
2192
2193
2194 // compute sum of weights properly
2195 histMVAPdfS->Sumw2();
2196 histMVAPdfB->Sumw2();
2197
2198 // fill histograms
2199 for (UInt_t ievt=0; ievt<mvaRes->GetSize(); ievt++) {
2200 Double_t theVal = mvaRes->GetValueVector()->at(ievt);
2201 Double_t theWeight = Data()->GetEvent(ievt)->GetWeight();
2202
2203 if (DataInfo().IsSignal(Data()->GetEvent(ievt))) histMVAPdfS->Fill( theVal, theWeight );
2204 else histMVAPdfB->Fill( theVal, theWeight );
2205 }
2206
2207 gTools().NormHist( histMVAPdfS );
2208 gTools().NormHist( histMVAPdfB );
2209
2210 // momentary hack for ROOT problem
2211 if(!IsSilentFile())
2212 {
2213 histMVAPdfS->Write();
2214 histMVAPdfB->Write();
2215 }
2216 // create PDFs
2217 fMVAPdfS->BuildPDF ( histMVAPdfS );
2218 fMVAPdfB->BuildPDF ( histMVAPdfB );
2219 fMVAPdfS->ValidatePDF( histMVAPdfS );
2220 fMVAPdfB->ValidatePDF( histMVAPdfB );
2221
2222 if (DataInfo().GetNClasses() == 2) { // TODO: this is an ugly hack.. adapt this to new framework
2223 Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName())
2224 << Form( "<CreateMVAPdfs> Separation from histogram (PDF): %1.3f (%1.3f)",
2225 GetSeparation( histMVAPdfS, histMVAPdfB ), GetSeparation( fMVAPdfS, fMVAPdfB ) )
2226 << Endl;
2227 }
2228
2229 delete histMVAPdfS;
2230 delete histMVAPdfB;
2231}
2232
2234 // the simple one, automatically calculates the mvaVal and uses the
2235 // SAME sig/bkg ratio as given in the training sample (typically 50/50
2236 // .. (NormMode=EqualNumEvents) but can be different)
2237 if (!fMVAPdfS || !fMVAPdfB) {
2238 Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName()) << "<GetProba> MVA PDFs for Signal and Background don't exist yet, we'll create them on demand" << Endl;
2239 CreateMVAPdfs();
2240 }
2241 Double_t sigFraction = DataInfo().GetTrainingSumSignalWeights() / (DataInfo().GetTrainingSumSignalWeights() + DataInfo().GetTrainingSumBackgrWeights() );
2242 Double_t mvaVal = GetMvaValue(ev);
2243
2244 return GetProba(mvaVal,sigFraction);
2245
2246}
2247////////////////////////////////////////////////////////////////////////////////
2248/// compute likelihood ratio
2249
2251{
2252 if (!fMVAPdfS || !fMVAPdfB) {
2253 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetProba> MVA PDFs for Signal and Background don't exist" << Endl;
2254 return -1.0;
2255 }
2256 Double_t p_s = fMVAPdfS->GetVal( mvaVal );
2257 Double_t p_b = fMVAPdfB->GetVal( mvaVal );
2258
2259 Double_t denom = p_s*ap_sig + p_b*(1 - ap_sig);
2260
2261 return (denom > 0) ? (p_s*ap_sig) / denom : -1;
2262}
2263
2264////////////////////////////////////////////////////////////////////////////////
2265/// compute rarity:
2266/// \f[
2267/// R(x) = \int_{[-\infty..x]} { PDF(x') dx' }
2268/// \f]
2269/// where PDF(x) is the PDF of the classifier's signal or background distribution
2270
2272{
2273 if ((reftype == Types::kSignal && !fMVAPdfS) || (reftype == Types::kBackground && !fMVAPdfB)) {
2274 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetRarity> Required MVA PDF for Signal or Background does not exist: "
2275 << "select option \"CreateMVAPdfs\"" << Endl;
2276 return 0.0;
2277 }
2278
2279 PDF* thePdf = ((reftype == Types::kSignal) ? fMVAPdfS : fMVAPdfB);
2280
2281 return thePdf->GetIntegral( thePdf->GetXmin(), mvaVal );
2282}
2283
2284////////////////////////////////////////////////////////////////////////////////
2285/// fill background efficiency (resp. rejection) versus signal efficiency plots
2286/// returns signal efficiency at background efficiency indicated in theString
2287
2289{
2290 Data()->SetCurrentType(type);
2291 Results* results = Data()->GetResults( GetMethodName(), type, Types::kClassification );
2292 std::vector<Float_t>* mvaRes = dynamic_cast<ResultsClassification*>(results)->GetValueVector();
2293
2294 // parse input string for required background efficiency
2295 TList* list = gTools().ParseFormatLine( theString );
2296
2297 // sanity check
2298 Bool_t computeArea = kFALSE;
2299 if (!list || list->GetSize() < 2) computeArea = kTRUE; // the area is computed
2300 else if (list->GetSize() > 2) {
2301 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetEfficiency> Wrong number of arguments"
2302 << " in string: " << theString
2303 << " | required format, e.g., Efficiency:0.05, or empty string" << Endl;
2304 delete list;
2305 return -1;
2306 }
2307
2308 // sanity check
2309 if ( results->GetHist("MVA_S")->GetNbinsX() != results->GetHist("MVA_B")->GetNbinsX() ||
2310 results->GetHist("MVA_HIGHBIN_S")->GetNbinsX() != results->GetHist("MVA_HIGHBIN_B")->GetNbinsX() ) {
2311 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetEfficiency> Binning mismatch between signal and background histos" << Endl;
2312 delete list;
2313 return -1.0;
2314 }
2315
2316 // create histograms
2317
2318 // first, get efficiency histograms for signal and background
2319 TH1 * effhist = results->GetHist("MVA_HIGHBIN_S");
2320 Double_t xmin = effhist->GetXaxis()->GetXmin();
2321 Double_t xmax = effhist->GetXaxis()->GetXmax();
2322
2323 TTHREAD_TLS(Double_t) nevtS;
2324
2325 // first round ? --> create histograms
2326 if (results->DoesExist("MVA_EFF_S")==0) {
2327
2328 // for efficiency plot
2329 TH1* eff_s = new TH1D( GetTestvarName() + "_effS", GetTestvarName() + " (signal)", fNbinsH, xmin, xmax );
2330 TH1* eff_b = new TH1D( GetTestvarName() + "_effB", GetTestvarName() + " (background)", fNbinsH, xmin, xmax );
2331 results->Store(eff_s, "MVA_EFF_S");
2332 results->Store(eff_b, "MVA_EFF_B");
2333
2334 // sign if cut
2335 Int_t sign = (fCutOrientation == kPositive) ? +1 : -1;
2336
2337 // this method is unbinned
2338 nevtS = 0;
2339 for (UInt_t ievt=0; ievt<Data()->GetNEvents(); ievt++) {
2340
2341 // read the tree
2342 Bool_t isSignal = DataInfo().IsSignal(GetEvent(ievt));
2343 Float_t theWeight = GetEvent(ievt)->GetWeight();
2344 Float_t theVal = (*mvaRes)[ievt];
2345
2346 // select histogram depending on if sig or bgd
2347 TH1* theHist = isSignal ? eff_s : eff_b;
2348
2349 // count signal and background events in tree
2350 if (isSignal) nevtS+=theWeight;
2351
2352 TAxis* axis = theHist->GetXaxis();
2353 Int_t maxbin = Int_t((theVal - axis->GetXmin())/(axis->GetXmax() - axis->GetXmin())*fNbinsH) + 1;
2354 if (sign > 0 && maxbin > fNbinsH) continue; // can happen... event doesn't count
2355 if (sign < 0 && maxbin < 1 ) continue; // can happen... event doesn't count
2356 if (sign > 0 && maxbin < 1 ) maxbin = 1;
2357 if (sign < 0 && maxbin > fNbinsH) maxbin = fNbinsH;
2358
2359 if (sign > 0)
2360 for (Int_t ibin=1; ibin<=maxbin; ibin++) theHist->AddBinContent( ibin , theWeight);
2361 else if (sign < 0)
2362 for (Int_t ibin=maxbin+1; ibin<=fNbinsH; ibin++) theHist->AddBinContent( ibin , theWeight );
2363 else
2364 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetEfficiency> Mismatch in sign" << Endl;
2365 }
2366
2367 // renormalise maximum to <=1
2368 // eff_s->Scale( 1.0/TMath::Max(1.,eff_s->GetMaximum()) );
2369 // eff_b->Scale( 1.0/TMath::Max(1.,eff_b->GetMaximum()) );
2370
2373
2374 // background efficiency versus signal efficiency
2375 TH1* eff_BvsS = new TH1D( GetTestvarName() + "_effBvsS", GetTestvarName() + "", fNbins, 0, 1 );
2376 results->Store(eff_BvsS, "MVA_EFF_BvsS");
2377 eff_BvsS->SetXTitle( "Signal eff" );
2378 eff_BvsS->SetYTitle( "Backgr eff" );
2379
2380 // background rejection (=1-eff.) versus signal efficiency
2381 TH1* rej_BvsS = new TH1D( GetTestvarName() + "_rejBvsS", GetTestvarName() + "", fNbins, 0, 1 );
2382 results->Store(rej_BvsS);
2383 rej_BvsS->SetXTitle( "Signal eff" );
2384 rej_BvsS->SetYTitle( "Backgr rejection (1-eff)" );
2385
2386 // inverse background eff (1/eff.) versus signal efficiency
2387 TH1* inveff_BvsS = new TH1D( GetTestvarName() + "_invBeffvsSeff",
2388 GetTestvarName(), fNbins, 0, 1 );
2389 results->Store(inveff_BvsS);
2390 inveff_BvsS->SetXTitle( "Signal eff" );
2391 inveff_BvsS->SetYTitle( "Inverse backgr. eff (1/eff)" );
2392
2393 // use root finder
2394 // spline background efficiency plot
2395 // note that there is a bin shift when going from a TH1D object to a TGraph :-(
2397 fSplRefS = new TSpline1( "spline2_signal", new TGraph( eff_s ) );
2398 fSplRefB = new TSpline1( "spline2_background", new TGraph( eff_b ) );
2399
2400 // verify spline sanity
2401 gTools().CheckSplines( eff_s, fSplRefS );
2402 gTools().CheckSplines( eff_b, fSplRefB );
2403 }
2404
2405 // make the background-vs-signal efficiency plot
2406
2407 // create root finder
2408 RootFinder rootFinder( this, fXmin, fXmax );
2409
2410 Double_t effB = 0;
2411 fEffS = eff_s; // to be set for the root finder
2412 for (Int_t bini=1; bini<=fNbins; bini++) {
2413
2414 // find cut value corresponding to a given signal efficiency
2415 Double_t effS = eff_BvsS->GetBinCenter( bini );
2416 Double_t cut = rootFinder.Root( effS );
2417
2418 // retrieve background efficiency for given cut
2419 if (Use_Splines_for_Eff_) effB = fSplRefB->Eval( cut );
2420 else effB = eff_b->GetBinContent( eff_b->FindBin( cut ) );
2421
2422 // and fill histograms
2423 eff_BvsS->SetBinContent( bini, effB );
2424 rej_BvsS->SetBinContent( bini, 1.0-effB );
2426 inveff_BvsS->SetBinContent( bini, 1.0/effB );
2427 }
2428
2429 // create splines for histogram
2430 fSpleffBvsS = new TSpline1( "effBvsS", new TGraph( eff_BvsS ) );
2431
2432 // search for overlap point where, when cutting on it,
2433 // one would obtain: eff_S = rej_B = 1 - eff_B
2434 Double_t effS = 0., rejB, effS_ = 0., rejB_ = 0.;
2435 Int_t nbins_ = 5000;
2436 for (Int_t bini=1; bini<=nbins_; bini++) {
2437
2438 // get corresponding signal and background efficiencies
2439 effS = (bini - 0.5)/Float_t(nbins_);
2440 rejB = 1.0 - fSpleffBvsS->Eval( effS );
2441
2442 // find signal efficiency that corresponds to required background efficiency
2443 if ((effS - rejB)*(effS_ - rejB_) < 0) break;
2444 effS_ = effS;
2445 rejB_ = rejB;
2446 }
2447
2448 // find cut that corresponds to signal efficiency and update signal-like criterion
2449 Double_t cut = rootFinder.Root( 0.5*(effS + effS_) );
2450 SetSignalReferenceCut( cut );
2451 fEffS = 0;
2452 }
2453
2454 // must exist...
2455 if (0 == fSpleffBvsS) {
2456 delete list;
2457 return 0.0;
2458 }
2459
2460 // now find signal efficiency that corresponds to required background efficiency
2461 Double_t effS = 0, effB = 0, effS_ = 0, effB_ = 0;
2462 Int_t nbins_ = 1000;
2463
2464 if (computeArea) {
2465
2466 // compute area of rej-vs-eff plot
2467 Double_t integral = 0;
2468 for (Int_t bini=1; bini<=nbins_; bini++) {
2469
2470 // get corresponding signal and background efficiencies
2471 effS = (bini - 0.5)/Float_t(nbins_);
2472 effB = fSpleffBvsS->Eval( effS );
2473 integral += (1.0 - effB);
2474 }
2475 integral /= nbins_;
2476
2477 delete list;
2478 return integral;
2479 }
2480 else {
2481
2482 // that will be the value of the efficiency retured (does not affect
2483 // the efficiency-vs-bkg plot which is done anyway.
2484 Float_t effBref = atof( ((TObjString*)list->At(1))->GetString() );
2485
2486 // find precise efficiency value
2487 for (Int_t bini=1; bini<=nbins_; bini++) {
2488
2489 // get corresponding signal and background efficiencies
2490 effS = (bini - 0.5)/Float_t(nbins_);
2491 effB = fSpleffBvsS->Eval( effS );
2492
2493 // find signal efficiency that corresponds to required background efficiency
2494 if ((effB - effBref)*(effB_ - effBref) <= 0) break;
2495 effS_ = effS;
2496 effB_ = effB;
2497 }
2498
2499 // take mean between bin above and bin below
2500 effS = 0.5*(effS + effS_);
2501
2502 effSerr = 0;
2503 if (nevtS > 0) effSerr = TMath::Sqrt( effS*(1.0 - effS)/nevtS );
2504
2505 delete list;
2506 return effS;
2507 }
2508
2509 return -1;
2510}
2511
2512////////////////////////////////////////////////////////////////////////////////
2513
2515{
2516 Data()->SetCurrentType(Types::kTraining);
2517
2518 Results* results = Data()->GetResults(GetMethodName(), Types::kTesting, Types::kNoAnalysisType);
2519
2520 // fill background efficiency (resp. rejection) versus signal efficiency plots
2521 // returns signal efficiency at background efficiency indicated in theString
2522
2523 // parse input string for required background efficiency
2524 TList* list = gTools().ParseFormatLine( theString );
2525 // sanity check
2526
2527 if (list->GetSize() != 2) {
2528 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetTrainingEfficiency> Wrong number of arguments"
2529 << " in string: " << theString
2530 << " | required format, e.g., Efficiency:0.05" << Endl;
2531 delete list;
2532 return -1;
2533 }
2534 // that will be the value of the efficiency retured (does not affect
2535 // the efficiency-vs-bkg plot which is done anyway.
2536 Float_t effBref = atof( ((TObjString*)list->At(1))->GetString() );
2537
2538 delete list;
2539
2540 // sanity check
2541 if (results->GetHist("MVA_S")->GetNbinsX() != results->GetHist("MVA_B")->GetNbinsX() ||
2542 results->GetHist("MVA_HIGHBIN_S")->GetNbinsX() != results->GetHist("MVA_HIGHBIN_B")->GetNbinsX() ) {
2543 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetTrainingEfficiency> Binning mismatch between signal and background histos"
2544 << Endl;
2545 return -1.0;
2546 }
2547
2548 // create histogram
2549
2550 // first, get efficiency histograms for signal and background
2551 TH1 * effhist = results->GetHist("MVA_HIGHBIN_S");
2552 Double_t xmin = effhist->GetXaxis()->GetXmin();
2553 Double_t xmax = effhist->GetXaxis()->GetXmax();
2554
2555 // first round ? --> create and fill histograms
2556 if (results->DoesExist("MVA_TRAIN_S")==0) {
2557
2558 // classifier response distributions for test sample
2559 Double_t sxmax = fXmax+0.00001;
2560
2561 // MVA plots on the training sample (check for overtraining)
2562 TH1* mva_s_tr = new TH1D( GetTestvarName() + "_Train_S",GetTestvarName() + "_Train_S", fNbinsMVAoutput, fXmin, sxmax );
2563 TH1* mva_b_tr = new TH1D( GetTestvarName() + "_Train_B",GetTestvarName() + "_Train_B", fNbinsMVAoutput, fXmin, sxmax );
2564 results->Store(mva_s_tr, "MVA_TRAIN_S");
2565 results->Store(mva_b_tr, "MVA_TRAIN_B");
2566 mva_s_tr->Sumw2();
2567 mva_b_tr->Sumw2();
2568
2569 // Training efficiency plots
2570 TH1* mva_eff_tr_s = new TH1D( GetTestvarName() + "_trainingEffS", GetTestvarName() + " (signal)",
2571 fNbinsH, xmin, xmax );
2572 TH1* mva_eff_tr_b = new TH1D( GetTestvarName() + "_trainingEffB", GetTestvarName() + " (background)",
2573 fNbinsH, xmin, xmax );
2574 results->Store(mva_eff_tr_s, "MVA_TRAINEFF_S");
2575 results->Store(mva_eff_tr_b, "MVA_TRAINEFF_B");
2576
2577 // sign if cut
2578 Int_t sign = (fCutOrientation == kPositive) ? +1 : -1;
2579
2580 std::vector<Double_t> mvaValues = GetMvaValues(0,Data()->GetNEvents());
2581 assert( (Long64_t) mvaValues.size() == Data()->GetNEvents());
2582
2583 // this method is unbinned
2584 for (Int_t ievt=0; ievt<Data()->GetNEvents(); ievt++) {
2585
2586 Data()->SetCurrentEvent(ievt);
2587 const Event* ev = GetEvent();
2588
2589 Double_t theVal = mvaValues[ievt];
2590 Double_t theWeight = ev->GetWeight();
2591
2592 TH1* theEffHist = DataInfo().IsSignal(ev) ? mva_eff_tr_s : mva_eff_tr_b;
2593 TH1* theClsHist = DataInfo().IsSignal(ev) ? mva_s_tr : mva_b_tr;
2594
2595 theClsHist->Fill( theVal, theWeight );
2596
2597 TAxis* axis = theEffHist->GetXaxis();
2598 Int_t maxbin = Int_t((theVal - axis->GetXmin())/(axis->GetXmax() - axis->GetXmin())*fNbinsH) + 1;
2599 if (sign > 0 && maxbin > fNbinsH) continue; // can happen... event doesn't count
2600 if (sign < 0 && maxbin < 1 ) continue; // can happen... event doesn't count
2601 if (sign > 0 && maxbin < 1 ) maxbin = 1;
2602 if (sign < 0 && maxbin > fNbinsH) maxbin = fNbinsH;
2603
2604 if (sign > 0) for (Int_t ibin=1; ibin<=maxbin; ibin++) theEffHist->AddBinContent( ibin , theWeight );
2605 else for (Int_t ibin=maxbin+1; ibin<=fNbinsH; ibin++) theEffHist->AddBinContent( ibin , theWeight );
2606 }
2607
2608 // normalise output distributions
2609 // uncomment those (and several others if you want unnormalized output
2610 gTools().NormHist( mva_s_tr );
2611 gTools().NormHist( mva_b_tr );
2612
2613 // renormalise to maximum
2614 mva_eff_tr_s->Scale( 1.0/TMath::Max(std::numeric_limits<double>::epsilon(), mva_eff_tr_s->GetMaximum()) );
2615 mva_eff_tr_b->Scale( 1.0/TMath::Max(std::numeric_limits<double>::epsilon(), mva_eff_tr_b->GetMaximum()) );
2616
2617 // Training background efficiency versus signal efficiency
2618 TH1* eff_bvss = new TH1D( GetTestvarName() + "_trainingEffBvsS", GetTestvarName() + "", fNbins, 0, 1 );
2619 // Training background rejection (=1-eff.) versus signal efficiency
2620 TH1* rej_bvss = new TH1D( GetTestvarName() + "_trainingRejBvsS", GetTestvarName() + "", fNbins, 0, 1 );
2621 results->Store(eff_bvss, "EFF_BVSS_TR");
2622 results->Store(rej_bvss, "REJ_BVSS_TR");
2623
2624 // use root finder
2625 // spline background efficiency plot
2626 // note that there is a bin shift when going from a TH1D object to a TGraph :-(
2628 if (fSplTrainRefS) delete fSplTrainRefS;
2629 if (fSplTrainRefB) delete fSplTrainRefB;
2630 fSplTrainRefS = new TSpline1( "spline2_signal", new TGraph( mva_eff_tr_s ) );
2631 fSplTrainRefB = new TSpline1( "spline2_background", new TGraph( mva_eff_tr_b ) );
2632
2633 // verify spline sanity
2634 gTools().CheckSplines( mva_eff_tr_s, fSplTrainRefS );
2635 gTools().CheckSplines( mva_eff_tr_b, fSplTrainRefB );
2636 }
2637
2638 // make the background-vs-signal efficiency plot
2639
2640 // create root finder
2641 RootFinder rootFinder(this, fXmin, fXmax );
2642
2643 Double_t effB = 0;
2644 fEffS = results->GetHist("MVA_TRAINEFF_S");
2645 for (Int_t bini=1; bini<=fNbins; bini++) {
2646
2647 // find cut value corresponding to a given signal efficiency
2648 Double_t effS = eff_bvss->GetBinCenter( bini );
2649
2650 Double_t cut = rootFinder.Root( effS );
2651
2652 // retrieve background efficiency for given cut
2653 if (Use_Splines_for_Eff_) effB = fSplTrainRefB->Eval( cut );
2654 else effB = mva_eff_tr_b->GetBinContent( mva_eff_tr_b->FindBin( cut ) );
2655
2656 // and fill histograms
2657 eff_bvss->SetBinContent( bini, effB );
2658 rej_bvss->SetBinContent( bini, 1.0-effB );
2659 }
2660 fEffS = 0;
2661
2662 // create splines for histogram
2663 fSplTrainEffBvsS = new TSpline1( "effBvsS", new TGraph( eff_bvss ) );
2664 }
2665
2666 // must exist...
2667 if (0 == fSplTrainEffBvsS) return 0.0;
2668
2669 // now find signal efficiency that corresponds to required background efficiency
2670 Double_t effS = 0., effB, effS_ = 0., effB_ = 0.;
2671 Int_t nbins_ = 1000;
2672 for (Int_t bini=1; bini<=nbins_; bini++) {
2673
2674 // get corresponding signal and background efficiencies
2675 effS = (bini - 0.5)/Float_t(nbins_);
2676 effB = fSplTrainEffBvsS->Eval( effS );
2677
2678 // find signal efficiency that corresponds to required background efficiency
2679 if ((effB - effBref)*(effB_ - effBref) <= 0) break;
2680 effS_ = effS;
2681 effB_ = effB;
2682 }
2683
2684 return 0.5*(effS + effS_); // the mean between bin above and bin below
2685}
2686
2687////////////////////////////////////////////////////////////////////////////////
2688
2689std::vector<Float_t> TMVA::MethodBase::GetMulticlassEfficiency(std::vector<std::vector<Float_t> >& purity)
2690{
2691 Data()->SetCurrentType(Types::kTesting);
2692 ResultsMulticlass* resMulticlass = dynamic_cast<ResultsMulticlass*>(Data()->GetResults(GetMethodName(), Types::kTesting, Types::kMulticlass));
2693 if (!resMulticlass) Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName())<< "unable to create pointer in GetMulticlassEfficiency, exiting."<<Endl;
2694
2695 purity.push_back(resMulticlass->GetAchievablePur());
2696 return resMulticlass->GetAchievableEff();
2697}
2698
2699////////////////////////////////////////////////////////////////////////////////
2700
2701std::vector<Float_t> TMVA::MethodBase::GetMulticlassTrainingEfficiency(std::vector<std::vector<Float_t> >& purity)
2702{
2703 Data()->SetCurrentType(Types::kTraining);
2704 ResultsMulticlass* resMulticlass = dynamic_cast<ResultsMulticlass*>(Data()->GetResults(GetMethodName(), Types::kTraining, Types::kMulticlass));
2705 if (!resMulticlass) Log() << kFATAL<< "unable to create pointer in GetMulticlassTrainingEfficiency, exiting."<<Endl;
2706
2707 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Determine optimal multiclass cuts for training data..." << Endl;
2708 for (UInt_t icls = 0; icls<DataInfo().GetNClasses(); ++icls) {
2709 resMulticlass->GetBestMultiClassCuts(icls);
2710 }
2711
2712 purity.push_back(resMulticlass->GetAchievablePur());
2713 return resMulticlass->GetAchievableEff();
2714}
2715
2716////////////////////////////////////////////////////////////////////////////////
2717/// Construct a confusion matrix for a multiclass classifier. The confusion
2718/// matrix compares, in turn, each class agaist all other classes in a pair-wise
2719/// fashion. In rows with index \f$ k_r = 0 ... K \f$, \f$ k_r \f$ is
2720/// considered signal for the sake of comparison and for each column
2721/// \f$ k_c = 0 ... K \f$ the corresponding class is considered background.
2722///
2723/// Note that the diagonal elements will be returned as NaN since this will
2724/// compare a class against itself.
2725///
2726/// \see TMVA::ResultsMulticlass::GetConfusionMatrix
2727///
2728/// \param[in] effB The background efficiency for which to evaluate.
2729/// \param[in] type The data set on which to evaluate (training, testing ...).
2730///
2731/// \return A matrix containing signal efficiencies for the given background
2732/// efficiency. The diagonal elements are NaN since this measure is
2733/// meaningless (comparing a class against itself).
2734///
2735
2737{
2738 if (GetAnalysisType() != Types::kMulticlass) {
2739 Log() << kFATAL << "Cannot get confusion matrix for non-multiclass analysis." << std::endl;
2740 return TMatrixD(0, 0);
2741 }
2742
2743 Data()->SetCurrentType(type);
2744 ResultsMulticlass *resMulticlass =
2745 dynamic_cast<ResultsMulticlass *>(Data()->GetResults(GetMethodName(), type, Types::kMulticlass));
2746
2747 if (resMulticlass == nullptr) {
2748 Log() << kFATAL << Form("Dataset[%s] : ", DataInfo().GetName())
2749 << "unable to create pointer in GetMulticlassEfficiency, exiting." << Endl;
2750 return TMatrixD(0, 0);
2751 }
2752
2753 return resMulticlass->GetConfusionMatrix(effB);
2754}
2755
2756////////////////////////////////////////////////////////////////////////////////
2757/// compute significance of mean difference
2758/// \f[
2759/// significance = \frac{|<S> - <B>|}{\sqrt{RMS_{S2} + RMS_{B2}}}
2760/// \f]
2761
2763{
2764 Double_t rms = sqrt( fRmsS*fRmsS + fRmsB*fRmsB );
2765
2766 return (rms > 0) ? TMath::Abs(fMeanS - fMeanB)/rms : 0;
2767}
2768
2769////////////////////////////////////////////////////////////////////////////////
2770/// compute "separation" defined as
2771/// \f[
2772/// <s2> = \frac{1}{2} \int_{-\infty}^{+\infty} { \frac{(S(x) - B(x))^2}{(S(x) + B(x))} dx }
2773/// \f]
2774
2776{
2777 return gTools().GetSeparation( histoS, histoB );
2778}
2779
2780////////////////////////////////////////////////////////////////////////////////
2781/// compute "separation" defined as
2782/// \f[
2783/// <s2> = \frac{1}{2} \int_{-\infty}^{+\infty} { \frac{(S(x) - B(x))^2}{(S(x) + B(x))} dx }
2784/// \f]
2785
2787{
2788 // note, if zero pointers given, use internal pdf
2789 // sanity check first
2790 if ((!pdfS && pdfB) || (pdfS && !pdfB))
2791 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetSeparation> Mismatch in pdfs" << Endl;
2792 if (!pdfS) pdfS = fSplS;
2793 if (!pdfB) pdfB = fSplB;
2794
2795 if (!fSplS || !fSplB) {
2796 Log()<<kDEBUG<<Form("[%s] : ",DataInfo().GetName())<< "could not calculate the separation, distributions"
2797 << " fSplS or fSplB are not yet filled" << Endl;
2798 return 0;
2799 }else{
2800 return gTools().GetSeparation( *pdfS, *pdfB );
2801 }
2802}
2803
2804////////////////////////////////////////////////////////////////////////////////
2805/// calculate the area (integral) under the ROC curve as a
2806/// overall quality measure of the classification
2807
2809{
2810 // note, if zero pointers given, use internal pdf
2811 // sanity check first
2812 if ((!histS && histB) || (histS && !histB))
2813 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetROCIntegral(TH1D*, TH1D*)> Mismatch in hists" << Endl;
2814
2815 if (histS==0 || histB==0) return 0.;
2816
2817 TMVA::PDF *pdfS = new TMVA::PDF( " PDF Sig", histS, TMVA::PDF::kSpline3 );
2818 TMVA::PDF *pdfB = new TMVA::PDF( " PDF Bkg", histB, TMVA::PDF::kSpline3 );
2819
2820
2821 Double_t xmin = TMath::Min(pdfS->GetXmin(), pdfB->GetXmin());
2822 Double_t xmax = TMath::Max(pdfS->GetXmax(), pdfB->GetXmax());
2823
2824 Double_t integral = 0;
2825 UInt_t nsteps = 1000;
2826 Double_t step = (xmax-xmin)/Double_t(nsteps);
2827 Double_t cut = xmin;
2828 for (UInt_t i=0; i<nsteps; i++) {
2829 integral += (1-pdfB->GetIntegral(cut,xmax)) * pdfS->GetVal(cut);
2830 cut+=step;
2831 }
2832 delete pdfS;
2833 delete pdfB;
2834 return integral*step;
2835}
2836
2837
2838////////////////////////////////////////////////////////////////////////////////
2839/// calculate the area (integral) under the ROC curve as a
2840/// overall quality measure of the classification
2841
2843{
2844 // note, if zero pointers given, use internal pdf
2845 // sanity check first
2846 if ((!pdfS && pdfB) || (pdfS && !pdfB))
2847 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetSeparation> Mismatch in pdfs" << Endl;
2848 if (!pdfS) pdfS = fSplS;
2849 if (!pdfB) pdfB = fSplB;
2850
2851 if (pdfS==0 || pdfB==0) return 0.;
2852
2853 Double_t xmin = TMath::Min(pdfS->GetXmin(), pdfB->GetXmin());
2854 Double_t xmax = TMath::Max(pdfS->GetXmax(), pdfB->GetXmax());
2855
2856 Double_t integral = 0;
2857 UInt_t nsteps = 1000;
2858 Double_t step = (xmax-xmin)/Double_t(nsteps);
2859 Double_t cut = xmin;
2860 for (UInt_t i=0; i<nsteps; i++) {
2861 integral += (1-pdfB->GetIntegral(cut,xmax)) * pdfS->GetVal(cut);
2862 cut+=step;
2863 }
2864 return integral*step;
2865}
2866
2867////////////////////////////////////////////////////////////////////////////////
2868/// plot significance, \f$ \frac{S}{\sqrt{S^2 + B^2}} \f$, curve for given number
2869/// of signal and background events; returns cut for maximum significance
2870/// also returned via reference is the maximum significance
2871
2873 Double_t BackgroundEvents,
2874 Double_t& max_significance_value ) const
2875{
2876 Results* results = Data()->GetResults( GetMethodName(), Types::kTesting, Types::kMaxAnalysisType );
2877
2878 Double_t max_significance(0);
2879 Double_t effS(0),effB(0),significance(0);
2880 TH1D *temp_histogram = new TH1D("temp", "temp", fNbinsH, fXmin, fXmax );
2881
2882 if (SignalEvents <= 0 || BackgroundEvents <= 0) {
2883 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetMaximumSignificance> "
2884 << "Number of signal or background events is <= 0 ==> abort"
2885 << Endl;
2886 }
2887
2888 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Using ratio SignalEvents/BackgroundEvents = "
2889 << SignalEvents/BackgroundEvents << Endl;
2890
2891 TH1* eff_s = results->GetHist("MVA_EFF_S");
2892 TH1* eff_b = results->GetHist("MVA_EFF_B");
2893
2894 if ( (eff_s==0) || (eff_b==0) ) {
2895 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Efficiency histograms empty !" << Endl;
2896 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "no maximum cut found, return 0" << Endl;
2897 return 0;
2898 }
2899
2900 for (Int_t bin=1; bin<=fNbinsH; bin++) {
2901 effS = eff_s->GetBinContent( bin );
2902 effB = eff_b->GetBinContent( bin );
2903
2904 // put significance into a histogram
2905 significance = sqrt(SignalEvents)*( effS )/sqrt( effS + ( BackgroundEvents / SignalEvents) * effB );
2906
2907 temp_histogram->SetBinContent(bin,significance);
2908 }
2909
2910 // find maximum in histogram
2911 max_significance = temp_histogram->GetBinCenter( temp_histogram->GetMaximumBin() );
2912 max_significance_value = temp_histogram->GetBinContent( temp_histogram->GetMaximumBin() );
2913
2914 // delete
2915 delete temp_histogram;
2916
2917 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Optimal cut at : " << max_significance << Endl;
2918 Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName()) << "Maximum significance: " << max_significance_value << Endl;
2919
2920 return max_significance;
2921}
2922
2923////////////////////////////////////////////////////////////////////////////////
2924/// calculates rms,mean, xmin, xmax of the event variable
2925/// this can be either done for the variables as they are or for
2926/// normalised variables (in the range of 0-1) if "norm" is set to kTRUE
2927
2929 Double_t& meanS, Double_t& meanB,
2930 Double_t& rmsS, Double_t& rmsB,
2932{
2933 Types::ETreeType previousTreeType = Data()->GetCurrentType();
2934 Data()->SetCurrentType(treeType);
2935
2936 Long64_t entries = Data()->GetNEvents();
2937
2938 // sanity check
2939 if (entries <=0)
2940 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<CalculateEstimator> Wrong tree type: " << treeType << Endl;
2941
2942 // index of the wanted variable
2943 UInt_t varIndex = DataInfo().FindVarIndex( theVarName );
2944
2945 // first fill signal and background in arrays before analysis
2946 xmin = +DBL_MAX;
2947 xmax = -DBL_MAX;
2948 Long64_t nEventsS = -1;
2949 Long64_t nEventsB = -1;
2950
2951 // take into account event weights
2952 meanS = 0;
2953 meanB = 0;
2954 rmsS = 0;
2955 rmsB = 0;
2956 Double_t sumwS = 0, sumwB = 0;
2957
2958 // loop over all training events
2959 for (Int_t ievt = 0; ievt < entries; ievt++) {
2960
2961 const Event* ev = GetEvent(ievt);
2962
2963 Double_t theVar = ev->GetValue(varIndex);
2964 Double_t weight = ev->GetWeight();
2965
2966 if (DataInfo().IsSignal(ev)) {
2967 sumwS += weight;
2968 meanS += weight*theVar;
2969 rmsS += weight*theVar*theVar;
2970 }
2971 else {
2972 sumwB += weight;
2973 meanB += weight*theVar;
2974 rmsB += weight*theVar*theVar;
2975 }
2976 xmin = TMath::Min( xmin, theVar );
2977 xmax = TMath::Max( xmax, theVar );
2978 }
2979 ++nEventsS;
2980 ++nEventsB;
2981
2982 meanS = meanS/sumwS;
2983 meanB = meanB/sumwB;
2984 rmsS = TMath::Sqrt( rmsS/sumwS - meanS*meanS );
2985 rmsB = TMath::Sqrt( rmsB/sumwB - meanB*meanB );
2986
2987 Data()->SetCurrentType(previousTreeType);
2988}
2989
2990////////////////////////////////////////////////////////////////////////////////
2991/// create reader class for method (classification only at present)
2992
2993void TMVA::MethodBase::MakeClass( const TString& theClassFileName ) const
2994{
2995 // the default consists of
2996 TString classFileName = "";
2997 if (theClassFileName == "")
2998 classFileName = GetWeightFileDir() + "/" + GetJobName() + "_" + GetMethodName() + ".class.C";
2999 else
3000 classFileName = theClassFileName;
3001
3002 TString className = TString("Read") + GetMethodName();
3003
3004 TString tfname( classFileName );
3005 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
3006 << "Creating standalone class: "
3007 << gTools().Color("lightblue") << classFileName << gTools().Color("reset") << Endl;
3008
3009 std::ofstream fout( classFileName );
3010 if (!fout.good()) { // file could not be opened --> Error
3011 Log() << kFATAL << "<MakeClass> Unable to open file: " << classFileName << Endl;
3012 }
3013
3014 // now create the class
3015 // preamble
3016 fout << "// Class: " << className << std::endl;
3017 fout << "// Automatically generated by MethodBase::MakeClass" << std::endl << "//" << std::endl;
3018
3019 // print general information and configuration state
3020 fout << std::endl;
3021 fout << "/* configuration options =====================================================" << std::endl << std::endl;
3022 WriteStateToStream( fout );
3023 fout << std::endl;
3024 fout << "============================================================================ */" << std::endl;
3025
3026 // generate the class
3027 fout << "" << std::endl;
3028 fout << "#include <array>" << std::endl;
3029 fout << "#include <vector>" << std::endl;
3030 fout << "#include <cmath>" << std::endl;
3031 fout << "#include <string>" << std::endl;
3032 fout << "#include <iostream>" << std::endl;
3033 fout << "" << std::endl;
3034 // now if the classifier needs to write some additional classes for its response implementation
3035 // this code goes here: (at least the header declarations need to come before the main class
3036 this->MakeClassSpecificHeader( fout, className );
3037
3038 fout << "#ifndef IClassifierReader__def" << std::endl;
3039 fout << "#define IClassifierReader__def" << std::endl;
3040 fout << std::endl;
3041 fout << "class IClassifierReader {" << std::endl;
3042 fout << std::endl;
3043 fout << " public:" << std::endl;
3044 fout << std::endl;
3045 fout << " // constructor" << std::endl;
3046 fout << " IClassifierReader() : fStatusIsClean( true ) {}" << std::endl;
3047 fout << " virtual ~IClassifierReader() {}" << std::endl;
3048 fout << std::endl;
3049 fout << " // return classifier response" << std::endl;
3050 if(GetAnalysisType() == Types::kMulticlass) {
3051 fout << " virtual std::vector<double> GetMulticlassValues( const std::vector<double>& inputValues ) const = 0;" << std::endl;
3052 } else {
3053 fout << " virtual double GetMvaValue( const std::vector<double>& inputValues ) const = 0;" << std::endl;
3054 }
3055 fout << std::endl;
3056 fout << " // returns classifier status" << std::endl;
3057 fout << " bool IsStatusClean() const { return fStatusIsClean; }" << std::endl;
3058 fout << std::endl;
3059 fout << " protected:" << std::endl;
3060 fout << std::endl;
3061 fout << " bool fStatusIsClean;" << std::endl;
3062 fout << "};" << std::endl;
3063 fout << std::endl;
3064 fout << "#endif" << std::endl;
3065 fout << std::endl;
3066 fout << "class " << className << " : public IClassifierReader {" << std::endl;
3067 fout << std::endl;
3068 fout << " public:" << std::endl;
3069 fout << std::endl;
3070 fout << " // constructor" << std::endl;
3071 fout << " " << className << "( std::vector<std::string>& theInputVars )" << std::endl;
3072 fout << " : IClassifierReader()," << std::endl;
3073 fout << " fClassName( \"" << className << "\" )," << std::endl;
3074 fout << " fNvars( " << GetNvar() << " )" << std::endl;
3075 fout << " {" << std::endl;
3076 fout << " // the training input variables" << std::endl;
3077 fout << " const char* inputVars[] = { ";
3078 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
3079 fout << "\"" << GetOriginalVarName(ivar) << "\"";
3080 if (ivar<GetNvar()-1) fout << ", ";
3081 }
3082 fout << " };" << std::endl;
3083 fout << std::endl;
3084 fout << " // sanity checks" << std::endl;
3085 fout << " if (theInputVars.size() <= 0) {" << std::endl;
3086 fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": empty input vector\" << std::endl;" << std::endl;
3087 fout << " fStatusIsClean = false;" << std::endl;
3088 fout << " }" << std::endl;
3089 fout << std::endl;
3090 fout << " if (theInputVars.size() != fNvars) {" << std::endl;
3091 fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": mismatch in number of input values: \"" << std::endl;
3092 fout << " << theInputVars.size() << \" != \" << fNvars << std::endl;" << std::endl;
3093 fout << " fStatusIsClean = false;" << std::endl;
3094 fout << " }" << std::endl;
3095 fout << std::endl;
3096 fout << " // validate input variables" << std::endl;
3097 fout << " for (size_t ivar = 0; ivar < theInputVars.size(); ivar++) {" << std::endl;
3098 fout << " if (theInputVars[ivar] != inputVars[ivar]) {" << std::endl;
3099 fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": mismatch in input variable names\" << std::endl" << std::endl;
3100 fout << " << \" for variable [\" << ivar << \"]: \" << theInputVars[ivar].c_str() << \" != \" << inputVars[ivar] << std::endl;" << std::endl;
3101 fout << " fStatusIsClean = false;" << std::endl;
3102 fout << " }" << std::endl;
3103 fout << " }" << std::endl;
3104 fout << std::endl;
3105 fout << " // initialize min and max vectors (for normalisation)" << std::endl;
3106 for (UInt_t ivar = 0; ivar < GetNvar(); ivar++) {
3107 fout << " fVmin[" << ivar << "] = " << std::setprecision(15) << GetXmin( ivar ) << ";" << std::endl;
3108 fout << " fVmax[" << ivar << "] = " << std::setprecision(15) << GetXmax( ivar ) << ";" << std::endl;
3109 }
3110 fout << std::endl;
3111 fout << " // initialize input variable types" << std::endl;
3112 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
3113 fout << " fType[" << ivar << "] = \'" << DataInfo().GetVariableInfo(ivar).GetVarType() << "\';" << std::endl;
3114 }
3115 fout << std::endl;
3116 fout << " // initialize constants" << std::endl;
3117 fout << " Initialize();" << std::endl;
3118 fout << std::endl;
3119 if (GetTransformationHandler().GetTransformationList().GetSize() != 0) {
3120 fout << " // initialize transformation" << std::endl;
3121 fout << " InitTransform();" << std::endl;
3122 }
3123 fout << " }" << std::endl;
3124 fout << std::endl;
3125 fout << " // destructor" << std::endl;
3126 fout << " virtual ~" << className << "() {" << std::endl;
3127 fout << " Clear(); // method-specific" << std::endl;
3128 fout << " }" << std::endl;
3129 fout << std::endl;
3130 fout << " // the classifier response" << std::endl;
3131 fout << " // \"inputValues\" is a vector of input values in the same order as the" << std::endl;
3132 fout << " // variables given to the constructor" << std::endl;
3133 if(GetAnalysisType() == Types::kMulticlass) {
3134 fout << " std::vector<double> GetMulticlassValues( const std::vector<double>& inputValues ) const override;" << std::endl;
3135 } else {
3136 fout << " double GetMvaValue( const std::vector<double>& inputValues ) const override;" << std::endl;
3137 }
3138 fout << std::endl;
3139 fout << " private:" << std::endl;
3140 fout << std::endl;
3141 fout << " // method-specific destructor" << std::endl;
3142 fout << " void Clear();" << std::endl;
3143 fout << std::endl;
3144 if (GetTransformationHandler().GetTransformationList().GetSize()!=0) {
3145 fout << " // input variable transformation" << std::endl;
3146 GetTransformationHandler().MakeFunction(fout, className,1);
3147 fout << " void InitTransform();" << std::endl;
3148 fout << " void Transform( std::vector<double> & iv, int sigOrBgd ) const;" << std::endl;
3149 fout << std::endl;
3150 }
3151 fout << " // common member variables" << std::endl;
3152 fout << " const char* fClassName;" << std::endl;
3153 fout << std::endl;
3154 fout << " const size_t fNvars;" << std::endl;
3155 fout << " size_t GetNvar() const { return fNvars; }" << std::endl;
3156 fout << " char GetType( int ivar ) const { return fType[ivar]; }" << std::endl;
3157 fout << std::endl;
3158 fout << " // normalisation of input variables" << std::endl;
3159 fout << " double fVmin[" << GetNvar() << "];" << std::endl;
3160 fout << " double fVmax[" << GetNvar() << "];" << std::endl;
3161 fout << " double NormVariable( double x, double xmin, double xmax ) const {" << std::endl;
3162 fout << " // normalise to output range: [-1, 1]" << std::endl;
3163 fout << " return 2*(x - xmin)/(xmax - xmin) - 1.0;" << std::endl;
3164 fout << " }" << std::endl;
3165 fout << std::endl;
3166 fout << " // type of input variable: 'F' or 'I'" << std::endl;
3167 fout << " char fType[" << GetNvar() << "];" << std::endl;
3168 fout << std::endl;
3169 fout << " // initialize internal variables" << std::endl;
3170 fout << " void Initialize();" << std::endl;
3171 if(GetAnalysisType() == Types::kMulticlass) {
3172 fout << " std::vector<double> GetMulticlassValues__( const std::vector<double>& inputValues ) const;" << std::endl;
3173 } else {
3174 fout << " double GetMvaValue__( const std::vector<double>& inputValues ) const;" << std::endl;
3175 }
3176 fout << "" << std::endl;
3177 fout << " // private members (method specific)" << std::endl;
3178
3179 // call the classifier specific output (the classifier must close the class !)
3180 MakeClassSpecific( fout, className );
3181
3182 if(GetAnalysisType() == Types::kMulticlass) {
3183 fout << "inline std::vector<double> " << className << "::GetMulticlassValues( const std::vector<double>& inputValues ) const" << std::endl;
3184 } else {
3185 fout << "inline double " << className << "::GetMvaValue( const std::vector<double>& inputValues ) const" << std::endl;
3186 }
3187 fout << "{" << std::endl;
3188 fout << " // classifier response value" << std::endl;
3189 if(GetAnalysisType() == Types::kMulticlass) {
3190 fout << " std::vector<double> retval;" << std::endl;
3191 } else {
3192 fout << " double retval = 0;" << std::endl;
3193 }
3194 fout << std::endl;
3195 fout << " // classifier response, sanity check first" << std::endl;
3196 fout << " if (!IsStatusClean()) {" << std::endl;
3197 fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": cannot return classifier response\"" << std::endl;
3198 fout << " << \" because status is dirty\" << std::endl;" << std::endl;
3199 fout << " }" << std::endl;
3200 fout << " else {" << std::endl;
3201 if (IsNormalised()) {
3202 fout << " // normalise variables" << std::endl;
3203 fout << " std::vector<double> iV;" << std::endl;
3204 fout << " iV.reserve(inputValues.size());" << std::endl;
3205 fout << " int ivar = 0;" << std::endl;
3206 fout << " for (std::vector<double>::const_iterator varIt = inputValues.begin();" << std::endl;
3207 fout << " varIt != inputValues.end(); varIt++, ivar++) {" << std::endl;
3208 fout << " iV.push_back(NormVariable( *varIt, fVmin[ivar], fVmax[ivar] ));" << std::endl;
3209 fout << " }" << std::endl;
3210 if (GetTransformationHandler().GetTransformationList().GetSize() != 0 && GetMethodType() != Types::kLikelihood &&
3211 GetMethodType() != Types::kHMatrix) {
3212 fout << " Transform( iV, -1 );" << std::endl;
3213 }
3214
3215 if(GetAnalysisType() == Types::kMulticlass) {
3216 fout << " retval = GetMulticlassValues__( iV );" << std::endl;
3217 } else {
3218 fout << " retval = GetMvaValue__( iV );" << std::endl;
3219 }
3220 } else {
3221 if (GetTransformationHandler().GetTransformationList().GetSize() != 0 && GetMethodType() != Types::kLikelihood &&
3222 GetMethodType() != Types::kHMatrix) {
3223 fout << " std::vector<double> iV(inputValues);" << std::endl;
3224 fout << " Transform( iV, -1 );" << std::endl;
3225 if(GetAnalysisType() == Types::kMulticlass) {
3226 fout << " retval = GetMulticlassValues__( iV );" << std::endl;
3227 } else {
3228 fout << " retval = GetMvaValue__( iV );" << std::endl;
3229 }
3230 } else {
3231 if(GetAnalysisType() == Types::kMulticlass) {
3232 fout << " retval = GetMulticlassValues__( inputValues );" << std::endl;
3233 } else {
3234 fout << " retval = GetMvaValue__( inputValues );" << std::endl;
3235 }
3236 }
3237 }
3238 fout << " }" << std::endl;
3239 fout << std::endl;
3240 fout << " return retval;" << std::endl;
3241 fout << "}" << std::endl;
3242
3243 // create output for transformation - if any
3244 if (GetTransformationHandler().GetTransformationList().GetSize()!=0)
3245 GetTransformationHandler().MakeFunction(fout, className,2);
3246
3247 // close the file
3248 fout.close();
3249}
3250
3251////////////////////////////////////////////////////////////////////////////////
3252/// prints out method-specific help method
3253
3255{
3256 // if options are written to reference file, also append help info
3257 std::streambuf* cout_sbuf = std::cout.rdbuf(); // save original sbuf
3258 std::ofstream* o = 0;
3259 if (gConfig().WriteOptionsReference()) {
3260 Log() << kINFO << "Print Help message for class " << GetName() << " into file: " << GetReferenceFile() << Endl;
3261 o = new std::ofstream( GetReferenceFile(), std::ios::app );
3262 if (!o->good()) { // file could not be opened --> Error
3263 Log() << kFATAL << "<PrintHelpMessage> Unable to append to output file: " << GetReferenceFile() << Endl;
3264 }
3265 std::cout.rdbuf( o->rdbuf() ); // redirect 'std::cout' to file
3266 }
3267
3268 // "|--------------------------------------------------------------|"
3269 if (!o) {
3270 Log() << kINFO << Endl;
3271 Log() << gTools().Color("bold")
3272 << "================================================================"
3273 << gTools().Color( "reset" )
3274 << Endl;
3275 Log() << gTools().Color("bold")
3276 << "H e l p f o r M V A m e t h o d [ " << GetName() << " ] :"
3277 << gTools().Color( "reset" )
3278 << Endl;
3279 }
3280 else {
3281 Log() << "Help for MVA method [ " << GetName() << " ] :" << Endl;
3282 }
3283
3284 // print method-specific help message
3285 GetHelpMessage();
3286
3287 if (!o) {
3288 Log() << Endl;
3289 Log() << "<Suppress this message by specifying \"!H\" in the booking option>" << Endl;
3290 Log() << gTools().Color("bold")
3291 << "================================================================"
3292 << gTools().Color( "reset" )
3293 << Endl;
3294 Log() << Endl;
3295 }
3296 else {
3297 // indicate END
3298 Log() << "# End of Message___" << Endl;
3299 }
3300
3301 std::cout.rdbuf( cout_sbuf ); // restore the original stream buffer
3302 if (o) o->close();
3303}
3304
3305// ----------------------- r o o t f i n d i n g ----------------------------
3306
3307////////////////////////////////////////////////////////////////////////////////
3308/// returns efficiency as function of cut
3309
3311{
3312 Double_t retval=0;
3313
3314 // retrieve the class object
3316 retval = fSplRefS->Eval( theCut );
3317 }
3318 else retval = fEffS->GetBinContent( fEffS->FindBin( theCut ) );
3319
3320 // caution: here we take some "forbidden" action to hide a problem:
3321 // in some cases, in particular for likelihood, the binned efficiency distributions
3322 // do not equal 1, at xmin, and 0 at xmax; of course, in principle we have the
3323 // unbinned information available in the trees, but the unbinned minimization is
3324 // too slow, and we don't need to do a precision measurement here. Hence, we force
3325 // this property.
3326 Double_t eps = 1.0e-5;
3327 if (theCut-fXmin < eps) retval = (GetCutOrientation() == kPositive) ? 1.0 : 0.0;
3328 else if (fXmax-theCut < eps) retval = (GetCutOrientation() == kPositive) ? 0.0 : 1.0;
3329
3330 return retval;
3331}
3332
3333////////////////////////////////////////////////////////////////////////////////
3334/// returns the event collection (i.e. the dataset) TRANSFORMED using the
3335/// classifiers specific Variable Transformation (e.g. Decorr or Decorr:Gauss:Decorr)
3336
3338{
3339 // if there's no variable transformation for this classifier, just hand back the
3340 // event collection of the data set
3341 if (GetTransformationHandler().GetTransformationList().GetEntries() <= 0) {
3342 return (Data()->GetEventCollection(type));
3343 }
3344
3345 // otherwise, transform ALL the events and hand back the vector of the pointers to the
3346 // transformed events. If the pointer is already != 0, i.e. the whole thing has been
3347 // done before, I don't need to do it again, but just "hand over" the pointer to those events.
3348 Int_t idx = Data()->TreeIndex(type); //index indicating Training,Testing,... events/datasets
3349 if (fEventCollections.at(idx) == 0) {
3350 fEventCollections.at(idx) = &(Data()->GetEventCollection(type));
3351 fEventCollections.at(idx) = GetTransformationHandler().CalcTransformations(*(fEventCollections.at(idx)),kTRUE);
3352 }
3353 return *(fEventCollections.at(idx));
3354}
3355
3356////////////////////////////////////////////////////////////////////////////////
3357/// calculates the TMVA version string from the training version code on the fly
3358
3360{
3361 UInt_t a = GetTrainingTMVAVersionCode() & 0xff0000; a>>=16;
3362 UInt_t b = GetTrainingTMVAVersionCode() & 0x00ff00; b>>=8;
3363 UInt_t c = GetTrainingTMVAVersionCode() & 0x0000ff;
3364
3365 return TString(Form("%i.%i.%i",a,b,c));
3366}
3367
3368////////////////////////////////////////////////////////////////////////////////
3369/// calculates the ROOT version string from the training version code on the fly
3370
3372{
3373 UInt_t a = GetTrainingROOTVersionCode() & 0xff0000; a>>=16;
3374 UInt_t b = GetTrainingROOTVersionCode() & 0x00ff00; b>>=8;
3375 UInt_t c = GetTrainingROOTVersionCode() & 0x0000ff;
3376
3377 return TString(Form("%i.%02i/%02i",a,b,c));
3378}
3379
3380////////////////////////////////////////////////////////////////////////////////
3381
3383 ResultsClassification* mvaRes = dynamic_cast<ResultsClassification*>
3384 ( Data()->GetResults(GetMethodName(),Types::kTesting, Types::kClassification) );
3385
3386 if (mvaRes != NULL) {
3387 TH1D *mva_s = dynamic_cast<TH1D*> (mvaRes->GetHist("MVA_S"));
3388 TH1D *mva_b = dynamic_cast<TH1D*> (mvaRes->GetHist("MVA_B"));
3389 TH1D *mva_s_tr = dynamic_cast<TH1D*> (mvaRes->GetHist("MVA_TRAIN_S"));
3390 TH1D *mva_b_tr = dynamic_cast<TH1D*> (mvaRes->GetHist("MVA_TRAIN_B"));
3391
3392 if ( !mva_s || !mva_b || !mva_s_tr || !mva_b_tr) return -1;
3393
3394 if (SorB == 's' || SorB == 'S')
3395 return mva_s->KolmogorovTest( mva_s_tr, opt.Data() );
3396 else
3397 return mva_b->KolmogorovTest( mva_b_tr, opt.Data() );
3398 }
3399 return -1;
3400}
SVector< double, 2 > v
Definition: Dict.h:5
const Bool_t Use_Splines_for_Eff_
Definition: MethodBase.cxx:134
const Int_t NBIN_HIST_HIGH
Definition: MethodBase.cxx:137
ROOT::R::TRInterface & r
Definition: Object.C:4
#define d(i)
Definition: RSha256.hxx:102
#define b(i)
Definition: RSha256.hxx:100
#define c(i)
Definition: RSha256.hxx:101
#define s1(x)
Definition: RSha256.hxx:91
#define ROOT_VERSION_CODE
Definition: RVersion.h:21
static RooMathCoreReg dummy
int Int_t
Definition: RtypesCore.h:41
int Ssiz_t
Definition: RtypesCore.h:63
char Char_t
Definition: RtypesCore.h:29
unsigned int UInt_t
Definition: RtypesCore.h:42
const Bool_t kFALSE
Definition: RtypesCore.h:88
bool Bool_t
Definition: RtypesCore.h:59
double Double_t
Definition: RtypesCore.h:55
long long Long64_t
Definition: RtypesCore.h:69
float Float_t
Definition: RtypesCore.h:53
const Bool_t kTRUE
Definition: RtypesCore.h:87
#define ClassImp(name)
Definition: Rtypes.h:365
char name[80]
Definition: TGX11.cxx:109
int type
Definition: TGX11.cxx:120
float xmin
Definition: THbookFile.cxx:93
float xmax
Definition: THbookFile.cxx:93
double sqrt(double)
TMatrixT< Double_t > TMatrixD
Definition: TMatrixDfwd.h:22
char * Form(const char *fmt,...)
R__EXTERN TSystem * gSystem
Definition: TSystem.h:560
#define TMVA_VERSION_CODE
Definition: Version.h:47
Class to manage histogram axis.
Definition: TAxis.h:30
Double_t GetXmax() const
Definition: TAxis.h:134
Double_t GetXmin() const
Definition: TAxis.h:133
virtual Int_t GetSize() const
Return the capacity of the collection, i.e.
Definition: TCollection.h:182
virtual Int_t Write(const char *name=0, Int_t option=0, Int_t bufsize=0)
Write all objects in this collection.
This class stores the date and time with a precision of one second in an unsigned 32 bit word (950130...
Definition: TDatime.h:37
const char * AsString() const
Return the date & time as a string (ctime() format).
Definition: TDatime.cxx:101
virtual TObject * Get(const char *namecycle)
Return pointer to object identified by namecycle.
Describe directory structure in memory.
Definition: TDirectory.h:34
virtual TDirectory * GetDirectory(const char *namecycle, Bool_t printError=false, const char *funcname="GetDirectory")
Find a directory using apath.
Definition: TDirectory.cxx:400
virtual Bool_t cd(const char *path=0)
Change current directory to "this" directory.
Definition: TDirectory.cxx:497
virtual TDirectory * mkdir(const char *name, const char *title="")
Create a sub-directory "a" or a hierarchy of sub-directories "a/b/c/...".
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:48
virtual void Close(Option_t *option="")
Close a file.
Definition: TFile.cxx:914
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseGeneralPurpose, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:3980
A Graph is a graphics object made of two arrays X and Y with npoints each.
Definition: TGraph.h:41
1-D histogram with a double per channel (see TH1 documentation)}
Definition: TH1.h:614
1-D histogram with a float per channel (see TH1 documentation)}
Definition: TH1.h:571
The TH1 histogram class.
Definition: TH1.h:56
virtual Double_t GetBinCenter(Int_t bin) const
Return bin center for 1D histogram.
Definition: TH1.cxx:8554
virtual Int_t GetQuantiles(Int_t nprobSum, Double_t *q, const Double_t *probSum=0)
Compute Quantiles for this histogram Quantile x_q of a probability distribution Function F is defined...
Definition: TH1.cxx:4434
virtual void AddBinContent(Int_t bin)
Increment bin content by 1.
Definition: TH1.cxx:1200
virtual Double_t GetMean(Int_t axis=1) const
For axis = 1,2 or 3 returns the mean value of the histogram along X,Y or Z axis.
Definition: TH1.cxx:7050
virtual void SetXTitle(const char *title)
Definition: TH1.h:409
static void AddDirectory(Bool_t add=kTRUE)
Sets the flag controlling the automatic add of histograms in memory.
Definition: TH1.cxx:1225
TAxis * GetXaxis()
Get the behaviour adopted by the object about the statoverflows. See EStatOverflows for more informat...
Definition: TH1.h:316
virtual Double_t GetMaximum(Double_t maxval=FLT_MAX) const
Return maximum value smaller than maxval of bins in the range, unless the value has been overridden b...
Definition: TH1.cxx:7964
virtual Int_t GetNbinsX() const
Definition: TH1.h:292
virtual Int_t Fill(Double_t x)
Increment bin with abscissa X by 1.
Definition: TH1.cxx:3258
virtual void SetBinContent(Int_t bin, Double_t content)
Set bin content see convention for numbering bins in TH1::GetBin In case the bin number is greater th...
Definition: TH1.cxx:8635
virtual Int_t GetMaximumBin() const
Return location of bin with maximum value in the range.
Definition: TH1.cxx:7994
virtual Double_t GetBinContent(Int_t bin) const
Return content of bin number bin.
Definition: TH1.cxx:4882
virtual void SetYTitle(const char *title)
Definition: TH1.h:410
virtual void Scale(Double_t c1=1, Option_t *option="")
Multiply this histogram by a constant c1.
Definition: TH1.cxx:6218
virtual Int_t FindBin(Double_t x, Double_t y=0, Double_t z=0)
Return Global bin number corresponding to x,y,z.
Definition: TH1.cxx:3579
virtual Double_t KolmogorovTest(const TH1 *h2, Option_t *option="") const
Statistical test of compatibility in shape between this histogram and h2, using Kolmogorov test.
Definition: TH1.cxx:7647
virtual void Sumw2(Bool_t flag=kTRUE)
Create structure to store sum of squares of weights.
Definition: TH1.cxx:8433
static Bool_t AddDirectoryStatus()
Static function: cannot be inlined on Windows/NT.
Definition: TH1.cxx:705
2-D histogram with a float per channel (see TH1 documentation)}
Definition: TH2.h:248
Int_t Fill(Double_t)
Invalid Fill method.
Definition: TH2.cxx:292
A doubly linked list.
Definition: TList.h:44
virtual TObject * At(Int_t idx) const
Returns the object at position idx. Returns 0 if idx is out of range.
Definition: TList.cxx:354
Class that contains all the information of a class.
Definition: ClassInfo.h:49
UInt_t GetNumber() const
Definition: ClassInfo.h:65
TString fWeightFileExtension
Definition: Config.h:123
class TMVA::Config::VariablePlotting fVariablePlotting
IONames & GetIONames()
Definition: Config.h:100
MsgLogger * fLogger
Definition: Configurable.h:128
Class that contains all the data information.
Definition: DataSetInfo.h:60
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition: Event.cxx:237
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not.
Definition: Event.cxx:382
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition: Event.cxx:392
Float_t GetTarget(UInt_t itgt) const
Definition: Event.h:103
static void SetIgnoreNegWeightsInTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition: Event.cxx:401
Interface for all concrete MVA method implementations.
Definition: IMethod.h:54
void Init(std::vector< TString > &graphTitles)
This function gets some title and it creates a TGraph for every title.
Definition: MethodBase.cxx:174
IPythonInteractive()
standard constructor
Definition: MethodBase.cxx:151
~IPythonInteractive()
standard destructor
Definition: MethodBase.cxx:159
void ClearGraphs()
This function sets the point number to 0 for all graphs.
Definition: MethodBase.cxx:198
void AddPoint(Double_t x, Double_t y1, Double_t y2)
This function is used only in 2 TGraph case, and it will add new data points to graphs.
Definition: MethodBase.cxx:212
Virtual base Class for all MVA method.
Definition: MethodBase.h:109
TDirectory * MethodBaseDir() const
returns the ROOT directory where all instances of the corresponding MVA method are stored
virtual Double_t GetKSTrainingVsTest(Char_t SorB, TString opt="X")
MethodBase(const TString &jobName, Types::EMVA methodType, const TString &methodTitle, DataSetInfo &dsi, const TString &theOption="")
standard constructor
Definition: MethodBase.cxx:242
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as
void ReadClassesFromXML(void *clsnode)
read number of classes from XML
void SetWeightFileDir(TString fileDir)
set directory of weight file
void WriteStateToXML(void *parent) const
general method used in writing the header of the weight files where the used variables,...
void DeclareBaseOptions()
define the options (their key words) that can be set in the option string here the options valid for ...
Definition: MethodBase.cxx:514
virtual void TestRegression(Double_t &bias, Double_t &biasT, Double_t &dev, Double_t &devT, Double_t &rms, Double_t &rmsT, Double_t &mInf, Double_t &mInfT, Double_t &corr, Types::ETreeType type)
calculate <sum-of-deviation-squared> of regression output versus "true" value from test sample
Definition: MethodBase.cxx:982
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
Definition: MethodBase.cxx:601
virtual Double_t GetSignificance() const
compute significance of mean difference
virtual Double_t GetProba(const Event *ev)
const char * GetName() const
Definition: MethodBase.h:325
virtual TMatrixD GetMulticlassConfusionMatrix(Double_t effB, Types::ETreeType type)
Construct a confusion matrix for a multiclass classifier.
void PrintHelpMessage() const
prints out method-specific help method
virtual void WriteEvaluationHistosToFile(Types::ETreeType treetype)
writes all MVA evaluation histograms to file
virtual void TestMulticlass()
test multiclass classification
const std::vector< TMVA::Event * > & GetEventCollection(Types::ETreeType type)
returns the event collection (i.e.
void SetupMethod()
setup of methods
Definition: MethodBase.cxx:411
TDirectory * BaseDir() const
returns the ROOT directory where info/histograms etc of the corresponding MVA method instance are sto...
virtual std::vector< Float_t > GetMulticlassEfficiency(std::vector< std::vector< Float_t > > &purity)
void AddInfoItem(void *gi, const TString &name, const TString &value) const
xml writing
virtual void AddClassifierOutputProb(Types::ETreeType type)
prepare tree branch with the method's discriminating variable
Definition: MethodBase.cxx:941
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
TString GetTrainingTMVAVersionString() const
calculates the TMVA version string from the training version code on the fly
void Statistics(Types::ETreeType treeType, const TString &theVarName, Double_t &, Double_t &, Double_t &, Double_t &, Double_t &, Double_t &)
calculates rms,mean, xmin, xmax of the event variable this can be either done for the variables as th...
Bool_t GetLine(std::istream &fin, char *buf)
reads one line from the input stream checks for certain keywords and interprets the line if keywords ...
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
Definition: MethodBase.cxx:428
virtual std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
Definition: MethodBase.cxx:899
virtual Bool_t IsSignalLike()
uses a pre-set cut on the MVA output (SetSignalReferenceCut and SetSignalReferenceCutOrientation) for...
Definition: MethodBase.cxx:859
virtual ~MethodBase()
destructor
Definition: MethodBase.cxx:369
virtual Double_t GetMaximumSignificance(Double_t SignalEvents, Double_t BackgroundEvents, Double_t &optimal_significance_value) const
plot significance, , curve for given number of signal and background events; returns cut for maximum ...
virtual Double_t GetTrainingEfficiency(const TString &)
void SetWeightFileName(TString)
set the weight file name (depreciated)
virtual void MakeClass(const TString &classFileName=TString("")) const
create reader class for method (classification only at present)
TString GetWeightFileName() const
retrieve weight file name
virtual void TestClassification()
initialization
void AddOutput(Types::ETreeType type, Types::EAnalysisType analysisType)
virtual void WriteMonitoringHistosToFile() const
write special monitoring histograms to file dummy implementation here --------------—
virtual void AddRegressionOutput(Types::ETreeType type)
prepare tree branch with the method's discriminating variable
Definition: MethodBase.cxx:749
void InitBase()
default initialization called by all constructors
Definition: MethodBase.cxx:446
virtual void GetRegressionDeviation(UInt_t tgtNum, Types::ETreeType type, Double_t &stddev, Double_t &stddev90Percent) const
Definition: MethodBase.cxx:729
void ReadStateFromXMLString(const char *xmlstr)
for reading from memory
void CreateMVAPdfs()
Create PDFs of the MVA output variables.
TString GetTrainingROOTVersionString() const
calculates the ROOT version string from the training version code on the fly
virtual Double_t GetValueForRoot(Double_t)
returns efficiency as function of cut
void ReadStateFromFile()
Function to write options and weights to file.
void WriteVarsToStream(std::ostream &tf, const TString &prefix="") const
write the list of variables (name, min, max) for a given data transformation method to the stream
void ReadVarsFromStream(std::istream &istr)
Read the variables (name, min, max) for a given data transformation method from the stream.
void ReadSpectatorsFromXML(void *specnode)
read spectator info from XML
virtual Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)=0
void SetTestvarName(const TString &v="")
Definition: MethodBase.h:332
void ReadVariablesFromXML(void *varnode)
read variable info from XML
virtual std::map< TString, Double_t > OptimizeTuningParameters(TString fomType="ROCIntegral", TString fitType="FitGA")
call the Optimizer with the set of parameters and ranges that are meant to be tuned.
Definition: MethodBase.cxx:628
virtual std::vector< Float_t > GetMulticlassTrainingEfficiency(std::vector< std::vector< Float_t > > &purity)
void WriteStateToStream(std::ostream &tf) const
general method used in writing the header of the weight files where the used variables,...
virtual Double_t GetRarity(Double_t mvaVal, Types::ESBType reftype=Types::kBackground) const
compute rarity:
virtual void SetTuneParameters(std::map< TString, Double_t > tuneParameters)
set the tuning parameters according to the argument This is just a dummy .
Definition: MethodBase.cxx:649
void ReadStateFromStream(std::istream &tf)
read the header from the weight files of the different MVA methods
void AddVarsXMLTo(void *parent) const
write variable info to XML
void AddTargetsXMLTo(void *parent) const
write target info to XML
void ReadTargetsFromXML(void *tarnode)
read target info from XML
void ProcessBaseOptions()
the option string is decoded, for available options see "DeclareOptions"
Definition: MethodBase.cxx:545
void ReadStateFromXML(void *parent)
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
Definition: MethodBase.cxx:841
void WriteStateToFile() const
write options and weights to file note that each one text file for the main configuration information...
void AddClassesXMLTo(void *parent) const
write class info to XML
virtual void AddClassifierOutput(Types::ETreeType type)
prepare tree branch with the method's discriminating variable
Definition: MethodBase.cxx:873
void AddSpectatorsXMLTo(void *parent) const
write spectator info to XML
virtual Double_t GetROCIntegral(TH1D *histS, TH1D *histB) const
calculate the area (integral) under the ROC curve as a overall quality measure of the classification
virtual void AddMulticlassOutput(Types::ETreeType type)
prepare tree branch with the method's discriminating variable
Definition: MethodBase.cxx:799
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Definition: MethodBase.cxx:438
void SetSource(const std::string &source)
Definition: MsgLogger.h:70
PDF wrapper for histograms; uses user-defined spline interpolation.
Definition: PDF.h:63
Double_t GetXmin() const
Definition: PDF.h:104
Double_t GetXmax() const
Definition: PDF.h:105
Double_t GetVal(Double_t x) const
returns value PDF(x)
Definition: PDF.cxx:704
@ kSpline3
Definition: PDF.h:70
@ kSpline2
Definition: PDF.h:70
Double_t GetIntegral(Double_t xmin, Double_t xmax)
computes PDF integral within given ranges
Definition: PDF.cxx:657
Class that is the base-class for a vector of result.
std::vector< Bool_t > * GetValueVectorTypes()
void SetValue(Float_t value, Int_t ievt)
set MVA response
std::vector< Float_t > * GetValueVector()
Class which takes the results of a multiclass classification.
TMatrixD GetConfusionMatrix(Double_t effB)
Returns a confusion matrix where each class is pitted against each other.
Float_t GetAchievablePur(UInt_t cls)
std::vector< Double_t > GetBestMultiClassCuts(UInt_t targetClass)
calculate the best working point (optimal cut values) for the multiclass classifier
void CreateMulticlassHistos(TString prefix, Int_t nbins, Int_t nbins_high)
this function fills the mva response histos for multiclass classification
Float_t GetAchievableEff(UInt_t cls)
void CreateMulticlassPerformanceHistos(TString prefix)
Create performance graphs for this classifier a multiclass setting.
Class that is the base-class for a vector of result.
Class that is the base-class for a vector of result.
Definition: Results.h:57
Bool_t DoesExist(const TString &alias) const
Returns true if there is an object stored in the result for a given alias, false otherwise.
Definition: Results.cxx:127
TH1 * GetHist(const TString &alias) const
Definition: Results.cxx:136
TList * GetStorage() const
Definition: Results.h:73
void Store(TObject *obj, const char *alias=0)
Definition: Results.cxx:86
Root finding using Brents algorithm (translated from CERNLIB function RZERO)
Definition: RootFinder.h:48
Double_t Root(Double_t refValue)
Root finding using Brents algorithm; taken from CERNLIB function RZERO.
Definition: RootFinder.cxx:72
Linear interpolation of TGraph.
Definition: TSpline1.h:43
Timing information for training and evaluation of MVA methods.
Definition: Timer.h:58
Double_t ElapsedSeconds(void)
computes elapsed tim in seconds
Definition: Timer.cxx:125
TString GetElapsedTime(Bool_t Scientific=kTRUE)
returns pretty string with elapsed time
Definition: Timer.cxx:134
void DrawProgressBar(Int_t, const TString &comment="")
draws progress bar in color or B&W caution:
Definition: Timer.cxx:190
void ComputeStat(const std::vector< TMVA::Event * > &, std::vector< Float_t > *, Double_t &, Double_t &, Double_t &, Double_t &, Double_t &, Double_t &, Int_t signalClass, Bool_t norm=kFALSE)
sanity check
Definition: Tools.cxx:214
TList * ParseFormatLine(TString theString, const char *sep=":")
Parse the string and cut into labels separated by ":".
Definition: Tools.cxx:413
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1174
Double_t GetSeparation(TH1 *S, TH1 *B) const
compute "separation" defined as
Definition: Tools.cxx:133
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1136
Double_t GetMutualInformation(const TH2F &)
Mutual Information method for non-linear correlations estimates in 2D histogram Author: Moritz Backes...
Definition: Tools.cxx:601
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:840
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1162
TXMLEngine & xmlengine()
Definition: Tools.h:270
Bool_t CheckSplines(const TH1 *, const TSpline *)
check quality of splining by comparing splines and histograms in each bin
Definition: Tools.cxx:491
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:337
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition: Tools.h:355
Double_t NormHist(TH1 *theHist, Double_t norm=1.0)
normalises histogram
Definition: Tools.cxx:395
Singleton class for Global types used by TMVA.
Definition: Types.h:73
@ kSignal
Definition: Types.h:136
@ kBackground
Definition: Types.h:137
@ kLikelihood
Definition: Types.h:81
@ kHMatrix
Definition: Types.h:83
EAnalysisType
Definition: Types.h:127
@ kMulticlass
Definition: Types.h:130
@ kNoAnalysisType
Definition: Types.h:131
@ kClassification
Definition: Types.h:128
@ kMaxAnalysisType
Definition: Types.h:132
@ kRegression
Definition: Types.h:129
@ kTraining
Definition: Types.h:144
@ kTesting
Definition: Types.h:145
Linear interpolation class.
Gaussian Transformation of input variables.
Class for type info of MVA input variable.
Definition: VariableInfo.h:47
void ReadFromXML(void *varnode)
read VariableInfo from stream
const TString & GetExpression() const
Definition: VariableInfo.h:57
char GetVarType() const
Definition: VariableInfo.h:61
void ReadFromStream(std::istream &istr)
read VariableInfo from stream
void AddToXML(void *varnode)
write class to XML
void SetExternalLink(void *p)
Definition: VariableInfo.h:73
void * GetExternalLink() const
Definition: VariableInfo.h:81
void BuildTransformationFromVarInfo(const std::vector< TMVA::VariableInfo > &var)
this method is only used when building a normalization transformation from old text files in this cas...
Linear interpolation class.
Linear interpolation class.
virtual void ReadTransformationFromStream(std::istream &istr, const TString &classname="")=0
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:35
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:47
Collectable string class.
Definition: TObjString.h:28
virtual Int_t Write(const char *name=0, Int_t option=0, Int_t bufsize=0)
Write this object to the current directory.
Definition: TObject.cxx:785
Basic string class.
Definition: TString.h:131
Ssiz_t Length() const
Definition: TString.h:405
void ToLower()
Change string to lower-case.
Definition: TString.cxx:1125
Int_t Atoi() const
Return integer value of string.
Definition: TString.cxx:1921
Bool_t EndsWith(const char *pat, ECaseCompare cmp=kExact) const
Return true if string ends with the specified string.
Definition: TString.cxx:2177
TSubString Strip(EStripType s=kTrailing, char c=' ') const
Return a substring of self stripped at beginning and/or end.
Definition: TString.cxx:1106
const char * Data() const
Definition: TString.h:364
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition: TString.h:687
@ kLeading
Definition: TString.h:262
Ssiz_t Last(char c) const
Find last occurrence of a character c.
Definition: TString.cxx:892
Bool_t IsNull() const
Definition: TString.h:402
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition: TString.cxx:2311
Ssiz_t Index(const char *pat, Ssiz_t i=0, ECaseCompare cmp=kExact) const
Definition: TString.h:634
virtual const char * GetBuildNode() const
Return the build node name.
Definition: TSystem.cxx:3828
virtual int MakeDirectory(const char *name)
Make a directory.
Definition: TSystem.cxx:834
virtual const char * WorkingDirectory()
Return working directory.
Definition: TSystem.cxx:878
virtual UserGroup_t * GetUserInfo(Int_t uid)
Returns all user info in the UserGroup_t structure.
Definition: TSystem.cxx:1588
void SaveDoc(XMLDocPointer_t xmldoc, const char *filename, Int_t layout=1)
store document content to file if layout<=0, no any spaces or newlines will be placed between xmlnode...
void FreeDoc(XMLDocPointer_t xmldoc)
frees allocated document data and deletes document itself
XMLNodePointer_t DocGetRootElement(XMLDocPointer_t xmldoc)
returns root node of document
XMLDocPointer_t NewDoc(const char *version="1.0")
creates new xml document with provided version
XMLDocPointer_t ParseFile(const char *filename, Int_t maxbuf=100000)
Parses content of file and tries to produce xml structures.
XMLDocPointer_t ParseString(const char *xmlstring)
parses content of string and tries to produce xml structures
void DocSetRootElement(XMLDocPointer_t xmldoc, XMLNodePointer_t xmlnode)
set main (root) node for document
TLine * line
Double_t x[n]
Definition: legend1.C:17
TH1F * h1
Definition: legend1.C:5
std::string GetMethodName(TCppMethod_t)
Definition: Cppyy.cxx:753
std::string GetName(const std::string &scope_name)
Definition: Cppyy.cxx:146
void Init(TClassEdit::TInterpreterLookupHelper *helper)
Definition: TClassEdit.cxx:144
static constexpr double s
static constexpr double m2
Config & gConfig()
Tools & gTools()
void CreateVariableTransforms(const TString &trafoDefinition, TMVA::DataSetInfo &dataInfo, TMVA::TransformationHandler &transformationHandler, TMVA::MsgLogger &log)
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Short_t Max(Short_t a, Short_t b)
Definition: TMathBase.h:212
Double_t Log(Double_t x)
Definition: TMath.h:748
Double_t Sqrt(Double_t x)
Definition: TMath.h:679
Short_t Min(Short_t a, Short_t b)
Definition: TMathBase.h:180
Short_t Abs(Short_t d)
Definition: TMathBase.h:120
TString fUser
Definition: TSystem.h:142
auto * a
Definition: textangle.C:12
REAL epsilon
Definition: triangle.c:617