Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_GRU.icc
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR_GRU_I
2#define TMVA_SOFIE_ROPERATOR_GRU_I
3
4namespace TMVA {
5namespace Experimental {
6namespace SOFIE {
7
8template <typename T>
9auto ROperator_GRU<T>::TypeInference(std::vector<ETensorType> input)
10-> std::vector<ETensorType> {
11 ETensorType out = input[0];
12 return {out, out};
13}
14
15template<typename T>
16auto ROperator_GRU<T>::ShapeInference(std::vector<std::vector<size_t>> input)
17-> std::vector<std::vector<size_t>> {
18 size_t num_directions = input[1][0];
19 size_t hidden_size = input[1][1] / 3;
20 if (fAttrLayout == 0) {
21 size_t seq_length = input[0][0];
22 size_t batch_size = input[0][1];
23 std::vector<std::vector<size_t>> ret(
24 {{seq_length, num_directions, batch_size, hidden_size},
25 {num_directions, batch_size, hidden_size}});
26 return ret;
27 } else {
28 size_t batch_size = input[0][0];
29 size_t seq_length = input[0][1];
30 std::vector<std::vector<size_t>> ret(
31 {{batch_size, seq_length, num_directions, hidden_size},
32 {batch_size, num_directions, hidden_size}});
33 return ret;
34 }
35}
36
37template<typename T>
39 fUseSession = model.UseSession();
40 // Check the input and output tensors
41 if (!model.CheckIfTensorAlreadyExist(fNX)) {
42 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNX + " is not found in model.");
43 }
44 fShapeX = model.GetTensorShape(fNX);
45 if (fShapeX.size() != 3) {
46 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNX + " is not of 3 dimensions.");
47 }
48 if (!model.CheckIfTensorAlreadyExist(fNW)) {
49 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNW + " is not found in model.");
50 }
51 fShapeW = model.GetTensorShape(fNW);
52 if (fShapeW.size() != 3) {
53 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNW + " is not of 3 dimensions.");
54 }
55 if (!model.CheckIfTensorAlreadyExist(fNR)) {
56 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNR + " is not found in model.");
57 }
58 fShapeR = model.GetTensorShape(fNR);
59 if (fShapeR.size() != 3) {
60 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNR + " is not of 3 dimensions.");
61 }
62 if (!fNB.empty()) {
63 if (!model.CheckIfTensorAlreadyExist(fNB)) {
64 throw std::runtime_error("TMVA SOFIE GRU op input tensor " + fNB + " is not found in model.");
65 }
66 fShapeB = model.GetTensorShape(fNB);
67 if (fShapeB.size() != 2 && fShapeB.size() != 4) {
68 throw std::runtime_error("TMVA SOFIE GRU op input tensor " + fNB + " is not of 2 or 4 dimensions.");
69 }
70 if (fShapeB.size() == 2) {
71 // Broadcasting the bias
72 auto original_data = model.GetInitializedTensorData(fNB);
73 size_t num_directions = fShapeW[0];
74 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
75 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
76 if (fType == "float") {
77 float *original_bias = static_cast<float*>(original_data.get());
78 float *new_bias = new float[num_directions * 6 * seq_length * batch_size * fAttrHiddenSize];
79 for (size_t direction = 0; direction < num_directions; direction++) {
80 for (size_t i = 0; i < 6; i++) {
81 for (size_t seq = 0; seq < seq_length; seq++) {
82 for (size_t batch = 0; batch < batch_size; batch++) {
83 size_t bias_offset = direction * 6 * fAttrHiddenSize + i * fAttrHiddenSize;
84 size_t offset = direction * 6 * batch_size * seq_length * fAttrHiddenSize +
85 i * batch_size * seq_length * fAttrHiddenSize +
86 + seq *batch_size *fAttrHiddenSize + batch *fAttrHiddenSize;
87 std::copy(original_bias + bias_offset, original_bias + bias_offset + fAttrHiddenSize,
88 new_bias + offset);
89 }
90 }
91 }
92 }
93
94 std::vector<size_t> new_bias_shape = {num_directions, 6, seq_length, batch_size, fAttrHiddenSize};
95 std::shared_ptr<void> new_bias_ptr(new_bias, std::default_delete<float[]>());
96 model.UpdateInitializedTensor(fNB, model.GetTensorType(fNB), new_bias_shape, new_bias_ptr);
97 fShapeB = model.GetTensorShape(fNB);
98 }
99 }
100 }
101 if (!fNSequence_lens.empty()) {
102 if (!model.CheckIfTensorAlreadyExist(fNSequence_lens)) {
103 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
104 fNSequence_lens +
105 "is not found in model.");
106 }
107 fShapeSequence_lens = model.GetTensorShape(fNSequence_lens);
108 if (fShapeSequence_lens.size() != 1) {
109 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
110 fNSequence_lens +
111 " is not of 1 dimension.");
112 }
113 }
114 if (!fNInitial_h.empty()) {
115 if (!model.CheckIfTensorAlreadyExist(fNInitial_h)) {
116 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
117 fNInitial_h + " is not found in model.");
118 }
119 fShapeInitial_h = model.GetTensorShape(fNInitial_h);
120 if (fShapeInitial_h.size() != 3) {
121 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
122 fNInitial_h + " is not of 3 dimensions.");
123 }
124 }
125 if (!fNY.empty()) {
126 fShapeY = ShapeInference({fShapeX, fShapeW})[0];
127 if (!model.CheckIfTensorAlreadyExist(fNY)) {
128 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
129 }
130 }
131 if (!fNY_h.empty()) {
132 fShapeY_h = ShapeInference({fShapeX, fShapeW})[1];
133 if (!model.CheckIfTensorAlreadyExist(fNY_h)) {
134 model.AddIntermediateTensor(fNY_h, model.GetTensorType(fNX), fShapeY_h);
135 }
136 }
137 // Check the attributes
138 for (auto &activation : fAttrActivations) {
139 if (activation != "Relu" && activation != "Tanh" &&
140 activation != "Sigmoid" && activation != "Affine" &&
141 activation != "LeakyRelu" && activation != "ThresholdRelu" &&
142 activation != "ScaledTanh" && activation != "HardSigmoid" &&
143 activation != "Elu" && activation != "Softsign" &&
144 activation != "Softplus") {
145 throw std::runtime_error("TMVA SOFIE - Activation function " +
146 activation + " not implemented");
147 }
148 }
149 if (fAttrDirection != "forward" && fAttrDirection != "backward" &&
150 fAttrDirection != "bidirectional") {
151 throw std::runtime_error(
152 "TMVA SOFIE - Invalid GRU direction fAttrDirection = " +
153 fAttrDirection);
154 }
155 if (3 * fAttrHiddenSize != fShapeW[1]) {
156 throw std::runtime_error(
157 "TMVA SOFIE - fAttrHiddenSize must be equal to " +
158 std::to_string(fShapeW[1] / 3));
159 }
160 if (fAttrLayout > 1) {
161 throw std::runtime_error("TMVA SOFIE - Layout fAttrLayout = " +
162 std::to_string(fAttrLayout) +
163 " must be 0 (timewise) or 1 (batchwise)");
164 }
165 if (fAttrLinearBeforeReset > 1) {
166 throw std::runtime_error(
167 "TMVA SOFIE - fAttrInputForget = " + std::to_string(fAttrLinearBeforeReset)
168 + " must be 0 or 1.");
169 }
170 if (fAttrActivations.empty()) {
171 if (fAttrDirection == "bidirectional") {
172 fAttrActivations = {"Sigmoid", "Tanh", "Sigmoid", "Tanh"};
173 } else {
174 fAttrActivations = {"Sigmoid", "Tanh"};
175 }
176 }
177}
178
179// generate code for Session data members (e.g. internal vectors)
180template <typename T>
181std::string ROperator_GRU<T>::GenerateSessionMembersCode(std::string opName)
182{
183 opName = "op_" + opName;
184 std::stringstream out;
185
186 size_t num_directions = fShapeW[0];
187 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
188 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
189 size_t input_size = fShapeX[2];
190
191 if (fAttrLayout != 0) {
192 out << "std::vector<" << fType << "> fVec_" << opName << "_input = std::vector<" << fType << ">("
193 << seq_length * batch_size * input_size << ");\n";
194 out << "std::vector<" << fType << "> fVec_" << opName << "_initial_hidden_state = std::vector<" << fType << ">("
195 << num_directions * batch_size * fAttrHiddenSize << ");\n";
196 out << "std::vector<" << fType << "> fVec_" << opName << "_initial_cell_state = std::vector<" << fType << ">("
197 << num_directions * batch_size * fAttrHiddenSize << ");\n";
198 }
199 // Set the feedforward
200 size_t ff_size = seq_length * batch_size * fAttrHiddenSize;
201 out << "std::vector<" << fType << "> fVec_" << opName << "_f_update_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
202 out << "std::vector<" << fType << "> fVec_" << opName << "_f_reset_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
203 out << "std::vector<" << fType << "> fVec_" << opName << "_f_hidden_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
204 // gate results
205 size_t hs_size = seq_length * num_directions * batch_size * fAttrHiddenSize;
206 out << "std::vector<" << fType << "> fVec_" << opName << "_update_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
207 out << "std::vector<" << fType << "> fVec_" << opName << "_reset_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
208 out << "std::vector<" << fType << "> fVec_" << opName << "_hidden_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
209
210 // feedback
211 out << "std::vector<" << fType << "> fVec_" << opName << "_feedback = std::vector<" << fType << ">("
212 << batch_size * fAttrHiddenSize << ");\n";
213
214 // hiddden state
215 if (fAttrLayout != 0 || fNY.empty()) {
216 out << "std::vector<" << fType << "> fVec_" << opName << "_hidden_state = std::vector<" << fType << ">(" << hs_size << ");\n";
217 }
218
219 out << "\n";
220
221 return out.str();
222}
223
224
225template<typename T>
226auto ROperator_GRU<T>::Generate(std::string OpName)
227-> std::string {
228 OpName = "op_" + OpName;
229 std::stringstream out;
230
231 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
232 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
233 size_t input_size = fShapeX[2];
234 size_t num_directions = fShapeW[0];
235
236 // set the input
237 if (fAttrLayout == 0) {
238 out << SP << fType << " *" << OpName << "_input = tensor_" << fNX << ";\n";
239 } else {
240 if (fUseSession) {
241 out << SP << fType << " * " << OpName << "_input = fVec_" << OpName << "_input.data();\n";
242 } else {
243 out << SP << fType << " " << OpName << "_input[" << seq_length * batch_size * input_size << "];\n";
244 }
245 out << SP << "for(size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
246 out << SP << SP << "for(size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
247 out << SP << SP << SP << "for(size_t i = 0; i < " << input_size << "; i++) {\n";
248 out << SP << SP << SP << SP << OpName << "_input[seq * " << batch_size * input_size
249 << " + batch * " << input_size << " + i] = " << "tensor_" << fNX << "[batch * "
250 << seq_length * input_size << " + seq * " << input_size << " + i];\n";
251 out << SP << SP << SP << "}\n";
252 out << SP << SP << "}\n";
253 out << SP << "}\n";
254 }
255
256 // Set the initial hidden state
257 if (!fNInitial_h.empty()) {
258 if (fAttrLayout == 0) {
259 out << SP << fType << " *" << OpName << "_initial_hidden_state = " << " tensor_"
260 << fNInitial_h << ";\n";
261 } else {
262 if (fUseSession) {
263 out << SP << fType << " * " << OpName << "_initial_hidden_state = fVec_" << OpName
264 << "_initial_hidden_state.data();\n";
265 } else {
266 out << SP << fType << " " << OpName << "_initial_hidden_state[" << num_directions * batch_size *
267 fAttrHiddenSize << "];\n";
268 }
269 for (size_t direction = 0; direction < num_directions; direction++) {
270 out << SP << "for(size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
271 out << SP << SP << "for(size_t h = 0; h < " << fAttrHiddenSize << "; h++) {\n";
272 out << SP << SP << SP << OpName << "_initial_hidden_state["
273 << direction * batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize
274 << " + h] = tensor_" << fNInitial_h << "[batch * " << num_directions * fAttrHiddenSize
275 << " + " << direction * fAttrHiddenSize << " + h];\n";
276 out << SP << SP << "}\n";
277 out << SP << "}\n";
278 }
279 }
280 }
281
282 // Set the feedforward
283 size_t feedforward_size = seq_length * batch_size * fAttrHiddenSize;
284 if (fUseSession) {
285 out << SP << fType << " * " << OpName << "_f_update_gate = fVec_" << OpName << "_f_update_gate.data();\n";
286 out << SP << fType << " * " << OpName << "_f_reset_gate = fVec_" << OpName << "_f_reset_gate.data();\n";
287 out << SP << fType << " * " << OpName << "_f_hidden_gate = fVec_" << OpName << "_f_hidden_gate.data();\n";
288 } else {
289 out << SP << fType << " " << OpName << "_f_update_gate[" << feedforward_size << "] = {0};\n";
290 out << SP << fType << " " << OpName << "_f_reset_gate[" << feedforward_size << "] = {0};\n";
291 out << SP << fType << " " << OpName << "_f_hidden_gate[" << feedforward_size << "] = {0};\n";
292 }
293 // Set the gates
294 size_t hidden_state_size = seq_length * num_directions * batch_size * fAttrHiddenSize;
295 if (fUseSession) {
296 out << SP << fType << " * " << OpName << "_update_gate = fVec_" << OpName << "_update_gate.data();\n";
297 out << SP << fType << " * " << OpName << "_reset_gate = fVec_" << OpName << "_reset_gate.data();\n";
298 out << SP << fType << " * " << OpName << "_hidden_gate = fVec_" << OpName << "_hidden_gate.data();\n";
299 } else {
300 out << SP << fType << " " << OpName << "_update_gate[" << hidden_state_size << "] = {0};\n";
301 out << SP << fType << " " << OpName << "_reset_gate[" << hidden_state_size << "] = {0};\n";
302 out << SP << fType << " " << OpName << "_hidden_gate[" << hidden_state_size << "] = {0};\n";
303 }
304 // Set the hidden state
305 if (fAttrLayout == 0 && !fNY.empty()) {
306 out << SP << fType << " *" << OpName << "_hidden_state = tensor_" << fNY << ";\n";
307 } else {
308 if (fUseSession) {
309 out << SP << fType << " * " << OpName << "_hidden_state = fVec_" << OpName << "_hidden_state.data();\n";
310 } else {
311 out << SP << fType << " " << OpName << "_hidden_state[" << hidden_state_size << "] = {0};\n";
312 }
313 }
314
315 if (fUseSession) {
316 out << SP << fType << " * " << OpName << "_feedback = fVec_" << OpName << "_feedback.data();\n";
317 } else {
318 out << SP << fType << " " << OpName << "_feedback[" << batch_size * fAttrHiddenSize << "] = {0};\n";
319 }
320
321 out << SP << "char " << OpName << "_transA = 'N';\n";
322 out << SP << "char " << OpName << "_transB = 'T';\n";
323 out << SP << "int " << OpName << "_m = " << seq_length * batch_size << ";\n";
324 out << SP << "int " << OpName << "_m2 = " << batch_size << ";\n";
325 out << SP << "int " << OpName << "_n = " << fAttrHiddenSize << ";\n";
326 out << SP << "int " << OpName << "_k = " << input_size << ";\n";
327 if (fType == "float") {
328 out << SP << "float " << OpName << "_alpha = 1.;\n";
329 out << SP << "float " << OpName << "_beta = 0.;\n";
330 }
331 if (!fNB.empty()) {
332 out << SP << "int " << OpName << "_bias_size = " << seq_length * batch_size * fAttrHiddenSize << ";\n";
333 }
334 out << SP << "int " << OpName << "_incx = 1;\n";
335 out << SP << "int " << OpName << "_incy = 1;\n";
336 out << SP << "int " << OpName << "_feedback_size = " << batch_size * fAttrHiddenSize << ";\n";
337
338 for (size_t direction = 0; direction < num_directions; direction++) {
339 if (direction == 0) {
340 if (fType == "float") {
341 // f_update_gate = input * weight_z^T
342 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
343 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
344 << fNW << ", &" << OpName << "_k, " << OpName << "_input, &" << OpName << "_k, &"
345 << OpName << "_beta, " << OpName << "_f_update_gate, &" << OpName << "_n);\n";
346 // f_reset_gate = input * weight_r^T
347 size_t wr_offset = fAttrHiddenSize * input_size;
348 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
349 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
350 << fNW << " + " << wr_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
351 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_reset_gate, &" << OpName << "_n);\n";
352 // f_hidden_gate = input * weight_h^T
353 size_t wh_offset = 2 * fAttrHiddenSize * input_size;
354 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
355 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
356 << fNW << " + " << wh_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
357 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_hidden_gate, &" << OpName << "_n);\n";
358 }
359 } else {
360 if (fType == "float") {
361 // f_update_gate = input * weight_z^T
362 size_t wz_offset = 3 * fAttrHiddenSize * input_size;
363 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
364 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
365 << fNW << " + " << wz_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
366 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_update_gate, &" << OpName << "_n);\n";
367 // f_reset_gate = input * weight_r^T
368 size_t wr_offset = 3 * fAttrHiddenSize * input_size + fAttrHiddenSize * input_size;
369 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
370 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
371 << fNW << " + " << wr_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
372 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_reset_gate, &" << OpName << "_n);\n";
373 // f_hidden_gate = input * weight_h^T
374 size_t wh_offset = 3 * fAttrHiddenSize * input_size + 2 * fAttrHiddenSize * input_size;
375 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
376 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
377 << fNW << " + " << wh_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
378 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_hidden_gate, &" << OpName << "_n);\n";
379 }
380 }
381
382 if (!fNB.empty()) {
383 if (direction == 0) {
384 if (fType == "float") {
385 // Add the bias of the weight to f_update_gate
386 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
387 << fNB << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &" << OpName << "_incy);\n";
388 // Add the bias of the recurrence to f_update_gate
389 size_t rbz_offset = 3 * batch_size * seq_length * fAttrHiddenSize;
390 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
391 << fNB << " + " << rbz_offset << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &"
392 << OpName << "_incy);\n";
393 // Add the bias of the weight to f_reset_gate
394 size_t wbr_offset = batch_size * seq_length * fAttrHiddenSize;
395 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
396 << fNB << " + " << wbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
397 << OpName << "_incy);\n";
398 // Add the bias of the recurrence to f_reset_gate
399 //size_t rbr_offset = fAttrHiddenSize * fAttrHiddenSize + 3 * batch_size * fAttrHiddenSize;
400 size_t rbr_offset = 4 * batch_size * seq_length * fAttrHiddenSize;
401 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
402 << fNB << " + " << rbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
403 << OpName << "_incy);\n";
404 // Add the bias of the weight to f_hidden_gate
405 size_t wbh_offset = 2 * batch_size * seq_length * fAttrHiddenSize;
406 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
407 << fNB << " + " << wbh_offset << ", &" << OpName << "_incx, " << OpName << "_f_hidden_gate, &"
408 << OpName << "_incy);\n";
409 if (fAttrLinearBeforeReset == 0) {
410 // Add the bias of the recurrence to f_hidden_gate
411 size_t rbh_offset = 5 * batch_size * seq_length * fAttrHiddenSize;
412 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
413 << fNB << " + " << rbh_offset << ", &" << OpName << "_incx, " << OpName
414 << "_f_hidden_gate, &" << OpName << "_incy);\n";
415 }
416 }
417 } else {
418 if (fType == "float") {
419 // Add the bias of the weight to f_update_gate
420 size_t wbz_offset = 6 * batch_size * seq_length * fAttrHiddenSize;
421 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
422 << fNB << " + " << wbz_offset << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &"
423 << OpName << "_incy);\n";
424 // Add the bias of the recurrence to f_update_gate
425 // size_t rbz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize + 3 * batch_size * fAttrHiddenSize;
426 size_t rbz_offset = 9 * batch_size * seq_length * fAttrHiddenSize;
427 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
428 << fNB << " + " << rbz_offset << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &"
429 << OpName << "_incy);\n";
430 // Add the bias of the weight to f_reset_gate
431 size_t wbr_offset = 7 * batch_size * seq_length * fAttrHiddenSize;
432 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
433 << fNB << " + " << wbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
434 << OpName << "_incy);\n";
435 // Add the bias of the recurrence to f_reset_gate
436 size_t rbr_offset = 10 * batch_size * seq_length * fAttrHiddenSize;
437 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
438 << fNB << " + " << rbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
439 << OpName << "_incy);\n";
440 // Add the bias of the weight to f_hidden_gate
441 size_t wbh_offset = 8 * batch_size * seq_length * fAttrHiddenSize;
442 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
443 << fNB << " + " << wbh_offset << ", &" << OpName << "_incx, " << OpName << "_f_hidden_gate, &"
444 << OpName << "_incy);\n";
445 if (fAttrLinearBeforeReset == 0) {
446 // Add the bias of the recurrence to f_hidden_gate
447 size_t rbh_offset = 11 * batch_size * seq_length * fAttrHiddenSize;
448 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
449 << fNB << " + " << rbh_offset << ", &" << OpName << "_incx, " << OpName
450 << "_f_hidden_gate, &" << OpName << "_incy);\n";
451 }
452 }
453 }
454 }
455
456 // Copy the feedforward into the gates
457 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
458 out << SP << SP << "size_t offset = seq * " << batch_size * fAttrHiddenSize << ";\n";
459 if (direction == 0) {
460 out << SP << SP << "size_t gate_offset = seq * " << num_directions * batch_size * fAttrHiddenSize
461 << ";\n";
462 } else {
463 out << SP << SP << "size_t gate_offset = seq * " << num_directions * batch_size * fAttrHiddenSize
464 << " + " << batch_size * fAttrHiddenSize << ";\n";
465 }
466 size_t f_seq_size = batch_size * fAttrHiddenSize;
467 out << SP << SP << "std::copy(" << OpName << "_f_update_gate + offset, " << OpName
468 << "_f_update_gate + offset + " << f_seq_size << ", " << OpName << "_update_gate + gate_offset);\n";
469 out << SP << SP << "std::copy(" << OpName << "_f_reset_gate + offset, " << OpName
470 << "_f_reset_gate + offset + " << f_seq_size << ", " << OpName << "_reset_gate + gate_offset);\n";
471 out << SP << SP << "std::copy(" << OpName << "_f_hidden_gate + offset, " << OpName
472 << "_f_hidden_gate + offset + " << f_seq_size << ", " << OpName << "_hidden_gate + gate_offset);\n";
473 out << SP << "}\n";
474
475 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
476 if (fAttrDirection == "backward" || direction == 1) {
477 out << SP << SP << "size_t index = " << seq_length - 1 << " - seq;\n";
478 } else {
479 out << SP << SP << "size_t index = seq;\n";
480 }
481 out << SP << SP << "int m2 = " << batch_size << ";\n";
482 if (direction == 0) {
483 out << SP << SP << "size_t offset = index * " << num_directions * batch_size * fAttrHiddenSize
484 << ";\n";
485 } else {
486 out << SP << SP << "size_t offset = index * " << num_directions * batch_size * fAttrHiddenSize
487 << " + " << batch_size * fAttrHiddenSize << ";\n";
488 }
489 size_t size = batch_size * fAttrHiddenSize;
490 // gate = gate + initial_hidden_state * Recurrence^T
491 out << SP << SP << "if (seq == 0) {\n";
492 if (!fNInitial_h.empty()) {
493 if (direction == 0) {
494 if (fType == "float") {
495 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
496 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << ", &"
497 << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName << "_n, &" << OpName
498 << "_alpha, " << OpName << "_update_gate + offset, &" << OpName << "_n);\n";
499 size_t rr_offset = fAttrHiddenSize * fAttrHiddenSize;
500 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
501 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
502 << rr_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
503 << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &" << OpName << "_n);\n";
504 }
505 } else { // direction=1
506 if (fType == "float") {
507 size_t rz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
508 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
509 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
510 << rz_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
511 << "_n, &" << OpName << "_alpha, " << OpName << "_update_gate + offset, &" << OpName << "_n);\n";
512 size_t rr_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
513 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
514 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
515 << rr_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
516 << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &" << OpName << "_n);\n";
517 }
518 }
519 }
520 out << SP << SP << "} else {\n";
521 // gate = gate + previous_hidden_state * Recurrence^T
522 if (direction == 0) {
523 if (fAttrDirection == "backward") {
524 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
525 << num_directions * batch_size * fAttrHiddenSize << ";\n";
526 } else {
527 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
528 << num_directions * batch_size * fAttrHiddenSize << ";\n";
529 }
530 if (fType == "float") {
531 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
532 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << ", &"
533 << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &" << OpName << "_n, &"
534 << OpName << "_alpha, " << OpName << "_update_gate + offset, &" << OpName << "_n);\n";
535 size_t rr_offset = fAttrHiddenSize * fAttrHiddenSize;
536 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
537 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
538 << rr_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
539 << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &"
540 << OpName << "_n);\n";
541 }
542 } else {
543 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
544 << num_directions * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
545 if (fType == "float") {
546 size_t rz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
547 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
548 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
549 << rz_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
550 << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_update_gate + offset, &"
551 << OpName << "_n);\n";
552 size_t rr_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
553 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
554 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
555 << rr_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
556 << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &"
557 << OpName << "_n);\n";
558 }
559 }
560 out << SP << SP << "}\n";
561
562 // Clip the elements of the update gate and the reset gate into the range [-fClip, fClip]
563 if (fAttrClip > .0) {
564 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
565 if (fType == "float") {
566 out << SP << SP << SP << "float z = (" << OpName << "_update_gate[i] > " << -fAttrClip
567 << ") ? " << OpName << "_update_gate[i] : " << -fAttrClip << ";\n";
568 }
569 out << SP << SP << SP << OpName << "_update_gate[i] = (z < " << fAttrClip
570 << ") ? z : " << fAttrClip << ";\n";
571 if (fType == "float") {
572 out << SP << SP << SP << "float r = (" << OpName << "_reset_gate[i] > " << -fAttrClip
573 << ") ? " << OpName << "_reset_gate[i] : " << -fAttrClip << ";\n";
574 }
575 out << SP << SP << SP << OpName << "_reset_gate[i] = (r < " << fAttrClip
576 << ") ? r : " << fAttrClip << ";\n";
577 out << SP << SP << "}\n";
578 }
579
580 // Apply the activation function to the update gate and the reset gate
581 if (fAttrActivations[direction * 2] == "Relu") {
582 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
583 out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < 0.)\n";
584 out << SP << SP << SP << SP << OpName << "_update_gate[i] = 0.;\n";
585 out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < 0.)\n";
586 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = 0.;\n";
587 out << SP << SP << "}\n";
588 } else if (fAttrActivations[direction * 2] == "Tanh") {
589 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
590 if (fType == "float") {
591 out << SP << SP << SP << "float z = exp(-2 * " << OpName << "_update_gate[i]);\n";
592 }
593 out << SP << SP << SP << SP << OpName << "_update_gate[i] = (1. - z) / (1. + z);\n";
594 if (fType == "float") {
595 out << SP << SP << SP << "float r = exp(-2 * " << OpName << "_reset_gate[i]);\n";
596 }
597 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = (1. - r) / (1. + r);\n";
598 out << SP << SP << "}\n";
599 } else if (fAttrActivations[direction * 2] == "Sigmoid") {
600 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
601 out << SP << SP << SP << SP << OpName << "_update_gate[i] = 1. / (1. + exp(-"
602 << OpName << "_update_gate[i]));\n";
603 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = 1. / (1. + exp(-"
604 << OpName << "_reset_gate[i]));\n";
605 out << SP << SP << "}\n";
606 } else if (fAttrActivations[direction * 2] == "Affine") {
607 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
608 out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
609 << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_update_gate[i] + "
610 << fAttrActivationBeta[direction * 2] << ";\n";
611 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
612 << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_reset_gate[i] + "
613 << fAttrActivationBeta[direction * 2] << ";\n";
614 out << SP << SP << "}\n";
615 } else if (fAttrActivations[direction * 2] == "ScaledTanh") {
616 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
617 if (fType == "float") {
618 out << SP << SP << SP << "float z = exp(-2 * " << fAttrActivationBeta[direction * 2]
619 << " * "<< OpName << "_update_gate[i]);\n";
620 }
621 out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
622 << fAttrActivationAlpha[direction * 2] << " * (1. - z) / (1. + z);\n";
623 if (fType == "float") {
624 out << SP << SP << SP << "float r = exp(-2 * " << fAttrActivationBeta[direction * 2]
625 << " * "<< OpName << "_reset_gate[i]);\n";
626 }
627 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
628 << fAttrActivationAlpha[direction * 2] << " * (1. - r) / (1. + r);\n";
629 out << SP << SP << "}\n";
630 } else if (fAttrActivations[direction * 2] == "HardSigmoid") {
631 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
632 if (fType == "float") {
633 out << SP << SP << SP << "float za = " << fAttrActivationAlpha[direction * 2] << " * "
634 << OpName << "_update_gate[i] + " << fAttrActivationBeta[direction * 2] << ";\n";
635 out << SP << SP << SP << "float zb = (za > 0.) ? za : 0.;\n";
636 }
637 out << SP << SP << SP << SP << OpName << "_update_gate[i] = (zb < 1.) ? zb : 1.;\n";
638 if (fType == "float") {
639 out << SP << SP << SP << "float ra = " << fAttrActivationAlpha[direction * 2] << " * "
640 << OpName << "_reset_gate[i] + " << fAttrActivationBeta[direction * 2] << ";\n";
641 out << SP << SP << SP << "float rb = (ra > 0.) ? ra : 0.;\n";
642 }
643 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = (rb < 1.) ? rb : 1.;\n";
644 out << SP << SP << "}\n";
645 } else if (fAttrActivations[direction * 2] == "LeakyRelu") {
646 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
647 out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < 0.)\n";
648 out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
649 << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_update_gate[i];\n";
650 out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < 0.)\n";
651 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
652 << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_reset_gate[i];\n";
653 out << SP << SP << "}\n";
654 } else if (fAttrActivations[direction * 2] == "ThresholdRelu") {
655 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
656 out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < "
657 << fAttrActivationAlpha[direction * 2] << ")\n";
658 out << SP << SP << SP << SP << OpName << "_update_gate[i] = 0.;\n";
659 out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < "
660 << fAttrActivationAlpha[direction * 2] << ")\n";
661 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = 0.;\n";
662 out << SP << SP << "}";
663 } else if (fAttrActivations[direction * 2] == "Elu") {
664 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
665 out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < 0.)\n";
666 out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
667 << fAttrActivationAlpha[direction * 2] << " * exp(" << OpName << "_update_gate[i] - 1.);\n";
668 out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < 0.)\n";
669 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
670 << fAttrActivationAlpha[direction * 2] << " * exp(" << OpName << "_reset_gate[i] - 1.);\n";
671 out << SP << SP << "}\n";
672 } else if (fAttrActivations[direction * 2] == "Softsign") {
673 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
674 out << SP << SP << SP << SP << OpName << "_update_gate[i] = " << OpName
675 << "_update_gate[i] / (1. + abs(" << OpName << "_update_gate[i]));\n";
676 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = " << OpName
677 << "_reset_gate[i] / (1. + abs(" << OpName << "_reset_gate[i]));\n";
678 out << SP << SP << "}\n";
679 } else { // fAttrActivations[direction * 2] = Softplus
680 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
681 out << SP << SP << SP << SP << OpName << "_update_gate[i] = log(1. + exp("
682 << OpName << "_update_gate[i]));\n";
683 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = log(1. + exp("
684 << OpName << "_reset_gate[i]));\n";
685 out << SP << SP << "}\n";
686 }
687
688 if (fAttrLinearBeforeReset == 0) {
689 out << SP << SP << "if (seq == 0) {\n";
690 if (!fNInitial_h.empty()) {
691 // feedback = reset_gate o initial_hidden_state
692 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
693 out << SP << SP << SP << SP << OpName << "_feedback[i] = " << OpName
694 << "_reset_gate[i + offset] * " << OpName << "_initial_hidden_state[i];\n";
695 out << SP << SP << SP << "}\n";
696 }
697 out << SP << SP << "} else {\n";
698 // feedback = reset_gate o previous_hidden_state
699 if (direction == 0) {
700 if (fAttrDirection == "backward") {
701 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
702 << num_directions * batch_size * fAttrHiddenSize << ";\n";
703 } else {
704 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
705 << num_directions * batch_size * fAttrHiddenSize << ";\n";
706 }
707 } else {
708 out << SP << SP << SP << "size_t previous_offset = (index + 1) * " << num_directions
709 * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
710 }
711 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
712 out << SP << SP << SP << SP << OpName << "_feedback[i] = " << OpName
713 << "_reset_gate[i + offset] * " << OpName << "_hidden_state[i + previous_offset];\n";
714 out << SP << SP << SP << "}\n";
715 out << SP << SP << "}\n";
716 // feedback = feedback * R_h^T
717 size_t rh_offset = (direction == 0) ?
718 2 * fAttrHiddenSize * fAttrHiddenSize : 3 * fAttrHiddenSize * fAttrHiddenSize
719 + 2 * fAttrHiddenSize * fAttrHiddenSize;
720 out << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
721 << OpName << "_n, &" << OpName << "_m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_"
722 << fNR << " + " << rh_offset << ", &" << OpName << "_n, " << OpName << "_feedback, &" << OpName
723 << "_n, &" << OpName << "_beta, " << OpName << "_feedback, &" << OpName << "_n);\n";
724 } else { // fAttrLinearBeforeReset=1
725 // feedback = previous_hidden_state * R_h^T
726 //LM fixes
727 size_t rh_offset = (direction == 0)
728 ? 2 * fAttrHiddenSize * fAttrHiddenSize
729 : 3 * fAttrHiddenSize * fAttrHiddenSize + 2 * fAttrHiddenSize * fAttrHiddenSize;
730 out << SP << SP << "if (seq == 0) {\n";
731 if (!fNInitial_h.empty()) {
732 // feedback = W * initial_hidden_state + bias
733 out << SP << SP << SP
734 << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName << "_n, &"
735 << OpName << "_m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
736 << rh_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &"
737 << OpName << "_n, &" << OpName << "_beta, " << OpName << "_feedback, &" << OpName << "_n);\n";
738 }
739 out << SP << SP << "} else {\n";
740 // case for seq > 0
741 if (direction == 0) {
742 if (fAttrDirection == "backward") {
743 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
744 << num_directions * batch_size * fAttrHiddenSize << ";\n";
745 } else {
746 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
747 << num_directions * batch_size * fAttrHiddenSize << ";\n";
748 }
749 } else {
750 out << SP << SP << SP << "size_t previous_offset = (index + 1) * " << num_directions
751 * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
752 }
753 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
754 << OpName << "_n, &" << OpName << "_m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR
755 << " + " << rh_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
756 << OpName << "_n, &" << OpName << "_beta, " << OpName << "_feedback, &" << OpName << "_n);\n";
757 // endif on seq 0 or not
758 out << SP << SP << "}\n";
759 // Add the bias of the recurrence to feedback
760 if (!fNB.empty()) {
761 size_t rbh_offset = (direction == 0) ? 5 * batch_size * seq_length * fAttrHiddenSize
762 : 11 * batch_size * seq_length * fAttrHiddenSize;
763 out << SP << SP << "BLAS::saxpy_(&" << OpName << "_feedback_size, &" << OpName
764 << "_alpha, tensor_" << fNB << " + " << rbh_offset << ", &" << OpName << "_incx, "
765 << OpName << "_feedback, &" << OpName << "_incy);\n";
766 }
767 // feedback = reset_gate o feedback
768 out << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
769 out << SP << SP << SP << OpName << "_feedback[i] *= " << OpName << "_reset_gate[i + offset];\n";
770 out << SP << SP << "}\n";
771 }
772
773 // hidden_gate = hidden_gate + feedback
774 out << SP << SP << "BLAS::saxpy_(&" << OpName << "_feedback_size, &" << OpName << "_alpha, "
775 << OpName << "_feedback, &" << OpName << "_incx, " << OpName << "_hidden_gate + offset, &"
776 << OpName << "_incy);\n";
777
778 // Clip the elements of the hidden gate into the range [-fClip, fClip]
779 if (fAttrClip > .0) {
780 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
781 if (fType == "float") {
782 out << SP << SP << SP << "float x = (" << OpName << "_hidden_gate[i] > " << -fAttrClip
783 << ") ? " << OpName << "_hidden_gate[i] : " << -fAttrClip << ";\n";
784 }
785 out << SP << SP << SP << OpName << "_hidden_gate[i] = (x < " << fAttrClip << ") ? x : "
786 << fAttrClip << ";\n";
787 out << SP << SP << "}\n";
788 }
789
790 // Apply the activation function to the hidden gate
791 if (fAttrActivations[direction * 2 + 1] == "Relu") {
792 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
793 out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < 0.)\n";
794 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = 0.;\n";
795 out << SP << SP << "}\n";
796 } else if (fAttrActivations[direction * 2 + 1] == "Tanh") {
797 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
798 if (fType == "float") {
799 out << SP << SP << SP << "float ex = exp(-2 * " << OpName << "_hidden_gate[i]);\n";
800 }
801 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = (1. - ex) / (1. + ex);\n";
802 out << SP << SP << "}\n";
803 } else if (fAttrActivations[direction * 2 + 1] == "Sigmoid") {
804 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
805 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = 1. / (1. + exp(-" << OpName
806 << "_hidden_gate[i]));\n";
807 out << SP << SP << "}\n";
808 } else if (fAttrActivations[direction * 2 + 1] == "Affine") {
809 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
810 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
811 << fAttrActivationAlpha[direction * 2 + 1] << " * " << OpName << "_hidden_gate[i] + "
812 << fAttrActivationBeta[direction * 2 + 1] << ";\n";
813 out << SP << SP << "}\n";
814 } else if (fAttrActivations[direction * 2 + 1] == "ScaledTanh") {
815 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
816 if (fType == "float") {
817 out << SP << SP << SP << "float ex = exp(-2 * " << fAttrActivationBeta[direction * 2 + 1]
818 << " * "<< OpName << "_hidden_gate[i]);\n";
819 }
820 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
821 << fAttrActivationAlpha[direction * 2 + 1] << " * (1. - ex) / (1. + ex);\n";
822 out << SP << SP << "}\n";
823 } else if (fAttrActivations[direction * 2 + 1] == "HardSigmoid") {
824 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
825 if (fType == "float") {
826 out << SP << SP << SP << "float a = " << fAttrActivationAlpha[direction * 2 + 1] << " * "
827 << OpName << "_hidden_gate[i] + " << fAttrActivationBeta[direction * 2 + 1] << ";\n";
828 out << SP << SP << SP << "float b = (a > 0.) ? a : 0.;\n";
829 }
830 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = (b < 1.) ? b : 1.;\n";
831 out << SP << SP << "}\n";
832 } else if (fAttrActivations[direction * 2 + 1] == "LeakyRelu") {
833 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
834 out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < 0.)\n";
835 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
836 << fAttrActivationAlpha[direction * 2 + 1] << " * " << OpName << "_hidden_gate[i];\n";
837 out << SP << SP << "}\n";
838 } else if (fAttrActivations[direction * 2 + 1] == "ThresholdRelu") {
839 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
840 out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < "
841 << fAttrActivationAlpha[direction * 2 + 1] << ")\n";
842 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = 0.;\n";
843 out << SP << SP << "}";
844 } else if (fAttrActivations[direction * 2 + 1] == "Elu") {
845 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
846 out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < 0.)\n";
847 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
848 << fAttrActivationAlpha[direction * 2 + 1] << " * exp(" << OpName << "_hidden_gate[i] - 1.);\n";
849 out << SP << SP << "}\n";
850 } else if (fAttrActivations[direction * 2 + 1] == "Softsign") {
851 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
852 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = " << OpName
853 << "_hidden_gate[i] / (1. + abs(" << OpName << "_hidden_gate[i]));\n";
854 out << SP << SP << "}\n";
855 } else { // fAttrActivations[direction * 2 + 1] = Softplus
856 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
857 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = log(1. + exp("
858 << OpName << "_hidden_gate[i]));\n";
859 out << SP << SP << "}\n";
860 }
861
862 // hidden_state = (1 - update_gate) o hidden_gate
863 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
864 out << SP << SP << SP << OpName << "_hidden_state[i] = ( 1. - " << OpName
865 << "_update_gate[i]) * " << OpName << "_hidden_gate[i];\n";
866 out << SP << SP << "}\n";
867
868 out << SP << SP << "if (seq == 0) {\n";
869 if (!fNInitial_h.empty()) {
870 // hidden_state += update_gate o initial_hidden_state
871 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
872 out << SP << SP << SP << SP << OpName << "_hidden_state[i + offset] += " << OpName
873 << "_update_gate[i + offset] * " << OpName << "_initial_hidden_state[i];\n";
874 out << SP << SP << SP << "}\n";
875 }
876 out << SP << SP << "} else {\n";
877 // hidden_state += update_gate o previous_hidden_state
878 if (direction == 0) {
879 if (fAttrDirection == "backward") {
880 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
881 << num_directions * batch_size * fAttrHiddenSize << ";\n";
882 } else {
883 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
884 << num_directions * batch_size * fAttrHiddenSize << ";\n";
885 }
886 } else {
887 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
888 << num_directions * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
889 }
890 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
891 out << SP << SP << SP << SP << OpName << "_hidden_state[i + offset] += " << OpName
892 << "_update_gate[i + offset] * " << OpName << "_hidden_state[i + previous_offset];\n";
893 out << SP << SP << SP << "}\n";
894 out << SP << SP << "}\n";
895
896 out << SP << "}\n";
897 }
898
899 // Padding the hidden state for GRU with different sequence lengths
900 if (!fNSequence_lens.empty()) {
901 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
902 out << SP << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
903 out << SP << SP << SP << "if (seq >= tensor_" << fNSequence_lens << "[batch]) {\n";
904 for (size_t direction = 0; direction < num_directions; direction++) {
905 out << SP << SP << SP << SP << SP << "for (size_t h = 0; h < " << fAttrHiddenSize << "; h++) {\n";
906 out << SP << SP << SP << SP << SP << SP << OpName << "_hidden_state[seq * "
907 << num_directions * batch_size * fAttrHiddenSize + direction * batch_size * fAttrHiddenSize
908 << " + batch * " << fAttrHiddenSize << " + h] = 0.;\n";
909 out << SP << SP << SP << SP << SP << "}\n";
910 }
911 out << SP << SP << SP << "}\n";
912 out << SP << SP << "}\n";
913 out << SP << "}\n";
914 }
915
916 // Copy the hidden state into y and y_h
917 if (fAttrLayout == 0) {
918 if (!fNY_h.empty()) {
919 // Copy hidden_state into Y_h
920 if (fNSequence_lens.empty()) {
921 size_t yh_size = batch_size * fAttrHiddenSize;
922 if (fAttrDirection == "backward") {
923 out << SP << "std::copy(" << OpName << "_hidden_state, " << OpName << "_hidden_state + "
924 << yh_size << ", tensor_" << fNY_h << ");\n";
925 } else {
926 size_t offset = (seq_length - 1) * num_directions * batch_size * fAttrHiddenSize;
927 out << SP << "std::copy(" << OpName << "_hidden_state + " << offset << ", " << OpName
928 << "_hidden_state + " << offset << " + " << yh_size << ", tensor_" << fNY_h << ");\n";
929 }
930 if (num_directions == 2) {
931 out << SP << "std::copy(" << OpName << "_hidden_state + " << yh_size << ", " << OpName
932 << "_hidden_state + " << 2 * yh_size << ", tensor_" << fNY_h << " + " << yh_size << ");\n";
933 }
934 } else { // GRU with different sequence lengths
935 if (fAttrDirection == "backward") {
936 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
937 out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
938 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
939 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + offset);\n";
940 out << SP << "}\n";
941 } else {
942 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
943 out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
944 out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
945 << " + batch * " << fAttrHiddenSize << ";\n";
946 out << SP << SP << "size_t yh_offset = batch * " << fAttrHiddenSize << ";\n";
947 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
948 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
949 out << SP << "}\n";
950 }
951 if (num_directions == 2) {
952 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
953 out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize
954 << " + batch * " << fAttrHiddenSize << ";\n";
955 out << SP << SP << "size_t yh_offset = " << batch_size * fAttrHiddenSize
956 << " + batch * " << fAttrHiddenSize << ";\n";
957 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
958 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
959 out << SP << "}\n";
960 }
961 }
962 }
963 } else { // fAttrLayout=1
964 if (!fNY.empty()) {
965 // Copy hidden_state into Y
966 for (size_t direction = 0; direction < num_directions; direction++) {
967 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
968 out << SP << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
969 out << SP << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
970 << " + " << direction * batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize << ";\n";
971 out << SP << SP << SP << "size_t y_offset = batch * " << seq_length * num_directions * fAttrHiddenSize
972 << " + seq * " << num_directions * fAttrHiddenSize << " + " << direction * fAttrHiddenSize << ";\n";
973 out << SP << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
974 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY << " + y_offset);\n";
975 out << SP << SP << "}\n";
976 out << SP << "}\n";
977 }
978 }
979 if (!fNY_h.empty()) {
980 // Copy the hidden_state into Y_h
981 if (fAttrDirection == "backward") {
982 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
983 out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
984 out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
985 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
986 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
987 out << SP << "}\n";
988 } else {
989 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
990 if (fNSequence_lens.empty()) {
991 out << SP << SP << "size_t seq = " << seq_length - 1 << ";\n";
992 } else {
993 out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
994 }
995 out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
996 << " + batch * " << fAttrHiddenSize << ";\n";
997 out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
998 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
999 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
1000 out << SP << "}\n";
1001 }
1002 if (num_directions == 2) {
1003 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1004 out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize << " + batch * "
1005 << fAttrHiddenSize << ";\n";
1006 out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << " + "
1007 << fAttrHiddenSize << ";\n";
1008 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1009 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
1010 out << SP << "}\n";
1011 }
1012 }
1013 }
1014
1015 return out.str();
1016}
1017
1018} // namespace SOFIE
1019} // namespace Experimental
1020} // namespace TMVA
1021
1022#endif
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
const ETensorType & GetTensorType(std::string name)
Definition RModel.cxx:70
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape)
Definition RModel.cxx:136
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:91
const std::vector< size_t > & GetTensorShape(std::string name)
Definition RModel.cxx:49
std::shared_ptr< void > GetInitializedTensorData(std::string tensor_name)
Definition RModel.cxx:160
void UpdateInitializedTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape, std::shared_ptr< void > data)
Definition RModel.cxx:151
std::string GenerateSessionMembersCode(std::string opName)
Generate the code for the Session internal data vectors.
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > >)
Infers the shape of the output tensors.
void Initialize(RModel &)
Initialize the model.
std::string Generate(std::string)
Generate the inference code.
std::vector< ETensorType > TypeInference(std::vector< ETensorType >)
Infers the type of the output tensors.
create variable transformations