Logo ROOT   6.10/09
Reference Guide
MethodBase.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne, Jan Therhaag
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodBase *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Virtual base class for all MVA method *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Peter Speckmayer <peter.speckmayer@cern.ch> - CERN, Switzerland *
16  * Joerg Stelzer <Joerg.Stelzer@cern.ch> - CERN, Switzerland *
17  * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
18  * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
19  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
20  * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
21  * *
22  * Copyright (c) 2005-2011: *
23  * CERN, Switzerland *
24  * U. of Victoria, Canada *
25  * MPI-K Heidelberg, Germany *
26  * U. of Bonn, Germany *
27  * *
28  * Redistribution and use in source and binary forms, with or without *
29  * modification, are permitted according to the terms listed in LICENSE *
30  * (http://tmva.sourceforge.net/LICENSE) *
31  **********************************************************************************/
32 
33 #ifndef ROOT_TMVA_MethodBase
34 #define ROOT_TMVA_MethodBase
35 
36 //////////////////////////////////////////////////////////////////////////
37 // //
38 // MethodBase //
39 // //
40 // Virtual base class for all TMVA method //
41 // //
42 //////////////////////////////////////////////////////////////////////////
43 
44 #include <iosfwd>
45 #include <vector>
46 #include <map>
47 #include "assert.h"
48 
49 #include "TString.h"
50 
51 #include "TMVA/IMethod.h"
52 #include "TMVA/Configurable.h"
53 #include "TMVA/Types.h"
54 #include "TMVA/DataSet.h"
55 #include "TMVA/Event.h"
57 #include <TMVA/Results.h>
58 
59 #include <TFile.h>
60 
61 class TGraph;
62 class TTree;
63 class TDirectory;
64 class TSpline;
65 class TH1F;
66 class TH1D;
67 class TMultiGraph;
68 
69 /*! \class TMVA::IPythonInteractive
70 \ingroup TMVA
71 
72 This class is needed by JsMVA, and it's a helper class for tracking errors during
73 the training in Jupyter notebook. It’s only initialized in Jupyter notebook context.
74 In initialization we specify some title, and a TGraph will be created for every title.
75 We can add new data points easily to all TGraphs. These graphs are added to a
76 TMultiGraph, and during an interactive training we get this TMultiGraph object
77 and plot it with JsROOT.
78 */
79 
80 namespace TMVA {
81 
82  class Ranking;
83  class PDF;
84  class TSpline1;
85  class MethodCuts;
86  class MethodBoost;
87  class DataSetInfo;
88 
90  public:
93  void Init(std::vector<TString>& graphTitles);
94  void ClearGraphs();
95  void AddPoint(Double_t x, Double_t y1, Double_t y2);
96  void AddPoint(std::vector<Double_t>& dat);
97  inline TMultiGraph* Get() {return fMultiGraph;}
98  inline bool NotInitialized(){ return fNumGraphs==0;};
99  private:
101  std::vector<TGraph*> fGraphs;
104  };
105 
106  class MethodBase : virtual public IMethod, public Configurable {
107 
108  friend class Factory;
109  friend class RootFinder;
110  friend class MethodBoost;
111  public:
112 
113  enum EWeightFileType { kROOT=0, kTEXT };
114 
115  // default constructor
116  MethodBase( const TString& jobName,
117  Types::EMVA methodType,
118  const TString& methodTitle,
119  DataSetInfo& dsi,
120  const TString& theOption = "" );
121 
122  // constructor used for Testing + Application of the MVA, only (no training),
123  // using given weight file
124  MethodBase( Types::EMVA methodType,
125  DataSetInfo& dsi,
126  const TString& weightFile );
127 
128  // default destructor
129  virtual ~MethodBase();
130 
131  // declaration, processing and checking of configuration options
132  void SetupMethod();
133  void ProcessSetup();
134  virtual void CheckSetup(); // may be overwritten by derived classes
135 
136  // ---------- main training and testing methods ------------------------------
137 
138  // prepare tree branch with the method's discriminating variable
139  void AddOutput( Types::ETreeType type, Types::EAnalysisType analysisType );
140 
141  // performs classifier training
142  // calls methods Train() implemented by derived classes
143  void TrainMethod();
144 
145  // optimize tuning parameters
146  virtual std::map<TString,Double_t> OptimizeTuningParameters(TString fomType="ROCIntegral", TString fitType="FitGA");
147  virtual void SetTuneParameters(std::map<TString,Double_t> tuneParameters);
148 
149  virtual void Train() = 0;
150 
151  // store and retrieve time used for training
152  void SetTrainTime( Double_t trainTime ) { fTrainTime = trainTime; }
153  Double_t GetTrainTime() const { return fTrainTime; }
154 
155  // store and retrieve time used for testing
156  void SetTestTime ( Double_t testTime ) { fTestTime = testTime; }
157  Double_t GetTestTime () const { return fTestTime; }
158 
159  // performs classifier testing
160  virtual void TestClassification();
161  virtual Double_t GetKSTrainingVsTest(Char_t SorB, TString opt="X");
162 
163  // performs multiclass classifier testing
164  virtual void TestMulticlass();
165 
166  // performs regression testing
167  virtual void TestRegression( Double_t& bias, Double_t& biasT,
168  Double_t& dev, Double_t& devT,
169  Double_t& rms, Double_t& rmsT,
170  Double_t& mInf, Double_t& mInfT, // mutual information
171  Double_t& corr,
172  Types::ETreeType type );
173 
174  // options treatment
175  virtual void Init() = 0;
176  virtual void DeclareOptions() = 0;
177  virtual void ProcessOptions() = 0;
178  virtual void DeclareCompatibilityOptions(); // declaration of past options
179 
180  // reset the Method --> As if it was not yet trained, just instantiated
181  // virtual void Reset() = 0;
182  //for the moment, I provide a dummy (that would not work) default, just to make
183  // compilation/running w/o parameter optimisation still possible
184  virtual void Reset(){return;}
185 
186  // classifier response:
187  // some methods may return a per-event error estimate
188  // error calculation is skipped if err==0
189  virtual Double_t GetMvaValue( Double_t* errLower = 0, Double_t* errUpper = 0) = 0;
190 
191  // signal/background classification response
192  Double_t GetMvaValue( const TMVA::Event* const ev, Double_t* err = 0, Double_t* errUpper = 0 );
193 
194  protected:
195  // helper function to set errors to -1
196  void NoErrorCalc(Double_t* const err, Double_t* const errUpper);
197 
198  // signal/background classification response for all current set of data
199  virtual std::vector<Double_t> GetMvaValues(Long64_t firstEvt = 0, Long64_t lastEvt = -1, Bool_t logProgress = false);
200 
201 
202  public:
203  // regression response
204  const std::vector<Float_t>& GetRegressionValues(const TMVA::Event* const ev){
205  fTmpEvent = ev;
206  const std::vector<Float_t>* ptr = &GetRegressionValues();
207  fTmpEvent = 0;
208  return (*ptr);
209  }
210 
211  virtual const std::vector<Float_t>& GetRegressionValues() {
212  std::vector<Float_t>* ptr = new std::vector<Float_t>(0);
213  return (*ptr);
214  }
215 
216  // multiclass classification response
217  virtual const std::vector<Float_t>& GetMulticlassValues() {
218  std::vector<Float_t>* ptr = new std::vector<Float_t>(0);
219  return (*ptr);
220  }
221 
222  // probability of classifier response (mvaval) to be signal (requires "CreateMvaPdf" option set)
223  virtual Double_t GetProba( const Event *ev); // the simple one, automatically calculates the mvaVal and uses the SAME sig/bkg ratio as given in the training sample (typically 50/50 .. (NormMode=EqualNumEvents) but can be different)
224  virtual Double_t GetProba( Double_t mvaVal, Double_t ap_sig );
225 
226  // Rarity of classifier response (signal or background (default) is uniform in [0,1])
227  virtual Double_t GetRarity( Double_t mvaVal, Types::ESBType reftype = Types::kBackground ) const;
228 
229  // create ranking
230  virtual const Ranking* CreateRanking() = 0;
231 
232  // make ROOT-independent C++ class
233  virtual void MakeClass( const TString& classFileName = TString("") ) const;
234 
235  // print help message
236  void PrintHelpMessage() const;
237 
238  //
239  // streamer methods for training information (creates "weight" files) --------
240  //
241  public:
242  void WriteStateToFile () const;
243  void ReadStateFromFile ();
244 
245  protected:
246  // the actual "weights"
247  virtual void AddWeightsXMLTo ( void* parent ) const = 0;
248  virtual void ReadWeightsFromXML ( void* wghtnode ) = 0;
249  virtual void ReadWeightsFromStream( std::istream& ) = 0; // backward compatibility
250  virtual void ReadWeightsFromStream( TFile& ) {} // backward compatibility
251 
252  private:
253  friend class MethodCategory;
254  friend class MethodCompositeBase;
255  void WriteStateToXML ( void* parent ) const;
256  void ReadStateFromXML ( void* parent );
257  void WriteStateToStream ( std::ostream& tf ) const; // needed for MakeClass
258  void WriteVarsToStream ( std::ostream& tf, const TString& prefix = "" ) const; // needed for MakeClass
259 
260 
261  public: // these two need to be public, they are used to read in-memory weight-files
262  void ReadStateFromStream ( std::istream& tf ); // backward compatibility
263  void ReadStateFromStream ( TFile& rf ); // backward compatibility
264  void ReadStateFromXMLString( const char* xmlstr ); // for reading from memory
265 
266  private:
267  // the variable information
268  void AddVarsXMLTo ( void* parent ) const;
269  void AddSpectatorsXMLTo ( void* parent ) const;
270  void AddTargetsXMLTo ( void* parent ) const;
271  void AddClassesXMLTo ( void* parent ) const;
272  void ReadVariablesFromXML ( void* varnode );
273  void ReadSpectatorsFromXML( void* specnode);
274  void ReadTargetsFromXML ( void* tarnode );
275  void ReadClassesFromXML ( void* clsnode );
276  void ReadVarsFromStream ( std::istream& istr ); // backward compatibility
277 
278  public:
279  // ---------------------------------------------------------------------------
280 
281  // write evaluation histograms into target file
282  virtual void WriteEvaluationHistosToFile(Types::ETreeType treetype);
283 
284  // write classifier-specific monitoring information to target file
285  virtual void WriteMonitoringHistosToFile() const;
286 
287  // ---------- public evaluation methods --------------------------------------
288 
289  // individual initialization for testing of each method
290  // overload this one for individual initialisation of the testing,
291  // it is then called automatically within the global "TestInit"
292 
293  // variables (and private member functions) for the Evaluation:
294  // get the efficiency. It fills a histogram for efficiency/vs/bkg
295  // and returns the one value fo the efficiency demanded for
296  // in the TString argument. (Watch the string format)
297  virtual Double_t GetEfficiency( const TString&, Types::ETreeType, Double_t& err );
298  virtual Double_t GetTrainingEfficiency(const TString& );
299  virtual std::vector<Float_t> GetMulticlassEfficiency( std::vector<std::vector<Float_t> >& purity );
300  virtual std::vector<Float_t> GetMulticlassTrainingEfficiency(std::vector<std::vector<Float_t> >& purity );
301  virtual TMatrixD GetMulticlassConfusionMatrix(Double_t effB, Types::ETreeType type);
302  virtual Double_t GetSignificance() const;
303  virtual Double_t GetROCIntegral(TH1D *histS, TH1D *histB) const;
304  virtual Double_t GetROCIntegral(PDF *pdfS=0, PDF *pdfB=0) const;
305  virtual Double_t GetMaximumSignificance( Double_t SignalEvents, Double_t BackgroundEvents,
306  Double_t& optimal_significance_value ) const;
307  virtual Double_t GetSeparation( TH1*, TH1* ) const;
308  virtual Double_t GetSeparation( PDF* pdfS = 0, PDF* pdfB = 0 ) const;
309 
310  virtual void GetRegressionDeviation(UInt_t tgtNum, Types::ETreeType type, Double_t& stddev,Double_t& stddev90Percent ) const;
311  // ---------- public accessors -----------------------------------------------
312 
313  // classifier naming (a lot of names ... aren't they ;-)
314  const TString& GetJobName () const { return fJobName; }
315  const TString& GetMethodName () const { return fMethodName; }
316  TString GetMethodTypeName() const { return Types::Instance().GetMethodName(fMethodType); }
317  Types::EMVA GetMethodType () const { return fMethodType; }
318  const char* GetName () const { return fMethodName.Data(); }
319  const TString& GetTestvarName () const { return fTestvar; }
320  const TString GetProbaName () const { return fTestvar + "_Proba"; }
321  TString GetWeightFileName() const;
322 
323  // build classifier name in Test tree
324  // MVA prefix (e.g., "TMVA_")
325  void SetTestvarName ( const TString & v="" ) { fTestvar = (v=="") ? ("MVA_" + GetMethodName()) : v; }
326 
327  // number of input variable used by classifier
328  UInt_t GetNvar() const { return DataInfo().GetNVariables(); }
329  UInt_t GetNVariables() const { return DataInfo().GetNVariables(); }
330  UInt_t GetNTargets() const { return DataInfo().GetNTargets(); };
331 
332  // internal names and expressions of input variables
333  const TString& GetInputVar ( Int_t i ) const { return DataInfo().GetVariableInfo(i).GetInternalName(); }
334  const TString& GetInputLabel( Int_t i ) const { return DataInfo().GetVariableInfo(i).GetLabel(); }
335  const char * GetInputTitle( Int_t i ) const { return DataInfo().GetVariableInfo(i).GetTitle(); }
336 
337  // normalisation and limit accessors
338  Double_t GetMean( Int_t ivar ) const { return GetTransformationHandler().GetMean(ivar); }
339  Double_t GetRMS ( Int_t ivar ) const { return GetTransformationHandler().GetRMS(ivar); }
340  Double_t GetXmin( Int_t ivar ) const { return GetTransformationHandler().GetMin(ivar); }
341  Double_t GetXmax( Int_t ivar ) const { return GetTransformationHandler().GetMax(ivar); }
342 
343  // sets the minimum requirement on the MVA output to declare an event signal-like
344  Double_t GetSignalReferenceCut() const { return fSignalReferenceCut; }
345  Double_t GetSignalReferenceCutOrientation() const { return fSignalReferenceCutOrientation; }
346 
347  // sets the minimum requirement on the MVA output to declare an event signal-like
348  void SetSignalReferenceCut( Double_t cut ) { fSignalReferenceCut = cut; }
349  void SetSignalReferenceCutOrientation( Double_t cutOrientation ) { fSignalReferenceCutOrientation = cutOrientation; }
350 
351  // pointers to ROOT directories
352  TDirectory* BaseDir() const;
353  TDirectory* MethodBaseDir() const;
354  TFile* GetFile() const {return fFile;}
355 
356  void SetMethodDir ( TDirectory* methodDir ) { fBaseDir = fMethodBaseDir = methodDir; }
357  void SetBaseDir( TDirectory* methodDir ){ fBaseDir = methodDir; }
358  void SetMethodBaseDir( TDirectory* methodDir ){ fMethodBaseDir = methodDir; }
359  void SetFile(TFile* file){fFile=file;}
360 
361  //Silent file
362  void SetSilentFile(Bool_t status){fSilentFile=status;}
363  Bool_t IsSilentFile(){return fSilentFile;}
364 
365  //Model Persistence
366  void SetModelPersistence(Bool_t status){fModelPersistence=status;}//added support to create/remove dir here if exits or not
367  Bool_t IsModelPersistence(){return fModelPersistence;}
368 
369  // the TMVA version can be obtained and checked using
370  // if (GetTrainingTMVAVersionCode()>TMVA_VERSION(3,7,2)) {...}
371  // or
372  // if (GetTrainingROOTVersionCode()>ROOT_VERSION(5,15,5)) {...}
373  UInt_t GetTrainingTMVAVersionCode() const { return fTMVATrainingVersion; }
374  UInt_t GetTrainingROOTVersionCode() const { return fROOTTrainingVersion; }
375  TString GetTrainingTMVAVersionString() const;
376  TString GetTrainingROOTVersionString() const;
377 
379  {
380  if(fTransformationPointer && takeReroutedIfAvailable) return *fTransformationPointer; else return fTransformation;
381  }
382  const TransformationHandler& GetTransformationHandler(Bool_t takeReroutedIfAvailable=true) const
383  {
384  if(fTransformationPointer && takeReroutedIfAvailable) return *fTransformationPointer; else return fTransformation;
385  }
386 
387  void RerouteTransformationHandler (TransformationHandler* fTargetTransformation) { fTransformationPointer=fTargetTransformation; }
388 
389  // ---------- event accessors ------------------------------------------------
390 
391  // returns reference to data set
392  // NOTE: this DataSet is the "original" dataset, i.e. the one seen by ALL Classifiers WITHOUT transformation
393  DataSet* Data() const { return DataInfo().GetDataSet(); }
394  DataSetInfo& DataInfo() const { return fDataSetInfo; }
395 
396  mutable const Event* fTmpEvent; //! temporary event when testing on a different DataSet than the own one
397 
398  // event reference and update
399  // NOTE: these Event accessors make sure that you get the events transformed according to the
400  // particular classifiers transformation chosen
401  UInt_t GetNEvents () const { return Data()->GetNEvents(); }
402  const Event* GetEvent () const;
403  const Event* GetEvent ( const TMVA::Event* ev ) const;
404  const Event* GetEvent ( Long64_t ievt ) const;
405  const Event* GetEvent ( Long64_t ievt , Types::ETreeType type ) const;
406  const Event* GetTrainingEvent( Long64_t ievt ) const;
407  const Event* GetTestingEvent ( Long64_t ievt ) const;
408  const std::vector<TMVA::Event*>& GetEventCollection( Types::ETreeType type );
409 
410  // ---------- public auxiliary methods ---------------------------------------
411 
412  // this method is used to decide whether an event is signal- or background-like
413  // the reference cut "xC" is taken to be where
414  // Int_[-oo,xC] { PDF_S(x) dx } = Int_[xC,+oo] { PDF_B(x) dx }
415  virtual Bool_t IsSignalLike();
416  virtual Bool_t IsSignalLike(Double_t mvaVal);
417 
418 
419  Bool_t HasMVAPdfs() const { return fHasMVAPdfs; }
420  virtual void SetAnalysisType( Types::EAnalysisType type ) { fAnalysisType = type; }
421  Types::EAnalysisType GetAnalysisType() const { return fAnalysisType; }
422  Bool_t DoRegression() const { return fAnalysisType == Types::kRegression; }
423  Bool_t DoMulticlass() const { return fAnalysisType == Types::kMulticlass; }
424 
425  // setter method for suppressing writing to XML and writing of standalone classes
426  void DisableWriting(Bool_t setter){ fModelPersistence = setter?kFALSE:kTRUE; }//DEPRECATED
427 
428  protected:
429  // helper variables for JsMVA
430  IPythonInteractive *fInteractive = nullptr;
431  bool fExitFromTraining = false;
432  UInt_t fIPyMaxIter = 0, fIPyCurrentIter = 0;
433 
434  public:
435 
436  // initializing IPythonInteractive class (for JsMVA only)
437  inline void InitIPythonInteractive(){
438  if (fInteractive) delete fInteractive;
439  fInteractive = new IPythonInteractive();
440  }
441 
442  // get training errors (for JsMVA only)
443  inline TMultiGraph* GetInteractiveTrainingError(){return fInteractive->Get();}
444 
445  // stop's the training process (for JsMVA only)
446  inline void ExitFromTraining(){
447  fExitFromTraining = true;
448  }
449 
450  // check's if the training ended (for JsMVA only)
451  inline bool TrainingEnded(){
452  if (fExitFromTraining && fInteractive){
453  delete fInteractive;
454  fInteractive = nullptr;
455  }
456  return fExitFromTraining;
457  }
458 
459  // get fIPyMaxIter
460  inline UInt_t GetMaxIter(){ return fIPyMaxIter; }
461 
462  // get fIPyCurrentIter
463  inline UInt_t GetCurrentIter(){ return fIPyCurrentIter; }
464 
465  protected:
466 
467  // ---------- protected accessors -------------------------------------------
468 
469  //TDirectory* LocalTDir() const { return Data().LocalRootDir(); }
470 
471  // weight file name and directory (given by global config variable)
472  void SetWeightFileName( TString );
473 
474  const TString& GetWeightFileDir() const { return fFileDir; }
475  void SetWeightFileDir( TString fileDir );
476 
477  // are input variables normalised ?
478  Bool_t IsNormalised() const { return fNormalise; }
479  void SetNormalised( Bool_t norm ) { fNormalise = norm; }
480 
481  // set number of input variables (only used by MethodCuts, could perhaps be removed)
482  // void SetNvar( Int_t n ) { fNvar = n; }
483 
484  // verbose and help flags
485  Bool_t Verbose() const { return fVerbose; }
486  Bool_t Help () const { return fHelp; }
487 
488  // ---------- protected event and tree accessors -----------------------------
489 
490  // names of input variables (if the original names are expressions, they are
491  // transformed into regexps)
492  const TString& GetInternalVarName( Int_t ivar ) const { return (*fInputVars)[ivar]; }
493  const TString& GetOriginalVarName( Int_t ivar ) const { return DataInfo().GetVariableInfo(ivar).GetExpression(); }
494 
495  Bool_t HasTrainingTree() const { return Data()->GetNTrainingEvents() != 0; }
496 
497  // ---------- protected auxiliary methods ------------------------------------
498 
499  protected:
500 
501  // make ROOT-independent C++ class for classifier response (classifier-specific implementation)
502  virtual void MakeClassSpecific( std::ostream&, const TString& = "" ) const {}
503 
504  // header and auxiliary classes
505  virtual void MakeClassSpecificHeader( std::ostream&, const TString& = "" ) const {}
506 
507  // static pointer to this object - required for ROOT finder (to be solved differently)(solved by Omar)
508  //static MethodBase* GetThisBase();
509 
510  // some basic statistical analysis
511  void Statistics( Types::ETreeType treeType, const TString& theVarName,
513  Double_t&, Double_t&, Double_t& );
514 
515  // if TRUE, write weights only to text files
516  Bool_t TxtWeightsOnly() const { return kTRUE; }
517 
518  protected:
519 
520  // access to event information that needs method-specific information
521 
522  Bool_t IsConstructedFromWeightFile() const { return fConstructedFromWeightFile; }
523 
524  private:
525 
526  // ---------- private definitions --------------------------------------------
527  // Initialisation
528  void InitBase();
529  void DeclareBaseOptions();
530  void ProcessBaseOptions();
531 
532  // used in efficiency computation
533  enum ECutOrientation { kNegative = -1, kPositive = +1 };
534  ECutOrientation GetCutOrientation() const { return fCutOrientation; }
535 
536  // ---------- private accessors ---------------------------------------------
537 
538  // reset required for RootFinder
539  void ResetThisBase();
540 
541  // ---------- private auxiliary methods --------------------------------------
542 
543  // PDFs for classifier response (required to compute signal probability and Rarity)
544  void CreateMVAPdfs();
545 
546  // for root finder
547  //virtual method to find ROOT
548  virtual Double_t GetValueForRoot ( Double_t ); // implementation
549 
550  // used for file parsing
551  Bool_t GetLine( std::istream& fin, char * buf );
552 
553  // fill test tree with classification or regression results
554  virtual void AddClassifierOutput ( Types::ETreeType type );
555  virtual void AddClassifierOutputProb( Types::ETreeType type );
556  virtual void AddRegressionOutput ( Types::ETreeType type );
557  virtual void AddMulticlassOutput ( Types::ETreeType type );
558 
559  private:
560 
561  void AddInfoItem( void* gi, const TString& name,
562  const TString& value) const;
563 
564  // ========== class members ==================================================
565 
566  protected:
567 
568  // direct accessors
569  Ranking* fRanking; // pointer to ranking object (created by derived classifiers)
570  std::vector<TString>* fInputVars; // vector of input variables used in MVA
571 
572  // histogram binning
573  Int_t fNbins; // number of bins in input variable histograms
574  Int_t fNbinsMVAoutput; // number of bins in MVA output histograms
575  Int_t fNbinsH; // number of bins in evaluation histograms
576 
577  Types::EAnalysisType fAnalysisType; // method-mode : true --> regression, false --> classification
578 
579  std::vector<Float_t>* fRegressionReturnVal; // holds the return-values for the regression
580  std::vector<Float_t>* fMulticlassReturnVal; // holds the return-values for the multiclass classification
581 
582  private:
583 
584  // MethodCuts redefines some of the evaluation variables and histograms -> must access private members
585  friend class MethodCuts;
586 
587 
588  // data sets
589  DataSetInfo& fDataSetInfo; //! the data set information (sometimes needed)
590 
591  Double_t fSignalReferenceCut; // minimum requirement on the MVA output to declare an event signal-like
592  Double_t fSignalReferenceCutOrientation; // minimum requirement on the MVA output to declare an event signal-like
593  Types::ESBType fVariableTransformType; // this is the event type (sig or bgd) assumed for variable transform
594 
595  // naming and versioning
596  TString fJobName; // name of job -> user defined, appears in weight files
597  TString fMethodName; // name of the method (set in derived class)
598  Types::EMVA fMethodType; // type of method (set in derived class)
599  TString fTestvar; // variable used in evaluation, etc (mostly the MVA)
600  UInt_t fTMVATrainingVersion; // TMVA version used for training
601  UInt_t fROOTTrainingVersion; // ROOT version used for training
602  Bool_t fConstructedFromWeightFile; // is it obtained from weight file?
603 
604  // Directory structure: dataloader/fMethodBaseDir/fBaseDir
605  // where the first directory name is defined by the method type
606  // and the second is user supplied (the title given in Factory::BookMethod())
607  TDirectory* fBaseDir; // base directory for the instance, needed to know where to jump back from localDir
608  mutable TDirectory* fMethodBaseDir; // base directory for the method
609  //this will be the next way to save results
611 
612  //SilentFile
614  //Model Persistence
616 
617  TString fParentDir; // method parent name, like booster name
618 
619  TString fFileDir; // unix sub-directory for weight files (default: DataLoader's Name + "weights")
620  TString fWeightFile; // weight file name
621 
622  private:
623 
624  TH1* fEffS; // efficiency histogram for rootfinder
625 
626  PDF* fDefaultPDF; // default PDF definitions
627  PDF* fMVAPdfS; // signal MVA PDF
628  PDF* fMVAPdfB; // background MVA PDF
629 
630  // TH1D* fmvaS; // PDFs of MVA distribution (signal)
631  // TH1D* fmvaB; // PDFs of MVA distribution (background)
632  PDF* fSplS; // PDFs of MVA distribution (signal)
633  PDF* fSplB; // PDFs of MVA distribution (background)
634  TSpline* fSpleffBvsS; // splines for signal eff. versus background eff.
635 
636  PDF* fSplTrainS; // PDFs of training MVA distribution (signal)
637  PDF* fSplTrainB; // PDFs of training MVA distribution (background)
638  TSpline* fSplTrainEffBvsS; // splines for training signal eff. versus background eff.
639 
640  private:
641 
642  // basic statistics quantities of MVA
643  Double_t fMeanS; // mean (signal)
644  Double_t fMeanB; // mean (background)
645  Double_t fRmsS; // RMS (signal)
646  Double_t fRmsB; // RMS (background)
647  Double_t fXmin; // minimum (signal and background)
648  Double_t fXmax; // maximum (signal and background)
649 
650  // variable preprocessing
651  TString fVarTransformString; // labels variable transform method
652 
653  TransformationHandler* fTransformationPointer; // pointer to the rest of transformations
654  TransformationHandler fTransformation; // the list of transformations
655 
656 
657  // help and verbosity
658  Bool_t fVerbose; // verbose flag
659  TString fVerbosityLevelString; // verbosity level (user input string)
660  EMsgType fVerbosityLevel; // verbosity level
661  Bool_t fHelp; // help flag
662  Bool_t fHasMVAPdfs; // MVA Pdfs are created for this classifier
663 
664  Bool_t fIgnoreNegWeightsInTraining;// If true, events with negative weights are not used in training
665 
666  protected:
667 
668  Bool_t IgnoreEventsWithNegWeightsInTraining() const { return fIgnoreNegWeightsInTraining; }
669 
670  // for signal/background
671  UInt_t fSignalClass; // index of the Signal-class
672  UInt_t fBackgroundClass; // index of the Background-class
673 
674  private:
675 
676  // timing variables
677  Double_t fTrainTime; // for timing measurements
678  Double_t fTestTime; // for timing measurements
679 
680  // orientation of cut: depends on signal and background mean values
681  ECutOrientation fCutOrientation; // +1 if Sig>Bkg, -1 otherwise
682 
683  // for root finder
684  TSpline1* fSplRefS; // helper splines for RootFinder (signal)
685  TSpline1* fSplRefB; // helper splines for RootFinder (background)
686 
687  TSpline1* fSplTrainRefS; // helper splines for RootFinder (signal)
688  TSpline1* fSplTrainRefB; // helper splines for RootFinder (background)
689 
690  mutable std::vector<const std::vector<TMVA::Event*>*> fEventCollections; // if the method needs the complete event-collection, the transformed event coll. ist stored here.
691 
692  public:
693  Bool_t fSetupCompleted; // is method setup
694 
695  private:
696 
697  // This is a workaround for OSx where static thread_local data members are
698  // not supported. The C++ solution would indeed be the following:
699 // static MethodBase*& GetThisBaseThreadLocal() {TTHREAD_TLS(MethodBase*) fgThisBase(nullptr); return fgThisBase; };
700 
701  // ===== depreciated options, kept for backward compatibility =====
702  private:
703 
704  Bool_t fNormalise; // normalise input variables
705  Bool_t fUseDecorr; // synonymous for decorrelation
706  TString fVariableTransformTypeString; // labels variable transform type
707  Bool_t fTxtWeightsOnly; // if TRUE, write weights only to text files
708  Int_t fNbinsMVAPdf; // number of bins used in histogram that creates PDF
709  Int_t fNsmoothMVAPdf; // number of times a histogram is smoothed before creating the PDF
710 
711  protected:
713  ClassDef(MethodBase,0); // Virtual base class for all TMVA method
714 
715  };
716 } // namespace TMVA
717 
718 
719 
720 
721 
722 
723 
724 // ========== INLINE FUNCTIONS =========================================================
725 
726 
727 //_______________________________________________________________________
728 inline const TMVA::Event* TMVA::MethodBase::GetEvent( const TMVA::Event* ev ) const
729 {
730  return GetTransformationHandler().Transform(ev);
731 }
732 
734 {
735  if(fTmpEvent)
736  return GetTransformationHandler().Transform(fTmpEvent);
737  else
738  return GetTransformationHandler().Transform(Data()->GetEvent());
739 }
740 
742 {
743  assert(fTmpEvent==0);
744  return GetTransformationHandler().Transform(Data()->GetEvent(ievt));
745 }
746 
748 {
749  assert(fTmpEvent==0);
750  return GetTransformationHandler().Transform(Data()->GetEvent(ievt, type));
751 }
752 
754 {
755  assert(fTmpEvent==0);
756  return GetEvent(ievt, Types::kTraining);
757 }
758 
760 {
761  assert(fTmpEvent==0);
762  return GetEvent(ievt, Types::kTesting);
763 }
764 
765 #endif
Bool_t HasMVAPdfs() const
Definition: MethodBase.h:419
Types::EAnalysisType fAnalysisType
Definition: MethodBase.h:577
void SetModelPersistence(Bool_t status)
Definition: MethodBase.h:366
TString fMethodName
Definition: MethodBase.h:597
virtual void ReadWeightsFromStream(TFile &)
Definition: MethodBase.h:250
virtual const std::vector< Float_t > & GetMulticlassValues()
Definition: MethodBase.h:217
long long Long64_t
Definition: RtypesCore.h:69
void AddPoint(Double_t x, Double_t y1, Double_t y2)
This function is used only in 2 TGraph case, and it will add new data points to graphs.
Definition: MethodBase.cxx:212
TString GetMethodName(Types::EMVA method) const
Definition: Types.cxx:136
Bool_t fIgnoreNegWeightsInTraining
Definition: MethodBase.h:664
Bool_t IsConstructedFromWeightFile() const
Definition: MethodBase.h:522
virtual void MakeClassSpecificHeader(std::ostream &, const TString &="") const
Definition: MethodBase.h:505
TSpline1 * fSplTrainRefS
Definition: MethodBase.h:687
const TString GetProbaName() const
Definition: MethodBase.h:320
std::vector< TGraph * > fGraphs
Definition: MethodBase.h:101
const TransformationHandler & GetTransformationHandler(Bool_t takeReroutedIfAvailable=true) const
Definition: MethodBase.h:382
UInt_t GetNvar() const
Definition: MethodBase.h:328
static Types & Instance()
the the single instance of "Types" if existing already, or create it (Singleton)
Definition: Types.cxx:70
const TString & GetOriginalVarName(Int_t ivar) const
Definition: MethodBase.h:493
TString fWeightFile
Definition: MethodBase.h:620
TString fVariableTransformTypeString
Definition: MethodBase.h:706
void SetMethodBaseDir(TDirectory *methodDir)
Definition: MethodBase.h:358
Base class for spline implementation containing the Draw/Paint methods.
Definition: TSpline.h:20
TransformationHandler * fTransformationPointer
Definition: MethodBase.h:653
Types::ESBType fVariableTransformType
Definition: MethodBase.h:593
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format...
Definition: TFile.h:46
EAnalysisType
Definition: Types.h:125
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:35
void InitIPythonInteractive()
Definition: MethodBase.h:437
Virtual base Class for all MVA method.
Definition: MethodBase.h:106
const std::vector< Float_t > & GetRegressionValues(const TMVA::Event *const ev)
Definition: MethodBase.h:204
void SetSignalReferenceCutOrientation(Double_t cutOrientation)
Definition: MethodBase.h:349
virtual const std::vector< Float_t > & GetRegressionValues()
Definition: MethodBase.h:211
Basic string class.
Definition: TString.h:129
tomato 1-D histogram with a float per channel (see TH1 documentation)}
Definition: TH1.h:551
void SetTrainTime(Double_t trainTime)
Definition: MethodBase.h:152
TMultiGraph * fMultiGraph
Definition: MethodBase.h:98
const TString & GetInternalVarName(Int_t ivar) const
Definition: MethodBase.h:492
TransformationHandler & GetTransformationHandler(Bool_t takeReroutedIfAvailable=true)
Definition: MethodBase.h:378
Ranking for variables in method (implementation)
Definition: Ranking.h:48
int Int_t
Definition: RtypesCore.h:41
TMultiGraph * Get()
Definition: MethodBase.h:97
bool Bool_t
Definition: RtypesCore.h:59
Results * fResults
Definition: MethodBase.h:712
TString fJobName
Definition: MethodBase.h:596
TSpline1 * fSplRefB
Definition: MethodBase.h:685
UInt_t GetNTargets() const
Definition: MethodBase.h:330
TSpline1 * fSplRefS
Definition: MethodBase.h:684
std::vector< TString > * fInputVars
Definition: MethodBase.h:570
const char * GetInputTitle(Int_t i) const
Definition: MethodBase.h:335
void SetSilentFile(Bool_t status)
Definition: MethodBase.h:362
Double_t fTrainTime
Definition: MethodBase.h:677
Double_t fTestTime
Definition: MethodBase.h:678
Double_t GetMean(Int_t ivar) const
Definition: MethodBase.h:338
Double_t GetTrainTime() const
Definition: MethodBase.h:153
const TString & GetInputLabel(Int_t i) const
Definition: MethodBase.h:334
void SetMethodDir(TDirectory *methodDir)
Definition: MethodBase.h:356
const TString & GetWeightFileDir() const
Definition: MethodBase.h:474
UInt_t fSignalClass
Definition: MethodBase.h:671
const TString & GetInputVar(Int_t i) const
Definition: MethodBase.h:333
Double_t x[n]
Definition: legend1.C:17
DataSetInfo & fDataSetInfo
Definition: MethodBase.h:589
#define ClassDef(name, id)
Definition: Rtypes.h:297
ECutOrientation fCutOrientation
Definition: MethodBase.h:681
Bool_t TxtWeightsOnly() const
Definition: MethodBase.h:516
UInt_t GetTrainingTMVAVersionCode() const
Definition: MethodBase.h:373
const Event * GetEvent() const
Definition: MethodBase.h:733
DataSet * Data() const
Definition: MethodBase.h:393
void ClearGraphs()
This function sets the point number to 0 for all graphs.
Definition: MethodBase.cxx:198
Virtual base class for combining several TMVA method.
~IPythonInteractive()
standard destructor
Definition: MethodBase.cxx:159
Double_t fMeanB
Definition: MethodBase.h:644
std::vector< std::vector< double > > Data
Double_t GetXmin(Int_t ivar) const
Definition: MethodBase.h:340
void Init(std::vector< TString > &graphTitles)
This function gets some title and it creates a TGraph for every title.
Definition: MethodBase.cxx:174
DataSetInfo & DataInfo() const
Definition: MethodBase.h:394
Bool_t DoRegression() const
Definition: MethodBase.h:422
TString fTestvar
Definition: MethodBase.h:599
Class that contains all the data information.
Definition: DataSetInfo.h:60
TFile * GetFile() const
Definition: MethodBase.h:354
PDF wrapper for histograms; uses user-defined spline interpolation.
Definition: PDF.h:63
TSpline * fSpleffBvsS
Definition: MethodBase.h:634
Bool_t fModelPersistence
Definition: MethodBase.h:615
const Event * GetTrainingEvent(Long64_t ievt) const
Definition: MethodBase.h:753
Bool_t Verbose() const
Definition: MethodBase.h:485
UInt_t fTMVATrainingVersion
Definition: MethodBase.h:600
UInt_t GetNEvents() const
temporary event when testing on a different DataSet than the own one
Definition: MethodBase.h:401
Class for boosting a TMVA method.
Definition: MethodBoost.h:56
Double_t GetXmax(Int_t ivar) const
Definition: MethodBase.h:341
TransformationHandler fTransformation
Definition: MethodBase.h:654
Bool_t DoMulticlass() const
Definition: MethodBase.h:423
Class that contains all the data information.
Definition: DataSet.h:69
virtual void MakeClassSpecific(std::ostream &, const TString &="") const
Definition: MethodBase.h:502
const Event * GetTestingEvent(Long64_t ievt) const
Definition: MethodBase.h:759
Bool_t HasTrainingTree() const
Definition: MethodBase.h:495
Double_t fRmsB
Definition: MethodBase.h:646
Double_t fXmin
Definition: MethodBase.h:647
std::string GetMethodName(TCppMethod_t)
Definition: Cppyy.cxx:733
TSpline1 * fSplTrainRefB
Definition: MethodBase.h:688
TDirectory * fMethodBaseDir
Definition: MethodBase.h:608
SVector< double, 2 > v
Definition: Dict.h:5
UInt_t fROOTTrainingVersion
Definition: MethodBase.h:601
const char * GetName() const
Definition: MethodBase.h:318
UInt_t GetTrainingROOTVersionCode() const
Definition: MethodBase.h:374
unsigned int UInt_t
Definition: RtypesCore.h:42
Double_t fMeanS
Definition: MethodBase.h:643
Bool_t Help() const
Definition: MethodBase.h:486
Int_t fNsmoothMVAPdf
Definition: MethodBase.h:709
Bool_t fTxtWeightsOnly
Definition: MethodBase.h:707
const TString & GetJobName() const
Definition: MethodBase.h:314
const TString & GetMethodName() const
Definition: MethodBase.h:315
TDirectory * fBaseDir
Definition: MethodBase.h:607
Bool_t fHasMVAPdfs
Definition: MethodBase.h:662
TSpline * fSplTrainEffBvsS
Definition: MethodBase.h:638
Class that contains all the data information.
This is the main MVA steering class.
Definition: Factory.h:81
tomato 1-D histogram with a double per channel (see TH1 documentation)}
Definition: TH1.h:594
Bool_t IsSilentFile()
Definition: MethodBase.h:363
Linear interpolation of TGraph.
Definition: TSpline1.h:43
Double_t GetSignalReferenceCutOrientation() const
Definition: MethodBase.h:345
void SetNormalised(Bool_t norm)
Definition: MethodBase.h:479
Double_t GetTestTime() const
Definition: MethodBase.h:157
UInt_t GetNVariables() const
Definition: MethodBase.h:329
std::vector< const std::vector< TMVA::Event * > * > fEventCollections
Definition: MethodBase.h:690
const Bool_t kFALSE
Definition: RtypesCore.h:92
TString fVerbosityLevelString
Definition: MethodBase.h:659
Class for categorizing the phase space.
Double_t fRmsS
Definition: MethodBase.h:645
UInt_t fBackgroundClass
Definition: MethodBase.h:672
Bool_t IgnoreEventsWithNegWeightsInTraining() const
Definition: MethodBase.h:668
void RerouteTransformationHandler(TransformationHandler *fTargetTransformation)
Definition: MethodBase.h:387
void SetTestTime(Double_t testTime)
Definition: MethodBase.h:156
Multivariate optimisation of signal efficiency for given background efficiency, applying rectangular ...
Definition: MethodCuts.h:61
UInt_t GetMaxIter()
Definition: MethodBase.h:460
double Double_t
Definition: RtypesCore.h:55
EMsgType fVerbosityLevel
Definition: MethodBase.h:660
Describe directory structure in memory.
Definition: TDirectory.h:34
std::vector< Float_t > * fMulticlassReturnVal
Definition: MethodBase.h:580
Bool_t IsNormalised() const
Definition: MethodBase.h:478
int type
Definition: TGX11.cxx:120
void SetFile(TFile *file)
Definition: MethodBase.h:359
virtual void Reset()
Definition: MethodBase.h:184
The TH1 histogram class.
Definition: TH1.h:56
IPythonInteractive()
standard constructor
Definition: MethodBase.cxx:151
void ExitFromTraining()
Definition: MethodBase.h:446
TString fParentDir
Definition: MethodBase.h:617
Bool_t fConstructedFromWeightFile
Definition: MethodBase.h:602
TString fVarTransformString
Definition: MethodBase.h:651
Interface for all concrete MVA method implementations.
Definition: IMethod.h:54
Types::EMVA fMethodType
Definition: MethodBase.h:598
char Char_t
Definition: RtypesCore.h:29
Double_t GetRMS(Int_t ivar) const
Definition: MethodBase.h:339
Root finding using Brents algorithm (translated from CERNLIB function RZERO)
Definition: RootFinder.h:48
This class is needed by JsMVA, and it&#39;s a helper class for tracking errors during the training in Jup...
Definition: MethodBase.h:89
Abstract ClassifierFactory template that handles arbitrary types.
Ranking * fRanking
Definition: MethodBase.h:569
TString GetMethodTypeName() const
Definition: MethodBase.h:316
Definition: file.py:1
bool TrainingEnded()
Definition: MethodBase.h:451
Class that is the base-class for a vector of result.
Definition: Results.h:57
Double_t fSignalReferenceCut
the data set information (sometimes needed)
Definition: MethodBase.h:591
const Event * fTmpEvent
Definition: MethodBase.h:396
Double_t GetSignalReferenceCut() const
Definition: MethodBase.h:344
A Graph is a graphics object made of two arrays X and Y with npoints each.
Definition: TGraph.h:41
Int_t fNbinsMVAoutput
Definition: MethodBase.h:574
Bool_t fSilentFile
Definition: MethodBase.h:613
UInt_t GetCurrentIter()
Definition: MethodBase.h:463
Double_t fXmax
Definition: MethodBase.h:648
void DisableWriting(Bool_t setter)
Definition: MethodBase.h:426
ECutOrientation GetCutOrientation() const
Definition: MethodBase.h:534
std::vector< Float_t > * fRegressionReturnVal
Definition: MethodBase.h:579
Types::EAnalysisType GetAnalysisType() const
Definition: MethodBase.h:421
A TTree object has a header with a name and a title.
Definition: TTree.h:78
const TString & GetTestvarName() const
Definition: MethodBase.h:319
void SetTestvarName(const TString &v="")
Definition: MethodBase.h:325
TString fFileDir
Definition: MethodBase.h:619
TMultiGraph * GetInteractiveTrainingError()
Definition: MethodBase.h:443
const Bool_t kTRUE
Definition: RtypesCore.h:91
double norm(double *x, double *p)
Definition: unuranDistr.cxx:40
Types::EMVA GetMethodType() const
Definition: MethodBase.h:317
void SetBaseDir(TDirectory *methodDir)
Definition: MethodBase.h:357
virtual void SetAnalysisType(Types::EAnalysisType type)
Definition: MethodBase.h:420
Bool_t fSetupCompleted
Definition: MethodBase.h:693
void SetSignalReferenceCut(Double_t cut)
Definition: MethodBase.h:348
Double_t fSignalReferenceCutOrientation
Definition: MethodBase.h:592
Bool_t IsModelPersistence()
Definition: MethodBase.h:367