Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 2424b716 authored by Tomasz Wasilczyk's avatar Tomasz Wasilczyk
Browse files

Implement socket receive iterator

Bug: 162032964
Test: watch logcat for ifautocf and l2repeater
Test: adb shell canhalctrl up test virtual vcan2
Change-Id: Icbae3951113391846cfcf9a6747ed565bdaa7dd7
parent 34d22ce5
Loading
Loading
Loading
Loading
+16 −20
Original line number Diff line number Diff line
@@ -165,11 +165,7 @@ void waitFor(std::set<std::string> ifnames, WaitCondition cnd, bool allOf) {

    LOG(DEBUG) << "Waiting for " << (allOf ? "" : "any of ") << toString(ifnames) << " to "
               << toString(cnd);
    while (true) {
        const auto msgBuf = sock.receive();
        CHECK(msgBuf.has_value()) << "Can't read Netlink socket";

        for (const auto rawMsg : *msgBuf) {
    for (const auto rawMsg : sock) {
        const auto msg = nl::Message<ifinfomsg>::parse(rawMsg, {RTM_NEWLINK, RTM_DELLINK});
        if (!msg.has_value()) continue;

@@ -181,12 +177,12 @@ void waitFor(std::set<std::string> ifnames, WaitCondition cnd, bool allOf) {
        states[ifname] = {present, up};

        if (isFullySatisfied()) {
                LOG(DEBUG) << "Finished waiting for " << (allOf ? "" : "some of ")
                           << toString(ifnames) << " to " << toString(cnd);
            LOG(DEBUG) << "Finished waiting for " << (allOf ? "" : "some of ") << toString(ifnames)
                       << " to " << toString(cnd);
            return;
        }
    }
    }
    LOG(FATAL) << "Can't read Netlink socket";
}

}  // namespace android::netdevice
+63 −17
Original line number Diff line number Diff line
@@ -68,6 +68,16 @@ bool Socket::send(const Buffer<nlmsghdr>& msg, const sockaddr_nl& sa) {
    return true;
}

bool Socket::increaseReceiveBuffer(size_t maxSize) {
    if (maxSize == 0) {
        LOG(ERROR) << "Maximum receive size should not be zero";
        return false;
    }

    if (mReceiveBuffer.size() < maxSize) mReceiveBuffer.resize(maxSize);
    return true;
}

std::optional<Buffer<nlmsghdr>> Socket::receive(size_t maxSize) {
    return receiveFrom(maxSize).first;
}
@@ -75,11 +85,7 @@ std::optional<Buffer<nlmsghdr>> Socket::receive(size_t maxSize) {
std::pair<std::optional<Buffer<nlmsghdr>>, sockaddr_nl> Socket::receiveFrom(size_t maxSize) {
    if (mFailed) return {std::nullopt, {}};

    if (maxSize == 0) {
        LOG(ERROR) << "Maximum receive size should not be zero";
        return {std::nullopt, {}};
    }
    if (mReceiveBuffer.size() < maxSize) mReceiveBuffer.resize(maxSize);
    if (!increaseReceiveBuffer(maxSize)) return {std::nullopt, {}};

    sockaddr_nl sa = {};
    socklen_t saLen = sizeof(sa);
@@ -120,11 +126,9 @@ bool Socket::receiveAck(uint32_t seq) {

std::optional<Buffer<nlmsghdr>> Socket::receive(const std::set<nlmsgtype_t>& msgtypes,
                                                size_t maxSize) {
    while (!mFailed) {
        const auto msgBuf = receive(maxSize);
        if (!msgBuf.has_value()) return std::nullopt;
    if (mFailed || !increaseReceiveBuffer(maxSize)) return std::nullopt;

        for (const auto rawMsg : *msgBuf) {
    for (const auto rawMsg : *this) {
        if (msgtypes.count(rawMsg->nlmsg_type) == 0) {
            LOG(WARNING) << "Received (and ignored) unexpected Netlink message of type "
                         << rawMsg->nlmsg_type;
@@ -133,7 +137,6 @@ std::optional<Buffer<nlmsghdr>> Socket::receive(const std::set<nlmsgtype_t>& msg

        return rawMsg;
    }
    }

    return std::nullopt;
}
@@ -150,4 +153,47 @@ std::optional<unsigned> Socket::getPid() {
    return sa.nl_pid;
}

Socket::receive_iterator::receive_iterator(Socket& socket, bool end)
    : mSocket(socket), mIsEnd(end) {
    if (!end) receive();
}

Socket::receive_iterator Socket::receive_iterator::operator++() {
    CHECK(!mIsEnd) << "Trying to increment end iterator";
    ++mCurrent;
    if (mCurrent.isEnd()) receive();
    return *this;
}

bool Socket::receive_iterator::operator==(const receive_iterator& other) const {
    if (mIsEnd != other.mIsEnd) return false;
    if (mIsEnd && other.mIsEnd) return true;
    return mCurrent == other.mCurrent;
}

const Buffer<nlmsghdr>& Socket::receive_iterator::operator*() const {
    CHECK(!mIsEnd) << "Trying to dereference end iterator";
    return *mCurrent;
}

void Socket::receive_iterator::receive() {
    CHECK(!mIsEnd) << "Trying to receive on end iterator";
    CHECK(mCurrent.isEnd()) << "Trying to receive without draining previous read";

    const auto buf = mSocket.receive();
    if (buf.has_value()) {
        mCurrent = buf->begin();
    } else {
        mIsEnd = true;
    }
}

Socket::receive_iterator Socket::begin() {
    return {*this, false};
}

Socket::receive_iterator Socket::end() {
    return {*this, true};
}

}  // namespace android::nl
+4 −2
Original line number Diff line number Diff line
@@ -94,7 +94,7 @@ class Buffer {
    class iterator {
      public:
        iterator() : mCurrent(nullptr, size_t(0)) {
            CHECK(!mCurrent.ok()) << "end() iterator should indicate it's beyond end";
            CHECK(isEnd()) << "end() iterator should indicate it's beyond end";
        }
        iterator(const Buffer<T>& buf) : mCurrent(buf) {}

@@ -108,13 +108,15 @@ class Buffer {

        bool operator==(const iterator& other) const {
            // all iterators beyond end are the same
            if (!mCurrent.ok() && !other.mCurrent.ok()) return true;
            if (isEnd() && other.isEnd()) return true;

            return uintptr_t(other.mCurrent.mData) == uintptr_t(mCurrent.mData);
        }

        const Buffer<T>& operator*() const { return mCurrent; }

        bool isEnd() const { return !mCurrent.ok(); }

      protected:
        Buffer<T> mCurrent;
    };
+37 −0
Original line number Diff line number Diff line
@@ -173,6 +173,42 @@ class Socket {
     */
    std::optional<unsigned> getPid();

    /**
     * Live iterator continuously receiving messages from Netlink socket.
     *
     * Iteration ends when socket fails to receive a buffer.
     *
     * Example:
     * ```
     *     nl::Socket sock(NETLINK_ROUTE, 0, RTMGRP_LINK);
     *     for (const auto rawMsg : sock) {
     *         const auto msg = nl::Message<ifinfomsg>::parse(rawMsg, {RTM_NEWLINK, RTM_DELLINK});
     *         if (!msg.has_value()) continue;
     *
     *         LOG(INFO) << msg->attributes.get<std::string>(IFLA_IFNAME)
     *                   << " is " << ((msg->data.ifi_flags & IFF_UP) ? "up" : "down");
     *     }
     *     LOG(FATAL) << "Failed to read from Netlink socket";
     * ```
     */
    class receive_iterator {
      public:
        receive_iterator(Socket& socket, bool end);

        receive_iterator operator++();
        bool operator==(const receive_iterator& other) const;
        const Buffer<nlmsghdr>& operator*() const;

      private:
        Socket& mSocket;
        bool mIsEnd;
        Buffer<nlmsghdr>::iterator mCurrent;

        void receive();
    };
    receive_iterator begin();
    receive_iterator end();

  private:
    const int mProtocol;
    base::unique_fd mFd;
@@ -181,6 +217,7 @@ class Socket {
    bool mFailed = false;
    uint32_t mSeq = 0;

    bool increaseReceiveBuffer(size_t maxSize);
    std::optional<Buffer<nlmsghdr>> receive(const std::set<nlmsgtype_t>& msgtypes, size_t maxSize);

    DISALLOW_COPY_AND_ASSIGN(Socket);
+1 −1
Original line number Diff line number Diff line
@@ -18,7 +18,7 @@

namespace android::nl::protocols {

NetlinkProtocol::NetlinkProtocol(int protocol, const std::string name,
NetlinkProtocol::NetlinkProtocol(int protocol, const std::string& name,
                                 const MessageDescriptorList&& messageDescrs)
    : mProtocol(protocol), mName(name), mMessageDescrs(toMap(messageDescrs, protocol)) {}

Loading