Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit fe38d2a1 authored by Josef Bacik's avatar Josef Bacik Committed by David S. Miller
Browse files

inet: collapse ipv4/v6 rcv_saddr_equal functions into one



We pass these per-protocol equal functions around in various places, but
we can just have one function that checks the sk->sk_family and then do
the right comparison function.  I've also changed the ipv4 version to
not cast to inet_sock since it is unneeded.

Signed-off-by: default avatarJosef Bacik <jbacik@fb.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent ab70e586
Loading
Loading
Loading
Loading
+1 −3
Original line number Diff line number Diff line
@@ -88,9 +88,7 @@ int __ipv6_get_lladdr(struct inet6_dev *idev, struct in6_addr *addr,
		      u32 banned_flags);
int ipv6_get_lladdr(struct net_device *dev, struct in6_addr *addr,
		    u32 banned_flags);
int ipv4_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2,
			 bool match_wildcard);
int ipv6_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2,
int inet_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2,
			 bool match_wildcard);
void addrconf_join_solict(struct net_device *dev, const struct in6_addr *addr);
void addrconf_leave_solict(struct inet6_dev *idev, const struct in6_addr *addr);
+1 −4
Original line number Diff line number Diff line
@@ -203,10 +203,7 @@ void inet_hashinfo_init(struct inet_hashinfo *h);

bool inet_ehash_insert(struct sock *sk, struct sock *osk);
bool inet_ehash_nolisten(struct sock *sk, struct sock *osk);
int __inet_hash(struct sock *sk, struct sock *osk,
		int (*saddr_same)(const struct sock *sk1,
				  const struct sock *sk2,
				  bool match_wildcard));
int __inet_hash(struct sock *sk, struct sock *osk);
int inet_hash(struct sock *sk);
void inet_unhash(struct sock *sk);

+0 −1
Original line number Diff line number Diff line
@@ -204,7 +204,6 @@ static inline void udp_lib_close(struct sock *sk, long timeout)
}

int udp_lib_get_port(struct sock *sk, unsigned short snum,
		     int (*)(const struct sock *, const struct sock *, bool),
		     unsigned int hash2_nulladdr);

u32 udp_flow_hashrnd(void);
+72 −0
Original line number Diff line number Diff line
@@ -31,6 +31,78 @@ const char inet_csk_timer_bug_msg[] = "inet_csk BUG: unknown timer value\n";
EXPORT_SYMBOL(inet_csk_timer_bug_msg);
#endif

#if IS_ENABLED(CONFIG_IPV6)
/* match_wildcard == true:  IPV6_ADDR_ANY equals to any IPv6 addresses if IPv6
 *                          only, and any IPv4 addresses if not IPv6 only
 * match_wildcard == false: addresses must be exactly the same, i.e.
 *                          IPV6_ADDR_ANY only equals to IPV6_ADDR_ANY,
 *                          and 0.0.0.0 equals to 0.0.0.0 only
 */
static int ipv6_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2,
				bool match_wildcard)
{
	const struct in6_addr *sk2_rcv_saddr6 = inet6_rcv_saddr(sk2);
	int sk2_ipv6only = inet_v6_ipv6only(sk2);
	int addr_type = ipv6_addr_type(&sk->sk_v6_rcv_saddr);
	int addr_type2 = sk2_rcv_saddr6 ? ipv6_addr_type(sk2_rcv_saddr6) : IPV6_ADDR_MAPPED;

	/* if both are mapped, treat as IPv4 */
	if (addr_type == IPV6_ADDR_MAPPED && addr_type2 == IPV6_ADDR_MAPPED) {
		if (!sk2_ipv6only) {
			if (sk->sk_rcv_saddr == sk2->sk_rcv_saddr)
				return 1;
			if (!sk->sk_rcv_saddr || !sk2->sk_rcv_saddr)
				return match_wildcard;
		}
		return 0;
	}

	if (addr_type == IPV6_ADDR_ANY && addr_type2 == IPV6_ADDR_ANY)
		return 1;

	if (addr_type2 == IPV6_ADDR_ANY && match_wildcard &&
	    !(sk2_ipv6only && addr_type == IPV6_ADDR_MAPPED))
		return 1;

	if (addr_type == IPV6_ADDR_ANY && match_wildcard &&
	    !(ipv6_only_sock(sk) && addr_type2 == IPV6_ADDR_MAPPED))
		return 1;

	if (sk2_rcv_saddr6 &&
	    ipv6_addr_equal(&sk->sk_v6_rcv_saddr, sk2_rcv_saddr6))
		return 1;

	return 0;
}
#endif

