30#ifndef TMVA_DNN_LSTM_LAYER
31#define TMVA_DNN_LSTM_LAYER
55template<
typename Architecture_t>
61 using Matrix_t =
typename Architecture_t::Matrix_t;
62 using Scalar_t =
typename Architecture_t::Scalar_t;
63 using Tensor_t =
typename Architecture_t::Tensor_t;
147 TBasicLSTMLayer(
size_t batchSize,
size_t stateSize,
size_t inputSize,
size_t timeSteps,
bool rememberState =
false,
148 bool returnSequence =
false,
174 const Tensor_t &activations_backward);
183 const Matrix_t & precStateActivations,
const Matrix_t & precCellActivations,
340template <
typename Architecture_t>
346 batchSize, 1, timeSteps, inputSize, 1, (returnSequence) ? timeSteps : 1, stateSize, 8,
347 {stateSize, stateSize, stateSize, stateSize, stateSize, stateSize, stateSize, stateSize},
348 {inputSize, inputSize, inputSize, inputSize, stateSize, stateSize, stateSize, stateSize}, 4,
349 {stateSize, stateSize, stateSize, stateSize}, {1, 1, 1, 1}, batchSize, (returnSequence) ? timeSteps : 1,
351 fStateSize(stateSize), fCellSize(stateSize), fTimeSteps(timeSteps), fRememberState(rememberState),
352 fReturnSequence(returnSequence), fF1(
f1), fF2(f2), fInputValue(batchSize, stateSize),
353 fCandidateValue(batchSize, stateSize), fForgetValue(batchSize, stateSize), fOutputValue(batchSize, stateSize),
354 fState(batchSize, stateSize), fCell(batchSize, stateSize), fWeightsInputGate(this->GetWeightsAt(0)),
355 fWeightsInputGateState(this->GetWeightsAt(4)), fInputGateBias(this->GetBiasesAt(0)),
356 fWeightsForgetGate(this->GetWeightsAt(1)), fWeightsForgetGateState(this->GetWeightsAt(5)),
357 fForgetGateBias(this->GetBiasesAt(1)), fWeightsCandidate(this->GetWeightsAt(2)),
358 fWeightsCandidateState(this->GetWeightsAt(6)), fCandidateBias(this->GetBiasesAt(2)),
359 fWeightsOutputGate(this->GetWeightsAt(3)), fWeightsOutputGateState(this->GetWeightsAt(7)),
360 fOutputGateBias(this->GetBiasesAt(3)), fWeightsInputGradients(this->GetWeightGradientsAt(0)),
361 fWeightsInputStateGradients(this->GetWeightGradientsAt(4)), fInputBiasGradients(this->GetBiasGradientsAt(0)),
362 fWeightsForgetGradients(this->GetWeightGradientsAt(1)),
363 fWeightsForgetStateGradients(this->GetWeightGradientsAt(5)), fForgetBiasGradients(this->GetBiasGradientsAt(1)),
364 fWeightsCandidateGradients(this->GetWeightGradientsAt(2)),
365 fWeightsCandidateStateGradients(this->GetWeightGradientsAt(6)),
366 fCandidateBiasGradients(this->GetBiasGradientsAt(2)), fWeightsOutputGradients(this->GetWeightGradientsAt(3)),
367 fWeightsOutputStateGradients(this->GetWeightGradientsAt(7)), fOutputBiasGradients(this->GetBiasGradientsAt(3))
369 for (
size_t i = 0; i < timeSteps; ++i) {
378 cell_value.emplace_back(batchSize, stateSize);
380 Architecture_t::InitializeLSTMTensors(
this);
384template <
typename Architecture_t>
387 fStateSize(layer.fStateSize),
388 fCellSize(layer.fCellSize),
389 fTimeSteps(layer.fTimeSteps),
390 fRememberState(layer.fRememberState),
391 fReturnSequence(layer.fReturnSequence),
392 fF1(layer.GetActivationFunctionF1()),
393 fF2(layer.GetActivationFunctionF2()),
394 fInputValue(layer.GetBatchSize(), layer.GetStateSize()),
395 fCandidateValue(layer.GetBatchSize(), layer.GetStateSize()),
396 fForgetValue(layer.GetBatchSize(), layer.GetStateSize()),
397 fOutputValue(layer.GetBatchSize(), layer.GetStateSize()),
398 fState(layer.GetBatchSize(), layer.GetStateSize()),
399 fCell(layer.GetBatchSize(), layer.GetCellSize()),
400 fWeightsInputGate(this->GetWeightsAt(0)),
401 fWeightsInputGateState(this->GetWeightsAt(4)),
402 fInputGateBias(this->GetBiasesAt(0)),
403 fWeightsForgetGate(this->GetWeightsAt(1)),
404 fWeightsForgetGateState(this->GetWeightsAt(5)),
405 fForgetGateBias(this->GetBiasesAt(1)),
406 fWeightsCandidate(this->GetWeightsAt(2)),
407 fWeightsCandidateState(this->GetWeightsAt(6)),
408 fCandidateBias(this->GetBiasesAt(2)),
409 fWeightsOutputGate(this->GetWeightsAt(3)),
410 fWeightsOutputGateState(this->GetWeightsAt(7)),
411 fOutputGateBias(this->GetBiasesAt(3)),
412 fWeightsInputGradients(this->GetWeightGradientsAt(0)),
413 fWeightsInputStateGradients(this->GetWeightGradientsAt(4)),
414 fInputBiasGradients(this->GetBiasGradientsAt(0)),
415 fWeightsForgetGradients(this->GetWeightGradientsAt(1)),
416 fWeightsForgetStateGradients(this->GetWeightGradientsAt(5)),
417 fForgetBiasGradients(this->GetBiasGradientsAt(1)),
418 fWeightsCandidateGradients(this->GetWeightGradientsAt(2)),
419 fWeightsCandidateStateGradients(this->GetWeightGradientsAt(6)),
420 fCandidateBiasGradients(this->GetBiasGradientsAt(2)),
421 fWeightsOutputGradients(this->GetWeightGradientsAt(3)),
422 fWeightsOutputStateGradients(this->GetWeightGradientsAt(7)),
423 fOutputBiasGradients(this->GetBiasGradientsAt(3))
464 Architecture_t::InitializeLSTMTensors(
this);
468template <
typename Architecture_t>
473 Architecture_t::InitializeLSTMDescriptors(fDescriptors,
this);
474 Architecture_t::InitializeLSTMWorkspace(fWorkspace, fDescriptors,
this);
478template <
typename Architecture_t>
486 Matrix_t tmpState(fInputValue.GetNrows(), fInputValue.GetNcols());
487 Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsInputGateState);
488 Architecture_t::MultiplyTranspose(fInputValue,
input, fWeightsInputGate);
489 Architecture_t::ScaleAdd(fInputValue, tmpState);
490 Architecture_t::AddRowWise(fInputValue, fInputGateBias);
491 DNN::evaluateDerivativeMatrix<Architecture_t>(di, fInp, fInputValue);
492 DNN::evaluateMatrix<Architecture_t>(fInputValue, fInp);
496template <
typename Architecture_t>
504 Matrix_t tmpState(fForgetValue.GetNrows(), fForgetValue.GetNcols());
505 Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsForgetGateState);
506 Architecture_t::MultiplyTranspose(fForgetValue,
input, fWeightsForgetGate);
507 Architecture_t::ScaleAdd(fForgetValue, tmpState);
508 Architecture_t::AddRowWise(fForgetValue, fForgetGateBias);
509 DNN::evaluateDerivativeMatrix<Architecture_t>(df, fFor, fForgetValue);
510 DNN::evaluateMatrix<Architecture_t>(fForgetValue, fFor);
514template <
typename Architecture_t>
522 Matrix_t tmpState(fCandidateValue.GetNrows(), fCandidateValue.GetNcols());
523 Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsCandidateState);
524 Architecture_t::MultiplyTranspose(fCandidateValue,
input, fWeightsCandidate);
525 Architecture_t::ScaleAdd(fCandidateValue, tmpState);
526 Architecture_t::AddRowWise(fCandidateValue, fCandidateBias);
527 DNN::evaluateDerivativeMatrix<Architecture_t>(dc, fCan, fCandidateValue);
528 DNN::evaluateMatrix<Architecture_t>(fCandidateValue, fCan);
532template <
typename Architecture_t>
540 Matrix_t tmpState(fOutputValue.GetNrows(), fOutputValue.GetNcols());
541 Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsOutputGateState);
542 Architecture_t::MultiplyTranspose(fOutputValue,
input, fWeightsOutputGate);
543 Architecture_t::ScaleAdd(fOutputValue, tmpState);
544 Architecture_t::AddRowWise(fOutputValue, fOutputGateBias);
545 DNN::evaluateDerivativeMatrix<Architecture_t>(dout, fOut, fOutputValue);
546 DNN::evaluateMatrix<Architecture_t>(fOutputValue, fOut);
552template <
typename Architecture_t>
558 if (Architecture_t::IsCudnn()) {
561 assert(
input.GetStrides()[1] == this->GetInputSize());
565 Architecture_t::Rearrange(
x,
input);
568 const auto &weights = this->GetWeightsTensor();
573 auto &hx = this->fState;
575 auto &cx = this->fCell;
577 auto &hy = this->fState;
578 auto &cy = this->fCell;
583 Architecture_t::RNNForward(
x, hx, cx, weights,
y, hy, cy, rnnDesc, rnnWork, isTraining);
585 if (fReturnSequence) {
586 Architecture_t::Rearrange(this->GetOutput(),
y);
589 Tensor_t tmp = (
y.At(
y.GetShape()[0] - 1)).Reshape({
y.GetShape()[1], 1,
y.GetShape()[2]});
590 Architecture_t::Copy(this->GetOutput(), tmp);
603 Tensor_t arrInput( fTimeSteps, this->GetBatchSize(), this->GetInputWidth());
606 Architecture_t::Rearrange(arrInput,
input);
608 Tensor_t arrOutput ( fTimeSteps, this->GetBatchSize(), fStateSize);
611 if (!this->fRememberState) {
617 for (
size_t t = 0; t < fTimeSteps; ++t) {
620 InputGate(arrInputMt, fDerivativesInput[t]);
621 ForgetGate(arrInputMt, fDerivativesForget[t]);
622 CandidateValue(arrInputMt, fDerivativesCandidate[t]);
623 OutputGate(arrInputMt, fDerivativesOutput[t]);
625 Architecture_t::Copy(this->GetInputGateTensorAt(t), fInputValue);
626 Architecture_t::Copy(this->GetForgetGateTensorAt(t), fForgetValue);
627 Architecture_t::Copy(this->GetCandidateGateTensorAt(t), fCandidateValue);
628 Architecture_t::Copy(this->GetOutputGateTensorAt(t), fOutputValue);
630 CellForward(fInputValue, fForgetValue, fCandidateValue, fOutputValue);
631 Matrix_t arrOutputMt = arrOutput[t];
632 Architecture_t::Copy(arrOutputMt, fState);
633 Architecture_t::Copy(this->GetCellTensorAt(t), fCell);
638 Architecture_t::Rearrange(this->GetOutput(), arrOutput);
641 Tensor_t tmp = arrOutput.At(fTimeSteps - 1);
644 tmp = tmp.Reshape( {tmp.GetShape()[0], tmp.GetShape()[1], 1});
645 assert(tmp.GetSize() == this->GetOutput().GetSize());
646 assert( tmp.GetShape()[0] == this->GetOutput().GetShape()[2]);
647 Architecture_t::Rearrange(this->GetOutput(), tmp);
654template <
typename Architecture_t>
661 Architecture_t::Hadamard(fCell, forgetGateValues);
662 Architecture_t::Hadamard(inputGateValues, candidateValues);
663 Architecture_t::ScaleAdd(fCell, inputGateValues);
665 Matrix_t cache(fCell.GetNrows(), fCell.GetNcols());
666 Architecture_t::Copy(cache, fCell);
670 DNN::evaluateMatrix<Architecture_t>(cache, fAT);
675 Architecture_t::Copy(fState, cache);
676 Architecture_t::Hadamard(fState, outputGateValues);
680template <
typename Architecture_t>
682 const Tensor_t &activations_backward)
687 if (Architecture_t::IsCudnn()) {
695 assert(activations_backward.GetStrides()[1] == this->GetInputSize());
697 Architecture_t::Rearrange(
x, activations_backward);
699 if (!fReturnSequence) {
702 Architecture_t::InitializeZero(dy);
707 Tensor_t tmp2 = dy.At(dy.GetShape()[0] - 1).Reshape({dy.GetShape()[1], 1, dy.GetShape()[2]});
710 Architecture_t::Copy(tmp2, this->GetActivationGradients());
712 Architecture_t::Rearrange(
y, this->GetOutput());
713 Architecture_t::Rearrange(dy, this->GetActivationGradients());
719 const auto &weights = this->GetWeightsTensor();
720 auto &weightGradients = this->GetWeightGradientsTensor();
723 Architecture_t::InitializeZero(weightGradients);
726 auto &hx = this->GetState();
727 auto &cx = this->GetCell();
738 Architecture_t::RNNBackward(
x, hx, cx,
y, dy, dhy, dcy, weights, dx, dhx, dcx, weightGradients, rnnDesc, rnnWork);
742 if (gradients_backward.GetSize() != 0)
743 Architecture_t::Rearrange(gradients_backward, dx);
752 Matrix_t state_gradients_backward(this->GetBatchSize(), fStateSize);
756 Matrix_t cell_gradients_backward(this->GetBatchSize(), fStateSize);
761 if (gradients_backward.GetSize() == 0 || gradients_backward[0].GetNrows() == 0 || gradients_backward[0].GetNcols() == 0) {
766 Tensor_t arr_gradients_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
771 Tensor_t arr_activations_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
773 Architecture_t::Rearrange(arr_activations_backward, activations_backward);
777 Tensor_t arr_output ( fTimeSteps, this->GetBatchSize(), fStateSize);
779 Matrix_t initState(this->GetBatchSize(), fCellSize);
784 Tensor_t arr_actgradients(fTimeSteps, this->GetBatchSize(), fStateSize);
786 if (fReturnSequence) {
787 Architecture_t::Rearrange(arr_output, this->GetOutput());
788 Architecture_t::Rearrange(arr_actgradients, this->GetActivationGradients());
792 Architecture_t::InitializeZero(arr_actgradients);
794 Tensor_t tmp_grad = arr_actgradients.At(fTimeSteps - 1).Reshape( {this->GetBatchSize(), fStateSize, 1});
795 assert(tmp_grad.GetSize() == this->GetActivationGradients().GetSize());
796 assert(tmp_grad.GetShape()[0] == this->GetActivationGradients().GetShape()[2]);
798 Architecture_t::Rearrange(tmp_grad, this->GetActivationGradients());
805 fWeightsInputGradients.Zero();
806 fWeightsInputStateGradients.Zero();
807 fInputBiasGradients.Zero();
810 fWeightsForgetGradients.Zero();
811 fWeightsForgetStateGradients.Zero();
812 fForgetBiasGradients.Zero();
815 fWeightsCandidateGradients.Zero();
816 fWeightsCandidateStateGradients.Zero();
817 fCandidateBiasGradients.Zero();
820 fWeightsOutputGradients.Zero();
821 fWeightsOutputStateGradients.Zero();
822 fOutputBiasGradients.Zero();
825 for (
size_t t = fTimeSteps; t > 0; t--) {
827 Architecture_t::ScaleAdd(state_gradients_backward, arr_actgradients[t-1]);
829 const Matrix_t &prevStateActivations = arr_output[t-2];
830 const Matrix_t &prevCellActivations = this->GetCellTensorAt(t-2);
832 Matrix_t dx = arr_gradients_backward[t-1];
833 CellBackward(state_gradients_backward, cell_gradients_backward,
834 prevStateActivations, prevCellActivations,
835 this->GetInputGateTensorAt(t-1), this->GetForgetGateTensorAt(t-1),
836 this->GetCandidateGateTensorAt(t-1), this->GetOutputGateTensorAt(t-1),
837 arr_activations_backward[t-1], dx,
838 fDerivativesInput[t-1], fDerivativesForget[t-1],
839 fDerivativesCandidate[t-1], fDerivativesOutput[t-1], t-1);
841 const Matrix_t &prevStateActivations = initState;
842 const Matrix_t &prevCellActivations = initState;
843 Matrix_t dx = arr_gradients_backward[t-1];
844 CellBackward(state_gradients_backward, cell_gradients_backward,
845 prevStateActivations, prevCellActivations,
846 this->GetInputGateTensorAt(t-1), this->GetForgetGateTensorAt(t-1),
847 this->GetCandidateGateTensorAt(t-1), this->GetOutputGateTensorAt(t-1),
848 arr_activations_backward[t-1], dx,
849 fDerivativesInput[t-1], fDerivativesForget[t-1],
850 fDerivativesCandidate[t-1], fDerivativesOutput[t-1], t-1);
855 Architecture_t::Rearrange(gradients_backward, arr_gradients_backward );
862template <
typename Architecture_t>
865 const Matrix_t & precStateActivations,
const Matrix_t & precCellActivations,
879 Matrix_t cell_gradient(this->GetCellTensorAt(t).GetNrows(), this->GetCellTensorAt(t).GetNcols());
880 DNN::evaluateDerivativeMatrix<Architecture_t>(cell_gradient, fAT, this->GetCellTensorAt(t));
883 Matrix_t cell_tanh(this->GetCellTensorAt(t).GetNrows(), this->GetCellTensorAt(t).GetNcols());
884 Architecture_t::Copy(cell_tanh, this->GetCellTensorAt(t));
885 DNN::evaluateMatrix<Architecture_t>(cell_tanh, fAT);
887 return Architecture_t::LSTMLayerBackward(state_gradients_backward, cell_gradients_backward,
888 fWeightsInputGradients, fWeightsForgetGradients, fWeightsCandidateGradients,
889 fWeightsOutputGradients, fWeightsInputStateGradients, fWeightsForgetStateGradients,
890 fWeightsCandidateStateGradients, fWeightsOutputStateGradients, fInputBiasGradients, fForgetBiasGradients,
891 fCandidateBiasGradients, fOutputBiasGradients, di, df, dc, dout,
892 precStateActivations, precCellActivations,
893 input_gate, forget_gate, candidate_gate, output_gate,
894 fWeightsInputGate, fWeightsForgetGate, fWeightsCandidate, fWeightsOutputGate,
895 fWeightsInputGateState, fWeightsForgetGateState, fWeightsCandidateState,
896 fWeightsOutputGateState,
input, input_gradient,
897 cell_gradient, cell_tanh);
901template <
typename Architecture_t>
910template<
typename Architecture_t>
914 std::cout <<
" LSTM Layer: \t ";
915 std::cout <<
" (NInput = " << this->GetInputSize();
916 std::cout <<
", NState = " << this->GetStateSize();
917 std::cout <<
", NTime = " << this->GetTimeSteps() <<
" )";
918 std::cout <<
"\tOutput = ( " << this->GetOutput().GetFirstSize() <<
" , " << this->GetOutput()[0].GetNrows() <<
" , " << this->GetOutput()[0].GetNcols() <<
" )\n";
922template <
typename Architecture_t>
937 this->WriteMatrixToXML(layerxml,
"InputWeights", this->GetWeightsAt(0));
938 this->WriteMatrixToXML(layerxml,
"InputStateWeights", this->GetWeightsAt(1));
939 this->WriteMatrixToXML(layerxml,
"InputBiases", this->GetBiasesAt(0));
940 this->WriteMatrixToXML(layerxml,
"ForgetWeights", this->GetWeightsAt(2));
941 this->WriteMatrixToXML(layerxml,
"ForgetStateWeights", this->GetWeightsAt(3));
942 this->WriteMatrixToXML(layerxml,
"ForgetBiases", this->GetBiasesAt(1));
943 this->WriteMatrixToXML(layerxml,
"CandidateWeights", this->GetWeightsAt(4));
944 this->WriteMatrixToXML(layerxml,
"CandidateStateWeights", this->GetWeightsAt(5));
945 this->WriteMatrixToXML(layerxml,
"CandidateBiases", this->GetBiasesAt(2));
946 this->WriteMatrixToXML(layerxml,
"OuputWeights", this->GetWeightsAt(6));
947 this->WriteMatrixToXML(layerxml,
"OutputStateWeights", this->GetWeightsAt(7));
948 this->WriteMatrixToXML(layerxml,
"OutputBiases", this->GetBiasesAt(3));
952template <
typename Architecture_t>
957 this->ReadMatrixXML(parent,
"InputWeights", this->GetWeightsAt(0));
958 this->ReadMatrixXML(parent,
"InputStateWeights", this->GetWeightsAt(1));
959 this->ReadMatrixXML(parent,
"InputBiases", this->GetBiasesAt(0));
960 this->ReadMatrixXML(parent,
"ForgetWeights", this->GetWeightsAt(2));
961 this->ReadMatrixXML(parent,
"ForgetStateWeights", this->GetWeightsAt(3));
962 this->ReadMatrixXML(parent,
"ForgetBiases", this->GetBiasesAt(1));
963 this->ReadMatrixXML(parent,
"CandidateWeights", this->GetWeightsAt(4));
964 this->ReadMatrixXML(parent,
"CandidateStateWeights", this->GetWeightsAt(5));
965 this->ReadMatrixXML(parent,
"CandidateBiases", this->GetBiasesAt(2));
966 this->ReadMatrixXML(parent,
"OuputWeights", this->GetWeightsAt(6));
967 this->ReadMatrixXML(parent,
"OutputStateWeights", this->GetWeightsAt(7));
968 this->ReadMatrixXML(parent,
"OutputBiases", this->GetBiasesAt(3));
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
void InputGate(const Matrix_t &input, Matrix_t &di)
Decides the values we'll update (NN with Sigmoid)
const Matrix_t & GetForgetGateTensorAt(size_t i) const
Matrix_t & GetWeightsOutputGateState()
const std::vector< Matrix_t > & GetOutputGateTensor() const
Tensor_t fWeightsTensor
Tensor for all weights.
const std::vector< Matrix_t > & GetInputGateTensor() const
std::vector< Matrix_t > & GetDerivativesOutput()
const Matrix_t & GetWeigthsForgetStateGradients() const
Matrix_t & GetWeightsForgetGate()
typename Architecture_t::Matrix_t Matrix_t
Matrix_t & GetCandidateGateTensorAt(size_t i)
void InitState(DNN::EInitialization m=DNN::EInitialization::kZero)
Initialize the hidden state and cell state method.
Matrix_t & fWeightsCandidateGradients
Gradients w.r.t the candidate gate - input weights.
const Matrix_t & GetOutputGateBias() const
Matrix_t & GetWeightsCandidateStateGradients()
Matrix_t & GetWeightsInputGate()
Matrix_t & GetWeightsInputGateState()
const std::vector< Matrix_t > & GetCandidateGateTensor() const
const Matrix_t & GetInputGateTensorAt(size_t i) const
std::vector< Matrix_t > & GetForgetGateTensor()
std::vector< Matrix_t > cell_value
cell value for every time step
Matrix_t & fWeightsOutputGradients
Gradients w.r.t the output gate - input weights.
Matrix_t & GetOutputGateBias()
Matrix_t & fOutputBiasGradients
Gradients w.r.t the output gate - bias weights.
DNN::EActivationFunction fF1
Activation function: sigmoid.
virtual void Initialize()
Initialize the weights according to the given initialization method.
Tensor_t fDy
cached activation gradient (input of backward) as T x B x S
Matrix_t & fWeightsOutputGate
Output Gate weights for input, fWeights[6].
Matrix_t & GetForgetGateBias()
Matrix_t & fWeightsCandidateStateGradients
Gradients w.r.t the candidate gate - hidden state weights.
void Forward(Tensor_t &input, bool isTraining=true)
Computes the next hidden state and next cell state with given input matrix.
const Matrix_t & GetInputGateBias() const
typename Architecture_t::Scalar_t Scalar_t
size_t GetInputSize() const
Getters.
Matrix_t & GetForgetGateTensorAt(size_t i)
const Matrix_t & GetOutputGateTensorAt(size_t i) const
const Matrix_t & GetCellTensorAt(size_t i) const
Tensor_t fX
cached input tensor as T x B x I
DNN::EActivationFunction GetActivationFunctionF2() const
Matrix_t & GetCellTensorAt(size_t i)
Matrix_t & fWeightsInputStateGradients
Gradients w.r.t the input gate - hidden state weights.
void CellForward(Matrix_t &inputGateValues, const Matrix_t &forgetGateValues, const Matrix_t &candidateValues, const Matrix_t &outputGateValues)
Forward for a single cell (time unit)
Matrix_t & CellBackward(Matrix_t &state_gradients_backward, Matrix_t &cell_gradients_backward, const Matrix_t &precStateActivations, const Matrix_t &precCellActivations, const Matrix_t &input_gate, const Matrix_t &forget_gate, const Matrix_t &candidate_gate, const Matrix_t &output_gate, const Matrix_t &input, Matrix_t &input_gradient, Matrix_t &di, Matrix_t &df, Matrix_t &dc, Matrix_t &dout, size_t t)
Backward for a single time unit a the corresponding call to Forward(...).
const Matrix_t & GetWeightsInputStateGradients() const
std::vector< Matrix_t > fDerivativesOutput
First fDerivatives of the activations output gate.
size_t GetStateSize() const
Matrix_t & fWeightsForgetGateState
Forget Gate weights for prev state, fWeights[3].
Matrix_t & fOutputGateBias
Output Gate bias.
std::vector< Matrix_t > fDerivativesCandidate
First fDerivatives of the activations candidate gate.
const Matrix_t & GetInputDerivativesAt(size_t i) const
Matrix_t & fWeightsForgetGate
Forget Gate weights for input, fWeights[2].
Matrix_t & fWeightsInputGradients
Gradients w.r.t the input gate - input weights.
typename Architecture_t::Tensor_t Tensor_t
const std::vector< Matrix_t > & GetDerivativesInput() const
Matrix_t & GetWeightsCandidate()
Matrix_t & fForgetGateBias
Forget Gate bias.
Matrix_t & GetWeightsInputGradients()
Matrix_t & GetCandidateBiasGradients()
Matrix_t & GetWeightsOutputGradients()
Matrix_t & fCandidateBias
Candidate Gate bias.
Matrix_t fCandidateValue
Computed candidate values.
Tensor_t & GetWeightGradientsTensor()
bool DoesRememberState() const
const Matrix_t & GetWeightsOutputGradients() const
typename Architecture_t::RecurrentDescriptor_t LayerDescriptor_t
const Matrix_t & GetWeightsInputGradients() const
Matrix_t & GetWeightsCandidateState()
Matrix_t & GetInputBiasGradients()
const Matrix_t & GetInputBiasGradients() const
size_t GetTimeSteps() const
DNN::EActivationFunction fF2
Activation function: tanh.
Matrix_t & fInputBiasGradients
Gradients w.r.t the input gate - bias weights.
Matrix_t & GetWeightsOutputStateGradients()
Matrix_t & fWeightsCandidateState
Candidate Gate weights for prev state, fWeights[5].
void AddWeightsXMLTo(void *parent)
Writes the information and the weights about the layer in an XML node.
Matrix_t & GetForgetGateValue()
std::vector< Matrix_t > fDerivativesForget
First fDerivatives of the activations forget gate.
const Tensor_t & GetWeightGradientsTensor() const
Matrix_t & GetForgetDerivativesAt(size_t i)
const Matrix_t & GetWeightsInputGateState() const
Matrix_t & GetWeightsInputStateGradients()
typename Architecture_t::DropoutDescriptor_t HelperDescriptor_t
Matrix_t & fForgetBiasGradients
Gradients w.r.t the forget gate - bias weights.
const Matrix_t & GetCandidateBias() const
std::vector< Matrix_t > output_gate_value
output gate value for every time step
const std::vector< Matrix_t > & GetDerivativesCandidate() const
size_t fStateSize
Hidden state size for LSTM.
void CandidateValue(const Matrix_t &input, Matrix_t &dc)
Decides the new candidate values (NN with Tanh)
std::vector< Matrix_t > fDerivativesInput
First fDerivatives of the activations input gate.
const Matrix_t & GetWeightsForgetGateState() const
Matrix_t & GetWeightsForgetGateState()
const Matrix_t & GetWeightsInputGate() const
const Matrix_t & GetInputGateValue() const
void Update(const Scalar_t learningRate)
bool DoesReturnSequence() const
Tensor_t fDx
cached gradient on the input (output of backward) as T x B x I
typename Architecture_t::RNNWorkspace_t RNNWorkspace_t
Matrix_t & GetOutputGateValue()
TBasicLSTMLayer(size_t batchSize, size_t stateSize, size_t inputSize, size_t timeSteps, bool rememberState=false, bool returnSequence=false, DNN::EActivationFunction f1=DNN::EActivationFunction::kSigmoid, DNN::EActivationFunction f2=DNN::EActivationFunction::kTanh, bool training=true, DNN::EInitialization fA=DNN::EInitialization::kZero)
Constructor.
Matrix_t & GetWeightsForgetStateGradients()
const Matrix_t & GetOutputBiasGradients() const
typename Architecture_t::TensorDescriptor_t TensorDescriptor_t
const Matrix_t & GetWeightsOutputStateGradients() const
Matrix_t & fWeightsOutputStateGradients
Gradients w.r.t the output gate - hidden state weights.
bool fReturnSequence
Return in output full sequence or just last element.
Matrix_t & GetWeightsForgetGradients()
Matrix_t & GetWeightsCandidateGradients()
const Matrix_t & GetWeightsForgetGradients() const
Matrix_t fCell
Cell state of LSTM.
std::vector< Matrix_t > & GetDerivativesCandidate()
const Matrix_t & GetForgetBiasGradients() const
std::vector< Matrix_t > & GetOutputGateTensor()
Matrix_t & GetCandidateValue()
const Matrix_t & GetForgetDerivativesAt(size_t i) const
Matrix_t fState
Hidden state of LSTM.
void OutputGate(const Matrix_t &input, Matrix_t &dout)
Computes output values (NN with Sigmoid)
const Matrix_t & GetForgetGateValue() const
std::vector< Matrix_t > candidate_gate_value
candidate gate value for every time step
Matrix_t & GetInputGateValue()
const Matrix_t & GetState() const
const Matrix_t & GetWeightsCandidateState() const
Matrix_t & GetCandidateBias()
const std::vector< Matrix_t > & GetForgetGateTensor() const
const std::vector< Matrix_t > & GetDerivativesOutput() const
const std::vector< Matrix_t > & GetCellTensor() const
const Tensor_t & GetWeightsTensor() const
Matrix_t & fWeightsInputGate
Input Gate weights for input, fWeights[0].
std::vector< Matrix_t > & GetCandidateGateTensor()
const Matrix_t & GetOutputDerivativesAt(size_t i) const
void ReadWeightsFromXML(void *parent)
Read the information and the weights about the layer from XML node.
const Matrix_t & GetCell() const
Matrix_t & fWeightsForgetStateGradients
Gradients w.r.t the forget gate - hidden state weights.
const Matrix_t & GetCandidateGateTensorAt(size_t i) const
Matrix_t fOutputValue
Computed output gate values.
size_t fCellSize
Cell state size of LSTM.
Matrix_t & GetOutputDerivativesAt(size_t i)
Matrix_t & GetInputGateTensorAt(size_t i)
std::vector< Matrix_t > & GetDerivativesInput()
Matrix_t & fWeightsOutputGateState
Output Gate weights for prev state, fWeights[7].
const std::vector< Matrix_t > & GetDerivativesForget() const
Matrix_t & GetForgetBiasGradients()
const Matrix_t & GetForgetGateBias() const
const Matrix_t & GetCandidateDerivativesAt(size_t i) const
Matrix_t & GetInputGateBias()
Matrix_t & GetOutputGateTensorAt(size_t i)
size_t fTimeSteps
Timesteps for LSTM.
const Matrix_t & GetCandidateBiasGradients() const
const Matrix_t & GetCandidateValue() const
typename Architecture_t::FilterDescriptor_t WeightsDescriptor_t
Matrix_t & fInputGateBias
Input Gate bias.
const Matrix_t & GetWeightsForgetGate() const
std::vector< Matrix_t > input_gate_value
input gate value for every time step
const Matrix_t & GetWeightsCandidateStateGradients() const
Tensor_t & GetWeightsTensor()
Matrix_t & fWeightsForgetGradients
Gradients w.r.t the forget gate - input weights.
std::vector< Matrix_t > & GetDerivativesForget()
const Matrix_t & GetWeightsOutputGate() const
void ForgetGate(const Matrix_t &input, Matrix_t &df)
Forgets the past values (NN with Sigmoid)
std::vector< Matrix_t > & GetInputGateTensor()
Matrix_t & GetOutputBiasGradients()
const Matrix_t & GetOutputGateValue() const
const Matrix_t & GetWeightsOutputGateState() const
Matrix_t & GetCandidateDerivativesAt(size_t i)
Matrix_t fInputValue
Computed input gate values.
Matrix_t & GetWeightsOutputGate()
const Matrix_t & GetWeightsCandidate() const
void Print() const
Prints the info about the layer.
void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward)
Backpropagates the error.
const Matrix_t & GetWeightsCandidateGradients() const
Tensor_t fWeightGradientsTensor
Tensor for all weight gradients.
Matrix_t & GetInputDerivativesAt(size_t i)
typename Architecture_t::RNNDescriptors_t RNNDescriptors_t
DNN::EActivationFunction GetActivationFunctionF1() const
Tensor_t fY
cached output tensor as T x B x S
std::vector< Matrix_t > forget_gate_value
forget gate value for every time step
Matrix_t & fWeightsCandidate
Candidate Gate weights for input, fWeights[4].
bool fRememberState
Remember state in next pass.
Matrix_t & fWeightsInputGateState
Input Gate weights for prev state, fWeights[1].
TDescriptors * fDescriptors
Keeps all the RNN descriptors.
std::vector< Matrix_t > & GetCellTensor()
size_t GetCellSize() const
Matrix_t & fCandidateBiasGradients
Gradients w.r.t the candidate gate - bias weights.
Matrix_t fForgetValue
Computed forget gate values.
Generic General Layer class.
virtual void Initialize()
Initialize the weights and biases according to the given initialization method.
size_t GetBatchSize() const
Getters.
size_t GetInputWidth() const
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=nullptr)
create new child element for parent node
XMLAttrPointer_t NewAttr(XMLNodePointer_t xmlnode, XMLNsPointer_t, const char *name, const char *value)
creates new attribute for xmlnode, namespaces are not supported for attributes
EActivationFunction
Enum that represents layer activation functions.
create variable transformations