160 size_t num_directions = input[1][0];
161 size_t hidden_size = input[1][1] / 3;
163 size_t seq_length = input[0][0];
164 size_t batch_size = input[0][1];
165 std::vector<std::vector<size_t>>
ret(
166 {{seq_length, num_directions, batch_size, hidden_size}, {num_directions, batch_size, hidden_size}});
169 size_t batch_size = input[0][0];
170 size_t seq_length = input[0][1];
171 std::vector<std::vector<size_t>>
ret(
172 {{batch_size, seq_length, num_directions, hidden_size}, {batch_size, num_directions, hidden_size}});
183 if (!model.CheckIfTensorAlreadyExist(
fNX)) {
184 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
fNX +
" is not found in model.");
188 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
fNX +
" is not of 3 dimensions.");
190 if (!model.CheckIfTensorAlreadyExist(
fNW)) {
191 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
fNW +
" is not found in model.");
195 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
fNW +
" is not of 3 dimensions.");
197 if (!model.CheckIfTensorAlreadyExist(
fNR)) {
198 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
fNR +
" is not found in model.");
202 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
fNR +
" is not of 3 dimensions.");
205 if (!model.CheckIfTensorAlreadyExist(
fNB)) {
206 throw std::runtime_error(
"TMVA SOFIE GRU op input tensor " +
fNB +
" is not found in model.");
210 throw std::runtime_error(
"TMVA SOFIE GRU op input tensor " +
fNB +
" is not of 2 or 4 dimensions.");
214 auto original_data = model.GetInitializedTensorData(
fNB);
215 size_t num_directions =
fShapeW[0];
218 if (
fType ==
"float") {
219 float *original_bias =
static_cast<float *
>(original_data.get());
220 float *new_bias =
new float[num_directions * 6 * seq_length * batch_size *
fAttrHiddenSize];
221 for (
size_t direction = 0; direction < num_directions; direction++) {
222 for (
size_t i = 0; i < 6; i++) {
223 for (
size_t seq = 0; seq < seq_length; seq++) {
224 for (
size_t batch = 0; batch < batch_size; batch++) {
226 size_t offset = direction * 6 * batch_size * seq_length *
fAttrHiddenSize +
229 std::copy(original_bias + bias_offset, original_bias + bias_offset +
fAttrHiddenSize,
236 std::vector<size_t> new_bias_shape = {num_directions, 6, seq_length, batch_size,
fAttrHiddenSize};
237 std::shared_ptr<void> new_bias_ptr(new_bias, std::default_delete<
float[]>());
238 model.UpdateInitializedTensor(
fNB, model.GetTensorType(
fNB), new_bias_shape, new_bias_ptr);
245 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
fNSequence_lens +
"is not found in model.");
249 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
fNSequence_lens +
" is not of 1 dimension.");
253 if (!model.CheckIfTensorAlreadyExist(
fNInitial_h)) {
254 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
fNInitial_h +
" is not found in model.");
258 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
fNInitial_h +
" is not of 3 dimensions.");
263 if (!model.CheckIfTensorAlreadyExist(
fNY)) {
264 model.AddIntermediateTensor(
fNY, model.GetTensorType(
fNX),
fShapeY);
267 if (!
fNY_h.empty()) {
269 if (!model.CheckIfTensorAlreadyExist(
fNY_h)) {
275 if (activation !=
"Relu" && activation !=
"Tanh" && activation !=
"Sigmoid" && activation !=
"Affine" &&
276 activation !=
"LeakyRelu" && activation !=
"ThresholdRelu" && activation !=
"ScaledTanh" &&
277 activation !=
"HardSigmoid" && activation !=
"Elu" && activation !=
"Softsign" && activation !=
"Softplus") {
278 throw std::runtime_error(
"TMVA SOFIE - Activation function " + activation +
" not implemented");
285 throw std::runtime_error(
"TMVA SOFIE - Invalid GRU direction fAttrDirection = " +
fAttrDirection);
288 throw std::runtime_error(
"TMVA SOFIE - fAttrHiddenSize must be equal to " + std::to_string(
fShapeW[1] / 3));
291 throw std::runtime_error(
"TMVA SOFIE - Layout fAttrLayout = " + std::to_string(
fAttrLayout) +
292 " must be 0 (timewise) or 1 (batchwise)");
309 std::string opName =
"op_gru_" +
fNX;
311 size_t num_directions =
fShapeW[0];
314 size_t input_size =
fShapeX[2];
316 auto declareVector = [&](std::string
const &
name, std::size_t
n) {
317 std::string fullName = opName +
"_" +
name;
322 declareVector(
"input", seq_length * batch_size * input_size);
323 declareVector(
"initial_hidden_state", num_directions * batch_size *
fAttrHiddenSize);
324 declareVector(
"initial_cell_state", num_directions * batch_size *
fAttrHiddenSize);
328 declareVector(
"f_update_gate", ff_size);
329 declareVector(
"f_reset_gate", ff_size);
330 declareVector(
"f_hidden_gate", ff_size);
332 size_t hs_size = seq_length * num_directions * batch_size *
fAttrHiddenSize;
333 declareVector(
"update_gate", hs_size);
334 declareVector(
"reset_gate", hs_size);
335 declareVector(
"hidden_gate", hs_size);
342 declareVector(
"hidden_state", hs_size);
349 OpName =
"op_" + OpName;
350 std::stringstream out;
354 size_t input_size =
fShapeX[2];
355 size_t num_directions =
fShapeW[0];
357 auto getVec = [&](std::string
const &
name) {
return "tensor_op_gru_" +
fNX +
"_" +
name; };
361 out <<
SP <<
fType <<
" const* " << OpName <<
"_input = tensor_" <<
fNX <<
";\n";
364 out <<
SP <<
fType <<
" * " << OpName <<
"_input = " << getVec(
"input") <<
";\n";
366 out <<
SP <<
fType <<
" " << OpName <<
"_input[" << seq_length * batch_size * input_size <<
"];\n";
368 out <<
SP <<
"for(size_t seq = 0; seq < " << seq_length <<
"; seq++) {\n";
369 out <<
SP <<
SP <<
"for(size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
370 out <<
SP <<
SP <<
SP <<
"for(size_t i = 0; i < " << input_size <<
"; i++) {\n";
371 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input[seq * " << batch_size * input_size <<
" + batch * " << input_size
372 <<
" + i] = " <<
"tensor_" <<
fNX <<
"[batch * " << seq_length * input_size <<
" + seq * " << input_size
374 out <<
SP <<
SP <<
SP <<
"}\n";
375 out <<
SP <<
SP <<
"}\n";
382 out <<
SP <<
fType <<
" *" << OpName <<
"_initial_hidden_state = " <<
" tensor_" <<
fNInitial_h <<
";\n";
385 out <<
SP <<
fType <<
" * " << OpName <<
"_initial_hidden_state = " << getVec(
"initial_hidden_state")
388 out <<
SP <<
fType <<
" " << OpName <<
"_initial_hidden_state["
391 for (
size_t direction = 0; direction < num_directions; direction++) {
392 out <<
SP <<
"for(size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
394 out <<
SP <<
SP <<
SP << OpName <<
"_initial_hidden_state[" << direction * batch_size *
fAttrHiddenSize
397 out <<
SP <<
SP <<
"}\n";
406 out <<
SP <<
fType <<
" * " << OpName <<
"_f_update_gate = " << getVec(
"f_update_gate") <<
";\n";
407 out <<
SP <<
fType <<
" * " << OpName <<
"_f_reset_gate = " << getVec(
"f_reset_gate") <<
";\n";
408 out <<
SP <<
fType <<
" * " << OpName <<
"_f_hidden_gate = " << getVec(
"f_hidden_gate") <<
";\n";
410 out <<
SP <<
fType <<
" " << OpName <<
"_f_update_gate[" << feedforward_size <<
"] = {0};\n";
411 out <<
SP <<
fType <<
" " << OpName <<
"_f_reset_gate[" << feedforward_size <<
"] = {0};\n";
412 out <<
SP <<
fType <<
" " << OpName <<
"_f_hidden_gate[" << feedforward_size <<
"] = {0};\n";
415 size_t hidden_state_size = seq_length * num_directions * batch_size *
fAttrHiddenSize;
417 out <<
SP <<
fType <<
" * " << OpName <<
"_update_gate = " << getVec(
"update_gate") <<
";\n";
418 out <<
SP <<
fType <<
" * " << OpName <<
"_reset_gate = " << getVec(
"reset_gate") <<
";\n";
419 out <<
SP <<
fType <<
" * " << OpName <<
"_hidden_gate = " << getVec(
"hidden_gate") <<
";\n";
421 out <<
SP <<
fType <<
" " << OpName <<
"_update_gate[" << hidden_state_size <<
"] = {0};\n";
422 out <<
SP <<
fType <<
" " << OpName <<
"_reset_gate[" << hidden_state_size <<
"] = {0};\n";
423 out <<
SP <<
fType <<
" " << OpName <<
"_hidden_gate[" << hidden_state_size <<
"] = {0};\n";
427 out <<
SP <<
fType <<
" *" << OpName <<
"_hidden_state = tensor_" <<
fNY <<
";\n";
430 out <<
SP <<
fType <<
" * " << OpName <<
"_hidden_state = " << getVec(
"hidden_state") <<
";\n";
432 out <<
SP <<
fType <<
" " << OpName <<
"_hidden_state[" << hidden_state_size <<
"] = {0};\n";
437 out <<
SP <<
fType <<
" * " << OpName <<
"_feedback = " << getVec(
"feedback") <<
";\n";
442 out <<
SP <<
"char " << OpName <<
"_transA = 'N';\n";
443 out <<
SP <<
"char " << OpName <<
"_transB = 'T';\n";
444 out <<
SP <<
"int " << OpName <<
"_m = " << seq_length * batch_size <<
";\n";
445 out <<
SP <<
"int " << OpName <<
"_m2 = " << batch_size <<
";\n";
447 out <<
SP <<
"int " << OpName <<
"_k = " << input_size <<
";\n";
448 if (
fType ==
"float") {
449 out <<
SP <<
"float " << OpName <<
"_alpha = 1.;\n";
450 out <<
SP <<
"float " << OpName <<
"_beta = 0.;\n";
453 out <<
SP <<
"int " << OpName <<
"_bias_size = " << seq_length * batch_size *
fAttrHiddenSize <<
";\n";
455 out <<
SP <<
"int " << OpName <<
"_incx = 1;\n";
456 out <<
SP <<
"int " << OpName <<
"_incy = 1;\n";
457 out <<
SP <<
"int " << OpName <<
"_feedback_size = " << batch_size *
fAttrHiddenSize <<
";\n";
459 for (
size_t direction = 0; direction < num_directions; direction++) {
460 if (direction == 0) {
461 if (
fType ==
"float") {
463 out <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName <<
"_n, &"
464 << OpName <<
"_m, &" << OpName <<
"_k, &" << OpName <<
"_alpha, tensor_" <<
fNW <<
", &" << OpName
465 <<
"_k, " << OpName <<
"_input, &" << OpName <<
"_k, &" << OpName <<
"_beta, " << OpName
466 <<
"_f_update_gate, &" << OpName <<
"_n);\n";
469 out <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName <<
"_n, &"
470 << OpName <<
"_m, &" << OpName <<
"_k, &" << OpName <<
"_alpha, tensor_" <<
fNW <<
" + " << wr_offset
471 <<
", &" << OpName <<
"_k, " << OpName <<
"_input, &" << OpName <<
"_k, &" << OpName <<
"_beta, "
472 << OpName <<
"_f_reset_gate, &" << OpName <<
"_n);\n";
475 out <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName <<
"_n, &"
476 << OpName <<
"_m, &" << OpName <<
"_k, &" << OpName <<
"_alpha, tensor_" <<
fNW <<
" + " << wh_offset
477 <<
", &" << OpName <<
"_k, " << OpName <<
"_input, &" << OpName <<
"_k, &" << OpName <<
"_beta, "
478 << OpName <<
"_f_hidden_gate, &" << OpName <<
"_n);\n";
481 if (
fType ==
"float") {
484 out <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName <<
"_n, &"
485 << OpName <<
"_m, &" << OpName <<
"_k, &" << OpName <<
"_alpha, tensor_" <<
fNW <<
" + " << wz_offset
486 <<
", &" << OpName <<
"_k, " << OpName <<
"_input, &" << OpName <<
"_k, &" << OpName <<
"_beta, "
487 << OpName <<
"_f_update_gate, &" << OpName <<
"_n);\n";
490 out <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName <<
"_n, &"
491 << OpName <<
"_m, &" << OpName <<
"_k, &" << OpName <<
"_alpha, tensor_" <<
fNW <<
" + " << wr_offset
492 <<
", &" << OpName <<
"_k, " << OpName <<
"_input, &" << OpName <<
"_k, &" << OpName <<
"_beta, "
493 << OpName <<
"_f_reset_gate, &" << OpName <<
"_n);\n";
496 out <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName <<
"_n, &"
497 << OpName <<
"_m, &" << OpName <<
"_k, &" << OpName <<
"_alpha, tensor_" <<
fNW <<
" + " << wh_offset
498 <<
", &" << OpName <<
"_k, " << OpName <<
"_input, &" << OpName <<
"_k, &" << OpName <<
"_beta, "
499 << OpName <<
"_f_hidden_gate, &" << OpName <<
"_n);\n";
504 if (direction == 0) {
505 if (
fType ==
"float") {
507 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_" <<
fNB <<
", &"
508 << OpName <<
"_incx, " << OpName <<
"_f_update_gate, &" << OpName <<
"_incy);\n";
511 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_" <<
fNB <<
" + "
512 << rbz_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_f_update_gate, &" << OpName
516 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_" <<
fNB <<
" + "
517 << wbr_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_f_reset_gate, &" << OpName
522 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_" <<
fNB <<
" + "
523 << rbr_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_f_reset_gate, &" << OpName
527 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_" <<
fNB <<
" + "
528 << wbh_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_f_hidden_gate, &" << OpName
533 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_" <<
fNB
534 <<
" + " << rbh_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_f_hidden_gate, &" << OpName
539 if (
fType ==
"float") {
542 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_" <<
fNB <<
" + "
543 << wbz_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_f_update_gate, &" << OpName
548 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_" <<
fNB <<
" + "
549 << rbz_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_f_update_gate, &" << OpName
553 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_" <<
fNB <<
" + "
554 << wbr_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_f_reset_gate, &" << OpName
558 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_" <<
fNB <<
" + "
559 << rbr_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_f_reset_gate, &" << OpName
563 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_" <<
fNB <<
" + "
564 << wbh_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_f_hidden_gate, &" << OpName
569 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_" <<
fNB
570 <<
" + " << rbh_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_f_hidden_gate, &" << OpName
578 out <<
SP <<
"for (size_t seq = 0; seq < " << seq_length <<
"; seq++) {\n";
580 if (direction == 0) {
581 out <<
SP <<
SP <<
"size_t gate_offset = seq * " << num_directions * batch_size *
fAttrHiddenSize <<
";\n";
583 out <<
SP <<
SP <<
"size_t gate_offset = seq * " << num_directions * batch_size *
fAttrHiddenSize <<
" + "
587 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_f_update_gate + offset, " << OpName <<
"_f_update_gate + offset + "
588 << f_seq_size <<
", " << OpName <<
"_update_gate + gate_offset);\n";
589 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_f_reset_gate + offset, " << OpName <<
"_f_reset_gate + offset + "
590 << f_seq_size <<
", " << OpName <<
"_reset_gate + gate_offset);\n";
591 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_f_hidden_gate + offset, " << OpName <<
"_f_hidden_gate + offset + "
592 << f_seq_size <<
", " << OpName <<
"_hidden_gate + gate_offset);\n";
595 out <<
SP <<
"for (size_t seq = 0; seq < " << seq_length <<
"; seq++) {\n";
597 out <<
SP <<
SP <<
"size_t index = " << seq_length - 1 <<
" - seq;\n";
599 out <<
SP <<
SP <<
"size_t index = seq;\n";
601 out <<
SP <<
SP <<
"int m2 = " << batch_size <<
";\n";
602 if (direction == 0) {
603 out <<
SP <<
SP <<
"size_t offset = index * " << num_directions * batch_size *
fAttrHiddenSize <<
";\n";
605 out <<
SP <<
SP <<
"size_t offset = index * " << num_directions * batch_size *
fAttrHiddenSize <<
" + "
610 out <<
SP <<
SP <<
"if (seq == 0) {\n";
612 if (direction == 0) {
613 if (
fType ==
"float") {
614 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName
615 <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
", &" << OpName
616 <<
"_n, " << OpName <<
"_initial_hidden_state, &" << OpName <<
"_n, &" << OpName <<
"_alpha, "
617 << OpName <<
"_update_gate + offset, &" << OpName <<
"_n);\n";
619 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName
620 <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + " << rr_offset
621 <<
", &" << OpName <<
"_n, " << OpName <<
"_initial_hidden_state, &" << OpName <<
"_n, &" << OpName
622 <<
"_alpha, " << OpName <<
"_reset_gate + offset, &" << OpName <<
"_n);\n";
625 if (
fType ==
"float") {
627 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName
628 <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + " << rz_offset
629 <<
", &" << OpName <<
"_n, " << OpName <<
"_initial_hidden_state, &" << OpName <<
"_n, &" << OpName
630 <<
"_alpha, " << OpName <<
"_update_gate + offset, &" << OpName <<
"_n);\n";
632 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName
633 <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + " << rr_offset
634 <<
", &" << OpName <<
"_n, " << OpName <<
"_initial_hidden_state, &" << OpName <<
"_n, &" << OpName
635 <<
"_alpha, " << OpName <<
"_reset_gate + offset, &" << OpName <<
"_n);\n";
639 out <<
SP <<
SP <<
"} else {\n";
641 if (direction == 0) {
643 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
646 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (seq - 1) * "
649 if (
fType ==
"float") {
650 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName
651 <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
", &" << OpName <<
"_n, "
652 << OpName <<
"_hidden_state + previous_offset, &" << OpName <<
"_n, &" << OpName <<
"_alpha, " << OpName
653 <<
"_update_gate + offset, &" << OpName <<
"_n);\n";
655 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName
656 <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + " << rr_offset
657 <<
", &" << OpName <<
"_n, " << OpName <<
"_hidden_state + previous_offset, &" << OpName <<
"_n, &"
658 << OpName <<
"_alpha, " << OpName <<
"_reset_gate + offset, &" << OpName <<
"_n);\n";
661 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
663 if (
fType ==
"float") {
665 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName
666 <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + " << rz_offset
667 <<
", &" << OpName <<
"_n, " << OpName <<
"_hidden_state + previous_offset, &" << OpName <<
"_n, &"
668 << OpName <<
"_alpha, " << OpName <<
"_update_gate + offset, &" << OpName <<
"_n);\n";
670 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName
671 <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + " << rr_offset
672 <<
", &" << OpName <<
"_n, " << OpName <<
"_hidden_state + previous_offset, &" << OpName <<
"_n, &"
673 << OpName <<
"_alpha, " << OpName <<
"_reset_gate + offset, &" << OpName <<
"_n);\n";
676 out <<
SP <<
SP <<
"}\n";
680 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
681 if (
fType ==
"float") {
682 out <<
SP <<
SP <<
SP <<
"float z = (" << OpName <<
"_update_gate[i] > " << -
fAttrClip <<
") ? " << OpName
683 <<
"_update_gate[i] : " << -
fAttrClip <<
";\n";
686 if (
fType ==
"float") {
687 out <<
SP <<
SP <<
SP <<
"float r = (" << OpName <<
"_reset_gate[i] > " << -
fAttrClip <<
") ? " << OpName
688 <<
"_reset_gate[i] : " << -
fAttrClip <<
";\n";
691 out <<
SP <<
SP <<
"}\n";
696 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
697 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_update_gate[i] < 0.)\n";
698 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_update_gate[i] = 0.;\n";
699 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_reset_gate[i] < 0.)\n";
700 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_reset_gate[i] = 0.;\n";
701 out <<
SP <<
SP <<
"}\n";
703 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
704 if (
fType ==
"float") {
705 out <<
SP <<
SP <<
SP <<
"float z = exp(-2 * " << OpName <<
"_update_gate[i]);\n";
707 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_update_gate[i] = (1. - z) / (1. + z);\n";
708 if (
fType ==
"float") {
709 out <<
SP <<
SP <<
SP <<
"float r = exp(-2 * " << OpName <<
"_reset_gate[i]);\n";
711 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_reset_gate[i] = (1. - r) / (1. + r);\n";
712 out <<
SP <<
SP <<
"}\n";
714 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
715 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_update_gate[i] = 1. / (1. + exp(-" << OpName
716 <<
"_update_gate[i]));\n";
717 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_reset_gate[i] = 1. / (1. + exp(-" << OpName
718 <<
"_reset_gate[i]));\n";
719 out <<
SP <<
SP <<
"}\n";
721 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
726 out <<
SP <<
SP <<
"}\n";
728 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
729 if (
fType ==
"float") {
731 <<
"_update_gate[i]);\n";
734 <<
" * (1. - z) / (1. + z);\n";
735 if (
fType ==
"float") {
737 <<
"_reset_gate[i]);\n";
740 <<
" * (1. - r) / (1. + r);\n";
741 out <<
SP <<
SP <<
"}\n";
743 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
744 if (
fType ==
"float") {
747 out <<
SP <<
SP <<
SP <<
"float zb = (za > 0.) ? za : 0.;\n";
749 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_update_gate[i] = (zb < 1.) ? zb : 1.;\n";
750 if (
fType ==
"float") {
753 out <<
SP <<
SP <<
SP <<
"float rb = (ra > 0.) ? ra : 0.;\n";
755 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_reset_gate[i] = (rb < 1.) ? rb : 1.;\n";
756 out <<
SP <<
SP <<
"}\n";
758 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
759 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_update_gate[i] < 0.)\n";
761 << OpName <<
"_update_gate[i];\n";
762 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_reset_gate[i] < 0.)\n";
764 << OpName <<
"_reset_gate[i];\n";
765 out <<
SP <<
SP <<
"}\n";
767 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
770 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_update_gate[i] = 0.;\n";
773 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_reset_gate[i] = 0.;\n";
774 out <<
SP <<
SP <<
"}";
776 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
777 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_update_gate[i] < 0.)\n";
779 <<
" * exp(" << OpName <<
"_update_gate[i] - 1.);\n";
780 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_reset_gate[i] < 0.)\n";
782 <<
" * exp(" << OpName <<
"_reset_gate[i] - 1.);\n";
783 out <<
SP <<
SP <<
"}\n";
785 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
786 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_update_gate[i] = " << OpName <<
"_update_gate[i] / (1. + abs("
787 << OpName <<
"_update_gate[i]));\n";
788 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_reset_gate[i] = " << OpName <<
"_reset_gate[i] / (1. + abs("
789 << OpName <<
"_reset_gate[i]));\n";
790 out <<
SP <<
SP <<
"}\n";
792 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
793 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_update_gate[i] = log(1. + exp(" << OpName <<
"_update_gate[i]));\n";
794 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_reset_gate[i] = log(1. + exp(" << OpName <<
"_reset_gate[i]));\n";
795 out <<
SP <<
SP <<
"}\n";
799 out <<
SP <<
SP <<
"if (seq == 0) {\n";
802 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
803 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_feedback[i] = " << OpName <<
"_reset_gate[i + offset] * "
804 << OpName <<
"_initial_hidden_state[i];\n";
805 out <<
SP <<
SP <<
SP <<
"}\n";
807 out <<
SP <<
SP <<
"} else {\n";
809 if (direction == 0) {
811 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
814 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (seq - 1) * "
818 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
821 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
822 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_feedback[i] = " << OpName <<
"_reset_gate[i + offset] * " << OpName
823 <<
"_hidden_state[i + previous_offset];\n";
824 out <<
SP <<
SP <<
SP <<
"}\n";
825 out <<
SP <<
SP <<
"}\n";
827 size_t rh_offset = (direction == 0)
830 out <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName <<
"_n, &"
831 << OpName <<
"_m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + " << rh_offset
832 <<
", &" << OpName <<
"_n, " << OpName <<
"_feedback, &" << OpName <<
"_n, &" << OpName <<
"_beta, "
833 << OpName <<
"_feedback, &" << OpName <<
"_n);\n";
837 size_t rh_offset = (direction == 0)
840 out <<
SP <<
SP <<
"if (seq == 0) {\n";
843 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName
844 <<
"_n, &" << OpName <<
"_m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + "
845 << rh_offset <<
", &" << OpName <<
"_n, " << OpName <<
"_initial_hidden_state, &" << OpName <<
"_n, &"
846 << OpName <<
"_beta, " << OpName <<
"_feedback, &" << OpName <<
"_n);\n";
848 out <<
SP <<
SP <<
"} else {\n";
850 if (direction == 0) {
852 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
855 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (seq - 1) * "
859 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
862 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName
863 <<
"_n, &" << OpName <<
"_m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + "
864 << rh_offset <<
", &" << OpName <<
"_n, " << OpName <<
"_hidden_state + previous_offset, &" << OpName
865 <<
"_n, &" << OpName <<
"_beta, " << OpName <<
"_feedback, &" << OpName <<
"_n);\n";
867 out <<
SP <<
SP <<
"}\n";
870 size_t rbh_offset = (direction == 0) ? 5 * batch_size * seq_length *
fAttrHiddenSize
872 out <<
SP <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_feedback_size, &" << OpName <<
"_alpha, tensor_" <<
fNB
873 <<
" + " << rbh_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_feedback, &" << OpName
877 out <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
878 out <<
SP <<
SP <<
SP << OpName <<
"_feedback[i] *= " << OpName <<
"_reset_gate[i + offset];\n";
879 out <<
SP <<
SP <<
"}\n";
883 out <<
SP <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_feedback_size, &" << OpName <<
"_alpha, " << OpName
884 <<
"_feedback, &" << OpName <<
"_incx, " << OpName <<
"_hidden_gate + offset, &" << OpName <<
"_incy);\n";
888 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
889 if (
fType ==
"float") {
890 out <<
SP <<
SP <<
SP <<
"float x = (" << OpName <<
"_hidden_gate[i] > " << -
fAttrClip <<
") ? " << OpName
891 <<
"_hidden_gate[i] : " << -
fAttrClip <<
";\n";
894 out <<
SP <<
SP <<
"}\n";
899 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
900 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_hidden_gate[i] < 0.)\n";
901 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_hidden_gate[i] = 0.;\n";
902 out <<
SP <<
SP <<
"}\n";
904 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
905 if (
fType ==
"float") {
906 out <<
SP <<
SP <<
SP <<
"float ex = exp(-2 * " << OpName <<
"_hidden_gate[i]);\n";
908 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_hidden_gate[i] = (1. - ex) / (1. + ex);\n";
909 out <<
SP <<
SP <<
"}\n";
911 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
912 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_hidden_gate[i] = 1. / (1. + exp(-" << OpName
913 <<
"_hidden_gate[i]));\n";
914 out <<
SP <<
SP <<
"}\n";
916 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
918 <<
" * " << OpName <<
"_hidden_gate[i] + " <<
fAttrActivationBeta[direction * 2 + 1] <<
";\n";
919 out <<
SP <<
SP <<
"}\n";
921 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
922 if (
fType ==
"float") {
924 <<
"_hidden_gate[i]);\n";
927 <<
" * (1. - ex) / (1. + ex);\n";
928 out <<
SP <<
SP <<
"}\n";
930 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
931 if (
fType ==
"float") {
934 out <<
SP <<
SP <<
SP <<
"float b = (a > 0.) ? a : 0.;\n";
936 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_hidden_gate[i] = (b < 1.) ? b : 1.;\n";
937 out <<
SP <<
SP <<
"}\n";
939 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
940 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_hidden_gate[i] < 0.)\n";
942 <<
" * " << OpName <<
"_hidden_gate[i];\n";
943 out <<
SP <<
SP <<
"}\n";
945 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
948 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_hidden_gate[i] = 0.;\n";
949 out <<
SP <<
SP <<
"}";
951 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
952 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_hidden_gate[i] < 0.)\n";
954 <<
" * exp(" << OpName <<
"_hidden_gate[i] - 1.);\n";
955 out <<
SP <<
SP <<
"}\n";
957 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
958 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_hidden_gate[i] = " << OpName <<
"_hidden_gate[i] / (1. + abs("
959 << OpName <<
"_hidden_gate[i]));\n";
960 out <<
SP <<
SP <<
"}\n";
962 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
963 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_hidden_gate[i] = log(1. + exp(" << OpName <<
"_hidden_gate[i]));\n";
964 out <<
SP <<
SP <<
"}\n";
968 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
969 out <<
SP <<
SP <<
SP << OpName <<
"_hidden_state[i] = ( 1. - " << OpName <<
"_update_gate[i]) * " << OpName
970 <<
"_hidden_gate[i];\n";
971 out <<
SP <<
SP <<
"}\n";
973 out <<
SP <<
SP <<
"if (seq == 0) {\n";
976 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
977 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_hidden_state[i + offset] += " << OpName
978 <<
"_update_gate[i + offset] * " << OpName <<
"_initial_hidden_state[i];\n";
979 out <<
SP <<
SP <<
SP <<
"}\n";
981 out <<
SP <<
SP <<
"} else {\n";
983 if (direction == 0) {
985 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
988 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (seq - 1) * "
992 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
995 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
996 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_hidden_state[i + offset] += " << OpName
997 <<
"_update_gate[i + offset] * " << OpName <<
"_hidden_state[i + previous_offset];\n";
998 out <<
SP <<
SP <<
SP <<
"}\n";
999 out <<
SP <<
SP <<
"}\n";
1006 out <<
SP <<
"for (size_t seq = 0; seq < " << seq_length <<
"; seq++) {\n";
1007 out <<
SP <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1009 for (
size_t direction = 0; direction < num_directions; direction++) {
1011 out <<
SP <<
SP <<
SP <<
SP <<
SP <<
SP << OpName <<
"_hidden_state[seq * "
1016 out <<
SP <<
SP <<
SP <<
"}\n";
1017 out <<
SP <<
SP <<
"}\n";
1023 if (!
fNY_h.empty()) {
1028 out <<
SP <<
"std::copy(" << OpName <<
"_hidden_state, " << OpName <<
"_hidden_state + " << yh_size
1029 <<
", tensor_" <<
fNY_h <<
");\n";
1031 size_t offset = (seq_length - 1) * num_directions * batch_size *
fAttrHiddenSize;
1032 out <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + " << offset <<
", " << OpName
1033 <<
"_hidden_state + " << offset <<
" + " << yh_size <<
", tensor_" <<
fNY_h <<
");\n";
1035 if (num_directions == 2) {
1036 out <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + " << yh_size <<
", " << OpName
1037 <<
"_hidden_state + " << 2 * yh_size <<
", tensor_" <<
fNY_h <<
" + " << yh_size <<
");\n";
1041 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1043 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + offset, " << OpName
1047 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1049 out <<
SP <<
SP <<
"size_t offset = seq * " << num_directions * batch_size *
fAttrHiddenSize
1052 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + offset, " << OpName
1053 <<
"_hidden_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY_h <<
" + yh_offset);\n";
1056 if (num_directions == 2) {
1057 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1062 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + offset, " << OpName
1063 <<
"_hidden_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY_h <<
" + yh_offset);\n";
1071 for (
size_t direction = 0; direction < num_directions; direction++) {
1072 out <<
SP <<
"for (size_t seq = 0; seq < " << seq_length <<
"; seq++) {\n";
1073 out <<
SP <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1074 out <<
SP <<
SP <<
SP <<
"size_t offset = seq * " << num_directions * batch_size *
fAttrHiddenSize <<
" + "
1076 out <<
SP <<
SP <<
SP <<
"size_t y_offset = batch * " << seq_length * num_directions *
fAttrHiddenSize
1078 out <<
SP <<
SP <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + offset, " << OpName
1079 <<
"_hidden_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY <<
" + y_offset);\n";
1080 out <<
SP <<
SP <<
"}\n";
1084 if (!
fNY_h.empty()) {
1087 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1089 out <<
SP <<
SP <<
"size_t yh_offset = batch * " << num_directions *
fAttrHiddenSize <<
";\n";
1090 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + offset, " << OpName
1091 <<
"_hidden_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY_h <<
" + yh_offset);\n";
1094 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1096 out <<
SP <<
SP <<
"size_t seq = " << seq_length - 1 <<
";\n";
1100 out <<
SP <<
SP <<
"size_t offset = seq * " << num_directions * batch_size *
fAttrHiddenSize
1102 out <<
SP <<
SP <<
"size_t yh_offset = batch * " << num_directions *
fAttrHiddenSize <<
";\n";
1103 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + offset, " << OpName
1104 <<
"_hidden_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY_h <<
" + yh_offset);\n";
1107 if (num_directions == 2) {
1108 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1111 out <<
SP <<
SP <<
"size_t yh_offset = batch * " << num_directions *
fAttrHiddenSize <<
" + "
1113 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + offset, " << OpName
1114 <<
"_hidden_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY_h <<
" + yh_offset);\n";
ROperator_GRU(std::vector< float > activation_alpha, std::vector< float > activation_beta, std::vector< std::string > activations, float clip, std::string direction, size_t hidden_size, size_t layout, size_t linear_before_reset, std::string nameX, std::string nameW, std::string nameR, std::string nameB, std::string nameSequence_lens, std::string nameInitial_h, std::string nameY, std::string nameY_h)
Constructor of ROperator_GRU from the attributes.