Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
Envelope.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Omar Zapata, Kim Albertsson
3
4/*************************************************************************
5 * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12#include <TMVA/Envelope.h>
13
14#include <TMVA/Configurable.h>
15#include <TMVA/DataLoader.h>
16#include <TMVA/MethodBase.h>
17#include <TMVA/OptionMap.h>
19#include <TMVA/Types.h>
20
21#include <TMVA/VariableInfo.h>
23
24#include <TAxis.h>
25#include <TFile.h>
26#include <TH2.h>
27
28using namespace TMVA;
29
30//_______________________________________________________________________
31/**
32Constructor for the initialization of Envelopes,
33differents Envelopes may needs differents constructors then
34this is a generic one protected.
35\param name the name algorithm.
36\param dataloader TMVA::DataLoader object with the data.
37\param file optional file to save the results.
38\param options extra options for the algorithm.
39*/
40Envelope::Envelope(const TString &name, DataLoader *dataloader, TFile *file, const TString options)
41 : Configurable(options), fDataLoader(dataloader), fFile(file), fModelPersistence(kTRUE), fVerbose(kFALSE),
42 fTransformations("I"), fSilentFile(kFALSE), fJobs(1)
43{
44 SetName(name.Data());
45 // render silent
46 if (gTools().CheckForSilentOption(GetOptions()))
47 Log().InhibitOutput(); // make sure is silent if wanted to
48
50 DeclareOptionRef(fVerbose, "V", "Verbose flag");
51
52 DeclareOptionRef(fModelPersistence, "ModelPersistence",
53 "Option to save the trained model in xml file or using serialization");
54 DeclareOptionRef(fTransformations, "Transformations", "List of transformations to test; formatting example: "
55 "\"Transformations=I;D;P;U;G,D\", for identity, "
56 "decorrelation, PCA, Uniform and Gaussianisation followed by "
57 "decorrelation transformations");
58 DeclareOptionRef(fJobs, "Jobs", "Option to run hign level algorithms in parallel with multi-thread");
59}
60
61//_______________________________________________________________________
63{
64}
65
66//_______________________________________________________________________
67/**
68Method to see if a file is available to save results
69\return Boolean with the status.
70*/
72
73//_______________________________________________________________________
74/**
75Method to get the pointer to TFile object.
76\return pointer to TFile object.
77*/
78TFile* Envelope::GetFile(){return fFile.get();}
79
80//_______________________________________________________________________
81/**
82Method to set the pointer to TFile object,
83with a writable file.
84\param file pointer to TFile object.
85*/
86void Envelope::SetFile(TFile *file){fFile=std::shared_ptr<TFile>(file);}
87
88//_______________________________________________________________________
89/**
90Method to see if the algorithm should print extra information.
91\return Boolean with the status.
92*/
94
95//_______________________________________________________________________
96/**
97Method enable print extra information in the algorithms.
98\param status Boolean with the status.
99*/
101
102//_______________________________________________________________________
103/**
104Method get the Booked methods in a option map object.
105\return vector of TMVA::OptionMap objects with the information of the Booked method
106*/
107std::vector<OptionMap> &Envelope::GetMethods()
108{
109 return fMethods;
110}
111
112//_______________________________________________________________________
113/**
114Method to get the pointer to TMVA::DataLoader object.
115\return pointer to TMVA::DataLoader object.
116*/
117
119
120//_______________________________________________________________________
121/**
122Method to set the pointer to TMVA::DataLoader object.
123\param dataloader pointer to TMVA::DataLoader object.
124*/
125
127{
128 fDataLoader = std::shared_ptr<DataLoader>(dataloader);
129}
130
131//_______________________________________________________________________
132/**
133Method to see if the algorithm model is saved in xml or serialized files.
134\return Boolean with the status.
135*/
136Bool_t TMVA::Envelope::IsModelPersistence(){return fModelPersistence; }
137
138//_______________________________________________________________________
139/**
140Method enable model persistence, then algorithms model is saved in xml or serialized files.
141\param status Boolean with the status.
142*/
143void TMVA::Envelope::SetModelPersistence(Bool_t status){fModelPersistence=status;}
144
145//_______________________________________________________________________
146/**
147Method to book the machine learning method to perform the algorithm.
148\param method enum TMVA::Types::EMVA with the type of the mva method
149\param methodTitle String with the method title.
150\param options String with the options for the method.
151*/
152void TMVA::Envelope::BookMethod(Types::EMVA method, TString methodTitle, TString options){
153 BookMethod(Types::Instance().GetMethodName(method), methodTitle, options);
154}
155
156//_______________________________________________________________________
157/**
158Method to book the machine learning method to perform the algorithm.
159\param methodName String with the name of the mva method
160\param methodTitle String with the method title.
161\param options String with the options for the method.
162*/
163void TMVA::Envelope::BookMethod(TString methodName, TString methodTitle, TString options){
164 for (auto &meth : fMethods) {
165 if (meth.GetValue<TString>("MethodName") == methodName && meth.GetValue<TString>("MethodTitle") == methodTitle) {
166 Log() << kFATAL << "Booking failed since method with title <" << methodTitle << "> already exists "
167 << "in with DataSet Name <" << fDataLoader->GetName() << "> " << Endl;
168 }
169 }
170 OptionMap fMethod;
171 fMethod["MethodName"] = methodName;
172 fMethod["MethodTitle"] = methodTitle;
173 fMethod["MethodOptions"] = options;
174
175 fMethods.push_back(fMethod);
176}
177
178//_______________________________________________________________________
179/**
180Method to parse the internal option string.
181*/
183{
184
185 Bool_t silent = kFALSE;
186#ifdef WIN32
187 // under Windows, switch progress bar and color off by default, as the typical windows shell doesn't handle these
188 // (would need different sequences..)
189 Bool_t color = kFALSE;
190 Bool_t drawProgressBar = kFALSE;
191#else
192 Bool_t color = !gROOT->IsBatch();
193 Bool_t drawProgressBar = kTRUE;
194#endif
195 DeclareOptionRef(color, "Color", "Flag for coloured screen output (default: True, if in batch mode: False)");
196 DeclareOptionRef(drawProgressBar, "DrawProgressBar",
197 "Draw progress bar to display training, testing and evaluation schedule (default: True)");
198 DeclareOptionRef(silent, "Silent", "Batch mode: boolean silent flag inhibiting any output from TMVA after the "
199 "creation of the factory class object (default: False)");
200
202 CheckForUnusedOptions();
203
204 if (IsVerbose())
205 Log().SetMinType(kVERBOSE);
206
207 // global settings
208 gConfig().SetUseColor(color);
209 gConfig().SetSilent(silent);
210 gConfig().SetDrawProgressBar(drawProgressBar);
211}
212
213//_______________________________________________________________________
214/**
215 * function to check methods booked
216 * \param methodname Method's name.
217 * \param methodtitle title associated to the method.
218 * \return true if the method was booked.
219 */
221{
222 for (auto &meth : fMethods) {
223 if (meth.GetValue<TString>("MethodName") == methodname && meth.GetValue<TString>("MethodTitle") == methodtitle)
224 return kTRUE;
225 }
226 return kFALSE;
227}
228
229//_______________________________________________________________________
230/**
231 * method to save Train/Test information into the output file.
232 * \param fDataSetInfo TMVA::DataSetInfo object reference
233 * \param fAnalysisType Types::kMulticlass and Types::kRegression
234 */
236{
237 RootBaseDir()->cd();
238
239 if (!RootBaseDir()->GetDirectory(fDataSetInfo.GetName()))
240 RootBaseDir()->mkdir(fDataSetInfo.GetName());
241 else
242 return; // loader is now in the output file, we dont need to save again
243
244 RootBaseDir()->cd(fDataSetInfo.GetName());
245 fDataSetInfo.GetDataSet(); // builds dataset (including calculation of correlation matrix)
246
247 // correlation matrix of the default DS
248 const TMatrixD *m(0);
249 const TH2 *h(0);
250
251 if (fAnalysisType == Types::kMulticlass) {
252 for (UInt_t cls = 0; cls < fDataSetInfo.GetNClasses(); cls++) {
253 m = fDataSetInfo.CorrelationMatrix(fDataSetInfo.GetClassInfo(cls)->GetName());
254 h = fDataSetInfo.CreateCorrelationMatrixHist(
255 m, TString("CorrelationMatrix") + fDataSetInfo.GetClassInfo(cls)->GetName(),
256 TString("Correlation Matrix (") + fDataSetInfo.GetClassInfo(cls)->GetName() + TString(")"));
257 if (h != 0) {
258 h->Write();
259 delete h;
260 }
261 }
262 } else {
263 m = fDataSetInfo.CorrelationMatrix("Signal");
264 h = fDataSetInfo.CreateCorrelationMatrixHist(m, "CorrelationMatrixS", "Correlation Matrix (signal)");
265 if (h != 0) {
266 h->Write();
267 delete h;
268 }
269
270 m = fDataSetInfo.CorrelationMatrix("Background");
271 h = fDataSetInfo.CreateCorrelationMatrixHist(m, "CorrelationMatrixB", "Correlation Matrix (background)");
272 if (h != 0) {
273 h->Write();
274 delete h;
275 }
276
277 m = fDataSetInfo.CorrelationMatrix("Regression");
278 h = fDataSetInfo.CreateCorrelationMatrixHist(m, "CorrelationMatrix", "Correlation Matrix");
279 if (h != 0) {
280 h->Write();
281 delete h;
282 }
283 }
284
285 // some default transformations to evaluate
286 // NOTE: all transformations are destroyed after this test
287 TString processTrfs = "I"; //"I;N;D;P;U;G,D;"
288
289 // plus some user defined transformations
290 processTrfs = fTransformations;
291
292 // remove any trace of identity transform - if given (avoid to apply it twice)
293 std::vector<TMVA::TransformationHandler *> trfs;
294 TransformationHandler *identityTrHandler = 0;
295
296 std::vector<TString> trfsDef = gTools().SplitString(processTrfs, ';');
297 std::vector<TString>::iterator trfsDefIt = trfsDef.begin();
298 for (; trfsDefIt != trfsDef.end(); ++trfsDefIt) {
299 trfs.push_back(new TMVA::TransformationHandler(fDataSetInfo, "Envelope"));
300 TString trfS = (*trfsDefIt);
301
302 // Log() << kINFO << Endl;
303 Log() << kDEBUG << "current transformation string: '" << trfS.Data() << "'" << Endl;
304 TMVA::CreateVariableTransforms(trfS, fDataSetInfo, *(trfs.back()), Log());
305
306 if (trfS.BeginsWith('I'))
307 identityTrHandler = trfs.back();
308 }
309
310 const std::vector<Event *> &inputEvents = fDataSetInfo.GetDataSet()->GetEventCollection();
311
312 // apply all transformations
313 std::vector<TMVA::TransformationHandler *>::iterator trfIt = trfs.begin();
314
315 for (; trfIt != trfs.end(); ++trfIt) {
316 // setting a Root dir causes the variables distributions to be saved to the root file
317 (*trfIt)->SetRootDir(RootBaseDir()->GetDirectory(fDataSetInfo.GetName())); // every dataloader have its own dir
318 (*trfIt)->CalcTransformations(inputEvents);
319 }
320 if (identityTrHandler)
321 identityTrHandler->PrintVariableRanking();
322
323 // clean up
324 for (trfIt = trfs.begin(); trfIt != trfs.end(); ++trfIt)
325 delete *trfIt;
326}
#define h(i)
Definition RSha256.hxx:106
const Bool_t kFALSE
Definition RtypesCore.h:92
const Bool_t kTRUE
Definition RtypesCore.h:91
char name[80]
Definition TGX11.cxx:110
#define gROOT
Definition TROOT.h:406
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition TFile.h:54
Service class for 2-Dim histogram classes.
Definition TH2.h:30
void SetDrawProgressBar(Bool_t d)
Definition Config.h:71
void SetUseColor(Bool_t uc)
Definition Config.h:62
void SetSilent(Bool_t s)
Definition Config.h:65
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
virtual void ParseOptions()
options parser
const TString & GetOptions() const
MsgLogger & Log() const
Class that contains all the data information.
Definition DataSetInfo.h:62
virtual const char * GetName() const
Returns name of object.
Definition DataSetInfo.h:71
const TMatrixD * CorrelationMatrix(const TString &className) const
UInt_t GetNClasses() const
DataSet * GetDataSet() const
returns data set
TH2 * CreateCorrelationMatrixHist(const TMatrixD *m, const TString &hName, const TString &hTitle) const
ClassInfo * GetClassInfo(Int_t clNum) const
const std::vector< Event * > & GetEventCollection(Types::ETreeType type=Types::kMaxTreeType) const
Definition DataSet.h:216
Bool_t HasMethod(TString methodname, TString methodtitle)
function to check methods booked
Definition Envelope.cxx:220
~Envelope()
Default destructor.
Definition Envelope.cxx:62
Bool_t IsModelPersistence()
Method to see if the algorithm model is saved in xml or serialized files.
Definition Envelope.cxx:136
std::shared_ptr< TFile > fFile
data
Definition Envelope.h:48
DataLoader * GetDataLoader()
Method to get the pointer to TMVA::DataLoader object.
Definition Envelope.cxx:118
Bool_t fModelPersistence
file to save the results
Definition Envelope.h:49
Bool_t IsSilentFile()
Method to see if a file is available to save results.
Definition Envelope.cxx:71
void SetDataLoader(DataLoader *dalaloader)
Method to set the pointer to TMVA::DataLoader object.
Definition Envelope.cxx:126
virtual void BookMethod(TString methodname, TString methodtitle, TString options="")
Method to book the machine learning method to perform the algorithm.
Definition Envelope.cxx:163
std::vector< OptionMap > fMethods
Definition Envelope.h:46
void SetVerbose(Bool_t status)
Method enable print extra information in the algorithms.
Definition Envelope.cxx:100
void SetFile(TFile *file)
Method to set the pointer to TFile object, with a writable file.
Definition Envelope.cxx:86
Bool_t IsVerbose()
Method to see if the algorithm should print extra information.
Definition Envelope.cxx:93
Bool_t fVerbose
flag to save the trained model
Definition Envelope.h:50
void SetModelPersistence(Bool_t status=kTRUE)
Method enable model persistence, then algorithms model is saved in xml or serialized files.
Definition Envelope.cxx:143
std::shared_ptr< DataLoader > fDataLoader
Booked method information.
Definition Envelope.h:47
virtual void ParseOptions()
Method to parse the internal option string.
Definition Envelope.cxx:182
TFile * GetFile()
Method to get the pointer to TFile object.
Definition Envelope.cxx:78
std::vector< OptionMap > & GetMethods()
Method get the Booked methods in a option map object.
Definition Envelope.cxx:107
TString fTransformations
flag for extra information
Definition Envelope.h:51
UInt_t fJobs
procpool object
Definition Envelope.h:56
Envelope(const TString &name, DataLoader *dataloader=nullptr, TFile *file=nullptr, const TString options="")
timer to measute the time.
Definition Envelope.cxx:40
void WriteDataInformation(TMVA::DataSetInfo &fDataSetInfo, TMVA::Types::EAnalysisType fAnalysisType)
method to save Train/Test information into the output file.
Definition Envelope.cxx:235
static void InhibitOutput()
Definition MsgLogger.cxx:73
class to storage options for the differents methods
Definition OptionMap.h:34
std::vector< TString > SplitString(const TString &theOpt, const char separator) const
splits the option string at 'separator' and fills the list 'splitV' with the primitive strings
Definition Tools.cxx:1211
Class that contains all the data information.
void PrintVariableRanking() const
prints ranking of input variables
static Types & Instance()
the the single instance of "Types" if existing already, or create it (Singleton)
Definition Types.cxx:69
@ kMulticlass
Definition Types.h:131
virtual void SetName(const char *name)
Set the name of the TNamed.
Definition TNamed.cxx:140
virtual const char * GetName() const
Returns name of object.
Definition TNamed.h:47
virtual Int_t Write(const char *name=0, Int_t option=0, Int_t bufsize=0)
Write this object to the current directory.
Definition TObject.cxx:798
Basic string class.
Definition TString.h:136
const char * Data() const
Definition TString.h:369
Bool_t BeginsWith(const char *s, ECaseCompare cmp=kExact) const
Definition TString.h:615
create variable transformations
Config & gConfig()
Tools & gTools()
void CreateVariableTransforms(const TString &trafoDefinition, TMVA::DataSetInfo &dataInfo, TMVA::TransformationHandler &transformationHandler, TMVA::MsgLogger &log)
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:158
Definition file.py:1
auto * m
Definition textangle.C:8