commit a663b3c47ab1 ("i2c: virtio: Avoid hang by using interruptible
completion wait") switched virtio_i2c_complete_reqs() to
wait_for_completion_interruptible() so a stuck device cannot hang a
task forever. That left a use-after-free: if the wait returns early on
a signal, virtio_i2c_xfer() frees reqs and DMA bounce buffers while the
device may still hold virtqueue tokens pointing at &reqs[i] and DMA
into read buffers. When those requests complete later,
virtio_i2c_msg_done() calls complete() on freed memory.

Waiting uninterruptibly for every completion before freeing avoids the
UAF but can hang the caller indefinitely if the virtio side never
completes the request. The virtio spec provides no way to cancel an
in-flight transfer, so that is not an acceptable tradeoff.

This commit makes two changes:

- Manage the freeing of the xfer allocations via kref, and ensure that
  each in-flight request holds a reference. This fixes the
  use-after-free by ensuring that the virtio device has a valid location
  to write to until the request completes. This will cause a memory
  leak in cases where the device hangs, but that is much preferable to
  memory corruption.

- Use wait_for_completion_killable() instead of _interruptible(). Even
  partial I2C transactions can have side effects, so the only time it
  makes sense to interrupt a transaction is when a process needs to be
  killed. Most existing I2C drivers don't support interruption at all,
  so this should not break userspace applications. This also addresses
  issues with Go programs accessing devices via the I2C userspace API,
  since the Go runtime stochastically signals SIGURG to running threads;
  leaving this as _interruptible() may cause partial side effects from
  which it is impossible to cleanly restart.

Signed-off-by: Gavin Li <[email protected]>
---
 drivers/i2c/busses/i2c-virtio.c | 89 ++++++++++++++++++++++++---------
 1 file changed, 64 insertions(+), 25 deletions(-)

diff --git a/drivers/i2c/busses/i2c-virtio.c b/drivers/i2c/busses/i2c-virtio.c
index 726c162cabd86..f7320a67a3409 100644
--- a/drivers/i2c/busses/i2c-virtio.c
+++ b/drivers/i2c/busses/i2c-virtio.c
@@ -13,6 +13,7 @@
 #include <linux/err.h>
 #include <linux/i2c.h>
 #include <linux/kernel.h>
+#include <linux/kref.h>
 #include <linux/module.h>
 #include <linux/virtio.h>
 #include <linux/virtio_ids.h>
@@ -31,39 +32,77 @@ struct virtio_i2c {
        struct virtqueue *vq;
 };
 
+struct virtio_i2c_xfer;
+
 /**
  * struct virtio_i2c_req - the virtio I2C request structure
+ * @xfer: owning transfer
+ * @msg: copy of the I2C message for virtio_i2c_xfer_release
  * @completion: completion of virtio I2C message
  * @out_hdr: the OUT header of the virtio I2C message
  * @buf: the buffer into which data is read, or from which it's written
  * @in_hdr: the IN header of the virtio I2C message
  */
 struct virtio_i2c_req {
+       struct virtio_i2c_xfer *xfer;
+       struct i2c_msg msg;
        struct completion completion;
        struct virtio_i2c_out_hdr out_hdr       ____cacheline_aligned;
        uint8_t *buf                            ____cacheline_aligned;
        struct virtio_i2c_in_hdr in_hdr         ____cacheline_aligned;
 };
 
+/**
+ * struct virtio_i2c_xfer - a queued I2C transfer
+ * @ref: one ref for the caller, plus one per in-flight virtqueue request
+ * @num: number of messages
+ * @reqs: the virtio I2C requests
+ */
+struct virtio_i2c_xfer {
+       struct kref ref;
+       int num;
+       struct virtio_i2c_req reqs[];
+};
+
+static void virtio_i2c_xfer_release(struct kref *ref)
+{
+       struct virtio_i2c_xfer *xfer = container_of(ref, struct 
virtio_i2c_xfer, ref);
+       int i;
+
+       for (i = 0; i < xfer->num; i++) {
+               struct virtio_i2c_req *req = &xfer->reqs[i];
+               i2c_put_dma_safe_msg_buf(req->buf, &req->msg, false);
+       }
+
+       kfree(xfer);
+}
+
 static void virtio_i2c_msg_done(struct virtqueue *vq)
 {
        struct virtio_i2c_req *req;
        unsigned int len;
 
-       while ((req = virtqueue_get_buf(vq, &len)))
+       while ((req = virtqueue_get_buf(vq, &len))) {
                complete(&req->completion);
+               kref_put(&req->xfer->ref, virtio_i2c_xfer_release);
+       }
 }
 
 static int virtio_i2c_prepare_reqs(struct virtqueue *vq,
-                                  struct virtio_i2c_req *reqs,
+                                  struct virtio_i2c_xfer *xfer,
                                   struct i2c_msg *msgs, int num)
 {
        struct scatterlist *sgs[3], out_hdr, msg_buf, in_hdr;
+       struct virtio_i2c_req *reqs = xfer->reqs;
        int i;
 
+       kref_init(&xfer->ref);
+
        for (i = 0; i < num; i++) {
                int outcnt = 0, incnt = 0;
 
+               reqs[i].xfer = xfer;
+               reqs[i].msg = msgs[i];
                init_completion(&reqs[i].completion);
 
                /*
@@ -99,36 +138,36 @@ static int virtio_i2c_prepare_reqs(struct virtqueue *vq,
 
                if (virtqueue_add_sgs(vq, sgs, outcnt, incnt, &reqs[i], 
GFP_KERNEL)) {
                        i2c_put_dma_safe_msg_buf(reqs[i].buf, &msgs[i], false);
+                       reqs[i].buf = NULL; /* prevent free by 
virtio_i2c_xfer_release */
                        break;
                }
+
+               kref_get(&xfer->ref); /* released in virtio_i2c_msg_done() */
        }
 
+       xfer->num = i;
        return i;
 }
 
