1/* 2 * Copyright (C) 2015 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 18#include "RenderScript.h" 19#include "rsCppInternal.h" 20 21#define NELEM(m) (sizeof(m) / sizeof((m)[0])) 22 23using namespace android; 24using namespace RSC; 25 26// ScriptIntrinsicBLAS APIS 27ScriptIntrinsicBLAS::ScriptIntrinsicBLAS(sp<RS> rs, sp<const Element> e) 28 : ScriptIntrinsic(rs, RS_SCRIPT_INTRINSIC_ID_BLAS, e) { 29 30} 31 32sp<ScriptIntrinsicBLAS> ScriptIntrinsicBLAS::create(sp<RS> rs) { 33 return new ScriptIntrinsicBLAS(rs, Element::U32(rs)); 34} 35 36enum RsBlasDataType { 37 SINGLE, 38 DOUBLE, 39 SINGLE_COMPLEX, 40 DOUBLE_COMPLEX 41}; 42 43static RsBlasCall 44setUpBLASCall(RsBlasDataType dataType, RsBlasFunction func, 45 int TransA, int TransB, int Side, int Uplo, int Diag, 46 int M, int N, int K, int incX, int incY, int KL, int KU, 47 float alphaF, float betaF, double alphaD, double betaD, 48 float alphaCX, float alphaCY, float betaCX, float betaCY, 49 double alphaZX, double alphaZY, double betaZX, double betaZY 50 ) { 51 RsBlasCall call; 52 memset(&call, 0, sizeof(call)); 53 call.func = func; 54 call.transA = (RsBlasTranspose)TransA; 55 call.transB = (RsBlasTranspose)TransB; 56 call.side = (RsBlasSide)Side; 57 call.uplo = (RsBlasUplo)Uplo; 58 call.diag = (RsBlasDiag)Diag; 59 call.M = M; 60 call.N = N; 61 call.K = K; 62 63 switch (dataType) { 64 case SINGLE: 65 // For Single-precision BLAS. 66 call.alpha.f = alphaF; 67 call.beta.f = betaF; 68 break; 69 case DOUBLE: 70 // For Double-precision BLAS. 71 call.alpha.d = alphaD; 72 call.beta.d = betaD; 73 break; 74 case SINGLE_COMPLEX: 75 // For Single-precision complex BLAS. 76 call.alpha.c.r = alphaCX; 77 call.alpha.c.i = alphaCY; 78 call.beta.c.r = betaCX; 79 call.beta.c.i = betaCY; 80 break; 81 case DOUBLE_COMPLEX: 82 // For Double-precision complex BLAS. 83 call.alpha.z.r = alphaZX; 84 call.alpha.z.i = alphaZY; 85 call.beta.z.r = betaZX; 86 call.beta.z.i = betaZY; 87 break; 88 default: 89 break; 90 } 91 92 call.incX = incX; 93 call.incY = incY; 94 call.KL = KL; 95 call.KU = KU; 96 97 return call; 98} 99 100static void 101nScriptIntrinsicBLAS_Single(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA, 102 int TransB, int Side, int Uplo, int Diag, int M, int N, int K, 103 float alpha, RsAllocation A, RsAllocation B, 104 float beta, RsAllocation C, int incX, int incY, int KL, int KU) { 105 RsBlasCall call = setUpBLASCall(SINGLE, func, TransA, TransB, Side, Uplo, Diag, 106 M, N, K, incX, incY, KL, KU, alpha, beta, 0.0, 0.0, 107 0.0f, 0.0f, 0.0f, 0.0f, 0.0, 0.0, 0.0, 0.0); 108 RsAllocation in_allocs[3] = {A, B, C}; 109 tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr, 110 &call, sizeof(call), nullptr, 0)); 111} 112 113 114static void 115nScriptIntrinsicBLAS_Double(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA, 116 int TransB, int Side, int Uplo, int Diag, int M, int N, int K, 117 double alpha, RsAllocation A, RsAllocation B, 118 double beta, RsAllocation C, int incX, int incY, int KL, int KU) { 119 RsBlasCall call = setUpBLASCall(DOUBLE, func, TransA, TransB, Side, Uplo, Diag, 120 M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, alpha, beta, 121 0.0f, 0.0f, 0.0f, 0.0f, 0.0, 0.0, 0.0, 0.0); 122 RsAllocation in_allocs[3] = {A, B, C}; 123 tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr, 124 &call, sizeof(call), nullptr, 0)); 125} 126 127static void 128nScriptIntrinsicBLAS_Complex(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA, 129 int TransB, int Side, int Uplo, int Diag, int M, int N, int K, 130 float alphaX, float alphaY, RsAllocation A, RsAllocation B, 131 float betaX, float betaY, RsAllocation C, int incX, int incY, int KL, int KU) { 132 RsBlasCall call = setUpBLASCall(SINGLE_COMPLEX, func, TransA, TransB, Side, Uplo, Diag, 133 M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, 0.0, 0.0, 134 alphaX, alphaY, betaX, betaY, 0.0, 0.0, 0.0, 0.0); 135 RsAllocation in_allocs[3] = {A, B, C}; 136 tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr, 137 &call, sizeof(call), nullptr, 0)); 138} 139 140static void 141nScriptIntrinsicBLAS_Z(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA, 142 int TransB, int Side, int Uplo, int Diag, int M, int N, int K, 143 double alphaX, double alphaY, RsAllocation A, RsAllocation B, 144 double betaX, double betaY, RsAllocation C, int incX, int incY, int KL, int KU) { 145 RsBlasCall call = setUpBLASCall(DOUBLE_COMPLEX, func, TransA, TransB, Side, Uplo, Diag, 146 M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, 0.0, 0.0, 147 0.0f, 0.0f, 0.0f, 0.0f, alphaX, alphaY, betaX, betaY); 148 RsAllocation in_allocs[3] = {A, B, C}; 149 tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr, 150 &call, sizeof(call), nullptr, 0)); 151} 152 153 154static void 155nScriptIntrinsicBLAS_BNNM(RS* mRS, RsContext con, RsScript id, int M, int N, int K, 156 RsAllocation A, int a_offset, RsAllocation B, int b_offset, 157 RsAllocation C, int c_offset, int c_mult_int) { 158 RsBlasCall call; 159 memset(&call, 0, sizeof(call)); 160 call.func = RsBlas_bnnm; 161 call.M = M; 162 call.N = N; 163 call.K = K; 164 call.a_offset = a_offset & 0xFF; 165 call.b_offset = b_offset & 0xFF; 166 call.c_offset = c_offset; 167 call.c_mult_int = c_mult_int; 168 169 RsAllocation in_allocs[3] = {A, B, C}; 170 tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr, 171 &call, sizeof(call), nullptr, 0)); 172} 173 174/** 175 * Level 2 BLAS 176 */ 177static void validateGEMV(RS* mRS, sp<const Element> e, RsBlasTranspose TransA, sp<Allocation> A, 178 sp<Allocation> X, int incX, sp<Allocation> Y, int incY) { 179 int M = A->getType()->getY(); 180 int N = A->getType()->getX(); 181 if (!A->getType()->getElement()->isCompatible(e) || 182 !X->getType()->getElement()->isCompatible(e) || 183 !Y->getType()->getElement()->isCompatible(e)) { 184 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 185 } 186 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { 187 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 188 } 189 190 if (incX <= 0 || incY <= 0) { 191 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 192 } 193 int expectedXDim = -1, expectedYDim = -1; 194 if (TransA == RsBlasNoTrans) { 195 expectedXDim = 1 + (N - 1) * incX; 196 expectedYDim = 1 + (M - 1) * incY; 197 } else { 198 expectedXDim = 1 + (M - 1) * incX; 199 expectedYDim = 1 + (N - 1) * incY; 200 } 201 if ((int)X->getType()->getX() != expectedXDim || 202 (int)Y->getType()->getX() != expectedYDim) { 203 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GEMV"); 204 } 205} 206 207void ScriptIntrinsicBLAS::SGEMV(RsBlasTranspose TransA, float alpha, sp<Allocation> A, sp<Allocation> X, 208 int incX, float beta, sp<Allocation> Y, int incY) { 209 validateGEMV(mRS, Element::F32(mRS), TransA, A, X, incX, Y, incY); 210 int M = A->getType()->getY(); 211 int N = A->getType()->getX(); 212 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgemv, 213 TransA, 0, 0, 0, 0, M, N, 0, 214 alpha, A->getID(), X->getID(), 215 beta, Y->getID(), incX, incY, 0, 0); 216} 217 218void ScriptIntrinsicBLAS::DGEMV(RsBlasTranspose TransA, double alpha, sp<Allocation> A, sp<Allocation> X, 219 int incX, double beta, sp<Allocation> Y, int incY) { 220 validateGEMV(mRS, Element::F64(mRS), TransA, A, X, incX, Y, incY); 221 int M = A->getType()->getY(); 222 int N = A->getType()->getX(); 223 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgemv, 224 TransA, 0, 0, 0, 0, M, N, 0, 225 alpha, A->getID(), X->getID(), 226 beta, Y->getID(), incX, incY, 0, 0); 227} 228 229void ScriptIntrinsicBLAS::CGEMV(RsBlasTranspose TransA, Float2 alpha, sp<Allocation> A, sp<Allocation> X, 230 int incX, Float2 beta, sp<Allocation> Y, int incY) { 231 validateGEMV(mRS, Element::F32_2(mRS), TransA, A, X, incX, Y, incY); 232 int M = A->getType()->getY(); 233 int N = A->getType()->getX(); 234 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgemv, 235 TransA, 0, 0, 0, 0, M, N, 0, 236 alpha.x, alpha.y, A->getID(), X->getID(), 237 beta.x, beta.y, Y->getID(), incX, incY, 0, 0); 238} 239 240void ScriptIntrinsicBLAS::ZGEMV(RsBlasTranspose TransA, Double2 alpha, sp<Allocation> A, sp<Allocation> X, 241 int incX, Double2 beta, sp<Allocation> Y, int incY) { 242 validateGEMV(mRS, Element::F64_2(mRS), TransA, A, X, incX, Y, incY); 243 int M = A->getType()->getY(); 244 int N = A->getType()->getX(); 245 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgemv, 246 TransA, 0, 0, 0, 0, M, N, 0, 247 alpha.x, alpha.y, A->getID(), X->getID(), 248 beta.x, beta.y, Y->getID(), incX, incY, 0, 0); 249} 250 251void ScriptIntrinsicBLAS::SGBMV(RsBlasTranspose TransA, int KL, int KU, float alpha, sp<Allocation> A, 252 sp<Allocation> X, int incX, float beta, sp<Allocation> Y, int incY) { 253 // GBMV has the same validation requirements as GEMV + KL and KU >= 0 254 validateGEMV(mRS, Element::F32(mRS), TransA, A, X, incX, Y, incY); 255 if (KL < 0 || KU < 0) { 256 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0"); 257 } 258 int M = A->getType()->getY(); 259 int N = A->getType()->getX(); 260 261 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgbmv, 262 TransA, 0, 0, 0, 0, M, N, 0, 263 alpha, A->getID(), X->getID(), 264 beta, Y->getID(), incX, incY, KL, KU); 265} 266 267void ScriptIntrinsicBLAS::DGBMV(RsBlasTranspose TransA, int KL, int KU, double alpha, sp<Allocation> A, 268 sp<Allocation> X, int incX, double beta, sp<Allocation> Y, int incY) { 269 // GBMV has the same validation requirements as GEMV + KL and KU >= 0 270 validateGEMV(mRS, Element::F64(mRS), TransA, A, X, incX, Y, incY); 271 if (KL < 0 || KU < 0) { 272 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0"); 273 } 274 int M = A->getType()->getY(); 275 int N = A->getType()->getX(); 276 277 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgbmv, 278 TransA, 0, 0, 0, 0, M, N, 0, 279 alpha, A->getID(), X->getID(), 280 beta, Y->getID(), incX, incY, KL, KU); 281} 282 283void ScriptIntrinsicBLAS::CGBMV(RsBlasTranspose TransA, int KL, int KU, Float2 alpha, sp<Allocation> A, 284 sp<Allocation> X, int incX, Float2 beta, sp<Allocation> Y, int incY) { 285 // GBMV has the same validation requirements as GEMV + KL and KU >= 0 286 validateGEMV(mRS, Element::F32_2(mRS), TransA, A, X, incX, Y, incY); 287 if (KL < 0 || KU < 0) { 288 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0"); 289 } 290 int M = A->getType()->getY(); 291 int N = A->getType()->getX(); 292 293 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgbmv, 294 TransA, 0, 0, 0, 0, M, N, 0, 295 alpha.x, alpha.y, A->getID(), X->getID(), 296 beta.x, beta.y, Y->getID(), incX, incY, KL, KU); 297} 298 299void ScriptIntrinsicBLAS::ZGBMV(RsBlasTranspose TransA, int KL, int KU, Double2 alpha, sp<Allocation> A, 300 sp<Allocation> X, int incX, Double2 beta, sp<Allocation> Y, int incY) { 301 // GBMV has the same validation requirements as GEMV + KL and KU >= 0 302 validateGEMV(mRS, Element::F64_2(mRS), TransA, A, X, incX, Y, incY); 303 if (KL < 0 || KU < 0) { 304 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0"); 305 } 306 int M = A->getType()->getY(); 307 int N = A->getType()->getX(); 308 309 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgbmv, 310 TransA, 0, 0, 0, 0, M, N, 0, 311 alpha.x, alpha.y, A->getID(), X->getID(), 312 beta.x, beta.y, Y->getID(), incX, incY, KL, KU); 313} 314 315static void validateTRMV(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, RsBlasTranspose TransA, 316 RsBlasDiag Diag, sp<Allocation> A, sp<Allocation> X, int incX) { 317 int N = A->getType()->getY(); 318 if ((int)A->getType()->getX() != N) { 319 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a square matrix for TRMV"); 320 } 321 if (!A->getType()->getElement()->isCompatible(e) || 322 !X->getType()->getElement()->isCompatible(e)) { 323 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 324 } 325 if (X->getType()->getY() > 1) { 326 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 327 } 328 329 if (incX <= 0) { 330 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 331 } 332 int expectedXDim = 1 + (N - 1) * incX; 333 if ((int)X->getType()->getX() != expectedXDim) { 334 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for TRMV"); 335 } 336} 337 338static int validateTPMV(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, RsBlasTranspose TransA, 339 RsBlasDiag Diag, sp<Allocation> Ap, sp<Allocation> X, int incX) { 340 if (!Ap->getType()->getElement()->isCompatible(e) || 341 !X->getType()->getElement()->isCompatible(e)) { 342 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 343 } 344 if (X->getType()->getY() > 1) { 345 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 346 } 347 348 if (Ap->getType()->getY() > 1) { 349 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1"); 350 } 351 352 int N = sqrt((double)Ap->getType()->getX() * 2); 353 if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) { 354 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap"); 355 } 356 if (incX <= 0) { 357 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 358 } 359 int expectedXDim = 1 + (N - 1) * incX; 360 if ((int)X->getType()->getX() != expectedXDim) { 361 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for TPMV"); 362 } 363 364 return N; 365} 366 367 368void ScriptIntrinsicBLAS::STRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 369 sp<Allocation> A, sp<Allocation> X, int incX) { 370 validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX); 371 int N = A->getType()->getY(); 372 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strmv, 373 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 374 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 375} 376 377void ScriptIntrinsicBLAS::DTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 378 sp<Allocation> A, sp<Allocation> X, int incX) { 379 validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX); 380 int N = A->getType()->getY(); 381 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrmv, 382 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 383 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 384} 385 386void ScriptIntrinsicBLAS::CTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 387 sp<Allocation> A, sp<Allocation> X, int incX) { 388 validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX); 389 int N = A->getType()->getY(); 390 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrmv, 391 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, 392 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 393} 394 395void ScriptIntrinsicBLAS::ZTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 396 sp<Allocation> A, sp<Allocation> X, int incX) { 397 validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX); 398 int N = A->getType()->getY(); 399 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrmv, 400 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, 401 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 402} 403 404void ScriptIntrinsicBLAS::STBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 405 int K, sp<Allocation> A, sp<Allocation> X, int incX) { 406 // TBMV has the same requirements as TRMV + K >= 0 407 if (K < 0) { 408 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); 409 } 410 validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX); 411 int N = A->getType()->getY(); 412 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stbmv, 413 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 414 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 415} 416 417void ScriptIntrinsicBLAS::DTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 418 int K, sp<Allocation> A, sp<Allocation> X, int incX) { 419 // TBMV has the same requirements as TRMV + K >= 0 420 if (K < 0) { 421 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); 422 } 423 validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX); 424 int N = A->getType()->getY(); 425 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtbmv, 426 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 427 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 428} 429 430void ScriptIntrinsicBLAS::CTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 431 int K, sp<Allocation> A, sp<Allocation> X, int incX) { 432 // TBMV has the same requirements as TRMV + K >= 0 433 if (K < 0) { 434 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); 435 } 436 validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX); 437 int N = A->getType()->getY(); 438 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctbmv, 439 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, 440 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 441} 442 443void ScriptIntrinsicBLAS::ZTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 444 int K, sp<Allocation> A, sp<Allocation> X, int incX) { 445 // TBMV has the same requirements as TRMV + K >= 0 446 if (K < 0) { 447 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); 448 } 449 validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX); 450 int N = A->getType()->getY(); 451 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztbmv, 452 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, 453 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 454} 455 456void ScriptIntrinsicBLAS::STPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 457 sp<Allocation> Ap, sp<Allocation> X, int incX) { 458 int N = validateTPMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, Ap, X, incX); 459 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stpmv, 460 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 461 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 462} 463 464void ScriptIntrinsicBLAS::DTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 465 sp<Allocation> Ap, sp<Allocation> X, int incX) { 466 int N = validateTPMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, Ap, X, incX); 467 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtpmv, 468 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 469 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 470} 471 472void ScriptIntrinsicBLAS::CTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 473 sp<Allocation> Ap, sp<Allocation> X, int incX) { 474 int N = validateTPMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX); 475 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctpmv, 476 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, 477 Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 478} 479 480void ScriptIntrinsicBLAS::ZTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 481 sp<Allocation> Ap, sp<Allocation> X, int incX) { 482 int N = validateTPMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX); 483 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztpmv, 484 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, 485 Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 486} 487 488void ScriptIntrinsicBLAS::STRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 489 sp<Allocation> A, sp<Allocation> X, int incX) { 490 // TRSV is the same as TRMV 491 validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX); 492 int N = A->getType()->getY(); 493 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strsv, 494 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 495 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 496} 497 498void ScriptIntrinsicBLAS::DTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 499 sp<Allocation> A, sp<Allocation> X, int incX) { 500 // TRSV is the same as TRMV 501 validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX); 502 int N = A->getType()->getY(); 503 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrsv, 504 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 505 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 506 507} 508 509void ScriptIntrinsicBLAS::CTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 510 sp<Allocation> A, sp<Allocation> X, int incX) { 511 // TRSV is the same as TRMV 512 validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX); 513 int N = A->getType()->getY(); 514 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrsv, 515 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, 516 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 517 518} 519 520void ScriptIntrinsicBLAS::ZTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 521 sp<Allocation> A, sp<Allocation> X, int incX) { 522 // TRSV is the same as TRMV 523 validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX); 524 int N = A->getType()->getY(); 525 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrsv, 526 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, 527 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 528 529} 530 531void ScriptIntrinsicBLAS::STBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 532 int K, sp<Allocation> A, sp<Allocation> X, int incX) { 533 // TBSV is the same as TRMV + K >= 0 534 validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX); 535 int N = A->getType()->getY(); 536 if (K < 0) { 537 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive"); 538 } 539 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stbsv, 540 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 541 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 542} 543 544void ScriptIntrinsicBLAS::DTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 545 int K, sp<Allocation> A, sp<Allocation> X, int incX) { 546 // TBSV is the same as TRMV + K >= 0 547 validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX); 548 int N = A->getType()->getY(); 549 if (K < 0) { 550 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive"); 551 } 552 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtbsv, 553 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 554 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 555} 556 557void ScriptIntrinsicBLAS::CTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 558 int K, sp<Allocation> A, sp<Allocation> X, int incX) { 559 // TBSV is the same as TRMV + K >= 0 560 validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX); 561 int N = A->getType()->getY(); 562 if (K < 0) { 563 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive"); 564 } 565 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctbsv, 566 TransA, 0, 0, Uplo, Diag, 0, N, K, 567 0, 0, A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 568} 569 570void ScriptIntrinsicBLAS::ZTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 571 int K, sp<Allocation> A, sp<Allocation> X, int incX) { 572 // TBSV is the same as TRMV + K >= 0 573 validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX); 574 int N = A->getType()->getY(); 575 if (K < 0) { 576 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive"); 577 } 578 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztbsv, 579 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, 580 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 581} 582 583void ScriptIntrinsicBLAS::STPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 584 sp<Allocation> Ap, sp<Allocation> X, int incX) { 585 // TPSV is same as TPMV 586 int N = validateTPMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, Ap, X, incX); 587 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stpsv, 588 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 589 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 590} 591 592void ScriptIntrinsicBLAS::DTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 593 sp<Allocation> Ap, sp<Allocation> X, int incX) { 594 // TPSV is same as TPMV 595 int N = validateTPMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, Ap, X, incX); 596 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtpsv, 597 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 598 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 599} 600 601void ScriptIntrinsicBLAS::CTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 602 sp<Allocation> Ap, sp<Allocation> X, int incX) { 603 // TPSV is same as TPMV 604 int N = validateTPMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX); 605 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctpsv, 606 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, 607 Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 608} 609 610void ScriptIntrinsicBLAS::ZTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 611 sp<Allocation> Ap, sp<Allocation> X, int incX) { 612 // TPSV is same as TPMV 613 int N = validateTPMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX); 614 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztpsv, 615 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, 616 Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 617} 618 619/** 620 * Level 2, S and D only 621 */ 622static int validateSYMV(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, sp<Allocation> A, 623 sp<Allocation> X, sp<Allocation> Y, int incX, int incY) { 624 int N = A->getType()->getY(); 625 if ((int)A->getType()->getX() != N) { 626 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a square matrix for SYMV"); 627 } 628 if (!A->getType()->getElement()->isCompatible(e) || 629 !X->getType()->getElement()->isCompatible(e) || 630 !Y->getType()->getElement()->isCompatible(e) ) { 631 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 632 } 633 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { 634 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 635 } 636 637 if (incX <= 0 || incY <= 0) { 638 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 639 } 640 int expectedXDim = 1 + (N - 1) * incX; 641 if ((int)X->getType()->getX() != expectedXDim) { 642 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYMV"); 643 } 644 int expectedYDim = 1 + (N - 1) * incY; 645 if ((int)Y->getType()->getX() != expectedYDim) { 646 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYMV"); 647 } 648 return N; 649} 650static int validateSPMV(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, sp<Allocation> Ap, 651 sp<Allocation> X, int incX, sp<Allocation> Y, int incY) { 652 if (!Ap->getType()->getElement()->isCompatible(e) || 653 !X->getType()->getElement()->isCompatible(e) || 654 !Y->getType()->getElement()->isCompatible(e)) { 655 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 656 } 657 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { 658 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 659 } 660 661 if (Ap->getType()->getY() > 1) { 662 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1"); 663 } 664 665 int N = sqrt((double)Ap->getType()->getX() * 2); 666 if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) { 667 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap"); 668 } 669 if (incX <= 0 || incY <= 0) { 670 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 671 } 672 int expectedXDim = 1 + (N - 1) * incX; 673 if ((int)X->getType()->getX() != expectedXDim) { 674 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPMV"); 675 } 676 int expectedYDim = 1 + (N - 1) * incY; 677 if ((int)Y->getType()->getX() != expectedYDim) { 678 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPMV"); 679 } 680 681 return N; 682} 683static void validateGER(RS* mRS, sp<const Element> e, sp<Allocation> X, int incX, 684 sp<Allocation> Y, int incY, sp<Allocation> A) { 685 if (!A->getType()->getElement()->isCompatible(e) || 686 !X->getType()->getElement()->isCompatible(e) || 687 !Y->getType()->getElement()->isCompatible(e) ) { 688 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 689 } 690 691 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { 692 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 693 } 694 695 int M = A->getType()->getY(); 696 int N = A->getType()->getX(); 697 698 if (N < 1 || M < 1) { 699 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "M and N must be 1 or greater for GER"); 700 } 701 if (incX <= 0 || incY <= 0) { 702 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 703 } 704 int expectedXDim = 1 + (M - 1) * incX; 705 if ((int)X->getType()->getX() != expectedXDim) { 706 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GER"); 707 } 708 int expectedYDim = 1 + (N - 1) * incY; 709 if ((int)Y->getType()->getX() != expectedYDim) { 710 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GER"); 711 } 712 713 714} 715static int validateSYR(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, 716 sp<Allocation> X, int incX, sp<Allocation> A) { 717 if (!A->getType()->getElement()->isCompatible(e) || 718 !X->getType()->getElement()->isCompatible(e)) { 719 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 720 } 721 722 int N = A->getType()->getX(); 723 724 if (X->getType()->getY() > 1) { 725 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 726 } 727 if (N != (int)A->getType()->getY()) { 728 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a symmetric matrix"); 729 } 730 if (incX <= 0) { 731 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 732 } 733 int expectedXDim = 1 + (N - 1) * incX; 734 if ((int)X->getType()->getX() != expectedXDim) { 735 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYR"); 736 } 737 return N; 738} 739static int validateSPR(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, 740 sp<Allocation> X, int incX, sp<Allocation> Ap) { 741 if (!Ap->getType()->getElement()->isCompatible(e) || 742 !X->getType()->getElement()->isCompatible(e)) { 743 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 744 } 745 if (X->getType()->getY() > 1) { 746 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 747 } 748 749 if (Ap->getType()->getY() > 1) { 750 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1"); 751 } 752 753 int N = sqrt((double)Ap->getType()->getX() * 2); 754 if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) { 755 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap"); 756 } 757 if (incX <= 0) { 758 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 759 } 760 int expectedXDim = 1 + (N - 1) * incX; 761 if ((int)X->getType()->getX() != expectedXDim) { 762 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPR"); 763 } 764 765 return N; 766} 767 768static int validateSYR2(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, sp<Allocation> X, 769 int incX, sp<Allocation> Y, int incY, sp<Allocation> A) { 770 if (!A->getType()->getElement()->isCompatible(e) || 771 !X->getType()->getElement()->isCompatible(e) || 772 !Y->getType()->getElement()->isCompatible(e)) { 773 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 774 } 775 776 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { 777 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 778 } 779 780 int N = A->getType()->getX(); 781 782 if (N != (int)A->getType()->getY()) { 783 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a symmetric matrix"); 784 } 785 if (incX <= 0 || incY <= 0) { 786 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 787 } 788 int expectedXDim = 1 + (N - 1) * incX; 789 int expectedYDim = 1 + (N - 1) * incY; 790 if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) { 791 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYR"); 792 } 793 return N; 794 795} 796static int validateSPR2(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, sp<Allocation> X, 797 int incX, sp<Allocation> Y, int incY, sp<Allocation> Ap) { 798 if (!Ap->getType()->getElement()->isCompatible(e) || 799 !X->getType()->getElement()->isCompatible(e) || 800 !Y->getType()->getElement()->isCompatible(e)) { 801 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 802 } 803 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { 804 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 805 } 806 807 if (Ap->getType()->getY() > 1) { 808 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1"); 809 } 810 811 int N = sqrt((double)Ap->getType()->getX() * 2); 812 if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) { 813 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap"); 814 } 815 if (incX <= 0 || incY <= 0) { 816 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 817 } 818 int expectedXDim = 1 + (N - 1) * incX; 819 int expectedYDim = 1 + (N - 1) * incY; 820 if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) { 821 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPR2"); 822 } 823 824 return N; 825} 826 827void ScriptIntrinsicBLAS::SSYMV(RsBlasUplo Uplo, float alpha, sp<Allocation> A, sp<Allocation> X, 828 int incX, float beta, sp<Allocation> Y, int incY) { 829 int N = validateSYMV(mRS, Element::F32(mRS), Uplo, A, X, Y, incX, incY); 830 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssymv, 831 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 832 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); 833} 834 835void ScriptIntrinsicBLAS::SSBMV(RsBlasUplo Uplo, int K, float alpha, sp<Allocation> A, sp<Allocation> X, 836 int incX, float beta, sp<Allocation> Y, int incY) { 837 // SBMV is the same as SYMV + K >= 0 838 if (K < 0) { 839 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); 840 } 841 int N = validateSYMV(mRS, Element::F32(mRS), Uplo, A, X, Y, incX, incY); 842 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssbmv, 843 0, 0, 0, Uplo, 0, 0, N, K, alpha, 844 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); 845} 846 847void ScriptIntrinsicBLAS::SSPMV(RsBlasUplo Uplo, float alpha, sp<Allocation> Ap, sp<Allocation> X, 848 int incX, float beta, sp<Allocation> Y, int incY) { 849 int N = validateSPMV(mRS, Element::F32(mRS), Uplo, Ap, X, incX, Y, incY); 850 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspmv, 851 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 852 Ap->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); 853} 854 855void ScriptIntrinsicBLAS::SGER(float alpha, sp<Allocation> X, int incX, 856 sp<Allocation> Y, int incY, sp<Allocation> A) { 857 int M = A->getType()->getY(); 858 int N = A->getType()->getX(); 859 validateGER(mRS, Element::F32(mRS), X, incX, Y, incY, A); 860 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sger, 861 0, 0, 0, 0, 0, M, N, 0, alpha, 862 X->getID(), Y->getID(), 0.f, A->getID(), incX, incY, 0, 0); 863} 864 865void ScriptIntrinsicBLAS::SSYR(RsBlasUplo Uplo, float alpha, sp<Allocation> X, 866 int incX, sp<Allocation> A) { 867 int N = validateSYR(mRS, Element::F32(mRS), Uplo, X, incX, A); 868 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr, 869 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 870 X->getID(), A->getID(), 0.f, 0, incX, 0, 0, 0); 871} 872 873void ScriptIntrinsicBLAS::SSPR(RsBlasUplo Uplo, float alpha, sp<Allocation> X, 874 int incX, sp<Allocation> Ap) { 875 int N = validateSPR(mRS, Element::F32(mRS), Uplo, X, incX, Ap); 876 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspr, 877 0, 0, 0, Uplo, 0, 0, N, 0, 878 alpha, X->getID(), Ap->getID(), 0.f, 0, incX, 0, 0, 0); 879} 880 881void ScriptIntrinsicBLAS::SSYR2(RsBlasUplo Uplo, float alpha, sp<Allocation> X, int incX, 882 sp<Allocation> Y, int incY, sp<Allocation> A) { 883 int N = validateSYR2(mRS, Element::F32(mRS), Uplo, X, incX, Y, incY, A); 884 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr2, 885 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 886 X->getID(), Y->getID(), 0, A->getID(), incX, incY, 0, 0); 887} 888 889void ScriptIntrinsicBLAS::SSPR2(RsBlasUplo Uplo, float alpha, sp<Allocation> X, int incX, 890 sp<Allocation> Y, int incY, sp<Allocation> Ap) { 891 int N = validateSPR2(mRS, Element::F32(mRS), Uplo, X, incX, Y, incY, Ap); 892 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspr2, 893 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 894 X->getID(), Y->getID(), 0, Ap->getID(), incX, incY, 0, 0); 895} 896 897void ScriptIntrinsicBLAS::DSYMV(RsBlasUplo Uplo, double alpha, sp<Allocation> A, sp<Allocation> X, 898 int incX, double beta, sp<Allocation> Y, int incY) { 899 int N = validateSYMV(mRS, Element::F64(mRS), Uplo, A, X, Y, incX, incY); 900 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsymv, 901 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 902 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); 903} 904 905void ScriptIntrinsicBLAS::DSBMV(RsBlasUplo Uplo, int K, double alpha, sp<Allocation> A, sp<Allocation> X, 906 int incX, double beta, sp<Allocation> Y, int incY) { 907 // SBMV is the same as SYMV + K >= 0 908 if (K < 0) { 909 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); 910 } 911 int N = validateSYMV(mRS, Element::F64(mRS), Uplo, A, X, Y, incX, incY); 912 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsbmv, 913 0, 0, 0, Uplo, 0, 0, N, K, alpha, 914 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); 915} 916 917void ScriptIntrinsicBLAS::DSPMV(RsBlasUplo Uplo, double alpha, sp<Allocation> Ap, sp<Allocation> X, 918 int incX, double beta, sp<Allocation> Y, int incY) { 919 int N = validateSPMV(mRS, Element::F64(mRS), Uplo, Ap, X, incX, Y, incY); 920 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspmv, 921 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 922 Ap->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); 923} 924 925void ScriptIntrinsicBLAS::DGER(double alpha, sp<Allocation> X, int incX, sp<Allocation> Y, 926 int incY, sp<Allocation> A) { 927 int M = A->getType()->getY(); 928 int N = A->getType()->getX(); 929 validateGER(mRS, Element::F64(mRS), X, incX, Y, incY, A); 930 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dger, 931 0, 0, 0, 0, 0, M, N, 0, alpha, 932 X->getID(), Y->getID(), 0.f, A->getID(), incX, incY, 0, 0); 933} 934 935void ScriptIntrinsicBLAS::DSYR(RsBlasUplo Uplo, double alpha, sp<Allocation> X, 936 int incX, sp<Allocation> A) { 937 int N = validateSYR(mRS, Element::F64(mRS), Uplo, X, incX, A); 938 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr, 939 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 940 X->getID(), A->getID(), 0.f, 0, incX, 0, 0, 0); 941} 942 943void ScriptIntrinsicBLAS::DSPR(RsBlasUplo Uplo, double alpha, sp<Allocation> X, 944 int incX, sp<Allocation> Ap) { 945 int N = validateSPR(mRS, Element::F64(mRS), Uplo, X, incX, Ap); 946 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspr, 947 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 948 X->getID(), Ap->getID(), 0.f, 0, incX, 0, 0, 0); 949} 950 951void ScriptIntrinsicBLAS::DSYR2(RsBlasUplo Uplo, double alpha, sp<Allocation> X, int incX, 952 sp<Allocation> Y, int incY, sp<Allocation> A) { 953 int N = validateSYR2(mRS, Element::F64(mRS), Uplo, X, incX, Y, incY, A); 954 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr2, 955 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 956 X->getID(), Y->getID(), 0, A->getID(), incX, incY, 0, 0); 957} 958 959void ScriptIntrinsicBLAS::DSPR2(RsBlasUplo Uplo, double alpha, sp<Allocation> X, int incX, 960 sp<Allocation> Y, int incY, sp<Allocation> Ap) { 961 int N = validateSPR2(mRS, Element::F64(mRS), Uplo, X, incX, Y, incY, Ap); 962 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspr2, 963 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 964 X->getID(), Y->getID(), 0, Ap->getID(), incX, incY, 0, 0); 965} 966 967 968/** 969 * Level 2, C and Z only 970 */ 971 972static void validateGERU(RS* mRS, sp<const Element> e, sp<Allocation> X, int incX, 973 sp<Allocation> Y, int incY, sp<Allocation> A) { 974 if (!A->getType()->getElement()->isCompatible(e) || 975 !X->getType()->getElement()->isCompatible(e) || 976 !Y->getType()->getElement()->isCompatible(e)) { 977 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 978 } 979 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { 980 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 981 } 982 983 int M = A->getType()->getY(); 984 int N = A->getType()->getX(); 985 if (incX <= 0 || incY <= 0) { 986 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 987 } 988 int expectedXDim = 1 + (M - 1) * incX; 989 if ((int)X->getType()->getX() != expectedXDim) { 990 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GERU"); 991 } 992 int expectedYDim = 1 + (N - 1) * incY; 993 if ((int)Y->getType()->getX() != expectedYDim) { 994 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GERU"); 995 } 996 997} 998 999void ScriptIntrinsicBLAS::CHEMV(RsBlasUplo Uplo, Float2 alpha, sp<Allocation> A, 1000 sp<Allocation> X, int incX, Float2 beta, sp<Allocation> Y, int incY) { 1001 // HEMV is the same as SYR2 validation-wise 1002 int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A); 1003 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chemv, 1004 0, 0, 0, Uplo, 0, 0, N, 0, 1005 alpha.x, alpha.y, A->getID(), X->getID(), 1006 beta.x, beta.y, Y->getID(), incX, incY, 0, 0); 1007} 1008 1009void ScriptIntrinsicBLAS::CHBMV(RsBlasUplo Uplo, int K, Float2 alpha, sp<Allocation> A, 1010 sp<Allocation> X, int incX, Float2 beta, sp<Allocation> Y, int incY) { 1011 // HBMV is the same as SYR2 validation-wise 1012 int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A); 1013 if (K < 0) { 1014 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be 0 or greater for HBMV"); 1015 } 1016 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chbmv, 1017 0, 0, 0, Uplo, 0, 0, N, K, 1018 alpha.x, alpha.y, A->getID(), X->getID(), 1019 beta.x, beta.y, Y->getID(), incX, incY, 0, 0); 1020} 1021 1022void ScriptIntrinsicBLAS::CHPMV(RsBlasUplo Uplo, Float2 alpha, sp<Allocation> Ap, 1023 sp<Allocation> X, int incX, Float2 beta, sp<Allocation> Y, int incY) { 1024 // HPMV is the same as SPR2 1025 int N = validateSPR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, Ap); 1026 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpmv, 1027 0, 0, 0, Uplo, 0, 0, N, 0, 1028 alpha.x, alpha.y, Ap->getID(), X->getID(), 1029 beta.x, beta.y, Y->getID(), incX, incY, 0, 0); 1030} 1031 1032void ScriptIntrinsicBLAS::CGERU(Float2 alpha, sp<Allocation> X, int incX, 1033 sp<Allocation> Y, int incY, sp<Allocation> A) { 1034 validateGERU(mRS, Element::F32_2(mRS), X, incX, Y, incY, A); 1035 int M = A->getType()->getY(); 1036 int N = A->getType()->getX(); 1037 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgeru, 1038 0, 0, 0, 0, 0, M, N, 0, 1039 alpha.x, alpha.y, X->getID(), Y->getID(), 1040 0, 0, A->getID(), incX, incY, 0, 0); 1041} 1042 1043void ScriptIntrinsicBLAS::CGERC(Float2 alpha, sp<Allocation> X, int incX, 1044 sp<Allocation> Y, int incY, sp<Allocation> A) { 1045 // Same as GERU 1046 validateGERU(mRS, Element::F32_2(mRS), X, incX, Y, incY, A); 1047 int M = A->getType()->getY(); 1048 int N = A->getType()->getX(); 1049 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgerc, 1050 0, 0, 0, 0, 0, M, N, 0, 1051 alpha.x, alpha.y, X->getID(), Y->getID(), 1052 0, 0, A->getID(), incX, incY, 0, 0); 1053} 1054 1055void ScriptIntrinsicBLAS::CHER(RsBlasUplo Uplo, float alpha, sp<Allocation> X, 1056 int incX, sp<Allocation> A) { 1057 // Same as SYR 1058 int N = validateSYR(mRS, Element::F32_2(mRS), Uplo, X, incX, A); 1059 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher, 1060 0, 0, 0, Uplo, 0, 0, N, 0, 1061 alpha, 0, X->getID(), 0, 1062 0, 0, A->getID(), incX, 0, 0, 0); 1063} 1064 1065void ScriptIntrinsicBLAS::CHPR(RsBlasUplo Uplo, float alpha, sp<Allocation> X, 1066 int incX, sp<Allocation> Ap) { 1067 // Equivalent to SPR for validation 1068 int N = validateSPR(mRS, Element::F32_2(mRS), Uplo, X, incX, Ap); 1069 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpr, 1070 0, 0, 0, Uplo, 0, 0, N, 0, 1071 alpha, 0, X->getID(), 0, 1072 0, 0, Ap->getID(), incX, 0, 0, 0); 1073} 1074 1075void ScriptIntrinsicBLAS::CHER2(RsBlasUplo Uplo, Float2 alpha, sp<Allocation> X, int incX, 1076 sp<Allocation> Y, int incY, sp<Allocation> A) { 1077 // Same as SYR2 1078 int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A); 1079 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher2, 1080 0, 0, 0, Uplo, 0, 0, N, 0, 1081 alpha.x, alpha.y, X->getID(), Y->getID(), 1082 0, 0, A->getID(), incX, incY, 0, 0); 1083} 1084 1085void ScriptIntrinsicBLAS::CHPR2(RsBlasUplo Uplo, Float2 alpha, sp<Allocation> X, int incX, 1086 sp<Allocation> Y, int incY, sp<Allocation> Ap) { 1087 // Same as SPR2 1088 int N = validateSPR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, Ap); 1089 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpr2, 1090 0, 0, 0, Uplo, 0, 0, N, 0, 1091 alpha.x, alpha.y, X->getID(), Y->getID(), 1092 0, 0, Ap->getID(), incX, incY, 0, 0); 1093} 1094 1095void ScriptIntrinsicBLAS::ZHEMV(RsBlasUplo Uplo, Double2 alpha, sp<Allocation> A, 1096 sp<Allocation> X, int incX, Double2 beta, sp<Allocation> Y, int incY) { 1097 // HEMV is the same as SYR2 validation-wise 1098 int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A); 1099 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhemv, 1100 0, 0, 0, Uplo, 0, 0, N, 0, 1101 alpha.x, alpha.y, A->getID(), X->getID(), 1102 beta.x, beta.y, Y->getID(), incX, incY, 0, 0); 1103} 1104 1105void ScriptIntrinsicBLAS::ZHBMV(RsBlasUplo Uplo, int K, Double2 alpha, sp<Allocation> A, sp<Allocation> X, 1106 int incX, Double2 beta, sp<Allocation> Y, int incY) { 1107 // HBMV is the same as SYR2 validation-wise 1108 int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A); 1109 if (K < 0) { 1110 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be 0 or greater for HBMV"); 1111 } 1112 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhbmv, 1113 0, 0, 0, Uplo, 0, 0, N, K, 1114 alpha.x, alpha.y, A->getID(), X->getID(), 1115 beta.x, beta.y, Y->getID(), incX, incY, 0, 0); 1116} 1117 1118void ScriptIntrinsicBLAS::ZHPMV(RsBlasUplo Uplo, Double2 alpha, sp<Allocation> Ap, sp<Allocation> X, 1119 int incX, Double2 beta, sp<Allocation> Y, int incY) { 1120 // HPMV is the same as SPR2 1121 int N = validateSPR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, Ap); 1122 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpmv, 1123 0, 0, 0, Uplo, 0, 0, N, 0, 1124 alpha.x, alpha.y, Ap->getID(), X->getID(), 1125 beta.x, beta.y, Y->getID(), incX, incY, 0, 0); 1126} 1127 1128void ScriptIntrinsicBLAS::ZGERU(Double2 alpha, sp<Allocation> X, int incX, 1129 sp<Allocation> Y, int incY, sp<Allocation> A) { 1130 validateGERU(mRS, Element::F64_2(mRS), X, incX, Y, incY, A); 1131 int M = A->getType()->getY(); 1132 int N = A->getType()->getX(); 1133 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgeru, 1134 0, 0, 0, 0, 0, M, N, 0, 1135 alpha.x, alpha.y, X->getID(), Y->getID(), 1136 0, 0, A->getID(), incX, incY, 0, 0); 1137} 1138 1139void ScriptIntrinsicBLAS::ZGERC(Double2 alpha, sp<Allocation> X, int incX, 1140 sp<Allocation> Y, int incY, sp<Allocation> A) { 1141 // Same as GERU 1142 validateGERU(mRS, Element::F64_2(mRS), X, incX, Y, incY, A); 1143 int M = A->getType()->getY(); 1144 int N = A->getType()->getX(); 1145 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgerc, 1146 0, 0, 0, 0, 0, M, N, 0, 1147 alpha.x, alpha.y, X->getID(), Y->getID(), 1148 0, 0, A->getID(), incX, incY, 0, 0); 1149} 1150 1151void ScriptIntrinsicBLAS::ZHER(RsBlasUplo Uplo, double alpha, sp<Allocation> X, 1152 int incX, sp<Allocation> A) { 1153 // Same as SYR 1154 int N = validateSYR(mRS, Element::F64_2(mRS), Uplo, X, incX, A); 1155 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher, 1156 0, 0, 0, Uplo, 0, 0, N, 0, 1157 alpha, 0, X->getID(), 0, 1158 0, 0, A->getID(), incX, 0, 0, 0); 1159} 1160 1161void ScriptIntrinsicBLAS::ZHPR(RsBlasUplo Uplo, double alpha, sp<Allocation> X, 1162 int incX, sp<Allocation> Ap) { 1163 // Equivalent to SPR for validation 1164 int N = validateSPR(mRS, Element::F64_2(mRS), Uplo, X, incX, Ap); 1165 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpr, 1166 0, 0, 0, Uplo, 0, 0, N, 0, 1167 alpha, 0, X->getID(), 0, 1168 0, 0, Ap->getID(), incX, 0, 0, 0); 1169} 1170 1171void ScriptIntrinsicBLAS::ZHER2(RsBlasUplo Uplo, Double2 alpha, sp<Allocation> X, int incX, 1172 sp<Allocation> Y, int incY, sp<Allocation> A) { 1173 // Same as SYR2 1174 int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A); 1175 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher2, 1176 0, 0, 0, Uplo, 0, 0, N, 0, 1177 alpha.x, alpha.y, X->getID(), Y->getID(), 1178 0, 0, A->getID(), incX, incY, 0, 0); 1179} 1180 1181void ScriptIntrinsicBLAS::ZHPR2(RsBlasUplo Uplo, Double2 alpha, sp<Allocation> X, int incX, 1182 sp<Allocation> Y, int incY, sp<Allocation> Ap) { 1183 // Same as SPR2 1184 int N = validateSPR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, Ap); 1185 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpr2, 1186 0, 0, 0, Uplo, 0, 0, N, 0, 1187 alpha.x, alpha.y, X->getID(), Y->getID(), 1188 0, 0, Ap->getID(), incX, incY, 0, 0); 1189} 1190 1191 1192/** 1193 * Level 3 BLAS 1194 */ 1195 1196static void validateL3(RS* mRS, sp<const Element> e, int TransA, int TransB, int Side, 1197 sp<Allocation> A, sp<Allocation> B, sp<Allocation> C) { 1198 int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1; 1199 if ((A != nullptr && !A->getType()->getElement()->isCompatible(e)) || 1200 (B != nullptr && !B->getType()->getElement()->isCompatible(e)) || 1201 (C != nullptr && !C->getType()->getElement()->isCompatible(e))) { 1202 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 1203 } 1204 if (C == nullptr) { 1205 // Since matrix C is used to store the result, it cannot be null. 1206 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Allocation C cannot be null"); 1207 } 1208 cM = C->getType()->getY(); 1209 cN = C->getType()->getX(); 1210 1211 if (Side == RsBlasRight) { 1212 if ((A == nullptr && B != nullptr) || (A != nullptr && B == nullptr)) { 1213 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Provided Matrix A without Matrix B, or vice versa"); 1214 } 1215 if (B != nullptr) { 1216 bM = A->getType()->getY(); 1217 bN = A->getType()->getX(); 1218 } 1219 if (A != nullptr) { 1220 aM = B->getType()->getY(); 1221 aN = B->getType()->getX(); 1222 } 1223 } else { 1224 if (A != nullptr) { 1225 if (TransA == RsBlasTrans || TransA == RsBlasConjTrans) { 1226 aN = A->getType()->getY(); 1227 aM = A->getType()->getX(); 1228 } else { 1229 aM = A->getType()->getY(); 1230 aN = A->getType()->getX(); 1231 } 1232 } 1233 if (B != nullptr) { 1234 if (TransB == RsBlasTrans || TransB == RsBlasConjTrans) { 1235 bN = B->getType()->getY(); 1236 bM = B->getType()->getX(); 1237 } else { 1238 bM = B->getType()->getY(); 1239 bN = B->getType()->getX(); 1240 } 1241 } 1242 } 1243 if (A != nullptr && B != nullptr && C != nullptr) { 1244 if (aN != bM || aM != cM || bN != cN) { 1245 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions"); 1246 } 1247 } else if (A != nullptr && C != nullptr) { 1248 // A and C only, for SYRK 1249 if (cM != cN) { 1250 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix C is not symmetric"); 1251 } 1252 if (aM != cM) { 1253 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions"); 1254 } 1255 } else if (A != nullptr && B != nullptr) { 1256 // A and B only 1257 if (aN != bM) { 1258 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions"); 1259 } 1260 } 1261 1262} 1263 1264void ScriptIntrinsicBLAS::SGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, float alpha, 1265 sp<Allocation> A, sp<Allocation> B, float beta, sp<Allocation> C) { 1266 validateL3(mRS, Element::F32(mRS), TransA, TransB, 0, A, B, C); 1267 1268 int M = -1, N = -1, K = -1; 1269 if (TransA != RsBlasNoTrans) { 1270 M = A->getType()->getX(); 1271 K = A->getType()->getY(); 1272 } else { 1273 M = A->getType()->getY(); 1274 K = A->getType()->getX(); 1275 } 1276 if (TransB != RsBlasNoTrans) { 1277 N = B->getType()->getY(); 1278 } else { 1279 N = B->getType()->getX(); 1280 } 1281 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgemm, 1282 TransA, TransB, 0, 0, 0, M, N, K, 1283 alpha, A->getID(), B->getID(), 1284 beta, C->getID(), 0, 0, 0, 0); 1285} 1286 1287void ScriptIntrinsicBLAS::DGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, double alpha, 1288 sp<Allocation> A, sp<Allocation> B, double beta, sp<Allocation> C) { 1289 validateL3(mRS, Element::F64(mRS), TransA, TransB, 0, A, B, C); 1290 int M = -1, N = -1, K = -1; 1291 if (TransA != RsBlasNoTrans) { 1292 M = A->getType()->getX(); 1293 K = A->getType()->getY(); 1294 } else { 1295 M = A->getType()->getY(); 1296 K = A->getType()->getX(); 1297 } 1298 if (TransB != RsBlasNoTrans) { 1299 N = B->getType()->getY(); 1300 } else { 1301 N = B->getType()->getX(); 1302 } 1303 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgemm, 1304 TransA, TransB, 0, 0, 0, M, N, K, 1305 alpha, A->getID(), B->getID(), 1306 beta, C->getID(), 0, 0, 0, 0); 1307} 1308 1309void ScriptIntrinsicBLAS::CGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, Float2 alpha, 1310 sp<Allocation> A, sp<Allocation> B, Float2 beta, sp<Allocation> C) { 1311 validateL3(mRS, Element::F32_2(mRS), TransA, TransB, 0, A, B, C); 1312 int M = -1, N = -1, K = -1; 1313 if (TransA != RsBlasNoTrans) { 1314 M = A->getType()->getX(); 1315 K = A->getType()->getY(); 1316 } else { 1317 M = A->getType()->getY(); 1318 K = A->getType()->getX(); 1319 } 1320 if (TransB != RsBlasNoTrans) { 1321 N = B->getType()->getY(); 1322 } else { 1323 N = B->getType()->getX(); 1324 } 1325 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgemm, 1326 TransA, TransB, 0, 0, 0, M, N, K, 1327 alpha.x, alpha.y, A->getID(), B->getID(), 1328 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1329} 1330 1331void ScriptIntrinsicBLAS::ZGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, Double2 alpha, 1332 sp<Allocation> A, sp<Allocation> B, Double2 beta, sp<Allocation> C) { 1333 validateL3(mRS, Element::F64_2(mRS), TransA, TransB, 0, A, B, C); 1334 int M = -1, N = -1, K = -1; 1335 if (TransA != RsBlasNoTrans) { 1336 M = A->getType()->getX(); 1337 K = A->getType()->getY(); 1338 } else { 1339 M = A->getType()->getY(); 1340 K = A->getType()->getX(); 1341 } 1342 if (TransB != RsBlasNoTrans) { 1343 N = B->getType()->getY(); 1344 } else { 1345 N = B->getType()->getX(); 1346 } 1347 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgemm, 1348 TransA, TransB, 0, 0, 0, M, N, K, 1349 alpha.x, alpha.y, A->getID(), B->getID(), 1350 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1351} 1352 1353void ScriptIntrinsicBLAS::SSYMM(RsBlasSide Side, RsBlasUplo Uplo, float alpha, 1354 sp<Allocation> A, sp<Allocation> B, float beta, sp<Allocation> C) { 1355 //For SYMM, Matrix A should be symmetric 1356 if (A->getType()->getX() != A->getType()->getY()) { 1357 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric"); 1358 } 1359 validateL3(mRS, Element::F32(mRS), 0, 0, Side, A, B, C); 1360 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssymm, 1361 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0, 1362 alpha, A->getID(), B->getID(), 1363 beta, C->getID(), 0, 0, 0, 0); 1364} 1365 1366void ScriptIntrinsicBLAS::DSYMM(RsBlasSide Side, RsBlasUplo Uplo, double alpha, 1367 sp<Allocation> A, sp<Allocation> B, double beta, sp<Allocation> C) { 1368 if (A->getType()->getX() != A->getType()->getY()) { 1369 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric"); 1370 } 1371 validateL3(mRS, Element::F64(mRS), 0, 0, Side, A, B, C); 1372 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsymm, 1373 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0, 1374 alpha, A->getID(), B->getID(), 1375 beta, C->getID(), 0, 0, 0, 0); 1376} 1377 1378void ScriptIntrinsicBLAS::CSYMM(RsBlasSide Side, RsBlasUplo Uplo, Float2 alpha, 1379 sp<Allocation> A, sp<Allocation> B, Float2 beta, sp<Allocation> C) { 1380 if (A->getType()->getX() != A->getType()->getY()) { 1381 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric"); 1382 } 1383 validateL3(mRS, Element::F32_2(mRS), 0, 0, Side, A, B, C); 1384 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csymm, 1385 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0, 1386 alpha.x, alpha.y, A->getID(), B->getID(), 1387 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1388} 1389 1390void ScriptIntrinsicBLAS::ZSYMM(RsBlasSide Side, RsBlasUplo Uplo, Double2 alpha, 1391 sp<Allocation> A, sp<Allocation> B, Double2 beta, sp<Allocation> C) { 1392 if (A->getType()->getX() != A->getType()->getY()) { 1393 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric"); 1394 } 1395 validateL3(mRS, Element::F64_2(mRS), 0, 0, Side, A, B, C); 1396 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsymm, 1397 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0, 1398 alpha.x, alpha.y, A->getID(), B->getID(), 1399 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1400} 1401 1402void ScriptIntrinsicBLAS::SSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha, 1403 sp<Allocation> A, float beta, sp<Allocation> C) { 1404 validateL3(mRS, Element::F32(mRS), Trans, 0, 0, A, nullptr, C); 1405 int K = -1; 1406 if (Trans != RsBlasNoTrans) { 1407 K = A->getType()->getY(); 1408 } else { 1409 K = A->getType()->getX(); 1410 } 1411 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyrk, 1412 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, 1413 alpha, A->getID(), 0, 1414 beta, C->getID(), 0, 0, 0, 0); 1415} 1416 1417void ScriptIntrinsicBLAS::DSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha, 1418 sp<Allocation> A, double beta, sp<Allocation> C) { 1419 validateL3(mRS, Element::F64(mRS), Trans, 0, 0, A, nullptr, C); 1420 int K = -1; 1421 if (Trans != RsBlasNoTrans) { 1422 K = A->getType()->getY(); 1423 } else { 1424 K = A->getType()->getX(); 1425 } 1426 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyrk, 1427 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, 1428 alpha, A->getID(), 0, 1429 beta, C->getID(), 0, 0, 0, 0); 1430} 1431 1432void ScriptIntrinsicBLAS::CSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha, 1433 sp<Allocation> A, Float2 beta, sp<Allocation> C) { 1434 validateL3(mRS, Element::F32_2(mRS), Trans, 0, 0, A, nullptr, C); 1435 int K = -1; 1436 if (Trans != RsBlasNoTrans) { 1437 K = A->getType()->getY(); 1438 } else { 1439 K = A->getType()->getX(); 1440 } 1441 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csyrk, 1442 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, 1443 alpha.x, alpha.y, A->getID(), 0, 1444 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1445} 1446 1447void ScriptIntrinsicBLAS::ZSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha, 1448 sp<Allocation> A, Double2 beta, sp<Allocation> C) { 1449 validateL3(mRS, Element::F64_2(mRS), Trans, 0, 0, A, nullptr, C); 1450 int K = -1; 1451 if (Trans != RsBlasNoTrans) { 1452 K = A->getType()->getY(); 1453 } else { 1454 K = A->getType()->getX(); 1455 } 1456 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsyrk, 1457 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, 1458 alpha.x, alpha.y, A->getID(), 0, 1459 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1460} 1461 1462static void validateSYR2K(RS* mRS, sp<const Element> e, RsBlasTranspose Trans, 1463 sp<Allocation> A, sp<Allocation> B, sp<Allocation> C) { 1464 if (!A->getType()->getElement()->isCompatible(e) || 1465 !B->getType()->getElement()->isCompatible(e) || 1466 !C->getType()->getElement()->isCompatible(e)) { 1467 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 1468 } 1469 int Cdim = -1; 1470 // A is n x k if no transpose, k x n if transpose 1471 // C is n x n 1472 if (Trans == RsBlasTrans) { 1473 // check columns versus C 1474 Cdim = A->getType()->getX(); 1475 } else { 1476 // check rows versus C 1477 Cdim = A->getType()->getY(); 1478 } 1479 if ((int)C->getType()->getX() != Cdim || (int)C->getType()->getY() != Cdim) { 1480 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid symmetric matrix in SYR2K"); 1481 } 1482 // A dims == B dims 1483 if (A->getType()->getX() != B->getType()->getX() || A->getType()->getY() != B->getType()->getY()) { 1484 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid A and B in SYR2K"); 1485 } 1486} 1487 1488void ScriptIntrinsicBLAS::SSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha, 1489 sp<Allocation> A, sp<Allocation> B, float beta, sp<Allocation> C) { 1490 validateSYR2K(mRS, Element::F32(mRS), Trans, A, B, C); 1491 int K = -1; 1492 if (Trans != RsBlasNoTrans) { 1493 K = A->getType()->getY(); 1494 } else { 1495 K = A->getType()->getX(); 1496 } 1497 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr2k, 1498 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, 1499 alpha, A->getID(), B->getID(), 1500 beta, C->getID(), 0, 0, 0, 0); 1501} 1502 1503void ScriptIntrinsicBLAS::DSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha, 1504 sp<Allocation> A, sp<Allocation> B, double beta, sp<Allocation> C) { 1505 validateSYR2K(mRS, Element::F64(mRS), Trans, A, B, C); 1506 int K = -1; 1507 if (Trans != RsBlasNoTrans) { 1508 K = A->getType()->getY(); 1509 } else { 1510 K = A->getType()->getX(); 1511 } 1512 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr2k, 1513 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, 1514 alpha, A->getID(), B->getID(), 1515 beta, C->getID(), 0, 0, 0, 0); 1516} 1517 1518void ScriptIntrinsicBLAS::CSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha, 1519 sp<Allocation> A, sp<Allocation> B, Float2 beta, sp<Allocation> C) { 1520 validateSYR2K(mRS, Element::F32_2(mRS), Trans, A, B, C); 1521 int K = -1; 1522 if (Trans != RsBlasNoTrans) { 1523 K = A->getType()->getY(); 1524 } else { 1525 K = A->getType()->getX(); 1526 } 1527 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csyr2k, 1528 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, 1529 alpha.x, alpha.y, A->getID(), B->getID(), 1530 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1531} 1532 1533void ScriptIntrinsicBLAS::ZSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha, 1534 sp<Allocation> A, sp<Allocation> B, Double2 beta, sp<Allocation> C) { 1535 validateSYR2K(mRS, Element::F64_2(mRS), Trans, A, B, C); 1536 int K = -1; 1537 if (Trans != RsBlasNoTrans) { 1538 K = A->getType()->getY(); 1539 } else { 1540 K = A->getType()->getX(); 1541 } 1542 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsyr2k, 1543 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, 1544 alpha.x, alpha.y, A->getID(), B->getID(), 1545 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1546} 1547 1548static void validateTRMM(RS* mRS, sp<const Element> e, RsBlasSide Side, RsBlasTranspose TransA, 1549 sp<Allocation> A, sp<Allocation> B) { 1550 int aM = -1, aN = -1, bM = -1, bN = -1; 1551 if (!A->getType()->getElement()->isCompatible(e) || 1552 !B->getType()->getElement()->isCompatible(e)) { 1553 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 1554 } 1555 1556 aM = A->getType()->getY(); 1557 aN = A->getType()->getX(); 1558 if (aM != aN) { 1559 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with a non-symmetric matrix A"); 1560 } 1561 1562 bM = B->getType()->getY(); 1563 bN = B->getType()->getX(); 1564 if (Side == RsBlasLeft) { 1565 if (aN != bM) { 1566 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with invalid matrices"); 1567 } 1568 } else { 1569 if (bN != aM) { 1570 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with invalid matrices"); 1571 } 1572 } 1573} 1574 1575void ScriptIntrinsicBLAS::STRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 1576 float alpha, sp<Allocation> A, sp<Allocation> B) { 1577 validateTRMM(mRS, Element::F32(mRS), Side, TransA, A, B); 1578 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strmm, 1579 TransA, 0, Side, Uplo, Diag,\ 1580 B->getType()->getY(), B->getType()->getX(), 0, 1581 alpha, A->getID(), B->getID(), 0.f, 0, 0, 0, 0, 0); 1582} 1583 1584void ScriptIntrinsicBLAS::DTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 1585 double alpha, sp<Allocation> A, sp<Allocation> B) { 1586 validateTRMM(mRS, Element::F64(mRS), Side, TransA, A, B); 1587 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrmm, 1588 TransA, 0, Side, Uplo, Diag, 1589 B->getType()->getY(), B->getType()->getX(), 0, 1590 alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0); 1591} 1592 1593void ScriptIntrinsicBLAS::CTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 1594 Float2 alpha, sp<Allocation> A, sp<Allocation> B) { 1595 validateTRMM(mRS, Element::F32_2(mRS), Side, TransA, A, B); 1596 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrmm, 1597 TransA, 0, Side, Uplo, Diag, 1598 B->getType()->getY(), B->getType()->getX(), 0, 1599 alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0); 1600} 1601 1602void ScriptIntrinsicBLAS::ZTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 1603 Double2 alpha, sp<Allocation> A, sp<Allocation> B) { 1604 validateTRMM(mRS, Element::F64_2(mRS), Side, TransA, A, B); 1605 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrmm, 1606 TransA, 0, Side, Uplo, Diag, 1607 B->getType()->getY(), B->getType()->getX(), 0, 1608 alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0); 1609} 1610 1611static void validateTRSM(RS* mRS, sp<const Element> e, RsBlasSide Side, RsBlasTranspose TransA, 1612 sp<Allocation> A, sp<Allocation> B) { 1613 int adim = -1, bM = -1, bN = -1; 1614 if (!A->getType()->getElement()->isCompatible(e) || 1615 !B->getType()->getElement()->isCompatible(e)) { 1616 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 1617 } 1618 adim = A->getType()->getX(); 1619 if (adim != (int)A->getType()->getY()) { 1620 // This may be unnecessary, the restriction could potentially be relaxed. 1621 // Allocation A needs to contain at least that symmetric matrix but could theoretically 1622 // be larger for now we assume adapters are sufficient, will reevaluate in the future. 1623 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with a non-symmetric matrix A"); 1624 } 1625 bM = B->getType()->getY(); 1626 bN = B->getType()->getX(); 1627 if (Side == RsBlasLeft) { 1628 // A is M*M 1629 if (adim != bM) { 1630 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with invalid matrix dimensions"); 1631 } 1632 } else { 1633 // A is N*N 1634 if (adim != bN) { 1635 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with invalid matrix dimensions"); 1636 } 1637 } 1638} 1639 1640void ScriptIntrinsicBLAS::STRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 1641 float alpha, sp<Allocation> A, sp<Allocation> B) { 1642 validateTRSM(mRS, Element::F32(mRS), Side, TransA, A, B); 1643 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strsm, 1644 TransA, 0, Side, Uplo, Diag, 1645 B->getType()->getY(), B->getType()->getX(), 0, 1646 alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0); 1647} 1648 1649void ScriptIntrinsicBLAS::DTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 1650 double alpha, sp<Allocation> A, sp<Allocation> B) { 1651 validateTRSM(mRS, Element::F64(mRS), Side, TransA, A, B); 1652 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrsm, 1653 TransA, 0, Side, Uplo, Diag, 1654 B->getType()->getY(), B->getType()->getX(), 0, 1655 alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0); 1656} 1657 1658void ScriptIntrinsicBLAS::CTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 1659 Float2 alpha, sp<Allocation> A, sp<Allocation> B) { 1660 validateTRSM(mRS, Element::F32_2(mRS), Side, TransA, A, B); 1661 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrsm, 1662 TransA, 0, Side, Uplo, Diag, 1663 B->getType()->getY(), B->getType()->getX(), 0, 1664 alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0); 1665} 1666 1667void ScriptIntrinsicBLAS::ZTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 1668 Double2 alpha, sp<Allocation> A, sp<Allocation> B) { 1669 validateTRSM(mRS, Element::F64_2(mRS), Side, TransA, A, B); 1670 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrsm, 1671 TransA, 0, Side, Uplo, Diag, 1672 B->getType()->getY(), B->getType()->getX(), 0, 1673 alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0); 1674} 1675 1676static void validateHEMM(RS* mRS, sp<const Element> e, RsBlasSide Side, 1677 sp<Allocation> A, sp<Allocation> B, sp<Allocation> C) { 1678 if (!A->getType()->getElement()->isCompatible(e) || 1679 !B->getType()->getElement()->isCompatible(e) || 1680 !C->getType()->getElement()->isCompatible(e)) { 1681 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 1682 } 1683 1684 // A must be square; can potentially be relaxed similar to TRSM 1685 int adim = A->getType()->getX(); 1686 if (adim != (int)A->getType()->getY()) { 1687 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with non-square A"); 1688 } 1689 if ((Side == RsBlasLeft && adim != (int)B->getType()->getY()) || 1690 (Side == RsBlasRight && adim != (int)B->getType()->getX())) { 1691 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with invalid B"); 1692 } 1693 if (B->getType()->getX() != C->getType()->getX() || 1694 B->getType()->getY() != C->getType()->getY()) { 1695 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with mismatched B and C"); 1696 } 1697} 1698 1699void ScriptIntrinsicBLAS::CHEMM(RsBlasSide Side, RsBlasUplo Uplo, Float2 alpha, 1700 sp<Allocation> A, sp<Allocation> B, Float2 beta, sp<Allocation> C) { 1701 validateHEMM(mRS, Element::F32_2(mRS), Side, A, B, C); 1702 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chemm, 1703 0, 0, Side, Uplo, 0, 1704 C->getType()->getY(), C->getType()->getX(), 0, 1705 alpha.x, alpha.y, A->getID(), B->getID(), 1706 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1707} 1708 1709void ScriptIntrinsicBLAS::ZHEMM(RsBlasSide Side, RsBlasUplo Uplo, Double2 alpha, 1710 sp<Allocation> A, sp<Allocation> B, Double2 beta, sp<Allocation> C) { 1711 validateHEMM(mRS, Element::F64_2(mRS), Side, A, B, C); 1712 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhemm, 1713 0, 0, Side, Uplo, 0, 1714 C->getType()->getY(), C->getType()->getX(), 0, 1715 alpha.x, alpha.y, A->getID(), B->getID(), 1716 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1717} 1718 1719static void validateHERK(RS* mRS, sp<const Element> e, RsBlasTranspose Trans, 1720 sp<Allocation> A, sp<Allocation> C) { 1721 if (!A->getType()->getElement()->isCompatible(e) || 1722 !C->getType()->getElement()->isCompatible(e)) { 1723 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 1724 } 1725 if (Trans != RsBlasNoTrans && Trans != RsBlasConjTrans) { 1726 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Call HERK with invalid Transpose"); 1727 } 1728 int cdim = C->getType()->getX(); 1729 if (cdim != (int)C->getType()->getY()) { 1730 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with non-square C"); 1731 } 1732 if (Trans == RsBlasNoTrans) { 1733 if (cdim != (int)A->getType()->getY()) { 1734 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with invalid A"); 1735 } 1736 } else { 1737 if (cdim != (int)A->getType()->getX()) { 1738 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with invalid A"); 1739 } 1740 } 1741} 1742 1743void ScriptIntrinsicBLAS::CHERK(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha, 1744 sp<Allocation> A, float beta, sp<Allocation> C) { 1745 validateHERK(mRS, Element::F32_2(mRS), Trans, A, C); 1746 int k = 0; 1747 if (Trans == RsBlasConjTrans) { 1748 k = A->getType()->getY(); 1749 } else { 1750 k = A->getType()->getX(); 1751 } 1752 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cherk, 1753 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k, 1754 alpha, 0, A->getID(), 0, 1755 beta, 0, C->getID(), 0, 0, 0, 0); 1756} 1757 1758void ScriptIntrinsicBLAS::ZHERK(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha, 1759 sp<Allocation> A, double beta, sp<Allocation> C) { 1760 validateHERK(mRS, Element::F64_2(mRS), Trans, A, C); 1761 int k = 0; 1762 if (Trans == RsBlasConjTrans) { 1763 k = A->getType()->getY(); 1764 } else { 1765 k = A->getType()->getX(); 1766 } 1767 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zherk, 1768 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k, 1769 alpha, 0, A->getID(), 0, 1770 beta, 0, C->getID(), 0, 0, 0, 0); 1771} 1772 1773static void validateHER2K(RS* mRS, sp<const Element> e, RsBlasTranspose Trans, 1774 sp<Allocation> A, sp<Allocation> B, sp<Allocation> C) { 1775 if (!A->getType()->getElement()->isCompatible(e) || 1776 !B->getType()->getElement()->isCompatible(e) || 1777 !C->getType()->getElement()->isCompatible(e)) { 1778 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 1779 } 1780 if (Trans != RsBlasNoTrans && Trans != RsBlasConjTrans) { 1781 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Call HERK with invalid Transpose"); 1782 } 1783 int cdim = C->getType()->getX(); 1784 if (cdim != (int)C->getType()->getY()) { 1785 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with non-square C"); 1786 } 1787 if (Trans == RsBlasNoTrans) { 1788 if ((int)A->getType()->getY() != cdim) { 1789 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid matrices"); 1790 } 1791 } else { 1792 if ((int)A->getType()->getX() != cdim) { 1793 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid matrices"); 1794 } 1795 } 1796 if (A->getType()->getX() != B->getType()->getX() || A->getType()->getY() != B->getType()->getY()) { 1797 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid A and B matrices"); 1798 } 1799} 1800 1801void ScriptIntrinsicBLAS::CHER2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha, 1802 sp<Allocation> A, sp<Allocation> B, float beta, sp<Allocation> C) { 1803 validateHER2K(mRS, Element::F32_2(mRS), Trans, A, B, C); 1804 int k = 0; 1805 if (Trans == RsBlasNoTrans) { 1806 k = A->getType()->getX(); 1807 } else { 1808 k = A->getType()->getY(); 1809 } 1810 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher2k, 1811 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k, 1812 alpha.x, alpha.y, A->getID(), B->getID(), 1813 beta, 0, C->getID(), 0, 0, 0, 0); 1814} 1815 1816void ScriptIntrinsicBLAS::ZHER2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha, 1817 sp<Allocation> A, sp<Allocation> B, double beta, sp<Allocation> C) { 1818 validateHER2K(mRS, Element::F64_2(mRS), Trans, A, B, C); 1819 int k = 0; 1820 if (Trans == RsBlasNoTrans) { 1821 k = A->getType()->getX(); 1822 } else { 1823 k = A->getType()->getY(); 1824 } 1825 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher2k, 1826 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k, 1827 alpha.x, alpha.y, A->getID(), B->getID(), 1828 beta, 0, C->getID(), 0, 0, 0, 0); 1829} 1830 1831 1832 1833void ScriptIntrinsicBLAS::BNNM(sp<Allocation> A, int a_offset, sp<Allocation> B, int b_offset, 1834 sp<Allocation> C, int c_offset, int c_mult) { 1835 validateL3(mRS, Element::U8(mRS), RsBlasNoTrans, RsBlasTrans, 0, A, B, C); 1836 1837 if (a_offset < 0 || a_offset > 255) { 1838 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid a_offset passed to BNNM"); 1839 } 1840 if (b_offset < 0 || b_offset > 255) { 1841 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid b_offset passed to BNNM"); 1842 } 1843 int M = -1, N = -1, K = -1; 1844 M = A->getType()->getY(); 1845 N = B->getType()->getY(); 1846 K = A->getType()->getX(); 1847 1848 nScriptIntrinsicBLAS_BNNM(mRS, mRS->getContext(), getID(), M, N, K, A->getID(), a_offset, 1849 B->getID(), b_offset, C->getID(), c_offset, c_mult); 1850} 1851