30#ifndef TMVA_DNN_GRU_LAYER
31#define TMVA_DNN_GRU_LAYER
55template<
typename Architecture_t>
61 using Matrix_t =
typename Architecture_t::Matrix_t;
62 using Scalar_t =
typename Architecture_t::Scalar_t;
63 using Tensor_t =
typename Architecture_t::Tensor_t;
140 TBasicGRULayer(
size_t batchSize,
size_t stateSize,
size_t inputSize,
141 size_t timeSteps,
bool rememberState =
false,
bool returnSequence =
false,
142 bool resetGateAfter =
false,
167 const Tensor_t &activations_backward);
175 const Matrix_t & precStateActivations,
307template <
typename Architecture_t>
312 :
VGeneralLayer<Architecture_t>(batchSize, 1, timeSteps, inputSize, 1, (returnSequence) ? timeSteps : 1, stateSize,
313 6, {stateSize, stateSize, stateSize, stateSize, stateSize, stateSize},
314 {inputSize, inputSize, inputSize, stateSize, stateSize, stateSize}, 3,
315 {stateSize, stateSize, stateSize}, {1, 1, 1}, batchSize,
316 (returnSequence) ? timeSteps : 1, stateSize, fA),
317 fStateSize(stateSize), fTimeSteps(timeSteps), fRememberState(rememberState), fReturnSequence(returnSequence), fResetGateAfter(resetGateAfter),
318 fF1(
f1), fF2(f2), fResetValue(batchSize, stateSize), fUpdateValue(batchSize, stateSize),
319 fCandidateValue(batchSize, stateSize), fState(batchSize, stateSize), fWeightsResetGate(this->GetWeightsAt(0)),
320 fWeightsResetGateState(this->GetWeightsAt(3)), fResetGateBias(this->GetBiasesAt(0)),
321 fWeightsUpdateGate(this->GetWeightsAt(1)), fWeightsUpdateGateState(this->GetWeightsAt(4)),
322 fUpdateGateBias(this->GetBiasesAt(1)), fWeightsCandidate(this->GetWeightsAt(2)),
323 fWeightsCandidateState(this->GetWeightsAt(5)), fCandidateBias(this->GetBiasesAt(2)),
324 fWeightsResetGradients(this->GetWeightGradientsAt(0)), fWeightsResetStateGradients(this->GetWeightGradientsAt(3)),
325 fResetBiasGradients(this->GetBiasGradientsAt(0)), fWeightsUpdateGradients(this->GetWeightGradientsAt(1)),
326 fWeightsUpdateStateGradients(this->GetWeightGradientsAt(4)), fUpdateBiasGradients(this->GetBiasGradientsAt(1)),
327 fWeightsCandidateGradients(this->GetWeightGradientsAt(2)),
328 fWeightsCandidateStateGradients(this->GetWeightGradientsAt(5)),
329 fCandidateBiasGradients(this->GetBiasGradientsAt(2))
331 for (
size_t i = 0; i < timeSteps; ++i) {
339 Architecture_t::InitializeGRUTensors(
this);
343template <
typename Architecture_t>
346 fStateSize(layer.fStateSize),
347 fTimeSteps(layer.fTimeSteps),
348 fRememberState(layer.fRememberState),
349 fReturnSequence(layer.fReturnSequence),
350 fResetGateAfter(layer.fResetGateAfter),
351 fF1(layer.GetActivationFunctionF1()),
352 fF2(layer.GetActivationFunctionF2()),
353 fResetValue(layer.GetBatchSize(), layer.GetStateSize()),
354 fUpdateValue(layer.GetBatchSize(), layer.GetStateSize()),
355 fCandidateValue(layer.GetBatchSize(), layer.GetStateSize()),
356 fState(layer.GetBatchSize(), layer.GetStateSize()),
357 fWeightsResetGate(this->GetWeightsAt(0)),
358 fWeightsResetGateState(this->GetWeightsAt(3)),
359 fResetGateBias(this->GetBiasesAt(0)),
360 fWeightsUpdateGate(this->GetWeightsAt(1)),
361 fWeightsUpdateGateState(this->GetWeightsAt(4)),
362 fUpdateGateBias(this->GetBiasesAt(1)),
363 fWeightsCandidate(this->GetWeightsAt(2)),
364 fWeightsCandidateState(this->GetWeightsAt(5)),
365 fCandidateBias(this->GetBiasesAt(2)),
366 fWeightsResetGradients(this->GetWeightGradientsAt(0)),
367 fWeightsResetStateGradients(this->GetWeightGradientsAt(3)),
368 fResetBiasGradients(this->GetBiasGradientsAt(0)),
369 fWeightsUpdateGradients(this->GetWeightGradientsAt(1)),
370 fWeightsUpdateStateGradients(this->GetWeightGradientsAt(4)),
371 fUpdateBiasGradients(this->GetBiasGradientsAt(1)),
372 fWeightsCandidateGradients(this->GetWeightGradientsAt(2)),
373 fWeightsCandidateStateGradients(this->GetWeightGradientsAt(5)),
374 fCandidateBiasGradients(this->GetBiasGradientsAt(2))
404 Architecture_t::InitializeGRUTensors(
this);
408template <
typename Architecture_t>
413 Architecture_t::InitializeGRUDescriptors(fDescriptors,
this);
414 Architecture_t::InitializeGRUWorkspace(fWorkspace, fDescriptors,
this);
417 if (Architecture_t::IsCudnn())
418 fResetGateAfter =
true;
422template <
typename Architecture_t>
430 Matrix_t tmpState(fResetValue.GetNrows(), fResetValue.GetNcols());
431 Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsResetGateState);
432 Architecture_t::MultiplyTranspose(fResetValue,
input, fWeightsResetGate);
433 Architecture_t::ScaleAdd(fResetValue, tmpState);
434 Architecture_t::AddRowWise(fResetValue, fResetGateBias);
435 DNN::evaluateDerivativeMatrix<Architecture_t>(dr, fRst, fResetValue);
436 DNN::evaluateMatrix<Architecture_t>(fResetValue, fRst);
440template <
typename Architecture_t>
448 Matrix_t tmpState(fUpdateValue.GetNrows(), fUpdateValue.GetNcols());
449 Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsUpdateGateState);
450 Architecture_t::MultiplyTranspose(fUpdateValue,
input, fWeightsUpdateGate);
451 Architecture_t::ScaleAdd(fUpdateValue, tmpState);
452 Architecture_t::AddRowWise(fUpdateValue, fUpdateGateBias);
453 DNN::evaluateDerivativeMatrix<Architecture_t>(du, fUpd, fUpdateValue);
454 DNN::evaluateMatrix<Architecture_t>(fUpdateValue, fUpd);
458template <
typename Architecture_t>
475 Matrix_t tmp(fCandidateValue.GetNrows(), fCandidateValue.GetNcols());
476 if (!fResetGateAfter) {
478 Architecture_t::Hadamard(tmpState, fState);
479 Architecture_t::MultiplyTranspose(tmp, tmpState, fWeightsCandidateState);
482 Architecture_t::MultiplyTranspose(tmp, fState, fWeightsCandidateState);
483 Architecture_t::Hadamard(tmp, fResetValue);
485 Architecture_t::MultiplyTranspose(fCandidateValue,
input, fWeightsCandidate);
486 Architecture_t::ScaleAdd(fCandidateValue, tmp);
487 Architecture_t::AddRowWise(fCandidateValue, fCandidateBias);
488 DNN::evaluateDerivativeMatrix<Architecture_t>(dc, fCan, fCandidateValue);
489 DNN::evaluateMatrix<Architecture_t>(fCandidateValue, fCan);
493template <
typename Architecture_t>
498 if (Architecture_t::IsCudnn()) {
501 assert(
input.GetStrides()[1] == this->GetInputSize());
505 Architecture_t::Rearrange(
x,
input);
507 const auto &weights = this->GetWeightsAt(0);
509 auto &hx = this->fState;
510 auto &cx = this->fCell;
512 auto &hy = this->fState;
513 auto &cy = this->fCell;
518 Architecture_t::RNNForward(
x, hx, cx, weights,
y, hy, cy, rnnDesc, rnnWork, isTraining);
520 if (fReturnSequence) {
521 Architecture_t::Rearrange(this->GetOutput(),
y);
524 Tensor_t tmp = (
y.At(
y.GetShape()[0] - 1)).Reshape({
y.GetShape()[1], 1,
y.GetShape()[2]});
525 Architecture_t::Copy(this->GetOutput(), tmp);
536 Tensor_t arrInput ( fTimeSteps, this->GetBatchSize(), this->GetInputWidth());
540 Architecture_t::Rearrange(arrInput,
input);
542 Tensor_t arrOutput ( fTimeSteps, this->GetBatchSize(), fStateSize );
547 if (!this->fRememberState) {
553 for (
size_t t = 0; t < fTimeSteps; ++t) {
555 ResetGate(arrInput[t], fDerivativesReset[t]);
556 Architecture_t::Copy(this->GetResetGateTensorAt(t), fResetValue);
557 UpdateGate(arrInput[t], fDerivativesUpdate[t]);
558 Architecture_t::Copy(this->GetUpdateGateTensorAt(t), fUpdateValue);
560 CandidateValue(arrInput[t], fDerivativesCandidate[t]);
561 Architecture_t::Copy(this->GetCandidateGateTensorAt(t), fCandidateValue);
564 CellForward(fUpdateValue, fCandidateValue);
568 Matrix_t arrOutputMt = arrOutput[t];
569 Architecture_t::Copy(arrOutputMt, fState);
573 Architecture_t::Rearrange(this->GetOutput(), arrOutput);
576 Tensor_t tmp = arrOutput.At(fTimeSteps - 1);
579 tmp = tmp.Reshape({tmp.GetShape()[0], tmp.GetShape()[1], 1});
580 assert(tmp.GetSize() == this->GetOutput().GetSize());
581 assert(tmp.GetShape()[0] == this->GetOutput().GetShape()[2]);
582 Architecture_t::Rearrange(this->GetOutput(), tmp);
589template <
typename Architecture_t>
593 Architecture_t::Hadamard(fState, updateGateValues);
597 for (
size_t j = 0; j < (size_t) tmp.GetNcols(); j++) {
598 for (
size_t i = 0; i < (size_t) tmp.GetNrows(); i++) {
599 tmp(i,j) = 1 - tmp(i,j);
604 Architecture_t::Hadamard(candidateValues, tmp);
605 Architecture_t::ScaleAdd(fState, candidateValues);
609template <
typename Architecture_t>
611 const Tensor_t &activations_backward)
615 if (Architecture_t::IsCudnn()) {
623 assert(activations_backward.GetStrides()[1] == this->GetInputSize());
626 Architecture_t::Rearrange(
x, activations_backward);
628 if (!fReturnSequence) {
631 Architecture_t::InitializeZero(dy);
634 Tensor_t tmp2 = dy.At(dy.GetShape()[0] - 1).Reshape({dy.GetShape()[1], 1, dy.GetShape()[2]});
637 Architecture_t::Copy(tmp2, this->GetActivationGradients());
639 Architecture_t::Rearrange(
y, this->GetOutput());
640 Architecture_t::Rearrange(dy, this->GetActivationGradients());
646 const auto &weights = this->GetWeightsTensor();
647 auto &weightGradients = this->GetWeightGradientsTensor();
651 Architecture_t::InitializeZero(weightGradients);
654 auto &hx = this->GetState();
655 auto &cx = this->GetCell();
665 Architecture_t::RNNBackward(
x, hx, cx,
y, dy, dhy, dcy, weights, dx, dhx, dcx, weightGradients, rnnDesc, rnnWork);
669 if (gradients_backward.GetSize() != 0)
670 Architecture_t::Rearrange(gradients_backward, dx);
678 Matrix_t state_gradients_backward(this->GetBatchSize(), fStateSize);
683 if (gradients_backward.GetSize() == 0 || gradients_backward[0].GetNrows() == 0 || gradients_backward[0].GetNcols() == 0) {
687 Tensor_t arr_gradients_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
692 Tensor_t arr_activations_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
694 Architecture_t::Rearrange(arr_activations_backward, activations_backward);
698 Tensor_t arr_output ( fTimeSteps, this->GetBatchSize(), fStateSize);
700 Matrix_t initState(this->GetBatchSize(), fStateSize);
704 Tensor_t arr_actgradients ( fTimeSteps, this->GetBatchSize(), fStateSize);
706 if (fReturnSequence) {
707 Architecture_t::Rearrange(arr_output, this->GetOutput());
708 Architecture_t::Rearrange(arr_actgradients, this->GetActivationGradients());
712 Architecture_t::InitializeZero(arr_actgradients);
714 Tensor_t tmp_grad = arr_actgradients.At(fTimeSteps - 1).Reshape({this->GetBatchSize(), fStateSize, 1});
715 assert(tmp_grad.GetSize() == this->GetActivationGradients().GetSize());
716 assert(tmp_grad.GetShape()[0] ==
717 this->GetActivationGradients().GetShape()[2]);
719 Architecture_t::Rearrange(tmp_grad, this->GetActivationGradients());
726 fWeightsResetGradients.Zero();
727 fWeightsResetStateGradients.Zero();
728 fResetBiasGradients.Zero();
731 fWeightsUpdateGradients.Zero();
732 fWeightsUpdateStateGradients.Zero();
733 fUpdateBiasGradients.Zero();
736 fWeightsCandidateGradients.Zero();
737 fWeightsCandidateStateGradients.Zero();
738 fCandidateBiasGradients.Zero();
741 for (
size_t t = fTimeSteps; t > 0; t--) {
743 Architecture_t::ScaleAdd(state_gradients_backward, arr_actgradients[t-1]);
745 const Matrix_t &prevStateActivations = arr_output[t-2];
746 Matrix_t dx = arr_gradients_backward[t-1];
748 CellBackward(state_gradients_backward, prevStateActivations,
749 this->GetResetGateTensorAt(t-1), this->GetUpdateGateTensorAt(t-1),
750 this->GetCandidateGateTensorAt(t-1),
751 arr_activations_backward[t-1], dx ,
752 fDerivativesReset[t-1], fDerivativesUpdate[t-1],
753 fDerivativesCandidate[t-1]);
755 const Matrix_t &prevStateActivations = initState;
756 Matrix_t dx = arr_gradients_backward[t-1];
757 CellBackward(state_gradients_backward, prevStateActivations,
758 this->GetResetGateTensorAt(t-1), this->GetUpdateGateTensorAt(t-1),
759 this->GetCandidateGateTensorAt(t-1),
760 arr_activations_backward[t-1], dx ,
761 fDerivativesReset[t-1], fDerivativesUpdate[t-1],
762 fDerivativesCandidate[t-1]);
767 Architecture_t::Rearrange(gradients_backward, arr_gradients_backward );
774template <
typename Architecture_t>
776 const Matrix_t & precStateActivations,
785 return Architecture_t::GRULayerBackward(state_gradients_backward,
786 fWeightsResetGradients, fWeightsUpdateGradients, fWeightsCandidateGradients,
787 fWeightsResetStateGradients, fWeightsUpdateStateGradients,
788 fWeightsCandidateStateGradients, fResetBiasGradients, fUpdateBiasGradients,
789 fCandidateBiasGradients, dr, du, dc,
790 precStateActivations,
791 reset_gate, update_gate, candidate_gate,
792 fWeightsResetGate, fWeightsUpdateGate, fWeightsCandidate,
793 fWeightsResetGateState, fWeightsUpdateGateState, fWeightsCandidateState,
794 input, input_gradient, fResetGateAfter);
799template <
typename Architecture_t>
807template<
typename Architecture_t>
811 std::cout <<
" GRU Layer: \t ";
812 std::cout <<
" (NInput = " << this->GetInputSize();
813 std::cout <<
", NState = " << this->GetStateSize();
814 std::cout <<
", NTime = " << this->GetTimeSteps() <<
" )";
815 std::cout <<
"\tOutput = ( " << this->GetOutput().GetFirstSize() <<
" , " << this->GetOutput()[0].GetNrows() <<
" , " << this->GetOutput()[0].GetNcols() <<
" )\n";
819template <
typename Architecture_t>
834 this->WriteMatrixToXML(layerxml,
"ResetWeights", this->GetWeightsAt(0));
835 this->WriteMatrixToXML(layerxml,
"ResetStateWeights", this->GetWeightsAt(1));
836 this->WriteMatrixToXML(layerxml,
"ResetBiases", this->GetBiasesAt(0));
837 this->WriteMatrixToXML(layerxml,
"UpdateWeights", this->GetWeightsAt(2));
838 this->WriteMatrixToXML(layerxml,
"UpdateStateWeights", this->GetWeightsAt(3));
839 this->WriteMatrixToXML(layerxml,
"UpdateBiases", this->GetBiasesAt(1));
840 this->WriteMatrixToXML(layerxml,
"CandidateWeights", this->GetWeightsAt(4));
841 this->WriteMatrixToXML(layerxml,
"CandidateStateWeights", this->GetWeightsAt(5));
842 this->WriteMatrixToXML(layerxml,
"CandidateBiases", this->GetBiasesAt(2));
846template <
typename Architecture_t>
851 this->ReadMatrixXML(parent,
"ResetWeights", this->GetWeightsAt(0));
852 this->ReadMatrixXML(parent,
"ResetStateWeights", this->GetWeightsAt(1));
853 this->ReadMatrixXML(parent,
"ResetBiases", this->GetBiasesAt(0));
854 this->ReadMatrixXML(parent,
"UpdateWeights", this->GetWeightsAt(2));
855 this->ReadMatrixXML(parent,
"UpdateStateWeights", this->GetWeightsAt(3));
856 this->ReadMatrixXML(parent,
"UpdateBiases", this->GetBiasesAt(1));
857 this->ReadMatrixXML(parent,
"CandidateWeights", this->GetWeightsAt(4));
858 this->ReadMatrixXML(parent,
"CandidateStateWeights", this->GetWeightsAt(5));
859 this->ReadMatrixXML(parent,
"CandidateBiases", this->GetBiasesAt(2));
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
const Matrix_t & GetWeightsCandidate() const
Matrix_t & GetWeightsCandidateStateGradients()
typename Architecture_t::RecurrentDescriptor_t LayerDescriptor_t
Matrix_t & GetWeightsResetGate()
Matrix_t & fResetBiasGradients
Gradients w.r.t the reset gate - bias weights.
std::vector< Matrix_t > & GetUpdateGateTensor()
typename Architecture_t::Tensor_t Tensor_t
std::vector< Matrix_t > reset_gate_value
Reset gate value for every time step.
Matrix_t & CellBackward(Matrix_t &state_gradients_backward, const Matrix_t &precStateActivations, const Matrix_t &reset_gate, const Matrix_t &update_gate, const Matrix_t &candidate_gate, const Matrix_t &input, Matrix_t &input_gradient, Matrix_t &dr, Matrix_t &du, Matrix_t &dc)
Backward for a single time unit a the corresponding call to Forward(...).
size_t fStateSize
Hidden state size for GRU.
const Matrix_t & GetWeightsResetGradients() const
const Matrix_t & GetUpdateBiasGradients() const
bool fReturnSequence
Return in output full sequence or just last element.
void Forward(Tensor_t &input, bool isTraining=true)
Computes the next hidden state and next cell state with given input matrix.
const Matrix_t & GetWeightsResetStateGradients() const
std::vector< Matrix_t > fDerivativesReset
First fDerivatives of the activations reset gate.
const Tensor_t & GetWeightsTensor() const
std::vector< Matrix_t > & GetResetGateTensor()
Matrix_t & GetWeightsUpdateGateState()
const std::vector< Matrix_t > & GetCandidateGateTensor() const
const Matrix_t & GetUpdateDerivativesAt(size_t i) const
Matrix_t & GetWeightsUpdateStateGradients()
size_t GetInputSize() const
Getters.
Matrix_t fState
Hidden state of GRU.
Matrix_t & GetWeightsResetGradients()
Tensor_t & GetWeightGradientsTensor()
const Matrix_t & GetCandidateBias() const
std::vector< Matrix_t > update_gate_value
Update gate value for every time step.
Tensor_t & GetWeightsTensor()
Tensor_t fX
cached input tensor as T x B x I
Matrix_t & GetCandidateGateTensorAt(size_t i)
Matrix_t & GetResetBiasGradients()
Matrix_t & GetCandidateValue()
Matrix_t & GetWeightsResetGateState()
DNN::EActivationFunction fF1
Activation function: sigmoid.
const Matrix_t & GetWeightsUpdateGate() const
const std::vector< Matrix_t > & GetDerivativesReset() const
const Matrix_t & GetUpdateGateBias() const
Matrix_t & fWeightsResetGradients
Gradients w.r.t the reset gate - input weights.
std::vector< Matrix_t > & GetDerivativesUpdate()
Matrix_t & fCandidateBiasGradients
Gradients w.r.t the candidate gate - bias weights.
Matrix_t & fCandidateBias
Candidate Gate bias.
Matrix_t & GetUpdateGateTensorAt(size_t i)
DNN::EActivationFunction fF2
Activation function: tanh.
const Matrix_t & GetWeightsUpdateGradients() const
Matrix_t & GetWeightsCandidateGradients()
Matrix_t & fWeightsUpdateStateGradients
Gradients w.r.t the update gate - hidden state weights.
void AddWeightsXMLTo(void *parent)
Writes the information and the weights about the layer in an XML node.
Matrix_t & GetWeightsCandidate()
Matrix_t & fWeightsUpdateGradients
Gradients w.r.t the update gate - input weights.
Matrix_t & GetUpdateGateValue()
size_t fTimeSteps
Timesteps for GRU.
std::vector< Matrix_t > fDerivativesCandidate
First fDerivatives of the activations candidate gate.
const Tensor_t & GetWeightGradientsTensor() const
typename Architecture_t::FilterDescriptor_t WeightsDescriptor_t
Tensor_t fWeightGradientsTensor
Tensor for all weight gradients.
Matrix_t & fUpdateBiasGradients
Gradients w.r.t the update gate - bias weights.
Matrix_t & GetWeightsResetStateGradients()
std::vector< Matrix_t > & GetCandidateGateTensor()
Matrix_t & fWeightsResetGate
Reset Gate weights for input, fWeights[0].
const Matrix_t & GetResetDerivativesAt(size_t i) const
Matrix_t & GetWeightsUpdateGate()
typename Architecture_t::Matrix_t Matrix_t
void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward)
Backpropagates the error.
const Matrix_t & GetCandidateGateTensorAt(size_t i) const
Matrix_t & GetWeightsCandidateState()
const Matrix_t & GetCandidateBiasGradients() const
Matrix_t & GetResetGateTensorAt(size_t i)
Matrix_t & fResetGateBias
Input Gate bias.
const std::vector< Matrix_t > & GetResetGateTensor() const
Matrix_t fCell
Empty matrix for GRU.
std::vector< Matrix_t > candidate_gate_value
Candidate gate value for every time step.
typename Architecture_t::Scalar_t Scalar_t
const Matrix_t & GetWeigthsUpdateStateGradients() const
const Matrix_t & GetCandidateValue() const
Matrix_t & GetCandidateBiasGradients()
Matrix_t & fWeightsCandidateStateGradients
Gradients w.r.t the candidate gate - hidden state weights.
const std::vector< Matrix_t > & GetDerivativesUpdate() const
const Matrix_t & GetCell() const
void UpdateGate(const Matrix_t &input, Matrix_t &df)
Forgets the past values (NN with Sigmoid)
const Matrix_t & GetCandidateDerivativesAt(size_t i) const
Matrix_t fResetValue
Computed reset gate values.
DNN::EActivationFunction GetActivationFunctionF2() const
Matrix_t & GetResetGateBias()
typename Architecture_t::RNNWorkspace_t RNNWorkspace_t
Matrix_t fUpdateValue
Computed forget gate values.
const Matrix_t & GetResetBiasGradients() const
bool fResetGateAfter
GRU variant to Apply the reset gate multiplication afterwards (used by cuDNN)
const Matrix_t & GetWeightsCandidateGradients() const
DNN::EActivationFunction GetActivationFunctionF1() const
bool DoesReturnSequence() const
Matrix_t & GetUpdateBiasGradients()
const Matrix_t & GetUpdateGateTensorAt(size_t i) const
Matrix_t & fWeightsResetGateState
Input Gate weights for prev state, fWeights[1].
Matrix_t & fWeightsUpdateGateState
Update Gate weights for prev state, fWeights[3].
const std::vector< Matrix_t > & GetDerivativesCandidate() const
virtual void Initialize()
Initialize the weights according to the given initialization method.
Tensor_t fWeightsTensor
Tensor for all weights.
void ReadWeightsFromXML(void *parent)
Read the information and the weights about the layer from XML node.
typename Architecture_t::RNNDescriptors_t RNNDescriptors_t
const Matrix_t & GetResetGateBias() const
Matrix_t & GetResetDerivativesAt(size_t i)
const Matrix_t & GetUpdateGateValue() const
const Matrix_t & GetResetGateTensorAt(size_t i) const
TDescriptors * fDescriptors
Keeps all the RNN descriptors.
void CellForward(Matrix_t &updateGateValues, Matrix_t &candidateValues)
Forward for a single cell (time unit)
Matrix_t & GetWeightsUpdateGradients()
Matrix_t & fWeightsResetStateGradients
Gradients w.r.t the reset gate - hidden state weights.
Matrix_t & fWeightsCandidateState
Candidate Gate weights for prev state, fWeights[5].
void Print() const
Prints the info about the layer.
size_t GetStateSize() const
std::vector< Matrix_t > & GetDerivativesReset()
Matrix_t & fUpdateGateBias
Update Gate bias.
const Matrix_t & GetWeightsCandidateStateGradients() const
void ResetGate(const Matrix_t &input, Matrix_t &di)
Decides the values we'll update (NN with Sigmoid)
const Matrix_t & GetWeightsResetGate() const
Tensor_t fDx
cached gradient on the input (output of backward) as T x B x I
Matrix_t & GetCandidateBias()
typename Architecture_t::TensorDescriptor_t TensorDescriptor_t
bool fRememberState
Remember state in next pass.
Matrix_t & fWeightsCandidate
Candidate Gate weights for input, fWeights[4].
Matrix_t & fWeightsCandidateGradients
Gradients w.r.t the candidate gate - input weights.
const Matrix_t & GetWeightsCandidateState() const
const std::vector< Matrix_t > & GetUpdateGateTensor() const
const Matrix_t & GetResetGateValue() const
void Update(const Scalar_t learningRate)
Tensor_t fY
cached output tensor as T x B x S
Matrix_t fCandidateValue
Computed candidate values.
const Matrix_t & GetState() const
Matrix_t & GetUpdateGateBias()
void InitState(DNN::EInitialization m=DNN::EInitialization::kZero)
Initialize the hidden state and cell state method.
Tensor_t fDy
cached activation gradient (input of backward) as T x B x S
Matrix_t & GetCandidateDerivativesAt(size_t i)
std::vector< Matrix_t > fDerivativesUpdate
First fDerivatives of the activations update gate.
size_t GetTimeSteps() const
const Matrix_t & GetWeightsUpdateGateState() const
std::vector< Matrix_t > & GetDerivativesCandidate()
bool DoesRememberState() const
const Matrix_t & GetWeightsResetGateState() const
void CandidateValue(const Matrix_t &input, Matrix_t &dc)
Decides the new candidate values (NN with Tanh)
TBasicGRULayer(size_t batchSize, size_t stateSize, size_t inputSize, size_t timeSteps, bool rememberState=false, bool returnSequence=false, bool resetGateAfter=false, DNN::EActivationFunction f1=DNN::EActivationFunction::kSigmoid, DNN::EActivationFunction f2=DNN::EActivationFunction::kTanh, bool training=true, DNN::EInitialization fA=DNN::EInitialization::kZero)
Constructor.
Matrix_t & GetUpdateDerivativesAt(size_t i)
Matrix_t & fWeightsUpdateGate
Update Gate weights for input, fWeights[2].
typename Architecture_t::DropoutDescriptor_t HelperDescriptor_t
Matrix_t & GetResetGateValue()
Generic General Layer class.
virtual void Initialize()
Initialize the weights and biases according to the given initialization method.
size_t GetBatchSize() const
Getters.
size_t GetInputWidth() const
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=nullptr)
create new child element for parent node
XMLAttrPointer_t NewAttr(XMLNodePointer_t xmlnode, XMLNsPointer_t, const char *name, const char *value)
creates new attribute for xmlnode, namespaces are not supported for attributes
EActivationFunction
Enum that represents layer activation functions.
create variable transformations