From: Jason Gunthorpe <[email protected]>

Some of the configurations during the attach/replace() should only apply
to IOMMUFD_OBJ_HWPT_PAGING. Once IOMMUFD_OBJ_HWPT_NESTED gets introduced
in a following patch, keeping them unconditionally in the common routine
will not work.

Wrap all of those PAGING-only configurations together into helpers. Do a
hwpt_is_paging check whenever calling them or their fallback routines.

Also, move "num_devices++" closer to the place using it.

Signed-off-by: Jason Gunthorpe <[email protected]>
Signed-off-by: Nicolin Chen <[email protected]>
Signed-off-by: Yi Liu <[email protected]>
---
 drivers/iommu/iommufd/device.c          | 137 ++++++++++++++++--------
 drivers/iommu/iommufd/iommufd_private.h |   5 +
 2 files changed, 99 insertions(+), 43 deletions(-)

diff --git a/drivers/iommu/iommufd/device.c b/drivers/iommu/iommufd/device.c
index a5f9f20b2a9b..3414276bbd15 100644
--- a/drivers/iommu/iommufd/device.c
+++ b/drivers/iommu/iommufd/device.c
@@ -325,6 +325,35 @@ static int iommufd_group_setup_msi(struct iommufd_group 
*igroup,
        return 0;
 }
 
