hmm_range_fault() holds the mmap read lock for the duration of the
call. This is incompatible with mappings whose fault handler may release
the mmap lock - notably userfaultfd-managed regions, where
handle_mm_fault() can return VM_FAULT_RETRY or VM_FAULT_COMPLETED after
dropping the lock. Drivers that need to populate device page tables for
such mappings have no way to do so today.

Add hmm_range_fault_unlockable(), modelled on the int *locked pattern
from get_user_pages_remote() in mm/gup.c. Callers set *locked = 1 and
pass &locked. The function may transiently drop and reacquire the mmap
lock while servicing retryable faults, and may set *locked = 0 if a
completed fault returns with the mmap lock dropped. In the latter case
the caller must reacquire the lock and restart the walk with a fresh
mmu_interval_read_begin() sequence.

The implementation is local to hmm_do_fault() and the outer loop in
hmm_range_fault_unlockable(). hmm_do_fault() conditionally sets
FAULT_FLAG_ALLOW_RETRY | FAULT_FLAG_KILLABLE when locked is non-NULL.
VM_FAULT_RETRY is handled internally by reacquiring mmap_lock and
retrying with FAULT_FLAG_TRIED set so fault handlers make forward
progress. VM_FAULT_COMPLETED is translated into *locked = 0 plus a
private return code consumed by the outer loop, which in turn returns 0
(or -EINTR on fatal signal) to the caller.

The previous refactor that moved page fault handling out of the
page-table walk callbacks is what makes this change small. Faults now
run after walk_page_range() has unwound, with only the mmap lock held, so
dropping it does not interact with the walker's pte spinlock or
hugetlb_vma_lock. Hugetlb regions therefore participate in the
unlockable path uniformly with PTE- and PMD-level mappings; no special
case is required.

hmm_range_fault() becomes a thin wrapper, preserving exact behaviour for
all existing callers. No EXPORT_SYMBOL behaviour change for
hmm_range_fault.

Documentation/mm/hmm.rst is updated with a description of the new API
and the recommended caller pattern.

Signed-off-by: Stanislav Kinsburskii <[email protected]>
---
 Documentation/mm/hmm.rst |   66 ++++++++++++++++++++++++++++
 include/linux/hmm.h      |    1 
 mm/hmm.c                 |  109 ++++++++++++++++++++++++++++++++++++++++++----
 3 files changed, 166 insertions(+), 10 deletions(-)

diff --git a/Documentation/mm/hmm.rst b/Documentation/mm/hmm.rst
index 7d61b7a8b65b..751ef3fb0434 100644
--- a/Documentation/mm/hmm.rst
+++ b/Documentation/mm/hmm.rst
@@ -208,6 +208,72 @@ invalidate() callback. That lock must be held before 
calling
 mmu_interval_read_retry() to avoid any race with a concurrent CPU page table
 update.
 
+Dropping the mmap lock during page faults
+=========================================
+
+Some VMAs have fault handlers that need to release the mmap lock while
+servicing a fault (for example, regions managed by ``userfaultfd``).
+``hmm_range_fault()`` cannot be used on such mappings because it must hold the
+mmap lock for the duration of the call. Drivers that need to support them
+should call::
+
+  int hmm_range_fault_unlockable(struct hmm_range *range, int *locked);
+
+The caller sets ``*locked = 1`` and holds ``mmap_read_lock`` before the call.
+If ``handle_mm_fault()`` returns ``VM_FAULT_RETRY``, the function reacquires
+the mmap lock internally and retries the fault with ``FAULT_FLAG_TRIED`` set.
+If ``handle_mm_fault()`` returns ``VM_FAULT_COMPLETED``, the function sets
+``*locked = 0``. If the return value is ``0``, the caller must reacquire
+the lock and restart the walk from ``range->start`` with a fresh notifier
+sequence. If the return value is ``-EINTR``, a fatal signal is pending and
+the caller should abort; the mmap lock is no longer held. When ``locked`` is
+``NULL`` the function keeps the lock held for the duration of the call,
+identical to ``hmm_range_fault()``.
+
+A typical caller looks like this::
+
+ int driver_populate_range_unlockable(...)
+ {
+      struct hmm_range range;
+      int locked;
+      ...
+
+      range.notifier = &interval_sub;
+      range.start = ...;
+      range.end = ...;
+      range.hmm_pfns = ...;
+
+      if (!mmget_not_zero(interval_sub.mm))
+          return -EFAULT;
+
+ again:
+      range.notifier_seq = mmu_interval_read_begin(&interval_sub);
+      locked = 1;
+      mmap_read_lock(mm);
+      ret = hmm_range_fault_unlockable(&range, &locked);
+      if (locked)
+          mmap_read_unlock(mm);
+      if (ret) {
+          if (ret == -EBUSY)
+              goto again;
+          return ret;
+      }
+      if (!locked)
+          goto again;
+
+      take_lock(driver->update);
+      if (mmu_interval_read_retry(&interval_sub, range.notifier_seq)) {
+          release_lock(driver->update);
+          goto again;
+      }
+
+      /* Use pfns array content to update device page table,
+       * under the update lock */
+
+      release_lock(driver->update);
+      return 0;
+ }
+
 Leverage default_flags and pfn_flags_mask
 =========================================
 
