Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit f7fa9b10 authored by Patrick McHardy's avatar Patrick McHardy Committed by David S. Miller
Browse files

[NETLINK]: Support dynamic number of multicast groups per netlink family

parent ab33a171
Loading
Loading
Loading
Loading
+51 −18
Original line number Diff line number Diff line
@@ -60,21 +60,24 @@
#include <net/scm.h>

#define Nprintk(a...)
#define NLGRPSZ(x)	(ALIGN(x, sizeof(unsigned long) * 8) / 8)

struct netlink_sock {
	/* struct sock has to be the first member of netlink_sock */
	struct sock		sk;
	u32			pid;
	unsigned int		groups;
	u32			dst_pid;
	u32			dst_group;
	u32			flags;
	u32			subscriptions;
	u32			ngroups;
	unsigned long		*groups;
	unsigned long		state;
	wait_queue_head_t	wait;
	struct netlink_callback	*cb;
	spinlock_t		cb_lock;
	void			(*data_ready)(struct sock *sk, int bytes);
	struct module		*module;
	u32			flags;
};

#define NETLINK_KERNEL_SOCKET	0x1
@@ -101,6 +104,7 @@ struct netlink_table {
	struct nl_pid_hash hash;
	struct hlist_head mc_list;
	unsigned int nl_nonroot;
	unsigned int groups;
	struct module *module;
	int registered;
};
@@ -138,6 +142,7 @@ static void netlink_sock_destruct(struct sock *sk)
	BUG_TRAP(!atomic_read(&sk->sk_rmem_alloc));
	BUG_TRAP(!atomic_read(&sk->sk_wmem_alloc));
	BUG_TRAP(!nlk_sk(sk)->cb);
	BUG_TRAP(!nlk_sk(sk)->groups);
}

