Logo ROOT   6.16/01
Reference Guide
CvSplit.h
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Kim Albertsson
3
4/*************************************************************************
5 * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12#ifndef ROOT_TMVA_CvSplit
13#define ROOT_TMVA_CvSplit
14
15#include "TMVA/Configurable.h"
16#include "TMVA/Types.h"
17
18#include <Rtypes.h>
19#include <TFormula.h>
20
21#include <memory>
22
23class TString;
24
25namespace TMVA {
26
27class CrossValidation;
28class DataSetInfo;
29class Event;
30
31/* =============================================================================
32 TMVA::CvSplit
33============================================================================= */
34
35class CvSplit : public Configurable {
36public:
37 CvSplit(UInt_t numFolds);
38 virtual ~CvSplit() {}
39
40 virtual void MakeKFoldDataSet(DataSetInfo &dsi) = 0;
41 virtual void PrepareFoldDataSet(DataSetInfo &dsi, UInt_t foldNumber, Types::ETreeType tt);
43
46
47protected:
50
51 std::vector<std::vector<TMVA::Event *>> fTrainEvents;
52 std::vector<std::vector<TMVA::Event *>> fTestEvents;
53
54protected:
55 ClassDef(CvSplit, 0);
56};
57
58/* =============================================================================
59 TMVA::CvSplitKFoldsExpr
60============================================================================= */
61
63public:
66
67 UInt_t Eval(UInt_t numFolds, const Event *ev);
68
69 static Bool_t Validate(TString expr);
70
71private:
73
74private:
76
77 std::vector<std::pair<Int_t, Int_t>>
78 fFormulaParIdxToDsiSpecIdx; //! Maps parameter indicies in splitExpr to their spectator index in the datasetinfo.
79 Int_t fIdxFormulaParNumFolds; //! Keeps track of the index of reserved par "NumFolds" in splitExpr.
80 TString fSplitExpr; //! Expression used to split data into folds. Should output values between 0 and numFolds.
81 TFormula fSplitFormula; //! TFormula for splitExpr.
82
83 std::vector<Double_t> fParValues;
84};
85
86/* =============================================================================
87 TMVA::CvSplitKFolds
88============================================================================= */
89
90class CvSplitKFolds : public CvSplit {
91
93
94public:
95 CvSplitKFolds(UInt_t numFolds, TString splitExpr = "", Bool_t stratified = kTRUE, UInt_t seed = 100);
96 ~CvSplitKFolds() override {}
97
98 void MakeKFoldDataSet(DataSetInfo &dsi) override;
99
100private:
101 std::vector<std::vector<Event *>> SplitSets(std::vector<TMVA::Event *> &oldSet, UInt_t numFolds, UInt_t numClasses);
102 std::vector<UInt_t> GetEventIndexToFoldMapping(UInt_t nEntries, UInt_t numFolds, UInt_t seed = 100);
103
104private:
106 TString fSplitExprString; //! Expression used to split data into folds. Should output values between 0 and numFolds.
107 std::unique_ptr<CvSplitKFoldsExpr> fSplitExpr;
108 Bool_t fStratified; // If true, use stratified split. (Balance class presence in each fold).
109
110 // Used for CrossValidation with random splits (not using the
111 // CVSplitKFoldsExpr functionality) to communicate Event to fold mapping.
112 std::map<const TMVA::Event *, UInt_t> fEventToFoldMapping;
113
114private:
116};
117
118} // end namespace TMVA
119
120#endif
int Int_t
Definition: RtypesCore.h:41
unsigned int UInt_t
Definition: RtypesCore.h:42
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kTRUE
Definition: RtypesCore.h:87
#define ClassDef(name, id)
Definition: Rtypes.h:324
The Formula class.
Definition: TFormula.h:84
Int_t fIdxFormulaParNumFolds
Maps parameter indicies in splitExpr to their spectator index in the datasetinfo.
Definition: CvSplit.h:79
UInt_t Eval(UInt_t numFolds, const Event *ev)
Definition: CvSplit.cxx:164
std::vector< std::pair< Int_t, Int_t > > fFormulaParIdxToDsiSpecIdx
Definition: CvSplit.h:78
UInt_t GetSpectatorIndexForName(DataSetInfo &dsi, TString name)
Definition: CvSplit.cxx:206
static Bool_t Validate(TString expr)
Definition: CvSplit.cxx:198
std::vector< Double_t > fParValues
TFormula for splitExpr.
Definition: CvSplit.h:83
CvSplitKFoldsExpr(DataSetInfo &dsi, TString expr)
Definition: CvSplit.cxx:139
TFormula fSplitFormula
Expression used to split data into folds. Should output values between 0 and numFolds.
Definition: CvSplit.h:81
DataSetInfo & fDsi
Definition: CvSplit.h:75
TString fSplitExpr
Keeps track of the index of reserved par "NumFolds" in splitExpr.
Definition: CvSplit.h:80
friend CrossValidation
Definition: CvSplit.h:92
Bool_t fStratified
Definition: CvSplit.h:108
std::vector< UInt_t > GetEventIndexToFoldMapping(UInt_t nEntries, UInt_t numFolds, UInt_t seed=100)
Generates a vector of fold assignments.
Definition: CvSplit.cxx:293
std::unique_ptr< CvSplitKFoldsExpr > fSplitExpr
Expression used to split data into folds. Should output values between 0 and numFolds.
Definition: CvSplit.h:107
std::map< const TMVA::Event *, UInt_t > fEventToFoldMapping
Definition: CvSplit.h:112
void MakeKFoldDataSet(DataSetInfo &dsi) override
Prepares a DataSet for cross validation.
Definition: CvSplit.cxx:255
std::vector< std::vector< Event * > > SplitSets(std::vector< TMVA::Event * > &oldSet, UInt_t numFolds, UInt_t numClasses)
Split sets for into k-folds.
Definition: CvSplit.cxx:319
~CvSplitKFolds() override
Definition: CvSplit.h:96
TString fSplitExprString
Definition: CvSplit.h:106
ClassDefOverride(CvSplitKFolds, 0)
CvSplitKFolds(UInt_t numFolds, TString splitExpr="", Bool_t stratified=kTRUE, UInt_t seed=100)
Splits a dataset into k folds, ready for use in cross validation.
Definition: CvSplit.cxx:243
UInt_t fNumFolds
Definition: CvSplit.h:48
Bool_t NeedsRebuild()
Definition: CvSplit.h:45
virtual ~CvSplit()
Definition: CvSplit.h:38
virtual void RecombineKFoldDataSet(DataSetInfo &dsi, Types::ETreeType tt=Types::kTraining)
Definition: CvSplit.cxx:114
UInt_t GetNumFolds()
Definition: CvSplit.h:44
std::vector< std::vector< TMVA::Event * > > fTestEvents
Definition: CvSplit.h:52
virtual void MakeKFoldDataSet(DataSetInfo &dsi)=0
std::vector< std::vector< TMVA::Event * > > fTrainEvents
Definition: CvSplit.h:51
virtual void PrepareFoldDataSet(DataSetInfo &dsi, UInt_t foldNumber, Types::ETreeType tt)
Set training and test set vectors of dataset described by dsi.
Definition: CvSplit.cxx:57
CvSplit(UInt_t numFolds)
Definition: CvSplit.cxx:38
Bool_t fMakeFoldDataSet
Definition: CvSplit.h:49
Class that contains all the data information.
Definition: DataSetInfo.h:60
@ kTraining
Definition: Types.h:144
Basic string class.
Definition: TString.h:131
Abstract ClassifierFactory template that handles arbitrary types.
auto * tt
Definition: textangle.C:16