Logo ROOT   6.14/05
Reference Guide
Device.h
Go to the documentation of this file.
1 // @(#)root/tmva/tmva/dnn:$Id$
2 // Author: Simon Pfreundschuh 13/07/16
3 
4 /*************************************************************************
5  * Copyright (C) 2016, Simon Pfreundschuh *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 ////////////////////////////////////////////////////////////////
13 // Defines the TDevice class which encapsules device specific //
14 // settings for the launching of threads. //
15 ////////////////////////////////////////////////////////////////
16 
17 #ifndef TMVA_DNN_ARCHITECTURES_CUDA_DEVICE
18 #define TMVA_DNN_ARCHITECTURES_CUDA_DEVICE
19 
20 #include "cuda.h"
21 #include "vector_types.h" // definition of dim3
22 #include "CudaMatrix.h"
23 
24 namespace TMVA
25 {
26 namespace DNN
27 {
28 
29 /** TDevice
30  *
31  * The TDevice class provides static functions for the generation of CUDA
32  * grids for kernel launches and is used to encapsulate the distribution
33  * of threads and blocks over the data.
34  *
35  */
36 class TDevice
37 {
38 public:
39  /* Number of threads per block along first dimensions. */
40  static constexpr int BlockDimX = 1;
41  /* Number of threads per block along second dimensions. */
42  static constexpr int BlockDimY = 32;
43  /* Resulting block size. */
44  static constexpr int BlockSize = BlockDimX * BlockDimY;
45 
46  /* Return 1D block of size 1 along the x-dimension and BlockSize along
47  * the y-dimension. */
48  static dim3 BlockDims1D()
49  {
50  return dim3(1, BlockSize);
51  }
52 
53  /* Return dim3 object representing a BlockDimX x BlockDimY 2D
54  * block */
55  static dim3 BlockDims2D()
56  {
57  return dim3(BlockDimX, BlockDimY);
58  }
59 
60  /* Return 1D dim3 object representing the block grid covering the row-range
61  * of the matrix A along the y-dimension. */
62  template<typename AFloat>
63  static dim3 GridDims1D(const TCudaMatrix<AFloat> &A)
64  {
65  int gridDim = A.GetNrows() / TDevice::BlockSize;
66  if ((A.GetNrows() % TDevice::BlockSize) != 0) {
67  gridDim += 1;
68  }
69  return dim3(1, gridDim);
70  }
71 
72  /* Return 2D dim3 object representing the block grid consisting of two-dimensional
73  * BlockDimX x BlockDimY blocks covering the matrix A */
74  template<typename AFloat>
75  static dim3 GridDims2D(const TCudaMatrix<AFloat> &A)
76  {
77  int gridDimX = A.GetNcols() / TDevice::BlockDimX;
78  if ((A.GetNcols() % TDevice::BlockDimX) != 0)
79  gridDimX += 1;
80  int gridDimY = A.GetNrows() / TDevice::BlockDimY;
81  if ((A.GetNrows() % TDevice::BlockDimY) != 0)
82  gridDimY += 1;
83  return dim3(gridDimX, gridDimY);
84  }
85 
86  /* Return the number of threads that will be launched for a given matrix \p A */
87  template<typename AFloat>
88  static int NThreads(const TCudaMatrix<AFloat> &A)
89  {
90  int gridDimX = A.GetNcols() / TDevice::BlockDimX;
91  if ((A.GetNcols() % TDevice::BlockDimX) != 0) {
92  gridDimX += 1;
93  }
94  int gridDimY = A.GetNrows() / TDevice::BlockDimY;
95  if ((A.GetNrows() % TDevice::BlockDimY) != 0) {
96  gridDimY += 1;
97  }
98  return gridDimX * gridDimY * TDevice::BlockDimX * TDevice::BlockDimY;
99  }
100 };
101 
102 } // namespace DNN
103 } // namespace TMVA
104 
105 #endif
static constexpr int BlockSize
Definition: Device.h:44
static dim3 GridDims1D(const TCudaMatrix< AFloat > &A)
Definition: Device.h:63
static double A[]
size_t GetNcols() const
Definition: CudaMatrix.h:152
static int NThreads(const TCudaMatrix< AFloat > &A)
Definition: Device.h:88
size_t GetNrows() const
Definition: CudaMatrix.h:151
TDevice.
Definition: Device.h:36
static constexpr int BlockDimY
Definition: Device.h:42
static dim3 BlockDims2D()
Definition: Device.h:55
static dim3 GridDims2D(const TCudaMatrix< AFloat > &A)
Definition: Device.h:75
Abstract ClassifierFactory template that handles arbitrary types.
static dim3 BlockDims1D()
Definition: Device.h:48
TCudaMatrix Class.
Definition: CudaMatrix.h:98
static constexpr int BlockDimX
Definition: Device.h:40