42 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNX +
" is not found in model.");
45 if (fShapeX.size() != 3) {
46 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNX +
" is not of 3 dimensions.");
49 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNW +
" is not found in model.");
52 if (fShapeW.size() != 3) {
53 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNW +
" is not of 3 dimensions.");
56 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNR +
" is not found in model.");
59 if (fShapeR.size() != 3) {
60 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNR +
" is not of 3 dimensions.");
64 throw std::runtime_error(
"TMVA SOFIE GRU op input tensor " + fNB +
" is not found in model.");
67 if (fShapeB.size() != 2 && fShapeB.size() != 4) {
68 throw std::runtime_error(
"TMVA SOFIE GRU op input tensor " + fNB +
" is not of 2 or 4 dimensions.");
70 if (fShapeB.size() == 2) {
74 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
75 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
76 if (fType ==
"float") {
80 for (
size_t i = 0; i < 6; i++) {
85 i * batch_size *
seq_length * fAttrHiddenSize +
86 +
seq *batch_size *fAttrHiddenSize +
batch *fAttrHiddenSize;
101 if (!fNSequence_lens.empty()) {
103 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
105 "is not found in model.");
108 if (fShapeSequence_lens.size() != 1) {
109 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
111 " is not of 1 dimension.");
114 if (!fNInitial_h.empty()) {
116 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
117 fNInitial_h +
" is not found in model.");
120 if (fShapeInitial_h.size() != 3) {
121 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
122 fNInitial_h +
" is not of 3 dimensions.");
126 fShapeY = ShapeInference({fShapeX, fShapeW})[0];
131 if (!fNY_h.empty()) {
132 fShapeY_h = ShapeInference({fShapeX, fShapeW})[1];
145 throw std::runtime_error(
"TMVA SOFIE - Activation function " +
149 if (fAttrDirection ==
"reverse") fAttrDirection =
"backward";
150 if (fAttrDirection !=
"forward" && fAttrDirection !=
"backward" &&
151 fAttrDirection !=
"reverse" &&
152 fAttrDirection !=
"bidirectional") {
153 throw std::runtime_error(
154 "TMVA SOFIE - Invalid GRU direction fAttrDirection = " +
157 if (3 * fAttrHiddenSize != fShapeW[1]) {
158 throw std::runtime_error(
159 "TMVA SOFIE - fAttrHiddenSize must be equal to " +
160 std::to_string(fShapeW[1] / 3));
162 if (fAttrLayout > 1) {
163 throw std::runtime_error(
"TMVA SOFIE - Layout fAttrLayout = " +
164 std::to_string(fAttrLayout) +
165 " must be 0 (timewise) or 1 (batchwise)");
167 if (fAttrLinearBeforeReset > 1) {
168 throw std::runtime_error(
169 "TMVA SOFIE - fAttrInputForget = " + std::to_string(fAttrLinearBeforeReset)
170 +
" must be 0 or 1.");
172 if (fAttrActivations.empty()) {
173 if (fAttrDirection ==
"bidirectional") {
174 fAttrActivations = {
"Sigmoid",
"Tanh",
"Sigmoid",
"Tanh"};
176 fAttrActivations = {
"Sigmoid",
"Tanh"};
231 std::stringstream out;
233 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
234 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
239 if (fAttrLayout == 0) {
240 out << SP << fType <<
" *" <<
OpName <<
"_input = tensor_" << fNX <<
";\n";
243 out << SP << fType <<
" * " <<
OpName <<
"_input = fVec_" <<
OpName <<
"_input.data();\n";
247 out << SP <<
"for(size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
248 out << SP << SP <<
"for(size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
249 out << SP << SP << SP <<
"for(size_t i = 0; i < " <<
input_size <<
"; i++) {\n";
250 out << SP << SP << SP << SP <<
OpName <<
"_input[seq * " << batch_size *
input_size
251 <<
" + batch * " <<
input_size <<
" + i] = " <<
"tensor_" << fNX <<
"[batch * "
253 out << SP << SP << SP <<
"}\n";
254 out << SP << SP <<
"}\n";
259 if (!fNInitial_h.empty()) {
260 if (fAttrLayout == 0) {
261 out << SP << fType <<
" *" <<
OpName <<
"_initial_hidden_state = " <<
" tensor_"
262 << fNInitial_h <<
";\n";
265 out << SP << fType <<
" * " <<
OpName <<
"_initial_hidden_state = fVec_" <<
OpName
266 <<
"_initial_hidden_state.data();\n";
268 out << SP << fType <<
" " <<
OpName <<
"_initial_hidden_state[" <<
num_directions * batch_size *
269 fAttrHiddenSize <<
"];\n";
272 out << SP <<
"for(size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
273 out << SP << SP <<
"for(size_t h = 0; h < " << fAttrHiddenSize <<
"; h++) {\n";
274 out << SP << SP << SP <<
OpName <<
"_initial_hidden_state["
275 <<
direction * batch_size * fAttrHiddenSize <<
" + batch * " << fAttrHiddenSize
276 <<
" + h] = tensor_" << fNInitial_h <<
"[batch * " <<
num_directions * fAttrHiddenSize
277 <<
" + " <<
direction * fAttrHiddenSize <<
" + h];\n";
278 out << SP << SP <<
"}\n";
287 out << SP << fType <<
" * " <<
OpName <<
"_f_update_gate = fVec_" <<
OpName <<
"_f_update_gate.data();\n";
288 out << SP << fType <<
" * " <<
OpName <<
"_f_reset_gate = fVec_" <<
OpName <<
"_f_reset_gate.data();\n";
289 out << SP << fType <<
" * " <<
OpName <<
"_f_hidden_gate = fVec_" <<
OpName <<
"_f_hidden_gate.data();\n";
298 out << SP << fType <<
" * " <<
OpName <<
"_update_gate = fVec_" <<
OpName <<
"_update_gate.data();\n";
299 out << SP << fType <<
" * " <<
OpName <<
"_reset_gate = fVec_" <<
OpName <<
"_reset_gate.data();\n";
300 out << SP << fType <<
" * " <<
OpName <<
"_hidden_gate = fVec_" <<
OpName <<
"_hidden_gate.data();\n";
307 if (fAttrLayout == 0 && !fNY.empty()) {
308 out << SP << fType <<
" *" <<
OpName <<
"_hidden_state = tensor_" << fNY <<
";\n";
311 out << SP << fType <<
" * " <<
OpName <<
"_hidden_state = fVec_" <<
OpName <<
"_hidden_state.data();\n";
318 out << SP << fType <<
" * " <<
OpName <<
"_feedback = fVec_" <<
OpName <<
"_feedback.data();\n";
320 out << SP << fType <<
" " <<
OpName <<
"_feedback[" << batch_size * fAttrHiddenSize <<
"] = {0};\n";
323 out << SP <<
"char " <<
OpName <<
"_transA = 'N';\n";
324 out << SP <<
"char " <<
OpName <<
"_transB = 'T';\n";
325 out << SP <<
"int " <<
OpName <<
"_m = " <<
seq_length * batch_size <<
";\n";
326 out << SP <<
"int " <<
OpName <<
"_m2 = " << batch_size <<
";\n";
327 out << SP <<
"int " <<
OpName <<
"_n = " << fAttrHiddenSize <<
";\n";
329 if (fType ==
"float") {
330 out << SP <<
"float " <<
OpName <<
"_alpha = 1.;\n";
331 out << SP <<
"float " <<
OpName <<
"_beta = 0.;\n";
334 out << SP <<
"int " <<
OpName <<
"_bias_size = " <<
seq_length * batch_size * fAttrHiddenSize <<
";\n";
336 out << SP <<
"int " <<
OpName <<
"_incx = 1;\n";
337 out << SP <<
"int " <<
OpName <<
"_incy = 1;\n";
338 out << SP <<
"int " <<
OpName <<
"_feedback_size = " << batch_size * fAttrHiddenSize <<
";\n";
342 if (fType ==
"float") {
344 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
350 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
356 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
362 if (fType ==
"float") {
365 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
371 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
377 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
386 if (fType ==
"float") {
388 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
389 << fNB <<
", &" <<
OpName <<
"_incx, " <<
OpName <<
"_f_update_gate, &" <<
OpName <<
"_incy);\n";
392 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
397 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
403 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
408 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
411 if (fAttrLinearBeforeReset == 0) {
414 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
416 <<
"_f_hidden_gate, &" <<
OpName <<
"_incy);\n";
420 if (fType ==
"float") {
423 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
429 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
434 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
439 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
444 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
447 if (fAttrLinearBeforeReset == 0) {
450 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
452 <<
"_f_hidden_gate, &" <<
OpName <<
"_incy);\n";
459 out << SP <<
"for (size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
460 out << SP << SP <<
"size_t offset = seq * " << batch_size * fAttrHiddenSize <<
";\n";
462 out << SP << SP <<
"size_t gate_offset = seq * " <<
num_directions * batch_size * fAttrHiddenSize
465 out << SP << SP <<
"size_t gate_offset = seq * " <<
num_directions * batch_size * fAttrHiddenSize
466 <<
" + " << batch_size * fAttrHiddenSize <<
";\n";
468 size_t f_seq_size = batch_size * fAttrHiddenSize;
469 out << SP << SP <<
"std::copy(" <<
OpName <<
"_f_update_gate + offset, " <<
OpName
470 <<
"_f_update_gate + offset + " <<
f_seq_size <<
", " <<
OpName <<
"_update_gate + gate_offset);\n";
471 out << SP << SP <<
"std::copy(" <<
OpName <<
"_f_reset_gate + offset, " <<
OpName
472 <<
"_f_reset_gate + offset + " <<
f_seq_size <<
", " <<
OpName <<
"_reset_gate + gate_offset);\n";
473 out << SP << SP <<
"std::copy(" <<
OpName <<
"_f_hidden_gate + offset, " <<
OpName
474 <<
"_f_hidden_gate + offset + " <<
f_seq_size <<
", " <<
OpName <<
"_hidden_gate + gate_offset);\n";
477 out << SP <<
"for (size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
478 if (fAttrDirection ==
"backward" ||
direction == 1) {
479 out << SP << SP <<
"size_t index = " <<
seq_length - 1 <<
" - seq;\n";
481 out << SP << SP <<
"size_t index = seq;\n";
483 out << SP << SP <<
"int m2 = " << batch_size <<
";\n";
485 out << SP << SP <<
"size_t offset = index * " <<
num_directions * batch_size * fAttrHiddenSize
488 out << SP << SP <<
"size_t offset = index * " <<
num_directions * batch_size * fAttrHiddenSize
489 <<
" + " << batch_size * fAttrHiddenSize <<
";\n";
491 size_t size = batch_size * fAttrHiddenSize;
493 out << SP << SP <<
"if (seq == 0) {\n";
494 if (!fNInitial_h.empty()) {
496 if (fType ==
"float") {
497 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
498 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
", &"
500 <<
"_alpha, " <<
OpName <<
"_update_gate + offset, &" <<
OpName <<
"_n);\n";
501 size_t rr_offset = fAttrHiddenSize * fAttrHiddenSize;
502 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
503 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
505 <<
"_n, &" <<
OpName <<
"_alpha, " <<
OpName <<
"_reset_gate + offset, &" <<
OpName <<
"_n);\n";
508 if (fType ==
"float") {
509 size_t rz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
510 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
511 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
513 <<
"_n, &" <<
OpName <<
"_alpha, " <<
OpName <<
"_update_gate + offset, &" <<
OpName <<
"_n);\n";
514 size_t rr_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
515 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
516 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
518 <<
"_n, &" <<
OpName <<
"_alpha, " <<
OpName <<
"_reset_gate + offset, &" <<
OpName <<
"_n);\n";
522 out << SP << SP <<
"} else {\n";
525 if (fAttrDirection ==
"backward") {
526 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
529 out << SP << SP << SP <<
"size_t previous_offset = (seq - 1) * "
532 if (fType ==
"float") {
533 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
534 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
", &"
535 <<
OpName <<
"_n, " <<
OpName <<
"_hidden_state + previous_offset, &" <<
OpName <<
"_n, &"
536 <<
OpName <<
"_alpha, " <<
OpName <<
"_update_gate + offset, &" <<
OpName <<
"_n);\n";
537 size_t rr_offset = fAttrHiddenSize * fAttrHiddenSize;
538 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
539 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
545 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
546 <<
num_directions * batch_size * fAttrHiddenSize <<
" + " << batch_size * fAttrHiddenSize <<
";\n";
547 if (fType ==
"float") {
548 size_t rz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
549 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
550 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
554 size_t rr_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
555 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
556 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
562 out << SP << SP <<
"}\n";
565 if (fAttrClip > .0) {
566 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
567 if (fType ==
"float") {
568 out << SP << SP << SP <<
"float z = (" <<
OpName <<
"_update_gate[i] > " << -fAttrClip
569 <<
") ? " <<
OpName <<
"_update_gate[i] : " << -fAttrClip <<
";\n";
571 out << SP << SP << SP <<
OpName <<
"_update_gate[i] = (z < " << fAttrClip
572 <<
") ? z : " << fAttrClip <<
";\n";
573 if (fType ==
"float") {
574 out << SP << SP << SP <<
"float r = (" <<
OpName <<
"_reset_gate[i] > " << -fAttrClip
575 <<
") ? " <<
OpName <<
"_reset_gate[i] : " << -fAttrClip <<
";\n";
577 out << SP << SP << SP <<
OpName <<
"_reset_gate[i] = (r < " << fAttrClip
578 <<
") ? r : " << fAttrClip <<
";\n";
579 out << SP << SP <<
"}\n";
583 if (fAttrActivations[
direction * 2] ==
"Relu") {
584 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
585 out << SP << SP << SP <<
"if (" <<
OpName <<
"_update_gate[i] < 0.)\n";
586 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = 0.;\n";
587 out << SP << SP << SP <<
"if (" <<
OpName <<
"_reset_gate[i] < 0.)\n";
588 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = 0.;\n";
589 out << SP << SP <<
"}\n";
590 }
else if (fAttrActivations[
direction * 2] ==
"Tanh") {
591 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
592 if (fType ==
"float") {
593 out << SP << SP << SP <<
"float z = exp(-2 * " <<
OpName <<
"_update_gate[i]);\n";
595 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = (1. - z) / (1. + z);\n";
596 if (fType ==
"float") {
597 out << SP << SP << SP <<
"float r = exp(-2 * " <<
OpName <<
"_reset_gate[i]);\n";
599 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = (1. - r) / (1. + r);\n";
600 out << SP << SP <<
"}\n";
601 }
else if (fAttrActivations[
direction * 2] ==
"Sigmoid") {
602 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
603 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = 1. / (1. + exp(-"
604 <<
OpName <<
"_update_gate[i]));\n";
605 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = 1. / (1. + exp(-"
606 <<
OpName <<
"_reset_gate[i]));\n";
607 out << SP << SP <<
"}\n";
608 }
else if (fAttrActivations[
direction * 2] ==
"Affine") {
609 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
610 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = "
611 << fAttrActivationAlpha[
direction * 2] <<
" * " <<
OpName <<
"_update_gate[i] + "
612 << fAttrActivationBeta[
direction * 2] <<
";\n";
613 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = "
614 << fAttrActivationAlpha[
direction * 2] <<
" * " <<
OpName <<
"_reset_gate[i] + "
615 << fAttrActivationBeta[
direction * 2] <<
";\n";
616 out << SP << SP <<
"}\n";
617 }
else if (fAttrActivations[
direction * 2] ==
"ScaledTanh") {
618 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
619 if (fType ==
"float") {
620 out << SP << SP << SP <<
"float z = exp(-2 * " << fAttrActivationBeta[
direction * 2]
621 <<
" * "<<
OpName <<
"_update_gate[i]);\n";
623 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = "
624 << fAttrActivationAlpha[
direction * 2] <<
" * (1. - z) / (1. + z);\n";
625 if (fType ==
"float") {
626 out << SP << SP << SP <<
"float r = exp(-2 * " << fAttrActivationBeta[
direction * 2]
627 <<
" * "<<
OpName <<
"_reset_gate[i]);\n";
629 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = "
630 << fAttrActivationAlpha[
direction * 2] <<
" * (1. - r) / (1. + r);\n";
631 out << SP << SP <<
"}\n";
632 }
else if (fAttrActivations[
direction * 2] ==
"HardSigmoid") {
633 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
634 if (fType ==
"float") {
635 out << SP << SP << SP <<
"float za = " << fAttrActivationAlpha[
direction * 2] <<
" * "
636 <<
OpName <<
"_update_gate[i] + " << fAttrActivationBeta[
direction * 2] <<
";\n";
637 out << SP << SP << SP <<
"float zb = (za > 0.) ? za : 0.;\n";
639 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = (zb < 1.) ? zb : 1.;\n";
640 if (fType ==
"float") {
641 out << SP << SP << SP <<
"float ra = " << fAttrActivationAlpha[
direction * 2] <<
" * "
642 <<
OpName <<
"_reset_gate[i] + " << fAttrActivationBeta[
direction * 2] <<
";\n";
643 out << SP << SP << SP <<
"float rb = (ra > 0.) ? ra : 0.;\n";
645 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = (rb < 1.) ? rb : 1.;\n";
646 out << SP << SP <<
"}\n";
647 }
else if (fAttrActivations[
direction * 2] ==
"LeakyRelu") {
648 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
649 out << SP << SP << SP <<
"if (" <<
OpName <<
"_update_gate[i] < 0.)\n";
650 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = "
651 << fAttrActivationAlpha[
direction * 2] <<
" * " <<
OpName <<
"_update_gate[i];\n";
652 out << SP << SP << SP <<
"if (" <<
OpName <<
"_reset_gate[i] < 0.)\n";
653 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = "
654 << fAttrActivationAlpha[
direction * 2] <<
" * " <<
OpName <<
"_reset_gate[i];\n";
655 out << SP << SP <<
"}\n";
656 }
else if (fAttrActivations[
direction * 2] ==
"ThresholdRelu") {
657 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
658 out << SP << SP << SP <<
"if (" <<
OpName <<
"_update_gate[i] < "
659 << fAttrActivationAlpha[
direction * 2] <<
")\n";
660 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = 0.;\n";
661 out << SP << SP << SP <<
"if (" <<
OpName <<
"_reset_gate[i] < "
662 << fAttrActivationAlpha[
direction * 2] <<
")\n";
663 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = 0.;\n";
664 out << SP << SP <<
"}";
665 }
else if (fAttrActivations[
direction * 2] ==
"Elu") {
666 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
667 out << SP << SP << SP <<
"if (" <<
OpName <<
"_update_gate[i] < 0.)\n";
668 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = "
669 << fAttrActivationAlpha[
direction * 2] <<
" * exp(" <<
OpName <<
"_update_gate[i] - 1.);\n";
670 out << SP << SP << SP <<
"if (" <<
OpName <<
"_reset_gate[i] < 0.)\n";
671 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = "
672 << fAttrActivationAlpha[
direction * 2] <<
" * exp(" <<
OpName <<
"_reset_gate[i] - 1.);\n";
673 out << SP << SP <<
"}\n";
674 }
else if (fAttrActivations[
direction * 2] ==
"Softsign") {
675 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
676 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = " <<
OpName
677 <<
"_update_gate[i] / (1. + abs(" <<
OpName <<
"_update_gate[i]));\n";
678 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = " <<
OpName
679 <<
"_reset_gate[i] / (1. + abs(" <<
OpName <<
"_reset_gate[i]));\n";
680 out << SP << SP <<
"}\n";
682 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
683 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = log(1. + exp("
684 <<
OpName <<
"_update_gate[i]));\n";
685 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = log(1. + exp("
686 <<
OpName <<
"_reset_gate[i]));\n";
687 out << SP << SP <<
"}\n";
690 if (fAttrLinearBeforeReset == 0) {
691 out << SP << SP <<
"if (seq == 0) {\n";
692 if (!fNInitial_h.empty()) {
694 out << SP << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
695 out << SP << SP << SP << SP <<
OpName <<
"_feedback[i] = " <<
OpName
696 <<
"_reset_gate[i + offset] * " <<
OpName <<
"_initial_hidden_state[i];\n";
697 out << SP << SP << SP <<
"}\n";
699 out << SP << SP <<
"} else {\n";
702 if (fAttrDirection ==
"backward") {
703 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
706 out << SP << SP << SP <<
"size_t previous_offset = (seq - 1) * "
710 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * " <<
num_directions
711 * batch_size * fAttrHiddenSize <<
" + " << batch_size * fAttrHiddenSize <<
";\n";
713 out << SP << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
714 out << SP << SP << SP << SP <<
OpName <<
"_feedback[i] = " <<
OpName
715 <<
"_reset_gate[i + offset] * " <<
OpName <<
"_hidden_state[i + previous_offset];\n";
716 out << SP << SP << SP <<
"}\n";
717 out << SP << SP <<
"}\n";
720 2 * fAttrHiddenSize * fAttrHiddenSize : 3 * fAttrHiddenSize * fAttrHiddenSize
721 + 2 * fAttrHiddenSize * fAttrHiddenSize;
722 out << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
730 ? 2 * fAttrHiddenSize * fAttrHiddenSize
731 : 3 * fAttrHiddenSize * fAttrHiddenSize + 2 * fAttrHiddenSize * fAttrHiddenSize;
732 out << SP << SP <<
"if (seq == 0) {\n";
733 if (!fNInitial_h.empty()) {
735 out << SP << SP << SP
736 <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &" <<
OpName <<
"_n, &"
737 <<
OpName <<
"_m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
741 out << SP << SP <<
"} else {\n";
744 if (fAttrDirection ==
"backward") {
745 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
748 out << SP << SP << SP <<
"size_t previous_offset = (seq - 1) * "
752 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * " <<
num_directions
753 * batch_size * fAttrHiddenSize <<
" + " << batch_size * fAttrHiddenSize <<
";\n";
755 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
760 out << SP << SP <<
"}\n";
764 : 11 * batch_size *
seq_length * fAttrHiddenSize;
765 out << SP << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_feedback_size, &" <<
OpName
766 <<
"_alpha, tensor_" << fNB <<
" + " <<
rbh_offset <<
", &" <<
OpName <<
"_incx, "
770 out << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
771 out << SP << SP << SP <<
OpName <<
"_feedback[i] *= " <<
OpName <<
"_reset_gate[i + offset];\n";
772 out << SP << SP <<
"}\n";
776 out << SP << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_feedback_size, &" <<
OpName <<
"_alpha, "
777 <<
OpName <<
"_feedback, &" <<
OpName <<
"_incx, " <<
OpName <<
"_hidden_gate + offset, &"
781 if (fAttrClip > .0) {
782 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
783 if (fType ==
"float") {
784 out << SP << SP << SP <<
"float x = (" <<
OpName <<
"_hidden_gate[i] > " << -fAttrClip
785 <<
") ? " <<
OpName <<
"_hidden_gate[i] : " << -fAttrClip <<
";\n";
787 out << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = (x < " << fAttrClip <<
") ? x : "
788 << fAttrClip <<
";\n";
789 out << SP << SP <<
"}\n";
793 if (fAttrActivations[
direction * 2 + 1] ==
"Relu") {
794 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
795 out << SP << SP << SP <<
"if (" <<
OpName <<
"_hidden_gate[i] < 0.)\n";
796 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = 0.;\n";
797 out << SP << SP <<
"}\n";
798 }
else if (fAttrActivations[
direction * 2 + 1] ==
"Tanh") {
799 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
800 if (fType ==
"float") {
801 out << SP << SP << SP <<
"float ex = exp(-2 * " <<
OpName <<
"_hidden_gate[i]);\n";
803 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = (1. - ex) / (1. + ex);\n";
804 out << SP << SP <<
"}\n";
805 }
else if (fAttrActivations[
direction * 2 + 1] ==
"Sigmoid") {
806 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
807 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = 1. / (1. + exp(-" <<
OpName
808 <<
"_hidden_gate[i]));\n";
809 out << SP << SP <<
"}\n";
810 }
else if (fAttrActivations[
direction * 2 + 1] ==
"Affine") {
811 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
812 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = "
813 << fAttrActivationAlpha[
direction * 2 + 1] <<
" * " <<
OpName <<
"_hidden_gate[i] + "
814 << fAttrActivationBeta[
direction * 2 + 1] <<
";\n";
815 out << SP << SP <<
"}\n";
816 }
else if (fAttrActivations[
direction * 2 + 1] ==
"ScaledTanh") {
817 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
818 if (fType ==
"float") {
819 out << SP << SP << SP <<
"float ex = exp(-2 * " << fAttrActivationBeta[
direction * 2 + 1]
820 <<
" * "<<
OpName <<
"_hidden_gate[i]);\n";
822 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = "
823 << fAttrActivationAlpha[
direction * 2 + 1] <<
" * (1. - ex) / (1. + ex);\n";
824 out << SP << SP <<
"}\n";
825 }
else if (fAttrActivations[
direction * 2 + 1] ==
"HardSigmoid") {
826 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
827 if (fType ==
"float") {
828 out << SP << SP << SP <<
"float a = " << fAttrActivationAlpha[
direction * 2 + 1] <<
" * "
829 <<
OpName <<
"_hidden_gate[i] + " << fAttrActivationBeta[
direction * 2 + 1] <<
";\n";
830 out << SP << SP << SP <<
"float b = (a > 0.) ? a : 0.;\n";
832 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = (b < 1.) ? b : 1.;\n";
833 out << SP << SP <<
"}\n";
834 }
else if (fAttrActivations[
direction * 2 + 1] ==
"LeakyRelu") {
835 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
836 out << SP << SP << SP <<
"if (" <<
OpName <<
"_hidden_gate[i] < 0.)\n";
837 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = "
838 << fAttrActivationAlpha[
direction * 2 + 1] <<
" * " <<
OpName <<
"_hidden_gate[i];\n";
839 out << SP << SP <<
"}\n";
840 }
else if (fAttrActivations[
direction * 2 + 1] ==
"ThresholdRelu") {
841 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
842 out << SP << SP << SP <<
"if (" <<
OpName <<
"_hidden_gate[i] < "
843 << fAttrActivationAlpha[
direction * 2 + 1] <<
")\n";
844 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = 0.;\n";
845 out << SP << SP <<
"}";
846 }
else if (fAttrActivations[
direction * 2 + 1] ==
"Elu") {
847 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
848 out << SP << SP << SP <<
"if (" <<
OpName <<
"_hidden_gate[i] < 0.)\n";
849 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = "
850 << fAttrActivationAlpha[
direction * 2 + 1] <<
" * exp(" <<
OpName <<
"_hidden_gate[i] - 1.);\n";
851 out << SP << SP <<
"}\n";
852 }
else if (fAttrActivations[
direction * 2 + 1] ==
"Softsign") {
853 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
854 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = " <<
OpName
855 <<
"_hidden_gate[i] / (1. + abs(" <<
OpName <<
"_hidden_gate[i]));\n";
856 out << SP << SP <<
"}\n";
858 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
859 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = log(1. + exp("
860 <<
OpName <<
"_hidden_gate[i]));\n";
861 out << SP << SP <<
"}\n";
865 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
866 out << SP << SP << SP <<
OpName <<
"_hidden_state[i] = ( 1. - " <<
OpName
867 <<
"_update_gate[i]) * " <<
OpName <<
"_hidden_gate[i];\n";
868 out << SP << SP <<
"}\n";
870 out << SP << SP <<
"if (seq == 0) {\n";
871 if (!fNInitial_h.empty()) {
873 out << SP << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
874 out << SP << SP << SP << SP <<
OpName <<
"_hidden_state[i + offset] += " <<
OpName
875 <<
"_update_gate[i + offset] * " <<
OpName <<
"_initial_hidden_state[i];\n";
876 out << SP << SP << SP <<
"}\n";
878 out << SP << SP <<
"} else {\n";
881 if (fAttrDirection ==
"backward") {
882 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
885 out << SP << SP << SP <<
"size_t previous_offset = (seq - 1) * "
889 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
890 <<
num_directions * batch_size * fAttrHiddenSize <<
" + " << batch_size * fAttrHiddenSize <<
";\n";
892 out << SP << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
893 out << SP << SP << SP << SP <<
OpName <<
"_hidden_state[i + offset] += " <<
OpName
894 <<
"_update_gate[i + offset] * " <<
OpName <<
"_hidden_state[i + previous_offset];\n";
895 out << SP << SP << SP <<
"}\n";
896 out << SP << SP <<
"}\n";
902 if (!fNSequence_lens.empty()) {
903 out << SP <<
"for (size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
904 out << SP << SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
905 out << SP << SP << SP <<
"if (seq >= tensor_" << fNSequence_lens <<
"[batch]) {\n";
907 out << SP << SP << SP << SP << SP <<
"for (size_t h = 0; h < " << fAttrHiddenSize <<
"; h++) {\n";
908 out << SP << SP << SP << SP << SP << SP <<
OpName <<
"_hidden_state[seq * "
910 <<
" + batch * " << fAttrHiddenSize <<
" + h] = 0.;\n";
911 out << SP << SP << SP << SP << SP <<
"}\n";
913 out << SP << SP << SP <<
"}\n";
914 out << SP << SP <<
"}\n";
919 if (fAttrLayout == 0) {
920 if (!fNY_h.empty()) {
922 if (fNSequence_lens.empty()) {
923 size_t yh_size = batch_size * fAttrHiddenSize;
924 if (fAttrDirection ==
"backward") {
925 out << SP <<
"std::copy(" <<
OpName <<
"_hidden_state, " <<
OpName <<
"_hidden_state + "
926 <<
yh_size <<
", tensor_" << fNY_h <<
");\n";
930 <<
"_hidden_state + " <<
offset <<
" + " <<
yh_size <<
", tensor_" << fNY_h <<
");\n";
934 <<
"_hidden_state + " << 2 *
yh_size <<
", tensor_" << fNY_h <<
" + " <<
yh_size <<
");\n";
937 if (fAttrDirection ==
"backward") {
938 out << SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
939 out << SP << SP <<
"size_t offset = batch * " << fAttrHiddenSize <<
";\n";
940 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
941 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + offset);\n";
944 out << SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
945 out << SP << SP <<
"size_t seq = " <<
"tensor_" << fNSequence_lens <<
"[batch] - 1;\n";
946 out << SP << SP <<
"size_t offset = seq * " <<
num_directions * batch_size * fAttrHiddenSize
947 <<
" + batch * " << fAttrHiddenSize <<
";\n";
948 out << SP << SP <<
"size_t yh_offset = batch * " << fAttrHiddenSize <<
";\n";
949 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
950 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";
954 out << SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
955 out << SP << SP <<
"size_t offset = " << batch_size * fAttrHiddenSize
956 <<
" + batch * " << fAttrHiddenSize <<
";\n";
957 out << SP << SP <<
"size_t yh_offset = " << batch_size * fAttrHiddenSize
958 <<
" + batch * " << fAttrHiddenSize <<
";\n";
959 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
960 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";
969 out << SP <<
"for (size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
970 out << SP << SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
971 out << SP << SP << SP <<
"size_t offset = seq * " <<
num_directions * batch_size * fAttrHiddenSize
972 <<
" + " <<
direction * batch_size * fAttrHiddenSize <<
" + batch * " << fAttrHiddenSize <<
";\n";
975 out << SP << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
976 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY <<
" + y_offset);\n";
977 out << SP << SP <<
"}\n";
981 if (!fNY_h.empty()) {
983 if (fAttrDirection ==
"backward") {
984 out << SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
985 out << SP << SP <<
"size_t offset = batch * " << fAttrHiddenSize <<
";\n";
986 out << SP << SP <<
"size_t yh_offset = batch * " <<
num_directions * fAttrHiddenSize <<
";\n";
987 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
988 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";
991 out << SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
992 if (fNSequence_lens.empty()) {
993 out << SP << SP <<
"size_t seq = " <<
seq_length - 1 <<
";\n";
995 out << SP << SP <<
"size_t seq = " <<
"tensor_" << fNSequence_lens <<
"[batch] - 1;\n";
997 out << SP << SP <<
"size_t offset = seq * " <<
num_directions * batch_size * fAttrHiddenSize
998 <<
" + batch * " << fAttrHiddenSize <<
";\n";
999 out << SP << SP <<
"size_t yh_offset = batch * " <<
num_directions * fAttrHiddenSize <<
";\n";
1000 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
1001 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";
1005 out << SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1006 out << SP << SP <<
"size_t offset = " << batch_size * fAttrHiddenSize <<
" + batch * "
1007 << fAttrHiddenSize <<
";\n";
1008 out << SP << SP <<
"size_t yh_offset = batch * " <<
num_directions * fAttrHiddenSize <<
" + "
1009 << fAttrHiddenSize <<
";\n";
1010 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
1011 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";