30#ifndef TMVA_DNN_GRU_LAYER
31#define TMVA_DNN_GRU_LAYER
55template<
typename Architecture_t>
61 using Matrix_t =
typename Architecture_t::Matrix_t;
62 using Scalar_t =
typename Architecture_t::Scalar_t;
63 using Tensor_t =
typename Architecture_t::Tensor_t;
307template <
typename Architecture_t>
319 fCandidateValue(batchSize,
stateSize), fState(batchSize,
stateSize), fWeightsResetGate(
this->GetWeightsAt(0)),
320 fWeightsResetGateState(
this->GetWeightsAt(3)), fResetGateBias(
this->GetBiasesAt(0)),
321 fWeightsUpdateGate(
this->GetWeightsAt(1)), fWeightsUpdateGateState(
this->GetWeightsAt(4)),
322 fUpdateGateBias(
this->GetBiasesAt(1)), fWeightsCandidate(
this->GetWeightsAt(2)),
323 fWeightsCandidateState(
this->GetWeightsAt(5)), fCandidateBias(
this->GetBiasesAt(2)),
324 fWeightsResetGradients(
this->GetWeightGradientsAt(0)), fWeightsResetStateGradients(
this->GetWeightGradientsAt(3)),
325 fResetBiasGradients(
this->GetBiasGradientsAt(0)), fWeightsUpdateGradients(
this->GetWeightGradientsAt(1)),
326 fWeightsUpdateStateGradients(
this->GetWeightGradientsAt(4)), fUpdateBiasGradients(
this->GetBiasGradientsAt(1)),
327 fWeightsCandidateGradients(
this->GetWeightGradientsAt(2)),
328 fWeightsCandidateStateGradients(
this->GetWeightGradientsAt(5)),
329 fCandidateBiasGradients(
this->GetBiasGradientsAt(2))
339 Architecture_t::InitializeGRUTensors(
this);
343template <
typename Architecture_t>
346 fStateSize(
layer.fStateSize),
347 fTimeSteps(
layer.fTimeSteps),
348 fRememberState(
layer.fRememberState),
349 fReturnSequence(
layer.fReturnSequence),
350 fResetGateAfter(
layer.fResetGateAfter),
351 fF1(
layer.GetActivationFunctionF1()),
352 fF2(
layer.GetActivationFunctionF2()),
353 fResetValue(
layer.GetBatchSize(),
layer.GetStateSize()),
354 fUpdateValue(
layer.GetBatchSize(),
layer.GetStateSize()),
355 fCandidateValue(
layer.GetBatchSize(),
layer.GetStateSize()),
356 fState(
layer.GetBatchSize(),
layer.GetStateSize()),
357 fWeightsResetGate(
this->GetWeightsAt(0)),
358 fWeightsResetGateState(
this->GetWeightsAt(3)),
359 fResetGateBias(
this->GetBiasesAt(0)),
360 fWeightsUpdateGate(
this->GetWeightsAt(1)),
361 fWeightsUpdateGateState(
this->GetWeightsAt(4)),
362 fUpdateGateBias(
this->GetBiasesAt(1)),
363 fWeightsCandidate(
this->GetWeightsAt(2)),
364 fWeightsCandidateState(
this->GetWeightsAt(5)),
365 fCandidateBias(
this->GetBiasesAt(2)),
366 fWeightsResetGradients(
this->GetWeightGradientsAt(0)),
367 fWeightsResetStateGradients(
this->GetWeightGradientsAt(3)),
368 fResetBiasGradients(
this->GetBiasGradientsAt(0)),
369 fWeightsUpdateGradients(
this->GetWeightGradientsAt(1)),
370 fWeightsUpdateStateGradients(
this->GetWeightGradientsAt(4)),
371 fUpdateBiasGradients(
this->GetBiasGradientsAt(1)),
372 fWeightsCandidateGradients(
this->GetWeightGradientsAt(2)),
373 fWeightsCandidateStateGradients(
this->GetWeightGradientsAt(5)),
374 fCandidateBiasGradients(
this->GetBiasGradientsAt(2))
404 Architecture_t::InitializeGRUTensors(
this);
408template <
typename Architecture_t>
413 Architecture_t::InitializeGRUDescriptors(fDescriptors,
this);
414 Architecture_t::InitializeGRUWorkspace(fWorkspace, fDescriptors,
this);
417 if (Architecture_t::IsCudnn())
418 fResetGateAfter =
true;
422template <
typename Architecture_t>
431 Architecture_t::MultiplyTranspose(
tmpState, fState, fWeightsResetGateState);
432 Architecture_t::MultiplyTranspose(fResetValue,
input, fWeightsResetGate);
433 Architecture_t::ScaleAdd(fResetValue,
tmpState);
434 Architecture_t::AddRowWise(fResetValue, fResetGateBias);
435 DNN::evaluateDerivativeMatrix<Architecture_t>(
dr,
fRst, fResetValue);
436 DNN::evaluateMatrix<Architecture_t>(fResetValue,
fRst);
440template <
typename Architecture_t>
449 Architecture_t::MultiplyTranspose(
tmpState, fState, fWeightsUpdateGateState);
450 Architecture_t::MultiplyTranspose(fUpdateValue,
input, fWeightsUpdateGate);
451 Architecture_t::ScaleAdd(fUpdateValue,
tmpState);
452 Architecture_t::AddRowWise(fUpdateValue, fUpdateGateBias);
453 DNN::evaluateDerivativeMatrix<Architecture_t>(
du,
fUpd, fUpdateValue);
454 DNN::evaluateMatrix<Architecture_t>(fUpdateValue,
fUpd);
458template <
typename Architecture_t>
475 Matrix_t tmp(fCandidateValue.GetNrows(), fCandidateValue.GetNcols());
476 if (!fResetGateAfter) {
478 Architecture_t::Hadamard(
tmpState, fState);
479 Architecture_t::MultiplyTranspose(tmp,
tmpState, fWeightsCandidateState);
482 Architecture_t::MultiplyTranspose(tmp, fState, fWeightsCandidateState);
483 Architecture_t::Hadamard(tmp, fResetValue);
485 Architecture_t::MultiplyTranspose(fCandidateValue,
input, fWeightsCandidate);
486 Architecture_t::ScaleAdd(fCandidateValue, tmp);
487 Architecture_t::AddRowWise(fCandidateValue, fCandidateBias);
488 DNN::evaluateDerivativeMatrix<Architecture_t>(
dc, fCan, fCandidateValue);
489 DNN::evaluateMatrix<Architecture_t>(fCandidateValue, fCan);
493template <
typename Architecture_t>
498 if (Architecture_t::IsCudnn()) {
505 Architecture_t::Rearrange(
x,
input);
508 const auto &weights = this->GetWeightsTensor();
510 auto &
hx = this->fState;
511 auto &
cx = this->fCell;
513 auto &
hy = this->fState;
514 auto &
cy = this->fCell;
521 if (fReturnSequence) {
522 Architecture_t::Rearrange(this->GetOutput(),
y);
525 Tensor_t tmp = (
y.At(
y.GetShape()[0] - 1)).Reshape({
y.GetShape()[1], 1,
y.GetShape()[2]});
526 Architecture_t::Copy(this->GetOutput(), tmp);
537 Tensor_t arrInput ( fTimeSteps, this->GetBatchSize(), this->GetInputWidth());
548 if (!this->fRememberState) {
554 for (
size_t t = 0; t < fTimeSteps; ++t) {
556 ResetGate(
arrInput[t], fDerivativesReset[t]);
557 Architecture_t::Copy(this->GetResetGateTensorAt(t), fResetValue);
558 UpdateGate(
arrInput[t], fDerivativesUpdate[t]);
559 Architecture_t::Copy(this->GetUpdateGateTensorAt(t), fUpdateValue);
561 CandidateValue(
arrInput[t], fDerivativesCandidate[t]);
562 Architecture_t::Copy(this->GetCandidateGateTensorAt(t), fCandidateValue);
565 CellForward(fUpdateValue, fCandidateValue);
574 Architecture_t::Rearrange(this->GetOutput(),
arrOutput);
580 tmp = tmp.Reshape({tmp.GetShape()[0], tmp.GetShape()[1], 1});
581 assert(tmp.GetSize() ==
this->GetOutput().GetSize());
582 assert(tmp.GetShape()[0] ==
this->GetOutput().GetShape()[2]);
583 Architecture_t::Rearrange(this->GetOutput(), tmp);
590template <
typename Architecture_t>
598 for (
size_t j = 0;
j < (size_t) tmp.GetNcols();
j++) {
599 for (
size_t i = 0; i < (size_t) tmp.GetNrows(); i++) {
600 tmp(i,
j) = 1 - tmp(i,
j);
610template <
typename Architecture_t>
616 if (Architecture_t::IsCudnn()) {
629 if (!fReturnSequence) {
632 Architecture_t::InitializeZero(
dy);
638 Architecture_t::Copy(
tmp2, this->GetActivationGradients());
640 Architecture_t::Rearrange(
y, this->GetOutput());
641 Architecture_t::Rearrange(
dy, this->GetActivationGradients());
647 const auto &weights = this->GetWeightsTensor();
655 auto &
hx = this->GetState();
656 auto &
cx = this->GetCell();
666 Architecture_t::RNNBackward(
x,
hx,
cx,
y,
dy,
dhy,
dcy, weights,
dx,
dhx,
dcx,
weightGradients,
rnnDesc,
rnnWork);
707 if (fReturnSequence) {
708 Architecture_t::Rearrange(
arr_output, this->GetOutput());
709 Architecture_t::Rearrange(
arr_actgradients, this->GetActivationGradients());
718 this->GetActivationGradients().GetShape()[2]);
720 Architecture_t::Rearrange(
tmp_grad, this->GetActivationGradients());
727 fWeightsResetGradients.Zero();
728 fWeightsResetStateGradients.Zero();
729 fResetBiasGradients.Zero();
732 fWeightsUpdateGradients.Zero();
733 fWeightsUpdateStateGradients.Zero();
734 fUpdateBiasGradients.Zero();
737 fWeightsCandidateGradients.Zero();
738 fWeightsCandidateStateGradients.Zero();
739 fCandidateBiasGradients.Zero();
742 for (
size_t t = fTimeSteps; t > 0; t--) {
750 this->GetResetGateTensorAt(t-1), this->GetUpdateGateTensorAt(t-1),
751 this->GetCandidateGateTensorAt(t-1),
753 fDerivativesReset[t-1], fDerivativesUpdate[t-1],
754 fDerivativesCandidate[t-1]);
759 this->GetResetGateTensorAt(t-1), this->GetUpdateGateTensorAt(t-1),
760 this->GetCandidateGateTensorAt(t-1),
762 fDerivativesReset[t-1], fDerivativesUpdate[t-1],
763 fDerivativesCandidate[t-1]);
775template <
typename Architecture_t>
787 fWeightsResetGradients, fWeightsUpdateGradients, fWeightsCandidateGradients,
788 fWeightsResetStateGradients, fWeightsUpdateStateGradients,
789 fWeightsCandidateStateGradients, fResetBiasGradients, fUpdateBiasGradients,
790 fCandidateBiasGradients,
dr,
du,
dc,
793 fWeightsResetGate, fWeightsUpdateGate, fWeightsCandidate,
794 fWeightsResetGateState, fWeightsUpdateGateState, fWeightsCandidateState,
800template <
typename Architecture_t>
808template<
typename Architecture_t>
812 std::cout <<
" GRU Layer: \t ";
813 std::cout <<
" (NInput = " << this->GetInputSize();
814 std::cout <<
", NState = " << this->GetStateSize();
815 std::cout <<
", NTime = " << this->GetTimeSteps() <<
" )";
816 std::cout <<
"\tOutput = ( " << this->GetOutput().GetFirstSize() <<
" , " << this->GetOutput()[0].GetNrows() <<
" , " << this->GetOutput()[0].GetNcols() <<
" )\n";
820template <
typename Architecture_t>
835 this->WriteMatrixToXML(
layerxml,
"ResetWeights", this->GetWeightsAt(0));
836 this->WriteMatrixToXML(
layerxml,
"ResetStateWeights", this->GetWeightsAt(1));
837 this->WriteMatrixToXML(
layerxml,
"ResetBiases", this->GetBiasesAt(0));
838 this->WriteMatrixToXML(
layerxml,
"UpdateWeights", this->GetWeightsAt(2));
839 this->WriteMatrixToXML(
layerxml,
"UpdateStateWeights", this->GetWeightsAt(3));
840 this->WriteMatrixToXML(
layerxml,
"UpdateBiases", this->GetBiasesAt(1));
841 this->WriteMatrixToXML(
layerxml,
"CandidateWeights", this->GetWeightsAt(4));
842 this->WriteMatrixToXML(
layerxml,
"CandidateStateWeights", this->GetWeightsAt(5));
843 this->WriteMatrixToXML(
layerxml,
"CandidateBiases", this->GetBiasesAt(2));
847template <
typename Architecture_t>
852 this->ReadMatrixXML(parent,
"ResetWeights", this->GetWeightsAt(0));
853 this->ReadMatrixXML(parent,
"ResetStateWeights", this->GetWeightsAt(1));
854 this->ReadMatrixXML(parent,
"ResetBiases", this->GetBiasesAt(0));
855 this->ReadMatrixXML(parent,
"UpdateWeights", this->GetWeightsAt(2));
856 this->ReadMatrixXML(parent,
"UpdateStateWeights", this->GetWeightsAt(3));
857 this->ReadMatrixXML(parent,
"UpdateBiases", this->GetBiasesAt(1));
858 this->ReadMatrixXML(parent,
"CandidateWeights", this->GetWeightsAt(4));
859 this->ReadMatrixXML(parent,
"CandidateStateWeights", this->GetWeightsAt(5));
860 this->ReadMatrixXML(parent,
"CandidateBiases", this->GetBiasesAt(2));
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
const Matrix_t & GetWeightsCandidate() const
Matrix_t & GetWeightsCandidateStateGradients()
typename Architecture_t::RecurrentDescriptor_t LayerDescriptor_t
Matrix_t & GetWeightsResetGate()
Matrix_t & fResetBiasGradients
Gradients w.r.t the reset gate - bias weights.
std::vector< Matrix_t > & GetUpdateGateTensor()
typename Architecture_t::Tensor_t Tensor_t
std::vector< Matrix_t > reset_gate_value
Reset gate value for every time step.
Matrix_t & CellBackward(Matrix_t &state_gradients_backward, const Matrix_t &precStateActivations, const Matrix_t &reset_gate, const Matrix_t &update_gate, const Matrix_t &candidate_gate, const Matrix_t &input, Matrix_t &input_gradient, Matrix_t &dr, Matrix_t &du, Matrix_t &dc)
Backward for a single time unit a the corresponding call to Forward(...).
size_t fStateSize
Hidden state size for GRU.
const Matrix_t & GetWeightsResetGradients() const
const Matrix_t & GetUpdateBiasGradients() const
bool fReturnSequence
Return in output full sequence or just last element.
void Forward(Tensor_t &input, bool isTraining=true)
Computes the next hidden state and next cell state with given input matrix.
const Matrix_t & GetWeightsResetStateGradients() const
std::vector< Matrix_t > fDerivativesReset
First fDerivatives of the activations reset gate.
const Tensor_t & GetWeightsTensor() const
std::vector< Matrix_t > & GetResetGateTensor()
Matrix_t & GetWeightsUpdateGateState()
const std::vector< Matrix_t > & GetCandidateGateTensor() const
const Matrix_t & GetUpdateDerivativesAt(size_t i) const
Matrix_t & GetWeightsUpdateStateGradients()
size_t GetInputSize() const
Getters.
Matrix_t fState
Hidden state of GRU.
Matrix_t & GetWeightsResetGradients()
Tensor_t & GetWeightGradientsTensor()
const Matrix_t & GetCandidateBias() const
std::vector< Matrix_t > update_gate_value
Update gate value for every time step.
Tensor_t & GetWeightsTensor()
Tensor_t fX
cached input tensor as T x B x I
Matrix_t & GetCandidateGateTensorAt(size_t i)
Matrix_t & GetResetBiasGradients()
Matrix_t & GetCandidateValue()
Matrix_t & GetWeightsResetGateState()
DNN::EActivationFunction fF1
Activation function: sigmoid.
const Matrix_t & GetWeightsUpdateGate() const
const std::vector< Matrix_t > & GetDerivativesReset() const
const Matrix_t & GetUpdateGateBias() const
Matrix_t & fWeightsResetGradients
Gradients w.r.t the reset gate - input weights.
std::vector< Matrix_t > & GetDerivativesUpdate()
Matrix_t & fCandidateBiasGradients
Gradients w.r.t the candidate gate - bias weights.
Matrix_t & fCandidateBias
Candidate Gate bias.
Matrix_t & GetUpdateGateTensorAt(size_t i)
DNN::EActivationFunction fF2
Activation function: tanh.
const Matrix_t & GetWeightsUpdateGradients() const
Matrix_t & GetWeightsCandidateGradients()
Matrix_t & fWeightsUpdateStateGradients
Gradients w.r.t the update gate - hidden state weights.
void AddWeightsXMLTo(void *parent)
Writes the information and the weights about the layer in an XML node.
Matrix_t & GetWeightsCandidate()
Matrix_t & fWeightsUpdateGradients
Gradients w.r.t the update gate - input weights.
Matrix_t & GetUpdateGateValue()
size_t fTimeSteps
Timesteps for GRU.
std::vector< Matrix_t > fDerivativesCandidate
First fDerivatives of the activations candidate gate.
const Tensor_t & GetWeightGradientsTensor() const
typename Architecture_t::FilterDescriptor_t WeightsDescriptor_t
Tensor_t fWeightGradientsTensor
Tensor for all weight gradients.
Matrix_t & fUpdateBiasGradients
Gradients w.r.t the update gate - bias weights.
Matrix_t & GetWeightsResetStateGradients()
std::vector< Matrix_t > & GetCandidateGateTensor()
Matrix_t & fWeightsResetGate
Reset Gate weights for input, fWeights[0].
const Matrix_t & GetResetDerivativesAt(size_t i) const
Matrix_t & GetWeightsUpdateGate()
typename Architecture_t::Matrix_t Matrix_t
void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward)
Backpropagates the error.
const Matrix_t & GetCandidateGateTensorAt(size_t i) const
Matrix_t & GetWeightsCandidateState()
const Matrix_t & GetCandidateBiasGradients() const
Matrix_t & GetResetGateTensorAt(size_t i)
Matrix_t & fResetGateBias
Input Gate bias.
const std::vector< Matrix_t > & GetResetGateTensor() const
Matrix_t fCell
Empty matrix for GRU.
std::vector< Matrix_t > candidate_gate_value
Candidate gate value for every time step.
typename Architecture_t::Scalar_t Scalar_t
const Matrix_t & GetWeigthsUpdateStateGradients() const
const Matrix_t & GetCandidateValue() const
Matrix_t & GetCandidateBiasGradients()
Matrix_t & fWeightsCandidateStateGradients
Gradients w.r.t the candidate gate - hidden state weights.
const std::vector< Matrix_t > & GetDerivativesUpdate() const
const Matrix_t & GetCell() const
void UpdateGate(const Matrix_t &input, Matrix_t &df)
Forgets the past values (NN with Sigmoid)
const Matrix_t & GetCandidateDerivativesAt(size_t i) const
Matrix_t fResetValue
Computed reset gate values.
DNN::EActivationFunction GetActivationFunctionF2() const
Matrix_t & GetResetGateBias()
typename Architecture_t::RNNWorkspace_t RNNWorkspace_t
Matrix_t fUpdateValue
Computed forget gate values.
const Matrix_t & GetResetBiasGradients() const
bool fResetGateAfter
GRU variant to Apply the reset gate multiplication afterwards (used by cuDNN)
const Matrix_t & GetWeightsCandidateGradients() const
DNN::EActivationFunction GetActivationFunctionF1() const
bool DoesReturnSequence() const
Matrix_t & GetUpdateBiasGradients()
const Matrix_t & GetUpdateGateTensorAt(size_t i) const
Matrix_t & fWeightsResetGateState
Input Gate weights for prev state, fWeights[1].
Matrix_t & fWeightsUpdateGateState
Update Gate weights for prev state, fWeights[3].
const std::vector< Matrix_t > & GetDerivativesCandidate() const
virtual void Initialize()
Initialize the weights according to the given initialization method.
Tensor_t fWeightsTensor
Tensor for all weights.
void ReadWeightsFromXML(void *parent)
Read the information and the weights about the layer from XML node.
typename Architecture_t::RNNDescriptors_t RNNDescriptors_t
const Matrix_t & GetResetGateBias() const
Matrix_t & GetResetDerivativesAt(size_t i)
const Matrix_t & GetUpdateGateValue() const
const Matrix_t & GetResetGateTensorAt(size_t i) const
TDescriptors * fDescriptors
Keeps all the RNN descriptors.
void CellForward(Matrix_t &updateGateValues, Matrix_t &candidateValues)
Forward for a single cell (time unit)
Matrix_t & GetWeightsUpdateGradients()
Matrix_t & fWeightsResetStateGradients
Gradients w.r.t the reset gate - hidden state weights.
Matrix_t & fWeightsCandidateState
Candidate Gate weights for prev state, fWeights[5].
void Print() const
Prints the info about the layer.
size_t GetStateSize() const
std::vector< Matrix_t > & GetDerivativesReset()
Matrix_t & fUpdateGateBias
Update Gate bias.
const Matrix_t & GetWeightsCandidateStateGradients() const
void ResetGate(const Matrix_t &input, Matrix_t &di)
Decides the values we'll update (NN with Sigmoid)
const Matrix_t & GetWeightsResetGate() const
Tensor_t fDx
cached gradient on the input (output of backward) as T x B x I
Matrix_t & GetCandidateBias()
typename Architecture_t::TensorDescriptor_t TensorDescriptor_t
bool fRememberState
Remember state in next pass.
Matrix_t & fWeightsCandidate
Candidate Gate weights for input, fWeights[4].
Matrix_t & fWeightsCandidateGradients
Gradients w.r.t the candidate gate - input weights.
const Matrix_t & GetWeightsCandidateState() const
const std::vector< Matrix_t > & GetUpdateGateTensor() const
const Matrix_t & GetResetGateValue() const
void Update(const Scalar_t learningRate)
Tensor_t fY
cached output tensor as T x B x S
Matrix_t fCandidateValue
Computed candidate values.
const Matrix_t & GetState() const
Matrix_t & GetUpdateGateBias()
void InitState(DNN::EInitialization m=DNN::EInitialization::kZero)
Initialize the hidden state and cell state method.
Tensor_t fDy
cached activation gradient (input of backward) as T x B x S
Matrix_t & GetCandidateDerivativesAt(size_t i)
std::vector< Matrix_t > fDerivativesUpdate
First fDerivatives of the activations update gate.
size_t GetTimeSteps() const
const Matrix_t & GetWeightsUpdateGateState() const
std::vector< Matrix_t > & GetDerivativesCandidate()
bool DoesRememberState() const
const Matrix_t & GetWeightsResetGateState() const
void CandidateValue(const Matrix_t &input, Matrix_t &dc)
Decides the new candidate values (NN with Tanh)
TBasicGRULayer(size_t batchSize, size_t stateSize, size_t inputSize, size_t timeSteps, bool rememberState=false, bool returnSequence=false, bool resetGateAfter=false, DNN::EActivationFunction f1=DNN::EActivationFunction::kSigmoid, DNN::EActivationFunction f2=DNN::EActivationFunction::kTanh, bool training=true, DNN::EInitialization fA=DNN::EInitialization::kZero)
Constructor.
Matrix_t & GetUpdateDerivativesAt(size_t i)
Matrix_t & fWeightsUpdateGate
Update Gate weights for input, fWeights[2].
typename Architecture_t::DropoutDescriptor_t HelperDescriptor_t
Matrix_t & GetResetGateValue()
Generic General Layer class.
virtual void Initialize()
Initialize the weights and biases according to the given initialization method.
size_t GetInputWidth() const
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=nullptr)
create new child element for parent node
XMLAttrPointer_t NewAttr(XMLNodePointer_t xmlnode, XMLNsPointer_t, const char *name, const char *value)
creates new attribute for xmlnode, namespaces are not supported for attributes
EActivationFunction
Enum that represents layer activation functions.
create variable transformations