186 size_t num_directions = input[1][0];
187 size_t hidden_size = input[1][1] / 4;
189 size_t seq_length = input[0][0];
190 size_t batch_size = input[0][1];
191 std::vector<std::vector<size_t>>
ret({{seq_length, num_directions, batch_size, hidden_size},
192 {num_directions, batch_size, hidden_size},
193 {num_directions, batch_size, hidden_size}});
196 size_t batch_size = input[0][0];
197 size_t seq_length = input[0][1];
198 std::vector<std::vector<size_t>>
ret({{batch_size, seq_length, num_directions, hidden_size},
199 {batch_size, num_directions, hidden_size},
200 {batch_size, num_directions, hidden_size}});
210 if (!model.CheckIfTensorAlreadyExist(
fNX)) {
211 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
fNX +
" is not found in model.");
215 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
fNX +
" is not of 3 dimensions.");
217 if (!model.CheckIfTensorAlreadyExist(
fNW)) {
218 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
fNW +
" is not found in model.");
222 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
fNW +
" is not of 3 dimensions.");
224 if (!model.CheckIfTensorAlreadyExist(
fNR)) {
225 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
fNR +
" is not found in model.");
229 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
fNR +
" is not of 3 dimensions.");
232 if (!model.CheckIfTensorAlreadyExist(
fNB)) {
233 throw std::runtime_error(
"TMVA SOFIE LSTM op input tensor " +
fNB +
" is not found in model.");
237 throw std::runtime_error(
"TMVA SOFIE LSTM op input tensor " +
fNB +
" is not of 2 or 5 dimensions.");
241 auto original_data = model.GetInitializedTensorData(
fNB);
242 size_t num_directions =
fShapeW[0];
245 if (
fType ==
"float") {
246 float *original_bias =
static_cast<float *
>(original_data.get());
247 float *new_bias =
new float[4 * num_directions * seq_length * batch_size *
fAttrHiddenSize];
248 for (
size_t gate = 0; gate < 4; gate++) {
250 for (
size_t direction = 0; direction < num_directions; direction++) {
255 for (
size_t seq = 0; seq < seq_length; seq++) {
256 for (
size_t batch = 0; batch < batch_size; batch++) {
257 size_t bias_offset = gate * num_directions * seq_length * batch_size *
fAttrHiddenSize +
260 std::copy(
sum.begin(),
sum.end(), new_bias + bias_offset);
265 std::vector<size_t> new_bias_shape = {4, num_directions, seq_length, batch_size,
fAttrHiddenSize};
266 std::shared_ptr<void> new_bias_ptr(new_bias, std::default_delete<
float[]>());
267 model.UpdateInitializedTensor(
fNB, model.GetTensorType(
fNB), new_bias_shape, new_bias_ptr);
274 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
fNSequence_lens +
"is not found in model.");
278 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
fNSequence_lens +
" is not of 1 dimension.");
282 if (!model.CheckIfTensorAlreadyExist(
fNInitial_h)) {
283 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
fNInitial_h +
" is not found in model.");
287 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
fNInitial_h +
" is not of 3 dimensions.");
291 if (!model.CheckIfTensorAlreadyExist(
fNInitial_c)) {
292 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
fNInitial_c +
" is not found in model.");
296 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
fNInitial_c +
" is not of 3 dimensions.");
300 if (!model.CheckIfTensorAlreadyExist(
fNP)) {
301 throw std::runtime_error(
"TMVA SOFIE LSTM op input tensor " +
fNP +
" is not found in model.");
305 throw std::runtime_error(
"TMVA SOFIE LSTM op input tensor " +
fNP +
" is not of 2 or 4 dimensions.");
309 auto original_data = model.GetInitializedTensorData(
fNP);
310 size_t num_directions =
fShapeW[0];
312 if (
fType ==
"float") {
313 float *original_p =
static_cast<float *
>(original_data.get());
314 float *new_p =
new float[num_directions * 3 * batch_size *
fAttrHiddenSize];
315 for (
size_t direction = 0; direction < num_directions; direction++) {
316 for (
size_t gate = 0; gate < 3; gate++) {
318 for (
size_t batch = 0; batch < batch_size; batch++) {
321 std::copy(original_p + p_offset, original_p + p_offset +
fAttrHiddenSize, new_p + offset);
325 std::vector<size_t> new_p_shape = {num_directions, 3, batch_size,
fAttrHiddenSize};
326 std::shared_ptr<void> new_p_ptr(new_p, std::default_delete<
float[]>());
327 model.UpdateInitializedTensor(
fNP, model.GetTensorType(
fNP), new_p_shape, new_p_ptr);
334 if (!model.CheckIfTensorAlreadyExist(
fNY)) {
335 model.AddIntermediateTensor(
fNY, model.GetTensorType(
fNX),
fShapeY);
338 if (!
fNY_h.empty()) {
340 if (!model.CheckIfTensorAlreadyExist(
fNY_h)) {
344 if (!
fNY_c.empty()) {
346 if (!model.CheckIfTensorAlreadyExist(
fNY_c)) {
352 if (activation !=
"Relu" && activation !=
"Tanh" && activation !=
"Sigmoid" && activation !=
"Affine" &&
353 activation !=
"LeakyRelu" && activation !=
"ThresholdRelu" && activation !=
"ScaledTanh" &&
354 activation !=
"HardSigmoid" && activation !=
"Elu" && activation !=
"Softsign" && activation !=
"Softplus") {
355 throw std::runtime_error(
"TMVA SOFIE - Activation function " + activation +
" not implemented");
359 throw std::runtime_error(
"TMVA SOFIE - Invalid LSTM direction fAttrDirection = " +
fAttrDirection);
362 throw std::runtime_error(
"TMVA SOFIE - fAttrHiddenSize must be equal to " + std::to_string(
fShapeW[1] / 4));
365 throw std::runtime_error(
"TMVA SOFIE - fAttrInputForget = " + std::to_string(
fAttrInputForget) +
369 throw std::runtime_error(
"TMVA SOFIE - Layout fAttrLayout = " + std::to_string(
fAttrLayout) +
370 " must be 0 (timewise) or 1 (batchwise)");
459 OpName =
"op_" + OpName;
460 std::stringstream out;
464 size_t input_size =
fShapeX[2];
465 size_t num_directions =
fShapeW[0];
469 out <<
SP <<
fType <<
" const *" << OpName <<
"_input = tensor_" <<
fNX <<
";\n";
472 out <<
SP <<
fType <<
" * " << OpName <<
"_input = this->fVec_" << OpName <<
"_input;\n";
474 out <<
SP <<
fType <<
" " << OpName <<
"_input[" << seq_length * batch_size * input_size <<
"] = {0};\n";
476 out <<
SP <<
"for(size_t seq = 0; seq < " << seq_length <<
"; seq++) {\n";
477 out <<
SP <<
SP <<
"for(size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
478 out <<
SP <<
SP <<
SP <<
"for(size_t i = 0; i < " << input_size <<
"; i++) {\n";
479 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input[seq * " << batch_size * input_size <<
" + batch * " << input_size
480 <<
" + i] = " <<
"tensor_" <<
fNX <<
"[batch * " << seq_length * input_size <<
" + seq * " << input_size
482 out <<
SP <<
SP <<
SP <<
"}\n";
483 out <<
SP <<
SP <<
"}\n";
490 out <<
SP <<
fType <<
" const*" << OpName <<
"_initial_hidden_state = " <<
" tensor_" <<
fNInitial_h <<
";\n";
493 out <<
SP <<
fType <<
" const* " << OpName <<
"_initial_hidden_state = this->fVec_" << OpName
494 <<
"_initial_hidden_state;\n";
496 out <<
SP <<
fType <<
" " << OpName <<
"_initial_hidden_state["
499 for (
size_t direction = 0; direction < num_directions; direction++) {
500 out <<
SP <<
"for(size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
502 out <<
SP <<
SP <<
SP << OpName <<
"_initial_hidden_state[" << direction * batch_size *
fAttrHiddenSize
505 out <<
SP <<
SP <<
"}\n";
514 out <<
SP <<
fType <<
" const*" << OpName <<
"_initial_cell_state = " <<
" tensor_" <<
fNInitial_c <<
";\n";
517 out <<
SP <<
fType <<
" const* " << OpName <<
"_initial_cell_state = this->fVec_" << OpName
518 <<
"_initial_cell_state;\n";
520 out <<
SP <<
fType <<
" " << OpName <<
"_initial_cell_state["
523 for (
size_t direction = 0; direction < num_directions; direction++) {
524 out <<
SP <<
"for(size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
526 out <<
SP <<
SP <<
SP << OpName <<
"_initial_cell_state[" << direction * batch_size *
fAttrHiddenSize
529 out <<
SP <<
SP <<
"}\n";
538 out <<
SP <<
fType <<
" * " << OpName <<
"_ff_input_gate = this->fVec_" << OpName <<
"_ff_input_gate;\n";
539 out <<
SP <<
fType <<
" * " << OpName <<
"_ff_output_gate = this->fVec_" << OpName <<
"_ff_output_gate;\n";
540 out <<
SP <<
fType <<
" * " << OpName <<
"_ff_cell_gate = this->fVec_" << OpName <<
"_ff_cell_gate;\n";
542 out <<
SP <<
fType <<
" * " << OpName <<
"_ff_forget_gate = this->fVec_" << OpName
543 <<
"_ff_forget_gate;\n";
546 out <<
SP <<
fType <<
" " << OpName <<
"_ff_input_gate[" << ff_size <<
"] = {0};\n";
547 out <<
SP <<
fType <<
" " << OpName <<
"_ff_output_gate[" << ff_size <<
"] = {0};\n";
548 out <<
SP <<
fType <<
" " << OpName <<
"_ff_cell_gate[" << ff_size <<
"] = {0};\n";
550 out <<
SP <<
fType <<
" " << OpName <<
"_ff_forget_gate[" << ff_size <<
"] = {0};\n";
554 size_t hidden_state_size = seq_length * num_directions * batch_size *
fAttrHiddenSize;
556 out <<
SP <<
fType <<
" * " << OpName <<
"_input_gate = this->fVec_" << OpName <<
"_input_gate;\n";
557 out <<
SP <<
fType <<
" * " << OpName <<
"_output_gate = this->fVec_" << OpName <<
"_output_gate;\n";
558 out <<
SP <<
fType <<
" * " << OpName <<
"_cell_gate = this->fVec_" << OpName <<
"_cell_gate;\n";
560 out <<
SP <<
fType <<
" * " << OpName <<
"_forget_gate = this->fVec_" << OpName <<
"_forget_gate;\n";
563 out <<
SP <<
fType <<
" " << OpName <<
"_input_gate[" << hidden_state_size <<
"] = {0};\n";
564 out <<
SP <<
fType <<
" " << OpName <<
"_output_gate[" << hidden_state_size <<
"] = {0};\n";
565 out <<
SP <<
fType <<
" " << OpName <<
"_cell_gate[" << hidden_state_size <<
"] = {0};\n";
567 out <<
SP <<
fType <<
" " << OpName <<
"_forget_gate[" << hidden_state_size <<
"] = {0};\n";
572 out <<
SP <<
fType <<
" * " << OpName <<
"_cell_state = this->fVec_" << OpName <<
"_cell_state;\n";
573 out <<
SP <<
fType <<
" * " << OpName <<
"_new_cell_state = this->fVec_" << OpName <<
"_new_cell_state;\n";
575 out <<
SP <<
fType <<
" " << OpName <<
"_cell_state[" << hidden_state_size <<
"] = {0};\n";
576 out <<
SP <<
fType <<
" " << OpName <<
"_new_cell_state[" << hidden_state_size <<
"] = {0};\n";
581 out <<
SP <<
fType <<
" *" << OpName <<
"_hidden_state = tensor_" <<
fNY <<
";\n";
584 out <<
SP <<
fType <<
" * " << OpName <<
"_hidden_state = this->fVec_" << OpName <<
"_hidden_state;\n";
586 out <<
SP <<
fType <<
" " << OpName <<
"_hidden_state[" << hidden_state_size <<
"] = {0};\n";
590 out <<
SP <<
"char " << OpName <<
"_transA = 'N';\n";
591 out <<
SP <<
"char " << OpName <<
"_transB = 'T';\n";
592 out <<
SP <<
"int " << OpName <<
"_m = " << seq_length * batch_size <<
";\n";
594 out <<
SP <<
"int " << OpName <<
"_k = " << input_size <<
";\n";
595 if (
fType ==
"float") {
596 out <<
SP <<
fType <<
" " << OpName <<
"_alpha = 1.;\n";
597 out <<
SP <<
fType <<
" " << OpName <<
"_beta = 0.;\n";
600 out <<
SP <<
"int " << OpName <<
"_bias_size = " << seq_length * batch_size *
fAttrHiddenSize <<
";\n";
601 out <<
SP <<
"int " << OpName <<
"_incx = 1;\n";
602 out <<
SP <<
"int " << OpName <<
"_incy = 1;\n";
605 auto emit_sgemm = [&](
const std::string &out_name,
size_t offset) -> std::string {
606 std::stringstream ss;
607 ss <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName <<
"_n, &" << OpName
608 <<
"_m, &" << OpName <<
"_k, &" << OpName <<
"_alpha, tensor_" <<
fNW;
611 ss <<
" + " << offset;
613 ss <<
", &" << OpName <<
"_k, " << OpName <<
"_input, &" << OpName <<
"_k, &" << OpName <<
"_beta, " << OpName
614 <<
"_" << out_name <<
", &" << OpName <<
"_n);\n";
618 for (
size_t direction = 0; direction < num_directions; direction++) {
619 if (direction == 0) {
620 if (
fType ==
"float") {
622 out <<
SP << emit_sgemm(
"ff_input_gate", 0);
625 out <<
SP << emit_sgemm(
"ff_output_gate", wo_offset);
628 out <<
SP << emit_sgemm(
"ff_cell_gate", wc_offset);
631 if (
fType ==
"float") {
636 out <<
SP << emit_sgemm(
"ff_output_gate", wo_offset);
639 out <<
SP << emit_sgemm(
"ff_cell_gate", wc_offset);
644 if (direction == 0) {
645 if (
fType ==
"float") {
647 out <<
SP << emit_sgemm(
"ff_forget_gate", wf_offset);
650 if (
fType ==
"float") {
652 out <<
SP << emit_sgemm(
"ff_forget_gate", wf_offset);
659 if (direction == 0) {
660 if (
fType ==
"float") {
662 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_" <<
fNB <<
", &"
663 << OpName <<
"_incx, " << OpName <<
"_ff_input_gate, &" << OpName <<
"_incy);\n";
666 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_" <<
fNB <<
" + "
667 << bo_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_ff_output_gate, &" << OpName
671 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_" <<
fNB <<
" + "
672 << bc_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_ff_cell_gate, &" << OpName
676 if (
fType ==
"float") {
679 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_" <<
fNB <<
" + "
680 << bi_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_ff_input_gate, &" << OpName
685 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_" <<
fNB <<
" + "
686 << bo_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_ff_output_gate, &" << OpName
689 size_t bc_offset = 4 * num_directions * seq_length * batch_size *
fAttrHiddenSize +
691 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_" <<
fNB <<
" + "
692 << bc_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_ff_cell_gate, &" << OpName
698 if (direction == 0) {
699 if (
fType ==
"float") {
701 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_" <<
fNB
702 <<
" + " << bo_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_ff_forget_gate, &" << OpName
706 if (
fType ==
"float") {
709 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_" <<
fNB
710 <<
" + " << bo_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_ff_forget_gate, &" << OpName
719 out <<
SP <<
"for (size_t seq = 0; seq < " << seq_length <<
"; seq++) {\n";
721 if (direction == 0) {
722 out <<
SP <<
SP <<
"size_t gate_offset = seq * " << num_directions * batch_size *
fAttrHiddenSize <<
";\n";
724 out <<
SP <<
SP <<
"size_t gate_offset = seq * " << num_directions * batch_size *
fAttrHiddenSize <<
" + "
728 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_ff_input_gate + ff_offset, " << OpName
729 <<
"_ff_input_gate + ff_offset + " << ff_seq_size <<
", " << OpName <<
"_input_gate + gate_offset);\n";
730 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_ff_output_gate + ff_offset, " << OpName
731 <<
"_ff_output_gate + ff_offset + " << ff_seq_size <<
", " << OpName <<
"_output_gate + gate_offset);\n";
732 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_ff_cell_gate + ff_offset, " << OpName
733 <<
"_ff_cell_gate + ff_offset + " << ff_seq_size <<
", " << OpName <<
"_cell_gate + gate_offset);\n";
735 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_ff_forget_gate + ff_offset, " << OpName
736 <<
"_ff_forget_gate + ff_offset + " << ff_seq_size <<
", " << OpName <<
"_forget_gate + gate_offset);\n";
740 out <<
SP <<
"for (size_t seq = 0; seq < " << seq_length <<
"; seq++) {\n";
742 out <<
SP <<
SP <<
"size_t index = " << seq_length - 1 <<
" - seq;\n";
744 out <<
SP <<
SP <<
"size_t index = seq;\n";
746 out <<
SP <<
SP <<
"int m2 = " << batch_size <<
";\n";
747 if (direction == 0) {
748 out <<
SP <<
SP <<
"size_t offset = index * " << num_directions * batch_size *
fAttrHiddenSize <<
";\n";
750 out <<
SP <<
SP <<
"size_t offset = index * " << num_directions * batch_size *
fAttrHiddenSize <<
" + "
755 out <<
SP <<
SP <<
"if (seq == 0) {\n";
757 if (direction == 0) {
758 if (
fType ==
"float") {
759 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName
760 <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
", &" << OpName
761 <<
"_n, " << OpName <<
"_initial_hidden_state, &" << OpName <<
"_n, &" << OpName <<
"_alpha, "
762 << OpName <<
"_input_gate + offset, &" << OpName <<
"_n);\n";
764 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName
765 <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + " << ro_offset
766 <<
", &" << OpName <<
"_n, " << OpName <<
"_initial_hidden_state, &" << OpName <<
"_n, &" << OpName
767 <<
"_alpha, " << OpName <<
"_output_gate + offset, &" << OpName <<
"_n);\n";
769 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName
770 <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + " << rc_offset
771 <<
", &" << OpName <<
"_n, " << OpName <<
"_initial_hidden_state, &" << OpName <<
"_n, &" << OpName
772 <<
"_alpha, " << OpName <<
"_cell_gate + offset, &" << OpName <<
"_n);\n";
775 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
776 << OpName <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + "
777 << rf_offset <<
", &" << OpName <<
"_n, " << OpName <<
"_initial_hidden_state, &" << OpName
778 <<
"_n, &" << OpName <<
"_alpha, " << OpName <<
"_forget_gate + offset, &" << OpName <<
"_n);\n";
782 if (
fType ==
"float") {
784 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName
785 <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + " << ri_offset
786 <<
", &" << OpName <<
"_n, " << OpName <<
"_initial_hidden_state, &" << OpName <<
"_n, &" << OpName
787 <<
"_alpha, " << OpName <<
"_input_gate + offset, &" << OpName <<
"_n);\n";
789 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName
790 <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + " << ro_offset
791 <<
", &" << OpName <<
"_n, " << OpName <<
"_initial_hidden_state, &" << OpName <<
"_n, &" << OpName
792 <<
"_alpha, " << OpName <<
"_output_gate + offset, &" << OpName <<
"_n);\n";
794 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName
795 <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + " << rc_offset
796 <<
", &" << OpName <<
"_n, " << OpName <<
"_initial_hidden_state, &" << OpName <<
"_n, &" << OpName
797 <<
"_alpha, " << OpName <<
"_cell_gate + offset, &" << OpName <<
"_n);\n";
800 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
801 << OpName <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + "
802 << rf_offset <<
", &" << OpName <<
"_n, " << OpName <<
"_initial_hidden_state, &" << OpName
803 <<
"_n, &" << OpName <<
"_alpha, " << OpName <<
"_forget_gate + offset, &" << OpName <<
"_n);\n";
808 out <<
SP <<
SP <<
"} else {\n";
810 if (direction == 0) {
812 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
815 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (seq - 1) * "
818 if (
fType ==
"float") {
819 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName
820 <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
", &" << OpName <<
"_n, "
821 << OpName <<
"_hidden_state + previous_offset, &" << OpName <<
"_n, &" << OpName <<
"_alpha, " << OpName
822 <<
"_input_gate + offset, &" << OpName <<
"_n);\n";
824 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName
825 <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + " << ro_offset
826 <<
", &" << OpName <<
"_n, " << OpName <<
"_hidden_state + previous_offset, &" << OpName <<
"_n, &"
827 << OpName <<
"_alpha, " << OpName <<
"_output_gate + offset, &" << OpName <<
"_n);\n";
829 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName
830 <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + " << rc_offset
831 <<
", &" << OpName <<
"_n, " << OpName <<
"_hidden_state + previous_offset, &" << OpName <<
"_n, &"
832 << OpName <<
"_alpha, " << OpName <<
"_cell_gate + offset, &" << OpName <<
"_n);\n";
835 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName
836 <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + " << rf_offset
837 <<
", &" << OpName <<
"_n, " << OpName <<
"_hidden_state + previous_offset, &" << OpName <<
"_n, &"
838 << OpName <<
"_alpha, " << OpName <<
"_forget_gate + offset, &" << OpName <<
"_n);\n";
842 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
844 if (
fType ==
"float") {
846 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName
847 <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + " << ri_offset
848 <<
", &" << OpName <<
"_n, " << OpName <<
"_hidden_state + previous_offset, &" << OpName <<
"_n, &"
849 << OpName <<
"_alpha, " << OpName <<
"_input_gate + offset, &" << OpName <<
"_n);\n";
851 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName
852 <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + " << ro_offset
853 <<
", &" << OpName <<
"_n, " << OpName <<
"_hidden_state + previous_offset, &" << OpName <<
"_n, &"
854 << OpName <<
"_alpha, " << OpName <<
"_output_gate + offset, &" << OpName <<
"_n);\n";
856 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName
857 <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + " << rc_offset
858 <<
", &" << OpName <<
"_n, " << OpName <<
"_hidden_state + previous_offset, &" << OpName <<
"_n, &"
859 << OpName <<
"_alpha, " << OpName <<
"_cell_gate + offset, &" << OpName <<
"_n);\n";
862 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName
863 <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + " << rf_offset
864 <<
", &" << OpName <<
"_n, " << OpName <<
"_hidden_state + previous_offset, &" << OpName <<
"_n, &"
865 << OpName <<
"_alpha, " << OpName <<
"_forget_gate + offset, &" << OpName <<
"_n);\n";
869 out <<
SP <<
SP <<
"}\n";
873 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
874 if (
fType ==
"float") {
875 out <<
SP <<
SP <<
SP <<
"float x = (" << OpName <<
"_cell_gate[i] > " << -
fAttrClip <<
") ? " << OpName
876 <<
"_cell_gate[i] : " << -
fAttrClip <<
";\n";
879 out <<
SP <<
SP <<
"}\n";
883 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
884 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_cell_gate[i] < 0.)\n";
885 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_cell_gate[i] = 0.;\n";
886 out <<
SP <<
SP <<
"}\n";
888 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
889 if (
fType ==
"float") {
890 out <<
SP <<
SP <<
SP <<
"float ex = exp(-2 * " << OpName <<
"_cell_gate[i]);\n";
892 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_cell_gate[i] = (1. - ex) / (1. + ex);\n";
893 out <<
SP <<
SP <<
"}\n";
895 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
896 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_cell_gate[i] = 1. / (1. + exp(-" << OpName <<
"_cell_gate[i]));\n";
897 out <<
SP <<
SP <<
"}\n";
899 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
902 out <<
SP <<
SP <<
"}\n";
904 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
905 if (
fType ==
"float") {
907 <<
"_cell_gate[i]);\n";
910 <<
" * (1. - ex) / (1. + ex);\n";
911 out <<
SP <<
SP <<
"}\n";
913 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
914 if (
fType ==
"float") {
917 out <<
SP <<
SP <<
SP <<
"float b = (a > 0.) ? a : 0.;\n";
919 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_cell_gate[i] = (b < 1.) ? b : 1.;\n";
920 out <<
SP <<
SP <<
"}\n";
922 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
923 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_cell_gate[i] < 0.)\n";
925 << OpName <<
"_cell_gate[i];\n";
926 out <<
SP <<
SP <<
"}\n";
928 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
931 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_cell_gate[i] = 0.;\n";
932 out <<
SP <<
SP <<
"}";
934 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
935 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_cell_gate[i] < 0.)\n";
937 <<
" * exp(" << OpName <<
"_cell_gate[i] - 1.);\n";
938 out <<
SP <<
SP <<
"}\n";
940 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
941 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_cell_gate[i] = " << OpName <<
"_cell_gate[i] / (1. + abs(" << OpName
942 <<
"_cell_gate[i]));\n";
943 out <<
SP <<
SP <<
"}\n";
945 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
946 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_cell_gate[i] = log(1. + exp(" << OpName <<
"_cell_gate[i]));\n";
947 out <<
SP <<
SP <<
"}\n";
953 out <<
SP <<
SP <<
"if (seq == 0) {\n";
955 if (direction == 0) {
956 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
957 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input_gate[i + offset] += tensor_" <<
fNP <<
"[i] * "
958 << OpName <<
"_initial_cell_state[i];\n";
959 out <<
SP <<
SP <<
SP <<
"}\n";
962 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
963 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_forget_gate[i + offset] += tensor_" <<
fNP <<
"[i + "
964 << pf_offset <<
"] * " << OpName <<
"_initial_cell_state[i];\n";
965 out <<
SP <<
SP <<
SP <<
"}\n";
970 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
971 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input_gate[i + offset] += tensor_" <<
fNP <<
"[i + "
972 << pi_offset <<
"] * " << OpName <<
"_initial_cell_state[i + " << initial_c_offset <<
"];\n";
973 out <<
SP <<
SP <<
SP <<
"}\n";
976 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
977 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_forget_gate[i + offset] += tensor_" <<
fNP <<
"[i + "
978 << pf_offset <<
"] * " << OpName <<
"_initial_cell_state[i + " << initial_c_offset <<
"];\n";
979 out <<
SP <<
SP <<
SP <<
"}\n";
983 out <<
SP <<
SP <<
"} else {\n";
984 if (direction == 0) {
986 out <<
SP <<
SP <<
SP <<
"size_t c_offset = (index + 1) * "
989 out <<
SP <<
SP <<
SP <<
"size_t c_offset = (seq - 1) * "
992 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
993 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input_gate[i + offset] += tensor_" <<
fNP <<
"[i] * " << OpName
994 <<
"_cell_state[i + c_offset];\n";
995 out <<
SP <<
SP <<
SP <<
"}\n";
998 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
999 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_forget_gate[i + offset] += tensor_" <<
fNP <<
"[i + "
1000 << pf_offset <<
"] * " << OpName <<
"_cell_state[i + c_offset];\n";
1001 out <<
SP <<
SP <<
SP <<
"}\n";
1005 out <<
SP <<
SP <<
SP <<
"size_t c_offset = (index + 1) * " << num_directions * batch_size *
fAttrHiddenSize
1007 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
1008 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input_gate[i + offset] += tensor_" <<
fNP <<
"[i + " << pi_offset
1009 <<
"] * " << OpName <<
"_cell_state[i + c_offset];\n";
1010 out <<
SP <<
SP <<
SP <<
"}\n";
1013 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
1014 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_forget_gate[i + offset] += tensor_" <<
fNP <<
"[i + "
1015 << pf_offset <<
"] * " << OpName <<
"_cell_state[i + c_offset];\n";
1016 out <<
SP <<
SP <<
SP <<
"}\n";
1019 out <<
SP <<
SP <<
"}\n";
1024 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1025 if (
fType ==
"float") {
1026 out <<
SP <<
SP <<
SP <<
"float x = (" << OpName <<
"_input_gate[i] > " << -
fAttrClip <<
") ? " << OpName
1027 <<
"_input_gate[i] : " << -
fAttrClip <<
";\n";
1030 out <<
SP <<
SP <<
"}\n";
1034 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1035 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_input_gate[i] < 0.)\n";
1036 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input_gate[i] = 0.;\n";
1037 out <<
SP <<
SP <<
"}\n";
1039 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1040 if (
fType ==
"float") {
1041 out <<
SP <<
SP <<
SP <<
"float ex = exp(-2 * " << OpName <<
"_input_gate[i]);\n";
1043 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input_gate[i] = (1. - ex) / (1. + ex);\n";
1044 out <<
SP <<
SP <<
"}\n";
1046 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1047 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input_gate[i] = 1. / (1. + exp(-" << OpName
1048 <<
"_input_gate[i]));\n";
1049 out <<
SP <<
SP <<
"}\n";
1051 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1054 out <<
SP <<
SP <<
"}\n";
1056 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1057 if (
fType ==
"float") {
1059 <<
"_input_gate[i]);\n";
1062 <<
" * (1. - ex) / (1. + ex);\n";
1063 out <<
SP <<
SP <<
"}\n";
1065 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1066 if (
fType ==
"float") {
1069 out <<
SP <<
SP <<
SP <<
"float b = (a > 0.) ? a : 0.;\n";
1071 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input_gate[i] = (b < 1.) ? b : 1.;\n";
1072 out <<
SP <<
SP <<
"}\n";
1074 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1075 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_input_gate[i] < 0.)\n";
1077 << OpName <<
"_input_gate[i];\n";
1078 out <<
SP <<
SP <<
"}\n";
1080 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1083 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input_gate[i] = 0.;\n";
1084 out <<
SP <<
SP <<
"}";
1086 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1087 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_input_gate[i] < 0.)\n";
1089 <<
" * exp(" << OpName <<
"_input_gate[i] - 1.);\n";
1090 out <<
SP <<
SP <<
"}\n";
1092 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1093 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input_gate[i] = " << OpName <<
"_input_gate[i] / (1. + abs("
1094 << OpName <<
"_input_gate[i]));\n";
1095 out <<
SP <<
SP <<
"}\n";
1097 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1098 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input_gate[i] = log(1. + exp(" << OpName <<
"_input_gate[i]));\n";
1099 out <<
SP <<
SP <<
"}\n";
1105 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1106 if (
fType ==
"float") {
1107 out <<
SP <<
SP <<
SP <<
"float x = (" << OpName <<
"_forget_gate[i] > " << -
fAttrClip <<
") ? "
1108 << OpName <<
"_forget_gate[i] : " << -
fAttrClip <<
";\n";
1112 out <<
SP <<
SP <<
"}\n";
1116 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1117 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_forget_gate[i] < 0.)\n";
1118 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_forget_gate[i] = 0.;\n";
1119 out <<
SP <<
SP <<
"}\n";
1121 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1122 if (
fType ==
"float") {
1123 out <<
SP <<
SP <<
SP <<
"float ex = exp(-2 * " << OpName <<
"_forget_gate[i]);\n";
1125 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_forget_gate[i] = (1. - ex) / (1. + ex);\n";
1126 out <<
SP <<
SP <<
"}\n";
1128 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1129 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_forget_gate[i] = 1. / (1. + exp(-" << OpName
1130 <<
"_forget_gate[i]));\n";
1131 out <<
SP <<
SP <<
"}\n";
1133 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1135 <<
" * " << OpName <<
"_forget_gate[i] + " <<
fAttrActivationBeta[direction * 3] <<
";\n";
1136 out <<
SP <<
SP <<
"}\n";
1138 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1139 if (
fType ==
"float") {
1141 <<
"_forget_gate[i]);\n";
1144 <<
" * (1. - ex) / (1. + ex);\n";
1145 out <<
SP <<
SP <<
"}\n";
1147 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1148 if (
fType ==
"float") {
1151 out <<
SP <<
SP <<
SP <<
"float b = (a > 0.) ? a : 0.;\n";
1153 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_forget_gate[i] = (b < 1.) ? b : 1.;\n";
1154 out <<
SP <<
SP <<
"}\n";
1156 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1157 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_forget_gate[i] < 0.)\n";
1159 <<
" * " << OpName <<
"_forget_gate[i];\n";
1160 out <<
SP <<
SP <<
"}\n";
1162 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1165 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_forget_gate[i] = 0.;\n";
1166 out <<
SP <<
SP <<
"}";
1168 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1169 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_forget_gate[i] < 0.)\n";
1171 <<
" * exp(" << OpName <<
"_forget_gate[i] - 1.);\n";
1172 out <<
SP <<
SP <<
"}\n";
1174 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1175 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_forget_gate[i] = " << OpName <<
"_forget_gate[i] / (1. + abs("
1176 << OpName <<
"_forget_gate[i]));\n";
1177 out <<
SP <<
SP <<
"}\n";
1179 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1180 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_forget_gate[i] = log(1. + exp(" << OpName
1181 <<
"_forget_gate[i]));\n";
1182 out <<
SP <<
SP <<
"}\n";
1187 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1188 out <<
SP <<
SP <<
SP << OpName <<
"_cell_state[i] = " << OpName <<
"_input_gate[i] * " << OpName
1189 <<
"_cell_gate[i];\n";
1190 out <<
SP <<
SP <<
"}\n";
1193 out <<
SP <<
SP <<
"if (seq == 0) {\n";
1196 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
1197 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_cell_state[i + offset] += " << OpName
1198 <<
"_forget_gate[i + offset] * " << OpName <<
"_initial_cell_state[i];\n";
1199 out <<
SP <<
SP <<
SP <<
"}\n";
1201 out <<
SP <<
SP <<
"} else {\n";
1203 if (direction == 0) {
1205 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
1208 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (seq - 1) * "
1212 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
1215 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
1216 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_cell_state[i + offset] += " << OpName
1217 <<
"_forget_gate[i + offset] * " << OpName <<
"_cell_state[i + previous_offset];\n";
1218 out <<
SP <<
SP <<
SP <<
"}\n";
1219 out <<
SP <<
SP <<
"}\n";
1224 if (direction == 0) {
1226 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
1227 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_output_gate[i + offset] += tensor_" <<
fNP <<
"[i + " << p_offset
1228 <<
"] * " << OpName <<
"_cell_state[i + offset];\n";
1229 out <<
SP <<
SP <<
SP <<
"}\n";
1232 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
1233 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_output_gate[i + offset] += tensor_" <<
fNP <<
"[i + " << p_offset
1234 <<
"] * " << OpName <<
"_cell_state[i + offset];\n";
1235 out <<
SP <<
SP <<
SP <<
"}\n";
1241 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1242 if (
fType ==
"float") {
1243 out <<
SP <<
SP <<
SP <<
"float x = (" << OpName <<
"_output_gate[i] > " << -
fAttrClip <<
") ? " << OpName
1244 <<
"_output_gate[i] : " << -
fAttrClip <<
";\n";
1247 out <<
SP <<
SP <<
"}\n";
1251 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1252 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_output_gate[i] < 0.)\n";
1253 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_output_gate[i] = 0.;\n";
1254 out <<
SP <<
SP <<
"}\n";
1256 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1257 if (
fType ==
"float") {
1258 out <<
SP <<
SP <<
SP <<
"float ex = exp(-2 * " << OpName <<
"_output_gate[i]);\n";
1260 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_output_gate[i] = (1. - ex) / (1. + ex);\n";
1261 out <<
SP <<
SP <<
"}\n";
1263 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1264 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_output_gate[i] = 1. / (1. + exp(-" << OpName
1265 <<
"_output_gate[i]));\n";
1266 out <<
SP <<
SP <<
"}\n";
1268 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1271 out <<
SP <<
SP <<
"}\n";
1273 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1274 if (
fType ==
"float") {
1276 <<
"_output_gate[i]);\n";
1279 <<
" * (1. - ex) / (1. + ex);\n";
1280 out <<
SP <<
SP <<
"}\n";
1282 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1283 if (
fType ==
"float") {
1286 out <<
SP <<
SP <<
SP <<
"float b = (a > 0.) ? a : 0.;\n";
1288 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_output_gate[i] = (b < 1.) ? b : 1.;\n";
1289 out <<
SP <<
SP <<
"}\n";
1291 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1292 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_output_gate[i] < 0.)\n";
1294 << OpName <<
"_output_gate[i];\n";
1295 out <<
SP <<
SP <<
"}\n";
1297 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1300 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_output_gate[i] = 0.;\n";
1301 out <<
SP <<
SP <<
"}";
1303 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1304 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_output_gate[i] < 0.)\n";
1306 <<
" * exp(" << OpName <<
"_output_gate[i] - 1.);\n";
1307 out <<
SP <<
SP <<
"}\n";
1309 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1310 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_output_gate[i] = " << OpName <<
"_output_gate[i] / (1. + abs("
1311 << OpName <<
"_output_gate[i]));\n";
1312 out <<
SP <<
SP <<
"}\n";
1314 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1315 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_output_gate[i] = log(1. + exp(" << OpName <<
"_output_gate[i]));\n";
1316 out <<
SP <<
SP <<
"}\n";
1320 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_cell_state + offset, " << OpName <<
"_cell_state + offset + "
1321 <<
size <<
", " << OpName <<
"_new_cell_state + offset);\n";
1324 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1325 if (
fType ==
"float") {
1326 out <<
SP <<
SP <<
SP <<
"float x = (" << OpName <<
"_new_cell_state[i] > " << -
fAttrClip <<
") ? "
1327 << OpName <<
"_new_cell_state[i] : " << -
fAttrClip <<
";\n";
1331 out <<
SP <<
SP <<
"}\n";
1335 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1336 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_new_cell_state[i] < 0.)\n";
1337 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_new_cell_state[i] = 0.;\n";
1338 out <<
SP <<
SP <<
"}\n";
1340 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1341 if (
fType ==
"float") {
1342 out <<
SP <<
SP <<
SP <<
"float ex = exp(-2 * " << OpName <<
"_new_cell_state[i]);\n";
1344 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_new_cell_state[i] = (1. - ex) / (1. + ex);\n";
1345 out <<
SP <<
SP <<
"}\n";
1347 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1348 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_new_cell_state[i] = 1. / (1. + exp(-" << OpName
1349 <<
"_new_cell_state[i]));\n";
1350 out <<
SP <<
SP <<
"}\n";
1352 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1354 <<
" * " << OpName <<
"_new_cell_state[i] + " <<
fAttrActivationBeta[direction * 3 + 2] <<
";\n";
1355 out <<
SP <<
SP <<
"}\n";
1357 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1358 if (
fType ==
"float") {
1360 <<
"_new_cell_state[i]);\n";
1363 <<
" * (1. - ex) / (1. + ex);\n";
1364 out <<
SP <<
SP <<
"}\n";
1366 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1367 if (
fType ==
"float") {
1370 out <<
SP <<
SP <<
SP <<
"float b = (a > 0.) ? a : 0.;\n";
1372 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_new_cell_state[i] = (b < 1.) ? b : 1.;\n";
1373 out <<
SP <<
SP <<
"}\n";
1375 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1376 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_new_cell_state[i] < 0.)\n";
1378 <<
" * " << OpName <<
"_new_cell_state[i];\n";
1379 out <<
SP <<
SP <<
"}\n";
1381 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1384 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_new_cell_state[i] = 0.;\n";
1385 out <<
SP <<
SP <<
"}";
1387 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1388 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_new_cell_state[i] < 0.)\n";
1390 <<
" * exp(" << OpName <<
"_new_cell_state[i] - 1.);\n";
1391 out <<
SP <<
SP <<
"}\n";
1393 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1394 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_new_cell_state[i] = " << OpName <<
"_new_cell_state[i] / (1. + abs("
1395 << OpName <<
"_new_cell_state[i]));\n";
1396 out <<
SP <<
SP <<
"}\n";
1398 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1399 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_new_cell_state[i] = log(1. + exp(" << OpName
1400 <<
"_new_cell_state[i]));\n";
1401 out <<
SP <<
SP <<
"}\n";
1405 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1406 out <<
SP <<
SP <<
SP << OpName <<
"_hidden_state[i] = " << OpName <<
"_output_gate[i] * " << OpName
1407 <<
"_new_cell_state[i];\n";
1408 out <<
SP <<
SP <<
"}\n";
1414 out <<
SP <<
"for (size_t seq = 0; seq < " << seq_length <<
"; seq++) {\n";
1415 out <<
SP <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1417 for (
size_t direction = 0; direction < num_directions; direction++) {
1419 out <<
SP <<
SP <<
SP <<
SP <<
SP <<
SP <<
"size_t idx = seq * "
1422 out <<
SP <<
SP <<
SP <<
SP <<
SP <<
SP << OpName <<
"_cell_state[idx] = 0.;\n";
1423 out <<
SP <<
SP <<
SP <<
SP <<
SP <<
SP << OpName <<
"_hidden_state[idx] = 0.;\n";
1426 out <<
SP <<
SP <<
SP <<
"}\n";
1427 out <<
SP <<
SP <<
"}\n";
1433 if (!
fNY_h.empty()) {
1438 out <<
SP <<
"std::copy(" << OpName <<
"_hidden_state, " << OpName <<
"_hidden_state + " << y_h_size
1439 <<
", tensor_" <<
fNY_h <<
");\n";
1441 size_t offset = (seq_length - 1) * num_directions * batch_size *
fAttrHiddenSize;
1442 out <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + " << offset <<
", " << OpName
1443 <<
"_hidden_state + " << offset <<
" + " << y_h_size <<
", tensor_" <<
fNY_h <<
");\n";
1445 if (num_directions == 2) {
1446 out <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + " << y_h_size <<
", " << OpName
1447 <<
"_hidden_state + " << 2 * y_h_size <<
", tensor_" <<
fNY_h <<
" + " << y_h_size <<
");\n";
1451 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1453 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + offset, " << OpName
1457 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1459 out <<
SP <<
SP <<
"size_t offset = seq * " << num_directions * batch_size *
fAttrHiddenSize
1462 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + offset, " << OpName
1463 <<
"_hidden_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY_h <<
" + y_h_offset);\n";
1466 if (num_directions == 2) {
1467 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1470 out <<
SP <<
SP <<
"size_t y_h_offset = " << batch_size *
fAttrHiddenSize <<
" + batch * "
1472 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + offset, " << OpName
1473 <<
"_hidden_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY_h <<
" + y_h_offset);\n";
1478 if (!
fNY_c.empty()) {
1483 out <<
SP <<
"std::copy(" << OpName <<
"_cell_state, " << OpName <<
"_hidden_state + " << y_h_size
1484 <<
", tensor_" <<
fNY_c <<
");\n";
1486 size_t offset = (seq_length - 1) * num_directions * batch_size *
fAttrHiddenSize;
1487 out <<
SP <<
"std::copy(" << OpName <<
"_cell_state + " << offset <<
", " << OpName <<
"_cell_state + "
1488 << offset <<
" + " << y_h_size <<
", tensor_" <<
fNY_c <<
");\n";
1490 if (num_directions == 2) {
1491 out <<
SP <<
"std::copy(" << OpName <<
"_cell_state + " << y_h_size <<
", " << OpName <<
"_cell_state + "
1492 << 2 * y_h_size <<
", tensor_" <<
fNY_c <<
" + " << y_h_size <<
");\n";
1496 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1498 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_cell_state + offset, " << OpName
1502 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1504 out <<
SP <<
SP <<
"size_t offset = seq * " << num_directions * batch_size *
fAttrHiddenSize
1507 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_cell_state + offset, " << OpName
1508 <<
"_cell_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY_c <<
" + y_h_offset);\n";
1511 if (num_directions == 2) {
1512 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1515 out <<
SP <<
SP <<
"size_t y_h_offset = " << batch_size *
fAttrHiddenSize <<
" + batch * "
1517 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_cell_state + offset, " << OpName
1518 <<
"_cell_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY_c <<
" + y_h_offset);\n";
1526 for (
size_t direction = 0; direction < num_directions; direction++) {
1527 out <<
SP <<
"for (size_t seq = 0; seq < " << seq_length <<
"; seq++) {\n";
1528 out <<
SP <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1529 out <<
SP <<
SP <<
SP <<
"size_t offset = seq * " << num_directions * batch_size *
fAttrHiddenSize <<
" + "
1531 out <<
SP <<
SP <<
SP <<
"size_t y_offset = batch * " << seq_length * num_directions *
fAttrHiddenSize
1533 out <<
SP <<
SP <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + offset, " << OpName
1534 <<
"_hidden_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY <<
" + y_offset);\n";
1535 out <<
SP <<
SP <<
"}\n";
1539 if (!
fNY_h.empty()) {
1542 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1544 out <<
SP <<
SP <<
"size_t y_h_offset = batch * " << num_directions *
fAttrHiddenSize <<
";\n";
1545 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + offset, " << OpName
1546 <<
"_hidden_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY_h <<
" + y_h_offset);\n";
1549 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1551 out <<
SP <<
SP <<
"size_t seq = " << seq_length - 1 <<
";\n";
1555 out <<
SP <<
SP <<
"size_t offset = seq * " << num_directions * batch_size *
fAttrHiddenSize
1557 out <<
SP <<
SP <<
"size_t y_h_offset = batch * " << num_directions *
fAttrHiddenSize <<
";\n";
1558 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + offset, " << OpName
1559 <<
"_hidden_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY_h <<
" + y_h_offset);\n";
1562 if (num_directions == 2) {
1563 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1566 out <<
SP <<
SP <<
"size_t y_h_offset = batch * " << num_directions *
fAttrHiddenSize <<
" + "
1568 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + offset, " << OpName
1569 <<
"_hidden_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY_h <<
" + y_h_offset);\n";
1574 if (!
fNY_c.empty()) {
1577 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1579 out <<
SP <<
SP <<
"size_t y_h_offset = batch * " << num_directions *
fAttrHiddenSize <<
";\n";
1580 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_cell_state + offset, " << OpName <<
"_cell_state + offset + "
1584 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1586 out <<
SP <<
SP <<
"size_t seq = " << seq_length - 1 <<
";\n";
1590 out <<
SP <<
SP <<
"size_t offset = seq * " << num_directions * batch_size *
fAttrHiddenSize
1592 out <<
SP <<
SP <<
"size_t y_h_offset = batch * " << num_directions *
fAttrHiddenSize <<
";\n";
1593 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_cell_state + offset, " << OpName <<
"_cell_state + offset + "
1597 if (num_directions == 2) {
1598 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1601 out <<
SP <<
SP <<
"size_t y_h_offset = batch * " << num_directions *
fAttrHiddenSize <<
" + "
1603 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_cell_state + offset, " << OpName <<
"_cell_state + offset + "
ROperator_LSTM(std::vector< float > activation_alpha, std::vector< float > activation_beta, std::vector< std::string > activations, float clip, std::string direction, size_t hidden_size, size_t input_forget, size_t layout, std::string nameX, std::string nameW, std::string nameR, std::string nameB, std::string nameSequence_lens, std::string nameInitial_h, std::string nameInitial_c, std::string nameP, std::string nameY, std::string nameY_h, std::string nameY_c)
Constructor of ROperator_LSTM from the attributes.