If the MSG_ZEROCOPY flag is specified with rds_sendmsg(), and,
if the SO_ZEROCOPY socket option has been set on the PF_RDS socket,
application pages sent down with rds_sendmsg() are pinned.

The pinning uses the accounting infrastructure added by
Commit a91dbff551a6 ("sock: ulimit on MSG_ZEROCOPY pages")

The payload bytes in the message may not be modified for the
duration that the message has been pinned. A multi-threaded
application using this infrastructure may thus need to be notified
about send-completion so that it can free/reuse the buffers
passed to rds_sendmsg(). Notification of send-completion will
identify each message-buffer by a cookie that the application
must specify as ancillary data to rds_sendmsg().
The ancillary data in this case has cmsg_level == SOL_RDS
and cmsg_type == RDS_CMSG_ZCOPY_COOKIE.

Signed-off-by: Sowmini Varadhan <sowmini.varad...@oracle.com>
---
 include/uapi/linux/rds.h |    1 +
 net/rds/message.c        |   48 +++++++++++++++++++++++++++++++++++++++++++++-
 net/rds/rds.h            |    3 +-
 net/rds/send.c           |   44 ++++++++++++++++++++++++++++++++++++-----
 4 files changed, 88 insertions(+), 8 deletions(-)

diff --git a/include/uapi/linux/rds.h b/include/uapi/linux/rds.h
index e71d449..12e3bca 100644
--- a/include/uapi/linux/rds.h
+++ b/include/uapi/linux/rds.h
@@ -103,6 +103,7 @@
 #define RDS_CMSG_MASKED_ATOMIC_FADD    8
 #define RDS_CMSG_MASKED_ATOMIC_CSWP    9
 #define RDS_CMSG_RXPATH_LATENCY                11
+#define        RDS_CMSG_ZCOPY_COOKIE           12
 
 #define RDS_INFO_FIRST                 10000
 #define RDS_INFO_COUNTERS              10000
diff --git a/net/rds/message.c b/net/rds/message.c
index 7ca968a..79b24db 100644
--- a/net/rds/message.c
+++ b/net/rds/message.c
@@ -362,12 +362,14 @@ struct rds_message *rds_message_map_pages(unsigned long 
*page_addrs, unsigned in
        return rm;
 }
 
