Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit fb778ea1 authored by Marek Lindner's avatar Marek Lindner
Browse files

batman-adv: protect each hash row with rcu locks

parent a775eb84
Loading
Loading
Loading
Loading
+25 −9
Original line number Diff line number Diff line
@@ -27,13 +27,16 @@ static void hash_init(struct hashtable_t *hash)
{
	int i;

	for (i = 0 ; i < hash->size; i++)
	for (i = 0 ; i < hash->size; i++) {
		INIT_HLIST_HEAD(&hash->table[i]);
		spin_lock_init(&hash->list_locks[i]);
	}
}

/* free only the hashtable and the hash itself. */
void hash_destroy(struct hashtable_t *hash)
{
	kfree(hash->list_locks);
	kfree(hash->table);
	kfree(hash);
}
@@ -44,19 +47,32 @@ struct hashtable_t *hash_new(int size)
	struct hashtable_t *hash;

	hash = kmalloc(sizeof(struct hashtable_t), GFP_ATOMIC);

	if (!hash)
		return NULL;

	hash->size = size;
	hash->table = kmalloc(sizeof(struct element_t *) * size, GFP_ATOMIC);
	if (!hash->table)
		goto free_hash;

	hash->list_locks = kmalloc(sizeof(spinlock_t) * size, GFP_ATOMIC);
	if (!hash->list_locks)
		goto free_table;

	hash->size = size;
	hash_init(hash);
	return hash;

	if (!hash->table) {
free_table:
	kfree(hash->table);
free_hash:
	kfree(hash);
	return NULL;
}

	hash_init(hash);
void bucket_free_rcu(struct rcu_head *rcu)
{
	struct element_t *bucket;

	return hash;
	bucket = container_of(rcu, struct element_t, rcu);
	kfree(bucket);
}
+49 −24
Original line number Diff line number Diff line
@@ -39,10 +39,12 @@ typedef void (*hashdata_free_cb)(void *, void *);
struct element_t {
	void *data;		/* pointer to the data */
	struct hlist_node hlist;	/* bucket list pointer */
	struct rcu_head rcu;
};

struct hashtable_t {
	struct hlist_head *table;   /* the hashtable itself, with the buckets */
	struct hlist_head *table;   /* the hashtable itself with the buckets */
	spinlock_t *list_locks;     /* spinlock for each hash list entry */
	int size;		    /* size of hashtable */
};

@@ -52,6 +54,8 @@ struct hashtable_t *hash_new(int size);
/* free only the hashtable and the hash itself. */
void hash_destroy(struct hashtable_t *hash);

void bucket_free_rcu(struct rcu_head *rcu);

/* remove the hash structure. if hashdata_free_cb != NULL, this function will be
 * called to remove the elements inside of the hash.  if you don't remove the
 * elements, memory might be leaked. */
