Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit bc4773d1 authored by Edwin Wong's avatar Edwin Wong Committed by Android (Google) Code Review
Browse files

Merge "Add KEY_TYPE_OFFLINE support."

parents 99f5cb42 92f92de0
Loading
Loading
Loading
Loading
+14 −0
Original line number Diff line number Diff line
@@ -27,4 +27,18 @@ bool operator<(const Vector<uint8_t> &lhs, const Vector<uint8_t> &rhs) {
    return memcmp((void *)lhs.array(), (void *)rhs.array(), rhs.size()) < 0;
}

std::string ByteArrayToHexString(const uint8_t* in_buffer, size_t length) {
    static const char kHexChars[] = "0123456789ABCDEF";

    // Each input byte creates two output hex characters.
    std::string out_buffer(length * 2, '\0');

    for (size_t i = 0; i < length; ++i) {
        char byte = in_buffer[i];
        out_buffer[(i * 2)] = kHexChars[(byte >> 4) & 0xf];
        out_buffer[(i * 2) + 1] = kHexChars[byte & 0xf];
    }
    return out_buffer;
}

} // namespace android
+4 −2
Original line number Diff line number Diff line
@@ -17,14 +17,16 @@
#ifndef CLEARKEY_UTILS_H_
#define CLEARKEY_UTILS_H_

#include <string>
#include <utils/Vector.h>

