Update nouveau to only use the mmu interval notifiers.

Signed-off-by: Ralph Campbell <[email protected]>
---
 drivers/gpu/drm/nouveau/nouveau_svm.c | 313 +++++++++++++++++---------
 1 file changed, 201 insertions(+), 112 deletions(-)

diff --git a/drivers/gpu/drm/nouveau/nouveau_svm.c 
b/drivers/gpu/drm/nouveau/nouveau_svm.c
index df9bf1fd1bc0..0343e48d41d7 100644
--- a/drivers/gpu/drm/nouveau/nouveau_svm.c
+++ b/drivers/gpu/drm/nouveau/nouveau_svm.c
@@ -88,7 +88,7 @@ nouveau_ivmm_find(struct nouveau_svm *svm, u64 inst)
 }
 
 struct nouveau_svmm {
-       struct mmu_notifier notifier;
+       struct mm_struct *mm;
        struct nouveau_vmm *vmm;
        struct {
                unsigned long start;
@@ -98,6 +98,13 @@ struct nouveau_svmm {
        struct mutex mutex;
 };
 
+struct svmm_interval {
+       struct mmu_interval_notifier notifier;
+       struct nouveau_svmm *svmm;
+};
+
+static const struct mmu_interval_notifier_ops nouveau_svm_mni_ops;
+
 #define SVMM_DBG(s,f,a...)                                                     
\
        NV_DEBUG((s)->vmm->cli->drm, "svm-%p: "f"\n", (s), ##a)
 #define SVMM_ERR(s,f,a...)                                                     
\
@@ -236,6 +243,8 @@ nouveau_svmm_join(struct nouveau_svmm *svmm, u64 inst)
 static void
 nouveau_svmm_invalidate(struct nouveau_svmm *svmm, u64 start, u64 limit)
 {
+       SVMM_DBG(svmm, "invalidate %016llx-%016llx", start, limit);
+
        if (limit > start) {
                bool super = svmm->vmm->vmm.object.client->super;
                svmm->vmm->vmm.object.client->super = true;
@@ -248,58 +257,25 @@ nouveau_svmm_invalidate(struct nouveau_svmm *svmm, u64 
start, u64 limit)
        }
 }
 
-static int
-nouveau_svmm_invalidate_range_start(struct mmu_notifier *mn,
-                                   const struct mmu_notifier_range *update)
-{
-       struct nouveau_svmm *svmm =
-               container_of(mn, struct nouveau_svmm, notifier);
-       unsigned long start = update->start;
-       unsigned long limit = update->end;
-
-       if (!mmu_notifier_range_blockable(update))
-               return -EAGAIN;
-
-       SVMM_DBG(svmm, "invalidate %016lx-%016lx", start, limit);
-
-       mutex_lock(&svmm->mutex);
-       if (unlikely(!svmm->vmm))
-               goto out;
-
-       if (limit > svmm->unmanaged.start && start < svmm->unmanaged.limit) {
-               if (start < svmm->unmanaged.start) {
-                       nouveau_svmm_invalidate(svmm, start,
-                                               svmm->unmanaged.limit);
-               }
-               start = svmm->unmanaged.limit;
-       }
-
-       nouveau_svmm_invalidate(svmm, start, limit);
-
-out:
-       mutex_unlock(&svmm->mutex);
-       return 0;
-}
-
-static void nouveau_svmm_free_notifier(struct mmu_notifier *mn)
-{
-       kfree(container_of(mn, struct nouveau_svmm, notifier));
-}
-
-static const struct mmu_notifier_ops nouveau_mn_ops = {
-       .invalidate_range_start = nouveau_svmm_invalidate_range_start,
-       .free_notifier = nouveau_svmm_free_notifier,
-};
-
 void
 nouveau_svmm_fini(struct nouveau_svmm **psvmm)
 {
        struct nouveau_svmm *svmm = *psvmm;
+       struct mmu_interval_notifier *mni;
+
        if (svmm) {
                mutex_lock(&svmm->mutex);
+               while (true) {
+                       mni = mmu_interval_notifier_find(svmm->mm,
+                                       &nouveau_svm_mni_ops, 0UL, ~0UL);
+                       if (!mni)
+                               break;
+                       mmu_interval_notifier_put(mni);
+               }
                svmm->vmm = NULL;
                mutex_unlock(&svmm->mutex);
-               mmu_notifier_put(&svmm->notifier);
+               mmdrop(svmm->mm);
+               kfree(svmm);
                *psvmm = NULL;
        }
 }
@@ -343,11 +319,12 @@ nouveau_svmm_init(struct drm_device *dev, void *data,
                goto out_free;
 
        down_write(&current->mm->mmap_sem);
-       svmm->notifier.ops = &nouveau_mn_ops;
-       ret = __mmu_notifier_register(&svmm->notifier, current->mm);
+       ret = __mmu_notifier_register(NULL, current->mm);
        if (ret)
                goto out_mm_unlock;
-       /* Note, ownership of svmm transfers to mmu_notifier */
+
+       mmgrab(current->mm);
+       svmm->mm = current->mm;
 
        cli->svm.svmm = svmm;
        cli->svm.cli = cli;
@@ -482,65 +459,212 @@ nouveau_svm_fault_cache(struct nouveau_svm *svm,
                fault->inst, fault->addr, fault->access);
 }
 
-struct svm_notifier {
-       struct mmu_interval_notifier notifier;
-       struct nouveau_svmm *svmm;
-};
+static struct svmm_interval *nouveau_svmm_new_interval(
+                                       struct nouveau_svmm *svmm,
+                                       unsigned long start,
+                                       unsigned long last)
+{
+       struct svmm_interval *smi;
+       int ret;
+
+       smi = kmalloc(sizeof(*smi), GFP_ATOMIC);
+       if (!smi)
+               return NULL;
+
+       smi->svmm = svmm;
+
+       ret = mmu_interval_notifier_insert_safe(&smi->notifier, svmm->mm,
+                               start, last - start + 1, &nouveau_svm_mni_ops);
+       if (ret) {
+               kfree(smi);
+               return NULL;
+       }
+
+       return smi;
+}
+
+static void nouveau_svmm_do_unmap(struct mmu_interval_notifier *mni,
+                                const struct mmu_notifier_range *range)
+{
+       struct svmm_interval *smi =
+               container_of(mni, struct svmm_interval, notifier);
+       struct nouveau_svmm *svmm = smi->svmm;
+       unsigned long start = mmu_interval_notifier_start(mni);
+       unsigned long last = mmu_interval_notifier_last(mni);
+
+       if (start >= range->start) {
+               /* Remove the whole interval or keep the right-hand part. */
+               if (last <= range->end)
+                       mmu_interval_notifier_put(mni);
+               else
+                       mmu_interval_notifier_update(mni, range->end, last);
+               return;
+       }
+
+       /* Keep the left-hand part of the interval. */
+       mmu_interval_notifier_update(mni, start, range->start - 1);
+
+       /* If a hole is created, create an interval for the right-hand part. */
+       if (last >= range->end) {
+               smi = nouveau_svmm_new_interval(svmm, range->end, last);
+               /*
+                * If we can't allocate an interval, we won't get invalidation
+                * callbacks so clear the mapping and rely on faults to reload
+                * the mappings if needed.
+                */
+               if (!smi)
+                       nouveau_svmm_invalidate(svmm, range->end, last + 1);
+       }
+}
 
-static bool nouveau_svm_range_invalidate(struct mmu_interval_notifier *mni,
-                                        const struct mmu_notifier_range *range,
-                                        unsigned long cur_seq)
+static bool nouveau_svmm_interval_invalidate(struct mmu_interval_notifier *mni,
+                               const struct mmu_notifier_range *range,
+                               unsigned long cur_seq)
 {
-       struct svm_notifier *sn =
-               container_of(mni, struct svm_notifier, notifier);
+       struct svmm_interval *smi =
+               container_of(mni, struct svmm_interval, notifier);
+       struct nouveau_svmm *svmm = smi->svmm;
 
        /*
-        * serializes the update to mni->invalidate_seq done by caller and
+        * Serializes the update to mni->invalidate_seq done by the caller and
         * prevents invalidation of the PTE from progressing while HW is being
-        * programmed. This is very hacky and only works because the normal
-        * notifier that does invalidation is always called after the range
-        * notifier.
+        * programmed.
         */
        if (mmu_notifier_range_blockable(range))
-               mutex_lock(&sn->svmm->mutex);
-       else if (!mutex_trylock(&sn->svmm->mutex))
+               mutex_lock(&svmm->mutex);
+       else if (!mutex_trylock(&svmm->mutex))
                return false;
+
        mmu_interval_set_seq(mni, cur_seq);
-       mutex_unlock(&sn->svmm->mutex);
+       nouveau_svmm_invalidate(svmm, range->start, range->end);
+
+       /* Stop tracking the range if it is an unmap. */
+       if (range->event == MMU_NOTIFY_UNMAP)
+               nouveau_svmm_do_unmap(mni, range);
+
+       mutex_unlock(&svmm->mutex);
        return true;
 }
 
+static void nouveau_svmm_interval_release(struct mmu_interval_notifier *mni)
+{
+       struct svmm_interval *smi =
+               container_of(mni, struct svmm_interval, notifier);
+
+       kfree(smi);
+}
+
 static const struct mmu_interval_notifier_ops nouveau_svm_mni_ops = {
-       .invalidate = nouveau_svm_range_invalidate,
+       .invalidate = nouveau_svmm_interval_invalidate,
+       .release = nouveau_svmm_interval_release,
 };
 
+/*
+ * Find or create a mmu_interval_notifier for the given range.
+ * Although mmu_interval_notifier_insert_safe() can handle overlapping
+ * intervals, we only create non-overlapping intervals, shrinking the hmm_range
+ * if it spans more than one svmm_interval.
+ */
+static int nouveau_svmm_interval_find(struct nouveau_svmm *svmm,
+                                struct hmm_range *range)
+{
+       struct mmu_interval_notifier *mni;
+       struct svmm_interval *smi;
+       struct vm_area_struct *vma;
+       unsigned long start = range->start;
+       unsigned long last = range->end - 1;
+       int ret;
+
+       mutex_lock(&svmm->mutex);
+       mni = mmu_interval_notifier_find(svmm->mm, &nouveau_svm_mni_ops, start,
+                                        last);
+       if (mni) {
+               if (start >= mmu_interval_notifier_start(mni)) {
+                       smi = container_of(mni, struct svmm_interval, notifier);
+                       if (last > mmu_interval_notifier_last(mni))
+                               range->end =
+                                       mmu_interval_notifier_last(mni) + 1;
+                       goto found;
+               }
+               WARN_ON(last <= mmu_interval_notifier_start(mni));
+               range->end = mmu_interval_notifier_start(mni);
+               last = range->end - 1;
+       }
+       /*
+        * Might as well create an interval covering the underlying VMA to
+        * avoid having to create a bunch of small intervals.
+        */
+       vma = find_vma(svmm->mm, range->start);
+       if (!vma || start < vma->vm_start) {
+               ret = -ENOENT;
+               goto err;
+       }
+       if (range->end > vma->vm_end) {
+               range->end = vma->vm_end;
+               last = range->end - 1;
+       } else if (!mni) {
+               /* Anything registered on the right part of the vma? */
+               mni = mmu_interval_notifier_find(svmm->mm, &nouveau_svm_mni_ops,
+                                                range->end, vma->vm_end - 1);
+               if (mni)
+                       last = mmu_interval_notifier_start(mni) - 1;
+               else
+                       last = vma->vm_end - 1;
+       }
+       /* Anything registered on the left part of the vma? */
+       mni = mmu_interval_notifier_find(svmm->mm, &nouveau_svm_mni_ops,
+                                        vma->vm_start, start - 1);
+       if (mni)
+               start = mmu_interval_notifier_last(mni) + 1;
+       else
+               start = vma->vm_start;
+       smi = nouveau_svmm_new_interval(svmm, start, last);
+       if (!smi) {
+               ret = -ENOMEM;
+               goto err;
+       }
+
+found:
+       range->notifier = &smi->notifier;
+       mutex_unlock(&svmm->mutex);
+       return 0;
+
+err:
+       mutex_unlock(&svmm->mutex);
+       return ret;
+}
+
 static int nouveau_range_fault(struct nouveau_svmm *svmm,
                               struct nouveau_drm *drm, void *data, u32 size,
-                              u64 *pfns, struct svm_notifier *notifier)
+                              u64 *pfns, u64 start, u64 end)
 {
        unsigned long timeout =
                jiffies + msecs_to_jiffies(HMM_RANGE_DEFAULT_TIMEOUT);
        /* Have HMM fault pages within the fault window to the GPU. */
        struct hmm_range range = {
-               .notifier = &notifier->notifier,
-               .start = notifier->notifier.interval_tree.start,
-               .end = notifier->notifier.interval_tree.last + 1,
+               .start = start,
+               .end = end,
                .pfns = pfns,
                .flags = nouveau_svm_pfn_flags,
                .values = nouveau_svm_pfn_values,
+               .default_flags = 0,
+               .pfn_flags_mask = ~0UL,
                .pfn_shift = NVIF_VMM_PFNMAP_V0_ADDR_SHIFT,
        };
-       struct mm_struct *mm = notifier->notifier.mm;
+       struct mm_struct *mm = svmm->mm;
        long ret;
 
        while (true) {
                if (time_after(jiffies, timeout))
                        return -EBUSY;
 
-               range.notifier_seq = mmu_interval_read_begin(range.notifier);
-               range.default_flags = 0;
-               range.pfn_flags_mask = -1UL;
                down_read(&mm->mmap_sem);
+               ret = nouveau_svmm_interval_find(svmm, &range);
+               if (ret) {
+                       up_read(&mm->mmap_sem);
+                       return ret;
+               }
+               range.notifier_seq = mmu_interval_read_begin(range.notifier);
                ret = hmm_range_fault(&range, 0);
                up_read(&mm->mmap_sem);
                if (ret <= 0) {
@@ -585,7 +709,6 @@ nouveau_svm_fault(struct nvif_notify *notify)
                } i;
                u64 phys[16];
        } args;
-       struct vm_area_struct *vma;
        u64 inst, start, limit;
        int fi, fn, pi, fill;
        int replay = 0, ret;
@@ -640,7 +763,6 @@ nouveau_svm_fault(struct nvif_notify *notify)
        args.i.p.version = 0;
 
        for (fi = 0; fn = fi + 1, fi < buffer->fault_nr; fi = fn) {
-               struct svm_notifier notifier;
                struct mm_struct *mm;
 
                /* Cancel any faults from non-SVM channels. */
@@ -662,36 +784,12 @@ nouveau_svm_fault(struct nvif_notify *notify)
                        start = max_t(u64, start, svmm->unmanaged.limit);
                SVMM_DBG(svmm, "wndw %016llx-%016llx", start, limit);
 
-               mm = svmm->notifier.mm;
+               mm = svmm->mm;
                if (!mmget_not_zero(mm)) {
                        nouveau_svm_fault_cancel_fault(svm, buffer->fault[fi]);
                        continue;
                }
 
-               /* Intersect fault window with the CPU VMA, cancelling
-                * the fault if the address is invalid.
-                */
-               down_read(&mm->mmap_sem);
-               vma = find_vma_intersection(mm, start, limit);
-               if (!vma) {
-                       SVMM_ERR(svmm, "wndw %016llx-%016llx", start, limit);
-                       up_read(&mm->mmap_sem);
-                       mmput(mm);
-                       nouveau_svm_fault_cancel_fault(svm, buffer->fault[fi]);
-                       continue;
-               }
-               start = max_t(u64, start, vma->vm_start);
-               limit = min_t(u64, limit, vma->vm_end);
-               up_read(&mm->mmap_sem);
-               SVMM_DBG(svmm, "wndw %016llx-%016llx", start, limit);
-
-               if (buffer->fault[fi]->addr != start) {
-                       SVMM_ERR(svmm, "addr %016llx", buffer->fault[fi]->addr);
-                       mmput(mm);
-                       nouveau_svm_fault_cancel_fault(svm, buffer->fault[fi]);
-                       continue;
-               }
-
                /* Prepare the GPU-side update of all pages within the
                 * fault window, determining required pages and access
                 * permissions based on pending faults.
@@ -743,18 +841,9 @@ nouveau_svm_fault(struct nvif_notify *notify)
                         args.i.p.addr,
                         args.i.p.addr + args.i.p.size, fn - fi);
 
-               notifier.svmm = svmm;
-               ret = mmu_interval_notifier_insert(&notifier.notifier,
-                                                  svmm->notifier.mm,
-                                                  args.i.p.addr, args.i.p.size,
-                                                  &nouveau_svm_mni_ops);
-               if (!ret) {
-                       ret = nouveau_range_fault(
-                               svmm, svm->drm, &args,
-                               sizeof(args.i) + pi * sizeof(args.phys[0]),
-                               args.phys, &notifier);
-                       mmu_interval_notifier_remove(&notifier.notifier);
-               }
+               ret = nouveau_range_fault(svmm, svm->drm, &args,
+                       sizeof(args.i) + pi * sizeof(args.phys[0]), args.phys,
+                       start, start + args.i.p.size);
                mmput(mm);
 
                /* Cancel any faults in the window whose pages didn't manage
-- 
2.20.1

_______________________________________________
Nouveau mailing list
[email protected]
https://lists.freedesktop.org/mailman/listinfo/nouveau

Reply via email to