Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit e5c1e519 authored by David S. Miller's avatar David S. Miller
Browse files

Merge branch 'l2tp_session_find-fixes'



Guillaume Nault says:

====================
l2tp: fix usage of l2tp_session_find()

l2tp_session_find() doesn't take a reference on the session returned to
its caller. Virtually all l2tp_session_find() users are racy, either
because the session can disappear from under them or because they take
a reference too late. This leads to bugs like 'use after free' or
failure to notice duplicate session creations.

In some cases, taking a reference on the session is not enough. The
special callbacks .ref() and .deref() also have to be called in cases
where the PPP pseudo-wire uses the socket associated with the session.
Therefore, when looking up a session, we also have to pass a flag
indicating if the .ref() callback has to be called.

In the future, we probably could drop the .ref() and .deref() callbacks
entirely by protecting the .sock field of struct pppol2tp_session with
RCU, thus allowing it to be freed and set to NULL even if the L2TP
session is still alive.
====================

Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents afe89962 2777e2ab
Loading
Loading
Loading
Loading
+120 −32
Original line number Diff line number Diff line
@@ -278,6 +278,55 @@ struct l2tp_session *l2tp_session_find(struct net *net, struct l2tp_tunnel *tunn
}
EXPORT_SYMBOL_GPL(l2tp_session_find);

/* Like l2tp_session_find() but takes a reference on the returned session.
 * Optionally calls session->ref() too if do_ref is true.
 */
struct l2tp_session *l2tp_session_get(struct net *net,
				      struct l2tp_tunnel *tunnel,
				      u32 session_id, bool do_ref)
{
	struct hlist_head *session_list;
	struct l2tp_session *session;

	if (!tunnel) {
		struct l2tp_net *pn = l2tp_pernet(net);

		session_list = l2tp_session_id_hash_2(pn, session_id);

		rcu_read_lock_bh();
		hlist_for_each_entry_rcu(session, session_list, global_hlist) {
			if (session->session_id == session_id) {
				l2tp_session_inc_refcount(session);
				if (do_ref && session->ref)
					session->ref(session);
				rcu_read_unlock_bh();

				return session;
			}
		}
		rcu_read_unlock_bh();

		return NULL;
	}

	session_list = l2tp_session_id_hash(tunnel, session_id);
	read_lock_bh(&tunnel->hlist_lock);
	hlist_for_each_entry(session, session_list, hlist) {
		if (session->session_id == session_id) {
			l2tp_session_inc_refcount(session);
			if (do_ref && session->ref)
				session->ref(session);
			read_unlock_bh(&tunnel->hlist_lock);

			return session;
		}
	}
	read_unlock_bh(&tunnel->hlist_lock);

	return NULL;
}
EXPORT_SYMBOL_GPL(l2tp_session_get);

struct l2tp_session *l2tp_session_find_nth(struct l2tp_tunnel *tunnel, int nth)
{
	int hash;
@@ -303,7 +352,8 @@ EXPORT_SYMBOL_GPL(l2tp_session_find_nth);
/* Lookup a session by interface name.
 * This is very inefficient but is only used by management interfaces.
 */
struct l2tp_session *l2tp_session_find_by_ifname(struct net *net, char *ifname)
struct l2tp_session *l2tp_session_get_by_ifname(struct net *net, char *ifname,
						bool do_ref)
{
	struct l2tp_net *pn = l2tp_pernet(net);
	int hash;
@@ -313,7 +363,11 @@ struct l2tp_session *l2tp_session_find_by_ifname(struct net *net, char *ifname)
	for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++) {
		hlist_for_each_entry_rcu(session, &pn->l2tp_session_hlist[hash], global_hlist) {
			if (!strcmp(session->ifname, ifname)) {
				l2tp_session_inc_refcount(session);
				if (do_ref && session->ref)
					session->ref(session);
				rcu_read_unlock_bh();

				return session;
			}
		}
@@ -323,7 +377,49 @@ struct l2tp_session *l2tp_session_find_by_ifname(struct net *net, char *ifname)

	return NULL;
}
EXPORT_SYMBOL_GPL(l2tp_session_find_by_ifname);
EXPORT_SYMBOL_GPL(l2tp_session_get_by_ifname);

static int l2tp_session_add_to_tunnel(struct l2tp_tunnel *tunnel,
				      struct l2tp_session *session)
{
	struct l2tp_session *session_walk;
	struct hlist_head *g_head;
	struct hlist_head *head;
	struct l2tp_net *pn;

	head = l2tp_session_id_hash(tunnel, session->session_id);

	write_lock_bh(&tunnel->hlist_lock);
	hlist_for_each_entry(session_walk, head, hlist)
		if (session_walk->session_id == session->session_id)
			goto exist;

