Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TensorDataLoader.h
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Author: Vladimir Ilievski
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : TTensorDataLoader *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Specialization of the Tensor Data Loader Class *
12 * *
13 * Authors (alphabetical): *
14 * Vladimir Ilievski <ilievski.vladimir@live.com> - CERN, Switzerland *
15 * *
16 * Copyright (c) 2005-2015: *
17 * CERN, Switzerland *
18 * U. of Victoria, Canada *
19 * MPI-K Heidelberg, Germany *
20 * U. of Bonn, Germany *
21 * *
22 * Redistribution and use in source and binary forms, with or without *
23 * modification, are permitted according to the terms listed in LICENSE *
24 * (http://tmva.sourceforge.net/LICENSE) *
25 **********************************************************************************/
26
27//////////////////////////////////////////////////////////////////////////
28// Partial specialization of the TTensorDataLoader class to adapt //
29// it to the TMatrix class. Also the data transfer is kept simple, //
30// since this implementation (being intended as reference and fallback) //
31// is not optimized for performance. //
32//////////////////////////////////////////////////////////////////////////
33
34#ifndef TMVA_DNN_ARCHITECTURES_REFERENCE_TENSORDATALOADER
35#define TMVA_DNN_ARCHITECTURES_REFERENCE_TENSORDATALOADER
36
38#include <iostream>
39
40namespace TMVA {
41namespace DNN {
42
43template <typename AReal>
44class TReference;
45
46template <typename AData, typename AReal>
48private:
50
51 const AData &fData; ///< The data that should be loaded in the batches.
52
53 size_t fNSamples; ///< The total number of samples in the dataset.
54 //size_t fBatchSize; ///< The size of a batch.
55 size_t fBatchDepth; ///< The number of matrices in the tensor.
56 size_t fBatchHeight; ///< The number od rows in each matrix.
57 size_t fBatchWidth; ///< The number of columns in each matrix.
58 size_t fNOutputFeatures; ///< The number of outputs from the classifier/regressor.
59 size_t fBatchIndex; ///< The index of the batch when there are multiple batches in parallel.
60
61 std::vector<size_t> fInputShape; ///< Defines the batch depth, no. of channels and spatial dimensions of an input tensor
62
63 std::vector<TMatrixT<AReal>> inputTensor; ///< The 3D tensor used to keep the input data.
64 TMatrixT<AReal> outputMatrix; ///< The matrix used to keep the output.
65 TMatrixT<AReal> weightMatrix; ///< The matrix used to keep the batch weights.
66
67 std::vector<size_t> fSampleIndices; ///< Ordering of the samples in the epoch.
68
69public:
70 /*! Constructor. */
71 TTensorDataLoader(const AData &data, size_t nSamples, size_t batchDepth,
72 size_t batchHeight, size_t batchWidth, size_t nOutputFeatures,
73 std::vector<size_t> inputShape, size_t nStreams = 1);
74
79
80 /** Copy input tensor into the given host buffer. Function to be specialized by
81 * the architecture-specific backend. */
82 void CopyTensorInput(std::vector<TMatrixT<AReal>> &tensor, IndexIterator_t sampleIterator);
83 /** Copy output matrix into the given host buffer. Function to be specialized
84 * by the architecture-specific backend. */
85 void CopyTensorOutput(TMatrixT<AReal> &matrix, IndexIterator_t sampleIterator);
86 /** Copy weight matrix into the given host buffer. Function to be specialized
87 * by the architecture-specific backend. */
88 void CopyTensorWeights(TMatrixT<AReal> &matrix, IndexIterator_t sampleIterator);
89
91 BatchIterator_t end() { return BatchIterator_t(*this, fNSamples / fInputShape[0]); }
92
93 /** Shuffle the order of the samples in the batch. The shuffling is indirect,
94 * i.e. only the indices are shuffled. No input data is moved by this
95 * routine. */
96 template<typename RNG>
97 void Shuffle(RNG & rng);
98
99 /** Return the next batch from the training set. The TTensorDataLoader object
100 * keeps an internal counter that cycles over the batches in the training
101 * set. */
103};
104
105//
106// TTensorDataLoader Class.
107//______________________________________________________________________________
108template <typename AData, typename AReal>
109TTensorDataLoader<AData, TReference<AReal>>::TTensorDataLoader(const AData &data, size_t nSamples, size_t batchDepth,
110 size_t batchHeight, size_t batchWidth, size_t nOutputFeatures,
111 std::vector<size_t> inputShape, size_t /* nStreams */)
112 : fData(data), fNSamples(nSamples), fBatchDepth(batchDepth), fBatchHeight(batchHeight),
113 fBatchWidth(batchWidth), fNOutputFeatures(nOutputFeatures), fBatchIndex(0), fInputShape(std::move(inputShape)), inputTensor(),
114 outputMatrix(inputShape[0], nOutputFeatures), weightMatrix(inputShape[0], 1), fSampleIndices()
115{
116
117 inputTensor.reserve(fBatchDepth);
118 for (size_t i = 0; i < fBatchDepth; i++) {
119 inputTensor.emplace_back(batchHeight, batchWidth);
120 }
121
122 fSampleIndices.reserve(fNSamples);
123 for (size_t i = 0; i < fNSamples; i++) {
124 fSampleIndices.push_back(i);
125 }
126}
127
128template <typename AData, typename AReal>
129template <typename RNG>
131{
132 std::shuffle(fSampleIndices.begin(), fSampleIndices.end(), rng);
133}
134
135template <typename AData, typename AReal>
137{
138 fBatchIndex %= (fNSamples / fInputShape[0]); // Cycle through samples.
139
140 size_t sampleIndex = fBatchIndex * fInputShape[0];
141 IndexIterator_t sampleIndexIterator = fSampleIndices.begin() + sampleIndex;
142
143 CopyTensorInput(inputTensor, sampleIndexIterator);
144 CopyTensorOutput(outputMatrix, sampleIndexIterator);
145 CopyTensorWeights(weightMatrix, sampleIndexIterator);
146
147 fBatchIndex++;
148 return TTensorBatch<TReference<AReal>>(inputTensor, outputMatrix, weightMatrix);
149}
150
151} // namespace DNN
152} // namespace TMVA
153
154#endif
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
The reference architecture class.
Definition Reference.h:53
TTensorDataLoader & operator=(const TTensorDataLoader &)=default
TMatrixT< AReal > weightMatrix
The matrix used to keep the batch weights.
size_t fBatchHeight
The number od rows in each matrix.
const AData & fData
The data that should be loaded in the batches.
std::vector< TMatrixT< AReal > > inputTensor
The 3D tensor used to keep the input data.
size_t fNSamples
The total number of samples in the dataset.
std::vector< size_t > fInputShape
Defines the batch depth, no. of channels and spatial dimensions of an input tensor.
void CopyTensorInput(std::vector< TMatrixT< AReal > > &tensor, IndexIterator_t sampleIterator)
Copy input tensor into the given host buffer.
std::vector< size_t > fSampleIndices
Ordering of the samples in the epoch.
size_t fBatchIndex
The index of the batch when there are multiple batches in parallel.
TMatrixT< AReal > outputMatrix
The matrix used to keep the output.
TTensorDataLoader(const TTensorDataLoader &)=default
size_t fBatchDepth
The number of matrices in the tensor.
size_t fBatchWidth
The number of columns in each matrix.
void CopyTensorWeights(TMatrixT< AReal > &matrix, IndexIterator_t sampleIterator)
Copy weight matrix into the given host buffer.
TTensorDataLoader & operator=(TTensorDataLoader &&)=default
size_t fNOutputFeatures
The number of outputs from the classifier/regressor.
void CopyTensorOutput(TMatrixT< AReal > &matrix, IndexIterator_t sampleIterator)
Copy output matrix into the given host buffer.
TTensorBatchIterator< Data_t, Architecture_t > BatchIterator_t
void Shuffle(RNG &rng)
Shuffle the order of the samples in the batch.
std::vector< size_t > fSampleIndices
Ordering of the samples in the epoch.
TTensorBatch< Architecture_t > GetTensorBatch()
Return the next batch from the training set.
size_t fBatchDepth
The number of matrices in the tensor.
size_t fNSamples
The total number of samples in the dataset.
TMatrixT.
Definition TMatrixT.h:39
typename std::vector< size_t >::iterator IndexIterator_t
Definition DataLoader.h:42
create variable transformations