Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit d0b4c1c3 authored by Miao Wang's avatar Miao Wang Committed by Android Git Automerger
Browse files

am f8e9fd2c: am a57a8a8c: Merge "[RenderScript] improve & minor fixes of L2...

am f8e9fd2c: am a57a8a8c: Merge "[RenderScript] improve & minor fixes of L2 BLAS validation." into mnc-dev

* commit 'f8e9fd2c':
  [RenderScript] improve & minor fixes of L2 BLAS validation.
parents 6c57c327 f8e9fd2c
Loading
Loading
Loading
Loading
+82 −42
Original line number Diff line number Diff line
@@ -276,7 +276,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
            expectedYDim = 1 + (N - 1) * incY;
        }
        if (X.getType().getX() != expectedXDim ||
            Y.getType().getY() != expectedXDim) {
            Y.getType().getX() != expectedYDim) {
            throw new RSRuntimeException("Incorrect vector dimensions for GEMV");
        }
    }
@@ -346,8 +346,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU);
    }

    static void validateTRMV(Element e, @Transpose int TransA, Allocation A, Allocation X, int incX) {
    static void validateTRMV(Element e, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
        validateTranspose(TransA);
        validateUplo(Uplo);
        validateDiag(Diag);
        int N = A.getType().getY();
        if (A.getType().getX() != N) {
            throw new RSRuntimeException("A must be a square matrix for TRMV");
@@ -386,59 +388,75 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
        }

        int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
        //is it really doing anything?
        if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
            throw new RSRuntimeException("Invalid dimension for Ap");
        }

        if (incX <= 0) {
            throw new RSRuntimeException("Vector increments must be greater than 0");
        }
        int expectedXDim = 1 + (N - 1) * incX;
        if (X.getType().getX() != expectedXDim) {
            throw new RSRuntimeException("Incorrect vector dimensions for SYMV");
            throw new RSRuntimeException("Incorrect vector dimensions for TPMV");
        }

        return N;
    }

    void STRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
        validateTRMV(Element.F32(mRS), TransA, A, X, incX);
        validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX);
        int N = A.getType().getY();
        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
    }
    void DTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
        validateTRMV(Element.F64(mRS), TransA, A, X, incX);
        validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX);
        int N = A.getType().getY();
        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
    }
    void CTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
        validateTRMV(Element.F32_2(mRS), TransA, A, X, incX);
        validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
        int N = A.getType().getY();
        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
    }
    void ZTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
        validateTRMV(Element.F64_2(mRS), TransA, A, X, incX);
        validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
        int N = A.getType().getY();
        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
    }

    void STBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  int K, Allocation A,  Allocation X,  int incX) {
        // TBMV has the same requirements as TRMV
        validateTRMV(Element.F32(mRS), TransA, A, X, incX);
        // TBMV has the same requirements as TRMV + K >= 0
        if (K < 0) {
            throw new RSRuntimeException("K must be greater than or equal to 0");
        }
        validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX);
        int N = A.getType().getY();
        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
    }
    void DTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  int K, Allocation A,  Allocation X,  int incX) {
        // TBMV has the same requirements as TRMV
        validateTRMV(Element.F64(mRS), TransA, A, X, incX);
        // TBMV has the same requirements as TRMV + K >= 0
        if (K < 0) {
            throw new RSRuntimeException("K must be greater than or equal to 0");
        }
        validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX);
        int N = A.getType().getY();
        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
    }
    void CTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  int K, Allocation A,  Allocation X,  int incX) {
        // TBMV has the same requirements as TRMV
        validateTRMV(Element.F32_2(mRS), TransA, A, X, incX);
        // TBMV has the same requirements as TRMV + K >= 0
        if (K < 0) {
            throw new RSRuntimeException("K must be greater than or equal to 0");
        }
        validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
        int N = A.getType().getY();
        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
    }
    void ZTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  int K, Allocation A,  Allocation X,  int incX) {
        // TBMV has the same requirements as TRMV
        validateTRMV(Element.F64_2(mRS), TransA, A, X, incX);
        // TBMV has the same requirements as TRMV + K >= 0
        if (K < 0) {
            throw new RSRuntimeException("K must be greater than or equal to 0");
        }
        validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
        int N = A.getType().getY();
        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
    }