	if (tunnel->version == L2TP_HDR_VER_3) {
		pn = l2tp_pernet(tunnel->l2tp_net);
		g_head = l2tp_session_id_hash_2(l2tp_pernet(tunnel->l2tp_net),
						session->session_id);

		spin_lock_bh(&pn->l2tp_session_hlist_lock);
		hlist_for_each_entry(session_walk, g_head, global_hlist)
			if (session_walk->session_id == session->session_id)
				goto exist_glob;

		hlist_add_head_rcu(&session->global_hlist, g_head);
		spin_unlock_bh(&pn->l2tp_session_hlist_lock);
	}

	hlist_add_head(&session->hlist, head);
	write_unlock_bh(&tunnel->hlist_lock);

	return 0;

exist_glob:
	spin_unlock_bh(&pn->l2tp_session_hlist_lock);
exist:
	write_unlock_bh(&tunnel->hlist_lock);

	return -EEXIST;
}

/* Lookup a tunnel by id
 */
@@ -633,6 +729,9 @@ static int l2tp_recv_data_seq(struct l2tp_session *session, struct sk_buff *skb)
 * a data (not control) frame before coming here. Fields up to the
 * session-id have already been parsed and ptr points to the data
 * after the session-id.
 *
 * session->ref() must have been called prior to l2tp_recv_common().
 * session->deref() will be called automatically after skb is processed.
 */
void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
		      unsigned char *ptr, unsigned char *optr, u16 hdrflags,
@@ -642,14 +741,6 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
	int offset;
	u32 ns, nr;

	/* The ref count is increased since we now hold a pointer to
	 * the session. Take care to decrement the refcnt when exiting
	 * this function from now on...
	 */
	l2tp_session_inc_refcount(session);
	if (session->ref)
		(*session->ref)(session);

