Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RDataLoaderEngine.hxx
Go to the documentation of this file.
1// Author: Dante Niewenhuis, VU Amsterdam 07/2023
2// Author: Kristupas Pranckietis, Vilnius University 05/2024
3// Author: Nopphakorn Subsa-Ard, King Mongkut's University of Technology Thonburi (KMUTT) (TH) 08/2024
4// Author: Vincenzo Eduardo Padulano, CERN 10/2024
5// Author: Martin Føll, University of Oslo (UiO) & CERN 01/2026
6// Author: Silia Taider, CERN 02/2026
7
8/*************************************************************************
9 * Copyright (C) 1995-2026, Rene Brun and Fons Rademakers. *
10 * All rights reserved. *
11 * *
12 * For the licensing terms see $ROOTSYS/LICENSE. *
13 * For the list of contributors see $ROOTSYS/README/CREDITS. *
14 *************************************************************************/
15
16#ifndef ROOT_INTERNAL_ML_RDATALOADERENGINE
17#define ROOT_INTERNAL_ML_RDATALOADERENGINE
18
19#include <condition_variable>
20#include <memory>
21#include <mutex>
22#include <string>
23#include <thread>
24#include <vector>
25
31#include "ROOT/ML/RSampler.hxx"
33
34// Empty namespace to create a hook for the Pythonization
36}
37
39/**
40 \class ROOT::Experimental::Internal::ML::RDataLoaderEngine
41\brief
42
43In this class, the processes of loading clusters (see RClusterLoader) and creating batches from those clusters (see
44RBatchLoader) are combined, allowing batches from the training and validation sets to be loaded directly from a dataset
45in an RDataFrame.
46*/
47
48template <typename... Args>
50private:
51 std::vector<std::string> fCols;
52 std::vector<std::size_t> fVecSizes;
53 std::size_t fBatchSize;
54 std::size_t fSetSeed;
55
56 // buffer quantities
57 std::size_t fBatchesInMemory;
58 std::size_t fBufferCapacity;
59 std::size_t fLowWatermark;
60 std::size_t fHighWatermark;
61
62 std::size_t fTrainingClusterIdx{0};
63 std::size_t fValidationClusterIdx{0};
64
65 float fTestSize;
66
67 std::unique_ptr<RDatasetLoader<Args...>> fDatasetLoader;
68 std::unique_ptr<RClusterLoader<Args...>> fClusterLoader;
69 std::unique_ptr<RBatchLoader> fTrainingBatchLoader;
70 std::unique_ptr<RBatchLoader> fValidationBatchLoader;
71 std::unique_ptr<RSampler> fTrainingSampler;
72 std::unique_ptr<RSampler> fValidationSampler;
73
74 std::unique_ptr<RFlat2DMatrixOperators> fTensorOperators;
75
76 std::vector<ROOT::RDF::RNode> fRdfs;
77
78 std::unique_ptr<std::thread> fLoadingThread;
79 std::condition_variable fLoadingCondition;
80 std::mutex fLoadingMutex;
81
85 std::string fSampleType;
88
89 bool fIsActive{false}; // Whether the loading thread is active
90
91 bool fEpochActive{false};
94
97
98 // flattened buffers for chunks and temporary tensors (rows * cols)
99 std::vector<RFlat2DMatrix> fTrainingDatasets;
100 std::vector<RFlat2DMatrix> fValidationDatasets;
101
104
107
108 std::size_t fTrainingEpochCount{0};
109 std::size_t fValidationEpochCount{0};
110
111public:
112 RDataLoaderEngine(const std::vector<ROOT::RDF::RNode> &rdfs, const std::size_t batchSize,
113 const std::size_t batchesInMemory, const std::vector<std::string> &cols,
114 const std::vector<std::size_t> &vecSizes = {}, const float vecPadding = 0.0,
115 const float testSize = 0.0, bool shuffle = true, bool dropRemainder = true,
116 const std::size_t setSeed = 0, bool loadEager = false, std::string sampleType = "",
117 float sampleRatio = 1.0, bool replacement = false)
118 : fRdfs(rdfs),
119 fCols(cols),
121 fBatchSize(batchSize),
131 {
132 fTensorOperators = std::make_unique<RFlat2DMatrixOperators>(fShuffle, fSetSeed);
133
134 if (fLoadEager) {
135 fDatasetLoader = std::make_unique<RDatasetLoader<Args...>>(fRdfs, fTestSize, fCols, fVecSizes, vecPadding,
138
139 if (fSampleType == "") {
140 fDatasetLoader->ConcatenateDatasets();
141
142 fTrainingDataset = fDatasetLoader->GetTrainingDataset();
143 fValidationDataset = fDatasetLoader->GetValidationDataset();
144
145 fNumTrainingEntries = fDatasetLoader->GetNumTrainingEntries();
146 fNumValidationEntries = fDatasetLoader->GetNumValidationEntries();
147 }
148
149 else {
150 fTrainingDatasets = fDatasetLoader->GetTrainingDatasets();
151 fValidationDatasets = fDatasetLoader->GetValidationDatasets();
152
155 fValidationSampler = std::make_unique<RSampler>(fValidationDatasets, fSampleType, fSampleRatio,
157
158 fNumTrainingEntries = fTrainingSampler->GetNumEntries();
159 fNumValidationEntries = fValidationSampler->GetNumEntries();
160 }
161 }
162
163 else {
164 // scan cluster boundaries
165 fClusterLoader = std::make_unique<RClusterLoader<Args...>>(fRdfs, fCols, fVecSizes, vecPadding, fTestSize,
167
168 // derive buffer quantities
172
173 // split cluster list into training and validation
174 fClusterLoader->SplitDataset();
175 fNumTrainingEntries = fClusterLoader->GetNumTrainingEntries();
176 fNumValidationEntries = fClusterLoader->GetNumValidationEntries();
177 }
178
179 fTrainingBatchLoader = std::make_unique<RBatchLoader>(fBatchSize, fCols, fLoadingMutex, fLoadingCondition,
181 fValidationBatchLoader = std::make_unique<RBatchLoader>(fBatchSize, fCols, fLoadingMutex, fLoadingCondition,
183 }
184
186
188 {
189 {
190 std::lock_guard<std::mutex> lock(fLoadingMutex);
191 if (!fIsActive)
192 return;
193 fIsActive = false;
194 }
195
196 fLoadingCondition.notify_all();
197
198 if (fLoadingThread) {
199 if (fLoadingThread->joinable()) {
200 fLoadingThread->join();
201 }
202 }
203
204 fLoadingThread.reset();
205 }
206
207 /// \brief Activate the loading process by spawning the loading thread.
208 void Activate()
209 {
210 {
211 std::lock_guard<std::mutex> lock(fLoadingMutex);
212 if (fIsActive)
213 return;
214
215 fIsActive = true;
216 }
217
218 if (fLoadEager) {
219 return;
220 }
221
222 fLoadingThread = std::make_unique<std::thread>(&RDataLoaderEngine::LoadData, this);
223 }
224
225 /// \brief Activate the training epoch by starting the batchloader.
227 {
228 {
229 std::lock_guard<std::mutex> lock(fLoadingMutex);
232 if (!fLoadEager) {
233 // Shuffle the cluster indices at the beginning of each epoch
234 fClusterLoader->ShuffleTrainingClusters(fTrainingEpochCount++);
235 }
236 }
237
238 fTrainingBatchLoader->Activate();
239 fLoadingCondition.notify_all();
240 }
241
243 {
244 {
245 std::lock_guard<std::mutex> lock(fLoadingMutex);
246 fTrainingEpochActive = false;
247 }
248
249 fTrainingBatchLoader->Reset();
250 fTrainingBatchLoader->DeActivate();
251 fLoadingCondition.notify_all();
252 }
253
255 {
256 {
257 std::lock_guard<std::mutex> lock(fLoadingMutex);
260 if (!fLoadEager) {
261 fClusterLoader->ShuffleValidationClusters(fValidationEpochCount++);
262 }
263 }
264
265 fValidationBatchLoader->Activate();
266 fLoadingCondition.notify_all();
267 }
268
270 {
271 {
272 std::lock_guard<std::mutex> lock(fLoadingMutex);
274 }
275
276 fValidationBatchLoader->Reset();
277 fValidationBatchLoader->DeActivate();
278 fLoadingCondition.notify_all();
279 }
280
281 /// \brief Main loop for loading clusters and creating batches.
282 /// The producer (loading thread) will keep loading clusters and creating batches until the end of the epoch is
283 /// reached, or the generator is deactivated.
284 void LoadData()
285 {
286 std::unique_lock<std::mutex> lock(fLoadingMutex);
287
288 while (true) {
289 // Wait until we have work or shutdown
290 fLoadingCondition.wait(lock, [&] {
291 return !fIsActive ||
292 (fTrainingEpochActive && fTrainingClusterIdx < fClusterLoader->GetNumTrainingClusters()) ||
293 (fValidationEpochActive && fValidationClusterIdx < fClusterLoader->GetNumValidationClusters());
294 });
295
296 if (!fIsActive) {
297 break;
298 }
299
300 // Helper: check if validation queue below watermark and needs the producer
301 auto validationEmpty = [&] {
302 if (!fValidationEpochActive || fValidationClusterIdx >= fClusterLoader->GetNumValidationClusters())
303 return false;
304 if (fValidationBatchLoader->isProducerDone())
305 return false;
306 return fValidationBatchLoader->GetNumBatchQueue() < fLowWatermark / fBatchSize;
307 };
308
309 // -- TRAINING --
311 const std::size_t numTrainingClusters = fClusterLoader->GetNumTrainingClusters();
312
313 while (true) {
314 // Stop conditions (shutdown or epoch end)
316 break;
317
318 // No more chunks to load: signal consumers
320 fTrainingBatchLoader->MarkProducerDone();
321 break;
322 }
323
324 // In the case of training prefetching, we could start requesting data for the next training loop while
325 // validation is active and might need data. To avoid getting stuck in the training loop, we check if the
326 // validation queue is below watermark and if so, we break out of the training loop.
327 if (validationEmpty()) {
328 break;
329 }
330
331 // If queue is not empty, wait until it drains below watermark, or validation needs data, or we are
332 // deactivated.
333 if (fTrainingBatchLoader->GetNumBatchQueue() >= fLowWatermark / fBatchSize) {
334 fLoadingCondition.wait(lock, [&] {
335 return !fIsActive || !fTrainingEpochActive ||
336 fTrainingBatchLoader->GetNumBatchQueue() < (fLowWatermark / fBatchSize) ||
338 });
339 continue;
340 }
341
342 // Accumulate clusters to load, enough to fill the buffer, or until we run out of clusters
343 std::vector<RClusterRange> trainClustersToLoad;
344 auto accumulatedEntries = 0;
345 const bool discovering = !fClusterLoader->IsSplitDiscovered();
347 (!discovering || trainClustersToLoad.empty())) {
348 const auto &cluster = fClusterLoader->GetTrainingClusters()[fTrainingClusterIdx++];
349 trainClustersToLoad.push_back(cluster);
350 accumulatedEntries += cluster.GetNumEntries();
351 }
352
354
355 // Release lock while reading and loading data to allow the consumer to access the queue freely in
356 // parallel. The loading thread re-acquires the lock in CreateBatches when it needs to push batches to
357 // the queue.
358 lock.unlock();
360 std::size_t rowOffset = 0;
361
362 for (auto &cluster : trainClustersToLoad) {
363 auto loadedEntries = fClusterLoader->LoadTrainingClusterInto(stagingBuffer, cluster.rdfIdx,
364 cluster.start, cluster.end, rowOffset);
365 if (discovering) {
366 // For the first epoch, we might discover that the cluster has fewer entries than expected because
367 // of filters
368 cluster.SetNumEntries(loadedEntries);
369 }
370 rowOffset += cluster.GetNumEntries();
371 }
372
373 if (discovering && fNumTrainingEntries == 0 && fClusterLoader->GetNumTrainingEntries() > 0) {
374 fNumTrainingEntries = fClusterLoader->GetNumTrainingEntries();
375 fNumValidationEntries = fClusterLoader->GetNumValidationEntries();
376 fTrainingBatchLoader->RecalculateBatchCounts(fNumTrainingEntries);
377 fValidationBatchLoader->RecalculateBatchCounts(fNumValidationEntries);
378 }
379
380 if (rowOffset < static_cast<std::size_t>(accumulatedEntries)) {
381 stagingBuffer.Resize(rowOffset, stagingBuffer.GetCols());
382 }
383
387
388 // Re-acquire the lock before the next iteration to check conditions and update indices
389 lock.lock();
390
391 if (isLastBuffer && discovering) {
392 fClusterLoader->FinaliseSplitDiscovery();
393 }
394 }
395 }
396
397 // -- VALIDATION --
399 const std::size_t numValidationClusters = fClusterLoader->GetNumValidationClusters();
400
401 while (true) {
402 // Stop conditions (shutdown or epoch end)
404 break;
405
406 // No more chunks to load: signal consumers
408 fValidationBatchLoader->MarkProducerDone();
409 break;
410 }
411
412 // If queue is not hungry, wait until it drains below watermark, or we are deactivated
413 if (fValidationBatchLoader->GetNumBatchQueue() >= (fLowWatermark / fBatchSize)) {
414 fLoadingCondition.wait(lock, [&] {
415 return !fIsActive || !fValidationEpochActive ||
416 fValidationBatchLoader->GetNumBatchQueue() < (fLowWatermark / fBatchSize);
417 });
418 continue;
419 }
420
421 // Accumulate clusters to load, enough to fill the buffer, or until we run out of clusters
422 std::vector<RClusterRange> valClustersToLoad;
423 auto accumulatedEntries = 0;
425 const auto &cluster = fClusterLoader->GetValidationClusters()[fValidationClusterIdx++];
426 valClustersToLoad.push_back(cluster);
427 accumulatedEntries += cluster.GetNumEntries();
428 }
429
431
432 lock.unlock();
433
435 std::size_t rowOffset = 0;
436
437 for (const auto &cluster : valClustersToLoad) {
438 fClusterLoader->LoadValidationClusterInto(stagingBuffer, cluster.rdfIdx, cluster.start, cluster.end,
439 rowOffset);
440 rowOffset += cluster.GetNumEntries();
441 }
442
446
447 lock.lock();
448 }
449 }
450 }
451 }
452
453 /// \brief Create training batches by first loading a chunk (see RClusterLoader) and split it into batches (see
454 /// RBatchLoader)
456 {
457 fTrainingBatchLoader->Activate();
458
459 if (fLoadEager) {
460 if (fSampleType == "") {
462 }
463
464 else {
466 }
467
468 fTrainingBatchLoader->CreateBatches(fSampledTrainingDataset, true);
469 fTrainingBatchLoader->MarkProducerDone();
470 }
471 }
472
473 /// \brief Creates validation batches by first loading a chunk (see RClusterLoader), and then split it into batches
474 /// (see RBatchLoader)
476 {
477 fValidationBatchLoader->Activate();
478
479 if (fLoadEager) {
480 if (fSampleType == "") {
482 }
483
484 else {
486 }
487
489 fValidationBatchLoader->MarkProducerDone();
490 }
491 }
492
493 /// \brief Loads a training batch from the queue
495 {
496 // Get next batch if available
497 return fTrainingBatchLoader->GetBatch();
498 }
499
500 /// \brief Loads a validation batch from the queue
502 {
503 // Get next batch if available
504 return fValidationBatchLoader->GetBatch();
505 }
506
507 std::size_t NumberOfTrainingBatches() { return fTrainingBatchLoader->GetNumBatches(); }
508 std::size_t NumberOfValidationBatches() { return fValidationBatchLoader->GetNumBatches(); }
509
510 std::size_t TrainRemainderRows() { return fTrainingBatchLoader->GetNumRemainderRows(); }
511 std::size_t ValidationRemainderRows() { return fValidationBatchLoader->GetNumRemainderRows(); }
512
513 bool IsActive()
514 {
515 std::lock_guard<std::mutex> lock(fLoadingMutex);
516 return fIsActive;
517 }
518
520 {
521 std::lock_guard<std::mutex> lock(fLoadingMutex);
523 }
524
526 {
527 std::lock_guard<std::mutex> lock(fLoadingMutex);
529 }
530};
531
532} // namespace ROOT::Experimental::Internal::ML
533
534#endif // ROOT_INTERNAL_ML_RDATALOADERENGINE
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Loads TTree/RNTuple clusters from one or more RDataFrames into RFlat2DMatrix buffers for ML training ...
In this class, the processes of loading clusters (see RClusterLoader) and creating batches from those...
void ActivateTrainingEpoch()
Activate the training epoch by starting the batchloader.
RFlat2DMatrix GetTrainBatch()
Loads a training batch from the queue.
std::unique_ptr< RFlat2DMatrixOperators > fTensorOperators
void CreateValidationBatches()
Creates validation batches by first loading a chunk (see RClusterLoader), and then split it into batc...
void LoadData()
Main loop for loading clusters and creating batches.
void CreateTrainBatches()
Create training batches by first loading a chunk (see RClusterLoader) and split it into batches (see ...
RFlat2DMatrix GetValidationBatch()
Loads a validation batch from the queue.
void Activate()
Activate the loading process by spawning the loading thread.
std::unique_ptr< RDatasetLoader< Args... > > fDatasetLoader
RDataLoaderEngine(const std::vector< ROOT::RDF::RNode > &rdfs, const std::size_t batchSize, const std::size_t batchesInMemory, const std::vector< std::string > &cols, const std::vector< std::size_t > &vecSizes={}, const float vecPadding=0.0, const float testSize=0.0, bool shuffle=true, bool dropRemainder=true, const std::size_t setSeed=0, bool loadEager=false, std::string sampleType="", float sampleRatio=1.0, bool replacement=false)
std::unique_ptr< RClusterLoader< Args... > > fClusterLoader
void SplitDatasets()
Split the dataframes in a training and validation dataset.
const_iterator end() const
Wrapper around ROOT::RVec<float> representing a 2D matrix.