rsCpuIntrinsicBLAS.cpp revision 06deda3751a4a7358a7c7e03fbf1e4325fafb807
1/* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 18#include "rsCpuIntrinsic.h" 19#include "rsCpuIntrinsicInlines.h" 20#include "cblas.h" 21 22using namespace android; 23using namespace android::renderscript; 24 25namespace android { 26namespace renderscript { 27 28 29class RsdCpuScriptIntrinsicBLAS : public RsdCpuScriptIntrinsic { 30public: 31 void invokeForEach(uint32_t slot, 32 const Allocation ** ain, 33 uint32_t inLen, 34 Allocation * aout, 35 const void * usr, 36 uint32_t usrLen, 37 const RsScriptCall *sc) override; 38 39 void populateScript(Script *) override; 40 ~RsdCpuScriptIntrinsicBLAS() override; 41 RsdCpuScriptIntrinsicBLAS(RsdCpuReferenceImpl *ctx, const Script *s); 42 43protected: 44 45 uint8_t a_offset = 0; 46 uint8_t b_offset = 0; 47 uint8_t c_offset = 0; 48 49 static void kernelBNNM(size_t m, size_t n, size_t k, 50 const uint8_t* a, uint8_t a_offset, size_t lda, 51 const uint8_t* b, uint8_t b_offset, size_t ldb, 52 uint8_t* c, int32_t c_offset, size_t ldc, 53 int32_t c_mult_int); 54 55 56 57}; 58 59} 60} 61 62void RsdCpuScriptIntrinsicBLAS::populateScript(Script *s) { 63 s->mHal.info.exportedVariableCount = 0; 64} 65 66static void initABC(const Allocation ** ain, 67 size_t size, 68 void** A, 69 void** B, 70 void** C, 71 int* lda, 72 int* ldb, 73 int* ldc) 74{ 75 if (ain[0]) { 76 *A = ain[0]->mHal.drvState.lod[0].mallocPtr; 77 *lda = (int)(ain[0]->mHal.drvState.lod[0].stride/size); 78 } 79 if (ain[1]) { 80 *B = ain[1]->mHal.drvState.lod[0].mallocPtr; 81 *ldb = (int)(ain[1]->mHal.drvState.lod[0].stride/size); 82 } 83 if (ain[2]) { 84 *C = ain[2]->mHal.drvState.lod[0].mallocPtr; 85 *ldc = (int)(ain[2]->mHal.drvState.lod[0].stride/size); 86 } 87 88 89} 90 91void RsdCpuScriptIntrinsicBLAS::invokeForEach(uint32_t slot, 92 const Allocation ** ain, 93 uint32_t inLen, 94 Allocation * aout, 95 const void * usr, 96 uint32_t usrLen, 97 const RsScriptCall *sc) { 98 RsBlasCall* call = (RsBlasCall*) usr; 99 // setup BLAS enum args 100 enum CBLAS_TRANSPOSE TransA = (enum CBLAS_TRANSPOSE)call->transA; 101 enum CBLAS_TRANSPOSE TransB = (enum CBLAS_TRANSPOSE)call->transB; 102 enum CBLAS_UPLO Uplo = (enum CBLAS_UPLO)call->uplo; 103 enum CBLAS_DIAG Diag = (enum CBLAS_DIAG)call->diag; 104 enum CBLAS_SIDE Side = (enum CBLAS_SIDE)call->side; 105 106 void *A = nullptr; 107 void *B = nullptr; 108 void *C = nullptr; 109 void *X = nullptr; 110 void *Y = nullptr; 111 112 int lda = 0, ldb = 0, ldc = 0; 113 114 switch (call->func) { 115 116 // Level 1 BLAS: returns into a 1D Allocation 117 118 119 // Level 2 BLAS 120 case (RsBlas_sgemv): 121 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc); 122 cblas_sgemv(CblasRowMajor, TransA, call->M, call->N, call->alpha.f, (float*)A, 123 lda, (float*)X, call->incX, call->beta.f, (float*)Y, call->incY); 124 break; 125 case (RsBlas_sgbmv): 126 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc); 127 cblas_sgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU, 128 call->alpha.f, (float*)A, lda, (float*)X, call->incX, 129 call->beta.f, (float*)Y, call->incY); 130 break; 131 case (RsBlas_strmv): 132 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 133 cblas_strmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A, 134 lda, (float*)X, call->incX); 135 break; 136 case (RsBlas_stbmv): 137 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 138 cblas_stbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (float*)A, 139 lda, (float*)X, call->incX); 140 break; 141 // stpmv takes a packed 1D Allocation only 142 case (RsBlas_stpmv): 143 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 144 cblas_stpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A, 145 (float*)X, call->incX); 146 break; 147 case (RsBlas_strsv): 148 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 149 cblas_strsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A, lda, 150 (float*)X, call->incX); 151 break; 152 case (RsBlas_stbsv): 153 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 154 cblas_stbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (float*)A, 155 lda, (float*)X, call->incX); 156 break; 157 case (RsBlas_stpsv): 158 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 159 cblas_stpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A, 160 (float*)X, call->incX); 161 break; 162 case (RsBlas_dgemv): 163 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc); 164 cblas_dgemv(CblasRowMajor, TransA, call->M, call->N, call->alpha.d, (double*)A, 165 lda, (double*)X, call->incX, call->beta.d, (double*)Y, call->incY); 166 break; 167 case (RsBlas_dgbmv): 168 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc); 169 cblas_dgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU, 170 call->alpha.d, (double*)A, lda, (double*)X, call->incX, 171 call->beta.d, (double*)Y, call->incY); 172 break; 173 case (RsBlas_dtrmv): 174 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 175 cblas_dtrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A, 176 lda, (double*)X, call->incX); 177 break; 178 case (RsBlas_dtbmv): 179 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 180 cblas_dtbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (double*)A, 181 lda, (double*)X, call->incX); 182 break; 183 // stpmv takes a packed 1D Allocation only 184 case (RsBlas_dtpmv): 185 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 186 cblas_dtpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A, 187 (double*)X, call->incX); 188 break; 189 case (RsBlas_dtrsv): 190 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 191 cblas_dtrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A, lda, 192 (double*)X, call->incX); 193 break; 194 case (RsBlas_dtbsv): 195 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 196 cblas_dtbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (double*)A, 197 lda, (double*)X, call->incX); 198 break; 199 case (RsBlas_dtpsv): 200 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 201 cblas_dtpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A, 202 (double*)X, call->incX); 203 break; 204 case (RsBlas_cgemv): 205 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc); 206 cblas_cgemv(CblasRowMajor, TransA, call->M, call->N, (void*)&call->alpha.c, (void*)A, 207 lda, (void*)X, call->incX, (void*)&call->beta.c, (void*)Y, call->incY); 208 break; 209 case (RsBlas_cgbmv): 210 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc); 211 cblas_cgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU, 212 (void*)&call->alpha.c, (void*)A, lda, (void*)X, call->incX, 213 (void*)&call->beta.c, (void*)Y, call->incY); 214 break; 215 case (RsBlas_ctrmv): 216 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 217 cblas_ctrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 218 lda, (void*)X, call->incX); 219 break; 220 case (RsBlas_ctbmv): 221 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 222 cblas_ctbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A, 223 lda, (void*)X, call->incX); 224 break; 225 // stpmv takes a packed 1D Allocation only 226 case (RsBlas_ctpmv): 227 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 228 cblas_ctpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 229 (void*)X, call->incX); 230 break; 231 case (RsBlas_ctrsv): 232 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 233 cblas_ctrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, lda, 234 (void*)X, call->incX); 235 break; 236 case (RsBlas_ctbsv): 237 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 238 cblas_ctbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A, 239 lda, (void*)X, call->incX); 240 break; 241 case (RsBlas_ctpsv): 242 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 243 cblas_ctpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 244 (void*)X, call->incX); 245 break; 246 case (RsBlas_zgemv): 247 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc); 248 cblas_zgemv(CblasRowMajor, TransA, call->M, call->N, (void*)&call->alpha.z, (void*)A, 249 lda, (void*)X, call->incX, (void*)&call->beta.z, (void*)Y, call->incY); 250 break; 251 case (RsBlas_zgbmv): 252 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc); 253 cblas_zgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU, 254 (void*)&call->alpha.z, (void*)A, lda, (void*)X, call->incX, 255 (void*)&call->beta.z, (void*)Y, call->incY); 256 break; 257 case (RsBlas_ztrmv): 258 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 259 cblas_ztrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 260 lda, (void*)X, call->incX); 261 break; 262 case (RsBlas_ztbmv): 263 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 264 cblas_ztbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A, 265 lda, (void*)X, call->incX); 266 break; 267 // stpmv takes a packed 1D Allocation only 268 case (RsBlas_ztpmv): 269 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 270 cblas_ztpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 271 (void*)X, call->incX); 272 break; 273 case (RsBlas_ztrsv): 274 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 275 cblas_ztrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, lda, 276 (void*)X, call->incX); 277 break; 278 case (RsBlas_ztbsv): 279 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 280 cblas_ztbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A, 281 lda, (void*)X, call->incX); 282 break; 283 case (RsBlas_ztpsv): 284 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 285 cblas_ztpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 286 (void*)X, call->incX); 287 break; 288 289 290 // S and D only 291 case (RsBlas_ssymv): 292 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc); 293 cblas_ssymv(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)A, lda, 294 (float*)X, call->incX, call->beta.f, (float*)Y, call->incY); 295 break; 296 case (RsBlas_ssbmv): 297 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc); 298 cblas_ssbmv(CblasRowMajor, Uplo, call->N, call->K, call->alpha.f, 299 (float*)A, lda, (float*)X, call->incX, call->beta.f, 300 (float*)Y, call->incY); 301 break; 302 //sspmv requires a packed 1D Allocation 303 case (RsBlas_sspmv): 304 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc); 305 cblas_sspmv(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)A, 306 (float*)X, call->incX, call->beta.f, (float*)Y, call->incY); 307 break; 308 // following calls have init reordered because A is output matrix 309 case (RsBlas_sger): 310 initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda); 311 cblas_sger(CblasRowMajor, call->M, call->N, call->alpha.f, (float*)X, 312 call->incX, (float*)Y, call->incY, (float*)A, lda); 313 break; 314 case (RsBlas_ssyr): 315 initABC(ain, sizeof(float), &X, &A, nullptr, &ldb, &lda, nullptr); 316 cblas_ssyr(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX, 317 (float*)A, lda); 318 break; 319 // sspr is packed 1D Allocation A only 320 case (RsBlas_sspr): 321 initABC(ain, sizeof(float), &X, &A, nullptr, &ldb, &lda, nullptr); 322 cblas_sspr(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX, 323 (float*)A); 324 break; 325 case (RsBlas_ssyr2): 326 initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda); 327 cblas_ssyr2(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX, 328 (float*)Y, call->incY, (float*)A, lda); 329 break; 330 // sspr2 is packed 1D Allocation A only 331 case (RsBlas_sspr2): 332 initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda); 333 cblas_sspr2(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX, 334 (float*)Y, call->incY, (float*)A); 335 break; 336 case (RsBlas_dsymv): 337 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc); 338 cblas_dsymv(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)A, lda, 339 (double*)X, call->incX, call->beta.d, (double*)Y, call->incY); 340 break; 341 case (RsBlas_dsbmv): 342 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc); 343 cblas_dsbmv(CblasRowMajor, Uplo, call->N, call->K, call->alpha.d, 344 (double*)A, lda, (double*)X, call->incX, call->beta.d, 345 (double*)Y, call->incY); 346 break; 347 // dspmv requires a packed 1D Allocation 348 case (RsBlas_dspmv): 349 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc); 350 cblas_dspmv(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)A, 351 (double*)X, call->incX, call->beta.d, (double*)Y, call->incY); 352 break; 353 // following calls have init reordered because A is output matrix 354 case (RsBlas_dger): 355 initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda); 356 cblas_dger(CblasRowMajor, call->M, call->N, call->alpha.d, (double*)X, 357 call->incX, (double*)Y, call->incY, (double*)A, lda); 358 break; 359 case (RsBlas_dsyr): 360 initABC(ain, sizeof(double), &X, &A, nullptr, &ldb, &lda, nullptr); 361 cblas_dsyr(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX, 362 (double*)A, lda); 363 break; 364 // dspr is packed 1D Allocation A only 365 case (RsBlas_dspr): 366 initABC(ain, sizeof(double), &X, &A, nullptr, &ldb, &lda, nullptr); 367 cblas_dspr(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX, 368 (double*)A); 369 break; 370 case (RsBlas_dsyr2): 371 initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda); 372 cblas_dsyr2(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX, 373 (double*)Y, call->incY, (double*)A, lda); 374 break; 375 // dspr2 is packed 1D Allocation A only 376 case (RsBlas_dspr2): 377 initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda); 378 cblas_dspr2(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX, 379 (double*)Y, call->incY, (double*)A); 380 break; 381 382 // C and Z only 383 case (RsBlas_chemv): 384 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc); 385 cblas_chemv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, A, lda, 386 X, call->incX, (void*)&call->beta.c, Y, call->incY); 387 break; 388 case (RsBlas_chbmv): 389 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc); 390 cblas_chbmv(CblasRowMajor, Uplo, call->N, call->K, (void*)&call->alpha.c, 391 A, lda, X, call->incX, (void*)&call->beta.c, Y, call->incY); 392 break; 393 case (RsBlas_chpmv): 394 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc); 395 cblas_chpmv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, A, 396 X, call->incX, (void*)&call->beta.c, Y, call->incY); 397 break; 398 case (RsBlas_cgeru): 399 initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda); 400 cblas_cgeru(CblasRowMajor, call->M, call->N, (void*)&call->alpha.c, 401 X, call->incX, Y, call->incY, A, lda); 402 break; 403 case (RsBlas_cgerc): 404 initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda); 405 cblas_cgerc(CblasRowMajor, call->M, call->N, (void*)&call->alpha.c, 406 X, call->incX, Y, call->incY, A, lda); 407 break; 408 case (RsBlas_cher): 409 initABC(ain, sizeof(float)*2, &X, nullptr, &A, &ldb, nullptr, &lda); 410 cblas_cher(CblasRowMajor, Uplo, call->N, call->alpha.f, 411 X, call->incX, A, lda); 412 break; 413 // packed 1D Allocations only 414 case (RsBlas_chpr): 415 initABC(ain, sizeof(float)*2, &X, nullptr, &A, &ldb, nullptr, &lda); 416 cblas_chpr(CblasRowMajor, Uplo, call->N, call->alpha.f, X, 417 call->incX, A); 418 break; 419 case (RsBlas_cher2): 420 initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda); 421 cblas_cher2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, 422 X, call->incX, Y, call->incY, A, lda); 423 break; 424 // packed 1D Allocations only 425 case (RsBlas_chpr2): 426 initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda); 427 cblas_chpr2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, X, 428 call->incX, Y, call->incY, A); 429 break; 430 case (RsBlas_zhemv): 431 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc); 432 cblas_zhemv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, A, lda, 433 X, call->incX, (void*)&call->beta.z, Y, call->incY); 434 break; 435 case (RsBlas_zhbmv): 436 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc); 437 cblas_zhbmv(CblasRowMajor, Uplo, call->N, call->K, (void*)&call->alpha.z, 438 A, lda, X, call->incX, (void*)&call->beta.z, Y, call->incY); 439 break; 440 case (RsBlas_zhpmv): 441 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc); 442 cblas_zhpmv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, A, 443 X, call->incX, (void*)&call->beta.z, Y, call->incY); 444 break; 445 case (RsBlas_zgeru): 446 initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda); 447 cblas_zgeru(CblasRowMajor, call->M, call->N, (void*)&call->alpha.z, 448 X, call->incX, Y, call->incY, A, lda); 449 break; 450 case (RsBlas_zgerc): 451 initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda); 452 cblas_zgerc(CblasRowMajor, call->M, call->N, (void*)&call->alpha.z, 453 X, call->incX, Y, call->incY, A, lda); 454 break; 455 case (RsBlas_zher): 456 initABC(ain, sizeof(double)*2, &X, nullptr, &A, &ldb, nullptr, &lda); 457 cblas_zher(CblasRowMajor, Uplo, call->N, call->alpha.d, 458 X, call->incX, A, lda); 459 break; 460 // packed 1D Allocations only 461 case (RsBlas_zhpr): 462 initABC(ain, sizeof(double)*2, &X, nullptr, &A, &ldb, nullptr, &lda); 463 cblas_zhpr(CblasRowMajor, Uplo, call->N, call->alpha.d, X, 464 call->incX, A); 465 break; 466 case (RsBlas_zher2): 467 initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda); 468 cblas_zher2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, 469 X, call->incX, Y, call->incY, A, lda); 470 break; 471 // packed 1D Allocations only 472 case (RsBlas_zhpr2): 473 initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda); 474 cblas_zhpr2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, X, 475 call->incX, Y, call->incY, A); 476 break; 477 478 // Level 3 BLAS 479 case (RsBlas_sgemm): 480 initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc); 481 cblas_sgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.f, 482 (float*)A, lda, (float*)B, ldb, call->beta.f, (float*)C, ldc); 483 break; 484 case (RsBlas_ssymm): 485 initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc); 486 cblas_ssymm(CblasRowMajor, Side, Uplo, call->M, call->N, call->alpha.f, (float*)A, 487 lda, (float*)B, ldb, call->beta.f, (float*)C, ldc); 488 break; 489 case (RsBlas_ssyrk): 490 initABC(ain, sizeof(float), &A, nullptr, &C, &lda, nullptr, &ldc); 491 cblas_ssyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, (float*)A, 492 lda, call->beta.f, (float*)C, ldc); 493 break; 494 case (RsBlas_ssyr2k): 495 initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc); 496 cblas_ssyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, (float*)A, 497 lda, (float*)B, ldb, call->beta.f, (float*)C, ldc); 498 break; 499 case (RsBlas_strmm): 500 initABC(ain, sizeof(float), &A, &B, nullptr, &lda, &ldb, nullptr); 501 cblas_strmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.f, 502 (float*)A, lda, (float*)B, ldb); 503 break; 504 case (RsBlas_strsm): 505 initABC(ain, sizeof(float), &A, &B, nullptr, &lda, &ldb, nullptr); 506 cblas_strsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.f, 507 (float*)A, lda, (float*)B, ldb); 508 break; 509 510 511 case (RsBlas_dgemm): 512 initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc); 513 cblas_dgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.d, 514 (double*)A, lda, (double*)B, ldb, call->beta.d, (double*)C, ldc); 515 break; 516 case (RsBlas_dsymm): 517 initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc); 518 cblas_dsymm(CblasRowMajor, Side, Uplo, call->M, call->N, call->alpha.d, (double*)A, 519 lda, (double*)B, ldb, call->beta.d, (double*)C, ldc); 520 break; 521 case (RsBlas_dsyrk): 522 initABC(ain, sizeof(double), &A, nullptr, &C, &lda, nullptr, &ldc); 523 cblas_dsyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, (double*)A, 524 lda, call->beta.d, (double*)C, ldc); 525 break; 526 case (RsBlas_dsyr2k): 527 initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc); 528 cblas_dsyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, (double*)A, 529 lda, (double*)B, ldb, call->beta.d, (double*)C, ldc); 530 break; 531 case (RsBlas_dtrmm): 532 initABC(ain, sizeof(double), &A, &B, nullptr, &lda, &ldb, nullptr); 533 cblas_dtrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.d, 534 (double*)A, lda, (double*)B, ldb); 535 break; 536 case (RsBlas_dtrsm): 537 initABC(ain, sizeof(double), &A, &B, nullptr, &lda, &ldb, nullptr); 538 cblas_dtrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.d, 539 (double*)A, lda, (double*)B, ldb); 540 break; 541 542 case (RsBlas_cgemm): 543 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc); 544 cblas_cgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.c, 545 A, lda, B, ldb, (void*)&call->beta.c, C, ldc); 546 break; 547 case (RsBlas_csymm): 548 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc); 549 cblas_csymm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.c, A, 550 lda, B, ldb, (void*)&call->beta.c, C, ldc); 551 break; 552 case (RsBlas_csyrk): 553 initABC(ain, sizeof(float)*2, &A, nullptr, &C, &lda, nullptr, &ldc); 554 cblas_csyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A, 555 lda, (void*)&call->beta.c, C, ldc); 556 break; 557 case (RsBlas_csyr2k): 558 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc); 559 cblas_csyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A, 560 lda, B, ldb, (void*)&call->beta.c, C, ldc); 561 break; 562 case (RsBlas_ctrmm): 563 initABC(ain, sizeof(float)*2, &A, &B, nullptr, &lda, &ldb, nullptr); 564 cblas_ctrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.c, 565 A, lda, B, ldb); 566 break; 567 case (RsBlas_ctrsm): 568 initABC(ain, sizeof(float)*2, &A, &B, nullptr, &lda, &ldb, nullptr); 569 cblas_ctrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.c, 570 A, lda, B, ldb); 571 break; 572 573 case (RsBlas_zgemm): 574 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc); 575 cblas_zgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.z, 576 A, lda, B, ldb, (void*)&call->beta.z, C, ldc); 577 break; 578 case (RsBlas_zsymm): 579 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc); 580 cblas_zsymm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.z, A, 581 lda, B, ldb, (void*)&call->beta.z, C, ldc); 582 break; 583 case (RsBlas_zsyrk): 584 initABC(ain, sizeof(double)*2, &A, nullptr, &C, &lda, nullptr, &ldc); 585 cblas_zsyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A, 586 lda, (void*)&call->beta.z, C, ldc); 587 break; 588 case (RsBlas_zsyr2k): 589 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc); 590 cblas_zsyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A, 591 lda, B, ldb, (void*)&call->beta.z, C, ldc); 592 break; 593 case (RsBlas_ztrmm): 594 initABC(ain, sizeof(double)*2, &A, &B, nullptr, &lda, &ldb, nullptr); 595 cblas_ztrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.z, 596 A, lda, B, ldb); 597 break; 598 case (RsBlas_ztrsm): 599 initABC(ain, sizeof(double)*2, &A, &B, nullptr, &lda, &ldb, nullptr); 600 cblas_ztrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.z, 601 A, lda, B, ldb); 602 break; 603 604 // Level 3 C and Z only 605 case (RsBlas_chemm): 606 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc); 607 cblas_chemm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.c, A, lda, 608 B, ldb, (void*)&call->beta.c, C, ldc); 609 break; 610 case (RsBlas_cherk): 611 initABC(ain, sizeof(float)*2, &A, nullptr, &C, &lda, nullptr, &ldc); 612 cblas_cherk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, A, lda, 613 call->beta.f, C, ldc); 614 break; 615 case (RsBlas_cher2k): 616 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc); 617 cblas_cher2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A, lda, 618 B, ldb, call->beta.f, C, ldc); 619 break; 620 621 case (RsBlas_zhemm): 622 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc); 623 cblas_zhemm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.z, A, lda, 624 B, ldb, (void*)&call->beta.z, C, ldc); 625 break; 626 case (RsBlas_zherk): 627 initABC(ain, sizeof(double)*2, &A, nullptr, &C, &lda, nullptr, &ldc); 628 cblas_zherk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, A, lda, 629 call->beta.d, C, ldc); 630 break; 631 case (RsBlas_zher2k): 632 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc); 633 cblas_zher2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A, lda, 634 B, ldb, call->beta.d, C, ldc); 635 break; 636 637 638 case (RsBlas_bnnm): 639 initABC(ain, sizeof(uint8_t), &A, &B, &C, &lda, &ldb, &ldc); 640 kernelBNNM(call->M, call->N, call->K, 641 (const uint8_t*)A, call->a_offset, lda, 642 (const uint8_t*)B, call->b_offset, ldb, 643 (uint8_t*)C, call->c_offset, ldc, 644 call->c_mult_int); 645 646 break; 647 648 default: 649 ALOGE("unimplemented\n"); 650 } 651 652 653} 654 655void RsdCpuScriptIntrinsicBLAS::kernelBNNM(size_t m, size_t n, size_t k, 656 const uint8_t* a, uint8_t a_offset, size_t lda, 657 const uint8_t* b, uint8_t b_offset, size_t ldb, 658 uint8_t* c, int32_t c_offset, size_t ldc, 659 int32_t c_mult_int) { 660 // Calculations are done in 1.10.21 fixed-point format for the final output, 661 // just before there's a shift down to drop the fractional parts. The output 662 // values are gated to 0 to 255 to fit in a byte, but the 10-bit format 663 // gives some headroom to avoid wrapping around on small overflows. 664 const int c_shift = 21; 665 size_t i = 0, j = 0, l = 0; 666 for (j = 0; j < n; j++) { 667 for (i = 0; i < m; i++) { 668 int32_t total = 0; 669 for (l = 0; l < k; l++) { 670 const int a_index = ((i * lda) + l); 671 const uint8_t a_as_byte = a[a_index]; 672 const int32_t a_as_int = (((int32_t)(a_as_byte)) - a_offset); 673 const int b_index = ((j * ldb) + l); 674 const uint8_t b_as_byte = b[b_index]; 675 const int32_t b_as_int = (((int32_t)(b_as_byte)) - b_offset); 676 const int32_t mult_as_int = (a_as_int * b_as_int); 677 total += mult_as_int; 678 } 679 const int c_index = ((ldc * i) + j); 680 int32_t output = 681 ((((total + c_offset) * c_mult_int) + (1 << (c_shift - 1))) 682 >> c_shift); 683 if (output > 255) { 684 output = 255; 685 } 686 if (output < 0) { 687 output = 0; 688 } 689 c[c_index] = (uint8_t)(output); 690 } 691 } 692} 693 694 695 696 697 698RsdCpuScriptIntrinsicBLAS::RsdCpuScriptIntrinsicBLAS(RsdCpuReferenceImpl *ctx, 699 const Script *s) 700 : RsdCpuScriptIntrinsic(ctx, s, nullptr, RS_SCRIPT_INTRINSIC_ID_BLAS) { 701 702 703} 704 705RsdCpuScriptIntrinsicBLAS::~RsdCpuScriptIntrinsicBLAS() { 706} 707 708 709 710 711 712RsdCpuScriptImpl * rsdIntrinsic_BLAS(RsdCpuReferenceImpl *ctx, 713 const Script *s, const Element *e) { 714 715 return new RsdCpuScriptIntrinsicBLAS(ctx, s); 716} 717