Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 6539c4f9 authored by Guillaume Nault's avatar Guillaume Nault Committed by Greg Kroah-Hartman
Browse files

l2tp: fix race in l2tp_recv_common()



commit 61b9a047729bb230978178bca6729689d0c50ca2 upstream.

Taking a reference on sessions in l2tp_recv_common() is racy; this
has to be done by the callers.

To this end, a new function is required (l2tp_session_get()) to
atomically lookup a session and take a reference on it. Callers then
have to manually drop this reference.

Fixes: fd558d18 ("l2tp: Split pppol2tp patch into separate l2tp and ppp parts")
Signed-off-by: default avatarGuillaume Nault <g.nault@alphalink.fr>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
Signed-off-by: default avatarAmit Pundir <amit.pundir@linaro.org>
Signed-off-by: default avatarGreg Kroah-Hartman <gregkh@linuxfoundation.org>
parent d2da8d39
Loading
Loading
Loading
Loading
+60 −13
Original line number Diff line number Diff line
@@ -278,6 +278,55 @@ struct l2tp_session *l2tp_session_find(struct net *net, struct l2tp_tunnel *tunn
}
EXPORT_SYMBOL_GPL(l2tp_session_find);

/* Like l2tp_session_find() but takes a reference on the returned session.
 * Optionally calls session->ref() too if do_ref is true.
 */
struct l2tp_session *l2tp_session_get(struct net *net,
				      struct l2tp_tunnel *tunnel,
				      u32 session_id, bool do_ref)
{
	struct hlist_head *session_list;
	struct l2tp_session *session;

	if (!tunnel) {
		struct l2tp_net *pn = l2tp_pernet(net);

		session_list = l2tp_session_id_hash_2(pn, session_id);

		rcu_read_lock_bh();
		hlist_for_each_entry_rcu(session, session_list, global_hlist) {
			if (session->session_id == session_id) {
				l2tp_session_inc_refcount(session);
				if (do_ref && session->ref)
					session->ref(session);
				rcu_read_unlock_bh();

				return session;
			}
		}
		rcu_read_unlock_bh();

		return NULL;
	}

	session_list = l2tp_session_id_hash(tunnel, session_id);
	read_lock_bh(&tunnel->hlist_lock);
	hlist_for_each_entry(session, session_list, hlist) {
		if (session->session_id == session_id) {
			l2tp_session_inc_refcount(session);
			if (do_ref && session->ref)
				session->ref(session);
			read_unlock_bh(&tunnel->hlist_lock);

			return session;
		}
	}
	read_unlock_bh(&tunnel->hlist_lock);

	return NULL;
}
EXPORT_SYMBOL_GPL(l2tp_session_get);