namespace android {
// Add a comparison operator for this Vector specialization so that it can be
// used as a key in a KeyedVector.
namespace android {

bool operator<(const Vector<uint8_t> &lhs, const Vector<uint8_t> &rhs);

std::string ByteArrayToHexString(const uint8_t* in_buffer, size_t length);

} // namespace android

#define UNUSED(x) (void)(x);
+14 −0
Original line number Diff line number Diff line
@@ -25,10 +25,12 @@ cc_binary {
        "CreatePluginFactories.cpp",
        "CryptoFactory.cpp",
        "CryptoPlugin.cpp",
        "DeviceFiles.cpp",
        "DrmFactory.cpp",
        "DrmPlugin.cpp",
        "InitDataParser.cpp",
        "JsonWebKey.cpp",
        "MemoryFileSystem.cpp",
        "Session.cpp",
        "SessionLibrary.cpp",
        "service.cpp",
@@ -49,11 +51,13 @@ cc_binary {
        "libhidlmemory",
        "libhidltransport",
        "liblog",
        "libprotobuf-cpp-lite",
        "libutils",
    ],

    static_libs: [
        "libclearkeycommon",
        "libclearkeydevicefiles-protos",
        "libjsmn",
    ],

@@ -66,3 +70,13 @@ cc_binary {
    },
}

cc_library_static {
    name: "libclearkeydevicefiles-protos",
    vendor: true,

    proto: {
        export_proto_headers: true,
        type: "lite",
    },
    srcs: ["protos/DeviceFiles.proto"],
}
+240 −0
Original line number Diff line number Diff line
// Copyright 2018 Google LLC. All Rights Reserved. This file and proprietary
// source code may only be used and distributed under the Widevine Master
// License Agreement.

#include <utils/Log.h>

#include <string>
#include <sys/stat.h>

#include "DeviceFiles.h"
#include "Utils.h"

#include <openssl/sha.h>

// Protobuf generated classes.
using android::hardware::drm::V1_1::clearkey::OfflineFile;
using android::hardware::drm::V1_1::clearkey::HashedFile;
using android::hardware::drm::V1_1::clearkey::License;
using android::hardware::drm::V1_1::clearkey::License_LicenseState_ACTIVE;
using android::hardware::drm::V1_1::clearkey::License_LicenseState_RELEASING;

namespace {
const char kLicenseFileNameExt[] = ".lic";

bool Hash(const std::string& data, std::string* hash) {
    if (!hash) return false;

    hash->resize(SHA256_DIGEST_LENGTH);

    const unsigned char* input = reinterpret_cast<const unsigned char*>(data.data());
    unsigned char* output = reinterpret_cast<unsigned char*>(&(*hash)[0]);
    SHA256(input, data.size(), output);
    return true;
}

}  // namespace

namespace android {
namespace hardware {
namespace drm {
namespace V1_1 {
namespace clearkey {

bool DeviceFiles::StoreLicense(
        const std::string& keySetId, LicenseState state,
        const std::string& licenseResponse) {

    OfflineFile file;
    file.set_type(OfflineFile::LICENSE);
    file.set_version(OfflineFile::VERSION_1);

    License* license = file.mutable_license();
    switch (state) {
        case kLicenseStateActive:
            license->set_state(License_LicenseState_ACTIVE);
            license->set_license(licenseResponse);
            break;
        case kLicenseStateReleasing:
            license->set_state(License_LicenseState_RELEASING);
            break;
        default:
            ALOGW("StoreLicense: Unknown license state: %u", state);
            return false;
    }

    std::string serializedFile;
    file.SerializeToString(&serializedFile);

    return StoreFileWithHash(keySetId + kLicenseFileNameExt, serializedFile);
}

bool DeviceFiles::StoreFileWithHash(const std::string& fileName,
        const std::string& serializedFile) {
    std::string hash;
    if (!Hash(serializedFile, &hash)) {
        ALOGE("StoreFileWithHash: Failed to compute hash");
        return false;
    }

    HashedFile hashFile;
    hashFile.set_file(serializedFile);
    hashFile.set_hash(hash);

    std::string serializedHashFile;
    hashFile.SerializeToString(&serializedHashFile);

    return StoreFileRaw(fileName, serializedHashFile);
}

bool DeviceFiles::StoreFileRaw(const std::string& fileName, const std::string& serializedHashFile) {
    MemoryFileSystem::MemoryFile memFile;
    memFile.setFileName(fileName);
    memFile.setContent(serializedHashFile);
    memFile.setFileSize(serializedHashFile.size());
    size_t len = mFileHandle.Write(fileName, memFile);

    if (len != static_cast<size_t>(serializedHashFile.size())) {
        ALOGE("StoreFileRaw: Failed to write %s", fileName.c_str());
        ALOGD("StoreFileRaw: expected=%zd, actual=%zu", serializedHashFile.size(), len);
        return false;
    }

    ALOGD("StoreFileRaw: wrote %zu bytes to %s", serializedHashFile.size(), fileName.c_str());
    return true;
}

bool DeviceFiles::RetrieveLicense(
    const std::string& keySetId, LicenseState* state, std::string* offlineLicense) {
    OfflineFile file;

    if (!RetrieveHashedFile(keySetId + kLicenseFileNameExt, &file)) {
        return false;
    }

    if (file.type() != OfflineFile::LICENSE) {
        ALOGE("RetrieveLicense: Invalid file type");
        return false;
    }

    if (file.version() != OfflineFile::VERSION_1) {
        ALOGE("RetrieveLicense: Invalid file version");
        return false;
    }

    if (!file.has_license()) {
        ALOGE("RetrieveLicense: License not present");
        return false;
    }

    License license = file.license();

    switch (license.state()) {
        case License_LicenseState_ACTIVE:
            *state = kLicenseStateActive;
            break;
        case License_LicenseState_RELEASING:
            *state = kLicenseStateReleasing;
            break;
        default:
            ALOGW("RetrieveLicense: Unrecognized license state: %u",
                    kLicenseStateUnknown);
            *state = kLicenseStateUnknown;
            break;
    }

    *offlineLicense = license.license();
    return true;
}

bool DeviceFiles::DeleteAllLicenses() {
    return mFileHandle.RemoveAllFiles();
}

bool DeviceFiles::LicenseExists(const std::string& keySetId) {
    return mFileHandle.FileExists(keySetId + kLicenseFileNameExt);
}

bool DeviceFiles::RetrieveHashedFile(const std::string& fileName, OfflineFile* deSerializedFile) {
    if (!deSerializedFile) {
        ALOGE("RetrieveHashedFile: invalid file parameter");
        return false;
    }

    if (!FileExists(fileName)) {
        ALOGE("RetrieveHashedFile: %s does not exist", fileName.c_str());
        return false;
    }

    ssize_t bytes = GetFileSize(fileName);
    if (bytes <= 0) {
        ALOGE("RetrieveHashedFile: invalid file size: %s", fileName.c_str());
        // Remove the corrupted file so the caller will not get the same error
        // when trying to access the file repeatedly, causing the system to stall.
        RemoveFile(fileName);
        return false;
    }

    std::string serializedHashFile;
    serializedHashFile.resize(bytes);
    bytes = mFileHandle.Read(fileName, &serializedHashFile);

    if (bytes != static_cast<ssize_t>(serializedHashFile.size())) {
        ALOGE("RetrieveHashedFile: Failed to read from %s", fileName.c_str());
        ALOGV("RetrieveHashedFile: expected: %zd, actual: %zd", serializedHashFile.size(), bytes);
        // Remove the corrupted file so the caller will not get the same error
        // when trying to access the file repeatedly, causing the system to stall.
        RemoveFile(fileName);
        return false;
    }

    ALOGV("RetrieveHashedFile: read %zd from %s", bytes, fileName.c_str());

    HashedFile hashFile;
    if (!hashFile.ParseFromString(serializedHashFile)) {
        ALOGE("RetrieveHashedFile: Unable to parse hash file");
        // Remove corrupt file.
        RemoveFile(fileName);
        return false;
    }

    std::string hash;
    if (!Hash(hashFile.file(), &hash)) {
        ALOGE("RetrieveHashedFile: Hash computation failed");
        return false;
    }

    if (hash != hashFile.hash()) {
        ALOGE("RetrieveHashedFile: Hash mismatch");
        // Remove corrupt file.
        RemoveFile(fileName);
        return false;
    }

    if (!deSerializedFile->ParseFromString(hashFile.file())) {
        ALOGE("RetrieveHashedFile: Unable to parse file");
        // Remove corrupt file.
        RemoveFile(fileName);
        return false;
    }

    return true;
}

bool DeviceFiles::FileExists(const std::string& fileName) const {
    return mFileHandle.FileExists(fileName);
}

bool DeviceFiles::RemoveFile(const std::string& fileName) {
    return mFileHandle.RemoveFile(fileName);
}

ssize_t DeviceFiles::GetFileSize(const std::string& fileName) const {
    return mFileHandle.GetFileSize(fileName);
}

} // namespace clearkey
} // namespace V1_1
} // namespace drm
} // namespace hardware
} // namespace android
+161 −41
Original line number Diff line number Diff line
@@ -25,11 +25,15 @@
#include "ClearKeyDrmProperties.h"
#include "Session.h"
#include "TypeConvert.h"
#include "Utils.h"

namespace {
const std::string kKeySetIdPrefix("ckid");
const int kKeySetIdLength = 16;
const int kSecureStopIdStart = 100;
const std::string kOfflineLicense("\"type\":\"persistent-license\"");
const std::string kStreaming("Streaming");
const std::string kOffline("Offline");
const std::string kTemporaryLicense("\"type\":\"temporary\"");
const std::string kTrue("True");

const std::string kQueryKeyLicenseType("LicenseType");
@@ -66,6 +70,7 @@ DrmPlugin::DrmPlugin(SessionLibrary* sessionLibrary)
    mPlayPolicy.clear();
    initProperties();
    mSecureStops.clear();
    std::srand(std::time(nullptr));
}

void DrmPlugin::initProperties() {
@@ -147,25 +152,53 @@ Status DrmPlugin::getKeyRequestCommon(const hidl_vec<uint8_t>& scope,
        std::string *defaultUrl) {
        UNUSED(optionalParameters);

    // GetKeyRequestOfflineKeyTypeNotSupported() in vts 1.0 and 1.1 expects
    // KeyType::OFFLINE to return ERROR_DRM_CANNOT_HANDLE in clearkey plugin.
    // Those tests pass in an empty initData, we use the empty initData to
    // signal the specific use case.
    if (keyType == KeyType::OFFLINE && 0 == initData.size()) {
        return Status::ERROR_DRM_CANNOT_HANDLE;
    }

    *defaultUrl = "";
    *keyRequestType = KeyRequestType::UNKNOWN;
    *request = std::vector<uint8_t>();

    if (scope.size() == 0) {
    if (scope.size() == 0 ||
            (keyType != KeyType::STREAMING &&
            keyType != KeyType::OFFLINE &&
            keyType != KeyType::RELEASE)) {
        return Status::BAD_VALUE;
    }

    if (keyType != KeyType::STREAMING) {
        return Status::ERROR_DRM_CANNOT_HANDLE;
    }

    sp<Session> session = mSessionLibrary->findSession(toVector(scope));
    const std::vector<uint8_t> scopeId = toVector(scope);
    sp<Session> session;
    if (keyType == KeyType::STREAMING || keyType == KeyType::OFFLINE) {
        std::vector<uint8_t> sessionId(scopeId.begin(), scopeId.end());
        session = mSessionLibrary->findSession(sessionId);
        if (!session.get()) {
            return Status::ERROR_DRM_SESSION_NOT_OPENED;
        }

    Status status = session->getKeyRequest(initData, mimeType, request);
        *keyRequestType = KeyRequestType::INITIAL;
    }

    Status status = session->getKeyRequest(initData, mimeType, keyType, request);

    if (keyType == KeyType::RELEASE) {
        std::vector<uint8_t> keySetId(scopeId.begin(), scopeId.end());
        std::string requestString(request->begin(), request->end());
        if (requestString.find(kOfflineLicense) != std::string::npos) {
            std::string emptyResponse;
            std::string keySetIdString(keySetId.begin(), keySetId.end());
            if (!mFileHandle.StoreLicense(keySetIdString,
                    DeviceFiles::kLicenseStateReleasing,
                    emptyResponse)) {
                ALOGE("Problem releasing offline license");
                return Status::ERROR_DRM_UNKNOWN;
            }
        }
        *keyRequestType = KeyRequestType::RELEASE;
    }
    return status;
}

@@ -227,6 +260,30 @@ void DrmPlugin::setPlayPolicy() {
    mPlayPolicy.push_back(policy);
}

bool DrmPlugin::makeKeySetId(std::string* keySetId) {
    if (!keySetId) {
        ALOGE("keySetId destination not provided");
        return false;
    }
    std::vector<uint8_t> ksid(kKeySetIdPrefix.begin(), kKeySetIdPrefix.end());
    ksid.resize(kKeySetIdLength);
    std::vector<uint8_t> randomData((kKeySetIdLength - kKeySetIdPrefix.size()) / 2, 0);

    while (keySetId->empty()) {
        for (auto itr = randomData.begin(); itr != randomData.end(); ++itr) {
            *itr = std::rand() % 0xff;
        }
        *keySetId = kKeySetIdPrefix + ByteArrayToHexString(
                reinterpret_cast<const uint8_t*>(randomData.data()), randomData.size());
        if (mFileHandle.LicenseExists(*keySetId)) {
            // collision, regenerate
            ALOGV("Retry generating KeySetId");
            keySetId->clear();
        }
    }
    return true;
}

Return<void> DrmPlugin::provideKeyResponse(
        const hidl_vec<uint8_t>& scope,
        const hidl_vec<uint8_t>& response,
@@ -237,22 +294,49 @@ Return<void> DrmPlugin::provideKeyResponse(
        return Void();
    }

    sp<Session> session = mSessionLibrary->findSession(toVector(scope));
    std::string responseString(
            reinterpret_cast<const char*>(response.data()), response.size());
    const std::vector<uint8_t> scopeId = toVector(scope);
    std::vector<uint8_t> sessionId;
    std::string keySetId;

    Status status = Status::OK;
    bool isOfflineLicense = responseString.find(kOfflineLicense) != std::string::npos;
    bool isRelease = (memcmp(scopeId.data(), kKeySetIdPrefix.data(), kKeySetIdPrefix.size()) == 0);
    if (isRelease) {
        keySetId.assign(scopeId.begin(), scopeId.end());
    } else {
        sessionId.assign(scopeId.begin(), scopeId.end());
        sp<Session> session = mSessionLibrary->findSession(sessionId);
        if (!session.get()) {
            _hidl_cb(Status::ERROR_DRM_SESSION_NOT_OPENED, hidl_vec<uint8_t>());
            return Void();
        }

        setPlayPolicy();
    std::vector<uint8_t> keySetId;
        // non offline license returns empty keySetId
        keySetId.clear();

    Status status = session->provideKeyResponse(response);
        status = session->provideKeyResponse(response);
        if (status == Status::OK) {
            if (isOfflineLicense) {
                if (!makeKeySetId(&keySetId)) {
                    _hidl_cb(Status::ERROR_DRM_UNKNOWN, hidl_vec<uint8_t>());
                    return Void();
                }
                bool ok = mFileHandle.StoreLicense(
                        keySetId,
                        DeviceFiles::kLicenseStateActive,
                        std::string(response.begin(), response.end()));
                if (!ok) {
                    ALOGE("Failed to store offline license");
                }
            }

            // Test calling AMediaDrm listeners.
        sendEvent(EventType::VENDOR_DEFINED, toVector(scope), toVector(scope));
            sendEvent(EventType::VENDOR_DEFINED, sessionId, sessionId);

        sendExpirationUpdate(toVector(scope), 100);
            sendExpirationUpdate(sessionId, 100);

            std::vector<KeyStatus> keysStatus;
            KeyStatus keyStatus;
@@ -267,16 +351,52 @@ Return<void> DrmPlugin::provideKeyResponse(
            keyStatus.type = V1_0::KeyStatusType::EXPIRED;
            keysStatus.push_back(keyStatus);

        sendKeysChange(toVector(scope), keysStatus, true);
    }
            sendKeysChange(sessionId, keysStatus, true);

    installSecureStop(scope);
            installSecureStop(sessionId);
        } else {
            ALOGE("Failed to add key, error=%d", status);
        }
    } // keyType::STREAMING || keyType::OFFLINE

    // Returns status and empty keySetId
    _hidl_cb(status, toHidlVec(keySetId));
    std::vector<uint8_t> keySetIdVec(keySetId.begin(), keySetId.end());
    _hidl_cb(status, toHidlVec(keySetIdVec));
    return Void();
}

Return<Status> DrmPlugin::restoreKeys(
        const hidl_vec<uint8_t>& sessionId, const hidl_vec<uint8_t>& keySetId) {
        if (sessionId.size() == 0 || keySetId.size() == 0) {
            return Status::BAD_VALUE;
        }

        DeviceFiles::LicenseState licenseState;
        std::string keySetIdString(keySetId.begin(), keySetId.end());
        std::string offlineLicense;
        Status status = Status::OK;
        if (!mFileHandle.RetrieveLicense(keySetIdString, &licenseState, &offlineLicense)) {
            ALOGE("Failed to restore offline license");
            return Status::ERROR_DRM_NO_LICENSE;
        }

        if (DeviceFiles::kLicenseStateUnknown == licenseState ||
                DeviceFiles::kLicenseStateReleasing == licenseState) {
            ALOGE("Invalid license state=%d", licenseState);
            return Status::ERROR_DRM_NO_LICENSE;
        }

        sp<Session> session = mSessionLibrary->findSession(toVector(sessionId));
        if (!session.get()) {
            return Status::ERROR_DRM_SESSION_NOT_OPENED;
        }
        status = session->provideKeyResponse(std::vector<uint8_t>(offlineLicense.begin(),
                offlineLicense.end()));
        if (status != Status::OK) {
            ALOGE("Failed to restore keys");
        }
        return status;
}

Return<void> DrmPlugin::getPropertyString(
        const hidl_string& propertyName, getPropertyString_cb _hidl_cb) {
    std::string name(propertyName.c_str());
Loading