29#ifndef TMVA_DNN_RNN_LAYER
30#define TMVA_DNN_RNN_LAYER
55template<
typename Architecture_t>
61 using Tensor_t =
typename Architecture_t::Tensor_t;
62 using Matrix_t =
typename Architecture_t::Matrix_t;
63 using Scalar_t =
typename Architecture_t::Scalar_t;
212template <
typename Architecture_t>
221 fWeightsInput(
this->GetWeightsAt(0)), fWeightsState(
this->GetWeightsAt(1)),
223 fWeightInputGradients(
this->GetWeightGradientsAt(0)), fWeightStateGradients(
this->GetWeightGradientsAt(1)),
224 fBiasGradients(
this->GetBiasGradientsAt(0)), fWeightsTensor({0}), fWeightGradientsTensor({0})
230template <
typename Architecture_t>
233 fRememberState(
layer.fRememberState), fReturnSequence(
layer.fReturnSequence), fF(
layer.GetActivationFunction()),
234 fState(
layer.GetBatchSize(),
layer.GetStateSize()),
235 fWeightsInput(
this->GetWeightsAt(0)), fWeightsState(
this->GetWeightsAt(1)), fBiases(
this->GetBiasesAt(0)),
236 fDerivatives(
layer.GetDerivatives().GetShape()), fWeightInputGradients(
this->GetWeightGradientsAt(0)),
237 fWeightStateGradients(
this->GetWeightGradientsAt(1)), fBiasGradients(
this->GetBiasGradientsAt(0)),
238 fWeightsTensor({0}), fWeightGradientsTensor({0})
241 Architecture_t::Copy(fDerivatives,
layer.GetDerivatives() );
244 Architecture_t::Copy(fState,
layer.GetState());
248template <
typename Architecture_t>
252 Architecture_t::ReleaseRNNDescriptors(fDescriptors);
257 Architecture_t::FreeRNNWorkspace(fWorkspace);
263template<
typename Architecture_t>
273 Architecture_t::InitializeRNNDescriptors(fDescriptors,
this);
274 Architecture_t::InitializeRNNWorkspace(fWorkspace, fDescriptors,
this);
278template <
typename Architecture_t>
282 Architecture_t::InitializeRNNTensors(
this);
285template <
typename Architecture_t>
290 Architecture_t::InitializeActivationDescriptor(fActivationDesc,this->GetActivationFunction());
294template<
typename Architecture_t>
298 std::cout <<
" RECURRENT Layer: \t ";
299 std::cout <<
" (NInput = " << this->GetInputSize();
300 std::cout <<
", NState = " << this->GetStateSize();
301 std::cout <<
", NTime = " << this->GetTimeSteps() <<
" )";
302 std::cout <<
"\tOutput = ( " << this->GetOutput().GetFirstSize() <<
" , " << this->GetOutput().GetHSize() <<
" , " << this->GetOutput().GetWSize() <<
" )\n";
305template <
typename Architecture_t>
306auto debugMatrix(
const typename Architecture_t::Matrix_t &A,
const std::string
name =
"matrix")
309 std::cout <<
name <<
"\n";
310 for (
size_t i = 0; i < A.GetNrows(); ++i) {
311 for (
size_t j = 0;
j < A.GetNcols(); ++
j) {
312 std::cout << A(i,
j) <<
" ";
316 std::cout <<
"********\n";
321template <
typename Architecture_t>
327 if (Architecture_t::IsCudnn()) {
332 Architecture_t::Rearrange(
x,
input);
337 const auto & weights = this->GetWeightsTensor();
343 auto &
hx = this->GetState();
344 auto &
cx = this->GetCell();
346 auto &
hy = this->GetState();
347 auto &
cy = this->GetCell();
356 if (fReturnSequence) {
357 Architecture_t::Rearrange(this->GetOutput(),
y);
361 Tensor_t tmp = (
y.At(
y.GetShape()[0] - 1)).Reshape({
y.GetShape()[1], 1,
y.GetShape()[2]});
362 Architecture_t::Copy(this->GetOutput(), tmp);
373 Tensor_t arrInput (fTimeSteps, this->GetBatchSize(), this->GetInputWidth() );
381 for (
size_t t = 0; t < fTimeSteps; ++t) {
390 Architecture_t::Rearrange(this->GetOutput(),
arrOutput);
397 tmp = tmp.Reshape({tmp.GetShape()[0], tmp.GetShape()[1], 1});
398 assert(tmp.GetSize() ==
this->GetOutput().GetSize());
399 assert(tmp.GetShape()[0] ==
this->GetOutput().GetShape()[2]);
400 Architecture_t::Rearrange(this->GetOutput(), tmp);
407template <
typename Architecture_t>
414 Architecture_t::MultiplyTranspose(
tmpState, fState, fWeightsState);
415 Architecture_t::MultiplyTranspose(fState,
input, fWeightsInput);
416 Architecture_t::ScaleAdd(fState,
tmpState);
417 Architecture_t::AddRowWise(fState, fBiases);
425 Architecture_t::ActivationFunctionForward(
tState,
fAF, fActivationDesc);
430template <
typename Architecture_t>
437 if (Architecture_t::IsCudnn() ) {
449 if (!fReturnSequence) {
452 Architecture_t::InitializeZero(
dy);
458 Architecture_t::Copy(
tmp2, this->GetActivationGradients());
461 Architecture_t::Rearrange(
y, this->GetOutput());
462 Architecture_t::Rearrange(
dy, this->GetActivationGradients());
469 auto &weights = this->GetWeightsTensor();
477 auto &
hx = this->GetState();
478 auto &
cx = this->GetCell();
489 Architecture_t::RNNBackward(
x,
hx,
cx,
y,
dy,
dhy,
dcy, weights,
dx,
dhx,
dcx,
weightGradients,
rnnDesc,
rnnWork);
528 if (fReturnSequence) {
529 Architecture_t::Rearrange(
arr_output, this->GetOutput());
530 Architecture_t::Rearrange(
arr_actgradients, this->GetActivationGradients());
540 this->GetActivationGradients().GetShape()[2]);
542 Architecture_t::Rearrange(
tmp_grad, this->GetActivationGradients());
546 fWeightInputGradients.Zero();
547 fWeightStateGradients.Zero();
548 fBiasGradients.Zero();
550 for (
size_t t = fTimeSteps; t > 0; t--) {
563 Architecture_t::ActivationFunctionBackward(df,
y,
565 this->GetActivationFunction(), fActivationDesc);
586template <
typename Architecture_t>
592 return Architecture_t::RecurrentLayerBackward(
state_gradients_backward, fWeightInputGradients, fWeightStateGradients,
598template <
typename Architecture_t>
611 this->WriteMatrixToXML(
layerxml,
"InputWeights",
this -> GetWeightsAt(0));
612 this->WriteMatrixToXML(
layerxml,
"StateWeights",
this -> GetWeightsAt(1));
613 this->WriteMatrixToXML(
layerxml,
"Biases",
this -> GetBiasesAt(0));
619template <
typename Architecture_t>
623 this->ReadMatrixXML(parent,
"InputWeights",
this -> GetWeightsAt(0));
624 this->ReadMatrixXML(parent,
"StateWeights",
this -> GetWeightsAt(1));
625 this->ReadMatrixXML(parent,
"Biases",
this -> GetBiasesAt(0));
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Tensor_t fDy
cached activation gradient (input of backward) as T x B x S
size_t GetStateSize() const
typename Architecture_t::RNNDescriptors_t RNNDescriptors_t
DNN::EActivationFunction GetActivationFunction() const
void InitState(DNN::EInitialization m=DNN::EInitialization::kZero)
Initialize the state method.
const Matrix_t & GetWeightInputGradients() const
const Tensor_t & GetWeightGradientsTensor() const
void Print() const
Prints the info about the layer.
typename Architecture_t::RecurrentDescriptor_t LayerDescriptor_t
Tensor_t fY
cached output tensor as T x B x S
Tensor_t fDerivatives
First fDerivatives of the activations.
const Matrix_t & GetWeightStateGradients() const
Matrix_t & fWeightsInput
Input weights, fWeights[0].
Matrix_t & fWeightsState
Prev state weights, fWeights[1].
virtual ~TBasicRNNLayer()
Destructor.
TDescriptors * fDescriptors
Keeps all the RNN descriptors.
Tensor_t fX
cached input tensor as T x B x I
void Forward(Tensor_t &input, bool isTraining=true)
Compute and return the next state with given input matrix.
Matrix_t & fBiases
Biases.
Architecture_t::ActivationDescriptor_t fActivationDesc
virtual void ReadWeightsFromXML(void *parent)
Read the information and the weights about the layer from XML node.
typename Architecture_t::TensorDescriptor_t TensorDescriptor_t
bool fReturnSequence
Return in output full sequence or just last element in time.
const Tensor_t & GetWeightsTensor() const
Matrix_t & GetBiasStateGradients()
Tensor_t fWeightGradientsTensor
size_t fStateSize
Hidden state size of RNN.
Matrix_t & GetWeightsState()
const Matrix_t & GetState() const
void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward)
Backpropagates the error.
Tensor_t & GetDerivatives()
const Matrix_t & GetCell() const
Matrix_t & CellBackward(Matrix_t &state_gradients_backward, const Matrix_t &precStateActivations, const Matrix_t &input, Matrix_t &input_gradient, Matrix_t &dF)
Backward for a single time unit a the corresponding call to Forward(...).
typename Architecture_t::Matrix_t Matrix_t
typename Architecture_t::DropoutDescriptor_t HelperDescriptor_t
typename Architecture_t::RNNWorkspace_t RNNWorkspace_t
Matrix_t fState
Hidden State.
Matrix_t & fWeightInputGradients
Gradients w.r.t. the input weights.
DNN::EActivationFunction fF
Activation function of the hidden state.
TBasicRNNLayer(size_t batchSize, size_t stateSize, size_t inputSize, size_t timeSteps, bool rememberState=false, bool returnSequence=false, DNN::EActivationFunction f=DNN::EActivationFunction::kTanh, bool training=true, DNN::EInitialization fA=DNN::EInitialization::kZero)
Constructor.
Tensor_t & GetWeightGradientsTensor()
Matrix_t & GetWeightsInput()
size_t GetTimeSteps() const
Getters.
bool fRememberState
Remember state in next pass.
Tensor_t & GetWeightsTensor()
Matrix_t & fWeightStateGradients
Gradients w.r.t. the recurring weights.
Matrix_t & GetWeightInputGradients()
const Matrix_t & GetBiasesState() const
void Update(const Scalar_t learningRate)
virtual void AddWeightsXMLTo(void *parent)
Writes the information and the weights about the layer in an XML node.
typename Architecture_t::Scalar_t Scalar_t
size_t fTimeSteps
Timesteps for RNN.
bool DoesRememberState() const
void CellForward(const Matrix_t &input, Matrix_t &dF)
Forward for a single cell (time unit)
Tensor_t fDx
cached gradient on the input (output of backward) as T x B x I
typename Architecture_t::Tensor_t Tensor_t
const Matrix_t & GetBiasStateGradients() const
size_t GetInputSize() const
Matrix_t & GetWeightStateGradients()
bool DoesReturnSequence() const
Matrix_t & fBiasGradients
Gradients w.r.t. the bias values.
const Matrix_t & GetWeightsInput() const
Matrix_t fCell
Empty matrix for RNN.
const Tensor_t & GetDerivatives() const
Matrix_t & GetBiasesState()
virtual void Initialize()
Initialize the weights according to the given initialization method.
const Matrix_t & GetWeightsState() const
typename Architecture_t::FilterDescriptor_t WeightsDescriptor_t
Generic General Layer class.
virtual void Initialize()
Initialize the weights and biases according to the given initialization method.
size_t GetInputWidth() const
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=nullptr)
create new child element for parent node
XMLAttrPointer_t NewAttr(XMLNodePointer_t xmlnode, XMLNsPointer_t, const char *name, const char *value)
creates new attribute for xmlnode, namespaces are not supported for attributes
auto debugMatrix(const typename Architecture_t::Matrix_t &A, const std::string name="matrix") -> void
EActivationFunction
Enum that represents layer activation functions.
create variable transformations