Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit f33562f5 authored by TreeHugger Robot's avatar TreeHugger Robot Committed by Automerger Merge Worker
Browse files

Merge "Add benchmarks for WASM bidding logic" into tm-mainline-prod am: 5c7f06a2

parents 2f3466f5 5c7f06a2
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -33,6 +33,7 @@ android_test {
        "compatibility-device-util-axt",
        "platform-test-annotations",
        "adservices-service-core",
        "androidx.core_core",
    ],
    test_suites: ["device-tests"],
    data: [":perfetto_artifacts"],
+1.25 MiB

File added.

No diff preview for this file type.

+24 −0
Original line number Diff line number Diff line
function generateBid(ad, wasmModule) {
  let input = ad.metadata.input;

  const instance = new WebAssembly.Instance(wasmModule);

  const memory = instance.exports.memory;
  const input_in_memory = new Float32Array(memory.buffer, 0, 200);
  for (let i = 0; i < input.length; ++i) {
    input_in_memory[i] = input[i];
  }
  const results = [
    instance.exports.nn_forward_model0(input_in_memory.length, input_in_memory),
    instance.exports.nn_forward_model1(input_in_memory.length, input_in_memory),
    instance.exports.nn_forward_model2(input_in_memory.length, input_in_memory),
    instance.exports.nn_forward_model3(input_in_memory.length, input_in_memory),
    instance.exports.nn_forward_model4(input_in_memory.length, input_in_memory),
  ];
  const bid = results.map(x => Math.max(x, 1)).reduce((x, y) => x * y);
  return {
    ad: 'example',
    bid: bid,
    render: ad.renderUrl
  }
}
 No newline at end of file
+85 −12
Original line number Diff line number Diff line
@@ -24,6 +24,9 @@ import static com.android.adservices.service.js.JSScriptArgument.stringArrayArg;

import static com.google.common.truth.Truth.assertThat;

import static org.junit.Assume.assumeTrue;

import android.annotation.SuppressLint;
import android.content.Context;
import android.perftests.utils.BenchmarkState;
import android.perftests.utils.PerfStatusReporter;
@@ -45,48 +48,44 @@ import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ListenableFuture;

import org.json.JSONArray;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;

import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

/** To run the unit tests for this class, run "atest RubidiumPerfTests:JSScriptEnginePerfTests" */
@MediumTest
@RunWith(AndroidJUnit4.class)
public class JSScriptEnginePerfTests {
    private static final String TAG = JSScriptEnginePerfTests.class.getSimpleName();
    private static final String TAG = JSScriptEngine.TAG;
    private static final Context sContext = ApplicationProvider.getApplicationContext();
    private static final ExecutorService sExecutorService = Executors.newFixedThreadPool(10);

    private static JSScriptEngine sJSScriptEngine;
    private static final JSScriptEngine sJSScriptEngine =
            JSScriptEngine.getInstanceForTesting(
                    sContext, Profiler.createInstance(JSScriptEngine.TAG));

    @Rule public PerfStatusReporter mPerfStatusReporter = new PerfStatusReporter();

    @Before
    public void before() throws Exception {
        Profiler profiler = Profiler.createInstance(JSScriptEngine.TAG);
        sJSScriptEngine = JSScriptEngine.getInstanceForTesting(sContext, profiler);

        // Warm up the sandbox env.
        callJSEngine(
                "function test() { return \"hello world\";" + " }", ImmutableList.of(), "test");
    }

    @After
    public void after() {
        sJSScriptEngine.shutdown();
    }

