Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 4297a0ef authored by David Ahern's avatar David Ahern Committed by David S. Miller
Browse files

net: ipv6: add second dif to inet6 socket lookups



Add a second device index, sdif, to inet6 socket lookups. sdif is the
index for ingress devices enslaved to an l3mdev. It allows the lookups
to consider the enslaved device as well as the L3 domain when searching
for a socket.

TCP moves the data in the cb. Prior to tcp_v4_rcv (e.g., early demux) the
ingress index is obtained from IPCB using inet_sdif and after tcp_v4_rcv
tcp_v4_sdif is used.

Signed-off-by: default avatarDavid Ahern <dsahern@gmail.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 1801b570
Loading
Loading
Loading
Loading
+13 −9
Original line number Original line Diff line number Diff line
@@ -49,7 +49,8 @@ struct sock *__inet6_lookup_established(struct net *net,
					const struct in6_addr *saddr,
					const struct in6_addr *saddr,
					const __be16 sport,
					const __be16 sport,
					const struct in6_addr *daddr,
					const struct in6_addr *daddr,
					const u16 hnum, const int dif);
					const u16 hnum, const int dif,
					const int sdif);


struct sock *inet6_lookup_listener(struct net *net,
struct sock *inet6_lookup_listener(struct net *net,
				   struct inet_hashinfo *hashinfo,
				   struct inet_hashinfo *hashinfo,
@@ -57,7 +58,8 @@ struct sock *inet6_lookup_listener(struct net *net,
				   const struct in6_addr *saddr,
				   const struct in6_addr *saddr,
				   const __be16 sport,
				   const __be16 sport,
				   const struct in6_addr *daddr,
				   const struct in6_addr *daddr,
				   const unsigned short hnum, const int dif);
				   const unsigned short hnum,
				   const int dif, const int sdif);


static inline struct sock *__inet6_lookup(struct net *net,
static inline struct sock *__inet6_lookup(struct net *net,
					  struct inet_hashinfo *hashinfo,
					  struct inet_hashinfo *hashinfo,
@@ -66,24 +68,25 @@ static inline struct sock *__inet6_lookup(struct net *net,
					  const __be16 sport,
					  const __be16 sport,
					  const struct in6_addr *daddr,
					  const struct in6_addr *daddr,
					  const u16 hnum,
					  const u16 hnum,
					  const int dif,
					  const int dif, const int sdif,
					  bool *refcounted)
					  bool *refcounted)
{
{
	struct sock *sk = __inet6_lookup_established(net, hashinfo, saddr,
	struct sock *sk = __inet6_lookup_established(net, hashinfo, saddr,
						sport, daddr, hnum, dif);
						     sport, daddr, hnum,
						     dif, sdif);
	*refcounted = true;
	*refcounted = true;
	if (sk)
	if (sk)
		return sk;
		return sk;
	*refcounted = false;
	*refcounted = false;
	return inet6_lookup_listener(net, hashinfo, skb, doff, saddr, sport,
	return inet6_lookup_listener(net, hashinfo, skb, doff, saddr, sport,
				     daddr, hnum, dif);
				     daddr, hnum, dif, sdif);
}
}


static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo,
static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo,
					      struct sk_buff *skb, int doff,
					      struct sk_buff *skb, int doff,
					      const __be16 sport,
					      const __be16 sport,
					      const __be16 dport,
					      const __be16 dport,
					      int iif,
					      int iif, int sdif,
					      bool *refcounted)
					      bool *refcounted)
{
{
	struct sock *sk = skb_steal_sock(skb);
	struct sock *sk = skb_steal_sock(skb);
@@ -95,7 +98,7 @@ static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo,
	return __inet6_lookup(dev_net(skb_dst(skb)->dev), hashinfo, skb,
	return __inet6_lookup(dev_net(skb_dst(skb)->dev), hashinfo, skb,
			      doff, &ipv6_hdr(skb)->saddr, sport,
			      doff, &ipv6_hdr(skb)->saddr, sport,
			      &ipv6_hdr(skb)->daddr, ntohs(dport),
			      &ipv6_hdr(skb)->daddr, ntohs(dport),
			      iif, refcounted);
			      iif, sdif, refcounted);
}
}


