Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_LSTM.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR_LSTM
2#define TMVA_SOFIE_ROPERATOR_LSTM
3
4#include "TMVA/RModel.hxx"
5#include "TMVA/ROperator.hxx"
7
8#include <memory>
9#include <sstream>
10#include <string>
11#include <vector>
12
13namespace TMVA {
14namespace Experimental {
15namespace SOFIE {
16
17/*! \brief Long Short-Term Memory operator
18 *
19 * Inference code generation for one-layer LSTM. Supports forward, reverse and bidirectional LSTM.
20 * See the <a href="https://github.com/onnx/onnx/blob/master/docs/Operators.md#LSTM">ONNX documentation</a>
21 * for details about the supported LSTM architectures.
22 */
23template <typename T> class ROperator_LSTM final : public ROperator {
24 private:
25 std::vector<float> fAttrActivationAlpha; ///< Sacling values used by some activation functions
26 std::vector<float> fAttrActivationBeta; ///< Scaling values used by some activation functions
27 std::vector<std::string> fAttrActivations; ///< Activation functions
28 float fAttrClip; ///< Clip threshold
29 std::string fAttrDirection; ///< Direction of processing
30 size_t fAttrHiddenSize; ///< Number of the hidden layers
31 size_t fAttrInputForget; ///< Forget gate
32 size_t fAttrLayout; ///< Data layout
33
34 std::string fNX; ///< Name of the input
35 std::string fNW; ///< Name of the weights
36 std::string fNR; ///< Name of the recurrence
37 std::string fNB; ///< Name of the bias
38 std::string fNSequence_lens; ///< Name of length of the sequences
39 std::string fNInitial_h; ///< Name of the initial value of the hidden states
40 std::string fNInitial_c; ///< Name of the initial value of the cell states
41 std::string fNP; ///< Name of peepholes
42 std::string fNY; ///< Name of the output
43 std::string fNY_h; ///< Name of the last sequence of the output
44 std::string fNY_c; ///< Name of the last sequence of the cell states
45
46 std::vector<size_t> fShapeX; ///< Shape of the input
47 std::vector<size_t> fShapeW; ///< Shape of the weights
48 std::vector<size_t> fShapeR; ///< Shape of the recurrence
49 std::vector<size_t> fShapeB; ///< Shape of the bias
50 std::vector<size_t> fShapeSequence_lens; ///< Shape of the length of the sequences
51 std::vector<size_t> fShapeInitial_h; ///< Shape of the initial value of the hidden states
52 std::vector<size_t> fShapeInitial_c; ///< Shape of the initial value of the cell states
53 std::vector<size_t> fShapeP; ///< Shape of the peepholes
54 std::vector<size_t> fShapeY; ///< Shape of the output
55 std::vector<size_t> fShapeY_h; ///< Shape of the last sequence of the output
56 std::vector<size_t> fShapeY_c; ///< Shape of the last sequence of the cell states
57
58 std::string fType; ///< Type of the tensors
59
60 public:
61 /*! Default constructor of ROperator_LSTM */
63
64 /*! \brief Constructor of ROperator_LSTM from the attributes
65 *
66 * \param activation_alpha scaling values used by some activation functions
67 * \param activation_beta scaling values used by some activation functions
68 * \param activations activation functions
69 * \param clip clip threshold
70 * \param direction direction of processing of the sequneces
71 * \param hidden_size number of hidden layers
72 * \param input_forget forget gate
73 * \param layout data layout
74 * \param nameX name of the input tensor
75 * \param nameW name of the weight tensor
76 * \param nameR name of the recurrence tensor
77 * \param nameB name of the bias tensor
78 * \param nameSequence_lens name of the length of the sequences
79 * \param nameInitial_h name of the initial value of the hidden states
80 * \param nameInitial_c name of the initial value of the cell states
81 * \param nameP name of the peepholes tensor
82 * \param nameY name of the output
83 * \param nameY_h name of the last sequence of the output
84 * \param nameY_c name of the last sequence of the cell states
85 */
86 ROperator_LSTM(std::vector<float> activation_alpha,
87 std::vector<float> activation_beta,
88 std::vector<std::string> activations, float clip,
89 std::string direction, size_t hidden_size,
90 size_t input_forget, size_t layout,
91 std::string nameX, std::string nameW, std::string nameR,
92 std::string nameB, std::string nameSequence_lens,
93 std::string nameInitial_h, std::string nameInitial_c, std::string nameP,
94 std::string nameY, std::string nameY_h, std::string nameY_c)
95 : fAttrActivationAlpha(activation_alpha),
96 fAttrActivationBeta(activation_beta), fAttrActivations(activations),
97 fAttrClip(clip), fAttrDirection(direction), fAttrHiddenSize(hidden_size),
98 fAttrInputForget(input_forget), fAttrLayout(layout),
99 fNX(UTILITY::Clean_name(nameX)), fNW(UTILITY::Clean_name(nameW)),
100 fNR(UTILITY::Clean_name(nameR)), fNB(UTILITY::Clean_name(nameB)),
101 fNSequence_lens(UTILITY::Clean_name(nameSequence_lens)),
102 fNInitial_h(UTILITY::Clean_name(nameInitial_h)),
103 fNInitial_c(UTILITY::Clean_name(nameInitial_c)), fNP(UTILITY::Clean_name(nameP)),
104 fNY(UTILITY::Clean_name(nameY)), fNY_h(UTILITY::Clean_name(nameY_h)),
105 fNY_c(UTILITY::Clean_name(nameY_c)) {
106 if (std::is_same<T, float>::value) {
107 fType = "float";
108 } else {
109 throw std::runtime_error(
110 "TMVA SOFIE Encountered unsupported type parsing a LSTM operator");
111 }
112 }
113
114 /*! \brief Infers the type of the output tensors
115 *
116 * \param input type of the input tensors
117 */
118 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input);
119
120 /*! \brief Infers the shape of the output tensors
121 *
122 * \param input shape of the input tensors
123 */
124 std::vector<std::vector<size_t>>
125 ShapeInference(std::vector<std::vector<size_t>> input);
126
127 /*! \brief Initialize the model
128 *
129 * \param model Model
130 */
131 void Initialize(RModel &model);
132
133 /*! \brief Generate the inference code
134 *
135 * \param OpName name of the operator
136 */
137 std::string Generate(std::string OpName);
138
139 /*! \brief Generate the code for the Session internal data vectors
140 *
141 * \param opName name of the operator
142 */
143 std::string GenerateSessionMembersCode(std::string opName);
144
145 /*! \brief Returns the blas routines needed to compile the generated code
146 */
147 std::vector<std::string> GetBlasRoutines() { return { std::string("Gemm"), std::string("Axpy") }; }
148};
149
150} // namespace SOFIE
151} // namespace Experimental
152} // namespace TMVA
153
154// Implementation of the ROperator_LSTM class
156
157#endif
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Long Short-Term Memory operator.
std::vector< size_t > fShapeR
Shape of the recurrence.
std::vector< size_t > fShapeInitial_c
Shape of the initial value of the cell states.
std::string fNY_c
Name of the last sequence of the cell states.
ROperator_LSTM(std::vector< float > activation_alpha, std::vector< float > activation_beta, std::vector< std::string > activations, float clip, std::string direction, size_t hidden_size, size_t input_forget, size_t layout, std::string nameX, std::string nameW, std::string nameR, std::string nameB, std::string nameSequence_lens, std::string nameInitial_h, std::string nameInitial_c, std::string nameP, std::string nameY, std::string nameY_h, std::string nameY_c)
Constructor of ROperator_LSTM from the attributes.
std::string fNR
Name of the recurrence.
std::vector< size_t > fShapeY_h
Shape of the last sequence of the output.
std::vector< float > fAttrActivationAlpha
Sacling values used by some activation functions.
std::vector< size_t > fShapeInitial_h
Shape of the initial value of the hidden states.
size_t fAttrHiddenSize
Number of the hidden layers.
std::string fNInitial_c
Name of the initial value of the cell states.
std::vector< size_t > fShapeB
Shape of the bias.
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > input)
Infers the shape of the output tensors.
std::string fNW
Name of the weights.
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input)
Infers the type of the output tensors.
std::vector< size_t > fShapeY
Shape of the output.
std::vector< size_t > fShapeP
Shape of the peepholes.
std::string GenerateSessionMembersCode(std::string opName)
Generate the code for the Session internal data vectors.
std::string fType
Type of the tensors.
std::string fAttrDirection
Direction of processing.
std::vector< float > fAttrActivationBeta
Scaling values used by some activation functions.
std::vector< std::string > GetBlasRoutines()
Returns the blas routines needed to compile the generated code.
std::vector< size_t > fShapeX
Shape of the input.
std::vector< size_t > fShapeW
Shape of the weights.
std::string fNSequence_lens
Name of length of the sequences.
std::string fNY_h
Name of the last sequence of the output.
std::string fNY
Name of the output.
std::vector< std::string > fAttrActivations
Activation functions.
void Initialize(RModel &model)
Initialize the model.
std::vector< size_t > fShapeSequence_lens
Shape of the length of the sequences.
std::string Generate(std::string OpName)
Generate the inference code.
std::vector< size_t > fShapeY_c
Shape of the last sequence of the cell states.
std::string fNInitial_h
Name of the initial value of the hidden states.
ROperator_LSTM()
Default constructor of ROperator_LSTM.
create variable transformations