Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit c9990ab3 authored by Jason Gunthorpe's avatar Jason Gunthorpe Committed by Doug Ledford
Browse files

RDMA/umem: Move all the ODP related stuff out of ucontext and into per_mm



This is the first step to make ODP use the owning_mm that is now part of
struct ib_umem.

Each ODP umem is linked to a single per_mm structure, which in turn, is
linked to a single mm, via the embedded mmu_notifier. This first patch
introduces the structure and reworks eveything to use it.

This also needs to introduce tgid into the ib_ucontext_per_mm, as
get_user_pages_remote() requires the originating task for statistics
tracking.

Signed-off-by: default avatarJason Gunthorpe <jgg@mellanox.com>
Signed-off-by: default avatarLeon Romanovsky <leonro@mellanox.com>
Signed-off-by: default avatarDoug Ledford <dledford@redhat.com>
parent 597ecc5a
Loading
Loading
Loading
Loading
+68 −59
Original line number Diff line number Diff line
@@ -115,34 +115,35 @@ static void ib_umem_notifier_end_account(struct ib_umem_odp *umem_odp)
}

/* Account for a new mmu notifier in an ib_ucontext. */
static void ib_ucontext_notifier_start_account(struct ib_ucontext *context)
static void
ib_ucontext_notifier_start_account(struct ib_ucontext_per_mm *per_mm)
{
	atomic_inc(&context->notifier_count);
	atomic_inc(&per_mm->notifier_count);
}

/* Account for a terminating mmu notifier in an ib_ucontext.
 *
 * Must be called with the ib_ucontext->umem_rwsem semaphore unlocked, since
 * the function takes the semaphore itself. */
