Logo ROOT  
Reference Guide
CvSplit.h
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Kim Albertsson
3
4/*************************************************************************
5 * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12#ifndef ROOT_TMVA_CvSplit
13#define ROOT_TMVA_CvSplit
14
15#include "TMVA/Configurable.h"
16#include "TMVA/Types.h"
17
18#include <Rtypes.h>
19#include <TFormula.h>
20
21#include <memory>
22#include <vector>
23#include <map>
24
25class TString;
26
27namespace TMVA {
28
29class CrossValidation;
30class DataSetInfo;
31class Event;
32
33/* =============================================================================
34 TMVA::CvSplit
35============================================================================= */
36
37class CvSplit : public Configurable {
38public:
39 CvSplit(UInt_t numFolds);
40 virtual ~CvSplit() {}
41
42 virtual void MakeKFoldDataSet(DataSetInfo &dsi) = 0;
43 virtual void PrepareFoldDataSet(DataSetInfo &dsi, UInt_t foldNumber, Types::ETreeType tt);
45
48
49protected:
52
53 std::vector<std::vector<TMVA::Event *>> fTrainEvents;
54 std::vector<std::vector<TMVA::Event *>> fTestEvents;
55
56protected:
57 ClassDef(CvSplit, 0);
58};
59
60/* =============================================================================
61 TMVA::CvSplitKFoldsExpr
62============================================================================= */
63
65public:
68
69 UInt_t Eval(UInt_t numFolds, const Event *ev);
70
71 static Bool_t Validate(TString expr);
72
73private:
75
76private:
78
79 std::vector<std::pair<Int_t, Int_t>>
80 fFormulaParIdxToDsiSpecIdx; //! Maps parameter indicies in splitExpr to their spectator index in the datasetinfo.
81 Int_t fIdxFormulaParNumFolds; //! Keeps track of the index of reserved par "NumFolds" in splitExpr.
82 TString fSplitExpr; //! Expression used to split data into folds. Should output values between 0 and numFolds.
83 TFormula fSplitFormula; //! TFormula for splitExpr.
84
85 std::vector<Double_t> fParValues;
86};
87
88/* =============================================================================
89 TMVA::CvSplitKFolds
90============================================================================= */
91
92class CvSplitKFolds : public CvSplit {
93
95
96public:
97 CvSplitKFolds(UInt_t numFolds, TString splitExpr = "", Bool_t stratified = kTRUE, UInt_t seed = 100);
98 ~CvSplitKFolds() override {}
99
100 void MakeKFoldDataSet(DataSetInfo &dsi) override;
101
102private:
103 std::vector<std::vector<Event *>> SplitSets(std::vector<TMVA::Event *> &oldSet, UInt_t numFolds, UInt_t numClasses);
104 std::vector<UInt_t> GetEventIndexToFoldMapping(UInt_t nEntries, UInt_t numFolds, UInt_t seed = 100);
105
106private:
108 TString fSplitExprString; //! Expression used to split data into folds. Should output values between 0 and numFolds.
109 std::unique_ptr<CvSplitKFoldsExpr> fSplitExpr;
110 Bool_t fStratified; // If true, use stratified split. (Balance class presence in each fold).
111
112 // Used for CrossValidation with random splits (not using the
113 // CVSplitKFoldsExpr functionality) to communicate Event to fold mapping.
114 std::map<const TMVA::Event *, UInt_t> fEventToFoldMapping;
115
116private:
118};
119
120} // end namespace TMVA
121
122#endif
int Int_t
Definition: RtypesCore.h:45
unsigned int UInt_t
Definition: RtypesCore.h:46
bool Bool_t
Definition: RtypesCore.h:63
const Bool_t kTRUE
Definition: RtypesCore.h:100
#define ClassDef(name, id)
Definition: Rtypes.h:325
char name[80]
Definition: TGX11.cxx:110
The Formula class.
Definition: TFormula.h:87
Int_t fIdxFormulaParNumFolds
Maps parameter indicies in splitExpr to their spectator index in the datasetinfo.
Definition: CvSplit.h:81
UInt_t Eval(UInt_t numFolds, const Event *ev)
Definition: CvSplit.cxx:164
std::vector< std::pair< Int_t, Int_t > > fFormulaParIdxToDsiSpecIdx
Definition: CvSplit.h:80
UInt_t GetSpectatorIndexForName(DataSetInfo &dsi, TString name)
Definition: CvSplit.cxx:206
static Bool_t Validate(TString expr)
Definition: CvSplit.cxx:198
std::vector< Double_t > fParValues
TFormula for splitExpr.
Definition: CvSplit.h:85
CvSplitKFoldsExpr(DataSetInfo &dsi, TString expr)
Definition: CvSplit.cxx:139
TFormula fSplitFormula
Expression used to split data into folds. Should output values between 0 and numFolds.
Definition: CvSplit.h:83
DataSetInfo & fDsi
Definition: CvSplit.h:77
TString fSplitExpr
Keeps track of the index of reserved par "NumFolds" in splitExpr.
Definition: CvSplit.h:82
friend CrossValidation
Definition: CvSplit.h:94
Bool_t fStratified
Definition: CvSplit.h:110
std::vector< UInt_t > GetEventIndexToFoldMapping(UInt_t nEntries, UInt_t numFolds, UInt_t seed=100)
Generates a vector of fold assignments.
Definition: CvSplit.cxx:293
std::unique_ptr< CvSplitKFoldsExpr > fSplitExpr
Expression used to split data into folds. Should output values between 0 and numFolds.
Definition: CvSplit.h:109
std::map< const TMVA::Event *, UInt_t > fEventToFoldMapping
Definition: CvSplit.h:114
void MakeKFoldDataSet(DataSetInfo &dsi) override
Prepares a DataSet for cross validation.
Definition: CvSplit.cxx:255
std::vector< std::vector< Event * > > SplitSets(std::vector< TMVA::Event * > &oldSet, UInt_t numFolds, UInt_t numClasses)
Split sets for into k-folds.
Definition: CvSplit.cxx:319
~CvSplitKFolds() override
Definition: CvSplit.h:98
TString fSplitExprString
Definition: CvSplit.h:108
ClassDefOverride(CvSplitKFolds, 0)
CvSplitKFolds(UInt_t numFolds, TString splitExpr="", Bool_t stratified=kTRUE, UInt_t seed=100)
Splits a dataset into k folds, ready for use in cross validation.
Definition: CvSplit.cxx:243
UInt_t fNumFolds
Definition: CvSplit.h:50
Bool_t NeedsRebuild()
Definition: CvSplit.h:47
virtual ~CvSplit()
Definition: CvSplit.h:40
virtual void RecombineKFoldDataSet(DataSetInfo &dsi, Types::ETreeType tt=Types::kTraining)
Definition: CvSplit.cxx:114
UInt_t GetNumFolds()
Definition: CvSplit.h:46
std::vector< std::vector< TMVA::Event * > > fTestEvents
Definition: CvSplit.h:54
virtual void MakeKFoldDataSet(DataSetInfo &dsi)=0
std::vector< std::vector< TMVA::Event * > > fTrainEvents
Definition: CvSplit.h:53
virtual void PrepareFoldDataSet(DataSetInfo &dsi, UInt_t foldNumber, Types::ETreeType tt)
Set training and test set vectors of dataset described by dsi.
Definition: CvSplit.cxx:57
CvSplit(UInt_t numFolds)
Definition: CvSplit.cxx:38
Bool_t fMakeFoldDataSet
Definition: CvSplit.h:51
Class that contains all the data information.
Definition: DataSetInfo.h:62
@ kTraining
Definition: Types.h:143
Basic string class.
Definition: TString.h:136
create variable transformations
auto * tt
Definition: textangle.C:16