1/*
2 * Copyright (C) 2015 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17
18#include "RenderScript.h"
19#include "rsCppInternal.h"
20
21#define NELEM(m) (sizeof(m) / sizeof((m)[0]))
22
23using namespace android;
24using namespace RSC;
25
26// ScriptIntrinsicBLAS APIS
27ScriptIntrinsicBLAS::ScriptIntrinsicBLAS(sp<RS> rs, sp<const Element> e)
28    : ScriptIntrinsic(rs, RS_SCRIPT_INTRINSIC_ID_BLAS, e) {
29
30}
31
32sp<ScriptIntrinsicBLAS> ScriptIntrinsicBLAS::create(sp<RS> rs) {
33    return new ScriptIntrinsicBLAS(rs, Element::U32(rs));
34}
35
36enum RsBlasDataType {
37    SINGLE,
38    DOUBLE,
39    SINGLE_COMPLEX,
40    DOUBLE_COMPLEX
41};
42
43static RsBlasCall
44setUpBLASCall(RsBlasDataType dataType, RsBlasFunction func,
45              int TransA, int TransB, int Side, int Uplo, int Diag,
46              int M, int N, int K, int incX, int incY, int KL, int KU,
47              float alphaF, float betaF, double alphaD, double betaD,
48              float alphaCX, float alphaCY, float betaCX, float betaCY,
49              double alphaZX, double alphaZY, double betaZX, double betaZY
50              ) {
51    RsBlasCall call;
52    memset(&call, 0, sizeof(call));
53    call.func = func;
54    call.transA = (RsBlasTranspose)TransA;
55    call.transB = (RsBlasTranspose)TransB;
56    call.side = (RsBlasSide)Side;
57    call.uplo = (RsBlasUplo)Uplo;
58    call.diag = (RsBlasDiag)Diag;
59    call.M = M;
60    call.N = N;
61    call.K = K;
62
63    switch (dataType) {
64        case SINGLE:
65            // For Single-precision BLAS.
66            call.alpha.f = alphaF;
67            call.beta.f = betaF;
68            break;
69        case DOUBLE:
70            // For Double-precision BLAS.
71            call.alpha.d = alphaD;
72            call.beta.d = betaD;
73            break;
74        case SINGLE_COMPLEX:
75            // For Single-precision complex BLAS.
76            call.alpha.c.r = alphaCX;
77            call.alpha.c.i = alphaCY;
78            call.beta.c.r = betaCX;
79            call.beta.c.i = betaCY;
80            break;
81        case DOUBLE_COMPLEX:
82            // For Double-precision complex BLAS.
83            call.alpha.z.r = alphaZX;
84            call.alpha.z.i = alphaZY;
85            call.beta.z.r = betaZX;
86            call.beta.z.i = betaZY;
87            break;
88        default:
89            break;
90    }
91
92    call.incX = incX;
93    call.incY = incY;
94    call.KL = KL;
95    call.KU = KU;
96
97    return call;
98}
99
100static void
101nScriptIntrinsicBLAS_Single(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
102                            int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
103                            float alpha, RsAllocation A, RsAllocation B,
104                            float beta, RsAllocation C, int incX, int incY, int KL, int KU) {
105    RsBlasCall call = setUpBLASCall(SINGLE, func, TransA, TransB, Side, Uplo, Diag,
106                                    M, N, K, incX, incY, KL, KU, alpha, beta, 0.0, 0.0,
107                                    0.0f, 0.0f, 0.0f, 0.0f, 0.0, 0.0, 0.0, 0.0);
108    RsAllocation in_allocs[3] = {A, B, C};
109    tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr,
110                                                      &call, sizeof(call), nullptr, 0));
111}
112
113
114static void
115nScriptIntrinsicBLAS_Double(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
116                            int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
117                            double alpha, RsAllocation A, RsAllocation B,
118                            double beta, RsAllocation C, int incX, int incY, int KL, int KU) {
119    RsBlasCall call = setUpBLASCall(DOUBLE, func, TransA, TransB, Side, Uplo, Diag,
120                                    M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, alpha, beta,
121                                    0.0f, 0.0f, 0.0f, 0.0f, 0.0, 0.0, 0.0, 0.0);
122    RsAllocation in_allocs[3] = {A, B, C};
123    tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr,
124                                                      &call, sizeof(call), nullptr, 0));
125}
126
127static void
128nScriptIntrinsicBLAS_Complex(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
129                             int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
130                             float alphaX, float alphaY, RsAllocation A, RsAllocation B,
131                             float betaX, float betaY, RsAllocation C, int incX, int incY, int KL, int KU) {
132    RsBlasCall call = setUpBLASCall(SINGLE_COMPLEX, func, TransA, TransB, Side, Uplo, Diag,
133                                    M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, 0.0, 0.0,
134                                    alphaX, alphaY, betaX, betaY, 0.0, 0.0, 0.0, 0.0);
135    RsAllocation in_allocs[3] = {A, B, C};
136    tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr,
137                                                      &call, sizeof(call), nullptr, 0));
138}
139
140static void
141nScriptIntrinsicBLAS_Z(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
142                       int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
143                       double alphaX, double alphaY, RsAllocation A, RsAllocation B,
144                       double betaX, double betaY, RsAllocation C, int incX, int incY, int KL, int KU) {
145    RsBlasCall call = setUpBLASCall(DOUBLE_COMPLEX, func, TransA, TransB, Side, Uplo, Diag,
146                                    M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, 0.0, 0.0,
147                                    0.0f, 0.0f, 0.0f, 0.0f, alphaX, alphaY, betaX, betaY);
148    RsAllocation in_allocs[3] = {A, B, C};
149    tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr,
150                                                      &call, sizeof(call), nullptr, 0));
151}
152
153
154static void
155nScriptIntrinsicBLAS_BNNM(RS* mRS, RsContext con, RsScript id, int M, int N, int K,
156                          RsAllocation A, int a_offset, RsAllocation B, int b_offset,
157                          RsAllocation C, int c_offset, int c_mult_int) {
158    RsBlasCall call;
159    memset(&call, 0, sizeof(call));
160    call.func = RsBlas_bnnm;
161    call.M = M;
162    call.N = N;
163    call.K = K;
164    call.a_offset = a_offset & 0xFF;
165    call.b_offset = b_offset & 0xFF;
166    call.c_offset = c_offset;
167    call.c_mult_int = c_mult_int;
168
169    RsAllocation in_allocs[3] = {A, B, C};
170    tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr,
171                                                      &call, sizeof(call), nullptr, 0));
172}
173
174/**
175 * Level 2 BLAS
176 */
177static void validateGEMV(RS* mRS, sp<const Element> e, RsBlasTranspose TransA, sp<Allocation> A,
178                         sp<Allocation> X, int incX, sp<Allocation> Y, int incY) {
179    int M = A->getType()->getY();
180    int N = A->getType()->getX();
181    if (!A->getType()->getElement()->isCompatible(e) ||
182        !X->getType()->getElement()->isCompatible(e) ||
183        !Y->getType()->getElement()->isCompatible(e)) {
184        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
185    }
186    if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
187        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
188    }
189
190    if (incX <= 0 || incY <= 0) {
191        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
192    }
193    int expectedXDim = -1, expectedYDim = -1;
194    if (TransA == RsBlasNoTrans) {
195        expectedXDim = 1 + (N - 1) * incX;
196        expectedYDim = 1 + (M - 1) * incY;
197    } else {
198        expectedXDim = 1 + (M - 1) * incX;
199        expectedYDim = 1 + (N - 1) * incY;
200    }
201    if ((int)X->getType()->getX() != expectedXDim ||
202        (int)Y->getType()->getX() != expectedYDim) {
203        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GEMV");
204    }
205}
206
207void ScriptIntrinsicBLAS::SGEMV(RsBlasTranspose TransA, float alpha, sp<Allocation> A, sp<Allocation> X,
208                                int incX, float beta, sp<Allocation> Y, int incY) {
209    validateGEMV(mRS, Element::F32(mRS), TransA, A, X, incX, Y, incY);
210    int M = A->getType()->getY();
211    int N = A->getType()->getX();
212    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgemv,
213                                TransA, 0, 0, 0, 0, M, N, 0,
214                                alpha, A->getID(), X->getID(),
215                                beta, Y->getID(), incX, incY, 0, 0);
216}
217
218void ScriptIntrinsicBLAS::DGEMV(RsBlasTranspose TransA, double alpha, sp<Allocation> A, sp<Allocation> X,
219                                int incX, double beta, sp<Allocation> Y, int incY) {
220    validateGEMV(mRS, Element::F64(mRS), TransA, A, X, incX, Y, incY);
221    int M = A->getType()->getY();
222    int N = A->getType()->getX();
223    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgemv,
224                                TransA, 0, 0, 0, 0, M, N, 0,
225                                alpha, A->getID(), X->getID(),
226                                beta, Y->getID(), incX, incY, 0, 0);
227}
228
229void ScriptIntrinsicBLAS::CGEMV(RsBlasTranspose TransA, Float2 alpha, sp<Allocation> A, sp<Allocation> X,
230                                int incX, Float2 beta, sp<Allocation> Y, int incY) {
231    validateGEMV(mRS, Element::F32_2(mRS), TransA, A, X, incX, Y, incY);
232    int M = A->getType()->getY();
233    int N = A->getType()->getX();
234    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgemv,
235                                 TransA, 0, 0, 0, 0, M, N, 0,
236                                 alpha.x, alpha.y, A->getID(), X->getID(),
237                                 beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
238}
239
240void ScriptIntrinsicBLAS::ZGEMV(RsBlasTranspose TransA, Double2 alpha, sp<Allocation> A, sp<Allocation> X,
241                                int incX, Double2 beta, sp<Allocation> Y, int incY) {
242    validateGEMV(mRS, Element::F64_2(mRS), TransA, A, X, incX, Y, incY);
243    int M = A->getType()->getY();
244    int N = A->getType()->getX();
245    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgemv,
246                           TransA, 0, 0, 0, 0, M, N, 0,
247                           alpha.x, alpha.y, A->getID(), X->getID(),
248                           beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
249}
250
251void ScriptIntrinsicBLAS::SGBMV(RsBlasTranspose TransA, int KL, int KU, float alpha, sp<Allocation> A,
252                                sp<Allocation> X, int incX, float beta, sp<Allocation> Y, int incY) {
253    // GBMV has the same validation requirements as GEMV + KL and KU >= 0
254    validateGEMV(mRS, Element::F32(mRS), TransA, A, X, incX, Y, incY);
255    if (KL < 0 || KU < 0) {
256        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
257    }
258    int M = A->getType()->getY();
259    int N = A->getType()->getX();
260
261    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgbmv,
262                                TransA, 0, 0, 0, 0, M, N, 0,
263                                alpha, A->getID(), X->getID(),
264                                beta, Y->getID(), incX, incY, KL, KU);
265}
266
267void ScriptIntrinsicBLAS::DGBMV(RsBlasTranspose TransA, int KL, int KU, double alpha, sp<Allocation> A,
268                                sp<Allocation> X, int incX, double beta, sp<Allocation> Y, int incY) {
269    // GBMV has the same validation requirements as GEMV + KL and KU >= 0
270    validateGEMV(mRS, Element::F64(mRS), TransA, A, X, incX, Y, incY);
271    if (KL < 0 || KU < 0) {
272        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
273    }
274    int M = A->getType()->getY();
275    int N = A->getType()->getX();
276
277    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgbmv,
278                                TransA, 0, 0, 0, 0, M, N, 0,
279                                alpha, A->getID(), X->getID(),
280                                beta, Y->getID(), incX, incY, KL, KU);
281}
282
283void ScriptIntrinsicBLAS::CGBMV(RsBlasTranspose TransA, int KL, int KU, Float2 alpha, sp<Allocation> A,
284                                sp<Allocation> X, int incX, Float2 beta, sp<Allocation> Y, int incY) {
285    // GBMV has the same validation requirements as GEMV + KL and KU >= 0
286    validateGEMV(mRS, Element::F32_2(mRS), TransA, A, X, incX, Y, incY);
287    if (KL < 0 || KU < 0) {
288        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
289    }
290    int M = A->getType()->getY();
291    int N = A->getType()->getX();
292
293    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgbmv,
294                                 TransA, 0, 0, 0, 0, M, N, 0,
295                                 alpha.x, alpha.y, A->getID(), X->getID(),
296                                 beta.x, beta.y, Y->getID(), incX, incY, KL, KU);
297}
298
299void ScriptIntrinsicBLAS::ZGBMV(RsBlasTranspose TransA, int KL, int KU, Double2 alpha, sp<Allocation> A,
300                                sp<Allocation> X, int incX, Double2 beta, sp<Allocation> Y, int incY) {
301    // GBMV has the same validation requirements as GEMV + KL and KU >= 0
302    validateGEMV(mRS, Element::F64_2(mRS), TransA, A, X, incX, Y, incY);
303    if (KL < 0 || KU < 0) {
304        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
305    }
306    int M = A->getType()->getY();
307    int N = A->getType()->getX();
308
309    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgbmv,
310                           TransA, 0, 0, 0, 0, M, N, 0,
311                           alpha.x, alpha.y, A->getID(), X->getID(),
312                           beta.x, beta.y, Y->getID(), incX, incY, KL, KU);
313}
314
315static void validateTRMV(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, RsBlasTranspose TransA,
316                         RsBlasDiag Diag, sp<Allocation> A, sp<Allocation> X, int incX) {
317    int N = A->getType()->getY();
318    if ((int)A->getType()->getX() != N) {
319        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a square matrix for TRMV");
320    }
321    if (!A->getType()->getElement()->isCompatible(e) ||
322        !X->getType()->getElement()->isCompatible(e)) {
323        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
324    }
325    if (X->getType()->getY() > 1) {
326        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
327    }
328
329    if (incX <= 0) {
330        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
331    }
332    int expectedXDim = 1 + (N - 1) * incX;
333    if ((int)X->getType()->getX() != expectedXDim) {
334        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for TRMV");
335    }
336}
337
338static int validateTPMV(RS* mRS, sp<const Element> e,  RsBlasUplo Uplo, RsBlasTranspose TransA,
339                        RsBlasDiag Diag, sp<Allocation> Ap, sp<Allocation> X, int incX) {
340    if (!Ap->getType()->getElement()->isCompatible(e) ||
341        !X->getType()->getElement()->isCompatible(e)) {
342        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
343    }
344    if (X->getType()->getY() > 1) {
345        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
346    }
347
348    if (Ap->getType()->getY() > 1) {
349        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
350    }
351
352    int N = sqrt((double)Ap->getType()->getX() * 2);
353    if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
354        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
355    }
356    if (incX <= 0) {
357        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
358    }
359    int expectedXDim = 1 + (N - 1) * incX;
360    if ((int)X->getType()->getX() != expectedXDim) {
361        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for TPMV");
362    }
363
364    return N;
365}
366
367
368void ScriptIntrinsicBLAS::STRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
369                                sp<Allocation> A, sp<Allocation> X, int incX) {
370    validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
371    int N = A->getType()->getY();
372    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strmv,
373                                TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
374                                A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
375}
376
377void ScriptIntrinsicBLAS::DTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
378                                sp<Allocation> A, sp<Allocation> X, int incX) {
379    validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
380    int N = A->getType()->getY();
381    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrmv,
382                                TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
383                                A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
384}
385
386void ScriptIntrinsicBLAS::CTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
387                                sp<Allocation> A, sp<Allocation> X, int incX) {
388    validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
389    int N = A->getType()->getY();
390    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrmv,
391                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
392                                 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
393}
394
395void ScriptIntrinsicBLAS::ZTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
396                                sp<Allocation> A, sp<Allocation> X, int incX) {
397    validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
398    int N = A->getType()->getY();
399    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrmv,
400                           TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
401                           A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
402}
403
404void ScriptIntrinsicBLAS::STBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
405                                int K, sp<Allocation> A, sp<Allocation> X, int incX) {
406    // TBMV has the same requirements as TRMV + K >= 0
407    if (K < 0) {
408        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
409    }
410    validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
411    int N = A->getType()->getY();
412    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stbmv,
413                                TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
414                                A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
415}
416
417void ScriptIntrinsicBLAS::DTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
418                                int K, sp<Allocation> A, sp<Allocation> X, int incX) {
419    // TBMV has the same requirements as TRMV + K >= 0
420    if (K < 0) {
421        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
422    }
423    validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
424    int N = A->getType()->getY();
425    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtbmv,
426                                TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
427                                A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
428}
429
430void ScriptIntrinsicBLAS::CTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
431                                int K, sp<Allocation> A, sp<Allocation> X, int incX) {
432    // TBMV has the same requirements as TRMV + K >= 0
433    if (K < 0) {
434        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
435    }
436    validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
437    int N = A->getType()->getY();
438    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctbmv,
439                                 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0,
440                                 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
441}
442
443void ScriptIntrinsicBLAS::ZTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
444                                int K, sp<Allocation> A, sp<Allocation> X, int incX) {
445    // TBMV has the same requirements as TRMV + K >= 0
446    if (K < 0) {
447        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
448    }
449    validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
450    int N = A->getType()->getY();
451    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztbmv,
452                           TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0,
453                           A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
454}
455
456void ScriptIntrinsicBLAS::STPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
457                                sp<Allocation> Ap, sp<Allocation> X, int incX) {
458    int N = validateTPMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
459    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stpmv,
460                                TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
461                                Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
462}
463
464void ScriptIntrinsicBLAS::DTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
465                                sp<Allocation> Ap, sp<Allocation> X, int incX) {
466    int N = validateTPMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
467    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtpmv,
468                                TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
469                                Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
470}
471
472void ScriptIntrinsicBLAS::CTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
473                                sp<Allocation> Ap,  sp<Allocation> X,  int incX) {
474    int N = validateTPMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
475    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctpmv,
476                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
477                                 Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
478}
479
480void ScriptIntrinsicBLAS::ZTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
481                                sp<Allocation> Ap, sp<Allocation> X, int incX) {
482    int N = validateTPMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
483    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztpmv,
484                           TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
485                           Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
486}
487
488void ScriptIntrinsicBLAS::STRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
489                                sp<Allocation> A, sp<Allocation> X, int incX) {
490    // TRSV is the same as TRMV
491    validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
492    int N = A->getType()->getY();
493    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strsv,
494                                TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
495                                A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
496}
497
498void ScriptIntrinsicBLAS::DTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
499                                sp<Allocation> A,  sp<Allocation> X,  int incX) {
500    // TRSV is the same as TRMV
501    validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
502    int N = A->getType()->getY();
503    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrsv,
504                                TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
505                                A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
506
507}
508
509void ScriptIntrinsicBLAS::CTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
510                                sp<Allocation> A, sp<Allocation> X, int incX) {
511    // TRSV is the same as TRMV
512    validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
513    int N = A->getType()->getY();
514    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrsv,
515                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
516                                 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
517
518}
519
520void ScriptIntrinsicBLAS::ZTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
521                                sp<Allocation> A, sp<Allocation> X, int incX) {
522    // TRSV is the same as TRMV
523    validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
524    int N = A->getType()->getY();
525    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrsv,
526                           TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
527                           A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
528
529}
530
531void ScriptIntrinsicBLAS::STBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
532                                int K, sp<Allocation> A, sp<Allocation> X, int incX) {
533    // TBSV is the same as TRMV + K >= 0
534    validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
535    int N = A->getType()->getY();
536    if (K < 0) {
537        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
538    }
539    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stbsv,
540                                TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
541                                A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
542}
543
544void ScriptIntrinsicBLAS::DTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
545                                int K, sp<Allocation> A, sp<Allocation> X, int incX) {
546    // TBSV is the same as TRMV + K >= 0
547    validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
548    int N = A->getType()->getY();
549    if (K < 0) {
550        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
551    }
552    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtbsv,
553                                TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
554                                A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
555}
556
557void ScriptIntrinsicBLAS::CTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
558                                int K, sp<Allocation> A, sp<Allocation> X, int incX) {
559    // TBSV is the same as TRMV + K >= 0
560    validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
561    int N = A->getType()->getY();
562    if (K < 0) {
563        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
564    }
565    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctbsv,
566                                 TransA, 0, 0, Uplo, Diag, 0, N, K,
567                                 0, 0, A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
568}
569
570void ScriptIntrinsicBLAS::ZTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
571                                int K, sp<Allocation> A, sp<Allocation> X, int incX) {
572    // TBSV is the same as TRMV + K >= 0
573    validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
574    int N = A->getType()->getY();
575    if (K < 0) {
576        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
577    }
578    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztbsv,
579                           TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0,
580                           A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
581}
582
583void ScriptIntrinsicBLAS::STPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
584                                sp<Allocation> Ap, sp<Allocation> X, int incX) {
585    // TPSV is same as TPMV
586    int N = validateTPMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
587    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stpsv,
588                                TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
589                                Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
590}
591
592void ScriptIntrinsicBLAS::DTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
593                                sp<Allocation> Ap, sp<Allocation> X, int incX) {
594    // TPSV is same as TPMV
595    int N = validateTPMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
596    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtpsv,
597                                TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
598                                Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
599}
600
601void ScriptIntrinsicBLAS::CTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
602                                sp<Allocation> Ap, sp<Allocation> X, int incX) {
603    // TPSV is same as TPMV
604    int N = validateTPMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
605    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctpsv,
606                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
607                                 Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
608}
609
610void ScriptIntrinsicBLAS::ZTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
611                                sp<Allocation> Ap, sp<Allocation> X, int incX) {
612    // TPSV is same as TPMV
613    int N = validateTPMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
614    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztpsv,
615                           TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
616                           Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
617}
618
619/**
620 * Level 2, S and D only
621 */
622static int validateSYMV(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, sp<Allocation> A,
623                        sp<Allocation> X, sp<Allocation> Y, int incX, int incY) {
624    int N = A->getType()->getY();
625    if ((int)A->getType()->getX() != N) {
626        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a square matrix for SYMV");
627    }
628    if (!A->getType()->getElement()->isCompatible(e) ||
629        !X->getType()->getElement()->isCompatible(e) ||
630        !Y->getType()->getElement()->isCompatible(e) ) {
631        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
632    }
633    if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
634        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
635    }
636
637    if (incX <= 0 || incY <= 0) {
638        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
639    }
640    int expectedXDim = 1 + (N - 1) * incX;
641    if ((int)X->getType()->getX() != expectedXDim) {
642        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYMV");
643    }
644    int expectedYDim = 1 + (N - 1) * incY;
645    if ((int)Y->getType()->getX() != expectedYDim) {
646        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYMV");
647    }
648    return N;
649}
650static int validateSPMV(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, sp<Allocation> Ap,
651                        sp<Allocation> X, int incX, sp<Allocation> Y, int incY) {
652    if (!Ap->getType()->getElement()->isCompatible(e) ||
653        !X->getType()->getElement()->isCompatible(e) ||
654        !Y->getType()->getElement()->isCompatible(e)) {
655        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
656    }
657    if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
658        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
659    }
660
661    if (Ap->getType()->getY() > 1) {
662        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
663    }
664
665    int N = sqrt((double)Ap->getType()->getX() * 2);
666    if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
667        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
668    }
669    if (incX <= 0 || incY <= 0) {
670        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
671    }
672    int expectedXDim = 1 + (N - 1) * incX;
673    if ((int)X->getType()->getX() != expectedXDim) {
674        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPMV");
675    }
676    int expectedYDim = 1 + (N - 1) * incY;
677    if ((int)Y->getType()->getX() != expectedYDim) {
678        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPMV");
679    }
680
681    return N;
682}
683static void validateGER(RS* mRS, sp<const Element> e, sp<Allocation> X, int incX,
684                        sp<Allocation> Y, int incY, sp<Allocation> A) {
685    if (!A->getType()->getElement()->isCompatible(e) ||
686        !X->getType()->getElement()->isCompatible(e) ||
687        !Y->getType()->getElement()->isCompatible(e) ) {
688        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
689    }
690
691    if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
692        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
693    }
694
695    int M = A->getType()->getY();
696    int N = A->getType()->getX();
697
698    if (N < 1 || M < 1) {
699        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "M and N must be 1 or greater for GER");
700    }
701    if (incX <= 0 || incY <= 0) {
702        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
703    }
704    int expectedXDim = 1 + (M - 1) * incX;
705    if ((int)X->getType()->getX() != expectedXDim) {
706        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GER");
707    }
708    int expectedYDim = 1 + (N - 1) * incY;
709    if ((int)Y->getType()->getX() != expectedYDim) {
710        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GER");
711    }
712
713
714}
715static int validateSYR(RS* mRS, sp<const Element> e, RsBlasUplo Uplo,
716                       sp<Allocation> X, int incX, sp<Allocation> A) {
717    if (!A->getType()->getElement()->isCompatible(e) ||
718        !X->getType()->getElement()->isCompatible(e)) {
719        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
720    }
721
722    int N = A->getType()->getX();
723
724    if (X->getType()->getY() > 1) {
725        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
726    }
727    if (N != (int)A->getType()->getY()) {
728        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a symmetric matrix");
729    }
730    if (incX <= 0) {
731        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
732    }
733    int expectedXDim = 1 + (N - 1) * incX;
734    if ((int)X->getType()->getX() != expectedXDim) {
735        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYR");
736    }
737    return N;
738}
739static int validateSPR(RS* mRS, sp<const Element> e, RsBlasUplo Uplo,
740                       sp<Allocation> X, int incX, sp<Allocation> Ap) {
741    if (!Ap->getType()->getElement()->isCompatible(e) ||
742        !X->getType()->getElement()->isCompatible(e)) {
743        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
744    }
745    if (X->getType()->getY() > 1) {
746        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
747    }
748
749    if (Ap->getType()->getY() > 1) {
750        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
751    }
752
753    int N = sqrt((double)Ap->getType()->getX() * 2);
754    if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
755        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
756    }
757    if (incX <= 0) {
758        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
759    }
760    int expectedXDim = 1 + (N - 1) * incX;
761    if ((int)X->getType()->getX() != expectedXDim) {
762        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPR");
763    }
764
765    return N;
766}
767
768static int validateSYR2(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, sp<Allocation> X,
769                        int incX, sp<Allocation> Y, int incY, sp<Allocation> A) {
770    if (!A->getType()->getElement()->isCompatible(e) ||
771        !X->getType()->getElement()->isCompatible(e) ||
772        !Y->getType()->getElement()->isCompatible(e)) {
773        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
774    }
775
776    if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
777        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
778    }
779
780    int N = A->getType()->getX();
781
782    if (N != (int)A->getType()->getY()) {
783        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a symmetric matrix");
784    }
785    if (incX <= 0 || incY <= 0) {
786        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
787    }
788    int expectedXDim = 1 + (N - 1) * incX;
789    int expectedYDim = 1 + (N - 1) * incY;
790    if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) {
791        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYR");
792    }
793    return N;
794
795}
796static int validateSPR2(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, sp<Allocation> X,
797                        int incX, sp<Allocation> Y, int incY, sp<Allocation> Ap) {
798    if (!Ap->getType()->getElement()->isCompatible(e) ||
799        !X->getType()->getElement()->isCompatible(e) ||
800        !Y->getType()->getElement()->isCompatible(e)) {
801        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
802    }
803    if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
804        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
805    }
806
807    if (Ap->getType()->getY() > 1) {
808        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
809    }
810
811    int N = sqrt((double)Ap->getType()->getX() * 2);
812    if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
813        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
814    }
815    if (incX <= 0 || incY <= 0) {
816        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
817    }
818    int expectedXDim = 1 + (N - 1) * incX;
819    int expectedYDim = 1 + (N - 1) * incY;
820    if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) {
821        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPR2");
822    }
823
824    return N;
825}
826
827void ScriptIntrinsicBLAS::SSYMV(RsBlasUplo Uplo, float alpha, sp<Allocation> A, sp<Allocation> X,
828                                int incX, float beta, sp<Allocation> Y, int incY) {
829    int N = validateSYMV(mRS, Element::F32(mRS), Uplo, A, X, Y, incX, incY);
830    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssymv,
831                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
832                                A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
833}
834
835void ScriptIntrinsicBLAS::SSBMV(RsBlasUplo Uplo, int K, float alpha, sp<Allocation> A, sp<Allocation> X,
836                                int incX, float beta, sp<Allocation> Y, int incY) {
837    // SBMV is the same as SYMV + K >= 0
838    if (K < 0) {
839        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
840    }
841    int N = validateSYMV(mRS, Element::F32(mRS), Uplo, A, X, Y, incX, incY);
842    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssbmv,
843                                0, 0, 0, Uplo, 0, 0, N, K, alpha,
844                                A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
845}
846
847void ScriptIntrinsicBLAS::SSPMV(RsBlasUplo Uplo, float alpha, sp<Allocation> Ap, sp<Allocation> X,
848                                int incX, float beta, sp<Allocation> Y, int incY) {
849    int N = validateSPMV(mRS, Element::F32(mRS), Uplo, Ap, X, incX, Y, incY);
850    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspmv,
851                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
852                                Ap->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
853}
854
855void ScriptIntrinsicBLAS::SGER(float alpha, sp<Allocation> X, int incX,
856                               sp<Allocation> Y, int incY, sp<Allocation> A) {
857    int M = A->getType()->getY();
858    int N = A->getType()->getX();
859    validateGER(mRS, Element::F32(mRS), X, incX, Y, incY, A);
860    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sger,
861                                0, 0, 0, 0, 0, M, N, 0, alpha,
862                                X->getID(), Y->getID(), 0.f, A->getID(), incX, incY, 0, 0);
863}
864
865void ScriptIntrinsicBLAS::SSYR(RsBlasUplo Uplo, float alpha, sp<Allocation> X,
866                               int incX, sp<Allocation> A) {
867    int N = validateSYR(mRS, Element::F32(mRS), Uplo, X, incX, A);
868    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr,
869                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
870                                X->getID(), A->getID(), 0.f, 0, incX, 0, 0, 0);
871}
872
873void ScriptIntrinsicBLAS::SSPR(RsBlasUplo Uplo, float alpha, sp<Allocation> X,
874                               int incX, sp<Allocation> Ap) {
875    int N = validateSPR(mRS, Element::F32(mRS), Uplo, X, incX, Ap);
876    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspr,
877                                0, 0, 0, Uplo, 0, 0, N, 0,
878                                alpha, X->getID(), Ap->getID(), 0.f, 0, incX, 0, 0, 0);
879}
880
881void ScriptIntrinsicBLAS::SSYR2(RsBlasUplo Uplo, float alpha, sp<Allocation> X, int incX,
882                                sp<Allocation> Y, int incY, sp<Allocation> A) {
883    int N = validateSYR2(mRS, Element::F32(mRS), Uplo, X, incX, Y, incY, A);
884    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr2,
885                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
886                                X->getID(), Y->getID(), 0, A->getID(), incX, incY, 0, 0);
887}
888
889void ScriptIntrinsicBLAS::SSPR2(RsBlasUplo Uplo, float alpha, sp<Allocation> X, int incX,
890                                sp<Allocation> Y, int incY, sp<Allocation> Ap) {
891    int N = validateSPR2(mRS, Element::F32(mRS), Uplo, X, incX, Y, incY, Ap);
892    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspr2,
893                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
894                                X->getID(), Y->getID(), 0, Ap->getID(), incX, incY, 0, 0);
895}
896
897void ScriptIntrinsicBLAS::DSYMV(RsBlasUplo Uplo, double alpha, sp<Allocation> A, sp<Allocation> X,
898                                int incX, double beta, sp<Allocation> Y, int incY) {
899    int N = validateSYMV(mRS, Element::F64(mRS), Uplo, A, X, Y, incX, incY);
900    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsymv,
901                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
902                                A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
903}
904
905void ScriptIntrinsicBLAS::DSBMV(RsBlasUplo Uplo, int K, double alpha, sp<Allocation> A, sp<Allocation> X,
906                                int incX, double beta, sp<Allocation> Y, int incY) {
907    // SBMV is the same as SYMV + K >= 0
908    if (K < 0) {
909        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
910    }
911    int N = validateSYMV(mRS, Element::F64(mRS), Uplo, A, X, Y, incX, incY);
912    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsbmv,
913                                0, 0, 0, Uplo, 0, 0, N, K, alpha,
914                                A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
915}
916
917void ScriptIntrinsicBLAS::DSPMV(RsBlasUplo Uplo, double alpha, sp<Allocation> Ap, sp<Allocation> X,
918                                int incX, double beta, sp<Allocation> Y, int incY) {
919    int N = validateSPMV(mRS, Element::F64(mRS), Uplo, Ap, X, incX, Y, incY);
920    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspmv,
921                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
922                                Ap->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
923}
924
925void ScriptIntrinsicBLAS::DGER(double alpha, sp<Allocation> X, int incX, sp<Allocation> Y,
926                               int incY, sp<Allocation> A) {
927    int M = A->getType()->getY();
928    int N = A->getType()->getX();
929    validateGER(mRS, Element::F64(mRS), X, incX, Y, incY, A);
930    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dger,
931                                0, 0, 0, 0, 0, M, N, 0, alpha,
932                                X->getID(), Y->getID(), 0.f, A->getID(), incX, incY, 0, 0);
933}
934
935void ScriptIntrinsicBLAS::DSYR(RsBlasUplo Uplo, double alpha, sp<Allocation> X,
936                               int incX, sp<Allocation> A) {
937    int N = validateSYR(mRS, Element::F64(mRS), Uplo, X, incX, A);
938    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr,
939                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
940                                X->getID(), A->getID(), 0.f, 0, incX, 0, 0, 0);
941}
942
943void ScriptIntrinsicBLAS::DSPR(RsBlasUplo Uplo, double alpha, sp<Allocation> X,
944                               int incX, sp<Allocation> Ap) {
945    int N = validateSPR(mRS, Element::F64(mRS), Uplo, X, incX, Ap);
946    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspr,
947                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
948                                X->getID(), Ap->getID(), 0.f, 0, incX, 0, 0, 0);
949}
950
951void ScriptIntrinsicBLAS::DSYR2(RsBlasUplo Uplo, double alpha, sp<Allocation> X, int incX,
952                                sp<Allocation> Y, int incY, sp<Allocation> A) {
953    int N = validateSYR2(mRS, Element::F64(mRS), Uplo, X, incX, Y, incY, A);
954    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr2,
955                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
956                                X->getID(), Y->getID(), 0, A->getID(), incX, incY, 0, 0);
957}
958
959void ScriptIntrinsicBLAS::DSPR2(RsBlasUplo Uplo, double alpha, sp<Allocation> X, int incX,
960                                sp<Allocation> Y, int incY, sp<Allocation> Ap) {
961    int N = validateSPR2(mRS, Element::F64(mRS), Uplo, X, incX, Y, incY, Ap);
962    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspr2,
963                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
964                                X->getID(), Y->getID(), 0, Ap->getID(), incX, incY, 0, 0);
965}
966
967
968/**
969 * Level 2, C and Z only
970 */
971
972static void validateGERU(RS* mRS, sp<const Element> e, sp<Allocation> X, int incX,
973                         sp<Allocation> Y, int incY, sp<Allocation> A) {
974    if (!A->getType()->getElement()->isCompatible(e) ||
975        !X->getType()->getElement()->isCompatible(e) ||
976        !Y->getType()->getElement()->isCompatible(e)) {
977        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
978    }
979    if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
980        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
981    }
982
983    int M = A->getType()->getY();
984    int N = A->getType()->getX();
985    if (incX <= 0 || incY <= 0) {
986        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
987    }
988    int expectedXDim = 1 + (M - 1) * incX;
989    if ((int)X->getType()->getX() != expectedXDim) {
990        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GERU");
991    }
992    int expectedYDim = 1 + (N - 1) * incY;
993    if ((int)Y->getType()->getX() != expectedYDim) {
994        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GERU");
995    }
996
997}
998
999void ScriptIntrinsicBLAS::CHEMV(RsBlasUplo Uplo, Float2 alpha, sp<Allocation> A,
1000                                sp<Allocation> X, int incX, Float2 beta, sp<Allocation> Y, int incY) {
1001    // HEMV is the same as SYR2 validation-wise
1002    int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A);
1003    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chemv,
1004                                 0, 0, 0, Uplo, 0, 0, N, 0,
1005                                 alpha.x, alpha.y, A->getID(), X->getID(),
1006                                 beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1007}
1008
1009void ScriptIntrinsicBLAS::CHBMV(RsBlasUplo Uplo, int K, Float2 alpha, sp<Allocation> A,
1010                                sp<Allocation> X, int incX, Float2 beta, sp<Allocation> Y, int incY) {
1011    // HBMV is the same as SYR2 validation-wise
1012    int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A);
1013    if (K < 0) {
1014        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be 0 or greater for HBMV");
1015    }
1016    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chbmv,
1017                                 0, 0, 0, Uplo, 0, 0, N, K,
1018                                 alpha.x, alpha.y, A->getID(), X->getID(),
1019                                 beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1020}
1021
1022void ScriptIntrinsicBLAS::CHPMV(RsBlasUplo Uplo, Float2 alpha, sp<Allocation> Ap,
1023                                sp<Allocation> X, int incX, Float2 beta, sp<Allocation> Y, int incY) {
1024    // HPMV is the same as SPR2
1025    int N = validateSPR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
1026    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpmv,
1027                                 0, 0, 0, Uplo, 0, 0, N, 0,
1028                                 alpha.x, alpha.y, Ap->getID(), X->getID(),
1029                                 beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1030}
1031
1032void ScriptIntrinsicBLAS::CGERU(Float2 alpha, sp<Allocation> X, int incX,
1033                                sp<Allocation> Y, int incY, sp<Allocation> A) {
1034    validateGERU(mRS, Element::F32_2(mRS), X, incX, Y, incY, A);
1035    int M = A->getType()->getY();
1036    int N = A->getType()->getX();
1037    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgeru,
1038                                 0, 0, 0, 0, 0, M, N, 0,
1039                                 alpha.x, alpha.y, X->getID(), Y->getID(),
1040                                 0, 0, A->getID(), incX, incY, 0, 0);
1041}
1042
1043void ScriptIntrinsicBLAS::CGERC(Float2 alpha, sp<Allocation> X, int incX,
1044                                sp<Allocation> Y, int incY, sp<Allocation> A) {
1045    // Same as GERU
1046    validateGERU(mRS, Element::F32_2(mRS), X, incX, Y, incY, A);
1047    int M = A->getType()->getY();
1048    int N = A->getType()->getX();
1049    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgerc,
1050                                 0, 0, 0, 0, 0, M, N, 0,
1051                                 alpha.x, alpha.y, X->getID(), Y->getID(),
1052                                 0, 0, A->getID(), incX, incY, 0, 0);
1053}
1054
1055void ScriptIntrinsicBLAS::CHER(RsBlasUplo Uplo, float alpha, sp<Allocation> X,
1056                               int incX, sp<Allocation> A) {
1057    // Same as SYR
1058    int N = validateSYR(mRS, Element::F32_2(mRS), Uplo, X, incX, A);
1059    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher,
1060                                 0, 0, 0, Uplo, 0, 0, N, 0,
1061                                 alpha, 0, X->getID(), 0,
1062                                 0, 0, A->getID(), incX, 0, 0, 0);
1063}
1064
1065void ScriptIntrinsicBLAS::CHPR(RsBlasUplo Uplo, float alpha, sp<Allocation> X,
1066                               int incX, sp<Allocation> Ap) {
1067    // Equivalent to SPR for validation
1068    int N = validateSPR(mRS, Element::F32_2(mRS), Uplo, X, incX, Ap);
1069    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpr,
1070                                 0, 0, 0, Uplo, 0, 0, N, 0,
1071                                 alpha, 0, X->getID(), 0,
1072                                 0, 0, Ap->getID(), incX, 0, 0, 0);
1073}
1074
1075void ScriptIntrinsicBLAS::CHER2(RsBlasUplo Uplo, Float2 alpha, sp<Allocation> X, int incX,
1076                                sp<Allocation> Y, int incY, sp<Allocation> A) {
1077    // Same as SYR2
1078    int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A);
1079    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher2,
1080                                 0, 0, 0, Uplo, 0, 0, N, 0,
1081                                 alpha.x, alpha.y, X->getID(), Y->getID(),
1082                                 0, 0, A->getID(), incX, incY, 0, 0);
1083}
1084
1085void ScriptIntrinsicBLAS::CHPR2(RsBlasUplo Uplo, Float2 alpha, sp<Allocation> X, int incX,
1086                                sp<Allocation> Y, int incY, sp<Allocation> Ap) {
1087    // Same as SPR2
1088    int N = validateSPR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
1089    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpr2,
1090                                 0, 0, 0, Uplo, 0, 0, N, 0,
1091                                 alpha.x, alpha.y, X->getID(), Y->getID(),
1092                                 0, 0, Ap->getID(), incX, incY, 0, 0);
1093}
1094
1095void ScriptIntrinsicBLAS::ZHEMV(RsBlasUplo Uplo, Double2 alpha, sp<Allocation> A,
1096                                sp<Allocation> X, int incX, Double2 beta, sp<Allocation> Y, int incY) {
1097    // HEMV is the same as SYR2 validation-wise
1098    int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A);
1099    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhemv,
1100                           0, 0, 0, Uplo, 0, 0, N, 0,
1101                           alpha.x, alpha.y, A->getID(), X->getID(),
1102                           beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1103}
1104
1105void ScriptIntrinsicBLAS::ZHBMV(RsBlasUplo Uplo, int K, Double2 alpha, sp<Allocation> A, sp<Allocation> X,
1106                                int incX, Double2 beta, sp<Allocation> Y, int incY) {
1107    // HBMV is the same as SYR2 validation-wise
1108    int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A);
1109    if (K < 0) {
1110        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be 0 or greater for HBMV");
1111    }
1112    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhbmv,
1113                           0, 0, 0, Uplo, 0, 0, N, K,
1114                           alpha.x, alpha.y, A->getID(), X->getID(),
1115                           beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1116}
1117
1118void ScriptIntrinsicBLAS::ZHPMV(RsBlasUplo Uplo, Double2 alpha, sp<Allocation> Ap, sp<Allocation> X,
1119                                int incX, Double2 beta, sp<Allocation> Y, int incY) {
1120    // HPMV is the same as SPR2
1121    int N = validateSPR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
1122    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpmv,
1123                           0, 0, 0, Uplo, 0, 0, N, 0,
1124                           alpha.x, alpha.y, Ap->getID(), X->getID(),
1125                           beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1126}
1127
1128void ScriptIntrinsicBLAS::ZGERU(Double2 alpha, sp<Allocation> X, int incX,
1129                                sp<Allocation> Y, int incY, sp<Allocation> A) {
1130    validateGERU(mRS, Element::F64_2(mRS), X, incX, Y, incY, A);
1131    int M = A->getType()->getY();
1132    int N = A->getType()->getX();
1133    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgeru,
1134                           0, 0, 0, 0, 0, M, N, 0,
1135                           alpha.x, alpha.y, X->getID(), Y->getID(),
1136                           0, 0, A->getID(), incX, incY, 0, 0);
1137}
1138
1139void ScriptIntrinsicBLAS::ZGERC(Double2 alpha, sp<Allocation> X, int incX,
1140                                sp<Allocation> Y, int incY, sp<Allocation> A) {
1141    // Same as GERU
1142    validateGERU(mRS, Element::F64_2(mRS), X, incX, Y, incY, A);
1143    int M = A->getType()->getY();
1144    int N = A->getType()->getX();
1145    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgerc,
1146                           0, 0, 0, 0, 0, M, N, 0,
1147                           alpha.x, alpha.y, X->getID(), Y->getID(),
1148                           0, 0, A->getID(), incX, incY, 0, 0);
1149}
1150
1151void ScriptIntrinsicBLAS::ZHER(RsBlasUplo Uplo, double alpha, sp<Allocation> X,
1152                               int incX, sp<Allocation> A) {
1153    // Same as SYR
1154    int N = validateSYR(mRS, Element::F64_2(mRS), Uplo, X, incX, A);
1155    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher,
1156                           0, 0, 0, Uplo, 0, 0, N, 0,
1157                           alpha, 0, X->getID(), 0,
1158                           0, 0, A->getID(), incX, 0, 0, 0);
1159}
1160
1161void ScriptIntrinsicBLAS::ZHPR(RsBlasUplo Uplo, double alpha, sp<Allocation> X,
1162                               int incX, sp<Allocation> Ap) {
1163    // Equivalent to SPR for validation
1164    int N = validateSPR(mRS, Element::F64_2(mRS), Uplo, X, incX, Ap);
1165    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpr,
1166                           0, 0, 0, Uplo, 0, 0, N, 0,
1167                           alpha, 0, X->getID(), 0,
1168                           0, 0, Ap->getID(), incX, 0, 0, 0);
1169}
1170
1171void ScriptIntrinsicBLAS::ZHER2(RsBlasUplo Uplo, Double2 alpha, sp<Allocation> X, int incX,
1172                                sp<Allocation> Y, int incY, sp<Allocation> A) {
1173    // Same as SYR2
1174    int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A);
1175    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher2,
1176                           0, 0, 0, Uplo, 0, 0, N, 0,
1177                           alpha.x, alpha.y, X->getID(), Y->getID(),
1178                           0, 0, A->getID(), incX, incY, 0, 0);
1179}
1180
1181void ScriptIntrinsicBLAS::ZHPR2(RsBlasUplo Uplo, Double2 alpha, sp<Allocation> X, int incX,
1182                                sp<Allocation> Y, int incY, sp<Allocation> Ap) {
1183    // Same as SPR2
1184    int N = validateSPR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
1185    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpr2,
1186                           0, 0, 0, Uplo, 0, 0, N, 0,
1187                           alpha.x, alpha.y, X->getID(), Y->getID(),
1188                           0, 0, Ap->getID(), incX, incY, 0, 0);
1189}
1190
1191
1192/**
1193 * Level 3 BLAS
1194 */
1195
1196static void validateL3(RS* mRS, sp<const Element> e, int TransA, int TransB, int Side,
1197                       sp<Allocation> A, sp<Allocation> B, sp<Allocation> C) {
1198    int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1;
1199    if ((A != nullptr && !A->getType()->getElement()->isCompatible(e)) ||
1200        (B != nullptr && !B->getType()->getElement()->isCompatible(e)) ||
1201        (C != nullptr && !C->getType()->getElement()->isCompatible(e))) {
1202        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1203    }
1204    if (C == nullptr) {
1205        // Since matrix C is used to store the result, it cannot be null.
1206        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Allocation C cannot be null");
1207    }
1208    cM = C->getType()->getY();
1209    cN = C->getType()->getX();
1210
1211    if (Side == RsBlasRight) {
1212        if ((A == nullptr && B != nullptr) || (A != nullptr && B == nullptr)) {
1213            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Provided Matrix A without Matrix B, or vice versa");
1214        }
1215        if (B != nullptr) {
1216            bM = A->getType()->getY();
1217            bN = A->getType()->getX();
1218        }
1219        if (A != nullptr) {
1220            aM = B->getType()->getY();
1221            aN = B->getType()->getX();
1222        }
1223    } else {
1224        if (A != nullptr) {
1225            if (TransA == RsBlasTrans || TransA == RsBlasConjTrans) {
1226                aN = A->getType()->getY();
1227                aM = A->getType()->getX();
1228            } else {
1229                aM = A->getType()->getY();
1230                aN = A->getType()->getX();
1231            }
1232        }
1233        if (B != nullptr) {
1234            if (TransB == RsBlasTrans || TransB == RsBlasConjTrans) {
1235                bN = B->getType()->getY();
1236                bM = B->getType()->getX();
1237            } else {
1238                bM = B->getType()->getY();
1239                bN = B->getType()->getX();
1240            }
1241        }
1242    }
1243    if (A != nullptr && B != nullptr && C != nullptr) {
1244        if (aN != bM || aM != cM || bN != cN) {
1245            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions");
1246        }
1247    } else if (A != nullptr && C != nullptr) {
1248        // A and C only, for SYRK
1249        if (cM != cN) {
1250            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix C is not symmetric");
1251        }
1252        if (aM != cM) {
1253            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions");
1254        }
1255    } else if (A != nullptr && B != nullptr) {
1256        // A and B only
1257        if (aN != bM) {
1258            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions");
1259        }
1260    }
1261
1262}
1263
1264void ScriptIntrinsicBLAS::SGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, float alpha,
1265                                sp<Allocation> A, sp<Allocation> B, float beta, sp<Allocation> C) {
1266    validateL3(mRS, Element::F32(mRS), TransA, TransB, 0, A, B, C);
1267
1268    int M = -1, N = -1, K = -1;
1269    if (TransA != RsBlasNoTrans) {
1270        M = A->getType()->getX();
1271        K = A->getType()->getY();
1272    } else {
1273        M = A->getType()->getY();
1274        K = A->getType()->getX();
1275    }
1276    if (TransB != RsBlasNoTrans) {
1277        N = B->getType()->getY();
1278    } else {
1279        N = B->getType()->getX();
1280    }
1281    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgemm,
1282                                TransA, TransB, 0, 0, 0, M, N, K,
1283                                alpha, A->getID(), B->getID(),
1284                                beta, C->getID(), 0, 0, 0, 0);
1285}
1286
1287void ScriptIntrinsicBLAS::DGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, double alpha,
1288                                sp<Allocation> A, sp<Allocation> B, double beta, sp<Allocation> C) {
1289    validateL3(mRS, Element::F64(mRS), TransA, TransB, 0, A, B, C);
1290    int M = -1, N = -1, K = -1;
1291    if (TransA != RsBlasNoTrans) {
1292        M = A->getType()->getX();
1293        K = A->getType()->getY();
1294    } else {
1295        M = A->getType()->getY();
1296        K = A->getType()->getX();
1297    }
1298    if (TransB != RsBlasNoTrans) {
1299        N = B->getType()->getY();
1300    } else {
1301        N = B->getType()->getX();
1302    }
1303    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgemm,
1304                                TransA, TransB, 0, 0, 0, M, N, K,
1305                                alpha, A->getID(), B->getID(),
1306                                beta, C->getID(), 0, 0, 0, 0);
1307}
1308
1309void ScriptIntrinsicBLAS::CGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, Float2 alpha,
1310                                sp<Allocation> A, sp<Allocation> B, Float2 beta, sp<Allocation> C) {
1311    validateL3(mRS, Element::F32_2(mRS), TransA, TransB, 0, A, B, C);
1312    int M = -1, N = -1, K = -1;
1313    if (TransA != RsBlasNoTrans) {
1314        M = A->getType()->getX();
1315        K = A->getType()->getY();
1316    } else {
1317        M = A->getType()->getY();
1318        K = A->getType()->getX();
1319    }
1320    if (TransB != RsBlasNoTrans) {
1321        N = B->getType()->getY();
1322    } else {
1323        N = B->getType()->getX();
1324    }
1325    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgemm,
1326                                 TransA, TransB, 0, 0, 0, M, N, K,
1327                                 alpha.x, alpha.y, A->getID(), B->getID(),
1328                                 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1329}
1330
1331void ScriptIntrinsicBLAS::ZGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, Double2 alpha,
1332                                sp<Allocation> A, sp<Allocation> B, Double2 beta, sp<Allocation> C) {
1333    validateL3(mRS, Element::F64_2(mRS), TransA, TransB, 0, A, B, C);
1334    int M = -1, N = -1, K = -1;
1335    if (TransA != RsBlasNoTrans) {
1336        M = A->getType()->getX();
1337        K = A->getType()->getY();
1338    } else {
1339        M = A->getType()->getY();
1340        K = A->getType()->getX();
1341    }
1342    if (TransB != RsBlasNoTrans) {
1343        N = B->getType()->getY();
1344    } else {
1345        N = B->getType()->getX();
1346    }
1347    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgemm,
1348                           TransA, TransB, 0, 0, 0, M, N, K,
1349                           alpha.x, alpha.y, A->getID(), B->getID(),
1350                           beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1351}
1352
1353void ScriptIntrinsicBLAS::SSYMM(RsBlasSide Side, RsBlasUplo Uplo, float alpha,
1354                                sp<Allocation> A, sp<Allocation> B, float beta, sp<Allocation> C) {
1355    //For SYMM, Matrix A should be symmetric
1356    if (A->getType()->getX() != A->getType()->getY()) {
1357        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
1358    }
1359    validateL3(mRS, Element::F32(mRS), 0, 0, Side, A, B, C);
1360    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssymm,
1361                                0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
1362                                alpha, A->getID(), B->getID(),
1363                                beta, C->getID(), 0, 0, 0, 0);
1364}
1365
1366void ScriptIntrinsicBLAS::DSYMM(RsBlasSide Side, RsBlasUplo Uplo, double alpha,
1367                                sp<Allocation> A, sp<Allocation> B, double beta, sp<Allocation> C) {
1368    if (A->getType()->getX() != A->getType()->getY()) {
1369        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
1370    }
1371    validateL3(mRS, Element::F64(mRS), 0, 0, Side, A, B, C);
1372    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsymm,
1373                                0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
1374                                alpha, A->getID(), B->getID(),
1375                                beta, C->getID(), 0, 0, 0, 0);
1376}
1377
1378void ScriptIntrinsicBLAS::CSYMM(RsBlasSide Side, RsBlasUplo Uplo, Float2 alpha,
1379                                sp<Allocation> A, sp<Allocation> B, Float2 beta, sp<Allocation> C) {
1380    if (A->getType()->getX() != A->getType()->getY()) {
1381        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
1382    }
1383    validateL3(mRS, Element::F32_2(mRS), 0, 0, Side, A, B, C);
1384    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csymm,
1385                                 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
1386                                 alpha.x, alpha.y, A->getID(), B->getID(),
1387                                 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1388}
1389
1390void ScriptIntrinsicBLAS::ZSYMM(RsBlasSide Side, RsBlasUplo Uplo, Double2 alpha,
1391                                sp<Allocation> A, sp<Allocation> B, Double2 beta, sp<Allocation> C) {
1392    if (A->getType()->getX() != A->getType()->getY()) {
1393        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
1394    }
1395    validateL3(mRS, Element::F64_2(mRS), 0, 0, Side, A, B, C);
1396    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsymm,
1397                           0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
1398                           alpha.x, alpha.y, A->getID(), B->getID(),
1399                           beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1400}
1401
1402void ScriptIntrinsicBLAS::SSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha,
1403                                sp<Allocation> A, float beta, sp<Allocation> C) {
1404    validateL3(mRS, Element::F32(mRS), Trans, 0, 0, A, nullptr, C);
1405    int K = -1;
1406    if (Trans != RsBlasNoTrans) {
1407        K = A->getType()->getY();
1408    } else {
1409        K = A->getType()->getX();
1410    }
1411    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyrk,
1412                                Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1413                                alpha, A->getID(), 0,
1414                                beta, C->getID(), 0, 0, 0, 0);
1415}
1416
1417void ScriptIntrinsicBLAS::DSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha,
1418                                sp<Allocation> A, double beta, sp<Allocation> C) {
1419    validateL3(mRS, Element::F64(mRS), Trans, 0, 0, A, nullptr, C);
1420    int K = -1;
1421    if (Trans != RsBlasNoTrans) {
1422        K = A->getType()->getY();
1423    } else {
1424        K = A->getType()->getX();
1425    }
1426    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyrk,
1427                                Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1428                                alpha, A->getID(), 0,
1429                                beta, C->getID(), 0, 0, 0, 0);
1430}
1431
1432void ScriptIntrinsicBLAS::CSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha,
1433                                sp<Allocation> A, Float2 beta, sp<Allocation> C) {
1434    validateL3(mRS, Element::F32_2(mRS), Trans, 0, 0, A, nullptr, C);
1435    int K = -1;
1436    if (Trans != RsBlasNoTrans) {
1437        K = A->getType()->getY();
1438    } else {
1439        K = A->getType()->getX();
1440    }
1441    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csyrk,
1442                                 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1443                                 alpha.x, alpha.y, A->getID(), 0,
1444                                 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1445}
1446
1447void ScriptIntrinsicBLAS::ZSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha,
1448                                sp<Allocation> A, Double2 beta, sp<Allocation> C) {
1449    validateL3(mRS, Element::F64_2(mRS), Trans, 0, 0, A, nullptr, C);
1450    int K = -1;
1451    if (Trans != RsBlasNoTrans) {
1452        K = A->getType()->getY();
1453    } else {
1454        K = A->getType()->getX();
1455    }
1456    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsyrk,
1457                           Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1458                           alpha.x, alpha.y, A->getID(), 0,
1459                           beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1460}
1461
1462static void validateSYR2K(RS* mRS, sp<const Element> e, RsBlasTranspose Trans,
1463                          sp<Allocation> A, sp<Allocation> B, sp<Allocation> C) {
1464    if (!A->getType()->getElement()->isCompatible(e) ||
1465        !B->getType()->getElement()->isCompatible(e) ||
1466        !C->getType()->getElement()->isCompatible(e)) {
1467        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1468    }
1469    int Cdim = -1;
1470    // A is n x k if no transpose, k x n if transpose
1471    // C is n x n
1472    if (Trans == RsBlasTrans) {
1473        // check columns versus C
1474        Cdim = A->getType()->getX();
1475    } else {
1476        // check rows versus C
1477        Cdim = A->getType()->getY();
1478    }
1479    if ((int)C->getType()->getX() != Cdim || (int)C->getType()->getY() != Cdim) {
1480        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid symmetric matrix in SYR2K");
1481    }
1482    // A dims == B dims
1483    if (A->getType()->getX() != B->getType()->getX() || A->getType()->getY() != B->getType()->getY()) {
1484        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid A and B in SYR2K");
1485    }
1486}
1487
1488void ScriptIntrinsicBLAS::SSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha,
1489                                 sp<Allocation> A, sp<Allocation> B, float beta, sp<Allocation> C) {
1490    validateSYR2K(mRS, Element::F32(mRS), Trans, A, B, C);
1491    int K = -1;
1492    if (Trans != RsBlasNoTrans) {
1493        K = A->getType()->getY();
1494    } else {
1495        K = A->getType()->getX();
1496    }
1497    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr2k,
1498                                Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1499                                alpha, A->getID(), B->getID(),
1500                                beta, C->getID(), 0, 0, 0, 0);
1501}
1502
1503void ScriptIntrinsicBLAS::DSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha,
1504                                 sp<Allocation> A, sp<Allocation> B, double beta, sp<Allocation> C) {
1505    validateSYR2K(mRS, Element::F64(mRS), Trans, A, B, C);
1506    int K = -1;
1507    if (Trans != RsBlasNoTrans) {
1508        K = A->getType()->getY();
1509    } else {
1510        K = A->getType()->getX();
1511    }
1512    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr2k,
1513                                Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1514                                alpha, A->getID(), B->getID(),
1515                                beta, C->getID(), 0, 0, 0, 0);
1516}
1517
1518void ScriptIntrinsicBLAS::CSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha,
1519                                 sp<Allocation> A, sp<Allocation> B, Float2 beta, sp<Allocation> C) {
1520    validateSYR2K(mRS, Element::F32_2(mRS), Trans, A, B, C);
1521    int K = -1;
1522    if (Trans != RsBlasNoTrans) {
1523        K = A->getType()->getY();
1524    } else {
1525        K = A->getType()->getX();
1526    }
1527    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csyr2k,
1528                                 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1529                                 alpha.x, alpha.y, A->getID(), B->getID(),
1530                                 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1531}
1532
1533void ScriptIntrinsicBLAS::ZSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha,
1534                                 sp<Allocation> A, sp<Allocation> B, Double2 beta, sp<Allocation> C) {
1535    validateSYR2K(mRS, Element::F64_2(mRS), Trans, A, B, C);
1536    int K = -1;
1537    if (Trans != RsBlasNoTrans) {
1538        K = A->getType()->getY();
1539    } else {
1540        K = A->getType()->getX();
1541    }
1542    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsyr2k,
1543                           Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1544                           alpha.x, alpha.y, A->getID(), B->getID(),
1545                           beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1546}
1547
1548static void validateTRMM(RS* mRS, sp<const Element> e, RsBlasSide Side, RsBlasTranspose TransA,
1549                         sp<Allocation> A, sp<Allocation> B) {
1550    int aM = -1, aN = -1, bM = -1, bN = -1;
1551    if (!A->getType()->getElement()->isCompatible(e) ||
1552        !B->getType()->getElement()->isCompatible(e)) {
1553        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1554    }
1555
1556    aM = A->getType()->getY();
1557    aN = A->getType()->getX();
1558    if (aM != aN) {
1559        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with a non-symmetric matrix A");
1560    }
1561
1562    bM = B->getType()->getY();
1563    bN = B->getType()->getX();
1564    if (Side == RsBlasLeft) {
1565        if (aN != bM) {
1566            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with invalid matrices");
1567        }
1568    } else {
1569        if (bN != aM) {
1570            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with invalid matrices");
1571        }
1572    }
1573}
1574
1575void ScriptIntrinsicBLAS::STRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1576                                float alpha, sp<Allocation> A, sp<Allocation> B) {
1577    validateTRMM(mRS, Element::F32(mRS), Side, TransA, A, B);
1578    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strmm,
1579                                TransA, 0, Side, Uplo, Diag,\
1580                                B->getType()->getY(), B->getType()->getX(), 0,
1581                                alpha, A->getID(), B->getID(), 0.f, 0, 0, 0, 0, 0);
1582}
1583
1584void ScriptIntrinsicBLAS::DTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1585                                double alpha, sp<Allocation> A, sp<Allocation> B) {
1586    validateTRMM(mRS, Element::F64(mRS), Side, TransA, A, B);
1587    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrmm,
1588                                TransA, 0, Side, Uplo, Diag,
1589                                B->getType()->getY(), B->getType()->getX(), 0,
1590                                alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0);
1591}
1592
1593void ScriptIntrinsicBLAS::CTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1594                                Float2 alpha, sp<Allocation> A, sp<Allocation> B) {
1595    validateTRMM(mRS, Element::F32_2(mRS), Side, TransA, A, B);
1596    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrmm,
1597                                 TransA, 0, Side, Uplo, Diag,
1598                                 B->getType()->getY(), B->getType()->getX(), 0,
1599                                 alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
1600}
1601
1602void ScriptIntrinsicBLAS::ZTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1603                                Double2 alpha, sp<Allocation> A, sp<Allocation> B) {
1604    validateTRMM(mRS, Element::F64_2(mRS), Side, TransA, A, B);
1605    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrmm,
1606                           TransA, 0, Side, Uplo, Diag,
1607                           B->getType()->getY(), B->getType()->getX(), 0,
1608                           alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
1609}
1610
1611static void validateTRSM(RS* mRS, sp<const Element> e, RsBlasSide Side, RsBlasTranspose TransA,
1612                         sp<Allocation> A, sp<Allocation> B) {
1613    int adim = -1, bM = -1, bN = -1;
1614    if (!A->getType()->getElement()->isCompatible(e) ||
1615        !B->getType()->getElement()->isCompatible(e)) {
1616        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1617    }
1618    adim = A->getType()->getX();
1619    if (adim != (int)A->getType()->getY()) {
1620        // This may be unnecessary, the restriction could potentially be relaxed.
1621        // Allocation A needs to contain at least that symmetric matrix but could theoretically
1622        // be larger for now we assume adapters are sufficient, will reevaluate in the future.
1623        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with a non-symmetric matrix A");
1624    }
1625    bM = B->getType()->getY();
1626    bN = B->getType()->getX();
1627    if (Side == RsBlasLeft) {
1628        // A is M*M
1629        if (adim != bM) {
1630            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with invalid matrix dimensions");
1631        }
1632    } else {
1633        // A is N*N
1634        if (adim != bN) {
1635            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with invalid matrix dimensions");
1636        }
1637    }
1638}
1639
1640void ScriptIntrinsicBLAS::STRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1641                                float alpha, sp<Allocation> A, sp<Allocation> B) {
1642    validateTRSM(mRS, Element::F32(mRS), Side, TransA, A, B);
1643    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strsm,
1644                                TransA, 0, Side, Uplo, Diag,
1645                                B->getType()->getY(), B->getType()->getX(), 0,
1646                                alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0);
1647}
1648
1649void ScriptIntrinsicBLAS::DTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1650                                double alpha, sp<Allocation> A, sp<Allocation> B) {
1651    validateTRSM(mRS, Element::F64(mRS), Side, TransA, A, B);
1652    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrsm,
1653                                TransA, 0, Side, Uplo, Diag,
1654                                B->getType()->getY(), B->getType()->getX(), 0,
1655                                alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0);
1656}
1657
1658void ScriptIntrinsicBLAS::CTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1659                                Float2 alpha, sp<Allocation> A, sp<Allocation> B) {
1660    validateTRSM(mRS, Element::F32_2(mRS), Side, TransA, A, B);
1661    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrsm,
1662                                 TransA, 0, Side, Uplo, Diag,
1663                                 B->getType()->getY(), B->getType()->getX(), 0,
1664                                 alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
1665}
1666
1667void ScriptIntrinsicBLAS::ZTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1668                                Double2 alpha, sp<Allocation> A, sp<Allocation> B) {
1669    validateTRSM(mRS, Element::F64_2(mRS), Side, TransA, A, B);
1670    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrsm,
1671                           TransA, 0, Side, Uplo, Diag,
1672                           B->getType()->getY(), B->getType()->getX(), 0,
1673                           alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
1674}
1675
1676static void validateHEMM(RS* mRS, sp<const Element> e, RsBlasSide Side,
1677                         sp<Allocation> A, sp<Allocation> B, sp<Allocation> C) {
1678    if (!A->getType()->getElement()->isCompatible(e) ||
1679        !B->getType()->getElement()->isCompatible(e) ||
1680        !C->getType()->getElement()->isCompatible(e)) {
1681        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1682    }
1683
1684    // A must be square; can potentially be relaxed similar to TRSM
1685    int adim = A->getType()->getX();
1686    if (adim != (int)A->getType()->getY()) {
1687        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with non-square A");
1688    }
1689    if ((Side == RsBlasLeft && adim != (int)B->getType()->getY()) ||
1690        (Side == RsBlasRight && adim != (int)B->getType()->getX())) {
1691        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with invalid B");
1692    }
1693    if (B->getType()->getX() != C->getType()->getX() ||
1694        B->getType()->getY() != C->getType()->getY()) {
1695        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with mismatched B and C");
1696    }
1697}
1698
1699void ScriptIntrinsicBLAS::CHEMM(RsBlasSide Side, RsBlasUplo Uplo, Float2 alpha,
1700                                sp<Allocation> A, sp<Allocation> B, Float2 beta, sp<Allocation> C) {
1701    validateHEMM(mRS, Element::F32_2(mRS), Side, A, B, C);
1702    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chemm,
1703                                 0, 0, Side, Uplo, 0,
1704                                 C->getType()->getY(), C->getType()->getX(), 0,
1705                                 alpha.x, alpha.y, A->getID(), B->getID(),
1706                                 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1707}
1708
1709void ScriptIntrinsicBLAS::ZHEMM(RsBlasSide Side, RsBlasUplo Uplo, Double2 alpha,
1710                                sp<Allocation> A, sp<Allocation> B, Double2 beta, sp<Allocation> C) {
1711    validateHEMM(mRS, Element::F64_2(mRS), Side, A, B, C);
1712    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhemm,
1713                           0, 0, Side, Uplo, 0,
1714                           C->getType()->getY(), C->getType()->getX(), 0,
1715                           alpha.x, alpha.y, A->getID(), B->getID(),
1716                           beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1717}
1718
1719static void validateHERK(RS* mRS, sp<const Element> e, RsBlasTranspose Trans,
1720                         sp<Allocation> A, sp<Allocation> C) {
1721    if (!A->getType()->getElement()->isCompatible(e) ||
1722        !C->getType()->getElement()->isCompatible(e)) {
1723        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1724    }
1725    if (Trans != RsBlasNoTrans && Trans != RsBlasConjTrans) {
1726        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Call HERK with invalid Transpose");
1727    }
1728    int cdim = C->getType()->getX();
1729    if (cdim != (int)C->getType()->getY()) {
1730        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with non-square C");
1731    }
1732    if (Trans == RsBlasNoTrans) {
1733        if (cdim != (int)A->getType()->getY()) {
1734            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with invalid A");
1735        }
1736    } else {
1737        if (cdim != (int)A->getType()->getX()) {
1738            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with invalid A");
1739        }
1740    }
1741}
1742
1743void ScriptIntrinsicBLAS::CHERK(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha,
1744                                sp<Allocation> A, float beta, sp<Allocation> C) {
1745    validateHERK(mRS, Element::F32_2(mRS), Trans, A, C);
1746    int k = 0;
1747    if (Trans == RsBlasConjTrans) {
1748        k = A->getType()->getY();
1749    } else {
1750        k = A->getType()->getX();
1751    }
1752    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cherk,
1753                                 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
1754                                 alpha, 0, A->getID(), 0,
1755                                 beta, 0, C->getID(), 0, 0, 0, 0);
1756}
1757
1758void ScriptIntrinsicBLAS::ZHERK(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha,
1759                                sp<Allocation> A, double beta, sp<Allocation> C) {
1760    validateHERK(mRS, Element::F64_2(mRS), Trans, A, C);
1761    int k = 0;
1762    if (Trans == RsBlasConjTrans) {
1763        k = A->getType()->getY();
1764    } else {
1765        k = A->getType()->getX();
1766    }
1767    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zherk,
1768                           Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
1769                           alpha, 0, A->getID(), 0,
1770                           beta, 0, C->getID(), 0, 0, 0, 0);
1771}
1772
1773static void validateHER2K(RS* mRS, sp<const Element> e, RsBlasTranspose Trans,
1774                          sp<Allocation> A, sp<Allocation> B, sp<Allocation> C) {
1775    if (!A->getType()->getElement()->isCompatible(e) ||
1776        !B->getType()->getElement()->isCompatible(e) ||
1777        !C->getType()->getElement()->isCompatible(e)) {
1778        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1779    }
1780    if (Trans != RsBlasNoTrans && Trans != RsBlasConjTrans) {
1781        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Call HERK with invalid Transpose");
1782    }
1783    int cdim = C->getType()->getX();
1784    if (cdim != (int)C->getType()->getY()) {
1785        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with non-square C");
1786    }
1787    if (Trans == RsBlasNoTrans) {
1788        if ((int)A->getType()->getY() != cdim) {
1789            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid matrices");
1790        }
1791    } else {
1792        if ((int)A->getType()->getX() != cdim) {
1793            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid matrices");
1794        }
1795    }
1796    if (A->getType()->getX() != B->getType()->getX() || A->getType()->getY() != B->getType()->getY()) {
1797        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid A and B matrices");
1798    }
1799}
1800
1801void ScriptIntrinsicBLAS::CHER2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha,
1802                                 sp<Allocation> A, sp<Allocation> B, float beta, sp<Allocation> C) {
1803    validateHER2K(mRS, Element::F32_2(mRS), Trans, A, B, C);
1804    int k = 0;
1805    if (Trans == RsBlasNoTrans) {
1806        k = A->getType()->getX();
1807    } else {
1808        k = A->getType()->getY();
1809    }
1810    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher2k,
1811                                 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
1812                                 alpha.x, alpha.y, A->getID(), B->getID(),
1813                                 beta, 0, C->getID(), 0, 0, 0, 0);
1814}
1815
1816void ScriptIntrinsicBLAS::ZHER2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha,
1817                                 sp<Allocation> A, sp<Allocation> B, double beta, sp<Allocation> C) {
1818    validateHER2K(mRS, Element::F64_2(mRS), Trans, A, B, C);
1819    int k = 0;
1820    if (Trans == RsBlasNoTrans) {
1821        k = A->getType()->getX();
1822    } else {
1823        k = A->getType()->getY();
1824    }
1825    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher2k,
1826                           Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
1827                           alpha.x, alpha.y, A->getID(), B->getID(),
1828                           beta, 0, C->getID(), 0, 0, 0, 0);
1829}
1830
1831
1832
1833void ScriptIntrinsicBLAS::BNNM(sp<Allocation> A, int a_offset, sp<Allocation> B, int b_offset,
1834                               sp<Allocation> C, int c_offset, int c_mult) {
1835    validateL3(mRS, Element::U8(mRS), RsBlasNoTrans, RsBlasTrans, 0, A, B, C);
1836
1837    if (a_offset < 0 || a_offset > 255) {
1838        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid a_offset passed to BNNM");
1839    }
1840    if (b_offset < 0 || b_offset > 255) {
1841        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid b_offset passed to BNNM");
1842    }
1843    int M = -1, N = -1, K = -1;
1844    M = A->getType()->getY();
1845    N = B->getType()->getY();
1846    K = A->getType()->getX();
1847
1848    nScriptIntrinsicBLAS_BNNM(mRS, mRS->getContext(), getID(), M, N, K, A->getID(), a_offset,
1849                              B->getID(), b_offset, C->getID(), c_offset, c_mult);
1850}
1851