35 :
TMVA::
MethodBase(jobName,
Types::kCrossValidation, methodTitle, theData, theOption), fSplitExpr(nullptr)
56 DeclareOptionRef(fEncapsulatedMethodName,
"EncapsulatedMethodName",
"");
57 DeclareOptionRef(fEncapsulatedMethodTypeName,
"EncapsulatedMethodTypeName",
"");
58 DeclareOptionRef(fNumFolds,
"NumFolds",
"Number of folds to generate");
59 DeclareOptionRef(fOutputEnsembling =
TString(
"None"),
"OutputEnsembling",
60 "Combines output from contained methods. If None, no combination is performed. (default None)");
63 DeclareOptionRef(fSplitExprString,
"SplitExpr",
"The expression used to assign events to folds");
79 Log() << kDEBUG <<
"ProcessOptions -- fNumFolds: " << fNumFolds <<
Endl;
80 Log() << kDEBUG <<
"ProcessOptions -- fEncapsulatedMethodName: " << fEncapsulatedMethodName <<
Endl;
81 Log() << kDEBUG <<
"ProcessOptions -- fEncapsulatedMethodTypeName: " << fEncapsulatedMethodTypeName <<
Endl;
83 if (fSplitExprString !=
TString(
"")) {
84 fSplitExpr = std::unique_ptr<CvSplitKFoldsExpr>(
new CvSplitKFoldsExpr(DataInfo(), fSplitExprString));
87 for (
UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
88 TString weightfile = GetWeightFileNameForFold(iFold);
90 Log() << kINFO <<
"Reading weightfile: " << weightfile <<
Endl;
91 MethodBase *fold_method = InstantiateMethodFromXML(fEncapsulatedMethodTypeName, weightfile);
92 fEncapsulatedMethods.push_back(fold_method);
101 fMulticlassValues = std::vector<Float_t>(DataInfo().GetNClasses());
102 fRegressionValues = std::vector<Float_t>(DataInfo().GetNTargets());
116 if (iFold >= fNumFolds) {
117 Log() << kFATAL << iFold <<
" out of range. "
118 <<
"Should be < " << fNumFolds <<
"." <<
Endl;
123 TString weightfile = fileDir +
"/" + fJobName +
"_" + fEncapsulatedMethodName +
"_" + foldStr +
".weights.xml";
160 Log() << kFATAL <<
"MethodCategory not supported for the moment." <<
Endl;
164 m->SetWeightFileDir(fileDir);
167 m->SetAnalysisType(fAnalysisType);
169 m->ReadStateFromFile();
185 gTools().
AddAttr(wght,
"EncapsulatedMethodName", fEncapsulatedMethodName);
186 gTools().
AddAttr(wght,
"EncapsulatedMethodTypeName", fEncapsulatedMethodTypeName);
187 gTools().
AddAttr(wght,
"OutputEnsembling", fOutputEnsembling);
189 for (
UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
190 TString weightfile = GetWeightFileNameForFold(iFold);
213 gTools().
ReadAttr(parent,
"EncapsulatedMethodName", fEncapsulatedMethodName);
214 gTools().
ReadAttr(parent,
"EncapsulatedMethodTypeName", fEncapsulatedMethodTypeName);
218 for (
UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
219 TString weightfile = GetWeightFileNameForFold(iFold);
221 Log() << kINFO <<
"Reading weightfile: " << weightfile <<
Endl;
222 MethodBase *fold_method = InstantiateMethodFromXML(fEncapsulatedMethodTypeName, weightfile);
223 fEncapsulatedMethods.push_back(fold_method);
227 if (fSplitExprString !=
TString(
"")) {
228 fSplitExpr = std::unique_ptr<CvSplitKFoldsExpr>(
new CvSplitKFoldsExpr(DataInfo(), fSplitExprString));
238 Log() << kFATAL <<
"CrossValidation currently supports only reading from XML." <<
Endl;
246 const Event *ev = GetEvent();
248 if (fOutputEnsembling ==
"None") {
249 if (fSplitExpr !=
nullptr) {
251 UInt_t iFold = fSplitExpr->Eval(fNumFolds, ev);
252 return fEncapsulatedMethods.at(iFold)->GetMvaValue(err, errUpper);
255 UInt_t iFold = fEventToFoldMapping.at(Data()->GetEvent());
256 return fEncapsulatedMethods.at(iFold)->GetMvaValue(err, errUpper);
258 }
else if (fOutputEnsembling ==
"Avg") {
260 for (
auto &
m : fEncapsulatedMethods) {
261 val +=
m->GetMvaValue(err, errUpper);
263 return val / fEncapsulatedMethods.size();
265 Log() << kFATAL <<
"Ensembling type " << fOutputEnsembling <<
" unknown" <<
Endl;
275 const Event *ev = GetEvent();
277 if (fOutputEnsembling ==
"None") {
278 if (fSplitExpr !=
nullptr) {
280 UInt_t iFold = fSplitExpr->Eval(fNumFolds, ev);
281 return fEncapsulatedMethods.at(iFold)->GetMulticlassValues();
284 UInt_t iFold = fEventToFoldMapping.at(Data()->GetEvent());
285 return fEncapsulatedMethods.at(iFold)->GetMulticlassValues();
287 }
else if (fOutputEnsembling ==
"Avg") {
289 for (
auto &
e : fMulticlassValues) {
293 for (
auto &
m : fEncapsulatedMethods) {
294 auto methodValues =
m->GetMulticlassValues();
295 for (
size_t i = 0; i < methodValues.size(); ++i) {
296 fMulticlassValues[i] += methodValues[i];
300 for (
auto &
e : fMulticlassValues) {
301 e /= fEncapsulatedMethods.size();
304 return fMulticlassValues;
307 Log() << kFATAL <<
"Ensembling type " << fOutputEnsembling <<
" unknown" <<
Endl;
308 return fMulticlassValues;
317 const Event *ev = GetEvent();
319 if (fOutputEnsembling ==
"None") {
320 if (fSplitExpr !=
nullptr) {
322 UInt_t iFold = fSplitExpr->Eval(fNumFolds, ev);
323 return fEncapsulatedMethods.at(iFold)->GetRegressionValues();
326 UInt_t iFold = fEventToFoldMapping.at(Data()->GetEvent());
327 return fEncapsulatedMethods.at(iFold)->GetRegressionValues();
329 }
else if (fOutputEnsembling ==
"Avg") {
331 for (
auto &
e : fRegressionValues) {
335 for (
auto &
m : fEncapsulatedMethods) {
336 auto methodValues =
m->GetRegressionValues();
337 for (
size_t i = 0; i < methodValues.size(); ++i) {
338 fRegressionValues[i] += methodValues[i];
342 for (
auto &
e : fRegressionValues) {
343 e /= fEncapsulatedMethods.size();
346 return fRegressionValues;
349 Log() << kFATAL <<
"Ensembling type " << fOutputEnsembling <<
" unknown" <<
Endl;
350 return fRegressionValues;
371 <<
"Method CrossValidation should not be created manually,"
372 " only as part of using TMVA::Reader."
400 Log() << kWARNING <<
"MakeClassSpecific not implemented for CrossValidation" <<
Endl;
408 Log() << kWARNING <<
"MakeClassSpecificHeader not implemented for CrossValidation" <<
Endl;
#define REGISTER_METHOD(CLASS)
for example
char * Form(const char *fmt,...)
R__EXTERN TSystem * gSystem
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
Class that contains all the data information.
Virtual base Class for all MVA method.
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
friend class MethodCrossValidation
void GetHelpMessage() const
void MakeClassSpecific(std::ostream &, const TString &) const
Make ROOT-independent C++ class for classifier response (classifier-specific implementation).
void AddWeightsXMLTo(void *parent) const
Write weights to XML.
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
void Init(void)
Common initialisation with defaults for the Method.
void MakeClassSpecificHeader(std::ostream &, const TString &) const
Specific class header.
void Reset(void)
Reset the method, as if it had just been instantiated (forget all training etc.).
const Ranking * CreateRanking()
TString GetWeightFileNameForFold(UInt_t iFold) const
Returns filename of weight file for a given fold.
const std::vector< Float_t > & GetRegressionValues()
Get the regression value generated by the containing methods.
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
MethodBase * InstantiateMethodFromXML(TString methodTypeName, TString weightfile) const
Reads in a weight file an instantiates the corresponding method.
void Train(void)
Call the Optimizer with the set of parameters and ranges that are meant to be tuned.
const std::vector< Float_t > & GetMulticlassValues()
Get the multiclass MVA response.
void ReadWeightsFromXML(void *parent)
Reads from the xml file.
void DeclareCompatibilityOptions()
Options that are used ONLY for the READER to ensure backward compatibility.
void WriteMonitoringHistosToFile(void) const
write special monitoring histograms to file dummy implementation here --------------—
virtual ~MethodCrossValidation(void)
Destructor.
void ReadWeightsFromStream(std::istream &istr)
Read the weights.
void ProcessOptions()
The option string is decoded, for available options see "DeclareOptions".
Ranking for variables in method (implementation)
Singleton class for Global types used by TMVA.
const char * Data() const
virtual const char * DirName(const char *pathname)
Return the directory name in pathname.
std::string GetName(const std::string &scope_name)
create variable transformations
MsgLogger & Endl(MsgLogger &ml)