Loading rs/java/android/renderscript/ScriptIntrinsicBLAS.java +88 −57 Original line number Diff line number Diff line Loading @@ -242,7 +242,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { } static void validateUplo(@Uplo int Uplo) { if (Uplo != LEFT && Uplo != RIGHT) { if (Uplo != UPPER && Uplo != LOWER) { throw new RSRuntimeException("Invalid uplo passed to BLAS"); } } Loading Loading @@ -986,56 +986,74 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { */ static void validateL3(Element e, int TransA, int TransB, int Side, Allocation A, Allocation B, Allocation C) { int aX = -1, aY = -1, bX = -1, bY = -1, cX = -1, cY = -1; int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1; if ((A != null && !A.getType().getElement().isCompatible(e)) || (B != null && !B.getType().getElement().isCompatible(e)) || (C != null && !C.getType().getElement().isCompatible(e))) { throw new RSRuntimeException("Called BLAS with wrong Element type"); } if (C != null) { cX = C.getType().getY(); cY = C.getType().getX(); if (C == null) { //since matrix C is used to store the result, it cannot be null. throw new RSRuntimeException("Allocation C cannot be null"); } cM = C.getType().getY(); cN = C.getType().getX(); if (Side == RIGHT) { if ((A == null && B != null) || (A != null && B == null)) { throw new RSRuntimeException("Provided Matrix A without Matrix B, or vice versa"); } if (B != null) { bX = A.getType().getY(); bY = A.getType().getX(); bM = A.getType().getY(); bN = A.getType().getX(); } if (A != null) { aX = B.getType().getY(); aY = B.getType().getX(); aM = B.getType().getY(); aN = B.getType().getX(); } } else { if (A != null) { if (TransA == TRANSPOSE) { aY = A.getType().getY(); aX = A.getType().getX(); if (TransA != NO_TRANSPOSE) { aN = A.getType().getY(); aM = A.getType().getX(); } else { aX = A.getType().getY(); aY = A.getType().getX(); aM = A.getType().getY(); aN = A.getType().getX(); } } if (B != null) { if (TransB == TRANSPOSE) { bY = B.getType().getY(); bX = B.getType().getX(); if (TransB != NO_TRANSPOSE) { bN = B.getType().getY(); bM = B.getType().getX(); } else { bX = B.getType().getY(); bY = B.getType().getX(); bM = B.getType().getY(); bN = B.getType().getX(); } } } if (A != null && B != null && C != null) { if (aY != bX || aX != cX || bY != cY) { if (aN != bM || aM != cM || bN != cN) { throw new RSRuntimeException("Called BLAS with invalid dimensions"); } } else if (A != null && C != null) { // A and C only if (aX != cY || aY != cX) { // A and C only, for SYRK if (cM != cN) { throw new RSRuntimeException("Matrix C is not symmetric"); } if (TransA != NO_TRANSPOSE) { if (aN != cM) { throw new RSRuntimeException("Called BLAS with invalid dimensions"); } } else { if (aM != cM) { throw new RSRuntimeException("Called BLAS with invalid dimensions"); } } } else if (A != null && B != null) { // A and B only if (aN != bM) { throw new RSRuntimeException("Called BLAS with invalid dimensions"); } } } Loading @@ -1047,14 +1065,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateL3(Element.F32(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; if (TransA == TRANSPOSE) { if (TransA != NO_TRANSPOSE) { M = A.getType().getX(); K = A.getType().getY(); } else { M = A.getType().getY(); K = A.getType().getX(); } if (TransB == TRANSPOSE) { if (TransB != NO_TRANSPOSE) { N = B.getType().getY(); } else { N = B.getType().getX(); Loading @@ -1068,14 +1086,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateTranspose(TransB); validateL3(Element.F64(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; if (TransA == TRANSPOSE) { if (TransA != NO_TRANSPOSE) { M = A.getType().getX(); K = A.getType().getY(); } else { M = A.getType().getY(); K = A.getType().getX(); } if (TransB == TRANSPOSE) { if (TransB != NO_TRANSPOSE) { N = B.getType().getY(); } else { N = B.getType().getX(); Loading @@ -1089,14 +1107,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateTranspose(TransB); validateL3(Element.F32_2(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; if (TransA == TRANSPOSE) { if (TransA != NO_TRANSPOSE) { M = A.getType().getX(); K = A.getType().getY(); } else { M = A.getType().getY(); K = A.getType().getX(); } if (TransB == TRANSPOSE) { if (TransB != NO_TRANSPOSE) { N = B.getType().getY(); } else { N = B.getType().getX(); Loading @@ -1111,14 +1129,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateTranspose(TransB); validateL3(Element.F64_2(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; if (TransA == TRANSPOSE) { if (TransA != NO_TRANSPOSE) { M = A.getType().getX(); K = A.getType().getY(); } else { M = A.getType().getY(); K = A.getType().getX(); } if (TransB == TRANSPOSE) { if (TransB != NO_TRANSPOSE) { N = B.getType().getY(); } else { N = B.getType().getX(); Loading @@ -1131,6 +1149,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { Allocation B, float beta, Allocation C) { validateSide(Side); validateUplo(Uplo); //For SYMM, Matrix A should be symmetric if (A.getType().getX() != A.getType().getY()) { throw new RSRuntimeException("Matrix A is not symmetric"); } validateL3(Element.F32(mRS), 0, 0, Side, A, B, C); mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0); Loading @@ -1139,6 +1161,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { Allocation B, double beta, Allocation C) { validateSide(Side); validateUplo(Uplo); if (A.getType().getX() != A.getType().getY()) { throw new RSRuntimeException("Matrix A is not symmetric"); } validateL3(Element.F64(mRS), 0, 0, Side, A, B, C); mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0); Loading @@ -1147,6 +1172,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { Allocation B, Float2 beta, Allocation C) { validateSide(Side); validateUplo(Uplo); if (A.getType().getX() != A.getType().getY()) { throw new RSRuntimeException("Matrix A is not symmetric"); } validateL3(Element.F32_2(mRS), 0, 0, Side, A, B, C); mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); Loading @@ -1155,6 +1183,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { Allocation B, Double2 beta, Allocation C) { validateSide(Side); validateUplo(Uplo); if (A.getType().getX() != A.getType().getY()) { throw new RSRuntimeException("Matrix A is not symmetric"); } validateL3(Element.F64_2(mRS), 0, 0, Side, A, B, C); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); Loading @@ -1165,7 +1196,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateL3(Element.F32(mRS), Trans, 0, 0, A, null, C); int K = -1; if (Trans == TRANSPOSE) { if (Trans != NO_TRANSPOSE) { K = A.getType().getY(); } else { K = A.getType().getX(); Loading @@ -1179,7 +1210,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateL3(Element.F64(mRS), Trans, 0, 0, A, null, C); int K = -1; if (Trans == TRANSPOSE) { if (Trans != NO_TRANSPOSE) { K = A.getType().getY(); } else { K = A.getType().getX(); Loading @@ -1191,7 +1222,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateL3(Element.F32_2(mRS), Trans, 0, 0, A, null, C); int K = -1; if (Trans == TRANSPOSE) { if (Trans != NO_TRANSPOSE) { K = A.getType().getY(); } else { K = A.getType().getX(); Loading @@ -1204,7 +1235,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateL3(Element.F64_2(mRS), Trans, 0, 0, A, null, C); int K = -1; if (Trans == TRANSPOSE) { if (Trans != NO_TRANSPOSE) { K = A.getType().getY(); } else { K = A.getType().getX(); Loading @@ -1230,7 +1261,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { // check rows versus C Cdim = A.getType().getY(); } if (C.getType().getX() != Cdim && C.getType().getY() != Cdim) { if (C.getType().getX() != Cdim || C.getType().getY() != Cdim) { throw new RSRuntimeException("Invalid symmetric matrix in SYR2K"); } // A dims == B dims Loading Loading @@ -1286,26 +1317,26 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { static void validateTRMM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) { validateSide(Side); validateTranspose(TransA); int aX = -1, aY = -1, bX = -1, bY = -1; int aM = -1, aN = -1, bM = -1, bN = -1; if (!A.getType().getElement().isCompatible(e) || !B.getType().getElement().isCompatible(e)) { throw new RSRuntimeException("Called BLAS with wrong Element type"); } if (TransA == TRANSPOSE) { aY = A.getType().getY(); aX = A.getType().getX(); } else { aY = A.getType().getX(); aX = A.getType().getY(); aM = A.getType().getY(); aN = A.getType().getX(); if (aM != aN) { throw new RSRuntimeException("Called TRMM with a non-symmetric matrix A"); } bX = B.getType().getY(); bY = B.getType().getX(); bM = B.getType().getY(); bN = B.getType().getX(); if (Side == LEFT) { if (aX == 0 || aY != bX) { if (aN != bM) { throw new RSRuntimeException("Called TRMM with invalid matrices"); } } else { if (bY != aX || aY == 0) { if (bN != aM) { throw new RSRuntimeException("Called TRMM with invalid matrices"); } } Loading Loading @@ -1340,7 +1371,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { } static void validateTRSM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) { int adim = -1, bX = -1, bY = -1; int adim = -1, bM = -1, bN = -1; validateSide(Side); validateTranspose(TransA); if (!A.getType().getElement().isCompatible(e) || Loading @@ -1354,16 +1385,16 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { // for now we assume adapters are sufficient, will reevaluate in the future throw new RSRuntimeException("Called TRSM with a non-symmetric matrix A"); } bX = B.getType().getY(); bY = B.getType().getX(); bM = B.getType().getY(); bN = B.getType().getX(); if (Side == LEFT) { // A is M*M if (adim != bY) { if (adim != bM) { throw new RSRuntimeException("Called TRSM with invalid matrix dimensions"); } } else { // A is N*N if (adim != bX) { if (adim != bN) { throw new RSRuntimeException("Called TRSM with invalid matrix dimensions"); } } Loading Loading @@ -1428,7 +1459,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { } public void ZHEMM(@Side int Side, @Uplo int Uplo, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) { validateUplo(Uplo); validateHEMM(Element.F32_2(mRS), Side, A, B, C); validateHEMM(Element.F64_2(mRS), Side, A, B, C); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); } Loading @@ -1444,11 +1475,11 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { throw new RSRuntimeException("Called HERK with non-square C"); } if (Trans == NO_TRANSPOSE) { if (cdim != A.getType().getX()) { if (cdim != A.getType().getY()) { throw new RSRuntimeException("Called HERK with invalid A"); } } else { if (cdim != A.getType().getY()) { if (cdim != A.getType().getX()) { throw new RSRuntimeException("Called HERK with invalid A"); } } Loading @@ -1457,7 +1488,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateHERK(Element.F32_2(mRS), Trans, A, C); int k = 0; if (Trans == TRANSPOSE) { if (Trans == CONJ_TRANSPOSE) { k = A.getType().getY(); } else { k = A.getType().getX(); Loading @@ -1469,7 +1500,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateHERK(Element.F64_2(mRS), Trans, A, C); int k = 0; if (Trans == TRANSPOSE) { if (Trans == CONJ_TRANSPOSE) { k = A.getType().getY(); } else { k = A.getType().getX(); Loading Loading
rs/java/android/renderscript/ScriptIntrinsicBLAS.java +88 −57 Original line number Diff line number Diff line Loading @@ -242,7 +242,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { } static void validateUplo(@Uplo int Uplo) { if (Uplo != LEFT && Uplo != RIGHT) { if (Uplo != UPPER && Uplo != LOWER) { throw new RSRuntimeException("Invalid uplo passed to BLAS"); } } Loading Loading @@ -986,56 +986,74 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { */ static void validateL3(Element e, int TransA, int TransB, int Side, Allocation A, Allocation B, Allocation C) { int aX = -1, aY = -1, bX = -1, bY = -1, cX = -1, cY = -1; int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1; if ((A != null && !A.getType().getElement().isCompatible(e)) || (B != null && !B.getType().getElement().isCompatible(e)) || (C != null && !C.getType().getElement().isCompatible(e))) { throw new RSRuntimeException("Called BLAS with wrong Element type"); } if (C != null) { cX = C.getType().getY(); cY = C.getType().getX(); if (C == null) { //since matrix C is used to store the result, it cannot be null. throw new RSRuntimeException("Allocation C cannot be null"); } cM = C.getType().getY(); cN = C.getType().getX(); if (Side == RIGHT) { if ((A == null && B != null) || (A != null && B == null)) { throw new RSRuntimeException("Provided Matrix A without Matrix B, or vice versa"); } if (B != null) { bX = A.getType().getY(); bY = A.getType().getX(); bM = A.getType().getY(); bN = A.getType().getX(); } if (A != null) { aX = B.getType().getY(); aY = B.getType().getX(); aM = B.getType().getY(); aN = B.getType().getX(); } } else { if (A != null) { if (TransA == TRANSPOSE) { aY = A.getType().getY(); aX = A.getType().getX(); if (TransA != NO_TRANSPOSE) { aN = A.getType().getY(); aM = A.getType().getX(); } else { aX = A.getType().getY(); aY = A.getType().getX(); aM = A.getType().getY(); aN = A.getType().getX(); } } if (B != null) { if (TransB == TRANSPOSE) { bY = B.getType().getY(); bX = B.getType().getX(); if (TransB != NO_TRANSPOSE) { bN = B.getType().getY(); bM = B.getType().getX(); } else { bX = B.getType().getY(); bY = B.getType().getX(); bM = B.getType().getY(); bN = B.getType().getX(); } } } if (A != null && B != null && C != null) { if (aY != bX || aX != cX || bY != cY) { if (aN != bM || aM != cM || bN != cN) { throw new RSRuntimeException("Called BLAS with invalid dimensions"); } } else if (A != null && C != null) { // A and C only if (aX != cY || aY != cX) { // A and C only, for SYRK if (cM != cN) { throw new RSRuntimeException("Matrix C is not symmetric"); } if (TransA != NO_TRANSPOSE) { if (aN != cM) { throw new RSRuntimeException("Called BLAS with invalid dimensions"); } } else { if (aM != cM) { throw new RSRuntimeException("Called BLAS with invalid dimensions"); } } } else if (A != null && B != null) { // A and B only if (aN != bM) { throw new RSRuntimeException("Called BLAS with invalid dimensions"); } } } Loading @@ -1047,14 +1065,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateL3(Element.F32(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; if (TransA == TRANSPOSE) { if (TransA != NO_TRANSPOSE) { M = A.getType().getX(); K = A.getType().getY(); } else { M = A.getType().getY(); K = A.getType().getX(); } if (TransB == TRANSPOSE) { if (TransB != NO_TRANSPOSE) { N = B.getType().getY(); } else { N = B.getType().getX(); Loading @@ -1068,14 +1086,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateTranspose(TransB); validateL3(Element.F64(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; if (TransA == TRANSPOSE) { if (TransA != NO_TRANSPOSE) { M = A.getType().getX(); K = A.getType().getY(); } else { M = A.getType().getY(); K = A.getType().getX(); } if (TransB == TRANSPOSE) { if (TransB != NO_TRANSPOSE) { N = B.getType().getY(); } else { N = B.getType().getX(); Loading @@ -1089,14 +1107,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateTranspose(TransB); validateL3(Element.F32_2(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; if (TransA == TRANSPOSE) { if (TransA != NO_TRANSPOSE) { M = A.getType().getX(); K = A.getType().getY(); } else { M = A.getType().getY(); K = A.getType().getX(); } if (TransB == TRANSPOSE) { if (TransB != NO_TRANSPOSE) { N = B.getType().getY(); } else { N = B.getType().getX(); Loading @@ -1111,14 +1129,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateTranspose(TransB); validateL3(Element.F64_2(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; if (TransA == TRANSPOSE) { if (TransA != NO_TRANSPOSE) { M = A.getType().getX(); K = A.getType().getY(); } else { M = A.getType().getY(); K = A.getType().getX(); } if (TransB == TRANSPOSE) { if (TransB != NO_TRANSPOSE) { N = B.getType().getY(); } else { N = B.getType().getX(); Loading @@ -1131,6 +1149,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { Allocation B, float beta, Allocation C) { validateSide(Side); validateUplo(Uplo); //For SYMM, Matrix A should be symmetric if (A.getType().getX() != A.getType().getY()) { throw new RSRuntimeException("Matrix A is not symmetric"); } validateL3(Element.F32(mRS), 0, 0, Side, A, B, C); mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0); Loading @@ -1139,6 +1161,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { Allocation B, double beta, Allocation C) { validateSide(Side); validateUplo(Uplo); if (A.getType().getX() != A.getType().getY()) { throw new RSRuntimeException("Matrix A is not symmetric"); } validateL3(Element.F64(mRS), 0, 0, Side, A, B, C); mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0); Loading @@ -1147,6 +1172,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { Allocation B, Float2 beta, Allocation C) { validateSide(Side); validateUplo(Uplo); if (A.getType().getX() != A.getType().getY()) { throw new RSRuntimeException("Matrix A is not symmetric"); } validateL3(Element.F32_2(mRS), 0, 0, Side, A, B, C); mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); Loading @@ -1155,6 +1183,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { Allocation B, Double2 beta, Allocation C) { validateSide(Side); validateUplo(Uplo); if (A.getType().getX() != A.getType().getY()) { throw new RSRuntimeException("Matrix A is not symmetric"); } validateL3(Element.F64_2(mRS), 0, 0, Side, A, B, C); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); Loading @@ -1165,7 +1196,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateL3(Element.F32(mRS), Trans, 0, 0, A, null, C); int K = -1; if (Trans == TRANSPOSE) { if (Trans != NO_TRANSPOSE) { K = A.getType().getY(); } else { K = A.getType().getX(); Loading @@ -1179,7 +1210,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateL3(Element.F64(mRS), Trans, 0, 0, A, null, C); int K = -1; if (Trans == TRANSPOSE) { if (Trans != NO_TRANSPOSE) { K = A.getType().getY(); } else { K = A.getType().getX(); Loading @@ -1191,7 +1222,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateL3(Element.F32_2(mRS), Trans, 0, 0, A, null, C); int K = -1; if (Trans == TRANSPOSE) { if (Trans != NO_TRANSPOSE) { K = A.getType().getY(); } else { K = A.getType().getX(); Loading @@ -1204,7 +1235,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateL3(Element.F64_2(mRS), Trans, 0, 0, A, null, C); int K = -1; if (Trans == TRANSPOSE) { if (Trans != NO_TRANSPOSE) { K = A.getType().getY(); } else { K = A.getType().getX(); Loading @@ -1230,7 +1261,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { // check rows versus C Cdim = A.getType().getY(); } if (C.getType().getX() != Cdim && C.getType().getY() != Cdim) { if (C.getType().getX() != Cdim || C.getType().getY() != Cdim) { throw new RSRuntimeException("Invalid symmetric matrix in SYR2K"); } // A dims == B dims Loading Loading @@ -1286,26 +1317,26 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { static void validateTRMM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) { validateSide(Side); validateTranspose(TransA); int aX = -1, aY = -1, bX = -1, bY = -1; int aM = -1, aN = -1, bM = -1, bN = -1; if (!A.getType().getElement().isCompatible(e) || !B.getType().getElement().isCompatible(e)) { throw new RSRuntimeException("Called BLAS with wrong Element type"); } if (TransA == TRANSPOSE) { aY = A.getType().getY(); aX = A.getType().getX(); } else { aY = A.getType().getX(); aX = A.getType().getY(); aM = A.getType().getY(); aN = A.getType().getX(); if (aM != aN) { throw new RSRuntimeException("Called TRMM with a non-symmetric matrix A"); } bX = B.getType().getY(); bY = B.getType().getX(); bM = B.getType().getY(); bN = B.getType().getX(); if (Side == LEFT) { if (aX == 0 || aY != bX) { if (aN != bM) { throw new RSRuntimeException("Called TRMM with invalid matrices"); } } else { if (bY != aX || aY == 0) { if (bN != aM) { throw new RSRuntimeException("Called TRMM with invalid matrices"); } } Loading Loading @@ -1340,7 +1371,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { } static void validateTRSM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) { int adim = -1, bX = -1, bY = -1; int adim = -1, bM = -1, bN = -1; validateSide(Side); validateTranspose(TransA); if (!A.getType().getElement().isCompatible(e) || Loading @@ -1354,16 +1385,16 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { // for now we assume adapters are sufficient, will reevaluate in the future throw new RSRuntimeException("Called TRSM with a non-symmetric matrix A"); } bX = B.getType().getY(); bY = B.getType().getX(); bM = B.getType().getY(); bN = B.getType().getX(); if (Side == LEFT) { // A is M*M if (adim != bY) { if (adim != bM) { throw new RSRuntimeException("Called TRSM with invalid matrix dimensions"); } } else { // A is N*N if (adim != bX) { if (adim != bN) { throw new RSRuntimeException("Called TRSM with invalid matrix dimensions"); } } Loading Loading @@ -1428,7 +1459,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { } public void ZHEMM(@Side int Side, @Uplo int Uplo, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) { validateUplo(Uplo); validateHEMM(Element.F32_2(mRS), Side, A, B, C); validateHEMM(Element.F64_2(mRS), Side, A, B, C); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); } Loading @@ -1444,11 +1475,11 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { throw new RSRuntimeException("Called HERK with non-square C"); } if (Trans == NO_TRANSPOSE) { if (cdim != A.getType().getX()) { if (cdim != A.getType().getY()) { throw new RSRuntimeException("Called HERK with invalid A"); } } else { if (cdim != A.getType().getY()) { if (cdim != A.getType().getX()) { throw new RSRuntimeException("Called HERK with invalid A"); } } Loading @@ -1457,7 +1488,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateHERK(Element.F32_2(mRS), Trans, A, C); int k = 0; if (Trans == TRANSPOSE) { if (Trans == CONJ_TRANSPOSE) { k = A.getType().getY(); } else { k = A.getType().getX(); Loading @@ -1469,7 +1500,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateHERK(Element.F64_2(mRS), Trans, A, C); int k = 0; if (Trans == TRANSPOSE) { if (Trans == CONJ_TRANSPOSE) { k = A.getType().getY(); } else { k = A.getType().getX(); Loading