Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 85f8c13e authored by Jozsef Kadlecsik's avatar Jozsef Kadlecsik
Browse files

netfilter: ipset: Rewrite cidr book keeping to handle /0

parent b9fed748
Loading
Loading
Loading
Loading
+55 −49
Original line number Original line Diff line number Diff line
@@ -137,50 +137,59 @@ htable_bits(u32 hashsize)
#endif
#endif


#define SET_HOST_MASK(family)	(family == AF_INET ? 32 : 128)
#define SET_HOST_MASK(family)	(family == AF_INET ? 32 : 128)
#ifdef IP_SET_HASH_WITH_MULTI
#define NETS_LENGTH(family)	(SET_HOST_MASK(family) + 1)
#else
#define NETS_LENGTH(family)	SET_HOST_MASK(family)
#endif


/* Network cidr size book keeping when the hash stores different
/* Network cidr size book keeping when the hash stores different
 * sized networks */
 * sized networks */
static void
static void
add_cidr(struct ip_set_hash *h, u8 cidr, u8 host_mask)
add_cidr(struct ip_set_hash *h, u8 cidr, u8 nets_length)
{
{
	u8 i;
	int i, j;

	++h->nets[cidr-1].nets;

	pr_debug("add_cidr added %u: %u\n", cidr, h->nets[cidr-1].nets);


	if (h->nets[cidr-1].nets > 1)
		return;

	/* New cidr size */
	for (i = 0; i < host_mask && h->nets[i].cidr; i++) {
	/* Add in increasing prefix order, so larger cidr first */
	/* Add in increasing prefix order, so larger cidr first */
		if (h->nets[i].cidr < cidr)
	for (i = 0, j = -1; i < nets_length && h->nets[i].nets; i++) {
			swap(h->nets[i].cidr, cidr);
		if (j != -1)
			continue;
		else if (h->nets[i].cidr < cidr)
			j = i;
		else if (h->nets[i].cidr == cidr) {
			h->nets[i].nets++;
			return;
		}
	}
	if (j != -1) {
		for (; i > j; i--) {
			h->nets[i].cidr = h->nets[i - 1].cidr;
			h->nets[i].nets = h->nets[i - 1].nets;
		}
	}
	}
	if (i < host_mask)
	h->nets[i].cidr = cidr;
	h->nets[i].cidr = cidr;
	h->nets[i].nets = 1;
}
}


static void
static void
del_cidr(struct ip_set_hash *h, u8 cidr, u8 host_mask)
del_cidr(struct ip_set_hash *h, u8 cidr, u8 nets_length)
{
{
	u8 i;
	u8 i, j;

	--h->nets[cidr-1].nets;


	pr_debug("del_cidr deleted %u: %u\n", cidr, h->nets[cidr-1].nets);
	for (i = 0; i < nets_length - 1 && h->nets[i].cidr != cidr; i++)
		;
	h->nets[i].nets--;


	if (h->nets[cidr-1].nets != 0)
	if (h->nets[i].nets != 0)
		return;
		return;


	/* All entries with this cidr size deleted, so cleanup h->cidr[] */
	for (j = i; j < nets_length - 1 && h->nets[j].nets; j++) {
	for (i = 0; i < host_mask - 1 && h->nets[i].cidr; i++) {
		h->nets[j].cidr = h->nets[j + 1].cidr;
		if (h->nets[i].cidr == cidr)
		h->nets[j].nets = h->nets[j + 1].nets;
			h->nets[i].cidr = cidr = h->nets[i+1].cidr;
	}
	}
	h->nets[i - 1].cidr = 0;
}
}
#else
#define NETS_LENGTH(family)		0
#endif
#endif


/* Destroy the hashtable part of the set */
/* Destroy the hashtable part of the set */
@@ -202,14 +211,14 @@ ahash_destroy(struct htable *t)


/* Calculate the actual memory size of the set data */
/* Calculate the actual memory size of the set data */
static size_t
static size_t
ahash_memsize(const struct ip_set_hash *h, size_t dsize, u8 host_mask)
ahash_memsize(const struct ip_set_hash *h, size_t dsize, u8 nets_length)
{
{
	u32 i;
	u32 i;
	struct htable *t = h->table;
	struct htable *t = h->table;
	size_t memsize = sizeof(*h)
	size_t memsize = sizeof(*h)
			 + sizeof(*t)
			 + sizeof(*t)
#ifdef IP_SET_HASH_WITH_NETS
#ifdef IP_SET_HASH_WITH_NETS
			 + sizeof(struct ip_set_hash_nets) * host_mask
			 + sizeof(struct ip_set_hash_nets) * nets_length
#endif
#endif
			 + jhash_size(t->htable_bits) * sizeof(struct hbucket);
			 + jhash_size(t->htable_bits) * sizeof(struct hbucket);


@@ -238,7 +247,7 @@ ip_set_hash_flush(struct ip_set *set)
	}
	}
#ifdef IP_SET_HASH_WITH_NETS
#ifdef IP_SET_HASH_WITH_NETS
	memset(h->nets, 0, sizeof(struct ip_set_hash_nets)
	memset(h->nets, 0, sizeof(struct ip_set_hash_nets)
			   * SET_HOST_MASK(set->family));
			   * NETS_LENGTH(set->family));
