Module Name: src Committed By: riastradh Date: Thu Aug 20 21:29:44 UTC 2020
Modified Files: src/sys/net: if_wg.c Log Message: [ozaki-r] Fix bugs found by maxv's audits To generate a diff of this commit: cvs rdiff -u -r1.1 -r1.2 src/sys/net/if_wg.c Please note that diffs are not public domain; they are subject to the copyright notices on the relevant files.
Modified files: Index: src/sys/net/if_wg.c diff -u src/sys/net/if_wg.c:1.1 src/sys/net/if_wg.c:1.2 --- src/sys/net/if_wg.c:1.1 Thu Aug 20 21:28:01 2020 +++ src/sys/net/if_wg.c Thu Aug 20 21:29:44 2020 @@ -1,4 +1,4 @@ -/* $NetBSD: if_wg.c,v 1.1 2020/08/20 21:28:01 riastradh Exp $ */ +/* $NetBSD: if_wg.c,v 1.2 2020/08/20 21:29:44 riastradh Exp $ */ /* * Copyright (C) Ryota Ozaki <ozaki.ry...@gmail.com> @@ -43,7 +43,7 @@ */ #include <sys/cdefs.h> -__KERNEL_RCSID(0, "$NetBSD: if_wg.c,v 1.1 2020/08/20 21:28:01 riastradh Exp $"); +__KERNEL_RCSID(0, "$NetBSD: if_wg.c,v 1.2 2020/08/20 21:29:44 riastradh Exp $"); #ifdef _KERNEL_OPT #include "opt_inet.h" @@ -2047,35 +2047,40 @@ wg_change_endpoint(struct wg_peer *wgp, wg_schedule_peer_task(wgp, WGP_TASK_ENDPOINT_CHANGED); } -static int -wg_determine_af(char *packet) +static bool +wg_validate_inner_packet(char *packet, size_t decrypted_len, int *af) { + uint16_t packet_len; struct ip *ip; - int af; + + if (__predict_false(decrypted_len < sizeof(struct ip))) + return false; ip = (struct ip *)packet; - af = ip->ip_v == 4 ? AF_INET : AF_INET6; - WG_DLOG("af=%d\n", af); + if (ip->ip_v == 4) + *af = AF_INET; + else if (ip->ip_v == 6) + *af = AF_INET6; + else + return false; - return af; -} + WG_DLOG("af=%d\n", *af); -static bool -wg_validate_inner_length(int af, char *packet, size_t expected_len) -{ - uint16_t actual_len; - - if (af == AF_INET) { - struct ip *ip = (struct ip *)packet; - actual_len = ntohs(ip->ip_len); + if (*af == AF_INET) { + packet_len = ntohs(ip->ip_len); } else { - struct ip6_hdr *ip6 = (struct ip6_hdr *)packet; - actual_len = sizeof(struct ip6_hdr) + ntohs(ip6->ip6_plen); + struct ip6_hdr *ip6; + + if (__predict_false(decrypted_len < sizeof(struct ip6_hdr))) + return false; + + ip6 = (struct ip6_hdr *)packet; + packet_len = sizeof(struct ip6_hdr) + ntohs(ip6->ip6_plen); } - WG_DLOG("actual_len=%u\n", actual_len); - if (actual_len > expected_len) { + + WG_DLOG("packet_len=%u\n", packet_len); + if (packet_len > decrypted_len) return false; - } return true; } @@ -2197,7 +2202,7 @@ static void wg_handle_msg_data(struct wg_softc *wg, struct mbuf *m, const struct sockaddr *src) { - struct wg_msg_data *wgmd = mtod(m, struct wg_msg_data *); + struct wg_msg_data *wgmd; char *encrypted_buf = NULL, *decrypted_buf; size_t encrypted_len, decrypted_len; struct wg_session *wgs; @@ -2208,6 +2213,13 @@ wg_handle_msg_data(struct wg_softc *wg, bool success, free_encrypted_buf = false, ok; struct mbuf *n; + if (m->m_len < sizeof(struct wg_msg_data)) { + m = m_pullup(m, sizeof(struct wg_msg_data)); + if (m == NULL) + return; + } + wgmd = mtod(m, struct wg_msg_data *); + KASSERT(wgmd->wgmd_type == WG_MSG_TYPE_DATA); WG_TRACE("data"); @@ -2222,6 +2234,11 @@ wg_handle_msg_data(struct wg_softc *wg, mlen = m_length(m); encrypted_len = mlen - sizeof(*wgmd); + if (encrypted_len < WG_AUTHTAG_LEN) { + WG_DLOG("Short encrypted_len: %lu\n", encrypted_len); + goto out; + } + success = m_ensure_contig(&m, sizeof(*wgmd) + encrypted_len); if (success) { encrypted_buf = mtod(m, char *) + sizeof(*wgmd); @@ -2231,15 +2248,20 @@ wg_handle_msg_data(struct wg_softc *wg, WG_DLOG("failed to allocate encrypted_buf\n"); goto out; } - m_copydata(m, sizeof(*wgmd), mlen - sizeof(*wgmd), - encrypted_buf); + m_copydata(m, sizeof(*wgmd), encrypted_len, encrypted_buf); free_encrypted_buf = true; } /* m_ensure_contig may change m regardless of its result */ wgmd = mtod(m, struct wg_msg_data *); - decrypted_len = encrypted_len; /* To avoid zero length */ - n = wg_get_mbuf(0, decrypted_len); + decrypted_len = encrypted_len - WG_AUTHTAG_LEN; + if (decrypted_len > MCLBYTES) { + /* FIXME handle larger data than MCLBYTES */ + WG_DLOG("couldn't handle larger data than MCLBYTES\n"); + goto out; + } + + n = wg_get_mbuf(0, decrypted_len + WG_AUTHTAG_LEN); /* To avoid zero length */ if (n == NULL) { WG_DLOG("wg_get_mbuf failed\n"); goto out; @@ -2275,8 +2297,7 @@ wg_handle_msg_data(struct wg_softc *wg, m = NULL; wgmd = NULL; - af = wg_determine_af(decrypted_buf); - ok = wg_validate_inner_length(af, decrypted_buf, decrypted_len); + ok = wg_validate_inner_packet(decrypted_buf, decrypted_len, &af); if (!ok) { /* something wrong... */ m_freem(n); @@ -2359,7 +2380,7 @@ wg_handle_msg_data(struct wg_softc *wg, out: wg_put_session(wgs, &psref); if (m != NULL) - m_free(m); + m_freem(m); if (free_encrypted_buf) kmem_intr_free(encrypted_buf, encrypted_len); } @@ -2410,10 +2431,55 @@ out: wg_put_session(wgs, &psref); } +static bool +wg_validate_msg_length(struct wg_softc *wg, const struct mbuf *m) +{ + struct wg_msg *wgm; + size_t mlen; + + mlen = m_length(m); + if (__predict_false(mlen < sizeof(struct wg_msg))) + return false; + + wgm = mtod(m, struct wg_msg *); + switch (wgm->wgm_type) { + case WG_MSG_TYPE_INIT: + if (__predict_true(mlen >= sizeof(struct wg_msg_init))) + return true; + break; + case WG_MSG_TYPE_RESP: + if (__predict_true(mlen >= sizeof(struct wg_msg_resp))) + return true; + break; + case WG_MSG_TYPE_COOKIE: + if (__predict_true(mlen >= sizeof(struct wg_msg_cookie))) + return true; + break; + case WG_MSG_TYPE_DATA: + if (__predict_true(mlen >= sizeof(struct wg_msg_data))) + return true; + break; + default: + WG_LOG_RATECHECK(&wg->wg_ppsratecheck, LOG_DEBUG, + "Unexpected msg type: %u\n", wgm->wgm_type); + return false; + } + WG_DLOG("Invalid msg size: mlen=%lu type=%u\n", mlen, wgm->wgm_type); + + return false; +} + static void wg_handle_packet(struct wg_softc *wg, struct mbuf *m, const struct sockaddr *src) { struct wg_msg *wgm; + bool valid; + + valid = wg_validate_msg_length(wg, m); + if (!valid) { + m_freem(m); + return; + } wgm = mtod(m, struct wg_msg *); switch (wgm->wgm_type) { @@ -2430,9 +2496,7 @@ wg_handle_packet(struct wg_softc *wg, st wg_handle_msg_data(wg, m, src); break; default: - WG_LOG_RATECHECK(&wg->wg_ppsratecheck, LOG_DEBUG, - "Unexpected msg type: %u\n", wgm->wgm_type); - m_freem(m); + /* wg_validate_msg_length should already reject this case */ break; } } @@ -2677,14 +2741,15 @@ wg_overudp_cb(struct mbuf **mp, int offs struct sockaddr *src, void *arg) { struct wg_softc *wg = arg; - struct wg_msg *wgm; + struct wg_msg wgm; struct mbuf *m = *mp; WG_TRACE("enter"); - wgm = (struct wg_msg *)(m->m_data + offset); - WG_DLOG("type=%d\n", wgm->wgm_type); - switch (wgm->wgm_type) { + m_copydata(m, offset, sizeof(struct wg_msg), &wgm); + WG_DLOG("type=%d\n", wgm.wgm_type); + + switch (wgm.wgm_type) { case WG_MSG_TYPE_DATA: m_adj(m, offset); wg_handle_msg_data(wg, m, src); @@ -3368,8 +3433,8 @@ wg_send_data_msg(struct wg_peer *wgp, st mlen = m_length(m); inner_len = mlen; - padded_len = mlen + (mlen % 16); - encrypted_len = mlen + (mlen % 16) + WG_AUTHTAG_LEN; + padded_len = roundup(mlen, 16); + encrypted_len = padded_len + WG_AUTHTAG_LEN; WG_DLOG("inner=%lu, padded=%lu, encrypted_len=%lu\n", inner_len, padded_len, encrypted_len); if (mlen != 0) {