Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 41b4deea authored by Jason Gunthorpe's avatar Jason Gunthorpe Committed by Doug Ledford
Browse files

RDMA/umem: Make ib_umem_odp into a sub structure of ib_umem



These two structures are linked together, use the container_of pattern
instead of a double allocation to make the code simpler and easier to
follow.

Signed-off-by: default avatarJason Gunthorpe <jgg@mellanox.com>
Signed-off-by: default avatarLeon Romanovsky <leonro@mellanox.com>
Signed-off-by: default avatarDoug Ledford <dledford@redhat.com>
parent b5231b01
Loading
Loading
Loading
Loading
+22 −14
Original line number Diff line number Diff line
@@ -108,34 +108,39 @@ struct ib_umem *ib_umem_get(struct ib_ucontext *context, unsigned long addr,
	if (!can_do_mlock())
		return ERR_PTR(-EPERM);

	umem = kzalloc(sizeof *umem, GFP_KERNEL);
	if (access & IB_ACCESS_ON_DEMAND) {
		umem = kzalloc(sizeof(struct ib_umem_odp), GFP_KERNEL);
		if (!umem)
			return ERR_PTR(-ENOMEM);
		umem->odp_data = to_ib_umem_odp(umem);
	} else {
		umem = kzalloc(sizeof(*umem), GFP_KERNEL);
		if (!umem)
			return ERR_PTR(-ENOMEM);
	}

	umem->context    = context;
	umem->length     = size;
	umem->address    = addr;
	umem->page_shift = PAGE_SHIFT;
	umem->writable   = ib_access_writable(access);
	umem->owning_mm = mm = current->mm;
	mmgrab(mm);

	if (access & IB_ACCESS_ON_DEMAND) {
		ret = ib_umem_odp_get(context, umem, access);
		ret = ib_umem_odp_get(to_ib_umem_odp(umem), access);
		if (ret)
			goto umem_kfree;
		return umem;
	}

	umem->owning_mm = mm = current->mm;
	mmgrab(mm);
	umem->odp_data = NULL;

	/* We assume the memory is from hugetlb until proved otherwise */
	umem->hugetlb   = 1;

	page_list = (struct page **) __get_free_page(GFP_KERNEL);
	if (!page_list) {
		ret = -ENOMEM;
		goto umem_kfree_drop;
		goto umem_kfree;
	}

	/*
@@ -226,12 +231,11 @@ struct ib_umem *ib_umem_get(struct ib_ucontext *context, unsigned long addr,
	if (vma_list)
		free_page((unsigned long) vma_list);
	free_page((unsigned long) page_list);
umem_kfree_drop:
	if (ret)
		mmdrop(umem->owning_mm);
umem_kfree:
	if (ret)
	if (ret) {
		mmdrop(umem->owning_mm);
		kfree(umem);
	}
	return ret ? ERR_PTR(ret) : umem;
}
EXPORT_SYMBOL(ib_umem_get);
@@ -239,6 +243,9 @@ EXPORT_SYMBOL(ib_umem_get);
static void __ib_umem_release_tail(struct ib_umem *umem)
{
	mmdrop(umem->owning_mm);
	if (umem->odp_data)
		kfree(to_ib_umem_odp(umem));
	else
		kfree(umem);
}

@@ -263,6 +270,7 @@ void ib_umem_release(struct ib_umem *umem)

	if (umem->odp_data) {
		ib_umem_odp_release(to_ib_umem_odp(umem));
		__ib_umem_release_tail(umem);
		return;
	}

+30 −49
Original line number Diff line number Diff line
@@ -58,7 +58,7 @@ static u64 node_start(struct umem_odp_node *n)
	struct ib_umem_odp *umem_odp =
			container_of(n, struct ib_umem_odp, interval_tree);

	return ib_umem_start(umem_odp->umem);
	return ib_umem_start(&umem_odp->umem);
}

/* Note that the representation of the intervals in the interval tree
@@ -71,7 +71,7 @@ static u64 node_last(struct umem_odp_node *n)
	struct ib_umem_odp *umem_odp =
			container_of(n, struct ib_umem_odp, interval_tree);

	return ib_umem_end(umem_odp->umem) - 1;
	return ib_umem_end(&umem_odp->umem) - 1;
}

INTERVAL_TREE_DEFINE(struct umem_odp_node, rb, u64, __subtree_last,
@@ -159,7 +159,7 @@ static void ib_ucontext_notifier_end_account(struct ib_ucontext *context)
static int ib_umem_notifier_release_trampoline(struct ib_umem_odp *umem_odp,
					       u64 start, u64 end, void *cookie)
{
	struct ib_umem *umem = umem_odp->umem;
	struct ib_umem *umem = &umem_odp->umem;

	/*
	 * Increase the number of notifiers running, to
@@ -198,7 +198,7 @@ static int invalidate_page_trampoline(struct ib_umem_odp *item, u64 start,
				      u64 end, void *cookie)
{
	ib_umem_notifier_start_account(item);
	item->umem->context->invalidate_range(item, start, start + PAGE_SIZE);
	item->umem.context->invalidate_range(item, start, start + PAGE_SIZE);
	ib_umem_notifier_end_account(item);
	return 0;
}
@@ -207,7 +207,7 @@ static int invalidate_range_start_trampoline(struct ib_umem_odp *item,
					     u64 start, u64 end, void *cookie)
{
	ib_umem_notifier_start_account(item);
	item->umem->context->invalidate_range(item, start, end);
	item->umem.context->invalidate_range(item, start, end);
	return 0;
}

@@ -277,28 +277,21 @@ static const struct mmu_notifier_ops ib_umem_notifiers = {
struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context,
				      unsigned long addr, size_t size)
{
	struct ib_umem *umem;
	struct ib_umem_odp *odp_data;
	struct ib_umem *umem;
	int pages = size >> PAGE_SHIFT;
	int ret;

	umem = kzalloc(sizeof(*umem), GFP_KERNEL);
	if (!umem)
	odp_data = kzalloc(sizeof(*odp_data), GFP_KERNEL);
	if (!odp_data)
		return ERR_PTR(-ENOMEM);

	umem = &odp_data->umem;
	umem->context    = context;
	umem->length     = size;
	umem->address    = addr;
	umem->page_shift = PAGE_SHIFT;
	umem->writable   = 1;

	odp_data = kzalloc(sizeof(*odp_data), GFP_KERNEL);
	if (!odp_data) {
		ret = -ENOMEM;
		goto out_umem;
	}
	odp_data->umem = umem;

	mutex_init(&odp_data->umem_mutex);
	init_completion(&odp_data->notifier_completion);

@@ -334,15 +327,14 @@ struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context,
	vfree(odp_data->page_list);
out_odp_data:
	kfree(odp_data);
out_umem:
	kfree(umem);
	return ERR_PTR(ret);
}
EXPORT_SYMBOL(ib_alloc_odp_umem);

int ib_umem_odp_get(struct ib_ucontext *context, struct ib_umem *umem,
		    int access)
int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
{
	struct ib_ucontext *context = umem_odp->umem.context;
	struct ib_umem *umem = &umem_odp->umem;
	int ret_val;
	struct pid *our_pid;
	struct mm_struct *mm = get_task_mm(current);
@@ -378,30 +370,23 @@ int ib_umem_odp_get(struct ib_ucontext *context, struct ib_umem *umem,
		goto out_mm;
	}

	umem->odp_data = kzalloc(sizeof(*umem->odp_data), GFP_KERNEL);
	if (!umem->odp_data) {
		ret_val = -ENOMEM;
		goto out_mm;
	}
	umem->odp_data->umem = umem;
	mutex_init(&umem_odp->umem_mutex);

	mutex_init(&umem->odp_data->umem_mutex);

	init_completion(&umem->odp_data->notifier_completion);
	init_completion(&umem_odp->notifier_completion);

	if (ib_umem_num_pages(umem)) {
		umem->odp_data->page_list =
			vzalloc(array_size(sizeof(*umem->odp_data->page_list),
		umem_odp->page_list =
			vzalloc(array_size(sizeof(*umem_odp->page_list),
					   ib_umem_num_pages(umem)));
		if (!umem->odp_data->page_list) {
		if (!umem_odp->page_list) {
			ret_val = -ENOMEM;
			goto out_odp_data;
			goto out_mm;
		}

		umem->odp_data->dma_list =
			vzalloc(array_size(sizeof(*umem->odp_data->dma_list),
		umem_odp->dma_list =
			vzalloc(array_size(sizeof(*umem_odp->dma_list),
					   ib_umem_num_pages(umem)));
		if (!umem->odp_data->dma_list) {
		if (!umem_odp->dma_list) {
			ret_val = -ENOMEM;
			goto out_page_list;
		}
@@ -415,13 +400,13 @@ int ib_umem_odp_get(struct ib_ucontext *context, struct ib_umem *umem,
	down_write(&context->umem_rwsem);
	context->odp_mrs_count++;
	if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
		rbt_ib_umem_insert(&umem->odp_data->interval_tree,
		rbt_ib_umem_insert(&umem_odp->interval_tree,
				   &context->umem_tree);
	if (likely(!atomic_read(&context->notifier_count)) ||
	    context->odp_mrs_count == 1)
		umem->odp_data->mn_counters_active = true;
		umem_odp->mn_counters_active = true;
	else
		list_add(&umem->odp_data->no_private_counters,
		list_add(&umem_odp->no_private_counters,
			 &context->no_private_counters);
	downgrade_write(&context->umem_rwsem);

@@ -454,11 +439,9 @@ int ib_umem_odp_get(struct ib_ucontext *context, struct ib_umem *umem,

out_mutex:
	up_read(&context->umem_rwsem);
	vfree(umem->odp_data->dma_list);
	vfree(umem_odp->dma_list);
out_page_list:
	vfree(umem->odp_data->page_list);
out_odp_data:
	kfree(umem->odp_data);
	vfree(umem_odp->page_list);
out_mm:
	mmput(mm);
	return ret_val;
@@ -466,7 +449,7 @@ int ib_umem_odp_get(struct ib_ucontext *context, struct ib_umem *umem,

void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
{
	struct ib_umem *umem = umem_odp->umem;
	struct ib_umem *umem = &umem_odp->umem;
	struct ib_ucontext *context = umem->context;

	/*
@@ -528,8 +511,6 @@ void ib_umem_odp_release(struct ib_umem_odp *umem_odp)

	vfree(umem_odp->dma_list);
	vfree(umem_odp->page_list);
	kfree(umem_odp);
	kfree(umem);
}

/*
@@ -557,7 +538,7 @@ static int ib_umem_odp_map_dma_single_page(
		u64 access_mask,
		unsigned long current_seq)
{
	struct ib_umem *umem = umem_odp->umem;
	struct ib_umem *umem = &umem_odp->umem;
	struct ib_device *dev = umem->context->device;
	dma_addr_t dma_addr;
	int stored_page = 0;
@@ -643,7 +624,7 @@ int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt,
			      u64 bcnt, u64 access_mask,
			      unsigned long current_seq)
{
	struct ib_umem *umem = umem_odp->umem;
	struct ib_umem *umem = &umem_odp->umem;
	struct task_struct *owning_process  = NULL;
	struct mm_struct   *owning_mm       = NULL;
	struct page       **local_page_list = NULL;
@@ -759,7 +740,7 @@ EXPORT_SYMBOL(ib_umem_odp_map_dma_pages);
void ib_umem_odp_unmap_dma_pages(struct ib_umem_odp *umem_odp, u64 virt,
				 u64 bound)
{
	struct ib_umem *umem = umem_odp->umem;
	struct ib_umem *umem = &umem_odp->umem;
	int idx;
	u64 addr;
	struct ib_device *dev = umem->context->device;
+13 −13
Original line number Diff line number Diff line
@@ -64,7 +64,7 @@ static int check_parent(struct ib_umem_odp *odp,
static struct ib_umem_odp *odp_next(struct ib_umem_odp *odp)
{
	struct mlx5_ib_mr *mr = odp->private, *parent = mr->parent;
	struct ib_ucontext *ctx = odp->umem->context;
	struct ib_ucontext *ctx = odp->umem.context;
	struct rb_node *rb;

	down_read(&ctx->umem_rwsem);
@@ -102,7 +102,7 @@ static struct ib_umem_odp *odp_lookup(struct ib_ucontext *ctx,
		if (!rb)
			goto not_found;
		odp = rb_entry(rb, struct ib_umem_odp, interval_tree.rb);
		if (ib_umem_start(odp->umem) > start + length)
		if (ib_umem_start(&odp->umem) > start + length)
			goto not_found;
	}
not_found:
@@ -137,7 +137,7 @@ void mlx5_odp_populate_klm(struct mlx5_klm *pklm, size_t offset,
	for (i = 0; i < nentries; i++, pklm++) {
		pklm->bcount = cpu_to_be32(MLX5_IMR_MTT_SIZE);
		va = (offset + i) * MLX5_IMR_MTT_SIZE;
		if (odp && odp->umem->address == va) {
		if (odp && odp->umem.address == va) {
			struct mlx5_ib_mr *mtt = odp->private;

			pklm->key = cpu_to_be32(mtt->ibmr.lkey);
@@ -153,13 +153,13 @@ void mlx5_odp_populate_klm(struct mlx5_klm *pklm, size_t offset,
static void mr_leaf_free_action(struct work_struct *work)
{
	struct ib_umem_odp *odp = container_of(work, struct ib_umem_odp, work);
	int idx = ib_umem_start(odp->umem) >> MLX5_IMR_MTT_SHIFT;
	int idx = ib_umem_start(&odp->umem) >> MLX5_IMR_MTT_SHIFT;
	struct mlx5_ib_mr *mr = odp->private, *imr = mr->parent;

	mr->parent = NULL;
	synchronize_srcu(&mr->dev->mr_srcu);

	ib_umem_release(odp->umem);
	ib_umem_release(&odp->umem);
	if (imr->live)
		mlx5_ib_update_xlt(imr, idx, 1, 0,
				   MLX5_IB_UPD_XLT_INDIRECT |
@@ -185,7 +185,7 @@ void mlx5_ib_invalidate_range(struct ib_umem_odp *umem_odp, unsigned long start,
		pr_err("invalidation called on NULL umem or non-ODP umem\n");
		return;
	}
	umem = umem_odp->umem;
	umem = &umem_odp->umem;

	mr = umem_odp->private;

@@ -392,16 +392,16 @@ static struct ib_umem_odp *implicit_mr_get_data(struct mlx5_ib_mr *mr,
			return ERR_CAST(odp);
		}

		mtt = implicit_mr_alloc(mr->ibmr.pd, odp->umem, 0,
		mtt = implicit_mr_alloc(mr->ibmr.pd, &odp->umem, 0,
					mr->access_flags);
		if (IS_ERR(mtt)) {
			mutex_unlock(&mr->umem->odp_data->umem_mutex);
			ib_umem_release(odp->umem);
			ib_umem_release(&odp->umem);
			return ERR_CAST(mtt);
		}

		odp->private = mtt;
		mtt->umem = odp->umem;
		mtt->umem = &odp->umem;
		mtt->mmkey.iova = addr;
		mtt->parent = mr;
		INIT_WORK(&odp->work, mr_leaf_free_action);
@@ -418,7 +418,7 @@ static struct ib_umem_odp *implicit_mr_get_data(struct mlx5_ib_mr *mr,
	addr += MLX5_IMR_MTT_SIZE;
	if (unlikely(addr < io_virt + bcnt)) {
		odp = odp_next(odp);
		if (odp && odp->umem->address != addr)
		if (odp && odp->umem.address != addr)
			odp = NULL;
		goto next_mr;
	}
@@ -465,7 +465,7 @@ static int mr_leaf_free(struct ib_umem_odp *umem_odp, u64 start, u64 end,
			void *cookie)
{
	struct mlx5_ib_mr *mr = umem_odp->private, *imr = cookie;
	struct ib_umem *umem = umem_odp->umem;
	struct ib_umem *umem = &umem_odp->umem;

	if (mr->parent != imr)
		return 0;
@@ -518,7 +518,7 @@ static int pagefault_mr(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr,
	}

next_mr:
	size = min_t(size_t, bcnt, ib_umem_end(odp->umem) - io_virt);
	size = min_t(size_t, bcnt, ib_umem_end(&odp->umem) - io_virt);

	page_shift = mr->umem->page_shift;
	page_mask = ~(BIT(page_shift) - 1);
@@ -577,7 +577,7 @@ static int pagefault_mr(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr,

		io_virt += size;
		next = odp_next(odp);
		if (unlikely(!next || next->umem->address != io_virt)) {
		if (unlikely(!next || next->umem.address != io_virt)) {
			mlx5_ib_dbg(dev, "next implicit leaf removed at 0x%llx. got %p\n",
				    io_virt, next);
			return -EAGAIN;
+4 −7
Original line number Diff line number Diff line
@@ -43,6 +43,7 @@ struct umem_odp_node {
};

struct ib_umem_odp {
	struct ib_umem umem;
	/*
	 * An array of the pages included in the on-demand paging umem.
	 * Indices of pages that are currently not mapped into the device will
@@ -72,7 +73,6 @@ struct ib_umem_odp {
	/* A linked list of umems that don't have private mmu notifier
	 * counters yet. */
	struct list_head no_private_counters;
	struct ib_umem		*umem;

	/* Tree tracking */
	struct umem_odp_node	interval_tree;
@@ -84,13 +84,12 @@ struct ib_umem_odp {

static inline struct ib_umem_odp *to_ib_umem_odp(struct ib_umem *umem)
{
	return umem->odp_data;
	return container_of(umem, struct ib_umem_odp, umem);
}

#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING

int ib_umem_odp_get(struct ib_ucontext *context, struct ib_umem *umem,
		    int access);
int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access);
struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context,
				      unsigned long addr, size_t size);
void ib_umem_odp_release(struct ib_umem_odp *umem_odp);
@@ -158,9 +157,7 @@ static inline int ib_umem_mmu_notifier_retry(struct ib_umem_odp *umem_odp,

#else /* CONFIG_INFINIBAND_ON_DEMAND_PAGING */

static inline int ib_umem_odp_get(struct ib_ucontext *context,
				  struct ib_umem *umem,
				  int access)
static inline int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
{
	return -EINVAL;
}