Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_LSTM.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR_LSTM
2#define TMVA_SOFIE_ROPERATOR_LSTM
3
4#include "TMVA/RModel.hxx"
5#include "TMVA/ROperator.hxx"
7
8#include <memory>
9#include <sstream>
10#include <string>
11#include <vector>
12
14
15/*! \brief Long Short-Term Memory operator
16 *
17 * Inference code generation for one-layer LSTM. Supports forward, reverse and bidirectional LSTM.
18 * See the <a href="https://github.com/onnx/onnx/blob/master/docs/Operators.md#LSTM">ONNX documentation</a>
19 * for details about the supported LSTM architectures.
20 */
21template <typename T> class ROperator_LSTM final : public ROperator {
22 private:
23 std::vector<float> fAttrActivationAlpha; ///< Sacling values used by some activation functions
24 std::vector<float> fAttrActivationBeta; ///< Scaling values used by some activation functions
25 std::vector<std::string> fAttrActivations; ///< Activation functions
26 float fAttrClip; ///< Clip threshold
27 std::string fAttrDirection; ///< Direction of processing
28 size_t fAttrHiddenSize; ///< Number of the hidden layers
29 size_t fAttrInputForget; ///< Forget gate
30 size_t fAttrLayout; ///< Data layout
31
32 std::string fNX; ///< Name of the input
33 std::string fNW; ///< Name of the weights
34 std::string fNR; ///< Name of the recurrence
35 std::string fNB; ///< Name of the bias
36 std::string fNSequence_lens; ///< Name of length of the sequences
37 std::string fNInitial_h; ///< Name of the initial value of the hidden states
38 std::string fNInitial_c; ///< Name of the initial value of the cell states
39 std::string fNP; ///< Name of peepholes
40 std::string fNY; ///< Name of the output
41 std::string fNY_h; ///< Name of the last sequence of the output
42 std::string fNY_c; ///< Name of the last sequence of the cell states
43
44 std::vector<size_t> fShapeX; ///< Shape of the input
45 std::vector<size_t> fShapeW; ///< Shape of the weights
46 std::vector<size_t> fShapeR; ///< Shape of the recurrence
47 std::vector<size_t> fShapeB; ///< Shape of the bias
48 std::vector<size_t> fShapeSequence_lens; ///< Shape of the length of the sequences
49 std::vector<size_t> fShapeInitial_h; ///< Shape of the initial value of the hidden states
50 std::vector<size_t> fShapeInitial_c; ///< Shape of the initial value of the cell states
51 std::vector<size_t> fShapeP; ///< Shape of the peepholes
52 std::vector<size_t> fShapeY; ///< Shape of the output
53 std::vector<size_t> fShapeY_h; ///< Shape of the last sequence of the output
54 std::vector<size_t> fShapeY_c; ///< Shape of the last sequence of the cell states
55
56 std::string fType; ///< Type of the tensors
57
58 public:
59 /*! Default constructor of ROperator_LSTM */
61
62 /*! \brief Constructor of ROperator_LSTM from the attributes
63 *
64 * \param activation_alpha scaling values used by some activation functions
65 * \param activation_beta scaling values used by some activation functions
66 * \param activations activation functions
67 * \param clip clip threshold
68 * \param direction direction of processing of the sequneces
69 * \param hidden_size number of hidden layers
70 * \param input_forget forget gate
71 * \param layout data layout
72 * \param nameX name of the input tensor
73 * \param nameW name of the weight tensor
74 * \param nameR name of the recurrence tensor
75 * \param nameB name of the bias tensor
76 * \param nameSequence_lens name of the length of the sequences
77 * \param nameInitial_h name of the initial value of the hidden states
78 * \param nameInitial_c name of the initial value of the cell states
79 * \param nameP name of the peepholes tensor
80 * \param nameY name of the output
81 * \param nameY_h name of the last sequence of the output
82 * \param nameY_c name of the last sequence of the cell states
83 */
84 ROperator_LSTM(std::vector<float> activation_alpha,
85 std::vector<float> activation_beta,
86 std::vector<std::string> activations, float clip,
87 std::string direction, size_t hidden_size,
88 size_t input_forget, size_t layout,
89 std::string nameX, std::string nameW, std::string nameR,
90 std::string nameB, std::string nameSequence_lens,
91 std::string nameInitial_h, std::string nameInitial_c, std::string nameP,
92 std::string nameY, std::string nameY_h, std::string nameY_c)
97 fNX(UTILITY::Clean_name(nameX)), fNW(UTILITY::Clean_name(nameW)),
98 fNR(UTILITY::Clean_name(nameR)), fNB(UTILITY::Clean_name(nameB)),
99 fNSequence_lens(UTILITY::Clean_name(nameSequence_lens)),
100 fNInitial_h(UTILITY::Clean_name(nameInitial_h)),
101 fNInitial_c(UTILITY::Clean_name(nameInitial_c)), fNP(UTILITY::Clean_name(nameP)),
102 fNY(UTILITY::Clean_name(nameY)), fNY_h(UTILITY::Clean_name(nameY_h)),
103 fNY_c(UTILITY::Clean_name(nameY_c)) {
104 if (std::is_same<T, float>::value) {
105 fType = "float";
106 } else {
107 throw std::runtime_error(
108 "TMVA SOFIE Encountered unsupported type parsing a LSTM operator");
109 }
110
112 if (!fNB.empty()){
113 fInputTensorNames.emplace_back(fNB);
114 }
115 if (!fNSequence_lens.empty()){
117 }
118 if (!fNInitial_h.empty()){
119 fInputTensorNames.emplace_back(fNInitial_h);
120 }
121 if (!fNInitial_c.empty()){
122 fInputTensorNames.emplace_back(fNInitial_c);
123 }
124 if (!fNP.empty()){
125 fInputTensorNames.emplace_back(fNP);
126 }
127
128 fOutputTensorNames = { };
129 if (!fNY.empty()){
130 fOutputTensorNames.emplace_back(fNY);
131 }
132 if (!fNY_h.empty()){
133 fOutputTensorNames.emplace_back(fNY_h);
134 }
135 if (!fNY_c.empty()){
136 fOutputTensorNames.emplace_back(fNY_c);
137 }
138 }
139
140 /*! \brief Infers the type of the output tensors
141 *
142 * \param input type of the input tensors
143 */
144 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override;
145
146 /*! \brief Infers the shape of the output tensors
147 *
148 * \param input shape of the input tensors
149 */
150 std::vector<std::vector<size_t>>
151 ShapeInference(std::vector<std::vector<size_t>> input) override;
152
153 /*! \brief Initialize the model
154 *
155 * \param model Model
156 */
157 void Initialize(RModel &) override;
158
159 /*! \brief Generate the inference code
160 *
161 * \param OpName name of the operator
162 */
163 std::string Generate(std::string OpName) override;
164
165 /*! \brief Generate the code for the Session internal data vectors
166 *
167 * \param opName name of the operator
168 */
169 std::string GenerateSessionMembersCode(std::string opName) override;
170
171 /*! \brief Returns the blas routines needed to compile the generated code
172 */
173 std::vector<std::string> GetBlasRoutines() override { return { std::string("Gemm"), std::string("Axpy") }; }
174};
175
176template <typename T>
177auto ROperator_LSTM<T>::TypeInference(std::vector<ETensorType> input) -> std::vector<ETensorType>
178{
179 ETensorType out = input[0];
180 return {out, out};
181}
182
183template <typename T>
184auto ROperator_LSTM<T>::ShapeInference(std::vector<std::vector<size_t>> input) -> std::vector<std::vector<size_t>>
185{
186 size_t num_directions = input[1][0];
187 size_t hidden_size = input[1][1] / 4;
188 if (fAttrLayout == 0) {
189 size_t seq_length = input[0][0];
190 size_t batch_size = input[0][1];
191 std::vector<std::vector<size_t>> ret({{seq_length, num_directions, batch_size, hidden_size},
194 return ret;
195 } else {
196 size_t batch_size = input[0][0];
197 size_t seq_length = input[0][1];
198 std::vector<std::vector<size_t>> ret({{batch_size, seq_length, num_directions, hidden_size},
201 return ret;
202 }
203}
204
205template <typename T>
207{
208 fUseSession = model.UseSession();
209 // Check the input and output tensors
210 if (!model.CheckIfTensorAlreadyExist(fNX)) {
211 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNX + " is not found in model.");
212 }
213 fShapeX = model.GetTensorShape(fNX);
214 if (fShapeX.size() != 3) {
215 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNX + " is not of 3 dimensions.");
216 }
217 if (!model.CheckIfTensorAlreadyExist(fNW)) {
218 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNW + " is not found in model.");
219 }
220 fShapeW = model.GetTensorShape(fNW);
221 if (fShapeW.size() != 3) {
222 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNW + " is not of 3 dimensions.");
223 }
224 if (!model.CheckIfTensorAlreadyExist(fNR)) {
225 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNR + " is not found in model.");
226 }
227 fShapeR = model.GetTensorShape(fNR);
228 if (fShapeR.size() != 3) {
229 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNR + " is not of 3 dimensions.");
230 }
231 if (!fNB.empty()) {
232 if (!model.CheckIfTensorAlreadyExist(fNB)) {
233 throw std::runtime_error("TMVA SOFIE LSTM op input tensor " + fNB + " is not found in model.");
234 }
235 fShapeB = model.GetTensorShape(fNB);
236 if (fShapeB.size() != 2 && fShapeB.size() != 5) {
237 throw std::runtime_error("TMVA SOFIE LSTM op input tensor " + fNB + " is not of 2 or 5 dimensions.");
238 }
239 if (fShapeB.size() == 2) {
240 // Broadcasting the bias
241 auto original_data = model.GetInitializedTensorData(fNB);
242 size_t num_directions = fShapeW[0];
243 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
244 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
245 if (fType == "float") {
246 float *original_bias = static_cast<float *>(original_data.get());
247 float *new_bias = new float[4 * num_directions * seq_length * batch_size * fAttrHiddenSize];
248 for (size_t gate = 0; gate < 4; gate++) {
249 std::vector<float> sum(fAttrHiddenSize);
250 for (size_t direction = 0; direction < num_directions; direction++) {
251 size_t offset = direction * 8 * fAttrHiddenSize + gate * fAttrHiddenSize;
252 for (size_t h = 0; h < fAttrHiddenSize; h++) {
253 sum[h] = original_bias[offset + h] + original_bias[offset + h + 4 * fAttrHiddenSize];
254 }
255 for (size_t seq = 0; seq < seq_length; seq++) {
256 for (size_t batch = 0; batch < batch_size; batch++) {
257 size_t bias_offset = gate * num_directions * seq_length * batch_size * fAttrHiddenSize +
258 direction * seq_length * batch_size * fAttrHiddenSize +
259 seq * batch_size * fAttrHiddenSize + batch * fAttrHiddenSize;
260 std::copy(sum.begin(), sum.end(), new_bias + bias_offset);
261 }
262 }
263 }
264 }
265 std::vector<size_t> new_bias_shape = {4, num_directions, seq_length, batch_size, fAttrHiddenSize};
266 std::shared_ptr<void> new_bias_ptr(new_bias, std::default_delete<float[]>());
267 model.UpdateInitializedTensor(fNB, model.GetTensorType(fNB), new_bias_shape, new_bias_ptr);
268 fShapeB = model.GetTensorShape(fNB);
269 }
270 }
271 }
272 if (!fNSequence_lens.empty()) {
273 if (!model.CheckIfTensorAlreadyExist(fNSequence_lens)) {
274 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNSequence_lens + "is not found in model.");
275 }
276 fShapeSequence_lens = model.GetTensorShape(fNSequence_lens);
277 if (fShapeSequence_lens.size() != 1) {
278 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNSequence_lens + " is not of 1 dimension.");
279 }
280 }
281 if (!fNInitial_h.empty()) {
282 if (!model.CheckIfTensorAlreadyExist(fNInitial_h)) {
283 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNInitial_h + " is not found in model.");
284 }
285 fShapeInitial_h = model.GetTensorShape(fNInitial_h);
286 if (fShapeInitial_h.size() != 3) {
287 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNInitial_h + " is not of 3 dimensions.");
288 }
289 }
290 if (!fNInitial_c.empty()) {
291 if (!model.CheckIfTensorAlreadyExist(fNInitial_c)) {
292 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNInitial_c + " is not found in model.");
293 }
294 fShapeInitial_c = model.GetTensorShape(fNInitial_c);
295 if (fShapeInitial_c.size() != 3) {
296 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNInitial_c + " is not of 3 dimensions.");
297 }
298 }
299 if (!fNP.empty()) {
300 if (!model.CheckIfTensorAlreadyExist(fNP)) {
301 throw std::runtime_error("TMVA SOFIE LSTM op input tensor " + fNP + " is not found in model.");
302 }
303 fShapeP = model.GetTensorShape(fNP);
304 if (fShapeP.size() != 2 && fShapeP.size() != 4) {
305 throw std::runtime_error("TMVA SOFIE LSTM op input tensor " + fNP + " is not of 2 or 4 dimensions.");
306 }
307 if (fShapeP.size() == 2) {
308 // Broadcasting the weight for peepholes
309 auto original_data = model.GetInitializedTensorData(fNP);
310 size_t num_directions = fShapeW[0];
311 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
312 if (fType == "float") {
313 float *original_p = static_cast<float *>(original_data.get());
314 float *new_p = new float[num_directions * 3 * batch_size * fAttrHiddenSize];
315 for (size_t direction = 0; direction < num_directions; direction++) {
316 for (size_t gate = 0; gate < 3; gate++) {
317 size_t p_offset = direction * 3 * fAttrHiddenSize + gate * fAttrHiddenSize;
318 for (size_t batch = 0; batch < batch_size; batch++) {
319 size_t offset = direction * 3 * batch_size * fAttrHiddenSize +
320 gate * batch_size * fAttrHiddenSize + batch * fAttrHiddenSize;
321 std::copy(original_p + p_offset, original_p + p_offset + fAttrHiddenSize, new_p + offset);
322 }
323 }
324 }
325 std::vector<size_t> new_p_shape = {num_directions, 3, batch_size, fAttrHiddenSize};
326 std::shared_ptr<void> new_p_ptr(new_p, std::default_delete<float[]>());
327 model.UpdateInitializedTensor(fNP, model.GetTensorType(fNP), new_p_shape, new_p_ptr);
328 fShapeP = model.GetTensorShape(fNP);
329 }
330 }
331 }
332 if (!fNY.empty()) {
333 fShapeY = ShapeInference({fShapeX, fShapeW})[0];
334 if (!model.CheckIfTensorAlreadyExist(fNY)) {
335 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
336 }
337 }
338 if (!fNY_h.empty()) {
339 fShapeY_h = ShapeInference({fShapeX, fShapeW})[1];
340 if (!model.CheckIfTensorAlreadyExist(fNY_h)) {
341 model.AddIntermediateTensor(fNY_h, model.GetTensorType(fNX), fShapeY_h);
342 }
343 }
344 if (!fNY_c.empty()) {
345 fShapeY_c = ShapeInference({fShapeX, fShapeW})[2];
346 if (!model.CheckIfTensorAlreadyExist(fNY_c)) {
347 model.AddIntermediateTensor(fNY_c, model.GetTensorType(fNX), fShapeY_c);
348 }
349 }
350 // Check the attributes
351 for (auto &activation : fAttrActivations) {
352 if (activation != "Relu" && activation != "Tanh" && activation != "Sigmoid" && activation != "Affine" &&
353 activation != "LeakyRelu" && activation != "ThresholdRelu" && activation != "ScaledTanh" &&
354 activation != "HardSigmoid" && activation != "Elu" && activation != "Softsign" && activation != "Softplus") {
355 throw std::runtime_error("TMVA SOFIE - Activation function " + activation + " not implemented");
356 }
357 }
358 if (fAttrDirection != "forward" && fAttrDirection != "backward" && fAttrDirection != "bidirectional") {
359 throw std::runtime_error("TMVA SOFIE - Invalid LSTM direction fAttrDirection = " + fAttrDirection);
360 }
361 if (4 * fAttrHiddenSize != fShapeW[1]) {
362 throw std::runtime_error("TMVA SOFIE - fAttrHiddenSize must be equal to " + std::to_string(fShapeW[1] / 4));
363 }
364 if (fAttrInputForget > 1) {
365 throw std::runtime_error("TMVA SOFIE - fAttrInputForget = " + std::to_string(fAttrInputForget) +
366 " must be 0 or 1.");
367 }
368 if (fAttrLayout > 1) {
369 throw std::runtime_error("TMVA SOFIE - Layout fAttrLayout = " + std::to_string(fAttrLayout) +
370 " must be 0 (timewise) or 1 (batchwise)");
371 }
372 if (fAttrActivations.empty()) {
373 if (fAttrDirection == "bidirectional") {
374 fAttrActivations = {"Sigmoid", "Tanh", "Tanh", "Sigmoid", "Tanh", "Tanh"};
375 } else {
376 fAttrActivations = {"Sigmoid", "Tanh", "Tanh"};
377 }
378 }
379}
380
381// generate code for Session data members (e.g. internal vectors)
382template <typename T>
384{
385 opName = "op_" + opName;
386 std::stringstream out;
387
388 size_t num_directions = fShapeW[0];
389 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
390 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
391 size_t input_size = fShapeX[2];
392
393 if (fAttrLayout != 0) {
394 out << "std::vector<" << fType << "> fVec_" << opName << "_input = std::vector<" << fType << ">("
395 << seq_length * batch_size * input_size << ");\n";
396 out << "std::vector<" << fType << "> fVec_" << opName << "_initial_hidden_state = std::vector<" << fType << ">("
397 << num_directions * batch_size * fAttrHiddenSize << ");\n";
398 out << "std::vector<" << fType << "> fVec_" << opName << "_initial_cell_state = std::vector<" << fType << ">("
399 << num_directions * batch_size * fAttrHiddenSize << ");\n";
400 }
401 // Set the feedforward
402 size_t ff_size = seq_length * batch_size * fAttrHiddenSize;
403 out << "std::vector<" << fType << "> fVec_" << opName << "_ff_input_gate = std::vector<" << fType << ">(" << ff_size
404 << ");\n";
405 out << "std::vector<" << fType << "> fVec_" << opName << "_ff_output_gate = std::vector<" << fType << ">(" << ff_size
406 << ");\n";
407 out << "std::vector<" << fType << "> fVec_" << opName << "_ff_cell_gate = std::vector<" << fType << ">(" << ff_size
408 << ");\n";
409 if (fAttrInputForget == 0)
410 out << "std::vector<" << fType << "> fVec_" << opName << "_ff_forget_gate = std::vector<" << fType << ">("
411 << ff_size << ");\n";
412 // gate results
413 size_t hs_size = seq_length * num_directions * batch_size * fAttrHiddenSize;
414 out << "std::vector<" << fType << "> fVec_" << opName << "_input_gate = std::vector<" << fType << ">(" << hs_size
415 << ");\n";
416 out << "std::vector<" << fType << "> fVec_" << opName << "_output_gate = std::vector<" << fType << ">(" << hs_size
417 << ");\n";
418 out << "std::vector<" << fType << "> fVec_" << opName << "_cell_gate = std::vector<" << fType << ">(" << hs_size
419 << ");\n";
420 if (fAttrInputForget == 0)
421 out << "std::vector<" << fType << "> fVec_" << opName << "_forget_gate = std::vector<" << fType << ">(" << hs_size
422 << ");\n";
423 // cell state
424 out << "std::vector<" << fType << "> fVec_" << opName << "_cell_state = std::vector<" << fType << ">(" << hs_size
425 << ");\n";
426 out << "std::vector<" << fType << "> fVec_" << opName << "_new_cell_state = std::vector<" << fType << ">(" << hs_size
427 << ");\n";
428 // hiddden state
429 if (fAttrLayout != 0 || fNY.empty()) {
430 out << "std::vector<" << fType << "> fVec_" << opName << "_hidden_state = std::vector<" << fType << ">("
431 << hs_size << ");\n";
432 }
433
434 out << "\n";
435
436 return out.str();
437}
438
439template <typename T>
440auto ROperator_LSTM<T>::Generate(std::string OpName) -> std::string
441{
442 OpName = "op_" + OpName;
443 std::stringstream out;
444
445 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
446 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
447 size_t input_size = fShapeX[2];
448 size_t num_directions = fShapeW[0];
449
450 // set the input
451 if (fAttrLayout == 0) {
452 out << SP << fType << " const *" << OpName << "_input = tensor_" << fNX << ";\n";
453 } else {
454 if (fUseSession)
455 out << SP << fType << " * " << OpName << "_input = this->fVec_" << OpName << "_input.data();\n";
456 else
457 out << SP << fType << " " << OpName << "_input[" << seq_length * batch_size * input_size << "] = {0};\n";
458
459 out << SP << "for(size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
460 out << SP << SP << "for(size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
461 out << SP << SP << SP << "for(size_t i = 0; i < " << input_size << "; i++) {\n";
462 out << SP << SP << SP << SP << OpName << "_input[seq * " << batch_size * input_size << " + batch * " << input_size
463 << " + i] = " << "tensor_" << fNX << "[batch * " << seq_length * input_size << " + seq * " << input_size
464 << " + i];\n";
465 out << SP << SP << SP << "}\n";
466 out << SP << SP << "}\n";
467 out << SP << "}\n";
468 }
469
470 // Set the initial hidden state
471 if (!fNInitial_h.empty()) {
472 if (fAttrLayout == 0) {
473 out << SP << fType << " *" << OpName << "_initial_hidden_state = " << " tensor_" << fNInitial_h << ";\n";
474 } else {
475 if (fUseSession)
476 out << SP << fType << " * " << OpName << "_initial_hidden_state = this->fVec_" << OpName
477 << "_initial_hidden_state.data();\n";
478 else
479 out << SP << fType << " " << OpName << "_initial_hidden_state["
480 << num_directions * batch_size * fAttrHiddenSize << "] = {0};\n";
481
482 for (size_t direction = 0; direction < num_directions; direction++) {
483 out << SP << "for(size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
484 out << SP << SP << "for(size_t h = 0; h < " << fAttrHiddenSize << "; h++) {\n";
485 out << SP << SP << SP << OpName << "_initial_hidden_state[" << direction * batch_size * fAttrHiddenSize
486 << " + batch * " << fAttrHiddenSize << " + h] = tensor_" << fNInitial_h << "[batch * "
487 << num_directions * fAttrHiddenSize << " + " << direction * fAttrHiddenSize << " + h];\n";
488 out << SP << SP << "}\n";
489 out << SP << "}\n";
490 }
491 }
492 }
493
494 // Set the initial cell state
495 if (!fNInitial_c.empty()) {
496 if (fAttrLayout == 0) {
497 out << SP << fType << " *" << OpName << "_initial_cell_state = " << " tensor_" << fNInitial_c << ";\n";
498 } else {
499 if (fUseSession)
500 out << SP << fType << " * " << OpName << "_initial_cell_state = this->fVec_" << OpName
501 << "_initial_cell_state.data();\n";
502 else
503 out << SP << fType << " " << OpName << "_initial_cell_state["
504 << num_directions * batch_size * fAttrHiddenSize << "] = {0};\n";
505
506 for (size_t direction = 0; direction < num_directions; direction++) {
507 out << SP << "for(size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
508 out << SP << SP << "for(size_t h = 0; h < " << fAttrHiddenSize << "; h++) {\n";
509 out << SP << SP << SP << OpName << "_initial_cell_state[" << direction * batch_size * fAttrHiddenSize
510 << " + batch * " << fAttrHiddenSize << " + h] = tensor_" << fNInitial_c << "[batch * "
511 << num_directions * fAttrHiddenSize << " + " << direction * fAttrHiddenSize << " + h];\n";
512 out << SP << SP << "}\n";
513 out << SP << "}\n";
514 }
515 }
516 }
517
518 // Set the feedforward
519 size_t ff_size = seq_length * batch_size * fAttrHiddenSize;
520 if (fUseSession) {
521 out << SP << fType << " * " << OpName << "_ff_input_gate = this->fVec_" << OpName << "_ff_input_gate.data();\n";
522 out << SP << fType << " * " << OpName << "_ff_output_gate = this->fVec_" << OpName << "_ff_output_gate.data();\n";
523 out << SP << fType << " * " << OpName << "_ff_cell_gate = this->fVec_" << OpName << "_ff_cell_gate.data();\n";
524 if (fAttrInputForget == 0) {
525 out << SP << fType << " * " << OpName << "_ff_forget_gate = this->fVec_" << OpName
526 << "_ff_forget_gate.data();\n";
527 }
528 } else {
529 out << SP << fType << " " << OpName << "_ff_input_gate[" << ff_size << "] = {0};\n";
530 out << SP << fType << " " << OpName << "_ff_output_gate[" << ff_size << "] = {0};\n";
531 out << SP << fType << " " << OpName << "_ff_cell_gate[" << ff_size << "] = {0};\n";
532 if (fAttrInputForget == 0) {
533 out << SP << fType << " " << OpName << "_ff_forget_gate[" << ff_size << "] = {0};\n";
534 }
535 }
536 // Set the gates
537 size_t hidden_state_size = seq_length * num_directions * batch_size * fAttrHiddenSize;
538 if (fUseSession) {
539 out << SP << fType << " * " << OpName << "_input_gate = this->fVec_" << OpName << "_input_gate.data();\n";
540 out << SP << fType << " * " << OpName << "_output_gate = this->fVec_" << OpName << "_output_gate.data();\n";
541 out << SP << fType << " * " << OpName << "_cell_gate = this->fVec_" << OpName << "_cell_gate.data();\n";
542 if (fAttrInputForget == 0) {
543 out << SP << fType << " * " << OpName << "_forget_gate = this->fVec_" << OpName << "_forget_gate.data();\n";
544 }
545 } else {
546 out << SP << fType << " " << OpName << "_input_gate[" << hidden_state_size << "] = {0};\n";
547 out << SP << fType << " " << OpName << "_output_gate[" << hidden_state_size << "] = {0};\n";
548 out << SP << fType << " " << OpName << "_cell_gate[" << hidden_state_size << "] = {0};\n";
549 if (fAttrInputForget == 0) {
550 out << SP << fType << " " << OpName << "_forget_gate[" << hidden_state_size << "] = {0};\n";
551 }
552 }
553 // Set the cell state and the new cell state = h(cell state)
554 if (fUseSession) {
555 out << SP << fType << " * " << OpName << "_cell_state = this->fVec_" << OpName << "_cell_state.data();\n";
556 out << SP << fType << " * " << OpName << "_new_cell_state = this->fVec_" << OpName << "_new_cell_state.data();\n";
557 } else {
558 out << SP << fType << " " << OpName << "_cell_state[" << hidden_state_size << "] = {0};\n";
559 out << SP << fType << " " << OpName << "_new_cell_state[" << hidden_state_size << "] = {0};\n";
560 }
561
562 // Set the hidden state
563 if (fAttrLayout == 0 && !fNY.empty()) {
564 out << SP << fType << " *" << OpName << "_hidden_state = tensor_" << fNY << ";\n";
565 } else {
566 if (fUseSession) {
567 out << SP << fType << " * " << OpName << "_hidden_state = this->fVec_" << OpName << "_hidden_state.data();\n";
568 } else {
569 out << SP << fType << " " << OpName << "_hidden_state[" << hidden_state_size << "] = {0};\n";
570 }
571 }
572
573 out << SP << "char " << OpName << "_transA = 'N';\n";
574 out << SP << "char " << OpName << "_transB = 'T';\n";
575 out << SP << "int " << OpName << "_m = " << seq_length * batch_size << ";\n";
576 out << SP << "int " << OpName << "_n = " << fAttrHiddenSize << ";\n";
577 out << SP << "int " << OpName << "_k = " << input_size << ";\n";
578 if (fType == "float") {
579 out << SP << fType << " " << OpName << "_alpha = 1.;\n";
580 out << SP << fType << " " << OpName << "_beta = 0.;\n";
581 }
582 if (!fNB.empty()) {
583 out << SP << "int " << OpName << "_bias_size = " << seq_length * batch_size * fAttrHiddenSize << ";\n";
584 out << SP << "int " << OpName << "_incx = 1;\n";
585 out << SP << "int " << OpName << "_incy = 1;\n";
586 }
587
588 for (size_t direction = 0; direction < num_directions; direction++) {
589 if (direction == 0) {
590 if (fType == "float") {
591 // input_gate = input * weight_i^T
592 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName << "_n, &"
593 << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_" << fNW << ", &" << OpName
594 << "_k, " << OpName << "_input, &" << OpName << "_k, &" << OpName << "_beta, " << OpName
595 << "_ff_input_gate, &" << OpName << "_n);\n";
596 // output_gate = input * weight_o^T
597 size_t wo_offset = fAttrHiddenSize * input_size;
598 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName << "_n, &"
599 << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_" << fNW << " + " << wo_offset
600 << ", &" << OpName << "_k, " << OpName << "_input, &" << OpName << "_k, &" << OpName << "_beta, "
601 << OpName << "_ff_output_gate, &" << OpName << "_n);\n";
602 // cell_gate = input * weight_c^T
603 size_t wc_offset = 3 * fAttrHiddenSize * input_size;
604 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName << "_n, &"
605 << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_" << fNW << " + " << wc_offset
606 << ", &" << OpName << "_k, " << OpName << "_input, &" << OpName << "_k, &" << OpName << "_beta, "
607 << OpName << "_ff_cell_gate, &" << OpName << "_n);\n";
608 }
609 } else {
610 if (fType == "float") {
611 // input_gate = input * weight_i^T
612 size_t wi_offset = 4 * fAttrHiddenSize * input_size;
613 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName << "_n, &"
614 << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_" << fNW << " + " << wi_offset
615 << ", &" << OpName << "_k, " << OpName << "_input, &" << OpName << "_k, &" << OpName << "_beta, "
616 << OpName << "_ff_input_gate, &" << OpName << "_n);\n";
617 // output_gate = input * weight_o^T
618 size_t wo_offset = 4 * fAttrHiddenSize * input_size + 1 * fAttrHiddenSize * input_size;
619 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName << "_n, &"
620 << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_" << fNW << " + " << wo_offset
621 << ", &" << OpName << "_k, " << OpName << "_input, &" << OpName << "_k, &" << OpName << "_beta, "
622 << OpName << "_ff_output_gate, &" << OpName << "_n);\n";
623 // cell_gate = input * weight_c^T
624 size_t wc_offset = 4 * fAttrHiddenSize * input_size + 3 * fAttrHiddenSize * input_size;
625 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName << "_n, &"
626 << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_" << fNW << " + " << wc_offset
627 << ", &" << OpName << "_k, " << OpName << "_input, &" << OpName << "_k, &" << OpName << "_beta, "
628 << OpName << "_ff_cell_gate, &" << OpName << "_n);\n";
629 }
630 }
631 if (fAttrInputForget == 0) {
632 // forget_gate = input * weight_f^T
633 if (direction == 0) {
634 if (fType == "float") {
635 size_t wf_offset = 2 * fAttrHiddenSize * input_size;
636 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName << "_n, &"
637 << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_" << fNW << " + " << wf_offset
638 << ", &" << OpName << "_k, " << OpName << "_input, &" << OpName << "_k, &" << OpName << "_beta, "
639 << OpName << "_ff_forget_gate, &" << OpName << "_n);\n";
640 }
641 } else {
642 if (fType == "float") {
643 size_t wf_offset = 4 * fAttrHiddenSize * input_size + 2 * fAttrHiddenSize * input_size;
644 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName << "_n, &"
645 << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_" << fNW << " + " << wf_offset
646 << ", &" << OpName << "_k, " << OpName << "_input, &" << OpName << "_k, &" << OpName << "_beta, "
647 << OpName << "_ff_forget_gate, &" << OpName << "_n);\n";
648 }
649 }
650 }
651
652 // Add the bias
653 if (!fNB.empty()) {
654 if (direction == 0) {
655 if (fType == "float") {
656 // ff_input_gate += bias_i
657 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_" << fNB << ", &"
658 << OpName << "_incx, " << OpName << "_ff_input_gate, &" << OpName << "_incy);\n";
659 // ff_output_gate += bias_o
660 size_t bo_offset = seq_length * batch_size * fAttrHiddenSize;
661 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_" << fNB << " + "
662 << bo_offset << ", &" << OpName << "_incx, " << OpName << "_ff_output_gate, &" << OpName
663 << "_incy);\n";
664 // ff_cell_gate += bias_c
665 size_t bc_offset = 3 * seq_length * batch_size * fAttrHiddenSize;
666 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_" << fNB << " + "
667 << bc_offset << ", &" << OpName << "_incx, " << OpName << "_ff_cell_gate, &" << OpName
668 << "_incy);\n";
669 }
670 } else {
671 if (fType == "float") {
672 // ff_input_gate += bias_i
673 size_t bi_offset = 4 * seq_length * batch_size * fAttrHiddenSize;
674 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_" << fNB << " + "
675 << bi_offset << ", &" << OpName << "_incx, " << OpName << "_ff_input_gate, &" << OpName
676 << "_incy);\n";
677 // ff_output_gate += bias_o
678 size_t bo_offset =
679 4 * seq_length * batch_size * fAttrHiddenSize + seq_length * batch_size * fAttrHiddenSize;
680 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_" << fNB << " + "
681 << bo_offset << ", &" << OpName << "_incx, " << OpName << "_ff_output_gate, &" << OpName
682 << "_incy);\n";
683 // ff_cell_gate += bias_c
684 size_t bc_offset = 4 * num_directions * seq_length * batch_size * fAttrHiddenSize +
685 3 * seq_length * batch_size * fAttrHiddenSize;
686 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_" << fNB << " + "
687 << bc_offset << ", &" << OpName << "_incx, " << OpName << "_ff_cell_gate, &" << OpName
688 << "_incy);\n";
689 }
690 }
691 if (fAttrInputForget == 0) {
692 // ff_forget_gate += bias_f
693 if (direction == 0) {
694 if (fType == "float") {
695 size_t bo_offset = 2 * seq_length * batch_size * fAttrHiddenSize;
696 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_" << fNB
697 << " + " << bo_offset << ", &" << OpName << "_incx, " << OpName << "_ff_forget_gate, &" << OpName
698 << "_incy);\n";
699 }
700 } else {
701 if (fType == "float") {
702 size_t bo_offset =
703 4 * seq_length * batch_size * fAttrHiddenSize + 2 * seq_length * batch_size * fAttrHiddenSize;
704 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_" << fNB
705 << " + " << bo_offset << ", &" << OpName << "_incx, " << OpName << "_ff_forget_gate, &" << OpName
706 << "_incy);\n";
707 }
708 }
709 }
710 }
711
712 // Copy ff_input_gate, ff_output_gate, ff_cell_gate and ff_forget_gate into input_gate, output_gate,
713 // cell_gate and forget_gate
714 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
715 out << SP << SP << "size_t ff_offset = seq * " << batch_size * fAttrHiddenSize << ";\n";
716 if (direction == 0) {
717 out << SP << SP << "size_t gate_offset = seq * " << num_directions * batch_size * fAttrHiddenSize << ";\n";
718 } else {
719 out << SP << SP << "size_t gate_offset = seq * " << num_directions * batch_size * fAttrHiddenSize << " + "
720 << batch_size * fAttrHiddenSize << ";\n";
721 }
722 size_t ff_seq_size = batch_size * fAttrHiddenSize;
723 out << SP << SP << "std::copy(" << OpName << "_ff_input_gate + ff_offset, " << OpName
724 << "_ff_input_gate + ff_offset + " << ff_seq_size << ", " << OpName << "_input_gate + gate_offset);\n";
725 out << SP << SP << "std::copy(" << OpName << "_ff_output_gate + ff_offset, " << OpName
726 << "_ff_output_gate + ff_offset + " << ff_seq_size << ", " << OpName << "_output_gate + gate_offset);\n";
727 out << SP << SP << "std::copy(" << OpName << "_ff_cell_gate + ff_offset, " << OpName
728 << "_ff_cell_gate + ff_offset + " << ff_seq_size << ", " << OpName << "_cell_gate + gate_offset);\n";
729 if (fAttrInputForget == 0) {
730 out << SP << SP << "std::copy(" << OpName << "_ff_forget_gate + ff_offset, " << OpName
731 << "_ff_forget_gate + ff_offset + " << ff_seq_size << ", " << OpName << "_forget_gate + gate_offset);\n";
732 }
733 out << SP << "}\n";
734
735 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
736 if (fAttrDirection == "backward" || direction == 1) {
737 out << SP << SP << "size_t index = " << seq_length - 1 << " - seq;\n";
738 } else {
739 out << SP << SP << "size_t index = seq;\n";
740 }
741 out << SP << SP << "int m2 = " << batch_size << ";\n";
742 if (direction == 0) {
743 out << SP << SP << "size_t offset = index * " << num_directions * batch_size * fAttrHiddenSize << ";\n";
744 } else {
745 out << SP << SP << "size_t offset = index * " << num_directions * batch_size * fAttrHiddenSize << " + "
746 << batch_size * fAttrHiddenSize << ";\n";
747 }
748 size_t size = batch_size * fAttrHiddenSize;
749 // gate = gate + initial_hidden_state * Recurrence^T
750 out << SP << SP << "if (seq == 0) {\n";
751 if (!fNInitial_h.empty()) {
752 if (direction == 0) {
753 if (fType == "float") {
754 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
755 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << ", &" << OpName
756 << "_n, " << OpName << "_initial_hidden_state, &" << OpName << "_n, &" << OpName << "_alpha, "
757 << OpName << "_input_gate + offset, &" << OpName << "_n);\n";
758 size_t ro_offset = fAttrHiddenSize * fAttrHiddenSize;
759 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
760 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << ro_offset
761 << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName << "_n, &" << OpName
762 << "_alpha, " << OpName << "_output_gate + offset, &" << OpName << "_n);\n";
763 size_t rc_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
764 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
765 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << rc_offset
766 << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName << "_n, &" << OpName
767 << "_alpha, " << OpName << "_cell_gate + offset, &" << OpName << "_n);\n";
768 if (fAttrInputForget == 0) {
769 size_t rf_offset = 2 * fAttrHiddenSize * fAttrHiddenSize;
770 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
771 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
772 << rf_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
773 << "_n, &" << OpName << "_alpha, " << OpName << "_forget_gate + offset, &" << OpName << "_n);\n";
774 }
775 }
776 } else { // direction=1
777 if (fType == "float") {
778 size_t ri_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
779 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
780 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << ri_offset
781 << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName << "_n, &" << OpName
782 << "_alpha, " << OpName << "_input_gate + offset, &" << OpName << "_n);\n";
783 size_t ro_offset = 4 * fAttrHiddenSize * fAttrHiddenSize + 1 * fAttrHiddenSize * fAttrHiddenSize;
784 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
785 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << ro_offset
786 << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName << "_n, &" << OpName
787 << "_alpha, " << OpName << "_output_gate + offset, &" << OpName << "_n);\n";
788 size_t rc_offset = 4 * fAttrHiddenSize * fAttrHiddenSize + 3 * fAttrHiddenSize * fAttrHiddenSize;
789 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
790 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << rc_offset
791 << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName << "_n, &" << OpName
792 << "_alpha, " << OpName << "_cell_gate + offset, &" << OpName << "_n);\n";
793 if (fAttrInputForget == 0) {
794 size_t rf_offset = 4 * fAttrHiddenSize * fAttrHiddenSize + 2 * fAttrHiddenSize * fAttrHiddenSize;
795 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
796 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
797 << rf_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
798 << "_n, &" << OpName << "_alpha, " << OpName << "_forget_gate + offset, &" << OpName << "_n);\n";
799 }
800 }
801 }
802 }
803 out << SP << SP << "} else {\n";
804 // gate = gate + previous_hidden_state * Recurrence^T
805 if (direction == 0) {
806 if (fAttrDirection == "backward") {
807 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
808 << num_directions * batch_size * fAttrHiddenSize << ";\n";
809 } else {
810 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
811 << num_directions * batch_size * fAttrHiddenSize << ";\n";
812 }
813 if (fType == "float") {
814 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
815 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << ", &" << OpName << "_n, "
816 << OpName << "_hidden_state + previous_offset, &" << OpName << "_n, &" << OpName << "_alpha, " << OpName
817 << "_input_gate + offset, &" << OpName << "_n);\n";
818 size_t ro_offset = 1 * fAttrHiddenSize * fAttrHiddenSize;
819 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
820 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << ro_offset
821 << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &" << OpName << "_n, &"
822 << OpName << "_alpha, " << OpName << "_output_gate + offset, &" << OpName << "_n);\n";
823 size_t rc_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
824 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
825 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << rc_offset
826 << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &" << OpName << "_n, &"
827 << OpName << "_alpha, " << OpName << "_cell_gate + offset, &" << OpName << "_n);\n";
828 if (fAttrInputForget == 0) {
829 size_t rf_offset = 2 * fAttrHiddenSize * fAttrHiddenSize;
830 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
831 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << rf_offset
832 << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &" << OpName << "_n, &"
833 << OpName << "_alpha, " << OpName << "_forget_gate + offset, &" << OpName << "_n);\n";
834 }
835 }
836 } else {
837 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
838 << num_directions * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
839 if (fType == "float") {
840 size_t ri_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
841 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
842 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << ri_offset
843 << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &" << OpName << "_n, &"
844 << OpName << "_alpha, " << OpName << "_input_gate + offset, &" << OpName << "_n);\n";
845 size_t ro_offset = 4 * fAttrHiddenSize * fAttrHiddenSize + fAttrHiddenSize * fAttrHiddenSize;
846 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
847 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << ro_offset
848 << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &" << OpName << "_n, &"
849 << OpName << "_alpha, " << OpName << "_output_gate + offset, &" << OpName << "_n);\n";
850 size_t rc_offset = 4 * fAttrHiddenSize * fAttrHiddenSize + 3 * fAttrHiddenSize * fAttrHiddenSize;
851 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
852 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << rc_offset
853 << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &" << OpName << "_n, &"
854 << OpName << "_alpha, " << OpName << "_cell_gate + offset, &" << OpName << "_n);\n";
855 if (fAttrInputForget == 0) {
856 size_t rf_offset = 4 * fAttrHiddenSize * fAttrHiddenSize + 2 * fAttrHiddenSize * fAttrHiddenSize;
857 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
858 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << rf_offset
859 << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &" << OpName << "_n, &"
860 << OpName << "_alpha, " << OpName << "_forget_gate + offset, &" << OpName << "_n);\n";
861 }
862 }
863 }
864 out << SP << SP << "}\n";
865
866 // Clip the elements of the cell gate into the range [-fAttrClip, fAttrClip]
867 if (fAttrClip > .0) {
868 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
869 if (fType == "float") {
870 out << SP << SP << SP << "float x = (" << OpName << "_cell_gate[i] > " << -fAttrClip << ") ? " << OpName
871 << "_cell_gate[i] : " << -fAttrClip << ";\n";
872 }
873 out << SP << SP << SP << OpName << "_cell_gate[i] = (x < " << fAttrClip << ") ? x : " << fAttrClip << ";\n";
874 out << SP << SP << "}\n";
875 }
876 // Apply the activation function to the cell gate, cell_gate = g(cell_gate)
877 if (fAttrActivations[direction * 3 + 1] == "Relu") {
878 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
879 out << SP << SP << SP << "if (" << OpName << "_cell_gate[i] < 0.)\n";
880 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = 0.;\n";
881 out << SP << SP << "}\n";
882 } else if (fAttrActivations[direction * 3 + 1] == "Tanh") {
883 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
884 if (fType == "float") {
885 out << SP << SP << SP << "float ex = exp(-2 * " << OpName << "_cell_gate[i]);\n";
886 }
887 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = (1. - ex) / (1. + ex);\n";
888 out << SP << SP << "}\n";
889 } else if (fAttrActivations[direction * 3 + 1] == "Sigmoid") {
890 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
891 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = 1. / (1. + exp(-" << OpName << "_cell_gate[i]));\n";
892 out << SP << SP << "}\n";
893 } else if (fAttrActivations[direction * 3 + 1] == "Affine") {
894 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
895 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = " << fAttrActivationAlpha[direction * 3 + 1] << " * "
896 << OpName << "_cell_gate[i] + " << fAttrActivationBeta[direction * 3 + 1] << ";\n";
897 out << SP << SP << "}\n";
898 } else if (fAttrActivations[direction * 3 + 1] == "ScaledTanh") {
899 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
900 if (fType == "float") {
901 out << SP << SP << SP << "float ex = exp(-2 * " << fAttrActivationBeta[direction * 3 + 1] << " * " << OpName
902 << "_cell_gate[i]);\n";
903 }
904 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = " << fAttrActivationAlpha[direction * 3 + 1]
905 << " * (1. - ex) / (1. + ex);\n";
906 out << SP << SP << "}\n";
907 } else if (fAttrActivations[direction * 3 + 1] == "HardSigmoid") {
908 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
909 if (fType == "float") {
910 out << SP << SP << SP << "float a = " << fAttrActivationAlpha[direction * 3 + 1] << " * " << OpName
911 << "_cell_gate[i] + " << fAttrActivationBeta[direction * 3 + 1] << ";\n";
912 out << SP << SP << SP << "float b = (a > 0.) ? a : 0.;\n";
913 }
914 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = (b < 1.) ? b : 1.;\n";
915 out << SP << SP << "}\n";
916 } else if (fAttrActivations[direction * 3 + 1] == "LeakyRelu") {
917 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
918 out << SP << SP << SP << "if (" << OpName << "_cell_gate[i] < 0.)\n";
919 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = " << fAttrActivationAlpha[direction * 3 + 1] << " * "
920 << OpName << "_cell_gate[i];\n";
921 out << SP << SP << "}\n";
922 } else if (fAttrActivations[direction * 3 + 1] == "ThresholdRelu") {
923 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
924 out << SP << SP << SP << "if (" << OpName << "_cell_gate[i] < " << fAttrActivationAlpha[direction * 3 + 1]
925 << ")\n";
926 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = 0.;\n";
927 out << SP << SP << "}";
928 } else if (fAttrActivations[direction * 3 + 1] == "Elu") {
929 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
930 out << SP << SP << SP << "if (" << OpName << "_cell_gate[i] < 0.)\n";
931 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = " << fAttrActivationAlpha[direction * 3 + 1]
932 << " * exp(" << OpName << "_cell_gate[i] - 1.);\n";
933 out << SP << SP << "}\n";
934 } else if (fAttrActivations[direction * 3 + 1] == "Softsign") {
935 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
936 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = " << OpName << "_cell_gate[i] / (1. + abs(" << OpName
937 << "_cell_gate[i]));\n";
938 out << SP << SP << "}\n";
939 } else { // fAttrActivations[direction * 3 + 1] = Softplus
940 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
941 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = log(1. + exp(" << OpName << "_cell_gate[i]));\n";
942 out << SP << SP << "}\n";
943 }
944
945 // Peephole connections for the input gate and the forget gate
946 if (!fNP.empty()) {
947 // gate = 1.0 * gate + previous_cell_state * P^T
948 out << SP << SP << "if (seq == 0) {\n";
949 if (!fNInitial_c.empty()) {
950 if (direction == 0) {
951 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
952 out << SP << SP << SP << SP << OpName << "_input_gate[i + offset] += tensor_" << fNP << "[i] * "
953 << OpName << "_initial_cell_state[i];\n";
954 out << SP << SP << SP << "}\n";
955 if (fAttrInputForget == 0) {
956 size_t pf_offset = batch_size * fAttrHiddenSize;
957 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
958 out << SP << SP << SP << SP << OpName << "_forget_gate[i + offset] += tensor_" << fNP << "[i + "
959 << pf_offset << "] * " << OpName << "_initial_cell_state[i];\n";
960 out << SP << SP << SP << "}\n";
961 }
962 } else {
963 size_t pi_offset = 3 * batch_size * fAttrHiddenSize;
964 size_t initial_c_offset = batch_size * fAttrHiddenSize;
965 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
966 out << SP << SP << SP << SP << OpName << "_input_gate[i + offset] += tensor_" << fNP << "[i + "
967 << pi_offset << "] * " << OpName << "_initial_cell_state[i + " << initial_c_offset << "];\n";
968 out << SP << SP << SP << "}\n";
969 if (fAttrInputForget == 0) {
970 size_t pf_offset = 3 * batch_size * fAttrHiddenSize + batch_size * fAttrHiddenSize;
971 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
972 out << SP << SP << SP << SP << OpName << "_forget_gate[i + offset] += tensor_" << fNP << "[i + "
973 << pf_offset << "] * " << OpName << "_initial_cell_state[i + " << initial_c_offset << "];\n";
974 out << SP << SP << SP << "}\n";
975 }
976 }
977 }
978 out << SP << SP << "} else {\n";
979 if (direction == 0) {
980 if (fAttrDirection == "backward") {
981 out << SP << SP << SP << "size_t c_offset = (index + 1) * "
982 << num_directions * batch_size * fAttrHiddenSize << ";\n";
983 } else {
984 out << SP << SP << SP << "size_t c_offset = (seq - 1) * "
985 << num_directions * batch_size * fAttrHiddenSize << ";\n";
986 }
987 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
988 out << SP << SP << SP << SP << OpName << "_input_gate[i + offset] += tensor_" << fNP << "[i] * " << OpName
989 << "_cell_state[i + c_offset];\n";
990 out << SP << SP << SP << "}\n";
991 if (fAttrInputForget == 0) {
992 size_t pf_offset = batch_size * fAttrHiddenSize;
993 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
994 out << SP << SP << SP << SP << OpName << "_forget_gate[i + offset] += tensor_" << fNP << "[i + "
995 << pf_offset << "] * " << OpName << "_cell_state[i + c_offset];\n";
996 out << SP << SP << SP << "}\n";
997 }
998 } else { // direction=1
999 size_t pi_offset = 3 * batch_size * fAttrHiddenSize;
1000 out << SP << SP << SP << "size_t c_offset = (index + 1) * " << num_directions * batch_size * fAttrHiddenSize
1001 << " + " << batch_size * fAttrHiddenSize << ";\n";
1002 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
1003 out << SP << SP << SP << SP << OpName << "_input_gate[i + offset] += tensor_" << fNP << "[i + " << pi_offset
1004 << "] * " << OpName << "_cell_state[i + c_offset];\n";
1005 out << SP << SP << SP << "}\n";
1006 if (fAttrInputForget == 0) {
1007 size_t pf_offset = 3 * batch_size * fAttrHiddenSize + batch_size * fAttrHiddenSize;
1008 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
1009 out << SP << SP << SP << SP << OpName << "_forget_gate[i + offset] += tensor_" << fNP << "[i + "
1010 << pf_offset << "] * " << OpName << "_cell_state[i + c_offset];\n";
1011 out << SP << SP << SP << "}\n";
1012 }
1013 }
1014 out << SP << SP << "}\n";
1015 }
1016
1017 // Clip the elements of the input gate into the range [-fAttrClip, fAttrClip]
1018 if (fAttrClip > .0) {
1019 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1020 if (fType == "float") {
1021 out << SP << SP << SP << "float x = (" << OpName << "_input_gate[i] > " << -fAttrClip << ") ? " << OpName
1022 << "_input_gate[i] : " << -fAttrClip << ";\n";
1023 }
1024 out << SP << SP << SP << OpName << "_input_gate[i] = (x < " << fAttrClip << ") ? x : " << fAttrClip << ";\n";
1025 out << SP << SP << "}\n";
1026 }
1027 // Apply the activation function to the input gate
1028 if (fAttrActivations[direction * 3] == "Relu") {
1029 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1030 out << SP << SP << SP << "if (" << OpName << "_input_gate[i] < 0.)\n";
1031 out << SP << SP << SP << SP << OpName << "_input_gate[i] = 0.;\n";
1032 out << SP << SP << "}\n";
1033 } else if (fAttrActivations[direction * 3] == "Tanh") {
1034 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1035 if (fType == "float") {
1036 out << SP << SP << SP << "float ex = exp(-2 * " << OpName << "_input_gate[i]);\n";
1037 }
1038 out << SP << SP << SP << SP << OpName << "_input_gate[i] = (1. - ex) / (1. + ex);\n";
1039 out << SP << SP << "}\n";
1040 } else if (fAttrActivations[direction * 3] == "Sigmoid") {
1041 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1042 out << SP << SP << SP << SP << OpName << "_input_gate[i] = 1. / (1. + exp(-" << OpName
1043 << "_input_gate[i]));\n";
1044 out << SP << SP << "}\n";
1045 } else if (fAttrActivations[direction * 3] == "Affine") {
1046 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1047 out << SP << SP << SP << SP << OpName << "_input_gate[i] = " << fAttrActivationAlpha[direction * 3] << " * "
1048 << OpName << "_input_gate[i] + " << fAttrActivationBeta[direction * 3] << ";\n";
1049 out << SP << SP << "}\n";
1050 } else if (fAttrActivations[direction * 3] == "ScaledTanh") {
1051 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1052 if (fType == "float") {
1053 out << SP << SP << SP << "float ex = exp(-2 * " << fAttrActivationBeta[direction * 3] << " * " << OpName
1054 << "_input_gate[i]);\n";
1055 }
1056 out << SP << SP << SP << SP << OpName << "_input_gate[i] = " << fAttrActivationAlpha[direction * 3]
1057 << " * (1. - ex) / (1. + ex);\n";
1058 out << SP << SP << "}\n";
1059 } else if (fAttrActivations[direction * 3] == "HardSigmoid") {
1060 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1061 if (fType == "float") {
1062 out << SP << SP << SP << "float a = " << fAttrActivationAlpha[direction * 3] << " * " << OpName
1063 << "_input_gate[i] + " << fAttrActivationBeta[direction * 3] << ";\n";
1064 out << SP << SP << SP << "float b = (a > 0.) ? a : 0.;\n";
1065 }
1066 out << SP << SP << SP << SP << OpName << "_input_gate[i] = (b < 1.) ? b : 1.;\n";
1067 out << SP << SP << "}\n";
1068 } else if (fAttrActivations[direction * 3] == "LeakyRelu") {
1069 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1070 out << SP << SP << SP << "if (" << OpName << "_input_gate[i] < 0.)\n";
1071 out << SP << SP << SP << SP << OpName << "_input_gate[i] = " << fAttrActivationAlpha[direction * 3] << " * "
1072 << OpName << "_input_gate[i];\n";
1073 out << SP << SP << "}\n";
1074 } else if (fAttrActivations[direction * 3] == "ThresholdRelu") {
1075 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1076 out << SP << SP << SP << "if (" << OpName << "_input_gate[i] < " << fAttrActivationAlpha[direction * 3]
1077 << ")\n";
1078 out << SP << SP << SP << SP << OpName << "_input_gate[i] = 0.;\n";
1079 out << SP << SP << "}";
1080 } else if (fAttrActivations[direction * 3] == "Elu") {
1081 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1082 out << SP << SP << SP << "if (" << OpName << "_input_gate[i] < 0.)\n";
1083 out << SP << SP << SP << SP << OpName << "_input_gate[i] = " << fAttrActivationAlpha[direction * 3]
1084 << " * exp(" << OpName << "_input_gate[i] - 1.);\n";
1085 out << SP << SP << "}\n";
1086 } else if (fAttrActivations[direction * 3] == "Softsign") {
1087 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1088 out << SP << SP << SP << SP << OpName << "_input_gate[i] = " << OpName << "_input_gate[i] / (1. + abs("
1089 << OpName << "_input_gate[i]));\n";
1090 out << SP << SP << "}\n";
1091 } else { // fAttrActivations[direction * 3] = Softplus
1092 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1093 out << SP << SP << SP << SP << OpName << "_input_gate[i] = log(1. + exp(" << OpName << "_input_gate[i]));\n";
1094 out << SP << SP << "}\n";
1095 }
1096
1097 if (fAttrInputForget == 0) {
1098 // Clip the elements of the forget gate into the range [-fAttrClip, fAttrClip]
1099 if (fAttrClip > .0) {
1100 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1101 if (fType == "float") {
1102 out << SP << SP << SP << "float x = (" << OpName << "_forget_gate[i] > " << -fAttrClip << ") ? "
1103 << OpName << "_forget_gate[i] : " << -fAttrClip << ";\n";
1104 }
1105 out << SP << SP << SP << OpName << "_forget_gate[i] = (x < " << fAttrClip << ") ? x : " << fAttrClip
1106 << ";\n";
1107 out << SP << SP << "}\n";
1108 }
1109 // Apply the activation function to the forget gate, cell_gate = g(cell_gate)
1110 if (fAttrActivations[direction * 3] == "Relu") {
1111 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1112 out << SP << SP << SP << "if (" << OpName << "_forget_gate[i] < 0.)\n";
1113 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = 0.;\n";
1114 out << SP << SP << "}\n";
1115 } else if (fAttrActivations[direction * 3] == "Tanh") {
1116 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1117 if (fType == "float") {
1118 out << SP << SP << SP << "float ex = exp(-2 * " << OpName << "_forget_gate[i]);\n";
1119 }
1120 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = (1. - ex) / (1. + ex);\n";
1121 out << SP << SP << "}\n";
1122 } else if (fAttrActivations[direction * 3] == "Sigmoid") {
1123 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1124 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = 1. / (1. + exp(-" << OpName
1125 << "_forget_gate[i]));\n";
1126 out << SP << SP << "}\n";
1127 } else if (fAttrActivations[direction * 3] == "Affine") {
1128 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1129 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = " << fAttrActivationAlpha[direction * 3]
1130 << " * " << OpName << "_forget_gate[i] + " << fAttrActivationBeta[direction * 3] << ";\n";
1131 out << SP << SP << "}\n";
1132 } else if (fAttrActivations[direction * 3] == "ScaledTanh") {
1133 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1134 if (fType == "float") {
1135 out << SP << SP << SP << "float ex = exp(-2 * " << fAttrActivationBeta[direction * 3] << " * " << OpName
1136 << "_forget_gate[i]);\n";
1137 }
1138 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = " << fAttrActivationAlpha[direction * 3]
1139 << " * (1. - ex) / (1. + ex);\n";
1140 out << SP << SP << "}\n";
1141 } else if (fAttrActivations[direction * 3] == "HardSigmoid") {
1142 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1143 if (fType == "float") {
1144 out << SP << SP << SP << "float a = " << fAttrActivationAlpha[direction * 3] << " * " << OpName
1145 << "_forget_gate[i] + " << fAttrActivationBeta[direction * 3] << ";\n";
1146 out << SP << SP << SP << "float b = (a > 0.) ? a : 0.;\n";
1147 }
1148 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = (b < 1.) ? b : 1.;\n";
1149 out << SP << SP << "}\n";
1150 } else if (fAttrActivations[direction * 3] == "LeakyRelu") {
1151 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1152 out << SP << SP << SP << "if (" << OpName << "_forget_gate[i] < 0.)\n";
1153 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = " << fAttrActivationAlpha[direction * 3]
1154 << " * " << OpName << "_forget_gate[i];\n";
1155 out << SP << SP << "}\n";
1156 } else if (fAttrActivations[direction * 3] == "ThresholdRelu") {
1157 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1158 out << SP << SP << SP << "if (" << OpName << "_forget_gate[i] < " << fAttrActivationAlpha[direction * 3]
1159 << ")\n";
1160 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = 0.;\n";
1161 out << SP << SP << "}";
1162 } else if (fAttrActivations[direction * 3] == "Elu") {
1163 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1164 out << SP << SP << SP << "if (" << OpName << "_forget_gate[i] < 0.)\n";
1165 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = " << fAttrActivationAlpha[direction * 3]
1166 << " * exp(" << OpName << "_forget_gate[i] - 1.);\n";
1167 out << SP << SP << "}\n";
1168 } else if (fAttrActivations[direction * 3] == "Softsign") {
1169 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1170 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = " << OpName << "_forget_gate[i] / (1. + abs("
1171 << OpName << "_forget_gate[i]));\n";
1172 out << SP << SP << "}\n";
1173 } else { // fAttrActivations[direction * 3] = Softplus
1174 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1175 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = log(1. + exp(" << OpName
1176 << "_forget_gate[i]));\n";
1177 out << SP << SP << "}\n";
1178 }
1179 }
1180
1181 // cell_state = input_gate o cell_gate
1182 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1183 out << SP << SP << SP << OpName << "_cell_state[i] = " << OpName << "_input_gate[i] * " << OpName
1184 << "_cell_gate[i];\n";
1185 out << SP << SP << "}\n";
1186
1187 if (fAttrInputForget == 0) {
1188 out << SP << SP << "if (seq == 0) {\n";
1189 if (!fNInitial_c.empty()) {
1190 // cell_state += forget_gate o initial_cell_state
1191 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
1192 out << SP << SP << SP << SP << OpName << "_cell_state[i + offset] += " << OpName
1193 << "_forget_gate[i + offset] * " << OpName << "_initial_cell_state[i];\n";
1194 out << SP << SP << SP << "}\n";
1195 }
1196 out << SP << SP << "} else {\n";
1197 // cell_state += forget_gate o previous_cell_state
1198 if (direction == 0) {
1199 if (fAttrDirection == "backward") {
1200 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
1201 << num_directions * batch_size * fAttrHiddenSize << ";\n";
1202 } else {
1203 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
1204 << num_directions * batch_size * fAttrHiddenSize << ";\n";
1205 }
1206 } else { // direction=1
1207 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
1208 << num_directions * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
1209 }
1210 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
1211 out << SP << SP << SP << SP << OpName << "_cell_state[i + offset] += " << OpName
1212 << "_forget_gate[i + offset] * " << OpName << "_cell_state[i + previous_offset];\n";
1213 out << SP << SP << SP << "}\n";
1214 out << SP << SP << "}\n";
1215 }
1216
1217 if (!fNP.empty()) {
1218 // Peephole connection for the output gate
1219 if (direction == 0) {
1220 size_t p_offset = 2 * batch_size * fAttrHiddenSize;
1221 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
1222 out << SP << SP << SP << SP << OpName << "_output_gate[i + offset] += tensor_" << fNP << "[i + " << p_offset
1223 << "] * " << OpName << "_cell_state[i + offset];\n";
1224 out << SP << SP << SP << "}\n";
1225 } else { // direction=1
1226 size_t p_offset = 3 * batch_size * fAttrHiddenSize + 2 * batch_size * fAttrHiddenSize;
1227 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
1228 out << SP << SP << SP << SP << OpName << "_output_gate[i + offset] += tensor_" << fNP << "[i + " << p_offset
1229 << "] * " << OpName << "_cell_state[i + offset];\n";
1230 out << SP << SP << SP << "}\n";
1231 }
1232 }
1233
1234 // Clip the elements of the output gate into the range [-fAttrClip, fAttrClip]
1235 if (fAttrClip > .0) {
1236 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1237 if (fType == "float") {
1238 out << SP << SP << SP << "float x = (" << OpName << "_output_gate[i] > " << -fAttrClip << ") ? " << OpName
1239 << "_output_gate[i] : " << -fAttrClip << ";\n";
1240 }
1241 out << SP << SP << SP << OpName << "_output_gate[i] = (x < " << fAttrClip << ") ? x : " << fAttrClip << ";\n";
1242 out << SP << SP << "}\n";
1243 }
1244 // Apply the activation function to the output gate
1245 if (fAttrActivations[direction * 3] == "Relu") {
1246 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1247 out << SP << SP << SP << "if (" << OpName << "_output_gate[i] < 0.)\n";
1248 out << SP << SP << SP << SP << OpName << "_output_gate[i] = 0.;\n";
1249 out << SP << SP << "}\n";
1250 } else if (fAttrActivations[direction * 3] == "Tanh") {
1251 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1252 if (fType == "float") {
1253 out << SP << SP << SP << "float ex = exp(-2 * " << OpName << "_output_gate[i]);\n";
1254 }
1255 out << SP << SP << SP << SP << OpName << "_output_gate[i] = (1. - ex) / (1. + ex);\n";
1256 out << SP << SP << "}\n";
1257 } else if (fAttrActivations[direction * 3] == "Sigmoid") {
1258 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1259 out << SP << SP << SP << SP << OpName << "_output_gate[i] = 1. / (1. + exp(-" << OpName
1260 << "_output_gate[i]));\n";
1261 out << SP << SP << "}\n";
1262 } else if (fAttrActivations[direction * 3] == "Affine") {
1263 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1264 out << SP << SP << SP << SP << OpName << "_output_gate[i] = " << fAttrActivationAlpha[direction * 3] << " * "
1265 << OpName << "_output_gate[i] + " << fAttrActivationBeta[direction * 3] << ";\n";
1266 out << SP << SP << "}\n";
1267 } else if (fAttrActivations[direction * 3] == "ScaledTanh") {
1268 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1269 if (fType == "float") {
1270 out << SP << SP << SP << "float ex = exp(-2 * " << fAttrActivationBeta[direction * 3] << " * " << OpName
1271 << "_output_gate[i]);\n";
1272 }
1273 out << SP << SP << SP << SP << OpName << "_output_gate[i] = " << fAttrActivationAlpha[direction * 3]
1274 << " * (1. - ex) / (1. + ex);\n";
1275 out << SP << SP << "}\n";
1276 } else if (fAttrActivations[direction * 3] == "HardSigmoid") {
1277 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1278 if (fType == "float") {
1279 out << SP << SP << SP << "float a = " << fAttrActivationAlpha[direction * 3] << " * " << OpName
1280 << "_output_gate[i] + " << fAttrActivationBeta[direction * 3] << ";\n";
1281 out << SP << SP << SP << "float b = (a > 0.) ? a : 0.;\n";
1282 }
1283 out << SP << SP << SP << SP << OpName << "_output_gate[i] = (b < 1.) ? b : 1.;\n";
1284 out << SP << SP << "}\n";
1285 } else if (fAttrActivations[direction * 3] == "LeakyRelu") {
1286 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1287 out << SP << SP << SP << "if (" << OpName << "_output_gate[i] < 0.)\n";
1288 out << SP << SP << SP << SP << OpName << "_output_gate[i] = " << fAttrActivationAlpha[direction * 3] << " * "
1289 << OpName << "_output_gate[i];\n";
1290 out << SP << SP << "}\n";
1291 } else if (fAttrActivations[direction * 3] == "ThresholdRelu") {
1292 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1293 out << SP << SP << SP << "if (" << OpName << "_output_gate[i] < " << fAttrActivationAlpha[direction * 3]
1294 << ")\n";
1295 out << SP << SP << SP << SP << OpName << "_output_gate[i] = 0.;\n";
1296 out << SP << SP << "}";
1297 } else if (fAttrActivations[direction * 3] == "Elu") {
1298 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1299 out << SP << SP << SP << "if (" << OpName << "_output_gate[i] < 0.)\n";
1300 out << SP << SP << SP << SP << OpName << "_output_gate[i] = " << fAttrActivationAlpha[direction * 3]
1301 << " * exp(" << OpName << "_output_gate[i] - 1.);\n";
1302 out << SP << SP << "}\n";
1303 } else if (fAttrActivations[direction * 3] == "Softsign") {
1304 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1305 out << SP << SP << SP << SP << OpName << "_output_gate[i] = " << OpName << "_output_gate[i] / (1. + abs("
1306 << OpName << "_output_gate[i]));\n";
1307 out << SP << SP << "}\n";
1308 } else { // fAttrActivations[direction * 3] = Softplus
1309 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1310 out << SP << SP << SP << SP << OpName << "_output_gate[i] = log(1. + exp(" << OpName << "_output_gate[i]));\n";
1311 out << SP << SP << "}\n";
1312 }
1313
1314 // copy cell_state into new_cell_state
1315 out << SP << SP << "std::copy(" << OpName << "_cell_state + offset, " << OpName << "_cell_state + offset + "
1316 << size << ", " << OpName << "_new_cell_state + offset);\n";
1317 // Clip the elements of the new_cell_state into the range [-fAttrClip, fAttrClip]
1318 if (fAttrClip > .0) {
1319 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1320 if (fType == "float") {
1321 out << SP << SP << SP << "float x = (" << OpName << "_new_cell_state[i] > " << -fAttrClip << ") ? "
1322 << OpName << "_new_cell_state[i] : " << -fAttrClip << ";\n";
1323 }
1324 out << SP << SP << SP << OpName << "_new_cell_state[i] = (x < " << fAttrClip << ") ? x : " << fAttrClip
1325 << ";\n";
1326 out << SP << SP << "}\n";
1327 }
1328 // Apply the activation function to the new cell state
1329 if (fAttrActivations[direction * 3 + 2] == "Relu") {
1330 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1331 out << SP << SP << SP << "if (" << OpName << "_new_cell_state[i] < 0.)\n";
1332 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = 0.;\n";
1333 out << SP << SP << "}\n";
1334 } else if (fAttrActivations[direction * 3 + 2] == "Tanh") {
1335 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1336 if (fType == "float") {
1337 out << SP << SP << SP << "float ex = exp(-2 * " << OpName << "_new_cell_state[i]);\n";
1338 }
1339 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = (1. - ex) / (1. + ex);\n";
1340 out << SP << SP << "}\n";
1341 } else if (fAttrActivations[direction * 3 + 2] == "Sigmoid") {
1342 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1343 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = 1. / (1. + exp(-" << OpName
1344 << "_new_cell_state[i]));\n";
1345 out << SP << SP << "}\n";
1346 } else if (fAttrActivations[direction * 3 + 2] == "Affine") {
1347 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1348 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = " << fAttrActivationAlpha[direction * 3 + 2]
1349 << " * " << OpName << "_new_cell_state[i] + " << fAttrActivationBeta[direction * 3 + 2] << ";\n";
1350 out << SP << SP << "}\n";
1351 } else if (fAttrActivations[direction * 3 + 2] == "ScaledTanh") {
1352 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1353 if (fType == "float") {
1354 out << SP << SP << SP << "float ex = exp(-2 * " << fAttrActivationBeta[direction * 3 + 2] << " * " << OpName
1355 << "_new_cell_state[i]);\n";
1356 }
1357 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = " << fAttrActivationAlpha[direction * 3 + 2]
1358 << " * (1. - ex) / (1. + ex);\n";
1359 out << SP << SP << "}\n";
1360 } else if (fAttrActivations[direction * 3 + 2] == "HardSigmoid") {
1361 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1362 if (fType == "float") {
1363 out << SP << SP << SP << "float a = " << fAttrActivationAlpha[direction * 3 + 2] << " * " << OpName
1364 << "_new_cell_state[i] + " << fAttrActivationBeta[direction * 3 + 2] << ";\n";
1365 out << SP << SP << SP << "float b = (a > 0.) ? a : 0.;\n";
1366 }
1367 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = (b < 1.) ? b : 1.;\n";
1368 out << SP << SP << "}\n";
1369 } else if (fAttrActivations[direction * 3 + 2] == "LeakyRelu") {
1370 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1371 out << SP << SP << SP << "if (" << OpName << "_new_cell_state[i] < 0.)\n";
1372 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = " << fAttrActivationAlpha[direction * 3 + 2]
1373 << " * " << OpName << "_new_cell_state[i];\n";
1374 out << SP << SP << "}\n";
1375 } else if (fAttrActivations[direction * 3 + 2] == "ThresholdRelu") {
1376 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1377 out << SP << SP << SP << "if (" << OpName << "_new_cell_state[i] < " << fAttrActivationAlpha[direction * 3 + 2]
1378 << ")\n";
1379 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = 0.;\n";
1380 out << SP << SP << "}";
1381 } else if (fAttrActivations[direction * 3 + 2] == "Elu") {
1382 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1383 out << SP << SP << SP << "if (" << OpName << "_new_cell_state[i] < 0.)\n";
1384 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = " << fAttrActivationAlpha[direction * 3 + 2]
1385 << " * exp(" << OpName << "_new_cell_state[i] - 1.);\n";
1386 out << SP << SP << "}\n";
1387 } else if (fAttrActivations[direction * 3 + 2] == "Softsign") {
1388 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1389 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = " << OpName << "_new_cell_state[i] / (1. + abs("
1390 << OpName << "_new_cell_state[i]));\n";
1391 out << SP << SP << "}\n";
1392 } else { // fAttrActivations[direction * 3 + 2] = Softplus
1393 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1394 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = log(1. + exp(" << OpName
1395 << "_new_cell_state[i]));\n";
1396 out << SP << SP << "}\n";
1397 }
1398
1399 // hidden_state = output_gate o new_cell_state
1400 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1401 out << SP << SP << SP << OpName << "_hidden_state[i] = " << OpName << "_output_gate[i] * " << OpName
1402 << "_new_cell_state[i];\n";
1403 out << SP << SP << "}\n";
1404 out << SP << "}\n";
1405 }
1406
1407 // Padding the hidden state for LSTM with different sequence lengths
1408 if (!fNSequence_lens.empty()) {
1409 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
1410 out << SP << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1411 out << SP << SP << SP << "if (seq >= tensor_" << fNSequence_lens << "[batch]) {\n";
1412 for (size_t direction = 0; direction < num_directions; direction++) {
1413 out << SP << SP << SP << SP << SP << "for (size_t h = 0; h < " << fAttrHiddenSize << "; h++) {\n";
1414 out << SP << SP << SP << SP << SP << SP << "size_t idx = seq * "
1415 << num_directions * batch_size * fAttrHiddenSize + direction * batch_size * fAttrHiddenSize
1416 << " + batch * " << fAttrHiddenSize << " + h;\n";
1417 out << SP << SP << SP << SP << SP << SP << OpName << "_cell_state[idx] = 0.;\n";
1418 out << SP << SP << SP << SP << SP << SP << OpName << "_hidden_state[idx] = 0.;\n";
1419 out << SP << SP << SP << SP << SP << "}\n";
1420 }
1421 out << SP << SP << SP << "}\n";
1422 out << SP << SP << "}\n";
1423 out << SP << "}\n";
1424 }
1425
1426 // Copy the hidden state into y and y_h and copy cell_state into y_c
1427 if (fAttrLayout == 0) {
1428 if (!fNY_h.empty()) {
1429 // Copy hidden_state into Y_h
1430 if (fNSequence_lens.empty()) {
1431 size_t y_h_size = batch_size * fAttrHiddenSize;
1432 if (fAttrDirection == "backward") {
1433 out << SP << "std::copy(" << OpName << "_hidden_state, " << OpName << "_hidden_state + " << y_h_size
1434 << ", tensor_" << fNY_h << ");\n";
1435 } else {
1436 size_t offset = (seq_length - 1) * num_directions * batch_size * fAttrHiddenSize;
1437 out << SP << "std::copy(" << OpName << "_hidden_state + " << offset << ", " << OpName
1438 << "_hidden_state + " << offset << " + " << y_h_size << ", tensor_" << fNY_h << ");\n";
1439 }
1440 if (num_directions == 2) {
1441 out << SP << "std::copy(" << OpName << "_hidden_state + " << y_h_size << ", " << OpName
1442 << "_hidden_state + " << 2 * y_h_size << ", tensor_" << fNY_h << " + " << y_h_size << ");\n";
1443 }
1444 } else { // LSTM with different sequence lengths
1445 if (fAttrDirection == "backward") {
1446 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1447 out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
1448 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1449 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + offset);\n";
1450 out << SP << "}\n";
1451 } else {
1452 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1453 out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
1454 out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
1455 << " + batch * " << fAttrHiddenSize << ";\n";
1456 out << SP << SP << "size_t y_h_offset = batch * " << fAttrHiddenSize << ";\n";
1457 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1458 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + y_h_offset);\n";
1459 out << SP << "}\n";
1460 }
1461 if (num_directions == 2) {
1462 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1463 out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize
1464 << ";\n";
1465 out << SP << SP << "size_t y_h_offset = " << batch_size * fAttrHiddenSize << " + batch * "
1466 << fAttrHiddenSize << ";\n";
1467 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1468 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + y_h_offset);\n";
1469 out << SP << "}\n";
1470 }
1471 }
1472 }
1473 if (!fNY_c.empty()) {
1474 // Copy cell_state into Y_c
1475 if (fNSequence_lens.empty()) {
1476 size_t y_h_size = batch_size * fAttrHiddenSize;
1477 if (fAttrDirection == "backward") {
1478 out << SP << "std::copy(" << OpName << "_cell_state, " << OpName << "_hidden_state + " << y_h_size
1479 << ", tensor_" << fNY_c << ");\n";
1480 } else {
1481 size_t offset = (seq_length - 1) * num_directions * batch_size * fAttrHiddenSize;
1482 out << SP << "std::copy(" << OpName << "_cell_state + " << offset << ", " << OpName << "_cell_state + "
1483 << offset << " + " << y_h_size << ", tensor_" << fNY_c << ");\n";
1484 }
1485 if (num_directions == 2) {
1486 out << SP << "std::copy(" << OpName << "_cell_state + " << y_h_size << ", " << OpName << "_cell_state + "
1487 << 2 * y_h_size << ", tensor_" << fNY_c << " + " << y_h_size << ");\n";
1488 }
1489 } else { // LSTM with different sequence lengths
1490 if (fAttrDirection == "backward") {
1491 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1492 out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
1493 out << SP << SP << "std::copy(" << OpName << "_cell_state + offset, " << OpName
1494 << "_cell_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_c << " + offset);\n";
1495 out << SP << "}\n";
1496 } else {
1497 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1498 out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
1499 out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
1500 << " + batch * " << fAttrHiddenSize << ";\n";
1501 out << SP << SP << "size_t y_h_offset = batch * " << fAttrHiddenSize << ";\n";
1502 out << SP << SP << "std::copy(" << OpName << "_cell_state + offset, " << OpName
1503 << "_cell_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_c << " + y_h_offset);\n";
1504 out << SP << "}\n";
1505 }
1506 if (num_directions == 2) {
1507 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1508 out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize
1509 << ";\n";
1510 out << SP << SP << "size_t y_h_offset = " << batch_size * fAttrHiddenSize << " + batch * "
1511 << fAttrHiddenSize << ";\n";
1512 out << SP << SP << "std::copy(" << OpName << "_cell_state + offset, " << OpName
1513 << "_cell_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_c << " + y_h_offset);\n";
1514 out << SP << "}\n";
1515 }
1516 }
1517 }
1518 } else { // fAttrLayout=1
1519 if (!fNY.empty()) {
1520 // Copy hidden_state into Y
1521 for (size_t direction = 0; direction < num_directions; direction++) {
1522 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
1523 out << SP << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1524 out << SP << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize << " + "
1525 << direction * batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize << ";\n";
1526 out << SP << SP << SP << "size_t y_offset = batch * " << seq_length * num_directions * fAttrHiddenSize
1527 << " + seq * " << num_directions * fAttrHiddenSize << " + " << direction * fAttrHiddenSize << ";\n";
1528 out << SP << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1529 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY << " + y_offset);\n";
1530 out << SP << SP << "}\n";
1531 out << SP << "}\n";
1532 }
1533 }
1534 if (!fNY_h.empty()) {
1535 // Copy the hidden_state into Y_h
1536 if (fAttrDirection == "backward") {
1537 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1538 out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
1539 out << SP << SP << "size_t y_h_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
1540 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1541 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + y_h_offset);\n";
1542 out << SP << "}\n";
1543 } else {
1544 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1545 if (fNSequence_lens.empty()) {
1546 out << SP << SP << "size_t seq = " << seq_length - 1 << ";\n";
1547 } else {
1548 out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
1549 }
1550 out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
1551 << " + batch * " << fAttrHiddenSize << ";\n";
1552 out << SP << SP << "size_t y_h_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
1553 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1554 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + y_h_offset);\n";
1555 out << SP << "}\n";
1556 }
1557 if (num_directions == 2) {
1558 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1559 out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize
1560 << ";\n";
1561 out << SP << SP << "size_t y_h_offset = batch * " << num_directions * fAttrHiddenSize << " + "
1562 << fAttrHiddenSize << ";\n";
1563 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1564 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + y_h_offset);\n";
1565 out << SP << "}\n";
1566 }
1567 }
1568
1569 if (!fNY_c.empty()) {
1570 // copy the cell_state into Y_c
1571 if (fAttrDirection == "backward") {
1572 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1573 out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
1574 out << SP << SP << "size_t y_h_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
1575 out << SP << SP << "std::copy(" << OpName << "_cell_state + offset, " << OpName << "_cell_state + offset + "
1576 << fAttrHiddenSize << ", tensor_" << fNY_c << " + y_h_offset);\n";
1577 out << SP << "}\n";
1578 } else {
1579 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1580 if (fNSequence_lens.empty()) {
1581 out << SP << SP << "size_t seq = " << seq_length - 1 << ";\n";
1582 } else {
1583 out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
1584 }
1585 out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
1586 << " + batch * " << fAttrHiddenSize << ";\n";
1587 out << SP << SP << "size_t y_h_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
1588 out << SP << SP << "std::copy(" << OpName << "_cell_state + offset, " << OpName << "_cell_state + offset + "
1589 << fAttrHiddenSize << ", tensor_" << fNY_c << " + y_h_offset);\n";
1590 out << SP << "}\n";
1591 }
1592 if (num_directions == 2) {
1593 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1594 out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize
1595 << ";\n";
1596 out << SP << SP << "size_t y_h_offset = batch * " << num_directions * fAttrHiddenSize << " + "
1597 << fAttrHiddenSize << ";\n";
1598 out << SP << SP << "std::copy(" << OpName << "_cell_state + offset, " << OpName << "_cell_state + offset + "
1599 << fAttrHiddenSize << ", tensor_" << fNY_c << " + y_h_offset);\n";
1600 out << SP << "}\n";
1601 }
1602 }
1603 }
1604
1605 return out.str();
1606}
1607
1608} // namespace TMVA::Experimental::SOFIE
1609
1610#endif
#define h(i)
Definition RSha256.hxx:106
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h offset
const_iterator begin() const
const_iterator end() const
Long Short-Term Memory operator.
std::string GenerateSessionMembersCode(std::string opName) override
Generate the code for the Session internal data vectors.
std::vector< size_t > fShapeR
Shape of the recurrence.
std::vector< size_t > fShapeInitial_c
Shape of the initial value of the cell states.
std::string fNY_c
Name of the last sequence of the cell states.
ROperator_LSTM(std::vector< float > activation_alpha, std::vector< float > activation_beta, std::vector< std::string > activations, float clip, std::string direction, size_t hidden_size, size_t input_forget, size_t layout, std::string nameX, std::string nameW, std::string nameR, std::string nameB, std::string nameSequence_lens, std::string nameInitial_h, std::string nameInitial_c, std::string nameP, std::string nameY, std::string nameY_h, std::string nameY_c)
Constructor of ROperator_LSTM from the attributes.
std::string fNR
Name of the recurrence.
std::vector< size_t > fShapeY_h
Shape of the last sequence of the output.
std::vector< float > fAttrActivationAlpha
Sacling values used by some activation functions.
std::vector< size_t > fShapeInitial_h
Shape of the initial value of the hidden states.
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > input) override
Infers the shape of the output tensors.
size_t fAttrHiddenSize
Number of the hidden layers.
std::vector< std::string > GetBlasRoutines() override
Returns the blas routines needed to compile the generated code.
std::string fNInitial_c
Name of the initial value of the cell states.
std::vector< size_t > fShapeB
Shape of the bias.
std::string fNW
Name of the weights.
std::vector< size_t > fShapeY
Shape of the output.
std::vector< size_t > fShapeP
Shape of the peepholes.
std::string fType
Type of the tensors.
std::string fAttrDirection
Direction of processing.
std::vector< float > fAttrActivationBeta
Scaling values used by some activation functions.
std::vector< size_t > fShapeX
Shape of the input.
std::vector< size_t > fShapeW
Shape of the weights.
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input) override
Infers the type of the output tensors.
std::string fNSequence_lens
Name of length of the sequences.
std::string fNY_h
Name of the last sequence of the output.
std::string fNY
Name of the output.
void Initialize(RModel &) override
Initialize the model.
std::vector< std::string > fAttrActivations
Activation functions.
std::string Generate(std::string OpName) override
Generate the inference code.
std::vector< size_t > fShapeSequence_lens
Shape of the length of the sequences.
std::vector< size_t > fShapeY_c
Shape of the last sequence of the cell states.
std::string fNInitial_h
Name of the initial value of the hidden states.
ROperator_LSTM()
Default constructor of ROperator_LSTM.
std::vector< std::string_view > fInputTensorNames
Definition ROperator.hxx:49
std::vector< std::string_view > fOutputTensorNames
Definition ROperator.hxx:50
static uint64_t sum(uint64_t i)
Definition Factory.cxx:2338