Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 9e4dbf05 authored by Miao Wang's avatar Miao Wang Committed by Android (Google) Code Review
Browse files

Merge changes I99f9f9ff,I559b5c56 into mnc-dev

* changes:
  [RenderScript] L2 BLAS, fix element type in ZHER
  [RenderScript] fixes for L3 BLAS APIs
parents b62dc82b cecc00ab
Loading
Loading
Loading
Loading
+89 −58
Original line number Diff line number Diff line
@@ -241,7 +241,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
    }

    static void validateUplo(@Uplo int Uplo) {
        if (Uplo != LEFT && Uplo != RIGHT) {
        if (Uplo != UPPER && Uplo != LOWER) {
            throw new RSRuntimeException("Invalid uplo passed to BLAS");
        }
    }
@@ -960,7 +960,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
    }
    public void ZHER(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) {
        // same as SYR
        int N = validateSYR(Element.F64(mRS), Uplo, X, incX, A);
        int N = validateSYR(Element.F64_2(mRS), Uplo, X, incX, A);
        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0);
    }
    public void ZHPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) {
@@ -985,56 +985,74 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
     */

    static void validateL3(Element e, int TransA, int TransB, int Side, Allocation A, Allocation B, Allocation C) {
        int aX = -1, aY = -1, bX = -1, bY = -1, cX = -1, cY = -1;
        int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1;
        if ((A != null && !A.getType().getElement().isCompatible(e)) ||
            (B != null && !B.getType().getElement().isCompatible(e)) ||
            (C != null && !C.getType().getElement().isCompatible(e))) {
            throw new RSRuntimeException("Called BLAS with wrong Element type");
        }
        if (C != null) {
            cX = C.getType().getY();
            cY = C.getType().getX();
        if (C == null) {
            //since matrix C is used to store the result, it cannot be null.
            throw new RSRuntimeException("Allocation C cannot be null");
        }
        cM = C.getType().getY();
        cN = C.getType().getX();

        if (Side == RIGHT) {
            if ((A == null && B != null) || (A != null && B == null)) {
                throw new RSRuntimeException("Provided Matrix A without Matrix B, or vice versa");
            }
            if (B != null) {
                bX = A.getType().getY();
                bY = A.getType().getX();
                bM = A.getType().getY();
                bN = A.getType().getX();
            }
            if (A != null) {
                aX = B.getType().getY();
                aY = B.getType().getX();
                aM = B.getType().getY();
                aN = B.getType().getX();
            }
        } else {
            if (A != null) {
                if (TransA == TRANSPOSE) {
                    aY = A.getType().getY();
                    aX = A.getType().getX();
                if (TransA != NO_TRANSPOSE) {
                    aN = A.getType().getY();
                    aM = A.getType().getX();
                } else {
                    aX = A.getType().getY();
                    aY = A.getType().getX();
                    aM = A.getType().getY();
                    aN = A.getType().getX();
                }
            }
            if (B != null) {
                if (TransB == TRANSPOSE) {
                    bY = B.getType().getY();
                    bX = B.getType().getX();
                if (TransB != NO_TRANSPOSE) {
                    bN = B.getType().getY();
                    bM = B.getType().getX();
                } else {
                    bX = B.getType().getY();
                    bY = B.getType().getX();
                    bM = B.getType().getY();
                    bN = B.getType().getX();
                }
            }
        }
        if (A != null && B != null && C != null) {
            if (aY != bX || aX != cX || bY != cY) {
            if (aN != bM || aM != cM || bN != cN) {
                throw new RSRuntimeException("Called BLAS with invalid dimensions");
            }
        } else if (A != null && C != null) {
            // A and C only
            if (aX != cY || aY != cX) {
            // A and C only, for SYRK
            if (cM != cN) {
                throw new RSRuntimeException("Matrix C is not symmetric");
            }
            if (TransA != NO_TRANSPOSE) {
                if (aN != cM) {
                    throw new RSRuntimeException("Called BLAS with invalid dimensions");
                }
            } else {
                if (aM != cM) {
                    throw new RSRuntimeException("Called BLAS with invalid dimensions");
                }
            }
        } else if (A != null && B != null) {
            // A and B only
            if (aN != bM) {
                throw new RSRuntimeException("Called BLAS with invalid dimensions");
            }
        }

    }
@@ -1046,14 +1064,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
        validateL3(Element.F32(mRS), TransA, TransB, 0, A, B, C);

        int M = -1, N = -1, K = -1;
        if (TransA == TRANSPOSE) {
        if (TransA != NO_TRANSPOSE) {
            M = A.getType().getX();
            K = A.getType().getY();
        } else {
            M = A.getType().getY();
            K = A.getType().getX();
        }
        if (TransB == TRANSPOSE) {
        if (TransB != NO_TRANSPOSE) {
            N = B.getType().getY();
        } else {
            N = B.getType().getX();
@@ -1067,14 +1085,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
        validateTranspose(TransB);
        validateL3(Element.F64(mRS), TransA, TransB, 0, A, B, C);
        int M = -1, N = -1, K = -1;
        if (TransA == TRANSPOSE) {
        if (TransA != NO_TRANSPOSE) {
            M = A.getType().getX();
            K = A.getType().getY();
        } else {
            M = A.getType().getY();
            K = A.getType().getX();
        }
        if (TransB == TRANSPOSE) {
        if (TransB != NO_TRANSPOSE) {
            N = B.getType().getY();
        } else {
            N = B.getType().getX();
@@ -1088,14 +1106,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
        validateTranspose(TransB);
        validateL3(Element.F32_2(mRS), TransA, TransB, 0, A, B, C);
        int M = -1, N = -1, K = -1;
        if (TransA == TRANSPOSE) {
        if (TransA != NO_TRANSPOSE) {
            M = A.getType().getX();
            K = A.getType().getY();
        } else {
            M = A.getType().getY();
            K = A.getType().getX();
        }
        if (TransB == TRANSPOSE) {
        if (TransB != NO_TRANSPOSE) {
            N = B.getType().getY();
        } else {
            N = B.getType().getX();
@@ -1110,14 +1128,14 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
        validateTranspose(TransB);
        validateL3(Element.F64_2(mRS), TransA, TransB, 0, A, B, C);
        int M = -1, N = -1, K = -1;
        if (TransA == TRANSPOSE) {
        if (TransA != NO_TRANSPOSE) {
            M = A.getType().getX();
            K = A.getType().getY();
        } else {
            M = A.getType().getY();
            K = A.getType().getX();
        }
        if (TransB == TRANSPOSE) {
        if (TransB != NO_TRANSPOSE) {
            N = B.getType().getY();
        } else {
            N = B.getType().getX();
@@ -1130,6 +1148,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
                      Allocation B, float beta, Allocation C) {
        validateSide(Side);
        validateUplo(Uplo);
        //For SYMM, Matrix A should be symmetric
        if (A.getType().getX() != A.getType().getY()) {
            throw new RSRuntimeException("Matrix A is not symmetric");
        }
        validateL3(Element.F32(mRS), 0, 0, Side, A, B, C);
        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS),
                                        beta, C.getID(mRS), 0, 0, 0, 0);
@@ -1138,6 +1160,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
                      Allocation B, double beta, Allocation C) {
        validateSide(Side);
        validateUplo(Uplo);
        if (A.getType().getX() != A.getType().getY()) {
            throw new RSRuntimeException("Matrix A is not symmetric");
        }
        validateL3(Element.F64(mRS), 0, 0, Side, A, B, C);
        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS),
                                        beta, C.getID(mRS), 0, 0, 0, 0);
@@ -1146,6 +1171,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
                      Allocation B, Float2 beta, Allocation C) {
        validateSide(Side);
        validateUplo(Uplo);
        if (A.getType().getX() != A.getType().getY()) {
            throw new RSRuntimeException("Matrix A is not symmetric");
        }
        validateL3(Element.F32_2(mRS), 0, 0, Side, A, B, C);
        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
                                         beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
@@ -1154,6 +1182,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
                      Allocation B, Double2 beta, Allocation C) {
        validateSide(Side);
        validateUplo(Uplo);
        if (A.getType().getX() != A.getType().getY()) {
            throw new RSRuntimeException("Matrix A is not symmetric");
        }
        validateL3(Element.F64_2(mRS), 0, 0, Side, A, B, C);
        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
                                   beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
@@ -1164,7 +1195,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
        validateUplo(Uplo);
        validateL3(Element.F32(mRS), Trans, 0, 0, A, null, C);
        int K = -1;
        if (Trans == TRANSPOSE) {
        if (Trans != NO_TRANSPOSE) {
            K = A.getType().getY();
        } else {
            K = A.getType().getX();
@@ -1178,7 +1209,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
        validateUplo(Uplo);
        validateL3(Element.F64(mRS), Trans, 0, 0, A, null, C);
        int K = -1;
        if (Trans == TRANSPOSE) {
        if (Trans != NO_TRANSPOSE) {
            K = A.getType().getY();
        } else {
            K = A.getType().getX();
@@ -1190,7 +1221,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
        validateUplo(Uplo);
        validateL3(Element.F32_2(mRS), Trans, 0, 0, A, null, C);
        int K = -1;
        if (Trans == TRANSPOSE) {
        if (Trans != NO_TRANSPOSE) {
            K = A.getType().getY();
        } else {
            K = A.getType().getX();
@@ -1203,7 +1234,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
        validateUplo(Uplo);
        validateL3(Element.F64_2(mRS), Trans, 0, 0, A, null, C);
        int K = -1;
        if (Trans == TRANSPOSE) {
        if (Trans != NO_TRANSPOSE) {
            K = A.getType().getY();
        } else {
            K = A.getType().getX();
@@ -1229,7 +1260,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
            // check rows versus C
            Cdim = A.getType().getY();
        }
        if (C.getType().getX() != Cdim && C.getType().getY() != Cdim) {
        if (C.getType().getX() != Cdim || C.getType().getY() != Cdim) {
            throw new RSRuntimeException("Invalid symmetric matrix in SYR2K");
        }
        // A dims == B dims
@@ -1285,26 +1316,26 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
    static void validateTRMM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) {
        validateSide(Side);
        validateTranspose(TransA);
        int aX = -1, aY = -1, bX = -1, bY = -1;
        int aM = -1, aN = -1, bM = -1, bN = -1;
        if (!A.getType().getElement().isCompatible(e) ||
            !B.getType().getElement().isCompatible(e)) {
            throw new RSRuntimeException("Called BLAS with wrong Element type");
        }
        if (TransA == TRANSPOSE) {
            aY = A.getType().getY();
            aX = A.getType().getX();
        } else {
            aY = A.getType().getX();
            aX = A.getType().getY();

        aM = A.getType().getY();
        aN = A.getType().getX();
        if (aM != aN) {
            throw new RSRuntimeException("Called TRMM with a non-symmetric matrix A");
        }
        bX = B.getType().getY();
        bY = B.getType().getX();

        bM = B.getType().getY();
        bN = B.getType().getX();
        if (Side == LEFT) {
            if (aX == 0 || aY != bX) {
            if (aN != bM) {
                throw new RSRuntimeException("Called TRMM with invalid matrices");
            }
        } else {
            if (bY != aX || aY == 0) {
            if (bN != aM) {
                throw new RSRuntimeException("Called TRMM with invalid matrices");
            }
        }
@@ -1339,7 +1370,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
    }

    static void validateTRSM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) {
        int adim = -1, bX = -1, bY = -1;
        int adim = -1, bM = -1, bN = -1;
        validateSide(Side);
        validateTranspose(TransA);
        if (!A.getType().getElement().isCompatible(e) ||
@@ -1353,16 +1384,16 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
            // for now we assume adapters are sufficient, will reevaluate in the future
            throw new RSRuntimeException("Called TRSM with a non-symmetric matrix A");
        }
        bX = B.getType().getY();
        bY = B.getType().getX();
        bM = B.getType().getY();
        bN = B.getType().getX();
        if (Side == LEFT) {
            // A is M*M
            if (adim != bY) {
            if (adim != bM) {
                throw new RSRuntimeException("Called TRSM with invalid matrix dimensions");
            }
        } else {
            // A is N*N
            if (adim != bX) {
            if (adim != bN) {
                throw new RSRuntimeException("Called TRSM with invalid matrix dimensions");
            }
        }
@@ -1427,7 +1458,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
    }
    public void ZHEMM(@Side int Side, @Uplo int Uplo, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) {
        validateUplo(Uplo);
        validateHEMM(Element.F32_2(mRS), Side, A, B, C);
        validateHEMM(Element.F64_2(mRS), Side, A, B, C);
        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0,
                                   alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
    }
@@ -1443,11 +1474,11 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
            throw new RSRuntimeException("Called HERK with non-square C");
        }
        if (Trans == NO_TRANSPOSE) {
            if (cdim != A.getType().getX()) {
            if (cdim != A.getType().getY()) {
                throw new RSRuntimeException("Called HERK with invalid A");
            }
        } else {
            if (cdim != A.getType().getY()) {
            if (cdim != A.getType().getX()) {
                throw new RSRuntimeException("Called HERK with invalid A");
            }
        }
@@ -1456,7 +1487,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
        validateUplo(Uplo);
        validateHERK(Element.F32_2(mRS), Trans, A, C);
        int k = 0;
        if (Trans == TRANSPOSE) {
        if (Trans == CONJ_TRANSPOSE) {
            k = A.getType().getY();
        } else {
            k = A.getType().getX();
@@ -1468,7 +1499,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
        validateUplo(Uplo);
        validateHERK(Element.F64_2(mRS), Trans, A, C);
        int k = 0;
        if (Trans == TRANSPOSE) {
        if (Trans == CONJ_TRANSPOSE) {
            k = A.getType().getY();
        } else {
            k = A.getType().getX();