Logo ROOT   6.10/09
Reference Guide
HyperParameterOptimisation.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Omar Zapata, Thomas James Stevenson.
3 
5 
6 #include "TMVA/Configurable.h"
7 #include "TMVA/DataSet.h"
8 #include "TMVA/Event.h"
9 #include "TMVA/MethodBase.h"
11 #include "TMVA/Types.h"
12 
13 #include "TGraph.h"
14 #include "TMultiGraph.h"
15 #include "TString.h"
16 #include "TSystem.h"
17 
18 #include <iostream>
19 #include <memory>
20 #include <vector>
21 
22 /*! \class TMVA::HyperParameterOptimisationResult
23 \ingroup TMVA
24 
25 */
26 
27 /*! \class TMVA::HyperParameterOptimisation
28 \ingroup TMVA
29 
30 */
31 
33  : fROCAVG(0.0), fROCCurves(std::make_shared<TMultiGraph>())
34 {
35 }
36 
38 {
39 }
40 
42 {
43 
44  return fROCCurves.get();
45 }
46 
48 {
51 
52  MsgLogger fLogger("HyperParameterOptimisation");
53 
54  for(UInt_t j=0; j<fFoldParameters.size(); ++j) {
55  fLogger<<kHEADER<< "===========================================================" << Endl;
56  fLogger<<kINFO<< "Optimisation for " << fMethodName << " fold " << j+1 << Endl;
57 
58  for(auto &it : fFoldParameters.at(j)) {
59  fLogger<<kINFO<< it.first << " " << it.second << Endl;
60  }
61  }
62 
64 
65 }
66 
68  fFomType("Separation"),
69  fFitType("Minuit"),
70  fNumFolds(5),
71  fResults(),
72  fClassifier(new TMVA::Factory("HyperParameterOptimisation","!V:!ROC:Silent:!ModelPersistence:!Color:!DrawProgressBar:AnalysisType=Classification"))
73 {
75 }
76 
78 {
79  fClassifier=nullptr;
80 }
81 
83 {
84  fNumFolds=i;
85  fDataLoader->MakeKFoldDataSet(fNumFolds);
87 }
88 
90 {
91  TString methodName = fMethod.GetValue<TString>("MethodName");
92  TString methodTitle = fMethod.GetValue<TString>("MethodTitle");
93  TString methodOptions = fMethod.GetValue<TString>("MethodOptions");
94 
95  if(!fFoldStatus)
96  {
97  fDataLoader->MakeKFoldDataSet(fNumFolds);
99  }
100  fResults.fMethodName = methodName;
101 
102  for(UInt_t i = 0; i < fNumFolds; ++i) {
103 
104  TString foldTitle = methodTitle;
105  foldTitle += "_opt";
106  foldTitle += i+1;
107 
109  fDataLoader->PrepareFoldDataSet(i, TMVA::Types::kTraining);
110 
111  auto smethod = fClassifier->BookMethod(fDataLoader.get(), methodName, methodTitle, methodOptions);
112 
113  auto params=smethod->OptimizeTuningParameters(fFomType,fFitType);
114  fResults.fFoldParameters.push_back(params);
115 
116  smethod->Data()->DeleteResults(smethod->GetMethodName(), Types::kTraining, Types::kClassification);
117 
118  fClassifier->DeleteAllMethods();
119 
120  fClassifier->fMethodsMap.clear();
121 
122  }
123 
124 }
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
T GetValue(const TString &key)
Definition: OptionMap.h:144
Config & gConfig()
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:35
Basic string class.
Definition: TString.h:129
bool Bool_t
Definition: RtypesCore.h:59
HyperParameterOptimisationResult fResults
STL namespace.
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition: Event.cxx:392
OptionMap fMethod
Definition: Envelope.h:38
Base class for all machine learning algorithms.
Definition: Envelope.h:35
virtual void Evaluate()
Virtual method to be implemented with your algorithm.
HyperParameterOptimisation(DataLoader *dataloader)
unsigned int UInt_t
Definition: RtypesCore.h:42
This is the main MVA steering class.
Definition: Factory.h:81
const Bool_t kFALSE
Definition: RtypesCore.h:92
std::shared_ptr< DataLoader > fDataLoader
Definition: Envelope.h:39
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
std::vector< std::map< TString, Double_t > > fFoldParameters
Abstract ClassifierFactory template that handles arbitrary types.
void SetSilent(Bool_t s)
Definition: Config.h:60
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
static void EnableOutput()
Definition: MsgLogger.cxx:75
const Bool_t kTRUE
Definition: RtypesCore.h:91