#endif
#endif
	h->elements = 0;
	h->elements = 0;
}
}
@@ -271,9 +280,6 @@ ip_set_hash_destroy(struct ip_set *set)
(jhash2((u32 *)(data), HKEY_DATALEN/sizeof(u32), initval)	\
(jhash2((u32 *)(data), HKEY_DATALEN/sizeof(u32), initval)	\
	& jhash_mask(htable_bits))
	& jhash_mask(htable_bits))


#define CONCAT(a, b, c)		a##b##c
#define TOKEN(a, b, c)		CONCAT(a, b, c)

/* Type/family dependent function prototypes */
/* Type/family dependent function prototypes */


#define type_pf_data_equal	TOKEN(TYPE, PF, _data_equal)
#define type_pf_data_equal	TOKEN(TYPE, PF, _data_equal)
@@ -478,7 +484,7 @@ type_pf_add(struct ip_set *set, void *value, u32 timeout, u32 flags)
	}
	}


#ifdef IP_SET_HASH_WITH_NETS
#ifdef IP_SET_HASH_WITH_NETS
	add_cidr(h, CIDR(d->cidr), HOST_MASK);
	add_cidr(h, CIDR(d->cidr), NETS_LENGTH(set->family));
#endif
#endif
	h->elements++;
	h->elements++;
out:
out:
@@ -513,7 +519,7 @@ type_pf_del(struct ip_set *set, void *value, u32 timeout, u32 flags)
		n->pos--;
		n->pos--;
		h->elements--;
		h->elements--;
#ifdef IP_SET_HASH_WITH_NETS
#ifdef IP_SET_HASH_WITH_NETS
		del_cidr(h, CIDR(d->cidr), HOST_MASK);
		del_cidr(h, CIDR(d->cidr), NETS_LENGTH(set->family));
