commit a663b3c47ab1 ("i2c: virtio: Avoid hang by using interruptible
completion wait") switched virtio_i2c_complete_reqs() to
wait_for_completion_interruptible() so a stuck device cannot hang a
task forever. That left a use-after-free: if the wait returns early on
a signal, virtio_i2c_xfer() frees reqs and DMA bounce buffers while the
device may still hold virtqueue tokens pointing at &reqs[i] and DMA
into read buffers. When those requests complete later,
virtio_i2c_msg_done() calls complete() on freed memory.

Waiting uninterruptibly for every completion before freeing avoids the
UAF but can hang the caller indefinitely if the virtio side never
completes the request. Performing a queue reset / device reset is
possible in this scenario, but adds complexity.

This commit manages the freeing of the xfer allocations via kref, and
ensures that each in-flight request holds a reference. This fixes the
use-after-free by ensuring that the virtio device has a valid location
to write to until the request completes. This will cause a memory leak
in cases where the device hangs, but that is much preferable to memory
corruption.

Additionally, force usage of a bounce buffer even if the i2c_msg buf is
DMA-safe, since the buffer passed to the virtqueue must remain valid
even if the transfer is interrupted. Remove usage of
i2c_get_dma_safe_msg_buf() since it may pass through msg->buf directly.
All bounce buffers are part of the single xfer allocation, so there is
no additional allocation overhead.

Signed-off-by: Gavin Li <[email protected]>

---

Changes in v5:
- DMA-align all bounce buffers

Changes in v4:
- Pack bounce buffers into a single allocation after reqs[]
- Remove per-request buf pointer and xfer->num
- Remove req.msg, track message direction with req->read
- Simplify xfer release to a single kfree()
- Restore using j to track successful transfers in _complete_xfer()
---
 drivers/i2c/busses/i2c-virtio.c | 135 +++++++++++++++++++++++++-------
 1 file changed, 108 insertions(+), 27 deletions(-)

diff --git a/drivers/i2c/busses/i2c-virtio.c b/drivers/i2c/busses/i2c-virtio.c
index 5da6fef92bec3..a5602865102d9 100644
--- a/drivers/i2c/busses/i2c-virtio.c
+++ b/drivers/i2c/busses/i2c-virtio.c
@@ -10,10 +10,13 @@
 
 #include <linux/acpi.h>
 #include <linux/completion.h>
+#include <linux/dma-mapping.h>
 #include <linux/err.h>
 #include <linux/i2c.h>
 #include <linux/kernel.h>
+#include <linux/kref.h>
 #include <linux/module.h>
+#include <linux/overflow.h>
 #include <linux/virtio.h>
 #include <linux/virtio_ids.h>
 #include <linux/virtio_config.h>
@@ -31,39 +34,80 @@ struct virtio_i2c {
        struct virtqueue *vq;
 };
 
+struct virtio_i2c_xfer;
+
 /**
  * struct virtio_i2c_req - the virtio I2C request structure
+ * @xfer: owning transfer
  * @completion: completion of virtio I2C message
+ * @read: true if this is a read message (I2C_M_RD is set)
  * @out_hdr: the OUT header of the virtio I2C message
- * @buf: the buffer into which data is read, or from which it's written
  * @in_hdr: the IN header of the virtio I2C message
  */
 struct virtio_i2c_req {
+       struct virtio_i2c_xfer *xfer;
        struct completion completion;
+       bool read;
        struct virtio_i2c_out_hdr out_hdr       ____cacheline_aligned;
-       uint8_t *buf                            ____cacheline_aligned;
        struct virtio_i2c_in_hdr in_hdr         ____cacheline_aligned;
 };
 
+/**
+ * struct virtio_i2c_xfer - a queued I2C transfer
+ * @ref: one ref for the caller, plus one per in-flight virtqueue request
+ * @bounce_buf_base: start of bounce buffer region
+ * @reqs: the virtio I2C requests
+ *
+ * Allocation layout:
+ * - struct virtio_i2c_xfer xfer
+ * - struct virtio_i2c_req reqs[num]
+ * - padding to dma_get_cache_alignment()
+ * - u8 bounce_buf[virtio_i2c_bounce_size(msgs[0].len)]
+ *   ...
+ * - u8 bounce_buf[virtio_i2c_bounce_size(msgs[num-1].len)]
+ */
+struct virtio_i2c_xfer {
+       struct kref ref;
+       u8 *bounce_buf_base;
+       struct virtio_i2c_req reqs[];
+};
+
+static size_t virtio_i2c_bounce_size(unsigned int len)
+{
+       return ALIGN(len, dma_get_cache_alignment());
+}
+
+static void virtio_i2c_xfer_release(struct kref *ref)
+{
+       struct virtio_i2c_xfer *xfer = container_of(ref, struct 
virtio_i2c_xfer, ref);
+       kfree(xfer);
+}
+
 static void virtio_i2c_msg_done(struct virtqueue *vq)
 {
        struct virtio_i2c_req *req;
        unsigned int len;
 
-       while ((req = virtqueue_get_buf(vq, &len)))
+       while ((req = virtqueue_get_buf(vq, &len))) {
                complete(&req->completion);
+               kref_put(&req->xfer->ref, virtio_i2c_xfer_release);
+       }
 }
 
-static int virtio_i2c_prepare_reqs(struct virtqueue *vq,
-                                  struct virtio_i2c_req *reqs,
+static int virtio_i2c_prepare_xfer(struct virtqueue *vq,
+                                  struct virtio_i2c_xfer *xfer,
                                   struct i2c_msg *msgs, int num)
 {
        struct scatterlist *sgs[3], out_hdr, msg_buf, in_hdr;
+       struct virtio_i2c_req *reqs = xfer->reqs;
+       u8 *bounce_buf = xfer->bounce_buf_base;
        int i;
 
        for (i = 0; i < num; i++) {
                int outcnt = 0, incnt = 0;
 
+               reqs[i].xfer = xfer;
+               reqs[i].read = !!(msgs[i].flags & I2C_M_RD);
                init_completion(&reqs[i].completion);
 
                /*
@@ -82,23 +126,31 @@ static int virtio_i2c_prepare_reqs(struct virtqueue *vq,
                sgs[outcnt++] = &out_hdr;
 
                if (msgs[i].len) {
-                       reqs[i].buf = i2c_get_dma_safe_msg_buf(&msgs[i], 1);
-                       if (!reqs[i].buf)
-                               break;
+                       /*
+                        * Even if msg->flags has I2C_M_DMA_SAFE set, a bounce
+                        * buffer is required because the transfer may be
+                        * interrupted, after which msg->buf is no longer valid.
+                        */
+                       if (!(msgs[i].flags & I2C_M_RD))
+                               memcpy(bounce_buf, msgs[i].buf, msgs[i].len);
 
-                       sg_init_one(&msg_buf, reqs[i].buf, msgs[i].len);
+                       sg_init_one(&msg_buf, bounce_buf, msgs[i].len);
 
                        if (msgs[i].flags & I2C_M_RD)
                                sgs[outcnt + incnt++] = &msg_buf;
                        else
                                sgs[outcnt++] = &msg_buf;
                }
+               bounce_buf += virtio_i2c_bounce_size(msgs[i].len);
 
                sg_init_one(&in_hdr, &reqs[i].in_hdr, sizeof(reqs[i].in_hdr));
                sgs[outcnt + incnt++] = &in_hdr;
 
+               /* This reference is released in virtio_i2c_msg_done(). */
+               kref_get(&xfer->ref);
+
                if (virtqueue_add_sgs(vq, sgs, outcnt, incnt, &reqs[i], 
GFP_KERNEL)) {
-                       i2c_put_dma_safe_msg_buf(reqs[i].buf, &msgs[i], false);
+                       kref_put(&xfer->ref, virtio_i2c_xfer_release);
                        break;
                }
        }
@@ -106,26 +158,38 @@ static int virtio_i2c_prepare_reqs(struct virtqueue *vq,
        return i;
 }
 
-static int virtio_i2c_complete_reqs(struct virtqueue *vq,
-                                   struct virtio_i2c_req *reqs,
-                                   struct i2c_msg *msgs, int num)
+static int virtio_i2c_complete_xfer(struct virtio_i2c_xfer *xfer,
+                                   struct i2c_msg *msgs,
+                                   int num)
 {
+       struct virtio_i2c_req *reqs = xfer->reqs;
+       u8 *bounce_buf = xfer->bounce_buf_base;
        bool failed = false;
        int i, j = 0;
 
        for (i = 0; i < num; i++) {
                struct virtio_i2c_req *req = &reqs[i];
+               struct i2c_msg *msg = &msgs[i];
+
+               if (wait_for_completion_interruptible(&req->completion))
+                       return -EINTR;
+
+               if (req->in_hdr.status != VIRTIO_I2C_MSG_OK) {
+                       /*
+                        * Don't break yet. Try to wait until all requests
+                        * complete to ensure that the virtqueue has enough
+                        * descriptor slots for the next transfer.
+                        */
+                       failed = true;
+               }
 
                if (!failed) {
-                       if (wait_for_completion_interruptible(&req->completion))
-                               failed = true;
-                       else if (req->in_hdr.status != VIRTIO_I2C_MSG_OK)
-                               failed = true;
-                       else
-                               j++;
+                       if (req->read)
+                               memcpy(msg->buf, bounce_buf, msg->len);
+                       j++;
                }
 
-               i2c_put_dma_safe_msg_buf(reqs[i].buf, &msgs[i], !failed);
+               bounce_buf += virtio_i2c_bounce_size(msg->len);
        }
 
        return j;
@@ -136,14 +200,31 @@ static int virtio_i2c_xfer(struct i2c_adapter *adap, 
struct i2c_msg *msgs,
 {
        struct virtio_i2c *vi = i2c_get_adapdata(adap);
        struct virtqueue *vq = vi->vq;
-       struct virtio_i2c_req *reqs;
-       int count;
+       struct virtio_i2c_xfer *xfer;
+       size_t alloc_size;
+       int i, count;
+
+       alloc_size = struct_size(xfer, reqs, num);
+       if (check_add_overflow(alloc_size,
+                              dma_get_cache_alignment() - 1,
+                              &alloc_size)) /* padding for PTR_ALIGN() */
+               return -EOVERFLOW;
+       for (i = 0; i < num; i++) {
+               if (check_add_overflow(alloc_size,
+                                      virtio_i2c_bounce_size(msgs[i].len),
+                                      &alloc_size))
+                       return -EOVERFLOW;
+       }
 
-       reqs = kzalloc_objs(*reqs, num);
-       if (!reqs)
+       xfer = kzalloc(alloc_size, GFP_KERNEL);
+       if (!xfer)
                return -ENOMEM;
 
-       count = virtio_i2c_prepare_reqs(vq, reqs, msgs, num);
+       kref_init(&xfer->ref);
+       xfer->bounce_buf_base = PTR_ALIGN((u8 *)(xfer->reqs + num),
+                                         dma_get_cache_alignment());
+
+       count = virtio_i2c_prepare_xfer(vq, xfer, msgs, num);
        if (!count)
                goto err_free;
 
@@ -157,10 +238,10 @@ static int virtio_i2c_xfer(struct i2c_adapter *adap, 
struct i2c_msg *msgs,
         */
        virtqueue_kick(vq);
 
-       count = virtio_i2c_complete_reqs(vq, reqs, msgs, count);
+       count = virtio_i2c_complete_xfer(xfer, msgs, count);
 
 err_free:
-       kfree(reqs);
+       kref_put(&xfer->ref, virtio_i2c_xfer_release);
        return count;
 }
 
-- 
2.54.0


Reply via email to