static void ib_ucontext_notifier_end_account(struct ib_ucontext *context)
static void ib_ucontext_notifier_end_account(struct ib_ucontext_per_mm *per_mm)
{
	int zero_notifiers = atomic_dec_and_test(&context->notifier_count);
	int zero_notifiers = atomic_dec_and_test(&per_mm->notifier_count);

	if (zero_notifiers &&
	    !list_empty(&context->no_private_counters)) {
	    !list_empty(&per_mm->no_private_counters)) {
		/* No currently running mmu notifiers. Now is the chance to
		 * add private accounting to all previously added umems. */
		struct ib_umem_odp *odp_data, *next;

		/* Prevent concurrent mmu notifiers from working on the
		 * no_private_counters list. */
		down_write(&context->umem_rwsem);
		down_write(&per_mm->umem_rwsem);

		/* Read the notifier_count again, with the umem_rwsem
		 * semaphore taken for write. */
		if (!atomic_read(&context->notifier_count)) {
		if (!atomic_read(&per_mm->notifier_count)) {
			list_for_each_entry_safe(odp_data, next,
						 &context->no_private_counters,
						 &per_mm->no_private_counters,
						 no_private_counters) {
				mutex_lock(&odp_data->umem_mutex);
				odp_data->mn_counters_active = true;
@@ -152,7 +153,7 @@ static void ib_ucontext_notifier_end_account(struct ib_ucontext *context)
			}
		}

		up_write(&context->umem_rwsem);
		up_write(&per_mm->umem_rwsem);
	}
}

@@ -179,19 +180,20 @@ static int ib_umem_notifier_release_trampoline(struct ib_umem_odp *umem_odp,
static void ib_umem_notifier_release(struct mmu_notifier *mn,
				     struct mm_struct *mm)
{
	struct ib_ucontext *context = container_of(mn, struct ib_ucontext, mn);
	struct ib_ucontext_per_mm *per_mm =
		container_of(mn, struct ib_ucontext_per_mm, mn);

	if (!context->invalidate_range)
	if (!per_mm->context->invalidate_range)
		return;

	ib_ucontext_notifier_start_account(context);
	down_read(&context->umem_rwsem);
	rbt_ib_umem_for_each_in_range(&context->umem_tree, 0,
	ib_ucontext_notifier_start_account(per_mm);
	down_read(&per_mm->umem_rwsem);
	rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, 0,
				      ULLONG_MAX,
				      ib_umem_notifier_release_trampoline,
				      true,
				      NULL);
	up_read(&context->umem_rwsem);
	up_read(&per_mm->umem_rwsem);
}

static int invalidate_page_trampoline(struct ib_umem_odp *item, u64 start,
@@ -217,23 +219,24 @@ static int ib_umem_notifier_invalidate_range_start(struct mmu_notifier *mn,
						    unsigned long end,
						    bool blockable)
{
	struct ib_ucontext *context = container_of(mn, struct ib_ucontext, mn);
	struct ib_ucontext_per_mm *per_mm =
		container_of(mn, struct ib_ucontext_per_mm, mn);
	int ret;

	if (!context->invalidate_range)
	if (!per_mm->context->invalidate_range)
		return 0;

	if (blockable)
		down_read(&context->umem_rwsem);
	else if (!down_read_trylock(&context->umem_rwsem))
		down_read(&per_mm->umem_rwsem);
	else if (!down_read_trylock(&per_mm->umem_rwsem))
		return -EAGAIN;

	ib_ucontext_notifier_start_account(context);
	ret = rbt_ib_umem_for_each_in_range(&context->umem_tree, start,
	ib_ucontext_notifier_start_account(per_mm);
	ret = rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, start,
				      end,
				      invalidate_range_start_trampoline,
				      blockable, NULL);
	up_read(&context->umem_rwsem);
	up_read(&per_mm->umem_rwsem);

	return ret;
}
@@ -250,9 +253,10 @@ static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn,
						  unsigned long start,
						  unsigned long end)
{
	struct ib_ucontext *context = container_of(mn, struct ib_ucontext, mn);
	struct ib_ucontext_per_mm *per_mm =
		container_of(mn, struct ib_ucontext_per_mm, mn);

	if (!context->invalidate_range)
	if (!per_mm->context->invalidate_range)
		return;

	/*
@@ -260,12 +264,12 @@ static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn,
	 * in ib_umem_notifier_invalidate_range_start so we shouldn't really block
	 * here. But this is ugly and fragile.
	 */
	down_read(&context->umem_rwsem);
	rbt_ib_umem_for_each_in_range(&context->umem_tree, start,
	down_read(&per_mm->umem_rwsem);
	rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, start,
				      end,
				      invalidate_range_end_trampoline, true, NULL);
	up_read(&context->umem_rwsem);
	ib_ucontext_notifier_end_account(context);
	up_read(&per_mm->umem_rwsem);
	ib_ucontext_notifier_end_account(per_mm);
}

static const struct mmu_notifier_ops ib_umem_notifiers = {
@@ -277,6 +281,7 @@ static const struct mmu_notifier_ops ib_umem_notifiers = {
struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context,
				      unsigned long addr, size_t size)
{
	struct ib_ucontext_per_mm *per_mm;
	struct ib_umem_odp *odp_data;
	struct ib_umem *umem;
	int pages = size >> PAGE_SHIFT;
@@ -292,6 +297,7 @@ struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context,
	umem->page_shift = PAGE_SHIFT;
	umem->writable   = 1;
	umem->is_odp = 1;
	odp_data->per_mm = per_mm = &context->per_mm;

	mutex_init(&odp_data->umem_mutex);
	init_completion(&odp_data->notifier_completion);
@@ -310,15 +316,15 @@ struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context,
		goto out_page_list;
	}

	down_write(&context->umem_rwsem);
	context->odp_mrs_count++;
	rbt_ib_umem_insert(&odp_data->interval_tree, &context->umem_tree);
	if (likely(!atomic_read(&context->notifier_count)))
	down_write(&per_mm->umem_rwsem);
	per_mm->odp_mrs_count++;
	rbt_ib_umem_insert(&odp_data->interval_tree, &per_mm->umem_tree);
	if (likely(!atomic_read(&per_mm->notifier_count)))
		odp_data->mn_counters_active = true;
	else
		list_add(&odp_data->no_private_counters,
			 &context->no_private_counters);
	up_write(&context->umem_rwsem);
			 &per_mm->no_private_counters);
	up_write(&per_mm->umem_rwsem);

	return odp_data;

@@ -334,6 +340,7 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
{
	struct ib_ucontext *context = umem_odp->umem.context;
	struct ib_umem *umem = &umem_odp->umem;
	struct ib_ucontext_per_mm *per_mm;
	int ret_val;
	struct pid *our_pid;
	struct mm_struct *mm = get_task_mm(current);
@@ -396,28 +403,30 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
	 * notification before the "current" task (and MM) is
	 * destroyed. We use the umem_rwsem semaphore to synchronize.
	 */
	down_write(&context->umem_rwsem);
	context->odp_mrs_count++;
	umem_odp->per_mm = per_mm = &context->per_mm;

	down_write(&per_mm->umem_rwsem);
	per_mm->odp_mrs_count++;
	if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
		rbt_ib_umem_insert(&umem_odp->interval_tree,
				   &context->umem_tree);
	if (likely(!atomic_read(&context->notifier_count)) ||
	    context->odp_mrs_count == 1)
				   &per_mm->umem_tree);
	if (likely(!atomic_read(&per_mm->notifier_count)) ||
	    per_mm->odp_mrs_count == 1)
		umem_odp->mn_counters_active = true;
	else
		list_add(&umem_odp->no_private_counters,
			 &context->no_private_counters);
	downgrade_write(&context->umem_rwsem);
			 &per_mm->no_private_counters);
	downgrade_write(&per_mm->umem_rwsem);

	if (context->odp_mrs_count == 1) {
	if (per_mm->odp_mrs_count == 1) {
		/*
		 * Note that at this point, no MMU notifier is running
		 * for this context!
		 * for this per_mm!
		 */
		atomic_set(&context->notifier_count, 0);
		INIT_HLIST_NODE(&context->mn.hlist);
		context->mn.ops = &ib_umem_notifiers;
		ret_val = mmu_notifier_register(&context->mn, mm);
		atomic_set(&per_mm->notifier_count, 0);
		INIT_HLIST_NODE(&per_mm->mn.hlist);
		per_mm->mn.ops = &ib_umem_notifiers;
		ret_val = mmu_notifier_register(&per_mm->mn, mm);
		if (ret_val) {
			pr_err("Failed to register mmu_notifier %d\n", ret_val);
			ret_val = -EBUSY;
@@ -425,7 +434,7 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
		}
	}

	up_read(&context->umem_rwsem);
	up_read(&per_mm->umem_rwsem);

	/*
	 * Note that doing an mmput can cause a notifier for the relevant mm.
@@ -437,7 +446,7 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
	return 0;

out_mutex:
	up_read(&context->umem_rwsem);
	up_read(&per_mm->umem_rwsem);
	vfree(umem_odp->dma_list);
out_page_list:
	vfree(umem_odp->page_list);
@@ -449,7 +458,7 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
{
	struct ib_umem *umem = &umem_odp->umem;
	struct ib_ucontext *context = umem->context;
	struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;

	/*
	 * Ensure that no more pages are mapped in the umem.
@@ -460,11 +469,11 @@ void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
	ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem),
				    ib_umem_end(umem));

	down_write(&context->umem_rwsem);
	down_write(&per_mm->umem_rwsem);
	if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
		rbt_ib_umem_remove(&umem_odp->interval_tree,
				   &context->umem_tree);
	context->odp_mrs_count--;
				   &per_mm->umem_tree);
	per_mm->odp_mrs_count--;
	if (!umem_odp->mn_counters_active) {
		list_del(&umem_odp->no_private_counters);
		complete_all(&umem_odp->notifier_completion);
@@ -477,13 +486,13 @@ void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
	 * that since we are doing it atomically, no other user could register
	 * and unregister while we do the check.
	 */
	downgrade_write(&context->umem_rwsem);
	if (!context->odp_mrs_count) {
	downgrade_write(&per_mm->umem_rwsem);
	if (!per_mm->odp_mrs_count) {
		struct task_struct *owning_process = NULL;
		struct mm_struct *owning_mm        = NULL;

		owning_process = get_pid_task(context->tgid,
					      PIDTYPE_PID);
		owning_process =
			get_pid_task(umem_odp->umem.context->tgid, PIDTYPE_PID);
		if (owning_process == NULL)
			/*
			 * The process is already dead, notifier were removed
@@ -498,7 +507,7 @@ void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
			 * removed already.
			 */
			goto out_put_task;
		mmu_notifier_unregister(&context->mn, owning_mm);
		mmu_notifier_unregister(&per_mm->mn, owning_mm);

		mmput(owning_mm);

@@ -506,7 +515,7 @@ void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
		put_task_struct(owning_process);
	}
out:
	up_read(&context->umem_rwsem);
	up_read(&per_mm->umem_rwsem);

	vfree(umem_odp->dma_list);
	vfree(umem_odp->page_list);
+5 −4
Original line number Diff line number Diff line
@@ -124,10 +124,11 @@ ssize_t ib_uverbs_get_context(struct ib_uverbs_file *file,
	ucontext->cleanup_retryable = false;

#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
	ucontext->umem_tree = RB_ROOT_CACHED;
	init_rwsem(&ucontext->umem_rwsem);
	ucontext->odp_mrs_count = 0;
	INIT_LIST_HEAD(&ucontext->no_private_counters);
	ucontext->per_mm.umem_tree = RB_ROOT_CACHED;
	init_rwsem(&ucontext->per_mm.umem_rwsem);
	ucontext->per_mm.odp_mrs_count = 0;
	INIT_LIST_HEAD(&ucontext->per_mm.no_private_counters);
	ucontext->per_mm.context = ucontext;

	if (!(ib_dev->attrs.device_cap_flags & IB_DEVICE_ON_DEMAND_PAGING))
		ucontext->invalidate_range = NULL;
+25 −18
Original line number Diff line number Diff line
@@ -61,13 +61,21 @@ static int check_parent(struct ib_umem_odp *odp,
	return mr && mr->parent == parent && !odp->dying;
}

struct ib_ucontext_per_mm *mr_to_per_mm(struct mlx5_ib_mr *mr)
{
	if (WARN_ON(!mr || !mr->umem || !mr->umem->is_odp))
		return NULL;

	return to_ib_umem_odp(mr->umem)->per_mm;
}

static struct ib_umem_odp *odp_next(struct ib_umem_odp *odp)
{
	struct mlx5_ib_mr *mr = odp->private, *parent = mr->parent;
	struct ib_ucontext *ctx = odp->umem.context;
	struct ib_ucontext_per_mm *per_mm = odp->per_mm;
	struct rb_node *rb;

	down_read(&ctx->umem_rwsem);
	down_read(&per_mm->umem_rwsem);
	while (1) {
		rb = rb_next(&odp->interval_tree.rb);
		if (!rb)
@@ -79,19 +87,19 @@ static struct ib_umem_odp *odp_next(struct ib_umem_odp *odp)
not_found:
	odp = NULL;
end:
	up_read(&ctx->umem_rwsem);
	up_read(&per_mm->umem_rwsem);
	return odp;
}

static struct ib_umem_odp *odp_lookup(struct ib_ucontext *ctx,
				      u64 start, u64 length,
static struct ib_umem_odp *odp_lookup(u64 start, u64 length,
				      struct mlx5_ib_mr *parent)
{
	struct ib_ucontext_per_mm *per_mm = mr_to_per_mm(parent);
	struct ib_umem_odp *odp;
	struct rb_node *rb;

	down_read(&ctx->umem_rwsem);
	odp = rbt_ib_umem_lookup(&ctx->umem_tree, start, length);
	down_read(&per_mm->umem_rwsem);
	odp = rbt_ib_umem_lookup(&per_mm->umem_tree, start, length);
	if (!odp)
		goto end;

@@ -108,7 +116,7 @@ static struct ib_umem_odp *odp_lookup(struct ib_ucontext *ctx,
not_found:
	odp = NULL;
end:
	up_read(&ctx->umem_rwsem);
	up_read(&per_mm->umem_rwsem);
	return odp;
}

@@ -116,7 +124,6 @@ void mlx5_odp_populate_klm(struct mlx5_klm *pklm, size_t offset,
			   size_t nentries, struct mlx5_ib_mr *mr, int flags)
{
	struct ib_pd *pd = mr->ibmr.pd;
	struct ib_ucontext *ctx = pd->uobject->context;
	struct mlx5_ib_dev *dev = to_mdev(pd->device);
	struct ib_umem_odp *odp;
	unsigned long va;
@@ -131,7 +138,7 @@ void mlx5_odp_populate_klm(struct mlx5_klm *pklm, size_t offset,
		return;
	}

	odp = odp_lookup(ctx, offset * MLX5_IMR_MTT_SIZE,
	odp = odp_lookup(offset * MLX5_IMR_MTT_SIZE,
			 nentries * MLX5_IMR_MTT_SIZE, mr);

	for (i = 0; i < nentries; i++, pklm++) {
@@ -368,7 +375,6 @@ static struct mlx5_ib_mr *implicit_mr_alloc(struct ib_pd *pd,
static struct ib_umem_odp *implicit_mr_get_data(struct mlx5_ib_mr *mr,
						u64 io_virt, size_t bcnt)
{
	struct ib_ucontext *ctx = mr->ibmr.pd->uobject->context;
	struct mlx5_ib_dev *dev = to_mdev(mr->ibmr.pd->device);
	struct ib_umem_odp *odp, *result = NULL;
	struct ib_umem_odp *odp_mr = to_ib_umem_odp(mr->umem);
@@ -377,7 +383,7 @@ static struct ib_umem_odp *implicit_mr_get_data(struct mlx5_ib_mr *mr,
	struct mlx5_ib_mr *mtt;

	mutex_lock(&odp_mr->umem_mutex);
	odp = odp_lookup(ctx, addr, 1, mr);
	odp = odp_lookup(addr, 1, mr);

	mlx5_ib_dbg(dev, "io_virt:%llx bcnt:%zx addr:%llx odp:%p\n",
		    io_virt, bcnt, addr, odp);
@@ -387,7 +393,8 @@ static struct ib_umem_odp *implicit_mr_get_data(struct mlx5_ib_mr *mr,
		if (nentries)
			nentries++;
	} else {
		odp = ib_alloc_odp_umem(ctx, addr, MLX5_IMR_MTT_SIZE);
		odp = ib_alloc_odp_umem(odp_mr->umem.context, addr,
					MLX5_IMR_MTT_SIZE);
		if (IS_ERR(odp)) {
			mutex_unlock(&odp_mr->umem_mutex);
			return ERR_CAST(odp);
@@ -486,12 +493,12 @@ static int mr_leaf_free(struct ib_umem_odp *umem_odp, u64 start, u64 end,

void mlx5_ib_free_implicit_mr(struct mlx5_ib_mr *imr)
{
	struct ib_ucontext *ctx = imr->ibmr.pd->uobject->context;
	struct ib_ucontext_per_mm *per_mm = mr_to_per_mm(imr);

	down_read(&ctx->umem_rwsem);
	rbt_ib_umem_for_each_in_range(&ctx->umem_tree, 0, ULLONG_MAX,
	down_read(&per_mm->umem_rwsem);
	rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, 0, ULLONG_MAX,
				      mr_leaf_free, true, imr);
	up_read(&ctx->umem_rwsem);
	up_read(&per_mm->umem_rwsem);

	wait_event(imr->q_leaf_free, !atomic_read(&imr->num_leaf_free));
}
+2 −0
Original line number Diff line number Diff line
@@ -44,6 +44,8 @@ struct umem_odp_node {

struct ib_umem_odp {
	struct ib_umem umem;
	struct ib_ucontext_per_mm *per_mm;

	/*
	 * An array of the pages included in the on-demand paging umem.
	 * Indices of pages that are currently not mapped into the device will
+20 −12
Original line number Diff line number Diff line
@@ -1488,6 +1488,25 @@ struct ib_rdmacg_object {
#endif
};

#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
struct ib_ucontext_per_mm {
	struct ib_ucontext *context;

	struct rb_root_cached umem_tree;
	/*
	 * Protects .umem_rbroot and tree, as well as odp_mrs_count and
	 * mmu notifiers registration.
	 */
	struct rw_semaphore umem_rwsem;

	struct mmu_notifier mn;
	atomic_t notifier_count;
	/* A list of umems that don't have private mmu notifier counters yet. */
	struct list_head no_private_counters;
	unsigned int odp_mrs_count;
};
#endif

struct ib_ucontext {
	struct ib_device       *device;
	struct ib_uverbs_file  *ufile;
@@ -1502,20 +1521,9 @@ struct ib_ucontext {

	struct pid             *tgid;
#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
	struct rb_root_cached   umem_tree;
	/*
	 * Protects .umem_rbroot and tree, as well as odp_mrs_count and
	 * mmu notifiers registration.
	 */
	struct rw_semaphore	umem_rwsem;
	void (*invalidate_range)(struct ib_umem_odp *umem_odp,
				 unsigned long start, unsigned long end);

	struct mmu_notifier	mn;
	atomic_t		notifier_count;
	/* A list of umems that don't have private mmu notifier counters yet. */
	struct list_head	no_private_counters;
	int                     odp_mrs_count;
	struct ib_ucontext_per_mm per_mm;
#endif

	struct ib_rdmacg_object	cg_obj;