Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 32c1da70 authored by Stephen Hemminger's avatar Stephen Hemminger Committed by David S. Miller
Browse files

[UDP]: Randomize port selection.



This patch causes UDP port allocation to be randomized like TCP.
The earlier code would always choose same port (ie first empty list).

Signed-off-by: default avatarStephen Hemminger <shemminger@linux-foundation.org>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 356f89e1
Loading
Loading
Loading
Loading
+45 −39
Original line number Diff line number Diff line
@@ -113,9 +113,8 @@ DEFINE_SNMP_STAT(struct udp_mib, udp_statistics) __read_mostly;
struct hlist_head udp_hash[UDP_HTABLE_SIZE];
DEFINE_RWLOCK(udp_hash_lock);

static int udp_port_rover;

static inline int __udp_lib_lport_inuse(__u16 num, struct hlist_head udptable[])
static inline int __udp_lib_lport_inuse(__u16 num,
					const struct hlist_head udptable[])
{
	struct sock *sk;
	struct hlist_node *node;
@@ -132,11 +131,10 @@ static inline int __udp_lib_lport_inuse(__u16 num, struct hlist_head udptable[])
 *  @sk:          socket struct in question
 *  @snum:        port number to look up
 *  @udptable:    hash list table, must be of UDP_HTABLE_SIZE
 *  @port_rover:  pointer to record of last unallocated port
 *  @saddr_comp:  AF-dependent comparison of bound local IP addresses
 */
int __udp_lib_get_port(struct sock *sk, unsigned short snum,
		       struct hlist_head udptable[], int *port_rover,
		       struct hlist_head udptable[],
		       int (*saddr_comp)(const struct sock *sk1,
					 const struct sock *sk2 )    )
{
@@ -146,49 +144,56 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum,
	int    error = 1;

	write_lock_bh(&udp_hash_lock);
	if (snum == 0) {
		int best_size_so_far, best, result, i;

		if (*port_rover > sysctl_local_port_range[1] ||
		    *port_rover < sysctl_local_port_range[0])
			*port_rover = sysctl_local_port_range[0];
		best_size_so_far = 32767;
		best = result = *port_rover;
		for (i = 0; i < UDP_HTABLE_SIZE; i++, result++) {
			int size;

			head = &udptable[result & (UDP_HTABLE_SIZE - 1)];
			if (hlist_empty(head)) {
				if (result > sysctl_local_port_range[1])
					result = sysctl_local_port_range[0] +
						((result - sysctl_local_port_range[0]) &
						 (UDP_HTABLE_SIZE - 1));

	if (!snum) {
		int i;
		int low = sysctl_local_port_range[0];
		int high = sysctl_local_port_range[1];
		unsigned rover, best, best_size_so_far;

		best_size_so_far = UINT_MAX;
		best = rover = net_random() % (high - low) + low;

		/* 1st pass: look for empty (or shortest) hash chain */
		for (i = 0; i < UDP_HTABLE_SIZE; i++) {
			int size = 0;

			head = &udptable[rover & (UDP_HTABLE_SIZE - 1)];
			if (hlist_empty(head))
				goto gotit;
			}
			size = 0;

			sk_for_each(sk2, node, head) {
				if (++size >= best_size_so_far)
					goto next;
			}
			best_size_so_far = size;
			best = result;
			best = rover;
		next:
			;
			/* fold back if end of range */
			if (++rover > high)
				rover = low + ((rover - low)
					       & (UDP_HTABLE_SIZE - 1));


		}
		result = best;
		for (i = 0; i < (1 << 16) / UDP_HTABLE_SIZE;
		     i++, result += UDP_HTABLE_SIZE) {
			if (result > sysctl_local_port_range[1])
				result = sysctl_local_port_range[0]
					+ ((result - sysctl_local_port_range[0]) &
					   (UDP_HTABLE_SIZE - 1));
			if (! __udp_lib_lport_inuse(result, udptable))
				break;

		/* 2nd pass: find hole in shortest hash chain */
		rover = best;
		for (i = 0; i < (1 << 16) / UDP_HTABLE_SIZE; i++) {
			if (! __udp_lib_lport_inuse(rover, udptable))
				goto gotit;
			rover += UDP_HTABLE_SIZE;
			if (rover > high)
				rover = low + ((rover - low)
					       & (UDP_HTABLE_SIZE - 1));
		}
		if (i >= (1 << 16) / UDP_HTABLE_SIZE)


		/* All ports in use! */
		goto fail;

gotit:
		*port_rover = snum = result;
		snum = rover;
	} else {
		head = &udptable[snum & (UDP_HTABLE_SIZE - 1)];

@@ -201,6 +206,7 @@ gotit:
			    (*saddr_comp)(sk, sk2)                             )
				goto fail;
	}

	inet_sk(sk)->num = snum;
	sk->sk_hash = snum;
	if (sk_unhashed(sk)) {
@@ -217,7 +223,7 @@ fail:
int udp_get_port(struct sock *sk, unsigned short snum,
			int (*scmp)(const struct sock *, const struct sock *))
{
	return  __udp_lib_get_port(sk, snum, udp_hash, &udp_port_rover, scmp);
	return  __udp_lib_get_port(sk, snum, udp_hash, scmp);
}

int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2)
+1 −1
Original line number Diff line number Diff line
@@ -9,7 +9,7 @@ extern int __udp4_lib_rcv(struct sk_buff *, struct hlist_head [], int );
extern void 	__udp4_lib_err(struct sk_buff *, u32, struct hlist_head []);

extern int	__udp_lib_get_port(struct sock *sk, unsigned short snum,
				   struct hlist_head udptable[], int *port_rover,
				   struct hlist_head udptable[],
				   int (*)(const struct sock*,const struct sock*));
extern int	ipv4_rcv_saddr_equal(const struct sock *, const struct sock *);

+1 −2
Original line number Diff line number Diff line
@@ -16,12 +16,11 @@
DEFINE_SNMP_STAT(struct udp_mib, udplite_statistics)	__read_mostly;

struct hlist_head 	udplite_hash[UDP_HTABLE_SIZE];
static int		udplite_port_rover;

int udplite_get_port(struct sock *sk, unsigned short p,
		     int (*c)(const struct sock *, const struct sock *))
{
	return  __udp_lib_get_port(sk, p, udplite_hash, &udplite_port_rover, c);
	return  __udp_lib_get_port(sk, p, udplite_hash, c);
}

static int udplite_v4_get_port(struct sock *sk, unsigned short snum)