Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 6e2770ef authored by Phil Burk's avatar Phil Burk
Browse files

aaudio: use weak pointer to prevent UAF

Avoid using the mServiceEndpoint smart pointer
from multiple threads.

Bug: 74122779
Test: see bug for test instructions
Change-Id: Idaf9e32a163b25e51bde35d6f5ea10a372b5d916
parent e5a37268
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -49,9 +49,9 @@ public:

    virtual aaudio_result_t close() = 0;

    virtual aaudio_result_t registerStream(android::sp<AAudioServiceStreamBase> stream);
    aaudio_result_t registerStream(android::sp<AAudioServiceStreamBase> stream);

    virtual aaudio_result_t unregisterStream(android::sp<AAudioServiceStreamBase> stream);
    aaudio_result_t unregisterStream(android::sp<AAudioServiceStreamBase> stream);

    virtual aaudio_result_t startStream(android::sp<AAudioServiceStreamBase> stream,
                                        audio_port_handle_t *clientHandle) = 0;
+34 −24
Original line number Diff line number Diff line
@@ -105,6 +105,9 @@ aaudio_result_t AAudioServiceStreamBase::open(const aaudio::AAudioStreamRequest
            goto error;
        }

        // This is not protected by a lock because the stream cannot be
        // referenced until the service returns a handle to the client.
        // So only one thread can open a stream.
        mServiceEndpoint = mEndpointManager.openEndpoint(mAudioService,
                                                         request,
                                                         sharingMode);
@@ -113,6 +116,9 @@ aaudio_result_t AAudioServiceStreamBase::open(const aaudio::AAudioStreamRequest
            result = AAUDIO_ERROR_UNAVAILABLE;
            goto error;
        }
        // Save a weak pointer that we will use to access the endpoint.
        mServiceEndpointWeak = mServiceEndpoint;

        mFramesPerBurst = mServiceEndpoint->getFramesPerBurst();
        copyFrom(*mServiceEndpoint);
    }
