Logo ROOT  
Reference Guide
ReshapeLayer.h
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Author: Vladimir Ilievski
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : TReshapeLayer *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Reshape Deep Neural Network Layer *
12 * *
13 * Authors (alphabetical): *
14 * Vladimir Ilievski <ilievski.vladimir@live.com> - CERN, Switzerland *
15 * *
16 * Copyright (c) 2005-2015: *
17 * CERN, Switzerland *
18 * U. of Victoria, Canada *
19 * MPI-K Heidelberg, Germany *
20 * U. of Bonn, Germany *
21 * *
22 * Redistribution and use in source and binary forms, with or without *
23 * modification, are permitted according to the terms listed in LICENSE *
24 * (http://tmva.sourceforge.net/LICENSE) *
25 **********************************************************************************/
26
27#ifndef TMVA_DNN_RESHAPELAYER
28#define TMVA_DNN_RESHAPELAYER
29
30#include "TMatrix.h"
31
33#include "TMVA/DNN/Functions.h"
34
35#include <iostream>
36
37namespace TMVA {
38namespace DNN {
39
40template <typename Architecture_t>
41class TReshapeLayer : public VGeneralLayer<Architecture_t> {
42public:
43 using Tensor_t = typename Architecture_t::Tensor_t;
44 using Matrix_t = typename Architecture_t::Matrix_t;
45 using Scalar_t = typename Architecture_t::Scalar_t;
46
47private:
48 bool fFlattening; ///< Whather the layer is doing flattening
49
50public:
51 /*! Constructor */
52 TReshapeLayer(size_t BatchSize, size_t InputDepth, size_t InputHeight, size_t InputWidth, size_t Depth,
53 size_t Height, size_t Width, size_t OutputNSlices, size_t OutputNRows, size_t OutputNCols,
54 bool Flattening);
55
56 /*! Copy the reshape layer provided as a pointer */
58
59 /*! Copy Constructor */
61
62 /*! Destructor. */
64
65 /*! The input must be in 3D tensor form with the different matrices
66 * corresponding to different events in the batch. It transforms the
67 * input matrices. */
68 void Forward(Tensor_t &input, bool applyDropout = false);
69
70 void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward);
71 // Tensor_t &inp1, Tensor_t &inp2);
72
73 /*! Prints the info about the layer. */
74 void Print() const;
75
76 /*! Writes the information and the weights about the layer in an XML node. */
77 virtual void AddWeightsXMLTo(void *parent);
78
79 /*! Read the information and the weights about the layer from XML node. */
80 virtual void ReadWeightsFromXML(void *parent);
81
82
83 /*! TODO Add documentation
84 * Does this layer flatten? (necessary for DenseLayer)
85 * B x D1 x D2 --> 1 x B x (D1 * D2) */
86 bool isFlattening() const { return fFlattening; }
87};
88
89//
90//
91// The Reshape Layer Class - Implementation
92//_________________________________________________________________________________________________
93template <typename Architecture_t>
94TReshapeLayer<Architecture_t>::TReshapeLayer(size_t batchSize, size_t inputDepth, size_t inputHeight, size_t inputWidth,
95 size_t depth, size_t height, size_t width, size_t outputNSlices,
96 size_t outputNRows, size_t outputNCols, bool flattening)
97 : VGeneralLayer<Architecture_t>(batchSize, inputDepth, inputHeight, inputWidth, depth, height, width, 0, 0, 0, 0, 0,
98 0, outputNSlices, outputNRows, outputNCols, EInitialization::kZero),
99 fFlattening(flattening)
100{
101 if (this->GetInputDepth() * this->GetInputHeight() * this->GetInputWidth() !=
102 this->GetDepth() * this->GetHeight() * this->GetWidth()) {
103 std::cout << "Reshape Dimensions not compatible \n"
104 << this->GetInputDepth() << " x " << this->GetInputHeight() << " x " << this->GetInputWidth() << " --> "
105 << this->GetDepth() << " x " << this->GetHeight() << " x " << this->GetWidth() << std::endl;
106 return;
107 }
108}
109
110//_________________________________________________________________________________________________
111template <typename Architecture_t>
113 : VGeneralLayer<Architecture_t>(layer), fFlattening(layer->isFlattening())
114{
115}
116
117//_________________________________________________________________________________________________
118template <typename Architecture_t>
120 : VGeneralLayer<Architecture_t>(layer), fFlattening(layer.fFlattening)
121{
122 // Nothing to do here.
123}
124
125//_________________________________________________________________________________________________
126template <typename Architecture_t>
128{
129 // Nothing to do here.
130}
131
132//_________________________________________________________________________________________________
133template <typename Architecture_t>
134auto TReshapeLayer<Architecture_t>::Forward(Tensor_t &input, bool /*applyDropout*/) -> void
135{
136 if (fFlattening) {
137
138 Architecture_t::Flatten(this->GetOutput(), input);
139
140 return;
141 } else {
142
143 Architecture_t::Deflatten(this->GetOutput(), input); //, out_size, nRows, nCols);
144 return;
145 }
146}
147//_________________________________________________________________________________________________
148template <typename Architecture_t>
150 /*activations_backward*/) -> void
151// Tensor_t & /*inp1*/, Tensor_t &
152// /*inp2*/) -> void
153{
154 size_t size = gradients_backward.GetSize();
155 // in case of first layer size is zero - do nothing
156 if (size == 0) return;
157 if (fFlattening) {
158 // deflatten in backprop
159 Architecture_t::Deflatten(gradients_backward, this->GetActivationGradients());
160 return;
161 } else {
162 Architecture_t::Flatten(gradients_backward, this->GetActivationGradients() );
163 return;
164 }
165}
166
167//_________________________________________________________________________________________________
168template <typename Architecture_t>
170{
171 std::cout << " RESHAPE Layer \t ";
172 std::cout << "Input = ( " << this->GetInputDepth() << " , " << this->GetInputHeight() << " , " << this->GetInputWidth() << " ) ";
173 if (this->GetOutput().GetSize() > 0) {
174 std::cout << "\tOutput = ( " << this->GetOutput().GetFirstSize() << " , " << this->GetOutput().GetHSize() << " , " << this->GetOutput().GetWSize() << " ) ";
175 }
176 std::cout << std::endl;
177}
178
179template <typename Architecture_t>
181{
182 auto layerxml = gTools().xmlengine().NewChild(parent, 0, "ReshapeLayer");
183
184 // write info for reshapelayer
185 gTools().xmlengine().NewAttr(layerxml, 0, "Depth", gTools().StringFromInt(this->GetDepth()));
186 gTools().xmlengine().NewAttr(layerxml, 0, "Height", gTools().StringFromInt(this->GetHeight()));
187 gTools().xmlengine().NewAttr(layerxml, 0, "Width", gTools().StringFromInt(this->GetWidth()));
188 gTools().xmlengine().NewAttr(layerxml, 0, "Flattening", gTools().StringFromInt(this->isFlattening()));
189
190
191}
192
193//______________________________________________________________________________
194template <typename Architecture_t>
196{
197 // no info to read
198}
199
200
201
202} // namespace DNN
203} // namespace TMVA
204
205#endif
include TDocParser_001 C image html pict1_TDocParser_001 png width
Definition: TDocParser.cxx:121
TReshapeLayer(size_t BatchSize, size_t InputDepth, size_t InputHeight, size_t InputWidth, size_t Depth, size_t Height, size_t Width, size_t OutputNSlices, size_t OutputNRows, size_t OutputNCols, bool Flattening)
Constructor.
Definition: ReshapeLayer.h:94
virtual void AddWeightsXMLTo(void *parent)
Writes the information and the weights about the layer in an XML node.
Definition: ReshapeLayer.h:180
bool isFlattening() const
TODO Add documentation Does this layer flatten? (necessary for DenseLayer) B x D1 x D2 --> 1 x B x (D...
Definition: ReshapeLayer.h:86
virtual void ReadWeightsFromXML(void *parent)
Read the information and the weights about the layer from XML node.
Definition: ReshapeLayer.h:195
bool fFlattening
Whather the layer is doing flattening.
Definition: ReshapeLayer.h:48
void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward)
Backpropagates the error.
Definition: ReshapeLayer.h:149
void Forward(Tensor_t &input, bool applyDropout=false)
The input must be in 3D tensor form with the different matrices corresponding to different events in ...
Definition: ReshapeLayer.h:134
void Print() const
Prints the info about the layer.
Definition: ReshapeLayer.h:169
~TReshapeLayer()
Destructor.
Definition: ReshapeLayer.h:127
Generic General Layer class.
Definition: GeneralLayer.h:49
typename Architecture_t::Matrix_t Matrix_t
Definition: GeneralLayer.h:52
size_t GetDepth() const
Definition: GeneralLayer.h:164
typename Architecture_t::Scalar_t Scalar_t
Definition: GeneralLayer.h:53
size_t GetInputDepth() const
Definition: GeneralLayer.h:161
size_t GetInputHeight() const
Definition: GeneralLayer.h:162
size_t GetWidth() const
Definition: GeneralLayer.h:166
size_t GetHeight() const
Definition: GeneralLayer.h:165
typename Architecture_t::Tensor_t Tensor_t
Definition: GeneralLayer.h:51
size_t GetInputWidth() const
Definition: GeneralLayer.h:163
TXMLEngine & xmlengine()
Definition: Tools.h:270
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=nullptr)
create new child element for parent node
Definition: TXMLEngine.cxx:709
XMLAttrPointer_t NewAttr(XMLNodePointer_t xmlnode, XMLNsPointer_t, const char *name, const char *value)
creates new attribute for xmlnode, namespaces are not supported for attributes
Definition: TXMLEngine.cxx:580
EInitialization
Definition: Functions.h:70
UInt_t Depth(const Node< T > *node)
Definition: NodekNN.h:213
create variable transformations
Tools & gTools()