29#ifndef TMVA_DNN_RNN_LAYER
30#define TMVA_DNN_RNN_LAYER
54template<
typename Architecture_t>
60 using Tensor_t =
typename Architecture_t::Tensor_t;
61 using Matrix_t =
typename Architecture_t::Matrix_t;
62 using Scalar_t =
typename Architecture_t::Scalar_t;
111 TBasicRNNLayer(
size_t batchSize,
size_t stateSize,
size_t inputSize,
112 size_t timeSteps,
bool rememberState =
false,
bool returnSequence =
false,
140 const Tensor_t &activations_backward);
148 const Matrix_t & precStateActivations,
211template <
typename Architecture_t>
216 :
VGeneralLayer<Architecture_t>(batchSize, 1, timeSteps, inputSize, 1, (returnSequence) ? timeSteps : 1 ,
217 stateSize, 2, {stateSize, stateSize}, {inputSize, stateSize}, 1, {stateSize}, {1},
218 batchSize, (returnSequence) ? timeSteps : 1, stateSize, fA),
219 fTimeSteps(timeSteps), fStateSize(stateSize), fRememberState(rememberState), fReturnSequence(returnSequence), fF(
f), fState(batchSize, stateSize),
220 fWeightsInput(this->GetWeightsAt(0)), fWeightsState(this->GetWeightsAt(1)),
221 fBiases(this->GetBiasesAt(0)), fDerivatives(timeSteps, batchSize, stateSize),
222 fWeightInputGradients(this->GetWeightGradientsAt(0)), fWeightStateGradients(this->GetWeightGradientsAt(1)),
223 fBiasGradients(this->GetBiasGradientsAt(0)), fWeightsTensor({0}), fWeightGradientsTensor({0})
229template <
typename Architecture_t>
231 :
VGeneralLayer<Architecture_t>(layer), fTimeSteps(layer.fTimeSteps), fStateSize(layer.fStateSize),
232 fRememberState(layer.fRememberState), fReturnSequence(layer.fReturnSequence), fF(layer.GetActivationFunction()),
233 fState(layer.GetBatchSize(), layer.GetStateSize()),
234 fWeightsInput(this->GetWeightsAt(0)), fWeightsState(this->GetWeightsAt(1)), fBiases(this->GetBiasesAt(0)),
235 fDerivatives(layer.GetDerivatives().GetShape()), fWeightInputGradients(this->GetWeightGradientsAt(0)),
236 fWeightStateGradients(this->GetWeightGradientsAt(1)), fBiasGradients(this->GetBiasGradientsAt(0)),
237 fWeightsTensor({0}), fWeightGradientsTensor({0})
247template <
typename Architecture_t>
251 Architecture_t::ReleaseRNNDescriptors(fDescriptors);
256 Architecture_t::FreeRNNWorkspace(fWorkspace);
262template<
typename Architecture_t>
272 Architecture_t::InitializeRNNDescriptors(fDescriptors,
this);
273 Architecture_t::InitializeRNNWorkspace(fWorkspace, fDescriptors,
this);
277template <
typename Architecture_t>
281 Architecture_t::InitializeRNNTensors(
this);
284template <
typename Architecture_t>
289 Architecture_t::InitializeActivationDescriptor(fActivationDesc,this->GetActivationFunction());
293template<
typename Architecture_t>
297 std::cout <<
" RECURRENT Layer: \t ";
298 std::cout <<
" (NInput = " << this->GetInputSize();
299 std::cout <<
", NState = " << this->GetStateSize();
300 std::cout <<
", NTime = " << this->GetTimeSteps() <<
" )";
301 std::cout <<
"\tOutput = ( " << this->GetOutput().GetFirstSize() <<
" , " << this->GetOutput().GetHSize() <<
" , " << this->GetOutput().GetWSize() <<
" )\n";
304template <
typename Architecture_t>
305auto debugMatrix(
const typename Architecture_t::Matrix_t &
A,
const std::string
name =
"matrix")
308 std::cout <<
name <<
"\n";
309 for (
size_t i = 0; i <
A.GetNrows(); ++i) {
310 for (
size_t j = 0; j <
A.GetNcols(); ++j) {
311 std::cout <<
A(i, j) <<
" ";
315 std::cout <<
"********\n";
320template <
typename Architecture_t>
326 if (Architecture_t::IsCudnn()) {
331 Architecture_t::Rearrange(
x, input);
333 const auto &weights = this->GetWeightsAt(0);
338 auto &hx = this->GetState();
339 auto &cx = this->GetCell();
341 auto &hy = this->GetState();
342 auto &cy = this->GetCell();
347 Architecture_t::RNNForward(
x, hx, cx, weights,
y, hy, cy, rnnDesc, rnnWork, isTraining);
349 if (fReturnSequence) {
350 Architecture_t::Rearrange(this->GetOutput(),
y);
354 Tensor_t tmp = (
y.At(
y.GetShape()[0] - 1)).Reshape({
y.GetShape()[1], 1,
y.GetShape()[2]});
367 Tensor_t arrInput (fTimeSteps, this->GetBatchSize(), this->GetInputWidth() );
369 Architecture_t::Rearrange(arrInput, input);
370 Tensor_t arrOutput ( fTimeSteps, this->GetBatchSize(), fStateSize);
375 for (
size_t t = 0; t < fTimeSteps; ++t) {
376 Matrix_t arrInput_m = arrInput.At(t).GetMatrix();
377 Matrix_t df_m = fDerivatives.At(t).GetMatrix();
378 CellForward(arrInput_m, df_m );
379 Matrix_t arrOutput_m = arrOutput.At(t).GetMatrix();
384 Architecture_t::Rearrange(this->GetOutput(), arrOutput);
388 Tensor_t tmp = arrOutput.At(fTimeSteps - 1);
391 tmp = tmp.Reshape({tmp.GetShape()[0], tmp.GetShape()[1], 1});
392 assert(tmp.GetSize() == this->GetOutput().GetSize());
393 assert(tmp.GetShape()[0] == this->GetOutput().GetShape()[2]);
394 Architecture_t::Rearrange(this->GetOutput(), tmp);
401template <
typename Architecture_t>
407 Matrix_t tmpState(fState.GetNrows(), fState.GetNcols());
408 Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsState);
409 Architecture_t::MultiplyTranspose(fState, input, fWeightsInput);
410 Architecture_t::ScaleAdd(fState, tmpState);
411 Architecture_t::AddRowWise(fState, fBiases);
419 Architecture_t::ActivationFunctionForward(tState, fAF, fActivationDesc);
424template <
typename Architecture_t>
426 const Tensor_t &activations_backward) ->
void
431 if (Architecture_t::IsCudnn() ) {
439 assert(activations_backward.GetStrides()[1] == this->GetInputSize() );
441 Architecture_t::Rearrange(
x, activations_backward);
443 if (!fReturnSequence) {
446 Architecture_t::InitializeZero(dy);
449 Tensor_t tmp2 = dy.At(dy.GetShape()[0] - 1).Reshape({dy.GetShape()[1], 1, dy.GetShape()[2]});
455 Architecture_t::Rearrange(
y, this->GetOutput());
456 Architecture_t::Rearrange(dy, this->GetActivationGradients());
462 const auto &weights = this->GetWeightsTensor();
463 auto &weightGradients = this->GetWeightGradientsTensor();
466 Architecture_t::InitializeZero(weightGradients);
469 auto &hx = this->GetState();
470 auto cx = this->GetCell();
481 Architecture_t::RNNBackward(
x, hx, cx,
y, dy, dhy, dcy, weights, dx, dhx, dcx, weightGradients, rnnDesc, rnnWork);
485 if (gradients_backward.GetSize() != 0)
486 Architecture_t::Rearrange(gradients_backward, dx);
499 if (gradients_backward.GetSize() == 0) {
502 Tensor_t arr_gradients_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
509 Tensor_t arr_activations_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
511 Architecture_t::Rearrange(arr_activations_backward, activations_backward);
513 Matrix_t state_gradients_backward(this->GetBatchSize(), fStateSize);
516 Matrix_t initState(this->GetBatchSize(), fStateSize);
519 Tensor_t arr_output ( fTimeSteps, this->GetBatchSize(), fStateSize);
520 Tensor_t arr_actgradients(fTimeSteps, this->GetBatchSize(), fStateSize);
522 if (fReturnSequence) {
523 Architecture_t::Rearrange(arr_output, this->GetOutput());
524 Architecture_t::Rearrange(arr_actgradients, this->GetActivationGradients());
529 Architecture_t::InitializeZero(arr_actgradients);
531 Tensor_t tmp_grad = arr_actgradients.At(fTimeSteps - 1).Reshape({this->GetBatchSize(), fStateSize, 1});
532 assert(tmp_grad.GetSize() == this->GetActivationGradients().GetSize());
533 assert(tmp_grad.GetShape()[0] ==
534 this->GetActivationGradients().GetShape()[2]);
536 Architecture_t::Rearrange(tmp_grad, this->GetActivationGradients());
540 fWeightInputGradients.Zero();
541 fWeightStateGradients.Zero();
542 fBiasGradients.Zero();
544 for (
size_t t = fTimeSteps; t > 0; t--) {
546 Matrix_t actgrad_m = arr_actgradients.At(t - 1).GetMatrix();
547 Architecture_t::ScaleAdd(state_gradients_backward, actgrad_m);
549 Matrix_t actbw_m = arr_activations_backward.At(t - 1).GetMatrix();
550 Matrix_t gradbw_m = arr_gradients_backward.At(t - 1).GetMatrix();
557 Architecture_t::ActivationFunctionBackward(df,
y,
559 this->GetActivationFunction(), fActivationDesc);
565 Matrix_t precStateActivations = arr_output.At(t - 2).GetMatrix();
566 CellBackward(state_gradients_backward, precStateActivations, actbw_m, gradbw_m, df_m);
569 const Matrix_t & precStateActivations = initState;
570 CellBackward(state_gradients_backward, precStateActivations, actbw_m, gradbw_m, df_m);
575 Architecture_t::Rearrange(gradients_backward, arr_gradients_backward );
580template <
typename Architecture_t>
582 const Matrix_t & precStateActivations,
586 return Architecture_t::RecurrentLayerBackward(state_gradients_backward, fWeightInputGradients, fWeightStateGradients,
587 fBiasGradients, dF, precStateActivations, fWeightsInput,
588 fWeightsState, input, input_gradient);
592template <
typename Architecture_t>
605 this->WriteMatrixToXML(layerxml,
"InputWeights",
this -> GetWeightsAt(0));
606 this->WriteMatrixToXML(layerxml,
"StateWeights",
this -> GetWeightsAt(1));
607 this->WriteMatrixToXML(layerxml,
"Biases",
this -> GetBiasesAt(0));
613template <
typename Architecture_t>
617 this->ReadMatrixXML(parent,
"InputWeights",
this -> GetWeightsAt(0));
618 this->ReadMatrixXML(parent,
"StateWeights",
this -> GetWeightsAt(1));
619 this->ReadMatrixXML(parent,
"Biases",
this -> GetBiasesAt(0));
static RooMathCoreReg dummy
Tensor_t fDy
cached activation gradient (input of backward) as T x B x S
size_t GetStateSize() const
typename Architecture_t::RNNDescriptors_t RNNDescriptors_t
DNN::EActivationFunction GetActivationFunction() const
void InitState(DNN::EInitialization m=DNN::EInitialization::kZero)
Initialize the state method.
const Matrix_t & GetWeightInputGradients() const
const Tensor_t & GetWeightGradientsTensor() const
void Print() const
Prints the info about the layer.
typename Architecture_t::RecurrentDescriptor_t LayerDescriptor_t
Tensor_t fY
cached output tensor as T x B x S
Tensor_t fDerivatives
First fDerivatives of the activations.
const Matrix_t & GetWeightStateGradients() const
Matrix_t & fWeightsInput
Input weights, fWeights[0].
Matrix_t & fWeightsState
Prev state weights, fWeights[1].
virtual ~TBasicRNNLayer()
Destructor.
TDescriptors * fDescriptors
Keeps all the RNN descriptors.
Tensor_t fX
cached input tensor as T x B x I
void Forward(Tensor_t &input, bool isTraining=true)
Compute and return the next state with given input matrix.
Matrix_t & fBiases
Biases.
Architecture_t::ActivationDescriptor_t fActivationDesc
virtual void ReadWeightsFromXML(void *parent)
Read the information and the weights about the layer from XML node.
typename Architecture_t::TensorDescriptor_t TensorDescriptor_t
bool fReturnSequence
Return in output full sequence or just last element in time.
const Tensor_t & GetWeightsTensor() const
Matrix_t & GetBiasStateGradients()
Tensor_t fWeightGradientsTensor
size_t fStateSize
Hidden state size of RNN.
Matrix_t & GetWeightsState()
const Matrix_t & GetState() const
void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward)
Backpropagates the error.
Tensor_t & GetDerivatives()
const Matrix_t & GetCell() const
Matrix_t & CellBackward(Matrix_t &state_gradients_backward, const Matrix_t &precStateActivations, const Matrix_t &input, Matrix_t &input_gradient, Matrix_t &dF)
Backward for a single time unit a the corresponding call to Forward(...).
typename Architecture_t::Matrix_t Matrix_t
typename Architecture_t::DropoutDescriptor_t HelperDescriptor_t
typename Architecture_t::RNNWorkspace_t RNNWorkspace_t
Matrix_t fState
Hidden State.
Matrix_t & fWeightInputGradients
Gradients w.r.t. the input weights.
DNN::EActivationFunction fF
Activation function of the hidden state.
TBasicRNNLayer(size_t batchSize, size_t stateSize, size_t inputSize, size_t timeSteps, bool rememberState=false, bool returnSequence=false, DNN::EActivationFunction f=DNN::EActivationFunction::kTanh, bool training=true, DNN::EInitialization fA=DNN::EInitialization::kZero)
Constructor.
Tensor_t & GetWeightGradientsTensor()
Matrix_t & GetWeightsInput()
size_t GetTimeSteps() const
Getters.
bool fRememberState
Remember state in next pass.
Tensor_t & GetWeightsTensor()
Matrix_t & fWeightStateGradients
Gradients w.r.t. the recurring weights.
Matrix_t & GetWeightInputGradients()
const Matrix_t & GetBiasesState() const
void Update(const Scalar_t learningRate)
virtual void AddWeightsXMLTo(void *parent)
Writes the information and the weights about the layer in an XML node.
size_t fTimeSteps
Timesteps for RNN.
bool DoesRememberState() const
void CellForward(const Matrix_t &input, Matrix_t &dF)
Forward for a single cell (time unit)
Tensor_t fDx
cached gradient on the input (output of backward) as T x B x I
const Matrix_t & GetBiasStateGradients() const
size_t GetInputSize() const
Matrix_t & GetWeightStateGradients()
bool DoesReturnSequence() const
Matrix_t & fBiasGradients
Gradients w.r.t. the bias values.
const Matrix_t & GetWeightsInput() const
Matrix_t fCell
Empty matrix for RNN.
const Tensor_t & GetDerivatives() const
Matrix_t & GetBiasesState()
virtual void Initialize()
Initialize the weights according to the given initialization method.
const Matrix_t & GetWeightsState() const
typename Architecture_t::FilterDescriptor_t WeightsDescriptor_t
Generic General Layer class.
virtual void Initialize()
Initialize the weights and biases according to the given initialization method.
typename Architecture_t::Scalar_t Scalar_t
typename Architecture_t::Tensor_t Tensor_t
size_t GetInputWidth() const
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=nullptr)
create new child element for parent node
XMLAttrPointer_t NewAttr(XMLNodePointer_t xmlnode, XMLNsPointer_t, const char *name, const char *value)
creates new attribute for xmlnode, namespaces are not supported for attributes
void Copy(void *source, void *dest)
auto debugMatrix(const typename Architecture_t::Matrix_t &A, const std::string name="matrix") -> void
EActivationFunction
Enum that represents layer activation functions.
create variable transformations