Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 7374d1d8 authored by Linux Build Service Account's avatar Linux Build Service Account Committed by Gerrit - the friendly Code Review server
Browse files

Merge "udp6: fix socket leak on early demux"

parents da57c4f0 04cb69df
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -247,6 +247,7 @@ static inline __be16 udp_flow_src_port(struct net *net, struct sk_buff *skb,

/* net/ipv4/udp.c */
void udp_v4_early_demux(struct sk_buff *skb);
void udp_sk_rx_dst_set(struct sock *sk, struct dst_entry *dst);
int udp_get_port(struct sock *sk, unsigned short snum,
		 int (*saddr_cmp)(const struct sock *,
				  const struct sock *));
+2 −1
Original line number Diff line number Diff line
@@ -1627,7 +1627,7 @@ int udp_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
/* For TCP sockets, sk_rx_dst is protected by socket lock
 * For UDP, we use xchg() to guard against concurrent changes.
 */
static void udp_sk_rx_dst_set(struct sock *sk, struct dst_entry *dst)
void udp_sk_rx_dst_set(struct sock *sk, struct dst_entry *dst)
{
	struct dst_entry *old;

@@ -1635,6 +1635,7 @@ static void udp_sk_rx_dst_set(struct sock *sk, struct dst_entry *dst)
	old = xchg(&sk->sk_rx_dst, dst);
	dst_release(old);
}
EXPORT_SYMBOL(udp_sk_rx_dst_set);

/*
 *	Multicasts and broadcasts go to each listener.
+7 −1
Original line number Diff line number Diff line
@@ -250,8 +250,14 @@ int __ip6_datagram_connect(struct sock *sk, struct sockaddr *uaddr,
	 */

	err = ip6_datagram_dst_update(sk, true);
	if (err)
	if (err) {
		/* Reset daddr and dport so that udp_v6_early_demux()
		 * fails to find this socket
		 */
		memset(&sk->sk_v6_daddr, 0, sizeof(sk->sk_v6_daddr));
		inet->inet_dport = 0;
		goto out;
	}

	sk->sk_state = TCP_ESTABLISHED;
	sk_set_txhash(sk);
+86 −9
Original line number Diff line number Diff line
@@ -46,6 +46,7 @@
#include <net/tcp_states.h>
#include <net/ip6_checksum.h>
#include <net/xfrm.h>
#include <net/inet_hashtables.h>
#include <net/inet6_hashtables.h>
#include <net/busy_poll.h>
#include <net/sock_reuseport.h>
@@ -277,11 +278,7 @@ static struct sock *__udp6_lib_lookup_skb(struct sk_buff *skb,
					  struct udp_table *udptable)
{
	const struct ipv6hdr *iph = ipv6_hdr(skb);
	struct sock *sk;

	sk = skb_steal_sock(skb);
	if (unlikely(sk))
		return sk;
	return __udp6_lib_lookup(dev_net(skb->dev), &iph->saddr, sport,
				 &iph->daddr, dport, inet6_iif(skb),
				 udptable, skb);
@@ -799,6 +796,24 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
	if (udp6_csum_init(skb, uh, proto))
		goto csum_error;

	/* Check if the socket is already available, e.g. due to early demux */
	sk = skb_steal_sock(skb);
	if (sk) {
		struct dst_entry *dst = skb_dst(skb);
		int ret;

		if (unlikely(sk->sk_rx_dst != dst))
			udp_sk_rx_dst_set(sk, dst);

		ret = udpv6_queue_rcv_skb(sk, skb);
		sock_put(sk);

		/* a return value > 0 means to resubmit the input */
		if (ret > 0)
			return ret;
		return 0;
	}

	/*
	 *	Multicast receive code
	 */
@@ -807,11 +822,6 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
				saddr, daddr, udptable, proto);

	/* Unicast */

	/*
	 * check socket cache ... must talk to Alan about his plans
	 * for sock caches... i'll skip this for now.
	 */
	sk = __udp6_lib_lookup_skb(skb, uh->source, uh->dest, udptable);
	if (sk) {
		int ret;
@@ -866,6 +876,72 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
	return 0;
}

static struct sock *__udp6_lib_demux_lookup(struct net *net,
					    __be16 loc_port,
					    const struct in6_addr *loc_addr,
					    __be16 rmt_port,
					    const struct in6_addr *rmt_addr,
					    int dif)
{
	unsigned short hnum = ntohs(loc_port);
	unsigned int hash2 = udp6_portaddr_hash(net, loc_addr, hnum);
	unsigned int slot2 = hash2 & udp_table.mask;
	struct udp_hslot *hslot2 = &udp_table.hash2[slot2];

	const __portpair ports = INET_COMBINED_PORTS(rmt_port, hnum);
	struct sock *sk;

	udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) {
		if (sk->sk_state == TCP_ESTABLISHED &&
		    INET6_MATCH(sk, net, rmt_addr, loc_addr, ports, dif))
			return sk;
		/* Only check first socket in chain */
		break;
	}
	return NULL;
}

static void udp_v6_early_demux(struct sk_buff *skb)
{
	struct net *net = dev_net(skb->dev);
	const struct udphdr *uh;
	struct sock *sk;
	struct dst_entry *dst;
	int dif = skb->dev->ifindex;

	if (!pskb_may_pull(skb, skb_transport_offset(skb) +
	    sizeof(struct udphdr)))
		return;

	uh = udp_hdr(skb);

	if (skb->pkt_type == PACKET_HOST)
		sk = __udp6_lib_demux_lookup(net, uh->dest,
					     &ipv6_hdr(skb)->daddr,
					     uh->source, &ipv6_hdr(skb)->saddr,
					     dif);
	else
		return;

	if (!sk || !atomic_inc_not_zero_hint(&sk->sk_refcnt, 2))
		return;

	skb->sk = sk;
	skb->destructor = sock_efree;
	dst = READ_ONCE(sk->sk_rx_dst);

	if (dst)
		dst = dst_check(dst, inet6_sk(sk)->rx_dst_cookie);
	if (dst) {
		if (dst->flags & DST_NOCACHE) {
			if (likely(atomic_inc_not_zero(&dst->__refcnt)))
				skb_dst_set(skb, dst);
		} else {
			skb_dst_set_noref(skb, dst);
		}
	}
}

static __inline__ int udpv6_rcv(struct sk_buff *skb)
{
	return __udp6_lib_rcv(skb, &udp_table, IPPROTO_UDP);
@@ -1381,6 +1457,7 @@ int compat_udpv6_getsockopt(struct sock *sk, int level, int optname,
#endif

static const struct inet6_protocol udpv6_protocol = {
	.early_demux	=	udp_v6_early_demux,
	.handler	=	udpv6_rcv,
	.err_handler	=	udpv6_err,
	.flags		=	INET6_PROTO_NOPOLICY|INET6_PROTO_FINAL,