struct sock *inet6_lookup(struct net *net, struct inet_hashinfo *hashinfo,
struct sock *inet6_lookup(struct net *net, struct inet_hashinfo *hashinfo,
@@ -107,13 +110,14 @@ struct sock *inet6_lookup(struct net *net, struct inet_hashinfo *hashinfo,
int inet6_hash(struct sock *sk);
int inet6_hash(struct sock *sk);
#endif /* IS_ENABLED(CONFIG_IPV6) */
#endif /* IS_ENABLED(CONFIG_IPV6) */


#define INET6_MATCH(__sk, __net, __saddr, __daddr, __ports, __dif)	\
#define INET6_MATCH(__sk, __net, __saddr, __daddr, __ports, __dif, __sdif) \
	(((__sk)->sk_portpair == (__ports))			&&	\
	(((__sk)->sk_portpair == (__ports))			&&	\
	 ((__sk)->sk_family == AF_INET6)			&&	\
	 ((__sk)->sk_family == AF_INET6)			&&	\
	 ipv6_addr_equal(&(__sk)->sk_v6_daddr, (__saddr))		&&	\
	 ipv6_addr_equal(&(__sk)->sk_v6_daddr, (__saddr))		&&	\
	 ipv6_addr_equal(&(__sk)->sk_v6_rcv_saddr, (__daddr))	&&	\
	 ipv6_addr_equal(&(__sk)->sk_v6_rcv_saddr, (__daddr))	&&	\
	 (!(__sk)->sk_bound_dev_if	||				\
	 (!(__sk)->sk_bound_dev_if	||				\
	   ((__sk)->sk_bound_dev_if == (__dif))) 		&&	\
	   ((__sk)->sk_bound_dev_if == (__dif))	||			\
	   ((__sk)->sk_bound_dev_if == (__sdif)))		&&	\
	 net_eq(sock_net(__sk), (__net)))
	 net_eq(sock_net(__sk), (__net)))


#endif /* _INET6_HASHTABLES_H */
#endif /* _INET6_HASHTABLES_H */
+10 −0
Original line number Original line Diff line number Diff line
@@ -827,6 +827,16 @@ static inline int tcp_v6_iif(const struct sk_buff *skb)


	return l3_slave ? skb->skb_iif : TCP_SKB_CB(skb)->header.h6.iif;
	return l3_slave ? skb->skb_iif : TCP_SKB_CB(skb)->header.h6.iif;
}
}

/* TCP_SKB_CB reference means this can not be used from early demux */
static inline int tcp_v6_sdif(const struct sk_buff *skb)
{
#if IS_ENABLED(CONFIG_NET_L3_MASTER_DEV)
	if (skb && ipv6_l3mdev_skb(TCP_SKB_CB(skb)->header.h6.flags))
		return TCP_SKB_CB(skb)->header.h6.iif;
#endif
	return 0;
}
#endif
#endif


/* TCP_SKB_CB reference means this can not be used from early demux */
/* TCP_SKB_CB reference means this can not be used from early demux */
+2 −2
Original line number Original line Diff line number Diff line
@@ -89,7 +89,7 @@ static void dccp_v6_err(struct sk_buff *skb, struct inet6_skb_parm *opt,
	sk = __inet6_lookup_established(net, &dccp_hashinfo,
	sk = __inet6_lookup_established(net, &dccp_hashinfo,
					&hdr->daddr, dh->dccph_dport,
					&hdr->daddr, dh->dccph_dport,
					&hdr->saddr, ntohs(dh->dccph_sport),
					&hdr->saddr, ntohs(dh->dccph_sport),
					inet6_iif(skb));
					inet6_iif(skb), 0);


	if (!sk) {
	if (!sk) {
		__ICMP6_INC_STATS(net, __in6_dev_get(skb->dev),
		__ICMP6_INC_STATS(net, __in6_dev_get(skb->dev),
@@ -687,7 +687,7 @@ static int dccp_v6_rcv(struct sk_buff *skb)
lookup:
lookup:
	sk = __inet6_lookup_skb(&dccp_hashinfo, skb, __dccp_hdr_len(dh),
	sk = __inet6_lookup_skb(&dccp_hashinfo, skb, __dccp_hdr_len(dh),
			        dh->dccph_sport, dh->dccph_dport,
			        dh->dccph_sport, dh->dccph_dport,
				inet6_iif(skb), &refcounted);
				inet6_iif(skb), 0, &refcounted);
	if (!sk) {
	if (!sk) {
		dccp_pr_debug("failed to look up flow ID in table and "
		dccp_pr_debug("failed to look up flow ID in table and "
			      "get corresponding socket\n");
			      "get corresponding socket\n");
+17 −11
Original line number Original line Diff line number Diff line
@@ -56,7 +56,7 @@ struct sock *__inet6_lookup_established(struct net *net,
					   const __be16 sport,
					   const __be16 sport,
					   const struct in6_addr *daddr,
					   const struct in6_addr *daddr,
					   const u16 hnum,
					   const u16 hnum,
					   const int dif)
					   const int dif, const int sdif)
{
{
	struct sock *sk;
	struct sock *sk;
	const struct hlist_nulls_node *node;
	const struct hlist_nulls_node *node;
@@ -73,12 +73,12 @@ struct sock *__inet6_lookup_established(struct net *net,
	sk_nulls_for_each_rcu(sk, node, &head->chain) {
	sk_nulls_for_each_rcu(sk, node, &head->chain) {
		if (sk->sk_hash != hash)
		if (sk->sk_hash != hash)
			continue;
			continue;
		if (!INET6_MATCH(sk, net, saddr, daddr, ports, dif))
		if (!INET6_MATCH(sk, net, saddr, daddr, ports, dif, sdif))
			continue;
			continue;
		if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt)))
		if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt)))
			goto out;
			goto out;


		if (unlikely(!INET6_MATCH(sk, net, saddr, daddr, ports, dif))) {
		if (unlikely(!INET6_MATCH(sk, net, saddr, daddr, ports, dif, sdif))) {
			sock_gen_put(sk);
			sock_gen_put(sk);
			goto begin;
			goto begin;
		}
		}
@@ -96,7 +96,7 @@ EXPORT_SYMBOL(__inet6_lookup_established);
static inline int compute_score(struct sock *sk, struct net *net,
static inline int compute_score(struct sock *sk, struct net *net,
				const unsigned short hnum,
				const unsigned short hnum,
				const struct in6_addr *daddr,
				const struct in6_addr *daddr,
				const int dif, bool exact_dif)
				const int dif, const int sdif, bool exact_dif)
{
{
	int score = -1;
	int score = -1;


@@ -110,8 +110,12 @@ static inline int compute_score(struct sock *sk, struct net *net,
			score++;
			score++;
		}
		}
		if (sk->sk_bound_dev_if || exact_dif) {
		if (sk->sk_bound_dev_if || exact_dif) {
			if (sk->sk_bound_dev_if != dif)
			bool dev_match = (sk->sk_bound_dev_if == dif ||
					  sk->sk_bound_dev_if == sdif);

			if (exact_dif && !dev_match)
				return -1;
				return -1;
			if (sk->sk_bound_dev_if && dev_match)
				score++;
				score++;
		}
		}
		if (sk->sk_incoming_cpu == raw_smp_processor_id())
		if (sk->sk_incoming_cpu == raw_smp_processor_id())
@@ -126,7 +130,7 @@ struct sock *inet6_lookup_listener(struct net *net,
		struct sk_buff *skb, int doff,
		struct sk_buff *skb, int doff,
		const struct in6_addr *saddr,
		const struct in6_addr *saddr,
		const __be16 sport, const struct in6_addr *daddr,
		const __be16 sport, const struct in6_addr *daddr,
		const unsigned short hnum, const int dif)
		const unsigned short hnum, const int dif, const int sdif)
{
{
	unsigned int hash = inet_lhashfn(net, hnum);
	unsigned int hash = inet_lhashfn(net, hnum);
	struct inet_listen_hashbucket *ilb = &hashinfo->listening_hash[hash];
	struct inet_listen_hashbucket *ilb = &hashinfo->listening_hash[hash];
@@ -136,7 +140,7 @@ struct sock *inet6_lookup_listener(struct net *net,
	u32 phash = 0;
	u32 phash = 0;


	sk_for_each(sk, &ilb->head) {
	sk_for_each(sk, &ilb->head) {
		score = compute_score(sk, net, hnum, daddr, dif, exact_dif);
		score = compute_score(sk, net, hnum, daddr, dif, sdif, exact_dif);
		if (score > hiscore) {
		if (score > hiscore) {
			reuseport = sk->sk_reuseport;
			reuseport = sk->sk_reuseport;
			if (reuseport) {
			if (reuseport) {
@@ -171,7 +175,7 @@ struct sock *inet6_lookup(struct net *net, struct inet_hashinfo *hashinfo,
	bool refcounted;
	bool refcounted;


	sk = __inet6_lookup(net, hashinfo, skb, doff, saddr, sport, daddr,
	sk = __inet6_lookup(net, hashinfo, skb, doff, saddr, sport, daddr,
			    ntohs(dport), dif, &refcounted);
			    ntohs(dport), dif, 0, &refcounted);
	if (sk && !refcounted && !refcount_inc_not_zero(&sk->sk_refcnt))
	if (sk && !refcounted && !refcount_inc_not_zero(&sk->sk_refcnt))
		sk = NULL;
		sk = NULL;
	return sk;
	return sk;
@@ -187,8 +191,9 @@ static int __inet6_check_established(struct inet_timewait_death_row *death_row,
	const struct in6_addr *daddr = &sk->sk_v6_rcv_saddr;
	const struct in6_addr *daddr = &sk->sk_v6_rcv_saddr;
	const struct in6_addr *saddr = &sk->sk_v6_daddr;
	const struct in6_addr *saddr = &sk->sk_v6_daddr;
	const int dif = sk->sk_bound_dev_if;
	const int dif = sk->sk_bound_dev_if;
	const __portpair ports = INET_COMBINED_PORTS(inet->inet_dport, lport);
	struct net *net = sock_net(sk);
	struct net *net = sock_net(sk);
	const int sdif = l3mdev_master_ifindex_by_index(net, dif);
	const __portpair ports = INET_COMBINED_PORTS(inet->inet_dport, lport);
	const unsigned int hash = inet6_ehashfn(net, daddr, lport, saddr,
	const unsigned int hash = inet6_ehashfn(net, daddr, lport, saddr,
						inet->inet_dport);
						inet->inet_dport);
	struct inet_ehash_bucket *head = inet_ehash_bucket(hinfo, hash);
	struct inet_ehash_bucket *head = inet_ehash_bucket(hinfo, hash);
@@ -203,7 +208,8 @@ static int __inet6_check_established(struct inet_timewait_death_row *death_row,
		if (sk2->sk_hash != hash)
		if (sk2->sk_hash != hash)
			continue;
			continue;


		if (likely(INET6_MATCH(sk2, net, saddr, daddr, ports, dif))) {
		if (likely(INET6_MATCH(sk2, net, saddr, daddr, ports,
				       dif, sdif))) {
			if (sk2->sk_state == TCP_TIME_WAIT) {
			if (sk2->sk_state == TCP_TIME_WAIT) {
				tw = inet_twsk(sk2);
				tw = inet_twsk(sk2);
				if (twsk_unique(sk, sk2, twp))
				if (twsk_unique(sk, sk2, twp))
+8 −5
Original line number Original line Diff line number Diff line
@@ -350,7 +350,7 @@ static void tcp_v6_err(struct sk_buff *skb, struct inet6_skb_parm *opt,
	sk = __inet6_lookup_established(net, &tcp_hashinfo,
	sk = __inet6_lookup_established(net, &tcp_hashinfo,
					&hdr->daddr, th->dest,
					&hdr->daddr, th->dest,
					&hdr->saddr, ntohs(th->source),
					&hdr->saddr, ntohs(th->source),
					skb->dev->ifindex);
					skb->dev->ifindex, inet6_sdif(skb));


	if (!sk) {
	if (!sk) {
		__ICMP6_INC_STATS(net, __in6_dev_get(skb->dev),
		__ICMP6_INC_STATS(net, __in6_dev_get(skb->dev),
@@ -918,7 +918,8 @@ static void tcp_v6_send_reset(const struct sock *sk, struct sk_buff *skb)
					   &tcp_hashinfo, NULL, 0,
					   &tcp_hashinfo, NULL, 0,
					   &ipv6h->saddr,
					   &ipv6h->saddr,
					   th->source, &ipv6h->daddr,
					   th->source, &ipv6h->daddr,
					   ntohs(th->source), tcp_v6_iif(skb));
					   ntohs(th->source), tcp_v6_iif(skb),
					   tcp_v6_sdif(skb));
		if (!sk1)
		if (!sk1)
			goto out;
			goto out;


@@ -1397,6 +1398,7 @@ static void tcp_v6_fill_cb(struct sk_buff *skb, const struct ipv6hdr *hdr,


static int tcp_v6_rcv(struct sk_buff *skb)
static int tcp_v6_rcv(struct sk_buff *skb)
{
{
	int sdif = inet6_sdif(skb);
	const struct tcphdr *th;
	const struct tcphdr *th;
	const struct ipv6hdr *hdr;
	const struct ipv6hdr *hdr;
	bool refcounted;
	bool refcounted;
@@ -1430,7 +1432,7 @@ static int tcp_v6_rcv(struct sk_buff *skb)


lookup:
lookup:
	sk = __inet6_lookup_skb(&tcp_hashinfo, skb, __tcp_hdrlen(th),
	sk = __inet6_lookup_skb(&tcp_hashinfo, skb, __tcp_hdrlen(th),
				th->source, th->dest, inet6_iif(skb),
				th->source, th->dest, inet6_iif(skb), sdif,
				&refcounted);
				&refcounted);
	if (!sk)
	if (!sk)
		goto no_tcp_socket;
		goto no_tcp_socket;
@@ -1563,7 +1565,8 @@ static int tcp_v6_rcv(struct sk_buff *skb)
					    skb, __tcp_hdrlen(th),
					    skb, __tcp_hdrlen(th),
					    &ipv6_hdr(skb)->saddr, th->source,
					    &ipv6_hdr(skb)->saddr, th->source,
					    &ipv6_hdr(skb)->daddr,
					    &ipv6_hdr(skb)->daddr,
					    ntohs(th->dest), tcp_v6_iif(skb));
					    ntohs(th->dest), tcp_v6_iif(skb),
					    sdif);
		if (sk2) {
		if (sk2) {
			struct inet_timewait_sock *tw = inet_twsk(sk);
			struct inet_timewait_sock *tw = inet_twsk(sk);
			inet_twsk_deschedule_put(tw);
			inet_twsk_deschedule_put(tw);
@@ -1610,7 +1613,7 @@ static void tcp_v6_early_demux(struct sk_buff *skb)
	sk = __inet6_lookup_established(dev_net(skb->dev), &tcp_hashinfo,
	sk = __inet6_lookup_established(dev_net(skb->dev), &tcp_hashinfo,
					&hdr->saddr, th->source,
					&hdr->saddr, th->source,
					&hdr->daddr, ntohs(th->dest),
					&hdr->daddr, ntohs(th->dest),
					inet6_iif(skb));
					inet6_iif(skb), inet6_sdif(skb));
	if (sk) {
	if (sk) {
		skb->sk = sk;
		skb->sk = sk;
		skb->destructor = sock_edemux;
		skb->destructor = sock_edemux;
Loading