Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RBatchGenerator.hxx
Go to the documentation of this file.
1// Author: Dante Niewenhuis, VU Amsterdam 07/2023
2// Author: Kristupas Pranckietis, Vilnius University 05/2024
3// Author: Nopphakorn Subsa-Ard, King Mongkut's University of Technology Thonburi (KMUTT) (TH) 08/2024
4// Author: Vincenzo Eduardo Padulano, CERN 10/2024
5// Author: Martin Føll, University of Oslo (UiO) & CERN 05/2025
6
7/*************************************************************************
8 * Copyright (C) 1995-2025, Rene Brun and Fons Rademakers. *
9 * All rights reserved. *
10 * *
11 * For the licensing terms see $ROOTSYS/LICENSE. *
12 * For the list of contributors see $ROOTSYS/README/CREDITS. *
13 *************************************************************************/
14
15#ifndef TMVA_RBATCHGENERATOR
16#define TMVA_RBATCHGENERATOR
17
18#include "TMVA/RTensor.hxx"
22#include "TROOT.h"
23
24#include <cmath>
25#include <memory>
26#include <mutex>
27#include <random>
28#include <thread>
29#include <variant>
30#include <vector>
31
32namespace TMVA {
33namespace Experimental {
34namespace Internal {
35
36// clang-format off
37/**
38\class ROOT::TMVA::Experimental::Internal::RBatchGenerator
39\ingroup tmva
40\brief
41
42In this class, the processes of loading chunks (see RChunkLoader) and creating batches from those chunks (see RBatchLoader) are combined, allowing batches from the training and validation sets to be loaded directly from a dataset in an RDataFrame.
43*/
44
45template <typename... Args>
47private:
48 std::vector<std::string> fCols;
49 // clang-format on
50 std::size_t fChunkSize;
51 std::size_t fMaxChunks;
52 std::size_t fBatchSize;
53 std::size_t fBlockSize;
54 std::size_t fNumColumns;
55 std::size_t fNumChunkCols;
56 std::size_t fNumEntries;
57 std::size_t fSetSeed;
58 std::size_t fSumVecSizes;
59
62
63 std::unique_ptr<RChunkLoader<Args...>> fChunkLoader;
64 std::unique_ptr<RBatchLoader> fBatchLoader;
65
66 std::unique_ptr<std::thread> fLoadingThread;
67
68 std::size_t fTrainingChunkNum;
70
72
73 std::mutex fIsActiveMutex;
74
77 bool fIsActive{false}; // Whether the loading thread is active
80
81 bool fEpochActive{false};
84
87
88 std::size_t fNumTrainingChunks;
90
93
96
99
102
105
108
109public:
110 RBatchGenerator(ROOT::RDF::RNode &rdf, const std::size_t chunkSize, const std::size_t blockSize,
111 const std::size_t batchSize, const std::vector<std::string> &cols,
112 const std::vector<std::size_t> &vecSizes = {}, const float vecPadding = 0.0,
113 const float validationSplit = 0.0, const std::size_t maxChunks = 0, bool shuffle = true,
114 bool dropRemainder = true, const std::size_t setSeed = 0)
115
116 : f_rdf(rdf),
117 fCols(cols),
119 fBlockSize(blockSize),
120 fBatchSize(batchSize),
126 fNotFiltered(f_rdf.GetFilterNames().empty()),
129 fTrainTensor({0, 0}),
130 fTrainChunkTensor({0, 0}),
131 fValidationTensor({0, 0}),
133 {
134
135 fNumEntries = f_rdf.Count().GetValue();
136 fEntries = f_rdf.Take<ULong64_t>("rdfentry_");
137
138 fSumVecSizes = std::accumulate(vecSizes.begin(), vecSizes.end(), 0);
140
141 // add the last element in entries to not go out of range when filling chunks
142 fEntries->push_back((*fEntries)[fNumEntries - 1] + 1);
143
147 fBatchLoader = std::make_unique<RBatchLoader>(fBatchSize, fNumChunkCols);
148
149 // split the dataset into training and validation sets
150 fChunkLoader->SplitDataset();
151
152 // number of training and validation entries after the split
153 fNumValidationEntries = static_cast<std::size_t>(fValidationSplit * fNumEntries);
155
158
161
164
165 if (dropRemainder) {
168 }
169
170 else {
173 }
174
175 // number of training and validation chunks, calculated in RChunkConstructor
176 fNumTrainingChunks = fChunkLoader->GetNumTrainingChunks();
177 fNumValidationChunks = fChunkLoader->GetNumValidationChunks();
178
181 }
182
184
186 {
187 {
188 std::lock_guard<std::mutex> lock(fIsActiveMutex);
189 fIsActive = false;
190 }
191
192 fBatchLoader->DeActivate();
193
194 if (fLoadingThread) {
195 if (fLoadingThread->joinable()) {
196 fLoadingThread->join();
197 }
198 }
199 }
200
201 /// \brief Activate the loading process by starting the batchloader, and
202 /// spawning the loading thread.
203 void Activate()
204 {
205 if (fIsActive)
206 return;
207
208 {
209 std::lock_guard<std::mutex> lock(fIsActiveMutex);
210 fIsActive = true;
211 }
212
213 fBatchLoader->Activate();
214 // fLoadingThread = std::make_unique<std::thread>(&RBatchGenerator::LoadChunks, this);
215 }
216
217 void ActivateEpoch() { fEpochActive = true; }
218
219 void DeActivateEpoch() { fEpochActive = false; }
220
222
224
226
228
229 /// \brief Create training batches by first loading a chunk (see RChunkLoader) and split it into batches (see RBatchLoader)
231 {
232
233 fChunkLoader->CreateTrainingChunksIntervals();
241 }
242
243 /// \brief Creates validation batches by first loading a chunk (see RChunkLoader), and then split it into batches (see RBatchLoader)
256
257 /// \brief Loads a training batch from the queue
259 {
260 auto batchQueue = fBatchLoader->GetNumTrainingBatchQueue();
261
262 // load the next chunk if the queue is empty
269 }
270
271 else {
273 }
274
275 // Get next batch if available
276 return fBatchLoader->GetTrainBatch();
277 }
278
279 /// \brief Loads a validation batch from the queue
281 {
282 auto batchQueue = fBatchLoader->GetNumValidationBatchQueue();
283
284 // load the next chunk if the queue is empty
291 }
292
293 else {
295 }
296
297 // Get next batch if available
298 return fBatchLoader->GetValidationBatch();
299 }
300
303
306
307 bool IsActive() { return fIsActive; }
309 /// \brief Returns the next batch of validation data if available.
310 /// Returns empty RTensor otherwise.
311};
312
313} // namespace Internal
314} // namespace Experimental
315} // namespace TMVA
316
317#endif // TMVA_RBATCHGENERATOR
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
unsigned long long ULong64_t
Portable unsigned long integer 8 bytes.
Definition RtypesCore.h:84
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
The public interface to the RDataFrame federation of classes.
RResultPtr< ULong64_t > Count()
Return the number of entries processed (lazy action).
RResultPtr< COLL > Take(std::string_view column="")
Return a collection of values of a column (lazy action, returns a std::vector by default).
Smart pointer for the return type of actions.
const_iterator begin() const
const_iterator end() const
Building and loading the chunks from the blocks and chunks constructed in RChunkConstructor.
RBatchGenerator(ROOT::RDF::RNode &rdf, const std::size_t chunkSize, const std::size_t blockSize, const std::size_t batchSize, const std::vector< std::string > &cols, const std::vector< std::size_t > &vecSizes={}, const float vecPadding=0.0, const float validationSplit=0.0, const std::size_t maxChunks=0, bool shuffle=true, bool dropRemainder=true, const std::size_t setSeed=0)
std::unique_ptr< std::thread > fLoadingThread
void Activate()
Activate the loading process by starting the batchloader, and spawning the loading thread.
TMVA::Experimental::RTensor< float > fValidationTensor
void CreateValidationBatches()
Creates validation batches by first loading a chunk (see RChunkLoader), and then split it into batche...
void CreateTrainBatches()
Create training batches by first loading a chunk (see RChunkLoader) and split it into batches (see RB...
TMVA::Experimental::RTensor< float > fTrainTensor
TMVA::Experimental::RTensor< float > GetValidationBatch()
Loads a validation batch from the queue
std::unique_ptr< RChunkLoader< Args... > > fChunkLoader
std::unique_ptr< RBatchLoader > fBatchLoader
TMVA::Experimental::RTensor< float > GetTrainBatch()
Loads a training batch from the queue.
ROOT::RDF::RResultPtr< std::vector< ULong64_t > > fEntries
TMVA::Experimental::RTensor< float > fValidationChunkTensor
TMVA::Experimental::RTensor< float > fTrainChunkTensor
RTensor is a container with contiguous memory and shape information.
Definition RTensor.hxx:163
void ChangeBeginAndEndEntries(const RNode &node, Long64_t begin, Long64_t end)
create variable transformations