#endif
#endif
		if (n->pos + AHASH_INIT_SIZE < n->size) {
		if (n->pos + AHASH_INIT_SIZE < n->size) {
			void *tmp = kzalloc((n->size - AHASH_INIT_SIZE)
			void *tmp = kzalloc((n->size - AHASH_INIT_SIZE)
@@ -546,10 +552,10 @@ type_pf_test_cidrs(struct ip_set *set, struct type_pf_elem *d, u32 timeout)
	const struct type_pf_elem *data;
	const struct type_pf_elem *data;
	int i, j = 0;
	int i, j = 0;
	u32 key, multi = 0;
	u32 key, multi = 0;
	u8 host_mask = SET_HOST_MASK(set->family);
	u8 nets_length = NETS_LENGTH(set->family);


	pr_debug("test by nets\n");
	pr_debug("test by nets\n");
	for (; j < host_mask && h->nets[j].cidr && !multi; j++) {
	for (; j < nets_length && h->nets[j].nets && !multi; j++) {
		type_pf_data_netmask(d, h->nets[j].cidr);
		type_pf_data_netmask(d, h->nets[j].cidr);
		key = HKEY(d, h->initval, t->htable_bits);
		key = HKEY(d, h->initval, t->htable_bits);
		n = hbucket(t, key);
		n = hbucket(t, key);
@@ -604,7 +610,7 @@ type_pf_head(struct ip_set *set, struct sk_buff *skb)
	memsize = ahash_memsize(h, with_timeout(h->timeout)
	memsize = ahash_memsize(h, with_timeout(h->timeout)
					? sizeof(struct type_pf_telem)
					? sizeof(struct type_pf_telem)
					: sizeof(struct type_pf_elem),
					: sizeof(struct type_pf_elem),
				set->family == AF_INET ? 32 : 128);
				NETS_LENGTH(set->family));
	read_unlock_bh(&set->lock);
	read_unlock_bh(&set->lock);


	nested = ipset_nest_start(skb, IPSET_ATTR_DATA);
	nested = ipset_nest_start(skb, IPSET_ATTR_DATA);
@@ -783,7 +789,7 @@ type_pf_elem_tadd(struct hbucket *n, const struct type_pf_elem *value,


/* Delete expired elements from the hashtable */
/* Delete expired elements from the hashtable */
static void
static void
type_pf_expire(struct ip_set_hash *h)
type_pf_expire(struct ip_set_hash *h, u8 nets_length)
{
{
	struct htable *t = h->table;
	struct htable *t = h->table;
	struct hbucket *n;
	struct hbucket *n;
@@ -798,7 +804,7 @@ type_pf_expire(struct ip_set_hash *h)
			if (type_pf_data_expired(data)) {
			if (type_pf_data_expired(data)) {
				pr_debug("expired %u/%u\n", i, j);
				pr_debug("expired %u/%u\n", i, j);
#ifdef IP_SET_HASH_WITH_NETS
#ifdef IP_SET_HASH_WITH_NETS
				del_cidr(h, CIDR(data->cidr), HOST_MASK);
				del_cidr(h, CIDR(data->cidr), nets_length);
#endif
#endif
				if (j != n->pos - 1)
				if (j != n->pos - 1)
					/* Not last one */
					/* Not last one */
@@ -839,7 +845,7 @@ type_pf_tresize(struct ip_set *set, bool retried)
	if (!retried) {
	if (!retried) {
		i = h->elements;
		i = h->elements;
		write_lock_bh(&set->lock);
		write_lock_bh(&set->lock);
		type_pf_expire(set->data);
		type_pf_expire(set->data, NETS_LENGTH(set->family));
		write_unlock_bh(&set->lock);
		write_unlock_bh(&set->lock);
		if (h->elements <  i)
		if (h->elements <  i)
			return 0;
			return 0;
@@ -904,7 +910,7 @@ type_pf_tadd(struct ip_set *set, void *value, u32 timeout, u32 flags)


	if (h->elements >= h->maxelem)
	if (h->elements >= h->maxelem)
		/* FIXME: when set is full, we slow down here */
		/* FIXME: when set is full, we slow down here */
		type_pf_expire(h);
		type_pf_expire(h, NETS_LENGTH(set->family));
	if (h->elements >= h->maxelem) {
	if (h->elements >= h->maxelem) {
		if (net_ratelimit())
		if (net_ratelimit())
			pr_warning("Set %s is full, maxelem %u reached\n",
			pr_warning("Set %s is full, maxelem %u reached\n",
@@ -933,8 +939,8 @@ type_pf_tadd(struct ip_set *set, void *value, u32 timeout, u32 flags)
	if (j != AHASH_MAX(h) + 1) {
	if (j != AHASH_MAX(h) + 1) {
		data = ahash_tdata(n, j);
		data = ahash_tdata(n, j);
#ifdef IP_SET_HASH_WITH_NETS
#ifdef IP_SET_HASH_WITH_NETS
		del_cidr(h, CIDR(data->cidr), HOST_MASK);
		del_cidr(h, CIDR(data->cidr), NETS_LENGTH(set->family));
		add_cidr(h, CIDR(d->cidr), HOST_MASK);
		add_cidr(h, CIDR(d->cidr), NETS_LENGTH(set->family));
#endif
#endif
		type_pf_data_copy(data, d);
		type_pf_data_copy(data, d);
		type_pf_data_timeout_set(data, timeout);
		type_pf_data_timeout_set(data, timeout);
@@ -952,7 +958,7 @@ type_pf_tadd(struct ip_set *set, void *value, u32 timeout, u32 flags)
	}
	}


#ifdef IP_SET_HASH_WITH_NETS
#ifdef IP_SET_HASH_WITH_NETS
	add_cidr(h, CIDR(d->cidr), HOST_MASK);
	add_cidr(h, CIDR(d->cidr), NETS_LENGTH(set->family));
#endif
#endif
	h->elements++;
	h->elements++;
out:
out:
@@ -986,7 +992,7 @@ type_pf_tdel(struct ip_set *set, void *value, u32 timeout, u32 flags)
		n->pos--;
		n->pos--;
		h->elements--;
		h->elements--;
#ifdef IP_SET_HASH_WITH_NETS
#ifdef IP_SET_HASH_WITH_NETS
		del_cidr(h, CIDR(d->cidr), HOST_MASK);
		del_cidr(h, CIDR(d->cidr), NETS_LENGTH(set->family));
#endif
#endif
		if (n->pos + AHASH_INIT_SIZE < n->size) {
		if (n->pos + AHASH_INIT_SIZE < n->size) {
			void *tmp = kzalloc((n->size - AHASH_INIT_SIZE)
			void *tmp = kzalloc((n->size - AHASH_INIT_SIZE)
@@ -1016,9 +1022,9 @@ type_pf_ttest_cidrs(struct ip_set *set, struct type_pf_elem *d, u32 timeout)
	struct hbucket *n;
	struct hbucket *n;
	int i, j = 0;
	int i, j = 0;
	u32 key, multi = 0;
	u32 key, multi = 0;
	u8 host_mask = SET_HOST_MASK(set->family);
	u8 nets_length = NETS_LENGTH(set->family);


	for (; j < host_mask && h->nets[j].cidr && !multi; j++) {
	for (; j < nets_length && h->nets[j].nets && !multi; j++) {
		type_pf_data_netmask(d, h->nets[j].cidr);
		type_pf_data_netmask(d, h->nets[j].cidr);
		key = HKEY(d, h->initval, t->htable_bits);
		key = HKEY(d, h->initval, t->htable_bits);
		n = hbucket(t, key);
		n = hbucket(t, key);
@@ -1147,7 +1153,7 @@ type_pf_gc(unsigned long ul_set)


	pr_debug("called\n");
	pr_debug("called\n");
	write_lock_bh(&set->lock);
	write_lock_bh(&set->lock);
	type_pf_expire(h);
	type_pf_expire(h, NETS_LENGTH(set->family));
	write_unlock_bh(&set->lock);
	write_unlock_bh(&set->lock);


	h->gc.expires = jiffies + IPSET_GC_PERIOD(h->timeout) * HZ;
	h->gc.expires = jiffies + IPSET_GC_PERIOD(h->timeout) * HZ;