Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 08674238 authored by Pablo Gamito's avatar Pablo Gamito Committed by Android (Google) Code Review
Browse files

Merge "Set TLS and incremental state lazily from Java" into main

parents 2d7a8466 44d230c4
Loading
Loading
Loading
Loading
+9 −9
Original line number Diff line number Diff line
@@ -18,8 +18,6 @@ package android.tracing.perfetto;

import android.util.proto.ProtoInputStream;

import com.android.internal.annotations.VisibleForTesting;

import dalvik.annotation.optimization.CriticalNative;

/**
@@ -73,7 +71,8 @@ public abstract class DataSource<DataSourceInstanceType extends DataSourceInstan
     * @param fun The tracing lambda that will be called with the tracing contexts of each active
     *            tracing instance.
     */
    public final void trace(TraceFunction<TlsStateType, IncrementalStateType> fun) {
    public final void trace(
            TraceFunction<DataSourceInstanceType, TlsStateType, IncrementalStateType> fun) {
        boolean startedIterator = nativePerfettoDsTraceIterateBegin(mNativeObj);

        if (!startedIterator) {
@@ -82,8 +81,10 @@ public abstract class DataSource<DataSourceInstanceType extends DataSourceInstan

        try {
            do {
                TracingContext<TlsStateType, IncrementalStateType> ctx =
                        new TracingContext<>(mNativeObj);
                int instanceIndex = nativeGetPerfettoDsInstanceIndex(mNativeObj);

                TracingContext<DataSourceInstanceType, TlsStateType, IncrementalStateType> ctx =
                        new TracingContext<>(this, instanceIndex);
                fun.trace(ctx);

                ctx.flush();
@@ -104,9 +105,7 @@ public abstract class DataSource<DataSourceInstanceType extends DataSourceInstan
     * Override this method to create a custom TlsState object for your DataSource. A new instance
     * will be created per trace instance per thread.
     *
     * NOTE: Should only be called from native side.
     */
    @VisibleForTesting
    public TlsStateType createTlsState(CreateTlsStateArgs<DataSourceInstanceType> args) {
        return null;
    }
@@ -114,9 +113,8 @@ public abstract class DataSource<DataSourceInstanceType extends DataSourceInstan
    /**
     * Override this method to create and use a custom IncrementalState object for your DataSource.
     *
     * NOTE: Should only be called from native side.
     */
    protected IncrementalStateType createIncrementalState(
    public IncrementalStateType createIncrementalState(
            CreateIncrementalStateArgs<DataSourceInstanceType> args) {
        return null;
    }
@@ -185,4 +183,6 @@ public abstract class DataSource<DataSourceInstanceType extends DataSourceInstan
    private static native boolean nativePerfettoDsTraceIterateNext(long dataSourcePtr);
    @CriticalNative
    private static native void nativePerfettoDsTraceIterateBreak(long dataSourcePtr);
    @CriticalNative
    private static native int nativeGetPerfettoDsInstanceIndex(long dataSourcePtr);
}
+4 −2
Original line number Diff line number Diff line
@@ -19,12 +19,14 @@ package android.tracing.perfetto;
/**
 * The interface for the trace function called from native on a trace call with a context.
 *
 * @param <DataSourceInstanceType> The type of DataSource this tracing context is for.
 * @param <TlsStateType> The type of the custom TLS state, if any is used.
 * @param <IncrementalStateType> The type of the custom incremental state, if any is used.
 *
 * @hide
 */
public interface TraceFunction<TlsStateType, IncrementalStateType> {
public interface TraceFunction<DataSourceInstanceType extends DataSourceInstance,
        TlsStateType, IncrementalStateType> {

    /**
     * This function will be called synchronously (i.e., always before trace() returns) only if
@@ -34,5 +36,5 @@ public interface TraceFunction<TlsStateType, IncrementalStateType> {
     *
     * @param ctx the tracing context to trace for in the trace function.
     */
    void trace(TracingContext<TlsStateType, IncrementalStateType> ctx);
    void trace(TracingContext<DataSourceInstanceType, TlsStateType, IncrementalStateType> ctx);
}
+36 −8
Original line number Diff line number Diff line
@@ -24,18 +24,25 @@ import java.util.List;
/**
 * Argument passed to the lambda function passed to Trace().
 *
 * @param <DataSourceInstanceType> The type of the datasource this tracing context is for.
 * @param <TlsStateType> The type of the custom TLS state, if any is used.
 * @param <IncrementalStateType> The type of the custom incremental state, if any is used.
 *
 * @hide
 */
public class TracingContext<TlsStateType, IncrementalStateType> {
public class TracingContext<DataSourceInstanceType extends DataSourceInstance, TlsStateType,
        IncrementalStateType> {

    private final long mNativeDsPtr;
    private final DataSource<DataSourceInstanceType, TlsStateType, IncrementalStateType>
            mDataSource;
    private final int mInstanceIndex;
    private final List<ProtoOutputStream> mTracePackets = new ArrayList<>();

    TracingContext(long nativeDsPtr) {
        this.mNativeDsPtr = nativeDsPtr;
    TracingContext(DataSource<DataSourceInstanceType, TlsStateType, IncrementalStateType>
            dataSource,
            int instanceIndex) {
        this.mDataSource = dataSource;
        this.mInstanceIndex = instanceIndex;
    }

    /**
@@ -61,18 +68,26 @@ public class TracingContext<TlsStateType, IncrementalStateType> {
     * Stop timeout expires.
     */
    public void flush() {
        nativeFlush(mNativeDsPtr, getAndClearAllPendingTracePackets());
        nativeFlush(mDataSource.mNativeObj, getAndClearAllPendingTracePackets());
    }

    /**
     * Can optionally be used to store custom per-sequence
     * session data, which is not reset when incremental state is cleared
     * (e.g. configuration options).
     *
     *h
     * @return The TlsState instance for the tracing thread and instance.
     */
    public TlsStateType getCustomTlsState() {
        return (TlsStateType) nativeGetCustomTls(mNativeDsPtr);
        TlsStateType tlsState = (TlsStateType) nativeGetCustomTls(mDataSource.mNativeObj);
        if (tlsState == null) {
            final CreateTlsStateArgs<DataSourceInstanceType> args =
                    new CreateTlsStateArgs<>(mDataSource, mInstanceIndex);
            tlsState = mDataSource.createTlsState(args);
            nativeSetCustomTls(mDataSource.mNativeObj, tlsState);
        }

        return tlsState;
    }

    /**
@@ -82,7 +97,16 @@ public class TracingContext<TlsStateType, IncrementalStateType> {
     * @return The current IncrementalState object instance.
     */
    public IncrementalStateType getIncrementalState() {
        return (IncrementalStateType) nativeGetIncrementalState(mNativeDsPtr);
        IncrementalStateType incrementalState =
                (IncrementalStateType) nativeGetIncrementalState(mDataSource.mNativeObj);
        if (incrementalState == null) {
            final CreateIncrementalStateArgs<DataSourceInstanceType> args =
                    new CreateIncrementalStateArgs<>(mDataSource, mInstanceIndex);
            incrementalState = mDataSource.createIncrementalState(args);
            nativeSetIncrementalState(mDataSource.mNativeObj, incrementalState);
        }

        return incrementalState;
    }

    private byte[][] getAndClearAllPendingTracePackets() {
@@ -97,6 +121,10 @@ public class TracingContext<TlsStateType, IncrementalStateType> {
    }

    private static native void nativeFlush(long dataSourcePtr, byte[][] packetData);

    private static native Object nativeGetCustomTls(long nativeDsPtr);
    private static native void nativeSetCustomTls(long nativeDsPtr, Object tlsState);

    private static native Object nativeGetIncrementalState(long nativeDsPtr);
    private static native void nativeSetIncrementalState(long nativeDsPtr, Object incrementalState);
}
+5 −2
Original line number Diff line number Diff line
@@ -440,6 +440,7 @@ public class PerfettoProtoLogImpl implements IProtoLog {
    }

    private int internStacktraceString(TracingContext<
            ProtoLogDataSource.Instance,
            ProtoLogDataSource.TlsState,
            ProtoLogDataSource.IncrementalState> ctx,
            String stacktrace) {
@@ -449,7 +450,8 @@ public class PerfettoProtoLogImpl implements IProtoLog {
    }

    private int internStringArg(
            TracingContext<ProtoLogDataSource.TlsState, ProtoLogDataSource.IncrementalState> ctx,
            TracingContext<ProtoLogDataSource.Instance, ProtoLogDataSource.TlsState,
                    ProtoLogDataSource.IncrementalState> ctx,
            String string
    ) {
        final ProtoLogDataSource.IncrementalState incrementalState = ctx.getIncrementalState();
@@ -458,7 +460,8 @@ public class PerfettoProtoLogImpl implements IProtoLog {
    }

    private int internString(
            TracingContext<ProtoLogDataSource.TlsState, ProtoLogDataSource.IncrementalState> ctx,
            TracingContext<ProtoLogDataSource.Instance, ProtoLogDataSource.TlsState,
                    ProtoLogDataSource.IncrementalState> ctx,
            Map<String, Integer> internMap,
            long fieldId,
            String string
+66 −60
Original line number Diff line number Diff line
@@ -93,49 +93,6 @@ jobject PerfettoDataSource::newInstance(JNIEnv* env, void* ds_config, size_t ds_
    return instance;
}

jobject PerfettoDataSource::createTlsStateGlobalRef(JNIEnv* env, PerfettoDsInstanceIndex inst_id) {
    ScopedLocalRef<jobject> args(env,
                                 env->NewObject(gCreateTlsStateArgsClassInfo.clazz,
                                                gCreateTlsStateArgsClassInfo.init, mJavaDataSource,
                                                inst_id));

    ScopedLocalRef<jobject> tslState(env,
                                     env->CallObjectMethod(mJavaDataSource,
                                                           gPerfettoDataSourceClassInfo
                                                                   .createTlsState,
                                                           args.get()));

    if (env->ExceptionCheck()) {
        LOGE_EX(env);
        env->ExceptionClear();
        LOG_ALWAYS_FATAL("Failed to create new Java Perfetto incremental state");
    }

    return env->NewGlobalRef(tslState.get());
}

jobject PerfettoDataSource::createIncrementalStateGlobalRef(JNIEnv* env,
                                                            PerfettoDsInstanceIndex inst_id) {
    ScopedLocalRef<jobject> args(env,
                                 env->NewObject(gCreateIncrementalStateArgsClassInfo.clazz,
                                                gCreateIncrementalStateArgsClassInfo.init,
                                                mJavaDataSource, inst_id));

    ScopedLocalRef<jobject> incrementalState(env,
                                             env->CallObjectMethod(mJavaDataSource,
                                                                   gPerfettoDataSourceClassInfo
                                                                           .createIncrementalState,
                                                                   args.get()));

    if (env->ExceptionCheck()) {
        LOGE_EX(env);
        env->ExceptionClear();
        LOG_ALWAYS_FATAL("Failed to create Java Perfetto incremental state");
    }

    return env->NewGlobalRef(incrementalState.get());
}

bool PerfettoDataSource::TraceIterateBegin() {
    if (gInIteration) {
        return false;
@@ -177,6 +134,15 @@ void PerfettoDataSource::TraceIterateBreak() {
    gInIteration = false;
}

PerfettoDsInstanceIndex PerfettoDataSource::GetInstanceIndex() {
    if (!gInIteration) {
        LOG_ALWAYS_FATAL("Tried calling GetInstanceIndex outside of a tracer iteration.");
        return -1;
    }

    return gIterator.impl.inst_id;
}

jobject PerfettoDataSource::GetCustomTls() {
    if (!gInIteration) {
        LOG_ALWAYS_FATAL("Tried getting CustomTls outside of a tracer iteration.");
@@ -189,6 +155,18 @@ jobject PerfettoDataSource::GetCustomTls() {
    return tls_state->jobj;
}

void PerfettoDataSource::SetCustomTls(jobject tlsState) {
    if (!gInIteration) {
        LOG_ALWAYS_FATAL("Tried getting CustomTls outside of a tracer iteration.");
        return;
    }

    TlsState* tls_state =
            reinterpret_cast<TlsState*>(PerfettoDsGetCustomTls(&dataSource, &gIterator));

    tls_state->jobj = tlsState;
}

jobject PerfettoDataSource::GetIncrementalState() {
    if (!gInIteration) {
        LOG_ALWAYS_FATAL("Tried getting IncrementalState outside of a tracer iteration.");
@@ -201,6 +179,18 @@ jobject PerfettoDataSource::GetIncrementalState() {
    return incr_state->jobj;
}

void PerfettoDataSource::SetIncrementalState(jobject incrementalState) {
    if (!gInIteration) {
        LOG_ALWAYS_FATAL("Tried getting IncrementalState outside of a tracer iteration.");
        return;
    }

    IncrementalState* incr_state = reinterpret_cast<IncrementalState*>(
            PerfettoDsGetIncrementalState(&dataSource, &gIterator));

    incr_state->jobj = incrementalState;
}

void PerfettoDataSource::WritePackets(JNIEnv* env, jobjectArray packets) {
    if (!gInIteration) {
        LOG_ALWAYS_FATAL("Tried writing packets outside of a tracer iteration.");
@@ -264,7 +254,7 @@ void nativeFlushAll(JNIEnv* env, jclass clazz, jlong ptr) {
}

void nativeRegisterDataSource(JNIEnv* env, jclass clazz, jlong datasource_ptr,
                              int buffer_exhausted_policy) {
                              jint buffer_exhausted_policy) {
    sp<PerfettoDataSource> datasource = reinterpret_cast<PerfettoDataSource*>(datasource_ptr);

    struct PerfettoDsParams params = PerfettoDsParamsDefault();
@@ -291,13 +281,8 @@ void nativeRegisterDataSource(JNIEnv* env, jclass clazz, jlong datasource_ptr,

    params.on_create_tls_cb = [](struct PerfettoDsImpl* ds_impl, PerfettoDsInstanceIndex inst_id,
                                 struct PerfettoDsTracerImpl* tracer, void* user_arg) -> void* {
        JNIEnv* env = GetOrAttachJNIEnvironment(gVm, JNI_VERSION_1_6);

        auto* datasource = reinterpret_cast<PerfettoDataSource*>(user_arg);

        jobject java_tls_state = datasource->createTlsStateGlobalRef(env, inst_id);

        auto* tls_state = new TlsState(java_tls_state);
        // Populated later and only if required by the java side
        auto* tls_state = new TlsState(NULL);
        return static_cast<void*>(tls_state);
    };

@@ -306,18 +291,16 @@ void nativeRegisterDataSource(JNIEnv* env, jclass clazz, jlong datasource_ptr,

        TlsState* tls_state = reinterpret_cast<TlsState*>(ptr);

        if (tls_state->jobj != NULL) {
            env->DeleteGlobalRef(tls_state->jobj);
        }
        delete tls_state;
    };

    params.on_create_incr_cb = [](struct PerfettoDsImpl* ds_impl, PerfettoDsInstanceIndex inst_id,
                                  struct PerfettoDsTracerImpl* tracer, void* user_arg) -> void* {
        JNIEnv* env = GetOrAttachJNIEnvironment(gVm, JNI_VERSION_1_6);

        auto* datasource = reinterpret_cast<PerfettoDataSource*>(user_arg);
        jobject java_incr_state = datasource->createIncrementalStateGlobalRef(env, inst_id);

        auto* incr_state = new IncrementalState(java_incr_state);
        // Populated later and only if required by the java side
        auto* incr_state = new IncrementalState(NULL);
        return static_cast<void*>(incr_state);
    };

@@ -326,7 +309,9 @@ void nativeRegisterDataSource(JNIEnv* env, jclass clazz, jlong datasource_ptr,

        IncrementalState* incr_state = reinterpret_cast<IncrementalState*>(ptr);

        if (incr_state->jobj != NULL) {
            env->DeleteGlobalRef(incr_state->jobj);
        }
        delete incr_state;
    };

@@ -401,16 +386,34 @@ void nativePerfettoDsTraceIterateBreak(jlong dataSourcePtr) {
    return datasource->TraceIterateBreak();
}

jint nativeGetPerfettoDsInstanceIndex(jlong dataSourcePtr) {
    sp<PerfettoDataSource> datasource = reinterpret_cast<PerfettoDataSource*>(dataSourcePtr);
    return (jint)datasource->GetInstanceIndex();
}

jobject nativeGetCustomTls(JNIEnv* /* env */, jclass /* clazz */, jlong dataSourcePtr) {
    sp<PerfettoDataSource> datasource = reinterpret_cast<PerfettoDataSource*>(dataSourcePtr);
    return datasource->GetCustomTls();
}

void nativeSetCustomTls(JNIEnv* env, jclass /* clazz */, jlong dataSourcePtr, jobject tlsState) {
    sp<PerfettoDataSource> datasource = reinterpret_cast<PerfettoDataSource*>(dataSourcePtr);
    tlsState = env->NewGlobalRef(tlsState);
    return datasource->SetCustomTls(tlsState);
}

jobject nativeGetIncrementalState(JNIEnv* /* env */, jclass /* clazz */, jlong dataSourcePtr) {
    sp<PerfettoDataSource> datasource = reinterpret_cast<PerfettoDataSource*>(dataSourcePtr);
    return datasource->GetIncrementalState();
}

void nativeSetIncrementalState(JNIEnv* env, jclass /* clazz */, jlong dataSourcePtr,
                               jobject incrementalState) {
    sp<PerfettoDataSource> datasource = reinterpret_cast<PerfettoDataSource*>(dataSourcePtr);
    incrementalState = env->NewGlobalRef(incrementalState);
    return datasource->SetIncrementalState(incrementalState);
}

const JNINativeMethod gMethods[] = {
        /* name, signature, funcPtr */
        {"nativeCreate", "(Landroid/tracing/perfetto/DataSource;Ljava/lang/String;)J",
@@ -425,13 +428,16 @@ const JNINativeMethod gMethods[] = {

        {"nativePerfettoDsTraceIterateBegin", "(J)Z", (void*)nativePerfettoDsTraceIterateBegin},
        {"nativePerfettoDsTraceIterateNext", "(J)Z", (void*)nativePerfettoDsTraceIterateNext},
        {"nativePerfettoDsTraceIterateBreak", "(J)V", (void*)nativePerfettoDsTraceIterateBreak}};
        {"nativePerfettoDsTraceIterateBreak", "(J)V", (void*)nativePerfettoDsTraceIterateBreak},
        {"nativeGetPerfettoDsInstanceIndex", "(J)I", (void*)nativeGetPerfettoDsInstanceIndex}};

const JNINativeMethod gMethodsTracingContext[] = {
        /* name, signature, funcPtr */
        {"nativeFlush", "(J[[B)V", (void*)nativeFlush},
        {"nativeGetCustomTls", "(J)Ljava/lang/Object;", (void*)nativeGetCustomTls},
        {"nativeGetIncrementalState", "(J)Ljava/lang/Object;", (void*)nativeGetIncrementalState},
        {"nativeSetCustomTls", "(JLjava/lang/Object;)V", (void*)nativeSetCustomTls},
        {"nativeSetIncrementalState", "(JLjava/lang/Object;)V", (void*)nativeSetIncrementalState},
};

int register_android_tracing_PerfettoDataSource(JNIEnv* env) {
Loading