Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
Factory.h
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne, Jan Therhaag
3// Updated by: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer
4
5/**********************************************************************************
6 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
7 * Package: TMVA *
8 * Class : Factory *
9 * *
10 * *
11 * Description: *
12 * This is the main MVA steering class: it creates (books) all MVA methods, *
13 * and guides them through the training, testing and evaluation phases. *
14 * *
15 * Authors (alphabetical): *
16 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
17 * Joerg Stelzer <stelzer@cern.ch> - DESY, Germany *
18 * Peter Speckmayer <peter.speckmayer@cern.ch> - CERN, Switzerland *
19 * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
20 * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
21 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
22 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
23 * Omar Zapata <Omar.Zapata@cern.ch> - UdeA/ITM Colombia *
24 * Lorenzo Moneta <Lorenzo.Moneta@cern.ch> - CERN, Switzerland *
25 * Sergei Gleyzer <Sergei.Gleyzer@cern.ch> - U of Florida & CERN *
26 * *
27 * Copyright (c) 2005-2011: *
28 * CERN, Switzerland *
29 * U. of Victoria, Canada *
30 * MPI-K Heidelberg, Germany *
31 * U. of Bonn, Germany *
32 * UdeA/ITM, Colombia *
33 * U. of Florida, USA *
34 * *
35 * Redistribution and use in source and binary forms, with or without *
36 * modification, are permitted according to the terms listed in LICENSE *
37 * (see tmva/doc/LICENSE) *
38 **********************************************************************************/
39
40#ifndef ROOT_TMVA_Factory
41#define ROOT_TMVA_Factory
42
43//////////////////////////////////////////////////////////////////////////
44// //
45// Factory //
46// //
47// This is the main MVA steering class: it creates all MVA methods, //
48// and guides them through the training, testing and evaluation //
49// phases //
50// //
51//////////////////////////////////////////////////////////////////////////
52
53#include <vector>
54#include <map>
55#include "TCut.h"
56
57#include "TMVA/Configurable.h"
58#include "TMVA/Types.h"
59#include "TMVA/DataSet.h"
60
61class TCanvas;
62class TDirectory;
63class TFile;
64class TGraph;
65class TH1F;
66class TMultiGraph;
67class TTree;
68namespace TMVA {
69
70 class IMethod;
71 class MethodBase;
72 class DataInputHandler;
73 class DataSetInfo;
74 class DataSetManager;
75 class DataLoader;
76 class ROCCurve;
77 class VariableTransformBase;
78
79
80 class Factory : public Configurable {
81 friend class CrossValidation;
82 public:
83
84 typedef std::vector<IMethod*> MVector;
85 std::map<TString,MVector*> fMethodsMap;//all methods for every dataset with the same name
86
87 // no default constructor
89
90 // constructor to work without file
92
93 // default destructor
94 virtual ~Factory();
95
96 // use TName::GetName and define correct name in constructor
97 //virtual const char* GetName() const { return "Factory"; }
98
99 // Internal wrapper type that can be constructed either like a TString or
100 // from a Types::EMVA enum value and stores the resolved TString. This
101 // avoids the need for multiple overloads of BookMethod.
103 public:
104 template <typename T, typename = std::enable_if_t<std::is_constructible_v<TString, T &&>>>
106 {
107 }
108 MethodName(Types::EMVA method) : fName(Types::Instance().GetMethodName(method)) {}
109 TString const &tString() const { return fName; }
110
111 private:
113 };
114
116
117 // optimize all booked methods (well, if desired by the method)
118 std::map<TString,Double_t> OptimizeAllMethods (TString fomType="ROCIntegral", TString fitType="FitGA");
121
122 // training for all booked methods
123 void TrainAllMethods ();
126
127 // testing
128 void TestAllMethods();
129
130 // performance evaluation
131 void EvaluateAllMethods( void );
132 void EvaluateAllVariables(DataLoader *loader, TString options = "" );
133
134 TH1F* EvaluateImportance( DataLoader *loader,VIType vitype, Types::EMVA theMethod, TString methodTitle, const char *theOption = "" );
135
136 // delete all methods and reset the method vector
137 void DeleteAllMethods( void );
138
139 // accessors
140 IMethod* GetMethod( const TString& datasetname, const TString& title ) const;
141 Bool_t HasMethod( const TString& datasetname, const TString& title ) const;
142
143 Bool_t Verbose( void ) const { return fVerbose; }
144 void SetVerbose( Bool_t v=kTRUE );
145
146 // make ROOT-independent C++ class for classifier response
147 // (classifier-specific implementation)
148 // If no classifier name is given, help messages for all booked
149 // classifiers are printed
150 virtual void MakeClass(const TString& datasetname , const TString& methodTitle = "" ) const;
151
152 // prints classifier-specific help messages, dedicated to
153 // help with the optimisation and configuration options tuning.
154 // If no classifier name is given, help messages for all booked
155 // classifiers are printed
156 void PrintHelpMessage(const TString& datasetname , const TString& methodTitle = "" ) const;
157
159
160 Bool_t IsSilentFile() const { return fSilentFile;}
162
167
168 // Methods to get a TGraph for an indicated method in dataset.
169 // Optional title and axis added with fLegend=kTRUE.
170 // Argument iClass used in multiclass settings, otherwise ignored.
175
176 // Methods to get a TMultiGraph for a given class and all methods in dataset.
179
180 // Draw all ROC curves of a given class for all methods in the dataset.
183
184 private:
185
186 // the beautiful greeting message
187 void Greetings();
188
189 //evaluate the simple case that is removing 1 variable at time
191 //evaluate all variables combinations
193 //evaluate randomly given a number of seeds
195
196 TH1F* GetImportance(const int nbits,std::vector<Double_t> importances,std::vector<TString> varNames);
197
198 // Helpers for public facing ROC methods
203
204 void WriteDataInformation(DataSetInfo& fDataSetInfo);
205
207
209
210 private:
211
212 // data members
213
214 TFile* fgTargetFile; ///<! ROOT output file
215
216
217 std::vector<TMVA::VariableTransformBase*> fDefaultTrfs; ///<! list of transformations on default DataSet
218
219 // cd to local directory
220 TString fOptions; ///<! option string given by construction (presently only "V")
221 TString fTransformations; ///<! list of transformations to test
222 Bool_t fVerbose; ///<! verbose mode
223 TString fVerboseLevel; ///<! verbosity level, controls granularity of logging
224 Bool_t fCorrelations; ///<! enable to calculate correlations
225 Bool_t fROC; ///<! enable to calculate ROC values
226 Bool_t fSilentFile; ///<! used in constructor without file
227
228 TString fJobName; ///<! jobname, used as extension in weight file names
229
230 Types::EAnalysisType fAnalysisType; ///<! the training type
231 Bool_t fModelPersistence;///<! option to save the trained model in xml file or using serialization
232
233
234 protected:
235
236 ClassDefOverride(Factory,0); // The factory creates all MVA methods, and performs their training and testing
237 };
238
239} // namespace TMVA
240
241#endif
bool Bool_t
Boolean (0=false, 1=true) (bool)
Definition RtypesCore.h:77
unsigned int UInt_t
Unsigned integer 4 bytes (unsigned int)
Definition RtypesCore.h:60
double Double_t
Double 8 bytes.
Definition RtypesCore.h:73
constexpr Bool_t kTRUE
Definition RtypesCore.h:107
#define ClassDefOverride(name, id)
Definition Rtypes.h:348
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
char name[80]
Definition TGX11.cxx:110
The Canvas class.
Definition TCanvas.h:23
Describe directory structure in memory.
Definition TDirectory.h:45
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Definition TFile.h:131
A TGraph is an object made of two arrays X and Y with npoints each.
Definition TGraph.h:41
1-D histogram with a float per channel (see TH1 documentation)
Definition TH1.h:879
Class to perform cross validation, splitting the dataloader into folds.
Class that contains all the data information.
Definition DataSetInfo.h:62
MethodName(Types::EMVA method)
Definition Factory.h:108
TString const & tString() const
Definition Factory.h:109
This is the main MVA steering class.
Definition Factory.h:80
void PrintHelpMessage(const TString &datasetname, const TString &methodTitle="") const
Print predefined help message of classifier.
Definition Factory.cxx:1328
Bool_t fSilentFile
! used in constructor without file
Definition Factory.h:226
Bool_t fCorrelations
! enable to calculate correlations
Definition Factory.h:224
Bool_t IsModelPersistence() const
Definition Factory.h:161
TString fOptions
! option string given by construction (presently only "V")
Definition Factory.h:220
std::vector< IMethod * > MVector
Definition Factory.h:84
void TrainAllMethods()
Iterates through all booked methods and calls training.
Definition Factory.cxx:1109
Bool_t Verbose(void) const
Definition Factory.h:143
void WriteDataInformation(DataSetInfo &fDataSetInfo)
Definition Factory.cxx:597
Factory(TString theJobName, TFile *theTargetFile, TString theOption="")
Standard constructor.
Definition Factory.cxx:113
void TestAllMethods()
Evaluates all booked methods on the testing data and adds the output to the Results in the corresponi...
Definition Factory.cxx:1266
void TrainAllMethodsForClassification(void)
Definition Factory.h:124
Bool_t fVerbose
! verbose mode
Definition Factory.h:222
void EvaluateAllMethods(void)
Iterates over all MVAs that have been booked, and calls their evaluation methods.
Definition Factory.cxx:1371
TH1F * EvaluateImportanceRandom(DataLoader *loader, UInt_t nseeds, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Definition Factory.cxx:2468
TH1F * GetImportance(const int nbits, std::vector< Double_t > importances, std::vector< TString > varNames)
Definition Factory.cxx:2586
Bool_t fROC
! enable to calculate ROC values
Definition Factory.h:225
void EvaluateAllVariables(DataLoader *loader, TString options="")
Iterates over all MVA input variables and evaluates them.
Definition Factory.cxx:1355
TDirectory * RootBaseDir()
Definition Factory.h:158
TString fVerboseLevel
! verbosity level, controls granularity of logging
Definition Factory.h:223
TMultiGraph * GetROCCurveAsMultiGraph(DataLoader *loader, UInt_t iClass, Types::ETreeType type=Types::kTesting)
Generate a collection of graphs, for all methods for a given class.
Definition Factory.cxx:983
TH1F * EvaluateImportance(DataLoader *loader, VIType vitype, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Evaluate Variable Importance.
Definition Factory.cxx:2212
void OptimizeAllMethodsForRegression(TString fomType="ROCIntegral", TString fitType="FitGA")
Definition Factory.h:120
Double_t GetROCIntegral(DataLoader *loader, TString theMethodName, UInt_t iClass=0, Types::ETreeType type=Types::kTesting)
Calculate the integral of the ROC curve, also known as the area under curve (AUC),...
Definition Factory.cxx:844
std::map< TString, MVector * > fMethodsMap
Definition Factory.h:85
void SetInputTreesFromEventAssignTrees()
virtual ~Factory()
Destructor.
Definition Factory.cxx:306
MethodBase * BookMethod(DataLoader *loader, MethodName theMethodName, TString methodTitle, TString theOption="")
Books an MVA classifier or regression method.
Definition Factory.cxx:358
virtual void MakeClass(const TString &datasetname, const TString &methodTitle="") const
Definition Factory.cxx:1300
MethodBase * BookMethodWeightfile(DataLoader *dataloader, TMVA::Types::EMVA methodType, const TString &weightfile)
Adds an already constructed method to be managed by this factory.
Definition Factory.cxx:496
Bool_t fModelPersistence
! option to save the trained model in xml file or using serialization
Definition Factory.h:231
std::map< TString, Double_t > OptimizeAllMethods(TString fomType="ROCIntegral", TString fitType="FitGA")
Iterates through all booked methods and sees if they use parameter tuning and if so does just that,...
Definition Factory.cxx:696
void OptimizeAllMethodsForClassification(TString fomType="ROCIntegral", TString fitType="FitGA")
Definition Factory.h:119
ROCCurve * GetROC(DataLoader *loader, TString theMethodName, UInt_t iClass=0, Types::ETreeType type=Types::kTesting)
Private method to generate a ROCCurve instance for a given method.
Definition Factory.cxx:744
Bool_t IsSilentFile() const
Definition Factory.h:160
TH1F * EvaluateImportanceShort(DataLoader *loader, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Definition Factory.cxx:2353
Types::EAnalysisType fAnalysisType
! the training type
Definition Factory.h:230
TString fJobName
! jobname, used as extension in weight file names
Definition Factory.h:228
Bool_t HasMethod(const TString &datasetname, const TString &title) const
Checks whether a given method name is defined for a given dataset.
Definition Factory.cxx:581
TGraph * GetROCCurve(DataLoader *loader, TString theMethodName, Bool_t setTitles=kTRUE, UInt_t iClass=0, Types::ETreeType type=Types::kTesting)
Argument iClass specifies the class to generate the ROC curve in a multiclass setting.
Definition Factory.cxx:907
void TrainAllMethodsForRegression(void)
Definition Factory.h:125
TH1F * EvaluateImportanceAll(DataLoader *loader, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Definition Factory.cxx:2241
void SetVerbose(Bool_t v=kTRUE)
Definition Factory.cxx:343
TFile * fgTargetFile
! ROOT output file
Definition Factory.h:214
std::vector< TMVA::VariableTransformBase * > fDefaultTrfs
! list of transformations on default DataSet
Definition Factory.h:217
IMethod * GetMethod(const TString &datasetname, const TString &title) const
Returns pointer to MVA that corresponds to given method title.
Definition Factory.cxx:561
void DeleteAllMethods(void)
Delete methods.
Definition Factory.cxx:324
TString fTransformations
! list of transformations to test
Definition Factory.h:221
void Greetings()
Print welcome message.
Definition Factory.cxx:295
Interface for all concrete MVA method implementations.
Definition IMethod.h:53
Virtual base Class for all MVA method.
Definition MethodBase.h:111
Singleton class for Global types used by TMVA.
Definition Types.h:71
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition TMultiGraph.h:34
Basic string class.
Definition TString.h:138
A TTree represents a columnar dataset.
Definition TTree.h:89
void forward(const LAYERDATA &prevLayerData, LAYERDATA &currLayerData)
apply the weights (and functions) in forward direction of the DNN
create variable transformations