Logo ROOT  
Reference Guide
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
Loading...
Searching...
No Matches
ROperator_GRU.icc
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR_GRU_I
2#define TMVA_SOFIE_ROPERATOR_GRU_I
3
4namespace TMVA {
5namespace Experimental {
6namespace SOFIE {
7
8template <typename T>
9auto ROperator_GRU<T>::TypeInference(std::vector<ETensorType> input)
10-> std::vector<ETensorType> {
11 ETensorType out = input[0];
12 return {out, out};
13}
14
15template<typename T>
16auto ROperator_GRU<T>::ShapeInference(std::vector<std::vector<size_t>> input)
17-> std::vector<std::vector<size_t>> {
18 size_t num_directions = input[1][0];
19 size_t hidden_size = input[1][1] / 3;
20 if (fAttrLayout == 0) {
21 size_t seq_length = input[0][0];
22 size_t batch_size = input[0][1];
23 std::vector<std::vector<size_t>> ret(
26 return ret;
27 } else {
28 size_t batch_size = input[0][0];
29 size_t seq_length = input[0][1];
30 std::vector<std::vector<size_t>> ret(
33 return ret;
34 }
35}
36
37template<typename T>
39
40 fUseSession = model.UseSession();
41 // Check the input and output tensors
42 if (!model.CheckIfTensorAlreadyExist(fNX)) {
43 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNX + " is not found in model.");
44 }
45 fShapeX = model.GetTensorShape(fNX);
46 if (fShapeX.size() != 3) {
47 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNX + " is not of 3 dimensions.");
48 }
49 if (!model.CheckIfTensorAlreadyExist(fNW)) {
50 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNW + " is not found in model.");
51 }
52 fShapeW = model.GetTensorShape(fNW);
53 if (fShapeW.size() != 3) {
54 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNW + " is not of 3 dimensions.");
55 }
56 if (!model.CheckIfTensorAlreadyExist(fNR)) {
57 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNR + " is not found in model.");
58 }
59 fShapeR = model.GetTensorShape(fNR);
60 if (fShapeR.size() != 3) {
61 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNR + " is not of 3 dimensions.");
62 }
63 if (!fNB.empty()) {
64 if (!model.CheckIfTensorAlreadyExist(fNB)) {
65 throw std::runtime_error("TMVA SOFIE GRU op input tensor " + fNB + " is not found in model.");
66 }
67 fShapeB = model.GetTensorShape(fNB);
68 if (fShapeB.size() != 2 && fShapeB.size() != 4) {
69 throw std::runtime_error("TMVA SOFIE GRU op input tensor " + fNB + " is not of 2 or 4 dimensions.");
70 }
71 if (fShapeB.size() == 2) {
72 // Broadcasting the bias
74 size_t num_directions = fShapeW[0];
75 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
76 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
77 if (fType == "float") {
78 float *original_bias = static_cast<float*>(original_data.get());
79 float *new_bias = new float[num_directions * 6 * seq_length * batch_size * fAttrHiddenSize];
80 for (size_t direction = 0; direction < num_directions; direction++) {
81 for (size_t i = 0; i < 6; i++) {
82 for (size_t seq = 0; seq < seq_length; seq++) {
83 for (size_t batch = 0; batch < batch_size; batch++) {
84 size_t bias_offset = direction * 6 * fAttrHiddenSize + i * fAttrHiddenSize;
85 size_t offset = direction * 6 * batch_size * seq_length * fAttrHiddenSize +
86 i * batch_size * seq_length * fAttrHiddenSize +
87 + seq *batch_size *fAttrHiddenSize + batch *fAttrHiddenSize;
88 std::copy(original_bias + bias_offset, original_bias + bias_offset + fAttrHiddenSize,
90 }
91 }
92 }
93 }
94
95 std::vector<size_t> new_bias_shape = {num_directions, 6, seq_length, batch_size, fAttrHiddenSize};
96 std::shared_ptr<void> new_bias_ptr(new_bias, std::default_delete<float[]>());
98 fShapeB = model.GetTensorShape(fNB);
99 }
100 }
101 }
102 if (!fNSequence_lens.empty()) {
103 if (!model.CheckIfTensorAlreadyExist(fNSequence_lens)) {
104 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
105 fNSequence_lens +
106 "is not found in model.");
107 }
108 fShapeSequence_lens = model.GetTensorShape(fNSequence_lens);
109 if (fShapeSequence_lens.size() != 1) {
110 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
111 fNSequence_lens +
112 " is not of 1 dimension.");
113 }
114 }
115 if (!fNInitial_h.empty()) {
116 if (!model.CheckIfTensorAlreadyExist(fNInitial_h)) {
117 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
118 fNInitial_h + " is not found in model.");
119 }
120 fShapeInitial_h = model.GetTensorShape(fNInitial_h);
121 if (fShapeInitial_h.size() != 3) {
122 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
123 fNInitial_h + " is not of 3 dimensions.");
124 }
125 }
126 if (!fNY.empty()) {
127 fShapeY = ShapeInference({fShapeX, fShapeW})[0];
128 if (!model.CheckIfTensorAlreadyExist(fNY)) {
129 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
130 }
131 }
132 if (!fNY_h.empty()) {
133 fShapeY_h = ShapeInference({fShapeX, fShapeW})[1];
134 if (!model.CheckIfTensorAlreadyExist(fNY_h)) {
135 model.AddIntermediateTensor(fNY_h, model.GetTensorType(fNX), fShapeY_h);
136 }
137 }
138 // Check the attributes
139 for (auto &activation : fAttrActivations) {
140 if (activation != "Relu" && activation != "Tanh" &&
141 activation != "Sigmoid" && activation != "Affine" &&
142 activation != "LeakyRelu" && activation != "ThresholdRelu" &&
143 activation != "ScaledTanh" && activation != "HardSigmoid" &&
144 activation != "Elu" && activation != "Softsign" &&
145 activation != "Softplus") {
146 throw std::runtime_error("TMVA SOFIE - Activation function " +
147 activation + " not implemented");
148 }
149 }
150 if (fAttrDirection == "reverse") fAttrDirection = "backward";
151 if (fAttrDirection != "forward" && fAttrDirection != "backward" &&
152 fAttrDirection != "reverse" &&
153 fAttrDirection != "bidirectional") {
154 throw std::runtime_error(
155 "TMVA SOFIE - Invalid GRU direction fAttrDirection = " +
156 fAttrDirection);
157 }
158 if (3 * fAttrHiddenSize != fShapeW[1]) {
159 throw std::runtime_error(
160 "TMVA SOFIE - fAttrHiddenSize must be equal to " +
161 std::to_string(fShapeW[1] / 3));
162 }
163 if (fAttrLayout > 1) {
164 throw std::runtime_error("TMVA SOFIE - Layout fAttrLayout = " +
165 std::to_string(fAttrLayout) +
166 " must be 0 (timewise) or 1 (batchwise)");
167 }
168 if (fAttrLinearBeforeReset > 1) {
169 throw std::runtime_error(
170 "TMVA SOFIE - fAttrInputForget = " + std::to_string(fAttrLinearBeforeReset)
171 + " must be 0 or 1.");
172 }
173 if (fAttrActivations.empty()) {
174 if (fAttrDirection == "bidirectional") {
175 fAttrActivations = {"Sigmoid", "Tanh", "Sigmoid", "Tanh"};
176 } else {
177 fAttrActivations = {"Sigmoid", "Tanh"};
178 }
179 }
180}
181
182// generate code for Session data members (e.g. internal vectors)
183template <typename T>
185{
186 opName = "op_" + opName;
187 std::stringstream out;
188
189 size_t num_directions = fShapeW[0];
190 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
191 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
192 size_t input_size = fShapeX[2];
193
194 if (fAttrLayout != 0) {
195 out << "std::vector<" << fType << "> fVec_" << opName << "_input = std::vector<" << fType << ">("
196 << seq_length * batch_size * input_size << ");\n";
197 out << "std::vector<" << fType << "> fVec_" << opName << "_initial_hidden_state = std::vector<" << fType << ">("
198 << num_directions * batch_size * fAttrHiddenSize << ");\n";
199 out << "std::vector<" << fType << "> fVec_" << opName << "_initial_cell_state = std::vector<" << fType << ">("
200 << num_directions * batch_size * fAttrHiddenSize << ");\n";
201 }
202 // Set the feedforward
203 size_t ff_size = seq_length * batch_size * fAttrHiddenSize;
204 out << "std::vector<" << fType << "> fVec_" << opName << "_f_update_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
205 out << "std::vector<" << fType << "> fVec_" << opName << "_f_reset_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
206 out << "std::vector<" << fType << "> fVec_" << opName << "_f_hidden_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
207 // gate results
208 size_t hs_size = seq_length * num_directions * batch_size * fAttrHiddenSize;
209 out << "std::vector<" << fType << "> fVec_" << opName << "_update_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
210 out << "std::vector<" << fType << "> fVec_" << opName << "_reset_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
211 out << "std::vector<" << fType << "> fVec_" << opName << "_hidden_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
212
213 // feedback
214 out << "std::vector<" << fType << "> fVec_" << opName << "_feedback = std::vector<" << fType << ">("
215 << batch_size * fAttrHiddenSize << ");\n";
216
217 // hiddden state
218 if (fAttrLayout != 0 || fNY.empty()) {
219 out << "std::vector<" << fType << "> fVec_" << opName << "_hidden_state = std::vector<" << fType << ">(" << hs_size << ");\n";
220 }
221
222 out << "\n";
223
224 return out.str();
225}
226
227
228template<typename T>
230-> std::string {
231 OpName = "op_" + OpName;
232 std::stringstream out;
233
234 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
235 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
236 size_t input_size = fShapeX[2];
237 size_t num_directions = fShapeW[0];
238
239 // set the input
240 if (fAttrLayout == 0) {
241 out << SP << fType << " *" << OpName << "_input = tensor_" << fNX << ";\n";
242 } else {
243 if (fUseSession) {
244 out << SP << fType << " * " << OpName << "_input = fVec_" << OpName << "_input.data();\n";
245 } else {
246 out << SP << fType << " " << OpName << "_input[" << seq_length * batch_size * input_size << "];\n";
247 }
248 out << SP << "for(size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
249 out << SP << SP << "for(size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
250 out << SP << SP << SP << "for(size_t i = 0; i < " << input_size << "; i++) {\n";
251 out << SP << SP << SP << SP << OpName << "_input[seq * " << batch_size * input_size
252 << " + batch * " << input_size << " + i] = " << "tensor_" << fNX << "[batch * "
253 << seq_length * input_size << " + seq * " << input_size << " + i];\n";
254 out << SP << SP << SP << "}\n";
255 out << SP << SP << "}\n";
256 out << SP << "}\n";
257 }
258
259 // Set the initial hidden state
260 if (!fNInitial_h.empty()) {
261 if (fAttrLayout == 0) {
262 out << SP << fType << " *" << OpName << "_initial_hidden_state = " << " tensor_"
263 << fNInitial_h << ";\n";
264 } else {
265 if (fUseSession) {
266 out << SP << fType << " * " << OpName << "_initial_hidden_state = fVec_" << OpName
267 << "_initial_hidden_state.data();\n";
268 } else {
269 out << SP << fType << " " << OpName << "_initial_hidden_state[" << num_directions * batch_size *
270 fAttrHiddenSize << "];\n";
271 }
272 for (size_t direction = 0; direction < num_directions; direction++) {
273 out << SP << "for(size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
274 out << SP << SP << "for(size_t h = 0; h < " << fAttrHiddenSize << "; h++) {\n";
275 out << SP << SP << SP << OpName << "_initial_hidden_state["
276 << direction * batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize
277 << " + h] = tensor_" << fNInitial_h << "[batch * " << num_directions * fAttrHiddenSize
278 << " + " << direction * fAttrHiddenSize << " + h];\n";
279 out << SP << SP << "}\n";
280 out << SP << "}\n";
281 }
282 }
283 }
284
285 // Set the feedforward
286 size_t feedforward_size = seq_length * batch_size * fAttrHiddenSize;
287 if (fUseSession) {
288 out << SP << fType << " * " << OpName << "_f_update_gate = fVec_" << OpName << "_f_update_gate.data();\n";
289 out << SP << fType << " * " << OpName << "_f_reset_gate = fVec_" << OpName << "_f_reset_gate.data();\n";
290 out << SP << fType << " * " << OpName << "_f_hidden_gate = fVec_" << OpName << "_f_hidden_gate.data();\n";
291 } else {
292 out << SP << fType << " " << OpName << "_f_update_gate[" << feedforward_size << "] = {0};\n";
293 out << SP << fType << " " << OpName << "_f_reset_gate[" << feedforward_size << "] = {0};\n";
294 out << SP << fType << " " << OpName << "_f_hidden_gate[" << feedforward_size << "] = {0};\n";
295 }
296 // Set the gates
297 size_t hidden_state_size = seq_length * num_directions * batch_size * fAttrHiddenSize;
298 if (fUseSession) {
299 out << SP << fType << " * " << OpName << "_update_gate = fVec_" << OpName << "_update_gate.data();\n";
300 out << SP << fType << " * " << OpName << "_reset_gate = fVec_" << OpName << "_reset_gate.data();\n";
301 out << SP << fType << " * " << OpName << "_hidden_gate = fVec_" << OpName << "_hidden_gate.data();\n";
302 } else {
303 out << SP << fType << " " << OpName << "_update_gate[" << hidden_state_size << "] = {0};\n";
304 out << SP << fType << " " << OpName << "_reset_gate[" << hidden_state_size << "] = {0};\n";
305 out << SP << fType << " " << OpName << "_hidden_gate[" << hidden_state_size << "] = {0};\n";
306 }
307 // Set the hidden state
308 if (fAttrLayout == 0 && !fNY.empty()) {
309 out << SP << fType << " *" << OpName << "_hidden_state = tensor_" << fNY << ";\n";
310 } else {
311 if (fUseSession) {
312 out << SP << fType << " * " << OpName << "_hidden_state = fVec_" << OpName << "_hidden_state.data();\n";
313 } else {
314 out << SP << fType << " " << OpName << "_hidden_state[" << hidden_state_size << "] = {0};\n";
315 }
316 }
317
318 if (fUseSession) {
319 out << SP << fType << " * " << OpName << "_feedback = fVec_" << OpName << "_feedback.data();\n";
320 } else {
321 out << SP << fType << " " << OpName << "_feedback[" << batch_size * fAttrHiddenSize << "] = {0};\n";
322 }
323
324 out << SP << "char " << OpName << "_transA = 'N';\n";
325 out << SP << "char " << OpName << "_transB = 'T';\n";
326 out << SP << "int " << OpName << "_m = " << seq_length * batch_size << ";\n";
327 out << SP << "int " << OpName << "_m2 = " << batch_size << ";\n";
328 out << SP << "int " << OpName << "_n = " << fAttrHiddenSize << ";\n";
329 out << SP << "int " << OpName << "_k = " << input_size << ";\n";
330 if (fType == "float") {
331 out << SP << "float " << OpName << "_alpha = 1.;\n";
332 out << SP << "float " << OpName << "_beta = 0.;\n";
333 }
334 if (!fNB.empty()) {
335 out << SP << "int " << OpName << "_bias_size = " << seq_length * batch_size * fAttrHiddenSize << ";\n";
336 }
337 out << SP << "int " << OpName << "_incx = 1;\n";
338 out << SP << "int " << OpName << "_incy = 1;\n";
339 out << SP << "int " << OpName << "_feedback_size = " << batch_size * fAttrHiddenSize << ";\n";
340
341 for (size_t direction = 0; direction < num_directions; direction++) {
342 if (direction == 0) {
343 if (fType == "float") {
344 // f_update_gate = input * weight_z^T
345 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
346 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
347 << fNW << ", &" << OpName << "_k, " << OpName << "_input, &" << OpName << "_k, &"
348 << OpName << "_beta, " << OpName << "_f_update_gate, &" << OpName << "_n);\n";
349 // f_reset_gate = input * weight_r^T
350 size_t wr_offset = fAttrHiddenSize * input_size;
351 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
352 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
353 << fNW << " + " << wr_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
354 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_reset_gate, &" << OpName << "_n);\n";
355 // f_hidden_gate = input * weight_h^T
356 size_t wh_offset = 2 * fAttrHiddenSize * input_size;
357 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
358 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
359 << fNW << " + " << wh_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
360 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_hidden_gate, &" << OpName << "_n);\n";
361 }
362 } else {
363 if (fType == "float") {
364 // f_update_gate = input * weight_z^T
365 size_t wz_offset = 3 * fAttrHiddenSize * input_size;
366 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
367 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
368 << fNW << " + " << wz_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
369 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_update_gate, &" << OpName << "_n);\n";
370 // f_reset_gate = input * weight_r^T
371 size_t wr_offset = 3 * fAttrHiddenSize * input_size + fAttrHiddenSize * input_size;
372 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
373 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
374 << fNW << " + " << wr_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
375 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_reset_gate, &" << OpName << "_n);\n";
376 // f_hidden_gate = input * weight_h^T
377 size_t wh_offset = 3 * fAttrHiddenSize * input_size + 2 * fAttrHiddenSize * input_size;
378 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
379 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
380 << fNW << " + " << wh_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
381 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_hidden_gate, &" << OpName << "_n);\n";
382 }
383 }
384
385 if (!fNB.empty()) {
386 if (direction == 0) {
387 if (fType == "float") {
388 // Add the bias of the weight to f_update_gate
389 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
390 << fNB << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &" << OpName << "_incy);\n";
391 // Add the bias of the recurrence to f_update_gate
392 size_t rbz_offset = 3 * batch_size * seq_length * fAttrHiddenSize;
393 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
394 << fNB << " + " << rbz_offset << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &"
395 << OpName << "_incy);\n";
396 // Add the bias of the weight to f_reset_gate
397 size_t wbr_offset = batch_size * seq_length * fAttrHiddenSize;
398 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
399 << fNB << " + " << wbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
400 << OpName << "_incy);\n";
401 // Add the bias of the recurrence to f_reset_gate
402 //size_t rbr_offset = fAttrHiddenSize * fAttrHiddenSize + 3 * batch_size * fAttrHiddenSize;
403 size_t rbr_offset = 4 * batch_size * seq_length * fAttrHiddenSize;
404 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
405 << fNB << " + " << rbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
406 << OpName << "_incy);\n";
407 // Add the bias of the weight to f_hidden_gate
408 size_t wbh_offset = 2 * batch_size * seq_length * fAttrHiddenSize;
409 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
410 << fNB << " + " << wbh_offset << ", &" << OpName << "_incx, " << OpName << "_f_hidden_gate, &"
411 << OpName << "_incy);\n";
412 if (fAttrLinearBeforeReset == 0) {
413 // Add the bias of the recurrence to f_hidden_gate
414 size_t rbh_offset = 5 * batch_size * seq_length * fAttrHiddenSize;
415 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
416 << fNB << " + " << rbh_offset << ", &" << OpName << "_incx, " << OpName
417 << "_f_hidden_gate, &" << OpName << "_incy);\n";
418 }
419 }
420 } else {
421 if (fType == "float") {
422 // Add the bias of the weight to f_update_gate
423 size_t wbz_offset = 6 * batch_size * seq_length * fAttrHiddenSize;
424 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
425 << fNB << " + " << wbz_offset << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &"
426 << OpName << "_incy);\n";
427 // Add the bias of the recurrence to f_update_gate
428 // size_t rbz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize + 3 * batch_size * fAttrHiddenSize;
429 size_t rbz_offset = 9 * batch_size * seq_length * fAttrHiddenSize;
430 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
431 << fNB << " + " << rbz_offset << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &"
432 << OpName << "_incy);\n";
433 // Add the bias of the weight to f_reset_gate
434 size_t wbr_offset = 7 * batch_size * seq_length * fAttrHiddenSize;
435 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
436 << fNB << " + " << wbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
437 << OpName << "_incy);\n";
438 // Add the bias of the recurrence to f_reset_gate
439 size_t rbr_offset = 10 * batch_size * seq_length * fAttrHiddenSize;
440 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
441 << fNB << " + " << rbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
442 << OpName << "_incy);\n";
443 // Add the bias of the weight to f_hidden_gate
444 size_t wbh_offset = 8 * batch_size * seq_length * fAttrHiddenSize;
445 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
446 << fNB << " + " << wbh_offset << ", &" << OpName << "_incx, " << OpName << "_f_hidden_gate, &"
447 << OpName << "_incy);\n";
448 if (fAttrLinearBeforeReset == 0) {
449 // Add the bias of the recurrence to f_hidden_gate
450 size_t rbh_offset = 11 * batch_size * seq_length * fAttrHiddenSize;
451 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
452 << fNB << " + " << rbh_offset << ", &" << OpName << "_incx, " << OpName
453 << "_f_hidden_gate, &" << OpName << "_incy);\n";
454 }
455 }
456 }
457 }
458
459 // Copy the feedforward into the gates
460 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
461 out << SP << SP << "size_t offset = seq * " << batch_size * fAttrHiddenSize << ";\n";
462 if (direction == 0) {
463 out << SP << SP << "size_t gate_offset = seq * " << num_directions * batch_size * fAttrHiddenSize
464 << ";\n";
465 } else {
466 out << SP << SP << "size_t gate_offset = seq * " << num_directions * batch_size * fAttrHiddenSize
467 << " + " << batch_size * fAttrHiddenSize << ";\n";
468 }
469 size_t f_seq_size = batch_size * fAttrHiddenSize;
470 out << SP << SP << "std::copy(" << OpName << "_f_update_gate + offset, " << OpName
471 << "_f_update_gate + offset + " << f_seq_size << ", " << OpName << "_update_gate + gate_offset);\n";
472 out << SP << SP << "std::copy(" << OpName << "_f_reset_gate + offset, " << OpName
473 << "_f_reset_gate + offset + " << f_seq_size << ", " << OpName << "_reset_gate + gate_offset);\n";
474 out << SP << SP << "std::copy(" << OpName << "_f_hidden_gate + offset, " << OpName
475 << "_f_hidden_gate + offset + " << f_seq_size << ", " << OpName << "_hidden_gate + gate_offset);\n";
476 out << SP << "}\n";
477
478 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
479 if (fAttrDirection == "backward" || direction == 1) {
480 out << SP << SP << "size_t index = " << seq_length - 1 << " - seq;\n";
481 } else {
482 out << SP << SP << "size_t index = seq;\n";
483 }
484 out << SP << SP << "int m2 = " << batch_size << ";\n";
485 if (direction == 0) {
486 out << SP << SP << "size_t offset = index * " << num_directions * batch_size * fAttrHiddenSize
487 << ";\n";
488 } else {
489 out << SP << SP << "size_t offset = index * " << num_directions * batch_size * fAttrHiddenSize
490 << " + " << batch_size * fAttrHiddenSize << ";\n";
491 }
492 size_t size = batch_size * fAttrHiddenSize;
493 // gate = gate + initial_hidden_state * Recurrence^T
494 out << SP << SP << "if (seq == 0) {\n";
495 if (!fNInitial_h.empty()) {
496 if (direction == 0) {
497 if (fType == "float") {
498 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
499 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << ", &"
500 << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName << "_n, &" << OpName
501 << "_alpha, " << OpName << "_update_gate + offset, &" << OpName << "_n);\n";
502 size_t rr_offset = fAttrHiddenSize * fAttrHiddenSize;
503 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
504 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
505 << rr_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
506 << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &" << OpName << "_n);\n";
507 }
508 } else { // direction=1
509 if (fType == "float") {
510 size_t rz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
511 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
512 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
513 << rz_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
514 << "_n, &" << OpName << "_alpha, " << OpName << "_update_gate + offset, &" << OpName << "_n);\n";
515 size_t rr_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
516 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
517 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
518 << rr_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
519 << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &" << OpName << "_n);\n";
520 }
521 }
522 }
523 out << SP << SP << "} else {\n";
524 // gate = gate + previous_hidden_state * Recurrence^T
525 if (direction == 0) {
526 if (fAttrDirection == "backward") {
527 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
528 << num_directions * batch_size * fAttrHiddenSize << ";\n";
529 } else {
530 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
531 << num_directions * batch_size * fAttrHiddenSize << ";\n";
532 }
533 if (fType == "float") {
534 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
535 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << ", &"
536 << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &" << OpName << "_n, &"
537 << OpName << "_alpha, " << OpName << "_update_gate + offset, &" << OpName << "_n);\n";
538 size_t rr_offset = fAttrHiddenSize * fAttrHiddenSize;
539 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
540 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
541 << rr_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
542 << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &"
543 << OpName << "_n);\n";
544 }
545 } else {
546 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
547 << num_directions * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
548 if (fType == "float") {
549 size_t rz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
550 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
551 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
552 << rz_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
553 << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_update_gate + offset, &"
554 << OpName << "_n);\n";
555 size_t rr_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
556 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
557 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
558 << rr_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
559 << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &"
560 << OpName << "_n);\n";
561 }
562 }
563 out << SP << SP << "}\n";
564
565 // Clip the elements of the update gate and the reset gate into the range [-fClip, fClip]
566 if (fAttrClip > .0) {
567 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
568 if (fType == "float") {
569 out << SP << SP << SP << "float z = (" << OpName << "_update_gate[i] > " << -fAttrClip
570 << ") ? " << OpName << "_update_gate[i] : " << -fAttrClip << ";\n";
571 }
572 out << SP << SP << SP << OpName << "_update_gate[i] = (z < " << fAttrClip
573 << ") ? z : " << fAttrClip << ";\n";
574 if (fType == "float") {
575 out << SP << SP << SP << "float r = (" << OpName << "_reset_gate[i] > " << -fAttrClip
576 << ") ? " << OpName << "_reset_gate[i] : " << -fAttrClip << ";\n";
577 }
578 out << SP << SP << SP << OpName << "_reset_gate[i] = (r < " << fAttrClip
579 << ") ? r : " << fAttrClip << ";\n";
580 out << SP << SP << "}\n";
581 }
582
583 // Apply the activation function to the update gate and the reset gate
584 if (fAttrActivations[direction * 2] == "Relu") {
585 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
586 out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < 0.)\n";
587 out << SP << SP << SP << SP << OpName << "_update_gate[i] = 0.;\n";
588 out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < 0.)\n";
589 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = 0.;\n";
590 out << SP << SP << "}\n";
591 } else if (fAttrActivations[direction * 2] == "Tanh") {
592 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
593 if (fType == "float") {
594 out << SP << SP << SP << "float z = exp(-2 * " << OpName << "_update_gate[i]);\n";
595 }
596 out << SP << SP << SP << SP << OpName << "_update_gate[i] = (1. - z) / (1. + z);\n";
597 if (fType == "float") {
598 out << SP << SP << SP << "float r = exp(-2 * " << OpName << "_reset_gate[i]);\n";
599 }
600 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = (1. - r) / (1. + r);\n";
601 out << SP << SP << "}\n";
602 } else if (fAttrActivations[direction * 2] == "Sigmoid") {
603 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
604 out << SP << SP << SP << SP << OpName << "_update_gate[i] = 1. / (1. + exp(-"
605 << OpName << "_update_gate[i]));\n";
606 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = 1. / (1. + exp(-"
607 << OpName << "_reset_gate[i]));\n";
608 out << SP << SP << "}\n";
609 } else if (fAttrActivations[direction * 2] == "Affine") {
610 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
611 out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
612 << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_update_gate[i] + "
613 << fAttrActivationBeta[direction * 2] << ";\n";
614 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
615 << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_reset_gate[i] + "
616 << fAttrActivationBeta[direction * 2] << ";\n";
617 out << SP << SP << "}\n";
618 } else if (fAttrActivations[direction * 2] == "ScaledTanh") {
619 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
620 if (fType == "float") {
621 out << SP << SP << SP << "float z = exp(-2 * " << fAttrActivationBeta[direction * 2]
622 << " * "<< OpName << "_update_gate[i]);\n";
623 }
624 out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
625 << fAttrActivationAlpha[direction * 2] << " * (1. - z) / (1. + z);\n";
626 if (fType == "float") {
627 out << SP << SP << SP << "float r = exp(-2 * " << fAttrActivationBeta[direction * 2]
628 << " * "<< OpName << "_reset_gate[i]);\n";
629 }
630 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
631 << fAttrActivationAlpha[direction * 2] << " * (1. - r) / (1. + r);\n";
632 out << SP << SP << "}\n";
633 } else if (fAttrActivations[direction * 2] == "HardSigmoid") {
634 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
635 if (fType == "float") {
636 out << SP << SP << SP << "float za = " << fAttrActivationAlpha[direction * 2] << " * "
637 << OpName << "_update_gate[i] + " << fAttrActivationBeta[direction * 2] << ";\n";
638 out << SP << SP << SP << "float zb = (za > 0.) ? za : 0.;\n";
639 }
640 out << SP << SP << SP << SP << OpName << "_update_gate[i] = (zb < 1.) ? zb : 1.;\n";
641 if (fType == "float") {
642 out << SP << SP << SP << "float ra = " << fAttrActivationAlpha[direction * 2] << " * "
643 << OpName << "_reset_gate[i] + " << fAttrActivationBeta[direction * 2] << ";\n";
644 out << SP << SP << SP << "float rb = (ra > 0.) ? ra : 0.;\n";
645 }
646 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = (rb < 1.) ? rb : 1.;\n";
647 out << SP << SP << "}\n";
648 } else if (fAttrActivations[direction * 2] == "LeakyRelu") {
649 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
650 out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < 0.)\n";
651 out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
652 << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_update_gate[i];\n";
653 out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < 0.)\n";
654 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
655 << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_reset_gate[i];\n";
656 out << SP << SP << "}\n";
657 } else if (fAttrActivations[direction * 2] == "ThresholdRelu") {
658 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
659 out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < "
660 << fAttrActivationAlpha[direction * 2] << ")\n";
661 out << SP << SP << SP << SP << OpName << "_update_gate[i] = 0.;\n";
662 out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < "
663 << fAttrActivationAlpha[direction * 2] << ")\n";
664 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = 0.;\n";
665 out << SP << SP << "}";
666 } else if (fAttrActivations[direction * 2] == "Elu") {
667 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
668 out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < 0.)\n";
669 out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
670 << fAttrActivationAlpha[direction * 2] << " * exp(" << OpName << "_update_gate[i] - 1.);\n";
671 out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < 0.)\n";
672 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
673 << fAttrActivationAlpha[direction * 2] << " * exp(" << OpName << "_reset_gate[i] - 1.);\n";
674 out << SP << SP << "}\n";
675 } else if (fAttrActivations[direction * 2] == "Softsign") {
676 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
677 out << SP << SP << SP << SP << OpName << "_update_gate[i] = " << OpName
678 << "_update_gate[i] / (1. + abs(" << OpName << "_update_gate[i]));\n";
679 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = " << OpName
680 << "_reset_gate[i] / (1. + abs(" << OpName << "_reset_gate[i]));\n";
681 out << SP << SP << "}\n";
682 } else { // fAttrActivations[direction * 2] = Softplus
683 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
684 out << SP << SP << SP << SP << OpName << "_update_gate[i] = log(1. + exp("
685 << OpName << "_update_gate[i]));\n";
686 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = log(1. + exp("
687 << OpName << "_reset_gate[i]));\n";
688 out << SP << SP << "}\n";
689 }
690
691 if (fAttrLinearBeforeReset == 0) {
692 out << SP << SP << "if (seq == 0) {\n";
693 if (!fNInitial_h.empty()) {
694 // feedback = reset_gate o initial_hidden_state
695 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
696 out << SP << SP << SP << SP << OpName << "_feedback[i] = " << OpName
697 << "_reset_gate[i + offset] * " << OpName << "_initial_hidden_state[i];\n";
698 out << SP << SP << SP << "}\n";
699 }
700 out << SP << SP << "} else {\n";
701 // feedback = reset_gate o previous_hidden_state
702 if (direction == 0) {
703 if (fAttrDirection == "backward") {
704 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
705 << num_directions * batch_size * fAttrHiddenSize << ";\n";
706 } else {
707 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
708 << num_directions * batch_size * fAttrHiddenSize << ";\n";
709 }
710 } else {
711 out << SP << SP << SP << "size_t previous_offset = (index + 1) * " << num_directions
712 * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
713 }
714 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
715 out << SP << SP << SP << SP << OpName << "_feedback[i] = " << OpName
716 << "_reset_gate[i + offset] * " << OpName << "_hidden_state[i + previous_offset];\n";
717 out << SP << SP << SP << "}\n";
718 out << SP << SP << "}\n";
719 // feedback = feedback * R_h^T
720 size_t rh_offset = (direction == 0) ?
721 2 * fAttrHiddenSize * fAttrHiddenSize : 3 * fAttrHiddenSize * fAttrHiddenSize
722 + 2 * fAttrHiddenSize * fAttrHiddenSize;
723 out << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
724 << OpName << "_n, &" << OpName << "_m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_"
725 << fNR << " + " << rh_offset << ", &" << OpName << "_n, " << OpName << "_feedback, &" << OpName
726 << "_n, &" << OpName << "_beta, " << OpName << "_feedback, &" << OpName << "_n);\n";
727 } else { // fAttrLinearBeforeReset=1
728 // feedback = previous_hidden_state * R_h^T
729 //LM fixes
730 size_t rh_offset = (direction == 0)
731 ? 2 * fAttrHiddenSize * fAttrHiddenSize
732 : 3 * fAttrHiddenSize * fAttrHiddenSize + 2 * fAttrHiddenSize * fAttrHiddenSize;
733 out << SP << SP << "if (seq == 0) {\n";
734 if (!fNInitial_h.empty()) {
735 // feedback = W * initial_hidden_state + bias
736 out << SP << SP << SP
737 << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName << "_n, &"
738 << OpName << "_m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
739 << rh_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &"
740 << OpName << "_n, &" << OpName << "_beta, " << OpName << "_feedback, &" << OpName << "_n);\n";
741 }
742 out << SP << SP << "} else {\n";
743 // case for seq > 0
744 if (direction == 0) {
745 if (fAttrDirection == "backward") {
746 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
747 << num_directions * batch_size * fAttrHiddenSize << ";\n";
748 } else {
749 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
750 << num_directions * batch_size * fAttrHiddenSize << ";\n";
751 }
752 } else {
753 out << SP << SP << SP << "size_t previous_offset = (index + 1) * " << num_directions
754 * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
755 }
756 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
757 << OpName << "_n, &" << OpName << "_m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR
758 << " + " << rh_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
759 << OpName << "_n, &" << OpName << "_beta, " << OpName << "_feedback, &" << OpName << "_n);\n";
760 // endif on seq 0 or not
761 out << SP << SP << "}\n";
762 // Add the bias of the recurrence to feedback
763 if (!fNB.empty()) {
764 size_t rbh_offset = (direction == 0) ? 5 * batch_size * seq_length * fAttrHiddenSize
765 : 11 * batch_size * seq_length * fAttrHiddenSize;
766 out << SP << SP << "BLAS::saxpy_(&" << OpName << "_feedback_size, &" << OpName
767 << "_alpha, tensor_" << fNB << " + " << rbh_offset << ", &" << OpName << "_incx, "
768 << OpName << "_feedback, &" << OpName << "_incy);\n";
769 }
770 // feedback = reset_gate o feedback
771 out << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
772 out << SP << SP << SP << OpName << "_feedback[i] *= " << OpName << "_reset_gate[i + offset];\n";
773 out << SP << SP << "}\n";
774 }
775
776 // hidden_gate = hidden_gate + feedback
777 out << SP << SP << "BLAS::saxpy_(&" << OpName << "_feedback_size, &" << OpName << "_alpha, "
778 << OpName << "_feedback, &" << OpName << "_incx, " << OpName << "_hidden_gate + offset, &"
779 << OpName << "_incy);\n";
780
781 // Clip the elements of the hidden gate into the range [-fClip, fClip]
782 if (fAttrClip > .0) {
783 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
784 if (fType == "float") {
785 out << SP << SP << SP << "float x = (" << OpName << "_hidden_gate[i] > " << -fAttrClip
786 << ") ? " << OpName << "_hidden_gate[i] : " << -fAttrClip << ";\n";
787 }
788 out << SP << SP << SP << OpName << "_hidden_gate[i] = (x < " << fAttrClip << ") ? x : "
789 << fAttrClip << ";\n";
790 out << SP << SP << "}\n";
791 }
792
793 // Apply the activation function to the hidden gate
794 if (fAttrActivations[direction * 2 + 1] == "Relu") {
795 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
796 out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < 0.)\n";
797 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = 0.;\n";
798 out << SP << SP << "}\n";
799 } else if (fAttrActivations[direction * 2 + 1] == "Tanh") {
800 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
801 if (fType == "float") {
802 out << SP << SP << SP << "float ex = exp(-2 * " << OpName << "_hidden_gate[i]);\n";
803 }
804 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = (1. - ex) / (1. + ex);\n";
805 out << SP << SP << "}\n";
806 } else if (fAttrActivations[direction * 2 + 1] == "Sigmoid") {
807 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
808 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = 1. / (1. + exp(-" << OpName
809 << "_hidden_gate[i]));\n";
810 out << SP << SP << "}\n";
811 } else if (fAttrActivations[direction * 2 + 1] == "Affine") {
812 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
813 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
814 << fAttrActivationAlpha[direction * 2 + 1] << " * " << OpName << "_hidden_gate[i] + "
815 << fAttrActivationBeta[direction * 2 + 1] << ";\n";
816 out << SP << SP << "}\n";
817 } else if (fAttrActivations[direction * 2 + 1] == "ScaledTanh") {
818 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
819 if (fType == "float") {
820 out << SP << SP << SP << "float ex = exp(-2 * " << fAttrActivationBeta[direction * 2 + 1]
821 << " * "<< OpName << "_hidden_gate[i]);\n";
822 }
823 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
824 << fAttrActivationAlpha[direction * 2 + 1] << " * (1. - ex) / (1. + ex);\n";
825 out << SP << SP << "}\n";
826 } else if (fAttrActivations[direction * 2 + 1] == "HardSigmoid") {
827 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
828 if (fType == "float") {
829 out << SP << SP << SP << "float a = " << fAttrActivationAlpha[direction * 2 + 1] << " * "
830 << OpName << "_hidden_gate[i] + " << fAttrActivationBeta[direction * 2 + 1] << ";\n";
831 out << SP << SP << SP << "float b = (a > 0.) ? a : 0.;\n";
832 }
833 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = (b < 1.) ? b : 1.;\n";
834 out << SP << SP << "}\n";
835 } else if (fAttrActivations[direction * 2 + 1] == "LeakyRelu") {
836 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
837 out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < 0.)\n";
838 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
839 << fAttrActivationAlpha[direction * 2 + 1] << " * " << OpName << "_hidden_gate[i];\n";
840 out << SP << SP << "}\n";
841 } else if (fAttrActivations[direction * 2 + 1] == "ThresholdRelu") {
842 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
843 out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < "
844 << fAttrActivationAlpha[direction * 2 + 1] << ")\n";
845 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = 0.;\n";
846 out << SP << SP << "}";
847 } else if (fAttrActivations[direction * 2 + 1] == "Elu") {
848 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
849 out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < 0.)\n";
850 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
851 << fAttrActivationAlpha[direction * 2 + 1] << " * exp(" << OpName << "_hidden_gate[i] - 1.);\n";
852 out << SP << SP << "}\n";
853 } else if (fAttrActivations[direction * 2 + 1] == "Softsign") {
854 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
855 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = " << OpName
856 << "_hidden_gate[i] / (1. + abs(" << OpName << "_hidden_gate[i]));\n";
857 out << SP << SP << "}\n";
858 } else { // fAttrActivations[direction * 2 + 1] = Softplus
859 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
860 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = log(1. + exp("
861 << OpName << "_hidden_gate[i]));\n";
862 out << SP << SP << "}\n";
863 }
864
865 // hidden_state = (1 - update_gate) o hidden_gate
866 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
867 out << SP << SP << SP << OpName << "_hidden_state[i] = ( 1. - " << OpName
868 << "_update_gate[i]) * " << OpName << "_hidden_gate[i];\n";
869 out << SP << SP << "}\n";
870
871 out << SP << SP << "if (seq == 0) {\n";
872 if (!fNInitial_h.empty()) {
873 // hidden_state += update_gate o initial_hidden_state
874 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
875 out << SP << SP << SP << SP << OpName << "_hidden_state[i + offset] += " << OpName
876 << "_update_gate[i + offset] * " << OpName << "_initial_hidden_state[i];\n";
877 out << SP << SP << SP << "}\n";
878 }
879 out << SP << SP << "} else {\n";
880 // hidden_state += update_gate o previous_hidden_state
881 if (direction == 0) {
882 if (fAttrDirection == "backward") {
883 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
884 << num_directions * batch_size * fAttrHiddenSize << ";\n";
885 } else {
886 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
887 << num_directions * batch_size * fAttrHiddenSize << ";\n";
888 }
889 } else {
890 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
891 << num_directions * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
892 }
893 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
894 out << SP << SP << SP << SP << OpName << "_hidden_state[i + offset] += " << OpName
895 << "_update_gate[i + offset] * " << OpName << "_hidden_state[i + previous_offset];\n";
896 out << SP << SP << SP << "}\n";
897 out << SP << SP << "}\n";
898
899 out << SP << "}\n";
900 }
901
902 // Padding the hidden state for GRU with different sequence lengths
903 if (!fNSequence_lens.empty()) {
904 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
905 out << SP << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
906 out << SP << SP << SP << "if (seq >= tensor_" << fNSequence_lens << "[batch]) {\n";
907 for (size_t direction = 0; direction < num_directions; direction++) {
908 out << SP << SP << SP << SP << SP << "for (size_t h = 0; h < " << fAttrHiddenSize << "; h++) {\n";
909 out << SP << SP << SP << SP << SP << SP << OpName << "_hidden_state[seq * "
910 << num_directions * batch_size * fAttrHiddenSize + direction * batch_size * fAttrHiddenSize
911 << " + batch * " << fAttrHiddenSize << " + h] = 0.;\n";
912 out << SP << SP << SP << SP << SP << "}\n";
913 }
914 out << SP << SP << SP << "}\n";
915 out << SP << SP << "}\n";
916 out << SP << "}\n";
917 }
918
919 // Copy the hidden state into y and y_h
920 if (fAttrLayout == 0) {
921 if (!fNY_h.empty()) {
922 // Copy hidden_state into Y_h
923 if (fNSequence_lens.empty()) {
924 size_t yh_size = batch_size * fAttrHiddenSize;
925 if (fAttrDirection == "backward") {
926 out << SP << "std::copy(" << OpName << "_hidden_state, " << OpName << "_hidden_state + "
927 << yh_size << ", tensor_" << fNY_h << ");\n";
928 } else {
929 size_t offset = (seq_length - 1) * num_directions * batch_size * fAttrHiddenSize;
930 out << SP << "std::copy(" << OpName << "_hidden_state + " << offset << ", " << OpName
931 << "_hidden_state + " << offset << " + " << yh_size << ", tensor_" << fNY_h << ");\n";
932 }
933 if (num_directions == 2) {
934 out << SP << "std::copy(" << OpName << "_hidden_state + " << yh_size << ", " << OpName
935 << "_hidden_state + " << 2 * yh_size << ", tensor_" << fNY_h << " + " << yh_size << ");\n";
936 }
937 } else { // GRU with different sequence lengths
938 if (fAttrDirection == "backward") {
939 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
940 out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
941 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
942 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + offset);\n";
943 out << SP << "}\n";
944 } else {
945 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
946 out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
947 out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
948 << " + batch * " << fAttrHiddenSize << ";\n";
949 out << SP << SP << "size_t yh_offset = batch * " << fAttrHiddenSize << ";\n";
950 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
951 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
952 out << SP << "}\n";
953 }
954 if (num_directions == 2) {
955 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
956 out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize
957 << " + batch * " << fAttrHiddenSize << ";\n";
958 out << SP << SP << "size_t yh_offset = " << batch_size * fAttrHiddenSize
959 << " + batch * " << fAttrHiddenSize << ";\n";
960 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
961 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
962 out << SP << "}\n";
963 }
964 }
965 }
966 } else { // fAttrLayout=1
967 if (!fNY.empty()) {
968 // Copy hidden_state into Y
969 for (size_t direction = 0; direction < num_directions; direction++) {
970 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
971 out << SP << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
972 out << SP << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
973 << " + " << direction * batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize << ";\n";
974 out << SP << SP << SP << "size_t y_offset = batch * " << seq_length * num_directions * fAttrHiddenSize
975 << " + seq * " << num_directions * fAttrHiddenSize << " + " << direction * fAttrHiddenSize << ";\n";
976 out << SP << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
977 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY << " + y_offset);\n";
978 out << SP << SP << "}\n";
979 out << SP << "}\n";
980 }
981 }
982 if (!fNY_h.empty()) {
983 // Copy the hidden_state into Y_h
984 if (fAttrDirection == "backward") {
985 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
986 out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
987 out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
988 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
989 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
990 out << SP << "}\n";
991 } else {
992 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
993 if (fNSequence_lens.empty()) {
994 out << SP << SP << "size_t seq = " << seq_length - 1 << ";\n";
995 } else {
996 out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
997 }
998 out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
999 << " + batch * " << fAttrHiddenSize << ";\n";
1000 out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
1001 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1002 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
1003 out << SP << "}\n";
1004 }
1005 if (num_directions == 2) {
1006 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1007 out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize << " + batch * "
1008 << fAttrHiddenSize << ";\n";
1009 out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << " + "
1010 << fAttrHiddenSize << ";\n";
1011 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1012 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
1013 out << SP << "}\n";
1014 }
1015 }
1016 }
1017
1018 return out.str();
1019}
1020
1021} // namespace SOFIE
1022} // namespace Experimental
1023} // namespace TMVA
1024
1025#endif
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h offset
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< Dim > dim_shape)
Definition RModel.cxx:200
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:95
const ETensorType & GetTensorType(std::string name) const
Definition RModel.cxx:67
const std::vector< size_t > & GetTensorShape(std::string name) const
Definition RModel.cxx:29
std::shared_ptr< void > GetInitializedTensorData(std::string tensor_name)
Definition RModel.cxx:261
void UpdateInitializedTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape, std::shared_ptr< void > data)
Definition RModel.cxx:252
std::string GenerateSessionMembersCode(std::string opName)
Generate the code for the Session internal data vectors.
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > >)
Infers the shape of the output tensors.
void Initialize(RModel &)
Initialize the model.
std::string Generate(std::string)
Generate the inference code.
std::vector< ETensorType > TypeInference(std::vector< ETensorType >)
Infers the type of the output tensors.
create variable transformations