Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 8217ca65 authored by Martin KaFai Lau's avatar Martin KaFai Lau Committed by Daniel Borkmann
Browse files

bpf: Enable BPF_PROG_TYPE_SK_REUSEPORT bpf prog in reuseport selection



This patch allows a BPF_PROG_TYPE_SK_REUSEPORT bpf prog to select a
SO_REUSEPORT sk from a BPF_MAP_TYPE_REUSEPORT_ARRAY introduced in
the earlier patch.  "bpf_run_sk_reuseport()" will return -ECONNREFUSED
when the BPF_PROG_TYPE_SK_REUSEPORT prog returns SK_DROP.
The callers, in inet[6]_hashtable.c and ipv[46]/udp.c, are modified to
handle this case and return NULL immediately instead of continuing the
sk search from its hashtable.

It re-uses the existing SO_ATTACH_REUSEPORT_EBPF setsockopt to attach
BPF_PROG_TYPE_SK_REUSEPORT.  The "sk_reuseport_attach_bpf()" will check
if the attaching bpf prog is in the new SK_REUSEPORT or the existing
SOCKET_FILTER type and then check different things accordingly.

One level of "__reuseport_attach_prog()" call is removed.  The
"sk_unhashed() && ..." and "sk->sk_reuseport_cb" tests are pushed
back to "reuseport_attach_prog()" in sock_reuseport.c.  sock_reuseport.c
seems to have more knowledge on those test requirements than filter.c.
In "reuseport_attach_prog()", after new_prog is attached to reuse->prog,
the old_prog (if any) is also directly freed instead of returning the
old_prog to the caller and asking the caller to free.

The sysctl_optmem_max check is moved back to the
"sk_reuseport_attach_filter()" and "sk_reuseport_attach_bpf()".
As of other bpf prog types, the new BPF_PROG_TYPE_SK_REUSEPORT is only
bounded by the usual "bpf_prog_charge_memlock()" during load time
instead of bounded by both bpf_prog_charge_memlock and sysctl_optmem_max.

Signed-off-by: default avatarMartin KaFai Lau <kafai@fb.com>
Acked-by: default avatarAlexei Starovoitov <ast@kernel.org>
Signed-off-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
parent 2dbb9b9e
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -753,6 +753,7 @@ int sk_attach_filter(struct sock_fprog *fprog, struct sock *sk);
int sk_attach_bpf(u32 ufd, struct sock *sk);
int sk_reuseport_attach_filter(struct sock_fprog *fprog, struct sock *sk);
int sk_reuseport_attach_bpf(u32 ufd, struct sock *sk);
void sk_reuseport_prog_free(struct bpf_prog *prog);
int sk_detach_filter(struct sock *sk);
int sk_get_filter(struct sock *sk, struct sock_filter __user *filter,
		  unsigned int len);
+1 −2
Original line number Diff line number Diff line
@@ -34,8 +34,7 @@ extern struct sock *reuseport_select_sock(struct sock *sk,
					  u32 hash,
					  struct sk_buff *skb,
					  int hdr_len);
extern struct bpf_prog *reuseport_attach_prog(struct sock *sk,
					      struct bpf_prog *prog);
extern int reuseport_attach_prog(struct sock *sk, struct bpf_prog *prog);
int reuseport_get_id(struct sock_reuseport *reuse);

#endif  /* _SOCK_REUSEPORT_H */
+52 −35
Original line number Diff line number Diff line
@@ -1453,30 +1453,6 @@ static int __sk_attach_prog(struct bpf_prog *prog, struct sock *sk)
	return 0;
}

static int __reuseport_attach_prog(struct bpf_prog *prog, struct sock *sk)
{
	struct bpf_prog *old_prog;
	int err;

	if (bpf_prog_size(prog->len) > sysctl_optmem_max)
		return -ENOMEM;

	if (sk_unhashed(sk) && sk->sk_reuseport) {
		err = reuseport_alloc(sk, false);
		if (err)
			return err;
	} else if (!rcu_access_pointer(sk->sk_reuseport_cb)) {
		/* The socket wasn't bound with SO_REUSEPORT */
		return -EINVAL;
	}

	old_prog = reuseport_attach_prog(sk, prog);
	if (old_prog)
		bpf_prog_destroy(old_prog);

	return 0;
}

