Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_GRU.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR_GRU
2#define TMVA_SOFIE_ROPERATOR_GRU
3
4#include "TMVA/RModel.hxx"
5#include "TMVA/ROperator.hxx"
7
8#include <memory>
9#include <sstream>
10#include <stdexcept>
11#include <string>
12#include <vector>
13
14namespace TMVA {
15namespace Experimental {
16namespace SOFIE {
17
18/*! \brief Gated Recurrent Unit operator
19 *
20 * Inference code generation for one-layer GRU. Supports forward, reverse and bidirectional GRU.
21 * See the <a href="https://github.com/onnx/onnx/blob/master/docs/Operators.md#GRU">ONNX documentation</a>
22 * for details about the supported GRU architectures.
23 */
24template <typename T> class ROperator_GRU final : public ROperator {
25 private:
26 std::vector<float> fAttrActivationAlpha; ///< Scaling values used by some activation functions
27 std::vector<float> fAttrActivationBeta; ///< Scaling values used by some activation functions
28 std::vector<std::string> fAttrActivations; ///< Activation functions
29 float fAttrClip; ///< Clip threshold
30 std::string fAttrDirection; ///< Direction of processing
31 size_t fAttrHiddenSize; ///< Number of the hidden layers
32 size_t fAttrLayout; ///< Data layout
33 size_t fAttrLinearBeforeReset; ///< Linear layer before the reset gate
34
35 std::string fNX; ///< Name of the input
36 std::string fNW; ///< Name of the weights
37 std::string fNR; ///< Name of the recurrence
38 std::string fNB; ///< Name of the bias
39 std::string fNSequence_lens; ///< Name of the length of the sequences
40 std::string fNInitial_h; ///< Name of the initial value of the hidden states
41 std::string fNY; ///< Name of the output
42 std::string fNY_h; ///< Name of the last sequence of the output
43
44 std::vector<size_t> fShapeX; ///< Shape of the input
45 std::vector<size_t> fShapeW; ///< Shape of the weights
46 std::vector<size_t> fShapeR; ///< Shape of the recurrence
47 std::vector<size_t> fShapeB; ///< Shape of the bias
48 std::vector<size_t> fShapeSequence_lens; ///< Shape of the length of the sequences
49 std::vector<size_t> fShapeInitial_h; ///< Shape of the initial value of hidden states
50 std::vector<size_t> fShapeY; ///< Shape of the output
51 std::vector<size_t> fShapeY_h; ///< Shape of the last sequence of the output
52
53 std::string fType; ///< Type of the tensors
54
55 public:
56 /*! Default constructor of ROperator_GRU */
58
59 /*! \brief Constructor of ROperator_GRU from the attributes
60 *
61 * \param activation_alpha scaling values used by some activation functions
62 * \param activation_beta scaling values used by some activation functions
63 * \param activations activation functions
64 * \param clip clip threshold
65 * \param direction direction of processing of the sequneces
66 * \param hidden_size number of hidden layers
67 * \param layout data layout
68 * \param linear_before_reset Linear layer before the reset gate
69 * \param nameX name of the input tensor
70 * \param nameW name of the weight tensor
71 * \param nameR name of the recurrence tensor
72 * \param nameB name of the bias tensor
73 * \param nameSequence_lens name of the length of the sequences
74 * \param nameInitial_h name of the initial value of the hidden states
75 * \param nameY name of the output
76 * \param nameY_h name of the last sequence of the output
77 */
78 ROperator_GRU(std::vector<float> activation_alpha,
79 std::vector<float> activation_beta,
80 std::vector<std::string> activations, float clip,
81 std::string direction, size_t hidden_size,
82 size_t layout, size_t linear_before_reset,
83 std::string nameX, std::string nameW, std::string nameR,
84 std::string nameB, std::string nameSequence_lens,
85 std::string nameInitial_h, std::string nameY, std::string nameY_h)
86 : fAttrActivationAlpha(activation_alpha),
87 fAttrActivationBeta(activation_beta), fAttrActivations(activations),
88 fAttrClip(clip), fAttrDirection(direction), fAttrHiddenSize(hidden_size),
89 fAttrLayout(layout), fAttrLinearBeforeReset(linear_before_reset),
90 fNX(UTILITY::Clean_name(nameX)), fNW(UTILITY::Clean_name(nameW)),
91 fNR(UTILITY::Clean_name(nameR)), fNB(UTILITY::Clean_name(nameB)),
92 fNSequence_lens(UTILITY::Clean_name(nameSequence_lens)),
93 fNInitial_h(UTILITY::Clean_name(nameInitial_h)),
94 fNY(UTILITY::Clean_name(nameY)), fNY_h(UTILITY::Clean_name(nameY_h)) {
95 if (std::is_same<T, float>::value) {
96 fType = "float";
97 } else {
98 throw std::runtime_error(
99 "TMVA SOFIE Encountered unsupported type parsing a GRU operator");
100 }
101 }
102
103 /*! \brief Infers the type of the output tensors
104 *
105 * \param input type of the input tensors
106 */
107 std::vector<ETensorType> TypeInference(std::vector<ETensorType> /*input*/);
108
109 /*! \brief Infers the shape of the output tensors
110 *
111 * \param input shape of the input tensors
112 */
113 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> /*input*/);
114
115 /*! \brief Initialize the model
116 *
117 * \param model Model
118 */
119 void Initialize(RModel & /*model*/);
120
121 /*! \brief Generate the inference code
122 *
123 * \param OpName name of the operator
124 */
125 std::string Generate(std::string /*OpName*/);
126
127 /*! \brief Generate the code for the Session internal data vectors
128 *
129 * \param opName name of the operator
130 */
131 std::string GenerateSessionMembersCode(std::string opName);
132
133 /*! \brief Returns the blas routines needed to compile the generated code
134 */
135 std::vector<std::string> GetBlasRoutines() { return { std::string("Gemm"), std::string("Axpy") }; }
136};
137
138} // namespace SOFIE
139} // namespace Experimental
140} // namespace TMVA
141
142// Implementation of the ROperator_GRU class
143#include "TMVA/ROperator_GRU.icc"
144
145#endif
Gated Recurrent Unit operator.
std::string GenerateSessionMembersCode(std::string opName)
Generate the code for the Session internal data vectors.
std::vector< size_t > fShapeY
Shape of the output.
std::string fNX
Name of the input.
std::string fType
Type of the tensors.
std::string fAttrDirection
Direction of processing.
std::string fNR
Name of the recurrence.
std::vector< float > fAttrActivationBeta
Scaling values used by some activation functions.
std::vector< std::string > GetBlasRoutines()
Returns the blas routines needed to compile the generated code.
std::string fNY
Name of the output.
std::string fNY_h
Name of the last sequence of the output.
std::string fNSequence_lens
Name of the length of the sequences.
std::vector< std::string > fAttrActivations
Activation functions.
ROperator_GRU(std::vector< float > activation_alpha, std::vector< float > activation_beta, std::vector< std::string > activations, float clip, std::string direction, size_t hidden_size, size_t layout, size_t linear_before_reset, std::string nameX, std::string nameW, std::string nameR, std::string nameB, std::string nameSequence_lens, std::string nameInitial_h, std::string nameY, std::string nameY_h)
Constructor of ROperator_GRU from the attributes.
size_t fAttrHiddenSize
Number of the hidden layers.
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > >)
Infers the shape of the output tensors.
std::vector< float > fAttrActivationAlpha
Scaling values used by some activation functions.
std::vector< size_t > fShapeR
Shape of the recurrence.
void Initialize(RModel &)
Initialize the model.
std::string fNW
Name of the weights.
std::vector< size_t > fShapeX
Shape of the input.
std::string Generate(std::string)
Generate the inference code.
std::vector< size_t > fShapeInitial_h
Shape of the initial value of hidden states.
std::vector< size_t > fShapeSequence_lens
Shape of the length of the sequences.
std::vector< size_t > fShapeY_h
Shape of the last sequence of the output.
size_t fAttrLinearBeforeReset
Linear layer before the reset gate.
std::vector< size_t > fShapeB
Shape of the bias.
std::vector< ETensorType > TypeInference(std::vector< ETensorType >)
Infers the type of the output tensors.
std::string fNInitial_h
Name of the initial value of the hidden states.
std::vector< size_t > fShapeW
Shape of the weights.
ROperator_GRU()
Default constructor of ROperator_GRU.
create variable transformations