Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ReshapeLayer.h
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Author: Vladimir Ilievski
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : TReshapeLayer *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Reshape Deep Neural Network Layer *
12 * *
13 * Authors (alphabetical): *
14 * Vladimir Ilievski <ilievski.vladimir@live.com> - CERN, Switzerland *
15 * *
16 * Copyright (c) 2005-2015: *
17 * CERN, Switzerland *
18 * U. of Victoria, Canada *
19 * MPI-K Heidelberg, Germany *
20 * U. of Bonn, Germany *
21 * *
22 * Redistribution and use in source and binary forms, with or without *
23 * modification, are permitted according to the terms listed in LICENSE *
24 * (http://tmva.sourceforge.net/LICENSE) *
25 **********************************************************************************/
26
27#ifndef TMVA_DNN_RESHAPELAYER
28#define TMVA_DNN_RESHAPELAYER
29
30#include "TMatrix.h"
31
33#include "TMVA/DNN/Functions.h"
34
35#include <iostream>
36
37namespace TMVA {
38namespace DNN {
39
40template <typename Architecture_t>
41class TReshapeLayer : public VGeneralLayer<Architecture_t> {
42public:
43 using Tensor_t = typename Architecture_t::Tensor_t;
44 using Matrix_t = typename Architecture_t::Matrix_t;
45 using Scalar_t = typename Architecture_t::Scalar_t;
46
47private:
48 bool fFlattening; ///< Whather the layer is doing flattening
49
50public:
51 /*! Constructor */
52 TReshapeLayer(size_t BatchSize, size_t InputDepth, size_t InputHeight, size_t InputWidth, size_t Depth,
53 size_t Height, size_t Width, size_t OutputNSlices, size_t OutputNRows, size_t OutputNCols,
54 bool Flattening);
55
56 /*! Copy the reshape layer provided as a pointer */
58
59 /*! Copy Constructor */
61
62 /*! Destructor. */
64
65 /*! The input must be in 3D tensor form with the different matrices
66 * corresponding to different events in the batch. It transforms the
67 * input matrices. */
68 void Forward(Tensor_t &input, bool applyDropout = false);
69
70 void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward);
71 // Tensor_t &inp1, Tensor_t &inp2);
72
73 /*! Prints the info about the layer. */
74 void Print() const;
75
76 /*! Writes the information and the weights about the layer in an XML node. */
77 virtual void AddWeightsXMLTo(void *parent);
78
79 /*! Read the information and the weights about the layer from XML node. */
80 virtual void ReadWeightsFromXML(void *parent);
81
82
83 /*! TODO Add documentation
84 * Does this layer flatten? (necessary for DenseLayer)
85 * B x D1 x D2 --> 1 x B x (D1 * D2) */
86 bool isFlattening() const { return fFlattening; }
87};
88
89//
90//
91// The Reshape Layer Class - Implementation
92//_________________________________________________________________________________________________
93template <typename Architecture_t>
94TReshapeLayer<Architecture_t>::TReshapeLayer(size_t batchSize, size_t inputDepth, size_t inputHeight, size_t inputWidth,
95 size_t depth, size_t height, size_t width, size_t outputNSlices,
96 size_t outputNRows, size_t outputNCols, bool flattening)
97 : VGeneralLayer<Architecture_t>(batchSize, inputDepth, inputHeight, inputWidth, depth, height, width, 0, 0, 0, 0, 0,
98 0, outputNSlices, outputNRows, outputNCols, EInitialization::kZero),
99 fFlattening(flattening)
100{
101 if (this->GetInputDepth() * this->GetInputHeight() * this->GetInputWidth() !=
102 this->GetDepth() * this->GetHeight() * this->GetWidth()) {
103 std::cout << "Reshape Dimensions not compatible \n"
104 << this->GetInputDepth() << " x " << this->GetInputHeight() << " x " << this->GetInputWidth() << " --> "
105 << this->GetDepth() << " x " << this->GetHeight() << " x " << this->GetWidth() << std::endl;
106 return;
107 }
108}
109
110//_________________________________________________________________________________________________
111template <typename Architecture_t>
113 : VGeneralLayer<Architecture_t>(layer), fFlattening(layer->isFlattening())
114{
115}
116
117//_________________________________________________________________________________________________
118template <typename Architecture_t>
120 : VGeneralLayer<Architecture_t>(layer), fFlattening(layer.fFlattening)
121{
122 // Nothing to do here.
123}
124
125//_________________________________________________________________________________________________
126template <typename Architecture_t>
128{
129 // Nothing to do here.
130}
131
132//_________________________________________________________________________________________________
133template <typename Architecture_t>
134auto TReshapeLayer<Architecture_t>::Forward(Tensor_t &input, bool /*applyDropout*/) -> void
135{
136 if (fFlattening) {
137
138 Architecture_t::Flatten(this->GetOutput(), input);
139
140 return;
141 } else {
142
143 Architecture_t::Deflatten(this->GetOutput(), input); //, out_size, nRows, nCols);
144 return;
145 }
146}
147//_________________________________________________________________________________________________
148template <typename Architecture_t>
150 /*activations_backward*/) -> void
151// Tensor_t & /*inp1*/, Tensor_t &
152// /*inp2*/) -> void
153{
154 size_t size = gradients_backward.GetSize();
155 // in case of first layer size is zero - do nothing
156 if (size == 0) return;
157 if (fFlattening) {
158 // deflatten in backprop
159 Architecture_t::Deflatten(gradients_backward, this->GetActivationGradients());
160 return;
161 } else {
162 Architecture_t::Flatten(gradients_backward, this->GetActivationGradients() );
163 return;
164 }
165}
166
167//_________________________________________________________________________________________________
168template <typename Architecture_t>
170{
171 std::cout << " RESHAPE Layer \t ";
172 std::cout << "Input = ( " << this->GetInputDepth() << " , " << this->GetInputHeight() << " , " << this->GetInputWidth() << " ) ";
173 if (this->GetOutput().GetSize() > 0) {
174 std::cout << "\tOutput = ( " << this->GetOutput().GetFirstSize() << " , " << this->GetOutput().GetHSize() << " , " << this->GetOutput().GetWSize() << " ) ";
175 }
176 std::cout << std::endl;
177}
178
179template <typename Architecture_t>
181{
182 auto layerxml = gTools().xmlengine().NewChild(parent, 0, "ReshapeLayer");
183
184 // write info for reshapelayer
185 gTools().xmlengine().NewAttr(layerxml, 0, "Depth", gTools().StringFromInt(this->GetDepth()));
186 gTools().xmlengine().NewAttr(layerxml, 0, "Height", gTools().StringFromInt(this->GetHeight()));
187 gTools().xmlengine().NewAttr(layerxml, 0, "Width", gTools().StringFromInt(this->GetWidth()));
188 gTools().xmlengine().NewAttr(layerxml, 0, "Flattening", gTools().StringFromInt(this->isFlattening()));
189
190
191}
192
193//______________________________________________________________________________
194template <typename Architecture_t>
196{
197 // no info to read
198}
199
200
201
202} // namespace DNN
203} // namespace TMVA
204
205#endif
include TDocParser_001 C image html pict1_TDocParser_001 png width
typename Architecture_t::Scalar_t Scalar_t
typename Architecture_t::Tensor_t Tensor_t
TReshapeLayer(size_t BatchSize, size_t InputDepth, size_t InputHeight, size_t InputWidth, size_t Depth, size_t Height, size_t Width, size_t OutputNSlices, size_t OutputNRows, size_t OutputNCols, bool Flattening)
Constructor.
typename Architecture_t::Matrix_t Matrix_t
virtual void AddWeightsXMLTo(void *parent)
Writes the information and the weights about the layer in an XML node.
bool isFlattening() const
TODO Add documentation Does this layer flatten? (necessary for DenseLayer) B x D1 x D2 --> 1 x B x (D...
virtual void ReadWeightsFromXML(void *parent)
Read the information and the weights about the layer from XML node.
bool fFlattening
Whather the layer is doing flattening.
void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward)
Backpropagates the error.
void Forward(Tensor_t &input, bool applyDropout=false)
The input must be in 3D tensor form with the different matrices corresponding to different events in ...
void Print() const
Prints the info about the layer.
Generic General Layer class.
size_t GetInputDepth() const
size_t GetInputHeight() const
size_t GetInputWidth() const
TXMLEngine & xmlengine()
Definition Tools.h:268
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=nullptr)
create new child element for parent node
XMLAttrPointer_t NewAttr(XMLNodePointer_t xmlnode, XMLNsPointer_t, const char *name, const char *value)
creates new attribute for xmlnode, namespaces are not supported for attributes
create variable transformations
Tools & gTools()