rsCpuIntrinsicBLAS.cpp revision 99d0e8130f5b4bb83d1a68d96496fa558e35193a
1/* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 18#include "rsCpuIntrinsic.h" 19#include "rsCpuIntrinsicInlines.h" 20#include "cblas.h" 21#include "eight_bit_int_gemm.h" 22 23using namespace android; 24using namespace android::renderscript; 25 26namespace android { 27namespace renderscript { 28 29 30class RsdCpuScriptIntrinsicBLAS : public RsdCpuScriptIntrinsic { 31public: 32 void invokeForEach(uint32_t slot, 33 const Allocation ** ain, 34 uint32_t inLen, 35 Allocation * aout, 36 const void * usr, 37 uint32_t usrLen, 38 const RsScriptCall *sc) override; 39 40 void populateScript(Script *) override; 41 ~RsdCpuScriptIntrinsicBLAS() override; 42 RsdCpuScriptIntrinsicBLAS(RsdCpuReferenceImpl *ctx, const Script *s); 43 44protected: 45 46 uint8_t a_offset = 0; 47 uint8_t b_offset = 0; 48 uint8_t c_offset = 0; 49 50 static void kernelBNNM(size_t m, size_t n, size_t k, 51 const uint8_t* a, uint8_t a_offset, size_t lda, 52 const uint8_t* b, uint8_t b_offset, size_t ldb, 53 uint8_t* c, int32_t c_offset, size_t ldc, 54 int32_t c_mult_int); 55 56 57 58}; 59 60} 61} 62 63void RsdCpuScriptIntrinsicBLAS::populateScript(Script *s) { 64 s->mHal.info.exportedVariableCount = 0; 65} 66 67static void initABC(const Allocation ** ain, 68 size_t size, 69 void** A, 70 void** B, 71 void** C, 72 int* lda, 73 int* ldb, 74 int* ldc) 75{ 76 if (ain[0]) { 77 *A = ain[0]->mHal.drvState.lod[0].mallocPtr; 78 *lda = (int)(ain[0]->mHal.drvState.lod[0].stride/size); 79 } 80 if (ain[1]) { 81 *B = ain[1]->mHal.drvState.lod[0].mallocPtr; 82 *ldb = (int)(ain[1]->mHal.drvState.lod[0].stride/size); 83 } 84 if (ain[2]) { 85 *C = ain[2]->mHal.drvState.lod[0].mallocPtr; 86 *ldc = (int)(ain[2]->mHal.drvState.lod[0].stride/size); 87 } 88 89 90} 91 92void RsdCpuScriptIntrinsicBLAS::invokeForEach(uint32_t slot, 93 const Allocation ** ain, 94 uint32_t inLen, 95 Allocation * aout, 96 const void * usr, 97 uint32_t usrLen, 98 const RsScriptCall *sc) { 99 RsBlasCall* call = (RsBlasCall*) usr; 100 // setup BLAS enum args 101 enum CBLAS_TRANSPOSE TransA = (enum CBLAS_TRANSPOSE)call->transA; 102 enum CBLAS_TRANSPOSE TransB = (enum CBLAS_TRANSPOSE)call->transB; 103 enum CBLAS_UPLO Uplo = (enum CBLAS_UPLO)call->uplo; 104 enum CBLAS_DIAG Diag = (enum CBLAS_DIAG)call->diag; 105 enum CBLAS_SIDE Side = (enum CBLAS_SIDE)call->side; 106 107 void *A = nullptr; 108 void *B = nullptr; 109 void *C = nullptr; 110 void *X = nullptr; 111 void *Y = nullptr; 112 113 int lda = 0, ldb = 0, ldc = 0; 114 115 switch (call->func) { 116 117 // Level 1 BLAS: returns into a 1D Allocation 118 119 120 // Level 2 BLAS 121 case (RsBlas_sgemv): 122 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc); 123 cblas_sgemv(CblasRowMajor, TransA, call->M, call->N, call->alpha.f, (float*)A, 124 lda, (float*)X, call->incX, call->beta.f, (float*)Y, call->incY); 125 break; 126 case (RsBlas_sgbmv): 127 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc); 128 cblas_sgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU, 129 call->alpha.f, (float*)A, lda, (float*)X, call->incX, 130 call->beta.f, (float*)Y, call->incY); 131 break; 132 case (RsBlas_strmv): 133 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 134 cblas_strmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A, 135 lda, (float*)X, call->incX); 136 break; 137 case (RsBlas_stbmv): 138 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 139 cblas_stbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (float*)A, 140 lda, (float*)X, call->incX); 141 break; 142 // stpmv takes a packed 1D Allocation only 143 case (RsBlas_stpmv): 144 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 145 cblas_stpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A, 146 (float*)X, call->incX); 147 break; 148 case (RsBlas_strsv): 149 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 150 cblas_strsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A, lda, 151 (float*)X, call->incX); 152 break; 153 case (RsBlas_stbsv): 154 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 155 cblas_stbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (float*)A, 156 lda, (float*)X, call->incX); 157 break; 158 case (RsBlas_stpsv): 159 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 160 cblas_stpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A, 161 (float*)X, call->incX); 162 break; 163 case (RsBlas_dgemv): 164 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc); 165 cblas_dgemv(CblasRowMajor, TransA, call->M, call->N, call->alpha.d, (double*)A, 166 lda, (double*)X, call->incX, call->beta.d, (double*)Y, call->incY); 167 break; 168 case (RsBlas_dgbmv): 169 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc); 170 cblas_dgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU, 171 call->alpha.d, (double*)A, lda, (double*)X, call->incX, 172 call->beta.d, (double*)Y, call->incY); 173 break; 174 case (RsBlas_dtrmv): 175 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 176 cblas_dtrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A, 177 lda, (double*)X, call->incX); 178 break; 179 case (RsBlas_dtbmv): 180 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 181 cblas_dtbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (double*)A, 182 lda, (double*)X, call->incX); 183 break; 184 // stpmv takes a packed 1D Allocation only 185 case (RsBlas_dtpmv): 186 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 187 cblas_dtpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A, 188 (double*)X, call->incX); 189 break; 190 case (RsBlas_dtrsv): 191 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 192 cblas_dtrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A, lda, 193 (double*)X, call->incX); 194 break; 195 case (RsBlas_dtbsv): 196 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 197 cblas_dtbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (double*)A, 198 lda, (double*)X, call->incX); 199 break; 200 case (RsBlas_dtpsv): 201 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 202 cblas_dtpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A, 203 (double*)X, call->incX); 204 break; 205 case (RsBlas_cgemv): 206 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc); 207 cblas_cgemv(CblasRowMajor, TransA, call->M, call->N, (void*)&call->alpha.c, (void*)A, 208 lda, (void*)X, call->incX, (void*)&call->beta.c, (void*)Y, call->incY); 209 break; 210 case (RsBlas_cgbmv): 211 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc); 212 cblas_cgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU, 213 (void*)&call->alpha.c, (void*)A, lda, (void*)X, call->incX, 214 (void*)&call->beta.c, (void*)Y, call->incY); 215 break; 216 case (RsBlas_ctrmv): 217 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 218 cblas_ctrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 219 lda, (void*)X, call->incX); 220 break; 221 case (RsBlas_ctbmv): 222 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 223 cblas_ctbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A, 224 lda, (void*)X, call->incX); 225 break; 226 // stpmv takes a packed 1D Allocation only 227 case (RsBlas_ctpmv): 228 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 229 cblas_ctpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 230 (void*)X, call->incX); 231 break; 232 case (RsBlas_ctrsv): 233 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 234 cblas_ctrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, lda, 235 (void*)X, call->incX); 236 break; 237 case (RsBlas_ctbsv): 238 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 239 cblas_ctbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A, 240 lda, (void*)X, call->incX); 241 break; 242 case (RsBlas_ctpsv): 243 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 244 cblas_ctpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 245 (void*)X, call->incX); 246 break; 247 case (RsBlas_zgemv): 248 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc); 249 cblas_zgemv(CblasRowMajor, TransA, call->M, call->N, (void*)&call->alpha.z, (void*)A, 250 lda, (void*)X, call->incX, (void*)&call->beta.z, (void*)Y, call->incY); 251 break; 252 case (RsBlas_zgbmv): 253 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc); 254 cblas_zgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU, 255 (void*)&call->alpha.z, (void*)A, lda, (void*)X, call->incX, 256 (void*)&call->beta.z, (void*)Y, call->incY); 257 break; 258 case (RsBlas_ztrmv): 259 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 260 cblas_ztrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 261 lda, (void*)X, call->incX); 262 break; 263 case (RsBlas_ztbmv): 264 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 265 cblas_ztbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A, 266 lda, (void*)X, call->incX); 267 break; 268 // stpmv takes a packed 1D Allocation only 269 case (RsBlas_ztpmv): 270 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 271 cblas_ztpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 272 (void*)X, call->incX); 273 break; 274 case (RsBlas_ztrsv): 275 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 276 cblas_ztrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, lda, 277 (void*)X, call->incX); 278 break; 279 case (RsBlas_ztbsv): 280 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 281 cblas_ztbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A, 282 lda, (void*)X, call->incX); 283 break; 284 case (RsBlas_ztpsv): 285 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 286 cblas_ztpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 287 (void*)X, call->incX); 288 break; 289 290 291 // S and D only 292 case (RsBlas_ssymv): 293 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc); 294 cblas_ssymv(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)A, lda, 295 (float*)X, call->incX, call->beta.f, (float*)Y, call->incY); 296 break; 297 case (RsBlas_ssbmv): 298 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc); 299 cblas_ssbmv(CblasRowMajor, Uplo, call->N, call->K, call->alpha.f, 300 (float*)A, lda, (float*)X, call->incX, call->beta.f, 301 (float*)Y, call->incY); 302 break; 303 //sspmv requires a packed 1D Allocation 304 case (RsBlas_sspmv): 305 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc); 306 cblas_sspmv(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)A, 307 (float*)X, call->incX, call->beta.f, (float*)Y, call->incY); 308 break; 309 // following calls have init reordered because A is output matrix 310 case (RsBlas_sger): 311 initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda); 312 cblas_sger(CblasRowMajor, call->M, call->N, call->alpha.f, (float*)X, 313 call->incX, (float*)Y, call->incY, (float*)A, lda); 314 break; 315 case (RsBlas_ssyr): 316 initABC(ain, sizeof(float), &X, &A, nullptr, &ldb, &lda, nullptr); 317 cblas_ssyr(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX, 318 (float*)A, lda); 319 break; 320 // sspr is packed 1D Allocation A only 321 case (RsBlas_sspr): 322 initABC(ain, sizeof(float), &X, &A, nullptr, &ldb, &lda, nullptr); 323 cblas_sspr(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX, 324 (float*)A); 325 break; 326 case (RsBlas_ssyr2): 327 initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda); 328 cblas_ssyr2(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX, 329 (float*)Y, call->incY, (float*)A, lda); 330 break; 331 // sspr2 is packed 1D Allocation A only 332 case (RsBlas_sspr2): 333 initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda); 334 cblas_sspr2(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX, 335 (float*)Y, call->incY, (float*)A); 336 break; 337 case (RsBlas_dsymv): 338 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc); 339 cblas_dsymv(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)A, lda, 340 (double*)X, call->incX, call->beta.d, (double*)Y, call->incY); 341 break; 342 case (RsBlas_dsbmv): 343 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc); 344 cblas_dsbmv(CblasRowMajor, Uplo, call->N, call->K, call->alpha.d, 345 (double*)A, lda, (double*)X, call->incX, call->beta.d, 346 (double*)Y, call->incY); 347 break; 348 // dspmv requires a packed 1D Allocation 349 case (RsBlas_dspmv): 350 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc); 351 cblas_dspmv(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)A, 352 (double*)X, call->incX, call->beta.d, (double*)Y, call->incY); 353 break; 354 // following calls have init reordered because A is output matrix 355 case (RsBlas_dger): 356 initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda); 357 cblas_dger(CblasRowMajor, call->M, call->N, call->alpha.d, (double*)X, 358 call->incX, (double*)Y, call->incY, (double*)A, lda); 359 break; 360 case (RsBlas_dsyr): 361 initABC(ain, sizeof(double), &X, &A, nullptr, &ldb, &lda, nullptr); 362 cblas_dsyr(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX, 363 (double*)A, lda); 364 break; 365 // dspr is packed 1D Allocation A only 366 case (RsBlas_dspr): 367 initABC(ain, sizeof(double), &X, &A, nullptr, &ldb, &lda, nullptr); 368 cblas_dspr(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX, 369 (double*)A); 370 break; 371 case (RsBlas_dsyr2): 372 initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda); 373 cblas_dsyr2(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX, 374 (double*)Y, call->incY, (double*)A, lda); 375 break; 376 // dspr2 is packed 1D Allocation A only 377 case (RsBlas_dspr2): 378 initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda); 379 cblas_dspr2(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX, 380 (double*)Y, call->incY, (double*)A); 381 break; 382 383 // C and Z only 384 case (RsBlas_chemv): 385 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc); 386 cblas_chemv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, A, lda, 387 X, call->incX, (void*)&call->beta.c, Y, call->incY); 388 break; 389 case (RsBlas_chbmv): 390 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc); 391 cblas_chbmv(CblasRowMajor, Uplo, call->N, call->K, (void*)&call->alpha.c, 392 A, lda, X, call->incX, (void*)&call->beta.c, Y, call->incY); 393 break; 394 case (RsBlas_chpmv): 395 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc); 396 cblas_chpmv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, A, 397 X, call->incX, (void*)&call->beta.c, Y, call->incY); 398 break; 399 case (RsBlas_cgeru): 400 initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda); 401 cblas_cgeru(CblasRowMajor, call->M, call->N, (void*)&call->alpha.c, 402 X, call->incX, Y, call->incY, A, lda); 403 break; 404 case (RsBlas_cgerc): 405 initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda); 406 cblas_cgerc(CblasRowMajor, call->M, call->N, (void*)&call->alpha.c, 407 X, call->incX, Y, call->incY, A, lda); 408 break; 409 case (RsBlas_cher): 410 initABC(ain, sizeof(float)*2, &X, nullptr, &A, &ldb, nullptr, &lda); 411 cblas_cher(CblasRowMajor, Uplo, call->N, call->alpha.f, 412 X, call->incX, A, lda); 413 break; 414 // packed 1D Allocations only 415 case (RsBlas_chpr): 416 initABC(ain, sizeof(float)*2, &X, nullptr, &A, &ldb, nullptr, &lda); 417 cblas_chpr(CblasRowMajor, Uplo, call->N, call->alpha.f, X, 418 call->incX, A); 419 break; 420 case (RsBlas_cher2): 421 initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda); 422 cblas_cher2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, 423 X, call->incX, Y, call->incY, A, lda); 424 break; 425 // packed 1D Allocations only 426 case (RsBlas_chpr2): 427 initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda); 428 cblas_chpr2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, X, 429 call->incX, Y, call->incY, A); 430 break; 431 case (RsBlas_zhemv): 432 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc); 433 cblas_zhemv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, A, lda, 434 X, call->incX, (void*)&call->beta.z, Y, call->incY); 435 break; 436 case (RsBlas_zhbmv): 437 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc); 438 cblas_zhbmv(CblasRowMajor, Uplo, call->N, call->K, (void*)&call->alpha.z, 439 A, lda, X, call->incX, (void*)&call->beta.z, Y, call->incY); 440 break; 441 case (RsBlas_zhpmv): 442 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc); 443 cblas_zhpmv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, A, 444 X, call->incX, (void*)&call->beta.z, Y, call->incY); 445 break; 446 case (RsBlas_zgeru): 447 initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda); 448 cblas_zgeru(CblasRowMajor, call->M, call->N, (void*)&call->alpha.z, 449 X, call->incX, Y, call->incY, A, lda); 450 break; 451 case (RsBlas_zgerc): 452 initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda); 453 cblas_zgerc(CblasRowMajor, call->M, call->N, (void*)&call->alpha.z, 454 X, call->incX, Y, call->incY, A, lda); 455 break; 456 case (RsBlas_zher): 457 initABC(ain, sizeof(double)*2, &X, nullptr, &A, &ldb, nullptr, &lda); 458 cblas_zher(CblasRowMajor, Uplo, call->N, call->alpha.d, 459 X, call->incX, A, lda); 460 break; 461 // packed 1D Allocations only 462 case (RsBlas_zhpr): 463 initABC(ain, sizeof(double)*2, &X, nullptr, &A, &ldb, nullptr, &lda); 464 cblas_zhpr(CblasRowMajor, Uplo, call->N, call->alpha.d, X, 465 call->incX, A); 466 break; 467 case (RsBlas_zher2): 468 initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda); 469 cblas_zher2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, 470 X, call->incX, Y, call->incY, A, lda); 471 break; 472 // packed 1D Allocations only 473 case (RsBlas_zhpr2): 474 initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda); 475 cblas_zhpr2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, X, 476 call->incX, Y, call->incY, A); 477 break; 478 479 // Level 3 BLAS 480 case (RsBlas_sgemm): 481 initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc); 482 cblas_sgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.f, 483 (float*)A, lda, (float*)B, ldb, call->beta.f, (float*)C, ldc); 484 break; 485 case (RsBlas_ssymm): 486 initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc); 487 cblas_ssymm(CblasRowMajor, Side, Uplo, call->M, call->N, call->alpha.f, (float*)A, 488 lda, (float*)B, ldb, call->beta.f, (float*)C, ldc); 489 break; 490 case (RsBlas_ssyrk): 491 initABC(ain, sizeof(float), &A, nullptr, &C, &lda, nullptr, &ldc); 492 cblas_ssyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, (float*)A, 493 lda, call->beta.f, (float*)C, ldc); 494 break; 495 case (RsBlas_ssyr2k): 496 initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc); 497 cblas_ssyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, (float*)A, 498 lda, (float*)B, ldb, call->beta.f, (float*)C, ldc); 499 break; 500 case (RsBlas_strmm): 501 initABC(ain, sizeof(float), &A, &B, nullptr, &lda, &ldb, nullptr); 502 cblas_strmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.f, 503 (float*)A, lda, (float*)B, ldb); 504 break; 505 case (RsBlas_strsm): 506 initABC(ain, sizeof(float), &A, &B, nullptr, &lda, &ldb, nullptr); 507 cblas_strsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.f, 508 (float*)A, lda, (float*)B, ldb); 509 break; 510 511 512 case (RsBlas_dgemm): 513 initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc); 514 cblas_dgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.d, 515 (double*)A, lda, (double*)B, ldb, call->beta.d, (double*)C, ldc); 516 break; 517 case (RsBlas_dsymm): 518 initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc); 519 cblas_dsymm(CblasRowMajor, Side, Uplo, call->M, call->N, call->alpha.d, (double*)A, 520 lda, (double*)B, ldb, call->beta.d, (double*)C, ldc); 521 break; 522 case (RsBlas_dsyrk): 523 initABC(ain, sizeof(double), &A, nullptr, &C, &lda, nullptr, &ldc); 524 cblas_dsyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, (double*)A, 525 lda, call->beta.d, (double*)C, ldc); 526 break; 527 case (RsBlas_dsyr2k): 528 initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc); 529 cblas_dsyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, (double*)A, 530 lda, (double*)B, ldb, call->beta.d, (double*)C, ldc); 531 break; 532 case (RsBlas_dtrmm): 533 initABC(ain, sizeof(double), &A, &B, nullptr, &lda, &ldb, nullptr); 534 cblas_dtrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.d, 535 (double*)A, lda, (double*)B, ldb); 536 break; 537 case (RsBlas_dtrsm): 538 initABC(ain, sizeof(double), &A, &B, nullptr, &lda, &ldb, nullptr); 539 cblas_dtrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.d, 540 (double*)A, lda, (double*)B, ldb); 541 break; 542 543 case (RsBlas_cgemm): 544 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc); 545 cblas_cgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.c, 546 A, lda, B, ldb, (void*)&call->beta.c, C, ldc); 547 break; 548 case (RsBlas_csymm): 549 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc); 550 cblas_csymm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.c, A, 551 lda, B, ldb, (void*)&call->beta.c, C, ldc); 552 break; 553 case (RsBlas_csyrk): 554 initABC(ain, sizeof(float)*2, &A, nullptr, &C, &lda, nullptr, &ldc); 555 cblas_csyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A, 556 lda, (void*)&call->beta.c, C, ldc); 557 break; 558 case (RsBlas_csyr2k): 559 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc); 560 cblas_csyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A, 561 lda, B, ldb, (void*)&call->beta.c, C, ldc); 562 break; 563 case (RsBlas_ctrmm): 564 initABC(ain, sizeof(float)*2, &A, &B, nullptr, &lda, &ldb, nullptr); 565 cblas_ctrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.c, 566 A, lda, B, ldb); 567 break; 568 case (RsBlas_ctrsm): 569 initABC(ain, sizeof(float)*2, &A, &B, nullptr, &lda, &ldb, nullptr); 570 cblas_ctrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.c, 571 A, lda, B, ldb); 572 break; 573 574 case (RsBlas_zgemm): 575 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc); 576 cblas_zgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.z, 577 A, lda, B, ldb, (void*)&call->beta.z, C, ldc); 578 break; 579 case (RsBlas_zsymm): 580 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc); 581 cblas_zsymm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.z, A, 582 lda, B, ldb, (void*)&call->beta.z, C, ldc); 583 break; 584 case (RsBlas_zsyrk): 585 initABC(ain, sizeof(double)*2, &A, nullptr, &C, &lda, nullptr, &ldc); 586 cblas_zsyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A, 587 lda, (void*)&call->beta.z, C, ldc); 588 break; 589 case (RsBlas_zsyr2k): 590 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc); 591 cblas_zsyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A, 592 lda, B, ldb, (void*)&call->beta.z, C, ldc); 593 break; 594 case (RsBlas_ztrmm): 595 initABC(ain, sizeof(double)*2, &A, &B, nullptr, &lda, &ldb, nullptr); 596 cblas_ztrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.z, 597 A, lda, B, ldb); 598 break; 599 case (RsBlas_ztrsm): 600 initABC(ain, sizeof(double)*2, &A, &B, nullptr, &lda, &ldb, nullptr); 601 cblas_ztrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.z, 602 A, lda, B, ldb); 603 break; 604 605 // Level 3 C and Z only 606 case (RsBlas_chemm): 607 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc); 608 cblas_chemm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.c, A, lda, 609 B, ldb, (void*)&call->beta.c, C, ldc); 610 break; 611 case (RsBlas_cherk): 612 initABC(ain, sizeof(float)*2, &A, nullptr, &C, &lda, nullptr, &ldc); 613 cblas_cherk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, A, lda, 614 call->beta.f, C, ldc); 615 break; 616 case (RsBlas_cher2k): 617 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc); 618 cblas_cher2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A, lda, 619 B, ldb, call->beta.f, C, ldc); 620 break; 621 622 case (RsBlas_zhemm): 623 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc); 624 cblas_zhemm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.z, A, lda, 625 B, ldb, (void*)&call->beta.z, C, ldc); 626 break; 627 case (RsBlas_zherk): 628 initABC(ain, sizeof(double)*2, &A, nullptr, &C, &lda, nullptr, &ldc); 629 cblas_zherk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, A, lda, 630 call->beta.d, C, ldc); 631 break; 632 case (RsBlas_zher2k): 633 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc); 634 cblas_zher2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A, lda, 635 B, ldb, call->beta.d, C, ldc); 636 break; 637 638 639 case (RsBlas_bnnm): 640 initABC(ain, sizeof(uint8_t), &A, &B, &C, &lda, &ldb, &ldc); 641 kernelBNNM(call->M, call->N, call->K, 642 (const uint8_t*)A, call->a_offset, lda, 643 (const uint8_t*)B, call->b_offset, ldb, 644 (uint8_t*)C, call->c_offset, ldc, 645 call->c_mult_int); 646 647 break; 648 649 default: 650 ALOGE("unimplemented\n"); 651 } 652 653 654} 655 656void RsdCpuScriptIntrinsicBLAS::kernelBNNM(size_t m, size_t n, size_t k, 657 const uint8_t* a, uint8_t a_offset, size_t lda, 658 const uint8_t* b, uint8_t b_offset, size_t ldb, 659 uint8_t* c, int32_t c_offset, size_t ldc, 660 int32_t c_mult_int) { 661 const int c_shift = 21; 662 // Using gemmlowp to calculate the low precision 8 bit GEMM. 663 gemmlowp::eight_bit_int_gemm::EightBitIntGemm(m, n, k, a, -a_offset, lda, 664 b, -b_offset, ldb, c, c_offset, 665 c_mult_int, c_shift, ldc); 666} 667 668 669 670 671 672RsdCpuScriptIntrinsicBLAS::RsdCpuScriptIntrinsicBLAS(RsdCpuReferenceImpl *ctx, 673 const Script *s) 674 : RsdCpuScriptIntrinsic(ctx, s, nullptr, RS_SCRIPT_INTRINSIC_ID_BLAS) { 675 676 677} 678 679RsdCpuScriptIntrinsicBLAS::~RsdCpuScriptIntrinsicBLAS() { 680} 681 682 683 684 685 686RsdCpuScriptImpl * rsdIntrinsic_BLAS(RsdCpuReferenceImpl *ctx, 687 const Script *s, const Element *e) { 688 689 return new RsdCpuScriptIntrinsicBLAS(ctx, s); 690} 691