Logo ROOT   6.12/07
Reference Guide
CrossValidation.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Omar Zapata, Thomas James Stevenson.
3 
4 #include "TMVA/CrossValidation.h"
5 
6 #include "TMVA/Config.h"
7 #include "TMVA/DataSet.h"
8 #include "TMVA/Event.h"
9 #include "TMVA/MethodBase.h"
10 #include "TMVA/MsgLogger.h"
12 #include "TMVA/tmvaglob.h"
13 #include "TMVA/Types.h"
14 
15 #include "TSystem.h"
16 #include "TAxis.h"
17 #include "TCanvas.h"
18 #include "TGraph.h"
19 #include "TMath.h"
20 
21 #include <iostream>
22 #include <memory>
23 
24 //_______________________________________________________________________
26 {
27 }
28 
29 //_______________________________________________________________________
31 {
32  fROCs=obj.fROCs;
33  fROCCurves = obj.fROCCurves;
34 }
35 
36 //_______________________________________________________________________
38 {
39  return fROCCurves.get();
40 }
41 
42 //_______________________________________________________________________
44 {
45  Float_t avg=0;
46  for(auto &roc:fROCs) avg+=roc.second;
47  return avg/fROCs.size();
48 }
49 
50 //_______________________________________________________________________
52 {
53  // NOTE: We are using here the unbiased estimation of the standard deviation.
54  Float_t std=0;
55  Float_t avg=GetROCAverage();
56  for(auto &roc:fROCs) std+=TMath::Power(roc.second-avg, 2);
57  return TMath::Sqrt(std/float(fROCs.size()-1.0));
58 }
59 
60 //_______________________________________________________________________
62 {
65 
66  MsgLogger fLogger("CrossValidation");
67  fLogger << kHEADER << " ==== Results ====" << Endl;
68  for(auto &item:fROCs)
69  fLogger << kINFO << Form("Fold %i ROC-Int : %.4f",item.first,item.second) << std::endl;
70 
71  fLogger << kINFO << "------------------------" << Endl;
72  fLogger << kINFO << Form("Average ROC-Int : %.4f",GetROCAverage()) << Endl;
73  fLogger << kINFO << Form("Std-Dev ROC-Int : %.4f",GetROCStandardDeviation()) << Endl;
74 
76 }
77 
78 //_______________________________________________________________________
80 {
81  TCanvas *c=new TCanvas(name.Data());
82  fROCCurves->Draw("AL");
83  fROCCurves->GetXaxis()->SetTitle(" Signal Efficiency ");
84  fROCCurves->GetYaxis()->SetTitle(" Background Rejection ");
85  Float_t adjust=1+fROCs.size()*0.01;
86  c->BuildLegend(0.15,0.15,0.4*adjust,0.5*adjust);
87  c->SetTitle("Cross Validation ROC Curves");
88  c->Draw();
89  return c;
90 }
91 
92 //_______________________________________________________________________
94 fNumFolds(5),fClassifier(new TMVA::Factory("CrossValidation","!V:!ROC:Silent:!ModelPersistence:!Color:!DrawProgressBar:AnalysisType=Classification"))
95 {
97  ParseOptions();
98 }
99 
100 //_______________________________________________________________________
102 {
103  fClassifier=nullptr;
104 }
105 
106 //_______________________________________________________________________
108 {
109  fNumFolds=i;
110  fDataLoader->MakeKFoldDataSet(fNumFolds);
112 }
113 
114 //_______________________________________________________________________
116 {
117  fResults.resize(fMethods.size());
118  for (UInt_t j = 0; j < fMethods.size(); j++) {
119 
120  TString methodName = fMethods[j].GetValue<TString>("MethodName");
121  TString methodTitle = fMethods[j].GetValue<TString>("MethodTitle");
122  TString methodOptions = fMethods[j].GetValue<TString>("MethodOptions");
123  if (methodName == "")
124  Log() << kFATAL << "No method booked for cross-validation" << Endl;
125 
128  Log() << kINFO << "Evaluate method: " << methodTitle << Endl;
130 
131  // Generate K folds on given dataset
132  if (!fFoldStatus) {
133  fDataLoader->MakeKFoldDataSet(fNumFolds);
134  fFoldStatus = kTRUE;
135  }
136 
137  // Process K folds
138  for (UInt_t i = 0; i < fNumFolds; ++i) {
139  Log() << kDEBUG << "Fold (" << methodTitle << "): " << i << Endl;
140  // Get specific fold of dataset and setup method
141  TString foldTitle = methodTitle;
142  foldTitle += "_fold";
143  foldTitle += i + 1;
144 
145  fDataLoader->PrepareFoldDataSet(i, TMVA::Types::kTesting);
146  MethodBase *smethod = fClassifier->BookMethod(fDataLoader.get(), methodName, methodTitle, methodOptions);
147 
148  // Train method
150  smethod->TrainMethod();
151 
152  // Test method
154  smethod->AddOutput(Types::kTesting, smethod->GetAnalysisType());
155  smethod->TestClassification();
156 
157  // Store results
158  fResults[j].fROCs[i] = fClassifier->GetROCIntegral(fDataLoader->GetName(), methodTitle);
159 
160  TGraph *gr = fClassifier->GetROCCurve(fDataLoader->GetName(), methodTitle, true);
161  gr->SetLineColor(i + 1);
162  gr->SetLineWidth(2);
163  gr->SetTitle(foldTitle.Data());
164  fResults[j].fROCCurves->Add(gr);
165 
166  fResults[j].fSigs.push_back(smethod->GetSignificance());
167  fResults[j].fSeps.push_back(smethod->GetSeparation());
168 
169  Double_t err;
170  fResults[j].fEff01s.push_back(smethod->GetEfficiency("Efficiency:0.01", Types::kTesting, err));
171  fResults[j].fEff10s.push_back(smethod->GetEfficiency("Efficiency:0.10", Types::kTesting, err));
172  fResults[j].fEff30s.push_back(smethod->GetEfficiency("Efficiency:0.30", Types::kTesting, err));
173  fResults[j].fEffAreas.push_back(smethod->GetEfficiency("", Types::kTesting, err));
174  fResults[j].fTrainEff01s.push_back(smethod->GetTrainingEfficiency("Efficiency:0.01"));
175  fResults[j].fTrainEff10s.push_back(smethod->GetTrainingEfficiency("Efficiency:0.10"));
176  fResults[j].fTrainEff30s.push_back(smethod->GetTrainingEfficiency("Efficiency:0.30"));
177 
178  // Clean-up for this fold
181  fClassifier->DeleteAllMethods();
182  fClassifier->fMethodsMap.clear();
183  }
184  }
186  Log() << kINFO << "Evaluation done." << Endl;
188 }
189 
190 //_______________________________________________________________________
191 const std::vector<TMVA::CrossValidationResult> &TMVA::CrossValidation::GetResults() const
192 {
193  if (fResults.size() == 0)
194  Log() << kFATAL << "No cross-validation results available" << Endl;
195  return fResults;
196 }
virtual void SetLineWidth(Width_t lwidth)
Set the line width.
Definition: TAttLine.h:43
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
void AddOutput(Types::ETreeType type, Types::EAnalysisType analysisType)
float Float_t
Definition: RtypesCore.h:53
void SetTitle(const char *title="")
Set canvas title.
Definition: TCanvas.cxx:1956
Config & gConfig()
MsgLogger & Log() const
Definition: Configurable.h:122
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:35
Virtual base Class for all MVA method.
Definition: MethodBase.h:109
Class to save the results of cross validation, the metric for the classification ins ROC and you can ...
Basic string class.
Definition: TString.h:125
bool Bool_t
Definition: RtypesCore.h:59
std::unique_ptr< Factory > fClassifier
virtual void SetTitle(const char *title="")
Set graph title.
Definition: TGraph.cxx:2208
STL namespace.
LongDouble_t Power(LongDouble_t x, LongDouble_t y)
Definition: TMath.h:627
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition: Event.cxx:392
std::vector< CrossValidationResult > fResults
DataSet * Data() const
Definition: MethodBase.h:398
void SetNumFolds(UInt_t i)
Abstract base class for all high level ml algorithms, you can book ml methods like BDT...
Definition: Envelope.h:43
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
virtual void SetLineColor(Color_t lcolor)
Set the line color.
Definition: TAttLine.h:40
virtual void ParseOptions()
Method to parse the internal option string.
Definition: Envelope.cxx:177
void DeleteResults(const TString &, Types::ETreeType type, Types::EAnalysisType analysistype)
delete the results stored for this particular Method instance.
Definition: DataSet.cxx:316
unsigned int UInt_t
Definition: RtypesCore.h:42
char * Form(const char *fmt,...)
Float_t GetROCStandardDeviation() const
const TString & GetMethodName() const
Definition: MethodBase.h:320
const std::vector< CrossValidationResult > & GetResults() const
This is the main MVA steering class.
Definition: Factory.h:81
virtual Double_t GetSignificance() const
compute significance of mean difference
TGraphErrors * gr
Definition: legend1.C:25
CrossValidation(DataLoader *loader)
virtual void Evaluate()
Virtual method to be implemented with your algorithm.
const Bool_t kFALSE
Definition: RtypesCore.h:88
The Canvas class.
Definition: TCanvas.h:31
double Double_t
Definition: RtypesCore.h:55
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as
std::shared_ptr< DataLoader > fDataLoader
Booked method information.
Definition: Envelope.h:47
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
virtual void Draw(Option_t *option="")
Draw a canvas.
Definition: TCanvas.cxx:826
Abstract ClassifierFactory template that handles arbitrary types.
virtual TLegend * BuildLegend(Double_t x1=0.3, Double_t y1=0.21, Double_t x2=0.3, Double_t y2=0.21, const char *title="", Option_t *option="")
Build a legend from the graphical objects in the pad.
Definition: TPad.cxx:485
void SetSilent(Bool_t s)
Definition: Config.h:68
std::map< UInt_t, Float_t > fROCs
A Graph is a graphics object made of two arrays X and Y with npoints each.
Definition: TGraph.h:41
TCanvas * Draw(const TString name="CrossValidation") const
virtual Double_t GetTrainingEfficiency(const TString &)
Types::EAnalysisType GetAnalysisType() const
Definition: MethodBase.h:426
Double_t Sqrt(Double_t x)
Definition: TMath.h:590
static void EnableOutput()
Definition: MsgLogger.cxx:75
const Bool_t kTRUE
Definition: RtypesCore.h:87
virtual void TestClassification()
initialization
std::shared_ptr< TMultiGraph > fROCCurves
std::vector< OptionMap > fMethods
Definition: Envelope.h:46
char name[80]
Definition: TGX11.cxx:109
const char * Data() const
Definition: TString.h:345