1/* 2 * Copyright (C) 2015 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 18#include "RenderScript.h" 19#include "rsCppInternal.h" 20 21#define NELEM(m) (sizeof(m) / sizeof((m)[0])) 22 23using android::RSC::Allocation; 24using android::RSC::Element; 25using android::RSC::RS; 26using android::RSC::RS_ERROR_INVALID_ELEMENT; 27using android::RSC::RS_ERROR_INVALID_PARAMETER; 28using android::RSC::RS_SUCCESS; 29using android::RSC::ScriptIntrinsicBLAS; 30using android::RSC::sp; 31 32// ScriptIntrinsicBLAS APIS 33ScriptIntrinsicBLAS::ScriptIntrinsicBLAS(sp<RS> rs, sp<const Element> e) 34 : ScriptIntrinsic(rs, RS_SCRIPT_INTRINSIC_ID_BLAS, e) { 35 36} 37 38sp<ScriptIntrinsicBLAS> ScriptIntrinsicBLAS::create(const sp<RS>& rs) { 39 return new ScriptIntrinsicBLAS(rs, Element::U32(rs)); 40} 41 42enum RsBlasDataType { 43 SINGLE, 44 DOUBLE, 45 SINGLE_COMPLEX, 46 DOUBLE_COMPLEX 47}; 48 49static RsBlasCall 50setUpBLASCall(RsBlasDataType dataType, RsBlasFunction func, 51 int TransA, int TransB, int Side, int Uplo, int Diag, 52 int M, int N, int K, int incX, int incY, int KL, int KU, 53 float alphaF, float betaF, double alphaD, double betaD, 54 float alphaCX, float alphaCY, float betaCX, float betaCY, 55 double alphaZX, double alphaZY, double betaZX, double betaZY 56 ) { 57 RsBlasCall call; 58 memset(&call, 0, sizeof(call)); 59 call.func = func; 60 call.transA = (RsBlasTranspose)TransA; 61 call.transB = (RsBlasTranspose)TransB; 62 call.side = (RsBlasSide)Side; 63 call.uplo = (RsBlasUplo)Uplo; 64 call.diag = (RsBlasDiag)Diag; 65 call.M = M; 66 call.N = N; 67 call.K = K; 68 69 switch (dataType) { 70 case SINGLE: 71 // For Single-precision BLAS. 72 call.alpha.f = alphaF; 73 call.beta.f = betaF; 74 break; 75 case DOUBLE: 76 // For Double-precision BLAS. 77 call.alpha.d = alphaD; 78 call.beta.d = betaD; 79 break; 80 case SINGLE_COMPLEX: 81 // For Single-precision complex BLAS. 82 call.alpha.c.r = alphaCX; 83 call.alpha.c.i = alphaCY; 84 call.beta.c.r = betaCX; 85 call.beta.c.i = betaCY; 86 break; 87 case DOUBLE_COMPLEX: 88 // For Double-precision complex BLAS. 89 call.alpha.z.r = alphaZX; 90 call.alpha.z.i = alphaZY; 91 call.beta.z.r = betaZX; 92 call.beta.z.i = betaZY; 93 break; 94 default: 95 break; 96 } 97 98 call.incX = incX; 99 call.incY = incY; 100 call.KL = KL; 101 call.KU = KU; 102 103 return call; 104} 105 106static void 107nScriptIntrinsicBLAS_Single(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA, 108 int TransB, int Side, int Uplo, int Diag, int M, int N, int K, 109 float alpha, RsAllocation A, RsAllocation B, 110 float beta, RsAllocation C, int incX, int incY, int KL, int KU) { 111 RsBlasCall call = setUpBLASCall(SINGLE, func, TransA, TransB, Side, Uplo, Diag, 112 M, N, K, incX, incY, KL, KU, alpha, beta, 0.0, 0.0, 113 0.0f, 0.0f, 0.0f, 0.0f, 0.0, 0.0, 0.0, 0.0); 114 RsAllocation in_allocs[3] = {A, B, C}; 115 tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr, 116 &call, sizeof(call), nullptr, 0)); 117} 118 119 120static void 121nScriptIntrinsicBLAS_Double(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA, 122 int TransB, int Side, int Uplo, int Diag, int M, int N, int K, 123 double alpha, RsAllocation A, RsAllocation B, 124 double beta, RsAllocation C, int incX, int incY, int KL, int KU) { 125 RsBlasCall call = setUpBLASCall(DOUBLE, func, TransA, TransB, Side, Uplo, Diag, 126 M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, alpha, beta, 127 0.0f, 0.0f, 0.0f, 0.0f, 0.0, 0.0, 0.0, 0.0); 128 RsAllocation in_allocs[3] = {A, B, C}; 129 tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr, 130 &call, sizeof(call), nullptr, 0)); 131} 132 133static void 134nScriptIntrinsicBLAS_Complex(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA, 135 int TransB, int Side, int Uplo, int Diag, int M, int N, int K, 136 float alphaX, float alphaY, RsAllocation A, RsAllocation B, 137 float betaX, float betaY, RsAllocation C, int incX, int incY, int KL, int KU) { 138 RsBlasCall call = setUpBLASCall(SINGLE_COMPLEX, func, TransA, TransB, Side, Uplo, Diag, 139 M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, 0.0, 0.0, 140 alphaX, alphaY, betaX, betaY, 0.0, 0.0, 0.0, 0.0); 141 RsAllocation in_allocs[3] = {A, B, C}; 142 tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr, 143 &call, sizeof(call), nullptr, 0)); 144} 145 146static void 147nScriptIntrinsicBLAS_Z(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA, 148 int TransB, int Side, int Uplo, int Diag, int M, int N, int K, 149 double alphaX, double alphaY, RsAllocation A, RsAllocation B, 150 double betaX, double betaY, RsAllocation C, int incX, int incY, int KL, int KU) { 151 RsBlasCall call = setUpBLASCall(DOUBLE_COMPLEX, func, TransA, TransB, Side, Uplo, Diag, 152 M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, 0.0, 0.0, 153 0.0f, 0.0f, 0.0f, 0.0f, alphaX, alphaY, betaX, betaY); 154 RsAllocation in_allocs[3] = {A, B, C}; 155 tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr, 156 &call, sizeof(call), nullptr, 0)); 157} 158 159 160static void 161nScriptIntrinsicBLAS_BNNM(RS* mRS, RsContext con, RsScript id, int M, int N, int K, 162 RsAllocation A, int a_offset, RsAllocation B, int b_offset, 163 RsAllocation C, int c_offset, int c_mult_int) { 164 RsBlasCall call; 165 memset(&call, 0, sizeof(call)); 166 call.func = RsBlas_bnnm; 167 call.M = M; 168 call.N = N; 169 call.K = K; 170 call.a_offset = a_offset & 0xFF; 171 call.b_offset = b_offset & 0xFF; 172 call.c_offset = c_offset; 173 call.c_mult_int = c_mult_int; 174 175 RsAllocation in_allocs[3] = {A, B, C}; 176 tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr, 177 &call, sizeof(call), nullptr, 0)); 178} 179 180/** 181 * Level 2 BLAS 182 */ 183static void validateGEMV(RS* mRS, const sp<const Element>& e, RsBlasTranspose TransA, const sp<Allocation>& A, 184 const sp<Allocation>& X, int incX, const sp<Allocation>& Y, int incY) { 185 int M = A->getType()->getY(); 186 int N = A->getType()->getX(); 187 if (!A->getType()->getElement()->isCompatible(e) || 188 !X->getType()->getElement()->isCompatible(e) || 189 !Y->getType()->getElement()->isCompatible(e)) { 190 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 191 } 192 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { 193 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 194 } 195 196 if (incX <= 0 || incY <= 0) { 197 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 198 } 199 int expectedXDim = -1, expectedYDim = -1; 200 if (TransA == RsBlasNoTrans) { 201 expectedXDim = 1 + (N - 1) * incX; 202 expectedYDim = 1 + (M - 1) * incY; 203 } else { 204 expectedXDim = 1 + (M - 1) * incX; 205 expectedYDim = 1 + (N - 1) * incY; 206 } 207 if ((int)X->getType()->getX() != expectedXDim || 208 (int)Y->getType()->getX() != expectedYDim) { 209 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GEMV"); 210 } 211} 212 213void ScriptIntrinsicBLAS::SGEMV(RsBlasTranspose TransA, float alpha, const sp<Allocation>& A, const sp<Allocation>& X, 214 int incX, float beta, const sp<Allocation>& Y, int incY) { 215 validateGEMV(mRS, Element::F32(mRS), TransA, A, X, incX, Y, incY); 216 int M = A->getType()->getY(); 217 int N = A->getType()->getX(); 218 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgemv, 219 TransA, 0, 0, 0, 0, M, N, 0, 220 alpha, A->getID(), X->getID(), 221 beta, Y->getID(), incX, incY, 0, 0); 222} 223 224void ScriptIntrinsicBLAS::DGEMV(RsBlasTranspose TransA, double alpha, const sp<Allocation>& A, const sp<Allocation>& X, 225 int incX, double beta, const sp<Allocation>& Y, int incY) { 226 validateGEMV(mRS, Element::F64(mRS), TransA, A, X, incX, Y, incY); 227 int M = A->getType()->getY(); 228 int N = A->getType()->getX(); 229 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgemv, 230 TransA, 0, 0, 0, 0, M, N, 0, 231 alpha, A->getID(), X->getID(), 232 beta, Y->getID(), incX, incY, 0, 0); 233} 234 235void ScriptIntrinsicBLAS::CGEMV(RsBlasTranspose TransA, Float2 alpha, const sp<Allocation>& A, const sp<Allocation>& X, 236 int incX, Float2 beta, const sp<Allocation>& Y, int incY) { 237 validateGEMV(mRS, Element::F32_2(mRS), TransA, A, X, incX, Y, incY); 238 int M = A->getType()->getY(); 239 int N = A->getType()->getX(); 240 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgemv, 241 TransA, 0, 0, 0, 0, M, N, 0, 242 alpha.x, alpha.y, A->getID(), X->getID(), 243 beta.x, beta.y, Y->getID(), incX, incY, 0, 0); 244} 245 246void ScriptIntrinsicBLAS::ZGEMV(RsBlasTranspose TransA, Double2 alpha, const sp<Allocation>& A, const sp<Allocation>& X, 247 int incX, Double2 beta, const sp<Allocation>& Y, int incY) { 248 validateGEMV(mRS, Element::F64_2(mRS), TransA, A, X, incX, Y, incY); 249 int M = A->getType()->getY(); 250 int N = A->getType()->getX(); 251 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgemv, 252 TransA, 0, 0, 0, 0, M, N, 0, 253 alpha.x, alpha.y, A->getID(), X->getID(), 254 beta.x, beta.y, Y->getID(), incX, incY, 0, 0); 255} 256 257void ScriptIntrinsicBLAS::SGBMV(RsBlasTranspose TransA, int KL, int KU, float alpha, const sp<Allocation>& A, 258 const sp<Allocation>& X, int incX, float beta, const sp<Allocation>& Y, int incY) { 259 // GBMV has the same validation requirements as GEMV + KL and KU >= 0 260 validateGEMV(mRS, Element::F32(mRS), TransA, A, X, incX, Y, incY); 261 if (KL < 0 || KU < 0) { 262 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0"); 263 } 264 int M = A->getType()->getY(); 265 int N = A->getType()->getX(); 266 267 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgbmv, 268 TransA, 0, 0, 0, 0, M, N, 0, 269 alpha, A->getID(), X->getID(), 270 beta, Y->getID(), incX, incY, KL, KU); 271} 272 273void ScriptIntrinsicBLAS::DGBMV(RsBlasTranspose TransA, int KL, int KU, double alpha, const sp<Allocation>& A, 274 const sp<Allocation>& X, int incX, double beta, const sp<Allocation>& Y, int incY) { 275 // GBMV has the same validation requirements as GEMV + KL and KU >= 0 276 validateGEMV(mRS, Element::F64(mRS), TransA, A, X, incX, Y, incY); 277 if (KL < 0 || KU < 0) { 278 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0"); 279 } 280 int M = A->getType()->getY(); 281 int N = A->getType()->getX(); 282 283 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgbmv, 284 TransA, 0, 0, 0, 0, M, N, 0, 285 alpha, A->getID(), X->getID(), 286 beta, Y->getID(), incX, incY, KL, KU); 287} 288 289void ScriptIntrinsicBLAS::CGBMV(RsBlasTranspose TransA, int KL, int KU, Float2 alpha, const sp<Allocation>& A, 290 const sp<Allocation>& X, int incX, Float2 beta, const sp<Allocation>& Y, int incY) { 291 // GBMV has the same validation requirements as GEMV + KL and KU >= 0 292 validateGEMV(mRS, Element::F32_2(mRS), TransA, A, X, incX, Y, incY); 293 if (KL < 0 || KU < 0) { 294 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0"); 295 } 296 int M = A->getType()->getY(); 297 int N = A->getType()->getX(); 298 299 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgbmv, 300 TransA, 0, 0, 0, 0, M, N, 0, 301 alpha.x, alpha.y, A->getID(), X->getID(), 302 beta.x, beta.y, Y->getID(), incX, incY, KL, KU); 303} 304 305void ScriptIntrinsicBLAS::ZGBMV(RsBlasTranspose TransA, int KL, int KU, Double2 alpha, const sp<Allocation>& A, 306 const sp<Allocation>& X, int incX, Double2 beta, const sp<Allocation>& Y, int incY) { 307 // GBMV has the same validation requirements as GEMV + KL and KU >= 0 308 validateGEMV(mRS, Element::F64_2(mRS), TransA, A, X, incX, Y, incY); 309 if (KL < 0 || KU < 0) { 310 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0"); 311 } 312 int M = A->getType()->getY(); 313 int N = A->getType()->getX(); 314 315 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgbmv, 316 TransA, 0, 0, 0, 0, M, N, 0, 317 alpha.x, alpha.y, A->getID(), X->getID(), 318 beta.x, beta.y, Y->getID(), incX, incY, KL, KU); 319} 320 321static void validateTRMV(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, RsBlasTranspose TransA, 322 RsBlasDiag Diag, const sp<Allocation>& A, const sp<Allocation>& X, int incX) { 323 int N = A->getType()->getY(); 324 if ((int)A->getType()->getX() != N) { 325 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a square matrix for TRMV"); 326 } 327 if (!A->getType()->getElement()->isCompatible(e) || 328 !X->getType()->getElement()->isCompatible(e)) { 329 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 330 } 331 if (X->getType()->getY() > 1) { 332 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 333 } 334 335 if (incX <= 0) { 336 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 337 } 338 int expectedXDim = 1 + (N - 1) * incX; 339 if ((int)X->getType()->getX() != expectedXDim) { 340 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for TRMV"); 341 } 342} 343 344static int validateTPMV(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, RsBlasTranspose TransA, 345 RsBlasDiag Diag, const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) { 346 if (!Ap->getType()->getElement()->isCompatible(e) || 347 !X->getType()->getElement()->isCompatible(e)) { 348 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 349 } 350 if (X->getType()->getY() > 1) { 351 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 352 } 353 354 if (Ap->getType()->getY() > 1) { 355 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1"); 356 } 357 358 int N = sqrt((double)Ap->getType()->getX() * 2); 359 if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) { 360 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap"); 361 } 362 if (incX <= 0) { 363 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 364 } 365 int expectedXDim = 1 + (N - 1) * incX; 366 if ((int)X->getType()->getX() != expectedXDim) { 367 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for TPMV"); 368 } 369 370 return N; 371} 372 373 374void ScriptIntrinsicBLAS::STRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 375 const sp<Allocation>& A, const sp<Allocation>& X, int incX) { 376 validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX); 377 int N = A->getType()->getY(); 378 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strmv, 379 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 380 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 381} 382 383void ScriptIntrinsicBLAS::DTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 384 const sp<Allocation>& A, const sp<Allocation>& X, int incX) { 385 validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX); 386 int N = A->getType()->getY(); 387 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrmv, 388 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 389 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 390} 391 392void ScriptIntrinsicBLAS::CTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 393 const sp<Allocation>& A, const sp<Allocation>& X, int incX) { 394 validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX); 395 int N = A->getType()->getY(); 396 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrmv, 397 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, 398 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 399} 400 401void ScriptIntrinsicBLAS::ZTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 402 const sp<Allocation>& A, const sp<Allocation>& X, int incX) { 403 validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX); 404 int N = A->getType()->getY(); 405 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrmv, 406 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, 407 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 408} 409 410void ScriptIntrinsicBLAS::STBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 411 int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) { 412 // TBMV has the same requirements as TRMV + K >= 0 413 if (K < 0) { 414 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); 415 } 416 validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX); 417 int N = A->getType()->getY(); 418 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stbmv, 419 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 420 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 421} 422 423void ScriptIntrinsicBLAS::DTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 424 int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) { 425 // TBMV has the same requirements as TRMV + K >= 0 426 if (K < 0) { 427 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); 428 } 429 validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX); 430 int N = A->getType()->getY(); 431 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtbmv, 432 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 433 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 434} 435 436void ScriptIntrinsicBLAS::CTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 437 int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) { 438 // TBMV has the same requirements as TRMV + K >= 0 439 if (K < 0) { 440 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); 441 } 442 validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX); 443 int N = A->getType()->getY(); 444 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctbmv, 445 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, 446 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 447} 448 449void ScriptIntrinsicBLAS::ZTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 450 int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) { 451 // TBMV has the same requirements as TRMV + K >= 0 452 if (K < 0) { 453 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); 454 } 455 validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX); 456 int N = A->getType()->getY(); 457 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztbmv, 458 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, 459 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 460} 461 462void ScriptIntrinsicBLAS::STPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 463 const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) { 464 int N = validateTPMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, Ap, X, incX); 465 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stpmv, 466 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 467 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 468} 469 470void ScriptIntrinsicBLAS::DTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 471 const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) { 472 int N = validateTPMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, Ap, X, incX); 473 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtpmv, 474 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 475 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 476} 477 478void ScriptIntrinsicBLAS::CTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 479 const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) { 480 int N = validateTPMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX); 481 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctpmv, 482 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, 483 Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 484} 485 486void ScriptIntrinsicBLAS::ZTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 487 const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) { 488 int N = validateTPMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX); 489 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztpmv, 490 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, 491 Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 492} 493 494void ScriptIntrinsicBLAS::STRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 495 const sp<Allocation>& A, const sp<Allocation>& X, int incX) { 496 // TRSV is the same as TRMV 497 validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX); 498 int N = A->getType()->getY(); 499 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strsv, 500 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 501 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 502} 503 504void ScriptIntrinsicBLAS::DTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 505 const sp<Allocation>& A, const sp<Allocation>& X, int incX) { 506 // TRSV is the same as TRMV 507 validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX); 508 int N = A->getType()->getY(); 509 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrsv, 510 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 511 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 512 513} 514 515void ScriptIntrinsicBLAS::CTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 516 const sp<Allocation>& A, const sp<Allocation>& X, int incX) { 517 // TRSV is the same as TRMV 518 validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX); 519 int N = A->getType()->getY(); 520 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrsv, 521 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, 522 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 523 524} 525 526void ScriptIntrinsicBLAS::ZTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 527 const sp<Allocation>& A, const sp<Allocation>& X, int incX) { 528 // TRSV is the same as TRMV 529 validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX); 530 int N = A->getType()->getY(); 531 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrsv, 532 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, 533 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 534 535} 536 537void ScriptIntrinsicBLAS::STBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 538 int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) { 539 // TBSV is the same as TRMV + K >= 0 540 validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX); 541 int N = A->getType()->getY(); 542 if (K < 0) { 543 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive"); 544 } 545 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stbsv, 546 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 547 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 548} 549 550void ScriptIntrinsicBLAS::DTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 551 int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) { 552 // TBSV is the same as TRMV + K >= 0 553 validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX); 554 int N = A->getType()->getY(); 555 if (K < 0) { 556 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive"); 557 } 558 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtbsv, 559 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 560 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 561} 562 563void ScriptIntrinsicBLAS::CTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 564 int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) { 565 // TBSV is the same as TRMV + K >= 0 566 validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX); 567 int N = A->getType()->getY(); 568 if (K < 0) { 569 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive"); 570 } 571 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctbsv, 572 TransA, 0, 0, Uplo, Diag, 0, N, K, 573 0, 0, A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 574} 575 576void ScriptIntrinsicBLAS::ZTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 577 int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) { 578 // TBSV is the same as TRMV + K >= 0 579 validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX); 580 int N = A->getType()->getY(); 581 if (K < 0) { 582 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive"); 583 } 584 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztbsv, 585 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, 586 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 587} 588 589void ScriptIntrinsicBLAS::STPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 590 const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) { 591 // TPSV is same as TPMV 592 int N = validateTPMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, Ap, X, incX); 593 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stpsv, 594 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 595 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 596} 597 598void ScriptIntrinsicBLAS::DTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 599 const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) { 600 // TPSV is same as TPMV 601 int N = validateTPMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, Ap, X, incX); 602 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtpsv, 603 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 604 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 605} 606 607void ScriptIntrinsicBLAS::CTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 608 const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) { 609 // TPSV is same as TPMV 610 int N = validateTPMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX); 611 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctpsv, 612 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, 613 Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 614} 615 616void ScriptIntrinsicBLAS::ZTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 617 const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) { 618 // TPSV is same as TPMV 619 int N = validateTPMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX); 620 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztpsv, 621 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, 622 Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 623} 624 625/** 626 * Level 2, S and D only 627 */ 628static int validateSYMV(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, const sp<Allocation>& A, 629 const sp<Allocation>& X, const sp<Allocation>& Y, int incX, int incY) { 630 int N = A->getType()->getY(); 631 if ((int)A->getType()->getX() != N) { 632 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a square matrix for SYMV"); 633 } 634 if (!A->getType()->getElement()->isCompatible(e) || 635 !X->getType()->getElement()->isCompatible(e) || 636 !Y->getType()->getElement()->isCompatible(e) ) { 637 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 638 } 639 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { 640 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 641 } 642 643 if (incX <= 0 || incY <= 0) { 644 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 645 } 646 int expectedXDim = 1 + (N - 1) * incX; 647 if ((int)X->getType()->getX() != expectedXDim) { 648 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYMV"); 649 } 650 int expectedYDim = 1 + (N - 1) * incY; 651 if ((int)Y->getType()->getX() != expectedYDim) { 652 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYMV"); 653 } 654 return N; 655} 656static int validateSPMV(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, const sp<Allocation>& Ap, 657 const sp<Allocation>& X, int incX, const sp<Allocation>& Y, int incY) { 658 if (!Ap->getType()->getElement()->isCompatible(e) || 659 !X->getType()->getElement()->isCompatible(e) || 660 !Y->getType()->getElement()->isCompatible(e)) { 661 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 662 } 663 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { 664 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 665 } 666 667 if (Ap->getType()->getY() > 1) { 668 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1"); 669 } 670 671 int N = sqrt((double)Ap->getType()->getX() * 2); 672 if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) { 673 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap"); 674 } 675 if (incX <= 0 || incY <= 0) { 676 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 677 } 678 int expectedXDim = 1 + (N - 1) * incX; 679 if ((int)X->getType()->getX() != expectedXDim) { 680 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPMV"); 681 } 682 int expectedYDim = 1 + (N - 1) * incY; 683 if ((int)Y->getType()->getX() != expectedYDim) { 684 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPMV"); 685 } 686 687 return N; 688} 689static void validateGER(RS* mRS, const sp<const Element>& e, const sp<Allocation>& X, int incX, 690 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) { 691 if (!A->getType()->getElement()->isCompatible(e) || 692 !X->getType()->getElement()->isCompatible(e) || 693 !Y->getType()->getElement()->isCompatible(e) ) { 694 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 695 } 696 697 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { 698 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 699 } 700 701 int M = A->getType()->getY(); 702 int N = A->getType()->getX(); 703 704 if (N < 1 || M < 1) { 705 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "M and N must be 1 or greater for GER"); 706 } 707 if (incX <= 0 || incY <= 0) { 708 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 709 } 710 int expectedXDim = 1 + (M - 1) * incX; 711 if ((int)X->getType()->getX() != expectedXDim) { 712 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GER"); 713 } 714 int expectedYDim = 1 + (N - 1) * incY; 715 if ((int)Y->getType()->getX() != expectedYDim) { 716 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GER"); 717 } 718 719 720} 721static int validateSYR(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, 722 const sp<Allocation>& X, int incX, const sp<Allocation>& A) { 723 if (!A->getType()->getElement()->isCompatible(e) || 724 !X->getType()->getElement()->isCompatible(e)) { 725 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 726 } 727 728 int N = A->getType()->getX(); 729 730 if (X->getType()->getY() > 1) { 731 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 732 } 733 if (N != (int)A->getType()->getY()) { 734 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a symmetric matrix"); 735 } 736 if (incX <= 0) { 737 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 738 } 739 int expectedXDim = 1 + (N - 1) * incX; 740 if ((int)X->getType()->getX() != expectedXDim) { 741 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYR"); 742 } 743 return N; 744} 745static int validateSPR(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, 746 const sp<Allocation>& X, int incX, const sp<Allocation>& Ap) { 747 if (!Ap->getType()->getElement()->isCompatible(e) || 748 !X->getType()->getElement()->isCompatible(e)) { 749 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 750 } 751 if (X->getType()->getY() > 1) { 752 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 753 } 754 755 if (Ap->getType()->getY() > 1) { 756 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1"); 757 } 758 759 int N = sqrt((double)Ap->getType()->getX() * 2); 760 if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) { 761 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap"); 762 } 763 if (incX <= 0) { 764 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 765 } 766 int expectedXDim = 1 + (N - 1) * incX; 767 if ((int)X->getType()->getX() != expectedXDim) { 768 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPR"); 769 } 770 771 return N; 772} 773 774static int validateSYR2(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, const sp<Allocation>& X, 775 int incX, const sp<Allocation>& Y, int incY, const sp<Allocation>& A) { 776 if (!A->getType()->getElement()->isCompatible(e) || 777 !X->getType()->getElement()->isCompatible(e) || 778 !Y->getType()->getElement()->isCompatible(e)) { 779 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 780 } 781 782 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { 783 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 784 } 785 786 int N = A->getType()->getX(); 787 788 if (N != (int)A->getType()->getY()) { 789 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a symmetric matrix"); 790 } 791 if (incX <= 0 || incY <= 0) { 792 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 793 } 794 int expectedXDim = 1 + (N - 1) * incX; 795 int expectedYDim = 1 + (N - 1) * incY; 796 if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) { 797 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYR"); 798 } 799 return N; 800 801} 802static int validateSPR2(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, const sp<Allocation>& X, 803 int incX, const sp<Allocation>& Y, int incY, const sp<Allocation>& Ap) { 804 if (!Ap->getType()->getElement()->isCompatible(e) || 805 !X->getType()->getElement()->isCompatible(e) || 806 !Y->getType()->getElement()->isCompatible(e)) { 807 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 808 } 809 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { 810 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 811 } 812 813 if (Ap->getType()->getY() > 1) { 814 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1"); 815 } 816 817 int N = sqrt((double)Ap->getType()->getX() * 2); 818 if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) { 819 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap"); 820 } 821 if (incX <= 0 || incY <= 0) { 822 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 823 } 824 int expectedXDim = 1 + (N - 1) * incX; 825 int expectedYDim = 1 + (N - 1) * incY; 826 if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) { 827 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPR2"); 828 } 829 830 return N; 831} 832 833void ScriptIntrinsicBLAS::SSYMV(RsBlasUplo Uplo, float alpha, const sp<Allocation>& A, const sp<Allocation>& X, 834 int incX, float beta, const sp<Allocation>& Y, int incY) { 835 int N = validateSYMV(mRS, Element::F32(mRS), Uplo, A, X, Y, incX, incY); 836 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssymv, 837 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 838 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); 839} 840 841void ScriptIntrinsicBLAS::SSBMV(RsBlasUplo Uplo, int K, float alpha, const sp<Allocation>& A, const sp<Allocation>& X, 842 int incX, float beta, const sp<Allocation>& Y, int incY) { 843 // SBMV is the same as SYMV + K >= 0 844 if (K < 0) { 845 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); 846 } 847 int N = validateSYMV(mRS, Element::F32(mRS), Uplo, A, X, Y, incX, incY); 848 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssbmv, 849 0, 0, 0, Uplo, 0, 0, N, K, alpha, 850 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); 851} 852 853void ScriptIntrinsicBLAS::SSPMV(RsBlasUplo Uplo, float alpha, const sp<Allocation>& Ap, const sp<Allocation>& X, 854 int incX, float beta, const sp<Allocation>& Y, int incY) { 855 int N = validateSPMV(mRS, Element::F32(mRS), Uplo, Ap, X, incX, Y, incY); 856 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspmv, 857 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 858 Ap->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); 859} 860 861void ScriptIntrinsicBLAS::SGER(float alpha, const sp<Allocation>& X, int incX, 862 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) { 863 int M = A->getType()->getY(); 864 int N = A->getType()->getX(); 865 validateGER(mRS, Element::F32(mRS), X, incX, Y, incY, A); 866 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sger, 867 0, 0, 0, 0, 0, M, N, 0, alpha, 868 X->getID(), Y->getID(), 0.f, A->getID(), incX, incY, 0, 0); 869} 870 871void ScriptIntrinsicBLAS::SSYR(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X, 872 int incX, const sp<Allocation>& A) { 873 int N = validateSYR(mRS, Element::F32(mRS), Uplo, X, incX, A); 874 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr, 875 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 876 X->getID(), A->getID(), 0.f, 0, incX, 0, 0, 0); 877} 878 879void ScriptIntrinsicBLAS::SSPR(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X, 880 int incX, const sp<Allocation>& Ap) { 881 int N = validateSPR(mRS, Element::F32(mRS), Uplo, X, incX, Ap); 882 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspr, 883 0, 0, 0, Uplo, 0, 0, N, 0, 884 alpha, X->getID(), Ap->getID(), 0.f, 0, incX, 0, 0, 0); 885} 886 887void ScriptIntrinsicBLAS::SSYR2(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X, int incX, 888 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) { 889 int N = validateSYR2(mRS, Element::F32(mRS), Uplo, X, incX, Y, incY, A); 890 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr2, 891 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 892 X->getID(), Y->getID(), 0, A->getID(), incX, incY, 0, 0); 893} 894 895void ScriptIntrinsicBLAS::SSPR2(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X, int incX, 896 const sp<Allocation>& Y, int incY, const sp<Allocation>& Ap) { 897 int N = validateSPR2(mRS, Element::F32(mRS), Uplo, X, incX, Y, incY, Ap); 898 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspr2, 899 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 900 X->getID(), Y->getID(), 0, Ap->getID(), incX, incY, 0, 0); 901} 902 903void ScriptIntrinsicBLAS::DSYMV(RsBlasUplo Uplo, double alpha, const sp<Allocation>& A, const sp<Allocation>& X, 904 int incX, double beta, const sp<Allocation>& Y, int incY) { 905 int N = validateSYMV(mRS, Element::F64(mRS), Uplo, A, X, Y, incX, incY); 906 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsymv, 907 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 908 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); 909} 910 911void ScriptIntrinsicBLAS::DSBMV(RsBlasUplo Uplo, int K, double alpha, const sp<Allocation>& A, const sp<Allocation>& X, 912 int incX, double beta, const sp<Allocation>& Y, int incY) { 913 // SBMV is the same as SYMV + K >= 0 914 if (K < 0) { 915 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); 916 } 917 int N = validateSYMV(mRS, Element::F64(mRS), Uplo, A, X, Y, incX, incY); 918 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsbmv, 919 0, 0, 0, Uplo, 0, 0, N, K, alpha, 920 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); 921} 922 923void ScriptIntrinsicBLAS::DSPMV(RsBlasUplo Uplo, double alpha, const sp<Allocation>& Ap, const sp<Allocation>& X, 924 int incX, double beta, const sp<Allocation>& Y, int incY) { 925 int N = validateSPMV(mRS, Element::F64(mRS), Uplo, Ap, X, incX, Y, incY); 926 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspmv, 927 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 928 Ap->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); 929} 930 931void ScriptIntrinsicBLAS::DGER(double alpha, const sp<Allocation>& X, int incX, const sp<Allocation>& Y, 932 int incY, const sp<Allocation>& A) { 933 int M = A->getType()->getY(); 934 int N = A->getType()->getX(); 935 validateGER(mRS, Element::F64(mRS), X, incX, Y, incY, A); 936 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dger, 937 0, 0, 0, 0, 0, M, N, 0, alpha, 938 X->getID(), Y->getID(), 0.f, A->getID(), incX, incY, 0, 0); 939} 940 941void ScriptIntrinsicBLAS::DSYR(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X, 942 int incX, const sp<Allocation>& A) { 943 int N = validateSYR(mRS, Element::F64(mRS), Uplo, X, incX, A); 944 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr, 945 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 946 X->getID(), A->getID(), 0.f, 0, incX, 0, 0, 0); 947} 948 949void ScriptIntrinsicBLAS::DSPR(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X, 950 int incX, const sp<Allocation>& Ap) { 951 int N = validateSPR(mRS, Element::F64(mRS), Uplo, X, incX, Ap); 952 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspr, 953 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 954 X->getID(), Ap->getID(), 0.f, 0, incX, 0, 0, 0); 955} 956 957void ScriptIntrinsicBLAS::DSYR2(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X, int incX, 958 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) { 959 int N = validateSYR2(mRS, Element::F64(mRS), Uplo, X, incX, Y, incY, A); 960 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr2, 961 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 962 X->getID(), Y->getID(), 0, A->getID(), incX, incY, 0, 0); 963} 964 965void ScriptIntrinsicBLAS::DSPR2(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X, int incX, 966 const sp<Allocation>& Y, int incY, const sp<Allocation>& Ap) { 967 int N = validateSPR2(mRS, Element::F64(mRS), Uplo, X, incX, Y, incY, Ap); 968 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspr2, 969 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 970 X->getID(), Y->getID(), 0, Ap->getID(), incX, incY, 0, 0); 971} 972 973 974/** 975 * Level 2, C and Z only 976 */ 977 978static void validateGERU(RS* mRS, const sp<const Element>& e, const sp<Allocation>& X, int incX, 979 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) { 980 if (!A->getType()->getElement()->isCompatible(e) || 981 !X->getType()->getElement()->isCompatible(e) || 982 !Y->getType()->getElement()->isCompatible(e)) { 983 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 984 } 985 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { 986 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 987 } 988 989 int M = A->getType()->getY(); 990 int N = A->getType()->getX(); 991 if (incX <= 0 || incY <= 0) { 992 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 993 } 994 int expectedXDim = 1 + (M - 1) * incX; 995 if ((int)X->getType()->getX() != expectedXDim) { 996 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GERU"); 997 } 998 int expectedYDim = 1 + (N - 1) * incY; 999 if ((int)Y->getType()->getX() != expectedYDim) { 1000 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GERU"); 1001 } 1002 1003} 1004 1005void ScriptIntrinsicBLAS::CHEMV(RsBlasUplo Uplo, Float2 alpha, const sp<Allocation>& A, 1006 const sp<Allocation>& X, int incX, Float2 beta, const sp<Allocation>& Y, int incY) { 1007 // HEMV is the same as SYR2 validation-wise 1008 int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A); 1009 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chemv, 1010 0, 0, 0, Uplo, 0, 0, N, 0, 1011 alpha.x, alpha.y, A->getID(), X->getID(), 1012 beta.x, beta.y, Y->getID(), incX, incY, 0, 0); 1013} 1014 1015void ScriptIntrinsicBLAS::CHBMV(RsBlasUplo Uplo, int K, Float2 alpha, const sp<Allocation>& A, 1016 const sp<Allocation>& X, int incX, Float2 beta, const sp<Allocation>& Y, int incY) { 1017 // HBMV is the same as SYR2 validation-wise 1018 int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A); 1019 if (K < 0) { 1020 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be 0 or greater for HBMV"); 1021 } 1022 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chbmv, 1023 0, 0, 0, Uplo, 0, 0, N, K, 1024 alpha.x, alpha.y, A->getID(), X->getID(), 1025 beta.x, beta.y, Y->getID(), incX, incY, 0, 0); 1026} 1027 1028void ScriptIntrinsicBLAS::CHPMV(RsBlasUplo Uplo, Float2 alpha, const sp<Allocation>& Ap, 1029 const sp<Allocation>& X, int incX, Float2 beta, const sp<Allocation>& Y, int incY) { 1030 // HPMV is the same as SPR2 1031 int N = validateSPR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, Ap); 1032 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpmv, 1033 0, 0, 0, Uplo, 0, 0, N, 0, 1034 alpha.x, alpha.y, Ap->getID(), X->getID(), 1035 beta.x, beta.y, Y->getID(), incX, incY, 0, 0); 1036} 1037 1038void ScriptIntrinsicBLAS::CGERU(Float2 alpha, const sp<Allocation>& X, int incX, 1039 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) { 1040 validateGERU(mRS, Element::F32_2(mRS), X, incX, Y, incY, A); 1041 int M = A->getType()->getY(); 1042 int N = A->getType()->getX(); 1043 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgeru, 1044 0, 0, 0, 0, 0, M, N, 0, 1045 alpha.x, alpha.y, X->getID(), Y->getID(), 1046 0, 0, A->getID(), incX, incY, 0, 0); 1047} 1048 1049void ScriptIntrinsicBLAS::CGERC(Float2 alpha, const sp<Allocation>& X, int incX, 1050 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) { 1051 // Same as GERU 1052 validateGERU(mRS, Element::F32_2(mRS), X, incX, Y, incY, A); 1053 int M = A->getType()->getY(); 1054 int N = A->getType()->getX(); 1055 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgerc, 1056 0, 0, 0, 0, 0, M, N, 0, 1057 alpha.x, alpha.y, X->getID(), Y->getID(), 1058 0, 0, A->getID(), incX, incY, 0, 0); 1059} 1060 1061void ScriptIntrinsicBLAS::CHER(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X, 1062 int incX, const sp<Allocation>& A) { 1063 // Same as SYR 1064 int N = validateSYR(mRS, Element::F32_2(mRS), Uplo, X, incX, A); 1065 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher, 1066 0, 0, 0, Uplo, 0, 0, N, 0, 1067 alpha, 0, X->getID(), 0, 1068 0, 0, A->getID(), incX, 0, 0, 0); 1069} 1070 1071void ScriptIntrinsicBLAS::CHPR(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X, 1072 int incX, const sp<Allocation>& Ap) { 1073 // Equivalent to SPR for validation 1074 int N = validateSPR(mRS, Element::F32_2(mRS), Uplo, X, incX, Ap); 1075 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpr, 1076 0, 0, 0, Uplo, 0, 0, N, 0, 1077 alpha, 0, X->getID(), 0, 1078 0, 0, Ap->getID(), incX, 0, 0, 0); 1079} 1080 1081void ScriptIntrinsicBLAS::CHER2(RsBlasUplo Uplo, Float2 alpha, const sp<Allocation>& X, int incX, 1082 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) { 1083 // Same as SYR2 1084 int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A); 1085 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher2, 1086 0, 0, 0, Uplo, 0, 0, N, 0, 1087 alpha.x, alpha.y, X->getID(), Y->getID(), 1088 0, 0, A->getID(), incX, incY, 0, 0); 1089} 1090 1091void ScriptIntrinsicBLAS::CHPR2(RsBlasUplo Uplo, Float2 alpha, const sp<Allocation>& X, int incX, 1092 const sp<Allocation>& Y, int incY, const sp<Allocation>& Ap) { 1093 // Same as SPR2 1094 int N = validateSPR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, Ap); 1095 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpr2, 1096 0, 0, 0, Uplo, 0, 0, N, 0, 1097 alpha.x, alpha.y, X->getID(), Y->getID(), 1098 0, 0, Ap->getID(), incX, incY, 0, 0); 1099} 1100 1101void ScriptIntrinsicBLAS::ZHEMV(RsBlasUplo Uplo, Double2 alpha, const sp<Allocation>& A, 1102 const sp<Allocation>& X, int incX, Double2 beta, const sp<Allocation>& Y, int incY) { 1103 // HEMV is the same as SYR2 validation-wise 1104 int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A); 1105 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhemv, 1106 0, 0, 0, Uplo, 0, 0, N, 0, 1107 alpha.x, alpha.y, A->getID(), X->getID(), 1108 beta.x, beta.y, Y->getID(), incX, incY, 0, 0); 1109} 1110 1111void ScriptIntrinsicBLAS::ZHBMV(RsBlasUplo Uplo, int K, Double2 alpha, const sp<Allocation>& A, const sp<Allocation>& X, 1112 int incX, Double2 beta, const sp<Allocation>& Y, int incY) { 1113 // HBMV is the same as SYR2 validation-wise 1114 int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A); 1115 if (K < 0) { 1116 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be 0 or greater for HBMV"); 1117 } 1118 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhbmv, 1119 0, 0, 0, Uplo, 0, 0, N, K, 1120 alpha.x, alpha.y, A->getID(), X->getID(), 1121 beta.x, beta.y, Y->getID(), incX, incY, 0, 0); 1122} 1123 1124void ScriptIntrinsicBLAS::ZHPMV(RsBlasUplo Uplo, Double2 alpha, const sp<Allocation>& Ap, const sp<Allocation>& X, 1125 int incX, Double2 beta, const sp<Allocation>& Y, int incY) { 1126 // HPMV is the same as SPR2 1127 int N = validateSPR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, Ap); 1128 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpmv, 1129 0, 0, 0, Uplo, 0, 0, N, 0, 1130 alpha.x, alpha.y, Ap->getID(), X->getID(), 1131 beta.x, beta.y, Y->getID(), incX, incY, 0, 0); 1132} 1133 1134void ScriptIntrinsicBLAS::ZGERU(Double2 alpha, const sp<Allocation>& X, int incX, 1135 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) { 1136 validateGERU(mRS, Element::F64_2(mRS), X, incX, Y, incY, A); 1137 int M = A->getType()->getY(); 1138 int N = A->getType()->getX(); 1139 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgeru, 1140 0, 0, 0, 0, 0, M, N, 0, 1141 alpha.x, alpha.y, X->getID(), Y->getID(), 1142 0, 0, A->getID(), incX, incY, 0, 0); 1143} 1144 1145void ScriptIntrinsicBLAS::ZGERC(Double2 alpha, const sp<Allocation>& X, int incX, 1146 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) { 1147 // Same as GERU 1148 validateGERU(mRS, Element::F64_2(mRS), X, incX, Y, incY, A); 1149 int M = A->getType()->getY(); 1150 int N = A->getType()->getX(); 1151 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgerc, 1152 0, 0, 0, 0, 0, M, N, 0, 1153 alpha.x, alpha.y, X->getID(), Y->getID(), 1154 0, 0, A->getID(), incX, incY, 0, 0); 1155} 1156 1157void ScriptIntrinsicBLAS::ZHER(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X, 1158 int incX, const sp<Allocation>& A) { 1159 // Same as SYR 1160 int N = validateSYR(mRS, Element::F64_2(mRS), Uplo, X, incX, A); 1161 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher, 1162 0, 0, 0, Uplo, 0, 0, N, 0, 1163 alpha, 0, X->getID(), 0, 1164 0, 0, A->getID(), incX, 0, 0, 0); 1165} 1166 1167void ScriptIntrinsicBLAS::ZHPR(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X, 1168 int incX, const sp<Allocation>& Ap) { 1169 // Equivalent to SPR for validation 1170 int N = validateSPR(mRS, Element::F64_2(mRS), Uplo, X, incX, Ap); 1171 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpr, 1172 0, 0, 0, Uplo, 0, 0, N, 0, 1173 alpha, 0, X->getID(), 0, 1174 0, 0, Ap->getID(), incX, 0, 0, 0); 1175} 1176 1177void ScriptIntrinsicBLAS::ZHER2(RsBlasUplo Uplo, Double2 alpha, const sp<Allocation>& X, int incX, 1178 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) { 1179 // Same as SYR2 1180 int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A); 1181 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher2, 1182 0, 0, 0, Uplo, 0, 0, N, 0, 1183 alpha.x, alpha.y, X->getID(), Y->getID(), 1184 0, 0, A->getID(), incX, incY, 0, 0); 1185} 1186 1187void ScriptIntrinsicBLAS::ZHPR2(RsBlasUplo Uplo, Double2 alpha, const sp<Allocation>& X, int incX, 1188 const sp<Allocation>& Y, int incY, const sp<Allocation>& Ap) { 1189 // Same as SPR2 1190 int N = validateSPR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, Ap); 1191 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpr2, 1192 0, 0, 0, Uplo, 0, 0, N, 0, 1193 alpha.x, alpha.y, X->getID(), Y->getID(), 1194 0, 0, Ap->getID(), incX, incY, 0, 0); 1195} 1196 1197 1198/** 1199 * Level 3 BLAS 1200 */ 1201 1202static void validateL3(RS* mRS, const sp<const Element>& e, int TransA, int TransB, int Side, 1203 const sp<Allocation>& A, const sp<Allocation>& B, const sp<Allocation>& C) { 1204 int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1; 1205 if ((A != nullptr && !A->getType()->getElement()->isCompatible(e)) || 1206 (B != nullptr && !B->getType()->getElement()->isCompatible(e)) || 1207 (C != nullptr && !C->getType()->getElement()->isCompatible(e))) { 1208 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 1209 } 1210 if (C == nullptr) { 1211 // Since matrix C is used to store the result, it cannot be null. 1212 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Allocation C cannot be null"); 1213 } 1214 cM = C->getType()->getY(); 1215 cN = C->getType()->getX(); 1216 1217 if (Side == RsBlasRight) { 1218 if ((A == nullptr && B != nullptr) || (A != nullptr && B == nullptr)) { 1219 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Provided Matrix A without Matrix B, or vice versa"); 1220 } 1221 if (B != nullptr) { 1222 bM = A->getType()->getY(); 1223 bN = A->getType()->getX(); 1224 } 1225 if (A != nullptr) { 1226 aM = B->getType()->getY(); 1227 aN = B->getType()->getX(); 1228 } 1229 } else { 1230 if (A != nullptr) { 1231 if (TransA == RsBlasTrans || TransA == RsBlasConjTrans) { 1232 aN = A->getType()->getY(); 1233 aM = A->getType()->getX(); 1234 } else { 1235 aM = A->getType()->getY(); 1236 aN = A->getType()->getX(); 1237 } 1238 } 1239 if (B != nullptr) { 1240 if (TransB == RsBlasTrans || TransB == RsBlasConjTrans) { 1241 bN = B->getType()->getY(); 1242 bM = B->getType()->getX(); 1243 } else { 1244 bM = B->getType()->getY(); 1245 bN = B->getType()->getX(); 1246 } 1247 } 1248 } 1249 if (A != nullptr && B != nullptr && C != nullptr) { 1250 if (aN != bM || aM != cM || bN != cN) { 1251 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions"); 1252 } 1253 } else if (A != nullptr && C != nullptr) { 1254 // A and C only, for SYRK 1255 if (cM != cN) { 1256 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix C is not symmetric"); 1257 } 1258 if (aM != cM) { 1259 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions"); 1260 } 1261 } else if (A != nullptr && B != nullptr) { 1262 // A and B only 1263 if (aN != bM) { 1264 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions"); 1265 } 1266 } 1267 1268} 1269 1270void ScriptIntrinsicBLAS::SGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, float alpha, 1271 const sp<Allocation>& A, const sp<Allocation>& B, float beta, const sp<Allocation>& C) { 1272 validateL3(mRS, Element::F32(mRS), TransA, TransB, 0, A, B, C); 1273 1274 int M = -1, N = -1, K = -1; 1275 if (TransA != RsBlasNoTrans) { 1276 M = A->getType()->getX(); 1277 K = A->getType()->getY(); 1278 } else { 1279 M = A->getType()->getY(); 1280 K = A->getType()->getX(); 1281 } 1282 if (TransB != RsBlasNoTrans) { 1283 N = B->getType()->getY(); 1284 } else { 1285 N = B->getType()->getX(); 1286 } 1287 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgemm, 1288 TransA, TransB, 0, 0, 0, M, N, K, 1289 alpha, A->getID(), B->getID(), 1290 beta, C->getID(), 0, 0, 0, 0); 1291} 1292 1293void ScriptIntrinsicBLAS::DGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, double alpha, 1294 const sp<Allocation>& A, const sp<Allocation>& B, double beta, const sp<Allocation>& C) { 1295 validateL3(mRS, Element::F64(mRS), TransA, TransB, 0, A, B, C); 1296 int M = -1, N = -1, K = -1; 1297 if (TransA != RsBlasNoTrans) { 1298 M = A->getType()->getX(); 1299 K = A->getType()->getY(); 1300 } else { 1301 M = A->getType()->getY(); 1302 K = A->getType()->getX(); 1303 } 1304 if (TransB != RsBlasNoTrans) { 1305 N = B->getType()->getY(); 1306 } else { 1307 N = B->getType()->getX(); 1308 } 1309 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgemm, 1310 TransA, TransB, 0, 0, 0, M, N, K, 1311 alpha, A->getID(), B->getID(), 1312 beta, C->getID(), 0, 0, 0, 0); 1313} 1314 1315void ScriptIntrinsicBLAS::CGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, Float2 alpha, 1316 const sp<Allocation>& A, const sp<Allocation>& B, Float2 beta, const sp<Allocation>& C) { 1317 validateL3(mRS, Element::F32_2(mRS), TransA, TransB, 0, A, B, C); 1318 int M = -1, N = -1, K = -1; 1319 if (TransA != RsBlasNoTrans) { 1320 M = A->getType()->getX(); 1321 K = A->getType()->getY(); 1322 } else { 1323 M = A->getType()->getY(); 1324 K = A->getType()->getX(); 1325 } 1326 if (TransB != RsBlasNoTrans) { 1327 N = B->getType()->getY(); 1328 } else { 1329 N = B->getType()->getX(); 1330 } 1331 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgemm, 1332 TransA, TransB, 0, 0, 0, M, N, K, 1333 alpha.x, alpha.y, A->getID(), B->getID(), 1334 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1335} 1336 1337void ScriptIntrinsicBLAS::ZGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, Double2 alpha, 1338 const sp<Allocation>& A, const sp<Allocation>& B, Double2 beta, const sp<Allocation>& C) { 1339 validateL3(mRS, Element::F64_2(mRS), TransA, TransB, 0, A, B, C); 1340 int M = -1, N = -1, K = -1; 1341 if (TransA != RsBlasNoTrans) { 1342 M = A->getType()->getX(); 1343 K = A->getType()->getY(); 1344 } else { 1345 M = A->getType()->getY(); 1346 K = A->getType()->getX(); 1347 } 1348 if (TransB != RsBlasNoTrans) { 1349 N = B->getType()->getY(); 1350 } else { 1351 N = B->getType()->getX(); 1352 } 1353 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgemm, 1354 TransA, TransB, 0, 0, 0, M, N, K, 1355 alpha.x, alpha.y, A->getID(), B->getID(), 1356 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1357} 1358 1359void ScriptIntrinsicBLAS::SSYMM(RsBlasSide Side, RsBlasUplo Uplo, float alpha, 1360 const sp<Allocation>& A, const sp<Allocation>& B, float beta, const sp<Allocation>& C) { 1361 //For SYMM, Matrix A should be symmetric 1362 if (A->getType()->getX() != A->getType()->getY()) { 1363 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric"); 1364 } 1365 validateL3(mRS, Element::F32(mRS), 0, 0, Side, A, B, C); 1366 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssymm, 1367 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0, 1368 alpha, A->getID(), B->getID(), 1369 beta, C->getID(), 0, 0, 0, 0); 1370} 1371 1372void ScriptIntrinsicBLAS::DSYMM(RsBlasSide Side, RsBlasUplo Uplo, double alpha, 1373 const sp<Allocation>& A, const sp<Allocation>& B, double beta, const sp<Allocation>& C) { 1374 if (A->getType()->getX() != A->getType()->getY()) { 1375 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric"); 1376 } 1377 validateL3(mRS, Element::F64(mRS), 0, 0, Side, A, B, C); 1378 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsymm, 1379 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0, 1380 alpha, A->getID(), B->getID(), 1381 beta, C->getID(), 0, 0, 0, 0); 1382} 1383 1384void ScriptIntrinsicBLAS::CSYMM(RsBlasSide Side, RsBlasUplo Uplo, Float2 alpha, 1385 const sp<Allocation>& A, const sp<Allocation>& B, Float2 beta, const sp<Allocation>& C) { 1386 if (A->getType()->getX() != A->getType()->getY()) { 1387 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric"); 1388 } 1389 validateL3(mRS, Element::F32_2(mRS), 0, 0, Side, A, B, C); 1390 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csymm, 1391 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0, 1392 alpha.x, alpha.y, A->getID(), B->getID(), 1393 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1394} 1395 1396void ScriptIntrinsicBLAS::ZSYMM(RsBlasSide Side, RsBlasUplo Uplo, Double2 alpha, 1397 const sp<Allocation>& A, const sp<Allocation>& B, Double2 beta, const sp<Allocation>& C) { 1398 if (A->getType()->getX() != A->getType()->getY()) { 1399 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric"); 1400 } 1401 validateL3(mRS, Element::F64_2(mRS), 0, 0, Side, A, B, C); 1402 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsymm, 1403 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0, 1404 alpha.x, alpha.y, A->getID(), B->getID(), 1405 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1406} 1407 1408void ScriptIntrinsicBLAS::SSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha, 1409 const sp<Allocation>& A, float beta, const sp<Allocation>& C) { 1410 validateL3(mRS, Element::F32(mRS), Trans, 0, 0, A, nullptr, C); 1411 int K = -1; 1412 if (Trans != RsBlasNoTrans) { 1413 K = A->getType()->getY(); 1414 } else { 1415 K = A->getType()->getX(); 1416 } 1417 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyrk, 1418 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, 1419 alpha, A->getID(), 0, 1420 beta, C->getID(), 0, 0, 0, 0); 1421} 1422 1423void ScriptIntrinsicBLAS::DSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha, 1424 const sp<Allocation>& A, double beta, const sp<Allocation>& C) { 1425 validateL3(mRS, Element::F64(mRS), Trans, 0, 0, A, nullptr, C); 1426 int K = -1; 1427 if (Trans != RsBlasNoTrans) { 1428 K = A->getType()->getY(); 1429 } else { 1430 K = A->getType()->getX(); 1431 } 1432 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyrk, 1433 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, 1434 alpha, A->getID(), 0, 1435 beta, C->getID(), 0, 0, 0, 0); 1436} 1437 1438void ScriptIntrinsicBLAS::CSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha, 1439 const sp<Allocation>& A, Float2 beta, const sp<Allocation>& C) { 1440 validateL3(mRS, Element::F32_2(mRS), Trans, 0, 0, A, nullptr, C); 1441 int K = -1; 1442 if (Trans != RsBlasNoTrans) { 1443 K = A->getType()->getY(); 1444 } else { 1445 K = A->getType()->getX(); 1446 } 1447 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csyrk, 1448 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, 1449 alpha.x, alpha.y, A->getID(), 0, 1450 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1451} 1452 1453void ScriptIntrinsicBLAS::ZSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha, 1454 const sp<Allocation>& A, Double2 beta, const sp<Allocation>& C) { 1455 validateL3(mRS, Element::F64_2(mRS), Trans, 0, 0, A, nullptr, C); 1456 int K = -1; 1457 if (Trans != RsBlasNoTrans) { 1458 K = A->getType()->getY(); 1459 } else { 1460 K = A->getType()->getX(); 1461 } 1462 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsyrk, 1463 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, 1464 alpha.x, alpha.y, A->getID(), 0, 1465 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1466} 1467 1468static void validateSYR2K(RS* mRS, const sp<const Element>& e, RsBlasTranspose Trans, 1469 const sp<Allocation>& A, const sp<Allocation>& B, const sp<Allocation>& C) { 1470 if (!A->getType()->getElement()->isCompatible(e) || 1471 !B->getType()->getElement()->isCompatible(e) || 1472 !C->getType()->getElement()->isCompatible(e)) { 1473 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 1474 } 1475 int Cdim = -1; 1476 // A is n x k if no transpose, k x n if transpose 1477 // C is n x n 1478 if (Trans == RsBlasTrans) { 1479 // check columns versus C 1480 Cdim = A->getType()->getX(); 1481 } else { 1482 // check rows versus C 1483 Cdim = A->getType()->getY(); 1484 } 1485 if ((int)C->getType()->getX() != Cdim || (int)C->getType()->getY() != Cdim) { 1486 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid symmetric matrix in SYR2K"); 1487 } 1488 // A dims == B dims 1489 if (A->getType()->getX() != B->getType()->getX() || A->getType()->getY() != B->getType()->getY()) { 1490 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid A and B in SYR2K"); 1491 } 1492} 1493 1494void ScriptIntrinsicBLAS::SSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha, 1495 const sp<Allocation>& A, const sp<Allocation>& B, float beta, const sp<Allocation>& C) { 1496 validateSYR2K(mRS, Element::F32(mRS), Trans, A, B, C); 1497 int K = -1; 1498 if (Trans != RsBlasNoTrans) { 1499 K = A->getType()->getY(); 1500 } else { 1501 K = A->getType()->getX(); 1502 } 1503 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr2k, 1504 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, 1505 alpha, A->getID(), B->getID(), 1506 beta, C->getID(), 0, 0, 0, 0); 1507} 1508 1509void ScriptIntrinsicBLAS::DSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha, 1510 const sp<Allocation>& A, const sp<Allocation>& B, double beta, const sp<Allocation>& C) { 1511 validateSYR2K(mRS, Element::F64(mRS), Trans, A, B, C); 1512 int K = -1; 1513 if (Trans != RsBlasNoTrans) { 1514 K = A->getType()->getY(); 1515 } else { 1516 K = A->getType()->getX(); 1517 } 1518 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr2k, 1519 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, 1520 alpha, A->getID(), B->getID(), 1521 beta, C->getID(), 0, 0, 0, 0); 1522} 1523 1524void ScriptIntrinsicBLAS::CSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha, 1525 const sp<Allocation>& A, const sp<Allocation>& B, Float2 beta, const sp<Allocation>& C) { 1526 validateSYR2K(mRS, Element::F32_2(mRS), Trans, A, B, C); 1527 int K = -1; 1528 if (Trans != RsBlasNoTrans) { 1529 K = A->getType()->getY(); 1530 } else { 1531 K = A->getType()->getX(); 1532 } 1533 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csyr2k, 1534 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, 1535 alpha.x, alpha.y, A->getID(), B->getID(), 1536 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1537} 1538 1539void ScriptIntrinsicBLAS::ZSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha, 1540 const sp<Allocation>& A, const sp<Allocation>& B, Double2 beta, const sp<Allocation>& C) { 1541 validateSYR2K(mRS, Element::F64_2(mRS), Trans, A, B, C); 1542 int K = -1; 1543 if (Trans != RsBlasNoTrans) { 1544 K = A->getType()->getY(); 1545 } else { 1546 K = A->getType()->getX(); 1547 } 1548 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsyr2k, 1549 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, 1550 alpha.x, alpha.y, A->getID(), B->getID(), 1551 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1552} 1553 1554static void validateTRMM(RS* mRS, const sp<const Element>& e, RsBlasSide Side, RsBlasTranspose TransA, 1555 const sp<Allocation>& A, const sp<Allocation>& B) { 1556 int aM = -1, aN = -1, bM = -1, bN = -1; 1557 if (!A->getType()->getElement()->isCompatible(e) || 1558 !B->getType()->getElement()->isCompatible(e)) { 1559 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 1560 } 1561 1562 aM = A->getType()->getY(); 1563 aN = A->getType()->getX(); 1564 if (aM != aN) { 1565 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with a non-symmetric matrix A"); 1566 } 1567 1568 bM = B->getType()->getY(); 1569 bN = B->getType()->getX(); 1570 if (Side == RsBlasLeft) { 1571 if (aN != bM) { 1572 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with invalid matrices"); 1573 } 1574 } else { 1575 if (bN != aM) { 1576 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with invalid matrices"); 1577 } 1578 } 1579} 1580 1581void ScriptIntrinsicBLAS::STRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 1582 float alpha, const sp<Allocation>& A, const sp<Allocation>& B) { 1583 validateTRMM(mRS, Element::F32(mRS), Side, TransA, A, B); 1584 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strmm, 1585 TransA, 0, Side, Uplo, Diag,\ 1586 B->getType()->getY(), B->getType()->getX(), 0, 1587 alpha, A->getID(), B->getID(), 0.f, 0, 0, 0, 0, 0); 1588} 1589 1590void ScriptIntrinsicBLAS::DTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 1591 double alpha, const sp<Allocation>& A, const sp<Allocation>& B) { 1592 validateTRMM(mRS, Element::F64(mRS), Side, TransA, A, B); 1593 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrmm, 1594 TransA, 0, Side, Uplo, Diag, 1595 B->getType()->getY(), B->getType()->getX(), 0, 1596 alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0); 1597} 1598 1599void ScriptIntrinsicBLAS::CTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 1600 Float2 alpha, const sp<Allocation>& A, const sp<Allocation>& B) { 1601 validateTRMM(mRS, Element::F32_2(mRS), Side, TransA, A, B); 1602 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrmm, 1603 TransA, 0, Side, Uplo, Diag, 1604 B->getType()->getY(), B->getType()->getX(), 0, 1605 alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0); 1606} 1607 1608void ScriptIntrinsicBLAS::ZTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 1609 Double2 alpha, const sp<Allocation>& A, const sp<Allocation>& B) { 1610 validateTRMM(mRS, Element::F64_2(mRS), Side, TransA, A, B); 1611 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrmm, 1612 TransA, 0, Side, Uplo, Diag, 1613 B->getType()->getY(), B->getType()->getX(), 0, 1614 alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0); 1615} 1616 1617static void validateTRSM(RS* mRS, const sp<const Element>& e, RsBlasSide Side, RsBlasTranspose TransA, 1618 const sp<Allocation>& A, const sp<Allocation>& B) { 1619 int adim = -1, bM = -1, bN = -1; 1620 if (!A->getType()->getElement()->isCompatible(e) || 1621 !B->getType()->getElement()->isCompatible(e)) { 1622 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 1623 } 1624 adim = A->getType()->getX(); 1625 if (adim != (int)A->getType()->getY()) { 1626 // This may be unnecessary, the restriction could potentially be relaxed. 1627 // Allocation A needs to contain at least that symmetric matrix but could theoretically 1628 // be larger for now we assume adapters are sufficient, will reevaluate in the future. 1629 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with a non-symmetric matrix A"); 1630 } 1631 bM = B->getType()->getY(); 1632 bN = B->getType()->getX(); 1633 if (Side == RsBlasLeft) { 1634 // A is M*M 1635 if (adim != bM) { 1636 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with invalid matrix dimensions"); 1637 } 1638 } else { 1639 // A is N*N 1640 if (adim != bN) { 1641 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with invalid matrix dimensions"); 1642 } 1643 } 1644} 1645 1646void ScriptIntrinsicBLAS::STRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 1647 float alpha, const sp<Allocation>& A, const sp<Allocation>& B) { 1648 validateTRSM(mRS, Element::F32(mRS), Side, TransA, A, B); 1649 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strsm, 1650 TransA, 0, Side, Uplo, Diag, 1651 B->getType()->getY(), B->getType()->getX(), 0, 1652 alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0); 1653} 1654 1655void ScriptIntrinsicBLAS::DTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 1656 double alpha, const sp<Allocation>& A, const sp<Allocation>& B) { 1657 validateTRSM(mRS, Element::F64(mRS), Side, TransA, A, B); 1658 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrsm, 1659 TransA, 0, Side, Uplo, Diag, 1660 B->getType()->getY(), B->getType()->getX(), 0, 1661 alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0); 1662} 1663 1664void ScriptIntrinsicBLAS::CTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 1665 Float2 alpha, const sp<Allocation>& A, const sp<Allocation>& B) { 1666 validateTRSM(mRS, Element::F32_2(mRS), Side, TransA, A, B); 1667 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrsm, 1668 TransA, 0, Side, Uplo, Diag, 1669 B->getType()->getY(), B->getType()->getX(), 0, 1670 alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0); 1671} 1672 1673void ScriptIntrinsicBLAS::ZTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 1674 Double2 alpha, const sp<Allocation>& A, const sp<Allocation>& B) { 1675 validateTRSM(mRS, Element::F64_2(mRS), Side, TransA, A, B); 1676 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrsm, 1677 TransA, 0, Side, Uplo, Diag, 1678 B->getType()->getY(), B->getType()->getX(), 0, 1679 alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0); 1680} 1681 1682static void validateHEMM(RS* mRS, const sp<const Element>& e, RsBlasSide Side, 1683 const sp<Allocation>& A, const sp<Allocation>& B, const sp<Allocation>& C) { 1684 if (!A->getType()->getElement()->isCompatible(e) || 1685 !B->getType()->getElement()->isCompatible(e) || 1686 !C->getType()->getElement()->isCompatible(e)) { 1687 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 1688 } 1689 1690 // A must be square; can potentially be relaxed similar to TRSM 1691 int adim = A->getType()->getX(); 1692 if (adim != (int)A->getType()->getY()) { 1693 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with non-square A"); 1694 } 1695 if ((Side == RsBlasLeft && adim != (int)B->getType()->getY()) || 1696 (Side == RsBlasRight && adim != (int)B->getType()->getX())) { 1697 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with invalid B"); 1698 } 1699 if (B->getType()->getX() != C->getType()->getX() || 1700 B->getType()->getY() != C->getType()->getY()) { 1701 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with mismatched B and C"); 1702 } 1703} 1704 1705void ScriptIntrinsicBLAS::CHEMM(RsBlasSide Side, RsBlasUplo Uplo, Float2 alpha, 1706 const sp<Allocation>& A, const sp<Allocation>& B, Float2 beta, const sp<Allocation>& C) { 1707 validateHEMM(mRS, Element::F32_2(mRS), Side, A, B, C); 1708 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chemm, 1709 0, 0, Side, Uplo, 0, 1710 C->getType()->getY(), C->getType()->getX(), 0, 1711 alpha.x, alpha.y, A->getID(), B->getID(), 1712 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1713} 1714 1715void ScriptIntrinsicBLAS::ZHEMM(RsBlasSide Side, RsBlasUplo Uplo, Double2 alpha, 1716 const sp<Allocation>& A, const sp<Allocation>& B, Double2 beta, const sp<Allocation>& C) { 1717 validateHEMM(mRS, Element::F64_2(mRS), Side, A, B, C); 1718 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhemm, 1719 0, 0, Side, Uplo, 0, 1720 C->getType()->getY(), C->getType()->getX(), 0, 1721 alpha.x, alpha.y, A->getID(), B->getID(), 1722 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1723} 1724 1725static void validateHERK(RS* mRS, const sp<const Element>& e, RsBlasTranspose Trans, 1726 const sp<Allocation>& A, const sp<Allocation>& C) { 1727 if (!A->getType()->getElement()->isCompatible(e) || 1728 !C->getType()->getElement()->isCompatible(e)) { 1729 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 1730 } 1731 if (Trans != RsBlasNoTrans && Trans != RsBlasConjTrans) { 1732 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Call HERK with invalid Transpose"); 1733 } 1734 int cdim = C->getType()->getX(); 1735 if (cdim != (int)C->getType()->getY()) { 1736 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with non-square C"); 1737 } 1738 if (Trans == RsBlasNoTrans) { 1739 if (cdim != (int)A->getType()->getY()) { 1740 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with invalid A"); 1741 } 1742 } else { 1743 if (cdim != (int)A->getType()->getX()) { 1744 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with invalid A"); 1745 } 1746 } 1747} 1748 1749void ScriptIntrinsicBLAS::CHERK(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha, 1750 const sp<Allocation>& A, float beta, const sp<Allocation>& C) { 1751 validateHERK(mRS, Element::F32_2(mRS), Trans, A, C); 1752 int k = 0; 1753 if (Trans == RsBlasConjTrans) { 1754 k = A->getType()->getY(); 1755 } else { 1756 k = A->getType()->getX(); 1757 } 1758 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cherk, 1759 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k, 1760 alpha, 0, A->getID(), 0, 1761 beta, 0, C->getID(), 0, 0, 0, 0); 1762} 1763 1764void ScriptIntrinsicBLAS::ZHERK(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha, 1765 const sp<Allocation>& A, double beta, const sp<Allocation>& C) { 1766 validateHERK(mRS, Element::F64_2(mRS), Trans, A, C); 1767 int k = 0; 1768 if (Trans == RsBlasConjTrans) { 1769 k = A->getType()->getY(); 1770 } else { 1771 k = A->getType()->getX(); 1772 } 1773 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zherk, 1774 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k, 1775 alpha, 0, A->getID(), 0, 1776 beta, 0, C->getID(), 0, 0, 0, 0); 1777} 1778 1779static void validateHER2K(RS* mRS, const sp<const Element>& e, RsBlasTranspose Trans, 1780 const sp<Allocation>& A, const sp<Allocation>& B, const sp<Allocation>& C) { 1781 if (!A->getType()->getElement()->isCompatible(e) || 1782 !B->getType()->getElement()->isCompatible(e) || 1783 !C->getType()->getElement()->isCompatible(e)) { 1784 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 1785 } 1786 if (Trans != RsBlasNoTrans && Trans != RsBlasConjTrans) { 1787 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Call HERK with invalid Transpose"); 1788 } 1789 int cdim = C->getType()->getX(); 1790 if (cdim != (int)C->getType()->getY()) { 1791 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with non-square C"); 1792 } 1793 if (Trans == RsBlasNoTrans) { 1794 if ((int)A->getType()->getY() != cdim) { 1795 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid matrices"); 1796 } 1797 } else { 1798 if ((int)A->getType()->getX() != cdim) { 1799 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid matrices"); 1800 } 1801 } 1802 if (A->getType()->getX() != B->getType()->getX() || A->getType()->getY() != B->getType()->getY()) { 1803 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid A and B matrices"); 1804 } 1805} 1806 1807void ScriptIntrinsicBLAS::CHER2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha, 1808 const sp<Allocation>& A, const sp<Allocation>& B, float beta, const sp<Allocation>& C) { 1809 validateHER2K(mRS, Element::F32_2(mRS), Trans, A, B, C); 1810 int k = 0; 1811 if (Trans == RsBlasNoTrans) { 1812 k = A->getType()->getX(); 1813 } else { 1814 k = A->getType()->getY(); 1815 } 1816 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher2k, 1817 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k, 1818 alpha.x, alpha.y, A->getID(), B->getID(), 1819 beta, 0, C->getID(), 0, 0, 0, 0); 1820} 1821 1822void ScriptIntrinsicBLAS::ZHER2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha, 1823 const sp<Allocation>& A, const sp<Allocation>& B, double beta, const sp<Allocation>& C) { 1824 validateHER2K(mRS, Element::F64_2(mRS), Trans, A, B, C); 1825 int k = 0; 1826 if (Trans == RsBlasNoTrans) { 1827 k = A->getType()->getX(); 1828 } else { 1829 k = A->getType()->getY(); 1830 } 1831 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher2k, 1832 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k, 1833 alpha.x, alpha.y, A->getID(), B->getID(), 1834 beta, 0, C->getID(), 0, 0, 0, 0); 1835} 1836 1837 1838 1839void ScriptIntrinsicBLAS::BNNM(const sp<Allocation>& A, int a_offset, const sp<Allocation>& B, int b_offset, 1840 const sp<Allocation>& C, int c_offset, int c_mult) { 1841 validateL3(mRS, Element::U8(mRS), RsBlasNoTrans, RsBlasTrans, 0, A, B, C); 1842 1843 if (a_offset < 0 || a_offset > 255) { 1844 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid a_offset passed to BNNM"); 1845 } 1846 if (b_offset < 0 || b_offset > 255) { 1847 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid b_offset passed to BNNM"); 1848 } 1849 int M = -1, N = -1, K = -1; 1850 M = A->getType()->getY(); 1851 N = B->getType()->getY(); 1852 K = A->getType()->getX(); 1853 1854 nScriptIntrinsicBLAS_BNNM(mRS, mRS->getContext(), getID(), M, N, K, A->getID(), a_offset, 1855 B->getID(), b_offset, C->getID(), c_offset, c_mult); 1856} 1857