Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
Dropout.cu
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Author: Simon Pfreundschuh 14/07/16
3
4/*************************************************************************
5 * Copyright (C) 2016, Simon Pfreundschuh *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
13/*#include "TMVA/DNN/Architectures/Cuda/Device.h"
14#include "Kernels.cuh"*/
15
16/////////////////////////////////////////////////////////////////////
17// Implementation of the Dropout function for TCudnn architectures.//
18/////////////////////////////////////////////////////////////////////
19
20namespace TMVA {
21namespace DNN {
22
23// FIXME: Do testing!!!
24//____________________________________________________________________________
25template<typename AFloat>
26void TCudnn<AFloat>::DropoutForward(TCudaTensor<AFloat> &A,
27 TDescriptors * descriptors,
28 TWorkspace * workspace,
29 AFloat /*dropoutProbability*/)
30{
31 if (!workspace || !descriptors) return;
32 auto poolWorkspace = static_cast<ConvWorkspace_t *>(workspace);
33 auto poolDescriptors = static_cast<PoolingDescriptors_t *>(descriptors);
34
35 //TCudaTensor<AFloat> tmp (A);
36
37 // Write the output into A
38 CUDNNCHECK(cudnnDropoutForward(A.GetCudnnHandle(),
39 poolDescriptors->HelperDescriptor,
40 A.GetTensorDescriptor(),// use tmp, if inplace op fails
41 A.GetDataPointer(),
42 A.GetTensorDescriptor(),
43 A.GetDataPointer(),
44 poolWorkspace->HelperWorkspace,
45 poolWorkspace->HelperWorkspaceSize));
46}
47
48//____________________________________________________________________________
49template<typename AFloat>
50void TCudnn<AFloat>::DropoutBackward(TCudaTensor<AFloat> &A,
51 TDescriptors * descriptors,
52 TWorkspace * workspace)
53{
54 if (!workspace || !descriptors) return;
55 auto poolWorkspace = static_cast<ConvWorkspace_t *>(workspace);
56 auto poolDescriptors = static_cast<PoolingDescriptors_t *>(descriptors);
57
58 //TCudaTensor<AFloat> tmp (A);
59
60 // Write the output into A
61 CUDNNCHECK(cudnnDropoutBackward(A.GetCudnnHandle(),
62 poolDescriptors->HelperDescriptor,
63 A.GetTensorDescriptor(),// use tmp, if inplace op fails
64 A.GetDataPointer(),
65 A.GetTensorDescriptor(),
66 A.GetDataPointer(),
67 poolWorkspace->HelperWorkspace,
68 poolWorkspace->HelperWorkspaceSize));
69}
70
71} // namespace DNN
72} // namespace TMVA
create variable transformations