Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_LSTM.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR_LSTM
2#define TMVA_SOFIE_ROPERATOR_LSTM
3
4#include "TMVA/RModel.hxx"
5#include "TMVA/ROperator.hxx"
7
8#include <memory>
9#include <sstream>
10#include <string>
11#include <vector>
12
14
15/*! \brief Long Short-Term Memory operator
16 *
17 * Inference code generation for one-layer LSTM. Supports forward, reverse and bidirectional LSTM.
18 * See the <a href="https://github.com/onnx/onnx/blob/master/docs/Operators.md#LSTM">ONNX documentation</a>
19 * for details about the supported LSTM architectures.
20 */
21template <typename T> class ROperator_LSTM final : public ROperator {
22 private:
23 std::vector<float> fAttrActivationAlpha; ///< Sacling values used by some activation functions
24 std::vector<float> fAttrActivationBeta; ///< Scaling values used by some activation functions
25 std::vector<std::string> fAttrActivations; ///< Activation functions
26 float fAttrClip; ///< Clip threshold
27 std::string fAttrDirection; ///< Direction of processing
28 size_t fAttrHiddenSize; ///< Number of the hidden layers
29 size_t fAttrInputForget; ///< Forget gate
30 size_t fAttrLayout; ///< Data layout
31
32 std::string fNX; ///< Name of the input
33 std::string fNW; ///< Name of the weights
34 std::string fNR; ///< Name of the recurrence
35 std::string fNB; ///< Name of the bias
36 std::string fNSequence_lens; ///< Name of length of the sequences
37 std::string fNInitial_h; ///< Name of the initial value of the hidden states
38 std::string fNInitial_c; ///< Name of the initial value of the cell states
39 std::string fNP; ///< Name of peepholes
40 std::string fNY; ///< Name of the output
41 std::string fNY_h; ///< Name of the last sequence of the output
42 std::string fNY_c; ///< Name of the last sequence of the cell states
43
44 std::vector<size_t> fShapeX; ///< Shape of the input
45 std::vector<size_t> fShapeW; ///< Shape of the weights
46 std::vector<size_t> fShapeR; ///< Shape of the recurrence
47 std::vector<size_t> fShapeB; ///< Shape of the bias
48 std::vector<size_t> fShapeSequence_lens; ///< Shape of the length of the sequences
49 std::vector<size_t> fShapeInitial_h; ///< Shape of the initial value of the hidden states
50 std::vector<size_t> fShapeInitial_c; ///< Shape of the initial value of the cell states
51 std::vector<size_t> fShapeP; ///< Shape of the peepholes
52 std::vector<size_t> fShapeY; ///< Shape of the output
53 std::vector<size_t> fShapeY_h; ///< Shape of the last sequence of the output
54 std::vector<size_t> fShapeY_c; ///< Shape of the last sequence of the cell states
55
56 std::string fType; ///< Type of the tensors
57
58 public:
59 /*! Default constructor of ROperator_LSTM */
61
62 /*! \brief Constructor of ROperator_LSTM from the attributes
63 *
64 * \param activation_alpha scaling values used by some activation functions
65 * \param activation_beta scaling values used by some activation functions
66 * \param activations activation functions
67 * \param clip clip threshold
68 * \param direction direction of processing of the sequneces
69 * \param hidden_size number of hidden layers
70 * \param input_forget forget gate
71 * \param layout data layout
72 * \param nameX name of the input tensor
73 * \param nameW name of the weight tensor
74 * \param nameR name of the recurrence tensor
75 * \param nameB name of the bias tensor
76 * \param nameSequence_lens name of the length of the sequences
77 * \param nameInitial_h name of the initial value of the hidden states
78 * \param nameInitial_c name of the initial value of the cell states
79 * \param nameP name of the peepholes tensor
80 * \param nameY name of the output
81 * \param nameY_h name of the last sequence of the output
82 * \param nameY_c name of the last sequence of the cell states
83 */
84 ROperator_LSTM(std::vector<float> activation_alpha,
85 std::vector<float> activation_beta,
86 std::vector<std::string> activations, float clip,
87 std::string direction, size_t hidden_size,
88 size_t input_forget, size_t layout,
89 std::string nameX, std::string nameW, std::string nameR,
90 std::string nameB, std::string nameSequence_lens,
91 std::string nameInitial_h, std::string nameInitial_c, std::string nameP,
92 std::string nameY, std::string nameY_h, std::string nameY_c)
97 fNX(UTILITY::Clean_name(nameX)), fNW(UTILITY::Clean_name(nameW)),
98 fNR(UTILITY::Clean_name(nameR)), fNB(UTILITY::Clean_name(nameB)),
99 fNSequence_lens(UTILITY::Clean_name(nameSequence_lens)),
100 fNInitial_h(UTILITY::Clean_name(nameInitial_h)),
101 fNInitial_c(UTILITY::Clean_name(nameInitial_c)), fNP(UTILITY::Clean_name(nameP)),
102 fNY(UTILITY::Clean_name(nameY)), fNY_h(UTILITY::Clean_name(nameY_h)),
103 fNY_c(UTILITY::Clean_name(nameY_c)) {
104 if (std::is_same<T, float>::value) {
105 fType = "float";
106 } else {
107 throw std::runtime_error(
108 "TMVA SOFIE Encountered unsupported type parsing a LSTM operator");
109 }
110
112 if (!fNB.empty()){
113 fInputTensorNames.emplace_back(fNB);
114 }
115 if (!fNSequence_lens.empty()){
117 }
118 if (!fNInitial_h.empty()){
119 fInputTensorNames.emplace_back(fNInitial_h);
120 }
121 if (!fNInitial_c.empty()){
122 fInputTensorNames.emplace_back(fNInitial_c);
123 }
124 if (!fNP.empty()){
125 fInputTensorNames.emplace_back(fNP);
126 }
127
128 fOutputTensorNames = { };
129 if (!fNY.empty()){
130 fOutputTensorNames.emplace_back(fNY);
131 }
132 if (!fNY_h.empty()){
133 fOutputTensorNames.emplace_back(fNY_h);
134 }
135 if (!fNY_c.empty()){
136 fOutputTensorNames.emplace_back(fNY_c);
137 }
138 }
139
140 /*! \brief Infers the type of the output tensors
141 *
142 * \param input type of the input tensors
143 */
144 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override;
145
146 /*! \brief Infers the shape of the output tensors
147 *
148 * \param input shape of the input tensors
149 */
150 std::vector<std::vector<size_t>>
151 ShapeInference(std::vector<std::vector<size_t>> input) override;
152
153 /*! \brief Initialize the model
154 *
155 * \param model Model
156 */
157 void Initialize(RModel &) override;
158
159 /*! \brief Generate the inference code
160 *
161 * \param OpName name of the operator
162 */
163 std::string Generate(std::string OpName) override;
164
165 /*! \brief Generate the code for the Session internal data vectors
166 *
167 * \param opName name of the operator
168 */
169 std::string GenerateSessionMembersCode(std::string opName) override;
170
171 /*! \brief Returns the blas routines needed to compile the generated code
172 */
173 std::vector<std::string> GetBlasRoutines() override { return { std::string("Gemm"), std::string("Axpy") }; }
174};
175
176template <typename T>
177auto ROperator_LSTM<T>::TypeInference(std::vector<ETensorType> input) -> std::vector<ETensorType>
178{
179 ETensorType out = input[0];
180 return {out, out};
181}
182
183template <typename T>
184auto ROperator_LSTM<T>::ShapeInference(std::vector<std::vector<size_t>> input) -> std::vector<std::vector<size_t>>
185{
186 size_t num_directions = input[1][0];
187 size_t hidden_size = input[1][1] / 4;
188 if (fAttrLayout == 0) {
189 size_t seq_length = input[0][0];
190 size_t batch_size = input[0][1];
191 std::vector<std::vector<size_t>> ret({{seq_length, num_directions, batch_size, hidden_size},
194 return ret;
195 } else {
196 size_t batch_size = input[0][0];
197 size_t seq_length = input[0][1];
198 std::vector<std::vector<size_t>> ret({{batch_size, seq_length, num_directions, hidden_size},
201 return ret;
202 }
203}
204
205template <typename T>
207{
208 fUseSession = model.UseSession();
209 // Check the input and output tensors
210 if (!model.CheckIfTensorAlreadyExist(fNX)) {
211 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNX + " is not found in model.");
212 }
213 fShapeX = model.GetTensorShape(fNX);
214 if (fShapeX.size() != 3) {
215 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNX + " is not of 3 dimensions.");
216 }
217 if (!model.CheckIfTensorAlreadyExist(fNW)) {
218 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNW + " is not found in model.");
219 }
220 fShapeW = model.GetTensorShape(fNW);
221 if (fShapeW.size() != 3) {
222 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNW + " is not of 3 dimensions.");
223 }
224 if (!model.CheckIfTensorAlreadyExist(fNR)) {
225 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNR + " is not found in model.");
226 }
227 fShapeR = model.GetTensorShape(fNR);
228 if (fShapeR.size() != 3) {
229 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNR + " is not of 3 dimensions.");
230 }
231 if (!fNB.empty()) {
232 if (!model.CheckIfTensorAlreadyExist(fNB)) {
233 throw std::runtime_error("TMVA SOFIE LSTM op input tensor " + fNB + " is not found in model.");
234 }
235 fShapeB = model.GetTensorShape(fNB);
236 if (fShapeB.size() != 2 && fShapeB.size() != 5) {
237 throw std::runtime_error("TMVA SOFIE LSTM op input tensor " + fNB + " is not of 2 or 5 dimensions.");
238 }
239 if (fShapeB.size() == 2) {
240 // Broadcasting the bias
241 auto original_data = model.GetInitializedTensorData(fNB);
242 size_t num_directions = fShapeW[0];
243 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
244 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
245 if (fType == "float") {
246 float *original_bias = static_cast<float *>(original_data.get());
247 float *new_bias = new float[4 * num_directions * seq_length * batch_size * fAttrHiddenSize];
248 for (size_t gate = 0; gate < 4; gate++) {
249 std::vector<float> sum(fAttrHiddenSize);
250 for (size_t direction = 0; direction < num_directions; direction++) {
251 size_t offset = direction * 8 * fAttrHiddenSize + gate * fAttrHiddenSize;
252 for (size_t h = 0; h < fAttrHiddenSize; h++) {
253 sum[h] = original_bias[offset + h] + original_bias[offset + h + 4 * fAttrHiddenSize];
254 }
255 for (size_t seq = 0; seq < seq_length; seq++) {
256 for (size_t batch = 0; batch < batch_size; batch++) {
257 size_t bias_offset = gate * num_directions * seq_length * batch_size * fAttrHiddenSize +
258 direction * seq_length * batch_size * fAttrHiddenSize +
259 seq * batch_size * fAttrHiddenSize + batch * fAttrHiddenSize;
260 std::copy(sum.begin(), sum.end(), new_bias + bias_offset);
261 }
262 }
263 }
264 }
265 std::vector<size_t> new_bias_shape = {4, num_directions, seq_length, batch_size, fAttrHiddenSize};
266 std::shared_ptr<void> new_bias_ptr(new_bias, std::default_delete<float[]>());
267 model.UpdateInitializedTensor(fNB, model.GetTensorType(fNB), new_bias_shape, new_bias_ptr);
268 fShapeB = model.GetTensorShape(fNB);
269 }
270 }
271 }
272 if (!fNSequence_lens.empty()) {
273 if (!model.CheckIfTensorAlreadyExist(fNSequence_lens)) {
274 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNSequence_lens + "is not found in model.");
275 }
276 fShapeSequence_lens = model.GetTensorShape(fNSequence_lens);
277 if (fShapeSequence_lens.size() != 1) {
278 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNSequence_lens + " is not of 1 dimension.");
279 }
280 }
281 if (!fNInitial_h.empty()) {
282 if (!model.CheckIfTensorAlreadyExist(fNInitial_h)) {
283 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNInitial_h + " is not found in model.");
284 }
285 fShapeInitial_h = model.GetTensorShape(fNInitial_h);
286 if (fShapeInitial_h.size() != 3) {
287 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNInitial_h + " is not of 3 dimensions.");
288 }
289 }
290 if (!fNInitial_c.empty()) {
291 if (!model.CheckIfTensorAlreadyExist(fNInitial_c)) {
292 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNInitial_c + " is not found in model.");
293 }
294 fShapeInitial_c = model.GetTensorShape(fNInitial_c);
295 if (fShapeInitial_c.size() != 3) {
296 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNInitial_c + " is not of 3 dimensions.");
297 }
298 }
299 if (!fNP.empty()) {
300 if (!model.CheckIfTensorAlreadyExist(fNP)) {
301 throw std::runtime_error("TMVA SOFIE LSTM op input tensor " + fNP + " is not found in model.");
302 }
303 fShapeP = model.GetTensorShape(fNP);
304 if (fShapeP.size() != 2 && fShapeP.size() != 4) {
305 throw std::runtime_error("TMVA SOFIE LSTM op input tensor " + fNP + " is not of 2 or 4 dimensions.");
306 }
307 if (fShapeP.size() == 2) {
308 // Broadcasting the weight for peepholes
309 auto original_data = model.GetInitializedTensorData(fNP);
310 size_t num_directions = fShapeW[0];
311 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
312 if (fType == "float") {
313 float *original_p = static_cast<float *>(original_data.get());
314 float *new_p = new float[num_directions * 3 * batch_size * fAttrHiddenSize];
315 for (size_t direction = 0; direction < num_directions; direction++) {
316 for (size_t gate = 0; gate < 3; gate++) {
317 size_t p_offset = direction * 3 * fAttrHiddenSize + gate * fAttrHiddenSize;
318 for (size_t batch = 0; batch < batch_size; batch++) {
319 size_t offset = direction * 3 * batch_size * fAttrHiddenSize +
320 gate * batch_size * fAttrHiddenSize + batch * fAttrHiddenSize;
321 std::copy(original_p + p_offset, original_p + p_offset + fAttrHiddenSize, new_p + offset);
322 }
323 }
324 }
325 std::vector<size_t> new_p_shape = {num_directions, 3, batch_size, fAttrHiddenSize};
326 std::shared_ptr<void> new_p_ptr(new_p, std::default_delete<float[]>());
327 model.UpdateInitializedTensor(fNP, model.GetTensorType(fNP), new_p_shape, new_p_ptr);
328 fShapeP = model.GetTensorShape(fNP);
329 }
330 }
331 }
332 if (!fNY.empty()) {
333 fShapeY = ShapeInference({fShapeX, fShapeW})[0];
334 if (!model.CheckIfTensorAlreadyExist(fNY)) {
335 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
336 }
337 }
338 if (!fNY_h.empty()) {
339 fShapeY_h = ShapeInference({fShapeX, fShapeW})[1];
340 if (!model.CheckIfTensorAlreadyExist(fNY_h)) {
341 model.AddIntermediateTensor(fNY_h, model.GetTensorType(fNX), fShapeY_h);
342 }
343 }
344 if (!fNY_c.empty()) {
345 fShapeY_c = ShapeInference({fShapeX, fShapeW})[2];
346 if (!model.CheckIfTensorAlreadyExist(fNY_c)) {
347 model.AddIntermediateTensor(fNY_c, model.GetTensorType(fNX), fShapeY_c);
348 }
349 }
350 // Check the attributes
351 for (auto &activation : fAttrActivations) {
352 if (activation != "Relu" && activation != "Tanh" && activation != "Sigmoid" && activation != "Affine" &&
353 activation != "LeakyRelu" && activation != "ThresholdRelu" && activation != "ScaledTanh" &&
354 activation != "HardSigmoid" && activation != "Elu" && activation != "Softsign" && activation != "Softplus") {
355 throw std::runtime_error("TMVA SOFIE - Activation function " + activation + " not implemented");
356 }
357 }
358 if (fAttrDirection != "forward" && fAttrDirection != "backward" && fAttrDirection != "bidirectional") {
359 throw std::runtime_error("TMVA SOFIE - Invalid LSTM direction fAttrDirection = " + fAttrDirection);
360 }
361 if (4 * fAttrHiddenSize != fShapeW[1]) {
362 throw std::runtime_error("TMVA SOFIE - fAttrHiddenSize must be equal to " + std::to_string(fShapeW[1] / 4));
363 }
364 if (fAttrInputForget > 1) {
365 throw std::runtime_error("TMVA SOFIE - fAttrInputForget = " + std::to_string(fAttrInputForget) +
366 " must be 0 or 1.");
367 }
368 if (fAttrLayout > 1) {
369 throw std::runtime_error("TMVA SOFIE - Layout fAttrLayout = " + std::to_string(fAttrLayout) +
370 " must be 0 (timewise) or 1 (batchwise)");
371 }
372 if (fAttrActivations.empty()) {
373 if (fAttrDirection == "bidirectional") {
374 fAttrActivations = {"Sigmoid", "Tanh", "Tanh", "Sigmoid", "Tanh", "Tanh"};
375 } else {
376 fAttrActivations = {"Sigmoid", "Tanh", "Tanh"};
377 }
378 }
379}
380
381// generate code for Session data members (e.g. internal vectors)
382template <typename T>
384{
385 opName = "op_" + opName;
386 std::stringstream out;
387
388 size_t num_directions = fShapeW[0];
389 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
390 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
391 size_t input_size = fShapeX[2];
392
393 struct Block {
394 std::string name;
395 size_t size;
396 };
397
398 std::vector<Block> blocks;
399
400 size_t ff_size = seq_length * batch_size * fAttrHiddenSize;
401 size_t hs_size = seq_length * num_directions * batch_size * fAttrHiddenSize;
402
403 // Layout-dependent buffers
404 if (fAttrLayout != 0) {
405 blocks.push_back({"input", seq_length * batch_size * input_size});
406 blocks.push_back({"initial_hidden_state", num_directions * batch_size * fAttrHiddenSize});
407 blocks.push_back({"initial_cell_state", num_directions * batch_size * fAttrHiddenSize});
408 }
409
410 // Feedforward gates
411 blocks.push_back({"ff_input_gate", ff_size});
412 blocks.push_back({"ff_output_gate", ff_size});
413 blocks.push_back({"ff_cell_gate", ff_size});
414 if (fAttrInputForget == 0)
415 blocks.push_back({"ff_forget_gate", ff_size});
416
417 // Gate outputs
418 blocks.push_back({"input_gate", hs_size});
419 blocks.push_back({"output_gate", hs_size});
420 blocks.push_back({"cell_gate", hs_size});
421 if (fAttrInputForget == 0)
422 blocks.push_back({"forget_gate", hs_size});
423
424 // Cell state
425 blocks.push_back({"cell_state", hs_size});
426 blocks.push_back({"new_cell_state", hs_size});
427
428 // Hidden state (conditional)
429 if (fAttrLayout != 0 || fNY.empty()) {
430 blocks.push_back({"hidden_state", hs_size});
431 }
432
433 // Compute total size
434 size_t total_size = 0;
435 for (const auto &b : blocks) {
436 total_size += b.size;
437 }
438
439 // Backing storage
440 out << "std::vector<" << fType << "> fVec_" << opName << "_buffer = std::vector<" << fType << ">(" << total_size
441 << ");\n";
442
443 // Emit pointers
444 std::size_t offset = 0;
445 for (const auto &b : blocks) {
446 out << fType << "* fVec_" << opName << "_" << b.name << " = fVec_" << opName << "_buffer.data() + " << offset
447 << ";\n";
448 offset += b.size;
449 }
450
451 out << "\n";
452
453 return out.str();
454}
455
456template <typename T>
457auto ROperator_LSTM<T>::Generate(std::string OpName) -> std::string
458{
459 OpName = "op_" + OpName;
460 std::stringstream out;
461
462 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
463 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
464 size_t input_size = fShapeX[2];
465 size_t num_directions = fShapeW[0];
466
467 // set the input
468 if (fAttrLayout == 0) {
469 out << SP << fType << " const *" << OpName << "_input = tensor_" << fNX << ";\n";
470 } else {
471 if (fUseSession)
472 out << SP << fType << " * " << OpName << "_input = this->fVec_" << OpName << "_input;\n";
473 else
474 out << SP << fType << " " << OpName << "_input[" << seq_length * batch_size * input_size << "] = {0};\n";
475
476 out << SP << "for(size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
477 out << SP << SP << "for(size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
478 out << SP << SP << SP << "for(size_t i = 0; i < " << input_size << "; i++) {\n";
479 out << SP << SP << SP << SP << OpName << "_input[seq * " << batch_size * input_size << " + batch * " << input_size
480 << " + i] = " << "tensor_" << fNX << "[batch * " << seq_length * input_size << " + seq * " << input_size
481 << " + i];\n";
482 out << SP << SP << SP << "}\n";
483 out << SP << SP << "}\n";
484 out << SP << "}\n";
485 }
486
487 // Set the initial hidden state
488 if (!fNInitial_h.empty()) {
489 if (fAttrLayout == 0) {
490 out << SP << fType << " const*" << OpName << "_initial_hidden_state = " << " tensor_" << fNInitial_h << ";\n";
491 } else {
492 if (fUseSession)
493 out << SP << fType << " const* " << OpName << "_initial_hidden_state = this->fVec_" << OpName
494 << "_initial_hidden_state;\n";
495 else
496 out << SP << fType << " " << OpName << "_initial_hidden_state["
497 << num_directions * batch_size * fAttrHiddenSize << "] = {0};\n";
498
499 for (size_t direction = 0; direction < num_directions; direction++) {
500 out << SP << "for(size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
501 out << SP << SP << "for(size_t h = 0; h < " << fAttrHiddenSize << "; h++) {\n";
502 out << SP << SP << SP << OpName << "_initial_hidden_state[" << direction * batch_size * fAttrHiddenSize
503 << " + batch * " << fAttrHiddenSize << " + h] = tensor_" << fNInitial_h << "[batch * "
504 << num_directions * fAttrHiddenSize << " + " << direction * fAttrHiddenSize << " + h];\n";
505 out << SP << SP << "}\n";
506 out << SP << "}\n";
507 }
508 }
509 }
510
511 // Set the initial cell state
512 if (!fNInitial_c.empty()) {
513 if (fAttrLayout == 0) {
514 out << SP << fType << " const*" << OpName << "_initial_cell_state = " << " tensor_" << fNInitial_c << ";\n";
515 } else {
516 if (fUseSession)
517 out << SP << fType << " const* " << OpName << "_initial_cell_state = this->fVec_" << OpName
518 << "_initial_cell_state;\n";
519 else
520 out << SP << fType << " " << OpName << "_initial_cell_state["
521 << num_directions * batch_size * fAttrHiddenSize << "] = {0};\n";
522
523 for (size_t direction = 0; direction < num_directions; direction++) {
524 out << SP << "for(size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
525 out << SP << SP << "for(size_t h = 0; h < " << fAttrHiddenSize << "; h++) {\n";
526 out << SP << SP << SP << OpName << "_initial_cell_state[" << direction * batch_size * fAttrHiddenSize
527 << " + batch * " << fAttrHiddenSize << " + h] = tensor_" << fNInitial_c << "[batch * "
528 << num_directions * fAttrHiddenSize << " + " << direction * fAttrHiddenSize << " + h];\n";
529 out << SP << SP << "}\n";
530 out << SP << "}\n";
531 }
532 }
533 }
534
535 // Set the feedforward
536 size_t ff_size = seq_length * batch_size * fAttrHiddenSize;
537 if (fUseSession) {
538 out << SP << fType << " * " << OpName << "_ff_input_gate = this->fVec_" << OpName << "_ff_input_gate;\n";
539 out << SP << fType << " * " << OpName << "_ff_output_gate = this->fVec_" << OpName << "_ff_output_gate;\n";
540 out << SP << fType << " * " << OpName << "_ff_cell_gate = this->fVec_" << OpName << "_ff_cell_gate;\n";
541 if (fAttrInputForget == 0) {
542 out << SP << fType << " * " << OpName << "_ff_forget_gate = this->fVec_" << OpName
543 << "_ff_forget_gate;\n";
544 }
545 } else {
546 out << SP << fType << " " << OpName << "_ff_input_gate[" << ff_size << "] = {0};\n";
547 out << SP << fType << " " << OpName << "_ff_output_gate[" << ff_size << "] = {0};\n";
548 out << SP << fType << " " << OpName << "_ff_cell_gate[" << ff_size << "] = {0};\n";
549 if (fAttrInputForget == 0) {
550 out << SP << fType << " " << OpName << "_ff_forget_gate[" << ff_size << "] = {0};\n";
551 }
552 }
553 // Set the gates
554 size_t hidden_state_size = seq_length * num_directions * batch_size * fAttrHiddenSize;
555 if (fUseSession) {
556 out << SP << fType << " * " << OpName << "_input_gate = this->fVec_" << OpName << "_input_gate;\n";
557 out << SP << fType << " * " << OpName << "_output_gate = this->fVec_" << OpName << "_output_gate;\n";
558 out << SP << fType << " * " << OpName << "_cell_gate = this->fVec_" << OpName << "_cell_gate;\n";
559 if (fAttrInputForget == 0) {
560 out << SP << fType << " * " << OpName << "_forget_gate = this->fVec_" << OpName << "_forget_gate;\n";
561 }
562 } else {
563 out << SP << fType << " " << OpName << "_input_gate[" << hidden_state_size << "] = {0};\n";
564 out << SP << fType << " " << OpName << "_output_gate[" << hidden_state_size << "] = {0};\n";
565 out << SP << fType << " " << OpName << "_cell_gate[" << hidden_state_size << "] = {0};\n";
566 if (fAttrInputForget == 0) {
567 out << SP << fType << " " << OpName << "_forget_gate[" << hidden_state_size << "] = {0};\n";
568 }
569 }
570 // Set the cell state and the new cell state = h(cell state)
571 if (fUseSession) {
572 out << SP << fType << " * " << OpName << "_cell_state = this->fVec_" << OpName << "_cell_state;\n";
573 out << SP << fType << " * " << OpName << "_new_cell_state = this->fVec_" << OpName << "_new_cell_state;\n";
574 } else {
575 out << SP << fType << " " << OpName << "_cell_state[" << hidden_state_size << "] = {0};\n";
576 out << SP << fType << " " << OpName << "_new_cell_state[" << hidden_state_size << "] = {0};\n";
577 }
578
579 // Set the hidden state
580 if (fAttrLayout == 0 && !fNY.empty()) {
581 out << SP << fType << " *" << OpName << "_hidden_state = tensor_" << fNY << ";\n";
582 } else {
583 if (fUseSession) {
584 out << SP << fType << " * " << OpName << "_hidden_state = this->fVec_" << OpName << "_hidden_state;\n";
585 } else {
586 out << SP << fType << " " << OpName << "_hidden_state[" << hidden_state_size << "] = {0};\n";
587 }
588 }
589
590 out << SP << "char " << OpName << "_transA = 'N';\n";
591 out << SP << "char " << OpName << "_transB = 'T';\n";
592 out << SP << "int " << OpName << "_m = " << seq_length * batch_size << ";\n";
593 out << SP << "int " << OpName << "_n = " << fAttrHiddenSize << ";\n";
594 out << SP << "int " << OpName << "_k = " << input_size << ";\n";
595 if (fType == "float") {
596 out << SP << fType << " " << OpName << "_alpha = 1.;\n";
597 out << SP << fType << " " << OpName << "_beta = 0.;\n";
598 }
599 if (!fNB.empty()) {
600 out << SP << "int " << OpName << "_bias_size = " << seq_length * batch_size * fAttrHiddenSize << ";\n";
601 out << SP << "int " << OpName << "_incx = 1;\n";
602 out << SP << "int " << OpName << "_incy = 1;\n";
603 }
604
605 auto emit_sgemm = [&](const std::string &out_name, size_t offset) -> std::string {
606 std::stringstream ss;
607 ss << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName << "_n, &" << OpName
608 << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_" << fNW;
609
610 if (offset != 0)
611 ss << " + " << offset;
612
613 ss << ", &" << OpName << "_k, " << OpName << "_input, &" << OpName << "_k, &" << OpName << "_beta, " << OpName
614 << "_" << out_name << ", &" << OpName << "_n);\n";
615 return ss.str();
616 };
617
618 for (size_t direction = 0; direction < num_directions; direction++) {
619 if (direction == 0) {
620 if (fType == "float") {
621 // input_gate = input * weight_i^T
622 out << SP << emit_sgemm("ff_input_gate", 0);
623 // output_gate = input * weight_o^T
624 size_t wo_offset = fAttrHiddenSize * input_size;
625 out << SP << emit_sgemm("ff_output_gate", wo_offset);
626 // cell_gate = input * weight_c^T
627 size_t wc_offset = 3 * fAttrHiddenSize * input_size;
628 out << SP << emit_sgemm("ff_cell_gate", wc_offset);
629 }
630 } else {
631 if (fType == "float") {
632 // input_gate = input * weight_i^T
633 out << SP << emit_sgemm("ff_input_gate", 4 * fAttrHiddenSize * input_size);
634 // output_gate = input * weight_o^T
635 size_t wo_offset = 4 * fAttrHiddenSize * input_size + 1 * fAttrHiddenSize * input_size;
636 out << SP << emit_sgemm("ff_output_gate", wo_offset);
637 // cell_gate = input * weight_c^T
638 size_t wc_offset = 4 * fAttrHiddenSize * input_size + 3 * fAttrHiddenSize * input_size;
639 out << SP << emit_sgemm("ff_cell_gate", wc_offset);
640 }
641 }
642 if (fAttrInputForget == 0) {
643 // forget_gate = input * weight_f^T
644 if (direction == 0) {
645 if (fType == "float") {
646 size_t wf_offset = 2 * fAttrHiddenSize * input_size;
647 out << SP << emit_sgemm("ff_forget_gate", wf_offset);
648 }
649 } else {
650 if (fType == "float") {
651 size_t wf_offset = 4 * fAttrHiddenSize * input_size + 2 * fAttrHiddenSize * input_size;
652 out << SP << emit_sgemm("ff_forget_gate", wf_offset);
653 }
654 }
655 }
656
657 // Add the bias
658 if (!fNB.empty()) {
659 if (direction == 0) {
660 if (fType == "float") {
661 // ff_input_gate += bias_i
662 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_" << fNB << ", &"
663 << OpName << "_incx, " << OpName << "_ff_input_gate, &" << OpName << "_incy);\n";
664 // ff_output_gate += bias_o
665 size_t bo_offset = seq_length * batch_size * fAttrHiddenSize;
666 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_" << fNB << " + "
667 << bo_offset << ", &" << OpName << "_incx, " << OpName << "_ff_output_gate, &" << OpName
668 << "_incy);\n";
669 // ff_cell_gate += bias_c
670 size_t bc_offset = 3 * seq_length * batch_size * fAttrHiddenSize;
671 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_" << fNB << " + "
672 << bc_offset << ", &" << OpName << "_incx, " << OpName << "_ff_cell_gate, &" << OpName
673 << "_incy);\n";
674 }
675 } else {
676 if (fType == "float") {
677 // ff_input_gate += bias_i
678 size_t bi_offset = 4 * seq_length * batch_size * fAttrHiddenSize;
679 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_" << fNB << " + "
680 << bi_offset << ", &" << OpName << "_incx, " << OpName << "_ff_input_gate, &" << OpName
681 << "_incy);\n";
682 // ff_output_gate += bias_o
683 size_t bo_offset =
684 4 * seq_length * batch_size * fAttrHiddenSize + seq_length * batch_size * fAttrHiddenSize;
685 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_" << fNB << " + "
686 << bo_offset << ", &" << OpName << "_incx, " << OpName << "_ff_output_gate, &" << OpName
687 << "_incy);\n";
688 // ff_cell_gate += bias_c
689 size_t bc_offset = 4 * num_directions * seq_length * batch_size * fAttrHiddenSize +
690 3 * seq_length * batch_size * fAttrHiddenSize;
691 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_" << fNB << " + "
692 << bc_offset << ", &" << OpName << "_incx, " << OpName << "_ff_cell_gate, &" << OpName
693 << "_incy);\n";
694 }
695 }
696 if (fAttrInputForget == 0) {
697 // ff_forget_gate += bias_f
698 if (direction == 0) {
699 if (fType == "float") {
700 size_t bo_offset = 2 * seq_length * batch_size * fAttrHiddenSize;
701 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_" << fNB
702 << " + " << bo_offset << ", &" << OpName << "_incx, " << OpName << "_ff_forget_gate, &" << OpName
703 << "_incy);\n";
704 }
705 } else {
706 if (fType == "float") {
707 size_t bo_offset =
708 4 * seq_length * batch_size * fAttrHiddenSize + 2 * seq_length * batch_size * fAttrHiddenSize;
709 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_" << fNB
710 << " + " << bo_offset << ", &" << OpName << "_incx, " << OpName << "_ff_forget_gate, &" << OpName
711 << "_incy);\n";
712 }
713 }
714 }
715 }
716
717 // Copy ff_input_gate, ff_output_gate, ff_cell_gate and ff_forget_gate into input_gate, output_gate,
718 // cell_gate and forget_gate
719 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
720 out << SP << SP << "size_t ff_offset = seq * " << batch_size * fAttrHiddenSize << ";\n";
721 if (direction == 0) {
722 out << SP << SP << "size_t gate_offset = seq * " << num_directions * batch_size * fAttrHiddenSize << ";\n";
723 } else {
724 out << SP << SP << "size_t gate_offset = seq * " << num_directions * batch_size * fAttrHiddenSize << " + "
725 << batch_size * fAttrHiddenSize << ";\n";
726 }
727 size_t ff_seq_size = batch_size * fAttrHiddenSize;
728 out << SP << SP << "std::copy(" << OpName << "_ff_input_gate + ff_offset, " << OpName
729 << "_ff_input_gate + ff_offset + " << ff_seq_size << ", " << OpName << "_input_gate + gate_offset);\n";
730 out << SP << SP << "std::copy(" << OpName << "_ff_output_gate + ff_offset, " << OpName
731 << "_ff_output_gate + ff_offset + " << ff_seq_size << ", " << OpName << "_output_gate + gate_offset);\n";
732 out << SP << SP << "std::copy(" << OpName << "_ff_cell_gate + ff_offset, " << OpName
733 << "_ff_cell_gate + ff_offset + " << ff_seq_size << ", " << OpName << "_cell_gate + gate_offset);\n";
734 if (fAttrInputForget == 0) {
735 out << SP << SP << "std::copy(" << OpName << "_ff_forget_gate + ff_offset, " << OpName
736 << "_ff_forget_gate + ff_offset + " << ff_seq_size << ", " << OpName << "_forget_gate + gate_offset);\n";
737 }
738 out << SP << "}\n";
739
740 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
741 if (fAttrDirection == "backward" || direction == 1) {
742 out << SP << SP << "size_t index = " << seq_length - 1 << " - seq;\n";
743 } else {
744 out << SP << SP << "size_t index = seq;\n";
745 }
746 out << SP << SP << "int m2 = " << batch_size << ";\n";
747 if (direction == 0) {
748 out << SP << SP << "size_t offset = index * " << num_directions * batch_size * fAttrHiddenSize << ";\n";
749 } else {
750 out << SP << SP << "size_t offset = index * " << num_directions * batch_size * fAttrHiddenSize << " + "
751 << batch_size * fAttrHiddenSize << ";\n";
752 }
753 size_t size = batch_size * fAttrHiddenSize;
754 // gate = gate + initial_hidden_state * Recurrence^T
755 out << SP << SP << "if (seq == 0) {\n";
756 if (!fNInitial_h.empty()) {
757 if (direction == 0) {
758 if (fType == "float") {
759 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
760 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << ", &" << OpName
761 << "_n, " << OpName << "_initial_hidden_state, &" << OpName << "_n, &" << OpName << "_alpha, "
762 << OpName << "_input_gate + offset, &" << OpName << "_n);\n";
763 size_t ro_offset = fAttrHiddenSize * fAttrHiddenSize;
764 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
765 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << ro_offset
766 << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName << "_n, &" << OpName
767 << "_alpha, " << OpName << "_output_gate + offset, &" << OpName << "_n);\n";
768 size_t rc_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
769 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
770 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << rc_offset
771 << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName << "_n, &" << OpName
772 << "_alpha, " << OpName << "_cell_gate + offset, &" << OpName << "_n);\n";
773 if (fAttrInputForget == 0) {
774 size_t rf_offset = 2 * fAttrHiddenSize * fAttrHiddenSize;
775 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
776 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
777 << rf_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
778 << "_n, &" << OpName << "_alpha, " << OpName << "_forget_gate + offset, &" << OpName << "_n);\n";
779 }
780 }
781 } else { // direction=1
782 if (fType == "float") {
783 size_t ri_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
784 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
785 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << ri_offset
786 << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName << "_n, &" << OpName
787 << "_alpha, " << OpName << "_input_gate + offset, &" << OpName << "_n);\n";
788 size_t ro_offset = 4 * fAttrHiddenSize * fAttrHiddenSize + 1 * fAttrHiddenSize * fAttrHiddenSize;
789 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
790 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << ro_offset
791 << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName << "_n, &" << OpName
792 << "_alpha, " << OpName << "_output_gate + offset, &" << OpName << "_n);\n";
793 size_t rc_offset = 4 * fAttrHiddenSize * fAttrHiddenSize + 3 * fAttrHiddenSize * fAttrHiddenSize;
794 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
795 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << rc_offset
796 << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName << "_n, &" << OpName
797 << "_alpha, " << OpName << "_cell_gate + offset, &" << OpName << "_n);\n";
798 if (fAttrInputForget == 0) {
799 size_t rf_offset = 4 * fAttrHiddenSize * fAttrHiddenSize + 2 * fAttrHiddenSize * fAttrHiddenSize;
800 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
801 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
802 << rf_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
803 << "_n, &" << OpName << "_alpha, " << OpName << "_forget_gate + offset, &" << OpName << "_n);\n";
804 }
805 }
806 }
807 }
808 out << SP << SP << "} else {\n";
809 // gate = gate + previous_hidden_state * Recurrence^T
810 if (direction == 0) {
811 if (fAttrDirection == "backward") {
812 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
813 << num_directions * batch_size * fAttrHiddenSize << ";\n";
814 } else {
815 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
816 << num_directions * batch_size * fAttrHiddenSize << ";\n";
817 }
818 if (fType == "float") {
819 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
820 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << ", &" << OpName << "_n, "
821 << OpName << "_hidden_state + previous_offset, &" << OpName << "_n, &" << OpName << "_alpha, " << OpName
822 << "_input_gate + offset, &" << OpName << "_n);\n";
823 size_t ro_offset = 1 * fAttrHiddenSize * fAttrHiddenSize;
824 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
825 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << ro_offset
826 << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &" << OpName << "_n, &"
827 << OpName << "_alpha, " << OpName << "_output_gate + offset, &" << OpName << "_n);\n";
828 size_t rc_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
829 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
830 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << rc_offset
831 << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &" << OpName << "_n, &"
832 << OpName << "_alpha, " << OpName << "_cell_gate + offset, &" << OpName << "_n);\n";
833 if (fAttrInputForget == 0) {
834 size_t rf_offset = 2 * fAttrHiddenSize * fAttrHiddenSize;
835 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
836 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << rf_offset
837 << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &" << OpName << "_n, &"
838 << OpName << "_alpha, " << OpName << "_forget_gate + offset, &" << OpName << "_n);\n";
839 }
840 }
841 } else {
842 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
843 << num_directions * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
844 if (fType == "float") {
845 size_t ri_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
846 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
847 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << ri_offset
848 << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &" << OpName << "_n, &"
849 << OpName << "_alpha, " << OpName << "_input_gate + offset, &" << OpName << "_n);\n";
850 size_t ro_offset = 4 * fAttrHiddenSize * fAttrHiddenSize + fAttrHiddenSize * fAttrHiddenSize;
851 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
852 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << ro_offset
853 << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &" << OpName << "_n, &"
854 << OpName << "_alpha, " << OpName << "_output_gate + offset, &" << OpName << "_n);\n";
855 size_t rc_offset = 4 * fAttrHiddenSize * fAttrHiddenSize + 3 * fAttrHiddenSize * fAttrHiddenSize;
856 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
857 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << rc_offset
858 << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &" << OpName << "_n, &"
859 << OpName << "_alpha, " << OpName << "_cell_gate + offset, &" << OpName << "_n);\n";
860 if (fAttrInputForget == 0) {
861 size_t rf_offset = 4 * fAttrHiddenSize * fAttrHiddenSize + 2 * fAttrHiddenSize * fAttrHiddenSize;
862 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
863 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << rf_offset
864 << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &" << OpName << "_n, &"
865 << OpName << "_alpha, " << OpName << "_forget_gate + offset, &" << OpName << "_n);\n";
866 }
867 }
868 }
869 out << SP << SP << "}\n";
870
871 // Clip the elements of the cell gate into the range [-fAttrClip, fAttrClip]
872 if (fAttrClip > .0) {
873 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
874 if (fType == "float") {
875 out << SP << SP << SP << "float x = (" << OpName << "_cell_gate[i] > " << -fAttrClip << ") ? " << OpName
876 << "_cell_gate[i] : " << -fAttrClip << ";\n";
877 }
878 out << SP << SP << SP << OpName << "_cell_gate[i] = (x < " << fAttrClip << ") ? x : " << fAttrClip << ";\n";
879 out << SP << SP << "}\n";
880 }
881 // Apply the activation function to the cell gate, cell_gate = g(cell_gate)
882 if (fAttrActivations[direction * 3 + 1] == "Relu") {
883 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
884 out << SP << SP << SP << "if (" << OpName << "_cell_gate[i] < 0.)\n";
885 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = 0.;\n";
886 out << SP << SP << "}\n";
887 } else if (fAttrActivations[direction * 3 + 1] == "Tanh") {
888 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
889 if (fType == "float") {
890 out << SP << SP << SP << "float ex = exp(-2 * " << OpName << "_cell_gate[i]);\n";
891 }
892 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = (1. - ex) / (1. + ex);\n";
893 out << SP << SP << "}\n";
894 } else if (fAttrActivations[direction * 3 + 1] == "Sigmoid") {
895 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
896 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = 1. / (1. + exp(-" << OpName << "_cell_gate[i]));\n";
897 out << SP << SP << "}\n";
898 } else if (fAttrActivations[direction * 3 + 1] == "Affine") {
899 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
900 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = " << fAttrActivationAlpha[direction * 3 + 1] << " * "
901 << OpName << "_cell_gate[i] + " << fAttrActivationBeta[direction * 3 + 1] << ";\n";
902 out << SP << SP << "}\n";
903 } else if (fAttrActivations[direction * 3 + 1] == "ScaledTanh") {
904 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
905 if (fType == "float") {
906 out << SP << SP << SP << "float ex = exp(-2 * " << fAttrActivationBeta[direction * 3 + 1] << " * " << OpName
907 << "_cell_gate[i]);\n";
908 }
909 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = " << fAttrActivationAlpha[direction * 3 + 1]
910 << " * (1. - ex) / (1. + ex);\n";
911 out << SP << SP << "}\n";
912 } else if (fAttrActivations[direction * 3 + 1] == "HardSigmoid") {
913 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
914 if (fType == "float") {
915 out << SP << SP << SP << "float a = " << fAttrActivationAlpha[direction * 3 + 1] << " * " << OpName
916 << "_cell_gate[i] + " << fAttrActivationBeta[direction * 3 + 1] << ";\n";
917 out << SP << SP << SP << "float b = (a > 0.) ? a : 0.;\n";
918 }
919 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = (b < 1.) ? b : 1.;\n";
920 out << SP << SP << "}\n";
921 } else if (fAttrActivations[direction * 3 + 1] == "LeakyRelu") {
922 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
923 out << SP << SP << SP << "if (" << OpName << "_cell_gate[i] < 0.)\n";
924 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = " << fAttrActivationAlpha[direction * 3 + 1] << " * "
925 << OpName << "_cell_gate[i];\n";
926 out << SP << SP << "}\n";
927 } else if (fAttrActivations[direction * 3 + 1] == "ThresholdRelu") {
928 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
929 out << SP << SP << SP << "if (" << OpName << "_cell_gate[i] < " << fAttrActivationAlpha[direction * 3 + 1]
930 << ")\n";
931 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = 0.;\n";
932 out << SP << SP << "}";
933 } else if (fAttrActivations[direction * 3 + 1] == "Elu") {
934 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
935 out << SP << SP << SP << "if (" << OpName << "_cell_gate[i] < 0.)\n";
936 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = " << fAttrActivationAlpha[direction * 3 + 1]
937 << " * exp(" << OpName << "_cell_gate[i] - 1.);\n";
938 out << SP << SP << "}\n";
939 } else if (fAttrActivations[direction * 3 + 1] == "Softsign") {
940 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
941 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = " << OpName << "_cell_gate[i] / (1. + abs(" << OpName
942 << "_cell_gate[i]));\n";
943 out << SP << SP << "}\n";
944 } else { // fAttrActivations[direction * 3 + 1] = Softplus
945 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
946 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = log(1. + exp(" << OpName << "_cell_gate[i]));\n";
947 out << SP << SP << "}\n";
948 }
949
950 // Peephole connections for the input gate and the forget gate
951 if (!fNP.empty()) {
952 // gate = 1.0 * gate + previous_cell_state * P^T
953 out << SP << SP << "if (seq == 0) {\n";
954 if (!fNInitial_c.empty()) {
955 if (direction == 0) {
956 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
957 out << SP << SP << SP << SP << OpName << "_input_gate[i + offset] += tensor_" << fNP << "[i] * "
958 << OpName << "_initial_cell_state[i];\n";
959 out << SP << SP << SP << "}\n";
960 if (fAttrInputForget == 0) {
961 size_t pf_offset = batch_size * fAttrHiddenSize;
962 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
963 out << SP << SP << SP << SP << OpName << "_forget_gate[i + offset] += tensor_" << fNP << "[i + "
964 << pf_offset << "] * " << OpName << "_initial_cell_state[i];\n";
965 out << SP << SP << SP << "}\n";
966 }
967 } else {
968 size_t pi_offset = 3 * batch_size * fAttrHiddenSize;
969 size_t initial_c_offset = batch_size * fAttrHiddenSize;
970 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
971 out << SP << SP << SP << SP << OpName << "_input_gate[i + offset] += tensor_" << fNP << "[i + "
972 << pi_offset << "] * " << OpName << "_initial_cell_state[i + " << initial_c_offset << "];\n";
973 out << SP << SP << SP << "}\n";
974 if (fAttrInputForget == 0) {
975 size_t pf_offset = 3 * batch_size * fAttrHiddenSize + batch_size * fAttrHiddenSize;
976 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
977 out << SP << SP << SP << SP << OpName << "_forget_gate[i + offset] += tensor_" << fNP << "[i + "
978 << pf_offset << "] * " << OpName << "_initial_cell_state[i + " << initial_c_offset << "];\n";
979 out << SP << SP << SP << "}\n";
980 }
981 }
982 }
983 out << SP << SP << "} else {\n";
984 if (direction == 0) {
985 if (fAttrDirection == "backward") {
986 out << SP << SP << SP << "size_t c_offset = (index + 1) * "
987 << num_directions * batch_size * fAttrHiddenSize << ";\n";
988 } else {
989 out << SP << SP << SP << "size_t c_offset = (seq - 1) * "
990 << num_directions * batch_size * fAttrHiddenSize << ";\n";
991 }
992 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
993 out << SP << SP << SP << SP << OpName << "_input_gate[i + offset] += tensor_" << fNP << "[i] * " << OpName
994 << "_cell_state[i + c_offset];\n";
995 out << SP << SP << SP << "}\n";
996 if (fAttrInputForget == 0) {
997 size_t pf_offset = batch_size * fAttrHiddenSize;
998 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
999 out << SP << SP << SP << SP << OpName << "_forget_gate[i + offset] += tensor_" << fNP << "[i + "
1000 << pf_offset << "] * " << OpName << "_cell_state[i + c_offset];\n";
1001 out << SP << SP << SP << "}\n";
1002 }
1003 } else { // direction=1
1004 size_t pi_offset = 3 * batch_size * fAttrHiddenSize;
1005 out << SP << SP << SP << "size_t c_offset = (index + 1) * " << num_directions * batch_size * fAttrHiddenSize
1006 << " + " << batch_size * fAttrHiddenSize << ";\n";
1007 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
1008 out << SP << SP << SP << SP << OpName << "_input_gate[i + offset] += tensor_" << fNP << "[i + " << pi_offset
1009 << "] * " << OpName << "_cell_state[i + c_offset];\n";
1010 out << SP << SP << SP << "}\n";
1011 if (fAttrInputForget == 0) {
1012 size_t pf_offset = 3 * batch_size * fAttrHiddenSize + batch_size * fAttrHiddenSize;
1013 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
1014 out << SP << SP << SP << SP << OpName << "_forget_gate[i + offset] += tensor_" << fNP << "[i + "
1015 << pf_offset << "] * " << OpName << "_cell_state[i + c_offset];\n";
1016 out << SP << SP << SP << "}\n";
1017 }
1018 }
1019 out << SP << SP << "}\n";
1020 }
1021
1022 // Clip the elements of the input gate into the range [-fAttrClip, fAttrClip]
1023 if (fAttrClip > .0) {
1024 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1025 if (fType == "float") {
1026 out << SP << SP << SP << "float x = (" << OpName << "_input_gate[i] > " << -fAttrClip << ") ? " << OpName
1027 << "_input_gate[i] : " << -fAttrClip << ";\n";
1028 }
1029 out << SP << SP << SP << OpName << "_input_gate[i] = (x < " << fAttrClip << ") ? x : " << fAttrClip << ";\n";
1030 out << SP << SP << "}\n";
1031 }
1032 // Apply the activation function to the input gate
1033 if (fAttrActivations[direction * 3] == "Relu") {
1034 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1035 out << SP << SP << SP << "if (" << OpName << "_input_gate[i] < 0.)\n";
1036 out << SP << SP << SP << SP << OpName << "_input_gate[i] = 0.;\n";
1037 out << SP << SP << "}\n";
1038 } else if (fAttrActivations[direction * 3] == "Tanh") {
1039 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1040 if (fType == "float") {
1041 out << SP << SP << SP << "float ex = exp(-2 * " << OpName << "_input_gate[i]);\n";
1042 }
1043 out << SP << SP << SP << SP << OpName << "_input_gate[i] = (1. - ex) / (1. + ex);\n";
1044 out << SP << SP << "}\n";
1045 } else if (fAttrActivations[direction * 3] == "Sigmoid") {
1046 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1047 out << SP << SP << SP << SP << OpName << "_input_gate[i] = 1. / (1. + exp(-" << OpName
1048 << "_input_gate[i]));\n";
1049 out << SP << SP << "}\n";
1050 } else if (fAttrActivations[direction * 3] == "Affine") {
1051 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1052 out << SP << SP << SP << SP << OpName << "_input_gate[i] = " << fAttrActivationAlpha[direction * 3] << " * "
1053 << OpName << "_input_gate[i] + " << fAttrActivationBeta[direction * 3] << ";\n";
1054 out << SP << SP << "}\n";
1055 } else if (fAttrActivations[direction * 3] == "ScaledTanh") {
1056 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1057 if (fType == "float") {
1058 out << SP << SP << SP << "float ex = exp(-2 * " << fAttrActivationBeta[direction * 3] << " * " << OpName
1059 << "_input_gate[i]);\n";
1060 }
1061 out << SP << SP << SP << SP << OpName << "_input_gate[i] = " << fAttrActivationAlpha[direction * 3]
1062 << " * (1. - ex) / (1. + ex);\n";
1063 out << SP << SP << "}\n";
1064 } else if (fAttrActivations[direction * 3] == "HardSigmoid") {
1065 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1066 if (fType == "float") {
1067 out << SP << SP << SP << "float a = " << fAttrActivationAlpha[direction * 3] << " * " << OpName
1068 << "_input_gate[i] + " << fAttrActivationBeta[direction * 3] << ";\n";
1069 out << SP << SP << SP << "float b = (a > 0.) ? a : 0.;\n";
1070 }
1071 out << SP << SP << SP << SP << OpName << "_input_gate[i] = (b < 1.) ? b : 1.;\n";
1072 out << SP << SP << "}\n";
1073 } else if (fAttrActivations[direction * 3] == "LeakyRelu") {
1074 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1075 out << SP << SP << SP << "if (" << OpName << "_input_gate[i] < 0.)\n";
1076 out << SP << SP << SP << SP << OpName << "_input_gate[i] = " << fAttrActivationAlpha[direction * 3] << " * "
1077 << OpName << "_input_gate[i];\n";
1078 out << SP << SP << "}\n";
1079 } else if (fAttrActivations[direction * 3] == "ThresholdRelu") {
1080 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1081 out << SP << SP << SP << "if (" << OpName << "_input_gate[i] < " << fAttrActivationAlpha[direction * 3]
1082 << ")\n";
1083 out << SP << SP << SP << SP << OpName << "_input_gate[i] = 0.;\n";
1084 out << SP << SP << "}";
1085 } else if (fAttrActivations[direction * 3] == "Elu") {
1086 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1087 out << SP << SP << SP << "if (" << OpName << "_input_gate[i] < 0.)\n";
1088 out << SP << SP << SP << SP << OpName << "_input_gate[i] = " << fAttrActivationAlpha[direction * 3]
1089 << " * exp(" << OpName << "_input_gate[i] - 1.);\n";
1090 out << SP << SP << "}\n";
1091 } else if (fAttrActivations[direction * 3] == "Softsign") {
1092 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1093 out << SP << SP << SP << SP << OpName << "_input_gate[i] = " << OpName << "_input_gate[i] / (1. + abs("
1094 << OpName << "_input_gate[i]));\n";
1095 out << SP << SP << "}\n";
1096 } else { // fAttrActivations[direction * 3] = Softplus
1097 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1098 out << SP << SP << SP << SP << OpName << "_input_gate[i] = log(1. + exp(" << OpName << "_input_gate[i]));\n";
1099 out << SP << SP << "}\n";
1100 }
1101
1102 if (fAttrInputForget == 0) {
1103 // Clip the elements of the forget gate into the range [-fAttrClip, fAttrClip]
1104 if (fAttrClip > .0) {
1105 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1106 if (fType == "float") {
1107 out << SP << SP << SP << "float x = (" << OpName << "_forget_gate[i] > " << -fAttrClip << ") ? "
1108 << OpName << "_forget_gate[i] : " << -fAttrClip << ";\n";
1109 }
1110 out << SP << SP << SP << OpName << "_forget_gate[i] = (x < " << fAttrClip << ") ? x : " << fAttrClip
1111 << ";\n";
1112 out << SP << SP << "}\n";
1113 }
1114 // Apply the activation function to the forget gate, cell_gate = g(cell_gate)
1115 if (fAttrActivations[direction * 3] == "Relu") {
1116 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1117 out << SP << SP << SP << "if (" << OpName << "_forget_gate[i] < 0.)\n";
1118 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = 0.;\n";
1119 out << SP << SP << "}\n";
1120 } else if (fAttrActivations[direction * 3] == "Tanh") {
1121 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1122 if (fType == "float") {
1123 out << SP << SP << SP << "float ex = exp(-2 * " << OpName << "_forget_gate[i]);\n";
1124 }
1125 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = (1. - ex) / (1. + ex);\n";
1126 out << SP << SP << "}\n";
1127 } else if (fAttrActivations[direction * 3] == "Sigmoid") {
1128 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1129 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = 1. / (1. + exp(-" << OpName
1130 << "_forget_gate[i]));\n";
1131 out << SP << SP << "}\n";
1132 } else if (fAttrActivations[direction * 3] == "Affine") {
1133 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1134 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = " << fAttrActivationAlpha[direction * 3]
1135 << " * " << OpName << "_forget_gate[i] + " << fAttrActivationBeta[direction * 3] << ";\n";
1136 out << SP << SP << "}\n";
1137 } else if (fAttrActivations[direction * 3] == "ScaledTanh") {
1138 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1139 if (fType == "float") {
1140 out << SP << SP << SP << "float ex = exp(-2 * " << fAttrActivationBeta[direction * 3] << " * " << OpName
1141 << "_forget_gate[i]);\n";
1142 }
1143 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = " << fAttrActivationAlpha[direction * 3]
1144 << " * (1. - ex) / (1. + ex);\n";
1145 out << SP << SP << "}\n";
1146 } else if (fAttrActivations[direction * 3] == "HardSigmoid") {
1147 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1148 if (fType == "float") {
1149 out << SP << SP << SP << "float a = " << fAttrActivationAlpha[direction * 3] << " * " << OpName
1150 << "_forget_gate[i] + " << fAttrActivationBeta[direction * 3] << ";\n";
1151 out << SP << SP << SP << "float b = (a > 0.) ? a : 0.;\n";
1152 }
1153 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = (b < 1.) ? b : 1.;\n";
1154 out << SP << SP << "}\n";
1155 } else if (fAttrActivations[direction * 3] == "LeakyRelu") {
1156 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1157 out << SP << SP << SP << "if (" << OpName << "_forget_gate[i] < 0.)\n";
1158 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = " << fAttrActivationAlpha[direction * 3]
1159 << " * " << OpName << "_forget_gate[i];\n";
1160 out << SP << SP << "}\n";
1161 } else if (fAttrActivations[direction * 3] == "ThresholdRelu") {
1162 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1163 out << SP << SP << SP << "if (" << OpName << "_forget_gate[i] < " << fAttrActivationAlpha[direction * 3]
1164 << ")\n";
1165 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = 0.;\n";
1166 out << SP << SP << "}";
1167 } else if (fAttrActivations[direction * 3] == "Elu") {
1168 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1169 out << SP << SP << SP << "if (" << OpName << "_forget_gate[i] < 0.)\n";
1170 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = " << fAttrActivationAlpha[direction * 3]
1171 << " * exp(" << OpName << "_forget_gate[i] - 1.);\n";
1172 out << SP << SP << "}\n";
1173 } else if (fAttrActivations[direction * 3] == "Softsign") {
1174 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1175 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = " << OpName << "_forget_gate[i] / (1. + abs("
1176 << OpName << "_forget_gate[i]));\n";
1177 out << SP << SP << "}\n";
1178 } else { // fAttrActivations[direction * 3] = Softplus
1179 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1180 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = log(1. + exp(" << OpName
1181 << "_forget_gate[i]));\n";
1182 out << SP << SP << "}\n";
1183 }
1184 }
1185
1186 // cell_state = input_gate o cell_gate
1187 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1188 out << SP << SP << SP << OpName << "_cell_state[i] = " << OpName << "_input_gate[i] * " << OpName
1189 << "_cell_gate[i];\n";
1190 out << SP << SP << "}\n";
1191
1192 if (fAttrInputForget == 0) {
1193 out << SP << SP << "if (seq == 0) {\n";
1194 if (!fNInitial_c.empty()) {
1195 // cell_state += forget_gate o initial_cell_state
1196 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
1197 out << SP << SP << SP << SP << OpName << "_cell_state[i + offset] += " << OpName
1198 << "_forget_gate[i + offset] * " << OpName << "_initial_cell_state[i];\n";
1199 out << SP << SP << SP << "}\n";
1200 }
1201 out << SP << SP << "} else {\n";
1202 // cell_state += forget_gate o previous_cell_state
1203 if (direction == 0) {
1204 if (fAttrDirection == "backward") {
1205 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
1206 << num_directions * batch_size * fAttrHiddenSize << ";\n";
1207 } else {
1208 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
1209 << num_directions * batch_size * fAttrHiddenSize << ";\n";
1210 }
1211 } else { // direction=1
1212 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
1213 << num_directions * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
1214 }
1215 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
1216 out << SP << SP << SP << SP << OpName << "_cell_state[i + offset] += " << OpName
1217 << "_forget_gate[i + offset] * " << OpName << "_cell_state[i + previous_offset];\n";
1218 out << SP << SP << SP << "}\n";
1219 out << SP << SP << "}\n";
1220 }
1221
1222 if (!fNP.empty()) {
1223 // Peephole connection for the output gate
1224 if (direction == 0) {
1225 size_t p_offset = 2 * batch_size * fAttrHiddenSize;
1226 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
1227 out << SP << SP << SP << SP << OpName << "_output_gate[i + offset] += tensor_" << fNP << "[i + " << p_offset
1228 << "] * " << OpName << "_cell_state[i + offset];\n";
1229 out << SP << SP << SP << "}\n";
1230 } else { // direction=1
1231 size_t p_offset = 3 * batch_size * fAttrHiddenSize + 2 * batch_size * fAttrHiddenSize;
1232 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
1233 out << SP << SP << SP << SP << OpName << "_output_gate[i + offset] += tensor_" << fNP << "[i + " << p_offset
1234 << "] * " << OpName << "_cell_state[i + offset];\n";
1235 out << SP << SP << SP << "}\n";
1236 }
1237 }
1238
1239 // Clip the elements of the output gate into the range [-fAttrClip, fAttrClip]
1240 if (fAttrClip > .0) {
1241 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1242 if (fType == "float") {
1243 out << SP << SP << SP << "float x = (" << OpName << "_output_gate[i] > " << -fAttrClip << ") ? " << OpName
1244 << "_output_gate[i] : " << -fAttrClip << ";\n";
1245 }
1246 out << SP << SP << SP << OpName << "_output_gate[i] = (x < " << fAttrClip << ") ? x : " << fAttrClip << ";\n";
1247 out << SP << SP << "}\n";
1248 }
1249 // Apply the activation function to the output gate
1250 if (fAttrActivations[direction * 3] == "Relu") {
1251 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1252 out << SP << SP << SP << "if (" << OpName << "_output_gate[i] < 0.)\n";
1253 out << SP << SP << SP << SP << OpName << "_output_gate[i] = 0.;\n";
1254 out << SP << SP << "}\n";
1255 } else if (fAttrActivations[direction * 3] == "Tanh") {
1256 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1257 if (fType == "float") {
1258 out << SP << SP << SP << "float ex = exp(-2 * " << OpName << "_output_gate[i]);\n";
1259 }
1260 out << SP << SP << SP << SP << OpName << "_output_gate[i] = (1. - ex) / (1. + ex);\n";
1261 out << SP << SP << "}\n";
1262 } else if (fAttrActivations[direction * 3] == "Sigmoid") {
1263 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1264 out << SP << SP << SP << SP << OpName << "_output_gate[i] = 1. / (1. + exp(-" << OpName
1265 << "_output_gate[i]));\n";
1266 out << SP << SP << "}\n";
1267 } else if (fAttrActivations[direction * 3] == "Affine") {
1268 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1269 out << SP << SP << SP << SP << OpName << "_output_gate[i] = " << fAttrActivationAlpha[direction * 3] << " * "
1270 << OpName << "_output_gate[i] + " << fAttrActivationBeta[direction * 3] << ";\n";
1271 out << SP << SP << "}\n";
1272 } else if (fAttrActivations[direction * 3] == "ScaledTanh") {
1273 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1274 if (fType == "float") {
1275 out << SP << SP << SP << "float ex = exp(-2 * " << fAttrActivationBeta[direction * 3] << " * " << OpName
1276 << "_output_gate[i]);\n";
1277 }
1278 out << SP << SP << SP << SP << OpName << "_output_gate[i] = " << fAttrActivationAlpha[direction * 3]
1279 << " * (1. - ex) / (1. + ex);\n";
1280 out << SP << SP << "}\n";
1281 } else if (fAttrActivations[direction * 3] == "HardSigmoid") {
1282 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1283 if (fType == "float") {
1284 out << SP << SP << SP << "float a = " << fAttrActivationAlpha[direction * 3] << " * " << OpName
1285 << "_output_gate[i] + " << fAttrActivationBeta[direction * 3] << ";\n";
1286 out << SP << SP << SP << "float b = (a > 0.) ? a : 0.;\n";
1287 }
1288 out << SP << SP << SP << SP << OpName << "_output_gate[i] = (b < 1.) ? b : 1.;\n";
1289 out << SP << SP << "}\n";
1290 } else if (fAttrActivations[direction * 3] == "LeakyRelu") {
1291 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1292 out << SP << SP << SP << "if (" << OpName << "_output_gate[i] < 0.)\n";
1293 out << SP << SP << SP << SP << OpName << "_output_gate[i] = " << fAttrActivationAlpha[direction * 3] << " * "
1294 << OpName << "_output_gate[i];\n";
1295 out << SP << SP << "}\n";
1296 } else if (fAttrActivations[direction * 3] == "ThresholdRelu") {
1297 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1298 out << SP << SP << SP << "if (" << OpName << "_output_gate[i] < " << fAttrActivationAlpha[direction * 3]
1299 << ")\n";
1300 out << SP << SP << SP << SP << OpName << "_output_gate[i] = 0.;\n";
1301 out << SP << SP << "}";
1302 } else if (fAttrActivations[direction * 3] == "Elu") {
1303 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1304 out << SP << SP << SP << "if (" << OpName << "_output_gate[i] < 0.)\n";
1305 out << SP << SP << SP << SP << OpName << "_output_gate[i] = " << fAttrActivationAlpha[direction * 3]
1306 << " * exp(" << OpName << "_output_gate[i] - 1.);\n";
1307 out << SP << SP << "}\n";
1308 } else if (fAttrActivations[direction * 3] == "Softsign") {
1309 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1310 out << SP << SP << SP << SP << OpName << "_output_gate[i] = " << OpName << "_output_gate[i] / (1. + abs("
1311 << OpName << "_output_gate[i]));\n";
1312 out << SP << SP << "}\n";
1313 } else { // fAttrActivations[direction * 3] = Softplus
1314 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1315 out << SP << SP << SP << SP << OpName << "_output_gate[i] = log(1. + exp(" << OpName << "_output_gate[i]));\n";
1316 out << SP << SP << "}\n";
1317 }
1318
1319 // copy cell_state into new_cell_state
1320 out << SP << SP << "std::copy(" << OpName << "_cell_state + offset, " << OpName << "_cell_state + offset + "
1321 << size << ", " << OpName << "_new_cell_state + offset);\n";
1322 // Clip the elements of the new_cell_state into the range [-fAttrClip, fAttrClip]
1323 if (fAttrClip > .0) {
1324 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1325 if (fType == "float") {
1326 out << SP << SP << SP << "float x = (" << OpName << "_new_cell_state[i] > " << -fAttrClip << ") ? "
1327 << OpName << "_new_cell_state[i] : " << -fAttrClip << ";\n";
1328 }
1329 out << SP << SP << SP << OpName << "_new_cell_state[i] = (x < " << fAttrClip << ") ? x : " << fAttrClip
1330 << ";\n";
1331 out << SP << SP << "}\n";
1332 }
1333 // Apply the activation function to the new cell state
1334 if (fAttrActivations[direction * 3 + 2] == "Relu") {
1335 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1336 out << SP << SP << SP << "if (" << OpName << "_new_cell_state[i] < 0.)\n";
1337 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = 0.;\n";
1338 out << SP << SP << "}\n";
1339 } else if (fAttrActivations[direction * 3 + 2] == "Tanh") {
1340 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1341 if (fType == "float") {
1342 out << SP << SP << SP << "float ex = exp(-2 * " << OpName << "_new_cell_state[i]);\n";
1343 }
1344 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = (1. - ex) / (1. + ex);\n";
1345 out << SP << SP << "}\n";
1346 } else if (fAttrActivations[direction * 3 + 2] == "Sigmoid") {
1347 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1348 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = 1. / (1. + exp(-" << OpName
1349 << "_new_cell_state[i]));\n";
1350 out << SP << SP << "}\n";
1351 } else if (fAttrActivations[direction * 3 + 2] == "Affine") {
1352 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1353 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = " << fAttrActivationAlpha[direction * 3 + 2]
1354 << " * " << OpName << "_new_cell_state[i] + " << fAttrActivationBeta[direction * 3 + 2] << ";\n";
1355 out << SP << SP << "}\n";
1356 } else if (fAttrActivations[direction * 3 + 2] == "ScaledTanh") {
1357 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1358 if (fType == "float") {
1359 out << SP << SP << SP << "float ex = exp(-2 * " << fAttrActivationBeta[direction * 3 + 2] << " * " << OpName
1360 << "_new_cell_state[i]);\n";
1361 }
1362 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = " << fAttrActivationAlpha[direction * 3 + 2]
1363 << " * (1. - ex) / (1. + ex);\n";
1364 out << SP << SP << "}\n";
1365 } else if (fAttrActivations[direction * 3 + 2] == "HardSigmoid") {
1366 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1367 if (fType == "float") {
1368 out << SP << SP << SP << "float a = " << fAttrActivationAlpha[direction * 3 + 2] << " * " << OpName
1369 << "_new_cell_state[i] + " << fAttrActivationBeta[direction * 3 + 2] << ";\n";
1370 out << SP << SP << SP << "float b = (a > 0.) ? a : 0.;\n";
1371 }
1372 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = (b < 1.) ? b : 1.;\n";
1373 out << SP << SP << "}\n";
1374 } else if (fAttrActivations[direction * 3 + 2] == "LeakyRelu") {
1375 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1376 out << SP << SP << SP << "if (" << OpName << "_new_cell_state[i] < 0.)\n";
1377 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = " << fAttrActivationAlpha[direction * 3 + 2]
1378 << " * " << OpName << "_new_cell_state[i];\n";
1379 out << SP << SP << "}\n";
1380 } else if (fAttrActivations[direction * 3 + 2] == "ThresholdRelu") {
1381 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1382 out << SP << SP << SP << "if (" << OpName << "_new_cell_state[i] < " << fAttrActivationAlpha[direction * 3 + 2]
1383 << ")\n";
1384 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = 0.;\n";
1385 out << SP << SP << "}";
1386 } else if (fAttrActivations[direction * 3 + 2] == "Elu") {
1387 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1388 out << SP << SP << SP << "if (" << OpName << "_new_cell_state[i] < 0.)\n";
1389 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = " << fAttrActivationAlpha[direction * 3 + 2]
1390 << " * exp(" << OpName << "_new_cell_state[i] - 1.);\n";
1391 out << SP << SP << "}\n";
1392 } else if (fAttrActivations[direction * 3 + 2] == "Softsign") {
1393 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1394 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = " << OpName << "_new_cell_state[i] / (1. + abs("
1395 << OpName << "_new_cell_state[i]));\n";
1396 out << SP << SP << "}\n";
1397 } else { // fAttrActivations[direction * 3 + 2] = Softplus
1398 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1399 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = log(1. + exp(" << OpName
1400 << "_new_cell_state[i]));\n";
1401 out << SP << SP << "}\n";
1402 }
1403
1404 // hidden_state = output_gate o new_cell_state
1405 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1406 out << SP << SP << SP << OpName << "_hidden_state[i] = " << OpName << "_output_gate[i] * " << OpName
1407 << "_new_cell_state[i];\n";
1408 out << SP << SP << "}\n";
1409 out << SP << "}\n";
1410 }
1411
1412 // Padding the hidden state for LSTM with different sequence lengths
1413 if (!fNSequence_lens.empty()) {
1414 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
1415 out << SP << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1416 out << SP << SP << SP << "if (seq >= tensor_" << fNSequence_lens << "[batch]) {\n";
1417 for (size_t direction = 0; direction < num_directions; direction++) {
1418 out << SP << SP << SP << SP << SP << "for (size_t h = 0; h < " << fAttrHiddenSize << "; h++) {\n";
1419 out << SP << SP << SP << SP << SP << SP << "size_t idx = seq * "
1420 << num_directions * batch_size * fAttrHiddenSize + direction * batch_size * fAttrHiddenSize
1421 << " + batch * " << fAttrHiddenSize << " + h;\n";
1422 out << SP << SP << SP << SP << SP << SP << OpName << "_cell_state[idx] = 0.;\n";
1423 out << SP << SP << SP << SP << SP << SP << OpName << "_hidden_state[idx] = 0.;\n";
1424 out << SP << SP << SP << SP << SP << "}\n";
1425 }
1426 out << SP << SP << SP << "}\n";
1427 out << SP << SP << "}\n";
1428 out << SP << "}\n";
1429 }
1430
1431 // Copy the hidden state into y and y_h and copy cell_state into y_c
1432 if (fAttrLayout == 0) {
1433 if (!fNY_h.empty()) {
1434 // Copy hidden_state into Y_h
1435 if (fNSequence_lens.empty()) {
1436 size_t y_h_size = batch_size * fAttrHiddenSize;
1437 if (fAttrDirection == "backward") {
1438 out << SP << "std::copy(" << OpName << "_hidden_state, " << OpName << "_hidden_state + " << y_h_size
1439 << ", tensor_" << fNY_h << ");\n";
1440 } else {
1441 size_t offset = (seq_length - 1) * num_directions * batch_size * fAttrHiddenSize;
1442 out << SP << "std::copy(" << OpName << "_hidden_state + " << offset << ", " << OpName
1443 << "_hidden_state + " << offset << " + " << y_h_size << ", tensor_" << fNY_h << ");\n";
1444 }
1445 if (num_directions == 2) {
1446 out << SP << "std::copy(" << OpName << "_hidden_state + " << y_h_size << ", " << OpName
1447 << "_hidden_state + " << 2 * y_h_size << ", tensor_" << fNY_h << " + " << y_h_size << ");\n";
1448 }
1449 } else { // LSTM with different sequence lengths
1450 if (fAttrDirection == "backward") {
1451 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1452 out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
1453 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1454 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + offset);\n";
1455 out << SP << "}\n";
1456 } else {
1457 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1458 out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
1459 out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
1460 << " + batch * " << fAttrHiddenSize << ";\n";
1461 out << SP << SP << "size_t y_h_offset = batch * " << fAttrHiddenSize << ";\n";
1462 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1463 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + y_h_offset);\n";
1464 out << SP << "}\n";
1465 }
1466 if (num_directions == 2) {
1467 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1468 out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize
1469 << ";\n";
1470 out << SP << SP << "size_t y_h_offset = " << batch_size * fAttrHiddenSize << " + batch * "
1471 << fAttrHiddenSize << ";\n";
1472 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1473 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + y_h_offset);\n";
1474 out << SP << "}\n";
1475 }
1476 }
1477 }
1478 if (!fNY_c.empty()) {
1479 // Copy cell_state into Y_c
1480 if (fNSequence_lens.empty()) {
1481 size_t y_h_size = batch_size * fAttrHiddenSize;
1482 if (fAttrDirection == "backward") {
1483 out << SP << "std::copy(" << OpName << "_cell_state, " << OpName << "_hidden_state + " << y_h_size
1484 << ", tensor_" << fNY_c << ");\n";
1485 } else {
1486 size_t offset = (seq_length - 1) * num_directions * batch_size * fAttrHiddenSize;
1487 out << SP << "std::copy(" << OpName << "_cell_state + " << offset << ", " << OpName << "_cell_state + "
1488 << offset << " + " << y_h_size << ", tensor_" << fNY_c << ");\n";
1489 }
1490 if (num_directions == 2) {
1491 out << SP << "std::copy(" << OpName << "_cell_state + " << y_h_size << ", " << OpName << "_cell_state + "
1492 << 2 * y_h_size << ", tensor_" << fNY_c << " + " << y_h_size << ");\n";
1493 }
1494 } else { // LSTM with different sequence lengths
1495 if (fAttrDirection == "backward") {
1496 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1497 out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
1498 out << SP << SP << "std::copy(" << OpName << "_cell_state + offset, " << OpName
1499 << "_cell_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_c << " + offset);\n";
1500 out << SP << "}\n";
1501 } else {
1502 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1503 out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
1504 out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
1505 << " + batch * " << fAttrHiddenSize << ";\n";
1506 out << SP << SP << "size_t y_h_offset = batch * " << fAttrHiddenSize << ";\n";
1507 out << SP << SP << "std::copy(" << OpName << "_cell_state + offset, " << OpName
1508 << "_cell_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_c << " + y_h_offset);\n";
1509 out << SP << "}\n";
1510 }
1511 if (num_directions == 2) {
1512 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1513 out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize
1514 << ";\n";
1515 out << SP << SP << "size_t y_h_offset = " << batch_size * fAttrHiddenSize << " + batch * "
1516 << fAttrHiddenSize << ";\n";
1517 out << SP << SP << "std::copy(" << OpName << "_cell_state + offset, " << OpName
1518 << "_cell_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_c << " + y_h_offset);\n";
1519 out << SP << "}\n";
1520 }
1521 }
1522 }
1523 } else { // fAttrLayout=1
1524 if (!fNY.empty()) {
1525 // Copy hidden_state into Y
1526 for (size_t direction = 0; direction < num_directions; direction++) {
1527 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
1528 out << SP << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1529 out << SP << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize << " + "
1530 << direction * batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize << ";\n";
1531 out << SP << SP << SP << "size_t y_offset = batch * " << seq_length * num_directions * fAttrHiddenSize
1532 << " + seq * " << num_directions * fAttrHiddenSize << " + " << direction * fAttrHiddenSize << ";\n";
1533 out << SP << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1534 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY << " + y_offset);\n";
1535 out << SP << SP << "}\n";
1536 out << SP << "}\n";
1537 }
1538 }
1539 if (!fNY_h.empty()) {
1540 // Copy the hidden_state into Y_h
1541 if (fAttrDirection == "backward") {
1542 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1543 out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
1544 out << SP << SP << "size_t y_h_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
1545 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1546 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + y_h_offset);\n";
1547 out << SP << "}\n";
1548 } else {
1549 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1550 if (fNSequence_lens.empty()) {
1551 out << SP << SP << "size_t seq = " << seq_length - 1 << ";\n";
1552 } else {
1553 out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
1554 }
1555 out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
1556 << " + batch * " << fAttrHiddenSize << ";\n";
1557 out << SP << SP << "size_t y_h_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
1558 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1559 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + y_h_offset);\n";
1560 out << SP << "}\n";
1561 }
1562 if (num_directions == 2) {
1563 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1564 out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize
1565 << ";\n";
1566 out << SP << SP << "size_t y_h_offset = batch * " << num_directions * fAttrHiddenSize << " + "
1567 << fAttrHiddenSize << ";\n";
1568 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1569 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + y_h_offset);\n";
1570 out << SP << "}\n";
1571 }
1572 }
1573
1574 if (!fNY_c.empty()) {
1575 // copy the cell_state into Y_c
1576 if (fAttrDirection == "backward") {
1577 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1578 out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
1579 out << SP << SP << "size_t y_h_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
1580 out << SP << SP << "std::copy(" << OpName << "_cell_state + offset, " << OpName << "_cell_state + offset + "
1581 << fAttrHiddenSize << ", tensor_" << fNY_c << " + y_h_offset);\n";
1582 out << SP << "}\n";
1583 } else {
1584 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1585 if (fNSequence_lens.empty()) {
1586 out << SP << SP << "size_t seq = " << seq_length - 1 << ";\n";
1587 } else {
1588 out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
1589 }
1590 out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
1591 << " + batch * " << fAttrHiddenSize << ";\n";
1592 out << SP << SP << "size_t y_h_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
1593 out << SP << SP << "std::copy(" << OpName << "_cell_state + offset, " << OpName << "_cell_state + offset + "
1594 << fAttrHiddenSize << ", tensor_" << fNY_c << " + y_h_offset);\n";
1595 out << SP << "}\n";
1596 }
1597 if (num_directions == 2) {
1598 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1599 out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize
1600 << ";\n";
1601 out << SP << SP << "size_t y_h_offset = batch * " << num_directions * fAttrHiddenSize << " + "
1602 << fAttrHiddenSize << ";\n";
1603 out << SP << SP << "std::copy(" << OpName << "_cell_state + offset, " << OpName << "_cell_state + offset + "
1604 << fAttrHiddenSize << ", tensor_" << fNY_c << " + y_h_offset);\n";
1605 out << SP << "}\n";
1606 }
1607 }
1608 }
1609
1610 return out.str();
1611}
1612
1613} // namespace TMVA::Experimental::SOFIE
1614
1615#endif
#define b(i)
Definition RSha256.hxx:100
#define h(i)
Definition RSha256.hxx:106
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h offset
char name[80]
Definition TGX11.cxx:146
const_iterator begin() const
const_iterator end() const
Long Short-Term Memory operator.
std::string GenerateSessionMembersCode(std::string opName) override
Generate the code for the Session internal data vectors.
std::vector< size_t > fShapeR
Shape of the recurrence.
std::vector< size_t > fShapeInitial_c
Shape of the initial value of the cell states.
std::string fNY_c
Name of the last sequence of the cell states.
ROperator_LSTM(std::vector< float > activation_alpha, std::vector< float > activation_beta, std::vector< std::string > activations, float clip, std::string direction, size_t hidden_size, size_t input_forget, size_t layout, std::string nameX, std::string nameW, std::string nameR, std::string nameB, std::string nameSequence_lens, std::string nameInitial_h, std::string nameInitial_c, std::string nameP, std::string nameY, std::string nameY_h, std::string nameY_c)
Constructor of ROperator_LSTM from the attributes.
std::string fNR
Name of the recurrence.
std::vector< size_t > fShapeY_h
Shape of the last sequence of the output.
std::vector< float > fAttrActivationAlpha
Sacling values used by some activation functions.
std::vector< size_t > fShapeInitial_h
Shape of the initial value of the hidden states.
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > input) override
Infers the shape of the output tensors.
size_t fAttrHiddenSize
Number of the hidden layers.
std::vector< std::string > GetBlasRoutines() override
Returns the blas routines needed to compile the generated code.
std::string fNInitial_c
Name of the initial value of the cell states.
std::vector< size_t > fShapeB
Shape of the bias.
std::string fNW
Name of the weights.
std::vector< size_t > fShapeY
Shape of the output.
std::vector< size_t > fShapeP
Shape of the peepholes.
std::string fType
Type of the tensors.
std::string fAttrDirection
Direction of processing.
std::vector< float > fAttrActivationBeta
Scaling values used by some activation functions.
std::vector< size_t > fShapeX
Shape of the input.
std::vector< size_t > fShapeW
Shape of the weights.
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input) override
Infers the type of the output tensors.
std::string fNSequence_lens
Name of length of the sequences.
std::string fNY_h
Name of the last sequence of the output.
std::string fNY
Name of the output.
void Initialize(RModel &) override
Initialize the model.
std::vector< std::string > fAttrActivations
Activation functions.
std::string Generate(std::string OpName) override
Generate the inference code.
std::vector< size_t > fShapeSequence_lens
Shape of the length of the sequences.
std::vector< size_t > fShapeY_c
Shape of the last sequence of the cell states.
std::string fNInitial_h
Name of the initial value of the hidden states.
ROperator_LSTM()
Default constructor of ROperator_LSTM.
std::vector< std::string_view > fInputTensorNames
Definition ROperator.hxx:47
std::vector< std::string_view > fOutputTensorNames
Definition ROperator.hxx:48
static uint64_t sum(uint64_t i)
Definition Factory.cxx:2338