29#ifndef TMVA_DNN_RNN_LAYER
30#define TMVA_DNN_RNN_LAYER
54template<
typename Architecture_t>
60 using Tensor_t =
typename Architecture_t::Tensor_t;
61 using Matrix_t =
typename Architecture_t::Matrix_t;
62 using Scalar_t =
typename Architecture_t::Scalar_t;
87 TBasicRNNLayer(
size_t batchSize,
size_t stateSize,
size_t inputSize,
88 size_t timeSteps,
bool rememberState =
false,
113 const Tensor_t &activations_backward);
121 const Matrix_t & precStateActivations,
165template <
typename Architecture_t>
170 :
VGeneralLayer<Architecture_t>(batchSize, 1, timeSteps, inputSize, 1, timeSteps, stateSize, 2,
171 {stateSize, stateSize}, {inputSize, stateSize}, 1, {stateSize}, {1}, batchSize,
172 timeSteps, stateSize, fA),
173 fTimeSteps(timeSteps),
174 fStateSize(stateSize),
175 fRememberState(rememberState),
177 fState(batchSize, stateSize),
178 fWeightsInput(this->GetWeightsAt(0)),
179 fWeightsState(this->GetWeightsAt(1)),
180 fBiases(this->GetBiasesAt(0)),
181 fDerivatives( timeSteps, batchSize, stateSize),
182 fWeightInputGradients(this->GetWeightGradientsAt(0)),
183 fWeightStateGradients(this->GetWeightGradientsAt(1)),
184 fBiasGradients(this->GetBiasGradientsAt(0))
190template <
typename Architecture_t>
192 :
VGeneralLayer<Architecture_t>(layer), fTimeSteps(layer.fTimeSteps), fStateSize(layer.fStateSize),
193 fRememberState(layer.fRememberState), fF(layer.GetActivationFunction()),
194 fState(layer.GetBatchSize(), layer.GetStateSize()), fWeightsInput(this->GetWeightsAt(0)),
195 fWeightsState(this->GetWeightsAt(1)), fBiases(this->GetBiasesAt(0)),
196 fDerivatives( layer.GetDerivatives().GetShape() ), fWeightInputGradients(this->GetWeightGradientsAt(0)),
197 fWeightStateGradients(this->GetWeightGradientsAt(1)), fBiasGradients(this->GetBiasGradientsAt(0))
217template <
typename Architecture_t>
222 Architecture_t::InitializeActivationDescriptor(fActivationDesc,this->GetActivationFunction());
226template<
typename Architecture_t>
230 std::cout <<
" RECURRENT Layer: \t ";
231 std::cout <<
" (NInput = " << this->GetInputSize();
232 std::cout <<
", NState = " << this->GetStateSize();
233 std::cout <<
", NTime = " << this->GetTimeSteps() <<
" )";
234 std::cout <<
"\tOutput = ( " << this->GetOutput().GetFirstSize() <<
" , " << this->GetOutput().GetHSize() <<
" , " << this->GetOutput().GetWSize() <<
" )\n";
237template <
typename Architecture_t>
238auto debugMatrix(
const typename Architecture_t::Matrix_t &
A,
const std::string
name =
"matrix")
241 std::cout <<
name <<
"\n";
242 for (
size_t i = 0; i <
A.GetNrows(); ++i) {
243 for (
size_t j = 0; j <
A.GetNcols(); ++j) {
244 std::cout <<
A(i, j) <<
" ";
248 std::cout <<
"********\n";
253template <
typename Architecture_t>
262 Tensor_t arrInput (fTimeSteps, this->GetBatchSize(), this->GetInputWidth() );
264 Architecture_t::Rearrange(arrInput, input);
265 Tensor_t arrOutput ( fTimeSteps, this->GetBatchSize(), fStateSize);
269 for (
size_t t = 0; t < fTimeSteps; ++t) {
270 Matrix_t arrInput_m = arrInput.At(t).GetMatrix();
271 Matrix_t df_m = fDerivatives.At(t).GetMatrix();
272 CellForward(arrInput_m, df_m );
273 Matrix_t arrOutput_m = arrOutput.At(t).GetMatrix();
276 Architecture_t::Rearrange(this->GetOutput(), arrOutput);
280template <
typename Architecture_t>
286 Matrix_t tmpState(fState.GetNrows(), fState.GetNcols());
287 Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsState);
288 Architecture_t::MultiplyTranspose(fState, input, fWeightsInput);
289 Architecture_t::ScaleAdd(fState, tmpState);
290 Architecture_t::AddRowWise(fState, fBiases);
298 Architecture_t::ActivationFunctionForward(tState, fAF, fActivationDesc);
303template <
typename Architecture_t>
305 const Tensor_t &activations_backward) ->
void
316 if (gradients_backward.GetSize() == 0) {
319 Tensor_t arr_gradients_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
326 Tensor_t arr_activations_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
328 Architecture_t::Rearrange(arr_activations_backward, activations_backward);
330 Matrix_t state_gradients_backward(this->GetBatchSize(), fStateSize);
333 Matrix_t initState(this->GetBatchSize(), fStateSize);
336 Tensor_t arr_output ( fTimeSteps, this->GetBatchSize(), fStateSize);
338 Architecture_t::Rearrange(arr_output, this->GetOutput());
340 Tensor_t arr_actgradients ( fTimeSteps, this->GetBatchSize(), fStateSize);
342 Architecture_t::Rearrange(arr_actgradients, this->GetActivationGradients());
345 fWeightInputGradients.Zero();
346 fWeightStateGradients.Zero();
347 fBiasGradients.Zero();
349 for (
size_t t = fTimeSteps; t > 0; t--) {
351 Matrix_t actgrad_m = arr_actgradients.At(t - 1).GetMatrix();
352 Architecture_t::ScaleAdd(state_gradients_backward, actgrad_m);
354 Matrix_t actbw_m = arr_activations_backward.At(t - 1).GetMatrix();
355 Matrix_t gradbw_m = arr_gradients_backward.At(t - 1).GetMatrix();
365 Architecture_t::ActivationFunctionBackward(df,
y,
367 this->GetActivationFunction(), fActivationDesc);
373 Matrix_t precStateActivations = arr_output.At(t - 2).GetMatrix();
374 CellBackward(state_gradients_backward, precStateActivations, actbw_m, gradbw_m, df_m);
382 const Matrix_t & precStateActivations = initState;
383 CellBackward(state_gradients_backward, precStateActivations, actbw_m, gradbw_m, df_m);
393 Architecture_t::Rearrange(gradients_backward, arr_gradients_backward );
399template <
typename Architecture_t>
401 const Matrix_t & precStateActivations,
405 return Architecture_t::RecurrentLayerBackward(state_gradients_backward, fWeightInputGradients, fWeightStateGradients,
406 fBiasGradients, dF, precStateActivations, fWeightsInput,
407 fWeightsState, input, input_gradient);
411template <
typename Architecture_t>
423 this->WriteMatrixToXML(layerxml,
"InputWeights",
this -> GetWeightsAt(0));
424 this->WriteMatrixToXML(layerxml,
"StateWeights",
this -> GetWeightsAt(1));
425 this->WriteMatrixToXML(layerxml,
"Biases",
this -> GetBiasesAt(0));
431template <
typename Architecture_t>
435 this->ReadMatrixXML(parent,
"InputWeights",
this -> GetWeightsAt(0));
436 this->ReadMatrixXML(parent,
"StateWeights",
this -> GetWeightsAt(1));
437 this->ReadMatrixXML(parent,
"Biases",
this -> GetBiasesAt(0));
static RooMathCoreReg dummy
size_t GetStateSize() const
DNN::EActivationFunction GetActivationFunction() const
void InitState(DNN::EInitialization m=DNN::EInitialization::kZero)
Initialize the weights according to the given initialization method.
const Matrix_t & GetWeightInputGradients() const
void Print() const
Prints the info about the layer.
Tensor_t fDerivatives
First fDerivatives of the activations.
const Matrix_t & GetWeightStateGradients() const
Matrix_t & fWeightsInput
Input weights, fWeights[0].
Matrix_t & fWeightsState
Prev state weights, fWeights[1].
Matrix_t & fBiases
Biases.
Architecture_t::ActivationDescriptor_t fActivationDesc
virtual void ReadWeightsFromXML(void *parent)
Read the information and the weights about the layer from XML node.
Matrix_t & GetBiasStateGradients()
size_t fStateSize
Hidden state size of RNN.
Matrix_t & GetWeightsState()
const Matrix_t & GetState() const
void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward)
Backpropagates the error.
void Forward(Tensor_t &input, bool isTraining=true)
Compute and return the next state with given input matrix.
Tensor_t & GetDerivatives()
Matrix_t & CellBackward(Matrix_t &state_gradients_backward, const Matrix_t &precStateActivations, const Matrix_t &input, Matrix_t &input_gradient, Matrix_t &dF)
Backward for a single time unit a the corresponding call to Forward(...).
TBasicRNNLayer(size_t batchSize, size_t stateSize, size_t inputSize, size_t timeSteps, bool rememberState=false, DNN::EActivationFunction f=DNN::EActivationFunction::kTanh, bool training=true, DNN::EInitialization fA=DNN::EInitialization::kZero)
Constructor.
typename Architecture_t::Matrix_t Matrix_t
Matrix_t fState
Hidden State.
Matrix_t & fWeightInputGradients
Gradients w.r.t. the input weights.
DNN::EActivationFunction fF
Activation function of the hidden state.
Matrix_t & GetWeightsInput()
size_t GetTimeSteps() const
Getters.
bool fRememberState
Remember state in next pass.
Matrix_t & fWeightStateGradients
Gradients w.r.t. the recurring weights.
Matrix_t & GetWeightInputGradients()
const Matrix_t & GetBiasesState() const
void Update(const Scalar_t learningRate)
virtual void AddWeightsXMLTo(void *parent)
Writes the information and the weights about the layer in an XML node.
size_t fTimeSteps
Timesteps for RNN.
void CellForward(const Matrix_t &input, Matrix_t &dF)
Forward for a single cell (time unit)
bool IsRememberState() const
const Matrix_t & GetBiasStateGradients() const
size_t GetInputSize() const
Matrix_t & GetWeightStateGradients()
Matrix_t & fBiasGradients
Gradients w.r.t. the bias values.
const Matrix_t & GetWeightsInput() const
const Tensor_t & GetDerivatives() const
Matrix_t & GetBiasesState()
const Matrix_t & GetWeightsState() const
Generic General Layer class.
typename Architecture_t::Scalar_t Scalar_t
typename Architecture_t::Tensor_t Tensor_t
size_t GetInputWidth() const
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=nullptr)
create new child element for parent node
XMLAttrPointer_t NewAttr(XMLNodePointer_t xmlnode, XMLNsPointer_t, const char *name, const char *value)
creates new attribute for xmlnode, namespaces are not supported for attributes
void Copy(void *source, void *dest)
auto debugMatrix(const typename Architecture_t::Matrix_t &A, const std::string name="matrix") -> void
EActivationFunction
Enum that represents layer activation functions.
create variable transformations