18#ifndef TMVA_DNN_ARCHITECTURES_CPU_BLAS
19#define TMVA_DNN_ARCHITECTURES_CPU_BLAS
26extern "C" void saxpy_(
const int *
n,
const float * alpha,
const float *
x,
27 const int * incx,
float *
y,
const int * incy);
28extern "C" void daxpy_(
const int *
n,
const double * alpha,
const double *
x,
29 const int * incx,
double *
y,
const int * incy);
30extern "C" void sger_(
const int *
m,
const int *
n,
const float * alpha,
31 const float *
x,
const int * incx,
32 const float *
y,
const int * incy,
33 float * A,
const int * lda);
34extern "C" void dger_(
const int *
m,
const int *
n,
const double * alpha,
35 const double *
x,
const int * incx,
36 const double *
y,
const int * incy,
37 double * A,
const int * lda);
38extern "C" void sgemv_(
const char * trans,
const int *
m,
const int *
n,
39 const float * alpha,
const float * A,
const int * lda,
40 const float *
x,
const int * incx,
41 const float * beta,
float *
y,
const int * incy);
42extern "C" void dgemv_(
const char * trans,
const int *
m,
const int *
n,
43 const double * alpha,
const double * A,
const int * lda,
44 const double *
x,
const int * incx,
45 const double * beta,
double *
y,
const int * incy);
46extern "C" void dgemm_(
const char * transa,
const char * transb,
47 const int *
m,
const int *
n,
const int * k,
48 const double * alpha,
const double * A,
const int * lda,
49 const double * B,
const int * ldb,
const double * beta,
50 double * C,
const int * ldc);
51extern "C" void sgemm_(
const char * transa,
const char * transb,
52 const int *
m,
const int *
n,
const int * k,
53 const float * alpha,
const float * A,
const int * lda,
54 const float * B,
const int * ldb,
const float * beta,
55 float * C,
const int * ldc);
58#include "gsl/gsl_cblas.h"
71template <
typename AReal>
73 const AReal *
x,
const int * incx,
74 AReal *
y,
const int * incy);
77template <
typename AReal>
78inline void Gemv(
const char *trans,
const int *
m,
const int *
n,
79 const AReal * alpha,
const AReal * A,
const int * lda,
80 const AReal *
x,
const int * incx,
84template <
typename AReal>
85inline void Gemm(
const char *transa,
const char *transb,
86 const int *
m,
const int *
n,
const int* k,
87 const AReal * alpha,
const AReal * A,
const int * lda,
88 const AReal * B,
const int * ldb,
const AReal * beta,
89 AReal * C,
const int * ldc);
92template <
typename AReal>
93inline void Ger(
const int *
m,
const int *
n,
const AReal * alpha,
94 const AReal *
x,
const int * incx,
95 const AReal *
y,
const int * incy,
96 AReal * A,
const int * lda);
104 const double *
x,
const int * incx,
105 double *
y,
const int * incy)
112 const float *
x,
const int * incx,
113 float *
y,
const int * incy)
120 const double * alpha,
const double * A,
const int * lda,
121 const double *
x,
const int * incx,
122 const double * beta,
double *
y,
const int * incy)
124 dgemv_(trans,
m,
n, alpha, A, lda,
x, incx, beta,
y, incy);
129 const float * alpha,
const float * A,
const int * lda,
130 const float *
x,
const int * incx,
131 const float * beta,
float *
y,
const int * incy)
133 sgemv_(trans,
m,
n, alpha, A, lda,
x, incx, beta,
y, incy);
138 const int *
m,
const int *
n,
const int* k,
139 const double * alpha,
const double * A,
const int * lda,
140 const double * B,
const int * ldb,
const double * beta,
141 double * C,
const int * ldc)
143 dgemm_(transa, transb,
m,
n, k, alpha, A, lda, B, ldb, beta, C, ldc);
148 const int *
m,
const int *
n,
const int* k,
149 const float * alpha,
const float * A,
const int * lda,
150 const float * B,
const int * ldb,
const float * beta,
151 float * C,
const int * ldc)
153 sgemm_(transa, transb,
m,
n, k, alpha, A, lda, B, ldb, beta, C, ldc);
158 const double *
x,
const int * incx,
159 const double *
y,
const int * incy,
160 double * A,
const int * lda)
162 dger_(
m,
n, alpha,
x, incx,
y, incy, A, lda);
166inline void Ger<float>(
const int *
m,
const int *
n,
const float * alpha,
167 const float *
x,
const int * incx,
168 const float *
y,
const int * incy,
169 float * A,
const int * lda)
171 sger_(
m,
n, alpha,
x, incx,
y, incy, A, lda);
180 const double *
x,
const int * incx,
181 double *
y,
const int * incy)
183 cblas_daxpy(*
n, *alpha,
x, *incx,
y, *incy);
187inline void Axpy<float>(
const int *
n,
const float * alpha,
188 const float *
x,
const int * incx,
189 float *
y,
const int * incy)
191 cblas_saxpy(*
n, *alpha,
x, *incx,
y, *incy);
195inline void Gemv<double>(
const char *trans,
const int *
m,
const int *
n,
196 const double * alpha,
const double * A,
const int * lda,
197 const double *
x,
const int * incx,
198 const double * beta,
double *
y,
const int * incy)
200 CBLAS_TRANSPOSE kTrans = (*trans ==
'T') ? CblasTrans : CblasNoTrans;
201 cblas_dgemv(CblasColMajor, kTrans, *
m, *
n, *alpha, A, *lda,
x, *incx, *beta,
y, *incy);
205inline void Gemv<float>(
const char *trans,
const int *
m,
const int *
n,
206 const float * alpha,
const float * A,
const int * lda,
207 const float *
x,
const int * incx,
208 const float * beta,
float *
y,
const int * incy)
210 CBLAS_TRANSPOSE kTrans = (*trans ==
'T') ? CblasTrans : CblasNoTrans;
211 cblas_sgemv(CblasColMajor, kTrans, *
m, *
n, *alpha, A, *lda,
x, *incx, *beta,
y, *incy);
215inline void Gemm<double>(
const char *transa,
const char *transb,
216 const int *
m,
const int *
n,
const int* k,
217 const double * alpha,
const double * A,
const int * lda,
218 const double * B,
const int * ldb,
const double * beta,
219 double * C,
const int * ldc)
221 CBLAS_TRANSPOSE kTransA = (*transa ==
'T') ? CblasTrans : CblasNoTrans;
222 CBLAS_TRANSPOSE kTransB = (*transb ==
'T') ? CblasTrans : CblasNoTrans;
223 cblas_dgemm(CblasColMajor, kTransA, kTransB, *
m, *
n, *k, *alpha, A, *lda, B, *ldb, *beta, C, *ldc);
227inline void Gemm<float>(
const char *transa,
const char *transb,
228 const int *
m,
const int *
n,
const int* k,
229 const float * alpha,
const float * A,
const int * lda,
230 const float * B,
const int * ldb,
const float * beta,
231 float * C,
const int * ldc)
233 CBLAS_TRANSPOSE kTransA = (*transa ==
'T') ? CblasTrans : CblasNoTrans;
234 CBLAS_TRANSPOSE kTransB = (*transb ==
'T') ? CblasTrans : CblasNoTrans;
235 cblas_sgemm(CblasColMajor, kTransA, kTransB, *
m, *
n, *k, *alpha, A, *lda, B, *ldb, *beta, C, *ldc);
239inline void Ger<double>(
const int *
m,
const int *
n,
const double * alpha,
240 const double *
x,
const int * incx,
241 const double *
y,
const int * incy,
242 double * A,
const int * lda)
244 cblas_dger(CblasColMajor, *
m, *
n, *alpha,
x, *incx,
y, *incy, A, *lda);
248inline void Ger<float>(
const int *
m,
const int *
n,
const float * alpha,
249 const float *
x,
const int * incx,
250 const float *
y,
const int * incy,
251 float * A,
const int * lda)
253 cblas_sger(CblasColMajor, *
m, *
n, *alpha,
x, *incx,
y, *incy, A, *lda);
void dgemv_(const char *trans, const int *m, const int *n, const double *alpha, const double *A, const int *lda, const double *x, const int *incx, const double *beta, double *y, const int *incy)
void dger_(const int *m, const int *n, const double *alpha, const double *x, const int *incx, const double *y, const int *incy, double *A, const int *lda)
void sger_(const int *m, const int *n, const float *alpha, const float *x, const int *incx, const float *y, const int *incy, float *A, const int *lda)
void sgemv_(const char *trans, const int *m, const int *n, const float *alpha, const float *A, const int *lda, const float *x, const int *incx, const float *beta, float *y, const int *incy)
void saxpy_(const int *n, const float *alpha, const float *x, const int *incx, float *y, const int *incy)
void dgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const double *alpha, const double *A, const int *lda, const double *B, const int *ldb, const double *beta, double *C, const int *ldc)
void daxpy_(const int *n, const double *alpha, const double *x, const int *incx, double *y, const int *incy)
void sgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const float *alpha, const float *A, const int *lda, const float *B, const int *ldb, const float *beta, float *C, const int *ldc)
void Gemm< float >(const char *transa, const char *transb, const int *m, const int *n, const int *k, const float *alpha, const float *A, const int *lda, const float *B, const int *ldb, const float *beta, float *C, const int *ldc)
void Axpy(const int *n, const AReal *alpha, const AReal *x, const int *incx, AReal *y, const int *incy)
Add the vector x scaled by alpha to y scaled by \beta
void Gemv< float >(const char *trans, const int *m, const int *n, const float *alpha, const float *A, const int *lda, const float *x, const int *incx, const float *beta, float *y, const int *incy)
void Gemm(const char *transa, const char *transb, const int *m, const int *n, const int *k, const AReal *alpha, const AReal *A, const int *lda, const AReal *B, const int *ldb, const AReal *beta, AReal *C, const int *ldc)
Multiply the matrix A with the matrix B and store the result in C.
void Ger< float >(const int *m, const int *n, const float *alpha, const float *x, const int *incx, const float *y, const int *incy, float *A, const int *lda)
void Gemm< double >(const char *transa, const char *transb, const int *m, const int *n, const int *k, const double *alpha, const double *A, const int *lda, const double *B, const int *ldb, const double *beta, double *C, const int *ldc)
void Gemv(const char *trans, const int *m, const int *n, const AReal *alpha, const AReal *A, const int *lda, const AReal *x, const int *incx, const AReal *beta, AReal *y, const int *incy)
Multiply the vector x with the matrix A and store the result in y.
void Ger< double >(const int *m, const int *n, const double *alpha, const double *x, const int *incx, const double *y, const int *incy, double *A, const int *lda)
void Axpy< float >(const int *n, const float *alpha, const float *x, const int *incx, float *y, const int *incy)
void Axpy< double >(const int *n, const double *alpha, const double *x, const int *incx, double *y, const int *incy)
void Gemv< double >(const char *trans, const int *m, const int *n, const double *alpha, const double *A, const int *lda, const double *x, const int *incx, const double *beta, double *y, const int *incy)
void Ger(const int *m, const int *n, const AReal *alpha, const AReal *x, const int *incx, const AReal *y, const int *incy, AReal *A, const int *lda)
Add the outer product of x and y to the matrix A.
create variable transformations