Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 0c01f3d1 authored by Linus Nilsson's avatar Linus Nilsson
Browse files

Transcoder: Fix error with short clips

- Fixes a bug in the sample reader where an error was incorrectly
reported if one track reached EOS right before switching to
sequential access.
- Adds more sample reader tests for different combinations of
sample access patterns and access modes.

Bug: 153453392
Fixes: 173643110
Test: Unit test (MediaSampleReaderNDKTests)
Change-Id: I3b683c5d8eb18a5b57d419ce113e08b40363ba9e
parent fdb3e339
Loading
Loading
Loading
Loading
+9 −0
Original line number Original line Diff line number Diff line
@@ -99,6 +99,7 @@ bool MediaSampleReaderNDK::advanceExtractor_l() {
    }
    }


    if (!AMediaExtractor_advance(mExtractor)) {
    if (!AMediaExtractor_advance(mExtractor)) {
        LOG(DEBUG) << "  EOS in advanceExtractor_l";
        mEosReached = true;
        mEosReached = true;
        for (auto it = mTrackSignals.begin(); it != mTrackSignals.end(); ++it) {
        for (auto it = mTrackSignals.begin(); it != mTrackSignals.end(); ++it) {
            it->second.notify_all();
            it->second.notify_all();
@@ -137,6 +138,8 @@ media_status_t MediaSampleReaderNDK::seekExtractorBackwards_l(int64_t targetTime
        LOG(ERROR) << "Unable to seek to " << seekToTimeUs << ", target " << targetTimeUs;
        LOG(ERROR) << "Unable to seek to " << seekToTimeUs << ", target " << targetTimeUs;
        return status;
        return status;
    }
    }

    mEosReached = false;
    mExtractorTrackIndex = AMediaExtractor_getSampleTrackIndex(mExtractor);
    mExtractorTrackIndex = AMediaExtractor_getSampleTrackIndex(mExtractor);
    int64_t sampleTimeUs = AMediaExtractor_getSampleTime(mExtractor);
    int64_t sampleTimeUs = AMediaExtractor_getSampleTime(mExtractor);


@@ -233,6 +236,8 @@ media_status_t MediaSampleReaderNDK::selectTrack(int trackIndex) {
}
}


media_status_t MediaSampleReaderNDK::setEnforceSequentialAccess(bool enforce) {
media_status_t MediaSampleReaderNDK::setEnforceSequentialAccess(bool enforce) {
    LOG(DEBUG) << "setEnforceSequentialAccess( " << enforce << " )";

    std::scoped_lock lock(mExtractorMutex);
    std::scoped_lock lock(mExtractorMutex);


    if (mEnforceSequentialAccess && !enforce) {
    if (mEnforceSequentialAccess && !enforce) {
@@ -374,7 +379,11 @@ media_status_t MediaSampleReaderNDK::getSampleInfoForTrack(int trackIndex, Media
        info->presentationTimeUs = 0;
        info->presentationTimeUs = 0;
        info->flags = SAMPLE_FLAG_END_OF_STREAM;
        info->flags = SAMPLE_FLAG_END_OF_STREAM;
        info->size = 0;
        info->size = 0;
        LOG(DEBUG) << "  getSampleInfoForTrack #" << trackIndex << ": End Of Stream";
    } else {
        LOG(ERROR) << "  getSampleInfoForTrack #" << trackIndex << ": Error " << status;
    }
    }

    return status;
    return status;
}
}


+2 −3
Original line number Original line Diff line number Diff line
@@ -15,6 +15,8 @@ cc_defaults {


    shared_libs: [
    shared_libs: [
        "libbase",
        "libbase",
        "libbinder_ndk",
        "libcrypto",
        "libcutils",
        "libcutils",
        "libmediandk",
        "libmediandk",
        "libmediatranscoder_asan",
        "libmediatranscoder_asan",
@@ -59,7 +61,6 @@ cc_test {
    name: "MediaTrackTranscoderTests",
    name: "MediaTrackTranscoderTests",
    defaults: ["testdefaults"],
    defaults: ["testdefaults"],
    srcs: ["MediaTrackTranscoderTests.cpp"],
    srcs: ["MediaTrackTranscoderTests.cpp"],
    shared_libs: ["libbinder_ndk"],
}
}


// VideoTrackTranscoder unit test
// VideoTrackTranscoder unit test
@@ -74,7 +75,6 @@ cc_test {
    name: "PassthroughTrackTranscoderTests",
    name: "PassthroughTrackTranscoderTests",
    defaults: ["testdefaults"],
    defaults: ["testdefaults"],
    srcs: ["PassthroughTrackTranscoderTests.cpp"],
    srcs: ["PassthroughTrackTranscoderTests.cpp"],
    shared_libs: ["libcrypto"],
}
}


// MediaSampleWriter unit test
// MediaSampleWriter unit test
@@ -89,5 +89,4 @@ cc_test {
    name: "MediaTranscoderTests",
    name: "MediaTranscoderTests",
    defaults: ["testdefaults"],
    defaults: ["testdefaults"],
    srcs: ["MediaTranscoderTests.cpp"],
    srcs: ["MediaTranscoderTests.cpp"],
    shared_libs: ["libbinder_ndk"],
}
}
+230 −54
Original line number Original line Diff line number Diff line
@@ -25,39 +25,166 @@
#include <fcntl.h>
#include <fcntl.h>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <media/MediaSampleReaderNDK.h>
#include <media/MediaSampleReaderNDK.h>
#include <openssl/md5.h>
#include <utils/Timers.h>
#include <utils/Timers.h>


#include <cmath>
#include <cmath>
#include <mutex>
#include <mutex>
#include <thread>
#include <thread>


// TODO(b/153453392): Test more asset types and validate sample data from readSampleDataForTrack.
// TODO(b/153453392): Test more asset types (frame reordering?).
// TODO(b/153453392): Test for sequential and parallel (single thread and multi thread) access.
// TODO(b/153453392): Test for switching between sequential and parallel access in different points
//  of time.


namespace android {
namespace android {


#define SEC_TO_USEC(s) ((s)*1000 * 1000)
#define SEC_TO_USEC(s) ((s)*1000 * 1000)


/** Helper class for comparing sample data using checksums. */
class Sample {
public:
    Sample(uint32_t flags, int64_t timestamp, size_t size, const uint8_t* buffer)
          : mFlags{flags}, mTimestamp{timestamp}, mSize{size} {
        initChecksum(buffer);
    }

    Sample(AMediaExtractor* extractor) {
        mFlags = AMediaExtractor_getSampleFlags(extractor);
        mTimestamp = AMediaExtractor_getSampleTime(extractor);
        mSize = static_cast<size_t>(AMediaExtractor_getSampleSize(extractor));

        auto buffer = std::make_unique<uint8_t[]>(mSize);
        AMediaExtractor_readSampleData(extractor, buffer.get(), mSize);

        initChecksum(buffer.get());
    }

    void initChecksum(const uint8_t* buffer) {
        MD5_CTX md5Ctx;
        MD5_Init(&md5Ctx);
        MD5_Update(&md5Ctx, buffer, mSize);
        MD5_Final(mChecksum, &md5Ctx);
    }

    bool operator==(const Sample& rhs) const {
        return mSize == rhs.mSize && mFlags == rhs.mFlags && mTimestamp == rhs.mTimestamp &&
               memcmp(mChecksum, rhs.mChecksum, MD5_DIGEST_LENGTH) == 0;
    }

    uint32_t mFlags;
    int64_t mTimestamp;
    size_t mSize;
    uint8_t mChecksum[MD5_DIGEST_LENGTH];
};

/** Constant for selecting all samples. */
static constexpr int SAMPLE_COUNT_ALL = -1;

/**
 * Utility class to test different sample access patterns combined with sequential or parallel
 * sample access modes.
 */
class SampleAccessTester {
public:
    SampleAccessTester(int sourceFd, size_t fileSize) {
        mSampleReader = MediaSampleReaderNDK::createFromFd(sourceFd, 0, fileSize);
        EXPECT_TRUE(mSampleReader);

        mTrackCount = mSampleReader->getTrackCount();

        for (int trackIndex = 0; trackIndex < mTrackCount; trackIndex++) {
            EXPECT_EQ(mSampleReader->selectTrack(trackIndex), AMEDIA_OK);
        }

        mSamples.resize(mTrackCount);
        mTrackThreads.resize(mTrackCount);
    }

    void getSampleInfo(int trackIndex) {
        MediaSampleInfo info;
        media_status_t status = mSampleReader->getSampleInfoForTrack(trackIndex, &info);
        EXPECT_EQ(status, AMEDIA_OK);
    }

    void readSamplesAsync(int trackIndex, int sampleCount) {
        mTrackThreads[trackIndex] = std::thread{[this, trackIndex, sampleCount] {
            int samplesRead = 0;
            MediaSampleInfo info;
            while (samplesRead < sampleCount || sampleCount == SAMPLE_COUNT_ALL) {
                media_status_t status = mSampleReader->getSampleInfoForTrack(trackIndex, &info);
                if (status != AMEDIA_OK) {
                    EXPECT_EQ(status, AMEDIA_ERROR_END_OF_STREAM);
                    EXPECT_TRUE((info.flags & SAMPLE_FLAG_END_OF_STREAM) != 0);
                    break;
                }
                ASSERT_TRUE((info.flags & SAMPLE_FLAG_END_OF_STREAM) == 0);

                auto buffer = std::make_unique<uint8_t[]>(info.size);
                status = mSampleReader->readSampleDataForTrack(trackIndex, buffer.get(), info.size);
                EXPECT_EQ(status, AMEDIA_OK);

                mSampleMutex.lock();
                const uint8_t* bufferPtr = buffer.get();
                mSamples[trackIndex].emplace_back(info.flags, info.presentationTimeUs, info.size,
                                                  bufferPtr);
                mSampleMutex.unlock();
                ++samplesRead;
            }
        }};
    }

    void readSamplesAsync(int sampleCount) {
        for (int trackIndex = 0; trackIndex < mTrackCount; trackIndex++) {
            readSamplesAsync(trackIndex, sampleCount);
        }
    }

    void waitForTrack(int trackIndex) {
        ASSERT_TRUE(mTrackThreads[trackIndex].joinable());
        mTrackThreads[trackIndex].join();
    }

    void waitForTracks() {
        for (int trackIndex = 0; trackIndex < mTrackCount; trackIndex++) {
            waitForTrack(trackIndex);
        }
    }

    void setEnforceSequentialAccess(bool enforce) {
        media_status_t status = mSampleReader->setEnforceSequentialAccess(enforce);
        EXPECT_EQ(status, AMEDIA_OK);
    }

    std::vector<std::vector<Sample>>& getSamples() { return mSamples; }

    std::shared_ptr<MediaSampleReader> mSampleReader;
    size_t mTrackCount;
    std::mutex mSampleMutex;
    std::vector<std::thread> mTrackThreads;
    std::vector<std::vector<Sample>> mSamples;
};

class MediaSampleReaderNDKTests : public ::testing::Test {
class MediaSampleReaderNDKTests : public ::testing::Test {
public:
public:
    MediaSampleReaderNDKTests() { LOG(DEBUG) << "MediaSampleReaderNDKTests created"; }
    MediaSampleReaderNDKTests() { LOG(DEBUG) << "MediaSampleReaderNDKTests created"; }


    void SetUp() override {
    void SetUp() override {
        LOG(DEBUG) << "MediaSampleReaderNDKTests set up";
        LOG(DEBUG) << "MediaSampleReaderNDKTests set up";

        // Need to start a thread pool to prevent AMediaExtractor binder calls from starving
        // (b/155663561).
        ABinderProcess_startThreadPool();

        const char* sourcePath =
        const char* sourcePath =
                "/data/local/tmp/TranscodingTestAssets/cubicle_avc_480x240_aac_24KHz.mp4";
                "/data/local/tmp/TranscodingTestAssets/cubicle_avc_480x240_aac_24KHz.mp4";


        mExtractor = AMediaExtractor_new();
        ASSERT_NE(mExtractor, nullptr);

        mSourceFd = open(sourcePath, O_RDONLY);
        mSourceFd = open(sourcePath, O_RDONLY);
        ASSERT_GT(mSourceFd, 0);
        ASSERT_GT(mSourceFd, 0);


        mFileSize = lseek(mSourceFd, 0, SEEK_END);
        mFileSize = lseek(mSourceFd, 0, SEEK_END);
        lseek(mSourceFd, 0, SEEK_SET);
        lseek(mSourceFd, 0, SEEK_SET);


        mExtractor = AMediaExtractor_new();
        ASSERT_NE(mExtractor, nullptr);

        media_status_t status =
        media_status_t status =
                AMediaExtractor_setDataSourceFd(mExtractor, mSourceFd, 0, mFileSize);
                AMediaExtractor_setDataSourceFd(mExtractor, mSourceFd, 0, mFileSize);
        ASSERT_EQ(status, AMEDIA_OK);
        ASSERT_EQ(status, AMEDIA_OK);
@@ -68,14 +195,14 @@ public:
        }
        }
    }
    }


    void initExtractorTimestamps() {
    void initExtractorSamples() {
        // Save all sample timestamps, per track, as reported by the extractor.
        if (mExtractorSamples.size() == mTrackCount) return;
        mExtractorTimestamps.resize(mTrackCount);

        // Save sample information, per track, as reported by the extractor.
        mExtractorSamples.resize(mTrackCount);
        do {
        do {
            const int trackIndex = AMediaExtractor_getSampleTrackIndex(mExtractor);
            const int trackIndex = AMediaExtractor_getSampleTrackIndex(mExtractor);
            const int64_t sampleTime = AMediaExtractor_getSampleTime(mExtractor);
            mExtractorSamples[trackIndex].emplace_back(mExtractor);

            mExtractorTimestamps[trackIndex].push_back(sampleTime);
        } while (AMediaExtractor_advance(mExtractor));
        } while (AMediaExtractor_advance(mExtractor));


        AMediaExtractor_seekTo(mExtractor, 0, AMEDIAEXTRACTOR_SEEK_PREVIOUS_SYNC);
        AMediaExtractor_seekTo(mExtractor, 0, AMEDIAEXTRACTOR_SEEK_PREVIOUS_SYNC);
@@ -104,6 +231,22 @@ public:
        return bitrates;
        return bitrates;
    }
    }


    void compareSamples(std::vector<std::vector<Sample>>& readerSamples) {
        initExtractorSamples();
        EXPECT_EQ(readerSamples.size(), mTrackCount);

        for (int trackIndex = 0; trackIndex < mTrackCount; trackIndex++) {
            LOG(DEBUG) << "Track " << trackIndex << ", comparing "
                       << readerSamples[trackIndex].size() << " samples.";
            EXPECT_EQ(readerSamples[trackIndex].size(), mExtractorSamples[trackIndex].size());
            for (size_t sampleIndex = 0; sampleIndex < readerSamples[trackIndex].size();
                 sampleIndex++) {
                EXPECT_EQ(readerSamples[trackIndex][sampleIndex],
                          mExtractorSamples[trackIndex][sampleIndex]);
            }
        }
    }

    void TearDown() override {
    void TearDown() override {
        LOG(DEBUG) << "MediaSampleReaderNDKTests tear down";
        LOG(DEBUG) << "MediaSampleReaderNDKTests tear down";
        AMediaExtractor_delete(mExtractor);
        AMediaExtractor_delete(mExtractor);
@@ -116,58 +259,91 @@ public:
    size_t mTrackCount;
    size_t mTrackCount;
    int mSourceFd;
    int mSourceFd;
    size_t mFileSize;
    size_t mFileSize;
    std::vector<std::vector<int64_t>> mExtractorTimestamps;
    std::vector<std::vector<Sample>> mExtractorSamples;
};
};


TEST_F(MediaSampleReaderNDKTests, TestSampleTimes) {
/** Reads all samples from all tracks in parallel. */
    LOG(DEBUG) << "TestSampleTimes Starts";
TEST_F(MediaSampleReaderNDKTests, TestParallelSampleAccess) {
    LOG(DEBUG) << "TestParallelSampleAccess Starts";


    std::shared_ptr<MediaSampleReader> sampleReader =
    SampleAccessTester tester{mSourceFd, mFileSize};
            MediaSampleReaderNDK::createFromFd(mSourceFd, 0, mFileSize);
    tester.readSamplesAsync(SAMPLE_COUNT_ALL);
    ASSERT_TRUE(sampleReader);
    tester.waitForTracks();
    compareSamples(tester.getSamples());
}


    for (int trackIndex = 0; trackIndex < mTrackCount; trackIndex++) {
/** Reads all samples from all tracks sequentially. */
        EXPECT_EQ(sampleReader->selectTrack(trackIndex), AMEDIA_OK);
TEST_F(MediaSampleReaderNDKTests, TestSequentialSampleAccess) {
    LOG(DEBUG) << "TestSequentialSampleAccess Starts";

    SampleAccessTester tester{mSourceFd, mFileSize};
    tester.setEnforceSequentialAccess(true);
    tester.readSamplesAsync(SAMPLE_COUNT_ALL);
    tester.waitForTracks();
    compareSamples(tester.getSamples());
}
}


    // Initialize the extractor timestamps.
/** Reads all samples from one track in parallel mode before switching to sequential mode. */
    initExtractorTimestamps();
TEST_F(MediaSampleReaderNDKTests, TestMixedSampleAccessTrackEOS) {
    LOG(DEBUG) << "TestMixedSampleAccessTrackEOS Starts";


    std::mutex timestampMutex;
    for (int readSampleInfoFlag = 0; readSampleInfoFlag <= 1; readSampleInfoFlag++) {
    std::vector<std::thread> trackThreads;
        for (int trackIndToEOS = 0; trackIndToEOS < mTrackCount; ++trackIndToEOS) {
    std::vector<std::vector<int64_t>> readerTimestamps(mTrackCount);
            LOG(DEBUG) << "Testing EOS of track " << trackIndToEOS;


    for (int trackIndex = 0; trackIndex < mTrackCount; trackIndex++) {
            SampleAccessTester tester{mSourceFd, mFileSize};
        trackThreads.emplace_back([sampleReader, trackIndex, &timestampMutex, &readerTimestamps] {

            MediaSampleInfo info;
            // If the flag is set, read sample info from a different track before draining the track
            while (true) {
            // under test to force the reader to save the extractor position.
                media_status_t status = sampleReader->getSampleInfoForTrack(trackIndex, &info);
            if (readSampleInfoFlag) {
                if (status != AMEDIA_OK) {
                tester.getSampleInfo((trackIndToEOS + 1) % mTrackCount);
                    EXPECT_EQ(status, AMEDIA_ERROR_END_OF_STREAM);
                    EXPECT_TRUE((info.flags & SAMPLE_FLAG_END_OF_STREAM) != 0);
                    break;
            }
            }
                ASSERT_TRUE((info.flags & SAMPLE_FLAG_END_OF_STREAM) == 0);

                timestampMutex.lock();
            // Read all samples from one track before enabling sequential access
                readerTimestamps[trackIndex].push_back(info.presentationTimeUs);
            tester.readSamplesAsync(trackIndToEOS, SAMPLE_COUNT_ALL);
                timestampMutex.unlock();
            tester.waitForTrack(trackIndToEOS);
                sampleReader->advanceTrack(trackIndex);
            tester.setEnforceSequentialAccess(true);

            for (int trackIndex = 0; trackIndex < mTrackCount; ++trackIndex) {
                if (trackIndex == trackIndToEOS) continue;

                tester.readSamplesAsync(trackIndex, SAMPLE_COUNT_ALL);
                tester.waitForTrack(trackIndex);
            }
            }
        });

            compareSamples(tester.getSamples());
        }
        }
    }
}

/**
 * Reads different combinations of sample counts from all tracks in parallel mode before switching
 * to sequential mode and reading the rest of the samples.
 */
TEST_F(MediaSampleReaderNDKTests, TestMixedSampleAccess) {
    LOG(DEBUG) << "TestMixedSampleAccess Starts";
    initExtractorSamples();

    for (int trackIndToTest = 0; trackIndToTest < mTrackCount; ++trackIndToTest) {
        for (int sampleCount = 0; sampleCount <= (mExtractorSamples[trackIndToTest].size() + 1);
             ++sampleCount) {
            SampleAccessTester tester{mSourceFd, mFileSize};


    for (auto& thread : trackThreads) {
            for (int trackIndex = 0; trackIndex < mTrackCount; ++trackIndex) {
        thread.join();
                if (trackIndex == trackIndToTest) {
                    tester.readSamplesAsync(trackIndex, sampleCount);
                } else {
                    tester.readSamplesAsync(trackIndex, mExtractorSamples[trackIndex].size() / 2);
                }
            }
            }


    for (int trackIndex = 0; trackIndex < mTrackCount; trackIndex++) {
            tester.waitForTracks();
        LOG(DEBUG) << "Track " << trackIndex << ", comparing "
            tester.setEnforceSequentialAccess(true);
                   << readerTimestamps[trackIndex].size() << " samples.";

        EXPECT_EQ(readerTimestamps[trackIndex].size(), mExtractorTimestamps[trackIndex].size());
            tester.readSamplesAsync(SAMPLE_COUNT_ALL);
        for (size_t sampleIndex = 0; sampleIndex < readerTimestamps[trackIndex].size();
            tester.waitForTracks();
             sampleIndex++) {

            EXPECT_EQ(readerTimestamps[trackIndex][sampleIndex],
            compareSamples(tester.getSamples());
                      mExtractorTimestamps[trackIndex][sampleIndex]);
        }
        }
    }
    }
}
}