Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_GRU.icc
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR_GRU_I
2#define TMVA_SOFIE_ROPERATOR_GRU_I
3
4namespace TMVA {
5namespace Experimental {
6namespace SOFIE {
7
8template <typename T>
9auto ROperator_GRU<T>::TypeInference(std::vector<ETensorType> input)
10-> std::vector<ETensorType> {
11 ETensorType out = input[0];
12 return {out, out};
13}
14
15template<typename T>
16auto ROperator_GRU<T>::ShapeInference(std::vector<std::vector<size_t>> input)
17-> std::vector<std::vector<size_t>> {
18 size_t num_directions = input[1][0];
19 size_t hidden_size = input[1][1] / 3;
20 if (fAttrLayout == 0) {
21 size_t seq_length = input[0][0];
22 size_t batch_size = input[0][1];
23 std::vector<std::vector<size_t>> ret(
24 {{seq_length, num_directions, batch_size, hidden_size},
25 {num_directions, batch_size, hidden_size}});
26 return ret;
27 } else {
28 size_t batch_size = input[0][0];
29 size_t seq_length = input[0][1];
30 std::vector<std::vector<size_t>> ret(
31 {{batch_size, seq_length, num_directions, hidden_size},
32 {batch_size, num_directions, hidden_size}});
33 return ret;
34 }
35}
36
37template<typename T>
39 fUseSession = model.UseSession();
40 // Check the input and output tensors
41 if (!model.CheckIfTensorAlreadyExist(fNX)) {
42 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNX + " is not found in model.");
43 }
44 fShapeX = model.GetTensorShape(fNX);
45 if (fShapeX.size() != 3) {
46 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNX + " is not of 3 dimensions.");
47 }
48 if (!model.CheckIfTensorAlreadyExist(fNW)) {
49 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNW + " is not found in model.");
50 }
51 fShapeW = model.GetTensorShape(fNW);
52 if (fShapeW.size() != 3) {
53 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNW + " is not of 3 dimensions.");
54 }
55 if (!model.CheckIfTensorAlreadyExist(fNR)) {
56 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNR + " is not found in model.");
57 }
58 fShapeR = model.GetTensorShape(fNR);
59 if (fShapeR.size() != 3) {
60 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNR + " is not of 3 dimensions.");
61 }
62 if (!fNB.empty()) {
63 if (!model.CheckIfTensorAlreadyExist(fNB)) {
64 throw std::runtime_error("TMVA SOFIE GRU op input tensor " + fNB + " is not found in model.");
65 }
66 fShapeB = model.GetTensorShape(fNB);
67 if (fShapeB.size() != 2 && fShapeB.size() != 4) {
68 throw std::runtime_error("TMVA SOFIE GRU op input tensor " + fNB + " is not of 2 or 4 dimensions.");
69 }
70 if (fShapeB.size() == 2) {
71 // Broadcasting the bias
72 auto original_data = model.GetInitializedTensorData(fNB);
73 size_t num_directions = fShapeW[0];
74 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
75 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
76 if (fType == "float") {
77 float *original_bias = static_cast<float*>(original_data.get());
78 float *new_bias = new float[num_directions * 6 * seq_length * batch_size * fAttrHiddenSize];
79 for (size_t direction = 0; direction < num_directions; direction++) {
80 for (size_t i = 0; i < 6; i++) {
81 for (size_t seq = 0; seq < seq_length; seq++) {
82 for (size_t batch = 0; batch < batch_size; batch++) {
83 size_t bias_offset = direction * 6 * fAttrHiddenSize + i * fAttrHiddenSize;
84 size_t offset = direction * 6 * batch_size * seq_length * fAttrHiddenSize +
85 i * batch_size * seq_length * fAttrHiddenSize +
86 + seq *batch_size *fAttrHiddenSize + batch *fAttrHiddenSize;
87 std::copy(original_bias + bias_offset, original_bias + bias_offset + fAttrHiddenSize,
88 new_bias + offset);
89 }
90 }
91 }
92 }
93
94 std::vector<size_t> new_bias_shape = {num_directions, 6, seq_length, batch_size, fAttrHiddenSize};
95 std::shared_ptr<void> new_bias_ptr(new_bias, std::default_delete<float[]>());
96 model.UpdateInitializedTensor(fNB, model.GetTensorType(fNB), new_bias_shape, new_bias_ptr);
97 fShapeB = model.GetTensorShape(fNB);
98 }
99 }
100 }
101 if (!fNSequence_lens.empty()) {
102 if (!model.CheckIfTensorAlreadyExist(fNSequence_lens)) {
103 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
104 fNSequence_lens +
105 "is not found in model.");
106 }
107 fShapeSequence_lens = model.GetTensorShape(fNSequence_lens);
108 if (fShapeSequence_lens.size() != 1) {
109 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
110 fNSequence_lens +
111 " is not of 1 dimension.");
112 }
113 }
114 if (!fNInitial_h.empty()) {
115 if (!model.CheckIfTensorAlreadyExist(fNInitial_h)) {
116 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
117 fNInitial_h + " is not found in model.");
118 }
119 fShapeInitial_h = model.GetTensorShape(fNInitial_h);
120 if (fShapeInitial_h.size() != 3) {
121 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
122 fNInitial_h + " is not of 3 dimensions.");
123 }
124 }
125 if (!fNY.empty()) {
126 fShapeY = ShapeInference({fShapeX, fShapeW})[0];
127 if (!model.CheckIfTensorAlreadyExist(fNY)) {
128 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
129 }
130 }
131 if (!fNY_h.empty()) {
132 fShapeY_h = ShapeInference({fShapeX, fShapeW})[1];
133 if (!model.CheckIfTensorAlreadyExist(fNY_h)) {
134 model.AddIntermediateTensor(fNY_h, model.GetTensorType(fNX), fShapeY_h);
135 }
136 }
137 // Check the attributes
138 for (auto &activation : fAttrActivations) {
139 if (activation != "Relu" && activation != "Tanh" &&
140 activation != "Sigmoid" && activation != "Affine" &&
141 activation != "LeakyRelu" && activation != "ThresholdRelu" &&
142 activation != "ScaledTanh" && activation != "HardSigmoid" &&
143 activation != "Elu" && activation != "Softsign" &&
144 activation != "Softplus") {
145 throw std::runtime_error("TMVA SOFIE - Activation function " +
146 activation + " not implemented");
147 }
148 }
149 if (fAttrDirection == "reverse") fAttrDirection = "backward";
150 if (fAttrDirection != "forward" && fAttrDirection != "backward" &&
151 fAttrDirection != "reverse" &&
152 fAttrDirection != "bidirectional") {
153 throw std::runtime_error(
154 "TMVA SOFIE - Invalid GRU direction fAttrDirection = " +
155 fAttrDirection);
156 }
157 if (3 * fAttrHiddenSize != fShapeW[1]) {
158 throw std::runtime_error(
159 "TMVA SOFIE - fAttrHiddenSize must be equal to " +
160 std::to_string(fShapeW[1] / 3));
161 }
162 if (fAttrLayout > 1) {
163 throw std::runtime_error("TMVA SOFIE - Layout fAttrLayout = " +
164 std::to_string(fAttrLayout) +
165 " must be 0 (timewise) or 1 (batchwise)");
166 }
167 if (fAttrLinearBeforeReset > 1) {
168 throw std::runtime_error(
169 "TMVA SOFIE - fAttrInputForget = " + std::to_string(fAttrLinearBeforeReset)
170 + " must be 0 or 1.");
171 }
172 if (fAttrActivations.empty()) {
173 if (fAttrDirection == "bidirectional") {
174 fAttrActivations = {"Sigmoid", "Tanh", "Sigmoid", "Tanh"};
175 } else {
176 fAttrActivations = {"Sigmoid", "Tanh"};
177 }
178 }
179}
180
181// generate code for Session data members (e.g. internal vectors)
182template <typename T>
183std::string ROperator_GRU<T>::GenerateSessionMembersCode(std::string opName)
184{
185 opName = "op_" + opName;
186 std::stringstream out;
187
188 size_t num_directions = fShapeW[0];
189 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
190 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
191 size_t input_size = fShapeX[2];
192
193 if (fAttrLayout != 0) {
194 out << "std::vector<" << fType << "> fVec_" << opName << "_input = std::vector<" << fType << ">("
195 << seq_length * batch_size * input_size << ");\n";
196 out << "std::vector<" << fType << "> fVec_" << opName << "_initial_hidden_state = std::vector<" << fType << ">("
197 << num_directions * batch_size * fAttrHiddenSize << ");\n";
198 out << "std::vector<" << fType << "> fVec_" << opName << "_initial_cell_state = std::vector<" << fType << ">("
199 << num_directions * batch_size * fAttrHiddenSize << ");\n";
200 }
201 // Set the feedforward
202 size_t ff_size = seq_length * batch_size * fAttrHiddenSize;
203 out << "std::vector<" << fType << "> fVec_" << opName << "_f_update_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
204 out << "std::vector<" << fType << "> fVec_" << opName << "_f_reset_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
205 out << "std::vector<" << fType << "> fVec_" << opName << "_f_hidden_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
206 // gate results
207 size_t hs_size = seq_length * num_directions * batch_size * fAttrHiddenSize;
208 out << "std::vector<" << fType << "> fVec_" << opName << "_update_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
209 out << "std::vector<" << fType << "> fVec_" << opName << "_reset_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
210 out << "std::vector<" << fType << "> fVec_" << opName << "_hidden_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
211
212 // feedback
213 out << "std::vector<" << fType << "> fVec_" << opName << "_feedback = std::vector<" << fType << ">("
214 << batch_size * fAttrHiddenSize << ");\n";
215
216 // hiddden state
217 if (fAttrLayout != 0 || fNY.empty()) {
218 out << "std::vector<" << fType << "> fVec_" << opName << "_hidden_state = std::vector<" << fType << ">(" << hs_size << ");\n";
219 }
220
221 out << "\n";
222
223 return out.str();
224}
225
226
227template<typename T>
228auto ROperator_GRU<T>::Generate(std::string OpName)
229-> std::string {
230 OpName = "op_" + OpName;
231 std::stringstream out;
232
233 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
234 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
235 size_t input_size = fShapeX[2];
236 size_t num_directions = fShapeW[0];
237
238 // set the input
239 if (fAttrLayout == 0) {
240 out << SP << fType << " *" << OpName << "_input = tensor_" << fNX << ";\n";
241 } else {
242 if (fUseSession) {
243 out << SP << fType << " * " << OpName << "_input = fVec_" << OpName << "_input.data();\n";
244 } else {
245 out << SP << fType << " " << OpName << "_input[" << seq_length * batch_size * input_size << "];\n";
246 }
247 out << SP << "for(size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
248 out << SP << SP << "for(size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
249 out << SP << SP << SP << "for(size_t i = 0; i < " << input_size << "; i++) {\n";
250 out << SP << SP << SP << SP << OpName << "_input[seq * " << batch_size * input_size
251 << " + batch * " << input_size << " + i] = " << "tensor_" << fNX << "[batch * "
252 << seq_length * input_size << " + seq * " << input_size << " + i];\n";
253 out << SP << SP << SP << "}\n";
254 out << SP << SP << "}\n";
255 out << SP << "}\n";
256 }
257
258 // Set the initial hidden state
259 if (!fNInitial_h.empty()) {
260 if (fAttrLayout == 0) {
261 out << SP << fType << " *" << OpName << "_initial_hidden_state = " << " tensor_"
262 << fNInitial_h << ";\n";
263 } else {
264 if (fUseSession) {
265 out << SP << fType << " * " << OpName << "_initial_hidden_state = fVec_" << OpName
266 << "_initial_hidden_state.data();\n";
267 } else {
268 out << SP << fType << " " << OpName << "_initial_hidden_state[" << num_directions * batch_size *
269 fAttrHiddenSize << "];\n";
270 }
271 for (size_t direction = 0; direction < num_directions; direction++) {
272 out << SP << "for(size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
273 out << SP << SP << "for(size_t h = 0; h < " << fAttrHiddenSize << "; h++) {\n";
274 out << SP << SP << SP << OpName << "_initial_hidden_state["
275 << direction * batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize
276 << " + h] = tensor_" << fNInitial_h << "[batch * " << num_directions * fAttrHiddenSize
277 << " + " << direction * fAttrHiddenSize << " + h];\n";
278 out << SP << SP << "}\n";
279 out << SP << "}\n";
280 }
281 }
282 }
283
284 // Set the feedforward
285 size_t feedforward_size = seq_length * batch_size * fAttrHiddenSize;
286 if (fUseSession) {
287 out << SP << fType << " * " << OpName << "_f_update_gate = fVec_" << OpName << "_f_update_gate.data();\n";
288 out << SP << fType << " * " << OpName << "_f_reset_gate = fVec_" << OpName << "_f_reset_gate.data();\n";
289 out << SP << fType << " * " << OpName << "_f_hidden_gate = fVec_" << OpName << "_f_hidden_gate.data();\n";
290 } else {
291 out << SP << fType << " " << OpName << "_f_update_gate[" << feedforward_size << "] = {0};\n";
292 out << SP << fType << " " << OpName << "_f_reset_gate[" << feedforward_size << "] = {0};\n";
293 out << SP << fType << " " << OpName << "_f_hidden_gate[" << feedforward_size << "] = {0};\n";
294 }
295 // Set the gates
296 size_t hidden_state_size = seq_length * num_directions * batch_size * fAttrHiddenSize;
297 if (fUseSession) {
298 out << SP << fType << " * " << OpName << "_update_gate = fVec_" << OpName << "_update_gate.data();\n";
299 out << SP << fType << " * " << OpName << "_reset_gate = fVec_" << OpName << "_reset_gate.data();\n";
300 out << SP << fType << " * " << OpName << "_hidden_gate = fVec_" << OpName << "_hidden_gate.data();\n";
301 } else {
302 out << SP << fType << " " << OpName << "_update_gate[" << hidden_state_size << "] = {0};\n";
303 out << SP << fType << " " << OpName << "_reset_gate[" << hidden_state_size << "] = {0};\n";
304 out << SP << fType << " " << OpName << "_hidden_gate[" << hidden_state_size << "] = {0};\n";
305 }
306 // Set the hidden state
307 if (fAttrLayout == 0 && !fNY.empty()) {
308 out << SP << fType << " *" << OpName << "_hidden_state = tensor_" << fNY << ";\n";
309 } else {
310 if (fUseSession) {
311 out << SP << fType << " * " << OpName << "_hidden_state = fVec_" << OpName << "_hidden_state.data();\n";
312 } else {
313 out << SP << fType << " " << OpName << "_hidden_state[" << hidden_state_size << "] = {0};\n";
314 }
315 }
316
317 if (fUseSession) {
318 out << SP << fType << " * " << OpName << "_feedback = fVec_" << OpName << "_feedback.data();\n";
319 } else {
320 out << SP << fType << " " << OpName << "_feedback[" << batch_size * fAttrHiddenSize << "] = {0};\n";
321 }
322
323 out << SP << "char " << OpName << "_transA = 'N';\n";
324 out << SP << "char " << OpName << "_transB = 'T';\n";
325 out << SP << "int " << OpName << "_m = " << seq_length * batch_size << ";\n";
326 out << SP << "int " << OpName << "_m2 = " << batch_size << ";\n";
327 out << SP << "int " << OpName << "_n = " << fAttrHiddenSize << ";\n";
328 out << SP << "int " << OpName << "_k = " << input_size << ";\n";
329 if (fType == "float") {
330 out << SP << "float " << OpName << "_alpha = 1.;\n";
331 out << SP << "float " << OpName << "_beta = 0.;\n";
332 }
333 if (!fNB.empty()) {
334 out << SP << "int " << OpName << "_bias_size = " << seq_length * batch_size * fAttrHiddenSize << ";\n";
335 }
336 out << SP << "int " << OpName << "_incx = 1;\n";
337 out << SP << "int " << OpName << "_incy = 1;\n";
338 out << SP << "int " << OpName << "_feedback_size = " << batch_size * fAttrHiddenSize << ";\n";
339
340 for (size_t direction = 0; direction < num_directions; direction++) {
341 if (direction == 0) {
342 if (fType == "float") {
343 // f_update_gate = input * weight_z^T
344 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
345 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
346 << fNW << ", &" << OpName << "_k, " << OpName << "_input, &" << OpName << "_k, &"
347 << OpName << "_beta, " << OpName << "_f_update_gate, &" << OpName << "_n);\n";
348 // f_reset_gate = input * weight_r^T
349 size_t wr_offset = fAttrHiddenSize * input_size;
350 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
351 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
352 << fNW << " + " << wr_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
353 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_reset_gate, &" << OpName << "_n);\n";
354 // f_hidden_gate = input * weight_h^T
355 size_t wh_offset = 2 * fAttrHiddenSize * input_size;
356 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
357 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
358 << fNW << " + " << wh_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
359 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_hidden_gate, &" << OpName << "_n);\n";
360 }
361 } else {
362 if (fType == "float") {
363 // f_update_gate = input * weight_z^T
364 size_t wz_offset = 3 * fAttrHiddenSize * input_size;
365 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
366 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
367 << fNW << " + " << wz_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
368 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_update_gate, &" << OpName << "_n);\n";
369 // f_reset_gate = input * weight_r^T
370 size_t wr_offset = 3 * fAttrHiddenSize * input_size + fAttrHiddenSize * input_size;
371 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
372 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
373 << fNW << " + " << wr_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
374 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_reset_gate, &" << OpName << "_n);\n";
375 // f_hidden_gate = input * weight_h^T
376 size_t wh_offset = 3 * fAttrHiddenSize * input_size + 2 * fAttrHiddenSize * input_size;
377 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
378 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
379 << fNW << " + " << wh_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
380 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_hidden_gate, &" << OpName << "_n);\n";
381 }
382 }
383
384 if (!fNB.empty()) {
385 if (direction == 0) {
386 if (fType == "float") {
387 // Add the bias of the weight to f_update_gate
388 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
389 << fNB << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &" << OpName << "_incy);\n";
390 // Add the bias of the recurrence to f_update_gate
391 size_t rbz_offset = 3 * batch_size * seq_length * fAttrHiddenSize;
392 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
393 << fNB << " + " << rbz_offset << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &"
394 << OpName << "_incy);\n";
395 // Add the bias of the weight to f_reset_gate
396 size_t wbr_offset = batch_size * seq_length * fAttrHiddenSize;
397 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
398 << fNB << " + " << wbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
399 << OpName << "_incy);\n";
400 // Add the bias of the recurrence to f_reset_gate
401 //size_t rbr_offset = fAttrHiddenSize * fAttrHiddenSize + 3 * batch_size * fAttrHiddenSize;
402 size_t rbr_offset = 4 * batch_size * seq_length * fAttrHiddenSize;
403 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
404 << fNB << " + " << rbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
405 << OpName << "_incy);\n";
406 // Add the bias of the weight to f_hidden_gate
407 size_t wbh_offset = 2 * batch_size * seq_length * fAttrHiddenSize;
408 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
409 << fNB << " + " << wbh_offset << ", &" << OpName << "_incx, " << OpName << "_f_hidden_gate, &"
410 << OpName << "_incy);\n";
411 if (fAttrLinearBeforeReset == 0) {
412 // Add the bias of the recurrence to f_hidden_gate
413 size_t rbh_offset = 5 * batch_size * seq_length * fAttrHiddenSize;
414 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
415 << fNB << " + " << rbh_offset << ", &" << OpName << "_incx, " << OpName
416 << "_f_hidden_gate, &" << OpName << "_incy);\n";
417 }
418 }
419 } else {
420 if (fType == "float") {
421 // Add the bias of the weight to f_update_gate
422 size_t wbz_offset = 6 * batch_size * seq_length * fAttrHiddenSize;
423 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
424 << fNB << " + " << wbz_offset << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &"
425 << OpName << "_incy);\n";
426 // Add the bias of the recurrence to f_update_gate
427 // size_t rbz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize + 3 * batch_size * fAttrHiddenSize;
428 size_t rbz_offset = 9 * batch_size * seq_length * fAttrHiddenSize;
429 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
430 << fNB << " + " << rbz_offset << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &"
431 << OpName << "_incy);\n";
432 // Add the bias of the weight to f_reset_gate
433 size_t wbr_offset = 7 * batch_size * seq_length * fAttrHiddenSize;
434 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
435 << fNB << " + " << wbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
436 << OpName << "_incy);\n";
437 // Add the bias of the recurrence to f_reset_gate
438 size_t rbr_offset = 10 * batch_size * seq_length * fAttrHiddenSize;
439 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
440 << fNB << " + " << rbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
441 << OpName << "_incy);\n";
442 // Add the bias of the weight to f_hidden_gate
443 size_t wbh_offset = 8 * batch_size * seq_length * fAttrHiddenSize;
444 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
445 << fNB << " + " << wbh_offset << ", &" << OpName << "_incx, " << OpName << "_f_hidden_gate, &"
446 << OpName << "_incy);\n";
447 if (fAttrLinearBeforeReset == 0) {
448 // Add the bias of the recurrence to f_hidden_gate
449 size_t rbh_offset = 11 * batch_size * seq_length * fAttrHiddenSize;
450 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
451 << fNB << " + " << rbh_offset << ", &" << OpName << "_incx, " << OpName
452 << "_f_hidden_gate, &" << OpName << "_incy);\n";
453 }
454 }
455 }
456 }
457
458 // Copy the feedforward into the gates
459 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
460 out << SP << SP << "size_t offset = seq * " << batch_size * fAttrHiddenSize << ";\n";
461 if (direction == 0) {
462 out << SP << SP << "size_t gate_offset = seq * " << num_directions * batch_size * fAttrHiddenSize
463 << ";\n";
464 } else {
465 out << SP << SP << "size_t gate_offset = seq * " << num_directions * batch_size * fAttrHiddenSize
466 << " + " << batch_size * fAttrHiddenSize << ";\n";
467 }
468 size_t f_seq_size = batch_size * fAttrHiddenSize;
469 out << SP << SP << "std::copy(" << OpName << "_f_update_gate + offset, " << OpName
470 << "_f_update_gate + offset + " << f_seq_size << ", " << OpName << "_update_gate + gate_offset);\n";
471 out << SP << SP << "std::copy(" << OpName << "_f_reset_gate + offset, " << OpName
472 << "_f_reset_gate + offset + " << f_seq_size << ", " << OpName << "_reset_gate + gate_offset);\n";
473 out << SP << SP << "std::copy(" << OpName << "_f_hidden_gate + offset, " << OpName
474 << "_f_hidden_gate + offset + " << f_seq_size << ", " << OpName << "_hidden_gate + gate_offset);\n";
475 out << SP << "}\n";
476
477 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
478 if (fAttrDirection == "backward" || direction == 1) {
479 out << SP << SP << "size_t index = " << seq_length - 1 << " - seq;\n";
480 } else {
481 out << SP << SP << "size_t index = seq;\n";
482 }
483 out << SP << SP << "int m2 = " << batch_size << ";\n";
484 if (direction == 0) {
485 out << SP << SP << "size_t offset = index * " << num_directions * batch_size * fAttrHiddenSize
486 << ";\n";
487 } else {
488 out << SP << SP << "size_t offset = index * " << num_directions * batch_size * fAttrHiddenSize
489 << " + " << batch_size * fAttrHiddenSize << ";\n";
490 }
491 size_t size = batch_size * fAttrHiddenSize;
492 // gate = gate + initial_hidden_state * Recurrence^T
493 out << SP << SP << "if (seq == 0) {\n";
494 if (!fNInitial_h.empty()) {
495 if (direction == 0) {
496 if (fType == "float") {
497 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
498 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << ", &"
499 << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName << "_n, &" << OpName
500 << "_alpha, " << OpName << "_update_gate + offset, &" << OpName << "_n);\n";
501 size_t rr_offset = fAttrHiddenSize * fAttrHiddenSize;
502 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
503 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
504 << rr_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
505 << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &" << OpName << "_n);\n";
506 }
507 } else { // direction=1
508 if (fType == "float") {
509 size_t rz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
510 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
511 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
512 << rz_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
513 << "_n, &" << OpName << "_alpha, " << OpName << "_update_gate + offset, &" << OpName << "_n);\n";
514 size_t rr_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
515 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
516 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
517 << rr_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
518 << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &" << OpName << "_n);\n";
519 }
520 }
521 }
522 out << SP << SP << "} else {\n";
523 // gate = gate + previous_hidden_state * Recurrence^T
524 if (direction == 0) {
525 if (fAttrDirection == "backward") {
526 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
527 << num_directions * batch_size * fAttrHiddenSize << ";\n";
528 } else {
529 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
530 << num_directions * batch_size * fAttrHiddenSize << ";\n";
531 }
532 if (fType == "float") {
533 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
534 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << ", &"
535 << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &" << OpName << "_n, &"
536 << OpName << "_alpha, " << OpName << "_update_gate + offset, &" << OpName << "_n);\n";
537 size_t rr_offset = fAttrHiddenSize * fAttrHiddenSize;
538 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
539 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
540 << rr_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
541 << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &"
542 << OpName << "_n);\n";
543 }
544 } else {
545 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
546 << num_directions * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
547 if (fType == "float") {
548 size_t rz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
549 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
550 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
551 << rz_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
552 << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_update_gate + offset, &"
553 << OpName << "_n);\n";
554 size_t rr_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
555 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
556 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
557 << rr_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
558 << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &"
559 << OpName << "_n);\n";
560 }
561 }
562 out << SP << SP << "}\n";
563
564 // Clip the elements of the update gate and the reset gate into the range [-fClip, fClip]
565 if (fAttrClip > .0) {
566 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
567 if (fType == "float") {
568 out << SP << SP << SP << "float z = (" << OpName << "_update_gate[i] > " << -fAttrClip
569 << ") ? " << OpName << "_update_gate[i] : " << -fAttrClip << ";\n";
570 }
571 out << SP << SP << SP << OpName << "_update_gate[i] = (z < " << fAttrClip
572 << ") ? z : " << fAttrClip << ";\n";
573 if (fType == "float") {
574 out << SP << SP << SP << "float r = (" << OpName << "_reset_gate[i] > " << -fAttrClip
575 << ") ? " << OpName << "_reset_gate[i] : " << -fAttrClip << ";\n";
576 }
577 out << SP << SP << SP << OpName << "_reset_gate[i] = (r < " << fAttrClip
578 << ") ? r : " << fAttrClip << ";\n";
579 out << SP << SP << "}\n";
580 }
581
582 // Apply the activation function to the update gate and the reset gate
583 if (fAttrActivations[direction * 2] == "Relu") {
584 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
585 out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < 0.)\n";
586 out << SP << SP << SP << SP << OpName << "_update_gate[i] = 0.;\n";
587 out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < 0.)\n";
588 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = 0.;\n";
589 out << SP << SP << "}\n";
590 } else if (fAttrActivations[direction * 2] == "Tanh") {
591 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
592 if (fType == "float") {
593 out << SP << SP << SP << "float z = exp(-2 * " << OpName << "_update_gate[i]);\n";
594 }
595 out << SP << SP << SP << SP << OpName << "_update_gate[i] = (1. - z) / (1. + z);\n";
596 if (fType == "float") {
597 out << SP << SP << SP << "float r = exp(-2 * " << OpName << "_reset_gate[i]);\n";
598 }
599 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = (1. - r) / (1. + r);\n";
600 out << SP << SP << "}\n";
601 } else if (fAttrActivations[direction * 2] == "Sigmoid") {
602 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
603 out << SP << SP << SP << SP << OpName << "_update_gate[i] = 1. / (1. + exp(-"
604 << OpName << "_update_gate[i]));\n";
605 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = 1. / (1. + exp(-"
606 << OpName << "_reset_gate[i]));\n";
607 out << SP << SP << "}\n";
608 } else if (fAttrActivations[direction * 2] == "Affine") {
609 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
610 out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
611 << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_update_gate[i] + "
612 << fAttrActivationBeta[direction * 2] << ";\n";
613 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
614 << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_reset_gate[i] + "
615 << fAttrActivationBeta[direction * 2] << ";\n";
616 out << SP << SP << "}\n";
617 } else if (fAttrActivations[direction * 2] == "ScaledTanh") {
618 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
619 if (fType == "float") {
620 out << SP << SP << SP << "float z = exp(-2 * " << fAttrActivationBeta[direction * 2]
621 << " * "<< OpName << "_update_gate[i]);\n";
622 }
623 out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
624 << fAttrActivationAlpha[direction * 2] << " * (1. - z) / (1. + z);\n";
625 if (fType == "float") {
626 out << SP << SP << SP << "float r = exp(-2 * " << fAttrActivationBeta[direction * 2]
627 << " * "<< OpName << "_reset_gate[i]);\n";
628 }
629 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
630 << fAttrActivationAlpha[direction * 2] << " * (1. - r) / (1. + r);\n";
631 out << SP << SP << "}\n";
632 } else if (fAttrActivations[direction * 2] == "HardSigmoid") {
633 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
634 if (fType == "float") {
635 out << SP << SP << SP << "float za = " << fAttrActivationAlpha[direction * 2] << " * "
636 << OpName << "_update_gate[i] + " << fAttrActivationBeta[direction * 2] << ";\n";
637 out << SP << SP << SP << "float zb = (za > 0.) ? za : 0.;\n";
638 }
639 out << SP << SP << SP << SP << OpName << "_update_gate[i] = (zb < 1.) ? zb : 1.;\n";
640 if (fType == "float") {
641 out << SP << SP << SP << "float ra = " << fAttrActivationAlpha[direction * 2] << " * "
642 << OpName << "_reset_gate[i] + " << fAttrActivationBeta[direction * 2] << ";\n";
643 out << SP << SP << SP << "float rb = (ra > 0.) ? ra : 0.;\n";
644 }
645 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = (rb < 1.) ? rb : 1.;\n";
646 out << SP << SP << "}\n";
647 } else if (fAttrActivations[direction * 2] == "LeakyRelu") {
648 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
649 out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < 0.)\n";
650 out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
651 << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_update_gate[i];\n";
652 out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < 0.)\n";
653 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
654 << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_reset_gate[i];\n";
655 out << SP << SP << "}\n";
656 } else if (fAttrActivations[direction * 2] == "ThresholdRelu") {
657 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
658 out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < "
659 << fAttrActivationAlpha[direction * 2] << ")\n";
660 out << SP << SP << SP << SP << OpName << "_update_gate[i] = 0.;\n";
661 out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < "
662 << fAttrActivationAlpha[direction * 2] << ")\n";
663 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = 0.;\n";
664 out << SP << SP << "}";
665 } else if (fAttrActivations[direction * 2] == "Elu") {
666 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
667 out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < 0.)\n";
668 out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
669 << fAttrActivationAlpha[direction * 2] << " * exp(" << OpName << "_update_gate[i] - 1.);\n";
670 out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < 0.)\n";
671 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
672 << fAttrActivationAlpha[direction * 2] << " * exp(" << OpName << "_reset_gate[i] - 1.);\n";
673 out << SP << SP << "}\n";
674 } else if (fAttrActivations[direction * 2] == "Softsign") {
675 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
676 out << SP << SP << SP << SP << OpName << "_update_gate[i] = " << OpName
677 << "_update_gate[i] / (1. + abs(" << OpName << "_update_gate[i]));\n";
678 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = " << OpName
679 << "_reset_gate[i] / (1. + abs(" << OpName << "_reset_gate[i]));\n";
680 out << SP << SP << "}\n";
681 } else { // fAttrActivations[direction * 2] = Softplus
682 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
683 out << SP << SP << SP << SP << OpName << "_update_gate[i] = log(1. + exp("
684 << OpName << "_update_gate[i]));\n";
685 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = log(1. + exp("
686 << OpName << "_reset_gate[i]));\n";
687 out << SP << SP << "}\n";
688 }
689
690 if (fAttrLinearBeforeReset == 0) {
691 out << SP << SP << "if (seq == 0) {\n";
692 if (!fNInitial_h.empty()) {
693 // feedback = reset_gate o initial_hidden_state
694 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
695 out << SP << SP << SP << SP << OpName << "_feedback[i] = " << OpName
696 << "_reset_gate[i + offset] * " << OpName << "_initial_hidden_state[i];\n";
697 out << SP << SP << SP << "}\n";
698 }
699 out << SP << SP << "} else {\n";
700 // feedback = reset_gate o previous_hidden_state
701 if (direction == 0) {
702 if (fAttrDirection == "backward") {
703 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
704 << num_directions * batch_size * fAttrHiddenSize << ";\n";
705 } else {
706 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
707 << num_directions * batch_size * fAttrHiddenSize << ";\n";
708 }
709 } else {
710 out << SP << SP << SP << "size_t previous_offset = (index + 1) * " << num_directions
711 * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
712 }
713 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
714 out << SP << SP << SP << SP << OpName << "_feedback[i] = " << OpName
715 << "_reset_gate[i + offset] * " << OpName << "_hidden_state[i + previous_offset];\n";
716 out << SP << SP << SP << "}\n";
717 out << SP << SP << "}\n";
718 // feedback = feedback * R_h^T
719 size_t rh_offset = (direction == 0) ?
720 2 * fAttrHiddenSize * fAttrHiddenSize : 3 * fAttrHiddenSize * fAttrHiddenSize
721 + 2 * fAttrHiddenSize * fAttrHiddenSize;
722 out << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
723 << OpName << "_n, &" << OpName << "_m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_"
724 << fNR << " + " << rh_offset << ", &" << OpName << "_n, " << OpName << "_feedback, &" << OpName
725 << "_n, &" << OpName << "_beta, " << OpName << "_feedback, &" << OpName << "_n);\n";
726 } else { // fAttrLinearBeforeReset=1
727 // feedback = previous_hidden_state * R_h^T
728 //LM fixes
729 size_t rh_offset = (direction == 0)
730 ? 2 * fAttrHiddenSize * fAttrHiddenSize
731 : 3 * fAttrHiddenSize * fAttrHiddenSize + 2 * fAttrHiddenSize * fAttrHiddenSize;
732 out << SP << SP << "if (seq == 0) {\n";
733 if (!fNInitial_h.empty()) {
734 // feedback = W * initial_hidden_state + bias
735 out << SP << SP << SP
736 << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName << "_n, &"
737 << OpName << "_m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
738 << rh_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &"
739 << OpName << "_n, &" << OpName << "_beta, " << OpName << "_feedback, &" << OpName << "_n);\n";
740 }
741 out << SP << SP << "} else {\n";
742 // case for seq > 0
743 if (direction == 0) {
744 if (fAttrDirection == "backward") {
745 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
746 << num_directions * batch_size * fAttrHiddenSize << ";\n";
747 } else {
748 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
749 << num_directions * batch_size * fAttrHiddenSize << ";\n";
750 }
751 } else {
752 out << SP << SP << SP << "size_t previous_offset = (index + 1) * " << num_directions
753 * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
754 }
755 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
756 << OpName << "_n, &" << OpName << "_m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR
757 << " + " << rh_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
758 << OpName << "_n, &" << OpName << "_beta, " << OpName << "_feedback, &" << OpName << "_n);\n";
759 // endif on seq 0 or not
760 out << SP << SP << "}\n";
761 // Add the bias of the recurrence to feedback
762 if (!fNB.empty()) {
763 size_t rbh_offset = (direction == 0) ? 5 * batch_size * seq_length * fAttrHiddenSize
764 : 11 * batch_size * seq_length * fAttrHiddenSize;
765 out << SP << SP << "BLAS::saxpy_(&" << OpName << "_feedback_size, &" << OpName
766 << "_alpha, tensor_" << fNB << " + " << rbh_offset << ", &" << OpName << "_incx, "
767 << OpName << "_feedback, &" << OpName << "_incy);\n";
768 }
769 // feedback = reset_gate o feedback
770 out << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
771 out << SP << SP << SP << OpName << "_feedback[i] *= " << OpName << "_reset_gate[i + offset];\n";
772 out << SP << SP << "}\n";
773 }
774
775 // hidden_gate = hidden_gate + feedback
776 out << SP << SP << "BLAS::saxpy_(&" << OpName << "_feedback_size, &" << OpName << "_alpha, "
777 << OpName << "_feedback, &" << OpName << "_incx, " << OpName << "_hidden_gate + offset, &"
778 << OpName << "_incy);\n";
779
780 // Clip the elements of the hidden gate into the range [-fClip, fClip]
781 if (fAttrClip > .0) {
782 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
783 if (fType == "float") {
784 out << SP << SP << SP << "float x = (" << OpName << "_hidden_gate[i] > " << -fAttrClip
785 << ") ? " << OpName << "_hidden_gate[i] : " << -fAttrClip << ";\n";
786 }
787 out << SP << SP << SP << OpName << "_hidden_gate[i] = (x < " << fAttrClip << ") ? x : "
788 << fAttrClip << ";\n";
789 out << SP << SP << "}\n";
790 }
791
792 // Apply the activation function to the hidden gate
793 if (fAttrActivations[direction * 2 + 1] == "Relu") {
794 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
795 out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < 0.)\n";
796 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = 0.;\n";
797 out << SP << SP << "}\n";
798 } else if (fAttrActivations[direction * 2 + 1] == "Tanh") {
799 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
800 if (fType == "float") {
801 out << SP << SP << SP << "float ex = exp(-2 * " << OpName << "_hidden_gate[i]);\n";
802 }
803 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = (1. - ex) / (1. + ex);\n";
804 out << SP << SP << "}\n";
805 } else if (fAttrActivations[direction * 2 + 1] == "Sigmoid") {
806 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
807 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = 1. / (1. + exp(-" << OpName
808 << "_hidden_gate[i]));\n";
809 out << SP << SP << "}\n";
810 } else if (fAttrActivations[direction * 2 + 1] == "Affine") {
811 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
812 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
813 << fAttrActivationAlpha[direction * 2 + 1] << " * " << OpName << "_hidden_gate[i] + "
814 << fAttrActivationBeta[direction * 2 + 1] << ";\n";
815 out << SP << SP << "}\n";
816 } else if (fAttrActivations[direction * 2 + 1] == "ScaledTanh") {
817 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
818 if (fType == "float") {
819 out << SP << SP << SP << "float ex = exp(-2 * " << fAttrActivationBeta[direction * 2 + 1]
820 << " * "<< OpName << "_hidden_gate[i]);\n";
821 }
822 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
823 << fAttrActivationAlpha[direction * 2 + 1] << " * (1. - ex) / (1. + ex);\n";
824 out << SP << SP << "}\n";
825 } else if (fAttrActivations[direction * 2 + 1] == "HardSigmoid") {
826 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
827 if (fType == "float") {
828 out << SP << SP << SP << "float a = " << fAttrActivationAlpha[direction * 2 + 1] << " * "
829 << OpName << "_hidden_gate[i] + " << fAttrActivationBeta[direction * 2 + 1] << ";\n";
830 out << SP << SP << SP << "float b = (a > 0.) ? a : 0.;\n";
831 }
832 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = (b < 1.) ? b : 1.;\n";
833 out << SP << SP << "}\n";
834 } else if (fAttrActivations[direction * 2 + 1] == "LeakyRelu") {
835 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
836 out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < 0.)\n";
837 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
838 << fAttrActivationAlpha[direction * 2 + 1] << " * " << OpName << "_hidden_gate[i];\n";
839 out << SP << SP << "}\n";
840 } else if (fAttrActivations[direction * 2 + 1] == "ThresholdRelu") {
841 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
842 out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < "
843 << fAttrActivationAlpha[direction * 2 + 1] << ")\n";
844 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = 0.;\n";
845 out << SP << SP << "}";
846 } else if (fAttrActivations[direction * 2 + 1] == "Elu") {
847 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
848 out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < 0.)\n";
849 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
850 << fAttrActivationAlpha[direction * 2 + 1] << " * exp(" << OpName << "_hidden_gate[i] - 1.);\n";
851 out << SP << SP << "}\n";
852 } else if (fAttrActivations[direction * 2 + 1] == "Softsign") {
853 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
854 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = " << OpName
855 << "_hidden_gate[i] / (1. + abs(" << OpName << "_hidden_gate[i]));\n";
856 out << SP << SP << "}\n";
857 } else { // fAttrActivations[direction * 2 + 1] = Softplus
858 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
859 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = log(1. + exp("
860 << OpName << "_hidden_gate[i]));\n";
861 out << SP << SP << "}\n";
862 }
863
864 // hidden_state = (1 - update_gate) o hidden_gate
865 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
866 out << SP << SP << SP << OpName << "_hidden_state[i] = ( 1. - " << OpName
867 << "_update_gate[i]) * " << OpName << "_hidden_gate[i];\n";
868 out << SP << SP << "}\n";
869
870 out << SP << SP << "if (seq == 0) {\n";
871 if (!fNInitial_h.empty()) {
872 // hidden_state += update_gate o initial_hidden_state
873 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
874 out << SP << SP << SP << SP << OpName << "_hidden_state[i + offset] += " << OpName
875 << "_update_gate[i + offset] * " << OpName << "_initial_hidden_state[i];\n";
876 out << SP << SP << SP << "}\n";
877 }
878 out << SP << SP << "} else {\n";
879 // hidden_state += update_gate o previous_hidden_state
880 if (direction == 0) {
881 if (fAttrDirection == "backward") {
882 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
883 << num_directions * batch_size * fAttrHiddenSize << ";\n";
884 } else {
885 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
886 << num_directions * batch_size * fAttrHiddenSize << ";\n";
887 }
888 } else {
889 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
890 << num_directions * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
891 }
892 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
893 out << SP << SP << SP << SP << OpName << "_hidden_state[i + offset] += " << OpName
894 << "_update_gate[i + offset] * " << OpName << "_hidden_state[i + previous_offset];\n";
895 out << SP << SP << SP << "}\n";
896 out << SP << SP << "}\n";
897
898 out << SP << "}\n";
899 }
900
901 // Padding the hidden state for GRU with different sequence lengths
902 if (!fNSequence_lens.empty()) {
903 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
904 out << SP << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
905 out << SP << SP << SP << "if (seq >= tensor_" << fNSequence_lens << "[batch]) {\n";
906 for (size_t direction = 0; direction < num_directions; direction++) {
907 out << SP << SP << SP << SP << SP << "for (size_t h = 0; h < " << fAttrHiddenSize << "; h++) {\n";
908 out << SP << SP << SP << SP << SP << SP << OpName << "_hidden_state[seq * "
909 << num_directions * batch_size * fAttrHiddenSize + direction * batch_size * fAttrHiddenSize
910 << " + batch * " << fAttrHiddenSize << " + h] = 0.;\n";
911 out << SP << SP << SP << SP << SP << "}\n";
912 }
913 out << SP << SP << SP << "}\n";
914 out << SP << SP << "}\n";
915 out << SP << "}\n";
916 }
917
918 // Copy the hidden state into y and y_h
919 if (fAttrLayout == 0) {
920 if (!fNY_h.empty()) {
921 // Copy hidden_state into Y_h
922 if (fNSequence_lens.empty()) {
923 size_t yh_size = batch_size * fAttrHiddenSize;
924 if (fAttrDirection == "backward") {
925 out << SP << "std::copy(" << OpName << "_hidden_state, " << OpName << "_hidden_state + "
926 << yh_size << ", tensor_" << fNY_h << ");\n";
927 } else {
928 size_t offset = (seq_length - 1) * num_directions * batch_size * fAttrHiddenSize;
929 out << SP << "std::copy(" << OpName << "_hidden_state + " << offset << ", " << OpName
930 << "_hidden_state + " << offset << " + " << yh_size << ", tensor_" << fNY_h << ");\n";
931 }
932 if (num_directions == 2) {
933 out << SP << "std::copy(" << OpName << "_hidden_state + " << yh_size << ", " << OpName
934 << "_hidden_state + " << 2 * yh_size << ", tensor_" << fNY_h << " + " << yh_size << ");\n";
935 }
936 } else { // GRU with different sequence lengths
937 if (fAttrDirection == "backward") {
938 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
939 out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
940 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
941 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + offset);\n";
942 out << SP << "}\n";
943 } else {
944 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
945 out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
946 out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
947 << " + batch * " << fAttrHiddenSize << ";\n";
948 out << SP << SP << "size_t yh_offset = batch * " << fAttrHiddenSize << ";\n";
949 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
950 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
951 out << SP << "}\n";
952 }
953 if (num_directions == 2) {
954 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
955 out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize
956 << " + batch * " << fAttrHiddenSize << ";\n";
957 out << SP << SP << "size_t yh_offset = " << batch_size * fAttrHiddenSize
958 << " + batch * " << fAttrHiddenSize << ";\n";
959 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
960 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
961 out << SP << "}\n";
962 }
963 }
964 }
965 } else { // fAttrLayout=1
966 if (!fNY.empty()) {
967 // Copy hidden_state into Y
968 for (size_t direction = 0; direction < num_directions; direction++) {
969 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
970 out << SP << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
971 out << SP << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
972 << " + " << direction * batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize << ";\n";
973 out << SP << SP << SP << "size_t y_offset = batch * " << seq_length * num_directions * fAttrHiddenSize
974 << " + seq * " << num_directions * fAttrHiddenSize << " + " << direction * fAttrHiddenSize << ";\n";
975 out << SP << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
976 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY << " + y_offset);\n";
977 out << SP << SP << "}\n";
978 out << SP << "}\n";
979 }
980 }
981 if (!fNY_h.empty()) {
982 // Copy the hidden_state into Y_h
983 if (fAttrDirection == "backward") {
984 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
985 out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
986 out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
987 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
988 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
989 out << SP << "}\n";
990 } else {
991 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
992 if (fNSequence_lens.empty()) {
993 out << SP << SP << "size_t seq = " << seq_length - 1 << ";\n";
994 } else {
995 out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
996 }
997 out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
998 << " + batch * " << fAttrHiddenSize << ";\n";
999 out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
1000 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1001 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
1002 out << SP << "}\n";
1003 }
1004 if (num_directions == 2) {
1005 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1006 out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize << " + batch * "
1007 << fAttrHiddenSize << ";\n";
1008 out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << " + "
1009 << fAttrHiddenSize << ";\n";
1010 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1011 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
1012 out << SP << "}\n";
1013 }
1014 }
1015 }
1016
1017 return out.str();
1018}
1019
1020} // namespace SOFIE
1021} // namespace Experimental
1022} // namespace TMVA
1023
1024#endif
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h offset
const ETensorType & GetTensorType(std::string name)
Definition RModel.cxx:94
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< Dim > dim_shape)
Definition RModel.cxx:203
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:122
const std::vector< size_t > & GetTensorShape(std::string name)
Definition RModel.cxx:56
std::shared_ptr< void > GetInitializedTensorData(std::string tensor_name)
Definition RModel.cxx:264
void UpdateInitializedTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape, std::shared_ptr< void > data)
Definition RModel.cxx:255
std::string GenerateSessionMembersCode(std::string opName)
Generate the code for the Session internal data vectors.
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > >)
Infers the shape of the output tensors.
void Initialize(RModel &)
Initialize the model.
std::string Generate(std::string)
Generate the inference code.
std::vector< ETensorType > TypeInference(std::vector< ETensorType >)
Infers the type of the output tensors.
create variable transformations