hmm_range_fault() currently triggers page faults from inside the page-table
walk callbacks: hmm_vma_walk_pmd(), hmm_vma_walk_pud(),
hmm_vma_walk_hugetlb_entry() and the pte-level helper all call
hmm_vma_fault(), which in turn calls handle_mm_fault() while the walker
still holds nested locks.  The pte spinlock is dropped explicitly by each
caller, and the hugetlb path manually drops and retakes
hugetlb_vma_lock_read around the fault to dodge a deadlock against the walk
framework's unconditional unlock.

This layering does not extend cleanly to fault handlers that may release
mmap_lock (VM_FAULT_RETRY, VM_FAULT_COMPLETED). If the lock is dropped
while walk_page_range() is mid-traversal, the VMA can be freed before the
walk framework's matching hugetlb_vma_unlock_read(), turning that unlock
into a use-after-free.

Split the responsibilities the way get_user_pages() does. Walk callbacks
become inspect-only: when they detect a range that needs to be faulted in,
they record it in struct hmm_vma_walk and return a private sentinel
(HMM_FAULT_PENDING). The outer loop in hmm_range_fault() then drops out of
walk_page_range(), invokes a new helper hmm_do_fault() that calls
handle_mm_fault() with only mmap_lock held, and restarts the walk so the
now-present entries are collected into hmm_pfns.

No functional change for existing callers. As a side effect the hugetlb
callback no longer needs the hugetlb_vma_{un}lock_read dance, and every
fault-path exit from the callbacks now releases the pte spinlock on a
single, common path. This refactor is also a precursor for adding an
unlockable variant of hmm_range_fault() in a follow-up patch.

Signed-off-by: Stanislav Kinsburskii <[email protected]>
---
 mm/hmm.c |  118 +++++++++++++++++++++++++++++++++++++++-----------------------
 1 file changed, 75 insertions(+), 43 deletions(-)

diff --git a/mm/hmm.c b/mm/hmm.c
index 5955f2f0c83db..2b157fcbc2928 100644
--- a/mm/hmm.c
+++ b/mm/hmm.c
@@ -33,8 +33,17 @@
 struct hmm_vma_walk {
        struct hmm_range        *range;
        unsigned long           last;
+       unsigned long           end;
+       unsigned int            required_fault;
 };
 
+/*
+ * Internal sentinel returned by walk callbacks when they need a page fault.
+ * The callback stores end/required_fault in hmm_vma_walk; the outer loop
+ * consumes the sentinel and never propagates it to the caller.
+ */
+#define HMM_FAULT_PENDING      -EAGAIN
+
 enum {
        HMM_NEED_FAULT = 1 << 0,
        HMM_NEED_WRITE_FAULT = 1 << 1,
@@ -60,37 +69,25 @@ static int hmm_pfns_fill(unsigned long addr, unsigned long 
end,
 }
 
 /*
- * hmm_vma_fault() - fault in a range lacking valid pmd or pte(s)
- * @addr: range virtual start address (inclusive)
- * @end: range virtual end address (exclusive)
- * @required_fault: HMM_NEED_* flags
- * @walk: mm_walk structure
- * Return: -EBUSY after page fault, or page fault error
+ * hmm_record_fault() - record a range that needs to be faulted in
  *
- * This function will be called whenever pmd_none() or pte_none() returns true,
- * or whenever there is no page directory covering the virtual address range.
+ * Called by the walk callbacks when they discover that part of the range
+ * needs a page fault.  The callback records what to fault and returns
+ * HMM_FAULT_PENDING; the outer loop in hmm_range_fault() drops back out of
+ * walk_page_range() and invokes handle_mm_fault() from a context where no
+ * page-table or hugetlb_vma_lock is held.
  */
-static int hmm_vma_fault(unsigned long addr, unsigned long end,
-                        unsigned int required_fault, struct mm_walk *walk)
+static int hmm_record_fault(unsigned long addr, unsigned long end,
+                           unsigned int required_fault,
+                           struct mm_walk *walk)
 {
        struct hmm_vma_walk *hmm_vma_walk = walk->private;
-       struct vm_area_struct *vma = walk->vma;
-       unsigned int fault_flags = FAULT_FLAG_REMOTE;
 
        WARN_ON_ONCE(!required_fault);
        hmm_vma_walk->last = addr;
-
-       if (required_fault & HMM_NEED_WRITE_FAULT) {
-               if (!(vma->vm_flags & VM_WRITE))
-                       return -EPERM;
-               fault_flags |= FAULT_FLAG_WRITE;
-       }
-
-       for (; addr < end; addr += PAGE_SIZE)
-               if (handle_mm_fault(vma, addr, fault_flags, NULL) &
-                   VM_FAULT_ERROR)
-                       return -EFAULT;
-       return -EBUSY;
+       hmm_vma_walk->end = end;
+       hmm_vma_walk->required_fault = required_fault;
+       return HMM_FAULT_PENDING;
 }
 
 static unsigned int hmm_pte_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
@@ -174,7 +171,7 @@ static int hmm_vma_walk_hole(unsigned long addr, unsigned 
long end,
                return hmm_pfns_fill(addr, end, range, HMM_PFN_ERROR);
        }
        if (required_fault)
-               return hmm_vma_fault(addr, end, required_fault, walk);
+               return hmm_record_fault(addr, end, required_fault, walk);
        return hmm_pfns_fill(addr, end, range, 0);
 }
 
