Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 3c81bdd9 authored by Linus Torvalds's avatar Linus Torvalds
Browse files
Pull vhost infrastructure updates from Michael S. Tsirkin:
 "This reworks vhost core dropping unnecessary RCU uses in favor of VQ
  mutexes which are used on fast path anyway.  This fixes worst-case
  latency for users which change the memory mappings a lot.  Memory
  allocation for vhost-net now supports fallback on vmalloc (same as for
  vhost-scsi) this makes it possible to create the device on systems
  where memory is very fragmented, with slightly lower performance"

* tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost:
  vhost: move memory pointer to VQs
  vhost: move acked_features to VQs
  vhost: replace rcu with mutex
  vhost-net: extend device allocation to vmalloc
parents 7ec6131b 47283bef
Loading
Loading
Loading
Loading
+23 −12
Original line number Original line Diff line number Diff line
@@ -17,6 +17,7 @@
#include <linux/workqueue.h>
#include <linux/workqueue.h>
#include <linux/file.h>
#include <linux/file.h>
#include <linux/slab.h>
#include <linux/slab.h>
#include <linux/vmalloc.h>


#include <linux/net.h>
#include <linux/net.h>
#include <linux/if_packet.h>
#include <linux/if_packet.h>
@@ -373,7 +374,7 @@ static void handle_tx(struct vhost_net *net)
			      % UIO_MAXIOV == nvq->done_idx))
			      % UIO_MAXIOV == nvq->done_idx))
			break;
			break;


		head = vhost_get_vq_desc(&net->dev, vq, vq->iov,
		head = vhost_get_vq_desc(vq, vq->iov,
					 ARRAY_SIZE(vq->iov),
					 ARRAY_SIZE(vq->iov),
					 &out, &in,
					 &out, &in,
					 NULL, NULL);
					 NULL, NULL);
@@ -505,7 +506,7 @@ static int get_rx_bufs(struct vhost_virtqueue *vq,
			r = -ENOBUFS;
			r = -ENOBUFS;
			goto err;
			goto err;
		}
		}
		r = vhost_get_vq_desc(vq->dev, vq, vq->iov + seg,
		r = vhost_get_vq_desc(vq, vq->iov + seg,
				      ARRAY_SIZE(vq->iov) - seg, &out,
				      ARRAY_SIZE(vq->iov) - seg, &out,
				      &in, log, log_num);
				      &in, log, log_num);
		if (unlikely(r < 0))
		if (unlikely(r < 0))
@@ -584,9 +585,9 @@ static void handle_rx(struct vhost_net *net)
	vhost_hlen = nvq->vhost_hlen;
	vhost_hlen = nvq->vhost_hlen;
	sock_hlen = nvq->sock_hlen;
	sock_hlen = nvq->sock_hlen;


	vq_log = unlikely(vhost_has_feature(&net->dev, VHOST_F_LOG_ALL)) ?
	vq_log = unlikely(vhost_has_feature(vq, VHOST_F_LOG_ALL)) ?
		vq->log : NULL;
		vq->log : NULL;
	mergeable = vhost_has_feature(&net->dev, VIRTIO_NET_F_MRG_RXBUF);
	mergeable = vhost_has_feature(vq, VIRTIO_NET_F_MRG_RXBUF);


	while ((sock_len = peek_head_len(sock->sk))) {
	while ((sock_len = peek_head_len(sock->sk))) {
		sock_len += sock_hlen;
		sock_len += sock_hlen;
@@ -699,18 +700,30 @@ static void handle_rx_net(struct vhost_work *work)
	handle_rx(net);
	handle_rx(net);
}
}


static void vhost_net_free(void *addr)
{
	if (is_vmalloc_addr(addr))
		vfree(addr);
	else
		kfree(addr);
}

