Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit f27a0d50 authored by Jason Gunthorpe's avatar Jason Gunthorpe Committed by Doug Ledford
Browse files

RDMA/umem: Use umem->owning_mm inside ODP



Since ODP had a single struct mmu_notifier located in the ucontext it
could only handle a single MM at a time, and this prevented it from using
the new owning_mm system.

With the prior rework it is now simple to let ODP track multiple MMs per
ucontext, finish the job so that the per_mm is allocated on a mm by mm
basis, and freed when the last umem is dropped from the ucontext.

As a side effect the new saner locking removes the lockdep splat about
nesting the umem_rwsem between mmu_notifier_unregister and
ib_umem_odp_release.

It also makes ODP work with multiple processes, across, fork, etc.

Signed-off-by: default avatarJason Gunthorpe <jgg@mellanox.com>
Signed-off-by: default avatarLeon Romanovsky <leonro@mellanox.com>
Signed-off-by: default avatarDoug Ledford <dledford@redhat.com>
parent c9990ab3
Loading
Loading
Loading
Loading
+160 −141
Original line number Diff line number Diff line
@@ -278,10 +278,135 @@ static const struct mmu_notifier_ops ib_umem_notifiers = {
	.invalidate_range_end       = ib_umem_notifier_invalidate_range_end,
};

struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context,
				      unsigned long addr, size_t size)
static void add_umem_to_per_mm(struct ib_umem_odp *umem_odp)
{
	struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
	struct ib_umem *umem = &umem_odp->umem;

	down_write(&per_mm->umem_rwsem);
	if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
		rbt_ib_umem_insert(&umem_odp->interval_tree,
				   &per_mm->umem_tree);

	if (likely(!atomic_read(&per_mm->notifier_count)))
		umem_odp->mn_counters_active = true;
	else
		list_add(&umem_odp->no_private_counters,
			 &per_mm->no_private_counters);
	up_write(&per_mm->umem_rwsem);
}

static void remove_umem_from_per_mm(struct ib_umem_odp *umem_odp)
{
	struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
	struct ib_umem *umem = &umem_odp->umem;

	down_write(&per_mm->umem_rwsem);
	if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
		rbt_ib_umem_remove(&umem_odp->interval_tree,
				   &per_mm->umem_tree);
	if (!umem_odp->mn_counters_active) {
		list_del(&umem_odp->no_private_counters);
		complete_all(&umem_odp->notifier_completion);
	}

	up_write(&per_mm->umem_rwsem);
}

static struct ib_ucontext_per_mm *alloc_per_mm(struct ib_ucontext *ctx,
					       struct mm_struct *mm)
{
	struct ib_ucontext_per_mm *per_mm;
	int ret;

	per_mm = kzalloc(sizeof(*per_mm), GFP_KERNEL);
	if (!per_mm)
		return ERR_PTR(-ENOMEM);

	per_mm->context = ctx;
	per_mm->mm = mm;
	per_mm->umem_tree = RB_ROOT_CACHED;
	init_rwsem(&per_mm->umem_rwsem);
	INIT_LIST_HEAD(&per_mm->no_private_counters);

	rcu_read_lock();
	per_mm->tgid = get_task_pid(current->group_leader, PIDTYPE_PID);
	rcu_read_unlock();

	WARN_ON(mm != current->mm);

	per_mm->mn.ops = &ib_umem_notifiers;
	ret = mmu_notifier_register(&per_mm->mn, per_mm->mm);
	if (ret) {
		dev_err(&ctx->device->dev,
			"Failed to register mmu_notifier %d\n", ret);
		goto out_pid;
	}

	list_add(&per_mm->ucontext_list, &ctx->per_mm_list);
	return per_mm;

out_pid:
	put_pid(per_mm->tgid);
	kfree(per_mm);
	return ERR_PTR(ret);
}

static int get_per_mm(struct ib_umem_odp *umem_odp)
{
	struct ib_ucontext *ctx = umem_odp->umem.context;
	struct ib_ucontext_per_mm *per_mm;

	/*
	 * Generally speaking we expect only one or two per_mm in this list,
	 * so no reason to optimize this search today.
	 */
	mutex_lock(&ctx->per_mm_list_lock);
	list_for_each_entry(per_mm, &ctx->per_mm_list, ucontext_list) {
		if (per_mm->mm == umem_odp->umem.owning_mm)
			goto found;
	}

	per_mm = alloc_per_mm(ctx, umem_odp->umem.owning_mm);
	if (IS_ERR(per_mm)) {
		mutex_unlock(&ctx->per_mm_list_lock);
		return PTR_ERR(per_mm);
	}

found:
	umem_odp->per_mm = per_mm;
	per_mm->odp_mrs_count++;
	mutex_unlock(&ctx->per_mm_list_lock);

	return 0;
}

