Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 723e8fc2 authored by Phil Burk's avatar Phil Burk Committed by Android (Google) Code Review
Browse files

Merge "aaudio: use weak pointer to prevent UAF" into oc-mr1-dev

parents 215185d3 92c3e263
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -49,9 +49,9 @@ public:

    virtual aaudio_result_t close() = 0;

    virtual aaudio_result_t registerStream(android::sp<AAudioServiceStreamBase> stream);
    aaudio_result_t registerStream(android::sp<AAudioServiceStreamBase> stream);

    virtual aaudio_result_t unregisterStream(android::sp<AAudioServiceStreamBase> stream);
    aaudio_result_t unregisterStream(android::sp<AAudioServiceStreamBase> stream);

    virtual aaudio_result_t startStream(android::sp<AAudioServiceStreamBase> stream,
                                        audio_port_handle_t *clientHandle) = 0;
+34 −23
Original line number Diff line number Diff line
@@ -104,6 +104,9 @@ aaudio_result_t AAudioServiceStreamBase::open(const aaudio::AAudioStreamRequest
            goto error;
        }

        // This is not protected by a lock because the stream cannot be
        // referenced until the service returns a handle to the client.
        // So only one thread can open a stream.
        mServiceEndpoint = mEndpointManager.openEndpoint(mAudioService,
                                                         request,
                                                         sharingMode);
@@ -112,6 +115,9 @@ aaudio_result_t AAudioServiceStreamBase::open(const aaudio::AAudioStreamRequest
            result = AAUDIO_ERROR_UNAVAILABLE;
            goto error;
        }
        // Save a weak pointer that we will use to access the endpoint.
        mServiceEndpointWeak = mServiceEndpoint;

        mFramesPerBurst = mServiceEndpoint->getFramesPerBurst();
        copyFrom(*mServiceEndpoint);
    }
