14#ifndef TMVA_RBATCHGENERATOR
15#define TMVA_RBATCHGENERATOR
32namespace Experimental {
35template <
typename... Args>
69 const std::vector<std::string> &cols,
const std::size_t numColumns,
70 const std::vector<std::size_t> &vecSizes = {},
const float vecPadding = 0.0,
71 const float validationSplit = 0.0,
const std::size_t maxChunks = 0,
bool shuffle =
true,
72 bool dropRemainder =
true)
73 :
fRng(std::random_device{}()),
74 fFixedSeed(std::uniform_int_distribution<std::random_device::result_type>{}(
fRng)),
88 std::make_unique<TMVA::Experimental::RTensor<float>>(std::vector<std::size_t>{
fChunkSize, numColumns});
98 std::size_t numAllEntries = report.begin()->GetAll();
173 std::size_t entriesForTraining =
188 std::size_t entriesForTraining =
232 for (std::size_t currentChunk = 0, currentEntry = 0;
244 currentEntry += report;
258 std::size_t currentChunk = 0;
259 for (std::size_t processedEvents = 0, currentRow = 0;
270 std::pair<std::size_t, std::size_t> report =
273 currentRow += report.first;
274 processedEvents += report.second;
294 auto &&[trainingIndices, validationIndices] =
createIndices(processedEvents);
297 fBatchLoader->CreateValidationBatches(validationIndices);
302 std::pair<std::vector<std::size_t>, std::vector<std::size_t>>
createIndices(std::size_t events)
305 std::vector<std::size_t> row_order = std::vector<std::size_t>(events);
306 std::iota(row_order.begin(), row_order.end(), 0);
310 std::shuffle(row_order.begin(), row_order.end(),
fFixedRng);
317 std::vector<std::size_t> trainingIndices =
318 std::vector<std::size_t>({row_order.begin(), row_order.end() - num_validation});
319 std::vector<std::size_t> validationIndices =
320 std::vector<std::size_t>({row_order.end() - num_validation, row_order.end()});
323 std::shuffle(trainingIndices.begin(), trainingIndices.end(),
fRng);
326 return std::make_pair(trainingIndices, validationIndices);
The public interface to the RDataFrame federation of classes.
RResultPtr< ULong64_t > Count()
Return the number of entries processed (lazy action).
RResultPtr< RCutFlowReport > Report()
Gather filtering statistics.
std::vector< std::string > GetFilterNames()
Returns the names of the filters created.
std::size_t ValidationRemainderRows()
Return number of validation remainder rows.
std::size_t TrainRemainderRows()
Return number of training remainder rows.
void LoadChunksNoFilters()
Load chunks when no filters are applied on rdataframe.
std::unique_ptr< std::thread > fLoadingThread
std::size_t NumberOfTrainingBatches()
const TMVA::Experimental::RTensor< float > & GetTrainBatch()
Returns the next batch of training data if available.
void Activate()
Activate the loading process by starting the batchloader, and spawning the loading thread.
std::size_t NumberOfValidationBatches()
Calculate number of validation batches and return it.
std::pair< std::vector< std::size_t >, std::vector< std::size_t > > createIndices(std::size_t events)
split the events of the current chunk into training and validation events, shuffle if needed
const TMVA::Experimental::RTensor< float > & GetValidationBatch()
Returns the next batch of validation data if available.
void DeActivate()
De-activate the loading process by deactivating the batchgenerator and joining the loading thread.
std::variant< std::shared_ptr< RChunkLoader< Args... > >, std::shared_ptr< RChunkLoaderFilters< Args... > > > fChunkLoader
std::unique_ptr< RBatchLoader > fBatchLoader
std::mutex fIsActiveMutex
std::random_device::result_type fFixedSeed
void CreateBatches(std::size_t processedEvents)
Create batches.
std::unique_ptr< TMVA::Experimental::RTensor< float > > fChunkTensor
RBatchGenerator(ROOT::RDF::RNode &rdf, const std::size_t chunkSize, const std::size_t batchSize, const std::vector< std::string > &cols, const std::size_t numColumns, const std::vector< std::size_t > &vecSizes={}, const float vecPadding=0.0, const float validationSplit=0.0, const std::size_t maxChunks=0, bool shuffle=true, bool dropRemainder=true)
RTensor is a container with contiguous memory and shape information.
RVec< PromoteType< T > > ceil(const RVec< T > &v)
create variable transformations