rsCpuIntrinsicBLAS.cpp revision 99d0e8130f5b4bb83d1a68d96496fa558e35193a
1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17
18#include "rsCpuIntrinsic.h"
19#include "rsCpuIntrinsicInlines.h"
20#include "cblas.h"
21#include "eight_bit_int_gemm.h"
22
23using namespace android;
24using namespace android::renderscript;
25
26namespace android {
27namespace renderscript {
28
29
30class RsdCpuScriptIntrinsicBLAS : public RsdCpuScriptIntrinsic {
31public:
32    void invokeForEach(uint32_t slot,
33                       const Allocation ** ain,
34                       uint32_t inLen,
35                       Allocation * aout,
36                       const void * usr,
37                       uint32_t usrLen,
38                       const RsScriptCall *sc) override;
39
40    void populateScript(Script *) override;
41    ~RsdCpuScriptIntrinsicBLAS() override;
42    RsdCpuScriptIntrinsicBLAS(RsdCpuReferenceImpl *ctx, const Script *s);
43
44protected:
45
46    uint8_t a_offset = 0;
47    uint8_t b_offset = 0;
48    uint8_t c_offset = 0;
49
50    static void kernelBNNM(size_t m, size_t n, size_t k,
51                           const uint8_t* a, uint8_t a_offset, size_t lda,
52                           const uint8_t* b, uint8_t b_offset, size_t ldb,
53                           uint8_t* c, int32_t c_offset, size_t ldc,
54                           int32_t c_mult_int);
55
56
57
58};
59
60}
61}
62
63void RsdCpuScriptIntrinsicBLAS::populateScript(Script *s) {
64    s->mHal.info.exportedVariableCount = 0;
65}
66
67static void initABC(const Allocation ** ain,
68                    size_t size,
69                    void** A,
70                    void** B,
71                    void** C,
72                    int* lda,
73                    int* ldb,
74                    int* ldc)
75{
76    if (ain[0]) {
77        *A = ain[0]->mHal.drvState.lod[0].mallocPtr;
78        *lda = (int)(ain[0]->mHal.drvState.lod[0].stride/size);
79    }
80    if (ain[1]) {
81        *B = ain[1]->mHal.drvState.lod[0].mallocPtr;
82        *ldb = (int)(ain[1]->mHal.drvState.lod[0].stride/size);
83    }
84    if (ain[2]) {
85        *C = ain[2]->mHal.drvState.lod[0].mallocPtr;
86        *ldc = (int)(ain[2]->mHal.drvState.lod[0].stride/size);
87    }
88
89
90}
91
92void RsdCpuScriptIntrinsicBLAS::invokeForEach(uint32_t slot,
93                                              const Allocation ** ain,
94                                              uint32_t inLen,
95                                              Allocation * aout,
96                                              const void * usr,
97                                              uint32_t usrLen,
98                                              const RsScriptCall *sc) {
99    RsBlasCall* call = (RsBlasCall*) usr;
100    // setup BLAS enum args
101    enum CBLAS_TRANSPOSE TransA = (enum CBLAS_TRANSPOSE)call->transA;
102    enum CBLAS_TRANSPOSE TransB = (enum CBLAS_TRANSPOSE)call->transB;
103    enum CBLAS_UPLO Uplo = (enum CBLAS_UPLO)call->uplo;
104    enum CBLAS_DIAG Diag = (enum CBLAS_DIAG)call->diag;
105    enum CBLAS_SIDE Side = (enum CBLAS_SIDE)call->side;
106
107    void *A = nullptr;
108    void *B = nullptr;
109    void *C = nullptr;
110    void *X = nullptr;
111    void *Y = nullptr;
112
113    int lda = 0, ldb = 0, ldc = 0;
114
115    switch (call->func) {
116
117    // Level 1 BLAS: returns into a 1D Allocation
118
119
120    // Level 2 BLAS
121    case (RsBlas_sgemv):
122        initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
123        cblas_sgemv(CblasRowMajor, TransA, call->M, call->N, call->alpha.f, (float*)A,
124                    lda, (float*)X, call->incX, call->beta.f, (float*)Y, call->incY);
125        break;
126    case (RsBlas_sgbmv):
127        initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
128        cblas_sgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU,
129                    call->alpha.f, (float*)A, lda, (float*)X, call->incX,
130                    call->beta.f, (float*)Y, call->incY);
131        break;
132    case (RsBlas_strmv):
133        initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
134        cblas_strmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A,
135                    lda, (float*)X, call->incX);
136        break;
137    case (RsBlas_stbmv):
138        initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
139        cblas_stbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (float*)A,
140                    lda, (float*)X, call->incX);
141        break;
142    // stpmv takes a packed 1D Allocation only
143    case (RsBlas_stpmv):
144        initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
145        cblas_stpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A,
146                    (float*)X, call->incX);
147        break;
148    case (RsBlas_strsv):
149        initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
150        cblas_strsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A, lda,
151                    (float*)X, call->incX);
152        break;
153    case (RsBlas_stbsv):
154        initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
155        cblas_stbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (float*)A,
156                    lda, (float*)X, call->incX);
157        break;
158    case (RsBlas_stpsv):
159        initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
160        cblas_stpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A,
161                    (float*)X, call->incX);
162        break;
163    case (RsBlas_dgemv):
164        initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
165        cblas_dgemv(CblasRowMajor, TransA, call->M, call->N, call->alpha.d, (double*)A,
166                    lda, (double*)X, call->incX, call->beta.d, (double*)Y, call->incY);
167        break;
168    case (RsBlas_dgbmv):
169        initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
170        cblas_dgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU,
171                    call->alpha.d, (double*)A, lda, (double*)X, call->incX,
172                    call->beta.d, (double*)Y, call->incY);
173        break;
174    case (RsBlas_dtrmv):
175        initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
176        cblas_dtrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A,
177                    lda, (double*)X, call->incX);
178        break;
179    case (RsBlas_dtbmv):
180        initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
181        cblas_dtbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (double*)A,
182                    lda, (double*)X, call->incX);
183        break;
184    // stpmv takes a packed 1D Allocation only
185    case (RsBlas_dtpmv):
186        initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
187        cblas_dtpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A,
188                    (double*)X, call->incX);
189        break;
190    case (RsBlas_dtrsv):
191        initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
192        cblas_dtrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A, lda,
193                    (double*)X, call->incX);
194        break;
195    case (RsBlas_dtbsv):
196        initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
197        cblas_dtbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (double*)A,
198                    lda, (double*)X, call->incX);
199        break;
200    case (RsBlas_dtpsv):
201        initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
202        cblas_dtpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A,
203                    (double*)X, call->incX);
204        break;
205    case (RsBlas_cgemv):
206        initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
207        cblas_cgemv(CblasRowMajor, TransA, call->M, call->N, (void*)&call->alpha.c, (void*)A,
208                    lda, (void*)X, call->incX, (void*)&call->beta.c, (void*)Y, call->incY);
209        break;
210    case (RsBlas_cgbmv):
211        initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
212        cblas_cgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU,
213                    (void*)&call->alpha.c, (void*)A, lda, (void*)X, call->incX,
214                    (void*)&call->beta.c, (void*)Y, call->incY);
215        break;
216    case (RsBlas_ctrmv):
217        initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
218        cblas_ctrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
219                    lda, (void*)X, call->incX);
220        break;
221    case (RsBlas_ctbmv):
222        initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
223        cblas_ctbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A,
224                    lda, (void*)X, call->incX);
225        break;
226    // stpmv takes a packed 1D Allocation only
227    case (RsBlas_ctpmv):
228        initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
229        cblas_ctpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
230                    (void*)X, call->incX);
231        break;
232    case (RsBlas_ctrsv):
233        initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
234        cblas_ctrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, lda,
235                    (void*)X, call->incX);
236        break;
237    case (RsBlas_ctbsv):
238        initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
239        cblas_ctbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A,
240                    lda, (void*)X, call->incX);
241        break;
242    case (RsBlas_ctpsv):
243        initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
244        cblas_ctpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
245                    (void*)X, call->incX);
246        break;
247    case (RsBlas_zgemv):
248        initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
249        cblas_zgemv(CblasRowMajor, TransA, call->M, call->N, (void*)&call->alpha.z, (void*)A,
250                    lda, (void*)X, call->incX, (void*)&call->beta.z, (void*)Y, call->incY);
251        break;
252    case (RsBlas_zgbmv):
253        initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
254        cblas_zgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU,
255                    (void*)&call->alpha.z, (void*)A, lda, (void*)X, call->incX,
256                    (void*)&call->beta.z, (void*)Y, call->incY);
257        break;
258    case (RsBlas_ztrmv):
259        initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
260        cblas_ztrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
261                    lda, (void*)X, call->incX);
262        break;
263    case (RsBlas_ztbmv):
264        initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
265        cblas_ztbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A,
266                    lda, (void*)X, call->incX);
267        break;
268    // stpmv takes a packed 1D Allocation only
269    case (RsBlas_ztpmv):
270        initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
271        cblas_ztpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
272                    (void*)X, call->incX);
273        break;
274    case (RsBlas_ztrsv):
275        initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
276        cblas_ztrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, lda,
277                    (void*)X, call->incX);
278        break;
279    case (RsBlas_ztbsv):
280        initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
281        cblas_ztbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A,
282                    lda, (void*)X, call->incX);
283        break;
284    case (RsBlas_ztpsv):
285        initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
286        cblas_ztpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
287                    (void*)X, call->incX);
288        break;
289
290
291    // S and D only
292    case (RsBlas_ssymv):
293        initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
294        cblas_ssymv(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)A, lda,
295                    (float*)X, call->incX, call->beta.f, (float*)Y, call->incY);
296        break;
297    case (RsBlas_ssbmv):
298        initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
299        cblas_ssbmv(CblasRowMajor, Uplo, call->N, call->K, call->alpha.f,
300                    (float*)A, lda, (float*)X, call->incX, call->beta.f,
301                    (float*)Y, call->incY);
302        break;
303    //sspmv requires a packed 1D Allocation
304    case (RsBlas_sspmv):
305        initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
306        cblas_sspmv(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)A,
307                    (float*)X, call->incX, call->beta.f, (float*)Y, call->incY);
308        break;
309    // following calls have init reordered because A is output matrix
310    case (RsBlas_sger):
311        initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda);
312        cblas_sger(CblasRowMajor, call->M, call->N, call->alpha.f, (float*)X,
313                   call->incX, (float*)Y, call->incY, (float*)A, lda);
314        break;
315    case (RsBlas_ssyr):
316        initABC(ain, sizeof(float), &X, &A, nullptr, &ldb, &lda, nullptr);
317        cblas_ssyr(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX,
318                   (float*)A, lda);
319        break;
320    // sspr is packed 1D Allocation A only
321    case (RsBlas_sspr):
322        initABC(ain, sizeof(float), &X, &A, nullptr, &ldb, &lda, nullptr);
323        cblas_sspr(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX,
324                   (float*)A);
325        break;
326    case (RsBlas_ssyr2):
327        initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda);
328        cblas_ssyr2(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX,
329                    (float*)Y, call->incY, (float*)A, lda);
330        break;
331    // sspr2 is packed 1D Allocation A only
332    case (RsBlas_sspr2):
333        initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda);
334        cblas_sspr2(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX,
335                    (float*)Y, call->incY, (float*)A);
336        break;
337    case (RsBlas_dsymv):
338        initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
339        cblas_dsymv(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)A, lda,
340                    (double*)X, call->incX, call->beta.d, (double*)Y, call->incY);
341        break;
342    case (RsBlas_dsbmv):
343        initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
344        cblas_dsbmv(CblasRowMajor, Uplo, call->N, call->K, call->alpha.d,
345                    (double*)A, lda, (double*)X, call->incX, call->beta.d,
346                    (double*)Y, call->incY);
347        break;
348    // dspmv requires a packed 1D Allocation
349    case (RsBlas_dspmv):
350        initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
351        cblas_dspmv(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)A,
352                    (double*)X, call->incX, call->beta.d, (double*)Y, call->incY);
353        break;
354    // following calls have init reordered because A is output matrix
355    case (RsBlas_dger):
356        initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda);
357        cblas_dger(CblasRowMajor, call->M, call->N, call->alpha.d, (double*)X,
358                   call->incX, (double*)Y, call->incY, (double*)A, lda);
359        break;
360    case (RsBlas_dsyr):
361        initABC(ain, sizeof(double), &X, &A, nullptr, &ldb, &lda, nullptr);
362        cblas_dsyr(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX,
363                   (double*)A, lda);
364        break;
365    // dspr is packed 1D Allocation A only
366    case (RsBlas_dspr):
367        initABC(ain, sizeof(double), &X, &A, nullptr, &ldb, &lda, nullptr);
368        cblas_dspr(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX,
369                   (double*)A);
370        break;
371    case (RsBlas_dsyr2):
372        initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda);
373        cblas_dsyr2(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX,
374                    (double*)Y, call->incY, (double*)A, lda);
375        break;
376    // dspr2 is packed 1D Allocation A only
377    case (RsBlas_dspr2):
378        initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda);
379        cblas_dspr2(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX,
380                    (double*)Y, call->incY, (double*)A);
381        break;
382
383    // C and Z only
384    case (RsBlas_chemv):
385        initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
386        cblas_chemv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, A, lda,
387                    X, call->incX, (void*)&call->beta.c, Y, call->incY);
388        break;
389    case (RsBlas_chbmv):
390        initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
391        cblas_chbmv(CblasRowMajor, Uplo, call->N, call->K, (void*)&call->alpha.c,
392                    A, lda, X, call->incX, (void*)&call->beta.c, Y, call->incY);
393        break;
394    case (RsBlas_chpmv):
395        initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
396        cblas_chpmv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, A,
397                    X, call->incX, (void*)&call->beta.c, Y, call->incY);
398        break;
399    case (RsBlas_cgeru):
400        initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda);
401        cblas_cgeru(CblasRowMajor, call->M, call->N, (void*)&call->alpha.c,
402                    X, call->incX, Y, call->incY, A, lda);
403        break;
404    case (RsBlas_cgerc):
405        initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda);
406        cblas_cgerc(CblasRowMajor, call->M, call->N, (void*)&call->alpha.c,
407                    X, call->incX, Y, call->incY, A, lda);
408        break;
409    case (RsBlas_cher):
410        initABC(ain, sizeof(float)*2, &X, nullptr, &A, &ldb, nullptr, &lda);
411        cblas_cher(CblasRowMajor, Uplo, call->N, call->alpha.f,
412                   X, call->incX, A, lda);
413        break;
414    // packed 1D Allocations only
415    case (RsBlas_chpr):
416        initABC(ain, sizeof(float)*2, &X, nullptr, &A, &ldb, nullptr, &lda);
417        cblas_chpr(CblasRowMajor, Uplo, call->N, call->alpha.f, X,
418                   call->incX, A);
419        break;
420    case (RsBlas_cher2):
421        initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda);
422        cblas_cher2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c,
423                   X, call->incX, Y, call->incY, A, lda);
424        break;
425    // packed 1D Allocations only
426    case (RsBlas_chpr2):
427        initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda);
428        cblas_chpr2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, X,
429                   call->incX, Y, call->incY, A);
430        break;
431    case (RsBlas_zhemv):
432        initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
433        cblas_zhemv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, A, lda,
434                    X, call->incX, (void*)&call->beta.z, Y, call->incY);
435        break;
436    case (RsBlas_zhbmv):
437        initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
438        cblas_zhbmv(CblasRowMajor, Uplo, call->N, call->K, (void*)&call->alpha.z,
439                    A, lda, X, call->incX, (void*)&call->beta.z, Y, call->incY);
440        break;
441    case (RsBlas_zhpmv):
442        initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
443        cblas_zhpmv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, A,
444                    X, call->incX, (void*)&call->beta.z, Y, call->incY);
445        break;
446    case (RsBlas_zgeru):
447        initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda);
448        cblas_zgeru(CblasRowMajor, call->M, call->N, (void*)&call->alpha.z,
449                    X, call->incX, Y, call->incY, A, lda);
450        break;
451    case (RsBlas_zgerc):
452        initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda);
453        cblas_zgerc(CblasRowMajor, call->M, call->N, (void*)&call->alpha.z,
454                    X, call->incX, Y, call->incY, A, lda);
455        break;
456    case (RsBlas_zher):
457        initABC(ain, sizeof(double)*2, &X, nullptr, &A, &ldb, nullptr, &lda);
458        cblas_zher(CblasRowMajor, Uplo, call->N, call->alpha.d,
459                   X, call->incX, A, lda);
460        break;
461    // packed 1D Allocations only
462    case (RsBlas_zhpr):
463        initABC(ain, sizeof(double)*2, &X, nullptr, &A, &ldb, nullptr, &lda);
464        cblas_zhpr(CblasRowMajor, Uplo, call->N, call->alpha.d, X,
465                   call->incX, A);
466        break;
467    case (RsBlas_zher2):
468        initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda);
469        cblas_zher2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z,
470                   X, call->incX, Y, call->incY, A, lda);
471        break;
472    // packed 1D Allocations only
473    case (RsBlas_zhpr2):
474        initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda);
475        cblas_zhpr2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, X,
476                   call->incX, Y, call->incY, A);
477        break;
478
479    // Level 3 BLAS
480    case (RsBlas_sgemm):
481        initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc);
482        cblas_sgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.f,
483                    (float*)A, lda, (float*)B, ldb, call->beta.f, (float*)C, ldc);
484        break;
485    case (RsBlas_ssymm):
486        initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc);
487        cblas_ssymm(CblasRowMajor, Side, Uplo, call->M, call->N, call->alpha.f, (float*)A,
488                    lda, (float*)B, ldb, call->beta.f, (float*)C, ldc);
489        break;
490    case (RsBlas_ssyrk):
491        initABC(ain, sizeof(float), &A, nullptr, &C, &lda, nullptr, &ldc);
492        cblas_ssyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, (float*)A,
493                    lda, call->beta.f, (float*)C, ldc);
494        break;
495    case (RsBlas_ssyr2k):
496        initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc);
497        cblas_ssyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, (float*)A,
498                     lda, (float*)B, ldb, call->beta.f, (float*)C, ldc);
499        break;
500    case (RsBlas_strmm):
501        initABC(ain, sizeof(float), &A, &B, nullptr, &lda, &ldb, nullptr);
502        cblas_strmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.f,
503                    (float*)A, lda, (float*)B, ldb);
504        break;
505    case (RsBlas_strsm):
506        initABC(ain, sizeof(float), &A, &B, nullptr, &lda, &ldb, nullptr);
507        cblas_strsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.f,
508                    (float*)A, lda, (float*)B, ldb);
509        break;
510
511
512    case (RsBlas_dgemm):
513        initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc);
514        cblas_dgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.d,
515                    (double*)A, lda, (double*)B, ldb, call->beta.d, (double*)C, ldc);
516        break;
517    case (RsBlas_dsymm):
518        initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc);
519        cblas_dsymm(CblasRowMajor, Side, Uplo, call->M, call->N, call->alpha.d, (double*)A,
520                    lda, (double*)B, ldb, call->beta.d, (double*)C, ldc);
521        break;
522    case (RsBlas_dsyrk):
523        initABC(ain, sizeof(double), &A, nullptr, &C, &lda, nullptr, &ldc);
524        cblas_dsyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, (double*)A,
525                    lda, call->beta.d, (double*)C, ldc);
526        break;
527    case (RsBlas_dsyr2k):
528        initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc);
529        cblas_dsyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, (double*)A,
530                     lda, (double*)B, ldb, call->beta.d, (double*)C, ldc);
531        break;
532    case (RsBlas_dtrmm):
533        initABC(ain, sizeof(double), &A, &B, nullptr, &lda, &ldb, nullptr);
534        cblas_dtrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.d,
535                    (double*)A, lda, (double*)B, ldb);
536        break;
537    case (RsBlas_dtrsm):
538        initABC(ain, sizeof(double), &A, &B, nullptr, &lda, &ldb, nullptr);
539        cblas_dtrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.d,
540                    (double*)A, lda, (double*)B, ldb);
541        break;
542
543    case (RsBlas_cgemm):
544        initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
545        cblas_cgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.c,
546                    A, lda, B, ldb, (void*)&call->beta.c, C, ldc);
547        break;
548    case (RsBlas_csymm):
549        initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
550        cblas_csymm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.c, A,
551                    lda, B, ldb, (void*)&call->beta.c, C, ldc);
552        break;
553    case (RsBlas_csyrk):
554        initABC(ain, sizeof(float)*2, &A, nullptr, &C, &lda, nullptr, &ldc);
555        cblas_csyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A,
556                    lda, (void*)&call->beta.c, C, ldc);
557        break;
558    case (RsBlas_csyr2k):
559        initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
560        cblas_csyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A,
561                     lda, B, ldb, (void*)&call->beta.c, C, ldc);
562        break;
563    case (RsBlas_ctrmm):
564        initABC(ain, sizeof(float)*2, &A, &B, nullptr, &lda, &ldb, nullptr);
565        cblas_ctrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.c,
566                    A, lda, B, ldb);
567        break;
568    case (RsBlas_ctrsm):
569        initABC(ain, sizeof(float)*2, &A, &B, nullptr, &lda, &ldb, nullptr);
570        cblas_ctrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.c,
571                    A, lda, B, ldb);
572        break;
573
574    case (RsBlas_zgemm):
575        initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
576        cblas_zgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.z,
577                    A, lda, B, ldb, (void*)&call->beta.z, C, ldc);
578        break;
579    case (RsBlas_zsymm):
580        initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
581        cblas_zsymm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.z, A,
582                    lda, B, ldb, (void*)&call->beta.z, C, ldc);
583        break;
584    case (RsBlas_zsyrk):
585        initABC(ain, sizeof(double)*2, &A, nullptr, &C, &lda, nullptr, &ldc);
586        cblas_zsyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A,
587                    lda, (void*)&call->beta.z, C, ldc);
588        break;
589    case (RsBlas_zsyr2k):
590        initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
591        cblas_zsyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A,
592                     lda, B, ldb, (void*)&call->beta.z, C, ldc);
593        break;
594    case (RsBlas_ztrmm):
595        initABC(ain, sizeof(double)*2, &A, &B, nullptr, &lda, &ldb, nullptr);
596        cblas_ztrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.z,
597                    A, lda, B, ldb);
598        break;
599    case (RsBlas_ztrsm):
600        initABC(ain, sizeof(double)*2, &A, &B, nullptr, &lda, &ldb, nullptr);
601        cblas_ztrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.z,
602                    A, lda, B, ldb);
603        break;
604
605    // Level 3 C and Z only
606    case (RsBlas_chemm):
607        initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
608        cblas_chemm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.c, A, lda,
609                    B, ldb, (void*)&call->beta.c, C, ldc);
610        break;
611    case (RsBlas_cherk):
612        initABC(ain, sizeof(float)*2, &A, nullptr, &C, &lda, nullptr, &ldc);
613        cblas_cherk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, A, lda,
614                    call->beta.f, C, ldc);
615        break;
616    case (RsBlas_cher2k):
617        initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
618        cblas_cher2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A, lda,
619                     B, ldb, call->beta.f, C, ldc);
620        break;
621
622    case (RsBlas_zhemm):
623        initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
624        cblas_zhemm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.z, A, lda,
625                    B, ldb, (void*)&call->beta.z, C, ldc);
626        break;
627    case (RsBlas_zherk):
628        initABC(ain, sizeof(double)*2, &A, nullptr, &C, &lda, nullptr, &ldc);
629        cblas_zherk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, A, lda,
630                    call->beta.d, C, ldc);
631        break;
632    case (RsBlas_zher2k):
633        initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
634        cblas_zher2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A, lda,
635                     B, ldb, call->beta.d, C, ldc);
636        break;
637
638
639    case (RsBlas_bnnm):
640        initABC(ain, sizeof(uint8_t), &A, &B, &C, &lda, &ldb, &ldc);
641        kernelBNNM(call->M, call->N, call->K,
642                    (const uint8_t*)A, call->a_offset, lda,
643                    (const uint8_t*)B, call->b_offset, ldb,
644                    (uint8_t*)C, call->c_offset, ldc,
645                    call->c_mult_int);
646
647        break;
648
649    default:
650        ALOGE("unimplemented\n");
651    }
652
653
654}
655
656void RsdCpuScriptIntrinsicBLAS::kernelBNNM(size_t m, size_t n, size_t k,
657                                           const uint8_t* a, uint8_t a_offset, size_t lda,
658                                           const uint8_t* b, uint8_t b_offset, size_t ldb,
659                                           uint8_t* c, int32_t c_offset, size_t ldc,
660                                           int32_t c_mult_int) {
661    const int c_shift = 21;
662    // Using gemmlowp to calculate the low precision 8 bit GEMM.
663    gemmlowp::eight_bit_int_gemm::EightBitIntGemm(m, n, k, a, -a_offset, lda,
664                                                  b, -b_offset, ldb, c, c_offset,
665                                                  c_mult_int, c_shift, ldc);
666}
667
668
669
670
671
672RsdCpuScriptIntrinsicBLAS::RsdCpuScriptIntrinsicBLAS(RsdCpuReferenceImpl *ctx,
673                                                   const Script *s)
674            : RsdCpuScriptIntrinsic(ctx, s, nullptr, RS_SCRIPT_INTRINSIC_ID_BLAS) {
675
676
677}
678
679RsdCpuScriptIntrinsicBLAS::~RsdCpuScriptIntrinsicBLAS() {
680}
681
682
683
684
685
686RsdCpuScriptImpl * rsdIntrinsic_BLAS(RsdCpuReferenceImpl *ctx,
687                                    const Script *s, const Element *e) {
688
689    return new RsdCpuScriptIntrinsicBLAS(ctx, s);
690}
691