From: Pavel Begunkov <[email protected]>

It's quite tedious and error-prone to manually check before each call
to io_{,un}account_mem() whether we need memory accounting. Instead,
the functions can work directly with struct io_ring_ctx and handle
checks themselves. In any case, they're perfectly inlined.

Signed-off-by: Pavel Begunkov <[email protected]>
---
 fs/io_uring.c | 59 ++++++++++++++++++++++-----------------------------
 1 file changed, 25 insertions(+), 34 deletions(-)

diff --git a/fs/io_uring.c b/fs/io_uring.c
index 3fd884b4e0be..f47f7abe19eb 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -2704,14 +2704,19 @@ static int io_sq_offload_start(struct io_ring_ctx *ctx,
        return ret;
 }
 
-static void io_unaccount_mem(struct user_struct *user, unsigned long nr_pages)
+static void io_unaccount_mem(struct io_ring_ctx *ctx, unsigned long nr_pages)
 {
-       atomic_long_sub(nr_pages, &user->locked_vm);
+       if (ctx->account_mem)
+               atomic_long_sub(nr_pages, &ctx->user->locked_vm);
 }
 
-static int io_account_mem(struct user_struct *user, unsigned long nr_pages)
+static int io_account_mem(struct io_ring_ctx *ctx, unsigned long nr_pages)
 {
        unsigned long page_limit, cur_pages, new_pages;
+       struct user_struct *user = ctx->user;
+
+       if (!ctx->account_mem)
+               return 0;
 
        /* Don't allow more pages than we can safely lock */
        page_limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
@@ -2773,8 +2778,7 @@ static int io_sqe_buffer_unregister(struct io_ring_ctx 
*ctx)
                for (j = 0; j < imu->nr_bvecs; j++)
                        put_page(imu->bvec[j].bv_page);
 
-               if (ctx->account_mem)
-                       io_unaccount_mem(ctx->user, imu->nr_bvecs);
+               io_unaccount_mem(ctx, imu->nr_bvecs);
                kvfree(imu->bvec);
                imu->nr_bvecs = 0;
        }
@@ -2857,11 +2861,9 @@ static int io_sqe_buffer_register(struct io_ring_ctx 
*ctx, void __user *arg,
                start = ubuf >> PAGE_SHIFT;
                nr_pages = end - start;
 
-               if (ctx->account_mem) {
-                       ret = io_account_mem(ctx->user, nr_pages);
-                       if (ret)
-                               goto err;
-               }
+               ret = io_account_mem(ctx, nr_pages);
+               if (ret)
+                       goto err;
 
                ret = 0;
                if (!pages || nr_pages > got_pages) {
@@ -2874,8 +2876,7 @@ static int io_sqe_buffer_register(struct io_ring_ctx 
*ctx, void __user *arg,
                                        GFP_KERNEL);
                        if (!pages || !vmas) {
                                ret = -ENOMEM;
-                               if (ctx->account_mem)
-                                       io_unaccount_mem(ctx->user, nr_pages);
+                               io_unaccount_mem(ctx, nr_pages);
                                goto err;
                        }
                        got_pages = nr_pages;
@@ -2885,8 +2886,7 @@ static int io_sqe_buffer_register(struct io_ring_ctx 
*ctx, void __user *arg,
                                                GFP_KERNEL);
                ret = -ENOMEM;
                if (!imu->bvec) {
-                       if (ctx->account_mem)
-                               io_unaccount_mem(ctx->user, nr_pages);
+                       io_unaccount_mem(ctx, nr_pages);
                        goto err;
                }
 
@@ -2919,8 +2919,7 @@ static int io_sqe_buffer_register(struct io_ring_ctx 
*ctx, void __user *arg,
                                for (j = 0; j < pret; j++)
                                        put_page(pages[j]);
                        }
-                       if (ctx->account_mem)
-                               io_unaccount_mem(ctx->user, nr_pages);
+                       io_unaccount_mem(ctx, nr_pages);
                        kvfree(imu->bvec);
                        goto err;
                }
@@ -3009,9 +3008,7 @@ static void io_ring_ctx_free(struct io_ring_ctx *ctx)
        io_mem_free(ctx->cq_ring);
 
        percpu_ref_exit(&ctx->refs);
-       if (ctx->account_mem)
-               io_unaccount_mem(ctx->user,
-                               ring_pages(ctx->sq_entries, ctx->cq_entries));
+       io_unaccount_mem(ctx, ring_pages(ctx->sq_entries, ctx->cq_entries));
        free_uid(ctx->user);
        kfree(ctx);
 }
@@ -3253,7 +3250,6 @@ static int io_uring_create(unsigned entries, struct 
io_uring_params *p)
 {
        struct user_struct *user = NULL;
        struct io_ring_ctx *ctx;
-       bool account_mem;
        int ret;
 
        if (!entries || entries > IORING_MAX_ENTRIES)
@@ -3269,29 +3265,24 @@ static int io_uring_create(unsigned entries, struct 
io_uring_params *p)
        p->cq_entries = 2 * p->sq_entries;
 
        user = get_uid(current_user());
-       account_mem = !capable(CAP_IPC_LOCK);
-
-       if (account_mem) {
-               ret = io_account_mem(user,
-                               ring_pages(p->sq_entries, p->cq_entries));
-               if (ret) {
-                       free_uid(user);
-                       return ret;
-               }
-       }
 
        ctx = io_ring_ctx_alloc(p);
        if (!ctx) {
-               if (account_mem)
-                       io_unaccount_mem(user, ring_pages(p->sq_entries,
-                                                               p->cq_entries));
                free_uid(user);
                return -ENOMEM;
        }
+
        ctx->compat = in_compat_syscall();
-       ctx->account_mem = account_mem;
+       ctx->account_mem = !capable(CAP_IPC_LOCK);
        ctx->user = user;
 
+       ret = io_account_mem(ctx, ring_pages(p->sq_entries, p->cq_entries));
+       if (ret) {
+               free_uid(user);
+               kfree(ctx);
+               return ret;
+       }
+
        ret = io_allocate_scq_urings(ctx, p);
        if (ret)
                goto err;
-- 
2.22.0

Reply via email to