    @Test
    public void evaluate_helloWorld() throws Exception {
        BenchmarkState state = mPerfStatusReporter.getBenchmarkState();
@@ -156,6 +155,7 @@ public class JSScriptEnginePerfTests {
        runParametrizedTurtledoveScript(75);
    }

    @SuppressLint("DefaultLocale")
    private void runParametrizedTurtledoveScript(int numAds) throws Exception {
        BenchmarkState state = mPerfStatusReporter.getBenchmarkState();
        state.pauseTiming();
@@ -220,7 +220,34 @@ public class JSScriptEnginePerfTests {
        return arrayArg("foo", Collections.nCopies(numCustomAudiences, interestGroupArg));
    }

    private static String callJSEngine(
    @Test
    public void evaluate_turtledoveWasm() throws Exception {
        assumeTrue(sJSScriptEngine.isWasmSupported().get(3, TimeUnit.SECONDS));

        BenchmarkState state = mPerfStatusReporter.getBenchmarkState();
        state.pauseTiming();

        String jsTestFile = readAsset("generate_bid_using_wasm.js");
        byte[] wasmTestFile = readBinaryAsset("generate_bid.wasm");
        JSScriptArgument[] inputBytes = new JSScriptArgument[200];
        Random rand = new Random();
        for (int i = 0; i < inputBytes.length; i++) {
            byte value = (byte) (rand.nextInt(2 * Byte.MAX_VALUE) - Byte.MIN_VALUE);
            inputBytes[i] = JSScriptArgument.numericArg("_", value);
        }
        JSScriptArgument adDataArgument =
                recordArg(
                        "ad",
                        stringArg("render_url", "http://google.com"),
                        recordArg("metadata", JSScriptArgument.arrayArg("input", inputBytes)));

        state.resumeTiming();
        while (state.keepRunning()) {
            callJSEngine(jsTestFile, wasmTestFile, ImmutableList.of(adDataArgument), "generateBid");
        }
    }

    private String callJSEngine(
            @NonNull String jsScript,
            @NonNull List<JSScriptArgument> args,
            @NonNull String functionName)
@@ -228,6 +255,15 @@ public class JSScriptEnginePerfTests {
        return callJSEngine(sJSScriptEngine, jsScript, args, functionName);
    }

    private String callJSEngine(
            @NonNull String jsScript,
            @NonNull byte[] wasmScript,
            @NonNull List<JSScriptArgument> args,
            @NonNull String functionName)
            throws Exception {
        return callJSEngine(sJSScriptEngine, jsScript, wasmScript, args, functionName);
    }

    private static String callJSEngine(
            @NonNull JSScriptEngine jsScriptEngine,
            @NonNull String jsScript,
@@ -241,6 +277,21 @@ public class JSScriptEnginePerfTests {
        return futureResult.get();
    }

    private String callJSEngine(
            @NonNull JSScriptEngine jsScriptEngine,
            @NonNull String jsScript,
            @NonNull byte[] wasmScript,
            @NonNull List<JSScriptArgument> args,
            @NonNull String functionName)
            throws Exception {
        CountDownLatch resultLatch = new CountDownLatch(1);
        ListenableFuture<String> futureResult =
                callJSEngineAsync(
                        jsScriptEngine, jsScript, wasmScript, args, functionName, resultLatch);
        resultLatch.await();
        return futureResult.get();
    }

    private static ListenableFuture<String> callJSEngineAsync(
            @NonNull String jsScript,
            @NonNull List<JSScriptArgument> args,
@@ -261,4 +312,26 @@ public class JSScriptEnginePerfTests {
        result.addListener(resultLatch::countDown, sExecutorService);
        return result;
    }

    private ListenableFuture<String> callJSEngineAsync(
            @NonNull JSScriptEngine engine,
            @NonNull String jsScript,
            @NonNull byte[] wasmScript,
            @NonNull List<JSScriptArgument> args,
            @NonNull String functionName,
            @NonNull CountDownLatch resultLatch) {
        Objects.requireNonNull(engine);
        Objects.requireNonNull(resultLatch);
        ListenableFuture<String> result = engine.evaluate(jsScript, wasmScript, args, functionName);
        result.addListener(resultLatch::countDown, sExecutorService);
        return result;
    }

    private byte[] readBinaryAsset(@NonNull String assetName) throws IOException {
        return sContext.getAssets().open(assetName).readAllBytes();
    }

    private String readAsset(@NonNull String assetName) throws IOException {
        return new String(readBinaryAsset(assetName), StandardCharsets.UTF_8);
    }
}