Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RBatchGenerator.hxx
Go to the documentation of this file.
1// Author: Dante Niewenhuis, VU Amsterdam 07/2023
2// Author: Kristupas Pranckietis, Vilnius University 05/2024
3// Author: Nopphakorn Subsa-Ard, King Mongkut's University of Technology Thonburi (KMUTT) (TH) 08/2024
4// Author: Vincenzo Eduardo Padulano, CERN 10/2024
5// Author: Martin Føll, University of Oslo (UiO) & CERN 01/2026
6
7/*************************************************************************
8 * Copyright (C) 1995-2026, Rene Brun and Fons Rademakers. *
9 * All rights reserved. *
10 * *
11 * For the licensing terms see $ROOTSYS/LICENSE. *
12 * For the list of contributors see $ROOTSYS/README/CREDITS. *
13 *************************************************************************/
14
15#ifndef TMVA_RBATCHGENERATOR
16#define TMVA_RBATCHGENERATOR
17
22
26#include "TROOT.h"
27
28#include <cmath>
29#include <memory>
30#include <mutex>
31#include <random>
32#include <thread>
33#include <variant>
34#include <vector>
35
36namespace TMVA {
37namespace Experimental {
38namespace Internal {
39
40// clang-format off
41/**
42\class ROOT::TMVA::Experimental::Internal::RBatchGenerator
43\ingroup tmva
44\brief
45
46In this class, the processes of loading chunks (see RChunkLoader) and creating batches from those chunks (see RBatchLoader) are combined, allowing batches from the training and validation sets to be loaded directly from a dataset in an RDataFrame.
47*/
48
49template <typename... Args>
51private:
52 std::vector<std::string> fCols;
53 std::vector<std::size_t> fVecSizes;
54 // clang-format on
55 std::size_t fChunkSize;
56 std::size_t fMaxChunks;
57 std::size_t fBatchSize;
58 std::size_t fBlockSize;
59 std::size_t fSetSeed;
60
62
63 std::unique_ptr<RDatasetLoader<Args...>> fDatasetLoader;
64 std::unique_ptr<RChunkLoader<Args...>> fChunkLoader;
65 std::unique_ptr<RBatchLoader> fTrainingBatchLoader;
66 std::unique_ptr<RBatchLoader> fValidationBatchLoader;
67 std::unique_ptr<RSampler> fTrainingSampler;
68 std::unique_ptr<RSampler> fValidationSampler;
69
70 std::unique_ptr<RFlat2DMatrixOperators> fTensorOperators;
71
72 std::vector<ROOT::RDF::RNode> f_rdfs;
73
74 std::unique_ptr<std::thread> fLoadingThread;
75
76 std::size_t fTrainingChunkNum;
78
79 std::mutex fIsActiveMutex;
80
84 std::string fSampleType;
87
88 bool fIsActive{false}; // Whether the loading thread is active
90
91 bool fEpochActive{false};
94
97
98 std::size_t fNumTrainingChunks;
100
101 // flattened buffers for chunks and temporary tensors (rows * cols)
102 std::vector<RFlat2DMatrix> fTrainingDatasets;
103 std::vector<RFlat2DMatrix> fValidationDatasets;
104
107
110
113
116
117public:
118 RBatchGenerator(const std::vector<ROOT::RDF::RNode> &rdfs, const std::size_t chunkSize, const std::size_t blockSize,
119 const std::size_t batchSize, const std::vector<std::string> &cols,
120 const std::vector<std::size_t> &vecSizes = {}, const float vecPadding = 0.0,
121 const float validationSplit = 0.0, const std::size_t maxChunks = 0, bool shuffle = true,
122 bool dropRemainder = true, const std::size_t setSeed = 0, bool loadEager = false,
123 std::string sampleType = "", float sampleRatio = 1.0, bool replacement = false)
124
125 : f_rdfs(rdfs),
126 fCols(cols),
129 fBlockSize(blockSize),
130 fBatchSize(batchSize),
141 {
142 fTensorOperators = std::make_unique<RFlat2DMatrixOperators>(fShuffle, fSetSeed);
143
144 if (fLoadEager) {
145 fDatasetLoader = std::make_unique<RDatasetLoader<Args...>>(f_rdfs, fValidationSplit, fCols, fVecSizes,
147 // split the datasets and extract the training and validation datasets
149
150 if (fSampleType == "") {
151 fDatasetLoader->ConcatenateDatasets();
152
153 fTrainingDataset = fDatasetLoader->GetTrainingDataset();
154 fValidationDataset = fDatasetLoader->GetValidationDataset();
155
156 fNumTrainingEntries = fDatasetLoader->GetNumTrainingEntries();
157 fNumValidationEntries = fDatasetLoader->GetNumValidationEntries();
158 }
159
160 else {
161 fTrainingDatasets = fDatasetLoader->GetTrainingDatasets();
162 fValidationDatasets = fDatasetLoader->GetValidationDatasets();
163
168
169 fNumTrainingEntries = fTrainingSampler->GetNumEntries();
170 fNumValidationEntries = fValidationSampler->GetNumEntries();
171 }
172 }
173
174 else {
176 std::make_unique<RChunkLoader<Args...>>(f_rdfs[0], fChunkSize, fBlockSize, fValidationSplit,
178
179 // split the dataset into training and validation sets
180 fChunkLoader->SplitDataset();
181
182 fNumTrainingEntries = fChunkLoader->GetNumTrainingEntries();
183 fNumValidationEntries = fChunkLoader->GetNumValidationEntries();
184
185 // number of training and validation chunks, calculated in RChunkConstructor
186 fNumTrainingChunks = fChunkLoader->GetNumTrainingChunks();
187 fNumValidationChunks = fChunkLoader->GetNumValidationChunks();
188 }
189
190 fTrainingBatchLoader = std::make_unique<RBatchLoader>(fBatchSize, fCols, fVecSizes,
192 fValidationBatchLoader = std::make_unique<RBatchLoader>(fBatchSize, fCols, fVecSizes,
194 }
195
197
199 {
200 {
201 std::lock_guard<std::mutex> lock(fIsActiveMutex);
202 fIsActive = false;
203 }
204
205 fTrainingBatchLoader->DeActivate();
206 fValidationBatchLoader->DeActivate();
207
208 if (fLoadingThread) {
209 if (fLoadingThread->joinable()) {
210 fLoadingThread->join();
211 }
212 }
213 }
214
215 /// \brief Activate the loading process by starting the batchloader, and
216 /// spawning the loading thread.
217 void Activate()
218 {
219 if (fIsActive)
220 return;
221
222 {
223 std::lock_guard<std::mutex> lock(fIsActiveMutex);
224 fIsActive = true;
225 }
226
227 fTrainingBatchLoader->Activate();
228 fValidationBatchLoader->Activate();
229 // fLoadingThread = std::make_unique<std::thread>(&RBatchGenerator::LoadChunks, this);
230 }
231
232 void ActivateEpoch() { fEpochActive = true; }
233
234 void DeActivateEpoch() { fEpochActive = false; }
235
237
239
241
243
244 /// \brief Create training batches by first loading a chunk (see RChunkLoader) and split it into batches (see RBatchLoader)
246 {
248 if (fLoadEager) {
249 if (fSampleType == "") {
251 }
252
253 else {
255 }
256
258 }
259
260 else {
261 fChunkLoader->CreateTrainingChunksIntervals();
266 }
267 }
268
269 /// \brief Creates validation batches by first loading a chunk (see RChunkLoader), and then split it into batches (see RBatchLoader)
271 {
273 if (fLoadEager) {
274 if (fSampleType == "") {
276 }
277
278 else {
280 }
281
283 }
284
285 else {
286 fChunkLoader->CreateValidationChunksIntervals();
291 }
292 }
293
294 /// \brief Loads a training batch from the queue
296 {
297 if (!fLoadEager) {
298 auto batchQueue = fTrainingBatchLoader->GetNumBatchQueue();
299
300 // load the next chunk if the queue is empty
306 }
307 }
308 // Get next batch if available
309 return fTrainingBatchLoader->GetBatch();
310 }
311
312 /// \brief Loads a validation batch from the queue
314 {
315 if (!fLoadEager) {
316 auto batchQueue = fValidationBatchLoader->GetNumBatchQueue();
317
318 // load the next chunk if the queue is empty
324 }
325 }
326 // Get next batch if available
327 return fValidationBatchLoader->GetBatch();
328 }
329
330 std::size_t NumberOfTrainingBatches() { return fTrainingBatchLoader->GetNumBatches(); }
331 std::size_t NumberOfValidationBatches() { return fValidationBatchLoader->GetNumBatches(); }
332
333 std::size_t TrainRemainderRows() { return fTrainingBatchLoader->GetNumRemainderRows(); }
334 std::size_t ValidationRemainderRows() { return fValidationBatchLoader->GetNumRemainderRows(); }
335
336 bool IsActive() { return fIsActive; }
338 /// \brief Returns the next batch of validation data if available.
339 /// Returns empty RTensor otherwise.
340};
341
342} // namespace Internal
343} // namespace Experimental
344} // namespace TMVA
345
346#endif // TMVA_RBATCHGENERATOR
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Building and loading the chunks from the blocks and chunks constructed in RChunkConstructor.
std::unique_ptr< RFlat2DMatrixOperators > fTensorOperators
std::unique_ptr< std::thread > fLoadingThread
std::unique_ptr< RBatchLoader > fTrainingBatchLoader
void Activate()
Activate the loading process by starting the batchloader, and spawning the loading thread.
RBatchGenerator(const std::vector< ROOT::RDF::RNode > &rdfs, const std::size_t chunkSize, const std::size_t blockSize, const std::size_t batchSize, const std::vector< std::string > &cols, const std::vector< std::size_t > &vecSizes={}, const float vecPadding=0.0, const float validationSplit=0.0, const std::size_t maxChunks=0, bool shuffle=true, bool dropRemainder=true, const std::size_t setSeed=0, bool loadEager=false, std::string sampleType="", float sampleRatio=1.0, bool replacement=false)
void CreateValidationBatches()
Creates validation batches by first loading a chunk (see RChunkLoader), and then split it into batche...
void CreateTrainBatches()
Create training batches by first loading a chunk (see RChunkLoader) and split it into batches (see RB...
std::unique_ptr< RDatasetLoader< Args... > > fDatasetLoader
RFlat2DMatrix GetValidationBatch()
Loads a validation batch from the queue.
std::unique_ptr< RChunkLoader< Args... > > fChunkLoader
std::unique_ptr< RBatchLoader > fValidationBatchLoader
RFlat2DMatrix GetTrainBatch()
Loads a training batch from the queue.
void SplitDatasets()
Split the dataframes in a training and validation dataset.
create variable transformations
Wrapper around ROOT::RVec<float> representing a 2D matrix.