30#ifndef TMVA_DNN_GRU_LAYER
31#define TMVA_DNN_GRU_LAYER
55template<
typename Architecture_t>
61 using Matrix_t =
typename Architecture_t::Matrix_t;
62 using Scalar_t =
typename Architecture_t::Scalar_t;
63 using Tensor_t =
typename Architecture_t::Tensor_t;
140 TBasicGRULayer(
size_t batchSize,
size_t stateSize,
size_t inputSize,
141 size_t timeSteps,
bool rememberState =
false,
bool returnSequence =
false,
142 bool resetGateAfter =
false,
167 const Tensor_t &activations_backward)
override;
175 const Matrix_t & precStateActivations,
191 void Print()
const override;
307template <
typename Architecture_t>
312 :
VGeneralLayer<Architecture_t>(batchSize, 1, timeSteps, inputSize, 1, (returnSequence) ? timeSteps : 1, stateSize,
313 6, {stateSize, stateSize, stateSize, stateSize, stateSize, stateSize},
314 {inputSize, inputSize, inputSize, stateSize, stateSize, stateSize}, 3,
315 {stateSize, stateSize, stateSize}, {1, 1, 1}, batchSize,
316 (returnSequence) ? timeSteps : 1, stateSize, fA),
317 fStateSize(stateSize), fTimeSteps(timeSteps), fRememberState(rememberState), fReturnSequence(returnSequence), fResetGateAfter(resetGateAfter),
318 fF1(
f1), fF2(f2), fResetValue(batchSize, stateSize), fUpdateValue(batchSize, stateSize),
319 fCandidateValue(batchSize, stateSize), fState(batchSize, stateSize), fWeightsResetGate(this->GetWeightsAt(0)),
320 fWeightsResetGateState(this->GetWeightsAt(3)), fResetGateBias(this->GetBiasesAt(0)),
321 fWeightsUpdateGate(this->GetWeightsAt(1)), fWeightsUpdateGateState(this->GetWeightsAt(4)),
322 fUpdateGateBias(this->GetBiasesAt(1)), fWeightsCandidate(this->GetWeightsAt(2)),
323 fWeightsCandidateState(this->GetWeightsAt(5)), fCandidateBias(this->GetBiasesAt(2)),
324 fWeightsResetGradients(this->GetWeightGradientsAt(0)), fWeightsResetStateGradients(this->GetWeightGradientsAt(3)),
325 fResetBiasGradients(this->GetBiasGradientsAt(0)), fWeightsUpdateGradients(this->GetWeightGradientsAt(1)),
326 fWeightsUpdateStateGradients(this->GetWeightGradientsAt(4)), fUpdateBiasGradients(this->GetBiasGradientsAt(1)),
327 fWeightsCandidateGradients(this->GetWeightGradientsAt(2)),
328 fWeightsCandidateStateGradients(this->GetWeightGradientsAt(5)),
329 fCandidateBiasGradients(this->GetBiasGradientsAt(2))
331 for (
size_t i = 0; i < timeSteps; ++i) {
332 fDerivativesReset.emplace_back(batchSize, stateSize);
333 fDerivativesUpdate.emplace_back(batchSize, stateSize);
334 fDerivativesCandidate.emplace_back(batchSize, stateSize);
335 reset_gate_value.emplace_back(batchSize, stateSize);
336 update_gate_value.emplace_back(batchSize, stateSize);
337 candidate_gate_value.emplace_back(batchSize, stateSize);
339 Architecture_t::InitializeGRUTensors(
this);
343template <
typename Architecture_t>
404 Architecture_t::InitializeGRUTensors(
this);
408template <
typename Architecture_t>
413 Architecture_t::InitializeGRUDescriptors(
fDescriptors,
this);
417 if (Architecture_t::IsCudnn())
422template <
typename Architecture_t>
440template <
typename Architecture_t>
458template <
typename Architecture_t>
478 Architecture_t::Hadamard(tmpState,
fState);
493template <
typename Architecture_t>
498 if (Architecture_t::IsCudnn()) {
501 assert(input.GetStrides()[1] == this->GetInputSize());
505 Architecture_t::Rearrange(
x, input);
511 auto &cx = this->
fCell;
514 auto &cy = this->
fCell;
519 Architecture_t::RNNForward(
x, hx, cx, weights,
y, hy, cy, rnnDesc, rnnWork, isTraining);
522 Architecture_t::Rearrange(this->
GetOutput(),
y);
525 Tensor_t tmp = (
y.At(
y.GetShape()[0] - 1)).Reshape({
y.GetShape()[1], 1,
y.GetShape()[2]});
526 Architecture_t::Copy(this->
GetOutput(), tmp);
541 Architecture_t::Rearrange(arrInput, input);
569 Matrix_t arrOutputMt = arrOutput[t];
570 Architecture_t::Copy(arrOutputMt,
fState);
574 Architecture_t::Rearrange(this->
GetOutput(), arrOutput);
580 tmp = tmp.Reshape({tmp.GetShape()[0], tmp.GetShape()[1], 1});
581 assert(tmp.GetSize() == this->GetOutput().GetSize());
582 assert(tmp.GetShape()[0] == this->GetOutput().GetShape()[2]);
583 Architecture_t::Rearrange(this->
GetOutput(), tmp);
590template <
typename Architecture_t>
594 Architecture_t::Hadamard(
fState, updateGateValues);
598 for (
size_t j = 0; j < (size_t) tmp.GetNcols(); j++) {
599 for (
size_t i = 0; i < (size_t) tmp.GetNrows(); i++) {
600 tmp(i,j) = 1 - tmp(i,j);
605 Architecture_t::Hadamard(candidateValues, tmp);
606 Architecture_t::ScaleAdd(
fState, candidateValues);
610template <
typename Architecture_t>
612 const Tensor_t &activations_backward)
616 if (Architecture_t::IsCudnn()) {
624 assert(activations_backward.GetStrides()[1] == this->GetInputSize());
627 Architecture_t::Rearrange(
x, activations_backward);
632 Architecture_t::InitializeZero(dy);
635 Tensor_t tmp2 = dy.At(dy.GetShape()[0] - 1).Reshape({dy.GetShape()[1], 1, dy.GetShape()[2]});
640 Architecture_t::Rearrange(
y, this->
GetOutput());
652 Architecture_t::InitializeZero(weightGradients);
666 Architecture_t::RNNBackward(
x, hx, cx,
y, dy, dhy, dcy, weights, dx, dhx, dcx, weightGradients, rnnDesc, rnnWork);
670 if (gradients_backward.GetSize() != 0)
671 Architecture_t::Rearrange(gradients_backward, dx);
684 if (gradients_backward.GetSize() == 0 || gradients_backward[0].GetNrows() == 0 || gradients_backward[0].GetNcols() == 0) {
695 Architecture_t::Rearrange(arr_activations_backward, activations_backward);
708 Architecture_t::Rearrange(arr_output, this->
GetOutput());
713 Architecture_t::InitializeZero(arr_actgradients);
716 assert(tmp_grad.GetSize() == this->GetActivationGradients().GetSize());
717 assert(tmp_grad.GetShape()[0] ==
718 this->GetActivationGradients().GetShape()[2]);
744 Architecture_t::ScaleAdd(state_gradients_backward, arr_actgradients[t-1]);
746 const Matrix_t &prevStateActivations = arr_output[t-2];
747 Matrix_t dx = arr_gradients_backward[t-1];
749 CellBackward(state_gradients_backward, prevStateActivations,
752 arr_activations_backward[t-1], dx ,
756 const Matrix_t &prevStateActivations = initState;
757 Matrix_t dx = arr_gradients_backward[t-1];
758 CellBackward(state_gradients_backward, prevStateActivations,
761 arr_activations_backward[t-1], dx ,
768 Architecture_t::Rearrange(gradients_backward, arr_gradients_backward );
775template <
typename Architecture_t>
777 const Matrix_t & precStateActivations,
786 return Architecture_t::GRULayerBackward(state_gradients_backward,
791 precStateActivations,
792 reset_gate, update_gate, candidate_gate,
800template <
typename Architecture_t>
808template<
typename Architecture_t>
812 std::cout <<
" GRU Layer: \t ";
815 std::cout <<
", NTime = " << this->
GetTimeSteps() <<
" )";
816 std::cout <<
"\tOutput = ( " << this->
GetOutput().GetFirstSize() <<
" , " << this->
GetOutput()[0].GetNrows() <<
" , " << this->
GetOutput()[0].GetNcols() <<
" )\n";
820template <
typename Architecture_t>
847template <
typename Architecture_t>
const Matrix_t & GetWeightsCandidate() const
Matrix_t & GetWeightsCandidateStateGradients()
typename Architecture_t::RecurrentDescriptor_t LayerDescriptor_t
void Forward(Tensor_t &input, bool isTraining=true) override
Computes the next hidden state and next cell state with given input matrix.
Matrix_t & GetWeightsResetGate()
Matrix_t & fResetBiasGradients
Gradients w.r.t the reset gate - bias weights.
std::vector< Matrix_t > & GetUpdateGateTensor()
typename Architecture_t::Tensor_t Tensor_t
std::vector< Matrix_t > reset_gate_value
Reset gate value for every time step.
Matrix_t & CellBackward(Matrix_t &state_gradients_backward, const Matrix_t &precStateActivations, const Matrix_t &reset_gate, const Matrix_t &update_gate, const Matrix_t &candidate_gate, const Matrix_t &input, Matrix_t &input_gradient, Matrix_t &dr, Matrix_t &du, Matrix_t &dc)
Backward for a single time unit a the corresponding call to Forward(...).
size_t fStateSize
Hidden state size for GRU.
const Matrix_t & GetWeightsResetGradients() const
const Matrix_t & GetUpdateBiasGradients() const
bool fReturnSequence
Return in output full sequence or just last element.
const Matrix_t & GetWeightsResetStateGradients() const
std::vector< Matrix_t > fDerivativesReset
First fDerivatives of the activations reset gate.
const Tensor_t & GetWeightsTensor() const
std::vector< Matrix_t > & GetResetGateTensor()
Matrix_t & GetWeightsUpdateGateState()
const std::vector< Matrix_t > & GetCandidateGateTensor() const
const Matrix_t & GetUpdateDerivativesAt(size_t i) const
Matrix_t & GetWeightsUpdateStateGradients()
void Print() const override
Prints the info about the layer.
size_t GetInputSize() const
Getters.
Matrix_t fState
Hidden state of GRU.
Matrix_t & GetWeightsResetGradients()
Tensor_t & GetWeightGradientsTensor()
const Matrix_t & GetCandidateBias() const
std::vector< Matrix_t > update_gate_value
Update gate value for every time step.
Tensor_t & GetWeightsTensor()
Tensor_t fX
cached input tensor as T x B x I
Matrix_t & GetCandidateGateTensorAt(size_t i)
Matrix_t & GetResetBiasGradients()
Matrix_t & GetCandidateValue()
void AddWeightsXMLTo(void *parent) override
Writes the information and the weights about the layer in an XML node.
Matrix_t & GetWeightsResetGateState()
DNN::EActivationFunction fF1
Activation function: sigmoid.
const Matrix_t & GetWeightsUpdateGate() const
const std::vector< Matrix_t > & GetDerivativesReset() const
const Matrix_t & GetUpdateGateBias() const
Matrix_t & fWeightsResetGradients
Gradients w.r.t the reset gate - input weights.
std::vector< Matrix_t > & GetDerivativesUpdate()
Matrix_t & fCandidateBiasGradients
Gradients w.r.t the candidate gate - bias weights.
Matrix_t & fCandidateBias
Candidate Gate bias.
Matrix_t & GetUpdateGateTensorAt(size_t i)
DNN::EActivationFunction fF2
Activation function: tanh.
const Matrix_t & GetWeightsUpdateGradients() const
Matrix_t & GetWeightsCandidateGradients()
Matrix_t & fWeightsUpdateStateGradients
Gradients w.r.t the update gate - hidden state weights.
Matrix_t & GetWeightsCandidate()
Matrix_t & fWeightsUpdateGradients
Gradients w.r.t the update gate - input weights.
Matrix_t & GetUpdateGateValue()
size_t fTimeSteps
Timesteps for GRU.
std::vector< Matrix_t > fDerivativesCandidate
First fDerivatives of the activations candidate gate.
const Tensor_t & GetWeightGradientsTensor() const
typename Architecture_t::FilterDescriptor_t WeightsDescriptor_t
Tensor_t fWeightGradientsTensor
Tensor for all weight gradients.
Matrix_t & fUpdateBiasGradients
Gradients w.r.t the update gate - bias weights.
Matrix_t & GetWeightsResetStateGradients()
std::vector< Matrix_t > & GetCandidateGateTensor()
Matrix_t & fWeightsResetGate
Reset Gate weights for input, fWeights[0].
const Matrix_t & GetResetDerivativesAt(size_t i) const
Matrix_t & GetWeightsUpdateGate()
typename Architecture_t::Matrix_t Matrix_t
const Matrix_t & GetCandidateGateTensorAt(size_t i) const
Matrix_t & GetWeightsCandidateState()
const Matrix_t & GetCandidateBiasGradients() const
Matrix_t & GetResetGateTensorAt(size_t i)
Matrix_t & fResetGateBias
Input Gate bias.
const std::vector< Matrix_t > & GetResetGateTensor() const
Matrix_t fCell
Empty matrix for GRU.
std::vector< Matrix_t > candidate_gate_value
Candidate gate value for every time step.
typename Architecture_t::Scalar_t Scalar_t
const Matrix_t & GetWeigthsUpdateStateGradients() const
const Matrix_t & GetCandidateValue() const
Matrix_t & GetCandidateBiasGradients()
Matrix_t & fWeightsCandidateStateGradients
Gradients w.r.t the candidate gate - hidden state weights.
const std::vector< Matrix_t > & GetDerivativesUpdate() const
const Matrix_t & GetCell() const
void Initialize() override
Initialize the weights according to the given initialization method.
void UpdateGate(const Matrix_t &input, Matrix_t &df)
Forgets the past values (NN with Sigmoid).
const Matrix_t & GetCandidateDerivativesAt(size_t i) const
Matrix_t fResetValue
Computed reset gate values.
DNN::EActivationFunction GetActivationFunctionF2() const
Matrix_t & GetResetGateBias()
typename Architecture_t::RNNWorkspace_t RNNWorkspace_t
Matrix_t fUpdateValue
Computed forget gate values.
const Matrix_t & GetResetBiasGradients() const
bool fResetGateAfter
GRU variant to Apply the reset gate multiplication afterwards (used by cuDNN).
const Matrix_t & GetWeightsCandidateGradients() const
DNN::EActivationFunction GetActivationFunctionF1() const
bool DoesReturnSequence() const
Matrix_t & GetUpdateBiasGradients()
const Matrix_t & GetUpdateGateTensorAt(size_t i) const
Matrix_t & fWeightsResetGateState
Input Gate weights for prev state, fWeights[1].
Matrix_t & fWeightsUpdateGateState
Update Gate weights for prev state, fWeights[3].
const std::vector< Matrix_t > & GetDerivativesCandidate() const
Tensor_t fWeightsTensor
Tensor for all weights.
typename Architecture_t::RNNDescriptors_t RNNDescriptors_t
const Matrix_t & GetResetGateBias() const
Matrix_t & GetResetDerivativesAt(size_t i)
const Matrix_t & GetUpdateGateValue() const
const Matrix_t & GetResetGateTensorAt(size_t i) const
TDescriptors * fDescriptors
Keeps all the RNN descriptors.
void CellForward(Matrix_t &updateGateValues, Matrix_t &candidateValues)
Forward for a single cell (time unit).
Matrix_t & GetWeightsUpdateGradients()
Matrix_t & fWeightsResetStateGradients
Gradients w.r.t the reset gate - hidden state weights.
Matrix_t & fWeightsCandidateState
Candidate Gate weights for prev state, fWeights[5].
size_t GetStateSize() const
std::vector< Matrix_t > & GetDerivativesReset()
Matrix_t & fUpdateGateBias
Update Gate bias.
void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward) override
Backpropagates the error.
const Matrix_t & GetWeightsCandidateStateGradients() const
void ResetGate(const Matrix_t &input, Matrix_t &di)
Decides the values we'll update (NN with Sigmoid).
const Matrix_t & GetWeightsResetGate() const
Tensor_t fDx
cached gradient on the input (output of backward) as T x B x I
Matrix_t & GetCandidateBias()
typename Architecture_t::TensorDescriptor_t TensorDescriptor_t
bool fRememberState
Remember state in next pass.
Matrix_t & fWeightsCandidate
Candidate Gate weights for input, fWeights[4].
Matrix_t & fWeightsCandidateGradients
Gradients w.r.t the candidate gate - input weights.
const Matrix_t & GetWeightsCandidateState() const
void ReadWeightsFromXML(void *parent) override
Read the information and the weights about the layer from XML node.
const std::vector< Matrix_t > & GetUpdateGateTensor() const
const Matrix_t & GetResetGateValue() const
void Update(const Scalar_t learningRate)
Tensor_t fY
cached output tensor as T x B x S
Matrix_t fCandidateValue
Computed candidate values.
const Matrix_t & GetState() const
Matrix_t & GetUpdateGateBias()
void InitState(DNN::EInitialization m=DNN::EInitialization::kZero)
Initialize the hidden state and cell state method.
Tensor_t fDy
cached activation gradient (input of backward) as T x B x S
Matrix_t & GetCandidateDerivativesAt(size_t i)
std::vector< Matrix_t > fDerivativesUpdate
First fDerivatives of the activations update gate.
size_t GetTimeSteps() const
const Matrix_t & GetWeightsUpdateGateState() const
std::vector< Matrix_t > & GetDerivativesCandidate()
bool DoesRememberState() const
const Matrix_t & GetWeightsResetGateState() const
void CandidateValue(const Matrix_t &input, Matrix_t &dc)
Decides the new candidate values (NN with Tanh).
TBasicGRULayer(size_t batchSize, size_t stateSize, size_t inputSize, size_t timeSteps, bool rememberState=false, bool returnSequence=false, bool resetGateAfter=false, DNN::EActivationFunction f1=DNN::EActivationFunction::kSigmoid, DNN::EActivationFunction f2=DNN::EActivationFunction::kTanh, bool training=true, DNN::EInitialization fA=DNN::EInitialization::kZero)
Constructor.
Matrix_t & GetUpdateDerivativesAt(size_t i)
Matrix_t & fWeightsUpdateGate
Update Gate weights for input, fWeights[2].
typename Architecture_t::DropoutDescriptor_t HelperDescriptor_t
Matrix_t & GetResetGateValue()
const Matrix_t & GetWeightsAt(size_t i) const
virtual void Initialize()
Initialize the weights and biases according to the given initialization method.
const Tensor_t & GetOutput() const
void WriteMatrixToXML(void *node, const char *name, const Matrix_t &matrix)
const Tensor_t & GetActivationGradients() const
const Matrix_t & GetBiasesAt(size_t i) const
const Matrix_t & GetBiasGradientsAt(size_t i) const
size_t GetBatchSize() const
Getters.
void ReadMatrixXML(void *node, const char *name, Matrix_t &matrix)
const Matrix_t & GetWeightGradientsAt(size_t i) const
VGeneralLayer(size_t BatchSize, size_t InputDepth, size_t InputHeight, size_t InputWidth, size_t Depth, size_t Height, size_t Width, size_t WeightsNSlices, size_t WeightsNRows, size_t WeightsNCols, size_t BiasesNSlices, size_t BiasesNRows, size_t BiasesNCols, size_t OutputNSlices, size_t OutputNRows, size_t OutputNCols, EInitialization Init)
Constructor.
size_t GetInputWidth() const
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=nullptr)
create new child element for parent node
XMLAttrPointer_t NewAttr(XMLNodePointer_t xmlnode, XMLNsPointer_t, const char *name, const char *value)
creates new attribute for xmlnode, namespaces are not supported for attributes
void evaluateDerivativeMatrix(typename Architecture_t::Matrix_t &B, EActivationFunction f, const typename Architecture_t::Matrix_t &A)
void evaluateMatrix(typename Architecture_t::Matrix_t &A, EActivationFunction f)
EActivationFunction
Enum that represents layer activation functions.
void initialize(typename Architecture_t::Matrix_t &A, EInitialization m)
create variable transformations