Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
Blas.h
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Author: Simon Pfreundschuh 20/07/16
3
4/*************************************************************************
5 * Copyright (C) 2016, Simon Pfreundschuh *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12///////////////////////////////////////////////////////////////////
13// Declarations of the BLAS functions used for the forward and //
14// backward propagation of activation through neural networks on //
15// CPUs. //
16///////////////////////////////////////////////////////////////////
17
18#ifndef TMVA_DNN_ARCHITECTURES_CPU_BLAS
19#define TMVA_DNN_ARCHITECTURES_CPU_BLAS
20
21#include <iostream>
22
23#ifndef DNN_USE_CBLAS
24// External Library Routines
25//____________________________________________________________________________
26extern "C" void saxpy_(const int * n, const float * alpha, const float * x,
27 const int * incx, float * y, const int * incy);
28extern "C" void daxpy_(const int * n, const double * alpha, const double * x,
29 const int * incx, double * y, const int * incy);
30extern "C" void sger_(const int * m, const int * n, const float * alpha,
31 const float * x, const int * incx,
32 const float * y, const int * incy,
33 float * A, const int * lda);
34extern "C" void dger_(const int * m, const int * n, const double * alpha,
35 const double * x, const int * incx,
36 const double * y, const int * incy,
37 double * A, const int * lda);
38extern "C" void sgemv_(const char * trans, const int * m, const int * n,
39 const float * alpha, const float * A, const int * lda,
40 const float * x, const int * incx,
41 const float * beta, float * y, const int * incy);
42extern "C" void dgemv_(const char * trans, const int * m, const int * n,
43 const double * alpha, const double * A, const int * lda,
44 const double * x, const int * incx,
45 const double * beta, double * y, const int * incy);
46extern "C" void dgemm_(const char * transa, const char * transb,
47 const int * m, const int * n, const int * k,
48 const double * alpha, const double * A, const int * lda,
49 const double * B, const int * ldb, const double * beta,
50 double * C, const int * ldc);
51extern "C" void sgemm_(const char * transa, const char * transb,
52 const int * m, const int * n, const int * k,
53 const float * alpha, const float * A, const int * lda,
54 const float * B, const int * ldb, const float * beta,
55 float * C, const int * ldc);
56
57#else
58#include "gsl/gsl_cblas.h"
59#endif
60
61namespace TMVA
62{
63namespace DNN
64{
65namespace Blas
66{
67
68// Type-Generic Wrappers
69//____________________________________________________________________________
70/** Add the vector \p x scaled by \p alpha to \p y scaled by \beta */
71template <typename AReal>
72inline void Axpy(const int * n, const AReal * alpha,
73 const AReal * x, const int * incx,
74 AReal * y, const int * incy);
75
76/** Multiply the vector \p x with the matrix \p A and store the result in \p y. */
77template <typename AReal>
78inline void Gemv(const char *trans, const int * m, const int * n,
79 const AReal * alpha, const AReal * A, const int * lda,
80 const AReal * x, const int * incx,
81 const AReal * beta, AReal * y, const int * incy);
82
83/** Multiply the matrix \p A with the matrix \p B and store the result in \p C. */
84template <typename AReal>
85inline void Gemm(const char *transa, const char *transb,
86 const int * m, const int * n, const int* k,
87 const AReal * alpha, const AReal * A, const int * lda,
88 const AReal * B, const int * ldb, const AReal * beta,
89 AReal * C, const int * ldc);
90
91/** Add the outer product of \p x and \p y to the matrix \p A. */
92template <typename AReal>
93inline void Ger(const int * m, const int * n, const AReal * alpha,
94 const AReal * x, const int * incx,
95 const AReal * y, const int * incy,
96 AReal * A, const int * lda);
97
98// Specializations
99//____________________________________________________________________________
100#ifndef DNN_USE_CBLAS
101
102template<>
103inline void Axpy<double>(const int * n, const double * alpha,
104 const double * x, const int * incx,
105 double * y, const int * incy)
106{
107 daxpy_(n, alpha, x, incx, y, incy);
108}
109
110template<>
111inline void Axpy<float>(const int * n, const float * alpha,
112 const float * x, const int * incx,
113 float * y, const int * incy)
114{
115 saxpy_(n, alpha, x, incx, y, incy);
116}
117
118template<>
119inline void Gemv<double>(const char *trans, const int * m, const int * n,
120 const double * alpha, const double * A, const int * lda,
121 const double * x, const int * incx,
122 const double * beta, double * y, const int * incy)
123{
124 dgemv_(trans, m, n, alpha, A, lda, x, incx, beta, y, incy);
125}
126
127template<>
128inline void Gemv<float>(const char *trans, const int * m, const int * n,
129 const float * alpha, const float * A, const int * lda,
130 const float * x, const int * incx,
131 const float * beta, float * y, const int * incy)
132{
133 sgemv_(trans, m, n, alpha, A, lda, x, incx, beta, y, incy);
134}
135
136template<>
137inline void Gemm<double>(const char *transa, const char *transb,
138 const int * m, const int * n, const int* k,
139 const double * alpha, const double * A, const int * lda,
140 const double * B, const int * ldb, const double * beta,
141 double * C, const int * ldc)
142{
143 dgemm_(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
144}
145
146template<>
147inline void Gemm<float>(const char *transa, const char *transb,
148 const int * m, const int * n, const int* k,
149 const float * alpha, const float * A, const int * lda,
150 const float * B, const int * ldb, const float * beta,
151 float * C, const int * ldc)
152{
153 sgemm_(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
154}
155
156template <>
157inline void Ger<double>(const int * m, const int * n, const double * alpha,
158 const double * x, const int * incx,
159 const double * y, const int * incy,
160 double * A, const int * lda)
161{
162 dger_(m, n, alpha, x, incx, y, incy, A, lda);
163}
164
165template <>
166inline void Ger<float>(const int * m, const int * n, const float * alpha,
167 const float * x, const int * incx,
168 const float * y, const int * incy,
169 float * A, const int * lda)
170{
171 sger_(m, n, alpha, x, incx, y, incy, A, lda);
172}
173
174#else // use cblas
175//--------------------------------------------------------
176// cblas implementation
177//-----------------------------------------------------------
178template<>
179inline void Axpy<double>(const int * n, const double * alpha,
180 const double * x, const int * incx,
181 double * y, const int * incy)
182{
183 cblas_daxpy(*n, *alpha, x, *incx, y, *incy);
184}
185
186template<>
187inline void Axpy<float>(const int * n, const float * alpha,
188 const float * x, const int * incx,
189 float * y, const int * incy)
190{
191 cblas_saxpy(*n, *alpha, x, *incx, y, *incy);
192}
193
194template<>
195inline void Gemv<double>(const char *trans, const int * m, const int * n,
196 const double * alpha, const double * A, const int * lda,
197 const double * x, const int * incx,
198 const double * beta, double * y, const int * incy)
199{
200 CBLAS_TRANSPOSE kTrans = (*trans == 'T') ? CblasTrans : CblasNoTrans;
201 cblas_dgemv(CblasColMajor, kTrans, *m, *n, *alpha, A, *lda, x, *incx, *beta, y, *incy);
202}
203
204template<>
205inline void Gemv<float>(const char *trans, const int * m, const int * n,
206 const float * alpha, const float * A, const int * lda,
207 const float * x, const int * incx,
208 const float * beta, float * y, const int * incy)
209{
210 CBLAS_TRANSPOSE kTrans = (*trans == 'T') ? CblasTrans : CblasNoTrans;
211 cblas_sgemv(CblasColMajor, kTrans, *m, *n, *alpha, A, *lda, x, *incx, *beta, y, *incy);
212}
213
214template<>
215inline void Gemm<double>(const char *transa, const char *transb,
216 const int * m, const int * n, const int* k,
217 const double * alpha, const double * A, const int * lda,
218 const double * B, const int * ldb, const double * beta,
219 double * C, const int * ldc)
220{
221 CBLAS_TRANSPOSE kTransA = (*transa == 'T') ? CblasTrans : CblasNoTrans;
222 CBLAS_TRANSPOSE kTransB = (*transb == 'T') ? CblasTrans : CblasNoTrans;
223 cblas_dgemm(CblasColMajor, kTransA, kTransB, *m, *n, *k, *alpha, A, *lda, B, *ldb, *beta, C, *ldc);
224}
225
226template<>
227inline void Gemm<float>(const char *transa, const char *transb,
228 const int * m, const int * n, const int* k,
229 const float * alpha, const float * A, const int * lda,
230 const float * B, const int * ldb, const float * beta,
231 float * C, const int * ldc)
232{
233 CBLAS_TRANSPOSE kTransA = (*transa == 'T') ? CblasTrans : CblasNoTrans;
234 CBLAS_TRANSPOSE kTransB = (*transb == 'T') ? CblasTrans : CblasNoTrans;
235 cblas_sgemm(CblasColMajor, kTransA, kTransB, *m, *n, *k, *alpha, A, *lda, B, *ldb, *beta, C, *ldc);
236}
237
238template <>
239inline void Ger<double>(const int * m, const int * n, const double * alpha,
240 const double * x, const int * incx,
241 const double * y, const int * incy,
242 double * A, const int * lda)
243{
244 cblas_dger(CblasColMajor, *m, *n, *alpha, x, *incx, y, *incy, A, *lda);
245}
246
247template <>
248inline void Ger<float>(const int * m, const int * n, const float * alpha,
249 const float * x, const int * incx,
250 const float * y, const int * incy,
251 float * A, const int * lda)
252{
253 cblas_sger(CblasColMajor, *m, *n, *alpha, x, *incx, y, *incy, A, *lda);
254}
255
256#endif
257
258} // namespace Blas
259} // namespace DNN
260} // namespace TMVA
261
262#endif
void dgemv_(const char *trans, const int *m, const int *n, const double *alpha, const double *A, const int *lda, const double *x, const int *incx, const double *beta, double *y, const int *incy)
void dger_(const int *m, const int *n, const double *alpha, const double *x, const int *incx, const double *y, const int *incy, double *A, const int *lda)
void sger_(const int *m, const int *n, const float *alpha, const float *x, const int *incx, const float *y, const int *incy, float *A, const int *lda)
void sgemv_(const char *trans, const int *m, const int *n, const float *alpha, const float *A, const int *lda, const float *x, const int *incx, const float *beta, float *y, const int *incy)
void saxpy_(const int *n, const float *alpha, const float *x, const int *incx, float *y, const int *incy)
void dgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const double *alpha, const double *A, const int *lda, const double *B, const int *ldb, const double *beta, double *C, const int *ldc)
void daxpy_(const int *n, const double *alpha, const double *x, const int *incx, double *y, const int *incy)
void sgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const float *alpha, const float *A, const int *lda, const float *B, const int *ldb, const float *beta, float *C, const int *ldc)
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
const Int_t n
Definition legend1.C:16
void Gemm< float >(const char *transa, const char *transb, const int *m, const int *n, const int *k, const float *alpha, const float *A, const int *lda, const float *B, const int *ldb, const float *beta, float *C, const int *ldc)
Definition Blas.h:147
void Axpy(const int *n, const AReal *alpha, const AReal *x, const int *incx, AReal *y, const int *incy)
Add the vector x scaled by alpha to y scaled by \beta.
void Gemv< float >(const char *trans, const int *m, const int *n, const float *alpha, const float *A, const int *lda, const float *x, const int *incx, const float *beta, float *y, const int *incy)
Definition Blas.h:128
void Gemm(const char *transa, const char *transb, const int *m, const int *n, const int *k, const AReal *alpha, const AReal *A, const int *lda, const AReal *B, const int *ldb, const AReal *beta, AReal *C, const int *ldc)
Multiply the matrix A with the matrix B and store the result in C.
void Ger< float >(const int *m, const int *n, const float *alpha, const float *x, const int *incx, const float *y, const int *incy, float *A, const int *lda)
Definition Blas.h:166
void Gemm< double >(const char *transa, const char *transb, const int *m, const int *n, const int *k, const double *alpha, const double *A, const int *lda, const double *B, const int *ldb, const double *beta, double *C, const int *ldc)
Definition Blas.h:137
void Gemv(const char *trans, const int *m, const int *n, const AReal *alpha, const AReal *A, const int *lda, const AReal *x, const int *incx, const AReal *beta, AReal *y, const int *incy)
Multiply the vector x with the matrix A and store the result in y.
void Ger< double >(const int *m, const int *n, const double *alpha, const double *x, const int *incx, const double *y, const int *incy, double *A, const int *lda)
Definition Blas.h:157
void Axpy< float >(const int *n, const float *alpha, const float *x, const int *incx, float *y, const int *incy)
Definition Blas.h:111
void Axpy< double >(const int *n, const double *alpha, const double *x, const int *incx, double *y, const int *incy)
Definition Blas.h:103
void Gemv< double >(const char *trans, const int *m, const int *n, const double *alpha, const double *A, const int *lda, const double *x, const int *incx, const double *beta, double *y, const int *incy)
Definition Blas.h:119
void Ger(const int *m, const int *n, const AReal *alpha, const AReal *x, const int *incx, const AReal *y, const int *incy, AReal *A, const int *lda)
Add the outer product of x and y to the matrix A.
create variable transformations
auto * m
Definition textangle.C:8