1#ifndef TMVA_SOFIE_ROPERATOR_GRU
2#define TMVA_SOFIE_ROPERATOR_GRU
113 if (std::is_same<T, float>::value) {
116 throw std::runtime_error(
117 "TMVA SOFIE Encountered unsupported type parsing a GRU operator");
125 std::vector<ETensorType>
TypeInference(std::vector<ETensorType> )
override;
131 std::vector<std::vector<size_t>>
ShapeInference(std::vector<std::vector<size_t>> )
override;
143 std::string
Generate(std::string )
override;
147 std::vector<std::string>
GetBlasRoutines()
override {
return { std::string(
"Gemm"), std::string(
"Axpy") }; }
162 if (fAttrLayout == 0) {
165 std::vector<std::vector<size_t>>
ret(
171 std::vector<std::vector<size_t>>
ret(
184 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNX +
" is not found in model.");
187 if (fShapeX.size() != 3) {
188 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNX +
" is not of 3 dimensions.");
191 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNW +
" is not found in model.");
194 if (fShapeW.size() != 3) {
195 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNW +
" is not of 3 dimensions.");
198 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNR +
" is not found in model.");
201 if (fShapeR.size() != 3) {
202 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNR +
" is not of 3 dimensions.");
206 throw std::runtime_error(
"TMVA SOFIE GRU op input tensor " + fNB +
" is not found in model.");
209 if (fShapeB.size() != 2 && fShapeB.size() != 4) {
210 throw std::runtime_error(
"TMVA SOFIE GRU op input tensor " + fNB +
" is not of 2 or 4 dimensions.");
212 if (fShapeB.size() == 2) {
216 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
217 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
218 if (fType ==
"float") {
222 for (
size_t i = 0; i < 6; i++) {
243 if (!fNSequence_lens.empty()) {
245 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNSequence_lens +
"is not found in model.");
248 if (fShapeSequence_lens.size() != 1) {
249 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNSequence_lens +
" is not of 1 dimension.");
252 if (!fNInitial_h.empty()) {
254 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNInitial_h +
" is not found in model.");
257 if (fShapeInitial_h.size() != 3) {
258 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNInitial_h +
" is not of 3 dimensions.");
262 fShapeY = ShapeInference({fShapeX, fShapeW})[0];
267 if (!fNY_h.empty()) {
268 fShapeY_h = ShapeInference({fShapeX, fShapeW})[1];
278 throw std::runtime_error(
"TMVA SOFIE - Activation function " +
activation +
" not implemented");
281 if (fAttrDirection ==
"reverse")
282 fAttrDirection =
"backward";
283 if (fAttrDirection !=
"forward" && fAttrDirection !=
"backward" && fAttrDirection !=
"reverse" &&
284 fAttrDirection !=
"bidirectional") {
285 throw std::runtime_error(
"TMVA SOFIE - Invalid GRU direction fAttrDirection = " + fAttrDirection);
287 if (3 * fAttrHiddenSize != fShapeW[1]) {
288 throw std::runtime_error(
"TMVA SOFIE - fAttrHiddenSize must be equal to " + std::to_string(fShapeW[1] / 3));
290 if (fAttrLayout > 1) {
291 throw std::runtime_error(
"TMVA SOFIE - Layout fAttrLayout = " + std::to_string(fAttrLayout) +
292 " must be 0 (timewise) or 1 (batchwise)");
294 if (fAttrLinearBeforeReset > 1) {
295 throw std::runtime_error(
"TMVA SOFIE - fAttrInputForget = " + std::to_string(fAttrLinearBeforeReset) +
298 if (fAttrActivations.empty()) {
299 if (fAttrDirection ==
"bidirectional") {
300 fAttrActivations = {
"Sigmoid",
"Tanh",
"Sigmoid",
"Tanh"};
302 fAttrActivations = {
"Sigmoid",
"Tanh"};
309 std::string
opName =
"op_gru_" + fNX;
312 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
313 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
321 if (fAttrLayout != 0) {
341 if (fAttrLayout != 0 || fNY.empty()) {
350 std::stringstream out;
352 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
353 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
357 auto getVec = [&](std::string
const &
name) {
return "tensor_op_gru_" + fNX +
"_" +
name; };
360 if (fAttrLayout == 0) {
361 out <<
SP << fType <<
" const* " <<
OpName <<
"_input = tensor_" << fNX <<
";\n";
364 out <<
SP << fType <<
" * " <<
OpName <<
"_input = " <<
getVec(
"input") <<
";\n";
368 out <<
SP <<
"for(size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
369 out <<
SP <<
SP <<
"for(size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
370 out <<
SP <<
SP <<
SP <<
"for(size_t i = 0; i < " <<
input_size <<
"; i++) {\n";
374 out <<
SP <<
SP <<
SP <<
"}\n";
375 out <<
SP <<
SP <<
"}\n";
380 if (!fNInitial_h.empty()) {
381 if (fAttrLayout == 0) {
382 out <<
SP << fType <<
" *" <<
OpName <<
"_initial_hidden_state = " <<
" tensor_" << fNInitial_h <<
";\n";
385 out <<
SP << fType <<
" * " <<
OpName <<
"_initial_hidden_state = " <<
getVec(
"initial_hidden_state")
388 out <<
SP << fType <<
" " <<
OpName <<
"_initial_hidden_state["
392 out <<
SP <<
"for(size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
393 out <<
SP <<
SP <<
"for(size_t h = 0; h < " << fAttrHiddenSize <<
"; h++) {\n";
395 <<
" + batch * " << fAttrHiddenSize <<
" + h] = tensor_" << fNInitial_h <<
"[batch * "
397 out <<
SP <<
SP <<
"}\n";
406 out <<
SP << fType <<
" * " <<
OpName <<
"_f_update_gate = " <<
getVec(
"f_update_gate") <<
";\n";
407 out <<
SP << fType <<
" * " <<
OpName <<
"_f_reset_gate = " <<
getVec(
"f_reset_gate") <<
";\n";
408 out <<
SP << fType <<
" * " <<
OpName <<
"_f_hidden_gate = " <<
getVec(
"f_hidden_gate") <<
";\n";
417 out <<
SP << fType <<
" * " <<
OpName <<
"_update_gate = " <<
getVec(
"update_gate") <<
";\n";
418 out <<
SP << fType <<
" * " <<
OpName <<
"_reset_gate = " <<
getVec(
"reset_gate") <<
";\n";
419 out <<
SP << fType <<
" * " <<
OpName <<
"_hidden_gate = " <<
getVec(
"hidden_gate") <<
";\n";
426 if (fAttrLayout == 0 && !fNY.empty()) {
427 out <<
SP << fType <<
" *" <<
OpName <<
"_hidden_state = tensor_" << fNY <<
";\n";
430 out <<
SP << fType <<
" * " <<
OpName <<
"_hidden_state = " <<
getVec(
"hidden_state") <<
";\n";
437 out <<
SP << fType <<
" * " <<
OpName <<
"_feedback = " <<
getVec(
"feedback") <<
";\n";
439 out <<
SP << fType <<
" " <<
OpName <<
"_feedback[" <<
batch_size * fAttrHiddenSize <<
"] = {0};\n";
442 out <<
SP <<
"char " <<
OpName <<
"_transA = 'N';\n";
443 out <<
SP <<
"char " <<
OpName <<
"_transB = 'T';\n";
446 out <<
SP <<
"int " <<
OpName <<
"_n = " << fAttrHiddenSize <<
";\n";
448 if (fType ==
"float") {
449 out <<
SP <<
"float " <<
OpName <<
"_alpha = 1.;\n";
450 out <<
SP <<
"float " <<
OpName <<
"_beta = 0.;\n";
455 out <<
SP <<
"int " <<
OpName <<
"_incx = 1;\n";
456 out <<
SP <<
"int " <<
OpName <<
"_incy = 1;\n";
457 out <<
SP <<
"int " <<
OpName <<
"_feedback_size = " <<
batch_size * fAttrHiddenSize <<
";\n";
461 if (fType ==
"float") {
463 out <<
SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &" <<
OpName <<
"_n, &"
466 <<
"_f_update_gate, &" <<
OpName <<
"_n);\n";
469 out <<
SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &" <<
OpName <<
"_n, &"
475 out <<
SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &" <<
OpName <<
"_n, &"
481 if (fType ==
"float") {
484 out <<
SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &" <<
OpName <<
"_n, &"
490 out <<
SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &" <<
OpName <<
"_n, &"
496 out <<
SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &" <<
OpName <<
"_n, &"
505 if (fType ==
"float") {
507 out <<
SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_" << fNB <<
", &"
511 out <<
SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_" << fNB <<
" + "
516 out <<
SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_" << fNB <<
" + "
522 out <<
SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_" << fNB <<
" + "
527 out <<
SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_" << fNB <<
" + "
530 if (fAttrLinearBeforeReset == 0) {
533 out <<
SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_" << fNB
539 if (fType ==
"float") {
542 out <<
SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_" << fNB <<
" + "
548 out <<
SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_" << fNB <<
" + "
553 out <<
SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_" << fNB <<
" + "
558 out <<
SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_" << fNB <<
" + "
563 out <<
SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_" << fNB <<
" + "
566 if (fAttrLinearBeforeReset == 0) {
569 out <<
SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_" << fNB
578 out <<
SP <<
"for (size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
579 out <<
SP <<
SP <<
"size_t offset = seq * " <<
batch_size * fAttrHiddenSize <<
";\n";
587 out <<
SP <<
SP <<
"std::copy(" <<
OpName <<
"_f_update_gate + offset, " <<
OpName <<
"_f_update_gate + offset + "
589 out <<
SP <<
SP <<
"std::copy(" <<
OpName <<
"_f_reset_gate + offset, " <<
OpName <<
"_f_reset_gate + offset + "
591 out <<
SP <<
SP <<
"std::copy(" <<
OpName <<
"_f_hidden_gate + offset, " <<
OpName <<
"_f_hidden_gate + offset + "
595 out <<
SP <<
"for (size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
596 if (fAttrDirection ==
"backward" ||
direction == 1) {
597 out <<
SP <<
SP <<
"size_t index = " <<
seq_length - 1 <<
" - seq;\n";
599 out <<
SP <<
SP <<
"size_t index = seq;\n";
610 out <<
SP <<
SP <<
"if (seq == 0) {\n";
611 if (!fNInitial_h.empty()) {
613 if (fType ==
"float") {
615 <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
", &" <<
OpName
616 <<
"_n, " <<
OpName <<
"_initial_hidden_state, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, "
617 <<
OpName <<
"_update_gate + offset, &" <<
OpName <<
"_n);\n";
618 size_t rr_offset = fAttrHiddenSize * fAttrHiddenSize;
622 <<
"_alpha, " <<
OpName <<
"_reset_gate + offset, &" <<
OpName <<
"_n);\n";
625 if (fType ==
"float") {
626 size_t rz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
630 <<
"_alpha, " <<
OpName <<
"_update_gate + offset, &" <<
OpName <<
"_n);\n";
631 size_t rr_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
635 <<
"_alpha, " <<
OpName <<
"_reset_gate + offset, &" <<
OpName <<
"_n);\n";
639 out <<
SP <<
SP <<
"} else {\n";
642 if (fAttrDirection ==
"backward") {
643 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
646 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (seq - 1) * "
649 if (fType ==
"float") {
651 <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
", &" <<
OpName <<
"_n, "
653 <<
"_update_gate + offset, &" <<
OpName <<
"_n);\n";
654 size_t rr_offset = fAttrHiddenSize * fAttrHiddenSize;
657 <<
", &" <<
OpName <<
"_n, " <<
OpName <<
"_hidden_state + previous_offset, &" <<
OpName <<
"_n, &"
661 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
663 if (fType ==
"float") {
664 size_t rz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
667 <<
", &" <<
OpName <<
"_n, " <<
OpName <<
"_hidden_state + previous_offset, &" <<
OpName <<
"_n, &"
668 <<
OpName <<
"_alpha, " <<
OpName <<
"_update_gate + offset, &" <<
OpName <<
"_n);\n";
669 size_t rr_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
672 <<
", &" <<
OpName <<
"_n, " <<
OpName <<
"_hidden_state + previous_offset, &" <<
OpName <<
"_n, &"
676 out <<
SP <<
SP <<
"}\n";
679 if (fAttrClip > .0) {
680 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
681 if (fType ==
"float") {
682 out <<
SP <<
SP <<
SP <<
"float z = (" <<
OpName <<
"_update_gate[i] > " << -fAttrClip <<
") ? " <<
OpName
683 <<
"_update_gate[i] : " << -fAttrClip <<
";\n";
685 out <<
SP <<
SP <<
SP <<
OpName <<
"_update_gate[i] = (z < " << fAttrClip <<
") ? z : " << fAttrClip <<
";\n";
686 if (fType ==
"float") {
687 out <<
SP <<
SP <<
SP <<
"float r = (" <<
OpName <<
"_reset_gate[i] > " << -fAttrClip <<
") ? " <<
OpName
688 <<
"_reset_gate[i] : " << -fAttrClip <<
";\n";
690 out <<
SP <<
SP <<
SP <<
OpName <<
"_reset_gate[i] = (r < " << fAttrClip <<
") ? r : " << fAttrClip <<
";\n";
691 out <<
SP <<
SP <<
"}\n";
695 if (fAttrActivations[
direction * 2] ==
"Relu") {
696 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
697 out <<
SP <<
SP <<
SP <<
"if (" <<
OpName <<
"_update_gate[i] < 0.)\n";
698 out <<
SP <<
SP <<
SP <<
SP <<
OpName <<
"_update_gate[i] = 0.;\n";
699 out <<
SP <<
SP <<
SP <<
"if (" <<
OpName <<
"_reset_gate[i] < 0.)\n";
700 out <<
SP <<
SP <<
SP <<
SP <<
OpName <<
"_reset_gate[i] = 0.;\n";
701 out <<
SP <<
SP <<
"}\n";
702 }
else if (fAttrActivations[
direction * 2] ==
"Tanh") {
703 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
704 if (fType ==
"float") {
705 out <<
SP <<
SP <<
SP <<
"float z = exp(-2 * " <<
OpName <<
"_update_gate[i]);\n";
707 out <<
SP <<
SP <<
SP <<
SP <<
OpName <<
"_update_gate[i] = (1. - z) / (1. + z);\n";
708 if (fType ==
"float") {
709 out <<
SP <<
SP <<
SP <<
"float r = exp(-2 * " <<
OpName <<
"_reset_gate[i]);\n";
711 out <<
SP <<
SP <<
SP <<
SP <<
OpName <<
"_reset_gate[i] = (1. - r) / (1. + r);\n";
712 out <<
SP <<
SP <<
"}\n";
713 }
else if (fAttrActivations[
direction * 2] ==
"Sigmoid") {
714 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
716 <<
"_update_gate[i]));\n";
718 <<
"_reset_gate[i]));\n";
719 out <<
SP <<
SP <<
"}\n";
720 }
else if (fAttrActivations[
direction * 2] ==
"Affine") {
721 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
723 <<
OpName <<
"_update_gate[i] + " << fAttrActivationBeta[
direction * 2] <<
";\n";
725 <<
OpName <<
"_reset_gate[i] + " << fAttrActivationBeta[
direction * 2] <<
";\n";
726 out <<
SP <<
SP <<
"}\n";
727 }
else if (fAttrActivations[
direction * 2] ==
"ScaledTanh") {
728 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
729 if (fType ==
"float") {
730 out <<
SP <<
SP <<
SP <<
"float z = exp(-2 * " << fAttrActivationBeta[
direction * 2] <<
" * " <<
OpName
731 <<
"_update_gate[i]);\n";
734 <<
" * (1. - z) / (1. + z);\n";
735 if (fType ==
"float") {
736 out <<
SP <<
SP <<
SP <<
"float r = exp(-2 * " << fAttrActivationBeta[
direction * 2] <<
" * " <<
OpName
737 <<
"_reset_gate[i]);\n";
740 <<
" * (1. - r) / (1. + r);\n";
741 out <<
SP <<
SP <<
"}\n";
742 }
else if (fAttrActivations[
direction * 2] ==
"HardSigmoid") {
743 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
744 if (fType ==
"float") {
746 <<
"_update_gate[i] + " << fAttrActivationBeta[
direction * 2] <<
";\n";
747 out <<
SP <<
SP <<
SP <<
"float zb = (za > 0.) ? za : 0.;\n";
749 out <<
SP <<
SP <<
SP <<
SP <<
OpName <<
"_update_gate[i] = (zb < 1.) ? zb : 1.;\n";
750 if (fType ==
"float") {
752 <<
"_reset_gate[i] + " << fAttrActivationBeta[
direction * 2] <<
";\n";
753 out <<
SP <<
SP <<
SP <<
"float rb = (ra > 0.) ? ra : 0.;\n";
755 out <<
SP <<
SP <<
SP <<
SP <<
OpName <<
"_reset_gate[i] = (rb < 1.) ? rb : 1.;\n";
756 out <<
SP <<
SP <<
"}\n";
757 }
else if (fAttrActivations[
direction * 2] ==
"LeakyRelu") {
758 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
759 out <<
SP <<
SP <<
SP <<
"if (" <<
OpName <<
"_update_gate[i] < 0.)\n";
761 <<
OpName <<
"_update_gate[i];\n";
762 out <<
SP <<
SP <<
SP <<
"if (" <<
OpName <<
"_reset_gate[i] < 0.)\n";
764 <<
OpName <<
"_reset_gate[i];\n";
765 out <<
SP <<
SP <<
"}\n";
766 }
else if (fAttrActivations[
direction * 2] ==
"ThresholdRelu") {
767 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
768 out <<
SP <<
SP <<
SP <<
"if (" <<
OpName <<
"_update_gate[i] < " << fAttrActivationAlpha[
direction * 2]
770 out <<
SP <<
SP <<
SP <<
SP <<
OpName <<
"_update_gate[i] = 0.;\n";
771 out <<
SP <<
SP <<
SP <<
"if (" <<
OpName <<
"_reset_gate[i] < " << fAttrActivationAlpha[
direction * 2]
773 out <<
SP <<
SP <<
SP <<
SP <<
OpName <<
"_reset_gate[i] = 0.;\n";
774 out <<
SP <<
SP <<
"}";
775 }
else if (fAttrActivations[
direction * 2] ==
"Elu") {
776 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
777 out <<
SP <<
SP <<
SP <<
"if (" <<
OpName <<
"_update_gate[i] < 0.)\n";
779 <<
" * exp(" <<
OpName <<
"_update_gate[i] - 1.);\n";
780 out <<
SP <<
SP <<
SP <<
"if (" <<
OpName <<
"_reset_gate[i] < 0.)\n";
782 <<
" * exp(" <<
OpName <<
"_reset_gate[i] - 1.);\n";
783 out <<
SP <<
SP <<
"}\n";
784 }
else if (fAttrActivations[
direction * 2] ==
"Softsign") {
785 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
786 out <<
SP <<
SP <<
SP <<
SP <<
OpName <<
"_update_gate[i] = " <<
OpName <<
"_update_gate[i] / (1. + abs("
787 <<
OpName <<
"_update_gate[i]));\n";
788 out <<
SP <<
SP <<
SP <<
SP <<
OpName <<
"_reset_gate[i] = " <<
OpName <<
"_reset_gate[i] / (1. + abs("
789 <<
OpName <<
"_reset_gate[i]));\n";
790 out <<
SP <<
SP <<
"}\n";
792 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
793 out <<
SP <<
SP <<
SP <<
SP <<
OpName <<
"_update_gate[i] = log(1. + exp(" <<
OpName <<
"_update_gate[i]));\n";
794 out <<
SP <<
SP <<
SP <<
SP <<
OpName <<
"_reset_gate[i] = log(1. + exp(" <<
OpName <<
"_reset_gate[i]));\n";
795 out <<
SP <<
SP <<
"}\n";
798 if (fAttrLinearBeforeReset == 0) {
799 out <<
SP <<
SP <<
"if (seq == 0) {\n";
800 if (!fNInitial_h.empty()) {
802 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
803 out <<
SP <<
SP <<
SP <<
SP <<
OpName <<
"_feedback[i] = " <<
OpName <<
"_reset_gate[i + offset] * "
804 <<
OpName <<
"_initial_hidden_state[i];\n";
805 out <<
SP <<
SP <<
SP <<
"}\n";
807 out <<
SP <<
SP <<
"} else {\n";
810 if (fAttrDirection ==
"backward") {
811 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
814 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (seq - 1) * "
818 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
821 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
823 <<
"_hidden_state[i + previous_offset];\n";
824 out <<
SP <<
SP <<
SP <<
"}\n";
825 out <<
SP <<
SP <<
"}\n";
828 ? 2 * fAttrHiddenSize * fAttrHiddenSize
829 : 3 * fAttrHiddenSize * fAttrHiddenSize + 2 * fAttrHiddenSize * fAttrHiddenSize;
830 out <<
SP <<
SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &" <<
OpName <<
"_n, &"
838 ? 2 * fAttrHiddenSize * fAttrHiddenSize
839 : 3 * fAttrHiddenSize * fAttrHiddenSize + 2 * fAttrHiddenSize * fAttrHiddenSize;
840 out <<
SP <<
SP <<
"if (seq == 0) {\n";
841 if (!fNInitial_h.empty()) {
844 <<
"_n, &" <<
OpName <<
"_m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
848 out <<
SP <<
SP <<
"} else {\n";
851 if (fAttrDirection ==
"backward") {
852 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
855 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (seq - 1) * "
859 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
863 <<
"_n, &" <<
OpName <<
"_m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
867 out <<
SP <<
SP <<
"}\n";
872 out <<
SP <<
SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_feedback_size, &" <<
OpName <<
"_alpha, tensor_" << fNB
877 out <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
878 out <<
SP <<
SP <<
SP <<
OpName <<
"_feedback[i] *= " <<
OpName <<
"_reset_gate[i + offset];\n";
879 out <<
SP <<
SP <<
"}\n";
884 <<
"_feedback, &" <<
OpName <<
"_incx, " <<
OpName <<
"_hidden_gate + offset, &" <<
OpName <<
"_incy);\n";
887 if (fAttrClip > .0) {
888 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
889 if (fType ==
"float") {
890 out <<
SP <<
SP <<
SP <<
"float x = (" <<
OpName <<
"_hidden_gate[i] > " << -fAttrClip <<
") ? " <<
OpName
891 <<
"_hidden_gate[i] : " << -fAttrClip <<
";\n";
893 out <<
SP <<
SP <<
SP <<
OpName <<
"_hidden_gate[i] = (x < " << fAttrClip <<
") ? x : " << fAttrClip <<
";\n";
894 out <<
SP <<
SP <<
"}\n";
898 if (fAttrActivations[
direction * 2 + 1] ==
"Relu") {
899 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
900 out <<
SP <<
SP <<
SP <<
"if (" <<
OpName <<
"_hidden_gate[i] < 0.)\n";
901 out <<
SP <<
SP <<
SP <<
SP <<
OpName <<
"_hidden_gate[i] = 0.;\n";
902 out <<
SP <<
SP <<
"}\n";
903 }
else if (fAttrActivations[
direction * 2 + 1] ==
"Tanh") {
904 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
905 if (fType ==
"float") {
906 out <<
SP <<
SP <<
SP <<
"float ex = exp(-2 * " <<
OpName <<
"_hidden_gate[i]);\n";
908 out <<
SP <<
SP <<
SP <<
SP <<
OpName <<
"_hidden_gate[i] = (1. - ex) / (1. + ex);\n";
909 out <<
SP <<
SP <<
"}\n";
910 }
else if (fAttrActivations[
direction * 2 + 1] ==
"Sigmoid") {
911 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
913 <<
"_hidden_gate[i]));\n";
914 out <<
SP <<
SP <<
"}\n";
915 }
else if (fAttrActivations[
direction * 2 + 1] ==
"Affine") {
916 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
918 <<
" * " <<
OpName <<
"_hidden_gate[i] + " << fAttrActivationBeta[
direction * 2 + 1] <<
";\n";
919 out <<
SP <<
SP <<
"}\n";
920 }
else if (fAttrActivations[
direction * 2 + 1] ==
"ScaledTanh") {
921 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
922 if (fType ==
"float") {
923 out <<
SP <<
SP <<
SP <<
"float ex = exp(-2 * " << fAttrActivationBeta[
direction * 2 + 1] <<
" * " <<
OpName
924 <<
"_hidden_gate[i]);\n";
927 <<
" * (1. - ex) / (1. + ex);\n";
928 out <<
SP <<
SP <<
"}\n";
929 }
else if (fAttrActivations[
direction * 2 + 1] ==
"HardSigmoid") {
930 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
931 if (fType ==
"float") {
933 <<
"_hidden_gate[i] + " << fAttrActivationBeta[
direction * 2 + 1] <<
";\n";
934 out <<
SP <<
SP <<
SP <<
"float b = (a > 0.) ? a : 0.;\n";
936 out <<
SP <<
SP <<
SP <<
SP <<
OpName <<
"_hidden_gate[i] = (b < 1.) ? b : 1.;\n";
937 out <<
SP <<
SP <<
"}\n";
938 }
else if (fAttrActivations[
direction * 2 + 1] ==
"LeakyRelu") {
939 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
940 out <<
SP <<
SP <<
SP <<
"if (" <<
OpName <<
"_hidden_gate[i] < 0.)\n";
942 <<
" * " <<
OpName <<
"_hidden_gate[i];\n";
943 out <<
SP <<
SP <<
"}\n";
944 }
else if (fAttrActivations[
direction * 2 + 1] ==
"ThresholdRelu") {
945 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
946 out <<
SP <<
SP <<
SP <<
"if (" <<
OpName <<
"_hidden_gate[i] < " << fAttrActivationAlpha[
direction * 2 + 1]
948 out <<
SP <<
SP <<
SP <<
SP <<
OpName <<
"_hidden_gate[i] = 0.;\n";
949 out <<
SP <<
SP <<
"}";
950 }
else if (fAttrActivations[
direction * 2 + 1] ==
"Elu") {
951 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
952 out <<
SP <<
SP <<
SP <<
"if (" <<
OpName <<
"_hidden_gate[i] < 0.)\n";
954 <<
" * exp(" <<
OpName <<
"_hidden_gate[i] - 1.);\n";
955 out <<
SP <<
SP <<
"}\n";
956 }
else if (fAttrActivations[
direction * 2 + 1] ==
"Softsign") {
957 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
958 out <<
SP <<
SP <<
SP <<
SP <<
OpName <<
"_hidden_gate[i] = " <<
OpName <<
"_hidden_gate[i] / (1. + abs("
959 <<
OpName <<
"_hidden_gate[i]));\n";
960 out <<
SP <<
SP <<
"}\n";
962 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
963 out <<
SP <<
SP <<
SP <<
SP <<
OpName <<
"_hidden_gate[i] = log(1. + exp(" <<
OpName <<
"_hidden_gate[i]));\n";
964 out <<
SP <<
SP <<
"}\n";
968 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
970 <<
"_hidden_gate[i];\n";
971 out <<
SP <<
SP <<
"}\n";
973 out <<
SP <<
SP <<
"if (seq == 0) {\n";
974 if (!fNInitial_h.empty()) {
976 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
978 <<
"_update_gate[i + offset] * " <<
OpName <<
"_initial_hidden_state[i];\n";
979 out <<
SP <<
SP <<
SP <<
"}\n";
981 out <<
SP <<
SP <<
"} else {\n";
984 if (fAttrDirection ==
"backward") {
985 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
988 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (seq - 1) * "
992 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
995 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
997 <<
"_update_gate[i + offset] * " <<
OpName <<
"_hidden_state[i + previous_offset];\n";
998 out <<
SP <<
SP <<
SP <<
"}\n";
999 out <<
SP <<
SP <<
"}\n";
1005 if (!fNSequence_lens.empty()) {
1006 out <<
SP <<
"for (size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
1007 out <<
SP <<
SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1008 out <<
SP <<
SP <<
SP <<
"if (seq >= tensor_" << fNSequence_lens <<
"[batch]) {\n";
1010 out <<
SP <<
SP <<
SP <<
SP <<
SP <<
"for (size_t h = 0; h < " << fAttrHiddenSize <<
"; h++) {\n";
1013 <<
" + batch * " << fAttrHiddenSize <<
" + h] = 0.;\n";
1016 out <<
SP <<
SP <<
SP <<
"}\n";
1017 out <<
SP <<
SP <<
"}\n";
1022 if (fAttrLayout == 0) {
1023 if (!fNY_h.empty()) {
1025 if (fNSequence_lens.empty()) {
1027 if (fAttrDirection ==
"backward") {
1029 <<
", tensor_" << fNY_h <<
");\n";
1033 <<
"_hidden_state + " <<
offset <<
" + " <<
yh_size <<
", tensor_" << fNY_h <<
");\n";
1037 <<
"_hidden_state + " << 2 *
yh_size <<
", tensor_" << fNY_h <<
" + " <<
yh_size <<
");\n";
1040 if (fAttrDirection ==
"backward") {
1041 out <<
SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1042 out <<
SP <<
SP <<
"size_t offset = batch * " << fAttrHiddenSize <<
";\n";
1043 out <<
SP <<
SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
1044 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + offset);\n";
1047 out <<
SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1048 out <<
SP <<
SP <<
"size_t seq = " <<
"tensor_" << fNSequence_lens <<
"[batch] - 1;\n";
1050 <<
" + batch * " << fAttrHiddenSize <<
";\n";
1051 out <<
SP <<
SP <<
"size_t yh_offset = batch * " << fAttrHiddenSize <<
";\n";
1052 out <<
SP <<
SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
1053 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";
1057 out <<
SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1058 out <<
SP <<
SP <<
"size_t offset = " <<
batch_size * fAttrHiddenSize <<
" + batch * " << fAttrHiddenSize
1060 out <<
SP <<
SP <<
"size_t yh_offset = " <<
batch_size * fAttrHiddenSize <<
" + batch * "
1061 << fAttrHiddenSize <<
";\n";
1062 out <<
SP <<
SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
1063 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";
1072 out <<
SP <<
"for (size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
1073 out <<
SP <<
SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1078 out <<
SP <<
SP <<
SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
1079 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY <<
" + y_offset);\n";
1080 out <<
SP <<
SP <<
"}\n";
1084 if (!fNY_h.empty()) {
1086 if (fAttrDirection ==
"backward") {
1087 out <<
SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1088 out <<
SP <<
SP <<
"size_t offset = batch * " << fAttrHiddenSize <<
";\n";
1089 out <<
SP <<
SP <<
"size_t yh_offset = batch * " <<
num_directions * fAttrHiddenSize <<
";\n";
1090 out <<
SP <<
SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
1091 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";
1094 out <<
SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1095 if (fNSequence_lens.empty()) {
1098 out <<
SP <<
SP <<
"size_t seq = " <<
"tensor_" << fNSequence_lens <<
"[batch] - 1;\n";
1101 <<
" + batch * " << fAttrHiddenSize <<
";\n";
1102 out <<
SP <<
SP <<
"size_t yh_offset = batch * " <<
num_directions * fAttrHiddenSize <<
";\n";
1103 out <<
SP <<
SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
1104 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";
1108 out <<
SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1109 out <<
SP <<
SP <<
"size_t offset = " <<
batch_size * fAttrHiddenSize <<
" + batch * " << fAttrHiddenSize
1111 out <<
SP <<
SP <<
"size_t yh_offset = batch * " <<
num_directions * fAttrHiddenSize <<
" + "
1112 << fAttrHiddenSize <<
";\n";
1113 out <<
SP <<
SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
1114 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h offset
std::vector< size_t > GetTensorShape(const std::string &name) const
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< Dim > dim_shape)
bool CheckIfTensorAlreadyExist(std::string tensor_name)
std::shared_ptr< void > GetInitializedTensorData(std::string tensor_name)
ETensorType GetTensorType(std::string name) const
void UpdateInitializedTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape, std::shared_ptr< void > data)
Gated Recurrent Unit operator.
std::vector< size_t > fShapeY
Shape of the output.
std::string fNX
Name of the input.
std::string fType
Type of the tensors.
size_t fAttrLayout
Data layout.
std::string fAttrDirection
Direction of processing.
std::string fNR
Name of the recurrence.
float fAttrClip
Clip threshold.
std::vector< float > fAttrActivationBeta
Scaling values used by some activation functions.
std::string fNY
Name of the output.
std::string fNY_h
Name of the last sequence of the output.
std::string fNSequence_lens
Name of the length of the sequences.
std::string fNB
Name of the bias.
std::vector< std::string > fAttrActivations
Activation functions.
void Initialize(RModel &) override
Initialize the model.
ROperator_GRU(std::vector< float > activation_alpha, std::vector< float > activation_beta, std::vector< std::string > activations, float clip, std::string direction, size_t hidden_size, size_t layout, size_t linear_before_reset, std::string nameX, std::string nameW, std::string nameR, std::string nameB, std::string nameSequence_lens, std::string nameInitial_h, std::string nameY, std::string nameY_h)
Constructor of ROperator_GRU from the attributes.
size_t fAttrHiddenSize
Number of the hidden layers.
std::string Generate(std::string) override
Generate the inference code.
std::vector< float > fAttrActivationAlpha
Scaling values used by some activation functions.
std::vector< size_t > fShapeR
Shape of the recurrence.
std::string fNW
Name of the weights.
std::vector< size_t > fShapeX
Shape of the input.
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > >) override
Infers the shape of the output tensors.
std::vector< size_t > fShapeInitial_h
Shape of the initial value of hidden states.
std::vector< size_t > fShapeSequence_lens
Shape of the length of the sequences.
std::vector< ETensorType > TypeInference(std::vector< ETensorType >) override
Infers the type of the output tensors.
std::vector< size_t > fShapeY_h
Shape of the last sequence of the output.
size_t fAttrLinearBeforeReset
Linear layer before the reset gate.
std::vector< size_t > fShapeB
Shape of the bias.
std::string fNInitial_h
Name of the initial value of the hidden states.
std::vector< size_t > fShapeW
Shape of the weights.
ROperator_GRU()
Default constructor of ROperator_GRU.
std::vector< std::string > GetBlasRoutines() override
Returns the blas routines needed to compile the generated code.
std::vector< std::string_view > fInputTensorNames
std::vector< std::string_view > fOutputTensorNames
ETensorType ConvertStringToType(std::string type)