Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 1891d57f authored by Cong Wang's avatar Cong Wang Committed by Greg Kroah-Hartman
Browse files

net_sched: add a temporary refcnt for struct tcindex_data



[ Upstream commit 304e024216a802a7dc8ba75d36de82fa136bbf3e ]

Although we intentionally use an ordered workqueue for all tc
filter works, the ordering is not guaranteed by RCU work,
given that tcf_queue_work() is esstenially a call_rcu().

This problem is demostrated by Thomas:

  CPU 0:
    tcf_queue_work()
      tcf_queue_work(&r->rwork, tcindex_destroy_rexts_work);

  -> Migration to CPU 1

  CPU 1:
     tcf_queue_work(&p->rwork, tcindex_destroy_work);

so the 2nd work could be queued before the 1st one, which leads
to a free-after-free.

Enforcing this order in RCU work is hard as it requires to change
RCU code too. Fortunately we can workaround this problem in tcindex
filter by taking a temporary refcnt, we only refcnt it right before
we begin to destroy it. This simplifies the code a lot as a full
refcnt requires much more changes in tcindex_set_parms().

Reported-by: default avatar <syzbot+46f513c3033d592409d2@syzkaller.appspotmail.com>
Fixes: 3d210534 ("net_sched: fix a race condition in tcindex_destroy()")
Cc: Thomas Gleixner <tglx@linutronix.de>
Cc: Paul E. McKenney <paulmck@kernel.org>
Cc: Jamal Hadi Salim <jhs@mojatatu.com>
Cc: Jiri Pirko <jiri@resnulli.us>
Signed-off-by: default avatarCong Wang <xiyou.wangcong@gmail.com>
Reviewed-by: default avatarPaul E. McKenney <paulmck@kernel.org>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
Signed-off-by: default avatarGreg Kroah-Hartman <gregkh@linuxfoundation.org>
parent 1189ba9e
Loading
Loading
Loading
Loading
+38 −6
Original line number Original line Diff line number Diff line
@@ -11,6 +11,7 @@
#include <linux/skbuff.h>
#include <linux/skbuff.h>
#include <linux/errno.h>
#include <linux/errno.h>
#include <linux/slab.h>
#include <linux/slab.h>
#include <linux/refcount.h>
#include <net/act_api.h>
#include <net/act_api.h>
#include <net/netlink.h>
#include <net/netlink.h>
#include <net/pkt_cls.h>
#include <net/pkt_cls.h>
@@ -26,9 +27,12 @@
#define DEFAULT_HASH_SIZE	64	/* optimized for diffserv */
#define DEFAULT_HASH_SIZE	64	/* optimized for diffserv */




struct tcindex_data;

struct tcindex_filter_result {
struct tcindex_filter_result {
	struct tcf_exts		exts;
	struct tcf_exts		exts;
	struct tcf_result	res;
	struct tcf_result	res;
	struct tcindex_data	*p;
	struct rcu_work		rwork;
	struct rcu_work		rwork;
};
};


@@ -49,6 +53,7 @@ struct tcindex_data {
	u32 hash;		/* hash table size; 0 if undefined */
	u32 hash;		/* hash table size; 0 if undefined */
	u32 alloc_hash;		/* allocated size */
	u32 alloc_hash;		/* allocated size */
	u32 fall_through;	/* 0: only classify if explicit match */
	u32 fall_through;	/* 0: only classify if explicit match */
	refcount_t refcnt;	/* a temporary refcnt for perfect hash */
	struct rcu_work rwork;
	struct rcu_work rwork;
};
};


@@ -57,6 +62,20 @@ static inline int tcindex_filter_is_set(struct tcindex_filter_result *r)
	return tcf_exts_has_actions(&r->exts) || r->res.classid;
	return tcf_exts_has_actions(&r->exts) || r->res.classid;
}
}


static void tcindex_data_get(struct tcindex_data *p)
{
	refcount_inc(&p->refcnt);
}

static void tcindex_data_put(struct tcindex_data *p)
{
	if (refcount_dec_and_test(&p->refcnt)) {
		kfree(p->perfect);
		kfree(p->h);
		kfree(p);
	}
}

static struct tcindex_filter_result *tcindex_lookup(struct tcindex_data *p,
static struct tcindex_filter_result *tcindex_lookup(struct tcindex_data *p,
						    u16 key)
						    u16 key)
{
{
@@ -141,6 +160,7 @@ static void __tcindex_destroy_rexts(struct tcindex_filter_result *r)
{
{
	tcf_exts_destroy(&r->exts);
	tcf_exts_destroy(&r->exts);
	tcf_exts_put_net(&r->exts);
	tcf_exts_put_net(&r->exts);
	tcindex_data_put(r->p);
}
}


static void tcindex_destroy_rexts_work(struct work_struct *work)
static void tcindex_destroy_rexts_work(struct work_struct *work)
@@ -212,6 +232,8 @@ static int tcindex_delete(struct tcf_proto *tp, void *arg, bool *last,
		else
		else
			__tcindex_destroy_fexts(f);
			__tcindex_destroy_fexts(f);
	} else {
	} else {
		tcindex_data_get(p);

		if (tcf_exts_get_net(&r->exts))
		if (tcf_exts_get_net(&r->exts))
			tcf_queue_work(&r->rwork, tcindex_destroy_rexts_work);
			tcf_queue_work(&r->rwork, tcindex_destroy_rexts_work);
		else
		else
@@ -228,9 +250,7 @@ static void tcindex_destroy_work(struct work_struct *work)
					      struct tcindex_data,
					      struct tcindex_data,
					      rwork);
					      rwork);


	kfree(p->perfect);
	tcindex_data_put(p);
	kfree(p->h);
	kfree(p);
}
}


