Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 48d0e023 authored by Paul Moore's avatar Paul Moore
Browse files

audit: fix the RCU locking for the auditd_connection structure



Cong Wang correctly pointed out that the RCU read locking of the
auditd_connection struct was wrong, this patch correct this by
adopting a more traditional, and correct RCU locking model.

This patch is heavily based on an earlier prototype by Cong Wang.

Cc: <stable@vger.kernel.org> # 4.11.x-
Reported-by: default avatarCong Wang <xiyou.wangcong@gmail.com>
Signed-off-by: default avatarCong Wang <xiyou.wangcong@gmail.com>
Signed-off-by: default avatarPaul Moore <paul@paul-moore.com>
parent 8cc96382
Loading
Loading
Loading
Loading
+100 −57
Original line number Diff line number Diff line
@@ -112,18 +112,19 @@ struct audit_net {
 * @pid: auditd PID
 * @portid: netlink portid
 * @net: the associated network namespace
 * @lock: spinlock to protect write access
 * @rcu: RCU head
 *
 * Description:
 * This struct is RCU protected; you must either hold the RCU lock for reading
 * or the included spinlock for writing.
 * or the associated spinlock for writing.
 */
static struct auditd_connection {
	struct pid *pid;
	u32 portid;
	struct net *net;
	spinlock_t lock;
} auditd_conn;
	struct rcu_head rcu;
} *auditd_conn = NULL;
static DEFINE_SPINLOCK(auditd_conn_lock);

/* If audit_rate_limit is non-zero, limit the rate of sending audit records
 * to that number per second.  This prevents DoS attacks, but results in
@@ -215,9 +216,11 @@ struct audit_reply {
int auditd_test_task(struct task_struct *task)
{
	int rc;
	struct auditd_connection *ac;

	rcu_read_lock();
	rc = (auditd_conn.pid && auditd_conn.pid == task_tgid(task) ? 1 : 0);
	ac = rcu_dereference(auditd_conn);
	rc = (ac && ac->pid == task_tgid(task) ? 1 : 0);
	rcu_read_unlock();

	return rc;
@@ -225,22 +228,21 @@ int auditd_test_task(struct task_struct *task)

/**
 * auditd_pid_vnr - Return the auditd PID relative to the namespace
 * @auditd: the auditd connection
 *
 * Description:
 * Returns the PID in relation to the namespace, 0 on failure.  This function
 * takes the RCU read lock internally, but if the caller needs to protect the
 * auditd_connection pointer it should take the RCU read lock as well.
 * Returns the PID in relation to the namespace, 0 on failure.
 */
static pid_t auditd_pid_vnr(const struct auditd_connection *auditd)
static pid_t auditd_pid_vnr(void)
{
	pid_t pid;
	const struct auditd_connection *ac;

	rcu_read_lock();
	if (!auditd || !auditd->pid)
	ac = rcu_dereference(auditd_conn);
	if (!ac || !ac->pid)
		pid = 0;
	else
		pid = pid_vnr(auditd->pid);
		pid = pid_vnr(ac->pid);
	rcu_read_unlock();

	return pid;
@@ -433,6 +435,24 @@ static int audit_set_failure(u32 state)
	return audit_do_config_change("audit_failure", &audit_failure, state);
}

/**
 * auditd_conn_free - RCU helper to release an auditd connection struct
 * @rcu: RCU head
 *
 * Description:
 * Drop any references inside the auditd connection tracking struct and free
 * the memory.
 */
 static void auditd_conn_free(struct rcu_head *rcu)
 {
	struct auditd_connection *ac;

	ac = container_of(rcu, struct auditd_connection, rcu);
	put_pid(ac->pid);
	put_net(ac->net);
	kfree(ac);
 }

/**
 * auditd_set - Set/Reset the auditd connection state
 * @pid: auditd PID
@@ -441,27 +461,33 @@ static int audit_set_failure(u32 state)
 *
 * Description:
 * This function will obtain and drop network namespace references as
 * necessary.
 * necessary.  Returns zero on success, negative values on failure.
 */
static void auditd_set(struct pid *pid, u32 portid, struct net *net)
static int auditd_set(struct pid *pid, u32 portid, struct net *net)
{
	unsigned long flags;
	struct auditd_connection *ac_old, *ac_new;

	spin_lock_irqsave(&auditd_conn.lock, flags);
	if (auditd_conn.pid)
		put_pid(auditd_conn.pid);
	if (pid)
		auditd_conn.pid = get_pid(pid);
	else
		auditd_conn.pid = NULL;
	auditd_conn.portid = portid;
	if (auditd_conn.net)
		put_net(auditd_conn.net);
	if (net)
		auditd_conn.net = get_net(net);
	else
		auditd_conn.net = NULL;
	spin_unlock_irqrestore(&auditd_conn.lock, flags);
	if (!pid || !net)
		return -EINVAL;

	ac_new = kzalloc(sizeof(*ac_new), GFP_KERNEL);
	if (!ac_new)
		return -ENOMEM;
	ac_new->pid = get_pid(pid);
	ac_new->portid = portid;
	ac_new->net = get_net(net);

	spin_lock_irqsave(&auditd_conn_lock, flags);
	ac_old = rcu_dereference_protected(auditd_conn,
					   lockdep_is_held(&auditd_conn_lock));
	rcu_assign_pointer(auditd_conn, ac_new);
	spin_unlock_irqrestore(&auditd_conn_lock, flags);

	if (ac_old)
		call_rcu(&ac_old->rcu, auditd_conn_free);

	return 0;
}

/**
@@ -556,13 +582,19 @@ static void kauditd_retry_skb(struct sk_buff *skb)
 */
