Guest receive mergeable rx buffer, it can merge
scatter rx buffer into a big buffer and then copy
to user space.

In addition, it also use iovec to replace buf in struct
virtio_vsock_pkt, keep tx and rx consistency. The only
difference is now tx still uses a segment of continuous
physical memory to implement.

Signed-off-by: Yiwen Jiang <jiangyi...@huawei.com>
---
 drivers/vhost/vsock.c                   |  31 +++++++---
 include/linux/virtio_vsock.h            |   6 +-
 net/vmw_vsock/virtio_transport.c        | 105 ++++++++++++++++++++++++++++----
 net/vmw_vsock/virtio_transport_common.c |  59 ++++++++++++++----
 4 files changed, 166 insertions(+), 35 deletions(-)

diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
index dc52b0f..c7ab0dd 100644
--- a/drivers/vhost/vsock.c
+++ b/drivers/vhost/vsock.c
@@ -179,6 +179,8 @@ static int get_rx_bufs(struct vhost_virtqueue *vq,
                size_t nbytes;
                size_t len;
                s16 headcount;
+               size_t remain_len;
+               int i;

                spin_lock_bh(&vsock->send_pkt_list_lock);
                if (list_empty(&vsock->send_pkt_list)) {
@@ -221,11 +223,19 @@ static int get_rx_bufs(struct vhost_virtqueue *vq,
                        break;
                }

-               nbytes = copy_to_iter(pkt->buf, pkt->len, &iov_iter);
-               if (nbytes != pkt->len) {
-                       virtio_transport_free_pkt(pkt);
-                       vq_err(vq, "Faulted on copying pkt buf\n");
-                       break;
+               remain_len = pkt->len;
+               for (i = 0; i < pkt->nr_vecs; i++) {
+                       int tmp_len;
+
+                       tmp_len = min(remain_len, pkt->vec[i].iov_len);
+                       nbytes = copy_to_iter(pkt->vec[i].iov_base, tmp_len, 
&iov_iter);
+                       if (nbytes != tmp_len) {
+                               virtio_transport_free_pkt(pkt);
+                               vq_err(vq, "Faulted on copying pkt buf\n");
+                               break;
+                       }
+
+                       remain_len -= tmp_len;
                }

                vhost_add_used_n(vq, vq->heads, headcount);
@@ -341,6 +351,7 @@ static void vhost_transport_send_pkt_work(struct vhost_work 
*work)
        struct iov_iter iov_iter;
        size_t nbytes;
        size_t len;
+       void *buf;

        if (in != 0) {
                vq_err(vq, "Expected 0 input buffers, got %u\n", in);
@@ -375,13 +386,17 @@ static void vhost_transport_send_pkt_work(struct 
vhost_work *work)
                return NULL;
        }

-       pkt->buf = kmalloc(pkt->len, GFP_KERNEL);
-       if (!pkt->buf) {
+       buf = kmalloc(pkt->len, GFP_KERNEL);
+       if (!buf) {
                kfree(pkt);
                return NULL;
        }

-       nbytes = copy_from_iter(pkt->buf, pkt->len, &iov_iter);
+       pkt->vec[0].iov_base = buf;
+       pkt->vec[0].iov_len = pkt->len;
+       pkt->nr_vecs = 1;
+
+       nbytes = copy_from_iter(buf, pkt->len, &iov_iter);
        if (nbytes != pkt->len) {
                vq_err(vq, "Expected %u byte payload, got %zu bytes\n",
                       pkt->len, nbytes);
diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h
index da9e1fe..734eeed 100644
--- a/include/linux/virtio_vsock.h
+++ b/include/linux/virtio_vsock.h
@@ -13,6 +13,8 @@
 #define VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE       (1024 * 4)
 #define VIRTIO_VSOCK_MAX_BUF_SIZE              0xFFFFFFFFUL
 #define VIRTIO_VSOCK_MAX_PKT_BUF_SIZE          (1024 * 64)
+/* virtio_vsock_pkt + max_pkt_len(default MAX_PKT_BUF_SIZE) */
+#define VIRTIO_VSOCK_MAX_VEC_NUM ((VIRTIO_VSOCK_MAX_PKT_BUF_SIZE / PAGE_SIZE) 
+ 1)

 /* Virtio-vsock feature */
 #define VIRTIO_VSOCK_F_MRG_RXBUF 0 /* Host can merge receive buffers. */
@@ -55,10 +57,12 @@ struct virtio_vsock_pkt {
        struct list_head list;
        /* socket refcnt not held, only use for cancellation */
        struct vsock_sock *vsk;
-       void *buf;
+       struct kvec vec[VIRTIO_VSOCK_MAX_VEC_NUM];
+       int nr_vecs;
        u32 len;
        u32 off;
        bool reply;
+       bool mergeable;
 };

 struct virtio_vsock_pkt_info {
diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c
index c4a465c..148b58a 100644
--- a/net/vmw_vsock/virtio_transport.c
+++ b/net/vmw_vsock/virtio_transport.c
@@ -155,8 +155,10 @@ static int virtio_transport_send_pkt_loopback(struct 
virtio_vsock *vsock,

                sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr));
                sgs[out_sg++] = &hdr;
-               if (pkt->buf) {
-                       sg_init_one(&buf, pkt->buf, pkt->len);
+               if (pkt->len) {
+                       /* Currently only support a segment of memory in tx */
+                       BUG_ON(pkt->vec[0].iov_len != pkt->len);
+                       sg_init_one(&buf, pkt->vec[0].iov_base, 
pkt->vec[0].iov_len);
                        sgs[out_sg++] = &buf;
                }

@@ -304,23 +306,28 @@ static int fill_old_rx_buff(struct virtqueue *vq)
        struct virtio_vsock_pkt *pkt;
        struct scatterlist hdr, buf, *sgs[2];
        int ret;
+       void *pkt_buf;

        pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
        if (!pkt)
                return -ENOMEM;

-       pkt->buf = kmalloc(buf_len, GFP_KERNEL);
-       if (!pkt->buf) {
+       pkt_buf = kmalloc(buf_len, GFP_KERNEL);
+       if (!pkt_buf) {
                virtio_transport_free_pkt(pkt);
                return -ENOMEM;
        }

+       pkt->vec[0].iov_base = pkt_buf;
+       pkt->vec[0].iov_len = buf_len;
+       pkt->nr_vecs = 1;
+
        pkt->len = buf_len;

        sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr));
        sgs[0] = &hdr;

-       sg_init_one(&buf, pkt->buf, buf_len);
+       sg_init_one(&buf, pkt->vec[0].iov_base, buf_len);
        sgs[1] = &buf;
        ret = virtqueue_add_sgs(vq, sgs, 0, 2, pkt, GFP_KERNEL);
        if (ret)
@@ -388,11 +395,78 @@ static bool virtio_transport_more_replies(struct 
virtio_vsock *vsock)
        return val < virtqueue_get_vring_size(vq);
 }

+static struct virtio_vsock_pkt *receive_mergeable(struct virtqueue *vq,
+               struct virtio_vsock *vsock, unsigned int *total_len)
+{
+       struct virtio_vsock_pkt *pkt;
+       u16 num_buf;
+       void *buf;
+       unsigned int len;
+       size_t vsock_hlen = sizeof(struct virtio_vsock_pkt);
+
+       buf = virtqueue_get_buf(vq, &len);
+       if (!buf)
+               return NULL;
+
+       *total_len = len;
+       vsock->rx_buf_nr--;
+
+       if (unlikely(len < vsock_hlen)) {
+               put_page(virt_to_head_page(buf));
+               return NULL;
+       }
+
+       pkt = buf;
+       num_buf = le16_to_cpu(pkt->mrg_rxbuf_hdr.num_buffers);
+       if (!num_buf || num_buf > VIRTIO_VSOCK_MAX_VEC_NUM) {
+               put_page(virt_to_head_page(buf));
+               return NULL;
+       }
+
+       /* Initialize pkt residual structure */
+       memset(&pkt->work, 0, vsock_hlen - sizeof(struct virtio_vsock_hdr) -
+                       sizeof(struct virtio_vsock_mrg_rxbuf_hdr));
+
+       pkt->mergeable = true;
+       pkt->len = le32_to_cpu(pkt->hdr.len);
+       if (!pkt->len)
+               return pkt;
+
+       len -= vsock_hlen;
+       if (len) {
+               pkt->vec[pkt->nr_vecs].iov_base = buf + vsock_hlen;
+               pkt->vec[pkt->nr_vecs].iov_len = len;
+               /* Shared page with pkt, so get page in advance */
+               get_page(virt_to_head_page(buf));
+               pkt->nr_vecs++;
+       }
+
+       while (--num_buf) {
+               buf = virtqueue_get_buf(vq, &len);
+               if (!buf)
+                       goto err;
+
+               *total_len += len;
+               vsock->rx_buf_nr--;
+
+               pkt->vec[pkt->nr_vecs].iov_base = buf;
+               pkt->vec[pkt->nr_vecs].iov_len = len;
+               pkt->nr_vecs++;
+       }
+
+       return pkt;
+err:
+       virtio_transport_free_pkt(pkt);
+       return NULL;
+}
+
 static void virtio_transport_rx_work(struct work_struct *work)
 {
        struct virtio_vsock *vsock =
                container_of(work, struct virtio_vsock, rx_work);
        struct virtqueue *vq;
+       size_t vsock_hlen = vsock->mergeable ? sizeof(struct virtio_vsock_pkt) :
+                       sizeof(struct virtio_vsock_hdr);

        vq = vsock->vqs[VSOCK_VQ_RX];

@@ -412,21 +486,26 @@ static void virtio_transport_rx_work(struct work_struct 
*work)
                                goto out;
                        }

-                       pkt = virtqueue_get_buf(vq, &len);
-                       if (!pkt) {
-                               break;
-                       }
+                       if (likely(vsock->mergeable)) {
+                               pkt = receive_mergeable(vq, vsock, &len);
+                               if (!pkt)
+                                       break;
+                       } else {
+                               pkt = virtqueue_get_buf(vq, &len);
+                               if (!pkt)
+                                       break;

-                       vsock->rx_buf_nr--;
+                               vsock->rx_buf_nr--;
+                       }

                        /* Drop short/long packets */
-                       if (unlikely(len < sizeof(pkt->hdr) ||
-                                    len > sizeof(pkt->hdr) + pkt->len)) {
+                       if (unlikely(len < vsock_hlen ||
+                                    len > vsock_hlen + pkt->len)) {
                                virtio_transport_free_pkt(pkt);
                                continue;
                        }

-                       pkt->len = len - sizeof(pkt->hdr);
+                       pkt->len = len - vsock_hlen;
                        virtio_transport_deliver_tap_pkt(pkt);
                        virtio_transport_recv_pkt(pkt);
                }
diff --git a/net/vmw_vsock/virtio_transport_common.c 
b/net/vmw_vsock/virtio_transport_common.c
index 3ae3a33..123a8b6 100644
--- a/net/vmw_vsock/virtio_transport_common.c
+++ b/net/vmw_vsock/virtio_transport_common.c
@@ -44,6 +44,7 @@ static const struct virtio_transport 
*virtio_transport_get_ops(void)
 {
        struct virtio_vsock_pkt *pkt;
        int err;
+       void *buf = NULL;

        pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
        if (!pkt)
@@ -62,12 +63,16 @@ static const struct virtio_transport 
*virtio_transport_get_ops(void)
        pkt->vsk                = info->vsk;

        if (info->msg && len > 0) {
-               pkt->buf = kmalloc(len, GFP_KERNEL);
-               if (!pkt->buf)
+               buf = kmalloc(len, GFP_KERNEL);
+               if (!buf)
                        goto out_pkt;
-               err = memcpy_from_msg(pkt->buf, info->msg, len);
+               err = memcpy_from_msg(buf, info->msg, len);
                if (err)
                        goto out;
+
+               pkt->vec[0].iov_base = buf;
+               pkt->vec[0].iov_len = len;
+               pkt->nr_vecs = 1;
        }

        trace_virtio_transport_alloc_pkt(src_cid, src_port,
@@ -80,7 +85,7 @@ static const struct virtio_transport 
*virtio_transport_get_ops(void)
        return pkt;

 out:
-       kfree(pkt->buf);
+       kfree(buf);
 out_pkt:
        kfree(pkt);
        return NULL;
@@ -92,6 +97,7 @@ static struct sk_buff *virtio_transport_build_skb(void 
*opaque)
        struct virtio_vsock_pkt *pkt = opaque;
        struct af_vsockmon_hdr *hdr;
        struct sk_buff *skb;
+       int i;

        skb = alloc_skb(sizeof(*hdr) + sizeof(pkt->hdr) + pkt->len,
                        GFP_ATOMIC);
@@ -134,7 +140,8 @@ static struct sk_buff *virtio_transport_build_skb(void 
*opaque)
        skb_put_data(skb, &pkt->hdr, sizeof(pkt->hdr));

        if (pkt->len) {
-               skb_put_data(skb, pkt->buf, pkt->len);
+               for (i = 0; i < pkt->nr_vecs; i++)
+                       skb_put_data(skb, pkt->vec[i].iov_base, 
pkt->vec[i].iov_len);
        }

        return skb;
@@ -260,6 +267,9 @@ static int virtio_transport_send_credit_update(struct 
vsock_sock *vsk,

        spin_lock_bh(&vvs->rx_lock);
        while (total < len && !list_empty(&vvs->rx_queue)) {
+               size_t copy_bytes, last_vec_total = 0, vec_off;
+               int i;
+
                pkt = list_first_entry(&vvs->rx_queue,
                                       struct virtio_vsock_pkt, list);

@@ -272,14 +282,28 @@ static int virtio_transport_send_credit_update(struct 
vsock_sock *vsk,
                 */
                spin_unlock_bh(&vvs->rx_lock);

-               err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
-               if (err)
-                       goto out;
+               for (i = 0; i < pkt->nr_vecs; i++) {
+                       if (pkt->off > last_vec_total + pkt->vec[i].iov_len) {
+                               last_vec_total += pkt->vec[i].iov_len;
+                               continue;
+                       }
+
+                       vec_off = pkt->off - last_vec_total;
+                       copy_bytes = min(pkt->vec[i].iov_len - vec_off, bytes);
+                       err = memcpy_to_msg(msg, pkt->vec[i].iov_base + vec_off,
+                                       copy_bytes);
+                       if (err)
+                               goto out;
+
+                       bytes -= copy_bytes;
+                       pkt->off += copy_bytes;
+                       total += copy_bytes;
+                       last_vec_total += pkt->vec[i].iov_len;
+                       if (!bytes)
+                               break;
+               }

                spin_lock_bh(&vvs->rx_lock);
-
-               total += bytes;
-               pkt->off += bytes;
                if (pkt->off == pkt->len) {
                        virtio_transport_dec_rx_pkt(vvs, pkt);
                        list_del(&pkt->list);
@@ -1050,8 +1074,17 @@ void virtio_transport_recv_pkt(struct virtio_vsock_pkt 
*pkt)

 void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
 {
-       kfree(pkt->buf);
-       kfree(pkt);
+       int i;
+
+       if (pkt->mergeable) {
+               for (i = 0; i < pkt->nr_vecs; i++)
+                       put_page(virt_to_head_page(pkt->vec[i].iov_base));
+               put_page(virt_to_head_page((void *)pkt));
+       } else {
+               for (i = 0; i < pkt->nr_vecs; i++)
+                       kfree(pkt->vec[i].iov_base);
+               kfree(pkt);
+       }
 }
 EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);

-- 
1.8.3.1


Reply via email to