Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
CvSplit.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Kim Albertsson
3
4/*************************************************************************
5 * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12#include "TMVA/CvSplit.h"
13
14#include "TMVA/DataSet.h"
15#include "TMVA/DataSetFactory.h"
16#include "TMVA/DataSetInfo.h"
17#include "TMVA/Event.h"
18#include "TMVA/MsgLogger.h"
19#include "TMVA/Tools.h"
20
21#include <TString.h>
22#include <TFormula.h>
23
24#include <algorithm>
25#include <numeric>
26#include <stdexcept>
27
28
29/* =============================================================================
30 TMVA::CvSplit
31============================================================================= */
32
33////////////////////////////////////////////////////////////////////////////////
34///
35
36TMVA::CvSplit::CvSplit(UInt_t numFolds) : fNumFolds(numFolds), fMakeFoldDataSet(kFALSE) {}
37
38////////////////////////////////////////////////////////////////////////////////
39/// \brief Set training and test set vectors of dataset described by `dsi`.
40/// \param[in] dsi DataSetInfo for data set to be split
41/// \param[in] foldNumber Ordinal of fold to prepare
42/// \param[in] tt The set used to prepare fold. If equal to `Types::kTraining`
43/// splitting will be based off the original train set. If instead
44/// equal to `Types::kTesting` the test set will be used.
45/// The original training/test set is the set as defined by
46/// `DataLoader::PrepareTrainingAndTestSet`.
47///
48/// Sets the training and test set vectors of the DataSet described by `dsi` as
49/// defined by the split. If `tt` is eqal to `Types::kTraining` the split will
50/// be based off of the original training set.
51///
52/// Note: Requires `MakeKFoldDataSet` to have been called first.
53///
54
56{
57 if (foldNumber >= fNumFolds) {
58 Log() << kFATAL << "DataSet prepared for \"" << fNumFolds << "\" folds, requested fold \"" << foldNumber
59 << "\" is outside of range." << Endl;
60 return;
61 }
62
63 auto prepareDataSetInternal = [this, &dsi, foldNumber](std::vector<std::vector<Event *>> vec) {
64 UInt_t numFolds = fTrainEvents.size();
65
66 // Events in training set (excludes current fold)
67 UInt_t nTotal = std::accumulate(vec.begin(), vec.end(), 0,
68 [&](UInt_t sum, std::vector<TMVA::Event *> v) { return sum + v.size(); });
69
70 UInt_t nTrain = nTotal - vec.at(foldNumber).size();
71 UInt_t nTest = vec.at(foldNumber).size();
72
73 std::vector<Event *> tempTrain;
74 std::vector<Event *> tempTest;
75
76 tempTrain.reserve(nTrain);
77 tempTest.reserve(nTest);
78
79 // Insert data into training set
80 for (UInt_t i = 0; i < numFolds; ++i) {
81 if (i == foldNumber) {
82 continue;
83 }
84
85 tempTrain.insert(tempTrain.end(), vec.at(i).begin(), vec.at(i).end());
86 }
87
88 // Insert data into test set
89 tempTest.insert(tempTest.end(), vec.at(foldNumber).begin(), vec.at(foldNumber).end());
90
91 Log() << kDEBUG << "Fold prepared, num events in training set: " << tempTrain.size() << Endl;
92 Log() << kDEBUG << "Fold prepared, num events in test set: " << tempTest.size() << Endl;
93
94 // Assign the vectors of the events to rebuild the dataset
95 dsi.GetDataSet()->SetEventCollection(&tempTrain, Types::kTraining, false);
96 dsi.GetDataSet()->SetEventCollection(&tempTest, Types::kTesting, false);
97 };
98
99 if (tt == Types::kTraining) {
100 prepareDataSetInternal(fTrainEvents);
101 } else if (tt == Types::kTesting) {
102 prepareDataSetInternal(fTestEvents);
103 } else {
104 Log() << kFATAL << "PrepareFoldDataSet can only work with training and testing data sets." << std::endl;
105 return;
106 }
107}
108
109////////////////////////////////////////////////////////////////////////////////
110///
111
113{
114 if (tt != Types::kTraining) {
115 Log() << kFATAL << "Only kTraining is supported for CvSplit::RecombineKFoldDataSet currently." << std::endl;
116 }
117
118 std::vector<Event *> *tempVec = new std::vector<Event *>;
119
120 for (UInt_t i = 0; i < fNumFolds; ++i) {
121 tempVec->insert(tempVec->end(), fTrainEvents.at(i).begin(), fTrainEvents.at(i).end());
122 }
123
124 dsi.GetDataSet()->SetEventCollection(tempVec, Types::kTraining, false);
125 dsi.GetDataSet()->SetEventCollection(tempVec, Types::kTesting, false);
126
127 delete tempVec;
128}
129
130/* =============================================================================
131 TMVA::CvSplitKFoldsExpr
132============================================================================= */
133
134////////////////////////////////////////////////////////////////////////////////
135///
136
138 : fDsi(dsi), fIdxFormulaParNumFolds(std::numeric_limits<Int_t>::max()), fSplitFormula("", expr),
139 fParValues(fSplitFormula.GetNpar())
140{
141 if (!fSplitFormula.IsValid()) {
142 throw std::runtime_error("Split expression \"" + std::string(fSplitExpr.Data()) + "\" is not a valid TFormula.");
143 }
144
147
148 // std::cout << "Found variable with name \"" << name << "\"." << std::endl;
149
150 if (name == "NumFolds" || name == "numFolds") {
151 // std::cout << "NumFolds|numFolds is a reserved variable! Adding to context." << std::endl;
153 } else {
155 }
156 }
157}
158
159////////////////////////////////////////////////////////////////////////////////
160///
161
163{
164 for (auto &p : fFormulaParIdxToDsiSpecIdx) {
165 auto iFormulaPar = p.first;
166 auto iSpectator = p.second;
167
168 fParValues.at(iFormulaPar) = ev->GetSpectator(iSpectator);
169 }
170
171 if (fIdxFormulaParNumFolds < fSplitFormula.GetNpar()) {
172 fParValues[fIdxFormulaParNumFolds] = numFolds;
173 }
174
175 // NOTE: We are using a double to represent an integer here. This _will_
176 // lead to problems if the norm of the double grows too large. A quick test
177 // with python suggests that problems arise at a magnitude of ~1e16.
178 Double_t iFold_d = fSplitFormula.EvalPar(nullptr, &fParValues[0]);
179
180 if (iFold_d < 0) {
181 throw std::runtime_error("Output of splitExpr must be non-negative.");
182 }
183
184 UInt_t iFold = std::lround(iFold_d);
185 if (iFold >= numFolds) {
186 throw std::runtime_error("Output of splitExpr should be a non-negative"
187 "integer between 0 and numFolds-1 inclusive.");
188 }
189
190 return iFold;
191}
192
193////////////////////////////////////////////////////////////////////////////////
194///
195
200
201////////////////////////////////////////////////////////////////////////////////
202///
203
205{
206 std::vector<VariableInfo> spectatorInfos = dsi.GetSpectatorInfos();
207
208 for (UInt_t iSpectator = 0; iSpectator < spectatorInfos.size(); ++iSpectator) {
210 if (vi.GetName() == name) {
211 return iSpectator;
212 } else if (vi.GetLabel() == name) {
213 return iSpectator;
214 } else if (vi.GetExpression() == name) {
215 return iSpectator;
216 }
217 }
218
219 throw std::runtime_error("Spectator \"" + std::string(name.Data()) + "\" not found.");
220}
221
222/* =============================================================================
223 TMVA::CvSplitKFolds
224============================================================================= */
225
226////////////////////////////////////////////////////////////////////////////////
227/// \brief Splits a dataset into k folds, ready for use in cross validation.
228/// \param[in] numFolds Number of folds to split data into
229/// \param[in] stratified If true, use stratified splitting, balancing the
230/// number of events across classes and folds. If false,
231/// no such balancing is done. For
232/// \param[in] splitExpr Expression used to split data into folds. If `""` a
233/// random assignment will be done. Otherwise the
234/// expression is fed into a TFormula and evaluated per
235/// event. The resulting value is the fold assignment.
236/// \param[in] seed Used only when using random splitting (i.e. when
237/// `splitExpr` is `""`). Seed is used to initialise the random
238/// number generator when assigning events to folds.
239///
240
242 : CvSplit(numFolds), fSeed(seed), fSplitExprString(splitExpr), fStratified(stratified)
243{
245 Log() << kFATAL << "Split expression \"" << fSplitExprString << "\" is not a valid TFormula." << Endl;
246 }
247
248}
249
250////////////////////////////////////////////////////////////////////////////////
251/// \brief Prepares a DataSet for cross validation
252
254{
255 // Validate spectator
256 // fSpectatorIdx = GetSpectatorIndexForName(dsi, fSpectatorName);
257
258 if (fSplitExprString != TString("")) {
259 fSplitExpr = std::unique_ptr<CvSplitKFoldsExpr>(new CvSplitKFoldsExpr(dsi, fSplitExprString));
260 }
261
262 // No need to do it again if the sets have already been split.
263 if (fMakeFoldDataSet) {
264 Log() << kINFO << "Splitting in k-folds has been already done" << Endl;
265 return;
266 }
267
268 fMakeFoldDataSet = kTRUE;
269
270 UInt_t numClasses = dsi.GetNClasses();
271
272 // Get the original event vectors for testing and training from the dataset.
273 std::vector<Event *> trainData = dsi.GetDataSet()->GetEventCollection(Types::kTraining);
274 std::vector<Event *> testData = dsi.GetDataSet()->GetEventCollection(Types::kTesting);
275
276 // Split the sets into the number of folds.
277 fTrainEvents = SplitSets(trainData, fNumFolds, numClasses);
278 fTestEvents = SplitSets(testData, fNumFolds, numClasses);
279}
280
281////////////////////////////////////////////////////////////////////////////////
282/// \brief Generates a vector of fold assignments
283/// \param[in] nEntries Number of events in range
284/// \param[in] numFolds Number of folds to split data into
285/// \param[in] seed Random seed
286///
287/// Randomly assigns events to `numFolds` folds. Each fold will hold at most
288/// `nEntries / numFolds + 1` events.
289///
290
292{
293 // Generate assignment of the pattern `0, 1, 2, 0, 1, 2, 0, 1 ...` for
294 // `numFolds = 3`.
295 std::vector<UInt_t> fOrigToFoldMapping;
297
298 for (UInt_t iEvent = 0; iEvent < nEntries; ++iEvent) {
300 }
301
302 // Shuffle assignment
305
306 return fOrigToFoldMapping;
307}
308
309
310////////////////////////////////////////////////////////////////////////////////
311/// \brief Split sets for into k-folds
312/// \param[in] oldSet Original, unsplit, events
313/// \param[in] numFolds Number of folds to split data into
314/// \param[in] numClasses number of classes to stratify into
315///
316
317std::vector<std::vector<TMVA::Event *>>
318TMVA::CvSplitKFolds::SplitSets(std::vector<TMVA::Event *> &oldSet, UInt_t numFolds, UInt_t numClasses)
319{
320 const ULong64_t nEntries = oldSet.size();
322
323 std::vector<std::vector<Event *>> tempSets;
324 tempSets.reserve(fNumFolds);
325 for (UInt_t iFold = 0; iFold < numFolds; ++iFold) {
326 tempSets.emplace_back();
327 tempSets.at(iFold).reserve(foldSize);
328 }
329
330 Bool_t useSplitExpr = !(fSplitExpr == nullptr || fSplitExprString == "");
331
332 if (useSplitExpr) {
333 // Deterministic split
334 for (ULong64_t i = 0; i < nEntries; i++) {
335 TMVA::Event *ev = oldSet[i];
336 UInt_t iFold = fSplitExpr->Eval(numFolds, ev);
337 tempSets.at((UInt_t)iFold).push_back(ev);
338 }
339 } else {
340 if(!fStratified){
341 // Random split
342 std::vector<UInt_t> fOrigToFoldMapping;
343 fOrigToFoldMapping = GetEventIndexToFoldMapping(nEntries, numFolds, fSeed);
344
345 for (UInt_t iEvent = 0; iEvent < nEntries; ++iEvent) {
348 tempSets.at(iFold).push_back(ev);
349
350 fEventToFoldMapping[ev] = iFold;
351 }
352 } else {
353 // Stratified Split
354 std::vector<std::vector<TMVA::Event *>> oldSets;
355 oldSets.reserve(numClasses);
356
357 for(UInt_t iClass = 0; iClass < numClasses; iClass++){
358 oldSets.emplace_back();
359 //find a way to get number of events in each class
360 oldSets.reserve(nEntries);
361 }
362
363 for(UInt_t iEvent = 0; iEvent < nEntries; ++iEvent){
364 // check the class of event and add to its vector of events
366 UInt_t iClass = ev->GetClass();
367 oldSets.at(iClass).push_back(ev);
368 }
369
370 for(UInt_t i = 0; i<numClasses; ++i){
371 // Shuffle each vector individually
373 std::shuffle(oldSets.at(i).begin(), oldSets.at(i).end(), rng);
374 }
375
376 for(UInt_t i = 0; i<numClasses; ++i) {
377 std::vector<UInt_t> fOrigToFoldMapping;
378 fOrigToFoldMapping = GetEventIndexToFoldMapping(oldSets.at(i).size(), numFolds, fSeed);
379
380 for (UInt_t iEvent = 0; iEvent < oldSets.at(i).size(); ++iEvent) {
382 TMVA::Event *ev = oldSets.at(i)[iEvent];
383 tempSets.at(iFold).push_back(ev);
384 fEventToFoldMapping[ev] = iFold;
385 }
386 }
387 }
388 }
389 return tempSets;
390}
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
constexpr Bool_t kFALSE
Definition RtypesCore.h:108
constexpr Bool_t kTRUE
Definition RtypesCore.h:107
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
winID h TVirtualViewer3D TVirtualGLPainter p
char name[80]
Definition TGX11.cxx:110
const_iterator begin() const
const_iterator end() const
The Formula class.
Definition TFormula.h:89
const char * GetParName(Int_t ipar) const
Return parameter name given by integer.
Bool_t IsValid() const
Definition TFormula.h:272
Int_t GetNpar() const
Definition TFormula.h:261
MsgLogger & Log() const
Int_t fIdxFormulaParNumFolds
Maps parameter indicies in splitExpr to their spectator index in the datasetinfo.
Definition CvSplit.h:81
UInt_t Eval(UInt_t numFolds, const Event *ev)
Definition CvSplit.cxx:162
std::vector< std::pair< Int_t, Int_t > > fFormulaParIdxToDsiSpecIdx
Definition CvSplit.h:80
UInt_t GetSpectatorIndexForName(DataSetInfo &dsi, TString name)
Definition CvSplit.cxx:204
static Bool_t Validate(TString expr)
Definition CvSplit.cxx:196
CvSplitKFoldsExpr(DataSetInfo &dsi, TString expr)
Definition CvSplit.cxx:137
TFormula fSplitFormula
Expression used to split data into folds. Should output values between 0 and numFolds.
Definition CvSplit.h:83
DataSetInfo & fDsi
Definition CvSplit.h:77
TString fSplitExpr
Keeps track of the index of reserved par "NumFolds" in splitExpr.
Definition CvSplit.h:82
std::vector< UInt_t > GetEventIndexToFoldMapping(UInt_t nEntries, UInt_t numFolds, UInt_t seed=100)
Generates a vector of fold assignments.
Definition CvSplit.cxx:291
void MakeKFoldDataSet(DataSetInfo &dsi) override
Prepares a DataSet for cross validation.
Definition CvSplit.cxx:253
std::vector< std::vector< Event * > > SplitSets(std::vector< TMVA::Event * > &oldSet, UInt_t numFolds, UInt_t numClasses)
Split sets for into k-folds.
Definition CvSplit.cxx:318
TString fSplitExprString
! Expression used to split data into folds. Should output values between 0 and numFolds.
Definition CvSplit.h:108
CvSplitKFolds(UInt_t numFolds, TString splitExpr="", Bool_t stratified=kTRUE, UInt_t seed=100)
Splits a dataset into k folds, ready for use in cross validation.
Definition CvSplit.cxx:241
virtual void RecombineKFoldDataSet(DataSetInfo &dsi, Types::ETreeType tt=Types::kTraining)
Definition CvSplit.cxx:112
virtual void PrepareFoldDataSet(DataSetInfo &dsi, UInt_t foldNumber, Types::ETreeType tt)
Set training and test set vectors of dataset described by dsi.
Definition CvSplit.cxx:55
CvSplit(UInt_t numFolds)
Definition CvSplit.cxx:36
Class that contains all the data information.
Definition DataSetInfo.h:62
@ kTraining
Definition Types.h:143
Class for type info of MVA input variable.
Basic string class.
Definition TString.h:138
const char * Data() const
Definition TString.h:384
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148
auto * tt
Definition textangle.C:16
static uint64_t sum(uint64_t i)
Definition Factory.cxx:2339