@@ -61,19 +65,22 @@ static inline void hash_delete(struct hashtable_t *hash,
	struct hlist_head *head;
	struct hlist_node *walk, *safe;
	struct element_t *bucket;
	spinlock_t *list_lock; /* spinlock to protect write access */
	int i;

	for (i = 0; i < hash->size; i++) {
		head = &hash->table[i];
		list_lock = &hash->list_locks[i];

		hlist_for_each_safe(walk, safe, head) {
			bucket = hlist_entry(walk, struct element_t, hlist);
		spin_lock_bh(list_lock);
		hlist_for_each_entry_safe(bucket, walk, safe, head, hlist) {
			if (free_cb)
				free_cb(bucket->data, arg);

			hlist_del(walk);
			kfree(bucket);
			hlist_del_rcu(walk);
			call_rcu(&bucket->rcu, bucket_free_rcu);
		}
		spin_unlock_bh(list_lock);
	}

	hash_destroy(hash);
@@ -88,29 +95,39 @@ static inline int hash_add(struct hashtable_t *hash,
	struct hlist_head *head;
	struct hlist_node *walk, *safe;
	struct element_t *bucket;
	spinlock_t *list_lock; /* spinlock to protect write access */

	if (!hash)
		return -1;
		goto err;

	index = choose(data, hash->size);
	head = &hash->table[index];
	list_lock = &hash->list_locks[index];

	hlist_for_each_safe(walk, safe, head) {
		bucket = hlist_entry(walk, struct element_t, hlist);
	rcu_read_lock();
	hlist_for_each_entry_safe(bucket, walk, safe, head, hlist) {
		if (compare(bucket->data, data))
			return -1;
			goto err_unlock;
	}
	rcu_read_unlock();

	/* no duplicate found in list, add new element */
	bucket = kmalloc(sizeof(struct element_t), GFP_ATOMIC);

	if (!bucket)
		return -1;
		goto err;

	bucket->data = data;
	hlist_add_head(&bucket->hlist, head);

	spin_lock_bh(list_lock);
	hlist_add_head_rcu(&bucket->hlist, head);
	spin_unlock_bh(list_lock);

	return 0;

err_unlock:
	rcu_read_unlock();
err:
	return -1;
}

/* removes data from hash, if found. returns pointer do data on success, so you
@@ -125,25 +142,31 @@ static inline void *hash_remove(struct hashtable_t *hash,
	struct hlist_node *walk;
	struct element_t *bucket;
	struct hlist_head *head;
	void *data_save;
	void *data_save = NULL;

	index = choose(data, hash->size);
	head = &hash->table[index];

	spin_lock_bh(&hash->list_locks[index]);
	hlist_for_each_entry(bucket, walk, head, hlist) {
		if (compare(bucket->data, data)) {
			data_save = bucket->data;
			hlist_del(walk);
			kfree(bucket);
			return data_save;
			hlist_del_rcu(walk);
			call_rcu(&bucket->rcu, bucket_free_rcu);
			break;
		}
	}
	spin_unlock_bh(&hash->list_locks[index]);

	return NULL;
	return data_save;
}

/* finds data, based on the key in keydata. returns the found data on success,
 * or NULL on error */
/**
 * finds data, based on the key in keydata. returns the found data on success,
 * or NULL on error
 *
 * caller must lock with rcu_read_lock() / rcu_read_unlock()
 **/
static inline void *hash_find(struct hashtable_t *hash,
			      hashdata_compare_cb compare,
			      hashdata_choose_cb choose, void *keydata)
@@ -152,6 +175,7 @@ static inline void *hash_find(struct hashtable_t *hash,
	struct hlist_head *head;
	struct hlist_node *walk;
	struct element_t *bucket;
	void *bucket_data = NULL;

	if (!hash)
		return NULL;
@@ -159,13 +183,14 @@ static inline void *hash_find(struct hashtable_t *hash,
	index = choose(keydata , hash->size);
	head = &hash->table[index];

	hlist_for_each(walk, head) {
		bucket = hlist_entry(walk, struct element_t, hlist);
		if (compare(bucket->data, keydata))
			return bucket->data;
	hlist_for_each_entry(bucket, walk, head, hlist) {
		if (compare(bucket->data, keydata)) {
			bucket_data = bucket->data;
			break;
		}
	}

	return NULL;
	return bucket_data;
}

#endif /* _NET_BATMAN_ADV_HASH_H_ */
+2 −0
Original line number Diff line number Diff line
@@ -220,9 +220,11 @@ static ssize_t bat_socket_write(struct file *file, const char __user *buff,
		goto dst_unreach;

	spin_lock_bh(&bat_priv->orig_hash_lock);
	rcu_read_lock();
	orig_node = ((struct orig_node *)hash_find(bat_priv->orig_hash,
						   compare_orig, choose_orig,
						   icmp_packet->dst));
	rcu_read_unlock();

	if (!orig_node)
		goto unlock;
+20 −7
Original line number Diff line number Diff line
@@ -150,9 +150,11 @@ struct orig_node *get_orig_node(struct bat_priv *bat_priv, uint8_t *addr)
	int size;
	int hash_added;

	rcu_read_lock();
	orig_node = ((struct orig_node *)hash_find(bat_priv->orig_hash,
						   compare_orig, choose_orig,
						   addr));
	rcu_read_unlock();

	if (orig_node)
		return orig_node;
@@ -294,6 +296,7 @@ static void _purge_orig(struct bat_priv *bat_priv)
	struct hlist_node *walk, *safe;
	struct hlist_head *head;
	struct element_t *bucket;
	spinlock_t *list_lock; /* spinlock to protect write access */
	struct orig_node *orig_node;
	int i;

@@ -305,22 +308,26 @@ static void _purge_orig(struct bat_priv *bat_priv)
	/* for all origins... */
	for (i = 0; i < hash->size; i++) {
		head = &hash->table[i];
		list_lock = &hash->list_locks[i];

		spin_lock_bh(list_lock);
		hlist_for_each_entry_safe(bucket, walk, safe, head, hlist) {
			orig_node = bucket->data;

			if (purge_orig_node(bat_priv, orig_node)) {
				if (orig_node->gw_flags)
					gw_node_delete(bat_priv, orig_node);
				hlist_del(walk);
				kfree(bucket);
				hlist_del_rcu(walk);
				call_rcu(&bucket->rcu, bucket_free_rcu);
				free_orig_node(orig_node, bat_priv);
				continue;
			}

			if (time_after(jiffies, orig_node->last_frag_packet +
						msecs_to_jiffies(FRAG_TIMEOUT)))
				frag_list_free(&orig_node->frag_list);
		}
		spin_unlock_bh(list_lock);
	}

	spin_unlock_bh(&bat_priv->orig_hash_lock);
@@ -387,7 +394,8 @@ int orig_seq_print_text(struct seq_file *seq, void *offset)
	for (i = 0; i < hash->size; i++) {
		head = &hash->table[i];

		hlist_for_each_entry(bucket, walk, head, hlist) {
		rcu_read_lock();
		hlist_for_each_entry_rcu(bucket, walk, head, hlist) {
			orig_node = bucket->data;

			if (!orig_node->router)
@@ -408,17 +416,16 @@ int orig_seq_print_text(struct seq_file *seq, void *offset)
				   neigh_node->addr,
				   neigh_node->if_incoming->net_dev->name);

			rcu_read_lock();
			hlist_for_each_entry_rcu(neigh_node, node,
						 &orig_node->neigh_list, list) {
				seq_printf(seq, " %pM (%3i)", neigh_node->addr,
						neigh_node->tq_avg);
			}
			rcu_read_unlock();

			seq_printf(seq, "\n");
			batman_count++;
		}
		rcu_read_unlock();
	}

	spin_unlock_bh(&bat_priv->orig_hash_lock);
@@ -476,18 +483,21 @@ int orig_hash_add_if(struct batman_if *batman_if, int max_if_num)
	for (i = 0; i < hash->size; i++) {
		head = &hash->table[i];

		hlist_for_each_entry(bucket, walk, head, hlist) {
		rcu_read_lock();
		hlist_for_each_entry_rcu(bucket, walk, head, hlist) {
			orig_node = bucket->data;

			if (orig_node_add_if(orig_node, max_if_num) == -1)
				goto err;
		}
		rcu_read_unlock();
	}

	spin_unlock_bh(&bat_priv->orig_hash_lock);
	return 0;

err:
	rcu_read_unlock();
	spin_unlock_bh(&bat_priv->orig_hash_lock);
	return -ENOMEM;
}
@@ -562,7 +572,8 @@ int orig_hash_del_if(struct batman_if *batman_if, int max_if_num)
	for (i = 0; i < hash->size; i++) {
		head = &hash->table[i];

		hlist_for_each_entry(bucket, walk, head, hlist) {
		rcu_read_lock();
		hlist_for_each_entry_rcu(bucket, walk, head, hlist) {
			orig_node = bucket->data;

			ret = orig_node_del_if(orig_node, max_if_num,
@@ -571,6 +582,7 @@ int orig_hash_del_if(struct batman_if *batman_if, int max_if_num)
			if (ret == -1)
				goto err;
		}
		rcu_read_unlock();
	}

	/* renumber remaining batman interfaces _inside_ of orig_hash_lock */
@@ -595,6 +607,7 @@ int orig_hash_del_if(struct batman_if *batman_if, int max_if_num)
	return 0;

err:
	rcu_read_unlock();
	spin_unlock_bh(&bat_priv->orig_hash_lock);
	return -ENOMEM;
}
+15 −1
Original line number Diff line number Diff line
@@ -52,7 +52,8 @@ void slide_own_bcast_window(struct batman_if *batman_if)
	for (i = 0; i < hash->size; i++) {
		head = &hash->table[i];

		hlist_for_each_entry(bucket, walk, head, hlist) {
		rcu_read_lock();
		hlist_for_each_entry_rcu(bucket, walk, head, hlist) {
			orig_node = bucket->data;
			word_index = batman_if->if_num * NUM_WORDS;
			word = &(orig_node->bcast_own[word_index]);
@@ -61,6 +62,7 @@ void slide_own_bcast_window(struct batman_if *batman_if)
			orig_node->bcast_own_sum[batman_if->if_num] =
				bit_packet_count(word);
		}
		rcu_read_unlock();
	}

	spin_unlock_bh(&bat_priv->orig_hash_lock);
@@ -873,9 +875,11 @@ static int recv_my_icmp_packet(struct bat_priv *bat_priv,
	/* answer echo request (ping) */
	/* get routing information */
	spin_lock_bh(&bat_priv->orig_hash_lock);
	rcu_read_lock();
	orig_node = ((struct orig_node *)hash_find(bat_priv->orig_hash,
						   compare_orig, choose_orig,
						   icmp_packet->orig));
	rcu_read_unlock();
	ret = NET_RX_DROP;

	if ((orig_node) && (orig_node->router)) {
@@ -931,9 +935,11 @@ static int recv_icmp_ttl_exceeded(struct bat_priv *bat_priv,

	/* get routing information */
	spin_lock_bh(&bat_priv->orig_hash_lock);
	rcu_read_lock();
	orig_node = ((struct orig_node *)
		     hash_find(bat_priv->orig_hash, compare_orig, choose_orig,
			       icmp_packet->orig));
	rcu_read_unlock();
	ret = NET_RX_DROP;

	if ((orig_node) && (orig_node->router)) {
@@ -1023,9 +1029,11 @@ int recv_icmp_packet(struct sk_buff *skb, struct batman_if *recv_if)

	/* get routing information */
	spin_lock_bh(&bat_priv->orig_hash_lock);
	rcu_read_lock();
	orig_node = ((struct orig_node *)
		     hash_find(bat_priv->orig_hash, compare_orig, choose_orig,
			       icmp_packet->dst));
	rcu_read_unlock();

	if ((orig_node) && (orig_node->router)) {

@@ -1094,9 +1102,11 @@ struct neigh_node *find_router(struct bat_priv *bat_priv,
				router_orig->orig, ETH_ALEN) == 0) {
		primary_orig_node = router_orig;
	} else {
		rcu_read_lock();
		primary_orig_node = hash_find(bat_priv->orig_hash, compare_orig,
					       choose_orig,
					       router_orig->primary_addr);
		rcu_read_unlock();

		if (!primary_orig_node)
			return orig_node->router;
@@ -1199,9 +1209,11 @@ int route_unicast_packet(struct sk_buff *skb, struct batman_if *recv_if,

	/* get routing information */
	spin_lock_bh(&bat_priv->orig_hash_lock);
	rcu_read_lock();
	orig_node = ((struct orig_node *)
		     hash_find(bat_priv->orig_hash, compare_orig, choose_orig,
			       unicast_packet->dest));
	rcu_read_unlock();

	router = find_router(bat_priv, orig_node, recv_if);

@@ -1345,9 +1357,11 @@ int recv_bcast_packet(struct sk_buff *skb, struct batman_if *recv_if)
		return NET_RX_DROP;

	spin_lock_bh(&bat_priv->orig_hash_lock);
	rcu_read_lock();
	orig_node = ((struct orig_node *)
		     hash_find(bat_priv->orig_hash, compare_orig, choose_orig,
			       bcast_packet->orig));
	rcu_read_unlock();

	if (!orig_node) {
		spin_unlock_bh(&bat_priv->orig_hash_lock);
Loading