/* This lock without WQ_FLAG_EXCLUSIVE is good on UP and it is _very_ bad on SMP.
@@ -333,7 +338,7 @@ static void netlink_remove(struct sock *sk)
	netlink_table_grab();
	if (sk_del_node_init(sk))
		nl_table[sk->sk_protocol].hash.entries--;
	if (nlk_sk(sk)->groups)
	if (nlk_sk(sk)->subscriptions)
		__sk_del_bind_node(sk);
	netlink_table_ungrab();
}
@@ -369,6 +374,8 @@ static int __netlink_create(struct socket *sock, int protocol)
static int netlink_create(struct socket *sock, int protocol)
{
	struct module *module = NULL;
	struct netlink_sock *nlk;
	unsigned int groups;
	int err = 0;

	sock->state = SS_UNCONNECTED;
@@ -392,15 +399,23 @@ static int netlink_create(struct socket *sock, int protocol)
		module = nl_table[protocol].module;
	else
		err = -EPROTONOSUPPORT;
	groups = nl_table[protocol].groups;
	netlink_unlock_table();

	if (err)
		goto out;
	if (err || (err = __netlink_create(sock, protocol) < 0))
		goto out_module;

	if ((err = __netlink_create(sock, protocol) < 0))
	nlk = nlk_sk(sock->sk);

	nlk->groups = kmalloc(NLGRPSZ(groups), GFP_KERNEL);
	if (nlk->groups == NULL) {
		err = -ENOMEM;
		goto out_module;
	}
	memset(nlk->groups, 0, NLGRPSZ(groups));
	nlk->ngroups = groups;

	nlk_sk(sock->sk)->module = module;
	nlk->module = module;
out:
	return err;

@@ -437,7 +452,7 @@ static int netlink_release(struct socket *sock)

	skb_queue_purge(&sk->sk_write_queue);

	if (nlk->pid && !nlk->groups) {
	if (nlk->pid && !nlk->subscriptions) {
		struct netlink_notify n = {
						.protocol = sk->sk_protocol,
						.pid = nlk->pid,
@@ -455,6 +470,9 @@ static int netlink_release(struct socket *sock)
		netlink_table_ungrab();
	}

	kfree(nlk->groups);
	nlk->groups = NULL;

	sock_put(sk);
	return 0;
}
@@ -503,6 +521,18 @@ static inline int netlink_capable(struct socket *sock, unsigned int flag)
	       capable(CAP_NET_ADMIN);
} 

static void
netlink_update_subscriptions(struct sock *sk, unsigned int subscriptions)
{
	struct netlink_sock *nlk = nlk_sk(sk);

	if (nlk->subscriptions && !subscriptions)
		__sk_del_bind_node(sk);
	else if (!nlk->subscriptions && subscriptions)
		sk_add_bind_node(sk, &nl_table[sk->sk_protocol].mc_list);
	nlk->subscriptions = subscriptions;
}

static int netlink_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
{
	struct sock *sk = sock->sk;
@@ -528,15 +558,14 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr, int addr_len
			return err;
	}

	if (!nladdr->nl_groups && !nlk->groups)
	if (!nladdr->nl_groups && !(u32)nlk->groups[0])
		return 0;

	netlink_table_grab();
	if (nlk->groups && !nladdr->nl_groups)
		__sk_del_bind_node(sk);
	else if (!nlk->groups && nladdr->nl_groups)
		sk_add_bind_node(sk, &nl_table[sk->sk_protocol].mc_list);
	nlk->groups = nladdr->nl_groups;
	netlink_update_subscriptions(sk, nlk->subscriptions +
	                                 hweight32(nladdr->nl_groups) -
	                                 hweight32(nlk->groups[0]));
	nlk->groups[0] = (nlk->groups[0] & ~0xffffffffUL) | nladdr->nl_groups; 
	netlink_table_ungrab();

	return 0;
@@ -590,7 +619,7 @@ static int netlink_getname(struct socket *sock, struct sockaddr *addr, int *addr
		nladdr->nl_groups = netlink_group_mask(nlk->dst_group);
	} else {
		nladdr->nl_pid = nlk->pid;
		nladdr->nl_groups = nlk->groups; 
		nladdr->nl_groups = nlk->groups[0];
	}
	return 0;
}
@@ -791,7 +820,8 @@ static inline int do_one_broadcast(struct sock *sk,
	if (p->exclude_sk == sk)
		goto out;

	if (nlk->pid == p->pid || !(nlk->groups & netlink_group_mask(p->group)))
	if (nlk->pid == p->pid || p->group - 1 >= nlk->ngroups ||
	    !test_bit(p->group - 1, nlk->groups))
		goto out;

	if (p->failure) {
@@ -887,7 +917,8 @@ static inline int do_one_set_err(struct sock *sk,
	if (sk == p->exclude_sk)
		goto out;

	if (nlk->pid == p->pid || !(nlk->groups & netlink_group_mask(p->group)))
	if (nlk->pid == p->pid || p->group - 1 >= nlk->ngroups ||
	    !test_bit(p->group - 1, nlk->groups))
		goto out;

	sk->sk_err = p->code;
@@ -1112,6 +1143,7 @@ netlink_kernel_create(int unit, void (*input)(struct sock *sk, int len), struct
	nlk->flags |= NETLINK_KERNEL_SOCKET;

	netlink_table_grab();
	nl_table[unit].groups = 32;
	nl_table[unit].module = module;
	nl_table[unit].registered = 1;
	netlink_table_ungrab();
@@ -1358,7 +1390,8 @@ static int netlink_seq_show(struct seq_file *seq, void *v)
			   s,
			   s->sk_protocol,
			   nlk->pid,
			   nlk->groups,
			   nlk->flags & NETLINK_KERNEL_SOCKET ?
				0 : (unsigned int)nlk->groups[0],
			   atomic_read(&s->sk_rmem_alloc),
			   atomic_read(&s->sk_wmem_alloc),
			   nlk->cb,