@@ -209,7 +206,7 @@ static int hmm_vma_handle_pmd(struct mm_walk *walk, 
unsigned long addr,
        required_fault =
                hmm_range_need_fault(hmm_vma_walk, hmm_pfns, npages, cpu_flags);
        if (required_fault)
-               return hmm_vma_fault(addr, end, required_fault, walk);
+               return hmm_record_fault(addr, end, required_fault, walk);
 
        pfn = pmd_pfn(pmd) + ((addr & ~PMD_MASK) >> PAGE_SHIFT);
        for (i = 0; addr < end; addr += PAGE_SIZE, i++, pfn++) {
@@ -328,7 +325,7 @@ static int hmm_vma_handle_pte(struct mm_walk *walk, 
unsigned long addr,
 fault:
        pte_unmap(ptep);
        /* Fault any virtual address we were asked to fault */
-       return hmm_vma_fault(addr, end, required_fault, walk);
+       return hmm_record_fault(addr, end, required_fault, walk);
 }
 
 #ifdef CONFIG_ARCH_ENABLE_THP_MIGRATION
@@ -371,7 +368,7 @@ static int hmm_vma_handle_absent_pmd(struct mm_walk *walk, 
unsigned long start,
                                              npages, 0);
        if (required_fault) {
                if (softleaf_is_device_private(entry))
-                       return hmm_vma_fault(addr, end, required_fault, walk);
+                       return hmm_record_fault(addr, end, required_fault, 
walk);
                else
                        return -EFAULT;
        }
@@ -517,7 +514,7 @@ static int hmm_vma_walk_pud(pud_t *pudp, unsigned long 
start, unsigned long end,
                                                      npages, cpu_flags);
                if (required_fault) {
                        spin_unlock(ptl);
-                       return hmm_vma_fault(addr, end, required_fault, walk);
+                       return hmm_record_fault(addr, end, required_fault, 
walk);
                }
 
                pfn = pud_pfn(pud) + ((addr & ~PUD_MASK) >> PAGE_SHIFT);
@@ -564,21 +561,8 @@ static int hmm_vma_walk_hugetlb_entry(pte_t *pte, unsigned 
long hmask,
        required_fault =
                hmm_pte_need_fault(hmm_vma_walk, pfn_req_flags, cpu_flags);
        if (required_fault) {
-               int ret;
-
                spin_unlock(ptl);
-               hugetlb_vma_unlock_read(vma);
-               /*
-                * Avoid deadlock: drop the vma lock before calling
-                * hmm_vma_fault(), which will itself potentially take and
-                * drop the vma lock. This is also correct from a
-                * protection point of view, because there is no further
-                * use here of either pte or ptl after dropping the vma
-                * lock.
-                */
-               ret = hmm_vma_fault(addr, end, required_fault, walk);
-               hugetlb_vma_lock_read(vma);
-               return ret;
+               return hmm_record_fault(addr, end, required_fault, walk);
        }
 
        pfn = pte_pfn(entry) + ((start & ~hmask) >> PAGE_SHIFT);
@@ -637,6 +621,44 @@ static const struct mm_walk_ops hmm_walk_ops = {
        .walk_lock      = PGWALK_RDLOCK,
 };
 
+/*
+ * hmm_do_fault - fault in a range recorded by a walk callback
+ *
+ * Called from the outer loop in hmm_range_fault() after a callback
+ * returned HMM_FAULT_PENDING.  At this point we hold only mmap_lock;
+ * the page-table spinlock and any hugetlb_vma_lock acquired by the walk
+ * framework have already been released by the unwind.
+ *
+ * Returns -EBUSY on success (all pages faulted, caller should re-walk).
+ * Returns a negative errno on failure.
+ */
+static int hmm_do_fault(struct mm_struct *mm,
+                       struct hmm_vma_walk *hmm_vma_walk)
+{
+       unsigned long addr = hmm_vma_walk->last;
+       unsigned long end = hmm_vma_walk->end;
+       unsigned int required_fault = hmm_vma_walk->required_fault;
+       unsigned int fault_flags = FAULT_FLAG_REMOTE;
+       struct vm_area_struct *vma;
+
+       vma = vma_lookup(mm, addr);
+       if (!vma)
+               return -EFAULT;
+
+       if (required_fault & HMM_NEED_WRITE_FAULT) {
+               if (!(vma->vm_flags & VM_WRITE))
+                       return -EPERM;
+               fault_flags |= FAULT_FLAG_WRITE;
+       }
+
+       for (; addr < end; addr += PAGE_SIZE)
+               if (handle_mm_fault(vma, addr, fault_flags, NULL) &
+                   VM_FAULT_ERROR)
+                       return -EFAULT;
+
+       return -EBUSY;
+}
+
 /**
  * hmm_range_fault - try to fault some address in a virtual address range
  * @range:     argument structure
@@ -674,6 +696,16 @@ int hmm_range_fault(struct hmm_range *range)
                        return -EBUSY;
                ret = walk_page_range(mm, hmm_vma_walk.last, range->end,
                                      &hmm_walk_ops, &hmm_vma_walk);
+               /*
+                * When HMM_FAULT_PENDING is returned a walk callback
+                * recorded a range that needs handle_mm_fault();
+                * hmm_do_fault() runs the fault outside walk_page_range()
+                * (so no page-table or hugetlb_vma_lock is held) and
+                * returns -EBUSY so the loop re-walks and picks up the
+                * now-present entries.
+                */
+               if (ret == HMM_FAULT_PENDING)
+                       ret = hmm_do_fault(mm, &hmm_vma_walk);
                /*
                 * When -EBUSY is returned the loop restarts with
                 * hmm_vma_walk.last set to an address that has not been stored



Reply via email to