Logo ROOT   6.14/05
Reference Guide
Blas.h
Go to the documentation of this file.
1 // @(#)root/tmva/tmva/dnn:$Id$
2 // Author: Simon Pfreundschuh 20/07/16
3 
4 /*************************************************************************
5  * Copyright (C) 2016, Simon Pfreundschuh *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 ///////////////////////////////////////////////////////////////////
13 // Declarations of the BLAS functions used for the forward and //
14 // backward propagation of activation through neural networks on //
15 // CPUs. //
16 ///////////////////////////////////////////////////////////////////
17 
18 #ifndef TMVA_DNN_ARCHITECTURES_CPU_BLAS
19 #define TMVA_DNN_ARCHITECTURES_CPU_BLAS
20 
21 #include <iostream>
22 
23 #ifndef DNN_USE_CBLAS
24 // External Library Routines
25 //____________________________________________________________________________
26 extern "C" void saxpy_(const int * n, const float * alpha, const float * x,
27  const int * incx, float * y, const int * incy);
28 extern "C" void daxpy_(const int * n, const double * alpha, const double * x,
29  const int * incx, double * y, const int * incy);
30 extern "C" void sger_(const int * m, const int * n, const float * alpha,
31  const float * x, const int * incx,
32  const float * y, const int * incy,
33  float * A, const int * lda);
34 extern "C" void dger_(const int * m, const int * n, const double * alpha,
35  const double * x, const int * incx,
36  const double * y, const int * incy,
37  double * A, const int * lda);
38 extern "C" void sgemv_(const char * trans, const int * m, const int * n,
39  const float * alpha, const float * A, const int * lda,
40  const float * x, const int * incx,
41  const float * beta, float * y, const int * incy);
42 extern "C" void dgemv_(const char * trans, const int * m, const int * n,
43  const double * alpha, const double * A, const int * lda,
44  const double * x, const int * incx,
45  const double * beta, double * y, const int * incy);
46 extern "C" void dgemm_(const char * transa, const char * transb,
47  const int * m, const int * n, const int * k,
48  const double * alpha, const double * A, const int * lda,
49  const double * B, const int * ldb, const double * beta,
50  double * C, const int * ldc);
51 extern "C" void sgemm_(const char * transa, const char * transb,
52  const int * m, const int * n, const int * k,
53  const float * alpha, const float * A, const int * lda,
54  const float * B, const int * ldb, const float * beta,
55  float * C, const int * ldc);
56 
57 #else
58 #include "gsl/gsl_cblas.h"
59 #endif
60 
61 namespace TMVA
62 {
63 namespace DNN
64 {
65 namespace Blas
66 {
67 
68 // Type-Generic Wrappers
69 //____________________________________________________________________________
70 /** Add the vector \p x scaled by \p alpha to \p y scaled by \beta */
71 template <typename Real_t>
72 inline void Axpy(const int * n, const Real_t * alpha,
73  const Real_t * x, const int * incx,
74  Real_t * y, const int * incy);
75 
76 /** Multiply the vector \p x with the matrix \p A and store the result in \p y. */
77 template <typename Real_t>
78 inline void Gemv(const char *trans, const int * m, const int * n,
79  const Real_t * alpha, const Real_t * A, const int * lda,
80  const Real_t * x, const int * incx,
81  const Real_t * beta, Real_t * y, const int * incy);
82 
83 /** Multiply the matrix \p A with the matrix \p B and store the result in \p C. */
84 template <typename Real_t>
85 inline void Gemm(const char *transa, const char *transb,
86  const int * m, const int * n, const int* k,
87  const Real_t * alpha, const Real_t * A, const int * lda,
88  const Real_t * B, const int * ldb, const Real_t * beta,
89  Real_t * C, const int * ldc);
90 
91 /** Add the outer product of \p x and \p y to the matrix \p A. */
92 template <typename Real_t>
93 inline void Ger(const int * m, const int * n, const Real_t * alpha,
94  const Real_t * x, const int * incx,
95  const Real_t * y, const int * incy,
96  Real_t * A, const int * lda);
97 
98 // Specializations
99 //____________________________________________________________________________
100 #ifndef DNN_USE_CBLAS
101 
102 template<>
103 inline void Axpy<double>(const int * n, const double * alpha,
104  const double * x, const int * incx,
105  double * y, const int * incy)
106 {
107 #ifdef DNN_USE_CBLAS
108  cblas_daxpy(*n, *alpha, x, *incx, y, *incy);
109 #else
110  daxpy_(n, alpha, x, incx, y, incy);
111 #endif
112 }
113 
114 template<>
115 inline void Axpy<float>(const int * n, const float * alpha,
116  const float * x, const int * incx,
117  float * y, const int * incy)
118 {
119  saxpy_(n, alpha, x, incx, y, incy);
120 }
121 
122 template<>
123 inline void Gemv<double>(const char *trans, const int * m, const int * n,
124  const double * alpha, const double * A, const int * lda,
125  const double * x, const int * incx,
126  const double * beta, double * y, const int * incy)
127 {
128  dgemv_(trans, m, n, alpha, A, lda, x, incx, beta, y, incy);
129 }
130 
131 template<>
132 inline void Gemv<float>(const char *trans, const int * m, const int * n,
133  const float * alpha, const float * A, const int * lda,
134  const float * x, const int * incx,
135  const float * beta, float * y, const int * incy)
136 {
137  sgemv_(trans, m, n, alpha, A, lda, x, incx, beta, y, incy);
138 }
139 
140 template<>
141 inline void Gemm<double>(const char *transa, const char *transb,
142  const int * m, const int * n, const int* k,
143  const double * alpha, const double * A, const int * lda,
144  const double * B, const int * ldb, const double * beta,
145  double * C, const int * ldc)
146 {
147  dgemm_(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
148 }
149 
150 template<>
151 inline void Gemm<float>(const char *transa, const char *transb,
152  const int * m, const int * n, const int* k,
153  const float * alpha, const float * A, const int * lda,
154  const float * B, const int * ldb, const float * beta,
155  float * C, const int * ldc)
156 {
157  sgemm_(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
158 }
159 
160 template <>
161 inline void Ger<double>(const int * m, const int * n, const double * alpha,
162  const double * x, const int * incx,
163  const double * y, const int * incy,
164  double * A, const int * lda)
165 {
166  dger_(m, n, alpha, x, incx, y, incy, A, lda);
167 }
168 
169 template <>
170 inline void Ger<float>(const int * m, const int * n, const float * alpha,
171  const float * x, const int * incx,
172  const float * y, const int * incy,
173  float * A, const int * lda)
174 {
175  sger_(m, n, alpha, x, incx, y, incy, A, lda);
176 }
177 
178 #else
179 //--------------------------------------------------------
180 // cblas implementation
181 //-----------------------------------------------------------
182 template<>
183 inline void Axpy<double>(const int * n, const double * alpha,
184  const double * x, const int * incx,
185  double * y, const int * incy)
186 {
187  cblas_daxpy(*n, *alpha, x, *incx, y, *incy);
188 }
189 
190 template<>
191 inline void Axpy<float>(const int * n, const float * alpha,
192  const float * x, const int * incx,
193  float * y, const int * incy)
194 {
195  cblas_saxpy(*n, *alpha, x, *incx, y, *incy);
196 }
197 
198 template<>
199 inline void Gemv<double>(const char *trans, const int * m, const int * n,
200  const double * alpha, const double * A, const int * lda,
201  const double * x, const int * incx,
202  const double * beta, double * y, const int * incy)
203 {
204  CBLAS_TRANSPOSE kTrans = (*trans == 'T') ? CblasTrans : CblasNoTrans;
205  cblas_dgemv(CblasColMajor, kTrans, *m, *n, *alpha, A, *lda, x, *incx, *beta, y, *incy);
206 }
207 
208 template<>
209 inline void Gemv<float>(const char *trans, const int * m, const int * n,
210  const float * alpha, const float * A, const int * lda,
211  const float * x, const int * incx,
212  const float * beta, float * y, const int * incy)
213 {
214  CBLAS_TRANSPOSE kTrans = (*trans == 'T') ? CblasTrans : CblasNoTrans;
215  cblas_sgemv(CblasColMajor, kTrans, *m, *n, *alpha, A, *lda, x, *incx, *beta, y, *incy);
216 }
217 
218 template<>
219 inline void Gemm<double>(const char *transa, const char *transb,
220  const int * m, const int * n, const int* k,
221  const double * alpha, const double * A, const int * lda,
222  const double * B, const int * ldb, const double * beta,
223  double * C, const int * ldc)
224 {
225  CBLAS_TRANSPOSE kTransA = (*transa == 'T') ? CblasTrans : CblasNoTrans;
226  CBLAS_TRANSPOSE kTransB = (*transb == 'T') ? CblasTrans : CblasNoTrans;
227  cblas_dgemm(CblasColMajor, kTransA, kTransB, *m, *n, *k, *alpha, A, *lda, B, *ldb, *beta, C, *ldc);
228 }
229 
230 template<>
231 inline void Gemm<float>(const char *transa, const char *transb,
232  const int * m, const int * n, const int* k,
233  const float * alpha, const float * A, const int * lda,
234  const float * B, const int * ldb, const float * beta,
235  float * C, const int * ldc)
236 {
237  CBLAS_TRANSPOSE kTransA = (*transa == 'T') ? CblasTrans : CblasNoTrans;
238  CBLAS_TRANSPOSE kTransB = (*transb == 'T') ? CblasTrans : CblasNoTrans;
239  cblas_sgemm(CblasColMajor, kTransA, kTransB, *m, *n, *k, *alpha, A, *lda, B, *ldb, *beta, C, *ldc);
240 }
241 
242 template <>
243 inline void Ger<double>(const int * m, const int * n, const double * alpha,
244  const double * x, const int * incx,
245  const double * y, const int * incy,
246  double * A, const int * lda)
247 {
248  cblas_dger(CblasColMajor, *m, *n, *alpha, x, *incx, y, *incy, A, *lda);
249 }
250 
251 template <>
252 inline void Ger<float>(const int * m, const int * n, const float * alpha,
253  const float * x, const int * incx,
254  const float * y, const int * incy,
255  float * A, const int * lda)
256 {
257  cblas_sger(CblasColMajor, *m, *n, *alpha, x, *incx, y, *incy, A, *lda);
258 }
259 
260 #endif
261 
262 } // namespace Blas
263 } // namespace DNN
264 } // namespace TMVA
265 
266 #endif
static double B[]
void sgemv_(const char *trans, const int *m, const int *n, const float *alpha, const float *A, const int *lda, const float *x, const int *incx, const float *beta, float *y, const int *incy)
auto * m
Definition: textangle.C:8
void Axpy< double >(const int *n, const double *alpha, const double *x, const int *incx, double *y, const int *incy)
Definition: Blas.h:103
void Ger(const int *m, const int *n, const Real_t *alpha, const Real_t *x, const int *incx, const Real_t *y, const int *incy, Real_t *A, const int *lda)
Add the outer product of x and y to the matrix A.
void Ger< float >(const int *m, const int *n, const float *alpha, const float *x, const int *incx, const float *y, const int *incy, float *A, const int *lda)
Definition: Blas.h:170
void Axpy< float >(const int *n, const float *alpha, const float *x, const int *incx, float *y, const int *incy)
Definition: Blas.h:115
static double A[]
void saxpy_(const int *n, const float *alpha, const float *x, const int *incx, float *y, const int *incy)
double beta(double x, double y)
Calculates the beta function.
void dgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const double *alpha, const double *A, const int *lda, const double *B, const int *ldb, const double *beta, double *C, const int *ldc)
void Gemm(const char *transa, const char *transb, const int *m, const int *n, const int *k, const Real_t *alpha, const Real_t *A, const int *lda, const Real_t *B, const int *ldb, const Real_t *beta, Real_t *C, const int *ldc)
Multiply the matrix A with the matrix B and store the result in C.
void Ger< double >(const int *m, const int *n, const double *alpha, const double *x, const int *incx, const double *y, const int *incy, double *A, const int *lda)
Definition: Blas.h:161
Double_t x[n]
Definition: legend1.C:17
void dger_(const int *m, const int *n, const double *alpha, const double *x, const int *incx, const double *y, const int *incy, double *A, const int *lda)
void Gemm< float >(const char *transa, const char *transb, const int *m, const int *n, const int *k, const float *alpha, const float *A, const int *lda, const float *B, const int *ldb, const float *beta, float *C, const int *ldc)
Definition: Blas.h:151
void Gemv< float >(const char *trans, const int *m, const int *n, const float *alpha, const float *A, const int *lda, const float *x, const int *incx, const float *beta, float *y, const int *incy)
Definition: Blas.h:132
void Axpy(const int *n, const Real_t *alpha, const Real_t *x, const int *incx, Real_t *y, const int *incy)
Add the vector x scaled by alpha to y scaled by .
static double C[]
void Gemv(const char *trans, const int *m, const int *n, const Real_t *alpha, const Real_t *A, const int *lda, const Real_t *x, const int *incx, const Real_t *beta, Real_t *y, const int *incy)
Multiply the vector x with the matrix A and store the result in y.
Double_t y[n]
Definition: legend1.C:17
void Gemm< double >(const char *transa, const char *transb, const int *m, const int *n, const int *k, const double *alpha, const double *A, const int *lda, const double *B, const int *ldb, const double *beta, double *C, const int *ldc)
Definition: Blas.h:141
void Gemv< double >(const char *trans, const int *m, const int *n, const double *alpha, const double *A, const int *lda, const double *x, const int *incx, const double *beta, double *y, const int *incy)
Definition: Blas.h:123
float Real_t
Definition: RtypesCore.h:64
Abstract ClassifierFactory template that handles arbitrary types.
void dgemv_(const char *trans, const int *m, const int *n, const double *alpha, const double *A, const int *lda, const double *x, const int *incx, const double *beta, double *y, const int *incy)
void sgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const float *alpha, const float *A, const int *lda, const float *B, const int *ldb, const float *beta, float *C, const int *ldc)
void daxpy_(const int *n, const double *alpha, const double *x, const int *incx, double *y, const int *incy)
void sger_(const int *m, const int *n, const float *alpha, const float *x, const int *incx, const float *y, const int *incy, float *A, const int *lda)
const Int_t n
Definition: legend1.C:16