Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 8043e5b0 authored by Narsinga Rao Chella's avatar Narsinga Rao Chella
Browse files

ipc: apr: Fix apr buffer override issue



Need use one global lock to prevent multiple threads visit single hab
socket at same time, as hab can't guarantee that the coming message is
received by the right thread.

Change-Id: Ie63a48cd9518327c9e4078000e613767ee1321b5
Signed-off-by: default avatarNarsinga Rao Chella <nrchella@codeaurora.org>
parent 0fa983ff
Loading
Loading
Loading
Loading
+13 −22
Original line number Diff line number Diff line
// SPDX-License-Identifier: GPL-2.0-only
/*
 * Copyright (c) 2010-2014, 2016-2020 The Linux Foundation. All rights reserved.
* Copyright (c) 2010-2014, 2016-2021 The Linux Foundation. All rights reserved.
*/

#include <linux/kernel.h>
@@ -122,7 +122,6 @@ struct apr_svc_table {
 *    apr handle and store in svc tbl.
 */

static struct mutex m_lock_tbl_qdsp6;

static struct apr_svc_table svc_tbl_qdsp6[] = {
	{
@@ -214,7 +213,6 @@ static struct apr_svc_table svc_tbl_qdsp6[] = {
	},
};

static struct mutex m_lock_tbl_voice;

static struct apr_svc_table svc_tbl_voice[] = {
	{
@@ -635,10 +633,10 @@ static int apr_vm_get_svc(const char *svc_name, int domain_id, int *client_id,
	int i;
	int size;
	struct apr_svc_table *tbl;
	struct mutex *lock;
	struct aprv2_vm_cmd_register_rsp_t apr_rsp;
	uint32_t apr_len;
	int ret = 0;
	unsigned long flags;
	struct {
		uint32_t cmd_id;
		struct aprv2_vm_cmd_register_t reg_cmd;
@@ -647,14 +645,11 @@ static int apr_vm_get_svc(const char *svc_name, int domain_id, int *client_id,
	if (domain_id == APR_DOMAIN_ADSP) {
		tbl = svc_tbl_qdsp6;
		size = ARRAY_SIZE(svc_tbl_qdsp6);
		lock = &m_lock_tbl_qdsp6;
	} else {
		tbl = svc_tbl_voice;
		size = ARRAY_SIZE(svc_tbl_voice);
		lock = &m_lock_tbl_voice;
	}

	mutex_lock(lock);
	spin_lock_irqsave(&hab_tx_lock, flags);
	for (i = 0; i < size; i++) {
		if (!strcmp(svc_name, tbl[i].name)) {
			*client_id = tbl[i].client_id;
@@ -678,7 +673,7 @@ static int apr_vm_get_svc(const char *svc_name, int domain_id, int *client_id,
				if (ret) {
					pr_err("%s: habmm_socket_send failed %d\n",
						__func__, ret);
					mutex_unlock(lock);
					spin_unlock_irqrestore(&hab_tx_lock, flags);
					return ret;
				}
				/* wait for response */
@@ -690,14 +685,14 @@ static int apr_vm_get_svc(const char *svc_name, int domain_id, int *client_id,
				if (ret) {
					pr_err("%s: apr_vm_nb_receive failed %d\n",
						__func__, ret);
					mutex_unlock(lock);
					spin_unlock_irqrestore(&hab_tx_lock, flags);
					return ret;
				}
				if (apr_rsp.status) {
					pr_err("%s: apr_vm_nb_receive status %d\n",
						__func__, apr_rsp.status);
					ret = apr_rsp.status;
					mutex_unlock(lock);
					spin_unlock_irqrestore(&hab_tx_lock, flags);
					return ret;
				}
				/* update svc table */
@@ -711,7 +706,7 @@ static int apr_vm_get_svc(const char *svc_name, int domain_id, int *client_id,
			break;
		}
	}
	mutex_unlock(lock);
	spin_unlock_irqrestore(&hab_tx_lock, flags);

	pr_debug("%s: svc_name = %s client_id = %d domain_id = %d\n",
		 __func__, svc_name, *client_id, domain_id);
@@ -731,10 +726,10 @@ static int apr_vm_rel_svc(int domain_id, int svc_id, int handle)
	int i;
	int size;
	struct apr_svc_table *tbl;
	struct mutex *lock;
	struct aprv2_vm_cmd_deregister_rsp_t apr_rsp;
	uint32_t apr_len;
	int ret = 0;
	unsigned long flags;
	struct {
		uint32_t cmd_id;
		struct aprv2_vm_cmd_deregister_t dereg_cmd;
@@ -743,14 +738,12 @@ static int apr_vm_rel_svc(int domain_id, int svc_id, int handle)
	if (domain_id == APR_DOMAIN_ADSP) {
		tbl = svc_tbl_qdsp6;
		size = ARRAY_SIZE(svc_tbl_qdsp6);
		lock = &m_lock_tbl_qdsp6;
	} else {
		tbl = svc_tbl_voice;
		size = ARRAY_SIZE(svc_tbl_voice);
		lock = &m_lock_tbl_voice;
	}

	mutex_lock(lock);
	spin_lock_irqsave(&hab_tx_lock, flags);
	for (i = 0; i < size; i++) {
		if (tbl[i].id == svc_id && tbl[i].handle == handle) {
			/* need to deregister a service */
@@ -790,8 +783,7 @@ static int apr_vm_rel_svc(int domain_id, int svc_id, int handle)
			break;
		}
	}
	mutex_unlock(lock);

	spin_unlock_irqrestore(&hab_tx_lock, flags);
	if (i == size) {
		pr_err("%s: APR: Wrong svc id %d handle %d\n",
				__func__, svc_id, handle);
@@ -861,7 +853,7 @@ int apr_send_pkt(void *handle, uint32_t *buf)
		return -ENETRESET;
	}

	spin_lock_irqsave(&svc->w_lock, flags);
	spin_lock_irqsave(&hab_tx_lock, flags);
	if (!svc->id || !svc->vm_handle) {
		pr_err("APR: Still service is not yet opened\n");
		ret = -EINVAL;
@@ -930,7 +922,7 @@ int apr_send_pkt(void *handle, uint32_t *buf)
	ret = hdr->pkt_size;

done:
	spin_unlock_irqrestore(&svc->w_lock, flags);
	spin_unlock_irqrestore(&hab_tx_lock, flags);
	return ret;
}
EXPORT_SYMBOL(apr_send_pkt);
@@ -1413,7 +1405,6 @@ static int apr_probe(struct platform_device *pdev)
			mutex_init(&client[i][j].m_lock);
			for (k = 0; k < APR_SVC_MAX; k++) {
				mutex_init(&client[i][j].svc[k].m_lock);
				spin_lock_init(&client[i][j].svc[k].w_lock);
			}
		}
	spin_lock(&apr_priv->apr_lock);