Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 150b9e51 authored by Michael S. Tsirkin's avatar Michael S. Tsirkin
Browse files

vhost: fix error handling in RESET_OWNER ioctl



RESET_OWNER ioctl would leave the fd in a bad state if
memory allocation failed: device is stopped
but owner is not reset. Make state changes
after allocating memory, such that a failed
ioctl has no effect.

Signed-off-by: default avatarMichael S. Tsirkin <mst@redhat.com>
parent 061b16cf
Loading
Loading
Loading
Loading
+7 −1
Original line number Original line Diff line number Diff line
@@ -967,14 +967,20 @@ static long vhost_net_reset_owner(struct vhost_net *n)
	struct socket *tx_sock = NULL;
	struct socket *tx_sock = NULL;
	struct socket *rx_sock = NULL;
	struct socket *rx_sock = NULL;
	long err;
	long err;
	struct vhost_memory *memory;


	mutex_lock(&n->dev.mutex);
	mutex_lock(&n->dev.mutex);
	err = vhost_dev_check_owner(&n->dev);
	err = vhost_dev_check_owner(&n->dev);
	if (err)
	if (err)
		goto done;
		goto done;
	memory = vhost_dev_reset_owner_prepare();
	if (!memory) {
		err = -ENOMEM;
		goto done;
	}
	vhost_net_stop(n, &tx_sock, &rx_sock);
	vhost_net_stop(n, &tx_sock, &rx_sock);
	vhost_net_flush(n);
	vhost_net_flush(n);
	err = vhost_dev_reset_owner(&n->dev);
	vhost_dev_reset_owner(&n->dev, memory);
	vhost_net_vq_reset(n);
	vhost_net_vq_reset(n);
done:
done:
	mutex_unlock(&n->dev.mutex);
	mutex_unlock(&n->dev.mutex);
+8 −1
Original line number Original line Diff line number Diff line
@@ -219,13 +219,20 @@ static long vhost_test_reset_owner(struct vhost_test *n)
{
{
	void *priv = NULL;
	void *priv = NULL;
	long err;
	long err;
	struct vhost_memory *memory;

	mutex_lock(&n->dev.mutex);
	mutex_lock(&n->dev.mutex);
	err = vhost_dev_check_owner(&n->dev);
	err = vhost_dev_check_owner(&n->dev);
	if (err)
	if (err)
		goto done;
		goto done;
	memory = vhost_dev_reset_owner_prepare();
	if (!memory) {
		err = -ENOMEM;
		goto done;
	}
	vhost_test_stop(n, &priv);
	vhost_test_stop(n, &priv);
	vhost_test_flush(n);
	vhost_test_flush(n);
	err = vhost_dev_reset_owner(&n->dev);
	vhost_dev_reset_owner(&n->dev, memory);
done:
done:
	mutex_unlock(&n->dev.mutex);
	mutex_unlock(&n->dev.mutex);
	return err;
	return err;
+7 −9
Original line number Original line Diff line number Diff line
@@ -386,21 +386,19 @@ static long vhost_dev_set_owner(struct vhost_dev *dev)
	return err;
	return err;
}
}


/* Caller should have device mutex */
struct vhost_memory *vhost_dev_reset_owner_prepare(void)
long vhost_dev_reset_owner(struct vhost_dev *dev)
{
{
	struct vhost_memory *memory;
	return kmalloc(offsetof(struct vhost_memory, regions), GFP_KERNEL);

}
	/* Restore memory to default empty mapping. */
	memory = kmalloc(offsetof(struct vhost_memory, regions), GFP_KERNEL);
	if (!memory)
		return -ENOMEM;


/* Caller should have device mutex */
void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory)
{
	vhost_dev_cleanup(dev, true);
	vhost_dev_cleanup(dev, true);


	/* Restore memory to default empty mapping. */
	memory->nregions = 0;
	memory->nregions = 0;
	RCU_INIT_POINTER(dev->memory, memory);
	RCU_INIT_POINTER(dev->memory, memory);
	return 0;
}
}


void vhost_dev_stop(struct vhost_dev *dev)
void vhost_dev_stop(struct vhost_dev *dev)
+2 −1
Original line number Original line Diff line number Diff line
@@ -136,7 +136,8 @@ struct vhost_dev {


long vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs, int nvqs);
long vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs, int nvqs);
long vhost_dev_check_owner(struct vhost_dev *);
long vhost_dev_check_owner(struct vhost_dev *);
long vhost_dev_reset_owner(struct vhost_dev *);
struct vhost_memory *vhost_dev_reset_owner_prepare(void);
void vhost_dev_reset_owner(struct vhost_dev *, struct vhost_memory *);
void vhost_dev_cleanup(struct vhost_dev *, bool locked);
void vhost_dev_cleanup(struct vhost_dev *, bool locked);
void vhost_dev_stop(struct vhost_dev *);
void vhost_dev_stop(struct vhost_dev *);
long vhost_dev_ioctl(struct vhost_dev *, unsigned int ioctl, void __user *argp);
long vhost_dev_ioctl(struct vhost_dev *, unsigned int ioctl, void __user *argp);