Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 869e0798 authored by Mike Lockwood's avatar Mike Lockwood Committed by Android (Google) Code Review
Browse files

Merge "MTP: add strict bounds checking for all incoming packets" into lmp-mr1-dev

parents 94691b01 ab063847
Loading
Loading
Loading
Loading
+117 −43
Original line number Diff line number Diff line
@@ -51,104 +51,178 @@ void MtpDataPacket::setTransactionID(MtpTransactionID id) {
    MtpPacket::putUInt32(MTP_CONTAINER_TRANSACTION_ID_OFFSET, id);
}

uint16_t MtpDataPacket::getUInt16() {
bool MtpDataPacket::getUInt8(uint8_t& value) {
    if (mPacketSize - mOffset < sizeof(value))
        return false;
    value = mBuffer[mOffset++];
    return true;
}

bool MtpDataPacket::getUInt16(uint16_t& value) {
    if (mPacketSize - mOffset < sizeof(value))
        return false;
    int offset = mOffset;
    uint16_t result = (uint16_t)mBuffer[offset] | ((uint16_t)mBuffer[offset + 1] << 8);
    mOffset += 2;
    return result;
    value = (uint16_t)mBuffer[offset] | ((uint16_t)mBuffer[offset + 1] << 8);
    mOffset += sizeof(value);
    return true;
}

uint32_t MtpDataPacket::getUInt32() {
bool MtpDataPacket::getUInt32(uint32_t& value) {
    if (mPacketSize - mOffset < sizeof(value))
        return false;
    int offset = mOffset;
    uint32_t result = (uint32_t)mBuffer[offset] | ((uint32_t)mBuffer[offset + 1] << 8) |
    value = (uint32_t)mBuffer[offset] | ((uint32_t)mBuffer[offset + 1] << 8) |
           ((uint32_t)mBuffer[offset + 2] << 16)  | ((uint32_t)mBuffer[offset + 3] << 24);
    mOffset += 4;
    return result;
    mOffset += sizeof(value);
    return true;
}

uint64_t MtpDataPacket::getUInt64() {
bool MtpDataPacket::getUInt64(uint64_t& value) {
    if (mPacketSize - mOffset < sizeof(value))
        return false;
    int offset = mOffset;
    uint64_t result = (uint64_t)mBuffer[offset] | ((uint64_t)mBuffer[offset + 1] << 8) |
    value = (uint64_t)mBuffer[offset] | ((uint64_t)mBuffer[offset + 1] << 8) |
           ((uint64_t)mBuffer[offset + 2] << 16) | ((uint64_t)mBuffer[offset + 3] << 24) |
           ((uint64_t)mBuffer[offset + 4] << 32) | ((uint64_t)mBuffer[offset + 5] << 40) |
           ((uint64_t)mBuffer[offset + 6] << 48)  | ((uint64_t)mBuffer[offset + 7] << 56);
    mOffset += 8;
    return result;
    mOffset += sizeof(value);
    return true;
}

void MtpDataPacket::getUInt128(uint128_t& value) {
    value[0] = getUInt32();
    value[1] = getUInt32();
    value[2] = getUInt32();
    value[3] = getUInt32();
bool MtpDataPacket::getUInt128(uint128_t& value) {
    return getUInt32(value[0]) && getUInt32(value[1]) && getUInt32(value[2]) && getUInt32(value[3]);
}

void MtpDataPacket::getString(MtpStringBuffer& string)
bool MtpDataPacket::getString(MtpStringBuffer& string)
{
    string.readFromPacket(this);
    return string.readFromPacket(this);
}

Int8List* MtpDataPacket::getAInt8() {
    uint32_t count;
    if (!getUInt32(count))
        return NULL;
    Int8List* result = new Int8List;
    int count = getUInt32();
    for (int i = 0; i < count; i++)
        result->push(getInt8());
    for (uint32_t i = 0; i < count; i++) {
        int8_t value;
        if (!getInt8(value)) {
            delete result;
            return NULL;
        }
        result->push(value);
    }
    return result;
}

UInt8List* MtpDataPacket::getAUInt8() {
    uint32_t count;
    if (!getUInt32(count))
        return NULL;
    UInt8List* result = new UInt8List;
    int count = getUInt32();
    for (int i = 0; i < count; i++)
        result->push(getUInt8());
    for (uint32_t i = 0; i < count; i++) {
        uint8_t value;
        if (!getUInt8(value)) {
            delete result;
            return NULL;
        }
        result->push(value);
    }
    return result;
}

Int16List* MtpDataPacket::getAInt16() {
    uint32_t count;
    if (!getUInt32(count))
        return NULL;
    Int16List* result = new Int16List;
    int count = getUInt32();
    for (int i = 0; i < count; i++)
        result->push(getInt16());
    for (uint32_t i = 0; i < count; i++) {
        int16_t value;
        if (!getInt16(value)) {
            delete result;
            return NULL;
        }
        result->push(value);
    }
    return result;
}

UInt16List* MtpDataPacket::getAUInt16() {
    uint32_t count;
    if (!getUInt32(count))
        return NULL;
    UInt16List* result = new UInt16List;
    int count = getUInt32();
    for (int i = 0; i < count; i++)
        result->push(getUInt16());
    for (uint32_t i = 0; i < count; i++) {
        uint16_t value;
        if (!getUInt16(value)) {
            delete result;
            return NULL;
        }
        result->push(value);
    }
    return result;
}

Int32List* MtpDataPacket::getAInt32() {
    uint32_t count;
    if (!getUInt32(count))
        return NULL;
    Int32List* result = new Int32List;
    int count = getUInt32();
    for (int i = 0; i < count; i++)
        result->push(getInt32());
    for (uint32_t i = 0; i < count; i++) {
        int32_t value;
        if (!getInt32(value)) {
            delete result;
            return NULL;
        }
        result->push(value);
    }
    return result;
}

UInt32List* MtpDataPacket::getAUInt32() {
    uint32_t count;
    if (!getUInt32(count))
        return NULL;
    UInt32List* result = new UInt32List;
    int count = getUInt32();
    for (int i = 0; i < count; i++)
        result->push(getUInt32());
    for (uint32_t i = 0; i < count; i++) {
        uint32_t value;
        if (!getUInt32(value)) {
            delete result;
            return NULL;
        }
        result->push(value);
    }
    return result;
}

Int64List* MtpDataPacket::getAInt64() {
    uint32_t count;
    if (!getUInt32(count))
        return NULL;
    Int64List* result = new Int64List;
    int count = getUInt32();
    for (int i = 0; i < count; i++)
        result->push(getInt64());
    for (uint32_t i = 0; i < count; i++) {
        int64_t value;
        if (!getInt64(value)) {
            delete result;
            return NULL;
        }
        result->push(value);
    }
    return result;
}

UInt64List* MtpDataPacket::getAUInt64() {
    uint32_t count;
    if (!getUInt32(count))
        return NULL;
    UInt64List* result = new UInt64List;
    int count = getUInt32();
    for (int i = 0; i < count; i++)
        result->push(getUInt64());
    for (uint32_t i = 0; i < count; i++) {
        uint64_t value;
        if (!getUInt64(value)) {
            delete result;
            return NULL;
        }
        result->push(value);
    }
    return result;
}

+13 −12
Original line number Diff line number Diff line
@@ -30,7 +30,7 @@ class MtpStringBuffer;
class MtpDataPacket : public MtpPacket {
private:
    // current offset for get/put methods
    int                 mOffset;
    size_t              mOffset;

public:
                        MtpDataPacket();
@@ -42,17 +42,18 @@ public:
    void                setTransactionID(MtpTransactionID id);

    inline const uint8_t*     getData() const { return mBuffer + MTP_CONTAINER_HEADER_SIZE; }
    inline uint8_t      getUInt8() { return (uint8_t)mBuffer[mOffset++]; }
    inline int8_t       getInt8() { return (int8_t)mBuffer[mOffset++]; }
    uint16_t            getUInt16();
    inline int16_t      getInt16() { return (int16_t)getUInt16(); }
    uint32_t            getUInt32();
    inline int32_t      getInt32() { return (int32_t)getUInt32(); }
    uint64_t            getUInt64();
    inline int64_t      getInt64() { return (int64_t)getUInt64(); }
    void                getUInt128(uint128_t& value);
    inline void         getInt128(int128_t& value) { getUInt128((uint128_t&)value); }
    void                getString(MtpStringBuffer& string);

    bool                getUInt8(uint8_t& value);
    inline bool         getInt8(int8_t& value) { return getUInt8((uint8_t&)value); }
    bool                getUInt16(uint16_t& value);
    inline bool         getInt16(int16_t& value) { return getUInt16((uint16_t&)value); }
    bool                getUInt32(uint32_t& value);
    inline bool         getInt32(int32_t& value) { return getUInt32((uint32_t&)value); }
    bool                getUInt64(uint64_t& value);
    inline bool         getInt64(int64_t& value) { return getUInt64((uint64_t&)value); }
    bool                getUInt128(uint128_t& value);
    inline bool         getInt128(int128_t& value) { return getUInt128((uint128_t&)value); }
    bool                getString(MtpStringBuffer& string);

    Int8List*           getAInt8();
    UInt8List*          getAUInt8();
+21 −11
Original line number Diff line number Diff line
@@ -313,8 +313,10 @@ MtpDeviceInfo* MtpDevice::getDeviceInfo() {
    MtpResponseCode ret = readResponse();
    if (ret == MTP_RESPONSE_OK) {
        MtpDeviceInfo* info = new MtpDeviceInfo;
        info->read(mData);
        if (info->read(mData))
            return info;
        else
            delete info;
    }
    return NULL;
}
@@ -346,8 +348,10 @@ MtpStorageInfo* MtpDevice::getStorageInfo(MtpStorageID storageID) {
    MtpResponseCode ret = readResponse();
    if (ret == MTP_RESPONSE_OK) {
        MtpStorageInfo* info = new MtpStorageInfo(storageID);
        info->read(mData);
        if (info->read(mData))
            return info;
        else
            delete info;
    }
    return NULL;
}
@@ -385,8 +389,10 @@ MtpObjectInfo* MtpDevice::getObjectInfo(MtpObjectHandle handle) {
    MtpResponseCode ret = readResponse();
    if (ret == MTP_RESPONSE_OK) {
        MtpObjectInfo* info = new MtpObjectInfo(handle);
        info->read(mData);
        if (info->read(mData))
            return info;
        else
            delete info;
    }
    return NULL;
}
@@ -547,8 +553,10 @@ MtpProperty* MtpDevice::getDevicePropDesc(MtpDeviceProperty code) {
    MtpResponseCode ret = readResponse();
    if (ret == MTP_RESPONSE_OK) {
        MtpProperty* property = new MtpProperty;
        property->read(mData);
        if (property->read(mData))
            return property;
        else
            delete property;
    }
    return NULL;
}
@@ -566,15 +574,17 @@ MtpProperty* MtpDevice::getObjectPropDesc(MtpObjectProperty code, MtpObjectForma
    MtpResponseCode ret = readResponse();
    if (ret == MTP_RESPONSE_OK) {
        MtpProperty* property = new MtpProperty;
        property->read(mData);
        if (property->read(mData))
            return property;
        else
            delete property;
    }
    return NULL;
}

bool MtpDevice::readObject(MtpObjectHandle handle,
        bool (* callback)(void* data, int offset, int length, void* clientData),
        int objectSize, void* clientData) {
        size_t objectSize, void* clientData) {
    Mutex::Autolock autoLock(mMutex);
    bool result = false;

+1 −1
Original line number Diff line number Diff line
@@ -98,7 +98,7 @@ public:
    bool                    readObject(MtpObjectHandle handle,
                                    bool (* callback)(void* data, int offset,
                                            int length, void* clientData),
                                    int objectSize, void* clientData);
                                    size_t objectSize, void* clientData);
    bool                    readObject(MtpObjectHandle handle, const char* destPath, int group,
                                    int perm);

+20 −13
Original line number Diff line number Diff line
@@ -28,7 +28,7 @@ MtpDeviceInfo::MtpDeviceInfo()
        mVendorExtensionID(0),
        mVendorExtensionVersion(0),
        mVendorExtensionDesc(NULL),
        mFunctionalCode(0),
        mFunctionalMode(0),
        mOperations(NULL),
        mEvents(NULL),
        mDeviceProperties(NULL),
@@ -59,39 +59,46 @@ MtpDeviceInfo::~MtpDeviceInfo() {
        free(mSerial);
}

void MtpDeviceInfo::read(MtpDataPacket& packet) {
bool MtpDeviceInfo::read(MtpDataPacket& packet) {
    MtpStringBuffer string;

    // read the device info
    mStandardVersion = packet.getUInt16();
    mVendorExtensionID = packet.getUInt32();
    mVendorExtensionVersion = packet.getUInt16();
    if (!packet.getUInt16(mStandardVersion)) return false;
    if (!packet.getUInt32(mVendorExtensionID)) return false;
    if (!packet.getUInt16(mVendorExtensionVersion)) return false;

    packet.getString(string);
    if (!packet.getString(string)) return false;
    mVendorExtensionDesc = strdup((const char *)string);

    mFunctionalCode = packet.getUInt16();
    if (!packet.getUInt16(mFunctionalMode)) return false;
    mOperations = packet.getAUInt16();
    if (!mOperations) return false;
    mEvents = packet.getAUInt16();
    if (!mEvents) return false;
    mDeviceProperties = packet.getAUInt16();
    if (!mDeviceProperties) return false;
    mCaptureFormats = packet.getAUInt16();
    if (!mCaptureFormats) return false;
    mPlaybackFormats = packet.getAUInt16();
    if (!mCaptureFormats) return false;

    packet.getString(string);
    if (!packet.getString(string)) return false;
    mManufacturer = strdup((const char *)string);
    packet.getString(string);
    if (!packet.getString(string)) return false;
    mModel = strdup((const char *)string);
    packet.getString(string);
    if (!packet.getString(string)) return false;
    mVersion = strdup((const char *)string);
    packet.getString(string);
    if (!packet.getString(string)) return false;
    mSerial = strdup((const char *)string);

    return true;
}

void MtpDeviceInfo::print() {
    ALOGV("Device Info:\n\tmStandardVersion: %d\n\tmVendorExtensionID: %d\n\tmVendorExtensionVersiony: %d\n",
            mStandardVersion, mVendorExtensionID, mVendorExtensionVersion);
    ALOGV("\tmVendorExtensionDesc: %s\n\tmFunctionalCode: %d\n\tmManufacturer: %s\n\tmModel: %s\n\tmVersion: %s\n\tmSerial: %s\n",
            mVendorExtensionDesc, mFunctionalCode, mManufacturer, mModel, mVersion, mSerial);
    ALOGV("\tmVendorExtensionDesc: %s\n\tmFunctionalMode: %d\n\tmManufacturer: %s\n\tmModel: %s\n\tmVersion: %s\n\tmSerial: %s\n",
            mVendorExtensionDesc, mFunctionalMode, mManufacturer, mModel, mVersion, mSerial);
}

}  // namespace android
Loading