Logo ROOT  
Reference Guide
CrossValidation.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Omar Zapata, Thomas James Stevenson, Pourya Vakilipourtakalou, Kim Albertsson
3 
4 /*************************************************************************
5  * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 #ifndef ROOT_TMVA_CROSS_EVALUATION
13 #define ROOT_TMVA_CROSS_EVALUATION
14 
15 #include "TGraph.h"
16 #include "TMultiGraph.h"
17 #include "TString.h"
18 #include <vector>
19 #include <map>
20 
21 #include "TMVA/IMethod.h"
22 #include "TMVA/Configurable.h"
23 #include "TMVA/Types.h"
24 #include "TMVA/DataSet.h"
25 #include "TMVA/Event.h"
26 #include <TMVA/Results.h>
27 #include <TMVA/Factory.h>
28 #include <TMVA/DataLoader.h>
29 #include <TMVA/OptionMap.h>
30 #include <TMVA/Envelope.h>
31 
32 /*! \class TMVA::CrossValidationResult
33  * Class to save the results of cross validation,
34  * the metric for the classification ins ROC and you can ROC curves
35  * ROC integrals, ROC average and ROC standard deviation.
36 \ingroup TMVA
37 */
38 
39 /*! \class TMVA::CrossValidation
40  * Class to perform cross validation, splitting the dataloader into folds.
41 \ingroup TMVA
42 */
43 
44 namespace TMVA {
45 
46 class CvSplitKFolds;
47 
48 using EventCollection_t = std::vector<Event *>;
49 using EventTypes_t = std::vector<Bool_t>;
50 using EventOutputs_t = std::vector<Float_t>;
51 using EventOutputsMulticlass_t = std::vector<std::vector<Float_t>>;
52 
54 public:
55  CrossValidationFoldResult() {} // For multi-proc serialisation
57  : fFold(iFold)
58  {}
59 
61 
64 
74 };
75 
76 // Used internally to keep per-fold aggregate statistics
77 // such as ROC curves, ROC integrals and efficiencies.
79  friend class CrossValidation;
80 
81 private:
82  std::map<UInt_t, Float_t> fROCs;
83  std::shared_ptr<TMultiGraph> fROCCurves;
84 
85  std::vector<Double_t> fSigs;
86  std::vector<Double_t> fSeps;
87  std::vector<Double_t> fEff01s;
88  std::vector<Double_t> fEff10s;
89  std::vector<Double_t> fEff30s;
90  std::vector<Double_t> fEffAreas;
91  std::vector<Double_t> fTrainEff01s;
92  std::vector<Double_t> fTrainEff10s;
93  std::vector<Double_t> fTrainEff30s;
94 
95 public:
96  CrossValidationResult(UInt_t numFolds);
99 
100  std::map<UInt_t, Float_t> GetROCValues() const { return fROCs; }
101  Float_t GetROCAverage() const;
103  TMultiGraph *GetROCCurves(Bool_t fLegend = kTRUE);
104  TGraph *GetAvgROCCurve(UInt_t numSamples = 100) const;
105  void Print() const;
106 
107  TCanvas *Draw(const TString name = "CrossValidation") const;
108  TCanvas *DrawAvgROCCurve(Bool_t drawFolds=kFALSE, TString title="") const;
109 
110  std::vector<Double_t> GetSigValues() const { return fSigs; }
111  std::vector<Double_t> GetSepValues() const { return fSeps; }
112  std::vector<Double_t> GetEff01Values() const { return fEff01s; }
113  std::vector<Double_t> GetEff10Values() const { return fEff10s; }
114  std::vector<Double_t> GetEff30Values() const { return fEff30s; }
115  std::vector<Double_t> GetEffAreaValues() const { return fEffAreas; }
116  std::vector<Double_t> GetTrainEff01Values() const { return fTrainEff01s; }
117  std::vector<Double_t> GetTrainEff10Values() const { return fTrainEff10s; }
118  std::vector<Double_t> GetTrainEff30Values() const { return fTrainEff30s; }
119 
120 private:
121  void Fill(CrossValidationFoldResult const & fr);
122 };
123 
124 class CrossValidation : public Envelope {
125 
126 public:
127  explicit CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TString options);
128  explicit CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TFile *outputFile, TString options);
130 
131  void InitOptions();
132  void ParseOptions();
133 
134  void SetNumFolds(UInt_t i);
135  void SetSplitExpr(TString splitExpr);
136 
139 
140  Factory &GetFactory() { return *fFactory; }
141 
142  const std::vector<CrossValidationResult> &GetResults() const;
143 
144  void Evaluate();
145 
146 private:
147  CrossValidationFoldResult ProcessFold(UInt_t iFold, const OptionMap & methodInfo);
148 
155  Bool_t fFoldFileOutput; //! If true: generate output file for each fold
156  Bool_t fFoldStatus; //! If true: dataset is prepared
158  UInt_t fNumFolds; //! Number of folds to prepare
159  UInt_t fNumWorkerProcs; //! Number of processes to use for fold evaluation.
160  //!(Default, no parallel evaluation)
162  TString fOutputEnsembling; //! How to combine output of individual folds
166  std::vector<CrossValidationResult> fResults; //!
171 
172  std::unique_ptr<Factory> fFoldFactory;
173  std::unique_ptr<Factory> fFactory;
174  std::unique_ptr<CvSplitKFolds> fSplit;
175 
177  };
178 
179 } // namespace TMVA
180 
181 #endif // ROOT_TMVA_CROSS_EVALUATION
TMVA::OptionMap
class to storage options for the differents methods
Definition: OptionMap.h:34
TMVA::CrossValidation::fROC
Bool_t fROC
Definition: CrossValidation.h:167
TMVA::CrossValidationResult::~CrossValidationResult
~CrossValidationResult()
Definition: CrossValidation.h:98
TMVA::CrossValidationFoldResult::fTrainEff01
Double_t fTrainEff01
Definition: CrossValidation.h:71
TMVA::CrossValidationResult::GetROCValues
std::map< UInt_t, Float_t > GetROCValues() const
Definition: CrossValidation.h:100
TMVA::CrossValidationResult::fROCCurves
std::shared_ptr< TMultiGraph > fROCCurves
Definition: CrossValidation.h:83
TMVA::CrossValidation::CrossValidation
CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TString options)
Definition: CrossValidation.cxx:310
kTRUE
const Bool_t kTRUE
Definition: RtypesCore.h:100
TMVA::CrossValidationFoldResult::CrossValidationFoldResult
CrossValidationFoldResult(UInt_t iFold)
Definition: CrossValidation.h:56
TMVA::CrossValidationResult::fROCs
std::map< UInt_t, Float_t > fROCs
Definition: CrossValidation.h:82
TMVA::Envelope
Abstract base class for all high level ml algorithms, you can book ml methods like BDT,...
Definition: Envelope.h:44
TMVA::EventCollection_t
std::vector< Event * > EventCollection_t
Definition: CrossValidation.h:48
TMVA::CrossValidationFoldResult::fEff01
Double_t fEff01
Definition: CrossValidation.h:67
TMVA::CrossValidationResult::GetEff10Values
std::vector< Double_t > GetEff10Values() const
Definition: CrossValidation.h:113
TMVA::CrossValidationResult::GetEffAreaValues
std::vector< Double_t > GetEffAreaValues() const
Definition: CrossValidation.h:115
TMVA::CrossValidation::fJobName
TString fJobName
If true: dataset is prepared.
Definition: CrossValidation.h:157
TMVA::CrossValidation::fOutputEnsembling
TString fOutputEnsembling
Definition: CrossValidation.h:162
TMVA::CrossValidationResult::GetROCAverage
Float_t GetROCAverage() const
Definition: CrossValidation.cxx:132
TMVA::CrossValidationResult::fEff10s
std::vector< Double_t > fEff10s
Definition: CrossValidation.h:88
TMVA::CrossValidation::fOutputFile
TFile * fOutputFile
How to combine output of individual folds.
Definition: CrossValidation.h:163
TMVA::CrossValidationFoldResult::fEff30
Double_t fEff30
Definition: CrossValidation.h:69
TMVA::CrossValidationResult::GetSepValues
std::vector< Double_t > GetSepValues() const
Definition: CrossValidation.h:111
TGraph.h
IMethod.h
TMVA::CrossValidation::~CrossValidation
~CrossValidation()
TMVA::CrossValidationResult::GetSigValues
std::vector< Double_t > GetSigValues() const
Definition: CrossValidation.h:110
TMVA::CrossValidationFoldResult::fEffArea
Double_t fEffArea
Definition: CrossValidation.h:70
TMVA::EventOutputsMulticlass_t
std::vector< std::vector< Float_t > > EventOutputsMulticlass_t
Definition: CrossValidation.h:51
DataLoader.h
Float_t
float Float_t
Definition: RtypesCore.h:57
TMVA::CrossValidation::fSplit
std::unique_ptr< CvSplitKFolds > fSplit
Definition: CrossValidation.h:174
TMVA::CrossValidationFoldResult
Definition: CrossValidation.h:53
TMVA::CrossValidation::fVerbose
Bool_t fVerbose
Definition: CrossValidation.h:169
TMVA::CrossValidation::GetSplitExpr
TString GetSplitExpr()
Definition: CrossValidation.h:138
TMVA::CrossValidation::fFoldFileOutput
Bool_t fFoldFileOutput
Definition: CrossValidation.h:155
TMVA::CrossValidation::fSplitExprString
TString fSplitExprString
Definition: CrossValidation.h:165
TMVA::CrossValidation::SetNumFolds
void SetNumFolds(UInt_t i)
Definition: CrossValidation.cxx:474
TMVA::CrossValidation::fNumFolds
UInt_t fNumFolds
Definition: CrossValidation.h:158
TMVA::CrossValidation::fResults
std::vector< CrossValidationResult > fResults
Definition: CrossValidation.h:166
TString
Basic string class.
Definition: TString.h:136
TMVA::EventOutputs_t
std::vector< Float_t > EventOutputs_t
Definition: CrossValidation.h:50
TMVA::CrossValidationFoldResult::fSep
Double_t fSep
Definition: CrossValidation.h:66
Bool_t
bool Bool_t
Definition: RtypesCore.h:63
TMVA::CrossValidationFoldResult::fEff10
Double_t fEff10
Definition: CrossValidation.h:68
TMVA::CrossValidationResult::GetEff30Values
std::vector< Double_t > GetEff30Values() const
Definition: CrossValidation.h:114
TString.h
bool
TMVA::CrossValidationFoldResult::fSig
Double_t fSig
Definition: CrossValidation.h:65
TMVA::CrossValidation::fCorrelations
Bool_t fCorrelations
Definition: CrossValidation.h:152
TMVA::CrossValidationResult::fTrainEff10s
std::vector< Double_t > fTrainEff10s
Definition: CrossValidation.h:92
Envelope.h
TMultiGraph.h
TMVA::CrossValidationResult::GetEff01Values
std::vector< Double_t > GetEff01Values() const
Definition: CrossValidation.h:112
TMVA::CrossValidation::fAnalysisType
Types::EAnalysisType fAnalysisType
Definition: CrossValidation.h:149
TMVA::CrossValidationFoldResult::fROC
TGraph fROC
Definition: CrossValidation.h:63
TMVA::CrossValidationResult::DrawAvgROCCurve
TCanvas * DrawAvgROCCurve(Bool_t drawFolds=kFALSE, TString title="") const
Definition: CrossValidation.cxx:187
TMVA::CrossValidation::GetFactory
Factory & GetFactory()
Definition: CrossValidation.h:140
TMVA::CrossValidationResult::GetTrainEff10Values
std::vector< Double_t > GetTrainEff10Values() const
Definition: CrossValidation.h:117
TMVA::CrossValidationResult::fEff01s
std::vector< Double_t > fEff01s
Definition: CrossValidation.h:87
TMVA::CrossValidation::fSplitTypeStr
TString fSplitTypeStr
Definition: CrossValidation.h:151
TMVA::CrossValidation::GetResults
const std::vector< CrossValidationResult > & GetResults() const
Definition: CrossValidation.cxx:700
TMVA::CrossValidation::fVerboseLevel
TString fVerboseLevel
Definition: CrossValidation.h:170
TMVA::Types::EAnalysisType
EAnalysisType
Definition: Types.h:128
TMVA::CrossValidationResult::GetROCStandardDeviation
Float_t GetROCStandardDeviation() const
Definition: CrossValidation.cxx:142
TMVA::CrossValidationResult::fTrainEff30s
std::vector< Double_t > fTrainEff30s
Definition: CrossValidation.h:93
TMVA::CrossValidation::fNumWorkerProcs
UInt_t fNumWorkerProcs
Number of folds to prepare.
Definition: CrossValidation.h:159
TMVA::CrossValidationResult::Fill
void Fill(CrossValidationFoldResult const &fr)
Definition: CrossValidation.cxx:73
TMVA::CrossValidation::Evaluate
void Evaluate()
Does training, test set evaluation and performance evaluation of using cross-evalution.
Definition: CrossValidation.cxx:588
kFALSE
const Bool_t kFALSE
Definition: RtypesCore.h:101
TMVA::CrossValidation::fFoldFactory
std::unique_ptr< Factory > fFoldFactory
Definition: CrossValidation.h:172
TMVA::CrossValidationFoldResult::fTrainEff30
Double_t fTrainEff30
Definition: CrossValidation.h:73
TMVA::EventTypes_t
std::vector< Bool_t > EventTypes_t
Definition: CrossValidation.h:49
TMVA::CrossValidationResult::fTrainEff01s
std::vector< Double_t > fTrainEff01s
Definition: CrossValidation.h:91
TMVA::CrossValidation::fCvFactoryOptions
TString fCvFactoryOptions
Definition: CrossValidation.h:153
Event.h
TMVA::CrossValidation::GetNumFolds
UInt_t GetNumFolds()
Definition: CrossValidation.h:137
TMVA::CrossValidation::ProcessFold
CrossValidationFoldResult ProcessFold(UInt_t iFold, const OptionMap &methodInfo)
Evaluates each fold in turn.
Definition: CrossValidation.cxx:507
TMVA::CrossValidationFoldResult::fFold
UInt_t fFold
Definition: CrossValidation.h:60
TMVA::CrossValidation::fFactory
std::unique_ptr< Factory > fFactory
Definition: CrossValidation.h:173
TMVA::Factory
This is the main MVA steering class.
Definition: Factory.h:80
UInt_t
unsigned int UInt_t
Definition: RtypesCore.h:46
TMVA::CrossValidationResult
Class to save the results of cross validation, the metric for the classification ins ROC and you can ...
Definition: CrossValidation.h:78
Types.h
Configurable.h
TMVA::CrossValidation::fAnalysisTypeStr
TString fAnalysisTypeStr
Definition: CrossValidation.h:150
TMVA::CrossValidation::SetSplitExpr
void SetSplitExpr(TString splitExpr)
Definition: CrossValidation.cxx:487
TFile
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:54
unsigned int
TMVA::CrossValidationResult::CrossValidationResult
CrossValidationResult(UInt_t numFolds)
Definition: CrossValidation.cxx:41
TMultiGraph
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:36
OptionMap.h
TMVA::CrossValidationResult::GetTrainEff30Values
std::vector< Double_t > GetTrainEff30Values() const
Definition: CrossValidation.h:118
TMVA::CrossValidation::fSilent
Bool_t fSilent
Definition: CrossValidation.h:164
Double_t
double Double_t
Definition: RtypesCore.h:59
TGraph
A TGraph is an object made of two arrays X and Y with npoints each.
Definition: TGraph.h:41
TCanvas
The Canvas class.
Definition: TCanvas.h:23
TMVA::CrossValidation::fDrawProgressBar
Bool_t fDrawProgressBar
Definition: CrossValidation.h:154
TMVA::CrossValidation::ParseOptions
void ParseOptions()
Method to parse the internal option string.
Definition: CrossValidation.cxx:380
ClassDef
#define ClassDef(name, id)
Definition: Rtypes.h:325
Factory.h
TMVA::CrossValidationResult::Print
void Print() const
Definition: CrossValidation.cxx:154
name
char name[80]
Definition: TGX11.cxx:110
TMVA::CrossValidation::InitOptions
void InitOptions()
Definition: CrossValidation.cxx:323
TMVA::CrossValidationFoldResult::fTrainEff10
Double_t fTrainEff10
Definition: CrossValidation.h:72
TMVA::CrossValidationResult::GetROCCurves
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
Definition: CrossValidation.cxx:92
TMVA::CrossValidation::fOutputFactoryOptions
TString fOutputFactoryOptions
Number of processes to use for fold evaluation.
Definition: CrossValidation.h:161
TMVA::CrossValidationResult::GetTrainEff01Values
std::vector< Double_t > GetTrainEff01Values() const
Definition: CrossValidation.h:116
TMVA::CrossValidationResult::fEffAreas
std::vector< Double_t > fEffAreas
Definition: CrossValidation.h:90
Results.h
TMVA::CrossValidationResult::fEff30s
std::vector< Double_t > fEff30s
Definition: CrossValidation.h:89
TMVA::CrossValidation::fTransformations
TString fTransformations
Definition: CrossValidation.h:168
TMVA::CrossValidationResult::GetAvgROCCurve
TGraph * GetAvgROCCurve(UInt_t numSamples=100) const
Generates a multigraph that contains an average ROC Curve.
Definition: CrossValidation.cxx:107
TMVA::CrossValidationResult::fSeps
std::vector< Double_t > fSeps
Definition: CrossValidation.h:86
TMVA::CrossValidationFoldResult::CrossValidationFoldResult
CrossValidationFoldResult()
Definition: CrossValidation.h:55
TMVA::CrossValidationFoldResult::fROCIntegral
Float_t fROCIntegral
Definition: CrossValidation.h:62
DataSet.h
TMVA::CrossValidationResult::fSigs
std::vector< Double_t > fSigs
Definition: CrossValidation.h:85
TMVA::CrossValidationResult::Draw
TCanvas * Draw(const TString name="CrossValidation") const
Definition: CrossValidation.cxx:173
TMVA::CrossValidation::fFoldStatus
Bool_t fFoldStatus
If true: generate output file for each fold.
Definition: CrossValidation.h:156
TMVA::CrossValidation
Class to perform cross validation, splitting the dataloader into folds.
Definition: CrossValidation.h:124
TMVA
create variable transformations
Definition: GeneticMinimizer.h:22
TMVA::DataLoader
Definition: DataLoader.h:50