From: "Mike Rapoport (Microsoft)" <[email protected]>

mfill_atomic() passes a lot of parameters down to its callees.

Aggregate them all into mfill_state structure and pass this structure to
functions that implement various UFFDIO_ commands.

Tracking the state in a structure will allow moving the code that retries
copying of data for UFFDIO_COPY into mfill_atomic_pte_copy() and make the
loop in mfill_atomic() identical for all UFFDIO operations on PTE-mapped
memory.

The mfill_state definition is deliberately local to mm/userfaultfd.c, hence
shmem_mfill_atomic_pte() is not updated.

Signed-off-by: Mike Rapoport (Microsoft) <[email protected]>
---
 mm/userfaultfd.c | 148 ++++++++++++++++++++++++++---------------------
 1 file changed, 82 insertions(+), 66 deletions(-)

diff --git a/mm/userfaultfd.c b/mm/userfaultfd.c
index a0885d543f22..6a0697c93ff4 100644
--- a/mm/userfaultfd.c
+++ b/mm/userfaultfd.c
@@ -20,6 +20,20 @@
 #include "internal.h"
 #include "swap.h"
 
+struct mfill_state {
+       struct userfaultfd_ctx *ctx;
+       unsigned long src_start;
+       unsigned long dst_start;
+       unsigned long len;
+       uffd_flags_t flags;
+
+       struct vm_area_struct *vma;
+       unsigned long src_addr;
+       unsigned long dst_addr;
+       struct folio *folio;
+       pmd_t *pmd;
+};
+
 static __always_inline
 bool validate_dst_vma(struct vm_area_struct *dst_vma, unsigned long dst_end)
 {
@@ -272,17 +286,17 @@ static int mfill_copy_folio_locked(struct folio *folio, 
unsigned long src_addr)
        return ret;
 }
 
