Loading net/wireguard/Makefile +2 −3 Original line number Diff line number Diff line Loading @@ -2,10 +2,9 @@ # # Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. ccflags-y := -O3 ccflags-$(CONFIG_WIREGUARD_DEBUG) += -DDEBUG -g ccflags-y += -D'pr_fmt(fmt)=KBUILD_MODNAME ": " fmt' ccflags-y := -D'pr_fmt(fmt)=KBUILD_MODNAME ": " fmt' ccflags-y += -Wframe-larger-than=2048 ccflags-$(CONFIG_WIREGUARD_DEBUG) += -DDEBUG -g ccflags-$(if $(WIREGUARD_VERSION),y,) += -D'WIREGUARD_VERSION="$(WIREGUARD_VERSION)"' wireguard-y := main.o noise.o device.o peer.o timers.o queueing.o send.o receive.o socket.o peerlookup.o allowedips.o ratelimiter.o cookie.o netlink.o Loading net/wireguard/allowedips.c +96 −92 Original line number Diff line number Diff line Loading @@ -6,6 +6,8 @@ #include "allowedips.h" #include "peer.h" static struct kmem_cache *node_cache; static void swap_endian(u8 *dst, const u8 *src, u8 bits) { if (bits == 32) { Loading @@ -28,12 +30,10 @@ static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src, node->bitlen = bits; memcpy(node->bits, src, bits / 8U); } #define CHOOSE_NODE(parent, key) \ parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1] static void node_free_rcu(struct rcu_head *rcu) static inline u8 choose(struct allowedips_node *node, const u8 *key) { kfree(container_of(rcu, struct allowedips_node, rcu)); return (key[node->bit_at_a] >> node->bit_at_b) & 1; } static void push_rcu(struct allowedips_node **stack, Loading @@ -45,6 +45,11 @@ static void push_rcu(struct allowedips_node **stack, } } static void node_free_rcu(struct rcu_head *rcu) { kmem_cache_free(node_cache, container_of(rcu, struct allowedips_node, rcu)); } static void root_free_rcu(struct rcu_head *rcu) { struct allowedips_node *node, *stack[128] = { Loading @@ -54,7 +59,7 @@ static void root_free_rcu(struct rcu_head *rcu) while (len > 0 && (node = stack[--len])) { push_rcu(stack, node->bit[0], &len); push_rcu(stack, node->bit[1], &len); kfree(node); kmem_cache_free(node_cache, node); } } Loading @@ -71,60 +76,6 @@ static void root_remove_peer_lists(struct allowedips_node *root) } } static void walk_remove_by_peer(struct allowedips_node __rcu **top, struct wg_peer *peer, struct mutex *lock) { #define REF(p) rcu_access_pointer(p) #define DEREF(p) rcu_dereference_protected(*(p), lockdep_is_held(lock)) #define PUSH(p) ({ \ WARN_ON(IS_ENABLED(DEBUG) && len >= 128); \ stack[len++] = p; \ }) struct allowedips_node __rcu **stack[128], **nptr; struct allowedips_node *node, *prev; unsigned int len; if (unlikely(!peer || !REF(*top))) return; for (prev = NULL, len = 0, PUSH(top); len > 0; prev = node) { nptr = stack[len - 1]; node = DEREF(nptr); if (!node) { --len; continue; } if (!prev || REF(prev->bit[0]) == node || REF(prev->bit[1]) == node) { if (REF(node->bit[0])) PUSH(&node->bit[0]); else if (REF(node->bit[1])) PUSH(&node->bit[1]); } else if (REF(node->bit[0]) == prev) { if (REF(node->bit[1])) PUSH(&node->bit[1]); } else { if (rcu_dereference_protected(node->peer, lockdep_is_held(lock)) == peer) { RCU_INIT_POINTER(node->peer, NULL); list_del_init(&node->peer_list); if (!node->bit[0] || !node->bit[1]) { rcu_assign_pointer(*nptr, DEREF( &node->bit[!REF(node->bit[0])])); call_rcu(&node->rcu, node_free_rcu); node = DEREF(nptr); } } --len; } } #undef REF #undef DEREF #undef PUSH } static unsigned int fls128(u64 a, u64 b) { return a ? fls64(a) + 64U : fls64(b); Loading Loading @@ -164,7 +115,7 @@ static struct allowedips_node *find_node(struct allowedips_node *trie, u8 bits, found = node; if (node->cidr == bits) break; node = rcu_dereference_bh(CHOOSE_NODE(node, key)); node = rcu_dereference_bh(node->bit[choose(node, key)]); } return found; } Loading Loading @@ -196,8 +147,7 @@ static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key, u8 cidr, u8 bits, struct allowedips_node **rnode, struct mutex *lock) { struct allowedips_node *node = rcu_dereference_protected(trie, lockdep_is_held(lock)); struct allowedips_node *node = rcu_dereference_protected(trie, lockdep_is_held(lock)); struct allowedips_node *parent = NULL; bool exact = false; Loading @@ -207,13 +157,24 @@ static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key, exact = true; break; } node = rcu_dereference_protected(CHOOSE_NODE(parent, key), lockdep_is_held(lock)); node = rcu_dereference_protected(parent->bit[choose(parent, key)], lockdep_is_held(lock)); } *rnode = parent; return exact; } static inline void connect_node(struct allowedips_node __rcu **parent, u8 bit, struct allowedips_node *node) { node->parent_bit_packed = (unsigned long)parent | bit; rcu_assign_pointer(*parent, node); } static inline void choose_and_connect_node(struct allowedips_node *parent, struct allowedips_node *node) { u8 bit = choose(parent, node->bits); connect_node(&parent->bit[bit], bit, node); } static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, u8 cidr, struct wg_peer *peer, struct mutex *lock) { Loading @@ -223,13 +184,13 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, return -EINVAL; if (!rcu_access_pointer(*trie)) { node = kzalloc(sizeof(*node), GFP_KERNEL); node = kmem_cache_zalloc(node_cache, GFP_KERNEL); if (unlikely(!node)) return -ENOMEM; RCU_INIT_POINTER(node->peer, peer); list_add_tail(&node->peer_list, &peer->allowedips_list); copy_and_assign_cidr(node, key, cidr, bits); rcu_assign_pointer(*trie, node); connect_node(trie, 2, node); return 0; } if (node_placement(*trie, key, cidr, bits, &node, lock)) { Loading @@ -238,7 +199,7 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, return 0; } newnode = kzalloc(sizeof(*newnode), GFP_KERNEL); newnode = kmem_cache_zalloc(node_cache, GFP_KERNEL); if (unlikely(!newnode)) return -ENOMEM; RCU_INIT_POINTER(newnode->peer, peer); Loading @@ -248,10 +209,10 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, if (!node) { down = rcu_dereference_protected(*trie, lockdep_is_held(lock)); } else { down = rcu_dereference_protected(CHOOSE_NODE(node, key), lockdep_is_held(lock)); const u8 bit = choose(node, key); down = rcu_dereference_protected(node->bit[bit], lockdep_is_held(lock)); if (!down) { rcu_assign_pointer(CHOOSE_NODE(node, key), newnode); connect_node(&node->bit[bit], bit, newnode); return 0; } } Loading @@ -259,30 +220,29 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, parent = node; if (newnode->cidr == cidr) { rcu_assign_pointer(CHOOSE_NODE(newnode, down->bits), down); choose_and_connect_node(newnode, down); if (!parent) rcu_assign_pointer(*trie, newnode); connect_node(trie, 2, newnode); else rcu_assign_pointer(CHOOSE_NODE(parent, newnode->bits), newnode); } else { node = kzalloc(sizeof(*node), GFP_KERNEL); choose_and_connect_node(parent, newnode); return 0; } node = kmem_cache_zalloc(node_cache, GFP_KERNEL); if (unlikely(!node)) { list_del(&newnode->peer_list); kfree(newnode); kmem_cache_free(node_cache, newnode); return -ENOMEM; } INIT_LIST_HEAD(&node->peer_list); copy_and_assign_cidr(node, newnode->bits, cidr, bits); rcu_assign_pointer(CHOOSE_NODE(node, down->bits), down); rcu_assign_pointer(CHOOSE_NODE(node, newnode->bits), newnode); choose_and_connect_node(node, down); choose_and_connect_node(node, newnode); if (!parent) rcu_assign_pointer(*trie, node); connect_node(trie, 2, node); else rcu_assign_pointer(CHOOSE_NODE(parent, node->bits), node); } choose_and_connect_node(parent, node); return 0; } Loading Loading @@ -340,9 +300,41 @@ int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, void wg_allowedips_remove_by_peer(struct allowedips *table, struct wg_peer *peer, struct mutex *lock) { struct allowedips_node *node, *child, **parent_bit, *parent, *tmp; bool free_parent; if (list_empty(&peer->allowedips_list)) return; ++table->seq; walk_remove_by_peer(&table->root4, peer, lock); walk_remove_by_peer(&table->root6, peer, lock); list_for_each_entry_safe(node, tmp, &peer->allowedips_list, peer_list) { list_del_init(&node->peer_list); RCU_INIT_POINTER(node->peer, NULL); if (node->bit[0] && node->bit[1]) continue; child = rcu_dereference_protected(node->bit[!rcu_access_pointer(node->bit[0])], lockdep_is_held(lock)); if (child) child->parent_bit_packed = node->parent_bit_packed; parent_bit = (struct allowedips_node **)(node->parent_bit_packed & ~3UL); *parent_bit = child; parent = (void *)parent_bit - offsetof(struct allowedips_node, bit[node->parent_bit_packed & 1]); free_parent = !rcu_access_pointer(node->bit[0]) && !rcu_access_pointer(node->bit[1]) && (node->parent_bit_packed & 3) <= 1 && !rcu_access_pointer(parent->peer); if (free_parent) child = rcu_dereference_protected( parent->bit[!(node->parent_bit_packed & 1)], lockdep_is_held(lock)); call_rcu(&node->rcu, node_free_rcu); if (!free_parent) continue; if (child) child->parent_bit_packed = parent->parent_bit_packed; *(struct allowedips_node **)(parent->parent_bit_packed & ~3UL) = child; call_rcu(&parent->rcu, node_free_rcu); } } int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr) Loading Loading @@ -379,4 +371,16 @@ struct wg_peer *wg_allowedips_lookup_src(struct allowedips *table, return NULL; } int __init wg_allowedips_slab_init(void) { node_cache = KMEM_CACHE(allowedips_node, 0); return node_cache ? 0 : -ENOMEM; } void wg_allowedips_slab_uninit(void) { rcu_barrier(); kmem_cache_destroy(node_cache); } #include "selftest/allowedips.c" net/wireguard/allowedips.h +7 −7 Original line number Diff line number Diff line Loading @@ -15,14 +15,11 @@ struct wg_peer; struct allowedips_node { struct wg_peer __rcu *peer; struct allowedips_node __rcu *bit[2]; /* While it may seem scandalous that we waste space for v4, * we're alloc'ing to the nearest power of 2 anyway, so this * doesn't actually make a difference. */ u8 bits[16] __aligned(__alignof(u64)); u8 cidr, bit_at_a, bit_at_b, bitlen; u8 bits[16] __aligned(__alignof(u64)); /* Keep rarely used list at bottom to be beyond cache line. */ /* Keep rarely used members at bottom to be beyond cache line. */ unsigned long parent_bit_packed; union { struct list_head peer_list; struct rcu_head rcu; Loading @@ -33,7 +30,7 @@ struct allowedips { struct allowedips_node __rcu *root4; struct allowedips_node __rcu *root6; u64 seq; }; } __aligned(4); /* We pack the lower 2 bits of &root, but m68k only gives 16-bit alignment. */ void wg_allowedips_init(struct allowedips *table); void wg_allowedips_free(struct allowedips *table, struct mutex *mutex); Loading @@ -56,4 +53,7 @@ struct wg_peer *wg_allowedips_lookup_src(struct allowedips *table, bool wg_allowedips_selftest(void); #endif int wg_allowedips_slab_init(void); void wg_allowedips_slab_uninit(void); #endif /* _WG_ALLOWEDIPS_H */ net/wireguard/main.c +16 −1 Original line number Diff line number Diff line Loading @@ -26,13 +26,22 @@ static int __init mod_init(void) (ret = curve25519_mod_init())) return ret; ret = wg_allowedips_slab_init(); if (ret < 0) goto err_allowedips; #ifdef DEBUG ret = -ENOTRECOVERABLE; if (!wg_allowedips_selftest() || !wg_packet_counter_selftest() || !wg_ratelimiter_selftest()) return -ENOTRECOVERABLE; goto err_peer; #endif wg_noise_init(); ret = wg_peer_init(); if (ret < 0) goto err_peer; ret = wg_device_init(); if (ret < 0) goto err_device; Loading @@ -49,6 +58,10 @@ static int __init mod_init(void) err_netlink: wg_device_uninit(); err_device: wg_peer_uninit(); err_peer: wg_allowedips_slab_uninit(); err_allowedips: return ret; } Loading @@ -56,6 +69,8 @@ static void __exit mod_exit(void) { wg_genetlink_uninit(); wg_device_uninit(); wg_peer_uninit(); wg_allowedips_slab_uninit(); } module_init(mod_init); Loading net/wireguard/peer.c +20 −7 Original line number Diff line number Diff line Loading @@ -15,6 +15,7 @@ #include <linux/rcupdate.h> #include <linux/list.h> static struct kmem_cache *peer_cache; static atomic64_t peer_counter = ATOMIC64_INIT(0); struct wg_peer *wg_peer_create(struct wg_device *wg, Loading @@ -29,10 +30,10 @@ struct wg_peer *wg_peer_create(struct wg_device *wg, if (wg->num_peers >= MAX_PEERS_PER_DEVICE) return ERR_PTR(ret); peer = kzalloc(sizeof(*peer), GFP_KERNEL); peer = kmem_cache_zalloc(peer_cache, GFP_KERNEL); if (unlikely(!peer)) return ERR_PTR(ret); if (dst_cache_init(&peer->endpoint_cache, GFP_KERNEL)) if (unlikely(dst_cache_init(&peer->endpoint_cache, GFP_KERNEL))) goto err; peer->device = wg; Loading Loading @@ -64,7 +65,7 @@ struct wg_peer *wg_peer_create(struct wg_device *wg, return peer; err: kfree(peer); kmem_cache_free(peer_cache, peer); return ERR_PTR(ret); } Loading @@ -88,7 +89,7 @@ static void peer_make_dead(struct wg_peer *peer) /* Mark as dead, so that we don't allow jumping contexts after. */ WRITE_ONCE(peer->is_dead, true); /* The caller must now synchronize_rcu() for this to take effect. */ /* The caller must now synchronize_net() for this to take effect. */ } static void peer_remove_after_dead(struct wg_peer *peer) Loading Loading @@ -160,7 +161,7 @@ void wg_peer_remove(struct wg_peer *peer) lockdep_assert_held(&peer->device->device_update_lock); peer_make_dead(peer); synchronize_rcu(); synchronize_net(); peer_remove_after_dead(peer); } Loading @@ -178,7 +179,7 @@ void wg_peer_remove_all(struct wg_device *wg) peer_make_dead(peer); list_add_tail(&peer->peer_list, &dead_peers); } synchronize_rcu(); synchronize_net(); list_for_each_entry_safe(peer, temp, &dead_peers, peer_list) peer_remove_after_dead(peer); } Loading @@ -193,7 +194,8 @@ static void rcu_release(struct rcu_head *rcu) /* The final zeroing takes care of clearing any remaining handshake key * material and other potentially sensitive information. */ kfree_sensitive(peer); memzero_explicit(peer, sizeof(*peer)); kmem_cache_free(peer_cache, peer); } static void kref_release(struct kref *refcount) Loading Loading @@ -225,3 +227,14 @@ void wg_peer_put(struct wg_peer *peer) return; kref_put(&peer->refcount, kref_release); } int __init wg_peer_init(void) { peer_cache = KMEM_CACHE(wg_peer, 0); return peer_cache ? 0 : -ENOMEM; } void wg_peer_uninit(void) { kmem_cache_destroy(peer_cache); } Loading
net/wireguard/Makefile +2 −3 Original line number Diff line number Diff line Loading @@ -2,10 +2,9 @@ # # Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. ccflags-y := -O3 ccflags-$(CONFIG_WIREGUARD_DEBUG) += -DDEBUG -g ccflags-y += -D'pr_fmt(fmt)=KBUILD_MODNAME ": " fmt' ccflags-y := -D'pr_fmt(fmt)=KBUILD_MODNAME ": " fmt' ccflags-y += -Wframe-larger-than=2048 ccflags-$(CONFIG_WIREGUARD_DEBUG) += -DDEBUG -g ccflags-$(if $(WIREGUARD_VERSION),y,) += -D'WIREGUARD_VERSION="$(WIREGUARD_VERSION)"' wireguard-y := main.o noise.o device.o peer.o timers.o queueing.o send.o receive.o socket.o peerlookup.o allowedips.o ratelimiter.o cookie.o netlink.o Loading
net/wireguard/allowedips.c +96 −92 Original line number Diff line number Diff line Loading @@ -6,6 +6,8 @@ #include "allowedips.h" #include "peer.h" static struct kmem_cache *node_cache; static void swap_endian(u8 *dst, const u8 *src, u8 bits) { if (bits == 32) { Loading @@ -28,12 +30,10 @@ static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src, node->bitlen = bits; memcpy(node->bits, src, bits / 8U); } #define CHOOSE_NODE(parent, key) \ parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1] static void node_free_rcu(struct rcu_head *rcu) static inline u8 choose(struct allowedips_node *node, const u8 *key) { kfree(container_of(rcu, struct allowedips_node, rcu)); return (key[node->bit_at_a] >> node->bit_at_b) & 1; } static void push_rcu(struct allowedips_node **stack, Loading @@ -45,6 +45,11 @@ static void push_rcu(struct allowedips_node **stack, } } static void node_free_rcu(struct rcu_head *rcu) { kmem_cache_free(node_cache, container_of(rcu, struct allowedips_node, rcu)); } static void root_free_rcu(struct rcu_head *rcu) { struct allowedips_node *node, *stack[128] = { Loading @@ -54,7 +59,7 @@ static void root_free_rcu(struct rcu_head *rcu) while (len > 0 && (node = stack[--len])) { push_rcu(stack, node->bit[0], &len); push_rcu(stack, node->bit[1], &len); kfree(node); kmem_cache_free(node_cache, node); } } Loading @@ -71,60 +76,6 @@ static void root_remove_peer_lists(struct allowedips_node *root) } } static void walk_remove_by_peer(struct allowedips_node __rcu **top, struct wg_peer *peer, struct mutex *lock) { #define REF(p) rcu_access_pointer(p) #define DEREF(p) rcu_dereference_protected(*(p), lockdep_is_held(lock)) #define PUSH(p) ({ \ WARN_ON(IS_ENABLED(DEBUG) && len >= 128); \ stack[len++] = p; \ }) struct allowedips_node __rcu **stack[128], **nptr; struct allowedips_node *node, *prev; unsigned int len; if (unlikely(!peer || !REF(*top))) return; for (prev = NULL, len = 0, PUSH(top); len > 0; prev = node) { nptr = stack[len - 1]; node = DEREF(nptr); if (!node) { --len; continue; } if (!prev || REF(prev->bit[0]) == node || REF(prev->bit[1]) == node) { if (REF(node->bit[0])) PUSH(&node->bit[0]); else if (REF(node->bit[1])) PUSH(&node->bit[1]); } else if (REF(node->bit[0]) == prev) { if (REF(node->bit[1])) PUSH(&node->bit[1]); } else { if (rcu_dereference_protected(node->peer, lockdep_is_held(lock)) == peer) { RCU_INIT_POINTER(node->peer, NULL); list_del_init(&node->peer_list); if (!node->bit[0] || !node->bit[1]) { rcu_assign_pointer(*nptr, DEREF( &node->bit[!REF(node->bit[0])])); call_rcu(&node->rcu, node_free_rcu); node = DEREF(nptr); } } --len; } } #undef REF #undef DEREF #undef PUSH } static unsigned int fls128(u64 a, u64 b) { return a ? fls64(a) + 64U : fls64(b); Loading Loading @@ -164,7 +115,7 @@ static struct allowedips_node *find_node(struct allowedips_node *trie, u8 bits, found = node; if (node->cidr == bits) break; node = rcu_dereference_bh(CHOOSE_NODE(node, key)); node = rcu_dereference_bh(node->bit[choose(node, key)]); } return found; } Loading Loading @@ -196,8 +147,7 @@ static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key, u8 cidr, u8 bits, struct allowedips_node **rnode, struct mutex *lock) { struct allowedips_node *node = rcu_dereference_protected(trie, lockdep_is_held(lock)); struct allowedips_node *node = rcu_dereference_protected(trie, lockdep_is_held(lock)); struct allowedips_node *parent = NULL; bool exact = false; Loading @@ -207,13 +157,24 @@ static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key, exact = true; break; } node = rcu_dereference_protected(CHOOSE_NODE(parent, key), lockdep_is_held(lock)); node = rcu_dereference_protected(parent->bit[choose(parent, key)], lockdep_is_held(lock)); } *rnode = parent; return exact; } static inline void connect_node(struct allowedips_node __rcu **parent, u8 bit, struct allowedips_node *node) { node->parent_bit_packed = (unsigned long)parent | bit; rcu_assign_pointer(*parent, node); } static inline void choose_and_connect_node(struct allowedips_node *parent, struct allowedips_node *node) { u8 bit = choose(parent, node->bits); connect_node(&parent->bit[bit], bit, node); } static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, u8 cidr, struct wg_peer *peer, struct mutex *lock) { Loading @@ -223,13 +184,13 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, return -EINVAL; if (!rcu_access_pointer(*trie)) { node = kzalloc(sizeof(*node), GFP_KERNEL); node = kmem_cache_zalloc(node_cache, GFP_KERNEL); if (unlikely(!node)) return -ENOMEM; RCU_INIT_POINTER(node->peer, peer); list_add_tail(&node->peer_list, &peer->allowedips_list); copy_and_assign_cidr(node, key, cidr, bits); rcu_assign_pointer(*trie, node); connect_node(trie, 2, node); return 0; } if (node_placement(*trie, key, cidr, bits, &node, lock)) { Loading @@ -238,7 +199,7 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, return 0; } newnode = kzalloc(sizeof(*newnode), GFP_KERNEL); newnode = kmem_cache_zalloc(node_cache, GFP_KERNEL); if (unlikely(!newnode)) return -ENOMEM; RCU_INIT_POINTER(newnode->peer, peer); Loading @@ -248,10 +209,10 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, if (!node) { down = rcu_dereference_protected(*trie, lockdep_is_held(lock)); } else { down = rcu_dereference_protected(CHOOSE_NODE(node, key), lockdep_is_held(lock)); const u8 bit = choose(node, key); down = rcu_dereference_protected(node->bit[bit], lockdep_is_held(lock)); if (!down) { rcu_assign_pointer(CHOOSE_NODE(node, key), newnode); connect_node(&node->bit[bit], bit, newnode); return 0; } } Loading @@ -259,30 +220,29 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, parent = node; if (newnode->cidr == cidr) { rcu_assign_pointer(CHOOSE_NODE(newnode, down->bits), down); choose_and_connect_node(newnode, down); if (!parent) rcu_assign_pointer(*trie, newnode); connect_node(trie, 2, newnode); else rcu_assign_pointer(CHOOSE_NODE(parent, newnode->bits), newnode); } else { node = kzalloc(sizeof(*node), GFP_KERNEL); choose_and_connect_node(parent, newnode); return 0; } node = kmem_cache_zalloc(node_cache, GFP_KERNEL); if (unlikely(!node)) { list_del(&newnode->peer_list); kfree(newnode); kmem_cache_free(node_cache, newnode); return -ENOMEM; } INIT_LIST_HEAD(&node->peer_list); copy_and_assign_cidr(node, newnode->bits, cidr, bits); rcu_assign_pointer(CHOOSE_NODE(node, down->bits), down); rcu_assign_pointer(CHOOSE_NODE(node, newnode->bits), newnode); choose_and_connect_node(node, down); choose_and_connect_node(node, newnode); if (!parent) rcu_assign_pointer(*trie, node); connect_node(trie, 2, node); else rcu_assign_pointer(CHOOSE_NODE(parent, node->bits), node); } choose_and_connect_node(parent, node); return 0; } Loading Loading @@ -340,9 +300,41 @@ int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, void wg_allowedips_remove_by_peer(struct allowedips *table, struct wg_peer *peer, struct mutex *lock) { struct allowedips_node *node, *child, **parent_bit, *parent, *tmp; bool free_parent; if (list_empty(&peer->allowedips_list)) return; ++table->seq; walk_remove_by_peer(&table->root4, peer, lock); walk_remove_by_peer(&table->root6, peer, lock); list_for_each_entry_safe(node, tmp, &peer->allowedips_list, peer_list) { list_del_init(&node->peer_list); RCU_INIT_POINTER(node->peer, NULL); if (node->bit[0] && node->bit[1]) continue; child = rcu_dereference_protected(node->bit[!rcu_access_pointer(node->bit[0])], lockdep_is_held(lock)); if (child) child->parent_bit_packed = node->parent_bit_packed; parent_bit = (struct allowedips_node **)(node->parent_bit_packed & ~3UL); *parent_bit = child; parent = (void *)parent_bit - offsetof(struct allowedips_node, bit[node->parent_bit_packed & 1]); free_parent = !rcu_access_pointer(node->bit[0]) && !rcu_access_pointer(node->bit[1]) && (node->parent_bit_packed & 3) <= 1 && !rcu_access_pointer(parent->peer); if (free_parent) child = rcu_dereference_protected( parent->bit[!(node->parent_bit_packed & 1)], lockdep_is_held(lock)); call_rcu(&node->rcu, node_free_rcu); if (!free_parent) continue; if (child) child->parent_bit_packed = parent->parent_bit_packed; *(struct allowedips_node **)(parent->parent_bit_packed & ~3UL) = child; call_rcu(&parent->rcu, node_free_rcu); } } int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr) Loading Loading @@ -379,4 +371,16 @@ struct wg_peer *wg_allowedips_lookup_src(struct allowedips *table, return NULL; } int __init wg_allowedips_slab_init(void) { node_cache = KMEM_CACHE(allowedips_node, 0); return node_cache ? 0 : -ENOMEM; } void wg_allowedips_slab_uninit(void) { rcu_barrier(); kmem_cache_destroy(node_cache); } #include "selftest/allowedips.c"
net/wireguard/allowedips.h +7 −7 Original line number Diff line number Diff line Loading @@ -15,14 +15,11 @@ struct wg_peer; struct allowedips_node { struct wg_peer __rcu *peer; struct allowedips_node __rcu *bit[2]; /* While it may seem scandalous that we waste space for v4, * we're alloc'ing to the nearest power of 2 anyway, so this * doesn't actually make a difference. */ u8 bits[16] __aligned(__alignof(u64)); u8 cidr, bit_at_a, bit_at_b, bitlen; u8 bits[16] __aligned(__alignof(u64)); /* Keep rarely used list at bottom to be beyond cache line. */ /* Keep rarely used members at bottom to be beyond cache line. */ unsigned long parent_bit_packed; union { struct list_head peer_list; struct rcu_head rcu; Loading @@ -33,7 +30,7 @@ struct allowedips { struct allowedips_node __rcu *root4; struct allowedips_node __rcu *root6; u64 seq; }; } __aligned(4); /* We pack the lower 2 bits of &root, but m68k only gives 16-bit alignment. */ void wg_allowedips_init(struct allowedips *table); void wg_allowedips_free(struct allowedips *table, struct mutex *mutex); Loading @@ -56,4 +53,7 @@ struct wg_peer *wg_allowedips_lookup_src(struct allowedips *table, bool wg_allowedips_selftest(void); #endif int wg_allowedips_slab_init(void); void wg_allowedips_slab_uninit(void); #endif /* _WG_ALLOWEDIPS_H */
net/wireguard/main.c +16 −1 Original line number Diff line number Diff line Loading @@ -26,13 +26,22 @@ static int __init mod_init(void) (ret = curve25519_mod_init())) return ret; ret = wg_allowedips_slab_init(); if (ret < 0) goto err_allowedips; #ifdef DEBUG ret = -ENOTRECOVERABLE; if (!wg_allowedips_selftest() || !wg_packet_counter_selftest() || !wg_ratelimiter_selftest()) return -ENOTRECOVERABLE; goto err_peer; #endif wg_noise_init(); ret = wg_peer_init(); if (ret < 0) goto err_peer; ret = wg_device_init(); if (ret < 0) goto err_device; Loading @@ -49,6 +58,10 @@ static int __init mod_init(void) err_netlink: wg_device_uninit(); err_device: wg_peer_uninit(); err_peer: wg_allowedips_slab_uninit(); err_allowedips: return ret; } Loading @@ -56,6 +69,8 @@ static void __exit mod_exit(void) { wg_genetlink_uninit(); wg_device_uninit(); wg_peer_uninit(); wg_allowedips_slab_uninit(); } module_init(mod_init); Loading
net/wireguard/peer.c +20 −7 Original line number Diff line number Diff line Loading @@ -15,6 +15,7 @@ #include <linux/rcupdate.h> #include <linux/list.h> static struct kmem_cache *peer_cache; static atomic64_t peer_counter = ATOMIC64_INIT(0); struct wg_peer *wg_peer_create(struct wg_device *wg, Loading @@ -29,10 +30,10 @@ struct wg_peer *wg_peer_create(struct wg_device *wg, if (wg->num_peers >= MAX_PEERS_PER_DEVICE) return ERR_PTR(ret); peer = kzalloc(sizeof(*peer), GFP_KERNEL); peer = kmem_cache_zalloc(peer_cache, GFP_KERNEL); if (unlikely(!peer)) return ERR_PTR(ret); if (dst_cache_init(&peer->endpoint_cache, GFP_KERNEL)) if (unlikely(dst_cache_init(&peer->endpoint_cache, GFP_KERNEL))) goto err; peer->device = wg; Loading Loading @@ -64,7 +65,7 @@ struct wg_peer *wg_peer_create(struct wg_device *wg, return peer; err: kfree(peer); kmem_cache_free(peer_cache, peer); return ERR_PTR(ret); } Loading @@ -88,7 +89,7 @@ static void peer_make_dead(struct wg_peer *peer) /* Mark as dead, so that we don't allow jumping contexts after. */ WRITE_ONCE(peer->is_dead, true); /* The caller must now synchronize_rcu() for this to take effect. */ /* The caller must now synchronize_net() for this to take effect. */ } static void peer_remove_after_dead(struct wg_peer *peer) Loading Loading @@ -160,7 +161,7 @@ void wg_peer_remove(struct wg_peer *peer) lockdep_assert_held(&peer->device->device_update_lock); peer_make_dead(peer); synchronize_rcu(); synchronize_net(); peer_remove_after_dead(peer); } Loading @@ -178,7 +179,7 @@ void wg_peer_remove_all(struct wg_device *wg) peer_make_dead(peer); list_add_tail(&peer->peer_list, &dead_peers); } synchronize_rcu(); synchronize_net(); list_for_each_entry_safe(peer, temp, &dead_peers, peer_list) peer_remove_after_dead(peer); } Loading @@ -193,7 +194,8 @@ static void rcu_release(struct rcu_head *rcu) /* The final zeroing takes care of clearing any remaining handshake key * material and other potentially sensitive information. */ kfree_sensitive(peer); memzero_explicit(peer, sizeof(*peer)); kmem_cache_free(peer_cache, peer); } static void kref_release(struct kref *refcount) Loading Loading @@ -225,3 +227,14 @@ void wg_peer_put(struct wg_peer *peer) return; kref_put(&peer->refcount, kref_release); } int __init wg_peer_init(void) { peer_cache = KMEM_CACHE(wg_peer, 0); return peer_cache ? 0 : -ENOMEM; } void wg_peer_uninit(void) { kmem_cache_destroy(peer_cache); }