Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RBatchLoader.cxx
Go to the documentation of this file.
1
3
4#include <algorithm>
5#include <numeric>
6#include <utility>
7
9
10RBatchLoader::RBatchLoader(std::size_t batchSize, const std::vector<std::string> &cols, std::mutex &sharedMutex,
11 std::condition_variable &sharedCV, const std::vector<std::size_t> &vecSizes,
12 std::size_t numEntries, bool dropRemainder)
13 : fBatchSize(batchSize),
14 fCols(cols),
15 fLock(sharedMutex),
16 fCV(sharedCV),
17 fVecSizes(vecSizes),
18 fNumEntries(numEntries),
19 fDropRemainder(dropRemainder)
20{
21 fSumVecSizes = std::accumulate(fVecSizes.begin(), fVecSizes.end(), 0);
22 fNumColumns = fCols.size() + fSumVecSizes - fVecSizes.size();
23
24 if (fBatchSize == 0) {
26 }
27
30
31 std::size_t numLeftoverBatches = fLeftoverBatchSize == 0 ? 0 : 1;
32
33 if (fDropRemainder) {
35 } else {
37 }
38
39 fPrimaryLeftoverBatch = std::make_unique<RFlat2DMatrix>();
40 fSecondaryLeftoverBatch = std::make_unique<RFlat2DMatrix>();
41}
42
43/// \brief Activate the batchloader. This means that batches can be created and loaded.
45{
46 {
47 std::lock_guard<std::mutex> lock(fLock);
48 if (fIsActive)
49 return;
50 fIsActive = true;
51 fProducerDone = false;
52 }
53
54 fCV.notify_all();
55}
56
57/// \brief DeActivate the batchloader. This means that no more batches are created.
58/// Batches can still be returned if they are already loaded.
60{
61 {
62 std::lock_guard<std::mutex> lock(fLock);
63 if (!fIsActive)
64 return;
65 fIsActive = false;
66 }
67
68 fCV.notify_all();
69}
70
71/// \brief Reset the batchloader state.
73{
74 {
75 std::lock_guard<std::mutex> lock(fLock);
76
77 while (!fBatchQueue.empty()) {
78 fBatchQueue.pop();
79 }
80
81 fCurrentBatch.reset();
82 fPrimaryLeftoverBatch = std::make_unique<RFlat2DMatrix>();
83 fSecondaryLeftoverBatch = std::make_unique<RFlat2DMatrix>();
84 }
85
86 fCV.notify_all();
87}
88
89/// \brief Signal that the producer has finished pushing all batches for this epoch.
91{
92 fProducerDone = true;
93 fCV.notify_all();
94}
95
96/// \brief Return a batch of data as a unique pointer.
97/// After the batch has been processed, it should be destroyed.
98/// \param[in] chunkTensor Tensor with the data from the chunk
99/// \param[in] idxs Index of batch in the chunk
100/// \return Batch
101std::unique_ptr<RFlat2DMatrix> RBatchLoader::CreateBatch(RFlat2DMatrix &chunkTensor, std::size_t idxs)
102{
103 auto batch = std::make_unique<RFlat2DMatrix>(fBatchSize, fNumColumns);
104 std::copy(chunkTensor.GetData() + (idxs * fBatchSize * fNumColumns),
105 chunkTensor.GetData() + ((idxs + 1) * fBatchSize * fNumColumns), batch->GetData());
106
107 return batch;
108}
109
110/// \brief Loading the batch from the queue.
111/// \return Batch
113{
114 std::unique_lock<std::mutex> lock(fLock);
115
116 // Wait until:
117 // - there is data in the queue
118 // - or producer declares "done"
119 // - or we are deactivated
120 fCV.wait(lock, [&] { return !fBatchQueue.empty() || fProducerDone || !fIsActive; });
121
122 if (fBatchQueue.empty()) {
123 // producer done and no queued data -> end-of-epoch signal
124 fCurrentBatch = std::make_unique<RFlat2DMatrix>();
125 return *fCurrentBatch;
126 }
127
128 fCurrentBatch = std::move(fBatchQueue.front());
129 fBatchQueue.pop();
130 // Notify the loading thread that the queue has drained
131 fCV.notify_all();
132
133 return *fCurrentBatch;
134}
135
136/// \brief Creating the batches from a chunk and add them to the queue.
137/// \param[in] chunkTensor Tensor with the data from the chunk
138/// \param[in] isLastBatch Check if the batch in the chunk is the last one
140{
141 std::size_t chunkSize = chunkTensor.GetRows();
142 std::size_t numCols = chunkTensor.GetCols();
143 std::size_t numBatches = chunkSize / fBatchSize;
144 std::size_t leftoverBatchSize = chunkSize % fBatchSize;
145
146 // create a vector of batches
147 std::vector<std::unique_ptr<RFlat2DMatrix>> batches;
148
149 // fill the full batches from the chunk into a vector
150 for (std::size_t i = 0; i < numBatches; i++) {
151 batches.emplace_back(CreateBatch(chunkTensor, i));
152 }
153
154 // copy the remaining entries from the chunk into a leftover batch
156 std::copy(chunkTensor.GetData() + (numBatches * fBatchSize * numCols),
158 LeftoverBatch.GetData());
159
160 // calculate how many empty slots are left in fPrimaryLeftoverBatch
161 std::size_t PrimaryLeftoverSize = fPrimaryLeftoverBatch->GetRows();
163
164 // copy LeftoverBatch to end of fPrimaryLeftoverBatch
167 std::copy(LeftoverBatch.GetData(), LeftoverBatch.GetData() + (leftoverBatchSize * fNumColumns),
169
170 // copy LeftoverBatch to end of fPrimaryLeftoverBatch and add it to the batch
172 auto copy = std::make_unique<RFlat2DMatrix>(fBatchSize, fNumColumns);
173 std::copy(fPrimaryLeftoverBatch->GetData(), fPrimaryLeftoverBatch->GetData() + (fBatchSize * fNumColumns),
174 copy->GetData());
175 batches.emplace_back(std::move(copy));
176
177 // reset fPrimaryLeftoverBatch and fSecondaryLeftoverBatch
179 fSecondaryLeftoverBatch = std::make_unique<RFlat2DMatrix>();
180 }
181 }
182
183 // copy LeftoverBatch to both fPrimaryLeftoverBatch and fSecondaryLeftoverBatch
184 else if (emptySlots < leftoverBatchSize) {
185 // copy the first part of LeftoverBatch to end of fPrimaryLeftoverTrainingBatch
187 std::copy(LeftoverBatch.GetData(), LeftoverBatch.GetData() + (emptySlots * numCols),
189
190 // copy the last part of LeftoverBatch to the end of fSecondaryLeftoverBatch
192 std::copy(LeftoverBatch.GetData() + (emptySlots * numCols),
194
195 // add fPrimaryLeftoverBatch to the batch vector
196 auto copy = std::make_unique<RFlat2DMatrix>(fBatchSize, fNumColumns);
197 std::copy(fPrimaryLeftoverBatch->GetData(), fPrimaryLeftoverBatch->GetData() + (fBatchSize * fNumColumns),
198 copy->GetData());
199 batches.emplace_back(std::move(copy));
200
201 // exchange fPrimaryLeftoverBatch and fSecondaryLeftoverBatch
203 // reset fSecondaryLeftoverTrainingBatch
204 fSecondaryLeftoverBatch = std::make_unique<RFlat2DMatrix>();
205 }
206
207 // copy the content of fPrimaryLeftoverBatch to the leftover batch from the chunk
208 if (isLastBatch) {
210 auto copy = std::make_unique<RFlat2DMatrix>(fLeftoverBatchSize, fNumColumns);
211 std::copy(fPrimaryLeftoverBatch->GetData(),
212 fPrimaryLeftoverBatch->GetData() + (fLeftoverBatchSize * fNumColumns), copy->GetData());
213 batches.emplace_back(std::move(copy));
214 }
215
216 fPrimaryLeftoverBatch = std::make_unique<RFlat2DMatrix>();
217 fSecondaryLeftoverBatch = std::make_unique<RFlat2DMatrix>();
218 }
219
220 {
221 std::lock_guard<std::mutex> lock(fLock);
222 for (auto &batch : batches) {
223 fBatchQueue.push(std::move(batch));
224 }
225 }
226
227 fCV.notify_all();
228}
229
230} // namespace ROOT::Experimental::Internal::ML
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
RFlat2DMatrix GetBatch()
Loading the batch from the queue.
std::unique_ptr< RFlat2DMatrix > fSecondaryLeftoverBatch
std::queue< std::unique_ptr< RFlat2DMatrix > > fBatchQueue
void Reset()
Reset the batchloader state.
void Activate()
Activate the batchloader. This means that batches can be created and loaded.
void MarkProducerDone()
Signal that the producer has finished pushing all batches for this epoch.
void CreateBatches(RFlat2DMatrix &chunkTensor, bool isLastBatch)
Creating the batches from a chunk and add them to the queue.
void DeActivate()
DeActivate the batchloader.
std::unique_ptr< RFlat2DMatrix > fPrimaryLeftoverBatch
std::unique_ptr< RFlat2DMatrix > fCurrentBatch
RBatchLoader(std::size_t batchSize, const std::vector< std::string > &cols, std::mutex &sharedMutex, std::condition_variable &sharedCV, const std::vector< std::size_t > &vecSizes={}, std::size_t numEntries=0, bool dropRemainder=false)
std::unique_ptr< RFlat2DMatrix > CreateBatch(RFlat2DMatrix &chunkTensor, std::size_t idxs)
Return a batch of data as a unique pointer.
Wrapper around ROOT::RVec<float> representing a 2D matrix.