Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit cef964c1 authored by Chris Weir's avatar Chris Weir Committed by Android (Google) Code Review
Browse files

Merge "Extend libnetdevice for Netlink Proxy"

parents 9b0786c0 06c2c0de
Loading
Loading
Loading
Loading
+78 −15
Original line number Diff line number Diff line
@@ -27,7 +27,8 @@ namespace android::netdevice {
 */
static constexpr bool kSuperVerbose = false;

NetlinkSocket::NetlinkSocket(int protocol) : mProtocol(protocol) {
NetlinkSocket::NetlinkSocket(int protocol, unsigned int pid, uint32_t groups)
    : mProtocol(protocol) {
    mFd.reset(socket(AF_NETLINK, SOCK_RAW, protocol));
    if (!mFd.ok()) {
        PLOG(ERROR) << "Can't open Netlink socket";
@@ -35,21 +36,23 @@ NetlinkSocket::NetlinkSocket(int protocol) : mProtocol(protocol) {
        return;
    }

    struct sockaddr_nl sa = {};
    sockaddr_nl sa = {};
    sa.nl_family = AF_NETLINK;
    sa.nl_pid = pid;
    sa.nl_groups = groups;

    if (bind(mFd.get(), reinterpret_cast<struct sockaddr*>(&sa), sizeof(sa)) < 0) {
    if (bind(mFd.get(), reinterpret_cast<sockaddr*>(&sa), sizeof(sa)) < 0) {
        PLOG(ERROR) << "Can't bind Netlink socket";
        mFd.reset();
        mFailed = true;
    }
}

bool NetlinkSocket::send(struct nlmsghdr* nlmsg, size_t totalLen) {
bool NetlinkSocket::send(nlmsghdr* nlmsg, size_t totalLen) {
    if constexpr (kSuperVerbose) {
        nlmsg->nlmsg_seq = mSeq;
        LOG(VERBOSE) << (mFailed ? "(not) " : "")
                     << "sending Netlink message: " << toString(nlmsg, totalLen, mProtocol);
                     << "sending Netlink message: " << toString({nlmsg, totalLen}, mProtocol);
    }

    if (mFailed) return false;
@@ -58,12 +61,12 @@ bool NetlinkSocket::send(struct nlmsghdr* nlmsg, size_t totalLen) {
    nlmsg->nlmsg_seq = mSeq++;
    nlmsg->nlmsg_flags |= NLM_F_ACK;

    struct iovec iov = {nlmsg, nlmsg->nlmsg_len};
    iovec iov = {nlmsg, nlmsg->nlmsg_len};

    struct sockaddr_nl sa = {};
    sockaddr_nl sa = {};
    sa.nl_family = AF_NETLINK;

    struct msghdr msg = {};
    msghdr msg = {};
    msg.msg_name = &sa;
    msg.msg_namelen = sizeof(sa);
    msg.msg_iov = &iov;
@@ -76,15 +79,65 @@ bool NetlinkSocket::send(struct nlmsghdr* nlmsg, size_t totalLen) {
    return true;
}

bool NetlinkSocket::send(const nlbuf<nlmsghdr>& msg, const sockaddr_nl& sa) {
    if constexpr (kSuperVerbose) {
        LOG(VERBOSE) << (mFailed ? "(not) " : "")
                     << "sending Netlink message: " << toString(msg, mProtocol);
    }

    if (mFailed) return false;
    const auto rawMsg = msg.getRaw();
    const auto bytesSent = sendto(mFd.get(), rawMsg.ptr(), rawMsg.len(), 0,
                                  reinterpret_cast<const sockaddr*>(&sa), sizeof(sa));
    if (bytesSent < 0) {
        PLOG(ERROR) << "Can't send Netlink message";
        return false;
    }
    return true;
}

std::optional<nlbuf<nlmsghdr>> NetlinkSocket::receive(void* buf, size_t bufLen) {
    sockaddr_nl sa = {};
    return receive(buf, bufLen, sa);
}

std::optional<nlbuf<nlmsghdr>> NetlinkSocket::receive(void* buf, size_t bufLen, sockaddr_nl& sa) {
    if (mFailed) return std::nullopt;

    socklen_t saLen = sizeof(sa);
    if (bufLen == 0) {
        LOG(ERROR) << "Receive buffer has zero size!";
        return std::nullopt;
    }
    const auto bytesReceived =
            recvfrom(mFd.get(), buf, bufLen, MSG_TRUNC, reinterpret_cast<sockaddr*>(&sa), &saLen);
    if (bytesReceived <= 0) {
        PLOG(ERROR) << "Failed to receive Netlink message";
        return std::nullopt;
    } else if (unsigned(bytesReceived) > bufLen) {
        PLOG(ERROR) << "Received data larger than the receive buffer! " << bytesReceived << " > "
                    << bufLen;
        return std::nullopt;
    }

    nlbuf<nlmsghdr> msg(reinterpret_cast<nlmsghdr*>(buf), bytesReceived);
    if constexpr (kSuperVerbose) {
        LOG(VERBOSE) << "received " << toString(msg, mProtocol);
    }
    return msg;
}

/* TODO(161389935): Migrate receiveAck to use nlmsg<> internally. Possibly reuse
 * NetlinkSocket::receive(). */
bool NetlinkSocket::receiveAck() {
    if (mFailed) return false;

    char buf[8192];

    struct sockaddr_nl sa;
    struct iovec iov = {buf, sizeof(buf)};
    sockaddr_nl sa;
    iovec iov = {buf, sizeof(buf)};

    struct msghdr msg = {};
    msghdr msg = {};
    msg.msg_name = &sa;
    msg.msg_namelen = sizeof(sa);
    msg.msg_iov = &iov;
@@ -102,11 +155,11 @@ bool NetlinkSocket::receiveAck() {
        return false;
    }

    for (auto nlmsg = reinterpret_cast<struct nlmsghdr*>(buf); NLMSG_OK(nlmsg, remainingLen);
    for (auto nlmsg = reinterpret_cast<nlmsghdr*>(buf); NLMSG_OK(nlmsg, remainingLen);
         nlmsg = NLMSG_NEXT(nlmsg, remainingLen)) {
        if constexpr (kSuperVerbose) {
            LOG(VERBOSE) << "received Netlink response: "
                         << toString(nlmsg, sizeof(buf), mProtocol);
                         << toString({nlmsg, nlmsg->nlmsg_len}, mProtocol);
        }

        // We're looking for error/ack message only, ignoring others.
@@ -116,7 +169,7 @@ bool NetlinkSocket::receiveAck() {
        }

        // Found error/ack message, return status.
        auto nlerr = reinterpret_cast<struct nlmsgerr*>(NLMSG_DATA(nlmsg));
        const auto nlerr = reinterpret_cast<nlmsgerr*>(NLMSG_DATA(nlmsg));
        if (nlerr->error != 0) {
            LOG(ERROR) << "Received Netlink error message: " << strerror(-nlerr->error);
            return false;
@@ -127,4 +180,14 @@ bool NetlinkSocket::receiveAck() {
    return false;
}

std::optional<unsigned int> NetlinkSocket::getSocketPid() {
    sockaddr_nl sa = {};
    socklen_t sasize = sizeof(sa);
    if (getsockname(mFd.get(), reinterpret_cast<sockaddr*>(&sa), &sasize) < 0) {
        PLOG(ERROR) << "Failed to getsockname() for netlink_fd!";
        return std::nullopt;
    }
    return sa.nl_pid;
}

}  // namespace android::netdevice
+53 −4
Original line number Diff line number Diff line
@@ -19,9 +19,12 @@
#include <android-base/macros.h>
#include <android-base/unique_fd.h>
#include <libnetdevice/NetlinkRequest.h>
#include <libnetdevice/nlbuf.h>

#include <linux/netlink.h>

#include <optional>

namespace android::netdevice {

/**
@@ -31,12 +34,23 @@ namespace android::netdevice {
 * use multiple instances over multiple threads.
 */
struct NetlinkSocket {
    NetlinkSocket(int protocol);
    /**
     * NetlinkSocket constructor.
     *
     * \param protocol the Netlink protocol to use.
     * \param pid port id. Default value of 0 allows the kernel to assign us a unique pid. (NOTE:
     * this is NOT the same as process id!)
     * \param groups Netlink multicast groups to listen to. This is a 32-bit bitfield, where each
     * bit is a different group. Default value of 0 means no groups are selected. See man netlink.7
     * for more details.
     */
    NetlinkSocket(int protocol, unsigned int pid = 0, uint32_t groups = 0);

    /**
     * Send Netlink message to Kernel.
     * Send Netlink message to Kernel. The sequence number will be automatically incremented, and
     * the NLM_F_ACK (request ACK) flag will be set.
     *
     * \param msg Message to send, nlmsg_seq will be set to next sequence number
     * \param msg Message to send.
     * \return true, if succeeded
     */
    template <class T, unsigned int BUFSIZE>
@@ -45,6 +59,34 @@ struct NetlinkSocket {
        return send(req.header(), req.totalLength);
    }

    /**
     * Send Netlink message. The message will be sent as is, without any modification.
     *
     * \param msg Message to send.
     * \param sa Destination address.
     * \return true, if succeeded
     */
    bool send(const nlbuf<nlmsghdr>& msg, const sockaddr_nl& sa);

    /**
     * Receive Netlink data.
     *
     * \param buf buffer to hold message data.
     * \param bufLen length of buf.
     * \return nlbuf with message data, std::nullopt on error.
     */
    std::optional<nlbuf<nlmsghdr>> receive(void* buf, size_t bufLen);

    /**
     * Receive Netlink data with address info.
     *
     * \param buf buffer to hold message data.
     * \param bufLen length of buf.
     * \param sa Blank struct that recvfrom will populate with address info.
     * \return nlbuf with message data, std::nullopt on error.
     */
    std::optional<nlbuf<nlmsghdr>> receive(void* buf, size_t bufLen, sockaddr_nl& sa);

    /**
     * Receive Netlink ACK message from Kernel.
     *
@@ -52,6 +94,13 @@ struct NetlinkSocket {
     */
    bool receiveAck();

    /**
     * Gets the PID assigned to mFd.
     *
     * \return pid that mSocket is bound to.
     */
    std::optional<unsigned int> getSocketPid();

  private:
    const int mProtocol;

@@ -59,7 +108,7 @@ struct NetlinkSocket {
    base::unique_fd mFd;
    bool mFailed = false;

    bool send(struct nlmsghdr* msg, size_t totalLen);
    bool send(nlmsghdr* msg, size_t totalLen);

    DISALLOW_COPY_AND_ASSIGN(NetlinkSocket);
};
+6 −0
Original line number Diff line number Diff line
@@ -53,6 +53,12 @@ class nlbuf {
    static constexpr size_t hdrlen = align(sizeof(T));

  public:
    /**
     * Constructor for nlbuf.
     *
     * \param data A pointer to the data the nlbuf wraps.
     * \param bufferLen Length of buffer.
     */
    nlbuf(const T* data, size_t bufferLen) : mData(data), mBufferEnd(pointerAdd(data, bufferLen)) {}

    const T* operator->() const {
+11 −1
Original line number Diff line number Diff line
@@ -16,12 +16,22 @@

#pragma once

#include <libnetdevice/nlbuf.h>

#include <linux/netlink.h>

#include <string>

namespace android::netdevice {

std::string toString(const nlmsghdr* hdr, size_t bufLen, int protocol);
/**
 * Stringify a Netlink message.
 *
 * \param hdr Pointer to the message(s) to print.
 * \param protocol Which Netlink protocol hdr uses.
 * \param printPayload True will stringify message data, false will only stringify the header(s).
 * \return Stringified message.
 */
std::string toString(const nlbuf<nlmsghdr> hdr, int protocol, bool printPayload = false);

}  // namespace android::netdevice
+12 −6
Original line number Diff line number Diff line
@@ -62,13 +62,20 @@ static void flagsToStream(std::stringstream& ss, __u16 nlmsg_flags) {
}

static void toStream(std::stringstream& ss, const nlbuf<uint8_t> data) {
    const auto rawData = data.getRaw();
    const auto dataLen = rawData.len();
    ss << std::hex;
    if (dataLen > 16) ss << std::endl << " 0000 ";
    int i = 0;
    for (const auto byte : data.getRaw()) {
    for (const auto byte : rawData) {
        if (i++ > 0) ss << ' ';
        ss << std::setw(2) << unsigned(byte);
        if (i % 16 == 0) {
            ss << std::endl << ' ' << std::dec << std::setw(4) << i << std::hex;
        }
    }
    ss << std::dec;
    if (dataLen > 16) ss << std::endl;
}

static void toStream(std::stringstream& ss, const nlbuf<nlattr> attr,
@@ -105,7 +112,7 @@ static void toStream(std::stringstream& ss, const nlbuf<nlattr> attr,
    }
}

static std::string toString(const nlbuf<nlmsghdr> hdr, int protocol) {
std::string toString(const nlbuf<nlmsghdr> hdr, int protocol, bool printPayload) {
    if (!hdr.firstOk()) return "nlmsg{buffer overflow}";

    std::stringstream ss;
@@ -133,10 +140,13 @@ static std::string toString(const nlbuf<nlmsghdr> hdr, int protocol) {
    }
    if (hdr->nlmsg_seq != 0) ss << ", seq=" << hdr->nlmsg_seq;
    if (hdr->nlmsg_pid != 0) ss << ", pid=" << hdr->nlmsg_pid;
    ss << ", len=" << hdr->nlmsg_len;

    ss << ", crc=" << std::hex << std::setw(4) << crc16(hdr.data<uint8_t>()) << std::dec;
    ss << "} ";

    if (!printPayload) return ss.str();

    if (!msgDescMaybe.has_value()) {
        toStream(ss, hdr.data<uint8_t>());
    } else {
@@ -161,8 +171,4 @@ static std::string toString(const nlbuf<nlmsghdr> hdr, int protocol) {
    return ss.str();
}

std::string toString(const nlmsghdr* hdr, size_t bufLen, int protocol) {
    return toString({hdr, bufLen}, protocol);
}

}  // namespace android::netdevice
Loading