When the open_device() op is called the container_users is incremented and
held incremented until close_device(). Thus, so long as drivers call
functions within their open_device()/close_device() region they do not
need to worry about the container_users.

These functions can all only be called between
open_device()/close_device():

  vfio_pin_pages()
  vfio_unpin_pages()
  vfio_dma_rw()
  vfio_register_notifier()
  vfio_unregister_notifier()

So eliminate the calls to vfio_group_add_container_user() and add a simple
WARN_ON to detect mis-use by drivers.

Signed-off-by: Jason Gunthorpe <j...@nvidia.com>
---
 drivers/vfio/vfio.c | 67 +++++++++------------------------------------
 1 file changed, 13 insertions(+), 54 deletions(-)

diff --git a/drivers/vfio/vfio.c b/drivers/vfio/vfio.c
index 3d75505bf3cc26..ab0c3f5635905c 100644
--- a/drivers/vfio/vfio.c
+++ b/drivers/vfio/vfio.c
@@ -2121,9 +2121,8 @@ int vfio_pin_pages(struct vfio_device *vdev, unsigned 
long *user_pfn, int npage,
        if (group->dev_counter > 1)
                return -EINVAL;
 
-       ret = vfio_group_add_container_user(group);
-       if (ret)
-               return ret;
+       if (WARN_ON(!READ_ONCE(vdev->open_count)))
+               return -EINVAL;
 
        container = group->container;
        driver = container->iommu_driver;
@@ -2134,8 +2133,6 @@ int vfio_pin_pages(struct vfio_device *vdev, unsigned 
long *user_pfn, int npage,
        else
                ret = -ENOTTY;
 
-       vfio_group_try_dissolve_container(group);
-
        return ret;
 }
 EXPORT_SYMBOL(vfio_pin_pages);
@@ -2162,9 +2159,8 @@ int vfio_unpin_pages(struct vfio_device *vdev, unsigned 
long *user_pfn,
        if (npage > VFIO_PIN_PAGES_MAX_ENTRIES)
                return -E2BIG;
 
-       ret = vfio_group_add_container_user(vdev->group);
-       if (ret)
-               return ret;
+       if (WARN_ON(!READ_ONCE(vdev->open_count)))
+               return -EINVAL;
 
        container = vdev->group->container;
        driver = container->iommu_driver;
@@ -2174,8 +2170,6 @@ int vfio_unpin_pages(struct vfio_device *vdev, unsigned 
long *user_pfn,
        else
                ret = -ENOTTY;
 
-       vfio_group_try_dissolve_container(vdev->group);
-
        return ret;
 }
 EXPORT_SYMBOL(vfio_unpin_pages);
@@ -2207,9 +2201,8 @@ int vfio_dma_rw(struct vfio_device *vdev, dma_addr_t 
user_iova,
        if (!data || len <= 0)
                return -EINVAL;
 
-       ret = vfio_group_add_container_user(vdev->group);
-       if (ret)
-               return ret;
+       if (WARN_ON(!READ_ONCE(vdev->open_count)))
+               return -EINVAL;
 
        container = vdev->group->container;
        driver = container->iommu_driver;
@@ -2219,9 +2212,6 @@ int vfio_dma_rw(struct vfio_device *vdev, dma_addr_t 
user_iova,
                                          user_iova, data, len, write);
        else
                ret = -ENOTTY;
-
-       vfio_group_try_dissolve_container(vdev->group);
-
        return ret;
 }
 EXPORT_SYMBOL(vfio_dma_rw);
@@ -2234,10 +2224,6 @@ static int vfio_register_iommu_notifier(struct 
vfio_group *group,
        struct vfio_iommu_driver *driver;
        int ret;
 
-       ret = vfio_group_add_container_user(group);
-       if (ret)
-               return -EINVAL;
-
        container = group->container;
        driver = container->iommu_driver;
        if (likely(driver && driver->ops->register_notifier))
@@ -2245,9 +2231,6 @@ static int vfio_register_iommu_notifier(struct vfio_group 
*group,
                                                     events, nb);
        else
                ret = -ENOTTY;
-
-       vfio_group_try_dissolve_container(group);
-
        return ret;
 }
 
@@ -2258,10 +2241,6 @@ static int vfio_unregister_iommu_notifier(struct 
vfio_group *group,
        struct vfio_iommu_driver *driver;
        int ret;
 
-       ret = vfio_group_add_container_user(group);
-       if (ret)
-               return -EINVAL;
-
        container = group->container;
        driver = container->iommu_driver;
        if (likely(driver && driver->ops->unregister_notifier))
@@ -2269,9 +2248,6 @@ static int vfio_unregister_iommu_notifier(struct 
vfio_group *group,
                                                       nb);
        else
                ret = -ENOTTY;
-
-       vfio_group_try_dissolve_container(group);
-
        return ret;
 }
 
@@ -2300,10 +2276,6 @@ static int vfio_register_group_notifier(struct 
vfio_group *group,
        if (*events)
                return -EINVAL;
 
-       ret = vfio_group_add_container_user(group);
-       if (ret)
-               return -EINVAL;
-
        ret = blocking_notifier_chain_register(&group->notifier, nb);
 
        /*
@@ -2313,25 +2285,6 @@ static int vfio_register_group_notifier(struct 
vfio_group *group,
        if (!ret && set_kvm && group->kvm)
                blocking_notifier_call_chain(&group->notifier,
                                        VFIO_GROUP_NOTIFY_SET_KVM, group->kvm);
-
-       vfio_group_try_dissolve_container(group);
-
-       return ret;
-}
-
-static int vfio_unregister_group_notifier(struct vfio_group *group,
-                                        struct notifier_block *nb)
-{
-       int ret;
-
-       ret = vfio_group_add_container_user(group);
-       if (ret)
-               return -EINVAL;
-
-       ret = blocking_notifier_chain_unregister(&group->notifier, nb);
-
-       vfio_group_try_dissolve_container(group);
-
        return ret;
 }
 
@@ -2344,6 +2297,9 @@ int vfio_register_notifier(struct vfio_device *dev, enum 
vfio_notify_type type,
        if (!nb || !events || (*events == 0))
                return -EINVAL;
 
+       if (WARN_ON(!READ_ONCE(dev->open_count)))
+               return -EINVAL;
+
        switch (type) {
        case VFIO_IOMMU_NOTIFY:
                ret = vfio_register_iommu_notifier(group, events, nb);
@@ -2368,12 +2324,15 @@ int vfio_unregister_notifier(struct vfio_device *dev,
        if (!nb)
                return -EINVAL;
 
+       if (WARN_ON(!READ_ONCE(dev->open_count)))
+               return -EINVAL;
+
        switch (type) {
        case VFIO_IOMMU_NOTIFY:
                ret = vfio_unregister_iommu_notifier(group, nb);
                break;
        case VFIO_GROUP_NOTIFY:
-               ret = vfio_unregister_group_notifier(group, nb);
+               ret = blocking_notifier_chain_unregister(&group->notifier, nb);
                break;
        default:
                ret = -EINVAL;
-- 
2.35.1

Reply via email to