static int vhost_net_open(struct inode *inode, struct file *f)
static int vhost_net_open(struct inode *inode, struct file *f)
{
{
	struct vhost_net *n = kmalloc(sizeof *n, GFP_KERNEL);
	struct vhost_net *n;
	struct vhost_dev *dev;
	struct vhost_dev *dev;
	struct vhost_virtqueue **vqs;
	struct vhost_virtqueue **vqs;
	int i;
	int i;


	n = kmalloc(sizeof *n, GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT);
	if (!n) {
		n = vmalloc(sizeof *n);
		if (!n)
		if (!n)
			return -ENOMEM;
			return -ENOMEM;
	}
	vqs = kmalloc(VHOST_NET_VQ_MAX * sizeof(*vqs), GFP_KERNEL);
	vqs = kmalloc(VHOST_NET_VQ_MAX * sizeof(*vqs), GFP_KERNEL);
	if (!vqs) {
	if (!vqs) {
		kfree(n);
		vhost_net_free(n);
		return -ENOMEM;
		return -ENOMEM;
	}
	}


@@ -827,7 +840,7 @@ static int vhost_net_release(struct inode *inode, struct file *f)
	 * since jobs can re-queue themselves. */
	 * since jobs can re-queue themselves. */
	vhost_net_flush(n);
	vhost_net_flush(n);
	kfree(n->dev.vqs);
	kfree(n->dev.vqs);
	kfree(n);
	vhost_net_free(n);
	return 0;
	return 0;
}
}


@@ -1038,15 +1051,13 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features)
		mutex_unlock(&n->dev.mutex);
		mutex_unlock(&n->dev.mutex);
		return -EFAULT;
		return -EFAULT;
	}
	}
	n->dev.acked_features = features;
	smp_wmb();
	for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
	for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
		mutex_lock(&n->vqs[i].vq.mutex);
		mutex_lock(&n->vqs[i].vq.mutex);
		n->vqs[i].vq.acked_features = features;
		n->vqs[i].vhost_hlen = vhost_hlen;
		n->vqs[i].vhost_hlen = vhost_hlen;
		n->vqs[i].sock_hlen = sock_hlen;
		n->vqs[i].sock_hlen = sock_hlen;
		mutex_unlock(&n->vqs[i].vq.mutex);
		mutex_unlock(&n->vqs[i].vq.mutex);
	}
	}
	vhost_net_flush(n);
	mutex_unlock(&n->dev.mutex);
	mutex_unlock(&n->dev.mutex);
	return 0;
	return 0;
}
}
+15 −11
Original line number Original line Diff line number Diff line
@@ -606,7 +606,7 @@ tcm_vhost_do_evt_work(struct vhost_scsi *vs, struct tcm_vhost_evt *evt)