static void auditd_reset(void)
{
	unsigned long flags;
	struct sk_buff *skb;
	struct auditd_connection *ac_old;

	/* if it isn't already broken, break the connection */
	rcu_read_lock();
	if (auditd_conn.pid)
		auditd_set(0, 0, NULL);
	rcu_read_unlock();
	spin_lock_irqsave(&auditd_conn_lock, flags);
	ac_old = rcu_dereference_protected(auditd_conn,
					   lockdep_is_held(&auditd_conn_lock));
	rcu_assign_pointer(auditd_conn, NULL);
	spin_unlock_irqrestore(&auditd_conn_lock, flags);

	if (ac_old)
		call_rcu(&ac_old->rcu, auditd_conn_free);

	/* flush all of the main and retry queues to the hold queue */
	while ((skb = skb_dequeue(&audit_retry_queue)))
@@ -588,6 +620,7 @@ static int auditd_send_unicast_skb(struct sk_buff *skb)
	u32 portid;
	struct net *net;
	struct sock *sk;
	struct auditd_connection *ac;

	/* NOTE: we can't call netlink_unicast while in the RCU section so
	 *       take a reference to the network namespace and grab local
@@ -597,15 +630,15 @@ static int auditd_send_unicast_skb(struct sk_buff *skb)
	 *       section netlink_unicast() should safely return an error */

	rcu_read_lock();
	if (!auditd_conn.pid) {
	ac = rcu_dereference(auditd_conn);
	if (!ac) {
		rcu_read_unlock();
		rc = -ECONNREFUSED;
		goto err;
	}
	net = auditd_conn.net;
	get_net(net);
	net = get_net(ac->net);
	sk = audit_get_sk(net);
	portid = auditd_conn.portid;
	portid = ac->portid;
	rcu_read_unlock();

	rc = netlink_unicast(sk, skb, portid, 0);
@@ -740,6 +773,7 @@ static int kauditd_thread(void *dummy)
	u32 portid = 0;
	struct net *net = NULL;
	struct sock *sk = NULL;
	struct auditd_connection *ac;

#define UNICAST_RETRIES 5

@@ -747,14 +781,14 @@ static int kauditd_thread(void *dummy)
	while (!kthread_should_stop()) {
		/* NOTE: see the lock comments in auditd_send_unicast_skb() */
		rcu_read_lock();
		if (!auditd_conn.pid) {
		ac = rcu_dereference(auditd_conn);
		if (!ac) {
			rcu_read_unlock();
			goto main_queue;
		}
		net = auditd_conn.net;
		get_net(net);
		net = get_net(ac->net);
		sk = audit_get_sk(net);
		portid = auditd_conn.portid;
		portid = ac->portid;
		rcu_read_unlock();

		/* attempt to flush the hold queue */
@@ -1117,7 +1151,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
		s.failure		= audit_failure;
		/* NOTE: use pid_vnr() so the PID is relative to the current
		 *       namespace */
		s.pid			= auditd_pid_vnr(&auditd_conn);
		s.pid			= auditd_pid_vnr();
		s.rate_limit		= audit_rate_limit;
		s.backlog_limit		= audit_backlog_limit;
		s.lost			= atomic_read(&audit_lost);
@@ -1160,7 +1194,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
			/* test the auditd connection */
			audit_replace(req_pid);

			auditd_pid = auditd_pid_vnr(&auditd_conn);
			auditd_pid = auditd_pid_vnr();
			/* only the current auditd can unregister itself */
			if ((!new_pid) && (new_pid != auditd_pid)) {
				audit_log_config_change("audit_pid", new_pid,
@@ -1174,20 +1208,31 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
				return -EEXIST;
			}

			if (audit_enabled != AUDIT_OFF)
				audit_log_config_change("audit_pid", new_pid,
							auditd_pid, 1);

			if (new_pid) {
				/* register a new auditd connection */
				auditd_set(req_pid, NETLINK_CB(skb).portid,
				err = auditd_set(req_pid,
						 NETLINK_CB(skb).portid,
						 sock_net(NETLINK_CB(skb).sk));
				if (audit_enabled != AUDIT_OFF)
					audit_log_config_change("audit_pid",
								new_pid,
								auditd_pid,
								err ? 0 : 1);
				if (err)
					return err;

				/* try to process any backlog */
				wake_up_interruptible(&kauditd_wait);
			} else
			} else {
				if (audit_enabled != AUDIT_OFF)
					audit_log_config_change("audit_pid",
								new_pid,
								auditd_pid, 1);

				/* unregister the auditd connection */
				auditd_reset();
			}
		}
		if (s.mask & AUDIT_STATUS_RATE_LIMIT) {
			err = audit_set_rate_limit(s.rate_limit);
			if (err < 0)
@@ -1454,10 +1499,11 @@ static void __net_exit audit_net_exit(struct net *net)
{
	struct audit_net *aunet = net_generic(net, audit_net_id);

	rcu_read_lock();
	if (net == auditd_conn.net)
		auditd_reset();
	rcu_read_unlock();
	/* NOTE: you would think that we would want to check the auditd
	 * connection and potentially reset it here if it lives in this
	 * namespace, but since the auditd connection tracking struct holds a
	 * reference to this namespace (see auditd_set()) we are only ever
	 * going to get here after that connection has been released */

	netlink_kernel_release(aunet->sk);
}
@@ -1481,9 +1527,6 @@ static int __init audit_init(void)
					       sizeof(struct audit_buffer),
					       0, SLAB_PANIC, NULL);

	memset(&auditd_conn, 0, sizeof(auditd_conn));
	spin_lock_init(&auditd_conn.lock);

	skb_queue_head_init(&audit_queue);
	skb_queue_head_init(&audit_retry_queue);
	skb_queue_head_init(&audit_hold_queue);