Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 08520107 authored by Jilai Wang's avatar Jilai Wang
Browse files

msm: npu: Prevent network from being used after free



Function wait_for_completion_interruptible_timeout will return after
receiving a signal which would cause the network data structure to be
used after it is freed. This change is to handle this scenario properly
to avoid use-after-free issue.

Change-Id: Icb74b3e7a5cb6c3201738c1952948d308333993e
Signed-off-by: default avatarJilai Wang <jilaiw@codeaurora.org>
parent cf4e0244
Loading
Loading
Loading
Loading
+5 −5
Original line number Diff line number Diff line
@@ -957,7 +957,7 @@ static int npu_load_network(struct npu_client *client,

	ret = npu_host_load_network(client, &req);
	if (ret) {
		pr_err("network load failed: %d\n", ret);
		pr_err("npu_host_load_network failed %d\n", ret);
		return ret;
	}

@@ -1014,7 +1014,7 @@ static int npu_load_network_v2(struct npu_client *client,

	kfree(patch_info);
	if (ret) {
		pr_err("network load failed: %d\n", ret);
		pr_err("npu_host_load_network_v2 failed %d\n", ret);
		return ret;
	}

@@ -1046,7 +1046,7 @@ static int npu_unload_network(struct npu_client *client,
	ret = npu_host_unload_network(client, &req);

	if (ret) {
		pr_err("npu_host_unload_network failed\n");
		pr_err("npu_host_unload_network failed %d\n", ret);
		return ret;
	}

@@ -1084,7 +1084,7 @@ static int npu_exec_network(struct npu_client *client,
	ret = npu_host_exec_network(client, &req);

	if (ret) {
		pr_err("npu_host_exec_network failed\n");
		pr_err("npu_host_exec_network failed %d\n", ret);
		return ret;
	}

@@ -1143,7 +1143,7 @@ static int npu_exec_network_v2(struct npu_client *client,

	kfree(patch_buf_info);
	if (ret) {
		pr_err("npu_host_exec_network failed\n");
		pr_err("npu_host_exec_network_v2 failed %d\n", ret);
		return ret;
	}

+363 −104
Original line number Diff line number Diff line
@@ -46,10 +46,13 @@ static int wait_for_status_ready(struct npu_device *npu_dev,
static struct npu_network *alloc_network(struct npu_host_ctx *ctx,
	struct npu_client *client);
static struct npu_network *get_network_by_hdl(struct npu_host_ctx *ctx,
	uint32_t hdl);
	struct npu_client *client, uint32_t hdl);
static struct npu_network *get_network_by_id(struct npu_host_ctx *ctx,
	struct npu_client *client, int64_t id);
static void free_network(struct npu_host_ctx *ctx, struct npu_client *client,
	int64_t id);
static void free_network(struct npu_host_ctx *ctx, int64_t id);
static int network_get(struct npu_network *network);
static int network_put(struct npu_network *network);
static void app_msg_proc(struct npu_host_ctx *host_ctx, uint32_t *msg);
static void log_msg_proc(struct npu_device *npu_dev, uint32_t *msg);
static void host_session_msg_hdlr(struct npu_device *npu_dev);
@@ -488,86 +491,129 @@ static int npu_notify_aop(struct npu_device *npu_dev, bool on)
 * Function Definitions - Network Management
 * -------------------------------------------------------------------------
 */
static int network_put(struct npu_network *network)
{
	if (!network)
		return 0;

	return atomic_dec_return(&network->ref_cnt);
}

static int network_get(struct npu_network *network)
{
	if (!network)
		return 0;

	return atomic_inc_return(&network->ref_cnt);
}

static struct npu_network *alloc_network(struct npu_host_ctx *ctx,
	struct npu_client *client)
{
	int32_t i;
	struct npu_network *network = ctx->networks;

	mutex_lock(&ctx->lock);
	WARN_ON(!mutex_is_locked(&ctx->lock));

	for (i = 0; i < MAX_LOADED_NETWORK; i++) {
		if (network->id == 0) {
			network->id = i + 1;
			network->network_hdl = 0;
		if (network->id == 0)
			break;
		}

		network++;
	}
	if (i >= MAX_LOADED_NETWORK)
		network = NULL;
	else
		ctx->network_num++;
	mutex_unlock(&ctx->lock);

	if (network) {
	if (i == MAX_LOADED_NETWORK) {
		pr_err("No free network\n");
		return NULL;
	}

	memset(network, 0, sizeof(struct npu_network));
	network->id = i + 1;
	init_completion(&network->cmd_done);
	network->is_valid = true;
		network->fw_error = false;
		network->cmd_pending = false;
	network->client = client;
	network->stats_buf = kzalloc(NPU_MAX_STATS_BUF_SIZE,
		GFP_KERNEL);
	if (!network->stats_buf) {
			free_network(ctx, network->id);
			network = NULL;
		}
		memset(network, 0, sizeof(struct npu_network));
		return NULL;
	}

	ctx->network_num++;
	return network;
}

static struct npu_network *get_network_by_hdl(struct npu_host_ctx *ctx,
	uint32_t hdl)
	struct npu_client *client, uint32_t hdl)
{
	int32_t i;
	struct npu_network *network = ctx->networks;

	WARN_ON(!mutex_is_locked(&ctx->lock));

	for (i = 0; i < MAX_LOADED_NETWORK; i++) {
		if (network->network_hdl == hdl)
			break;

		network++;
	}

	if ((i == MAX_LOADED_NETWORK) || !network->is_valid) {
		pr_err("network hdl invalid %d\n", hdl);
		network = NULL;
		return NULL;
	}

	if (client && (client != network->client)) {
		pr_err("network %d doesn't belong to this client\n",
			network->id);
		return NULL;
	}

	network_get(network);
	return network;
}

static struct npu_network *get_network_by_id(struct npu_host_ctx *ctx,
	int64_t id)
	struct npu_client *client, int64_t id)
{
	struct npu_network *network = NULL;

	WARN_ON(!mutex_is_locked(&ctx->lock));

	if (id < 1 || id > MAX_LOADED_NETWORK ||
		!ctx->networks[id - 1].is_valid) {
		pr_err("network id invalid %d\n", (int32_t)id);
		pr_err("Invalid network id %d\n", (int32_t)id);
		return NULL;
	}

	return &ctx->networks[id - 1];
	network = &ctx->networks[id - 1];
	if (client && (client != network->client)) {
		pr_err("network %d doesn't belong to this client\n", id);
		return NULL;
	}

	network_get(network);
	return network;
}

static void free_network(struct npu_host_ctx *ctx, int64_t id)
static void free_network(struct npu_host_ctx *ctx, struct npu_client *client,
	int64_t id)
{
	struct npu_network *network = get_network_by_id(ctx, id);
	struct npu_network *network = NULL;

	WARN_ON(!mutex_is_locked(&ctx->lock));

	network = get_network_by_id(ctx, client, id);
	if (network) {
		network_put(network);
		if (atomic_read(&network->ref_cnt) == 0) {
			kfree(network->stats_buf);
		mutex_lock(&ctx->lock);
			memset(network, 0, sizeof(struct npu_network));
			ctx->network_num--;
		mutex_unlock(&ctx->lock);
		} else {
			pr_warn("network %d:%d is in use\n", network->id,
				atomic_read(&network->ref_cnt));
		}
	}
}

@@ -617,11 +663,22 @@ static void app_msg_proc(struct npu_host_ctx *host_ctx, uint32_t *msg)
		pr_debug("total_num_layers: %d\n",
			exe_rsp_pkt->stats.exe_stats.total_num_layers);

		network = get_network_by_hdl(host_ctx,
		mutex_lock(&host_ctx->lock);
		network = get_network_by_hdl(host_ctx, NULL,
			exe_rsp_pkt->network_hdl);
		if (!network) {
			pr_err("can't find network %x\n",
				exe_rsp_pkt->network_hdl);
			mutex_unlock(&host_ctx->lock);
			break;
		}

		if (network->trans_id != exe_rsp_pkt->header.trans_id) {
			pr_err("execute_pkt trans_id is not match %d:%d\n",
				network->trans_id,
				exe_rsp_pkt->header.trans_id);
			network_put(network);
			mutex_unlock(&host_ctx->lock);
			break;
		}

@@ -640,6 +697,8 @@ static void app_msg_proc(struct npu_host_ctx *host_ctx, uint32_t *msg)
			if (npu_queue_event(network->client, &kevt))
				pr_err("queue npu event failed\n");
		}
		network_put(network);
		mutex_unlock(&host_ctx->lock);

		break;
	}
@@ -653,11 +712,22 @@ static void app_msg_proc(struct npu_host_ctx *host_ctx, uint32_t *msg)
			exe_rsp_pkt->header.status);
		pr_debug("trans_id : %d", exe_rsp_pkt->header.trans_id);

		network = get_network_by_hdl(host_ctx,
		mutex_lock(&host_ctx->lock);
		network = get_network_by_hdl(host_ctx, NULL,
			exe_rsp_pkt->network_hdl);
		if (!network) {
			pr_err("can't find network %x\n",
				exe_rsp_pkt->network_hdl);
			mutex_unlock(&host_ctx->lock);
			break;
		}

		if (network->trans_id != exe_rsp_pkt->header.trans_id) {
			pr_err("execute_pkt_v2 trans_id is not match %d:%d\n",
				network->trans_id,
				exe_rsp_pkt->header.trans_id);
			network_put(network);
			mutex_unlock(&host_ctx->lock);
			break;
		}

@@ -690,6 +760,8 @@ static void app_msg_proc(struct npu_host_ctx *host_ctx, uint32_t *msg)
		} else {
			complete(&network->cmd_done);
		}
		network_put(network);
		mutex_unlock(&host_ctx->lock);
		break;
	}
	case NPU_IPC_MSG_LOAD_DONE:
