Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RecurrentPropagation.hxx
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Authors: Surya S Dwivedi 15/07/2019, Saurav Shekhar 23/06/17
3/*************************************************************************
4 * Copyright (C) 2019, Surya S Dwivedi, Saurav Shekhar *
5 * All rights reserved. *
6 * *
7 * For the licensing terms see $ROOTSYS/LICENSE. *
8 * For the list of contributors see $ROOTSYS/README/CREDITS. *
9 *************************************************************************/
10
11/////////////////////////////////////////////////////////////////////
12// Implementation of the functions required for the forward and //
13// backward propagation of activations through a recurrent neural //
14// network in the reference implementation. //
15/////////////////////////////////////////////////////////////////////
16
18
19namespace TMVA
20{
21namespace DNN
22{
23
24//______________________________________________________________________________
25template<typename Scalar_t>
27 TMatrixT<Scalar_t> & input_weight_gradients,
28 TMatrixT<Scalar_t> & state_weight_gradients,
29 TMatrixT<Scalar_t> & bias_gradients,
30 TMatrixT<Scalar_t> & df, //BxH
31 const TMatrixT<Scalar_t> & state, // BxH
32 const TMatrixT<Scalar_t> & weights_input, // HxD
33 const TMatrixT<Scalar_t> & weights_state, // HxH
34 const TMatrixT<Scalar_t> & input, // BxD
35 TMatrixT<Scalar_t> & input_gradient)
36-> Matrix_t &
37{
38 // Compute element-wise product.
39 for (size_t i = 0; i < (size_t) df.GetNrows(); i++) {
40 for (size_t j = 0; j < (size_t) df.GetNcols(); j++) {
41 df(i,j) *= state_gradients_backward(i,j); // B x H
42 }
43 }
44
45 // Input gradients.
46 if (input_gradient.GetNoElements() > 0) {
47 input_gradient.Mult(df, weights_input); // B x H . H x D = B x D
48 }
49
50 // State gradients
51 if (state_gradients_backward.GetNoElements() > 0) {
52 state_gradients_backward.Mult(df, weights_state); // B x H . H x H = B x H
53 }
54
55 // Weights gradients.
56 if (input_weight_gradients.GetNoElements() > 0) {
57 TMatrixT<Scalar_t> tmp(input_weight_gradients);
58 input_weight_gradients.TMult(df, input); // H x B . B x D
59 input_weight_gradients += tmp;
60 }
61 if (state_weight_gradients.GetNoElements() > 0) {
62 TMatrixT<Scalar_t> tmp(state_weight_gradients);
63 state_weight_gradients.TMult(df, state); // H x B . B x H
64 state_weight_gradients += tmp;
65 }
66
67 // Bias gradients. B x H -> H x 1
68 if (bias_gradients.GetNoElements() > 0) {
69 // this loops on state size
70 for (size_t j = 0; j < (size_t) df.GetNcols(); j++) {
71 Scalar_t sum = 0.0;
72 // this loops on batch size summing all gradient contributions in a batch
73 for (size_t i = 0; i < (size_t) df.GetNrows(); i++) {
74 sum += df(i,j);
75 }
76 bias_gradients(j,0) += sum;
77 }
78 }
79
80 return input_gradient;
81}
82
83
84//______________________________________________________________________________
85template <typename Scalar_t>
87 TMatrixT<Scalar_t> & cell_gradients_backward,
88 TMatrixT<Scalar_t> & input_weight_gradients,
89 TMatrixT<Scalar_t> & forget_weight_gradients,
90 TMatrixT<Scalar_t> & candidate_weight_gradients,
91 TMatrixT<Scalar_t> & output_weight_gradients,
92 TMatrixT<Scalar_t> & input_state_weight_gradients,
93 TMatrixT<Scalar_t> & forget_state_weight_gradients,
94 TMatrixT<Scalar_t> & candidate_state_weight_gradients,
95 TMatrixT<Scalar_t> & output_state_weight_gradients,
96 TMatrixT<Scalar_t> & input_bias_gradients,
97 TMatrixT<Scalar_t> & forget_bias_gradients,
98 TMatrixT<Scalar_t> & candidate_bias_gradients,
99 TMatrixT<Scalar_t> & output_bias_gradients,
103 TMatrixT<Scalar_t> & dout,
104 const TMatrixT<Scalar_t> & precStateActivations,
105 const TMatrixT<Scalar_t> & precCellActivations,
106 const TMatrixT<Scalar_t> & fInput,
107 const TMatrixT<Scalar_t> & fForget,
108 const TMatrixT<Scalar_t> & fCandidate,
109 const TMatrixT<Scalar_t> & fOutput,
110 const TMatrixT<Scalar_t> & weights_input,
111 const TMatrixT<Scalar_t> & weights_forget,
112 const TMatrixT<Scalar_t> & weights_candidate,
113 const TMatrixT<Scalar_t> & weights_output,
114 const TMatrixT<Scalar_t> & weights_input_state,
115 const TMatrixT<Scalar_t> & weights_forget_state,
116 const TMatrixT<Scalar_t> & weights_candidate_state,
117 const TMatrixT<Scalar_t> & weights_output_state,
119 TMatrixT<Scalar_t> & input_gradient,
120 TMatrixT<Scalar_t> & cell_gradient,
121 TMatrixT<Scalar_t> & cell_tanh)
122-> Matrix_t &
123{
124 // cell gradient
125 Hadamard(cell_gradient, fOutput);
126 Hadamard(cell_gradient, state_gradients_backward);
127 cell_gradient += cell_gradients_backward;
128 cell_gradients_backward = cell_gradient;
129 Hadamard(cell_gradients_backward, fForget);
130
131 // candidate gradient
132 TMatrixT<Scalar_t> candidate_gradient(cell_gradient);
133 Hadamard(candidate_gradient, fInput);
134 Hadamard(candidate_gradient, dc);
135
136 // input gate gradient
137 TMatrixT<Scalar_t> input_gate_gradient(cell_gradient);
138 Hadamard(input_gate_gradient, fCandidate);
139 Hadamard(input_gate_gradient, di);
140
141 // forget gradient
142 TMatrixT<Scalar_t> forget_gradient(cell_gradient);
143 Hadamard(forget_gradient, precCellActivations);
144 Hadamard(forget_gradient, df);
145
146 // output gradient
147 TMatrixT<Scalar_t> output_gradient(cell_tanh);
148 Hadamard(output_gradient, state_gradients_backward);
149 Hadamard(output_gradient, dout);
150
151 // input gradient
152 TMatrixT<Scalar_t> tmpInp(input_gradient);
153 tmpInp.Mult(input_gate_gradient, weights_input);
154 input_gradient = tmpInp;
155 tmpInp.Mult(forget_gradient, weights_forget);
156 input_gradient += tmpInp;
157 tmpInp.Mult(candidate_gradient, weights_candidate);
158 input_gradient += tmpInp;
159 tmpInp.Mult(output_gradient, weights_output);
160 input_gradient += tmpInp;
161
162 // state gradient backwards
163 TMatrixT<Scalar_t> tmpState(state_gradients_backward);
164 tmpState.Mult(input_gate_gradient, weights_input_state);
165 state_gradients_backward = tmpState;
166 tmpState.Mult(forget_gradient, weights_forget_state);
167 state_gradients_backward += tmpState;
168 tmpState.Mult(candidate_gradient, weights_candidate_state);
169 state_gradients_backward += tmpState;
170 tmpState.Mult(output_gradient, weights_output_state);
171 state_gradients_backward += tmpState;
172
173 //input weight gradients
174 TMatrixT<Scalar_t> tmp(input_weight_gradients);
175 input_weight_gradients.TMult(input_gate_gradient, input);
176 input_weight_gradients += tmp;
177 tmp = forget_weight_gradients;
178 forget_weight_gradients.TMult(forget_gradient, input);
179 forget_weight_gradients += tmp;
180 tmp = candidate_weight_gradients;
181 candidate_weight_gradients.TMult(candidate_gradient, input);
182 candidate_weight_gradients += tmp;
183 tmp = output_weight_gradients;
184 output_weight_gradients.TMult(output_gradient, input);
185 output_weight_gradients += tmp;
186
187 // state weight gradients
188 TMatrixT<Scalar_t> tmp1(input_state_weight_gradients);
189 input_state_weight_gradients.TMult(input_gate_gradient, precStateActivations);
190 input_state_weight_gradients += tmp1;
191 tmp1 = forget_state_weight_gradients;
192 forget_state_weight_gradients.TMult(forget_gradient, precStateActivations);
193 forget_state_weight_gradients += tmp1;
194 tmp1 = candidate_state_weight_gradients;
195 candidate_state_weight_gradients.TMult(candidate_gradient, precStateActivations);
196 candidate_state_weight_gradients += tmp1;
197 tmp1 = output_state_weight_gradients;
198 output_state_weight_gradients.TMult(output_gradient, precStateActivations);
199 output_state_weight_gradients += tmp1;
200
201 // bias gradients
202 for (size_t j = 0; j < (size_t) df.GetNcols(); j++) {
203 Scalar_t sum_inp = 0.0, sum_forget = 0.0, sum_candidate = 0.0, sum_out = 0.0;
204 // this loops on batch size summing all gradient contributions in a batch
205 for (size_t i = 0; i < (size_t) df.GetNrows(); i++) {
206 sum_inp += input_gate_gradient(i,j);
207 sum_forget += forget_gradient(i,j);
208 sum_candidate += candidate_gradient(i,j);
209 sum_out += output_gradient(i,j);
210 }
211 input_bias_gradients(j,0) += sum_inp;
212 forget_bias_gradients(j,0) += sum_forget;
213 candidate_bias_gradients(j,0) += sum_candidate;
214 output_bias_gradients(j,0) += sum_out;
215 }
216
217 return input_gradient;
218}
219
220
221
222//______________________________________________________________________________
223template <typename Scalar_t>
225 TMatrixT<Scalar_t> & reset_weight_gradients,
226 TMatrixT<Scalar_t> & update_weight_gradients,
227 TMatrixT<Scalar_t> & candidate_weight_gradients,
228 TMatrixT<Scalar_t> & reset_state_weight_gradients,
229 TMatrixT<Scalar_t> & update_state_weight_gradients,
230 TMatrixT<Scalar_t> & candidate_state_weight_gradients,
231 TMatrixT<Scalar_t> & reset_bias_gradients,
232 TMatrixT<Scalar_t> & update_bias_gradients,
233 TMatrixT<Scalar_t> & candidate_bias_gradients,
237 const TMatrixT<Scalar_t> & precStateActivations,
238 const TMatrixT<Scalar_t> & fReset,
239 const TMatrixT<Scalar_t> & fUpdate,
240 const TMatrixT<Scalar_t> & fCandidate,
241 const TMatrixT<Scalar_t> & weights_reset,
242 const TMatrixT<Scalar_t> & weights_update,
243 const TMatrixT<Scalar_t> & weights_candidate,
244 const TMatrixT<Scalar_t> & weights_reset_state,
245 const TMatrixT<Scalar_t> & weights_update_state,
246 const TMatrixT<Scalar_t> & weights_candidate_state,
248 TMatrixT<Scalar_t> & input_gradient)
249-> Matrix_t &
250{
251 // reset gradient
252 TMatrixT<Scalar_t> reset_gradient(fUpdate);
253 for (size_t j = 0; j < (size_t) reset_gradient.GetNcols(); j++) {
254 for (size_t i = 0; i < (size_t) reset_gradient.GetNrows(); i++) {
255 reset_gradient(i,j) = 1 - reset_gradient(i,j);
256 }
257 }
258 Hadamard(reset_gradient, dc);
259 Hadamard(reset_gradient, state_gradients_backward);
260 TMatrixT<Scalar_t> tmpMul(precStateActivations);
261 tmpMul.Mult(reset_gradient, weights_candidate_state);
262 Hadamard(tmpMul, precStateActivations);
263 Hadamard(tmpMul, dr);
264 reset_gradient = tmpMul;
265
266 // update gradient
267 TMatrixT<Scalar_t> update_gradient(precStateActivations); // H X 1
268 for (size_t j = 0; j < (size_t) update_gradient.GetNcols(); j++) {
269 for (size_t i = 0; i < (size_t) update_gradient.GetNrows(); i++) {
270 update_gradient(i,j) = update_gradient(i,j) - fCandidate(i,j);
271 }
272 }
273 Hadamard(update_gradient, du);
274 Hadamard(update_gradient, state_gradients_backward);
275
276 // candidate gradient
277 TMatrixT<Scalar_t> candidate_gradient(fUpdate);
278 for (size_t j = 0; j < (size_t) candidate_gradient.GetNcols(); j++) {
279 for (size_t i = 0; i < (size_t) candidate_gradient.GetNrows(); i++) {
280 candidate_gradient(i,j) = 1 - candidate_gradient(i,j);
281 }
282 }
283 Hadamard(candidate_gradient, dc);
284 Hadamard(candidate_gradient, state_gradients_backward);
285
286 // calculating state_gradient_backwards term by term
287 // term 1
288 TMatrixT<Scalar_t> temp(state_gradients_backward);
289 TMatrixT<Scalar_t> term(fUpdate); // H X 1
290 Hadamard(term, temp);
291 state_gradients_backward = term;
292
293 //term 2
294 term = precStateActivations;
295 Hadamard(term, du);
296 Hadamard(term, temp);
297 TMatrixT<Scalar_t> var(precStateActivations);
298 var.Mult(term, weights_update_state);
299 term = var;
300 state_gradients_backward += term;
301
302 // term 3
303 term = fCandidate;
304 for (size_t j = 0; j < (size_t) term.GetNcols(); j++) {
305 for (size_t i = 0; i < (size_t) term.GetNrows(); i++) {
306 term(i,j) = - term(i,j);
307 }
308 }
309 Hadamard(term, du);
310 Hadamard(term, temp);
311 var.Mult(term, weights_update_state);
312 term = var;
313 state_gradients_backward += term;
314
315 // term 4
316 term = fUpdate;
317 for (size_t j = 0; j < (size_t) term.GetNcols(); j++) {
318 for (size_t i = 0; i < (size_t) term.GetNrows(); i++) {
319 term(i,j) = 1 - term(i,j);
320 }
321 }
322 Hadamard(term, dc);
323 Hadamard(term, temp);
324 var.Mult(term, weights_candidate_state);
325 Hadamard(var, fReset);
326 term = var;
327 state_gradients_backward += term;
328
329 // term 5
330 term = fUpdate;
331 for (size_t j = 0; j < (size_t) term.GetNcols(); j++) {
332 for (size_t i = 0; i < (size_t) term.GetNrows(); i++) {
333 term(i,j) = 1 - term(i,j);
334 }
335 }
336 Hadamard(term, dc);
337 Hadamard(term, temp);
338 var.Mult(term, weights_candidate_state);
339 Hadamard(var, precStateActivations);
340 Hadamard(var, dr);
341 term.Mult(var, weights_reset_state);
342 state_gradients_backward += term;
343
344 // input gradients
345 TMatrixT<Scalar_t> tmpInp(input_gradient);
346 tmpInp.Mult(reset_gradient, weights_reset);
347 input_gradient = tmpInp;
348 tmpInp.Mult(update_gradient, weights_update);
349 input_gradient += tmpInp;
350 tmpInp.Mult(candidate_gradient, weights_candidate);
351 input_gradient += tmpInp;
352
353 //input weight gradients
354 TMatrixT<Scalar_t> tmp(reset_weight_gradients);
355 reset_weight_gradients.TMult(reset_gradient, input);
356 reset_weight_gradients += tmp;
357 tmp = update_weight_gradients;
358 update_weight_gradients.TMult(update_gradient, input);
359 update_weight_gradients += tmp;
360 tmp = candidate_weight_gradients;
361 candidate_weight_gradients.TMult(candidate_gradient, input);
362 candidate_weight_gradients += tmp;
363
364 // state weight gradients
365 TMatrixT<Scalar_t> tmp1(reset_state_weight_gradients);
366 reset_state_weight_gradients.TMult(reset_gradient, precStateActivations);
367 reset_state_weight_gradients += tmp1;
368 tmp1 = update_state_weight_gradients;
369 update_state_weight_gradients.TMult(update_gradient, precStateActivations);
370 update_state_weight_gradients += tmp1;
371 tmp1 = candidate_state_weight_gradients;
372 TMatrixT<Scalar_t> tmp2(fReset);
373 Hadamard(tmp2, precStateActivations);
374 candidate_state_weight_gradients.TMult(candidate_gradient, tmp2);
375 candidate_state_weight_gradients += tmp1;
376
377 // bias gradients
378 for (size_t j = 0; j < (size_t) du.GetNcols(); j++) {
379 Scalar_t sum_reset = 0.0, sum_update = 0.0, sum_candidate = 0.0;
380 // this loops on batch size summing all gradient contributions in a batch
381 for (size_t i = 0; i < (size_t) du.GetNrows(); i++) {
382 sum_reset += reset_gradient(i,j);
383 sum_update += update_gradient(i,j);
384 sum_candidate += candidate_gradient(i,j);
385 }
386 reset_bias_gradients(j,0) += sum_reset;
387 update_bias_gradients(j,0) += sum_update;
388 candidate_bias_gradients(j,0) += sum_candidate;
389 }
390
391 return input_gradient;
392}
393
394} // namespace DNN
395} // namespace TMVA
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
static Matrix_t & GRULayerBackward(TMatrixT< Scalar_t > &state_gradients_backward, TMatrixT< Scalar_t > &reset_weight_gradients, TMatrixT< Scalar_t > &update_weight_gradients, TMatrixT< Scalar_t > &candidate_weight_gradients, TMatrixT< Scalar_t > &reset_state_weight_gradients, TMatrixT< Scalar_t > &update_state_weight_gradients, TMatrixT< Scalar_t > &candidate_state_weight_gradients, TMatrixT< Scalar_t > &reset_bias_gradients, TMatrixT< Scalar_t > &update_bias_gradients, TMatrixT< Scalar_t > &candidate_bias_gradients, TMatrixT< Scalar_t > &dr, TMatrixT< Scalar_t > &du, TMatrixT< Scalar_t > &dc, const TMatrixT< Scalar_t > &precStateActivations, const TMatrixT< Scalar_t > &fReset, const TMatrixT< Scalar_t > &fUpdate, const TMatrixT< Scalar_t > &fCandidate, const TMatrixT< Scalar_t > &weights_reset, const TMatrixT< Scalar_t > &weights_update, const TMatrixT< Scalar_t > &weights_candidate, const TMatrixT< Scalar_t > &weights_reset_state, const TMatrixT< Scalar_t > &weights_update_state, const TMatrixT< Scalar_t > &weights_candidate_state, const TMatrixT< Scalar_t > &input, TMatrixT< Scalar_t > &input_gradient)
Backward pass for GRU Network.
static Matrix_t & LSTMLayerBackward(TMatrixT< Scalar_t > &state_gradients_backward, TMatrixT< Scalar_t > &cell_gradients_backward, TMatrixT< Scalar_t > &input_weight_gradients, TMatrixT< Scalar_t > &forget_weight_gradients, TMatrixT< Scalar_t > &candidate_weight_gradients, TMatrixT< Scalar_t > &output_weight_gradients, TMatrixT< Scalar_t > &input_state_weight_gradients, TMatrixT< Scalar_t > &forget_state_weight_gradients, TMatrixT< Scalar_t > &candidate_state_weight_gradients, TMatrixT< Scalar_t > &output_state_weight_gradients, TMatrixT< Scalar_t > &input_bias_gradients, TMatrixT< Scalar_t > &forget_bias_gradients, TMatrixT< Scalar_t > &candidate_bias_gradients, TMatrixT< Scalar_t > &output_bias_gradients, TMatrixT< Scalar_t > &di, TMatrixT< Scalar_t > &df, TMatrixT< Scalar_t > &dc, TMatrixT< Scalar_t > &dout, const TMatrixT< Scalar_t > &precStateActivations, const TMatrixT< Scalar_t > &precCellActivations, const TMatrixT< Scalar_t > &fInput, const TMatrixT< Scalar_t > &fForget, const TMatrixT< Scalar_t > &fCandidate, const TMatrixT< Scalar_t > &fOutput, const TMatrixT< Scalar_t > &weights_input, const TMatrixT< Scalar_t > &weights_forget, const TMatrixT< Scalar_t > &weights_candidate, const TMatrixT< Scalar_t > &weights_output, const TMatrixT< Scalar_t > &weights_input_state, const TMatrixT< Scalar_t > &weights_forget_state, const TMatrixT< Scalar_t > &weights_candidate_state, const TMatrixT< Scalar_t > &weights_output_state, const TMatrixT< Scalar_t > &input, TMatrixT< Scalar_t > &input_gradient, TMatrixT< Scalar_t > &cell_gradient, TMatrixT< Scalar_t > &cell_tanh)
Backward pass for LSTM Network.
static Matrix_t & RecurrentLayerBackward(TMatrixT< Scalar_t > &state_gradients_backward, TMatrixT< Scalar_t > &input_weight_gradients, TMatrixT< Scalar_t > &state_weight_gradients, TMatrixT< Scalar_t > &bias_gradients, TMatrixT< Scalar_t > &df, const TMatrixT< Scalar_t > &state, const TMatrixT< Scalar_t > &weights_input, const TMatrixT< Scalar_t > &weights_state, const TMatrixT< Scalar_t > &input, TMatrixT< Scalar_t > &input_gradient)
Backpropagation step for a Recurrent Neural Network.
Int_t GetNrows() const
Int_t GetNcols() const
TMatrixT.
Definition TMatrixT.h:40
void TMult(const TMatrixT< Element > &a, const TMatrixT< Element > &b)
Replace this matrix with C such that C = A' * B.
Definition TMatrixT.cxx:848
void Mult(const TMatrixT< Element > &a, const TMatrixT< Element > &b)
General matrix multiplication. Replace this matrix with C such that C = A * B.
Definition TMatrixT.cxx:644
create variable transformations
static uint64_t sum(uint64_t i)
Definition Factory.cxx:2345