@@ -460,35 +478,35 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
    }
    void STRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  Allocation A,  Allocation X,  int incX) {
        // TRSV is the same as TRMV
        validateTRMV(Element.F32(mRS), TransA, A, X, incX);
        validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX);
        int N = A.getType().getY();
        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);

    }
    void DTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  Allocation A,  Allocation X,  int incX) {
        // TRSV is the same as TRMV
        validateTRMV(Element.F64(mRS), TransA, A, X, incX);
        validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX);
        int N = A.getType().getY();
        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);

    }
    void CTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  Allocation A,  Allocation X,  int incX) {
        // TRSV is the same as TRMV
        validateTRMV(Element.F32_2(mRS), TransA, A, X, incX);
        validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
        int N = A.getType().getY();
        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);

    }
    void ZTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  Allocation A,  Allocation X,  int incX) {
        // TRSV is the same as TRMV
        validateTRMV(Element.F64_2(mRS), TransA, A, X, incX);
        validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
        int N = A.getType().getY();
        mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);

    }
    void STBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  int K, Allocation A,  Allocation X,  int incX) {
        // TBSV is the same as TRMV
        validateTRMV(Element.F32(mRS), TransA, A, X, incX);
        // TBSV is the same as TRMV + K >= 0
        validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX);
        int N = A.getType().getY();
        if (K < 0) {
            throw new RSRuntimeException("Number of diagonals must be positive");
@@ -496,8 +514,8 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
    }
    void DTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  int K, Allocation A,  Allocation X,  int incX) {
        // TBSV is the same as TRMV
        validateTRMV(Element.F64(mRS), TransA, A, X, incX);
        // TBSV is the same as TRMV + K >= 0
        validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX);
        int N = A.getType().getY();
        if (K < 0) {
            throw new RSRuntimeException("Number of diagonals must be positive");
@@ -505,8 +523,8 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
    }
    void CTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  int K, Allocation A,  Allocation X,  int incX) {
        // TBSV is the same as TRMV
        validateTRMV(Element.F32_2(mRS), TransA, A, X, incX);
        // TBSV is the same as TRMV + K >= 0
        validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
        int N = A.getType().getY();
        if (K < 0) {
            throw new RSRuntimeException("Number of diagonals must be positive");
@@ -514,8 +532,8 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
    }
    void ZTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag,  int K, Allocation A,  Allocation X,  int incX) {
        // TBSV is the same as TRMV
        validateTRMV(Element.F64_2(mRS), TransA, A, X, incX);
        // TBSV is the same as TRMV + K >= 0
        validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
        int N = A.getType().getY();
        if (K < 0) {
            throw new RSRuntimeException("Number of diagonals must be positive");
@@ -593,7 +611,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
        if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
            throw new RSRuntimeException("Invalid dimension for Ap");
        }

        if (incX <= 0 || incY <= 0) {
            throw new RSRuntimeException("Vector increments must be greater than 0");
        }
        int expectedXDim = 1 + (N - 1) * incX;
        if (X.getType().getX() != expectedXDim) {
            throw new RSRuntimeException("Incorrect vector dimensions for SPMV");
@@ -622,8 +642,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
        if (N < 1 || M < 1) {
            throw new RSRuntimeException("M and N must be 1 or greater for GER");
        }

        int expectedXDim = 1 + (N - 1) * incX;
        if (incX <= 0 || incY <= 0) {
            throw new RSRuntimeException("Vector increments must be greater than 0");
        }
        int expectedXDim = 1 + (M - 1) * incX;
        if (X.getType().getX() != expectedXDim) {
            throw new RSRuntimeException("Incorrect vector dimensions for GER");
        }
@@ -649,7 +671,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
        if (N != A.getType().getY()) {
            throw new RSRuntimeException("A must be a symmetric matrix");
        }

        if (incX <= 0) {
            throw new RSRuntimeException("Vector increments must be greater than 0");
        }
        int expectedXDim = 1 + (N - 1) * incX;
        if (X.getType().getX() != expectedXDim) {
            throw new RSRuntimeException("Incorrect vector dimensions for SYR");
@@ -674,10 +698,12 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
        if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
            throw new RSRuntimeException("Invalid dimension for Ap");
        }

        if (incX <= 0) {
            throw new RSRuntimeException("Vector increments must be greater than 0");
        }
        int expectedXDim = 1 + (N - 1) * incX;
        if (X.getType().getX() != expectedXDim) {
            throw new RSRuntimeException("Incorrect vector dimensions for SPMV");
            throw new RSRuntimeException("Incorrect vector dimensions for SPR");
        }

        return N;
@@ -700,7 +726,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
        if (N != A.getType().getY()) {
            throw new RSRuntimeException("A must be a symmetric matrix");
        }

        if (incX <= 0 || incY <= 0) {
            throw new RSRuntimeException("Vector increments must be greater than 0");
        }
        int expectedXDim = 1 + (N - 1) * incX;
        int expectedYDim = 1 + (N - 1) * incY;
        if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) {
@@ -728,11 +756,13 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
        if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
            throw new RSRuntimeException("Invalid dimension for Ap");
        }

        if (incX <= 0 || incY <= 0) {
            throw new RSRuntimeException("Vector increments must be greater than 0");
        }
        int expectedXDim = 1 + (N - 1) * incX;
        int expectedYDim = 1 + (N - 1) * incY;
        if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) {
            throw new RSRuntimeException("Incorrect vector dimensions for SPMV");
            throw new RSRuntimeException("Incorrect vector dimensions for SPR2");
        }

        return N;
