Logo ROOT   6.16/01
Reference Guide
Blas.h
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Author: Simon Pfreundschuh 20/07/16
3
4/*************************************************************************
5 * Copyright (C) 2016, Simon Pfreundschuh *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12///////////////////////////////////////////////////////////////////
13// Declarations of the BLAS functions used for the forward and //
14// backward propagation of activation through neural networks on //
15// CPUs. //
16///////////////////////////////////////////////////////////////////
17
18#ifndef TMVA_DNN_ARCHITECTURES_CPU_BLAS
19#define TMVA_DNN_ARCHITECTURES_CPU_BLAS
20
21#include <iostream>
22
23#ifndef DNN_USE_CBLAS
24// External Library Routines
25//____________________________________________________________________________
26extern "C" void saxpy_(const int * n, const float * alpha, const float * x,
27 const int * incx, float * y, const int * incy);
28extern "C" void daxpy_(const int * n, const double * alpha, const double * x,
29 const int * incx, double * y, const int * incy);
30extern "C" void sger_(const int * m, const int * n, const float * alpha,
31 const float * x, const int * incx,
32 const float * y, const int * incy,
33 float * A, const int * lda);
34extern "C" void dger_(const int * m, const int * n, const double * alpha,
35 const double * x, const int * incx,
36 const double * y, const int * incy,
37 double * A, const int * lda);
38extern "C" void sgemv_(const char * trans, const int * m, const int * n,
39 const float * alpha, const float * A, const int * lda,
40 const float * x, const int * incx,
41 const float * beta, float * y, const int * incy);
42extern "C" void dgemv_(const char * trans, const int * m, const int * n,
43 const double * alpha, const double * A, const int * lda,
44 const double * x, const int * incx,
45 const double * beta, double * y, const int * incy);
46extern "C" void dgemm_(const char * transa, const char * transb,
47 const int * m, const int * n, const int * k,
48 const double * alpha, const double * A, const int * lda,
49 const double * B, const int * ldb, const double * beta,
50 double * C, const int * ldc);
51extern "C" void sgemm_(const char * transa, const char * transb,
52 const int * m, const int * n, const int * k,
53 const float * alpha, const float * A, const int * lda,
54 const float * B, const int * ldb, const float * beta,
55 float * C, const int * ldc);
56
57#else
58#include "gsl/gsl_cblas.h"
59#endif
60
61namespace TMVA
62{
63namespace DNN
64{
65namespace Blas
66{
67
68// Type-Generic Wrappers
69//____________________________________________________________________________
70/** Add the vector \p x scaled by \p alpha to \p y scaled by \beta */
71template <typename Real_t>
72inline void Axpy(const int * n, const Real_t * alpha,
73 const Real_t * x, const int * incx,
74 Real_t * y, const int * incy);
75
76/** Multiply the vector \p x with the matrix \p A and store the result in \p y. */
77template <typename Real_t>
78inline void Gemv(const char *trans, const int * m, const int * n,
79 const Real_t * alpha, const Real_t * A, const int * lda,
80 const Real_t * x, const int * incx,
81 const Real_t * beta, Real_t * y, const int * incy);
82
83/** Multiply the matrix \p A with the matrix \p B and store the result in \p C. */
84template <typename Real_t>
85inline void Gemm(const char *transa, const char *transb,
86 const int * m, const int * n, const int* k,
87 const Real_t * alpha, const Real_t * A, const int * lda,
88 const Real_t * B, const int * ldb, const Real_t * beta,
89 Real_t * C, const int * ldc);
90
91/** Add the outer product of \p x and \p y to the matrix \p A. */
92template <typename Real_t>
93inline void Ger(const int * m, const int * n, const Real_t * alpha,
94 const Real_t * x, const int * incx,
95 const Real_t * y, const int * incy,
96 Real_t * A, const int * lda);
97
98// Specializations
99//____________________________________________________________________________
100#ifndef DNN_USE_CBLAS
101
102template<>
103inline void Axpy<double>(const int * n, const double * alpha,
104 const double * x, const int * incx,
105 double * y, const int * incy)
106{
107#ifdef DNN_USE_CBLAS
108 cblas_daxpy(*n, *alpha, x, *incx, y, *incy);
109#else
110 daxpy_(n, alpha, x, incx, y, incy);
111#endif
112}
113
114template<>
115inline void Axpy<float>(const int * n, const float * alpha,
116 const float * x, const int * incx,
117 float * y, const int * incy)
118{
119 saxpy_(n, alpha, x, incx, y, incy);
120}
121
122template<>
123inline void Gemv<double>(const char *trans, const int * m, const int * n,
124 const double * alpha, const double * A, const int * lda,
125 const double * x, const int * incx,
126 const double * beta, double * y, const int * incy)
127{
128 dgemv_(trans, m, n, alpha, A, lda, x, incx, beta, y, incy);
129}
130
131template<>
132inline void Gemv<float>(const char *trans, const int * m, const int * n,
133 const float * alpha, const float * A, const int * lda,
134 const float * x, const int * incx,
135 const float * beta, float * y, const int * incy)
136{
137 sgemv_(trans, m, n, alpha, A, lda, x, incx, beta, y, incy);
138}
139
140template<>
141inline void Gemm<double>(const char *transa, const char *transb,
142 const int * m, const int * n, const int* k,
143 const double * alpha, const double * A, const int * lda,
144 const double * B, const int * ldb, const double * beta,
145 double * C, const int * ldc)
146{
147 dgemm_(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
148}
149
150template<>
151inline void Gemm<float>(const char *transa, const char *transb,
152 const int * m, const int * n, const int* k,
153 const float * alpha, const float * A, const int * lda,
154 const float * B, const int * ldb, const float * beta,
155 float * C, const int * ldc)
156{
157 sgemm_(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
158}
159
160template <>
161inline void Ger<double>(const int * m, const int * n, const double * alpha,
162 const double * x, const int * incx,
163 const double * y, const int * incy,
164 double * A, const int * lda)
165{
166 dger_(m, n, alpha, x, incx, y, incy, A, lda);
167}
168
169template <>
170inline void Ger<float>(const int * m, const int * n, const float * alpha,
171 const float * x, const int * incx,
172 const float * y, const int * incy,
173 float * A, const int * lda)
174{
175 sger_(m, n, alpha, x, incx, y, incy, A, lda);
176}
177
178#else
179//--------------------------------------------------------
180// cblas implementation
181//-----------------------------------------------------------
182template<>
183inline void Axpy<double>(const int * n, const double * alpha,
184 const double * x, const int * incx,
185 double * y, const int * incy)
186{
187 cblas_daxpy(*n, *alpha, x, *incx, y, *incy);
188}
189
190template<>
191inline void Axpy<float>(const int * n, const float * alpha,
192 const float * x, const int * incx,
193 float * y, const int * incy)
194{
195 cblas_saxpy(*n, *alpha, x, *incx, y, *incy);
196}
197
198template<>
199inline void Gemv<double>(const char *trans, const int * m, const int * n,
200 const double * alpha, const double * A, const int * lda,
201 const double * x, const int * incx,
202 const double * beta, double * y, const int * incy)
203{
204 CBLAS_TRANSPOSE kTrans = (*trans == 'T') ? CblasTrans : CblasNoTrans;
205 cblas_dgemv(CblasColMajor, kTrans, *m, *n, *alpha, A, *lda, x, *incx, *beta, y, *incy);
206}
207
208template<>
209inline void Gemv<float>(const char *trans, const int * m, const int * n,
210 const float * alpha, const float * A, const int * lda,
211 const float * x, const int * incx,
212 const float * beta, float * y, const int * incy)
213{
214 CBLAS_TRANSPOSE kTrans = (*trans == 'T') ? CblasTrans : CblasNoTrans;
215 cblas_sgemv(CblasColMajor, kTrans, *m, *n, *alpha, A, *lda, x, *incx, *beta, y, *incy);
216}
217
218template<>
219inline void Gemm<double>(const char *transa, const char *transb,
220 const int * m, const int * n, const int* k,
221 const double * alpha, const double * A, const int * lda,
222 const double * B, const int * ldb, const double * beta,
223 double * C, const int * ldc)
224{
225 CBLAS_TRANSPOSE kTransA = (*transa == 'T') ? CblasTrans : CblasNoTrans;
226 CBLAS_TRANSPOSE kTransB = (*transb == 'T') ? CblasTrans : CblasNoTrans;
227 cblas_dgemm(CblasColMajor, kTransA, kTransB, *m, *n, *k, *alpha, A, *lda, B, *ldb, *beta, C, *ldc);
228}
229
230template<>
231inline void Gemm<float>(const char *transa, const char *transb,
232 const int * m, const int * n, const int* k,
233 const float * alpha, const float * A, const int * lda,
234 const float * B, const int * ldb, const float * beta,
235 float * C, const int * ldc)
236{
237 CBLAS_TRANSPOSE kTransA = (*transa == 'T') ? CblasTrans : CblasNoTrans;
238 CBLAS_TRANSPOSE kTransB = (*transb == 'T') ? CblasTrans : CblasNoTrans;
239 cblas_sgemm(CblasColMajor, kTransA, kTransB, *m, *n, *k, *alpha, A, *lda, B, *ldb, *beta, C, *ldc);
240}
241
242template <>
243inline void Ger<double>(const int * m, const int * n, const double * alpha,
244 const double * x, const int * incx,
245 const double * y, const int * incy,
246 double * A, const int * lda)
247{
248 cblas_dger(CblasColMajor, *m, *n, *alpha, x, *incx, y, *incy, A, *lda);
249}
250
251template <>
252inline void Ger<float>(const int * m, const int * n, const float * alpha,
253 const float * x, const int * incx,
254 const float * y, const int * incy,
255 float * A, const int * lda)
256{
257 cblas_sger(CblasColMajor, *m, *n, *alpha, x, *incx, y, *incy, A, *lda);
258}
259
260#endif
261
262} // namespace Blas
263} // namespace DNN
264} // namespace TMVA
265
266#endif
void dgemv_(const char *trans, const int *m, const int *n, const double *alpha, const double *A, const int *lda, const double *x, const int *incx, const double *beta, double *y, const int *incy)
void dger_(const int *m, const int *n, const double *alpha, const double *x, const int *incx, const double *y, const int *incy, double *A, const int *lda)
void sger_(const int *m, const int *n, const float *alpha, const float *x, const int *incx, const float *y, const int *incy, float *A, const int *lda)
void sgemv_(const char *trans, const int *m, const int *n, const float *alpha, const float *A, const int *lda, const float *x, const int *incx, const float *beta, float *y, const int *incy)
void saxpy_(const int *n, const float *alpha, const float *x, const int *incx, float *y, const int *incy)
void dgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const double *alpha, const double *A, const int *lda, const double *B, const int *ldb, const double *beta, double *C, const int *ldc)
void daxpy_(const int *n, const double *alpha, const double *x, const int *incx, double *y, const int *incy)
void sgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const float *alpha, const float *A, const int *lda, const float *B, const int *ldb, const float *beta, float *C, const int *ldc)
float Real_t
Definition: RtypesCore.h:64
double beta(double x, double y)
Calculates the beta function.
Double_t y[n]
Definition: legend1.C:17
Double_t x[n]
Definition: legend1.C:17
const Int_t n
Definition: legend1.C:16
static double B[]
static double A[]
static double C[]
void Gemm< float >(const char *transa, const char *transb, const int *m, const int *n, const int *k, const float *alpha, const float *A, const int *lda, const float *B, const int *ldb, const float *beta, float *C, const int *ldc)
Definition: Blas.h:151
void Ger(const int *m, const int *n, const Real_t *alpha, const Real_t *x, const int *incx, const Real_t *y, const int *incy, Real_t *A, const int *lda)
Add the outer product of x and y to the matrix A.
void Gemv(const char *trans, const int *m, const int *n, const Real_t *alpha, const Real_t *A, const int *lda, const Real_t *x, const int *incx, const Real_t *beta, Real_t *y, const int *incy)
Multiply the vector x with the matrix A and store the result in y.
void Gemv< float >(const char *trans, const int *m, const int *n, const float *alpha, const float *A, const int *lda, const float *x, const int *incx, const float *beta, float *y, const int *incy)
Definition: Blas.h:132
void Axpy(const int *n, const Real_t *alpha, const Real_t *x, const int *incx, Real_t *y, const int *incy)
Add the vector x scaled by alpha to y scaled by \beta.
void Ger< float >(const int *m, const int *n, const float *alpha, const float *x, const int *incx, const float *y, const int *incy, float *A, const int *lda)
Definition: Blas.h:170
void Gemm< double >(const char *transa, const char *transb, const int *m, const int *n, const int *k, const double *alpha, const double *A, const int *lda, const double *B, const int *ldb, const double *beta, double *C, const int *ldc)
Definition: Blas.h:141
void Ger< double >(const int *m, const int *n, const double *alpha, const double *x, const int *incx, const double *y, const int *incy, double *A, const int *lda)
Definition: Blas.h:161
void Axpy< float >(const int *n, const float *alpha, const float *x, const int *incx, float *y, const int *incy)
Definition: Blas.h:115
void Axpy< double >(const int *n, const double *alpha, const double *x, const int *incx, double *y, const int *incy)
Definition: Blas.h:103
void Gemv< double >(const char *trans, const int *m, const int *n, const double *alpha, const double *A, const int *lda, const double *x, const int *incx, const double *beta, double *y, const int *incy)
Definition: Blas.h:123
void Gemm(const char *transa, const char *transb, const int *m, const int *n, const int *k, const Real_t *alpha, const Real_t *A, const int *lda, const Real_t *B, const int *ldb, const Real_t *beta, Real_t *C, const int *ldc)
Multiply the matrix A with the matrix B and store the result in C.
Abstract ClassifierFactory template that handles arbitrary types.
auto * m
Definition: textangle.C:8