27#ifndef TMVA_DNN_RESHAPELAYER
28#define TMVA_DNN_RESHAPELAYER
40template <
typename Architecture_t>
43 using Matrix_t =
typename Architecture_t::Matrix_t;
44 using Scalar_t =
typename Architecture_t::Scalar_t;
51 TReshapeLayer(
size_t BatchSize,
size_t InputDepth,
size_t InputHeight,
size_t InputWidth,
size_t Depth,
52 size_t Height,
size_t Width,
size_t OutputNSlices,
size_t OutputNRows,
size_t OutputNCols,
67 void Forward(std::vector<Matrix_t> &input,
bool applyDropout =
false);
69 void Backward(std::vector<Matrix_t> &gradients_backward,
const std::vector<Matrix_t> &activations_backward,
70 std::vector<Matrix_t> &inp1, std::vector<Matrix_t> &inp2);
92template <
typename Architecture_t>
94 size_t depth,
size_t height,
size_t width,
size_t outputNSlices,
95 size_t outputNRows,
size_t outputNCols,
bool flattening)
96 :
VGeneralLayer<Architecture_t>(batchSize, inputDepth, inputHeight, inputWidth, depth, height,
width, 0, 0, 0, 0, 0,
98 fFlattening(flattening)
102 std::cout <<
"Reshape Dimensions not compatible \n"
110template <
typename Architecture_t>
112 :
VGeneralLayer<Architecture_t>(layer), fFlattening(layer->isFlattening())
117template <
typename Architecture_t>
119 :
VGeneralLayer<Architecture_t>(layer), fFlattening(layer.fFlattening)
125template <
typename Architecture_t>
132template <
typename Architecture_t>
136 size_t size = input.size();
137 size_t nRows = input[0].GetNrows();
138 size_t nCols = input[0].GetNcols();
139 Architecture_t::Flatten(this->GetOutputAt(0), input, size, nRows, nCols);
141 for (
size_t i = 0; i < this->GetBatchSize(); i++) {
142 Architecture_t::Reshape(this->GetOutputAt(i), input[i]);
148template <
typename Architecture_t>
150 const std::vector<Matrix_t> & ,
151 std::vector<Matrix_t> & , std::vector<Matrix_t> &
155 if (gradients_backward.size() == 0)
return;
157 size_t size = gradients_backward.size();
158 size_t nRows = gradients_backward[0].GetNrows();
159 size_t nCols = gradients_backward[0].GetNcols();
160 Architecture_t::Deflatten(gradients_backward, this->GetActivationGradientsAt(0), size, nRows, nCols);
162 for (
size_t i = 0; i < this->GetBatchSize(); i++) {
163 Architecture_t::Reshape(gradients_backward[i], this->GetActivationGradientsAt(i));
169template <
typename Architecture_t>
172 std::cout <<
" RESHAPE Layer \t ";
173 std::cout <<
"Input = ( " << this->GetInputDepth() <<
" , " << this->GetInputHeight() <<
" , " << this->GetInputWidth() <<
" ) ";
174 if (this->GetOutput().size() > 0) {
175 std::cout <<
"\tOutput = ( " << this->GetOutput().size() <<
" , " << this->GetOutput()[0].GetNrows() <<
" , " << this->GetOutput()[0].GetNcols() <<
" ) ";
177 std::cout << std::endl;
180template <
typename Architecture_t>
195template <
typename Architecture_t>
include TDocParser_001 C image html pict1_TDocParser_001 png width
TReshapeLayer(size_t BatchSize, size_t InputDepth, size_t InputHeight, size_t InputWidth, size_t Depth, size_t Height, size_t Width, size_t OutputNSlices, size_t OutputNRows, size_t OutputNCols, bool Flattening)
Constructor.
void Backward(std::vector< Matrix_t > &gradients_backward, const std::vector< Matrix_t > &activations_backward, std::vector< Matrix_t > &inp1, std::vector< Matrix_t > &inp2)
Backpropagates the error.
virtual void AddWeightsXMLTo(void *parent)
Writes the information and the weights about the layer in an XML node.
bool isFlattening() const
TODO Add documentation Does this layer flatten? (necessary for DenseLayer) B x D1 x D2 --> 1 x B x (D...
virtual void ReadWeightsFromXML(void *parent)
Read the information and the weights about the layer from XML node.
bool fFlattening
Whather the layer is doing flattening.
void Forward(std::vector< Matrix_t > &input, bool applyDropout=false)
The input must be in 3D tensor form with the different matrices corresponding to different events in ...
void Print() const
Prints the info about the layer.
~TReshapeLayer()
Destructor.
Generic General Layer class.
typename Architecture_t::Matrix_t Matrix_t
typename Architecture_t::Scalar_t Scalar_t
size_t GetInputDepth() const
size_t GetInputHeight() const
size_t GetInputWidth() const
XMLAttrPointer_t NewAttr(XMLNodePointer_t xmlnode, XMLNsPointer_t, const char *name, const char *value)
creates new attribute for xmlnode, namespaces are not supported for attributes
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=0)
create new child element for parent node
UInt_t Depth(const Node< T > *node)
create variable transformations