Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 5c7f06a2 authored by TreeHugger Robot's avatar TreeHugger Robot Committed by Android (Google) Code Review
Browse files

Merge "Add benchmarks for WASM bidding logic" into tm-mainline-prod

parents a2a2d0bd 7713b243
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -33,6 +33,7 @@ android_test {
        "compatibility-device-util-axt",
        "platform-test-annotations",
        "adservices-service-core",
        "androidx.core_core",
    ],
    test_suites: ["device-tests"],
    data: [":perfetto_artifacts"],
+1.25 MiB

File added.

No diff preview for this file type.

+24 −0
Original line number Diff line number Diff line
function generateBid(ad, wasmModule) {
  let input = ad.metadata.input;

  const instance = new WebAssembly.Instance(wasmModule);

  const memory = instance.exports.memory;
  const input_in_memory = new Float32Array(memory.buffer, 0, 200);
  for (let i = 0; i < input.length; ++i) {
    input_in_memory[i] = input[i];
  }
  const results = [
    instance.exports.nn_forward_model0(input_in_memory.length, input_in_memory),
    instance.exports.nn_forward_model1(input_in_memory.length, input_in_memory),
    instance.exports.nn_forward_model2(input_in_memory.length, input_in_memory),
    instance.exports.nn_forward_model3(input_in_memory.length, input_in_memory),
    instance.exports.nn_forward_model4(input_in_memory.length, input_in_memory),
  ];
  const bid = results.map(x => Math.max(x, 1)).reduce((x, y) => x * y);
  return {
    ad: 'example',
    bid: bid,
    render: ad.renderUrl
  }
}
 No newline at end of file
+85 −12
Original line number Diff line number Diff line
@@ -24,6 +24,9 @@ import static com.android.adservices.service.js.JSScriptArgument.stringArrayArg;

import static com.google.common.truth.Truth.assertThat;

import static org.junit.Assume.assumeTrue;

import android.annotation.SuppressLint;
import android.content.Context;
import android.perftests.utils.BenchmarkState;
import android.perftests.utils.PerfStatusReporter;
@@ -45,48 +48,44 @@ import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ListenableFuture;

import org.json.JSONArray;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;

import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

/** To run the unit tests for this class, run "atest RubidiumPerfTests:JSScriptEnginePerfTests" */
@MediumTest
@RunWith(AndroidJUnit4.class)
public class JSScriptEnginePerfTests {
    private static final String TAG = JSScriptEnginePerfTests.class.getSimpleName();
    private static final String TAG = JSScriptEngine.TAG;
    private static final Context sContext = ApplicationProvider.getApplicationContext();
    private static final ExecutorService sExecutorService = Executors.newFixedThreadPool(10);

    private static JSScriptEngine sJSScriptEngine;
    private static final JSScriptEngine sJSScriptEngine =
            JSScriptEngine.getInstanceForTesting(
                    sContext, Profiler.createInstance(JSScriptEngine.TAG));

    @Rule public PerfStatusReporter mPerfStatusReporter = new PerfStatusReporter();

    @Before
    public void before() throws Exception {
        Profiler profiler = Profiler.createInstance(JSScriptEngine.TAG);
        sJSScriptEngine = JSScriptEngine.getInstanceForTesting(sContext, profiler);

        // Warm up the sandbox env.
        callJSEngine(
                "function test() { return \"hello world\";" + " }", ImmutableList.of(), "test");
    }

    @After
    public void after() {
        sJSScriptEngine.shutdown();
    }