-static int virtio_i2c_complete_reqs(struct virtqueue *vq,
-                                   struct virtio_i2c_req *reqs,
-                                   struct i2c_msg *msgs, int num)
+static int virtio_i2c_complete_reqs(struct virtio_i2c_xfer *xfer)
 {
-       bool failed = false;
-       int i, j = 0;
+       struct virtio_i2c_req *reqs = xfer->reqs;
+       int i, fail_index = -1;
 
-       for (i = 0; i < num; i++) {
+       for (i = 0; i < xfer->num; i++) {
                struct virtio_i2c_req *req = &reqs[i];
-
-               if (!failed) {
-                       if (wait_for_completion_interruptible(&req->completion))
-                               failed = true;
-                       else if (req->in_hdr.status != VIRTIO_I2C_MSG_OK)
-                               failed = true;
-                       else
-                               j++;
+               if (wait_for_completion_killable(&req->completion)) {
+                       return -EINTR;
+               } else if (req->in_hdr.status != VIRTIO_I2C_MSG_OK) {
+                       /* Don't break yet. Try to wait until all requests 
complete. */
+                       if (fail_index < 0)
+                               fail_index = i;
                }
-
-               i2c_put_dma_safe_msg_buf(reqs[i].buf, &msgs[i], !failed);
+               i2c_put_dma_safe_msg_buf(req->buf, &req->msg, fail_index < 0);
+               req->buf = NULL; /* prevent free by virtio_i2c_xfer_release */
        }
 
-       return j;
+       return fail_index >= 0 ? fail_index : xfer->num; /* number of 
successful transactions */
 }
 
 static int virtio_i2c_xfer(struct i2c_adapter *adap, struct i2c_msg *msgs,
@@ -136,14 +175,14 @@ static int virtio_i2c_xfer(struct i2c_adapter *adap, 
struct i2c_msg *msgs,
 {
        struct virtio_i2c *vi = i2c_get_adapdata(adap);
        struct virtqueue *vq = vi->vq;
-       struct virtio_i2c_req *reqs;
+       struct virtio_i2c_xfer *xfer;
        int count;
 
-       reqs = kcalloc(num, sizeof(*reqs), GFP_KERNEL);
-       if (!reqs)
+       xfer = kzalloc(struct_size(xfer, reqs, num), GFP_KERNEL);
+       if (!xfer)
                return -ENOMEM;
 
-       count = virtio_i2c_prepare_reqs(vq, reqs, msgs, num);
+       count = virtio_i2c_prepare_reqs(vq, xfer, msgs, num);
        if (!count)
                goto err_free;
 
@@ -157,10 +196,10 @@ static int virtio_i2c_xfer(struct i2c_adapter *adap, 
struct i2c_msg *msgs,
         */
        virtqueue_kick(vq);
 
-       count = virtio_i2c_complete_reqs(vq, reqs, msgs, count);
+       count = virtio_i2c_complete_reqs(xfer);
 
 err_free:
-       kfree(reqs);
+       kref_put(&xfer->ref, virtio_i2c_xfer_release);
        return count;
 }
 
-- 
2.54.0


Reply via email to