/* match_wildcard == true:  0.0.0.0 equals to any IPv4 addresses
 * match_wildcard == false: addresses must be exactly the same, i.e.
 *                          0.0.0.0 only equals to 0.0.0.0
 */
static int ipv4_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2,
				bool match_wildcard)
{
	if (!ipv6_only_sock(sk2)) {
		if (sk->sk_rcv_saddr == sk2->sk_rcv_saddr)
			return 1;
		if (!sk->sk_rcv_saddr || !sk2->sk_rcv_saddr)
			return match_wildcard;
	}
	return 0;
}

int inet_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2,
			 bool match_wildcard)
{
#if IS_ENABLED(CONFIG_IPV6)
	if (sk->sk_family == AF_INET6)
		return ipv6_rcv_saddr_equal(sk, sk2, match_wildcard);
#endif
	return ipv4_rcv_saddr_equal(sk, sk2, match_wildcard);
}
EXPORT_SYMBOL(inet_rcv_saddr_equal);

void inet_get_local_port_range(struct net *net, int *low, int *high)
{
	unsigned int seq;
+5 −11
Original line number Diff line number Diff line
@@ -435,10 +435,7 @@ bool inet_ehash_nolisten(struct sock *sk, struct sock *osk)
EXPORT_SYMBOL_GPL(inet_ehash_nolisten);

static int inet_reuseport_add_sock(struct sock *sk,
				   struct inet_listen_hashbucket *ilb,
				   int (*saddr_same)(const struct sock *sk1,
						     const struct sock *sk2,
						     bool match_wildcard))
				   struct inet_listen_hashbucket *ilb)
{
	struct inet_bind_bucket *tb = inet_csk(sk)->icsk_bind_hash;
	struct sock *sk2;
@@ -451,7 +448,7 @@ static int inet_reuseport_add_sock(struct sock *sk,
		    sk2->sk_bound_dev_if == sk->sk_bound_dev_if &&
		    inet_csk(sk2)->icsk_bind_hash == tb &&
		    sk2->sk_reuseport && uid_eq(uid, sock_i_uid(sk2)) &&
		    saddr_same(sk, sk2, false))
		    inet_rcv_saddr_equal(sk, sk2, false))
			return reuseport_add_sock(sk, sk2);
	}

@@ -461,10 +458,7 @@ static int inet_reuseport_add_sock(struct sock *sk,
	return 0;
}

int __inet_hash(struct sock *sk, struct sock *osk,
		 int (*saddr_same)(const struct sock *sk1,
				   const struct sock *sk2,
				   bool match_wildcard))
int __inet_hash(struct sock *sk, struct sock *osk)
{
	struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
	struct inet_listen_hashbucket *ilb;
@@ -479,7 +473,7 @@ int __inet_hash(struct sock *sk, struct sock *osk,

	spin_lock(&ilb->lock);
	if (sk->sk_reuseport) {
		err = inet_reuseport_add_sock(sk, ilb, saddr_same);
		err = inet_reuseport_add_sock(sk, ilb);
		if (err)
			goto unlock;
	}
@@ -503,7 +497,7 @@ int inet_hash(struct sock *sk)

	if (sk->sk_state != TCP_CLOSE) {
		local_bh_disable();
		err = __inet_hash(sk, NULL, ipv4_rcv_saddr_equal);
		err = __inet_hash(sk, NULL);
		local_bh_enable();
	}

Loading