1/*
2 * Copyright (C) 2015 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17
18#include "RenderScript.h"
19#include "rsCppInternal.h"
20
21using namespace android;
22using namespace RSC;
23
24// ScriptIntrinsicBLAS APIS
25ScriptIntrinsicBLAS::ScriptIntrinsicBLAS(sp<RS> rs, sp<const Element> e)
26    : ScriptIntrinsic(rs, RS_SCRIPT_INTRINSIC_ID_BLAS, e) {
27
28}
29
30sp<ScriptIntrinsicBLAS> ScriptIntrinsicBLAS::create(sp<RS> rs) {
31    return new ScriptIntrinsicBLAS(rs, Element::U32(rs));
32}
33
34enum RsBlasDataType {
35    SINGLE,
36    DOUBLE,
37    SINGLE_COMPLEX,
38    DOUBLE_COMPLEX
39};
40
41static RsBlasCall
42setUpBLASCall(RsBlasDataType dataType, RsBlasFunction func,
43              int TransA, int TransB, int Side, int Uplo, int Diag,
44              int M, int N, int K, int incX, int incY, int KL, int KU,
45              float alphaF, float betaF, double alphaD, double betaD,
46              float alphaCX, float alphaCY, float betaCX, float betaCY,
47              double alphaZX, double alphaZY, double betaZX, double betaZY
48              ) {
49    RsBlasCall call;
50    memset(&call, 0, sizeof(call));
51    call.func = func;
52    call.transA = (RsBlasTranspose)TransA;
53    call.transB = (RsBlasTranspose)TransB;
54    call.side = (RsBlasSide)Side;
55    call.uplo = (RsBlasUplo)Uplo;
56    call.diag = (RsBlasDiag)Diag;
57    call.M = M;
58    call.N = N;
59    call.K = K;
60
61    switch (dataType) {
62        case SINGLE:
63            // For Single-precision BLAS.
64            call.alpha.f = alphaF;
65            call.beta.f = betaF;
66            break;
67        case DOUBLE:
68            // For Double-precision BLAS.
69            call.alpha.d = alphaD;
70            call.beta.d = betaD;
71            break;
72        case SINGLE_COMPLEX:
73            // For Single-precision complex BLAS.
74            call.alpha.c.r = alphaCX;
75            call.alpha.c.i = alphaCY;
76            call.beta.c.r = betaCX;
77            call.beta.c.i = betaCY;
78            break;
79        case DOUBLE_COMPLEX:
80            // For Double-precision complex BLAS.
81            call.alpha.z.r = alphaZX;
82            call.alpha.z.i = alphaZY;
83            call.beta.z.r = betaZX;
84            call.beta.z.i = betaZY;
85            break;
86        default:
87            break;
88    }
89
90    call.incX = incX;
91    call.incY = incY;
92    call.KL = KL;
93    call.KU = KU;
94
95    return call;
96}
97
98static void
99nScriptIntrinsicBLAS_Single(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
100                            int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
101                            float alpha, RsAllocation A, RsAllocation B,
102                            float beta, RsAllocation C, int incX, int incY, int KL, int KU) {
103    RsBlasCall call = setUpBLASCall(SINGLE, func, TransA, TransB, Side, Uplo, Diag,
104                                    M, N, K, incX, incY, KL, KU, alpha, beta, 0.0, 0.0,
105                                    0.0f, 0.0f, 0.0f, 0.0f, 0.0, 0.0, 0.0, 0.0);
106    RsAllocation in_allocs[3] = {A, B, C};
107    tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr,
108                                                      &call, sizeof(call), nullptr, 0));
109}
110
111
112static void
113nScriptIntrinsicBLAS_Double(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
114                            int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
115                            double alpha, RsAllocation A, RsAllocation B,
116                            double beta, RsAllocation C, int incX, int incY, int KL, int KU) {
117    RsBlasCall call = setUpBLASCall(DOUBLE, func, TransA, TransB, Side, Uplo, Diag,
118                                    M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, alpha, beta,
119                                    0.0f, 0.0f, 0.0f, 0.0f, 0.0, 0.0, 0.0, 0.0);
120    RsAllocation in_allocs[3] = {A, B, C};
121    tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr,
122                                                      &call, sizeof(call), nullptr, 0));
123}
124
125static void
126nScriptIntrinsicBLAS_Complex(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
127                             int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
128                             float alphaX, float alphaY, RsAllocation A, RsAllocation B,
129                             float betaX, float betaY, RsAllocation C, int incX, int incY, int KL, int KU) {
130    RsBlasCall call = setUpBLASCall(SINGLE_COMPLEX, func, TransA, TransB, Side, Uplo, Diag,
131                                    M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, 0.0, 0.0,
132                                    alphaX, alphaY, betaX, betaY, 0.0, 0.0, 0.0, 0.0);
133    RsAllocation in_allocs[3] = {A, B, C};
134    tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr,
135                                                      &call, sizeof(call), nullptr, 0));
136}
137
138static void
139nScriptIntrinsicBLAS_Z(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
140                       int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
141                       double alphaX, double alphaY, RsAllocation A, RsAllocation B,
142                       double betaX, double betaY, RsAllocation C, int incX, int incY, int KL, int KU) {
143    RsBlasCall call = setUpBLASCall(DOUBLE_COMPLEX, func, TransA, TransB, Side, Uplo, Diag,
144                                    M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, 0.0, 0.0,
145                                    0.0f, 0.0f, 0.0f, 0.0f, alphaX, alphaY, betaX, betaY);
146    RsAllocation in_allocs[3] = {A, B, C};
147    tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr,
148                                                      &call, sizeof(call), nullptr, 0));
149}
150
151
152static void
153nScriptIntrinsicBLAS_BNNM(RS* mRS, RsContext con, RsScript id, int M, int N, int K,
154                          RsAllocation A, int a_offset, RsAllocation B, int b_offset,
155                          RsAllocation C, int c_offset, int c_mult_int) {
156    RsBlasCall call;
157    memset(&call, 0, sizeof(call));
158    call.func = RsBlas_bnnm;
159    call.M = M;
160    call.N = N;
161    call.K = K;
162    call.a_offset = a_offset & 0xFF;
163    call.b_offset = b_offset & 0xFF;
164    call.c_offset = c_offset;
165    call.c_mult_int = c_mult_int;
166
167    RsAllocation in_allocs[3] = {A, B, C};
168    tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr,
169                                                      &call, sizeof(call), nullptr, 0));
170}
171
172/**
173 * Level 2 BLAS
174 */
175static void validateGEMV(RS* mRS, sp<const Element> e, RsBlasTranspose TransA, sp<Allocation> A,
176                         sp<Allocation> X, int incX, sp<Allocation> Y, int incY) {
177    int M = A->getType()->getY();
178    int N = A->getType()->getX();
179    if (!A->getType()->getElement()->isCompatible(e) ||
180        !X->getType()->getElement()->isCompatible(e) ||
181        !Y->getType()->getElement()->isCompatible(e)) {
182        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
183    }
184    if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
185        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
186    }
187
188    if (incX <= 0 || incY <= 0) {
189        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
190    }
191    int expectedXDim = -1, expectedYDim = -1;
192    if (TransA == RsBlasNoTrans) {
193        expectedXDim = 1 + (N - 1) * incX;
194        expectedYDim = 1 + (M - 1) * incY;
195    } else {
196        expectedXDim = 1 + (M - 1) * incX;
197        expectedYDim = 1 + (N - 1) * incY;
198    }
199    if ((int)X->getType()->getX() != expectedXDim ||
200        (int)Y->getType()->getX() != expectedYDim) {
201        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GEMV");
202    }
203}
204
205void ScriptIntrinsicBLAS::SGEMV(RsBlasTranspose TransA, float alpha, sp<Allocation> A, sp<Allocation> X,
206                                int incX, float beta, sp<Allocation> Y, int incY) {
207    validateGEMV(mRS, Element::F32(mRS), TransA, A, X, incX, Y, incY);
208    int M = A->getType()->getY();
209    int N = A->getType()->getX();
210    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgemv,
211                                TransA, 0, 0, 0, 0, M, N, 0,
212                                alpha, A->getID(), X->getID(),
213                                beta, Y->getID(), incX, incY, 0, 0);
214}
215
216void ScriptIntrinsicBLAS::DGEMV(RsBlasTranspose TransA, double alpha, sp<Allocation> A, sp<Allocation> X,
217                                int incX, double beta, sp<Allocation> Y, int incY) {
218    validateGEMV(mRS, Element::F64(mRS), TransA, A, X, incX, Y, incY);
219    int M = A->getType()->getY();
220    int N = A->getType()->getX();
221    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgemv,
222                                TransA, 0, 0, 0, 0, M, N, 0,
223                                alpha, A->getID(), X->getID(),
224                                beta, Y->getID(), incX, incY, 0, 0);
225}
226
227void ScriptIntrinsicBLAS::CGEMV(RsBlasTranspose TransA, Float2 alpha, sp<Allocation> A, sp<Allocation> X,
228                                int incX, Float2 beta, sp<Allocation> Y, int incY) {
229    validateGEMV(mRS, Element::F32_2(mRS), TransA, A, X, incX, Y, incY);
230    int M = A->getType()->getY();
231    int N = A->getType()->getX();
232    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgemv,
233                                 TransA, 0, 0, 0, 0, M, N, 0,
234                                 alpha.x, alpha.y, A->getID(), X->getID(),
235                                 beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
236}
237
238void ScriptIntrinsicBLAS::ZGEMV(RsBlasTranspose TransA, Double2 alpha, sp<Allocation> A, sp<Allocation> X,
239                                int incX, Double2 beta, sp<Allocation> Y, int incY) {
240    validateGEMV(mRS, Element::F64_2(mRS), TransA, A, X, incX, Y, incY);
241    int M = A->getType()->getY();
242    int N = A->getType()->getX();
243    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgemv,
244                           TransA, 0, 0, 0, 0, M, N, 0,
245                           alpha.x, alpha.y, A->getID(), X->getID(),
246                           beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
247}
248
249void ScriptIntrinsicBLAS::SGBMV(RsBlasTranspose TransA, int KL, int KU, float alpha, sp<Allocation> A,
250                                sp<Allocation> X, int incX, float beta, sp<Allocation> Y, int incY) {
251    // GBMV has the same validation requirements as GEMV + KL and KU >= 0
252    validateGEMV(mRS, Element::F32(mRS), TransA, A, X, incX, Y, incY);
253    if (KL < 0 || KU < 0) {
254        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
255    }
256    int M = A->getType()->getY();
257    int N = A->getType()->getX();
258
259    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgbmv,
260                                TransA, 0, 0, 0, 0, M, N, 0,
261                                alpha, A->getID(), X->getID(),
262                                beta, Y->getID(), incX, incY, KL, KU);
263}
264
265void ScriptIntrinsicBLAS::DGBMV(RsBlasTranspose TransA, int KL, int KU, double alpha, sp<Allocation> A,
266                                sp<Allocation> X, int incX, double beta, sp<Allocation> Y, int incY) {
267    // GBMV has the same validation requirements as GEMV + KL and KU >= 0
268    validateGEMV(mRS, Element::F64(mRS), TransA, A, X, incX, Y, incY);
269    if (KL < 0 || KU < 0) {
270        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
271    }
272    int M = A->getType()->getY();
273    int N = A->getType()->getX();
274
275    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgbmv,
276                                TransA, 0, 0, 0, 0, M, N, 0,
277                                alpha, A->getID(), X->getID(),
278                                beta, Y->getID(), incX, incY, KL, KU);
279}
280
281void ScriptIntrinsicBLAS::CGBMV(RsBlasTranspose TransA, int KL, int KU, Float2 alpha, sp<Allocation> A,
282                                sp<Allocation> X, int incX, Float2 beta, sp<Allocation> Y, int incY) {
283    // GBMV has the same validation requirements as GEMV + KL and KU >= 0
284    validateGEMV(mRS, Element::F32_2(mRS), TransA, A, X, incX, Y, incY);
285    if (KL < 0 || KU < 0) {
286        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
287    }
288    int M = A->getType()->getY();
289    int N = A->getType()->getX();
290
291    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgbmv,
292                                 TransA, 0, 0, 0, 0, M, N, 0,
293                                 alpha.x, alpha.y, A->getID(), X->getID(),
294                                 beta.x, beta.y, Y->getID(), incX, incY, KL, KU);
295}
296
297void ScriptIntrinsicBLAS::ZGBMV(RsBlasTranspose TransA, int KL, int KU, Double2 alpha, sp<Allocation> A,
298                                sp<Allocation> X, int incX, Double2 beta, sp<Allocation> Y, int incY) {
299    // GBMV has the same validation requirements as GEMV + KL and KU >= 0
300    validateGEMV(mRS, Element::F64_2(mRS), TransA, A, X, incX, Y, incY);
301    if (KL < 0 || KU < 0) {
302        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
303    }
304    int M = A->getType()->getY();
305    int N = A->getType()->getX();
306
307    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgbmv,
308                           TransA, 0, 0, 0, 0, M, N, 0,
309                           alpha.x, alpha.y, A->getID(), X->getID(),
310                           beta.x, beta.y, Y->getID(), incX, incY, KL, KU);
311}
312
313static void validateTRMV(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, RsBlasTranspose TransA,
314                         RsBlasDiag Diag, sp<Allocation> A, sp<Allocation> X, int incX) {
315    int N = A->getType()->getY();
316    if ((int)A->getType()->getX() != N) {
317        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a square matrix for TRMV");
318    }
319    if (!A->getType()->getElement()->isCompatible(e) ||
320        !X->getType()->getElement()->isCompatible(e)) {
321        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
322    }
323    if (X->getType()->getY() > 1) {
324        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
325    }
326
327    if (incX <= 0) {
328        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
329    }
330    int expectedXDim = 1 + (N - 1) * incX;
331    if ((int)X->getType()->getX() != expectedXDim) {
332        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for TRMV");
333    }
334}
335
336static int validateTPMV(RS* mRS, sp<const Element> e,  RsBlasUplo Uplo, RsBlasTranspose TransA,
337                        RsBlasDiag Diag, sp<Allocation> Ap, sp<Allocation> X, int incX) {
338    if (!Ap->getType()->getElement()->isCompatible(e) ||
339        !X->getType()->getElement()->isCompatible(e)) {
340        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
341    }
342    if (X->getType()->getY() > 1) {
343        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
344    }
345
346    if (Ap->getType()->getY() > 1) {
347        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
348    }
349
350    int N = sqrt((double)Ap->getType()->getX() * 2);
351    if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
352        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
353    }
354    if (incX <= 0) {
355        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
356    }
357    int expectedXDim = 1 + (N - 1) * incX;
358    if ((int)X->getType()->getX() != expectedXDim) {
359        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for TPMV");
360    }
361
362    return N;
363}
364
365
366void ScriptIntrinsicBLAS::STRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
367                                sp<Allocation> A, sp<Allocation> X, int incX) {
368    validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
369    int N = A->getType()->getY();
370    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strmv,
371                                TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
372                                A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
373}
374
375void ScriptIntrinsicBLAS::DTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
376                                sp<Allocation> A, sp<Allocation> X, int incX) {
377    validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
378    int N = A->getType()->getY();
379    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrmv,
380                                TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
381                                A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
382}
383
384void ScriptIntrinsicBLAS::CTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
385                                sp<Allocation> A, sp<Allocation> X, int incX) {
386    validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
387    int N = A->getType()->getY();
388    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrmv,
389                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
390                                 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
391}
392
393void ScriptIntrinsicBLAS::ZTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
394                                sp<Allocation> A, sp<Allocation> X, int incX) {
395    validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
396    int N = A->getType()->getY();
397    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrmv,
398                           TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
399                           A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
400}
401
402void ScriptIntrinsicBLAS::STBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
403                                int K, sp<Allocation> A, sp<Allocation> X, int incX) {
404    // TBMV has the same requirements as TRMV + K >= 0
405    if (K < 0) {
406        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
407    }
408    validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
409    int N = A->getType()->getY();
410    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stbmv,
411                                TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
412                                A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
413}
414
415void ScriptIntrinsicBLAS::DTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
416                                int K, sp<Allocation> A, sp<Allocation> X, int incX) {
417    // TBMV has the same requirements as TRMV + K >= 0
418    if (K < 0) {
419        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
420    }
421    validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
422    int N = A->getType()->getY();
423    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtbmv,
424                                TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
425                                A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
426}
427
428void ScriptIntrinsicBLAS::CTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
429                                int K, sp<Allocation> A, sp<Allocation> X, int incX) {
430    // TBMV has the same requirements as TRMV + K >= 0
431    if (K < 0) {
432        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
433    }
434    validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
435    int N = A->getType()->getY();
436    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctbmv,
437                                 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0,
438                                 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
439}
440
441void ScriptIntrinsicBLAS::ZTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
442                                int K, sp<Allocation> A, sp<Allocation> X, int incX) {
443    // TBMV has the same requirements as TRMV + K >= 0
444    if (K < 0) {
445        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
446    }
447    validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
448    int N = A->getType()->getY();
449    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztbmv,
450                           TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0,
451                           A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
452}
453
454void ScriptIntrinsicBLAS::STPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
455                                sp<Allocation> Ap, sp<Allocation> X, int incX) {
456    int N = validateTPMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
457    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stpmv,
458                                TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
459                                Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
460}
461
462void ScriptIntrinsicBLAS::DTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
463                                sp<Allocation> Ap, sp<Allocation> X, int incX) {
464    int N = validateTPMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
465    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtpmv,
466                                TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
467                                Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
468}
469
470void ScriptIntrinsicBLAS::CTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
471                                sp<Allocation> Ap,  sp<Allocation> X,  int incX) {
472    int N = validateTPMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
473    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctpmv,
474                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
475                                 Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
476}
477
478void ScriptIntrinsicBLAS::ZTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
479                                sp<Allocation> Ap, sp<Allocation> X, int incX) {
480    int N = validateTPMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
481    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztpmv,
482                           TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
483                           Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
484}
485
486void ScriptIntrinsicBLAS::STRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
487                                sp<Allocation> A, sp<Allocation> X, int incX) {
488    // TRSV is the same as TRMV
489    validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
490    int N = A->getType()->getY();
491    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strsv,
492                                TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
493                                A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
494}
495
496void ScriptIntrinsicBLAS::DTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
497                                sp<Allocation> A,  sp<Allocation> X,  int incX) {
498    // TRSV is the same as TRMV
499    validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
500    int N = A->getType()->getY();
501    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrsv,
502                                TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
503                                A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
504
505}
506
507void ScriptIntrinsicBLAS::CTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
508                                sp<Allocation> A, sp<Allocation> X, int incX) {
509    // TRSV is the same as TRMV
510    validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
511    int N = A->getType()->getY();
512    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrsv,
513                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
514                                 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
515
516}
517
518void ScriptIntrinsicBLAS::ZTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
519                                sp<Allocation> A, sp<Allocation> X, int incX) {
520    // TRSV is the same as TRMV
521    validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
522    int N = A->getType()->getY();
523    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrsv,
524                           TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
525                           A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
526
527}
528
529void ScriptIntrinsicBLAS::STBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
530                                int K, sp<Allocation> A, sp<Allocation> X, int incX) {
531    // TBSV is the same as TRMV + K >= 0
532    validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
533    int N = A->getType()->getY();
534    if (K < 0) {
535        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
536    }
537    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stbsv,
538                                TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
539                                A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
540}
541
542void ScriptIntrinsicBLAS::DTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
543                                int K, sp<Allocation> A, sp<Allocation> X, int incX) {
544    // TBSV is the same as TRMV + K >= 0
545    validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
546    int N = A->getType()->getY();
547    if (K < 0) {
548        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
549    }
550    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtbsv,
551                                TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
552                                A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
553}
554
555void ScriptIntrinsicBLAS::CTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
556                                int K, sp<Allocation> A, sp<Allocation> X, int incX) {
557    // TBSV is the same as TRMV + K >= 0
558    validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
559    int N = A->getType()->getY();
560    if (K < 0) {
561        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
562    }
563    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctbsv,
564                                 TransA, 0, 0, Uplo, Diag, 0, N, K,
565                                 0, 0, A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
566}
567
568void ScriptIntrinsicBLAS::ZTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
569                                int K, sp<Allocation> A, sp<Allocation> X, int incX) {
570    // TBSV is the same as TRMV + K >= 0
571    validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
572    int N = A->getType()->getY();
573    if (K < 0) {
574        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
575    }
576    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztbsv,
577                           TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0,
578                           A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
579}
580
581void ScriptIntrinsicBLAS::STPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
582                                sp<Allocation> Ap, sp<Allocation> X, int incX) {
583    // TPSV is same as TPMV
584    int N = validateTPMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
585    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stpsv,
586                                TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
587                                Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
588}
589
590void ScriptIntrinsicBLAS::DTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
591                                sp<Allocation> Ap, sp<Allocation> X, int incX) {
592    // TPSV is same as TPMV
593    int N = validateTPMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
594    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtpsv,
595                                TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
596                                Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
597}
598
599void ScriptIntrinsicBLAS::CTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
600                                sp<Allocation> Ap, sp<Allocation> X, int incX) {
601    // TPSV is same as TPMV
602    int N = validateTPMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
603    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctpsv,
604                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
605                                 Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
606}
607
608void ScriptIntrinsicBLAS::ZTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
609                                sp<Allocation> Ap, sp<Allocation> X, int incX) {
610    // TPSV is same as TPMV
611    int N = validateTPMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
612    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztpsv,
613                           TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
614                           Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
615}
616
617/**
618 * Level 2, S and D only
619 */
620static int validateSYMV(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, sp<Allocation> A,
621                        sp<Allocation> X, sp<Allocation> Y, int incX, int incY) {
622    int N = A->getType()->getY();
623    if ((int)A->getType()->getX() != N) {
624        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a square matrix for SYMV");
625    }
626    if (!A->getType()->getElement()->isCompatible(e) ||
627        !X->getType()->getElement()->isCompatible(e) ||
628        !Y->getType()->getElement()->isCompatible(e) ) {
629        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
630    }
631    if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
632        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
633    }
634
635    if (incX <= 0 || incY <= 0) {
636        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
637    }
638    int expectedXDim = 1 + (N - 1) * incX;
639    if ((int)X->getType()->getX() != expectedXDim) {
640        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYMV");
641    }
642    int expectedYDim = 1 + (N - 1) * incY;
643    if ((int)Y->getType()->getX() != expectedYDim) {
644        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYMV");
645    }
646    return N;
647}
648static int validateSPMV(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, sp<Allocation> Ap,
649                        sp<Allocation> X, int incX, sp<Allocation> Y, int incY) {
650    if (!Ap->getType()->getElement()->isCompatible(e) ||
651        !X->getType()->getElement()->isCompatible(e) ||
652        !Y->getType()->getElement()->isCompatible(e)) {
653        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
654    }
655    if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
656        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
657    }
658
659    if (Ap->getType()->getY() > 1) {
660        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
661    }
662
663    int N = sqrt((double)Ap->getType()->getX() * 2);
664    if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
665        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
666    }
667    if (incX <= 0 || incY <= 0) {
668        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
669    }
670    int expectedXDim = 1 + (N - 1) * incX;
671    if ((int)X->getType()->getX() != expectedXDim) {
672        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPMV");
673    }
674    int expectedYDim = 1 + (N - 1) * incY;
675    if ((int)Y->getType()->getX() != expectedYDim) {
676        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPMV");
677    }
678
679    return N;
680}
681static void validateGER(RS* mRS, sp<const Element> e, sp<Allocation> X, int incX,
682                        sp<Allocation> Y, int incY, sp<Allocation> A) {
683    if (!A->getType()->getElement()->isCompatible(e) ||
684        !X->getType()->getElement()->isCompatible(e) ||
685        !Y->getType()->getElement()->isCompatible(e) ) {
686        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
687    }
688
689    if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
690        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
691    }
692
693    int M = A->getType()->getY();
694    int N = A->getType()->getX();
695
696    if (N < 1 || M < 1) {
697        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "M and N must be 1 or greater for GER");
698    }
699    if (incX <= 0 || incY <= 0) {
700        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
701    }
702    int expectedXDim = 1 + (M - 1) * incX;
703    if ((int)X->getType()->getX() != expectedXDim) {
704        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GER");
705    }
706    int expectedYDim = 1 + (N - 1) * incY;
707    if ((int)Y->getType()->getX() != expectedYDim) {
708        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GER");
709    }
710
711
712}
713static int validateSYR(RS* mRS, sp<const Element> e, RsBlasUplo Uplo,
714                       sp<Allocation> X, int incX, sp<Allocation> A) {
715    if (!A->getType()->getElement()->isCompatible(e) ||
716        !X->getType()->getElement()->isCompatible(e)) {
717        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
718    }
719
720    int N = A->getType()->getX();
721
722    if (X->getType()->getY() > 1) {
723        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
724    }
725    if (N != (int)A->getType()->getY()) {
726        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a symmetric matrix");
727    }
728    if (incX <= 0) {
729        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
730    }
731    int expectedXDim = 1 + (N - 1) * incX;
732    if ((int)X->getType()->getX() != expectedXDim) {
733        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYR");
734    }
735    return N;
736}
737static int validateSPR(RS* mRS, sp<const Element> e, RsBlasUplo Uplo,
738                       sp<Allocation> X, int incX, sp<Allocation> Ap) {
739    if (!Ap->getType()->getElement()->isCompatible(e) ||
740        !X->getType()->getElement()->isCompatible(e)) {
741        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
742    }
743    if (X->getType()->getY() > 1) {
744        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
745    }
746
747    if (Ap->getType()->getY() > 1) {
748        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
749    }
750
751    int N = sqrt((double)Ap->getType()->getX() * 2);
752    if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
753        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
754    }
755    if (incX <= 0) {
756        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
757    }
758    int expectedXDim = 1 + (N - 1) * incX;
759    if ((int)X->getType()->getX() != expectedXDim) {
760        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPR");
761    }
762
763    return N;
764}
765
766static int validateSYR2(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, sp<Allocation> X,
767                        int incX, sp<Allocation> Y, int incY, sp<Allocation> A) {
768    if (!A->getType()->getElement()->isCompatible(e) ||
769        !X->getType()->getElement()->isCompatible(e) ||
770        !Y->getType()->getElement()->isCompatible(e)) {
771        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
772    }
773
774    if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
775        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
776    }
777
778    int N = A->getType()->getX();
779
780    if (N != (int)A->getType()->getY()) {
781        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a symmetric matrix");
782    }
783    if (incX <= 0 || incY <= 0) {
784        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
785    }
786    int expectedXDim = 1 + (N - 1) * incX;
787    int expectedYDim = 1 + (N - 1) * incY;
788    if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) {
789        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYR");
790    }
791    return N;
792
793}
794static int validateSPR2(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, sp<Allocation> X,
795                        int incX, sp<Allocation> Y, int incY, sp<Allocation> Ap) {
796    if (!Ap->getType()->getElement()->isCompatible(e) ||
797        !X->getType()->getElement()->isCompatible(e) ||
798        !Y->getType()->getElement()->isCompatible(e)) {
799        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
800    }
801    if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
802        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
803    }
804
805    if (Ap->getType()->getY() > 1) {
806        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
807    }
808
809    int N = sqrt((double)Ap->getType()->getX() * 2);
810    if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
811        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
812    }
813    if (incX <= 0 || incY <= 0) {
814        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
815    }
816    int expectedXDim = 1 + (N - 1) * incX;
817    int expectedYDim = 1 + (N - 1) * incY;
818    if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) {
819        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPR2");
820    }
821
822    return N;
823}
824
825void ScriptIntrinsicBLAS::SSYMV(RsBlasUplo Uplo, float alpha, sp<Allocation> A, sp<Allocation> X,
826                                int incX, float beta, sp<Allocation> Y, int incY) {
827    int N = validateSYMV(mRS, Element::F32(mRS), Uplo, A, X, Y, incX, incY);
828    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssymv,
829                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
830                                A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
831}
832
833void ScriptIntrinsicBLAS::SSBMV(RsBlasUplo Uplo, int K, float alpha, sp<Allocation> A, sp<Allocation> X,
834                                int incX, float beta, sp<Allocation> Y, int incY) {
835    // SBMV is the same as SYMV + K >= 0
836    if (K < 0) {
837        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
838    }
839    int N = validateSYMV(mRS, Element::F32(mRS), Uplo, A, X, Y, incX, incY);
840    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssbmv,
841                                0, 0, 0, Uplo, 0, 0, N, K, alpha,
842                                A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
843}
844
845void ScriptIntrinsicBLAS::SSPMV(RsBlasUplo Uplo, float alpha, sp<Allocation> Ap, sp<Allocation> X,
846                                int incX, float beta, sp<Allocation> Y, int incY) {
847    int N = validateSPMV(mRS, Element::F32(mRS), Uplo, Ap, X, incX, Y, incY);
848    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspmv,
849                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
850                                Ap->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
851}
852
853void ScriptIntrinsicBLAS::SGER(float alpha, sp<Allocation> X, int incX,
854                               sp<Allocation> Y, int incY, sp<Allocation> A) {
855    int M = A->getType()->getY();
856    int N = A->getType()->getX();
857    validateGER(mRS, Element::F32(mRS), X, incX, Y, incY, A);
858    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sger,
859                                0, 0, 0, 0, 0, M, N, 0, alpha,
860                                X->getID(), Y->getID(), 0.f, A->getID(), incX, incY, 0, 0);
861}
862
863void ScriptIntrinsicBLAS::SSYR(RsBlasUplo Uplo, float alpha, sp<Allocation> X,
864                               int incX, sp<Allocation> A) {
865    int N = validateSYR(mRS, Element::F32(mRS), Uplo, X, incX, A);
866    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr,
867                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
868                                X->getID(), A->getID(), 0.f, 0, incX, 0, 0, 0);
869}
870
871void ScriptIntrinsicBLAS::SSPR(RsBlasUplo Uplo, float alpha, sp<Allocation> X,
872                               int incX, sp<Allocation> Ap) {
873    int N = validateSPR(mRS, Element::F32(mRS), Uplo, X, incX, Ap);
874    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspr,
875                                0, 0, 0, Uplo, 0, 0, N, 0,
876                                alpha, X->getID(), Ap->getID(), 0.f, 0, incX, 0, 0, 0);
877}
878
879void ScriptIntrinsicBLAS::SSYR2(RsBlasUplo Uplo, float alpha, sp<Allocation> X, int incX,
880                                sp<Allocation> Y, int incY, sp<Allocation> A) {
881    int N = validateSYR2(mRS, Element::F32(mRS), Uplo, X, incX, Y, incY, A);
882    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr2,
883                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
884                                X->getID(), Y->getID(), 0, A->getID(), incX, incY, 0, 0);
885}
886
887void ScriptIntrinsicBLAS::SSPR2(RsBlasUplo Uplo, float alpha, sp<Allocation> X, int incX,
888                                sp<Allocation> Y, int incY, sp<Allocation> Ap) {
889    int N = validateSPR2(mRS, Element::F32(mRS), Uplo, X, incX, Y, incY, Ap);
890    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspr2,
891                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
892                                X->getID(), Y->getID(), 0, Ap->getID(), incX, incY, 0, 0);
893}
894
895void ScriptIntrinsicBLAS::DSYMV(RsBlasUplo Uplo, double alpha, sp<Allocation> A, sp<Allocation> X,
896                                int incX, double beta, sp<Allocation> Y, int incY) {
897    int N = validateSYMV(mRS, Element::F64(mRS), Uplo, A, X, Y, incX, incY);
898    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsymv,
899                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
900                                A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
901}
902
903void ScriptIntrinsicBLAS::DSBMV(RsBlasUplo Uplo, int K, double alpha, sp<Allocation> A, sp<Allocation> X,
904                                int incX, double beta, sp<Allocation> Y, int incY) {
905    // SBMV is the same as SYMV + K >= 0
906    if (K < 0) {
907        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
908    }
909    int N = validateSYMV(mRS, Element::F64(mRS), Uplo, A, X, Y, incX, incY);
910    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsbmv,
911                                0, 0, 0, Uplo, 0, 0, N, K, alpha,
912                                A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
913}
914
915void ScriptIntrinsicBLAS::DSPMV(RsBlasUplo Uplo, double alpha, sp<Allocation> Ap, sp<Allocation> X,
916                                int incX, double beta, sp<Allocation> Y, int incY) {
917    int N = validateSPMV(mRS, Element::F64(mRS), Uplo, Ap, X, incX, Y, incY);
918    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspmv,
919                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
920                                Ap->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
921}
922
923void ScriptIntrinsicBLAS::DGER(double alpha, sp<Allocation> X, int incX, sp<Allocation> Y,
924                               int incY, sp<Allocation> A) {
925    int M = A->getType()->getY();
926    int N = A->getType()->getX();
927    validateGER(mRS, Element::F64(mRS), X, incX, Y, incY, A);
928    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dger,
929                                0, 0, 0, 0, 0, M, N, 0, alpha,
930                                X->getID(), Y->getID(), 0.f, A->getID(), incX, incY, 0, 0);
931}
932
933void ScriptIntrinsicBLAS::DSYR(RsBlasUplo Uplo, double alpha, sp<Allocation> X,
934                               int incX, sp<Allocation> A) {
935    int N = validateSYR(mRS, Element::F64(mRS), Uplo, X, incX, A);
936    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr,
937                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
938                                X->getID(), A->getID(), 0.f, 0, incX, 0, 0, 0);
939}
940
941void ScriptIntrinsicBLAS::DSPR(RsBlasUplo Uplo, double alpha, sp<Allocation> X,
942                               int incX, sp<Allocation> Ap) {
943    int N = validateSPR(mRS, Element::F64(mRS), Uplo, X, incX, Ap);
944    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspr,
945                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
946                                X->getID(), Ap->getID(), 0.f, 0, incX, 0, 0, 0);
947}
948
949void ScriptIntrinsicBLAS::DSYR2(RsBlasUplo Uplo, double alpha, sp<Allocation> X, int incX,
950                                sp<Allocation> Y, int incY, sp<Allocation> A) {
951    int N = validateSYR2(mRS, Element::F64(mRS), Uplo, X, incX, Y, incY, A);
952    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr2,
953                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
954                                X->getID(), Y->getID(), 0, A->getID(), incX, incY, 0, 0);
955}
956
957void ScriptIntrinsicBLAS::DSPR2(RsBlasUplo Uplo, double alpha, sp<Allocation> X, int incX,
958                                sp<Allocation> Y, int incY, sp<Allocation> Ap) {
959    int N = validateSPR2(mRS, Element::F64(mRS), Uplo, X, incX, Y, incY, Ap);
960    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspr2,
961                                0, 0, 0, Uplo, 0, 0, N, 0, alpha,
962                                X->getID(), Y->getID(), 0, Ap->getID(), incX, incY, 0, 0);
963}
964
965
966/**
967 * Level 2, C and Z only
968 */
969
970static void validateGERU(RS* mRS, sp<const Element> e, sp<Allocation> X, int incX,
971                         sp<Allocation> Y, int incY, sp<Allocation> A) {
972    if (!A->getType()->getElement()->isCompatible(e) ||
973        !X->getType()->getElement()->isCompatible(e) ||
974        !Y->getType()->getElement()->isCompatible(e)) {
975        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
976    }
977    if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
978        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
979    }
980
981    int M = A->getType()->getY();
982    int N = A->getType()->getX();
983    if (incX <= 0 || incY <= 0) {
984        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
985    }
986    int expectedXDim = 1 + (M - 1) * incX;
987    if ((int)X->getType()->getX() != expectedXDim) {
988        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GERU");
989    }
990    int expectedYDim = 1 + (N - 1) * incY;
991    if ((int)Y->getType()->getX() != expectedYDim) {
992        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GERU");
993    }
994
995}
996
997void ScriptIntrinsicBLAS::CHEMV(RsBlasUplo Uplo, Float2 alpha, sp<Allocation> A,
998                                sp<Allocation> X, int incX, Float2 beta, sp<Allocation> Y, int incY) {
999    // HEMV is the same as SYR2 validation-wise
1000    int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A);
1001    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chemv,
1002                                 0, 0, 0, Uplo, 0, 0, N, 0,
1003                                 alpha.x, alpha.y, A->getID(), X->getID(),
1004                                 beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1005}
1006
1007void ScriptIntrinsicBLAS::CHBMV(RsBlasUplo Uplo, int K, Float2 alpha, sp<Allocation> A,
1008                                sp<Allocation> X, int incX, Float2 beta, sp<Allocation> Y, int incY) {
1009    // HBMV is the same as SYR2 validation-wise
1010    int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A);
1011    if (K < 0) {
1012        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be 0 or greater for HBMV");
1013    }
1014    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chbmv,
1015                                 0, 0, 0, Uplo, 0, 0, N, K,
1016                                 alpha.x, alpha.y, A->getID(), X->getID(),
1017                                 beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1018}
1019
1020void ScriptIntrinsicBLAS::CHPMV(RsBlasUplo Uplo, Float2 alpha, sp<Allocation> Ap,
1021                                sp<Allocation> X, int incX, Float2 beta, sp<Allocation> Y, int incY) {
1022    // HPMV is the same as SPR2
1023    int N = validateSPR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
1024    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpmv,
1025                                 0, 0, 0, Uplo, 0, 0, N, 0,
1026                                 alpha.x, alpha.y, Ap->getID(), X->getID(),
1027                                 beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1028}
1029
1030void ScriptIntrinsicBLAS::CGERU(Float2 alpha, sp<Allocation> X, int incX,
1031                                sp<Allocation> Y, int incY, sp<Allocation> A) {
1032    validateGERU(mRS, Element::F32_2(mRS), X, incX, Y, incY, A);
1033    int M = A->getType()->getY();
1034    int N = A->getType()->getX();
1035    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgeru,
1036                                 0, 0, 0, 0, 0, M, N, 0,
1037                                 alpha.x, alpha.y, X->getID(), Y->getID(),
1038                                 0, 0, A->getID(), incX, incY, 0, 0);
1039}
1040
1041void ScriptIntrinsicBLAS::CGERC(Float2 alpha, sp<Allocation> X, int incX,
1042                                sp<Allocation> Y, int incY, sp<Allocation> A) {
1043    // Same as GERU
1044    validateGERU(mRS, Element::F32_2(mRS), X, incX, Y, incY, A);
1045    int M = A->getType()->getY();
1046    int N = A->getType()->getX();
1047    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgerc,
1048                                 0, 0, 0, 0, 0, M, N, 0,
1049                                 alpha.x, alpha.y, X->getID(), Y->getID(),
1050                                 0, 0, A->getID(), incX, incY, 0, 0);
1051}
1052
1053void ScriptIntrinsicBLAS::CHER(RsBlasUplo Uplo, float alpha, sp<Allocation> X,
1054                               int incX, sp<Allocation> A) {
1055    // Same as SYR
1056    int N = validateSYR(mRS, Element::F32_2(mRS), Uplo, X, incX, A);
1057    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher,
1058                                 0, 0, 0, Uplo, 0, 0, N, 0,
1059                                 alpha, 0, X->getID(), 0,
1060                                 0, 0, A->getID(), incX, 0, 0, 0);
1061}
1062
1063void ScriptIntrinsicBLAS::CHPR(RsBlasUplo Uplo, float alpha, sp<Allocation> X,
1064                               int incX, sp<Allocation> Ap) {
1065    // Equivalent to SPR for validation
1066    int N = validateSPR(mRS, Element::F32_2(mRS), Uplo, X, incX, Ap);
1067    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpr,
1068                                 0, 0, 0, Uplo, 0, 0, N, 0,
1069                                 alpha, 0, X->getID(), 0,
1070                                 0, 0, Ap->getID(), incX, 0, 0, 0);
1071}
1072
1073void ScriptIntrinsicBLAS::CHER2(RsBlasUplo Uplo, Float2 alpha, sp<Allocation> X, int incX,
1074                                sp<Allocation> Y, int incY, sp<Allocation> A) {
1075    // Same as SYR2
1076    int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A);
1077    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher2,
1078                                 0, 0, 0, Uplo, 0, 0, N, 0,
1079                                 alpha.x, alpha.y, X->getID(), Y->getID(),
1080                                 0, 0, A->getID(), incX, incY, 0, 0);
1081}
1082
1083void ScriptIntrinsicBLAS::CHPR2(RsBlasUplo Uplo, Float2 alpha, sp<Allocation> X, int incX,
1084                                sp<Allocation> Y, int incY, sp<Allocation> Ap) {
1085    // Same as SPR2
1086    int N = validateSPR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
1087    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpr2,
1088                                 0, 0, 0, Uplo, 0, 0, N, 0,
1089                                 alpha.x, alpha.y, X->getID(), Y->getID(),
1090                                 0, 0, Ap->getID(), incX, incY, 0, 0);
1091}
1092
1093void ScriptIntrinsicBLAS::ZHEMV(RsBlasUplo Uplo, Double2 alpha, sp<Allocation> A,
1094                                sp<Allocation> X, int incX, Double2 beta, sp<Allocation> Y, int incY) {
1095    // HEMV is the same as SYR2 validation-wise
1096    int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A);
1097    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhemv,
1098                           0, 0, 0, Uplo, 0, 0, N, 0,
1099                           alpha.x, alpha.y, A->getID(), X->getID(),
1100                           beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1101}
1102
1103void ScriptIntrinsicBLAS::ZHBMV(RsBlasUplo Uplo, int K, Double2 alpha, sp<Allocation> A, sp<Allocation> X,
1104                                int incX, Double2 beta, sp<Allocation> Y, int incY) {
1105    // HBMV is the same as SYR2 validation-wise
1106    int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A);
1107    if (K < 0) {
1108        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be 0 or greater for HBMV");
1109    }
1110    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhbmv,
1111                           0, 0, 0, Uplo, 0, 0, N, K,
1112                           alpha.x, alpha.y, A->getID(), X->getID(),
1113                           beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1114}
1115
1116void ScriptIntrinsicBLAS::ZHPMV(RsBlasUplo Uplo, Double2 alpha, sp<Allocation> Ap, sp<Allocation> X,
1117                                int incX, Double2 beta, sp<Allocation> Y, int incY) {
1118    // HPMV is the same as SPR2
1119    int N = validateSPR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
1120    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpmv,
1121                           0, 0, 0, Uplo, 0, 0, N, 0,
1122                           alpha.x, alpha.y, Ap->getID(), X->getID(),
1123                           beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1124}
1125
1126void ScriptIntrinsicBLAS::ZGERU(Double2 alpha, sp<Allocation> X, int incX,
1127                                sp<Allocation> Y, int incY, sp<Allocation> A) {
1128    validateGERU(mRS, Element::F64_2(mRS), X, incX, Y, incY, A);
1129    int M = A->getType()->getY();
1130    int N = A->getType()->getX();
1131    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgeru,
1132                           0, 0, 0, 0, 0, M, N, 0,
1133                           alpha.x, alpha.y, X->getID(), Y->getID(),
1134                           0, 0, A->getID(), incX, incY, 0, 0);
1135}
1136
1137void ScriptIntrinsicBLAS::ZGERC(Double2 alpha, sp<Allocation> X, int incX,
1138                                sp<Allocation> Y, int incY, sp<Allocation> A) {
1139    // Same as GERU
1140    validateGERU(mRS, Element::F64_2(mRS), X, incX, Y, incY, A);
1141    int M = A->getType()->getY();
1142    int N = A->getType()->getX();
1143    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgerc,
1144                           0, 0, 0, 0, 0, M, N, 0,
1145                           alpha.x, alpha.y, X->getID(), Y->getID(),
1146                           0, 0, A->getID(), incX, incY, 0, 0);
1147}
1148
1149void ScriptIntrinsicBLAS::ZHER(RsBlasUplo Uplo, double alpha, sp<Allocation> X,
1150                               int incX, sp<Allocation> A) {
1151    // Same as SYR
1152    int N = validateSYR(mRS, Element::F64_2(mRS), Uplo, X, incX, A);
1153    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher,
1154                           0, 0, 0, Uplo, 0, 0, N, 0,
1155                           alpha, 0, X->getID(), 0,
1156                           0, 0, A->getID(), incX, 0, 0, 0);
1157}
1158
1159void ScriptIntrinsicBLAS::ZHPR(RsBlasUplo Uplo, double alpha, sp<Allocation> X,
1160                               int incX, sp<Allocation> Ap) {
1161    // Equivalent to SPR for validation
1162    int N = validateSPR(mRS, Element::F64_2(mRS), Uplo, X, incX, Ap);
1163    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpr,
1164                           0, 0, 0, Uplo, 0, 0, N, 0,
1165                           alpha, 0, X->getID(), 0,
1166                           0, 0, Ap->getID(), incX, 0, 0, 0);
1167}
1168
1169void ScriptIntrinsicBLAS::ZHER2(RsBlasUplo Uplo, Double2 alpha, sp<Allocation> X, int incX,
1170                                sp<Allocation> Y, int incY, sp<Allocation> A) {
1171    // Same as SYR2
1172    int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A);
1173    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher2,
1174                           0, 0, 0, Uplo, 0, 0, N, 0,
1175                           alpha.x, alpha.y, X->getID(), Y->getID(),
1176                           0, 0, A->getID(), incX, incY, 0, 0);
1177}
1178
1179void ScriptIntrinsicBLAS::ZHPR2(RsBlasUplo Uplo, Double2 alpha, sp<Allocation> X, int incX,
1180                                sp<Allocation> Y, int incY, sp<Allocation> Ap) {
1181    // Same as SPR2
1182    int N = validateSPR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
1183    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpr2,
1184                           0, 0, 0, Uplo, 0, 0, N, 0,
1185                           alpha.x, alpha.y, X->getID(), Y->getID(),
1186                           0, 0, Ap->getID(), incX, incY, 0, 0);
1187}
1188
1189
1190/**
1191 * Level 3 BLAS
1192 */
1193
1194static void validateL3(RS* mRS, sp<const Element> e, int TransA, int TransB, int Side,
1195                       sp<Allocation> A, sp<Allocation> B, sp<Allocation> C) {
1196    int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1;
1197    if ((A != nullptr && !A->getType()->getElement()->isCompatible(e)) ||
1198        (B != nullptr && !B->getType()->getElement()->isCompatible(e)) ||
1199        (C != nullptr && !C->getType()->getElement()->isCompatible(e))) {
1200        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1201    }
1202    if (C == nullptr) {
1203        // Since matrix C is used to store the result, it cannot be null.
1204        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Allocation C cannot be null");
1205    }
1206    cM = C->getType()->getY();
1207    cN = C->getType()->getX();
1208
1209    if (Side == RsBlasRight) {
1210        if ((A == nullptr && B != nullptr) || (A != nullptr && B == nullptr)) {
1211            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Provided Matrix A without Matrix B, or vice versa");
1212        }
1213        if (B != nullptr) {
1214            bM = A->getType()->getY();
1215            bN = A->getType()->getX();
1216        }
1217        if (A != nullptr) {
1218            aM = B->getType()->getY();
1219            aN = B->getType()->getX();
1220        }
1221    } else {
1222        if (A != nullptr) {
1223            if (TransA == RsBlasTrans || TransA == RsBlasConjTrans) {
1224                aN = A->getType()->getY();
1225                aM = A->getType()->getX();
1226            } else {
1227                aM = A->getType()->getY();
1228                aN = A->getType()->getX();
1229            }
1230        }
1231        if (B != nullptr) {
1232            if (TransB == RsBlasTrans || TransB == RsBlasConjTrans) {
1233                bN = B->getType()->getY();
1234                bM = B->getType()->getX();
1235            } else {
1236                bM = B->getType()->getY();
1237                bN = B->getType()->getX();
1238            }
1239        }
1240    }
1241    if (A != nullptr && B != nullptr && C != nullptr) {
1242        if (aN != bM || aM != cM || bN != cN) {
1243            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions");
1244        }
1245    } else if (A != nullptr && C != nullptr) {
1246        // A and C only, for SYRK
1247        if (cM != cN) {
1248            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix C is not symmetric");
1249        }
1250        if (aM != cM) {
1251            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions");
1252        }
1253    } else if (A != nullptr && B != nullptr) {
1254        // A and B only
1255        if (aN != bM) {
1256            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions");
1257        }
1258    }
1259
1260}
1261
1262void ScriptIntrinsicBLAS::SGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, float alpha,
1263                                sp<Allocation> A, sp<Allocation> B, float beta, sp<Allocation> C) {
1264    validateL3(mRS, Element::F32(mRS), TransA, TransB, 0, A, B, C);
1265
1266    int M = -1, N = -1, K = -1;
1267    if (TransA != RsBlasNoTrans) {
1268        M = A->getType()->getX();
1269        K = A->getType()->getY();
1270    } else {
1271        M = A->getType()->getY();
1272        K = A->getType()->getX();
1273    }
1274    if (TransB != RsBlasNoTrans) {
1275        N = B->getType()->getY();
1276    } else {
1277        N = B->getType()->getX();
1278    }
1279    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgemm,
1280                                TransA, TransB, 0, 0, 0, M, N, K,
1281                                alpha, A->getID(), B->getID(),
1282                                beta, C->getID(), 0, 0, 0, 0);
1283}
1284
1285void ScriptIntrinsicBLAS::DGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, double alpha,
1286                                sp<Allocation> A, sp<Allocation> B, double beta, sp<Allocation> C) {
1287    validateL3(mRS, Element::F64(mRS), TransA, TransB, 0, A, B, C);
1288    int M = -1, N = -1, K = -1;
1289    if (TransA != RsBlasNoTrans) {
1290        M = A->getType()->getX();
1291        K = A->getType()->getY();
1292    } else {
1293        M = A->getType()->getY();
1294        K = A->getType()->getX();
1295    }
1296    if (TransB != RsBlasNoTrans) {
1297        N = B->getType()->getY();
1298    } else {
1299        N = B->getType()->getX();
1300    }
1301    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgemm,
1302                                TransA, TransB, 0, 0, 0, M, N, K,
1303                                alpha, A->getID(), B->getID(),
1304                                beta, C->getID(), 0, 0, 0, 0);
1305}
1306
1307void ScriptIntrinsicBLAS::CGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, Float2 alpha,
1308                                sp<Allocation> A, sp<Allocation> B, Float2 beta, sp<Allocation> C) {
1309    validateL3(mRS, Element::F32_2(mRS), TransA, TransB, 0, A, B, C);
1310    int M = -1, N = -1, K = -1;
1311    if (TransA != RsBlasNoTrans) {
1312        M = A->getType()->getX();
1313        K = A->getType()->getY();
1314    } else {
1315        M = A->getType()->getY();
1316        K = A->getType()->getX();
1317    }
1318    if (TransB != RsBlasNoTrans) {
1319        N = B->getType()->getY();
1320    } else {
1321        N = B->getType()->getX();
1322    }
1323    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgemm,
1324                                 TransA, TransB, 0, 0, 0, M, N, K,
1325                                 alpha.x, alpha.y, A->getID(), B->getID(),
1326                                 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1327}
1328
1329void ScriptIntrinsicBLAS::ZGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, Double2 alpha,
1330                                sp<Allocation> A, sp<Allocation> B, Double2 beta, sp<Allocation> C) {
1331    validateL3(mRS, Element::F64_2(mRS), TransA, TransB, 0, A, B, C);
1332    int M = -1, N = -1, K = -1;
1333    if (TransA != RsBlasNoTrans) {
1334        M = A->getType()->getX();
1335        K = A->getType()->getY();
1336    } else {
1337        M = A->getType()->getY();
1338        K = A->getType()->getX();
1339    }
1340    if (TransB != RsBlasNoTrans) {
1341        N = B->getType()->getY();
1342    } else {
1343        N = B->getType()->getX();
1344    }
1345    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgemm,
1346                           TransA, TransB, 0, 0, 0, M, N, K,
1347                           alpha.x, alpha.y, A->getID(), B->getID(),
1348                           beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1349}
1350
1351void ScriptIntrinsicBLAS::SSYMM(RsBlasSide Side, RsBlasUplo Uplo, float alpha,
1352                                sp<Allocation> A, sp<Allocation> B, float beta, sp<Allocation> C) {
1353    //For SYMM, Matrix A should be symmetric
1354    if (A->getType()->getX() != A->getType()->getY()) {
1355        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
1356    }
1357    validateL3(mRS, Element::F32(mRS), 0, 0, Side, A, B, C);
1358    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssymm,
1359                                0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
1360                                alpha, A->getID(), B->getID(),
1361                                beta, C->getID(), 0, 0, 0, 0);
1362}
1363
1364void ScriptIntrinsicBLAS::DSYMM(RsBlasSide Side, RsBlasUplo Uplo, double alpha,
1365                                sp<Allocation> A, sp<Allocation> B, double beta, sp<Allocation> C) {
1366    if (A->getType()->getX() != A->getType()->getY()) {
1367        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
1368    }
1369    validateL3(mRS, Element::F64(mRS), 0, 0, Side, A, B, C);
1370    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsymm,
1371                                0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
1372                                alpha, A->getID(), B->getID(),
1373                                beta, C->getID(), 0, 0, 0, 0);
1374}
1375
1376void ScriptIntrinsicBLAS::CSYMM(RsBlasSide Side, RsBlasUplo Uplo, Float2 alpha,
1377                                sp<Allocation> A, sp<Allocation> B, Float2 beta, sp<Allocation> C) {
1378    if (A->getType()->getX() != A->getType()->getY()) {
1379        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
1380    }
1381    validateL3(mRS, Element::F32_2(mRS), 0, 0, Side, A, B, C);
1382    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csymm,
1383                                 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
1384                                 alpha.x, alpha.y, A->getID(), B->getID(),
1385                                 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1386}
1387
1388void ScriptIntrinsicBLAS::ZSYMM(RsBlasSide Side, RsBlasUplo Uplo, Double2 alpha,
1389                                sp<Allocation> A, sp<Allocation> B, Double2 beta, sp<Allocation> C) {
1390    if (A->getType()->getX() != A->getType()->getY()) {
1391        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
1392    }
1393    validateL3(mRS, Element::F64_2(mRS), 0, 0, Side, A, B, C);
1394    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsymm,
1395                           0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
1396                           alpha.x, alpha.y, A->getID(), B->getID(),
1397                           beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1398}
1399
1400void ScriptIntrinsicBLAS::SSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha,
1401                                sp<Allocation> A, float beta, sp<Allocation> C) {
1402    validateL3(mRS, Element::F32(mRS), Trans, 0, 0, A, nullptr, C);
1403    int K = -1;
1404    if (Trans != RsBlasNoTrans) {
1405        K = A->getType()->getY();
1406    } else {
1407        K = A->getType()->getX();
1408    }
1409    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyrk,
1410                                Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1411                                alpha, A->getID(), 0,
1412                                beta, C->getID(), 0, 0, 0, 0);
1413}
1414
1415void ScriptIntrinsicBLAS::DSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha,
1416                                sp<Allocation> A, double beta, sp<Allocation> C) {
1417    validateL3(mRS, Element::F64(mRS), Trans, 0, 0, A, nullptr, C);
1418    int K = -1;
1419    if (Trans != RsBlasNoTrans) {
1420        K = A->getType()->getY();
1421    } else {
1422        K = A->getType()->getX();
1423    }
1424    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyrk,
1425                                Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1426                                alpha, A->getID(), 0,
1427                                beta, C->getID(), 0, 0, 0, 0);
1428}
1429
1430void ScriptIntrinsicBLAS::CSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha,
1431                                sp<Allocation> A, Float2 beta, sp<Allocation> C) {
1432    validateL3(mRS, Element::F32_2(mRS), Trans, 0, 0, A, nullptr, C);
1433    int K = -1;
1434    if (Trans != RsBlasNoTrans) {
1435        K = A->getType()->getY();
1436    } else {
1437        K = A->getType()->getX();
1438    }
1439    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csyrk,
1440                                 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1441                                 alpha.x, alpha.y, A->getID(), 0,
1442                                 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1443}
1444
1445void ScriptIntrinsicBLAS::ZSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha,
1446                                sp<Allocation> A, Double2 beta, sp<Allocation> C) {
1447    validateL3(mRS, Element::F64_2(mRS), Trans, 0, 0, A, nullptr, C);
1448    int K = -1;
1449    if (Trans != RsBlasNoTrans) {
1450        K = A->getType()->getY();
1451    } else {
1452        K = A->getType()->getX();
1453    }
1454    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsyrk,
1455                           Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1456                           alpha.x, alpha.y, A->getID(), 0,
1457                           beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1458}
1459
1460static void validateSYR2K(RS* mRS, sp<const Element> e, RsBlasTranspose Trans,
1461                          sp<Allocation> A, sp<Allocation> B, sp<Allocation> C) {
1462    if (!A->getType()->getElement()->isCompatible(e) ||
1463        !B->getType()->getElement()->isCompatible(e) ||
1464        !C->getType()->getElement()->isCompatible(e)) {
1465        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1466    }
1467    int Cdim = -1;
1468    // A is n x k if no transpose, k x n if transpose
1469    // C is n x n
1470    if (Trans == RsBlasTrans) {
1471        // check columns versus C
1472        Cdim = A->getType()->getX();
1473    } else {
1474        // check rows versus C
1475        Cdim = A->getType()->getY();
1476    }
1477    if ((int)C->getType()->getX() != Cdim || (int)C->getType()->getY() != Cdim) {
1478        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid symmetric matrix in SYR2K");
1479    }
1480    // A dims == B dims
1481    if (A->getType()->getX() != B->getType()->getX() || A->getType()->getY() != B->getType()->getY()) {
1482        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid A and B in SYR2K");
1483    }
1484}
1485
1486void ScriptIntrinsicBLAS::SSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha,
1487                                 sp<Allocation> A, sp<Allocation> B, float beta, sp<Allocation> C) {
1488    validateSYR2K(mRS, Element::F32(mRS), Trans, A, B, C);
1489    int K = -1;
1490    if (Trans != RsBlasNoTrans) {
1491        K = A->getType()->getY();
1492    } else {
1493        K = A->getType()->getX();
1494    }
1495    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr2k,
1496                                Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1497                                alpha, A->getID(), B->getID(),
1498                                beta, C->getID(), 0, 0, 0, 0);
1499}
1500
1501void ScriptIntrinsicBLAS::DSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha,
1502                                 sp<Allocation> A, sp<Allocation> B, double beta, sp<Allocation> C) {
1503    validateSYR2K(mRS, Element::F64(mRS), Trans, A, B, C);
1504    int K = -1;
1505    if (Trans != RsBlasNoTrans) {
1506        K = A->getType()->getY();
1507    } else {
1508        K = A->getType()->getX();
1509    }
1510    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr2k,
1511                                Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1512                                alpha, A->getID(), B->getID(),
1513                                beta, C->getID(), 0, 0, 0, 0);
1514}
1515
1516void ScriptIntrinsicBLAS::CSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha,
1517                                 sp<Allocation> A, sp<Allocation> B, Float2 beta, sp<Allocation> C) {
1518    validateSYR2K(mRS, Element::F32_2(mRS), Trans, A, B, C);
1519    int K = -1;
1520    if (Trans != RsBlasNoTrans) {
1521        K = A->getType()->getY();
1522    } else {
1523        K = A->getType()->getX();
1524    }
1525    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csyr2k,
1526                                 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1527                                 alpha.x, alpha.y, A->getID(), B->getID(),
1528                                 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1529}
1530
1531void ScriptIntrinsicBLAS::ZSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha,
1532                                 sp<Allocation> A, sp<Allocation> B, Double2 beta, sp<Allocation> C) {
1533    validateSYR2K(mRS, Element::F64_2(mRS), Trans, A, B, C);
1534    int K = -1;
1535    if (Trans != RsBlasNoTrans) {
1536        K = A->getType()->getY();
1537    } else {
1538        K = A->getType()->getX();
1539    }
1540    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsyr2k,
1541                           Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1542                           alpha.x, alpha.y, A->getID(), B->getID(),
1543                           beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1544}
1545
1546static void validateTRMM(RS* mRS, sp<const Element> e, RsBlasSide Side, RsBlasTranspose TransA,
1547                         sp<Allocation> A, sp<Allocation> B) {
1548    int aM = -1, aN = -1, bM = -1, bN = -1;
1549    if (!A->getType()->getElement()->isCompatible(e) ||
1550        !B->getType()->getElement()->isCompatible(e)) {
1551        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1552    }
1553
1554    aM = A->getType()->getY();
1555    aN = A->getType()->getX();
1556    if (aM != aN) {
1557        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with a non-symmetric matrix A");
1558    }
1559
1560    bM = B->getType()->getY();
1561    bN = B->getType()->getX();
1562    if (Side == RsBlasLeft) {
1563        if (aN != bM) {
1564            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with invalid matrices");
1565        }
1566    } else {
1567        if (bN != aM) {
1568            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with invalid matrices");
1569        }
1570    }
1571}
1572
1573void ScriptIntrinsicBLAS::STRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1574                                float alpha, sp<Allocation> A, sp<Allocation> B) {
1575    validateTRMM(mRS, Element::F32(mRS), Side, TransA, A, B);
1576    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strmm,
1577                                TransA, 0, Side, Uplo, Diag,\
1578                                B->getType()->getY(), B->getType()->getX(), 0,
1579                                alpha, A->getID(), B->getID(), 0.f, 0, 0, 0, 0, 0);
1580}
1581
1582void ScriptIntrinsicBLAS::DTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1583                                double alpha, sp<Allocation> A, sp<Allocation> B) {
1584    validateTRMM(mRS, Element::F64(mRS), Side, TransA, A, B);
1585    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrmm,
1586                                TransA, 0, Side, Uplo, Diag,
1587                                B->getType()->getY(), B->getType()->getX(), 0,
1588                                alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0);
1589}
1590
1591void ScriptIntrinsicBLAS::CTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1592                                Float2 alpha, sp<Allocation> A, sp<Allocation> B) {
1593    validateTRMM(mRS, Element::F32_2(mRS), Side, TransA, A, B);
1594    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrmm,
1595                                 TransA, 0, Side, Uplo, Diag,
1596                                 B->getType()->getY(), B->getType()->getX(), 0,
1597                                 alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
1598}
1599
1600void ScriptIntrinsicBLAS::ZTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1601                                Double2 alpha, sp<Allocation> A, sp<Allocation> B) {
1602    validateTRMM(mRS, Element::F64_2(mRS), Side, TransA, A, B);
1603    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrmm,
1604                           TransA, 0, Side, Uplo, Diag,
1605                           B->getType()->getY(), B->getType()->getX(), 0,
1606                           alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
1607}
1608
1609static void validateTRSM(RS* mRS, sp<const Element> e, RsBlasSide Side, RsBlasTranspose TransA,
1610                         sp<Allocation> A, sp<Allocation> B) {
1611    int adim = -1, bM = -1, bN = -1;
1612    if (!A->getType()->getElement()->isCompatible(e) ||
1613        !B->getType()->getElement()->isCompatible(e)) {
1614        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1615    }
1616    adim = A->getType()->getX();
1617    if (adim != (int)A->getType()->getY()) {
1618        // This may be unnecessary, the restriction could potentially be relaxed.
1619        // Allocation A needs to contain at least that symmetric matrix but could theoretically
1620        // be larger for now we assume adapters are sufficient, will reevaluate in the future.
1621        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with a non-symmetric matrix A");
1622    }
1623    bM = B->getType()->getY();
1624    bN = B->getType()->getX();
1625    if (Side == RsBlasLeft) {
1626        // A is M*M
1627        if (adim != bM) {
1628            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with invalid matrix dimensions");
1629        }
1630    } else {
1631        // A is N*N
1632        if (adim != bN) {
1633            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with invalid matrix dimensions");
1634        }
1635    }
1636}
1637
1638void ScriptIntrinsicBLAS::STRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1639                                float alpha, sp<Allocation> A, sp<Allocation> B) {
1640    validateTRSM(mRS, Element::F32(mRS), Side, TransA, A, B);
1641    nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strsm,
1642                                TransA, 0, Side, Uplo, Diag,
1643                                B->getType()->getY(), B->getType()->getX(), 0,
1644                                alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0);
1645}
1646
1647void ScriptIntrinsicBLAS::DTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1648                                double alpha, sp<Allocation> A, sp<Allocation> B) {
1649    validateTRSM(mRS, Element::F64(mRS), Side, TransA, A, B);
1650    nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrsm,
1651                                TransA, 0, Side, Uplo, Diag,
1652                                B->getType()->getY(), B->getType()->getX(), 0,
1653                                alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0);
1654}
1655
1656void ScriptIntrinsicBLAS::CTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1657                                Float2 alpha, sp<Allocation> A, sp<Allocation> B) {
1658    validateTRSM(mRS, Element::F32_2(mRS), Side, TransA, A, B);
1659    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrsm,
1660                                 TransA, 0, Side, Uplo, Diag,
1661                                 B->getType()->getY(), B->getType()->getX(), 0,
1662                                 alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
1663}
1664
1665void ScriptIntrinsicBLAS::ZTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1666                                Double2 alpha, sp<Allocation> A, sp<Allocation> B) {
1667    validateTRSM(mRS, Element::F64_2(mRS), Side, TransA, A, B);
1668    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrsm,
1669                           TransA, 0, Side, Uplo, Diag,
1670                           B->getType()->getY(), B->getType()->getX(), 0,
1671                           alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
1672}
1673
1674static void validateHEMM(RS* mRS, sp<const Element> e, RsBlasSide Side,
1675                         sp<Allocation> A, sp<Allocation> B, sp<Allocation> C) {
1676    if (!A->getType()->getElement()->isCompatible(e) ||
1677        !B->getType()->getElement()->isCompatible(e) ||
1678        !C->getType()->getElement()->isCompatible(e)) {
1679        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1680    }
1681
1682    // A must be square; can potentially be relaxed similar to TRSM
1683    int adim = A->getType()->getX();
1684    if (adim != (int)A->getType()->getY()) {
1685        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with non-square A");
1686    }
1687    if ((Side == RsBlasLeft && adim != (int)B->getType()->getY()) ||
1688        (Side == RsBlasRight && adim != (int)B->getType()->getX())) {
1689        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with invalid B");
1690    }
1691    if (B->getType()->getX() != C->getType()->getX() ||
1692        B->getType()->getY() != C->getType()->getY()) {
1693        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with mismatched B and C");
1694    }
1695}
1696
1697void ScriptIntrinsicBLAS::CHEMM(RsBlasSide Side, RsBlasUplo Uplo, Float2 alpha,
1698                                sp<Allocation> A, sp<Allocation> B, Float2 beta, sp<Allocation> C) {
1699    validateHEMM(mRS, Element::F32_2(mRS), Side, A, B, C);
1700    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chemm,
1701                                 0, 0, Side, Uplo, 0,
1702                                 C->getType()->getY(), C->getType()->getX(), 0,
1703                                 alpha.x, alpha.y, A->getID(), B->getID(),
1704                                 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1705}
1706
1707void ScriptIntrinsicBLAS::ZHEMM(RsBlasSide Side, RsBlasUplo Uplo, Double2 alpha,
1708                                sp<Allocation> A, sp<Allocation> B, Double2 beta, sp<Allocation> C) {
1709    validateHEMM(mRS, Element::F64_2(mRS), Side, A, B, C);
1710    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhemm,
1711                           0, 0, Side, Uplo, 0,
1712                           C->getType()->getY(), C->getType()->getX(), 0,
1713                           alpha.x, alpha.y, A->getID(), B->getID(),
1714                           beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1715}
1716
1717static void validateHERK(RS* mRS, sp<const Element> e, RsBlasTranspose Trans,
1718                         sp<Allocation> A, sp<Allocation> C) {
1719    if (!A->getType()->getElement()->isCompatible(e) ||
1720        !C->getType()->getElement()->isCompatible(e)) {
1721        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1722    }
1723    if (Trans != RsBlasNoTrans && Trans != RsBlasConjTrans) {
1724        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Call HERK with invalid Transpose");
1725    }
1726    int cdim = C->getType()->getX();
1727    if (cdim != (int)C->getType()->getY()) {
1728        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with non-square C");
1729    }
1730    if (Trans == RsBlasNoTrans) {
1731        if (cdim != (int)A->getType()->getY()) {
1732            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with invalid A");
1733        }
1734    } else {
1735        if (cdim != (int)A->getType()->getX()) {
1736            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with invalid A");
1737        }
1738    }
1739}
1740
1741void ScriptIntrinsicBLAS::CHERK(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha,
1742                                sp<Allocation> A, float beta, sp<Allocation> C) {
1743    validateHERK(mRS, Element::F32_2(mRS), Trans, A, C);
1744    int k = 0;
1745    if (Trans == RsBlasConjTrans) {
1746        k = A->getType()->getY();
1747    } else {
1748        k = A->getType()->getX();
1749    }
1750    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cherk,
1751                                 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
1752                                 alpha, 0, A->getID(), 0,
1753                                 beta, 0, C->getID(), 0, 0, 0, 0);
1754}
1755
1756void ScriptIntrinsicBLAS::ZHERK(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha,
1757                                sp<Allocation> A, double beta, sp<Allocation> C) {
1758    validateHERK(mRS, Element::F64_2(mRS), Trans, A, C);
1759    int k = 0;
1760    if (Trans == RsBlasConjTrans) {
1761        k = A->getType()->getY();
1762    } else {
1763        k = A->getType()->getX();
1764    }
1765    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zherk,
1766                           Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
1767                           alpha, 0, A->getID(), 0,
1768                           beta, 0, C->getID(), 0, 0, 0, 0);
1769}
1770
1771static void validateHER2K(RS* mRS, sp<const Element> e, RsBlasTranspose Trans,
1772                          sp<Allocation> A, sp<Allocation> B, sp<Allocation> C) {
1773    if (!A->getType()->getElement()->isCompatible(e) ||
1774        !B->getType()->getElement()->isCompatible(e) ||
1775        !C->getType()->getElement()->isCompatible(e)) {
1776        mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1777    }
1778    if (Trans != RsBlasNoTrans && Trans != RsBlasConjTrans) {
1779        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Call HERK with invalid Transpose");
1780    }
1781    int cdim = C->getType()->getX();
1782    if (cdim != (int)C->getType()->getY()) {
1783        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with non-square C");
1784    }
1785    if (Trans == RsBlasNoTrans) {
1786        if ((int)A->getType()->getY() != cdim) {
1787            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid matrices");
1788        }
1789    } else {
1790        if ((int)A->getType()->getX() != cdim) {
1791            mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid matrices");
1792        }
1793    }
1794    if (A->getType()->getX() != B->getType()->getX() || A->getType()->getY() != B->getType()->getY()) {
1795        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid A and B matrices");
1796    }
1797}
1798
1799void ScriptIntrinsicBLAS::CHER2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha,
1800                                 sp<Allocation> A, sp<Allocation> B, float beta, sp<Allocation> C) {
1801    validateHER2K(mRS, Element::F32_2(mRS), Trans, A, B, C);
1802    int k = 0;
1803    if (Trans == RsBlasNoTrans) {
1804        k = A->getType()->getX();
1805    } else {
1806        k = A->getType()->getY();
1807    }
1808    nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher2k,
1809                                 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
1810                                 alpha.x, alpha.y, A->getID(), B->getID(),
1811                                 beta, 0, C->getID(), 0, 0, 0, 0);
1812}
1813
1814void ScriptIntrinsicBLAS::ZHER2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha,
1815                                 sp<Allocation> A, sp<Allocation> B, double beta, sp<Allocation> C) {
1816    validateHER2K(mRS, Element::F64_2(mRS), Trans, A, B, C);
1817    int k = 0;
1818    if (Trans == RsBlasNoTrans) {
1819        k = A->getType()->getX();
1820    } else {
1821        k = A->getType()->getY();
1822    }
1823    nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher2k,
1824                           Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
1825                           alpha.x, alpha.y, A->getID(), B->getID(),
1826                           beta, 0, C->getID(), 0, 0, 0, 0);
1827}
1828
1829
1830
1831void ScriptIntrinsicBLAS::BNNM(sp<Allocation> A, int a_offset, sp<Allocation> B, int b_offset,
1832                               sp<Allocation> C, int c_offset, int c_mult) {
1833    validateL3(mRS, Element::U8(mRS), RsBlasNoTrans, RsBlasTrans, 0, A, B, C);
1834
1835    if (a_offset < 0 || a_offset > 255) {
1836        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid a_offset passed to BNNM");
1837    }
1838    if (b_offset < 0 || b_offset > 255) {
1839        mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid b_offset passed to BNNM");
1840    }
1841    int M = -1, N = -1, K = -1;
1842    M = A->getType()->getY();
1843    N = B->getType()->getY();
1844    K = A->getType()->getX();
1845
1846    nScriptIntrinsicBLAS_BNNM(mRS, mRS->getContext(), getID(), M, N, K, A->getID(), a_offset,
1847                              B->getID(), b_offset, C->getID(), c_offset, c_mult);
1848}
1849