struct l2tp_session *l2tp_session_get_nth(struct l2tp_tunnel *tunnel, int nth,
					  bool do_ref)
{
@@ -637,6 +686,9 @@ static int l2tp_recv_data_seq(struct l2tp_session *session, struct sk_buff *skb)
 * a data (not control) frame before coming here. Fields up to the
 * session-id have already been parsed and ptr points to the data
 * after the session-id.
 *
 * session->ref() must have been called prior to l2tp_recv_common().
 * session->deref() will be called automatically after skb is processed.
 */
void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
		      unsigned char *ptr, unsigned char *optr, u16 hdrflags,
@@ -646,14 +698,6 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
	int offset;
	u32 ns, nr;

	/* The ref count is increased since we now hold a pointer to
	 * the session. Take care to decrement the refcnt when exiting
	 * this function from now on...
	 */
	l2tp_session_inc_refcount(session);
	if (session->ref)
		(*session->ref)(session);

	/* Parse and check optional cookie */
	if (session->peer_cookie_len > 0) {
		if (memcmp(ptr, &session->peer_cookie[0], session->peer_cookie_len)) {
@@ -806,8 +850,6 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
	/* Try to dequeue as many skbs from reorder_q as we can. */
	l2tp_recv_dequeue(session);

	l2tp_session_dec_refcount(session);

	return;

discard:
@@ -816,8 +858,6 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,

	if (session->deref)
		(*session->deref)(session);

	l2tp_session_dec_refcount(session);
}
EXPORT_SYMBOL(l2tp_recv_common);

@@ -924,8 +964,14 @@ static int l2tp_udp_recv_core(struct l2tp_tunnel *tunnel, struct sk_buff *skb,
	}

	/* Find the session context */
	session = l2tp_session_find(tunnel->l2tp_net, tunnel, session_id);
	session = l2tp_session_get(tunnel->l2tp_net, tunnel, session_id, true);
	if (!session || !session->recv_skb) {
		if (session) {
			if (session->deref)
				session->deref(session);
			l2tp_session_dec_refcount(session);
		}

		/* Not found? Pass to userspace to deal with */
		l2tp_info(tunnel, L2TP_MSG_DATA,
			  "%s: no session found (%u/%u). Passing up.\n",
@@ -934,6 +980,7 @@ static int l2tp_udp_recv_core(struct l2tp_tunnel *tunnel, struct sk_buff *skb,
	}

	l2tp_recv_common(session, skb, ptr, optr, hdrflags, length, payload_hook);
	l2tp_session_dec_refcount(session);

	return 0;

+3 −0
Original line number Diff line number Diff line
@@ -240,6 +240,9 @@ static inline struct l2tp_tunnel *l2tp_sock_to_tunnel(struct sock *sk)
	return tunnel;
}

struct l2tp_session *l2tp_session_get(struct net *net,
				      struct l2tp_tunnel *tunnel,
				      u32 session_id, bool do_ref);
struct l2tp_session *l2tp_session_find(struct net *net,
				       struct l2tp_tunnel *tunnel,
				       u32 session_id);
+12 −5
Original line number Diff line number Diff line
@@ -143,19 +143,19 @@ static int l2tp_ip_recv(struct sk_buff *skb)
	}

	/* Ok, this is a data packet. Lookup the session. */
	session = l2tp_session_find(net, NULL, session_id);
	if (session == NULL)
	session = l2tp_session_get(net, NULL, session_id, true);
	if (!session)
		goto discard;

	tunnel = session->tunnel;
	if (tunnel == NULL)
		goto discard;
	if (!tunnel)
		goto discard_sess;

	/* Trace packet contents, if enabled */
	if (tunnel->debug & L2TP_MSG_DATA) {
		length = min(32u, skb->len);
		if (!pskb_may_pull(skb, length))
			goto discard;
			goto discard_sess;

		/* Point to L2TP header */
		optr = ptr = skb->data;
@@ -165,6 +165,7 @@ static int l2tp_ip_recv(struct sk_buff *skb)
	}

	l2tp_recv_common(session, skb, ptr, optr, 0, skb->len, tunnel->recv_payload_hook);
	l2tp_session_dec_refcount(session);

	return 0;

@@ -203,6 +204,12 @@ static int l2tp_ip_recv(struct sk_buff *skb)

	return sk_receive_skb(sk, skb, 1);

discard_sess:
	if (session->deref)
		session->deref(session);
	l2tp_session_dec_refcount(session);
	goto discard;

discard_put:
	sock_put(sk);

+13 −5
Original line number Diff line number Diff line
@@ -156,19 +156,19 @@ static int l2tp_ip6_recv(struct sk_buff *skb)
	}

	/* Ok, this is a data packet. Lookup the session. */
	session = l2tp_session_find(net, NULL, session_id);
	if (session == NULL)
	session = l2tp_session_get(net, NULL, session_id, true);
	if (!session)
		goto discard;

	tunnel = session->tunnel;
	if (tunnel == NULL)
		goto discard;
	if (!tunnel)
		goto discard_sess;

	/* Trace packet contents, if enabled */
	if (tunnel->debug & L2TP_MSG_DATA) {
		length = min(32u, skb->len);
		if (!pskb_may_pull(skb, length))
			goto discard;
			goto discard_sess;

		/* Point to L2TP header */
		optr = ptr = skb->data;
@@ -179,6 +179,8 @@ static int l2tp_ip6_recv(struct sk_buff *skb)

	l2tp_recv_common(session, skb, ptr, optr, 0, skb->len,
			 tunnel->recv_payload_hook);
	l2tp_session_dec_refcount(session);

	return 0;

pass_up:
@@ -216,6 +218,12 @@ static int l2tp_ip6_recv(struct sk_buff *skb)

	return sk_receive_skb(sk, skb, 1);

discard_sess:
	if (session->deref)
		session->deref(session);
	l2tp_session_dec_refcount(session);
	goto discard;

discard_put:
	sock_put(sk);