Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit ca22354b authored by Jason Gunthorpe's avatar Jason Gunthorpe
Browse files

RDMA/rxe: Close a race after ib_register_device



Since rxe allows unregistration from other threads the rxe pointer can
become invalid any moment after ib_register_driver returns. This could
cause a user triggered use after free.

Add another driver callback to be called right after the device becomes
registered to complete any device setup required post-registration.  This
callback has enough core locking to prevent the device from becoming
unregistered.

Signed-off-by: default avatarJason Gunthorpe <jgg@mellanox.com>
parent 6cc2c8e5
Loading
Loading
Loading
Loading
+9 −0
Original line number Diff line number Diff line
@@ -803,6 +803,12 @@ static int enable_device_and_get(struct ib_device *device)
	 */
	downgrade_write(&devices_rwsem);

	if (device->ops.enable_driver) {
		ret = device->ops.enable_driver(device);
		if (ret)
			goto out;
	}

	down_read(&clients_rwsem);
	xa_for_each_marked (&clients, index, client, CLIENT_REGISTERED) {
		ret = add_client_context(device, client);
@@ -810,6 +816,8 @@ static int enable_device_and_get(struct ib_device *device)
			break;
	}
	up_read(&clients_rwsem);

out:
	up_read(&devices_rwsem);
	return ret;
}
@@ -1775,6 +1783,7 @@ void ib_set_device_ops(struct ib_device *dev, const struct ib_device_ops *ops)
	SET_DEVICE_OP(dev_ops, disassociate_ucontext);
	SET_DEVICE_OP(dev_ops, drain_rq);
	SET_DEVICE_OP(dev_ops, drain_sq);
	SET_DEVICE_OP(dev_ops, enable_driver);
	SET_DEVICE_OP(dev_ops, fill_res_entry);
	SET_DEVICE_OP(dev_ops, get_dev_fw_str);
	SET_DEVICE_OP(dev_ops, get_dma_mr);
+4 −4
Original line number Diff line number Diff line
@@ -517,24 +517,24 @@ enum rdma_link_layer rxe_link_layer(struct rxe_dev *rxe, unsigned int port_num)
	return IB_LINK_LAYER_ETHERNET;
}

struct rxe_dev *rxe_net_add(struct net_device *ndev)
int rxe_net_add(struct net_device *ndev)
{
	int err;
	struct rxe_dev *rxe = NULL;

	rxe = ib_alloc_device(rxe_dev, ib_dev);
	if (!rxe)
		return NULL;
		return -ENOMEM;

	rxe->ndev = ndev;

	err = rxe_add(rxe, ndev->mtu);
	if (err) {
		ib_dealloc_device(&rxe->ib_dev);
		return NULL;
		return err;
	}

	return rxe;
	return 0;
}

static void rxe_port_event(struct rxe_dev *rxe,
+1 −1
Original line number Diff line number Diff line
@@ -43,7 +43,7 @@ struct rxe_recv_sockets {
	struct socket *sk6;
};

struct rxe_dev *rxe_net_add(struct net_device *ndev);
int rxe_net_add(struct net_device *ndev);

int rxe_net_init(void);
void rxe_net_exit(void);
+2 −7
Original line number Diff line number Diff line
@@ -60,7 +60,6 @@ static int rxe_param_set_add(const char *val, const struct kernel_param *kp)
	char intf[32];
	struct net_device *ndev;
	struct rxe_dev *exists;
	struct rxe_dev *rxe;

	len = sanitize_arg(val, intf, sizeof(intf));
	if (!len) {
@@ -82,16 +81,12 @@ static int rxe_param_set_add(const char *val, const struct kernel_param *kp)
		goto err;
	}

	rxe = rxe_net_add(ndev);
	if (!rxe) {
	err = rxe_net_add(ndev);
	if (err) {
		pr_err("failed to add %s\n", intf);
		err = -EINVAL;
		goto err;
	}

	rxe_set_port_state(rxe);
	dev_info(&rxe->ib_dev.dev, "added %s\n", intf);

err:
	dev_put(ndev);
	return err;
+14 −0
Original line number Diff line number Diff line
@@ -1125,6 +1125,15 @@ static const struct attribute_group rxe_attr_group = {
	.attrs = rxe_dev_attributes,
};

static int rxe_enable_driver(struct ib_device *ib_dev)
{
	struct rxe_dev *rxe = container_of(ib_dev, struct rxe_dev, ib_dev);

	rxe_set_port_state(rxe);
	dev_info(&rxe->ib_dev.dev, "added %s\n", netdev_name(rxe->ndev));
	return 0;
}

static const struct ib_device_ops rxe_dev_ops = {
	.alloc_hw_stats = rxe_ib_alloc_hw_stats,
	.alloc_mr = rxe_alloc_mr,
@@ -1144,6 +1153,7 @@ static const struct ib_device_ops rxe_dev_ops = {
	.destroy_qp = rxe_destroy_qp,
	.destroy_srq = rxe_destroy_srq,
	.detach_mcast = rxe_detach_mcast,
	.enable_driver = rxe_enable_driver,
	.get_dma_mr = rxe_get_dma_mr,
	.get_hw_stats = rxe_ib_get_hw_stats,
	.get_link_layer = rxe_get_link_layer,
@@ -1245,5 +1255,9 @@ int rxe_register_device(struct rxe_dev *rxe)
	if (err)
		pr_warn("%s failed with error %d\n", __func__, err);

	/*
	 * Note that rxe may be invalid at this point if another thread
	 * unregistered it.
	 */
	return err;
}
Loading