Module Name:    src
Committed By:   riastradh
Date:           Mon Aug 31 20:23:56 UTC 2020

Modified Files:
        src/sys/net: if_wg.c

Log Message:
wg: Use thmap(9) for peer and session lookup.

Make sure we also don't trip over our own shoelaces by choosing the
same session index twice.


To generate a diff of this commit:
cvs rdiff -u -r1.36 -r1.37 src/sys/net/if_wg.c

Please note that diffs are not public domain; they are subject to the
copyright notices on the relevant files.

Modified files:

Index: src/sys/net/if_wg.c
diff -u src/sys/net/if_wg.c:1.36 src/sys/net/if_wg.c:1.37
--- src/sys/net/if_wg.c:1.36	Mon Aug 31 20:21:30 2020
+++ src/sys/net/if_wg.c	Mon Aug 31 20:23:56 2020
@@ -1,4 +1,4 @@
-/*	$NetBSD: if_wg.c,v 1.36 2020/08/31 20:21:30 riastradh Exp $	*/
+/*	$NetBSD: if_wg.c,v 1.37 2020/08/31 20:23:56 riastradh Exp $	*/
 
 /*
  * Copyright (C) Ryota Ozaki <ozaki.ry...@gmail.com>
@@ -41,7 +41,7 @@
  */
 
 #include <sys/cdefs.h>
-__KERNEL_RCSID(0, "$NetBSD: if_wg.c,v 1.36 2020/08/31 20:21:30 riastradh Exp $");
+__KERNEL_RCSID(0, "$NetBSD: if_wg.c,v 1.37 2020/08/31 20:23:56 riastradh Exp $");
 
 #ifdef _KERNEL_OPT
 #include "opt_inet.h"
