Logo ROOT   6.14/05
Reference Guide
CvSplit.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Kim Albertsson
3 
4 /*************************************************************************
5  * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 #ifndef ROOT_TMVA_CvSplit
13 #define ROOT_TMVA_CvSplit
14 
15 #include "TMVA/Configurable.h"
16 #include "TMVA/Types.h"
17 
18 #include <Rtypes.h>
19 #include <TFormula.h>
20 
21 #include <memory>
22 
23 class TString;
24 
25 namespace TMVA {
26 
27 class CrossValidation;
28 class DataSetInfo;
29 class Event;
30 
31 /* =============================================================================
32  TMVA::CvSplit
33 ============================================================================= */
34 
35 class CvSplit : public Configurable {
36 public:
37  CvSplit(UInt_t numFolds);
38  virtual ~CvSplit() {}
39 
40  virtual void MakeKFoldDataSet(DataSetInfo &dsi) = 0;
41  virtual void PrepareFoldDataSet(DataSetInfo &dsi, UInt_t foldNumber, Types::ETreeType tt);
43 
46 
47 protected:
50 
51  std::vector<std::vector<TMVA::Event *>> fTrainEvents;
52  std::vector<std::vector<TMVA::Event *>> fTestEvents;
53 
54 protected:
55  ClassDef(CvSplit, 0);
56 };
57 
58 /* =============================================================================
59  TMVA::CvSplitKFoldsExpr
60 ============================================================================= */
61 
63 public:
66 
67  UInt_t Eval(UInt_t numFolds, const Event *ev);
68 
69  static Bool_t Validate(TString expr);
70 
71 private:
72  UInt_t GetSpectatorIndexForName(DataSetInfo &dsi, TString name);
73 
74 private:
76 
77  std::vector<std::pair<Int_t, Int_t>>
78  fFormulaParIdxToDsiSpecIdx; //! Maps parameter indicies in splitExpr to their spectator index in the datasetinfo.
79  Int_t fIdxFormulaParNumFolds; //! Keeps track of the index of reserved par "NumFolds" in splitExpr.
80  TString fSplitExpr; //! Expression used to split data into folds. Should output values between 0 and numFolds.
81  TFormula fSplitFormula; //! TFormula for splitExpr.
82 
83  std::vector<Double_t> fParValues;
84 };
85 
86 /* =============================================================================
87  TMVA::CvSplitKFolds
88 ============================================================================= */
89 
90 class CvSplitKFolds : public CvSplit {
91 
93 
94 public:
95  CvSplitKFolds(UInt_t numFolds, TString splitExpr = "", Bool_t stratified = kTRUE, UInt_t seed = 100);
96  ~CvSplitKFolds() override {}
97 
98  void MakeKFoldDataSet(DataSetInfo &dsi) override;
99 
100 private:
101  std::vector<std::vector<Event *>> SplitSets(std::vector<TMVA::Event *> &oldSet, UInt_t numFolds);
102  std::vector<UInt_t> GetEventIndexToFoldMapping(UInt_t nEntries, UInt_t numFolds, UInt_t seed = 100);
103 
104 private:
106  TString fSplitExprString; //! Expression used to split data into folds. Should output values between 0 and numFolds.
107  std::unique_ptr<CvSplitKFoldsExpr> fSplitExpr;
108  Bool_t fStratified; // If true, use stratified split. (Balance class presence in each fold).
109 
110  // Used for CrossValidation with random splits (not using the
111  // CVSplitKFoldsExpr functionality) to communicate Event to fold mapping.
112  std::map<const TMVA::Event *, UInt_t> fEventToFoldMapping;
113 
114 private:
116 };
117 
118 } // end namespace TMVA
119 
120 #endif
std::unique_ptr< CvSplitKFoldsExpr > fSplitExpr
Expression used to split data into folds. Should output values between 0 and numFolds.
Definition: CvSplit.h:107
Int_t fIdxFormulaParNumFolds
Maps parameter indicies in splitExpr to their spectator index in the datasetinfo. ...
Definition: CvSplit.h:79
std::vector< Double_t > fParValues
TFormula for splitExpr.
Definition: CvSplit.h:83
auto * tt
Definition: textangle.C:16
TString fSplitExprString
Definition: CvSplit.h:106
Basic string class.
Definition: TString.h:131
Bool_t NeedsRebuild()
Definition: CvSplit.h:45
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
UInt_t GetNumFolds()
Definition: CvSplit.h:44
std::vector< std::pair< Int_t, Int_t > > fFormulaParIdxToDsiSpecIdx
Definition: CvSplit.h:78
virtual void PrepareFoldDataSet(DataSetInfo &dsi, UInt_t foldNumber, Types::ETreeType tt)
Set training and test set vectors of dataset described by dsi.
Definition: CvSplit.cxx:57
#define ClassDef(name, id)
Definition: Rtypes.h:320
std::vector< std::vector< TMVA::Event * > > fTestEvents
Definition: CvSplit.h:52
TString fSplitExpr
Keeps track of the index of reserved par "NumFolds" in splitExpr.
Definition: CvSplit.h:80
Class that contains all the data information.
Definition: DataSetInfo.h:60
UInt_t fNumFolds
Definition: CvSplit.h:48
friend CrossValidation
Definition: CvSplit.h:92
virtual void RecombineKFoldDataSet(DataSetInfo &dsi, Types::ETreeType tt=Types::kTraining)
Definition: CvSplit.cxx:114
virtual void MakeKFoldDataSet(DataSetInfo &dsi)=0
The Formula class.
Definition: TFormula.h:83
unsigned int UInt_t
Definition: RtypesCore.h:42
Bool_t fMakeFoldDataSet
Definition: CvSplit.h:49
Bool_t fStratified
Definition: CvSplit.h:108
CvSplit(UInt_t numFolds)
Definition: CvSplit.cxx:38
std::vector< std::vector< TMVA::Event * > > fTrainEvents
Definition: CvSplit.h:51
DataSetInfo & fDsi
Definition: CvSplit.h:75
std::map< const TMVA::Event *, UInt_t > fEventToFoldMapping
Definition: CvSplit.h:112
TFormula fSplitFormula
Expression used to split data into folds. Should output values between 0 and numFolds.
Definition: CvSplit.h:81
Abstract ClassifierFactory template that handles arbitrary types.
virtual ~CvSplit()
Definition: CvSplit.h:38
~CvSplitKFolds() override
Definition: CvSplit.h:96
#define ClassDefOverride(name, id)
Definition: Rtypes.h:324
const Bool_t kTRUE
Definition: RtypesCore.h:87
char name[80]
Definition: TGX11.cxx:109