29#ifndef TMVA_DNN_RNN_LAYER
30#define TMVA_DNN_RNN_LAYER
55template<
typename Architecture_t>
61 using Tensor_t =
typename Architecture_t::Tensor_t;
62 using Matrix_t =
typename Architecture_t::Matrix_t;
63 using Scalar_t =
typename Architecture_t::Scalar_t;
112 TBasicRNNLayer(
size_t batchSize,
size_t stateSize,
size_t inputSize,
113 size_t timeSteps,
bool rememberState =
false,
bool returnSequence =
false,
141 const Tensor_t &activations_backward);
149 const Matrix_t & precStateActivations,
212template <
typename Architecture_t>
217 :
VGeneralLayer<Architecture_t>(batchSize, 1, timeSteps, inputSize, 1, (returnSequence) ? timeSteps : 1 ,
218 stateSize, 2, {stateSize, stateSize}, {inputSize, stateSize}, 1, {stateSize}, {1},
219 batchSize, (returnSequence) ? timeSteps : 1, stateSize, fA),
220 fTimeSteps(timeSteps), fStateSize(stateSize), fRememberState(rememberState), fReturnSequence(returnSequence), fF(
f), fState(batchSize, stateSize),
221 fWeightsInput(this->GetWeightsAt(0)), fWeightsState(this->GetWeightsAt(1)),
222 fBiases(this->GetBiasesAt(0)), fDerivatives(timeSteps, batchSize, stateSize),
223 fWeightInputGradients(this->GetWeightGradientsAt(0)), fWeightStateGradients(this->GetWeightGradientsAt(1)),
224 fBiasGradients(this->GetBiasGradientsAt(0)), fWeightsTensor({0}), fWeightGradientsTensor({0})
230template <
typename Architecture_t>
232 :
VGeneralLayer<Architecture_t>(layer), fTimeSteps(layer.fTimeSteps), fStateSize(layer.fStateSize),
233 fRememberState(layer.fRememberState), fReturnSequence(layer.fReturnSequence), fF(layer.GetActivationFunction()),
234 fState(layer.GetBatchSize(), layer.GetStateSize()),
235 fWeightsInput(this->GetWeightsAt(0)), fWeightsState(this->GetWeightsAt(1)), fBiases(this->GetBiasesAt(0)),
236 fDerivatives(layer.GetDerivatives().GetShape()), fWeightInputGradients(this->GetWeightGradientsAt(0)),
237 fWeightStateGradients(this->GetWeightGradientsAt(1)), fBiasGradients(this->GetBiasGradientsAt(0)),
238 fWeightsTensor({0}), fWeightGradientsTensor({0})
241 Architecture_t::Copy(fDerivatives, layer.GetDerivatives() );
244 Architecture_t::Copy(fState, layer.GetState());
248template <
typename Architecture_t>
252 Architecture_t::ReleaseRNNDescriptors(fDescriptors);
257 Architecture_t::FreeRNNWorkspace(fWorkspace);
263template<
typename Architecture_t>
273 Architecture_t::InitializeRNNDescriptors(fDescriptors,
this);
274 Architecture_t::InitializeRNNWorkspace(fWorkspace, fDescriptors,
this);
278template <
typename Architecture_t>
282 Architecture_t::InitializeRNNTensors(
this);
285template <
typename Architecture_t>
290 Architecture_t::InitializeActivationDescriptor(fActivationDesc,this->GetActivationFunction());
294template<
typename Architecture_t>
298 std::cout <<
" RECURRENT Layer: \t ";
299 std::cout <<
" (NInput = " << this->GetInputSize();
300 std::cout <<
", NState = " << this->GetStateSize();
301 std::cout <<
", NTime = " << this->GetTimeSteps() <<
" )";
302 std::cout <<
"\tOutput = ( " << this->GetOutput().GetFirstSize() <<
" , " << this->GetOutput().GetHSize() <<
" , " << this->GetOutput().GetWSize() <<
" )\n";
305template <
typename Architecture_t>
306auto debugMatrix(
const typename Architecture_t::Matrix_t &A,
const std::string
name =
"matrix")
309 std::cout <<
name <<
"\n";
310 for (
size_t i = 0; i < A.GetNrows(); ++i) {
311 for (
size_t j = 0; j < A.GetNcols(); ++j) {
312 std::cout << A(i, j) <<
" ";
316 std::cout <<
"********\n";
321template <
typename Architecture_t>
327 if (Architecture_t::IsCudnn()) {
332 Architecture_t::Rearrange(
x, input);
334 const auto &weights = this->GetWeightsAt(0);
339 auto &hx = this->GetState();
340 auto &cx = this->GetCell();
342 auto &hy = this->GetState();
343 auto &cy = this->GetCell();
348 Architecture_t::RNNForward(
x, hx, cx, weights,
y, hy, cy, rnnDesc, rnnWork, isTraining);
350 if (fReturnSequence) {
351 Architecture_t::Rearrange(this->GetOutput(),
y);
355 Tensor_t tmp = (
y.At(
y.GetShape()[0] - 1)).Reshape({
y.GetShape()[1], 1,
y.GetShape()[2]});
356 Architecture_t::Copy(this->GetOutput(), tmp);
368 Tensor_t arrInput (fTimeSteps, this->GetBatchSize(), this->GetInputWidth() );
370 Architecture_t::Rearrange(arrInput, input);
371 Tensor_t arrOutput ( fTimeSteps, this->GetBatchSize(), fStateSize);
376 for (
size_t t = 0; t < fTimeSteps; ++t) {
377 Matrix_t arrInput_m = arrInput.At(t).GetMatrix();
378 Matrix_t df_m = fDerivatives.At(t).GetMatrix();
379 CellForward(arrInput_m, df_m );
380 Matrix_t arrOutput_m = arrOutput.At(t).GetMatrix();
381 Architecture_t::Copy(arrOutput_m, fState);
385 Architecture_t::Rearrange(this->GetOutput(), arrOutput);
389 Tensor_t tmp = arrOutput.At(fTimeSteps - 1);
392 tmp = tmp.Reshape({tmp.GetShape()[0], tmp.GetShape()[1], 1});
393 assert(tmp.GetSize() == this->GetOutput().GetSize());
394 assert(tmp.GetShape()[0] == this->GetOutput().GetShape()[2]);
395 Architecture_t::Rearrange(this->GetOutput(), tmp);
402template <
typename Architecture_t>
408 Matrix_t tmpState(fState.GetNrows(), fState.GetNcols());
409 Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsState);
410 Architecture_t::MultiplyTranspose(fState, input, fWeightsInput);
411 Architecture_t::ScaleAdd(fState, tmpState);
412 Architecture_t::AddRowWise(fState, fBiases);
419 Architecture_t::Copy(inputActivFunc, tState);
420 Architecture_t::ActivationFunctionForward(tState, fAF, fActivationDesc);
425template <
typename Architecture_t>
427 const Tensor_t &activations_backward) ->
void
432 if (Architecture_t::IsCudnn() ) {
440 assert(activations_backward.GetStrides()[1] == this->GetInputSize() );
442 Architecture_t::Rearrange(
x, activations_backward);
444 if (!fReturnSequence) {
447 Architecture_t::InitializeZero(dy);
450 Tensor_t tmp2 = dy.At(dy.GetShape()[0] - 1).Reshape({dy.GetShape()[1], 1, dy.GetShape()[2]});
453 Architecture_t::Copy(tmp2, this->GetActivationGradients());
456 Architecture_t::Rearrange(
y, this->GetOutput());
457 Architecture_t::Rearrange(dy, this->GetActivationGradients());
463 const auto &weights = this->GetWeightsTensor();
464 auto &weightGradients = this->GetWeightGradientsTensor();
467 Architecture_t::InitializeZero(weightGradients);
470 auto &hx = this->GetState();
471 auto cx = this->GetCell();
482 Architecture_t::RNNBackward(
x, hx, cx,
y, dy, dhy, dcy, weights, dx, dhx, dcx, weightGradients, rnnDesc, rnnWork);
486 if (gradients_backward.GetSize() != 0)
487 Architecture_t::Rearrange(gradients_backward, dx);
500 if (gradients_backward.GetSize() == 0) {
503 Tensor_t arr_gradients_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
510 Tensor_t arr_activations_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
512 Architecture_t::Rearrange(arr_activations_backward, activations_backward);
514 Matrix_t state_gradients_backward(this->GetBatchSize(), fStateSize);
517 Matrix_t initState(this->GetBatchSize(), fStateSize);
520 Tensor_t arr_output ( fTimeSteps, this->GetBatchSize(), fStateSize);
521 Tensor_t arr_actgradients(fTimeSteps, this->GetBatchSize(), fStateSize);
523 if (fReturnSequence) {
524 Architecture_t::Rearrange(arr_output, this->GetOutput());
525 Architecture_t::Rearrange(arr_actgradients, this->GetActivationGradients());
530 Architecture_t::InitializeZero(arr_actgradients);
532 Tensor_t tmp_grad = arr_actgradients.At(fTimeSteps - 1).Reshape({this->GetBatchSize(), fStateSize, 1});
533 assert(tmp_grad.GetSize() == this->GetActivationGradients().GetSize());
534 assert(tmp_grad.GetShape()[0] ==
535 this->GetActivationGradients().GetShape()[2]);
537 Architecture_t::Rearrange(tmp_grad, this->GetActivationGradients());
541 fWeightInputGradients.Zero();
542 fWeightStateGradients.Zero();
543 fBiasGradients.Zero();
545 for (
size_t t = fTimeSteps; t > 0; t--) {
547 Matrix_t actgrad_m = arr_actgradients.At(t - 1).GetMatrix();
548 Architecture_t::ScaleAdd(state_gradients_backward, actgrad_m);
550 Matrix_t actbw_m = arr_activations_backward.At(t - 1).GetMatrix();
551 Matrix_t gradbw_m = arr_gradients_backward.At(t - 1).GetMatrix();
558 Architecture_t::ActivationFunctionBackward(df,
y,
560 this->GetActivationFunction(), fActivationDesc);
566 Matrix_t precStateActivations = arr_output.At(t - 2).GetMatrix();
567 CellBackward(state_gradients_backward, precStateActivations, actbw_m, gradbw_m, df_m);
570 const Matrix_t & precStateActivations = initState;
571 CellBackward(state_gradients_backward, precStateActivations, actbw_m, gradbw_m, df_m);
576 Architecture_t::Rearrange(gradients_backward, arr_gradients_backward );
581template <
typename Architecture_t>
583 const Matrix_t & precStateActivations,
587 return Architecture_t::RecurrentLayerBackward(state_gradients_backward, fWeightInputGradients, fWeightStateGradients,
588 fBiasGradients, dF, precStateActivations, fWeightsInput,
589 fWeightsState, input, input_gradient);
593template <
typename Architecture_t>
606 this->WriteMatrixToXML(layerxml,
"InputWeights",
this -> GetWeightsAt(0));
607 this->WriteMatrixToXML(layerxml,
"StateWeights",
this -> GetWeightsAt(1));
608 this->WriteMatrixToXML(layerxml,
"Biases",
this -> GetBiasesAt(0));
614template <
typename Architecture_t>
618 this->ReadMatrixXML(parent,
"InputWeights",
this -> GetWeightsAt(0));
619 this->ReadMatrixXML(parent,
"StateWeights",
this -> GetWeightsAt(1));
620 this->ReadMatrixXML(parent,
"Biases",
this -> GetBiasesAt(0));
Tensor_t fDy
cached activation gradient (input of backward) as T x B x S
size_t GetStateSize() const
typename Architecture_t::RNNDescriptors_t RNNDescriptors_t
DNN::EActivationFunction GetActivationFunction() const
void InitState(DNN::EInitialization m=DNN::EInitialization::kZero)
Initialize the state method.
const Matrix_t & GetWeightInputGradients() const
const Tensor_t & GetWeightGradientsTensor() const
void Print() const
Prints the info about the layer.
typename Architecture_t::RecurrentDescriptor_t LayerDescriptor_t
Tensor_t fY
cached output tensor as T x B x S
Tensor_t fDerivatives
First fDerivatives of the activations.
const Matrix_t & GetWeightStateGradients() const
Matrix_t & fWeightsInput
Input weights, fWeights[0].
Matrix_t & fWeightsState
Prev state weights, fWeights[1].
virtual ~TBasicRNNLayer()
Destructor.
TDescriptors * fDescriptors
Keeps all the RNN descriptors.
Tensor_t fX
cached input tensor as T x B x I
void Forward(Tensor_t &input, bool isTraining=true)
Compute and return the next state with given input matrix.
Matrix_t & fBiases
Biases.
Architecture_t::ActivationDescriptor_t fActivationDesc
virtual void ReadWeightsFromXML(void *parent)
Read the information and the weights about the layer from XML node.
typename Architecture_t::TensorDescriptor_t TensorDescriptor_t
bool fReturnSequence
Return in output full sequence or just last element in time.
const Tensor_t & GetWeightsTensor() const
Matrix_t & GetBiasStateGradients()
Tensor_t fWeightGradientsTensor
size_t fStateSize
Hidden state size of RNN.
Matrix_t & GetWeightsState()
const Matrix_t & GetState() const
void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward)
Backpropagates the error.
Tensor_t & GetDerivatives()
const Matrix_t & GetCell() const
Matrix_t & CellBackward(Matrix_t &state_gradients_backward, const Matrix_t &precStateActivations, const Matrix_t &input, Matrix_t &input_gradient, Matrix_t &dF)
Backward for a single time unit a the corresponding call to Forward(...).
typename Architecture_t::Matrix_t Matrix_t
typename Architecture_t::DropoutDescriptor_t HelperDescriptor_t
typename Architecture_t::RNNWorkspace_t RNNWorkspace_t
Matrix_t fState
Hidden State.
Matrix_t & fWeightInputGradients
Gradients w.r.t. the input weights.
DNN::EActivationFunction fF
Activation function of the hidden state.
TBasicRNNLayer(size_t batchSize, size_t stateSize, size_t inputSize, size_t timeSteps, bool rememberState=false, bool returnSequence=false, DNN::EActivationFunction f=DNN::EActivationFunction::kTanh, bool training=true, DNN::EInitialization fA=DNN::EInitialization::kZero)
Constructor.
Tensor_t & GetWeightGradientsTensor()
Matrix_t & GetWeightsInput()
size_t GetTimeSteps() const
Getters.
bool fRememberState
Remember state in next pass.
Tensor_t & GetWeightsTensor()
Matrix_t & fWeightStateGradients
Gradients w.r.t. the recurring weights.
Matrix_t & GetWeightInputGradients()
const Matrix_t & GetBiasesState() const
void Update(const Scalar_t learningRate)
virtual void AddWeightsXMLTo(void *parent)
Writes the information and the weights about the layer in an XML node.
typename Architecture_t::Scalar_t Scalar_t
size_t fTimeSteps
Timesteps for RNN.
bool DoesRememberState() const
void CellForward(const Matrix_t &input, Matrix_t &dF)
Forward for a single cell (time unit)
Tensor_t fDx
cached gradient on the input (output of backward) as T x B x I
typename Architecture_t::Tensor_t Tensor_t
const Matrix_t & GetBiasStateGradients() const
size_t GetInputSize() const
Matrix_t & GetWeightStateGradients()
bool DoesReturnSequence() const
Matrix_t & fBiasGradients
Gradients w.r.t. the bias values.
const Matrix_t & GetWeightsInput() const
Matrix_t fCell
Empty matrix for RNN.
const Tensor_t & GetDerivatives() const
Matrix_t & GetBiasesState()
virtual void Initialize()
Initialize the weights according to the given initialization method.
const Matrix_t & GetWeightsState() const
typename Architecture_t::FilterDescriptor_t WeightsDescriptor_t
Generic General Layer class.
virtual void Initialize()
Initialize the weights and biases according to the given initialization method.
size_t GetInputWidth() const
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=nullptr)
create new child element for parent node
XMLAttrPointer_t NewAttr(XMLNodePointer_t xmlnode, XMLNsPointer_t, const char *name, const char *value)
creates new attribute for xmlnode, namespaces are not supported for attributes
auto debugMatrix(const typename Architecture_t::Matrix_t &A, const std::string name="matrix") -> void
EActivationFunction
Enum that represents layer activation functions.
create variable transformations