Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodCrossValidation.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Kim Albertsson
3
4/*************************************************************************
5 * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12/*! \class TMVA::MethodCrossValidation
13\ingroup TMVA
14*/
16
18#include "TMVA/Config.h"
19#include "TMVA/CvSplit.h"
20#include "TMVA/MethodCategory.h"
21#include "TMVA/Tools.h"
22#include "TMVA/Types.h"
23
24#include "TSystem.h"
25
27
28
29////////////////////////////////////////////////////////////////////////////////
30///
31
34 : TMVA::MethodBase(jobName, Types::kCrossValidation, methodTitle, theData, theOption), fSplitExpr(nullptr)
35{
36}
37
38////////////////////////////////////////////////////////////////////////////////
39
44
45////////////////////////////////////////////////////////////////////////////////
46/// Destructor.
47///
48
50
51////////////////////////////////////////////////////////////////////////////////
52
54{
55 DeclareOptionRef(fEncapsulatedMethodName, "EncapsulatedMethodName", "");
56 DeclareOptionRef(fEncapsulatedMethodTypeName, "EncapsulatedMethodTypeName", "");
57 DeclareOptionRef(fNumFolds, "NumFolds", "Number of folds to generate");
58 DeclareOptionRef(fOutputEnsembling = TString("None"), "OutputEnsembling",
59 "Combines output from contained methods. If None, no combination is performed. (default None)");
60 AddPreDefVal(TString("None"));
61 AddPreDefVal(TString("Avg"));
62 DeclareOptionRef(fSplitExprString, "SplitExpr", "The expression used to assign events to folds");
63}
64
65////////////////////////////////////////////////////////////////////////////////
66/// Options that are used ONLY for the READER to ensure backward compatibility.
67
72
73////////////////////////////////////////////////////////////////////////////////
74/// The option string is decoded, for available options see "DeclareOptions".
75
77{
78 Log() << kDEBUG << "ProcessOptions -- fNumFolds: " << fNumFolds << Endl;
79 Log() << kDEBUG << "ProcessOptions -- fEncapsulatedMethodName: " << fEncapsulatedMethodName << Endl;
80 Log() << kDEBUG << "ProcessOptions -- fEncapsulatedMethodTypeName: " << fEncapsulatedMethodTypeName << Endl;
81
82 if (fSplitExprString != TString("")) {
83 fSplitExpr = std::unique_ptr<CvSplitKFoldsExpr>(new CvSplitKFoldsExpr(DataInfo(), fSplitExprString));
84 }
85
86 for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
87 TString weightfile = GetWeightFileNameForFold(iFold);
88
89 Log() << kINFO << "Reading weightfile: " << weightfile << Endl;
90 MethodBase *fold_method = InstantiateMethodFromXML(fEncapsulatedMethodTypeName, weightfile);
91 fEncapsulatedMethods.push_back(fold_method);
92 }
93}
94
95////////////////////////////////////////////////////////////////////////////////
96/// Common initialisation with defaults for the Method.
97
99{
100 fMulticlassValues = std::vector<Float_t>(DataInfo().GetNClasses());
101 fRegressionValues = std::vector<Float_t>(DataInfo().GetNTargets());
102}
103
104////////////////////////////////////////////////////////////////////////////////
105/// Reset the method, as if it had just been instantiated (forget all training etc.).
106
108
109////////////////////////////////////////////////////////////////////////////////
110/// \brief Returns filename of weight file for a given fold.
111/// \param[in] iFold Ordinal of the fold. Range: 0 to NumFolds exclusive.
112///
114{
115 if (iFold >= fNumFolds) {
116 Log() << kFATAL << iFold << " out of range. "
117 << "Should be < " << fNumFolds << "." << Endl;
118 }
119
120 TString foldStr = TString::Format("fold%i", iFold + 1);
121 TString fileDir = gSystem->GetDirName(GetWeightFileName());
122 TString weightfile = fileDir + "/" + fJobName + "_" + fEncapsulatedMethodName + "_" + foldStr + ".weights.xml";
123
124 return weightfile;
125}
126
127////////////////////////////////////////////////////////////////////////////////
128/// Call the Optimizer with the set of parameters and ranges that
129/// are meant to be tuned.
130
131// std::map<TString,Double_t> TMVA::MethodCrossValidation::OptimizeTuningParameters(TString fomType, TString fitType)
132// {
133// }
134
135////////////////////////////////////////////////////////////////////////////////
136/// Set the tuning parameters according to the argument.
137
138// void TMVA::MethodCrossValidation::SetTuneParameters(std::map<TString,Double_t> tuneParameters)
139// {
140// }
141
142////////////////////////////////////////////////////////////////////////////////
143/// training.
144
146
147////////////////////////////////////////////////////////////////////////////////
148/// \brief Reads in a weight file an instantiates the corresponding method
149/// \param[in] methodTypeName Canonical name of the method type. E.g. `"BDT"`
150/// for Boosted Decision Trees.
151/// \param[in] weightfile File to read method parameters from
154{
155 TMVA::MethodBase *m = dynamic_cast<MethodBase *>(
156 ClassifierFactory::Instance().Create(std::string(methodTypeName.Data()), DataInfo(), weightfile));
157
158 if (m->GetMethodType() == Types::kCategory) {
159 Log() << kFATAL << "MethodCategory not supported for the moment." << Endl;
160 }
161
162 TString fileDir = TString(DataInfo().GetName()) + "/" + gConfig().GetIONames().fWeightFileDir;
163 m->SetWeightFileDir(fileDir);
164 // m->SetModelPersistence(fModelPersistence);
165 // m->SetSilentFile(IsSilentFile());
166 m->SetAnalysisType(fAnalysisType);
167 m->SetupMethod();
168 m->ReadStateFromFile();
169 // m->SetTestvarName(testvarName);
170
171 return m;
172}
173
174////////////////////////////////////////////////////////////////////////////////
175/// Write weights to XML.
176
178{
179 void *wght = gTools().AddChild(parent, "Weights");
180
181 gTools().AddAttr(wght, "JobName", fJobName);
182 gTools().AddAttr(wght, "SplitExpr", fSplitExprString);
183 gTools().AddAttr(wght, "NumFolds", fNumFolds);
184 gTools().AddAttr(wght, "EncapsulatedMethodName", fEncapsulatedMethodName);
185 gTools().AddAttr(wght, "EncapsulatedMethodTypeName", fEncapsulatedMethodTypeName);
186 gTools().AddAttr(wght, "OutputEnsembling", fOutputEnsembling);
187
188 for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
189 TString weightfile = GetWeightFileNameForFold(iFold);
190
191 // TODO: Add a swithch in options for using either split files or only one.
192 // TODO: This would store the method inside MethodCrossValidation
193 // Another option is to store the folds as separate files.
194 // // Retrieve encap. method for fold n
195 // MethodBase * method = InstantiateMethodFromXML(fEncapsulatedMethodTypeName, weightfile);
196 //
197 // // Serialise encapsulated method for fold n
198 // void* foldNode = gTools().AddChild(parent, foldStr);
199 // method->WriteStateToXML(foldNode);
200 }
201}
202
203////////////////////////////////////////////////////////////////////////////////
204/// Reads from the xml file.
205///
206
208{
209 gTools().ReadAttr(parent, "JobName", fJobName);
210 gTools().ReadAttr(parent, "SplitExpr", fSplitExprString);
211 gTools().ReadAttr(parent, "NumFolds", fNumFolds);
212 gTools().ReadAttr(parent, "EncapsulatedMethodName", fEncapsulatedMethodName);
213 gTools().ReadAttr(parent, "EncapsulatedMethodTypeName", fEncapsulatedMethodTypeName);
214 gTools().ReadAttr(parent, "OutputEnsembling", fOutputEnsembling);
215
216 // Read in methods for all folds
217 for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
218 TString weightfile = GetWeightFileNameForFold(iFold);
219
220 Log() << kINFO << "Reading weightfile: " << weightfile << Endl;
221 MethodBase *fold_method = InstantiateMethodFromXML(fEncapsulatedMethodTypeName, weightfile);
222 fEncapsulatedMethods.push_back(fold_method);
223 }
224
225 // SplitExpr
226 if (fSplitExprString != TString("")) {
227 fSplitExpr = std::unique_ptr<CvSplitKFoldsExpr>(new CvSplitKFoldsExpr(DataInfo(), fSplitExprString));
228 } else {
229 Log() << kFATAL << "MethodCrossValidation supports XML reading only for deterministic splitting !" << Endl;
230 }
231}
232
233////////////////////////////////////////////////////////////////////////////////
234/// Read the weights
235///
236
238{
239 Log() << kFATAL << "CrossValidation currently supports only reading from XML." << Endl;
240}
241
242////////////////////////////////////////////////////////////////////////////////
243///
244
246{
247 const Event *ev = GetEvent();
248
249 if (fOutputEnsembling == "None") {
250 if (fSplitExpr != nullptr) {
251 // K-folds with a deterministic split
252 UInt_t iFold = fSplitExpr->Eval(fNumFolds, ev);
253 return fEncapsulatedMethods.at(iFold)->GetMvaValue(err, errUpper);
254 } else {
255 // K-folds with a random split was used
256 UInt_t iFold = fEventToFoldMapping.at(Data()->GetEvent());
257 return fEncapsulatedMethods.at(iFold)->GetMvaValue(err, errUpper);
258 }
259 } else if (fOutputEnsembling == "Avg") {
260 Double_t val = 0.0;
261 for (auto &m : fEncapsulatedMethods) {
262 val += m->GetMvaValue(err, errUpper);
263 }
264 return val / fEncapsulatedMethods.size();
265 } else {
266 Log() << kFATAL << "Ensembling type " << fOutputEnsembling << " unknown" << Endl;
267 return 0; // Cannot happen
268 }
269}
270
271////////////////////////////////////////////////////////////////////////////////
272/// Get the multiclass MVA response.
273
275{
276 const Event *ev = GetEvent();
277
278 if (fOutputEnsembling == "None") {
279 if (fSplitExpr != nullptr) {
280 // K-folds with a deterministic split
281 UInt_t iFold = fSplitExpr->Eval(fNumFolds, ev);
282 return fEncapsulatedMethods.at(iFold)->GetMulticlassValues();
283 } else {
284 // K-folds with a random split was used
285 UInt_t iFold = fEventToFoldMapping.at(Data()->GetEvent());
286 return fEncapsulatedMethods.at(iFold)->GetMulticlassValues();
287 }
288 } else if (fOutputEnsembling == "Avg") {
289
290 for (auto &e : fMulticlassValues) {
291 e = 0;
292 }
293
294 for (auto &m : fEncapsulatedMethods) {
295 auto methodValues = m->GetMulticlassValues();
296 for (size_t i = 0; i < methodValues.size(); ++i) {
297 fMulticlassValues[i] += methodValues[i];
298 }
299 }
300
301 for (auto &e : fMulticlassValues) {
302 e /= fEncapsulatedMethods.size();
303 }
304
305 return fMulticlassValues;
306
307 } else {
308 Log() << kFATAL << "Ensembling type " << fOutputEnsembling << " unknown" << Endl;
309 return fMulticlassValues; // Cannot happen
310 }
311}
312
313////////////////////////////////////////////////////////////////////////////////
314/// Get the regression value generated by the containing methods.
315
317{
318 const Event *ev = GetEvent();
319
320 if (fOutputEnsembling == "None") {
321 if (fSplitExpr != nullptr) {
322 // K-folds with a deterministic split
323 UInt_t iFold = fSplitExpr->Eval(fNumFolds, ev);
324 return fEncapsulatedMethods.at(iFold)->GetRegressionValues();
325 } else {
326 // K-folds with a random split was used
327 UInt_t iFold = fEventToFoldMapping.at(Data()->GetEvent());
328 return fEncapsulatedMethods.at(iFold)->GetRegressionValues();
329 }
330 } else if (fOutputEnsembling == "Avg") {
331
332 for (auto &e : fRegressionValues) {
333 e = 0;
334 }
335
336 for (auto &m : fEncapsulatedMethods) {
337 auto methodValues = m->GetRegressionValues();
338 for (size_t i = 0; i < methodValues.size(); ++i) {
339 fRegressionValues[i] += methodValues[i];
340 }
341 }
342
343 for (auto &e : fRegressionValues) {
344 e /= fEncapsulatedMethods.size();
345 }
346
347 return fRegressionValues;
348
349 } else {
350 Log() << kFATAL << "Ensembling type " << fOutputEnsembling << " unknown" << Endl;
351 return fRegressionValues; // Cannot happen
352 }
353}
354
355////////////////////////////////////////////////////////////////////////////////
356///
357
359{
360 // // Used for evaluation, which is outside the life time of MethodCrossEval.
361 // Log() << kFATAL << "Method CrossValidation should not be created manually,"
362 // " only as part of using TMVA::Reader." << Endl;
363 // return;
364}
365
366////////////////////////////////////////////////////////////////////////////////
367///
368
370{
371 Log() << kWARNING
372 << "Method CrossValidation should not be created manually,"
373 " only as part of using TMVA::Reader."
374 << Endl;
375}
376
377////////////////////////////////////////////////////////////////////////////////
378///
379
381{
382 return nullptr;
383}
384
385////////////////////////////////////////////////////////////////////////////////
386
388 UInt_t /*numberTargets*/)
389{
390 return kTRUE;
391 // if (fEncapsulatedMethods.size() == 0) {return kFALSE;}
392 // if (fEncapsulatedMethods.at(0) == nullptr) {return kFALSE;}
393 // return fEncapsulatedMethods.at(0)->HasAnalysisType(type, numberClasses, numberTargets);
394}
395
396////////////////////////////////////////////////////////////////////////////////
397/// Make ROOT-independent C++ class for classifier response (classifier-specific implementation).
398
399void TMVA::MethodCrossValidation::MakeClassSpecific(std::ostream & /*fout*/, const TString & /*className*/) const
400{
401 Log() << kWARNING << "MakeClassSpecific not implemented for CrossValidation" << Endl;
402}
403
404////////////////////////////////////////////////////////////////////////////////
405/// Specific class header.
406
407void TMVA::MethodCrossValidation::MakeClassSpecificHeader(std::ostream & /*fout*/, const TString & /*className*/) const
408{
409 Log() << kWARNING << "MakeClassSpecificHeader not implemented for CrossValidation" << Endl;
410}
#define REGISTER_METHOD(CLASS)
for example
#define e(i)
Definition RSha256.hxx:103
constexpr Bool_t kTRUE
Definition RtypesCore.h:107
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
R__EXTERN TSystem * gSystem
Definition TSystem.h:572
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
TString fWeightFileDir
Definition Config.h:124
IONames & GetIONames()
Definition Config.h:98
Class that contains all the data information.
Definition DataSetInfo.h:62
Virtual base Class for all MVA method.
Definition MethodBase.h:111
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
friend class MethodCrossValidation
Definition MethodBase.h:117
void Train(void) override
Call the Optimizer with the set of parameters and ranges that are meant to be tuned.
const Ranking * CreateRanking() override
void WriteMonitoringHistosToFile(void) const override
write special monitoring histograms to file dummy implementation here --------------—
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets) override
void AddWeightsXMLTo(void *parent) const override
Write weights to XML.
void DeclareCompatibilityOptions() override
Options that are used ONLY for the READER to ensure backward compatibility.
void Init(void) override
Common initialisation with defaults for the Method.
void MakeClassSpecific(std::ostream &, const TString &) const override
Make ROOT-independent C++ class for classifier response (classifier-specific implementation).
TString GetWeightFileNameForFold(UInt_t iFold) const
Returns filename of weight file for a given fold.
const std::vector< Float_t > & GetMulticlassValues() override
Get the multiclass MVA response.
MethodBase * InstantiateMethodFromXML(TString methodTypeName, TString weightfile) const
Reads in a weight file an instantiates the corresponding method.
void ReadWeightsFromStream(std::istream &istr) override
Read the weights.
void ProcessOptions() override
The option string is decoded, for available options see "DeclareOptions".
Double_t GetMvaValue(Double_t *err=nullptr, Double_t *errUpper=nullptr) override
void Reset(void) override
Reset the method, as if it had just been instantiated (forget all training etc.).
const std::vector< Float_t > & GetRegressionValues() override
Get the regression value generated by the containing methods.
void ReadWeightsFromXML(void *parent) override
Reads from the xml file.
void MakeClassSpecificHeader(std::ostream &, const TString &) const override
Specific class header.
virtual ~MethodCrossValidation(void)
Destructor.
Ranking for variables in method (implementation)
Definition Ranking.h:48
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition Tools.h:329
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition Tools.h:347
void * AddChild(void *parent, const char *childname, const char *content=nullptr, bool isRootNode=false)
add child node
Definition Tools.cxx:1124
Singleton class for Global types used by TMVA.
Definition Types.h:71
@ kCategory
Definition Types.h:97
Basic string class.
Definition TString.h:138
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2384
virtual TString GetDirName(const char *pathname)
Return the directory name in pathname.
Definition TSystem.cxx:1042
create variable transformations
Config & gConfig()
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148
TMarker m
Definition textangle.C:8