Instead of returning -EAGAIN unconditionally, we'd better do that only
we're sure the range is overlapped with the metadata area.

Reported-by: Jason Gunthorpe <j...@ziepe.ca>
Fixes: 7f466032dc9e ("vhost: access vq metadata through kernel virtual address")
Signed-off-by: Jason Wang <jasow...@redhat.com>
---
 drivers/vhost/vhost.c | 32 +++++++++++++++++++-------------
 1 file changed, 19 insertions(+), 13 deletions(-)

diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index d8863aaaf0f6..f98155f28f02 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -371,16 +371,19 @@ static void inline vhost_vq_access_map_end(struct 
vhost_virtqueue *vq)
        spin_unlock(&vq->mmu_lock);
 }
 
-static void vhost_invalidate_vq_start(struct vhost_virtqueue *vq,
-                                     int index,
-                                     unsigned long start,
-                                     unsigned long end)
+static int vhost_invalidate_vq_start(struct vhost_virtqueue *vq,
+                                    int index,
+                                    unsigned long start,
+                                    unsigned long end,
+                                    bool blockable)
 {
        struct vhost_uaddr *uaddr = &vq->uaddrs[index];
        struct vhost_map *map;
 
        if (!vhost_map_range_overlap(uaddr, start, end))
-               return;
+               return 0;
+       else if (!blockable)
+               return -EAGAIN;
 
        spin_lock(&vq->mmu_lock);
        ++vq->invalidate_count;
@@ -394,6 +397,8 @@ static void vhost_invalidate_vq_start(struct 
vhost_virtqueue *vq,
                vhost_set_map_dirty(vq, map, index);
                vhost_map_unprefetch(map);
        }
+
+       return 0;
 }
 
 static void vhost_invalidate_vq_end(struct vhost_virtqueue *vq,
@@ -414,18 +419,19 @@ static int vhost_invalidate_range_start(struct 
mmu_notifier *mn,
 {
        struct vhost_dev *dev = container_of(mn, struct vhost_dev,
                                             mmu_notifier);
-       int i, j;
-
-       if (!mmu_notifier_range_blockable(range))
-               return -EAGAIN;
+       bool blockable = mmu_notifier_range_blockable(range);
+       int i, j, ret;
 
        for (i = 0; i < dev->nvqs; i++) {
                struct vhost_virtqueue *vq = dev->vqs[i];
 
-               for (j = 0; j < VHOST_NUM_ADDRS; j++)
-                       vhost_invalidate_vq_start(vq, j,
-                                                 range->start,
-                                                 range->end);
+               for (j = 0; j < VHOST_NUM_ADDRS; j++) {
+                       ret = vhost_invalidate_vq_start(vq, j,
+                                                       range->start,
+                                                       range->end, blockable);
+                       if (ret)
+                               return ret;
+               }
        }
 
        return 0;
-- 
2.18.1

_______________________________________________
Virtualization mailing list
Virtualization@lists.linux-foundation.org
https://lists.linuxfoundation.org/mailman/listinfo/virtualization

Reply via email to