Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 41c6d650 authored by Reshetova, Elena's avatar Reshetova, Elena Committed by David S. Miller
Browse files

net: convert sock.sk_refcnt from atomic_t to refcount_t



refcount_t type and corresponding API should be
used instead of atomic_t when the variable is used as
a reference counter. This allows to avoid accidental
refcounter overflows that might lead to use-after-free
situations.

This patch uses refcount_inc_not_zero() instead of
atomic_inc_not_zero_hint() due to absense of a _hint()
version of refcount API. If the hint() version must
be used, we might need to revisit API.

Signed-off-by: default avatarElena Reshetova <elena.reshetova@intel.com>
Signed-off-by: default avatarHans Liljestrand <ishkamiel@gmail.com>
Signed-off-by: default avatarKees Cook <keescook@chromium.org>
Signed-off-by: default avatarDavid Windsor <dwindsor@gmail.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 14afee4b
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -877,7 +877,7 @@ static void aead_sock_destruct(struct sock *sk)
	unsigned int ivlen = crypto_aead_ivsize(
				crypto_aead_reqtfm(&ctx->aead_req));

	WARN_ON(atomic_read(&sk->sk_refcnt) != 0);
	WARN_ON(refcount_read(&sk->sk_refcnt) != 0);
	aead_put_sgl(sk);
	sock_kzfree_s(sk, ctx->iv, ivlen);
	sock_kfree_s(sk, ctx, ctx->len);
+2 −2
Original line number Diff line number Diff line
@@ -32,7 +32,7 @@
#include <net/tcp_states.h>
#include <net/netns/hash.h>

#include <linux/atomic.h>
#include <linux/refcount.h>
#include <asm/byteorder.h>

/* This is for all connections with a full identity, no wildcards.
@@ -334,7 +334,7 @@ static inline struct sock *inet_lookup(struct net *net,
	sk = __inet_lookup(net, hashinfo, skb, doff, saddr, sport, daddr,
			   dport, dif, &refcounted);

	if (sk && !refcounted && !atomic_inc_not_zero(&sk->sk_refcnt))
	if (sk && !refcounted && !refcount_inc_not_zero(&sk->sk_refcnt))
		sk = NULL;
	return sk;
}
+5 −4
Original line number Diff line number Diff line
@@ -19,6 +19,7 @@
#include <linux/spinlock.h>
#include <linux/types.h>
#include <linux/bug.h>
#include <linux/refcount.h>

#include <net/sock.h>

@@ -89,7 +90,7 @@ reqsk_alloc(const struct request_sock_ops *ops, struct sock *sk_listener,
		return NULL;
	req->rsk_listener = NULL;
	if (attach_listener) {
		if (unlikely(!atomic_inc_not_zero(&sk_listener->sk_refcnt))) {
		if (unlikely(!refcount_inc_not_zero(&sk_listener->sk_refcnt))) {
			kmem_cache_free(ops->slab, req);
			return NULL;
		}
@@ -100,7 +101,7 @@ reqsk_alloc(const struct request_sock_ops *ops, struct sock *sk_listener,
	sk_node_init(&req_to_sk(req)->sk_node);
	sk_tx_queue_clear(req_to_sk(req));
	req->saved_syn = NULL;
	atomic_set(&req->rsk_refcnt, 0);
	refcount_set(&req->rsk_refcnt, 0);

	return req;
}
@@ -108,7 +109,7 @@ reqsk_alloc(const struct request_sock_ops *ops, struct sock *sk_listener,
static inline void reqsk_free(struct request_sock *req)
{
	/* temporary debugging */
	WARN_ON_ONCE(atomic_read(&req->rsk_refcnt) != 0);
	WARN_ON_ONCE(refcount_read(&req->rsk_refcnt) != 0);

	req->rsk_ops->destructor(req);
	if (req->rsk_listener)
@@ -119,7 +120,7 @@ static inline void reqsk_free(struct request_sock *req)

