Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
LSTMLayer.h
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn/lstm:$Id$
2// Author: Surya S Dwivedi 27/05/19
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : BasicLSTMLayer *
8 * *
9 * Description: *
10 * NeuralNetwork *
11 * *
12 * Authors (alphabetical): *
13 * Surya S Dwivedi <surya2191997@gmail.com> - IIT Kharagpur, India *
14 * *
15 * Copyright (c) 2005-2019: *
16 * All rights reserved. *
17 * CERN, Switzerland *
18 * *
19 * For the licensing terms see $ROOTSYS/LICENSE. *
20 * For the list of contributors see $ROOTSYS/README/CREDITS. *
21 **********************************************************************************/
22
23//#pragma once
24
25//////////////////////////////////////////////////////////////////////
26// This class implements the LSTM layer. LSTM is a variant of vanilla
27// RNN which is capable of learning long range dependencies.
28//////////////////////////////////////////////////////////////////////
29
30#ifndef TMVA_DNN_LSTM_LAYER
31#define TMVA_DNN_LSTM_LAYER
32
33#include <cmath>
34#include <iostream>
35#include <vector>
36
37#include "TMatrix.h"
38#include "TMVA/DNN/Functions.h"
39
40namespace TMVA
41{
42namespace DNN
43{
44namespace RNN
45{
46
47//______________________________________________________________________________
48//
49// Basic LSTM Layer
50//______________________________________________________________________________
51
52/** \class BasicLSTMLayer
53 Generic implementation
54*/
55template<typename Architecture_t>
56 class TBasicLSTMLayer : public VGeneralLayer<Architecture_t>
57{
58
59public:
60
61 using Matrix_t = typename Architecture_t::Matrix_t;
62 using Scalar_t = typename Architecture_t::Scalar_t;
63 using Tensor_t = typename Architecture_t::Tensor_t;
64
65 using LayerDescriptor_t = typename Architecture_t::RecurrentDescriptor_t;
66 using WeightsDescriptor_t = typename Architecture_t::FilterDescriptor_t;
67 using TensorDescriptor_t = typename Architecture_t::TensorDescriptor_t;
68 using HelperDescriptor_t = typename Architecture_t::DropoutDescriptor_t;
69
70 using RNNWorkspace_t = typename Architecture_t::RNNWorkspace_t;
71 using RNNDescriptors_t = typename Architecture_t::RNNDescriptors_t;
72
73private:
74
75 size_t fStateSize; ///< Hidden state size for LSTM
76 size_t fCellSize; ///< Cell state size of LSTM
77 size_t fTimeSteps; ///< Timesteps for LSTM
78
79 bool fRememberState; ///< Remember state in next pass
80 bool fReturnSequence = false; ///< Return in output full sequence or just last element
81
82 DNN::EActivationFunction fF1; ///< Activation function: sigmoid
83 DNN::EActivationFunction fF2; ///< Activation function: tanh
84
85 Matrix_t fInputValue; ///< Computed input gate values
86 Matrix_t fCandidateValue; ///< Computed candidate values
87 Matrix_t fForgetValue; ///< Computed forget gate values
88 Matrix_t fOutputValue; ///< Computed output gate values
89 Matrix_t fState; ///< Hidden state of LSTM
90 Matrix_t fCell; ///< Cell state of LSTM
91
92 Matrix_t &fWeightsInputGate; ///< Input Gate weights for input, fWeights[0]
93 Matrix_t &fWeightsInputGateState; ///< Input Gate weights for prev state, fWeights[1]
94 Matrix_t &fInputGateBias; ///< Input Gate bias
95
96 Matrix_t &fWeightsForgetGate; ///< Forget Gate weights for input, fWeights[2]
97 Matrix_t &fWeightsForgetGateState; ///< Forget Gate weights for prev state, fWeights[3]
98 Matrix_t &fForgetGateBias; ///< Forget Gate bias
99
100 Matrix_t &fWeightsCandidate; ///< Candidate Gate weights for input, fWeights[4]
101 Matrix_t &fWeightsCandidateState; ///< Candidate Gate weights for prev state, fWeights[5]
102 Matrix_t &fCandidateBias; ///< Candidate Gate bias
103
104 Matrix_t &fWeightsOutputGate; ///< Output Gate weights for input, fWeights[6]
105 Matrix_t &fWeightsOutputGateState; ///< Output Gate weights for prev state, fWeights[7]
106 Matrix_t &fOutputGateBias; ///< Output Gate bias
107
108 std::vector<Matrix_t> input_gate_value; ///< input gate value for every time step
109 std::vector<Matrix_t> forget_gate_value; ///< forget gate value for every time step
110 std::vector<Matrix_t> candidate_gate_value; ///< candidate gate value for every time step
111 std::vector<Matrix_t> output_gate_value; ///< output gate value for every time step
112 std::vector<Matrix_t> cell_value; ///< cell value for every time step
113 std::vector<Matrix_t> fDerivativesInput; ///< First fDerivatives of the activations input gate
114 std::vector<Matrix_t> fDerivativesForget; ///< First fDerivatives of the activations forget gate
115 std::vector<Matrix_t> fDerivativesCandidate; ///< First fDerivatives of the activations candidate gate
116 std::vector<Matrix_t> fDerivativesOutput; ///< First fDerivatives of the activations output gate
117
118 Matrix_t &fWeightsInputGradients; ///< Gradients w.r.t the input gate - input weights
119 Matrix_t &fWeightsInputStateGradients; ///< Gradients w.r.t the input gate - hidden state weights
120 Matrix_t &fInputBiasGradients; ///< Gradients w.r.t the input gate - bias weights
121 Matrix_t &fWeightsForgetGradients; ///< Gradients w.r.t the forget gate - input weights
122 Matrix_t &fWeightsForgetStateGradients; ///< Gradients w.r.t the forget gate - hidden state weights
123 Matrix_t &fForgetBiasGradients; ///< Gradients w.r.t the forget gate - bias weights
124 Matrix_t &fWeightsCandidateGradients; ///< Gradients w.r.t the candidate gate - input weights
125 Matrix_t &fWeightsCandidateStateGradients; ///< Gradients w.r.t the candidate gate - hidden state weights
126 Matrix_t &fCandidateBiasGradients; ///< Gradients w.r.t the candidate gate - bias weights
127 Matrix_t &fWeightsOutputGradients; ///< Gradients w.r.t the output gate - input weights
128 Matrix_t &fWeightsOutputStateGradients; ///< Gradients w.r.t the output gate - hidden state weights
129 Matrix_t &fOutputBiasGradients; ///< Gradients w.r.t the output gate - bias weights
130
131 // Tensor representing all weights (used by cuDNN)
132 Tensor_t fWeightsTensor; ///< Tensor for all weights
133 Tensor_t fWeightGradientsTensor; ///< Tensor for all weight gradients
134
135 // tensors used internally for the forward and backward pass
136 Tensor_t fX; ///< cached input tensor as T x B x I
137 Tensor_t fY; ///< cached output tensor as T x B x S
138 Tensor_t fDx; ///< cached gradient on the input (output of backward) as T x B x I
139 Tensor_t fDy; ///< cached activation gradient (input of backward) as T x B x S
140
141 TDescriptors *fDescriptors = nullptr; ///< Keeps all the RNN descriptors
142 TWorkspace *fWorkspace = nullptr; // workspace needed for GPU computation (CudNN)
143
144public:
145
146 /*! Constructor */
147 TBasicLSTMLayer(size_t batchSize, size_t stateSize, size_t inputSize, size_t timeSteps, bool rememberState = false,
148 bool returnSequence = false,
152
153 /*! Copy Constructor */
155
156 /*! Initialize the weights according to the given initialization
157 ** method. */
158 virtual void Initialize();
159
160 /*! Initialize the hidden state and cell state method. */
162
163 /*! Computes the next hidden state
164 * and next cell state with given input matrix. */
165 void Forward(Tensor_t &input, bool isTraining = true);
166
167 /*! Forward for a single cell (time unit) */
168 void CellForward(Matrix_t &inputGateValues, const Matrix_t &forgetGateValues,
169 const Matrix_t &candidateValues, const Matrix_t &outputGateValues);
170
171 /*! Backpropagates the error. Must only be called directly at the corresponding
172 * call to Forward(...). */
173 void Backward(Tensor_t &gradients_backward,
174 const Tensor_t &activations_backward);
175
176 /* Updates weights and biases, given the learning rate */
177 void Update(const Scalar_t learningRate);
178
179 /*! Backward for a single time unit
180 * a the corresponding call to Forward(...). */
181 Matrix_t & CellBackward(Matrix_t & state_gradients_backward,
182 Matrix_t & cell_gradients_backward,
183 const Matrix_t & precStateActivations, const Matrix_t & precCellActivations,
184 const Matrix_t & input_gate, const Matrix_t & forget_gate,
185 const Matrix_t & candidate_gate, const Matrix_t & output_gate,
186 const Matrix_t & input, Matrix_t & input_gradient,
187 Matrix_t &di, Matrix_t &df, Matrix_t &dc, Matrix_t &dout, size_t t);
188
189 /*! Decides the values we'll update (NN with Sigmoid) */
190 void InputGate(const Matrix_t &input, Matrix_t &di);
191
192 /*! Forgets the past values (NN with Sigmoid) */
193 void ForgetGate(const Matrix_t &input, Matrix_t &df);
194
195 /*! Decides the new candidate values (NN with Tanh) */
196 void CandidateValue(const Matrix_t &input, Matrix_t &dc);
197
198 /*! Computes output values (NN with Sigmoid) */
199 void OutputGate(const Matrix_t &input, Matrix_t &dout);
200
201 /*! Prints the info about the layer */
202 void Print() const;
203
204 /*! Writes the information and the weights about the layer in an XML node. */
205 void AddWeightsXMLTo(void *parent);
206
207 /*! Read the information and the weights about the layer from XML node. */
208 void ReadWeightsFromXML(void *parent);
209
210 /*! Getters */
211 size_t GetInputSize() const { return this->GetInputWidth(); }
212 size_t GetTimeSteps() const { return fTimeSteps; }
213 size_t GetStateSize() const { return fStateSize; }
214 size_t GetCellSize() const { return fCellSize; }
215
216 inline bool DoesRememberState() const { return fRememberState; }
217 inline bool DoesReturnSequence() const { return fReturnSequence; }
218
221
222 const Matrix_t & GetInputGateValue() const { return fInputValue; }
224 const Matrix_t & GetCandidateValue() const { return fCandidateValue; }
226 const Matrix_t & GetForgetGateValue() const { return fForgetValue; }
228 const Matrix_t & GetOutputGateValue() const { return fOutputValue; }
230
231 const Matrix_t & GetState() const { return fState; }
232 Matrix_t & GetState() { return fState; }
233 const Matrix_t & GetCell() const { return fCell; }
234 Matrix_t & GetCell() { return fCell; }
235
252
253 const std::vector<Matrix_t> & GetDerivativesInput() const { return fDerivativesInput; }
254 std::vector<Matrix_t> & GetDerivativesInput() { return fDerivativesInput; }
255 const Matrix_t & GetInputDerivativesAt(size_t i) const { return fDerivativesInput[i]; }
257 const std::vector<Matrix_t> & GetDerivativesForget() const { return fDerivativesForget; }
258 std::vector<Matrix_t> & GetDerivativesForget() { return fDerivativesForget; }
259 const Matrix_t & GetForgetDerivativesAt(size_t i) const { return fDerivativesForget[i]; }
261 const std::vector<Matrix_t> & GetDerivativesCandidate() const { return fDerivativesCandidate; }
262 std::vector<Matrix_t> & GetDerivativesCandidate() { return fDerivativesCandidate; }
263 const Matrix_t & GetCandidateDerivativesAt(size_t i) const { return fDerivativesCandidate[i]; }
265 const std::vector<Matrix_t> & GetDerivativesOutput() const { return fDerivativesOutput; }
266 std::vector<Matrix_t> & GetDerivativesOutput() { return fDerivativesOutput; }
267 const Matrix_t & GetOutputDerivativesAt(size_t i) const { return fDerivativesOutput[i]; }
269
270 const std::vector<Matrix_t> & GetInputGateTensor() const { return input_gate_value; }
271 std::vector<Matrix_t> & GetInputGateTensor() { return input_gate_value; }
272 const Matrix_t & GetInputGateTensorAt(size_t i) const { return input_gate_value[i]; }
274 const std::vector<Matrix_t> & GetForgetGateTensor() const { return forget_gate_value; }
275 std::vector<Matrix_t> & GetForgetGateTensor() { return forget_gate_value; }
276 const Matrix_t & GetForgetGateTensorAt(size_t i) const { return forget_gate_value[i]; }
278 const std::vector<Matrix_t> & GetCandidateGateTensor() const { return candidate_gate_value; }
279 std::vector<Matrix_t> & GetCandidateGateTensor() { return candidate_gate_value; }
280 const Matrix_t & GetCandidateGateTensorAt(size_t i) const { return candidate_gate_value[i]; }
282 const std::vector<Matrix_t> & GetOutputGateTensor() const { return output_gate_value; }
283 std::vector<Matrix_t> & GetOutputGateTensor() { return output_gate_value; }
284 const Matrix_t & GetOutputGateTensorAt(size_t i) const { return output_gate_value[i]; }
286 const std::vector<Matrix_t> & GetCellTensor() const { return cell_value; }
287 std::vector<Matrix_t> & GetCellTensor() { return cell_value; }
288 const Matrix_t & GetCellTensorAt(size_t i) const { return cell_value[i]; }
289 Matrix_t & GetCellTensorAt(size_t i) { return cell_value[i]; }
290
291 const Matrix_t & GetInputGateBias() const { return fInputGateBias; }
293 const Matrix_t & GetForgetGateBias() const { return fForgetGateBias; }
295 const Matrix_t & GetCandidateBias() const { return fCandidateBias; }
297 const Matrix_t & GetOutputGateBias() const { return fOutputGateBias; }
323
325 const Tensor_t &GetWeightsTensor() const { return fWeightsTensor; }
328
329 Tensor_t &GetX() { return fX; }
330 Tensor_t &GetY() { return fY; }
331 Tensor_t &GetDX() { return fDx; }
332 Tensor_t &GetDY() { return fDy; }
333};
334
335//______________________________________________________________________________
336//
337// Basic LSTM-Layer Implementation
338//______________________________________________________________________________
339
340template <typename Architecture_t>
341TBasicLSTMLayer<Architecture_t>::TBasicLSTMLayer(size_t batchSize, size_t stateSize, size_t inputSize, size_t timeSteps,
342 bool rememberState, bool returnSequence, DNN::EActivationFunction f1,
343 DNN::EActivationFunction f2, bool /* training */,
345 : VGeneralLayer<Architecture_t>(
346 batchSize, 1, timeSteps, inputSize, 1, (returnSequence) ? timeSteps : 1, stateSize, 8,
347 {stateSize, stateSize, stateSize, stateSize, stateSize, stateSize, stateSize, stateSize},
348 {inputSize, inputSize, inputSize, inputSize, stateSize, stateSize, stateSize, stateSize}, 4,
349 {stateSize, stateSize, stateSize, stateSize}, {1, 1, 1, 1}, batchSize, (returnSequence) ? timeSteps : 1,
350 stateSize, fA),
351 fStateSize(stateSize), fCellSize(stateSize), fTimeSteps(timeSteps), fRememberState(rememberState),
352 fReturnSequence(returnSequence), fF1(f1), fF2(f2), fInputValue(batchSize, stateSize),
353 fCandidateValue(batchSize, stateSize), fForgetValue(batchSize, stateSize), fOutputValue(batchSize, stateSize),
354 fState(batchSize, stateSize), fCell(batchSize, stateSize), fWeightsInputGate(this->GetWeightsAt(0)),
355 fWeightsInputGateState(this->GetWeightsAt(4)), fInputGateBias(this->GetBiasesAt(0)),
356 fWeightsForgetGate(this->GetWeightsAt(1)), fWeightsForgetGateState(this->GetWeightsAt(5)),
357 fForgetGateBias(this->GetBiasesAt(1)), fWeightsCandidate(this->GetWeightsAt(2)),
358 fWeightsCandidateState(this->GetWeightsAt(6)), fCandidateBias(this->GetBiasesAt(2)),
359 fWeightsOutputGate(this->GetWeightsAt(3)), fWeightsOutputGateState(this->GetWeightsAt(7)),
360 fOutputGateBias(this->GetBiasesAt(3)), fWeightsInputGradients(this->GetWeightGradientsAt(0)),
361 fWeightsInputStateGradients(this->GetWeightGradientsAt(4)), fInputBiasGradients(this->GetBiasGradientsAt(0)),
362 fWeightsForgetGradients(this->GetWeightGradientsAt(1)),
363 fWeightsForgetStateGradients(this->GetWeightGradientsAt(5)), fForgetBiasGradients(this->GetBiasGradientsAt(1)),
364 fWeightsCandidateGradients(this->GetWeightGradientsAt(2)),
365 fWeightsCandidateStateGradients(this->GetWeightGradientsAt(6)),
366 fCandidateBiasGradients(this->GetBiasGradientsAt(2)), fWeightsOutputGradients(this->GetWeightGradientsAt(3)),
367 fWeightsOutputStateGradients(this->GetWeightGradientsAt(7)), fOutputBiasGradients(this->GetBiasGradientsAt(3))
368{
369 for (size_t i = 0; i < timeSteps; ++i) {
370 fDerivativesInput.emplace_back(batchSize, stateSize);
371 fDerivativesForget.emplace_back(batchSize, stateSize);
372 fDerivativesCandidate.emplace_back(batchSize, stateSize);
373 fDerivativesOutput.emplace_back(batchSize, stateSize);
374 input_gate_value.emplace_back(batchSize, stateSize);
375 forget_gate_value.emplace_back(batchSize, stateSize);
376 candidate_gate_value.emplace_back(batchSize, stateSize);
377 output_gate_value.emplace_back(batchSize, stateSize);
378 cell_value.emplace_back(batchSize, stateSize);
379 }
380 Architecture_t::InitializeLSTMTensors(this);
381}
382
383 //______________________________________________________________________________
384template <typename Architecture_t>
386 : VGeneralLayer<Architecture_t>(layer),
387 fStateSize(layer.fStateSize),
388 fCellSize(layer.fCellSize),
389 fTimeSteps(layer.fTimeSteps),
390 fRememberState(layer.fRememberState),
391 fReturnSequence(layer.fReturnSequence),
392 fF1(layer.GetActivationFunctionF1()),
393 fF2(layer.GetActivationFunctionF2()),
394 fInputValue(layer.GetBatchSize(), layer.GetStateSize()),
395 fCandidateValue(layer.GetBatchSize(), layer.GetStateSize()),
396 fForgetValue(layer.GetBatchSize(), layer.GetStateSize()),
397 fOutputValue(layer.GetBatchSize(), layer.GetStateSize()),
398 fState(layer.GetBatchSize(), layer.GetStateSize()),
399 fCell(layer.GetBatchSize(), layer.GetCellSize()),
400 fWeightsInputGate(this->GetWeightsAt(0)),
401 fWeightsInputGateState(this->GetWeightsAt(4)),
402 fInputGateBias(this->GetBiasesAt(0)),
403 fWeightsForgetGate(this->GetWeightsAt(1)),
404 fWeightsForgetGateState(this->GetWeightsAt(5)),
405 fForgetGateBias(this->GetBiasesAt(1)),
406 fWeightsCandidate(this->GetWeightsAt(2)),
407 fWeightsCandidateState(this->GetWeightsAt(6)),
408 fCandidateBias(this->GetBiasesAt(2)),
409 fWeightsOutputGate(this->GetWeightsAt(3)),
410 fWeightsOutputGateState(this->GetWeightsAt(7)),
411 fOutputGateBias(this->GetBiasesAt(3)),
412 fWeightsInputGradients(this->GetWeightGradientsAt(0)),
413 fWeightsInputStateGradients(this->GetWeightGradientsAt(4)),
414 fInputBiasGradients(this->GetBiasGradientsAt(0)),
415 fWeightsForgetGradients(this->GetWeightGradientsAt(1)),
416 fWeightsForgetStateGradients(this->GetWeightGradientsAt(5)),
417 fForgetBiasGradients(this->GetBiasGradientsAt(1)),
418 fWeightsCandidateGradients(this->GetWeightGradientsAt(2)),
419 fWeightsCandidateStateGradients(this->GetWeightGradientsAt(6)),
420 fCandidateBiasGradients(this->GetBiasGradientsAt(2)),
421 fWeightsOutputGradients(this->GetWeightGradientsAt(3)),
422 fWeightsOutputStateGradients(this->GetWeightGradientsAt(7)),
423 fOutputBiasGradients(this->GetBiasGradientsAt(3))
424{
425 for (size_t i = 0; i < fTimeSteps; ++i) {
426 fDerivativesInput.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
427 Architecture_t::Copy(fDerivativesInput[i], layer.GetInputDerivativesAt(i));
428
429 fDerivativesForget.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
430 Architecture_t::Copy(fDerivativesForget[i], layer.GetForgetDerivativesAt(i));
431
432 fDerivativesCandidate.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
433 Architecture_t::Copy(fDerivativesCandidate[i], layer.GetCandidateDerivativesAt(i));
434
435 fDerivativesOutput.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
436 Architecture_t::Copy(fDerivativesOutput[i], layer.GetOutputDerivativesAt(i));
437
438 input_gate_value.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
439 Architecture_t::Copy(input_gate_value[i], layer.GetInputGateTensorAt(i));
440
441 forget_gate_value.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
442 Architecture_t::Copy(forget_gate_value[i], layer.GetForgetGateTensorAt(i));
443
444 candidate_gate_value.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
445 Architecture_t::Copy(candidate_gate_value[i], layer.GetCandidateGateTensorAt(i));
446
447 output_gate_value.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
448 Architecture_t::Copy(output_gate_value[i], layer.GetOutputGateTensorAt(i));
449
450 cell_value.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
451 Architecture_t::Copy(cell_value[i], layer.GetCellTensorAt(i));
452 }
453
454 // Gradient matrices not copied
455 Architecture_t::Copy(fState, layer.GetState());
456 Architecture_t::Copy(fCell, layer.GetCell());
457
458 // Copy each gate values.
459 Architecture_t::Copy(fInputValue, layer.GetInputGateValue());
460 Architecture_t::Copy(fCandidateValue, layer.GetCandidateValue());
461 Architecture_t::Copy(fForgetValue, layer.GetForgetGateValue());
462 Architecture_t::Copy(fOutputValue, layer.GetOutputGateValue());
463
464 Architecture_t::InitializeLSTMTensors(this);
465}
466
467//______________________________________________________________________________
468template <typename Architecture_t>
470{
472
473 Architecture_t::InitializeLSTMDescriptors(fDescriptors, this);
474 Architecture_t::InitializeLSTMWorkspace(fWorkspace, fDescriptors, this);
475}
476
477//______________________________________________________________________________
478template <typename Architecture_t>
480-> void
481{
482 /*! Computes input gate values according to equation:
483 * input = act(W_input . input + W_state . state + bias)
484 * activation function: sigmoid. */
485 const DNN::EActivationFunction fInp = this->GetActivationFunctionF1();
486 Matrix_t tmpState(fInputValue.GetNrows(), fInputValue.GetNcols());
487 Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsInputGateState);
488 Architecture_t::MultiplyTranspose(fInputValue, input, fWeightsInputGate);
489 Architecture_t::ScaleAdd(fInputValue, tmpState);
490 Architecture_t::AddRowWise(fInputValue, fInputGateBias);
491 DNN::evaluateDerivativeMatrix<Architecture_t>(di, fInp, fInputValue);
492 DNN::evaluateMatrix<Architecture_t>(fInputValue, fInp);
493}
494
495 //______________________________________________________________________________
496template <typename Architecture_t>
498-> void
499{
500 /*! Computes forget gate values according to equation:
501 * forget = act(W_input . input + W_state . state + bias)
502 * activation function: sigmoid. */
503 const DNN::EActivationFunction fFor = this->GetActivationFunctionF1();
504 Matrix_t tmpState(fForgetValue.GetNrows(), fForgetValue.GetNcols());
505 Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsForgetGateState);
506 Architecture_t::MultiplyTranspose(fForgetValue, input, fWeightsForgetGate);
507 Architecture_t::ScaleAdd(fForgetValue, tmpState);
508 Architecture_t::AddRowWise(fForgetValue, fForgetGateBias);
509 DNN::evaluateDerivativeMatrix<Architecture_t>(df, fFor, fForgetValue);
510 DNN::evaluateMatrix<Architecture_t>(fForgetValue, fFor);
511}
512
513 //______________________________________________________________________________
514template <typename Architecture_t>
516-> void
517{
518 /*! Candidate value will be used to scale input gate values followed by Hadamard product.
519 * candidate_value = act(W_input . input + W_state . state + bias)
520 * activation function = tanh. */
521 const DNN::EActivationFunction fCan = this->GetActivationFunctionF2();
522 Matrix_t tmpState(fCandidateValue.GetNrows(), fCandidateValue.GetNcols());
523 Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsCandidateState);
524 Architecture_t::MultiplyTranspose(fCandidateValue, input, fWeightsCandidate);
525 Architecture_t::ScaleAdd(fCandidateValue, tmpState);
526 Architecture_t::AddRowWise(fCandidateValue, fCandidateBias);
527 DNN::evaluateDerivativeMatrix<Architecture_t>(dc, fCan, fCandidateValue);
528 DNN::evaluateMatrix<Architecture_t>(fCandidateValue, fCan);
529}
530
531 //______________________________________________________________________________
532template <typename Architecture_t>
534-> void
535{
536 /*! Output gate values will be used to calculate next hidden state and output values.
537 * output = act(W_input . input + W_state . state + bias)
538 * activation function = sigmoid. */
539 const DNN::EActivationFunction fOut = this->GetActivationFunctionF1();
540 Matrix_t tmpState(fOutputValue.GetNrows(), fOutputValue.GetNcols());
541 Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsOutputGateState);
542 Architecture_t::MultiplyTranspose(fOutputValue, input, fWeightsOutputGate);
543 Architecture_t::ScaleAdd(fOutputValue, tmpState);
544 Architecture_t::AddRowWise(fOutputValue, fOutputGateBias);
545 DNN::evaluateDerivativeMatrix<Architecture_t>(dout, fOut, fOutputValue);
546 DNN::evaluateMatrix<Architecture_t>(fOutputValue, fOut);
547}
548
549
550
551 //______________________________________________________________________________
552template <typename Architecture_t>
554-> void
555{
556
557 // for Cudnn
558 if (Architecture_t::IsCudnn()) {
559
560 // input size is stride[1] of input tensor that is B x T x inputSize
561 assert(input.GetStrides()[1] == this->GetInputSize());
562
563 Tensor_t &x = this->fX;
564 Tensor_t &y = this->fY;
565 Architecture_t::Rearrange(x, input);
566
567 //const auto &weights = this->GetWeightsAt(0);
568 const auto &weights = this->GetWeightsTensor();
569 // Tensor_t cx({1}); // not used for normal RNN
570 // Tensor_t cy({1}); // not used for normal RNN
571
572 // hx is fState - tensor are of right shape
573 auto &hx = this->fState;
574 //auto &cx = this->fCell;
575 auto &cx = this->fCell; // pass an empty cell state
576 // use same for hy and cy
577 auto &hy = this->fState;
578 auto &cy = this->fCell;
579
580 auto & rnnDesc = static_cast<RNNDescriptors_t &>(*fDescriptors);
581 auto & rnnWork = static_cast<RNNWorkspace_t &>(*fWorkspace);
582
583 Architecture_t::RNNForward(x, hx, cx, weights, y, hy, cy, rnnDesc, rnnWork, isTraining);
584
585 if (fReturnSequence) {
586 Architecture_t::Rearrange(this->GetOutput(), y); // swap B and T from y to Output
587 } else {
588 // tmp is a reference to y (full cudnn output)
589 Tensor_t tmp = (y.At(y.GetShape()[0] - 1)).Reshape({y.GetShape()[1], 1, y.GetShape()[2]});
590 Architecture_t::Copy(this->GetOutput(), tmp);
591 }
592
593 return;
594 }
595
596 // Standard CPU implementation
597
598 // D : input size
599 // H : state size
600 // T : time size
601 // B : batch size
602
603 Tensor_t arrInput( fTimeSteps, this->GetBatchSize(), this->GetInputWidth());
604 //Tensor_t &arrInput = this->GetX();
605
606 Architecture_t::Rearrange(arrInput, input); // B x T x D
607
608 Tensor_t arrOutput ( fTimeSteps, this->GetBatchSize(), fStateSize);
609
610
611 if (!this->fRememberState) {
613 }
614
615 /*! Pass each gate values to CellForward() to calculate
616 * next hidden state and next cell state. */
617 for (size_t t = 0; t < fTimeSteps; ++t) {
618 /* Feed forward network: value of each gate being computed at each timestep t. */
619 Matrix_t arrInputMt = arrInput[t];
620 InputGate(arrInputMt, fDerivativesInput[t]);
621 ForgetGate(arrInputMt, fDerivativesForget[t]);
622 CandidateValue(arrInputMt, fDerivativesCandidate[t]);
623 OutputGate(arrInputMt, fDerivativesOutput[t]);
624
625 Architecture_t::Copy(this->GetInputGateTensorAt(t), fInputValue);
626 Architecture_t::Copy(this->GetForgetGateTensorAt(t), fForgetValue);
627 Architecture_t::Copy(this->GetCandidateGateTensorAt(t), fCandidateValue);
628 Architecture_t::Copy(this->GetOutputGateTensorAt(t), fOutputValue);
629
630 CellForward(fInputValue, fForgetValue, fCandidateValue, fOutputValue);
631 Matrix_t arrOutputMt = arrOutput[t];
632 Architecture_t::Copy(arrOutputMt, fState);
633 Architecture_t::Copy(this->GetCellTensorAt(t), fCell);
634 }
635
636 // check if full output needs to be returned
637 if (fReturnSequence)
638 Architecture_t::Rearrange(this->GetOutput(), arrOutput); // B x T x D
639 else {
640 // get T[end[]]
641 Tensor_t tmp = arrOutput.At(fTimeSteps - 1); // take last time step
642 // shape of tmp is for CPU (columnwise) B x D , need to reshape to make a B x D x 1
643 // and transpose it to 1 x D x B (this is how output is expected in columnmajor format)
644 tmp = tmp.Reshape( {tmp.GetShape()[0], tmp.GetShape()[1], 1});
645 assert(tmp.GetSize() == this->GetOutput().GetSize());
646 assert( tmp.GetShape()[0] == this->GetOutput().GetShape()[2]); // B is last dim in output and first in tmp
647 Architecture_t::Rearrange(this->GetOutput(), tmp);
648 // keep array output
649 fY = arrOutput;
650 }
651}
652
653 //______________________________________________________________________________
654template <typename Architecture_t>
655auto inline TBasicLSTMLayer<Architecture_t>::CellForward(Matrix_t &inputGateValues, const Matrix_t &forgetGateValues,
656 const Matrix_t &candidateValues, const Matrix_t &outputGateValues)
657-> void
658{
659
660 // Update cell state.
661 Architecture_t::Hadamard(fCell, forgetGateValues);
662 Architecture_t::Hadamard(inputGateValues, candidateValues);
663 Architecture_t::ScaleAdd(fCell, inputGateValues);
664
665 Matrix_t cache(fCell.GetNrows(), fCell.GetNcols());
666 Architecture_t::Copy(cache, fCell);
667
668 // Update hidden state.
669 const DNN::EActivationFunction fAT = this->GetActivationFunctionF2();
670 DNN::evaluateMatrix<Architecture_t>(cache, fAT);
671
672 /*! The Hadamard product of output_gate_value . tanh(cell_state)
673 * will be copied to next hidden state (passed to next LSTM cell)
674 * and we will update our outputGateValues also. */
675 Architecture_t::Copy(fState, cache);
676 Architecture_t::Hadamard(fState, outputGateValues);
677}
678
679 //____________________________________________________________________________
680template <typename Architecture_t>
681auto inline TBasicLSTMLayer<Architecture_t>::Backward(Tensor_t &gradients_backward, // B x T x D
682 const Tensor_t &activations_backward) // B x T x D
683-> void
684{
685
686 // BACKWARD for CUDNN
687 if (Architecture_t::IsCudnn()) {
688
689 Tensor_t &x = this->fX;
690 Tensor_t &y = this->fY;
691 Tensor_t &dx = this->fDx;
692 Tensor_t &dy = this->fDy;
693
694 // input size is stride[1] of input tensor that is B x T x inputSize
695 assert(activations_backward.GetStrides()[1] == this->GetInputSize());
696
697 Architecture_t::Rearrange(x, activations_backward);
698
699 if (!fReturnSequence) {
700
701 // Architecture_t::InitializeZero(dy);
702 Architecture_t::InitializeZero(dy);
703
704 // Tensor_t tmp1 = y.At(y.GetShape()[0] - 1).Reshape({y.GetShape()[1], 1, y.GetShape()[2]});
705 // dy is a tensor of shape (rowmajor for Cudnn): T x B x S
706 // and this->ActivationGradients is B x (T=1) x S
707 Tensor_t tmp2 = dy.At(dy.GetShape()[0] - 1).Reshape({dy.GetShape()[1], 1, dy.GetShape()[2]});
708
709 // Architecture_t::Copy(tmp1, this->GetOutput());
710 Architecture_t::Copy(tmp2, this->GetActivationGradients());
711 } else {
712 Architecture_t::Rearrange(y, this->GetOutput());
713 Architecture_t::Rearrange(dy, this->GetActivationGradients());
714 }
715
716 // Architecture_t::PrintTensor(this->GetOutput(), "output before bwd");
717
718 // for cudnn Matrix_t and Tensor_t are same type
719 const auto &weights = this->GetWeightsTensor();
720 auto &weightGradients = this->GetWeightGradientsTensor();
721 // note that cudnnRNNBackwardWeights accumulate the weight gradients.
722 // We need then to initialize the tensor to zero every time
723 Architecture_t::InitializeZero(weightGradients);
724
725 // hx is fState
726 auto &hx = this->GetState();
727 auto &cx = this->GetCell();
728 //auto &cx = this->GetCell();
729 // use same for hy and cy
730 auto &dhy = hx;
731 auto &dcy = cx;
732 auto &dhx = hx;
733 auto &dcx = cx;
734
735 auto & rnnDesc = static_cast<RNNDescriptors_t &>(*fDescriptors);
736 auto & rnnWork = static_cast<RNNWorkspace_t &>(*fWorkspace);
737
738 Architecture_t::RNNBackward(x, hx, cx, y, dy, dhy, dcy, weights, dx, dhx, dcx, weightGradients, rnnDesc, rnnWork);
739
740 // Architecture_t::PrintTensor(this->GetOutput(), "output after bwd");
741
742 if (gradients_backward.GetSize() != 0)
743 Architecture_t::Rearrange(gradients_backward, dx);
744
745 return;
746 }
747 // CPU implementation
748
749 // gradients_backward is activationGradients of layer before it, which is input layer.
750 // Currently, gradients_backward is for input(x) and not for state.
751 // For the state it can be:
752 Matrix_t state_gradients_backward(this->GetBatchSize(), fStateSize); // B x H
753 DNN::initialize<Architecture_t>(state_gradients_backward, DNN::EInitialization::kZero); // B x H
754
755
756 Matrix_t cell_gradients_backward(this->GetBatchSize(), fStateSize); // B x H
757 DNN::initialize<Architecture_t>(cell_gradients_backward, DNN::EInitialization::kZero); // B x H
758
759 // if dummy is false gradients_backward will be written back on the matrix
760 bool dummy = false;
761 if (gradients_backward.GetSize() == 0 || gradients_backward[0].GetNrows() == 0 || gradients_backward[0].GetNcols() == 0) {
762 dummy = true;
763 }
764
765
766 Tensor_t arr_gradients_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
767
768
769 //Architecture_t::Rearrange(arr_gradients_backward, gradients_backward); // B x T x D
770 // activations_backward is input.
771 Tensor_t arr_activations_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
772
773 Architecture_t::Rearrange(arr_activations_backward, activations_backward); // B x T x D
774
775 /*! For backpropagation, we need to calculate loss. For loss, output must be known.
776 * We obtain outputs during forward propagation and place the results in arr_output tensor. */
777 Tensor_t arr_output ( fTimeSteps, this->GetBatchSize(), fStateSize);
778
779 Matrix_t initState(this->GetBatchSize(), fCellSize); // B x H
780 DNN::initialize<Architecture_t>(initState, DNN::EInitialization::kZero); // B x H
781
782 // This will take partial derivative of state[t] w.r.t state[t-1]
783
784 Tensor_t arr_actgradients(fTimeSteps, this->GetBatchSize(), fStateSize);
785
786 if (fReturnSequence) {
787 Architecture_t::Rearrange(arr_output, this->GetOutput());
788 Architecture_t::Rearrange(arr_actgradients, this->GetActivationGradients());
789 } else {
790 // here for CPU need to transpose the input activation gradients into the right format
791 arr_output = fY;
792 Architecture_t::InitializeZero(arr_actgradients);
793 // need to reshape to pad a time dimension = 1 (note here is columnmajor tensors)
794 Tensor_t tmp_grad = arr_actgradients.At(fTimeSteps - 1).Reshape( {this->GetBatchSize(), fStateSize, 1});
795 assert(tmp_grad.GetSize() == this->GetActivationGradients().GetSize());
796 assert(tmp_grad.GetShape()[0] == this->GetActivationGradients().GetShape()[2]); // B in tmp is [0] and [2] in input act. gradients
797
798 Architecture_t::Rearrange(tmp_grad, this->GetActivationGradients());
799 }
800
801 /*! There are total 8 different weight matrices and 4 bias vectors.
802 * Re-initialize them with zero because it should have some value. (can't be garbage values) */
803
804 // Input Gate.
805 fWeightsInputGradients.Zero();
806 fWeightsInputStateGradients.Zero();
807 fInputBiasGradients.Zero();
808
809 // Forget Gate.
810 fWeightsForgetGradients.Zero();
811 fWeightsForgetStateGradients.Zero();
812 fForgetBiasGradients.Zero();
813
814 // Candidate Gate.
815 fWeightsCandidateGradients.Zero();
816 fWeightsCandidateStateGradients.Zero();
817 fCandidateBiasGradients.Zero();
818
819 // Output Gate.
820 fWeightsOutputGradients.Zero();
821 fWeightsOutputStateGradients.Zero();
822 fOutputBiasGradients.Zero();
823
824
825 for (size_t t = fTimeSteps; t > 0; t--) {
826 // Store the sum of gradients obtained at each timestep during backward pass.
827 Architecture_t::ScaleAdd(state_gradients_backward, arr_actgradients[t-1]);
828 if (t > 1) {
829 const Matrix_t &prevStateActivations = arr_output[t-2];
830 const Matrix_t &prevCellActivations = this->GetCellTensorAt(t-2);
831 // During forward propagation, each gate value calculates their gradients.
832 Matrix_t dx = arr_gradients_backward[t-1];
833 CellBackward(state_gradients_backward, cell_gradients_backward,
834 prevStateActivations, prevCellActivations,
835 this->GetInputGateTensorAt(t-1), this->GetForgetGateTensorAt(t-1),
836 this->GetCandidateGateTensorAt(t-1), this->GetOutputGateTensorAt(t-1),
837 arr_activations_backward[t-1], dx,
838 fDerivativesInput[t-1], fDerivativesForget[t-1],
839 fDerivativesCandidate[t-1], fDerivativesOutput[t-1], t-1);
840 } else {
841 const Matrix_t &prevStateActivations = initState;
842 const Matrix_t &prevCellActivations = initState;
843 Matrix_t dx = arr_gradients_backward[t-1];
844 CellBackward(state_gradients_backward, cell_gradients_backward,
845 prevStateActivations, prevCellActivations,
846 this->GetInputGateTensorAt(t-1), this->GetForgetGateTensorAt(t-1),
847 this->GetCandidateGateTensorAt(t-1), this->GetOutputGateTensorAt(t-1),
848 arr_activations_backward[t-1], dx,
849 fDerivativesInput[t-1], fDerivativesForget[t-1],
850 fDerivativesCandidate[t-1], fDerivativesOutput[t-1], t-1);
851 }
852 }
853
854 if (!dummy) {
855 Architecture_t::Rearrange(gradients_backward, arr_gradients_backward );
856 }
857
858}
859
860
861 //______________________________________________________________________________
862template <typename Architecture_t>
863auto inline TBasicLSTMLayer<Architecture_t>::CellBackward(Matrix_t & state_gradients_backward,
864 Matrix_t & cell_gradients_backward,
865 const Matrix_t & precStateActivations, const Matrix_t & precCellActivations,
866 const Matrix_t & input_gate, const Matrix_t & forget_gate,
867 const Matrix_t & candidate_gate, const Matrix_t & output_gate,
868 const Matrix_t & input, Matrix_t & input_gradient,
869 Matrix_t &di, Matrix_t &df, Matrix_t &dc, Matrix_t &dout,
870 size_t t)
871-> Matrix_t &
872{
873 /*! Call here LSTMLayerBackward() to pass parameters i.e. gradient
874 * values obtained from each gate during forward propagation. */
875
876
877 // cell gradient for current time step
878 const DNN::EActivationFunction fAT = this->GetActivationFunctionF2();
879 Matrix_t cell_gradient(this->GetCellTensorAt(t).GetNrows(), this->GetCellTensorAt(t).GetNcols());
880 DNN::evaluateDerivativeMatrix<Architecture_t>(cell_gradient, fAT, this->GetCellTensorAt(t));
881
882 // cell tanh value for current time step
883 Matrix_t cell_tanh(this->GetCellTensorAt(t).GetNrows(), this->GetCellTensorAt(t).GetNcols());
884 Architecture_t::Copy(cell_tanh, this->GetCellTensorAt(t));
885 DNN::evaluateMatrix<Architecture_t>(cell_tanh, fAT);
886
887 return Architecture_t::LSTMLayerBackward(state_gradients_backward, cell_gradients_backward,
888 fWeightsInputGradients, fWeightsForgetGradients, fWeightsCandidateGradients,
889 fWeightsOutputGradients, fWeightsInputStateGradients, fWeightsForgetStateGradients,
890 fWeightsCandidateStateGradients, fWeightsOutputStateGradients, fInputBiasGradients, fForgetBiasGradients,
891 fCandidateBiasGradients, fOutputBiasGradients, di, df, dc, dout,
892 precStateActivations, precCellActivations,
893 input_gate, forget_gate, candidate_gate, output_gate,
894 fWeightsInputGate, fWeightsForgetGate, fWeightsCandidate, fWeightsOutputGate,
895 fWeightsInputGateState, fWeightsForgetGateState, fWeightsCandidateState,
896 fWeightsOutputGateState, input, input_gradient,
897 cell_gradient, cell_tanh);
898}
899
900 //______________________________________________________________________________
901template <typename Architecture_t>
903-> void
904{
905 DNN::initialize<Architecture_t>(this->GetState(), DNN::EInitialization::kZero);
906 DNN::initialize<Architecture_t>(this->GetCell(), DNN::EInitialization::kZero);
907}
908
909 //______________________________________________________________________________
910template<typename Architecture_t>
912-> void
913{
914 std::cout << " LSTM Layer: \t ";
915 std::cout << " (NInput = " << this->GetInputSize(); // input size
916 std::cout << ", NState = " << this->GetStateSize(); // hidden state size
917 std::cout << ", NTime = " << this->GetTimeSteps() << " )"; // time size
918 std::cout << "\tOutput = ( " << this->GetOutput().GetFirstSize() << " , " << this->GetOutput()[0].GetNrows() << " , " << this->GetOutput()[0].GetNcols() << " )\n";
919}
920
921 //______________________________________________________________________________
922template <typename Architecture_t>
924-> void
925{
926 auto layerxml = gTools().xmlengine().NewChild(parent, nullptr, "LSTMLayer");
927
928 // Write all other info like outputSize, cellSize, inputSize, timeSteps, rememberState
929 gTools().xmlengine().NewAttr(layerxml, nullptr, "StateSize", gTools().StringFromInt(this->GetStateSize()));
930 gTools().xmlengine().NewAttr(layerxml, nullptr, "CellSize", gTools().StringFromInt(this->GetCellSize()));
931 gTools().xmlengine().NewAttr(layerxml, nullptr, "InputSize", gTools().StringFromInt(this->GetInputSize()));
932 gTools().xmlengine().NewAttr(layerxml, nullptr, "TimeSteps", gTools().StringFromInt(this->GetTimeSteps()));
933 gTools().xmlengine().NewAttr(layerxml, nullptr, "RememberState", gTools().StringFromInt(this->DoesRememberState()));
934 gTools().xmlengine().NewAttr(layerxml, nullptr, "ReturnSequence", gTools().StringFromInt(this->DoesReturnSequence()));
935
936 // write weights and bias matrices
937 this->WriteMatrixToXML(layerxml, "InputWeights", this->GetWeightsAt(0));
938 this->WriteMatrixToXML(layerxml, "InputStateWeights", this->GetWeightsAt(1));
939 this->WriteMatrixToXML(layerxml, "InputBiases", this->GetBiasesAt(0));
940 this->WriteMatrixToXML(layerxml, "ForgetWeights", this->GetWeightsAt(2));
941 this->WriteMatrixToXML(layerxml, "ForgetStateWeights", this->GetWeightsAt(3));
942 this->WriteMatrixToXML(layerxml, "ForgetBiases", this->GetBiasesAt(1));
943 this->WriteMatrixToXML(layerxml, "CandidateWeights", this->GetWeightsAt(4));
944 this->WriteMatrixToXML(layerxml, "CandidateStateWeights", this->GetWeightsAt(5));
945 this->WriteMatrixToXML(layerxml, "CandidateBiases", this->GetBiasesAt(2));
946 this->WriteMatrixToXML(layerxml, "OuputWeights", this->GetWeightsAt(6));
947 this->WriteMatrixToXML(layerxml, "OutputStateWeights", this->GetWeightsAt(7));
948 this->WriteMatrixToXML(layerxml, "OutputBiases", this->GetBiasesAt(3));
949}
950
951 //______________________________________________________________________________
952template <typename Architecture_t>
954-> void
955{
956 // Read weights and biases
957 this->ReadMatrixXML(parent, "InputWeights", this->GetWeightsAt(0));
958 this->ReadMatrixXML(parent, "InputStateWeights", this->GetWeightsAt(1));
959 this->ReadMatrixXML(parent, "InputBiases", this->GetBiasesAt(0));
960 this->ReadMatrixXML(parent, "ForgetWeights", this->GetWeightsAt(2));
961 this->ReadMatrixXML(parent, "ForgetStateWeights", this->GetWeightsAt(3));
962 this->ReadMatrixXML(parent, "ForgetBiases", this->GetBiasesAt(1));
963 this->ReadMatrixXML(parent, "CandidateWeights", this->GetWeightsAt(4));
964 this->ReadMatrixXML(parent, "CandidateStateWeights", this->GetWeightsAt(5));
965 this->ReadMatrixXML(parent, "CandidateBiases", this->GetBiasesAt(2));
966 this->ReadMatrixXML(parent, "OuputWeights", this->GetWeightsAt(6));
967 this->ReadMatrixXML(parent, "OutputStateWeights", this->GetWeightsAt(7));
968 this->ReadMatrixXML(parent, "OutputBiases", this->GetBiasesAt(3));
969}
970
971} // namespace LSTM
972} // namespace DNN
973} // namespace TMVA
974
975#endif // LSTM_LAYER_H
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
void InputGate(const Matrix_t &input, Matrix_t &di)
Decides the values we'll update (NN with Sigmoid)
Definition LSTMLayer.h:479
const Matrix_t & GetForgetGateTensorAt(size_t i) const
Definition LSTMLayer.h:276
Matrix_t & GetWeightsOutputGateState()
Definition LSTMLayer.h:251
const std::vector< Matrix_t > & GetOutputGateTensor() const
Definition LSTMLayer.h:282
Tensor_t fWeightsTensor
Tensor for all weights.
Definition LSTMLayer.h:132
const std::vector< Matrix_t > & GetInputGateTensor() const
Definition LSTMLayer.h:270
std::vector< Matrix_t > & GetDerivativesOutput()
Definition LSTMLayer.h:266
const Matrix_t & GetWeigthsForgetStateGradients() const
Definition LSTMLayer.h:307
typename Architecture_t::Matrix_t Matrix_t
Definition LSTMLayer.h:61
Matrix_t & GetCandidateGateTensorAt(size_t i)
Definition LSTMLayer.h:281
void InitState(DNN::EInitialization m=DNN::EInitialization::kZero)
Initialize the hidden state and cell state method.
Definition LSTMLayer.h:902
Matrix_t & fWeightsCandidateGradients
Gradients w.r.t the candidate gate - input weights.
Definition LSTMLayer.h:124
const Matrix_t & GetOutputGateBias() const
Definition LSTMLayer.h:297
Matrix_t & GetWeightsCandidateStateGradients()
Definition LSTMLayer.h:314
Matrix_t & GetWeightsInputGateState()
Definition LSTMLayer.h:245
const std::vector< Matrix_t > & GetCandidateGateTensor() const
Definition LSTMLayer.h:278
const Matrix_t & GetInputGateTensorAt(size_t i) const
Definition LSTMLayer.h:272
std::vector< Matrix_t > & GetForgetGateTensor()
Definition LSTMLayer.h:275
std::vector< Matrix_t > cell_value
cell value for every time step
Definition LSTMLayer.h:112
Matrix_t & fWeightsOutputGradients
Gradients w.r.t the output gate - input weights.
Definition LSTMLayer.h:127
Matrix_t & fOutputBiasGradients
Gradients w.r.t the output gate - bias weights.
Definition LSTMLayer.h:129
DNN::EActivationFunction fF1
Activation function: sigmoid.
Definition LSTMLayer.h:82
virtual void Initialize()
Initialize the weights according to the given initialization method.
Definition LSTMLayer.h:469
Tensor_t fDy
cached activation gradient (input of backward) as T x B x S
Definition LSTMLayer.h:139
Matrix_t & fWeightsOutputGate
Output Gate weights for input, fWeights[6].
Definition LSTMLayer.h:104
Matrix_t & fWeightsCandidateStateGradients
Gradients w.r.t the candidate gate - hidden state weights.
Definition LSTMLayer.h:125
void Forward(Tensor_t &input, bool isTraining=true)
Computes the next hidden state and next cell state with given input matrix.
Definition LSTMLayer.h:553
const Matrix_t & GetInputGateBias() const
Definition LSTMLayer.h:291
typename Architecture_t::Scalar_t Scalar_t
Definition LSTMLayer.h:62
size_t GetInputSize() const
Getters.
Definition LSTMLayer.h:211
Matrix_t & GetForgetGateTensorAt(size_t i)
Definition LSTMLayer.h:277
const Matrix_t & GetOutputGateTensorAt(size_t i) const
Definition LSTMLayer.h:284
const Matrix_t & GetCellTensorAt(size_t i) const
Definition LSTMLayer.h:288
Tensor_t fX
cached input tensor as T x B x I
Definition LSTMLayer.h:136
DNN::EActivationFunction GetActivationFunctionF2() const
Definition LSTMLayer.h:220
Matrix_t & GetCellTensorAt(size_t i)
Definition LSTMLayer.h:289
Matrix_t & fWeightsInputStateGradients
Gradients w.r.t the input gate - hidden state weights.
Definition LSTMLayer.h:119
void CellForward(Matrix_t &inputGateValues, const Matrix_t &forgetGateValues, const Matrix_t &candidateValues, const Matrix_t &outputGateValues)
Forward for a single cell (time unit)
Definition LSTMLayer.h:655
Matrix_t & CellBackward(Matrix_t &state_gradients_backward, Matrix_t &cell_gradients_backward, const Matrix_t &precStateActivations, const Matrix_t &precCellActivations, const Matrix_t &input_gate, const Matrix_t &forget_gate, const Matrix_t &candidate_gate, const Matrix_t &output_gate, const Matrix_t &input, Matrix_t &input_gradient, Matrix_t &di, Matrix_t &df, Matrix_t &dc, Matrix_t &dout, size_t t)
Backward for a single time unit a the corresponding call to Forward(...).
Definition LSTMLayer.h:863
const Matrix_t & GetWeightsInputStateGradients() const
Definition LSTMLayer.h:301
std::vector< Matrix_t > fDerivativesOutput
First fDerivatives of the activations output gate.
Definition LSTMLayer.h:116
Matrix_t & fWeightsForgetGateState
Forget Gate weights for prev state, fWeights[3].
Definition LSTMLayer.h:97
Matrix_t & fOutputGateBias
Output Gate bias.
Definition LSTMLayer.h:106
std::vector< Matrix_t > fDerivativesCandidate
First fDerivatives of the activations candidate gate.
Definition LSTMLayer.h:115
const Matrix_t & GetInputDerivativesAt(size_t i) const
Definition LSTMLayer.h:255
Matrix_t & fWeightsForgetGate
Forget Gate weights for input, fWeights[2].
Definition LSTMLayer.h:96
Matrix_t & fWeightsInputGradients
Gradients w.r.t the input gate - input weights.
Definition LSTMLayer.h:118
typename Architecture_t::Tensor_t Tensor_t
Definition LSTMLayer.h:63
const std::vector< Matrix_t > & GetDerivativesInput() const
Definition LSTMLayer.h:253
Matrix_t & fForgetGateBias
Forget Gate bias.
Definition LSTMLayer.h:98
Matrix_t & GetWeightsInputGradients()
Definition LSTMLayer.h:300
Matrix_t & GetCandidateBiasGradients()
Definition LSTMLayer.h:316
Matrix_t & GetWeightsOutputGradients()
Definition LSTMLayer.h:318
Matrix_t & fCandidateBias
Candidate Gate bias.
Definition LSTMLayer.h:102
Matrix_t fCandidateValue
Computed candidate values.
Definition LSTMLayer.h:86
Tensor_t & GetWeightGradientsTensor()
Definition LSTMLayer.h:326
const Matrix_t & GetWeightsOutputGradients() const
Definition LSTMLayer.h:317
typename Architecture_t::RecurrentDescriptor_t LayerDescriptor_t
Definition LSTMLayer.h:65
const Matrix_t & GetWeightsInputGradients() const
Definition LSTMLayer.h:299
Matrix_t & GetWeightsCandidateState()
Definition LSTMLayer.h:249
const Matrix_t & GetInputBiasGradients() const
Definition LSTMLayer.h:303
DNN::EActivationFunction fF2
Activation function: tanh.
Definition LSTMLayer.h:83
Matrix_t & fInputBiasGradients
Gradients w.r.t the input gate - bias weights.
Definition LSTMLayer.h:120
Matrix_t & GetWeightsOutputStateGradients()
Definition LSTMLayer.h:320
Matrix_t & fWeightsCandidateState
Candidate Gate weights for prev state, fWeights[5].
Definition LSTMLayer.h:101
void AddWeightsXMLTo(void *parent)
Writes the information and the weights about the layer in an XML node.
Definition LSTMLayer.h:923
std::vector< Matrix_t > fDerivativesForget
First fDerivatives of the activations forget gate.
Definition LSTMLayer.h:114
const Tensor_t & GetWeightGradientsTensor() const
Definition LSTMLayer.h:327
Matrix_t & GetForgetDerivativesAt(size_t i)
Definition LSTMLayer.h:260
const Matrix_t & GetWeightsInputGateState() const
Definition LSTMLayer.h:244
Matrix_t & GetWeightsInputStateGradients()
Definition LSTMLayer.h:302
typename Architecture_t::DropoutDescriptor_t HelperDescriptor_t
Definition LSTMLayer.h:68
Matrix_t & fForgetBiasGradients
Gradients w.r.t the forget gate - bias weights.
Definition LSTMLayer.h:123
const Matrix_t & GetCandidateBias() const
Definition LSTMLayer.h:295
std::vector< Matrix_t > output_gate_value
output gate value for every time step
Definition LSTMLayer.h:111
const std::vector< Matrix_t > & GetDerivativesCandidate() const
Definition LSTMLayer.h:261
size_t fStateSize
Hidden state size for LSTM.
Definition LSTMLayer.h:75
void CandidateValue(const Matrix_t &input, Matrix_t &dc)
Decides the new candidate values (NN with Tanh)
Definition LSTMLayer.h:515
std::vector< Matrix_t > fDerivativesInput
First fDerivatives of the activations input gate.
Definition LSTMLayer.h:113
const Matrix_t & GetWeightsForgetGateState() const
Definition LSTMLayer.h:246
Matrix_t & GetWeightsForgetGateState()
Definition LSTMLayer.h:247
const Matrix_t & GetWeightsInputGate() const
Definition LSTMLayer.h:236
const Matrix_t & GetInputGateValue() const
Definition LSTMLayer.h:222
void Update(const Scalar_t learningRate)
Tensor_t fDx
cached gradient on the input (output of backward) as T x B x I
Definition LSTMLayer.h:138
typename Architecture_t::RNNWorkspace_t RNNWorkspace_t
Definition LSTMLayer.h:70
TBasicLSTMLayer(size_t batchSize, size_t stateSize, size_t inputSize, size_t timeSteps, bool rememberState=false, bool returnSequence=false, DNN::EActivationFunction f1=DNN::EActivationFunction::kSigmoid, DNN::EActivationFunction f2=DNN::EActivationFunction::kTanh, bool training=true, DNN::EInitialization fA=DNN::EInitialization::kZero)
Constructor.
Definition LSTMLayer.h:341
Matrix_t & GetWeightsForgetStateGradients()
Definition LSTMLayer.h:308
const Matrix_t & GetOutputBiasGradients() const
Definition LSTMLayer.h:321
typename Architecture_t::TensorDescriptor_t TensorDescriptor_t
Definition LSTMLayer.h:67
const Matrix_t & GetWeightsOutputStateGradients() const
Definition LSTMLayer.h:319
Matrix_t & fWeightsOutputStateGradients
Gradients w.r.t the output gate - hidden state weights.
Definition LSTMLayer.h:128
bool fReturnSequence
Return in output full sequence or just last element.
Definition LSTMLayer.h:80
Matrix_t & GetWeightsForgetGradients()
Definition LSTMLayer.h:306
Matrix_t & GetWeightsCandidateGradients()
Definition LSTMLayer.h:312
const Matrix_t & GetWeightsForgetGradients() const
Definition LSTMLayer.h:305
Matrix_t fCell
Cell state of LSTM.
Definition LSTMLayer.h:90
std::vector< Matrix_t > & GetDerivativesCandidate()
Definition LSTMLayer.h:262
const Matrix_t & GetForgetBiasGradients() const
Definition LSTMLayer.h:309
std::vector< Matrix_t > & GetOutputGateTensor()
Definition LSTMLayer.h:283
const Matrix_t & GetForgetDerivativesAt(size_t i) const
Definition LSTMLayer.h:259
Matrix_t fState
Hidden state of LSTM.
Definition LSTMLayer.h:89
void OutputGate(const Matrix_t &input, Matrix_t &dout)
Computes output values (NN with Sigmoid)
Definition LSTMLayer.h:533
const Matrix_t & GetForgetGateValue() const
Definition LSTMLayer.h:226
std::vector< Matrix_t > candidate_gate_value
candidate gate value for every time step
Definition LSTMLayer.h:110
const Matrix_t & GetState() const
Definition LSTMLayer.h:231
const Matrix_t & GetWeightsCandidateState() const
Definition LSTMLayer.h:248
const std::vector< Matrix_t > & GetForgetGateTensor() const
Definition LSTMLayer.h:274
const std::vector< Matrix_t > & GetDerivativesOutput() const
Definition LSTMLayer.h:265
const std::vector< Matrix_t > & GetCellTensor() const
Definition LSTMLayer.h:286
const Tensor_t & GetWeightsTensor() const
Definition LSTMLayer.h:325
Matrix_t & fWeightsInputGate
Input Gate weights for input, fWeights[0].
Definition LSTMLayer.h:92
std::vector< Matrix_t > & GetCandidateGateTensor()
Definition LSTMLayer.h:279
const Matrix_t & GetOutputDerivativesAt(size_t i) const
Definition LSTMLayer.h:267
void ReadWeightsFromXML(void *parent)
Read the information and the weights about the layer from XML node.
Definition LSTMLayer.h:953
const Matrix_t & GetCell() const
Definition LSTMLayer.h:233
Matrix_t & fWeightsForgetStateGradients
Gradients w.r.t the forget gate - hidden state weights.
Definition LSTMLayer.h:122
const Matrix_t & GetCandidateGateTensorAt(size_t i) const
Definition LSTMLayer.h:280
Matrix_t fOutputValue
Computed output gate values.
Definition LSTMLayer.h:88
size_t fCellSize
Cell state size of LSTM.
Definition LSTMLayer.h:76
Matrix_t & GetOutputDerivativesAt(size_t i)
Definition LSTMLayer.h:268
Matrix_t & GetInputGateTensorAt(size_t i)
Definition LSTMLayer.h:273
std::vector< Matrix_t > & GetDerivativesInput()
Definition LSTMLayer.h:254
Matrix_t & fWeightsOutputGateState
Output Gate weights for prev state, fWeights[7].
Definition LSTMLayer.h:105
const std::vector< Matrix_t > & GetDerivativesForget() const
Definition LSTMLayer.h:257
const Matrix_t & GetForgetGateBias() const
Definition LSTMLayer.h:293
const Matrix_t & GetCandidateDerivativesAt(size_t i) const
Definition LSTMLayer.h:263
Matrix_t & GetOutputGateTensorAt(size_t i)
Definition LSTMLayer.h:285
size_t fTimeSteps
Timesteps for LSTM.
Definition LSTMLayer.h:77
const Matrix_t & GetCandidateBiasGradients() const
Definition LSTMLayer.h:315
const Matrix_t & GetCandidateValue() const
Definition LSTMLayer.h:224
typename Architecture_t::FilterDescriptor_t WeightsDescriptor_t
Definition LSTMLayer.h:66
Matrix_t & fInputGateBias
Input Gate bias.
Definition LSTMLayer.h:94
const Matrix_t & GetWeightsForgetGate() const
Definition LSTMLayer.h:240
std::vector< Matrix_t > input_gate_value
input gate value for every time step
Definition LSTMLayer.h:108
const Matrix_t & GetWeightsCandidateStateGradients() const
Definition LSTMLayer.h:313
Matrix_t & fWeightsForgetGradients
Gradients w.r.t the forget gate - input weights.
Definition LSTMLayer.h:121
std::vector< Matrix_t > & GetDerivativesForget()
Definition LSTMLayer.h:258
const Matrix_t & GetWeightsOutputGate() const
Definition LSTMLayer.h:242
void ForgetGate(const Matrix_t &input, Matrix_t &df)
Forgets the past values (NN with Sigmoid)
Definition LSTMLayer.h:497
std::vector< Matrix_t > & GetInputGateTensor()
Definition LSTMLayer.h:271
const Matrix_t & GetOutputGateValue() const
Definition LSTMLayer.h:228
const Matrix_t & GetWeightsOutputGateState() const
Definition LSTMLayer.h:250
Matrix_t & GetCandidateDerivativesAt(size_t i)
Definition LSTMLayer.h:264
Matrix_t fInputValue
Computed input gate values.
Definition LSTMLayer.h:85
const Matrix_t & GetWeightsCandidate() const
Definition LSTMLayer.h:238
void Print() const
Prints the info about the layer.
Definition LSTMLayer.h:911
void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward)
Backpropagates the error.
Definition LSTMLayer.h:681
const Matrix_t & GetWeightsCandidateGradients() const
Definition LSTMLayer.h:311
Tensor_t fWeightGradientsTensor
Tensor for all weight gradients.
Definition LSTMLayer.h:133
Matrix_t & GetInputDerivativesAt(size_t i)
Definition LSTMLayer.h:256
typename Architecture_t::RNNDescriptors_t RNNDescriptors_t
Definition LSTMLayer.h:71
DNN::EActivationFunction GetActivationFunctionF1() const
Definition LSTMLayer.h:219
Tensor_t fY
cached output tensor as T x B x S
Definition LSTMLayer.h:137
std::vector< Matrix_t > forget_gate_value
forget gate value for every time step
Definition LSTMLayer.h:109
Matrix_t & fWeightsCandidate
Candidate Gate weights for input, fWeights[4].
Definition LSTMLayer.h:100
bool fRememberState
Remember state in next pass.
Definition LSTMLayer.h:79
Matrix_t & fWeightsInputGateState
Input Gate weights for prev state, fWeights[1].
Definition LSTMLayer.h:93
TDescriptors * fDescriptors
Keeps all the RNN descriptors.
Definition LSTMLayer.h:141
std::vector< Matrix_t > & GetCellTensor()
Definition LSTMLayer.h:287
Matrix_t & fCandidateBiasGradients
Gradients w.r.t the candidate gate - bias weights.
Definition LSTMLayer.h:126
Matrix_t fForgetValue
Computed forget gate values.
Definition LSTMLayer.h:87
Generic General Layer class.
virtual void Initialize()
Initialize the weights and biases according to the given initialization method.
size_t GetBatchSize() const
Getters.
size_t GetInputWidth() const
TXMLEngine & xmlengine()
Definition Tools.h:262
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=nullptr)
create new child element for parent node
XMLAttrPointer_t NewAttr(XMLNodePointer_t xmlnode, XMLNsPointer_t, const char *name, const char *value)
creates new attribute for xmlnode, namespaces are not supported for attributes
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
TF1 * f1
Definition legend1.C:11
EActivationFunction
Enum that represents layer activation functions.
Definition Functions.h:32
create variable transformations
Tools & gTools()
TMarker m
Definition textangle.C:8