Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit e40526cb authored by Daniel Borkmann's avatar Daniel Borkmann Committed by David S. Miller
Browse files

packet: fix use after free race in send path when dev is released



Salam reported a use after free bug in PF_PACKET that occurs when
we're sending out frames on a socket bound device and suddenly the
net device is being unregistered. It appears that commit 827d9780
introduced a possible race condition between {t,}packet_snd() and
packet_notifier(). In the case of a bound socket, packet_notifier()
can drop the last reference to the net_device and {t,}packet_snd()
might end up suddenly sending a packet over a freed net_device.

To avoid reverting 827d9780 and thus introducing a performance
regression compared to the current state of things, we decided to
hold a cached RCU protected pointer to the net device and maintain
it on write side via bind spin_lock protected register_prot_hook()
and __unregister_prot_hook() calls.

In {t,}packet_snd() path, we access this pointer under rcu_read_lock
through packet_cached_dev_get() that holds reference to the device
to prevent it from being freed through packet_notifier() while
we're in send path. This is okay to do as dev_put()/dev_hold() are
per-cpu counters, so this should not be a performance issue. Also,
the code simplifies a bit as we don't need need_rls_dev anymore.

Fixes: 827d9780 ("af-packet: Use existing netdev reference for bound sockets.")
Reported-by: default avatarSalam Noureddine <noureddine@aristanetworks.com>
Signed-off-by: default avatarDaniel Borkmann <dborkman@redhat.com>
Signed-off-by: default avatarSalam Noureddine <noureddine@aristanetworks.com>
Cc: Ben Greear <greearb@candelatech.com>
Cc: Eric Dumazet <eric.dumazet@gmail.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent db739ef3
Loading
Loading
Loading
Loading
+36 −23
Original line number Original line Diff line number Diff line
@@ -244,11 +244,15 @@ static void __fanout_link(struct sock *sk, struct packet_sock *po);
static void register_prot_hook(struct sock *sk)
static void register_prot_hook(struct sock *sk)
{
{
	struct packet_sock *po = pkt_sk(sk);
	struct packet_sock *po = pkt_sk(sk);

	if (!po->running) {
	if (!po->running) {
		if (po->fanout)
		if (po->fanout) {
			__fanout_link(sk, po);
			__fanout_link(sk, po);
		else
		} else {
			dev_add_pack(&po->prot_hook);
			dev_add_pack(&po->prot_hook);
			rcu_assign_pointer(po->cached_dev, po->prot_hook.dev);
		}

		sock_hold(sk);
		sock_hold(sk);
		po->running = 1;
		po->running = 1;
	}
	}
@@ -266,10 +270,13 @@ static void __unregister_prot_hook(struct sock *sk, bool sync)
	struct packet_sock *po = pkt_sk(sk);
	struct packet_sock *po = pkt_sk(sk);


	po->running = 0;
	po->running = 0;
	if (po->fanout)
	if (po->fanout) {
		__fanout_unlink(sk, po);
		__fanout_unlink(sk, po);
	else
	} else {
		__dev_remove_pack(&po->prot_hook);
		__dev_remove_pack(&po->prot_hook);
		RCU_INIT_POINTER(po->cached_dev, NULL);
	}

	__sock_put(sk);
	__sock_put(sk);


	if (sync) {
	if (sync) {
@@ -2052,12 +2059,24 @@ static int tpacket_fill_skb(struct packet_sock *po, struct sk_buff *skb,
	return tp_len;
	return tp_len;
}
}


static struct net_device *packet_cached_dev_get(struct packet_sock *po)
{
	struct net_device *dev;

	rcu_read_lock();
	dev = rcu_dereference(po->cached_dev);
	if (dev)
		dev_hold(dev);
	rcu_read_unlock();

	return dev;
}

static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
{
{
	struct sk_buff *skb;
	struct sk_buff *skb;
	struct net_device *dev;
	struct net_device *dev;
	__be16 proto;
	__be16 proto;
	bool need_rls_dev = false;
	int err, reserve = 0;
	int err, reserve = 0;
	void *ph;
	void *ph;
	struct sockaddr_ll *saddr = (struct sockaddr_ll *)msg->msg_name;
	struct sockaddr_ll *saddr = (struct sockaddr_ll *)msg->msg_name;
@@ -2070,7 +2089,7 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
	mutex_lock(&po->pg_vec_lock);
	mutex_lock(&po->pg_vec_lock);


	if (saddr == NULL) {
	if (saddr == NULL) {
		dev = po->prot_hook.dev;
		dev	= packet_cached_dev_get(po);
		proto	= po->num;
		proto	= po->num;
		addr	= NULL;
		addr	= NULL;
	} else {
	} else {
@@ -2084,19 +2103,17 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
		proto	= saddr->sll_protocol;
		proto	= saddr->sll_protocol;
		addr	= saddr->sll_addr;
		addr	= saddr->sll_addr;
		dev = dev_get_by_index(sock_net(&po->sk), saddr->sll_ifindex);
		dev = dev_get_by_index(sock_net(&po->sk), saddr->sll_ifindex);
		need_rls_dev = true;
	}
	}


	err = -ENXIO;
	err = -ENXIO;
	if (unlikely(dev == NULL))
	if (unlikely(dev == NULL))
		goto out;
		goto out;

	reserve = dev->hard_header_len;

	err = -ENETDOWN;
	err = -ENETDOWN;
	if (unlikely(!(dev->flags & IFF_UP)))
	if (unlikely(!(dev->flags & IFF_UP)))
		goto out_put;
		goto out_put;


	reserve = dev->hard_header_len;

	size_max = po->tx_ring.frame_size
	size_max = po->tx_ring.frame_size
		- (po->tp_hdrlen - sizeof(struct sockaddr_ll));
		- (po->tp_hdrlen - sizeof(struct sockaddr_ll));


@@ -2173,7 +2190,6 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
	__packet_set_status(po, ph, status);
	__packet_set_status(po, ph, status);
	kfree_skb(skb);
	kfree_skb(skb);
out_put:
out_put:
	if (need_rls_dev)
	dev_put(dev);
	dev_put(dev);
out:
out:
	mutex_unlock(&po->pg_vec_lock);
	mutex_unlock(&po->pg_vec_lock);
@@ -2212,7 +2228,6 @@ static int packet_snd(struct socket *sock,
	struct sk_buff *skb;
	struct sk_buff *skb;
	struct net_device *dev;
	struct net_device *dev;
	__be16 proto;
	__be16 proto;
	bool need_rls_dev = false;
	unsigned char *addr;
	unsigned char *addr;
	int err, reserve = 0;
	int err, reserve = 0;
	struct virtio_net_hdr vnet_hdr = { 0 };
	struct virtio_net_hdr vnet_hdr = { 0 };
@@ -2228,7 +2243,7 @@ static int packet_snd(struct socket *sock,
	 */
	 */


	if (saddr == NULL) {
	if (saddr == NULL) {
		dev = po->prot_hook.dev;
		dev	= packet_cached_dev_get(po);
		proto	= po->num;
		proto	= po->num;
		addr	= NULL;
		addr	= NULL;
	} else {
	} else {
@@ -2240,19 +2255,17 @@ static int packet_snd(struct socket *sock,
		proto	= saddr->sll_protocol;
		proto	= saddr->sll_protocol;
		addr	= saddr->sll_addr;
		addr	= saddr->sll_addr;
		dev = dev_get_by_index(sock_net(sk), saddr->sll_ifindex);
		dev = dev_get_by_index(sock_net(sk), saddr->sll_ifindex);
		need_rls_dev = true;
	}
	}


	err = -ENXIO;
	err = -ENXIO;
	if (dev == NULL)
	if (unlikely(dev == NULL))
		goto out_unlock;
		goto out_unlock;
	if (sock->type == SOCK_RAW)
		reserve = dev->hard_header_len;

	err = -ENETDOWN;
	err = -ENETDOWN;
	if (!(dev->flags & IFF_UP))
	if (unlikely(!(dev->flags & IFF_UP)))
		goto out_unlock;
		goto out_unlock;


	if (sock->type == SOCK_RAW)
		reserve = dev->hard_header_len;
	if (po->has_vnet_hdr) {
	if (po->has_vnet_hdr) {
		vnet_hdr_len = sizeof(vnet_hdr);
		vnet_hdr_len = sizeof(vnet_hdr);


@@ -2386,7 +2399,6 @@ static int packet_snd(struct socket *sock,
	if (err > 0 && (err = net_xmit_errno(err)) != 0)
	if (err > 0 && (err = net_xmit_errno(err)) != 0)
		goto out_unlock;
		goto out_unlock;


	if (need_rls_dev)
	dev_put(dev);
	dev_put(dev);


	return len;
	return len;
@@ -2394,7 +2406,7 @@ static int packet_snd(struct socket *sock,
out_free:
out_free:
	kfree_skb(skb);
	kfree_skb(skb);
out_unlock:
out_unlock:
	if (dev && need_rls_dev)
	if (dev)
		dev_put(dev);
		dev_put(dev);
out:
out:
	return err;
	return err;
@@ -2614,6 +2626,7 @@ static int packet_create(struct net *net, struct socket *sock, int protocol,
	po = pkt_sk(sk);
	po = pkt_sk(sk);
	sk->sk_family = PF_PACKET;
	sk->sk_family = PF_PACKET;
	po->num = proto;
	po->num = proto;
	RCU_INIT_POINTER(po->cached_dev, NULL);


	sk->sk_destruct = packet_sock_destruct;
	sk->sk_destruct = packet_sock_destruct;
	sk_refcnt_debug_inc(sk);
	sk_refcnt_debug_inc(sk);
+1 −0
Original line number Original line Diff line number Diff line
@@ -113,6 +113,7 @@ struct packet_sock {
	unsigned int		tp_loss:1;
	unsigned int		tp_loss:1;
	unsigned int		tp_tx_has_off:1;
	unsigned int		tp_tx_has_off:1;
	unsigned int		tp_tstamp;
	unsigned int		tp_tstamp;
	struct net_device __rcu	*cached_dev;
	struct packet_type	prot_hook ____cacheline_aligned_in_smp;
	struct packet_type	prot_hook ____cacheline_aligned_in_smp;
};
};