From: Sungho Bae <[email protected]>

Modify __send_to_port() to take ownership of a struct port_buffer *
instead of a void * raw buffer.

Previously, put_chars() would pass a raw kmemdup'd buffer and free it
immediately after __send_to_port() returned. This caused a potential
Use-After-Free and data corruption if the virtqueue was shared with
nonblocking writers, as virtqueue_get_buf() might return an older
completed buffer, causing the newly added buffer to be kfree'd while the
host is still DMAing from it.

By transferring ownership of the allocated port_buffer to __send_to_port(),
we ensure that the exact buffer returned by the host is the one that gets
freed, resolving the memory lifecycle mismatch.

Signed-off-by: Sungho Bae <[email protected]>
---
 drivers/char/virtio_console.c | 69 +++++++++++++++++++----------------
 1 file changed, 37 insertions(+), 32 deletions(-)

diff --git a/drivers/char/virtio_console.c b/drivers/char/virtio_console.c
index 9a33217c68d9..bbf5b3825f12 100644
--- a/drivers/char/virtio_console.c
+++ b/drivers/char/virtio_console.c
@@ -402,7 +402,7 @@ static void reclaim_dma_bufs(void)
 }
 
 static struct port_buffer *alloc_buf(struct virtio_device *vdev, size_t 
buf_size,
-                                    int pages)
+                                    int pages, gfp_t gfp)
 {
        struct port_buffer *buf;
 
@@ -436,11 +436,10 @@ static struct port_buffer *alloc_buf(struct virtio_device 
*vdev, size_t buf_size
 
                /* Increase device refcnt to avoid freeing it */
                get_device(buf->dev);
-               buf->buf = dma_alloc_coherent(buf->dev, buf_size, &buf->dma,
-                                             GFP_KERNEL);
+               buf->buf = dma_alloc_coherent(buf->dev, buf_size, &buf->dma, 
gfp);
        } else {
                buf->dev = NULL;
-               buf->buf = kmalloc(buf_size, GFP_KERNEL);
+               buf->buf = kmalloc(buf_size, gfp);
        }
 
        if (!buf->buf)
@@ -595,7 +594,7 @@ static void reclaim_consumed_buffers(struct port *port)
 
 static ssize_t __send_to_port(struct port *port, struct scatterlist *sg,
                              int nents, size_t in_count,
-                             void *data, bool nonblock)
+                             struct port_buffer *buf, bool nonblock)
 {
        struct virtqueue *out_vq;
        int err;
@@ -608,14 +607,14 @@ static ssize_t __send_to_port(struct port *port, struct 
scatterlist *sg,
 
        reclaim_consumed_buffers(port);
 
-       err = virtqueue_add_outbuf(out_vq, sg, nents, data, GFP_ATOMIC);
+       err = virtqueue_add_outbuf(out_vq, sg, nents, buf, GFP_ATOMIC);
 
        /* Tell Host to go! */
        virtqueue_kick(out_vq);
 
        if (err) {
                in_count = 0;
-               goto done;
+               goto free_and_done;
        }
 
        if (out_vq->num_free == 0)
@@ -632,10 +631,19 @@ static ssize_t __send_to_port(struct port *port, struct 
scatterlist *sg,
         * buffer and relax the spinning requirement.  The downside is
         * we need to kmalloc a GFP_ATOMIC buffer each time the
         * console driver writes something out.
+        *
+        * Spin until host returns the buffer.
+        * Capture the returned buf so we can free it.
+        * If broken, buf == NULL and buf stays in the vq;
+        * remove_vqs() will call virtqueue_detach_unused_buf() -> free_buf().
         */
-       while (!virtqueue_get_buf(out_vq, &len)
+       while (!(buf = virtqueue_get_buf(out_vq, &len))
                && !virtqueue_is_broken(out_vq))
                cpu_relax();
+
+free_and_done:
+       if (buf)
+               free_buf(buf, false);
 done:
        spin_unlock_irqrestore(&port->outvq_lock, flags);
 
@@ -816,14 +824,14 @@ static ssize_t port_fops_write(struct file *filp, const 
char __user *ubuf,
 
        count = min((size_t)(32 * 1024), count);
 
-       buf = alloc_buf(port->portdev->vdev, count, 0);
+       buf = alloc_buf(port->portdev->vdev, count, 0, GFP_KERNEL);
        if (!buf)
                return -ENOMEM;
 
        ret = copy_from_user(buf->buf, ubuf, count);
        if (ret) {
-               ret = -EFAULT;
-               goto free_buf;
+               free_buf(buf, true);
+               return -EFAULT;
        }
 
        /*
@@ -835,15 +843,7 @@ static ssize_t port_fops_write(struct file *filp, const 
char __user *ubuf,
         */
        nonblock = true;
        sg_init_one(sg, buf->buf, count);
-       ret = __send_to_port(port, sg, 1, count, buf, nonblock);
-
-       if (nonblock && ret > 0)
-               goto out;
-
-free_buf:
-       free_buf(buf, true);
-out:
-       return ret;
+       return __send_to_port(port, sg, 1, count, buf, nonblock);
 }
 
 struct sg_list {
@@ -932,7 +932,7 @@ static ssize_t port_fops_splice_write(struct 
pipe_inode_info *pipe,
                goto error_out;
 
        occupancy = pipe_buf_usage(pipe);
-       buf = alloc_buf(port->portdev->vdev, 0, occupancy);
+       buf = alloc_buf(port->portdev->vdev, 0, occupancy, GFP_KERNEL);
 
        if (!buf) {
                ret = -ENOMEM;
@@ -946,11 +946,12 @@ static ssize_t port_fops_splice_write(struct 
pipe_inode_info *pipe,
        sg_init_table(sgl.sg, sgl.size);
        ret = __splice_from_pipe(pipe, &sd, pipe_to_sg);
        pipe_unlock(pipe);
+
        if (likely(ret > 0))
                ret = __send_to_port(port, buf->sg, sgl.n, sgl.len, buf, true);
-
-       if (unlikely(ret <= 0))
+       else
                free_buf(buf, true);
+
        return ret;
 
 error_out:
@@ -1108,21 +1109,25 @@ static ssize_t put_chars(u32 vtermno, const u8 *buf, 
size_t count)
 {
        struct port *port;
        struct scatterlist sg[1];
-       void *data;
-       int ret;
+       struct port_buffer *pbuf;
 
        port = find_port_by_vtermno(vtermno);
        if (!port)
                return -EPIPE;
 
-       data = kmemdup(buf, count, GFP_ATOMIC);
-       if (!data)
+       pbuf = alloc_buf(port->portdev->vdev, count, 0, GFP_ATOMIC);
+       if (!pbuf)
                return -ENOMEM;
 
-       sg_init_one(sg, data, count);
-       ret = __send_to_port(port, sg, 1, count, data, false);
-       kfree(data);
-       return ret;
+       memcpy(pbuf->buf, buf, count);
+       pbuf->len = count;
+       sg_init_one(sg, pbuf->buf, count);
+
+       /*
+        * Ownership of pbuf is transferred to __send_to_port().
+        * Do not touch or free pbuf after this call.
+        */
+       return __send_to_port(port, sg, 1, count, pbuf, false);
 }
 
 /*
@@ -1295,7 +1300,7 @@ static int fill_queue(struct virtqueue *vq, spinlock_t 
*lock)
 
        nr_added_bufs = 0;
        do {
-               buf = alloc_buf(vq->vdev, PAGE_SIZE, 0);
+               buf = alloc_buf(vq->vdev, PAGE_SIZE, 0, GFP_KERNEL);
                if (!buf)
                        return -ENOMEM;
 
-- 
2.34.1


Reply via email to