Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit f2af00e8 authored by Chong Zhang's avatar Chong Zhang Committed by Ryan Longair
Browse files

Fix race condition for cas sessions

Change the session to shared_ptr and use atomic_load/store.

Test: POC; CTS MediaCasTest; CTS MediaDrmClearkeyTest#
testClearKeyPlaybackMpeg2ts
bug: 113027383

Change-Id: I75f4cb33a022f28d45918442d64c5c46df2640ef
(cherry picked from commit 7934a8f7)
parent b1b11dc8
Loading
Loading
Loading
Loading
+13 −11
Original line number Diff line number Diff line
@@ -118,9 +118,9 @@ status_t ClearKeyCasPlugin::openSession(CasSessionId* sessionId) {

status_t ClearKeyCasPlugin::closeSession(const CasSessionId &sessionId) {
    ALOGV("closeSession: sessionId=%s", sessionIdToString(sessionId).string());
    sp<ClearKeyCasSession> session =
    std::shared_ptr<ClearKeyCasSession> session =
            ClearKeySessionLibrary::get()->findSession(sessionId);
    if (session == NULL) {
    if (session.get() == nullptr) {
        return ERROR_CAS_SESSION_NOT_OPENED;
    }

@@ -132,9 +132,9 @@ status_t ClearKeyCasPlugin::setSessionPrivateData(
        const CasSessionId &sessionId, const CasData & /*data*/) {
    ALOGV("setSessionPrivateData: sessionId=%s",
            sessionIdToString(sessionId).string());
    sp<ClearKeyCasSession> session =
    std::shared_ptr<ClearKeyCasSession> session =
            ClearKeySessionLibrary::get()->findSession(sessionId);
    if (session == NULL) {
    if (session.get() == nullptr) {
        return ERROR_CAS_SESSION_NOT_OPENED;
    }
    return OK;
@@ -143,9 +143,9 @@ status_t ClearKeyCasPlugin::setSessionPrivateData(
status_t ClearKeyCasPlugin::processEcm(
        const CasSessionId &sessionId, const CasEcm& ecm) {
    ALOGV("processEcm: sessionId=%s", sessionIdToString(sessionId).string());
    sp<ClearKeyCasSession> session =
    std::shared_ptr<ClearKeyCasSession> session =
            ClearKeySessionLibrary::get()->findSession(sessionId);
    if (session == NULL) {
    if (session.get() == nullptr) {
        return ERROR_CAS_SESSION_NOT_OPENED;
    }

@@ -418,15 +418,15 @@ status_t ClearKeyDescramblerPlugin::setMediaCasSession(
        const CasSessionId &sessionId) {
    ALOGV("setMediaCasSession: sessionId=%s", sessionIdToString(sessionId).string());

    sp<ClearKeyCasSession> session =
    std::shared_ptr<ClearKeyCasSession> session =
            ClearKeySessionLibrary::get()->findSession(sessionId);

    if (session == NULL) {
    if (session.get() == nullptr) {
        ALOGE("ClearKeyDescramblerPlugin: session not found");
        return ERROR_CAS_SESSION_NOT_OPENED;
    }

    mCASSession = session;
    std::atomic_store(&mCASSession, session);
    return OK;
}

@@ -447,12 +447,14 @@ ssize_t ClearKeyDescramblerPlugin::descramble(
          subSamplesToString(subSamples, numSubSamples).string(),
          srcPtr, dstPtr, srcOffset, dstOffset);

    if (mCASSession == NULL) {
    std::shared_ptr<ClearKeyCasSession> session = std::atomic_load(&mCASSession);

    if (session.get() == nullptr) {
        ALOGE("Uninitialized CAS session!");
        return ERROR_CAS_DECRYPT_UNIT_NOT_INITIALIZED;
    }

    return mCASSession->decrypt(
    return session->decrypt(
            secure, scramblingControl,
            numSubSamples, subSamples,
            (uint8_t*)srcPtr + srcOffset,
+1 −1
Original line number Diff line number Diff line
@@ -120,7 +120,7 @@ public:
            AString *errorDetailMsg) override;

private:
    sp<ClearKeyCasSession> mCASSession;
    std::shared_ptr<ClearKeyCasSession> mCASSession;

    String8 subSamplesToString(
            SubSample const *subSamples,
+4 −4
Original line number Diff line number Diff line
@@ -56,7 +56,7 @@ status_t ClearKeySessionLibrary::addSession(

    Mutex::Autolock lock(mSessionsLock);

    sp<ClearKeyCasSession> session = new ClearKeyCasSession(plugin);
    std::shared_ptr<ClearKeyCasSession> session(new ClearKeyCasSession(plugin));

    uint8_t *byteArray = (uint8_t *) &mNextSessionId;
    sessionId->push_back(byteArray[3]);
@@ -69,7 +69,7 @@ status_t ClearKeySessionLibrary::addSession(
    return OK;
}

sp<ClearKeyCasSession> ClearKeySessionLibrary::findSession(
std::shared_ptr<ClearKeyCasSession> ClearKeySessionLibrary::findSession(
        const CasSessionId& sessionId) {
    Mutex::Autolock lock(mSessionsLock);

@@ -88,7 +88,7 @@ void ClearKeySessionLibrary::destroySession(const CasSessionId& sessionId) {
        return;
    }

    sp<ClearKeyCasSession> session = mIDToSessionMap.valueAt(index);
    std::shared_ptr<ClearKeyCasSession> session = mIDToSessionMap.valueAt(index);
    mIDToSessionMap.removeItemsAt(index);
}

@@ -96,7 +96,7 @@ void ClearKeySessionLibrary::destroyPlugin(CasPlugin *plugin) {
    Mutex::Autolock lock(mSessionsLock);

    for (ssize_t index = (ssize_t)mIDToSessionMap.size() - 1; index >= 0; index--) {
        sp<ClearKeyCasSession> session = mIDToSessionMap.valueAt(index);
        std::shared_ptr<ClearKeyCasSession> session = mIDToSessionMap.valueAt(index);
        if (session->getPlugin() == plugin) {
            mIDToSessionMap.removeItemsAt(index);
        }
+6 −4
Original line number Diff line number Diff line
@@ -32,6 +32,10 @@ class KeyFetcher;

class ClearKeyCasSession : public RefBase {
public:
    explicit ClearKeyCasSession(CasPlugin *plugin);

    virtual ~ClearKeyCasSession();

    ssize_t decrypt(
            bool secure,
            DescramblerPlugin::ScramblingControl scramblingControl,
@@ -58,8 +62,6 @@ private:

    friend class ClearKeySessionLibrary;

    explicit ClearKeyCasSession(CasPlugin *plugin);
    virtual ~ClearKeyCasSession();
    CasPlugin* getPlugin() const { return mPlugin; }
    status_t decryptPayload(
            const AES_KEY& key, size_t length, size_t offset, char* buffer) const;
@@ -73,7 +75,7 @@ public:

    status_t addSession(CasPlugin *plugin, CasSessionId *sessionId);

    sp<ClearKeyCasSession> findSession(const CasSessionId& sessionId);
    std::shared_ptr<ClearKeyCasSession> findSession(const CasSessionId& sessionId);

    void destroySession(const CasSessionId& sessionId);

@@ -85,7 +87,7 @@ private:

    Mutex mSessionsLock;
    uint32_t mNextSessionId;
    KeyedVector<CasSessionId, sp<ClearKeyCasSession>> mIDToSessionMap;
    KeyedVector<CasSessionId, std::shared_ptr<ClearKeyCasSession>> mIDToSessionMap;

    ClearKeySessionLibrary();
    DISALLOW_EVIL_CONSTRUCTORS(ClearKeySessionLibrary);