ROOT
v6-34
Reference Guide
Loading...
Searching...
No Matches
Dropout.cu
Go to the documentation of this file.
1
// @(#)root/tmva/tmva/dnn:$Id$
2
// Author: Simon Pfreundschuh 14/07/16
3
4
/*************************************************************************
5
* Copyright (C) 2016, Simon Pfreundschuh *
6
* All rights reserved. *
7
* *
8
* For the licensing terms see $ROOTSYS/LICENSE. *
9
* For the list of contributors see $ROOTSYS/README/CREDITS. *
10
*************************************************************************/
11
12
#include "
TMVA/DNN/Architectures/TCudnn.h
"
13
/*#include "TMVA/DNN/Architectures/Cuda/Device.h"
14
#include "Kernels.cuh"*/
15
16
/////////////////////////////////////////////////////////////////////
17
// Implementation of the Dropout function for TCudnn architectures.//
18
/////////////////////////////////////////////////////////////////////
19
20
namespace
TMVA
{
21
namespace
DNN {
22
23
// FIXME: Do testing!!!
24
//____________________________________________________________________________
25
template
<
typename
AFloat>
26
void
TCudnn<AFloat>::DropoutForward
(
TCudaTensor<AFloat>
&A,
27
TDescriptors *
descriptors
,
28
TWorkspace * workspace,
29
AFloat
/*dropoutProbability*/
)
30
{
31
if
(!workspace || !
descriptors
)
return
;
32
auto
poolWorkspace
=
static_cast<
ConvWorkspace_t *
>
(workspace);
33
auto
poolDescriptors
=
static_cast<
PoolingDescriptors_t *
>
(
descriptors
);
34
35
//TCudaTensor<AFloat> tmp (A);
36
37
// Write the output into A
38
CUDNNCHECK
(
cudnnDropoutForward
(A.GetCudnnHandle(),
39
poolDescriptors
->HelperDescriptor,
40
A.GetTensorDescriptor(),
// use tmp, if inplace op fails
41
A.GetDataPointer(),
42
A.GetTensorDescriptor(),
43
A.GetDataPointer(),
44
poolWorkspace
->HelperWorkspace,
45
poolWorkspace
->HelperWorkspaceSize));
46
}
47
48
//____________________________________________________________________________
49
template
<
typename
AFloat>
50
void
TCudnn<AFloat>::DropoutBackward
(
TCudaTensor<AFloat>
&A,
51
TDescriptors *
descriptors
,
52
TWorkspace * workspace)
53
{
54
if
(!workspace || !
descriptors
)
return
;
55
auto
poolWorkspace
=
static_cast<
ConvWorkspace_t *
>
(workspace);
56
auto
poolDescriptors
=
static_cast<
PoolingDescriptors_t *
>
(
descriptors
);
57
58
//TCudaTensor<AFloat> tmp (A);
59
60
// Write the output into A
61
CUDNNCHECK
(
cudnnDropoutBackward
(A.GetCudnnHandle(),
62
poolDescriptors
->HelperDescriptor,
63
A.GetTensorDescriptor(),
// use tmp, if inplace op fails
64
A.GetDataPointer(),
65
A.GetTensorDescriptor(),
66
A.GetDataPointer(),
67
poolWorkspace
->HelperWorkspace,
68
poolWorkspace
->HelperWorkspaceSize));
69
}
70
71
}
// namespace DNN
72
}
// namespace TMVA
TRangeDynCast
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Definition
TCollection.h:358
TCudnn.h
ROOT::Detail::TRangeCast
Definition
TCollection.h:311
TMVA
create variable transformations
Definition
GeneticMinimizer.h:22
tmva
tmva
src
DNN
Architectures
Cudnn
Dropout.cu
ROOT v6-34 - Reference Guide Generated on Wed Jan 29 2025 04:46:28 (GVA Time) using Doxygen 1.10.0