Module Name:    src
Committed By:   riastradh
Date:           Thu Aug 20 21:29:44 UTC 2020

Modified Files:
        src/sys/net: if_wg.c

Log Message:
[ozaki-r] Fix bugs found by maxv's audits


To generate a diff of this commit:
cvs rdiff -u -r1.1 -r1.2 src/sys/net/if_wg.c

Please note that diffs are not public domain; they are subject to the
copyright notices on the relevant files.

Modified files:

Index: src/sys/net/if_wg.c
diff -u src/sys/net/if_wg.c:1.1 src/sys/net/if_wg.c:1.2
--- src/sys/net/if_wg.c:1.1	Thu Aug 20 21:28:01 2020
+++ src/sys/net/if_wg.c	Thu Aug 20 21:29:44 2020
@@ -1,4 +1,4 @@
-/*	$NetBSD: if_wg.c,v 1.1 2020/08/20 21:28:01 riastradh Exp $	*/
+/*	$NetBSD: if_wg.c,v 1.2 2020/08/20 21:29:44 riastradh Exp $	*/
 
 /*
  * Copyright (C) Ryota Ozaki <ozaki.ry...@gmail.com>
@@ -43,7 +43,7 @@
  */
 
 #include <sys/cdefs.h>
-__KERNEL_RCSID(0, "$NetBSD: if_wg.c,v 1.1 2020/08/20 21:28:01 riastradh Exp $");
+__KERNEL_RCSID(0, "$NetBSD: if_wg.c,v 1.2 2020/08/20 21:29:44 riastradh Exp $");
 
 #ifdef _KERNEL_OPT
 #include "opt_inet.h"
@@ -2047,35 +2047,40 @@ wg_change_endpoint(struct wg_peer *wgp, 
 	wg_schedule_peer_task(wgp, WGP_TASK_ENDPOINT_CHANGED);
 }
 
-static int
-wg_determine_af(char *packet)
+static bool
+wg_validate_inner_packet(char *packet, size_t decrypted_len, int *af)
 {
+	uint16_t packet_len;
 	struct ip *ip;
-	int af;
+
+	if (__predict_false(decrypted_len < sizeof(struct ip)))
+		return false;
 
 	ip = (struct ip *)packet;
-	af = ip->ip_v == 4 ? AF_INET : AF_INET6;
-	WG_DLOG("af=%d\n", af);
+	if (ip->ip_v == 4)
+		*af = AF_INET;
+	else if (ip->ip_v == 6)
+		*af = AF_INET6;
+	else
+		return false;
 
-	return af;
-}
+	WG_DLOG("af=%d\n", *af);
 
