25template<
typename AFloat>
26void TCudnn<AFloat>::DropoutForward(TCudaTensor<AFloat> &A,
27 TDescriptors * descriptors,
28 TWorkspace * workspace,
31 if (!workspace || !descriptors)
return;
32 auto poolWorkspace =
static_cast<ConvWorkspace_t *
>(workspace);
33 auto poolDescriptors =
static_cast<PoolingDescriptors_t *
>(descriptors);
38 CUDNNCHECK(cudnnDropoutForward(A.GetCudnnHandle(),
39 poolDescriptors->HelperDescriptor,
40 A.GetTensorDescriptor(),
42 A.GetTensorDescriptor(),
44 poolWorkspace->HelperWorkspace,
45 poolWorkspace->HelperWorkspaceSize));
49template<
typename AFloat>
50void TCudnn<AFloat>::DropoutBackward(TCudaTensor<AFloat> &A,
51 TDescriptors * descriptors,
52 TWorkspace * workspace)
54 if (!workspace || !descriptors)
return;
55 auto poolWorkspace =
static_cast<ConvWorkspace_t *
>(workspace);
56 auto poolDescriptors =
static_cast<PoolingDescriptors_t *
>(descriptors);
61 CUDNNCHECK(cudnnDropoutBackward(A.GetCudnnHandle(),
62 poolDescriptors->HelperDescriptor,
63 A.GetTensorDescriptor(),
65 A.GetTensorDescriptor(),
67 poolWorkspace->HelperWorkspace,
68 poolWorkspace->HelperWorkspaceSize));
create variable transformations