@@ -708,16 +780,30 @@ static void app_msg_proc(struct npu_host_ctx *host_ctx, uint32_t *msg)
		 */
		pr_debug("network_hdl: %x\n", load_rsp_pkt->network_hdl);
		network_id = load_rsp_pkt->network_hdl >> 16;
		network = get_network_by_id(host_ctx, network_id);
		mutex_lock(&host_ctx->lock);
		network = get_network_by_id(host_ctx, NULL, network_id);
		if (!network) {
			pr_err("can't find network %d\n", network_id);
			mutex_unlock(&host_ctx->lock);
			break;
		}

		if (network->trans_id != load_rsp_pkt->header.trans_id) {
			pr_err("load_rsp_pkt trans_id is not match %d:%d\n",
				network->trans_id,
				load_rsp_pkt->header.trans_id);
			network_put(network);
			mutex_unlock(&host_ctx->lock);
			break;
		}

		network->network_hdl = load_rsp_pkt->network_hdl;
		network->cmd_pending = false;
		network->cmd_ret_status = load_rsp_pkt->header.status;

		complete(&network->cmd_done);
		network_put(network);
		mutex_unlock(&host_ctx->lock);
		break;
	}
	case NPU_IPC_MSG_UNLOAD_DONE:
@@ -729,11 +815,22 @@ static void app_msg_proc(struct npu_host_ctx *host_ctx, uint32_t *msg)
			unload_rsp_pkt->header.status,
			unload_rsp_pkt->header.trans_id);

		network = get_network_by_hdl(host_ctx,
		mutex_lock(&host_ctx->lock);
		network = get_network_by_hdl(host_ctx, NULL,
			unload_rsp_pkt->network_hdl);
		if (!network) {
			pr_err("can't find network %x\n",
				unload_rsp_pkt->network_hdl);
			mutex_unlock(&host_ctx->lock);
			break;
		}

		if (network->trans_id != unload_rsp_pkt->header.trans_id) {
			pr_err("unload_rsp_pkt trans_id is not match %d:%d\n",
				network->trans_id,
				unload_rsp_pkt->header.trans_id);
			network_put(network);
			mutex_unlock(&host_ctx->lock);
			break;
		}

@@ -741,6 +838,8 @@ static void app_msg_proc(struct npu_host_ctx *host_ctx, uint32_t *msg)
		network->cmd_ret_status = unload_rsp_pkt->header.status;

		complete(&network->cmd_done);
		network_put(network);
		mutex_unlock(&host_ctx->lock);
		break;
	}
	case NPU_IPC_MSG_LOOPBACK_DONE:
