ROOT
git-r3/HEAD
Reference Guide
Loading...
Searching...
No Matches
Dropout.cu
Go to the documentation of this file.
1
// @(#)root/tmva/tmva/dnn:$Id$
2
// Author: Simon Pfreundschuh 14/07/16
3
4
/*************************************************************************
5
* Copyright (C) 2016, Simon Pfreundschuh *
6
* All rights reserved. *
7
* *
8
* For the licensing terms see $ROOTSYS/LICENSE. *
9
* For the list of contributors see $ROOTSYS/README/CREDITS. *
10
*************************************************************************/
11
12
#include "
TMVA/DNN/Architectures/Cuda.h
"
13
#include "
TMVA/DNN/Architectures/Cuda/Device.h
"
14
#include "
Kernels.cuh
"
15
16
/////////////////////////////////////////////////////////////////////
17
// Implementation of the Dropout function for TCuda architectures. //
18
/////////////////////////////////////////////////////////////////////
19
20
namespace
TMVA
{
21
namespace
DNN
{
22
23
//____________________________________________________________________________
24
template
<
typename
AFloat>
25
void
TCuda<AFloat>::DropoutForward
(
TCudaTensor<AFloat>
& A,
26
TDescriptors
*
/*descriptors*/
,
27
TWorkspace
*
/*workspace*/
,
28
AFloat dropoutProbability)
29
{
30
dim3 blockDims =
TDevice::BlockDims2D
();
31
dim3 gridDims =
TDevice::GridDims2D
(A);
32
cudaStream_t s = A.GetComputeStream();
33
::TMVA::DNN::Cuda::Dropout<<<gridDims, blockDims, 0, s>
>>(
34
A.GetDataPointer(),
35
(
int
) A.GetNrows(),
36
(
int
) A.GetNcols(),
37
dropoutProbability,
38
TCudaMatrix<AFloat>::GetCurandStatesPointer
());
39
}
40
41
}
// namespace DNN
42
}
// namespace TMVA
Cuda.h
Device.h
Kernels.cuh
TMVA::DNN::TCudaMatrix::GetCurandStatesPointer
static curandState_t * GetCurandStatesPointer()
Definition
CudaMatrix.h:152
TMVA::DNN::TCudaTensor
TCudaTensor Class.
Definition
CudaTensor.h:84
TMVA::DNN::TCuda::DropoutForward
static void DropoutForward(Tensor_t &A, TDescriptors *descriptors, TWorkspace *workspace, Scalar_t p)
Apply dropout with activation probability p to the given tensor A and scale the result by reciprocal ...
TMVA::DNN::TDevice::BlockDims2D
static dim3 BlockDims2D()
Definition
Device.h:55
TMVA::DNN::TDevice::GridDims2D
static dim3 GridDims2D(int nrows, int ncols)
Definition
Device.h:74
TMVA::DNN::Cuda::Dropout
__global__ void Dropout(AFloat *A, int m, int n, AFloat dropoutProbability, curandState_t *state)
Definition
Kernels.cuh:964
TMVA::DNN
Definition
Adadelta.h:36
TMVA
create variable transformations
Definition
GeneticMinimizer.h:22
TMVA::DNN::TDescriptors
Definition
ContextHandles.h:29
TMVA::DNN::TWorkspace
Definition
ContextHandles.h:32
tmva
tmva
src
DNN
Architectures
Cuda
Dropout.cu
ROOTgit-r3/HEAD - Reference Guide Generated on
(GVA Time) using Doxygen 1.16.1