Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 930bb620 authored by Shivaansh Agrawal's avatar Shivaansh Agrawal Committed by Ray Essick
Browse files

mp3dec: Fix out of bound read error

Add check for required number of bytes before stream read
while reading side info.
Modify bitstream read functions to only read required number of bytes

Bug: 154075955
Bug: 154076193
Test: POC in bug description
Test: atest android.mediav2.cts.CodecDecoderTest
Test: atest Mp3DecoderTest -- --enable-module-dynamic-download=true

Change-Id: I777f22d21cbf026056f1ac69de4bb763846b1a9d
(cherry picked from commit c57092d1)
parent 8579aa2b
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
@@ -219,6 +219,11 @@ ERROR_CODE pvmp3_framedecoder(tPVMP3DecoderExternal *pExt,

    if (info->error_protection)
    {
        if (!bitsAvailable(&pVars->inputStream, 16))
        {
            return SIDE_INFO_ERROR;
        }

        /*
         *  Get crc content
         */
+61 −0
Original line number Diff line number Diff line
@@ -73,6 +73,7 @@ Input

#include "pvmp3_get_side_info.h"
#include "pvmp3_crc.h"
#include "pvmp3_getbits.h"


/*----------------------------------------------------------------------------
@@ -125,12 +126,22 @@ ERROR_CODE pvmp3_get_side_info(tmp3Bits *inputStream,
    {
        if (stereo == 1)
        {
            if (!bitsAvailable(inputStream, 14))
            {
                return SIDE_INFO_ERROR;
            }

            tmp = getbits_crc(inputStream, 14, crc, info->error_protection);
            si->main_data_begin = (tmp << 18) >> 23;    /* 9 */
            si->private_bits    = (tmp << 27) >> 27;    /* 5 */
        }
        else
        {
            if (!bitsAvailable(inputStream, 12))
            {
                return SIDE_INFO_ERROR;
            }

            tmp = getbits_crc(inputStream, 12, crc, info->error_protection);
            si->main_data_begin = (tmp << 20) >> 23;    /* 9 */
            si->private_bits    = (tmp << 29) >> 29;    /* 3 */
@@ -139,6 +150,11 @@ ERROR_CODE pvmp3_get_side_info(tmp3Bits *inputStream,

        for (ch = 0; ch < stereo; ch++)
        {
            if (!bitsAvailable(inputStream, 4))
            {
                return SIDE_INFO_ERROR;
            }

            tmp = getbits_crc(inputStream, 4, crc, info->error_protection);
            si->ch[ch].scfsi[0] = (tmp << 28) >> 31;    /* 1 */
            si->ch[ch].scfsi[1] = (tmp << 29) >> 31;    /* 1 */
@@ -150,6 +166,11 @@ ERROR_CODE pvmp3_get_side_info(tmp3Bits *inputStream,
        {
            for (ch = 0; ch < stereo; ch++)
            {
                if (!bitsAvailable(inputStream, 34))
                {
                    return SIDE_INFO_ERROR;
                }

                si->ch[ch].gran[gr].part2_3_length    = getbits_crc(inputStream, 12, crc, info->error_protection);
                tmp = getbits_crc(inputStream, 22, crc, info->error_protection);

@@ -160,6 +181,11 @@ ERROR_CODE pvmp3_get_side_info(tmp3Bits *inputStream,

                if (si->ch[ch].gran[gr].window_switching_flag)
                {
                    if (!bitsAvailable(inputStream, 22))
                    {
                        return SIDE_INFO_ERROR;
                    }

                    tmp = getbits_crc(inputStream, 22, crc, info->error_protection);

                    si->ch[ch].gran[gr].block_type       = (tmp << 10) >> 30;   /* 2 */;
@@ -192,6 +218,11 @@ ERROR_CODE pvmp3_get_side_info(tmp3Bits *inputStream,
                }
                else
                {
                    if (!bitsAvailable(inputStream, 22))
                    {
                        return SIDE_INFO_ERROR;
                    }

                    tmp = getbits_crc(inputStream, 22, crc, info->error_protection);

                    si->ch[ch].gran[gr].table_select[0] = (tmp << 10) >> 27;   /* 5 */;
@@ -204,6 +235,11 @@ ERROR_CODE pvmp3_get_side_info(tmp3Bits *inputStream,
                    si->ch[ch].gran[gr].block_type      = 0;
                }

                if (!bitsAvailable(inputStream, 3))
                {
                    return SIDE_INFO_ERROR;
                }

                tmp = getbits_crc(inputStream, 3, crc, info->error_protection);
                si->ch[ch].gran[gr].preflag            = (tmp << 29) >> 31;    /* 1 */
                si->ch[ch].gran[gr].scalefac_scale     = (tmp << 30) >> 31;    /* 1 */
@@ -213,11 +249,21 @@ ERROR_CODE pvmp3_get_side_info(tmp3Bits *inputStream,
    }
    else /* Layer 3 LSF */
    {
        if (!bitsAvailable(inputStream, 8 + stereo))
        {
            return SIDE_INFO_ERROR;
        }

        si->main_data_begin = getbits_crc(inputStream,      8, crc, info->error_protection);
        si->private_bits    = getbits_crc(inputStream, stereo, crc, info->error_protection);

        for (ch = 0; ch < stereo; ch++)
        {
            if (!bitsAvailable(inputStream, 39))
            {
                return SIDE_INFO_ERROR;
            }

            tmp = getbits_crc(inputStream, 21, crc, info->error_protection);
            si->ch[ch].gran[0].part2_3_length    = (tmp << 11) >> 20;  /* 12 */
            si->ch[ch].gran[0].big_values        = (tmp << 23) >> 23;  /*  9 */
@@ -230,6 +276,11 @@ ERROR_CODE pvmp3_get_side_info(tmp3Bits *inputStream,
            if (si->ch[ch].gran[0].window_switching_flag)
            {

                if (!bitsAvailable(inputStream, 22))
                {
                    return SIDE_INFO_ERROR;
                }

                tmp = getbits_crc(inputStream, 22, crc, info->error_protection);

                si->ch[ch].gran[0].block_type       = (tmp << 10) >> 30;   /* 2 */;
@@ -262,6 +313,11 @@ ERROR_CODE pvmp3_get_side_info(tmp3Bits *inputStream,
            }
            else
            {
                if (!bitsAvailable(inputStream, 22))
                {
                    return SIDE_INFO_ERROR;
                }

                tmp = getbits_crc(inputStream, 22, crc, info->error_protection);

                si->ch[ch].gran[0].table_select[0] = (tmp << 10) >> 27;   /* 5 */;
@@ -274,6 +330,11 @@ ERROR_CODE pvmp3_get_side_info(tmp3Bits *inputStream,
                si->ch[ch].gran[0].block_type      = 0;
            }

            if (!bitsAvailable(inputStream, 2))
            {
                return SIDE_INFO_ERROR;
            }

            tmp = getbits_crc(inputStream, 2, crc, info->error_protection);
            si->ch[ch].gran[0].scalefac_scale     =  tmp >> 1;  /* 1 */
            si->ch[ch].gran[0].count1table_select =  tmp & 1;  /* 1 */
+65 −27
Original line number Diff line number Diff line
@@ -113,10 +113,11 @@ uint32 getNbits(tmp3Bits *ptBitStream,

    uint32    offset;
    uint32    bitIndex;
    uint8     Elem;         /* Needs to be same type as pInput->pBuffer */
    uint8     Elem1;
    uint8     Elem2;
    uint8     Elem3;
    uint32    bytesToFetch;
    uint8     Elem  = 0;         /* Needs to be same type as pInput->pBuffer */
    uint8     Elem1 = 0;
    uint8     Elem2 = 0;
    uint8     Elem3 = 0;
    uint32   returnValue = 0;

    if (!neededBits)
@@ -126,10 +127,25 @@ uint32 getNbits(tmp3Bits *ptBitStream,

    offset = (ptBitStream->usedBits) >> INBUF_ARRAY_INDEX_SHIFT;

    Elem  = *(ptBitStream->pBuffer + module(offset  , BUFSIZE));
    Elem1 = *(ptBitStream->pBuffer + module(offset + 1, BUFSIZE));
    Elem2 = *(ptBitStream->pBuffer + module(offset + 2, BUFSIZE));
    /* Remove extra high bits by shifting up */
    bitIndex = module(ptBitStream->usedBits, INBUF_BIT_WIDTH);

    bytesToFetch = (bitIndex + neededBits + 7 ) >> 3 ;

    switch (bytesToFetch)
    {
    case 4:
        Elem3 = *(ptBitStream->pBuffer + module(offset + 3, BUFSIZE));
        [[fallthrough]];
    case 3:
        Elem2 = *(ptBitStream->pBuffer + module(offset + 2, BUFSIZE));
        [[fallthrough]];
    case 2:
        Elem1 = *(ptBitStream->pBuffer + module(offset + 1, BUFSIZE));
        [[fallthrough]];
    case 1:
        Elem = *(ptBitStream->pBuffer + module(offset, BUFSIZE));
    }


    returnValue = (((uint32)(Elem)) << 24) |
@@ -137,9 +153,6 @@ uint32 getNbits(tmp3Bits *ptBitStream,
                  (((uint32)(Elem2)) << 8) |
                  ((uint32)(Elem3));

    /* Remove extra high bits by shifting up */
    bitIndex = module(ptBitStream->usedBits, INBUF_BIT_WIDTH);

    /* This line is faster than to mask off the high bits. */
    returnValue <<= bitIndex;

@@ -161,22 +174,32 @@ uint16 getUpTo9bits(tmp3Bits *ptBitStream,

    uint32    offset;
    uint32    bitIndex;
    uint8    Elem;         /* Needs to be same type as pInput->pBuffer */
    uint8    Elem1;
    uint32    bytesToFetch;
    uint8    Elem  = 0;         /* Needs to be same type as pInput->pBuffer */
    uint8    Elem1 = 0;
    uint16   returnValue;

    offset = (ptBitStream->usedBits) >> INBUF_ARRAY_INDEX_SHIFT;

    /* Remove extra high bits by shifting up */
    bitIndex = module(ptBitStream->usedBits, INBUF_BIT_WIDTH);

    bytesToFetch = (bitIndex + neededBits + 7 ) >> 3 ;

    if (bytesToFetch > 1)
    {
        Elem = *(ptBitStream->pBuffer + module(offset, BUFSIZE));
        Elem1 = *(ptBitStream->pBuffer + module(offset + 1, BUFSIZE));
    }
    else if (bytesToFetch > 0)
    {
        Elem = *(ptBitStream->pBuffer + module(offset, BUFSIZE));
    }


    returnValue = (((uint16)(Elem)) << 8) |
                  ((uint16)(Elem1));

    /* Remove extra high bits by shifting up */
    bitIndex = module(ptBitStream->usedBits, INBUF_BIT_WIDTH);

    ptBitStream->usedBits += neededBits;
    /* This line is faster than to mask off the high bits. */
    returnValue = (returnValue << (bitIndex));
@@ -197,25 +220,40 @@ uint32 getUpTo17bits(tmp3Bits *ptBitStream,

    uint32    offset;
    uint32    bitIndex;
    uint8     Elem;         /* Needs to be same type as pInput->pBuffer */
    uint8     Elem1;
    uint8     Elem2;
    uint32    bytesToFetch;
    uint8     Elem  = 0;         /* Needs to be same type as pInput->pBuffer */
    uint8     Elem1 = 0;
    uint8     Elem2 = 0;
    uint32   returnValue;

    offset = (ptBitStream->usedBits) >> INBUF_ARRAY_INDEX_SHIFT;

    /* Remove extra high bits by shifting up */
    bitIndex = module(ptBitStream->usedBits, INBUF_BIT_WIDTH);

    bytesToFetch = (bitIndex + neededBits + 7 ) >> 3 ;

    if (bytesToFetch > 2)
    {
        Elem  = *(ptBitStream->pBuffer + module(offset, BUFSIZE));
        Elem1 = *(ptBitStream->pBuffer + module(offset + 1, BUFSIZE));
        Elem2 = *(ptBitStream->pBuffer + module(offset + 2, BUFSIZE));
    }
    else if (bytesToFetch > 1)
    {
        Elem  = *(ptBitStream->pBuffer + module(offset, BUFSIZE));
        Elem1 = *(ptBitStream->pBuffer + module(offset + 1, BUFSIZE));
    }
    else if (bytesToFetch > 0)
    {
        Elem = *(ptBitStream->pBuffer + module(offset, BUFSIZE));
    }


    returnValue = (((uint32)(Elem)) << 16) |
                  (((uint32)(Elem1)) << 8) |
                  ((uint32)(Elem2));

    /* Remove extra high bits by shifting up */
    bitIndex = module(ptBitStream->usedBits, INBUF_BIT_WIDTH);

    ptBitStream->usedBits += neededBits;
    /* This line is faster than to mask off the high bits. */
    returnValue = 0xFFFFFF & (returnValue << (bitIndex));
+5 −0
Original line number Diff line number Diff line
@@ -104,6 +104,11 @@ extern "C"
; Function Prototype declaration
----------------------------------------------------------------------------*/

static inline bool bitsAvailable(tmp3Bits *inputStream, uint32 neededBits)
{
    return (inputStream->inputBufferCurrentLength << 3) >= (neededBits + inputStream->usedBits);
}

/*----------------------------------------------------------------------------
; END
----------------------------------------------------------------------------*/