static inline int
static inline int
@@ -248,9 +268,11 @@ static const struct nla_policy tcindex_policy[TCA_TCINDEX_MAX + 1] = {
};
};


static int tcindex_filter_result_init(struct tcindex_filter_result *r,
static int tcindex_filter_result_init(struct tcindex_filter_result *r,
				      struct tcindex_data *p,
				      struct net *net)
				      struct net *net)
{
{
	memset(r, 0, sizeof(*r));
	memset(r, 0, sizeof(*r));
	r->p = p;
	return tcf_exts_init(&r->exts, net, TCA_TCINDEX_ACT,
	return tcf_exts_init(&r->exts, net, TCA_TCINDEX_ACT,
			     TCA_TCINDEX_POLICE);
			     TCA_TCINDEX_POLICE);
}
}
@@ -290,6 +312,7 @@ static int tcindex_alloc_perfect_hash(struct net *net, struct tcindex_data *cp)
				    TCA_TCINDEX_ACT, TCA_TCINDEX_POLICE);
				    TCA_TCINDEX_ACT, TCA_TCINDEX_POLICE);
		if (err < 0)
		if (err < 0)
			goto errout;
			goto errout;
		cp->perfect[i].p = cp;
	}
	}


	return 0;
	return 0;
@@ -334,6 +357,7 @@ tcindex_set_parms(struct net *net, struct tcf_proto *tp, unsigned long base,
	cp->alloc_hash = p->alloc_hash;
	cp->alloc_hash = p->alloc_hash;
	cp->fall_through = p->fall_through;
	cp->fall_through = p->fall_through;
	cp->tp = tp;
	cp->tp = tp;
	refcount_set(&cp->refcnt, 1); /* Paired with tcindex_destroy_work() */


	if (tb[TCA_TCINDEX_HASH])
	if (tb[TCA_TCINDEX_HASH])
		cp->hash = nla_get_u32(tb[TCA_TCINDEX_HASH]);
		cp->hash = nla_get_u32(tb[TCA_TCINDEX_HASH]);
@@ -366,7 +390,7 @@ tcindex_set_parms(struct net *net, struct tcf_proto *tp, unsigned long base,
	}
	}
	cp->h = p->h;
	cp->h = p->h;


	err = tcindex_filter_result_init(&new_filter_result, net);
	err = tcindex_filter_result_init(&new_filter_result, cp, net);
	if (err < 0)
	if (err < 0)
		goto errout_alloc;
		goto errout_alloc;
	if (old_r)
	if (old_r)
@@ -434,7 +458,7 @@ tcindex_set_parms(struct net *net, struct tcf_proto *tp, unsigned long base,
			goto errout_alloc;
			goto errout_alloc;
		f->key = handle;
		f->key = handle;
		f->next = NULL;
		f->next = NULL;
		err = tcindex_filter_result_init(&f->result, net);
		err = tcindex_filter_result_init(&f->result, cp, net);
		if (err < 0) {
		if (err < 0) {
			kfree(f);
			kfree(f);
			goto errout_alloc;
			goto errout_alloc;
@@ -447,7 +471,7 @@ tcindex_set_parms(struct net *net, struct tcf_proto *tp, unsigned long base,
	}
	}


	if (old_r && old_r != r) {
	if (old_r && old_r != r) {
		err = tcindex_filter_result_init(old_r, net);
		err = tcindex_filter_result_init(old_r, cp, net);
		if (err < 0) {
		if (err < 0) {
			kfree(f);
			kfree(f);
			goto errout_alloc;
			goto errout_alloc;
@@ -571,6 +595,14 @@ static void tcindex_destroy(struct tcf_proto *tp, bool rtnl_held,
		for (i = 0; i < p->hash; i++) {
		for (i = 0; i < p->hash; i++) {
			struct tcindex_filter_result *r = p->perfect + i;
			struct tcindex_filter_result *r = p->perfect + i;


			/* tcf_queue_work() does not guarantee the ordering we
			 * want, so we have to take this refcnt temporarily to
			 * ensure 'p' is freed after all tcindex_filter_result
			 * here. Imperfect hash does not need this, because it
			 * uses linked lists rather than an array.
			 */
			tcindex_data_get(p);

			tcf_unbind_filter(tp, &r->res);
			tcf_unbind_filter(tp, &r->res);
			if (tcf_exts_get_net(&r->exts))
			if (tcf_exts_get_net(&r->exts))
				tcf_queue_work(&r->rwork,
				tcf_queue_work(&r->rwork,