Impose a new default strict MMIO mapping mode where the vma for
a VM_PFNMAP mapping must be backed by a vfio device.  This allows
holding a reference to the device and registering a notifier for the
device, which additionally keeps the device in an IOMMU context for
the extent of the DMA mapping.  On notification of device release,
automatically drop the DMA mappings for it.

Signed-off-by: Alex Williamson <alex.william...@redhat.com>
---
 drivers/vfio/vfio_iommu_type1.c |  163 ++++++++++++++++++++++++++++-----------
 1 file changed, 116 insertions(+), 47 deletions(-)

diff --git a/drivers/vfio/vfio_iommu_type1.c b/drivers/vfio/vfio_iommu_type1.c
index f22c07a40521..e89f11141dee 100644
--- a/drivers/vfio/vfio_iommu_type1.c
+++ b/drivers/vfio/vfio_iommu_type1.c
@@ -101,6 +101,20 @@ struct vfio_dma {
        struct task_struct      *task;
        struct rb_root          pfn_list;       /* Ex-user pinned pfn list */
        unsigned long           *bitmap;
+       struct pfnmap_obj       *pfnmap;
+};
+
+/*
+ * Separate object used for tracking pfnmaps to allow reference release and
+ * unregistering notifier outside of callback chain.
+ */
+struct pfnmap_obj {
+       struct notifier_block   nb;
+       struct work_struct      work;
+       struct vfio_iommu       *iommu;
+       struct vfio_dma         *dma;
+       struct vfio_device      *device;
+       unsigned long           base_pfn;
 };
 
 struct vfio_batch {
@@ -506,42 +520,6 @@ static void vfio_batch_fini(struct vfio_batch *batch)
                free_page((unsigned long)batch->pages);
 }
 
-static int follow_fault_pfn(struct vm_area_struct *vma, struct mm_struct *mm,
-                           unsigned long vaddr, unsigned long *pfn,
-                           bool write_fault)
-{
-       pte_t *ptep;
-       spinlock_t *ptl;
-       int ret;
-
-       ret = follow_pte(vma->vm_mm, vaddr, &ptep, &ptl);
-       if (ret) {
-               bool unlocked = false;
-
-               ret = fixup_user_fault(mm, vaddr,
-                                      FAULT_FLAG_REMOTE |
-                                      (write_fault ?  FAULT_FLAG_WRITE : 0),
-                                      &unlocked);
-               if (unlocked)
-                       return -EAGAIN;
-
-               if (ret)
-                       return ret;
-
-               ret = follow_pte(vma->vm_mm, vaddr, &ptep, &ptl);
-               if (ret)
-                       return ret;
-       }
-
-       if (write_fault && !pte_write(*ptep))
-               ret = -EFAULT;
-       else
-               *pfn = pte_pfn(*ptep);
-
-       pte_unmap_unlock(ptep, ptl);
-       return ret;
-}
-
 /* Return 1 if iommu->lock dropped and notified, 0 if done */
 static int unmap_dma_pfn_list(struct vfio_iommu *iommu, struct vfio_dma *dma,
                              struct vfio_dma **dma_last, int *retries)
@@ -575,6 +553,52 @@ static int unmap_dma_pfn_list(struct vfio_iommu *iommu, 
struct vfio_dma *dma,
        return 0;
 }
 
