Logo ROOT   6.12/07
Reference Guide
CrossValidation.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Omar Zapata, Thomas James Stevenson and Pourya Vakilipourtakalou. 2016
3 
4 #ifndef ROOT_TMVA_CrossValidation
5 #define ROOT_TMVA_CrossValidation
6 
7 #include "TString.h"
8 
9 #include "TMultiGraph.h"
10 
11 #include "TMVA/IMethod.h"
12 
13 #include "TMVA/Configurable.h"
14 
15 #include "TMVA/Types.h"
16 
17 #include "TMVA/DataSet.h"
18 
19 #include "TMVA/Event.h"
20 
21 #include <TMVA/Results.h>
22 
23 #include <TMVA/Factory.h>
24 
25 #include <TMVA/DataLoader.h>
26 
27 #include <TMVA/OptionMap.h>
28 
29 #include <TMVA/Envelope.h>
30 
31 /*! \class TMVA::CrossValidationResult
32  * Class to save the results of cross validation,
33  * the metric for the classification ins ROC and you can ROC curves
34  * ROC integrals, ROC average and ROC standard deviation.
35 \ingroup TMVA
36 */
37 
38 /*! \class TMVA::CrossValidation
39  * Class to perform cross validation, splitting the dataloader into folds.
40 \ingroup TMVA
41 */
42 
43 namespace TMVA {
44 
46  friend class CrossValidation;
47 
48  private:
49  std::map<UInt_t,Float_t> fROCs;
50  std::shared_ptr<TMultiGraph> fROCCurves;
51 
52  std::vector<Double_t> fSigs;
53  std::vector<Double_t> fSeps;
54  std::vector<Double_t> fEff01s;
55  std::vector<Double_t> fEff10s;
56  std::vector<Double_t> fEff30s;
57  std::vector<Double_t> fEffAreas;
58  std::vector<Double_t> fTrainEff01s;
59  std::vector<Double_t> fTrainEff10s;
60  std::vector<Double_t> fTrainEff30s;
61 
62  public:
65  ~CrossValidationResult(){fROCCurves=nullptr;}
66 
67  std::map<UInt_t,Float_t> GetROCValues(){return fROCs;}
68  Float_t GetROCAverage() const;
71  void Print() const ;
72 
73  TCanvas* Draw(const TString name="CrossValidation") const;
74 
75  std::vector<Double_t> GetSigValues() {return fSigs;}
76  std::vector<Double_t> GetSepValues() {return fSeps;}
77  std::vector<Double_t> GetEff01Values() {return fEff01s;}
78  std::vector<Double_t> GetEff10Values() {return fEff10s;}
79  std::vector<Double_t> GetEff30Values() {return fEff30s;}
80  std::vector<Double_t> GetEffAreaValues() {return fEffAreas;}
81  std::vector<Double_t> GetTrainEff01Values() {return fTrainEff01s;}
82  std::vector<Double_t> GetTrainEff10Values() {return fTrainEff10s;}
83  std::vector<Double_t> GetTrainEff30Values() {return fTrainEff30s;}
84  };
85 
86 
87  class CrossValidation : public Envelope {
89  std::vector<CrossValidationResult> fResults; //!
91  public:
92  explicit CrossValidation(DataLoader *loader);
93  ~CrossValidation();
94 
95  void SetNumFolds(UInt_t i);
96  UInt_t GetNumFolds() {return fNumFolds;}
97 
98  virtual void Evaluate();
99 // void EvaluateFold(UInt_t fold);//used in ParallelExecution
100 
101  const std::vector<CrossValidationResult> &GetResults() const;
102 
103  private:
104  std::unique_ptr<Factory> fClassifier;
106  };
107 
108 } // namespace TMVA
109 
110 #endif // ROOT_TMVA_CrossValidation
std::vector< Double_t > fSigs
float Float_t
Definition: RtypesCore.h:53
std::map< UInt_t, Float_t > GetROCValues()
std::vector< Double_t > GetEff01Values()
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:35
Class to save the results of cross validation, the metric for the classification ins ROC and you can ...
Basic string class.
Definition: TString.h:125
bool Bool_t
Definition: RtypesCore.h:59
std::unique_ptr< Factory > fClassifier
std::vector< Double_t > GetEff10Values()
std::vector< Double_t > GetSigValues()
std::vector< Double_t > fEff10s
std::vector< Double_t > GetEff30Values()
#define ClassDef(name, id)
Definition: Rtypes.h:320
std::vector< CrossValidationResult > fResults
std::vector< Double_t > GetEffAreaValues()
Abstract base class for all high level ml algorithms, you can book ml methods like BDT...
Definition: Envelope.h:43
std::vector< Double_t > fTrainEff01s
std::vector< Double_t > fTrainEff10s
std::vector< Double_t > fEff01s
unsigned int UInt_t
Definition: RtypesCore.h:42
std::vector< Double_t > GetTrainEff01Values()
Float_t GetROCStandardDeviation() const
std::vector< Double_t > fTrainEff30s
The Canvas class.
Definition: TCanvas.h:31
Class to perform cross validation, splitting the dataloader into folds.
std::vector< Double_t > GetTrainEff30Values()
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
std::vector< Double_t > fEffAreas
std::vector< Double_t > fSeps
Abstract ClassifierFactory template that handles arbitrary types.
std::map< UInt_t, Float_t > fROCs
TCanvas * Draw(const TString name="CrossValidation") const
std::vector< Double_t > fEff30s
std::vector< Double_t > GetTrainEff10Values()
const Bool_t kTRUE
Definition: RtypesCore.h:87
std::shared_ptr< TMultiGraph > fROCCurves
std::vector< Double_t > GetSepValues()
char name[80]
Definition: TGX11.cxx:109