1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17
18#include "rsCpuIntrinsic.h"
19#include "rsCpuIntrinsicInlines.h"
20#include "rsCpuBLASDispatch.h"
21#include "eight_bit_int_gemm.h"
22
23using namespace android;
24using namespace android::renderscript;
25
26namespace android {
27namespace renderscript {
28
29
30class RsdCpuScriptIntrinsicBLAS : public RsdCpuScriptIntrinsic {
31public:
32    void invokeForEach(uint32_t slot,
33                       const Allocation ** ain,
34                       uint32_t inLen,
35                       Allocation * aout,
36                       const void * usr,
37                       uint32_t usrLen,
38                       const RsScriptCall *sc) override;
39
40    void populateScript(Script *) override;
41    ~RsdCpuScriptIntrinsicBLAS() override;
42    RsdCpuScriptIntrinsicBLAS(RsdCpuReferenceImpl *ctx, const Script *s);
43
44protected:
45
46    uint8_t a_offset = 0;
47    uint8_t b_offset = 0;
48    uint8_t c_offset = 0;
49
50#ifdef RS_COMPATIBILITY_LIB
51    bool isBlasLibInitialized = false;
52#endif
53    static void kernelBNNM(size_t m, size_t n, size_t k,
54                           const uint8_t* a, uint8_t a_offset, size_t lda,
55                           const uint8_t* b, uint8_t b_offset, size_t ldb,
56                           uint8_t* c, int32_t c_offset, size_t ldc,
57                           int32_t c_mult_int);
58
59
60
61};
62
63}
64}
65
66void RsdCpuScriptIntrinsicBLAS::populateScript(Script *s) {
67    s->mHal.info.exportedVariableCount = 0;
68}
69
70static void initABC(const Allocation ** ain,
71                    size_t size,
72                    void** A,
73                    void** B,
74                    void** C,
75                    int* lda,
76                    int* ldb,
77                    int* ldc)
78{
79    if (ain[0]) {
80        *A = ain[0]->mHal.drvState.lod[0].mallocPtr;
81        *lda = (int)(ain[0]->mHal.drvState.lod[0].stride/size);
82    }
83    if (ain[1]) {
84        *B = ain[1]->mHal.drvState.lod[0].mallocPtr;
85        *ldb = (int)(ain[1]->mHal.drvState.lod[0].stride/size);
86    }
87    if (ain[2]) {
88        *C = ain[2]->mHal.drvState.lod[0].mallocPtr;
89        *ldc = (int)(ain[2]->mHal.drvState.lod[0].stride/size);
90    }
91
92
93}
94
95void RsdCpuScriptIntrinsicBLAS::invokeForEach(uint32_t slot,
96                                              const Allocation ** ain,
97                                              uint32_t inLen,
98                                              Allocation * aout,
99                                              const void * usr,
100                                              uint32_t usrLen,
101                                              const RsScriptCall *sc) {
102    RsBlasCall* call = (RsBlasCall*) usr;
103    // setup BLAS enum args
104    enum CBLAS_TRANSPOSE TransA = (enum CBLAS_TRANSPOSE)call->transA;
105    enum CBLAS_TRANSPOSE TransB = (enum CBLAS_TRANSPOSE)call->transB;
106    enum CBLAS_UPLO Uplo = (enum CBLAS_UPLO)call->uplo;
107    enum CBLAS_DIAG Diag = (enum CBLAS_DIAG)call->diag;
108    enum CBLAS_SIDE Side = (enum CBLAS_SIDE)call->side;
109
110    void *A = nullptr;
111    void *B = nullptr;
112    void *C = nullptr;
113    void *X = nullptr;
114    void *Y = nullptr;
115
116    int lda = 0, ldb = 0, ldc = 0;
117
118#ifdef RS_COMPATIBILITY_LIB
119    // Allow BNNM even without libblas
120    if (call->func != RsBlas_bnnm && !isBlasLibInitialized) {
121        if (!loadBLASLib()) {
122            ALOGE("Failed to load the BLAS lib, IntrinsicBLAS NOT supported!\n");
123            return;
124        }
125        isBlasLibInitialized = true;
126    }
127#endif
128
129    switch (call->func) {
130
131    // Level 1 BLAS: returns into a 1D Allocation
132
133
134    // Level 2 BLAS
135    case (RsBlas_sgemv):
136        initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
137        cblas_sgemv(CblasRowMajor, TransA, call->M, call->N, call->alpha.f, (float*)A,
138                    lda, (float*)X, call->incX, call->beta.f, (float*)Y, call->incY);
139        break;
140    case (RsBlas_sgbmv):
141        initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
142        cblas_sgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU,
143                    call->alpha.f, (float*)A, lda, (float*)X, call->incX,
144                    call->beta.f, (float*)Y, call->incY);
145        break;
146    case (RsBlas_strmv):
147        initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
148        cblas_strmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A,
149                    lda, (float*)X, call->incX);
150        break;
151    case (RsBlas_stbmv):
152        initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
153        cblas_stbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (float*)A,
154                    lda, (float*)X, call->incX);
155        break;
156    // stpmv takes a packed 1D Allocation only
157    case (RsBlas_stpmv):
158        initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
159        cblas_stpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A,
160                    (float*)X, call->incX);
161        break;
162    case (RsBlas_strsv):
163        initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
164        cblas_strsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A, lda,
165                    (float*)X, call->incX);
166        break;
167    case (RsBlas_stbsv):
168        initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
169        cblas_stbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (float*)A,
170                    lda, (float*)X, call->incX);
171        break;
172    case (RsBlas_stpsv):
173        initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
174        cblas_stpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A,
175                    (float*)X, call->incX);
176        break;
177    case (RsBlas_dgemv):
178        initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
179        cblas_dgemv(CblasRowMajor, TransA, call->M, call->N, call->alpha.d, (double*)A,
180                    lda, (double*)X, call->incX, call->beta.d, (double*)Y, call->incY);
181        break;
182    case (RsBlas_dgbmv):
183        initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
184        cblas_dgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU,
185                    call->alpha.d, (double*)A, lda, (double*)X, call->incX,
186                    call->beta.d, (double*)Y, call->incY);
187        break;
188    case (RsBlas_dtrmv):
189        initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
190        cblas_dtrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A,
191                    lda, (double*)X, call->incX);
192        break;
193    case (RsBlas_dtbmv):
194        initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
195        cblas_dtbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (double*)A,
196                    lda, (double*)X, call->incX);
197        break;
198    // stpmv takes a packed 1D Allocation only
199    case (RsBlas_dtpmv):
200        initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
201        cblas_dtpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A,
202                    (double*)X, call->incX);
203        break;
204    case (RsBlas_dtrsv):
205        initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
206        cblas_dtrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A, lda,
207                    (double*)X, call->incX);
208        break;
209    case (RsBlas_dtbsv):
210        initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
211        cblas_dtbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (double*)A,
212                    lda, (double*)X, call->incX);
213        break;
214    case (RsBlas_dtpsv):
215        initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
216        cblas_dtpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A,
217                    (double*)X, call->incX);
218        break;
219    case (RsBlas_cgemv):
220        initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
221        cblas_cgemv(CblasRowMajor, TransA, call->M, call->N, (void*)&call->alpha.c, (void*)A,
222                    lda, (void*)X, call->incX, (void*)&call->beta.c, (void*)Y, call->incY);
223        break;
224    case (RsBlas_cgbmv):
225        initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
226        cblas_cgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU,
227                    (void*)&call->alpha.c, (void*)A, lda, (void*)X, call->incX,
228                    (void*)&call->beta.c, (void*)Y, call->incY);
229        break;
230    case (RsBlas_ctrmv):
231        initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
232        cblas_ctrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
233                    lda, (void*)X, call->incX);
234        break;
235    case (RsBlas_ctbmv):
236        initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
237        cblas_ctbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A,
238                    lda, (void*)X, call->incX);
239        break;
240    // stpmv takes a packed 1D Allocation only
241    case (RsBlas_ctpmv):
242        initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
243        cblas_ctpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
244                    (void*)X, call->incX);
245        break;
246    case (RsBlas_ctrsv):
247        initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
248        cblas_ctrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, lda,
249                    (void*)X, call->incX);
250        break;
251    case (RsBlas_ctbsv):
252        initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
253        cblas_ctbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A,
254                    lda, (void*)X, call->incX);
255        break;
256    case (RsBlas_ctpsv):
257        initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
258        cblas_ctpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
259                    (void*)X, call->incX);
260        break;
261    case (RsBlas_zgemv):
262        initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
263        cblas_zgemv(CblasRowMajor, TransA, call->M, call->N, (void*)&call->alpha.z, (void*)A,
264                    lda, (void*)X, call->incX, (void*)&call->beta.z, (void*)Y, call->incY);
265        break;
266    case (RsBlas_zgbmv):
267        initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
268        cblas_zgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU,
269                    (void*)&call->alpha.z, (void*)A, lda, (void*)X, call->incX,
270                    (void*)&call->beta.z, (void*)Y, call->incY);
271        break;
272    case (RsBlas_ztrmv):
273        initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
274        cblas_ztrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
275                    lda, (void*)X, call->incX);
276        break;
277    case (RsBlas_ztbmv):
278        initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
279        cblas_ztbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A,
280                    lda, (void*)X, call->incX);
281        break;
282    // stpmv takes a packed 1D Allocation only
283    case (RsBlas_ztpmv):
284        initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
285        cblas_ztpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
286                    (void*)X, call->incX);
287        break;
288    case (RsBlas_ztrsv):
289        initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
290        cblas_ztrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, lda,
291                    (void*)X, call->incX);
292        break;
293    case (RsBlas_ztbsv):
294        initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
295        cblas_ztbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A,
296                    lda, (void*)X, call->incX);
297        break;
298    case (RsBlas_ztpsv):
299        initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
300        cblas_ztpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
301                    (void*)X, call->incX);
302        break;
303
304
305    // S and D only
306    case (RsBlas_ssymv):
307        initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
308        cblas_ssymv(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)A, lda,
309                    (float*)X, call->incX, call->beta.f, (float*)Y, call->incY);
310        break;
311    case (RsBlas_ssbmv):
312        initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
313        cblas_ssbmv(CblasRowMajor, Uplo, call->N, call->K, call->alpha.f,
314                    (float*)A, lda, (float*)X, call->incX, call->beta.f,
315                    (float*)Y, call->incY);
316        break;
317    //sspmv requires a packed 1D Allocation
318    case (RsBlas_sspmv):
319        initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
320        cblas_sspmv(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)A,
321                    (float*)X, call->incX, call->beta.f, (float*)Y, call->incY);
322        break;
323    // following calls have init reordered because A is output matrix
324    case (RsBlas_sger):
325        initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda);
326        cblas_sger(CblasRowMajor, call->M, call->N, call->alpha.f, (float*)X,
327                   call->incX, (float*)Y, call->incY, (float*)A, lda);
328        break;
329    case (RsBlas_ssyr):
330        initABC(ain, sizeof(float), &X, &A, nullptr, &ldb, &lda, nullptr);
331        cblas_ssyr(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX,
332                   (float*)A, lda);
333        break;
334    // sspr is packed 1D Allocation A only
335    case (RsBlas_sspr):
336        initABC(ain, sizeof(float), &X, &A, nullptr, &ldb, &lda, nullptr);
337        cblas_sspr(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX,
338                   (float*)A);
339        break;
340    case (RsBlas_ssyr2):
341        initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda);
342        cblas_ssyr2(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX,
343                    (float*)Y, call->incY, (float*)A, lda);
344        break;
345    // sspr2 is packed 1D Allocation A only
346    case (RsBlas_sspr2):
347        initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda);
348        cblas_sspr2(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX,
349                    (float*)Y, call->incY, (float*)A);
350        break;
351    case (RsBlas_dsymv):
352        initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
353        cblas_dsymv(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)A, lda,
354                    (double*)X, call->incX, call->beta.d, (double*)Y, call->incY);
355        break;
356    case (RsBlas_dsbmv):
357        initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
358        cblas_dsbmv(CblasRowMajor, Uplo, call->N, call->K, call->alpha.d,
359                    (double*)A, lda, (double*)X, call->incX, call->beta.d,
360                    (double*)Y, call->incY);
361        break;
362    // dspmv requires a packed 1D Allocation
363    case (RsBlas_dspmv):
364        initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
365        cblas_dspmv(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)A,
366                    (double*)X, call->incX, call->beta.d, (double*)Y, call->incY);
367        break;
368    // following calls have init reordered because A is output matrix
369    case (RsBlas_dger):
370        initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda);
371        cblas_dger(CblasRowMajor, call->M, call->N, call->alpha.d, (double*)X,
372                   call->incX, (double*)Y, call->incY, (double*)A, lda);
373        break;
374    case (RsBlas_dsyr):
375        initABC(ain, sizeof(double), &X, &A, nullptr, &ldb, &lda, nullptr);
376        cblas_dsyr(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX,
377                   (double*)A, lda);
378        break;
379    // dspr is packed 1D Allocation A only
380    case (RsBlas_dspr):
381        initABC(ain, sizeof(double), &X, &A, nullptr, &ldb, &lda, nullptr);
382        cblas_dspr(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX,
383                   (double*)A);
384        break;
385    case (RsBlas_dsyr2):
386        initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda);
387        cblas_dsyr2(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX,
388                    (double*)Y, call->incY, (double*)A, lda);
389        break;
390    // dspr2 is packed 1D Allocation A only
391    case (RsBlas_dspr2):
392        initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda);
393        cblas_dspr2(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX,
394                    (double*)Y, call->incY, (double*)A);
395        break;
396
397    // C and Z only
398    case (RsBlas_chemv):
399        initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
400        cblas_chemv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, A, lda,
401                    X, call->incX, (void*)&call->beta.c, Y, call->incY);
402        break;
403    case (RsBlas_chbmv):
404        initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
405        cblas_chbmv(CblasRowMajor, Uplo, call->N, call->K, (void*)&call->alpha.c,
406                    A, lda, X, call->incX, (void*)&call->beta.c, Y, call->incY);
407        break;
408    case (RsBlas_chpmv):
409        initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
410        cblas_chpmv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, A,
411                    X, call->incX, (void*)&call->beta.c, Y, call->incY);
412        break;
413    case (RsBlas_cgeru):
414        initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda);
415        cblas_cgeru(CblasRowMajor, call->M, call->N, (void*)&call->alpha.c,
416                    X, call->incX, Y, call->incY, A, lda);
417        break;
418    case (RsBlas_cgerc):
419        initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda);
420        cblas_cgerc(CblasRowMajor, call->M, call->N, (void*)&call->alpha.c,
421                    X, call->incX, Y, call->incY, A, lda);
422        break;
423    case (RsBlas_cher):
424        initABC(ain, sizeof(float)*2, &X, nullptr, &A, &ldb, nullptr, &lda);
425        cblas_cher(CblasRowMajor, Uplo, call->N, call->alpha.f,
426                   X, call->incX, A, lda);
427        break;
428    // packed 1D Allocations only
429    case (RsBlas_chpr):
430        initABC(ain, sizeof(float)*2, &X, nullptr, &A, &ldb, nullptr, &lda);
431        cblas_chpr(CblasRowMajor, Uplo, call->N, call->alpha.f, X,
432                   call->incX, A);
433        break;
434    case (RsBlas_cher2):
435        initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda);
436        cblas_cher2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c,
437                   X, call->incX, Y, call->incY, A, lda);
438        break;
439    // packed 1D Allocations only
440    case (RsBlas_chpr2):
441        initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda);
442        cblas_chpr2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, X,
443                   call->incX, Y, call->incY, A);
444        break;
445    case (RsBlas_zhemv):
446        initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
447        cblas_zhemv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, A, lda,
448                    X, call->incX, (void*)&call->beta.z, Y, call->incY);
449        break;
450    case (RsBlas_zhbmv):
451        initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
452        cblas_zhbmv(CblasRowMajor, Uplo, call->N, call->K, (void*)&call->alpha.z,
453                    A, lda, X, call->incX, (void*)&call->beta.z, Y, call->incY);
454        break;
455    case (RsBlas_zhpmv):
456        initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
457        cblas_zhpmv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, A,
458                    X, call->incX, (void*)&call->beta.z, Y, call->incY);
459        break;
460    case (RsBlas_zgeru):
461        initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda);
462        cblas_zgeru(CblasRowMajor, call->M, call->N, (void*)&call->alpha.z,
463                    X, call->incX, Y, call->incY, A, lda);
464        break;
465    case (RsBlas_zgerc):
466        initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda);
467        cblas_zgerc(CblasRowMajor, call->M, call->N, (void*)&call->alpha.z,
468                    X, call->incX, Y, call->incY, A, lda);
469        break;
470    case (RsBlas_zher):
471        initABC(ain, sizeof(double)*2, &X, nullptr, &A, &ldb, nullptr, &lda);
472        cblas_zher(CblasRowMajor, Uplo, call->N, call->alpha.d,
473                   X, call->incX, A, lda);
474        break;
475    // packed 1D Allocations only
476    case (RsBlas_zhpr):
477        initABC(ain, sizeof(double)*2, &X, nullptr, &A, &ldb, nullptr, &lda);
478        cblas_zhpr(CblasRowMajor, Uplo, call->N, call->alpha.d, X,
479                   call->incX, A);
480        break;
481    case (RsBlas_zher2):
482        initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda);
483        cblas_zher2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z,
484                   X, call->incX, Y, call->incY, A, lda);
485        break;
486    // packed 1D Allocations only
487    case (RsBlas_zhpr2):
488        initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda);
489        cblas_zhpr2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, X,
490                   call->incX, Y, call->incY, A);
491        break;
492
493    // Level 3 BLAS
494    case (RsBlas_sgemm):
495        initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc);
496        cblas_sgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.f,
497                    (float*)A, lda, (float*)B, ldb, call->beta.f, (float*)C, ldc);
498        break;
499    case (RsBlas_ssymm):
500        initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc);
501        cblas_ssymm(CblasRowMajor, Side, Uplo, call->M, call->N, call->alpha.f, (float*)A,
502                    lda, (float*)B, ldb, call->beta.f, (float*)C, ldc);
503        break;
504    case (RsBlas_ssyrk):
505        initABC(ain, sizeof(float), &A, nullptr, &C, &lda, nullptr, &ldc);
506        cblas_ssyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, (float*)A,
507                    lda, call->beta.f, (float*)C, ldc);
508        break;
509    case (RsBlas_ssyr2k):
510        initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc);
511        cblas_ssyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, (float*)A,
512                     lda, (float*)B, ldb, call->beta.f, (float*)C, ldc);
513        break;
514    case (RsBlas_strmm):
515        initABC(ain, sizeof(float), &A, &B, nullptr, &lda, &ldb, nullptr);
516        cblas_strmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.f,
517                    (float*)A, lda, (float*)B, ldb);
518        break;
519    case (RsBlas_strsm):
520        initABC(ain, sizeof(float), &A, &B, nullptr, &lda, &ldb, nullptr);
521        cblas_strsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.f,
522                    (float*)A, lda, (float*)B, ldb);
523        break;
524
525
526    case (RsBlas_dgemm):
527        initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc);
528        cblas_dgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.d,
529                    (double*)A, lda, (double*)B, ldb, call->beta.d, (double*)C, ldc);
530        break;
531    case (RsBlas_dsymm):
532        initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc);
533        cblas_dsymm(CblasRowMajor, Side, Uplo, call->M, call->N, call->alpha.d, (double*)A,
534                    lda, (double*)B, ldb, call->beta.d, (double*)C, ldc);
535        break;
536    case (RsBlas_dsyrk):
537        initABC(ain, sizeof(double), &A, nullptr, &C, &lda, nullptr, &ldc);
538        cblas_dsyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, (double*)A,
539                    lda, call->beta.d, (double*)C, ldc);
540        break;
541    case (RsBlas_dsyr2k):
542        initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc);
543        cblas_dsyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, (double*)A,
544                     lda, (double*)B, ldb, call->beta.d, (double*)C, ldc);
545        break;
546    case (RsBlas_dtrmm):
547        initABC(ain, sizeof(double), &A, &B, nullptr, &lda, &ldb, nullptr);
548        cblas_dtrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.d,
549                    (double*)A, lda, (double*)B, ldb);
550        break;
551    case (RsBlas_dtrsm):
552        initABC(ain, sizeof(double), &A, &B, nullptr, &lda, &ldb, nullptr);
553        cblas_dtrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.d,
554                    (double*)A, lda, (double*)B, ldb);
555        break;
556
557    case (RsBlas_cgemm):
558        initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
559        cblas_cgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.c,
560                    A, lda, B, ldb, (void*)&call->beta.c, C, ldc);
561        break;
562    case (RsBlas_csymm):
563        initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
564        cblas_csymm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.c, A,
565                    lda, B, ldb, (void*)&call->beta.c, C, ldc);
566        break;
567    case (RsBlas_csyrk):
568        initABC(ain, sizeof(float)*2, &A, nullptr, &C, &lda, nullptr, &ldc);
569        cblas_csyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A,
570                    lda, (void*)&call->beta.c, C, ldc);
571        break;
572    case (RsBlas_csyr2k):
573        initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
574        cblas_csyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A,
575                     lda, B, ldb, (void*)&call->beta.c, C, ldc);
576        break;
577    case (RsBlas_ctrmm):
578        initABC(ain, sizeof(float)*2, &A, &B, nullptr, &lda, &ldb, nullptr);
579        cblas_ctrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.c,
580                    A, lda, B, ldb);
581        break;
582    case (RsBlas_ctrsm):
583        initABC(ain, sizeof(float)*2, &A, &B, nullptr, &lda, &ldb, nullptr);
584        cblas_ctrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.c,
585                    A, lda, B, ldb);
586        break;
587
588    case (RsBlas_zgemm):
589        initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
590        cblas_zgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.z,
591                    A, lda, B, ldb, (void*)&call->beta.z, C, ldc);
592        break;
593    case (RsBlas_zsymm):
594        initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
595        cblas_zsymm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.z, A,
596                    lda, B, ldb, (void*)&call->beta.z, C, ldc);
597        break;
598    case (RsBlas_zsyrk):
599        initABC(ain, sizeof(double)*2, &A, nullptr, &C, &lda, nullptr, &ldc);
600        cblas_zsyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A,
601                    lda, (void*)&call->beta.z, C, ldc);
602        break;
603    case (RsBlas_zsyr2k):
604        initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
605        cblas_zsyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A,
606                     lda, B, ldb, (void*)&call->beta.z, C, ldc);
607        break;
608    case (RsBlas_ztrmm):
609        initABC(ain, sizeof(double)*2, &A, &B, nullptr, &lda, &ldb, nullptr);
610        cblas_ztrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.z,
611                    A, lda, B, ldb);
612        break;
613    case (RsBlas_ztrsm):
614        initABC(ain, sizeof(double)*2, &A, &B, nullptr, &lda, &ldb, nullptr);
615        cblas_ztrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.z,
616                    A, lda, B, ldb);
617        break;
618
619    // Level 3 C and Z only
620    case (RsBlas_chemm):
621        initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
622        cblas_chemm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.c, A, lda,
623                    B, ldb, (void*)&call->beta.c, C, ldc);
624        break;
625    case (RsBlas_cherk):
626        initABC(ain, sizeof(float)*2, &A, nullptr, &C, &lda, nullptr, &ldc);
627        cblas_cherk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, A, lda,
628                    call->beta.f, C, ldc);
629        break;
630    case (RsBlas_cher2k):
631        initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
632        cblas_cher2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A, lda,
633                     B, ldb, call->beta.f, C, ldc);
634        break;
635
636    case (RsBlas_zhemm):
637        initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
638        cblas_zhemm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.z, A, lda,
639                    B, ldb, (void*)&call->beta.z, C, ldc);
640        break;
641    case (RsBlas_zherk):
642        initABC(ain, sizeof(double)*2, &A, nullptr, &C, &lda, nullptr, &ldc);
643        cblas_zherk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, A, lda,
644                    call->beta.d, C, ldc);
645        break;
646    case (RsBlas_zher2k):
647        initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
648        cblas_zher2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A, lda,
649                     B, ldb, call->beta.d, C, ldc);
650        break;
651
652
653    case (RsBlas_bnnm):
654        initABC(ain, sizeof(uint8_t), &A, &B, &C, &lda, &ldb, &ldc);
655        kernelBNNM(call->M, call->N, call->K,
656                    (const uint8_t*)A, call->a_offset, lda,
657                    (const uint8_t*)B, call->b_offset, ldb,
658                    (uint8_t*)C, call->c_offset, ldc,
659                    call->c_mult_int);
660
661        break;
662
663    default:
664        ALOGE("unimplemented\n");
665    }
666
667
668}
669
670void RsdCpuScriptIntrinsicBLAS::kernelBNNM(size_t m, size_t n, size_t k,
671                                           const uint8_t* a, uint8_t a_offset, size_t lda,
672                                           const uint8_t* b, uint8_t b_offset, size_t ldb,
673                                           uint8_t* c, int32_t c_offset, size_t ldc,
674                                           int32_t c_mult_int) {
675    const int c_shift = 21;
676#if defined(ARCH_ARM_HAVE_VFP) || defined(ARCH_ARM_USE_INTRINSICS)
677    // Non-optimized path for ARMv7 devices without SIMD instructions.
678    if (!gArchUseSIMD) {
679        /*
680         * Calculations are done in 1.10.21 fixed-point format for the final output,
681         * just before there's a shift down to drop the fractional parts. The output
682         * values are gated to 0 to 255 to fit in a byte, but the 10-bit format
683         * gives some headroom to avoid wrapping around on small overflows.
684         */
685        size_t i = 0, j = 0, l = 0;
686        for (j = 0; j < n; j++) {
687            for (i = 0; i < m; i++) {
688                int32_t total = 0;
689                for (l = 0; l < k; l++) {
690                    const int a_index = ((i * lda) + l);
691                    const uint8_t a_as_byte = a[a_index];
692                    const int32_t a_as_int = (((int32_t)(a_as_byte)) - a_offset);
693                    const int b_index = ((j * ldb) + l);
694                    const uint8_t b_as_byte = b[b_index];
695                    const int32_t b_as_int = (((int32_t)(b_as_byte)) - b_offset);
696                    const int32_t mult_as_int = (a_as_int * b_as_int);
697                    total += mult_as_int;
698                }
699                const int c_index = ((ldc * i) + j);
700                int32_t output =
701                    ((((total + c_offset) * c_mult_int) + (1 << (c_shift - 1)))
702                     >> c_shift);
703                if (output > 255) {
704                    output = 255;
705                }
706                if (output < 0) {
707                    output = 0;
708                }
709                c[c_index] = (uint8_t)(output);
710            }
711        }
712        return;
713    }
714#endif
715
716    // Using gemmlowp to calculate the low precision 8 bit GEMM.
717    bool transpose_a = true;
718    bool transpose_b = false;
719    bool transpose_c = true;
720    gemmlowp::eight_bit_int_gemm::EightBitIntGemm(transpose_a, transpose_b, transpose_c,
721                                                  m, n, k, a, -a_offset, lda,
722                                                  b, -b_offset, ldb, c, c_offset,
723                                                  c_mult_int, c_shift, ldc,
724                                                  gemmlowp::eight_bit_int_gemm::BitDepthSetting::A8B8);
725
726}
727
728
729
730
731
732RsdCpuScriptIntrinsicBLAS::RsdCpuScriptIntrinsicBLAS(RsdCpuReferenceImpl *ctx,
733                                                   const Script *s)
734            : RsdCpuScriptIntrinsic(ctx, s, nullptr, RS_SCRIPT_INTRINSIC_ID_BLAS) {
735
736
737}
738
739RsdCpuScriptIntrinsicBLAS::~RsdCpuScriptIntrinsicBLAS() {
740}
741
742
743
744
745
746RsdCpuScriptImpl * rsdIntrinsic_BLAS(RsdCpuReferenceImpl *ctx,
747                                    const Script *s, const Element *e) {
748
749    return new RsdCpuScriptIntrinsicBLAS(ctx, s);
750}
751