Loading rs/java/android/renderscript/ScriptIntrinsicBLAS.java +89 −58 Original line number Original line Diff line number Diff line Loading @@ -241,7 +241,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { } } static void validateUplo(@Uplo int Uplo) { static void validateUplo(@Uplo int Uplo) { if (Uplo != LEFT && Uplo != RIGHT) { if (Uplo != UPPER && Uplo != LOWER) { throw new RSRuntimeException("Invalid uplo passed to BLAS"); throw new RSRuntimeException("Invalid uplo passed to BLAS"); } } } } Loading Loading @@ -960,7 +960,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { } } public void ZHER(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) { public void ZHER(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) { // same as SYR // same as SYR int N = validateSYR(Element.F64(mRS), Uplo, X, incX, A); int N = validateSYR(Element.F64_2(mRS), Uplo, X, incX, A); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0); } } public void ZHPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) { public void ZHPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) { Loading @@ -985,56 +985,74 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { */ */ static void validateL3(Element e, int TransA, int TransB, int Side, Allocation A, Allocation B, Allocation C) { static void validateL3(Element e, int TransA, int TransB, int Side, Allocation A, Allocation B, Allocation C) { int aX = -1, aY = -1, bX = -1, bY = -1, cX = -1, cY = -1; int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1; if ((A != null && !A.getType().getElement().isCompatible(e)) || if ((A != null && !A.getType().getElement().isCompatible(e)) || (B != null && !B.getType().getElement().isCompatible(e)) || (B != null && !B.getType().getElement().isCompatible(e)) || (C != null && !C.getType().getElement().isCompatible(e))) { (C != null && !C.getType().getElement().isCompatible(e))) { throw new RSRuntimeException("Called BLAS with wrong Element type"); throw new RSRuntimeException("Called BLAS with wrong Element type"); } } if (C != null) { if (C == null) { cX = C.getType().getY(); //since matrix C is used to store the result, it cannot be null. cY = C.getType().getX(); throw new RSRuntimeException("Allocation C cannot be null"); } } cM = C.getType().getY(); cN = C.getType().getX(); if (Side == RIGHT) { if (Side == RIGHT) { if ((A == null && B != null) || (A != null && B == null)) { throw new RSRuntimeException("Provided Matrix A without Matrix B, or vice versa"); } if (B != null) { if (B != null) { bX = A.getType().getY(); bM = A.getType().getY(); bY = A.getType().getX(); bN = A.getType().getX(); } } if (A != null) { if (A != null) { aX = B.getType().getY(); aM = B.getType().getY(); aY = B.getType().getX(); aN = B.getType().getX(); } } } else { } else { if (A != null) { if (A != null) { if (TransA == TRANSPOSE) { if (TransA != NO_TRANSPOSE) { aY = A.getType().getY(); aN = A.getType().getY(); aX = A.getType().getX(); aM = A.getType().getX(); } else { } else { aX = A.getType().getY(); aM = A.getType().getY(); aY = A.getType().getX(); aN = A.getType().getX(); } } } } if (B != null) { if (B != null) { if (TransB == TRANSPOSE) { if (TransB != NO_TRANSPOSE) { bY = B.getType().getY(); bN = B.getType().getY(); bX = B.getType().getX(); bM = B.getType().getX(); } else { } else { bX = B.getType().getY(); bM = B.getType().getY(); bY = B.getType().getX(); bN = B.getType().getX(); } } } } } } if (A != null && B != null && C != null) { if (A != null && B != null && C != null) { if (aY != bX || aX != cX || bY != cY) { if (aN != bM || aM != cM || bN != cN) { throw new RSRuntimeException("Called BLAS with invalid dimensions"); throw new RSRuntimeException("Called BLAS with invalid dimensions"); } } } else if (A != null && C != null) { } else if (A != null && C != null) { // A and C only // A and C only, for SYRK if (aX != cY || aY != cX) { if (cM != cN) { throw new RSRuntimeException("Matrix C is not symmetric"); } if (TransA != NO_TRANSPOSE) { if (aN != cM) { throw new RSRuntimeException("Called BLAS with invalid dimensions"); } } else { if (aM != cM) { throw new RSRuntimeException("Called BLAS with invalid dimensions"); throw new RSRuntimeException("Called BLAS with invalid dimensions"); } } } } else if (A != null && B != null) { } else if (A != null && B != null) { // A and B only // A and B only if (aN != bM) { throw new RSRuntimeException("Called BLAS with invalid dimensions"); } } } } } Loading @@ -1046,14 +1064,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateL3(Element.F32(mRS), TransA, TransB, 0, A, B, C); validateL3(Element.F32(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; int M = -1, N = -1, K = -1; if (TransA == TRANSPOSE) { if (TransA != NO_TRANSPOSE) { M = A.getType().getX(); M = A.getType().getX(); K = A.getType().getY(); K = A.getType().getY(); } else { } else { M = A.getType().getY(); M = A.getType().getY(); K = A.getType().getX(); K = A.getType().getX(); } } if (TransB == TRANSPOSE) { if (TransB != NO_TRANSPOSE) { N = B.getType().getY(); N = B.getType().getY(); } else { } else { N = B.getType().getX(); N = B.getType().getX(); Loading @@ -1067,14 +1085,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateTranspose(TransB); validateTranspose(TransB); validateL3(Element.F64(mRS), TransA, TransB, 0, A, B, C); validateL3(Element.F64(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; int M = -1, N = -1, K = -1; if (TransA == TRANSPOSE) { if (TransA != NO_TRANSPOSE) { M = A.getType().getX(); M = A.getType().getX(); K = A.getType().getY(); K = A.getType().getY(); } else { } else { M = A.getType().getY(); M = A.getType().getY(); K = A.getType().getX(); K = A.getType().getX(); } } if (TransB == TRANSPOSE) { if (TransB != NO_TRANSPOSE) { N = B.getType().getY(); N = B.getType().getY(); } else { } else { N = B.getType().getX(); N = B.getType().getX(); Loading @@ -1088,14 +1106,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateTranspose(TransB); validateTranspose(TransB); validateL3(Element.F32_2(mRS), TransA, TransB, 0, A, B, C); validateL3(Element.F32_2(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; int M = -1, N = -1, K = -1; if (TransA == TRANSPOSE) { if (TransA != NO_TRANSPOSE) { M = A.getType().getX(); M = A.getType().getX(); K = A.getType().getY(); K = A.getType().getY(); } else { } else { M = A.getType().getY(); M = A.getType().getY(); K = A.getType().getX(); K = A.getType().getX(); } } if (TransB == TRANSPOSE) { if (TransB != NO_TRANSPOSE) { N = B.getType().getY(); N = B.getType().getY(); } else { } else { N = B.getType().getX(); N = B.getType().getX(); Loading @@ -1110,14 +1128,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateTranspose(TransB); validateTranspose(TransB); validateL3(Element.F64_2(mRS), TransA, TransB, 0, A, B, C); validateL3(Element.F64_2(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; int M = -1, N = -1, K = -1; if (TransA == TRANSPOSE) { if (TransA != NO_TRANSPOSE) { M = A.getType().getX(); M = A.getType().getX(); K = A.getType().getY(); K = A.getType().getY(); } else { } else { M = A.getType().getY(); M = A.getType().getY(); K = A.getType().getX(); K = A.getType().getX(); } } if (TransB == TRANSPOSE) { if (TransB != NO_TRANSPOSE) { N = B.getType().getY(); N = B.getType().getY(); } else { } else { N = B.getType().getX(); N = B.getType().getX(); Loading @@ -1130,6 +1148,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { Allocation B, float beta, Allocation C) { Allocation B, float beta, Allocation C) { validateSide(Side); validateSide(Side); validateUplo(Uplo); validateUplo(Uplo); //For SYMM, Matrix A should be symmetric if (A.getType().getX() != A.getType().getY()) { throw new RSRuntimeException("Matrix A is not symmetric"); } validateL3(Element.F32(mRS), 0, 0, Side, A, B, C); validateL3(Element.F32(mRS), 0, 0, Side, A, B, C); mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS), mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0); beta, C.getID(mRS), 0, 0, 0, 0); Loading @@ -1138,6 +1160,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { Allocation B, double beta, Allocation C) { Allocation B, double beta, Allocation C) { validateSide(Side); validateSide(Side); validateUplo(Uplo); validateUplo(Uplo); if (A.getType().getX() != A.getType().getY()) { throw new RSRuntimeException("Matrix A is not symmetric"); } validateL3(Element.F64(mRS), 0, 0, Side, A, B, C); validateL3(Element.F64(mRS), 0, 0, Side, A, B, C); mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS), mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0); beta, C.getID(mRS), 0, 0, 0, 0); Loading @@ -1146,6 +1171,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { Allocation B, Float2 beta, Allocation C) { Allocation B, Float2 beta, Allocation C) { validateSide(Side); validateSide(Side); validateUplo(Uplo); validateUplo(Uplo); if (A.getType().getX() != A.getType().getY()) { throw new RSRuntimeException("Matrix A is not symmetric"); } validateL3(Element.F32_2(mRS), 0, 0, Side, A, B, C); validateL3(Element.F32_2(mRS), 0, 0, Side, A, B, C); mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); Loading @@ -1154,6 +1182,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { Allocation B, Double2 beta, Allocation C) { Allocation B, Double2 beta, Allocation C) { validateSide(Side); validateSide(Side); validateUplo(Uplo); validateUplo(Uplo); if (A.getType().getX() != A.getType().getY()) { throw new RSRuntimeException("Matrix A is not symmetric"); } validateL3(Element.F64_2(mRS), 0, 0, Side, A, B, C); validateL3(Element.F64_2(mRS), 0, 0, Side, A, B, C); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); Loading @@ -1164,7 +1195,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateUplo(Uplo); validateL3(Element.F32(mRS), Trans, 0, 0, A, null, C); validateL3(Element.F32(mRS), Trans, 0, 0, A, null, C); int K = -1; int K = -1; if (Trans == TRANSPOSE) { if (Trans != NO_TRANSPOSE) { K = A.getType().getY(); K = A.getType().getY(); } else { } else { K = A.getType().getX(); K = A.getType().getX(); Loading @@ -1178,7 +1209,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateUplo(Uplo); validateL3(Element.F64(mRS), Trans, 0, 0, A, null, C); validateL3(Element.F64(mRS), Trans, 0, 0, A, null, C); int K = -1; int K = -1; if (Trans == TRANSPOSE) { if (Trans != NO_TRANSPOSE) { K = A.getType().getY(); K = A.getType().getY(); } else { } else { K = A.getType().getX(); K = A.getType().getX(); Loading @@ -1190,7 +1221,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateUplo(Uplo); validateL3(Element.F32_2(mRS), Trans, 0, 0, A, null, C); validateL3(Element.F32_2(mRS), Trans, 0, 0, A, null, C); int K = -1; int K = -1; if (Trans == TRANSPOSE) { if (Trans != NO_TRANSPOSE) { K = A.getType().getY(); K = A.getType().getY(); } else { } else { K = A.getType().getX(); K = A.getType().getX(); Loading @@ -1203,7 +1234,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateUplo(Uplo); validateL3(Element.F64_2(mRS), Trans, 0, 0, A, null, C); validateL3(Element.F64_2(mRS), Trans, 0, 0, A, null, C); int K = -1; int K = -1; if (Trans == TRANSPOSE) { if (Trans != NO_TRANSPOSE) { K = A.getType().getY(); K = A.getType().getY(); } else { } else { K = A.getType().getX(); K = A.getType().getX(); Loading @@ -1229,7 +1260,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { // check rows versus C // check rows versus C Cdim = A.getType().getY(); Cdim = A.getType().getY(); } } if (C.getType().getX() != Cdim && C.getType().getY() != Cdim) { if (C.getType().getX() != Cdim || C.getType().getY() != Cdim) { throw new RSRuntimeException("Invalid symmetric matrix in SYR2K"); throw new RSRuntimeException("Invalid symmetric matrix in SYR2K"); } } // A dims == B dims // A dims == B dims Loading Loading @@ -1285,26 +1316,26 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { static void validateTRMM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) { static void validateTRMM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) { validateSide(Side); validateSide(Side); validateTranspose(TransA); validateTranspose(TransA); int aX = -1, aY = -1, bX = -1, bY = -1; int aM = -1, aN = -1, bM = -1, bN = -1; if (!A.getType().getElement().isCompatible(e) || if (!A.getType().getElement().isCompatible(e) || !B.getType().getElement().isCompatible(e)) { !B.getType().getElement().isCompatible(e)) { throw new RSRuntimeException("Called BLAS with wrong Element type"); throw new RSRuntimeException("Called BLAS with wrong Element type"); } } if (TransA == TRANSPOSE) { aY = A.getType().getY(); aM = A.getType().getY(); aX = A.getType().getX(); aN = A.getType().getX(); } else { if (aM != aN) { aY = A.getType().getX(); throw new RSRuntimeException("Called TRMM with a non-symmetric matrix A"); aX = A.getType().getY(); } } bX = B.getType().getY(); bY = B.getType().getX(); bM = B.getType().getY(); bN = B.getType().getX(); if (Side == LEFT) { if (Side == LEFT) { if (aX == 0 || aY != bX) { if (aN != bM) { throw new RSRuntimeException("Called TRMM with invalid matrices"); throw new RSRuntimeException("Called TRMM with invalid matrices"); } } } else { } else { if (bY != aX || aY == 0) { if (bN != aM) { throw new RSRuntimeException("Called TRMM with invalid matrices"); throw new RSRuntimeException("Called TRMM with invalid matrices"); } } } } Loading Loading @@ -1339,7 +1370,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { } } static void validateTRSM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) { static void validateTRSM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) { int adim = -1, bX = -1, bY = -1; int adim = -1, bM = -1, bN = -1; validateSide(Side); validateSide(Side); validateTranspose(TransA); validateTranspose(TransA); if (!A.getType().getElement().isCompatible(e) || if (!A.getType().getElement().isCompatible(e) || Loading @@ -1353,16 +1384,16 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { // for now we assume adapters are sufficient, will reevaluate in the future // for now we assume adapters are sufficient, will reevaluate in the future throw new RSRuntimeException("Called TRSM with a non-symmetric matrix A"); throw new RSRuntimeException("Called TRSM with a non-symmetric matrix A"); } } bX = B.getType().getY(); bM = B.getType().getY(); bY = B.getType().getX(); bN = B.getType().getX(); if (Side == LEFT) { if (Side == LEFT) { // A is M*M // A is M*M if (adim != bY) { if (adim != bM) { throw new RSRuntimeException("Called TRSM with invalid matrix dimensions"); throw new RSRuntimeException("Called TRSM with invalid matrix dimensions"); } } } else { } else { // A is N*N // A is N*N if (adim != bX) { if (adim != bN) { throw new RSRuntimeException("Called TRSM with invalid matrix dimensions"); throw new RSRuntimeException("Called TRSM with invalid matrix dimensions"); } } } } Loading Loading @@ -1427,7 +1458,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { } } public void ZHEMM(@Side int Side, @Uplo int Uplo, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) { public void ZHEMM(@Side int Side, @Uplo int Uplo, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) { validateUplo(Uplo); validateUplo(Uplo); validateHEMM(Element.F32_2(mRS), Side, A, B, C); validateHEMM(Element.F64_2(mRS), Side, A, B, C); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); } } Loading @@ -1443,11 +1474,11 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { throw new RSRuntimeException("Called HERK with non-square C"); throw new RSRuntimeException("Called HERK with non-square C"); } } if (Trans == NO_TRANSPOSE) { if (Trans == NO_TRANSPOSE) { if (cdim != A.getType().getX()) { if (cdim != A.getType().getY()) { throw new RSRuntimeException("Called HERK with invalid A"); throw new RSRuntimeException("Called HERK with invalid A"); } } } else { } else { if (cdim != A.getType().getY()) { if (cdim != A.getType().getX()) { throw new RSRuntimeException("Called HERK with invalid A"); throw new RSRuntimeException("Called HERK with invalid A"); } } } } Loading @@ -1456,7 +1487,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateUplo(Uplo); validateHERK(Element.F32_2(mRS), Trans, A, C); validateHERK(Element.F32_2(mRS), Trans, A, C); int k = 0; int k = 0; if (Trans == TRANSPOSE) { if (Trans == CONJ_TRANSPOSE) { k = A.getType().getY(); k = A.getType().getY(); } else { } else { k = A.getType().getX(); k = A.getType().getX(); Loading @@ -1468,7 +1499,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateUplo(Uplo); validateHERK(Element.F64_2(mRS), Trans, A, C); validateHERK(Element.F64_2(mRS), Trans, A, C); int k = 0; int k = 0; if (Trans == TRANSPOSE) { if (Trans == CONJ_TRANSPOSE) { k = A.getType().getY(); k = A.getType().getY(); } else { } else { k = A.getType().getX(); k = A.getType().getX(); Loading Loading
rs/java/android/renderscript/ScriptIntrinsicBLAS.java +89 −58 Original line number Original line Diff line number Diff line Loading @@ -241,7 +241,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { } } static void validateUplo(@Uplo int Uplo) { static void validateUplo(@Uplo int Uplo) { if (Uplo != LEFT && Uplo != RIGHT) { if (Uplo != UPPER && Uplo != LOWER) { throw new RSRuntimeException("Invalid uplo passed to BLAS"); throw new RSRuntimeException("Invalid uplo passed to BLAS"); } } } } Loading Loading @@ -960,7 +960,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { } } public void ZHER(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) { public void ZHER(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) { // same as SYR // same as SYR int N = validateSYR(Element.F64(mRS), Uplo, X, incX, A); int N = validateSYR(Element.F64_2(mRS), Uplo, X, incX, A); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0); } } public void ZHPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) { public void ZHPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) { Loading @@ -985,56 +985,74 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { */ */ static void validateL3(Element e, int TransA, int TransB, int Side, Allocation A, Allocation B, Allocation C) { static void validateL3(Element e, int TransA, int TransB, int Side, Allocation A, Allocation B, Allocation C) { int aX = -1, aY = -1, bX = -1, bY = -1, cX = -1, cY = -1; int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1; if ((A != null && !A.getType().getElement().isCompatible(e)) || if ((A != null && !A.getType().getElement().isCompatible(e)) || (B != null && !B.getType().getElement().isCompatible(e)) || (B != null && !B.getType().getElement().isCompatible(e)) || (C != null && !C.getType().getElement().isCompatible(e))) { (C != null && !C.getType().getElement().isCompatible(e))) { throw new RSRuntimeException("Called BLAS with wrong Element type"); throw new RSRuntimeException("Called BLAS with wrong Element type"); } } if (C != null) { if (C == null) { cX = C.getType().getY(); //since matrix C is used to store the result, it cannot be null. cY = C.getType().getX(); throw new RSRuntimeException("Allocation C cannot be null"); } } cM = C.getType().getY(); cN = C.getType().getX(); if (Side == RIGHT) { if (Side == RIGHT) { if ((A == null && B != null) || (A != null && B == null)) { throw new RSRuntimeException("Provided Matrix A without Matrix B, or vice versa"); } if (B != null) { if (B != null) { bX = A.getType().getY(); bM = A.getType().getY(); bY = A.getType().getX(); bN = A.getType().getX(); } } if (A != null) { if (A != null) { aX = B.getType().getY(); aM = B.getType().getY(); aY = B.getType().getX(); aN = B.getType().getX(); } } } else { } else { if (A != null) { if (A != null) { if (TransA == TRANSPOSE) { if (TransA != NO_TRANSPOSE) { aY = A.getType().getY(); aN = A.getType().getY(); aX = A.getType().getX(); aM = A.getType().getX(); } else { } else { aX = A.getType().getY(); aM = A.getType().getY(); aY = A.getType().getX(); aN = A.getType().getX(); } } } } if (B != null) { if (B != null) { if (TransB == TRANSPOSE) { if (TransB != NO_TRANSPOSE) { bY = B.getType().getY(); bN = B.getType().getY(); bX = B.getType().getX(); bM = B.getType().getX(); } else { } else { bX = B.getType().getY(); bM = B.getType().getY(); bY = B.getType().getX(); bN = B.getType().getX(); } } } } } } if (A != null && B != null && C != null) { if (A != null && B != null && C != null) { if (aY != bX || aX != cX || bY != cY) { if (aN != bM || aM != cM || bN != cN) { throw new RSRuntimeException("Called BLAS with invalid dimensions"); throw new RSRuntimeException("Called BLAS with invalid dimensions"); } } } else if (A != null && C != null) { } else if (A != null && C != null) { // A and C only // A and C only, for SYRK if (aX != cY || aY != cX) { if (cM != cN) { throw new RSRuntimeException("Matrix C is not symmetric"); } if (TransA != NO_TRANSPOSE) { if (aN != cM) { throw new RSRuntimeException("Called BLAS with invalid dimensions"); } } else { if (aM != cM) { throw new RSRuntimeException("Called BLAS with invalid dimensions"); throw new RSRuntimeException("Called BLAS with invalid dimensions"); } } } } else if (A != null && B != null) { } else if (A != null && B != null) { // A and B only // A and B only if (aN != bM) { throw new RSRuntimeException("Called BLAS with invalid dimensions"); } } } } } Loading @@ -1046,14 +1064,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateL3(Element.F32(mRS), TransA, TransB, 0, A, B, C); validateL3(Element.F32(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; int M = -1, N = -1, K = -1; if (TransA == TRANSPOSE) { if (TransA != NO_TRANSPOSE) { M = A.getType().getX(); M = A.getType().getX(); K = A.getType().getY(); K = A.getType().getY(); } else { } else { M = A.getType().getY(); M = A.getType().getY(); K = A.getType().getX(); K = A.getType().getX(); } } if (TransB == TRANSPOSE) { if (TransB != NO_TRANSPOSE) { N = B.getType().getY(); N = B.getType().getY(); } else { } else { N = B.getType().getX(); N = B.getType().getX(); Loading @@ -1067,14 +1085,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateTranspose(TransB); validateTranspose(TransB); validateL3(Element.F64(mRS), TransA, TransB, 0, A, B, C); validateL3(Element.F64(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; int M = -1, N = -1, K = -1; if (TransA == TRANSPOSE) { if (TransA != NO_TRANSPOSE) { M = A.getType().getX(); M = A.getType().getX(); K = A.getType().getY(); K = A.getType().getY(); } else { } else { M = A.getType().getY(); M = A.getType().getY(); K = A.getType().getX(); K = A.getType().getX(); } } if (TransB == TRANSPOSE) { if (TransB != NO_TRANSPOSE) { N = B.getType().getY(); N = B.getType().getY(); } else { } else { N = B.getType().getX(); N = B.getType().getX(); Loading @@ -1088,14 +1106,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateTranspose(TransB); validateTranspose(TransB); validateL3(Element.F32_2(mRS), TransA, TransB, 0, A, B, C); validateL3(Element.F32_2(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; int M = -1, N = -1, K = -1; if (TransA == TRANSPOSE) { if (TransA != NO_TRANSPOSE) { M = A.getType().getX(); M = A.getType().getX(); K = A.getType().getY(); K = A.getType().getY(); } else { } else { M = A.getType().getY(); M = A.getType().getY(); K = A.getType().getX(); K = A.getType().getX(); } } if (TransB == TRANSPOSE) { if (TransB != NO_TRANSPOSE) { N = B.getType().getY(); N = B.getType().getY(); } else { } else { N = B.getType().getX(); N = B.getType().getX(); Loading @@ -1110,14 +1128,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateTranspose(TransB); validateTranspose(TransB); validateL3(Element.F64_2(mRS), TransA, TransB, 0, A, B, C); validateL3(Element.F64_2(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; int M = -1, N = -1, K = -1; if (TransA == TRANSPOSE) { if (TransA != NO_TRANSPOSE) { M = A.getType().getX(); M = A.getType().getX(); K = A.getType().getY(); K = A.getType().getY(); } else { } else { M = A.getType().getY(); M = A.getType().getY(); K = A.getType().getX(); K = A.getType().getX(); } } if (TransB == TRANSPOSE) { if (TransB != NO_TRANSPOSE) { N = B.getType().getY(); N = B.getType().getY(); } else { } else { N = B.getType().getX(); N = B.getType().getX(); Loading @@ -1130,6 +1148,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { Allocation B, float beta, Allocation C) { Allocation B, float beta, Allocation C) { validateSide(Side); validateSide(Side); validateUplo(Uplo); validateUplo(Uplo); //For SYMM, Matrix A should be symmetric if (A.getType().getX() != A.getType().getY()) { throw new RSRuntimeException("Matrix A is not symmetric"); } validateL3(Element.F32(mRS), 0, 0, Side, A, B, C); validateL3(Element.F32(mRS), 0, 0, Side, A, B, C); mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS), mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0); beta, C.getID(mRS), 0, 0, 0, 0); Loading @@ -1138,6 +1160,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { Allocation B, double beta, Allocation C) { Allocation B, double beta, Allocation C) { validateSide(Side); validateSide(Side); validateUplo(Uplo); validateUplo(Uplo); if (A.getType().getX() != A.getType().getY()) { throw new RSRuntimeException("Matrix A is not symmetric"); } validateL3(Element.F64(mRS), 0, 0, Side, A, B, C); validateL3(Element.F64(mRS), 0, 0, Side, A, B, C); mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS), mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0); beta, C.getID(mRS), 0, 0, 0, 0); Loading @@ -1146,6 +1171,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { Allocation B, Float2 beta, Allocation C) { Allocation B, Float2 beta, Allocation C) { validateSide(Side); validateSide(Side); validateUplo(Uplo); validateUplo(Uplo); if (A.getType().getX() != A.getType().getY()) { throw new RSRuntimeException("Matrix A is not symmetric"); } validateL3(Element.F32_2(mRS), 0, 0, Side, A, B, C); validateL3(Element.F32_2(mRS), 0, 0, Side, A, B, C); mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); Loading @@ -1154,6 +1182,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { Allocation B, Double2 beta, Allocation C) { Allocation B, Double2 beta, Allocation C) { validateSide(Side); validateSide(Side); validateUplo(Uplo); validateUplo(Uplo); if (A.getType().getX() != A.getType().getY()) { throw new RSRuntimeException("Matrix A is not symmetric"); } validateL3(Element.F64_2(mRS), 0, 0, Side, A, B, C); validateL3(Element.F64_2(mRS), 0, 0, Side, A, B, C); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); Loading @@ -1164,7 +1195,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateUplo(Uplo); validateL3(Element.F32(mRS), Trans, 0, 0, A, null, C); validateL3(Element.F32(mRS), Trans, 0, 0, A, null, C); int K = -1; int K = -1; if (Trans == TRANSPOSE) { if (Trans != NO_TRANSPOSE) { K = A.getType().getY(); K = A.getType().getY(); } else { } else { K = A.getType().getX(); K = A.getType().getX(); Loading @@ -1178,7 +1209,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateUplo(Uplo); validateL3(Element.F64(mRS), Trans, 0, 0, A, null, C); validateL3(Element.F64(mRS), Trans, 0, 0, A, null, C); int K = -1; int K = -1; if (Trans == TRANSPOSE) { if (Trans != NO_TRANSPOSE) { K = A.getType().getY(); K = A.getType().getY(); } else { } else { K = A.getType().getX(); K = A.getType().getX(); Loading @@ -1190,7 +1221,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateUplo(Uplo); validateL3(Element.F32_2(mRS), Trans, 0, 0, A, null, C); validateL3(Element.F32_2(mRS), Trans, 0, 0, A, null, C); int K = -1; int K = -1; if (Trans == TRANSPOSE) { if (Trans != NO_TRANSPOSE) { K = A.getType().getY(); K = A.getType().getY(); } else { } else { K = A.getType().getX(); K = A.getType().getX(); Loading @@ -1203,7 +1234,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateUplo(Uplo); validateL3(Element.F64_2(mRS), Trans, 0, 0, A, null, C); validateL3(Element.F64_2(mRS), Trans, 0, 0, A, null, C); int K = -1; int K = -1; if (Trans == TRANSPOSE) { if (Trans != NO_TRANSPOSE) { K = A.getType().getY(); K = A.getType().getY(); } else { } else { K = A.getType().getX(); K = A.getType().getX(); Loading @@ -1229,7 +1260,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { // check rows versus C // check rows versus C Cdim = A.getType().getY(); Cdim = A.getType().getY(); } } if (C.getType().getX() != Cdim && C.getType().getY() != Cdim) { if (C.getType().getX() != Cdim || C.getType().getY() != Cdim) { throw new RSRuntimeException("Invalid symmetric matrix in SYR2K"); throw new RSRuntimeException("Invalid symmetric matrix in SYR2K"); } } // A dims == B dims // A dims == B dims Loading Loading @@ -1285,26 +1316,26 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { static void validateTRMM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) { static void validateTRMM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) { validateSide(Side); validateSide(Side); validateTranspose(TransA); validateTranspose(TransA); int aX = -1, aY = -1, bX = -1, bY = -1; int aM = -1, aN = -1, bM = -1, bN = -1; if (!A.getType().getElement().isCompatible(e) || if (!A.getType().getElement().isCompatible(e) || !B.getType().getElement().isCompatible(e)) { !B.getType().getElement().isCompatible(e)) { throw new RSRuntimeException("Called BLAS with wrong Element type"); throw new RSRuntimeException("Called BLAS with wrong Element type"); } } if (TransA == TRANSPOSE) { aY = A.getType().getY(); aM = A.getType().getY(); aX = A.getType().getX(); aN = A.getType().getX(); } else { if (aM != aN) { aY = A.getType().getX(); throw new RSRuntimeException("Called TRMM with a non-symmetric matrix A"); aX = A.getType().getY(); } } bX = B.getType().getY(); bY = B.getType().getX(); bM = B.getType().getY(); bN = B.getType().getX(); if (Side == LEFT) { if (Side == LEFT) { if (aX == 0 || aY != bX) { if (aN != bM) { throw new RSRuntimeException("Called TRMM with invalid matrices"); throw new RSRuntimeException("Called TRMM with invalid matrices"); } } } else { } else { if (bY != aX || aY == 0) { if (bN != aM) { throw new RSRuntimeException("Called TRMM with invalid matrices"); throw new RSRuntimeException("Called TRMM with invalid matrices"); } } } } Loading Loading @@ -1339,7 +1370,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { } } static void validateTRSM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) { static void validateTRSM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) { int adim = -1, bX = -1, bY = -1; int adim = -1, bM = -1, bN = -1; validateSide(Side); validateSide(Side); validateTranspose(TransA); validateTranspose(TransA); if (!A.getType().getElement().isCompatible(e) || if (!A.getType().getElement().isCompatible(e) || Loading @@ -1353,16 +1384,16 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { // for now we assume adapters are sufficient, will reevaluate in the future // for now we assume adapters are sufficient, will reevaluate in the future throw new RSRuntimeException("Called TRSM with a non-symmetric matrix A"); throw new RSRuntimeException("Called TRSM with a non-symmetric matrix A"); } } bX = B.getType().getY(); bM = B.getType().getY(); bY = B.getType().getX(); bN = B.getType().getX(); if (Side == LEFT) { if (Side == LEFT) { // A is M*M // A is M*M if (adim != bY) { if (adim != bM) { throw new RSRuntimeException("Called TRSM with invalid matrix dimensions"); throw new RSRuntimeException("Called TRSM with invalid matrix dimensions"); } } } else { } else { // A is N*N // A is N*N if (adim != bX) { if (adim != bN) { throw new RSRuntimeException("Called TRSM with invalid matrix dimensions"); throw new RSRuntimeException("Called TRSM with invalid matrix dimensions"); } } } } Loading Loading @@ -1427,7 +1458,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { } } public void ZHEMM(@Side int Side, @Uplo int Uplo, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) { public void ZHEMM(@Side int Side, @Uplo int Uplo, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) { validateUplo(Uplo); validateUplo(Uplo); validateHEMM(Element.F32_2(mRS), Side, A, B, C); validateHEMM(Element.F64_2(mRS), Side, A, B, C); mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); } } Loading @@ -1443,11 +1474,11 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { throw new RSRuntimeException("Called HERK with non-square C"); throw new RSRuntimeException("Called HERK with non-square C"); } } if (Trans == NO_TRANSPOSE) { if (Trans == NO_TRANSPOSE) { if (cdim != A.getType().getX()) { if (cdim != A.getType().getY()) { throw new RSRuntimeException("Called HERK with invalid A"); throw new RSRuntimeException("Called HERK with invalid A"); } } } else { } else { if (cdim != A.getType().getY()) { if (cdim != A.getType().getX()) { throw new RSRuntimeException("Called HERK with invalid A"); throw new RSRuntimeException("Called HERK with invalid A"); } } } } Loading @@ -1456,7 +1487,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateUplo(Uplo); validateHERK(Element.F32_2(mRS), Trans, A, C); validateHERK(Element.F32_2(mRS), Trans, A, C); int k = 0; int k = 0; if (Trans == TRANSPOSE) { if (Trans == CONJ_TRANSPOSE) { k = A.getType().getY(); k = A.getType().getY(); } else { } else { k = A.getType().getX(); k = A.getType().getX(); Loading @@ -1468,7 +1499,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { validateUplo(Uplo); validateUplo(Uplo); validateHERK(Element.F64_2(mRS), Trans, A, C); validateHERK(Element.F64_2(mRS), Trans, A, C); int k = 0; int k = 0; if (Trans == TRANSPOSE) { if (Trans == CONJ_TRANSPOSE) { k = A.getType().getY(); k = A.getType().getY(); } else { } else { k = A.getType().getX(); k = A.getType().getX(); Loading