Logo ROOT  
Reference Guide
RecurrentPropagation.hxx
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Author: Saurav Shekhar 23/06/17
3
4/*************************************************************************
5 * Copyright (C) 2017, Saurav Shekhar *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12/////////////////////////////////////////////////////////////////////
13// Implementation of the functions required for the forward and //
14// backward propagation of activations through a recurrent neural //
15// network in the reference implementation. //
16/////////////////////////////////////////////////////////////////////
17
19
20namespace TMVA {
21namespace DNN {
22
23
24//______________________________________________________________________________
25template<typename Scalar_t>
27 TMatrixT<Scalar_t> & input_weight_gradients,
28 TMatrixT<Scalar_t> & state_weight_gradients,
29 TMatrixT<Scalar_t> & bias_gradients,
30 TMatrixT<Scalar_t> & df, //BxH
31 const TMatrixT<Scalar_t> & state, // BxH
32 const TMatrixT<Scalar_t> & weights_input, // HxD
33 const TMatrixT<Scalar_t> & weights_state, // HxH
34 const TMatrixT<Scalar_t> & input, // BxD
35 TMatrixT<Scalar_t> & input_gradient)
36-> Matrix_t &
37{
38
39 // std::cout << "Reference Recurrent Propo" << std::endl;
40 // std::cout << "df\n";
41 // df.Print();
42 // std::cout << "state gradient\n";
43 // state_gradients_backward.Print();
44 // std::cout << "inputw gradient\n";
45 // input_weight_gradients.Print();
46 // std::cout << "state\n";
47 // state.Print();
48 // std::cout << "input\n";
49 // input.Print();
50
51 // Compute element-wise product.
52 for (size_t i = 0; i < (size_t) df.GetNrows(); i++) {
53 for (size_t j = 0; j < (size_t) df.GetNcols(); j++) {
54 df(i,j) *= state_gradients_backward(i,j); // B x H
55 }
56 }
57
58 // Input gradients.
59 if (input_gradient.GetNoElements() > 0) {
60 input_gradient.Mult(df, weights_input); // B x H . H x D = B x D
61 }
62 // State gradients
63 if (state_gradients_backward.GetNoElements() > 0) {
64 state_gradients_backward.Mult(df, weights_state); // B x H . H x H = B x H
65 }
66
67 // Weights gradients.
68 if (input_weight_gradients.GetNoElements() > 0) {
69 TMatrixT<Scalar_t> tmp(input_weight_gradients);
70 input_weight_gradients.TMult(df, input); // H x B . B x D
71 input_weight_gradients += tmp;
72 }
73 if (state_weight_gradients.GetNoElements() > 0) {
74 TMatrixT<Scalar_t> tmp(state_weight_gradients);
75 state_weight_gradients.TMult(df, state); // H x B . B x H
76 state_weight_gradients += tmp;
77 }
78
79 // Bias gradients. B x H -> H x 1
80 if (bias_gradients.GetNoElements() > 0) {
81 // this loops on state size
82 for (size_t j = 0; j < (size_t) df.GetNcols(); j++) {
83 Scalar_t sum = 0.0;
84 // this loops on batch size summing all gradient contributions in a batch
85 for (size_t i = 0; i < (size_t) df.GetNrows(); i++) {
86 sum += df(i,j);
87 }
88 bias_gradients(j,0) += sum;
89 }
90 }
91
92 // std::cout << "RecurrentPropo: end " << std::endl;
93
94 // std::cout << "state gradient\n";
95 // state_gradients_backward.Print();
96 // std::cout << "inputw gradient\n";
97 // input_weight_gradients.Print();
98 // std::cout << "bias gradient\n";
99 // bias_gradients.Print();
100 // std::cout << "input gradient\n";
101 // input_gradient.Print();
102
103
104 return input_gradient;
105}
106
107
108} // namespace DNN
109} // namespace TMVA
static Matrix_t & RecurrentLayerBackward(TMatrixT< Scalar_t > &state_gradients_backward, TMatrixT< Scalar_t > &input_weight_gradients, TMatrixT< Scalar_t > &state_weight_gradients, TMatrixT< Scalar_t > &bias_gradients, TMatrixT< Scalar_t > &df, const TMatrixT< Scalar_t > &state, const TMatrixT< Scalar_t > &weights_input, const TMatrixT< Scalar_t > &weights_state, const TMatrixT< Scalar_t > &input, TMatrixT< Scalar_t > &input_gradient)
Backpropagation step for a Recurrent Neural Network.
TMatrixT.
Definition: TMatrixT.h:39
create variable transformations
static long int sum(long int i)
Definition: Factory.cxx:2276