rsCpuIntrinsicBLAS.cpp revision 22cb808b0dfc9bd514d2e19b302a97f8455b5731
1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17
18#include "rsCpuIntrinsic.h"
19#include "rsCpuIntrinsicInlines.h"
20#include "cblas.h"
21
22using namespace android;
23using namespace android::renderscript;
24
25namespace android {
26namespace renderscript {
27
28
29class RsdCpuScriptIntrinsicBLAS : public RsdCpuScriptIntrinsic {
30public:
31    void invokeForEach(uint32_t slot,
32                       const Allocation ** ain,
33                       uint32_t inLen,
34                       Allocation * aout,
35                       const void * usr,
36                       uint32_t usrLen,
37                       const RsScriptCall *sc) override;
38
39    void populateScript(Script *) override;
40    ~RsdCpuScriptIntrinsicBLAS() override;
41    RsdCpuScriptIntrinsicBLAS(RsdCpuReferenceImpl *ctx, const Script *s);
42
43protected:
44
45    uint8_t a_offset = 0;
46    uint8_t b_offset = 0;
47    uint8_t c_offset = 0;
48
49    static void kernelBNNM(size_t m, size_t n, size_t k,
50                           const uint8_t* a, uint8_t a_offset, size_t lda,
51                           const uint8_t* b, uint8_t b_offset, size_t ldb,
52                           uint8_t* c, int32_t c_offset, size_t ldc,
53                           int32_t c_mult_int);
54
55
56
57};
58
59}
60}
61
62void RsdCpuScriptIntrinsicBLAS::populateScript(Script *s) {
63    s->mHal.info.exportedVariableCount = 0;
64}
65
66static void initABC(const Allocation ** ain,
67                    size_t size,
68                    void** A,
69                    void** B,
70                    void** C,
71                    int* lda,
72                    int* ldb,
73                    int* ldc)
74{
75    if (ain[0]) {
76        *A = ain[0]->mHal.drvState.lod[0].mallocPtr;
77        *lda = (int)(ain[0]->mHal.drvState.lod[0].stride/size);
78    }
79    if (ain[1]) {
80        *B = ain[1]->mHal.drvState.lod[0].mallocPtr;
81        *ldb = (int)(ain[1]->mHal.drvState.lod[0].stride/size);
82    }
83    if (ain[2]) {
84        *C = ain[2]->mHal.drvState.lod[0].mallocPtr;
85        *ldc = (int)(ain[2]->mHal.drvState.lod[0].stride/size);
86    }
87
88
89}
90
91void RsdCpuScriptIntrinsicBLAS::invokeForEach(uint32_t slot,
92                                              const Allocation ** ain,
93                                              uint32_t inLen,
94                                              Allocation * aout,
95                                              const void * usr,
96                                              uint32_t usrLen,
97                                              const RsScriptCall *sc) {
98    RsBlasCall* call = (RsBlasCall*) usr;
99    // setup BLAS enum args
100    enum CBLAS_TRANSPOSE TransA = (enum CBLAS_TRANSPOSE)call->transA;
101    enum CBLAS_TRANSPOSE TransB = (enum CBLAS_TRANSPOSE)call->transB;
102    enum CBLAS_UPLO Uplo = (enum CBLAS_UPLO)call->uplo;
103    enum CBLAS_DIAG Diag = (enum CBLAS_DIAG)call->diag;
104    enum CBLAS_SIDE Side = (enum CBLAS_SIDE)call->side;
105
106    void *A = nullptr;
107    void *B = nullptr;
108    void *C = nullptr;
109    void *X = nullptr;
110    void *Y = nullptr;
111
112    int lda = 0, ldb = 0, ldc = 0;
113
114    switch (call->func) {
115
116    // Level 1 BLAS: returns into a 1D Allocation
117
118
119    // Level 2 BLAS
120    case (RsBlas_sgemv):
121        initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
122        cblas_sgemv(CblasRowMajor, TransA, call->M, call->N, call->alpha.f, (float*)A,
123                    lda, (float*)X, call->incX, call->beta.f, (float*)Y, call->incY);
124        break;
125    case (RsBlas_sgbmv):
126        initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
127        cblas_sgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU,
128                    call->alpha.f, (float*)A, lda, (float*)X, call->incX,
129                    call->beta.f, (float*)Y, call->incY);
130        break;
131    case (RsBlas_strmv):
132        initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
133        cblas_strmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A,
134                    lda, (float*)X, call->incX);
135        break;
136    case (RsBlas_stbmv):
137        initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
138        cblas_stbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (float*)A,
139                    lda, (float*)X, call->incX);
140        break;
141    // stpmv takes a packed 1D Allocation only
142    case (RsBlas_stpmv):
143        initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
144        cblas_stpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A,
145                    (float*)X, call->incX);
146        break;
147    case (RsBlas_strsv):
148        initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
149        cblas_strsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A, lda,
150                    (float*)X, call->incX);
151        break;
152    case (RsBlas_stbsv):
153        initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
154        cblas_stbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (float*)A,
155                    lda, (float*)X, call->incX);
156        break;
157    case (RsBlas_stpsv):
158        initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
159        cblas_stpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A,
160                    (float*)X, call->incX);
161        break;
162    case (RsBlas_dgemv):
163        initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
164        cblas_dgemv(CblasRowMajor, TransA, call->M, call->N, call->alpha.d, (double*)A,
165                    lda, (double*)X, call->incX, call->beta.d, (double*)Y, call->incY);
166        break;
167    case (RsBlas_dgbmv):
168        initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
169        cblas_dgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU,
170                    call->alpha.d, (double*)A, lda, (double*)X, call->incX,
171                    call->beta.d, (double*)Y, call->incY);
172        break;
173    case (RsBlas_dtrmv):
174        initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
175        cblas_dtrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A,
176                    lda, (double*)X, call->incX);
177        break;
178    case (RsBlas_dtbmv):
179        initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
180        cblas_dtbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (double*)A,
181                    lda, (double*)X, call->incX);
182        break;
183    // stpmv takes a packed 1D Allocation only
184    case (RsBlas_dtpmv):
185        initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
186        cblas_dtpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A,
187                    (double*)X, call->incX);
188        break;
189    case (RsBlas_dtrsv):
190        initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
191        cblas_dtrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A, lda,
192                    (double*)X, call->incX);
193        break;
194    case (RsBlas_dtbsv):
195        initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
196        cblas_dtbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (double*)A,
197                    lda, (double*)X, call->incX);
198        break;
199    case (RsBlas_dtpsv):
200        initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
201        cblas_dtpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A,
202                    (double*)X, call->incX);
203        break;
204    case (RsBlas_cgemv):
205        initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
206        cblas_cgemv(CblasRowMajor, TransA, call->M, call->N, (void*)&call->alpha.c, (void*)A,
207                    lda, (void*)X, call->incX, (void*)&call->beta.c, (void*)Y, call->incY);
208        break;
209    case (RsBlas_cgbmv):
210        initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
211        cblas_cgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU,
212                    (void*)&call->alpha.c, (void*)A, lda, (void*)X, call->incX,
213                    (void*)&call->beta.c, (void*)Y, call->incY);
214        break;
215    case (RsBlas_ctrmv):
216        initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
217        cblas_ctrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
218                    lda, (void*)X, call->incX);
219        break;
220    case (RsBlas_ctbmv):
221        initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
222        cblas_ctbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A,
223                    lda, (void*)X, call->incX);
224        break;
225    // stpmv takes a packed 1D Allocation only
226    case (RsBlas_ctpmv):
227        initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
228        cblas_ctpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
229                    (void*)X, call->incX);
230        break;
231    case (RsBlas_ctrsv):
232        initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
233        cblas_ctrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, lda,
234                    (void*)X, call->incX);
235        break;
236    case (RsBlas_ctbsv):
237        initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
238        cblas_ctbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A,
239                    lda, (void*)X, call->incX);
240        break;
241    case (RsBlas_ctpsv):
242        initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
243        cblas_ctpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
244                    (void*)X, call->incX);
245        break;
246    case (RsBlas_zgemv):
247        initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
248        cblas_zgemv(CblasRowMajor, TransA, call->M, call->N, (void*)&call->alpha.z, (void*)A,
249                    lda, (void*)X, call->incX, (void*)&call->beta.z, (void*)Y, call->incY);
250        break;
251    case (RsBlas_zgbmv):
252        initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
253        cblas_zgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU,
254                    (void*)&call->alpha.z, (void*)A, lda, (void*)X, call->incX,
255                    (void*)&call->beta.z, (void*)Y, call->incY);
256        break;
257    case (RsBlas_ztrmv):
258        initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
259        cblas_ztrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
260                    lda, (void*)X, call->incX);
261        break;
262    case (RsBlas_ztbmv):
263        initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
264        cblas_ztbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A,
265                    lda, (void*)X, call->incX);
266        break;
267    // stpmv takes a packed 1D Allocation only
268    case (RsBlas_ztpmv):
269        initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
270        cblas_ztpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
271                    (void*)X, call->incX);
272        break;
273    case (RsBlas_ztrsv):
274        initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
275        cblas_ztrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, lda,
276                    (void*)X, call->incX);
277        break;
278    case (RsBlas_ztbsv):
279        initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
280        cblas_ztbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A,
281                    lda, (void*)X, call->incX);
282        break;
283    case (RsBlas_ztpsv):
284        initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
285        cblas_ztpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
286                    (void*)X, call->incX);
287        break;
288
289
290    // S and D only
291    case (RsBlas_ssymv):
292        initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
293        cblas_ssymv(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)A, lda,
294                    (float*)X, call->incX, call->beta.f, (float*)Y, call->incY);
295        break;
296    case (RsBlas_ssbmv):
297        initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
298        cblas_ssbmv(CblasRowMajor, Uplo, call->N, call->K, call->alpha.f,
299                    (float*)A, lda, (float*)X, call->incX, call->beta.f,
300                    (float*)Y, call->incY);
301        break;
302    //sspmv requires a packed 1D Allocation
303    case (RsBlas_sspmv):
304        initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
305        cblas_sspmv(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)A,
306                    (float*)X, call->incX, call->beta.f, (float*)Y, call->incY);
307        break;
308    // following calls have init reordered because A is output matrix
309    case (RsBlas_sger):
310        initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda);
311        cblas_sger(CblasRowMajor, call->M, call->N, call->alpha.f, (float*)X,
312                   call->incX, (float*)Y, call->incY, (float*)A, lda);
313        break;
314    case (RsBlas_ssyr):
315        initABC(ain, sizeof(float), &X, &A, nullptr, &ldb, &lda, nullptr);
316        cblas_ssyr(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX,
317                   (float*)A, lda);
318        break;
319    // sspr is packed 1D Allocation A only
320    case (RsBlas_sspr):
321        initABC(ain, sizeof(float), &X, &A, nullptr, &ldb, &lda, nullptr);
322        cblas_sspr(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX,
323                   (float*)A);
324        break;
325    case (RsBlas_ssyr2):
326        initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda);
327        cblas_ssyr2(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX,
328                    (float*)Y, call->incY, (float*)A, lda);
329        break;
330    // sspr2 is packed 1D Allocation A only
331    case (RsBlas_sspr2):
332        initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda);
333        cblas_sspr2(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX,
334                    (float*)Y, call->incY, (float*)A);
335        break;
336    case (RsBlas_dsymv):
337        initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
338        cblas_dsymv(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)A, lda,
339                    (double*)X, call->incX, call->beta.d, (double*)Y, call->incY);
340        break;
341    case (RsBlas_dsbmv):
342        initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
343        cblas_dsbmv(CblasRowMajor, Uplo, call->N, call->K, call->alpha.d,
344                    (double*)A, lda, (double*)X, call->incX, call->beta.d,
345                    (double*)Y, call->incY);
346        break;
347    // dspmv requires a packed 1D Allocation
348    case (RsBlas_dspmv):
349        initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
350        cblas_dspmv(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)A,
351                    (double*)X, call->incX, call->beta.d, (double*)Y, call->incY);
352        break;
353    // following calls have init reordered because A is output matrix
354    case (RsBlas_dger):
355        initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda);
356        cblas_dger(CblasRowMajor, call->M, call->N, call->alpha.d, (double*)X,
357                   call->incX, (double*)Y, call->incY, (double*)A, lda);
358        break;
359    case (RsBlas_dsyr):
360        initABC(ain, sizeof(double), &X, &A, nullptr, &ldb, &lda, nullptr);
361        cblas_dsyr(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX,
362                   (double*)A, lda);
363        break;
364    // dspr is packed 1D Allocation A only
365    case (RsBlas_dspr):
366        initABC(ain, sizeof(double), &X, &A, nullptr, &ldb, &lda, nullptr);
367        cblas_dspr(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX,
368                   (double*)A);
369        break;
370    case (RsBlas_dsyr2):
371        initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda);
372        cblas_dsyr2(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX,
373                    (double*)Y, call->incY, (double*)A, lda);
374        break;
375    // dspr2 is packed 1D Allocation A only
376    case (RsBlas_dspr2):
377        initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda);
378        cblas_dspr2(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX,
379                    (double*)Y, call->incY, (double*)A);
380        break;
381
382    // C and Z only
383    case (RsBlas_chemv):
384        initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
385        cblas_chemv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, A, lda,
386                    X, call->incX, (void*)&call->beta.c, Y, call->incY);
387        break;
388    case (RsBlas_chbmv):
389        initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
390        cblas_chbmv(CblasRowMajor, Uplo, call->N, call->K, (void*)&call->alpha.c,
391                    A, lda, X, call->incX, (void*)&call->beta.c, Y, call->incY);
392        break;
393    case (RsBlas_chpmv):
394        initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
395        cblas_chpmv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, A,
396                    X, call->incX, (void*)&call->beta.c, Y, call->incY);
397        break;
398    case (RsBlas_cgeru):
399        initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda);
400        cblas_cgeru(CblasRowMajor, call->M, call->N, (void*)&call->alpha.c,
401                    X, call->incX, Y, call->incY, A, lda);
402        break;
403    case (RsBlas_cgerc):
404        initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda);
405        cblas_cgerc(CblasRowMajor, call->M, call->N, (void*)&call->alpha.c,
406                    X, call->incX, Y, call->incY, A, lda);
407        break;
408    case (RsBlas_cher):
409        initABC(ain, sizeof(float)*2, &X, nullptr, &A, &ldb, nullptr, &lda);
410        cblas_cher(CblasRowMajor, Uplo, call->N, call->alpha.f,
411                   X, call->incX, A, lda);
412        break;
413    // packed 1D Allocations only
414    case (RsBlas_chpr):
415        initABC(ain, sizeof(float)*2, &X, nullptr, &A, &ldb, nullptr, &lda);
416        cblas_chpr(CblasRowMajor, Uplo, call->N, call->alpha.f, X,
417                   call->incX, A);
418        break;
419    case (RsBlas_cher2):
420        initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda);
421        cblas_cher2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c,
422                   X, call->incX, Y, call->incY, A, lda);
423        break;
424    // packed 1D Allocations only
425    case (RsBlas_chpr2):
426        initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda);
427        cblas_chpr2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, X,
428                   call->incX, Y, call->incY, A);
429        break;
430    case (RsBlas_zhemv):
431        initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
432        cblas_zhemv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, A, lda,
433                    X, call->incX, (void*)&call->beta.z, Y, call->incY);
434        break;
435    case (RsBlas_zhbmv):
436        initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
437        cblas_zhbmv(CblasRowMajor, Uplo, call->N, call->K, (void*)&call->alpha.z,
438                    A, lda, X, call->incX, (void*)&call->beta.z, Y, call->incY);
439        break;
440    case (RsBlas_zhpmv):
441        initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
442        cblas_zhpmv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, A,
443                    X, call->incX, (void*)&call->beta.z, Y, call->incY);
444        break;
445    case (RsBlas_zgeru):
446        initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda);
447        cblas_zgeru(CblasRowMajor, call->M, call->N, (void*)&call->alpha.z,
448                    X, call->incX, Y, call->incY, A, lda);
449        break;
450    case (RsBlas_zgerc):
451        initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda);
452        cblas_zgerc(CblasRowMajor, call->M, call->N, (void*)&call->alpha.z,
453                    X, call->incX, Y, call->incY, A, lda);
454        break;
455    case (RsBlas_zher):
456        initABC(ain, sizeof(double)*2, &X, nullptr, &A, &ldb, nullptr, &lda);
457        cblas_zher(CblasRowMajor, Uplo, call->N, call->alpha.d,
458                   X, call->incX, A, lda);
459        break;
460    // packed 1D Allocations only
461    case (RsBlas_zhpr):
462        initABC(ain, sizeof(double)*2, &X, nullptr, &A, &ldb, nullptr, &lda);
463        cblas_zhpr(CblasRowMajor, Uplo, call->N, call->alpha.d, X,
464                   call->incX, A);
465        break;
466    case (RsBlas_zher2):
467        initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda);
468        cblas_zher2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z,
469                   X, call->incX, Y, call->incY, A, lda);
470        break;
471    // packed 1D Allocations only
472    case (RsBlas_zhpr2):
473        initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda);
474        cblas_zhpr2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, X,
475                   call->incX, Y, call->incY, A);
476        break;
477
478    // Level 3 BLAS
479    case (RsBlas_sgemm):
480        initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc);
481        cblas_sgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.f,
482                    (float*)A, lda, (float*)B, ldb, call->beta.f, (float*)C, ldc);
483        break;
484    case (RsBlas_ssymm):
485        initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc);
486        cblas_ssymm(CblasRowMajor, Side, Uplo, call->M, call->N, call->alpha.f, (float*)A,
487                    lda, (float*)B, ldb, call->beta.f, (float*)C, ldc);
488        break;
489    case (RsBlas_ssyrk):
490        initABC(ain, sizeof(float), &A, nullptr, &C, &lda, nullptr, &ldc);
491        cblas_ssyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, (float*)A,
492                    lda, call->beta.f, (float*)C, ldc);
493        break;
494    case (RsBlas_ssyr2k):
495        initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc);
496        cblas_ssyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, (float*)A,
497                     lda, (float*)B, ldb, call->beta.f, (float*)C, ldc);
498        break;
499    case (RsBlas_strmm):
500        initABC(ain, sizeof(float), &A, &B, nullptr, &lda, &ldb, nullptr);
501        cblas_strmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.f,
502                    (float*)A, lda, (float*)B, ldb);
503        break;
504    case (RsBlas_strsm):
505        initABC(ain, sizeof(float), &A, &B, nullptr, &lda, &ldb, nullptr);
506        cblas_strsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.f,
507                    (float*)A, lda, (float*)B, ldb);
508        break;
509
510
511    case (RsBlas_dgemm):
512        initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc);
513        cblas_dgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.d,
514                    (double*)A, lda, (double*)B, ldb, call->beta.d, (double*)C, ldc);
515        break;
516    case (RsBlas_dsymm):
517        initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc);
518        cblas_dsymm(CblasRowMajor, Side, Uplo, call->M, call->N, call->alpha.d, (double*)A,
519                    lda, (double*)B, ldb, call->beta.d, (double*)C, ldc);
520        break;
521    case (RsBlas_dsyrk):
522        initABC(ain, sizeof(double), &A, nullptr, &C, &lda, nullptr, &ldc);
523        cblas_dsyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, (double*)A,
524                    lda, call->beta.d, (double*)C, ldc);
525        break;
526    case (RsBlas_dsyr2k):
527        initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc);
528        cblas_dsyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, (double*)A,
529                     lda, (double*)B, ldb, call->beta.d, (double*)C, ldc);
530        break;
531    case (RsBlas_dtrmm):
532        initABC(ain, sizeof(double), &A, &B, nullptr, &lda, &ldb, nullptr);
533        cblas_dtrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.d,
534                    (double*)A, lda, (double*)B, ldb);
535        break;
536    case (RsBlas_dtrsm):
537        initABC(ain, sizeof(double), &A, &B, nullptr, &lda, &ldb, nullptr);
538        cblas_dtrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.d,
539                    (double*)A, lda, (double*)B, ldb);
540        break;
541
542    case (RsBlas_cgemm):
543        initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
544        cblas_cgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.c,
545                    A, lda, B, ldb, (void*)&call->beta.c, C, ldc);
546        break;
547    case (RsBlas_csymm):
548        initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
549        cblas_csymm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.c, A,
550                    lda, B, ldb, (void*)&call->beta.c, C, ldc);
551        break;
552    case (RsBlas_csyrk):
553        initABC(ain, sizeof(float)*2, &A, nullptr, &C, &lda, nullptr, &ldc);
554        cblas_csyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A,
555                    lda, (void*)&call->beta.c, C, ldc);
556        break;
557    case (RsBlas_csyr2k):
558        initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
559        cblas_csyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A,
560                     lda, B, ldb, (void*)&call->beta.c, C, ldc);
561        break;
562    case (RsBlas_ctrmm):
563        initABC(ain, sizeof(float)*2, &A, &B, nullptr, &lda, &ldb, nullptr);
564        cblas_ctrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.c,
565                    A, lda, B, ldb);
566        break;
567    case (RsBlas_ctrsm):
568        initABC(ain, sizeof(float)*2, &A, &B, nullptr, &lda, &ldb, nullptr);
569        cblas_ctrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.c,
570                    A, lda, B, ldb);
571        break;
572
573    case (RsBlas_zgemm):
574        initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
575        cblas_zgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.z,
576                    A, lda, B, ldb, (void*)&call->beta.z, C, ldc);
577        break;
578    case (RsBlas_zsymm):
579        initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
580        cblas_zsymm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.z, A,
581                    lda, B, ldb, (void*)&call->beta.z, C, ldc);
582        break;
583    case (RsBlas_zsyrk):
584        initABC(ain, sizeof(double)*2, &A, nullptr, &C, &lda, nullptr, &ldc);
585        cblas_zsyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A,
586                    lda, (void*)&call->beta.z, C, ldc);
587        break;
588    case (RsBlas_zsyr2k):
589        initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
590        cblas_zsyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A,
591                     lda, B, ldb, (void*)&call->beta.z, C, ldc);
592        break;
593    case (RsBlas_ztrmm):
594        initABC(ain, sizeof(double)*2, &A, &B, nullptr, &lda, &ldb, nullptr);
595        cblas_ztrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.z,
596                    A, lda, B, ldb);
597        break;
598    case (RsBlas_ztrsm):
599        initABC(ain, sizeof(double)*2, &A, &B, nullptr, &lda, &ldb, nullptr);
600        cblas_ztrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.z,
601                    A, lda, B, ldb);
602        break;
603
604    // Level 3 C and Z only
605    case (RsBlas_chemm):
606        initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
607        cblas_chemm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.c, A, lda,
608                    B, ldb, (void*)&call->beta.c, C, ldc);
609        break;
610    case (RsBlas_cherk):
611        initABC(ain, sizeof(float)*2, &A, nullptr, &C, &lda, nullptr, &ldc);
612        cblas_cherk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, A, lda,
613                    call->beta.f, C, ldc);
614        break;
615    case (RsBlas_cher2k):
616        initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
617        cblas_cher2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A, lda,
618                     B, ldb, call->beta.f, C, ldc);
619        break;
620
621    case (RsBlas_zhemm):
622        initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
623        cblas_zhemm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.z, A, lda,
624                    B, ldb, (void*)&call->beta.z, C, ldc);
625        break;
626    case (RsBlas_zherk):
627        initABC(ain, sizeof(double)*2, &A, nullptr, &C, &lda, nullptr, &ldc);
628        cblas_zherk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, A, lda,
629                    call->beta.d, C, ldc);
630        break;
631    case (RsBlas_zher2k):
632        initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
633        cblas_zher2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A, lda,
634                     B, ldb, call->beta.d, C, ldc);
635        break;
636
637
638    case (RsBlas_bnnm):
639        initABC(ain, sizeof(uint8_t), &A, &B, &C, &lda, &ldb, &ldc);
640        kernelBNNM(call->M, call->N, call->K,
641                    (const uint8_t*)A, call->a_offset, lda,
642                    (const uint8_t*)B, call->b_offset, ldb,
643                    (uint8_t*)C, call->c_offset, ldc,
644                    call->c_mult_int);
645
646        break;
647
648    default:
649        ALOGE("unimplemented\n");
650    }
651
652
653}
654
655void RsdCpuScriptIntrinsicBLAS::kernelBNNM(size_t m, size_t n, size_t k,
656                                           const uint8_t* a, uint8_t a_offset, size_t lda,
657                                           const uint8_t* b, uint8_t b_offset, size_t ldb,
658                                           uint8_t* c, int32_t c_offset, size_t ldc,
659                                           int32_t c_mult_int) {
660    // Calculations are done in 1.10.21 fixed-point format for the final output,
661    // just before there's a shift down to drop the fractional parts. The output
662    // values are gated to 0 to 255 to fit in a byte, but the 10-bit format
663    // gives some headroom to avoid wrapping around on small overflows.
664    const int c_shift = 21;
665    size_t i = 0, j = 0, l = 0;
666    for (j = 0; j < n; j++) {
667        for (i = 0; i < m; i++) {
668            int32_t total = 0;
669            for (l = 0; l < k; l++) {
670                const int a_index = ((i * lda) + l);
671                const uint8_t a_as_byte = a[a_index];
672                const int32_t a_as_int = (((int32_t)(a_as_byte)) - a_offset);
673                const int b_index = ((j * ldb) + l);
674                const uint8_t b_as_byte = b[b_index];
675                const int32_t b_as_int = (((int32_t)(b_as_byte)) - b_offset);
676                const int32_t mult_as_int = (a_as_int * b_as_int);
677                total += mult_as_int;
678            }
679            const int c_index = ((ldc * i) + j);
680            int32_t output =
681                ((((total + c_offset) * c_mult_int) + (1 << (c_shift - 1)))
682                 >> c_shift);
683            if (output > 255) {
684                output = 255;
685            }
686            if (output < 0) {
687                output = 0;
688            }
689            c[c_index] = (uint8_t)(output);
690        }
691    }
692}
693
694
695
696
697
698RsdCpuScriptIntrinsicBLAS::RsdCpuScriptIntrinsicBLAS(RsdCpuReferenceImpl *ctx,
699                                                   const Script *s)
700            : RsdCpuScriptIntrinsic(ctx, s, nullptr, RS_SCRIPT_INTRINSIC_ID_BLAS) {
701
702
703}
704
705RsdCpuScriptIntrinsicBLAS::~RsdCpuScriptIntrinsicBLAS() {
706}
707
708
709
710
711
712RsdCpuScriptImpl * rsdIntrinsic_BLAS(RsdCpuReferenceImpl *ctx,
713                                    const Script *s, const Element *e) {
714
715    return new RsdCpuScriptIntrinsicBLAS(ctx, s);
716}
717