Logo ROOT   6.16/01
Reference Guide
MethodCrossValidation.h
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Kim Albertsson
3
4/*************************************************************************
5 * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12#ifndef ROOT_TMVA_MethodCrossValidation
13#define ROOT_TMVA_MethodCrossValidation
14
15//////////////////////////////////////////////////////////////////////////
16// //
17// MethodCrossValidation //
18// //
19//////////////////////////////////////////////////////////////////////////
20
21#include "TMVA/CvSplit.h"
22#include "TMVA/DataSetInfo.h"
23#include "TMVA/MethodBase.h"
24
25#include "TString.h"
26
27#include <iostream>
28#include <memory>
29
30namespace TMVA {
31
32class CrossValidation;
33class Ranking;
34
35// Looks for serialised methods of the form methodTitle + "_fold" + iFold;
37
39
40public:
41 // constructor for training and reading
42 MethodCrossValidation(const TString &jobName, const TString &methodTitle, DataSetInfo &theData,
43 const TString &theOption = "");
44
45 // constructor for calculating BDT-MVA using previously generatad decision trees
46 MethodCrossValidation(DataSetInfo &theData, const TString &theWeightFile);
47
48 virtual ~MethodCrossValidation(void);
49
50 // optimize tuning parameters
51 // virtual std::map<TString,Double_t> OptimizeTuningParameters(TString fomType="ROCIntegral", TString
52 // fitType="FitGA"); virtual void SetTuneParameters(std::map<TString,Double_t> tuneParameters);
53
54 // training method
55 void Train(void);
56
57 // revoke training
58 void Reset(void);
59
61
62 // write weights to file
63 void AddWeightsXMLTo(void *parent) const;
64
65 // read weights from file
66 void ReadWeightsFromStream(std::istream &istr);
67 void ReadWeightsFromXML(void *parent);
68
69 // write method specific histos to target file
70 void WriteMonitoringHistosToFile(void) const;
71
72 // calculate the MVA value
73 Double_t GetMvaValue(Double_t *err = 0, Double_t *errUpper = 0);
74 const std::vector<Float_t> &GetMulticlassValues();
75 const std::vector<Float_t> &GetRegressionValues();
76
77 // the option handling methods
78 void DeclareOptions();
79 void ProcessOptions();
80
81 // make ROOT-independent C++ class for classifier response (classifier-specific implementation)
82 void MakeClassSpecific(std::ostream &, const TString &) const;
83 void MakeClassSpecificHeader(std::ostream &, const TString &) const;
84
85 void GetHelpMessage() const;
86
87 const Ranking *CreateRanking();
88 Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets);
89
90protected:
91 void Init(void);
93
94private:
96 MethodBase *InstantiateMethodFromXML(TString methodTypeName, TString weightfile) const;
97
98private:
103
105 std::unique_ptr<CvSplitKFoldsExpr> fSplitExpr;
106
107 std::vector<Float_t> fMulticlassValues;
108 std::vector<Float_t> fRegressionValues;
109
110 std::vector<MethodBase *> fEncapsulatedMethods;
111
112 // Used for CrossValidation with random splits (not using the
113 // CVSplitCrossValisationExpr functionality) to communicate Event to fold
114 // mapping.
115 std::map<const TMVA::Event *, UInt_t> fEventToFoldMapping;
116
117 // for backward compatibility
119};
120
121} // namespace TMVA
122
123#endif
unsigned int UInt_t
Definition: RtypesCore.h:42
bool Bool_t
Definition: RtypesCore.h:59
double Double_t
Definition: RtypesCore.h:55
#define ClassDef(name, id)
Definition: Rtypes.h:324
int type
Definition: TGX11.cxx:120
Class that contains all the data information.
Definition: DataSetInfo.h:60
Virtual base Class for all MVA method.
Definition: MethodBase.h:109
virtual void ReadWeightsFromStream(std::istream &)=0
friend class MethodCrossValidation
Definition: MethodBase.h:115
std::vector< Float_t > fRegressionValues
std::vector< Float_t > fMulticlassValues
std::unique_ptr< CvSplitKFoldsExpr > fSplitExpr
void MakeClassSpecific(std::ostream &, const TString &) const
Make ROOT-independent C++ class for classifier response (classifier-specific implementation).
void AddWeightsXMLTo(void *parent) const
Write weights to XML.
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
void Init(void)
Common initialisation with defaults for the Method.
std::vector< MethodBase * > fEncapsulatedMethods
void MakeClassSpecificHeader(std::ostream &, const TString &) const
Specific class header.
void Reset(void)
Reset the method, as if it had just been instantiated (forget all training etc.).
TString GetWeightFileNameForFold(UInt_t iFold) const
Returns filename of weight file for a given fold.
const std::vector< Float_t > & GetRegressionValues()
Get the regression value generated by the containing methods.
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
MethodBase * InstantiateMethodFromXML(TString methodTypeName, TString weightfile) const
Reads in a weight file an instantiates the corresponding method.
void Train(void)
Call the Optimizer with the set of parameters and ranges that are meant to be tuned.
const std::vector< Float_t > & GetMulticlassValues()
Get the multiclass MVA response.
std::map< const TMVA::Event *, UInt_t > fEventToFoldMapping
void ReadWeightsFromXML(void *parent)
Reads from the xml file.
void DeclareCompatibilityOptions()
Options that are used ONLY for the READER to ensure backward compatibility.
void WriteMonitoringHistosToFile(void) const
write special monitoring histograms to file dummy implementation here --------------—
virtual ~MethodCrossValidation(void)
Destructor.
void ReadWeightsFromStream(std::istream &istr)
Read the weights.
void ProcessOptions()
The option string is decoded, for available options see "DeclareOptions".
Ranking for variables in method (implementation)
Definition: Ranking.h:48
EAnalysisType
Definition: Types.h:127
Basic string class.
Definition: TString.h:131
Abstract ClassifierFactory template that handles arbitrary types.