Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodBase.h
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne, Jan Therhaag
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodBase *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Virtual base class for all MVA method *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Peter Speckmayer <peter.speckmayer@cern.ch> - CERN, Switzerland *
16 * Joerg Stelzer <Joerg.Stelzer@cern.ch> - CERN, Switzerland *
17 * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
18 * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
19 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
20 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
21 * *
22 * Copyright (c) 2005-2011: *
23 * CERN, Switzerland *
24 * U. of Victoria, Canada *
25 * MPI-K Heidelberg, Germany *
26 * U. of Bonn, Germany *
27 * *
28 * Redistribution and use in source and binary forms, with or without *
29 * modification, are permitted according to the terms listed in LICENSE *
30 * (http://tmva.sourceforge.net/LICENSE) *
31 **********************************************************************************/
32
33#ifndef ROOT_TMVA_MethodBase
34#define ROOT_TMVA_MethodBase
35
36//////////////////////////////////////////////////////////////////////////
37// //
38// MethodBase //
39// //
40// Virtual base class for all TMVA method //
41// //
42//////////////////////////////////////////////////////////////////////////
43
44#include <iosfwd>
45#include <vector>
46#include <map>
47#include "assert.h"
48
49#include "TString.h"
50
51#include "TMVA/IMethod.h"
52#include "TMVA/Configurable.h"
53#include "TMVA/Types.h"
54#include "TMVA/DataSet.h"
55#include "TMVA/Event.h"
57#include <TMVA/Results.h>
59
60#include <TFile.h>
61
62class TGraph;
63class TTree;
64class TDirectory;
65class TSpline;
66class TH1F;
67class TH1D;
68class TMultiGraph;
69
70/*! \class TMVA::IPythonInteractive
71\ingroup TMVA
72
73This class is needed by JsMVA, and it's a helper class for tracking errors during
74the training in Jupyter notebook. It’s only initialized in Jupyter notebook context.
75In initialization we specify some title, and a TGraph will be created for every title.
76We can add new data points easily to all TGraphs. These graphs are added to a
77TMultiGraph, and during an interactive training we get this TMultiGraph object
78and plot it with JsROOT.
79*/
80
81namespace TMVA {
82
83 class Ranking;
84 class PDF;
85 class TSpline1;
86 class MethodCuts;
87 class MethodBoost;
88 class DataSetInfo;
89 namespace Experimental {
90 class Classification;
91 }
92 class TrainingHistory;
93
95 public:
98 void Init(std::vector<TString>& graphTitles);
99 void ClearGraphs();
100 void AddPoint(Double_t x, Double_t y1, Double_t y2);
101 void AddPoint(std::vector<Double_t>& dat);
102 inline TMultiGraph* Get() {return fMultiGraph;}
103 inline bool NotInitialized(){ return fNumGraphs==0;};
104 private:
106 std::vector<TGraph*> fGraphs;
109 };
110
111 class MethodBase : virtual public IMethod, public Configurable {
112
113 friend class CrossValidation;
114 friend class Factory;
115 friend class RootFinder;
116 friend class MethodBoost;
119
120 public:
121
123
124 // default constructor
125 MethodBase( const TString& jobName,
126 Types::EMVA methodType,
127 const TString& methodTitle,
128 DataSetInfo& dsi,
129 const TString& theOption = "" );
130
131 // constructor used for Testing + Application of the MVA, only (no training),
132 // using given weight file
133 MethodBase( Types::EMVA methodType,
134 DataSetInfo& dsi,
135 const TString& weightFile );
136
137 // default destructor
138 virtual ~MethodBase();
139
140 // declaration, processing and checking of configuration options
141 void SetupMethod();
142 void ProcessSetup();
143 virtual void CheckSetup(); // may be overwritten by derived classes
144
145 // ---------- main training and testing methods ------------------------------
146
147 // prepare tree branch with the method's discriminating variable
149
150 // performs classifier training
151 // calls methods Train() implemented by derived classes
152 void TrainMethod();
153
154 // optimize tuning parameters
155 virtual std::map<TString,Double_t> OptimizeTuningParameters(TString fomType="ROCIntegral", TString fitType="FitGA");
156 virtual void SetTuneParameters(std::map<TString,Double_t> tuneParameters);
157
158 virtual void Train() = 0;
159
160 // store and retrieve time used for training
161 void SetTrainTime( Double_t trainTime ) { fTrainTime = trainTime; }
162 Double_t GetTrainTime() const { return fTrainTime; }
163
164 // store and retrieve time used for testing
165 void SetTestTime ( Double_t testTime ) { fTestTime = testTime; }
166 Double_t GetTestTime () const { return fTestTime; }
167
168 // performs classifier testing
169 virtual void TestClassification();
170 virtual Double_t GetKSTrainingVsTest(Char_t SorB, TString opt="X");
171
172 // performs multiclass classifier testing
173 virtual void TestMulticlass();
174
175 // performs regression testing
176 virtual void TestRegression( Double_t& bias, Double_t& biasT,
177 Double_t& dev, Double_t& devT,
178 Double_t& rms, Double_t& rmsT,
179 Double_t& mInf, Double_t& mInfT, // mutual information
180 Double_t& corr,
182
183 // options treatment
184 virtual void Init() = 0;
185 virtual void DeclareOptions() = 0;
186 virtual void ProcessOptions() = 0;
187 virtual void DeclareCompatibilityOptions(); // declaration of past options
188
189 // reset the Method --> As if it was not yet trained, just instantiated
190 // virtual void Reset() = 0;
191 //for the moment, I provide a dummy (that would not work) default, just to make
192 // compilation/running w/o parameter optimisation still possible
193 virtual void Reset(){return;}
194
195 // classifier response:
196 // some methods may return a per-event error estimate
197 // error calculation is skipped if err==0
198 virtual Double_t GetMvaValue( Double_t* errLower = 0, Double_t* errUpper = 0) = 0;
199
200 // signal/background classification response
201 Double_t GetMvaValue( const TMVA::Event* const ev, Double_t* err = 0, Double_t* errUpper = 0 );
202
203 protected:
204 // helper function to set errors to -1
205 void NoErrorCalc(Double_t* const err, Double_t* const errUpper);
206
207 // signal/background classification response for all current set of data
208 virtual std::vector<Double_t> GetMvaValues(Long64_t firstEvt = 0, Long64_t lastEvt = -1, Bool_t logProgress = false);
209
210
211 public:
212 // regression response
213 const std::vector<Float_t>& GetRegressionValues(const TMVA::Event* const ev){
214 fTmpEvent = ev;
215 const std::vector<Float_t>* ptr = &GetRegressionValues();
216 fTmpEvent = 0;
217 return (*ptr);
218 }
219
220 virtual const std::vector<Float_t>& GetRegressionValues() {
221 std::vector<Float_t>* ptr = new std::vector<Float_t>(0);
222 return (*ptr);
223 }
224
225 // multiclass classification response
226 virtual const std::vector<Float_t>& GetMulticlassValues() {
227 std::vector<Float_t>* ptr = new std::vector<Float_t>(0);
228 return (*ptr);
229 }
230
231 // Training history
232 virtual const std::vector<Float_t>& GetTrainingHistory(const char* /*name*/ ) {
233 std::vector<Float_t>* ptr = new std::vector<Float_t>(0);
234 return (*ptr);
235 }
236
237 // probability of classifier response (mvaval) to be signal (requires "CreateMvaPdf" option set)
238 virtual Double_t GetProba( const Event *ev); // the simple one, automatically calculates the mvaVal and uses the SAME sig/bkg ratio as given in the training sample (typically 50/50 .. (NormMode=EqualNumEvents) but can be different)
239 virtual Double_t GetProba( Double_t mvaVal, Double_t ap_sig );
240
241 // Rarity of classifier response (signal or background (default) is uniform in [0,1])
242 virtual Double_t GetRarity( Double_t mvaVal, Types::ESBType reftype = Types::kBackground ) const;
243
244 // create ranking
245 virtual const Ranking* CreateRanking() = 0;
246
247 // make ROOT-independent C++ class
248 virtual void MakeClass( const TString& classFileName = TString("") ) const;
249
250 // print help message
251 void PrintHelpMessage() const;
252
253 //
254 // streamer methods for training information (creates "weight" files) --------
255 //
256 public:
257 void WriteStateToFile () const;
258 void ReadStateFromFile ();
259
260 protected:
261 // the actual "weights"
262 virtual void AddWeightsXMLTo ( void* parent ) const = 0;
263 virtual void ReadWeightsFromXML ( void* wghtnode ) = 0;
264 virtual void ReadWeightsFromStream( std::istream& ) = 0; // backward compatibility
265 virtual void ReadWeightsFromStream( TFile& ) {} // backward compatibility
266
267 private:
268 friend class MethodCategory;
270 void WriteStateToXML ( void* parent ) const;
271 void ReadStateFromXML ( void* parent );
272 void WriteStateToStream ( std::ostream& tf ) const; // needed for MakeClass
273 void WriteVarsToStream ( std::ostream& tf, const TString& prefix = "" ) const; // needed for MakeClass
274
275
276 public: // these two need to be public, they are used to read in-memory weight-files
277 void ReadStateFromStream ( std::istream& tf ); // backward compatibility
278 void ReadStateFromStream ( TFile& rf ); // backward compatibility
279 void ReadStateFromXMLString( const char* xmlstr ); // for reading from memory
280
281 private:
282 // the variable information
283 void AddVarsXMLTo ( void* parent ) const;
284 void AddSpectatorsXMLTo ( void* parent ) const;
285 void AddTargetsXMLTo ( void* parent ) const;
286 void AddClassesXMLTo ( void* parent ) const;
287 void ReadVariablesFromXML ( void* varnode );
288 void ReadSpectatorsFromXML( void* specnode);
289 void ReadTargetsFromXML ( void* tarnode );
290 void ReadClassesFromXML ( void* clsnode );
291 void ReadVarsFromStream ( std::istream& istr ); // backward compatibility
292
293 public:
294 // ---------------------------------------------------------------------------
295
296 // write evaluation histograms into target file
297 virtual void WriteEvaluationHistosToFile(Types::ETreeType treetype);
298
299 // write classifier-specific monitoring information to target file
300 virtual void WriteMonitoringHistosToFile() const;
301
302 // ---------- public evaluation methods --------------------------------------
303
304 // individual initialization for testing of each method
305 // overload this one for individual initialisation of the testing,
306 // it is then called automatically within the global "TestInit"
307
308 // variables (and private member functions) for the Evaluation:
309 // get the efficiency. It fills a histogram for efficiency/vs/bkg
310 // and returns the one value fo the efficiency demanded for
311 // in the TString argument. (Watch the string format)
312 virtual Double_t GetEfficiency( const TString&, Types::ETreeType, Double_t& err );
313 virtual Double_t GetTrainingEfficiency(const TString& );
314 virtual std::vector<Float_t> GetMulticlassEfficiency( std::vector<std::vector<Float_t> >& purity );
315 virtual std::vector<Float_t> GetMulticlassTrainingEfficiency(std::vector<std::vector<Float_t> >& purity );
317 virtual Double_t GetSignificance() const;
318 virtual Double_t GetROCIntegral(TH1D *histS, TH1D *histB) const;
319 virtual Double_t GetROCIntegral(PDF *pdfS=0, PDF *pdfB=0) const;
320 virtual Double_t GetMaximumSignificance( Double_t SignalEvents, Double_t BackgroundEvents,
321 Double_t& optimal_significance_value ) const;
322 virtual Double_t GetSeparation( TH1*, TH1* ) const;
323 virtual Double_t GetSeparation( PDF* pdfS = 0, PDF* pdfB = 0 ) const;
324
325 virtual void GetRegressionDeviation(UInt_t tgtNum, Types::ETreeType type, Double_t& stddev,Double_t& stddev90Percent ) const;
326 // ---------- public accessors -----------------------------------------------
327
328 // classifier naming (a lot of names ... aren't they ;-)
329 const TString& GetJobName () const { return fJobName; }
330 const TString& GetMethodName () const { return fMethodName; }
333 const char* GetName () const { return fMethodName.Data(); }
334 const TString& GetTestvarName () const { return fTestvar; }
335 const TString GetProbaName () const { return fTestvar + "_Proba"; }
337
338 // build classifier name in Test tree
339 // MVA prefix (e.g., "TMVA_")
340 void SetTestvarName ( const TString & v="" ) { fTestvar = (v=="") ? ("MVA_" + GetMethodName()) : v; }
341
342 // number of input variable used by classifier
343 UInt_t GetNvar() const { return DataInfo().GetNVariables(); }
345 UInt_t GetNTargets() const { return DataInfo().GetNTargets(); };
346
347 // internal names and expressions of input variables
348 const TString& GetInputVar ( Int_t i ) const { return DataInfo().GetVariableInfo(i).GetInternalName(); }
349 const TString& GetInputLabel( Int_t i ) const { return DataInfo().GetVariableInfo(i).GetLabel(); }
350 const char * GetInputTitle( Int_t i ) const { return DataInfo().GetVariableInfo(i).GetTitle(); }
351
352 // normalisation and limit accessors
353 Double_t GetMean( Int_t ivar ) const { return GetTransformationHandler().GetMean(ivar); }
354 Double_t GetRMS ( Int_t ivar ) const { return GetTransformationHandler().GetRMS(ivar); }
355 Double_t GetXmin( Int_t ivar ) const { return GetTransformationHandler().GetMin(ivar); }
356 Double_t GetXmax( Int_t ivar ) const { return GetTransformationHandler().GetMax(ivar); }
357
358 // sets the minimum requirement on the MVA output to declare an event signal-like
361
362 // sets the minimum requirement on the MVA output to declare an event signal-like
365
366 // pointers to ROOT directories
367 TDirectory* BaseDir() const;
368 TDirectory* MethodBaseDir() const;
369 TFile* GetFile() const {return fFile;}
370
371 void SetMethodDir ( TDirectory* methodDir ) { fBaseDir = fMethodBaseDir = methodDir; }
372 void SetBaseDir( TDirectory* methodDir ){ fBaseDir = methodDir; }
373 void SetMethodBaseDir( TDirectory* methodDir ){ fMethodBaseDir = methodDir; }
375
376 //Silent file
377 void SetSilentFile(Bool_t status) {fSilentFile=status;}
379
380 //Model Persistence
381 void SetModelPersistence(Bool_t status){fModelPersistence=status;}//added support to create/remove dir here if exits or not
383
384 // the TMVA version can be obtained and checked using
385 // if (GetTrainingTMVAVersionCode()>TMVA_VERSION(3,7,2)) {...}
386 // or
387 // if (GetTrainingROOTVersionCode()>ROOT_VERSION(5,15,5)) {...}
392
394 {
395 if(fTransformationPointer && takeReroutedIfAvailable) return *fTransformationPointer; else return fTransformation;
396 }
397 const TransformationHandler& GetTransformationHandler(Bool_t takeReroutedIfAvailable=true) const
398 {
399 if(fTransformationPointer && takeReroutedIfAvailable) return *fTransformationPointer; else return fTransformation;
400 }
401
402 void RerouteTransformationHandler (TransformationHandler* fTargetTransformation) { fTransformationPointer=fTargetTransformation; }
403
404 // ---------- event accessors ------------------------------------------------
405
406 // returns reference to data set
407 // NOTE: this DataSet is the "original" dataset, i.e. the one seen by ALL Classifiers WITHOUT transformation
408 DataSet* Data() const { return DataInfo().GetDataSet(); }
409 DataSetInfo& DataInfo() const { return fDataSetInfo; }
410
411 mutable const Event* fTmpEvent; //! temporary event when testing on a different DataSet than the own one
412
413 // event reference and update
414 // NOTE: these Event accessors make sure that you get the events transformed according to the
415 // particular classifiers transformation chosen
416 UInt_t GetNEvents () const { return Data()->GetNEvents(); }
417 const Event* GetEvent () const;
418 const Event* GetEvent ( const TMVA::Event* ev ) const;
419 const Event* GetEvent ( Long64_t ievt ) const;
420 const Event* GetEvent ( Long64_t ievt , Types::ETreeType type ) const;
421 const Event* GetTrainingEvent( Long64_t ievt ) const;
422 const Event* GetTestingEvent ( Long64_t ievt ) const;
423 const std::vector<TMVA::Event*>& GetEventCollection( Types::ETreeType type );
424
426 // ---------- public auxiliary methods ---------------------------------------
427
428 // this method is used to decide whether an event is signal- or background-like
429 // the reference cut "xC" is taken to be where
430 // Int_[-oo,xC] { PDF_S(x) dx } = Int_[xC,+oo] { PDF_B(x) dx }
431 virtual Bool_t IsSignalLike();
432 virtual Bool_t IsSignalLike(Double_t mvaVal);
433
434
435 Bool_t HasMVAPdfs() const { return fHasMVAPdfs; }
440
441 // setter method for suppressing writing to XML and writing of standalone classes
442 void DisableWriting(Bool_t setter){ fModelPersistence = setter?kFALSE:kTRUE; }//DEPRECATED
443
444 protected:
445 // helper variables for JsMVA
447 bool fExitFromTraining = false;
449
450 public:
451
452 // initializing IPythonInteractive class (for JsMVA only)
454 if (fInteractive) delete fInteractive;
456 }
457
458 // get training errors (for JsMVA only)
460
461 // stop's the training process (for JsMVA only)
462 inline void ExitFromTraining(){
463 fExitFromTraining = true;
464 }
465
466 // check's if the training ended (for JsMVA only)
467 inline bool TrainingEnded(){
469 delete fInteractive;
470 fInteractive = nullptr;
471 }
472 return fExitFromTraining;
473 }
474
475 // get fIPyMaxIter
476 inline UInt_t GetMaxIter(){ return fIPyMaxIter; }
477
478 // get fIPyCurrentIter
480
481 protected:
482
483 // ---------- protected accessors -------------------------------------------
484
485 //TDirectory* LocalTDir() const { return Data().LocalRootDir(); }
486
487 // weight file name and directory (given by global config variable)
489
490 const TString& GetWeightFileDir() const { return fFileDir; }
491 void SetWeightFileDir( TString fileDir );
492
493 // are input variables normalised ?
494 Bool_t IsNormalised() const { return fNormalise; }
495 void SetNormalised( Bool_t norm ) { fNormalise = norm; }
496
497 // set number of input variables (only used by MethodCuts, could perhaps be removed)
498 // void SetNvar( Int_t n ) { fNvar = n; }
499
500 // verbose and help flags
501 Bool_t Verbose() const { return fVerbose; }
502 Bool_t Help () const { return fHelp; }
503
504 // ---------- protected event and tree accessors -----------------------------
505
506 // names of input variables (if the original names are expressions, they are
507 // transformed into regexps)
508 const TString& GetInternalVarName( Int_t ivar ) const { return (*fInputVars)[ivar]; }
509 const TString& GetOriginalVarName( Int_t ivar ) const { return DataInfo().GetVariableInfo(ivar).GetExpression(); }
510
511 Bool_t HasTrainingTree() const { return Data()->GetNTrainingEvents() != 0; }
512
513 // ---------- protected auxiliary methods ------------------------------------
514
515 protected:
516
517 // make ROOT-independent C++ class for classifier response (classifier-specific implementation)
518 virtual void MakeClassSpecific( std::ostream&, const TString& = "" ) const {}
519
520 // header and auxiliary classes
521 virtual void MakeClassSpecificHeader( std::ostream&, const TString& = "" ) const {}
522
523 // static pointer to this object - required for ROOT finder (to be solved differently)(solved by Omar)
524 //static MethodBase* GetThisBase();
525
526 // some basic statistical analysis
527 void Statistics( Types::ETreeType treeType, const TString& theVarName,
530
531 // if TRUE, write weights only to text files
532 Bool_t TxtWeightsOnly() const { return kTRUE; }
533
534 protected:
535
536 // access to event information that needs method-specific information
537
539
540 private:
541
542 // ---------- private definitions --------------------------------------------
543 // Initialisation
544 void InitBase();
545 void DeclareBaseOptions();
546 void ProcessBaseOptions();
547
548 // used in efficiency computation
551
552 // ---------- private accessors ---------------------------------------------
553
554 // reset required for RootFinder
556
557 // ---------- private auxiliary methods --------------------------------------
558
559 // PDFs for classifier response (required to compute signal probability and Rarity)
560 void CreateMVAPdfs();
561
562 // for root finder
563 //virtual method to find ROOT
564 virtual Double_t GetValueForRoot ( Double_t ); // implementation
565
566 // used for file parsing
567 Bool_t GetLine( std::istream& fin, char * buf );
568
569 // fill test tree with classification or regression results
574
575 private:
576
577 void AddInfoItem( void* gi, const TString& name,
578 const TString& value) const;
579
580 // ========== class members ==================================================
581
582 protected:
583
584 // direct accessors
585 Ranking* fRanking; // pointer to ranking object (created by derived classifiers)
586 std::vector<TString>* fInputVars; // vector of input variables used in MVA
587
588 // histogram binning
589 Int_t fNbins; // number of bins in input variable histograms
590 Int_t fNbinsMVAoutput; // number of bins in MVA output histograms
591 Int_t fNbinsH; // number of bins in evaluation histograms
592
593 Types::EAnalysisType fAnalysisType; // method-mode : true --> regression, false --> classification
594
595 std::vector<Float_t>* fRegressionReturnVal; // holds the return-values for the regression
596 std::vector<Float_t>* fMulticlassReturnVal; // holds the return-values for the multiclass classification
597
598 private:
599
600 // MethodCuts redefines some of the evaluation variables and histograms -> must access private members
601 friend class MethodCuts;
602
603
604 // data sets
605 DataSetInfo& fDataSetInfo; //! the data set information (sometimes needed)
606
607 Double_t fSignalReferenceCut; // minimum requirement on the MVA output to declare an event signal-like
608 Double_t fSignalReferenceCutOrientation; // minimum requirement on the MVA output to declare an event signal-like
609 Types::ESBType fVariableTransformType; // this is the event type (sig or bgd) assumed for variable transform
610
611 // naming and versioning
612 TString fJobName; // name of job -> user defined, appears in weight files
613 TString fMethodName; // name of the method (set in derived class)
614 Types::EMVA fMethodType; // type of method (set in derived class)
615 TString fTestvar; // variable used in evaluation, etc (mostly the MVA)
616 UInt_t fTMVATrainingVersion; // TMVA version used for training
617 UInt_t fROOTTrainingVersion; // ROOT version used for training
618 Bool_t fConstructedFromWeightFile; // is it obtained from weight file?
619
620 // Directory structure: dataloader/fMethodBaseDir/fBaseDir
621 // where the first directory name is defined by the method type
622 // and the second is user supplied (the title given in Factory::BookMethod())
623 TDirectory* fBaseDir; // base directory for the instance, needed to know where to jump back from localDir
624 mutable TDirectory* fMethodBaseDir; // base directory for the method
625 //this will be the next way to save results
627
628 //SilentFile
630 //Model Persistence
632
633 TString fParentDir; // method parent name, like booster name
634
635 TString fFileDir; // unix sub-directory for weight files (default: DataLoader's Name + "weights")
636 TString fWeightFile; // weight file name
637
638 private:
639
640 TH1* fEffS; // efficiency histogram for rootfinder
641
642 PDF* fDefaultPDF; // default PDF definitions
643 PDF* fMVAPdfS; // signal MVA PDF
644 PDF* fMVAPdfB; // background MVA PDF
645
646 // TH1D* fmvaS; // PDFs of MVA distribution (signal)
647 // TH1D* fmvaB; // PDFs of MVA distribution (background)
648 PDF* fSplS; // PDFs of MVA distribution (signal)
649 PDF* fSplB; // PDFs of MVA distribution (background)
650 TSpline* fSpleffBvsS; // splines for signal eff. versus background eff.
651
652 PDF* fSplTrainS; // PDFs of training MVA distribution (signal)
653 PDF* fSplTrainB; // PDFs of training MVA distribution (background)
654 TSpline* fSplTrainEffBvsS; // splines for training signal eff. versus background eff.
655
656 private:
657
658 // basic statistics quantities of MVA
659 Double_t fMeanS; // mean (signal)
660 Double_t fMeanB; // mean (background)
661 Double_t fRmsS; // RMS (signal)
662 Double_t fRmsB; // RMS (background)
663 Double_t fXmin; // minimum (signal and background)
664 Double_t fXmax; // maximum (signal and background)
665
666 // variable preprocessing
667 TString fVarTransformString; // labels variable transform method
668
669 TransformationHandler* fTransformationPointer; // pointer to the rest of transformations
670 TransformationHandler fTransformation; // the list of transformations
671
672
673 // help and verbosity
674 Bool_t fVerbose; // verbose flag
675 TString fVerbosityLevelString; // verbosity level (user input string)
676 EMsgType fVerbosityLevel; // verbosity level
677 Bool_t fHelp; // help flag
678 Bool_t fHasMVAPdfs; // MVA Pdfs are created for this classifier
679
680 Bool_t fIgnoreNegWeightsInTraining;// If true, events with negative weights are not used in training
681
682 protected:
683
685
686 // for signal/background
687 UInt_t fSignalClass; // index of the Signal-class
688 UInt_t fBackgroundClass; // index of the Background-class
689
690 private:
691
692 // timing variables
693 Double_t fTrainTime; // for timing measurements
694 Double_t fTestTime; // for timing measurements
695
696 // orientation of cut: depends on signal and background mean values
697 ECutOrientation fCutOrientation; // +1 if Sig>Bkg, -1 otherwise
698
699 // for root finder
700 TSpline1* fSplRefS; // helper splines for RootFinder (signal)
701 TSpline1* fSplRefB; // helper splines for RootFinder (background)
702
703 TSpline1* fSplTrainRefS; // helper splines for RootFinder (signal)
704 TSpline1* fSplTrainRefB; // helper splines for RootFinder (background)
705
706 mutable std::vector<const std::vector<TMVA::Event*>*> fEventCollections; // if the method needs the complete event-collection, the transformed event coll. ist stored here.
707
708 public:
709 Bool_t fSetupCompleted; // is method setup
710
711 private:
712
713 // This is a workaround for OSx where static thread_local data members are
714 // not supported. The C++ solution would indeed be the following:
715// static MethodBase*& GetThisBaseThreadLocal() {TTHREAD_TLS(MethodBase*) fgThisBase(nullptr); return fgThisBase; };
716
717 // ===== depreciated options, kept for backward compatibility =====
718 private:
719
720 Bool_t fNormalise; // normalise input variables
721 Bool_t fUseDecorr; // synonymous for decorrelation
722 TString fVariableTransformTypeString; // labels variable transform type
723 Bool_t fTxtWeightsOnly; // if TRUE, write weights only to text files
724 Int_t fNbinsMVAPdf; // number of bins used in histogram that creates PDF
725 Int_t fNsmoothMVAPdf; // number of times a histogram is smoothed before creating the PDF
726
727 protected:
729 ClassDef(MethodBase,0); // Virtual base class for all TMVA method
730
731 };
732} // namespace TMVA
733
734
735
736
737
738
739
740// ========== INLINE FUNCTIONS =========================================================
741
742
743//_______________________________________________________________________
744inline const TMVA::Event* TMVA::MethodBase::GetEvent( const TMVA::Event* ev ) const
745{
747}
748
750{
751 if(fTmpEvent)
752 return GetTransformationHandler().Transform(fTmpEvent);
753 else
754 return GetTransformationHandler().Transform(Data()->GetEvent());
755}
756
758{
759 assert(fTmpEvent==0);
760 return GetTransformationHandler().Transform(Data()->GetEvent(ievt));
761}
762
764{
765 assert(fTmpEvent==0);
766 return GetTransformationHandler().Transform(Data()->GetEvent(ievt, type));
767}
768
770{
771 assert(fTmpEvent==0);
772 return GetEvent(ievt, Types::kTraining);
773}
774
776{
777 assert(fTmpEvent==0);
778 return GetEvent(ievt, Types::kTesting);
779}
780
781#endif
char Char_t
Definition RtypesCore.h:37
const Bool_t kFALSE
Definition RtypesCore.h:92
double Double_t
Definition RtypesCore.h:59
long long Long64_t
Definition RtypesCore.h:73
const Bool_t kTRUE
Definition RtypesCore.h:91
#define ClassDef(name, id)
Definition Rtypes.h:325
char name[80]
Definition TGX11.cxx:110
int type
Definition TGX11.cxx:121
Describe directory structure in memory.
Definition TDirectory.h:45
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition TFile.h:54
A TGraph is an object made of two arrays X and Y with npoints each.
Definition TGraph.h:41
1-D histogram with a double per channel (see TH1 documentation)}
Definition TH1.h:618
1-D histogram with a float per channel (see TH1 documentation)}
Definition TH1.h:575
TH1 is the base class of all histogram classes in ROOT.
Definition TH1.h:58
Class to perform two class classification.
Class to perform cross validation, splitting the dataloader into folds.
Class that contains all the data information.
Definition DataSetInfo.h:62
UInt_t GetNVariables() const
UInt_t GetNTargets() const
DataSet * GetDataSet() const
returns data set
VariableInfo & GetVariableInfo(Int_t i)
Class that contains all the data information.
Definition DataSet.h:58
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition DataSet.h:206
Long64_t GetNTrainingEvents() const
Definition DataSet.h:68
This is the main MVA steering class.
Definition Factory.h:80
Interface for all concrete MVA method implementations.
Definition IMethod.h:53
This class is needed by JsMVA, and it's a helper class for tracking errors during the training in Jup...
Definition MethodBase.h:94
void Init(std::vector< TString > &graphTitles)
This function gets some title and it creates a TGraph for every title.
IPythonInteractive()
standard constructor
TMultiGraph * fMultiGraph
Definition MethodBase.h:105
std::vector< TGraph * > fGraphs
Definition MethodBase.h:106
~IPythonInteractive()
standard destructor
TMultiGraph * Get()
Definition MethodBase.h:102
void ClearGraphs()
This function sets the point number to 0 for all graphs.
void AddPoint(Double_t x, Double_t y1, Double_t y2)
This function is used only in 2 TGraph case, and it will add new data points to graphs.
Virtual base Class for all MVA method.
Definition MethodBase.h:111
TransformationHandler * fTransformationPointer
Definition MethodBase.h:669
virtual void MakeClassSpecificHeader(std::ostream &, const TString &="") const
Definition MethodBase.h:521
TString fVerbosityLevelString
Definition MethodBase.h:675
TDirectory * MethodBaseDir() const
returns the ROOT directory where all instances of the corresponding MVA method are stored
virtual const std::vector< Float_t > & GetRegressionValues()
Definition MethodBase.h:220
const std::vector< Float_t > & GetRegressionValues(const TMVA::Event *const ev)
Definition MethodBase.h:213
virtual void Train()=0
virtual Double_t GetKSTrainingVsTest(Char_t SorB, TString opt="X")
TString fMethodName
Definition MethodBase.h:613
TFile * GetFile() const
Definition MethodBase.h:369
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as
Bool_t HasTrainingTree() const
Definition MethodBase.h:511
virtual void DeclareOptions()=0
void SetSilentFile(Bool_t status)
Definition MethodBase.h:377
void ReadClassesFromXML(void *clsnode)
read number of classes from XML
TMultiGraph * GetInteractiveTrainingError()
Definition MethodBase.h:459
void SetWeightFileDir(TString fileDir)
set directory of weight file
void WriteStateToXML(void *parent) const
general method used in writing the header of the weight files where the used variables,...
void DeclareBaseOptions()
define the options (their key words) that can be set in the option string here the options valid for ...
UInt_t GetMaxIter()
Definition MethodBase.h:476
Double_t GetXmin(Int_t ivar) const
Definition MethodBase.h:355
Bool_t Verbose() const
Definition MethodBase.h:501
virtual void TestRegression(Double_t &bias, Double_t &biasT, Double_t &dev, Double_t &devT, Double_t &rms, Double_t &rmsT, Double_t &mInf, Double_t &mInfT, Double_t &corr, Types::ETreeType type)
calculate <sum-of-deviation-squared> of regression output versus "true" value from test sample
Double_t GetMean(Int_t ivar) const
Definition MethodBase.h:353
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
TString GetMethodTypeName() const
Definition MethodBase.h:331
Bool_t DoMulticlass() const
Definition MethodBase.h:439
virtual Double_t GetSignificance() const
compute significance of mean difference
void DisableWriting(Bool_t setter)
Definition MethodBase.h:442
Bool_t fTxtWeightsOnly
Definition MethodBase.h:723
virtual void ReadWeightsFromXML(void *wghtnode)=0
virtual Double_t GetProba(const Event *ev)
const char * GetName() const
Definition MethodBase.h:333
virtual const std::vector< Float_t > & GetMulticlassValues()
Definition MethodBase.h:226
TSpline * fSplTrainEffBvsS
Definition MethodBase.h:654
Types::EAnalysisType GetAnalysisType() const
Definition MethodBase.h:437
virtual TMatrixD GetMulticlassConfusionMatrix(Double_t effB, Types::ETreeType type)
Construct a confusion matrix for a multiclass classifier.
UInt_t GetTrainingTMVAVersionCode() const
Definition MethodBase.h:388
UInt_t fTMVATrainingVersion
Definition MethodBase.h:616
Bool_t IsModelPersistence() const
Definition MethodBase.h:382
Types::ESBType fVariableTransformType
Definition MethodBase.h:609
const TString & GetInputVar(Int_t i) const
Definition MethodBase.h:348
void PrintHelpMessage() const
prints out method-specific help method
void SetMethodDir(TDirectory *methodDir)
Definition MethodBase.h:371
const TString & GetJobName() const
Definition MethodBase.h:329
Bool_t IgnoreEventsWithNegWeightsInTraining() const
Definition MethodBase.h:684
Double_t fTrainTime
Definition MethodBase.h:693
virtual void WriteEvaluationHistosToFile(Types::ETreeType treetype)
writes all MVA evaluation histograms to file
virtual void TestMulticlass()
test multiclass classification
Bool_t fModelPersistence
Definition MethodBase.h:631
const TString & GetTestvarName() const
Definition MethodBase.h:334
const std::vector< TMVA::Event * > & GetEventCollection(Types::ETreeType type)
returns the event collection (i.e.
void SetupMethod()
setup of methods
TDirectory * BaseDir() const
returns the ROOT directory where info/histograms etc of the corresponding MVA method instance are sto...
const Event * GetTestingEvent(Long64_t ievt) const
Definition MethodBase.h:775
virtual std::vector< Float_t > GetMulticlassEfficiency(std::vector< std::vector< Float_t > > &purity)
UInt_t GetNTargets() const
Definition MethodBase.h:345
EMsgType fVerbosityLevel
Definition MethodBase.h:676
const TString GetProbaName() const
Definition MethodBase.h:335
TransformationHandler fTransformation
Definition MethodBase.h:670
virtual void ReadWeightsFromStream(std::istream &)=0
void AddInfoItem(void *gi, const TString &name, const TString &value) const
xml writing
TDirectory * fMethodBaseDir
Definition MethodBase.h:624
virtual void AddClassifierOutputProb(Types::ETreeType type)
prepare tree branch with the method's discriminating variable
virtual void MakeClassSpecific(std::ostream &, const TString &="") const
Definition MethodBase.h:518
Double_t fSignalReferenceCutOrientation
Definition MethodBase.h:608
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
TString GetTrainingTMVAVersionString() const
calculates the TMVA version string from the training version code on the fly
TSpline * fSpleffBvsS
Definition MethodBase.h:650
TString fVariableTransformTypeString
Definition MethodBase.h:722
TSpline1 * fSplTrainRefB
Definition MethodBase.h:704
const TString & GetWeightFileDir() const
Definition MethodBase.h:490
virtual void SetAnalysisType(Types::EAnalysisType type)
Definition MethodBase.h:436
UInt_t GetCurrentIter()
Definition MethodBase.h:479
const TString & GetMethodName() const
Definition MethodBase.h:330
Bool_t TxtWeightsOnly() const
Definition MethodBase.h:532
void Statistics(Types::ETreeType treeType, const TString &theVarName, Double_t &, Double_t &, Double_t &, Double_t &, Double_t &, Double_t &)
calculates rms,mean, xmin, xmax of the event variable this can be either done for the variables as th...
virtual void ReadWeightsFromStream(TFile &)
Definition MethodBase.h:265
void ExitFromTraining()
Definition MethodBase.h:462
UInt_t GetNEvents() const
temporary event when testing on a different DataSet than the own one
Definition MethodBase.h:416
Bool_t DoRegression() const
Definition MethodBase.h:438
std::vector< Float_t > * fRegressionReturnVal
Definition MethodBase.h:595
std::vector< Float_t > * fMulticlassReturnVal
Definition MethodBase.h:596
Bool_t GetLine(std::istream &fin, char *buf)
reads one line from the input stream checks for certain keywords and interprets the line if keywords ...
const Event * GetEvent() const
Definition MethodBase.h:749
TString fVarTransformString
Definition MethodBase.h:667
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
virtual void ProcessOptions()=0
virtual std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
virtual Bool_t IsSignalLike()
uses a pre-set cut on the MVA output (SetSignalReferenceCut and SetSignalReferenceCutOrientation) for...
void RerouteTransformationHandler(TransformationHandler *fTargetTransformation)
Definition MethodBase.h:402
virtual ~MethodBase()
destructor
Bool_t HasMVAPdfs() const
Definition MethodBase.h:435
UInt_t fBackgroundClass
Definition MethodBase.h:688
Double_t fTestTime
Definition MethodBase.h:694
virtual Double_t GetMaximumSignificance(Double_t SignalEvents, Double_t BackgroundEvents, Double_t &optimal_significance_value) const
plot significance, , curve for given number of signal and background events; returns cut for maximum ...
virtual Double_t GetTrainingEfficiency(const TString &)
void SetWeightFileName(TString)
set the weight file name (depreciated)
DataSetInfo & DataInfo() const
Definition MethodBase.h:409
virtual void MakeClass(const TString &classFileName=TString("")) const
create reader class for method (classification only at present)
TString GetWeightFileName() const
retrieve weight file name
void SetTestTime(Double_t testTime)
Definition MethodBase.h:165
Types::EMVA fMethodType
Definition MethodBase.h:614
virtual void TestClassification()
initialization
void AddOutput(Types::ETreeType type, Types::EAnalysisType analysisType)
virtual void WriteMonitoringHistosToFile() const
write special monitoring histograms to file dummy implementation here --------------—
UInt_t GetNVariables() const
Definition MethodBase.h:344
Types::EAnalysisType fAnalysisType
Definition MethodBase.h:593
virtual void AddRegressionOutput(Types::ETreeType type)
prepare tree branch with the method's discriminating variable
void InitBase()
default initialization called by all constructors
std::vector< const std::vector< TMVA::Event * > * > fEventCollections
Definition MethodBase.h:706
virtual void GetRegressionDeviation(UInt_t tgtNum, Types::ETreeType type, Double_t &stddev, Double_t &stddev90Percent) const
void InitIPythonInteractive()
Definition MethodBase.h:453
void ReadStateFromXMLString(const char *xmlstr)
for reading from memory
Bool_t fSetupCompleted
Definition MethodBase.h:709
void CreateMVAPdfs()
Create PDFs of the MVA output variables.
TString GetTrainingROOTVersionString() const
calculates the ROOT version string from the training version code on the fly
virtual Double_t GetValueForRoot(Double_t)
returns efficiency as function of cut
UInt_t GetTrainingROOTVersionCode() const
Definition MethodBase.h:389
void ReadStateFromFile()
Function to write options and weights to file.
void WriteVarsToStream(std::ostream &tf, const TString &prefix="") const
write the list of variables (name, min, max) for a given data transformation method to the stream
void ReadVarsFromStream(std::istream &istr)
Read the variables (name, min, max) for a given data transformation method from the stream.
virtual void AddWeightsXMLTo(void *parent) const =0
virtual void Init()=0
void ReadSpectatorsFromXML(void *specnode)
read spectator info from XML
virtual const std::vector< Float_t > & GetTrainingHistory(const char *)
Definition MethodBase.h:232
const Event * fTmpEvent
Definition MethodBase.h:411
virtual Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)=0
void SetNormalised(Bool_t norm)
Definition MethodBase.h:495
void SetTestvarName(const TString &v="")
Definition MethodBase.h:340
void ReadVariablesFromXML(void *varnode)
read variable info from XML
Bool_t fConstructedFromWeightFile
Definition MethodBase.h:618
UInt_t GetNvar() const
Definition MethodBase.h:343
virtual std::map< TString, Double_t > OptimizeTuningParameters(TString fomType="ROCIntegral", TString fitType="FitGA")
call the Optimizer with the set of parameters and ranges that are meant to be tuned.
void SetTrainTime(Double_t trainTime)
Definition MethodBase.h:161
Double_t GetXmax(Int_t ivar) const
Definition MethodBase.h:356
virtual std::vector< Float_t > GetMulticlassTrainingEfficiency(std::vector< std::vector< Float_t > > &purity)
DataSetInfo & fDataSetInfo
Definition MethodBase.h:605
void WriteStateToStream(std::ostream &tf) const
general method used in writing the header of the weight files where the used variables,...
virtual Double_t GetRarity(Double_t mvaVal, Types::ESBType reftype=Types::kBackground) const
compute rarity:
virtual void SetTuneParameters(std::map< TString, Double_t > tuneParameters)
set the tuning parameters according to the argument This is just a dummy .
Double_t GetTrainTime() const
Definition MethodBase.h:162
void SetBaseDir(TDirectory *methodDir)
Definition MethodBase.h:372
void ReadStateFromStream(std::istream &tf)
read the header from the weight files of the different MVA methods
void AddVarsXMLTo(void *parent) const
write variable info to XML
Bool_t Help() const
Definition MethodBase.h:502
TSpline1 * fSplRefB
Definition MethodBase.h:701
UInt_t fIPyCurrentIter
Definition MethodBase.h:448
TransformationHandler & GetTransformationHandler(Bool_t takeReroutedIfAvailable=true)
Definition MethodBase.h:393
Bool_t IsSilentFile() const
Definition MethodBase.h:378
Types::EMVA GetMethodType() const
Definition MethodBase.h:332
void AddTargetsXMLTo(void *parent) const
write target info to XML
Results * fResults
Definition MethodBase.h:728
void ReadTargetsFromXML(void *tarnode)
read target info from XML
void SetFile(TFile *file)
Definition MethodBase.h:374
void ProcessBaseOptions()
the option string is decoded, for available options see "DeclareOptions"
virtual void Reset()
Definition MethodBase.h:193
UInt_t fROOTTrainingVersion
Definition MethodBase.h:617
void ReadStateFromXML(void *parent)
Double_t GetSignalReferenceCutOrientation() const
Definition MethodBase.h:360
void SetSignalReferenceCut(Double_t cut)
Definition MethodBase.h:363
std::vector< TString > * fInputVars
Definition MethodBase.h:586
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
Double_t fSignalReferenceCut
the data set information (sometimes needed)
Definition MethodBase.h:607
Double_t GetTestTime() const
Definition MethodBase.h:166
const TransformationHandler & GetTransformationHandler(Bool_t takeReroutedIfAvailable=true) const
Definition MethodBase.h:397
void SetSignalReferenceCutOrientation(Double_t cutOrientation)
Definition MethodBase.h:364
TDirectory * fBaseDir
Definition MethodBase.h:623
Bool_t fIgnoreNegWeightsInTraining
Definition MethodBase.h:680
void WriteStateToFile() const
write options and weights to file note that each one text file for the main configuration information...
void AddClassesXMLTo(void *parent) const
write class info to XML
const TString & GetInputLabel(Int_t i) const
Definition MethodBase.h:349
ECutOrientation GetCutOrientation() const
Definition MethodBase.h:550
Ranking * fRanking
Definition MethodBase.h:585
TrainingHistory fTrainHistory
Definition MethodBase.h:425
void SetMethodBaseDir(TDirectory *methodDir)
Definition MethodBase.h:373
virtual void AddClassifierOutput(Types::ETreeType type)
prepare tree branch with the method's discriminating variable
DataSet * Data() const
Definition MethodBase.h:408
void AddSpectatorsXMLTo(void *parent) const
write spectator info to XML
void SetModelPersistence(Bool_t status)
Definition MethodBase.h:381
TString fWeightFile
Definition MethodBase.h:636
Bool_t IsNormalised() const
Definition MethodBase.h:494
const TString & GetInternalVarName(Int_t ivar) const
Definition MethodBase.h:508
const Event * GetTrainingEvent(Long64_t ievt) const
Definition MethodBase.h:769
Double_t GetSignalReferenceCut() const
Definition MethodBase.h:359
TSpline1 * fSplTrainRefS
Definition MethodBase.h:703
TSpline1 * fSplRefS
Definition MethodBase.h:700
virtual Double_t GetROCIntegral(TH1D *histS, TH1D *histB) const
calculate the area (integral) under the ROC curve as a overall quality measure of the classification
ECutOrientation fCutOrientation
Definition MethodBase.h:697
virtual const Ranking * CreateRanking()=0
Double_t GetRMS(Int_t ivar) const
Definition MethodBase.h:354
IPythonInteractive * fInteractive
Definition MethodBase.h:446
const char * GetInputTitle(Int_t i) const
Definition MethodBase.h:350
const TString & GetOriginalVarName(Int_t ivar) const
Definition MethodBase.h:509
virtual void AddMulticlassOutput(Types::ETreeType type)
prepare tree branch with the method's discriminating variable
Bool_t IsConstructedFromWeightFile() const
Definition MethodBase.h:538
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Class for boosting a TMVA method.
Definition MethodBoost.h:58
Class for categorizing the phase space.
Virtual base class for combining several TMVA method.
Multivariate optimisation of signal efficiency for given background efficiency, applying rectangular ...
Definition MethodCuts.h:61
PDF wrapper for histograms; uses user-defined spline interpolation.
Definition PDF.h:63
Ranking for variables in method (implementation)
Definition Ranking.h:48
Class that is the base-class for a vector of result.
Definition Results.h:57
Root finding using Brents algorithm (translated from CERNLIB function RZERO)
Definition RootFinder.h:48
Linear interpolation of TGraph.
Definition TSpline1.h:43
Tracking data from training.
Class that contains all the data information.
const Event * Transform(const Event *) const
the transformation
Double_t GetRMS(Int_t ivar, Int_t cls=-1) const
Double_t GetMean(Int_t ivar, Int_t cls=-1) const
Double_t GetMin(Int_t ivar, Int_t cls=-1) const
Double_t GetMax(Int_t ivar, Int_t cls=-1) const
TString GetMethodName(Types::EMVA method) const
Definition Types.cxx:135
static Types & Instance()
the the single instance of "Types" if existing already, or create it (Singleton)
Definition Types.cxx:69
@ kBackground
Definition Types.h:138
@ kMulticlass
Definition Types.h:131
@ kRegression
Definition Types.h:130
@ kTraining
Definition Types.h:145
const TString & GetLabel() const
const TString & GetExpression() const
const TString & GetInternalName() const
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition TMultiGraph.h:36
virtual const char * GetTitle() const
Returns title of object.
Definition TNamed.h:48
Base class for spline implementation containing the Draw/Paint methods.
Definition TSpline.h:31
Basic string class.
Definition TString.h:136
const char * Data() const
Definition TString.h:369
A TTree represents a columnar dataset.
Definition TTree.h:79
Double_t x[n]
Definition legend1.C:17
create variable transformations
Definition file.py:1