@@ -743,7 +773,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
    }
    void SSBMV(@Uplo int Uplo, int K, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
        // SBMV is the same as SYMV
        // SBMV is the same as SYMV + K >= 0
        if (K < 0) {
            throw new RSRuntimeException("K must be greater than or equal to 0");
        }
        int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY);
        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
    }
@@ -754,6 +787,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
    void SGER(float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
        int M = A.getType().getY();
        int N = A.getType().getX();
        validateGER(Element.F32(mRS), X, incX, Y, incY, A);
        mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0);
    }
    void SSYR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) {
@@ -777,7 +811,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
    }
    void DSBMV(@Uplo int Uplo, int K, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
        // SBMV is the same as SYMV
        // SBMV is the same as SYMV + K >= 0
        if (K < 0) {
            throw new RSRuntimeException("K must be greater than or equal to 0");
        }
        int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY);
        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
    }
@@ -788,6 +825,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
    void DGER(double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
        int M = A.getType().getY();
        int N = A.getType().getX();
        validateGER(Element.F64(mRS), X, incX, Y, incY, A);
        mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0);
    }
    void DSYR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) {
@@ -824,8 +862,10 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {

        int M = A.getType().getY();
        int N = A.getType().getX();

        int expectedXDim = 1 + (N - 1) * incX;
        if (incX <= 0 || incY <= 0) {
            throw new RSRuntimeException("Vector increments must be greater than 0");
        }
        int expectedXDim = 1 + (M - 1) * incX;
        if (X.getType().getX() != expectedXDim) {
            throw new RSRuntimeException("Incorrect vector dimensions for GERU");
        }
@@ -869,7 +909,7 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
    }
    void CHER(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) {
        // same as SYR
        int N = validateSYR(Element.F32(mRS), Uplo, X, incX, A);
        int N = validateSYR(Element.F32_2(mRS), Uplo, X, incX, A);
        mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0);
    }
    void CHPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) {