	/* Parse and check optional cookie */
	if (session->peer_cookie_len > 0) {
		if (memcmp(ptr, &session->peer_cookie[0], session->peer_cookie_len)) {
@@ -802,8 +893,6 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
	/* Try to dequeue as many skbs from reorder_q as we can. */
	l2tp_recv_dequeue(session);

	l2tp_session_dec_refcount(session);

	return;

discard:
@@ -812,8 +901,6 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,

	if (session->deref)
		(*session->deref)(session);

	l2tp_session_dec_refcount(session);
}
EXPORT_SYMBOL(l2tp_recv_common);

@@ -920,8 +1007,14 @@ static int l2tp_udp_recv_core(struct l2tp_tunnel *tunnel, struct sk_buff *skb,
	}

	/* Find the session context */
	session = l2tp_session_find(tunnel->l2tp_net, tunnel, session_id);
	session = l2tp_session_get(tunnel->l2tp_net, tunnel, session_id, true);
	if (!session || !session->recv_skb) {
		if (session) {
			if (session->deref)
				session->deref(session);
			l2tp_session_dec_refcount(session);
		}

		/* Not found? Pass to userspace to deal with */
		l2tp_info(tunnel, L2TP_MSG_DATA,
			  "%s: no session found (%u/%u). Passing up.\n",
@@ -930,6 +1023,7 @@ static int l2tp_udp_recv_core(struct l2tp_tunnel *tunnel, struct sk_buff *skb,
	}

	l2tp_recv_common(session, skb, ptr, optr, hdrflags, length, payload_hook);
	l2tp_session_dec_refcount(session);

	return 0;

@@ -1738,6 +1832,7 @@ EXPORT_SYMBOL_GPL(l2tp_session_set_header_len);
struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunnel, u32 session_id, u32 peer_session_id, struct l2tp_session_cfg *cfg)
{
	struct l2tp_session *session;
	int err;

	session = kzalloc(sizeof(struct l2tp_session) + priv_size, GFP_KERNEL);
	if (session != NULL) {
@@ -1793,6 +1888,13 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn

		l2tp_session_set_header_len(session, tunnel->version);

		err = l2tp_session_add_to_tunnel(tunnel, session);
		if (err) {
			kfree(session);

			return ERR_PTR(err);
		}

		/* Bump the reference count. The session context is deleted
		 * only when this drops to zero.
		 */
@@ -1802,29 +1904,15 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn
		/* Ensure tunnel socket isn't deleted */
		sock_hold(tunnel->sock);

		/* Add session to the tunnel's hash list */
		write_lock_bh(&tunnel->hlist_lock);
		hlist_add_head(&session->hlist,
			       l2tp_session_id_hash(tunnel, session_id));
		write_unlock_bh(&tunnel->hlist_lock);

		/* And to the global session list if L2TPv3 */
		if (tunnel->version != L2TP_HDR_VER_2) {
			struct l2tp_net *pn = l2tp_pernet(tunnel->l2tp_net);

			spin_lock_bh(&pn->l2tp_session_hlist_lock);
			hlist_add_head_rcu(&session->global_hlist,
					   l2tp_session_id_hash_2(pn, session_id));
			spin_unlock_bh(&pn->l2tp_session_hlist_lock);
		}

		/* Ignore management session in session count value */
		if (session->session_id != 0)
			atomic_inc(&l2tp_session_count);
	}

		return session;
	}

	return ERR_PTR(-ENOMEM);
}
EXPORT_SYMBOL_GPL(l2tp_session_create);

/*****************************************************************************
+5 −1
Original line number Diff line number Diff line
@@ -230,11 +230,15 @@ static inline struct l2tp_tunnel *l2tp_sock_to_tunnel(struct sock *sk)
	return tunnel;
}

struct l2tp_session *l2tp_session_get(struct net *net,
				      struct l2tp_tunnel *tunnel,
				      u32 session_id, bool do_ref);
struct l2tp_session *l2tp_session_find(struct net *net,
				       struct l2tp_tunnel *tunnel,
				       u32 session_id);
struct l2tp_session *l2tp_session_find_nth(struct l2tp_tunnel *tunnel, int nth);
struct l2tp_session *l2tp_session_find_by_ifname(struct net *net, char *ifname);
struct l2tp_session *l2tp_session_get_by_ifname(struct net *net, char *ifname,
						bool do_ref);
struct l2tp_tunnel *l2tp_tunnel_find(struct net *net, u32 tunnel_id);
struct l2tp_tunnel *l2tp_tunnel_find_nth(struct net *net, int nth);

+2 −8
Original line number Diff line number Diff line
@@ -221,12 +221,6 @@ static int l2tp_eth_create(struct net *net, u32 tunnel_id, u32 session_id, u32 p
		goto out;
	}

	session = l2tp_session_find(net, tunnel, session_id);
	if (session) {
		rc = -EEXIST;
		goto out;
	}

	if (cfg->ifname) {
		dev = dev_get_by_name(net, cfg->ifname);
		if (dev) {
@@ -240,8 +234,8 @@ static int l2tp_eth_create(struct net *net, u32 tunnel_id, u32 session_id, u32 p

	session = l2tp_session_create(sizeof(*spriv), tunnel, session_id,
				      peer_session_id, cfg);
	if (!session) {
		rc = -ENOMEM;
	if (IS_ERR(session)) {
		rc = PTR_ERR(session);
		goto out;
	}

+12 −5
Original line number Diff line number Diff line
@@ -143,19 +143,19 @@ static int l2tp_ip_recv(struct sk_buff *skb)
	}

	/* Ok, this is a data packet. Lookup the session. */
	session = l2tp_session_find(net, NULL, session_id);
	if (session == NULL)
	session = l2tp_session_get(net, NULL, session_id, true);
	if (!session)
		goto discard;

	tunnel = session->tunnel;
	if (tunnel == NULL)
		goto discard;
	if (!tunnel)
		goto discard_sess;

	/* Trace packet contents, if enabled */
	if (tunnel->debug & L2TP_MSG_DATA) {
		length = min(32u, skb->len);
		if (!pskb_may_pull(skb, length))
			goto discard;
			goto discard_sess;

		/* Point to L2TP header */
		optr = ptr = skb->data;
@@ -165,6 +165,7 @@ static int l2tp_ip_recv(struct sk_buff *skb)
	}

	l2tp_recv_common(session, skb, ptr, optr, 0, skb->len, tunnel->recv_payload_hook);
	l2tp_session_dec_refcount(session);

	return 0;

@@ -203,6 +204,12 @@ static int l2tp_ip_recv(struct sk_buff *skb)

	return sk_receive_skb(sk, skb, 1);

discard_sess:
	if (session->deref)
		session->deref(session);
	l2tp_session_dec_refcount(session);
	goto discard;

discard_put:
	sock_put(sk);

+13 −5
Original line number Diff line number Diff line
@@ -156,19 +156,19 @@ static int l2tp_ip6_recv(struct sk_buff *skb)
	}

	/* Ok, this is a data packet. Lookup the session. */
	session = l2tp_session_find(net, NULL, session_id);
	if (session == NULL)
	session = l2tp_session_get(net, NULL, session_id, true);
	if (!session)
		goto discard;

	tunnel = session->tunnel;
	if (tunnel == NULL)
		goto discard;
	if (!tunnel)
		goto discard_sess;

	/* Trace packet contents, if enabled */
	if (tunnel->debug & L2TP_MSG_DATA) {
		length = min(32u, skb->len);
		if (!pskb_may_pull(skb, length))
			goto discard;
			goto discard_sess;

		/* Point to L2TP header */
		optr = ptr = skb->data;
@@ -179,6 +179,8 @@ static int l2tp_ip6_recv(struct sk_buff *skb)

	l2tp_recv_common(session, skb, ptr, optr, 0, skb->len,
			 tunnel->recv_payload_hook);
	l2tp_session_dec_refcount(session);

	return 0;

pass_up:
@@ -216,6 +218,12 @@ static int l2tp_ip6_recv(struct sk_buff *skb)

	return sk_receive_skb(sk, skb, 1);

discard_sess:
	if (session->deref)
		session->deref(session);
	l2tp_session_dec_refcount(session);
	goto discard;

discard_put:
	sock_put(sk);

Loading