Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_RNN.icc
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR_RNN_I
2#define TMVA_SOFIE_ROPERATOR_RNN_I
3
4namespace TMVA {
5namespace Experimental {
6namespace SOFIE {
7
8template <typename T>
9auto ROperator_RNN<T>::TypeInference(std::vector<ETensorType> input)
10-> std::vector<ETensorType> {
11 ETensorType out = input[0];
12 return {out, out};
13}
14
15template <typename T>
16auto ROperator_RNN<T>::ShapeInference(std::vector<std::vector<size_t>> input)
17-> std::vector<std::vector<size_t>> {
18 size_t num_directions = input[1][0];
19 size_t hidden_size = input[1][1];
20 if (fAttrLayout == 0) {
21 size_t seq_length = input[0][0];
22 size_t batch_size = input[0][1];
23 std::vector<std::vector<size_t>> ret(
24 {{seq_length, num_directions, batch_size, hidden_size},
25 {num_directions, batch_size, hidden_size}});
26 return ret;
27 } else {
28 size_t batch_size = input[0][0];
29 size_t seq_length = input[0][1];
30 std::vector<std::vector<size_t>> ret(
31 {{batch_size, seq_length, num_directions, hidden_size},
32 {batch_size, num_directions, hidden_size}});
33 return ret;
34 }
35}
36
37template <typename T>
39-> void {
40 fUseSession = model.UseSession();
41 // Check the input and output tensors
42 if (!model.CheckIfTensorAlreadyExist(fNX)) {
43 throw std::runtime_error("TMVA SOFIE RNN Op input tensor " + fNX +
44 " is not found in model.");
45 }
46 fShapeX = model.GetTensorShape(fNX);
47 if (fShapeX.size() != 3) {
48 throw std::runtime_error("TMVA SOFIE RNN Op input tensor " + fNX +
49 " is not of 3 dimensions.");
50 }
51 if (!model.CheckIfTensorAlreadyExist(fNW)) {
52 throw std::runtime_error("TMVA SOFIE RNN Op input tensor " + fNW +
53 " is not found in model.");
54 }
55 fShapeW = model.GetTensorShape(fNW);
56 if (fShapeW.size() != 3) {
57 throw std::runtime_error("TMVA SOFIE RNN Op input tensor " + fNW +
58 " is not of 3 dimensions.");
59 }
60 if (!model.CheckIfTensorAlreadyExist(fNR)) {
61 throw std::runtime_error("TMVA SOFIE RNN Op input tensor " + fNR +
62 " is not found in model.");
63 }
64 fShapeR = model.GetTensorShape(fNR);
65 if (fShapeR.size() != 3) {
66 throw std::runtime_error("TMVA SOFIE RNN Op input tensor " + fNR +
67 " is not of 3 dimensions.");
68 }
69 if (!fNB.empty()) {
70 if (!model.CheckIfTensorAlreadyExist(fNB)) {
71 throw std::runtime_error("TMVA SOFIE RNN op input tensor " + fNB +
72 " is not found in model.");
73 }
74 fShapeB = model.GetTensorShape(fNB);
75 if (fShapeB.size() != 2 && fShapeB.size() != 4) {
76 throw std::runtime_error("TMVA SOFIE RNN op input tensor " + fNB +
77 " is not of 2 or 4 dimensions.");
78 }
79 if (fShapeB.size() == 2) {
80 // Broadcasting the bias
81 auto original_data = model.GetInitializedTensorData(fNB);
82 size_t num_directions = fShapeW[0];
83 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
84 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
85 if (fType == "float") {
86 float *original_bias = static_cast<float *>(original_data.get());
87 float *new_bias = new float[num_directions * seq_length *
88 batch_size * fAttrHiddenSize];
89 float sum[fAttrHiddenSize];
90 for (size_t direction = 0; direction < num_directions;
91 direction++) {
92 for (size_t h = 0; h < fAttrHiddenSize; h++) {
93 sum[h] = original_bias[direction * 2 * fAttrHiddenSize + h] +
94 original_bias[(2 * direction + 1) * fAttrHiddenSize + h];
95 }
96 for (size_t seq = 0; seq < seq_length; seq++) {
97 for (size_t batch = 0; batch < batch_size; batch++) {
98 size_t bias_offset =
99 direction * seq_length * batch_size * fAttrHiddenSize +
100 seq * batch_size * fAttrHiddenSize + batch * fAttrHiddenSize;
101 std::copy(sum, sum + fAttrHiddenSize, new_bias + bias_offset);
102 }
103 }
104 }
105 std::vector<size_t> new_bias_shape = {num_directions, seq_length,
106 batch_size, fAttrHiddenSize};
107 std::shared_ptr<void> new_bias_ptr(new_bias, std::default_delete<float[]>());
108 model.UpdateInitializedTensor(fNB, model.GetTensorType(fNB),
109 new_bias_shape, new_bias_ptr);
110 fShapeB = model.GetTensorShape(fNB);
111 }
112 }
113 }
114 if (!fNSequence_lens.empty()) {
115 if (!model.CheckIfTensorAlreadyExist(fNSequence_lens)) {
116 throw std::runtime_error("TMVA SOFIE RNN Op input tensor " +
117 fNSequence_lens + "is not found in model.");
118 }
119 fShapeSequence_lens = model.GetTensorShape(fNSequence_lens);
120 if (fShapeSequence_lens.size() != 1) {
121 throw std::runtime_error("TMVA SOFIE RNN Op input tensor " +
122 fNSequence_lens + " is not of 1 dimension.");
123 }
124 }
125 if (!fNInitial_h.empty()) {
126 if (!model.CheckIfTensorAlreadyExist(fNInitial_h)) {
127 throw std::runtime_error("TMVA SOFIE RNN Op input tensor " +
128 fNInitial_h + " is not found in model.");
129 }
130 fShapeInitial_h = model.GetTensorShape(fNInitial_h);
131 if (fShapeInitial_h.size() != 3) {
132 throw std::runtime_error("TMVA SOFIE RNN Op input tensor " +
133 fNInitial_h + " is not of 3 dimensions.");
134 }
135 }
136 if (!fNY.empty()) {
137 fShapeY = ShapeInference({fShapeX, fShapeW})[0];
138 if (!model.CheckIfTensorAlreadyExist(fNY)) {
139 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
140 }
141 }
142 if (!fNY_h.empty()) {
143 fShapeY_h = ShapeInference({fShapeX, fShapeW})[1];
144 if (!model.CheckIfTensorAlreadyExist(fNY_h)) {
145 model.AddIntermediateTensor(fNY_h, model.GetTensorType(fNX),
146 fShapeY_h);
147 }
148 }
149 // Check the attributes
150 for (auto &activation : fAttrActivations) {
151 if (activation != "Relu" && activation != "Tanh" &&
152 activation != "Sigmoid" && activation != "Affine" &&
153 activation != "LeakyRelu" && activation != "ThresholdRelu" &&
154 activation != "ScaledTanh" && activation != "HardSigmoid" &&
155 activation != "Elu" && activation != "Softsign" &&
156 activation != "Softplus") {
157 throw std::runtime_error("TMVA SOFIE - Activation function " +
158 activation + " not implemented");
159 }
160 }
161 if (fAttrDirection != "forward" && fAttrDirection != "backward" &&
162 fAttrDirection != "bidirectional") {
163 throw std::runtime_error(
164 "TMVA SOFIE - Invalid RNN direction fAttrDirection = " +
165 fAttrDirection);
166 }
167 if (fAttrHiddenSize != fShapeW[1]) {
168 throw std::runtime_error(
169 "TMVA SOFIE - fAttrHiddenSize must be equal to " +
170 std::to_string(fShapeW[1]));
171 }
172 if (fAttrLayout > 1) {
173 throw std::runtime_error(
174 "TMVA SOFIE - Layout fAttrLayout = " + std::to_string(fAttrLayout) +
175 " must be 0 (timewise) or 1 (batchwise)");
176 }
177 if (fAttrActivations.empty()) {
178 if (fAttrDirection == "bidirectional") {
179 fAttrActivations = {"Tanh", "Tanh"};
180 } else {
181 fAttrActivations = {"Tanh"};
182 }
183 }
184 // Add needed standard library headers
185 model.AddNeededStdLib("cmath");
186}
187
188// generate code for Session data members (e.g. internal vectors)
189template <typename T>
190std::string ROperator_RNN<T>::GenerateSessionMembersCode(std::string opName)
191{
192 opName = "op_" + opName;
193 std::stringstream out;
194
195 size_t num_directions = fShapeW[0];
196 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
197 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
198 size_t input_size = fShapeX[2];
199
200 if (fAttrLayout != 0) {
201 out << "std::vector<" << fType << "> fVec_" << opName << "_input = std::vector<" << fType << ">("
202 << seq_length * batch_size * input_size << ");\n";
203 out << "std::vector<" << fType << "> fVec_" << opName << "_initial_hidden_state = std::vector<" << fType << ">("
204 << num_directions * batch_size * fAttrHiddenSize << ");\n";
205 }
206 out << "std::vector<" << fType << "> fVec_" << opName << "_feedforward = std::vector<" << fType << ">("
207 << seq_length * batch_size * fAttrHiddenSize << ");\n";
208
209 if (fAttrLayout != 0 || fNY.empty()) {
210 out << "std::vector<" << fType << "> fVec_" << opName << "_hidden_state = std::vector<" << fType << ">("
211 << seq_length * num_directions * batch_size * fAttrHiddenSize << ");\n";
212 }
213
214 out << "\n";
215
216 return out.str();
217}
218
219//////////////////////////////////////////////////////////////////////////////////////////////////
220template<typename T>
221auto ROperator_RNN<T>::Generate(std::string OpName)
222-> std::string {
223 OpName = "op_" + OpName;
224 std::stringstream out;
225
226 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
227 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
228 size_t input_size = fShapeX[2];
229 size_t num_directions = fShapeW[0];
230
231 // set the input
232 if (fAttrLayout == 0) {
233 if (fType == "float") {
234 out << SP << "float *" << OpName << "_input = tensor_" << fNX << ";\n";
235 }
236 } else {
237 if (fUseSession)
238 out << SP << fType << " * " << OpName << "_input = fVec_" << OpName << "_input.data();\n";
239 else
240 out << SP << fType << " " << OpName << "_input[" << seq_length * batch_size * input_size << "];\n";
241 out << SP << "for(size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
242 out << SP << SP << "for(size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
243 out << SP << SP << SP << "for(size_t i = 0; i < " << input_size << "; i++) {\n";
244 out << SP << SP << SP << SP << OpName << "_input[seq * " << batch_size * input_size
245 << " + batch * " << input_size << " + i] = " << "tensor_" << fNX << "[batch * "
246 << seq_length * input_size << " + seq * " << input_size << " + i];\n";
247 out << SP << SP << SP << "}\n";
248 out << SP << SP << "}\n";
249 out << SP << "}\n";
250 }
251
252 // Set the initial hidden state
253 if (!fNInitial_h.empty()) {
254 if (fAttrLayout == 0) {
255 out << SP << fType << " *" << OpName << "_initial_hidden_state = " << " tensor_"
256 << fNInitial_h << ";\n";
257 } else {
258 if (fUseSession)
259 out << SP << fType << " * " << OpName << "_initial_hidden_state = fVec_" << OpName
260 << "_initial_hidden_state.data();\n";
261 else
262 out << fType << " " << OpName << "_initial_hidden_state[" << num_directions * batch_size *
263 fAttrHiddenSize << "] = {0};\n";
264
265 for (size_t direction = 0; direction < num_directions; direction++) {
266 out << SP << "for(size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
267 out << SP << SP << "for(size_t h = 0; h < " << fAttrHiddenSize << "; h++) {\n";
268 out << SP << SP << SP << OpName << "_initial_hidden_state["
269 << direction * batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize
270 << " + h] = tensor_" << fNInitial_h << "[batch * " << num_directions * fAttrHiddenSize
271 << " + " << direction * fAttrHiddenSize << " + h];\n";
272 out << SP << SP << "}\n";
273 out << SP << "}\n";
274 }
275 }
276 }
277
278 if (fUseSession)
279 out << SP << fType << " * " << OpName << "_feedforward = fVec_" << OpName
280 << "_feedforward.data();\n";
281 else
282 out << SP << fType << " " << OpName << "_feedforward[" << seq_length * batch_size * fAttrHiddenSize << "] = {0};\n";
283
284 // Set the hidden state
285 if (fAttrLayout == 0 && !fNY.empty()) {
286 out << SP << fType << " *" << OpName << "_hidden_state = tensor_" << fNY << ";\n";
287 } else {
288 if (fUseSession)
289 out << SP << fType << " * " << OpName << "_hidden_state = fVec_" << OpName << "_hidden_state.data();\n";
290 else
291 out << SP << fType << " " << OpName << "_hidden_state[" << seq_length * num_directions *
292 batch_size * fAttrHiddenSize << "] = {0};\n";
293 }
294
295 out << SP << "char " << OpName << "_transA = 'N';\n";
296 out << SP << "char " << OpName << "_transB = 'T';\n";
297 out << SP << "int " << OpName << "_m = " << seq_length * batch_size << ";\n";
298 out << SP << "int " << OpName << "_n = " << fAttrHiddenSize << ";\n";
299 out << SP << "int " << OpName << "_k = " << input_size << ";\n";
300 if (fType == "float") {
301 out << SP << "float " << OpName << "_alpha = 1.;\n";
302 out << SP << "float " << OpName << "_beta = .0;\n";
303 }
304 if (!fNB.empty()) {
305 out << SP << "int " << OpName << "_bias_size = " << seq_length * batch_size * fAttrHiddenSize << ";\n";
306 out << SP << "int " << OpName << "_incx = 1;\n";
307 out << SP << "int " << OpName << "_incy = 1;\n";
308 }
309
310 for (size_t direction = 0; direction < num_directions; direction++) {
311 // feedforward = input * W^T + bias
312 if (fType == "float") {
313 if (direction == 0) {
314 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
315 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName
316 << "_alpha, tensor_" << fNW << ", &" << OpName << "_k, " << OpName
317 << "_input, &" << OpName << "_k, &" << OpName << "_beta, " << OpName
318 << "_feedforward, &" << OpName << "_n);\n";
319 } else {
320 out << SP << "size_t " << OpName << "_w_offset = " << fAttrHiddenSize * input_size
321 << ";\n";
322 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
323 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName
324 << "_alpha, tensor_" << fNW << " + " << OpName << "_w_offset, &" << OpName
325 << "_k, " << OpName << "_input, &" << OpName << "_k, &" << OpName << "_beta, "
326 << OpName << "_feedforward, &" << OpName << "_n);\n";
327 }
328 }
329 // Add the bias
330 if (!fNB.empty()) {
331 if (fType == "float") {
332 if (direction == 0) {
333 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
334 << fNB << ", &" << OpName << "_incx, " << OpName << "_feedforward, &" << OpName << "_incy);\n";
335 } else {
336 out << SP << "size_t " << OpName << "_bias_offset = "
337 << seq_length * batch_size * fAttrHiddenSize << ";\n";
338 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
339 << fNB << " + " << OpName << "_bias_offset, &" << OpName << "_incx, " << OpName
340 << "_feedforward, &" << OpName << "_incy);\n";
341 }
342 }
343 }
344
345 // Copy feedforward into hidden state
346 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
347 out << SP << SP << "size_t offset = seq * " << batch_size * fAttrHiddenSize << ";\n";
348 out << SP << SP << "size_t size = " << batch_size * fAttrHiddenSize << ";\n";
349 out << SP << SP << "size_t h_offset = seq * "
350 << num_directions * batch_size * fAttrHiddenSize << " + "
351 << direction * batch_size * fAttrHiddenSize << ";\n";
352 out << SP << SP << "std::copy(" << OpName << "_feedforward + offset, " << OpName
353 << "_feedforward + offset + size, " << OpName << "_hidden_state + h_offset);\n";
354 out << SP << "}\n";
355
356
357 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
358 if (fAttrDirection == "backward" || direction == 1) {
359 out << SP << SP << "size_t index = " << seq_length - 1 << " - seq;\n";
360 } else {
361 out << SP << SP << "size_t index = seq;\n";
362 }
363
364 out << SP << SP << "int m2 = " << batch_size << ";\n";
365 out << SP << SP << "size_t offset = index * "
366 << num_directions * batch_size * fAttrHiddenSize << " + "
367 << direction * batch_size * fAttrHiddenSize << ";\n";
368 out << SP << SP << "size_t size = " << batch_size * fAttrHiddenSize << ";\n";
369 out << SP << SP << "if (seq == 0) {\n";
370 if (!fNInitial_h.empty()) {
371 // hidden_state = hidden_state + initial_hidden_state * R^T
372 out << SP << SP << SP << "size_t r_offset = "
373 << direction * fAttrHiddenSize * fAttrHiddenSize << ";\n";
374 out << SP << SP << SP << "size_t initial_hidden_state_offset = "
375 << direction * batch_size * fAttrHiddenSize << ";\n";
376 if (fType == "float") {
377 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName
378 << "_transA, &" << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName
379 << "_alpha, tensor_" << fNR << " + r_offset, &" << OpName << "_n, " << OpName
380 << "_initial_hidden_state + initial_hidden_state_offset, &" << OpName << "_n, &"
381 << OpName << "_alpha, " << OpName << "_hidden_state + offset, &" << OpName << "_n);\n";
382 }
383 }
384 out << SP << SP << "} else {\n";
385 // hidden_state = hidden_state + previous_hidden_state * R^T
386 out << SP << SP << SP << "size_t r_offset = "
387 << direction * fAttrHiddenSize * fAttrHiddenSize << ";\n";
388 if (fAttrDirection == "backward" || direction == 1) {
389 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
390 << num_directions * batch_size * fAttrHiddenSize
391 << " + " << direction * batch_size * fAttrHiddenSize << ";\n";
392 } else {
393 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
394 << num_directions * batch_size * fAttrHiddenSize
395 << " + " << direction * batch_size * fAttrHiddenSize << ";\n";
396 }
397 if (fType == "float") {
398 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
399 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR
400 << " + r_offset, &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
401 << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_hidden_state + offset, &"
402 << OpName << "_n);\n";
403 }
404 out << SP << SP << "}\n";
405
406 // Clip the elements of the hidden state into the range [-fAttrClip, fAttrClip]
407 if (fAttrClip > .0) {
408 out << SP << SP << "for (size_t i = offset; i < offset + size; i++) {\n";
409 if (fType == "float") {
410 out << SP << SP << SP << "float x = (" << OpName << "_hidden_state[i] > " << -fAttrClip
411 << ") ? " << OpName << "_hidden_state[i] : " << -fAttrClip << ";\n";
412 }
413 out << SP << SP << SP << OpName << "_hidden_state[i] = (x < " << fAttrClip
414 << ") ? x : " << fAttrClip << ";\n";
415 out << SP << SP << "}\n";
416 }
417
418 // Apply the activation function to the hidden state
419 if (fAttrActivations[direction] == "Relu") {
420 out << SP << SP << "for (size_t i = offset; i < offset + size; i++) {\n";
421 out << SP << SP << SP << "if (" << OpName << "_hidden_state[i] < 0.)\n";
422 out << SP << SP << SP << SP << OpName << "_hidden_state[i] = 0.;\n";
423 out << SP << SP << "}\n";
424 } else if (fAttrActivations[direction] == "Tanh") {
425 out << SP << SP << "for (size_t i = offset; i < offset + size; i++) {\n";
426 if (fType == "float") {
427 out << SP << SP << SP << "float ex = std::exp(-2 * " << OpName << "_hidden_state[i]);\n";
428 }
429 out << SP << SP << SP << SP << OpName << "_hidden_state[i] = (1. - ex) / (1. + ex);\n";
430 out << SP << SP << "}\n";
431 } else if (fAttrActivations[direction] == "Sigmoid") {
432 out << SP << SP << "for (size_t i = offset; i < offset + size; i++) {\n";
433 out << SP << SP << SP << SP << OpName << "_hidden_state[i] = 1. / (1. + std::exp(-" << OpName
434 << "_hidden_state[i]));\n";
435 out << SP << SP << "}\n";
436 } else if (fAttrActivations[direction] == "Affine") {
437 out << SP << SP << "for (size_t i = offset; i < offset + size; i++) {\n";
438 out << SP << SP << SP << SP << OpName << "_hidden_state[i] = " << fAttrActivationAlpha[direction]
439 << " * " << OpName << "_hidden_state[i] + " << fAttrActivationBeta[direction] << ";\n";
440 out << SP << SP << "}\n";
441 } else if (fAttrActivations[direction] == "ScaledTanh") {
442 out << SP << SP << "for (size_t i = offset; i < offset + size; i++) {\n";
443 if (fType == "float") {
444 out << SP << SP << SP << "float ex = std::exp(-2 * " << fAttrActivationBeta[direction]
445 << " * "<< OpName << "_hidden_state[i]);\n";
446 }
447 out << SP << SP << SP << SP << OpName << "_hidden_state[i] = " << fAttrActivationAlpha[direction]
448 << " * (1. - ex) / (1. + ex);\n";
449 out << SP << SP << "}\n";
450 } else if (fAttrActivations[direction] == "HardSigmoid") {
451 out << SP << SP << "for (size_t i = offset; i < offset + size; i++) {\n";
452 if (fType == "float") {
453 out << SP << SP << SP << "float a = " << fAttrActivationAlpha[direction] << " * "
454 << OpName << "_hidden_state[i] + " << fAttrActivationBeta[direction] << ";\n";
455 out << SP << SP << SP << "float b = (a > 0.) ? a : 0.;\n";
456 }
457 out << SP << SP << SP << SP << OpName << "_hidden_state[i] = (b < 1.) ? b : 1.;\n";
458 out << SP << SP << "}\n";
459 } else if (fAttrActivations[direction] == "LeakyRelu") {
460 out << SP << SP << "for (size_t i = offset; i < offset + size; i++) {\n";
461 out << SP << SP << SP << "if (" << OpName << "_hidden_state[i] < 0.)\n";
462 out << SP << SP << SP << SP << OpName << "_hidden_state[i] = " << fAttrActivationAlpha[direction]
463 << " * " << OpName << "_hidden_state[i];\n";
464 out << SP << SP << "}\n";
465 } else if (fAttrActivations[direction] == "ThresholdRelu") {
466 out << SP << SP << "for (size_t i = offset; i < offset + size; i++) {\n";
467 out << SP << SP << SP << "if (" << OpName << "_hidden_state[i] < "
468 << fAttrActivationAlpha[direction] << ")\n";
469 out << SP << SP << SP << SP << OpName << "_hidden_state[i] = 0.;\n";
470 out << SP << SP << "}";
471 } else if (fAttrActivations[direction] == "Elu") {
472 out << SP << SP << "for (size_t i = offset; i < offset + size; i++) {\n";
473 out << SP << SP << SP << "if (" << OpName << "_hidden_state[i] < 0.)\n";
474 out << SP << SP << SP << SP << OpName << "_hidden_state[i] = " << fAttrActivationAlpha[direction]
475 << " * std::exp(" << OpName << "_hidden_state[i] - 1.);\n";
476 out << SP << SP << "}\n";
477 } else if (fAttrActivations[direction] == "Softsign") {
478 out << SP << SP << "for (size_t i = offset; i < offset + size; i++) {\n";
479 out << SP << SP << SP << SP << OpName << "_hidden_state[i] = " << OpName
480 << "_hidden_state[i] / (1. + abs(" << OpName << "_hidden_state[i]));\n";
481 out << SP << SP << "}\n";
482 } else { // fAttrActivations[direction] = Softplus
483 out << SP << SP << "for (size_t i = offset; i < offset + size; i++) {\n";
484 out << SP << SP << SP << SP << OpName << "_hidden_state[i] = log(1. + std::exp("
485 << OpName << "_hidden_state[i]));\n";
486 out << SP << SP << "}\n";
487 out << SP << "}\n";
488 }
489 out << SP << "}\n";
490 }
491
492 // Padding the hidden state for RNN with different sequence lengths
493 if (!fNSequence_lens.empty()) {
494 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
495 out << SP << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
496 out << SP << SP << SP << "if (seq >= tensor_" << fNSequence_lens << "[batch]) {\n";
497 out << SP << SP << SP << SP << "for (size_t h = 0; h < " << fAttrHiddenSize << "; h++) {\n";
498 if (num_directions == 1) {
499 out << SP << SP << SP << SP << SP << OpName << "_hidden_state[seq * "
500 << num_directions * batch_size * fAttrHiddenSize << " + batch * "
501 << fAttrHiddenSize << " + h] = 0.;\n";
502 } else {
503 out << SP << SP << SP << SP << SP << OpName << "_hidden_state[seq * "
504 << num_directions * batch_size * fAttrHiddenSize << " + batch * "
505 << fAttrHiddenSize << " + h] = 0.;\n";
506 out << SP << SP << SP << SP << SP << OpName << "_hidden_state[seq * "
507 << num_directions * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize
508 << " + batch * " << fAttrHiddenSize << " + h] = 0.;\n";
509 }
510 out << SP << SP << SP << SP << "}\n";
511 out << SP << SP << SP << "}\n";
512 out << SP << SP << "}\n";
513 out << SP << "}\n";
514 }
515
516 // Copy the hidden state into y and y_h
517 if (fAttrLayout == 0) {
518 if (!fNY_h.empty()) {
519 if (fNSequence_lens.empty()) {
520 size_t yh_size = batch_size * fAttrHiddenSize;
521 if (fAttrDirection == "backward") {
522 out << SP << "std::copy(" << OpName << "_hidden_state, " << OpName << "_hidden_state + "
523 << yh_size << ", tensor_" << fNY_h << ");\n";
524 } else {
525 size_t offset = (seq_length - 1) * num_directions * batch_size * fAttrHiddenSize;
526 out << SP << "std::copy(" << OpName << "_hidden_state + " << offset << ", " << OpName
527 << "_hidden_state + " << offset << " + " << yh_size << ", tensor_" << fNY_h << ");\n";
528 }
529 if (num_directions == 2) {
530 out << SP << "std::copy(" << OpName << "_hidden_state + " << yh_size << ", " << OpName
531 << "_hidden_state + " << 2 * yh_size << ", tensor_" << fNY_h << " + " << yh_size << ");\n";
532 }
533 } else { // RNN with different sequence lengths
534 if (fAttrDirection == "backward") {
535 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
536 out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
537 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
538 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + offset);\n";
539 out << SP << "}\n";
540 } else {
541 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
542 out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
543 out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
544 << " + batch * " << fAttrHiddenSize << ";\n";
545 out << SP << SP << "size_t yh_offset = batch * " << fAttrHiddenSize << ";\n";
546 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
547 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
548 out << SP << "}\n";
549 }
550 if (num_directions == 2) {
551 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
552 out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize
553 << " + batch * " << fAttrHiddenSize << ";\n";
554 out << SP << SP << "size_t yh_offset = " << batch_size * fAttrHiddenSize
555 << " + batch * " << fAttrHiddenSize << ";\n";
556 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
557 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
558 out << SP << "}\n";
559 }
560 }
561 }
562 } else { // fAttrLayout=1
563 if (!fNY.empty()) {
564 for (size_t direction = 0; direction < num_directions; direction++) {
565 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
566 out << SP << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
567 out << SP << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
568 << " + " << direction * batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize << ";\n";
569 out << SP << SP << SP << "size_t y_offset = batch * " << seq_length * num_directions * fAttrHiddenSize
570 << " + seq * " << num_directions * fAttrHiddenSize << " + " << direction * fAttrHiddenSize << ";\n";
571 out << SP << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
572 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY << " + y_offset);\n";
573 out << SP << SP << "}\n";
574 out << SP << "}\n";
575 }
576 }
577 if (!fNY_h.empty()) {
578 if (fAttrDirection == "backward") {
579 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
580 out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
581 out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
582 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
583 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
584 out << SP << "}\n";
585 } else {
586 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
587 if (fNSequence_lens.empty()) {
588 out << SP << SP << "size_t seq = " << seq_length - 1 << ";\n";
589 } else {
590 out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
591 }
592 out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
593 << " + batch * " << fAttrHiddenSize << ";\n";
594 out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
595 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
596 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
597 out << SP << "}\n";
598 }
599 if (num_directions == 2) {
600 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
601 out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize << " + batch * "
602 << fAttrHiddenSize << ";\n";
603 out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << " + "
604 << fAttrHiddenSize << ";\n";
605 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
606 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
607 out << SP << "}\n";
608 }
609 }
610 }
611
612 return out.str();
613}
614
615} // namespace SOFIE
616} // namespace Experimental
617} // namespace TMVA
618
619#endif
#define h(i)
Definition RSha256.hxx:106
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input)
Infers the type of the output tensors.
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > input)
Infers the shape of the output tensors.
std::string Generate(std::string OpName)
Generates the inference code.
void Initialize(RModel &model)
Initialize the model.
std::string GenerateSessionMembersCode(std::string opName)
create variable transformations
static uint64_t sum(uint64_t i)
Definition Factory.cxx:2345