+static int iommufd_hwpt_paging_attach(struct iommufd_hw_pagetable *hwpt,
+                                     struct iommufd_device *idev)
+{
+       int rc;
+
+       lockdep_assert_held(&idev->igroup->lock);
+
+       /* Try to upgrade the domain we have */
+       if (idev->enforce_cache_coherency) {
+               rc = iommufd_hw_pagetable_enforce_cc(hwpt);
+               if (rc)
+                       return rc;
+       }
+
+       rc = iopt_table_enforce_dev_resv_regions(&hwpt->ioas->iopt, idev->dev,
+                                                &idev->igroup->sw_msi_start);
+       if (rc)
+               return rc;
+
+       if (list_empty(&idev->igroup->device_list)) {
+               rc = iommufd_group_setup_msi(idev->igroup, hwpt);
+               if (rc) {
+                       iopt_remove_reserved_iova(&hwpt->ioas->iopt, idev->dev);
+                       return rc;
+               }
+       }
+       return 0;
+}
+
 int iommufd_hw_pagetable_attach(struct iommufd_hw_pagetable *hwpt,
                                struct iommufd_device *idev)
 {
@@ -337,18 +366,12 @@ int iommufd_hw_pagetable_attach(struct 
iommufd_hw_pagetable *hwpt,
                goto err_unlock;
        }
 
-       /* Try to upgrade the domain we have */
-       if (idev->enforce_cache_coherency) {
-               rc = iommufd_hw_pagetable_enforce_cc(hwpt);
+       if (hwpt_is_paging(hwpt)) {
+               rc = iommufd_hwpt_paging_attach(hwpt, idev);
                if (rc)
                        goto err_unlock;
        }
 
-       rc = iopt_table_enforce_dev_resv_regions(&hwpt->ioas->iopt, idev->dev,
-                                                &idev->igroup->sw_msi_start);
-       if (rc)
-               goto err_unlock;
-
        /*
         * Only attach to the group once for the first device that is in the
         * group. All the other devices will follow this attachment. The user
@@ -357,10 +380,6 @@ int iommufd_hw_pagetable_attach(struct 
iommufd_hw_pagetable *hwpt,
         * attachment.
         */
        if (list_empty(&idev->igroup->device_list)) {
-               rc = iommufd_group_setup_msi(idev->igroup, hwpt);
-               if (rc)
-                       goto err_unresv;
-
                rc = iommu_attach_group(hwpt->domain, idev->igroup->group);
                if (rc)
                        goto err_unresv;
@@ -371,7 +390,8 @@ int iommufd_hw_pagetable_attach(struct iommufd_hw_pagetable 
*hwpt,
        mutex_unlock(&idev->igroup->lock);
        return 0;
 err_unresv:
-       iopt_remove_reserved_iova(&hwpt->ioas->iopt, idev->dev);
+       if (hwpt_is_paging(hwpt))
+               iopt_remove_reserved_iova(&hwpt->ioas->iopt, idev->dev);
 err_unlock:
        mutex_unlock(&idev->igroup->lock);
        return rc;
@@ -388,7 +408,8 @@ iommufd_hw_pagetable_detach(struct iommufd_device *idev)
                iommu_detach_group(hwpt->domain, idev->igroup->group);
                idev->igroup->hwpt = NULL;
        }
-       iopt_remove_reserved_iova(&hwpt->ioas->iopt, idev->dev);
+       if (hwpt_is_paging(hwpt))
+               iopt_remove_reserved_iova(&hwpt->ioas->iopt, idev->dev);
        mutex_unlock(&idev->igroup->lock);
 
        /* Caller must destroy hwpt */
@@ -407,40 +428,36 @@ iommufd_device_do_attach(struct iommufd_device *idev,
        return NULL;
 }
 
-static struct iommufd_hw_pagetable *
-iommufd_device_do_replace(struct iommufd_device *idev,
-                         struct iommufd_hw_pagetable *hwpt)
+static void iommufd_group_remove_reserved_iova(struct iommufd_group *igroup,
+                                              struct iommufd_hw_pagetable 
*hwpt)
 {
-       struct iommufd_group *igroup = idev->igroup;
-       struct iommufd_hw_pagetable *old_hwpt;
-       unsigned int num_devices = 0;
        struct iommufd_device *cur;
-       int rc;
 
-       mutex_lock(&idev->igroup->lock);
+       lockdep_assert_held(&igroup->lock);
 
-       if (igroup->hwpt == NULL) {
-               rc = -EINVAL;
-               goto err_unlock;
-       }
+       list_for_each_entry(cur, &igroup->device_list, group_item)
+               iopt_remove_reserved_iova(&hwpt->ioas->iopt, cur->dev);
+}
 
-       if (hwpt == igroup->hwpt) {
-               mutex_unlock(&idev->igroup->lock);
-               return NULL;
-       }
+static int iommufd_group_do_replace_paging(struct iommufd_group *igroup,
+                                          struct iommufd_hw_pagetable *hwpt)
+{
+       struct iommufd_hw_pagetable *old_hwpt = igroup->hwpt;
+       struct iommufd_device *cur;
+       int rc;
+
+       lockdep_assert_held(&igroup->lock);
 
        /* Try to upgrade the domain we have */
        list_for_each_entry(cur, &igroup->device_list, group_item) {
-               num_devices++;
                if (cur->enforce_cache_coherency) {
                        rc = iommufd_hw_pagetable_enforce_cc(hwpt);
                        if (rc)
-                               goto err_unlock;
+                               return rc;
                }
        }
 
-       old_hwpt = igroup->hwpt;
-       if (hwpt->ioas != old_hwpt->ioas) {
+       if (hwpt_is_paging(old_hwpt) && hwpt->ioas != old_hwpt->ioas) {
                list_for_each_entry(cur, &igroup->device_list, group_item) {
                        rc = iopt_table_enforce_dev_resv_regions(
                                &hwpt->ioas->iopt, cur->dev, NULL);
@@ -448,23 +465,57 @@ iommufd_device_do_replace(struct iommufd_device *idev,
                                goto err_unresv;
                }
        }
-
-       rc = iommufd_group_setup_msi(idev->igroup, hwpt);
+       rc = iommufd_group_setup_msi(igroup, hwpt);
        if (rc)
                goto err_unresv;
+       return 0;
+
+err_unresv:
+       iommufd_group_remove_reserved_iova(igroup, hwpt);
+       return rc;
+}
+
+static struct iommufd_hw_pagetable *
+iommufd_device_do_replace(struct iommufd_device *idev,
+                         struct iommufd_hw_pagetable *hwpt)
+{
+       struct iommufd_group *igroup = idev->igroup;
+       struct iommufd_hw_pagetable *old_hwpt;
+       unsigned int num_devices = 0;
+       struct iommufd_device *cur;
+       int rc;
+
+       mutex_lock(&idev->igroup->lock);
+
+       if (igroup->hwpt == NULL) {
+               rc = -EINVAL;
+               goto err_unlock;
+       }
+
+       if (hwpt == igroup->hwpt) {
+               mutex_unlock(&idev->igroup->lock);
+               return NULL;
+       }
+
+       if (hwpt_is_paging(hwpt)) {
+               rc = iommufd_group_do_replace_paging(igroup, hwpt);
+               if (rc)
+                       goto err_unlock;
+       }
 
        rc = iommu_group_replace_domain(igroup->group, hwpt->domain);
        if (rc)
                goto err_unresv;
 
-       if (hwpt->ioas != old_hwpt->ioas) {
-               list_for_each_entry(cur, &igroup->device_list, group_item)
-                       iopt_remove_reserved_iova(&old_hwpt->ioas->iopt,
-                                                 cur->dev);
-       }
+       old_hwpt = igroup->hwpt;
+       if (hwpt_is_paging(old_hwpt) &&
+           (!hwpt_is_paging(hwpt) || hwpt->ioas != old_hwpt->ioas))
+               iommufd_group_remove_reserved_iova(igroup, old_hwpt);
 
        igroup->hwpt = hwpt;
 
+       list_for_each_entry(cur, &igroup->device_list, group_item)
+               num_devices++;
        /*
         * Move the refcounts held by the device_list to the new hwpt. Retain a
         * refcount for this thread as the caller will free it.
@@ -478,8 +529,8 @@ iommufd_device_do_replace(struct iommufd_device *idev,
        /* Caller must destroy old_hwpt */
        return old_hwpt;
 err_unresv:
-       list_for_each_entry(cur, &igroup->device_list, group_item)
-               iopt_remove_reserved_iova(&hwpt->ioas->iopt, cur->dev);
+       if (hwpt_is_paging(hwpt))
+               iommufd_group_remove_reserved_iova(igroup, hwpt);
 err_unlock:
        mutex_unlock(&idev->igroup->lock);
        return ERR_PTR(rc);
diff --git a/drivers/iommu/iommufd/iommufd_private.h 
b/drivers/iommu/iommufd/iommufd_private.h
index 07ce35f09599..6244ffddee6e 100644
--- a/drivers/iommu/iommufd/iommufd_private.h
+++ b/drivers/iommu/iommufd/iommufd_private.h
@@ -240,6 +240,11 @@ struct iommufd_hw_pagetable {
        struct list_head hwpt_item;
 };
 
+static inline bool hwpt_is_paging(struct iommufd_hw_pagetable *hwpt)
+{
+       return hwpt->obj.type == IOMMUFD_OBJ_HWPT_PAGING;
+}
+
 struct iommufd_hw_pagetable *
 iommufd_hw_pagetable_alloc(struct iommufd_ctx *ictx, struct iommufd_ioas *ioas,
                           struct iommufd_device *idev, u32 flags,
-- 
2.34.1

Reply via email to