For preparing decryption request, several memory chunks are required
(aead_req, sgin, sgout, iv, aad). For submitting the decrypt request to
an accelerator, it is required that the buffers which are read by the
accelerator must be dma-able and not come from stack. The buffers for
aad and iv can be separately kmalloced each, but it is inefficient.
This patch does a combined allocation for preparing decryption request
and then segments into aead_req || sgin || sgout || iv || aad.

Signed-off-by: Vakul Garg <vakul.g...@nxp.com>
---
 include/net/tls.h |   4 -
 net/tls/tls_sw.c  | 257 +++++++++++++++++++++++++++++++-----------------------
 2 files changed, 148 insertions(+), 113 deletions(-)

diff --git a/include/net/tls.h b/include/net/tls.h
index d8b3b6578c01..d5c683e8bb22 100644
--- a/include/net/tls.h
+++ b/include/net/tls.h
@@ -124,10 +124,6 @@ struct tls_sw_context_rx {
        struct sk_buff *recv_pkt;
        u8 control;
        bool decrypted;
-
-       char rx_aad_ciphertext[TLS_AAD_SPACE_SIZE];
-       char rx_aad_plaintext[TLS_AAD_SPACE_SIZE];
-
 };
 
 struct tls_record_info {
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index cd5ed2d1dbe8..a478b06fc015 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -118,19 +118,13 @@ static int tls_do_decryption(struct sock *sk,
                             struct scatterlist *sgout,
                             char *iv_recv,
                             size_t data_len,
-                            struct sk_buff *skb,
-                            gfp_t flags)
+                            struct aead_request *aead_req)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
        struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
-       struct aead_request *aead_req;
-
        int ret;
 
-       aead_req = aead_request_alloc(ctx->aead_recv, flags);
-       if (!aead_req)
-               return -ENOMEM;
-
+       aead_request_set_tfm(aead_req, ctx->aead_recv);
        aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
        aead_request_set_crypt(aead_req, sgin, sgout,
                               data_len + tls_ctx->rx.tag_size,
@@ -139,8 +133,6 @@ static int tls_do_decryption(struct sock *sk,
                                  crypto_req_done, &ctx->async_wait);
 
        ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait);
-
-       aead_request_free(aead_req);
        return ret;
 }
 
@@ -727,8 +719,138 @@ static struct sk_buff *tls_wait_data(struct sock *sk, int 
flags,
        return skb;
 }
 
