19 #ifndef TMVA_DNN_ARCHITECTURES_REFERENCE_DATALOADER 20 #define TMVA_DNN_ARCHITECTURES_REFERENCE_DATALOADER 29 template <
typename AReal>
32 template <
typename AData,
typename AReal>
52 TDataLoader(
const AData &
data,
size_t nSamples,
size_t batchSize,
size_t nInputFeatures,
size_t nOutputFeatures,
83 template <
typename AData,
typename AReal>
85 size_t nInputFeatures,
size_t nOutputFeatures,
size_t )
88 outputMatrix(batchSize, nOutputFeatures), weightMatrix(batchSize, 1),
fSampleIndices()
96 template <
typename AData,
typename AReal>
104 CopyInput(inputMatrix, sampleIndexIterator);
105 CopyOutput(outputMatrix, sampleIndexIterator);
114 template <
typename AData,
typename AReal>
123 #endif // TMVA_DNN_ARCHITECTURES_REFERENCE_DATALOADER void CopyOutput(HostBuffer_t &buffer, IndexIterator_t begin, size_t batchSize)
Copy output matrix into the given host buffer.
TBatch< AArchitecture > GetBatch()
Return the next batch from the training set.
typename std::vector< size_t >::iterator IndexIterator_t
void CopyInput(HostBuffer_t &buffer, IndexIterator_t begin, size_t batchSize)
Copy input matrix into the given host buffer.
TMatrixT< AReal > weightMatrix
std::vector< size_t > fSampleIndices
Ordering of the samples in the epoch.
The reference architecture class.
void Shuffle()
Shuffle the order of the samples in the batch.
TBatchIterator< Data_t, AArchitecture > BatchIterator_t
TDataLoader(const Data_t &data, size_t nSamples, size_t batchSize, size_t nInputFeatures, size_t nOutputFeatures, size_t nStreams=1)
TMatrixT< AReal > inputMatrix
TMatrixT< AReal > outputMatrix
std::vector< size_t > fSampleIndices
Ordering of the samples in the epoch.
Abstract ClassifierFactory template that handles arbitrary types.
void CopyWeights(HostBuffer_t &buffer, IndexIterator_t begin, size_t batchSize)
Copy weight matrix into the given host buffer.
TDataLoader & operator=(const TDataLoader &)=default