Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 7713b243 authored by Stefano Galarraga's avatar Stefano Galarraga
Browse files

Add benchmarks for WASM bidding logic

Add new test using new WebView WASM support.
Add JetPack Core dependency to fix the ClassNotFoundException issue when
looking for the ContextCompat class.

Results of running the `evaluate_turtledoveWasm` test on a Pixel 4:

```
evaluate_turtledoveWasm_standardDeviation: 838804
evaluate_turtledoveWasm_min (ns): 21844616
evaluate_turtledoveWasm_median (ns): 22534819
evaluate_turtledoveWasm_mean (ns): 22640871
```

Test: JSScriptEnginePerfTests
Change-Id: Id79234f17258439ccc633cae2e1cbddc356aff99
Bug: 230095467
parent 10f8a27d
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -33,6 +33,7 @@ android_test {
        "compatibility-device-util-axt",
        "platform-test-annotations",
        "adservices-service-core",
        "androidx.core_core",
    ],
    test_suites: ["device-tests"],
    data: [":perfetto_artifacts"],
+1.25 MiB

File added.

No diff preview for this file type.

+24 −0
Original line number Diff line number Diff line
function generateBid(ad, wasmModule) {
  let input = ad.metadata.input;

  const instance = new WebAssembly.Instance(wasmModule);

  const memory = instance.exports.memory;
  const input_in_memory = new Float32Array(memory.buffer, 0, 200);
  for (let i = 0; i < input.length; ++i) {
    input_in_memory[i] = input[i];
  }
  const results = [
    instance.exports.nn_forward_model0(input_in_memory.length, input_in_memory),
    instance.exports.nn_forward_model1(input_in_memory.length, input_in_memory),
    instance.exports.nn_forward_model2(input_in_memory.length, input_in_memory),
    instance.exports.nn_forward_model3(input_in_memory.length, input_in_memory),
    instance.exports.nn_forward_model4(input_in_memory.length, input_in_memory),
  ];
  const bid = results.map(x => Math.max(x, 1)).reduce((x, y) => x * y);
  return {
    ad: 'example',
    bid: bid,
    render: ad.renderUrl
  }
}
 No newline at end of file
+85 −12
Original line number Diff line number Diff line
@@ -24,6 +24,9 @@ import static com.android.adservices.service.js.JSScriptArgument.stringArrayArg;

import static com.google.common.truth.Truth.assertThat;

import static org.junit.Assume.assumeTrue;

import android.annotation.SuppressLint;
import android.content.Context;
import android.perftests.utils.BenchmarkState;
import android.perftests.utils.PerfStatusReporter;
@@ -45,48 +48,44 @@ import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ListenableFuture;

import org.json.JSONArray;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;

import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

/** To run the unit tests for this class, run "atest RubidiumPerfTests:JSScriptEnginePerfTests" */
@MediumTest
@RunWith(AndroidJUnit4.class)
public class JSScriptEnginePerfTests {
    private static final String TAG = JSScriptEnginePerfTests.class.getSimpleName();
    private static final String TAG = JSScriptEngine.TAG;
    private static final Context sContext = ApplicationProvider.getApplicationContext();
    private static final ExecutorService sExecutorService = Executors.newFixedThreadPool(10);

    private static JSScriptEngine sJSScriptEngine;
    private static final JSScriptEngine sJSScriptEngine =
            JSScriptEngine.getInstanceForTesting(
                    sContext, Profiler.createInstance(JSScriptEngine.TAG));

    @Rule public PerfStatusReporter mPerfStatusReporter = new PerfStatusReporter();

    @Before
    public void before() throws Exception {
        Profiler profiler = Profiler.createInstance(JSScriptEngine.TAG);
        sJSScriptEngine = JSScriptEngine.getInstanceForTesting(sContext, profiler);

        // Warm up the sandbox env.
        callJSEngine(
                "function test() { return \"hello world\";" + " }", ImmutableList.of(), "test");
    }

    @After
    public void after() {
        sJSScriptEngine.shutdown();
    }

