Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RDataLoaderEngine.hxx
Go to the documentation of this file.
1// Author: Dante Niewenhuis, VU Amsterdam 07/2023
2// Author: Kristupas Pranckietis, Vilnius University 05/2024
3// Author: Nopphakorn Subsa-Ard, King Mongkut's University of Technology Thonburi (KMUTT) (TH) 08/2024
4// Author: Vincenzo Eduardo Padulano, CERN 10/2024
5// Author: Martin Føll, University of Oslo (UiO) & CERN 01/2026
6// Author: Silia Taider, CERN 02/2026
7
8/*************************************************************************
9 * Copyright (C) 1995-2026, Rene Brun and Fons Rademakers. *
10 * All rights reserved. *
11 * *
12 * For the licensing terms see $ROOTSYS/LICENSE. *
13 * For the list of contributors see $ROOTSYS/README/CREDITS. *
14 *************************************************************************/
15
16#ifndef ROOT_INTERNAL_ML_RDATALOADERENGINE
17#define ROOT_INTERNAL_ML_RDATALOADERENGINE
18
19#include <condition_variable>
20#include <memory>
21#include <mutex>
22#include <string>
23#include <thread>
24#include <vector>
25
31#include "ROOT/ML/RSampler.hxx"
33
34// Empty namespace to create a hook for the Pythonization
36}
37
39/**
40 \class ROOT::Experimental::Internal::ML::RDataLoaderEngine
41\brief
42
43In this class, the processes of loading clusters (see RClusterLoader) and creating batches from those clusters (see
44RBatchLoader) are combined, allowing batches from the training and validation sets to be loaded directly from a dataset
45in an RDataFrame.
46*/
47
48template <typename... Args>
50private:
51 std::vector<std::string> fCols;
52 std::vector<std::size_t> fVecSizes;
53 std::size_t fBatchSize;
54 std::size_t fSetSeed;
55
56 // buffer quantities
57 std::size_t fBatchesInMemory;
58 std::size_t fBufferCapacity;
59 std::size_t fLowWatermark;
60 std::size_t fHighWatermark;
61
62 std::size_t fTrainingClusterIdx{0};
63 std::size_t fValidationClusterIdx{0};
64
65 float fTestSize;
66
67 std::unique_ptr<RDatasetLoader<Args...>> fDatasetLoader;
68 std::unique_ptr<RClusterLoader<Args...>> fClusterLoader;
69 std::unique_ptr<RBatchLoader> fTrainingBatchLoader;
70 std::unique_ptr<RBatchLoader> fValidationBatchLoader;
71 std::unique_ptr<RSampler> fTrainingSampler;
72 std::unique_ptr<RSampler> fValidationSampler;
73
74 std::unique_ptr<RFlat2DMatrixOperators> fTensorOperators;
75
76 std::vector<ROOT::RDF::RNode> fRdfs;
77
78 std::unique_ptr<std::thread> fLoadingThread;
79 std::condition_variable fLoadingCondition;
80 std::mutex fLoadingMutex;
81
85 std::string fSampleType;
88
89 bool fIsActive{false}; // Whether the loading thread is active
90
91 bool fEpochActive{false};
94
97
98 // flattened buffers for chunks and temporary tensors (rows * cols)
99 std::vector<RFlat2DMatrix> fTrainingDatasets;
100 std::vector<RFlat2DMatrix> fValidationDatasets;
101
104
107
108 std::size_t fTrainingEpochCount{0};
109 std::size_t fValidationEpochCount{0};
110
111public:
112 RDataLoaderEngine(const std::vector<ROOT::RDF::RNode> &rdfs, const std::size_t batchSize,
113 const std::size_t batchesInMemory, const std::vector<std::string> &cols,
114 const std::vector<std::size_t> &vecSizes = {}, const float vecPadding = 0.0,
115 const float testSize = 0.0, bool shuffle = true, bool dropRemainder = true,
116 const std::size_t setSeed = 0, bool loadEager = false, std::string sampleType = "",
117 float sampleRatio = 1.0, bool replacement = false)
118 : fRdfs(rdfs),
119 fCols(cols),
121 fBatchSize(batchSize),
131 {
132 fTensorOperators = std::make_unique<RFlat2DMatrixOperators>(fShuffle, fSetSeed);
133
134 if (fLoadEager) {
135 fDatasetLoader = std::make_unique<RDatasetLoader<Args...>>(fRdfs, fTestSize, fCols, fVecSizes, vecPadding,
138
139 if (fSampleType == "") {
140 fDatasetLoader->ConcatenateDatasets();
141
142 fTrainingDataset = fDatasetLoader->GetTrainingDataset();
143 fValidationDataset = fDatasetLoader->GetValidationDataset();
144
145 fNumTrainingEntries = fDatasetLoader->GetNumTrainingEntries();
146 fNumValidationEntries = fDatasetLoader->GetNumValidationEntries();
147 }
148
149 else {
150 fTrainingDatasets = fDatasetLoader->GetTrainingDatasets();
151 fValidationDatasets = fDatasetLoader->GetValidationDatasets();
152
155 fValidationSampler = std::make_unique<RSampler>(fValidationDatasets, fSampleType, fSampleRatio,
157
158 fNumTrainingEntries = fTrainingSampler->GetNumEntries();
159 fNumValidationEntries = fValidationSampler->GetNumEntries();
160 }
161 }
162
163 else {
164 // scan cluster boundaries
165 fClusterLoader = std::make_unique<RClusterLoader<Args...>>(fRdfs, fCols, fVecSizes, vecPadding, fTestSize,
167
168 // derive buffer quantities
172
173 // split cluster list into training and validation
174 fClusterLoader->SplitDataset();
175 fNumTrainingEntries = fClusterLoader->GetNumTrainingEntries();
176 fNumValidationEntries = fClusterLoader->GetNumValidationEntries();
177 }
178
179 fTrainingBatchLoader = std::make_unique<RBatchLoader>(fBatchSize, fCols, fLoadingMutex, fLoadingCondition,
181 fValidationBatchLoader = std::make_unique<RBatchLoader>(fBatchSize, fCols, fLoadingMutex, fLoadingCondition,
183 }
184
186
188 {
189 {
190 std::lock_guard<std::mutex> lock(fLoadingMutex);
191 if (!fIsActive)
192 return;
193 fIsActive = false;
194 }
195
196 fLoadingCondition.notify_all();
197
198 if (fLoadingThread) {
199 if (fLoadingThread->joinable()) {
200 fLoadingThread->join();
201 }
202 }
203
204 fLoadingThread.reset();
205 }
206
207 /// \brief Activate the loading process by spawning the loading thread.
208 void Activate()
209 {
210 {
211 std::lock_guard<std::mutex> lock(fLoadingMutex);
212 if (fIsActive)
213 return;
214
215 fIsActive = true;
216 }
217
218 if (fLoadEager) {
219 return;
220 }
221
222 fLoadingThread = std::make_unique<std::thread>(&RDataLoaderEngine::LoadData, this);
223 }
224
225 /// \brief Activate the training epoch by starting the batchloader.
227 {
228 {
229 std::lock_guard<std::mutex> lock(fLoadingMutex);
232 if (!fLoadEager) {
233 // Shuffle the cluster indices at the beginning of each epoch
234 fClusterLoader->ShuffleTrainingClusters(fTrainingEpochCount++);
235 }
236 }
237
238 fTrainingBatchLoader->Activate();
239 fLoadingCondition.notify_all();
240 }
241
243 {
244 {
245 std::lock_guard<std::mutex> lock(fLoadingMutex);
246 fTrainingEpochActive = false;
247 }
248
249 fTrainingBatchLoader->Reset();
250 fTrainingBatchLoader->DeActivate();
251 fLoadingCondition.notify_all();
252 }
253
255 {
256 {
257 std::lock_guard<std::mutex> lock(fLoadingMutex);
260 if (!fLoadEager) {
261 fClusterLoader->ShuffleValidationClusters(fValidationEpochCount++);
262 }
263 }
264
265 fValidationBatchLoader->Activate();
266 fLoadingCondition.notify_all();
267 }
268
270 {
271 {
272 std::lock_guard<std::mutex> lock(fLoadingMutex);
274 }
275
276 fValidationBatchLoader->Reset();
277 fValidationBatchLoader->DeActivate();
278 fLoadingCondition.notify_all();
279 }
280
281 /// \brief Main loop for loading clusters and creating batches.
282 /// The producer (loading thread) will keep loading clusters and creating batches until the end of the epoch is
283 /// reached, or the generator is deactivated.
284 void LoadData()
285 {
286 std::unique_lock<std::mutex> lock(fLoadingMutex);
287
288 while (true) {
289 // Wait until we have work or shutdown
290 fLoadingCondition.wait(lock, [&] {
291 return !fIsActive ||
292 (fTrainingEpochActive && fTrainingClusterIdx < fClusterLoader->GetNumTrainingClusters()) ||
293 (fValidationEpochActive && fValidationClusterIdx < fClusterLoader->GetNumValidationClusters());
294 });
295
296 if (!fIsActive) {
297 break;
298 }
299
300 const std::size_t numTrainingClusters = fClusterLoader->GetNumTrainingClusters();
301 const std::size_t numValidationClusters = fClusterLoader->GetNumValidationClusters();
302
303 // Helper: check if validation queue below watermark and needs the producer
304 auto validationEmpty = [&] {
306 return false;
307 if (fValidationBatchLoader->isProducerDone())
308 return false;
309 return fValidationBatchLoader->GetNumBatchQueue() < fLowWatermark / fBatchSize;
310 };
311
312 // -- TRAINING --
314 while (true) {
315 // Stop conditions (shutdown or epoch end)
317 break;
318
319 // No more chunks to load: signal consumers
321 fTrainingBatchLoader->MarkProducerDone();
322 break;
323 }
324
325 // In the case of training prefetching, we could start requesting data for the next training loop while
326 // validation is active and might need data. To avoid getting stuck in the training loop, we check if the
327 // validation queue is below watermark and if so, we break out of the training loop.
328 if (validationEmpty()) {
329 break;
330 }
331
332 // If queue is not empty, wait until it drains below watermark, or validation needs data, or we are
333 // deactivated.
334 if (fTrainingBatchLoader->GetNumBatchQueue() >= fLowWatermark / fBatchSize) {
335 fLoadingCondition.wait(lock, [&] {
336 return !fIsActive || !fTrainingEpochActive ||
337 fTrainingBatchLoader->GetNumBatchQueue() < (fLowWatermark / fBatchSize) ||
339 });
340 continue;
341 }
342
343 // Accumulate clusters to load, enough to fill the buffer, or until we run out of clusters
344 std::vector<RClusterRange> trainClustersToLoad;
345 auto accumulatedEntries = 0;
346 const bool discovering = !fClusterLoader->IsSplitDiscovered();
348 (!discovering || trainClustersToLoad.empty())) {
349 const auto &cluster = fClusterLoader->GetTrainingClusters()[fTrainingClusterIdx++];
350 trainClustersToLoad.push_back(cluster);
351 accumulatedEntries += cluster.GetNumEntries();
352 }
353
355
356 // Release lock while reading and loading data to allow the consumer to access the queue freely in
357 // parallel. The loading thread re-acquires the lock in CreateBatches when it needs to push batches to
358 // the queue.
359 lock.unlock();
361 std::size_t rowOffset = 0;
362
363 for (auto &cluster : trainClustersToLoad) {
364 auto loadedEntries = fClusterLoader->LoadTrainingClusterInto(stagingBuffer, cluster.rdfIdx,
365 cluster.start, cluster.end, rowOffset);
366 if (discovering) {
367 // For the first epoch, we might discover that the cluster has fewer entries than expected because
368 // of filters
369 cluster.SetNumEntries(loadedEntries);
370 }
371 rowOffset += cluster.GetNumEntries();
372 }
373
374 if (discovering && fNumTrainingEntries == 0 && fClusterLoader->GetNumTrainingEntries() > 0) {
375 fNumTrainingEntries = fClusterLoader->GetNumTrainingEntries();
376 fNumValidationEntries = fClusterLoader->GetNumValidationEntries();
377 fTrainingBatchLoader->RecalculateBatchCounts(fNumTrainingEntries);
378 fValidationBatchLoader->RecalculateBatchCounts(fNumValidationEntries);
379 }
380
381 if (rowOffset < static_cast<std::size_t>(accumulatedEntries)) {
382 stagingBuffer.Resize(rowOffset, stagingBuffer.GetCols());
383 }
384
388
389 // Re-acquire the lock before the next iteration to check conditions and update indices
390 lock.lock();
391
392 if (isLastBuffer && discovering) {
393 fClusterLoader->FinaliseSplitDiscovery();
394 }
395 }
396 }
397
398 // -- VALIDATION --
400 while (true) {
401 // Stop conditions (shutdown or epoch end)
403 break;
404
405 // No more chunks to load: signal consumers
407 fValidationBatchLoader->MarkProducerDone();
408 break;
409 }
410
411 // If queue is not hungry, wait until it drains below watermark, or we are deactivated
412 if (fValidationBatchLoader->GetNumBatchQueue() >= (fLowWatermark / fBatchSize)) {
413 fLoadingCondition.wait(lock, [&] {
414 return !fIsActive || !fValidationEpochActive ||
415 fValidationBatchLoader->GetNumBatchQueue() < (fLowWatermark / fBatchSize);
416 });
417 continue;
418 }
419
420 // Accumulate clusters to load, enough to fill the buffer, or until we run out of clusters
421 std::vector<RClusterRange> valClustersToLoad;
422 auto accumulatedEntries = 0;
424 const auto &cluster = fClusterLoader->GetValidationClusters()[fValidationClusterIdx++];
425 valClustersToLoad.push_back(cluster);
426 accumulatedEntries += cluster.GetNumEntries();
427 }
428
430
431 lock.unlock();
432
434 std::size_t rowOffset = 0;
435
436 for (const auto &cluster : valClustersToLoad) {
437 fClusterLoader->LoadValidationClusterInto(stagingBuffer, cluster.rdfIdx, cluster.start, cluster.end,
438 rowOffset);
439 rowOffset += cluster.GetNumEntries();
440 }
441
445
446 lock.lock();
447 }
448 }
449 }
450 }
451
452 /// \brief Create training batches by first loading a chunk (see RClusterLoader) and split it into batches (see
453 /// RBatchLoader)
455 {
456 fTrainingBatchLoader->Activate();
457
458 if (fLoadEager) {
459 if (fSampleType == "") {
461 }
462
463 else {
465 }
466
467 fTrainingBatchLoader->CreateBatches(fSampledTrainingDataset, true);
468 fTrainingBatchLoader->MarkProducerDone();
469 }
470 }
471
472 /// \brief Creates validation batches by first loading a chunk (see RClusterLoader), and then split it into batches
473 /// (see RBatchLoader)
475 {
476 fValidationBatchLoader->Activate();
477
478 if (fLoadEager) {
479 if (fSampleType == "") {
481 }
482
483 else {
485 }
486
488 fValidationBatchLoader->MarkProducerDone();
489 }
490 }
491
492 /// \brief Loads a training batch from the queue
494 {
495 // Get next batch if available
496 return fTrainingBatchLoader->GetBatch();
497 }
498
499 /// \brief Loads a validation batch from the queue
501 {
502 // Get next batch if available
503 return fValidationBatchLoader->GetBatch();
504 }
505
506 std::size_t NumberOfTrainingBatches() { return fTrainingBatchLoader->GetNumBatches(); }
507 std::size_t NumberOfValidationBatches() { return fValidationBatchLoader->GetNumBatches(); }
508
509 std::size_t TrainRemainderRows() { return fTrainingBatchLoader->GetNumRemainderRows(); }
510 std::size_t ValidationRemainderRows() { return fValidationBatchLoader->GetNumRemainderRows(); }
511
512 bool IsActive()
513 {
514 std::lock_guard<std::mutex> lock(fLoadingMutex);
515 return fIsActive;
516 }
517
519 {
520 std::lock_guard<std::mutex> lock(fLoadingMutex);
522 }
523
525 {
526 std::lock_guard<std::mutex> lock(fLoadingMutex);
528 }
529};
530
531} // namespace ROOT::Experimental::Internal::ML
532
533#endif // ROOT_INTERNAL_ML_RDATALOADERENGINE
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Loads TTree/RNTuple clusters from one or more RDataFrames into RFlat2DMatrix buffers for ML training ...
In this class, the processes of loading clusters (see RClusterLoader) and creating batches from those...
void ActivateTrainingEpoch()
Activate the training epoch by starting the batchloader.
RFlat2DMatrix GetTrainBatch()
Loads a training batch from the queue.
std::unique_ptr< RFlat2DMatrixOperators > fTensorOperators
void CreateValidationBatches()
Creates validation batches by first loading a chunk (see RClusterLoader), and then split it into batc...
void LoadData()
Main loop for loading clusters and creating batches.
void CreateTrainBatches()
Create training batches by first loading a chunk (see RClusterLoader) and split it into batches (see ...
RFlat2DMatrix GetValidationBatch()
Loads a validation batch from the queue.
void Activate()
Activate the loading process by spawning the loading thread.
std::unique_ptr< RDatasetLoader< Args... > > fDatasetLoader
RDataLoaderEngine(const std::vector< ROOT::RDF::RNode > &rdfs, const std::size_t batchSize, const std::size_t batchesInMemory, const std::vector< std::string > &cols, const std::vector< std::size_t > &vecSizes={}, const float vecPadding=0.0, const float testSize=0.0, bool shuffle=true, bool dropRemainder=true, const std::size_t setSeed=0, bool loadEager=false, std::string sampleType="", float sampleRatio=1.0, bool replacement=false)
std::unique_ptr< RClusterLoader< Args... > > fClusterLoader
void SplitDatasets()
Split the dataframes in a training and validation dataset.
const_iterator end() const
Wrapper around ROOT::RVec<float> representing a 2D matrix.