+/* This function decrypts the input skb into either out_iov or in out_sg
+ * or in skb buffers itself. The input parameter 'zc' indicates if
+ * zero-copy mode needs to be tried or not. With zero-copy mode, either
+ * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are
+ * NULL, then the decryption happens inside skb buffers itself, i.e.
+ * zero-copy gets disabled and 'zc' is updated.
+ */
+
+static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
+                           struct iov_iter *out_iov,
+                           struct scatterlist *out_sg,
+                           int *chunk, bool *zc)
+{
+       struct tls_context *tls_ctx = tls_get_ctx(sk);
+       struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+       struct strp_msg *rxm = strp_msg(skb);
+       int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
+       struct aead_request *aead_req;
+       struct sk_buff *unused;
+       u8 *aad, *iv, *mem = NULL;
+       struct scatterlist *sgin = NULL;
+       struct scatterlist *sgout = NULL;
+       const int data_len = rxm->full_len - tls_ctx->rx.overhead_size;
+
+       if (*zc && (out_iov || out_sg)) {
+               if (out_iov)
+                       n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1;
+               else if (out_sg)
+                       n_sgout = sg_nents(out_sg);
+               else
+                       goto no_zerocopy;
+
+               n_sgin = skb_nsg(skb, rxm->offset + tls_ctx->rx.prepend_size,
+                                rxm->full_len - tls_ctx->rx.prepend_size);
+       } else {
+no_zerocopy:
+               n_sgout = 0;
+               *zc = false;
+               n_sgin = skb_cow_data(skb, 0, &unused);
+       }
+
+       if (n_sgin < 1)
+               return -EBADMSG;
+
+       /* Increment to accommodate AAD */
+       n_sgin = n_sgin + 1;
+
+       nsg = n_sgin + n_sgout;
+
+       aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
+       mem_size = aead_size + (nsg * sizeof(struct scatterlist));
+       mem_size = mem_size + TLS_AAD_SPACE_SIZE;
+       mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv);
+
+       /* Allocate a single block of memory which contains
+        * aead_req || sgin[] || sgout[] || aad || iv.
+        * This order achieves correct alignment for aead_req, sgin, sgout.
+        */
+       mem = kmalloc(mem_size, sk->sk_allocation);
+       if (!mem)
+               return -ENOMEM;
+
+       /* Segment the allocated memory */
+       aead_req = (struct aead_request *)mem;
+       sgin = (struct scatterlist *)(mem + aead_size);
+       sgout = sgin + n_sgin;
+       aad = (u8 *)(sgout + n_sgout);
+       iv = aad + TLS_AAD_SPACE_SIZE;
+
+       /* Prepare IV */
+       err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
+                           iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
+                           tls_ctx->rx.iv_size);
+       if (err < 0) {
+               kfree(mem);
+               return err;
+       }
+       memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
+
+       /* Prepare AAD */
+       tls_make_aad(aad, rxm->full_len - tls_ctx->rx.overhead_size,
+                    tls_ctx->rx.rec_seq, tls_ctx->rx.rec_seq_size,
+                    ctx->control);
+
+       /* Prepare sgin */
+       sg_init_table(sgin, n_sgin);
+       sg_set_buf(&sgin[0], aad, TLS_AAD_SPACE_SIZE);
+       err = skb_to_sgvec(skb, &sgin[1],
+                          rxm->offset + tls_ctx->rx.prepend_size,
+                          rxm->full_len - tls_ctx->rx.prepend_size);
+       if (err < 0) {
+               kfree(mem);
+               return err;
+       }
+
+       if (n_sgout) {
+               if (out_iov) {
+                       sg_init_table(sgout, n_sgout);
+                       sg_set_buf(&sgout[0], aad, TLS_AAD_SPACE_SIZE);
+
+                       *chunk = 0;
+                       err = zerocopy_from_iter(sk, out_iov, data_len, &pages,
+                                                chunk, &sgout[1],
+                                                (n_sgout - 1), false);
+                       if (err < 0)
+                               goto fallback_to_reg_recv;
+               } else if (out_sg) {
+                       memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
+               } else {
+                       goto fallback_to_reg_recv;
+               }
+       } else {
+fallback_to_reg_recv:
+               sgout = sgin;
+               pages = 0;
+               *chunk = 0;
+               *zc = false;
+       }
+
+       /* Prepare and submit AEAD request */
+       err = tls_do_decryption(sk, sgin, sgout, iv, data_len, aead_req);
+
+       /* Release the pages in case iov was mapped to pages */
+       for (; pages > 0; pages--)
+               put_page(sg_page(&sgout[pages]));
+
+       kfree(mem);
+       return err;
+}
+
 static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
-                             struct scatterlist *sgout, bool *zc)
+                             struct iov_iter *dest, int *chunk, bool *zc)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
        struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