+static void unregister_device_bg(struct work_struct *work)
+{
+       struct pfnmap_obj *pfnmap = container_of(work, struct pfnmap_obj, work);
+
+       vfio_device_unregister_notifier(pfnmap->device, &pfnmap->nb);
+       vfio_device_put(pfnmap->device);
+       kfree(pfnmap);
+}
+
+static void vfio_remove_dma(struct vfio_iommu *iommu, struct vfio_dma *dma);
+
+static int vfio_device_nb_cb(struct notifier_block *nb,
+                            unsigned long action, void *unused)
+{
+       struct pfnmap_obj *pfnmap = container_of(nb, struct pfnmap_obj, nb);
+
+       switch (action) {
+       case VFIO_DEVICE_RELEASE:
+       {
+               struct vfio_dma *dma_last = NULL;
+               int retries = 0;
+again:
+               mutex_lock(&pfnmap->iommu->lock);
+               if (pfnmap->dma) {
+                       struct vfio_dma *dma = pfnmap->dma;
+
+                       if (unmap_dma_pfn_list(pfnmap->iommu, dma,
+                                              &dma_last, &retries))
+                               goto again;
+
+                       dma->pfnmap = NULL;
+                       pfnmap->dma = NULL;
+                       vfio_remove_dma(pfnmap->iommu, dma);
+               }
+               mutex_unlock(&pfnmap->iommu->lock);
+
+               /* Cannot unregister notifier from callback chain */
+               INIT_WORK(&pfnmap->work, unregister_device_bg);
+               schedule_work(&pfnmap->work);
+               break;
+       }
+       }
+
+       return NOTIFY_OK;
+}
+
 /*
  * Returns the positive number of pfns successfully obtained or a negative
  * error code.
@@ -601,21 +625,60 @@ static int vaddr_get_pfns(struct vfio_iommu *iommu, 
struct vfio_dma *dma,
 
        vaddr = untagged_addr(vaddr);
 
-retry:
        vma = find_vma_intersection(mm, vaddr, vaddr + 1);
 
        if (vma && vma->vm_flags & VM_PFNMAP) {
-               ret = follow_fault_pfn(vma, mm, vaddr, pfn,
-                                      dma->prot & IOMMU_WRITE);
-               if (ret == -EAGAIN)
-                       goto retry;
-
-               if (!ret) {
-                       if (is_invalid_reserved_pfn(*pfn))
-                               ret = 1;
-                       else
-                               ret = -EFAULT;
+               if ((dma->prot & IOMMU_WRITE && !(vma->vm_flags & VM_WRITE)) ||
+                   (dma->prot & IOMMU_READ && !(vma->vm_flags & VM_READ))) {
+                       ret = -EFAULT;
+                       goto done;
+               }
+
+               if (!dma->pfnmap) {
+                       struct vfio_device *device;
+                       unsigned long base_pfn;
+                       struct pfnmap_obj *pfnmap;
+
+                       device = vfio_device_get_from_vma(vma);
+                       if (IS_ERR(device)) {
+                               ret = PTR_ERR(device);
+                               goto done;
+                       }
+
+                       ret = vfio_vma_to_pfn(vma, &base_pfn);
+                       if (ret) {
+                               vfio_device_put(device);
+                               goto done;
+                       }
+
+                       pfnmap = kzalloc(sizeof(*pfnmap), GFP_KERNEL);
+                       if (!pfnmap) {
+                               vfio_device_put(device);
+                               ret = -ENOMEM;
+                               goto done;
+                       }
+
+                       pfnmap->nb.notifier_call = vfio_device_nb_cb;
+                       pfnmap->iommu = iommu;
+                       pfnmap->dma = dma;
+                       pfnmap->device = device;
+                       pfnmap->base_pfn = base_pfn;
+
+                       dma->pfnmap = pfnmap;
+
+                       ret = vfio_device_register_notifier(device,
+                                                           &pfnmap->nb);
+                       if (ret) {
+                               dma->pfnmap = NULL;
+                               kfree(pfnmap);
+                               vfio_device_put(device);
+                               goto done;
+                       }
                }
+
+               *pfn = ((vaddr - vma->vm_start) >> PAGE_SHIFT) +
+                                                       dma->pfnmap->base_pfn;
+               ret = 1;
        }
 done:
        mmap_read_unlock(mm);
@@ -1189,6 +1252,12 @@ static long vfio_unmap_unpin(struct vfio_iommu *iommu, 
struct vfio_dma *dma,
 static void vfio_remove_dma(struct vfio_iommu *iommu, struct vfio_dma *dma)
 {
        WARN_ON(!RB_EMPTY_ROOT(&dma->pfn_list));
+       if (dma->pfnmap) {
+               vfio_device_unregister_notifier(dma->pfnmap->device,
+                                               &dma->pfnmap->nb);
+               vfio_device_put(dma->pfnmap->device);
+               kfree(dma->pfnmap);
+       }
        vfio_unmap_unpin(iommu, dma, true);
        vfio_unlink_dma(iommu, dma);
        put_task_struct(dma->task);

Reply via email to