Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 5adb5bc6 authored by Zach Brown's avatar Zach Brown Committed by Andy Grover
Browse files

RDS: have sockets get transport module references



Right now there's nothing to stop the various paths that use
rs->rs_transport from racing with rmmod and executing freed transport
code.  The simple fix is to have binding to a transport also hold a
reference to the transport's module, removing this class of races.

We already had an unused t_owner field which was set for the modular
transports and which wasn't set for the built-in loop transport.

Signed-off-by: default avatarZach Brown <zach.brown@oracle.com>
parent 77510481
Loading
Loading
Loading
Loading
+2 −0
Original line number Original line Diff line number Diff line
@@ -90,6 +90,8 @@ static int rds_release(struct socket *sock)
	rds_sock_count--;
	rds_sock_count--;
	spin_unlock_irqrestore(&rds_sock_lock, flags);
	spin_unlock_irqrestore(&rds_sock_lock, flags);


	rds_trans_put(rs->rs_transport);

	sock->sk = NULL;
	sock->sk = NULL;
	sock_put(sk);
	sock_put(sk);
out:
out:
+4 −1
Original line number Original line Diff line number Diff line
@@ -117,6 +117,7 @@ static struct rds_connection *__rds_conn_create(__be32 laddr, __be32 faddr,
{
{
	struct rds_connection *conn, *parent = NULL;
	struct rds_connection *conn, *parent = NULL;
	struct hlist_head *head = rds_conn_bucket(laddr, faddr);
	struct hlist_head *head = rds_conn_bucket(laddr, faddr);
	struct rds_transport *loop_trans;
	unsigned long flags;
	unsigned long flags;
	int ret;
	int ret;


@@ -163,7 +164,9 @@ static struct rds_connection *__rds_conn_create(__be32 laddr, __be32 faddr,
	 * can bind to the destination address then we'd rather the messages
	 * can bind to the destination address then we'd rather the messages
	 * flow through loopback rather than either transport.
	 * flow through loopback rather than either transport.
	 */
	 */
	if (rds_trans_get_preferred(faddr)) {
	loop_trans = rds_trans_get_preferred(faddr);
	if (loop_trans) {
		rds_trans_put(loop_trans);
		conn->c_loopback = 1;
		conn->c_loopback = 1;
		if (is_outgoing && trans->t_prefer_loopback) {
		if (is_outgoing && trans->t_prefer_loopback) {
			/* "outgoing" connection - and the transport
			/* "outgoing" connection - and the transport
+1 −0
Original line number Original line Diff line number Diff line
@@ -798,6 +798,7 @@ void rds_connect_complete(struct rds_connection *conn);
int rds_trans_register(struct rds_transport *trans);
int rds_trans_register(struct rds_transport *trans);
void rds_trans_unregister(struct rds_transport *trans);
void rds_trans_unregister(struct rds_transport *trans);
struct rds_transport *rds_trans_get_preferred(__be32 addr);
struct rds_transport *rds_trans_get_preferred(__be32 addr);
void rds_trans_put(struct rds_transport *trans);
unsigned int rds_trans_stats_info_copy(struct rds_info_iterator *iter,
unsigned int rds_trans_stats_info_copy(struct rds_info_iterator *iter,
				       unsigned int avail);
				       unsigned int avail);
int rds_trans_init(void);
int rds_trans_init(void);
+14 −5
Original line number Original line Diff line number Diff line
@@ -71,19 +71,28 @@ void rds_trans_unregister(struct rds_transport *trans)
}
}
EXPORT_SYMBOL_GPL(rds_trans_unregister);
EXPORT_SYMBOL_GPL(rds_trans_unregister);


void rds_trans_put(struct rds_transport *trans)
{
	if (trans && trans->t_owner)
		module_put(trans->t_owner);
}

struct rds_transport *rds_trans_get_preferred(__be32 addr)
struct rds_transport *rds_trans_get_preferred(__be32 addr)
{
{
	struct rds_transport *ret = NULL;
	struct rds_transport *ret = NULL;
	int i;
	struct rds_transport *trans;
	unsigned int i;


	if (IN_LOOPBACK(ntohl(addr)))
	if (IN_LOOPBACK(ntohl(addr)))
		return &rds_loop_transport;
		return &rds_loop_transport;


	down_read(&rds_trans_sem);
	down_read(&rds_trans_sem);
	for (i = 0; i < RDS_TRANS_COUNT; i++)
	for (i = 0; i < RDS_TRANS_COUNT; i++) {
	{
		trans = transports[i];
		if (transports[i] && (transports[i]->laddr_check(addr) == 0)) {

			ret = transports[i];
		if (trans && (trans->laddr_check(addr) == 0) &&
		    (!trans->t_owner || try_module_get(trans->t_owner))) {
			ret = trans;
			break;
			break;
		}
		}
	}
	}