@@ -130,15 +136,19 @@ aaudio_result_t AAudioServiceStreamBase::close() {

    stop();

    if (mServiceEndpoint == nullptr) {
    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
    if (endpoint == nullptr) {
        result = AAUDIO_ERROR_INVALID_STATE;
    } else {
        mServiceEndpoint->unregisterStream(this);
        AAudioEndpointManager &mEndpointManager = AAudioEndpointManager::getInstance();
        mEndpointManager.closeEndpoint(mServiceEndpoint);
        mServiceEndpoint.clear();
        endpoint->unregisterStream(this);
        AAudioEndpointManager &endpointManager = AAudioEndpointManager::getInstance();
        endpointManager.closeEndpoint(endpoint);

        // AAudioService::closeStream() prevents two threads from closing at the same time.
        mServiceEndpoint.clear(); // endpoint will hold the pointer until this method returns.
    }


    {
        std::lock_guard<std::mutex> lock(mUpMessageQueueLock);
        stopTimestampThread();
@@ -152,7 +162,12 @@ aaudio_result_t AAudioServiceStreamBase::close() {

aaudio_result_t AAudioServiceStreamBase::startDevice() {
    mClientHandle = AUDIO_PORT_HANDLE_NONE;
    return mServiceEndpoint->startStream(this, &mClientHandle);
    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
    if (endpoint == nullptr) {
        ALOGE("%s() has no endpoint", __func__);
        return AAUDIO_ERROR_INVALID_STATE;
    }
    return endpoint->startStream(this, &mClientHandle);
}

/**
@@ -162,16 +177,11 @@ aaudio_result_t AAudioServiceStreamBase::startDevice() {
 */
aaudio_result_t AAudioServiceStreamBase::start() {
    aaudio_result_t result = AAUDIO_OK;

    if (isRunning()) {
        return AAUDIO_OK;
    }

    if (mServiceEndpoint == nullptr) {
        ALOGE("AAudioServiceStreamBase::start() missing endpoint");
        result = AAUDIO_ERROR_INVALID_STATE;
        goto error;
    }

    // Start with fresh presentation timestamps.
    mAtomicTimestamp.clear();

@@ -198,10 +208,6 @@ aaudio_result_t AAudioServiceStreamBase::pause() {
    if (!isRunning()) {
        return result;
    }
    if (mServiceEndpoint == nullptr) {
        ALOGE("AAudioServiceStreamShared::pause() missing endpoint");
        return AAUDIO_ERROR_INVALID_STATE;
    }

    // Send it now because the timestamp gets rounded up when stopStream() is called below.
    // Also we don't need the timestamps while we are shutting down.
@@ -213,7 +219,12 @@ aaudio_result_t AAudioServiceStreamBase::pause() {
        return result;
    }

    result = mServiceEndpoint->stopStream(this, mClientHandle);
    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
    if (endpoint == nullptr) {
        ALOGE("%s() has no endpoint", __func__);
        return AAUDIO_ERROR_INVALID_STATE;
    }
    result = endpoint->stopStream(this, mClientHandle);
    if (result != AAUDIO_OK) {
        ALOGE("AAudioServiceStreamShared::pause() mServiceEndpoint returned %d", result);
        disconnect(); // TODO should we return or pause Base first?
@@ -230,11 +241,6 @@ aaudio_result_t AAudioServiceStreamBase::stop() {
        return result;
    }

    if (mServiceEndpoint == nullptr) {
        ALOGE("AAudioServiceStreamShared::stop() missing endpoint");
        return AAUDIO_ERROR_INVALID_STATE;
    }

    // Send it now because the timestamp gets rounded up when stopStream() is called below.
    // Also we don't need the timestamps while we are shutting down.
    sendCurrentTimestamp(); // warning - this calls a virtual function
@@ -244,8 +250,13 @@ aaudio_result_t AAudioServiceStreamBase::stop() {
        return result;
    }

    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
    if (endpoint == nullptr) {
        ALOGE("%s() has no endpoint", __func__);
        return AAUDIO_ERROR_INVALID_STATE;
    }
    // TODO wait for data to be played out
    result = mServiceEndpoint->stopStream(this, mClientHandle);
    result = endpoint->stopStream(this, mClientHandle);
    if (result != AAUDIO_OK) {
        ALOGE("AAudioServiceStreamShared::stop() mServiceEndpoint returned %d", result);
        disconnect();
+5 −0
Original line number Diff line number Diff line
@@ -233,7 +233,12 @@ protected:
    SimpleDoubleBuffer<Timestamp>  mAtomicTimestamp;

    android::AAudioService &mAudioService;

    // The mServiceEndpoint variable can be accessed by multiple threads.
    // So we access it by locally promoting a weak pointer to a smart pointer,
    // which is thread-safe.
    android::sp<AAudioServiceEndpoint> mServiceEndpoint;
    android::wp<AAudioServiceEndpoint> mServiceEndpointWeak;

private:
    aaudio_handle_t         mHandle = -1;
+44 −11
Original line number Diff line number Diff line
@@ -70,14 +70,19 @@ aaudio_result_t AAudioServiceStreamMMAP::open(const aaudio::AAudioStreamRequest
        return result;
    }

    result = mServiceEndpoint->registerStream(keep);
    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
    if (endpoint == nullptr) {
        ALOGE("%s() has no endpoint", __func__);
        return AAUDIO_ERROR_INVALID_STATE;
    }

    result = endpoint->registerStream(keep);
    if (result != AAUDIO_OK) {
        goto error;
        return result;
    }

    setState(AAUDIO_STREAM_STATE_OPEN);

error:
    return AAUDIO_OK;
}

@@ -122,21 +127,37 @@ aaudio_result_t AAudioServiceStreamMMAP::stop() {

aaudio_result_t AAudioServiceStreamMMAP::startClient(const android::AudioClient& client,
                                                       audio_port_handle_t *clientHandle) {
    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
    if (endpoint == nullptr) {
        ALOGE("%s() has no endpoint", __func__);
        return AAUDIO_ERROR_INVALID_STATE;
    }
    // Start the client on behalf of the application. Generate a new porthandle.
    aaudio_result_t result = mServiceEndpoint->startClient(client, clientHandle);
    aaudio_result_t result = endpoint->startClient(client, clientHandle);
    return result;
}

aaudio_result_t AAudioServiceStreamMMAP::stopClient(audio_port_handle_t clientHandle) {
    aaudio_result_t result = mServiceEndpoint->stopClient(clientHandle);
    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
    if (endpoint == nullptr) {
        ALOGE("%s() has no endpoint", __func__);
        return AAUDIO_ERROR_INVALID_STATE;
    }
    aaudio_result_t result = endpoint->stopClient(clientHandle);
    return result;
}

// Get free-running DSP or DMA hardware position from the HAL.
aaudio_result_t AAudioServiceStreamMMAP::getFreeRunningPosition(int64_t *positionFrames,
                                                                  int64_t *timeNanos) {
    sp<AAudioServiceEndpointMMAP> serviceEndpointMMAP{
            static_cast<AAudioServiceEndpointMMAP *>(mServiceEndpoint.get())};
    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
    if (endpoint == nullptr) {
        ALOGE("%s() has no endpoint", __func__);
        return AAUDIO_ERROR_INVALID_STATE;
    }
    sp<AAudioServiceEndpointMMAP> serviceEndpointMMAP =
            static_cast<AAudioServiceEndpointMMAP *>(endpoint.get());

    aaudio_result_t result = serviceEndpointMMAP->getFreeRunningPosition(positionFrames, timeNanos);
    if (result == AAUDIO_OK) {
        Timestamp timestamp(*positionFrames, *timeNanos);
@@ -152,8 +173,15 @@ aaudio_result_t AAudioServiceStreamMMAP::getFreeRunningPosition(int64_t *positio
// Get timestamp that was written by getFreeRunningPosition()
aaudio_result_t AAudioServiceStreamMMAP::getHardwareTimestamp(int64_t *positionFrames,
                                                                int64_t *timeNanos) {
    sp<AAudioServiceEndpointMMAP> serviceEndpointMMAP{
            static_cast<AAudioServiceEndpointMMAP *>(mServiceEndpoint.get())};

    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
    if (endpoint == nullptr) {
        ALOGE("%s() has no endpoint", __func__);
        return AAUDIO_ERROR_INVALID_STATE;
    }
    sp<AAudioServiceEndpointMMAP> serviceEndpointMMAP =
            static_cast<AAudioServiceEndpointMMAP *>(endpoint.get());

    // TODO Get presentation timestamp from the HAL
    if (mAtomicTimestamp.isValid()) {
        Timestamp timestamp = mAtomicTimestamp.read();
@@ -171,7 +199,12 @@ aaudio_result_t AAudioServiceStreamMMAP::getHardwareTimestamp(int64_t *positionF
aaudio_result_t AAudioServiceStreamMMAP::getAudioDataDescription(
        AudioEndpointParcelable &parcelable)
{
    sp<AAudioServiceEndpointMMAP> serviceEndpointMMAP{
            static_cast<AAudioServiceEndpointMMAP *>(mServiceEndpoint.get())};
    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
    if (endpoint == nullptr) {
        ALOGE("%s() has no endpoint", __func__);
        return AAUDIO_ERROR_INVALID_STATE;
    }
    sp<AAudioServiceEndpointMMAP> serviceEndpointMMAP =
            static_cast<AAudioServiceEndpointMMAP *>(endpoint.get());
    return serviceEndpointMMAP->getDownDataDescription(parcelable);
}
+20 −9
Original line number Diff line number Diff line
@@ -128,6 +128,11 @@ aaudio_result_t AAudioServiceStreamShared::open(const aaudio::AAudioStreamReques

    const AAudioStreamConfiguration &configurationInput = request.getConstantConfiguration();

    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
    if (endpoint == nullptr) {
        result = AAUDIO_ERROR_INVALID_STATE;
        goto error;
    }

    // Is the request compatible with the shared endpoint?
    setFormat(configurationInput.getFormat());
@@ -141,20 +146,20 @@ aaudio_result_t AAudioServiceStreamShared::open(const aaudio::AAudioStreamReques

    setSampleRate(configurationInput.getSampleRate());
    if (getSampleRate() == AAUDIO_UNSPECIFIED) {
        setSampleRate(mServiceEndpoint->getSampleRate());
    } else if (getSampleRate() != mServiceEndpoint->getSampleRate()) {
        setSampleRate(endpoint->getSampleRate());
    } else if (getSampleRate() != endpoint->getSampleRate()) {
        ALOGE("AAudioServiceStreamShared::open() mSampleRate = %d, need %d",
              getSampleRate(), mServiceEndpoint->getSampleRate());
              getSampleRate(), endpoint->getSampleRate());
        result = AAUDIO_ERROR_INVALID_RATE;
        goto error;
    }

    setSamplesPerFrame(configurationInput.getSamplesPerFrame());
    if (getSamplesPerFrame() == AAUDIO_UNSPECIFIED) {
        setSamplesPerFrame(mServiceEndpoint->getSamplesPerFrame());
    } else if (getSamplesPerFrame() != mServiceEndpoint->getSamplesPerFrame()) {
        setSamplesPerFrame(endpoint->getSamplesPerFrame());
    } else if (getSamplesPerFrame() != endpoint->getSamplesPerFrame()) {
        ALOGE("AAudioServiceStreamShared::open() mSamplesPerFrame = %d, need %d",
              getSamplesPerFrame(), mServiceEndpoint->getSamplesPerFrame());
              getSamplesPerFrame(), endpoint->getSamplesPerFrame());
        result = AAUDIO_ERROR_OUT_OF_RANGE;
        goto error;
    }
@@ -181,9 +186,9 @@ aaudio_result_t AAudioServiceStreamShared::open(const aaudio::AAudioStreamReques
    }

    ALOGD("AAudioServiceStreamShared::open() actual rate = %d, channels = %d, deviceId = %d",
          getSampleRate(), getSamplesPerFrame(), mServiceEndpoint->getDeviceId());
          getSampleRate(), getSamplesPerFrame(), endpoint->getDeviceId());

    result = mServiceEndpoint->registerStream(keep);
    result = endpoint->registerStream(keep);
    if (result != AAUDIO_OK) {
        goto error;
    }
@@ -250,7 +255,13 @@ aaudio_result_t AAudioServiceStreamShared::getHardwareTimestamp(int64_t *positio
                                                                int64_t *timeNanos) {

    int64_t position = 0;
    aaudio_result_t result = mServiceEndpoint->getTimestamp(&position, timeNanos);
    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
    if (endpoint == nullptr) {
        ALOGE("%s() has no endpoint", __func__);
        return AAUDIO_ERROR_INVALID_STATE;
    }

    aaudio_result_t result = endpoint->getTimestamp(&position, timeNanos);
    if (result == AAUDIO_OK) {
        int64_t offset = mTimestampPositionOffset.load();
        // TODO, do not go below starting value