diff --git a/include/linux/hmm.h b/include/linux/hmm.h
index db75ffc949a7..46e581865c48 100644
--- a/include/linux/hmm.h
+++ b/include/linux/hmm.h
@@ -123,6 +123,7 @@ struct hmm_range {
  * Please see Documentation/mm/hmm.rst for how to use the range API.
  */
 int hmm_range_fault(struct hmm_range *range);
+int hmm_range_fault_unlockable(struct hmm_range *range, int *locked);
 
 /*
  * HMM_RANGE_DEFAULT_TIMEOUT - default timeout (ms) when waiting for a range
diff --git a/mm/hmm.c b/mm/hmm.c
index 2129b1ee4c35..1869b6df23a6 100644
--- a/mm/hmm.c
+++ b/mm/hmm.c
@@ -32,6 +32,7 @@
 
 struct hmm_vma_walk {
        struct hmm_range        *range;
+       int                     *locked;
        unsigned long           last;
        unsigned long           end;
        unsigned int            required_fault;
@@ -44,6 +45,14 @@ struct hmm_vma_walk {
  */
 #define HMM_FAULT_PENDING      -EAGAIN
 
+/*
+ * Internal sentinel returned by hmm_do_fault() when handle_mm_fault()
+ * completes a page fault with the mmap lock dropped. hmm_do_fault() sets
+ * *locked = 0; the outer loop consumes the sentinel and never propagates it
+ * to the caller.
+ */
+#define HMM_FAULT_UNLOCKED     -ENOLCK
+
 enum {
        HMM_NEED_FAULT = 1 << 0,
        HMM_NEED_WRITE_FAULT = 1 << 1,
@@ -73,9 +82,9 @@ static int hmm_pfns_fill(unsigned long addr, unsigned long 
end,
  *
  * Called by the walk callbacks when they discover that part of the range
  * needs a page fault.  The callback records what to fault and returns
- * HMM_FAULT_PENDING; the outer loop in hmm_range_fault() drops back out of
- * walk_page_range() and invokes handle_mm_fault() from a context where no
- * page-table or hugetlb_vma_lock is held.
+ * HMM_FAULT_PENDING; the outer loop in hmm_range_fault_unlockable() drops
+ * back out of walk_page_range() and invokes handle_mm_fault() from a context
+ * where no page-table or hugetlb_vma_lock is held.
  */
 static int hmm_record_fault(unsigned long addr, unsigned long end,
                            unsigned int required_fault,
@@ -624,7 +633,7 @@ static const struct mm_walk_ops hmm_walk_ops = {
 /*
  * hmm_do_fault - fault in a range recorded by a walk callback
  *
- * Called from the outer loop in hmm_range_fault() after a callback
+ * Called from the outer loop in hmm_range_fault_unlockable() after a callback
  * returned HMM_FAULT_PENDING.  At this point we hold only mmap_lock;
  * the page-table spinlock and any hugetlb_vma_lock acquired by the walk
  * framework have already been released by the unwind.
@@ -641,6 +650,14 @@ static int hmm_do_fault(struct mm_struct *mm,
        unsigned int fault_flags = FAULT_FLAG_REMOTE;
        struct vm_area_struct *vma;
 
+       if (hmm_vma_walk->locked)
+               fault_flags |= FAULT_FLAG_ALLOW_RETRY | FAULT_FLAG_KILLABLE;
+
+retry:
+       if ((fault_flags & FAULT_FLAG_TRIED) &&
+           fatal_signal_pending(current))
+               return -EINTR;
+
        vma = vma_lookup(mm, addr);
        if (!vma)
                return -EFAULT;
@@ -651,10 +668,30 @@ static int hmm_do_fault(struct mm_struct *mm,
                fault_flags |= FAULT_FLAG_WRITE;
        }
 
-       for (; addr < end; addr += PAGE_SIZE)
-               if (handle_mm_fault(vma, addr, fault_flags, NULL) &
-                   VM_FAULT_ERROR)
-                       return -EFAULT;
+       for (; addr < end; addr += PAGE_SIZE) {
+               vm_fault_t ret;
+
+               ret = handle_mm_fault(vma, addr, fault_flags, NULL);
+
+               if (ret & VM_FAULT_COMPLETED) {
+                       *hmm_vma_walk->locked = 0;
+                       return HMM_FAULT_UNLOCKED;
+               }
+
+               if (ret & VM_FAULT_ERROR) {
+                       int err = vm_fault_to_errno(ret, 0);
+
+                       if (err)
+                               return err;
+                       BUG();
+               }
+
+               if (ret & VM_FAULT_RETRY) {
+                       mmap_read_lock(mm);
+                       fault_flags |= FAULT_FLAG_TRIED;
+                       goto retry;
+               }
+       }
 
        return -EBUSY;
 }
@@ -677,11 +714,57 @@ static int hmm_do_fault(struct mm_struct *mm,
  *
  * This is similar to get_user_pages(), except that it can read the page tables
  * without mutating them (ie causing faults).
+ *
+ * The mmap lock must be held by the caller and will remain held on return.
+ * For a variant that allows the mmap lock to be dropped during faults (e.g.,
+ * for userfaultfd support), see hmm_range_fault_unlockable().
  */
 int hmm_range_fault(struct hmm_range *range)
+{
+       return hmm_range_fault_unlockable(range, NULL);
+}
+EXPORT_SYMBOL(hmm_range_fault);
+
+/**
+ * hmm_range_fault_unlockable - fault in a range, possibly dropping the mmap
+ *                              lock
+ * @range:     argument structure
+ * @locked:    pointer to caller's lock state, or %NULL
+ *
+ * Behaves like hmm_range_fault(), but allows handle_mm_fault() to drop the
+ * mmap read lock during a fault.  This makes the function usable on mappings
+ * whose fault path may release the lock (for example, userfaultfd-managed
+ * regions).
+ *
+ * If @locked is %NULL the mmap lock is never released and the function
+ * behaves exactly like hmm_range_fault().
+ *
+ * If @locked is non-%NULL the caller must hold mmap_read_lock and set
+ * *@locked = 1 before the call.  Retryable faults may drop and reacquire the
+ * mmap lock internally before retrying the fault with FAULT_FLAG_TRIED set.
+ * On return:
+ *
+ *   *@locked == 1: the mmap lock is still held. The return value has the
+ *                  same meaning as hmm_range_fault() (0 on success, or one
+ *                  of the error codes documented there).
+ *
+ *   *@locked == 0: the mmap lock was dropped during a completed page fault.
+ *                  No PFNs collected so far are guaranteed to be valid because
+ *                  the address space may have changed under us. If the return
+ *                  value is 0, the caller must reacquire the lock and restart
+ *                  with a fresh mmu_interval_read_begin(). If the return value
+ *                  is -EINTR, a fatal signal is pending and the caller should
+ *                  abort; the mmap lock is no longer held.
+ *
+ * -EINTR may also be returned if a fatal signal is pending during retry
+ * handling.
+ * See Documentation/mm/hmm.rst for the full usage pattern.
+ */
+int hmm_range_fault_unlockable(struct hmm_range *range, int *locked)
 {
        struct hmm_vma_walk hmm_vma_walk = {
                .range = range,
+               .locked = locked,
                .last = range->start,
        };
        struct mm_struct *mm = range->notifier->mm;
@@ -704,8 +787,14 @@ int hmm_range_fault(struct hmm_range *range)
                 * returns -EBUSY so the loop re-walks and picks up the
                 * now-present entries.
                 */
-               if (ret == HMM_FAULT_PENDING)
+               if (ret == HMM_FAULT_PENDING) {
                        ret = hmm_do_fault(mm, &hmm_vma_walk);
+                       if (ret == HMM_FAULT_UNLOCKED) {
+                               if (fatal_signal_pending(current))
+                                       return -EINTR;
+                               return 0;     /* caller must restart */
+                       }
+               }
                /*
                 * When -EBUSY is returned the loop restarts with
                 * hmm_vma_walk.last set to an address that has not been stored
@@ -715,7 +804,7 @@ int hmm_range_fault(struct hmm_range *range)
        } while (ret == -EBUSY);
        return ret;
 }
-EXPORT_SYMBOL(hmm_range_fault);
+EXPORT_SYMBOL(hmm_range_fault_unlockable);
 
 /**
  * hmm_dma_map_alloc - Allocate HMM map structure



Reply via email to