commit a663b3c47ab1 ("i2c: virtio: Avoid hang by using interruptible
completion wait") switched virtio_i2c_complete_reqs() to
wait_for_completion_interruptible() so a stuck device cannot hang a
task forever. That left a use-after-free: if the wait returns early on
a signal, virtio_i2c_xfer() frees reqs and DMA bounce buffers while the
device may still hold virtqueue tokens pointing at &reqs[i] and DMA
into read buffers. When those requests complete later,
virtio_i2c_msg_done() calls complete() on freed memory.Waiting uninterruptibly for every completion before freeing avoids the UAF but can hang the caller indefinitely if the virtio side never completes the request. The virtio spec provides no way to cancel an in-flight transfer, so that is not an acceptable tradeoff. This commit makes two changes: - Manage the freeing of the xfer allocations via kref, and ensure that each in-flight request holds a reference. This fixes the use-after-free by ensuring that the virtio device has a valid location to write to until the request completes. This will cause a memory leak in cases where the device hangs, but that is much preferable to memory corruption. - Use wait_for_completion_killable() instead of _interruptible(). Even partial I2C transactions can have side effects, so the only time it makes sense to interrupt a transaction is when a process needs to be killed. Most existing I2C drivers don't support interruption at all, so this should not break userspace applications. This also addresses issues with Go programs accessing devices via the I2C userspace API, since the Go runtime stochastically signals SIGURG to running threads; leaving this as _interruptible() may cause partial side effects from which it is impossible to cleanly restart. Signed-off-by: Gavin Li <[email protected]> --- drivers/i2c/busses/i2c-virtio.c | 89 ++++++++++++++++++++++++--------- 1 file changed, 64 insertions(+), 25 deletions(-) diff --git a/drivers/i2c/busses/i2c-virtio.c b/drivers/i2c/busses/i2c-virtio.c index 726c162cabd86..f7320a67a3409 100644 --- a/drivers/i2c/busses/i2c-virtio.c +++ b/drivers/i2c/busses/i2c-virtio.c @@ -13,6 +13,7 @@ #include <linux/err.h> #include <linux/i2c.h> #include <linux/kernel.h> +#include <linux/kref.h> #include <linux/module.h> #include <linux/virtio.h> #include <linux/virtio_ids.h> @@ -31,39 +32,77 @@ struct virtio_i2c { struct virtqueue *vq; }; +struct virtio_i2c_xfer; + /** * struct virtio_i2c_req - the virtio I2C request structure + * @xfer: owning transfer + * @msg: copy of the I2C message for virtio_i2c_xfer_release * @completion: completion of virtio I2C message * @out_hdr: the OUT header of the virtio I2C message * @buf: the buffer into which data is read, or from which it's written * @in_hdr: the IN header of the virtio I2C message */ struct virtio_i2c_req { + struct virtio_i2c_xfer *xfer; + struct i2c_msg msg; struct completion completion; struct virtio_i2c_out_hdr out_hdr ____cacheline_aligned; uint8_t *buf ____cacheline_aligned; struct virtio_i2c_in_hdr in_hdr ____cacheline_aligned; }; +/** + * struct virtio_i2c_xfer - a queued I2C transfer + * @ref: one ref for the caller, plus one per in-flight virtqueue request + * @num: number of messages + * @reqs: the virtio I2C requests + */ +struct virtio_i2c_xfer { + struct kref ref; + int num; + struct virtio_i2c_req reqs[]; +}; + +static void virtio_i2c_xfer_release(struct kref *ref) +{ + struct virtio_i2c_xfer *xfer = container_of(ref, struct virtio_i2c_xfer, ref); + int i; + + for (i = 0; i < xfer->num; i++) { + struct virtio_i2c_req *req = &xfer->reqs[i]; + i2c_put_dma_safe_msg_buf(req->buf, &req->msg, false); + } + + kfree(xfer); +} + static void virtio_i2c_msg_done(struct virtqueue *vq) { struct virtio_i2c_req *req; unsigned int len; - while ((req = virtqueue_get_buf(vq, &len))) + while ((req = virtqueue_get_buf(vq, &len))) { complete(&req->completion); + kref_put(&req->xfer->ref, virtio_i2c_xfer_release); + } } static int virtio_i2c_prepare_reqs(struct virtqueue *vq, - struct virtio_i2c_req *reqs, + struct virtio_i2c_xfer *xfer, struct i2c_msg *msgs, int num) { struct scatterlist *sgs[3], out_hdr, msg_buf, in_hdr; + struct virtio_i2c_req *reqs = xfer->reqs; int i; + kref_init(&xfer->ref); + for (i = 0; i < num; i++) { int outcnt = 0, incnt = 0; + reqs[i].xfer = xfer; + reqs[i].msg = msgs[i]; init_completion(&reqs[i].completion); /* @@ -99,36 +138,36 @@ static int virtio_i2c_prepare_reqs(struct virtqueue *vq, if (virtqueue_add_sgs(vq, sgs, outcnt, incnt, &reqs[i], GFP_KERNEL)) { i2c_put_dma_safe_msg_buf(reqs[i].buf, &msgs[i], false); + reqs[i].buf = NULL; /* prevent free by virtio_i2c_xfer_release */ break; } + + kref_get(&xfer->ref); /* released in virtio_i2c_msg_done() */ } + xfer->num = i; return i; } -static int virtio_i2c_complete_reqs(struct virtqueue *vq, - struct virtio_i2c_req *reqs, - struct i2c_msg *msgs, int num) +static int virtio_i2c_complete_reqs(struct virtio_i2c_xfer *xfer) { - bool failed = false; - int i, j = 0; + struct virtio_i2c_req *reqs = xfer->reqs; + int i, fail_index = -1; - for (i = 0; i < num; i++) { + for (i = 0; i < xfer->num; i++) { struct virtio_i2c_req *req = &reqs[i]; - - if (!failed) { - if (wait_for_completion_interruptible(&req->completion)) - failed = true; - else if (req->in_hdr.status != VIRTIO_I2C_MSG_OK) - failed = true; - else - j++; + if (wait_for_completion_killable(&req->completion)) { + return -EINTR; + } else if (req->in_hdr.status != VIRTIO_I2C_MSG_OK) { + /* Don't break yet. Try to wait until all requests complete. */ + if (fail_index < 0) + fail_index = i; } - - i2c_put_dma_safe_msg_buf(reqs[i].buf, &msgs[i], !failed); + i2c_put_dma_safe_msg_buf(req->buf, &req->msg, fail_index < 0); + req->buf = NULL; /* prevent free by virtio_i2c_xfer_release */ } - return j; + return fail_index >= 0 ? fail_index : xfer->num; /* number of successful transactions */ } static int virtio_i2c_xfer(struct i2c_adapter *adap, struct i2c_msg *msgs, @@ -136,14 +175,14 @@ static int virtio_i2c_xfer(struct i2c_adapter *adap, struct i2c_msg *msgs, { struct virtio_i2c *vi = i2c_get_adapdata(adap); struct virtqueue *vq = vi->vq; - struct virtio_i2c_req *reqs; + struct virtio_i2c_xfer *xfer; int count; - reqs = kcalloc(num, sizeof(*reqs), GFP_KERNEL); - if (!reqs) + xfer = kzalloc(struct_size(xfer, reqs, num), GFP_KERNEL); + if (!xfer) return -ENOMEM; - count = virtio_i2c_prepare_reqs(vq, reqs, msgs, num); + count = virtio_i2c_prepare_reqs(vq, xfer, msgs, num); if (!count) goto err_free; @@ -157,10 +196,10 @@ static int virtio_i2c_xfer(struct i2c_adapter *adap, struct i2c_msg *msgs, */ virtqueue_kick(vq); - count = virtio_i2c_complete_reqs(vq, reqs, msgs, count); + count = virtio_i2c_complete_reqs(xfer); err_free: - kfree(reqs); + kref_put(&xfer->ref, virtio_i2c_xfer_release); return count; } -- 2.54.0
