Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit b428f77f authored by Tomasz Wasilczyk's avatar Tomasz Wasilczyk
Browse files

Netlink socket refactoring

- merge two send() methods into one
- use internal receive buffer instead of asking user to supply one
- move setting sequence number to MessageFactory sending code
- don't limit send function to Kernel as a recipient
- move adding NLM_F_ACK to the caller side
- getSocketPid -> getPid
- unsigned int -> unsigned

One part missing is refactoring receiveAck (b/161389935).

Bug: 162032964
Test: canhalctrl up test virtual vcan3
Change-Id: Ie3d460dbc2ea1251469bf08504cfe2c6e80bbe75
parent 66fc9390
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -70,7 +70,7 @@ bool setBitrate(std::string ifname, uint32_t bitrate) {
    struct can_bittiming bt = {};
    bt.bitrate = bitrate;

    nl::MessageFactory<struct ifinfomsg> req(RTM_NEWLINK, NLM_F_REQUEST);
    nl::MessageFactory<struct ifinfomsg> req(RTM_NEWLINK, NLM_F_REQUEST | NLM_F_ACK);

    const auto ifidx = nametoindex(ifname);
    if (ifidx == 0) {
+2 −2
Original line number Diff line number Diff line
@@ -63,7 +63,7 @@ bool down(std::string ifname) {

bool add(std::string dev, std::string type) {
    nl::MessageFactory<struct ifinfomsg> req(RTM_NEWLINK,
                                             NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL);
                                             NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL | NLM_F_ACK);
    req.addattr(IFLA_IFNAME, dev);

    {
@@ -76,7 +76,7 @@ bool add(std::string dev, std::string type) {
}

bool del(std::string dev) {
    nl::MessageFactory<struct ifinfomsg> req(RTM_DELLINK, NLM_F_REQUEST);
    nl::MessageFactory<struct ifinfomsg> req(RTM_DELLINK, NLM_F_REQUEST | NLM_F_ACK);
    req.addattr(IFLA_IFNAME, dev);

    nl::Socket sock(NETLINK_ROUTE);
+1 −1
Original line number Diff line number Diff line
@@ -34,7 +34,7 @@ bool add(const std::string& eth, const std::string& vlan, uint16_t id) {
    }

    nl::MessageFactory<struct ifinfomsg> req(RTM_NEWLINK,
                                             NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL);
                                             NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL | NLM_F_ACK);
    req.addattr(IFLA_IFNAME, vlan);
    req.addattr<uint32_t>(IFLA_LINK, ethidx);

+33 −56
Original line number Diff line number Diff line
@@ -27,7 +27,7 @@ namespace android::nl {
 */
static constexpr bool kSuperVerbose = false;

Socket::Socket(int protocol, unsigned int pid, uint32_t groups) : mProtocol(protocol) {
Socket::Socket(int protocol, unsigned pid, uint32_t groups) : mProtocol(protocol) {
    mFd.reset(socket(AF_NETLINK, SOCK_RAW, protocol));
    if (!mFd.ok()) {
        PLOG(ERROR) << "Can't open Netlink socket";
@@ -47,83 +47,60 @@ Socket::Socket(int protocol, unsigned int pid, uint32_t groups) : mProtocol(prot
    }
}

bool Socket::send(nlmsghdr* nlmsg, size_t totalLen) {
    if constexpr (kSuperVerbose) {
        nlmsg->nlmsg_seq = mSeq;
        LOG(VERBOSE) << (mFailed ? "(not) " : "")
                     << "sending Netlink message: " << toString({nlmsg, totalLen}, mProtocol);
    }

    if (mFailed) return false;

    nlmsg->nlmsg_pid = 0;  // kernel
    nlmsg->nlmsg_seq = mSeq++;
    nlmsg->nlmsg_flags |= NLM_F_ACK;

    iovec iov = {nlmsg, nlmsg->nlmsg_len};

    sockaddr_nl sa = {};
    sa.nl_family = AF_NETLINK;

    msghdr msg = {};
    msg.msg_name = &sa;
    msg.msg_namelen = sizeof(sa);
    msg.msg_iov = &iov;
    msg.msg_iovlen = 1;

    if (sendmsg(mFd.get(), &msg, 0) < 0) {
        PLOG(ERROR) << "Can't send Netlink message";
        return false;
    }
    return true;
}

bool Socket::send(const Buffer<nlmsghdr>& msg, const sockaddr_nl& sa) {
    if constexpr (kSuperVerbose) {
        LOG(VERBOSE) << (mFailed ? "(not) " : "")
                     << "sending Netlink message: " << toString(msg, mProtocol);
        LOG(VERBOSE) << (mFailed ? "(not) " : "") << "sending Netlink message ("  //
                     << msg->nlmsg_pid << " -> " << sa.nl_pid << "): " << toString(msg, mProtocol);
    }

    if (mFailed) return false;

    mSeq = msg->nlmsg_seq;
    const auto rawMsg = msg.getRaw();
    const auto bytesSent = sendto(mFd.get(), rawMsg.ptr(), rawMsg.len(), 0,
                                  reinterpret_cast<const sockaddr*>(&sa), sizeof(sa));
    if (bytesSent < 0) {
        PLOG(ERROR) << "Can't send Netlink message";
        return false;
    } else if (size_t(bytesSent) != rawMsg.len()) {
        LOG(ERROR) << "Can't send Netlink message: truncated message";
        return false;
    }
    return true;
}

std::optional<Buffer<nlmsghdr>> Socket::receive(void* buf, size_t bufLen) {
    sockaddr_nl sa = {};
    return receive(buf, bufLen, sa);
std::optional<Buffer<nlmsghdr>> Socket::receive(size_t maxSize) {
    return receiveFrom(maxSize).first;
}

std::optional<Buffer<nlmsghdr>> Socket::receive(void* buf, size_t bufLen, sockaddr_nl& sa) {
    if (mFailed) return std::nullopt;
std::pair<std::optional<Buffer<nlmsghdr>>, sockaddr_nl> Socket::receiveFrom(size_t maxSize) {
    if (mFailed) return {std::nullopt, {}};

    socklen_t saLen = sizeof(sa);
    if (bufLen == 0) {
        LOG(ERROR) << "Receive buffer has zero size!";
        return std::nullopt;
    if (maxSize == 0) {
        LOG(ERROR) << "Maximum receive size should not be zero";
        return {std::nullopt, {}};
    }
    const auto bytesReceived =
            recvfrom(mFd.get(), buf, bufLen, MSG_TRUNC, reinterpret_cast<sockaddr*>(&sa), &saLen);
    if (mReceiveBuffer.size() < maxSize) mReceiveBuffer.resize(maxSize);

    sockaddr_nl sa = {};
    socklen_t saLen = sizeof(sa);
    const auto bytesReceived = recvfrom(mFd.get(), mReceiveBuffer.data(), maxSize, MSG_TRUNC,
                                        reinterpret_cast<sockaddr*>(&sa), &saLen);

    if (bytesReceived <= 0) {
        PLOG(ERROR) << "Failed to receive Netlink message";
        return std::nullopt;
    } else if (unsigned(bytesReceived) > bufLen) {
        PLOG(ERROR) << "Received data larger than the receive buffer! " << bytesReceived << " > "
                    << bufLen;
        return std::nullopt;
        return {std::nullopt, {}};
    } else if (size_t(bytesReceived) > maxSize) {
        PLOG(ERROR) << "Received data larger than maximum receive size: "  //
                    << bytesReceived << " > " << maxSize;
        return {std::nullopt, {}};
    }

    Buffer<nlmsghdr> msg(reinterpret_cast<nlmsghdr*>(buf), bytesReceived);
    Buffer<nlmsghdr> msg(reinterpret_cast<nlmsghdr*>(mReceiveBuffer.data()), bytesReceived);
    if constexpr (kSuperVerbose) {
        LOG(VERBOSE) << "received " << toString(msg, mProtocol);
        LOG(VERBOSE) << "received (" << sa.nl_pid << " -> " << msg->nlmsg_pid << "):"  //
                     << toString(msg, mProtocol);
    }
    return msg;
    return {msg, sa};
}

/* TODO(161389935): Migrate receiveAck to use nlmsg<> internally. Possibly reuse
@@ -179,11 +156,11 @@ bool Socket::receiveAck() {
    return false;
}

std::optional<unsigned int> Socket::getSocketPid() {
std::optional<unsigned> Socket::getPid() {
    sockaddr_nl sa = {};
    socklen_t sasize = sizeof(sa);
    if (getsockname(mFd.get(), reinterpret_cast<sockaddr*>(&sa), &sasize) < 0) {
        PLOG(ERROR) << "Failed to getsockname() for netlink_fd!";
        PLOG(ERROR) << "Failed to get PID of Netlink socket";
        return std::nullopt;
    }
    return sa.nl_pid;
+0 −1
Original line number Diff line number Diff line
@@ -35,7 +35,6 @@ void addattr_nest_end(struct nlmsghdr* n, struct nlattr* nest);

}  // namespace impl

// TODO(twasilczyk): rename to NetlinkMessage
/**
 * Wrapper around NETLINK_ROUTE messages, to build them in C++ style.
 *
Loading