29#ifndef TMVA_DNN_RNN_LAYER
30#define TMVA_DNN_RNN_LAYER
55template<
typename Architecture_t>
61 using Tensor_t =
typename Architecture_t::Tensor_t;
62 using Matrix_t =
typename Architecture_t::Matrix_t;
63 using Scalar_t =
typename Architecture_t::Scalar_t;
112 TBasicRNNLayer(
size_t batchSize,
size_t stateSize,
size_t inputSize,
113 size_t timeSteps,
bool rememberState =
false,
bool returnSequence =
false,
141 const Tensor_t &activations_backward)
override;
149 const Matrix_t & precStateActivations,
153 void Print()
const override;
212template <
typename Architecture_t>
217 :
VGeneralLayer<Architecture_t>(batchSize, 1, timeSteps, inputSize, 1, (returnSequence) ? timeSteps : 1 ,
218 stateSize, 2, {stateSize, stateSize}, {inputSize, stateSize}, 1, {stateSize}, {1},
219 batchSize, (returnSequence) ? timeSteps : 1, stateSize, fA),
220 fTimeSteps(timeSteps), fStateSize(stateSize), fRememberState(rememberState), fReturnSequence(returnSequence), fF(
f), fState(batchSize, stateSize),
221 fWeightsInput(this->GetWeightsAt(0)), fWeightsState(this->GetWeightsAt(1)),
222 fBiases(this->GetBiasesAt(0)), fDerivatives(timeSteps, batchSize, stateSize),
223 fWeightInputGradients(this->GetWeightGradientsAt(0)), fWeightStateGradients(this->GetWeightGradientsAt(1)),
224 fBiasGradients(this->GetBiasGradientsAt(0)), fWeightsTensor({0}), fWeightGradientsTensor({0})
230template <
typename Architecture_t>
241 Architecture_t::Copy(fDerivatives, layer.GetDerivatives() );
244 Architecture_t::Copy(fState, layer.GetState());
248template <
typename Architecture_t>
263template<
typename Architecture_t>
273 Architecture_t::InitializeRNNDescriptors(
fDescriptors,
this);
278template <
typename Architecture_t>
282 Architecture_t::InitializeRNNTensors(
this);
285template <
typename Architecture_t>
294template<
typename Architecture_t>
298 std::cout <<
" RECURRENT Layer: \t ";
301 std::cout <<
", NTime = " << this->
GetTimeSteps() <<
" )";
302 std::cout <<
"\tOutput = ( " << this->
GetOutput().GetFirstSize() <<
" , " << this->
GetOutput().GetHSize() <<
" , " << this->
GetOutput().GetWSize() <<
" )\n";
305template <
typename Architecture_t>
306auto debugMatrix(
const typename Architecture_t::Matrix_t &A,
const std::string
name =
"matrix")
309 std::cout <<
name <<
"\n";
310 for (
size_t i = 0; i < A.GetNrows(); ++i) {
311 for (
size_t j = 0; j < A.GetNcols(); ++j) {
312 std::cout << A(i, j) <<
" ";
316 std::cout <<
"********\n";
321template <
typename Architecture_t>
327 if (Architecture_t::IsCudnn()) {
332 Architecture_t::Rearrange(
x, input);
354 Architecture_t::RNNForward(
x, hx, cx, weights,
y, hy, cy, rnnDesc, rnnWork, isTraining);
357 Architecture_t::Rearrange(this->
GetOutput(),
y);
361 Tensor_t tmp = (
y.At(
y.GetShape()[0] - 1)).Reshape({
y.GetShape()[1], 1,
y.GetShape()[2]});
362 Architecture_t::Copy(this->
GetOutput(), tmp);
375 Architecture_t::Rearrange(arrInput, input);
382 Matrix_t arrInput_m = arrInput.At(t).GetMatrix();
385 Matrix_t arrOutput_m = arrOutput.At(t).GetMatrix();
386 Architecture_t::Copy(arrOutput_m,
fState);
390 Architecture_t::Rearrange(this->
GetOutput(), arrOutput);
397 tmp = tmp.Reshape({tmp.GetShape()[0], tmp.GetShape()[1], 1});
398 assert(tmp.GetSize() == this->GetOutput().GetSize());
399 assert(tmp.GetShape()[0] == this->GetOutput().GetShape()[2]);
400 Architecture_t::Rearrange(this->
GetOutput(), tmp);
407template <
typename Architecture_t>
416 Architecture_t::ScaleAdd(
fState, tmpState);
424 Architecture_t::Copy(inputActivFunc, tState);
425 Architecture_t::ActivationFunctionForward(tState, fAF,
fActivationDesc);
430template <
typename Architecture_t>
432 const Tensor_t &activations_backward) ->
void
437 if (Architecture_t::IsCudnn() ) {
445 assert(activations_backward.GetStrides()[1] == this->GetInputSize() );
447 Architecture_t::Rearrange(
x, activations_backward);
452 Architecture_t::InitializeZero(dy);
455 Tensor_t tmp2 = dy.At(dy.GetShape()[0] - 1).Reshape({dy.GetShape()[1], 1, dy.GetShape()[2]});
461 Architecture_t::Rearrange(
y, this->
GetOutput());
473 Architecture_t::InitializeZero(weightGradients);
489 Architecture_t::RNNBackward(
x, hx, cx,
y, dy, dhy, dcy, weights, dx, dhx, dcx, weightGradients, rnnDesc, rnnWork);
491 if (gradients_backward.GetSize() != 0)
492 Architecture_t::Rearrange(gradients_backward, dx);
505 if (gradients_backward.GetSize() == 0) {
517 Architecture_t::Rearrange(arr_activations_backward, activations_backward);
529 Architecture_t::Rearrange(arr_output, this->
GetOutput());
535 Architecture_t::InitializeZero(arr_actgradients);
538 assert(tmp_grad.GetSize() == this->GetActivationGradients().GetSize());
539 assert(tmp_grad.GetShape()[0] ==
540 this->GetActivationGradients().GetShape()[2]);
552 Matrix_t actgrad_m = arr_actgradients.At(t - 1).GetMatrix();
553 Architecture_t::ScaleAdd(state_gradients_backward, actgrad_m);
555 Matrix_t actbw_m = arr_activations_backward.At(t - 1).GetMatrix();
556 Matrix_t gradbw_m = arr_gradients_backward.At(t - 1).GetMatrix();
563 Architecture_t::ActivationFunctionBackward(df,
y,
571 Matrix_t precStateActivations = arr_output.At(t - 2).GetMatrix();
572 CellBackward(state_gradients_backward, precStateActivations, actbw_m, gradbw_m, df_m);
575 const Matrix_t & precStateActivations = initState;
576 CellBackward(state_gradients_backward, precStateActivations, actbw_m, gradbw_m, df_m);
581 Architecture_t::Rearrange(gradients_backward, arr_gradients_backward );
586template <
typename Architecture_t>
588 const Matrix_t & precStateActivations,
598template <
typename Architecture_t>
619template <
typename Architecture_t>
Tensor_t fDy
cached activation gradient (input of backward) as T x B x S
size_t GetStateSize() const
typename Architecture_t::RNNDescriptors_t RNNDescriptors_t
DNN::EActivationFunction GetActivationFunction() const
void InitState(DNN::EInitialization m=DNN::EInitialization::kZero)
Initialize the state method.
const Matrix_t & GetWeightInputGradients() const
const Tensor_t & GetWeightGradientsTensor() const
typename Architecture_t::RecurrentDescriptor_t LayerDescriptor_t
Tensor_t fY
cached output tensor as T x B x S
void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward) override
Backpropagates the error.
Tensor_t fDerivatives
First fDerivatives of the activations.
const Matrix_t & GetWeightStateGradients() const
Matrix_t & fWeightsInput
Input weights, fWeights[0].
Matrix_t & fWeightsState
Prev state weights, fWeights[1].
virtual ~TBasicRNNLayer()
Destructor.
void Print() const override
Prints the info about the layer.
TDescriptors * fDescriptors
Keeps all the RNN descriptors.
Tensor_t fX
cached input tensor as T x B x I
Matrix_t & fBiases
Biases.
Architecture_t::ActivationDescriptor_t fActivationDesc
typename Architecture_t::TensorDescriptor_t TensorDescriptor_t
bool fReturnSequence
Return in output full sequence or just last element in time.
const Tensor_t & GetWeightsTensor() const
Matrix_t & GetBiasStateGradients()
Tensor_t fWeightGradientsTensor
size_t fStateSize
Hidden state size of RNN.
Matrix_t & GetWeightsState()
const Matrix_t & GetState() const
Tensor_t & GetDerivatives()
const Matrix_t & GetCell() const
Matrix_t & CellBackward(Matrix_t &state_gradients_backward, const Matrix_t &precStateActivations, const Matrix_t &input, Matrix_t &input_gradient, Matrix_t &dF)
Backward for a single time unit a the corresponding call to Forward(...).
typename Architecture_t::Matrix_t Matrix_t
void ReadWeightsFromXML(void *parent) override
Read the information and the weights about the layer from XML node.
typename Architecture_t::DropoutDescriptor_t HelperDescriptor_t
typename Architecture_t::RNNWorkspace_t RNNWorkspace_t
Matrix_t fState
Hidden State.
Matrix_t & fWeightInputGradients
Gradients w.r.t. the input weights.
DNN::EActivationFunction fF
Activation function of the hidden state.
TBasicRNNLayer(size_t batchSize, size_t stateSize, size_t inputSize, size_t timeSteps, bool rememberState=false, bool returnSequence=false, DNN::EActivationFunction f=DNN::EActivationFunction::kTanh, bool training=true, DNN::EInitialization fA=DNN::EInitialization::kZero)
Constructor.
Tensor_t & GetWeightGradientsTensor()
Matrix_t & GetWeightsInput()
size_t GetTimeSteps() const
Getters.
bool fRememberState
Remember state in next pass.
Tensor_t & GetWeightsTensor()
Matrix_t & fWeightStateGradients
Gradients w.r.t. the recurring weights.
Matrix_t & GetWeightInputGradients()
const Matrix_t & GetBiasesState() const
void Update(const Scalar_t learningRate)
typename Architecture_t::Scalar_t Scalar_t
size_t fTimeSteps
Timesteps for RNN.
bool DoesRememberState() const
void CellForward(const Matrix_t &input, Matrix_t &dF)
Forward for a single cell (time unit).
Tensor_t fDx
cached gradient on the input (output of backward) as T x B x I
typename Architecture_t::Tensor_t Tensor_t
void AddWeightsXMLTo(void *parent) override
Writes the information and the weights about the layer in an XML node.
const Matrix_t & GetBiasStateGradients() const
size_t GetInputSize() const
Matrix_t & GetWeightStateGradients()
bool DoesReturnSequence() const
Matrix_t & fBiasGradients
Gradients w.r.t. the bias values.
const Matrix_t & GetWeightsInput() const
Matrix_t fCell
Empty matrix for RNN.
void Initialize() override
Initialize the weights according to the given initialization method.
const Tensor_t & GetDerivatives() const
Matrix_t & GetBiasesState()
const Matrix_t & GetWeightsState() const
typename Architecture_t::FilterDescriptor_t WeightsDescriptor_t
void Forward(Tensor_t &input, bool isTraining=true) override
Compute and return the next state with given input matrix.
const Matrix_t & GetWeightsAt(size_t i) const
virtual void Initialize()
Initialize the weights and biases according to the given initialization method.
const Tensor_t & GetOutput() const
void WriteMatrixToXML(void *node, const char *name, const Matrix_t &matrix)
const Tensor_t & GetActivationGradients() const
const Matrix_t & GetBiasesAt(size_t i) const
const Matrix_t & GetBiasGradientsAt(size_t i) const
size_t GetBatchSize() const
Getters.
void ReadMatrixXML(void *node, const char *name, Matrix_t &matrix)
const Matrix_t & GetWeightGradientsAt(size_t i) const
VGeneralLayer(size_t BatchSize, size_t InputDepth, size_t InputHeight, size_t InputWidth, size_t Depth, size_t Height, size_t Width, size_t WeightsNSlices, size_t WeightsNRows, size_t WeightsNCols, size_t BiasesNSlices, size_t BiasesNRows, size_t BiasesNCols, size_t OutputNSlices, size_t OutputNRows, size_t OutputNCols, EInitialization Init)
Constructor.
size_t GetInputWidth() const
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=nullptr)
create new child element for parent node
XMLAttrPointer_t NewAttr(XMLNodePointer_t xmlnode, XMLNsPointer_t, const char *name, const char *value)
creates new attribute for xmlnode, namespaces are not supported for attributes
auto debugMatrix(const typename Architecture_t::Matrix_t &A, const std::string name="matrix") -> void
EActivationFunction
Enum that represents layer activation functions.
void initialize(typename Architecture_t::Matrix_t &A, EInitialization m)
create variable transformations