43 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNX +
" is not found in model.");
46 if (fShapeX.size() != 3) {
47 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNX +
" is not of 3 dimensions.");
50 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNW +
" is not found in model.");
53 if (fShapeW.size() != 3) {
54 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNW +
" is not of 3 dimensions.");
57 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNR +
" is not found in model.");
60 if (fShapeR.size() != 3) {
61 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNR +
" is not of 3 dimensions.");
65 throw std::runtime_error(
"TMVA SOFIE GRU op input tensor " + fNB +
" is not found in model.");
68 if (fShapeB.size() != 2 && fShapeB.size() != 4) {
69 throw std::runtime_error(
"TMVA SOFIE GRU op input tensor " + fNB +
" is not of 2 or 4 dimensions.");
71 if (fShapeB.size() == 2) {
75 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
76 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
77 if (fType ==
"float") {
81 for (
size_t i = 0; i < 6; i++) {
102 if (!fNSequence_lens.empty()) {
104 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
106 "is not found in model.");
109 if (fShapeSequence_lens.size() != 1) {
110 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
112 " is not of 1 dimension.");
115 if (!fNInitial_h.empty()) {
117 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
118 fNInitial_h +
" is not found in model.");
121 if (fShapeInitial_h.size() != 3) {
122 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
123 fNInitial_h +
" is not of 3 dimensions.");
127 fShapeY = ShapeInference({fShapeX, fShapeW})[0];
132 if (!fNY_h.empty()) {
133 fShapeY_h = ShapeInference({fShapeX, fShapeW})[1];
146 throw std::runtime_error(
"TMVA SOFIE - Activation function " +
150 if (fAttrDirection ==
"reverse") fAttrDirection =
"backward";
151 if (fAttrDirection !=
"forward" && fAttrDirection !=
"backward" &&
152 fAttrDirection !=
"reverse" &&
153 fAttrDirection !=
"bidirectional") {
154 throw std::runtime_error(
155 "TMVA SOFIE - Invalid GRU direction fAttrDirection = " +
158 if (3 * fAttrHiddenSize != fShapeW[1]) {
159 throw std::runtime_error(
160 "TMVA SOFIE - fAttrHiddenSize must be equal to " +
161 std::to_string(fShapeW[1] / 3));
163 if (fAttrLayout > 1) {
164 throw std::runtime_error(
"TMVA SOFIE - Layout fAttrLayout = " +
165 std::to_string(fAttrLayout) +
166 " must be 0 (timewise) or 1 (batchwise)");
168 if (fAttrLinearBeforeReset > 1) {
169 throw std::runtime_error(
170 "TMVA SOFIE - fAttrInputForget = " + std::to_string(fAttrLinearBeforeReset)
171 +
" must be 0 or 1.");
173 if (fAttrActivations.empty()) {
174 if (fAttrDirection ==
"bidirectional") {
175 fAttrActivations = {
"Sigmoid",
"Tanh",
"Sigmoid",
"Tanh"};
177 fAttrActivations = {
"Sigmoid",
"Tanh"};
232 std::stringstream out;
234 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
235 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
240 if (fAttrLayout == 0) {
241 out << SP << fType <<
" *" <<
OpName <<
"_input = tensor_" << fNX <<
";\n";
244 out << SP << fType <<
" * " <<
OpName <<
"_input = fVec_" <<
OpName <<
"_input.data();\n";
248 out << SP <<
"for(size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
249 out << SP << SP <<
"for(size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
250 out << SP << SP << SP <<
"for(size_t i = 0; i < " <<
input_size <<
"; i++) {\n";
252 <<
" + batch * " <<
input_size <<
" + i] = " <<
"tensor_" << fNX <<
"[batch * "
254 out << SP << SP << SP <<
"}\n";
255 out << SP << SP <<
"}\n";
260 if (!fNInitial_h.empty()) {
261 if (fAttrLayout == 0) {
262 out << SP << fType <<
" *" <<
OpName <<
"_initial_hidden_state = " <<
" tensor_"
263 << fNInitial_h <<
";\n";
266 out << SP << fType <<
" * " <<
OpName <<
"_initial_hidden_state = fVec_" <<
OpName
267 <<
"_initial_hidden_state.data();\n";
270 fAttrHiddenSize <<
"];\n";
273 out << SP <<
"for(size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
274 out << SP << SP <<
"for(size_t h = 0; h < " << fAttrHiddenSize <<
"; h++) {\n";
275 out << SP << SP << SP <<
OpName <<
"_initial_hidden_state["
277 <<
" + h] = tensor_" << fNInitial_h <<
"[batch * " <<
num_directions * fAttrHiddenSize
278 <<
" + " <<
direction * fAttrHiddenSize <<
" + h];\n";
279 out << SP << SP <<
"}\n";
288 out << SP << fType <<
" * " <<
OpName <<
"_f_update_gate = fVec_" <<
OpName <<
"_f_update_gate.data();\n";
289 out << SP << fType <<
" * " <<
OpName <<
"_f_reset_gate = fVec_" <<
OpName <<
"_f_reset_gate.data();\n";
290 out << SP << fType <<
" * " <<
OpName <<
"_f_hidden_gate = fVec_" <<
OpName <<
"_f_hidden_gate.data();\n";
299 out << SP << fType <<
" * " <<
OpName <<
"_update_gate = fVec_" <<
OpName <<
"_update_gate.data();\n";
300 out << SP << fType <<
" * " <<
OpName <<
"_reset_gate = fVec_" <<
OpName <<
"_reset_gate.data();\n";
301 out << SP << fType <<
" * " <<
OpName <<
"_hidden_gate = fVec_" <<
OpName <<
"_hidden_gate.data();\n";
308 if (fAttrLayout == 0 && !fNY.empty()) {
309 out << SP << fType <<
" *" <<
OpName <<
"_hidden_state = tensor_" << fNY <<
";\n";
312 out << SP << fType <<
" * " <<
OpName <<
"_hidden_state = fVec_" <<
OpName <<
"_hidden_state.data();\n";
319 out << SP << fType <<
" * " <<
OpName <<
"_feedback = fVec_" <<
OpName <<
"_feedback.data();\n";
321 out << SP << fType <<
" " <<
OpName <<
"_feedback[" <<
batch_size * fAttrHiddenSize <<
"] = {0};\n";
324 out << SP <<
"char " <<
OpName <<
"_transA = 'N';\n";
325 out << SP <<
"char " <<
OpName <<
"_transB = 'T';\n";
328 out << SP <<
"int " <<
OpName <<
"_n = " << fAttrHiddenSize <<
";\n";
330 if (fType ==
"float") {
331 out << SP <<
"float " <<
OpName <<
"_alpha = 1.;\n";
332 out << SP <<
"float " <<
OpName <<
"_beta = 0.;\n";
337 out << SP <<
"int " <<
OpName <<
"_incx = 1;\n";
338 out << SP <<
"int " <<
OpName <<
"_incy = 1;\n";
339 out << SP <<
"int " <<
OpName <<
"_feedback_size = " <<
batch_size * fAttrHiddenSize <<
";\n";
343 if (fType ==
"float") {
345 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
351 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
357 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
363 if (fType ==
"float") {
366 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
372 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
378 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
387 if (fType ==
"float") {
389 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
390 << fNB <<
", &" <<
OpName <<
"_incx, " <<
OpName <<
"_f_update_gate, &" <<
OpName <<
"_incy);\n";
393 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
398 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
404 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
409 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
412 if (fAttrLinearBeforeReset == 0) {
415 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
417 <<
"_f_hidden_gate, &" <<
OpName <<
"_incy);\n";
421 if (fType ==
"float") {
424 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
430 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
435 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
440 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
445 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
448 if (fAttrLinearBeforeReset == 0) {
451 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
453 <<
"_f_hidden_gate, &" <<
OpName <<
"_incy);\n";
460 out << SP <<
"for (size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
461 out << SP << SP <<
"size_t offset = seq * " <<
batch_size * fAttrHiddenSize <<
";\n";
467 <<
" + " <<
batch_size * fAttrHiddenSize <<
";\n";
470 out << SP << SP <<
"std::copy(" <<
OpName <<
"_f_update_gate + offset, " <<
OpName
471 <<
"_f_update_gate + offset + " <<
f_seq_size <<
", " <<
OpName <<
"_update_gate + gate_offset);\n";
472 out << SP << SP <<
"std::copy(" <<
OpName <<
"_f_reset_gate + offset, " <<
OpName
473 <<
"_f_reset_gate + offset + " <<
f_seq_size <<
", " <<
OpName <<
"_reset_gate + gate_offset);\n";
474 out << SP << SP <<
"std::copy(" <<
OpName <<
"_f_hidden_gate + offset, " <<
OpName
475 <<
"_f_hidden_gate + offset + " <<
f_seq_size <<
", " <<
OpName <<
"_hidden_gate + gate_offset);\n";
478 out << SP <<
"for (size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
479 if (fAttrDirection ==
"backward" ||
direction == 1) {
480 out << SP << SP <<
"size_t index = " <<
seq_length - 1 <<
" - seq;\n";
482 out << SP << SP <<
"size_t index = seq;\n";
484 out << SP << SP <<
"int m2 = " <<
batch_size <<
";\n";
490 <<
" + " <<
batch_size * fAttrHiddenSize <<
";\n";
494 out << SP << SP <<
"if (seq == 0) {\n";
495 if (!fNInitial_h.empty()) {
497 if (fType ==
"float") {
498 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
499 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
", &"
501 <<
"_alpha, " <<
OpName <<
"_update_gate + offset, &" <<
OpName <<
"_n);\n";
502 size_t rr_offset = fAttrHiddenSize * fAttrHiddenSize;
503 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
504 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
506 <<
"_n, &" <<
OpName <<
"_alpha, " <<
OpName <<
"_reset_gate + offset, &" <<
OpName <<
"_n);\n";
509 if (fType ==
"float") {
510 size_t rz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
511 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
512 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
514 <<
"_n, &" <<
OpName <<
"_alpha, " <<
OpName <<
"_update_gate + offset, &" <<
OpName <<
"_n);\n";
515 size_t rr_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
516 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
517 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
519 <<
"_n, &" <<
OpName <<
"_alpha, " <<
OpName <<
"_reset_gate + offset, &" <<
OpName <<
"_n);\n";
523 out << SP << SP <<
"} else {\n";
526 if (fAttrDirection ==
"backward") {
527 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
530 out << SP << SP << SP <<
"size_t previous_offset = (seq - 1) * "
533 if (fType ==
"float") {
534 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
535 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
", &"
536 <<
OpName <<
"_n, " <<
OpName <<
"_hidden_state + previous_offset, &" <<
OpName <<
"_n, &"
537 <<
OpName <<
"_alpha, " <<
OpName <<
"_update_gate + offset, &" <<
OpName <<
"_n);\n";
538 size_t rr_offset = fAttrHiddenSize * fAttrHiddenSize;
539 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
540 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
546 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
548 if (fType ==
"float") {
549 size_t rz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
550 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
551 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
555 size_t rr_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
556 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
557 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
563 out << SP << SP <<
"}\n";
566 if (fAttrClip > .0) {
567 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
568 if (fType ==
"float") {
569 out << SP << SP << SP <<
"float z = (" <<
OpName <<
"_update_gate[i] > " << -fAttrClip
570 <<
") ? " <<
OpName <<
"_update_gate[i] : " << -fAttrClip <<
";\n";
572 out << SP << SP << SP <<
OpName <<
"_update_gate[i] = (z < " << fAttrClip
573 <<
") ? z : " << fAttrClip <<
";\n";
574 if (fType ==
"float") {
575 out << SP << SP << SP <<
"float r = (" <<
OpName <<
"_reset_gate[i] > " << -fAttrClip
576 <<
") ? " <<
OpName <<
"_reset_gate[i] : " << -fAttrClip <<
";\n";
578 out << SP << SP << SP <<
OpName <<
"_reset_gate[i] = (r < " << fAttrClip
579 <<
") ? r : " << fAttrClip <<
";\n";
580 out << SP << SP <<
"}\n";
584 if (fAttrActivations[
direction * 2] ==
"Relu") {
585 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
586 out << SP << SP << SP <<
"if (" <<
OpName <<
"_update_gate[i] < 0.)\n";
587 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = 0.;\n";
588 out << SP << SP << SP <<
"if (" <<
OpName <<
"_reset_gate[i] < 0.)\n";
589 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = 0.;\n";
590 out << SP << SP <<
"}\n";
591 }
else if (fAttrActivations[
direction * 2] ==
"Tanh") {
592 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
593 if (fType ==
"float") {
594 out << SP << SP << SP <<
"float z = exp(-2 * " <<
OpName <<
"_update_gate[i]);\n";
596 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = (1. - z) / (1. + z);\n";
597 if (fType ==
"float") {
598 out << SP << SP << SP <<
"float r = exp(-2 * " <<
OpName <<
"_reset_gate[i]);\n";
600 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = (1. - r) / (1. + r);\n";
601 out << SP << SP <<
"}\n";
602 }
else if (fAttrActivations[
direction * 2] ==
"Sigmoid") {
603 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
604 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = 1. / (1. + exp(-"
605 <<
OpName <<
"_update_gate[i]));\n";
606 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = 1. / (1. + exp(-"
607 <<
OpName <<
"_reset_gate[i]));\n";
608 out << SP << SP <<
"}\n";
609 }
else if (fAttrActivations[
direction * 2] ==
"Affine") {
610 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
611 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = "
612 << fAttrActivationAlpha[
direction * 2] <<
" * " <<
OpName <<
"_update_gate[i] + "
613 << fAttrActivationBeta[
direction * 2] <<
";\n";
614 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = "
615 << fAttrActivationAlpha[
direction * 2] <<
" * " <<
OpName <<
"_reset_gate[i] + "
616 << fAttrActivationBeta[
direction * 2] <<
";\n";
617 out << SP << SP <<
"}\n";
618 }
else if (fAttrActivations[
direction * 2] ==
"ScaledTanh") {
619 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
620 if (fType ==
"float") {
621 out << SP << SP << SP <<
"float z = exp(-2 * " << fAttrActivationBeta[
direction * 2]
622 <<
" * "<<
OpName <<
"_update_gate[i]);\n";
624 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = "
625 << fAttrActivationAlpha[
direction * 2] <<
" * (1. - z) / (1. + z);\n";
626 if (fType ==
"float") {
627 out << SP << SP << SP <<
"float r = exp(-2 * " << fAttrActivationBeta[
direction * 2]
628 <<
" * "<<
OpName <<
"_reset_gate[i]);\n";
630 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = "
631 << fAttrActivationAlpha[
direction * 2] <<
" * (1. - r) / (1. + r);\n";
632 out << SP << SP <<
"}\n";
633 }
else if (fAttrActivations[
direction * 2] ==
"HardSigmoid") {
634 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
635 if (fType ==
"float") {
636 out << SP << SP << SP <<
"float za = " << fAttrActivationAlpha[
direction * 2] <<
" * "
637 <<
OpName <<
"_update_gate[i] + " << fAttrActivationBeta[
direction * 2] <<
";\n";
638 out << SP << SP << SP <<
"float zb = (za > 0.) ? za : 0.;\n";
640 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = (zb < 1.) ? zb : 1.;\n";
641 if (fType ==
"float") {
642 out << SP << SP << SP <<
"float ra = " << fAttrActivationAlpha[
direction * 2] <<
" * "
643 <<
OpName <<
"_reset_gate[i] + " << fAttrActivationBeta[
direction * 2] <<
";\n";
644 out << SP << SP << SP <<
"float rb = (ra > 0.) ? ra : 0.;\n";
646 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = (rb < 1.) ? rb : 1.;\n";
647 out << SP << SP <<
"}\n";
648 }
else if (fAttrActivations[
direction * 2] ==
"LeakyRelu") {
649 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
650 out << SP << SP << SP <<
"if (" <<
OpName <<
"_update_gate[i] < 0.)\n";
651 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = "
652 << fAttrActivationAlpha[
direction * 2] <<
" * " <<
OpName <<
"_update_gate[i];\n";
653 out << SP << SP << SP <<
"if (" <<
OpName <<
"_reset_gate[i] < 0.)\n";
654 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = "
655 << fAttrActivationAlpha[
direction * 2] <<
" * " <<
OpName <<
"_reset_gate[i];\n";
656 out << SP << SP <<
"}\n";
657 }
else if (fAttrActivations[
direction * 2] ==
"ThresholdRelu") {
658 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
659 out << SP << SP << SP <<
"if (" <<
OpName <<
"_update_gate[i] < "
660 << fAttrActivationAlpha[
direction * 2] <<
")\n";
661 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = 0.;\n";
662 out << SP << SP << SP <<
"if (" <<
OpName <<
"_reset_gate[i] < "
663 << fAttrActivationAlpha[
direction * 2] <<
")\n";
664 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = 0.;\n";
665 out << SP << SP <<
"}";
666 }
else if (fAttrActivations[
direction * 2] ==
"Elu") {
667 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
668 out << SP << SP << SP <<
"if (" <<
OpName <<
"_update_gate[i] < 0.)\n";
669 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = "
670 << fAttrActivationAlpha[
direction * 2] <<
" * exp(" <<
OpName <<
"_update_gate[i] - 1.);\n";
671 out << SP << SP << SP <<
"if (" <<
OpName <<
"_reset_gate[i] < 0.)\n";
672 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = "
673 << fAttrActivationAlpha[
direction * 2] <<
" * exp(" <<
OpName <<
"_reset_gate[i] - 1.);\n";
674 out << SP << SP <<
"}\n";
675 }
else if (fAttrActivations[
direction * 2] ==
"Softsign") {
676 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
677 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = " <<
OpName
678 <<
"_update_gate[i] / (1. + abs(" <<
OpName <<
"_update_gate[i]));\n";
679 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = " <<
OpName
680 <<
"_reset_gate[i] / (1. + abs(" <<
OpName <<
"_reset_gate[i]));\n";
681 out << SP << SP <<
"}\n";
683 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
684 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = log(1. + exp("
685 <<
OpName <<
"_update_gate[i]));\n";
686 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = log(1. + exp("
687 <<
OpName <<
"_reset_gate[i]));\n";
688 out << SP << SP <<
"}\n";
691 if (fAttrLinearBeforeReset == 0) {
692 out << SP << SP <<
"if (seq == 0) {\n";
693 if (!fNInitial_h.empty()) {
695 out << SP << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
696 out << SP << SP << SP << SP <<
OpName <<
"_feedback[i] = " <<
OpName
697 <<
"_reset_gate[i + offset] * " <<
OpName <<
"_initial_hidden_state[i];\n";
698 out << SP << SP << SP <<
"}\n";
700 out << SP << SP <<
"} else {\n";
703 if (fAttrDirection ==
"backward") {
704 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
707 out << SP << SP << SP <<
"size_t previous_offset = (seq - 1) * "
711 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * " <<
num_directions
714 out << SP << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
715 out << SP << SP << SP << SP <<
OpName <<
"_feedback[i] = " <<
OpName
716 <<
"_reset_gate[i + offset] * " <<
OpName <<
"_hidden_state[i + previous_offset];\n";
717 out << SP << SP << SP <<
"}\n";
718 out << SP << SP <<
"}\n";
721 2 * fAttrHiddenSize * fAttrHiddenSize : 3 * fAttrHiddenSize * fAttrHiddenSize
722 + 2 * fAttrHiddenSize * fAttrHiddenSize;
723 out << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
731 ? 2 * fAttrHiddenSize * fAttrHiddenSize
732 : 3 * fAttrHiddenSize * fAttrHiddenSize + 2 * fAttrHiddenSize * fAttrHiddenSize;
733 out << SP << SP <<
"if (seq == 0) {\n";
734 if (!fNInitial_h.empty()) {
736 out << SP << SP << SP
737 <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &" <<
OpName <<
"_n, &"
738 <<
OpName <<
"_m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
742 out << SP << SP <<
"} else {\n";
745 if (fAttrDirection ==
"backward") {
746 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
749 out << SP << SP << SP <<
"size_t previous_offset = (seq - 1) * "
753 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * " <<
num_directions
756 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
761 out << SP << SP <<
"}\n";
766 out << SP << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_feedback_size, &" <<
OpName
767 <<
"_alpha, tensor_" << fNB <<
" + " <<
rbh_offset <<
", &" <<
OpName <<
"_incx, "
771 out << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
772 out << SP << SP << SP <<
OpName <<
"_feedback[i] *= " <<
OpName <<
"_reset_gate[i + offset];\n";
773 out << SP << SP <<
"}\n";
777 out << SP << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_feedback_size, &" <<
OpName <<
"_alpha, "
778 <<
OpName <<
"_feedback, &" <<
OpName <<
"_incx, " <<
OpName <<
"_hidden_gate + offset, &"
782 if (fAttrClip > .0) {
783 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
784 if (fType ==
"float") {
785 out << SP << SP << SP <<
"float x = (" <<
OpName <<
"_hidden_gate[i] > " << -fAttrClip
786 <<
") ? " <<
OpName <<
"_hidden_gate[i] : " << -fAttrClip <<
";\n";
788 out << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = (x < " << fAttrClip <<
") ? x : "
789 << fAttrClip <<
";\n";
790 out << SP << SP <<
"}\n";
794 if (fAttrActivations[
direction * 2 + 1] ==
"Relu") {
795 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
796 out << SP << SP << SP <<
"if (" <<
OpName <<
"_hidden_gate[i] < 0.)\n";
797 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = 0.;\n";
798 out << SP << SP <<
"}\n";
799 }
else if (fAttrActivations[
direction * 2 + 1] ==
"Tanh") {
800 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
801 if (fType ==
"float") {
802 out << SP << SP << SP <<
"float ex = exp(-2 * " <<
OpName <<
"_hidden_gate[i]);\n";
804 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = (1. - ex) / (1. + ex);\n";
805 out << SP << SP <<
"}\n";
806 }
else if (fAttrActivations[
direction * 2 + 1] ==
"Sigmoid") {
807 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
808 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = 1. / (1. + exp(-" <<
OpName
809 <<
"_hidden_gate[i]));\n";
810 out << SP << SP <<
"}\n";
811 }
else if (fAttrActivations[
direction * 2 + 1] ==
"Affine") {
812 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
813 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = "
814 << fAttrActivationAlpha[
direction * 2 + 1] <<
" * " <<
OpName <<
"_hidden_gate[i] + "
815 << fAttrActivationBeta[
direction * 2 + 1] <<
";\n";
816 out << SP << SP <<
"}\n";
817 }
else if (fAttrActivations[
direction * 2 + 1] ==
"ScaledTanh") {
818 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
819 if (fType ==
"float") {
820 out << SP << SP << SP <<
"float ex = exp(-2 * " << fAttrActivationBeta[
direction * 2 + 1]
821 <<
" * "<<
OpName <<
"_hidden_gate[i]);\n";
823 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = "
824 << fAttrActivationAlpha[
direction * 2 + 1] <<
" * (1. - ex) / (1. + ex);\n";
825 out << SP << SP <<
"}\n";
826 }
else if (fAttrActivations[
direction * 2 + 1] ==
"HardSigmoid") {
827 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
828 if (fType ==
"float") {
829 out << SP << SP << SP <<
"float a = " << fAttrActivationAlpha[
direction * 2 + 1] <<
" * "
830 <<
OpName <<
"_hidden_gate[i] + " << fAttrActivationBeta[
direction * 2 + 1] <<
";\n";
831 out << SP << SP << SP <<
"float b = (a > 0.) ? a : 0.;\n";
833 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = (b < 1.) ? b : 1.;\n";
834 out << SP << SP <<
"}\n";
835 }
else if (fAttrActivations[
direction * 2 + 1] ==
"LeakyRelu") {
836 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
837 out << SP << SP << SP <<
"if (" <<
OpName <<
"_hidden_gate[i] < 0.)\n";
838 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = "
839 << fAttrActivationAlpha[
direction * 2 + 1] <<
" * " <<
OpName <<
"_hidden_gate[i];\n";
840 out << SP << SP <<
"}\n";
841 }
else if (fAttrActivations[
direction * 2 + 1] ==
"ThresholdRelu") {
842 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
843 out << SP << SP << SP <<
"if (" <<
OpName <<
"_hidden_gate[i] < "
844 << fAttrActivationAlpha[
direction * 2 + 1] <<
")\n";
845 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = 0.;\n";
846 out << SP << SP <<
"}";
847 }
else if (fAttrActivations[
direction * 2 + 1] ==
"Elu") {
848 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
849 out << SP << SP << SP <<
"if (" <<
OpName <<
"_hidden_gate[i] < 0.)\n";
850 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = "
851 << fAttrActivationAlpha[
direction * 2 + 1] <<
" * exp(" <<
OpName <<
"_hidden_gate[i] - 1.);\n";
852 out << SP << SP <<
"}\n";
853 }
else if (fAttrActivations[
direction * 2 + 1] ==
"Softsign") {
854 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
855 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = " <<
OpName
856 <<
"_hidden_gate[i] / (1. + abs(" <<
OpName <<
"_hidden_gate[i]));\n";
857 out << SP << SP <<
"}\n";
859 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
860 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = log(1. + exp("
861 <<
OpName <<
"_hidden_gate[i]));\n";
862 out << SP << SP <<
"}\n";
866 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
867 out << SP << SP << SP <<
OpName <<
"_hidden_state[i] = ( 1. - " <<
OpName
868 <<
"_update_gate[i]) * " <<
OpName <<
"_hidden_gate[i];\n";
869 out << SP << SP <<
"}\n";
871 out << SP << SP <<
"if (seq == 0) {\n";
872 if (!fNInitial_h.empty()) {
874 out << SP << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
875 out << SP << SP << SP << SP <<
OpName <<
"_hidden_state[i + offset] += " <<
OpName
876 <<
"_update_gate[i + offset] * " <<
OpName <<
"_initial_hidden_state[i];\n";
877 out << SP << SP << SP <<
"}\n";
879 out << SP << SP <<
"} else {\n";
882 if (fAttrDirection ==
"backward") {
883 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
886 out << SP << SP << SP <<
"size_t previous_offset = (seq - 1) * "
890 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
893 out << SP << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
894 out << SP << SP << SP << SP <<
OpName <<
"_hidden_state[i + offset] += " <<
OpName
895 <<
"_update_gate[i + offset] * " <<
OpName <<
"_hidden_state[i + previous_offset];\n";
896 out << SP << SP << SP <<
"}\n";
897 out << SP << SP <<
"}\n";
903 if (!fNSequence_lens.empty()) {
904 out << SP <<
"for (size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
905 out << SP << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
906 out << SP << SP << SP <<
"if (seq >= tensor_" << fNSequence_lens <<
"[batch]) {\n";
908 out << SP << SP << SP << SP << SP <<
"for (size_t h = 0; h < " << fAttrHiddenSize <<
"; h++) {\n";
909 out << SP << SP << SP << SP << SP << SP <<
OpName <<
"_hidden_state[seq * "
911 <<
" + batch * " << fAttrHiddenSize <<
" + h] = 0.;\n";
912 out << SP << SP << SP << SP << SP <<
"}\n";
914 out << SP << SP << SP <<
"}\n";
915 out << SP << SP <<
"}\n";
920 if (fAttrLayout == 0) {
921 if (!fNY_h.empty()) {
923 if (fNSequence_lens.empty()) {
925 if (fAttrDirection ==
"backward") {
926 out << SP <<
"std::copy(" <<
OpName <<
"_hidden_state, " <<
OpName <<
"_hidden_state + "
927 <<
yh_size <<
", tensor_" << fNY_h <<
");\n";
931 <<
"_hidden_state + " <<
offset <<
" + " <<
yh_size <<
", tensor_" << fNY_h <<
");\n";
935 <<
"_hidden_state + " << 2 *
yh_size <<
", tensor_" << fNY_h <<
" + " <<
yh_size <<
");\n";
938 if (fAttrDirection ==
"backward") {
939 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
940 out << SP << SP <<
"size_t offset = batch * " << fAttrHiddenSize <<
";\n";
941 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
942 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + offset);\n";
945 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
946 out << SP << SP <<
"size_t seq = " <<
"tensor_" << fNSequence_lens <<
"[batch] - 1;\n";
948 <<
" + batch * " << fAttrHiddenSize <<
";\n";
949 out << SP << SP <<
"size_t yh_offset = batch * " << fAttrHiddenSize <<
";\n";
950 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
951 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";
955 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
956 out << SP << SP <<
"size_t offset = " <<
batch_size * fAttrHiddenSize
957 <<
" + batch * " << fAttrHiddenSize <<
";\n";
958 out << SP << SP <<
"size_t yh_offset = " <<
batch_size * fAttrHiddenSize
959 <<
" + batch * " << fAttrHiddenSize <<
";\n";
960 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
961 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";
970 out << SP <<
"for (size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
971 out << SP << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
973 <<
" + " <<
direction *
batch_size * fAttrHiddenSize <<
" + batch * " << fAttrHiddenSize <<
";\n";
976 out << SP << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
977 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY <<
" + y_offset);\n";
978 out << SP << SP <<
"}\n";
982 if (!fNY_h.empty()) {
984 if (fAttrDirection ==
"backward") {
985 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
986 out << SP << SP <<
"size_t offset = batch * " << fAttrHiddenSize <<
";\n";
987 out << SP << SP <<
"size_t yh_offset = batch * " <<
num_directions * fAttrHiddenSize <<
";\n";
988 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
989 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";
992 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
993 if (fNSequence_lens.empty()) {
994 out << SP << SP <<
"size_t seq = " <<
seq_length - 1 <<
";\n";
996 out << SP << SP <<
"size_t seq = " <<
"tensor_" << fNSequence_lens <<
"[batch] - 1;\n";
999 <<
" + batch * " << fAttrHiddenSize <<
";\n";
1000 out << SP << SP <<
"size_t yh_offset = batch * " <<
num_directions * fAttrHiddenSize <<
";\n";
1001 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
1002 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";
1006 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1007 out << SP << SP <<
"size_t offset = " <<
batch_size * fAttrHiddenSize <<
" + batch * "
1008 << fAttrHiddenSize <<
";\n";
1009 out << SP << SP <<
"size_t yh_offset = batch * " <<
num_directions * fAttrHiddenSize <<
" + "
1010 << fAttrHiddenSize <<
";\n";
1011 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
1012 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";