From: Ilya Lesokhin <il...@mellanox.com>

Signed-off-by: Dave Watson <davejwat...@fb.com>
Signed-off-by: Ilya Lesokhin <il...@mellanox.com>
Signed-off-by: Aviad Yehezkel <avia...@mellanox.com>
---
 MAINTAINERS        |   1 +
 include/net/tls.h  |  44 ++++
 net/tls/Makefile   |   2 +-
 net/tls/tls_main.c |  34 +--
 net/tls/tls_sw.c   | 729 +++++++++++++++++++++++++++++++++++++++++++++++++++++
 5 files changed, 794 insertions(+), 16 deletions(-)
 create mode 100644 net/tls/tls_sw.c

diff --git a/MAINTAINERS b/MAINTAINERS
index e3b70c3..413c1d9 100644
--- a/MAINTAINERS
+++ b/MAINTAINERS
@@ -8491,6 +8491,7 @@ M:        Ilya Lesokhin <il...@mellanox.com>
 M:     Aviad Yehezkel <avia...@mellanox.com>
 M:     Boris Pismenny <bor...@mellanox.com>
 M:     Haggai Eran <hagg...@mellanox.com>
+M:     Dave Watson <davejwat...@fb.com>
 L:     net...@vger.kernel.org
 T:     git git://git.kernel.org/pub/scm/linux/kernel/git/davem/net.git
 T:     git git://git.kernel.org/pub/scm/linux/kernel/git/davem/net-next.git
diff --git a/include/net/tls.h b/include/net/tls.h
index f7f0cde..bb1f41e 100644
--- a/include/net/tls.h
+++ b/include/net/tls.h
@@ -48,6 +48,7 @@
 
 #define TLS_CRYPTO_INFO_READY(info)    ((info)->cipher_type)
 #define TLS_IS_STATE_HW(info)          ((info)->state == TLS_STATE_HW)
+#define TLS_IS_STATE_SW(info)          ((info)->state == TLS_STATE_SW)
 
 #define TLS_RECORD_TYPE_DATA           0x17
 
@@ -68,6 +69,37 @@ struct tls_offload_context {
        spinlock_t lock;        /* protects records list */
 };
 
+#define TLS_DATA_PAGES                 (TLS_MAX_PAYLOAD_SIZE / PAGE_SIZE)
+/* +1 for aad, +1 for tag, +1 for chaining */
+#define TLS_SG_DATA_SIZE               (TLS_DATA_PAGES + 3)
+#define ALG_MAX_PAGES 16 /* for skb_to_sgvec */
+#define TLS_AAD_SPACE_SIZE             21
+#define TLS_AAD_SIZE                   13
+#define TLS_TAG_SIZE                   16
+
+#define TLS_NONCE_SIZE                 8
+#define TLS_PREPEND_SIZE               (TLS_HEADER_SIZE + TLS_NONCE_SIZE)
+#define TLS_OVERHEAD           (TLS_PREPEND_SIZE + TLS_TAG_SIZE)
+
+struct tls_sw_context {
+       struct sock *sk;
+       void (*sk_write_space)(struct sock *sk);
+       struct crypto_aead *aead_send;
+
+       /* Sending context */
+       struct scatterlist sg_tx_data[TLS_SG_DATA_SIZE];
+       struct scatterlist sg_tx_data2[ALG_MAX_PAGES + 1];
+       char aad_send[TLS_AAD_SPACE_SIZE];
+       char tag_send[TLS_TAG_SIZE];
+       skb_frag_t tx_frag;
+       int wmem_len;
+       int order_npages;
+       struct scatterlist sgaad_send[2];
+       struct scatterlist sgtag_send[2];
+       struct sk_buff_head tx_queue;
+       int unsent;
+};
+
 struct tls_context {
        union {
                struct tls_crypto_info crypto_send;
@@ -102,6 +134,12 @@ int tls_device_sendmsg(struct sock *sk, struct msghdr 
*msg, size_t size);
 int tls_device_sendpage(struct sock *sk, struct page *page,
                        int offset, size_t size, int flags);
 
+int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx);
+void tls_clear_sw_offload(struct sock *sk);
+int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
+int tls_sw_sendpage(struct sock *sk, struct page *page,
+                   int offset, size_t size, int flags);
+
 struct tls_record_info *tls_get_record(struct tls_offload_context *context,
                                       u32 seq);
 
@@ -174,6 +212,12 @@ static inline struct tls_context *tls_get_ctx(const struct 
sock *sk)
        return sk->sk_user_data;
 }
 
