In TLS inline crypto, we can have one direction in software
and another in hardware. Thus, we split the TLS configuration to separate
structures for receive and transmit.

Signed-off-by: Boris Pismenny <bor...@mellanox.com>
---
 include/net/tls.h  |  51 +++++++++++++-------
 net/tls/tls_main.c | 103 ++++++++++++++++++++--------------------
 net/tls/tls_sw.c   | 134 ++++++++++++++++++++++++++++++-----------------------
 3 files changed, 161 insertions(+), 127 deletions(-)

diff --git a/include/net/tls.h b/include/net/tls.h
index 3da8e13..95a8c60 100644
--- a/include/net/tls.h
+++ b/include/net/tls.h
@@ -83,21 +83,10 @@ struct tls_device {
        void (*unhash)(struct tls_device *device, struct sock *sk);
 };
 
-struct tls_sw_context {
+struct tls_sw_context_tx {
        struct crypto_aead *aead_send;
-       struct crypto_aead *aead_recv;
        struct crypto_wait async_wait;
 
-       /* Receive context */
-       struct strparser strp;
-       void (*saved_data_ready)(struct sock *sk);
-       unsigned int (*sk_poll)(struct file *file, struct socket *sock,
-                               struct poll_table_struct *wait);
-       struct sk_buff *recv_pkt;
-       u8 control;
-       bool decrypted;
-
-       /* Sending context */
        char aad_space[TLS_AAD_SPACE_SIZE];
 
        unsigned int sg_plaintext_size;
@@ -114,6 +103,19 @@ struct tls_sw_context {
        struct scatterlist sg_aead_out[2];
 };
 
+struct tls_sw_context_rx {
+       struct crypto_aead *aead_recv;
+       struct crypto_wait async_wait;
+
+       struct strparser strp;
+       void (*saved_data_ready)(struct sock *sk);
+       unsigned int (*sk_poll)(struct file *file, struct socket *sock,
+                               struct poll_table_struct *wait);
+       struct sk_buff *recv_pkt;
+       u8 control;
+       bool decrypted;
+};
+
 enum {
        TLS_PENDING_CLOSED_RECORD
 };
@@ -138,9 +140,15 @@ struct tls_context {
                struct tls12_crypto_info_aes_gcm_128 crypto_recv_aes_gcm_128;
        };
 
-       void *priv_ctx;
+       struct list_head list;
+       struct net_device *netdev;
+       refcount_t refcount;
+
+       void *priv_ctx_tx;
+       void *priv_ctx_rx;
 
-       u8 conf:3;
+       u8 tx_conf:3;
+       u8 rx_conf:3;
 
        struct cipher_context tx;
        struct cipher_context rx;
@@ -177,7 +185,8 @@ int tls_sk_attach(struct sock *sk, int optname, char __user 
*optval,
 int tls_sw_sendpage(struct sock *sk, struct page *page,
                    int offset, size_t size, int flags);
 void tls_sw_close(struct sock *sk, long timeout);
-void tls_sw_free_resources(struct sock *sk);
+void tls_sw_free_resources_tx(struct sock *sk);
+void tls_sw_free_resources_rx(struct sock *sk);
 int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
                   int nonblock, int flags, int *addr_len);
 unsigned int tls_sw_poll(struct file *file, struct socket *sock,
@@ -297,16 +306,22 @@ static inline struct tls_context *tls_get_ctx(const 
struct sock *sk)
        return icsk->icsk_ulp_data;
 }
 
-static inline struct tls_sw_context *tls_sw_ctx(
+static inline struct tls_sw_context_rx *tls_sw_ctx_rx(
+               const struct tls_context *tls_ctx)
+{
+       return (struct tls_sw_context_rx *)tls_ctx->priv_ctx_rx;
+}
+
+static inline struct tls_sw_context_tx *tls_sw_ctx_tx(
                const struct tls_context *tls_ctx)
 {
-       return (struct tls_sw_context *)tls_ctx->priv_ctx;
+       return (struct tls_sw_context_tx *)tls_ctx->priv_ctx_tx;
 }
 
 static inline struct tls_offload_context *tls_offload_ctx(
                const struct tls_context *tls_ctx)
 {
-       return (struct tls_offload_context *)tls_ctx->priv_ctx;
+       return (struct tls_offload_context *)tls_ctx->priv_ctx_tx;
 }
 
 int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg,
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 0d37997..545bf34 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -51,12 +51,9 @@ enum {
        TLSV6,
        TLS_NUM_PROTS,
 };
-
 enum {
        TLS_BASE,
-       TLS_SW_TX,
-       TLS_SW_RX,
-       TLS_SW_RXTX,
+       TLS_SW,
        TLS_HW_RECORD,
        TLS_NUM_CONFIG,
 };
@@ -65,14 +62,14 @@ enum {
 static DEFINE_MUTEX(tcpv6_prot_mutex);
 static LIST_HEAD(device_list);
 static DEFINE_MUTEX(device_mutex);
-static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG];
+static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
 static struct proto_ops tls_sw_proto_ops;
 
-static inline void update_sk_prot(struct sock *sk, struct tls_context *ctx)
+static void update_sk_prot(struct sock *sk, struct tls_context *ctx)
 {
        int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
 
-       sk->sk_prot = &tls_prots[ip_ver][ctx->conf];
+       sk->sk_prot = &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf];
 }
 
 int wait_on_pending_writer(struct sock *sk, long *timeo)
@@ -245,10 +242,10 @@ static void tls_sk_proto_close(struct sock *sk, long 
timeout)
        lock_sock(sk);
        sk_proto_close = ctx->sk_proto_close;
 
-       if (ctx->conf == TLS_HW_RECORD)
+       if (ctx->tx_conf == TLS_HW_RECORD && ctx->rx_conf == TLS_HW_RECORD)
                goto skip_tx_cleanup;
 
-       if (ctx->conf == TLS_BASE) {
+       if (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE) {
                kfree(ctx);
                ctx = NULL;
                goto skip_tx_cleanup;
@@ -270,15 +267,17 @@ static void tls_sk_proto_close(struct sock *sk, long 
timeout)
                }
        }
 
-       kfree(ctx->tx.rec_seq);
-       kfree(ctx->tx.iv);
-       kfree(ctx->rx.rec_seq);
-       kfree(ctx->rx.iv);
+       /* We need these for tls_sw_fallback handling of other packets */
+       if (ctx->tx_conf == TLS_SW) {
+               kfree(ctx->tx.rec_seq);
+               kfree(ctx->tx.iv);
+               tls_sw_free_resources_tx(sk);
+       }
 
-       if (ctx->conf == TLS_SW_TX ||
-           ctx->conf == TLS_SW_RX ||
-           ctx->conf == TLS_SW_RXTX) {
-               tls_sw_free_resources(sk);
+       if (ctx->rx_conf == TLS_SW) {
+               kfree(ctx->rx.rec_seq);
+               kfree(ctx->rx.iv);
+               tls_sw_free_resources_rx(sk);
        }
 
 skip_tx_cleanup:
@@ -287,7 +286,8 @@ static void tls_sk_proto_close(struct sock *sk, long 
timeout)
        /* free ctx for TLS_HW_RECORD, used by tcp_set_state
         * for sk->sk_prot->unhash [tls_hw_unhash]
         */
-       if (ctx && ctx->conf == TLS_HW_RECORD)
+       if (ctx && ctx->tx_conf == TLS_HW_RECORD &&
+           ctx->rx_conf == TLS_HW_RECORD)
                kfree(ctx);
 }
 
@@ -441,25 +441,21 @@ static int do_tls_setsockopt_conf(struct sock *sk, char 
__user *optval,
                goto err_crypto_info;
        }
 
-       /* currently SW is default, we will have ethtool in future */
        if (tx) {
                rc = tls_set_sw_offload(sk, ctx, 1);
-               if (ctx->conf == TLS_SW_RX)
-                       conf = TLS_SW_RXTX;
-               else
-                       conf = TLS_SW_TX;
+               conf = TLS_SW;
        } else {
                rc = tls_set_sw_offload(sk, ctx, 0);
-               if (ctx->conf == TLS_SW_TX)
-                       conf = TLS_SW_RXTX;
-               else
-                       conf = TLS_SW_RX;
+               conf = TLS_SW;
        }
 
        if (rc)
                goto err_crypto_info;
 
-       ctx->conf = conf;
+       if (tx)
+               ctx->tx_conf = conf;
+       else
+               ctx->rx_conf = conf;
        update_sk_prot(sk, ctx);
        if (tx) {
                ctx->sk_write_space = sk->sk_write_space;
@@ -535,7 +531,8 @@ static int tls_hw_prot(struct sock *sk)
                        ctx->hash = sk->sk_prot->hash;
                        ctx->unhash = sk->sk_prot->unhash;
                        ctx->sk_proto_close = sk->sk_prot->close;
-                       ctx->conf = TLS_HW_RECORD;
+                       ctx->rx_conf = TLS_HW_RECORD;
+                       ctx->tx_conf = TLS_HW_RECORD;
                        update_sk_prot(sk, ctx);
                        rc = 1;
                        break;
@@ -579,29 +576,30 @@ static int tls_hw_hash(struct sock *sk)
        return err;
 }
 
-static void build_protos(struct proto *prot, struct proto *base)
+static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
+                        struct proto *base)
 {
-       prot[TLS_BASE] = *base;
-       prot[TLS_BASE].setsockopt       = tls_setsockopt;
-       prot[TLS_BASE].getsockopt       = tls_getsockopt;
-       prot[TLS_BASE].close            = tls_sk_proto_close;
-
-       prot[TLS_SW_TX] = prot[TLS_BASE];
-       prot[TLS_SW_TX].sendmsg         = tls_sw_sendmsg;
-       prot[TLS_SW_TX].sendpage        = tls_sw_sendpage;
-
-       prot[TLS_SW_RX] = prot[TLS_BASE];
-       prot[TLS_SW_RX].recvmsg         = tls_sw_recvmsg;
-       prot[TLS_SW_RX].close           = tls_sk_proto_close;
-
-       prot[TLS_SW_RXTX] = prot[TLS_SW_TX];
-       prot[TLS_SW_RXTX].recvmsg       = tls_sw_recvmsg;
-       prot[TLS_SW_RXTX].close         = tls_sk_proto_close;
-
-       prot[TLS_HW_RECORD] = *base;
-       prot[TLS_HW_RECORD].hash        = tls_hw_hash;
-       prot[TLS_HW_RECORD].unhash      = tls_hw_unhash;
-       prot[TLS_HW_RECORD].close       = tls_sk_proto_close;
+       prot[TLS_BASE][TLS_BASE] = *base;
+       prot[TLS_BASE][TLS_BASE].setsockopt     = tls_setsockopt;
+       prot[TLS_BASE][TLS_BASE].getsockopt     = tls_getsockopt;
+       prot[TLS_BASE][TLS_BASE].close          = tls_sk_proto_close;
+
+       prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
+       prot[TLS_SW][TLS_BASE].sendmsg          = tls_sw_sendmsg;
+       prot[TLS_SW][TLS_BASE].sendpage         = tls_sw_sendpage;
+
+       prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE];
+       prot[TLS_BASE][TLS_SW].recvmsg          = tls_sw_recvmsg;
+       prot[TLS_BASE][TLS_SW].close            = tls_sk_proto_close;
+
+       prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE];
+       prot[TLS_SW][TLS_SW].recvmsg    = tls_sw_recvmsg;
+       prot[TLS_SW][TLS_SW].close      = tls_sk_proto_close;
+
+       prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
+       prot[TLS_HW_RECORD][TLS_HW_RECORD].hash         = tls_hw_hash;
+       prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash       = tls_hw_unhash;
+       prot[TLS_HW_RECORD][TLS_HW_RECORD].close        = tls_sk_proto_close;
 }
 
 static int tls_init(struct sock *sk)
