24template <
typename AFloat>
25template <
typename RNNLayer>
26void TCudnn<AFloat>::InitializeRecurrentTensors(RNNLayer *layer)
29 size_t timeSteps = (layer->DoesReturnSequence()) ? layer->GetTimeSteps() : 1;
31 Tensor_t(layer->GetOutput().GetDeviceBuffer(),
32 {layer->GetBatchSize(), timeSteps, layer->GetStateSize()}, GetTensorLayout());
33 layer->GetActivationGradients() =
34 Tensor_t(layer->GetActivationGradients().GetDeviceBuffer(), {layer->GetBatchSize(), timeSteps, layer->GetStateSize()},
38 for (
size_t i = 0; i < layer->GetWeights().
size(); ++i) {
39 auto &
w = layer->GetWeightsAt(i);
41 w = Tensor_t(layer->GetWeightsAt(i).GetDeviceBuffer(), {layer->GetWeightsAt(i).GetNrows(), layer->GetWeightsAt(i).GetNcols()},
45 for (
size_t i = 0; i < layer->GetBiases().
size(); ++i) {
48 auto &
b = layer->GetBiasesAt(i);
49 b = Tensor_t(layer->GetBiasesAt(i).GetDeviceBuffer(), {layer->GetStateSize(), 1}, GetTensorLayout(), 0, 0);
60 layer->GetX() = Tensor_t({layer->GetTimeSteps(), layer->GetBatchSize(), layer->GetInputSize() }, GetTensorLayout());
61 layer->GetY() = Tensor_t({layer->GetTimeSteps(), layer->GetBatchSize(), layer->GetStateSize() }, GetTensorLayout());
63 layer->GetDX() = Tensor_t({layer->GetTimeSteps(), layer->GetBatchSize(), layer->GetInputSize() }, GetTensorLayout());
64 layer->GetDY() = Tensor_t({layer->GetTimeSteps(), layer->GetBatchSize(), layer->GetStateSize() }, GetTensorLayout());
67template <
typename AFloat>
68template <
typename RNNLayer>
69void TCudnn<AFloat>::InitializeRecurrentDescriptors(TDescriptors *&descriptors, RNNLayer *layer)
72 auto rnnDescriptors =
new RNNDescriptors_t ();
73 CUDNNCHECK(cudnnCreateRNNDescriptor(&rnnDescriptors->LayerDescriptor));
75 CUDNNCHECK(cudnnCreateDropoutDescriptor(&rnnDescriptors->HelperDescriptor));
77 enum RNNType {kRNN, kLSTM, kGRU};
78 RNNType rnn_type = kRNN;
79 if ( std::is_same<RNNLayer, LSTMLayer_t>::value ) rnn_type = kLSTM;
80 if ( std::is_same<RNNLayer, GRULayer_t>::value ) rnn_type = kGRU;
82 cudnnHandle_t handle = layer->GetOutput().GetCudnnHandle();
83 float dropoutProb = 0.0;
85 void *dropoutStates =
nullptr;
86 size_t dropoutStateSize = 0;
89 CUDNNCHECK(cudnnDropoutGetStatesSize(handle, &dropoutStateSize));
93 unsigned long long seed = GetRandomGenerator().GetSeed();
95 CUDNNCHECK(cudnnSetDropoutDescriptor(rnnDescriptors->HelperDescriptor, handle, dropoutProb, dropoutStates,
96 dropoutStateSize, seed));
104 int inputSize = layer->GetInputSize();
105 int hiddenSize = layer->GetStateSize();
108 cudnnRNNInputMode_t inputMode = CUDNN_LINEAR_INPUT;
110 cudnnDirectionMode_t direction = CUDNN_UNIDIRECTIONAL;
111 bool bidirectional = (direction == CUDNN_BIDIRECTIONAL);
113 cudnnRNNMode_t
mode = CUDNN_RNN_TANH;
114 if (rnn_type == kLSTM)
mode = CUDNN_LSTM;
115 if (rnn_type == kGRU)
mode = CUDNN_GRU;
117 cudnnRNNAlgo_t algo = CUDNN_RNN_ALGO_STANDARD;
120 int numLinearLayers = 0;
121 if (
mode == CUDNN_RNN_RELU ||
mode == CUDNN_RNN_TANH) {
124 if (
mode == CUDNN_GRU ) {
127 if (
mode == CUDNN_LSTM) {
131 assert(numLinearLayers == layer->GetWeights().size());
133 cudnnDataType_t mathPrec = CUDNN_DATA_FLOAT;
134 if (std::is_same<AFloat, double>::value) { mathPrec = CUDNN_DATA_DOUBLE;}
137 cudnnRNNBiasMode_t biasMode = CUDNN_RNN_NO_BIAS;
138 if (layer->GetBiases().size() > 0)
139 biasMode = CUDNN_RNN_SINGLE_INP_BIAS;
143 cudnnDataType_t dataType = mathPrec;
144 int projSize = hiddenSize;
147 int seqLength = layer->GetTimeSteps();
150#if (CUDNN_VERSION >= 8000)
151 unsigned int auxFlags = CUDNN_RNN_PADDED_IO_ENABLED;
152 cudnnMathType_t mathType = CUDNN_DEFAULT_MATH;
154 CUDNNCHECK(cudnnSetRNNDescriptor_v8(rnnDescriptors->LayerDescriptor, algo,
mode, biasMode, direction,
155 inputMode, dataType, mathPrec, mathType, inputSize, hiddenSize, projSize, numLayers,
156 rnnDescriptors->HelperDescriptor, auxFlags));
158 CUDNNCHECK(cudnnCreateRNNDataDescriptor(&rnnDescriptors->xDataDesc));
159 CUDNNCHECK(cudnnCreateRNNDataDescriptor(&rnnDescriptors->yDataDesc));
161 std::vector<int> seqLengthArray(layer->GetBatchSize(), seqLength);
162 int vectorSize = inputSize;
164 cudnnRNNDataLayout_t layout = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED;
165 AFloat paddingFill = 0;
166 CUDNNCHECK(cudnnSetRNNDataDescriptor(rnnDescriptors->xDataDesc, dataType, layout, seqLength,
167 layer->GetBatchSize(), vectorSize, seqLengthArray.data(), &paddingFill));
169 vectorSize = bidirectional ? hiddenSize * 2 : hiddenSize;
170 CUDNNCHECK(cudnnSetRNNDataDescriptor(rnnDescriptors->yDataDesc, dataType, layout, seqLength,
171 layer->GetBatchSize(), vectorSize, seqLengthArray.data(), &paddingFill));
174 CUDNNCHECK(cudnnSetRNNDescriptor(handle, rnnDescriptors->LayerDescriptor, hiddenSize, numLayers, rnnDescriptors->HelperDescriptor, inputMode, direction,
mode, algo, mathPrec) );
176 CUDNNCHECK(cudnnSetRNNBiasMode(rnnDescriptors->LayerDescriptor, biasMode));
185 rnnDescriptors->xDesc.resize(seqLength);
186 rnnDescriptors->yDesc.resize(seqLength);
187 rnnDescriptors->dxDesc.resize(seqLength);
188 rnnDescriptors->dyDesc.resize(seqLength);
189 TensorDescriptor_t *xDesc = rnnDescriptors->xDesc.data();
190 TensorDescriptor_t *yDesc = rnnDescriptors->yDesc.data();
191 TensorDescriptor_t *dxDesc = rnnDescriptors->dxDesc.data();
192 TensorDescriptor_t *dyDesc = rnnDescriptors->dyDesc.data();
194 for (
int i = 0; i < seqLength; i++) {
195 CUDNNCHECK(cudnnCreateTensorDescriptor(&xDesc[i]));
196 CUDNNCHECK(cudnnCreateTensorDescriptor(&yDesc[i]));
197 CUDNNCHECK(cudnnCreateTensorDescriptor(&dxDesc[i]));
198 CUDNNCHECK(cudnnCreateTensorDescriptor(&dyDesc[i]));
200 dimA[0] = layer->GetBatchSize();
201 dimA[1] = layer->GetInputSize();
204 strideA[0] = dimA[2] * dimA[1];
205 strideA[1] = dimA[2];
208 CUDNNCHECK(cudnnSetTensorNdDescriptor(xDesc[i], mathPrec, 3, dimA, strideA));
209 CUDNNCHECK(cudnnSetTensorNdDescriptor(dxDesc[i], mathPrec, 3, dimA, strideA));
211 dimA[0] = layer->GetBatchSize();
212 dimA[1] = bidirectional ? hiddenSize * 2 : hiddenSize;
215 strideA[0] = dimA[2] * dimA[1];
216 strideA[1] = dimA[2];
219 CUDNNCHECK(cudnnSetTensorNdDescriptor(yDesc[i], mathPrec, 3, dimA, strideA));
220 CUDNNCHECK(cudnnSetTensorNdDescriptor(dyDesc[i], mathPrec, 3, dimA, strideA));
229 size_t weightsSize = 0;
230#if (CUDNN_VERSION >= 8000)
231 size_t weightSpaceSize = 0;
232 CUDNNCHECK(cudnnGetRNNWeightSpaceSize(handle, rnnDescriptors->LayerDescriptor, &weightSpaceSize));
234 weightsSize = weightSpaceSize;
238 CUDNNCHECK(cudnnCreateFilterDescriptor(&rnnDescriptors->WeightsDescriptor));
239 CUDNNCHECK(cudnnCreateFilterDescriptor(&rnnDescriptors->WeightsGradDescriptor));
241 CUDNNCHECK(cudnnGetRNNParamsSize(handle, rnnDescriptors->LayerDescriptor, xDesc[0], &weightsSize, mathPrec));
245 dimW[0] = (mathPrec == CUDNN_DATA_DOUBLE) ? weightsSize /
sizeof(
double) : weightsSize /
sizeof(float);
249 auto &weightTensor = layer->GetWeightsTensor();
250 auto &weightGradTensor = layer->GetWeightGradientsTensor();
252#if (CUDNN_VERSION >= 8000)
255 weightTensor = Tensor_t( { (size_t) dimW[0]}, GetTensorLayout(), 0, 0);
256 weightGradTensor = Tensor_t({(size_t) dimW[0]}, GetTensorLayout(), 0, 0);
261 weightTensor = Tensor_t( { (size_t) dimW[0], 1, 1}, GetTensorLayout(), 0, 0);
262 weightGradTensor = Tensor_t({(size_t) dimW[0], 1, 1}, GetTensorLayout(), 0, 0);
264 CUDNNCHECK(cudnnSetFilterNdDescriptor(rnnDescriptors->WeightsDescriptor, mathPrec, CUDNN_TENSOR_NCHW, 3, dimW));
265 CUDNNCHECK(cudnnSetFilterNdDescriptor(rnnDescriptors->WeightsGradDescriptor, mathPrec, CUDNN_TENSOR_NCHW, 3, dimW));
273 int nL = (!bidirectional) ? numLayers : 2 * numLayers;
274 for (
int ilayer = 0; ilayer < nL; ilayer++) {
275 for (
int linLayerID = 0; linLayerID < numLinearLayers; linLayerID++) {
277 AFloat *linLayerMat =
nullptr;
278 AFloat *linLayerBias =
nullptr;
281#if (CUDNN_VERSION >= 8000)
283 cudnnTensorDescriptor_t linLayerMatDesc;
284 CUDNNCHECK(cudnnCreateTensorDescriptor(&linLayerMatDesc));
285 cudnnTensorDescriptor_t linLayerBiasDesc;
286 CUDNNCHECK(cudnnCreateTensorDescriptor(&linLayerBiasDesc));
287 CUDNNCHECK(cudnnGetRNNWeightParams(handle, rnnDescriptors->LayerDescriptor, ilayer, weightSpaceSize, weightTensor.GetDataPointer(),
288 linLayerID, linLayerMatDesc, (
void **)&linLayerMat, linLayerBiasDesc, (
void **)&linLayerBias));
293 cudnnFilterDescriptor_t linLayerMatDesc;
294 CUDNNCHECK(cudnnCreateFilterDescriptor(&linLayerMatDesc));
295 cudnnFilterDescriptor_t linLayerBiasDesc;
296 CUDNNCHECK(cudnnCreateFilterDescriptor(&linLayerBiasDesc));
298 CUDNNCHECK(cudnnGetRNNLinLayerMatrixParams(handle, rnnDescriptors->LayerDescriptor, ilayer, rnnDescriptors->xDesc.data()[0],
299 rnnDescriptors->WeightsDescriptor, weightTensor.GetDataPointer(),
300 linLayerID, linLayerMatDesc, (
void **)&linLayerMat));
302 CUDNNCHECK(cudnnGetRNNLinLayerBiasParams(handle, rnnDescriptors->LayerDescriptor, ilayer,
303 rnnDescriptors->xDesc.data()[0], rnnDescriptors->WeightsDescriptor,
304 weightTensor.GetDataPointer(), linLayerID, linLayerBiasDesc,
305 (
void **)&linLayerBias));
310 cudnnDataType_t dataType;
312 int filterDimA[3] = {0,0,0};
314#if (CUDNN_VERSION >= 8000)
316 CUDNNCHECK(cudnnGetTensorNdDescriptor(linLayerMatDesc, 3, &dataType, &nbDims, filterDimA, strideA));
318 cudnnTensorFormat_t
format;
319 CUDNNCHECK(cudnnGetFilterNdDescriptor(linLayerMatDesc, 3, &dataType, &
format, &nbDims, filterDimA));
335 int wsize = layer->GetWeightsAt(linLayerID).GetSize();
346 cudaMemcpyAsync(linLayerMat, layer->GetWeightsAt(linLayerID).GetDataPointer(), wsize *
sizeof(AFloat),
347 cudaMemcpyDeviceToDevice, layer->GetWeightsAt(linLayerID).GetComputeStream());
355 int biasID = linLayerID;
356 if (biasMode == CUDNN_RNN_SINGLE_REC_BIAS) {
359 biasID = linLayerID - 1;
360 if (
mode == CUDNN_LSTM) biasID = linLayerID - 4;
361 if (
mode == CUDNN_GRU) biasID = linLayerID - 3;
365#if (CUDNN_VERSION >= 8000)
367 CUDNNCHECK(cudnnGetTensorNdDescriptor(linLayerBiasDesc, 3, &dataType, &nbDims, filterDimA, strideA));
369 CUDNNCHECK(cudnnGetFilterNdDescriptor(linLayerBiasDesc, 3, &dataType, &
format, &nbDims, filterDimA));
373 if (filterDimA[0] > 0) {
379 int wsize = layer->GetBiasesAt(biasID).GetSize();
390 assert(wsize == filterDimA[1]);
391 cudaMemcpyAsync(linLayerBias, layer->GetBiasesAt(biasID).GetDataPointer(), wsize *
sizeof(AFloat),
392 cudaMemcpyDeviceToDevice, layer->GetBiasesAt(biasID).GetComputeStream());
401#if (CUDNN_VERSION >= 8000)
405 AFloat *bGradOffset =
nullptr;
406 AFloat *wGradOffset =
nullptr;
407 CUDNNCHECK(cudnnGetRNNWeightParams(handle, rnnDescriptors->LayerDescriptor, ilayer, weightSpaceSize, weightGradTensor.GetDataPointer(),
408 linLayerID, linLayerMatDesc, (
void **)&wGradOffset, linLayerBiasDesc, (
void **)&bGradOffset));
413 if (linLayerMat && wGradOffset) {
414 auto &
w = layer->GetWeightsAt(linLayerID);
415 auto & dw = layer->GetWeightGradientsAt(linLayerID);
416 w = Tensor_t( TCudaDeviceBuffer<AFloat>(linLayerMat,
w.GetSize(),
w.GetComputeStream()),
w.GetShape(), GetTensorLayout(), 0, 0);
417 dw = Tensor_t(TCudaDeviceBuffer<AFloat>(wGradOffset, dw.GetSize(), dw.GetComputeStream()), dw.GetShape(), GetTensorLayout(), 0, 0);
419 if (linLayerBias && bGradOffset) {
420 auto &
b = layer->GetBiasesAt(biasID);
421 auto &db = layer->GetBiasGradientsAt(biasID);
422 b = Tensor_t(TCudaDeviceBuffer<AFloat>(linLayerBias,
b.GetSize(),
b.GetComputeStream()),
b.GetShape(), GetTensorLayout(), 0, 0);
423 db = Tensor_t(TCudaDeviceBuffer<AFloat>(bGradOffset, db.GetSize(), db.GetComputeStream()), db.GetShape(), GetTensorLayout(), 0, 0);
432#if (CUDNN_VERSION >= 8000)
435 CUDNNCHECK(cudnnDestroyFilterDescriptor(linLayerMatDesc));
436 CUDNNCHECK(cudnnDestroyFilterDescriptor(linLayerBiasDesc));
451#if (CUDNN_VERSION < 8000)
453 for (
size_t i = 0; i < layer->GetWeights().
size(); ++i) {
454 auto &
w = layer->GetWeightsAt(i);
455 auto & dw = layer->GetWeightGradientsAt(i);
456 if (weightTensor(
offset, 0, 0) !=
w(0, 0))
457 std::cerr <<
"Error - different offset for weight " << i << std::endl;
460 w = Tensor_t(weightTensor.GetDeviceBuffer().GetSubBuffer(
offset,
w.GetSize()),
w.GetShape(),
461 GetTensorLayout(), 0, 0);
462 dw = Tensor_t(weightGradTensor.GetDeviceBuffer().GetSubBuffer(
offset,
w.GetSize()),
w.GetShape(), GetTensorLayout(), 0, 0);
467 for (
size_t i = 0; i < layer->GetBiases().
size(); ++i) {
468 auto &
b = layer->GetBiasesAt(i);
469 auto &db = layer->GetBiasGradientsAt(i);
470 if (weightTensor(
offset, 0, 0) !=
b(0, 0))
471 std::cerr <<
"Error - different offset for bias " << i << std::endl;
474 b = Tensor_t(weightTensor.GetDeviceBuffer().GetSubBuffer(
offset,
b.GetSize()),
b.GetShape(), GetTensorLayout(), 0, 0);
475 db = Tensor_t(weightGradTensor.GetDeviceBuffer().GetSubBuffer(
offset,
b.GetSize()),
b.GetShape(), GetTensorLayout(), 0,
483 descriptors = rnnDescriptors;
487template<
typename AFloat>
488void TCudnn<AFloat>::ReleaseRNNDescriptors(TDescriptors * descriptors)
490 auto & rnnDescriptors =
static_cast<RNNDescriptors_t &
>(*descriptors);
491 CUDNNCHECK(cudnnDestroyRNNDescriptor(rnnDescriptors.LayerDescriptor));
493 ReleaseDescriptor(rnnDescriptors.HelperDescriptor);
494#if (CUDNN_VERSION >= 8000)
495 CUDNNCHECK(cudnnDestroyRNNDataDescriptor(rnnDescriptors.xDataDesc));
496 CUDNNCHECK(cudnnDestroyRNNDataDescriptor(rnnDescriptors.yDataDesc));
498 ReleaseDescriptor(rnnDescriptors.WeightsDescriptor);
499 ReleaseDescriptor(rnnDescriptors.WeightsGradDescriptor);
502 for (
size_t i = 0; i < rnnDescriptors.xDesc.size(); i++) {
503 cudnnDestroyTensorDescriptor(rnnDescriptors.xDesc.data()[i]);
504 cudnnDestroyTensorDescriptor(rnnDescriptors.yDesc.data()[i]);
506 cudnnDestroyTensorDescriptor(rnnDescriptors.dxDesc.data()[i]);
507 cudnnDestroyTensorDescriptor(rnnDescriptors.dyDesc.data()[i]);
514template <
typename AFloat>
515template <
typename RNNLayer>
516void TCudnn<AFloat>::InitializeRecurrentWorkspace(TWorkspace *&workspace, TDescriptors *&descriptors, RNNLayer *layer)
518 auto rnnWorkspace =
new RNNWorkspace_t ();
519 auto rnnDescriptors =
static_cast<RNNDescriptors_t *
>(descriptors);
521 cudnnHandle_t handle = layer->GetOutput().GetCudnnHandle();
523 bool bidirectional =
false;
527 size_t numLayers = 1;
528 if (bidirectional) numLayers *= 2;
531 Tensor_t &stateTensor = layer->GetState();
532 stateTensor = Tensor_t(stateTensor.GetDeviceBuffer(), { numLayers, layer->GetBatchSize(), layer->GetStateSize()},
533 GetTensorLayout(), 0, 0 );
535 if (layer->GetCell().GetSize() > 0) {
536 Tensor_t & cellStateTensor = layer->GetCell();
537 cellStateTensor = Tensor_t(cellStateTensor.GetDeviceBuffer(), {numLayers, layer->GetBatchSize(), layer->GetStateSize()}, GetTensorLayout(), 0, 0 );
542#if (CUDNN_VERSION >= 8000)
545 CUDNNCHECK(cudnnGetRNNTempSpaceSizes(handle, rnnDescriptors->LayerDescriptor, CUDNN_FWD_MODE_TRAINING,
546 rnnDescriptors->xDataDesc, &rnnWorkspace->ForwardWorkspaceSize,
547 &rnnWorkspace->HelperWorkspaceSize));
549 CUDNNCHECK(cudnnGetRNNTempSpaceSizes(handle, rnnDescriptors->LayerDescriptor, CUDNN_FWD_MODE_INFERENCE,
550 rnnDescriptors->xDataDesc, &rnnWorkspace->InferenceWorkspaceSize,
554 CUDNNCHECK(cudnnGetRNNWorkspaceSize(handle, rnnDescriptors->LayerDescriptor, layer->GetTimeSteps(),
555 rnnDescriptors->xDesc.data(), &rnnWorkspace->ForwardWorkspaceSize));
557 CUDNNCHECK(cudnnGetRNNTrainingReserveSize(handle, rnnDescriptors->LayerDescriptor, layer->GetTimeSteps(),
558 rnnDescriptors->xDesc.data(), &rnnWorkspace->HelperWorkspaceSize));
561 if (rnnWorkspace->ForwardWorkspaceSize > 0) cudaMalloc(&rnnWorkspace->ForwardWorkspace, rnnWorkspace->ForwardWorkspaceSize*
sizeof(AFloat));
562 if (rnnWorkspace->ForwardWorkspaceSize > 0 && rnnWorkspace->ForwardWorkspace ==
nullptr ) {
563 std::cerr <<
"Error allocating RNN workspace of size " << rnnWorkspace->ForwardWorkspaceSize <<
" - probably running out of memory on the GPU"
565 std::cout <<
" layer input shape is { " << layer->GetBatchSize() <<
" , " << layer->GetTimeSteps() <<
" , "
566 <<layer->GetStateSize() <<
" } " << std::endl;
571 if (rnnWorkspace->InferenceWorkspaceSize > 0)
572 cudaMalloc(&rnnWorkspace->InferenceWorkspace, rnnWorkspace->InferenceWorkspaceSize*
sizeof(AFloat));
574 if (rnnWorkspace->HelperWorkspaceSize > 0) cudaMalloc(&rnnWorkspace->HelperWorkspace, rnnWorkspace->HelperWorkspaceSize*
sizeof(AFloat));
575 if (rnnWorkspace->HelperWorkspaceSize > 0 && rnnWorkspace->HelperWorkspace ==
nullptr ) {
576 std::cerr <<
"Error allocating RNN reserved workspace of size " << rnnWorkspace->HelperWorkspaceSize <<
" - probably running out of memory on the GPU"
578 std::cout <<
" layer input shape is { " << layer->GetBatchSize() <<
" , " << layer->GetTimeSteps() <<
" , "
579 <<layer->GetStateSize() <<
" } " << std::endl;
584 workspace = rnnWorkspace;
589template <
typename AFloat>
590void TCudnn<AFloat>::FreeRNNWorkspace(TWorkspace * workspace) {
591 if (!workspace)
return;
592 auto rnnWorkspace =
static_cast<RNNWorkspace_t *
>(workspace);
594 if(rnnWorkspace->ForwardWorkspace) cudaFree(rnnWorkspace->ForwardWorkspace);
595 if(rnnWorkspace->InferenceWorkspace) cudaFree(rnnWorkspace->InferenceWorkspace);
596 if(rnnWorkspace->HelperWorkspace) cudaFree(rnnWorkspace->HelperWorkspace);
602template <
typename AFloat>
603void TCudnn<AFloat>::RNNForward(
const Tensor_t &
x,
const Tensor_t &hx,
const Tensor_t &cx,
const Tensor_t & weights, Tensor_t &
y,
604 Tensor_t &hy, Tensor_t &cy,
const RNNDescriptors_t & desc, RNNWorkspace_t &workspace,
bool isTraining)
611 bool rememberState =
false;
612 cudnnHandle_t cudnnHandle =
x.GetCudnnHandle();
614 int seqLength =
x.GetShape()[0];
615 cudnnRNNDescriptor_t rnnDesc = desc.LayerDescriptor;
618 bool isLSTM = (cx.GetSize() > 0) && rememberState;
620#if (CUDNN_VERSION >= 8000)
622 cudnnForwardMode_t fwdMode = (isTraining) ? CUDNN_FWD_MODE_TRAINING : CUDNN_FWD_MODE_INFERENCE;
623 const int * devSeqLength =
nullptr;
624 size_t weightSpaceSize = (std::is_same<AFloat, double>::value) ? weights.GetSize()*
sizeof(
double) :
625 weights.GetSize()* sizeof(float);
626 size_t workspaceSize = (isTraining) ? workspace.ForwardWorkspaceSize : workspace.InferenceWorkspaceSize;
627 void * workspacePtr = (isTraining) ? workspace.ForwardWorkspace : workspace.InferenceWorkspace;
628 cudnnStatus_t status = cudnnRNNForward(
629 cudnnHandle, rnnDesc, fwdMode, devSeqLength,
631 desc.xDataDesc,
x.GetDataPointer(), desc.yDataDesc,
y.GetDataPointer(),
632 hx.GetTensorDescriptor(), (rememberState) ? hx.GetDataPointer(): nullptr,
633 (rememberState) ? hy.GetDataPointer() : nullptr,
634 (isLSTM) ? cx.GetTensorDescriptor() : hx.GetTensorDescriptor(), (isLSTM) ? cx.GetDataPointer() : nullptr,
635 (isLSTM) ? cy.GetDataPointer() : nullptr,
636 weightSpaceSize, weights.GetDataPointer(), workspaceSize, workspacePtr,
637 workspace.HelperWorkspaceSize, workspace.HelperWorkspace);
639 assert(status == CUDNN_STATUS_SUCCESS);
645 cudnnStatus_t status = cudnnRNNForwardTraining(
646 cudnnHandle, rnnDesc, seqLength, desc.xDesc.data(),
x.GetDataPointer(), hx.GetTensorDescriptor(), (rememberState) ?
647 hx.GetDataPointer() : nullptr, (isLSTM) ? cx.GetTensorDescriptor() : hx.GetTensorDescriptor(), (isLSTM) ? cx.GetDataPointer() : nullptr, desc.WeightsDescriptor,
648 weights.GetDataPointer(), desc.yDesc.
data(),
y.GetDataPointer(), hy.GetTensorDescriptor(), hy.GetDataPointer(),
649 (isLSTM) ? cy.GetTensorDescriptor() : hy.GetTensorDescriptor(), (isLSTM) ? cy.GetDataPointer() : nullptr, workspace.ForwardWorkspace, workspace.ForwardWorkspaceSize,
650 workspace.HelperWorkspace, workspace.HelperWorkspaceSize);
652 assert(status == CUDNN_STATUS_SUCCESS);
658 cudnnStatus_t status = cudnnRNNForwardInference(
659 cudnnHandle, rnnDesc, seqLength, desc.xDesc.data(),
x.GetDataPointer(), hx.GetTensorDescriptor(),
660 (rememberState) ? hx.GetDataPointer() : nullptr,
661 (isLSTM) ? cx.GetTensorDescriptor() : hx.GetTensorDescriptor(), (isLSTM) ? cx.GetDataPointer() : nullptr,
662 desc.WeightsDescriptor, weights.GetDataPointer(), desc.yDesc.
data(),
y.GetDataPointer(),
663 hy.GetTensorDescriptor(), hy.GetDataPointer(), (isLSTM) ? cy.GetTensorDescriptor() : hy.GetTensorDescriptor(),
664 (isLSTM) ? cy.GetDataPointer() : nullptr, workspace.ForwardWorkspace, workspace.ForwardWorkspaceSize);
666 assert(status == CUDNN_STATUS_SUCCESS);
678template <
typename AFloat>
679void TCudnn<AFloat>::RNNBackward(
const Tensor_t &
x,
const Tensor_t &hx,
const Tensor_t &cx,
const Tensor_t &
y,
680 const Tensor_t &dy,
const Tensor_t &dhy,
const Tensor_t &dcy,
const Tensor_t &weights,
681 Tensor_t &dx, Tensor_t &dhx, Tensor_t &dcx, Tensor_t &dw,
const RNNDescriptors_t &desc,
682 RNNWorkspace_t &workspace)
685 bool rememberState =
false;
686 bool rememberStateGrad =
false;
687 bool isLSTM = (cx.GetSize() > 0) && rememberState;
688 int seqLength =
x.GetShape()[0];
689 int batchSize =
x.GetShape()[1];
690 cudnnRNNDescriptor_t rnnDesc = desc.LayerDescriptor;
691 cudnnHandle_t cudnnHandle =
x.GetCudnnHandle();
697#if (CUDNN_VERSION >= 8000)
707 size_t weightSpaceSize = (std::is_same<AFloat, double>::value) ? weights.GetSize()*
sizeof(
double) :
708 weights.GetSize()* sizeof(float);
709 cudnnStatus_t status = cudnnRNNBackwardData_v8(
710 cudnnHandle, rnnDesc, NULL,
711 desc.yDataDesc,
y.GetDataPointer(), dy.GetDataPointer(),
712 desc.xDataDesc, dx.GetDataPointer(),
713 hx.GetTensorDescriptor(), (rememberState) ? hx.GetDataPointer() : nullptr,
714 (rememberStateGrad) ? dhy.GetDataPointer() : nullptr, (rememberStateGrad) ? dhx.GetDataPointer() : nullptr,
715 (isLSTM) ? cx.GetTensorDescriptor() : hx.GetTensorDescriptor(),
716 (isLSTM) ? cx.GetDataPointer() : nullptr, (isLSTM) ? dcy.GetDataPointer() : nullptr, (isLSTM) ? dcx.GetDataPointer() : nullptr,
717 weightSpaceSize, weights.GetDataPointer(),
718 workspace.ForwardWorkspaceSize, workspace.ForwardWorkspace, workspace.HelperWorkspaceSize, workspace.HelperWorkspace);
721 assert(status == CUDNN_STATUS_SUCCESS);
733 status = cudnnRNNBackwardWeights_v8(cudnnHandle, rnnDesc,CUDNN_WGRAD_MODE_ADD, NULL,
734 desc.xDataDesc,
x.GetDataPointer(),
735 hx.GetTensorDescriptor(), (rememberState) ? hx.GetDataPointer() : nullptr,
736 desc.yDataDesc,
y.GetDataPointer(),
737 weightSpaceSize, dw.GetDataPointer(),
738 workspace.ForwardWorkspaceSize, workspace.ForwardWorkspace, workspace.HelperWorkspaceSize, workspace.HelperWorkspace);
746 cudnnStatus_t status = cudnnRNNBackwardData(
747 cudnnHandle, rnnDesc, seqLength, desc.yDesc.data(),
y.GetDataPointer(), desc.dyDesc.data(), dy.GetDataPointer(),
748 dhy.GetTensorDescriptor(), (rememberStateGrad) ? dhy.GetDataPointer() : nullptr,
749 (isLSTM) ? dcy.GetTensorDescriptor() : dhy.GetTensorDescriptor(), (isLSTM) ? dcy.GetDataPointer() : nullptr,
750 desc.WeightsDescriptor, weights.GetDataPointer(), hx.GetTensorDescriptor(),
751 (rememberState) ? hx.GetDataPointer() : nullptr, (isLSTM) ? cx.GetTensorDescriptor() : hx.GetTensorDescriptor(),
752 (isLSTM) ? cx.GetDataPointer() : nullptr,
753 desc.dxDesc.
data(), dx.GetDataPointer(), dhx.GetTensorDescriptor(),
754 (rememberState) ? dhx.GetDataPointer() : nullptr,
755 (isLSTM) ? dcx.GetTensorDescriptor() : dhx.GetTensorDescriptor(),
756 (isLSTM) ? dcx.GetDataPointer() : nullptr,
757 workspace.ForwardWorkspace, workspace.ForwardWorkspaceSize, workspace.HelperWorkspace,
758 workspace.HelperWorkspaceSize);
760 assert(status == CUDNN_STATUS_SUCCESS);
764 status = cudnnRNNBackwardWeights(cudnnHandle, rnnDesc, seqLength, desc.xDesc.data(),
x.GetDataPointer(),
765 hx.GetTensorDescriptor(), (rememberState) ? hx.GetDataPointer() : nullptr,
766 desc.yDesc.
data(),
y.GetDataPointer(), workspace.ForwardWorkspace,
767 workspace.ForwardWorkspaceSize, desc.WeightsGradDescriptor, dw.GetDataPointer(),
768 workspace.HelperWorkspace, workspace.HelperWorkspaceSize);
770 assert(status == CUDNN_STATUS_SUCCESS);
777template<
typename AFloat>
778void TCudnn<AFloat>::Rearrange(Tensor_t &
y,
const Tensor_t &
x) {
782 cudnnHandle_t cudnnHandle =
x.GetCudnnHandle();
785 TensorDescriptor_t
d =
tmp.GetTensorDescriptor();
789 cudnnDataType_t dataType;
790 cudnnGetTensorNdDescriptor(
d,
tmp.GetNDim() , &dataType, &
n, dims, strides);
795 auto outputShape =
y.GetShape();
796 assert(xNdim ==
y.GetNDim());
798 assert(outputShape[0] = dims[1]);
799 assert(outputShape[1] == dims[0]);
800 assert(outputShape[2] == (
n ==4) ? dims[3] : dims[2]);
801 if (
n==4) assert(dims[2] == 1);
805 int xStrides[xNdim] = { (
int) outputShape[2], (
int)(outputShape[2] * outputShape[0]), 1 };
807 for (
int i = 0; i < xNdim; ++i)
808 xDims[i] = outputShape[i];
810 cudnnStatus_t status = cudnnSetTensorNdDescriptor(
d, dataType, xNdim, xDims, xStrides);
811 assert(status == CUDNN_STATUS_SUCCESS);
813 status = cudnnTransformTensor(cudnnHandle, &alpha,
d,
x.GetDataPointer() , &beta,
814 y.GetTensorDescriptor(),
y.GetDataPointer());
815 assert(status == CUDNN_STATUS_SUCCESS);
819 status = cudnnSetTensorNdDescriptor(
d, dataType,
n, dims, strides);
820 assert(status == CUDNN_STATUS_SUCCESS);
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h offset
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char mode
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t format
double beta(double x, double y)
Calculates the beta function.
create variable transformations