ping

On Fri, Aug 6 2021 at 12:43:14 AM -0400, Hamza Mahfooz <[email protected]> wrote:
It is made mention of in commit e7096c131e516 ("net: WireGuard secure
network tunnel"), that it is desirable to move away from the statically
sized hash-table implementation.

Signed-off-by: Hamza Mahfooz <[email protected]>
---
 drivers/net/wireguard/device.c     |   4 +
 drivers/net/wireguard/device.h     |   2 +-
 drivers/net/wireguard/noise.c      |   1 +
 drivers/net/wireguard/noise.h      |   1 +
 drivers/net/wireguard/peer.h       |   2 +-
drivers/net/wireguard/peerlookup.c | 190 ++++++++++++++---------------
 drivers/net/wireguard/peerlookup.h |  27 ++--
 7 files changed, 112 insertions(+), 115 deletions(-)

diff --git a/drivers/net/wireguard/device.c b/drivers/net/wireguard/device.c
index 551ddaaaf540..3bd43c9481ef 100644
--- a/drivers/net/wireguard/device.c
+++ b/drivers/net/wireguard/device.c
@@ -243,7 +243,9 @@ static void wg_destruct(struct net_device *dev)
        skb_queue_purge(&wg->incoming_handshakes);
        free_percpu(dev->tstats);
        free_percpu(wg->incoming_handshakes_worker);
+       wg_index_hashtable_destroy(wg->index_hashtable);
        kvfree(wg->index_hashtable);
+       wg_pubkey_hashtable_destroy(wg->peer_hashtable);
        kvfree(wg->peer_hashtable);
        mutex_unlock(&wg->device_update_lock);

@@ -382,8 +384,10 @@ static int wg_newlink(struct net *src_net, struct net_device *dev,
 err_free_tstats:
        free_percpu(dev->tstats);
 err_free_index_hashtable:
+       wg_index_hashtable_destroy(wg->index_hashtable);
        kvfree(wg->index_hashtable);
 err_free_peer_hashtable:
+       wg_pubkey_hashtable_destroy(wg->peer_hashtable);
        kvfree(wg->peer_hashtable);
        return ret;
 }
diff --git a/drivers/net/wireguard/device.h b/drivers/net/wireguard/device.h
index 854bc3d97150..24980eb766af 100644
--- a/drivers/net/wireguard/device.h
+++ b/drivers/net/wireguard/device.h
@@ -50,7 +50,7 @@ struct wg_device {
        struct multicore_worker __percpu *incoming_handshakes_worker;
        struct cookie_checker cookie_checker;
        struct pubkey_hashtable *peer_hashtable;
-       struct index_hashtable *index_hashtable;
+       struct rhashtable *index_hashtable;
        struct allowedips peer_allowedips;
        struct mutex device_update_lock, socket_update_lock;
        struct list_head device_list, peer_list;
diff --git a/drivers/net/wireguard/noise.c b/drivers/net/wireguard/noise.c
index c0cfd9b36c0b..d42a0ff2be5d 100644
--- a/drivers/net/wireguard/noise.c
+++ b/drivers/net/wireguard/noise.c
@@ -797,6 +797,7 @@ bool wg_noise_handshake_begin_session(struct noise_handshake *handshake,
        new_keypair->i_am_the_initiator = handshake->state ==
                                          HANDSHAKE_CONSUMED_RESPONSE;
        new_keypair->remote_index = handshake->remote_index;
+       new_keypair->entry.index = handshake->entry.index;

        if (new_keypair->i_am_the_initiator)
                derive_keys(&new_keypair->sending, &new_keypair->receiving,
diff --git a/drivers/net/wireguard/noise.h b/drivers/net/wireguard/noise.h
index c527253dba80..ea705747e4e4 100644
--- a/drivers/net/wireguard/noise.h
+++ b/drivers/net/wireguard/noise.h
@@ -72,6 +72,7 @@ struct noise_handshake {

        u8 ephemeral_private[NOISE_PUBLIC_KEY_LEN];
        u8 remote_static[NOISE_PUBLIC_KEY_LEN];
+       siphash_key_t skey;
        u8 remote_ephemeral[NOISE_PUBLIC_KEY_LEN];
        u8 precomputed_static_static[NOISE_PUBLIC_KEY_LEN];

diff --git a/drivers/net/wireguard/peer.h b/drivers/net/wireguard/peer.h
index 76e4d3128ad4..d5403fb7a6a0 100644
--- a/drivers/net/wireguard/peer.h
+++ b/drivers/net/wireguard/peer.h
@@ -48,7 +48,7 @@ struct wg_peer {
        atomic64_t last_sent_handshake;
struct work_struct transmit_handshake_work, clear_peer_work, transmit_packet_work;
        struct cookie latest_cookie;
-       struct hlist_node pubkey_hash;
+       struct rhash_head pubkey_hash;
        u64 rx_bytes, tx_bytes;
        struct timer_list timer_retransmit_handshake, timer_send_keepalive;
        struct timer_list timer_new_handshake, timer_zero_key_material;
diff --git a/drivers/net/wireguard/peerlookup.c b/drivers/net/wireguard/peerlookup.c
index f2783aa7a88f..2ea2ba85a33d 100644
--- a/drivers/net/wireguard/peerlookup.c
+++ b/drivers/net/wireguard/peerlookup.c
@@ -7,18 +7,29 @@
 #include "peer.h"
 #include "noise.h"

-static struct hlist_head *pubkey_bucket(struct pubkey_hashtable *table,
-                                       const u8 pubkey[NOISE_PUBLIC_KEY_LEN])
+struct pubkey_pair {
+       u8 key[NOISE_PUBLIC_KEY_LEN];
+       siphash_key_t skey;
+};
+
+static u32 pubkey_hash(const void *data, u32 len, u32 seed)
 {
+       const struct pubkey_pair *pair = data;
+
/* siphash gives us a secure 64bit number based on a random key. Since - * the bits are uniformly distributed, we can then mask off to get the
-        * bits we need.
+        * the bits are uniformly distributed.
         */
-       const u64 hash = siphash(pubkey, NOISE_PUBLIC_KEY_LEN, &table->key);

-       return &table->hashtable[hash & (HASH_SIZE(table->hashtable) - 1)];
+       return (u32)siphash(pair->key, len, &pair->skey);
 }

+static const struct rhashtable_params wg_peer_params = {
+       .key_len = NOISE_PUBLIC_KEY_LEN,
+       .key_offset = offsetof(struct wg_peer, handshake.remote_static),
+       .head_offset = offsetof(struct wg_peer, pubkey_hash),
+       .hashfn = pubkey_hash
+};
+
 struct pubkey_hashtable *wg_pubkey_hashtable_alloc(void)
 {
struct pubkey_hashtable *table = kvmalloc(sizeof(*table), GFP_KERNEL); @@ -27,26 +38,25 @@ struct pubkey_hashtable *wg_pubkey_hashtable_alloc(void)
                return NULL;

        get_random_bytes(&table->key, sizeof(table->key));
-       hash_init(table->hashtable);
-       mutex_init(&table->lock);
+       rhashtable_init(&table->hashtable, &wg_peer_params);
+
        return table;
 }

 void wg_pubkey_hashtable_add(struct pubkey_hashtable *table,
                             struct wg_peer *peer)
 {
-       mutex_lock(&table->lock);
-       hlist_add_head_rcu(&peer->pubkey_hash,
-                          pubkey_bucket(table, peer->handshake.remote_static));
-       mutex_unlock(&table->lock);
+       memcpy(&peer->handshake.skey, &table->key, sizeof(table->key));
+ WARN_ON(rhashtable_insert_fast(&table->hashtable, &peer->pubkey_hash,
+                                      wg_peer_params));
 }

 void wg_pubkey_hashtable_remove(struct pubkey_hashtable *table,
                                struct wg_peer *peer)
 {
-       mutex_lock(&table->lock);
-       hlist_del_init_rcu(&peer->pubkey_hash);
-       mutex_unlock(&table->lock);
+       memcpy(&peer->handshake.skey, &table->key, sizeof(table->key));
+       rhashtable_remove_fast(&table->hashtable, &peer->pubkey_hash,
+                              wg_peer_params);
 }

 /* Returns a strong reference to a peer */
@@ -54,41 +64,54 @@ struct wg_peer *
 wg_pubkey_hashtable_lookup(struct pubkey_hashtable *table,
                           const u8 pubkey[NOISE_PUBLIC_KEY_LEN])
 {
-       struct wg_peer *iter_peer, *peer = NULL;
+       struct wg_peer *peer = NULL;
+       struct pubkey_pair pair;
+
+       memcpy(pair.key, pubkey, NOISE_PUBLIC_KEY_LEN);
+       memcpy(&pair.skey, &table->key, sizeof(pair.skey));

        rcu_read_lock_bh();
-       hlist_for_each_entry_rcu_bh(iter_peer, pubkey_bucket(table, pubkey),
-                                   pubkey_hash) {
-               if (!memcmp(pubkey, iter_peer->handshake.remote_static,
-                           NOISE_PUBLIC_KEY_LEN)) {
-                       peer = iter_peer;
-                       break;
-               }
-       }
-       peer = wg_peer_get_maybe_zero(peer);
+ peer = wg_peer_get_maybe_zero(rhashtable_lookup_fast(&table->hashtable,
+                                                            &pair,
+                                                            wg_peer_params));
        rcu_read_unlock_bh();
+
        return peer;
 }

-static struct hlist_head *index_bucket(struct index_hashtable *table,
-                                      const __le32 index)
+void wg_pubkey_hashtable_destroy(struct pubkey_hashtable *table)
+{
+       WARN_ON(atomic_read(&table->hashtable.nelems));
+       rhashtable_destroy(&table->hashtable);
+}
+
+static u32 index_hash(const void *data, u32 len, u32 seed)
 {
+       const __le32 *index = data;
+
        /* Since the indices are random and thus all bits are uniformly
-        * distributed, we can find its bucket simply by masking.
+        * distributed, we can use them as the hash value.
         */
-       return &table->hashtable[(__force u32)index &
-                                (HASH_SIZE(table->hashtable) - 1)];
+
+       return (__force u32)*index;
 }

-struct index_hashtable *wg_index_hashtable_alloc(void)
+static const struct rhashtable_params index_entry_params = {
+       .key_len = sizeof(__le32),
+       .key_offset = offsetof(struct index_hashtable_entry, index),
+       .head_offset = offsetof(struct index_hashtable_entry, index_hash),
+       .hashfn = index_hash
+};
+
+struct rhashtable *wg_index_hashtable_alloc(void)
 {
- struct index_hashtable *table = kvmalloc(sizeof(*table), GFP_KERNEL);
+       struct rhashtable *table = kvmalloc(sizeof(*table), GFP_KERNEL);

        if (!table)
                return NULL;

-       hash_init(table->hashtable);
-       spin_lock_init(&table->lock);
+       rhashtable_init(table, &index_entry_params);
+
        return table;
 }

@@ -116,111 +139,86 @@ struct index_hashtable *wg_index_hashtable_alloc(void)
  * is another thing to consider moving forward.
  */

-__le32 wg_index_hashtable_insert(struct index_hashtable *table,
+__le32 wg_index_hashtable_insert(struct rhashtable *table,
                                 struct index_hashtable_entry *entry)
 {
        struct index_hashtable_entry *existing_entry;

-       spin_lock_bh(&table->lock);
-       hlist_del_init_rcu(&entry->index_hash);
-       spin_unlock_bh(&table->lock);
+       wg_index_hashtable_remove(table, entry);

        rcu_read_lock_bh();

 search_unused_slot:
        /* First we try to find an unused slot, randomly, while unlocked. */
        entry->index = (__force __le32)get_random_u32();
-       hlist_for_each_entry_rcu_bh(existing_entry,
-                                   index_bucket(table, entry->index),
-                                   index_hash) {
-               if (existing_entry->index == entry->index)
-                       /* If it's already in use, we continue searching. */
-                       goto search_unused_slot;
-       }

- /* Once we've found an unused slot, we lock it, and then double-check
-        * that nobody else stole it from us.
-        */
-       spin_lock_bh(&table->lock);
-       hlist_for_each_entry_rcu_bh(existing_entry,
-                                   index_bucket(table, entry->index),
-                                   index_hash) {
-               if (existing_entry->index == entry->index) {
-                       spin_unlock_bh(&table->lock);
-                       /* If it was stolen, we start over. */
-                       goto search_unused_slot;
-               }
+       existing_entry = rhashtable_lookup_get_insert_fast(table,
+                                                          &entry->index_hash,
+                                                          index_entry_params);
+
+       if (existing_entry) {
+               WARN_ON(IS_ERR(existing_entry));
+
+               /* If it's already in use, we continue searching. */
+               goto search_unused_slot;
        }
-       /* Otherwise, we know we have it exclusively (since we're locked),
-        * so we insert.
-        */
-       hlist_add_head_rcu(&entry->index_hash,
-                          index_bucket(table, entry->index));
-       spin_unlock_bh(&table->lock);

        rcu_read_unlock_bh();

        return entry->index;
 }

-bool wg_index_hashtable_replace(struct index_hashtable *table,
+bool wg_index_hashtable_replace(struct rhashtable *table,
                                struct index_hashtable_entry *old,
                                struct index_hashtable_entry *new)
 {
-       bool ret;
+       int ret = rhashtable_replace_fast(table, &old->index_hash,
+                                         &new->index_hash,
+                                         index_entry_params);

-       spin_lock_bh(&table->lock);
-       ret = !hlist_unhashed(&old->index_hash);
-       if (unlikely(!ret))
-               goto out;
+       WARN_ON(ret == -EINVAL);

-       new->index = old->index;
-       hlist_replace_rcu(&old->index_hash, &new->index_hash);
-
-       /* Calling init here NULLs out index_hash, and in fact after this
-        * function returns, it's theoretically possible for this to get
- * reinserted elsewhere. That means the RCU lookup below might either
-        * terminate early or jump between buckets, in which case the packet
-        * simply gets dropped, which isn't terrible.
-        */
-       INIT_HLIST_NODE(&old->index_hash);
-out:
-       spin_unlock_bh(&table->lock);
-       return ret;
+       return ret != -ENOENT;
 }

-void wg_index_hashtable_remove(struct index_hashtable *table,
+void wg_index_hashtable_remove(struct rhashtable *table,
                               struct index_hashtable_entry *entry)
 {
-       spin_lock_bh(&table->lock);
-       hlist_del_init_rcu(&entry->index_hash);
-       spin_unlock_bh(&table->lock);
+ rhashtable_remove_fast(table, &entry->index_hash, index_entry_params);
 }

 /* Returns a strong reference to a entry->peer */
 struct index_hashtable_entry *
-wg_index_hashtable_lookup(struct index_hashtable *table,
+wg_index_hashtable_lookup(struct rhashtable *table,
                          const enum index_hashtable_type type_mask,
                          const __le32 index, struct wg_peer **peer)
 {
-       struct index_hashtable_entry *iter_entry, *entry = NULL;
+       struct index_hashtable_entry *entry = NULL;

        rcu_read_lock_bh();
-       hlist_for_each_entry_rcu_bh(iter_entry, index_bucket(table, index),
-                                   index_hash) {
-               if (iter_entry->index == index) {
-                       if (likely(iter_entry->type & type_mask))
-                               entry = iter_entry;
-                       break;
-               }
-       }
+       entry = rhashtable_lookup_fast(table, &index, index_entry_params);
+
        if (likely(entry)) {
+               if (unlikely(!(entry->type & type_mask))) {
+                       entry = NULL;
+                       goto out;
+               }
+
                entry->peer = wg_peer_get_maybe_zero(entry->peer);
                if (likely(entry->peer))
                        *peer = entry->peer;
                else
                        entry = NULL;
        }
+
+out:
        rcu_read_unlock_bh();
+
        return entry;
 }
+
+void wg_index_hashtable_destroy(struct rhashtable *table)
+{
+       WARN_ON(atomic_read(&table->nelems));
+       rhashtable_destroy(table);
+}
diff --git a/drivers/net/wireguard/peerlookup.h b/drivers/net/wireguard/peerlookup.h
index ced811797680..a3cef26cb733 100644
--- a/drivers/net/wireguard/peerlookup.h
+++ b/drivers/net/wireguard/peerlookup.h
@@ -8,17 +8,14 @@

 #include "messages.h"

-#include <linux/hashtable.h>
-#include <linux/mutex.h>
+#include <linux/rhashtable.h>
 #include <linux/siphash.h>

 struct wg_peer;

 struct pubkey_hashtable {
-       /* TODO: move to rhashtable */
-       DECLARE_HASHTABLE(hashtable, 11);
+       struct rhashtable hashtable;
        siphash_key_t key;
-       struct mutex lock;
 };

 struct pubkey_hashtable *wg_pubkey_hashtable_alloc(void);
@@ -29,12 +26,7 @@ void wg_pubkey_hashtable_remove(struct pubkey_hashtable *table,
 struct wg_peer *
 wg_pubkey_hashtable_lookup(struct pubkey_hashtable *table,
                           const u8 pubkey[NOISE_PUBLIC_KEY_LEN]);
-
-struct index_hashtable {
-       /* TODO: move to rhashtable */
-       DECLARE_HASHTABLE(hashtable, 13);
-       spinlock_t lock;
-};
+void wg_pubkey_hashtable_destroy(struct pubkey_hashtable *table);

 enum index_hashtable_type {
        INDEX_HASHTABLE_HANDSHAKE = 1U << 0,
@@ -43,22 +35,23 @@ enum index_hashtable_type {

 struct index_hashtable_entry {
        struct wg_peer *peer;
-       struct hlist_node index_hash;
+       struct rhash_head index_hash;
        enum index_hashtable_type type;
        __le32 index;
 };

-struct index_hashtable *wg_index_hashtable_alloc(void);
-__le32 wg_index_hashtable_insert(struct index_hashtable *table,
+struct rhashtable *wg_index_hashtable_alloc(void);
+__le32 wg_index_hashtable_insert(struct rhashtable *table,
                                 struct index_hashtable_entry *entry);
-bool wg_index_hashtable_replace(struct index_hashtable *table,
+bool wg_index_hashtable_replace(struct rhashtable *table,
                                struct index_hashtable_entry *old,
                                struct index_hashtable_entry *new);
-void wg_index_hashtable_remove(struct index_hashtable *table,
+void wg_index_hashtable_remove(struct rhashtable *table,
                               struct index_hashtable_entry *entry);
 struct index_hashtable_entry *
-wg_index_hashtable_lookup(struct index_hashtable *table,
+wg_index_hashtable_lookup(struct rhashtable *table,
                          const enum index_hashtable_type type_mask,
                          const __le32 index, struct wg_peer **peer);
+void wg_index_hashtable_destroy(struct rhashtable *table);

 #endif /* _WG_PEERLOOKUP_H */
--
2.32.0



Reply via email to