Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit c5042dac authored by David S. Miller's avatar David S. Miller
Browse files

Merge branch 'l2tp-remove-unsafe-calls-to-l2tp_tunnel_find_nth'



Guillaume Nault says:

====================
l2tp: remove unsafe calls to l2tp_tunnel_find_nth()

Using l2tp_tunnel_find_nth() is racy, because the returned tunnel can
go away as soon as this function returns. This series introduce
l2tp_tunnel_get_nth() as a safe replacement to fixes these races.

With this series, all unsafe tunnel/session lookups are finally gone.
====================

Acked-by: default avatarJason Wang <jasowang@redhat.com>
Acked-by: default avatarMichael S. Tsirkin <mst@redhat.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 9267c430 f726214d
Loading
Loading
Loading
Loading
+20 −20
Original line number Diff line number Diff line
@@ -183,6 +183,26 @@ struct l2tp_tunnel *l2tp_tunnel_get(const struct net *net, u32 tunnel_id)
}
EXPORT_SYMBOL_GPL(l2tp_tunnel_get);

struct l2tp_tunnel *l2tp_tunnel_get_nth(const struct net *net, int nth)
{
	const struct l2tp_net *pn = l2tp_pernet(net);
	struct l2tp_tunnel *tunnel;
	int count = 0;

	rcu_read_lock_bh();
	list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) {
		if (++count > nth) {
			l2tp_tunnel_inc_refcount(tunnel);
			rcu_read_unlock_bh();
			return tunnel;
		}
	}
	rcu_read_unlock_bh();

	return NULL;
}
EXPORT_SYMBOL_GPL(l2tp_tunnel_get_nth);

/* Lookup a session. A new reference is held on the returned session. */
struct l2tp_session *l2tp_session_get(const struct net *net,
				      struct l2tp_tunnel *tunnel,
@@ -335,26 +355,6 @@ int l2tp_session_register(struct l2tp_session *session,
}
EXPORT_SYMBOL_GPL(l2tp_session_register);

struct l2tp_tunnel *l2tp_tunnel_find_nth(const struct net *net, int nth)
{
	struct l2tp_net *pn = l2tp_pernet(net);
	struct l2tp_tunnel *tunnel;
	int count = 0;

	rcu_read_lock_bh();
	list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) {
		if (++count > nth) {
			rcu_read_unlock_bh();
			return tunnel;
		}
	}

	rcu_read_unlock_bh();

	return NULL;
}
EXPORT_SYMBOL_GPL(l2tp_tunnel_find_nth);

/*****************************************************************************
 * Receive data handling
 *****************************************************************************/
+2 −1
Original line number Diff line number Diff line
@@ -212,6 +212,8 @@ static inline void *l2tp_session_priv(struct l2tp_session *session)
}

struct l2tp_tunnel *l2tp_tunnel_get(const struct net *net, u32 tunnel_id);
struct l2tp_tunnel *l2tp_tunnel_get_nth(const struct net *net, int nth);

void l2tp_tunnel_free(struct l2tp_tunnel *tunnel);

