Logo ROOT   6.10/09
Reference Guide
CrossValidation.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Omar Zapata, Thomas James Stevenson and Pourya Vakilipourtakalou. 2016
3 
4 #ifndef ROOT_TMVA_CrossValidation
5 #define ROOT_TMVA_CrossValidation
6 
7 #include "TString.h"
8 
9 #include "TMultiGraph.h"
10 
11 #include "TMVA/IMethod.h"
12 
13 #include "TMVA/Configurable.h"
14 
15 #include "TMVA/Types.h"
16 
17 #include "TMVA/DataSet.h"
18 
19 #include "TMVA/Event.h"
20 
21 #include <TMVA/Results.h>
22 
23 #include <TMVA/Factory.h>
24 
25 #include <TMVA/DataLoader.h>
26 
27 #include <TMVA/OptionMap.h>
28 
29 #include <TMVA/Envelope.h>
30 
31 namespace TMVA {
32 
34  friend class CrossValidation;
35 
36  private:
37  std::map<UInt_t,Float_t> fROCs;
38  std::shared_ptr<TMultiGraph> fROCCurves;
39 
40  std::vector<Double_t> fSigs;
41  std::vector<Double_t> fSeps;
42  std::vector<Double_t> fEff01s;
43  std::vector<Double_t> fEff10s;
44  std::vector<Double_t> fEff30s;
45  std::vector<Double_t> fEffAreas;
46  std::vector<Double_t> fTrainEff01s;
47  std::vector<Double_t> fTrainEff10s;
48  std::vector<Double_t> fTrainEff30s;
49 
50  public:
53  ~CrossValidationResult(){fROCCurves=nullptr;}
54 
55  std::map<UInt_t,Float_t> GetROCValues(){return fROCs;}
56  Float_t GetROCAverage() const;
59  void Print() const ;
60 
61  TCanvas* Draw(const TString name="CrossValidation") const;
62 
63  std::vector<Double_t> GetSigValues() {return fSigs;}
64  std::vector<Double_t> GetSepValues() {return fSeps;}
65  std::vector<Double_t> GetEff01Values() {return fEff01s;}
66  std::vector<Double_t> GetEff10Values() {return fEff10s;}
67  std::vector<Double_t> GetEff30Values() {return fEff30s;}
68  std::vector<Double_t> GetEffAreaValues() {return fEffAreas;}
69  std::vector<Double_t> GetTrainEff01Values() {return fTrainEff01s;}
70  std::vector<Double_t> GetTrainEff10Values() {return fTrainEff10s;}
71  std::vector<Double_t> GetTrainEff30Values() {return fTrainEff30s;}
72  };
73 
74 
75  class CrossValidation : public Envelope {
79  public:
80  explicit CrossValidation(DataLoader *loader);
81  ~CrossValidation();
82 
83  void SetNumFolds(UInt_t i);
84  UInt_t GetNumFolds() {return fNumFolds;}
85 
86  virtual void Evaluate();
87 // void EvaluateFold(UInt_t fold);//used in ParallelExecution
88 
89  const CrossValidationResult& GetResults() const;
90 
91  private:
92  std::unique_ptr<Factory> fClassifier;
94  };
95 
96 } // namespace TMVA
97 
98 #endif // ROOT_TMVA_CrossValidation
std::vector< Double_t > fSigs
float Float_t
Definition: RtypesCore.h:53
std::map< UInt_t, Float_t > GetROCValues()
CrossValidationResult fResults
std::vector< Double_t > GetEff01Values()
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:35
Basic string class.
Definition: TString.h:129
bool Bool_t
Definition: RtypesCore.h:59
std::unique_ptr< Factory > fClassifier
std::vector< Double_t > GetEff10Values()
std::vector< Double_t > GetSigValues()
std::vector< Double_t > fEff10s
std::vector< Double_t > GetEff30Values()
#define ClassDef(name, id)
Definition: Rtypes.h:297
std::vector< Double_t > GetEffAreaValues()
Base class for all machine learning algorithms.
Definition: Envelope.h:35
std::vector< Double_t > fTrainEff01s
std::vector< Double_t > fTrainEff10s
std::vector< Double_t > fEff01s
unsigned int UInt_t
Definition: RtypesCore.h:42
std::vector< Double_t > GetTrainEff01Values()
Float_t GetROCStandardDeviation() const
std::vector< Double_t > fTrainEff30s
The Canvas class.
Definition: TCanvas.h:31
std::vector< Double_t > GetTrainEff30Values()
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
std::vector< Double_t > fEffAreas
std::vector< Double_t > fSeps
Abstract ClassifierFactory template that handles arbitrary types.
std::map< UInt_t, Float_t > fROCs
TCanvas * Draw(const TString name="CrossValidation") const
std::vector< Double_t > fEff30s
std::vector< Double_t > GetTrainEff10Values()
const Bool_t kTRUE
Definition: RtypesCore.h:91
std::shared_ptr< TMultiGraph > fROCCurves
std::vector< Double_t > GetSepValues()