On Tue, 9 Mar 2021 12:26:07 -0700
Alex Williamson <alex.william...@redhat.com> wrote:

> On Tue, 9 Mar 2021 13:47:39 -0500
> Peter Xu <pet...@redhat.com> wrote:
> 
> > On Tue, Mar 09, 2021 at 12:40:04PM -0400, Jason Gunthorpe wrote:  
> > > On Tue, Mar 09, 2021 at 08:29:51AM -0700, Alex Williamson wrote:    
> > > > On Tue, 9 Mar 2021 08:46:09 -0400
> > > > Jason Gunthorpe <j...@nvidia.com> wrote:
> > > >     
> > > > > On Tue, Mar 09, 2021 at 03:49:09AM +0000, Zengtao (B) wrote:    
> > > > > > Hi guys:
> > > > > > 
> > > > > > Thanks for the helpful comments, after rethinking the issue, I have 
> > > > > > proposed
> > > > > >  the following change: 
> > > > > > 1. follow_pte instead of follow_pfn.      
> > > > > 
> > > > > Still no on follow_pfn, you don't need it once you use vmf_insert_pfn 
> > > > >    
> > > > 
> > > > vmf_insert_pfn() only solves the BUG_ON, follow_pte() is being used
> > > > here to determine whether the translation is already present to avoid
> > > > both duplicate work in inserting the translation and allocating a
> > > > duplicate vma tracking structure.    
> > >  
> > > Oh.. Doing something stateful in fault is not nice at all
> > > 
> > > I would rather see __vfio_pci_add_vma() search the vma_list for dups
> > > than call follow_pfn/pte..    
> > 
> > It seems to me that searching vma list is still the simplest way to fix the
> > problem for the current code base.  I see io_remap_pfn_range() is also used 
> > in
> > the new series - maybe that'll need to be moved to where PCI_COMMAND_MEMORY 
> > got
> > turned on/off in the new series (I just noticed remap_pfn_range modifies vma
> > flags..), as you suggested in the other email.  
> 
> 
> In the new series, I think the fault handler becomes (untested):
> 
> static vm_fault_t vfio_pci_mmap_fault(struct vm_fault *vmf)
> {
>         struct vm_area_struct *vma = vmf->vma;
>         struct vfio_pci_device *vdev = vma->vm_private_data;
>         unsigned long base_pfn, pgoff;
>         vm_fault_t ret = VM_FAULT_SIGBUS;
> 
>         if (vfio_pci_bar_vma_to_pfn(vma, &base_pfn))
>                 return ret;
> 
>         pgoff = (vmf->address - vma->vm_start) >> PAGE_SHIFT;
> 
>         down_read(&vdev->memory_lock);
> 
>         if (__vfio_pci_memory_enabled(vdev))
>                 ret = vmf_insert_pfn(vma, vmf->address, pgoff + base_pfn);
> 
>         up_read(&vdev->memory_lock);
> 
>         return ret;
> }

And I think this is what we end up with for the current code base:

diff --git a/drivers/vfio/pci/vfio_pci.c b/drivers/vfio/pci/vfio_pci.c
index 65e7e6b44578..2f247ab18c66 100644
--- a/drivers/vfio/pci/vfio_pci.c
+++ b/drivers/vfio/pci/vfio_pci.c
@@ -1568,19 +1568,24 @@ void vfio_pci_memory_unlock_and_restore(struct 
vfio_pci_device *vdev, u16 cmd)
 }
 
 /* Caller holds vma_lock */
-static int __vfio_pci_add_vma(struct vfio_pci_device *vdev,
-                             struct vm_area_struct *vma)
+struct vfio_pci_mmap_vma *__vfio_pci_add_vma(struct vfio_pci_device *vdev,
+                                            struct vm_area_struct *vma)
 {
        struct vfio_pci_mmap_vma *mmap_vma;
 
+       list_for_each_entry(mmap_vma, &vdev->vma_list, vma_next) {
+               if (mmap_vma->vma == vma)
+                       return ERR_PTR(-EEXIST);
+       }
+
        mmap_vma = kmalloc(sizeof(*mmap_vma), GFP_KERNEL);
        if (!mmap_vma)
-               return -ENOMEM;
+               return ERR_PTR(-ENOMEM);
 
        mmap_vma->vma = vma;
        list_add(&mmap_vma->vma_next, &vdev->vma_list);
 
-       return 0;
+       return mmap_vma;
 }
 
 /*
@@ -1612,30 +1617,39 @@ static vm_fault_t vfio_pci_mmap_fault(struct vm_fault 
*vmf)
 {
        struct vm_area_struct *vma = vmf->vma;
        struct vfio_pci_device *vdev = vma->vm_private_data;
-       vm_fault_t ret = VM_FAULT_NOPAGE;
+       struct vfio_pci_mmap_vma *mmap_vma;
+       unsigned long vaddr, pfn;
+       vm_fault_t ret;
 
        mutex_lock(&vdev->vma_lock);
        down_read(&vdev->memory_lock);
 
        if (!__vfio_pci_memory_enabled(vdev)) {
                ret = VM_FAULT_SIGBUS;
-               mutex_unlock(&vdev->vma_lock);
                goto up_out;
        }
 
-       if (__vfio_pci_add_vma(vdev, vma)) {
-               ret = VM_FAULT_OOM;
-               mutex_unlock(&vdev->vma_lock);
+       mmap_vma = __vfio_pci_add_vma(vdev, vma);
+       if (IS_ERR(mmap_vma)) {
+               /* A concurrent fault might have already inserted the page */
+               ret = (PTR_ERR(mmap_vma) == -EEXIST) ? VM_FAULT_NOPAGE :
+                                                      VM_FAULT_OOM;
                goto up_out;
        }
 
-       mutex_unlock(&vdev->vma_lock);
-
-       if (io_remap_pfn_range(vma, vma->vm_start, vma->vm_pgoff,
-                              vma->vm_end - vma->vm_start, vma->vm_page_prot))
-               ret = VM_FAULT_SIGBUS;
-
+       for (vaddr = vma->vm_start, pfn = vma->vm_pgoff;
+            vaddr < vma->vm_end; vaddr += PAGE_SIZE, pfn++) {
+               ret = vmf_insert_pfn(vma, vaddr, pfn);
+               if (ret != VM_FAULT_NOPAGE) {
+                       zap_vma_ptes(vma, vma->vm_start,
+                                    vma->vm_end - vma->vm_start);
+                       list_del(&mmap_vma->vma_next);
+                       kfree(mmap_vma);
+                       break;
+               }
+       }
 up_out:
+       mutex_unlock(&vdev->vma_lock);
        up_read(&vdev->memory_lock);
        return ret;
 }

Reply via email to