Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 8c900061 authored by Linus Nilsson's avatar Linus Nilsson Committed by Android (Google) Code Review
Browse files

Merge "Transcoder: Add support for progress updates."

parents 6643c8fd e2cdd1f5
Loading
Loading
Loading
Loading
+45 −10
Original line number Diff line number Diff line
@@ -78,14 +78,14 @@ MediaSampleWriter::~MediaSampleWriter() {
    }
}

bool MediaSampleWriter::init(int fd, const OnWritingFinishedCallback& callback) {
    return init(DefaultMuxer::create(fd), callback);
bool MediaSampleWriter::init(int fd, const std::weak_ptr<CallbackInterface>& callbacks) {
    return init(DefaultMuxer::create(fd), callbacks);
}

bool MediaSampleWriter::init(const std::shared_ptr<MediaSampleWriterMuxerInterface>& muxer,
                             const OnWritingFinishedCallback& callback) {
    if (callback == nullptr) {
        LOG(ERROR) << "Callback cannot be null";
                             const std::weak_ptr<CallbackInterface>& callbacks) {
    if (callbacks.lock() == nullptr) {
        LOG(ERROR) << "Callback object cannot be null";
        return false;
    } else if (muxer == nullptr) {
        LOG(ERROR) << "Muxer cannot be null";
@@ -100,7 +100,7 @@ bool MediaSampleWriter::init(const std::shared_ptr<MediaSampleWriterMuxerInterfa

    mState = INITIALIZED;
    mMuxer = muxer;
    mWritingFinishedCallback = callback;
    mCallbacks = callbacks;
    return true;
}

@@ -127,7 +127,11 @@ bool MediaSampleWriter::addTrack(const std::shared_ptr<MediaSampleQueue>& sample
        durationUs = 0;
    }

    mTracks.emplace_back(sampleQueue, static_cast<size_t>(trackIndex), durationUs);
    const char* mime = nullptr;
    const bool isVideo = AMediaFormat_getString(trackFormat.get(), AMEDIAFORMAT_KEY_MIME, &mime) &&
                         (strncmp(mime, "video/", 6) == 0);

    mTracks.emplace_back(sampleQueue, static_cast<size_t>(trackIndex), durationUs, isVideo);
    return true;
}

@@ -144,7 +148,9 @@ bool MediaSampleWriter::start() {

    mThread = std::thread([this] {
        media_status_t status = writeSamples();
        mWritingFinishedCallback(status);
        if (auto callbacks = mCallbacks.lock()) {
            callbacks->onFinished(this, status);
        }
    });
    mState = STARTED;
    return true;
@@ -191,6 +197,18 @@ media_status_t MediaSampleWriter::runWriterLoop() {
    AMediaCodecBufferInfo bufferInfo;
    uint32_t segmentEndTimeUs = mTrackSegmentLengthUs;
    bool samplesLeft = true;
    int32_t lastProgressUpdate = 0;

    // Set the "primary" track that will be used to determine progress to the track with longest
    // duration.
    int primaryTrackIndex = -1;
    int64_t longestDurationUs = 0;
    for (int trackIndex = 0; trackIndex < mTracks.size(); ++trackIndex) {
        if (mTracks[trackIndex].mDurationUs > longestDurationUs) {
            primaryTrackIndex = trackIndex;
            longestDurationUs = mTracks[trackIndex].mDurationUs;
        }
    }

    while (samplesLeft) {
        samplesLeft = false;
@@ -216,9 +234,10 @@ media_status_t MediaSampleWriter::runWriterLoop() {
                    samplesLeft = true;
                }

                // Record the first sample's timestamp in order to translate duration to EOS time
                // for tracks that does not start at 0.
                track.mPrevSampleTimeUs = sample->info.presentationTimeUs;
                if (!track.mFirstSampleTimeSet) {
                    // Record the first sample's timestamp in order to translate duration to EOS
                    // time for tracks that does not start at 0.
                    track.mFirstSampleTimeUs = sample->info.presentationTimeUs;
                    track.mFirstSampleTimeSet = true;
                }
@@ -238,6 +257,22 @@ media_status_t MediaSampleWriter::runWriterLoop() {
            } while (sample->info.presentationTimeUs < segmentEndTimeUs && !track.mReachedEos);
        }

        // TODO(lnilsson): Add option to toggle progress reporting on/off.
        if (primaryTrackIndex >= 0) {
            const TrackRecord& track = mTracks[primaryTrackIndex];

            const int64_t elapsed = track.mPrevSampleTimeUs - track.mFirstSampleTimeUs;
            int32_t progress = (elapsed * 100) / track.mDurationUs;
            progress = std::clamp(progress, 0, 100);

            if (progress > lastProgressUpdate) {
                if (auto callbacks = mCallbacks.lock()) {
                    callbacks->onProgressUpdate(this, progress);
                }
                lastProgressUpdate = progress;
            }
        }

        segmentEndTimeUs += mTrackSegmentLengthUs;
    }

+7 −3
Original line number Diff line number Diff line
@@ -151,11 +151,16 @@ void MediaTranscoder::onTrackError(const MediaTrackTranscoder* transcoder, media
    sendCallback(status);
}

void MediaTranscoder::onSampleWriterFinished(media_status_t status) {
void MediaTranscoder::onFinished(const MediaSampleWriter* writer __unused, media_status_t status) {
    LOG((status != AMEDIA_OK) ? ERROR : DEBUG) << "Sample writer finished with status " << status;
    sendCallback(status);
}

void MediaTranscoder::onProgressUpdate(const MediaSampleWriter* writer __unused, int32_t progress) {
    // Dispatch progress updated to the client.
    mCallbacks->onProgressUpdate(this, progress);
}

MediaTranscoder::MediaTranscoder(const std::shared_ptr<CallbackInterface>& callbacks)
      : mCallbacks(callbacks) {}

@@ -288,8 +293,7 @@ media_status_t MediaTranscoder::configureDestination(int fd) {
    }

    mSampleWriter = std::make_unique<MediaSampleWriter>();
    const bool initOk = mSampleWriter->init(
            fd, std::bind(&MediaTranscoder::onSampleWriterFinished, this, std::placeholders::_1));
    const bool initOk = mSampleWriter->init(fd, shared_from_this());

    if (!initOk) {
        LOG(ERROR) << "Unable to initialize sample writer with destination fd: " << fd;
+26 −15
Original line number Diff line number Diff line
@@ -71,18 +71,27 @@ public:
    /** The default segment length. */
    static constexpr uint32_t kDefaultTrackSegmentLengthUs = 1 * 1000 * 1000;  // 1 sec.

    /** Client callback for when the writer is finished. */
    using OnWritingFinishedCallback = std::function<void(media_status_t)>;
    /** Callback interface. */
    class CallbackInterface {
    public:
        /**
         * Sample writer finished. The finished callback is only called after the sample writer has
         * been successfully started.
         */
        virtual void onFinished(const MediaSampleWriter* writer, media_status_t status) = 0;

        /** Sample writer progress update in percent. */
        virtual void onProgressUpdate(const MediaSampleWriter* writer, int32_t progress) = 0;

        virtual ~CallbackInterface() = default;
    };

    /**
     * Constructor with custom segment length.
     * @param trackSegmentLengthUs The segment length to use for this MediaSampleWriter.
     */
    MediaSampleWriter(uint32_t trackSegmentLengthUs)
          : mTrackSegmentLengthUs(trackSegmentLengthUs),
            mWritingFinishedCallback(nullptr),
            mMuxer(nullptr),
            mState(UNINITIALIZED){};
          : mTrackSegmentLengthUs(trackSegmentLengthUs), mMuxer(nullptr), mState(UNINITIALIZED){};

    /** Constructor using the default segment length. */
    MediaSampleWriter() : MediaSampleWriter(kDefaultTrackSegmentLengthUs){};
@@ -95,21 +104,19 @@ public:
     * to be initialized before tracks are added and can only be initialized once.
     * @param fd An open file descriptor to write to. The caller is responsible for closing this
     *        file descriptor and it is safe to do so once this method returns.
     * @param callback Client callback that gets called when the sample writer has finished, after
     *        it was successfully started.
     * @param callbacks Client callback object that gets called by the sample writer.
     * @return True if the writer was successfully initialized.
     */
    bool init(int fd, const OnWritingFinishedCallback& callback /* nonnull */);
    bool init(int fd, const std::weak_ptr<CallbackInterface>& callbacks /* nonnull */);

    /**
     * Initializes the sample writer with a custom muxer interface implementation.
     * @param muxer The custom muxer interface implementation.
     * @param callback Client callback that gets called when the sample writer has finished, after
     *        it was successfully started.
     * @param @param callbacks Client callback object that gets called by the sample writer.
     * @return True if the writer was successfully initialized.
     */
    bool init(const std::shared_ptr<MediaSampleWriterMuxerInterface>& muxer /* nonnull */,
              const OnWritingFinishedCallback& callback /* nonnull */);
              const std::weak_ptr<CallbackInterface>& callbacks /* nonnull */);

    /**
     * Adds a new track to the sample writer. Tracks must be added after the sample writer has been
@@ -145,24 +152,28 @@ private:

    struct TrackRecord {
        TrackRecord(const std::shared_ptr<MediaSampleQueue>& sampleQueue, size_t trackIndex,
                    int64_t durationUs)
                    int64_t durationUs, bool isVideo)
              : mSampleQueue(sampleQueue),
                mTrackIndex(trackIndex),
                mDurationUs(durationUs),
                mFirstSampleTimeUs(0),
                mPrevSampleTimeUs(0),
                mFirstSampleTimeSet(false),
                mReachedEos(false) {}
                mReachedEos(false),
                mIsVideo(isVideo) {}

        std::shared_ptr<MediaSampleQueue> mSampleQueue;
        const size_t mTrackIndex;
        int64_t mDurationUs;
        int64_t mFirstSampleTimeUs;
        int64_t mPrevSampleTimeUs;
        bool mFirstSampleTimeSet;
        bool mReachedEos;
        bool mIsVideo;
    };

    const uint32_t mTrackSegmentLengthUs;
    OnWritingFinishedCallback mWritingFinishedCallback;
    std::weak_ptr<CallbackInterface> mCallbacks;
    std::shared_ptr<MediaSampleWriterMuxerInterface> mMuxer;
    std::vector<TrackRecord> mTracks;
    std::thread mThread;
+11 −2
Original line number Diff line number Diff line
@@ -17,6 +17,9 @@
#ifndef ANDROID_MEDIA_TRANSCODER_H
#define ANDROID_MEDIA_TRANSCODER_H

#include <binder/Parcel.h>
#include <binder/Parcelable.h>
#include <media/MediaSampleWriter.h>
#include <media/MediaTrackTranscoderCallback.h>
#include <media/NdkMediaError.h>
#include <media/NdkMediaFormat.h>
@@ -30,11 +33,11 @@
namespace android {

class MediaSampleReader;
class MediaSampleWriter;
class Parcel;

class MediaTranscoder : public std::enable_shared_from_this<MediaTranscoder>,
                        public MediaTrackTranscoderCallback {
                        public MediaTrackTranscoderCallback,
                        public MediaSampleWriter::CallbackInterface {
public:
    /** Callbacks from transcoder to client. */
    class CallbackInterface {
@@ -126,6 +129,12 @@ private:
    virtual void onTrackError(const MediaTrackTranscoder* transcoder,
                              media_status_t status) override;
    // ~MediaTrackTranscoderCallback

    // MediaSampleWriter::CallbackInterface
    virtual void onFinished(const MediaSampleWriter* writer, media_status_t status) override;
    virtual void onProgressUpdate(const MediaSampleWriter* writer, int32_t progress) override;
    // ~MediaSampleWriter::CallbackInterface

    void onSampleWriterFinished(media_status_t status);
    void sendCallback(media_status_t status);

+115 −57
Original line number Diff line number Diff line
@@ -32,28 +32,6 @@

namespace android {

/** Minimal one-shot semaphore */
class SimpleSemaphore {
public:
    void signal() {
        std::unique_lock<std::mutex> lock(mMutex);
        mSignaled = true;
        mCondition.notify_all();
    }

    void wait() {
        std::unique_lock<std::mutex> lock(mMutex);
        while (!mSignaled) {
            mCondition.wait(lock);
        }
    }

private:
    std::mutex mMutex;
    std::condition_variable mCondition;
    bool mSignaled = false;
};

/** Muxer interface to enable MediaSampleWriter testing. */
class TestMuxer : public MediaSampleWriterMuxerInterface {
public:
@@ -151,11 +129,22 @@ public:
        for (size_t trackIndex = 0; trackIndex < mTrackCount; trackIndex++) {
            AMediaFormat* trackFormat = AMediaExtractor_getTrackFormat(mExtractor, trackIndex);
            ASSERT_NE(trackFormat, nullptr);

            const char* mime = nullptr;
            AMediaFormat_getString(trackFormat, AMEDIAFORMAT_KEY_MIME, &mime);
            if (strncmp(mime, "video/", 6) == 0) {
                mVideoTrackIndex = trackIndex;
            } else if (strncmp(mime, "audio/", 6) == 0) {
                mAudioTrackIndex = trackIndex;
            }

            mTrackFormats.push_back(
                    std::shared_ptr<AMediaFormat>(trackFormat, &AMediaFormat_delete));

            AMediaExtractor_selectTrack(mExtractor, trackIndex);
        }
        EXPECT_GE(mVideoTrackIndex, 0);
        EXPECT_GE(mAudioTrackIndex, 0);
    }

    void reset() const {
@@ -167,6 +156,60 @@ public:
    AMediaExtractor* mExtractor = nullptr;
    size_t mTrackCount = 0;
    std::vector<std::shared_ptr<AMediaFormat>> mTrackFormats;
    int mVideoTrackIndex = -1;
    int mAudioTrackIndex = -1;
};

class TestCallbacks : public MediaSampleWriter::CallbackInterface {
public:
    TestCallbacks(bool expectSuccess = true) : mExpectSuccess(expectSuccess) {}

    bool hasFinished() {
        std::unique_lock<std::mutex> lock(mMutex);
        return mFinished;
    }

    // MediaSampleWriter::CallbackInterface
    virtual void onFinished(const MediaSampleWriter* writer __unused,
                            media_status_t status) override {
        std::unique_lock<std::mutex> lock(mMutex);
        EXPECT_FALSE(mFinished);
        if (mExpectSuccess) {
            EXPECT_EQ(status, AMEDIA_OK);
        } else {
            EXPECT_NE(status, AMEDIA_OK);
        }
        mFinished = true;
        mCondition.notify_all();
    }

    virtual void onProgressUpdate(const MediaSampleWriter* writer __unused,
                                  int32_t progress) override {
        EXPECT_GT(progress, mLastProgress);
        EXPECT_GE(progress, 0);
        EXPECT_LE(progress, 100);

        mLastProgress = progress;
        mProgressUpdateCount++;
    }
    // ~MediaSampleWriter::CallbackInterface

    void waitForWritingFinished() {
        std::unique_lock<std::mutex> lock(mMutex);
        while (!mFinished) {
            mCondition.wait(lock);
        }
    }

    uint32_t getProgressUpdateCount() const { return mProgressUpdateCount; }

private:
    std::mutex mMutex;
    std::condition_variable mCondition;
    bool mFinished = false;
    bool mExpectSuccess;
    int32_t mLastProgress = -1;
    uint32_t mProgressUpdateCount = 0;
};

class MediaSampleWriterTests : public ::testing::Test {
@@ -222,7 +265,7 @@ public:
protected:
    std::shared_ptr<TestMuxer> mTestMuxer;
    std::shared_ptr<MediaSampleQueue> mSampleQueue;
    const MediaSampleWriter::OnWritingFinishedCallback mEmptyCallback = [](media_status_t) {};
    std::shared_ptr<TestCallbacks> mTestCallbacks = std::make_shared<TestCallbacks>();
};

TEST_F(MediaSampleWriterTests, TestAddTrackWithoutInit) {
@@ -239,14 +282,14 @@ TEST_F(MediaSampleWriterTests, TestStartWithoutInit) {

TEST_F(MediaSampleWriterTests, TestStartWithoutTracks) {
    MediaSampleWriter writer{};
    EXPECT_TRUE(writer.init(mTestMuxer, mEmptyCallback));
    EXPECT_TRUE(writer.init(mTestMuxer, mTestCallbacks));
    EXPECT_FALSE(writer.start());
    EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::NoEvent);
}

TEST_F(MediaSampleWriterTests, TestAddInvalidTrack) {
    MediaSampleWriter writer{};
    EXPECT_TRUE(writer.init(mTestMuxer, mEmptyCallback));
    EXPECT_TRUE(writer.init(mTestMuxer, mTestCallbacks));

    EXPECT_FALSE(writer.addTrack(mSampleQueue, nullptr));
    EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::NoEvent);
@@ -259,31 +302,25 @@ TEST_F(MediaSampleWriterTests, TestAddInvalidTrack) {
TEST_F(MediaSampleWriterTests, TestDoubleStartStop) {
    MediaSampleWriter writer{};

    bool callbackFired = false;
    MediaSampleWriter::OnWritingFinishedCallback stoppedCallback =
            [&callbackFired](media_status_t status) {
                EXPECT_NE(status, AMEDIA_OK);
                EXPECT_FALSE(callbackFired);
                callbackFired = true;
            };

    EXPECT_TRUE(writer.init(mTestMuxer, stoppedCallback));
    std::shared_ptr<TestCallbacks> callbacks =
            std::make_shared<TestCallbacks>(false /* expectSuccess */);
    EXPECT_TRUE(writer.init(mTestMuxer, callbacks));

    const TestMediaSource& mediaSource = getMediaSource();
    EXPECT_TRUE(writer.addTrack(mSampleQueue, mediaSource.mTrackFormats[0]));
    EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::AddTrack(mediaSource.mTrackFormats[0].get()));

    EXPECT_TRUE(writer.start());
    ASSERT_TRUE(writer.start());
    EXPECT_FALSE(writer.start());

    EXPECT_TRUE(writer.stop());
    EXPECT_TRUE(callbackFired);
    EXPECT_TRUE(callbacks->hasFinished());
    EXPECT_FALSE(writer.stop());
}

TEST_F(MediaSampleWriterTests, TestStopWithoutStart) {
    MediaSampleWriter writer{};
    EXPECT_TRUE(writer.init(mTestMuxer, mEmptyCallback));
    EXPECT_TRUE(writer.init(mTestMuxer, mTestCallbacks));

    const TestMediaSource& mediaSource = getMediaSource();
    EXPECT_TRUE(writer.addTrack(mSampleQueue, mediaSource.mTrackFormats[0]));
@@ -295,22 +332,48 @@ TEST_F(MediaSampleWriterTests, TestStopWithoutStart) {

TEST_F(MediaSampleWriterTests, TestStartWithoutCallback) {
    MediaSampleWriter writer{};
    EXPECT_FALSE(writer.init(mTestMuxer, nullptr));

    std::weak_ptr<MediaSampleWriter::CallbackInterface> unassignedWp;
    EXPECT_FALSE(writer.init(mTestMuxer, unassignedWp));

    std::shared_ptr<MediaSampleWriter::CallbackInterface> unassignedSp;
    EXPECT_FALSE(writer.init(mTestMuxer, unassignedSp));

    const TestMediaSource& mediaSource = getMediaSource();
    EXPECT_FALSE(writer.addTrack(mSampleQueue, mediaSource.mTrackFormats[0]));
    ASSERT_FALSE(writer.start());
}

TEST_F(MediaSampleWriterTests, TestProgressUpdate) {
    static constexpr uint32_t kSegmentLengthUs = 1;
    const TestMediaSource& mediaSource = getMediaSource();

    MediaSampleWriter writer{kSegmentLengthUs};
    EXPECT_TRUE(writer.init(mTestMuxer, mTestCallbacks));

    std::shared_ptr<AMediaFormat> videoFormat =
            std::shared_ptr<AMediaFormat>(AMediaFormat_new(), &AMediaFormat_delete);
    AMediaFormat_copy(videoFormat.get(),
                      mediaSource.mTrackFormats[mediaSource.mVideoTrackIndex].get());

    AMediaFormat_setInt64(videoFormat.get(), AMEDIAFORMAT_KEY_DURATION, 100);
    EXPECT_TRUE(writer.addTrack(mSampleQueue, videoFormat));
    ASSERT_TRUE(writer.start());

    for (int64_t pts = 0; pts < 100; ++pts) {
        mSampleQueue->enqueue(newSampleWithPts(pts));
    }
    mSampleQueue->enqueue(newSampleEos());
    mTestCallbacks->waitForWritingFinished();

    EXPECT_EQ(mTestCallbacks->getProgressUpdateCount(), 100);
}

TEST_F(MediaSampleWriterTests, TestInterleaving) {
    static constexpr uint32_t kSegmentLength = MediaSampleWriter::kDefaultTrackSegmentLengthUs;
    SimpleSemaphore semaphore;

    MediaSampleWriter writer{kSegmentLength};
    EXPECT_TRUE(writer.init(mTestMuxer, [&semaphore](media_status_t status) {
        EXPECT_EQ(status, AMEDIA_OK);
        semaphore.signal();
    }));
    EXPECT_TRUE(writer.init(mTestMuxer, mTestCallbacks));

    // Use two tracks for this test.
    static constexpr int kNumTracks = 2;
@@ -356,7 +419,7 @@ TEST_F(MediaSampleWriterTests, TestInterleaving) {
    ASSERT_TRUE(writer.start());

    // Wait for writer to complete.
    semaphore.wait();
    mTestCallbacks->waitForWritingFinished();
    EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::Start());

    // Verify sample order.
@@ -386,16 +449,14 @@ TEST_F(MediaSampleWriterTests, TestInterleaving) {

    EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::Stop());
    EXPECT_TRUE(writer.stop());
    EXPECT_TRUE(mTestCallbacks->hasFinished());
}

TEST_F(MediaSampleWriterTests, TestAbortInputQueue) {
    SimpleSemaphore semaphore;

    MediaSampleWriter writer{};
    EXPECT_TRUE(writer.init(mTestMuxer, [&semaphore](media_status_t status) {
        EXPECT_NE(status, AMEDIA_OK);
        semaphore.signal();
    }));
    std::shared_ptr<TestCallbacks> callbacks =
            std::make_shared<TestCallbacks>(false /* expectSuccess */);
    EXPECT_TRUE(writer.init(mTestMuxer, callbacks));

    // Use two tracks for this test.
    static constexpr int kNumTracks = 2;
@@ -417,7 +478,8 @@ TEST_F(MediaSampleWriterTests, TestAbortInputQueue) {
    for (int trackIdx = 0; trackIdx < kNumTracks; ++trackIdx) {
        sampleQueues[trackIdx]->abort();
    }
    semaphore.wait();

    callbacks->waitForWritingFinished();

    EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::Start());
    EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::Stop());
@@ -465,12 +527,8 @@ TEST_F(MediaSampleWriterTests, TestDefaultMuxer) {
    ASSERT_GT(destinationFd, 0);

    // Initialize writer.
    SimpleSemaphore semaphore;
    MediaSampleWriter writer{};
    EXPECT_TRUE(writer.init(destinationFd, [&semaphore](media_status_t status) {
        EXPECT_EQ(status, AMEDIA_OK);
        semaphore.signal();
    }));
    EXPECT_TRUE(writer.init(destinationFd, mTestCallbacks));
    close(destinationFd);

    // Add tracks.
@@ -497,7 +555,7 @@ TEST_F(MediaSampleWriterTests, TestDefaultMuxer) {
    }

    // Wait for writer.
    semaphore.wait();
    mTestCallbacks->waitForWritingFinished();
    EXPECT_TRUE(writer.stop());

    // Compare output file with source.