Module Name:    src
Committed By:   riastradh
Date:           Thu Aug 20 21:31:37 UTC 2020

Modified Files:
        src/sys/net: if_wg.c

Log Message:
Implement sliding window for wireguard replay detection.


To generate a diff of this commit:
cvs rdiff -u -r1.5 -r1.6 src/sys/net/if_wg.c

Please note that diffs are not public domain; they are subject to the
copyright notices on the relevant files.

Modified files:

Index: src/sys/net/if_wg.c
diff -u src/sys/net/if_wg.c:1.5 src/sys/net/if_wg.c:1.6
--- src/sys/net/if_wg.c:1.5	Thu Aug 20 21:31:16 2020
+++ src/sys/net/if_wg.c	Thu Aug 20 21:31:36 2020
@@ -1,4 +1,4 @@
-/*	$NetBSD: if_wg.c,v 1.5 2020/08/20 21:31:16 riastradh Exp $	*/
+/*	$NetBSD: if_wg.c,v 1.6 2020/08/20 21:31:36 riastradh Exp $	*/
 
 /*
  * Copyright (C) Ryota Ozaki <ozaki.ry...@gmail.com>
@@ -43,7 +43,7 @@
  */
 
 #include <sys/cdefs.h>
-__KERNEL_RCSID(0, "$NetBSD: if_wg.c,v 1.5 2020/08/20 21:31:16 riastradh Exp $");
+__KERNEL_RCSID(0, "$NetBSD: if_wg.c,v 1.6 2020/08/20 21:31:36 riastradh Exp $");
 
 #ifdef _KERNEL_OPT
 #include "opt_inet.h"
