We'll need to pass extra information for buffer registration apart from
iovec, add a flag to struct io_uring_rsrc_update2 that tells that its
data fields points to an extended registration structure, i.e.
struct io_uring_reg_buffer. To do normal registration the user has to
set target_fd and dmabuf_fd fields to -1, and any other combination is
currently rejected.

Signed-off-by: Pavel Begunkov <[email protected]>
---
 include/uapi/linux/io_uring.h | 13 ++++++++-
 io_uring/rsrc.c               | 53 +++++++++++++++++++++++++++--------
 2 files changed, 54 insertions(+), 12 deletions(-)

diff --git a/include/uapi/linux/io_uring.h b/include/uapi/linux/io_uring.h
index deb772222b6d..f64d1f246b93 100644
--- a/include/uapi/linux/io_uring.h
+++ b/include/uapi/linux/io_uring.h
@@ -765,15 +765,26 @@ struct io_uring_rsrc_update {
        __aligned_u64 data;
 };
 
+/* struct io_uring_rsrc_update2::flags */
+enum io_uring_rsrc_reg_flags {
+       IORING_RSRC_F_EXTENDED_UPDATE           = 1,
+};
+
 struct io_uring_rsrc_update2 {
        __u32 offset;
-       __u32 resv;
+       __u32 flags;
        __aligned_u64 data;
        __aligned_u64 tags;
        __u32 nr;
        __u32 resv2;
 };
 
+struct io_uring_reg_buffer {
+       __aligned_u64           iov_uaddr;
+       __s32                   target_fd;
+       __s32                   dmabuf_fd;
+};
+
 /* Skip updating fd indexes set to this value in the fd table */
 #define IORING_REGISTER_FILES_SKIP     (-2)
 
diff --git a/io_uring/rsrc.c b/io_uring/rsrc.c
index 21548942e80d..691f9645d04c 100644
--- a/io_uring/rsrc.c
+++ b/io_uring/rsrc.c
@@ -27,7 +27,8 @@ struct io_rsrc_update {
        u32                             offset;
 };
 