static inline void reqsk_put(struct request_sock *req)
{
	if (atomic_dec_and_test(&req->rsk_refcnt))
	if (refcount_dec_and_test(&req->rsk_refcnt))
		reqsk_free(req);
}

+9 −8
Original line number Diff line number Diff line
@@ -66,6 +66,7 @@
#include <linux/poll.h>

#include <linux/atomic.h>
#include <linux/refcount.h>
#include <net/dst.h>
#include <net/checksum.h>
#include <net/tcp_states.h>
@@ -219,7 +220,7 @@ struct sock_common {
		u32		skc_tw_rcv_nxt; /* struct tcp_timewait_sock  */
	};

	atomic_t		skc_refcnt;
	refcount_t		skc_refcnt;
	/* private: */
	int                     skc_dontcopy_end[0];
	union {
@@ -611,7 +612,7 @@ static inline bool __sk_del_node_init(struct sock *sk)

static __always_inline void sock_hold(struct sock *sk)
{
	atomic_inc(&sk->sk_refcnt);
	refcount_inc(&sk->sk_refcnt);
}

/* Ungrab socket in the context, which assumes that socket refcnt
@@ -619,7 +620,7 @@ static __always_inline void sock_hold(struct sock *sk)
 */
static __always_inline void __sock_put(struct sock *sk)
{
	atomic_dec(&sk->sk_refcnt);
	refcount_dec(&sk->sk_refcnt);
}

static inline bool sk_del_node_init(struct sock *sk)
@@ -628,7 +629,7 @@ static inline bool sk_del_node_init(struct sock *sk)

	if (rc) {
		/* paranoid for a while -acme */
		WARN_ON(atomic_read(&sk->sk_refcnt) == 1);
		WARN_ON(refcount_read(&sk->sk_refcnt) == 1);
		__sock_put(sk);
	}
	return rc;
@@ -650,7 +651,7 @@ static inline bool sk_nulls_del_node_init_rcu(struct sock *sk)

	if (rc) {
		/* paranoid for a while -acme */
		WARN_ON(atomic_read(&sk->sk_refcnt) == 1);
		WARN_ON(refcount_read(&sk->sk_refcnt) == 1);
		__sock_put(sk);
	}
	return rc;
@@ -1144,9 +1145,9 @@ static inline void sk_refcnt_debug_dec(struct sock *sk)

static inline void sk_refcnt_debug_release(const struct sock *sk)
{
	if (atomic_read(&sk->sk_refcnt) != 1)
	if (refcount_read(&sk->sk_refcnt) != 1)
		printk(KERN_DEBUG "Destruction of the %s socket %p delayed, refcnt=%d\n",
		       sk->sk_prot->name, sk, atomic_read(&sk->sk_refcnt));
		       sk->sk_prot->name, sk, refcount_read(&sk->sk_refcnt));
}
#else /* SOCK_REFCNT_DEBUG */
#define sk_refcnt_debug_inc(sk) do { } while (0)
@@ -1636,7 +1637,7 @@ void sock_init_data(struct socket *sock, struct sock *sk);
/* Ungrab socket and destroy it, if it was the last reference. */
static inline void sock_put(struct sock *sk)
{
	if (atomic_dec_and_test(&sk->sk_refcnt))
	if (refcount_dec_and_test(&sk->sk_refcnt))
		sk_free(sk);
}
/* Generic version of sock_put(), dealing with all sockets
+1 −1
Original line number Diff line number Diff line
@@ -211,7 +211,7 @@ static void vcc_info(struct seq_file *seq, struct atm_vcc *vcc)
		   vcc->flags, sk->sk_err,
		   sk_wmem_alloc_get(sk), sk->sk_sndbuf,
		   sk_rmem_alloc_get(sk), sk->sk_rcvbuf,
		   atomic_read(&sk->sk_refcnt));
		   refcount_read(&sk->sk_refcnt));
}

static void svc_info(struct seq_file *seq, struct atm_vcc *vcc)
Loading