Module Name:    src
Committed By:   riastradh
Date:           Thu Aug 27 02:53:47 UTC 2020

Modified Files:
        src/sys/net: if_wg.c

Log Message:
wg: Use m_pullup to make message header contiguous before processing.


To generate a diff of this commit:
cvs rdiff -u -r1.25 -r1.26 src/sys/net/if_wg.c

Please note that diffs are not public domain; they are subject to the
copyright notices on the relevant files.

Modified files:

Index: src/sys/net/if_wg.c
diff -u src/sys/net/if_wg.c:1.25 src/sys/net/if_wg.c:1.26
--- src/sys/net/if_wg.c:1.25	Thu Aug 27 02:52:33 2020
+++ src/sys/net/if_wg.c	Thu Aug 27 02:53:47 2020
@@ -1,4 +1,4 @@
-/*	$NetBSD: if_wg.c,v 1.25 2020/08/27 02:52:33 riastradh Exp $	*/
+/*	$NetBSD: if_wg.c,v 1.26 2020/08/27 02:53:47 riastradh Exp $	*/
 
 /*
  * Copyright (C) Ryota Ozaki <ozaki.ry...@gmail.com>
@@ -41,7 +41,7 @@
  */
 
 #include <sys/cdefs.h>
-__KERNEL_RCSID(0, "$NetBSD: if_wg.c,v 1.25 2020/08/27 02:52:33 riastradh Exp $");
+__KERNEL_RCSID(0, "$NetBSD: if_wg.c,v 1.26 2020/08/27 02:53:47 riastradh Exp $");
 
 #ifdef _KERNEL_OPT
 #include "opt_inet.h"
@@ -2343,11 +2343,7 @@ wg_handle_msg_data(struct wg_softc *wg, 
 	bool success, free_encrypted_buf = false, ok;
 	struct mbuf *n;
 
-	if (m->m_len < sizeof(struct wg_msg_data)) {
-		m = m_pullup(m, sizeof(struct wg_msg_data));
-		if (m == NULL)
-			return;
-	}
+	KASSERT(m->m_len >= sizeof(struct wg_msg_data));
 	wgmd = mtod(m, struct wg_msg_data *);
 
 	KASSERT(wgmd->wgmd_type == WG_MSG_TYPE_DATA);
@@ -2573,42 +2569,63 @@ out:
 	wg_put_session(wgs, &psref);
 }
 
-static bool
-wg_validate_msg_length(struct wg_softc *wg, const struct mbuf *m)
+static struct mbuf *
+wg_validate_msg_header(struct wg_softc *wg, struct mbuf *m)
 {
-	struct wg_msg *wgm;
-	size_t mlen;
+	struct wg_msg wgm;
+	size_t mbuflen;
+	size_t msglen;
 
-	mlen = m_length(m);
-	if (__predict_false(mlen < sizeof(struct wg_msg)))
-		return false;
+	/*
+	 * Get the mbuf chain length.  It is already guaranteed, by
+	 * wg_overudp_cb, to be large enough for a struct wg_msg.
+	 */
+	mbuflen = m_length(m);
+	KASSERT(mbuflen >= sizeof(struct wg_msg));
 
-	wgm = mtod(m, struct wg_msg *);
-	switch (wgm->wgm_type) {
+	/*
+	 * Copy the message header (32-bit message type) out -- we'll
+	 * worry about contiguity and alignment later.
+	 */
+	m_copydata(m, 0, sizeof(wgm), &wgm);
+	switch (wgm.wgm_type) {
 	case WG_MSG_TYPE_INIT:
-		if (__predict_true(mlen >= sizeof(struct wg_msg_init)))
-			return true;
+		msglen = sizeof(struct wg_msg_init);
 		break;
 	case WG_MSG_TYPE_RESP:
-		if (__predict_true(mlen >= sizeof(struct wg_msg_resp)))
-			return true;
+		msglen = sizeof(struct wg_msg_resp);
 		break;
 	case WG_MSG_TYPE_COOKIE:
-		if (__predict_true(mlen >= sizeof(struct wg_msg_cookie)))
-			return true;
+		msglen = sizeof(struct wg_msg_cookie);
 		break;
 	case WG_MSG_TYPE_DATA:
-		if (__predict_true(mlen >= sizeof(struct wg_msg_data)))
-			return true;
+		msglen = sizeof(struct wg_msg_data);
 		break;
 	default:
 		WG_LOG_RATECHECK(&wg->wg_ppsratecheck, LOG_DEBUG,
-		    "Unexpected msg type: %u\n", wgm->wgm_type);
-		return false;
+		    "Unexpected msg type: %u\n", wgm.wgm_type);
+		goto error;
 	}
-	WG_DLOG("Invalid msg size: mlen=%lu type=%u\n", mlen, wgm->wgm_type);
 
-	return false;
+	/* Verify the mbuf chain is long enough for this type of message.  */
+	if (__predict_false(mbuflen < msglen)) {
+		WG_DLOG("Invalid msg size: mbuflen=%lu type=%u\n", mbuflen,
+		    wgm.wgm_type);
+		goto error;
+	}
+
+	/* Make the message header contiguous if necessary.  */
+	if (__predict_false(m->m_len < msglen)) {
+		m = m_pullup(m, msglen);
+		if (m == NULL)
+			return NULL;
+	}
+
+	return m;
+
+error:
+	m_freem(m);
+	return NULL;
 }
 
 static void
@@ -2616,14 +2633,12 @@ wg_handle_packet(struct wg_softc *wg, st
     const struct sockaddr *src)
 {
 	struct wg_msg *wgm;
-	bool valid;
 
-	valid = wg_validate_msg_length(wg, m);
-	if (!valid) {
-		m_freem(m);
+	m = wg_validate_msg_header(wg, m);
+	if (__predict_false(m == NULL))
 		return;
-	}
 
+	KASSERT(m->m_len >= sizeof(struct wg_msg));
 	wgm = mtod(m, struct wg_msg *);
 	switch (wgm->wgm_type) {
 	case WG_MSG_TYPE_INIT:
@@ -2639,7 +2654,7 @@ wg_handle_packet(struct wg_softc *wg, st
 		wg_handle_msg_data(wg, m, src);
 		break;
 	default:
-		/* wg_validate_msg_length should already reject this case */
+		/* wg_validate_msg_header should already reject this case */
 		break;
 	}
 }

Reply via email to