59   if (foldNumber >= fNumFolds) {
 
   60      Log() << kFATAL << 
"DataSet prepared for \"" << fNumFolds << 
"\" folds, requested fold \"" << foldNumber
 
   61            << 
"\" is outside of range." << 
Endl;
 
   65   auto prepareDataSetInternal = [
this, &dsi, foldNumber](std::vector<std::vector<Event *>> 
vec) {
 
   66      UInt_t numFolds = fTrainEvents.size();
 
   69      UInt_t nTotal = std::accumulate(
vec.begin(), 
vec.end(), 0,
 
   70                                      [&](
UInt_t sum, std::vector<TMVA::Event *> 
v) { return sum + v.size(); });
 
   72      UInt_t nTrain = nTotal - 
vec.at(foldNumber).size();
 
   75      std::vector<Event *> tempTrain;
 
   76      std::vector<Event *> tempTest;
 
   78      tempTrain.reserve(nTrain);
 
   79      tempTest.reserve(nTest);
 
   82      for (
UInt_t i = 0; i < numFolds; ++i) {
 
   83         if (i == foldNumber) {
 
   87         tempTrain.insert(tempTrain.end(), 
vec.at(i).begin(), 
vec.at(i).end());
 
   91      tempTest.insert(tempTest.end(), 
vec.at(foldNumber).begin(), 
vec.at(foldNumber).end());
 
   93      Log() << kDEBUG << 
"Fold prepared, num events in training set: " << tempTrain.size() << 
Endl;
 
   94      Log() << kDEBUG << 
"Fold prepared, num events in test     set: " << tempTest.size() << 
Endl;
 
  102      prepareDataSetInternal(fTrainEvents);
 
  104      prepareDataSetInternal(fTestEvents);
 
  106      Log() << kFATAL << 
"PrepareFoldDataSet can only work with training and testing data sets." << std::endl;
 
  117      Log() << kFATAL << 
"Only kTraining is supported for CvSplit::RecombineKFoldDataSet currently." << std::endl;
 
  120   std::vector<Event *> *tempVec = 
new std::vector<Event *>;
 
  122   for (
UInt_t i = 0; i < fNumFolds; ++i) {
 
  123      tempVec->insert(tempVec->end(), fTrainEvents.at(i).begin(), fTrainEvents.at(i).end());
 
  140   : fDsi(dsi), fIdxFormulaParNumFolds(std::numeric_limits<
Int_t>::max()), fSplitFormula(
"", expr),
 
  141     fParValues(fSplitFormula.GetNpar())
 
  144      throw std::runtime_error(
"Split expression \"" + std::string(
fSplitExpr.
Data()) + 
"\" is not a valid TFormula.");
 
  152      if (
name == 
"NumFolds" || 
name == 
"numFolds") {
 
  166   for (
auto &
p : fFormulaParIdxToDsiSpecIdx) {
 
  167      auto iFormulaPar = 
p.first;
 
  168      auto iSpectator = 
p.second;
 
  170      fParValues.at(iFormulaPar) = ev->
GetSpectator(iSpectator);
 
  173   if (fIdxFormulaParNumFolds < fSplitFormula.GetNpar()) {
 
  174      fParValues[fIdxFormulaParNumFolds] = numFolds;
 
  180   Double_t iFold_d = fSplitFormula.EvalPar(
nullptr, &fParValues[0]);
 
  183      throw std::runtime_error(
"Output of splitExpr must be non-negative.");
 
  186   UInt_t iFold = std::lround(iFold_d);
 
  187   if (iFold >= numFolds) {
 
  188      throw std::runtime_error(
"Output of splitExpr should be a non-negative" 
  189                               "integer between 0 and numFolds-1 inclusive.");
 
  210   for (
UInt_t iSpectator = 0; iSpectator < spectatorInfos.size(); ++iSpectator) {
 
  221   throw std::runtime_error(
"Spectator \"" + std::string(
name.Data()) + 
"\" not found.");
 
  244   : 
CvSplit(numFolds), fSeed(seed), fSplitExprString(splitExpr), fStratified(stratified)
 
  260   if (fSplitExprString != 
TString(
"")) {
 
  261      fSplitExpr = std::unique_ptr<CvSplitKFoldsExpr>(
new CvSplitKFoldsExpr(dsi, fSplitExprString));
 
  265   if (fMakeFoldDataSet) {
 
  266      Log() << kINFO << 
"Splitting in k-folds has been already done" << 
Endl;
 
  270   fMakeFoldDataSet = 
kTRUE;
 
  279   fTrainEvents = SplitSets(trainData, fNumFolds, numClasses);
 
  280   fTestEvents = SplitSets(testData, fNumFolds, numClasses);
 
  297   std::vector<UInt_t> fOrigToFoldMapping;
 
  298   fOrigToFoldMapping.reserve(nEntries);
 
  300   for (
UInt_t iEvent = 0; iEvent < nEntries; ++iEvent) {
 
  301      fOrigToFoldMapping.push_back(iEvent % numFolds);
 
  306   std::shuffle(fOrigToFoldMapping.begin(), fOrigToFoldMapping.end(), rng);
 
  308   return fOrigToFoldMapping;
 
  319std::vector<std::vector<TMVA::Event *>>
 
  322   const ULong64_t nEntries = oldSet.size();
 
  323   const ULong64_t foldSize = nEntries / numFolds;
 
  325   std::vector<std::vector<Event *>> tempSets;
 
  326   tempSets.reserve(fNumFolds);
 
  327   for (
UInt_t iFold = 0; iFold < numFolds; ++iFold) {
 
  328      tempSets.emplace_back();
 
  329      tempSets.at(iFold).reserve(foldSize);
 
  332   Bool_t useSplitExpr = !(fSplitExpr == 
nullptr || fSplitExprString == 
"");
 
  336      for (
ULong64_t i = 0; i < nEntries; i++) {
 
  338         UInt_t iFold = fSplitExpr->Eval(numFolds, ev);
 
  339         tempSets.at((
UInt_t)iFold).push_back(ev);
 
  344         std::vector<UInt_t> fOrigToFoldMapping;
 
  345         fOrigToFoldMapping = GetEventIndexToFoldMapping(nEntries, numFolds, fSeed);
 
  347         for (
UInt_t iEvent = 0; iEvent < nEntries; ++iEvent) {
 
  348            UInt_t iFold = fOrigToFoldMapping[iEvent];
 
  350            tempSets.at(iFold).push_back(ev);
 
  352            fEventToFoldMapping[ev] = iFold;
 
  356         std::vector<std::vector<TMVA::Event *>> oldSets;
 
  357         oldSets.reserve(numClasses);
 
  359         for(
UInt_t iClass = 0; iClass < numClasses; iClass++){
 
  360            oldSets.emplace_back();
 
  362            oldSets.reserve(nEntries);
 
  365         for(
UInt_t iEvent = 0; iEvent < nEntries; ++iEvent){
 
  369            oldSets.at(iClass).push_back(ev);
 
  372         for(
UInt_t i = 0; i<numClasses; ++i){
 
  375            std::shuffle(oldSets.at(i).begin(), oldSets.at(i).end(), rng);
 
  378         for(
UInt_t i = 0; i<numClasses; ++i) {
 
  379            std::vector<UInt_t> fOrigToFoldMapping;
 
  380            fOrigToFoldMapping = GetEventIndexToFoldMapping(oldSets.at(i).size(), numFolds, fSeed);
 
  382            for (
UInt_t iEvent = 0; iEvent < oldSets.at(i).
size(); ++iEvent) {
 
  383               UInt_t iFold = fOrigToFoldMapping[iEvent];
 
  385               tempSets.at(iFold).push_back(ev);
 
  386               fEventToFoldMapping[ev] = iFold;
 
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
 
unsigned long long ULong64_t
 
winID h TVirtualViewer3D TVirtualGLPainter p
 
Int_t fIdxFormulaParNumFolds
Maps parameter indicies in splitExpr to their spectator index in the datasetinfo.
 
UInt_t Eval(UInt_t numFolds, const Event *ev)
 
std::vector< std::pair< Int_t, Int_t > > fFormulaParIdxToDsiSpecIdx
 
UInt_t GetSpectatorIndexForName(DataSetInfo &dsi, TString name)
 
static Bool_t Validate(TString expr)
 
CvSplitKFoldsExpr(DataSetInfo &dsi, TString expr)
 
TFormula fSplitFormula
Expression used to split data into folds. Should output values between 0 and numFolds.
 
TString fSplitExpr
Keeps track of the index of reserved par "NumFolds" in splitExpr.
 
std::vector< UInt_t > GetEventIndexToFoldMapping(UInt_t nEntries, UInt_t numFolds, UInt_t seed=100)
Generates a vector of fold assignments.
 
void MakeKFoldDataSet(DataSetInfo &dsi) override
Prepares a DataSet for cross validation.
 
std::vector< std::vector< Event * > > SplitSets(std::vector< TMVA::Event * > &oldSet, UInt_t numFolds, UInt_t numClasses)
Split sets for into k-folds.
 
TString fSplitExprString
! Expression used to split data into folds. Should output values between 0 and numFolds.
 
CvSplitKFolds(UInt_t numFolds, TString splitExpr="", Bool_t stratified=kTRUE, UInt_t seed=100)
Splits a dataset into k folds, ready for use in cross validation.
 
virtual void RecombineKFoldDataSet(DataSetInfo &dsi, Types::ETreeType tt=Types::kTraining)
 
virtual void PrepareFoldDataSet(DataSetInfo &dsi, UInt_t foldNumber, Types::ETreeType tt)
Set training and test set vectors of dataset described by dsi.
 
Class that contains all the data information.
 
std::vector< VariableInfo > & GetSpectatorInfos()
 
UInt_t GetNClasses() const
 
DataSet * GetDataSet() const
returns data set
 
void SetEventCollection(std::vector< Event * > *, Types::ETreeType, Bool_t deleteEvents=true)
Sets the event collection (by DataSetFactory)
 
const std::vector< Event * > & GetEventCollection(Types::ETreeType type=Types::kMaxTreeType) const
 
Float_t GetSpectator(UInt_t ivar) const
return spectator content
 
Class for type info of MVA input variable.
 
const TString & GetLabel() const
 
const TString & GetExpression() const
 
const char * GetName() const override
Returns name of object.
 
const char * Data() const
 
MsgLogger & Endl(MsgLogger &ml)
 
static uint64_t sum(uint64_t i)