+static inline struct tls_sw_context *tls_sw_ctx(
+               const struct tls_context *tls_ctx)
+{
+       return (struct tls_sw_context *)tls_ctx->priv_ctx;
+}
+
 static inline struct tls_offload_context *tls_offload_ctx(
                const struct tls_context *tls_ctx)
 {
diff --git a/net/tls/Makefile b/net/tls/Makefile
index 65e5677..61457e0 100644
--- a/net/tls/Makefile
+++ b/net/tls/Makefile
@@ -4,4 +4,4 @@
 
 obj-$(CONFIG_TLS) += tls.o
 
-tls-y := tls_main.o tls_device.o
+tls-y := tls_main.o tls_device.o tls_sw.o
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 6a3df25..a4efd02 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -46,6 +46,7 @@ MODULE_DESCRIPTION("Transport Layer Security Support");
 MODULE_LICENSE("Dual BSD/GPL");
 
 static struct proto tls_device_prot;
+static struct proto tls_sw_prot;
 
 int tls_push_frags(struct sock *sk,
                   struct tls_context *ctx,
@@ -188,13 +189,10 @@ int tls_sk_query(struct sock *sk, int optname, char 
__user *optval,
                        rc = -EINVAL;
                        goto out;
                }
-               if (TLS_IS_STATE_HW(crypto_info)) {
-                       lock_sock(sk);
-                       memcpy(crypto_info_aes_gcm_128->iv,
-                              ctx->iv,
-                              TLS_CIPHER_AES_GCM_128_IV_SIZE);
-                       release_sock(sk);
-               }
+               lock_sock(sk);
+               memcpy(crypto_info_aes_gcm_128->iv, ctx->iv,
+                      TLS_CIPHER_AES_GCM_128_IV_SIZE);
+               release_sock(sk);
                rc = copy_to_user(optval,
                                  crypto_info_aes_gcm_128,
                                  sizeof(*crypto_info_aes_gcm_128));
@@ -224,6 +222,7 @@ int tls_sk_attach(struct sock *sk, int optname, char __user 
*optval,
        struct tls_context *ctx = tls_get_ctx(sk);
        struct tls_crypto_info *crypto_info;
        bool allocated_tls_ctx = false;
+       struct proto *prot = NULL;
 
        if (!optval || (optlen < sizeof(*crypto_info))) {
                rc = -EINVAL;
@@ -267,12 +266,6 @@ int tls_sk_attach(struct sock *sk, int optname, char 
__user *optval,
                goto err_sk_user_data;
        }
 
-       /* currently we support only HW offload */
-       if (!TLS_IS_STATE_HW(crypto_info)) {
-               rc = -ENOPROTOOPT;
-               goto err_crypto_info;
-       }
-
        /* check version */
        if (crypto_info->version != TLS_1_2_VERSION) {
                rc = -ENOTSUPP;
@@ -306,6 +299,12 @@ int tls_sk_attach(struct sock *sk, int optname, char 
__user *optval,
 
        if (TLS_IS_STATE_HW(crypto_info)) {
                rc = tls_set_device_offload(sk, ctx);
+               prot = &tls_device_prot;
+               if (rc)
+                       goto err_crypto_info;
+       } else if (TLS_IS_STATE_SW(crypto_info)) {
+               rc = tls_set_sw_offload(sk, ctx);
+               prot = &tls_sw_prot;
                if (rc)
                        goto err_crypto_info;
        }
@@ -315,8 +314,9 @@ int tls_sk_attach(struct sock *sk, int optname, char __user 
*optval,
                goto err_set_device_offload;
        }
 
-       /* TODO: add protection */
-       sk->sk_prot = &tls_device_prot;
+       rc = 0;
+
+       sk->sk_prot = prot;
        goto out;
 
 err_set_device_offload:
@@ -337,6 +337,10 @@ static int __init tls_init(void)
        tls_device_prot.sendmsg         = tls_device_sendmsg;
        tls_device_prot.sendpage        = tls_device_sendpage;
 
+       tls_sw_prot                     = tcp_prot;
+       tls_sw_prot.sendmsg             = tls_sw_sendmsg;
+       tls_sw_prot.sendpage            = tls_sw_sendpage;
+
        return 0;
 }
 
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
new file mode 100644
index 0000000..4698dc7
--- /dev/null
+++ b/net/tls/tls_sw.c
@@ -0,0 +1,729 @@
+/*
+ * af_tls: TLS socket
+ *
+ * Copyright (C) 2016
+ *
+ * Original authors:
+ *   Fridolin Pokorny <fridolin.poko...@gmail.com>
+ *   Nikos Mavrogiannopoulos <n...@gnults.org>
+ *   Dave Watson <davejwat...@fb.com>
+ *   Lance Chao <lancerc...@fb.com>
+ *
+ * Based on RFC 5288, RFC 6347, RFC 5246, RFC 6655
+ *
+ * This program is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU General Public License as
+ * published by the Free Software Foundation; either version 2 of the
+ * License, or (at your option) any later version.
+ */
+
+#include <linux/module.h>
+#include <net/tcp.h>
+#include <net/inet_common.h>
+#include <linux/highmem.h>
+#include <linux/netdevice.h>
+#include <crypto/aead.h>
+
+#include <net/tls.h>
+
+static int tls_kernel_sendpage(struct sock *sk, int flags);
+
+static inline void tls_make_aad(struct sock *sk,
+                               int recv,
+                               char *buf,
+                               size_t size,
+                               char *nonce_explicit,
+                               unsigned char record_type)
+{
+       memcpy(buf, nonce_explicit, TLS_NONCE_SIZE);
+
+       buf[8] = record_type;
+       buf[9] = TLS_1_2_VERSION_MAJOR;
+       buf[10] = TLS_1_2_VERSION_MINOR;
+       buf[11] = size >> 8;
+       buf[12] = size & 0xFF;
+}
+
+static int tls_do_encryption(struct sock *sk, struct scatterlist *sgin,
+                            struct scatterlist *sgout, size_t data_len,
+                            struct sk_buff *skb)
+{
+       struct tls_context *tls_ctx = tls_get_ctx(sk);
+       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       int ret;
+       unsigned int req_size = sizeof(struct aead_request) +
+               crypto_aead_reqsize(ctx->aead_send);
+       struct aead_request *aead_req;
+
+       pr_debug("tls_do_encryption %p\n", sk);
+
+       aead_req = kmalloc(req_size, GFP_ATOMIC);
+
+       if (!aead_req)
+               return -ENOMEM;
+
+       aead_request_set_tfm(aead_req, ctx->aead_send);
+       aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
+       aead_request_set_crypt(aead_req, sgin, sgout, data_len, tls_ctx->iv);
+
+       ret = crypto_aead_encrypt(aead_req);
+
+       kfree(aead_req);
+       if (ret < 0)
+               return ret;
+       tls_kernel_sendpage(sk, MSG_DONTWAIT);
+
+       return ret;
+}
+
+/* Allocates enough pages to hold the decrypted data, as well as
+ * setting ctx->sg_tx_data to the pages
+ */
+static int tls_pre_encrypt(struct sock *sk, size_t data_len)
+{
+       struct tls_context *tls_ctx = tls_get_ctx(sk);
+       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       int i;
+       unsigned int npages;
+       size_t aligned_size;
+       size_t encrypt_len;
+       struct scatterlist *sg;
+       int ret = 0;
+       struct page *tx_pages;
+
+       encrypt_len = data_len + TLS_OVERHEAD;
+       npages = encrypt_len / PAGE_SIZE;
+       aligned_size = npages * PAGE_SIZE;
+       if (aligned_size < encrypt_len)
+               npages++;
+
+       ctx->order_npages = order_base_2(npages);
+       WARN_ON(ctx->order_npages < 0 || ctx->order_npages > 3);
+       /* The first entry in sg_tx_data is AAD so skip it */
+       sg_init_table(ctx->sg_tx_data, TLS_SG_DATA_SIZE);
+       sg_set_buf(&ctx->sg_tx_data[0], ctx->aad_send, sizeof(ctx->aad_send));
+       tx_pages = alloc_pages(GFP_KERNEL | __GFP_COMP,
+                              ctx->order_npages);
+       if (!tx_pages) {
+               ret = -ENOMEM;
+               return ret;
+       }
+
+       sg = ctx->sg_tx_data + 1;
+       /* For the first page, leave room for prepend. It will be
+        * copied into the page later
+        */
+       sg_set_page(sg, tx_pages, PAGE_SIZE - TLS_PREPEND_SIZE,
+                   TLS_PREPEND_SIZE);
+       for (i = 1; i < npages; i++)
+               sg_set_page(sg + i, tx_pages + i, PAGE_SIZE, 0);
+
+       __skb_frag_set_page(&ctx->tx_frag, tx_pages);
+
+       return ret;
+}
+
+static void tls_release_tx_frag(struct sock *sk)
+{
+       struct tls_context *tls_ctx = tls_get_ctx(sk);
+       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       struct page *tx_page = skb_frag_page(&ctx->tx_frag);
+
+       if (!tls_is_pending_open_record(tls_ctx) && tx_page) {
+               struct sk_buff *head;
+               /* Successfully sent the whole packet, account for it*/
+
+               head = skb_peek(&ctx->tx_queue);
+               skb_dequeue(&ctx->tx_queue);
+               sk->sk_wmem_queued -= ctx->wmem_len;
+               sk_mem_uncharge(sk, ctx->wmem_len);
+               ctx->wmem_len = 0;
+               kfree_skb(head);
+               ctx->unsent -= skb_frag_size(&ctx->tx_frag) - TLS_OVERHEAD;
+               tls_increment_seqno(tls_ctx->iv, sk);
+               __free_pages(tx_page,
+                            ctx->order_npages);
+               __skb_frag_set_page(&ctx->tx_frag, NULL);
+       }
+       ctx->sk_write_space(sk);
+}
+
+static int tls_kernel_sendpage(struct sock *sk, int flags)
+{
+       int ret;
+       struct tls_context *tls_ctx = tls_get_ctx(sk);
+       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+
+       skb_frag_size_add(&ctx->tx_frag, TLS_OVERHEAD);
+       ret = tls_push_frags(sk, tls_ctx, &ctx->tx_frag, 1, 0, flags);
+       if (ret >= 0)
+               tls_release_tx_frag(sk);
+       else if (ret != -EAGAIN)
+               tls_err_abort(sk);
+
+       return ret;
+}
+
+static int tls_push_zerocopy(struct sock *sk, struct scatterlist *sgin,
+                            int pages, int bytes, unsigned char record_type)
+{
+       struct tls_context *tls_ctx = tls_get_ctx(sk);
+       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       int ret;
+
+       tls_make_aad(sk, 0, ctx->aad_send, bytes, tls_ctx->iv, record_type);
+
+       sg_chain(ctx->sgaad_send, 2, sgin);
+       //sg_unmark_end(&sgin[pages - 1]);
+       sg_chain(sgin, pages + 1, ctx->sgtag_send);
+       ret = sg_nents_for_len(ctx->sgaad_send, bytes + 13 + 16);
+
+       ret = tls_pre_encrypt(sk, bytes);
+       if (ret < 0)
+               goto out;
+
+       tls_fill_prepend(tls_ctx,
+                        page_address(skb_frag_page(&ctx->tx_frag)),
+                        bytes, record_type);
+
+       skb_frag_size_set(&ctx->tx_frag, bytes);
+
+       ret = tls_do_encryption(sk,
+                               ctx->sgaad_send,
+                               ctx->sg_tx_data,
+                               bytes, NULL);
+
+       if (ret < 0)
+               goto out;
+
+out:
+       if (ret < 0) {
+               sk->sk_err = EPIPE;
+               return ret;
+       }
+
+       return 0;
+}
+
+static int tls_push(struct sock *sk, unsigned char record_type)
+{
+       struct tls_context *tls_ctx = tls_get_ctx(sk);
+       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       int bytes = min_t(int, ctx->unsent, (int)TLS_MAX_PAYLOAD_SIZE);
+       int nsg, ret = 0;
+       struct sk_buff *head = skb_peek(&ctx->tx_queue);
+
+       if (!head)
+               return 0;
+
+       bytes = min_t(int, bytes, head->len);
+
+       sg_init_table(ctx->sg_tx_data2, ARRAY_SIZE(ctx->sg_tx_data2));
+       nsg = skb_to_sgvec(head, &ctx->sg_tx_data2[0], 0, bytes);
+
+       /* The length of sg into decryption must not be over
+        * ALG_MAX_PAGES. The aad takes the first sg, so the payload
+        * must be less than ALG_MAX_PAGES - 1
+        */
+       if (nsg > ALG_MAX_PAGES - 1) {
+               ret = -EBADMSG;
+               goto out;
+       }
+
+       tls_make_aad(sk, 0, ctx->aad_send, bytes, tls_ctx->iv, record_type);
+
+       sg_chain(ctx->sgaad_send, 2, ctx->sg_tx_data2);
+       sg_chain(ctx->sg_tx_data2,
+                nsg + 1,
+                ctx->sgtag_send);
+
+       ret = tls_pre_encrypt(sk, bytes);
+       if (ret < 0)
+               goto out;
+
+       tls_fill_prepend(tls_ctx,
+                        page_address(skb_frag_page(&ctx->tx_frag)),
+                        bytes, record_type);
+
+       skb_frag_size_set(&ctx->tx_frag, bytes);
+       tls_ctx->pending_offset = 0;
+       head->sk = sk;
+
+       ret = tls_do_encryption(sk,
+                               ctx->sgaad_send,
+                               ctx->sg_tx_data,
+                               bytes, head);
+
+       if (ret < 0)
+               goto out;
+
+out:
+       if (ret < 0) {
+               sk->sk_err = EPIPE;
+               return ret;
+       }
+
+       return 0;
+}
+
+static int zerocopy_from_iter(struct iov_iter *from,
+                             struct scatterlist *sg, int *bytes)
+{
+       //int len = iov_iter_count(from);
+       int n = 0;
+
+       if (bytes)
+               *bytes = 0;
+
+       //TODO pass in number of pages
+       while (iov_iter_count(from) && n < MAX_SKB_FRAGS - 1) {
+               struct page *pages[MAX_SKB_FRAGS];
+               size_t start;
+               ssize_t copied;
+               int j = 0;
+
+               if (bytes && *bytes >= TLS_MAX_PAYLOAD_SIZE)
+                       break;
+
+               copied = iov_iter_get_pages(from, pages, TLS_MAX_PAYLOAD_SIZE,
+                                           MAX_SKB_FRAGS - n, &start);
+               if (bytes)
+                       *bytes += copied;
+               if (copied < 0)
+                       return -EFAULT;
+
+               iov_iter_advance(from, copied);
+
+               while (copied) {
+                       int size = min_t(int, copied, PAGE_SIZE - start);
+
+                       sg_set_page(&sg[n], pages[j], size, start);
+                       start = 0;
+                       copied -= size;
+                       j++;
+                       n++;
+               }
+       }
+       return n;
+}
+
+int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
+{
+       struct tls_context *tls_ctx = tls_get_ctx(sk);
+       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       int ret = 0;
+       long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
+       bool eor = !(msg->msg_flags & MSG_MORE);
+       struct sk_buff *skb = NULL;
+       size_t copy, copied = 0;
+       unsigned char record_type = TLS_RECORD_TYPE_DATA;
+
+       lock_sock(sk);
+
+       if (msg->msg_flags & MSG_OOB) {
+               if (!eor || ctx->unsent) {
+                       ret = -EINVAL;
+                       goto send_end;
+               }
+
+               ret = copy_from_iter(&record_type, 1, &msg->msg_iter);
+               if (ret != 1) {
+                       return -EFAULT;
+                       goto send_end;
+               }
+       }
+
+       while (msg_data_left(msg)) {
+               bool merge = true;
+               int i;
+               struct page_frag *pfrag;
+
+               if (sk->sk_err)
+                       goto send_end;
+               if (!sk_stream_memory_free(sk))
+                       goto wait_for_memory;
+
+               skb = skb_peek_tail(&ctx->tx_queue);
+               // Try for zerocopy
+               if (!skb && !skb_frag_page(&ctx->tx_frag) && eor) {
+                       int pages;
+                       int err;
+                       // TODO can send partial pages?
+                       int page_count = iov_iter_npages(&msg->msg_iter,
+                                                        ALG_MAX_PAGES);
+                       struct scatterlist sgin[ALG_MAX_PAGES + 1];
+                       int bytes;
+
+                       sg_init_table(sgin, ALG_MAX_PAGES + 1);
+
+                       if (page_count >= ALG_MAX_PAGES)
+                               goto reg_send;
+
+                       // TODO check pages?
+                       err = zerocopy_from_iter(&msg->msg_iter, &sgin[0],
+                                                &bytes);
+                       pages = err;
+                       ctx->unsent += bytes;
+                       if (err < 0)
+                               goto send_end;
+
+                       // Try to send msg
+                       tls_push_zerocopy(sk, sgin, pages, bytes, record_type);
+                       for (; pages > 0; pages--)
+                               put_page(sg_page(&sgin[pages - 1]));
+                       if (err < 0) {
+                               tls_err_abort(sk);
+                               goto send_end;
+                       }
+                       continue;
+               }
+
+reg_send:
+               while (!skb) {
+                       skb = alloc_skb(0, sk->sk_allocation);
+                       if (skb)
+                               __skb_queue_tail(&ctx->tx_queue, skb);
+               }
+
+               i = skb_shinfo(skb)->nr_frags;
+               pfrag = sk_page_frag(sk);
+
+               if (!sk_page_frag_refill(sk, pfrag))
+                       goto wait_for_memory;
+
+               if (!skb_can_coalesce(skb, i, pfrag->page,
+                                     pfrag->offset)) {
+                       if (i == ALG_MAX_PAGES) {
+                               struct sk_buff *tskb;
+
+                               tskb = alloc_skb(0, sk->sk_allocation);
+                               if (!tskb)
+                                       goto wait_for_memory;
+
+                               if (skb)
+                                       skb->next = tskb;
+                               else
+                                       __skb_queue_tail(&ctx->tx_queue,
+                                                        tskb);
+
+                               skb = tskb;
+                               skb->ip_summed = CHECKSUM_UNNECESSARY;
+                               continue;
+                       }
+                       merge = false;
+               }
+
+               copy = min_t(int, msg_data_left(msg),
+                            pfrag->size - pfrag->offset);
+               copy = min_t(int, copy, TLS_MAX_PAYLOAD_SIZE - ctx->unsent);
+
+               if (!sk_wmem_schedule(sk, copy))
+                       goto wait_for_memory;
+
+               ret = skb_copy_to_page_nocache(sk, &msg->msg_iter, skb,
+                                              pfrag->page,
+                                              pfrag->offset,
+                                              copy);
+               ctx->wmem_len += copy;
+               if (ret)
+                       goto send_end;
+
+               /* Update the skb. */
+               if (merge) {
+                       skb_frag_size_add(&skb_shinfo(skb)->frags[i - 1], copy);
+               } else {
+                       skb_fill_page_desc(skb, i, pfrag->page,
+                                          pfrag->offset, copy);
+                       get_page(pfrag->page);
+               }
+
+               pfrag->offset += copy;
+               copied += copy;
+               ctx->unsent += copy;
+
+               if (ctx->unsent >= TLS_MAX_PAYLOAD_SIZE) {
+                       ret = tls_push(sk, record_type);
+                       if (ret)
+                               goto send_end;
+               }
+
+               continue;
+
+wait_for_memory:
+               ret = tls_push(sk, record_type);
+               if (ret)
+                       goto send_end;
+//push_wait:
+               set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
+               ret = sk_stream_wait_memory(sk, &timeo);
+               if (ret)
+                       goto send_end;
+       }
+
+       if (eor)
+               ret = tls_push(sk, record_type);
+
+send_end:
+       ret = sk_stream_error(sk, msg->msg_flags, ret);
+
+       /* make sure we wake any epoll edge trigger waiter */
+       if (unlikely(skb_queue_len(&ctx->tx_queue) == 0 && ret == -EAGAIN))
+               sk->sk_write_space(sk);
+
+       release_sock(sk);
+       return ret < 0 ? ret : size;
+}
+
+void tls_sw_sk_destruct(struct sock *sk)
+{
+       struct tls_context *tls_ctx = tls_get_ctx(sk);
+       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       struct page *tx_page = skb_frag_page(&ctx->tx_frag);
+
+       crypto_free_aead(ctx->aead_send);
+
+       if (tx_page)
+               __free_pages(tx_page, ctx->order_npages);
+
+       skb_queue_purge(&ctx->tx_queue);
+       tls_sk_destruct(sk, tls_ctx);
+}
+
+int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx)
+{
+       char keyval[TLS_CIPHER_AES_GCM_128_KEY_SIZE +
+               TLS_CIPHER_AES_GCM_128_SALT_SIZE];
+       struct tls_crypto_info *crypto_info;
+       struct tls_crypto_info_aes_gcm_128 *gcm_128_info;
+       struct tls_sw_context *sw_ctx;
+       u16 nonece_size, tag_size, iv_size;
+       char *iv;
+       int rc = 0;
+
+       if (!ctx) {
+               rc = -EINVAL;
+               goto out;
+       }
+
+       if (ctx->priv_ctx) {
+               rc = -EEXIST;
+               goto out;
+       }
+
+       sw_ctx = kzalloc(sizeof(*sw_ctx), GFP_KERNEL);
+       if (!sw_ctx) {
+               rc = -ENOMEM;
+               goto out;
+       }
+
+       ctx->priv_ctx = (struct tls_offload_context *)sw_ctx;
+
+       crypto_info = &ctx->crypto_send;
+       switch (crypto_info->cipher_type) {
+       case TLS_CIPHER_AES_GCM_128: {
+               nonece_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
+               tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
+               iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
+               iv = ((struct tls_crypto_info_aes_gcm_128 *)crypto_info)->iv;
+               gcm_128_info =
+                       (struct tls_crypto_info_aes_gcm_128 *)crypto_info;
+               break;
+       }
+       default:
+               rc = -EINVAL;
+               goto out;
+       }
+
+       ctx->prepand_size = TLS_HEADER_SIZE + nonece_size;
+       ctx->tag_size = tag_size;
+       ctx->iv_size = iv_size;
+       ctx->iv = kmalloc(iv_size, GFP_KERNEL);
+       if (!ctx->iv) {
+               rc = ENOMEM;
+               goto out;
+       }
+       memcpy(ctx->iv, iv, iv_size);
+
+       /* Preallocation for sending
+        *   scatterlist: AAD | data | TAG (for crypto API)
+        *   vec: HEADER | data | TAG
+        */
+       sg_init_table(sw_ctx->sg_tx_data, TLS_SG_DATA_SIZE);
+       sg_set_buf(&sw_ctx->sg_tx_data[0], sw_ctx->aad_send,
+                  sizeof(sw_ctx->aad_send));
+
+       sg_set_buf(sw_ctx->sg_tx_data + TLS_SG_DATA_SIZE - 2,
+                  sw_ctx->tag_send, sizeof(sw_ctx->tag_send));
+       sg_mark_end(sw_ctx->sg_tx_data + TLS_SG_DATA_SIZE - 1);
+
+       sg_init_table(sw_ctx->sgaad_send, 2);
+       sg_init_table(sw_ctx->sgtag_send, 2);
+
+       sg_set_buf(&sw_ctx->sgaad_send[0], sw_ctx->aad_send,
+                  sizeof(sw_ctx->aad_send));
+       /* chaining to tag is performed on actual data size when sending */
+       sg_set_buf(&sw_ctx->sgtag_send[0], sw_ctx->tag_send,
+                  sizeof(sw_ctx->tag_send));
+
+       sg_unmark_end(&sw_ctx->sgaad_send[1]);
+
+       if (!sw_ctx->aead_send) {
+               sw_ctx->aead_send =
+                               crypto_alloc_aead("rfc5288(gcm(aes))",
+                                                 CRYPTO_ALG_INTERNAL, 0);
+               if (IS_ERR(sw_ctx->aead_send)) {
+                       rc = PTR_ERR(sw_ctx->aead_send);
+                       sw_ctx->aead_send = NULL;
+                       pr_err("bind fail\n"); // TODO
+                       goto out;
+               }
+       }
+
+       sk->sk_destruct = tls_sw_sk_destruct;
+       sw_ctx->sk_write_space = ctx->sk_write_space;
+       ctx->sk_write_space =  tls_release_tx_frag;
+
+       skb_queue_head_init(&sw_ctx->tx_queue);
+       sw_ctx->sk = sk;
+
+       memcpy(keyval, gcm_128_info->key, TLS_CIPHER_AES_GCM_128_KEY_SIZE);
+       memcpy(keyval + TLS_CIPHER_AES_GCM_128_KEY_SIZE, gcm_128_info->salt,
+              TLS_CIPHER_AES_GCM_128_SALT_SIZE);
+
+       rc = crypto_aead_setkey(sw_ctx->aead_send, keyval,
+                               TLS_CIPHER_AES_GCM_128_KEY_SIZE +
+                               TLS_CIPHER_AES_GCM_128_SALT_SIZE);
+       if (rc)
+               goto out;
+
+       rc = crypto_aead_setauthsize(sw_ctx->aead_send, TLS_TAG_SIZE);
+       if (rc)
+               goto out;
+
+out:
+       return rc;
+}
+
+int tls_sw_sendpage(struct sock *sk, struct page *page,
+                   int offset, size_t size, int flags)
+{
+       struct tls_context *tls_ctx = tls_get_ctx(sk);
+       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       int ret = 0, i;
+       long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
+       bool eor;
+       struct sk_buff *skb = NULL;
+       size_t queued = 0;
+       unsigned char record_type = TLS_RECORD_TYPE_DATA;
+
+       if (flags & MSG_SENDPAGE_NOTLAST)
+               flags |= MSG_MORE;
+
+       /* No MSG_EOR from splice, only look at MSG_MORE */
+       eor = !(flags & MSG_MORE);
+
+       lock_sock(sk);
+
+       if (flags & MSG_OOB) {
+               ret = -ENOTSUPP;
+               goto sendpage_end;
+       }
+       sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
+
+       /* Call the sk_stream functions to manage the sndbuf mem. */
+       while (size > 0) {
+               size_t send_size = min(size, TLS_MAX_PAYLOAD_SIZE);
+
+               if (!sk_stream_memory_free(sk) ||
+                   (ctx->unsent + send_size > TLS_MAX_PAYLOAD_SIZE)) {
+                       ret = tls_push(sk, record_type);
+                       if (ret)
+                               goto sendpage_end;
+                       set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
+                       ret = sk_stream_wait_memory(sk, &timeo);
+                       if (ret)
+                               goto sendpage_end;
+               }
+
+               if (sk->sk_err)
+                       goto sendpage_end;
+
+               skb = skb_peek_tail(&ctx->tx_queue);
+               if (skb) {
+                       i = skb_shinfo(skb)->nr_frags;
+
+                       if (skb_can_coalesce(skb, i, page, offset)) {
+                               skb_frag_size_add(
+                                       &skb_shinfo(skb)->frags[i - 1],
+                                       send_size);
+                               skb_shinfo(skb)->tx_flags |= SKBTX_SHARED_FRAG;
+                               goto coalesced;
+                       }
+
+                       if (i >= ALG_MAX_PAGES) {
+                               struct sk_buff *tskb;
+
+                               tskb = alloc_skb(0, sk->sk_allocation);
+                               while (!tskb) {
+                                       ret = tls_push(sk, record_type);
+                                       if (ret)
+                                               goto sendpage_end;
+                                       set_bit(SOCK_NOSPACE,
+                                               &sk->sk_socket->flags);
+                                       ret = sk_stream_wait_memory(sk, &timeo);
+                                       if (ret)
+                                               goto sendpage_end;
+
+                                       tskb = alloc_skb(0, sk->sk_allocation);
+                               }
+
+                               if (skb)
+                                       skb->next = tskb;
+                               else
+                                       __skb_queue_tail(&ctx->tx_queue,
+                                                        tskb);
+                               skb = tskb;
+                               i = 0;
+                       }
+               } else {
+                       skb = alloc_skb(0, sk->sk_allocation);
+                       __skb_queue_tail(&ctx->tx_queue, skb);
+                       i = 0;
+               }
+
+               get_page(page);
+               skb_fill_page_desc(skb, i, page, offset, send_size);
+               skb_shinfo(skb)->tx_flags |= SKBTX_SHARED_FRAG;
+
+coalesced:
+               skb->len += send_size;
+               skb->data_len += send_size;
+               skb->truesize += send_size;
+               sk->sk_wmem_queued += send_size;
+               ctx->wmem_len += send_size;
+               sk_mem_charge(sk, send_size);
+               ctx->unsent += send_size;
+               queued += send_size;
+               offset += queued;
+               size -= send_size;
+
+               if (eor || ctx->unsent >= TLS_MAX_PAYLOAD_SIZE) {
+                       ret = tls_push(sk, record_type);
+                       if (ret)
+                               goto sendpage_end;
+               }
+       }
+
+       if (eor || ctx->unsent >= TLS_MAX_PAYLOAD_SIZE)
+               ret = tls_push(sk, record_type);
+
+sendpage_end:
+       ret = sk_stream_error(sk, flags, ret);
+
+       if (ret < 0)
+               ret = sk_stream_error(sk, flags, ret);
+
+       release_sock(sk);
+
+       return ret < 0 ? ret : queued;
+}
-- 
2.7.4

Reply via email to