@@ -337,6 +337,82 @@ struct wg_msg_cookie {
 #define WG_MSG_TYPE_DATA		4
 #define WG_MSG_TYPE_MAX			WG_MSG_TYPE_DATA
 
+/* Sliding windows */
+
+#define	SLIWIN_BITS	2048u
+#define	SLIWIN_TYPE	uint32_t
+#define	SLIWIN_BPW	NBBY*sizeof(SLIWIN_TYPE)
+#define	SLIWIN_WORDS	howmany(SLIWIN_BITS, SLIWIN_BPW)
+#define	SLIWIN_NPKT	(SLIWIN_BITS - NBBY*sizeof(SLIWIN_TYPE))
+
+struct sliwin {
+	SLIWIN_TYPE	B[SLIWIN_WORDS];
+	uint64_t	T;
+};
+
+static void
+sliwin_reset(struct sliwin *W)
+{
+
+	memset(W, 0, sizeof(*W));
+}
+
+static int
+sliwin_check_fast(const volatile struct sliwin *W, uint64_t S)
+{
+
+	/*
+	 * If it's more than one window older than the highest sequence
+	 * number we've seen, reject.
+	 */
+	if (S + SLIWIN_NPKT < atomic_load_relaxed(&W->T))
+		return EAUTH;
+
+	/*
+	 * Otherwise, we need to take the lock to decide, so don't
+	 * reject just yet.  Caller must serialize a call to
+	 * sliwin_update in this case.
+	 */
+	return 0;
+}
+
+static int
+sliwin_update(struct sliwin *W, uint64_t S)
+{
+	unsigned word, bit;
+
+	/*
+	 * If it's more than one window older than the highest sequence
+	 * number we've seen, reject.
+	 */
+	if (S + SLIWIN_NPKT < W->T)
+		return EAUTH;
+
+	/*
+	 * If it's higher than the highest sequence number we've seen,
+	 * advance the window.
+	 */
+	if (S > W->T) {
+		uint64_t i = W->T / SLIWIN_BPW;
+		uint64_t j = S / SLIWIN_BPW;
+		unsigned k;
+
+		for (k = 0; k < MIN(j - i, SLIWIN_WORDS); k++)
+			W->B[(i + k + 1) % SLIWIN_WORDS] = 0;
+		atomic_store_relaxed(&W->T, S);
+	}
+
+	/* Test and set the bit -- if already set, reject.  */
+	word = (S / SLIWIN_BPW) % SLIWIN_WORDS;
+	bit = S % SLIWIN_BPW;
+	if (W->B[word] & (1UL << bit))
+		return EAUTH;
+	W->B[word] |= 1UL << bit;
+
+	/* Accept!  */
+	return 0;
+}
+
 struct wg_worker {
 	kmutex_t	wgw_lock;
 	kcondvar_t	wgw_cv;
@@ -370,8 +446,11 @@ struct wg_session {
 	uint32_t	wgs_receiver_index;
 	volatile uint64_t
 			wgs_send_counter;
-	volatile uint64_t
-			wgs_recv_counter;
+
+	struct {
+		kmutex_t	lock;
+		struct sliwin	window;
+	}		*wgs_recvwin;
 
 	uint8_t		wgs_handshake_hash[WG_HASH_LEN];
 	uint8_t		wgs_chaining_key[WG_CHAINING_KEY_LEN];
@@ -1942,7 +2021,7 @@ wg_clear_states(struct wg_session *wgs)
 {
 
 	wgs->wgs_send_counter = 0;
-	wgs->wgs_recv_counter = 0;
+	sliwin_reset(&wgs->wgs_recvwin->window);
 
 #define wgs_clear(v)	explicit_memset(wgs->wgs_##v, 0, sizeof(wgs->wgs_##v))
 	wgs_clear(handshake_hash);
@@ -2231,6 +2310,15 @@ wg_handle_msg_data(struct wg_softc *wg, 
 	}
 	wgp = wgs->wgs_peer;
 
+	error = sliwin_check_fast(&wgs->wgs_recvwin->window,
+	    wgmd->wgmd_counter);
+	if (error) {
+		WG_LOG_RATECHECK(&wgp->wgp_ppsratecheck, LOG_DEBUG,
+		    "out-of-window packet: %"PRIu64"\n",
+		    wgmd->wgmd_counter);
+		goto out;
+	}
+
 	mlen = m_length(m);
 	encrypted_len = mlen - sizeof(*wgmd);
 
@@ -2281,17 +2369,17 @@ wg_handle_msg_data(struct wg_softc *wg, 
 	}
 	WG_DLOG("outsize=%u\n", (u_int)decrypted_len);
 
-	/* TODO deal with reordering with a sliding window */
-	if (wgs->wgs_recv_counter != 0 &&
-	    wgmd->wgmd_counter <= wgs->wgs_recv_counter) {
+	mutex_enter(&wgs->wgs_recvwin->lock);
+	error = sliwin_update(&wgs->wgs_recvwin->window,
+	    wgmd->wgmd_counter);
+	mutex_exit(&wgs->wgs_recvwin->lock);
+	if (error) {
 		WG_LOG_RATECHECK(&wgp->wgp_ppsratecheck, LOG_DEBUG,
-		    "wgmd_counter is equal to or smaller than wgs_recv_counter:"
-		    " %"PRIu64" <= %"PRIu64"\n", wgmd->wgmd_counter,
-		    wgs->wgs_recv_counter);
+		    "replay or out-of-window packet: %"PRIu64"\n",
+		    wgmd->wgmd_counter);
 		m_freem(n);
 		goto out;
 	}
-	wgs->wgs_recv_counter = wgmd->wgmd_counter;
 
 	m_freem(m);
 	m = NULL;
@@ -3020,11 +3108,16 @@ wg_alloc_peer(struct wg_softc *wg)
 	wgs->wgs_state = WGS_STATE_UNKNOWN;
 	psref_target_init(&wgs->wgs_psref, wg_psref_class);
 	wgs->wgs_lock = mutex_obj_alloc(MUTEX_DEFAULT, IPL_NONE);
+	wgs->wgs_recvwin = kmem_zalloc(sizeof(*wgs->wgs_recvwin), KM_SLEEP);
+	mutex_init(&wgs->wgs_recvwin->lock, MUTEX_DEFAULT, IPL_NONE);
+
 	wgs = wgp->wgp_session_unstable;
 	wgs->wgs_peer = wgp;
 	wgs->wgs_state = WGS_STATE_UNKNOWN;
 	psref_target_init(&wgs->wgs_psref, wg_psref_class);
 	wgs->wgs_lock = mutex_obj_alloc(MUTEX_DEFAULT, IPL_NONE);
+	wgs->wgs_recvwin = kmem_zalloc(sizeof(*wgs->wgs_recvwin), KM_SLEEP);
+	mutex_init(&wgs->wgs_recvwin->lock, MUTEX_DEFAULT, IPL_NONE);
 
 	return wgp;
 }
@@ -3061,10 +3154,14 @@ wg_destroy_peer(struct wg_peer *wgp)
 	wgs = wgp->wgp_session_unstable;
 	psref_target_destroy(&wgs->wgs_psref, wg_psref_class);
 	mutex_obj_free(wgs->wgs_lock);
+	mutex_destroy(&wgs->wgs_recvwin->lock);
+	kmem_free(wgs->wgs_recvwin, sizeof(*wgs->wgs_recvwin));
 	kmem_free(wgs, sizeof(*wgs));
 	wgs = wgp->wgp_session_stable;
 	psref_target_destroy(&wgs->wgs_psref, wg_psref_class);
 	mutex_obj_free(wgs->wgs_lock);
+	mutex_destroy(&wgs->wgs_recvwin->lock);
+	kmem_free(wgs->wgs_recvwin, sizeof(*wgs->wgs_recvwin));
 	kmem_free(wgs, sizeof(*wgs));
 
 	psref_target_destroy(&wgp->wgp_endpoint->wgsa_psref, wg_psref_class);

Reply via email to