Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
CrossValidation.h
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Omar Zapata, Thomas James Stevenson, Pourya Vakilipourtakalou, Kim Albertsson
3
4/*************************************************************************
5 * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12#ifndef ROOT_TMVA_CROSS_EVALUATION
13#define ROOT_TMVA_CROSS_EVALUATION
14
15#include "TGraph.h"
16#include "TMultiGraph.h"
17#include "TString.h"
18#include <vector>
19#include <map>
20
21#include "TMVA/IMethod.h"
22#include "TMVA/Configurable.h"
23#include "TMVA/Types.h"
24#include "TMVA/DataSet.h"
25#include "TMVA/Event.h"
26#include <TMVA/Results.h>
27#include <TMVA/Factory.h>
28#include <TMVA/DataLoader.h>
29#include <TMVA/OptionMap.h>
30#include <TMVA/Envelope.h>
31
32/*! \class TMVA::CrossValidationResult
33 * Class to save the results of cross validation,
34 * the metric for the classification ins ROC and you can ROC curves
35 * ROC integrals, ROC average and ROC standard deviation.
36\ingroup TMVA
37*/
38
39/*! \class TMVA::CrossValidation
40 * Class to perform cross validation, splitting the dataloader into folds.
41\ingroup TMVA
42*/
43
44namespace TMVA {
45
46class CvSplitKFolds;
47
48using EventCollection_t = std::vector<Event *>;
49using EventTypes_t = std::vector<Bool_t>;
50using EventOutputs_t = std::vector<Float_t>;
51using EventOutputsMulticlass_t = std::vector<std::vector<Float_t>>;
52
54public:
55 CrossValidationFoldResult() {} // For multi-proc serialisation
57 : fFold(iFold)
58 {}
59
61
64
74};
75
76// Used internally to keep per-fold aggregate statistics
77// such as ROC curves, ROC integrals and efficiencies.
79 friend class CrossValidation;
80
81private:
82 std::map<UInt_t, Float_t> fROCs;
83 std::shared_ptr<TMultiGraph> fROCCurves;
84
85 std::vector<Double_t> fSigs;
86 std::vector<Double_t> fSeps;
87 std::vector<Double_t> fEff01s;
88 std::vector<Double_t> fEff10s;
89 std::vector<Double_t> fEff30s;
90 std::vector<Double_t> fEffAreas;
91 std::vector<Double_t> fTrainEff01s;
92 std::vector<Double_t> fTrainEff10s;
93 std::vector<Double_t> fTrainEff30s;
94
95public:
99
100 std::map<UInt_t, Float_t> GetROCValues() const { return fROCs; }
101 Float_t GetROCAverage() const;
104 TGraph *GetAvgROCCurve(UInt_t numSamples = 100) const;
105 void Print() const;
106
107 TCanvas *Draw(const TString name = "CrossValidation") const;
108 TCanvas *DrawAvgROCCurve(Bool_t drawFolds=kFALSE, TString title="") const;
109
110 std::vector<Double_t> GetSigValues() const { return fSigs; }
111 std::vector<Double_t> GetSepValues() const { return fSeps; }
112 std::vector<Double_t> GetEff01Values() const { return fEff01s; }
113 std::vector<Double_t> GetEff10Values() const { return fEff10s; }
114 std::vector<Double_t> GetEff30Values() const { return fEff30s; }
115 std::vector<Double_t> GetEffAreaValues() const { return fEffAreas; }
116 std::vector<Double_t> GetTrainEff01Values() const { return fTrainEff01s; }
117 std::vector<Double_t> GetTrainEff10Values() const { return fTrainEff10s; }
118 std::vector<Double_t> GetTrainEff30Values() const { return fTrainEff30s; }
119
120private:
121 void Fill(CrossValidationFoldResult const & fr);
122};
123
124class CrossValidation : public Envelope {
125
126public:
127 explicit CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TString options);
128 explicit CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TFile *outputFile, TString options);
130
131 void InitOptions();
132 void ParseOptions();
133
134 void SetNumFolds(UInt_t i);
135 void SetSplitExpr(TString splitExpr);
136
139
141
142 const std::vector<CrossValidationResult> &GetResults() const;
143
144 void Evaluate();
145
146private:
147 CrossValidationFoldResult ProcessFold(UInt_t iFold, const OptionMap & methodInfo);
148
155 Bool_t fFoldFileOutput; //! If true: generate output file for each fold
156 Bool_t fFoldStatus; //! If true: dataset is prepared
158 UInt_t fNumFolds; //! Number of folds to prepare
159 UInt_t fNumWorkerProcs; //! Number of processes to use for fold evaluation.
160 //!(Default, no parallel evaluation)
162 TString fOutputEnsembling; //! How to combine output of individual folds
166 std::vector<CrossValidationResult> fResults; //!
171
172 std::unique_ptr<Factory> fFoldFactory;
173 std::unique_ptr<Factory> fFactory;
174 std::unique_ptr<CvSplitKFolds> fSplit;
175
177 };
178
179} // namespace TMVA
180
181#endif // ROOT_TMVA_CROSS_EVALUATION
unsigned int UInt_t
Definition RtypesCore.h:46
const Bool_t kFALSE
Definition RtypesCore.h:101
bool Bool_t
Definition RtypesCore.h:63
double Double_t
Definition RtypesCore.h:59
float Float_t
Definition RtypesCore.h:57
const Bool_t kTRUE
Definition RtypesCore.h:100
#define ClassDef(name, id)
Definition Rtypes.h:325
char name[80]
Definition TGX11.cxx:110
The Canvas class.
Definition TCanvas.h:23
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition TFile.h:54
A TGraph is an object made of two arrays X and Y with npoints each.
Definition TGraph.h:41
Class to save the results of cross validation, the metric for the classification ins ROC and you can ...
std::vector< Double_t > fSeps
std::vector< Double_t > fEff01s
std::vector< Double_t > GetTrainEff10Values() const
std::vector< Double_t > fTrainEff30s
std::shared_ptr< TMultiGraph > fROCCurves
std::vector< Double_t > GetTrainEff30Values() const
std::vector< Double_t > fSigs
std::vector< Double_t > fEff30s
void Fill(CrossValidationFoldResult const &fr)
Float_t GetROCStandardDeviation() const
std::vector< Double_t > fEff10s
std::vector< Double_t > fTrainEff01s
std::vector< Double_t > GetEff10Values() const
std::map< UInt_t, Float_t > fROCs
std::vector< Double_t > fTrainEff10s
std::vector< Double_t > GetTrainEff01Values() const
std::vector< Double_t > fEffAreas
std::vector< Double_t > GetEff01Values() const
TCanvas * DrawAvgROCCurve(Bool_t drawFolds=kFALSE, TString title="") const
std::vector< Double_t > GetSigValues() const
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
TGraph * GetAvgROCCurve(UInt_t numSamples=100) const
Generates a multigraph that contains an average ROC Curve.
std::map< UInt_t, Float_t > GetROCValues() const
std::vector< Double_t > GetEffAreaValues() const
std::vector< Double_t > GetSepValues() const
std::vector< Double_t > GetEff30Values() const
Class to perform cross validation, splitting the dataloader into folds.
void ParseOptions()
Method to parse the internal option string.
const std::vector< CrossValidationResult > & GetResults() const
std::vector< CrossValidationResult > fResults
std::unique_ptr< Factory > fFoldFactory
Bool_t fFoldStatus
If true: generate output file for each fold.
std::unique_ptr< CvSplitKFolds > fSplit
TFile * fOutputFile
How to combine output of individual folds.
Types::EAnalysisType fAnalysisType
void SetSplitExpr(TString splitExpr)
void Evaluate()
Does training, test set evaluation and performance evaluation of using cross-evalution.
TString fOutputFactoryOptions
Number of processes to use for fold evaluation.
std::unique_ptr< Factory > fFactory
UInt_t fNumWorkerProcs
Number of folds to prepare.
TString fJobName
If true: dataset is prepared.
CrossValidationFoldResult ProcessFold(UInt_t iFold, const OptionMap &methodInfo)
Evaluates each fold in turn.
Abstract base class for all high level ml algorithms, you can book ml methods like BDT,...
Definition Envelope.h:44
This is the main MVA steering class.
Definition Factory.h:80
class to storage options for the differents methods
Definition OptionMap.h:34
std::vector< Float_t > EventOutputs_t
std::vector< std::vector< Float_t > > EventOutputsMulticlass_t
std::vector< Event * > EventCollection_t
std::vector< Bool_t > EventTypes_t
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition TMultiGraph.h:36
Basic string class.
Definition TString.h:136
create variable transformations
th1 Draw()