30#ifndef TMVA_DNN_LSTM_LAYER
31#define TMVA_DNN_LSTM_LAYER
55template<
typename Architecture_t>
61 using Matrix_t =
typename Architecture_t::Matrix_t;
62 using Scalar_t =
typename Architecture_t::Scalar_t;
63 using Tensor_t =
typename Architecture_t::Tensor_t;
147 TBasicLSTMLayer(
size_t batchSize,
size_t stateSize,
size_t inputSize,
size_t timeSteps,
bool rememberState =
false,
148 bool returnSequence =
false,
174 const Tensor_t &activations_backward);
183 const Matrix_t & precStateActivations,
const Matrix_t & precCellActivations,
340template <
typename Architecture_t>
346 batchSize, 1, timeSteps, inputSize, 1, (returnSequence) ? timeSteps : 1, stateSize, 8,
347 {stateSize, stateSize, stateSize, stateSize, stateSize, stateSize, stateSize, stateSize},
348 {inputSize, inputSize, inputSize, inputSize, stateSize, stateSize, stateSize, stateSize}, 4,
349 {stateSize, stateSize, stateSize, stateSize}, {1, 1, 1, 1}, batchSize, (returnSequence) ? timeSteps : 1,
351 fStateSize(stateSize), fCellSize(stateSize), fTimeSteps(timeSteps), fRememberState(rememberState),
352 fReturnSequence(returnSequence), fF1(
f1), fF2(f2), fInputValue(batchSize, stateSize),
353 fCandidateValue(batchSize, stateSize), fForgetValue(batchSize, stateSize), fOutputValue(batchSize, stateSize),
354 fState(batchSize, stateSize), fCell(batchSize, stateSize), fWeightsInputGate(this->GetWeightsAt(0)),
355 fWeightsInputGateState(this->GetWeightsAt(4)), fInputGateBias(this->GetBiasesAt(0)),
356 fWeightsForgetGate(this->GetWeightsAt(1)), fWeightsForgetGateState(this->GetWeightsAt(5)),
357 fForgetGateBias(this->GetBiasesAt(1)), fWeightsCandidate(this->GetWeightsAt(2)),
358 fWeightsCandidateState(this->GetWeightsAt(6)), fCandidateBias(this->GetBiasesAt(2)),
359 fWeightsOutputGate(this->GetWeightsAt(3)), fWeightsOutputGateState(this->GetWeightsAt(7)),
360 fOutputGateBias(this->GetBiasesAt(3)), fWeightsInputGradients(this->GetWeightGradientsAt(0)),
361 fWeightsInputStateGradients(this->GetWeightGradientsAt(4)), fInputBiasGradients(this->GetBiasGradientsAt(0)),
362 fWeightsForgetGradients(this->GetWeightGradientsAt(1)),
363 fWeightsForgetStateGradients(this->GetWeightGradientsAt(5)), fForgetBiasGradients(this->GetBiasGradientsAt(1)),
364 fWeightsCandidateGradients(this->GetWeightGradientsAt(2)),
365 fWeightsCandidateStateGradients(this->GetWeightGradientsAt(6)),
366 fCandidateBiasGradients(this->GetBiasGradientsAt(2)), fWeightsOutputGradients(this->GetWeightGradientsAt(3)),
367 fWeightsOutputStateGradients(this->GetWeightGradientsAt(7)), fOutputBiasGradients(this->GetBiasGradientsAt(3))
369 for (
size_t i = 0; i < timeSteps; ++i) {
378 cell_value.emplace_back(batchSize, stateSize);
380 Architecture_t::InitializeLSTMTensors(
this);
384template <
typename Architecture_t>
387 fStateSize(layer.fStateSize),
388 fCellSize(layer.fCellSize),
389 fTimeSteps(layer.fTimeSteps),
390 fRememberState(layer.fRememberState),
391 fReturnSequence(layer.fReturnSequence),
392 fF1(layer.GetActivationFunctionF1()),
393 fF2(layer.GetActivationFunctionF2()),
394 fInputValue(layer.GetBatchSize(), layer.GetStateSize()),
395 fCandidateValue(layer.GetBatchSize(), layer.GetStateSize()),
396 fForgetValue(layer.GetBatchSize(), layer.GetStateSize()),
397 fOutputValue(layer.GetBatchSize(), layer.GetStateSize()),
398 fState(layer.GetBatchSize(), layer.GetStateSize()),
399 fCell(layer.GetBatchSize(), layer.GetCellSize()),
400 fWeightsInputGate(this->GetWeightsAt(0)),
401 fWeightsInputGateState(this->GetWeightsAt(4)),
402 fInputGateBias(this->GetBiasesAt(0)),
403 fWeightsForgetGate(this->GetWeightsAt(1)),
404 fWeightsForgetGateState(this->GetWeightsAt(5)),
405 fForgetGateBias(this->GetBiasesAt(1)),
406 fWeightsCandidate(this->GetWeightsAt(2)),
407 fWeightsCandidateState(this->GetWeightsAt(6)),
408 fCandidateBias(this->GetBiasesAt(2)),
409 fWeightsOutputGate(this->GetWeightsAt(3)),
410 fWeightsOutputGateState(this->GetWeightsAt(7)),
411 fOutputGateBias(this->GetBiasesAt(3)),
412 fWeightsInputGradients(this->GetWeightGradientsAt(0)),
413 fWeightsInputStateGradients(this->GetWeightGradientsAt(4)),
414 fInputBiasGradients(this->GetBiasGradientsAt(0)),
415 fWeightsForgetGradients(this->GetWeightGradientsAt(1)),
416 fWeightsForgetStateGradients(this->GetWeightGradientsAt(5)),
417 fForgetBiasGradients(this->GetBiasGradientsAt(1)),
418 fWeightsCandidateGradients(this->GetWeightGradientsAt(2)),
419 fWeightsCandidateStateGradients(this->GetWeightGradientsAt(6)),
420 fCandidateBiasGradients(this->GetBiasGradientsAt(2)),
421 fWeightsOutputGradients(this->GetWeightGradientsAt(3)),
422 fWeightsOutputStateGradients(this->GetWeightGradientsAt(7)),
423 fOutputBiasGradients(this->GetBiasGradientsAt(3))
464 Architecture_t::InitializeLSTMTensors(
this);
468template <
typename Architecture_t>
473 Architecture_t::InitializeLSTMDescriptors(fDescriptors,
this);
474 Architecture_t::InitializeLSTMWorkspace(fWorkspace, fDescriptors,
this);
478template <
typename Architecture_t>
486 Matrix_t tmpState(fInputValue.GetNrows(), fInputValue.GetNcols());
487 Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsInputGateState);
488 Architecture_t::MultiplyTranspose(fInputValue, input, fWeightsInputGate);
489 Architecture_t::ScaleAdd(fInputValue, tmpState);
490 Architecture_t::AddRowWise(fInputValue, fInputGateBias);
491 DNN::evaluateDerivativeMatrix<Architecture_t>(di, fInp, fInputValue);
492 DNN::evaluateMatrix<Architecture_t>(fInputValue, fInp);
496template <
typename Architecture_t>
504 Matrix_t tmpState(fForgetValue.GetNrows(), fForgetValue.GetNcols());
505 Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsForgetGateState);
506 Architecture_t::MultiplyTranspose(fForgetValue, input, fWeightsForgetGate);
507 Architecture_t::ScaleAdd(fForgetValue, tmpState);
508 Architecture_t::AddRowWise(fForgetValue, fForgetGateBias);
509 DNN::evaluateDerivativeMatrix<Architecture_t>(df, fFor, fForgetValue);
510 DNN::evaluateMatrix<Architecture_t>(fForgetValue, fFor);
514template <
typename Architecture_t>
522 Matrix_t tmpState(fCandidateValue.GetNrows(), fCandidateValue.GetNcols());
523 Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsCandidateState);
524 Architecture_t::MultiplyTranspose(fCandidateValue, input, fWeightsCandidate);
525 Architecture_t::ScaleAdd(fCandidateValue, tmpState);
526 Architecture_t::AddRowWise(fCandidateValue, fCandidateBias);
527 DNN::evaluateDerivativeMatrix<Architecture_t>(dc, fCan, fCandidateValue);
528 DNN::evaluateMatrix<Architecture_t>(fCandidateValue, fCan);
532template <
typename Architecture_t>
540 Matrix_t tmpState(fOutputValue.GetNrows(), fOutputValue.GetNcols());
541 Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsOutputGateState);
542 Architecture_t::MultiplyTranspose(fOutputValue, input, fWeightsOutputGate);
543 Architecture_t::ScaleAdd(fOutputValue, tmpState);
544 Architecture_t::AddRowWise(fOutputValue, fOutputGateBias);
545 DNN::evaluateDerivativeMatrix<Architecture_t>(dout, fOut, fOutputValue);
546 DNN::evaluateMatrix<Architecture_t>(fOutputValue, fOut);
552template <
typename Architecture_t>
558 if (Architecture_t::IsCudnn()) {
561 assert(input.GetStrides()[1] == this->GetInputSize());
565 Architecture_t::Rearrange(
x, input);
567 const auto &weights = this->GetWeightsAt(0);
572 auto &hx = this->fState;
574 auto &cx = this->fCell;
576 auto &hy = this->fState;
577 auto &cy = this->fCell;
582 Architecture_t::RNNForward(
x, hx, cx, weights,
y, hy, cy, rnnDesc, rnnWork, isTraining);
584 if (fReturnSequence) {
585 Architecture_t::Rearrange(this->GetOutput(),
y);
588 Tensor_t tmp = (
y.At(
y.GetShape()[0] - 1)).Reshape({
y.GetShape()[1], 1,
y.GetShape()[2]});
589 Architecture_t::Copy(this->GetOutput(), tmp);
602 Tensor_t arrInput( fTimeSteps, this->GetBatchSize(), this->GetInputWidth());
605 Architecture_t::Rearrange(arrInput, input);
607 Tensor_t arrOutput ( fTimeSteps, this->GetBatchSize(), fStateSize);
610 if (!this->fRememberState) {
616 for (
size_t t = 0; t < fTimeSteps; ++t) {
619 InputGate(arrInputMt, fDerivativesInput[t]);
620 ForgetGate(arrInputMt, fDerivativesForget[t]);
621 CandidateValue(arrInputMt, fDerivativesCandidate[t]);
622 OutputGate(arrInputMt, fDerivativesOutput[t]);
624 Architecture_t::Copy(this->GetInputGateTensorAt(t), fInputValue);
625 Architecture_t::Copy(this->GetForgetGateTensorAt(t), fForgetValue);
626 Architecture_t::Copy(this->GetCandidateGateTensorAt(t), fCandidateValue);
627 Architecture_t::Copy(this->GetOutputGateTensorAt(t), fOutputValue);
629 CellForward(fInputValue, fForgetValue, fCandidateValue, fOutputValue);
630 Matrix_t arrOutputMt = arrOutput[t];
631 Architecture_t::Copy(arrOutputMt, fState);
632 Architecture_t::Copy(this->GetCellTensorAt(t), fCell);
637 Architecture_t::Rearrange(this->GetOutput(), arrOutput);
640 Tensor_t tmp = arrOutput.At(fTimeSteps - 1);
643 tmp = tmp.Reshape( {tmp.GetShape()[0], tmp.GetShape()[1], 1});
644 assert(tmp.GetSize() == this->GetOutput().GetSize());
645 assert( tmp.GetShape()[0] == this->GetOutput().GetShape()[2]);
646 Architecture_t::Rearrange(this->GetOutput(), tmp);
653template <
typename Architecture_t>
660 Architecture_t::Hadamard(fCell, forgetGateValues);
661 Architecture_t::Hadamard(inputGateValues, candidateValues);
662 Architecture_t::ScaleAdd(fCell, inputGateValues);
664 Matrix_t cache(fCell.GetNrows(), fCell.GetNcols());
665 Architecture_t::Copy(cache, fCell);
669 DNN::evaluateMatrix<Architecture_t>(cache, fAT);
674 Architecture_t::Copy(fState, cache);
675 Architecture_t::Hadamard(fState, outputGateValues);
679template <
typename Architecture_t>
681 const Tensor_t &activations_backward)
686 if (Architecture_t::IsCudnn()) {
694 assert(activations_backward.GetStrides()[1] == this->GetInputSize());
696 Architecture_t::Rearrange(
x, activations_backward);
698 if (!fReturnSequence) {
701 Architecture_t::InitializeZero(dy);
706 Tensor_t tmp2 = dy.At(dy.GetShape()[0] - 1).Reshape({dy.GetShape()[1], 1, dy.GetShape()[2]});
709 Architecture_t::Copy(tmp2, this->GetActivationGradients());
711 Architecture_t::Rearrange(
y, this->GetOutput());
712 Architecture_t::Rearrange(dy, this->GetActivationGradients());
718 const auto &weights = this->GetWeightsTensor();
719 auto &weightGradients = this->GetWeightGradientsTensor();
722 Architecture_t::InitializeZero(weightGradients);
725 auto &hx = this->GetState();
726 auto &cx = this->GetCell();
737 Architecture_t::RNNBackward(
x, hx, cx,
y, dy, dhy, dcy, weights, dx, dhx, dcx, weightGradients, rnnDesc, rnnWork);
741 if (gradients_backward.GetSize() != 0)
742 Architecture_t::Rearrange(gradients_backward, dx);
751 Matrix_t state_gradients_backward(this->GetBatchSize(), fStateSize);
755 Matrix_t cell_gradients_backward(this->GetBatchSize(), fStateSize);
760 if (gradients_backward.GetSize() == 0 || gradients_backward[0].GetNrows() == 0 || gradients_backward[0].GetNcols() == 0) {
765 Tensor_t arr_gradients_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
770 Tensor_t arr_activations_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
772 Architecture_t::Rearrange(arr_activations_backward, activations_backward);
776 Tensor_t arr_output ( fTimeSteps, this->GetBatchSize(), fStateSize);
778 Matrix_t initState(this->GetBatchSize(), fCellSize);
783 Tensor_t arr_actgradients(fTimeSteps, this->GetBatchSize(), fStateSize);
785 if (fReturnSequence) {
786 Architecture_t::Rearrange(arr_output, this->GetOutput());
787 Architecture_t::Rearrange(arr_actgradients, this->GetActivationGradients());
791 Architecture_t::InitializeZero(arr_actgradients);
793 Tensor_t tmp_grad = arr_actgradients.At(fTimeSteps - 1).Reshape( {this->GetBatchSize(), fStateSize, 1});
794 assert(tmp_grad.GetSize() == this->GetActivationGradients().GetSize());
795 assert(tmp_grad.GetShape()[0] == this->GetActivationGradients().GetShape()[2]);
797 Architecture_t::Rearrange(tmp_grad, this->GetActivationGradients());
804 fWeightsInputGradients.Zero();
805 fWeightsInputStateGradients.Zero();
806 fInputBiasGradients.Zero();
809 fWeightsForgetGradients.Zero();
810 fWeightsForgetStateGradients.Zero();
811 fForgetBiasGradients.Zero();
814 fWeightsCandidateGradients.Zero();
815 fWeightsCandidateStateGradients.Zero();
816 fCandidateBiasGradients.Zero();
819 fWeightsOutputGradients.Zero();
820 fWeightsOutputStateGradients.Zero();
821 fOutputBiasGradients.Zero();
824 for (
size_t t = fTimeSteps; t > 0; t--) {
826 Architecture_t::ScaleAdd(state_gradients_backward, arr_actgradients[t-1]);
828 const Matrix_t &prevStateActivations = arr_output[t-2];
829 const Matrix_t &prevCellActivations = this->GetCellTensorAt(t-2);
831 Matrix_t dx = arr_gradients_backward[t-1];
832 CellBackward(state_gradients_backward, cell_gradients_backward,
833 prevStateActivations, prevCellActivations,
834 this->GetInputGateTensorAt(t-1), this->GetForgetGateTensorAt(t-1),
835 this->GetCandidateGateTensorAt(t-1), this->GetOutputGateTensorAt(t-1),
836 arr_activations_backward[t-1], dx,
837 fDerivativesInput[t-1], fDerivativesForget[t-1],
838 fDerivativesCandidate[t-1], fDerivativesOutput[t-1], t-1);
840 const Matrix_t &prevStateActivations = initState;
841 const Matrix_t &prevCellActivations = initState;
842 Matrix_t dx = arr_gradients_backward[t-1];
843 CellBackward(state_gradients_backward, cell_gradients_backward,
844 prevStateActivations, prevCellActivations,
845 this->GetInputGateTensorAt(t-1), this->GetForgetGateTensorAt(t-1),
846 this->GetCandidateGateTensorAt(t-1), this->GetOutputGateTensorAt(t-1),
847 arr_activations_backward[t-1], dx,
848 fDerivativesInput[t-1], fDerivativesForget[t-1],
849 fDerivativesCandidate[t-1], fDerivativesOutput[t-1], t-1);
854 Architecture_t::Rearrange(gradients_backward, arr_gradients_backward );
861template <
typename Architecture_t>
864 const Matrix_t & precStateActivations,
const Matrix_t & precCellActivations,
878 Matrix_t cell_gradient(this->GetCellTensorAt(t).GetNrows(), this->GetCellTensorAt(t).GetNcols());
879 DNN::evaluateDerivativeMatrix<Architecture_t>(cell_gradient, fAT, this->GetCellTensorAt(t));
882 Matrix_t cell_tanh(this->GetCellTensorAt(t).GetNrows(), this->GetCellTensorAt(t).GetNcols());
883 Architecture_t::Copy(cell_tanh, this->GetCellTensorAt(t));
884 DNN::evaluateMatrix<Architecture_t>(cell_tanh, fAT);
886 return Architecture_t::LSTMLayerBackward(state_gradients_backward, cell_gradients_backward,
887 fWeightsInputGradients, fWeightsForgetGradients, fWeightsCandidateGradients,
888 fWeightsOutputGradients, fWeightsInputStateGradients, fWeightsForgetStateGradients,
889 fWeightsCandidateStateGradients, fWeightsOutputStateGradients, fInputBiasGradients, fForgetBiasGradients,
890 fCandidateBiasGradients, fOutputBiasGradients, di, df, dc, dout,
891 precStateActivations, precCellActivations,
892 input_gate, forget_gate, candidate_gate, output_gate,
893 fWeightsInputGate, fWeightsForgetGate, fWeightsCandidate, fWeightsOutputGate,
894 fWeightsInputGateState, fWeightsForgetGateState, fWeightsCandidateState,
895 fWeightsOutputGateState, input, input_gradient,
896 cell_gradient, cell_tanh);
900template <
typename Architecture_t>
909template<
typename Architecture_t>
913 std::cout <<
" LSTM Layer: \t ";
914 std::cout <<
" (NInput = " << this->GetInputSize();
915 std::cout <<
", NState = " << this->GetStateSize();
916 std::cout <<
", NTime = " << this->GetTimeSteps() <<
" )";
917 std::cout <<
"\tOutput = ( " << this->GetOutput().GetFirstSize() <<
" , " << this->GetOutput()[0].GetNrows() <<
" , " << this->GetOutput()[0].GetNcols() <<
" )\n";
921template <
typename Architecture_t>
936 this->WriteMatrixToXML(layerxml,
"InputWeights", this->GetWeightsAt(0));
937 this->WriteMatrixToXML(layerxml,
"InputStateWeights", this->GetWeightsAt(1));
938 this->WriteMatrixToXML(layerxml,
"InputBiases", this->GetBiasesAt(0));
939 this->WriteMatrixToXML(layerxml,
"ForgetWeights", this->GetWeightsAt(2));
940 this->WriteMatrixToXML(layerxml,
"ForgetStateWeights", this->GetWeightsAt(3));
941 this->WriteMatrixToXML(layerxml,
"ForgetBiases", this->GetBiasesAt(1));
942 this->WriteMatrixToXML(layerxml,
"CandidateWeights", this->GetWeightsAt(4));
943 this->WriteMatrixToXML(layerxml,
"CandidateStateWeights", this->GetWeightsAt(5));
944 this->WriteMatrixToXML(layerxml,
"CandidateBiases", this->GetBiasesAt(2));
945 this->WriteMatrixToXML(layerxml,
"OuputWeights", this->GetWeightsAt(6));
946 this->WriteMatrixToXML(layerxml,
"OutputStateWeights", this->GetWeightsAt(7));
947 this->WriteMatrixToXML(layerxml,
"OutputBiases", this->GetBiasesAt(3));
951template <
typename Architecture_t>
956 this->ReadMatrixXML(parent,
"InputWeights", this->GetWeightsAt(0));
957 this->ReadMatrixXML(parent,
"InputStateWeights", this->GetWeightsAt(1));
958 this->ReadMatrixXML(parent,
"InputBiases", this->GetBiasesAt(0));
959 this->ReadMatrixXML(parent,
"ForgetWeights", this->GetWeightsAt(2));
960 this->ReadMatrixXML(parent,
"ForgetStateWeights", this->GetWeightsAt(3));
961 this->ReadMatrixXML(parent,
"ForgetBiases", this->GetBiasesAt(1));
962 this->ReadMatrixXML(parent,
"CandidateWeights", this->GetWeightsAt(4));
963 this->ReadMatrixXML(parent,
"CandidateStateWeights", this->GetWeightsAt(5));
964 this->ReadMatrixXML(parent,
"CandidateBiases", this->GetBiasesAt(2));
965 this->ReadMatrixXML(parent,
"OuputWeights", this->GetWeightsAt(6));
966 this->ReadMatrixXML(parent,
"OutputStateWeights", this->GetWeightsAt(7));
967 this->ReadMatrixXML(parent,
"OutputBiases", this->GetBiasesAt(3));
void InputGate(const Matrix_t &input, Matrix_t &di)
Decides the values we'll update (NN with Sigmoid)
const Matrix_t & GetForgetGateTensorAt(size_t i) const
Matrix_t & GetWeightsOutputGateState()
const std::vector< Matrix_t > & GetOutputGateTensor() const
Tensor_t fWeightsTensor
Tensor for all weights.
const std::vector< Matrix_t > & GetInputGateTensor() const
std::vector< Matrix_t > & GetDerivativesOutput()
const Matrix_t & GetWeigthsForgetStateGradients() const
Matrix_t & GetWeightsForgetGate()
typename Architecture_t::Matrix_t Matrix_t
Matrix_t & GetCandidateGateTensorAt(size_t i)
void InitState(DNN::EInitialization m=DNN::EInitialization::kZero)
Initialize the hidden state and cell state method.
Matrix_t & fWeightsCandidateGradients
Gradients w.r.t the candidate gate - input weights.
const Matrix_t & GetOutputGateBias() const
Matrix_t & GetWeightsCandidateStateGradients()
Matrix_t & GetWeightsInputGate()
Matrix_t & GetWeightsInputGateState()
const std::vector< Matrix_t > & GetCandidateGateTensor() const
const Matrix_t & GetInputGateTensorAt(size_t i) const
std::vector< Matrix_t > & GetForgetGateTensor()
std::vector< Matrix_t > cell_value
cell value for every time step
Matrix_t & fWeightsOutputGradients
Gradients w.r.t the output gate - input weights.
Matrix_t & GetOutputGateBias()
Matrix_t & fOutputBiasGradients
Gradients w.r.t the output gate - bias weights.
DNN::EActivationFunction fF1
Activation function: sigmoid.
virtual void Initialize()
Initialize the weights according to the given initialization method.
Tensor_t fDy
cached activation gradient (input of backward) as T x B x S
Matrix_t & fWeightsOutputGate
Output Gate weights for input, fWeights[6].
Matrix_t & GetForgetGateBias()
Matrix_t & fWeightsCandidateStateGradients
Gradients w.r.t the candidate gate - hidden state weights.
void Forward(Tensor_t &input, bool isTraining=true)
Computes the next hidden state and next cell state with given input matrix.
const Matrix_t & GetInputGateBias() const
typename Architecture_t::Scalar_t Scalar_t
size_t GetInputSize() const
Getters.
Matrix_t & GetForgetGateTensorAt(size_t i)
const Matrix_t & GetOutputGateTensorAt(size_t i) const
const Matrix_t & GetCellTensorAt(size_t i) const
Tensor_t fX
cached input tensor as T x B x I
DNN::EActivationFunction GetActivationFunctionF2() const
Matrix_t & GetCellTensorAt(size_t i)
Matrix_t & fWeightsInputStateGradients
Gradients w.r.t the input gate - hidden state weights.
void CellForward(Matrix_t &inputGateValues, const Matrix_t &forgetGateValues, const Matrix_t &candidateValues, const Matrix_t &outputGateValues)
Forward for a single cell (time unit)
Matrix_t & CellBackward(Matrix_t &state_gradients_backward, Matrix_t &cell_gradients_backward, const Matrix_t &precStateActivations, const Matrix_t &precCellActivations, const Matrix_t &input_gate, const Matrix_t &forget_gate, const Matrix_t &candidate_gate, const Matrix_t &output_gate, const Matrix_t &input, Matrix_t &input_gradient, Matrix_t &di, Matrix_t &df, Matrix_t &dc, Matrix_t &dout, size_t t)
Backward for a single time unit a the corresponding call to Forward(...).
const Matrix_t & GetWeightsInputStateGradients() const
std::vector< Matrix_t > fDerivativesOutput
First fDerivatives of the activations output gate.
size_t GetStateSize() const
Matrix_t & fWeightsForgetGateState
Forget Gate weights for prev state, fWeights[3].
Matrix_t & fOutputGateBias
Output Gate bias.
std::vector< Matrix_t > fDerivativesCandidate
First fDerivatives of the activations candidate gate.
const Matrix_t & GetInputDerivativesAt(size_t i) const
Matrix_t & fWeightsForgetGate
Forget Gate weights for input, fWeights[2].
Matrix_t & fWeightsInputGradients
Gradients w.r.t the input gate - input weights.
typename Architecture_t::Tensor_t Tensor_t
const std::vector< Matrix_t > & GetDerivativesInput() const
Matrix_t & GetWeightsCandidate()
Matrix_t & fForgetGateBias
Forget Gate bias.
Matrix_t & GetWeightsInputGradients()
Matrix_t & GetCandidateBiasGradients()
Matrix_t & GetWeightsOutputGradients()
Matrix_t & fCandidateBias
Candidate Gate bias.
Matrix_t fCandidateValue
Computed candidate values.
Tensor_t & GetWeightGradientsTensor()
bool DoesRememberState() const
const Matrix_t & GetWeightsOutputGradients() const
typename Architecture_t::RecurrentDescriptor_t LayerDescriptor_t
const Matrix_t & GetWeightsInputGradients() const
Matrix_t & GetWeightsCandidateState()
Matrix_t & GetInputBiasGradients()
const Matrix_t & GetInputBiasGradients() const
size_t GetTimeSteps() const
DNN::EActivationFunction fF2
Activaton function: tanh.
Matrix_t & fInputBiasGradients
Gradients w.r.t the input gate - bias weights.
Matrix_t & GetWeightsOutputStateGradients()
Matrix_t & fWeightsCandidateState
Candidate Gate weights for prev state, fWeights[5].
void AddWeightsXMLTo(void *parent)
Writes the information and the weights about the layer in an XML node.
Matrix_t & GetForgetGateValue()
std::vector< Matrix_t > fDerivativesForget
First fDerivatives of the activations forget gate.
const Tensor_t & GetWeightGradientsTensor() const
Matrix_t & GetForgetDerivativesAt(size_t i)
const Matrix_t & GetWeightsInputGateState() const
Matrix_t & GetWeightsInputStateGradients()
typename Architecture_t::DropoutDescriptor_t HelperDescriptor_t
Matrix_t & fForgetBiasGradients
Gradients w.r.t the forget gate - bias weights.
const Matrix_t & GetCandidateBias() const
std::vector< Matrix_t > output_gate_value
output gate value for every time step
const std::vector< Matrix_t > & GetDerivativesCandidate() const
size_t fStateSize
Hidden state size for LSTM.
void CandidateValue(const Matrix_t &input, Matrix_t &dc)
Decides the new candidate values (NN with Tanh)
std::vector< Matrix_t > fDerivativesInput
First fDerivatives of the activations input gate.
const Matrix_t & GetWeightsForgetGateState() const
Matrix_t & GetWeightsForgetGateState()
const Matrix_t & GetWeightsInputGate() const
const Matrix_t & GetInputGateValue() const
void Update(const Scalar_t learningRate)
bool DoesReturnSequence() const
Tensor_t fDx
cached gradient on the input (output of backward) as T x B x I
typename Architecture_t::RNNWorkspace_t RNNWorkspace_t
Matrix_t & GetOutputGateValue()
TBasicLSTMLayer(size_t batchSize, size_t stateSize, size_t inputSize, size_t timeSteps, bool rememberState=false, bool returnSequence=false, DNN::EActivationFunction f1=DNN::EActivationFunction::kSigmoid, DNN::EActivationFunction f2=DNN::EActivationFunction::kTanh, bool training=true, DNN::EInitialization fA=DNN::EInitialization::kZero)
Constructor.
Matrix_t & GetWeightsForgetStateGradients()
const Matrix_t & GetOutputBiasGradients() const
typename Architecture_t::TensorDescriptor_t TensorDescriptor_t
const Matrix_t & GetWeightsOutputStateGradients() const
Matrix_t & fWeightsOutputStateGradients
Gradients w.r.t the output gate - hidden state weights.
bool fReturnSequence
Return in output full sequence or just last element.
Matrix_t & GetWeightsForgetGradients()
Matrix_t & GetWeightsCandidateGradients()
const Matrix_t & GetWeightsForgetGradients() const
Matrix_t fCell
Cell state of LSTM.
std::vector< Matrix_t > & GetDerivativesCandidate()
const Matrix_t & GetForgetBiasGradients() const
std::vector< Matrix_t > & GetOutputGateTensor()
Matrix_t & GetCandidateValue()
const Matrix_t & GetForgetDerivativesAt(size_t i) const
Matrix_t fState
Hidden state of LSTM.
void OutputGate(const Matrix_t &input, Matrix_t &dout)
Computes output values (NN with Sigmoid)
const Matrix_t & GetForgetGateValue() const
std::vector< Matrix_t > candidate_gate_value
candidate gate value for every time step
Matrix_t & GetInputGateValue()
const Matrix_t & GetState() const
const Matrix_t & GetWeightsCandidateState() const
Matrix_t & GetCandidateBias()
const std::vector< Matrix_t > & GetForgetGateTensor() const
const std::vector< Matrix_t > & GetDerivativesOutput() const
const std::vector< Matrix_t > & GetCellTensor() const
const Tensor_t & GetWeightsTensor() const
Matrix_t & fWeightsInputGate
Input Gate weights for input, fWeights[0].
std::vector< Matrix_t > & GetCandidateGateTensor()
const Matrix_t & GetOutputDerivativesAt(size_t i) const
void ReadWeightsFromXML(void *parent)
Read the information and the weights about the layer from XML node.
const Matrix_t & GetCell() const
Matrix_t & fWeightsForgetStateGradients
Gradients w.r.t the forget gate - hidden state weights.
const Matrix_t & GetCandidateGateTensorAt(size_t i) const
Matrix_t fOutputValue
Computed output gate values.
size_t fCellSize
Cell state size of LSTM.
Matrix_t & GetOutputDerivativesAt(size_t i)
Matrix_t & GetInputGateTensorAt(size_t i)
std::vector< Matrix_t > & GetDerivativesInput()
Matrix_t & fWeightsOutputGateState
Output Gate weights for prev state, fWeights[7].
const std::vector< Matrix_t > & GetDerivativesForget() const
Matrix_t & GetForgetBiasGradients()
const Matrix_t & GetForgetGateBias() const
const Matrix_t & GetCandidateDerivativesAt(size_t i) const
Matrix_t & GetInputGateBias()
Matrix_t & GetOutputGateTensorAt(size_t i)
size_t fTimeSteps
Timesteps for LSTM.
const Matrix_t & GetCandidateBiasGradients() const
const Matrix_t & GetCandidateValue() const
typename Architecture_t::FilterDescriptor_t WeightsDescriptor_t
Matrix_t & fInputGateBias
Input Gate bias.
const Matrix_t & GetWeightsForgetGate() const
std::vector< Matrix_t > input_gate_value
input gate value for every time step
const Matrix_t & GetWeightsCandidateStateGradients() const
Tensor_t & GetWeightsTensor()
Matrix_t & fWeightsForgetGradients
Gradients w.r.t the forget gate - input weights.
std::vector< Matrix_t > & GetDerivativesForget()
const Matrix_t & GetWeightsOutputGate() const
void ForgetGate(const Matrix_t &input, Matrix_t &df)
Forgets the past values (NN with Sigmoid)
std::vector< Matrix_t > & GetInputGateTensor()
Matrix_t & GetOutputBiasGradients()
const Matrix_t & GetOutputGateValue() const
const Matrix_t & GetWeightsOutputGateState() const
Matrix_t & GetCandidateDerivativesAt(size_t i)
Matrix_t fInputValue
Computed input gate values.
Matrix_t & GetWeightsOutputGate()
const Matrix_t & GetWeightsCandidate() const
void Print() const
Prints the info about the layer.
void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward)
Backpropagates the error.
const Matrix_t & GetWeightsCandidateGradients() const
Tensor_t fWeightGradientsTensor
Tensor for all weight gradients.
Matrix_t & GetInputDerivativesAt(size_t i)
typename Architecture_t::RNNDescriptors_t RNNDescriptors_t
DNN::EActivationFunction GetActivationFunctionF1() const
Tensor_t fY
cached output tensor as T x B x S
std::vector< Matrix_t > forget_gate_value
forget gate value for every time step
Matrix_t & fWeightsCandidate
Candidate Gate weights for input, fWeights[4].
bool fRememberState
Remember state in next pass.
Matrix_t & fWeightsInputGateState
Input Gate weights for prev state, fWeights[1].
TDescriptors * fDescriptors
Keeps all the RNN descriptors.
std::vector< Matrix_t > & GetCellTensor()
size_t GetCellSize() const
Matrix_t & fCandidateBiasGradients
Gradients w.r.t the candidate gate - bias weights.
Matrix_t fForgetValue
Computed forget gate values.
Generic General Layer class.
virtual void Initialize()
Initialize the weights and biases according to the given initialization method.
size_t GetBatchSize() const
Getters.
size_t GetInputWidth() const
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=nullptr)
create new child element for parent node
XMLAttrPointer_t NewAttr(XMLNodePointer_t xmlnode, XMLNsPointer_t, const char *name, const char *value)
creates new attribute for xmlnode, namespaces are not supported for attributes
EActivationFunction
Enum that represents layer activation functions.
create variable transformations