1#ifndef TMVA_SOFIE_ROPERATOR_LSTM
2#define TMVA_SOFIE_ROPERATOR_LSTM
14namespace Experimental {
87 std::vector<float> activation_beta,
88 std::vector<std::string> activations,
float clip,
89 std::string direction,
size_t hidden_size,
90 size_t input_forget,
size_t layout,
91 std::string nameX, std::string nameW, std::string nameR,
92 std::string nameB, std::string nameSequence_lens,
93 std::string nameInitial_h, std::string nameInitial_c, std::string nameP,
94 std::string nameY, std::string nameY_h, std::string nameY_c)
99 fNX(UTILITY::Clean_name(nameX)),
fNW(UTILITY::Clean_name(nameW)),
100 fNR(UTILITY::Clean_name(nameR)),
fNB(UTILITY::Clean_name(nameB)),
103 fNInitial_c(UTILITY::Clean_name(nameInitial_c)),
fNP(UTILITY::Clean_name(nameP)),
104 fNY(UTILITY::Clean_name(nameY)),
fNY_h(UTILITY::Clean_name(nameY_h)),
105 fNY_c(UTILITY::Clean_name(nameY_c)) {
106 if (std::is_same<T, float>::value) {
109 throw std::runtime_error(
110 "TMVA SOFIE Encountered unsupported type parsing a LSTM operator");
124 std::vector<std::vector<size_t>>
137 std::string
Generate(std::string OpName);
147 std::vector<std::string>
GetBlasRoutines() {
return { std::string(
"Gemm"), std::string(
"Axpy") }; }
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Long Short-Term Memory operator.
std::vector< size_t > fShapeR
Shape of the recurrence.
std::vector< size_t > fShapeInitial_c
Shape of the initial value of the cell states.
std::string fNP
Name of peepholes.
std::string fNY_c
Name of the last sequence of the cell states.
std::string fNX
Name of the input.
ROperator_LSTM(std::vector< float > activation_alpha, std::vector< float > activation_beta, std::vector< std::string > activations, float clip, std::string direction, size_t hidden_size, size_t input_forget, size_t layout, std::string nameX, std::string nameW, std::string nameR, std::string nameB, std::string nameSequence_lens, std::string nameInitial_h, std::string nameInitial_c, std::string nameP, std::string nameY, std::string nameY_h, std::string nameY_c)
Constructor of ROperator_LSTM from the attributes.
std::string fNR
Name of the recurrence.
std::vector< size_t > fShapeY_h
Shape of the last sequence of the output.
std::vector< float > fAttrActivationAlpha
Sacling values used by some activation functions.
std::vector< size_t > fShapeInitial_h
Shape of the initial value of the hidden states.
size_t fAttrLayout
Data layout.
size_t fAttrHiddenSize
Number of the hidden layers.
std::string fNInitial_c
Name of the initial value of the cell states.
std::vector< size_t > fShapeB
Shape of the bias.
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > input)
Infers the shape of the output tensors.
std::string fNW
Name of the weights.
float fAttrClip
Clip threshold.
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input)
Infers the type of the output tensors.
std::vector< size_t > fShapeY
Shape of the output.
std::vector< size_t > fShapeP
Shape of the peepholes.
std::string GenerateSessionMembersCode(std::string opName)
Generate the code for the Session internal data vectors.
std::string fType
Type of the tensors.
std::string fAttrDirection
Direction of processing.
std::vector< float > fAttrActivationBeta
Scaling values used by some activation functions.
std::vector< std::string > GetBlasRoutines()
Returns the blas routines needed to compile the generated code.
std::vector< size_t > fShapeX
Shape of the input.
std::vector< size_t > fShapeW
Shape of the weights.
std::string fNSequence_lens
Name of length of the sequences.
std::string fNY_h
Name of the last sequence of the output.
std::string fNY
Name of the output.
std::vector< std::string > fAttrActivations
Activation functions.
void Initialize(RModel &model)
Initialize the model.
size_t fAttrInputForget
Forget gate.
std::vector< size_t > fShapeSequence_lens
Shape of the length of the sequences.
std::string Generate(std::string OpName)
Generate the inference code.
std::vector< size_t > fShapeY_c
Shape of the last sequence of the cell states.
std::string fNInitial_h
Name of the initial value of the hidden states.
std::string fNB
Name of the bias.
ROperator_LSTM()
Default constructor of ROperator_LSTM.
create variable transformations