Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 7f72f747 authored by Tim Murray's avatar Tim Murray Committed by Gerrit Code Review
Browse files

Merge "Add BNNM intrinsic."

parents 71ba4e46 9cb16a2f
Loading
Loading
Loading
Loading
+11 −0
Original line number Diff line number Diff line
@@ -951,6 +951,17 @@ public class RenderScript {
        rsnScriptIntrinsicBLAS_Z(mContext, id, func, TransA, TransB, Side, Uplo, Diag, M, N, K, alphaX, alphaY, A, B, betaX, betaY, C, incX, incY, KL, KU);
    }

    native void rsnScriptIntrinsicBLAS_BNNM(long con, long id, int M, int N, int K,
                                             long A, int a_offset, long B, int b_offset, long C, int c_offset,
                                             int c_mult_int);
    synchronized void nScriptIntrinsicBLAS_BNNM(long id, int M, int N, int K,
                                             long A, int a_offset, long B, int b_offset, long C, int c_offset,
                                             int c_mult_int) {
        validate();
        rsnScriptIntrinsicBLAS_BNNM(mContext, id, M, N, K, A, a_offset, B, b_offset, C, c_offset, c_mult_int);
    }



    long     mDev;
    long     mContext;
+21 −0
Original line number Diff line number Diff line
@@ -176,6 +176,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
    private static final int RsBlas_zherk = 141;
    private static final int RsBlas_zher2k = 142;

    // BLAS extensions start here
    private static final int RsBlas_bnnm = 1000;

    /**
     */
    public static ScriptIntrinsicBLAS create(RenderScript rs) {
@@ -1485,5 +1488,23 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
    }


    /**
     *
     * 8-bit GEMM-like operation for neural networks
     *
     * @hide
     **/
    public void BNNM(Allocation A, int a_offset, Allocation B, int b_offset, Allocation C, int c_offset, int c_mult) {
        validateL3(Element.U8(mRS), NO_TRANSPOSE, TRANSPOSE, 0, A, B, C);

        int M = -1, N = -1, K = -1;
        M = A.getType().getY();
        N = B.getType().getY();
        K = A.getType().getX();


        mRS.nScriptIntrinsicBLAS_BNNM(getID(mRS), M, N, K, A.getID(mRS), a_offset, B.getID(mRS), b_offset, C.getID(mRS), c_offset, c_mult);

    }

}
+28 −0
Original line number Diff line number Diff line
@@ -583,6 +583,32 @@ nScriptIntrinsicBLAS_Z(JNIEnv *_env, jobject _this, jlong con, jlong id, jint fu
}


static void
nScriptIntrinsicBLAS_BNNM(JNIEnv *_env, jobject _this, jlong con, jlong id, jint M, jint N, jint K,
                                             jlong A, jint a_offset, jlong B, jint b_offset, jlong C, jint c_offset,
                                             jint c_mult_int) {
    RsBlasCall call;
    memset(&call, 0, sizeof(call));
    call.func = RsBlas_bnnm;
    call.M = M;
    call.N = N;
    call.K = K;
    call.a_offset = a_offset;
    call.b_offset = b_offset;
    call.c_offset = c_offset;
    call.c_mult_int = c_mult_int;

    RsAllocation in_allocs[3];
    in_allocs[0] = (RsAllocation)A;
    in_allocs[1] = (RsAllocation)B;
    in_allocs[2] = (RsAllocation)C;

    rsScriptForEachMulti((RsContext)con, (RsScript)id, 0,
                         in_allocs, sizeof(in_allocs), nullptr,
                         &call, sizeof(call), nullptr, 0);
}


static void
nAssignName(JNIEnv *_env, jobject _this, jlong con, jlong obj, jbyteArray str)
{
@@ -2427,6 +2453,8 @@ static JNINativeMethod methods[] = {
{"rsnScriptIntrinsicBLAS_Complex",   "(JJIIIIIIIIIFFJJFFJIIII)V",             (void*)nScriptIntrinsicBLAS_Complex },
{"rsnScriptIntrinsicBLAS_Z",         "(JJIIIIIIIIIDDJJDDJIIII)V",             (void*)nScriptIntrinsicBLAS_Z },

{"rsnScriptIntrinsicBLAS_BNNM",      "(JJIIIJIJIJII)V",                       (void*)nScriptIntrinsicBLAS_BNNM },

{"rsnProgramStoreCreate",            "(JZZZZZZIII)J",                         (void*)nProgramStoreCreate },

{"rsnProgramBindConstants",          "(JJIJ)V",                               (void*)nProgramBindConstants },