blas.h revision 9c3043ff3bf31a6a81810b4ce9e87ef936f1f529
1/* Copyright 2015 Google Inc. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Exposes the family of BLAS routines as pre-canned high performance calls for
17// use in conjunction with the StreamExecutor abstraction.
18//
19// Note that this interface is optionally supported by platforms; see
20// StreamExecutor::SupportsBlas() for details.
21//
22// This abstraction makes it simple to entrain BLAS operations on GPU data into
23// a Stream -- users typically will not use this API directly, but will use the
24// Stream builder methods to entrain these operations "under the hood". For
25// example:
26//
27//  DeviceMemory<float> x = stream_exec->AllocateArray<float>(1024);
28//  DeviceMemory<float> y = stream_exec->AllocateArray<float>(1024);
29//  // ... populate x and y ...
30//  Stream stream{stream_exec};
31//  stream
32//    .Init()
33//    .ThenBlasAxpy(1024, 5.5, x, 1, &y, 1)
34//    .BlockHostUntilDone();
35//
36// By using stream operations in this manner the user can easily intermix custom
37// kernel launches (via StreamExecutor::ThenLaunch()) with these pre-canned BLAS
38// routines.
39
40#ifndef TENSORFLOW_STREAM_EXECUTOR_BLAS_H_
41#define TENSORFLOW_STREAM_EXECUTOR_BLAS_H_
42
43#include <complex>
44#include "tensorflow/stream_executor/platform/port.h"
45
46#include "tensorflow/stream_executor/lib/array_slice.h"
47#include "tensorflow/stream_executor/platform/port.h"
48
49namespace perftools {
50namespace gputools {
51
52class Stream;
53
54template <typename ElemT>
55class DeviceMemory;
56
57namespace blas {
58
59// Specifies whether the input matrix will be transposed or
60// transposed+conjugated before any BLAS operations.
61enum class Transpose { kNoTranspose, kTranspose, kConjugateTranspose };
62
63// Returns a name for t.
64string TransposeString(Transpose t);
65
66// Specifies whether the upper or lower triangular part of a
67// symmetric/Hermitian matrix is used.
68enum class UpperLower { kUpper, kLower };
69
70// Returns a name for ul.
71string UpperLowerString(UpperLower ul);
72
73// Specifies whether a matrix is unit triangular.
74enum class Diagonal { kUnit, kNonUnit };
75
76// Returns a name for d.
77string DiagonalString(Diagonal d);
78
79// Specifies whether a Hermitian matrix appears on the left or right in
80// operation.
81enum class Side { kLeft, kRight };
82
83// Returns a name for s.
84string SideString(Side s);
85
86// BLAS support interface -- this can be derived from a GPU executor when the
87// underlying platform has an BLAS library implementation available. See
88// StreamExecutor::AsBlas().
89//
90// Thread-hostile: CUDA associates a CUDA-context with a particular thread in
91// the system. Any operation that a user attempts to perform by enqueueing BLAS
92// operations on a thread not-associated with the CUDA-context has unknown
93// behavior at the current time; see b/13176597
94class BlasSupport {
95 public:
96  virtual ~BlasSupport() {}
97
98  // Computes the sum of magnitudes of the vector elements.
99  // result <- |Re x(1)| + |Im x(1)| + |Re  x(2)| + |Im  x(2)|+ ... + |Re  x(n)|
100  // + |Im x(n)|.
101  // Note that Im x(i) = 0 for real types float/double.
102  virtual bool DoBlasAsum(Stream *stream, uint64 elem_count,
103                          const DeviceMemory<float> &x, int incx,
104                          DeviceMemory<float> *result) = 0;
105  virtual bool DoBlasAsum(Stream *stream, uint64 elem_count,
106                          const DeviceMemory<double> &x, int incx,
107                          DeviceMemory<double> *result) = 0;
108  virtual bool DoBlasAsum(Stream *stream, uint64 elem_count,
109                          const DeviceMemory<std::complex<float>> &x, int incx,
110                          DeviceMemory<float> *result) = 0;
111  virtual bool DoBlasAsum(Stream *stream, uint64 elem_count,
112                          const DeviceMemory<std::complex<double>> &x, int incx,
113                          DeviceMemory<double> *result) = 0;
114
115  // Performs a BLAS y <- ax+y operation.
116  virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha,
117                          const DeviceMemory<float> &x, int incx,
118                          DeviceMemory<float> *y, int incy) = 0;
119  virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha,
120                          const DeviceMemory<double> &x, int incx,
121                          DeviceMemory<double> *y, int incy) = 0;
122  virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count,
123                          std::complex<float> alpha,
124                          const DeviceMemory<std::complex<float>> &x, int incx,
125                          DeviceMemory<std::complex<float>> *y, int incy) = 0;
126  virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count,
127                          std::complex<double> alpha,
128                          const DeviceMemory<std::complex<double>> &x, int incx,
129                          DeviceMemory<std::complex<double>> *y, int incy) = 0;
130
131  // Copies vector to another vector: y <- x.
132  virtual bool DoBlasCopy(Stream *stream, uint64 elem_count,
133                          const DeviceMemory<float> &x, int incx,
134                          DeviceMemory<float> *y, int incy) = 0;
135  virtual bool DoBlasCopy(Stream *stream, uint64 elem_count,
136                          const DeviceMemory<double> &x, int incx,
137                          DeviceMemory<double> *y, int incy) = 0;
138  virtual bool DoBlasCopy(Stream *stream, uint64 elem_count,
139                          const DeviceMemory<std::complex<float>> &x, int incx,
140                          DeviceMemory<std::complex<float>> *y, int incy) = 0;
141  virtual bool DoBlasCopy(Stream *stream, uint64 elem_count,
142                          const DeviceMemory<std::complex<double>> &x, int incx,
143                          DeviceMemory<std::complex<double>> *y, int incy) = 0;
144
145  // Performs a BLAS dot product result <- x . y.
146  virtual bool DoBlasDot(Stream *stream, uint64 elem_count,
147                         const DeviceMemory<float> &x, int incx,
148                         const DeviceMemory<float> &y, int incy,
149                         DeviceMemory<float> *result) = 0;
150  virtual bool DoBlasDot(Stream *stream, uint64 elem_count,
151                         const DeviceMemory<double> &x, int incx,
152                         const DeviceMemory<double> &y, int incy,
153                         DeviceMemory<double> *result) = 0;
154
155  // Performs a BLAS dot product result <- conj(x) . y for complex types.
156  virtual bool DoBlasDotc(Stream *stream, uint64 elem_count,
157                          const DeviceMemory<std::complex<float>> &x, int incx,
158                          const DeviceMemory<std::complex<float>> &y, int incy,
159                          DeviceMemory<std::complex<float>> *result) = 0;
160  virtual bool DoBlasDotc(Stream *stream, uint64 elem_count,
161                          const DeviceMemory<std::complex<double>> &x, int incx,
162                          const DeviceMemory<std::complex<double>> &y, int incy,
163                          DeviceMemory<std::complex<double>> *result) = 0;
164
165  // Performs a BLAS dot product result <- x . y for complex types. Note that
166  // x is unconjugated in this routine.
167  virtual bool DoBlasDotu(Stream *stream, uint64 elem_count,
168                          const DeviceMemory<std::complex<float>> &x, int incx,
169                          const DeviceMemory<std::complex<float>> &y, int incy,
170                          DeviceMemory<std::complex<float>> *result) = 0;
171  virtual bool DoBlasDotu(Stream *stream, uint64 elem_count,
172                          const DeviceMemory<std::complex<double>> &x, int incx,
173                          const DeviceMemory<std::complex<double>> &y, int incy,
174                          DeviceMemory<std::complex<double>> *result) = 0;
175
176  // Computes the Euclidean norm of a vector: result <- ||x||.
177  // See the following link for more information of Euclidean norm:
178  // http://en.wikipedia.org/wiki/Norm_(mathematics)#Euclidean_norm
179  virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count,
180                          const DeviceMemory<float> &x, int incx,
181                          DeviceMemory<float> *result) = 0;
182  virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count,
183                          const DeviceMemory<double> &x, int incx,
184                          DeviceMemory<double> *result) = 0;
185  virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count,
186                          const DeviceMemory<std::complex<float>> &x, int incx,
187                          DeviceMemory<float> *result) = 0;
188  virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count,
189                          const DeviceMemory<std::complex<double>> &x, int incx,
190                          DeviceMemory<double> *result) = 0;
191
192  // Performs rotation of points in the plane:
193  // x(i) = c*x(i) + s*y(i)
194  // y(i) = c*y(i) - s*x(i).
195  virtual bool DoBlasRot(Stream *stream, uint64 elem_count,
196                         DeviceMemory<float> *x, int incx,
197                         DeviceMemory<float> *y, int incy, float c,
198                         float s) = 0;
199  virtual bool DoBlasRot(Stream *stream, uint64 elem_count,
200                         DeviceMemory<double> *x, int incx,
201                         DeviceMemory<double> *y, int incy, double c,
202                         double s) = 0;
203  virtual bool DoBlasRot(Stream *stream, uint64 elem_count,
204                         DeviceMemory<std::complex<float>> *x, int incx,
205                         DeviceMemory<std::complex<float>> *y, int incy,
206                         float c, float s) = 0;
207  virtual bool DoBlasRot(Stream *stream, uint64 elem_count,
208                         DeviceMemory<std::complex<double>> *x, int incx,
209                         DeviceMemory<std::complex<double>> *y, int incy,
210                         double c, double s) = 0;
211
212  // Computes the parameters for a Givens rotation.
213  // Given the Cartesian coordinates (a, b) of a point, these routines return
214  // the parameters c, s, r, and z associated with the Givens rotation. The
215  // parameters c and s define a unitary matrix such that:
216  //
217  //   |  c s |.| a | = | r |
218  //   | -s c | | b |   | 0 |
219  //
220  // The parameter z is defined such that if |a| > |b|, z is s; otherwise if
221  // c is not 0 z is 1/c; otherwise z is 1.
222  virtual bool DoBlasRotg(Stream *stream, DeviceMemory<float> *a,
223                          DeviceMemory<float> *b, DeviceMemory<float> *c,
224                          DeviceMemory<float> *s) = 0;
225  virtual bool DoBlasRotg(Stream *stream, DeviceMemory<double> *a,
226                          DeviceMemory<double> *b, DeviceMemory<double> *c,
227                          DeviceMemory<double> *s) = 0;
228  virtual bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a,
229                          DeviceMemory<std::complex<float>> *b,
230                          DeviceMemory<float> *c,
231                          DeviceMemory<std::complex<float>> *s) = 0;
232  virtual bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a,
233                          DeviceMemory<std::complex<double>> *b,
234                          DeviceMemory<double> *c,
235                          DeviceMemory<std::complex<double>> *s) = 0;
236
237  // Performs modified Givens rotation of points in the plane.
238  // Given two vectors x and y, each vector element of these vectors is replaced
239  // as follows:
240  //
241  //   | x(i) | =  H | x(i) |
242  //   | y(i) |      | y(i) |
243  //
244  // for i=1 to n, where H is a modified Givens transformation matrix whose
245  // values are stored in the param[1] through param[4] array.
246  // For more information please Google this routine.
247  virtual bool DoBlasRotm(Stream *stream, uint64 elem_count,
248                          DeviceMemory<float> *x, int incx,
249                          DeviceMemory<float> *y, int incy,
250                          const DeviceMemory<float> &param) = 0;
251  virtual bool DoBlasRotm(Stream *stream, uint64 elem_count,
252                          DeviceMemory<double> *x, int incx,
253                          DeviceMemory<double> *y, int incy,
254                          const DeviceMemory<double> &param) = 0;
255
256  // Computes the parameters for a modified Givens rotation.
257  // Given Cartesian coordinates (x1, y1) of an input vector, these routines
258  // compute the components of a modified Givens transformation matrix H that
259  // zeros the y-component of the resulting vector:
260  //
261  //   | x1 | =  H | x1 * sqrt(d1) |
262  //   |  0 |      | y1 * sqrt(d1) |
263  //
264  // For more information please Google this routine.
265  virtual bool DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1,
266                           DeviceMemory<float> *d2, DeviceMemory<float> *x1,
267                           const DeviceMemory<float> &y1,
268                           DeviceMemory<float> *param) = 0;
269  virtual bool DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1,
270                           DeviceMemory<double> *d2, DeviceMemory<double> *x1,
271                           const DeviceMemory<double> &y1,
272                           DeviceMemory<double> *param) = 0;
273
274  // Computes the product of a vector by a scalar: x <- a*x.
275  virtual bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
276                          DeviceMemory<float> *x, int incx) = 0;
277  virtual bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
278                          DeviceMemory<double> *x, int incx) = 0;
279  virtual bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
280                          DeviceMemory<std::complex<float>> *x, int incx) = 0;
281  virtual bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
282                          DeviceMemory<std::complex<double>> *x, int incx) = 0;
283  virtual bool DoBlasScal(Stream *stream, uint64 elem_count,
284                          std::complex<float> alpha,
285                          DeviceMemory<std::complex<float>> *x, int incx) = 0;
286  virtual bool DoBlasScal(Stream *stream, uint64 elem_count,
287                          std::complex<double> alpha,
288                          DeviceMemory<std::complex<double>> *x, int incx) = 0;
289
290  // Swaps a vector with another vector.
291  virtual bool DoBlasSwap(Stream *stream, uint64 elem_count,
292                          DeviceMemory<float> *x, int incx,
293                          DeviceMemory<float> *y, int incy) = 0;
294  virtual bool DoBlasSwap(Stream *stream, uint64 elem_count,
295                          DeviceMemory<double> *x, int incx,
296                          DeviceMemory<double> *y, int incy) = 0;
297  virtual bool DoBlasSwap(Stream *stream, uint64 elem_count,
298                          DeviceMemory<std::complex<float>> *x, int incx,
299                          DeviceMemory<std::complex<float>> *y, int incy) = 0;
300  virtual bool DoBlasSwap(Stream *stream, uint64 elem_count,
301                          DeviceMemory<std::complex<double>> *x, int incx,
302                          DeviceMemory<std::complex<double>> *y, int incy) = 0;
303
304  // Finds the index of the element with maximum absolute value.
305  virtual bool DoBlasIamax(Stream *stream, uint64 elem_count,
306                           const DeviceMemory<float> &x, int incx,
307                           DeviceMemory<int> *result) = 0;
308  virtual bool DoBlasIamax(Stream *stream, uint64 elem_count,
309                           const DeviceMemory<double> &x, int incx,
310                           DeviceMemory<int> *result) = 0;
311  virtual bool DoBlasIamax(Stream *stream, uint64 elem_count,
312                           const DeviceMemory<std::complex<float>> &x, int incx,
313                           DeviceMemory<int> *result) = 0;
314  virtual bool DoBlasIamax(Stream *stream, uint64 elem_count,
315                           const DeviceMemory<std::complex<double>> &x,
316                           int incx, DeviceMemory<int> *result) = 0;
317
318  // Finds the index of the element with minimum absolute value.
319  virtual bool DoBlasIamin(Stream *stream, uint64 elem_count,
320                           const DeviceMemory<float> &x, int incx,
321                           DeviceMemory<int> *result) = 0;
322  virtual bool DoBlasIamin(Stream *stream, uint64 elem_count,
323                           const DeviceMemory<double> &x, int incx,
324                           DeviceMemory<int> *result) = 0;
325  virtual bool DoBlasIamin(Stream *stream, uint64 elem_count,
326                           const DeviceMemory<std::complex<float>> &x, int incx,
327                           DeviceMemory<int> *result) = 0;
328  virtual bool DoBlasIamin(Stream *stream, uint64 elem_count,
329                           const DeviceMemory<std::complex<double>> &x,
330                           int incx, DeviceMemory<int> *result) = 0;
331
332  // Computes a matrix-vector product using a general band matrix:
333  //
334  //     y <- alpha * a * x + beta * y,
335  // or
336  //     y <- alpha * a' * x + beta * y,
337  // or
338  //     y <- alpha * conj(a') * x + beta * y,
339  //
340  // alpha and beta are scalars; a is an m-by-n general band matrix, with kl
341  // sub-diagonals and ku super-diagonals; x is a vector with
342  // n(trans==kNoTranspose)/m(otherwise) elements;
343  // y is a vector with m(trans==kNoTranspose)/n(otherwise) elements.
344  virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
345                          uint64 n, uint64 kl, uint64 ku, float alpha,
346                          const DeviceMemory<float> &a, int lda,
347                          const DeviceMemory<float> &x, int incx, float beta,
348                          DeviceMemory<float> *y, int incy) = 0;
349  virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
350                          uint64 n, uint64 kl, uint64 ku, double alpha,
351                          const DeviceMemory<double> &a, int lda,
352                          const DeviceMemory<double> &x, int incx, double beta,
353                          DeviceMemory<double> *y, int incy) = 0;
354  virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
355                          uint64 n, uint64 kl, uint64 ku,
356                          std::complex<float> alpha,
357                          const DeviceMemory<std::complex<float>> &a, int lda,
358                          const DeviceMemory<std::complex<float>> &x, int incx,
359                          std::complex<float> beta,
360                          DeviceMemory<std::complex<float>> *y, int incy) = 0;
361  virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
362                          uint64 n, uint64 kl, uint64 ku,
363                          std::complex<double> alpha,
364                          const DeviceMemory<std::complex<double>> &a, int lda,
365                          const DeviceMemory<std::complex<double>> &x, int incx,
366                          std::complex<double> beta,
367                          DeviceMemory<std::complex<double>> *y, int incy) = 0;
368
369  // Computes a matrix-vector product using a general matrix.
370  //
371  //     y <- alpha * a * x + beta * y,
372  // or
373  //     y <- alpha * a' * x + beta * y,
374  // or
375  //     y <- alpha * conj(a') * x + beta * y,
376  //
377  // alpha and beta are scalars; a is an m-by-n general matrix; x is a vector
378  // with n(trans==kNoTranspose)/m(otherwise) elements;
379  // y is a vector with m(trans==kNoTranspose)/n(otherwise) elements.
380  virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
381                          uint64 n, float alpha, const DeviceMemory<float> &a,
382                          int lda, const DeviceMemory<float> &x, int incx,
383                          float beta, DeviceMemory<float> *y, int incy) = 0;
384  virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
385                          uint64 n, double alpha, const DeviceMemory<double> &a,
386                          int lda, const DeviceMemory<double> &x, int incx,
387                          double beta, DeviceMemory<double> *y, int incy) = 0;
388  virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
389                          uint64 n, std::complex<float> alpha,
390                          const DeviceMemory<std::complex<float>> &a, int lda,
391                          const DeviceMemory<std::complex<float>> &x, int incx,
392                          std::complex<float> beta,
393                          DeviceMemory<std::complex<float>> *y, int incy) = 0;
394  virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
395                          uint64 n, std::complex<double> alpha,
396                          const DeviceMemory<std::complex<double>> &a, int lda,
397                          const DeviceMemory<std::complex<double>> &x, int incx,
398                          std::complex<double> beta,
399                          DeviceMemory<std::complex<double>> *y, int incy) = 0;
400
401  // Performs a rank-1 update of a general matrix.
402  //
403  //     a <- alpha * x * y' + a,
404  //
405  // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is
406  // an m-by-n general matrix.
407  virtual bool DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha,
408                         const DeviceMemory<float> &x, int incx,
409                         const DeviceMemory<float> &y, int incy,
410                         DeviceMemory<float> *a, int lda) = 0;
411  virtual bool DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha,
412                         const DeviceMemory<double> &x, int incx,
413                         const DeviceMemory<double> &y, int incy,
414                         DeviceMemory<double> *a, int lda) = 0;
415
416  // Performs a rank-1 update (conjugated) of a general matrix.
417  //
418  //     a <- alpha * x * conj(y') + a,
419  //
420  // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is
421  // an m-by-n general matrix.
422  virtual bool DoBlasGerc(Stream *stream, uint64 m, uint64 n,
423                          std::complex<float> alpha,
424                          const DeviceMemory<std::complex<float>> &x, int incx,
425                          const DeviceMemory<std::complex<float>> &y, int incy,
426                          DeviceMemory<std::complex<float>> *a, int lda) = 0;
427  virtual bool DoBlasGerc(Stream *stream, uint64 m, uint64 n,
428                          std::complex<double> alpha,
429                          const DeviceMemory<std::complex<double>> &x, int incx,
430                          const DeviceMemory<std::complex<double>> &y, int incy,
431                          DeviceMemory<std::complex<double>> *a, int lda) = 0;
432
433  // Performs a rank-1 update (unconjugated) of a general matrix.
434  //
435  //     a <- alpha * x * y' + a,
436  //
437  // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is
438  // an m-by-n general matrix.
439  virtual bool DoBlasGeru(Stream *stream, uint64 m, uint64 n,
440                          std::complex<float> alpha,
441                          const DeviceMemory<std::complex<float>> &x, int incx,
442                          const DeviceMemory<std::complex<float>> &y, int incy,
443                          DeviceMemory<std::complex<float>> *a, int lda) = 0;
444  virtual bool DoBlasGeru(Stream *stream, uint64 m, uint64 n,
445                          std::complex<double> alpha,
446                          const DeviceMemory<std::complex<double>> &x, int incx,
447                          const DeviceMemory<std::complex<double>> &y, int incy,
448                          DeviceMemory<std::complex<double>> *a, int lda) = 0;
449
450  // Computes a matrix-vector product using a Hermitian band matrix.
451  //
452  //     y <- alpha * a * x + beta * y,
453  //
454  // alpha and beta are scalars; a is an n-by-n Hermitian band matrix, with k
455  // super-diagonals; x and y are n-element vectors.
456  virtual bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
457                          uint64 k, std::complex<float> alpha,
458                          const DeviceMemory<std::complex<float>> &a, int lda,
459                          const DeviceMemory<std::complex<float>> &x, int incx,
460                          std::complex<float> beta,
461                          DeviceMemory<std::complex<float>> *y, int incy) = 0;
462  virtual bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
463                          uint64 k, std::complex<double> alpha,
464                          const DeviceMemory<std::complex<double>> &a, int lda,
465                          const DeviceMemory<std::complex<double>> &x, int incx,
466                          std::complex<double> beta,
467                          DeviceMemory<std::complex<double>> *y, int incy) = 0;
468
469  // Computes a matrix-vector product using a Hermitian matrix.
470  //
471  //     y <- alpha * a * x + beta * y,
472  //
473  // alpha and beta are scalars; a is an n-by-n Hermitian matrix; x and y are
474  // n-element vectors.
475  virtual bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
476                          std::complex<float> alpha,
477                          const DeviceMemory<std::complex<float>> &a, int lda,
478                          const DeviceMemory<std::complex<float>> &x, int incx,
479                          std::complex<float> beta,
480                          DeviceMemory<std::complex<float>> *y, int incy) = 0;
481  virtual bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
482                          std::complex<double> alpha,
483                          const DeviceMemory<std::complex<double>> &a, int lda,
484                          const DeviceMemory<std::complex<double>> &x, int incx,
485                          std::complex<double> beta,
486                          DeviceMemory<std::complex<double>> *y, int incy) = 0;
487
488  // Performs a rank-1 update of a Hermitian matrix.
489  //
490  //     a <- alpha * x * conj(x') + a,
491  //
492  // alpha is a scalar; x is an n-element vector; a is an n-by-n Hermitian
493  // matrix.
494  virtual bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
495                         float alpha,
496                         const DeviceMemory<std::complex<float>> &x, int incx,
497                         DeviceMemory<std::complex<float>> *a, int lda) = 0;
498  virtual bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
499                         double alpha,
500                         const DeviceMemory<std::complex<double>> &x, int incx,
501                         DeviceMemory<std::complex<double>> *a, int lda) = 0;
502
503  // Performs a rank-2 update of a Hermitian matrix.
504  //
505  //     a <- alpha * x * conj(x') + conj(alpha) * y * conj(x') + a,
506  //
507  // alpha is a scalar; x and y are n-element vectors; a is an n-by-n Hermitian
508  // matrix.
509  virtual bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
510                          std::complex<float> alpha,
511                          const DeviceMemory<std::complex<float>> &x, int incx,
512                          const DeviceMemory<std::complex<float>> &y, int incy,
513                          DeviceMemory<std::complex<float>> *a, int lda) = 0;
514  virtual bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
515                          std::complex<double> alpha,
516                          const DeviceMemory<std::complex<double>> &x, int incx,
517                          const DeviceMemory<std::complex<double>> &y, int incy,
518                          DeviceMemory<std::complex<double>> *a, int lda) = 0;
519
520  // Computes a matrix-vector product using a Hermitian packed matrix.
521  //
522  //     y <- alpha * a * x + beta * y,
523  //
524  // alpha and beta are scalars; a is an n-by-n Hermitian matrix, supplied in
525  // packed form; x and y are n-element vectors.
526  virtual bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
527                          std::complex<float> alpha,
528                          const DeviceMemory<std::complex<float>> &ap,
529                          const DeviceMemory<std::complex<float>> &x, int incx,
530                          std::complex<float> beta,
531                          DeviceMemory<std::complex<float>> *y, int incy) = 0;
532  virtual bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
533                          std::complex<double> alpha,
534                          const DeviceMemory<std::complex<double>> &ap,
535                          const DeviceMemory<std::complex<double>> &x, int incx,
536                          std::complex<double> beta,
537                          DeviceMemory<std::complex<double>> *y, int incy) = 0;
538
539  // Performs a rank-1 update of a Hermitian packed matrix.
540  //
541  //     a <- alpha * x * conj(x') + a,
542  //
543  // alpha is a scalar; x is an n-element vector; a is an n-by-n Hermitian
544  // matrix, supplied in packed form.
545  virtual bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
546                         float alpha,
547                         const DeviceMemory<std::complex<float>> &x, int incx,
548                         DeviceMemory<std::complex<float>> *ap) = 0;
549  virtual bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
550                         double alpha,
551                         const DeviceMemory<std::complex<double>> &x, int incx,
552                         DeviceMemory<std::complex<double>> *ap) = 0;
553
554  // Performs a rank-2 update of a Hermitian packed matrix.
555  //
556  //     a <- alpha * x * conj(x') + conj(alpha) * y * conj(x') + a,
557  //
558  // alpha is a scalar; x and y are n-element vectors; a is an n-by-n Hermitian
559  // matrix, supplied in packed form.
560  virtual bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
561                          std::complex<float> alpha,
562                          const DeviceMemory<std::complex<float>> &x, int incx,
563                          const DeviceMemory<std::complex<float>> &y, int incy,
564                          DeviceMemory<std::complex<float>> *ap) = 0;
565  virtual bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
566                          std::complex<double> alpha,
567                          const DeviceMemory<std::complex<double>> &x, int incx,
568                          const DeviceMemory<std::complex<double>> &y, int incy,
569                          DeviceMemory<std::complex<double>> *ap) = 0;
570
571  // Computes a matrix-vector product using a symmetric band matrix.
572  //
573  //     y <- alpha * a * x + beta * y,
574  //
575  // alpha and beta are scalars; a is an n-by-n symmetric band matrix, with k
576  // super-diagonals; x and y are n-element vectors.
577  virtual bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
578                          uint64 k, float alpha, const DeviceMemory<float> &a,
579                          int lda, const DeviceMemory<float> &x, int incx,
580                          float beta, DeviceMemory<float> *y, int incy) = 0;
581  virtual bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
582                          uint64 k, double alpha, const DeviceMemory<double> &a,
583                          int lda, const DeviceMemory<double> &x, int incx,
584                          double beta, DeviceMemory<double> *y, int incy) = 0;
585
586  // Computes a matrix-vector product using a symmetric packed matrix.
587  //
588  //     y <- alpha * a * x + beta * y,
589  //
590  // alpha and beta are scalars; a is an n-by-n symmetric matrix, supplied in
591  // packed form; x and y are n-element vectors.
592  virtual bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
593                          float alpha, const DeviceMemory<float> &ap,
594                          const DeviceMemory<float> &x, int incx, float beta,
595                          DeviceMemory<float> *y, int incy) = 0;
596  virtual bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
597                          double alpha, const DeviceMemory<double> &ap,
598                          const DeviceMemory<double> &x, int incx, double beta,
599                          DeviceMemory<double> *y, int incy) = 0;
600
601  // Performs a rank-1 update of a symmetric packed matrix.
602  //
603  //     a <- alpha * x * x' + a,
604  //
605  // alpha is a scalar; x is an n-element vector; a is an n-by-n symmetric
606  // matrix, supplied in packed form.
607  virtual bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
608                         float alpha, const DeviceMemory<float> &x, int incx,
609                         DeviceMemory<float> *ap) = 0;
610  virtual bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
611                         double alpha, const DeviceMemory<double> &x, int incx,
612                         DeviceMemory<double> *ap) = 0;
613
614  // Performs a rank-2 update of a symmetric packed matrix.
615  //
616  //     a <- alpha * x * x' + alpha * y * x' + a,
617  //
618  // alpha is a scalar; x and y are n-element vectors; a is an n-by-n symmetric
619  // matrix, supplied in packed form.
620  virtual bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
621                          float alpha, const DeviceMemory<float> &x, int incx,
622                          const DeviceMemory<float> &y, int incy,
623                          DeviceMemory<float> *ap) = 0;
624  virtual bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
625                          double alpha, const DeviceMemory<double> &x, int incx,
626                          const DeviceMemory<double> &y, int incy,
627                          DeviceMemory<double> *ap) = 0;
628
629  // Computes a matrix-vector product for a symmetric matrix.
630  //
631  //     y <- alpha * a * x + beta * y,
632  //
633  // alpha and beta are scalars; a is an n-by-n symmetric matrix; x and y are
634  // n-element vectors.
635  virtual bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
636                          float alpha, const DeviceMemory<float> &a, int lda,
637                          const DeviceMemory<float> &x, int incx, float beta,
638                          DeviceMemory<float> *y, int incy) = 0;
639  virtual bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
640                          double alpha, const DeviceMemory<double> &a, int lda,
641                          const DeviceMemory<double> &x, int incx, double beta,
642                          DeviceMemory<double> *y, int incy) = 0;
643
644  // Performs a rank-1 update of a symmetric matrix.
645  //
646  //     a <- alpha * x * x' + a,
647  //
648  // alpha is a scalar; x is an n-element vector; a is an n-by-n symmetric
649  // matrix.
650  virtual bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
651                         float alpha, const DeviceMemory<float> &x, int incx,
652                         DeviceMemory<float> *a, int lda) = 0;
653  virtual bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
654                         double alpha, const DeviceMemory<double> &x, int incx,
655                         DeviceMemory<double> *a, int lda) = 0;
656
657  // Performs a rank-2 update of symmetric matrix.
658  //
659  //     a <- alpha * x * x' + alpha * y * x' + a,
660  //
661  // alpha is a scalar; x and y are n-element vectors; a is an n-by-n symmetric
662  // matrix.
663  virtual bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
664                          float alpha, const DeviceMemory<float> &x, int incx,
665                          const DeviceMemory<float> &y, int incy,
666                          DeviceMemory<float> *a, int lda) = 0;
667  virtual bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
668                          double alpha, const DeviceMemory<double> &x, int incx,
669                          const DeviceMemory<double> &y, int incy,
670                          DeviceMemory<double> *a, int lda) = 0;
671
672  // Computes a matrix-vector product using a triangular band matrix.
673  //
674  //     x <- a * x,
675  // or
676  //     x <- a' * x,
677  // or
678  //     x <- conj(a') * x,
679  //
680  // a is an n-by-n unit, or non-unit, upper or lower triangular band matrix,
681  // with k+1 diagonals; x is a n-element vector.
682  virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
683                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
684                          uint64 k, const DeviceMemory<float> &a, int lda,
685                          DeviceMemory<float> *x, int incx) = 0;
686  virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
687                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
688                          uint64 k, const DeviceMemory<double> &a, int lda,
689                          DeviceMemory<double> *x, int incx) = 0;
690  virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
691                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
692                          uint64 k, const DeviceMemory<std::complex<float>> &a,
693                          int lda, DeviceMemory<std::complex<float>> *x,
694                          int incx) = 0;
695  virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
696                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
697                          uint64 k, const DeviceMemory<std::complex<double>> &a,
698                          int lda, DeviceMemory<std::complex<double>> *x,
699                          int incx) = 0;
700
701  // Solves a system of linear equations whose coefficients are in a triangular
702  // band matrix as below:
703  //
704  //     a * x = b,
705  // or
706  //     a' * x = b,
707  // or
708  //     conj(a') * x = b,
709  //
710  // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or
711  // lower triangular band matrix, with k+1 diagonals.
712  virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
713                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
714                          uint64 k, const DeviceMemory<float> &a, int lda,
715                          DeviceMemory<float> *x, int incx) = 0;
716  virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
717                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
718                          uint64 k, const DeviceMemory<double> &a, int lda,
719                          DeviceMemory<double> *x, int incx) = 0;
720  virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
721                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
722                          uint64 k, const DeviceMemory<std::complex<float>> &a,
723                          int lda, DeviceMemory<std::complex<float>> *x,
724                          int incx) = 0;
725  virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
726                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
727                          uint64 k, const DeviceMemory<std::complex<double>> &a,
728                          int lda, DeviceMemory<std::complex<double>> *x,
729                          int incx) = 0;
730
731  // Computes a matrix-vector product using a triangular packed matrix.
732  //
733  //     x <- a * x,
734  // or
735  //     x <- a' * x,
736  // or
737  //     x <- conj(a') * x,
738  //
739  // a is an n-by-n unit, or non-unit, upper or lower triangular matrix,
740  // supplied in packed form; x is a n-element vector.
741  virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
742                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
743                          const DeviceMemory<float> &ap, DeviceMemory<float> *x,
744                          int incx) = 0;
745  virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
746                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
747                          const DeviceMemory<double> &ap,
748                          DeviceMemory<double> *x, int incx) = 0;
749  virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
750                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
751                          const DeviceMemory<std::complex<float>> &ap,
752                          DeviceMemory<std::complex<float>> *x, int incx) = 0;
753  virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
754                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
755                          const DeviceMemory<std::complex<double>> &ap,
756                          DeviceMemory<std::complex<double>> *x, int incx) = 0;
757
758  // Solves a system of linear equations whose coefficients are in a triangular
759  // packed matrix as below:
760  //
761  //     a * x = b,
762  // or
763  //     a' * x = b,
764  // or
765  //     conj(a') * x = b,
766  //
767  // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or
768  // lower triangular matrix, supplied in packed form.
769  virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
770                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
771                          const DeviceMemory<float> &ap, DeviceMemory<float> *x,
772                          int incx) = 0;
773  virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
774                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
775                          const DeviceMemory<double> &ap,
776                          DeviceMemory<double> *x, int incx) = 0;
777  virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
778                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
779                          const DeviceMemory<std::complex<float>> &ap,
780                          DeviceMemory<std::complex<float>> *x, int incx) = 0;
781  virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
782                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
783                          const DeviceMemory<std::complex<double>> &ap,
784                          DeviceMemory<std::complex<double>> *x, int incx) = 0;
785
786  // Computes a matrix-vector product using a triangular matrix.
787  //
788  //     x <- a * x,
789  // or
790  //     x <- a' * x,
791  // or
792  //     x <- conj(a') * x,
793  //
794  // a is an n-by-n unit, or non-unit, upper or lower triangular matrix; x is a
795  // n-element vector.
796  virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
797                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
798                          const DeviceMemory<float> &a, int lda,
799                          DeviceMemory<float> *x, int incx) = 0;
800  virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
801                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
802                          const DeviceMemory<double> &a, int lda,
803                          DeviceMemory<double> *x, int incx) = 0;
804  virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
805                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
806                          const DeviceMemory<std::complex<float>> &a, int lda,
807                          DeviceMemory<std::complex<float>> *x, int incx) = 0;
808  virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
809                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
810                          const DeviceMemory<std::complex<double>> &a, int lda,
811                          DeviceMemory<std::complex<double>> *x, int incx) = 0;
812
813  // Solves a system of linear equations whose coefficients are in a triangular
814  // matrix as below:
815  //
816  //     a * x = b,
817  // or
818  //     a' * x = b,
819  // or
820  //     conj(a') * x = b,
821  //
822  // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or
823  // lower triangular matrix.
824  virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
825                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
826                          const DeviceMemory<float> &a, int lda,
827                          DeviceMemory<float> *x, int incx) = 0;
828  virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
829                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
830                          const DeviceMemory<double> &a, int lda,
831                          DeviceMemory<double> *x, int incx) = 0;
832  virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
833                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
834                          const DeviceMemory<std::complex<float>> &a, int lda,
835                          DeviceMemory<std::complex<float>> *x, int incx) = 0;
836  virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
837                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
838                          const DeviceMemory<std::complex<double>> &a, int lda,
839                          DeviceMemory<std::complex<double>> *x, int incx) = 0;
840
841  // Computes a matrix-matrix product with general matrices:
842  //
843  //     c <- alpha * op(a) * op(b) + beta * c,
844  //
845  // op(X) is one of op(X) = X, or op(X) = X', or op(X) = conj(X'); alpha and
846  // beta are scalars; a, b, and c are matrices; op(a) is an m-by-k matrix;
847  // op(b) is a k-by-n matrix; c is an m-by-n matrix.
848  virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
849                          blas::Transpose transb, uint64 m, uint64 n, uint64 k,
850                          float alpha, const DeviceMemory<float> &a, int lda,
851                          const DeviceMemory<float> &b, int ldb, float beta,
852                          DeviceMemory<float> *c, int ldc) = 0;
853  virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
854                          blas::Transpose transb, uint64 m, uint64 n, uint64 k,
855                          double alpha, const DeviceMemory<double> &a, int lda,
856                          const DeviceMemory<double> &b, int ldb, double beta,
857                          DeviceMemory<double> *c, int ldc) = 0;
858  virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
859                          blas::Transpose transb, uint64 m, uint64 n, uint64 k,
860                          std::complex<float> alpha,
861                          const DeviceMemory<std::complex<float>> &a, int lda,
862                          const DeviceMemory<std::complex<float>> &b, int ldb,
863                          std::complex<float> beta,
864                          DeviceMemory<std::complex<float>> *c, int ldc) = 0;
865  virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
866                          blas::Transpose transb, uint64 m, uint64 n, uint64 k,
867                          std::complex<double> alpha,
868                          const DeviceMemory<std::complex<double>> &a, int lda,
869                          const DeviceMemory<std::complex<double>> &b, int ldb,
870                          std::complex<double> beta,
871                          DeviceMemory<std::complex<double>> *c, int ldc) = 0;
872
873  // Computes a batch of matrix-matrix product with general matrices.
874  // This is a batched version of DoBlasGemm.
875  // The batched GEMM computes matrix product for each input/output in a, b,
876  // and c, which contain batch_count DeviceMemory objects.
877  virtual bool DoBlasGemmBatched(
878      Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
879      uint64 n, uint64 k, float alpha,
880      const port::ArraySlice<DeviceMemory<float> *> &a, int lda,
881      const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta,
882      const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
883      int batch_count) = 0;
884  virtual bool DoBlasGemmBatched(
885      Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
886      uint64 n, uint64 k, double alpha,
887      const port::ArraySlice<DeviceMemory<double> *> &a, int lda,
888      const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, double beta,
889      const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
890      int batch_count) = 0;
891  virtual bool DoBlasGemmBatched(
892      Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
893      uint64 n, uint64 k, std::complex<float> alpha,
894      const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
895      const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
896      std::complex<float> beta,
897      const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
898      int batch_count) = 0;
899  virtual bool DoBlasGemmBatched(
900      Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
901      uint64 n, uint64 k, std::complex<double> alpha,
902      const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
903      const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
904      std::complex<double> beta,
905      const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
906      int batch_count) = 0;
907
908  // Computes a matrix-matrix product where one input matrix is Hermitian:
909  //
910  //     c <- alpha * a * b + beta * c,
911  // or
912  //     c <- alpha * b * a + beta * c,
913  //
914  // alpha and beta are scalars; a is a Hermitian matrix; b and c are m-by-n
915  // matrices.
916  virtual bool DoBlasHemm(Stream *stream, blas::Side side,
917                          blas::UpperLower uplo, uint64 m, uint64 n,
918                          std::complex<float> alpha,
919                          const DeviceMemory<std::complex<float>> &a, int lda,
920                          const DeviceMemory<std::complex<float>> &b, int ldb,
921                          std::complex<float> beta,
922                          DeviceMemory<std::complex<float>> *c, int ldc) = 0;
923  virtual bool DoBlasHemm(Stream *stream, blas::Side side,
924                          blas::UpperLower uplo, uint64 m, uint64 n,
925                          std::complex<double> alpha,
926                          const DeviceMemory<std::complex<double>> &a, int lda,
927                          const DeviceMemory<std::complex<double>> &b, int ldb,
928                          std::complex<double> beta,
929                          DeviceMemory<std::complex<double>> *c, int ldc) = 0;
930
931  // Performs a Hermitian rank-k update.
932  //
933  //     c <- alpha * a * conj(a') + beta * c,
934  // or
935  //     c <- alpha * conj(a') * a + beta * c,
936  //
937  // alpha and beta are scalars; c is a n-by-n Hermitian matrix; a is an n-by-k
938  // matrix in the first case and a k-by-n matrix in the second case.
939  virtual bool DoBlasHerk(Stream *stream, blas::UpperLower uplo,
940                          blas::Transpose trans, uint64 n, uint64 k,
941                          float alpha,
942                          const DeviceMemory<std::complex<float>> &a, int lda,
943                          float beta, DeviceMemory<std::complex<float>> *c,
944                          int ldc) = 0;
945  virtual bool DoBlasHerk(Stream *stream, blas::UpperLower uplo,
946                          blas::Transpose trans, uint64 n, uint64 k,
947                          double alpha,
948                          const DeviceMemory<std::complex<double>> &a, int lda,
949                          double beta, DeviceMemory<std::complex<double>> *c,
950                          int ldc) = 0;
951
952  // Performs a Hermitian rank-2k update.
953  //
954  //     c <- alpha * a * conj(b') + conj(alpha) * b * conj(a') + beta * c,
955  // or
956  //     c <- alpha * conj(b') * a + conj(alpha) * conj(a') * b + beta * c,
957  //
958  // alpha and beta are scalars; c is a n-by-n Hermitian matrix; a and b are
959  // n-by-k matrices in the first case and k-by-n matrices in the second case.
960  virtual bool DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
961                           blas::Transpose trans, uint64 n, uint64 k,
962                           std::complex<float> alpha,
963                           const DeviceMemory<std::complex<float>> &a, int lda,
964                           const DeviceMemory<std::complex<float>> &b, int ldb,
965                           float beta, DeviceMemory<std::complex<float>> *c,
966                           int ldc) = 0;
967  virtual bool DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
968                           blas::Transpose trans, uint64 n, uint64 k,
969                           std::complex<double> alpha,
970                           const DeviceMemory<std::complex<double>> &a, int lda,
971                           const DeviceMemory<std::complex<double>> &b, int ldb,
972                           double beta, DeviceMemory<std::complex<double>> *c,
973                           int ldc) = 0;
974
975  // Computes a matrix-matrix product where one input matrix is symmetric.
976  //
977  //     c <- alpha * a * b + beta * c,
978  // or
979  //     c <- alpha * b * a + beta * c,
980  //
981  // alpha and beta are scalars; a is a symmetric matrix; b and c are m-by-n
982  // matrices.
983  virtual bool DoBlasSymm(Stream *stream, blas::Side side,
984                          blas::UpperLower uplo, uint64 m, uint64 n,
985                          float alpha, const DeviceMemory<float> &a, int lda,
986                          const DeviceMemory<float> &b, int ldb, float beta,
987                          DeviceMemory<float> *c, int ldc) = 0;
988  virtual bool DoBlasSymm(Stream *stream, blas::Side side,
989                          blas::UpperLower uplo, uint64 m, uint64 n,
990                          double alpha, const DeviceMemory<double> &a, int lda,
991                          const DeviceMemory<double> &b, int ldb, double beta,
992                          DeviceMemory<double> *c, int ldc) = 0;
993  virtual bool DoBlasSymm(Stream *stream, blas::Side side,
994                          blas::UpperLower uplo, uint64 m, uint64 n,
995                          std::complex<float> alpha,
996                          const DeviceMemory<std::complex<float>> &a, int lda,
997                          const DeviceMemory<std::complex<float>> &b, int ldb,
998                          std::complex<float> beta,
999                          DeviceMemory<std::complex<float>> *c, int ldc) = 0;
1000  virtual bool DoBlasSymm(Stream *stream, blas::Side side,
1001                          blas::UpperLower uplo, uint64 m, uint64 n,
1002                          std::complex<double> alpha,
1003                          const DeviceMemory<std::complex<double>> &a, int lda,
1004                          const DeviceMemory<std::complex<double>> &b, int ldb,
1005                          std::complex<double> beta,
1006                          DeviceMemory<std::complex<double>> *c, int ldc) = 0;
1007
1008  // Performs a symmetric rank-k update.
1009  //
1010  //     c <- alpha * a * a' + beta * c,
1011  // or
1012  //     c <- alpha * a' * a + beta * c,
1013  //
1014  // alpha and beta are scalars; c is a n-by-n symmetric matrix; a is an n-by-k
1015  // matrix in the first case and a k-by-n matrix in the second case.
1016  virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
1017                          blas::Transpose trans, uint64 n, uint64 k,
1018                          float alpha, const DeviceMemory<float> &a, int lda,
1019                          float beta, DeviceMemory<float> *c, int ldc) = 0;
1020  virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
1021                          blas::Transpose trans, uint64 n, uint64 k,
1022                          double alpha, const DeviceMemory<double> &a, int lda,
1023                          double beta, DeviceMemory<double> *c, int ldc) = 0;
1024  virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
1025                          blas::Transpose trans, uint64 n, uint64 k,
1026                          std::complex<float> alpha,
1027                          const DeviceMemory<std::complex<float>> &a, int lda,
1028                          std::complex<float> beta,
1029                          DeviceMemory<std::complex<float>> *c, int ldc) = 0;
1030  virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
1031                          blas::Transpose trans, uint64 n, uint64 k,
1032                          std::complex<double> alpha,
1033                          const DeviceMemory<std::complex<double>> &a, int lda,
1034                          std::complex<double> beta,
1035                          DeviceMemory<std::complex<double>> *c, int ldc) = 0;
1036
1037  // Performs a symmetric rank-2k update.
1038  //
1039  //     c <- alpha * a * b' + alpha * b * a' + beta * c,
1040  // or
1041  //     c <- alpha * b' * a + alpha * a' * b + beta * c,
1042  //
1043  // alpha and beta are scalars; c is a n-by-n symmetric matrix; a and b are
1044  // n-by-k matrices in the first case and k-by-n matrices in the second case.
1045  virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
1046                           blas::Transpose trans, uint64 n, uint64 k,
1047                           float alpha, const DeviceMemory<float> &a, int lda,
1048                           const DeviceMemory<float> &b, int ldb, float beta,
1049                           DeviceMemory<float> *c, int ldc) = 0;
1050  virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
1051                           blas::Transpose trans, uint64 n, uint64 k,
1052                           double alpha, const DeviceMemory<double> &a, int lda,
1053                           const DeviceMemory<double> &b, int ldb, double beta,
1054                           DeviceMemory<double> *c, int ldc) = 0;
1055  virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
1056                           blas::Transpose trans, uint64 n, uint64 k,
1057                           std::complex<float> alpha,
1058                           const DeviceMemory<std::complex<float>> &a, int lda,
1059                           const DeviceMemory<std::complex<float>> &b, int ldb,
1060                           std::complex<float> beta,
1061                           DeviceMemory<std::complex<float>> *c, int ldc) = 0;
1062  virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
1063                           blas::Transpose trans, uint64 n, uint64 k,
1064                           std::complex<double> alpha,
1065                           const DeviceMemory<std::complex<double>> &a, int lda,
1066                           const DeviceMemory<std::complex<double>> &b, int ldb,
1067                           std::complex<double> beta,
1068                           DeviceMemory<std::complex<double>> *c, int ldc) = 0;
1069
1070  // Computes a matrix-matrix product where one input matrix is triangular.
1071  //
1072  //     b <- alpha * op(a) * b,
1073  // or
1074  //     b <- alpha * b * op(a)
1075  //
1076  // alpha is a scalar; b is an m-by-n matrix; a is a unit, or non-unit, upper
1077  // or lower triangular matrix; op(a) is one of op(a) = a, or op(a) = a', or
1078  // op(a) = conj(a').
1079  virtual bool DoBlasTrmm(Stream *stream, blas::Side side,
1080                          blas::UpperLower uplo, blas::Transpose transa,
1081                          blas::Diagonal diag, uint64 m, uint64 n, float alpha,
1082                          const DeviceMemory<float> &a, int lda,
1083                          DeviceMemory<float> *b, int ldb) = 0;
1084  virtual bool DoBlasTrmm(Stream *stream, blas::Side side,
1085                          blas::UpperLower uplo, blas::Transpose transa,
1086                          blas::Diagonal diag, uint64 m, uint64 n, double alpha,
1087                          const DeviceMemory<double> &a, int lda,
1088                          DeviceMemory<double> *b, int ldb) = 0;
1089  virtual bool DoBlasTrmm(Stream *stream, blas::Side side,
1090                          blas::UpperLower uplo, blas::Transpose transa,
1091                          blas::Diagonal diag, uint64 m, uint64 n,
1092                          std::complex<float> alpha,
1093                          const DeviceMemory<std::complex<float>> &a, int lda,
1094                          DeviceMemory<std::complex<float>> *b, int ldb) = 0;
1095  virtual bool DoBlasTrmm(Stream *stream, blas::Side side,
1096                          blas::UpperLower uplo, blas::Transpose transa,
1097                          blas::Diagonal diag, uint64 m, uint64 n,
1098                          std::complex<double> alpha,
1099                          const DeviceMemory<std::complex<double>> &a, int lda,
1100                          DeviceMemory<std::complex<double>> *b, int ldb) = 0;
1101
1102  // Solves a triangular matrix equation.
1103  //
1104  //     op(a) * x = alpha * b,
1105  // or
1106  //     x * op(a) = alpha * b
1107  //
1108  // alpha is a scalar; x and b are m-by-n matrices; a is a unit, or non-unit,
1109  // upper or lower triangular matrix; op(a) is one of op(a) = a, or op(a) = a',
1110  // or op(a) = conj(a').
1111  virtual bool DoBlasTrsm(Stream *stream, blas::Side side,
1112                          blas::UpperLower uplo, blas::Transpose transa,
1113                          blas::Diagonal diag, uint64 m, uint64 n, float alpha,
1114                          const DeviceMemory<float> &a, int lda,
1115                          DeviceMemory<float> *b, int ldb) = 0;
1116  virtual bool DoBlasTrsm(Stream *stream, blas::Side side,
1117                          blas::UpperLower uplo, blas::Transpose transa,
1118                          blas::Diagonal diag, uint64 m, uint64 n, double alpha,
1119                          const DeviceMemory<double> &a, int lda,
1120                          DeviceMemory<double> *b, int ldb) = 0;
1121  virtual bool DoBlasTrsm(Stream *stream, blas::Side side,
1122                          blas::UpperLower uplo, blas::Transpose transa,
1123                          blas::Diagonal diag, uint64 m, uint64 n,
1124                          std::complex<float> alpha,
1125                          const DeviceMemory<std::complex<float>> &a, int lda,
1126                          DeviceMemory<std::complex<float>> *b, int ldb) = 0;
1127  virtual bool DoBlasTrsm(Stream *stream, blas::Side side,
1128                          blas::UpperLower uplo, blas::Transpose transa,
1129                          blas::Diagonal diag, uint64 m, uint64 n,
1130                          std::complex<double> alpha,
1131                          const DeviceMemory<std::complex<double>> &a, int lda,
1132                          DeviceMemory<std::complex<double>> *b, int ldb) = 0;
1133
1134 protected:
1135  BlasSupport() {}
1136
1137 private:
1138  SE_DISALLOW_COPY_AND_ASSIGN(BlasSupport);
1139};
1140
1141// Macro used to quickly declare overrides for abstract virtuals in the
1142// BlasSupport base class.
1143#define TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES                  \
1144  bool DoBlasAsum(Stream *stream, uint64 elem_count,                           \
1145                  const DeviceMemory<float> &x, int incx,                      \
1146                  DeviceMemory<float> *result) override;                       \
1147  bool DoBlasAsum(Stream *stream, uint64 elem_count,                           \
1148                  const DeviceMemory<double> &x, int incx,                     \
1149                  DeviceMemory<double> *result) override;                      \
1150  bool DoBlasAsum(Stream *stream, uint64 elem_count,                           \
1151                  const DeviceMemory<std::complex<float>> &x, int incx,        \
1152                  DeviceMemory<float> *result) override;                       \
1153  bool DoBlasAsum(Stream *stream, uint64 elem_count,                           \
1154                  const DeviceMemory<std::complex<double>> &x, int incx,       \
1155                  DeviceMemory<double> *result) override;                      \
1156  bool DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha,              \
1157                  const DeviceMemory<float> &x, int incx,                      \
1158                  DeviceMemory<float> *y, int incy) override;                  \
1159  bool DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha,             \
1160                  const DeviceMemory<double> &x, int incx,                     \
1161                  DeviceMemory<double> *y, int incy) override;                 \
1162  bool DoBlasAxpy(Stream *stream, uint64 elem_count,                           \
1163                  std::complex<float> alpha,                                   \
1164                  const DeviceMemory<std::complex<float>> &x, int incx,        \
1165                  DeviceMemory<std::complex<float>> *y, int incy) override;    \
1166  bool DoBlasAxpy(Stream *stream, uint64 elem_count,                           \
1167                  std::complex<double> alpha,                                  \
1168                  const DeviceMemory<std::complex<double>> &x, int incx,       \
1169                  DeviceMemory<std::complex<double>> *y, int incy) override;   \
1170  bool DoBlasCopy(Stream *stream, uint64 elem_count,                           \
1171                  const DeviceMemory<float> &x, int incx,                      \
1172                  DeviceMemory<float> *y, int incy) override;                  \
1173  bool DoBlasCopy(Stream *stream, uint64 elem_count,                           \
1174                  const DeviceMemory<double> &x, int incx,                     \
1175                  DeviceMemory<double> *y, int incy) override;                 \
1176  bool DoBlasCopy(Stream *stream, uint64 elem_count,                           \
1177                  const DeviceMemory<std::complex<float>> &x, int incx,        \
1178                  DeviceMemory<std::complex<float>> *y, int incy) override;    \
1179  bool DoBlasCopy(Stream *stream, uint64 elem_count,                           \
1180                  const DeviceMemory<std::complex<double>> &x, int incx,       \
1181                  DeviceMemory<std::complex<double>> *y, int incy) override;   \
1182  bool DoBlasDot(Stream *stream, uint64 elem_count,                            \
1183                 const DeviceMemory<float> &x, int incx,                       \
1184                 const DeviceMemory<float> &y, int incy,                       \
1185                 DeviceMemory<float> *result) override;                        \
1186  bool DoBlasDot(Stream *stream, uint64 elem_count,                            \
1187                 const DeviceMemory<double> &x, int incx,                      \
1188                 const DeviceMemory<double> &y, int incy,                      \
1189                 DeviceMemory<double> *result) override;                       \
1190  bool DoBlasDotc(Stream *stream, uint64 elem_count,                           \
1191                  const DeviceMemory<std::complex<float>> &x, int incx,        \
1192                  const DeviceMemory<std::complex<float>> &y, int incy,        \
1193                  DeviceMemory<std::complex<float>> *result) override;         \
1194  bool DoBlasDotc(Stream *stream, uint64 elem_count,                           \
1195                  const DeviceMemory<std::complex<double>> &x, int incx,       \
1196                  const DeviceMemory<std::complex<double>> &y, int incy,       \
1197                  DeviceMemory<std::complex<double>> *result) override;        \
1198  bool DoBlasDotu(Stream *stream, uint64 elem_count,                           \
1199                  const DeviceMemory<std::complex<float>> &x, int incx,        \
1200                  const DeviceMemory<std::complex<float>> &y, int incy,        \
1201                  DeviceMemory<std::complex<float>> *result) override;         \
1202  bool DoBlasDotu(Stream *stream, uint64 elem_count,                           \
1203                  const DeviceMemory<std::complex<double>> &x, int incx,       \
1204                  const DeviceMemory<std::complex<double>> &y, int incy,       \
1205                  DeviceMemory<std::complex<double>> *result) override;        \
1206  bool DoBlasNrm2(Stream *stream, uint64 elem_count,                           \
1207                  const DeviceMemory<float> &x, int incx,                      \
1208                  DeviceMemory<float> *result) override;                       \
1209  bool DoBlasNrm2(Stream *stream, uint64 elem_count,                           \
1210                  const DeviceMemory<double> &x, int incx,                     \
1211                  DeviceMemory<double> *result) override;                      \
1212  bool DoBlasNrm2(Stream *stream, uint64 elem_count,                           \
1213                  const DeviceMemory<std::complex<float>> &x, int incx,        \
1214                  DeviceMemory<float> *result) override;                       \
1215  bool DoBlasNrm2(Stream *stream, uint64 elem_count,                           \
1216                  const DeviceMemory<std::complex<double>> &x, int incx,       \
1217                  DeviceMemory<double> *result) override;                      \
1218  bool DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory<float> *x,    \
1219                 int incx, DeviceMemory<float> *y, int incy, float c, float s) \
1220      override;                                                                \
1221  bool DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory<double> *x,   \
1222                 int incx, DeviceMemory<double> *y, int incy, double c,        \
1223                 double s) override;                                           \
1224  bool DoBlasRot(Stream *stream, uint64 elem_count,                            \
1225                 DeviceMemory<std::complex<float>> *x, int incx,               \
1226                 DeviceMemory<std::complex<float>> *y, int incy, float c,      \
1227                 float s) override;                                            \
1228  bool DoBlasRot(Stream *stream, uint64 elem_count,                            \
1229                 DeviceMemory<std::complex<double>> *x, int incx,              \
1230                 DeviceMemory<std::complex<double>> *y, int incy, double c,    \
1231                 double s) override;                                           \
1232  bool DoBlasRotg(Stream *stream, DeviceMemory<float> *a,                      \
1233                  DeviceMemory<float> *b, DeviceMemory<float> *c,              \
1234                  DeviceMemory<float> *s) override;                            \
1235  bool DoBlasRotg(Stream *stream, DeviceMemory<double> *a,                     \
1236                  DeviceMemory<double> *b, DeviceMemory<double> *c,            \
1237                  DeviceMemory<double> *s) override;                           \
1238  bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a,        \
1239                  DeviceMemory<std::complex<float>> *b,                        \
1240                  DeviceMemory<float> *c,                                      \
1241                  DeviceMemory<std::complex<float>> *s) override;              \
1242  bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a,       \
1243                  DeviceMemory<std::complex<double>> *b,                       \
1244                  DeviceMemory<double> *c,                                     \
1245                  DeviceMemory<std::complex<double>> *s) override;             \
1246  bool DoBlasRotm(Stream *stream, uint64 elem_count, DeviceMemory<float> *x,   \
1247                  int incx, DeviceMemory<float> *y, int incy,                  \
1248                  const DeviceMemory<float> &param) override;                  \
1249  bool DoBlasRotm(Stream *stream, uint64 elem_count, DeviceMemory<double> *x,  \
1250                  int incx, DeviceMemory<double> *y, int incy,                 \
1251                  const DeviceMemory<double> &param) override;                 \
1252  bool DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1,                    \
1253                   DeviceMemory<float> *d2, DeviceMemory<float> *x1,           \
1254                   const DeviceMemory<float> &y1, DeviceMemory<float> *param)  \
1255      override;                                                                \
1256  bool DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1,                   \
1257                   DeviceMemory<double> *d2, DeviceMemory<double> *x1,         \
1258                   const DeviceMemory<double> &y1,                             \
1259                   DeviceMemory<double> *param) override;                      \
1260  bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha,              \
1261                  DeviceMemory<float> *x, int incx) override;                  \
1262  bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha,             \
1263                  DeviceMemory<double> *x, int incx) override;                 \
1264  bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha,              \
1265                  DeviceMemory<std::complex<float>> *x, int incx) override;    \
1266  bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha,             \
1267                  DeviceMemory<std::complex<double>> *x, int incx) override;   \
1268  bool DoBlasScal(Stream *stream, uint64 elem_count,                           \
1269                  std::complex<float> alpha,                                   \
1270                  DeviceMemory<std::complex<float>> *x, int incx) override;    \
1271  bool DoBlasScal(Stream *stream, uint64 elem_count,                           \
1272                  std::complex<double> alpha,                                  \
1273                  DeviceMemory<std::complex<double>> *x, int incx) override;   \
1274  bool DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory<float> *x,   \
1275                  int incx, DeviceMemory<float> *y, int incy) override;        \
1276  bool DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory<double> *x,  \
1277                  int incx, DeviceMemory<double> *y, int incy) override;       \
1278  bool DoBlasSwap(Stream *stream, uint64 elem_count,                           \
1279                  DeviceMemory<std::complex<float>> *x, int incx,              \
1280                  DeviceMemory<std::complex<float>> *y, int incy) override;    \
1281  bool DoBlasSwap(Stream *stream, uint64 elem_count,                           \
1282                  DeviceMemory<std::complex<double>> *x, int incx,             \
1283                  DeviceMemory<std::complex<double>> *y, int incy) override;   \
1284  bool DoBlasIamax(Stream *stream, uint64 elem_count,                          \
1285                   const DeviceMemory<float> &x, int incx,                     \
1286                   DeviceMemory<int> *result) override;                        \
1287  bool DoBlasIamax(Stream *stream, uint64 elem_count,                          \
1288                   const DeviceMemory<double> &x, int incx,                    \
1289                   DeviceMemory<int> *result) override;                        \
1290  bool DoBlasIamax(Stream *stream, uint64 elem_count,                          \
1291                   const DeviceMemory<std::complex<float>> &x, int incx,       \
1292                   DeviceMemory<int> *result) override;                        \
1293  bool DoBlasIamax(Stream *stream, uint64 elem_count,                          \
1294                   const DeviceMemory<std::complex<double>> &x, int incx,      \
1295                   DeviceMemory<int> *result) override;                        \
1296  bool DoBlasIamin(Stream *stream, uint64 elem_count,                          \
1297                   const DeviceMemory<float> &x, int incx,                     \
1298                   DeviceMemory<int> *result) override;                        \
1299  bool DoBlasIamin(Stream *stream, uint64 elem_count,                          \
1300                   const DeviceMemory<double> &x, int incx,                    \
1301                   DeviceMemory<int> *result) override;                        \
1302  bool DoBlasIamin(Stream *stream, uint64 elem_count,                          \
1303                   const DeviceMemory<std::complex<float>> &x, int incx,       \
1304                   DeviceMemory<int> *result) override;                        \
1305  bool DoBlasIamin(Stream *stream, uint64 elem_count,                          \
1306                   const DeviceMemory<std::complex<double>> &x, int incx,      \
1307                   DeviceMemory<int> *result) override;                        \
1308  bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1309                  uint64 kl, uint64 ku, float alpha,                           \
1310                  const DeviceMemory<float> &a, int lda,                       \
1311                  const DeviceMemory<float> &x, int incx, float beta,          \
1312                  DeviceMemory<float> *y, int incy) override;                  \
1313  bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1314                  uint64 kl, uint64 ku, double alpha,                          \
1315                  const DeviceMemory<double> &a, int lda,                      \
1316                  const DeviceMemory<double> &x, int incx, double beta,        \
1317                  DeviceMemory<double> *y, int incy) override;                 \
1318  bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1319                  uint64 kl, uint64 ku, std::complex<float> alpha,             \
1320                  const DeviceMemory<std::complex<float>> &a, int lda,         \
1321                  const DeviceMemory<std::complex<float>> &x, int incx,        \
1322                  std::complex<float> beta,                                    \
1323                  DeviceMemory<std::complex<float>> *y, int incy) override;    \
1324  bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1325                  uint64 kl, uint64 ku, std::complex<double> alpha,            \
1326                  const DeviceMemory<std::complex<double>> &a, int lda,        \
1327                  const DeviceMemory<std::complex<double>> &x, int incx,       \
1328                  std::complex<double> beta,                                   \
1329                  DeviceMemory<std::complex<double>> *y, int incy) override;   \
1330  bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1331                  float alpha, const DeviceMemory<float> &a, int lda,          \
1332                  const DeviceMemory<float> &x, int incx, float beta,          \
1333                  DeviceMemory<float> *y, int incy) override;                  \
1334  bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1335                  double alpha, const DeviceMemory<double> &a, int lda,        \
1336                  const DeviceMemory<double> &x, int incx, double beta,        \
1337                  DeviceMemory<double> *y, int incy) override;                 \
1338  bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1339                  std::complex<float> alpha,                                   \
1340                  const DeviceMemory<std::complex<float>> &a, int lda,         \
1341                  const DeviceMemory<std::complex<float>> &x, int incx,        \
1342                  std::complex<float> beta,                                    \
1343                  DeviceMemory<std::complex<float>> *y, int incy) override;    \
1344  bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1345                  std::complex<double> alpha,                                  \
1346                  const DeviceMemory<std::complex<double>> &a, int lda,        \
1347                  const DeviceMemory<std::complex<double>> &x, int incx,       \
1348                  std::complex<double> beta,                                   \
1349                  DeviceMemory<std::complex<double>> *y, int incy) override;   \
1350  bool DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha,              \
1351                 const DeviceMemory<float> &x, int incx,                       \
1352                 const DeviceMemory<float> &y, int incy,                       \
1353                 DeviceMemory<float> *a, int lda) override;                    \
1354  bool DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha,             \
1355                 const DeviceMemory<double> &x, int incx,                      \
1356                 const DeviceMemory<double> &y, int incy,                      \
1357                 DeviceMemory<double> *a, int lda) override;                   \
1358  bool DoBlasGerc(Stream *stream, uint64 m, uint64 n,                          \
1359                  std::complex<float> alpha,                                   \
1360                  const DeviceMemory<std::complex<float>> &x, int incx,        \
1361                  const DeviceMemory<std::complex<float>> &y, int incy,        \
1362                  DeviceMemory<std::complex<float>> *a, int lda) override;     \
1363  bool DoBlasGerc(Stream *stream, uint64 m, uint64 n,                          \
1364                  std::complex<double> alpha,                                  \
1365                  const DeviceMemory<std::complex<double>> &x, int incx,       \
1366                  const DeviceMemory<std::complex<double>> &y, int incy,       \
1367                  DeviceMemory<std::complex<double>> *a, int lda) override;    \
1368  bool DoBlasGeru(Stream *stream, uint64 m, uint64 n,                          \
1369                  std::complex<float> alpha,                                   \
1370                  const DeviceMemory<std::complex<float>> &x, int incx,        \
1371                  const DeviceMemory<std::complex<float>> &y, int incy,        \
1372                  DeviceMemory<std::complex<float>> *a, int lda) override;     \
1373  bool DoBlasGeru(Stream *stream, uint64 m, uint64 n,                          \
1374                  std::complex<double> alpha,                                  \
1375                  const DeviceMemory<std::complex<double>> &x, int incx,       \
1376                  const DeviceMemory<std::complex<double>> &y, int incy,       \
1377                  DeviceMemory<std::complex<double>> *a, int lda) override;    \
1378  bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k,   \
1379                  std::complex<float> alpha,                                   \
1380                  const DeviceMemory<std::complex<float>> &a, int lda,         \
1381                  const DeviceMemory<std::complex<float>> &x, int incx,        \
1382                  std::complex<float> beta,                                    \
1383                  DeviceMemory<std::complex<float>> *y, int incy) override;    \
1384  bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k,   \
1385                  std::complex<double> alpha,                                  \
1386                  const DeviceMemory<std::complex<double>> &a, int lda,        \
1387                  const DeviceMemory<std::complex<double>> &x, int incx,       \
1388                  std::complex<double> beta,                                   \
1389                  DeviceMemory<std::complex<double>> *y, int incy) override;   \
1390  bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1391                  std::complex<float> alpha,                                   \
1392                  const DeviceMemory<std::complex<float>> &a, int lda,         \
1393                  const DeviceMemory<std::complex<float>> &x, int incx,        \
1394                  std::complex<float> beta,                                    \
1395                  DeviceMemory<std::complex<float>> *y, int incy) override;    \
1396  bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1397                  std::complex<double> alpha,                                  \
1398                  const DeviceMemory<std::complex<double>> &a, int lda,        \
1399                  const DeviceMemory<std::complex<double>> &x, int incx,       \
1400                  std::complex<double> beta,                                   \
1401                  DeviceMemory<std::complex<double>> *y, int incy) override;   \
1402  bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \
1403                 const DeviceMemory<std::complex<float>> &x, int incx,         \
1404                 DeviceMemory<std::complex<float>> *a, int lda) override;      \
1405  bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,              \
1406                 double alpha, const DeviceMemory<std::complex<double>> &x,    \
1407                 int incx, DeviceMemory<std::complex<double>> *a, int lda)     \
1408      override;                                                                \
1409  bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1410                  std::complex<float> alpha,                                   \
1411                  const DeviceMemory<std::complex<float>> &x, int incx,        \
1412                  const DeviceMemory<std::complex<float>> &y, int incy,        \
1413                  DeviceMemory<std::complex<float>> *a, int lda) override;     \
1414  bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1415                  std::complex<double> alpha,                                  \
1416                  const DeviceMemory<std::complex<double>> &x, int incx,       \
1417                  const DeviceMemory<std::complex<double>> &y, int incy,       \
1418                  DeviceMemory<std::complex<double>> *a, int lda) override;    \
1419  bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1420                  std::complex<float> alpha,                                   \
1421                  const DeviceMemory<std::complex<float>> &ap,                 \
1422                  const DeviceMemory<std::complex<float>> &x, int incx,        \
1423                  std::complex<float> beta,                                    \
1424                  DeviceMemory<std::complex<float>> *y, int incy) override;    \
1425  bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1426                  std::complex<double> alpha,                                  \
1427                  const DeviceMemory<std::complex<double>> &ap,                \
1428                  const DeviceMemory<std::complex<double>> &x, int incx,       \
1429                  std::complex<double> beta,                                   \
1430                  DeviceMemory<std::complex<double>> *y, int incy) override;   \
1431  bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \
1432                 const DeviceMemory<std::complex<float>> &x, int incx,         \
1433                 DeviceMemory<std::complex<float>> *ap) override;              \
1434  bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,              \
1435                 double alpha, const DeviceMemory<std::complex<double>> &x,    \
1436                 int incx, DeviceMemory<std::complex<double>> *ap) override;   \
1437  bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1438                  std::complex<float> alpha,                                   \
1439                  const DeviceMemory<std::complex<float>> &x, int incx,        \
1440                  const DeviceMemory<std::complex<float>> &y, int incy,        \
1441                  DeviceMemory<std::complex<float>> *ap) override;             \
1442  bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1443                  std::complex<double> alpha,                                  \
1444                  const DeviceMemory<std::complex<double>> &x, int incx,       \
1445                  const DeviceMemory<std::complex<double>> &y, int incy,       \
1446                  DeviceMemory<std::complex<double>> *ap) override;            \
1447  bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k,   \
1448                  float alpha, const DeviceMemory<float> &a, int lda,          \
1449                  const DeviceMemory<float> &x, int incx, float beta,          \
1450                  DeviceMemory<float> *y, int incy) override;                  \
1451  bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k,   \
1452                  double alpha, const DeviceMemory<double> &a, int lda,        \
1453                  const DeviceMemory<double> &x, int incx, double beta,        \
1454                  DeviceMemory<double> *y, int incy) override;                 \
1455  bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1456                  float alpha, const DeviceMemory<float> &ap,                  \
1457                  const DeviceMemory<float> &x, int incx, float beta,          \
1458                  DeviceMemory<float> *y, int incy) override;                  \
1459  bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1460                  double alpha, const DeviceMemory<double> &ap,                \
1461                  const DeviceMemory<double> &x, int incx, double beta,        \
1462                  DeviceMemory<double> *y, int incy) override;                 \
1463  bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \
1464                 const DeviceMemory<float> &x, int incx,                       \
1465                 DeviceMemory<float> *ap) override;                            \
1466  bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,              \
1467                 double alpha, const DeviceMemory<double> &x, int incx,        \
1468                 DeviceMemory<double> *ap) override;                           \
1469  bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1470                  float alpha, const DeviceMemory<float> &x, int incx,         \
1471                  const DeviceMemory<float> &y, int incy,                      \
1472                  DeviceMemory<float> *ap) override;                           \
1473  bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1474                  double alpha, const DeviceMemory<double> &x, int incx,       \
1475                  const DeviceMemory<double> &y, int incy,                     \
1476                  DeviceMemory<double> *ap) override;                          \
1477  bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1478                  float alpha, const DeviceMemory<float> &a, int lda,          \
1479                  const DeviceMemory<float> &x, int incx, float beta,          \
1480                  DeviceMemory<float> *y, int incy) override;                  \
1481  bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1482                  double alpha, const DeviceMemory<double> &a, int lda,        \
1483                  const DeviceMemory<double> &x, int incx, double beta,        \
1484                  DeviceMemory<double> *y, int incy) override;                 \
1485  bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \
1486                 const DeviceMemory<float> &x, int incx,                       \
1487                 DeviceMemory<float> *a, int lda) override;                    \
1488  bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,              \
1489                 double alpha, const DeviceMemory<double> &x, int incx,        \
1490                 DeviceMemory<double> *a, int lda) override;                   \
1491  bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1492                  float alpha, const DeviceMemory<float> &x, int incx,         \
1493                  const DeviceMemory<float> &y, int incy,                      \
1494                  DeviceMemory<float> *a, int lda) override;                   \
1495  bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1496                  double alpha, const DeviceMemory<double> &x, int incx,       \
1497                  const DeviceMemory<double> &y, int incy,                     \
1498                  DeviceMemory<double> *a, int lda) override;                  \
1499  bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,                       \
1500                  blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1501                  uint64 k, const DeviceMemory<float> &a, int lda,             \
1502                  DeviceMemory<float> *x, int incx) override;                  \
1503  bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,                       \
1504                  blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1505                  uint64 k, const DeviceMemory<double> &a, int lda,            \
1506                  DeviceMemory<double> *x, int incx) override;                 \
1507  bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,                       \
1508                  blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1509                  uint64 k, const DeviceMemory<std::complex<float>> &a,        \
1510                  int lda, DeviceMemory<std::complex<float>> *x, int incx)     \
1511      override;                                                                \
1512  bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,                       \
1513                  blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1514                  uint64 k, const DeviceMemory<std::complex<double>> &a,       \
1515                  int lda, DeviceMemory<std::complex<double>> *x, int incx)    \
1516      override;                                                                \
1517  bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,                       \
1518                  blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1519                  uint64 k, const DeviceMemory<float> &a, int lda,             \
1520                  DeviceMemory<float> *x, int incx) override;                  \
1521  bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,                       \
1522                  blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1523                  uint64 k, const DeviceMemory<double> &a, int lda,            \
1524                  DeviceMemory<double> *x, int incx) override;                 \
1525  bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,                       \
1526                  blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1527                  uint64 k, const DeviceMemory<std::complex<float>> &a,        \
1528                  int lda, DeviceMemory<std::complex<float>> *x, int incx)     \
1529      override;                                                                \
1530  bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,                       \
1531                  blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1532                  uint64 k, const DeviceMemory<std::complex<double>> &a,       \
1533                  int lda, DeviceMemory<std::complex<double>> *x, int incx)    \
1534      override;                                                                \
1535  bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,                       \
1536                  blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1537                  const DeviceMemory<float> &ap, DeviceMemory<float> *x,       \
1538                  int incx) override;                                          \
1539  bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,                       \
1540                  blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1541                  const DeviceMemory<double> &ap, DeviceMemory<double> *x,     \
1542                  int incx) override;                                          \
1543  bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,                       \
1544                  blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1545                  const DeviceMemory<std::complex<float>> &ap,                 \
1546                  DeviceMemory<std::complex<float>> *x, int incx) override;    \
1547  bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,                       \
1548                  blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1549                  const DeviceMemory<std::complex<double>> &ap,                \
1550                  DeviceMemory<std::complex<double>> *x, int incx) override;   \
1551  bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,                       \
1552                  blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1553                  const DeviceMemory<float> &ap, DeviceMemory<float> *x,       \
1554                  int incx) override;                                          \
1555  bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,                       \
1556                  blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1557                  const DeviceMemory<double> &ap, DeviceMemory<double> *x,     \
1558                  int incx) override;                                          \
1559  bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,                       \
1560                  blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1561                  const DeviceMemory<std::complex<float>> &ap,                 \
1562                  DeviceMemory<std::complex<float>> *x, int incx) override;    \
1563  bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,                       \
1564                  blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1565                  const DeviceMemory<std::complex<double>> &ap,                \
1566                  DeviceMemory<std::complex<double>> *x, int incx) override;   \
1567  bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,                       \
1568                  blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1569                  const DeviceMemory<float> &a, int lda,                       \
1570                  DeviceMemory<float> *x, int incx) override;                  \
1571  bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,                       \
1572                  blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1573                  const DeviceMemory<double> &a, int lda,                      \
1574                  DeviceMemory<double> *x, int incx) override;                 \
1575  bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,                       \
1576                  blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1577                  const DeviceMemory<std::complex<float>> &a, int lda,         \
1578                  DeviceMemory<std::complex<float>> *x, int incx) override;    \
1579  bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,                       \
1580                  blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1581                  const DeviceMemory<std::complex<double>> &a, int lda,        \
1582                  DeviceMemory<std::complex<double>> *x, int incx) override;   \
1583  bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,                       \
1584                  blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1585                  const DeviceMemory<float> &a, int lda,                       \
1586                  DeviceMemory<float> *x, int incx) override;                  \
1587  bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,                       \
1588                  blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1589                  const DeviceMemory<double> &a, int lda,                      \
1590                  DeviceMemory<double> *x, int incx) override;                 \
1591  bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,                       \
1592                  blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1593                  const DeviceMemory<std::complex<float>> &a, int lda,         \
1594                  DeviceMemory<std::complex<float>> *x, int incx) override;    \
1595  bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,                       \
1596                  blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1597                  const DeviceMemory<std::complex<double>> &a, int lda,        \
1598                  DeviceMemory<std::complex<double>> *x, int incx) override;   \
1599  bool DoBlasGemm(Stream *stream, blas::Transpose transa,                      \
1600                  blas::Transpose transb, uint64 m, uint64 n, uint64 k,        \
1601                  float alpha, const DeviceMemory<float> &a, int lda,          \
1602                  const DeviceMemory<float> &b, int ldb, float beta,           \
1603                  DeviceMemory<float> *c, int ldc) override;                   \
1604  bool DoBlasGemm(Stream *stream, blas::Transpose transa,                      \
1605                  blas::Transpose transb, uint64 m, uint64 n, uint64 k,        \
1606                  double alpha, const DeviceMemory<double> &a, int lda,        \
1607                  const DeviceMemory<double> &b, int ldb, double beta,         \
1608                  DeviceMemory<double> *c, int ldc) override;                  \
1609  bool DoBlasGemm(Stream *stream, blas::Transpose transa,                      \
1610                  blas::Transpose transb, uint64 m, uint64 n, uint64 k,        \
1611                  std::complex<float> alpha,                                   \
1612                  const DeviceMemory<std::complex<float>> &a, int lda,         \
1613                  const DeviceMemory<std::complex<float>> &b, int ldb,         \
1614                  std::complex<float> beta,                                    \
1615                  DeviceMemory<std::complex<float>> *c, int ldc) override;     \
1616  bool DoBlasGemm(Stream *stream, blas::Transpose transa,                      \
1617                  blas::Transpose transb, uint64 m, uint64 n, uint64 k,        \
1618                  std::complex<double> alpha,                                  \
1619                  const DeviceMemory<std::complex<double>> &a, int lda,        \
1620                  const DeviceMemory<std::complex<double>> &b, int ldb,        \
1621                  std::complex<double> beta,                                   \
1622                  DeviceMemory<std::complex<double>> *c, int ldc) override;    \
1623  bool DoBlasGemmBatched(                                                      \
1624      Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1625      uint64 m, uint64 n, uint64 k, float alpha,                               \
1626      const port::ArraySlice<DeviceMemory<float> *> &a, int lda,               \
1627      const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta,   \
1628      const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,               \
1629      int batch_count) override;                                               \
1630  bool DoBlasGemmBatched(                                                      \
1631      Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1632      uint64 m, uint64 n, uint64 k, double alpha,                              \
1633      const port::ArraySlice<DeviceMemory<double> *> &a, int lda,              \
1634      const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, double beta, \
1635      const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,              \
1636      int batch_count) override;                                               \
1637  bool DoBlasGemmBatched(                                                      \
1638      Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1639      uint64 m, uint64 n, uint64 k, std::complex<float> alpha,                 \
1640      const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda, \
1641      const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb, \
1642      std::complex<float> beta,                                                \
1643      const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc, \
1644      int batch_count) override;                                               \
1645  bool DoBlasGemmBatched(                                                      \
1646      Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1647      uint64 m, uint64 n, uint64 k, std::complex<double> alpha,                \
1648      const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a,         \
1649      int lda,                                                                 \
1650      const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b,         \
1651      int ldb, std::complex<double> beta,                                      \
1652      const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c,         \
1653      int ldc, int batch_count) override;                                      \
1654  bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
1655                  uint64 m, uint64 n, std::complex<float> alpha,               \
1656                  const DeviceMemory<std::complex<float>> &a, int lda,         \
1657                  const DeviceMemory<std::complex<float>> &b, int ldb,         \
1658                  std::complex<float> beta,                                    \
1659                  DeviceMemory<std::complex<float>> *c, int ldc) override;     \
1660  bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
1661                  uint64 m, uint64 n, std::complex<double> alpha,              \
1662                  const DeviceMemory<std::complex<double>> &a, int lda,        \
1663                  const DeviceMemory<std::complex<double>> &b, int ldb,        \
1664                  std::complex<double> beta,                                   \
1665                  DeviceMemory<std::complex<double>> *c, int ldc) override;    \
1666  bool DoBlasHerk(Stream *stream, blas::UpperLower uplo,                       \
1667                  blas::Transpose trans, uint64 n, uint64 k, float alpha,      \
1668                  const DeviceMemory<std::complex<float>> &a, int lda,         \
1669                  float beta, DeviceMemory<std::complex<float>> *c, int ldc)   \
1670      override;                                                                \
1671  bool DoBlasHerk(Stream *stream, blas::UpperLower uplo,                       \
1672                  blas::Transpose trans, uint64 n, uint64 k, double alpha,     \
1673                  const DeviceMemory<std::complex<double>> &a, int lda,        \
1674                  double beta, DeviceMemory<std::complex<double>> *c, int ldc) \
1675      override;                                                                \
1676  bool DoBlasHer2k(                                                            \
1677      Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n,  \
1678      uint64 k, std::complex<float> alpha,                                     \
1679      const DeviceMemory<std::complex<float>> &a, int lda,                     \
1680      const DeviceMemory<std::complex<float>> &b, int ldb, float beta,         \
1681      DeviceMemory<std::complex<float>> *c, int ldc) override;                 \
1682  bool DoBlasHer2k(                                                            \
1683      Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n,  \
1684      uint64 k, std::complex<double> alpha,                                    \
1685      const DeviceMemory<std::complex<double>> &a, int lda,                    \
1686      const DeviceMemory<std::complex<double>> &b, int ldb, double beta,       \
1687      DeviceMemory<std::complex<double>> *c, int ldc) override;                \
1688  bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
1689                  uint64 m, uint64 n, float alpha,                             \
1690                  const DeviceMemory<float> &a, int lda,                       \
1691                  const DeviceMemory<float> &b, int ldb, float beta,           \
1692                  DeviceMemory<float> *c, int ldc) override;                   \
1693  bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
1694                  uint64 m, uint64 n, double alpha,                            \
1695                  const DeviceMemory<double> &a, int lda,                      \
1696                  const DeviceMemory<double> &b, int ldb, double beta,         \
1697                  DeviceMemory<double> *c, int ldc) override;                  \
1698  bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
1699                  uint64 m, uint64 n, std::complex<float> alpha,               \
1700                  const DeviceMemory<std::complex<float>> &a, int lda,         \
1701                  const DeviceMemory<std::complex<float>> &b, int ldb,         \
1702                  std::complex<float> beta,                                    \
1703                  DeviceMemory<std::complex<float>> *c, int ldc) override;     \
1704  bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
1705                  uint64 m, uint64 n, std::complex<double> alpha,              \
1706                  const DeviceMemory<std::complex<double>> &a, int lda,        \
1707                  const DeviceMemory<std::complex<double>> &b, int ldb,        \
1708                  std::complex<double> beta,                                   \
1709                  DeviceMemory<std::complex<double>> *c, int ldc) override;    \
1710  bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,                       \
1711                  blas::Transpose trans, uint64 n, uint64 k, float alpha,      \
1712                  const DeviceMemory<float> &a, int lda, float beta,           \
1713                  DeviceMemory<float> *c, int ldc) override;                   \
1714  bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,                       \
1715                  blas::Transpose trans, uint64 n, uint64 k, double alpha,     \
1716                  const DeviceMemory<double> &a, int lda, double beta,         \
1717                  DeviceMemory<double> *c, int ldc) override;                  \
1718  bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,                       \
1719                  blas::Transpose trans, uint64 n, uint64 k,                   \
1720                  std::complex<float> alpha,                                   \
1721                  const DeviceMemory<std::complex<float>> &a, int lda,         \
1722                  std::complex<float> beta,                                    \
1723                  DeviceMemory<std::complex<float>> *c, int ldc) override;     \
1724  bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,                       \
1725                  blas::Transpose trans, uint64 n, uint64 k,                   \
1726                  std::complex<double> alpha,                                  \
1727                  const DeviceMemory<std::complex<double>> &a, int lda,        \
1728                  std::complex<double> beta,                                   \
1729                  DeviceMemory<std::complex<double>> *c, int ldc) override;    \
1730  bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,                      \
1731                   blas::Transpose trans, uint64 n, uint64 k, float alpha,     \
1732                   const DeviceMemory<float> &a, int lda,                      \
1733                   const DeviceMemory<float> &b, int ldb, float beta,          \
1734                   DeviceMemory<float> *c, int ldc) override;                  \
1735  bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,                      \
1736                   blas::Transpose trans, uint64 n, uint64 k, double alpha,    \
1737                   const DeviceMemory<double> &a, int lda,                     \
1738                   const DeviceMemory<double> &b, int ldb, double beta,        \
1739                   DeviceMemory<double> *c, int ldc) override;                 \
1740  bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,                      \
1741                   blas::Transpose trans, uint64 n, uint64 k,                  \
1742                   std::complex<float> alpha,                                  \
1743                   const DeviceMemory<std::complex<float>> &a, int lda,        \
1744                   const DeviceMemory<std::complex<float>> &b, int ldb,        \
1745                   std::complex<float> beta,                                   \
1746                   DeviceMemory<std::complex<float>> *c, int ldc) override;    \
1747  bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,                      \
1748                   blas::Transpose trans, uint64 n, uint64 k,                  \
1749                   std::complex<double> alpha,                                 \
1750                   const DeviceMemory<std::complex<double>> &a, int lda,       \
1751                   const DeviceMemory<std::complex<double>> &b, int ldb,       \
1752                   std::complex<double> beta,                                  \
1753                   DeviceMemory<std::complex<double>> *c, int ldc) override;   \
1754  bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
1755                  blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
1756                  uint64 n, float alpha, const DeviceMemory<float> &a,         \
1757                  int lda, DeviceMemory<float> *b, int ldb) override;          \
1758  bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
1759                  blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
1760                  uint64 n, double alpha, const DeviceMemory<double> &a,       \
1761                  int lda, DeviceMemory<double> *b, int ldb) override;         \
1762  bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
1763                  blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
1764                  uint64 n, std::complex<float> alpha,                         \
1765                  const DeviceMemory<std::complex<float>> &a, int lda,         \
1766                  DeviceMemory<std::complex<float>> *b, int ldb) override;     \
1767  bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
1768                  blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
1769                  uint64 n, std::complex<double> alpha,                        \
1770                  const DeviceMemory<std::complex<double>> &a, int lda,        \
1771                  DeviceMemory<std::complex<double>> *b, int ldb) override;    \
1772  bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
1773                  blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
1774                  uint64 n, float alpha, const DeviceMemory<float> &a,         \
1775                  int lda, DeviceMemory<float> *b, int ldb) override;          \
1776  bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
1777                  blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
1778                  uint64 n, double alpha, const DeviceMemory<double> &a,       \
1779                  int lda, DeviceMemory<double> *b, int ldb) override;         \
1780  bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
1781                  blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
1782                  uint64 n, std::complex<float> alpha,                         \
1783                  const DeviceMemory<std::complex<float>> &a, int lda,         \
1784                  DeviceMemory<std::complex<float>> *b, int ldb) override;     \
1785  bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
1786                  blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
1787                  uint64 n, std::complex<double> alpha,                        \
1788                  const DeviceMemory<std::complex<double>> &a, int lda,        \
1789                  DeviceMemory<std::complex<double>> *b, int ldb) override;
1790
1791}  // namespace blas
1792}  // namespace gputools
1793}  // namespace perftools
1794
1795#endif  // TENSORFLOW_STREAM_EXECUTOR_BLAS_H_
1796