1#ifndef TMVA_BATCHGENERATOR
2#define TMVA_BATCHGENERATOR
20namespace Experimental {
23template <
typename... Args>
45 std::unique_ptr<TMVA::Experimental::Internal::RBatchLoader>
fBatchLoader;
67 RBatchGenerator(
const std::string &treeName,
const std::string &fileName,
const std::size_t chunkSize,
68 const std::size_t batchSize,
const std::vector<std::string> &cols,
69 const std::vector<std::string> &
filters = {},
const std::vector<std::size_t> &vecSizes = {},
70 const float vecPadding = 0.0,
const float validationSplit = 0.0,
const std::size_t maxChunks = 0,
71 const std::size_t numColumns = 0,
bool shuffle =
true)
100 std::make_unique<TMVA::Experimental::RTensor<float>>((std::vector<std::size_t>){
fChunkSize,
fNumColumns});
213 std::vector<std::size_t> row_order = std::vector<std::size_t>(processedEvents);
214 std::iota(row_order.begin(), row_order.end(), 0);
217 std::shuffle(row_order.begin(), row_order.end(),
fRng);
224 std::vector<std::size_t> valid_idx({row_order.begin(), row_order.begin() + num_validation});
225 std::vector<std::size_t> train_idx({row_order.begin() + num_validation, row_order.end()});
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
TObject * Get(const char *namecycle) override
Return pointer to object identified by namecycle.
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
void CreateBatches(std::size_t currentChunk, std::size_t processedEvents)
Create batches for the current_chunk.
std::unique_ptr< std::thread > fLoadingThread
const TMVA::Experimental::RTensor< float > & GetTrainBatch()
Returns the next batch of training data if available.
RBatchGenerator(const std::string &treeName, const std::string &fileName, const std::size_t chunkSize, const std::size_t batchSize, const std::vector< std::string > &cols, const std::vector< std::string > &filters={}, const std::vector< std::size_t > &vecSizes={}, const float vecPadding=0.0, const float validationSplit=0.0, const std::size_t maxChunks=0, const std::size_t numColumns=0, bool shuffle=true)
void Activate()
Activate the loading process by starting the batchloader, and spawning the loading thread.
const TMVA::Experimental::RTensor< float > & GetValidationBatch()
Returns the next batch of validation data if available.
std::vector< std::vector< std::size_t > > fTrainingIdxs
std::vector< std::string > fCols
void DeActivate()
De-activate the loading process by deactivating the batchgenerator and joining the loading thread.
void createIdxs(std::size_t processedEvents)
plit the events of the current chunk into validation and training events
std::unique_ptr< TMVA::Experimental::Internal::RBatchLoader > fBatchLoader
std::vector< std::string > fFilters
std::vector< std::size_t > fVecSizes
std::vector< std::vector< std::size_t > > fValidationIdxs
TMVA::RandomGenerator< TRandom3 > fRng
std::unique_ptr< TMVA::Experimental::RTensor< float > > fCurrentBatch
std::unique_ptr< TMVA::Experimental::Internal::RChunkLoader< Args... > > fChunkLoader
std::unique_ptr< TMVA::Experimental::RTensor< float > > fChunkTensor
RTensor is a container with contiguous memory and shape information.
A TTree represents a columnar dataset.
virtual Long64_t GetEntries() const
void EnableThreadSafety()
Enable support for multi-threading within the ROOT code in particular, enables the global mutex to ma...
create variable transformations