-static struct io_rsrc_node *io_sqe_buffer_register(struct io_ring_ctx *ctx,
+static struct io_rsrc_node *
+io_sqe_buffer_register(struct io_ring_ctx *ctx, struct io_uring_reg_buffer *rb,
                        struct iovec *iov, struct page **last_hpage);
 
 /* only define max */
@@ -234,6 +235,8 @@ static int __io_sqe_files_update(struct io_ring_ctx *ctx,
 
        if (!ctx->file_table.data.nr)
                return -ENXIO;
+       if (up->flags)
+               return -EINVAL;
        if (up->offset + nr_args > ctx->file_table.data.nr)
                return -EINVAL;
 
@@ -288,10 +291,18 @@ static int __io_sqe_files_update(struct io_ring_ctx *ctx,
        return done ? done : err;
 }
 
+static inline void io_default_reg_buf(struct io_uring_reg_buffer *rb)
+{
+       memset(rb, 0, sizeof(*rb));
+       rb->target_fd = -1;
+       rb->dmabuf_fd = -1;
+}
+
 static int __io_sqe_buffers_update(struct io_ring_ctx *ctx,
                                   struct io_uring_rsrc_update2 *up,
                                   unsigned int nr_args)
 {
+       bool extended_entry = up->flags & IORING_RSRC_F_EXTENDED_UPDATE;
        u64 __user *tags = u64_to_user_ptr(up->tags);
        struct iovec fast_iov, *iov;
        struct page *last_hpage = NULL;
@@ -302,14 +313,32 @@ static int __io_sqe_buffers_update(struct io_ring_ctx 
*ctx,
 
        if (!ctx->buf_table.nr)
                return -ENXIO;
+       if (up->flags & ~IORING_RSRC_F_EXTENDED_UPDATE)
+               return -EINVAL;
        if (up->offset + nr_args > ctx->buf_table.nr)
                return -EINVAL;
 
        for (done = 0; done < nr_args; done++) {
+               struct io_uring_reg_buffer rb;
                struct io_rsrc_node *node;
                u64 tag = 0;
 
-               uvec = u64_to_user_ptr(user_data);
+               if (extended_entry) {
+                       if (copy_from_user(&rb, u64_to_user_ptr(user_data),
+                                          sizeof(rb)))
+                               return -EFAULT;
+                       user_data += sizeof(rb);
+               } else {
+                       io_default_reg_buf(&rb);
+                       rb.iov_uaddr = user_data;
+
+                       if (ctx->compat)
+                               user_data += sizeof(struct compat_iovec);
+                       else
+                               user_data += sizeof(struct iovec);
+               }
+
+               uvec = u64_to_user_ptr(rb.iov_uaddr);
                iov = iovec_from_user(uvec, 1, 1, &fast_iov, ctx->compat);
                if (IS_ERR(iov)) {
                        err = PTR_ERR(iov);
@@ -322,7 +351,7 @@ static int __io_sqe_buffers_update(struct io_ring_ctx *ctx,
                err = io_buffer_validate(iov);
                if (err)
                        break;
-               node = io_sqe_buffer_register(ctx, iov, &last_hpage);
+               node = io_sqe_buffer_register(ctx, &rb, iov, &last_hpage);
                if (IS_ERR(node)) {
                        err = PTR_ERR(node);
                        break;
@@ -337,10 +366,6 @@ static int __io_sqe_buffers_update(struct io_ring_ctx *ctx,
                i = array_index_nospec(up->offset + done, ctx->buf_table.nr);
                io_reset_rsrc_node(ctx, &ctx->buf_table, i);
                ctx->buf_table.nodes[i] = node;
-               if (ctx->compat)
-                       user_data += sizeof(struct compat_iovec);
-               else
-                       user_data += sizeof(struct iovec);
        }
        return done ? done : err;
 }
@@ -375,7 +400,7 @@ int io_register_files_update(struct io_ring_ctx *ctx, void 
__user *arg,
        memset(&up, 0, sizeof(up));
        if (copy_from_user(&up, arg, sizeof(struct io_uring_rsrc_update)))
                return -EFAULT;
-       if (up.resv || up.resv2)
+       if (up.resv2)
                return -EINVAL;
        return __io_register_rsrc_update(ctx, IORING_RSRC_FILE, &up, nr_args);
 }
@@ -389,7 +414,7 @@ int io_register_rsrc_update(struct io_ring_ctx *ctx, void 
__user *arg,
                return -EINVAL;
        if (copy_from_user(&up, arg, sizeof(up)))
                return -EFAULT;
-       if (!up.nr || up.resv || up.resv2)
+       if (!up.nr || up.resv2)
                return -EINVAL;
        return __io_register_rsrc_update(ctx, type, &up, up.nr);
 }
@@ -493,7 +518,7 @@ int io_files_update(struct io_kiocb *req, unsigned int 
issue_flags)
        up2.data = up->arg;
        up2.nr = 0;
        up2.tags = 0;
-       up2.resv = 0;
+       up2.flags = 0;
        up2.resv2 = 0;
 
        if (up->offset == IORING_FILE_INDEX_ALLOC) {
@@ -778,6 +803,7 @@ bool io_check_coalesce_buffer(struct page **page_array, int 
nr_pages,
 }
 
 static struct io_rsrc_node *io_sqe_buffer_register(struct io_ring_ctx *ctx,
+                                                  struct io_uring_reg_buffer 
*rb,
                                                   struct iovec *iov,
                                                   struct page **last_hpage)
 {
@@ -790,6 +816,9 @@ static struct io_rsrc_node *io_sqe_buffer_register(struct 
io_ring_ctx *ctx,
        struct io_imu_folio_data data;
        bool coalesced = false;
 
+       if (rb->dmabuf_fd != -1 || rb->target_fd != -1)
+               return NULL;
+
        if (!iov->iov_base)
                return NULL;
 
@@ -887,6 +916,7 @@ int io_sqe_buffers_register(struct io_ring_ctx *ctx, void 
__user *arg,
                memset(iov, 0, sizeof(*iov));
 
        for (i = 0; i < nr_args; i++) {
+               struct io_uring_reg_buffer rb;
                struct io_rsrc_node *node;
                u64 tag = 0;
 
@@ -913,7 +943,8 @@ int io_sqe_buffers_register(struct io_ring_ctx *ctx, void 
__user *arg,
                        }
                }
 
-               node = io_sqe_buffer_register(ctx, iov, &last_hpage);
+               io_default_reg_buf(&rb);
+               node = io_sqe_buffer_register(ctx, &rb, iov, &last_hpage);
                if (IS_ERR(node)) {
                        ret = PTR_ERR(node);
                        break;
-- 
2.52.0

Reply via email to