ScriptIntrinsicBLAS.java revision 2514806817ec394d334595d76e20f3129117da6e
1/* 2 * Copyright (C) 2015 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17package android.renderscript; 18 19import android.annotation.IntDef; 20import java.lang.annotation.Retention; 21import java.lang.annotation.RetentionPolicy; 22 23/** 24 * 25 * BLAS 26 * 27 * @hide 28 **/ 29public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { 30 private Allocation mLUT; 31 32 private ScriptIntrinsicBLAS(long id, RenderScript rs) { 33 super(id, rs); 34 } 35 36 private static final int RsBlas_sdsdot = 1; 37 private static final int RsBlas_dsdot = 2; 38 private static final int RsBlas_sdot = 3; 39 private static final int RsBlas_ddot = 4; 40 private static final int RsBlas_cdotu_sub = 5; 41 private static final int RsBlas_cdotc_sub = 6; 42 private static final int RsBlas_zdotu_sub = 7; 43 private static final int RsBlas_zdotc_sub = 8; 44 private static final int RsBlas_snrm2 = 9; 45 private static final int RsBlas_sasum = 10; 46 private static final int RsBlas_dnrm2 = 11; 47 private static final int RsBlas_dasum = 12; 48 private static final int RsBlas_scnrm2 = 13; 49 private static final int RsBlas_scasum = 14; 50 private static final int RsBlas_dznrm2 = 15; 51 private static final int RsBlas_dzasum = 16; 52 private static final int RsBlas_isamax = 17; 53 private static final int RsBlas_idamax = 18; 54 private static final int RsBlas_icamax = 19; 55 private static final int RsBlas_izamax = 20; 56 private static final int RsBlas_sswap = 21; 57 private static final int RsBlas_scopy = 22; 58 private static final int RsBlas_saxpy = 23; 59 private static final int RsBlas_dswap = 24; 60 private static final int RsBlas_dcopy = 25; 61 private static final int RsBlas_daxpy = 26; 62 private static final int RsBlas_cswap = 27; 63 private static final int RsBlas_ccopy = 28; 64 private static final int RsBlas_caxpy = 29; 65 private static final int RsBlas_zswap = 30; 66 private static final int RsBlas_zcopy = 31; 67 private static final int RsBlas_zaxpy = 32; 68 private static final int RsBlas_srotg = 33; 69 private static final int RsBlas_srotmg = 34; 70 private static final int RsBlas_srot = 35; 71 private static final int RsBlas_srotm = 36; 72 private static final int RsBlas_drotg = 37; 73 private static final int RsBlas_drotmg = 38; 74 private static final int RsBlas_drot = 39; 75 private static final int RsBlas_drotm = 40; 76 private static final int RsBlas_sscal = 41; 77 private static final int RsBlas_dscal = 42; 78 private static final int RsBlas_cscal = 43; 79 private static final int RsBlas_zscal = 44; 80 private static final int RsBlas_csscal = 45; 81 private static final int RsBlas_zdscal = 46; 82 private static final int RsBlas_sgemv = 47; 83 private static final int RsBlas_sgbmv = 48; 84 private static final int RsBlas_strmv = 49; 85 private static final int RsBlas_stbmv = 50; 86 private static final int RsBlas_stpmv = 51; 87 private static final int RsBlas_strsv = 52; 88 private static final int RsBlas_stbsv = 53; 89 private static final int RsBlas_stpsv = 54; 90 private static final int RsBlas_dgemv = 55; 91 private static final int RsBlas_dgbmv = 56; 92 private static final int RsBlas_dtrmv = 57; 93 private static final int RsBlas_dtbmv = 58; 94 private static final int RsBlas_dtpmv = 59; 95 private static final int RsBlas_dtrsv = 60; 96 private static final int RsBlas_dtbsv = 61; 97 private static final int RsBlas_dtpsv = 62; 98 private static final int RsBlas_cgemv = 63; 99 private static final int RsBlas_cgbmv = 64; 100 private static final int RsBlas_ctrmv = 65; 101 private static final int RsBlas_ctbmv = 66; 102 private static final int RsBlas_ctpmv = 67; 103 private static final int RsBlas_ctrsv = 68; 104 private static final int RsBlas_ctbsv = 69; 105 private static final int RsBlas_ctpsv = 70; 106 private static final int RsBlas_zgemv = 71; 107 private static final int RsBlas_zgbmv = 72; 108 private static final int RsBlas_ztrmv = 73; 109 private static final int RsBlas_ztbmv = 74; 110 private static final int RsBlas_ztpmv = 75; 111 private static final int RsBlas_ztrsv = 76; 112 private static final int RsBlas_ztbsv = 77; 113 private static final int RsBlas_ztpsv = 78; 114 private static final int RsBlas_ssymv = 79; 115 private static final int RsBlas_ssbmv = 80; 116 private static final int RsBlas_sspmv = 81; 117 private static final int RsBlas_sger = 82; 118 private static final int RsBlas_ssyr = 83; 119 private static final int RsBlas_sspr = 84; 120 private static final int RsBlas_ssyr2 = 85; 121 private static final int RsBlas_sspr2 = 86; 122 private static final int RsBlas_dsymv = 87; 123 private static final int RsBlas_dsbmv = 88; 124 private static final int RsBlas_dspmv = 89; 125 private static final int RsBlas_dger = 90; 126 private static final int RsBlas_dsyr = 91; 127 private static final int RsBlas_dspr = 92; 128 private static final int RsBlas_dsyr2 = 93; 129 private static final int RsBlas_dspr2 = 94; 130 private static final int RsBlas_chemv = 95; 131 private static final int RsBlas_chbmv = 96; 132 private static final int RsBlas_chpmv = 97; 133 private static final int RsBlas_cgeru = 98; 134 private static final int RsBlas_cgerc = 99; 135 private static final int RsBlas_cher = 100; 136 private static final int RsBlas_chpr = 101; 137 private static final int RsBlas_cher2 = 102; 138 private static final int RsBlas_chpr2 = 103; 139 private static final int RsBlas_zhemv = 104; 140 private static final int RsBlas_zhbmv = 105; 141 private static final int RsBlas_zhpmv = 106; 142 private static final int RsBlas_zgeru = 107; 143 private static final int RsBlas_zgerc = 108; 144 private static final int RsBlas_zher = 109; 145 private static final int RsBlas_zhpr = 110; 146 private static final int RsBlas_zher2 = 111; 147 private static final int RsBlas_zhpr2 = 112; 148 private static final int RsBlas_sgemm = 113; 149 private static final int RsBlas_ssymm = 114; 150 private static final int RsBlas_ssyrk = 115; 151 private static final int RsBlas_ssyr2k = 116; 152 private static final int RsBlas_strmm = 117; 153 private static final int RsBlas_strsm = 118; 154 private static final int RsBlas_dgemm = 119; 155 private static final int RsBlas_dsymm = 120; 156 private static final int RsBlas_dsyrk = 121; 157 private static final int RsBlas_dsyr2k = 122; 158 private static final int RsBlas_dtrmm = 123; 159 private static final int RsBlas_dtrsm = 124; 160 private static final int RsBlas_cgemm = 125; 161 private static final int RsBlas_csymm = 126; 162 private static final int RsBlas_csyrk = 127; 163 private static final int RsBlas_csyr2k = 128; 164 private static final int RsBlas_ctrmm = 129; 165 private static final int RsBlas_ctrsm = 130; 166 private static final int RsBlas_zgemm = 131; 167 private static final int RsBlas_zsymm = 132; 168 private static final int RsBlas_zsyrk = 133; 169 private static final int RsBlas_zsyr2k = 134; 170 private static final int RsBlas_ztrmm = 135; 171 private static final int RsBlas_ztrsm = 136; 172 private static final int RsBlas_chemm = 137; 173 private static final int RsBlas_cherk = 138; 174 private static final int RsBlas_cher2k = 139; 175 private static final int RsBlas_zhemm = 140; 176 private static final int RsBlas_zherk = 141; 177 private static final int RsBlas_zher2k = 142; 178 179 // BLAS extensions start here 180 private static final int RsBlas_bnnm = 1000; 181 182 /** 183 */ 184 public static ScriptIntrinsicBLAS create(RenderScript rs) { 185 long id = rs.nScriptIntrinsicCreate(13, Element.U32(rs).getID(rs)); 186 return new ScriptIntrinsicBLAS(id, rs); 187 } 188 189 @IntDef({NO_TRANSPOSE, TRANSPOSE, CONJ_TRANSPOSE}) 190 @Retention(RetentionPolicy.SOURCE) 191 public @interface Transpose {} 192 193 @IntDef({UPPER, LOWER}) 194 @Retention(RetentionPolicy.SOURCE) 195 public @interface Uplo {} 196 197 @IntDef({NON_UNIT, UNIT}) 198 @Retention(RetentionPolicy.SOURCE) 199 public @interface Diag {} 200 201 @IntDef({LEFT, RIGHT}) 202 @Retention(RetentionPolicy.SOURCE) 203 public @interface Side {} 204 205 public static final int NO_TRANSPOSE = 111; 206 public static final int TRANSPOSE = 112; 207 public static final int CONJ_TRANSPOSE = 113; 208 209 public static final int UPPER = 121; 210 public static final int LOWER = 122; 211 212 public static final int NON_UNIT = 131; 213 public static final int UNIT = 132; 214 215 public static final int LEFT = 141; 216 public static final int RIGHT = 142; 217 218 static void validateSide(@Side int Side) { 219 if (Side != LEFT && Side != RIGHT) { 220 throw new RSRuntimeException("Invalid side passed to BLAS"); 221 } 222 } 223 224 static void validateTranspose(@Transpose int Trans) { 225 if (Trans != NO_TRANSPOSE && Trans != TRANSPOSE && 226 Trans != CONJ_TRANSPOSE) { 227 throw new RSRuntimeException("Invalid transpose passed to BLAS"); 228 } 229 } 230 231 static void validateConjTranspose(@Transpose int Trans) { 232 if (Trans != NO_TRANSPOSE && 233 Trans != CONJ_TRANSPOSE) { 234 throw new RSRuntimeException("Invalid transpose passed to BLAS"); 235 } 236 } 237 238 static void validateDiag(@Diag int Diag) { 239 if (Diag != NON_UNIT && Diag != UNIT) { 240 throw new RSRuntimeException("Invalid diag passed to BLAS"); 241 } 242 } 243 244 static void validateUplo(@Uplo int Uplo) { 245 if (Uplo != UPPER && Uplo != LOWER) { 246 throw new RSRuntimeException("Invalid uplo passed to BLAS"); 247 } 248 } 249 250 251 /** 252 * Level 2 BLAS 253 */ 254 255 static void validateGEMV(Element e, int TransA, Allocation A, Allocation X, int incX, Allocation Y, int incY) { 256 validateTranspose(TransA); 257 int M = A.getType().getY(); 258 int N = A.getType().getX(); 259 if (!A.getType().getElement().isCompatible(e) || 260 !X.getType().getElement().isCompatible(e) || 261 !Y.getType().getElement().isCompatible(e)) { 262 throw new RSRuntimeException("Called BLAS with wrong Element type"); 263 } 264 if (X.getType().getY() > 1 || Y.getType().getY() > 1) { 265 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); 266 } 267 268 if (incX <= 0 || incY <= 0) { 269 throw new RSRuntimeException("Vector increments must be greater than 0"); 270 } 271 int expectedXDim = -1, expectedYDim = -1; 272 if (TransA == NO_TRANSPOSE) { 273 expectedXDim = 1 + (N - 1) * incX; 274 expectedYDim = 1 + (M - 1) * incY; 275 } else { 276 expectedXDim = 1 + (M - 1) * incX; 277 expectedYDim = 1 + (N - 1) * incY; 278 } 279 if (X.getType().getX() != expectedXDim || 280 Y.getType().getX() != expectedYDim) { 281 throw new RSRuntimeException("Incorrect vector dimensions for GEMV"); 282 } 283 } 284 public void SGEMV(@Transpose int TransA, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) { 285 validateGEMV(Element.F32(mRS), TransA, A, X, incX, Y, incY); 286 int M = A.getType().getY(); 287 int N = A.getType().getX(); 288 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); 289 } 290 public void DGEMV(@Transpose int TransA, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) { 291 validateGEMV(Element.F64(mRS), TransA, A, X, incX, Y, incY); 292 int M = A.getType().getY(); 293 int N = A.getType().getX(); 294 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); 295 } 296 public void CGEMV(@Transpose int TransA, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { 297 validateGEMV(Element.F32_2(mRS), TransA, A, X, incX, Y, incY); 298 int M = A.getType().getY(); 299 int N = A.getType().getX(); 300 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); 301 } 302 public void ZGEMV(@Transpose int TransA, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { 303 validateGEMV(Element.F64_2(mRS), TransA, A, X, incX, Y, incY); 304 int M = A.getType().getY(); 305 int N = A.getType().getX(); 306 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); 307 } 308 309 public void SGBMV(@Transpose int TransA, int KL, int KU, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) { 310 // GBMV has the same validation requirements as GEMV + KL and KU >= 0 311 validateGEMV(Element.F32(mRS), TransA, A, X, incX, Y, incY); 312 if (KL < 0 || KU < 0) { 313 throw new RSRuntimeException("KL and KU must be greater than or equal to 0"); 314 } 315 int M = A.getType().getY(); 316 int N = A.getType().getX(); 317 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, KL, KU); 318 } 319 public void DGBMV(@Transpose int TransA, int KL, int KU, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) { 320 // GBMV has the same validation requirements as GEMV + KL and KU >= 0 321 validateGEMV(Element.F64(mRS), TransA, A, X, incX, Y, incY); 322 if (KL < 0 || KU < 0) { 323 throw new RSRuntimeException("KL and KU must be greater than or equal to 0"); 324 } 325 int M = A.getType().getY(); 326 int N = A.getType().getX(); 327 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, KL, KU); 328 } 329 public void CGBMV(@Transpose int TransA, int KL, int KU, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { 330 // GBMV has the same validation requirements as GEMV + KL and KU >= 0 331 validateGEMV(Element.F32_2(mRS), TransA, A, X, incX, Y, incY); 332 if (KL < 0 || KU < 0) { 333 throw new RSRuntimeException("KL and KU must be greater than or equal to 0"); 334 } 335 int M = A.getType().getY(); 336 int N = A.getType().getX(); 337 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU); 338 } 339 public void ZGBMV(@Transpose int TransA, int KL, int KU, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { 340 // GBMV has the same validation requirements as GEMV + KL and KU >= 0 341 validateGEMV(Element.F64_2(mRS), TransA, A, X, incX, Y, incY); 342 if (KL < 0 || KU < 0) { 343 throw new RSRuntimeException("KL and KU must be greater than or equal to 0"); 344 } 345 int M = A.getType().getY(); 346 int N = A.getType().getX(); 347 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU); 348 } 349 350 static void validateTRMV(Element e, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { 351 validateTranspose(TransA); 352 validateUplo(Uplo); 353 validateDiag(Diag); 354 int N = A.getType().getY(); 355 if (A.getType().getX() != N) { 356 throw new RSRuntimeException("A must be a square matrix for TRMV"); 357 } 358 if (!A.getType().getElement().isCompatible(e) || 359 !X.getType().getElement().isCompatible(e)) { 360 throw new RSRuntimeException("Called BLAS with wrong Element type"); 361 } 362 if (X.getType().getY() > 1) { 363 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); 364 } 365 366 if (incX <= 0) { 367 throw new RSRuntimeException("Vector increments must be greater than 0"); 368 } 369 int expectedXDim = 1 + (N - 1) * incX; 370 if (X.getType().getX() != expectedXDim) { 371 throw new RSRuntimeException("Incorrect vector dimensions for TRMV"); 372 } 373 } 374 375 static int validateTPMV(Element e, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { 376 validateTranspose(TransA); 377 validateUplo(Uplo); 378 validateDiag(Diag); 379 if (!Ap.getType().getElement().isCompatible(e) || 380 !X.getType().getElement().isCompatible(e)) { 381 throw new RSRuntimeException("Called BLAS with wrong Element type"); 382 } 383 if (X.getType().getY() > 1) { 384 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); 385 } 386 387 if (Ap.getType().getY() > 1) { 388 throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1"); 389 } 390 391 int N = (int)Math.sqrt((double)Ap.getType().getX() * 2); 392 //is it really doing anything? 393 if (Ap.getType().getX() != ((N * (N+1)) / 2)) { 394 throw new RSRuntimeException("Invalid dimension for Ap"); 395 } 396 if (incX <= 0) { 397 throw new RSRuntimeException("Vector increments must be greater than 0"); 398 } 399 int expectedXDim = 1 + (N - 1) * incX; 400 if (X.getType().getX() != expectedXDim) { 401 throw new RSRuntimeException("Incorrect vector dimensions for TPMV"); 402 } 403 404 return N; 405 } 406 407 public void STRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { 408 validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX); 409 int N = A.getType().getY(); 410 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); 411 } 412 public void DTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { 413 validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX); 414 int N = A.getType().getY(); 415 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); 416 } 417 public void CTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { 418 validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX); 419 int N = A.getType().getY(); 420 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); 421 } 422 public void ZTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { 423 validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX); 424 int N = A.getType().getY(); 425 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); 426 } 427 428 public void STBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { 429 // TBMV has the same requirements as TRMV + K >= 0 430 if (K < 0) { 431 throw new RSRuntimeException("K must be greater than or equal to 0"); 432 } 433 validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX); 434 int N = A.getType().getY(); 435 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); 436 } 437 public void DTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { 438 // TBMV has the same requirements as TRMV + K >= 0 439 if (K < 0) { 440 throw new RSRuntimeException("K must be greater than or equal to 0"); 441 } 442 validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX); 443 int N = A.getType().getY(); 444 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); 445 } 446 public void CTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { 447 // TBMV has the same requirements as TRMV + K >= 0 448 if (K < 0) { 449 throw new RSRuntimeException("K must be greater than or equal to 0"); 450 } 451 validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX); 452 int N = A.getType().getY(); 453 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); 454 } 455 public void ZTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { 456 // TBMV has the same requirements as TRMV + K >= 0 457 if (K < 0) { 458 throw new RSRuntimeException("K must be greater than or equal to 0"); 459 } 460 validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX); 461 int N = A.getType().getY(); 462 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); 463 } 464 public void STPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { 465 int N = validateTPMV(Element.F32(mRS), Uplo, TransA, Diag, Ap, X, incX); 466 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); 467 } 468 public void DTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { 469 int N = validateTPMV(Element.F64(mRS), Uplo, TransA, Diag, Ap, X, incX); 470 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); 471 } 472 public void CTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { 473 int N = validateTPMV(Element.F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX); 474 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); 475 } 476 public void ZTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { 477 int N = validateTPMV(Element.F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX); 478 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); 479 } 480 public void STRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { 481 // TRSV is the same as TRMV 482 validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX); 483 int N = A.getType().getY(); 484 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); 485 486 } 487 public void DTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { 488 // TRSV is the same as TRMV 489 validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX); 490 int N = A.getType().getY(); 491 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); 492 493 } 494 public void CTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { 495 // TRSV is the same as TRMV 496 validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX); 497 int N = A.getType().getY(); 498 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); 499 500 } 501 public void ZTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { 502 // TRSV is the same as TRMV 503 validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX); 504 int N = A.getType().getY(); 505 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); 506 507 } 508 public void STBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { 509 // TBSV is the same as TRMV + K >= 0 510 validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX); 511 int N = A.getType().getY(); 512 if (K < 0) { 513 throw new RSRuntimeException("Number of diagonals must be positive"); 514 } 515 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); 516 } 517 public void DTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { 518 // TBSV is the same as TRMV + K >= 0 519 validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX); 520 int N = A.getType().getY(); 521 if (K < 0) { 522 throw new RSRuntimeException("Number of diagonals must be positive"); 523 } 524 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); 525 } 526 public void CTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { 527 // TBSV is the same as TRMV + K >= 0 528 validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX); 529 int N = A.getType().getY(); 530 if (K < 0) { 531 throw new RSRuntimeException("Number of diagonals must be positive"); 532 } 533 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); 534 } 535 public void ZTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { 536 // TBSV is the same as TRMV + K >= 0 537 validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX); 538 int N = A.getType().getY(); 539 if (K < 0) { 540 throw new RSRuntimeException("Number of diagonals must be positive"); 541 } 542 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); 543 } 544 public void STPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { 545 // TPSV is same as TPMV 546 int N = validateTPMV(Element.F32(mRS), Uplo, TransA, Diag, Ap, X, incX); 547 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); 548 } 549 public void DTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { 550 // TPSV is same as TPMV 551 int N = validateTPMV(Element.F64(mRS), Uplo, TransA, Diag, Ap, X, incX); 552 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); 553 } 554 public void CTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { 555 // TPSV is same as TPMV 556 int N = validateTPMV(Element.F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX); 557 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); 558 } 559 public void ZTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { 560 // TPSV is same as TPMV 561 int N = validateTPMV(Element.F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX); 562 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); 563 } 564 565 /** 566 * Level 2, S and D only 567 */ 568 static int validateSYMV(Element e, @Uplo int Uplo, Allocation A, Allocation X, Allocation Y, int incX, int incY) { 569 validateUplo(Uplo); 570 int N = A.getType().getY(); 571 if (A.getType().getX() != N) { 572 throw new RSRuntimeException("A must be a square matrix for SYMV"); 573 } 574 if (!A.getType().getElement().isCompatible(e) || 575 !X.getType().getElement().isCompatible(e) || 576 !Y.getType().getElement().isCompatible(e) ) { 577 throw new RSRuntimeException("Called BLAS with wrong Element type"); 578 } 579 if (X.getType().getY() > 1 || Y.getType().getY() > 1) { 580 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); 581 } 582 583 if (incX <= 0 || incY <= 0) { 584 throw new RSRuntimeException("Vector increments must be greater than 0"); 585 } 586 int expectedXDim = 1 + (N - 1) * incX; 587 if (X.getType().getX() != expectedXDim) { 588 throw new RSRuntimeException("Incorrect vector dimensions for SYMV"); 589 } 590 int expectedYDim = 1 + (N - 1) * incY; 591 if (Y.getType().getX() != expectedYDim) { 592 throw new RSRuntimeException("Incorrect vector dimensions for SYMV"); 593 } 594 return N; 595 } 596 static int validateSPMV(Element e, @Uplo int Uplo, Allocation Ap, Allocation X, int incX, Allocation Y, int incY) { 597 validateUplo(Uplo); 598 if (!Ap.getType().getElement().isCompatible(e) || 599 !X.getType().getElement().isCompatible(e) || 600 !Y.getType().getElement().isCompatible(e)) { 601 throw new RSRuntimeException("Called BLAS with wrong Element type"); 602 } 603 if (X.getType().getY() > 1 || Y.getType().getY() > 1) { 604 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); 605 } 606 607 if (Ap.getType().getY() > 1) { 608 throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1"); 609 } 610 611 int N = (int)Math.sqrt((double)Ap.getType().getX() * 2); 612 if (Ap.getType().getX() != ((N * (N+1)) / 2)) { 613 throw new RSRuntimeException("Invalid dimension for Ap"); 614 } 615 if (incX <= 0 || incY <= 0) { 616 throw new RSRuntimeException("Vector increments must be greater than 0"); 617 } 618 int expectedXDim = 1 + (N - 1) * incX; 619 if (X.getType().getX() != expectedXDim) { 620 throw new RSRuntimeException("Incorrect vector dimensions for SPMV"); 621 } 622 int expectedYDim = 1 + (N - 1) * incY; 623 if (Y.getType().getX() != expectedYDim) { 624 throw new RSRuntimeException("Incorrect vector dimensions for SPMV"); 625 } 626 627 return N; 628 } 629 static void validateGER(Element e, Allocation X, int incX, Allocation Y, int incY, Allocation A) { 630 if (!A.getType().getElement().isCompatible(e) || 631 !X.getType().getElement().isCompatible(e) || 632 !Y.getType().getElement().isCompatible(e) ) { 633 throw new RSRuntimeException("Called BLAS with wrong Element type"); 634 } 635 636 if (X.getType().getY() > 1 || Y.getType().getY() > 1) { 637 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); 638 } 639 640 int M = A.getType().getY(); 641 int N = A.getType().getX(); 642 643 if (N < 1 || M < 1) { 644 throw new RSRuntimeException("M and N must be 1 or greater for GER"); 645 } 646 if (incX <= 0 || incY <= 0) { 647 throw new RSRuntimeException("Vector increments must be greater than 0"); 648 } 649 int expectedXDim = 1 + (M - 1) * incX; 650 if (X.getType().getX() != expectedXDim) { 651 throw new RSRuntimeException("Incorrect vector dimensions for GER"); 652 } 653 int expectedYDim = 1 + (N - 1) * incY; 654 if (Y.getType().getX() != expectedYDim) { 655 throw new RSRuntimeException("Incorrect vector dimensions for GER"); 656 } 657 658 659 } 660 static int validateSYR(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation A) { 661 validateUplo(Uplo); 662 if (!A.getType().getElement().isCompatible(e) || 663 !X.getType().getElement().isCompatible(e)) { 664 throw new RSRuntimeException("Called BLAS with wrong Element type"); 665 } 666 667 int N = A.getType().getX(); 668 669 if (X.getType().getY() > 1) { 670 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); 671 } 672 if (N != A.getType().getY()) { 673 throw new RSRuntimeException("A must be a symmetric matrix"); 674 } 675 if (incX <= 0) { 676 throw new RSRuntimeException("Vector increments must be greater than 0"); 677 } 678 int expectedXDim = 1 + (N - 1) * incX; 679 if (X.getType().getX() != expectedXDim) { 680 throw new RSRuntimeException("Incorrect vector dimensions for SYR"); 681 } 682 return N; 683 } 684 static int validateSPR(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Ap) { 685 validateUplo(Uplo); 686 if (!Ap.getType().getElement().isCompatible(e) || 687 !X.getType().getElement().isCompatible(e)) { 688 throw new RSRuntimeException("Called BLAS with wrong Element type"); 689 } 690 if (X.getType().getY() > 1) { 691 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); 692 } 693 694 if (Ap.getType().getY() > 1) { 695 throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1"); 696 } 697 698 int N = (int)Math.sqrt((double)Ap.getType().getX() * 2); 699 if (Ap.getType().getX() != ((N * (N+1)) / 2)) { 700 throw new RSRuntimeException("Invalid dimension for Ap"); 701 } 702 if (incX <= 0) { 703 throw new RSRuntimeException("Vector increments must be greater than 0"); 704 } 705 int expectedXDim = 1 + (N - 1) * incX; 706 if (X.getType().getX() != expectedXDim) { 707 throw new RSRuntimeException("Incorrect vector dimensions for SPR"); 708 } 709 710 return N; 711 } 712 713 static int validateSYR2(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Y, int incY, Allocation A) { 714 validateUplo(Uplo); 715 if (!A.getType().getElement().isCompatible(e) || 716 !X.getType().getElement().isCompatible(e) || 717 !Y.getType().getElement().isCompatible(e)) { 718 throw new RSRuntimeException("Called BLAS with wrong Element type"); 719 } 720 721 if (X.getType().getY() > 1 || Y.getType().getY() > 1) { 722 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); 723 } 724 725 int N = A.getType().getX(); 726 727 if (N != A.getType().getY()) { 728 throw new RSRuntimeException("A must be a symmetric matrix"); 729 } 730 if (incX <= 0 || incY <= 0) { 731 throw new RSRuntimeException("Vector increments must be greater than 0"); 732 } 733 int expectedXDim = 1 + (N - 1) * incX; 734 int expectedYDim = 1 + (N - 1) * incY; 735 if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) { 736 throw new RSRuntimeException("Incorrect vector dimensions for SYR"); 737 } 738 return N; 739 740 } 741 static int validateSPR2(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { 742 validateUplo(Uplo); 743 if (!Ap.getType().getElement().isCompatible(e) || 744 !X.getType().getElement().isCompatible(e) || 745 !Y.getType().getElement().isCompatible(e)) { 746 throw new RSRuntimeException("Called BLAS with wrong Element type"); 747 } 748 if (X.getType().getY() > 1 || Y.getType().getY() > 1) { 749 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); 750 } 751 752 if (Ap.getType().getY() > 1) { 753 throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1"); 754 } 755 756 int N = (int)Math.sqrt((double)Ap.getType().getX() * 2); 757 if (Ap.getType().getX() != ((N * (N+1)) / 2)) { 758 throw new RSRuntimeException("Invalid dimension for Ap"); 759 } 760 if (incX <= 0 || incY <= 0) { 761 throw new RSRuntimeException("Vector increments must be greater than 0"); 762 } 763 int expectedXDim = 1 + (N - 1) * incX; 764 int expectedYDim = 1 + (N - 1) * incY; 765 if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) { 766 throw new RSRuntimeException("Incorrect vector dimensions for SPR2"); 767 } 768 769 return N; 770 } 771 772 public void SSYMV(@Uplo int Uplo, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) { 773 int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY); 774 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); 775 } 776 public void SSBMV(@Uplo int Uplo, int K, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) { 777 // SBMV is the same as SYMV + K >= 0 778 if (K < 0) { 779 throw new RSRuntimeException("K must be greater than or equal to 0"); 780 } 781 int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY); 782 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); 783 } 784 public void SSPMV(@Uplo int Uplo, float alpha, Allocation Ap, Allocation X, int incX, float beta, Allocation Y, int incY) { 785 int N = validateSPMV(Element.F32(mRS), Uplo, Ap, X, incX, Y, incY); 786 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); 787 } 788 public void SGER(float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { 789 int M = A.getType().getY(); 790 int N = A.getType().getX(); 791 validateGER(Element.F32(mRS), X, incX, Y, incY, A); 792 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0); 793 } 794 public void SSYR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) { 795 int N = validateSYR(Element.F32(mRS), Uplo, X, incX, A); 796 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), A.getID(mRS), 0.f, 0, incX, 0, 0, 0); 797 } 798 public void SSPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) { 799 int N = validateSPR(Element.F32(mRS), Uplo, X, incX, Ap); 800 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Ap.getID(mRS), 0.f, 0, incX, 0, 0, 0); 801 } 802 public void SSYR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { 803 int N = validateSYR2(Element.F32(mRS), Uplo, X, incX, Y, incY, A); 804 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, A.getID(mRS), incX, incY, 0, 0); 805 } 806 public void SSPR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { 807 int N = validateSPR2(Element.F32(mRS), Uplo, X, incX, Y, incY, Ap); 808 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, Ap.getID(mRS), incX, incY, 0, 0); 809 } 810 public void DSYMV(@Uplo int Uplo, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) { 811 int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY); 812 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); 813 } 814 public void DSBMV(@Uplo int Uplo, int K, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) { 815 // SBMV is the same as SYMV + K >= 0 816 if (K < 0) { 817 throw new RSRuntimeException("K must be greater than or equal to 0"); 818 } 819 int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY); 820 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); 821 } 822 public void DSPMV(@Uplo int Uplo, double alpha, Allocation Ap, Allocation X, int incX, double beta, Allocation Y, int incY) { 823 int N = validateSPMV(Element.F64(mRS), Uplo, Ap, X, incX, Y, incY); 824 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); 825 } 826 public void DGER(double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { 827 int M = A.getType().getY(); 828 int N = A.getType().getX(); 829 validateGER(Element.F64(mRS), X, incX, Y, incY, A); 830 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0); 831 } 832 public void DSYR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) { 833 int N = validateSYR(Element.F64(mRS), Uplo, X, incX, A); 834 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), A.getID(mRS), 0.f, 0, incX, 0, 0, 0); 835 } 836 public void DSPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) { 837 int N = validateSPR(Element.F64(mRS), Uplo, X, incX, Ap); 838 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Ap.getID(mRS), 0.f, 0, incX, 0, 0, 0); 839 } 840 public void DSYR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { 841 int N = validateSYR2(Element.F64(mRS), Uplo, X, incX, Y, incY, A); 842 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, A.getID(mRS), incX, incY, 0, 0); 843 } 844 public void DSPR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { 845 int N = validateSPR2(Element.F64(mRS), Uplo, X, incX, Y, incY, Ap); 846 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, Ap.getID(mRS), incX, incY, 0, 0); 847 } 848 849 850 /** 851 * Level 2, C and Z only 852 */ 853 854 static void validateGERU(Element e, Allocation X, int incX, Allocation Y, int incY, Allocation A) { 855 if (!A.getType().getElement().isCompatible(e) || 856 !X.getType().getElement().isCompatible(e) || 857 !Y.getType().getElement().isCompatible(e)) { 858 throw new RSRuntimeException("Called BLAS with wrong Element type"); 859 } 860 if (X.getType().getY() > 1 || Y.getType().getY() > 1) { 861 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); 862 } 863 864 int M = A.getType().getY(); 865 int N = A.getType().getX(); 866 if (incX <= 0 || incY <= 0) { 867 throw new RSRuntimeException("Vector increments must be greater than 0"); 868 } 869 int expectedXDim = 1 + (M - 1) * incX; 870 if (X.getType().getX() != expectedXDim) { 871 throw new RSRuntimeException("Incorrect vector dimensions for GERU"); 872 } 873 int expectedYDim = 1 + (N - 1) * incY; 874 if (Y.getType().getX() != expectedYDim) { 875 throw new RSRuntimeException("Incorrect vector dimensions for GERU"); 876 } 877 878 } 879 880 public void CHEMV(@Uplo int Uplo, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { 881 // HEMV is the same as SYR2 validation-wise 882 int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A); 883 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); 884 } 885 public void CHBMV(@Uplo int Uplo, int K, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { 886 // HBMV is the same as SYR2 validation-wise 887 int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A); 888 if (K < 0) { 889 throw new RSRuntimeException("K must be 0 or greater for HBMV"); 890 } 891 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); 892 } 893 public void CHPMV(@Uplo int Uplo, Float2 alpha, Allocation Ap, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { 894 // HPMV is the same as SPR2 895 int N = validateSPR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, Ap); 896 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); 897 } 898 public void CGERU(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { 899 validateGERU(Element.F32_2(mRS), X, incX, Y, incY, A); 900 int M = A.getType().getY(); 901 int N = A.getType().getX(); 902 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); 903 } 904 public void CGERC(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { 905 // same as GERU 906 validateGERU(Element.F32_2(mRS), X, incX, Y, incY, A); 907 int M = A.getType().getY(); 908 int N = A.getType().getX(); 909 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); 910 } 911 public void CHER(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) { 912 // same as SYR 913 int N = validateSYR(Element.F32_2(mRS), Uplo, X, incX, A); 914 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0); 915 } 916 public void CHPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) { 917 // equivalent to SPR for validation 918 int N = validateSPR(Element.F32_2(mRS), Uplo, X, incX, Ap); 919 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, Ap.getID(mRS), incX, 0, 0, 0); 920 } 921 public void CHER2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { 922 // same as SYR2 923 int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A); 924 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); 925 } 926 public void CHPR2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { 927 // same as SPR2 928 int N = validateSPR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, Ap); 929 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, Ap.getID(mRS), incX, incY, 0, 0); 930 } 931 public void ZHEMV(@Uplo int Uplo, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { 932 // HEMV is the same as SYR2 validation-wise 933 int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A); 934 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); 935 } 936 public void ZHBMV(@Uplo int Uplo, int K, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { 937 // HBMV is the same as SYR2 validation-wise 938 int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A); 939 if (K < 0) { 940 throw new RSRuntimeException("K must be 0 or greater for HBMV"); 941 } 942 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); 943 } 944 public void ZHPMV(@Uplo int Uplo, Double2 alpha, Allocation Ap, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { 945 // HPMV is the same as SPR2 946 int N = validateSPR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, Ap); 947 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); 948 } 949 public void ZGERU(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { 950 validateGERU(Element.F64_2(mRS), X, incX, Y, incY, A); 951 int M = A.getType().getY(); 952 int N = A.getType().getX(); 953 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); 954 } 955 public void ZGERC(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { 956 // same as GERU 957 validateGERU(Element.F64_2(mRS), X, incX, Y, incY, A); 958 int M = A.getType().getY(); 959 int N = A.getType().getX(); 960 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); 961 } 962 public void ZHER(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) { 963 // same as SYR 964 int N = validateSYR(Element.F64_2(mRS), Uplo, X, incX, A); 965 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0); 966 } 967 public void ZHPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) { 968 // equivalent to SPR for validation 969 int N = validateSPR(Element.F64_2(mRS), Uplo, X, incX, Ap); 970 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, Ap.getID(mRS), incX, 0, 0, 0); 971 } 972 public void ZHER2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { 973 // same as SYR2 974 int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A); 975 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); 976 } 977 public void ZHPR2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { 978 // same as SPR2 979 int N = validateSPR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, Ap); 980 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, Ap.getID(mRS), incX, incY, 0, 0); 981 } 982 983 984 /** 985 * Level 3 BLAS 986 */ 987 988 static void validateL3(Element e, int TransA, int TransB, int Side, Allocation A, Allocation B, Allocation C) { 989 int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1; 990 if ((A != null && !A.getType().getElement().isCompatible(e)) || 991 (B != null && !B.getType().getElement().isCompatible(e)) || 992 (C != null && !C.getType().getElement().isCompatible(e))) { 993 throw new RSRuntimeException("Called BLAS with wrong Element type"); 994 } 995 if (C == null) { 996 //since matrix C is used to store the result, it cannot be null. 997 throw new RSRuntimeException("Allocation C cannot be null"); 998 } 999 cM = C.getType().getY(); 1000 cN = C.getType().getX(); 1001 1002 if (Side == RIGHT) { 1003 if ((A == null && B != null) || (A != null && B == null)) { 1004 throw new RSRuntimeException("Provided Matrix A without Matrix B, or vice versa"); 1005 } 1006 if (B != null) { 1007 bM = A.getType().getY(); 1008 bN = A.getType().getX(); 1009 } 1010 if (A != null) { 1011 aM = B.getType().getY(); 1012 aN = B.getType().getX(); 1013 } 1014 } else { 1015 if (A != null) { 1016 if (TransA == TRANSPOSE || TransA == CONJ_TRANSPOSE) { 1017 aN = A.getType().getY(); 1018 aM = A.getType().getX(); 1019 } else { 1020 aM = A.getType().getY(); 1021 aN = A.getType().getX(); 1022 } 1023 } 1024 if (B != null) { 1025 if (TransB == TRANSPOSE || TransB == CONJ_TRANSPOSE) { 1026 bN = B.getType().getY(); 1027 bM = B.getType().getX(); 1028 } else { 1029 bM = B.getType().getY(); 1030 bN = B.getType().getX(); 1031 } 1032 } 1033 } 1034 if (A != null && B != null && C != null) { 1035 if (aN != bM || aM != cM || bN != cN) { 1036 throw new RSRuntimeException("Called BLAS with invalid dimensions"); 1037 } 1038 } else if (A != null && C != null) { 1039 // A and C only, for SYRK 1040 if (cM != cN) { 1041 throw new RSRuntimeException("Matrix C is not symmetric"); 1042 } 1043 if (aM != cM) { 1044 throw new RSRuntimeException("Called BLAS with invalid dimensions"); 1045 } 1046 } else if (A != null && B != null) { 1047 // A and B only 1048 if (aN != bM) { 1049 throw new RSRuntimeException("Called BLAS with invalid dimensions"); 1050 } 1051 } 1052 1053 } 1054 1055 public void SGEMM(@Transpose int TransA, @Transpose int TransB, float alpha, Allocation A, 1056 Allocation B, float beta, Allocation C) { 1057 validateTranspose(TransA); 1058 validateTranspose(TransB); 1059 validateL3(Element.F32(mRS), TransA, TransB, 0, A, B, C); 1060 1061 int M = -1, N = -1, K = -1; 1062 if (TransA != NO_TRANSPOSE) { 1063 M = A.getType().getX(); 1064 K = A.getType().getY(); 1065 } else { 1066 M = A.getType().getY(); 1067 K = A.getType().getX(); 1068 } 1069 if (TransB != NO_TRANSPOSE) { 1070 N = B.getType().getY(); 1071 } else { 1072 N = B.getType().getX(); 1073 } 1074 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha, A.getID(mRS), B.getID(mRS), 1075 beta, C.getID(mRS), 0, 0, 0, 0); 1076 } 1077 public void DGEMM(@Transpose int TransA, @Transpose int TransB, double alpha, Allocation A, 1078 Allocation B, double beta, Allocation C) { 1079 validateTranspose(TransA); 1080 validateTranspose(TransB); 1081 validateL3(Element.F64(mRS), TransA, TransB, 0, A, B, C); 1082 int M = -1, N = -1, K = -1; 1083 if (TransA != NO_TRANSPOSE) { 1084 M = A.getType().getX(); 1085 K = A.getType().getY(); 1086 } else { 1087 M = A.getType().getY(); 1088 K = A.getType().getX(); 1089 } 1090 if (TransB != NO_TRANSPOSE) { 1091 N = B.getType().getY(); 1092 } else { 1093 N = B.getType().getX(); 1094 } 1095 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha, A.getID(mRS), B.getID(mRS), 1096 beta, C.getID(mRS), 0, 0, 0, 0); 1097 } 1098 public void CGEMM(@Transpose int TransA, @Transpose int TransB, Float2 alpha, Allocation A, 1099 Allocation B, Float2 beta, Allocation C) { 1100 validateTranspose(TransA); 1101 validateTranspose(TransB); 1102 validateL3(Element.F32_2(mRS), TransA, TransB, 0, A, B, C); 1103 int M = -1, N = -1, K = -1; 1104 if (TransA != NO_TRANSPOSE) { 1105 M = A.getType().getX(); 1106 K = A.getType().getY(); 1107 } else { 1108 M = A.getType().getY(); 1109 K = A.getType().getX(); 1110 } 1111 if (TransB != NO_TRANSPOSE) { 1112 N = B.getType().getY(); 1113 } else { 1114 N = B.getType().getX(); 1115 } 1116 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 1117 beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); 1118 } 1119 1120 public void ZGEMM(@Transpose int TransA, @Transpose int TransB, Double2 alpha, Allocation A, 1121 Allocation B, Double2 beta, Allocation C) { 1122 validateTranspose(TransA); 1123 validateTranspose(TransB); 1124 validateL3(Element.F64_2(mRS), TransA, TransB, 0, A, B, C); 1125 int M = -1, N = -1, K = -1; 1126 if (TransA != NO_TRANSPOSE) { 1127 M = A.getType().getX(); 1128 K = A.getType().getY(); 1129 } else { 1130 M = A.getType().getY(); 1131 K = A.getType().getX(); 1132 } 1133 if (TransB != NO_TRANSPOSE) { 1134 N = B.getType().getY(); 1135 } else { 1136 N = B.getType().getX(); 1137 } 1138 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 1139 beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); 1140 } 1141 1142 public void SSYMM(@Side int Side, @Uplo int Uplo, float alpha, Allocation A, 1143 Allocation B, float beta, Allocation C) { 1144 validateSide(Side); 1145 validateUplo(Uplo); 1146 //For SYMM, Matrix A should be symmetric 1147 if (A.getType().getX() != A.getType().getY()) { 1148 throw new RSRuntimeException("Matrix A is not symmetric"); 1149 } 1150 validateL3(Element.F32(mRS), 0, 0, Side, A, B, C); 1151 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS), 1152 beta, C.getID(mRS), 0, 0, 0, 0); 1153 } 1154 public void DSYMM(@Side int Side, @Uplo int Uplo, double alpha, Allocation A, 1155 Allocation B, double beta, Allocation C) { 1156 validateSide(Side); 1157 validateUplo(Uplo); 1158 if (A.getType().getX() != A.getType().getY()) { 1159 throw new RSRuntimeException("Matrix A is not symmetric"); 1160 } 1161 validateL3(Element.F64(mRS), 0, 0, Side, A, B, C); 1162 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS), 1163 beta, C.getID(mRS), 0, 0, 0, 0); 1164 } 1165 public void CSYMM(@Side int Side, @Uplo int Uplo, Float2 alpha, Allocation A, 1166 Allocation B, Float2 beta, Allocation C) { 1167 validateSide(Side); 1168 validateUplo(Uplo); 1169 if (A.getType().getX() != A.getType().getY()) { 1170 throw new RSRuntimeException("Matrix A is not symmetric"); 1171 } 1172 validateL3(Element.F32_2(mRS), 0, 0, Side, A, B, C); 1173 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 1174 beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); 1175 } 1176 public void ZSYMM(@Side int Side, @Uplo int Uplo, Double2 alpha, Allocation A, 1177 Allocation B, Double2 beta, Allocation C) { 1178 validateSide(Side); 1179 validateUplo(Uplo); 1180 if (A.getType().getX() != A.getType().getY()) { 1181 throw new RSRuntimeException("Matrix A is not symmetric"); 1182 } 1183 validateL3(Element.F64_2(mRS), 0, 0, Side, A, B, C); 1184 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 1185 beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); 1186 } 1187 1188 public void SSYRK(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, float beta, Allocation C) { 1189 validateTranspose(Trans); 1190 validateUplo(Uplo); 1191 validateL3(Element.F32(mRS), Trans, 0, 0, A, null, C); 1192 int K = -1; 1193 if (Trans != NO_TRANSPOSE) { 1194 K = A.getType().getY(); 1195 } else { 1196 K = A.getType().getX(); 1197 } 1198 1199 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), 0, beta, C.getID(mRS), 0, 0, 0, 0); 1200 } 1201 1202 public void DSYRK(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, double beta, Allocation C) { 1203 validateTranspose(Trans); 1204 validateUplo(Uplo); 1205 validateL3(Element.F64(mRS), Trans, 0, 0, A, null, C); 1206 int K = -1; 1207 if (Trans != NO_TRANSPOSE) { 1208 K = A.getType().getY(); 1209 } else { 1210 K = A.getType().getX(); 1211 } 1212 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), 0, beta, C.getID(mRS), 0, 0, 0, 0); 1213 } 1214 public void CSYRK(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Float2 beta, Allocation C) { 1215 validateTranspose(Trans); 1216 validateUplo(Uplo); 1217 validateL3(Element.F32_2(mRS), Trans, 0, 0, A, null, C); 1218 int K = -1; 1219 if (Trans != NO_TRANSPOSE) { 1220 K = A.getType().getY(); 1221 } else { 1222 K = A.getType().getX(); 1223 } 1224 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), 0, beta.x, beta.y, 1225 C.getID(mRS), 0, 0, 0, 0); 1226 } 1227 public void ZSYRK(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Double2 beta, Allocation C) { 1228 validateTranspose(Trans); 1229 validateUplo(Uplo); 1230 validateL3(Element.F64_2(mRS), Trans, 0, 0, A, null, C); 1231 int K = -1; 1232 if (Trans != NO_TRANSPOSE) { 1233 K = A.getType().getY(); 1234 } else { 1235 K = A.getType().getX(); 1236 } 1237 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), 0, beta.x, beta.y, 1238 C.getID(mRS), 0, 0, 0, 0); 1239 } 1240 1241 static void validateSYR2K(Element e, @Transpose int Trans, Allocation A, Allocation B, Allocation C) { 1242 validateTranspose(Trans); 1243 if (!A.getType().getElement().isCompatible(e) || 1244 !B.getType().getElement().isCompatible(e) || 1245 !C.getType().getElement().isCompatible(e)) { 1246 throw new RSRuntimeException("Called BLAS with wrong Element type"); 1247 } 1248 int Cdim = -1; 1249 // A is n x k if no transpose, k x n if transpose 1250 // C is n x n 1251 if (Trans == TRANSPOSE) { 1252 // check columns versus C 1253 Cdim = A.getType().getX(); 1254 } else { 1255 // check rows versus C 1256 Cdim = A.getType().getY(); 1257 } 1258 if (C.getType().getX() != Cdim || C.getType().getY() != Cdim) { 1259 throw new RSRuntimeException("Invalid symmetric matrix in SYR2K"); 1260 } 1261 // A dims == B dims 1262 if (A.getType().getX() != B.getType().getX() || A.getType().getY() != B.getType().getY()) { 1263 throw new RSRuntimeException("Invalid A and B in SYR2K"); 1264 } 1265 } 1266 public void SSYR2K(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, Allocation B, float beta, Allocation C) { 1267 validateUplo(Uplo); 1268 validateSYR2K(Element.F32(mRS), Trans, A, B, C); 1269 int K = -1; 1270 if (Trans != NO_TRANSPOSE) { 1271 K = A.getType().getY(); 1272 } else { 1273 K = A.getType().getX(); 1274 } 1275 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0); 1276 } 1277 public void DSYR2K(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, Allocation B, double beta, Allocation C) { 1278 validateUplo(Uplo); 1279 validateSYR2K(Element.F64(mRS), Trans, A, B, C); 1280 int K = -1; 1281 if (Trans != NO_TRANSPOSE) { 1282 K = A.getType().getY(); 1283 } else { 1284 K = A.getType().getX(); 1285 } 1286 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0); 1287 } 1288 public void CSYR2K(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Allocation B, Float2 beta, Allocation C) { 1289 validateUplo(Uplo); 1290 validateSYR2K(Element.F32_2(mRS), Trans, A, B, C); 1291 int K = -1; 1292 if (Trans != NO_TRANSPOSE) { 1293 K = A.getType().getY(); 1294 } else { 1295 K = A.getType().getX(); 1296 } 1297 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); 1298 } 1299 public void ZSYR2K(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) { 1300 validateUplo(Uplo); 1301 validateSYR2K(Element.F64_2(mRS), Trans, A, B, C); 1302 int K = -1; 1303 if (Trans != NO_TRANSPOSE) { 1304 K = A.getType().getY(); 1305 } else { 1306 K = A.getType().getX(); 1307 } 1308 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); 1309 } 1310 1311 static void validateTRMM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) { 1312 validateSide(Side); 1313 validateTranspose(TransA); 1314 int aM = -1, aN = -1, bM = -1, bN = -1; 1315 if (!A.getType().getElement().isCompatible(e) || 1316 !B.getType().getElement().isCompatible(e)) { 1317 throw new RSRuntimeException("Called BLAS with wrong Element type"); 1318 } 1319 1320 aM = A.getType().getY(); 1321 aN = A.getType().getX(); 1322 if (aM != aN) { 1323 throw new RSRuntimeException("Called TRMM with a non-symmetric matrix A"); 1324 } 1325 1326 bM = B.getType().getY(); 1327 bN = B.getType().getX(); 1328 if (Side == LEFT) { 1329 if (aN != bM) { 1330 throw new RSRuntimeException("Called TRMM with invalid matrices"); 1331 } 1332 } else { 1333 if (bN != aM) { 1334 throw new RSRuntimeException("Called TRMM with invalid matrices"); 1335 } 1336 } 1337 } 1338 public void STRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, float alpha, Allocation A, Allocation B) { 1339 validateUplo(Uplo); 1340 validateDiag(Diag); 1341 validateTRMM(Element.F32(mRS), Side, TransA, A, B); 1342 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, 1343 alpha, A.getID(mRS), B.getID(mRS), 0.f, 0, 0, 0, 0, 0); 1344 } 1345 public void DTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, double alpha, Allocation A, Allocation B) { 1346 validateUplo(Uplo); 1347 validateDiag(Diag); 1348 validateTRMM(Element.F64(mRS), Side, TransA, A, B); 1349 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, 1350 alpha, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0); 1351 } 1352 public void CTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Float2 alpha, Allocation A, Allocation B) { 1353 validateUplo(Uplo); 1354 validateDiag(Diag); 1355 validateTRMM(Element.F32_2(mRS), Side, TransA, A, B); 1356 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, 1357 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0); 1358 } 1359 public void ZTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Double2 alpha, Allocation A, Allocation B) { 1360 validateUplo(Uplo); 1361 validateDiag(Diag); 1362 validateTRMM(Element.F64_2(mRS), Side, TransA, A, B); 1363 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, 1364 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0); 1365 } 1366 1367 static void validateTRSM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) { 1368 int adim = -1, bM = -1, bN = -1; 1369 validateSide(Side); 1370 validateTranspose(TransA); 1371 if (!A.getType().getElement().isCompatible(e) || 1372 !B.getType().getElement().isCompatible(e)) { 1373 throw new RSRuntimeException("Called BLAS with wrong Element type"); 1374 } 1375 adim = A.getType().getX(); 1376 if (adim != A.getType().getY()) { 1377 // this may be unnecessary, the restriction could potentially be relaxed 1378 // A needs to contain at least that symmetric matrix but could theoretically be larger 1379 // for now we assume adapters are sufficient, will reevaluate in the future 1380 throw new RSRuntimeException("Called TRSM with a non-symmetric matrix A"); 1381 } 1382 bM = B.getType().getY(); 1383 bN = B.getType().getX(); 1384 if (Side == LEFT) { 1385 // A is M*M 1386 if (adim != bM) { 1387 throw new RSRuntimeException("Called TRSM with invalid matrix dimensions"); 1388 } 1389 } else { 1390 // A is N*N 1391 if (adim != bN) { 1392 throw new RSRuntimeException("Called TRSM with invalid matrix dimensions"); 1393 } 1394 } 1395 } 1396 public void STRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, float alpha, Allocation A, Allocation B) { 1397 validateUplo(Uplo); 1398 validateDiag(Diag); 1399 validateTRSM(Element.F32(mRS), Side, TransA, A, B); 1400 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, 1401 alpha, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0); 1402 } 1403 public void DTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, double alpha, Allocation A, Allocation B) { 1404 validateUplo(Uplo); 1405 validateDiag(Diag); 1406 validateTRSM(Element.F64(mRS), Side, TransA, A, B); 1407 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, 1408 alpha, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0); 1409 } 1410 public void CTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Float2 alpha, Allocation A, Allocation B) { 1411 validateUplo(Uplo); 1412 validateDiag(Diag); 1413 validateTRSM(Element.F32_2(mRS), Side, TransA, A, B); 1414 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, 1415 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0); 1416 } 1417 public void ZTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Double2 alpha, Allocation A, Allocation B) { 1418 validateUplo(Uplo); 1419 validateDiag(Diag); 1420 validateTRSM(Element.F64_2(mRS), Side, TransA, A, B); 1421 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, 1422 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0); 1423 } 1424 1425 static void validateHEMM(Element e, @Side int Side, Allocation A, Allocation B, Allocation C) { 1426 validateSide(Side); 1427 1428 if (!A.getType().getElement().isCompatible(e) || 1429 !B.getType().getElement().isCompatible(e) || 1430 !C.getType().getElement().isCompatible(e)) { 1431 throw new RSRuntimeException("Called BLAS with wrong Element type"); 1432 } 1433 1434 // A must be square; can potentially be relaxed similar to TRSM 1435 int adim = A.getType().getX(); 1436 if (adim != A.getType().getY()) { 1437 throw new RSRuntimeException("Called HEMM with non-square A"); 1438 } 1439 if ((Side == LEFT && adim != B.getType().getY()) || 1440 (Side == RIGHT && adim != B.getType().getX())) { 1441 throw new RSRuntimeException("Called HEMM with invalid B"); 1442 } 1443 if (B.getType().getX() != C.getType().getX() || 1444 B.getType().getY() != C.getType().getY()) { 1445 throw new RSRuntimeException("Called HEMM with mismatched B and C"); 1446 } 1447 } 1448 public void CHEMM(@Side int Side, @Uplo int Uplo, Float2 alpha, Allocation A, Allocation B, Float2 beta, Allocation C) { 1449 validateUplo(Uplo); 1450 validateHEMM(Element.F32_2(mRS), Side, A, B, C); 1451 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, 1452 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); 1453 } 1454 public void ZHEMM(@Side int Side, @Uplo int Uplo, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) { 1455 validateUplo(Uplo); 1456 validateHEMM(Element.F64_2(mRS), Side, A, B, C); 1457 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, 1458 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); 1459 } 1460 1461 static void validateHERK(Element e, @Transpose int Trans, Allocation A, Allocation C) { 1462 if (!A.getType().getElement().isCompatible(e) || 1463 !C.getType().getElement().isCompatible(e)) { 1464 throw new RSRuntimeException("Called BLAS with wrong Element type"); 1465 } 1466 validateConjTranspose(Trans); 1467 int cdim = C.getType().getX(); 1468 if (cdim != C.getType().getY()) { 1469 throw new RSRuntimeException("Called HERK with non-square C"); 1470 } 1471 if (Trans == NO_TRANSPOSE) { 1472 if (cdim != A.getType().getY()) { 1473 throw new RSRuntimeException("Called HERK with invalid A"); 1474 } 1475 } else { 1476 if (cdim != A.getType().getX()) { 1477 throw new RSRuntimeException("Called HERK with invalid A"); 1478 } 1479 } 1480 } 1481 public void CHERK(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, float beta, Allocation C) { 1482 validateUplo(Uplo); 1483 validateHERK(Element.F32_2(mRS), Trans, A, C); 1484 int k = 0; 1485 if (Trans == CONJ_TRANSPOSE) { 1486 k = A.getType().getY(); 1487 } else { 1488 k = A.getType().getX(); 1489 } 1490 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cherk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, 1491 alpha, 0, A.getID(mRS), 0, beta, 0, C.getID(mRS), 0, 0, 0, 0); 1492 } 1493 public void ZHERK(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, double beta, Allocation C) { 1494 validateUplo(Uplo); 1495 validateHERK(Element.F64_2(mRS), Trans, A, C); 1496 int k = 0; 1497 if (Trans == CONJ_TRANSPOSE) { 1498 k = A.getType().getY(); 1499 } else { 1500 k = A.getType().getX(); 1501 } 1502 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zherk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, 1503 alpha, 0, A.getID(mRS), 0, beta, 0, C.getID(mRS), 0, 0, 0, 0); 1504 } 1505 1506 static void validateHER2K(Element e, @Transpose int Trans, Allocation A, Allocation B, Allocation C) { 1507 if (!A.getType().getElement().isCompatible(e) || 1508 !B.getType().getElement().isCompatible(e) || 1509 !C.getType().getElement().isCompatible(e)) { 1510 throw new RSRuntimeException("Called BLAS with wrong Element type"); 1511 } 1512 validateConjTranspose(Trans); 1513 int cdim = C.getType().getX(); 1514 if (cdim != C.getType().getY()) { 1515 throw new RSRuntimeException("Called HER2K with non-square C"); 1516 } 1517 if (Trans == NO_TRANSPOSE) { 1518 if (A.getType().getY() != cdim) { 1519 throw new RSRuntimeException("Called HER2K with invalid matrices"); 1520 } 1521 } else { 1522 if (A.getType().getX() != cdim) { 1523 throw new RSRuntimeException("Called HER2K with invalid matrices"); 1524 } 1525 } 1526 if (A.getType().getX() != B.getType().getX() || A.getType().getY() != B.getType().getY()) { 1527 throw new RSRuntimeException("Called HER2K with invalid A and B matrices"); 1528 } 1529 } 1530 public void CHER2K(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Allocation B, float beta, Allocation C) { 1531 validateUplo(Uplo); 1532 validateHER2K(Element.F32_2(mRS), Trans, A, B, C); 1533 int k = 0; 1534 if (Trans == NO_TRANSPOSE) { 1535 k = A.getType().getX(); 1536 } else { 1537 k = A.getType().getY(); 1538 } 1539 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, alpha.x, alpha.y, 1540 A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0); 1541 } 1542 public void ZHER2K(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Allocation B, double beta, Allocation C) { 1543 validateUplo(Uplo); 1544 validateHER2K(Element.F64_2(mRS), Trans, A, B, C); 1545 int k = 0; 1546 if (Trans == NO_TRANSPOSE) { 1547 k = A.getType().getX(); 1548 } else { 1549 k = A.getType().getY(); 1550 } 1551 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, alpha.x, alpha.y, 1552 A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0); 1553 } 1554 1555 1556 /** 1557 * 8-bit GEMM-like operation for neural networks: C = B.transposed() * A 1558 * Calculations are done in 1.10.21 fixed-point format for the final output, 1559 * just before there's a shift down to drop the fractional parts. The output 1560 * values are gated to 0 to 255 to fit in a byte, but the 10-bit format 1561 * gives some headroom to avoid wrapping around on small overflows. 1562 * 1563 * @param A The input allocation contains matrix A, supported elements type {@link Element#U8}. 1564 * @param a_offset The offset for all values in matrix A, e.g A[i,j] = A[i,j] - a_offset. Value should be from 0 to 255. 1565 * @param B The input allocation contains matrix B, supported elements type {@link Element#U8}. 1566 * @param b_offset The offset for all values in matrix B, e.g B[i,j] = B[i,j] - b_offset. Value should be from 0 to 255. 1567 * @param C The input allocation contains matrix C, supported elements type {@link Element#U8}. 1568 * @param c_offset The offset for all values in matrix C. 1569 * @param c_mult The multiplier for all values in matrix C, e.g C[i,j] = (C[i,j] + c_offset) * c_mult. 1570 **/ 1571 public void BNNM(Allocation A, int a_offset, Allocation B, int b_offset, Allocation C, int c_offset, int c_mult) { 1572 validateL3(Element.U8(mRS), NO_TRANSPOSE, TRANSPOSE, 0, A, B, C); 1573 1574 if (a_offset < 0 || a_offset > 255) { 1575 throw new RSRuntimeException("Invalid a_offset passed to BNNM"); 1576 } 1577 if (b_offset < 0 || b_offset > 255) { 1578 throw new RSRuntimeException("Invalid b_offset passed to BNNM"); 1579 } 1580 int M = -1, N = -1, K = -1; 1581 M = A.getType().getY(); 1582 N = B.getType().getY(); 1583 K = A.getType().getX(); 1584 1585 1586 mRS.nScriptIntrinsicBLAS_BNNM(getID(mRS), M, N, K, A.getID(mRS), a_offset, B.getID(mRS), b_offset, C.getID(mRS), c_offset, c_mult); 1587 1588 } 1589 1590} 1591