    @Test
    public void evaluate_helloWorld() throws Exception {
        BenchmarkState state = mPerfStatusReporter.getBenchmarkState();
@@ -156,6 +155,7 @@ public class JSScriptEnginePerfTests {
        runParametrizedTurtledoveScript(75);
    }

    @SuppressLint("DefaultLocale")
    private void runParametrizedTurtledoveScript(int numAds) throws Exception {
        BenchmarkState state = mPerfStatusReporter.getBenchmarkState();
        state.pauseTiming();
@@ -220,7 +220,34 @@ public class JSScriptEnginePerfTests {
        return arrayArg("foo", Collections.nCopies(numCustomAudiences, interestGroupArg));
    }

    private static String callJSEngine(
    @Test
    public void evaluate_turtledoveWasm() throws Exception {
        assumeTrue(sJSScriptEngine.isWasmSupported().get(3, TimeUnit.SECONDS));

        BenchmarkState state = mPerfStatusReporter.getBenchmarkState();
        state.pauseTiming();

        String jsTestFile = readAsset("generate_bid_using_wasm.js");
        byte[] wasmTestFile = readBinaryAsset("generate_bid.wasm");
        JSScriptArgument[] inputBytes = new JSScriptArgument[200];
        Random rand = new Random();
        for (int i = 0; i < inputBytes.length; i++) {
            byte value = (byte) (rand.nextInt(2 * Byte.MAX_VALUE) - Byte.MIN_VALUE);
            inputBytes[i] = JSScriptArgument.numericArg("_", value);
        }
        JSScriptArgument adDataArgument =
                recordArg(
                        "ad",
                        stringArg("render_url", "http://google.com"),
                        recordArg("metadata", JSScriptArgument.arrayArg("input", inputBytes)));

        state.resumeTiming();
        while (state.keepRunning()) {
            callJSEngine(jsTestFile, wasmTestFile, ImmutableList.of(adDataArgument), "generateBid");
        }
    }

    private String callJSEngine(
            @NonNull String jsScript,
            @NonNull List<JSScriptArgument> args,
            @NonNull String functionName)
@@ -228,6 +255,15 @@ public class JSScriptEnginePerfTests {
        return callJSEngine(sJSScriptEngine, jsScript, args, functionName);
    }

    private String callJSEngine(
            @NonNull String jsScript,
            @NonNull byte[] wasmScript,
            @NonNull List<JSScriptArgument> args,
            @NonNull String functionName)
            throws Exception {
        return callJSEngine(sJSScriptEngine, jsScript, wasmScript, args, functionName);
    }

    private static String callJSEngine(
            @NonNull JSScriptEngine jsScriptEngine,
            @NonNull String jsScript,
@@ -241,6 +277,21 @@ public class JSScriptEnginePerfTests {
        return futureResult.get();
    }

    private String callJSEngine(
            @NonNull JSScriptEngine jsScriptEngine,
            @NonNull String jsScript,
            @NonNull byte[] wasmScript,
            @NonNull List<JSScriptArgument> args,
            @NonNull String functionName)
            throws Exception {
        CountDownLatch resultLatch = new CountDownLatch(1);
        ListenableFuture<String> futureResult =
                callJSEngineAsync(
                        jsScriptEngine, jsScript, wasmScript, args, functionName, resultLatch);
        resultLatch.await();
        return futureResult.get();
    }

    private static ListenableFuture<String> callJSEngineAsync(
            @NonNull String jsScript,
            @NonNull List<JSScriptArgument> args,
@@ -261,4 +312,26 @@ public class JSScriptEnginePerfTests {
        result.addListener(resultLatch::countDown, sExecutorService);
        return result;
    }

    private ListenableFuture<String> callJSEngineAsync(
            @NonNull JSScriptEngine engine,
            @NonNull String jsScript,
            @NonNull byte[] wasmScript,
            @NonNull List<JSScriptArgument> args,
            @NonNull String functionName,
            @NonNull CountDownLatch resultLatch) {
        Objects.requireNonNull(engine);
        Objects.requireNonNull(resultLatch);
        ListenableFuture<String> result = engine.evaluate(jsScript, wasmScript, args, functionName);
        result.addListener(resultLatch::countDown, sExecutorService);
        return result;
    }

    private byte[] readBinaryAsset(@NonNull String assetName) throws IOException {
        return sContext.getAssets().open(assetName).readAllBytes();
    }

    private String readAsset(@NonNull String assetName) throws IOException {
        return new String(readBinaryAsset(assetName), StandardCharsets.UTF_8);
    }
}