1/*
2 * Copyright (C) 2015 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17
18#include "RenderScript.h"
19#include "rsCppInternal.h"
20
21#define NELEM(m) (sizeof(m) / sizeof((m)[0]))
22
23using android::RSC::Allocation;
24using android::RSC::Element;
25using android::RSC::RS;
26using android::RSC::RS_ERROR_INVALID_ELEMENT;
27using android::RSC::RS_ERROR_INVALID_PARAMETER;
28using android::RSC::RS_SUCCESS;
29using android::RSC::ScriptIntrinsicBLAS;
30using android::RSC::sp;
31
32// ScriptIntrinsicBLAS APIS
33ScriptIntrinsicBLAS::ScriptIntrinsicBLAS(sp<RS> rs, sp<const Element> e)
34    : ScriptIntrinsic(rs, RS_SCRIPT_INTRINSIC_ID_BLAS, e) {
35
36}
37
38sp<ScriptIntrinsicBLAS> ScriptIntrinsicBLAS::create(const sp<RS>& rs) {
39    return new ScriptIntrinsicBLAS(rs, Element::U32(rs));
40}
41
42enum RsBlasDataType {
43    SINGLE,
44    DOUBLE,
45    SINGLE_COMPLEX,
46    DOUBLE_COMPLEX
47};
48
49static RsBlasCall
50setUpBLASCall(RsBlasDataType dataType, RsBlasFunction func,
51              int TransA, int TransB, int Side, int Uplo, int Diag,
52              int M, int N, int K, int incX, int incY, int KL, int KU,
53              float alphaF, float betaF, double alphaD, double betaD,
54              float alphaCX, float alphaCY, float betaCX, float betaCY,
55              double alphaZX, double alphaZY, double betaZX, double betaZY
56              ) {
57    RsBlasCall call;
58    memset(&call, 0, sizeof(call));
59    call.func = func;
60    call.transA = (RsBlasTranspose)TransA;
61    call.transB = (RsBlasTranspose)TransB;
62    call.side = (RsBlasSide)Side;
63    call.uplo = (RsBlasUplo)Uplo;
64    call.diag = (RsBlasDiag)Diag;
65    call.M = M;
66    call.N = N;
67    call.K = K;
68
69    switch (dataType) {
70        case SINGLE:
71            // For Single-precision BLAS.
72            call.alpha.f = alphaF;
73            call.beta.f = betaF;
74            break;
75        case DOUBLE:
76            // For Double-precision BLAS.
77            call.alpha.d = alphaD;
78            call.beta.d = betaD;
79            break;
80        case SINGLE_COMPLEX:
81            // For Single-precision complex BLAS.
82            call.alpha.c.r = alphaCX;
83            call.alpha.c.i = alphaCY;
84            call.beta.c.r = betaCX;
85            call.beta.c.i = betaCY;
86            break;
87        case DOUBLE_COMPLEX:
88            // For Double-precision complex BLAS.
89            call.alpha.z.r = alphaZX;
90            call.alpha.z.i = alphaZY;
91            call.beta.z.r = betaZX;
92            call.beta.z.i = betaZY;
93            break;
94        default:
95            break;
96    }
97
98    call.incX = incX;
99    call.incY = incY;
100    call.KL = KL;
101    call.KU = KU;
102
103    return call;
104}
105
106static void
107nScriptIntrinsicBLAS_Single(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
108                            int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
109                            float alpha, RsAllocation A, RsAllocation B,
110                            float beta, RsAllocation C, int incX, int incY, int KL, int KU) {
111    RsBlasCall call = setUpBLASCall(SINGLE, func, TransA, TransB, Side, Uplo, Diag,
112                                    M, N, K, incX, incY, KL, KU, alpha, beta, 0.0, 0.0,
113                                    0.0f, 0.0f, 0.0f, 0.0f, 0.0, 0.0, 0.0, 0.0);
114    RsAllocation in_allocs[3] = {A, B, C};
115    tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr,
116                                                      &call, sizeof(call), nullptr, 0));
117}
118
119
120static void
121nScriptIntrinsicBLAS_Double(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
122                            int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
123                            double alpha, RsAllocation A, RsAllocation B,
124                            double beta, RsAllocation C, int incX, int incY, int KL, int KU) {
125    RsBlasCall call = setUpBLASCall(DOUBLE, func, TransA, TransB, Side, Uplo, Diag,
126                                    M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, alpha, beta,
127                                    0.0f, 0.0f, 0.0f, 0.0f, 0.0, 0.0, 0.0, 0.0);
128    RsAllocation in_allocs[3] = {A, B, C};
129    tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr,
130                                                      &call, sizeof(call), nullptr, 0));
131}
132
133static void
134nScriptIntrinsicBLAS_Complex(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
135                             int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
136                             float alphaX, float alphaY, RsAllocation A, RsAllocation B,
137                             float betaX, float betaY, RsAllocation C, int incX, int incY, int KL, int KU) {
138    RsBlasCall call = setUpBLASCall(SINGLE_COMPLEX, func, TransA, TransB, Side, Uplo, Diag,
139                                    M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, 0.0, 0.0,
140                                    alphaX, alphaY, betaX, betaY, 0.0, 0.0, 0.0, 0.0);
141    RsAllocation in_allocs[3] = {A, B, C};
142    tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr,
143                                                      &call, sizeof(call), nullptr, 0));
144}
145
146static void
147nScriptIntrinsicBLAS_Z(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
148                       int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
149                       double alphaX, double alphaY, RsAllocation A, RsAllocation B,
150                       double betaX, double betaY, RsAllocation C, int incX, int incY, int KL, int KU) {
151    RsBlasCall call = setUpBLASCall(DOUBLE_COMPLEX, func, TransA, TransB, Side, Uplo, Diag,
152                                    M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, 0.0, 0.0,
153                                    0.0f, 0.0f, 0.0f, 0.0f, alphaX, alphaY, betaX, betaY);
154    RsAllocation in_allocs[3] = {A, B, C};
155    tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr,
156                                                      &call, sizeof(call), nullptr, 0));
157}
158
159
160static void
161nScriptIntrinsicBLAS_BNNM(RS* mRS, RsContext con, RsScript id, int M, int N, int K,
162                          RsAllocation A, int a_offset, RsAllocation B, int b_offset,
163                          RsAllocation C, int c_offset, int c_mult_int) {
164    RsBlasCall call;
165    memset(&call, 0, sizeof(call));
166    call.func = RsBlas_bnnm;
167    call.M = M;
168    call.N = N;
169    call.K = K;
170    call.a_offset = a_offset & 0xFF;
171    call.b_offset = b_offset & 0xFF;
172    call.c_offset = c_offset;
173    call.c_mult_int = c_mult_int;
174
175    RsAllocation in_allocs[3] = {A, B, C};
176    tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr,
177                                                      &call, sizeof(call), nullptr, 0));
178}
179
180/**
181 * Level 2 BLAS
182 */
183static void validateGEMV(RS* mRS, const sp<const Element>& e, RsBlasTranspose TransA, const sp<Allocation>& A,
184                         const sp<Allocation>& X, int incX, const sp<Allocation>& Y, int incY) {
185    int M = A->getType()->getY();
186    int N = A->getType()->getX();
187    if (!A->getType()->getElement()->isCompatible(e) ||
188        !X->getType()->getElement()->isCompatible(e) ||
189        !Y->getType()->getElement()->isCompatible(e)) {
190        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
191    }
192    if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
193        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
194    }
195
196    if (incX <= 0 || incY <= 0) {
197        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
198    }
199    int expectedXDim = -1, expectedYDim = -1;
200    if (TransA == RsBlasNoTrans) {
201        expectedXDim = 1 + (N - 1) * incX;
202        expectedYDim = 1 + (M - 1) * incY;
203    } else {
204        expectedXDim = 1 + (M - 1) * incX;
205        expectedYDim = 1 + (N - 1) * incY;
206    }
207    if ((int)X->getType()->getX() != expectedXDim ||
208        (int)Y->getType()->getX() != expectedYDim) {
209        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GEMV");
210    }
211}
212
213void ScriptIntrinsicBLAS::SGEMV(RsBlasTranspose TransA, float alpha, const sp<Allocation>& A, const sp<Allocation>& X,
214                                int incX, float beta, const sp<Allocation>& Y, int incY) {
215    validateGEMV(mRS, Element::F32(mRS), TransA, A, X, incX, Y, incY);
216    int M = A->getType()->getY();
217    int N = A->getType()->getX();
218    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgemv,
219                                TransA, 0, 0, 0, 0, M, N, 0,
220                                alpha, A->getID(), X->getID(),
221                                beta, Y->getID(), incX, incY, 0, 0);
222}
223
224void ScriptIntrinsicBLAS::DGEMV(RsBlasTranspose TransA, double alpha, const sp<Allocation>& A, const sp<Allocation>& X,
225                                int incX, double beta, const sp<Allocation>& Y, int incY) {
226    validateGEMV(mRS, Element::F64(mRS), TransA, A, X, incX, Y, incY);
227    int M = A->getType()->getY();
228    int N = A->getType()->getX();
229    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgemv,
230                                TransA, 0, 0, 0, 0, M, N, 0,
231                                alpha, A->getID(), X->getID(),
232                                beta, Y->getID(), incX, incY, 0, 0);
233}
234
235void ScriptIntrinsicBLAS::CGEMV(RsBlasTranspose TransA, Float2 alpha, const sp<Allocation>& A, const sp<Allocation>& X,
236                                int incX, Float2 beta, const sp<Allocation>& Y, int incY) {
237    validateGEMV(mRS, Element::F32_2(mRS), TransA, A, X, incX, Y, incY);
238    int M = A->getType()->getY();
239    int N = A->getType()->getX();
240    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgemv,
241                                 TransA, 0, 0, 0, 0, M, N, 0,
242                                 alpha.x, alpha.y, A->getID(), X->getID(),
243                                 beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
244}
245
246void ScriptIntrinsicBLAS::ZGEMV(RsBlasTranspose TransA, Double2 alpha, const sp<Allocation>& A, const sp<Allocation>& X,
247                                int incX, Double2 beta, const sp<Allocation>& Y, int incY) {
248    validateGEMV(mRS, Element::F64_2(mRS), TransA, A, X, incX, Y, incY);
249    int M = A->getType()->getY();
250    int N = A->getType()->getX();
251    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgemv,
252                           TransA, 0, 0, 0, 0, M, N, 0,
253                           alpha.x, alpha.y, A->getID(), X->getID(),
254                           beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
255}
256
257void ScriptIntrinsicBLAS::SGBMV(RsBlasTranspose TransA, int KL, int KU, float alpha, const sp<Allocation>& A,
258                                const sp<Allocation>& X, int incX, float beta, const sp<Allocation>& Y, int incY) {
259    // GBMV has the same validation requirements as GEMV + KL and KU >= 0
260    validateGEMV(mRS, Element::F32(mRS), TransA, A, X, incX, Y, incY);
261    if (KL < 0 || KU < 0) {
262        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
263    }
264    int M = A->getType()->getY();
265    int N = A->getType()->getX();
266
267    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgbmv,
268                                TransA, 0, 0, 0, 0, M, N, 0,
269                                alpha, A->getID(), X->getID(),
270                                beta, Y->getID(), incX, incY, KL, KU);
271}
272
273void ScriptIntrinsicBLAS::DGBMV(RsBlasTranspose TransA, int KL, int KU, double alpha, const sp<Allocation>& A,
274                                const sp<Allocation>& X, int incX, double beta, const sp<Allocation>& Y, int incY) {
275    // GBMV has the same validation requirements as GEMV + KL and KU >= 0
276    validateGEMV(mRS, Element::F64(mRS), TransA, A, X, incX, Y, incY);
277    if (KL < 0 || KU < 0) {
278        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
279    }
280    int M = A->getType()->getY();
281    int N = A->getType()->getX();
282
283    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgbmv,
284                                TransA, 0, 0, 0, 0, M, N, 0,
285                                alpha, A->getID(), X->getID(),
286                                beta, Y->getID(), incX, incY, KL, KU);
287}
288
289void ScriptIntrinsicBLAS::CGBMV(RsBlasTranspose TransA, int KL, int KU, Float2 alpha, const sp<Allocation>& A,
290                                const sp<Allocation>& X, int incX, Float2 beta, const sp<Allocation>& Y, int incY) {
291    // GBMV has the same validation requirements as GEMV + KL and KU >= 0
292    validateGEMV(mRS, Element::F32_2(mRS), TransA, A, X, incX, Y, incY);
293    if (KL < 0 || KU < 0) {
294        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
295    }
296    int M = A->getType()->getY();
297    int N = A->getType()->getX();
298
299    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgbmv,
300                                 TransA, 0, 0, 0, 0, M, N, 0,
301                                 alpha.x, alpha.y, A->getID(), X->getID(),
302                                 beta.x, beta.y, Y->getID(), incX, incY, KL, KU);
303}
304
305void ScriptIntrinsicBLAS::ZGBMV(RsBlasTranspose TransA, int KL, int KU, Double2 alpha, const sp<Allocation>& A,
306                                const sp<Allocation>& X, int incX, Double2 beta, const sp<Allocation>& Y, int incY) {
307    // GBMV has the same validation requirements as GEMV + KL and KU >= 0
308    validateGEMV(mRS, Element::F64_2(mRS), TransA, A, X, incX, Y, incY);
309    if (KL < 0 || KU < 0) {
310        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
311    }
312    int M = A->getType()->getY();
313    int N = A->getType()->getX();
314
315    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgbmv,
316                           TransA, 0, 0, 0, 0, M, N, 0,
317                           alpha.x, alpha.y, A->getID(), X->getID(),
318                           beta.x, beta.y, Y->getID(), incX, incY, KL, KU);
319}
320
321static void validateTRMV(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, RsBlasTranspose TransA,
322                         RsBlasDiag Diag, const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
323    int N = A->getType()->getY();
324    if ((int)A->getType()->getX() != N) {
325        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a square matrix for TRMV");
326    }
327    if (!A->getType()->getElement()->isCompatible(e) ||
328        !X->getType()->getElement()->isCompatible(e)) {
329        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
330    }
331    if (X->getType()->getY() > 1) {
332        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
333    }
334
335    if (incX <= 0) {
336        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
337    }
338    int expectedXDim = 1 + (N - 1) * incX;
339    if ((int)X->getType()->getX() != expectedXDim) {
340        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for TRMV");
341    }
342}
343
344static int validateTPMV(RS* mRS, const sp<const Element>& e,  RsBlasUplo Uplo, RsBlasTranspose TransA,
345                        RsBlasDiag Diag, const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) {
346    if (!Ap->getType()->getElement()->isCompatible(e) ||
347        !X->getType()->getElement()->isCompatible(e)) {
348        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
349    }
350    if (X->getType()->getY() > 1) {
351        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
352    }
353
354    if (Ap->getType()->getY() > 1) {
355        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
356    }
357
358    int N = sqrt((double)Ap->getType()->getX() * 2);
359    if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
360        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
361    }
362    if (incX <= 0) {
363        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
364    }
365    int expectedXDim = 1 + (N - 1) * incX;
366    if ((int)X->getType()->getX() != expectedXDim) {
367        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for TPMV");
368    }
369
370    return N;
371}
372
373
374void ScriptIntrinsicBLAS::STRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
375                                const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
376    validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
377    int N = A->getType()->getY();
378    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strmv,
379                                TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
380                                A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
381}
382
383void ScriptIntrinsicBLAS::DTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
384                                const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
385    validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
386    int N = A->getType()->getY();
387    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrmv,
388                                TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
389                                A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
390}
391
392void ScriptIntrinsicBLAS::CTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
393                                const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
394    validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
395    int N = A->getType()->getY();
396    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrmv,
397                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
398                                 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
399}
400
401void ScriptIntrinsicBLAS::ZTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
402                                const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
403    validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
404    int N = A->getType()->getY();
405    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrmv,
406                           TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
407                           A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
408}
409
410void ScriptIntrinsicBLAS::STBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
411                                int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
412    // TBMV has the same requirements as TRMV + K >= 0
413    if (K < 0) {
414        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
415    }
416    validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
417    int N = A->getType()->getY();
418    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stbmv,
419                                TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
420                                A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
421}
422
423void ScriptIntrinsicBLAS::DTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
424                                int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
425    // TBMV has the same requirements as TRMV + K >= 0
426    if (K < 0) {
427        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
428    }
429    validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
430    int N = A->getType()->getY();
431    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtbmv,
432                                TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
433                                A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
434}
435
436void ScriptIntrinsicBLAS::CTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
437                                int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
438    // TBMV has the same requirements as TRMV + K >= 0
439    if (K < 0) {
440        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
441    }
442    validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
443    int N = A->getType()->getY();
444    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctbmv,
445                                 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0,
446                                 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
447}
448
449void ScriptIntrinsicBLAS::ZTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
450                                int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
451    // TBMV has the same requirements as TRMV + K >= 0
452    if (K < 0) {
453        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
454    }
455    validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
456    int N = A->getType()->getY();
457    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztbmv,
458                           TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0,
459                           A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
460}
461
462void ScriptIntrinsicBLAS::STPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
463                                const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) {
464    int N = validateTPMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
465    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stpmv,
466                                TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
467                                Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
468}
469
470void ScriptIntrinsicBLAS::DTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
471                                const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) {
472    int N = validateTPMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
473    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtpmv,
474                                TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
475                                Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
476}
477
478void ScriptIntrinsicBLAS::CTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
479                                const sp<Allocation>& Ap,  const sp<Allocation>& X,  int incX) {
480    int N = validateTPMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
481    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctpmv,
482                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
483                                 Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
484}
485
486void ScriptIntrinsicBLAS::ZTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
487                                const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) {
488    int N = validateTPMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
489    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztpmv,
490                           TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
491                           Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
492}
493
494void ScriptIntrinsicBLAS::STRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
495                                const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
496    // TRSV is the same as TRMV
497    validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
498    int N = A->getType()->getY();
499    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strsv,
500                                TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
501                                A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
502}
503
504void ScriptIntrinsicBLAS::DTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
505                                const sp<Allocation>& A,  const sp<Allocation>& X,  int incX) {
506    // TRSV is the same as TRMV
507    validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
508    int N = A->getType()->getY();
509    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrsv,
510                                TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
511                                A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
512
513}
514
515void ScriptIntrinsicBLAS::CTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
516                                const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
517    // TRSV is the same as TRMV
518    validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
519    int N = A->getType()->getY();
520    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrsv,
521                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
522                                 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
523
524}
525
526void ScriptIntrinsicBLAS::ZTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
527                                const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
528    // TRSV is the same as TRMV
529    validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
530    int N = A->getType()->getY();
531    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrsv,
532                           TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
533                           A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
534
535}
536
537void ScriptIntrinsicBLAS::STBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
538                                int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
539    // TBSV is the same as TRMV + K >= 0
540    validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
541    int N = A->getType()->getY();
542    if (K < 0) {
543        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
544    }
545    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stbsv,
546                                TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
547                                A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
548}
549
550void ScriptIntrinsicBLAS::DTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
551                                int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
552    // TBSV is the same as TRMV + K >= 0
553    validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
554    int N = A->getType()->getY();
555    if (K < 0) {
556        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
557    }
558    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtbsv,
559                                TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
560                                A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
561}
562
563void ScriptIntrinsicBLAS::CTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
564                                int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
565    // TBSV is the same as TRMV + K >= 0
566    validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
567    int N = A->getType()->getY();
568    if (K < 0) {
569        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
570    }
571    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctbsv,
572                                 TransA, 0, 0, Uplo, Diag, 0, N, K,
573                                 0, 0, A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
574}
575
576void ScriptIntrinsicBLAS::ZTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
577                                int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
578    // TBSV is the same as TRMV + K >= 0
579    validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
580    int N = A->getType()->getY();
581    if (K < 0) {
582        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
583    }
584    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztbsv,
585                           TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0,
586                           A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
587}
588
589void ScriptIntrinsicBLAS::STPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
590                                const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) {
591    // TPSV is same as TPMV
592    int N = validateTPMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
593    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stpsv,
594                                TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
595                                Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
596}
597
598void ScriptIntrinsicBLAS::DTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
599                                const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) {
600    // TPSV is same as TPMV
601    int N = validateTPMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
602    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtpsv,
603                                TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
604                                Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
605}
606
607void ScriptIntrinsicBLAS::CTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
608                                const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) {
609    // TPSV is same as TPMV
610    int N = validateTPMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
611    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctpsv,
612                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
613                                 Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
614}
615
616void ScriptIntrinsicBLAS::ZTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
617                                const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) {
618    // TPSV is same as TPMV
619    int N = validateTPMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
620    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztpsv,
621                           TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
622                           Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
623}
624
625/**
626 * Level 2, S and D only
627 */
628static int validateSYMV(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, const sp<Allocation>& A,
629                        const sp<Allocation>& X, const sp<Allocation>& Y, int incX, int incY) {
630    int N = A->getType()->getY();
631    if ((int)A->getType()->getX() != N) {
632        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a square matrix for SYMV");
633    }
634    if (!A->getType()->getElement()->isCompatible(e) ||
635        !X->getType()->getElement()->isCompatible(e) ||
636        !Y->getType()->getElement()->isCompatible(e) ) {
637        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
638    }
639    if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
640        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
641    }
642
643    if (incX <= 0 || incY <= 0) {
644        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
645    }
646    int expectedXDim = 1 + (N - 1) * incX;
647    if ((int)X->getType()->getX() != expectedXDim) {
648        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYMV");
649    }
650    int expectedYDim = 1 + (N - 1) * incY;
651    if ((int)Y->getType()->getX() != expectedYDim) {
652        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYMV");
653    }
654    return N;
655}
656static int validateSPMV(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, const sp<Allocation>& Ap,
657                        const sp<Allocation>& X, int incX, const sp<Allocation>& Y, int incY) {
658    if (!Ap->getType()->getElement()->isCompatible(e) ||
659        !X->getType()->getElement()->isCompatible(e) ||
660        !Y->getType()->getElement()->isCompatible(e)) {
661        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
662    }
663    if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
664        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
665    }
666
667    if (Ap->getType()->getY() > 1) {
668        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
669    }
670
671    int N = sqrt((double)Ap->getType()->getX() * 2);
672    if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
673        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
674    }
675    if (incX <= 0 || incY <= 0) {
676        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
677    }
678    int expectedXDim = 1 + (N - 1) * incX;
679    if ((int)X->getType()->getX() != expectedXDim) {
680        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPMV");
681    }
682    int expectedYDim = 1 + (N - 1) * incY;
683    if ((int)Y->getType()->getX() != expectedYDim) {
684        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPMV");
685    }
686
687    return N;
688}
689static void validateGER(RS* mRS, const sp<const Element>& e, const sp<Allocation>& X, int incX,
690                        const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
691    if (!A->getType()->getElement()->isCompatible(e) ||
692        !X->getType()->getElement()->isCompatible(e) ||
693        !Y->getType()->getElement()->isCompatible(e) ) {
694        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
695    }
696
697    if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
698        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
699    }
700
701    int M = A->getType()->getY();
702    int N = A->getType()->getX();
703
704    if (N < 1 || M < 1) {
705        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "M and N must be 1 or greater for GER");
706    }
707    if (incX <= 0 || incY <= 0) {
708        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
709    }
710    int expectedXDim = 1 + (M - 1) * incX;
711    if ((int)X->getType()->getX() != expectedXDim) {
712        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GER");
713    }
714    int expectedYDim = 1 + (N - 1) * incY;
715    if ((int)Y->getType()->getX() != expectedYDim) {
716        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GER");
717    }
718
719
720}
721static int validateSYR(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo,
722                       const sp<Allocation>& X, int incX, const sp<Allocation>& A) {
723    if (!A->getType()->getElement()->isCompatible(e) ||
724        !X->getType()->getElement()->isCompatible(e)) {
725        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
726    }
727
728    int N = A->getType()->getX();
729
730    if (X->getType()->getY() > 1) {
731        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
732    }
733    if (N != (int)A->getType()->getY()) {
734        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a symmetric matrix");
735    }
736    if (incX <= 0) {
737        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
738    }
739    int expectedXDim = 1 + (N - 1) * incX;
740    if ((int)X->getType()->getX() != expectedXDim) {
741        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYR");
742    }
743    return N;
744}
745static int validateSPR(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo,
746                       const sp<Allocation>& X, int incX, const sp<Allocation>& Ap) {
747    if (!Ap->getType()->getElement()->isCompatible(e) ||
748        !X->getType()->getElement()->isCompatible(e)) {
749        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
750    }
751    if (X->getType()->getY() > 1) {
752        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
753    }
754
755    if (Ap->getType()->getY() > 1) {
756        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
757    }
758
759    int N = sqrt((double)Ap->getType()->getX() * 2);
760    if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
761        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
762    }
763    if (incX <= 0) {
764        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
765    }
766    int expectedXDim = 1 + (N - 1) * incX;
767    if ((int)X->getType()->getX() != expectedXDim) {
768        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPR");
769    }
770
771    return N;
772}
773
774static int validateSYR2(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, const sp<Allocation>& X,
775                        int incX, const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
776    if (!A->getType()->getElement()->isCompatible(e) ||
777        !X->getType()->getElement()->isCompatible(e) ||
778        !Y->getType()->getElement()->isCompatible(e)) {
779        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
780    }
781
782    if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
783        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
784    }
785
786    int N = A->getType()->getX();
787
788    if (N != (int)A->getType()->getY()) {
789        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a symmetric matrix");
790    }
791    if (incX <= 0 || incY <= 0) {
792        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
793    }
794    int expectedXDim = 1 + (N - 1) * incX;
795    int expectedYDim = 1 + (N - 1) * incY;
796    if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) {
797        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYR");
798    }
799    return N;
800
801}
802static int validateSPR2(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, const sp<Allocation>& X,
803                        int incX, const sp<Allocation>& Y, int incY, const sp<Allocation>& Ap) {
804    if (!Ap->getType()->getElement()->isCompatible(e) ||
805        !X->getType()->getElement()->isCompatible(e) ||
806        !Y->getType()->getElement()->isCompatible(e)) {
807        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
808    }
809    if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
810        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
811    }
812
813    if (Ap->getType()->getY() > 1) {
814        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
815    }
816
817    int N = sqrt((double)Ap->getType()->getX() * 2);
818    if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
819        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
820    }
821    if (incX <= 0 || incY <= 0) {
822        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
823    }
824    int expectedXDim = 1 + (N - 1) * incX;
825    int expectedYDim = 1 + (N - 1) * incY;
826    if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) {
827        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPR2");
828    }
829
830    return N;
831}
832
833void ScriptIntrinsicBLAS::SSYMV(RsBlasUplo Uplo, float alpha, const sp<Allocation>& A, const sp<Allocation>& X,
834                                int incX, float beta, const sp<Allocation>& Y, int incY) {
835    int N = validateSYMV(mRS, Element::F32(mRS), Uplo, A, X, Y, incX, incY);
836    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssymv,
837                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
838                                A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
839}
840
841void ScriptIntrinsicBLAS::SSBMV(RsBlasUplo Uplo, int K, float alpha, const sp<Allocation>& A, const sp<Allocation>& X,
842                                int incX, float beta, const sp<Allocation>& Y, int incY) {
843    // SBMV is the same as SYMV + K >= 0
844    if (K < 0) {
845        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
846    }
847    int N = validateSYMV(mRS, Element::F32(mRS), Uplo, A, X, Y, incX, incY);
848    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssbmv,
849                                0, 0, 0, Uplo, 0, 0, N, K, alpha,
850                                A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
851}
852
853void ScriptIntrinsicBLAS::SSPMV(RsBlasUplo Uplo, float alpha, const sp<Allocation>& Ap, const sp<Allocation>& X,
854                                int incX, float beta, const sp<Allocation>& Y, int incY) {
855    int N = validateSPMV(mRS, Element::F32(mRS), Uplo, Ap, X, incX, Y, incY);
856    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspmv,
857                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
858                                Ap->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
859}
860
861void ScriptIntrinsicBLAS::SGER(float alpha, const sp<Allocation>& X, int incX,
862                               const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
863    int M = A->getType()->getY();
864    int N = A->getType()->getX();
865    validateGER(mRS, Element::F32(mRS), X, incX, Y, incY, A);
866    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sger,
867                                0, 0, 0, 0, 0, M, N, 0, alpha,
868                                X->getID(), Y->getID(), 0.f, A->getID(), incX, incY, 0, 0);
869}
870
871void ScriptIntrinsicBLAS::SSYR(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X,
872                               int incX, const sp<Allocation>& A) {
873    int N = validateSYR(mRS, Element::F32(mRS), Uplo, X, incX, A);
874    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr,
875                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
876                                X->getID(), A->getID(), 0.f, 0, incX, 0, 0, 0);
877}
878
879void ScriptIntrinsicBLAS::SSPR(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X,
880                               int incX, const sp<Allocation>& Ap) {
881    int N = validateSPR(mRS, Element::F32(mRS), Uplo, X, incX, Ap);
882    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspr,
883                                0, 0, 0, Uplo, 0, 0, N, 0,
884                                alpha, X->getID(), Ap->getID(), 0.f, 0, incX, 0, 0, 0);
885}
886
887void ScriptIntrinsicBLAS::SSYR2(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X, int incX,
888                                const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
889    int N = validateSYR2(mRS, Element::F32(mRS), Uplo, X, incX, Y, incY, A);
890    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr2,
891                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
892                                X->getID(), Y->getID(), 0, A->getID(), incX, incY, 0, 0);
893}
894
895void ScriptIntrinsicBLAS::SSPR2(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X, int incX,
896                                const sp<Allocation>& Y, int incY, const sp<Allocation>& Ap) {
897    int N = validateSPR2(mRS, Element::F32(mRS), Uplo, X, incX, Y, incY, Ap);
898    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspr2,
899                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
900                                X->getID(), Y->getID(), 0, Ap->getID(), incX, incY, 0, 0);
901}
902
903void ScriptIntrinsicBLAS::DSYMV(RsBlasUplo Uplo, double alpha, const sp<Allocation>& A, const sp<Allocation>& X,
904                                int incX, double beta, const sp<Allocation>& Y, int incY) {
905    int N = validateSYMV(mRS, Element::F64(mRS), Uplo, A, X, Y, incX, incY);
906    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsymv,
907                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
908                                A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
909}
910
911void ScriptIntrinsicBLAS::DSBMV(RsBlasUplo Uplo, int K, double alpha, const sp<Allocation>& A, const sp<Allocation>& X,
912                                int incX, double beta, const sp<Allocation>& Y, int incY) {
913    // SBMV is the same as SYMV + K >= 0
914    if (K < 0) {
915        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
916    }
917    int N = validateSYMV(mRS, Element::F64(mRS), Uplo, A, X, Y, incX, incY);
918    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsbmv,
919                                0, 0, 0, Uplo, 0, 0, N, K, alpha,
920                                A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
921}
922
923void ScriptIntrinsicBLAS::DSPMV(RsBlasUplo Uplo, double alpha, const sp<Allocation>& Ap, const sp<Allocation>& X,
924                                int incX, double beta, const sp<Allocation>& Y, int incY) {
925    int N = validateSPMV(mRS, Element::F64(mRS), Uplo, Ap, X, incX, Y, incY);
926    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspmv,
927                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
928                                Ap->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
929}
930
931void ScriptIntrinsicBLAS::DGER(double alpha, const sp<Allocation>& X, int incX, const sp<Allocation>& Y,
932                               int incY, const sp<Allocation>& A) {
933    int M = A->getType()->getY();
934    int N = A->getType()->getX();
935    validateGER(mRS, Element::F64(mRS), X, incX, Y, incY, A);
936    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dger,
937                                0, 0, 0, 0, 0, M, N, 0, alpha,
938                                X->getID(), Y->getID(), 0.f, A->getID(), incX, incY, 0, 0);
939}
940
941void ScriptIntrinsicBLAS::DSYR(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X,
942                               int incX, const sp<Allocation>& A) {
943    int N = validateSYR(mRS, Element::F64(mRS), Uplo, X, incX, A);
944    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr,
945                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
946                                X->getID(), A->getID(), 0.f, 0, incX, 0, 0, 0);
947}
948
949void ScriptIntrinsicBLAS::DSPR(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X,
950                               int incX, const sp<Allocation>& Ap) {
951    int N = validateSPR(mRS, Element::F64(mRS), Uplo, X, incX, Ap);
952    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspr,
953                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
954                                X->getID(), Ap->getID(), 0.f, 0, incX, 0, 0, 0);
955}
956
957void ScriptIntrinsicBLAS::DSYR2(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X, int incX,
958                                const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
959    int N = validateSYR2(mRS, Element::F64(mRS), Uplo, X, incX, Y, incY, A);
960    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr2,
961                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
962                                X->getID(), Y->getID(), 0, A->getID(), incX, incY, 0, 0);
963}
964
965void ScriptIntrinsicBLAS::DSPR2(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X, int incX,
966                                const sp<Allocation>& Y, int incY, const sp<Allocation>& Ap) {
967    int N = validateSPR2(mRS, Element::F64(mRS), Uplo, X, incX, Y, incY, Ap);
968    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspr2,
969                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
970                                X->getID(), Y->getID(), 0, Ap->getID(), incX, incY, 0, 0);
971}
972
973
974/**
975 * Level 2, C and Z only
976 */
977
978static void validateGERU(RS* mRS, const sp<const Element>& e, const sp<Allocation>& X, int incX,
979                         const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
980    if (!A->getType()->getElement()->isCompatible(e) ||
981        !X->getType()->getElement()->isCompatible(e) ||
982        !Y->getType()->getElement()->isCompatible(e)) {
983        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
984    }
985    if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
986        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
987    }
988
989    int M = A->getType()->getY();
990    int N = A->getType()->getX();
991    if (incX <= 0 || incY <= 0) {
992        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
993    }
994    int expectedXDim = 1 + (M - 1) * incX;
995    if ((int)X->getType()->getX() != expectedXDim) {
996        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GERU");
997    }
998    int expectedYDim = 1 + (N - 1) * incY;
999    if ((int)Y->getType()->getX() != expectedYDim) {
1000        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GERU");
1001    }
1002
1003}
1004
1005void ScriptIntrinsicBLAS::CHEMV(RsBlasUplo Uplo, Float2 alpha, const sp<Allocation>& A,
1006                                const sp<Allocation>& X, int incX, Float2 beta, const sp<Allocation>& Y, int incY) {
1007    // HEMV is the same as SYR2 validation-wise
1008    int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A);
1009    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chemv,
1010                                 0, 0, 0, Uplo, 0, 0, N, 0,
1011                                 alpha.x, alpha.y, A->getID(), X->getID(),
1012                                 beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1013}
1014
1015void ScriptIntrinsicBLAS::CHBMV(RsBlasUplo Uplo, int K, Float2 alpha, const sp<Allocation>& A,
1016                                const sp<Allocation>& X, int incX, Float2 beta, const sp<Allocation>& Y, int incY) {
1017    // HBMV is the same as SYR2 validation-wise
1018    int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A);
1019    if (K < 0) {
1020        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be 0 or greater for HBMV");
1021    }
1022    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chbmv,
1023                                 0, 0, 0, Uplo, 0, 0, N, K,
1024                                 alpha.x, alpha.y, A->getID(), X->getID(),
1025                                 beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1026}
1027
1028void ScriptIntrinsicBLAS::CHPMV(RsBlasUplo Uplo, Float2 alpha, const sp<Allocation>& Ap,
1029                                const sp<Allocation>& X, int incX, Float2 beta, const sp<Allocation>& Y, int incY) {
1030    // HPMV is the same as SPR2
1031    int N = validateSPR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
1032    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpmv,
1033                                 0, 0, 0, Uplo, 0, 0, N, 0,
1034                                 alpha.x, alpha.y, Ap->getID(), X->getID(),
1035                                 beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1036}
1037
1038void ScriptIntrinsicBLAS::CGERU(Float2 alpha, const sp<Allocation>& X, int incX,
1039                                const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
1040    validateGERU(mRS, Element::F32_2(mRS), X, incX, Y, incY, A);
1041    int M = A->getType()->getY();
1042    int N = A->getType()->getX();
1043    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgeru,
1044                                 0, 0, 0, 0, 0, M, N, 0,
1045                                 alpha.x, alpha.y, X->getID(), Y->getID(),
1046                                 0, 0, A->getID(), incX, incY, 0, 0);
1047}
1048
1049void ScriptIntrinsicBLAS::CGERC(Float2 alpha, const sp<Allocation>& X, int incX,
1050                                const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
1051    // Same as GERU
1052    validateGERU(mRS, Element::F32_2(mRS), X, incX, Y, incY, A);
1053    int M = A->getType()->getY();
1054    int N = A->getType()->getX();
1055    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgerc,
1056                                 0, 0, 0, 0, 0, M, N, 0,
1057                                 alpha.x, alpha.y, X->getID(), Y->getID(),
1058                                 0, 0, A->getID(), incX, incY, 0, 0);
1059}
1060
1061void ScriptIntrinsicBLAS::CHER(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X,
1062                               int incX, const sp<Allocation>& A) {
1063    // Same as SYR
1064    int N = validateSYR(mRS, Element::F32_2(mRS), Uplo, X, incX, A);
1065    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher,
1066                                 0, 0, 0, Uplo, 0, 0, N, 0,
1067                                 alpha, 0, X->getID(), 0,
1068                                 0, 0, A->getID(), incX, 0, 0, 0);
1069}
1070
1071void ScriptIntrinsicBLAS::CHPR(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X,
1072                               int incX, const sp<Allocation>& Ap) {
1073    // Equivalent to SPR for validation
1074    int N = validateSPR(mRS, Element::F32_2(mRS), Uplo, X, incX, Ap);
1075    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpr,
1076                                 0, 0, 0, Uplo, 0, 0, N, 0,
1077                                 alpha, 0, X->getID(), 0,
1078                                 0, 0, Ap->getID(), incX, 0, 0, 0);
1079}
1080
1081void ScriptIntrinsicBLAS::CHER2(RsBlasUplo Uplo, Float2 alpha, const sp<Allocation>& X, int incX,
1082                                const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
1083    // Same as SYR2
1084    int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A);
1085    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher2,
1086                                 0, 0, 0, Uplo, 0, 0, N, 0,
1087                                 alpha.x, alpha.y, X->getID(), Y->getID(),
1088                                 0, 0, A->getID(), incX, incY, 0, 0);
1089}
1090
1091void ScriptIntrinsicBLAS::CHPR2(RsBlasUplo Uplo, Float2 alpha, const sp<Allocation>& X, int incX,
1092                                const sp<Allocation>& Y, int incY, const sp<Allocation>& Ap) {
1093    // Same as SPR2
1094    int N = validateSPR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
1095    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpr2,
1096                                 0, 0, 0, Uplo, 0, 0, N, 0,
1097                                 alpha.x, alpha.y, X->getID(), Y->getID(),
1098                                 0, 0, Ap->getID(), incX, incY, 0, 0);
1099}
1100
1101void ScriptIntrinsicBLAS::ZHEMV(RsBlasUplo Uplo, Double2 alpha, const sp<Allocation>& A,
1102                                const sp<Allocation>& X, int incX, Double2 beta, const sp<Allocation>& Y, int incY) {
1103    // HEMV is the same as SYR2 validation-wise
1104    int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A);
1105    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhemv,
1106                           0, 0, 0, Uplo, 0, 0, N, 0,
1107                           alpha.x, alpha.y, A->getID(), X->getID(),
1108                           beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1109}
1110
1111void ScriptIntrinsicBLAS::ZHBMV(RsBlasUplo Uplo, int K, Double2 alpha, const sp<Allocation>& A, const sp<Allocation>& X,
1112                                int incX, Double2 beta, const sp<Allocation>& Y, int incY) {
1113    // HBMV is the same as SYR2 validation-wise
1114    int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A);
1115    if (K < 0) {
1116        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be 0 or greater for HBMV");
1117    }
1118    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhbmv,
1119                           0, 0, 0, Uplo, 0, 0, N, K,
1120                           alpha.x, alpha.y, A->getID(), X->getID(),
1121                           beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1122}
1123
1124void ScriptIntrinsicBLAS::ZHPMV(RsBlasUplo Uplo, Double2 alpha, const sp<Allocation>& Ap, const sp<Allocation>& X,
1125                                int incX, Double2 beta, const sp<Allocation>& Y, int incY) {
1126    // HPMV is the same as SPR2
1127    int N = validateSPR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
1128    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpmv,
1129                           0, 0, 0, Uplo, 0, 0, N, 0,
1130                           alpha.x, alpha.y, Ap->getID(), X->getID(),
1131                           beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1132}
1133
1134void ScriptIntrinsicBLAS::ZGERU(Double2 alpha, const sp<Allocation>& X, int incX,
1135                                const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
1136    validateGERU(mRS, Element::F64_2(mRS), X, incX, Y, incY, A);
1137    int M = A->getType()->getY();
1138    int N = A->getType()->getX();
1139    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgeru,
1140                           0, 0, 0, 0, 0, M, N, 0,
1141                           alpha.x, alpha.y, X->getID(), Y->getID(),
1142                           0, 0, A->getID(), incX, incY, 0, 0);
1143}
1144
1145void ScriptIntrinsicBLAS::ZGERC(Double2 alpha, const sp<Allocation>& X, int incX,
1146                                const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
1147    // Same as GERU
1148    validateGERU(mRS, Element::F64_2(mRS), X, incX, Y, incY, A);
1149    int M = A->getType()->getY();
1150    int N = A->getType()->getX();
1151    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgerc,
1152                           0, 0, 0, 0, 0, M, N, 0,
1153                           alpha.x, alpha.y, X->getID(), Y->getID(),
1154                           0, 0, A->getID(), incX, incY, 0, 0);
1155}
1156
1157void ScriptIntrinsicBLAS::ZHER(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X,
1158                               int incX, const sp<Allocation>& A) {
1159    // Same as SYR
1160    int N = validateSYR(mRS, Element::F64_2(mRS), Uplo, X, incX, A);
1161    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher,
1162                           0, 0, 0, Uplo, 0, 0, N, 0,
1163                           alpha, 0, X->getID(), 0,
1164                           0, 0, A->getID(), incX, 0, 0, 0);
1165}
1166
1167void ScriptIntrinsicBLAS::ZHPR(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X,
1168                               int incX, const sp<Allocation>& Ap) {
1169    // Equivalent to SPR for validation
1170    int N = validateSPR(mRS, Element::F64_2(mRS), Uplo, X, incX, Ap);
1171    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpr,
1172                           0, 0, 0, Uplo, 0, 0, N, 0,
1173                           alpha, 0, X->getID(), 0,
1174                           0, 0, Ap->getID(), incX, 0, 0, 0);
1175}
1176
1177void ScriptIntrinsicBLAS::ZHER2(RsBlasUplo Uplo, Double2 alpha, const sp<Allocation>& X, int incX,
1178                                const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
1179    // Same as SYR2
1180    int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A);
1181    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher2,
1182                           0, 0, 0, Uplo, 0, 0, N, 0,
1183                           alpha.x, alpha.y, X->getID(), Y->getID(),
1184                           0, 0, A->getID(), incX, incY, 0, 0);
1185}
1186
1187void ScriptIntrinsicBLAS::ZHPR2(RsBlasUplo Uplo, Double2 alpha, const sp<Allocation>& X, int incX,
1188                                const sp<Allocation>& Y, int incY, const sp<Allocation>& Ap) {
1189    // Same as SPR2
1190    int N = validateSPR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
1191    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpr2,
1192                           0, 0, 0, Uplo, 0, 0, N, 0,
1193                           alpha.x, alpha.y, X->getID(), Y->getID(),
1194                           0, 0, Ap->getID(), incX, incY, 0, 0);
1195}
1196
1197
1198/**
1199 * Level 3 BLAS
1200 */
1201
1202static void validateL3(RS* mRS, const sp<const Element>& e, int TransA, int TransB, int Side,
1203                       const sp<Allocation>& A, const sp<Allocation>& B, const sp<Allocation>& C) {
1204    int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1;
1205    if ((A != nullptr && !A->getType()->getElement()->isCompatible(e)) ||
1206        (B != nullptr && !B->getType()->getElement()->isCompatible(e)) ||
1207        (C != nullptr && !C->getType()->getElement()->isCompatible(e))) {
1208        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1209    }
1210    if (C == nullptr) {
1211        // Since matrix C is used to store the result, it cannot be null.
1212        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Allocation C cannot be null");
1213    }
1214    cM = C->getType()->getY();
1215    cN = C->getType()->getX();
1216
1217    if (Side == RsBlasRight) {
1218        if ((A == nullptr && B != nullptr) || (A != nullptr && B == nullptr)) {
1219            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Provided Matrix A without Matrix B, or vice versa");
1220        }
1221        if (B != nullptr) {
1222            bM = A->getType()->getY();
1223            bN = A->getType()->getX();
1224        }
1225        if (A != nullptr) {
1226            aM = B->getType()->getY();
1227            aN = B->getType()->getX();
1228        }
1229    } else {
1230        if (A != nullptr) {
1231            if (TransA == RsBlasTrans || TransA == RsBlasConjTrans) {
1232                aN = A->getType()->getY();
1233                aM = A->getType()->getX();
1234            } else {
1235                aM = A->getType()->getY();
1236                aN = A->getType()->getX();
1237            }
1238        }
1239        if (B != nullptr) {
1240            if (TransB == RsBlasTrans || TransB == RsBlasConjTrans) {
1241                bN = B->getType()->getY();
1242                bM = B->getType()->getX();
1243            } else {
1244                bM = B->getType()->getY();
1245                bN = B->getType()->getX();
1246            }
1247        }
1248    }
1249    if (A != nullptr && B != nullptr && C != nullptr) {
1250        if (aN != bM || aM != cM || bN != cN) {
1251            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions");
1252        }
1253    } else if (A != nullptr && C != nullptr) {
1254        // A and C only, for SYRK
1255        if (cM != cN) {
1256            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix C is not symmetric");
1257        }
1258        if (aM != cM) {
1259            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions");
1260        }
1261    } else if (A != nullptr && B != nullptr) {
1262        // A and B only
1263        if (aN != bM) {
1264            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions");
1265        }
1266    }
1267
1268}
1269
1270void ScriptIntrinsicBLAS::SGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, float alpha,
1271                                const sp<Allocation>& A, const sp<Allocation>& B, float beta, const sp<Allocation>& C) {
1272    validateL3(mRS, Element::F32(mRS), TransA, TransB, 0, A, B, C);
1273
1274    int M = -1, N = -1, K = -1;
1275    if (TransA != RsBlasNoTrans) {
1276        M = A->getType()->getX();
1277        K = A->getType()->getY();
1278    } else {
1279        M = A->getType()->getY();
1280        K = A->getType()->getX();
1281    }
1282    if (TransB != RsBlasNoTrans) {
1283        N = B->getType()->getY();
1284    } else {
1285        N = B->getType()->getX();
1286    }
1287    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgemm,
1288                                TransA, TransB, 0, 0, 0, M, N, K,
1289                                alpha, A->getID(), B->getID(),
1290                                beta, C->getID(), 0, 0, 0, 0);
1291}
1292
1293void ScriptIntrinsicBLAS::DGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, double alpha,
1294                                const sp<Allocation>& A, const sp<Allocation>& B, double beta, const sp<Allocation>& C) {
1295    validateL3(mRS, Element::F64(mRS), TransA, TransB, 0, A, B, C);
1296    int M = -1, N = -1, K = -1;
1297    if (TransA != RsBlasNoTrans) {
1298        M = A->getType()->getX();
1299        K = A->getType()->getY();
1300    } else {
1301        M = A->getType()->getY();
1302        K = A->getType()->getX();
1303    }
1304    if (TransB != RsBlasNoTrans) {
1305        N = B->getType()->getY();
1306    } else {
1307        N = B->getType()->getX();
1308    }
1309    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgemm,
1310                                TransA, TransB, 0, 0, 0, M, N, K,
1311                                alpha, A->getID(), B->getID(),
1312                                beta, C->getID(), 0, 0, 0, 0);
1313}
1314
1315void ScriptIntrinsicBLAS::CGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, Float2 alpha,
1316                                const sp<Allocation>& A, const sp<Allocation>& B, Float2 beta, const sp<Allocation>& C) {
1317    validateL3(mRS, Element::F32_2(mRS), TransA, TransB, 0, A, B, C);
1318    int M = -1, N = -1, K = -1;
1319    if (TransA != RsBlasNoTrans) {
1320        M = A->getType()->getX();
1321        K = A->getType()->getY();
1322    } else {
1323        M = A->getType()->getY();
1324        K = A->getType()->getX();
1325    }
1326    if (TransB != RsBlasNoTrans) {
1327        N = B->getType()->getY();
1328    } else {
1329        N = B->getType()->getX();
1330    }
1331    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgemm,
1332                                 TransA, TransB, 0, 0, 0, M, N, K,
1333                                 alpha.x, alpha.y, A->getID(), B->getID(),
1334                                 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1335}
1336
1337void ScriptIntrinsicBLAS::ZGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, Double2 alpha,
1338                                const sp<Allocation>& A, const sp<Allocation>& B, Double2 beta, const sp<Allocation>& C) {
1339    validateL3(mRS, Element::F64_2(mRS), TransA, TransB, 0, A, B, C);
1340    int M = -1, N = -1, K = -1;
1341    if (TransA != RsBlasNoTrans) {
1342        M = A->getType()->getX();
1343        K = A->getType()->getY();
1344    } else {
1345        M = A->getType()->getY();
1346        K = A->getType()->getX();
1347    }
1348    if (TransB != RsBlasNoTrans) {
1349        N = B->getType()->getY();
1350    } else {
1351        N = B->getType()->getX();
1352    }
1353    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgemm,
1354                           TransA, TransB, 0, 0, 0, M, N, K,
1355                           alpha.x, alpha.y, A->getID(), B->getID(),
1356                           beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1357}
1358
1359void ScriptIntrinsicBLAS::SSYMM(RsBlasSide Side, RsBlasUplo Uplo, float alpha,
1360                                const sp<Allocation>& A, const sp<Allocation>& B, float beta, const sp<Allocation>& C) {
1361    //For SYMM, Matrix A should be symmetric
1362    if (A->getType()->getX() != A->getType()->getY()) {
1363        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
1364    }
1365    validateL3(mRS, Element::F32(mRS), 0, 0, Side, A, B, C);
1366    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssymm,
1367                                0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
1368                                alpha, A->getID(), B->getID(),
1369                                beta, C->getID(), 0, 0, 0, 0);
1370}
1371
1372void ScriptIntrinsicBLAS::DSYMM(RsBlasSide Side, RsBlasUplo Uplo, double alpha,
1373                                const sp<Allocation>& A, const sp<Allocation>& B, double beta, const sp<Allocation>& C) {
1374    if (A->getType()->getX() != A->getType()->getY()) {
1375        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
1376    }
1377    validateL3(mRS, Element::F64(mRS), 0, 0, Side, A, B, C);
1378    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsymm,
1379                                0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
1380                                alpha, A->getID(), B->getID(),
1381                                beta, C->getID(), 0, 0, 0, 0);
1382}
1383
1384void ScriptIntrinsicBLAS::CSYMM(RsBlasSide Side, RsBlasUplo Uplo, Float2 alpha,
1385                                const sp<Allocation>& A, const sp<Allocation>& B, Float2 beta, const sp<Allocation>& C) {
1386    if (A->getType()->getX() != A->getType()->getY()) {
1387        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
1388    }
1389    validateL3(mRS, Element::F32_2(mRS), 0, 0, Side, A, B, C);
1390    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csymm,
1391                                 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
1392                                 alpha.x, alpha.y, A->getID(), B->getID(),
1393                                 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1394}
1395
1396void ScriptIntrinsicBLAS::ZSYMM(RsBlasSide Side, RsBlasUplo Uplo, Double2 alpha,
1397                                const sp<Allocation>& A, const sp<Allocation>& B, Double2 beta, const sp<Allocation>& C) {
1398    if (A->getType()->getX() != A->getType()->getY()) {
1399        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
1400    }
1401    validateL3(mRS, Element::F64_2(mRS), 0, 0, Side, A, B, C);
1402    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsymm,
1403                           0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
1404                           alpha.x, alpha.y, A->getID(), B->getID(),
1405                           beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1406}
1407
1408void ScriptIntrinsicBLAS::SSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha,
1409                                const sp<Allocation>& A, float beta, const sp<Allocation>& C) {
1410    validateL3(mRS, Element::F32(mRS), Trans, 0, 0, A, nullptr, C);
1411    int K = -1;
1412    if (Trans != RsBlasNoTrans) {
1413        K = A->getType()->getY();
1414    } else {
1415        K = A->getType()->getX();
1416    }
1417    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyrk,
1418                                Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1419                                alpha, A->getID(), 0,
1420                                beta, C->getID(), 0, 0, 0, 0);
1421}
1422
1423void ScriptIntrinsicBLAS::DSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha,
1424                                const sp<Allocation>& A, double beta, const sp<Allocation>& C) {
1425    validateL3(mRS, Element::F64(mRS), Trans, 0, 0, A, nullptr, C);
1426    int K = -1;
1427    if (Trans != RsBlasNoTrans) {
1428        K = A->getType()->getY();
1429    } else {
1430        K = A->getType()->getX();
1431    }
1432    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyrk,
1433                                Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1434                                alpha, A->getID(), 0,
1435                                beta, C->getID(), 0, 0, 0, 0);
1436}
1437
1438void ScriptIntrinsicBLAS::CSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha,
1439                                const sp<Allocation>& A, Float2 beta, const sp<Allocation>& C) {
1440    validateL3(mRS, Element::F32_2(mRS), Trans, 0, 0, A, nullptr, C);
1441    int K = -1;
1442    if (Trans != RsBlasNoTrans) {
1443        K = A->getType()->getY();
1444    } else {
1445        K = A->getType()->getX();
1446    }
1447    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csyrk,
1448                                 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1449                                 alpha.x, alpha.y, A->getID(), 0,
1450                                 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1451}
1452
1453void ScriptIntrinsicBLAS::ZSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha,
1454                                const sp<Allocation>& A, Double2 beta, const sp<Allocation>& C) {
1455    validateL3(mRS, Element::F64_2(mRS), Trans, 0, 0, A, nullptr, C);
1456    int K = -1;
1457    if (Trans != RsBlasNoTrans) {
1458        K = A->getType()->getY();
1459    } else {
1460        K = A->getType()->getX();
1461    }
1462    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsyrk,
1463                           Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1464                           alpha.x, alpha.y, A->getID(), 0,
1465                           beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1466}
1467
1468static void validateSYR2K(RS* mRS, const sp<const Element>& e, RsBlasTranspose Trans,
1469                          const sp<Allocation>& A, const sp<Allocation>& B, const sp<Allocation>& C) {
1470    if (!A->getType()->getElement()->isCompatible(e) ||
1471        !B->getType()->getElement()->isCompatible(e) ||
1472        !C->getType()->getElement()->isCompatible(e)) {
1473        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1474    }
1475    int Cdim = -1;
1476    // A is n x k if no transpose, k x n if transpose
1477    // C is n x n
1478    if (Trans == RsBlasTrans) {
1479        // check columns versus C
1480        Cdim = A->getType()->getX();
1481    } else {
1482        // check rows versus C
1483        Cdim = A->getType()->getY();
1484    }
1485    if ((int)C->getType()->getX() != Cdim || (int)C->getType()->getY() != Cdim) {
1486        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid symmetric matrix in SYR2K");
1487    }
1488    // A dims == B dims
1489    if (A->getType()->getX() != B->getType()->getX() || A->getType()->getY() != B->getType()->getY()) {
1490        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid A and B in SYR2K");
1491    }
1492}
1493
1494void ScriptIntrinsicBLAS::SSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha,
1495                                 const sp<Allocation>& A, const sp<Allocation>& B, float beta, const sp<Allocation>& C) {
1496    validateSYR2K(mRS, Element::F32(mRS), Trans, A, B, C);
1497    int K = -1;
1498    if (Trans != RsBlasNoTrans) {
1499        K = A->getType()->getY();
1500    } else {
1501        K = A->getType()->getX();
1502    }
1503    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr2k,
1504                                Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1505                                alpha, A->getID(), B->getID(),
1506                                beta, C->getID(), 0, 0, 0, 0);
1507}
1508
1509void ScriptIntrinsicBLAS::DSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha,
1510                                 const sp<Allocation>& A, const sp<Allocation>& B, double beta, const sp<Allocation>& C) {
1511    validateSYR2K(mRS, Element::F64(mRS), Trans, A, B, C);
1512    int K = -1;
1513    if (Trans != RsBlasNoTrans) {
1514        K = A->getType()->getY();
1515    } else {
1516        K = A->getType()->getX();
1517    }
1518    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr2k,
1519                                Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1520                                alpha, A->getID(), B->getID(),
1521                                beta, C->getID(), 0, 0, 0, 0);
1522}
1523
1524void ScriptIntrinsicBLAS::CSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha,
1525                                 const sp<Allocation>& A, const sp<Allocation>& B, Float2 beta, const sp<Allocation>& C) {
1526    validateSYR2K(mRS, Element::F32_2(mRS), Trans, A, B, C);
1527    int K = -1;
1528    if (Trans != RsBlasNoTrans) {
1529        K = A->getType()->getY();
1530    } else {
1531        K = A->getType()->getX();
1532    }
1533    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csyr2k,
1534                                 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1535                                 alpha.x, alpha.y, A->getID(), B->getID(),
1536                                 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1537}
1538
1539void ScriptIntrinsicBLAS::ZSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha,
1540                                 const sp<Allocation>& A, const sp<Allocation>& B, Double2 beta, const sp<Allocation>& C) {
1541    validateSYR2K(mRS, Element::F64_2(mRS), Trans, A, B, C);
1542    int K = -1;
1543    if (Trans != RsBlasNoTrans) {
1544        K = A->getType()->getY();
1545    } else {
1546        K = A->getType()->getX();
1547    }
1548    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsyr2k,
1549                           Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1550                           alpha.x, alpha.y, A->getID(), B->getID(),
1551                           beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1552}
1553
1554static void validateTRMM(RS* mRS, const sp<const Element>& e, RsBlasSide Side, RsBlasTranspose TransA,
1555                         const sp<Allocation>& A, const sp<Allocation>& B) {
1556    int aM = -1, aN = -1, bM = -1, bN = -1;
1557    if (!A->getType()->getElement()->isCompatible(e) ||
1558        !B->getType()->getElement()->isCompatible(e)) {
1559        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1560    }
1561
1562    aM = A->getType()->getY();
1563    aN = A->getType()->getX();
1564    if (aM != aN) {
1565        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with a non-symmetric matrix A");
1566    }
1567
1568    bM = B->getType()->getY();
1569    bN = B->getType()->getX();
1570    if (Side == RsBlasLeft) {
1571        if (aN != bM) {
1572            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with invalid matrices");
1573        }
1574    } else {
1575        if (bN != aM) {
1576            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with invalid matrices");
1577        }
1578    }
1579}
1580
1581void ScriptIntrinsicBLAS::STRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1582                                float alpha, const sp<Allocation>& A, const sp<Allocation>& B) {
1583    validateTRMM(mRS, Element::F32(mRS), Side, TransA, A, B);
1584    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strmm,
1585                                TransA, 0, Side, Uplo, Diag,\
1586                                B->getType()->getY(), B->getType()->getX(), 0,
1587                                alpha, A->getID(), B->getID(), 0.f, 0, 0, 0, 0, 0);
1588}
1589
1590void ScriptIntrinsicBLAS::DTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1591                                double alpha, const sp<Allocation>& A, const sp<Allocation>& B) {
1592    validateTRMM(mRS, Element::F64(mRS), Side, TransA, A, B);
1593    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrmm,
1594                                TransA, 0, Side, Uplo, Diag,
1595                                B->getType()->getY(), B->getType()->getX(), 0,
1596                                alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0);
1597}
1598
1599void ScriptIntrinsicBLAS::CTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1600                                Float2 alpha, const sp<Allocation>& A, const sp<Allocation>& B) {
1601    validateTRMM(mRS, Element::F32_2(mRS), Side, TransA, A, B);
1602    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrmm,
1603                                 TransA, 0, Side, Uplo, Diag,
1604                                 B->getType()->getY(), B->getType()->getX(), 0,
1605                                 alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
1606}
1607
1608void ScriptIntrinsicBLAS::ZTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1609                                Double2 alpha, const sp<Allocation>& A, const sp<Allocation>& B) {
1610    validateTRMM(mRS, Element::F64_2(mRS), Side, TransA, A, B);
1611    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrmm,
1612                           TransA, 0, Side, Uplo, Diag,
1613                           B->getType()->getY(), B->getType()->getX(), 0,
1614                           alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
1615}
1616
1617static void validateTRSM(RS* mRS, const sp<const Element>& e, RsBlasSide Side, RsBlasTranspose TransA,
1618                         const sp<Allocation>& A, const sp<Allocation>& B) {
1619    int adim = -1, bM = -1, bN = -1;
1620    if (!A->getType()->getElement()->isCompatible(e) ||
1621        !B->getType()->getElement()->isCompatible(e)) {
1622        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1623    }
1624    adim = A->getType()->getX();
1625    if (adim != (int)A->getType()->getY()) {
1626        // This may be unnecessary, the restriction could potentially be relaxed.
1627        // Allocation A needs to contain at least that symmetric matrix but could theoretically
1628        // be larger for now we assume adapters are sufficient, will reevaluate in the future.
1629        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with a non-symmetric matrix A");
1630    }
1631    bM = B->getType()->getY();
1632    bN = B->getType()->getX();
1633    if (Side == RsBlasLeft) {
1634        // A is M*M
1635        if (adim != bM) {
1636            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with invalid matrix dimensions");
1637        }
1638    } else {
1639        // A is N*N
1640        if (adim != bN) {
1641            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with invalid matrix dimensions");
1642        }
1643    }
1644}
1645
1646void ScriptIntrinsicBLAS::STRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1647                                float alpha, const sp<Allocation>& A, const sp<Allocation>& B) {
1648    validateTRSM(mRS, Element::F32(mRS), Side, TransA, A, B);
1649    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strsm,
1650                                TransA, 0, Side, Uplo, Diag,
1651                                B->getType()->getY(), B->getType()->getX(), 0,
1652                                alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0);
1653}
1654
1655void ScriptIntrinsicBLAS::DTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1656                                double alpha, const sp<Allocation>& A, const sp<Allocation>& B) {
1657    validateTRSM(mRS, Element::F64(mRS), Side, TransA, A, B);
1658    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrsm,
1659                                TransA, 0, Side, Uplo, Diag,
1660                                B->getType()->getY(), B->getType()->getX(), 0,
1661                                alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0);
1662}
1663
1664void ScriptIntrinsicBLAS::CTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1665                                Float2 alpha, const sp<Allocation>& A, const sp<Allocation>& B) {
1666    validateTRSM(mRS, Element::F32_2(mRS), Side, TransA, A, B);
1667    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrsm,
1668                                 TransA, 0, Side, Uplo, Diag,
1669                                 B->getType()->getY(), B->getType()->getX(), 0,
1670                                 alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
1671}
1672
1673void ScriptIntrinsicBLAS::ZTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1674                                Double2 alpha, const sp<Allocation>& A, const sp<Allocation>& B) {
1675    validateTRSM(mRS, Element::F64_2(mRS), Side, TransA, A, B);
1676    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrsm,
1677                           TransA, 0, Side, Uplo, Diag,
1678                           B->getType()->getY(), B->getType()->getX(), 0,
1679                           alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
1680}
1681
1682static void validateHEMM(RS* mRS, const sp<const Element>& e, RsBlasSide Side,
1683                         const sp<Allocation>& A, const sp<Allocation>& B, const sp<Allocation>& C) {
1684    if (!A->getType()->getElement()->isCompatible(e) ||
1685        !B->getType()->getElement()->isCompatible(e) ||
1686        !C->getType()->getElement()->isCompatible(e)) {
1687        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1688    }
1689
1690    // A must be square; can potentially be relaxed similar to TRSM
1691    int adim = A->getType()->getX();
1692    if (adim != (int)A->getType()->getY()) {
1693        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with non-square A");
1694    }
1695    if ((Side == RsBlasLeft && adim != (int)B->getType()->getY()) ||
1696        (Side == RsBlasRight && adim != (int)B->getType()->getX())) {
1697        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with invalid B");
1698    }
1699    if (B->getType()->getX() != C->getType()->getX() ||
1700        B->getType()->getY() != C->getType()->getY()) {
1701        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with mismatched B and C");
1702    }
1703}
1704
1705void ScriptIntrinsicBLAS::CHEMM(RsBlasSide Side, RsBlasUplo Uplo, Float2 alpha,
1706                                const sp<Allocation>& A, const sp<Allocation>& B, Float2 beta, const sp<Allocation>& C) {
1707    validateHEMM(mRS, Element::F32_2(mRS), Side, A, B, C);
1708    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chemm,
1709                                 0, 0, Side, Uplo, 0,
1710                                 C->getType()->getY(), C->getType()->getX(), 0,
1711                                 alpha.x, alpha.y, A->getID(), B->getID(),
1712                                 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1713}
1714
1715void ScriptIntrinsicBLAS::ZHEMM(RsBlasSide Side, RsBlasUplo Uplo, Double2 alpha,
1716                                const sp<Allocation>& A, const sp<Allocation>& B, Double2 beta, const sp<Allocation>& C) {
1717    validateHEMM(mRS, Element::F64_2(mRS), Side, A, B, C);
1718    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhemm,
1719                           0, 0, Side, Uplo, 0,
1720                           C->getType()->getY(), C->getType()->getX(), 0,
1721                           alpha.x, alpha.y, A->getID(), B->getID(),
1722                           beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1723}
1724
1725static void validateHERK(RS* mRS, const sp<const Element>& e, RsBlasTranspose Trans,
1726                         const sp<Allocation>& A, const sp<Allocation>& C) {
1727    if (!A->getType()->getElement()->isCompatible(e) ||
1728        !C->getType()->getElement()->isCompatible(e)) {
1729        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1730    }
1731    if (Trans != RsBlasNoTrans && Trans != RsBlasConjTrans) {
1732        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Call HERK with invalid Transpose");
1733    }
1734    int cdim = C->getType()->getX();
1735    if (cdim != (int)C->getType()->getY()) {
1736        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with non-square C");
1737    }
1738    if (Trans == RsBlasNoTrans) {
1739        if (cdim != (int)A->getType()->getY()) {
1740            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with invalid A");
1741        }
1742    } else {
1743        if (cdim != (int)A->getType()->getX()) {
1744            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with invalid A");
1745        }
1746    }
1747}
1748
1749void ScriptIntrinsicBLAS::CHERK(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha,
1750                                const sp<Allocation>& A, float beta, const sp<Allocation>& C) {
1751    validateHERK(mRS, Element::F32_2(mRS), Trans, A, C);
1752    int k = 0;
1753    if (Trans == RsBlasConjTrans) {
1754        k = A->getType()->getY();
1755    } else {
1756        k = A->getType()->getX();
1757    }
1758    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cherk,
1759                                 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
1760                                 alpha, 0, A->getID(), 0,
1761                                 beta, 0, C->getID(), 0, 0, 0, 0);
1762}
1763
1764void ScriptIntrinsicBLAS::ZHERK(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha,
1765                                const sp<Allocation>& A, double beta, const sp<Allocation>& C) {
1766    validateHERK(mRS, Element::F64_2(mRS), Trans, A, C);
1767    int k = 0;
1768    if (Trans == RsBlasConjTrans) {
1769        k = A->getType()->getY();
1770    } else {
1771        k = A->getType()->getX();
1772    }
1773    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zherk,
1774                           Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
1775                           alpha, 0, A->getID(), 0,
1776                           beta, 0, C->getID(), 0, 0, 0, 0);
1777}
1778
1779static void validateHER2K(RS* mRS, const sp<const Element>& e, RsBlasTranspose Trans,
1780                          const sp<Allocation>& A, const sp<Allocation>& B, const sp<Allocation>& C) {
1781    if (!A->getType()->getElement()->isCompatible(e) ||
1782        !B->getType()->getElement()->isCompatible(e) ||
1783        !C->getType()->getElement()->isCompatible(e)) {
1784        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1785    }
1786    if (Trans != RsBlasNoTrans && Trans != RsBlasConjTrans) {
1787        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Call HERK with invalid Transpose");
1788    }
1789    int cdim = C->getType()->getX();
1790    if (cdim != (int)C->getType()->getY()) {
1791        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with non-square C");
1792    }
1793    if (Trans == RsBlasNoTrans) {
1794        if ((int)A->getType()->getY() != cdim) {
1795            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid matrices");
1796        }
1797    } else {
1798        if ((int)A->getType()->getX() != cdim) {
1799            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid matrices");
1800        }
1801    }
1802    if (A->getType()->getX() != B->getType()->getX() || A->getType()->getY() != B->getType()->getY()) {
1803        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid A and B matrices");
1804    }
1805}
1806
1807void ScriptIntrinsicBLAS::CHER2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha,
1808                                 const sp<Allocation>& A, const sp<Allocation>& B, float beta, const sp<Allocation>& C) {
1809    validateHER2K(mRS, Element::F32_2(mRS), Trans, A, B, C);
1810    int k = 0;
1811    if (Trans == RsBlasNoTrans) {
1812        k = A->getType()->getX();
1813    } else {
1814        k = A->getType()->getY();
1815    }
1816    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher2k,
1817                                 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
1818                                 alpha.x, alpha.y, A->getID(), B->getID(),
1819                                 beta, 0, C->getID(), 0, 0, 0, 0);
1820}
1821
1822void ScriptIntrinsicBLAS::ZHER2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha,
1823                                 const sp<Allocation>& A, const sp<Allocation>& B, double beta, const sp<Allocation>& C) {
1824    validateHER2K(mRS, Element::F64_2(mRS), Trans, A, B, C);
1825    int k = 0;
1826    if (Trans == RsBlasNoTrans) {
1827        k = A->getType()->getX();
1828    } else {
1829        k = A->getType()->getY();
1830    }
1831    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher2k,
1832                           Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
1833                           alpha.x, alpha.y, A->getID(), B->getID(),
1834                           beta, 0, C->getID(), 0, 0, 0, 0);
1835}
1836
1837
1838
1839void ScriptIntrinsicBLAS::BNNM(const sp<Allocation>& A, int a_offset, const sp<Allocation>& B, int b_offset,
1840                               const sp<Allocation>& C, int c_offset, int c_mult) {
1841    validateL3(mRS, Element::U8(mRS), RsBlasNoTrans, RsBlasTrans, 0, A, B, C);
1842
1843    if (a_offset < 0 || a_offset > 255) {
1844        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid a_offset passed to BNNM");
1845    }
1846    if (b_offset < 0 || b_offset > 255) {
1847        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid b_offset passed to BNNM");
1848    }
1849    int M = -1, N = -1, K = -1;
1850    M = A->getType()->getY();
1851    N = B->getType()->getY();
1852    K = A->getType()->getX();
1853
1854    nScriptIntrinsicBLAS_BNNM(mRS, mRS->getContext(), getID(), M, N, K, A->getID(), a_offset,
1855                              B->getID(), b_offset, C->getID(), c_offset, c_mult);
1856}
1857