There are several places in the code that need to get the pointers of
svm and sdev according to a pasid and device. Add a helper to achieve
this for code consolidation and readability.

Signed-off-by: Lu Baolu <baolu...@linux.intel.com>
Reviewed-by: Kevin Tian <kevin.t...@intel.com>
---
 drivers/iommu/intel/svm.c | 115 +++++++++++++++++++++-----------------
 1 file changed, 65 insertions(+), 50 deletions(-)

diff --git a/drivers/iommu/intel/svm.c b/drivers/iommu/intel/svm.c
index 65d2327dcd0d..c104a50a625c 100644
--- a/drivers/iommu/intel/svm.c
+++ b/drivers/iommu/intel/svm.c
@@ -228,13 +228,57 @@ static LIST_HEAD(global_svm_list);
        list_for_each_entry((sdev), &(svm)->devs, list) \
                if ((d) != (sdev)->dev) {} else
 
+static int pasid_to_svm_sdev(struct device *dev, unsigned int pasid,
+                            struct intel_svm **rsvm,
+                            struct intel_svm_dev **rsdev)
+{
+       struct intel_svm_dev *d, *sdev = NULL;
+       struct intel_svm *svm;
+
+       /* The caller should hold the pasid_mutex lock */
+       if (WARN_ON(!mutex_is_locked(&pasid_mutex)))
+               return -EINVAL;
+
+       if (pasid == INVALID_IOASID || pasid >= PASID_MAX)
+               return -EINVAL;
+
+       svm = ioasid_find(NULL, pasid, NULL);
+       if (IS_ERR(svm))
+               return PTR_ERR(svm);
+
+       if (!svm)
+               goto out;
+
+       /*
+        * If we found svm for the PASID, there must be at least one device
+        * bond.
+        */
+       if (WARN_ON(list_empty(&svm->devs)))
+               return -EINVAL;
+
+       rcu_read_lock();
+       list_for_each_entry_rcu(d, &svm->devs, list) {
+               if (d->dev == dev) {
+                       sdev = d;
+                       break;
+               }
+       }
+       rcu_read_unlock();
+
+out:
+       *rsvm = svm;
+       *rsdev = sdev;
+
+       return 0;
+}
+
 int intel_svm_bind_gpasid(struct iommu_domain *domain, struct device *dev,
                          struct iommu_gpasid_bind_data *data)
 {
        struct intel_iommu *iommu = device_to_iommu(dev, NULL, NULL);
+       struct intel_svm_dev *sdev = NULL;
        struct dmar_domain *dmar_domain;
-       struct intel_svm_dev *sdev;
-       struct intel_svm *svm;
+       struct intel_svm *svm = NULL;
        int ret = 0;
 
        if (WARN_ON(!iommu) || !data)
@@ -261,35 +305,23 @@ int intel_svm_bind_gpasid(struct iommu_domain *domain, 
struct device *dev,
        dmar_domain = to_dmar_domain(domain);
 
        mutex_lock(&pasid_mutex);
-       svm = ioasid_find(NULL, data->hpasid, NULL);
-       if (IS_ERR(svm)) {
-               ret = PTR_ERR(svm);
+       ret = pasid_to_svm_sdev(dev, data->hpasid, &svm, &sdev);
+       if (ret)
                goto out;
-       }
-
-       if (svm) {
-               /*
-                * If we found svm for the PASID, there must be at
-                * least one device bond, otherwise svm should be freed.
-                */
-               if (WARN_ON(list_empty(&svm->devs))) {
-                       ret = -EINVAL;
-                       goto out;
-               }
 
+       if (sdev) {
                /*
                 * Do not allow multiple bindings of the same device-PASID since
                 * there is only one SL page tables per PASID. We may revisit
                 * once sharing PGD across domains are supported.
                 */
-               for_each_svm_dev(sdev, svm, dev) {
-                       dev_warn_ratelimited(dev,
-                                            "Already bound with PASID %u\n",
-                                            svm->pasid);
-                       ret = -EBUSY;
-                       goto out;
-               }
-       } else {
+               dev_warn_ratelimited(dev, "Already bound with PASID %u\n",
+                                    svm->pasid);
+               ret = -EBUSY;
+               goto out;
+       }
+
+       if (!svm) {
                /* We come here when PASID has never been bond to a device. */
                svm = kzalloc(sizeof(*svm), GFP_KERNEL);
                if (!svm) {
@@ -372,25 +404,17 @@ int intel_svm_unbind_gpasid(struct device *dev, int pasid)
        struct intel_iommu *iommu = device_to_iommu(dev, NULL, NULL);
        struct intel_svm_dev *sdev;
        struct intel_svm *svm;
-       int ret = -EINVAL;
+       int ret;
 
        if (WARN_ON(!iommu))
                return -EINVAL;
 
        mutex_lock(&pasid_mutex);
-       svm = ioasid_find(NULL, pasid, NULL);
-       if (!svm) {
-               ret = -EINVAL;
-               goto out;
-       }
-
-       if (IS_ERR(svm)) {
-               ret = PTR_ERR(svm);
+       ret = pasid_to_svm_sdev(dev, pasid, &svm, &sdev);
+       if (ret)
                goto out;
-       }
 
-       for_each_svm_dev(sdev, svm, dev) {
-               ret = 0;
+       if (sdev) {
                if (iommu_dev_feature_enabled(dev, IOMMU_DEV_FEAT_AUX))
                        sdev->users--;
                if (!sdev->users) {
@@ -414,7 +438,6 @@ int intel_svm_unbind_gpasid(struct device *dev, int pasid)
                                kfree(svm);
                        }
                }
-               break;
        }
 out:
        mutex_unlock(&pasid_mutex);
@@ -592,7 +615,7 @@ intel_svm_bind_mm(struct device *dev, int flags, struct 
svm_dev_ops *ops,
        if (sd)
                *sd = sdev;
        ret = 0;
- out:
+out:
        return ret;
 }
 
@@ -608,17 +631,11 @@ static int intel_svm_unbind_mm(struct device *dev, int 
pasid)
        if (!iommu)
                goto out;
 
-       svm = ioasid_find(NULL, pasid, NULL);
-       if (!svm)
-               goto out;
-
-       if (IS_ERR(svm)) {
-               ret = PTR_ERR(svm);
+       ret = pasid_to_svm_sdev(dev, pasid, &svm, &sdev);
+       if (ret)
                goto out;
-       }
 
-       for_each_svm_dev(sdev, svm, dev) {
-               ret = 0;
+       if (sdev) {
                sdev->users--;
                if (!sdev->users) {
                        list_del_rcu(&sdev->list);
@@ -647,10 +664,8 @@ static int intel_svm_unbind_mm(struct device *dev, int 
pasid)
                                kfree(svm);
                        }
                }
-               break;
        }
- out:
-
+out:
        return ret;
 }
 
-- 
2.17.1

_______________________________________________
iommu mailing list
iommu@lists.linux-foundation.org
https://lists.linuxfoundation.org/mailman/listinfo/iommu

Reply via email to