58 Log() << kFATAL <<
"DataSet prepared for \"" <<
fNumFolds <<
"\" folds, requested fold \"" << foldNumber
59 <<
"\" is outside of range." <<
Endl;
63 auto prepareDataSetInternal = [
this, &dsi, foldNumber](std::vector<std::vector<Event *>>
vec) {
67 UInt_t nTotal = std::accumulate(
vec.begin(),
vec.end(), 0,
68 [&](
UInt_t sum, std::vector<TMVA::Event *>
v) { return sum + v.size(); });
70 UInt_t nTrain = nTotal -
vec.at(foldNumber).size();
73 std::vector<Event *> tempTrain;
74 std::vector<Event *> tempTest;
76 tempTrain.reserve(nTrain);
77 tempTest.reserve(nTest);
80 for (
UInt_t i = 0; i < numFolds; ++i) {
81 if (i == foldNumber) {
85 tempTrain.insert(tempTrain.end(),
vec.at(i).begin(),
vec.at(i).end());
89 tempTest.insert(tempTest.end(),
vec.at(foldNumber).begin(),
vec.at(foldNumber).end());
91 Log() << kDEBUG <<
"Fold prepared, num events in training set: " << tempTrain.size() <<
Endl;
92 Log() << kDEBUG <<
"Fold prepared, num events in test set: " << tempTest.size() <<
Endl;
104 Log() << kFATAL <<
"PrepareFoldDataSet can only work with training and testing data sets." << std::endl;
115 Log() << kFATAL <<
"Only kTraining is supported for CvSplit::RecombineKFoldDataSet currently." << std::endl;
118 std::vector<Event *> *tempVec =
new std::vector<Event *>;
142 throw std::runtime_error(
"Split expression \"" + std::string(fSplitExpr.Data()) +
"\" is not a valid TFormula.");
150 if (
name ==
"NumFolds" ||
name ==
"numFolds") {
165 auto iFormulaPar = p.first;
166 auto iSpectator = p.second;
181 throw std::runtime_error(
"Output of splitExpr must be non-negative.");
184 UInt_t iFold = std::lround(iFold_d);
185 if (iFold >= numFolds) {
186 throw std::runtime_error(
"Output of splitExpr should be a non-negative"
187 "integer between 0 and numFolds-1 inclusive.");
208 for (
UInt_t iSpectator = 0; iSpectator < spectatorInfos.size(); ++iSpectator) {
219 throw std::runtime_error(
"Spectator \"" + std::string(
name.Data()) +
"\" not found.");
264 Log() << kINFO <<
"Splitting in k-folds has been already done" <<
Endl;
295 std::vector<UInt_t> fOrigToFoldMapping;
296 fOrigToFoldMapping.reserve(nEntries);
298 for (
UInt_t iEvent = 0; iEvent < nEntries; ++iEvent) {
299 fOrigToFoldMapping.push_back(iEvent % numFolds);
304 std::shuffle(fOrigToFoldMapping.begin(), fOrigToFoldMapping.end(), rng);
306 return fOrigToFoldMapping;
317std::vector<std::vector<TMVA::Event *>>
320 const ULong64_t nEntries = oldSet.size();
321 const ULong64_t foldSize = nEntries / numFolds;
323 std::vector<std::vector<Event *>> tempSets;
325 for (
UInt_t iFold = 0; iFold < numFolds; ++iFold) {
326 tempSets.emplace_back();
327 tempSets.at(iFold).reserve(foldSize);
334 for (
ULong64_t i = 0; i < nEntries; i++) {
337 tempSets.at((
UInt_t)iFold).push_back(ev);
342 std::vector<UInt_t> fOrigToFoldMapping;
345 for (
UInt_t iEvent = 0; iEvent < nEntries; ++iEvent) {
346 UInt_t iFold = fOrigToFoldMapping[iEvent];
348 tempSets.at(iFold).push_back(ev);
354 std::vector<std::vector<TMVA::Event *>> oldSets;
355 oldSets.reserve(numClasses);
357 for(
UInt_t iClass = 0; iClass < numClasses; iClass++){
358 oldSets.emplace_back();
360 oldSets.reserve(nEntries);
363 for(
UInt_t iEvent = 0; iEvent < nEntries; ++iEvent){
367 oldSets.at(iClass).push_back(ev);
370 for(
UInt_t i = 0; i<numClasses; ++i){
373 std::shuffle(oldSets.at(i).begin(), oldSets.at(i).end(), rng);
376 for(
UInt_t i = 0; i<numClasses; ++i) {
377 std::vector<UInt_t> fOrigToFoldMapping;
380 for (
UInt_t iEvent = 0; iEvent < oldSets.at(i).
size(); ++iEvent) {
381 UInt_t iFold = fOrigToFoldMapping[iEvent];
383 tempSets.at(iFold).push_back(ev);
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
int Int_t
Signed integer 4 bytes (int).
unsigned int UInt_t
Unsigned integer 4 bytes (unsigned int).
bool Bool_t
Boolean (0=false, 1=true) (bool).
double Double_t
Double 8 bytes.
unsigned long long ULong64_t
Portable unsigned long integer 8 bytes.
Int_t fIdxFormulaParNumFolds
! Keeps track of the index of reserved par "NumFolds" in splitExpr.
UInt_t Eval(UInt_t numFolds, const Event *ev)
std::vector< std::pair< Int_t, Int_t > > fFormulaParIdxToDsiSpecIdx
! Maps parameter indicies in splitExpr to their spectator index in the datasetinfo.
UInt_t GetSpectatorIndexForName(DataSetInfo &dsi, TString name)
static Bool_t Validate(TString expr)
std::vector< Double_t > fParValues
CvSplitKFoldsExpr(DataSetInfo &dsi, TString expr)
TFormula fSplitFormula
! TFormula for splitExpr.
Bool_t fStratified
If true, use stratified split. (Balance class presence in each fold).
std::vector< UInt_t > GetEventIndexToFoldMapping(UInt_t nEntries, UInt_t numFolds, UInt_t seed=100)
Generates a vector of fold assignments.
std::unique_ptr< CvSplitKFoldsExpr > fSplitExpr
std::map< const TMVA::Event *, UInt_t > fEventToFoldMapping
void MakeKFoldDataSet(DataSetInfo &dsi) override
Prepares a DataSet for cross validation.
std::vector< std::vector< Event * > > SplitSets(std::vector< TMVA::Event * > &oldSet, UInt_t numFolds, UInt_t numClasses)
Split sets for into k-folds.
TString fSplitExprString
! Expression used to split data into folds. Should output values between 0 and numFolds.
CvSplitKFolds(UInt_t numFolds, TString splitExpr="", Bool_t stratified=kTRUE, UInt_t seed=100)
Splits a dataset into k folds, ready for use in cross validation.
virtual void RecombineKFoldDataSet(DataSetInfo &dsi, Types::ETreeType tt=Types::kTraining)
std::vector< std::vector< TMVA::Event * > > fTestEvents
std::vector< std::vector< TMVA::Event * > > fTrainEvents
virtual void PrepareFoldDataSet(DataSetInfo &dsi, UInt_t foldNumber, Types::ETreeType tt)
Set training and test set vectors of dataset described by dsi.
Class that contains all the data information.
std::vector< VariableInfo > & GetSpectatorInfos()
UInt_t GetNClasses() const
DataSet * GetDataSet() const
returns data set
void SetEventCollection(std::vector< Event * > *, Types::ETreeType, Bool_t deleteEvents=true)
Sets the event collection (by DataSetFactory).
const std::vector< Event * > & GetEventCollection(Types::ETreeType type=Types::kMaxTreeType) const
Float_t GetSpectator(UInt_t ivar) const
return spectator content
Class for type info of MVA input variable.
const TString & GetLabel() const
const TString & GetExpression() const
const char * GetName() const override
Returns name of object.
MsgLogger & Endl(MsgLogger &ml)
static uint64_t sum(uint64_t i)