Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RBatchLoader.hxx
Go to the documentation of this file.
1#ifndef TMVA_RBatchLoader
2#define TMVA_RBatchLoader
3
4#include <iostream>
5#include <vector>
6#include <memory>
7
8// Imports for threading
9#include <queue>
10#include <mutex>
11#include <condition_variable>
12
13#include "TMVA/RTensor.hxx"
14#include "TMVA/Tools.h"
15#include "TRandom3.h"
16
17namespace TMVA {
18namespace Experimental {
19namespace Internal {
20
22private:
23 std::size_t fBatchSize;
24 std::size_t fNumColumns;
25 std::size_t fMaxBatches;
26
27 bool fIsActive = false;
29
30 std::mutex fBatchLock;
31 std::condition_variable fBatchCondition;
32
33 std::queue<std::unique_ptr<TMVA::Experimental::RTensor<float>>> fTrainingBatchQueue;
34 std::vector<std::unique_ptr<TMVA::Experimental::RTensor<float>>> fValidationBatches;
35 std::unique_ptr<TMVA::Experimental::RTensor<float>> fCurrentBatch;
36
37 std::size_t fValidationIdx = 0;
38
40
41public:
42 RBatchLoader(const std::size_t batchSize, const std::size_t numColumns, const std::size_t maxBatches)
43 : fBatchSize(batchSize), fNumColumns(numColumns), fMaxBatches(maxBatches)
44 {
45 }
46
48
49public:
50 /// \brief Return a batch of data as a unique pointer.
51 /// After the batch has been processed, it should be distroyed.
52 /// \return Training batch
54 {
55 std::unique_lock<std::mutex> lock(fBatchLock);
56 fBatchCondition.wait(lock, [this]() { return !fTrainingBatchQueue.empty() || !fIsActive; });
57
58 if (fTrainingBatchQueue.empty()) {
59 fCurrentBatch = std::make_unique<TMVA::Experimental::RTensor<float>>(std::vector<std::size_t>({0}));
60 return *fCurrentBatch;
61 }
62
63 fCurrentBatch = std::move(fTrainingBatchQueue.front());
65
66 fBatchCondition.notify_all();
67
68 return *fCurrentBatch;
69 }
70
71 /// \brief Returns a batch of data for validation
72 /// The owner of this batch has to be with the RBatchLoader.
73 /// This is because the same validation batches should be used in all epochs.
74 /// \return Validation batch
76 {
77 if (HasValidationData()) {
78 return *fValidationBatches[fValidationIdx++].get();
79 }
80
81 return fEmptyTensor;
82 }
83
84 /// \brief Checks if there are more training batches available
85 /// \return
87 {
88 {
89 std::unique_lock<std::mutex> lock(fBatchLock);
90 if (!fTrainingBatchQueue.empty() || fIsActive)
91 return true;
92 }
93
94 return false;
95 }
96
97 /// \brief Checks if there are more training batches available
98 /// \return
100 {
101 std::unique_lock<std::mutex> lock(fBatchLock);
102 return fValidationIdx < fValidationBatches.size();
103 }
104
105 /// \brief Activate the batchloader so it will accept chunks to batch
106 void Activate()
107 {
108 {
109 std::lock_guard<std::mutex> lock(fBatchLock);
110 fIsActive = true;
111 }
112 fBatchCondition.notify_all();
113 }
114
115 /// \brief DeActivate the batchloader. This means that no more batches are created.
116 /// Batches can still be returned if they are already loaded
118 {
119 {
120 std::lock_guard<std::mutex> lock(fBatchLock);
121 fIsActive = false;
122 }
123 fBatchCondition.notify_all();
124 }
125
126 /// \brief Create a batch filled with the events on the given idx
127 /// \param chunkTensor
128 /// \param idx
129 /// \return
130 std::unique_ptr<TMVA::Experimental::RTensor<float>>
131 CreateBatch(const TMVA::Experimental::RTensor<float> &chunkTensor, const std::vector<std::size_t> idx)
132 {
133 auto batch =
134 std::make_unique<TMVA::Experimental::RTensor<float>>(std::vector<std::size_t>({fBatchSize, fNumColumns}));
135
136 for (std::size_t i = 0; i < fBatchSize; i++) {
137 std::copy(chunkTensor.GetData() + (idx[i] * fNumColumns), chunkTensor.GetData() + ((idx[i] + 1) * fNumColumns),
138 batch->GetData() + i * fNumColumns);
139 }
140
141 return batch;
142 }
143
144 /// \brief Create training batches from the given chunk of data based on the given event indices
145 /// Batches are added to the training queue of batches
146 /// The eventIndices can be shuffled to ensure random order for each epoch
147 /// \param chunkTensor
148 /// \param eventIndices
149 /// \param shuffle
151 std::vector<std::size_t> eventIndices, const bool shuffle = true)
152 {
153 // Wait until less than a full chunk of batches are in the queue before loading splitting the next chunk into
154 // batches
155 {
156 std::unique_lock<std::mutex> lock(fBatchLock);
157 fBatchCondition.wait(lock, [this]() { return (fTrainingBatchQueue.size() < fMaxBatches) || !fIsActive; });
158 if (!fIsActive)
159 return;
160 }
161
162 if (shuffle)
163 std::shuffle(eventIndices.begin(), eventIndices.end(), fRng); // Shuffle the order of idx
164
165 std::vector<std::unique_ptr<TMVA::Experimental::RTensor<float>>> batches;
166
167 // Create tasks of fBatchSize untill all idx are used
168 for (std::size_t start = 0; (start + fBatchSize) <= eventIndices.size(); start += fBatchSize) {
169
170 // Grab the first fBatchSize indices from the
171 std::vector<std::size_t> idx;
172 for (std::size_t i = start; i < (start + fBatchSize); i++) {
173 idx.push_back(eventIndices[i]);
174 }
175
176 // Fill a batch
177 batches.emplace_back(CreateBatch(chunkTensor, idx));
178 }
179
180 {
181 std::unique_lock<std::mutex> lock(fBatchLock);
182 for (std::size_t i = 0; i < batches.size(); i++) {
183 fTrainingBatchQueue.push(std::move(batches[i]));
184 }
185 }
186
187 fBatchCondition.notify_one();
188 }
189
190 /// \brief Create validation batches from the given chunk based on the given event indices
191 /// Batches are added to the vector of validation batches
192 /// \param chunkTensor
193 /// \param eventIndices
195 const std::vector<std::size_t> eventIndices)
196 {
197 // Create tasks of fBatchSize untill all idx are used
198 for (std::size_t start = 0; (start + fBatchSize) <= eventIndices.size(); start += fBatchSize) {
199
200 std::vector<std::size_t> idx;
201
202 for (std::size_t i = start; i < (start + fBatchSize); i++) {
203 idx.push_back(eventIndices[i]);
204 }
205
206 {
207 std::unique_lock<std::mutex> lock(fBatchLock);
208 fValidationBatches.emplace_back(CreateBatch(chunkTensor, idx));
209 }
210 }
211 }
212
213 /// \brief Reset the validation process
215 {
216 std::unique_lock<std::mutex> lock(fBatchLock);
217 fValidationIdx = 0;
218 }
219};
220
221} // namespace Internal
222} // namespace Experimental
223} // namespace TMVA
224
225#endif // TMVA_RBatchLoader
void Activate()
Activate the batchloader so it will accept chunks to batch.
void CreateValidationBatches(const TMVA::Experimental::RTensor< float > &chunkTensor, const std::vector< std::size_t > eventIndices)
Create validation batches from the given chunk based on the given event indices Batches are added to ...
RBatchLoader(const std::size_t batchSize, const std::size_t numColumns, const std::size_t maxBatches)
std::unique_ptr< TMVA::Experimental::RTensor< float > > CreateBatch(const TMVA::Experimental::RTensor< float > &chunkTensor, const std::vector< std::size_t > idx)
Create a batch filled with the events on the given idx.
TMVA::Experimental::RTensor< float > fEmptyTensor
void CreateTrainingBatches(const TMVA::Experimental::RTensor< float > &chunkTensor, std::vector< std::size_t > eventIndices, const bool shuffle=true)
Create training batches from the given chunk of data based on the given event indices Batches are add...
std::vector< std::unique_ptr< TMVA::Experimental::RTensor< float > > > fValidationBatches
bool HasTrainData()
Checks if there are more training batches available.
const TMVA::Experimental::RTensor< float > & GetTrainBatch()
Return a batch of data as a unique pointer.
bool HasValidationData()
Checks if there are more training batches available.
std::unique_ptr< TMVA::Experimental::RTensor< float > > fCurrentBatch
TMVA::RandomGenerator< TRandom3 > fRng
void StartValidation()
Reset the validation process.
const TMVA::Experimental::RTensor< float > & GetValidationBatch()
Returns a batch of data for validation The owner of this batch has to be with the RBatchLoader.
void DeActivate()
DeActivate the batchloader.
std::queue< std::unique_ptr< TMVA::Experimental::RTensor< float > > > fTrainingBatchQueue
RTensor is a container with contiguous memory and shape information.
Definition RTensor.hxx:162
create variable transformations