void put_per_mm(struct ib_umem_odp *umem_odp)
{
	struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
	struct ib_ucontext *ctx = umem_odp->umem.context;
	bool need_free;

	mutex_lock(&ctx->per_mm_list_lock);
	umem_odp->per_mm = NULL;
	per_mm->odp_mrs_count--;
	need_free = per_mm->odp_mrs_count == 0;
	if (need_free)
		list_del(&per_mm->ucontext_list);
	mutex_unlock(&ctx->per_mm_list_lock);

	if (!need_free)
		return;

	mmu_notifier_unregister(&per_mm->mn, per_mm->mm);
	put_pid(per_mm->tgid);
	kfree(per_mm);
}

struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext_per_mm *per_mm,
				      unsigned long addr, size_t size)
{
	struct ib_ucontext *ctx = per_mm->context;
	struct ib_umem_odp *odp_data;
	struct ib_umem *umem;
	int pages = size >> PAGE_SHIFT;
@@ -291,13 +416,13 @@ struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context,
	if (!odp_data)
		return ERR_PTR(-ENOMEM);
	umem = &odp_data->umem;
	umem->context    = context;
	umem->context    = ctx;
	umem->length     = size;
	umem->address    = addr;
	umem->page_shift = PAGE_SHIFT;
	umem->writable   = 1;
	umem->is_odp = 1;
	odp_data->per_mm = per_mm = &context->per_mm;
	odp_data->per_mm = per_mm;

	mutex_init(&odp_data->umem_mutex);
	init_completion(&odp_data->notifier_completion);
@@ -316,15 +441,14 @@ struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context,
		goto out_page_list;
	}

	down_write(&per_mm->umem_rwsem);
	/*
	 * Caller must ensure that the umem_odp that the per_mm came from
	 * cannot be freed during the call to ib_alloc_odp_umem.
	 */
	mutex_lock(&ctx->per_mm_list_lock);
	per_mm->odp_mrs_count++;
	rbt_ib_umem_insert(&odp_data->interval_tree, &per_mm->umem_tree);
	if (likely(!atomic_read(&per_mm->notifier_count)))
		odp_data->mn_counters_active = true;
	else
		list_add(&odp_data->no_private_counters,
			 &per_mm->no_private_counters);
	up_write(&per_mm->umem_rwsem);
	mutex_unlock(&ctx->per_mm_list_lock);
	add_umem_to_per_mm(odp_data);

	return odp_data;

@@ -338,15 +462,13 @@ EXPORT_SYMBOL(ib_alloc_odp_umem);

