Logo ROOT  
Reference Guide
Loading...
Searching...
No Matches
HyperParameterOptimisation.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Omar Zapata, Thomas James Stevenson.
3
5
6#include "TMVA/Configurable.h"
7#include "TMVA/CvSplit.h"
8#include "TMVA/DataSet.h"
9#include "TMVA/Event.h"
10#include "TMVA/MethodBase.h"
12#include "TMVA/Types.h"
13
14#include "TMultiGraph.h"
15#include "TString.h"
16
17#include <memory>
18#include <vector>
19
20/*! \class TMVA::HyperParameterOptimisationResult
21\ingroup TMVA
22
23*/
24
25/*! \class TMVA::HyperParameterOptimisation
26\ingroup TMVA
27
28*/
29
30//_______________________________________________________________________
35
36//_______________________________________________________________________
40
41//_______________________________________________________________________
47
48//_______________________________________________________________________
50{
53
54 MsgLogger fLogger("HyperParameterOptimisation");
55
56 for(UInt_t j=0; j<fFoldParameters.size(); ++j) {
57 fLogger<<kHEADER<< "===========================================================" << Endl;
58 fLogger<<kINFO<< "Optimisation for " << fMethodName << " fold " << j+1 << Endl;
59
60 for(auto &it : fFoldParameters.at(j)) {
61 fLogger<<kINFO<< it.first << " " << it.second << Endl;
62 }
63 }
64
66
67}
68
69//_______________________________________________________________________
71 fFomType("Separation"),
72 fFitType("Minuit"),
73 fNumFolds(5),
74 fResults(),
75 fClassifier(new TMVA::Factory("HyperParameterOptimisation","!V:!ROC:Silent:!ModelPersistence:!Color:!DrawProgressBar:AnalysisType=Classification"))
76{
78}
79
80//_______________________________________________________________________
85
86//_______________________________________________________________________
88{
89 fNumFolds = i;
90 // fDataLoader->MakeKFoldDataSet(fNumFolds);
92}
93
94//_______________________________________________________________________
96{
97 for (auto &meth : fMethods) {
98 TString methodName = meth.GetValue<TString>("MethodName");
99 TString methodTitle = meth.GetValue<TString>("MethodTitle");
100 TString methodOptions = meth.GetValue<TString>("MethodOptions");
101
102 CvSplitKFolds split{fNumFolds, "", kFALSE, 0};
103 if (!fFoldStatus) {
104 fDataLoader->MakeKFoldDataSet(split);
106 }
107 fResults.fMethodName = methodName;
108
109 for (UInt_t i = 0; i < fNumFolds; ++i) {
110 TString foldTitle = methodTitle;
111 foldTitle += "_opt";
112 foldTitle += i + 1;
113
115 fDataLoader->PrepareFoldDataSet(split, i, TMVA::Types::kTraining);
116
117 auto smethod = fClassifier->BookMethod(fDataLoader.get(), methodName, methodTitle, methodOptions);
118
119 auto params = smethod->OptimizeTuningParameters(fFomType, fFitType);
120 fResults.fFoldParameters.push_back(params);
121
122 smethod->Data()->DeleteResults(smethod->GetMethodName(), Types::kTraining, Types::kClassification);
123
124 fClassifier->DeleteAllMethods();
125
126 fClassifier->fMethodsMap.clear();
127 }
128 }
129}
unsigned int UInt_t
Unsigned integer 4 bytes (unsigned int).
Definition RtypesCore.h:60
bool Bool_t
Boolean (0=false, 1=true) (bool).
Definition RtypesCore.h:77
constexpr Bool_t kFALSE
Definition RtypesCore.h:108
constexpr Bool_t kTRUE
Definition RtypesCore.h:107
void SetSilent(Bool_t s)
Definition Config.h:63
std::vector< OptionMap > fMethods
! Booked method information
Definition Envelope.h:46
std::shared_ptr< DataLoader > fDataLoader
! data
Definition Envelope.h:47
Envelope(const TString &name, DataLoader *dataloader=nullptr, TFile *file=nullptr, const TString options="")
Constructor for the initialization of Envelopes, differents Envelopes may needs differents constructo...
Definition Envelope.cxx:40
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition Event.cxx:399
This is the main MVA steering class.
Definition Factory.h:80
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
std::vector< std::map< TString, Double_t > > fFoldParameters
HyperParameterOptimisationResult fResults
!
void Evaluate() override
Virtual method to be implemented with your algorithm.
ostringstream derivative to redirect and format output
Definition MsgLogger.h:57
static void EnableOutput()
Definition MsgLogger.cxx:67
@ kClassification
Definition Types.h:127
@ kTraining
Definition Types.h:143
Basic string class.
Definition TString.h:138
create variable transformations
Config & gConfig()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148