Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit dad76bf7 authored by K. Y. Srinivasan's avatar K. Y. Srinivasan Committed by Greg Kroah-Hartman
Browse files

Staging: hv: vmbus: Properly deal with de-registering channel callback



Ensure that we correctly handle racing invocations of the channel callback
when the channel is being closed. We do this using the channel's inbound_lock.
A side-effect of this strategy is that we avoid repeatedly picking up this lock
as we drain the inbound ring-buffer.

Signed-off-by: default avatarK. Y. Srinivasan <kys@microsoft.com>
Signed-off-by: default avatarHaiyang Zhang <haiyangz@microsoft.com>
Signed-off-by: default avatarGreg Kroah-Hartman <gregkh@suse.de>
parent 76c39d42
Loading
Loading
Loading
Loading
+5 −15
Original line number Diff line number Diff line
@@ -513,9 +513,12 @@ void vmbus_close(struct vmbus_channel *channel)
{
	struct vmbus_channel_close_channel *msg;
	int ret;
	unsigned long flags;

	/* Stop callback and cancel the timer asap */
	spin_lock_irqsave(&channel->inbound_lock, flags);
	channel->onchannel_callback = NULL;
	spin_unlock_irqrestore(&channel->inbound_lock, flags);

	/* Send a closing message */

@@ -735,19 +738,15 @@ int vmbus_recvpacket(struct vmbus_channel *channel, void *buffer,
	u32 packetlen;
	u32 userlen;
	int ret;
	unsigned long flags;

	*buffer_actual_len = 0;
	*requestid = 0;

	spin_lock_irqsave(&channel->inbound_lock, flags);

	ret = hv_ringbuffer_peek(&channel->inbound, &desc,
			     sizeof(struct vmpacket_descriptor));
	if (ret != 0) {
		spin_unlock_irqrestore(&channel->inbound_lock, flags);
	if (ret != 0)
		return 0;
	}

	packetlen = desc.len8 << 3;
	userlen = packetlen - (desc.offset8 << 3);
@@ -755,7 +754,6 @@ int vmbus_recvpacket(struct vmbus_channel *channel, void *buffer,
	*buffer_actual_len = userlen;

	if (userlen > bufferlen) {
		spin_unlock_irqrestore(&channel->inbound_lock, flags);

		pr_err("Buffer too small - got %d needs %d\n",
			   bufferlen, userlen);
@@ -768,7 +766,6 @@ int vmbus_recvpacket(struct vmbus_channel *channel, void *buffer,
	ret = hv_ringbuffer_read(&channel->inbound, buffer, userlen,
			     (desc.offset8 << 3));

	spin_unlock_irqrestore(&channel->inbound_lock, flags);

	return 0;
}
@@ -785,19 +782,15 @@ int vmbus_recvpacket_raw(struct vmbus_channel *channel, void *buffer,
	u32 packetlen;
	u32 userlen;
	int ret;
	unsigned long flags;

	*buffer_actual_len = 0;
	*requestid = 0;

	spin_lock_irqsave(&channel->inbound_lock, flags);

	ret = hv_ringbuffer_peek(&channel->inbound, &desc,
			     sizeof(struct vmpacket_descriptor));
	if (ret != 0) {
		spin_unlock_irqrestore(&channel->inbound_lock, flags);
	if (ret != 0)
		return 0;
	}


	packetlen = desc.len8 << 3;
@@ -806,8 +799,6 @@ int vmbus_recvpacket_raw(struct vmbus_channel *channel, void *buffer,
	*buffer_actual_len = packetlen;

	if (packetlen > bufferlen) {
		spin_unlock_irqrestore(&channel->inbound_lock, flags);

		pr_err("Buffer too small - needed %d bytes but "
			"got space for only %d bytes\n",
			packetlen, bufferlen);
@@ -819,7 +810,6 @@ int vmbus_recvpacket_raw(struct vmbus_channel *channel, void *buffer,
	/* Copy over the entire packet to the user buffer */
	ret = hv_ringbuffer_read(&channel->inbound, buffer, packetlen, 0);

	spin_unlock_irqrestore(&channel->inbound_lock, flags);
	return 0;
}
EXPORT_SYMBOL_GPL(vmbus_recvpacket_raw);
+3 −0
Original line number Diff line number Diff line
@@ -215,6 +215,7 @@ struct vmbus_channel *relid2channel(u32 relid)
static void process_chn_event(u32 relid)
{
	struct vmbus_channel *channel;
	unsigned long flags;

	/*
	 * Find the channel based on this relid and invokes the
@@ -222,11 +223,13 @@ static void process_chn_event(u32 relid)
	 */
	channel = relid2channel(relid);

	spin_lock_irqsave(&channel->inbound_lock, flags);
	if (channel && (channel->onchannel_callback != NULL)) {
		channel->onchannel_callback(channel->channel_callback_context);
	} else {
		pr_err("channel not found for relid - %u\n", relid);
	}
	spin_unlock_irqrestore(&channel->inbound_lock, flags);
}

/*
+0 −3
Original line number Diff line number Diff line
@@ -62,9 +62,7 @@ static struct netvsc_device *get_outbound_net_device(struct hv_device *device)
static struct netvsc_device *get_inbound_net_device(struct hv_device *device)
{
	struct netvsc_device *net_device;
	unsigned long flags;

	spin_lock_irqsave(&device->channel->inbound_lock, flags);
	net_device = device->ext;

	if (!net_device)
@@ -75,7 +73,6 @@ static struct netvsc_device *get_inbound_net_device(struct hv_device *device)
		net_device = NULL;

get_in_err:
	spin_unlock_irqrestore(&device->channel->inbound_lock, flags);
	return net_device;
}

+0 −3
Original line number Diff line number Diff line
@@ -352,9 +352,7 @@ static inline struct storvsc_device *get_in_stor_device(
					struct hv_device *device)
{
	struct storvsc_device *stor_device;
	unsigned long flags;

	spin_lock_irqsave(&device->channel->inbound_lock, flags);
	stor_device = (struct storvsc_device *)device->ext;

	if (!stor_device)
@@ -370,7 +368,6 @@ static inline struct storvsc_device *get_in_stor_device(
		stor_device = NULL;

get_in_err:
	spin_unlock_irqrestore(&device->channel->inbound_lock, flags);
	return stor_device;

}