Logo ROOT   6.14/05
Reference Guide
MethodCrossValidation.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Kim Albertsson
3 
4 /*************************************************************************
5  * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 /*! \class TMVA::MethodCrossValidation
13 \ingroup TMVA
14 */
16 
17 #include "TMVA/ClassifierFactory.h"
18 #include "TMVA/Config.h"
19 #include "TMVA/CvSplit.h"
20 #include "TMVA/MethodCategory.h"
21 #include "TMVA/Tools.h"
22 #include "TMVA/Types.h"
23 
24 #include "TSystem.h"
25 
26 REGISTER_METHOD(CrossValidation)
27 
29 
30 ////////////////////////////////////////////////////////////////////////////////
31 ///
32 
34  DataSetInfo &theData, const TString &theOption)
35  : TMVA::MethodBase(jobName, Types::kCrossValidation, methodTitle, theData, theOption), fSplitExpr(nullptr)
36 {
37 }
38 
39 ////////////////////////////////////////////////////////////////////////////////
40 
42  : TMVA::MethodBase(Types::kCrossValidation, theData, theWeightFile), fSplitExpr(nullptr)
43 {
44 }
45 
46 ////////////////////////////////////////////////////////////////////////////////
47 /// Destructor.
48 ///
49 
51 
52 ////////////////////////////////////////////////////////////////////////////////
53 
55 {
56  DeclareOptionRef(fEncapsulatedMethodName, "EncapsulatedMethodName", "");
57  DeclareOptionRef(fEncapsulatedMethodTypeName, "EncapsulatedMethodTypeName", "");
58  DeclareOptionRef(fNumFolds, "NumFolds", "Number of folds to generate");
59  DeclareOptionRef(fOutputEnsembling = TString("None"), "OutputEnsembling",
60  "Combines output from contained methods. If None, no combination is performed. (default None)");
61  AddPreDefVal(TString("None"));
62  AddPreDefVal(TString("Avg"));
63  DeclareOptionRef(fSplitExprString, "SplitExpr", "The expression used to assign events to folds");
64 }
65 
66 ////////////////////////////////////////////////////////////////////////////////
67 /// Options that are used ONLY for the READER to ensure backward compatibility.
68 
70 {
72 }
73 
74 ////////////////////////////////////////////////////////////////////////////////
75 /// The option string is decoded, for available options see "DeclareOptions".
76 
78 {
79  Log() << kDEBUG << "ProcessOptions -- fNumFolds: " << fNumFolds << Endl;
80  Log() << kDEBUG << "ProcessOptions -- fEncapsulatedMethodName: " << fEncapsulatedMethodName << Endl;
81  Log() << kDEBUG << "ProcessOptions -- fEncapsulatedMethodTypeName: " << fEncapsulatedMethodTypeName << Endl;
82 
83  if (fSplitExprString != TString("")) {
84  fSplitExpr = std::unique_ptr<CvSplitKFoldsExpr>(new CvSplitKFoldsExpr(DataInfo(), fSplitExprString));
85  }
86 
87  for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
88  TString weightfile = GetWeightFileNameForFold(iFold);
89 
90  Log() << kINFO << "Reading weightfile: " << weightfile << Endl;
92  fEncapsulatedMethods.push_back(fold_method);
93  }
94 }
95 
96 ////////////////////////////////////////////////////////////////////////////////
97 /// Common initialisation with defaults for the Method.
98 
100 {
101  fMulticlassValues = std::vector<Float_t>(DataInfo().GetNClasses());
102  fRegressionValues = std::vector<Float_t>(DataInfo().GetNTargets());
103 }
104 
105 ////////////////////////////////////////////////////////////////////////////////
106 /// Reset the method, as if it had just been instantiated (forget all training etc.).
107 
109 
110 ////////////////////////////////////////////////////////////////////////////////
111 /// \brief Returns filename of weight file for a given fold.
112 /// \param[in] iFold Ordinal of the fold. Range: 0 to NumFolds exclusive.
113 ///
115 {
116  if (iFold >= fNumFolds) {
117  Log() << kFATAL << iFold << " out of range. "
118  << "Should be < " << fNumFolds << "." << Endl;
119  }
120 
121  TString foldStr = Form("fold%i", iFold + 1);
123  TString weightfile = fileDir + "/" + fJobName + "_" + fEncapsulatedMethodName + "_" + foldStr + ".weights.xml";
124 
125  return weightfile;
126 }
127 
128 ////////////////////////////////////////////////////////////////////////////////
129 /// Call the Optimizer with the set of parameters and ranges that
130 /// are meant to be tuned.
131 
132 // std::map<TString,Double_t> TMVA::MethodCrossValidation::OptimizeTuningParameters(TString fomType, TString fitType)
133 // {
134 // }
135 
136 ////////////////////////////////////////////////////////////////////////////////
137 /// Set the tuning parameters according to the argument.
138 
139 // void TMVA::MethodCrossValidation::SetTuneParameters(std::map<TString,Double_t> tuneParameters)
140 // {
141 // }
142 
143 ////////////////////////////////////////////////////////////////////////////////
144 /// training.
145 
147 
148 ////////////////////////////////////////////////////////////////////////////////
149 /// \brief Reads in a weight file an instantiates the corresponding method
150 /// \param[in] methodTypeName Canonical name of the method type. E.g. `"BDT"`
151 /// for Boosted Decision Trees.
152 /// \param[in] weightfile File to read method parameters from
155 {
156  TMVA::MethodBase *m = dynamic_cast<MethodBase *>(
157  ClassifierFactory::Instance().Create(std::string(methodTypeName.Data()), DataInfo(), weightfile));
158 
159  if (m->GetMethodType() == Types::kCategory) {
160  Log() << kFATAL << "MethodCategory not supported for the moment." << Endl;
161  }
162 
163  TString fileDir = TString(DataInfo().GetName()) + "/" + gConfig().GetIONames().fWeightFileDir;
164  m->SetWeightFileDir(fileDir);
165  // m->SetModelPersistence(fModelPersistence);
166  // m->SetSilentFile(IsSilentFile());
168  m->SetupMethod();
169  m->ReadStateFromFile();
170  // m->SetTestvarName(testvarName);
171 
172  return m;
173 }
174 
175 ////////////////////////////////////////////////////////////////////////////////
176 /// Write weights to XML.
177 
179 {
180  void *wght = gTools().AddChild(parent, "Weights");
181 
182  gTools().AddAttr(wght, "JobName", fJobName);
183  gTools().AddAttr(wght, "SplitExpr", fSplitExprString);
184  gTools().AddAttr(wght, "NumFolds", fNumFolds);
185  gTools().AddAttr(wght, "EncapsulatedMethodName", fEncapsulatedMethodName);
186  gTools().AddAttr(wght, "EncapsulatedMethodTypeName", fEncapsulatedMethodTypeName);
187  gTools().AddAttr(wght, "OutputEnsembling", fOutputEnsembling);
188 
189  for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
190  TString weightfile = GetWeightFileNameForFold(iFold);
191 
192  // TODO: Add a swithch in options for using either split files or only one.
193  // TODO: This would store the method inside MethodCrossValidation
194  // Another option is to store the folds as separate files.
195  // // Retrieve encap. method for fold n
196  // MethodBase * method = InstantiateMethodFromXML(fEncapsulatedMethodTypeName, weightfile);
197  //
198  // // Serialise encapsulated method for fold n
199  // void* foldNode = gTools().AddChild(parent, foldStr);
200  // method->WriteStateToXML(foldNode);
201  }
202 }
203 
204 ////////////////////////////////////////////////////////////////////////////////
205 /// Reads from the xml file.
206 ///
207 
209 {
210  gTools().ReadAttr(parent, "JobName", fJobName);
211  gTools().ReadAttr(parent, "SplitExpr", fSplitExprString);
212  gTools().ReadAttr(parent, "NumFolds", fNumFolds);
213  gTools().ReadAttr(parent, "EncapsulatedMethodName", fEncapsulatedMethodName);
214  gTools().ReadAttr(parent, "EncapsulatedMethodTypeName", fEncapsulatedMethodTypeName);
215  gTools().ReadAttr(parent, "OutputEnsembling", fOutputEnsembling);
216 
217  // Read in methods for all folds
218  for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
219  TString weightfile = GetWeightFileNameForFold(iFold);
220 
221  Log() << kINFO << "Reading weightfile: " << weightfile << Endl;
223  fEncapsulatedMethods.push_back(fold_method);
224  }
225 
226  // SplitExpr
227  if (fSplitExprString != TString("")) {
228  fSplitExpr = std::unique_ptr<CvSplitKFoldsExpr>(new CvSplitKFoldsExpr(DataInfo(), fSplitExprString));
229  }
230 }
231 
232 ////////////////////////////////////////////////////////////////////////////////
233 /// Read the weights
234 ///
235 
237 {
238  Log() << kFATAL << "CrossValidation currently supports only reading from XML." << Endl;
239 }
240 
241 ////////////////////////////////////////////////////////////////////////////////
242 ///
243 
245 {
246  const Event *ev = GetEvent();
247 
248  if (fOutputEnsembling == "None") {
249  if (fSplitExpr != nullptr) {
250  // K-folds with a deterministic split
251  UInt_t iFold = fSplitExpr->Eval(fNumFolds, ev);
252  return fEncapsulatedMethods.at(iFold)->GetMvaValue(err, errUpper);
253  } else {
254  // K-folds with a random split was used
255  UInt_t iFold = fEventToFoldMapping.at(Data()->GetEvent());
256  return fEncapsulatedMethods.at(iFold)->GetMvaValue(err, errUpper);
257  }
258  } else if (fOutputEnsembling == "Avg") {
259  Double_t val = 0.0;
260  for (auto &m : fEncapsulatedMethods) {
261  val += m->GetMvaValue(err, errUpper);
262  }
263  return val / fEncapsulatedMethods.size();
264  } else {
265  Log() << kFATAL << "Ensembling type " << fOutputEnsembling << " unknown" << Endl;
266  return 0; // Cannot happen
267  }
268 }
269 
270 ////////////////////////////////////////////////////////////////////////////////
271 /// Get the multiclass MVA response.
272 
274 {
275  const Event *ev = GetEvent();
276 
277  if (fOutputEnsembling == "None") {
278  if (fSplitExpr != nullptr) {
279  // K-folds with a deterministic split
280  UInt_t iFold = fSplitExpr->Eval(fNumFolds, ev);
281  return fEncapsulatedMethods.at(iFold)->GetMulticlassValues();
282  } else {
283  // K-folds with a random split was used
284  UInt_t iFold = fEventToFoldMapping.at(Data()->GetEvent());
285  return fEncapsulatedMethods.at(iFold)->GetMulticlassValues();
286  }
287  } else if (fOutputEnsembling == "Avg") {
288 
289  for (auto &e : fMulticlassValues) {
290  e = 0;
291  }
292 
293  for (auto &m : fEncapsulatedMethods) {
294  auto methodValues = m->GetMulticlassValues();
295  for (size_t i = 0; i < methodValues.size(); ++i) {
296  fMulticlassValues[i] += methodValues[i];
297  }
298  }
299 
300  for (auto &e : fMulticlassValues) {
301  e /= fEncapsulatedMethods.size();
302  }
303 
304  return fMulticlassValues;
305 
306  } else {
307  Log() << kFATAL << "Ensembling type " << fOutputEnsembling << " unknown" << Endl;
308  return fMulticlassValues; // Cannot happen
309  }
310 }
311 
312 ////////////////////////////////////////////////////////////////////////////////
313 /// Get the regression value generated by the containing methods.
314 
316 {
317  const Event *ev = GetEvent();
318 
319  if (fOutputEnsembling == "None") {
320  if (fSplitExpr != nullptr) {
321  // K-folds with a deterministic split
322  UInt_t iFold = fSplitExpr->Eval(fNumFolds, ev);
323  return fEncapsulatedMethods.at(iFold)->GetRegressionValues();
324  } else {
325  // K-folds with a random split was used
326  UInt_t iFold = fEventToFoldMapping.at(Data()->GetEvent());
327  return fEncapsulatedMethods.at(iFold)->GetRegressionValues();
328  }
329  } else if (fOutputEnsembling == "Avg") {
330 
331  for (auto &e : fRegressionValues) {
332  e = 0;
333  }
334 
335  for (auto &m : fEncapsulatedMethods) {
336  auto methodValues = m->GetRegressionValues();
337  for (size_t i = 0; i < methodValues.size(); ++i) {
338  fRegressionValues[i] += methodValues[i];
339  }
340  }
341 
342  for (auto &e : fRegressionValues) {
343  e /= fEncapsulatedMethods.size();
344  }
345 
346  return fRegressionValues;
347 
348  } else {
349  Log() << kFATAL << "Ensembling type " << fOutputEnsembling << " unknown" << Endl;
350  return fRegressionValues; // Cannot happen
351  }
352 }
353 
354 ////////////////////////////////////////////////////////////////////////////////
355 ///
356 
358 {
359  // // Used for evaluation, which is outside the life time of MethodCrossEval.
360  // Log() << kFATAL << "Method CrossValidation should not be created manually,"
361  // " only as part of using TMVA::Reader." << Endl;
362  // return;
363 }
364 
365 ////////////////////////////////////////////////////////////////////////////////
366 ///
367 
369 {
370  Log() << kWARNING
371  << "Method CrossValidation should not be created manually,"
372  " only as part of using TMVA::Reader."
373  << Endl;
374 }
375 
376 ////////////////////////////////////////////////////////////////////////////////
377 ///
378 
380 {
381  return nullptr;
382 }
383 
384 ////////////////////////////////////////////////////////////////////////////////
385 
387  UInt_t /*numberTargets*/)
388 {
389  return kTRUE;
390  // if (fEncapsulatedMethods.size() == 0) {return kFALSE;}
391  // if (fEncapsulatedMethods.at(0) == nullptr) {return kFALSE;}
392  // return fEncapsulatedMethods.at(0)->HasAnalysisType(type, numberClasses, numberTargets);
393 }
394 
395 ////////////////////////////////////////////////////////////////////////////////
396 /// Make ROOT-independent C++ class for classifier response (classifier-specific implementation).
397 
398 void TMVA::MethodCrossValidation::MakeClassSpecific(std::ostream & /*fout*/, const TString & /*className*/) const
399 {
400  Log() << kWARNING << "MakeClassSpecific not implemented for CrossValidation" << Endl;
401 }
402 
403 ////////////////////////////////////////////////////////////////////////////////
404 /// Specific class header.
405 
406 void TMVA::MethodCrossValidation::MakeClassSpecificHeader(std::ostream & /*fout*/, const TString & /*className*/) const
407 {
408  Log() << kWARNING << "MakeClassSpecificHeader not implemented for CrossValidation" << Endl;
409 }
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
Types::EAnalysisType fAnalysisType
Definition: MethodBase.h:584
void ReadWeightsFromStream(std::istream &istr)
Read the weights.
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
MethodBase * InstantiateMethodFromXML(TString methodTypeName, TString weightfile) const
Reads in a weight file an instantiates the corresponding method.
Singleton class for Global types used by TMVA.
Definition: Types.h:73
auto * m
Definition: textangle.C:8
void WriteMonitoringHistosToFile(void) const
write special monitoring histograms to file dummy implementation here --------------— ...
void ReadWeightsFromXML(void *parent)
Reads from the xml file.
Config & gConfig()
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
MsgLogger & Log() const
Definition: Configurable.h:122
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
EAnalysisType
Definition: Types.h:127
Virtual base Class for all MVA method.
Definition: MethodBase.h:109
void DeclareCompatibilityOptions()
Options that are used ONLY for the READER to ensure backward compatibility.
Basic string class.
Definition: TString.h:131
Ranking for variables in method (implementation)
Definition: Ranking.h:48
virtual const char * DirName(const char *pathname)
Return the directory name in pathname.
Definition: TSystem.cxx:1004
bool Bool_t
Definition: RtypesCore.h:59
TString fJobName
Definition: MethodBase.h:603
UInt_t GetNClasses() const
Definition: DataSetInfo.h:136
void Reset(void)
Reset the method, as if it had just been instantiated (forget all training etc.). ...
const std::vector< Float_t > & GetRegressionValues()
Get the regression value generated by the containing methods.
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition: Tools.h:353
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1136
const Event * GetEvent() const
Definition: MethodBase.h:740
DataSet * Data() const
Definition: MethodBase.h:400
TString fWeightFileDir
Definition: Config.h:112
void ReadStateFromFile()
Function to write options and weights to file.
MethodCrossValidation(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
IONames & GetIONames()
Definition: Config.h:90
std::vector< Float_t > fMulticlassValues
void SetupMethod()
setup of methods
Definition: MethodBase.cxx:411
DataSetInfo & DataInfo() const
Definition: MethodBase.h:401
Class that contains all the data information.
Definition: DataSetInfo.h:60
std::unique_ptr< CvSplitKFoldsExpr > fSplitExpr
virtual ~MethodCrossValidation(void)
Destructor.
UInt_t GetNTargets() const
Definition: DataSetInfo.h:111
R__EXTERN TSystem * gSystem
Definition: TSystem.h:540
const char * GetName() const
Definition: MethodBase.h:325
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
unsigned int UInt_t
Definition: RtypesCore.h:42
void ProcessOptions()
The option string is decoded, for available options see "DeclareOptions".
char * Form(const char *fmt,...)
void Train(void)
Call the Optimizer with the set of parameters and ranges that are meant to be tuned.
void MakeClassSpecificHeader(std::ostream &, const TString &) const
Specific class header.
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:335
Tools & gTools()
TString GetWeightFileName() const
retrieve weight file name
void Init(void)
Common initialisation with defaults for the Method.
#define ClassImp(name)
Definition: Rtypes.h:359
double Double_t
Definition: RtypesCore.h:55
you should not use this method at all Int_t Int_t Double_t Double_t Double_t e
Definition: TRolke.cxx:630
std::vector< Float_t > fRegressionValues
void AddPreDefVal(const T &)
Definition: Configurable.h:168
#define REGISTER_METHOD(CLASS)
for example
TString GetWeightFileNameForFold(UInt_t iFold) const
Returns filename of weight file for a given fold.
Abstract ClassifierFactory template that handles arbitrary types.
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
Definition: MethodBase.cxx:601
void AddWeightsXMLTo(void *parent) const
Write weights to XML.
void SetWeightFileDir(TString fileDir)
set directory of weight file
const std::vector< Float_t > & GetMulticlassValues()
Get the multiclass MVA response.
std::vector< MethodBase * > fEncapsulatedMethods
std::map< const TMVA::Event *, UInt_t > fEventToFoldMapping
const Bool_t kTRUE
Definition: RtypesCore.h:87
Types::EMVA GetMethodType() const
Definition: MethodBase.h:324
virtual void SetAnalysisType(Types::EAnalysisType type)
Definition: MethodBase.h:427
void MakeClassSpecific(std::ostream &, const TString &) const
Make ROOT-independent C++ class for classifier response (classifier-specific implementation).
const char * Data() const
Definition: TString.h:364