@@ -131,13 +137,16 @@ aaudio_result_t AAudioServiceStreamBase::close() {

    stop();

    if (mServiceEndpoint == nullptr) {
    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
    if (endpoint == nullptr) {
        result = AAUDIO_ERROR_INVALID_STATE;
    } else {
        mServiceEndpoint->unregisterStream(this);
        AAudioEndpointManager &mEndpointManager = AAudioEndpointManager::getInstance();
        mEndpointManager.closeEndpoint(mServiceEndpoint);
        mServiceEndpoint.clear();
        endpoint->unregisterStream(this);
        AAudioEndpointManager &endpointManager = AAudioEndpointManager::getInstance();
        endpointManager.closeEndpoint(endpoint);

        // AAudioService::closeStream() prevents two threads from closing at the same time.
        mServiceEndpoint.clear(); // endpoint will hold the pointer until this method returns.
    }

    {
@@ -153,7 +162,12 @@ aaudio_result_t AAudioServiceStreamBase::close() {

aaudio_result_t AAudioServiceStreamBase::startDevice() {
    mClientHandle = AUDIO_PORT_HANDLE_NONE;
    return mServiceEndpoint->startStream(this, &mClientHandle);
    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
    if (endpoint == nullptr) {
        ALOGE("%s() has no endpoint", __func__);
        return AAUDIO_ERROR_INVALID_STATE;
    }
    return endpoint->startStream(this, &mClientHandle);
}

/**
@@ -163,16 +177,11 @@ aaudio_result_t AAudioServiceStreamBase::startDevice() {
 */
aaudio_result_t AAudioServiceStreamBase::start() {
    aaudio_result_t result = AAUDIO_OK;

    if (isRunning()) {
        return AAUDIO_OK;
    }

    if (mServiceEndpoint == nullptr) {
        ALOGE("%s() missing endpoint", __func__);
        result = AAUDIO_ERROR_INVALID_STATE;
        goto error;
    }

    setFlowing(false);

    // Start with fresh presentation timestamps.
@@ -201,10 +210,6 @@ aaudio_result_t AAudioServiceStreamBase::pause() {
    if (!isRunning()) {
        return result;
    }
    if (mServiceEndpoint == nullptr) {
        ALOGE("%s() missing endpoint", __func__);
        return AAUDIO_ERROR_INVALID_STATE;
    }

    // Send it now because the timestamp gets rounded up when stopStream() is called below.
    // Also we don't need the timestamps while we are shutting down.
@@ -216,7 +221,12 @@ aaudio_result_t AAudioServiceStreamBase::pause() {
        return result;
    }

    result = mServiceEndpoint->stopStream(this, mClientHandle);
    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
    if (endpoint == nullptr) {
        ALOGE("%s() has no endpoint", __func__);
        return AAUDIO_ERROR_INVALID_STATE;
    }
    result = endpoint->stopStream(this, mClientHandle);
    if (result != AAUDIO_OK) {
        ALOGE("%s() mServiceEndpoint returned %d, %s", __func__, result, getTypeText());
        disconnect(); // TODO should we return or pause Base first?
@@ -233,11 +243,6 @@ aaudio_result_t AAudioServiceStreamBase::stop() {
        return result;
    }

    if (mServiceEndpoint == nullptr) {
        ALOGE("%s() missing endpoint", __func__);
        return AAUDIO_ERROR_INVALID_STATE;
    }

    setState(AAUDIO_STREAM_STATE_STOPPING);

    // Send it now because the timestamp gets rounded up when stopStream() is called below.
@@ -249,10 +254,15 @@ aaudio_result_t AAudioServiceStreamBase::stop() {
        return result;
    }

    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
    if (endpoint == nullptr) {
        ALOGE("%s() has no endpoint", __func__);
        return AAUDIO_ERROR_INVALID_STATE;
    }
    // TODO wait for data to be played out
    result = mServiceEndpoint->stopStream(this, mClientHandle);
    result = endpoint->stopStream(this, mClientHandle);
    if (result != AAUDIO_OK) {
        ALOGE("%s() mServiceEndpoint returned %d, %s", __func__, result, getTypeText());
        ALOGE("%s() stopStream returned %d, %s", __func__, result, getTypeText());
        disconnect();
        // TODO what to do with result here?
    }
+5 −0
Original line number Diff line number Diff line
@@ -279,7 +279,12 @@ protected:
    SimpleDoubleBuffer<Timestamp>  mAtomicTimestamp;

    android::AAudioService &mAudioService;

    // The mServiceEndpoint variable can be accessed by multiple threads.
    // So we access it by locally promoting a weak pointer to a smart pointer,
    // which is thread-safe.
    android::sp<AAudioServiceEndpoint> mServiceEndpoint;
    android::wp<AAudioServiceEndpoint> mServiceEndpointWeak;

private:
    aaudio_handle_t         mHandle = -1;
+44 −11
Original line number Diff line number Diff line
@@ -70,14 +70,19 @@ aaudio_result_t AAudioServiceStreamMMAP::open(const aaudio::AAudioStreamRequest
        return result;
    }

    result = mServiceEndpoint->registerStream(keep);
    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
    if (endpoint == nullptr) {
        ALOGE("%s() has no endpoint", __func__);
        return AAUDIO_ERROR_INVALID_STATE;
    }

    result = endpoint->registerStream(keep);
    if (result != AAUDIO_OK) {
        goto error;
        return result;
    }

    setState(AAUDIO_STREAM_STATE_OPEN);

error:
    return AAUDIO_OK;
}

@@ -118,21 +123,37 @@ aaudio_result_t AAudioServiceStreamMMAP::stop() {

aaudio_result_t AAudioServiceStreamMMAP::startClient(const android::AudioClient& client,
                                                       audio_port_handle_t *clientHandle) {
    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
    if (endpoint == nullptr) {
        ALOGE("%s() has no endpoint", __func__);
        return AAUDIO_ERROR_INVALID_STATE;
    }
    // Start the client on behalf of the application. Generate a new porthandle.
    aaudio_result_t result = mServiceEndpoint->startClient(client, clientHandle);
    aaudio_result_t result = endpoint->startClient(client, clientHandle);
    return result;
}

aaudio_result_t AAudioServiceStreamMMAP::stopClient(audio_port_handle_t clientHandle) {
    aaudio_result_t result = mServiceEndpoint->stopClient(clientHandle);
    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
    if (endpoint == nullptr) {
        ALOGE("%s() has no endpoint", __func__);
        return AAUDIO_ERROR_INVALID_STATE;
    }
    aaudio_result_t result = endpoint->stopClient(clientHandle);
    return result;
}

// Get free-running DSP or DMA hardware position from the HAL.
aaudio_result_t AAudioServiceStreamMMAP::getFreeRunningPosition(int64_t *positionFrames,
                                                                  int64_t *timeNanos) {
    sp<AAudioServiceEndpointMMAP> serviceEndpointMMAP{
            static_cast<AAudioServiceEndpointMMAP *>(mServiceEndpoint.get())};
    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
    if (endpoint == nullptr) {
        ALOGE("%s() has no endpoint", __func__);
        return AAUDIO_ERROR_INVALID_STATE;
    }
    sp<AAudioServiceEndpointMMAP> serviceEndpointMMAP =
            static_cast<AAudioServiceEndpointMMAP *>(endpoint.get());

    aaudio_result_t result = serviceEndpointMMAP->getFreeRunningPosition(positionFrames, timeNanos);
    if (result == AAUDIO_OK) {
        Timestamp timestamp(*positionFrames, *timeNanos);
@@ -148,8 +169,15 @@ aaudio_result_t AAudioServiceStreamMMAP::getFreeRunningPosition(int64_t *positio
// Get timestamp that was written by getFreeRunningPosition()
aaudio_result_t AAudioServiceStreamMMAP::getHardwareTimestamp(int64_t *positionFrames,
                                                                int64_t *timeNanos) {
    sp<AAudioServiceEndpointMMAP> serviceEndpointMMAP{
            static_cast<AAudioServiceEndpointMMAP *>(mServiceEndpoint.get())};

    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
    if (endpoint == nullptr) {
        ALOGE("%s() has no endpoint", __func__);
        return AAUDIO_ERROR_INVALID_STATE;
    }
    sp<AAudioServiceEndpointMMAP> serviceEndpointMMAP =
            static_cast<AAudioServiceEndpointMMAP *>(endpoint.get());

    // TODO Get presentation timestamp from the HAL
    if (mAtomicTimestamp.isValid()) {
        Timestamp timestamp = mAtomicTimestamp.read();
@@ -165,7 +193,12 @@ aaudio_result_t AAudioServiceStreamMMAP::getHardwareTimestamp(int64_t *positionF
aaudio_result_t AAudioServiceStreamMMAP::getAudioDataDescription(
        AudioEndpointParcelable &parcelable)
{
    sp<AAudioServiceEndpointMMAP> serviceEndpointMMAP{
            static_cast<AAudioServiceEndpointMMAP *>(mServiceEndpoint.get())};
    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
    if (endpoint == nullptr) {
        ALOGE("%s() has no endpoint", __func__);
        return AAUDIO_ERROR_INVALID_STATE;
    }
    sp<AAudioServiceEndpointMMAP> serviceEndpointMMAP =
            static_cast<AAudioServiceEndpointMMAP *>(endpoint.get());
    return serviceEndpointMMAP->getDownDataDescription(parcelable);
}
+23 −8
Original line number Diff line number Diff line
@@ -128,6 +128,12 @@ aaudio_result_t AAudioServiceStreamShared::open(const aaudio::AAudioStreamReques

    const AAudioStreamConfiguration &configurationInput = request.getConstantConfiguration();

    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
    if (endpoint == nullptr) {
        result = AAUDIO_ERROR_INVALID_STATE;
        goto error;
    }

    // Is the request compatible with the shared endpoint?
    setFormat(configurationInput.getFormat());
    if (getFormat() == AAUDIO_FORMAT_UNSPECIFIED) {
@@ -140,20 +146,20 @@ aaudio_result_t AAudioServiceStreamShared::open(const aaudio::AAudioStreamReques

    setSampleRate(configurationInput.getSampleRate());
    if (getSampleRate() == AAUDIO_UNSPECIFIED) {
        setSampleRate(mServiceEndpoint->getSampleRate());
    } else if (getSampleRate() != mServiceEndpoint->getSampleRate()) {
        setSampleRate(endpoint->getSampleRate());
    } else if (getSampleRate() != endpoint->getSampleRate()) {
        ALOGD("%s() mSampleRate = %d, need %d",
              __func__, getSampleRate(), mServiceEndpoint->getSampleRate());
              __func__, getSampleRate(), endpoint->getSampleRate());
        result = AAUDIO_ERROR_INVALID_RATE;
        goto error;
    }

    setSamplesPerFrame(configurationInput.getSamplesPerFrame());
    if (getSamplesPerFrame() == AAUDIO_UNSPECIFIED) {
        setSamplesPerFrame(mServiceEndpoint->getSamplesPerFrame());
    } else if (getSamplesPerFrame() != mServiceEndpoint->getSamplesPerFrame()) {
        setSamplesPerFrame(endpoint->getSamplesPerFrame());
    } else if (getSamplesPerFrame() != endpoint->getSamplesPerFrame()) {
        ALOGD("%s() mSamplesPerFrame = %d, need %d",
              __func__, getSamplesPerFrame(), mServiceEndpoint->getSamplesPerFrame());
              __func__, getSamplesPerFrame(), endpoint->getSamplesPerFrame());
        result = AAUDIO_ERROR_OUT_OF_RANGE;
        goto error;
    }
@@ -179,7 +185,10 @@ aaudio_result_t AAudioServiceStreamShared::open(const aaudio::AAudioStreamReques
        }
    }

    result = mServiceEndpoint->registerStream(keep);
    ALOGD("AAudioServiceStreamShared::open() actual rate = %d, channels = %d, deviceId = %d",
          getSampleRate(), getSamplesPerFrame(), endpoint->getDeviceId());

    result = endpoint->registerStream(keep);
    if (result != AAUDIO_OK) {
        goto error;
    }
@@ -246,7 +255,13 @@ aaudio_result_t AAudioServiceStreamShared::getHardwareTimestamp(int64_t *positio
                                                                int64_t *timeNanos) {

    int64_t position = 0;
    aaudio_result_t result = mServiceEndpoint->getTimestamp(&position, timeNanos);
    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
    if (endpoint == nullptr) {
        ALOGE("%s() has no endpoint", __func__);
        return AAUDIO_ERROR_INVALID_STATE;
    }

    aaudio_result_t result = endpoint->getTimestamp(&position, timeNanos);
    if (result == AAUDIO_OK) {
        int64_t offset = mTimestampPositionOffset.load();
        // TODO, do not go below starting value