Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 06c2c0de authored by chrisweir's avatar chrisweir
Browse files

Extend libnetdevice for Netlink Proxy

Add some additional generic send/receive functionality for
NetlinkSocket.

Bug: 155190864
Test: Manual
Change-Id: I7a882fa642553c61e0b2b3a32638a309089c6d22
parent 87ef3447
Loading
Loading
Loading
Loading
+78 −15
Original line number Diff line number Diff line
@@ -27,7 +27,8 @@ namespace android::netdevice {
 */
static constexpr bool kSuperVerbose = false;

NetlinkSocket::NetlinkSocket(int protocol) : mProtocol(protocol) {
NetlinkSocket::NetlinkSocket(int protocol, unsigned int pid, uint32_t groups)
    : mProtocol(protocol) {
    mFd.reset(socket(AF_NETLINK, SOCK_RAW, protocol));
    if (!mFd.ok()) {
        PLOG(ERROR) << "Can't open Netlink socket";
@@ -35,21 +36,23 @@ NetlinkSocket::NetlinkSocket(int protocol) : mProtocol(protocol) {
        return;
    }

    struct sockaddr_nl sa = {};
    sockaddr_nl sa = {};
    sa.nl_family = AF_NETLINK;
    sa.nl_pid = pid;
    sa.nl_groups = groups;

    if (bind(mFd.get(), reinterpret_cast<struct sockaddr*>(&sa), sizeof(sa)) < 0) {
    if (bind(mFd.get(), reinterpret_cast<sockaddr*>(&sa), sizeof(sa)) < 0) {
        PLOG(ERROR) << "Can't bind Netlink socket";
        mFd.reset();
        mFailed = true;
    }
}

bool NetlinkSocket::send(struct nlmsghdr* nlmsg, size_t totalLen) {
bool NetlinkSocket::send(nlmsghdr* nlmsg, size_t totalLen) {
    if constexpr (kSuperVerbose) {
        nlmsg->nlmsg_seq = mSeq;
        LOG(VERBOSE) << (mFailed ? "(not) " : "")
                     << "sending Netlink message: " << toString(nlmsg, totalLen, mProtocol);
                     << "sending Netlink message: " << toString({nlmsg, totalLen}, mProtocol);
    }

    if (mFailed) return false;
@@ -58,12 +61,12 @@ bool NetlinkSocket::send(struct nlmsghdr* nlmsg, size_t totalLen) {
    nlmsg->nlmsg_seq = mSeq++;
    nlmsg->nlmsg_flags |= NLM_F_ACK;

    struct iovec iov = {nlmsg, nlmsg->nlmsg_len};
    iovec iov = {nlmsg, nlmsg->nlmsg_len};

    struct sockaddr_nl sa = {};
    sockaddr_nl sa = {};
    sa.nl_family = AF_NETLINK;

    struct msghdr msg = {};
    msghdr msg = {};
    msg.msg_name = &sa;
    msg.msg_namelen = sizeof(sa);
    msg.msg_iov = &iov;
@@ -76,15 +79,65 @@ bool NetlinkSocket::send(struct nlmsghdr* nlmsg, size_t totalLen) {
    return true;
}

bool NetlinkSocket::send(const nlbuf<nlmsghdr>& msg, const sockaddr_nl& sa) {
    if constexpr (kSuperVerbose) {
        LOG(VERBOSE) << (mFailed ? "(not) " : "")
                     << "sending Netlink message: " << toString(msg, mProtocol);
    }

    if (mFailed) return false;
    const auto rawMsg = msg.getRaw();
    const auto bytesSent = sendto(mFd.get(), rawMsg.ptr(), rawMsg.len(), 0,
                                  reinterpret_cast<const sockaddr*>(&sa), sizeof(sa));
    if (bytesSent < 0) {
        PLOG(ERROR) << "Can't send Netlink message";
        return false;
    }
    return true;
}

std::optional<nlbuf<nlmsghdr>> NetlinkSocket::receive(void* buf, size_t bufLen) {
    sockaddr_nl sa = {};
    return receive(buf, bufLen, sa);
}

std::optional<nlbuf<nlmsghdr>> NetlinkSocket::receive(void* buf, size_t bufLen, sockaddr_nl& sa) {
    if (mFailed) return std::nullopt;

    socklen_t saLen = sizeof(sa);
    if (bufLen == 0) {
        LOG(ERROR) << "Receive buffer has zero size!";
        return std::nullopt;
    }
    const auto bytesReceived =
            recvfrom(mFd.get(), buf, bufLen, MSG_TRUNC, reinterpret_cast<sockaddr*>(&sa), &saLen);
    if (bytesReceived <= 0) {
        PLOG(ERROR) << "Failed to receive Netlink message";
        return std::nullopt;
    } else if (unsigned(bytesReceived) > bufLen) {
        PLOG(ERROR) << "Received data larger than the receive buffer! " << bytesReceived << " > "
                    << bufLen;
        return std::nullopt;
    }

    nlbuf<nlmsghdr> msg(reinterpret_cast<nlmsghdr*>(buf), bytesReceived);
    if constexpr (kSuperVerbose) {
        LOG(VERBOSE) << "received " << toString(msg, mProtocol);
    }
    return msg;
}

/* TODO(161389935): Migrate receiveAck to use nlmsg<> internally. Possibly reuse
 * NetlinkSocket::receive(). */
bool NetlinkSocket::receiveAck() {
    if (mFailed) return false;

    char buf[8192];

    struct sockaddr_nl sa;
    struct iovec iov = {buf, sizeof(buf)};
    sockaddr_nl sa;
    iovec iov = {buf, sizeof(buf)};

    struct msghdr msg = {};
    msghdr msg = {};
    msg.msg_name = &sa;
    msg.msg_namelen = sizeof(sa);
    msg.msg_iov = &iov;
@@ -102,11 +155,11 @@ bool NetlinkSocket::receiveAck() {
        return false;
    }

    for (auto nlmsg = reinterpret_cast<struct nlmsghdr*>(buf); NLMSG_OK(nlmsg, remainingLen);
    for (auto nlmsg = reinterpret_cast<nlmsghdr*>(buf); NLMSG_OK(nlmsg, remainingLen);
         nlmsg = NLMSG_NEXT(nlmsg, remainingLen)) {
        if constexpr (kSuperVerbose) {
            LOG(VERBOSE) << "received Netlink response: "
                         << toString(nlmsg, sizeof(buf), mProtocol);
                         << toString({nlmsg, nlmsg->nlmsg_len}, mProtocol);
        }

        // We're looking for error/ack message only, ignoring others.
@@ -116,7 +169,7 @@ bool NetlinkSocket::receiveAck() {
        }

        // Found error/ack message, return status.
        auto nlerr = reinterpret_cast<struct nlmsgerr*>(NLMSG_DATA(nlmsg));
        const auto nlerr = reinterpret_cast<nlmsgerr*>(NLMSG_DATA(nlmsg));
        if (nlerr->error != 0) {
            LOG(ERROR) << "Received Netlink error message: " << nlerr->error;
            return false;
@@ -127,4 +180,14 @@ bool NetlinkSocket::receiveAck() {
    return false;
}

std::optional<unsigned int> NetlinkSocket::getSocketPid() {
    sockaddr_nl sa = {};
    socklen_t sasize = sizeof(sa);
    if (getsockname(mFd.get(), reinterpret_cast<sockaddr*>(&sa), &sasize) < 0) {
        PLOG(ERROR) << "Failed to getsockname() for netlink_fd!";
        return std::nullopt;
    }
    return sa.nl_pid;
}

}  // namespace android::netdevice
+53 −4
Original line number Diff line number Diff line
@@ -19,9 +19,12 @@
#include <android-base/macros.h>
#include <android-base/unique_fd.h>
#include <libnetdevice/NetlinkRequest.h>
#include <libnetdevice/nlbuf.h>

#include <linux/netlink.h>

#include <optional>

namespace android::netdevice {

/**
@@ -31,12 +34,23 @@ namespace android::netdevice {
 * use multiple instances over multiple threads.
 */
struct NetlinkSocket {
    NetlinkSocket(int protocol);
    /**
     * NetlinkSocket constructor.
     *
     * \param protocol the Netlink protocol to use.
     * \param pid port id. Default value of 0 allows the kernel to assign us a unique pid. (NOTE:
     * this is NOT the same as process id!)
     * \param groups Netlink multicast groups to listen to. This is a 32-bit bitfield, where each
     * bit is a different group. Default value of 0 means no groups are selected. See man netlink.7
     * for more details.
     */
    NetlinkSocket(int protocol, unsigned int pid = 0, uint32_t groups = 0);

    /**
     * Send Netlink message to Kernel.
     * Send Netlink message to Kernel. The sequence number will be automatically incremented, and
     * the NLM_F_ACK (request ACK) flag will be set.
     *
     * \param msg Message to send, nlmsg_seq will be set to next sequence number
     * \param msg Message to send.
     * \return true, if succeeded
     */
    template <class T, unsigned int BUFSIZE>
@@ -45,6 +59,34 @@ struct NetlinkSocket {
        return send(req.header(), req.totalLength);
    }

    /**
     * Send Netlink message. The message will be sent as is, without any modification.
     *
     * \param msg Message to send.
     * \param sa Destination address.
     * \return true, if succeeded
     */
    bool send(const nlbuf<nlmsghdr>& msg, const sockaddr_nl& sa);

    /**
     * Receive Netlink data.
     *
     * \param buf buffer to hold message data.
     * \param bufLen length of buf.
     * \return nlbuf with message data, std::nullopt on error.
     */
    std::optional<nlbuf<nlmsghdr>> receive(void* buf, size_t bufLen);

    /**
     * Receive Netlink data with address info.
     *
     * \param buf buffer to hold message data.
     * \param bufLen length of buf.
     * \param sa Blank struct that recvfrom will populate with address info.
     * \return nlbuf with message data, std::nullopt on error.
     */
    std::optional<nlbuf<nlmsghdr>> receive(void* buf, size_t bufLen, sockaddr_nl& sa);

    /**
     * Receive Netlink ACK message from Kernel.
     *
@@ -52,6 +94,13 @@ struct NetlinkSocket {
     */
    bool receiveAck();

    /**
     * Gets the PID assigned to mFd.
     *
     * \return pid that mSocket is bound to.
     */
    std::optional<unsigned int> getSocketPid();

  private:
    const int mProtocol;

@@ -59,7 +108,7 @@ struct NetlinkSocket {
    base::unique_fd mFd;
    bool mFailed = false;

    bool send(struct nlmsghdr* msg, size_t totalLen);
    bool send(nlmsghdr* msg, size_t totalLen);

    DISALLOW_COPY_AND_ASSIGN(NetlinkSocket);
};
+6 −0
Original line number Diff line number Diff line
@@ -53,6 +53,12 @@ class nlbuf {
    static constexpr size_t hdrlen = align(sizeof(T));

  public:
    /**
     * Constructor for nlbuf.
     *
     * \param data A pointer to the data the nlbuf wraps.
     * \param bufferLen Length of buffer.
     */
    nlbuf(const T* data, size_t bufferLen) : mData(data), mBufferEnd(pointerAdd(data, bufferLen)) {}

    const T* operator->() const {
+11 −1
Original line number Diff line number Diff line
@@ -16,12 +16,22 @@

#pragma once

#include <libnetdevice/nlbuf.h>

#include <linux/netlink.h>

#include <string>

namespace android::netdevice {

std::string toString(const nlmsghdr* hdr, size_t bufLen, int protocol);
/**
 * Stringify a Netlink message.
 *
 * \param hdr Pointer to the message(s) to print.
 * \param protocol Which Netlink protocol hdr uses.
 * \param printPayload True will stringify message data, false will only stringify the header(s).
 * \return Stringified message.
 */
std::string toString(const nlbuf<nlmsghdr> hdr, int protocol, bool printPayload = false);

}  // namespace android::netdevice
+12 −6
Original line number Diff line number Diff line
@@ -62,13 +62,20 @@ static void flagsToStream(std::stringstream& ss, __u16 nlmsg_flags) {
}

static void toStream(std::stringstream& ss, const nlbuf<uint8_t> data) {
    const auto rawData = data.getRaw();
    const auto dataLen = rawData.len();
    ss << std::hex;
    if (dataLen > 16) ss << std::endl << " 0000 ";
    int i = 0;
    for (const auto byte : data.getRaw()) {
    for (const auto byte : rawData) {
        if (i++ > 0) ss << ' ';
        ss << std::setw(2) << unsigned(byte);
        if (i % 16 == 0) {
            ss << std::endl << ' ' << std::dec << std::setw(4) << i << std::hex;
        }
    }
    ss << std::dec;
    if (dataLen > 16) ss << std::endl;
}

static void toStream(std::stringstream& ss, const nlbuf<nlattr> attr,
@@ -105,7 +112,7 @@ static void toStream(std::stringstream& ss, const nlbuf<nlattr> attr,
    }
}

static std::string toString(const nlbuf<nlmsghdr> hdr, int protocol) {
std::string toString(const nlbuf<nlmsghdr> hdr, int protocol, bool printPayload) {
    if (!hdr.firstOk()) return "nlmsg{buffer overflow}";

    std::stringstream ss;
@@ -133,10 +140,13 @@ static std::string toString(const nlbuf<nlmsghdr> hdr, int protocol) {
    }
    if (hdr->nlmsg_seq != 0) ss << ", seq=" << hdr->nlmsg_seq;
    if (hdr->nlmsg_pid != 0) ss << ", pid=" << hdr->nlmsg_pid;
    ss << ", len=" << hdr->nlmsg_len;

    ss << ", crc=" << std::hex << std::setw(4) << crc16(hdr.data<uint8_t>()) << std::dec;
    ss << "} ";

    if (!printPayload) return ss.str();

    if (!msgDescMaybe.has_value()) {
        toStream(ss, hdr.data<uint8_t>());
    } else {
@@ -161,8 +171,4 @@ static std::string toString(const nlbuf<nlmsghdr> hdr, int protocol) {
    return ss.str();
}

std::string toString(const nlmsghdr* hdr, size_t bufLen, int protocol) {
    return toString({hdr, bufLen}, protocol);
}

}  // namespace android::netdevice
Loading