again:
again:
	vhost_disable_notify(&vs->dev, vq);
	vhost_disable_notify(&vs->dev, vq);
	head = vhost_get_vq_desc(&vs->dev, vq, vq->iov,
	head = vhost_get_vq_desc(vq, vq->iov,
			ARRAY_SIZE(vq->iov), &out, &in,
			ARRAY_SIZE(vq->iov), &out, &in,
			NULL, NULL);
			NULL, NULL);
	if (head < 0) {
	if (head < 0) {
@@ -945,7 +945,7 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq)
	vhost_disable_notify(&vs->dev, vq);
	vhost_disable_notify(&vs->dev, vq);


	for (;;) {
	for (;;) {
		head = vhost_get_vq_desc(&vs->dev, vq, vq->iov,
		head = vhost_get_vq_desc(vq, vq->iov,
					ARRAY_SIZE(vq->iov), &out, &in,
					ARRAY_SIZE(vq->iov), &out, &in,
					NULL, NULL);
					NULL, NULL);
		pr_debug("vhost_get_vq_desc: head: %d, out: %u in: %u\n",
		pr_debug("vhost_get_vq_desc: head: %d, out: %u in: %u\n",
@@ -1373,6 +1373,9 @@ vhost_scsi_clear_endpoint(struct vhost_scsi *vs,


static int vhost_scsi_set_features(struct vhost_scsi *vs, u64 features)
static int vhost_scsi_set_features(struct vhost_scsi *vs, u64 features)
{
{
	struct vhost_virtqueue *vq;
	int i;

	if (features & ~VHOST_SCSI_FEATURES)
	if (features & ~VHOST_SCSI_FEATURES)
		return -EOPNOTSUPP;
		return -EOPNOTSUPP;


@@ -1382,9 +1385,13 @@ static int vhost_scsi_set_features(struct vhost_scsi *vs, u64 features)
		mutex_unlock(&vs->dev.mutex);
		mutex_unlock(&vs->dev.mutex);
		return -EFAULT;
		return -EFAULT;
	}
	}
	vs->dev.acked_features = features;

	smp_wmb();
	for (i = 0; i < VHOST_SCSI_MAX_VQ; i++) {
	vhost_scsi_flush(vs);
		vq = &vs->vqs[i].vq;
		mutex_lock(&vq->mutex);
		vq->acked_features = features;
		mutex_unlock(&vq->mutex);
	}
	mutex_unlock(&vs->dev.mutex);
	mutex_unlock(&vs->dev.mutex);
	return 0;
	return 0;
}
}
@@ -1591,10 +1598,6 @@ tcm_vhost_do_plug(struct tcm_vhost_tpg *tpg,
		return;
		return;


	mutex_lock(&vs->dev.mutex);
	mutex_lock(&vs->dev.mutex);
	if (!vhost_has_feature(&vs->dev, VIRTIO_SCSI_F_HOTPLUG)) {
		mutex_unlock(&vs->dev.mutex);
		return;
	}


	if (plug)
	if (plug)
		reason = VIRTIO_SCSI_EVT_RESET_RESCAN;
		reason = VIRTIO_SCSI_EVT_RESET_RESCAN;
@@ -1603,6 +1606,7 @@ tcm_vhost_do_plug(struct tcm_vhost_tpg *tpg,


	vq = &vs->vqs[VHOST_SCSI_VQ_EVT].vq;
	vq = &vs->vqs[VHOST_SCSI_VQ_EVT].vq;
	mutex_lock(&vq->mutex);
	mutex_lock(&vq->mutex);
	if (vhost_has_feature(vq, VIRTIO_SCSI_F_HOTPLUG))
		tcm_vhost_send_evt(vs, tpg, lun,
		tcm_vhost_send_evt(vs, tpg, lun,
				   VIRTIO_SCSI_T_TRANSPORT_RESET, reason);
				   VIRTIO_SCSI_T_TRANSPORT_RESET, reason);
	mutex_unlock(&vq->mutex);
	mutex_unlock(&vq->mutex);
+7 −4
Original line number Original line Diff line number Diff line
@@ -53,7 +53,7 @@ static void handle_vq(struct vhost_test *n)
	vhost_disable_notify(&n->dev, vq);
	vhost_disable_notify(&n->dev, vq);


	for (;;) {
	for (;;) {
		head = vhost_get_vq_desc(&n->dev, vq, vq->iov,
		head = vhost_get_vq_desc(vq, vq->iov,
					 ARRAY_SIZE(vq->iov),
					 ARRAY_SIZE(vq->iov),
					 &out, &in,
					 &out, &in,
					 NULL, NULL);
					 NULL, NULL);
@@ -241,15 +241,18 @@ static long vhost_test_reset_owner(struct vhost_test *n)


static int vhost_test_set_features(struct vhost_test *n, u64 features)
static int vhost_test_set_features(struct vhost_test *n, u64 features)
{
{
	struct vhost_virtqueue *vq;

	mutex_lock(&n->dev.mutex);
	mutex_lock(&n->dev.mutex);
	if ((features & (1 << VHOST_F_LOG_ALL)) &&
	if ((features & (1 << VHOST_F_LOG_ALL)) &&
	    !vhost_log_access_ok(&n->dev)) {
	    !vhost_log_access_ok(&n->dev)) {
		mutex_unlock(&n->dev.mutex);
		mutex_unlock(&n->dev.mutex);
		return -EFAULT;
		return -EFAULT;
	}
	}
	n->dev.acked_features = features;
	vq = &n->vqs[VHOST_TEST_VQ];
	smp_wmb();
	mutex_lock(&vq->mutex);
	vhost_test_flush(n);
	vq->acked_features = features;
	mutex_unlock(&vq->mutex);
	mutex_unlock(&n->dev.mutex);
	mutex_unlock(&n->dev.mutex);
	return 0;
	return 0;
}
}
+50 −47
Original line number Original line Diff line number Diff line
@@ -18,7 +18,6 @@
#include <linux/mmu_context.h>
#include <linux/mmu_context.h>
#include <linux/miscdevice.h>
#include <linux/miscdevice.h>
#include <linux/mutex.h>
#include <linux/mutex.h>
#include <linux/rcupdate.h>
#include <linux/poll.h>
#include <linux/poll.h>
#include <linux/file.h>
#include <linux/file.h>
#include <linux/highmem.h>
#include <linux/highmem.h>
@@ -191,6 +190,7 @@ static void vhost_vq_reset(struct vhost_dev *dev,
	vq->log_used = false;
	vq->log_used = false;
	vq->log_addr = -1ull;
	vq->log_addr = -1ull;
	vq->private_data = NULL;
	vq->private_data = NULL;
	vq->acked_features = 0;
	vq->log_base = NULL;
	vq->log_base = NULL;
	vq->error_ctx = NULL;
	vq->error_ctx = NULL;
	vq->error = NULL;
	vq->error = NULL;
@@ -198,6 +198,7 @@ static void vhost_vq_reset(struct vhost_dev *dev,
	vq->call_ctx = NULL;
	vq->call_ctx = NULL;
	vq->call = NULL;
	vq->call = NULL;
	vq->log_ctx = NULL;
	vq->log_ctx = NULL;
	vq->memory = NULL;
}
}


static int vhost_worker(void *data)
static int vhost_worker(void *data)
@@ -415,11 +416,18 @@ EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare);
/* Caller should have device mutex */
/* Caller should have device mutex */
void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory)
void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory)
{
{
	int i;

	vhost_dev_cleanup(dev, true);
	vhost_dev_cleanup(dev, true);


	/* Restore memory to default empty mapping. */
	/* Restore memory to default empty mapping. */
	memory->nregions = 0;
	memory->nregions = 0;
	RCU_INIT_POINTER(dev->memory, memory);
	dev->memory = memory;
	/* We don't need VQ locks below since vhost_dev_cleanup makes sure
	 * VQs aren't running.
	 */
	for (i = 0; i < dev->nvqs; ++i)
		dev->vqs[i]->memory = memory;
}
}
EXPORT_SYMBOL_GPL(vhost_dev_reset_owner);
EXPORT_SYMBOL_GPL(vhost_dev_reset_owner);


@@ -462,10 +470,8 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
		fput(dev->log_file);
		fput(dev->log_file);
	dev->log_file = NULL;
	dev->log_file = NULL;
	/* No one will access memory at this point */
	/* No one will access memory at this point */
	kfree(rcu_dereference_protected(dev->memory,
	kfree(dev->memory);
					locked ==
	dev->memory = NULL;
						lockdep_is_held(&dev->mutex)));
	RCU_INIT_POINTER(dev->memory, NULL);
	WARN_ON(!list_empty(&dev->work_list));
	WARN_ON(!list_empty(&dev->work_list));
	if (dev->worker) {
	if (dev->worker) {
		kthread_stop(dev->worker);
		kthread_stop(dev->worker);
@@ -524,11 +530,13 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem,


	for (i = 0; i < d->nvqs; ++i) {
	for (i = 0; i < d->nvqs; ++i) {
		int ok;
		int ok;
		bool log;

		mutex_lock(&d->vqs[i]->mutex);
		mutex_lock(&d->vqs[i]->mutex);
		log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL);
		/* If ring is inactive, will check when it's enabled. */
		/* If ring is inactive, will check when it's enabled. */
		if (d->vqs[i]->private_data)
		if (d->vqs[i]->private_data)
			ok = vq_memory_access_ok(d->vqs[i]->log_base, mem,
			ok = vq_memory_access_ok(d->vqs[i]->log_base, mem, log);
						 log_all);
		else
		else
			ok = 1;
			ok = 1;
		mutex_unlock(&d->vqs[i]->mutex);
		mutex_unlock(&d->vqs[i]->mutex);
@@ -538,12 +546,12 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem,
	return 1;
	return 1;
}
}


static int vq_access_ok(struct vhost_dev *d, unsigned int num,
static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
			struct vring_desc __user *desc,
			struct vring_desc __user *desc,
			struct vring_avail __user *avail,
			struct vring_avail __user *avail,
			struct vring_used __user *used)
			struct vring_used __user *used)
{
{
	size_t s = vhost_has_feature(d, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
	size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
	return access_ok(VERIFY_READ, desc, num * sizeof *desc) &&
	return access_ok(VERIFY_READ, desc, num * sizeof *desc) &&
	       access_ok(VERIFY_READ, avail,
	       access_ok(VERIFY_READ, avail,
			 sizeof *avail + num * sizeof *avail->ring + s) &&
			 sizeof *avail + num * sizeof *avail->ring + s) &&
@@ -555,26 +563,19 @@ static int vq_access_ok(struct vhost_dev *d, unsigned int num,
/* Caller should have device mutex but not vq mutex */
/* Caller should have device mutex but not vq mutex */
int vhost_log_access_ok(struct vhost_dev *dev)
int vhost_log_access_ok(struct vhost_dev *dev)
{
{
	struct vhost_memory *mp;
	return memory_access_ok(dev, dev->memory, 1);

	mp = rcu_dereference_protected(dev->memory,
				       lockdep_is_held(&dev->mutex));
	return memory_access_ok(dev, mp, 1);
}
}
EXPORT_SYMBOL_GPL(vhost_log_access_ok);
EXPORT_SYMBOL_GPL(vhost_log_access_ok);


/* Verify access for write logging. */
/* Verify access for write logging. */
/* Caller should have vq mutex and device mutex */
/* Caller should have vq mutex and device mutex */
static int vq_log_access_ok(struct vhost_dev *d, struct vhost_virtqueue *vq,
static int vq_log_access_ok(struct vhost_virtqueue *vq,
			    void __user *log_base)
			    void __user *log_base)
{
{
	struct vhost_memory *mp;
	size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
	size_t s = vhost_has_feature(d, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;


	mp = rcu_dereference_protected(vq->dev->memory,
	return vq_memory_access_ok(log_base, vq->memory,
				       lockdep_is_held(&vq->mutex));
				   vhost_has_feature(vq, VHOST_F_LOG_ALL)) &&
	return vq_memory_access_ok(log_base, mp,
			    vhost_has_feature(vq->dev, VHOST_F_LOG_ALL)) &&
		(!vq->log_used || log_access_ok(log_base, vq->log_addr,
		(!vq->log_used || log_access_ok(log_base, vq->log_addr,
					sizeof *vq->used +
					sizeof *vq->used +
					vq->num * sizeof *vq->used->ring + s));
					vq->num * sizeof *vq->used->ring + s));
@@ -584,8 +585,8 @@ static int vq_log_access_ok(struct vhost_dev *d, struct vhost_virtqueue *vq,
/* Caller should have vq mutex and device mutex */
/* Caller should have vq mutex and device mutex */
int vhost_vq_access_ok(struct vhost_virtqueue *vq)
int vhost_vq_access_ok(struct vhost_virtqueue *vq)
{
{
	return vq_access_ok(vq->dev, vq->num, vq->desc, vq->avail, vq->used) &&
	return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used) &&
		vq_log_access_ok(vq->dev, vq, vq->log_base);
		vq_log_access_ok(vq, vq->log_base);
}
}
EXPORT_SYMBOL_GPL(vhost_vq_access_ok);
EXPORT_SYMBOL_GPL(vhost_vq_access_ok);


@@ -593,6 +594,7 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
{
{
	struct vhost_memory mem, *newmem, *oldmem;
	struct vhost_memory mem, *newmem, *oldmem;
	unsigned long size = offsetof(struct vhost_memory, regions);
	unsigned long size = offsetof(struct vhost_memory, regions);
	int i;


	if (copy_from_user(&mem, m, size))
	if (copy_from_user(&mem, m, size))
		return -EFAULT;
		return -EFAULT;
@@ -611,15 +613,19 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
		return -EFAULT;
		return -EFAULT;
	}
	}


	if (!memory_access_ok(d, newmem,
	if (!memory_access_ok(d, newmem, 0)) {
			      vhost_has_feature(d, VHOST_F_LOG_ALL))) {
		kfree(newmem);
		kfree(newmem);
		return -EFAULT;
		return -EFAULT;
	}
	}
	oldmem = rcu_dereference_protected(d->memory,
	oldmem = d->memory;
					   lockdep_is_held(&d->mutex));
	d->memory = newmem;
	rcu_assign_pointer(d->memory, newmem);

	synchronize_rcu();
	/* All memory accesses are done under some VQ mutex. */
	for (i = 0; i < d->nvqs; ++i) {
		mutex_lock(&d->vqs[i]->mutex);
		d->vqs[i]->memory = newmem;
		mutex_unlock(&d->vqs[i]->mutex);
	}
	kfree(oldmem);
	kfree(oldmem);
	return 0;
	return 0;
}
}
@@ -718,7 +724,7 @@ long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp)
		 * If it is not, we don't as size might not have been setup.
		 * If it is not, we don't as size might not have been setup.
		 * We will verify when backend is configured. */
		 * We will verify when backend is configured. */
		if (vq->private_data) {
		if (vq->private_data) {
			if (!vq_access_ok(d, vq->num,
			if (!vq_access_ok(vq, vq->num,
				(void __user *)(unsigned long)a.desc_user_addr,
				(void __user *)(unsigned long)a.desc_user_addr,
				(void __user *)(unsigned long)a.avail_user_addr,
				(void __user *)(unsigned long)a.avail_user_addr,
				(void __user *)(unsigned long)a.used_user_addr)) {
				(void __user *)(unsigned long)a.used_user_addr)) {
@@ -858,7 +864,7 @@ long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
			vq = d->vqs[i];
			vq = d->vqs[i];
			mutex_lock(&vq->mutex);
			mutex_lock(&vq->mutex);
			/* If ring is inactive, will check when it's enabled. */
			/* If ring is inactive, will check when it's enabled. */
			if (vq->private_data && !vq_log_access_ok(d, vq, base))
			if (vq->private_data && !vq_log_access_ok(vq, base))
				r = -EFAULT;
				r = -EFAULT;
			else
			else
				vq->log_base = base;
				vq->log_base = base;
@@ -1044,7 +1050,7 @@ int vhost_init_used(struct vhost_virtqueue *vq)
}
}
EXPORT_SYMBOL_GPL(vhost_init_used);
EXPORT_SYMBOL_GPL(vhost_init_used);


static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
			  struct iovec iov[], int iov_size)
			  struct iovec iov[], int iov_size)
{
{
	const struct vhost_memory_region *reg;
	const struct vhost_memory_region *reg;
@@ -1053,9 +1059,7 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
	u64 s = 0;
	u64 s = 0;
	int ret = 0;
	int ret = 0;


	rcu_read_lock();
	mem = vq->memory;

	mem = rcu_dereference(dev->memory);
	while ((u64)len > s) {
	while ((u64)len > s) {
		u64 size;
		u64 size;
		if (unlikely(ret >= iov_size)) {
		if (unlikely(ret >= iov_size)) {
@@ -1077,7 +1081,6 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
		++ret;
		++ret;
	}
	}


	rcu_read_unlock();
	return ret;
	return ret;
}
}


@@ -1102,7 +1105,7 @@ static unsigned next_desc(struct vring_desc *desc)
	return next;
	return next;
}
}


static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
static int get_indirect(struct vhost_virtqueue *vq,
			struct iovec iov[], unsigned int iov_size,
			struct iovec iov[], unsigned int iov_size,
			unsigned int *out_num, unsigned int *in_num,
			unsigned int *out_num, unsigned int *in_num,
			struct vhost_log *log, unsigned int *log_num,
			struct vhost_log *log, unsigned int *log_num,
@@ -1121,7 +1124,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
		return -EINVAL;
		return -EINVAL;
	}
	}


	ret = translate_desc(dev, indirect->addr, indirect->len, vq->indirect,
	ret = translate_desc(vq, indirect->addr, indirect->len, vq->indirect,
			     UIO_MAXIOV);
			     UIO_MAXIOV);
	if (unlikely(ret < 0)) {
	if (unlikely(ret < 0)) {
		vq_err(vq, "Translation failure %d in indirect.\n", ret);
		vq_err(vq, "Translation failure %d in indirect.\n", ret);
@@ -1161,7 +1164,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
			return -EINVAL;
			return -EINVAL;
		}
		}


		ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count,
		ret = translate_desc(vq, desc.addr, desc.len, iov + iov_count,
				     iov_size - iov_count);
				     iov_size - iov_count);
		if (unlikely(ret < 0)) {
		if (unlikely(ret < 0)) {
			vq_err(vq, "Translation failure %d indirect idx %d\n",
			vq_err(vq, "Translation failure %d indirect idx %d\n",
@@ -1198,7 +1201,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
 * This function returns the descriptor number found, or vq->num (which is
 * This function returns the descriptor number found, or vq->num (which is
 * never a valid descriptor number) if none was found.  A negative code is
 * never a valid descriptor number) if none was found.  A negative code is
 * returned on error. */
 * returned on error. */
int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
int vhost_get_vq_desc(struct vhost_virtqueue *vq,
		      struct iovec iov[], unsigned int iov_size,
		      struct iovec iov[], unsigned int iov_size,
		      unsigned int *out_num, unsigned int *in_num,
		      unsigned int *out_num, unsigned int *in_num,
		      struct vhost_log *log, unsigned int *log_num)
		      struct vhost_log *log, unsigned int *log_num)
@@ -1272,7 +1275,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
			return -EFAULT;
			return -EFAULT;
		}
		}
		if (desc.flags & VRING_DESC_F_INDIRECT) {
		if (desc.flags & VRING_DESC_F_INDIRECT) {
			ret = get_indirect(dev, vq, iov, iov_size,
			ret = get_indirect(vq, iov, iov_size,
					   out_num, in_num,
					   out_num, in_num,
					   log, log_num, &desc);
					   log, log_num, &desc);
			if (unlikely(ret < 0)) {
			if (unlikely(ret < 0)) {
@@ -1283,7 +1286,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
			continue;
			continue;
		}
		}


		ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count,
		ret = translate_desc(vq, desc.addr, desc.len, iov + iov_count,
				     iov_size - iov_count);
				     iov_size - iov_count);
		if (unlikely(ret < 0)) {
		if (unlikely(ret < 0)) {
			vq_err(vq, "Translation failure %d descriptor idx %d\n",
			vq_err(vq, "Translation failure %d descriptor idx %d\n",
@@ -1426,11 +1429,11 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
	 * interrupts. */
	 * interrupts. */
	smp_mb();
	smp_mb();


	if (vhost_has_feature(dev, VIRTIO_F_NOTIFY_ON_EMPTY) &&
	if (vhost_has_feature(vq, VIRTIO_F_NOTIFY_ON_EMPTY) &&
	    unlikely(vq->avail_idx == vq->last_avail_idx))
	    unlikely(vq->avail_idx == vq->last_avail_idx))
		return true;
		return true;


	if (!vhost_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
	if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
		__u16 flags;
		__u16 flags;
		if (__get_user(flags, &vq->avail->flags)) {
		if (__get_user(flags, &vq->avail->flags)) {
			vq_err(vq, "Failed to get flags");
			vq_err(vq, "Failed to get flags");
@@ -1491,7 +1494,7 @@ bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
	if (!(vq->used_flags & VRING_USED_F_NO_NOTIFY))
	if (!(vq->used_flags & VRING_USED_F_NO_NOTIFY))
		return false;
		return false;
	vq->used_flags &= ~VRING_USED_F_NO_NOTIFY;
	vq->used_flags &= ~VRING_USED_F_NO_NOTIFY;
	if (!vhost_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
	if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
		r = vhost_update_used_flags(vq);
		r = vhost_update_used_flags(vq);
		if (r) {
		if (r) {
			vq_err(vq, "Failed to enable notification at %p: %d\n",
			vq_err(vq, "Failed to enable notification at %p: %d\n",
@@ -1528,7 +1531,7 @@ void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
	if (vq->used_flags & VRING_USED_F_NO_NOTIFY)
	if (vq->used_flags & VRING_USED_F_NO_NOTIFY)
		return;
		return;
	vq->used_flags |= VRING_USED_F_NO_NOTIFY;
	vq->used_flags |= VRING_USED_F_NO_NOTIFY;
	if (!vhost_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
	if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
		r = vhost_update_used_flags(vq);
		r = vhost_update_used_flags(vq);
		if (r)
		if (r)
			vq_err(vq, "Failed to enable notification at %p: %d\n",
			vq_err(vq, "Failed to enable notification at %p: %d\n",
+6 −13
Original line number Original line Diff line number Diff line
@@ -104,20 +104,18 @@ struct vhost_virtqueue {
	struct iovec *indirect;
	struct iovec *indirect;
	struct vring_used_elem *heads;
	struct vring_used_elem *heads;
	/* Protected by virtqueue mutex. */
	/* Protected by virtqueue mutex. */
	struct vhost_memory *memory;
	void *private_data;
	void *private_data;
	unsigned acked_features;
	/* Log write descriptors */
	/* Log write descriptors */
	void __user *log_base;
	void __user *log_base;
	struct vhost_log *log;
	struct vhost_log *log;
};
};


struct vhost_dev {
struct vhost_dev {
	/* Readers use RCU to access memory table pointer
	struct vhost_memory *memory;
	 * log base pointer and features.
	 * Writers use mutex below.*/
	struct vhost_memory __rcu *memory;
	struct mm_struct *mm;
	struct mm_struct *mm;
	struct mutex mutex;
	struct mutex mutex;
	unsigned acked_features;
	struct vhost_virtqueue **vqs;
	struct vhost_virtqueue **vqs;
	int nvqs;
	int nvqs;
	struct file *log_file;
	struct file *log_file;
@@ -140,7 +138,7 @@ long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp);
int vhost_vq_access_ok(struct vhost_virtqueue *vq);
int vhost_vq_access_ok(struct vhost_virtqueue *vq);
int vhost_log_access_ok(struct vhost_dev *);
int vhost_log_access_ok(struct vhost_dev *);


int vhost_get_vq_desc(struct vhost_dev *, struct vhost_virtqueue *,
int vhost_get_vq_desc(struct vhost_virtqueue *,
		      struct iovec iov[], unsigned int iov_count,
		      struct iovec iov[], unsigned int iov_count,
		      unsigned int *out_num, unsigned int *in_num,
		      unsigned int *out_num, unsigned int *in_num,
		      struct vhost_log *log, unsigned int *log_num);
		      struct vhost_log *log, unsigned int *log_num);
@@ -174,13 +172,8 @@ enum {
			 (1ULL << VHOST_F_LOG_ALL),
			 (1ULL << VHOST_F_LOG_ALL),
};
};


static inline int vhost_has_feature(struct vhost_dev *dev, int bit)
static inline int vhost_has_feature(struct vhost_virtqueue *vq, int bit)
{
{
	unsigned acked_features;
	return vq->acked_features & (1 << bit);

	/* TODO: check that we are running from vhost_worker or dev mutex is
	 * held? */
	acked_features = rcu_dereference_index_check(dev->acked_features, 1);
	return acked_features & (1 << bit);
}
}
#endif
#endif