-static int mfill_atomic_pte_copy(pmd_t *dst_pmd,
-                                struct vm_area_struct *dst_vma,
-                                unsigned long dst_addr,
-                                unsigned long src_addr,
-                                uffd_flags_t flags,
-                                struct folio **foliop)
+static int mfill_atomic_pte_copy(struct mfill_state *state)
 {
-       int ret;
+       struct vm_area_struct *dst_vma = state->vma;
+       unsigned long dst_addr = state->dst_addr;
+       unsigned long src_addr = state->src_addr;
+       uffd_flags_t flags = state->flags;
+       pmd_t *dst_pmd = state->pmd;
        struct folio *folio;
+       int ret;
 
-       if (!*foliop) {
+       if (!state->folio) {
                ret = -ENOMEM;
                folio = vma_alloc_folio(GFP_HIGHUSER_MOVABLE, 0, dst_vma,
                                        dst_addr);
@@ -294,13 +308,13 @@ static int mfill_atomic_pte_copy(pmd_t *dst_pmd,
                /* fallback to copy_from_user outside mmap_lock */
                if (unlikely(ret)) {
                        ret = -ENOENT;
-                       *foliop = folio;
+                       state->folio = folio;
                        /* don't free the page */
                        goto out;
                }
        } else {
-               folio = *foliop;
-               *foliop = NULL;
+               folio = state->folio;
+               state->folio = NULL;
        }
 
        /*
@@ -357,10 +371,11 @@ static int mfill_atomic_pte_zeroed_folio(pmd_t *dst_pmd,
        return ret;
 }
 
-static int mfill_atomic_pte_zeropage(pmd_t *dst_pmd,
-                                    struct vm_area_struct *dst_vma,
-                                    unsigned long dst_addr)
+static int mfill_atomic_pte_zeropage(struct mfill_state *state)
 {
+       struct vm_area_struct *dst_vma = state->vma;
+       unsigned long dst_addr = state->dst_addr;
+       pmd_t *dst_pmd = state->pmd;
        pte_t _dst_pte, *dst_pte;
        spinlock_t *ptl;
        int ret;
@@ -392,13 +407,14 @@ static int mfill_atomic_pte_zeropage(pmd_t *dst_pmd,
 }
 
 /* Handles UFFDIO_CONTINUE for all shmem VMAs (shared or private). */
-static int mfill_atomic_pte_continue(pmd_t *dst_pmd,
-                                    struct vm_area_struct *dst_vma,
-                                    unsigned long dst_addr,
-                                    uffd_flags_t flags)
+static int mfill_atomic_pte_continue(struct mfill_state *state)
 {
-       struct inode *inode = file_inode(dst_vma->vm_file);
+       struct vm_area_struct *dst_vma = state->vma;
+       unsigned long dst_addr = state->dst_addr;
        pgoff_t pgoff = linear_page_index(dst_vma, dst_addr);
+       struct inode *inode = file_inode(dst_vma->vm_file);
+       uffd_flags_t flags = state->flags;
+       pmd_t *dst_pmd = state->pmd;
        struct folio *folio;
        struct page *page;
        int ret;
@@ -436,15 +452,15 @@ static int mfill_atomic_pte_continue(pmd_t *dst_pmd,
 }
 
 /* Handles UFFDIO_POISON for all non-hugetlb VMAs. */
-static int mfill_atomic_pte_poison(pmd_t *dst_pmd,
-                                  struct vm_area_struct *dst_vma,
-                                  unsigned long dst_addr,
-                                  uffd_flags_t flags)
+static int mfill_atomic_pte_poison(struct mfill_state *state)
 {
-       int ret;
+       struct vm_area_struct *dst_vma = state->vma;
        struct mm_struct *dst_mm = dst_vma->vm_mm;
+       unsigned long dst_addr = state->dst_addr;
+       pmd_t *dst_pmd = state->pmd;
        pte_t _dst_pte, *dst_pte;
        spinlock_t *ptl;
+       int ret;
 
        _dst_pte = make_pte_marker(PTE_MARKER_POISONED);
        ret = -EAGAIN;
@@ -668,22 +684,20 @@ extern ssize_t mfill_atomic_hugetlb(struct 
userfaultfd_ctx *ctx,
                                    uffd_flags_t flags);
 #endif /* CONFIG_HUGETLB_PAGE */
 
-static __always_inline ssize_t mfill_atomic_pte(pmd_t *dst_pmd,
-                                               struct vm_area_struct *dst_vma,
-                                               unsigned long dst_addr,
-                                               unsigned long src_addr,
-                                               uffd_flags_t flags,
-                                               struct folio **foliop)
+static __always_inline ssize_t mfill_atomic_pte(struct mfill_state *state)
 {
+       struct vm_area_struct *dst_vma = state->vma;
+       unsigned long src_addr = state->src_addr;
+       unsigned long dst_addr = state->dst_addr;
+       struct folio **foliop = &state->folio;
+       uffd_flags_t flags = state->flags;
+       pmd_t *dst_pmd = state->pmd;
        ssize_t err;
 
-       if (uffd_flags_mode_is(flags, MFILL_ATOMIC_CONTINUE)) {
-               return mfill_atomic_pte_continue(dst_pmd, dst_vma,
-                                                dst_addr, flags);
-       } else if (uffd_flags_mode_is(flags, MFILL_ATOMIC_POISON)) {
-               return mfill_atomic_pte_poison(dst_pmd, dst_vma,
-                                              dst_addr, flags);
-       }
+       if (uffd_flags_mode_is(flags, MFILL_ATOMIC_CONTINUE))
+               return mfill_atomic_pte_continue(state);
+       if (uffd_flags_mode_is(flags, MFILL_ATOMIC_POISON))
+               return mfill_atomic_pte_poison(state);
 
        /*
         * The normal page fault path for a shmem will invoke the
@@ -697,12 +711,9 @@ static __always_inline ssize_t mfill_atomic_pte(pmd_t 
*dst_pmd,
         */
        if (!(dst_vma->vm_flags & VM_SHARED)) {
                if (uffd_flags_mode_is(flags, MFILL_ATOMIC_COPY))
-                       err = mfill_atomic_pte_copy(dst_pmd, dst_vma,
-                                                   dst_addr, src_addr,
-                                                   flags, foliop);
+                       err = mfill_atomic_pte_copy(state);
                else
-                       err = mfill_atomic_pte_zeropage(dst_pmd,
-                                                dst_vma, dst_addr);
+                       err = mfill_atomic_pte_zeropage(state);
        } else {
                err = shmem_mfill_atomic_pte(dst_pmd, dst_vma,
                                             dst_addr, src_addr,
@@ -718,13 +729,20 @@ static __always_inline ssize_t mfill_atomic(struct 
userfaultfd_ctx *ctx,
                                            unsigned long len,
                                            uffd_flags_t flags)
 {
+       struct mfill_state state = (struct mfill_state){
+               .ctx = ctx,
+               .dst_start = dst_start,
+               .src_start = src_start,
+               .flags = flags,
+
+               .src_addr = src_start,
+               .dst_addr = dst_start,
+       };
        struct mm_struct *dst_mm = ctx->mm;
        struct vm_area_struct *dst_vma;
+       long copied = 0;
        ssize_t err;
        pmd_t *dst_pmd;
-       unsigned long src_addr, dst_addr;
-       long copied;
-       struct folio *folio;
 
        /*
         * Sanitize the command parameters:
@@ -736,10 +754,6 @@ static __always_inline ssize_t mfill_atomic(struct 
userfaultfd_ctx *ctx,
        VM_WARN_ON_ONCE(src_start + len <= src_start);
        VM_WARN_ON_ONCE(dst_start + len <= dst_start);
 
-       src_addr = src_start;
-       dst_addr = dst_start;
-       copied = 0;
-       folio = NULL;
 retry:
        /*
         * Make sure the vma is not shared, that the dst range is
@@ -790,12 +804,14 @@ static __always_inline ssize_t mfill_atomic(struct 
userfaultfd_ctx *ctx,
            uffd_flags_mode_is(flags, MFILL_ATOMIC_CONTINUE))
                goto out_unlock;
 
-       while (src_addr < src_start + len) {
-               pmd_t dst_pmdval;
+       state.vma = dst_vma;
 
-               VM_WARN_ON_ONCE(dst_addr >= dst_start + len);
+       while (state.src_addr < src_start + len) {
+               VM_WARN_ON_ONCE(state.dst_addr >= dst_start + len);
+
+               pmd_t dst_pmdval;
 
-               dst_pmd = mm_alloc_pmd(dst_mm, dst_addr);
+               dst_pmd = mm_alloc_pmd(dst_mm, state.dst_addr);
                if (unlikely(!dst_pmd)) {
                        err = -ENOMEM;
                        break;
@@ -827,34 +843,34 @@ static __always_inline ssize_t mfill_atomic(struct 
userfaultfd_ctx *ctx,
                 * tables under us; pte_offset_map_lock() will deal with that.
                 */
 
-               err = mfill_atomic_pte(dst_pmd, dst_vma, dst_addr,
-                                      src_addr, flags, &folio);
+               state.pmd = dst_pmd;
+               err = mfill_atomic_pte(&state);
                cond_resched();
 
                if (unlikely(err == -ENOENT)) {
                        void *kaddr;
 
                        up_read(&ctx->map_changing_lock);
-                       uffd_mfill_unlock(dst_vma);
-                       VM_WARN_ON_ONCE(!folio);
+                       uffd_mfill_unlock(state.vma);
+                       VM_WARN_ON_ONCE(!state.folio);
 
-                       kaddr = kmap_local_folio(folio, 0);
+                       kaddr = kmap_local_folio(state.folio, 0);
                        err = copy_from_user(kaddr,
-                                            (const void __user *) src_addr,
+                                            (const void __user 
*)state.src_addr,
                                             PAGE_SIZE);
                        kunmap_local(kaddr);
                        if (unlikely(err)) {
                                err = -EFAULT;
                                goto out;
                        }
-                       flush_dcache_folio(folio);
+                       flush_dcache_folio(state.folio);
                        goto retry;
                } else
-                       VM_WARN_ON_ONCE(folio);
+                       VM_WARN_ON_ONCE(state.folio);
 
                if (!err) {
-                       dst_addr += PAGE_SIZE;
-                       src_addr += PAGE_SIZE;
+                       state.dst_addr += PAGE_SIZE;
+                       state.src_addr += PAGE_SIZE;
                        copied += PAGE_SIZE;
 
                        if (fatal_signal_pending(current))
@@ -866,10 +882,10 @@ static __always_inline ssize_t mfill_atomic(struct 
userfaultfd_ctx *ctx,
 
 out_unlock:
        up_read(&ctx->map_changing_lock);
-       uffd_mfill_unlock(dst_vma);
+       uffd_mfill_unlock(state.vma);
 out:
-       if (folio)
-               folio_put(folio);
+       if (state.folio)
+               folio_put(state.folio);
        VM_WARN_ON_ONCE(copied < 0);
        VM_WARN_ON_ONCE(err > 0);
        VM_WARN_ON_ONCE(!copied && !err);
-- 
2.51.0


Reply via email to