Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
Device.h
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Author: Simon Pfreundschuh 13/07/16
3
4/*************************************************************************
5 * Copyright (C) 2016, Simon Pfreundschuh *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12////////////////////////////////////////////////////////////////
13// Defines the TDevice class which encapsules device specific //
14// settings for the launching of threads. //
15////////////////////////////////////////////////////////////////
16
17#ifndef TMVA_DNN_ARCHITECTURES_CUDA_DEVICE
18#define TMVA_DNN_ARCHITECTURES_CUDA_DEVICE
19
20#include "cuda.h"
21#include "vector_types.h" // definition of dim3
22#include "CudaMatrix.h"
23
24namespace TMVA
25{
26namespace DNN
27{
28
29/** TDevice
30 *
31 * The TDevice class provides static functions for the generation of CUDA
32 * grids for kernel launches and is used to encapsulate the distribution
33 * of threads and blocks over the data.
34 *
35 */
37{
38public:
39 /* Number of threads per block along first dimensions. */
40 static constexpr int BlockDimX = 1;
41 /* Number of threads per block along second dimensions. */
42 static constexpr int BlockDimY = 32;
43 /* Resulting block size. */
44 static constexpr int BlockSize = BlockDimX * BlockDimY;
45
46 /* Return 1D block of size 1 along the x-dimension and BlockSize along
47 * the y-dimension. */
48 static dim3 BlockDims1D()
49 {
50 return dim3(1, BlockSize);
51 }
52
53 /* Return dim3 object representing a BlockDimX x BlockDimY 2D
54 * block */
55 static dim3 BlockDims2D()
56 {
57 return dim3(BlockDimX, BlockDimY);
58 }
59
60 /* Return 1D dim3 object representing the block grid covering the row-range
61 * of the matrix A along the y-dimension. */
62 template<typename AMatrix>
63 static dim3 GridDims1D(const AMatrix &A)
64 {
65 int gridDim = A.GetNrows() / TDevice::BlockSize;
66 if ((A.GetNrows() % TDevice::BlockSize) != 0) {
67 gridDim += 1;
68 }
69 return dim3(1, gridDim);
70 }
71
72 /* Return 2D dim3 object representing the block grid consisting of two-dimensional
73 * BlockDimX x BlockDimY blocks covering a 2D matrix with nrows x ncols (assume columnmajor storage */
74 static dim3 GridDims2D(int nrows, int ncols)
75 {
76 int gridDimX = ncols / TDevice::BlockDimX;
77 if ((ncols % TDevice::BlockDimX) != 0)
78 gridDimX += 1;
79 int gridDimY = nrows / TDevice::BlockDimY;
80 if ((nrows % TDevice::BlockDimY) != 0)
81 gridDimY += 1;
82 return dim3(gridDimX, gridDimY);
83 }
84
85 /* Return 2D dim3 object representing the block grid consisting of two-dimensional
86 * BlockDimX x BlockDimY blocks covering the matrix A */
87 template<typename AMatrix>
88 static dim3 GridDims2D(const AMatrix &A)
89 {
90 int gridDimX = A.GetNcols() / TDevice::BlockDimX;
91 if ((A.GetNcols() % TDevice::BlockDimX) != 0)
92 gridDimX += 1;
93 int gridDimY = A.GetNrows() / TDevice::BlockDimY;
94 if ((A.GetNrows() % TDevice::BlockDimY) != 0)
95 gridDimY += 1;
96 return dim3(gridDimX, gridDimY);
97 }
98
99 /* Return the number of threads that will be launched for a given matrix \p A */
100 template<typename AMatrix>
101 static int NThreads(const AMatrix &A)
102 {
103 int gridDimX = A.GetNcols() / TDevice::BlockDimX;
104 if ((A.GetNcols() % TDevice::BlockDimX) != 0) {
105 gridDimX += 1;
106 }
107 int gridDimY = A.GetNrows() / TDevice::BlockDimY;
108 if ((A.GetNrows() % TDevice::BlockDimY) != 0) {
109 gridDimY += 1;
110 }
111 return gridDimX * gridDimY * TDevice::BlockDimX * TDevice::BlockDimY;
112 }
113};
114
115} // namespace DNN
116} // namespace TMVA
117
118#endif
static constexpr int BlockDimY
Definition Device.h:42
static dim3 BlockDims2D()
Definition Device.h:55
static constexpr int BlockDimX
Definition Device.h:40
static dim3 GridDims2D(int nrows, int ncols)
Definition Device.h:74
static dim3 GridDims2D(const AMatrix &A)
Definition Device.h:88
static constexpr int BlockSize
Definition Device.h:44
static dim3 BlockDims1D()
Definition Device.h:48
static dim3 GridDims1D(const AMatrix &A)
Definition Device.h:63
static int NThreads(const AMatrix &A)
Definition Device.h:101
create variable transformations