struct l2tp_session *l2tp_session_get(const struct net *net,
@@ -220,7 +222,6 @@ struct l2tp_session *l2tp_session_get(const struct net *net,
struct l2tp_session *l2tp_session_get_nth(struct l2tp_tunnel *tunnel, int nth);
struct l2tp_session *l2tp_session_get_by_ifname(const struct net *net,
						const char *ifname);
struct l2tp_tunnel *l2tp_tunnel_find_nth(const struct net *net, int nth);

int l2tp_tunnel_create(struct net *net, int fd, int version, u32 tunnel_id,
		       u32 peer_tunnel_id, struct l2tp_tunnel_cfg *cfg,
+13 −2
Original line number Diff line number Diff line
@@ -47,7 +47,11 @@ struct l2tp_dfs_seq_data {

static void l2tp_dfs_next_tunnel(struct l2tp_dfs_seq_data *pd)
{
	pd->tunnel = l2tp_tunnel_find_nth(pd->net, pd->tunnel_idx);
	/* Drop reference taken during previous invocation */
	if (pd->tunnel)
		l2tp_tunnel_dec_refcount(pd->tunnel);

	pd->tunnel = l2tp_tunnel_get_nth(pd->net, pd->tunnel_idx);
	pd->tunnel_idx++;
}

@@ -96,7 +100,14 @@ static void *l2tp_dfs_seq_next(struct seq_file *m, void *v, loff_t *pos)

static void l2tp_dfs_seq_stop(struct seq_file *p, void *v)
{
	/* nothing to do */
	struct l2tp_dfs_seq_data *pd = v;

	if (!pd || pd == SEQ_START_TOKEN)
		return;

	/* Drop reference taken by last invocation of l2tp_dfs_next_tunnel() */
	if (pd->tunnel)
		l2tp_tunnel_dec_refcount(pd->tunnel);
}

static void l2tp_dfs_seq_tunnel_show(struct seq_file *m, void *v)
+8 −3
Original line number Diff line number Diff line
@@ -487,14 +487,17 @@ static int l2tp_nl_cmd_tunnel_dump(struct sk_buff *skb, struct netlink_callback
	struct net *net = sock_net(skb->sk);

	for (;;) {
		tunnel = l2tp_tunnel_find_nth(net, ti);
		tunnel = l2tp_tunnel_get_nth(net, ti);
		if (tunnel == NULL)
			goto out;

		if (l2tp_nl_tunnel_send(skb, NETLINK_CB(cb->skb).portid,
					cb->nlh->nlmsg_seq, NLM_F_MULTI,
					tunnel, L2TP_CMD_TUNNEL_GET) < 0)
					tunnel, L2TP_CMD_TUNNEL_GET) < 0) {
			l2tp_tunnel_dec_refcount(tunnel);
			goto out;
		}
		l2tp_tunnel_dec_refcount(tunnel);

		ti++;
	}
@@ -848,7 +851,7 @@ static int l2tp_nl_cmd_session_dump(struct sk_buff *skb, struct netlink_callback

	for (;;) {
		if (tunnel == NULL) {
			tunnel = l2tp_tunnel_find_nth(net, ti);
			tunnel = l2tp_tunnel_get_nth(net, ti);
			if (tunnel == NULL)
				goto out;
		}
@@ -856,6 +859,7 @@ static int l2tp_nl_cmd_session_dump(struct sk_buff *skb, struct netlink_callback
		session = l2tp_session_get_nth(tunnel, si);
		if (session == NULL) {
			ti++;
			l2tp_tunnel_dec_refcount(tunnel);
			tunnel = NULL;
			si = 0;
			continue;
@@ -865,6 +869,7 @@ static int l2tp_nl_cmd_session_dump(struct sk_buff *skb, struct netlink_callback
					 cb->nlh->nlmsg_seq, NLM_F_MULTI,
					 session, L2TP_CMD_SESSION_GET) < 0) {
			l2tp_session_dec_refcount(session);
			l2tp_tunnel_dec_refcount(tunnel);
			break;
		}
		l2tp_session_dec_refcount(session);
+17 −7
Original line number Diff line number Diff line
@@ -1551,16 +1551,19 @@ struct pppol2tp_seq_data {

static void pppol2tp_next_tunnel(struct net *net, struct pppol2tp_seq_data *pd)
{
	/* Drop reference taken during previous invocation */
	if (pd->tunnel)
		l2tp_tunnel_dec_refcount(pd->tunnel);

	for (;;) {
		pd->tunnel = l2tp_tunnel_find_nth(net, pd->tunnel_idx);
		pd->tunnel = l2tp_tunnel_get_nth(net, pd->tunnel_idx);
		pd->tunnel_idx++;

		if (pd->tunnel == NULL)
			break;
		/* Only accept L2TPv2 tunnels */
		if (!pd->tunnel || pd->tunnel->version == 2)
			return;

		/* Ignore L2TPv3 tunnels */
		if (pd->tunnel->version < 3)
			break;
		l2tp_tunnel_dec_refcount(pd->tunnel);
	}
}

@@ -1609,7 +1612,14 @@ static void *pppol2tp_seq_next(struct seq_file *m, void *v, loff_t *pos)

static void pppol2tp_seq_stop(struct seq_file *p, void *v)
{
	/* nothing to do */
	struct pppol2tp_seq_data *pd = v;

	if (!pd || pd == SEQ_START_TOKEN)
		return;

	/* Drop reference taken by last invocation of pppol2tp_next_tunnel() */
	if (pd->tunnel)
		l2tp_tunnel_dec_refcount(pd->tunnel);
}

static void pppol2tp_seq_tunnel_show(struct seq_file *m, void *v)