-int rds_message_copy_from_user(struct rds_message *rm, struct iov_iter *from)
+int rds_message_copy_from_user(struct rds_message *rm, struct iov_iter *from,
+                              bool zcopy)
 {
        unsigned long to_copy, nbytes;
        unsigned long sg_off;
        struct scatterlist *sg;
        int ret = 0;
+       int length = iov_iter_count(from);
 
        rm->m_inc.i_hdr.h_len = cpu_to_be32(iov_iter_count(from));
 
@@ -377,6 +379,50 @@ int rds_message_copy_from_user(struct rds_message *rm, 
struct iov_iter *from)
        sg = rm->data.op_sg;
        sg_off = 0; /* Dear gcc, sg->page will be null from kzalloc. */
 
+       if (zcopy) {
+               int total_copied = 0;
+               struct sk_buff *skb;
+
+               skb = alloc_skb(SO_EE_ORIGIN_MAX_ZCOOKIES * sizeof(u32),
+                               GFP_KERNEL);
+               if (!skb)
+                       return -ENOMEM;
+               rm->data.op_mmp_znotifier = RDS_ZCOPY_SKB(skb);
+               memset(rm->data.op_mmp_znotifier, 0,
+                      sizeof(*rm->data.op_mmp_znotifier));
+               if (mm_account_pinned_pages(&rm->data.op_mmp_znotifier->z_mmp,
+                                           length)) {
+                       consume_skb(skb);
+                       rm->data.op_mmp_znotifier = NULL;
+                       return -ENOMEM;
+               }
+               while (iov_iter_count(from)) {
+                       struct page *pages;
+                       size_t start;
+                       ssize_t copied;
+
+                       copied = iov_iter_get_pages(from, &pages, PAGE_SIZE,
+                                                   1, &start);
+                       if (copied < 0) {
+                               struct mmpin *mmp;
+
+                               mmp = &rm->data.op_mmp_znotifier->z_mmp;
+                               mm_unaccount_pinned_pages(mmp);
+                               consume_skb(skb);
+                               rm->data.op_mmp_znotifier = NULL;
+                               return -EFAULT;
+                       }
+                       total_copied += copied;
+                       iov_iter_advance(from, copied);
+                       length -= copied;
+                       sg_set_page(sg, pages, copied, start);
+                       rm->data.op_nents++;
+                       sg++;
+               }
+               WARN_ON_ONCE(length != 0);
+               return ret;
+       } /* zcopy */
+
        while (iov_iter_count(from)) {
                if (!sg_page(sg)) {
                        ret = rds_page_remainder_alloc(sg, iov_iter_count(from),
diff --git a/net/rds/rds.h b/net/rds/rds.h
index c375dd8..9dfc23c 100644
--- a/net/rds/rds.h
+++ b/net/rds/rds.h
@@ -789,7 +789,8 @@ void rds_for_each_conn_info(struct socket *sock, unsigned 
int len,
 /* message.c */
 struct rds_message *rds_message_alloc(unsigned int nents, gfp_t gfp);
 struct scatterlist *rds_message_alloc_sgs(struct rds_message *rm, int nents);
-int rds_message_copy_from_user(struct rds_message *rm, struct iov_iter *from);
+int rds_message_copy_from_user(struct rds_message *rm, struct iov_iter *from,
+                              bool zcopy);
 struct rds_message *rds_message_map_pages(unsigned long *page_addrs, unsigned 
int total_len);
 void rds_message_populate_header(struct rds_header *hdr, __be16 sport,
                                 __be16 dport, u64 seq);
diff --git a/net/rds/send.c b/net/rds/send.c
index 5ac0925..1f72c8a 100644
--- a/net/rds/send.c
+++ b/net/rds/send.c
@@ -875,12 +875,13 @@ static int rds_send_queue_rm(struct rds_sock *rs, struct 
rds_connection *conn,
  * rds_message is getting to be quite complicated, and we'd like to allocate
  * it all in one go. This figures out how big it needs to be up front.
  */
-static int rds_rm_size(struct msghdr *msg, int data_len)
+static int rds_rm_size(struct msghdr *msg, int data_len, int num_sgs)
 {
        struct cmsghdr *cmsg;
        int size = 0;
        int cmsg_groups = 0;
        int retval;
+       bool zcopy_cookie = false;
 
        for_each_cmsghdr(cmsg, msg) {
                if (!CMSG_OK(msg, cmsg))
@@ -899,6 +900,8 @@ static int rds_rm_size(struct msghdr *msg, int data_len)
 
                        break;
 
+               case RDS_CMSG_ZCOPY_COOKIE:
+                       zcopy_cookie = true;
                case RDS_CMSG_RDMA_DEST:
                case RDS_CMSG_RDMA_MAP:
                        cmsg_groups |= 2;
@@ -919,7 +922,10 @@ static int rds_rm_size(struct msghdr *msg, int data_len)
 
        }
 
-       size += ceil(data_len, PAGE_SIZE) * sizeof(struct scatterlist);
+       if ((msg->msg_flags & MSG_ZEROCOPY) && !zcopy_cookie)
+               return -EINVAL;
+
+       size += num_sgs * sizeof(struct scatterlist);
 
        /* Ensure (DEST, MAP) are never used with (ARGS, ATOMIC) */
        if (cmsg_groups == 3)
@@ -928,6 +934,18 @@ static int rds_rm_size(struct msghdr *msg, int data_len)
        return size;
 }
 
+static int rds_cmsg_zcopy(struct rds_sock *rs, struct rds_message *rm,
+                         struct cmsghdr *cmsg)
+{
+       u32 *cookie;
+
+       if (cmsg->cmsg_len < CMSG_LEN(sizeof(*cookie)))
+               return -EINVAL;
+       cookie = CMSG_DATA(cmsg);
+       rm->data.op_mmp_znotifier->z_cookie = *cookie;
+       return 0;
+}
+
 static int rds_cmsg_send(struct rds_sock *rs, struct rds_message *rm,
                         struct msghdr *msg, int *allocated_mr)
 {
@@ -970,6 +988,10 @@ static int rds_cmsg_send(struct rds_sock *rs, struct 
rds_message *rm,
                        ret = rds_cmsg_atomic(rs, rm, cmsg);
                        break;
 
+               case RDS_CMSG_ZCOPY_COOKIE:
+                       ret = rds_cmsg_zcopy(rs, rm, cmsg);
+                       break;
+
                default:
                        return -EINVAL;
                }
@@ -1040,10 +1062,13 @@ int rds_sendmsg(struct socket *sock, struct msghdr 
*msg, size_t payload_len)
        long timeo = sock_sndtimeo(sk, nonblock);
        struct rds_conn_path *cpath;
        size_t total_payload_len = payload_len, rdma_payload_len = 0;
+       bool zcopy = ((msg->msg_flags & MSG_ZEROCOPY) &&
+                     sock_flag(rds_rs_to_sk(rs), SOCK_ZEROCOPY));
+       int num_sgs = ceil(payload_len, PAGE_SIZE);
 
        /* Mirror Linux UDP mirror of BSD error message compatibility */
        /* XXX: Perhaps MSG_MORE someday */
-       if (msg->msg_flags & ~(MSG_DONTWAIT | MSG_CMSG_COMPAT)) {
+       if (msg->msg_flags & ~(MSG_DONTWAIT | MSG_CMSG_COMPAT | MSG_ZEROCOPY)) {
                ret = -EOPNOTSUPP;
                goto out;
        }
@@ -1087,8 +1112,15 @@ int rds_sendmsg(struct socket *sock, struct msghdr *msg, 
size_t payload_len)
                goto out;
        }
 
+       if (zcopy) {
+               if (rs->rs_transport->t_type != RDS_TRANS_TCP) {
+                       ret = -EOPNOTSUPP;
+                       goto out;
+               }
+               num_sgs = iov_iter_npages(&msg->msg_iter, INT_MAX);
+       }
        /* size of rm including all sgs */
-       ret = rds_rm_size(msg, payload_len);
+       ret = rds_rm_size(msg, payload_len, num_sgs);
        if (ret < 0)
                goto out;
 
@@ -1100,12 +1132,12 @@ int rds_sendmsg(struct socket *sock, struct msghdr 
*msg, size_t payload_len)
 
        /* Attach data to the rm */
        if (payload_len) {
-               rm->data.op_sg = rds_message_alloc_sgs(rm, ceil(payload_len, 
PAGE_SIZE));
+               rm->data.op_sg = rds_message_alloc_sgs(rm, num_sgs);
                if (!rm->data.op_sg) {
                        ret = -ENOMEM;
                        goto out;
                }
-               ret = rds_message_copy_from_user(rm, &msg->msg_iter);
+               ret = rds_message_copy_from_user(rm, &msg->msg_iter, zcopy);
                if (ret)
                        goto out;
        }
-- 
1.7.1

Reply via email to