Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 7727640c authored by Tim Smith's avatar Tim Smith Committed by David Howells
Browse files

af_rxrpc: Keep rxrpc_call pointers in a hashtable



Keep track of rxrpc_call structures in a hashtable so they can be
found directly from the network parameters which define the call.

This allows incoming packets to be routed directly to a call without walking
through hierarchy of peer -> transport -> connection -> call and all the
spinlocks that that entailed.

Signed-off-by: default avatarTim Smith <tim@electronghost.co.uk>
Signed-off-by: default avatarDavid Howells <dhowells@redhat.com>
parent e8388eb1
Loading
Loading
Loading
Loading
+191 −2
Original line number Original line Diff line number Diff line
@@ -12,6 +12,8 @@
#include <linux/slab.h>
#include <linux/slab.h>
#include <linux/module.h>
#include <linux/module.h>
#include <linux/circ_buf.h>
#include <linux/circ_buf.h>
#include <linux/hashtable.h>
#include <linux/spinlock_types.h>
#include <net/sock.h>
#include <net/sock.h>
#include <net/af_rxrpc.h>
#include <net/af_rxrpc.h>
#include "ar-internal.h"
#include "ar-internal.h"
@@ -55,6 +57,145 @@ static void rxrpc_dead_call_expired(unsigned long _call);
static void rxrpc_ack_time_expired(unsigned long _call);
static void rxrpc_ack_time_expired(unsigned long _call);
static void rxrpc_resend_time_expired(unsigned long _call);
static void rxrpc_resend_time_expired(unsigned long _call);


static DEFINE_SPINLOCK(rxrpc_call_hash_lock);
static DEFINE_HASHTABLE(rxrpc_call_hash, 10);

/*
 * Hash function for rxrpc_call_hash
 */
static unsigned long rxrpc_call_hashfunc(
	u8		clientflag,
	__be32		cid,
	__be32		call_id,
	__be32		epoch,
	__be16		service_id,
	sa_family_t	proto,
	void		*localptr,
	unsigned int	addr_size,
	const u8	*peer_addr)
{
	const u16 *p;
	unsigned int i;
	unsigned long key;
	u32 hcid = ntohl(cid);

	_enter("");

	key = (unsigned long)localptr;
	/* We just want to add up the __be32 values, so forcing the
	 * cast should be okay.
	 */
	key += (__force u32)epoch;
	key += (__force u16)service_id;
	key += (__force u32)call_id;
	key += (hcid & RXRPC_CIDMASK) >> RXRPC_CIDSHIFT;
	key += hcid & RXRPC_CHANNELMASK;
	key += clientflag;
	key += proto;
	/* Step through the peer address in 16-bit portions for speed */
	for (i = 0, p = (const u16 *)peer_addr; i < addr_size >> 1; i++, p++)
		key += *p;
	_leave(" key = 0x%lx", key);
	return key;
}

/*
 * Add a call to the hashtable
 */
static void rxrpc_call_hash_add(struct rxrpc_call *call)
{
	unsigned long key;
	unsigned int addr_size = 0;

	_enter("");
	switch (call->proto) {
	case AF_INET:
		addr_size = sizeof(call->peer_ip.ipv4_addr);
		break;
	case AF_INET6:
		addr_size = sizeof(call->peer_ip.ipv6_addr);
		break;
	default:
		break;
	}
	key = rxrpc_call_hashfunc(call->in_clientflag, call->cid,
				  call->call_id, call->epoch,
				  call->service_id, call->proto,
				  call->conn->trans->local, addr_size,
				  call->peer_ip.ipv6_addr);
	/* Store the full key in the call */
	call->hash_key = key;
	spin_lock(&rxrpc_call_hash_lock);
	hash_add_rcu(rxrpc_call_hash, &call->hash_node, key);
	spin_unlock(&rxrpc_call_hash_lock);
	_leave("");
}

/*
 * Remove a call from the hashtable
 */
static void rxrpc_call_hash_del(struct rxrpc_call *call)
{
	_enter("");
	spin_lock(&rxrpc_call_hash_lock);
	hash_del_rcu(&call->hash_node);
	spin_unlock(&rxrpc_call_hash_lock);
	_leave("");
}

/*
 * Find a call in the hashtable and return it, or NULL if it
 * isn't there.
 */
