Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 450e83eb authored by Josh Gao's avatar Josh Gao Committed by Gerrit Code Review
Browse files

Merge "adb: implement zstd compression for file sync."

parents f7fed04a 317d3e17
Loading
Loading
Loading
Loading
+5 −1
Original line number Diff line number Diff line
@@ -124,11 +124,12 @@ cc_defaults {
        "libadbd_core",
        "libadbconnection_server",
        "libasyncio",
        "libbase",
        "libbrotli",
        "libcutils_sockets",
        "libdiagnose_usb",
        "libmdnssd",
        "libbase",
        "libzstd",

        "libadb_protos",
        "libapp_processes_protos_lite",
@@ -351,6 +352,7 @@ cc_binary_host {
        "liblog",
        "libziparchive",
        "libz",
        "libzstd",
    ],

    // Don't add anything here, we don't want additional shared dependencies
@@ -483,6 +485,7 @@ cc_library {
        "libbrotli",
        "libdiagnose_usb",
        "liblz4",
        "libzstd",
    ],

    shared_libs: [
@@ -586,6 +589,7 @@ cc_library {
        "libdiagnose_usb",
        "liblz4",
        "libmdnssd",
        "libzstd",
    ],

    visibility: [
+2 −0
Original line number Diff line number Diff line
@@ -1336,6 +1336,8 @@ static CompressionType parse_compression_type(const std::string& str, bool allow
        return CompressionType::Brotli;
    } else if (str == "lz4") {
        return CompressionType::LZ4;
    } else if (str == "zstd") {
        return CompressionType::Zstd;
    }

    error_exit("unexpected compression type %s", str.c_str());
+26 −3
Original line number Diff line number Diff line
@@ -240,6 +240,7 @@ class SyncConnection {
            have_sendrecv_v2_ = CanUseFeature(*features, kFeatureSendRecv2);
            have_sendrecv_v2_brotli_ = CanUseFeature(*features, kFeatureSendRecv2Brotli);
            have_sendrecv_v2_lz4_ = CanUseFeature(*features, kFeatureSendRecv2LZ4);
            have_sendrecv_v2_zstd_ = CanUseFeature(*features, kFeatureSendRecv2Zstd);
            have_sendrecv_v2_dry_run_send_ = CanUseFeature(*features, kFeatureSendRecv2DryRunSend);
            std::string error;
            fd.reset(adb_connect("sync:", &error));
@@ -268,13 +269,16 @@ class SyncConnection {
    bool HaveSendRecv2() const { return have_sendrecv_v2_; }
    bool HaveSendRecv2Brotli() const { return have_sendrecv_v2_brotli_; }
    bool HaveSendRecv2LZ4() const { return have_sendrecv_v2_lz4_; }
    bool HaveSendRecv2Zstd() const { return have_sendrecv_v2_zstd_; }
    bool HaveSendRecv2DryRunSend() const { return have_sendrecv_v2_dry_run_send_; }

    // Resolve a compression type which might be CompressionType::Any to a specific compression
    // algorithm.
    CompressionType ResolveCompressionType(CompressionType compression) const {
        if (compression == CompressionType::Any) {
            if (HaveSendRecv2LZ4()) {
            if (HaveSendRecv2Zstd()) {
                return CompressionType::Zstd;
            } else if (HaveSendRecv2LZ4()) {
                return CompressionType::LZ4;
            } else if (HaveSendRecv2Brotli()) {
                return CompressionType::Brotli;
@@ -374,6 +378,10 @@ class SyncConnection {
                msg.send_v2_setup.flags = kSyncFlagLZ4;
                break;

            case CompressionType::Zstd:
                msg.send_v2_setup.flags = kSyncFlagZstd;
                break;

            case CompressionType::Any:
                LOG(FATAL) << "unexpected CompressionType::Any";
        }
@@ -421,6 +429,10 @@ class SyncConnection {
                msg.recv_v2_setup.flags |= kSyncFlagLZ4;
                break;

            case CompressionType::Zstd:
                msg.recv_v2_setup.flags |= kSyncFlagZstd;
                break;

            case CompressionType::Any:
                LOG(FATAL) << "unexpected CompressionType::Any";
        }
@@ -631,7 +643,8 @@ class SyncConnection {
        syncsendbuf sbuf;
        sbuf.id = ID_DATA;

        std::variant<std::monostate, NullEncoder, BrotliEncoder, LZ4Encoder> encoder_storage;
        std::variant<std::monostate, NullEncoder, BrotliEncoder, LZ4Encoder, ZstdEncoder>
                encoder_storage;
        Encoder* encoder = nullptr;
        switch (compression) {
            case CompressionType::None:
@@ -646,6 +659,10 @@ class SyncConnection {
                encoder = &encoder_storage.emplace<LZ4Encoder>(SYNC_DATA_MAX);
                break;

            case CompressionType::Zstd:
                encoder = &encoder_storage.emplace<ZstdEncoder>(SYNC_DATA_MAX);
                break;

            case CompressionType::Any:
                LOG(FATAL) << "unexpected CompressionType::Any";
        }
@@ -928,6 +945,7 @@ class SyncConnection {
    bool have_sendrecv_v2_;
    bool have_sendrecv_v2_brotli_;
    bool have_sendrecv_v2_lz4_;
    bool have_sendrecv_v2_zstd_;
    bool have_sendrecv_v2_dry_run_send_;

    TransferLedger global_ledger_;
@@ -1133,7 +1151,8 @@ static bool sync_recv_v2(SyncConnection& sc, const char* rpath, const char* lpat
    uint64_t bytes_copied = 0;

    Block buffer(SYNC_DATA_MAX);
    std::variant<std::monostate, NullDecoder, BrotliDecoder, LZ4Decoder> decoder_storage;
    std::variant<std::monostate, NullDecoder, BrotliDecoder, LZ4Decoder, ZstdDecoder>
            decoder_storage;
    Decoder* decoder = nullptr;

    std::span buffer_span(buffer.data(), buffer.size());
@@ -1150,6 +1169,10 @@ static bool sync_recv_v2(SyncConnection& sc, const char* rpath, const char* lpat
            decoder = &decoder_storage.emplace<LZ4Decoder>(buffer_span);
            break;

        case CompressionType::Zstd:
            decoder = &decoder_storage.emplace<ZstdDecoder>(buffer_span);
            break;

        case CompressionType::Any:
            LOG(FATAL) << "unexpected CompressionType::Any";
    }
+103 −0
Original line number Diff line number Diff line
@@ -25,6 +25,7 @@
#include <brotli/decode.h>
#include <brotli/encode.h>
#include <lz4frame.h>
#include <zstd.h>

#include "types.h"

@@ -381,3 +382,105 @@ struct LZ4Encoder final : public Encoder {
    std::unique_ptr<LZ4F_cctx, LZ4F_errorCode_t (*)(LZ4F_cctx*)> encoder_;
    IOVector output_buffer_;
};

struct ZstdDecoder final : public Decoder {
    explicit ZstdDecoder(std::span<char> output_buffer)
        : Decoder(output_buffer), decoder_(ZSTD_createDStream(), ZSTD_freeDStream) {
        if (!decoder_) {
            LOG(FATAL) << "failed to initialize Zstd decompression context";
        }
    }

    DecodeResult Decode(std::span<char>* output) final {
        ZSTD_inBuffer in;
        in.src = input_buffer_.front_data();
        in.size = input_buffer_.front_size();
        in.pos = 0;

        ZSTD_outBuffer out;
        out.dst = output_buffer_.data();
        // The standard specifies size() as returning size_t, but our current version of
        // libc++ returns a signed value instead.
        out.size = static_cast<size_t>(output_buffer_.size());
        out.pos = 0;

        size_t rc = ZSTD_decompressStream(decoder_.get(), &out, &in);
        if (ZSTD_isError(rc)) {
            LOG(ERROR) << "ZSTD_decompressStream failed: " << ZSTD_getErrorName(rc);
            return DecodeResult::Error;
        }

        input_buffer_.drop_front(in.pos);
        if (rc == 0) {
            if (!input_buffer_.empty()) {
                LOG(ERROR) << "Zstd stream hit end before reading all data";
                return DecodeResult::Error;
            }
            zstd_done_ = true;
        }

        *output = std::span<char>(output_buffer_.data(), out.pos);

        if (finished_) {
            return input_buffer_.empty() && zstd_done_ ? DecodeResult::Done
                                                       : DecodeResult::MoreOutput;
        }
        return DecodeResult::NeedInput;
    }

  private:
    bool zstd_done_ = false;
    std::unique_ptr<ZSTD_DStream, size_t (*)(ZSTD_DStream*)> decoder_;
};

struct ZstdEncoder final : public Encoder {
    explicit ZstdEncoder(size_t output_block_size)
        : Encoder(output_block_size), encoder_(ZSTD_createCStream(), ZSTD_freeCStream) {
        if (!encoder_) {
            LOG(FATAL) << "failed to initialize Zstd compression context";
        }
        ZSTD_CCtx_setParameter(encoder_.get(), ZSTD_c_compressionLevel, 1);
    }

    EncodeResult Encode(Block* output) final {
        ZSTD_inBuffer in;
        in.src = input_buffer_.front_data();
        in.size = input_buffer_.front_size();
        in.pos = 0;

        output->resize(output_block_size_);

        ZSTD_outBuffer out;
        out.dst = output->data();
        out.size = static_cast<size_t>(output->size());
        out.pos = 0;

        ZSTD_EndDirective end_directive = finished_ ? ZSTD_e_end : ZSTD_e_continue;
        size_t rc = ZSTD_compressStream2(encoder_.get(), &out, &in, end_directive);
        if (ZSTD_isError(rc)) {
            LOG(ERROR) << "ZSTD_compressStream2 failed: " << ZSTD_getErrorName(rc);
            return EncodeResult::Error;
        }

        input_buffer_.drop_front(in.pos);
        output->resize(out.pos);

        if (rc == 0) {
            // Zstd finished flushing its data.
            if (finished_) {
                if (!input_buffer_.empty()) {
                    LOG(ERROR) << "ZSTD_compressStream2 finished early";
                    return EncodeResult::Error;
                }
                return EncodeResult::Done;
            } else {
                return input_buffer_.empty() ? EncodeResult::NeedInput : EncodeResult::MoreOutput;
            }
        } else {
            return EncodeResult::MoreOutput;
        }
    }

  private:
    std::unique_ptr<ZSTD_CStream, size_t (*)(ZSTD_CStream*)> encoder_;
};
+30 −2
Original line number Diff line number Diff line
@@ -272,7 +272,8 @@ static bool handle_send_file_data(borrowed_fd s, unique_fd fd, uint32_t* timesta
    syncmsg msg;
    Block buffer(SYNC_DATA_MAX);
    std::span<char> buffer_span(buffer.data(), buffer.size());
    std::variant<std::monostate, NullDecoder, BrotliDecoder, LZ4Decoder> decoder_storage;
    std::variant<std::monostate, NullDecoder, BrotliDecoder, LZ4Decoder, ZstdDecoder>
            decoder_storage;
    Decoder* decoder = nullptr;

    switch (compression) {
@@ -288,6 +289,10 @@ static bool handle_send_file_data(borrowed_fd s, unique_fd fd, uint32_t* timesta
            decoder = &decoder_storage.emplace<LZ4Decoder>(buffer_span);
            break;

        case CompressionType::Zstd:
            decoder = &decoder_storage.emplace<ZstdDecoder>(buffer_span);
            break;

        case CompressionType::Any:
            LOG(FATAL) << "unexpected CompressionType::Any";
    }
@@ -590,6 +595,15 @@ static bool do_send_v2(int s, const std::string& path, std::vector<char>& buffer
        }
        compression = CompressionType::LZ4;
    }
    if (msg.send_v2_setup.flags & kSyncFlagZstd) {
        msg.send_v2_setup.flags &= ~kSyncFlagZstd;
        if (compression) {
            SendSyncFail(s, android::base::StringPrintf("multiple compression flags received: %d",
                                                        orig_flags));
            return false;
        }
        compression = CompressionType::Zstd;
    }
    if (msg.send_v2_setup.flags & kSyncFlagDryRun) {
        msg.send_v2_setup.flags &= ~kSyncFlagDryRun;
        dry_run = true;
@@ -623,7 +637,8 @@ static bool recv_impl(borrowed_fd s, const char* path, CompressionType compressi
    syncmsg msg;
    msg.data.id = ID_DATA;

    std::variant<std::monostate, NullEncoder, BrotliEncoder, LZ4Encoder> encoder_storage;
    std::variant<std::monostate, NullEncoder, BrotliEncoder, LZ4Encoder, ZstdEncoder>
            encoder_storage;
    Encoder* encoder;

    switch (compression) {
@@ -639,6 +654,10 @@ static bool recv_impl(borrowed_fd s, const char* path, CompressionType compressi
            encoder = &encoder_storage.emplace<LZ4Encoder>(SYNC_DATA_MAX);
            break;

        case CompressionType::Zstd:
            encoder = &encoder_storage.emplace<ZstdEncoder>(SYNC_DATA_MAX);
            break;

        case CompressionType::Any:
            LOG(FATAL) << "unexpected CompressionType::Any";
    }
@@ -726,6 +745,15 @@ static bool do_recv_v2(borrowed_fd s, const char* path, std::vector<char>& buffe
        }
        compression = CompressionType::LZ4;
    }
    if (msg.recv_v2_setup.flags & kSyncFlagZstd) {
        msg.recv_v2_setup.flags &= ~kSyncFlagZstd;
        if (compression) {
            SendSyncFail(s, android::base::StringPrintf("multiple compression flags received: %d",
                                                        orig_flags));
            return false;
        }
        compression = CompressionType::Zstd;
    }

    if (msg.recv_v2_setup.flags) {
        SendSyncFail(s, android::base::StringPrintf("unknown flags: %d", msg.recv_v2_setup.flags));
Loading