Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodCrossValidation.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Kim Albertsson
3
4/*************************************************************************
5 * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12/*! \class TMVA::MethodCrossValidation
13\ingroup TMVA
14*/
16
18#include "TMVA/Config.h"
19#include "TMVA/CvSplit.h"
20#include "TMVA/MethodCategory.h"
21#include "TMVA/Tools.h"
22#include "TMVA/Types.h"
23
24#include "TSystem.h"
25
26REGISTER_METHOD(CrossValidation)
27
29
30////////////////////////////////////////////////////////////////////////////////
31///
32
34 DataSetInfo &theData, const TString &theOption)
35 : TMVA::MethodBase(jobName, Types::kCrossValidation, methodTitle, theData, theOption), fSplitExpr(nullptr)
36{
37}
38
39////////////////////////////////////////////////////////////////////////////////
40
42 : TMVA::MethodBase(Types::kCrossValidation, theData, theWeightFile), fSplitExpr(nullptr)
43{
44}
45
46////////////////////////////////////////////////////////////////////////////////
47/// Destructor.
48///
49
51
52////////////////////////////////////////////////////////////////////////////////
53
55{
56 DeclareOptionRef(fEncapsulatedMethodName, "EncapsulatedMethodName", "");
57 DeclareOptionRef(fEncapsulatedMethodTypeName, "EncapsulatedMethodTypeName", "");
58 DeclareOptionRef(fNumFolds, "NumFolds", "Number of folds to generate");
59 DeclareOptionRef(fOutputEnsembling = TString("None"), "OutputEnsembling",
60 "Combines output from contained methods. If None, no combination is performed. (default None)");
61 AddPreDefVal(TString("None"));
62 AddPreDefVal(TString("Avg"));
63 DeclareOptionRef(fSplitExprString, "SplitExpr", "The expression used to assign events to folds");
64}
65
66////////////////////////////////////////////////////////////////////////////////
67/// Options that are used ONLY for the READER to ensure backward compatibility.
68
70{
72}
73
74////////////////////////////////////////////////////////////////////////////////
75/// The option string is decoded, for available options see "DeclareOptions".
76
78{
79 Log() << kDEBUG << "ProcessOptions -- fNumFolds: " << fNumFolds << Endl;
80 Log() << kDEBUG << "ProcessOptions -- fEncapsulatedMethodName: " << fEncapsulatedMethodName << Endl;
81 Log() << kDEBUG << "ProcessOptions -- fEncapsulatedMethodTypeName: " << fEncapsulatedMethodTypeName << Endl;
82
83 if (fSplitExprString != TString("")) {
84 fSplitExpr = std::unique_ptr<CvSplitKFoldsExpr>(new CvSplitKFoldsExpr(DataInfo(), fSplitExprString));
85 }
86
87 for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
88 TString weightfile = GetWeightFileNameForFold(iFold);
89
90 Log() << kINFO << "Reading weightfile: " << weightfile << Endl;
91 MethodBase *fold_method = InstantiateMethodFromXML(fEncapsulatedMethodTypeName, weightfile);
92 fEncapsulatedMethods.push_back(fold_method);
93 }
94}
95
96////////////////////////////////////////////////////////////////////////////////
97/// Common initialisation with defaults for the Method.
98
100{
101 fMulticlassValues = std::vector<Float_t>(DataInfo().GetNClasses());
102 fRegressionValues = std::vector<Float_t>(DataInfo().GetNTargets());
103}
104
105////////////////////////////////////////////////////////////////////////////////
106/// Reset the method, as if it had just been instantiated (forget all training etc.).
107
109
110////////////////////////////////////////////////////////////////////////////////
111/// \brief Returns filename of weight file for a given fold.
112/// \param[in] iFold Ordinal of the fold. Range: 0 to NumFolds exclusive.
113///
115{
116 if (iFold >= fNumFolds) {
117 Log() << kFATAL << iFold << " out of range. "
118 << "Should be < " << fNumFolds << "." << Endl;
119 }
120
121 TString foldStr = Form("fold%i", iFold + 1);
122 TString fileDir = gSystem->GetDirName(GetWeightFileName());
123 TString weightfile = fileDir + "/" + fJobName + "_" + fEncapsulatedMethodName + "_" + foldStr + ".weights.xml";
124
125 return weightfile;
126}
127
128////////////////////////////////////////////////////////////////////////////////
129/// Call the Optimizer with the set of parameters and ranges that
130/// are meant to be tuned.
131
132// std::map<TString,Double_t> TMVA::MethodCrossValidation::OptimizeTuningParameters(TString fomType, TString fitType)
133// {
134// }
135
136////////////////////////////////////////////////////////////////////////////////
137/// Set the tuning parameters according to the argument.
138
139// void TMVA::MethodCrossValidation::SetTuneParameters(std::map<TString,Double_t> tuneParameters)
140// {
141// }
142
143////////////////////////////////////////////////////////////////////////////////
144/// training.
145
147
148////////////////////////////////////////////////////////////////////////////////
149/// \brief Reads in a weight file an instantiates the corresponding method
150/// \param[in] methodTypeName Canonical name of the method type. E.g. `"BDT"`
151/// for Boosted Decision Trees.
152/// \param[in] weightfile File to read method parameters from
155{
156 TMVA::MethodBase *m = dynamic_cast<MethodBase *>(
157 ClassifierFactory::Instance().Create(std::string(methodTypeName.Data()), DataInfo(), weightfile));
158
159 if (m->GetMethodType() == Types::kCategory) {
160 Log() << kFATAL << "MethodCategory not supported for the moment." << Endl;
161 }
162
163 TString fileDir = TString(DataInfo().GetName()) + "/" + gConfig().GetIONames().fWeightFileDir;
164 m->SetWeightFileDir(fileDir);
165 // m->SetModelPersistence(fModelPersistence);
166 // m->SetSilentFile(IsSilentFile());
167 m->SetAnalysisType(fAnalysisType);
168 m->SetupMethod();
169 m->ReadStateFromFile();
170 // m->SetTestvarName(testvarName);
171
172 return m;
173}
174
175////////////////////////////////////////////////////////////////////////////////
176/// Write weights to XML.
177
179{
180 void *wght = gTools().AddChild(parent, "Weights");
181
182 gTools().AddAttr(wght, "JobName", fJobName);
183 gTools().AddAttr(wght, "SplitExpr", fSplitExprString);
184 gTools().AddAttr(wght, "NumFolds", fNumFolds);
185 gTools().AddAttr(wght, "EncapsulatedMethodName", fEncapsulatedMethodName);
186 gTools().AddAttr(wght, "EncapsulatedMethodTypeName", fEncapsulatedMethodTypeName);
187 gTools().AddAttr(wght, "OutputEnsembling", fOutputEnsembling);
188
189 for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
190 TString weightfile = GetWeightFileNameForFold(iFold);
191
192 // TODO: Add a swithch in options for using either split files or only one.
193 // TODO: This would store the method inside MethodCrossValidation
194 // Another option is to store the folds as separate files.
195 // // Retrieve encap. method for fold n
196 // MethodBase * method = InstantiateMethodFromXML(fEncapsulatedMethodTypeName, weightfile);
197 //
198 // // Serialise encapsulated method for fold n
199 // void* foldNode = gTools().AddChild(parent, foldStr);
200 // method->WriteStateToXML(foldNode);
201 }
202}
203
204////////////////////////////////////////////////////////////////////////////////
205/// Reads from the xml file.
206///
207
209{
210 gTools().ReadAttr(parent, "JobName", fJobName);
211 gTools().ReadAttr(parent, "SplitExpr", fSplitExprString);
212 gTools().ReadAttr(parent, "NumFolds", fNumFolds);
213 gTools().ReadAttr(parent, "EncapsulatedMethodName", fEncapsulatedMethodName);
214 gTools().ReadAttr(parent, "EncapsulatedMethodTypeName", fEncapsulatedMethodTypeName);
215 gTools().ReadAttr(parent, "OutputEnsembling", fOutputEnsembling);
216
217 // Read in methods for all folds
218 for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
219 TString weightfile = GetWeightFileNameForFold(iFold);
220
221 Log() << kINFO << "Reading weightfile: " << weightfile << Endl;
222 MethodBase *fold_method = InstantiateMethodFromXML(fEncapsulatedMethodTypeName, weightfile);
223 fEncapsulatedMethods.push_back(fold_method);
224 }
225
226 // SplitExpr
227 if (fSplitExprString != TString("")) {
228 fSplitExpr = std::unique_ptr<CvSplitKFoldsExpr>(new CvSplitKFoldsExpr(DataInfo(), fSplitExprString));
229 } else {
230 Log() << kFATAL << "MethodCrossValidation supports XML reading only for deterministic splitting !" << Endl;
231 }
232}
233
234////////////////////////////////////////////////////////////////////////////////
235/// Read the weights
236///
237
239{
240 Log() << kFATAL << "CrossValidation currently supports only reading from XML." << Endl;
241}
242
243////////////////////////////////////////////////////////////////////////////////
244///
245
247{
248 const Event *ev = GetEvent();
249
250 if (fOutputEnsembling == "None") {
251 if (fSplitExpr != nullptr) {
252 // K-folds with a deterministic split
253 UInt_t iFold = fSplitExpr->Eval(fNumFolds, ev);
254 return fEncapsulatedMethods.at(iFold)->GetMvaValue(err, errUpper);
255 } else {
256 // K-folds with a random split was used
257 UInt_t iFold = fEventToFoldMapping.at(Data()->GetEvent());
258 return fEncapsulatedMethods.at(iFold)->GetMvaValue(err, errUpper);
259 }
260 } else if (fOutputEnsembling == "Avg") {
261 Double_t val = 0.0;
262 for (auto &m : fEncapsulatedMethods) {
263 val += m->GetMvaValue(err, errUpper);
264 }
265 return val / fEncapsulatedMethods.size();
266 } else {
267 Log() << kFATAL << "Ensembling type " << fOutputEnsembling << " unknown" << Endl;
268 return 0; // Cannot happen
269 }
270}
271
272////////////////////////////////////////////////////////////////////////////////
273/// Get the multiclass MVA response.
274
276{
277 const Event *ev = GetEvent();
278
279 if (fOutputEnsembling == "None") {
280 if (fSplitExpr != nullptr) {
281 // K-folds with a deterministic split
282 UInt_t iFold = fSplitExpr->Eval(fNumFolds, ev);
283 return fEncapsulatedMethods.at(iFold)->GetMulticlassValues();
284 } else {
285 // K-folds with a random split was used
286 UInt_t iFold = fEventToFoldMapping.at(Data()->GetEvent());
287 return fEncapsulatedMethods.at(iFold)->GetMulticlassValues();
288 }
289 } else if (fOutputEnsembling == "Avg") {
290
291 for (auto &e : fMulticlassValues) {
292 e = 0;
293 }
294
295 for (auto &m : fEncapsulatedMethods) {
296 auto methodValues = m->GetMulticlassValues();
297 for (size_t i = 0; i < methodValues.size(); ++i) {
298 fMulticlassValues[i] += methodValues[i];
299 }
300 }
301
302 for (auto &e : fMulticlassValues) {
303 e /= fEncapsulatedMethods.size();
304 }
305
306 return fMulticlassValues;
307
308 } else {
309 Log() << kFATAL << "Ensembling type " << fOutputEnsembling << " unknown" << Endl;
310 return fMulticlassValues; // Cannot happen
311 }
312}
313
314////////////////////////////////////////////////////////////////////////////////
315/// Get the regression value generated by the containing methods.
316
318{
319 const Event *ev = GetEvent();
320
321 if (fOutputEnsembling == "None") {
322 if (fSplitExpr != nullptr) {
323 // K-folds with a deterministic split
324 UInt_t iFold = fSplitExpr->Eval(fNumFolds, ev);
325 return fEncapsulatedMethods.at(iFold)->GetRegressionValues();
326 } else {
327 // K-folds with a random split was used
328 UInt_t iFold = fEventToFoldMapping.at(Data()->GetEvent());
329 return fEncapsulatedMethods.at(iFold)->GetRegressionValues();
330 }
331 } else if (fOutputEnsembling == "Avg") {
332
333 for (auto &e : fRegressionValues) {
334 e = 0;
335 }
336
337 for (auto &m : fEncapsulatedMethods) {
338 auto methodValues = m->GetRegressionValues();
339 for (size_t i = 0; i < methodValues.size(); ++i) {
340 fRegressionValues[i] += methodValues[i];
341 }
342 }
343
344 for (auto &e : fRegressionValues) {
345 e /= fEncapsulatedMethods.size();
346 }
347
348 return fRegressionValues;
349
350 } else {
351 Log() << kFATAL << "Ensembling type " << fOutputEnsembling << " unknown" << Endl;
352 return fRegressionValues; // Cannot happen
353 }
354}
355
356////////////////////////////////////////////////////////////////////////////////
357///
358
360{
361 // // Used for evaluation, which is outside the life time of MethodCrossEval.
362 // Log() << kFATAL << "Method CrossValidation should not be created manually,"
363 // " only as part of using TMVA::Reader." << Endl;
364 // return;
365}
366
367////////////////////////////////////////////////////////////////////////////////
368///
369
371{
372 Log() << kWARNING
373 << "Method CrossValidation should not be created manually,"
374 " only as part of using TMVA::Reader."
375 << Endl;
376}
377
378////////////////////////////////////////////////////////////////////////////////
379///
380
382{
383 return nullptr;
384}
385
386////////////////////////////////////////////////////////////////////////////////
387
389 UInt_t /*numberTargets*/)
390{
391 return kTRUE;
392 // if (fEncapsulatedMethods.size() == 0) {return kFALSE;}
393 // if (fEncapsulatedMethods.at(0) == nullptr) {return kFALSE;}
394 // return fEncapsulatedMethods.at(0)->HasAnalysisType(type, numberClasses, numberTargets);
395}
396
397////////////////////////////////////////////////////////////////////////////////
398/// Make ROOT-independent C++ class for classifier response (classifier-specific implementation).
399
400void TMVA::MethodCrossValidation::MakeClassSpecific(std::ostream & /*fout*/, const TString & /*className*/) const
401{
402 Log() << kWARNING << "MakeClassSpecific not implemented for CrossValidation" << Endl;
403}
404
405////////////////////////////////////////////////////////////////////////////////
406/// Specific class header.
407
408void TMVA::MethodCrossValidation::MakeClassSpecificHeader(std::ostream & /*fout*/, const TString & /*className*/) const
409{
410 Log() << kWARNING << "MakeClassSpecificHeader not implemented for CrossValidation" << Endl;
411}
#define REGISTER_METHOD(CLASS)
for example
#define e(i)
Definition RSha256.hxx:103
double Double_t
Definition RtypesCore.h:59
const Bool_t kTRUE
Definition RtypesCore.h:100
#define ClassImp(name)
Definition Rtypes.h:364
char * Form(const char *fmt,...)
R__EXTERN TSystem * gSystem
Definition TSystem.h:559
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
TString fWeightFileDir
Definition Config.h:124
IONames & GetIONames()
Definition Config.h:98
Class that contains all the data information.
Definition DataSetInfo.h:62
Virtual base Class for all MVA method.
Definition MethodBase.h:111
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
friend class MethodCrossValidation
Definition MethodBase.h:117
void MakeClassSpecific(std::ostream &, const TString &) const
Make ROOT-independent C++ class for classifier response (classifier-specific implementation).
void AddWeightsXMLTo(void *parent) const
Write weights to XML.
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
void Init(void)
Common initialisation with defaults for the Method.
void MakeClassSpecificHeader(std::ostream &, const TString &) const
Specific class header.
void Reset(void)
Reset the method, as if it had just been instantiated (forget all training etc.).
TString GetWeightFileNameForFold(UInt_t iFold) const
Returns filename of weight file for a given fold.
const std::vector< Float_t > & GetRegressionValues()
Get the regression value generated by the containing methods.
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
MethodBase * InstantiateMethodFromXML(TString methodTypeName, TString weightfile) const
Reads in a weight file an instantiates the corresponding method.
void Train(void)
Call the Optimizer with the set of parameters and ranges that are meant to be tuned.
const std::vector< Float_t > & GetMulticlassValues()
Get the multiclass MVA response.
void ReadWeightsFromXML(void *parent)
Reads from the xml file.
void DeclareCompatibilityOptions()
Options that are used ONLY for the READER to ensure backward compatibility.
void WriteMonitoringHistosToFile(void) const
write special monitoring histograms to file dummy implementation here --------------—
virtual ~MethodCrossValidation(void)
Destructor.
void ReadWeightsFromStream(std::istream &istr)
Read the weights.
void ProcessOptions()
The option string is decoded, for available options see "DeclareOptions".
Ranking for variables in method (implementation)
Definition Ranking.h:48
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition Tools.cxx:1124
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition Tools.h:329
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition Tools.h:347
Singleton class for Global types used by TMVA.
Definition Types.h:71
@ kCategory
Definition Types.h:97
Basic string class.
Definition TString.h:136
const char * Data() const
Definition TString.h:369
virtual TString GetDirName(const char *pathname)
Return the directory name in pathname.
Definition TSystem.cxx:1032
create variable transformations
Config & gConfig()
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148
auto * m
Definition textangle.C:8