Logo ROOT   6.14/05
Reference Guide
CvSplit.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Kim Albertsson
3 
4 /*************************************************************************
5  * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 #include "TMVA/CvSplit.h"
13 
14 #include "TMVA/DataSet.h"
15 #include "TMVA/DataSetFactory.h"
16 #include "TMVA/DataSetInfo.h"
17 #include "TMVA/Event.h"
18 #include "TMVA/MsgLogger.h"
19 #include "TMVA/Tools.h"
20 
21 #include <TString.h>
22 #include <TFormula.h>
23 
24 #include <algorithm>
25 #include <numeric>
26 #include <stdexcept>
27 
30 
31 /* =============================================================================
32  TMVA::CvSplit
33 ============================================================================= */
34 
35 ////////////////////////////////////////////////////////////////////////////////
36 ///
37 
38 TMVA::CvSplit::CvSplit(UInt_t numFolds) : fNumFolds(numFolds), fMakeFoldDataSet(kFALSE) {}
39 
40 ////////////////////////////////////////////////////////////////////////////////
41 /// \brief Set training and test set vectors of dataset described by `dsi`.
42 /// \param[in] dsi DataSetInfo for data set to be split
43 /// \param[in] foldNumber Ordinal of fold to prepare
44 /// \param[in] tt The set used to prepare fold. If equal to `Types::kTraining`
45 /// splitting will be based off the original train set. If instead
46 /// equal to `Types::kTesting` the test set will be used.
47 /// The original training/test set is the set as defined by
48 /// `DataLoader::PrepareTrainingAndTestSet`.
49 ///
50 /// Sets the training and test set vectors of the DataSet described by `dsi` as
51 /// defined by the split. If `tt` is eqal to `Types::kTraining` the split will
52 /// be based off of the original training set.
53 ///
54 /// Note: Requires `MakeKFoldDataSet` to have been called first.
55 ///
56 
58 {
59  if (foldNumber >= fNumFolds) {
60  Log() << kFATAL << "DataSet prepared for \"" << fNumFolds << "\" folds, requested fold \"" << foldNumber
61  << "\" is outside of range." << Endl;
62  return;
63  }
64 
65  auto prepareDataSetInternal = [this, &dsi, foldNumber](std::vector<std::vector<Event *>> vec) {
66  UInt_t numFolds = fTrainEvents.size();
67 
68  // Events in training set (excludes current fold)
69  UInt_t nTotal = std::accumulate(vec.begin(), vec.end(), 0,
70  [&](UInt_t sum, std::vector<TMVA::Event *> v) { return sum + v.size(); });
71 
72  UInt_t nTrain = nTotal - vec.at(foldNumber).size();
73  UInt_t nTest = vec.at(foldNumber).size();
74 
75  std::vector<Event *> tempTrain;
76  std::vector<Event *> tempTest;
77 
78  tempTrain.reserve(nTrain);
79  tempTest.reserve(nTest);
80 
81  // Insert data into training set
82  for (UInt_t i = 0; i < numFolds; ++i) {
83  if (i == foldNumber) {
84  continue;
85  }
86 
87  tempTrain.insert(tempTrain.end(), vec.at(i).begin(), vec.at(i).end());
88  }
89 
90  // Insert data into test set
91  tempTest.insert(tempTest.end(), vec.at(foldNumber).begin(), vec.at(foldNumber).end());
92 
93  Log() << kDEBUG << "Fold prepared, num events in training set: " << tempTrain.size() << Endl;
94  Log() << kDEBUG << "Fold prepared, num events in test set: " << tempTest.size() << Endl;
95 
96  // Assign the vectors of the events to rebuild the dataset
97  dsi.GetDataSet()->SetEventCollection(&tempTrain, Types::kTraining, false);
98  dsi.GetDataSet()->SetEventCollection(&tempTest, Types::kTesting, false);
99  };
100 
101  if (tt == Types::kTraining) {
102  prepareDataSetInternal(fTrainEvents);
103  } else if (tt == Types::kTesting) {
104  prepareDataSetInternal(fTestEvents);
105  } else {
106  Log() << kFATAL << "PrepareFoldDataSet can only work with training and testing data sets." << std::endl;
107  return;
108  }
109 }
110 
111 ////////////////////////////////////////////////////////////////////////////////
112 ///
113 
115 {
116  if (tt != Types::kTraining) {
117  Log() << kFATAL << "Only kTraining is supported for CvSplit::RecombineKFoldDataSet currently." << std::endl;
118  }
119 
120  std::vector<Event *> *tempVec = new std::vector<Event *>;
121 
122  for (UInt_t i = 0; i < fNumFolds; ++i) {
123  tempVec->insert(tempVec->end(), fTrainEvents.at(i).begin(), fTrainEvents.at(i).end());
124  }
125 
126  dsi.GetDataSet()->SetEventCollection(tempVec, Types::kTraining, false);
127  dsi.GetDataSet()->SetEventCollection(tempVec, Types::kTesting, false);
128 
129  delete tempVec;
130 }
131 
132 /* =============================================================================
133  TMVA::CvSplitKFoldsExpr
134 ============================================================================= */
135 
136 ////////////////////////////////////////////////////////////////////////////////
137 ///
138 
140  : fDsi(dsi), fIdxFormulaParNumFolds(std::numeric_limits<UInt_t>::max()), fSplitFormula("", expr),
141  fParValues(fSplitFormula.GetNpar())
142 {
143  if (not fSplitFormula.IsValid()) {
144  throw std::runtime_error("Split expression \"" + std::string(fSplitExpr.Data()) + "\" is not a valid TFormula.");
145  }
146 
147  for (Int_t iFormulaPar = 0; iFormulaPar < fSplitFormula.GetNpar(); ++iFormulaPar) {
148  TString name = fSplitFormula.GetParName(iFormulaPar);
149 
150  // std::cout << "Found variable with name \"" << name << "\"." << std::endl;
151 
152  if (name == "NumFolds" or name == "numFolds") {
153  // std::cout << "NumFolds|numFolds is a reserved variable! Adding to context." << std::endl;
154  fIdxFormulaParNumFolds = iFormulaPar;
155  } else {
156  fFormulaParIdxToDsiSpecIdx.push_back(std::make_pair(iFormulaPar, GetSpectatorIndexForName(fDsi, name)));
157  }
158  }
159 }
160 
161 ////////////////////////////////////////////////////////////////////////////////
162 ///
163 
165 {
166  for (auto &p : fFormulaParIdxToDsiSpecIdx) {
167  auto iFormulaPar = p.first;
168  auto iSpectator = p.second;
169 
170  fParValues.at(iFormulaPar) = ev->GetSpectator(iSpectator);
171  }
172 
175  }
176 
177  Double_t iFold = fSplitFormula.EvalPar(nullptr, &fParValues[0]);
178 
179  if (fabs(iFold - (double)((UInt_t)iFold)) > 1e-5) {
180  throw std::runtime_error(
181  "Output of splitExpr should be a non-negative integer between 0 and numFolds-1 inclusive.");
182  }
183 
184  return iFold;
185 }
186 
187 ////////////////////////////////////////////////////////////////////////////////
188 ///
189 
191 {
192  return TFormula("", expr).IsValid();
193 }
194 
195 ////////////////////////////////////////////////////////////////////////////////
196 ///
197 
199 {
200  std::vector<VariableInfo> spectatorInfos = dsi.GetSpectatorInfos();
201 
202  for (UInt_t iSpectator = 0; iSpectator < spectatorInfos.size(); ++iSpectator) {
203  VariableInfo vi = spectatorInfos[iSpectator];
204  if (vi.GetName() == name) {
205  return iSpectator;
206  } else if (vi.GetLabel() == name) {
207  return iSpectator;
208  } else if (vi.GetExpression() == name) {
209  return iSpectator;
210  }
211  }
212 
213  throw std::runtime_error("Spectator \"" + std::string(name.Data()) + "\" not found.");
214 }
215 
216 /* =============================================================================
217  TMVA::CvSplitKFolds
218 ============================================================================= */
219 
220 ////////////////////////////////////////////////////////////////////////////////
221 /// \brief Splits a dataset into k folds, ready for use in cross validation.
222 /// \param numFolds[in] Number of folds to split data into
223 /// \param stratified[in] If true, use stratified splitting, balancing the
224 /// number of events across classes and folds. If false,
225 /// no such balancing is done. For
226 /// \param splitExpr[in] Expression used to split data into folds. If `""` a
227 /// random assignment will be done. Otherwise the
228 /// expression is fed into a TFormula and evaluated per
229 /// event. The resulting value is the the fold assignment.
230 /// \param seed[in] Used only when using random splitting (i.e. when
231 /// `splitExpr` is `""`). Seed is used to initialise the random
232 /// number generator when assigning events to folds.
233 ///
234 
235 TMVA::CvSplitKFolds::CvSplitKFolds(UInt_t numFolds, TString splitExpr, Bool_t stratified, UInt_t seed)
236  : CvSplit(numFolds), fSeed(seed), fSplitExprString(splitExpr), fStratified(stratified)
237 {
238  if (not CvSplitKFoldsExpr::Validate(fSplitExprString) and (splitExpr != TString(""))) {
239  Log() << kFATAL << "Split expression \"" << fSplitExprString << "\" is not a valid TFormula." << Endl;
240  }
241 
242  if (stratified) {
243  Log() << kFATAL << "Stratified KFolds not currently implemented." << std::endl;
244  }
245 }
246 
247 ////////////////////////////////////////////////////////////////////////////////
248 /// \brief Prepares a DataSet for cross validation
249 
251 {
252  // Validate spectator
253  // fSpectatorIdx = GetSpectatorIndexForName(dsi, fSpectatorName);
254 
255  if (fSplitExprString != TString("")) {
256  fSplitExpr = std::unique_ptr<CvSplitKFoldsExpr>(new CvSplitKFoldsExpr(dsi, fSplitExprString));
257  }
258 
259  // No need to do it again if the sets have already been split.
260  if (fMakeFoldDataSet) {
261  Log() << kINFO << "Splitting in k-folds has been already done" << Endl;
262  return;
263  }
264 
266 
267  // Get the original event vectors for testing and training from the dataset.
268  std::vector<Event *> trainData = dsi.GetDataSet()->GetEventCollection(Types::kTraining);
269  std::vector<Event *> testData = dsi.GetDataSet()->GetEventCollection(Types::kTesting);
270 
271  // Split the sets into the number of folds.
272  fTrainEvents = SplitSets(trainData, fNumFolds);
273  fTestEvents = SplitSets(testData, fNumFolds);
274 }
275 
276 ////////////////////////////////////////////////////////////////////////////////
277 /// \brief Generates a vector of fold assignments
278 /// \param nEntires[in] Number of events in range
279 /// \param numFolds[in] Number of folds to split data into
280 /// \param seed[in] Random seed
281 ///
282 /// Randomly assigns events to `numFolds` folds. Each fold will hold at most
283 /// `nEntries / numFolds + 1` events.
284 ///
285 
286 std::vector<UInt_t> TMVA::CvSplitKFolds::GetEventIndexToFoldMapping(UInt_t nEntries, UInt_t numFolds, UInt_t seed)
287 {
288  // Generate assignment of the pattern `0, 1, 2, 0, 1, 2, 0, 1 ...` for
289  // `numFolds = 3`.
290  std::vector<UInt_t> fOrigToFoldMapping;
291  fOrigToFoldMapping.reserve(nEntries);
292  for (UInt_t iEvent = 0; iEvent < nEntries; ++iEvent) {
293  fOrigToFoldMapping.push_back(iEvent % numFolds);
294  }
295 
296  // Shuffle assignment
298  std::shuffle(fOrigToFoldMapping.begin(), fOrigToFoldMapping.end(), rng);
299 
300  return fOrigToFoldMapping;
301 }
302 
303 ////////////////////////////////////////////////////////////////////////////////
304 /// \brief Split sets for into k-folds
305 /// \param oldSet[in] Original, unsplit, events
306 /// \param numFolds[in] Number of folds to split data into
307 ///
308 
309 std::vector<std::vector<TMVA::Event *>>
310 TMVA::CvSplitKFolds::SplitSets(std::vector<TMVA::Event *> &oldSet, UInt_t numFolds)
311 {
312  const ULong64_t nEntries = oldSet.size();
313  const ULong64_t foldSize = nEntries / numFolds;
314 
315  std::vector<std::vector<Event *>> tempSets;
316  tempSets.reserve(fNumFolds);
317  for (UInt_t iFold = 0; iFold < numFolds; ++iFold) {
318  tempSets.emplace_back();
319  tempSets.at(iFold).reserve(foldSize);
320  }
321 
322  Bool_t useSplitExpr = not(fSplitExpr == nullptr or fSplitExprString == "");
323 
324  if (useSplitExpr) {
325  // Deterministic split
326  for (ULong64_t i = 0; i < nEntries; i++) {
327  TMVA::Event *ev = oldSet[i];
328  UInt_t iFold = fSplitExpr->Eval(numFolds, ev);
329  tempSets.at((UInt_t)iFold).push_back(ev);
330  }
331  } else {
332  // Random split
333  std::vector<UInt_t> fOrigToFoldMapping = GetEventIndexToFoldMapping(nEntries, numFolds, fSeed);
334 
335  for (UInt_t iEvent = 0; iEvent < nEntries; ++iEvent) {
336  UInt_t iFold = fOrigToFoldMapping[iEvent];
337  TMVA::Event *ev = oldSet[iEvent];
338  tempSets.at(iFold).push_back(ev);
339 
340  fEventToFoldMapping[ev] = iFold;
341  }
342  }
343 
344  return tempSets;
345 }
std::unique_ptr< CvSplitKFoldsExpr > fSplitExpr
Expression used to split data into folds. Should output values between 0 and numFolds.
Definition: CvSplit.h:107
Int_t fIdxFormulaParNumFolds
Maps parameter indicies in splitExpr to their spectator index in the datasetinfo. ...
Definition: CvSplit.h:79
CvSplitKFolds(UInt_t numFolds, TString splitExpr="", Bool_t stratified=kTRUE, UInt_t seed=100)
Splits a dataset into k folds, ready for use in cross validation.
Definition: CvSplit.cxx:235
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:47
static long int sum(long int i)
Definition: Factory.cxx:2258
std::vector< Double_t > fParValues
TFormula for splitExpr.
Definition: CvSplit.h:83
auto * tt
Definition: textangle.C:16
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
std::vector< UInt_t > GetEventIndexToFoldMapping(UInt_t nEntries, UInt_t numFolds, UInt_t seed=100)
Generates a vector of fold assignments.
Definition: CvSplit.cxx:286
TString fSplitExprString
Definition: CvSplit.h:106
std::vector< VariableInfo > & GetSpectatorInfos()
Definition: DataSetInfo.h:104
UInt_t GetSpectatorIndexForName(DataSetInfo &dsi, TString name)
Definition: CvSplit.cxx:198
MsgLogger & Log() const
Definition: Configurable.h:122
void MakeKFoldDataSet(DataSetInfo &dsi) override
Prepares a DataSet for cross validation.
Definition: CvSplit.cxx:250
Basic string class.
Definition: TString.h:131
Int_t GetNpar() const
Definition: TFormula.h:193
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
const std::vector< Event * > & GetEventCollection(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:227
STL namespace.
const TString & GetLabel() const
Definition: VariableInfo.h:59
std::vector< std::pair< Int_t, Int_t > > fFormulaParIdxToDsiSpecIdx
Definition: CvSplit.h:78
const TString & GetExpression() const
Definition: VariableInfo.h:57
virtual void PrepareFoldDataSet(DataSetInfo &dsi, UInt_t foldNumber, Types::ETreeType tt)
Set training and test set vectors of dataset described by dsi.
Definition: CvSplit.cxx:57
std::vector< std::vector< TMVA::Event * > > fTestEvents
Definition: CvSplit.h:52
TString fSplitExpr
Keeps track of the index of reserved par "NumFolds" in splitExpr.
Definition: CvSplit.h:80
Class that contains all the data information.
Definition: DataSetInfo.h:60
UInt_t fNumFolds
Definition: CvSplit.h:48
virtual void RecombineKFoldDataSet(DataSetInfo &dsi, Types::ETreeType tt=Types::kTraining)
Definition: CvSplit.cxx:114
Double_t EvalPar(const Double_t *x, const Double_t *params=0) const
Definition: TFormula.cxx:3061
VecExpr< UnaryOp< Fabs< T >, VecExpr< A, T, D >, T >, T, D > fabs(const VecExpr< A, T, D > &rhs)
CvSplitKFoldsExpr(DataSetInfo &dsi, TString expr)
Definition: CvSplit.cxx:139
SVector< double, 2 > v
Definition: Dict.h:5
The Formula class.
Definition: TFormula.h:83
unsigned int UInt_t
Definition: RtypesCore.h:42
Bool_t fMakeFoldDataSet
Definition: CvSplit.h:49
std::vector< std::vector< Event * > > SplitSets(std::vector< TMVA::Event *> &oldSet, UInt_t numFolds)
Split sets for into k-folds.
Definition: CvSplit.cxx:310
static Bool_t Validate(TString expr)
Definition: CvSplit.cxx:190
const Bool_t kFALSE
Definition: RtypesCore.h:88
CvSplit(UInt_t numFolds)
Definition: CvSplit.cxx:38
#define ClassImp(name)
Definition: Rtypes.h:359
double Double_t
Definition: RtypesCore.h:55
unsigned long long ULong64_t
Definition: RtypesCore.h:70
const char * GetParName(Int_t ipar) const
Return parameter name given by integer.
Definition: TFormula.cxx:2790
std::vector< std::vector< TMVA::Event * > > fTrainEvents
Definition: CvSplit.h:51
void SetEventCollection(std::vector< Event *> *, Types::ETreeType, Bool_t deleteEvents=true)
Sets the event collection (by DataSetFactory)
Definition: DataSet.cxx:250
you should not use this method at all Int_t Int_t Double_t Double_t Double_t e
Definition: TRolke.cxx:630
DataSetInfo & fDsi
Definition: CvSplit.h:75
std::map< const TMVA::Event *, UInt_t > fEventToFoldMapping
Definition: CvSplit.h:112
TFormula fSplitFormula
Expression used to split data into folds. Should output values between 0 and numFolds.
Definition: CvSplit.h:81
UInt_t Eval(UInt_t numFolds, const Event *ev)
Definition: CvSplit.cxx:164
TString()
TString default ctor.
Definition: TString.cxx:87
Class for type info of MVA input variable.
Definition: VariableInfo.h:47
Bool_t IsValid() const
Definition: TFormula.h:204
Float_t GetSpectator(UInt_t ivar) const
return spectator content
Definition: Event.cxx:262
const Bool_t kTRUE
Definition: RtypesCore.h:87
DataSet * GetDataSet() const
returns data set
char name[80]
Definition: TGX11.cxx:109
const char * Data() const
Definition: TString.h:364