rsCpuIntrinsicBLAS.cpp revision b75ba0fc7469d0bb4c1a6679664a846b3741792e
1/* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 18#include "rsCpuIntrinsic.h" 19#include "rsCpuIntrinsicInlines.h" 20#include "cblas.h" 21 22using namespace android; 23using namespace android::renderscript; 24 25namespace android { 26namespace renderscript { 27 28 29class RsdCpuScriptIntrinsicBLAS : public RsdCpuScriptIntrinsic { 30public: 31 virtual void invokeForEach(uint32_t slot, 32 const Allocation ** ain, 33 uint32_t inLen, 34 Allocation * aout, 35 const void * usr, 36 uint32_t usrLen, 37 const RsScriptCall *sc); 38 39 virtual void populateScript(Script *); 40 virtual ~RsdCpuScriptIntrinsicBLAS(); 41 RsdCpuScriptIntrinsicBLAS(RsdCpuReferenceImpl *ctx, const Script *s); 42 43protected: 44 45 uint8_t a_offset = 0; 46 uint8_t b_offset = 0; 47 uint8_t c_offset = 0; 48 49 static void kernelBNNM(size_t m, size_t n, size_t k, 50 const uint8_t* a, uint32_t a_offset, size_t lda, 51 const uint8_t* b, uint32_t b_offset, size_t ldb, 52 uint8_t* c, uint32_t c_offset, size_t ldc, 53 uint32_t c_mult_int); 54 55 56 57}; 58 59} 60} 61 62void RsdCpuScriptIntrinsicBLAS::populateScript(Script *s) { 63 s->mHal.info.exportedVariableCount = 0; 64} 65 66static void initABC(const Allocation ** ain, 67 size_t size, 68 void** A, 69 void** B, 70 void** C, 71 int* lda, 72 int* ldb, 73 int* ldc) 74{ 75 if (ain[0]) { 76 *A = ain[0]->mHal.drvState.lod[0].mallocPtr; 77 *lda = (int)(ain[0]->mHal.drvState.lod[0].stride/size); 78 } 79 if (ain[1]) { 80 *B = ain[1]->mHal.drvState.lod[0].mallocPtr; 81 *ldb = (int)(ain[1]->mHal.drvState.lod[0].stride/size); 82 } 83 if (ain[2]) { 84 *C = ain[2]->mHal.drvState.lod[0].mallocPtr; 85 *ldc = (int)(ain[2]->mHal.drvState.lod[0].stride/size); 86 } 87 88 89} 90 91void RsdCpuScriptIntrinsicBLAS::invokeForEach(uint32_t slot, 92 const Allocation ** ain, 93 uint32_t inLen, 94 Allocation * aout, 95 const void * usr, 96 uint32_t usrLen, 97 const RsScriptCall *sc) { 98 RsBlasCall* call = (RsBlasCall*) usr; 99 // setup BLAS enum args 100 enum CBLAS_TRANSPOSE TransA = (enum CBLAS_TRANSPOSE)call->transA; 101 enum CBLAS_TRANSPOSE TransB = (enum CBLAS_TRANSPOSE)call->transB; 102 enum CBLAS_UPLO Uplo = (enum CBLAS_UPLO)call->uplo; 103 enum CBLAS_DIAG Diag = (enum CBLAS_DIAG)call->diag; 104 enum CBLAS_SIDE Side = (enum CBLAS_SIDE)call->side; 105 106 void *A = nullptr; 107 void *B = nullptr; 108 void *C = nullptr; 109 void *X = nullptr; 110 void *Y = nullptr; 111 112 int lda = 0, ldb = 0, ldc = 0; 113 114 switch (call->func) { 115 116 // Level 1 BLAS: returns into a 1D Allocation 117 118 119 // Level 2 BLAS 120 case (RsBlas_sgemv): 121 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc); 122 cblas_sgemv(CblasRowMajor, TransA, call->M, call->N, call->alpha.f, (float*)A, 123 lda, (float*)X, call->incX, call->beta.f, (float*)Y, call->incY); 124 break; 125 case (RsBlas_sgbmv): 126 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc); 127 cblas_sgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU, 128 call->alpha.f, (float*)A, lda, (float*)X, call->incX, 129 call->beta.f, (float*)Y, call->incY); 130 break; 131 case (RsBlas_strmv): 132 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 133 cblas_strmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A, 134 lda, (float*)X, call->incX); 135 break; 136 case (RsBlas_stbmv): 137 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 138 cblas_stbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (float*)A, 139 lda, (float*)X, call->incX); 140 break; 141 // stpmv takes a packed 1D Allocation only 142 case (RsBlas_stpmv): 143 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 144 cblas_stpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A, 145 (float*)X, call->incX); 146 break; 147 case (RsBlas_strsv): 148 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 149 cblas_strsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A, lda, 150 (float*)X, call->incX); 151 break; 152 case (RsBlas_stbsv): 153 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 154 cblas_stbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (float*)A, 155 lda, (float*)X, call->incX); 156 break; 157 case (RsBlas_stpsv): 158 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 159 cblas_stpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A, 160 (float*)X, call->incX); 161 break; 162 case (RsBlas_dgemv): 163 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc); 164 cblas_dgemv(CblasRowMajor, TransA, call->M, call->N, call->alpha.d, (double*)A, 165 lda, (double*)X, call->incX, call->beta.d, (double*)Y, call->incY); 166 break; 167 case (RsBlas_dgbmv): 168 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc); 169 cblas_dgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU, 170 call->alpha.d, (double*)A, lda, (double*)X, call->incX, 171 call->beta.d, (double*)Y, call->incY); 172 break; 173 case (RsBlas_dtrmv): 174 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 175 cblas_dtrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A, 176 lda, (double*)X, call->incX); 177 break; 178 case (RsBlas_dtbmv): 179 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 180 cblas_dtbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (double*)A, 181 lda, (double*)X, call->incX); 182 break; 183 // stpmv takes a packed 1D Allocation only 184 case (RsBlas_dtpmv): 185 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 186 cblas_dtpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A, 187 (double*)X, call->incX); 188 break; 189 case (RsBlas_dtrsv): 190 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 191 cblas_dtrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A, lda, 192 (double*)X, call->incX); 193 break; 194 case (RsBlas_dtbsv): 195 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 196 cblas_dtbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (double*)A, 197 lda, (double*)X, call->incX); 198 break; 199 case (RsBlas_dtpsv): 200 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 201 cblas_dtpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A, 202 (double*)X, call->incX); 203 break; 204 case (RsBlas_cgemv): 205 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc); 206 cblas_cgemv(CblasRowMajor, TransA, call->M, call->N, (void*)&call->alpha.c, (void*)A, 207 lda, (void*)X, call->incX, (void*)&call->beta.c, (void*)Y, call->incY); 208 break; 209 case (RsBlas_cgbmv): 210 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc); 211 cblas_cgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU, 212 (void*)&call->alpha.c, (void*)A, lda, (void*)X, call->incX, 213 (void*)&call->beta.c, (void*)Y, call->incY); 214 break; 215 case (RsBlas_ctrmv): 216 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 217 cblas_ctrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 218 lda, (void*)X, call->incX); 219 break; 220 case (RsBlas_ctbmv): 221 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 222 cblas_ctbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A, 223 lda, (void*)X, call->incX); 224 break; 225 // stpmv takes a packed 1D Allocation only 226 case (RsBlas_ctpmv): 227 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 228 cblas_ctpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 229 (void*)X, call->incX); 230 break; 231 case (RsBlas_ctrsv): 232 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 233 cblas_ctrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, lda, 234 (void*)X, call->incX); 235 break; 236 case (RsBlas_ctbsv): 237 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 238 cblas_ctbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A, 239 lda, (void*)X, call->incX); 240 break; 241 case (RsBlas_ctpsv): 242 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 243 cblas_ctpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 244 (void*)X, call->incX); 245 break; 246 case (RsBlas_zgemv): 247 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc); 248 cblas_zgemv(CblasRowMajor, TransA, call->M, call->N, (void*)&call->alpha.z, (void*)A, 249 lda, (void*)X, call->incX, (void*)&call->beta.z, (void*)Y, call->incY); 250 break; 251 case (RsBlas_zgbmv): 252 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc); 253 cblas_zgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU, 254 (void*)&call->alpha.z, (void*)A, lda, (void*)X, call->incX, 255 (void*)&call->beta.z, (void*)Y, call->incY); 256 break; 257 case (RsBlas_ztrmv): 258 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 259 cblas_ztrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 260 lda, (void*)X, call->incX); 261 break; 262 case (RsBlas_ztbmv): 263 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 264 cblas_ztbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A, 265 lda, (void*)X, call->incX); 266 break; 267 // stpmv takes a packed 1D Allocation only 268 case (RsBlas_ztpmv): 269 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 270 cblas_ztpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 271 (void*)X, call->incX); 272 break; 273 case (RsBlas_ztrsv): 274 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 275 cblas_ztrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, lda, 276 (void*)X, call->incX); 277 break; 278 case (RsBlas_ztbsv): 279 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 280 cblas_ztbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A, 281 lda, (void*)X, call->incX); 282 break; 283 case (RsBlas_ztpsv): 284 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 285 cblas_ztpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 286 (void*)X, call->incX); 287 break; 288 289 290 // S and D only 291 case (RsBlas_ssymv): 292 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc); 293 cblas_ssymv(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)A, lda, 294 (float*)X, call->incX, call->beta.f, (float*)Y, call->incY); 295 break; 296 case (RsBlas_ssbmv): 297 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc); 298 cblas_ssbmv(CblasRowMajor, Uplo, call->N, call->K, call->alpha.f, 299 (float*)A, lda, (float*)X, call->incX, call->beta.f, 300 (float*)Y, call->incY); 301 break; 302 //sspmv requires a packed 1D Allocation 303 case (RsBlas_sspmv): 304 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc); 305 cblas_sspmv(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)A, 306 (float*)X, call->incX, call->beta.f, (float*)Y, call->incY); 307 break; 308 // following calls have init reordered because A is output matrix 309 case (RsBlas_sger): 310 initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda); 311 cblas_sger(CblasRowMajor, call->M, call->N, call->alpha.f, (float*)X, 312 call->incX, (float*)Y, call->incY, (float*)A, lda); 313 break; 314 case (RsBlas_ssyr): 315 initABC(ain, sizeof(float), &X, &A, nullptr, &ldb, &lda, nullptr); 316 cblas_ssyr(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX, 317 (float*)A, lda); 318 break; 319 // sspr is packed 1D Allocation A only 320 case (RsBlas_sspr): 321 initABC(ain, sizeof(float), &X, &A, nullptr, &ldb, &lda, nullptr); 322 cblas_sspr(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX, 323 (float*)A); 324 break; 325 case (RsBlas_ssyr2): 326 initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda); 327 cblas_ssyr2(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX, 328 (float*)Y, call->incY, (float*)A, lda); 329 break; 330 // sspr2 is packed 1D Allocation A only 331 case (RsBlas_sspr2): 332 initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda); 333 cblas_sspr2(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX, 334 (float*)Y, call->incY, (float*)A); 335 break; 336 case (RsBlas_dsymv): 337 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc); 338 cblas_dsymv(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)A, lda, 339 (double*)X, call->incX, call->beta.d, (double*)Y, call->incY); 340 break; 341 case (RsBlas_dsbmv): 342 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc); 343 cblas_dsbmv(CblasRowMajor, Uplo, call->N, call->K, call->alpha.d, 344 (double*)A, lda, (double*)X, call->incX, call->beta.d, 345 (double*)Y, call->incY); 346 break; 347 // dspmv requires a packed 1D Allocation 348 case (RsBlas_dspmv): 349 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc); 350 cblas_dspmv(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)A, 351 (double*)X, call->incX, call->beta.d, (double*)Y, call->incY); 352 break; 353 // following calls have init reordered because A is output matrix 354 case (RsBlas_dger): 355 initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda); 356 cblas_dger(CblasRowMajor, call->M, call->N, call->alpha.d, (double*)X, 357 call->incX, (double*)Y, call->incY, (double*)A, lda); 358 break; 359 case (RsBlas_dsyr): 360 initABC(ain, sizeof(double), &X, &A, nullptr, &ldb, &lda, nullptr); 361 cblas_dsyr(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX, 362 (double*)A, lda); 363 break; 364 // dspr is packed 1D Allocation A only 365 case (RsBlas_dspr): 366 initABC(ain, sizeof(double), &X, &A, nullptr, &ldb, &lda, nullptr); 367 cblas_dspr(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX, 368 (double*)A); 369 break; 370 case (RsBlas_dsyr2): 371 initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda); 372 cblas_dsyr2(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX, 373 (double*)Y, call->incY, (double*)A, lda); 374 break; 375 // dspr2 is packed 1D Allocation A only 376 case (RsBlas_dspr2): 377 initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda); 378 cblas_dspr2(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX, 379 (double*)Y, call->incY, (double*)A); 380 break; 381 382 // C and Z only 383 case (RsBlas_chemv): 384 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc); 385 cblas_chemv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, A, lda, 386 X, call->incX, (void*)&call->beta.c, Y, call->incY); 387 break; 388 case (RsBlas_chbmv): 389 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc); 390 cblas_chbmv(CblasRowMajor, Uplo, call->N, call->K, (void*)&call->alpha.c, 391 A, lda, X, call->incX, (void*)&call->beta.c, Y, call->incY); 392 break; 393 case (RsBlas_chpmv): 394 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc); 395 cblas_chpmv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, A, 396 X, call->incX, (void*)&call->beta.c, Y, call->incY); 397 break; 398 case (RsBlas_cgeru): 399 initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda); 400 cblas_cgeru(CblasRowMajor, call->M, call->N, (void*)&call->alpha.c, 401 X, call->incX, Y, call->incY, A, lda); 402 break; 403 case (RsBlas_cgerc): 404 initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda); 405 cblas_cgerc(CblasRowMajor, call->M, call->N, (void*)&call->alpha.c, 406 X, call->incX, Y, call->incY, A, lda); 407 break; 408 case (RsBlas_cher): 409 initABC(ain, sizeof(float)*2, &X, &A, nullptr, &ldb, &lda, nullptr); 410 cblas_cher(CblasRowMajor, Uplo, call->N, call->alpha.f, 411 X, call->incX, A, lda); 412 break; 413 // packed 1D Allocations only 414 case (RsBlas_chpr): 415 initABC(ain, sizeof(float)*2, &X, &A, nullptr, &ldb, &lda, nullptr); 416 cblas_chpr(CblasRowMajor, Uplo, call->N, call->alpha.f, X, 417 call->incX, A); 418 break; 419 case (RsBlas_cher2): 420 initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda); 421 cblas_cher2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, 422 X, call->incX, Y, call->incY, A, lda); 423 break; 424 // packed 1D Allocations only 425 case (RsBlas_chpr2): 426 initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda); 427 cblas_chpr2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, X, 428 call->incX, Y, call->incY, A); 429 break; 430 case (RsBlas_zhemv): 431 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc); 432 cblas_zhemv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, A, lda, 433 X, call->incX, (void*)&call->beta.z, Y, call->incY); 434 break; 435 case (RsBlas_zhbmv): 436 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc); 437 cblas_zhbmv(CblasRowMajor, Uplo, call->N, call->K, (void*)&call->alpha.z, 438 A, lda, X, call->incX, (void*)&call->beta.z, Y, call->incY); 439 break; 440 case (RsBlas_zhpmv): 441 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc); 442 cblas_zhpmv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, A, 443 X, call->incX, (void*)&call->beta.z, Y, call->incY); 444 break; 445 case (RsBlas_zgeru): 446 initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda); 447 cblas_zgeru(CblasRowMajor, call->M, call->N, (void*)&call->alpha.z, 448 X, call->incX, Y, call->incY, A, lda); 449 break; 450 case (RsBlas_zgerc): 451 initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda); 452 cblas_zgerc(CblasRowMajor, call->M, call->N, (void*)&call->alpha.z, 453 X, call->incX, Y, call->incY, A, lda); 454 break; 455 case (RsBlas_zher): 456 initABC(ain, sizeof(double)*2, &X, &A, nullptr, &ldb, &lda, nullptr); 457 cblas_zher(CblasRowMajor, Uplo, call->N, call->alpha.d, 458 X, call->incX, A, lda); 459 break; 460 // packed 1D Allocations only 461 case (RsBlas_zhpr): 462 initABC(ain, sizeof(double)*2, &X, &A, nullptr, &ldb, &lda, nullptr); 463 cblas_zhpr(CblasRowMajor, Uplo, call->N, call->alpha.d, X, 464 call->incX, A); 465 break; 466 case (RsBlas_zher2): 467 initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda); 468 cblas_zher2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, 469 X, call->incX, Y, call->incY, A, lda); 470 break; 471 // packed 1D Allocations only 472 case (RsBlas_zhpr2): 473 initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda); 474 cblas_zhpr2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, X, 475 call->incX, Y, call->incY, A); 476 break; 477 478 // Level 3 BLAS 479 case (RsBlas_sgemm): 480 initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc); 481 ALOGE("call->M = %d, call->N = %d, call->K = %d, lda = %d, ldb = %d, ldc = %d", call->M, call->N, call->K, lda, ldb, ldc); 482 cblas_sgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.f, 483 (float*)A, lda, (float*)B, ldb, call->beta.f, (float*)C, ldc); 484 break; 485 case (RsBlas_ssymm): 486 initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc); 487 cblas_ssymm(CblasRowMajor, Side, Uplo, call->M, call->N, call->alpha.f, (float*)A, 488 lda, (float*)B, ldb, call->beta.f, (float*)C, ldc); 489 break; 490 case (RsBlas_ssyrk): 491 initABC(ain, sizeof(float), &A, nullptr, &C, &lda, nullptr, &ldc); 492 cblas_ssyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, (float*)A, 493 lda, call->beta.f, (float*)C, ldc); 494 break; 495 case (RsBlas_ssyr2k): 496 initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc); 497 cblas_ssyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, (float*)A, 498 lda, (float*)B, ldb, call->beta.f, (float*)C, ldc); 499 break; 500 case (RsBlas_strmm): 501 initABC(ain, sizeof(float), &A, &B, nullptr, &lda, &ldb, nullptr); 502 cblas_strmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.f, 503 (float*)A, lda, (float*)B, ldb); 504 break; 505 case (RsBlas_strsm): 506 initABC(ain, sizeof(float), &A, &B, nullptr, &lda, &ldb, nullptr); 507 cblas_strsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.f, 508 (float*)A, lda, (float*)B, ldb); 509 break; 510 511 512 case (RsBlas_dgemm): 513 initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc); 514 cblas_dgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.d, 515 (double*)A, lda, (double*)B, ldb, call->beta.d, (double*)C, ldc); 516 break; 517 case (RsBlas_dsymm): 518 initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc); 519 cblas_dsymm(CblasRowMajor, Side, Uplo, call->M, call->N, call->alpha.d, (double*)A, 520 lda, (double*)B, ldb, call->beta.d, (double*)C, ldc); 521 break; 522 case (RsBlas_dsyrk): 523 initABC(ain, sizeof(double), &A, nullptr, &C, &lda, nullptr, &ldc); 524 cblas_dsyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, (double*)A, 525 lda, call->beta.d, (double*)C, ldc); 526 break; 527 case (RsBlas_dsyr2k): 528 initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc); 529 cblas_dsyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, (double*)A, 530 lda, (double*)B, ldb, call->beta.d, (double*)C, ldc); 531 break; 532 case (RsBlas_dtrmm): 533 initABC(ain, sizeof(double), &A, &B, nullptr, &lda, &ldb, nullptr); 534 cblas_dtrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.d, 535 (double*)A, lda, (double*)B, ldb); 536 break; 537 case (RsBlas_dtrsm): 538 initABC(ain, sizeof(double), &A, &B, nullptr, &lda, &ldb, nullptr); 539 cblas_dtrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.d, 540 (double*)A, lda, (double*)B, ldb); 541 break; 542 543 case (RsBlas_cgemm): 544 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc); 545 cblas_cgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.c, 546 A, lda, B, ldb, (void*)&call->beta.c, C, ldc); 547 break; 548 case (RsBlas_csymm): 549 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc); 550 cblas_csymm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.c, A, 551 lda, B, ldb, (void*)&call->beta.c, C, ldc); 552 break; 553 case (RsBlas_csyrk): 554 initABC(ain, sizeof(float)*2, &A, nullptr, &C, &lda, nullptr, &ldc); 555 cblas_csyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A, 556 lda, (void*)&call->beta.c, C, ldc); 557 break; 558 case (RsBlas_csyr2k): 559 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc); 560 cblas_csyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A, 561 lda, B, ldb, (void*)&call->beta.c, C, ldc); 562 break; 563 case (RsBlas_ctrmm): 564 initABC(ain, sizeof(float)*2, &A, &B, nullptr, &lda, &ldb, nullptr); 565 cblas_ctrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.c, 566 A, lda, B, ldb); 567 break; 568 case (RsBlas_ctrsm): 569 initABC(ain, sizeof(float)*2, &A, &B, nullptr, &lda, &ldb, nullptr); 570 cblas_ctrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.c, 571 A, lda, B, ldb); 572 break; 573 574 case (RsBlas_zgemm): 575 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc); 576 cblas_zgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.z, 577 A, lda, B, ldb, (void*)&call->beta.z, C, ldc); 578 break; 579 case (RsBlas_zsymm): 580 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc); 581 cblas_zsymm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.z, A, 582 lda, B, ldb, (void*)&call->beta.z, C, ldc); 583 break; 584 case (RsBlas_zsyrk): 585 initABC(ain, sizeof(double)*2, &A, nullptr, &C, &lda, nullptr, &ldc); 586 cblas_zsyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A, 587 lda, (void*)&call->beta.z, C, ldc); 588 break; 589 case (RsBlas_zsyr2k): 590 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc); 591 cblas_zsyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A, 592 lda, B, ldb, (void*)&call->beta.z, C, ldc); 593 break; 594 case (RsBlas_ztrmm): 595 initABC(ain, sizeof(double)*2, &A, &B, nullptr, &lda, &ldb, nullptr); 596 cblas_ztrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.z, 597 A, lda, B, ldb); 598 break; 599 case (RsBlas_ztrsm): 600 initABC(ain, sizeof(double)*2, &A, &B, nullptr, &lda, &ldb, nullptr); 601 cblas_ztrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.z, 602 A, lda, B, ldb); 603 break; 604 605 // Level 3 C and Z only 606 case (RsBlas_chemm): 607 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc); 608 cblas_chemm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.c, A, lda, 609 B, ldb, (void*)&call->beta.c, C, ldc); 610 break; 611 case (RsBlas_cherk): 612 initABC(ain, sizeof(float)*2, &A, nullptr, &C, &lda, nullptr, &ldc); 613 cblas_cherk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, A, lda, 614 call->beta.f, C, ldc); 615 break; 616 case (RsBlas_cher2k): 617 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc); 618 cblas_cher2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A, lda, 619 B, ldb, call->beta.f, C, ldc); 620 break; 621 622 case (RsBlas_zhemm): 623 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc); 624 cblas_zhemm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.z, A, lda, 625 B, ldb, (void*)&call->beta.z, C, ldc); 626 break; 627 case (RsBlas_zherk): 628 initABC(ain, sizeof(double)*2, &A, nullptr, &C, &lda, nullptr, &ldc); 629 cblas_zherk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, A, lda, 630 call->beta.d, C, ldc); 631 break; 632 case (RsBlas_zher2k): 633 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc); 634 cblas_zher2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A, lda, 635 B, ldb, call->beta.d, C, ldc); 636 break; 637 638 639 case (RsBlas_bnnm): 640 initABC(ain, sizeof(uint8_t), &A, &B, &C, &lda, &ldb, &ldc); 641 kernelBNNM(call->M, call->N, call->K, 642 (const uint8_t*)A, call->a_offset, lda, 643 (const uint8_t*)B, call->b_offset, ldb, 644 (uint8_t*)C, call->c_offset, ldc, 645 call->c_mult_int); 646 647 break; 648 649 default: 650 ALOGE("unimplemented\n"); 651 } 652 653 654} 655 656void RsdCpuScriptIntrinsicBLAS::kernelBNNM(size_t m, size_t n, size_t k, 657 const uint8_t* a, uint32_t a_offset, size_t lda, 658 const uint8_t* b, uint32_t b_offset, size_t ldb, 659 uint8_t* c, uint32_t c_offset, size_t ldc, 660 uint32_t c_mult_int) { 661 // Calculations are done in 1.10.21 fixed-point format for the final output, 662 // just before there's a shift down to drop the fractional parts. The output 663 // values are gated to 0 to 255 to fit in a byte, but the 10-bit format 664 // gives some headroom to avoid wrapping around on small overflows. 665 const int c_shift = 21; 666 size_t i = 0, j = 0, l = 0; 667 for (j = 0; j < n; j++) { 668 for (i = 0; i < m; i++) { 669 int32_t total = 0; 670 for (l = 0; l < k; l++) { 671 const int a_index = ((i * lda) + l); 672 const uint8_t a_as_byte = a[a_index]; 673 const int32_t a_as_int = (((int32_t)(a_as_byte)) - a_offset); 674 const int b_index = ((j * ldb) + l); 675 const uint8_t b_as_byte = b[b_index]; 676 const int32_t b_as_int = (((int32_t)(b_as_byte)) - b_offset); 677 const int32_t mult_as_int = (a_as_int * b_as_int); 678 total += mult_as_int; 679 } 680 const int c_index = ((ldc * i) + j); 681 int32_t output = 682 ((((total + c_offset) * c_mult_int) + (1 << (c_shift - 1))) 683 >> c_shift); 684 if (output > 255) { 685 output = 255; 686 } 687 if (output < 0) { 688 output = 0; 689 } 690 c[c_index] = (uint8_t)(output); 691 } 692 } 693} 694 695 696 697 698 699RsdCpuScriptIntrinsicBLAS::RsdCpuScriptIntrinsicBLAS(RsdCpuReferenceImpl *ctx, 700 const Script *s) 701 : RsdCpuScriptIntrinsic(ctx, s, nullptr, RS_SCRIPT_INTRINSIC_ID_BLAS) { 702 703 704} 705 706RsdCpuScriptIntrinsicBLAS::~RsdCpuScriptIntrinsicBLAS() { 707} 708 709 710 711 712 713RsdCpuScriptImpl * rsdIntrinsic_BLAS(RsdCpuReferenceImpl *ctx, 714 const Script *s, const Element *e) { 715 716 return new RsdCpuScriptIntrinsicBLAS(ctx, s); 717} 718