int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
{
	struct ib_ucontext *context = umem_odp->umem.context;
	struct ib_umem *umem = &umem_odp->umem;
	struct ib_ucontext_per_mm *per_mm;
	/*
	 * NOTE: This must called in a process context where umem->owning_mm
	 * == current->mm
	 */
	struct mm_struct *mm = umem->owning_mm;
	int ret_val;
	struct pid *our_pid;
	struct mm_struct *mm = get_task_mm(current);

	if (!mm)
		return -EINVAL;

	if (access & IB_ACCESS_HUGETLB) {
		struct vm_area_struct *vma;
@@ -366,16 +488,6 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
		umem->hugetlb = 0;
	}

	/* Prevent creating ODP MRs in child processes */
	rcu_read_lock();
	our_pid = get_task_pid(current->group_leader, PIDTYPE_PID);
	rcu_read_unlock();
	put_pid(our_pid);
	if (context->tgid != our_pid) {
		ret_val = -EINVAL;
		goto out_mm;
	}

	mutex_init(&umem_odp->umem_mutex);

	init_completion(&umem_odp->notifier_completion);
@@ -384,10 +496,8 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
		umem_odp->page_list =
			vzalloc(array_size(sizeof(*umem_odp->page_list),
					   ib_umem_num_pages(umem)));
		if (!umem_odp->page_list) {
			ret_val = -ENOMEM;
			goto out_mm;
		}
		if (!umem_odp->page_list)
			return -ENOMEM;

		umem_odp->dma_list =
			vzalloc(array_size(sizeof(*umem_odp->dma_list),
@@ -398,67 +508,23 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
		}
	}

	/*
	 * When using MMU notifiers, we will get a
	 * notification before the "current" task (and MM) is
	 * destroyed. We use the umem_rwsem semaphore to synchronize.
	 */
	umem_odp->per_mm = per_mm = &context->per_mm;

	down_write(&per_mm->umem_rwsem);
	per_mm->odp_mrs_count++;
	if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
		rbt_ib_umem_insert(&umem_odp->interval_tree,
				   &per_mm->umem_tree);
	if (likely(!atomic_read(&per_mm->notifier_count)) ||
	    per_mm->odp_mrs_count == 1)
		umem_odp->mn_counters_active = true;
	else
		list_add(&umem_odp->no_private_counters,
			 &per_mm->no_private_counters);
	downgrade_write(&per_mm->umem_rwsem);

	if (per_mm->odp_mrs_count == 1) {
		/*
		 * Note that at this point, no MMU notifier is running
		 * for this per_mm!
		 */
		atomic_set(&per_mm->notifier_count, 0);
		INIT_HLIST_NODE(&per_mm->mn.hlist);
		per_mm->mn.ops = &ib_umem_notifiers;
		ret_val = mmu_notifier_register(&per_mm->mn, mm);
		if (ret_val) {
			pr_err("Failed to register mmu_notifier %d\n", ret_val);
			ret_val = -EBUSY;
			goto out_mutex;
		}
	}

	up_read(&per_mm->umem_rwsem);
	ret_val = get_per_mm(umem_odp);
	if (ret_val)
		goto out_dma_list;
	add_umem_to_per_mm(umem_odp);

	/*
	 * Note that doing an mmput can cause a notifier for the relevant mm.
	 * If the notifier is called while we hold the umem_rwsem, this will
	 * cause a deadlock. Therefore, we release the reference only after we
	 * released the semaphore.
	 */
	mmput(mm);
	return 0;

out_mutex:
	up_read(&per_mm->umem_rwsem);
out_dma_list:
	vfree(umem_odp->dma_list);
out_page_list:
	vfree(umem_odp->page_list);
out_mm:
	mmput(mm);
	return ret_val;
}