struct rxrpc_call *rxrpc_find_call_hash(
	u8		clientflag,
	__be32		cid,
	__be32		call_id,
	__be32		epoch,
	__be16		service_id,
	void		*localptr,
	sa_family_t	proto,
	const u8	*peer_addr)
{
	unsigned long key;
	unsigned int addr_size = 0;
	struct rxrpc_call *call = NULL;
	struct rxrpc_call *ret = NULL;

	_enter("");
	switch (proto) {
	case AF_INET:
		addr_size = sizeof(call->peer_ip.ipv4_addr);
		break;
	case AF_INET6:
		addr_size = sizeof(call->peer_ip.ipv6_addr);
		break;
	default:
		break;
	}

	key = rxrpc_call_hashfunc(clientflag, cid, call_id, epoch,
				  service_id, proto, localptr, addr_size,
				  peer_addr);
	hash_for_each_possible_rcu(rxrpc_call_hash, call, hash_node, key) {
		if (call->hash_key == key &&
		    call->call_id == call_id &&
		    call->cid == cid &&
		    call->in_clientflag == clientflag &&
		    call->service_id == service_id &&
		    call->proto == proto &&
		    call->local == localptr &&
		    memcmp(call->peer_ip.ipv6_addr, peer_addr,
			      addr_size) == 0 &&
		    call->epoch == epoch) {
			ret = call;
			break;
		}
	}
	_leave(" = %p", ret);
	return ret;
}

/*
/*
 * allocate a new call
 * allocate a new call
 */
 */
@@ -136,6 +277,26 @@ static struct rxrpc_call *rxrpc_alloc_client_call(
		return ERR_PTR(ret);
		return ERR_PTR(ret);
	}
	}


	/* Record copies of information for hashtable lookup */
	call->proto = rx->proto;
	call->local = trans->local;
	switch (call->proto) {
	case AF_INET:
		call->peer_ip.ipv4_addr =
			trans->peer->srx.transport.sin.sin_addr.s_addr;
		break;
	case AF_INET6:
		memcpy(call->peer_ip.ipv6_addr,
		       trans->peer->srx.transport.sin6.sin6_addr.in6_u.u6_addr8,
		       sizeof(call->peer_ip.ipv6_addr));
		break;
	}
	call->epoch = call->conn->epoch;
	call->service_id = call->conn->service_id;
	call->in_clientflag = call->conn->in_clientflag;
	/* Add the new call to the hashtable */
	rxrpc_call_hash_add(call);

	spin_lock(&call->conn->trans->peer->lock);
	spin_lock(&call->conn->trans->peer->lock);
	list_add(&call->error_link, &call->conn->trans->peer->error_targets);
	list_add(&call->error_link, &call->conn->trans->peer->error_targets);
	spin_unlock(&call->conn->trans->peer->lock);
	spin_unlock(&call->conn->trans->peer->lock);