@@ -643,7 +641,8 @@ static int tls_init(struct sock *sk)
                mutex_unlock(&tcpv6_prot_mutex);
        }
 
-       ctx->conf = TLS_BASE;
+       ctx->tx_conf = TLS_BASE;
+       ctx->rx_conf = TLS_BASE;
        update_sk_prot(sk, ctx);
 out:
        return rc;
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index 71e7959..f374cc2 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -52,7 +52,7 @@ static int tls_do_decryption(struct sock *sk,
                             gfp_t flags)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
-       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
        struct strp_msg *rxm = strp_msg(skb);
        struct aead_request *aead_req;
 
@@ -122,7 +122,7 @@ static void trim_sg(struct sock *sk, struct scatterlist *sg,
 static void trim_both_sgl(struct sock *sk, int target_size)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
-       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
 
        trim_sg(sk, ctx->sg_plaintext_data,
                &ctx->sg_plaintext_num_elem,
@@ -141,7 +141,7 @@ static void trim_both_sgl(struct sock *sk, int target_size)
 static int alloc_encrypted_sg(struct sock *sk, int len)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
-       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
        int rc = 0;
 
        rc = sk_alloc_sg(sk, len,
@@ -155,7 +155,7 @@ static int alloc_encrypted_sg(struct sock *sk, int len)
 static int alloc_plaintext_sg(struct sock *sk, int len)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
-       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
        int rc = 0;
 
        rc = sk_alloc_sg(sk, len, ctx->sg_plaintext_data, 0,
@@ -181,7 +181,7 @@ static void free_sg(struct sock *sk, struct scatterlist *sg,
 static void tls_free_both_sg(struct sock *sk)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
-       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
 
        free_sg(sk, ctx->sg_encrypted_data, &ctx->sg_encrypted_num_elem,
                &ctx->sg_encrypted_size);
@@ -191,7 +191,7 @@ static void tls_free_both_sg(struct sock *sk)
 }
 
 static int tls_do_encryption(struct tls_context *tls_ctx,
-                            struct tls_sw_context *ctx, size_t data_len,
+                            struct tls_sw_context_tx *ctx, size_t data_len,
                             gfp_t flags)
 {
        unsigned int req_size = sizeof(struct aead_request) +
@@ -227,7 +227,7 @@ static int tls_push_record(struct sock *sk, int flags,
                           unsigned char record_type)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
-       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
        int rc;
 
        sg_mark_end(ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem - 1);
@@ -339,7 +339,7 @@ static int memcopy_from_iter(struct sock *sk, struct 
iov_iter *from,
                             int bytes)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
-       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
        struct scatterlist *sg = ctx->sg_plaintext_data;
        int copy, i, rc = 0;
 
@@ -367,7 +367,7 @@ static int memcopy_from_iter(struct sock *sk, struct 
iov_iter *from,
 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
-       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
        int ret = 0;
        int required_size;
        long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
@@ -522,7 +522,7 @@ int tls_sw_sendpage(struct sock *sk, struct page *page,
                    int offset, size_t size, int flags)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
-       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
        int ret = 0;
        long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
        bool eor;
@@ -636,7 +636,7 @@ static struct sk_buff *tls_wait_data(struct sock *sk, int 
flags,
                                     long timeo, int *err)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
-       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
        struct sk_buff *skb;
        DEFINE_WAIT_FUNC(wait, woken_wake_function);
 
@@ -674,7 +674,7 @@ static int decrypt_skb(struct sock *sk, struct sk_buff *skb,
                       struct scatterlist *sgout)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
-       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
        char iv[TLS_CIPHER_AES_GCM_128_SALT_SIZE + MAX_IV_SIZE];
        struct scatterlist sgin_arr[MAX_SKB_FRAGS + 2];
        struct scatterlist *sgin = &sgin_arr[0];
@@ -724,7 +724,7 @@ static bool tls_sw_advance_skb(struct sock *sk, struct 
sk_buff *skb,
                               unsigned int len)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
-       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
        struct strp_msg *rxm = strp_msg(skb);
 
        if (len < rxm->full_len) {
@@ -750,7 +750,7 @@ int tls_sw_recvmsg(struct sock *sk,
                   int *addr_len)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
-       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
        unsigned char control;
        struct strp_msg *rxm;
        struct sk_buff *skb;
@@ -870,7 +870,7 @@ ssize_t tls_sw_splice_read(struct socket *sock,  loff_t 
*ppos,
                           size_t len, unsigned int flags)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sock->sk);
-       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
        struct strp_msg *rxm = NULL;
        struct sock *sk = sock->sk;
        struct sk_buff *skb;
@@ -923,7 +923,7 @@ unsigned int tls_sw_poll(struct file *file, struct socket 
*sock,
        unsigned int ret;
        struct sock *sk = sock->sk;
        struct tls_context *tls_ctx = tls_get_ctx(sk);
-       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
 
        /* Grab POLLOUT and POLLHUP from the underlying socket */
        ret = ctx->sk_poll(file, sock, wait);
@@ -939,7 +939,7 @@ unsigned int tls_sw_poll(struct file *file, struct socket 
*sock,
 static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
 {
        struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
-       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
        char header[tls_ctx->rx.prepend_size];
        struct strp_msg *rxm = strp_msg(skb);
        size_t cipher_overhead;
@@ -988,7 +988,7 @@ static int tls_read_size(struct strparser *strp, struct 
sk_buff *skb)
 static void tls_queue(struct strparser *strp, struct sk_buff *skb)
 {
        struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
-       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
        struct strp_msg *rxm;
 
        rxm = strp_msg(skb);
@@ -1004,18 +1004,28 @@ static void tls_queue(struct strparser *strp, struct 
sk_buff *skb)
 static void tls_data_ready(struct sock *sk)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
-       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
 
        strp_data_ready(&ctx->strp);
 }
 
-void tls_sw_free_resources(struct sock *sk)
+void tls_sw_free_resources_tx(struct sock *sk)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
-       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
 
        if (ctx->aead_send)
                crypto_free_aead(ctx->aead_send);
+       tls_free_both_sg(sk);
+
+       kfree(ctx);
+}
+
+void tls_sw_free_resources_rx(struct sock *sk)
+{
+       struct tls_context *tls_ctx = tls_get_ctx(sk);
+       struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+
        if (ctx->aead_recv) {
                if (ctx->recv_pkt) {
                        kfree_skb(ctx->recv_pkt);
@@ -1031,10 +1041,7 @@ void tls_sw_free_resources(struct sock *sk)
                lock_sock(sk);
        }
 
-       tls_free_both_sg(sk);
-
        kfree(ctx);
-       kfree(tls_ctx);
 }
 
 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
@@ -1042,7 +1049,8 @@ int tls_set_sw_offload(struct sock *sk, struct 
tls_context *ctx, int tx)
        char keyval[TLS_CIPHER_AES_GCM_128_KEY_SIZE];
        struct tls_crypto_info *crypto_info;
        struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
-       struct tls_sw_context *sw_ctx;
+       struct tls_sw_context_tx *sw_ctx_tx;
+       struct tls_sw_context_rx *sw_ctx_rx;
        struct cipher_context *cctx;
        struct crypto_aead **aead;
        struct strp_callbacks cb;
@@ -1055,27 +1063,32 @@ int tls_set_sw_offload(struct sock *sk, struct 
tls_context *ctx, int tx)
                goto out;
        }
 
-       if (!ctx->priv_ctx) {
-               sw_ctx = kzalloc(sizeof(*sw_ctx), GFP_KERNEL);
-               if (!sw_ctx) {
+       if (tx) {
+               sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL);
+               if (!sw_ctx_tx) {
                        rc = -ENOMEM;
                        goto out;
                }
-               crypto_init_wait(&sw_ctx->async_wait);
+               crypto_init_wait(&sw_ctx_tx->async_wait);
+               ctx->priv_ctx_tx = sw_ctx_tx;
        } else {
-               sw_ctx = ctx->priv_ctx;
+               sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL);
+               if (!sw_ctx_rx) {
+                       rc = -ENOMEM;
+                       goto out;
+               }
+               crypto_init_wait(&sw_ctx_rx->async_wait);
+               ctx->priv_ctx_rx = sw_ctx_rx;
        }
 
-       ctx->priv_ctx = (struct tls_offload_context *)sw_ctx;
-
        if (tx) {
                crypto_info = &ctx->crypto_send;
                cctx = &ctx->tx;
-               aead = &sw_ctx->aead_send;
+               aead = &sw_ctx_tx->aead_send;
        } else {
                crypto_info = &ctx->crypto_recv;
                cctx = &ctx->rx;
-               aead = &sw_ctx->aead_recv;
+               aead = &sw_ctx_rx->aead_recv;
        }
 
        switch (crypto_info->cipher_type) {
@@ -1123,21 +1136,23 @@ int tls_set_sw_offload(struct sock *sk, struct 
tls_context *ctx, int tx)
        memcpy(cctx->rec_seq, rec_seq, rec_seq_size);
 
        if (tx) {
-               sg_init_table(sw_ctx->sg_encrypted_data,
-                             ARRAY_SIZE(sw_ctx->sg_encrypted_data));
-               sg_init_table(sw_ctx->sg_plaintext_data,
-                             ARRAY_SIZE(sw_ctx->sg_plaintext_data));
-
-               sg_init_table(sw_ctx->sg_aead_in, 2);
-               sg_set_buf(&sw_ctx->sg_aead_in[0], sw_ctx->aad_space,
-                          sizeof(sw_ctx->aad_space));
-               sg_unmark_end(&sw_ctx->sg_aead_in[1]);
-               sg_chain(sw_ctx->sg_aead_in, 2, sw_ctx->sg_plaintext_data);
-               sg_init_table(sw_ctx->sg_aead_out, 2);
-               sg_set_buf(&sw_ctx->sg_aead_out[0], sw_ctx->aad_space,
-                          sizeof(sw_ctx->aad_space));
-               sg_unmark_end(&sw_ctx->sg_aead_out[1]);
-               sg_chain(sw_ctx->sg_aead_out, 2, sw_ctx->sg_encrypted_data);
+               sg_init_table(sw_ctx_tx->sg_encrypted_data,
+                             ARRAY_SIZE(sw_ctx_tx->sg_encrypted_data));
+               sg_init_table(sw_ctx_tx->sg_plaintext_data,
+                             ARRAY_SIZE(sw_ctx_tx->sg_plaintext_data));
+
+               sg_init_table(sw_ctx_tx->sg_aead_in, 2);
+               sg_set_buf(&sw_ctx_tx->sg_aead_in[0], sw_ctx_tx->aad_space,
+                          sizeof(sw_ctx_tx->aad_space));
+               sg_unmark_end(&sw_ctx_tx->sg_aead_in[1]);
+               sg_chain(sw_ctx_tx->sg_aead_in, 2,
+                        sw_ctx_tx->sg_plaintext_data);
+               sg_init_table(sw_ctx_tx->sg_aead_out, 2);
+               sg_set_buf(&sw_ctx_tx->sg_aead_out[0], sw_ctx_tx->aad_space,
+                          sizeof(sw_ctx_tx->aad_space));
+               sg_unmark_end(&sw_ctx_tx->sg_aead_out[1]);
+               sg_chain(sw_ctx_tx->sg_aead_out, 2,
+                        sw_ctx_tx->sg_encrypted_data);
        }
 
        if (!*aead) {
@@ -1168,16 +1183,16 @@ int tls_set_sw_offload(struct sock *sk, struct 
tls_context *ctx, int tx)
                cb.rcv_msg = tls_queue;
                cb.parse_msg = tls_read_size;
 
-               strp_init(&sw_ctx->strp, sk, &cb);
+               strp_init(&sw_ctx_rx->strp, sk, &cb);
 
                write_lock_bh(&sk->sk_callback_lock);
-               sw_ctx->saved_data_ready = sk->sk_data_ready;
+               sw_ctx_rx->saved_data_ready = sk->sk_data_ready;
                sk->sk_data_ready = tls_data_ready;
                write_unlock_bh(&sk->sk_callback_lock);
 
-               sw_ctx->sk_poll = sk->sk_socket->ops->poll;
+               sw_ctx_rx->sk_poll = sk->sk_socket->ops->poll;
 
-               strp_check_rcv(&sw_ctx->strp);
+               strp_check_rcv(&sw_ctx_rx->strp);
        }
 
        goto out;
@@ -1189,11 +1204,16 @@ int tls_set_sw_offload(struct sock *sk, struct 
tls_context *ctx, int tx)
        kfree(cctx->rec_seq);
        cctx->rec_seq = NULL;
 free_iv:
-       kfree(ctx->tx.iv);
-       ctx->tx.iv = NULL;
+       kfree(cctx->iv);
+       cctx->iv = NULL;
 free_priv:
-       kfree(ctx->priv_ctx);
-       ctx->priv_ctx = NULL;
+       if (tx) {
+               kfree(ctx->priv_ctx_tx);
+               ctx->priv_ctx_tx = NULL;
+       } else {
+               kfree(ctx->priv_ctx_rx);
+               ctx->priv_ctx_rx = NULL;
+       }
 out:
        return rc;
 }
-- 
1.8.3.1

Reply via email to