Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RBatchGenerator.hxx
Go to the documentation of this file.
1#ifndef TMVA_BATCHGENERATOR
2#define TMVA_BATCHGENERATOR
3
4#include <iostream>
5#include <vector>
6#include <thread>
7#include <memory>
8#include <cmath>
9#include <mutex>
10
11#include "TMVA/RTensor.hxx"
13#include "TMVA/RChunkLoader.hxx"
14#include "TMVA/RBatchLoader.hxx"
15#include "TMVA/Tools.h"
16#include "TRandom3.h"
17#include "TROOT.h"
18
19namespace TMVA {
20namespace Experimental {
21namespace Internal {
22
23template <typename... Args>
25private:
27
28 std::string fFileName;
29 std::string fTreeName;
30
31 std::vector<std::string> fCols;
32 std::vector<std::string> fFilters;
33
34 std::size_t fChunkSize;
35 std::size_t fMaxChunks;
36 std::size_t fBatchSize;
37 std::size_t fMaxBatches;
38 std::size_t fNumColumns;
39 std::size_t fNumEntries;
40 std::size_t fCurrentRow = 0;
41
43
45 std::unique_ptr<TMVA::Experimental::Internal::RBatchLoader> fBatchLoader;
46
47 std::unique_ptr<std::thread> fLoadingThread;
48
49 bool fUseWholeFile = true;
50
51 std::unique_ptr<TMVA::Experimental::RTensor<float>> fChunkTensor;
52 std::unique_ptr<TMVA::Experimental::RTensor<float>> fCurrentBatch;
53
54 std::vector<std::vector<std::size_t>> fTrainingIdxs;
55 std::vector<std::vector<std::size_t>> fValidationIdxs;
56
57 // filled batch elements
58 std::mutex fIsActiveLock;
59
60 bool fShuffle = true;
61 bool fIsActive = false;
62
63 std::vector<std::size_t> fVecSizes;
65
66public:
67 RBatchGenerator(const std::string &treeName, const std::string &fileName, const std::size_t chunkSize,
68 const std::size_t batchSize, const std::vector<std::string> &cols,
69 const std::vector<std::string> &filters = {}, const std::vector<std::size_t> &vecSizes = {},
70 const float vecPadding = 0.0, const float validationSplit = 0.0, const std::size_t maxChunks = 0,
71 const std::size_t numColumns = 0, bool shuffle = true)
72 : fTreeName(treeName),
73 fFileName(fileName),
74 fChunkSize(chunkSize),
75 fBatchSize(batchSize),
76 fCols(cols),
78 fVecSizes(vecSizes),
79 fVecPadding(vecPadding),
80 fValidationSplit(validationSplit),
81 fMaxChunks(maxChunks),
82 fNumColumns((numColumns != 0) ? numColumns : cols.size()),
83 fShuffle(shuffle),
84 fUseWholeFile(maxChunks == 0)
85 {
86 // limits the number of batches that can be contained in the batchqueue based on the chunksize
88
89 // get the number of fNumEntries in the dataframe
90 std::unique_ptr<TFile> f{TFile::Open(fFileName.c_str())};
91 std::unique_ptr<TTree> t{f->Get<TTree>(fTreeName.c_str())};
93
96 fBatchLoader = std::make_unique<TMVA::Experimental::Internal::RBatchLoader>(fBatchSize, fNumColumns, fMaxBatches);
97
98 // Create tensor to load the chunk into
100 std::make_unique<TMVA::Experimental::RTensor<float>>((std::vector<std::size_t>){fChunkSize, fNumColumns});
101 }
102
104
105 /// \brief De-activate the loading process by deactivating the batchgenerator
106 /// and joining the loading thread
108 {
109 {
110 std::lock_guard<std::mutex> lock(fIsActiveLock);
111 fIsActive = false;
112 }
113
114 fBatchLoader->DeActivate();
115
116 if (fLoadingThread) {
117 if (fLoadingThread->joinable()) {
118 fLoadingThread->join();
119 }
120 }
121 }
122
123 /// \brief Activate the loading process by starting the batchloader, and
124 /// spawning the loading thread.
125 void Activate()
126 {
127 if (fIsActive)
128 return;
129
130 {
131 std::lock_guard<std::mutex> lock(fIsActiveLock);
132 fIsActive = true;
133 }
134
135 fCurrentRow = 0;
136 fBatchLoader->Activate();
137 fLoadingThread = std::make_unique<std::thread>(&RBatchGenerator::LoadChunks, this);
138 }
139
140 /// \brief Returns the next batch of training data if available.
141 /// Returns empty RTensor otherwise.
142 /// \return
144 {
145 // Get next batch if available
146 return fBatchLoader->GetTrainBatch();
147 }
148
149 /// \brief Returns the next batch of validation data if available.
150 /// Returns empty RTensor otherwise.
151 /// \return
153 {
154 // Get next batch if available
155 return fBatchLoader->GetValidationBatch();
156 }
157
158 bool HasTrainData() { return fBatchLoader->HasTrainData(); }
159
160 bool HasValidationData() { return fBatchLoader->HasValidationData(); }
161
163 {
165
166 for (std::size_t current_chunk = 0; ((current_chunk < fMaxChunks) || fUseWholeFile) && fCurrentRow < fNumEntries;
167 current_chunk++) {
168
169 // stop the loop when the loading is not active anymore
170 {
171 std::lock_guard<std::mutex> lock(fIsActiveLock);
172 if (!fIsActive)
173 return;
174 }
175
176 // A pair that consists the proccessed, and passed events while loading the chunk
177 std::pair<std::size_t, std::size_t> report = fChunkLoader->LoadChunk(*fChunkTensor, fCurrentRow);
178 fCurrentRow += report.first;
179
180 CreateBatches(current_chunk, report.second);
181
182 // Stop loading if the number of processed events is smaller than the desired chunk size
183 if (report.first < fChunkSize) {
184 break;
185 }
186 }
187
188 fBatchLoader->DeActivate();
189 }
190
191 /// \brief Create batches for the current_chunk.
192 /// \param currentChunk
193 /// \param processedEvents
194 void CreateBatches(std::size_t currentChunk, std::size_t processedEvents)
195 {
196
197 // Check if the indices in this chunk where already split in train and validations
198 if (fTrainingIdxs.size() > currentChunk) {
199 fBatchLoader->CreateTrainingBatches(*fChunkTensor, fTrainingIdxs[currentChunk], fShuffle);
200 } else {
201 // Create the Validation batches if this is not the first epoch
202 createIdxs(processedEvents);
203 fBatchLoader->CreateTrainingBatches(*fChunkTensor, fTrainingIdxs[currentChunk], fShuffle);
204 fBatchLoader->CreateValidationBatches(*fChunkTensor, fValidationIdxs[currentChunk]);
205 }
206 }
207
208 /// \brief plit the events of the current chunk into validation and training events
209 /// \param processedEvents
210 void createIdxs(std::size_t processedEvents)
211 {
212 // Create a vector of number 1..processedEvents
213 std::vector<std::size_t> row_order = std::vector<std::size_t>(processedEvents);
214 std::iota(row_order.begin(), row_order.end(), 0);
215
216 if (fShuffle) {
217 std::shuffle(row_order.begin(), row_order.end(), fRng);
218 }
219
220 // calculate the number of events used for validation
221 std::size_t num_validation = ceil(processedEvents * fValidationSplit);
222
223 // Devide the vector into training and validation
224 std::vector<std::size_t> valid_idx({row_order.begin(), row_order.begin() + num_validation});
225 std::vector<std::size_t> train_idx({row_order.begin() + num_validation, row_order.end()});
226
227 fTrainingIdxs.push_back(train_idx);
228 fValidationIdxs.push_back(valid_idx);
229 }
230
231 void StartValidation() { fBatchLoader->StartValidation(); }
232 bool IsActive() { return fIsActive; }
233};
234
235} // namespace Internal
236} // namespace Experimental
237} // namespace TMVA
238
239#endif // TMVA_BATCHGENERATOR
#define f(i)
Definition RSha256.hxx:104
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
const char * filters[]
TObject * Get(const char *namecycle) override
Return pointer to object identified by namecycle.
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4075
void CreateBatches(std::size_t currentChunk, std::size_t processedEvents)
Create batches for the current_chunk.
std::unique_ptr< std::thread > fLoadingThread
const TMVA::Experimental::RTensor< float > & GetTrainBatch()
Returns the next batch of training data if available.
RBatchGenerator(const std::string &treeName, const std::string &fileName, const std::size_t chunkSize, const std::size_t batchSize, const std::vector< std::string > &cols, const std::vector< std::string > &filters={}, const std::vector< std::size_t > &vecSizes={}, const float vecPadding=0.0, const float validationSplit=0.0, const std::size_t maxChunks=0, const std::size_t numColumns=0, bool shuffle=true)
void Activate()
Activate the loading process by starting the batchloader, and spawning the loading thread.
const TMVA::Experimental::RTensor< float > & GetValidationBatch()
Returns the next batch of validation data if available.
std::vector< std::vector< std::size_t > > fTrainingIdxs
void DeActivate()
De-activate the loading process by deactivating the batchgenerator and joining the loading thread.
void createIdxs(std::size_t processedEvents)
plit the events of the current chunk into validation and training events
std::unique_ptr< TMVA::Experimental::Internal::RBatchLoader > fBatchLoader
std::vector< std::vector< std::size_t > > fValidationIdxs
TMVA::RandomGenerator< TRandom3 > fRng
std::unique_ptr< TMVA::Experimental::RTensor< float > > fCurrentBatch
std::unique_ptr< TMVA::Experimental::Internal::RChunkLoader< Args... > > fChunkLoader
std::unique_ptr< TMVA::Experimental::RTensor< float > > fChunkTensor
RTensor is a container with contiguous memory and shape information.
Definition RTensor.hxx:162
A TTree represents a columnar dataset.
Definition TTree.h:79
virtual Long64_t GetEntries() const
Definition TTree.h:463
void EnableThreadSafety()
Enable support for multi-threading within the ROOT code in particular, enables the global mutex to ma...
Definition TROOT.cxx:501
create variable transformations