Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RClusterLoader.hxx
Go to the documentation of this file.
1// Author: Dante Niewenhuis, VU Amsterdam 07/2023
2// Author: Kristupas Pranckietis, Vilnius University 05/2024
3// Author: Nopphakorn Subsa-Ard, King Mongkut's University of Technology Thonburi (KMUTT) (TH) 08/2024
4// Author: Vincenzo Eduardo Padulano, CERN 10/2024
5// Author: Silia Taider, CERN 03/2026
6
7/*************************************************************************
8 * Copyright (C) 1995-2025, Rene Brun and Fons Rademakers. *
9 * All rights reserved. *
10 * *
11 * For the licensing terms see $ROOTSYS/LICENSE. *
12 * For the list of contributors see $ROOTSYS/README/CREDITS. *
13 *************************************************************************/
14
15#ifndef ROOT_INTERNAL_ML_RCLUSTERLOADER
16#define ROOT_INTERNAL_ML_RCLUSTERLOADER
17
18#include <algorithm>
19#include <numeric>
20#include <random>
21#include <string>
22#include <utility>
23#include <vector>
24
27#include "ROOT/RDataFrame.hxx"
28#include "ROOT/RDFHelpers.hxx"
29#include "ROOT/RDF/Utils.hxx"
30
32
33/**
34 * \struct RClusterRange
35 * \brief Describes a contiguous range of entries within a single RDataFrame,
36 * corresponding to one TTree/RNTuple cluster boundary.
37 *
38 * For filtered RDataFrames the \p numEntries field may be smaller than `end - start`
39 * because it tracks the number of entries that actually pass the filter,
40 * discovered and set lazily during the first epoch.
41 */
43 std::size_t rdfIdx; // which rdf this cluster belongs to
44 std::uint64_t start; // first raw entry (incl)
45 std::uint64_t end; // one-past-last entry (excl)
46 std::size_t numEntries{
47 static_cast<std::size_t>(end - start)}; // number of entries in the cluster (that pass filters, if any)
48
49 std::size_t GetNumEntries() const { return numEntries; }
50 void SetNumEntries(std::size_t num) { numEntries = num; }
51};
52
53/**
54 * \class ROOT::Experimental::Internal::ML::RClusterLoaderFunctor
55 * \brief Functor invoked by RDataFrame::Foreach to fill one row of an RFlat2DMatrix.
56 *
57 */
58
59template <typename... ColTypes>
61 std::size_t fOffset{};
62 std::size_t fVecSizeIdx{};
63 float fVecPadding{};
64 std::vector<std::size_t> fMaxVecSizes{};
66
67 std::size_t fNumChunkCols;
68
69 int fI;
71
72 //////////////////////////////////////////////////////////////////////////
73 /// \brief \brief Copy the content of a column into the current tensor when the column consists of vectors
75 void AssignToTensor(const T &vec, int i, int numColumns)
76 {
77 std::size_t max_vec_size = fMaxVecSizes[fVecSizeIdx++];
78 std::size_t vec_size = vec.size();
79
80 float *dst = fChunkTensor.GetData() + fOffset + numColumns * i;
81 if (vec_size < max_vec_size) // Padding vector column to max_vec_size with fVecPadding
82 {
83 std::copy(vec.begin(), vec.end(), dst);
84 std::fill(dst + vec_size, dst + max_vec_size, fVecPadding);
85 } else // Copy only max_vec_size length from vector column
86 {
87 std::copy(vec.begin(), vec.begin() + max_vec_size, dst);
88 }
90 }
91
92 //////////////////////////////////////////////////////////////////////////
93 /// \brief Copy the content of a column into the current tensor when the column consists of scalar values
95 void AssignToTensor(const T &val, int i, int numColumns)
96 {
98 fOffset++;
99 }
100
101public:
103 const std::vector<std::size_t> &maxVecSizes, float vecPadding, int i,
104 std::size_t rowOffset = 0)
108 fI(i),
111 {
112 }
113
114 void operator()(const ColTypes &...cols)
115 {
116 fVecSizeIdx = 0;
118 }
119};
120
121/**
122 * \class ROOT::Experimental::Internal::ML::RClusterLoader
123 * \brief Loads TTree/RNTuple clusters from one or more RDataFrames into RFlat2DMatrix
124 * buffers for ML training and validation.
125 *
126 * ### Overview
127 * At construction the loader scans the cluster boundaries of every
128 * provided RDataFrame and stores them as a flat list of \ref RClusterRange objects.
129 * SplitDataset() then partitions those ranges into training and validation sets according to \p validationSplit.
130 *
131 * ### The split strategy depends on whether shuffling is enabled or not
132 * - **Unshuffled**: one cut is made so that the first `(1 - validationSplit)`
133 * fraction of entries goes to training. At most one cluster is split at the boundary.
134 * - **Shuffled**: each cluster is split proportionally (according to `validationSplit`)
135 * so both sets draw entries from every part of the dataset. ShuffleTrainingClusters()
136 * and ShuffleValidationClusters() re-order the cluster lists at the start of each epoch.
137 * A second shuffling step, at the entries level, happens inside LoadTrainingClusterInto()
138 * and LoadValidationClusterInto() when loading the data into the tensors.
139 *
140 * ### Filtered RDataFrames
141 * When any RDataFrame carries a filter, the true entry count is not known
142 * until the computation graph is executed. In this case SplitDataset() is a
143 * no-op and the split is discovered lazily inside LoadTrainingClusterInto()
144 * during the first epoch.
145 * After the first epoch FinaliseSplitDiscovery() marks the split as stable and
146 * all subsequent epochs use the same pre-computed ranges.
147 */
148template <typename... Args>
150private:
151 std::vector<ROOT::RDF::RNode> &fRdfs;
152 std::vector<std::size_t> fRdfSizes;
153 std::vector<std::string> fCols;
154 std::vector<std::size_t> fVecSizes;
158 std::size_t fSetSeed;
159
160 std::size_t fNumCols;
161 std::size_t fSumVecSizes;
162 std::size_t fNumChunkCols;
163
164 std::vector<RClusterRange> fAllClusters;
165 std::vector<RClusterRange> fTrainingClusters;
166 std::vector<RClusterRange> fValidationClusters;
167
168 std::size_t fTotalEntries{0};
169 std::size_t fNumTrainingEntries{0};
170 std::size_t fNumValidationEntries{0};
171
172 bool fIsFiltered{false};
173 bool fSplitDiscovered{false};
175
176public:
177 RClusterLoader(std::vector<ROOT::RDF::RNode> &rdfs, const std::vector<std::string> &cols,
178 const std::vector<std::size_t> &vecSizes, float vecPadding, float validationSplit, bool shuffle,
179 std::size_t setSeed)
180 : fRdfs(rdfs),
181 fCols(cols),
187 {
188 fNumCols = fCols.size();
189 fSumVecSizes = std::accumulate(fVecSizes.begin(), fVecSizes.end(), 0UL);
191
192 for (auto &rdf : fRdfs) {
193 // TODO(staider) We need a better API in RDF to detect generically whether there's a filter or not
194 if (!rdf.GetFilterNames().empty()) {
195 fIsFiltered = true;
196 break;
197 }
198 }
199
200 fRdfSizes.resize(fRdfs.size(), 0);
201
202 // scan cluster boundaries across files
203 // TODO(staider) Add progress bar to inform the user about this potentially long operation
204 for (std::size_t rdfIdx = 0; rdfIdx < fRdfs.size(); ++rdfIdx) {
206 fAllClusters.push_back({rdfIdx, r.first, r.second});
207 auto numEntries = r.second - r.first;
208 fRdfSizes[rdfIdx] += numEntries;
209 fTotalEntries += numEntries;
210 }
211 }
212 }
213
214 //////////////////////////////////////////////////////////////////////////
215 /// \brief Distribute the clusters into training and validation datasets
216 /// No-op for filtered RDataFrames, the split is discovered lazily during the first epoch.
218 {
219 if (fAllClusters.empty())
220 throw std::runtime_error("RClusterLoader::SplitDataset: no clusters found.");
221
222 if (fIsFiltered) {
223 return;
224 }
225
226 if (fShuffle) {
227 // --- Shuffled path
228 // Every cluster contributes a prefix to training and a suffix to validation.
229 // Cost: Each cluster is read twice per epoch, only when validation split is more than 0.
230 // TODO(staider) Swicth between prefix or suffix for validation randomly per cluster
231 for (const RClusterRange &c : fAllClusters) {
232 const std::size_t sz = c.GetNumEntries();
233 const std::size_t trainSz = static_cast<std::size_t>((1.0f - fValidationSplit) * sz);
234 const std::size_t valSz = sz - trainSz;
235
236 if (trainSz > 0) {
237 fTrainingClusters.push_back({c.rdfIdx, c.start, c.start + static_cast<std::uint64_t>(trainSz)});
239 }
240 if (valSz > 0) {
241 fValidationClusters.push_back({c.rdfIdx, c.start + static_cast<std::uint64_t>(trainSz), c.end});
243 }
244 }
245 } else {
246 // --- Unshuffled path
247 // Contiguous split: first (1 - validationSplit) fraction of entries go to
248 // training, the remainder to validation. At most one cluster is split at
249 // the boundary.
250 const std::size_t targetTraining = fTotalEntries - static_cast<std::size_t>(fValidationSplit * fTotalEntries);
251
252 std::size_t accumulated = 0;
253 std::size_t splitIdx = 0;
254 for (; splitIdx < fAllClusters.size(); ++splitIdx) {
255 const std::size_t sz = fAllClusters[splitIdx].GetNumEntries();
256 if (accumulated + sz > targetTraining) {
257 break;
258 }
259 accumulated += sz;
260 }
261
262 // Assign whole train/val clusters
263 fTrainingClusters.assign(fAllClusters.begin(), fAllClusters.begin() + splitIdx);
265
267 // Split the boundary cluster
269 const std::uint64_t splitPoint = boundary.start + static_cast<std::uint64_t>(targetTraining - accumulated);
270
271 fTrainingClusters.push_back({boundary.rdfIdx, boundary.start, splitPoint});
272 fValidationClusters.push_back({boundary.rdfIdx, splitPoint, boundary.end});
274 fAllClusters.end());
275
277 } else {
278 fValidationClusters.assign(fAllClusters.begin() + splitIdx, fAllClusters.end());
279 }
280
282 }
283
284 if (fTrainingClusters.empty())
285 throw std::runtime_error("RClusterLoader::SplitDataset: no entries for training after split. "
286 "Reduce validation_split.");
287
288 if (fValidationSplit > 0.0f && fValidationClusters.empty())
289 throw std::runtime_error("RClusterLoader::SplitDataset: no entries for validation after split. "
290 "Increase validation_split.");
291 }
292
293 //////////////////////////////////////////////////////////////////////////
294 /// \brief Re-order training clusters for the upcoming epoch
296 {
297 if (!fShuffle) {
298 return;
299 }
300
301 std::mt19937 g(fSetSeed == 0 ? std::random_device{}() : fSetSeed ^ epochIdx);
302 std::shuffle(fTrainingClusters.begin(), fTrainingClusters.end(), g);
303 }
304
305 //////////////////////////////////////////////////////////////////////////
306 /// \brief Re-order validation clusters for the upcoming epoch
308 {
309 if (!fShuffle) {
310 return;
311 }
312 std::mt19937 g(fSetSeed == 0 ? std::random_device{}() : fSetSeed ^ epochIdx);
313 std::shuffle(fValidationClusters.begin(), fValidationClusters.end(), g);
314 }
315
316 void LoadClusterInto(RFlat2DMatrix &dest, std::size_t rdfIdx, std::uint64_t startRow, std::uint64_t endRow,
317 std::size_t rowOffset = 0)
318 {
319 ROOT::RDF::RNode &rdf = fRdfs[rdfIdx];
322 rdf.Foreach(func, fCols);
324 }
325
326 //////////////////////////////////////////////////////////////////////////
327 /// \brief Load one training cluster and return the number of rows written.
328 ///
329 /// **Unfiltered**: delegates directly to `LoadClusterInto()`
330 /// **Filtered**, epoch 1 (!fSplitDiscovered):
331 /// - On the first call, Count() is called across all RDFs to obtain
332 /// the total filtered entry count, fNumTrainingEntries and
333 /// fNumValidationEntries are set as targets.
334 /// - A single Foreach on the full raw cluster range loads data and captures
335 /// rdfentry_ simultaneously. The real train/val boundary is computed from
336 /// the accumulated filtered count vs the target, then the train sub-range
337 /// is pushed to fTrainingClusters and the val sub-range to fValidationClusters.
338 /// - Only the train rows are written into \p dest.
339 /// -All subsequent epochs: delegates directly to `LoadClusterInto()`
340 std::size_t LoadTrainingClusterInto(RFlat2DMatrix &dest, std::size_t rdfIdx, std::uint64_t startRow,
341 std::uint64_t endRow, std::size_t rowOffset = 0)
342 {
344 // First call: discover total filtered count and set split targets.
346 std::vector<ROOT::RDF::RResultPtr<ULong64_t>> counts;
347 counts.reserve(fRdfs.size());
348 for (auto &rdf : fRdfs) {
349 counts.push_back(rdf.Count());
350 }
352
353 std::size_t totalFiltered = 0;
354 for (auto &c : counts) {
355 totalFiltered += c.GetValue();
356 }
357 fNumTrainingEntries = static_cast<std::size_t>(totalFiltered * (1.0f - fValidationSplit));
359 }
360
361 ROOT::RDF::RNode &rdf = fRdfs[rdfIdx];
362
363 // Fill data and collect raw entry indices that pass the filter
364 std::vector<ULong64_t> rdfEntries;
365 rdfEntries.reserve(endRow - startRow);
366
369
370 std::vector<std::string> colsWithEntry;
371 colsWithEntry.reserve(fCols.size() + 1);
372 colsWithEntry.push_back("rdfentry_");
373 colsWithEntry.insert(colsWithEntry.end(), fCols.begin(), fCols.end());
374
375 rdf.Foreach(
376 [&](ULong64_t entry, const Args &...cols) {
377 rdfEntries.push_back(entry);
378 loader(cols...);
379 },
381
383
384 const std::size_t totalFiltered = rdfEntries.size();
385 if (totalFiltered == 0) {
386 return 0;
387 }
388 std::sort(rdfEntries.begin(), rdfEntries.end());
389
391 const std::size_t trainCount =
392 std::min(static_cast<std::size_t>(totalFiltered * (1.0f - fValidationSplit)), trainRemaining);
393 const std::size_t valCount = totalFiltered - trainCount;
394
395 // The boundary is the raw entry index of the first entry assigned to validation.
396 // Stable across epochs since the same filter always produces the same ordered entries.
397 const std::uint64_t boundary = (valCount > 0) ? rdfEntries[trainCount] : endRow;
398
399 if (trainCount > 0)
400 fTrainingClusters.push_back({rdfIdx, startRow, boundary, trainCount});
401 if (valCount > 0)
402 fValidationClusters.push_back({rdfIdx, boundary, endRow, valCount});
403
405 return trainCount;
406 }
407
409 return endRow - startRow;
410 }
411
412 //////////////////////////////////////////////////////////////////////////
413 /// \brief Load one validation cluster into \p dest starting at \p rowOffset
414 void LoadValidationClusterInto(RFlat2DMatrix &dest, std::size_t rdfIdx, std::uint64_t startRow, std::uint64_t endRow,
415 std::size_t rowOffset = 0)
416 {
418 }
419
420 //////////////////////////////////////////////////////////////////////////
421 /// \brief Mark the train/val split as finalised after the first epoch
423 {
424 if (fIsFiltered)
425 fSplitDiscovered = true;
426 }
427
428 bool IsSplitDiscovered() const { return !fIsFiltered || fSplitDiscovered; }
429
430 //////////////////////////////////////////////////////////////////////////
431 // Accessors
432 std::size_t GetNumTrainingEntries() const { return fNumTrainingEntries; }
433 std::size_t GetNumValidationEntries() const { return fNumValidationEntries; }
434 std::size_t GetNumChunkCols() const { return fNumChunkCols; }
435
436 const std::vector<RClusterRange> &GetTrainingClusters() const
437 {
439 }
440 const std::vector<RClusterRange> &GetValidationClusters() const { return fValidationClusters; }
441
442 std::size_t GetNumTrainingClusters() const
443 {
444 return (fIsFiltered && !fSplitDiscovered) ? fAllClusters.size() : fTrainingClusters.size();
445 }
446 std::size_t GetNumValidationClusters() const { return fValidationClusters.size(); }
447 std::size_t GetNmTotalClusters() const { return fAllClusters.size(); }
448};
449
450} // namespace ROOT::Experimental::Internal::ML
451#endif // ROOT_INTERNAL_ML_RCLUSTERLOADER
#define c(i)
Definition RSha256.hxx:101
#define g(i)
Definition RSha256.hxx:105
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t dest
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t r
Functor invoked by RDataFrame::Foreach to fill one row of an RFlat2DMatrix.
void AssignToTensor(const T &vec, int i, int numColumns)
Copy the content of a column into the current tensor when the column consists of vectors.
RClusterLoaderFunctor(RFlat2DMatrix &chunkTensor, std::size_t numColumns, const std::vector< std::size_t > &maxVecSizes, float vecPadding, int i, std::size_t rowOffset=0)
void AssignToTensor(const T &val, int i, int numColumns)
Copy the content of a column into the current tensor when the column consists of scalar values.
Loads TTree/RNTuple clusters from one or more RDataFrames into RFlat2DMatrix buffers for ML training ...
void ShuffleTrainingClusters(std::size_t epochIdx)
Re-order training clusters for the upcoming epoch.
void FinaliseSplitDiscovery()
Mark the train/val split as finalised after the first epoch.
void LoadClusterInto(RFlat2DMatrix &dest, std::size_t rdfIdx, std::uint64_t startRow, std::uint64_t endRow, std::size_t rowOffset=0)
void ShuffleValidationClusters(std::size_t epochIdx)
Re-order validation clusters for the upcoming epoch.
std::size_t LoadTrainingClusterInto(RFlat2DMatrix &dest, std::size_t rdfIdx, std::uint64_t startRow, std::uint64_t endRow, std::size_t rowOffset=0)
Load one training cluster and return the number of rows written.
void SplitDataset()
Distribute the clusters into training and validation datasets No-op for filtered RDataFrames,...
void LoadValidationClusterInto(RFlat2DMatrix &dest, std::size_t rdfIdx, std::uint64_t startRow, std::uint64_t endRow, std::size_t rowOffset=0)
Load one validation cluster into dest starting at rowOffset.
const std::vector< RClusterRange > & GetTrainingClusters() const
RClusterLoader(std::vector< ROOT::RDF::RNode > &rdfs, const std::vector< std::string > &cols, const std::vector< std::size_t > &vecSizes, float vecPadding, float validationSplit, bool shuffle, std::size_t setSeed)
const std::vector< RClusterRange > & GetValidationClusters() const
The public interface to the RDataFrame federation of classes.
const_iterator begin() const
const_iterator end() const
std::vector< std::pair< std::uint64_t, std::uint64_t > > GetDatasetGlobalClusterBoundaries(const RNode &node)
Retrieve the cluster boundaries for each cluster in the dataset, across files, with a global offset.
void ChangeBeginAndEndEntries(const RNode &node, Long64_t begin, Long64_t end)
unsigned int RunGraphs(std::vector< RResultHandle > handles)
Run the event loops of multiple RDataFrames concurrently.
Describes a contiguous range of entries within a single RDataFrame, corresponding to one TTree/RNTupl...
Wrapper around ROOT::RVec<float> representing a 2D matrix.