1 // @(#)root/tmva $Id$
2 // Author: Omar Zapata
4 #include <TMVA/Envelope.h>
6 #include <TMVA/Configurable.h>
7 #include <TMVA/DataLoader.h>
8 #include <TMVA/MethodBase.h>
9 #include <TMVA/OptionMap.h>
11 #include <TMVA/Types.h>
13 #include <TMVA/VariableInfo.h>
14 #include <TMVA/VariableTransform.h>
16 #include <TAxis.h>
17 #include <TCanvas.h>
18 #include <TFile.h>
19 #include <TGraph.h>
20 #include <TSystem.h>
21 #include <TH2.h>
23 #include <iostream>
25 using namespace TMVA;
27 //_______________________________________________________________________
28 /**
29 Constructor for the initialization of Envelopes,
30 differents Envelopes may needs differents constructors then
31 this is a generic one protected.
32 \param name the name algorithm.
33 \param dataloader TMVA::DataLoader object with the data.
34 \param file optional file to save the results.
35 \param options extra options for the algorithm.
36 */
37 Envelope::Envelope(const TString &name, DataLoader *dalaloader, TFile *file, const TString options)
38  : Configurable(options), fDataLoader(dalaloader), fFile(file), fModelPersistence(kTRUE), fVerbose(kFALSE),
39  fTransformations("I"), fSilentFile(kFALSE), fJobs(1)
40 {
41  SetName(name.Data());
42  // render silent
44  Log().InhibitOutput(); // make sure is silent if wanted to
47  DeclareOptionRef(fVerbose, "V", "Verbose flag");
49  DeclareOptionRef(fModelPersistence, "ModelPersistence",
50  "Option to save the trained model in xml file or using serialization");
51  DeclareOptionRef(fTransformations, "Transformations", "List of transformations to test; formatting example: "
52  "\"Transformations=I;D;P;U;G,D\", for identity, "
53  "decorrelation, PCA, Uniform and Gaussianisation followed by "
54  "decorrelation transformations");
55  DeclareOptionRef(fJobs, "Jobs", "Option to run hign level algorithms in parallel with multi-thread");
56 }
58 //_______________________________________________________________________
60 {
61 }
63 //_______________________________________________________________________
64 /**
65 Method to see if a file is available to save results
66 \return Boolean with the status.
67 */
68 Bool_t Envelope::IsSilentFile(){return fFile==nullptr;}
70 //_______________________________________________________________________
71 /**
72 Method to get the pointer to TFile object.
73 \return pointer to TFile object.
74 */
75 TFile* Envelope::GetFile(){return fFile.get();}
77 //_______________________________________________________________________
78 /**
79 Method to set the pointer to TFile object,
80 with a writable file.
81 \param file pointer to TFile object.
82 */
83 void Envelope::SetFile(TFile *file){fFile=std::shared_ptr<TFile>(file);}
85 //_______________________________________________________________________
86 /**
87 Method to see if the algorithm should print extra information.
88 \return Boolean with the status.
89 */
92 //_______________________________________________________________________
93 /**
94 Method enable print extra information in the algorithms.
95 \param status Boolean with the status.
96 */
97 void Envelope::SetVerbose(Bool_t status){fVerbose=status;}
99 //_______________________________________________________________________
100 /**
101 Method get the Booked methods in a option map object.
102 \return vector of TMVA::OptionMap objects with the information of the Booked method
103 */
104 std::vector<OptionMap> &Envelope::GetMethods()
105 {
106  return fMethods;
107 }
109 //_______________________________________________________________________
110 /**
111 Method to get the pointer to TMVA::DataLoader object.
112 \return pointer to TMVA::DataLoader object.
113 */
117 //_______________________________________________________________________
118 /**
119 Method to set the pointer to TMVA::DataLoader object.
120 \param dalaloader pointer to TMVA::DataLoader object.
121 */
123  fDataLoader=std::shared_ptr<DataLoader>(dalaloader) ;
124 }
126 //_______________________________________________________________________
127 /**
128 Method to see if the algorithm model is saved in xml or serialized files.
129 \return Boolean with the status.
130 */
133 //_______________________________________________________________________
134 /**
135 Method enable model persistence, then algorithms model is saved in xml or serialized files.
136 \param status Boolean with the status.
137 */
140 //_______________________________________________________________________
141 /**
142 Method to book the machine learning method to perform the algorithm.
143 \param method enum TMVA::Types::EMVA with the type of the mva method
144 \param methodtitle String with the method title.
145 \param options String with the options for the method.
146 */
147 void TMVA::Envelope::BookMethod(Types::EMVA method, TString methodTitle, TString options){
148  BookMethod(Types::Instance().GetMethodName(method), methodTitle, options);
149 }
151 //_______________________________________________________________________
152 /**
153 Method to book the machine learning method to perform the algorithm.
154 \param methodname String with the name of the mva method
155 \param methodtitle String with the method title.
156 \param options String with the options for the method.
157 */
158 void TMVA::Envelope::BookMethod(TString methodName, TString methodTitle, TString options){
159  for (auto &meth : fMethods) {
160  if (meth.GetValue<TString>("MethodName") == methodName && meth.GetValue<TString>("MethodTitle") == methodTitle) {
161  Log() << kFATAL << "Booking failed since method with title <" << methodTitle << "> already exists "
162  << "in with DataSet Name <" << fDataLoader->GetName() << "> " << Endl;
163  }
164  }
165  OptionMap fMethod;
166  fMethod["MethodName"] = methodName;
167  fMethod["MethodTitle"] = methodTitle;
168  fMethod["MethodOptions"] = options;
170  fMethods.push_back(fMethod);
171 }
173 //_______________________________________________________________________
174 /**
175 Method to parse the internal option string.
176 */
178 {
180  Bool_t silent = kFALSE;
181 #ifdef WIN32
182  // under Windows, switch progress bar and color off by default, as the typical windows shell doesn't handle these
183  // (would need different sequences..)
184  Bool_t color = kFALSE;
185  Bool_t drawProgressBar = kFALSE;
186 #else
187  Bool_t color = !gROOT->IsBatch();
188  Bool_t drawProgressBar = kTRUE;
189 #endif
190  DeclareOptionRef(color, "Color", "Flag for coloured screen output (default: True, if in batch mode: False)");
191  DeclareOptionRef(drawProgressBar, "DrawProgressBar",
192  "Draw progress bar to display training, testing and evaluation schedule (default: True)");
193  DeclareOptionRef(silent, "Silent", "Batch mode: boolean silent flag inhibiting any output from TMVA after the "
194  "creation of the factory class object (default: False)");
199  if (IsVerbose())
200  Log().SetMinType(kVERBOSE);
202  // global settings
203  gConfig().SetUseColor(color);
204  gConfig().SetSilent(silent);
205  gConfig().SetDrawProgressBar(drawProgressBar);
206 }
208 //_______________________________________________________________________
209 /**
210  * function to check methods booked
211  * \param methodname Method's name.
212  * \param methodtitle title associated to the method.
213  * \return true if the method was booked.
214  */
215 Bool_t TMVA::Envelope::HasMethod(TString methodname, TString methodtitle)
216 {
217  for (auto &meth : fMethods) {
218  if (meth.GetValue<TString>("MethodName") == methodname && meth.GetValue<TString>("MethodTitle") == methodtitle)
219  return kTRUE;
220  }
221  return kFALSE;
222 }
224 //_______________________________________________________________________
225 /**
226  * method to save Train/Test information into the output file.
227  * \param fDataSetInfo TMVA::DataSetInfo object reference
228  * \param fAnalysisType Types::kMulticlass and Types::kRegression
229  */
231 {
232  RootBaseDir()->cd();
234  if (!RootBaseDir()->GetDirectory(fDataSetInfo.GetName()))
235  RootBaseDir()->mkdir(fDataSetInfo.GetName());
236  else
237  return; // loader is now in the output file, we dont need to save again
239  RootBaseDir()->cd(fDataSetInfo.GetName());
240  fDataSetInfo.GetDataSet(); // builds dataset (including calculation of correlation matrix)
242  // correlation matrix of the default DS
243  const TMatrixD *m(0);
244  const TH2 *h(0);
246  if (fAnalysisType == Types::kMulticlass) {
247  for (UInt_t cls = 0; cls < fDataSetInfo.GetNClasses(); cls++) {
248  m = fDataSetInfo.CorrelationMatrix(fDataSetInfo.GetClassInfo(cls)->GetName());
249  h = fDataSetInfo.CreateCorrelationMatrixHist(
250  m, TString("CorrelationMatrix") + fDataSetInfo.GetClassInfo(cls)->GetName(),
251  TString("Correlation Matrix (") + fDataSetInfo.GetClassInfo(cls)->GetName() + TString(")"));
252  if (h != 0) {
253  h->Write();
254  delete h;
255  }
256  }
257  } else {
258  m = fDataSetInfo.CorrelationMatrix("Signal");
259  h = fDataSetInfo.CreateCorrelationMatrixHist(m, "CorrelationMatrixS", "Correlation Matrix (signal)");
260  if (h != 0) {
261  h->Write();
262  delete h;
263  }
265  m = fDataSetInfo.CorrelationMatrix("Background");
266  h = fDataSetInfo.CreateCorrelationMatrixHist(m, "CorrelationMatrixB", "Correlation Matrix (background)");
267  if (h != 0) {
268  h->Write();
269  delete h;
270  }
272  m = fDataSetInfo.CorrelationMatrix("Regression");
273  h = fDataSetInfo.CreateCorrelationMatrixHist(m, "CorrelationMatrix", "Correlation Matrix");
274  if (h != 0) {
275  h->Write();
276  delete h;
277  }
278  }
280  // some default transformations to evaluate
281  // NOTE: all transformations are destroyed after this test
282  TString processTrfs = "I"; //"I;N;D;P;U;G,D;"
284  // plus some user defined transformations
285  processTrfs = fTransformations;
287  // remove any trace of identity transform - if given (avoid to apply it twice)
288  std::vector<TMVA::TransformationHandler *> trfs;
289  TransformationHandler *identityTrHandler = 0;
291  std::vector<TString> trfsDef = gTools().SplitString(processTrfs, ';');
292  std::vector<TString>::iterator trfsDefIt = trfsDef.begin();
293  for (; trfsDefIt != trfsDef.end(); trfsDefIt++) {
294  trfs.push_back(new TMVA::TransformationHandler(fDataSetInfo, "Envelope"));
295  TString trfS = (*trfsDefIt);
297  // Log() << kINFO << Endl;
298  Log() << kDEBUG << "current transformation string: '" << trfS.Data() << "'" << Endl;
299  TMVA::CreateVariableTransforms(trfS, fDataSetInfo, *(trfs.back()), Log());
301  if (trfS.BeginsWith('I'))
302  identityTrHandler = trfs.back();
303  }
305  const std::vector<Event *> &inputEvents = fDataSetInfo.GetDataSet()->GetEventCollection();
307  // apply all transformations
308  std::vector<TMVA::TransformationHandler *>::iterator trfIt = trfs.begin();
310  for (; trfIt != trfs.end(); trfIt++) {
311  // setting a Root dir causes the variables distributions to be saved to the root file
312  (*trfIt)->SetRootDir(RootBaseDir()->GetDirectory(fDataSetInfo.GetName())); // every dataloader have its own dir
313  (*trfIt)->CalcTransformations(inputEvents);
314  }
315  if (identityTrHandler)
316  identityTrHandler->PrintVariableRanking();
318  // clean up
319  for (trfIt = trfs.begin(); trfIt != trfs.end(); trfIt++)
320  delete *trfIt;
321 }
