19#ifndef TMVA_DNN_ARCHITECTURES_REFERENCE_DATALOADER
20#define TMVA_DNN_ARCHITECTURES_REFERENCE_DATALOADER
29template <
typename AReal>
32template <
typename AData,
typename AReal>
52 TDataLoader(
const AData &data,
size_t nSamples,
size_t batchSize,
size_t nInputFeatures,
size_t nOutputFeatures,
83template <
typename AData,
typename AReal>
85 size_t nInputFeatures,
size_t nOutputFeatures,
size_t )
86 : fData(data), fNSamples(nSamples), fBatchSize(batchSize), fNInputFeatures(nInputFeatures),
87 fNOutputFeatures(nOutputFeatures), fBatchIndex(0), inputMatrix(batchSize, nInputFeatures),
88 outputMatrix(batchSize, nOutputFeatures), weightMatrix(batchSize, 1), fSampleIndices()
96template <
typename AData,
typename AReal>
99 fBatchIndex %= (fNSamples / fBatchSize);
101 size_t sampleIndex = fBatchIndex * fBatchSize;
104 CopyInput(inputMatrix, sampleIndexIterator);
105 CopyOutput(outputMatrix, sampleIndexIterator);
106 CopyWeights(weightMatrix, sampleIndexIterator);
114template <
typename AData,
typename AReal>
117 std::shuffle(fSampleIndices.begin(), fSampleIndices.end(), std::default_random_engine{});
void CopyWeights(TMatrixT< AReal > &matrix, IndexIterator_t begin)
Copy weight matrix into the given host buffer.
TDataLoader & operator=(const TDataLoader &)=default
void CopyOutput(TMatrixT< AReal > &matrix, IndexIterator_t begin)
Copy output matrix into the given host buffer.
void CopyInput(TMatrixT< AReal > &matrix, IndexIterator_t begin)
Copy input matrix into the given host buffer.
TDataLoader(TDataLoader &&)=default
TDataLoader & operator=(TDataLoader &&)=default
TMatrixT< AReal > outputMatrix
std::vector< size_t > fSampleIndices
Ordering of the samples in the epoch.
TMatrixT< AReal > inputMatrix
TDataLoader(const TDataLoader &)=default
TMatrixT< AReal > weightMatrix
TBatchIterator< Data_t, AArchitecture > BatchIterator_t
std::vector< size_t > fSampleIndices
Ordering of the samples in the epoch.
TBatch< AArchitecture > GetBatch()
Return the next batch from the training set.
void Shuffle()
Shuffle the order of the samples in the batch.
The reference architecture class.
typename std::vector< size_t >::iterator IndexIterator_t
create variable transformations