Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 28457ee6 authored by Arnd Bergmann's avatar Arnd Bergmann Committed by Paul E. McKenney
Browse files

vhost: add __rcu annotations



Also add rcu_dereference_protected() for code paths where locks are held.

Signed-off-by: default avatarArnd Bergmann <arnd@arndb.de>
Signed-off-by: default avatarPaul E. McKenney <paulmck@linux.vnet.ibm.com>
Cc: "Michael S. Tsirkin" <mst@redhat.com>
parent 65e6bf48
Loading
Loading
Loading
Loading
+12 −4
Original line number Original line Diff line number Diff line
@@ -127,7 +127,10 @@ static void handle_tx(struct vhost_net *net)
	size_t len, total_len = 0;
	size_t len, total_len = 0;
	int err, wmem;
	int err, wmem;
	size_t hdr_size;
	size_t hdr_size;
	struct socket *sock = rcu_dereference(vq->private_data);
	struct socket *sock;

	sock = rcu_dereference_check(vq->private_data,
				     lockdep_is_held(&vq->mutex));
	if (!sock)
	if (!sock)
		return;
		return;


@@ -582,7 +585,10 @@ static void vhost_net_disable_vq(struct vhost_net *n,
static void vhost_net_enable_vq(struct vhost_net *n,
static void vhost_net_enable_vq(struct vhost_net *n,
				struct vhost_virtqueue *vq)
				struct vhost_virtqueue *vq)
{
{
	struct socket *sock = vq->private_data;
	struct socket *sock;

	sock = rcu_dereference_protected(vq->private_data,
					 lockdep_is_held(&vq->mutex));
	if (!sock)
	if (!sock)
		return;
		return;
	if (vq == n->vqs + VHOST_NET_VQ_TX) {
	if (vq == n->vqs + VHOST_NET_VQ_TX) {
@@ -598,7 +604,8 @@ static struct socket *vhost_net_stop_vq(struct vhost_net *n,
	struct socket *sock;
	struct socket *sock;


	mutex_lock(&vq->mutex);
	mutex_lock(&vq->mutex);
	sock = vq->private_data;
	sock = rcu_dereference_protected(vq->private_data,
					 lockdep_is_held(&vq->mutex));
	vhost_net_disable_vq(n, vq);
	vhost_net_disable_vq(n, vq);
	rcu_assign_pointer(vq->private_data, NULL);
	rcu_assign_pointer(vq->private_data, NULL);
	mutex_unlock(&vq->mutex);
	mutex_unlock(&vq->mutex);
@@ -736,7 +743,8 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
	}
	}


	/* start polling new socket */
	/* start polling new socket */
	oldsock = vq->private_data;
	oldsock = rcu_dereference_protected(vq->private_data,
					    lockdep_is_held(&vq->mutex));
	if (sock != oldsock) {
	if (sock != oldsock) {
                vhost_net_disable_vq(n, vq);
                vhost_net_disable_vq(n, vq);
                rcu_assign_pointer(vq->private_data, sock);
                rcu_assign_pointer(vq->private_data, sock);
+16 −6
Original line number Original line Diff line number Diff line
@@ -284,7 +284,7 @@ long vhost_dev_reset_owner(struct vhost_dev *dev)
	vhost_dev_cleanup(dev);
	vhost_dev_cleanup(dev);


	memory->nregions = 0;
	memory->nregions = 0;
	dev->memory = memory;
	RCU_INIT_POINTER(dev->memory, memory);
	return 0;
	return 0;
}
}


@@ -316,8 +316,9 @@ void vhost_dev_cleanup(struct vhost_dev *dev)
		fput(dev->log_file);
		fput(dev->log_file);
	dev->log_file = NULL;
	dev->log_file = NULL;
	/* No one will access memory at this point */
	/* No one will access memory at this point */
	kfree(dev->memory);
	kfree(rcu_dereference_protected(dev->memory,
	dev->memory = NULL;
					lockdep_is_held(&dev->mutex)));
	RCU_INIT_POINTER(dev->memory, NULL);
	if (dev->mm)
	if (dev->mm)
		mmput(dev->mm);
		mmput(dev->mm);
	dev->mm = NULL;
	dev->mm = NULL;
@@ -401,14 +402,22 @@ static int vq_access_ok(unsigned int num,
/* Caller should have device mutex but not vq mutex */
/* Caller should have device mutex but not vq mutex */
int vhost_log_access_ok(struct vhost_dev *dev)
int vhost_log_access_ok(struct vhost_dev *dev)
{
{
	return memory_access_ok(dev, dev->memory, 1);
	struct vhost_memory *mp;

	mp = rcu_dereference_protected(dev->memory,
				       lockdep_is_held(&dev->mutex));
	return memory_access_ok(dev, mp, 1);
}
}


/* Verify access for write logging. */
/* Verify access for write logging. */
/* Caller should have vq mutex and device mutex */
/* Caller should have vq mutex and device mutex */
static int vq_log_access_ok(struct vhost_virtqueue *vq, void __user *log_base)
static int vq_log_access_ok(struct vhost_virtqueue *vq, void __user *log_base)
{
{
	return vq_memory_access_ok(log_base, vq->dev->memory,
	struct vhost_memory *mp;

	mp = rcu_dereference_protected(vq->dev->memory,
				       lockdep_is_held(&vq->mutex));
	return vq_memory_access_ok(log_base, mp,
			    vhost_has_feature(vq->dev, VHOST_F_LOG_ALL)) &&
			    vhost_has_feature(vq->dev, VHOST_F_LOG_ALL)) &&
		(!vq->log_used || log_access_ok(log_base, vq->log_addr,
		(!vq->log_used || log_access_ok(log_base, vq->log_addr,
					sizeof *vq->used +
					sizeof *vq->used +
@@ -448,7 +457,8 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
		kfree(newmem);
		kfree(newmem);
		return -EFAULT;
		return -EFAULT;
	}
	}
	oldmem = d->memory;
	oldmem = rcu_dereference_protected(d->memory,
					   lockdep_is_held(&d->mutex));
	rcu_assign_pointer(d->memory, newmem);
	rcu_assign_pointer(d->memory, newmem);
	synchronize_rcu();
	synchronize_rcu();
	kfree(oldmem);
	kfree(oldmem);
+7 −3
Original line number Original line Diff line number Diff line
@@ -106,7 +106,7 @@ struct vhost_virtqueue {
	 * vhost_work execution acts instead of rcu_read_lock() and the end of
	 * vhost_work execution acts instead of rcu_read_lock() and the end of
	 * vhost_work execution acts instead of rcu_read_lock().
	 * vhost_work execution acts instead of rcu_read_lock().
	 * Writers use virtqueue mutex. */
	 * Writers use virtqueue mutex. */
	void *private_data;
	void __rcu *private_data;
	/* Log write descriptors */
	/* Log write descriptors */
	void __user *log_base;
	void __user *log_base;
	struct vhost_log log[VHOST_NET_MAX_SG];
	struct vhost_log log[VHOST_NET_MAX_SG];
@@ -116,7 +116,7 @@ struct vhost_dev {
	/* Readers use RCU to access memory table pointer
	/* Readers use RCU to access memory table pointer
	 * log base pointer and features.
	 * log base pointer and features.
	 * Writers use mutex below.*/
	 * Writers use mutex below.*/
	struct vhost_memory *memory;
	struct vhost_memory __rcu *memory;
	struct mm_struct *mm;
	struct mm_struct *mm;
	struct mutex mutex;
	struct mutex mutex;
	unsigned acked_features;
	unsigned acked_features;
@@ -173,7 +173,11 @@ enum {


static inline int vhost_has_feature(struct vhost_dev *dev, int bit)
static inline int vhost_has_feature(struct vhost_dev *dev, int bit)
{
{
	unsigned acked_features = rcu_dereference(dev->acked_features);
	unsigned acked_features;

	acked_features =
		rcu_dereference_index_check(dev->acked_features,
					    lockdep_is_held(&dev->mutex));
	return acked_features & (1 << bit);
	return acked_features & (1 << bit);
}
}