The branch main has been updated by rscheff:

URL: 
https://cgit.FreeBSD.org/src/commit/?id=85df11a1dec6eab9efbce9fd20712402a8e7ac7c

commit 85df11a1dec6eab9efbce9fd20712402a8e7ac7c
Author:     Richard Scheffenegger <[email protected]>
AuthorDate: 2024-03-13 11:35:51 +0000
Commit:     Richard Scheffenegger <[email protected]>
CommitDate: 2024-03-13 12:23:13 +0000

    ktls: deep copy tls_enable struct for in-kernel tcp consumers
    
    Doing a deep copy of the keys early allows users of the
    tls_enable structure to assume kernel memory.
    This enables the socket options to be set by kernel threads.
    
    Reviewed By:    #transport, tuexen, jhb, rrs
    Sponsored by:   NetApp, Inc.
    X-NetApp-PR:    #79
    Differential Revision:  https://reviews.freebsd.org/D44250
---
 sys/kern/uipc_ktls.c     | 96 ++++++++++++++++++++++++++++++++++++++++--------
 sys/netinet/tcp_usrreq.c | 44 ++++------------------
 sys/sys/ktls.h           | 17 +++++----
 3 files changed, 97 insertions(+), 60 deletions(-)

diff --git a/sys/kern/uipc_ktls.c b/sys/kern/uipc_ktls.c
index deba6940bbee..df296090ec97 100644
--- a/sys/kern/uipc_ktls.c
+++ b/sys/kern/uipc_ktls.c
@@ -297,10 +297,86 @@ SYSCTL_COUNTER_U64(_kern_ipc_tls_toe, OID_AUTO, chacha20, 
CTLFLAG_RD,
 
 static MALLOC_DEFINE(M_KTLS, "ktls", "Kernel TLS");
 
+static void ktls_reclaim_thread(void *ctx);
 static void ktls_reset_receive_tag(void *context, int pending);
 static void ktls_reset_send_tag(void *context, int pending);
 static void ktls_work_thread(void *ctx);
-static void ktls_reclaim_thread(void *ctx);
+
+int
+ktls_copyin_tls_enable(struct sockopt *sopt, struct tls_enable *tls)
+{
+       struct tls_enable_v0 tls_v0;
+       int error;
+       uint8_t *cipher_key = NULL, *iv = NULL, *auth_key = NULL;
+
+       if (sopt->sopt_valsize == sizeof(tls_v0)) {
+               error = sooptcopyin(sopt, &tls_v0, sizeof(tls_v0), 
sizeof(tls_v0));
+               if (error != 0)
+                       goto done;
+               memset(tls, 0, sizeof(*tls));
+               tls->cipher_key = tls_v0.cipher_key;
+               tls->iv = tls_v0.iv;
+               tls->auth_key = tls_v0.auth_key;
+               tls->cipher_algorithm = tls_v0.cipher_algorithm;
+               tls->cipher_key_len = tls_v0.cipher_key_len;
+               tls->iv_len = tls_v0.iv_len;
+               tls->auth_algorithm = tls_v0.auth_algorithm;
+               tls->auth_key_len = tls_v0.auth_key_len;
+               tls->flags = tls_v0.flags;
+               tls->tls_vmajor = tls_v0.tls_vmajor;
+               tls->tls_vminor = tls_v0.tls_vminor;
+       } else
+               error = sooptcopyin(sopt, tls, sizeof(*tls), sizeof(*tls));
+
+       if (error != 0)
+               goto done;
+
+       /*
+        * Now do a deep copy of the variable-length arrays in the struct, so 
that
+        * subsequent consumers of it can reliably assume kernel memory. This
+        * requires doing our own allocations, which we will free in the
+        * error paths so that our caller need only worry about outstanding
+        * allocations existing on successful return.
+        */
+       cipher_key = malloc(tls->cipher_key_len, M_KTLS, M_WAITOK);
+       iv = malloc(tls->iv_len, M_KTLS, M_WAITOK);
+       auth_key = malloc(tls->auth_key_len, M_KTLS, M_WAITOK);
+       if (sopt->sopt_td != NULL) {
+               error = copyin(tls->cipher_key, cipher_key, 
tls->cipher_key_len);
+               if (error != 0)
+                       goto done;
+               error = copyin(tls->iv, iv, tls->iv_len);
+               if (error != 0)
+                       goto done;
+               error = copyin(tls->auth_key, auth_key, tls->auth_key_len);
+               if (error != 0)
+                       goto done;
+       } else {
+               bcopy(tls->cipher_key, cipher_key, tls->cipher_key_len);
+               bcopy(tls->iv, iv, tls->iv_len);
+               bcopy(tls->auth_key, auth_key, tls->auth_key_len);
+       }
+       tls->cipher_key = cipher_key;
+       tls->iv = iv;
+       tls->auth_key = auth_key;
+
+done:
+       if (error != 0) {
+               zfree(cipher_key, M_KTLS);
+               zfree(iv, M_KTLS);
+               zfree(auth_key, M_KTLS);
+       }
+
+       return (error);
+}
+
+void
+ktls_cleanup_tls_enable(struct tls_enable *tls)
+{
+       zfree(__DECONST(void *, tls->cipher_key), M_KTLS);
+       zfree(__DECONST(void *, tls->iv), M_KTLS);
+       zfree(__DECONST(void *, tls->auth_key), M_KTLS);
+}
 
 static u_int
 ktls_get_cpu(struct socket *so)
@@ -702,18 +778,12 @@ ktls_create_session(struct socket *so, struct tls_enable 
*en,
                tls->params.auth_key_len = en->auth_key_len;
                tls->params.auth_key = malloc(en->auth_key_len, M_KTLS,
                    M_WAITOK);
-               error = copyin(en->auth_key, tls->params.auth_key,
-                   en->auth_key_len);
-               if (error)
-                       goto out;
+               bcopy(en->auth_key, tls->params.auth_key, en->auth_key_len);
        }
 
        tls->params.cipher_key_len = en->cipher_key_len;
        tls->params.cipher_key = malloc(en->cipher_key_len, M_KTLS, M_WAITOK);
-       error = copyin(en->cipher_key, tls->params.cipher_key,
-           en->cipher_key_len);
-       if (error)
-               goto out;
+       bcopy(en->cipher_key, tls->params.cipher_key, en->cipher_key_len);
 
        /*
         * This holds the implicit portion of the nonce for AEAD
@@ -722,9 +792,7 @@ ktls_create_session(struct socket *so, struct tls_enable 
*en,
         */
        if (en->iv_len != 0) {
                tls->params.iv_len = en->iv_len;
-               error = copyin(en->iv, tls->params.iv, en->iv_len);
-               if (error)
-                       goto out;
+               bcopy(en->iv, tls->params.iv, en->iv_len);
 
                /*
                 * For TLS 1.2 with GCM, generate an 8-byte nonce as a
@@ -740,10 +808,6 @@ ktls_create_session(struct socket *so, struct tls_enable 
*en,
 
        *tlsp = tls;
        return (0);
-
-out:
-       ktls_free(tls);
-       return (error);
 }
 
 static struct ktls_session *
diff --git a/sys/netinet/tcp_usrreq.c b/sys/netinet/tcp_usrreq.c
index a73d2a15c1d5..916fe33e8704 100644
--- a/sys/netinet/tcp_usrreq.c
+++ b/sys/netinet/tcp_usrreq.c
@@ -1914,37 +1914,6 @@ CTASSERT(TCP_CA_NAME_MAX <= TCP_LOG_ID_LEN);
 CTASSERT(TCP_LOG_REASON_LEN <= TCP_LOG_ID_LEN);
 #endif
 
-#ifdef KERN_TLS
-static int
-copyin_tls_enable(struct sockopt *sopt, struct tls_enable *tls)
-{
-       struct tls_enable_v0 tls_v0;
-       int error;
-
-       if (sopt->sopt_valsize == sizeof(tls_v0)) {
-               error = sooptcopyin(sopt, &tls_v0, sizeof(tls_v0),
-                   sizeof(tls_v0));
-               if (error)
-                       return (error);
-               memset(tls, 0, sizeof(*tls));
-               tls->cipher_key = tls_v0.cipher_key;
-               tls->iv = tls_v0.iv;
-               tls->auth_key = tls_v0.auth_key;
-               tls->cipher_algorithm = tls_v0.cipher_algorithm;
-               tls->cipher_key_len = tls_v0.cipher_key_len;
-               tls->iv_len = tls_v0.iv_len;
-               tls->auth_algorithm = tls_v0.auth_algorithm;
-               tls->auth_key_len = tls_v0.auth_key_len;
-               tls->flags = tls_v0.flags;
-               tls->tls_vmajor = tls_v0.tls_vmajor;
-               tls->tls_vminor = tls_v0.tls_vminor;
-               return (0);
-       }
-
-       return (sooptcopyin(sopt, tls, sizeof(*tls), sizeof(*tls)));
-}
-#endif
-
 extern struct cc_algo newreno_cc_algo;
 
 static int
@@ -2292,15 +2261,16 @@ unlock_and_done:
 #ifdef KERN_TLS
                case TCP_TXTLS_ENABLE:
                        INP_WUNLOCK(inp);
-                       error = copyin_tls_enable(sopt, &tls);
-                       if (error)
+                       error = ktls_copyin_tls_enable(sopt, &tls);
+                       if (error != 0)
                                break;
                        error = ktls_enable_tx(so, &tls);
+                       ktls_cleanup_tls_enable(&tls);
                        break;
                case TCP_TXTLS_MODE:
                        INP_WUNLOCK(inp);
                        error = sooptcopyin(sopt, &ui, sizeof(ui), sizeof(ui));
-                       if (error)
+                       if (error != 0)
                                return (error);
 
                        INP_WLOCK_RECHECK(inp);
@@ -2309,11 +2279,11 @@ unlock_and_done:
                        break;
                case TCP_RXTLS_ENABLE:
                        INP_WUNLOCK(inp);
-                       error = sooptcopyin(sopt, &tls, sizeof(tls),
-                           sizeof(tls));
-                       if (error)
+                       error = ktls_copyin_tls_enable(sopt, &tls);
+                       if (error != 0)
                                break;
                        error = ktls_enable_rx(so, &tls);
+                       ktls_cleanup_tls_enable(&tls);
                        break;
 #endif
                case TCP_MAXUNACKTIME:
diff --git a/sys/sys/ktls.h b/sys/sys/ktls.h
index 693864394ffe..9b3433f4b1fd 100644
--- a/sys/sys/ktls.h
+++ b/sys/sys/ktls.h
@@ -174,6 +174,7 @@ struct m_snd_tag;
 struct mbuf;
 struct sockbuf;
 struct socket;
+struct sockopt;
 
 struct ktls_session {
        struct ktls_ocf_session *ocf_session;
@@ -213,27 +214,29 @@ typedef enum {
 } ktls_mbuf_crypto_st_t;
 
 void ktls_check_rx(struct sockbuf *sb);
-ktls_mbuf_crypto_st_t ktls_mbuf_crypto_state(struct mbuf *mb, int offset, int 
len);
+void ktls_cleanup_tls_enable(struct tls_enable *tls);
+int ktls_copyin_tls_enable(struct sockopt *sopt, struct tls_enable *tls);
 void ktls_disable_ifnet(void *arg);
 int ktls_enable_rx(struct socket *so, struct tls_enable *en);
 int ktls_enable_tx(struct socket *so, struct tls_enable *en);
+void ktls_enqueue(struct mbuf *m, struct socket *so, int page_count);
+void ktls_enqueue_to_free(struct mbuf *m);
 void ktls_destroy(struct ktls_session *tls);
 void ktls_frame(struct mbuf *m, struct ktls_session *tls, int *enqueue_cnt,
     uint8_t record_type);
-bool ktls_permit_empty_frames(struct ktls_session *tls);
-void ktls_seq(struct sockbuf *sb, struct mbuf *m);
-void ktls_enqueue(struct mbuf *m, struct socket *so, int page_count);
-void ktls_enqueue_to_free(struct mbuf *m);
 int ktls_get_rx_mode(struct socket *so, int *modep);
-int ktls_set_tx_mode(struct socket *so, int mode);
 int ktls_get_tx_mode(struct socket *so, int *modep);
 int ktls_get_rx_sequence(struct inpcb *inp, uint32_t *tcpseq, uint64_t 
*tlsseq);
 void ktls_input_ifp_mismatch(struct sockbuf *sb, struct ifnet *ifp);
-int ktls_output_eagain(struct inpcb *inp, struct ktls_session *tls);
+ktls_mbuf_crypto_st_t ktls_mbuf_crypto_state(struct mbuf *mb, int offset, int 
len);
 #ifdef RATELIMIT
 int ktls_modify_txrtlmt(struct ktls_session *tls, uint64_t max_pacing_rate);
 #endif
+int ktls_output_eagain(struct inpcb *inp, struct ktls_session *tls);
 bool ktls_pending_rx_info(struct sockbuf *sb, uint64_t *seqnop, size_t 
*residp);
+bool ktls_permit_empty_frames(struct ktls_session *tls);
+void ktls_seq(struct sockbuf *sb, struct mbuf *m);
+int ktls_set_tx_mode(struct socket *so, int mode);
 
 static inline struct ktls_session *
 ktls_hold(struct ktls_session *tls)

Reply via email to