Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 456b61bc authored by Eric Dumazet's avatar Eric Dumazet Committed by David S. Miller
Browse files

ipv6: mcast: RCU conversion



ipv6_sk_mc_lock rwlock becomes a spinlock.

readers (inet6_mc_check()) now takes rcu_read_lock() instead of read
lock. Writers dont need to disable BH anymore.

struct ipv6_mc_socklist objects are reclaimed after one RCU grace
period.

Signed-off-by: default avatarEric Dumazet <eric.dumazet@gmail.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 2757a15f
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -364,7 +364,7 @@ struct ipv6_pinfo {

	__u32			dst_cookie;

	struct ipv6_mc_socklist	*ipv6_mc_list;
	struct ipv6_mc_socklist	__rcu *ipv6_mc_list;
	struct ipv6_ac_socklist	*ipv6_ac_list;
	struct ipv6_fl_socklist *ipv6_fl_list;

+2 −1
Original line number Diff line number Diff line
@@ -89,10 +89,11 @@ struct ip6_sf_socklist {
struct ipv6_mc_socklist {
	struct in6_addr		addr;
	int			ifindex;
	struct ipv6_mc_socklist *next;
	struct ipv6_mc_socklist __rcu *next;
	rwlock_t		sflock;
	unsigned int		sfmode;		/* MCAST_{INCLUDE,EXCLUDE} */
	struct ip6_sf_socklist	*sflist;
	struct rcu_head		rcu;
};

struct ip6_sf_list {
+44 −31
Original line number Diff line number Diff line
@@ -82,7 +82,7 @@ static void *__mld2_query_bugs[] __attribute__((__unused__)) = {
static struct in6_addr mld2_all_mcr = MLD2_ALL_MCR_INIT;

/* Big mc list lock for all the sockets */
static DEFINE_RWLOCK(ipv6_sk_mc_lock);
static DEFINE_SPINLOCK(ipv6_sk_mc_lock);

static void igmp6_join_group(struct ifmcaddr6 *ma);
static void igmp6_leave_group(struct ifmcaddr6 *ma);
@@ -123,6 +123,11 @@ int sysctl_mld_max_msf __read_mostly = IPV6_MLD_MAX_MSF;
 *	socket join on multicast group
 */

#define for_each_pmc_rcu(np, pmc)				\
	for (pmc = rcu_dereference(np->ipv6_mc_list);		\
	     pmc != NULL;					\
	     pmc = rcu_dereference(pmc->next))

int ipv6_sock_mc_join(struct sock *sk, int ifindex, const struct in6_addr *addr)
{
	struct net_device *dev = NULL;
@@ -134,15 +139,15 @@ int ipv6_sock_mc_join(struct sock *sk, int ifindex, const struct in6_addr *addr)
	if (!ipv6_addr_is_multicast(addr))
		return -EINVAL;

	read_lock_bh(&ipv6_sk_mc_lock);
	for (mc_lst=np->ipv6_mc_list; mc_lst; mc_lst=mc_lst->next) {
	rcu_read_lock();
	for_each_pmc_rcu(np, mc_lst) {
		if ((ifindex == 0 || mc_lst->ifindex == ifindex) &&
		    ipv6_addr_equal(&mc_lst->addr, addr)) {
			read_unlock_bh(&ipv6_sk_mc_lock);
			rcu_read_unlock();
			return -EADDRINUSE;
		}
	}
	read_unlock_bh(&ipv6_sk_mc_lock);
	rcu_read_unlock();

	mc_lst = sock_kmalloc(sk, sizeof(struct ipv6_mc_socklist), GFP_KERNEL);

@@ -186,33 +191,41 @@ int ipv6_sock_mc_join(struct sock *sk, int ifindex, const struct in6_addr *addr)
		return err;
	}

	write_lock_bh(&ipv6_sk_mc_lock);
	spin_lock(&ipv6_sk_mc_lock);
	mc_lst->next = np->ipv6_mc_list;
	np->ipv6_mc_list = mc_lst;
	write_unlock_bh(&ipv6_sk_mc_lock);
	rcu_assign_pointer(np->ipv6_mc_list, mc_lst);
	spin_unlock(&ipv6_sk_mc_lock);

	rcu_read_unlock();

	return 0;
}

static void ipv6_mc_socklist_reclaim(struct rcu_head *head)
{
	kfree(container_of(head, struct ipv6_mc_socklist, rcu));
}
/*
 *	socket leave on multicast group
 */
int ipv6_sock_mc_drop(struct sock *sk, int ifindex, const struct in6_addr *addr)
{
	struct ipv6_pinfo *np = inet6_sk(sk);
	struct ipv6_mc_socklist *mc_lst, **lnk;
	struct ipv6_mc_socklist *mc_lst;
	struct ipv6_mc_socklist __rcu **lnk;
	struct net *net = sock_net(sk);

	write_lock_bh(&ipv6_sk_mc_lock);
	for (lnk = &np->ipv6_mc_list; (mc_lst = *lnk) !=NULL ; lnk = &mc_lst->next) {
	spin_lock(&ipv6_sk_mc_lock);
	for (lnk = &np->ipv6_mc_list;
	     (mc_lst = rcu_dereference_protected(*lnk,
			lockdep_is_held(&ipv6_sk_mc_lock))) !=NULL ;
	      lnk = &mc_lst->next) {
		if ((ifindex == 0 || mc_lst->ifindex == ifindex) &&
		    ipv6_addr_equal(&mc_lst->addr, addr)) {
			struct net_device *dev;

			*lnk = mc_lst->next;
			write_unlock_bh(&ipv6_sk_mc_lock);
			spin_unlock(&ipv6_sk_mc_lock);

			rcu_read_lock();
			dev = dev_get_by_index_rcu(net, mc_lst->ifindex);
@@ -225,11 +238,12 @@ int ipv6_sock_mc_drop(struct sock *sk, int ifindex, const struct in6_addr *addr)
			} else
				(void) ip6_mc_leave_src(sk, mc_lst, NULL);
			rcu_read_unlock();
			sock_kfree_s(sk, mc_lst, sizeof(*mc_lst));
			atomic_sub(sizeof(*mc_lst), &sk->sk_omem_alloc);
			call_rcu(&mc_lst->rcu, ipv6_mc_socklist_reclaim);
			return 0;
		}
	}
	write_unlock_bh(&ipv6_sk_mc_lock);
	spin_unlock(&ipv6_sk_mc_lock);

	return -EADDRNOTAVAIL;
}
@@ -272,12 +286,13 @@ void ipv6_sock_mc_close(struct sock *sk)
	struct ipv6_mc_socklist *mc_lst;
	struct net *net = sock_net(sk);

	write_lock_bh(&ipv6_sk_mc_lock);
	while ((mc_lst = np->ipv6_mc_list) != NULL) {
	spin_lock(&ipv6_sk_mc_lock);
	while ((mc_lst = rcu_dereference_protected(np->ipv6_mc_list,
				lockdep_is_held(&ipv6_sk_mc_lock))) != NULL) {
		struct net_device *dev;

		np->ipv6_mc_list = mc_lst->next;
		write_unlock_bh(&ipv6_sk_mc_lock);
		spin_unlock(&ipv6_sk_mc_lock);

		rcu_read_lock();
		dev = dev_get_by_index_rcu(net, mc_lst->ifindex);
@@ -290,11 +305,13 @@ void ipv6_sock_mc_close(struct sock *sk)
		} else
			(void) ip6_mc_leave_src(sk, mc_lst, NULL);
		rcu_read_unlock();
		sock_kfree_s(sk, mc_lst, sizeof(*mc_lst));

		write_lock_bh(&ipv6_sk_mc_lock);
		atomic_sub(sizeof(*mc_lst), &sk->sk_omem_alloc);
		call_rcu(&mc_lst->rcu, ipv6_mc_socklist_reclaim);

		spin_lock(&ipv6_sk_mc_lock);
	}
	write_unlock_bh(&ipv6_sk_mc_lock);
	spin_unlock(&ipv6_sk_mc_lock);
}

int ip6_mc_source(int add, int omode, struct sock *sk,
@@ -328,8 +345,7 @@ int ip6_mc_source(int add, int omode, struct sock *sk,

	err = -EADDRNOTAVAIL;

	read_lock(&ipv6_sk_mc_lock);
	for (pmc=inet6->ipv6_mc_list; pmc; pmc=pmc->next) {
	for_each_pmc_rcu(inet6, pmc) {
		if (pgsr->gsr_interface && pmc->ifindex != pgsr->gsr_interface)
			continue;
		if (ipv6_addr_equal(&pmc->addr, group))
@@ -428,7 +444,6 @@ int ip6_mc_source(int add, int omode, struct sock *sk,
done:
	if (pmclocked)
		write_unlock(&pmc->sflock);
	read_unlock(&ipv6_sk_mc_lock);
	read_unlock_bh(&idev->lock);
	rcu_read_unlock();
	if (leavegroup)
@@ -466,14 +481,13 @@ int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf)
	dev = idev->dev;

	err = 0;
	read_lock(&ipv6_sk_mc_lock);

	if (gsf->gf_fmode == MCAST_INCLUDE && gsf->gf_numsrc == 0) {
		leavegroup = 1;
		goto done;
	}

	for (pmc=inet6->ipv6_mc_list; pmc; pmc=pmc->next) {
	for_each_pmc_rcu(inet6, pmc) {
		if (pmc->ifindex != gsf->gf_interface)
			continue;
		if (ipv6_addr_equal(&pmc->addr, group))
@@ -521,7 +535,6 @@ int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf)
	write_unlock(&pmc->sflock);
	err = 0;
done:
	read_unlock(&ipv6_sk_mc_lock);
	read_unlock_bh(&idev->lock);
	rcu_read_unlock();
	if (leavegroup)
@@ -562,7 +575,7 @@ int ip6_mc_msfget(struct sock *sk, struct group_filter *gsf,
	 * so reading the list is safe.
	 */

	for (pmc=inet6->ipv6_mc_list; pmc; pmc=pmc->next) {
	for_each_pmc_rcu(inet6, pmc) {
		if (pmc->ifindex != gsf->gf_interface)
			continue;
		if (ipv6_addr_equal(group, &pmc->addr))
@@ -612,13 +625,13 @@ int inet6_mc_check(struct sock *sk, const struct in6_addr *mc_addr,
	struct ip6_sf_socklist *psl;
	int rv = 1;

	read_lock(&ipv6_sk_mc_lock);
	for (mc = np->ipv6_mc_list; mc; mc = mc->next) {
	rcu_read_lock();
	for_each_pmc_rcu(np, mc) {
		if (ipv6_addr_equal(&mc->addr, mc_addr))
			break;
	}
	if (!mc) {
		read_unlock(&ipv6_sk_mc_lock);
		rcu_read_unlock();
		return 1;
	}
	read_lock(&mc->sflock);
@@ -638,7 +651,7 @@ int inet6_mc_check(struct sock *sk, const struct in6_addr *mc_addr,
			rv = 0;
	}
	read_unlock(&mc->sflock);
	read_unlock(&ipv6_sk_mc_lock);
	rcu_read_unlock();

	return rv;
}