1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17
18#include "rsCpuIntrinsic.h"
19#include "rsCpuIntrinsicInlines.h"
20#include "rsCpuBLASDispatch.h"
21#include "eight_bit_int_gemm.h"
22
23namespace android {
24namespace renderscript {
25
26
27class RsdCpuScriptIntrinsicBLAS : public RsdCpuScriptIntrinsic {
28public:
29    void invokeForEach(uint32_t slot,
30                       const Allocation ** ain,
31                       uint32_t inLen,
32                       Allocation * aout,
33                       const void * usr,
34                       uint32_t usrLen,
35                       const RsScriptCall *sc) override;
36    void populateScript(Script *) override;
37    ~RsdCpuScriptIntrinsicBLAS() override;
38    RsdCpuScriptIntrinsicBLAS(RsdCpuReferenceImpl *ctx, const Script *s);
39
40protected:
41
42    uint8_t a_offset = 0;
43    uint8_t b_offset = 0;
44    uint8_t c_offset = 0;
45
46#ifdef RS_COMPATIBILITY_LIB
47    bool isBlasLibInitialized = false;
48#endif
49    static void kernelBNNM(size_t m, size_t n, size_t k,
50                           const uint8_t* a, uint8_t a_offset, size_t lda,
51                           const uint8_t* b, uint8_t b_offset, size_t ldb,
52                           uint8_t* c, int32_t c_offset, size_t ldc,
53                           int32_t c_mult_int);
54
55
56
57};
58
59void RsdCpuScriptIntrinsicBLAS::populateScript(Script *s) {
60    s->mHal.info.exportedVariableCount = 0;
61}
62
63static void initABC(const Allocation ** ain,
64                    size_t size,
65                    void** A,
66                    void** B,
67                    void** C,
68                    int* lda,
69                    int* ldb,
70                    int* ldc)
71{
72    if (ain[0]) {
73        *A = ain[0]->mHal.drvState.lod[0].mallocPtr;
74        *lda = (int)(ain[0]->mHal.drvState.lod[0].stride/size);
75    }
76    if (ain[1]) {
77        *B = ain[1]->mHal.drvState.lod[0].mallocPtr;
78        *ldb = (int)(ain[1]->mHal.drvState.lod[0].stride/size);
79    }
80    if (ain[2]) {
81        *C = ain[2]->mHal.drvState.lod[0].mallocPtr;
82        *ldc = (int)(ain[2]->mHal.drvState.lod[0].stride/size);
83    }
84}
85
86// Routine to setup LaunchStruct for GEMM callback.
87static void setupGEMM(MTLaunchStructForEachBlas *mtls, const Allocation **ain, RsBlasCall* call,
88                      RsdCpuReferenceImpl *ctx) {
89    uint32_t mm, nn, kk;
90    mm = call->M;
91    nn = call->N;
92    kk = call->K;
93
94    memset(mtls, 0, sizeof(MTLaunchStructForEachBlas));
95    mtls->rs        = ctx;
96    mtls->sc        = call;
97    mtls->dimPtr    = &mtls->fep.dim;
98    mtls->fep.dim.x = nn;
99    mtls->fep.dim.y = mm;
100    mtls->fep.dim.z = kk;
101    if (ain) {
102        memcpy(mtls->ains, ain, 3 * sizeof(ain[0]));
103    }
104    uint32_t elementBytes = 4;
105    if (ain[0]) {
106        elementBytes = ain[0]->getType()->getElement()->getSizeBytes();
107    }
108    const uint32_t MIN_SIZE_TO_TILE = 64 * 1024 / elementBytes;
109    const uint32_t MAX_WORK_PER_THREAD = 512 / elementBytes;
110    const uint32_t THREAD_COUNT = ctx->getThreadCount();
111    uint32_t tileSizeN = 0;
112    uint32_t tileSizeM = 0;
113
114    // Do not tile the matrix if:
115    // 1. It is too small comparing to the other matrix.
116    // 2. It is too small comparing to MIN_SIZE_TO_TILE .
117    if (nn * kk > MIN_SIZE_TO_TILE && nn * THREAD_COUNT > mm) {
118        tileSizeN = rsMin(nn / THREAD_COUNT, MAX_WORK_PER_THREAD);
119    }
120    if (mm * kk > MIN_SIZE_TO_TILE && mm * THREAD_COUNT > nn) {
121        tileSizeM = rsMin(mm / THREAD_COUNT, MAX_WORK_PER_THREAD);
122    }
123    mtls->numTileM = 1;
124    mtls->numTileN = 1;
125    mtls->tileSizeM = mm;
126    mtls->tileSizeN = nn;
127
128    // If tiling is needed, compute the number of slices for A & B.
129    mtls->isThreadable = (tileSizeM > 0 || tileSizeN > 0);
130    if (tileSizeM) {
131        mtls->numTileM += (mm - 1) / tileSizeM;
132        mtls->tileSizeM = tileSizeM;
133    }
134    if (tileSizeN) {
135        mtls->numTileN += (nn - 1) / tileSizeN;
136        mtls->tileSizeN = tileSizeN;
137    }
138
139    mtls->mSliceNum  = 0;
140}
141
142// Generic GEMM callback routine.
143template <typename T_data, typename T_param, typename Func>
144static void walk_tiled_gemm(Func blasFunc, T_param alpha, T_param beta, int vecSize,
145                            RsBlasCall* call, const MTLaunchStructForEachBlas *mtls) {
146    // setup BLAS enum args
147    enum CBLAS_TRANSPOSE TransA = (enum CBLAS_TRANSPOSE)call->transA;
148    enum CBLAS_TRANSPOSE TransB = (enum CBLAS_TRANSPOSE)call->transB;
149
150    void *A = nullptr;
151    void *B = nullptr;
152    void *C = nullptr;
153
154    int lda = 0, ldb = 0, ldc = 0;
155
156    const Allocation *ain[RS_KERNEL_INPUT_LIMIT];
157    ain[0] = mtls->ains[0];
158    ain[1] = mtls->ains[1];
159    ain[2] = mtls->ains[2];
160
161    initABC(ain, sizeof(T_data) * vecSize, &A, &B, &C, &lda, &ldb, &ldc);
162
163    // Determin the stride of the tiled matrices.
164    int mStride = (TransA == CblasNoTrans) ? lda : 1;
165    int nStride = (TransB == CblasNoTrans) ? 1 : ldb;
166    while (1) {
167        uint32_t slice  = (uint32_t)__sync_fetch_and_add(&mtls->mSliceNum, 1);
168
169        uint32_t mStart = (slice % mtls->numTileM) * mtls->tileSizeM;
170        uint32_t mEnd   = mStart + mtls->tileSizeM;
171        mEnd = rsMin(mEnd, (uint32_t)call->M);
172        if (mEnd <= mStart) {
173            return;
174        }
175
176        uint32_t nStart = (slice / mtls->numTileM) * mtls->tileSizeN;
177        uint32_t nEnd   = nStart + mtls->tileSizeN;
178        nEnd = rsMin(nEnd, (uint32_t)call->N);
179        if (nEnd <= nStart) {
180            return;
181        }
182
183        blasFunc(CblasRowMajor, TransA, TransB,
184                 mEnd - mStart, nEnd - nStart, call->K, alpha,
185                 (T_data *)A + mStart * mStride * vecSize, lda,
186                 (T_data *)B + nStart * nStride * vecSize, ldb, beta,
187                 (T_data *)C + (mStart * ldc + nStart) * vecSize, ldc);
188    }
189}
190
191// SGEMM callback
192static void walk_2d_sgemm(void *usr, uint32_t idx) {
193    const MTLaunchStructForEachBlas *mtls = (const MTLaunchStructForEachBlas *)usr;
194    RsBlasCall* call = (RsBlasCall*) mtls->sc;
195
196    float alpha = call->alpha.f;
197    float beta = call->beta.f;
198
199    walk_tiled_gemm<float, float, FnPtr_cblas_sgemm>(cblas_sgemm, alpha, beta, 1, call, mtls);
200}
201
202// DGEMM callback
203static void walk_2d_dgemm(void *usr, uint32_t idx) {
204    const MTLaunchStructForEachBlas *mtls = (const MTLaunchStructForEachBlas *)usr;
205    RsBlasCall* call = (RsBlasCall*) mtls->sc;
206
207    double alpha = call->alpha.d;
208    double beta = call->beta.d;
209
210    walk_tiled_gemm<double, double, FnPtr_cblas_dgemm>(cblas_dgemm, alpha, beta, 1, call, mtls);
211}
212
213// CGEMM callback
214static void walk_2d_cgemm(void *usr, uint32_t idx) {
215    const MTLaunchStructForEachBlas *mtls = (const MTLaunchStructForEachBlas *)usr;
216    RsBlasCall* call = (RsBlasCall*) mtls->sc;
217
218    void * alpha = (void *)&call->alpha.c;
219    void * beta = (void *)&call->beta.c;
220
221    walk_tiled_gemm<float, void *, FnPtr_cblas_cgemm>(cblas_cgemm, alpha, beta, 2, call, mtls);
222}
223
224// ZGEMM callback
225static void walk_2d_zgemm(void *usr, uint32_t idx) {
226    const MTLaunchStructForEachBlas *mtls = (const MTLaunchStructForEachBlas *)usr;
227    RsBlasCall* call = (RsBlasCall*) mtls->sc;
228
229    void * alpha = (void *)&call->alpha.z;
230    void * beta = (void *)&call->beta.z;
231
232    walk_tiled_gemm<double, void *, FnPtr_cblas_zgemm>(cblas_zgemm, alpha, beta, 2, call, mtls);
233}
234
235
236void RsdCpuScriptIntrinsicBLAS::invokeForEach(uint32_t slot,
237                                              const Allocation ** ain,
238                                              uint32_t inLen,
239                                              Allocation * aout,
240                                              const void * usr,
241                                              uint32_t usrLen,
242                                              const RsScriptCall *sc) {
243    RsBlasCall* call = (RsBlasCall*) usr;
244    // setup BLAS enum args
245    enum CBLAS_TRANSPOSE TransA = (enum CBLAS_TRANSPOSE)call->transA;
246    enum CBLAS_TRANSPOSE TransB = (enum CBLAS_TRANSPOSE)call->transB;
247    enum CBLAS_UPLO Uplo = (enum CBLAS_UPLO)call->uplo;
248    enum CBLAS_DIAG Diag = (enum CBLAS_DIAG)call->diag;
249    enum CBLAS_SIDE Side = (enum CBLAS_SIDE)call->side;
250
251    void *A = nullptr;
252    void *B = nullptr;
253    void *C = nullptr;
254    void *X = nullptr;
255    void *Y = nullptr;
256
257    int lda = 0, ldb = 0, ldc = 0;
258
259    MTLaunchStructForEachBlas mtls;
260
261#ifdef RS_COMPATIBILITY_LIB
262    // Allow BNNM even without libblas
263    if (call->func != RsBlas_bnnm && !isBlasLibInitialized) {
264        if (!loadBLASLib()) {
265            ALOGE("Failed to load the BLAS lib, IntrinsicBLAS NOT supported!\n");
266            return;
267        }
268        isBlasLibInitialized = true;
269    }
270#endif
271
272    switch (call->func) {
273
274    // Level 1 BLAS: returns into a 1D Allocation
275
276
277    // Level 2 BLAS
278    case (RsBlas_sgemv):
279        initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
280        cblas_sgemv(CblasRowMajor, TransA, call->M, call->N, call->alpha.f, (float*)A,
281                    lda, (float*)X, call->incX, call->beta.f, (float*)Y, call->incY);
282        break;
283    case (RsBlas_sgbmv):
284        initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
285        cblas_sgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU,
286                    call->alpha.f, (float*)A, lda, (float*)X, call->incX,
287                    call->beta.f, (float*)Y, call->incY);
288        break;
289    case (RsBlas_strmv):
290        initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
291        cblas_strmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A,
292                    lda, (float*)X, call->incX);
293        break;
294    case (RsBlas_stbmv):
295        initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
296        cblas_stbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (float*)A,
297                    lda, (float*)X, call->incX);
298        break;
299    // stpmv takes a packed 1D Allocation only
300    case (RsBlas_stpmv):
301        initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
302        cblas_stpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A,
303                    (float*)X, call->incX);
304        break;
305    case (RsBlas_strsv):
306        initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
307        cblas_strsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A, lda,
308                    (float*)X, call->incX);
309        break;
310    case (RsBlas_stbsv):
311        initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
312        cblas_stbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (float*)A,
313                    lda, (float*)X, call->incX);
314        break;
315    case (RsBlas_stpsv):
316        initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
317        cblas_stpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A,
318                    (float*)X, call->incX);
319        break;
320    case (RsBlas_dgemv):
321        initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
322        cblas_dgemv(CblasRowMajor, TransA, call->M, call->N, call->alpha.d, (double*)A,
323                    lda, (double*)X, call->incX, call->beta.d, (double*)Y, call->incY);
324        break;
325    case (RsBlas_dgbmv):
326        initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
327        cblas_dgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU,
328                    call->alpha.d, (double*)A, lda, (double*)X, call->incX,
329                    call->beta.d, (double*)Y, call->incY);
330        break;
331    case (RsBlas_dtrmv):
332        initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
333        cblas_dtrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A,
334                    lda, (double*)X, call->incX);
335        break;
336    case (RsBlas_dtbmv):
337        initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
338        cblas_dtbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (double*)A,
339                    lda, (double*)X, call->incX);
340        break;
341    // stpmv takes a packed 1D Allocation only
342    case (RsBlas_dtpmv):
343        initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
344        cblas_dtpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A,
345                    (double*)X, call->incX);
346        break;
347    case (RsBlas_dtrsv):
348        initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
349        cblas_dtrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A, lda,
350                    (double*)X, call->incX);
351        break;
352    case (RsBlas_dtbsv):
353        initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
354        cblas_dtbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (double*)A,
355                    lda, (double*)X, call->incX);
356        break;
357    case (RsBlas_dtpsv):
358        initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
359        cblas_dtpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A,
360                    (double*)X, call->incX);
361        break;
362    case (RsBlas_cgemv):
363        initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
364        cblas_cgemv(CblasRowMajor, TransA, call->M, call->N, (void*)&call->alpha.c, (void*)A,
365                    lda, (void*)X, call->incX, (void*)&call->beta.c, (void*)Y, call->incY);
366        break;
367    case (RsBlas_cgbmv):
368        initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
369        cblas_cgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU,
370                    (void*)&call->alpha.c, (void*)A, lda, (void*)X, call->incX,
371                    (void*)&call->beta.c, (void*)Y, call->incY);
372        break;
373    case (RsBlas_ctrmv):
374        initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
375        cblas_ctrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
376                    lda, (void*)X, call->incX);
377        break;
378    case (RsBlas_ctbmv):
379        initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
380        cblas_ctbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A,
381                    lda, (void*)X, call->incX);
382        break;
383    // stpmv takes a packed 1D Allocation only
384    case (RsBlas_ctpmv):
385        initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
386        cblas_ctpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
387                    (void*)X, call->incX);
388        break;
389    case (RsBlas_ctrsv):
390        initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
391        cblas_ctrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, lda,
392                    (void*)X, call->incX);
393        break;
394    case (RsBlas_ctbsv):
395        initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
396        cblas_ctbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A,
397                    lda, (void*)X, call->incX);
398        break;
399    case (RsBlas_ctpsv):
400        initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
401        cblas_ctpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
402                    (void*)X, call->incX);
403        break;
404    case (RsBlas_zgemv):
405        initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
406        cblas_zgemv(CblasRowMajor, TransA, call->M, call->N, (void*)&call->alpha.z, (void*)A,
407                    lda, (void*)X, call->incX, (void*)&call->beta.z, (void*)Y, call->incY);
408        break;
409    case (RsBlas_zgbmv):
410        initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
411        cblas_zgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU,
412                    (void*)&call->alpha.z, (void*)A, lda, (void*)X, call->incX,
413                    (void*)&call->beta.z, (void*)Y, call->incY);
414        break;
415    case (RsBlas_ztrmv):
416        initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
417        cblas_ztrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
418                    lda, (void*)X, call->incX);
419        break;
420    case (RsBlas_ztbmv):
421        initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
422        cblas_ztbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A,
423                    lda, (void*)X, call->incX);
424        break;
425    // stpmv takes a packed 1D Allocation only
426    case (RsBlas_ztpmv):
427        initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
428        cblas_ztpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
429                    (void*)X, call->incX);
430        break;
431    case (RsBlas_ztrsv):
432        initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
433        cblas_ztrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, lda,
434                    (void*)X, call->incX);
435        break;
436    case (RsBlas_ztbsv):
437        initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
438        cblas_ztbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A,
439                    lda, (void*)X, call->incX);
440        break;
441    case (RsBlas_ztpsv):
442        initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
443        cblas_ztpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
444                    (void*)X, call->incX);
445        break;
446
447
448    // S and D only
449    case (RsBlas_ssymv):
450        initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
451        cblas_ssymv(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)A, lda,
452                    (float*)X, call->incX, call->beta.f, (float*)Y, call->incY);
453        break;
454    case (RsBlas_ssbmv):
455        initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
456        cblas_ssbmv(CblasRowMajor, Uplo, call->N, call->K, call->alpha.f,
457                    (float*)A, lda, (float*)X, call->incX, call->beta.f,
458                    (float*)Y, call->incY);
459        break;
460    //sspmv requires a packed 1D Allocation
461    case (RsBlas_sspmv):
462        initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
463        cblas_sspmv(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)A,
464                    (float*)X, call->incX, call->beta.f, (float*)Y, call->incY);
465        break;
466    // following calls have init reordered because A is output matrix
467    case (RsBlas_sger):
468        initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda);
469        cblas_sger(CblasRowMajor, call->M, call->N, call->alpha.f, (float*)X,
470                   call->incX, (float*)Y, call->incY, (float*)A, lda);
471        break;
472    case (RsBlas_ssyr):
473        initABC(ain, sizeof(float), &X, &A, nullptr, &ldb, &lda, nullptr);
474        cblas_ssyr(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX,
475                   (float*)A, lda);
476        break;
477    // sspr is packed 1D Allocation A only
478    case (RsBlas_sspr):
479        initABC(ain, sizeof(float), &X, &A, nullptr, &ldb, &lda, nullptr);
480        cblas_sspr(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX,
481                   (float*)A);
482        break;
483    case (RsBlas_ssyr2):
484        initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda);
485        cblas_ssyr2(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX,
486                    (float*)Y, call->incY, (float*)A, lda);
487        break;
488    // sspr2 is packed 1D Allocation A only
489    case (RsBlas_sspr2):
490        initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda);
491        cblas_sspr2(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX,
492                    (float*)Y, call->incY, (float*)A);
493        break;
494    case (RsBlas_dsymv):
495        initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
496        cblas_dsymv(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)A, lda,
497                    (double*)X, call->incX, call->beta.d, (double*)Y, call->incY);
498        break;
499    case (RsBlas_dsbmv):
500        initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
501        cblas_dsbmv(CblasRowMajor, Uplo, call->N, call->K, call->alpha.d,
502                    (double*)A, lda, (double*)X, call->incX, call->beta.d,
503                    (double*)Y, call->incY);
504        break;
505    // dspmv requires a packed 1D Allocation
506    case (RsBlas_dspmv):
507        initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
508        cblas_dspmv(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)A,
509                    (double*)X, call->incX, call->beta.d, (double*)Y, call->incY);
510        break;
511    // following calls have init reordered because A is output matrix
512    case (RsBlas_dger):
513        initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda);
514        cblas_dger(CblasRowMajor, call->M, call->N, call->alpha.d, (double*)X,
515                   call->incX, (double*)Y, call->incY, (double*)A, lda);
516        break;
517    case (RsBlas_dsyr):
518        initABC(ain, sizeof(double), &X, &A, nullptr, &ldb, &lda, nullptr);
519        cblas_dsyr(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX,
520                   (double*)A, lda);
521        break;
522    // dspr is packed 1D Allocation A only
523    case (RsBlas_dspr):
524        initABC(ain, sizeof(double), &X, &A, nullptr, &ldb, &lda, nullptr);
525        cblas_dspr(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX,
526                   (double*)A);
527        break;
528    case (RsBlas_dsyr2):
529        initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda);
530        cblas_dsyr2(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX,
531                    (double*)Y, call->incY, (double*)A, lda);
532        break;
533    // dspr2 is packed 1D Allocation A only
534    case (RsBlas_dspr2):
535        initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda);
536        cblas_dspr2(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX,
537                    (double*)Y, call->incY, (double*)A);
538        break;
539
540    // C and Z only
541    case (RsBlas_chemv):
542        initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
543        cblas_chemv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, A, lda,
544                    X, call->incX, (void*)&call->beta.c, Y, call->incY);
545        break;
546    case (RsBlas_chbmv):
547        initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
548        cblas_chbmv(CblasRowMajor, Uplo, call->N, call->K, (void*)&call->alpha.c,
549                    A, lda, X, call->incX, (void*)&call->beta.c, Y, call->incY);
550        break;
551    case (RsBlas_chpmv):
552        initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
553        cblas_chpmv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, A,
554                    X, call->incX, (void*)&call->beta.c, Y, call->incY);
555        break;
556    case (RsBlas_cgeru):
557        initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda);
558        cblas_cgeru(CblasRowMajor, call->M, call->N, (void*)&call->alpha.c,
559                    X, call->incX, Y, call->incY, A, lda);
560        break;
561    case (RsBlas_cgerc):
562        initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda);
563        cblas_cgerc(CblasRowMajor, call->M, call->N, (void*)&call->alpha.c,
564                    X, call->incX, Y, call->incY, A, lda);
565        break;
566    case (RsBlas_cher):
567        initABC(ain, sizeof(float)*2, &X, nullptr, &A, &ldb, nullptr, &lda);
568        cblas_cher(CblasRowMajor, Uplo, call->N, call->alpha.f,
569                   X, call->incX, A, lda);
570        break;
571    // packed 1D Allocations only
572    case (RsBlas_chpr):
573        initABC(ain, sizeof(float)*2, &X, nullptr, &A, &ldb, nullptr, &lda);
574        cblas_chpr(CblasRowMajor, Uplo, call->N, call->alpha.f, X,
575                   call->incX, A);
576        break;
577    case (RsBlas_cher2):
578        initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda);
579        cblas_cher2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c,
580                   X, call->incX, Y, call->incY, A, lda);
581        break;
582    // packed 1D Allocations only
583    case (RsBlas_chpr2):
584        initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda);
585        cblas_chpr2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, X,
586                   call->incX, Y, call->incY, A);
587        break;
588    case (RsBlas_zhemv):
589        initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
590        cblas_zhemv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, A, lda,
591                    X, call->incX, (void*)&call->beta.z, Y, call->incY);
592        break;
593    case (RsBlas_zhbmv):
594        initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
595        cblas_zhbmv(CblasRowMajor, Uplo, call->N, call->K, (void*)&call->alpha.z,
596                    A, lda, X, call->incX, (void*)&call->beta.z, Y, call->incY);
597        break;
598    case (RsBlas_zhpmv):
599        initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
600        cblas_zhpmv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, A,
601                    X, call->incX, (void*)&call->beta.z, Y, call->incY);
602        break;
603    case (RsBlas_zgeru):
604        initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda);
605        cblas_zgeru(CblasRowMajor, call->M, call->N, (void*)&call->alpha.z,
606                    X, call->incX, Y, call->incY, A, lda);
607        break;
608    case (RsBlas_zgerc):
609        initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda);
610        cblas_zgerc(CblasRowMajor, call->M, call->N, (void*)&call->alpha.z,
611                    X, call->incX, Y, call->incY, A, lda);
612        break;
613    case (RsBlas_zher):
614        initABC(ain, sizeof(double)*2, &X, nullptr, &A, &ldb, nullptr, &lda);
615        cblas_zher(CblasRowMajor, Uplo, call->N, call->alpha.d,
616                   X, call->incX, A, lda);
617        break;
618    // packed 1D Allocations only
619    case (RsBlas_zhpr):
620        initABC(ain, sizeof(double)*2, &X, nullptr, &A, &ldb, nullptr, &lda);
621        cblas_zhpr(CblasRowMajor, Uplo, call->N, call->alpha.d, X,
622                   call->incX, A);
623        break;
624    case (RsBlas_zher2):
625        initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda);
626        cblas_zher2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z,
627                   X, call->incX, Y, call->incY, A, lda);
628        break;
629    // packed 1D Allocations only
630    case (RsBlas_zhpr2):
631        initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda);
632        cblas_zhpr2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, X,
633                   call->incX, Y, call->incY, A);
634        break;
635
636    // Level 3 BLAS
637    case (RsBlas_sgemm):
638        setupGEMM(&mtls, ain, call, mCtx);
639        if (mtls.isThreadable) {
640            mCtx->launchThreads(walk_2d_sgemm, &mtls);
641        } else {
642            initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc);
643            cblas_sgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.f,
644                        (float*)A, lda, (float*)B, ldb, call->beta.f, (float*)C, ldc);
645        }
646        break;
647    case (RsBlas_ssymm):
648        initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc);
649        cblas_ssymm(CblasRowMajor, Side, Uplo, call->M, call->N, call->alpha.f, (float*)A,
650                    lda, (float*)B, ldb, call->beta.f, (float*)C, ldc);
651        break;
652    case (RsBlas_ssyrk):
653        initABC(ain, sizeof(float), &A, nullptr, &C, &lda, nullptr, &ldc);
654        cblas_ssyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, (float*)A,
655                    lda, call->beta.f, (float*)C, ldc);
656        break;
657    case (RsBlas_ssyr2k):
658        initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc);
659        cblas_ssyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, (float*)A,
660                     lda, (float*)B, ldb, call->beta.f, (float*)C, ldc);
661        break;
662    case (RsBlas_strmm):
663        initABC(ain, sizeof(float), &A, &B, nullptr, &lda, &ldb, nullptr);
664        cblas_strmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.f,
665                    (float*)A, lda, (float*)B, ldb);
666        break;
667    case (RsBlas_strsm):
668        initABC(ain, sizeof(float), &A, &B, nullptr, &lda, &ldb, nullptr);
669        cblas_strsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.f,
670                    (float*)A, lda, (float*)B, ldb);
671        break;
672
673
674    case (RsBlas_dgemm):
675        setupGEMM(&mtls, ain, call, mCtx);
676        if (mtls.isThreadable) {
677            mCtx->launchThreads(walk_2d_dgemm, &mtls);
678        } else {
679            initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc);
680            cblas_dgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.d,
681                        (double*)A, lda, (double*)B, ldb, call->beta.d, (double*)C, ldc);
682        }
683        break;
684    case (RsBlas_dsymm):
685        initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc);
686        cblas_dsymm(CblasRowMajor, Side, Uplo, call->M, call->N, call->alpha.d, (double*)A,
687                    lda, (double*)B, ldb, call->beta.d, (double*)C, ldc);
688        break;
689    case (RsBlas_dsyrk):
690        initABC(ain, sizeof(double), &A, nullptr, &C, &lda, nullptr, &ldc);
691        cblas_dsyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, (double*)A,
692                    lda, call->beta.d, (double*)C, ldc);
693        break;
694    case (RsBlas_dsyr2k):
695        initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc);
696        cblas_dsyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, (double*)A,
697                     lda, (double*)B, ldb, call->beta.d, (double*)C, ldc);
698        break;
699    case (RsBlas_dtrmm):
700        initABC(ain, sizeof(double), &A, &B, nullptr, &lda, &ldb, nullptr);
701        cblas_dtrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.d,
702                    (double*)A, lda, (double*)B, ldb);
703        break;
704    case (RsBlas_dtrsm):
705        initABC(ain, sizeof(double), &A, &B, nullptr, &lda, &ldb, nullptr);
706        cblas_dtrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.d,
707                    (double*)A, lda, (double*)B, ldb);
708        break;
709
710    case (RsBlas_cgemm):
711        setupGEMM(&mtls, ain, call, mCtx);
712        if (mtls.isThreadable) {
713            mCtx->launchThreads(walk_2d_cgemm, &mtls);
714        } else {
715            initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
716            cblas_cgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.c,
717                        A, lda, B, ldb, (void*)&call->beta.c, C, ldc);
718        }
719        break;
720    case (RsBlas_csymm):
721        initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
722        cblas_csymm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.c, A,
723                    lda, B, ldb, (void*)&call->beta.c, C, ldc);
724        break;
725    case (RsBlas_csyrk):
726        initABC(ain, sizeof(float)*2, &A, nullptr, &C, &lda, nullptr, &ldc);
727        cblas_csyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A,
728                    lda, (void*)&call->beta.c, C, ldc);
729        break;
730    case (RsBlas_csyr2k):
731        initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
732        cblas_csyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A,
733                     lda, B, ldb, (void*)&call->beta.c, C, ldc);
734        break;
735    case (RsBlas_ctrmm):
736        initABC(ain, sizeof(float)*2, &A, &B, nullptr, &lda, &ldb, nullptr);
737        cblas_ctrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.c,
738                    A, lda, B, ldb);
739        break;
740    case (RsBlas_ctrsm):
741        initABC(ain, sizeof(float)*2, &A, &B, nullptr, &lda, &ldb, nullptr);
742        cblas_ctrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.c,
743                    A, lda, B, ldb);
744        break;
745
746    case (RsBlas_zgemm):
747        setupGEMM(&mtls, ain, call, mCtx);
748        if (mtls.isThreadable) {
749            mCtx->launchThreads(walk_2d_zgemm, &mtls);
750        } else {
751            initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
752            cblas_zgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.z,
753                        A, lda, B, ldb, (void*)&call->beta.z, C, ldc);
754        }
755        break;
756    case (RsBlas_zsymm):
757        initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
758        cblas_zsymm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.z, A,
759                    lda, B, ldb, (void*)&call->beta.z, C, ldc);
760        break;
761    case (RsBlas_zsyrk):
762        initABC(ain, sizeof(double)*2, &A, nullptr, &C, &lda, nullptr, &ldc);
763        cblas_zsyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A,
764                    lda, (void*)&call->beta.z, C, ldc);
765        break;
766    case (RsBlas_zsyr2k):
767        initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
768        cblas_zsyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A,
769                     lda, B, ldb, (void*)&call->beta.z, C, ldc);
770        break;
771    case (RsBlas_ztrmm):
772        initABC(ain, sizeof(double)*2, &A, &B, nullptr, &lda, &ldb, nullptr);
773        cblas_ztrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.z,
774                    A, lda, B, ldb);
775        break;
776    case (RsBlas_ztrsm):
777        initABC(ain, sizeof(double)*2, &A, &B, nullptr, &lda, &ldb, nullptr);
778        cblas_ztrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.z,
779                    A, lda, B, ldb);
780        break;
781
782    // Level 3 C and Z only
783    case (RsBlas_chemm):
784        initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
785        cblas_chemm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.c, A, lda,
786                    B, ldb, (void*)&call->beta.c, C, ldc);
787        break;
788    case (RsBlas_cherk):
789        initABC(ain, sizeof(float)*2, &A, nullptr, &C, &lda, nullptr, &ldc);
790        cblas_cherk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, A, lda,
791                    call->beta.f, C, ldc);
792        break;
793    case (RsBlas_cher2k):
794        initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
795        cblas_cher2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A, lda,
796                     B, ldb, call->beta.f, C, ldc);
797        break;
798
799    case (RsBlas_zhemm):
800        initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
801        cblas_zhemm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.z, A, lda,
802                    B, ldb, (void*)&call->beta.z, C, ldc);
803        break;
804    case (RsBlas_zherk):
805        initABC(ain, sizeof(double)*2, &A, nullptr, &C, &lda, nullptr, &ldc);
806        cblas_zherk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, A, lda,
807                    call->beta.d, C, ldc);
808        break;
809    case (RsBlas_zher2k):
810        initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
811        cblas_zher2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A, lda,
812                     B, ldb, call->beta.d, C, ldc);
813        break;
814
815
816    case (RsBlas_bnnm):
817        initABC(ain, sizeof(uint8_t), &A, &B, &C, &lda, &ldb, &ldc);
818        kernelBNNM(call->M, call->N, call->K,
819                    (const uint8_t*)A, call->a_offset, lda,
820                    (const uint8_t*)B, call->b_offset, ldb,
821                    (uint8_t*)C, call->c_offset, ldc,
822                    call->c_mult_int);
823
824        break;
825
826    default:
827        ALOGE("unimplemented\n");
828    }
829
830
831}
832
833void RsdCpuScriptIntrinsicBLAS::kernelBNNM(size_t m, size_t n, size_t k,
834                                           const uint8_t* a, uint8_t a_offset, size_t lda,
835                                           const uint8_t* b, uint8_t b_offset, size_t ldb,
836                                           uint8_t* c, int32_t c_offset, size_t ldc,
837                                           int32_t c_mult_int) {
838    const int c_shift = 21;
839#if defined(ARCH_ARM_HAVE_VFP) || defined(ARCH_ARM_USE_INTRINSICS)
840    // Non-optimized path for ARMv7 devices without SIMD instructions.
841    if (!gArchUseSIMD) {
842        /*
843         * Calculations are done in 1.10.21 fixed-point format for the final output,
844         * just before there's a shift down to drop the fractional parts. The output
845         * values are gated to 0 to 255 to fit in a byte, but the 10-bit format
846         * gives some headroom to avoid wrapping around on small overflows.
847         */
848        size_t i = 0, j = 0, l = 0;
849        for (j = 0; j < n; j++) {
850            for (i = 0; i < m; i++) {
851                int32_t total = 0;
852                for (l = 0; l < k; l++) {
853                    const int a_index = ((i * lda) + l);
854                    const uint8_t a_as_byte = a[a_index];
855                    const int32_t a_as_int = (((int32_t)(a_as_byte)) - a_offset);
856                    const int b_index = ((j * ldb) + l);
857                    const uint8_t b_as_byte = b[b_index];
858                    const int32_t b_as_int = (((int32_t)(b_as_byte)) - b_offset);
859                    const int32_t mult_as_int = (a_as_int * b_as_int);
860                    total += mult_as_int;
861                }
862                const int c_index = ((ldc * i) + j);
863                int32_t output =
864                    ((((total + c_offset) * c_mult_int) + (1 << (c_shift - 1)))
865                     >> c_shift);
866                if (output > 255) {
867                    output = 255;
868                }
869                if (output < 0) {
870                    output = 0;
871                }
872                c[c_index] = (uint8_t)(output);
873            }
874        }
875        return;
876    }
877#endif
878
879    // Using gemmlowp to calculate the low precision 8 bit GEMM.
880    bool transpose_a = true;
881    bool transpose_b = false;
882    bool transpose_c = true;
883    gemmlowp::eight_bit_int_gemm::EightBitIntGemm(transpose_a, transpose_b, transpose_c,
884                                                  m, n, k, a, -a_offset, lda,
885                                                  b, -b_offset, ldb, c, c_offset,
886                                                  c_mult_int, c_shift, ldc,
887                                                  gemmlowp::eight_bit_int_gemm::BitDepthSetting::A8B8);
888
889}
890
891
892
893
894
895RsdCpuScriptIntrinsicBLAS::RsdCpuScriptIntrinsicBLAS(RsdCpuReferenceImpl *ctx,
896                                                   const Script *s)
897            : RsdCpuScriptIntrinsic(ctx, s, nullptr, RS_SCRIPT_INTRINSIC_ID_BLAS) {
898
899
900}
901
902RsdCpuScriptIntrinsicBLAS::~RsdCpuScriptIntrinsicBLAS() {
903}
904
905RsdCpuScriptImpl * rsdIntrinsic_BLAS(RsdCpuReferenceImpl *ctx,
906                                    const Script *s, const Element *e) {
907
908    return new RsdCpuScriptIntrinsicBLAS(ctx, s);
909}
910
911} // namespace renderscript
912} // namespace android
913