ScriptIntrinsicBLAS.java revision cecc00aba1012d4f19bc78fcd12ddcbccdd49182
1/*
2 * Copyright (C) 2015 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package android.renderscript;
18
19import android.annotation.IntDef;
20import java.lang.annotation.Retention;
21import java.lang.annotation.RetentionPolicy;
22
23/**
24 *
25 * BLAS
26 *
27 **/
28public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
29    private Allocation mLUT;
30
31    private ScriptIntrinsicBLAS(long id, RenderScript rs) {
32        super(id, rs);
33    }
34
35    private static final int RsBlas_sdsdot = 1;
36    private static final int RsBlas_dsdot = 2;
37    private static final int RsBlas_sdot = 3;
38    private static final int RsBlas_ddot = 4;
39    private static final int RsBlas_cdotu_sub = 5;
40    private static final int RsBlas_cdotc_sub = 6;
41    private static final int RsBlas_zdotu_sub = 7;
42    private static final int RsBlas_zdotc_sub = 8;
43    private static final int RsBlas_snrm2 = 9;
44    private static final int RsBlas_sasum = 10;
45    private static final int RsBlas_dnrm2 = 11;
46    private static final int RsBlas_dasum = 12;
47    private static final int RsBlas_scnrm2 = 13;
48    private static final int RsBlas_scasum = 14;
49    private static final int RsBlas_dznrm2 = 15;
50    private static final int RsBlas_dzasum = 16;
51    private static final int RsBlas_isamax = 17;
52    private static final int RsBlas_idamax = 18;
53    private static final int RsBlas_icamax = 19;
54    private static final int RsBlas_izamax = 20;
55    private static final int RsBlas_sswap = 21;
56    private static final int RsBlas_scopy = 22;
57    private static final int RsBlas_saxpy = 23;
58    private static final int RsBlas_dswap = 24;
59    private static final int RsBlas_dcopy = 25;
60    private static final int RsBlas_daxpy = 26;
61    private static final int RsBlas_cswap = 27;
62    private static final int RsBlas_ccopy = 28;
63    private static final int RsBlas_caxpy = 29;
64    private static final int RsBlas_zswap = 30;
65    private static final int RsBlas_zcopy = 31;
66    private static final int RsBlas_zaxpy = 32;
67    private static final int RsBlas_srotg = 33;
68    private static final int RsBlas_srotmg = 34;
69    private static final int RsBlas_srot = 35;
70    private static final int RsBlas_srotm = 36;
71    private static final int RsBlas_drotg = 37;
72    private static final int RsBlas_drotmg = 38;
73    private static final int RsBlas_drot = 39;
74    private static final int RsBlas_drotm = 40;
75    private static final int RsBlas_sscal = 41;
76    private static final int RsBlas_dscal = 42;
77    private static final int RsBlas_cscal = 43;
78    private static final int RsBlas_zscal = 44;
79    private static final int RsBlas_csscal = 45;
80    private static final int RsBlas_zdscal = 46;
81    private static final int RsBlas_sgemv = 47;
82    private static final int RsBlas_sgbmv = 48;
83    private static final int RsBlas_strmv = 49;
84    private static final int RsBlas_stbmv = 50;
85    private static final int RsBlas_stpmv = 51;
86    private static final int RsBlas_strsv = 52;
87    private static final int RsBlas_stbsv = 53;
88    private static final int RsBlas_stpsv = 54;
89    private static final int RsBlas_dgemv = 55;
90    private static final int RsBlas_dgbmv = 56;
91    private static final int RsBlas_dtrmv = 57;
92    private static final int RsBlas_dtbmv = 58;
93    private static final int RsBlas_dtpmv = 59;
94    private static final int RsBlas_dtrsv = 60;
95    private static final int RsBlas_dtbsv = 61;
96    private static final int RsBlas_dtpsv = 62;
97    private static final int RsBlas_cgemv = 63;
98    private static final int RsBlas_cgbmv = 64;
99    private static final int RsBlas_ctrmv = 65;
100    private static final int RsBlas_ctbmv = 66;
101    private static final int RsBlas_ctpmv = 67;
102    private static final int RsBlas_ctrsv = 68;
103    private static final int RsBlas_ctbsv = 69;
104    private static final int RsBlas_ctpsv = 70;
105    private static final int RsBlas_zgemv = 71;
106    private static final int RsBlas_zgbmv = 72;
107    private static final int RsBlas_ztrmv = 73;
108    private static final int RsBlas_ztbmv = 74;
109    private static final int RsBlas_ztpmv = 75;
110    private static final int RsBlas_ztrsv = 76;
111    private static final int RsBlas_ztbsv = 77;
112    private static final int RsBlas_ztpsv = 78;
113    private static final int RsBlas_ssymv = 79;
114    private static final int RsBlas_ssbmv = 80;
115    private static final int RsBlas_sspmv = 81;
116    private static final int RsBlas_sger = 82;
117    private static final int RsBlas_ssyr = 83;
118    private static final int RsBlas_sspr = 84;
119    private static final int RsBlas_ssyr2 = 85;
120    private static final int RsBlas_sspr2 = 86;
121    private static final int RsBlas_dsymv = 87;
122    private static final int RsBlas_dsbmv = 88;
123    private static final int RsBlas_dspmv = 89;
124    private static final int RsBlas_dger = 90;
125    private static final int RsBlas_dsyr = 91;
126    private static final int RsBlas_dspr = 92;
127    private static final int RsBlas_dsyr2 = 93;
128    private static final int RsBlas_dspr2 = 94;
129    private static final int RsBlas_chemv = 95;
130    private static final int RsBlas_chbmv = 96;
131    private static final int RsBlas_chpmv = 97;
132    private static final int RsBlas_cgeru = 98;
133    private static final int RsBlas_cgerc = 99;
134    private static final int RsBlas_cher = 100;
135    private static final int RsBlas_chpr = 101;
136    private static final int RsBlas_cher2 = 102;
137    private static final int RsBlas_chpr2 = 103;
138    private static final int RsBlas_zhemv = 104;
139    private static final int RsBlas_zhbmv = 105;
140    private static final int RsBlas_zhpmv = 106;
141    private static final int RsBlas_zgeru = 107;
142    private static final int RsBlas_zgerc = 108;
143    private static final int RsBlas_zher = 109;
144    private static final int RsBlas_zhpr = 110;
145    private static final int RsBlas_zher2 = 111;
146    private static final int RsBlas_zhpr2 = 112;
147    private static final int RsBlas_sgemm = 113;
148    private static final int RsBlas_ssymm = 114;
149    private static final int RsBlas_ssyrk = 115;
150    private static final int RsBlas_ssyr2k = 116;
151    private static final int RsBlas_strmm = 117;
152    private static final int RsBlas_strsm = 118;
153    private static final int RsBlas_dgemm = 119;
154    private static final int RsBlas_dsymm = 120;
155    private static final int RsBlas_dsyrk = 121;
156    private static final int RsBlas_dsyr2k = 122;
157    private static final int RsBlas_dtrmm = 123;
158    private static final int RsBlas_dtrsm = 124;
159    private static final int RsBlas_cgemm = 125;
160    private static final int RsBlas_csymm = 126;
161    private static final int RsBlas_csyrk = 127;
162    private static final int RsBlas_csyr2k = 128;
163    private static final int RsBlas_ctrmm = 129;
164    private static final int RsBlas_ctrsm = 130;
165    private static final int RsBlas_zgemm = 131;
166    private static final int RsBlas_zsymm = 132;
167    private static final int RsBlas_zsyrk = 133;
168    private static final int RsBlas_zsyr2k = 134;
169    private static final int RsBlas_ztrmm = 135;
170    private static final int RsBlas_ztrsm = 136;
171    private static final int RsBlas_chemm = 137;
172    private static final int RsBlas_cherk = 138;
173    private static final int RsBlas_cher2k = 139;
174    private static final int RsBlas_zhemm = 140;
175    private static final int RsBlas_zherk = 141;
176    private static final int RsBlas_zher2k = 142;
177
178    // BLAS extensions start here
179    private static final int RsBlas_bnnm = 1000;
180
181    /**
182     */
183    public static ScriptIntrinsicBLAS create(RenderScript rs) {
184        long id = rs.nScriptIntrinsicCreate(13, Element.U32(rs).getID(rs));
185        return new ScriptIntrinsicBLAS(id, rs);
186    }
187
188    @IntDef({NO_TRANSPOSE, TRANSPOSE, CONJ_TRANSPOSE})
189    @Retention(RetentionPolicy.SOURCE)
190    public @interface Transpose {}
191
192    @IntDef({UPPER, LOWER})
193    @Retention(RetentionPolicy.SOURCE)
194    public @interface Uplo {}
195
196    @IntDef({NON_UNIT, UNIT})
197    @Retention(RetentionPolicy.SOURCE)
198    public @interface Diag {}
199
200    @IntDef({LEFT, RIGHT})
201    @Retention(RetentionPolicy.SOURCE)
202    public @interface Side {}
203
204    public static final int NO_TRANSPOSE = 111;
205    public static final int TRANSPOSE = 112;
206    public static final int CONJ_TRANSPOSE = 113;
207
208    public static final int UPPER = 121;
209    public static final int LOWER = 122;
210
211    public static final int NON_UNIT = 131;
212    public static final int UNIT = 132;
213
214    public static final int LEFT = 141;
215    public static final int RIGHT = 142;
216
217    static void validateSide(@Side int Side) {
218        if (Side != LEFT && Side != RIGHT) {
219            throw new RSRuntimeException("Invalid side passed to BLAS");
220        }
221    }
222
223    static void validateTranspose(@Transpose int Trans) {
224        if (Trans != NO_TRANSPOSE && Trans != TRANSPOSE &&
225            Trans != CONJ_TRANSPOSE) {
226            throw new RSRuntimeException("Invalid transpose passed to BLAS");
227        }
228    }
229
230    static void validateConjTranspose(@Transpose int Trans) {
231        if (Trans != NO_TRANSPOSE &&
232            Trans != CONJ_TRANSPOSE) {
233            throw new RSRuntimeException("Invalid transpose passed to BLAS");
234        }
235    }
236
237    static void validateDiag(@Diag int Diag) {
238        if (Diag != NON_UNIT && Diag != UNIT) {
239            throw new RSRuntimeException("Invalid diag passed to BLAS");
240        }
241    }
242
243    static void validateUplo(@Uplo int Uplo) {
244        if (Uplo != UPPER && Uplo != LOWER) {
245            throw new RSRuntimeException("Invalid uplo passed to BLAS");
246        }
247    }
248
249
250    /**
251     * Level 2 BLAS
252     */
253
254    static void validateGEMV(Element e, int TransA, Allocation A, Allocation X, int incX, Allocation Y, int incY) {
255        validateTranspose(TransA);
256        int M = A.getType().getY();
257        int N = A.getType().getX();
258        if (!A.getType().getElement().isCompatible(e) ||
259            !X.getType().getElement().isCompatible(e) ||
260            !Y.getType().getElement().isCompatible(e)) {
261            throw new RSRuntimeException("Called BLAS with wrong Element type");
262        }
263        if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
264            throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
265        }
266
267        if (incX <= 0 || incY <= 0) {
268            throw new RSRuntimeException("Vector increments must be greater than 0");
269        }
270        int expectedXDim = -1, expectedYDim = -1;
271        if (TransA == NO_TRANSPOSE) {
272            expectedXDim = 1 + (N - 1) * incX;
273            expectedYDim = 1 + (M - 1) * incY;
274        } else {
275            expectedXDim = 1 + (M - 1) * incX;
276            expectedYDim = 1 + (N - 1) * incY;
277        }
278        if (X.getType().getX() != expectedXDim ||
279            Y.getType().getX() != expectedYDim) {
280            throw new RSRuntimeException("Incorrect vector dimensions for GEMV");
281        }
282    }
283    public void SGEMV(@Transpose int TransA, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
284        validateGEMV(Element.F32(mRS), TransA, A, X, incX, Y, incY);
285        int M = A.getType().getY();
286        int N = A.getType().getX();
287        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
288    }
289    public void DGEMV(@Transpose int TransA, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
290        validateGEMV(Element.F64(mRS), TransA, A, X, incX, Y, incY);
291        int M = A.getType().getY();
292        int N = A.getType().getX();
293        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
294    }
295    public void CGEMV(@Transpose int TransA, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
296        validateGEMV(Element.F32_2(mRS), TransA, A, X, incX, Y, incY);
297        int M = A.getType().getY();
298        int N = A.getType().getX();
299        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
300    }
301    public void ZGEMV(@Transpose int TransA, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
302        validateGEMV(Element.F64_2(mRS), TransA, A, X, incX, Y, incY);
303        int M = A.getType().getY();
304        int N = A.getType().getX();
305        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
306    }
307
308    public void SGBMV(@Transpose int TransA, int KL, int KU, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
309        // GBMV has the same validation requirements as GEMV + KL and KU >= 0
310        validateGEMV(Element.F32(mRS), TransA, A, X, incX, Y, incY);
311        if (KL < 0 || KU < 0) {
312            throw new RSRuntimeException("KL and KU must be greater than or equal to 0");
313        }
314        int M = A.getType().getY();
315        int N = A.getType().getX();
316        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, KL, KU);
317    }
318    public void DGBMV(@Transpose int TransA, int KL, int KU, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
319        // GBMV has the same validation requirements as GEMV + KL and KU >= 0
320        validateGEMV(Element.F64(mRS), TransA, A, X, incX, Y, incY);
321        if (KL < 0 || KU < 0) {
322            throw new RSRuntimeException("KL and KU must be greater than or equal to 0");
323        }
324        int M = A.getType().getY();
325        int N = A.getType().getX();
326        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, KL, KU);
327    }
328    public void CGBMV(@Transpose int TransA, int KL, int KU, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
329        // GBMV has the same validation requirements as GEMV + KL and KU >= 0
330        validateGEMV(Element.F32_2(mRS), TransA, A, X, incX, Y, incY);
331        if (KL < 0 || KU < 0) {
332            throw new RSRuntimeException("KL and KU must be greater than or equal to 0");
333        }
334        int M = A.getType().getY();
335        int N = A.getType().getX();
336        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU);
337    }
338    public void ZGBMV(@Transpose int TransA, int KL, int KU, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
339        // GBMV has the same validation requirements as GEMV + KL and KU >= 0
340        validateGEMV(Element.F64_2(mRS), TransA, A, X, incX, Y, incY);
341        if (KL < 0 || KU < 0) {
342            throw new RSRuntimeException("KL and KU must be greater than or equal to 0");
343        }
344        int M = A.getType().getY();
345        int N = A.getType().getX();
346        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU);
347    }
348
349    static void validateTRMV(Element e, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
350        validateTranspose(TransA);
351        validateUplo(Uplo);
352        validateDiag(Diag);
353        int N = A.getType().getY();
354        if (A.getType().getX() != N) {
355            throw new RSRuntimeException("A must be a square matrix for TRMV");
356        }
357        if (!A.getType().getElement().isCompatible(e) ||
358            !X.getType().getElement().isCompatible(e)) {
359            throw new RSRuntimeException("Called BLAS with wrong Element type");
360        }
361        if (X.getType().getY() > 1) {
362            throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
363        }
364
365        if (incX <= 0) {
366            throw new RSRuntimeException("Vector increments must be greater than 0");
367        }
368        int expectedXDim = 1 + (N - 1) * incX;
369        if (X.getType().getX() != expectedXDim) {
370            throw new RSRuntimeException("Incorrect vector dimensions for TRMV");
371        }
372    }
373
374    static int validateTPMV(Element e, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
375        validateTranspose(TransA);
376        validateUplo(Uplo);
377        validateDiag(Diag);
378        if (!Ap.getType().getElement().isCompatible(e) ||
379            !X.getType().getElement().isCompatible(e)) {
380            throw new RSRuntimeException("Called BLAS with wrong Element type");
381        }
382        if (X.getType().getY() > 1) {
383            throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
384        }
385
386        if (Ap.getType().getY() > 1) {
387            throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1");
388        }
389
390        int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
391        //is it really doing anything?
392        if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
393            throw new RSRuntimeException("Invalid dimension for Ap");
394        }
395        if (incX <= 0) {
396            throw new RSRuntimeException("Vector increments must be greater than 0");
397        }
398        int expectedXDim = 1 + (N - 1) * incX;
399        if (X.getType().getX() != expectedXDim) {
400            throw new RSRuntimeException("Incorrect vector dimensions for TPMV");
401        }
402
403        return N;
404    }
405
406    public void STRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
407        validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX);
408        int N = A.getType().getY();
409        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
410    }
411    public void DTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
412        validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX);
413        int N = A.getType().getY();
414        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
415    }
416    public void CTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
417        validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
418        int N = A.getType().getY();
419        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
420    }
421    public void ZTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
422        validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
423        int N = A.getType().getY();
424        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
425    }
426
427    public void STBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  int K, Allocation A,  Allocation X,  int incX) {
428        // TBMV has the same requirements as TRMV + K >= 0
429        if (K < 0) {
430            throw new RSRuntimeException("K must be greater than or equal to 0");
431        }
432        validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX);
433        int N = A.getType().getY();
434        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
435    }
436    public void DTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  int K, Allocation A,  Allocation X,  int incX) {
437        // TBMV has the same requirements as TRMV + K >= 0
438        if (K < 0) {
439            throw new RSRuntimeException("K must be greater than or equal to 0");
440        }
441        validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX);
442        int N = A.getType().getY();
443        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
444    }
445    public void CTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  int K, Allocation A,  Allocation X,  int incX) {
446        // TBMV has the same requirements as TRMV + K >= 0
447        if (K < 0) {
448            throw new RSRuntimeException("K must be greater than or equal to 0");
449        }
450        validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
451        int N = A.getType().getY();
452        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
453    }
454    public void ZTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  int K, Allocation A,  Allocation X,  int incX) {
455        // TBMV has the same requirements as TRMV + K >= 0
456        if (K < 0) {
457            throw new RSRuntimeException("K must be greater than or equal to 0");
458        }
459        validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
460        int N = A.getType().getY();
461        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
462    }
463    public void STPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  Allocation Ap,  Allocation X,  int incX) {
464        int N = validateTPMV(Element.F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
465        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
466    }
467    public void DTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  Allocation Ap,  Allocation X,  int incX) {
468        int N = validateTPMV(Element.F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
469        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
470    }
471    public void CTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  Allocation Ap,  Allocation X,  int incX) {
472        int N = validateTPMV(Element.F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
473        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
474    }
475    public void ZTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  Allocation Ap,  Allocation X,  int incX) {
476        int N = validateTPMV(Element.F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
477        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
478    }
479    public void STRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  Allocation A,  Allocation X,  int incX) {
480        // TRSV is the same as TRMV
481        validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX);
482        int N = A.getType().getY();
483        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
484
485    }
486    public void DTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  Allocation A,  Allocation X,  int incX) {
487        // TRSV is the same as TRMV
488        validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX);
489        int N = A.getType().getY();
490        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
491
492    }
493    public void CTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  Allocation A,  Allocation X,  int incX) {
494        // TRSV is the same as TRMV
495        validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
496        int N = A.getType().getY();
497        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
498
499    }
500    public void ZTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  Allocation A,  Allocation X,  int incX) {
501        // TRSV is the same as TRMV
502        validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
503        int N = A.getType().getY();
504        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
505
506    }
507    public void STBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  int K, Allocation A,  Allocation X,  int incX) {
508        // TBSV is the same as TRMV + K >= 0
509        validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX);
510        int N = A.getType().getY();
511        if (K < 0) {
512            throw new RSRuntimeException("Number of diagonals must be positive");
513        }
514        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
515    }
516    public void DTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  int K, Allocation A,  Allocation X,  int incX) {
517        // TBSV is the same as TRMV + K >= 0
518        validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX);
519        int N = A.getType().getY();
520        if (K < 0) {
521            throw new RSRuntimeException("Number of diagonals must be positive");
522        }
523        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
524    }
525    public void CTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  int K, Allocation A,  Allocation X,  int incX) {
526        // TBSV is the same as TRMV + K >= 0
527        validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
528        int N = A.getType().getY();
529        if (K < 0) {
530            throw new RSRuntimeException("Number of diagonals must be positive");
531        }
532        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
533    }
534    public void ZTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  int K, Allocation A,  Allocation X,  int incX) {
535        // TBSV is the same as TRMV + K >= 0
536        validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
537        int N = A.getType().getY();
538        if (K < 0) {
539            throw new RSRuntimeException("Number of diagonals must be positive");
540        }
541        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
542    }
543    public void STPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  Allocation Ap,  Allocation X,  int incX) {
544        // TPSV is same as TPMV
545        int N = validateTPMV(Element.F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
546        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
547    }
548    public void DTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  Allocation Ap,  Allocation X,  int incX) {
549        // TPSV is same as TPMV
550        int N = validateTPMV(Element.F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
551        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
552    }
553    public void CTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  Allocation Ap,  Allocation X,  int incX) {
554        // TPSV is same as TPMV
555        int N = validateTPMV(Element.F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
556        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
557    }
558    public void ZTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  Allocation Ap,  Allocation X,  int incX) {
559        // TPSV is same as TPMV
560        int N = validateTPMV(Element.F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
561        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
562    }
563
564    /**
565     * Level 2, S and D only
566     */
567    static int validateSYMV(Element e, @Uplo int Uplo, Allocation A, Allocation X, Allocation Y, int incX, int incY) {
568        validateUplo(Uplo);
569        int N = A.getType().getY();
570        if (A.getType().getX() != N) {
571            throw new RSRuntimeException("A must be a square matrix for SYMV");
572        }
573        if (!A.getType().getElement().isCompatible(e) ||
574            !X.getType().getElement().isCompatible(e) ||
575            !Y.getType().getElement().isCompatible(e) ) {
576            throw new RSRuntimeException("Called BLAS with wrong Element type");
577        }
578        if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
579            throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
580        }
581
582        if (incX <= 0 || incY <= 0) {
583            throw new RSRuntimeException("Vector increments must be greater than 0");
584        }
585        int expectedXDim = 1 + (N - 1) * incX;
586        if (X.getType().getX() != expectedXDim) {
587            throw new RSRuntimeException("Incorrect vector dimensions for SYMV");
588        }
589        int expectedYDim = 1 + (N - 1) * incY;
590        if (Y.getType().getX() != expectedYDim) {
591            throw new RSRuntimeException("Incorrect vector dimensions for SYMV");
592        }
593        return N;
594    }
595    static int validateSPMV(Element e, @Uplo int Uplo, Allocation Ap, Allocation X, int incX, Allocation Y, int incY) {
596        validateUplo(Uplo);
597        if (!Ap.getType().getElement().isCompatible(e) ||
598            !X.getType().getElement().isCompatible(e) ||
599            !Y.getType().getElement().isCompatible(e)) {
600            throw new RSRuntimeException("Called BLAS with wrong Element type");
601        }
602        if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
603            throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
604        }
605
606        if (Ap.getType().getY() > 1) {
607            throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1");
608        }
609
610        int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
611        if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
612            throw new RSRuntimeException("Invalid dimension for Ap");
613        }
614        if (incX <= 0 || incY <= 0) {
615            throw new RSRuntimeException("Vector increments must be greater than 0");
616        }
617        int expectedXDim = 1 + (N - 1) * incX;
618        if (X.getType().getX() != expectedXDim) {
619            throw new RSRuntimeException("Incorrect vector dimensions for SPMV");
620        }
621        int expectedYDim = 1 + (N - 1) * incY;
622        if (Y.getType().getX() != expectedYDim) {
623            throw new RSRuntimeException("Incorrect vector dimensions for SPMV");
624        }
625
626        return N;
627    }
628    static void validateGER(Element e, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
629        if (!A.getType().getElement().isCompatible(e) ||
630            !X.getType().getElement().isCompatible(e) ||
631            !Y.getType().getElement().isCompatible(e) ) {
632            throw new RSRuntimeException("Called BLAS with wrong Element type");
633        }
634
635        if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
636            throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
637        }
638
639        int M = A.getType().getY();
640        int N = A.getType().getX();
641
642        if (N < 1 || M < 1) {
643            throw new RSRuntimeException("M and N must be 1 or greater for GER");
644        }
645        if (incX <= 0 || incY <= 0) {
646            throw new RSRuntimeException("Vector increments must be greater than 0");
647        }
648        int expectedXDim = 1 + (M - 1) * incX;
649        if (X.getType().getX() != expectedXDim) {
650            throw new RSRuntimeException("Incorrect vector dimensions for GER");
651        }
652        int expectedYDim = 1 + (N - 1) * incY;
653        if (Y.getType().getX() != expectedYDim) {
654            throw new RSRuntimeException("Incorrect vector dimensions for GER");
655        }
656
657
658    }
659    static int validateSYR(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation A) {
660        validateUplo(Uplo);
661        if (!A.getType().getElement().isCompatible(e) ||
662            !X.getType().getElement().isCompatible(e)) {
663            throw new RSRuntimeException("Called BLAS with wrong Element type");
664        }
665
666        int N = A.getType().getX();
667
668        if (X.getType().getY() > 1) {
669            throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
670        }
671        if (N != A.getType().getY()) {
672            throw new RSRuntimeException("A must be a symmetric matrix");
673        }
674        if (incX <= 0) {
675            throw new RSRuntimeException("Vector increments must be greater than 0");
676        }
677        int expectedXDim = 1 + (N - 1) * incX;
678        if (X.getType().getX() != expectedXDim) {
679            throw new RSRuntimeException("Incorrect vector dimensions for SYR");
680        }
681        return N;
682    }
683    static int validateSPR(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Ap) {
684        validateUplo(Uplo);
685        if (!Ap.getType().getElement().isCompatible(e) ||
686            !X.getType().getElement().isCompatible(e)) {
687            throw new RSRuntimeException("Called BLAS with wrong Element type");
688        }
689        if (X.getType().getY() > 1) {
690            throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
691        }
692
693        if (Ap.getType().getY() > 1) {
694            throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1");
695        }
696
697        int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
698        if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
699            throw new RSRuntimeException("Invalid dimension for Ap");
700        }
701        if (incX <= 0) {
702            throw new RSRuntimeException("Vector increments must be greater than 0");
703        }
704        int expectedXDim = 1 + (N - 1) * incX;
705        if (X.getType().getX() != expectedXDim) {
706            throw new RSRuntimeException("Incorrect vector dimensions for SPR");
707        }
708
709        return N;
710    }
711
712    static int validateSYR2(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
713        validateUplo(Uplo);
714        if (!A.getType().getElement().isCompatible(e) ||
715            !X.getType().getElement().isCompatible(e) ||
716            !Y.getType().getElement().isCompatible(e)) {
717            throw new RSRuntimeException("Called BLAS with wrong Element type");
718        }
719
720        if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
721            throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
722        }
723
724        int N = A.getType().getX();
725
726        if (N != A.getType().getY()) {
727            throw new RSRuntimeException("A must be a symmetric matrix");
728        }
729        if (incX <= 0 || incY <= 0) {
730            throw new RSRuntimeException("Vector increments must be greater than 0");
731        }
732        int expectedXDim = 1 + (N - 1) * incX;
733        int expectedYDim = 1 + (N - 1) * incY;
734        if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) {
735            throw new RSRuntimeException("Incorrect vector dimensions for SYR");
736        }
737        return N;
738
739    }
740    static int validateSPR2(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
741        validateUplo(Uplo);
742        if (!Ap.getType().getElement().isCompatible(e) ||
743            !X.getType().getElement().isCompatible(e) ||
744            !Y.getType().getElement().isCompatible(e)) {
745            throw new RSRuntimeException("Called BLAS with wrong Element type");
746        }
747        if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
748            throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
749        }
750
751        if (Ap.getType().getY() > 1) {
752            throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1");
753        }
754
755        int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
756        if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
757            throw new RSRuntimeException("Invalid dimension for Ap");
758        }
759        if (incX <= 0 || incY <= 0) {
760            throw new RSRuntimeException("Vector increments must be greater than 0");
761        }
762        int expectedXDim = 1 + (N - 1) * incX;
763        int expectedYDim = 1 + (N - 1) * incY;
764        if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) {
765            throw new RSRuntimeException("Incorrect vector dimensions for SPR2");
766        }
767
768        return N;
769    }
770
771    public void SSYMV(@Uplo int Uplo, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
772        int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY);
773        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
774    }
775    public void SSBMV(@Uplo int Uplo, int K, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
776        // SBMV is the same as SYMV + K >= 0
777        if (K < 0) {
778            throw new RSRuntimeException("K must be greater than or equal to 0");
779        }
780        int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY);
781        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
782    }
783    public void SSPMV(@Uplo int Uplo, float alpha, Allocation Ap, Allocation X, int incX, float beta, Allocation Y, int incY) {
784        int N = validateSPMV(Element.F32(mRS), Uplo, Ap, X, incX, Y, incY);
785        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
786    }
787    public void SGER(float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
788        int M = A.getType().getY();
789        int N = A.getType().getX();
790        validateGER(Element.F32(mRS), X, incX, Y, incY, A);
791        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0);
792    }
793    public void SSYR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) {
794        int N = validateSYR(Element.F32(mRS), Uplo, X, incX, A);
795        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), A.getID(mRS), 0.f, 0, incX, 0, 0, 0);
796    }
797    public void SSPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) {
798        int N = validateSPR(Element.F32(mRS), Uplo, X, incX, Ap);
799        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Ap.getID(mRS), 0.f, 0, incX, 0, 0, 0);
800    }
801    public void SSYR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
802        int N = validateSYR2(Element.F32(mRS), Uplo, X, incX, Y, incY, A);
803        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, A.getID(mRS), incX, incY, 0, 0);
804    }
805    public void SSPR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
806        int N = validateSPR2(Element.F32(mRS), Uplo, X, incX, Y, incY, Ap);
807        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, Ap.getID(mRS), incX, incY, 0, 0);
808    }
809    public void DSYMV(@Uplo int Uplo, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
810        int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY);
811        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
812    }
813    public void DSBMV(@Uplo int Uplo, int K, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
814        // SBMV is the same as SYMV + K >= 0
815        if (K < 0) {
816            throw new RSRuntimeException("K must be greater than or equal to 0");
817        }
818        int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY);
819        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
820    }
821    public void DSPMV(@Uplo int Uplo, double alpha, Allocation Ap, Allocation X, int incX, double beta, Allocation Y, int incY) {
822        int N = validateSPMV(Element.F64(mRS), Uplo, Ap, X, incX, Y, incY);
823        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
824    }
825    public void DGER(double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
826        int M = A.getType().getY();
827        int N = A.getType().getX();
828        validateGER(Element.F64(mRS), X, incX, Y, incY, A);
829        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0);
830    }
831    public void DSYR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) {
832        int N = validateSYR(Element.F64(mRS), Uplo, X, incX, A);
833        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), A.getID(mRS), 0.f, 0, incX, 0, 0, 0);
834    }
835    public void DSPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) {
836        int N = validateSPR(Element.F64(mRS), Uplo, X, incX, Ap);
837        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Ap.getID(mRS), 0.f, 0, incX, 0, 0, 0);
838    }
839    public void DSYR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
840        int N = validateSYR2(Element.F64(mRS), Uplo, X, incX, Y, incY, A);
841        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, A.getID(mRS), incX, incY, 0, 0);
842    }
843    public void DSPR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
844        int N = validateSPR2(Element.F64(mRS), Uplo, X, incX, Y, incY, Ap);
845        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, Ap.getID(mRS), incX, incY, 0, 0);
846    }
847
848
849    /**
850     * Level 2, C and Z only
851     */
852
853    static void validateGERU(Element e, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
854        if (!A.getType().getElement().isCompatible(e) ||
855            !X.getType().getElement().isCompatible(e) ||
856            !Y.getType().getElement().isCompatible(e)) {
857            throw new RSRuntimeException("Called BLAS with wrong Element type");
858        }
859        if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
860            throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
861        }
862
863        int M = A.getType().getY();
864        int N = A.getType().getX();
865        if (incX <= 0 || incY <= 0) {
866            throw new RSRuntimeException("Vector increments must be greater than 0");
867        }
868        int expectedXDim = 1 + (M - 1) * incX;
869        if (X.getType().getX() != expectedXDim) {
870            throw new RSRuntimeException("Incorrect vector dimensions for GERU");
871        }
872        int expectedYDim = 1 + (N - 1) * incY;
873        if (Y.getType().getX() != expectedYDim) {
874            throw new RSRuntimeException("Incorrect vector dimensions for GERU");
875        }
876
877    }
878
879    public void CHEMV(@Uplo int Uplo, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
880        // HEMV is the same as SYR2 validation-wise
881        int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A);
882        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
883    }
884    public void CHBMV(@Uplo int Uplo, int K, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
885        // HBMV is the same as SYR2 validation-wise
886        int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A);
887        if (K < 0) {
888            throw new RSRuntimeException("K must be 0 or greater for HBMV");
889        }
890        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
891    }
892    public void CHPMV(@Uplo int Uplo, Float2 alpha, Allocation Ap, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
893        // HPMV is the same as SPR2
894        int N = validateSPR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
895        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
896    }
897    public void CGERU(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
898        validateGERU(Element.F32_2(mRS), X, incX, Y, incY, A);
899        int M = A.getType().getY();
900        int N = A.getType().getX();
901        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
902    }
903    public void CGERC(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
904        // same as GERU
905        validateGERU(Element.F32_2(mRS), X, incX, Y, incY, A);
906        int M = A.getType().getY();
907        int N = A.getType().getX();
908        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
909    }
910    public void CHER(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) {
911        // same as SYR
912        int N = validateSYR(Element.F32_2(mRS), Uplo, X, incX, A);
913        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0);
914    }
915    public void CHPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) {
916        // equivalent to SPR for validation
917        int N = validateSPR(Element.F32_2(mRS), Uplo, X, incX, Ap);
918        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, Ap.getID(mRS), incX, 0, 0, 0);
919    }
920    public void CHER2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
921        // same as SYR2
922        int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A);
923        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
924    }
925    public void CHPR2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
926        // same as SPR2
927        int N = validateSPR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
928        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, Ap.getID(mRS), incX, incY, 0, 0);
929    }
930    public void ZHEMV(@Uplo int Uplo, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
931        // HEMV is the same as SYR2 validation-wise
932        int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A);
933        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
934    }
935    public void ZHBMV(@Uplo int Uplo, int K, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
936        // HBMV is the same as SYR2 validation-wise
937        int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A);
938        if (K < 0) {
939            throw new RSRuntimeException("K must be 0 or greater for HBMV");
940        }
941        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
942    }
943    public void ZHPMV(@Uplo int Uplo, Double2 alpha, Allocation Ap, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
944        // HPMV is the same as SPR2
945        int N = validateSPR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
946        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
947    }
948    public void ZGERU(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
949        validateGERU(Element.F64_2(mRS), X, incX, Y, incY, A);
950        int M = A.getType().getY();
951        int N = A.getType().getX();
952        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
953    }
954    public void ZGERC(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
955        // same as GERU
956        validateGERU(Element.F64_2(mRS), X, incX, Y, incY, A);
957        int M = A.getType().getY();
958        int N = A.getType().getX();
959        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
960    }
961    public void ZHER(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) {
962        // same as SYR
963        int N = validateSYR(Element.F64_2(mRS), Uplo, X, incX, A);
964        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0);
965    }
966    public void ZHPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) {
967        // equivalent to SPR for validation
968        int N = validateSPR(Element.F64_2(mRS), Uplo, X, incX, Ap);
969        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, Ap.getID(mRS), incX, 0, 0, 0);
970    }
971    public void ZHER2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
972        // same as SYR2
973        int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A);
974        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
975    }
976    public void ZHPR2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
977        // same as SPR2
978        int N = validateSPR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
979        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, Ap.getID(mRS), incX, incY, 0, 0);
980    }
981
982
983    /**
984     * Level 3 BLAS
985     */
986
987    static void validateL3(Element e, int TransA, int TransB, int Side, Allocation A, Allocation B, Allocation C) {
988        int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1;
989        if ((A != null && !A.getType().getElement().isCompatible(e)) ||
990            (B != null && !B.getType().getElement().isCompatible(e)) ||
991            (C != null && !C.getType().getElement().isCompatible(e))) {
992            throw new RSRuntimeException("Called BLAS with wrong Element type");
993        }
994        if (C == null) {
995            //since matrix C is used to store the result, it cannot be null.
996            throw new RSRuntimeException("Allocation C cannot be null");
997        }
998        cM = C.getType().getY();
999        cN = C.getType().getX();
1000
1001        if (Side == RIGHT) {
1002            if ((A == null && B != null) || (A != null && B == null)) {
1003                throw new RSRuntimeException("Provided Matrix A without Matrix B, or vice versa");
1004            }
1005            if (B != null) {
1006                bM = A.getType().getY();
1007                bN = A.getType().getX();
1008            }
1009            if (A != null) {
1010                aM = B.getType().getY();
1011                aN = B.getType().getX();
1012            }
1013        } else {
1014            if (A != null) {
1015                if (TransA != NO_TRANSPOSE) {
1016                    aN = A.getType().getY();
1017                    aM = A.getType().getX();
1018                } else {
1019                    aM = A.getType().getY();
1020                    aN = A.getType().getX();
1021                }
1022            }
1023            if (B != null) {
1024                if (TransB != NO_TRANSPOSE) {
1025                    bN = B.getType().getY();
1026                    bM = B.getType().getX();
1027                } else {
1028                    bM = B.getType().getY();
1029                    bN = B.getType().getX();
1030                }
1031            }
1032        }
1033        if (A != null && B != null && C != null) {
1034            if (aN != bM || aM != cM || bN != cN) {
1035                throw new RSRuntimeException("Called BLAS with invalid dimensions");
1036            }
1037        } else if (A != null && C != null) {
1038            // A and C only, for SYRK
1039            if (cM != cN) {
1040                throw new RSRuntimeException("Matrix C is not symmetric");
1041            }
1042            if (TransA != NO_TRANSPOSE) {
1043                if (aN != cM) {
1044                    throw new RSRuntimeException("Called BLAS with invalid dimensions");
1045                }
1046            } else {
1047                if (aM != cM) {
1048                    throw new RSRuntimeException("Called BLAS with invalid dimensions");
1049                }
1050            }
1051        } else if (A != null && B != null) {
1052            // A and B only
1053            if (aN != bM) {
1054                throw new RSRuntimeException("Called BLAS with invalid dimensions");
1055            }
1056        }
1057
1058    }
1059
1060    public void SGEMM(@Transpose int TransA, @Transpose int TransB, float alpha, Allocation A,
1061                      Allocation B, float beta, Allocation C) {
1062        validateTranspose(TransA);
1063        validateTranspose(TransB);
1064        validateL3(Element.F32(mRS), TransA, TransB, 0, A, B, C);
1065
1066        int M = -1, N = -1, K = -1;
1067        if (TransA != NO_TRANSPOSE) {
1068            M = A.getType().getX();
1069            K = A.getType().getY();
1070        } else {
1071            M = A.getType().getY();
1072            K = A.getType().getX();
1073        }
1074        if (TransB != NO_TRANSPOSE) {
1075            N = B.getType().getY();
1076        } else {
1077            N = B.getType().getX();
1078        }
1079        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgemm, TransA, TransB, 0, 0, 0, M, N, K,  alpha, A.getID(mRS), B.getID(mRS),
1080                                        beta, C.getID(mRS), 0, 0, 0, 0);
1081    }
1082    public void DGEMM(@Transpose int TransA, @Transpose int TransB, double alpha, Allocation A,
1083                      Allocation B, double beta, Allocation C) {
1084        validateTranspose(TransA);
1085        validateTranspose(TransB);
1086        validateL3(Element.F64(mRS), TransA, TransB, 0, A, B, C);
1087        int M = -1, N = -1, K = -1;
1088        if (TransA != NO_TRANSPOSE) {
1089            M = A.getType().getX();
1090            K = A.getType().getY();
1091        } else {
1092            M = A.getType().getY();
1093            K = A.getType().getX();
1094        }
1095        if (TransB != NO_TRANSPOSE) {
1096            N = B.getType().getY();
1097        } else {
1098            N = B.getType().getX();
1099        }
1100        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgemm, TransA, TransB, 0, 0, 0, M, N, K,  alpha, A.getID(mRS), B.getID(mRS),
1101                                        beta, C.getID(mRS), 0, 0, 0, 0);
1102    }
1103    public void CGEMM(@Transpose int TransA, @Transpose int TransB, Float2 alpha, Allocation A,
1104                      Allocation B, Float2 beta, Allocation C) {
1105        validateTranspose(TransA);
1106        validateTranspose(TransB);
1107        validateL3(Element.F32_2(mRS), TransA, TransB, 0, A, B, C);
1108        int M = -1, N = -1, K = -1;
1109        if (TransA != NO_TRANSPOSE) {
1110            M = A.getType().getX();
1111            K = A.getType().getY();
1112        } else {
1113            M = A.getType().getY();
1114            K = A.getType().getX();
1115        }
1116        if (TransB != NO_TRANSPOSE) {
1117            N = B.getType().getY();
1118        } else {
1119            N = B.getType().getX();
1120        }
1121        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgemm, TransA, TransB, 0, 0, 0, M, N, K,  alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
1122                                         beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1123    }
1124
1125    public void ZGEMM(@Transpose int TransA, @Transpose int TransB, Double2 alpha, Allocation A,
1126                      Allocation B, Double2 beta, Allocation C) {
1127        validateTranspose(TransA);
1128        validateTranspose(TransB);
1129        validateL3(Element.F64_2(mRS), TransA, TransB, 0, A, B, C);
1130        int M = -1, N = -1, K = -1;
1131        if (TransA != NO_TRANSPOSE) {
1132            M = A.getType().getX();
1133            K = A.getType().getY();
1134        } else {
1135            M = A.getType().getY();
1136            K = A.getType().getX();
1137        }
1138        if (TransB != NO_TRANSPOSE) {
1139            N = B.getType().getY();
1140        } else {
1141            N = B.getType().getX();
1142        }
1143        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgemm, TransA, TransB, 0, 0, 0, M, N, K,  alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
1144                                   beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1145    }
1146
1147    public void SSYMM(@Side int Side, @Uplo int Uplo, float alpha, Allocation A,
1148                      Allocation B, float beta, Allocation C) {
1149        validateSide(Side);
1150        validateUplo(Uplo);
1151        //For SYMM, Matrix A should be symmetric
1152        if (A.getType().getX() != A.getType().getY()) {
1153            throw new RSRuntimeException("Matrix A is not symmetric");
1154        }
1155        validateL3(Element.F32(mRS), 0, 0, Side, A, B, C);
1156        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS),
1157                                        beta, C.getID(mRS), 0, 0, 0, 0);
1158    }
1159    public void DSYMM(@Side int Side, @Uplo int Uplo, double alpha, Allocation A,
1160                      Allocation B, double beta, Allocation C) {
1161        validateSide(Side);
1162        validateUplo(Uplo);
1163        if (A.getType().getX() != A.getType().getY()) {
1164            throw new RSRuntimeException("Matrix A is not symmetric");
1165        }
1166        validateL3(Element.F64(mRS), 0, 0, Side, A, B, C);
1167        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS),
1168                                        beta, C.getID(mRS), 0, 0, 0, 0);
1169    }
1170    public void CSYMM(@Side int Side, @Uplo int Uplo, Float2 alpha, Allocation A,
1171                      Allocation B, Float2 beta, Allocation C) {
1172        validateSide(Side);
1173        validateUplo(Uplo);
1174        if (A.getType().getX() != A.getType().getY()) {
1175            throw new RSRuntimeException("Matrix A is not symmetric");
1176        }
1177        validateL3(Element.F32_2(mRS), 0, 0, Side, A, B, C);
1178        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
1179                                         beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1180    }
1181    public void ZSYMM(@Side int Side, @Uplo int Uplo, Double2 alpha, Allocation A,
1182                      Allocation B, Double2 beta, Allocation C) {
1183        validateSide(Side);
1184        validateUplo(Uplo);
1185        if (A.getType().getX() != A.getType().getY()) {
1186            throw new RSRuntimeException("Matrix A is not symmetric");
1187        }
1188        validateL3(Element.F64_2(mRS), 0, 0, Side, A, B, C);
1189        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
1190                                   beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1191    }
1192
1193    public void SSYRK(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, float beta, Allocation C) {
1194        validateTranspose(Trans);
1195        validateUplo(Uplo);
1196        validateL3(Element.F32(mRS), Trans, 0, 0, A, null, C);
1197        int K = -1;
1198        if (Trans != NO_TRANSPOSE) {
1199            K = A.getType().getY();
1200        } else {
1201            K = A.getType().getX();
1202        }
1203
1204        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), 0, beta, C.getID(mRS), 0, 0, 0, 0);
1205    }
1206
1207    public void DSYRK(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, double beta, Allocation C) {
1208        validateTranspose(Trans);
1209        validateUplo(Uplo);
1210        validateL3(Element.F64(mRS), Trans, 0, 0, A, null, C);
1211        int K = -1;
1212        if (Trans != NO_TRANSPOSE) {
1213            K = A.getType().getY();
1214        } else {
1215            K = A.getType().getX();
1216        }
1217        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), 0, beta, C.getID(mRS), 0, 0, 0, 0);
1218    }
1219    public void CSYRK(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Float2 beta, Allocation C) {
1220        validateTranspose(Trans);
1221        validateUplo(Uplo);
1222        validateL3(Element.F32_2(mRS), Trans, 0, 0, A, null, C);
1223        int K = -1;
1224        if (Trans != NO_TRANSPOSE) {
1225            K = A.getType().getY();
1226        } else {
1227            K = A.getType().getX();
1228        }
1229        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), 0, beta.x, beta.y,
1230                                         C.getID(mRS), 0, 0, 0, 0);
1231    }
1232    public void ZSYRK(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Double2 beta, Allocation C) {
1233        validateTranspose(Trans);
1234        validateUplo(Uplo);
1235        validateL3(Element.F64_2(mRS), Trans, 0, 0, A, null, C);
1236        int K = -1;
1237        if (Trans != NO_TRANSPOSE) {
1238            K = A.getType().getY();
1239        } else {
1240            K = A.getType().getX();
1241        }
1242        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), 0, beta.x, beta.y,
1243                                   C.getID(mRS), 0, 0, 0, 0);
1244    }
1245
1246    static void validateSYR2K(Element e, @Transpose int Trans, Allocation A, Allocation B, Allocation C) {
1247        validateTranspose(Trans);
1248        if (!A.getType().getElement().isCompatible(e) ||
1249            !B.getType().getElement().isCompatible(e) ||
1250            !C.getType().getElement().isCompatible(e)) {
1251            throw new RSRuntimeException("Called BLAS with wrong Element type");
1252        }
1253        int Cdim = -1;
1254        // A is n x k if no transpose, k x n if transpose
1255        // C is n x n
1256        if (Trans == TRANSPOSE) {
1257            // check columns versus C
1258            Cdim = A.getType().getX();
1259        } else {
1260            // check rows versus C
1261            Cdim = A.getType().getY();
1262        }
1263        if (C.getType().getX() != Cdim || C.getType().getY() != Cdim) {
1264            throw new RSRuntimeException("Invalid symmetric matrix in SYR2K");
1265        }
1266        // A dims == B dims
1267        if (A.getType().getX() != B.getType().getX() || A.getType().getY() != B.getType().getY()) {
1268            throw new RSRuntimeException("Invalid A and B in SYR2K");
1269        }
1270    }
1271    public void SSYR2K(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, Allocation B, float beta, Allocation C) {
1272        validateUplo(Uplo);
1273        validateSYR2K(Element.F32(mRS), Trans, A, B, C);
1274        int K = -1;
1275        if (Trans == TRANSPOSE) {
1276            K = A.getType().getY();
1277        } else {
1278            K = A.getType().getX();
1279        }
1280        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0);
1281    }
1282    public void DSYR2K(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, Allocation B, double beta, Allocation C) {
1283        validateUplo(Uplo);
1284        validateSYR2K(Element.F64(mRS), Trans, A, B, C);
1285        int K = -1;
1286        if (Trans == TRANSPOSE) {
1287            K = A.getType().getY();
1288        } else {
1289            K = A.getType().getX();
1290        }
1291        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0);
1292    }
1293    public void CSYR2K(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Allocation B, Float2 beta, Allocation C) {
1294        validateUplo(Uplo);
1295        validateSYR2K(Element.F32_2(mRS), Trans, A, B, C);
1296        int K = -1;
1297        if (Trans == TRANSPOSE) {
1298            K = A.getType().getY();
1299        } else {
1300            K = A.getType().getX();
1301        }
1302        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1303    }
1304    public void ZSYR2K(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) {
1305        validateUplo(Uplo);
1306        validateSYR2K(Element.F64_2(mRS), Trans, A, B, C);
1307        int K = -1;
1308        if (Trans == TRANSPOSE) {
1309            K = A.getType().getY();
1310        } else {
1311            K = A.getType().getX();
1312        }
1313        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1314    }
1315
1316    static void validateTRMM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) {
1317        validateSide(Side);
1318        validateTranspose(TransA);
1319        int aM = -1, aN = -1, bM = -1, bN = -1;
1320        if (!A.getType().getElement().isCompatible(e) ||
1321            !B.getType().getElement().isCompatible(e)) {
1322            throw new RSRuntimeException("Called BLAS with wrong Element type");
1323        }
1324
1325        aM = A.getType().getY();
1326        aN = A.getType().getX();
1327        if (aM != aN) {
1328            throw new RSRuntimeException("Called TRMM with a non-symmetric matrix A");
1329        }
1330
1331        bM = B.getType().getY();
1332        bN = B.getType().getX();
1333        if (Side == LEFT) {
1334            if (aN != bM) {
1335                throw new RSRuntimeException("Called TRMM with invalid matrices");
1336            }
1337        } else {
1338            if (bN != aM) {
1339                throw new RSRuntimeException("Called TRMM with invalid matrices");
1340            }
1341        }
1342    }
1343    public void STRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, float alpha, Allocation A, Allocation B) {
1344        validateUplo(Uplo);
1345        validateDiag(Diag);
1346        validateTRMM(Element.F32(mRS), Side, TransA, A, B);
1347        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1348                                        alpha, A.getID(mRS), B.getID(mRS), 0.f, 0, 0, 0, 0, 0);
1349    }
1350    public void DTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, double alpha, Allocation A, Allocation B) {
1351        validateUplo(Uplo);
1352        validateDiag(Diag);
1353        validateTRMM(Element.F64(mRS), Side, TransA, A, B);
1354        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1355                                        alpha, A.getID(mRS), B.getID(mRS), 0.f, 0, 0, 0, 0, 0);
1356    }
1357    public void CTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Float2 alpha, Allocation A, Allocation B) {
1358        validateUplo(Uplo);
1359        validateDiag(Diag);
1360        validateTRMM(Element.F32_2(mRS), Side, TransA, A, B);
1361        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1362                                         alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0);
1363    }
1364    public void ZTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Double2 alpha, Allocation A, Allocation B) {
1365        validateUplo(Uplo);
1366        validateDiag(Diag);
1367        validateTRMM(Element.F64_2(mRS), Side, TransA, A, B);
1368        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1369                                   alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0);
1370    }
1371
1372    static void validateTRSM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) {
1373        int adim = -1, bM = -1, bN = -1;
1374        validateSide(Side);
1375        validateTranspose(TransA);
1376        if (!A.getType().getElement().isCompatible(e) ||
1377            !B.getType().getElement().isCompatible(e)) {
1378            throw new RSRuntimeException("Called BLAS with wrong Element type");
1379        }
1380        adim = A.getType().getX();
1381        if (adim != A.getType().getY()) {
1382            // this may be unnecessary, the restriction could potentially be relaxed
1383            // A needs to contain at least that symmetric matrix but could theoretically be larger
1384            // for now we assume adapters are sufficient, will reevaluate in the future
1385            throw new RSRuntimeException("Called TRSM with a non-symmetric matrix A");
1386        }
1387        bM = B.getType().getY();
1388        bN = B.getType().getX();
1389        if (Side == LEFT) {
1390            // A is M*M
1391            if (adim != bM) {
1392                throw new RSRuntimeException("Called TRSM with invalid matrix dimensions");
1393            }
1394        } else {
1395            // A is N*N
1396            if (adim != bN) {
1397                throw new RSRuntimeException("Called TRSM with invalid matrix dimensions");
1398            }
1399        }
1400    }
1401    public void STRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, float alpha, Allocation A, Allocation B) {
1402        validateUplo(Uplo);
1403        validateDiag(Diag);
1404        validateTRSM(Element.F32(mRS), Side, TransA, A, B);
1405        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1406                                        alpha, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0);
1407    }
1408    public void DTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, double alpha, Allocation A, Allocation B) {
1409        validateUplo(Uplo);
1410        validateDiag(Diag);
1411        validateTRSM(Element.F64(mRS), Side, TransA, A, B);
1412        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1413                                        alpha, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0);
1414    }
1415    public void CTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Float2 alpha, Allocation A, Allocation B) {
1416        validateUplo(Uplo);
1417        validateDiag(Diag);
1418        validateTRSM(Element.F32_2(mRS), Side, TransA, A, B);
1419        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1420                                         alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0);
1421    }
1422    public void ZTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Double2 alpha, Allocation A, Allocation B) {
1423        validateUplo(Uplo);
1424        validateDiag(Diag);
1425        validateTRSM(Element.F64_2(mRS), Side, TransA, A, B);
1426        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1427                                   alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0);
1428    }
1429
1430    static void validateHEMM(Element e, @Side int Side, Allocation A, Allocation B, Allocation C) {
1431        validateSide(Side);
1432
1433        if (!A.getType().getElement().isCompatible(e) ||
1434            !B.getType().getElement().isCompatible(e) ||
1435            !C.getType().getElement().isCompatible(e)) {
1436            throw new RSRuntimeException("Called BLAS with wrong Element type");
1437        }
1438
1439        // A must be square; can potentially be relaxed similar to TRSM
1440        int adim = A.getType().getX();
1441        if (adim != A.getType().getY()) {
1442            throw new RSRuntimeException("Called HEMM with non-square A");
1443        }
1444        if ((Side == LEFT && adim != B.getType().getY()) ||
1445            (Side == RIGHT && adim != B.getType().getX())) {
1446            throw new RSRuntimeException("Called HEMM with invalid B");
1447        }
1448        if (B.getType().getX() != C.getType().getX() ||
1449            B.getType().getY() != C.getType().getY()) {
1450            throw new RSRuntimeException("Called HEMM with mismatched B and C");
1451        }
1452    }
1453    public void CHEMM(@Side int Side, @Uplo int Uplo, Float2 alpha, Allocation A, Allocation B, Float2 beta, Allocation C) {
1454        validateUplo(Uplo);
1455        validateHEMM(Element.F32_2(mRS), Side, A, B, C);
1456        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0,
1457                                         alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1458    }
1459    public void ZHEMM(@Side int Side, @Uplo int Uplo, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) {
1460        validateUplo(Uplo);
1461        validateHEMM(Element.F64_2(mRS), Side, A, B, C);
1462        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0,
1463                                   alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1464    }
1465
1466    static void validateHERK(Element e, @Transpose int Trans, Allocation A, Allocation C) {
1467        if (!A.getType().getElement().isCompatible(e) ||
1468            !C.getType().getElement().isCompatible(e)) {
1469            throw new RSRuntimeException("Called BLAS with wrong Element type");
1470        }
1471        validateConjTranspose(Trans);
1472        int cdim = C.getType().getX();
1473        if (cdim != C.getType().getY()) {
1474            throw new RSRuntimeException("Called HERK with non-square C");
1475        }
1476        if (Trans == NO_TRANSPOSE) {
1477            if (cdim != A.getType().getY()) {
1478                throw new RSRuntimeException("Called HERK with invalid A");
1479            }
1480        } else {
1481            if (cdim != A.getType().getX()) {
1482                throw new RSRuntimeException("Called HERK with invalid A");
1483            }
1484        }
1485    }
1486    public void CHERK(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, float beta, Allocation C) {
1487        validateUplo(Uplo);
1488        validateHERK(Element.F32_2(mRS), Trans, A, C);
1489        int k = 0;
1490        if (Trans == CONJ_TRANSPOSE) {
1491            k = A.getType().getY();
1492        } else {
1493            k = A.getType().getX();
1494        }
1495        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cherk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k,
1496                                         alpha, 0, A.getID(mRS), 0, beta, 0, C.getID(mRS), 0, 0, 0, 0);
1497    }
1498    public void ZHERK(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, double beta, Allocation C) {
1499        validateUplo(Uplo);
1500        validateHERK(Element.F64_2(mRS), Trans, A, C);
1501        int k = 0;
1502        if (Trans == CONJ_TRANSPOSE) {
1503            k = A.getType().getY();
1504        } else {
1505            k = A.getType().getX();
1506        }
1507        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zherk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k,
1508                                   alpha, 0, A.getID(mRS), 0, beta, 0, C.getID(mRS), 0, 0, 0, 0);
1509    }
1510
1511    static void validateHER2K(Element e, @Transpose int Trans, Allocation A, Allocation B, Allocation C) {
1512        if (!A.getType().getElement().isCompatible(e) ||
1513            !B.getType().getElement().isCompatible(e) ||
1514            !C.getType().getElement().isCompatible(e)) {
1515            throw new RSRuntimeException("Called BLAS with wrong Element type");
1516        }
1517        validateConjTranspose(Trans);
1518        int cdim = C.getType().getX();
1519        if (cdim != C.getType().getY()) {
1520            throw new RSRuntimeException("Called HER2K with non-square C");
1521        }
1522        if (Trans == NO_TRANSPOSE) {
1523            if (A.getType().getY() != cdim) {
1524                throw new RSRuntimeException("Called HER2K with invalid matrices");
1525            }
1526        } else {
1527            if (A.getType().getX() != cdim) {
1528                throw new RSRuntimeException("Called HER2K with invalid matrices");
1529            }
1530        }
1531        if (A.getType().getX() != B.getType().getX() || A.getType().getY() != B.getType().getY()) {
1532            throw new RSRuntimeException("Called HER2K with invalid A and B matrices");
1533        }
1534    }
1535    public void CHER2K(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Allocation B, float beta, Allocation C) {
1536        validateUplo(Uplo);
1537        validateHER2K(Element.F32_2(mRS), Trans, A, B, C);
1538        int k = 0;
1539        if (Trans == NO_TRANSPOSE) {
1540            k = A.getType().getX();
1541        } else {
1542            k = A.getType().getY();
1543        }
1544        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, alpha.x, alpha.y,
1545                                         A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0);
1546    }
1547    public void ZHER2K(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Allocation B, double beta, Allocation C) {
1548        validateUplo(Uplo);
1549        validateHER2K(Element.F64_2(mRS), Trans, A, B, C);
1550        int k = 0;
1551        if (Trans == NO_TRANSPOSE) {
1552            k = A.getType().getX();
1553        } else {
1554            k = A.getType().getY();
1555        }
1556        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, alpha.x, alpha.y,
1557                                   A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0);
1558    }
1559
1560
1561    /**
1562     *
1563     * 8-bit GEMM-like operation for neural networks
1564     *
1565     **/
1566    public void BNNM(Allocation A, int a_offset, Allocation B, int b_offset, Allocation C, int c_offset, int c_mult) {
1567        validateL3(Element.U8(mRS), NO_TRANSPOSE, TRANSPOSE, 0, A, B, C);
1568
1569        int M = -1, N = -1, K = -1;
1570        M = A.getType().getY();
1571        N = B.getType().getY();
1572        K = A.getType().getX();
1573
1574
1575        mRS.nScriptIntrinsicBLAS_BNNM(getID(mRS), M, N, K, A.getID(mRS), a_offset, B.getID(mRS), b_offset, C.getID(mRS), c_offset, c_mult);
1576
1577    }
1578
1579}
1580