void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
{
	struct ib_umem *umem = &umem_odp->umem;
	struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;

	/*
	 * Ensure that no more pages are mapped in the umem.
@@ -469,54 +535,8 @@ void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
	ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem),
				    ib_umem_end(umem));

	down_write(&per_mm->umem_rwsem);
	if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
		rbt_ib_umem_remove(&umem_odp->interval_tree,
				   &per_mm->umem_tree);
	per_mm->odp_mrs_count--;
	if (!umem_odp->mn_counters_active) {
		list_del(&umem_odp->no_private_counters);
		complete_all(&umem_odp->notifier_completion);
	}

	/*
	 * Downgrade the lock to a read lock. This ensures that the notifiers
	 * (who lock the mutex for reading) will be able to finish, and we
	 * will be able to enventually obtain the mmu notifiers SRCU. Note
	 * that since we are doing it atomically, no other user could register
	 * and unregister while we do the check.
	 */
	downgrade_write(&per_mm->umem_rwsem);
	if (!per_mm->odp_mrs_count) {
		struct task_struct *owning_process = NULL;
		struct mm_struct *owning_mm        = NULL;

		owning_process =
			get_pid_task(umem_odp->umem.context->tgid, PIDTYPE_PID);
		if (owning_process == NULL)
			/*
			 * The process is already dead, notifier were removed
			 * already.
			 */
			goto out;

		owning_mm = get_task_mm(owning_process);
		if (owning_mm == NULL)
			/*
			 * The process' mm is already dead, notifier were
			 * removed already.
			 */
			goto out_put_task;
		mmu_notifier_unregister(&per_mm->mn, owning_mm);

		mmput(owning_mm);

out_put_task:
		put_task_struct(owning_process);
	}
out:
	up_read(&per_mm->umem_rwsem);

	remove_umem_from_per_mm(umem_odp);
	put_per_mm(umem_odp);
	vfree(umem_odp->dma_list);
	vfree(umem_odp->page_list);
}
@@ -634,7 +654,7 @@ int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt,
{
	struct ib_umem *umem = &umem_odp->umem;
	struct task_struct *owning_process  = NULL;
	struct mm_struct   *owning_mm       = NULL;
	struct mm_struct *owning_mm = umem_odp->umem.owning_mm;
	struct page       **local_page_list = NULL;
	u64 page_mask, off;
	int j, k, ret = 0, start_idx, npages = 0, page_shift;
@@ -658,15 +678,14 @@ int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt,
	user_virt = user_virt & page_mask;
	bcnt += off; /* Charge for the first page offset as well. */

	owning_process = get_pid_task(umem->context->tgid, PIDTYPE_PID);
	if (owning_process == NULL) {
	/*
	 * owning_process is allowed to be NULL, this means somehow the mm is
	 * existing beyond the lifetime of the originating process.. Presumably
	 * mmget_not_zero will fail in this case.
	 */
	owning_process = get_pid_task(umem_odp->per_mm->tgid, PIDTYPE_PID);
	if (WARN_ON(!mmget_not_zero(umem_odp->umem.owning_mm))) {
		ret = -EINVAL;
		goto out_no_task;
	}

	owning_mm = get_task_mm(owning_process);
	if (owning_mm == NULL) {
		ret = -ENOENT;
		goto out_put_task;
	}

@@ -738,8 +757,8 @@ int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt,

	mmput(owning_mm);
out_put_task:
	if (owning_process)
		put_task_struct(owning_process);
out_no_task:
	free_page((unsigned long)local_page_list);
	return ret;
}
+2 −6
Original line number Diff line number Diff line
@@ -124,12 +124,8 @@ ssize_t ib_uverbs_get_context(struct ib_uverbs_file *file,
	ucontext->cleanup_retryable = false;

#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
	ucontext->per_mm.umem_tree = RB_ROOT_CACHED;
	init_rwsem(&ucontext->per_mm.umem_rwsem);
	ucontext->per_mm.odp_mrs_count = 0;
	INIT_LIST_HEAD(&ucontext->per_mm.no_private_counters);
	ucontext->per_mm.context = ucontext;

	mutex_init(&ucontext->per_mm_list_lock);
	INIT_LIST_HEAD(&ucontext->per_mm_list);
	if (!(ib_dev->attrs.device_cap_flags & IB_DEVICE_ON_DEMAND_PAGING))
		ucontext->invalidate_range = NULL;

+7 −0
Original line number Diff line number Diff line
@@ -1861,6 +1861,13 @@ static int mlx5_ib_dealloc_ucontext(struct ib_ucontext *ibcontext)
	struct mlx5_ib_dev *dev = to_mdev(ibcontext->device);
	struct mlx5_bfreg_info *bfregi;

#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
	/* All umem's must be destroyed before destroying the ucontext. */
	mutex_lock(&ibcontext->per_mm_list_lock);
	WARN_ON(!list_empty(&ibcontext->per_mm_list));
	mutex_unlock(&ibcontext->per_mm_list_lock);
#endif

	if (context->devx_uid)
		mlx5_ib_devx_destroy(dev, context);

+1 −1
Original line number Diff line number Diff line
@@ -393,7 +393,7 @@ static struct ib_umem_odp *implicit_mr_get_data(struct mlx5_ib_mr *mr,
		if (nentries)
			nentries++;
	} else {
		odp = ib_alloc_odp_umem(odp_mr->umem.context, addr,
		odp = ib_alloc_odp_umem(odp_mr->per_mm, addr,
					MLX5_IMR_MTT_SIZE);
		if (IS_ERR(odp)) {
			mutex_unlock(&odp_mr->umem_mutex);
+19 −1
Original line number Diff line number Diff line
@@ -91,8 +91,26 @@ static inline struct ib_umem_odp *to_ib_umem_odp(struct ib_umem *umem)

#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING

struct ib_ucontext_per_mm {
	struct ib_ucontext *context;
	struct mm_struct *mm;
	struct pid *tgid;

	struct rb_root_cached umem_tree;
	/* Protects umem_tree */
	struct rw_semaphore umem_rwsem;
	atomic_t notifier_count;

	struct mmu_notifier mn;
	/* A list of umems that don't have private mmu notifier counters yet. */
	struct list_head no_private_counters;
	unsigned int odp_mrs_count;

	struct list_head ucontext_list;
};

int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access);
struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context,
struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext_per_mm *per_mm,
				      unsigned long addr, size_t size);
void ib_umem_odp_release(struct ib_umem_odp *umem_odp);

Loading