Loading include/net/udp.h +1 −0 Original line number Diff line number Diff line Loading @@ -247,6 +247,7 @@ static inline __be16 udp_flow_src_port(struct net *net, struct sk_buff *skb, /* net/ipv4/udp.c */ void udp_v4_early_demux(struct sk_buff *skb); void udp_sk_rx_dst_set(struct sock *sk, struct dst_entry *dst); int udp_get_port(struct sock *sk, unsigned short snum, int (*saddr_cmp)(const struct sock *, const struct sock *)); Loading net/ipv4/udp.c +2 −1 Original line number Diff line number Diff line Loading @@ -1627,7 +1627,7 @@ int udp_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) /* For TCP sockets, sk_rx_dst is protected by socket lock * For UDP, we use xchg() to guard against concurrent changes. */ static void udp_sk_rx_dst_set(struct sock *sk, struct dst_entry *dst) void udp_sk_rx_dst_set(struct sock *sk, struct dst_entry *dst) { struct dst_entry *old; Loading @@ -1635,6 +1635,7 @@ static void udp_sk_rx_dst_set(struct sock *sk, struct dst_entry *dst) old = xchg(&sk->sk_rx_dst, dst); dst_release(old); } EXPORT_SYMBOL(udp_sk_rx_dst_set); /* * Multicasts and broadcasts go to each listener. Loading net/ipv6/datagram.c +7 −1 Original line number Diff line number Diff line Loading @@ -250,8 +250,14 @@ int __ip6_datagram_connect(struct sock *sk, struct sockaddr *uaddr, */ err = ip6_datagram_dst_update(sk, true); if (err) if (err) { /* Reset daddr and dport so that udp_v6_early_demux() * fails to find this socket */ memset(&sk->sk_v6_daddr, 0, sizeof(sk->sk_v6_daddr)); inet->inet_dport = 0; goto out; } sk->sk_state = TCP_ESTABLISHED; sk_set_txhash(sk); Loading net/ipv6/udp.c +86 −9 Original line number Diff line number Diff line Loading @@ -46,6 +46,7 @@ #include <net/tcp_states.h> #include <net/ip6_checksum.h> #include <net/xfrm.h> #include <net/inet_hashtables.h> #include <net/inet6_hashtables.h> #include <net/busy_poll.h> #include <net/sock_reuseport.h> Loading Loading @@ -277,11 +278,7 @@ static struct sock *__udp6_lib_lookup_skb(struct sk_buff *skb, struct udp_table *udptable) { const struct ipv6hdr *iph = ipv6_hdr(skb); struct sock *sk; sk = skb_steal_sock(skb); if (unlikely(sk)) return sk; return __udp6_lib_lookup(dev_net(skb->dev), &iph->saddr, sport, &iph->daddr, dport, inet6_iif(skb), udptable, skb); Loading Loading @@ -799,6 +796,24 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, if (udp6_csum_init(skb, uh, proto)) goto csum_error; /* Check if the socket is already available, e.g. due to early demux */ sk = skb_steal_sock(skb); if (sk) { struct dst_entry *dst = skb_dst(skb); int ret; if (unlikely(sk->sk_rx_dst != dst)) udp_sk_rx_dst_set(sk, dst); ret = udpv6_queue_rcv_skb(sk, skb); sock_put(sk); /* a return value > 0 means to resubmit the input */ if (ret > 0) return ret; return 0; } /* * Multicast receive code */ Loading @@ -807,11 +822,6 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, saddr, daddr, udptable, proto); /* Unicast */ /* * check socket cache ... must talk to Alan about his plans * for sock caches... i'll skip this for now. */ sk = __udp6_lib_lookup_skb(skb, uh->source, uh->dest, udptable); if (sk) { int ret; Loading Loading @@ -866,6 +876,72 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, return 0; } static struct sock *__udp6_lib_demux_lookup(struct net *net, __be16 loc_port, const struct in6_addr *loc_addr, __be16 rmt_port, const struct in6_addr *rmt_addr, int dif) { unsigned short hnum = ntohs(loc_port); unsigned int hash2 = udp6_portaddr_hash(net, loc_addr, hnum); unsigned int slot2 = hash2 & udp_table.mask; struct udp_hslot *hslot2 = &udp_table.hash2[slot2]; const __portpair ports = INET_COMBINED_PORTS(rmt_port, hnum); struct sock *sk; udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) { if (sk->sk_state == TCP_ESTABLISHED && INET6_MATCH(sk, net, rmt_addr, loc_addr, ports, dif)) return sk; /* Only check first socket in chain */ break; } return NULL; } static void udp_v6_early_demux(struct sk_buff *skb) { struct net *net = dev_net(skb->dev); const struct udphdr *uh; struct sock *sk; struct dst_entry *dst; int dif = skb->dev->ifindex; if (!pskb_may_pull(skb, skb_transport_offset(skb) + sizeof(struct udphdr))) return; uh = udp_hdr(skb); if (skb->pkt_type == PACKET_HOST) sk = __udp6_lib_demux_lookup(net, uh->dest, &ipv6_hdr(skb)->daddr, uh->source, &ipv6_hdr(skb)->saddr, dif); else return; if (!sk || !atomic_inc_not_zero_hint(&sk->sk_refcnt, 2)) return; skb->sk = sk; skb->destructor = sock_efree; dst = READ_ONCE(sk->sk_rx_dst); if (dst) dst = dst_check(dst, inet6_sk(sk)->rx_dst_cookie); if (dst) { if (dst->flags & DST_NOCACHE) { if (likely(atomic_inc_not_zero(&dst->__refcnt))) skb_dst_set(skb, dst); } else { skb_dst_set_noref(skb, dst); } } } static __inline__ int udpv6_rcv(struct sk_buff *skb) { return __udp6_lib_rcv(skb, &udp_table, IPPROTO_UDP); Loading Loading @@ -1381,6 +1457,7 @@ int compat_udpv6_getsockopt(struct sock *sk, int level, int optname, #endif static const struct inet6_protocol udpv6_protocol = { .early_demux = udp_v6_early_demux, .handler = udpv6_rcv, .err_handler = udpv6_err, .flags = INET6_PROTO_NOPOLICY|INET6_PROTO_FINAL, Loading Loading
include/net/udp.h +1 −0 Original line number Diff line number Diff line Loading @@ -247,6 +247,7 @@ static inline __be16 udp_flow_src_port(struct net *net, struct sk_buff *skb, /* net/ipv4/udp.c */ void udp_v4_early_demux(struct sk_buff *skb); void udp_sk_rx_dst_set(struct sock *sk, struct dst_entry *dst); int udp_get_port(struct sock *sk, unsigned short snum, int (*saddr_cmp)(const struct sock *, const struct sock *)); Loading
net/ipv4/udp.c +2 −1 Original line number Diff line number Diff line Loading @@ -1627,7 +1627,7 @@ int udp_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) /* For TCP sockets, sk_rx_dst is protected by socket lock * For UDP, we use xchg() to guard against concurrent changes. */ static void udp_sk_rx_dst_set(struct sock *sk, struct dst_entry *dst) void udp_sk_rx_dst_set(struct sock *sk, struct dst_entry *dst) { struct dst_entry *old; Loading @@ -1635,6 +1635,7 @@ static void udp_sk_rx_dst_set(struct sock *sk, struct dst_entry *dst) old = xchg(&sk->sk_rx_dst, dst); dst_release(old); } EXPORT_SYMBOL(udp_sk_rx_dst_set); /* * Multicasts and broadcasts go to each listener. Loading
net/ipv6/datagram.c +7 −1 Original line number Diff line number Diff line Loading @@ -250,8 +250,14 @@ int __ip6_datagram_connect(struct sock *sk, struct sockaddr *uaddr, */ err = ip6_datagram_dst_update(sk, true); if (err) if (err) { /* Reset daddr and dport so that udp_v6_early_demux() * fails to find this socket */ memset(&sk->sk_v6_daddr, 0, sizeof(sk->sk_v6_daddr)); inet->inet_dport = 0; goto out; } sk->sk_state = TCP_ESTABLISHED; sk_set_txhash(sk); Loading
net/ipv6/udp.c +86 −9 Original line number Diff line number Diff line Loading @@ -46,6 +46,7 @@ #include <net/tcp_states.h> #include <net/ip6_checksum.h> #include <net/xfrm.h> #include <net/inet_hashtables.h> #include <net/inet6_hashtables.h> #include <net/busy_poll.h> #include <net/sock_reuseport.h> Loading Loading @@ -277,11 +278,7 @@ static struct sock *__udp6_lib_lookup_skb(struct sk_buff *skb, struct udp_table *udptable) { const struct ipv6hdr *iph = ipv6_hdr(skb); struct sock *sk; sk = skb_steal_sock(skb); if (unlikely(sk)) return sk; return __udp6_lib_lookup(dev_net(skb->dev), &iph->saddr, sport, &iph->daddr, dport, inet6_iif(skb), udptable, skb); Loading Loading @@ -799,6 +796,24 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, if (udp6_csum_init(skb, uh, proto)) goto csum_error; /* Check if the socket is already available, e.g. due to early demux */ sk = skb_steal_sock(skb); if (sk) { struct dst_entry *dst = skb_dst(skb); int ret; if (unlikely(sk->sk_rx_dst != dst)) udp_sk_rx_dst_set(sk, dst); ret = udpv6_queue_rcv_skb(sk, skb); sock_put(sk); /* a return value > 0 means to resubmit the input */ if (ret > 0) return ret; return 0; } /* * Multicast receive code */ Loading @@ -807,11 +822,6 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, saddr, daddr, udptable, proto); /* Unicast */ /* * check socket cache ... must talk to Alan about his plans * for sock caches... i'll skip this for now. */ sk = __udp6_lib_lookup_skb(skb, uh->source, uh->dest, udptable); if (sk) { int ret; Loading Loading @@ -866,6 +876,72 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, return 0; } static struct sock *__udp6_lib_demux_lookup(struct net *net, __be16 loc_port, const struct in6_addr *loc_addr, __be16 rmt_port, const struct in6_addr *rmt_addr, int dif) { unsigned short hnum = ntohs(loc_port); unsigned int hash2 = udp6_portaddr_hash(net, loc_addr, hnum); unsigned int slot2 = hash2 & udp_table.mask; struct udp_hslot *hslot2 = &udp_table.hash2[slot2]; const __portpair ports = INET_COMBINED_PORTS(rmt_port, hnum); struct sock *sk; udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) { if (sk->sk_state == TCP_ESTABLISHED && INET6_MATCH(sk, net, rmt_addr, loc_addr, ports, dif)) return sk; /* Only check first socket in chain */ break; } return NULL; } static void udp_v6_early_demux(struct sk_buff *skb) { struct net *net = dev_net(skb->dev); const struct udphdr *uh; struct sock *sk; struct dst_entry *dst; int dif = skb->dev->ifindex; if (!pskb_may_pull(skb, skb_transport_offset(skb) + sizeof(struct udphdr))) return; uh = udp_hdr(skb); if (skb->pkt_type == PACKET_HOST) sk = __udp6_lib_demux_lookup(net, uh->dest, &ipv6_hdr(skb)->daddr, uh->source, &ipv6_hdr(skb)->saddr, dif); else return; if (!sk || !atomic_inc_not_zero_hint(&sk->sk_refcnt, 2)) return; skb->sk = sk; skb->destructor = sock_efree; dst = READ_ONCE(sk->sk_rx_dst); if (dst) dst = dst_check(dst, inet6_sk(sk)->rx_dst_cookie); if (dst) { if (dst->flags & DST_NOCACHE) { if (likely(atomic_inc_not_zero(&dst->__refcnt))) skb_dst_set(skb, dst); } else { skb_dst_set_noref(skb, dst); } } } static __inline__ int udpv6_rcv(struct sk_buff *skb) { return __udp6_lib_rcv(skb, &udp_table, IPPROTO_UDP); Loading Loading @@ -1381,6 +1457,7 @@ int compat_udpv6_getsockopt(struct sock *sk, int level, int optname, #endif static const struct inet6_protocol udpv6_protocol = { .early_demux = udp_v6_early_demux, .handler = udpv6_rcv, .err_handler = udpv6_err, .flags = INET6_PROTO_NOPOLICY|INET6_PROTO_FINAL, Loading