Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 22d79c9a authored by Jason Gunthorpe's avatar Jason Gunthorpe
Browse files

RMDA/odp: Consolidate umem_odp initialization

This is done in two different places, consolidate all the post-allocation
initialization into a single function.

Link: https://lore.kernel.org/r/20190819111710.18440-5-leon@kernel.org


Signed-off-by: default avatarLeon Romanovsky <leonro@mellanox.com>
Signed-off-by: default avatarJason Gunthorpe <jgg@mellanox.com>
parent fd7dbf03
Loading
Loading
Loading
Loading
+86 −114
Original line number Diff line number Diff line
@@ -171,23 +171,6 @@ static const struct mmu_notifier_ops ib_umem_notifiers = {
	.invalidate_range_end       = ib_umem_notifier_invalidate_range_end,
};

static void add_umem_to_per_mm(struct ib_umem_odp *umem_odp)
{
	struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;

	down_write(&per_mm->umem_rwsem);
	/*
	 * Note that the representation of the intervals in the interval tree
	 * considers the ending point as contained in the interval, while the
	 * function ib_umem_end returns the first address which is not
	 * contained in the umem.
	 */
	umem_odp->interval_tree.start = ib_umem_start(umem_odp);
	umem_odp->interval_tree.last = ib_umem_end(umem_odp) - 1;
	interval_tree_insert(&umem_odp->interval_tree, &per_mm->umem_tree);
	up_write(&per_mm->umem_rwsem);
}

static void remove_umem_from_per_mm(struct ib_umem_odp *umem_odp)
{
	struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
@@ -237,33 +220,23 @@ static struct ib_ucontext_per_mm *alloc_per_mm(struct ib_ucontext *ctx,
	return ERR_PTR(ret);
}

static int get_per_mm(struct ib_umem_odp *umem_odp)
static struct ib_ucontext_per_mm *get_per_mm(struct ib_umem_odp *umem_odp)
{
	struct ib_ucontext *ctx = umem_odp->umem.context;
	struct ib_ucontext_per_mm *per_mm;

	lockdep_assert_held(&ctx->per_mm_list_lock);

	/*
	 * Generally speaking we expect only one or two per_mm in this list,
	 * so no reason to optimize this search today.
	 */
	mutex_lock(&ctx->per_mm_list_lock);
	list_for_each_entry(per_mm, &ctx->per_mm_list, ucontext_list) {
		if (per_mm->mm == umem_odp->umem.owning_mm)
			goto found;
	}

	per_mm = alloc_per_mm(ctx, umem_odp->umem.owning_mm);
	if (IS_ERR(per_mm)) {
		mutex_unlock(&ctx->per_mm_list_lock);
		return PTR_ERR(per_mm);
			return per_mm;
	}

found:
	umem_odp->per_mm = per_mm;
	per_mm->odp_mrs_count++;
	mutex_unlock(&ctx->per_mm_list_lock);

	return 0;
	return alloc_per_mm(ctx, umem_odp->umem.owning_mm);
}

static void free_per_mm(struct rcu_head *rcu)
@@ -304,79 +277,114 @@ static void put_per_mm(struct ib_umem_odp *umem_odp)
	mmu_notifier_call_srcu(&per_mm->rcu, free_per_mm);
}

struct ib_umem_odp *ib_alloc_odp_umem(struct ib_umem_odp *root,
				      unsigned long addr, size_t size)
static inline int ib_init_umem_odp(struct ib_umem_odp *umem_odp,
				   struct ib_ucontext_per_mm *per_mm)
{
	struct ib_ucontext_per_mm *per_mm = root->per_mm;
	struct ib_ucontext *ctx = per_mm->context;
	struct ib_umem_odp *odp_data;
	struct ib_umem *umem;
	int pages = size >> PAGE_SHIFT;
	struct ib_ucontext *ctx = umem_odp->umem.context;
	int ret;

