Module Name: src Committed By: riastradh Date: Thu Aug 20 21:31:37 UTC 2020
Modified Files: src/sys/net: if_wg.c Log Message: Implement sliding window for wireguard replay detection. To generate a diff of this commit: cvs rdiff -u -r1.5 -r1.6 src/sys/net/if_wg.c Please note that diffs are not public domain; they are subject to the copyright notices on the relevant files.
Modified files: Index: src/sys/net/if_wg.c diff -u src/sys/net/if_wg.c:1.5 src/sys/net/if_wg.c:1.6 --- src/sys/net/if_wg.c:1.5 Thu Aug 20 21:31:16 2020 +++ src/sys/net/if_wg.c Thu Aug 20 21:31:36 2020 @@ -1,4 +1,4 @@ -/* $NetBSD: if_wg.c,v 1.5 2020/08/20 21:31:16 riastradh Exp $ */ +/* $NetBSD: if_wg.c,v 1.6 2020/08/20 21:31:36 riastradh Exp $ */ /* * Copyright (C) Ryota Ozaki <ozaki.ry...@gmail.com> @@ -43,7 +43,7 @@ */ #include <sys/cdefs.h> -__KERNEL_RCSID(0, "$NetBSD: if_wg.c,v 1.5 2020/08/20 21:31:16 riastradh Exp $"); +__KERNEL_RCSID(0, "$NetBSD: if_wg.c,v 1.6 2020/08/20 21:31:36 riastradh Exp $"); #ifdef _KERNEL_OPT #include "opt_inet.h" @@ -337,6 +337,82 @@ struct wg_msg_cookie { #define WG_MSG_TYPE_DATA 4 #define WG_MSG_TYPE_MAX WG_MSG_TYPE_DATA +/* Sliding windows */ + +#define SLIWIN_BITS 2048u +#define SLIWIN_TYPE uint32_t +#define SLIWIN_BPW NBBY*sizeof(SLIWIN_TYPE) +#define SLIWIN_WORDS howmany(SLIWIN_BITS, SLIWIN_BPW) +#define SLIWIN_NPKT (SLIWIN_BITS - NBBY*sizeof(SLIWIN_TYPE)) + +struct sliwin { + SLIWIN_TYPE B[SLIWIN_WORDS]; + uint64_t T; +}; + +static void +sliwin_reset(struct sliwin *W) +{ + + memset(W, 0, sizeof(*W)); +} + +static int +sliwin_check_fast(const volatile struct sliwin *W, uint64_t S) +{ + + /* + * If it's more than one window older than the highest sequence + * number we've seen, reject. + */ + if (S + SLIWIN_NPKT < atomic_load_relaxed(&W->T)) + return EAUTH; + + /* + * Otherwise, we need to take the lock to decide, so don't + * reject just yet. Caller must serialize a call to + * sliwin_update in this case. + */ + return 0; +} + +static int +sliwin_update(struct sliwin *W, uint64_t S) +{ + unsigned word, bit; + + /* + * If it's more than one window older than the highest sequence + * number we've seen, reject. + */ + if (S + SLIWIN_NPKT < W->T) + return EAUTH; + + /* + * If it's higher than the highest sequence number we've seen, + * advance the window. + */ + if (S > W->T) { + uint64_t i = W->T / SLIWIN_BPW; + uint64_t j = S / SLIWIN_BPW; + unsigned k; + + for (k = 0; k < MIN(j - i, SLIWIN_WORDS); k++) + W->B[(i + k + 1) % SLIWIN_WORDS] = 0; + atomic_store_relaxed(&W->T, S); + } + + /* Test and set the bit -- if already set, reject. */ + word = (S / SLIWIN_BPW) % SLIWIN_WORDS; + bit = S % SLIWIN_BPW; + if (W->B[word] & (1UL << bit)) + return EAUTH; + W->B[word] |= 1UL << bit; + + /* Accept! */ + return 0; +} + struct wg_worker { kmutex_t wgw_lock; kcondvar_t wgw_cv; @@ -370,8 +446,11 @@ struct wg_session { uint32_t wgs_receiver_index; volatile uint64_t wgs_send_counter; - volatile uint64_t - wgs_recv_counter; + + struct { + kmutex_t lock; + struct sliwin window; + } *wgs_recvwin; uint8_t wgs_handshake_hash[WG_HASH_LEN]; uint8_t wgs_chaining_key[WG_CHAINING_KEY_LEN]; @@ -1942,7 +2021,7 @@ wg_clear_states(struct wg_session *wgs) { wgs->wgs_send_counter = 0; - wgs->wgs_recv_counter = 0; + sliwin_reset(&wgs->wgs_recvwin->window); #define wgs_clear(v) explicit_memset(wgs->wgs_##v, 0, sizeof(wgs->wgs_##v)) wgs_clear(handshake_hash); @@ -2231,6 +2310,15 @@ wg_handle_msg_data(struct wg_softc *wg, } wgp = wgs->wgs_peer; + error = sliwin_check_fast(&wgs->wgs_recvwin->window, + wgmd->wgmd_counter); + if (error) { + WG_LOG_RATECHECK(&wgp->wgp_ppsratecheck, LOG_DEBUG, + "out-of-window packet: %"PRIu64"\n", + wgmd->wgmd_counter); + goto out; + } + mlen = m_length(m); encrypted_len = mlen - sizeof(*wgmd); @@ -2281,17 +2369,17 @@ wg_handle_msg_data(struct wg_softc *wg, } WG_DLOG("outsize=%u\n", (u_int)decrypted_len); - /* TODO deal with reordering with a sliding window */ - if (wgs->wgs_recv_counter != 0 && - wgmd->wgmd_counter <= wgs->wgs_recv_counter) { + mutex_enter(&wgs->wgs_recvwin->lock); + error = sliwin_update(&wgs->wgs_recvwin->window, + wgmd->wgmd_counter); + mutex_exit(&wgs->wgs_recvwin->lock); + if (error) { WG_LOG_RATECHECK(&wgp->wgp_ppsratecheck, LOG_DEBUG, - "wgmd_counter is equal to or smaller than wgs_recv_counter:" - " %"PRIu64" <= %"PRIu64"\n", wgmd->wgmd_counter, - wgs->wgs_recv_counter); + "replay or out-of-window packet: %"PRIu64"\n", + wgmd->wgmd_counter); m_freem(n); goto out; } - wgs->wgs_recv_counter = wgmd->wgmd_counter; m_freem(m); m = NULL; @@ -3020,11 +3108,16 @@ wg_alloc_peer(struct wg_softc *wg) wgs->wgs_state = WGS_STATE_UNKNOWN; psref_target_init(&wgs->wgs_psref, wg_psref_class); wgs->wgs_lock = mutex_obj_alloc(MUTEX_DEFAULT, IPL_NONE); + wgs->wgs_recvwin = kmem_zalloc(sizeof(*wgs->wgs_recvwin), KM_SLEEP); + mutex_init(&wgs->wgs_recvwin->lock, MUTEX_DEFAULT, IPL_NONE); + wgs = wgp->wgp_session_unstable; wgs->wgs_peer = wgp; wgs->wgs_state = WGS_STATE_UNKNOWN; psref_target_init(&wgs->wgs_psref, wg_psref_class); wgs->wgs_lock = mutex_obj_alloc(MUTEX_DEFAULT, IPL_NONE); + wgs->wgs_recvwin = kmem_zalloc(sizeof(*wgs->wgs_recvwin), KM_SLEEP); + mutex_init(&wgs->wgs_recvwin->lock, MUTEX_DEFAULT, IPL_NONE); return wgp; } @@ -3061,10 +3154,14 @@ wg_destroy_peer(struct wg_peer *wgp) wgs = wgp->wgp_session_unstable; psref_target_destroy(&wgs->wgs_psref, wg_psref_class); mutex_obj_free(wgs->wgs_lock); + mutex_destroy(&wgs->wgs_recvwin->lock); + kmem_free(wgs->wgs_recvwin, sizeof(*wgs->wgs_recvwin)); kmem_free(wgs, sizeof(*wgs)); wgs = wgp->wgp_session_stable; psref_target_destroy(&wgs->wgs_psref, wg_psref_class); mutex_obj_free(wgs->wgs_lock); + mutex_destroy(&wgs->wgs_recvwin->lock); + kmem_free(wgs->wgs_recvwin, sizeof(*wgs->wgs_recvwin)); kmem_free(wgs, sizeof(*wgs)); psref_target_destroy(&wgp->wgp_endpoint->wgsa_psref, wg_psref_class);