Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RecurrentPropagation.hxx
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Authors: Surya S Dwivedi 01/08/2019, Saurav Shekhar 23/06/17
3/*************************************************************************
4 * Copyright (C) 2019, Surya S Dwivedi, Saurav Shekhar *
5 * All rights reserved. *
6 * *
7 * For the licensing terms see $ROOTSYS/LICENSE. *
8 * For the list of contributors see $ROOTSYS/README/CREDITS. *
9 *************************************************************************/
10
11/////////////////////////////////////////////////////////////////////
12// Implementation of the functions required for the forward and //
13// backward propagation of activations through a recurrent neural //
14// network in the TCpu architecture //
15/////////////////////////////////////////////////////////////////////
16
19
20namespace TMVA
21{
22namespace DNN
23{
24
25template<typename AFloat>
26auto TCpu<AFloat>::RecurrentLayerBackward(TCpuMatrix<AFloat> & state_gradients_backward, // BxH
27 TCpuMatrix<AFloat> & input_weight_gradients,
28 TCpuMatrix<AFloat> & state_weight_gradients,
29 TCpuMatrix<AFloat> & bias_gradients,
30 TCpuMatrix<AFloat> & df, //BxH
31 const TCpuMatrix<AFloat> & state, // BxH
32 const TCpuMatrix<AFloat> & weights_input, // HxD
33 const TCpuMatrix<AFloat> & weights_state, // HxH
34 const TCpuMatrix<AFloat> & input, // BxD
35 TCpuMatrix<AFloat> & input_gradient)
37{
38 // std::cout << "Recurrent Propo" << std::endl;
39 // TMVA_DNN_PrintTCpuMatrix(df,"DF");
40 // TMVA_DNN_PrintTCpuMatrix(state_gradients_backward,"State grad");
41 // TMVA_DNN_PrintTCpuMatrix(input_weight_gradients,"input w grad");
42 // TMVA_DNN_PrintTCpuMatrix(state,"state");
43 // TMVA_DNN_PrintTCpuMatrix(input,"input");
44
45 // Compute element-wise product.
46 //Hadamard(df, state_gradients_backward); // B x H
47
48 // Input gradients.
49 if (input_gradient.GetNoElements() > 0) {
50 Multiply(input_gradient, df, weights_input);
51 }
52
53 // State gradients.
54 if (state_gradients_backward.GetNoElements() > 0) {
55 Multiply(state_gradients_backward, df, weights_state);
56 }
57
58 // compute the gradients
59 // Perform the operation in place by readding the result on the same gradient matrix
60 // e.g. W += D * X
61
62 // Weights gradients
63 if (input_weight_gradients.GetNoElements() > 0) {
64 TransposeMultiply(input_weight_gradients, df, input, 1. , 1.); // H x B . B x D
65 }
66
67 if (state_weight_gradients.GetNoElements() > 0) {
68 TransposeMultiply(state_weight_gradients, df, state, 1. , 1. ); // H x B . B x H
69 }
70
71 // Bias gradients.
72 if (bias_gradients.GetNoElements() > 0) {
73 SumColumns(bias_gradients, df, 1., 1.); // could be probably do all here
74 }
75
76 return input_gradient;
77}
78
79//______________________________________________________________________________
80template <typename Scalar_t>
81auto inline TCpu<Scalar_t>::LSTMLayerBackward(TCpuMatrix<Scalar_t> & state_gradients_backward,
82 TCpuMatrix<Scalar_t> & cell_gradients_backward,
83 TCpuMatrix<Scalar_t> & input_weight_gradients,
84 TCpuMatrix<Scalar_t> & forget_weight_gradients,
85 TCpuMatrix<Scalar_t> & candidate_weight_gradients,
86 TCpuMatrix<Scalar_t> & output_weight_gradients,
87 TCpuMatrix<Scalar_t> & input_state_weight_gradients,
88 TCpuMatrix<Scalar_t> & forget_state_weight_gradients,
89 TCpuMatrix<Scalar_t> & candidate_state_weight_gradients,
90 TCpuMatrix<Scalar_t> & output_state_weight_gradients,
91 TCpuMatrix<Scalar_t> & input_bias_gradients,
92 TCpuMatrix<Scalar_t> & forget_bias_gradients,
93 TCpuMatrix<Scalar_t> & candidate_bias_gradients,
94 TCpuMatrix<Scalar_t> & output_bias_gradients,
99 const TCpuMatrix<Scalar_t> & precStateActivations,
100 const TCpuMatrix<Scalar_t> & precCellActivations,
101 const TCpuMatrix<Scalar_t> & fInput,
102 const TCpuMatrix<Scalar_t> & fForget,
103 const TCpuMatrix<Scalar_t> & fCandidate,
104 const TCpuMatrix<Scalar_t> & fOutput,
105 const TCpuMatrix<Scalar_t> & weights_input,
106 const TCpuMatrix<Scalar_t> & weights_forget,
107 const TCpuMatrix<Scalar_t> & weights_candidate,
108 const TCpuMatrix<Scalar_t> & weights_output,
109 const TCpuMatrix<Scalar_t> & weights_input_state,
110 const TCpuMatrix<Scalar_t> & weights_forget_state,
111 const TCpuMatrix<Scalar_t> & weights_candidate_state,
112 const TCpuMatrix<Scalar_t> & weights_output_state,
113 const TCpuMatrix<Scalar_t> & input,
114 TCpuMatrix<Scalar_t> & input_gradient,
115 TCpuMatrix<Scalar_t> & cell_gradient,
116 TCpuMatrix<Scalar_t> & cell_tanh)
118{
119 //some temporary varibales used later
120 TCpuMatrix<Scalar_t> tmpInp(input_gradient.GetNrows(), input_gradient.GetNcols());
121 TCpuMatrix<Scalar_t> tmpState(state_gradients_backward.GetNrows(), state_gradients_backward.GetNcols());
122
123 TCpuMatrix<Scalar_t> input_gate_gradient(fInput.GetNrows(), fInput.GetNcols());
124 TCpuMatrix<Scalar_t> forget_gradient(fForget.GetNrows(), fForget.GetNcols());
125 TCpuMatrix<Scalar_t> candidate_gradient(fCandidate.GetNrows(), fCandidate.GetNcols());
126 TCpuMatrix<Scalar_t> output_gradient(fOutput.GetNrows(), fOutput.GetNcols());
127
128 // cell gradient
129 Hadamard(cell_gradient, fOutput);
130 Hadamard(cell_gradient, state_gradients_backward);
131 ScaleAdd(cell_gradient, cell_gradients_backward);
132 Copy(cell_gradients_backward, cell_gradient);
133 Hadamard(cell_gradients_backward, fForget);
134
135 // candidate gradient
136 Copy(candidate_gradient, cell_gradient);
137 Hadamard(candidate_gradient, fInput);
138 Hadamard(candidate_gradient, dc);
139
140 // input gate gradient
141 Copy(input_gate_gradient, cell_gradient);
142 Hadamard(input_gate_gradient, fCandidate);
143 Hadamard(input_gate_gradient, di);
144
145 // forget gradient
146 Copy(forget_gradient, cell_gradient);
147 Hadamard(forget_gradient, precCellActivations);
148 Hadamard(forget_gradient, df);
149
150 // output grdient
151 Copy(output_gradient, cell_tanh);
152 Hadamard(output_gradient, state_gradients_backward);
153 Hadamard(output_gradient, dout);
154
155 // input gradient
156 Multiply(tmpInp, input_gate_gradient, weights_input);
157 Copy(input_gradient, tmpInp);
158 Multiply(tmpInp, forget_gradient, weights_forget);
159 ScaleAdd(input_gradient, tmpInp);
160 Multiply(tmpInp, candidate_gradient, weights_candidate);
161 ScaleAdd(input_gradient, tmpInp);
162 Multiply(tmpInp, output_gradient, weights_output);
163 ScaleAdd(input_gradient, tmpInp);
164
165 // state gradient backwards
166 Multiply(tmpState, input_gate_gradient, weights_input_state);
167 Copy(state_gradients_backward, tmpState);
168 Multiply(tmpState, forget_gradient, weights_forget_state);
169 ScaleAdd(state_gradients_backward, tmpState);
170 Multiply(tmpState, candidate_gradient, weights_candidate_state);
171 ScaleAdd(state_gradients_backward, tmpState);
172 Multiply(tmpState, output_gradient, weights_output_state);
173 ScaleAdd(state_gradients_backward, tmpState);
174
175 // input weight gradient
176 TransposeMultiply(input_weight_gradients, input_gate_gradient, input, 1. , 1.); // H x B . B x D
177 TransposeMultiply(forget_weight_gradients, forget_gradient, input, 1. , 1.);
178 TransposeMultiply(candidate_weight_gradients, candidate_gradient, input, 1. , 1.);
179 TransposeMultiply(output_weight_gradients, output_gradient, input, 1. , 1.);
180
181 // state weight gradients
182 TransposeMultiply(input_state_weight_gradients, input_gate_gradient, precStateActivations, 1. , 1. ); // H x B . B x H
183 TransposeMultiply(forget_state_weight_gradients, forget_gradient, precStateActivations, 1. , 1. );
184 TransposeMultiply(candidate_state_weight_gradients, candidate_gradient, precStateActivations, 1. , 1. );
185 TransposeMultiply(output_state_weight_gradients, output_gradient, precStateActivations, 1. , 1. );
186
187 // bias gradient
188 SumColumns(input_bias_gradients, input_gate_gradient, 1., 1.);
189 SumColumns(forget_bias_gradients, forget_gradient, 1., 1.);
190 SumColumns(candidate_bias_gradients, candidate_gradient, 1., 1.);
191 SumColumns(output_bias_gradients, output_gradient, 1., 1.);
192
193 return input_gradient;
194}
195
196
197//______________________________________________________________________________
198template <typename Scalar_t>
199auto inline TCpu<Scalar_t>::GRULayerBackward(TCpuMatrix<Scalar_t> & state_gradients_backward,
200 TCpuMatrix<Scalar_t> & reset_weight_gradients,
201 TCpuMatrix<Scalar_t> & update_weight_gradients,
202 TCpuMatrix<Scalar_t> & candidate_weight_gradients,
203 TCpuMatrix<Scalar_t> & reset_state_weight_gradients,
204 TCpuMatrix<Scalar_t> & update_state_weight_gradients,
205 TCpuMatrix<Scalar_t> & candidate_state_weight_gradients,
206 TCpuMatrix<Scalar_t> & reset_bias_gradients,
207 TCpuMatrix<Scalar_t> & update_bias_gradients,
208 TCpuMatrix<Scalar_t> & candidate_bias_gradients,
212 const TCpuMatrix<Scalar_t> & precStateActivations,
213 const TCpuMatrix<Scalar_t> & fReset,
214 const TCpuMatrix<Scalar_t> & fUpdate,
215 const TCpuMatrix<Scalar_t> & fCandidate,
216 const TCpuMatrix<Scalar_t> & weights_reset,
217 const TCpuMatrix<Scalar_t> & weights_update,
218 const TCpuMatrix<Scalar_t> & weights_candidate,
219 const TCpuMatrix<Scalar_t> & weights_reset_state,
220 const TCpuMatrix<Scalar_t> & weights_update_state,
221 const TCpuMatrix<Scalar_t> & weights_candidate_state,
222 const TCpuMatrix<Scalar_t> & input,
223 TCpuMatrix<Scalar_t> & input_gradient,
224 bool resetGateAfter)
226{
227 // reset gradient
228 int r = fUpdate.GetNrows(), c = fUpdate.GetNcols();
229 TCpuMatrix<Scalar_t> reset_gradient(r, c);
230 Copy(reset_gradient, fUpdate);
231 for (size_t j = 0; j < (size_t)reset_gradient.GetNcols(); j++) {
232 for (size_t i = 0; i < (size_t)reset_gradient.GetNrows(); i++) {
233 reset_gradient(i, j) = 1 - reset_gradient(i, j);
234 }
235 }
236 Hadamard(reset_gradient, dc);
237 Hadamard(reset_gradient, state_gradients_backward);
238 TCpuMatrix<Scalar_t> tmpMul(r, c);
239
240 if (!resetGateAfter) {
241 // case resetGateAfter is false U * ( r * h)
242 // dr = h * (UT * dy)
243 Multiply(tmpMul, reset_gradient, weights_candidate_state);
244 Hadamard(tmpMul, precStateActivations);
245 } else {
246 // case true : r * ( U * h) --> dr = dy * (U * h)
247 MultiplyTranspose(tmpMul, precStateActivations, weights_candidate_state);
248 Hadamard(tmpMul, reset_gradient);
249 }
250 Hadamard(tmpMul, dr);
251 Copy(reset_gradient, tmpMul);
252
253 // update gradient
254 TCpuMatrix<Scalar_t> update_gradient(r, c); // H X 1
255 Copy(update_gradient, precStateActivations);
256 for (size_t j = 0; j < (size_t)update_gradient.GetNcols(); j++) {
257 for (size_t i = 0; i < (size_t)update_gradient.GetNrows(); i++) {
258 update_gradient(i, j) = update_gradient(i, j) - fCandidate(i, j);
259 }
260 }
261 Hadamard(update_gradient, du);
262 Hadamard(update_gradient, state_gradients_backward);
263
264 // candidate gradient
265 TCpuMatrix<Scalar_t> candidate_gradient(r, c);
266 Copy(candidate_gradient, fUpdate);
267 for (size_t j = 0; j < (size_t)candidate_gradient.GetNcols(); j++) {
268 for (size_t i = 0; i < (size_t)candidate_gradient.GetNrows(); i++) {
269 candidate_gradient(i, j) = 1 - candidate_gradient(i, j);
270 }
271 }
272 Hadamard(candidate_gradient, dc);
273 Hadamard(candidate_gradient, state_gradients_backward);
274
275 // calculating state gradient backwards term by term
276 // term 1
277 TCpuMatrix<Scalar_t> temp(r, c);
278 Copy(temp, state_gradients_backward);
279 TCpuMatrix<Scalar_t> term(r, c); // H X 1
280 Copy(term, fUpdate);
281 Hadamard(term, temp);
282 Copy(state_gradients_backward, term);
283
284 // term 2
285 Copy(term, precStateActivations);
286 Hadamard(term, du);
287 Hadamard(term, temp);
289 Multiply(var, term, weights_update_state);
290 Copy(term, var);
291 ScaleAdd(state_gradients_backward, term);
292
293 // term 3
294 Copy(term, fCandidate);
295 for (size_t j = 0; j < (size_t)term.GetNcols(); j++) {
296 for (size_t i = 0; i < (size_t)term.GetNrows(); i++) {
297 term(i, j) = -term(i, j);
298 }
299 }
300 Hadamard(term, du);
301 Hadamard(term, temp);
302 Multiply(var, term, weights_update_state);
303 Copy(term, var);
304 ScaleAdd(state_gradients_backward, term);
305
306 // term 4
307 Copy(term, fUpdate);
308 for (size_t j = 0; j < (size_t)term.GetNcols(); j++) {
309 for (size_t i = 0; i < (size_t)term.GetNrows(); i++) {
310 term(i, j) = 1 - term(i, j);
311 }
312 }
313 Hadamard(term, dc);
314 Hadamard(term, temp);
315
316 if (!resetGateAfter) {
317 // case resetGateAfter is false : U * ( r * h)
318 // dh = r * (UT * dy)
319 Multiply(var, term, weights_candidate_state);
320 Hadamard(var, fReset);
321 } else {
322 // case resetGateAfter = true
323 // dh = UT * ( r * dy )
324 Hadamard(term, fReset);
325 Multiply(var, term, weights_candidate_state);
326 }
327 //
328 Copy(term, var);
329 ScaleAdd(state_gradients_backward, term);
330
331 // term 5
332 Copy(term, fUpdate);
333 for (size_t j = 0; j < (size_t)term.GetNcols(); j++) {
334 for (size_t i = 0; i < (size_t)term.GetNrows(); i++) {
335 term(i, j) = 1 - term(i, j);
336 }
337 }
338 // here we re-compute dr (probably we could be more eficient)
339 Hadamard(term, dc);
340 Hadamard(term, temp);
341 if (!resetGateAfter) {
342 // case reset gate after = false
343 // recompute dr/dh (as above for dr): // dr = h * (UT * dy)
344 Multiply(var, term, weights_candidate_state);
345 Hadamard(var, precStateActivations);
346 } else {
347 // case = true dr = dy * (U * h)
348 MultiplyTranspose(var, precStateActivations, weights_candidate_state);
349 Hadamard(var, term);
350 }
351 Hadamard(var, dr);
352 Multiply(term, var, weights_reset_state);
353 ScaleAdd(state_gradients_backward, term);
354
355 // input gradients
356 TCpuMatrix<Scalar_t> tmpInp(input_gradient.GetNrows(), input_gradient.GetNcols());
357 Multiply(tmpInp, reset_gradient, weights_reset);
358 Copy(input_gradient, tmpInp);
359 Multiply(tmpInp, update_gradient, weights_update);
360 ScaleAdd(input_gradient, tmpInp);
361 Multiply(tmpInp, candidate_gradient, weights_candidate);
362 ScaleAdd(input_gradient, tmpInp);
363
364 // input weight gradients
365 TransposeMultiply(reset_weight_gradients, reset_gradient, input, 1., 1.); // H x B . B x D
366 TransposeMultiply(update_weight_gradients, update_gradient, input, 1., 1.);
367 TransposeMultiply(candidate_weight_gradients, candidate_gradient, input, 1., 1.);
368
369 // state weight gradients
370 TransposeMultiply(reset_state_weight_gradients, reset_gradient, precStateActivations, 1., 1.); // H x B . B x H
371 TransposeMultiply(update_state_weight_gradients, update_gradient, precStateActivations, 1., 1.);
372 TCpuMatrix<Scalar_t> tempvar(r, c);
373
374 // candidate weight gradients
375 // impl case reseyGateAfter = false
376 if (!resetGateAfter) {
377 // dU = ( h * r) * dy
378 Copy(tempvar, precStateActivations);
379 Hadamard(tempvar, fReset);
380 TransposeMultiply(candidate_state_weight_gradients, candidate_gradient, tempvar, 1., 1.);
381 } else {
382 // case resetAfter=true
383 // dU = h * ( r * dy)
384 Copy(tempvar, candidate_gradient);
385 Hadamard(tempvar, fReset);
386 TransposeMultiply(candidate_state_weight_gradients, tempvar, precStateActivations, 1., 1.);
387 }
388
389 // bias gradients
390 SumColumns(reset_bias_gradients, reset_gradient, 1., 1.); // could be probably do all here
391 SumColumns(update_bias_gradients, update_gradient, 1., 1.);
392 SumColumns(candidate_bias_gradients, candidate_gradient, 1., 1.);
393
394 return input_gradient;
395}
396
397} // namespace DNN
398} // namespace TMVA
ROOT::R::TRInterface & r
Definition Object.C:4
#define c(i)
Definition RSha256.hxx:101
The TCpuMatrix class.
Definition CpuMatrix.h:86
size_t GetNcols() const
Definition CpuMatrix.h:156
size_t GetNrows() const
Definition CpuMatrix.h:155
static Matrix_t & LSTMLayerBackward(TCpuMatrix< Scalar_t > &state_gradients_backward, TCpuMatrix< Scalar_t > &cell_gradients_backward, TCpuMatrix< Scalar_t > &input_weight_gradients, TCpuMatrix< Scalar_t > &forget_weight_gradients, TCpuMatrix< Scalar_t > &candidate_weight_gradients, TCpuMatrix< Scalar_t > &output_weight_gradients, TCpuMatrix< Scalar_t > &input_state_weight_gradients, TCpuMatrix< Scalar_t > &forget_state_weight_gradients, TCpuMatrix< Scalar_t > &candidate_state_weight_gradients, TCpuMatrix< Scalar_t > &output_state_weight_gradients, TCpuMatrix< Scalar_t > &input_bias_gradients, TCpuMatrix< Scalar_t > &forget_bias_gradients, TCpuMatrix< Scalar_t > &candidate_bias_gradients, TCpuMatrix< Scalar_t > &output_bias_gradients, TCpuMatrix< Scalar_t > &di, TCpuMatrix< Scalar_t > &df, TCpuMatrix< Scalar_t > &dc, TCpuMatrix< Scalar_t > &dout, const TCpuMatrix< Scalar_t > &precStateActivations, const TCpuMatrix< Scalar_t > &precCellActivations, const TCpuMatrix< Scalar_t > &fInput, const TCpuMatrix< Scalar_t > &fForget, const TCpuMatrix< Scalar_t > &fCandidate, const TCpuMatrix< Scalar_t > &fOutput, const TCpuMatrix< Scalar_t > &weights_input, const TCpuMatrix< Scalar_t > &weights_forget, const TCpuMatrix< Scalar_t > &weights_candidate, const TCpuMatrix< Scalar_t > &weights_output, const TCpuMatrix< Scalar_t > &weights_input_state, const TCpuMatrix< Scalar_t > &weights_forget_state, const TCpuMatrix< Scalar_t > &weights_candidate_state, const TCpuMatrix< Scalar_t > &weights_output_state, const TCpuMatrix< Scalar_t > &input, TCpuMatrix< Scalar_t > &input_gradient, TCpuMatrix< Scalar_t > &cell_gradient, TCpuMatrix< Scalar_t > &cell_tanh)
Backward pass for LSTM Network.
static Matrix_t & RecurrentLayerBackward(Matrix_t &state_gradients_backward, Matrix_t &input_weight_gradients, Matrix_t &state_weight_gradients, Matrix_t &bias_gradients, Matrix_t &df, const Matrix_t &state, const Matrix_t &weights_input, const Matrix_t &weights_state, const Matrix_t &input, Matrix_t &input_gradient)
Backward pass for Recurrent Networks.
static Matrix_t & GRULayerBackward(TCpuMatrix< Scalar_t > &state_gradients_backward, TCpuMatrix< Scalar_t > &reset_weight_gradients, TCpuMatrix< Scalar_t > &update_weight_gradients, TCpuMatrix< Scalar_t > &candidate_weight_gradients, TCpuMatrix< Scalar_t > &reset_state_weight_gradients, TCpuMatrix< Scalar_t > &update_state_weight_gradients, TCpuMatrix< Scalar_t > &candidate_state_weight_gradients, TCpuMatrix< Scalar_t > &reset_bias_gradients, TCpuMatrix< Scalar_t > &update_bias_gradients, TCpuMatrix< Scalar_t > &candidate_bias_gradients, TCpuMatrix< Scalar_t > &dr, TCpuMatrix< Scalar_t > &du, TCpuMatrix< Scalar_t > &dc, const TCpuMatrix< Scalar_t > &precStateActivations, const TCpuMatrix< Scalar_t > &fReset, const TCpuMatrix< Scalar_t > &fUpdate, const TCpuMatrix< Scalar_t > &fCandidate, const TCpuMatrix< Scalar_t > &weights_reset, const TCpuMatrix< Scalar_t > &weights_update, const TCpuMatrix< Scalar_t > &weights_candidate, const TCpuMatrix< Scalar_t > &weights_reset_state, const TCpuMatrix< Scalar_t > &weights_update_state, const TCpuMatrix< Scalar_t > &weights_candidate_state, const TCpuMatrix< Scalar_t > &input, TCpuMatrix< Scalar_t > &input_gradient, bool resetGateAfter)
Backward pass for GRU Network.
create variable transformations