    @Test
    public void evaluate_helloWorld() throws Exception {
        BenchmarkState state = mPerfStatusReporter.getBenchmarkState();
@@ -156,6 +155,7 @@ public class JSScriptEnginePerfTests {
        runParametrizedTurtledoveScript(75);
    }

    @SuppressLint("DefaultLocale")
    private void runParametrizedTurtledoveScript(int numAds) throws Exception {
        BenchmarkState state = mPerfStatusReporter.getBenchmarkState();
        state.pauseTiming();
@@ -220,7 +220,34 @@ public class JSScriptEnginePerfTests {
        return arrayArg("foo", Collections.nCopies(numCustomAudiences, interestGroupArg));
    }

    private static String callJSEngine(
    @Test
    public void evaluate_turtledoveWasm() throws Exception {
        assumeTrue(sJSScriptEngine.isWasmSupported().get(3, TimeUnit.SECONDS));

        BenchmarkState state = mPerfStatusReporter.getBenchmarkState();
        state.pauseTiming();

        String jsTestFile = readAsset("generate_bid_using_wasm.js");
        byte[] wasmTestFile = readBinaryAsset("generate_bid.wasm");
        JSScriptArgument[] inputBytes = new JSScriptArgument[200];
        Random rand = new Random();
        for (int i = 0; i < inputBytes.length; i++) {
            byte value = (byte) (rand.nextInt(2 * Byte.MAX_VALUE) - Byte.MIN_VALUE);
            inputBytes[i] = JSScriptArgument.numericArg("_", value);
        }
        JSScriptArgument adDataArgument =
                recordArg(
                        "ad",
                        stringArg("render_url", "http://google.com"),
                        recordArg("metadata", JSScriptArgument.arrayArg("input", inputBytes)));

        state.resumeTiming();
        while (state.keepRunning()) {
            callJSEngine(jsTestFile, wasmTestFile, ImmutableList.of(adDataArgument), "generateBid");
        }
    }

    private String callJSEngine(
            @NonNull String jsScript,
            @NonNull List<JSScriptArgument> args,
            @NonNull String functionName)
@@ -228,6 +255,15 @@ public class JSScriptEnginePerfTests {
        return callJSEngine(sJSScriptEngine, jsScript, args, functionName);
    }

    private String callJSEngine(
            @NonNull String jsScript,
            @NonNull byte[] wasmScript,
            @NonNull List<JSScriptArgument> args,
            @NonNull String functionName)
            throws Exception {
        return callJSEngine(sJSScriptEngine, jsScript, wasmScript, args, functionName);
    }

    private static String callJSEngine(
            @NonNull JSScriptEngine jsScriptEngine,
            @NonNull String jsScript,
@@ -241,6 +277,21 @@ public class JSScriptEnginePerfTests {
        return futureResult.get();
    }

    private String callJSEngine(
            @NonNull JSScriptEngine jsScriptEngine,
            @NonNull String jsScript,
            @NonNull byte[] wasmScript,
            @NonNull List<JSScriptArgument> args,
            @NonNull String functionName)
            throws Exception {
        CountDownLatch resultLatch = new CountDownLatch(1);
        ListenableFuture<String> futureResult =
                callJSEngineAsync(
                        jsScriptEngine, jsScript, wasmScript, args, functionName, resultLatch);
        resultLatch.await();
        return futureResult.get();
    }

    private static ListenableFuture<String> callJSEngineAsync(
            @NonNull String jsScript,
            @NonNull List<JSScriptArgument> args,
@@ -261,4 +312,26 @@ public class JSScriptEnginePerfTests {
        result.addListener(resultLatch::countDown, sExecutorService);
        return result;
    }

    private ListenableFuture<String> callJSEngineAsync(
            @NonNull JSScriptEngine engine,
            @NonNull String jsScript,
            @NonNull byte[] wasmScript,
            @NonNull List<JSScriptArgument> args,
            @NonNull String functionName,
            @NonNull CountDownLatch resultLatch) {
        Objects.requireNonNull(engine);
        Objects.requireNonNull(resultLatch);
        ListenableFuture<String> result = engine.evaluate(jsScript, wasmScript, args, functionName);
        result.addListener(resultLatch::countDown, sExecutorService);
        return result;
    }

    private byte[] readBinaryAsset(@NonNull String assetName) throws IOException {
        return sContext.getAssets().open(assetName).readAllBytes();
    }

    private String readAsset(@NonNull String assetName) throws IOException {
        return new String(readBinaryAsset(assetName), StandardCharsets.UTF_8);
    }
}