Logo ROOT  
Reference Guide
Loading...
Searching...
No Matches
MethodCrossValidation.h
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Kim Albertsson
3
4/*************************************************************************
5 * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12#ifndef ROOT_TMVA_MethodCrossValidation
13#define ROOT_TMVA_MethodCrossValidation
14
15//////////////////////////////////////////////////////////////////////////
16// //
17// MethodCrossValidation //
18// //
19//////////////////////////////////////////////////////////////////////////
20
21#include "TMVA/CvSplit.h"
22#include "TMVA/DataSetInfo.h"
23#include "TMVA/MethodBase.h"
24
25#include "TString.h"
26
27#include <iostream>
28#include <memory>
29#include <vector>
30#include <map>
31
32namespace TMVA {
33
34class CrossValidation;
35class Ranking;
36
37// Looks for serialised methods of the form methodTitle + "_fold" + iFold;
39
41
42public:
43 // constructor for training and reading
44 MethodCrossValidation(const TString &jobName, const TString &methodTitle, DataSetInfo &theData,
45 const TString &theOption = "");
46
47 // constructor for calculating BDT-MVA using previously generated decision trees
48 MethodCrossValidation(DataSetInfo &theData, const TString &theWeightFile);
49
50 virtual ~MethodCrossValidation(void);
51
52 // optimize tuning parameters
53 // virtual std::map<TString,Double_t> OptimizeTuningParameters(TString fomType="ROCIntegral", TString
54 // fitType="FitGA"); virtual void SetTuneParameters(std::map<TString,Double_t> tuneParameters);
55
56 // training method
57 void Train(void) override;
58
59 // revoke training
60 void Reset(void) override;
61
63
64 // write weights to file
65 void AddWeightsXMLTo(void *parent) const override;
66
67 // read weights from file
68 void ReadWeightsFromStream(std::istream &istr) override;
69 void ReadWeightsFromXML(void *parent) override;
70
71 // write method specific histos to target file
72 void WriteMonitoringHistosToFile(void) const override;
73
74 // calculate the MVA value
75 Double_t GetMvaValue(Double_t *err = nullptr, Double_t *errUpper = nullptr) override;
76 const std::vector<Float_t> &GetMulticlassValues() override;
77 const std::vector<Float_t> &GetRegressionValues() override;
78
79 // the option handling methods
80 void DeclareOptions() override;
81 void ProcessOptions() override;
82
83 // make ROOT-independent C++ class for classifier response (classifier-specific implementation)
84 void MakeClassSpecific(std::ostream &, const TString &) const override;
85 void MakeClassSpecificHeader(std::ostream &, const TString &) const override;
86
87 void GetHelpMessage() const override;
88
89 const Ranking *CreateRanking() override;
90 Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets) override;
91
92protected:
93 void Init(void) override;
94 void DeclareCompatibilityOptions() override;
95
96private:
98 MethodBase *InstantiateMethodFromXML(TString methodTypeName, TString weightfile) const;
99
100private:
105
107 std::unique_ptr<CvSplitKFoldsExpr> fSplitExpr;
108
109 std::vector<Float_t> fMulticlassValues;
110 std::vector<Float_t> fRegressionValues;
111
112 std::vector<MethodBase *> fEncapsulatedMethods;
113
114 // Used for CrossValidation with random splits (not using the
115 // CVSplitCrossValisationExpr functionality) to communicate Event to fold
116 // mapping.
117 std::map<const TMVA::Event *, UInt_t> fEventToFoldMapping;
118
119 // for backward compatibility
121};
122
123} // namespace TMVA
124
125#endif
unsigned int UInt_t
Unsigned integer 4 bytes (unsigned int).
Definition RtypesCore.h:60
bool Bool_t
Boolean (0=false, 1=true) (bool).
Definition RtypesCore.h:77
double Double_t
Double 8 bytes.
Definition RtypesCore.h:73
#define ClassDefOverride(name, id)
Definition Rtypes.h:348
Double_t err
Class to perform cross validation, splitting the dataloader into folds.
Class that contains all the data information.
Definition DataSetInfo.h:62
MethodBase(const TString &jobName, Types::EMVA methodType, const TString &methodTitle, DataSetInfo &dsi, const TString &theOption="")
standard constructor
void ReadWeightsFromStream(std::istream &) override=0
friend class MethodCrossValidation
Definition MethodBase.h:117
std::vector< Float_t > fRegressionValues
void Train(void) override
Call the Optimizer with the set of parameters and ranges that are meant to be tuned.
std::vector< Float_t > fMulticlassValues
std::unique_ptr< CvSplitKFoldsExpr > fSplitExpr
const Ranking * CreateRanking() override
void WriteMonitoringHistosToFile(void) const override
write special monitoring histograms to file dummy implementation here --------------—
std::vector< MethodBase * > fEncapsulatedMethods
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets) override
void AddWeightsXMLTo(void *parent) const override
Write weights to XML.
void DeclareCompatibilityOptions() override
Options that are used ONLY for the READER to ensure backward compatibility.
void Init(void) override
Common initialisation with defaults for the Method.
void MakeClassSpecific(std::ostream &, const TString &) const override
Make ROOT-independent C++ class for classifier response (classifier-specific implementation).
TString GetWeightFileNameForFold(UInt_t iFold) const
Returns filename of weight file for a given fold.
const std::vector< Float_t > & GetMulticlassValues() override
Get the multiclass MVA response.
MethodBase * InstantiateMethodFromXML(TString methodTypeName, TString weightfile) const
Reads in a weight file an instantiates the corresponding method.
void ReadWeightsFromStream(std::istream &istr) override
Read the weights.
void ProcessOptions() override
The option string is decoded, for available options see "DeclareOptions".
Double_t GetMvaValue(Double_t *err=nullptr, Double_t *errUpper=nullptr) override
void Reset(void) override
Reset the method, as if it had just been instantiated (forget all training etc.).
std::map< const TMVA::Event *, UInt_t > fEventToFoldMapping
MethodCrossValidation(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
const std::vector< Float_t > & GetRegressionValues() override
Get the regression value generated by the containing methods.
void ReadWeightsFromXML(void *parent) override
Reads from the xml file.
void MakeClassSpecificHeader(std::ostream &, const TString &) const override
Specific class header.
virtual ~MethodCrossValidation(void)
Destructor.
Ranking for variables in method (implementation).
Definition Ranking.h:48
Basic string class.
Definition TString.h:138
create variable transformations