Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 91d2f27b authored by Tim Murray's avatar Tim Murray Committed by Android Git Automerger
Browse files

am 7f72f747: Merge "Add BNNM intrinsic."

* commit '7f72f747':
  Add BNNM intrinsic.
parents 1cb119d3 7f72f747
Loading
Loading
Loading
Loading
+11 −0
Original line number Diff line number Diff line
@@ -951,6 +951,17 @@ public class RenderScript {
        rsnScriptIntrinsicBLAS_Z(mContext, id, func, TransA, TransB, Side, Uplo, Diag, M, N, K, alphaX, alphaY, A, B, betaX, betaY, C, incX, incY, KL, KU);
    }

    native void rsnScriptIntrinsicBLAS_BNNM(long con, long id, int M, int N, int K,
                                             long A, int a_offset, long B, int b_offset, long C, int c_offset,
                                             int c_mult_int);
    synchronized void nScriptIntrinsicBLAS_BNNM(long id, int M, int N, int K,
                                             long A, int a_offset, long B, int b_offset, long C, int c_offset,
                                             int c_mult_int) {
        validate();
        rsnScriptIntrinsicBLAS_BNNM(mContext, id, M, N, K, A, a_offset, B, b_offset, C, c_offset, c_mult_int);
    }



    long     mDev;
    long     mContext;
+21 −0
Original line number Diff line number Diff line
@@ -176,6 +176,9 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
    private static final int RsBlas_zherk = 141;
    private static final int RsBlas_zher2k = 142;

    // BLAS extensions start here
    private static final int RsBlas_bnnm = 1000;

    /**
     */
    public static ScriptIntrinsicBLAS create(RenderScript rs) {
@@ -1485,5 +1488,23 @@ public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
    }


    /**
     *
     * 8-bit GEMM-like operation for neural networks
     *
     * @hide
     **/
    public void BNNM(Allocation A, int a_offset, Allocation B, int b_offset, Allocation C, int c_offset, int c_mult) {
        validateL3(Element.U8(mRS), NO_TRANSPOSE, TRANSPOSE, 0, A, B, C);

        int M = -1, N = -1, K = -1;
        M = A.getType().getY();
        N = B.getType().getY();
        K = A.getType().getX();


        mRS.nScriptIntrinsicBLAS_BNNM(getID(mRS), M, N, K, A.getID(mRS), a_offset, B.getID(mRS), b_offset, C.getID(mRS), c_offset, c_mult);

    }

}
+28 −0
Original line number Diff line number Diff line
@@ -583,6 +583,32 @@ nScriptIntrinsicBLAS_Z(JNIEnv *_env, jobject _this, jlong con, jlong id, jint fu
}


static void
nScriptIntrinsicBLAS_BNNM(JNIEnv *_env, jobject _this, jlong con, jlong id, jint M, jint N, jint K,
                                             jlong A, jint a_offset, jlong B, jint b_offset, jlong C, jint c_offset,
                                             jint c_mult_int) {
    RsBlasCall call;
    memset(&call, 0, sizeof(call));
    call.func = RsBlas_bnnm;
    call.M = M;
    call.N = N;
    call.K = K;
    call.a_offset = a_offset;
    call.b_offset = b_offset;
    call.c_offset = c_offset;
    call.c_mult_int = c_mult_int;

    RsAllocation in_allocs[3];
    in_allocs[0] = (RsAllocation)A;
    in_allocs[1] = (RsAllocation)B;
    in_allocs[2] = (RsAllocation)C;

    rsScriptForEachMulti((RsContext)con, (RsScript)id, 0,
                         in_allocs, sizeof(in_allocs), nullptr,
                         &call, sizeof(call), nullptr, 0);
}


static void
nAssignName(JNIEnv *_env, jobject _this, jlong con, jlong obj, jbyteArray str)
{
@@ -2427,6 +2453,8 @@ static JNINativeMethod methods[] = {
{"rsnScriptIntrinsicBLAS_Complex",   "(JJIIIIIIIIIFFJJFFJIIII)V",             (void*)nScriptIntrinsicBLAS_Complex },
{"rsnScriptIntrinsicBLAS_Z",         "(JJIIIIIIIIIDDJJDDJIIII)V",             (void*)nScriptIntrinsicBLAS_Z },

{"rsnScriptIntrinsicBLAS_BNNM",      "(JJIIIJIJIJII)V",                       (void*)nScriptIntrinsicBLAS_BNNM },

{"rsnProgramStoreCreate",            "(JZZZZZZIII)J",                         (void*)nProgramStoreCreate },

{"rsnProgramBindConstants",          "(JJIJ)V",                               (void*)nProgramBindConstants },