27#ifndef TMVA_DNN_RESHAPELAYER
28#define TMVA_DNN_RESHAPELAYER
40template <
typename Architecture_t>
43 using Tensor_t =
typename Architecture_t::Tensor_t;
44 using Matrix_t =
typename Architecture_t::Matrix_t;
45 using Scalar_t =
typename Architecture_t::Scalar_t;
52 TReshapeLayer(
size_t BatchSize,
size_t InputDepth,
size_t InputHeight,
size_t InputWidth,
size_t Depth,
53 size_t Height,
size_t Width,
size_t OutputNSlices,
size_t OutputNRows,
size_t OutputNCols,
74 void Print()
const override;
93template <
typename Architecture_t>
95 size_t depth,
size_t height,
size_t width,
size_t outputNSlices,
96 size_t outputNRows,
size_t outputNCols,
bool flattening)
97 :
VGeneralLayer<Architecture_t>(batchSize, inputDepth, inputHeight, inputWidth, depth, height, width, 0, 0, 0, 0, 0,
103 std::cout <<
"Reshape Dimensions not compatible \n"
111template <
typename Architecture_t>
118template <
typename Architecture_t>
126template <
typename Architecture_t>
133template <
typename Architecture_t>
138 Architecture_t::Flatten(this->
GetOutput(), input);
143 Architecture_t::Deflatten(this->
GetOutput(), input);
148template <
typename Architecture_t>
154 size_t size = gradients_backward.GetSize();
156 if (
size == 0)
return;
168template <
typename Architecture_t>
171 std::cout <<
" RESHAPE Layer \t ";
174 std::cout <<
"\tOutput = ( " << this->
GetOutput().GetFirstSize() <<
" , " << this->
GetOutput().GetHSize() <<
" , " << this->
GetOutput().GetWSize() <<
" ) ";
176 std::cout << std::endl;
179template <
typename Architecture_t>
194template <
typename Architecture_t>
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
typename Architecture_t::Scalar_t Scalar_t
typename Architecture_t::Tensor_t Tensor_t
void ReadWeightsFromXML(void *parent) override
Read the information and the weights about the layer from XML node.
TReshapeLayer(size_t BatchSize, size_t InputDepth, size_t InputHeight, size_t InputWidth, size_t Depth, size_t Height, size_t Width, size_t OutputNSlices, size_t OutputNRows, size_t OutputNCols, bool Flattening)
Constructor.
typename Architecture_t::Matrix_t Matrix_t
void Forward(Tensor_t &input, bool applyDropout=false) override
The input must be in 3D tensor form with the different matrices corresponding to different events in ...
bool isFlattening() const
TODO Add documentation Does this layer flatten?
bool fFlattening
Whether the layer is doing flattening.
void AddWeightsXMLTo(void *parent) override
Writes the information and the weights about the layer in an XML node.
void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward) override
Backpropagates the error.
void Print() const override
Prints the info about the layer.
~TReshapeLayer()
Destructor.
const Tensor_t & GetOutput() const
size_t GetInputDepth() const
const Tensor_t & GetActivationGradients() const
size_t GetInputHeight() const
VGeneralLayer(size_t BatchSize, size_t InputDepth, size_t InputHeight, size_t InputWidth, size_t Depth, size_t Height, size_t Width, size_t WeightsNSlices, size_t WeightsNRows, size_t WeightsNCols, size_t BiasesNSlices, size_t BiasesNRows, size_t BiasesNCols, size_t OutputNSlices, size_t OutputNRows, size_t OutputNCols, EInitialization Init)
Constructor.
size_t GetInputWidth() const
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=nullptr)
create new child element for parent node
XMLAttrPointer_t NewAttr(XMLNodePointer_t xmlnode, XMLNsPointer_t, const char *name, const char *value)
creates new attribute for xmlnode, namespaces are not supported for attributes
create variable transformations