19#ifndef TMVA_DNN_ARCHITECTURES_CUDA_CUDAMATRIX
20#define TMVA_DNN_ARCHITECTURES_CUDA_CUDAMATRIX
24#include "RConfigure.h"
25#ifdef R__HAS_STD_STRING_VIEW
26#ifndef R__CUDA_HAS_STD_STRING_VIEW
27#undef R__HAS_STD_STRING_VIEW
28#define R__HAS_STD_EXPERIMENTAL_STRING_VIEW
33#include "cuda_runtime.h"
35#include "curand_kernel.h"
40#define CUDACHECK(ans) {cudaError((ans), __FILE__, __LINE__); }
48inline void cudaError(cudaError_t code,
const char *
file,
int line,
bool abort=
true);
63template<
typename AFloat>
107template<
typename AFloat>
206 if (code != cudaSuccess)
208 fprintf(stderr,
"CUDA Error: %s %s %d\n", cudaGetErrorString(code),
file,
line);
209 if (abort) exit(code);
214template<
typename AFloat>
216 : fDevicePointer(devicePointer)
222template<
typename AFloat>
226 cudaMemcpy(& buffer, fDevicePointer,
sizeof(AFloat),
227 cudaMemcpyDeviceToHost);
232template<
typename AFloat>
236 cudaMemcpyDeviceToDevice);
240template<
typename AFloat>
243 AFloat buffer = value;
244 cudaMemcpy(fDevicePointer, & buffer,
sizeof(AFloat),
245 cudaMemcpyHostToDevice);
249template<
typename AFloat>
253 cudaMemcpy(& buffer, fDevicePointer,
sizeof(AFloat),
254 cudaMemcpyDeviceToHost);
256 cudaMemcpy(fDevicePointer, & buffer,
sizeof(AFloat),
257 cudaMemcpyHostToDevice);
261template<
typename AFloat>
265 cudaMemcpy(& buffer, fDevicePointer,
sizeof(AFloat),
266 cudaMemcpyDeviceToHost);
268 cudaMemcpy(fDevicePointer, & buffer,
sizeof(AFloat),
269 cudaMemcpyHostToDevice);
273template<
typename AFloat>
276 return fElementBuffer.GetComputeStream();
280template<
typename AFloat>
283 return fElementBuffer.SetComputeStream(stream);
287template<
typename AFloat>
291 cudaEventCreateWithFlags(&
event, cudaEventDisableTiming);
293 cudaStreamWaitEvent(fElementBuffer.GetComputeStream(),
event, 0);
294 cudaEventDestroy(
event);
298template<
typename AFloat>
301 AFloat buffer = value;
302 cudaMemcpy(fDeviceReturn, & buffer,
sizeof(AFloat), cudaMemcpyHostToDevice);
306template<
typename AFloat>
310 cudaMemcpy(& buffer, fDeviceReturn,
sizeof(AFloat), cudaMemcpyDeviceToHost);
315template<
typename AFloat>
318 AFloat * elementPointer = fElementBuffer;
319 elementPointer += j * fNRows + i;
void operator-=(AFloat value)
TCudaDeviceReference(AFloat *devicePointer)
void operator=(const TCudaDeviceReference &other)
void operator+=(AFloat value)
TCudaDeviceBuffer< AFloat > fElementBuffer
TCudaMatrix & operator=(const TCudaMatrix &)=default
static curandState_t * fCurandStates
TCudaMatrix(const TMatrixT< AFloat > &)
static AFloat GetDeviceReturn()
Transfer the value in the device return buffer to the host.
void SetComputeStream(cudaStream_t stream)
static AFloat * fDeviceReturn
Buffer for kernel return values.
cudaStream_t GetComputeStream() const
size_t GetNoElements() const
void InitializeCuda()
Initializes all shared devices resource and makes sure that a sufficient number of curand states are ...
static Bool_t gInitializeCurand
TCudaDeviceReference< AFloat > operator()(size_t i, size_t j) const
Access to elements of device matrices provided through TCudaDeviceReference class.
static AFloat * GetDeviceReturnPointer()
Return device pointer to the device return buffer.
const cublasHandle_t & GetCublasHandle() const
static AFloat * fOnes
Vector used for summations of columns.
static void ResetDeviceReturn(AFloat value=0.0)
Set the return buffer on the device to the specified value.
const AFloat * GetDataPointer() const
TCudaMatrix(TCudaDeviceBuffer< AFloat > buffer, size_t m, size_t n)
static size_t fNCurandStates
TCudaMatrix(const TCudaMatrix &)=default
TCudaDeviceBuffer< AFloat > GetDeviceBuffer() const
void Synchronize(const TCudaMatrix &) const
Blocking synchronization with the associated compute stream, if it's not the default stream.
static AFloat * GetOnes()
static cublasHandle_t fCublasHandle
static size_t fInstances
Current number of matrix instances.
TCudaMatrix(size_t i, size_t j)
TCudaMatrix & operator=(TCudaMatrix &&)=default
void InitializeCurandStates()
AFloat * GetDataPointer()
static size_t fNOnes
Current length of the one vector.
TCudaMatrix(TCudaMatrix &&)=default
static curandState_t * GetCurandStatesPointer()
void Print(Option_t *name="") const
Print the matrix as a table of elements.
void cudaError(cudaError_t code, const char *file, int line, bool abort=true)
Function to check cuda return code.
create variable transformations