@@ -77,6 +77,7 @@ __KERNEL_RCSID(0, "$NetBSD: if_wg.c,v 1.
 #include <sys/sysctl.h>
 #include <sys/syslog.h>
 #include <sys/systm.h>
+#include <sys/thmap.h>
 #include <sys/time.h>
 #include <sys/timespec.h>
 
@@ -603,6 +604,9 @@ struct wg_softc {
 
 	int		wg_npeers;
 	struct pslist_head	wg_peers;
+	struct thmap	*wg_peers_bypubkey;
+	struct thmap	*wg_peers_byname;
+	struct thmap	*wg_sessions_byindex;
 	uint16_t	wg_listen_port;
 
 	struct wg_worker	*wg_worker;
@@ -1159,6 +1163,52 @@ wg_unlock_session(struct wg_peer *wgp, s
 }
 #endif
 
+static uint32_t
+wg_assign_sender_index(struct wg_softc *wg, struct wg_session *wgs)
+{
+	struct wg_peer *wgp = wgs->wgs_peer;
+	struct wg_session *wgs0;
+	uint32_t index;
+	void *garbage;
+
+	mutex_enter(wgs->wgs_lock);
+
+	/* Release the current index, if there is one.  */
+	while ((index = wgs->wgs_sender_index) != 0) {
+		/* Remove the session by index.  */
+		thmap_del(wg->wg_sessions_byindex, &index, sizeof index);
+		wgs->wgs_sender_index = 0;
+		mutex_exit(wgs->wgs_lock);
+
+		/* Wait for all thmap_gets to complete, and GC.  */
+		garbage = thmap_stage_gc(wg->wg_sessions_byindex);
+		mutex_enter(wgs->wgs_peer->wgp_lock);
+		pserialize_perform(wgp->wgp_psz);
+		mutex_exit(wgs->wgs_peer->wgp_lock);
+		thmap_gc(wg->wg_sessions_byindex, garbage);
+
+		mutex_enter(wgs->wgs_lock);
+	}
+
+restart:
+	/* Pick a uniform random nonzero index.  */
+	while (__predict_false((index = cprng_strong32()) == 0))
+		continue;
+
+	/* Try to take it.  */
+	wgs->wgs_sender_index = index;
+	wgs0 = thmap_put(wg->wg_sessions_byindex,
+	    &wgs->wgs_sender_index, sizeof wgs->wgs_sender_index, wgs);
+
+	/* If someone else beat us, start over.  */
+	if (__predict_false(wgs0 != wgs))
+		goto restart;
+
+	mutex_exit(wgs->wgs_lock);
+
+	return index;
+}
+
 /*
  * Handshake patterns
  *
@@ -1192,7 +1242,7 @@ wg_fill_msg_init(struct wg_softc *wg, st
 	uint8_t privkey[WG_EPHEMERAL_KEY_LEN];
 
 	wgmi->wgmi_type = WG_MSG_TYPE_INIT;
-	wgmi->wgmi_sender = cprng_strong32();
+	wgmi->wgmi_sender = wg_assign_sender_index(wg, wgs);
 
 	/* [W] 5.4.2: First Message: Initiator to Responder */
 
@@ -1267,7 +1317,6 @@ wg_fill_msg_init(struct wg_softc *wg, st
 	memcpy(wgs->wgs_ephemeral_key_priv, privkey, sizeof(privkey));
 	memcpy(wgs->wgs_handshake_hash, hash, sizeof(hash));
 	memcpy(wgs->wgs_chaining_key, ckey, sizeof(ckey));
-	wgs->wgs_sender_index = wgmi->wgmi_sender;
 	WG_DLOG("%s: sender=%x\n", __func__, wgs->wgs_sender_index);
 }
 
@@ -1609,7 +1658,7 @@ wg_fill_msg_resp(struct wg_softc *wg, st
 	memcpy(ckey, wgs->wgs_chaining_key, sizeof(ckey));
 
 	wgmr->wgmr_type = WG_MSG_TYPE_RESP;
-	wgmr->wgmr_sender = cprng_strong32();
+	wgmr->wgmr_sender = wg_assign_sender_index(wg, wgs);
 	wgmr->wgmr_receiver = wgmi->wgmi_sender;
 
 	/* [W] 5.4.3 Second Message: Responder to Initiator */
@@ -1680,7 +1729,6 @@ wg_fill_msg_resp(struct wg_softc *wg, st
 	memcpy(wgs->wgs_chaining_key, ckey, sizeof(ckey));
 	memcpy(wgs->wgs_ephemeral_key_pub, pubkey, sizeof(pubkey));
 	memcpy(wgs->wgs_ephemeral_key_priv, privkey, sizeof(privkey));
-	wgs->wgs_sender_index = wgmr->wgmr_sender;
 	wgs->wgs_receiver_index = wgmi->wgmi_sender;
 	WG_DLOG("sender=%x\n", wgs->wgs_sender_index);
 	WG_DLOG("receiver=%x\n", wgs->wgs_receiver_index);
@@ -1906,12 +1954,7 @@ wg_lookup_peer_by_pubkey(struct wg_softc
 	struct wg_peer *wgp;
 
 	int s = pserialize_read_enter();
-	/* XXX O(n) */
-	WG_PEER_READER_FOREACH(wgp, wg) {
-		if (consttime_memequal(wgp->wgp_pubkey, pubkey,
-			sizeof(wgp->wgp_pubkey)))
-			break;
-	}
+	wgp = thmap_get(wg->wg_peers_bypubkey, pubkey, WG_STATIC_KEY_LEN);
 	if (wgp != NULL)
 		wg_get_peer(wgp, psref);
 	pserialize_read_exit(s);
@@ -2087,24 +2130,10 @@ static struct wg_session *
 wg_lookup_session_by_index(struct wg_softc *wg, const uint32_t index,
     struct psref *psref)
 {
-	struct wg_peer *wgp;
 	struct wg_session *wgs;
 
 	int s = pserialize_read_enter();
-	/* XXX O(n) */
-	WG_PEER_READER_FOREACH(wgp, wg) {
-		wgs = wgp->wgp_session_stable;
-		WG_DLOG("index=%x wgs_sender_index=%x\n",
-		    index, wgs->wgs_sender_index);
-		if (wgs->wgs_sender_index == index)
-			break;
-		wgs = wgp->wgp_session_unstable;
-		WG_DLOG("index=%x wgs_sender_index=%x\n",
-		    index, wgs->wgs_sender_index);
-		if (wgs->wgs_sender_index == index)
-			break;
-		wgs = NULL;
-	}
+	wgs = thmap_get(wg->wg_sessions_byindex, &index, sizeof index);
 	if (wgs != NULL)
 		psref_acquire(psref, &wgs->wgs_psref, wg_psref_class);
 	pserialize_read_exit(s);
@@ -3262,7 +3291,10 @@ wg_destroy_peer(struct wg_peer *wgp)
 {
 	struct wg_session *wgs;
 	struct wg_softc *wg = wgp->wgp_sc;
+	uint32_t index;
+	void *garbage;
 
+	/* Prevent new packets from this peer on any source address.  */
 	rw_enter(wg->wg_rwlock, RW_WRITER);
 	for (int i = 0; i < wgp->wgp_n_allowedips; i++) {
 		struct wg_allowedip *wga = &wgp->wgp_allowedips[i];
@@ -3281,11 +3313,29 @@ wg_destroy_peer(struct wg_peer *wgp)
 	}
 	rw_exit(wg->wg_rwlock);
 
+	/* Halt all packet processing and timeouts.  */
 	softint_disestablish(wgp->wgp_si);
 	callout_halt(&wgp->wgp_rekey_timer, NULL);
 	callout_halt(&wgp->wgp_handshake_timeout_timer, NULL);
 	callout_halt(&wgp->wgp_session_dtor_timer, NULL);
 
+	/* Remove the sessions by index.  */
+	if ((index = wgp->wgp_session_stable->wgs_sender_index) != 0) {
+		thmap_del(wg->wg_sessions_byindex, &index, sizeof index);
+		wgp->wgp_session_stable->wgs_sender_index = 0;
+	}
+	if ((index = wgp->wgp_session_unstable->wgs_sender_index) != 0) {
+		thmap_del(wg->wg_sessions_byindex, &index, sizeof index);
+		wgp->wgp_session_unstable->wgs_sender_index = 0;
+	}
+
+	/* Wait for all thmap_gets to complete, and GC.  */
+	garbage = thmap_stage_gc(wg->wg_sessions_byindex);
+	mutex_enter(wgp->wgp_lock);
+	pserialize_perform(wgp->wgp_psz);
+	mutex_exit(wgp->wgp_lock);
+	thmap_gc(wg->wg_sessions_byindex, garbage);
+
 	wgs = wgp->wgp_session_unstable;
 	psref_target_destroy(&wgs->wgs_psref, wg_psref_class);
 	mutex_obj_free(wgs->wgs_lock);
@@ -3295,6 +3345,7 @@ wg_destroy_peer(struct wg_peer *wgp)
 	mutex_destroy(&wgs->wgs_send_counter_lock);
 #endif
 	kmem_free(wgs, sizeof(*wgs));
+
 	wgs = wgp->wgp_session_stable;
 	psref_target_destroy(&wgs->wgs_psref, wg_psref_class);
 	mutex_obj_free(wgs->wgs_lock);
@@ -3320,11 +3371,23 @@ wg_destroy_peer(struct wg_peer *wgp)
 static void
 wg_destroy_all_peers(struct wg_softc *wg)
 {
-	struct wg_peer *wgp;
+	struct wg_peer *wgp, *wgp0 __diagused;
+	void *garbage_byname, *garbage_bypubkey;
 
 restart:
+	garbage_byname = garbage_bypubkey = NULL;
 	mutex_enter(wg->wg_lock);
 	WG_PEER_WRITER_FOREACH(wgp, wg) {
+		if (wgp->wgp_name[0]) {
+			wgp0 = thmap_del(wg->wg_peers_byname, wgp->wgp_name,
+			    strlen(wgp->wgp_name));
+			KASSERT(wgp0 == wgp);
+			garbage_byname = thmap_stage_gc(wg->wg_peers_byname);
+		}
+		wgp0 = thmap_del(wg->wg_peers_bypubkey, wgp->wgp_pubkey,
+		    sizeof(wgp->wgp_pubkey));
+		KASSERT(wgp0 == wgp);
+		garbage_bypubkey = thmap_stage_gc(wg->wg_peers_bypubkey);
 		WG_PEER_WRITER_REMOVE(wgp);
 		wg->wg_npeers--;
 		mutex_enter(wgp->wgp_lock);
@@ -3342,6 +3405,8 @@ restart:
 	psref_target_destroy(&wgp->wgp_psref, wg_psref_class);
 
 	wg_destroy_peer(wgp);
+	thmap_gc(wg->wg_peers_byname, garbage_byname);
+	thmap_gc(wg->wg_peers_bypubkey, garbage_bypubkey);
 
 	goto restart;
 }
@@ -3349,14 +3414,17 @@ restart:
 static int
 wg_destroy_peer_name(struct wg_softc *wg, const char *name)
 {
-	struct wg_peer *wgp;
+	struct wg_peer *wgp, *wgp0 __diagused;
+	void *garbage_byname, *garbage_bypubkey;
 
 	mutex_enter(wg->wg_lock);
-	WG_PEER_WRITER_FOREACH(wgp, wg) {
-		if (strcmp(wgp->wgp_name, name) == 0)
-			break;
-	}
+	wgp = thmap_del(wg->wg_peers_byname, name, strlen(name));
 	if (wgp != NULL) {
+		wgp0 = thmap_del(wg->wg_peers_bypubkey, wgp->wgp_pubkey,
+		    sizeof(wgp->wgp_pubkey));
+		KASSERT(wgp0 == wgp);
+		garbage_byname = thmap_stage_gc(wg->wg_peers_byname);
+		garbage_bypubkey = thmap_stage_gc(wg->wg_peers_bypubkey);
 		WG_PEER_WRITER_REMOVE(wgp);
 		wg->wg_npeers--;
 		mutex_enter(wgp->wgp_lock);
@@ -3373,6 +3441,8 @@ wg_destroy_peer_name(struct wg_softc *wg
 	psref_target_destroy(&wgp->wgp_psref, wg_psref_class);
 
 	wg_destroy_peer(wgp);
+	thmap_gc(wg->wg_peers_byname, garbage_byname);
+	thmap_gc(wg->wg_peers_bypubkey, garbage_bypubkey);
 
 	return 0;
 }
@@ -3432,6 +3502,9 @@ wg_clone_create(struct if_clone *ifc, in
 #endif
 
 	PSLIST_INIT(&wg->wg_peers);
+	wg->wg_peers_bypubkey = thmap_create(0, NULL, THMAP_NOCOPY);
+	wg->wg_peers_byname = thmap_create(0, NULL, THMAP_NOCOPY);
+	wg->wg_sessions_byindex = thmap_create(0, NULL, THMAP_NOCOPY);
 	wg->wg_lock = mutex_obj_alloc(MUTEX_DEFAULT, IPL_NONE);
 	wg->wg_rwlock = rw_obj_alloc();
 	wg->wg_ops = &wg_ops_rumpkernel;
@@ -3445,6 +3518,9 @@ wg_clone_create(struct if_clone *ifc, in
 			free(wg->wg_rtable_ipv6, M_RTABLE);
 		PSLIST_DESTROY(&wg->wg_peers);
 		mutex_obj_free(wg->wg_lock);
+		thmap_destroy(wg->wg_sessions_byindex);
+		thmap_destroy(wg->wg_peers_byname);
+		thmap_destroy(wg->wg_peers_bypubkey);
 		kmem_free(wg, sizeof(struct wg_softc));
 		return error;
 	}
@@ -3482,6 +3558,9 @@ wg_clone_destroy(struct ifnet *ifp)
 		free(wg->wg_rtable_ipv6, M_RTABLE);
 
 	PSLIST_DESTROY(&wg->wg_peers);
+	thmap_destroy(wg->wg_sessions_byindex);
+	thmap_destroy(wg->wg_peers_byname);
+	thmap_destroy(wg->wg_peers_bypubkey);
 	mutex_obj_free(wg->wg_lock);
 	rw_obj_free(wg->wg_rwlock);
 
@@ -4076,7 +4155,7 @@ wg_ioctl_add_peer(struct wg_softc *wg, s
 	int error;
 	prop_dictionary_t prop_dict;
 	char *buf = NULL;
-	struct wg_peer *wgp = NULL;
+	struct wg_peer *wgp = NULL, *wgp0 __diagused;
 
 	error = wg_alloc_prop_buf(&buf, ifd);
 	if (error != 0)
@@ -4091,6 +4170,24 @@ wg_ioctl_add_peer(struct wg_softc *wg, s
 		goto out;
 
 	mutex_enter(wg->wg_lock);
+	if (thmap_get(wg->wg_peers_bypubkey, wgp->wgp_pubkey,
+		sizeof(wgp->wgp_pubkey)) != NULL ||
+	    (wgp->wgp_name[0] &&
+		thmap_get(wg->wg_peers_byname, wgp->wgp_name,
+		    strlen(wgp->wgp_name)) != NULL)) {
+		mutex_exit(wg->wg_lock);
+		wg_destroy_peer(wgp);
+		error = EEXIST;
+		goto out;
+	}
+	wgp0 = thmap_put(wg->wg_peers_bypubkey, wgp->wgp_pubkey,
+	    sizeof(wgp->wgp_pubkey), wgp);
+	KASSERT(wgp0 == wgp);
+	if (wgp->wgp_name[0]) {
+		wgp0 = thmap_put(wg->wg_peers_byname, wgp->wgp_name,
+		    strlen(wgp->wgp_name), wgp);
+		KASSERT(wgp0 == wgp);
+	}
 	WG_PEER_WRITER_INSERT_HEAD(wgp, wg);
 	wg->wg_npeers++;
 	mutex_exit(wg->wg_lock);

Reply via email to