@@ -328,9 +489,12 @@ struct rxrpc_call *rxrpc_incoming_call(struct rxrpc_sock *rx,
		parent = *p;
		parent = *p;
		call = rb_entry(parent, struct rxrpc_call, conn_node);
		call = rb_entry(parent, struct rxrpc_call, conn_node);


		if (call_id < call->call_id)
		/* The tree is sorted in order of the __be32 value without
		 * turning it into host order.
		 */
		if ((__force u32)call_id < (__force u32)call->call_id)
			p = &(*p)->rb_left;
			p = &(*p)->rb_left;
		else if (call_id > call->call_id)
		else if ((__force u32)call_id > (__force u32)call->call_id)
			p = &(*p)->rb_right;
			p = &(*p)->rb_right;
		else
		else
			goto old_call;
			goto old_call;
@@ -355,6 +519,28 @@ struct rxrpc_call *rxrpc_incoming_call(struct rxrpc_sock *rx,
	list_add_tail(&call->link, &rxrpc_calls);
	list_add_tail(&call->link, &rxrpc_calls);
	write_unlock_bh(&rxrpc_call_lock);
	write_unlock_bh(&rxrpc_call_lock);


	/* Record copies of information for hashtable lookup */
	call->proto = rx->proto;
	call->local = conn->trans->local;
	switch (call->proto) {
	case AF_INET:
		call->peer_ip.ipv4_addr =
			conn->trans->peer->srx.transport.sin.sin_addr.s_addr;
		break;
	case AF_INET6:
		memcpy(call->peer_ip.ipv6_addr,
		       conn->trans->peer->srx.transport.sin6.sin6_addr.in6_u.u6_addr8,
		       sizeof(call->peer_ip.ipv6_addr));
		break;
	default:
		break;
	}
	call->epoch = conn->epoch;
	call->service_id = conn->service_id;
	call->in_clientflag = conn->in_clientflag;
	/* Add the new call to the hashtable */
	rxrpc_call_hash_add(call);

	_net("CALL incoming %d on CONN %d", call->debug_id, call->conn->debug_id);
	_net("CALL incoming %d on CONN %d", call->debug_id, call->conn->debug_id);


	call->lifetimer.expires = jiffies + rxrpc_max_call_lifetime;
	call->lifetimer.expires = jiffies + rxrpc_max_call_lifetime;
@@ -673,6 +859,9 @@ static void rxrpc_cleanup_call(struct rxrpc_call *call)
		rxrpc_put_connection(call->conn);
		rxrpc_put_connection(call->conn);
	}
	}


	/* Remove the call from the hash */
	rxrpc_call_hash_del(call);

	if (call->acks_window) {
	if (call->acks_window) {
		_debug("kill Tx window %d",
		_debug("kill Tx window %d",
		       CIRC_CNT(call->acks_head, call->acks_tail,
		       CIRC_CNT(call->acks_head, call->acks_tail,
+73 −104
Original line number Original line Diff line number Diff line
@@ -523,36 +523,38 @@ static void rxrpc_process_jumbo_packet(struct rxrpc_call *call,
 * post an incoming packet to the appropriate call/socket to deal with
 * post an incoming packet to the appropriate call/socket to deal with
 * - must get rid of the sk_buff, either by freeing it or by queuing it
 * - must get rid of the sk_buff, either by freeing it or by queuing it
 */
 */
static void rxrpc_post_packet_to_call(struct rxrpc_connection *conn,
static void rxrpc_post_packet_to_call(struct rxrpc_call *call,
				      struct sk_buff *skb)
				      struct sk_buff *skb)
{
{
	struct rxrpc_skb_priv *sp;
	struct rxrpc_skb_priv *sp;
	struct rxrpc_call *call;
	struct rb_node *p;
	__be32 call_id;

	_enter("%p,%p", conn, skb);


	read_lock_bh(&conn->lock);
	_enter("%p,%p", call, skb);


	sp = rxrpc_skb(skb);
	sp = rxrpc_skb(skb);


	/* look at extant calls by channel number first */
	call = conn->channels[ntohl(sp->hdr.cid) & RXRPC_CHANNELMASK];
	if (!call || call->call_id != sp->hdr.callNumber)
		goto call_not_extant;

	_debug("extant call [%d]", call->state);
	_debug("extant call [%d]", call->state);
	ASSERTCMP(call->conn, ==, conn);


	read_lock(&call->state_lock);
	read_lock(&call->state_lock);
	switch (call->state) {
	switch (call->state) {
	case RXRPC_CALL_LOCALLY_ABORTED:
	case RXRPC_CALL_LOCALLY_ABORTED:
		if (!test_and_set_bit(RXRPC_CALL_ABORT, &call->events))
		if (!test_and_set_bit(RXRPC_CALL_ABORT, &call->events)) {
			rxrpc_queue_call(call);
			rxrpc_queue_call(call);
			goto free_unlock;
		}
	case RXRPC_CALL_REMOTELY_ABORTED:
	case RXRPC_CALL_REMOTELY_ABORTED:
	case RXRPC_CALL_NETWORK_ERROR:
	case RXRPC_CALL_NETWORK_ERROR:
	case RXRPC_CALL_DEAD:
	case RXRPC_CALL_DEAD:
		goto dead_call;
	case RXRPC_CALL_COMPLETE:
	case RXRPC_CALL_CLIENT_FINAL_ACK:
		/* complete server call */
		if (call->conn->in_clientflag)
			goto dead_call;
		/* resend last packet of a completed call */
		_debug("final ack again");
		rxrpc_get_call(call);
		set_bit(RXRPC_CALL_ACK_FINAL, &call->events);
		rxrpc_queue_call(call);
		goto free_unlock;
		goto free_unlock;
	default:
	default:
		break;
		break;
@@ -560,7 +562,6 @@ static void rxrpc_post_packet_to_call(struct rxrpc_connection *conn,


	read_unlock(&call->state_lock);
	read_unlock(&call->state_lock);
	rxrpc_get_call(call);
	rxrpc_get_call(call);
	read_unlock_bh(&conn->lock);


	if (sp->hdr.type == RXRPC_PACKET_TYPE_DATA &&
	if (sp->hdr.type == RXRPC_PACKET_TYPE_DATA &&
	    sp->hdr.flags & RXRPC_JUMBO_PACKET)
	    sp->hdr.flags & RXRPC_JUMBO_PACKET)
@@ -571,80 +572,16 @@ static void rxrpc_post_packet_to_call(struct rxrpc_connection *conn,
	rxrpc_put_call(call);
	rxrpc_put_call(call);
	goto done;
	goto done;


call_not_extant:
	/* search the completed calls in case what we're dealing with is
	 * there */
	_debug("call not extant");

	call_id = sp->hdr.callNumber;
	p = conn->calls.rb_node;
	while (p) {
		call = rb_entry(p, struct rxrpc_call, conn_node);

		if (call_id < call->call_id)
			p = p->rb_left;
		else if (call_id > call->call_id)
			p = p->rb_right;
		else
			goto found_completed_call;
	}

dead_call:
dead_call:
	/* it's a either a really old call that we no longer remember or its a
	 * new incoming call */
	read_unlock_bh(&conn->lock);

	if (sp->hdr.flags & RXRPC_CLIENT_INITIATED &&
	    sp->hdr.seq == cpu_to_be32(1)) {
		_debug("incoming call");
		skb_queue_tail(&conn->trans->local->accept_queue, skb);
		rxrpc_queue_work(&conn->trans->local->acceptor);
		goto done;
	}

	_debug("dead call");
	if (sp->hdr.type != RXRPC_PACKET_TYPE_ABORT) {
	if (sp->hdr.type != RXRPC_PACKET_TYPE_ABORT) {
		skb->priority = RX_CALL_DEAD;
		skb->priority = RX_CALL_DEAD;
		rxrpc_reject_packet(conn->trans->local, skb);
		rxrpc_reject_packet(call->conn->trans->local, skb);
	}
		goto unlock;
	goto done;

	/* resend last packet of a completed call
	 * - client calls may have been aborted or ACK'd
	 * - server calls may have been aborted
	 */
found_completed_call:
	_debug("completed call");

	if (atomic_read(&call->usage) == 0)
		goto dead_call;

	/* synchronise any state changes */
	read_lock(&call->state_lock);
	ASSERTIFCMP(call->state != RXRPC_CALL_CLIENT_FINAL_ACK,
		    call->state, >=, RXRPC_CALL_COMPLETE);

	if (call->state == RXRPC_CALL_LOCALLY_ABORTED ||
	    call->state == RXRPC_CALL_REMOTELY_ABORTED ||
	    call->state == RXRPC_CALL_DEAD) {
		read_unlock(&call->state_lock);
		goto dead_call;
	}

	if (call->conn->in_clientflag) {
		read_unlock(&call->state_lock);
		goto dead_call; /* complete server call */
	}
	}

	_debug("final ack again");
	rxrpc_get_call(call);
	set_bit(RXRPC_CALL_ACK_FINAL, &call->events);
	rxrpc_queue_call(call);

free_unlock:
free_unlock:
	read_unlock(&call->state_lock);
	read_unlock_bh(&conn->lock);
	rxrpc_free_skb(skb);
	rxrpc_free_skb(skb);
unlock:
	read_unlock(&call->state_lock);
done:
done:
	_leave("");
	_leave("");
}
}
@@ -663,17 +600,42 @@ static void rxrpc_post_packet_to_conn(struct rxrpc_connection *conn,
	rxrpc_queue_conn(conn);
	rxrpc_queue_conn(conn);
}
}


static struct rxrpc_connection *rxrpc_conn_from_local(struct rxrpc_local *local,
					       struct sk_buff *skb,
					       struct rxrpc_skb_priv *sp)
{
	struct rxrpc_peer *peer;
	struct rxrpc_transport *trans;
	struct rxrpc_connection *conn;

	peer = rxrpc_find_peer(local, ip_hdr(skb)->saddr,
				udp_hdr(skb)->source);
	if (IS_ERR(peer))
		goto cant_find_conn;

	trans = rxrpc_find_transport(local, peer);
	rxrpc_put_peer(peer);
	if (!trans)
		goto cant_find_conn;

	conn = rxrpc_find_connection(trans, &sp->hdr);
	rxrpc_put_transport(trans);
	if (!conn)
		goto cant_find_conn;

	return conn;
cant_find_conn:
	return NULL;
}

/*
/*
 * handle data received on the local endpoint
 * handle data received on the local endpoint
 * - may be called in interrupt context
 * - may be called in interrupt context
 */
 */
void rxrpc_data_ready(struct sock *sk, int count)
void rxrpc_data_ready(struct sock *sk, int count)
{
{
	struct rxrpc_connection *conn;
	struct rxrpc_transport *trans;
	struct rxrpc_skb_priv *sp;
	struct rxrpc_skb_priv *sp;
	struct rxrpc_local *local;
	struct rxrpc_local *local;
	struct rxrpc_peer *peer;
	struct sk_buff *skb;
	struct sk_buff *skb;
	int ret;
	int ret;


@@ -748,27 +710,34 @@ void rxrpc_data_ready(struct sock *sk, int count)
	    (sp->hdr.callNumber == 0 || sp->hdr.seq == 0))
	    (sp->hdr.callNumber == 0 || sp->hdr.seq == 0))
		goto bad_message;
		goto bad_message;


	peer = rxrpc_find_peer(local, ip_hdr(skb)->saddr, udp_hdr(skb)->source);
	if (sp->hdr.callNumber == 0) {
	if (IS_ERR(peer))
		/* This is a connection-level packet. These should be
		goto cant_route_call;
		 * fairly rare, so the extra overhead of looking them up the

		 * old-fashioned way doesn't really hurt */
	trans = rxrpc_find_transport(local, peer);
		struct rxrpc_connection *conn;
	rxrpc_put_peer(peer);
	if (!trans)
		goto cant_route_call;


	conn = rxrpc_find_connection(trans, &sp->hdr);
		conn = rxrpc_conn_from_local(local, skb, sp);
	rxrpc_put_transport(trans);
		if (!conn)
		if (!conn)
			goto cant_route_call;
			goto cant_route_call;


		_debug("CONN %p {%d}", conn, conn->debug_id);
		_debug("CONN %p {%d}", conn, conn->debug_id);

	if (sp->hdr.callNumber == 0)
		rxrpc_post_packet_to_conn(conn, skb);
		rxrpc_post_packet_to_conn(conn, skb);
	else
		rxrpc_post_packet_to_call(conn, skb);
		rxrpc_put_connection(conn);
		rxrpc_put_connection(conn);
	} else {
		struct rxrpc_call *call;
		u8 in_clientflag = 0;

		if (sp->hdr.flags & RXRPC_CLIENT_INITIATED)
			in_clientflag = RXRPC_CLIENT_INITIATED;
		call = rxrpc_find_call_hash(in_clientflag, sp->hdr.cid,
					    sp->hdr.callNumber, sp->hdr.epoch,
					    sp->hdr.serviceId, local, AF_INET,
					    (u8 *)&ip_hdr(skb)->saddr);
		if (call)
			rxrpc_post_packet_to_call(call, skb);
		else
			goto cant_route_call;
	}
	rxrpc_put_local(local);
	rxrpc_put_local(local);
	return;
	return;


+13 −0
Original line number Original line Diff line number Diff line
@@ -396,9 +396,20 @@ struct rxrpc_call {
#define RXRPC_ACKR_WINDOW_ASZ DIV_ROUND_UP(RXRPC_MAXACKS, BITS_PER_LONG)
#define RXRPC_ACKR_WINDOW_ASZ DIV_ROUND_UP(RXRPC_MAXACKS, BITS_PER_LONG)
	unsigned long		ackr_window[RXRPC_ACKR_WINDOW_ASZ + 1];
	unsigned long		ackr_window[RXRPC_ACKR_WINDOW_ASZ + 1];


	struct hlist_node	hash_node;
	unsigned long		hash_key;	/* Full hash key */
	u8			in_clientflag;	/* Copy of conn->in_clientflag for hashing */
	struct rxrpc_local	*local;		/* Local endpoint. Used for hashing. */
	sa_family_t		proto;		/* Frame protocol */
	/* the following should all be in net order */
	/* the following should all be in net order */
	__be32			cid;		/* connection ID + channel index  */
	__be32			cid;		/* connection ID + channel index  */
	__be32			call_id;	/* call ID on connection  */
	__be32			call_id;	/* call ID on connection  */
	__be32			epoch;		/* epoch of this connection */
	__be16			service_id;	/* service ID */
	union {					/* Peer IP address for hashing */
		__be32	ipv4_addr;
		__u8	ipv6_addr[16];		/* Anticipates eventual IPv6 support */
	} peer_ip;
};
};


/*
/*
@@ -453,6 +464,8 @@ extern struct kmem_cache *rxrpc_call_jar;
extern struct list_head rxrpc_calls;
extern struct list_head rxrpc_calls;
extern rwlock_t rxrpc_call_lock;
extern rwlock_t rxrpc_call_lock;


struct rxrpc_call *rxrpc_find_call_hash(u8,  __be32, __be32, __be32,
					__be16, void *, sa_family_t, const u8 *);
struct rxrpc_call *rxrpc_get_client_call(struct rxrpc_sock *,
struct rxrpc_call *rxrpc_get_client_call(struct rxrpc_sock *,
					 struct rxrpc_transport *,
					 struct rxrpc_transport *,
					 struct rxrpc_conn_bundle *,
					 struct rxrpc_conn_bundle *,