Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RBatchLoader.cxx
Go to the documentation of this file.
1
3
4#include <algorithm>
5#include <numeric>
6#include <utility>
7
9
10RBatchLoader::RBatchLoader(std::size_t batchSize, const std::vector<std::string> &cols, std::mutex &sharedMutex,
11 std::condition_variable &sharedCV, const std::vector<std::size_t> &vecSizes,
12 std::size_t numEntries, bool dropRemainder)
13 : fBatchSize(batchSize),
14 fCols(cols),
15 fLock(sharedMutex),
16 fCV(sharedCV),
17 fVecSizes(vecSizes),
18 fNumEntries(numEntries),
19 fDropRemainder(dropRemainder)
20{
21 fSumVecSizes = std::accumulate(fVecSizes.begin(), fVecSizes.end(), 0);
22 fNumColumns = fCols.size() + fSumVecSizes - fVecSizes.size();
23
24 RecalculateBatchCounts(numEntries);
25
26 fPrimaryLeftoverBatch = std::make_unique<RFlat2DMatrix>();
27 fSecondaryLeftoverBatch = std::make_unique<RFlat2DMatrix>();
28}
29
30/// \brief Activate the batchloader. This means that batches can be created and loaded.
32{
33 {
34 std::lock_guard<std::mutex> lock(fLock);
35 if (fIsActive)
36 return;
37 fIsActive = true;
38 fProducerDone = false;
39 }
40
41 fCV.notify_all();
42}
43
44/// \brief DeActivate the batchloader. This means that no more batches are created.
45/// Batches can still be returned if they are already loaded.
47{
48 {
49 std::lock_guard<std::mutex> lock(fLock);
50 if (!fIsActive)
51 return;
52 fIsActive = false;
53 }
54
55 fCV.notify_all();
56}
57
58/// \brief Reset the batchloader state.
60{
61 {
62 std::lock_guard<std::mutex> lock(fLock);
63
64 while (!fBatchQueue.empty()) {
65 fBatchQueue.pop();
66 }
67
68 fCurrentBatch.reset();
69 fPrimaryLeftoverBatch = std::make_unique<RFlat2DMatrix>();
70 fSecondaryLeftoverBatch = std::make_unique<RFlat2DMatrix>();
71 }
72
73 fCV.notify_all();
74}
75
76/// \brief Signal that the producer has finished pushing all batches for this epoch.
78{
79 fProducerDone = true;
80 fCV.notify_all();
81}
82
83/// \brief Return a batch of data as a unique pointer.
84/// After the batch has been processed, it should be destroyed.
85/// \param[in] chunkTensor Tensor with the data from the chunk
86/// \param[in] idxs Index of batch in the chunk
87/// \return Batch
88std::unique_ptr<RFlat2DMatrix> RBatchLoader::CreateBatch(RFlat2DMatrix &chunkTensor, std::size_t idxs)
89{
90 auto batch = std::make_unique<RFlat2DMatrix>(fBatchSize, fNumColumns);
91 std::copy(chunkTensor.GetData() + (idxs * fBatchSize * fNumColumns),
92 chunkTensor.GetData() + ((idxs + 1) * fBatchSize * fNumColumns), batch->GetData());
93
94 return batch;
95}
96
97/// \brief Loading the batch from the queue.
98/// \return Batch
100{
101 std::unique_lock<std::mutex> lock(fLock);
102
103 // Wait until:
104 // - there is data in the queue
105 // - or producer declares "done"
106 // - or we are deactivated
107 fCV.wait(lock, [&] { return !fBatchQueue.empty() || fProducerDone || !fIsActive; });
108
109 if (fBatchQueue.empty()) {
110 // producer done and no queued data -> end-of-epoch signal
111 fCurrentBatch = std::make_unique<RFlat2DMatrix>();
112 return *fCurrentBatch;
113 }
114
115 fCurrentBatch = std::move(fBatchQueue.front());
116 fBatchQueue.pop();
117 // Notify the loading thread that the queue has drained
118 fCV.notify_all();
119
120 return *fCurrentBatch;
121}
122
123/// \brief Creating the batches from a chunk and add them to the queue.
124/// \param[in] chunkTensor Tensor with the data from the chunk
125/// \param[in] isLastBatch Check if the batch in the chunk is the last one
127{
128 std::size_t chunkSize = chunkTensor.GetRows();
129 std::size_t numCols = chunkTensor.GetCols();
130 std::size_t numBatches = chunkSize / fBatchSize;
131 std::size_t leftoverBatchSize = chunkSize % fBatchSize;
132
133 // create a vector of batches
134 std::vector<std::unique_ptr<RFlat2DMatrix>> batches;
135
136 // fill the full batches from the chunk into a vector
137 for (std::size_t i = 0; i < numBatches; i++) {
138 batches.emplace_back(CreateBatch(chunkTensor, i));
139 }
140
141 // copy the remaining entries from the chunk into a leftover batch
143 std::copy(chunkTensor.GetData() + (numBatches * fBatchSize * numCols),
145 LeftoverBatch.GetData());
146
147 // calculate how many empty slots are left in fPrimaryLeftoverBatch
148 std::size_t PrimaryLeftoverSize = fPrimaryLeftoverBatch->GetRows();
150
151 // copy LeftoverBatch to end of fPrimaryLeftoverBatch
154 std::copy(LeftoverBatch.GetData(), LeftoverBatch.GetData() + (leftoverBatchSize * fNumColumns),
156
157 // copy LeftoverBatch to end of fPrimaryLeftoverBatch and add it to the batch
159 auto copy = std::make_unique<RFlat2DMatrix>(fBatchSize, fNumColumns);
160 std::copy(fPrimaryLeftoverBatch->GetData(), fPrimaryLeftoverBatch->GetData() + (fBatchSize * fNumColumns),
161 copy->GetData());
162 batches.emplace_back(std::move(copy));
163
164 // reset fPrimaryLeftoverBatch and fSecondaryLeftoverBatch
166 fSecondaryLeftoverBatch = std::make_unique<RFlat2DMatrix>();
167 }
168 }
169
170 // copy LeftoverBatch to both fPrimaryLeftoverBatch and fSecondaryLeftoverBatch
171 else if (emptySlots < leftoverBatchSize) {
172 // copy the first part of LeftoverBatch to end of fPrimaryLeftoverTrainingBatch
174 std::copy(LeftoverBatch.GetData(), LeftoverBatch.GetData() + (emptySlots * numCols),
176
177 // copy the last part of LeftoverBatch to the end of fSecondaryLeftoverBatch
179 std::copy(LeftoverBatch.GetData() + (emptySlots * numCols),
181
182 // add fPrimaryLeftoverBatch to the batch vector
183 auto copy = std::make_unique<RFlat2DMatrix>(fBatchSize, fNumColumns);
184 std::copy(fPrimaryLeftoverBatch->GetData(), fPrimaryLeftoverBatch->GetData() + (fBatchSize * fNumColumns),
185 copy->GetData());
186 batches.emplace_back(std::move(copy));
187
188 // exchange fPrimaryLeftoverBatch and fSecondaryLeftoverBatch
190 // reset fSecondaryLeftoverTrainingBatch
191 fSecondaryLeftoverBatch = std::make_unique<RFlat2DMatrix>();
192 }
193
194 // copy the content of fPrimaryLeftoverBatch to the leftover batch from the chunk
195 if (isLastBatch) {
197 auto copy = std::make_unique<RFlat2DMatrix>(fLeftoverBatchSize, fNumColumns);
198 std::copy(fPrimaryLeftoverBatch->GetData(),
199 fPrimaryLeftoverBatch->GetData() + (fLeftoverBatchSize * fNumColumns), copy->GetData());
200 batches.emplace_back(std::move(copy));
201 }
202
203 fPrimaryLeftoverBatch = std::make_unique<RFlat2DMatrix>();
204 fSecondaryLeftoverBatch = std::make_unique<RFlat2DMatrix>();
205 }
206
207 {
208 std::lock_guard<std::mutex> lock(fLock);
209 for (auto &batch : batches) {
210 fBatchQueue.push(std::move(batch));
211 }
212 }
213
214 fCV.notify_all();
215}
216
217/// \brief Recalculate batch counts from the given number of entries.
218/// Used at construction or when the true entry count is discovered lazily (filtered case).
219void RBatchLoader::RecalculateBatchCounts(std::size_t numEntries)
220{
221 fNumEntries = numEntries;
222
223 if (fBatchSize == 0) {
225 }
226
229
230 const std::size_t numLeftoverBatches = fLeftoverBatchSize == 0 ? 0 : 1;
232}
233} // namespace ROOT::Experimental::Internal::ML
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
RFlat2DMatrix GetBatch()
Loading the batch from the queue.
std::unique_ptr< RFlat2DMatrix > fSecondaryLeftoverBatch
std::queue< std::unique_ptr< RFlat2DMatrix > > fBatchQueue
void RecalculateBatchCounts(std::size_t numEntries)
Recalculate batch counts from the given number of entries.
void Reset()
Reset the batchloader state.
void Activate()
Activate the batchloader. This means that batches can be created and loaded.
void MarkProducerDone()
Signal that the producer has finished pushing all batches for this epoch.
void CreateBatches(RFlat2DMatrix &chunkTensor, bool isLastBatch)
Creating the batches from a chunk and add them to the queue.
void DeActivate()
DeActivate the batchloader.
std::unique_ptr< RFlat2DMatrix > fPrimaryLeftoverBatch
std::unique_ptr< RFlat2DMatrix > fCurrentBatch
RBatchLoader(std::size_t batchSize, const std::vector< std::string > &cols, std::mutex &sharedMutex, std::condition_variable &sharedCV, const std::vector< std::size_t > &vecSizes={}, std::size_t numEntries=0, bool dropRemainder=false)
std::unique_ptr< RFlat2DMatrix > CreateBatch(RFlat2DMatrix &chunkTensor, std::size_t idxs)
Return a batch of data as a unique pointer.
Wrapper around ROOT::RVec<float> representing a 2D matrix.