commit a663b3c47ab1 ("i2c: virtio: Avoid hang by using interruptible
completion wait") switched virtio_i2c_complete_reqs() to
wait_for_completion_interruptible() so a stuck device cannot hang a
task forever. That left a use-after-free: if the wait returns early on
a signal, virtio_i2c_xfer() frees reqs and DMA bounce buffers while the
device may still hold virtqueue tokens pointing at &reqs[i] and DMA
into read buffers. When those requests complete later,
virtio_i2c_msg_done() calls complete() on freed memory.Waiting uninterruptibly for every completion before freeing avoids the UAF but can hang the caller indefinitely if the virtio side never completes the request. Performing a queue reset / device reset is possible in this scenario, but adds complexity. This commit manages the freeing of the xfer allocations via kref, and ensures that each in-flight request holds a reference. This fixes the use-after-free by ensuring that the virtio device has a valid location to write to until the request completes. This will cause a memory leak in cases where the device hangs, but that is much preferable to memory corruption. Additionally, force usage of a bounce buffer even if the i2c_msg buf is DMA-safe, since the buffer passed to the virtqueue must remain valid even if the transfer is interrupted. Remove usage of i2c_get_dma_safe_msg_buf() since it may pass through msg->buf directly. All bounce buffers are part of the single xfer allocation, so there is no additional allocation overhead. Signed-off-by: Gavin Li <[email protected]> --- Changes in v5: - DMA-align all bounce buffers Changes in v4: - Pack bounce buffers into a single allocation after reqs[] - Remove per-request buf pointer and xfer->num - Remove req.msg, track message direction with req->read - Simplify xfer release to a single kfree() - Restore using j to track successful transfers in _complete_xfer() --- drivers/i2c/busses/i2c-virtio.c | 135 +++++++++++++++++++++++++------- 1 file changed, 108 insertions(+), 27 deletions(-) diff --git a/drivers/i2c/busses/i2c-virtio.c b/drivers/i2c/busses/i2c-virtio.c index 5da6fef92bec3..a5602865102d9 100644 --- a/drivers/i2c/busses/i2c-virtio.c +++ b/drivers/i2c/busses/i2c-virtio.c @@ -10,10 +10,13 @@ #include <linux/acpi.h> #include <linux/completion.h> +#include <linux/dma-mapping.h> #include <linux/err.h> #include <linux/i2c.h> #include <linux/kernel.h> +#include <linux/kref.h> #include <linux/module.h> +#include <linux/overflow.h> #include <linux/virtio.h> #include <linux/virtio_ids.h> #include <linux/virtio_config.h> @@ -31,39 +34,80 @@ struct virtio_i2c { struct virtqueue *vq; }; +struct virtio_i2c_xfer; + /** * struct virtio_i2c_req - the virtio I2C request structure + * @xfer: owning transfer * @completion: completion of virtio I2C message + * @read: true if this is a read message (I2C_M_RD is set) * @out_hdr: the OUT header of the virtio I2C message - * @buf: the buffer into which data is read, or from which it's written * @in_hdr: the IN header of the virtio I2C message */ struct virtio_i2c_req { + struct virtio_i2c_xfer *xfer; struct completion completion; + bool read; struct virtio_i2c_out_hdr out_hdr ____cacheline_aligned; - uint8_t *buf ____cacheline_aligned; struct virtio_i2c_in_hdr in_hdr ____cacheline_aligned; }; +/** + * struct virtio_i2c_xfer - a queued I2C transfer + * @ref: one ref for the caller, plus one per in-flight virtqueue request + * @bounce_buf_base: start of bounce buffer region + * @reqs: the virtio I2C requests + * + * Allocation layout: + * - struct virtio_i2c_xfer xfer + * - struct virtio_i2c_req reqs[num] + * - padding to dma_get_cache_alignment() + * - u8 bounce_buf[virtio_i2c_bounce_size(msgs[0].len)] + * ... + * - u8 bounce_buf[virtio_i2c_bounce_size(msgs[num-1].len)] + */ +struct virtio_i2c_xfer { + struct kref ref; + u8 *bounce_buf_base; + struct virtio_i2c_req reqs[]; +}; + +static size_t virtio_i2c_bounce_size(unsigned int len) +{ + return ALIGN(len, dma_get_cache_alignment()); +} + +static void virtio_i2c_xfer_release(struct kref *ref) +{ + struct virtio_i2c_xfer *xfer = container_of(ref, struct virtio_i2c_xfer, ref); + kfree(xfer); +} + static void virtio_i2c_msg_done(struct virtqueue *vq) { struct virtio_i2c_req *req; unsigned int len; - while ((req = virtqueue_get_buf(vq, &len))) + while ((req = virtqueue_get_buf(vq, &len))) { complete(&req->completion); + kref_put(&req->xfer->ref, virtio_i2c_xfer_release); + } } -static int virtio_i2c_prepare_reqs(struct virtqueue *vq, - struct virtio_i2c_req *reqs, +static int virtio_i2c_prepare_xfer(struct virtqueue *vq, + struct virtio_i2c_xfer *xfer, struct i2c_msg *msgs, int num) { struct scatterlist *sgs[3], out_hdr, msg_buf, in_hdr; + struct virtio_i2c_req *reqs = xfer->reqs; + u8 *bounce_buf = xfer->bounce_buf_base; int i; for (i = 0; i < num; i++) { int outcnt = 0, incnt = 0; + reqs[i].xfer = xfer; + reqs[i].read = !!(msgs[i].flags & I2C_M_RD); init_completion(&reqs[i].completion); /* @@ -82,23 +126,31 @@ static int virtio_i2c_prepare_reqs(struct virtqueue *vq, sgs[outcnt++] = &out_hdr; if (msgs[i].len) { - reqs[i].buf = i2c_get_dma_safe_msg_buf(&msgs[i], 1); - if (!reqs[i].buf) - break; + /* + * Even if msg->flags has I2C_M_DMA_SAFE set, a bounce + * buffer is required because the transfer may be + * interrupted, after which msg->buf is no longer valid. + */ + if (!(msgs[i].flags & I2C_M_RD)) + memcpy(bounce_buf, msgs[i].buf, msgs[i].len); - sg_init_one(&msg_buf, reqs[i].buf, msgs[i].len); + sg_init_one(&msg_buf, bounce_buf, msgs[i].len); if (msgs[i].flags & I2C_M_RD) sgs[outcnt + incnt++] = &msg_buf; else sgs[outcnt++] = &msg_buf; } + bounce_buf += virtio_i2c_bounce_size(msgs[i].len); sg_init_one(&in_hdr, &reqs[i].in_hdr, sizeof(reqs[i].in_hdr)); sgs[outcnt + incnt++] = &in_hdr; + /* This reference is released in virtio_i2c_msg_done(). */ + kref_get(&xfer->ref); + if (virtqueue_add_sgs(vq, sgs, outcnt, incnt, &reqs[i], GFP_KERNEL)) { - i2c_put_dma_safe_msg_buf(reqs[i].buf, &msgs[i], false); + kref_put(&xfer->ref, virtio_i2c_xfer_release); break; } } @@ -106,26 +158,38 @@ static int virtio_i2c_prepare_reqs(struct virtqueue *vq, return i; } -static int virtio_i2c_complete_reqs(struct virtqueue *vq, - struct virtio_i2c_req *reqs, - struct i2c_msg *msgs, int num) +static int virtio_i2c_complete_xfer(struct virtio_i2c_xfer *xfer, + struct i2c_msg *msgs, + int num) { + struct virtio_i2c_req *reqs = xfer->reqs; + u8 *bounce_buf = xfer->bounce_buf_base; bool failed = false; int i, j = 0; for (i = 0; i < num; i++) { struct virtio_i2c_req *req = &reqs[i]; + struct i2c_msg *msg = &msgs[i]; + + if (wait_for_completion_interruptible(&req->completion)) + return -EINTR; + + if (req->in_hdr.status != VIRTIO_I2C_MSG_OK) { + /* + * Don't break yet. Try to wait until all requests + * complete to ensure that the virtqueue has enough + * descriptor slots for the next transfer. + */ + failed = true; + } if (!failed) { - if (wait_for_completion_interruptible(&req->completion)) - failed = true; - else if (req->in_hdr.status != VIRTIO_I2C_MSG_OK) - failed = true; - else - j++; + if (req->read) + memcpy(msg->buf, bounce_buf, msg->len); + j++; } - i2c_put_dma_safe_msg_buf(reqs[i].buf, &msgs[i], !failed); + bounce_buf += virtio_i2c_bounce_size(msg->len); } return j; @@ -136,14 +200,31 @@ static int virtio_i2c_xfer(struct i2c_adapter *adap, struct i2c_msg *msgs, { struct virtio_i2c *vi = i2c_get_adapdata(adap); struct virtqueue *vq = vi->vq; - struct virtio_i2c_req *reqs; - int count; + struct virtio_i2c_xfer *xfer; + size_t alloc_size; + int i, count; + + alloc_size = struct_size(xfer, reqs, num); + if (check_add_overflow(alloc_size, + dma_get_cache_alignment() - 1, + &alloc_size)) /* padding for PTR_ALIGN() */ + return -EOVERFLOW; + for (i = 0; i < num; i++) { + if (check_add_overflow(alloc_size, + virtio_i2c_bounce_size(msgs[i].len), + &alloc_size)) + return -EOVERFLOW; + } - reqs = kzalloc_objs(*reqs, num); - if (!reqs) + xfer = kzalloc(alloc_size, GFP_KERNEL); + if (!xfer) return -ENOMEM; - count = virtio_i2c_prepare_reqs(vq, reqs, msgs, num); + kref_init(&xfer->ref); + xfer->bounce_buf_base = PTR_ALIGN((u8 *)(xfer->reqs + num), + dma_get_cache_alignment()); + + count = virtio_i2c_prepare_xfer(vq, xfer, msgs, num); if (!count) goto err_free; @@ -157,10 +238,10 @@ static int virtio_i2c_xfer(struct i2c_adapter *adap, struct i2c_msg *msgs, */ virtqueue_kick(vq); - count = virtio_i2c_complete_reqs(vq, reqs, msgs, count); + count = virtio_i2c_complete_xfer(xfer, msgs, count); err_free: - kfree(reqs); + kref_put(&xfer->ref, virtio_i2c_xfer_release); return count; } -- 2.54.0