-static bool
-wg_validate_inner_length(int af, char *packet, size_t expected_len)
-{
-	uint16_t actual_len;
-
-	if (af == AF_INET) {
-		struct ip *ip = (struct ip *)packet;
-		actual_len = ntohs(ip->ip_len);
+	if (*af == AF_INET) {
+		packet_len = ntohs(ip->ip_len);
 	} else {
-		struct ip6_hdr *ip6 = (struct ip6_hdr *)packet;
-		actual_len = sizeof(struct ip6_hdr) + ntohs(ip6->ip6_plen);
+		struct ip6_hdr *ip6;
+
+		if (__predict_false(decrypted_len < sizeof(struct ip6_hdr)))
+			return false;
+
+		ip6 = (struct ip6_hdr *)packet;
+		packet_len = sizeof(struct ip6_hdr) + ntohs(ip6->ip6_plen);
 	}
-	WG_DLOG("actual_len=%u\n", actual_len);
-	if (actual_len > expected_len) {
+
+	WG_DLOG("packet_len=%u\n", packet_len);
+	if (packet_len > decrypted_len)
 		return false;
-	}
 
 	return true;
 }
@@ -2197,7 +2202,7 @@ static void
 wg_handle_msg_data(struct wg_softc *wg, struct mbuf *m,
     const struct sockaddr *src)
 {
-	struct wg_msg_data *wgmd = mtod(m, struct wg_msg_data *);
+	struct wg_msg_data *wgmd;
 	char *encrypted_buf = NULL, *decrypted_buf;
 	size_t encrypted_len, decrypted_len;
 	struct wg_session *wgs;
@@ -2208,6 +2213,13 @@ wg_handle_msg_data(struct wg_softc *wg, 
 	bool success, free_encrypted_buf = false, ok;
 	struct mbuf *n;
 
+	if (m->m_len < sizeof(struct wg_msg_data)) {
+		m = m_pullup(m, sizeof(struct wg_msg_data));
+		if (m == NULL)
+			return;
+	}
+	wgmd = mtod(m, struct wg_msg_data *);
+
 	KASSERT(wgmd->wgmd_type == WG_MSG_TYPE_DATA);
 	WG_TRACE("data");
 
@@ -2222,6 +2234,11 @@ wg_handle_msg_data(struct wg_softc *wg, 
 	mlen = m_length(m);
 	encrypted_len = mlen - sizeof(*wgmd);
 
+	if (encrypted_len < WG_AUTHTAG_LEN) {
+		WG_DLOG("Short encrypted_len: %lu\n", encrypted_len);
+		goto out;
+	}
+
 	success = m_ensure_contig(&m, sizeof(*wgmd) + encrypted_len);
 	if (success) {
 		encrypted_buf = mtod(m, char *) + sizeof(*wgmd);
@@ -2231,15 +2248,20 @@ wg_handle_msg_data(struct wg_softc *wg, 
 			WG_DLOG("failed to allocate encrypted_buf\n");
 			goto out;
 		}
-		m_copydata(m, sizeof(*wgmd), mlen - sizeof(*wgmd),
-		    encrypted_buf);
+		m_copydata(m, sizeof(*wgmd), encrypted_len, encrypted_buf);
 		free_encrypted_buf = true;
 	}
 	/* m_ensure_contig may change m regardless of its result */
 	wgmd = mtod(m, struct wg_msg_data *);
 
-	decrypted_len = encrypted_len; /* To avoid zero length */
-	n = wg_get_mbuf(0, decrypted_len);
+	decrypted_len = encrypted_len - WG_AUTHTAG_LEN;
+	if (decrypted_len > MCLBYTES) {
+		/* FIXME handle larger data than MCLBYTES */
+		WG_DLOG("couldn't handle larger data than MCLBYTES\n");
+		goto out;
+	}
+
+	n = wg_get_mbuf(0, decrypted_len + WG_AUTHTAG_LEN); /* To avoid zero length */
 	if (n == NULL) {
 		WG_DLOG("wg_get_mbuf failed\n");
 		goto out;
@@ -2275,8 +2297,7 @@ wg_handle_msg_data(struct wg_softc *wg, 
 	m = NULL;
 	wgmd = NULL;
 
-	af = wg_determine_af(decrypted_buf);
-	ok = wg_validate_inner_length(af, decrypted_buf, decrypted_len);
+	ok = wg_validate_inner_packet(decrypted_buf, decrypted_len, &af);
 	if (!ok) {
 		/* something wrong... */
 		m_freem(n);
@@ -2359,7 +2380,7 @@ wg_handle_msg_data(struct wg_softc *wg, 
 out:
 	wg_put_session(wgs, &psref);
 	if (m != NULL)
-		m_free(m);
+		m_freem(m);
 	if (free_encrypted_buf)
 		kmem_intr_free(encrypted_buf, encrypted_len);
 }
@@ -2410,10 +2431,55 @@ out:
 	wg_put_session(wgs, &psref);
 }
 
+static bool
+wg_validate_msg_length(struct wg_softc *wg, const struct mbuf *m)
+{
+	struct wg_msg *wgm;
+	size_t mlen;
+
+	mlen = m_length(m);
+	if (__predict_false(mlen < sizeof(struct wg_msg)))
+		return false;
+
+	wgm = mtod(m, struct wg_msg *);
+	switch (wgm->wgm_type) {
+	case WG_MSG_TYPE_INIT:
+		if (__predict_true(mlen >= sizeof(struct wg_msg_init)))
+			return true;
+		break;
+	case WG_MSG_TYPE_RESP:
+		if (__predict_true(mlen >= sizeof(struct wg_msg_resp)))
+			return true;
+		break;
+	case WG_MSG_TYPE_COOKIE:
+		if (__predict_true(mlen >= sizeof(struct wg_msg_cookie)))
+			return true;
+		break;
+	case WG_MSG_TYPE_DATA:
+		if (__predict_true(mlen >= sizeof(struct wg_msg_data)))
+			return true;
+		break;
+	default:
+		WG_LOG_RATECHECK(&wg->wg_ppsratecheck, LOG_DEBUG,
+		    "Unexpected msg type: %u\n", wgm->wgm_type);
+		return false;
+	}
+	WG_DLOG("Invalid msg size: mlen=%lu type=%u\n", mlen, wgm->wgm_type);
+
+	return false;
+}
+
 static void
 wg_handle_packet(struct wg_softc *wg, struct mbuf *m, const struct sockaddr *src)
 {
 	struct wg_msg *wgm;
+	bool valid;
+
+	valid = wg_validate_msg_length(wg, m);
+	if (!valid) {
+		m_freem(m);
+		return;
+	}
 
 	wgm = mtod(m, struct wg_msg *);
 	switch (wgm->wgm_type) {
@@ -2430,9 +2496,7 @@ wg_handle_packet(struct wg_softc *wg, st
 		wg_handle_msg_data(wg, m, src);
 		break;
 	default:
-		WG_LOG_RATECHECK(&wg->wg_ppsratecheck, LOG_DEBUG,
-		    "Unexpected msg type: %u\n", wgm->wgm_type);
-		m_freem(m);
+		/* wg_validate_msg_length should already reject this case */
 		break;
 	}
 }
@@ -2677,14 +2741,15 @@ wg_overudp_cb(struct mbuf **mp, int offs
     struct sockaddr *src, void *arg)
 {
 	struct wg_softc *wg = arg;
-	struct wg_msg *wgm;
+	struct wg_msg wgm;
 	struct mbuf *m = *mp;
 
 	WG_TRACE("enter");
 
-	wgm = (struct wg_msg *)(m->m_data + offset);
-	WG_DLOG("type=%d\n", wgm->wgm_type);
-	switch (wgm->wgm_type) {
+	m_copydata(m, offset, sizeof(struct wg_msg), &wgm);
+	WG_DLOG("type=%d\n", wgm.wgm_type);
+
+	switch (wgm.wgm_type) {
 	case WG_MSG_TYPE_DATA:
 		m_adj(m, offset);
 		wg_handle_msg_data(wg, m, src);
@@ -3368,8 +3433,8 @@ wg_send_data_msg(struct wg_peer *wgp, st
 
 	mlen = m_length(m);
 	inner_len = mlen;
-	padded_len = mlen + (mlen % 16);
-	encrypted_len = mlen + (mlen % 16) + WG_AUTHTAG_LEN;
+	padded_len = roundup(mlen, 16);
+	encrypted_len = padded_len + WG_AUTHTAG_LEN;
 	WG_DLOG("inner=%lu, padded=%lu, encrypted_len=%lu\n",
 	    inner_len, padded_len, encrypted_len);
 	if (mlen != 0) {

Reply via email to