@@ -741,7 +863,7 @@ static int decrypt_skb_update(struct sock *sk, struct 
sk_buff *skb,
                return err;
 #endif
        if (!ctx->decrypted) {
-               err = decrypt_skb(sk, skb, sgout);
+               err = decrypt_internal(sk, skb, dest, NULL, chunk, zc);
                if (err < 0)
                        return err;
        } else {
@@ -760,67 +882,10 @@ static int decrypt_skb_update(struct sock *sk, struct 
sk_buff *skb,
 int decrypt_skb(struct sock *sk, struct sk_buff *skb,
                struct scatterlist *sgout)
 {
-       struct tls_context *tls_ctx = tls_get_ctx(sk);
-       struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
-       char iv[TLS_CIPHER_AES_GCM_128_SALT_SIZE + MAX_IV_SIZE];
-       struct scatterlist sgin_arr[MAX_SKB_FRAGS + 2];
-       struct scatterlist *sgin = &sgin_arr[0];
-       struct strp_msg *rxm = strp_msg(skb);
-       int ret, nsg;
-       struct sk_buff *unused;
-
-       ret = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
-                           iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
-                           tls_ctx->rx.iv_size);
-       if (ret < 0)
-               return ret;
-
-       memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
-       if (!sgout) {
-               nsg = skb_cow_data(skb, 0, &unused);
-       } else {
-               nsg = skb_nsg(skb,
-                             rxm->offset + tls_ctx->rx.prepend_size,
-                             rxm->full_len - tls_ctx->rx.prepend_size);
-               if (nsg <= 0)
-                       return nsg;
-       }
-
-       // We need one extra for ctx->rx_aad_ciphertext
-       nsg++;
-
-       if (nsg > ARRAY_SIZE(sgin_arr))
-               sgin = kmalloc_array(nsg, sizeof(*sgin), sk->sk_allocation);
-
-       if (!sgout)
-               sgout = sgin;
-
-       sg_init_table(sgin, nsg);
-       sg_set_buf(&sgin[0], ctx->rx_aad_ciphertext, TLS_AAD_SPACE_SIZE);
-
-       nsg = skb_to_sgvec(skb, &sgin[1],
-                          rxm->offset + tls_ctx->rx.prepend_size,
-                          rxm->full_len - tls_ctx->rx.prepend_size);
-       if (nsg < 0) {
-               ret = nsg;
-               goto out;
-       }
-
-       tls_make_aad(ctx->rx_aad_ciphertext,
-                    rxm->full_len - tls_ctx->rx.overhead_size,
-                    tls_ctx->rx.rec_seq,
-                    tls_ctx->rx.rec_seq_size,
-                    ctx->control);
-
-       ret = tls_do_decryption(sk, sgin, sgout, iv,
-                               rxm->full_len - tls_ctx->rx.overhead_size,
-                               skb, sk->sk_allocation);
-
-out:
-       if (sgin != &sgin_arr[0])
-               kfree(sgin);
+       bool zc = true;
+       int chunk;
 
-       return ret;
+       return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc);
 }
 
 static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
@@ -899,43 +964,17 @@ int tls_sw_recvmsg(struct sock *sk,
                }
 
                if (!ctx->decrypted) {
-                       int page_count;
-                       int to_copy;
-
-                       page_count = iov_iter_npages(&msg->msg_iter,
-                                                    MAX_SKB_FRAGS);
-                       to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
-                       if (!is_kvec && to_copy <= len && page_count < 
MAX_SKB_FRAGS &&
-                           likely(!(flags & MSG_PEEK)))  {
-                               struct scatterlist sgin[MAX_SKB_FRAGS + 1];
-                               int pages = 0;
+                       int to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
 
+                       if (!is_kvec && to_copy <= len &&
+                           likely(!(flags & MSG_PEEK)))
                                zc = true;
-                               sg_init_table(sgin, MAX_SKB_FRAGS + 1);
-                               sg_set_buf(&sgin[0], ctx->rx_aad_plaintext,
-                                          TLS_AAD_SPACE_SIZE);
-
-                               err = zerocopy_from_iter(sk, &msg->msg_iter,
-                                                        to_copy, &pages,
-                                                        &chunk, &sgin[1],
-                                                        MAX_SKB_FRAGS, false);
-                               if (err < 0)
-                                       goto fallback_to_reg_recv;
-
-                               err = decrypt_skb_update(sk, skb, sgin, &zc);
-                               for (; pages > 0; pages--)
-                                       put_page(sg_page(&sgin[pages]));
-                               if (err < 0) {
-                                       tls_err_abort(sk, EBADMSG);
-                                       goto recv_end;
-                               }
-                       } else {
-fallback_to_reg_recv:
-                               err = decrypt_skb_update(sk, skb, NULL, &zc);
-                               if (err < 0) {
-                                       tls_err_abort(sk, EBADMSG);
-                                       goto recv_end;
-                               }
+
+                       err = decrypt_skb_update(sk, skb, &msg->msg_iter,
+                                                &chunk, &zc);
+                       if (err < 0) {
+                               tls_err_abort(sk, EBADMSG);
+                               goto recv_end;
                        }
                        ctx->decrypted = true;
                }
@@ -986,7 +1025,7 @@ ssize_t tls_sw_splice_read(struct socket *sock,  loff_t 
*ppos,
        int err = 0;
        long timeo;
        int chunk;
-       bool zc;
+       bool zc = false;
 
        lock_sock(sk);
 
@@ -1003,7 +1042,7 @@ ssize_t tls_sw_splice_read(struct socket *sock,  loff_t 
*ppos,
        }
 
        if (!ctx->decrypted) {
-               err = decrypt_skb_update(sk, skb, NULL, &zc);
+               err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc);
 
                if (err < 0) {
                        tls_err_abort(sk, EBADMSG);
-- 
2.13.6

Reply via email to