@@ -855,7 +954,6 @@ static int npu_send_network_cmd(struct npu_device *npu_dev,
	struct npu_host_ctx *host_ctx = &npu_dev->host_ctx;
	int ret = 0;

	mutex_lock(&host_ctx->lock);
	if (network->fw_error || host_ctx->fw_error ||
		(host_ctx->fw_state == FW_DISABLED)) {
		pr_err("fw is in error state or disabled, can't send network cmd\n");
@@ -870,12 +968,12 @@ static int npu_send_network_cmd(struct npu_device *npu_dev,
		network->cmd_async = async;
		network->cmd_ret_status = 0;
		network->cmd_pending = true;
		network->trans_id = atomic_read(&host_ctx->ipc_trans_id);
		ret = npu_host_ipc_send_cmd(npu_dev,
			IPC_QUEUE_APPS_EXEC, cmd_ptr);
		if (ret)
			network->cmd_pending = false;
	}
	mutex_unlock(&host_ctx->lock);

	return ret;
}
@@ -975,12 +1073,14 @@ int32_t npu_host_load_network(struct npu_client *client,
	if (ret)
		return ret;

	mutex_lock(&host_ctx->lock);
	network = alloc_network(host_ctx, client);
	if (!network) {
		ret = -ENOMEM;
		goto err_deinit_fw;
	}

	network_get(network);
	network->buf_hdl = load_ioctl->buf_ion_hdl;
	network->size = load_ioctl->buf_size;
	network->phy_add = load_ioctl->buf_phys_addr;
@@ -1021,16 +1121,26 @@ int32_t npu_host_load_network(struct npu_client *client,
		goto error_free_network;
	}

	if (!wait_for_completion_interruptible_timeout(
	mutex_unlock(&host_ctx->lock);

	ret = wait_for_completion_interruptible_timeout(
		&network->cmd_done,
		(host_ctx->fw_dbg_mode & FW_DBG_MODE_INC_TIMEOUT) ?
		NW_DEBUG_TIMEOUT : NW_CMD_TIMEOUT)) {
		pr_err_ratelimited("npu: NPU_IPC_CMD_LOAD time out\n");
		NW_DEBUG_TIMEOUT : NW_CMD_TIMEOUT);

	mutex_lock(&host_ctx->lock);
	if (!ret) {
		pr_err_ratelimited("NPU_IPC_CMD_LOAD time out\n");
		ret = -ETIMEDOUT;
		goto error_free_network;
	} else if (network->fw_error) {
	} else if (ret < 0) {
		pr_err("NPU_IPC_CMD_LOAD is interrupted by signal\n");
		goto error_free_network;
	}

	if (network->fw_error) {
		ret = -EIO;
		pr_err("load cmd returns with error\n");
		pr_err("fw is in error state during load network\n");
		goto error_free_network;
	}

@@ -1039,12 +1149,17 @@ int32_t npu_host_load_network(struct npu_client *client,
		goto error_free_network;

	load_ioctl->network_hdl = network->network_hdl;
	network->is_active = true;
	network_put(network);
	mutex_unlock(&host_ctx->lock);

	return ret;

error_free_network:
	free_network(host_ctx, network->id);
	network_put(network);
	free_network(host_ctx, client, network->id);
err_deinit_fw:
	mutex_unlock(&host_ctx->lock);
	fw_deinit(npu_dev, false);
	return ret;
}
@@ -1065,12 +1180,14 @@ int32_t npu_host_load_network_v2(struct npu_client *client,
	if (ret)
		return ret;

	mutex_lock(&host_ctx->lock);
	network = alloc_network(host_ctx, client);
	if (!network) {
		ret = -ENOMEM;
		goto err_deinit_fw;
	}

	network_get(network);
	num_patch_params = load_ioctl->patch_info_num;
	pkt_size = sizeof(*load_packet) +
		num_patch_params * sizeof(struct npu_patch_tuple_v2);
@@ -1128,16 +1245,27 @@ int32_t npu_host_load_network_v2(struct npu_client *client,
		goto error_free_network;
	}

	if (!wait_for_completion_interruptible_timeout(
	mutex_unlock(&host_ctx->lock);

	ret = wait_for_completion_interruptible_timeout(
		&network->cmd_done,
		(host_ctx->fw_dbg_mode & FW_DBG_MODE_INC_TIMEOUT) ?
		NW_DEBUG_TIMEOUT : NW_CMD_TIMEOUT)) {
		NW_DEBUG_TIMEOUT : NW_CMD_TIMEOUT);

	mutex_lock(&host_ctx->lock);

	if (!ret) {
		pr_err_ratelimited("npu: NPU_IPC_CMD_LOAD time out\n");
		ret = -ETIMEDOUT;
		goto error_free_network;
	} else if (network->fw_error) {
	} else if (ret < 0) {
		pr_err("NPU_IPC_CMD_LOAD_V2 is interrupted by signal\n");
		goto error_free_network;
	}

	if (network->fw_error) {
		ret = -EIO;
		pr_err("load cmd returns with error\n");
		pr_err("fw is in error state during load_v2 network\n");
		goto error_free_network;
	}

@@ -1146,13 +1274,18 @@ int32_t npu_host_load_network_v2(struct npu_client *client,
		goto error_free_network;

	load_ioctl->network_hdl = network->network_hdl;
	network->is_active = true;
	network_put(network);
	mutex_unlock(&host_ctx->lock);

	return ret;

error_free_network:
	kfree(load_packet);
	free_network(host_ctx, network->id);
	network_put(network);
	free_network(host_ctx, client, network->id);
err_deinit_fw:
	mutex_unlock(&host_ctx->lock);
	fw_deinit(npu_dev, false);
	return ret;
}
@@ -1167,15 +1300,27 @@ int32_t npu_host_unload_network(struct npu_client *client,
	struct npu_host_ctx *host_ctx = &npu_dev->host_ctx;

	/* get the corresponding network for ipc trans id purpose */
	network = get_network_by_hdl(host_ctx, unload->network_hdl);
	if (!network)
	mutex_lock(&host_ctx->lock);
	network = get_network_by_hdl(host_ctx, client,
		unload->network_hdl);
	if (!network) {
		mutex_unlock(&host_ctx->lock);
		return -EINVAL;
	}

	if (!network->is_active) {
		pr_err("network is not active\n");
		network_put(network);
		mutex_unlock(&host_ctx->lock);
		return -EINVAL;
	}

	if (network->fw_error) {
		pr_err("fw in error state, skip unload network in fw\n");
		goto skip_fw;
		goto free_network;
	}

	pr_debug("Unload network %d\n", network->id);
	/* prepare IPC packet for UNLOAD */
	unload_packet.header.cmd_type = NPU_IPC_CMD_UNLOAD;
	unload_packet.header.size = sizeof(struct ipc_cmd_unload_pkt);
@@ -1190,23 +1335,55 @@ int32_t npu_host_unload_network(struct npu_client *client,

	if (ret) {
		pr_err("NPU_IPC_CMD_UNLOAD sent failed: %d\n", ret);
	} else if (!wait_for_completion_interruptible_timeout(
		/*
		 * If another command is running on this network,
		 * don't free_network now.
		 */
		if (ret == -EBUSY) {
			pr_err("Network is running, retry later\n");
			network_put(network);
			mutex_unlock(&host_ctx->lock);
			return ret;
		}
		goto free_network;
	}

	mutex_unlock(&host_ctx->lock);

	ret = wait_for_completion_interruptible_timeout(
		&network->cmd_done,
		(host_ctx->fw_dbg_mode & FW_DBG_MODE_INC_TIMEOUT) ?
		NW_DEBUG_TIMEOUT : NW_CMD_TIMEOUT)) {
		NW_DEBUG_TIMEOUT : NW_CMD_TIMEOUT);

	mutex_lock(&host_ctx->lock);

	if (!ret) {
		pr_err_ratelimited("npu: NPU_IPC_CMD_UNLOAD time out\n");
		network->cmd_pending = false;
		ret = -ETIMEDOUT;
	} else if (network->fw_error) {
		goto free_network;
	} else if (ret < 0) {
		pr_err("Wait for unload done interrupted by signal\n");
		network->cmd_pending = false;
		goto free_network;
	}

	if (network->fw_error) {
		ret = -EIO;
		pr_err("unload cmd returns with error\n");
		pr_err("fw is in error state during unload network\n");
	} else {
		ret = network->cmd_ret_status;
		pr_debug("unload network status %d", ret);
	}

skip_fw:
free_network:
	/*
	 * free the network on the kernel if the corresponding ACO
	 * handle is unloaded on the firmware side
	 */
	free_network(host_ctx, network->id);
	network_put(network);
	free_network(host_ctx, client, network->id);
	mutex_unlock(&host_ctx->lock);
	fw_deinit(npu_dev, false);
	return ret;
}
@@ -1223,20 +1400,35 @@ int32_t npu_host_exec_network(struct npu_client *client,
	struct npu_host_ctx *host_ctx = &npu_dev->host_ctx;
	bool async_ioctl = !!exec_ioctl->async;

	network = get_network_by_hdl(host_ctx, exec_ioctl->network_hdl);
	mutex_lock(&host_ctx->lock);
	network = get_network_by_hdl(host_ctx, client,
		exec_ioctl->network_hdl);

	if (!network)
	if (!network) {
		mutex_unlock(&host_ctx->lock);
		return -EINVAL;
	}

	if (!network->is_active) {
		pr_err("network is not active\n");
		ret = -EINVAL;
		goto exec_done;
	}

	if (network->fw_error)
		return -EIO;
	if (network->fw_error) {
		pr_err("fw is in error state\n");
		ret = -EIO;
		goto exec_done;
	}

	pr_debug("execute network %d\n", network->id);
	memset(&exec_packet, 0, sizeof(exec_packet));
	if (exec_ioctl->patching_required) {
		if ((exec_ioctl->input_layer_num != 1) ||
			(exec_ioctl->output_layer_num != 1)) {
			pr_err("Invalid input/output layer num\n");
			return -EINVAL;
			ret = -EINVAL;
			goto exec_done;
		}

		input_off = exec_ioctl->input_layers[0].buf_phys_addr;
@@ -1245,7 +1437,8 @@ int32_t npu_host_exec_network(struct npu_client *client,
		if (!npu_mem_verify_addr(client, input_off) ||
			!npu_mem_verify_addr(client, output_off)) {
			pr_err("Invalid patch buf address\n");
			return -EINVAL;
			ret = -EINVAL;
			goto exec_done;
		}

		exec_packet.patch_params.num_params = 2;
@@ -1270,27 +1463,46 @@ int32_t npu_host_exec_network(struct npu_client *client,

	if (ret) {
		pr_err("NPU_IPC_CMD_EXECUTE sent failed: %d\n", ret);
	} else if (async_ioctl) {
		goto exec_done;
	}

	if (async_ioctl) {
		pr_debug("Async ioctl, return now\n");
	} else if (!wait_for_completion_interruptible_timeout(
		goto exec_done;
	}

	mutex_unlock(&host_ctx->lock);

	ret = wait_for_completion_interruptible_timeout(
		&network->cmd_done,
		(host_ctx->fw_dbg_mode & FW_DBG_MODE_INC_TIMEOUT) ?
		NW_DEBUG_TIMEOUT : NW_CMD_TIMEOUT)) {
		NW_DEBUG_TIMEOUT : NW_CMD_TIMEOUT);

	mutex_lock(&host_ctx->lock);
	if (!ret) {
		pr_err_ratelimited("npu: NPU_IPC_CMD_EXECUTE time out\n");
		/* dump debug stats */
		npu_dump_debug_timeout_stats(npu_dev);
		network->cmd_pending = false;

		/* treat execution timed out as ssr */
		fw_deinit(npu_dev, true);
		ret = -ETIMEDOUT;
	} else if (network->fw_error) {
		goto exec_done;
	} else if (ret < 0) {
		pr_err("Wait for execution done interrupted by signal\n");
		network->cmd_pending = false;
		goto exec_done;
	}

	if (network->fw_error) {
		ret = -EIO;
		pr_err("execute cmd returns with error\n");
		pr_err("fw is in error state during execute network\n");
	} else {
		ret = network->cmd_ret_status;
		pr_debug("execution status %d", ret);
	}

exec_done:
	network_put(network);
	mutex_unlock(&host_ctx->lock);
	return ret;
}

@@ -1307,21 +1519,37 @@ int32_t npu_host_exec_network_v2(struct npu_client *client,
	bool async_ioctl = !!exec_ioctl->async;
	int i;

	network = get_network_by_hdl(host_ctx, exec_ioctl->network_hdl);
	mutex_lock(&host_ctx->lock);
	network = get_network_by_hdl(host_ctx, client,
		exec_ioctl->network_hdl);

	if (!network)
	if (!network) {
		mutex_unlock(&host_ctx->lock);
		return -EINVAL;
	}

	if (network->fw_error)
		return -EIO;
	if (!network->is_active) {
		pr_err("network is not active\n");
		ret = -EINVAL;
		goto exec_v2_done;
	}

	if (network->fw_error) {
		pr_err("fw is in error state\n");
		ret = -EIO;
		goto exec_v2_done;
	}

	pr_debug("execute_v2 network %d\n", network->id);
	num_patch_params = exec_ioctl->patch_buf_info_num;
	pkt_size = num_patch_params * sizeof(struct npu_patch_params_v2) +
		sizeof(*exec_packet);
	exec_packet = kzalloc(pkt_size, GFP_KERNEL);

	if (!exec_packet)
		return -ENOMEM;
	if (!exec_packet) {
		ret = -ENOMEM;
		goto exec_v2_done;
	}

	for (i = 0; i < num_patch_params; i++) {
		exec_packet->patch_params[i].id = patch_buf_info[i].buf_id;
@@ -1336,8 +1564,8 @@ int32_t npu_host_exec_network_v2(struct npu_client *client,
		if (!npu_mem_verify_addr(client,
			patch_buf_info[i].buf_phys_addr)) {
			pr_err("Invalid patch value\n");
			kfree(exec_packet);
			return -EINVAL;
			ret = -EINVAL;
			goto free_exec_packet;
		}
	}

@@ -1362,23 +1590,41 @@ int32_t npu_host_exec_network_v2(struct npu_client *client,

	if (ret) {
		pr_err("NPU_IPC_CMD_EXECUTE_V2 sent failed: %d\n", ret);
	} else if (async_ioctl) {
		goto free_exec_packet;
	}

	if (async_ioctl) {
		pr_debug("Async ioctl, return now\n");
	} else if (!wait_for_completion_interruptible_timeout(
		goto free_exec_packet;
	}

	mutex_unlock(&host_ctx->lock);

	ret = wait_for_completion_interruptible_timeout(
		&network->cmd_done,
		(host_ctx->fw_dbg_mode & FW_DBG_MODE_INC_TIMEOUT) ?
		NW_DEBUG_TIMEOUT : NW_CMD_TIMEOUT)) {
		NW_DEBUG_TIMEOUT : NW_CMD_TIMEOUT);

	mutex_lock(&host_ctx->lock);
	if (!ret) {
		pr_err_ratelimited("npu: NPU_IPC_CMD_EXECUTE_V2 time out\n");
		/* dump debug stats */
		npu_dump_debug_timeout_stats(npu_dev);
		network->cmd_pending = false;
		/* treat execution timed out as ssr */
		fw_deinit(npu_dev, true);
		ret = -ETIMEDOUT;
	} else if (network->fw_error) {
		goto free_exec_packet;
	} else if (ret < 0) {
		pr_err("Wait for execution_v2 done interrupted by signal\n");
		network->cmd_pending = false;
		goto free_exec_packet;
	}

	if (network->fw_error) {
		ret = -EIO;
		pr_err("execute cmd returns with error\n");
	} else {
		pr_err("fw is in error state during execute_v2 network\n");
		goto free_exec_packet;
	}

	ret = network->cmd_ret_status;
	if (!ret) {
		exec_ioctl->stats_buf_size = network->stats_buf_size;
@@ -1389,10 +1635,15 @@ int32_t npu_host_exec_network_v2(struct npu_client *client,
			pr_err("copy stats to user failed\n");
			exec_ioctl->stats_buf_size = 0;
		}
		}
	} else {
		pr_err("execution failed %d\n", ret);
	}

free_exec_packet:
	kfree(exec_packet);
exec_v2_done:
	network_put(network);
	mutex_unlock(&host_ctx->lock);
	return ret;
}

@@ -1418,14 +1669,22 @@ int32_t npu_host_loopback_test(struct npu_device *npu_dev)

	if (ret) {
		pr_err("NPU_IPC_CMD_LOOPBACK sent failed: %d\n", ret);
	} else if (!wait_for_completion_interruptible_timeout(
		goto loopback_exit;
	}

	ret = wait_for_completion_interruptible_timeout(
		&host_ctx->loopback_done,
		(host_ctx->fw_dbg_mode & FW_DBG_MODE_INC_TIMEOUT) ?
		NW_DEBUG_TIMEOUT : NW_CMD_TIMEOUT)) {
		NW_DEBUG_TIMEOUT : NW_CMD_TIMEOUT);

	if (!ret) {
		pr_err_ratelimited("npu: NPU_IPC_CMD_LOOPBACK time out\n");
		ret = -ETIMEDOUT;
	} else if (ret < 0) {
		pr_err("Wait for loopback done interrupted by signal\n");
	}

loopback_exit:
	fw_deinit(npu_dev, false);

	return ret;
+3 −0
Original line number Diff line number Diff line
@@ -54,7 +54,10 @@ struct npu_network {
	void *stats_buf;
	void __user *stats_buf_u;
	uint32_t stats_buf_size;
	uint32_t trans_id;
	atomic_t ref_cnt;
	bool is_valid;
	bool is_active;
	bool fw_error;
	bool cmd_pending;
	bool cmd_async;