	if (!size)
		return ERR_PTR(-EINVAL);
	umem_odp->umem.is_odp = 1;
	if (!umem_odp->is_implicit_odp) {
		size_t pages = ib_umem_odp_num_pages(umem_odp);

	odp_data = kzalloc(sizeof(*odp_data), GFP_KERNEL);
	if (!odp_data)
		return ERR_PTR(-ENOMEM);
	umem = &odp_data->umem;
	umem->context    = ctx;
	umem->length     = size;
	umem->address    = addr;
	odp_data->page_shift = PAGE_SHIFT;
	umem->writable   = root->umem.writable;
	umem->is_odp = 1;
	odp_data->per_mm = per_mm;
	umem->owning_mm  = per_mm->mm;
	mmgrab(umem->owning_mm);
		if (!pages)
			return -EINVAL;

	mutex_init(&odp_data->umem_mutex);
	init_completion(&odp_data->notifier_completion);
		/*
		 * Note that the representation of the intervals in the
		 * interval tree considers the ending point as contained in
		 * the interval, while the function ib_umem_end returns the
		 * first address which is not contained in the umem.
		 */
		umem_odp->interval_tree.start = ib_umem_start(umem_odp);
		umem_odp->interval_tree.last = ib_umem_end(umem_odp) - 1;

	odp_data->page_list =
		vzalloc(array_size(pages, sizeof(*odp_data->page_list)));
	if (!odp_data->page_list) {
		ret = -ENOMEM;
		goto out_odp_data;
	}
		umem_odp->page_list = vzalloc(
			array_size(sizeof(*umem_odp->page_list), pages));
		if (!umem_odp->page_list)
			return -ENOMEM;

	odp_data->dma_list =
		vzalloc(array_size(pages, sizeof(*odp_data->dma_list)));
	if (!odp_data->dma_list) {
		umem_odp->dma_list =
			vzalloc(array_size(sizeof(*umem_odp->dma_list), pages));
		if (!umem_odp->dma_list) {
			ret = -ENOMEM;
			goto out_page_list;
		}
	}

	/*
	 * Caller must ensure that the umem_odp that the per_mm came from
	 * cannot be freed during the call to ib_alloc_odp_umem.
	 */
	mutex_lock(&ctx->per_mm_list_lock);
	if (!per_mm) {
		per_mm = get_per_mm(umem_odp);
		if (IS_ERR(per_mm)) {
			ret = PTR_ERR(per_mm);
			goto out_unlock;
		}
	}
	umem_odp->per_mm = per_mm;
	per_mm->odp_mrs_count++;
	mutex_unlock(&ctx->per_mm_list_lock);
	add_umem_to_per_mm(odp_data);

	return odp_data;
	mutex_init(&umem_odp->umem_mutex);
	init_completion(&umem_odp->notifier_completion);

	if (!umem_odp->is_implicit_odp) {
		down_write(&per_mm->umem_rwsem);
		interval_tree_insert(&umem_odp->interval_tree,
				     &per_mm->umem_tree);
		up_write(&per_mm->umem_rwsem);
	}

	return 0;

out_unlock:
	mutex_unlock(&ctx->per_mm_list_lock);
	vfree(umem_odp->dma_list);
out_page_list:
	vfree(odp_data->page_list);
out_odp_data:
	mmdrop(umem->owning_mm);
	vfree(umem_odp->page_list);
	return ret;
}

struct ib_umem_odp *ib_alloc_odp_umem(struct ib_umem_odp *root,
				      unsigned long addr, size_t size)
{
	/*
	 * Caller must ensure that root cannot be freed during the call to
	 * ib_alloc_odp_umem.
	 */
	struct ib_umem_odp *odp_data;
	struct ib_umem *umem;
	int ret;

	odp_data = kzalloc(sizeof(*odp_data), GFP_KERNEL);
	if (!odp_data)
		return ERR_PTR(-ENOMEM);
	umem = &odp_data->umem;
	umem->context    = root->umem.context;
	umem->length     = size;
	umem->address    = addr;
	umem->writable   = root->umem.writable;
	umem->owning_mm  = root->umem.owning_mm;
	odp_data->page_shift = PAGE_SHIFT;

	ret = ib_init_umem_odp(odp_data, root->per_mm);
	if (ret) {
		kfree(odp_data);
		return ERR_PTR(ret);
	}

	mmgrab(umem->owning_mm);

	return odp_data;
}
EXPORT_SYMBOL(ib_alloc_odp_umem);

int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
{
	struct ib_umem *umem = &umem_odp->umem;
	/*
	 * NOTE: This must called in a process context where umem->owning_mm
	 * == current->mm
	 */
	struct mm_struct *mm = umem->owning_mm;
	int ret_val;
	struct mm_struct *mm = umem_odp->umem.owning_mm;

	if (umem_odp->umem.address == 0 && umem_odp->umem.length == 0)
		umem_odp->is_implicit_odp = 1;
@@ -397,43 +405,7 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
		up_read(&mm->mmap_sem);
	}

	mutex_init(&umem_odp->umem_mutex);

	init_completion(&umem_odp->notifier_completion);

	if (!umem_odp->is_implicit_odp) {
		if (!ib_umem_odp_num_pages(umem_odp))
			return -EINVAL;

		umem_odp->page_list =
			vzalloc(array_size(sizeof(*umem_odp->page_list),
					   ib_umem_odp_num_pages(umem_odp)));
		if (!umem_odp->page_list)
			return -ENOMEM;

		umem_odp->dma_list =
			vzalloc(array_size(sizeof(*umem_odp->dma_list),
					   ib_umem_odp_num_pages(umem_odp)));
		if (!umem_odp->dma_list) {
			ret_val = -ENOMEM;
			goto out_page_list;
		}
	}

	ret_val = get_per_mm(umem_odp);
	if (ret_val)
		goto out_dma_list;

	if (!umem_odp->is_implicit_odp)
		add_umem_to_per_mm(umem_odp);

	return 0;

out_dma_list:
	vfree(umem_odp->dma_list);
out_page_list:
	vfree(umem_odp->page_list);
	return ret_val;
	return ib_init_umem_odp(umem_odp, NULL);
}

void ib_umem_odp_release(struct ib_umem_odp *umem_odp)