static
struct bpf_prog *__get_filter(struct sock_fprog *fprog, struct sock *sk)
{
@@ -1550,13 +1526,15 @@ int sk_reuseport_attach_filter(struct sock_fprog *fprog, struct sock *sk)
	if (IS_ERR(prog))
		return PTR_ERR(prog);

	err = __reuseport_attach_prog(prog, sk);
	if (err < 0) {
	if (bpf_prog_size(prog->len) > sysctl_optmem_max)
		err = -ENOMEM;
	else
		err = reuseport_attach_prog(sk, prog);

	if (err)
		__bpf_prog_release(prog);
		return err;
	}

	return 0;
	return err;
}

static struct bpf_prog *__get_bpf(u32 ufd, struct sock *sk)
@@ -1586,19 +1564,58 @@ int sk_attach_bpf(u32 ufd, struct sock *sk)

int sk_reuseport_attach_bpf(u32 ufd, struct sock *sk)
{
	struct bpf_prog *prog = __get_bpf(ufd, sk);
	struct bpf_prog *prog;
	int err;

	if (sock_flag(sk, SOCK_FILTER_LOCKED))
		return -EPERM;

	prog = bpf_prog_get_type(ufd, BPF_PROG_TYPE_SOCKET_FILTER);
	if (IS_ERR(prog) && PTR_ERR(prog) == -EINVAL)
		prog = bpf_prog_get_type(ufd, BPF_PROG_TYPE_SK_REUSEPORT);
	if (IS_ERR(prog))
		return PTR_ERR(prog);

	err = __reuseport_attach_prog(prog, sk);
	if (err < 0) {
	if (prog->type == BPF_PROG_TYPE_SK_REUSEPORT) {
		/* Like other non BPF_PROG_TYPE_SOCKET_FILTER
		 * bpf prog (e.g. sockmap).  It depends on the
		 * limitation imposed by bpf_prog_load().
		 * Hence, sysctl_optmem_max is not checked.
		 */
		if ((sk->sk_type != SOCK_STREAM &&
		     sk->sk_type != SOCK_DGRAM) ||
		    (sk->sk_protocol != IPPROTO_UDP &&
		     sk->sk_protocol != IPPROTO_TCP) ||
		    (sk->sk_family != AF_INET &&
		     sk->sk_family != AF_INET6)) {
			err = -ENOTSUPP;
			goto err_prog_put;
		}
	} else {
		/* BPF_PROG_TYPE_SOCKET_FILTER */
		if (bpf_prog_size(prog->len) > sysctl_optmem_max) {
			err = -ENOMEM;
			goto err_prog_put;
		}
	}

	err = reuseport_attach_prog(sk, prog);
err_prog_put:
	if (err)
		bpf_prog_put(prog);

	return err;
}

	return 0;
void sk_reuseport_prog_free(struct bpf_prog *prog)
{
	if (!prog)
		return;

	if (prog->type == BPF_PROG_TYPE_SK_REUSEPORT)
		bpf_prog_put(prog);
	else
		bpf_prog_destroy(prog);
}

struct bpf_scratchpad {
+26 −10
Original line number Diff line number Diff line
@@ -9,6 +9,7 @@
#include <net/sock_reuseport.h>
#include <linux/bpf.h>
#include <linux/idr.h>
#include <linux/filter.h>
#include <linux/rcupdate.h>

#define INIT_SOCKS 128
@@ -133,8 +134,7 @@ static void reuseport_free_rcu(struct rcu_head *head)
	struct sock_reuseport *reuse;

	reuse = container_of(head, struct sock_reuseport, rcu);
	if (reuse->prog)
		bpf_prog_destroy(reuse->prog);
	sk_reuseport_prog_free(rcu_dereference_protected(reuse->prog, 1));
	if (reuse->reuseport_id)
		ida_simple_remove(&reuseport_ida, reuse->reuseport_id);
	kfree(reuse);
@@ -219,7 +219,7 @@ void reuseport_detach_sock(struct sock *sk)
}
EXPORT_SYMBOL(reuseport_detach_sock);

static struct sock *run_bpf(struct sock_reuseport *reuse, u16 socks,
static struct sock *run_bpf_filter(struct sock_reuseport *reuse, u16 socks,
				   struct bpf_prog *prog, struct sk_buff *skb,
				   int hdr_len)
{
@@ -282,9 +282,15 @@ struct sock *reuseport_select_sock(struct sock *sk,
		/* paired with smp_wmb() in reuseport_add_sock() */
		smp_rmb();

		if (prog && skb)
			sk2 = run_bpf(reuse, socks, prog, skb, hdr_len);
		if (!prog || !skb)
			goto select_by_hash;

		if (prog->type == BPF_PROG_TYPE_SK_REUSEPORT)
			sk2 = bpf_run_sk_reuseport(reuse, sk, prog, skb, hash);
		else
			sk2 = run_bpf_filter(reuse, socks, prog, skb, hdr_len);

select_by_hash:
		/* no bpf or invalid bpf result: fall back to hash usage */
		if (!sk2)
			sk2 = reuse->socks[reciprocal_scale(hash, socks)];
@@ -296,12 +302,21 @@ struct sock *reuseport_select_sock(struct sock *sk,
}
EXPORT_SYMBOL(reuseport_select_sock);

struct bpf_prog *
reuseport_attach_prog(struct sock *sk, struct bpf_prog *prog)
int reuseport_attach_prog(struct sock *sk, struct bpf_prog *prog)
{
	struct sock_reuseport *reuse;
	struct bpf_prog *old_prog;

	if (sk_unhashed(sk) && sk->sk_reuseport) {
		int err = reuseport_alloc(sk, false);

		if (err)
			return err;
	} else if (!rcu_access_pointer(sk->sk_reuseport_cb)) {
		/* The socket wasn't bound with SO_REUSEPORT */
		return -EINVAL;
	}

	spin_lock_bh(&reuseport_lock);
	reuse = rcu_dereference_protected(sk->sk_reuseport_cb,
					  lockdep_is_held(&reuseport_lock));
@@ -310,6 +325,7 @@ reuseport_attach_prog(struct sock *sk, struct bpf_prog *prog)
	rcu_assign_pointer(reuse->prog, prog);
	spin_unlock_bh(&reuseport_lock);

	return old_prog;
	sk_reuseport_prog_free(old_prog);
	return 0;
}
EXPORT_SYMBOL(reuseport_attach_prog);
+9 −5
Original line number Diff line number Diff line
@@ -328,7 +328,7 @@ struct sock *__inet_lookup_listener(struct net *net,
				    saddr, sport, daddr, hnum,
				    dif, sdif);
	if (result)
		return result;
		goto done;

	/* Lookup lhash2 with INADDR_ANY */

@@ -337,9 +337,10 @@ struct sock *__inet_lookup_listener(struct net *net,
	if (ilb2->count > ilb->count)
		goto port_lookup;

	return inet_lhash2_lookup(net, ilb2, skb, doff,
	result = inet_lhash2_lookup(net, ilb2, skb, doff,
				    saddr, sport, daddr, hnum,
				    dif, sdif);
	goto done;

port_lookup:
	sk_for_each_rcu(sk, &ilb->head) {
@@ -352,12 +353,15 @@ struct sock *__inet_lookup_listener(struct net *net,
				result = reuseport_select_sock(sk, phash,
							       skb, doff);
				if (result)
					return result;
					goto done;
			}
			result = sk;
			hiscore = score;
		}
	}
done:
	if (unlikely(IS_ERR(result)))
		return NULL;
	return result;
}
EXPORT_SYMBOL_GPL(__inet_lookup_listener);
Loading