Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RBatchGenerator.hxx
Go to the documentation of this file.
1// Author: Dante Niewenhuis, VU Amsterdam 07/2023
2// Author: Kristupas Pranckietis, Vilnius University 05/2024
3// Author: Nopphakorn Subsa-Ard, King Mongkut's University of Technology Thonburi (KMUTT) (TH) 08/2024
4// Author: Vincenzo Eduardo Padulano, CERN 10/2024
5// Author: Martin Føll, University of Oslo (UiO) & CERN 05/2025
6
7/*************************************************************************
8 * Copyright (C) 1995-2025, Rene Brun and Fons Rademakers. *
9 * All rights reserved. *
10 * *
11 * For the licensing terms see $ROOTSYS/LICENSE. *
12 * For the list of contributors see $ROOTSYS/README/CREDITS. *
13 *************************************************************************/
14
15#ifndef TMVA_RBATCHGENERATOR
16#define TMVA_RBATCHGENERATOR
17
22#include "TROOT.h"
23
24#include <cmath>
25#include <memory>
26#include <mutex>
27#include <random>
28#include <thread>
29#include <variant>
30#include <vector>
31
32namespace TMVA {
33namespace Experimental {
34namespace Internal {
35
36// clang-format off
37/**
38\class ROOT::TMVA::Experimental::Internal::RBatchGenerator
39\ingroup tmva
40\brief
41
42In this class, the processes of loading chunks (see RChunkLoader) and creating batches from those chunks (see RBatchLoader) are combined, allowing batches from the training and validation sets to be loaded directly from a dataset in an RDataFrame.
43*/
44
45template <typename... Args>
47private:
48 std::vector<std::string> fCols;
49 // clang-format on
50 std::size_t fChunkSize;
51 std::size_t fMaxChunks;
52 std::size_t fBatchSize;
53 std::size_t fBlockSize;
54 std::size_t fNumColumns;
55 std::size_t fNumChunkCols;
56 std::size_t fNumEntries;
57 std::size_t fSetSeed;
58 std::size_t fSumVecSizes;
59
62
63 std::unique_ptr<RChunkLoader<Args...>> fChunkLoader;
64 std::unique_ptr<RBatchLoader> fBatchLoader;
65
66 std::unique_ptr<std::thread> fLoadingThread;
67
68 std::size_t fTrainingChunkNum;
70
72
73 std::mutex fIsActiveMutex;
74
77 bool fIsActive{false}; // Whether the loading thread is active
80
81 bool fEpochActive{false};
84
87
88 std::size_t fNumTrainingChunks;
90
93
96
99
102
103 // flattened buffers for chunks and temporary tensors (rows * cols)
106
109
110public:
111 RBatchGenerator(ROOT::RDF::RNode &rdf, const std::size_t chunkSize, const std::size_t blockSize,
112 const std::size_t batchSize, const std::vector<std::string> &cols,
113 const std::vector<std::size_t> &vecSizes = {}, const float vecPadding = 0.0,
114 const float validationSplit = 0.0, const std::size_t maxChunks = 0, bool shuffle = true,
115 bool dropRemainder = true, const std::size_t setSeed = 0)
116
117 : f_rdf(rdf),
118 fCols(cols),
120 fBlockSize(blockSize),
121 fBatchSize(batchSize),
127 fNotFiltered(f_rdf.GetFilterNames().empty()),
130 {
131
132 fNumEntries = f_rdf.Count().GetValue();
133 fEntries = f_rdf.Take<ULong64_t>("rdfentry_");
134
135 fSumVecSizes = std::accumulate(vecSizes.begin(), vecSizes.end(), 0);
137
138 // add the last element in entries to not go out of range when filling chunks
139 fEntries->push_back((*fEntries)[fNumEntries - 1] + 1);
140
144 fBatchLoader = std::make_unique<RBatchLoader>(fBatchSize, fNumChunkCols);
145
146 // split the dataset into training and validation sets
147 fChunkLoader->SplitDataset();
148
149 // number of training and validation entries after the split
150 fNumValidationEntries = static_cast<std::size_t>(fValidationSplit * fNumEntries);
152
155
158
161
162 if (dropRemainder) {
165 }
166
167 else {
170 }
171
172 // number of training and validation chunks, calculated in RChunkConstructor
173 fNumTrainingChunks = fChunkLoader->GetNumTrainingChunks();
174 fNumValidationChunks = fChunkLoader->GetNumValidationChunks();
175
178 }
179
181
183 {
184 {
185 std::lock_guard<std::mutex> lock(fIsActiveMutex);
186 fIsActive = false;
187 }
188
189 fBatchLoader->DeActivate();
190
191 if (fLoadingThread) {
192 if (fLoadingThread->joinable()) {
193 fLoadingThread->join();
194 }
195 }
196 }
197
198 /// \brief Activate the loading process by starting the batchloader, and
199 /// spawning the loading thread.
200 void Activate()
201 {
202 if (fIsActive)
203 return;
204
205 {
206 std::lock_guard<std::mutex> lock(fIsActiveMutex);
207 fIsActive = true;
208 }
209
210 fBatchLoader->Activate();
211 // fLoadingThread = std::make_unique<std::thread>(&RBatchGenerator::LoadChunks, this);
212 }
213
214 void ActivateEpoch() { fEpochActive = true; }
215
216 void DeActivateEpoch() { fEpochActive = false; }
217
219
221
223
225
226 /// \brief Create training batches by first loading a chunk (see RChunkLoader) and split it into batches (see RBatchLoader)
228 {
229
230 fChunkLoader->CreateTrainingChunksIntervals();
238 }
239
240 /// \brief Creates validation batches by first loading a chunk (see RChunkLoader), and then split it into batches (see RBatchLoader)
253
254 /// \brief Loads a training batch from the queue
256 {
257 auto batchQueue = fBatchLoader->GetNumTrainingBatchQueue();
258
259 // load the next chunk if the queue is empty
266 }
267
268 else {
270 }
271
272 // Get next batch if available
273 return fBatchLoader->GetTrainBatch();
274 }
275
276 /// \brief Loads a validation batch from the queue
278 {
279 auto batchQueue = fBatchLoader->GetNumValidationBatchQueue();
280
281 // load the next chunk if the queue is empty
288 }
289
290 else {
292 }
293
294 // Get next batch if available
295 return fBatchLoader->GetValidationBatch();
296 }
297
300
303
304 bool IsActive() { return fIsActive; }
306 /// \brief Returns the next batch of validation data if available.
307 /// Returns empty RTensor otherwise.
308};
309
310} // namespace Internal
311} // namespace Experimental
312} // namespace TMVA
313
314#endif // TMVA_RBATCHGENERATOR
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
The public interface to the RDataFrame federation of classes.
RResultPtr< ULong64_t > Count()
Return the number of entries processed (lazy action).
RResultPtr< COLL > Take(std::string_view column="")
Return a collection of values of a column (lazy action, returns a std::vector by default).
Smart pointer for the return type of actions.
const_iterator begin() const
const_iterator end() const
RBatchGenerator(ROOT::RDF::RNode &rdf, const std::size_t chunkSize, const std::size_t blockSize, const std::size_t batchSize, const std::vector< std::string > &cols, const std::vector< std::size_t > &vecSizes={}, const float vecPadding=0.0, const float validationSplit=0.0, const std::size_t maxChunks=0, bool shuffle=true, bool dropRemainder=true, const std::size_t setSeed=0)
std::unique_ptr< std::thread > fLoadingThread
void Activate()
Activate the loading process by starting the batchloader, and spawning the loading thread.
void CreateValidationBatches()
Creates validation batches by first loading a chunk (see RChunkLoader), and then split it into batche...
void CreateTrainBatches()
Create training batches by first loading a chunk (see RChunkLoader) and split it into batches (see RB...
RFlat2DMatrix GetValidationBatch()
Loads a validation batch from the queue.
std::unique_ptr< RChunkLoader< Args... > > fChunkLoader
std::unique_ptr< RBatchLoader > fBatchLoader
ROOT::RDF::RResultPtr< std::vector< ULong64_t > > fEntries
RFlat2DMatrix GetTrainBatch()
Loads a training batch from the queue.
void ChangeBeginAndEndEntries(const RNode &node, Long64_t begin, Long64_t end)
create variable transformations
Wrapper around ROOT::RVec<float> representing a 2D matrix.