Move sanity and compatibility tests from the attach_dev callbacks to this
new test_dev callback function. The IOMMU core makes sure an attach_dev()
must be invoked after a successful test_dev callback.

Signed-off-by: Nicolin Chen <[email protected]>
---
 drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.h   |   2 +
 .../arm/arm-smmu-v3/arm-smmu-v3-iommufd.c     |   6 +-
 .../iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c   |   4 +-
 drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c   | 113 +++++++++++-------
 4 files changed, 74 insertions(+), 51 deletions(-)

diff --git a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.h 
b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.h
index ae23aacc38402..acb1dbc592cf0 100644
--- a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.h
+++ b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.h
@@ -963,6 +963,8 @@ void arm_smmu_write_cd_entry(struct arm_smmu_master 
*master, int ssid,
                             struct arm_smmu_cd *cdptr,
                             const struct arm_smmu_cd *target);
 
+int arm_smmu_domain_test_dev(struct iommu_domain *domain, struct device *dev,
+                            ioasid_t pasid, struct iommu_domain *old_domain);
 int arm_smmu_set_pasid(struct arm_smmu_master *master,
                       struct arm_smmu_domain *smmu_domain, ioasid_t pasid,
                       struct arm_smmu_cd *cd, struct iommu_domain *old);
diff --git a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-iommufd.c 
b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-iommufd.c
index 313201a616991..a253f9c8bb290 100644
--- a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-iommufd.c
+++ b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-iommufd.c
@@ -152,11 +152,6 @@ static int arm_smmu_attach_dev_nested(struct iommu_domain 
*domain,
        struct arm_smmu_ste ste;
        int ret;
 
-       if (nested_domain->vsmmu->smmu != master->smmu)
-               return -EINVAL;
-       if (arm_smmu_ssids_in_use(&master->cd_table))
-               return -EBUSY;
-
        mutex_lock(&arm_smmu_asid_lock);
        /*
         * The VM has to control the actual ATS state at the PCI device because
@@ -187,6 +182,7 @@ static void arm_smmu_domain_nested_free(struct iommu_domain 
*domain)
 }
 
 static const struct iommu_domain_ops arm_smmu_nested_ops = {
+       .test_dev = arm_smmu_domain_test_dev,
        .attach_dev = arm_smmu_attach_dev_nested,
        .free = arm_smmu_domain_nested_free,
 };
diff --git a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c 
b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c
index 59a480974d80f..610d9e826c07e 100644
--- a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c
+++ b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c
@@ -276,9 +276,6 @@ static int arm_smmu_sva_set_dev_pasid(struct iommu_domain 
*domain,
        struct arm_smmu_cd target;
        int ret;
 
-       if (!(master->smmu->features & ARM_SMMU_FEAT_SVA))
-               return -EOPNOTSUPP;
-
        /* Prevent arm_smmu_mm_release from being called while we are attaching 
*/
        if (!mmget_not_zero(domain->mm))
                return -EINVAL;
@@ -319,6 +316,7 @@ static void arm_smmu_sva_domain_free(struct iommu_domain 
*domain)
 }
 
 static const struct iommu_domain_ops arm_smmu_sva_domain_ops = {
+       .test_dev               = arm_smmu_domain_test_dev,
        .set_dev_pasid          = arm_smmu_sva_set_dev_pasid,
        .free                   = arm_smmu_sva_domain_free
 };
diff --git a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c 
b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c
index a33fbd12a0dd9..3448e55bbcdbb 100644
--- a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c
+++ b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c
@@ -2765,9 +2765,6 @@ static int arm_smmu_enable_iopf(struct arm_smmu_master 
*master,
 
        iommu_group_mutex_assert(master->dev);
 
-       if (!IS_ENABLED(CONFIG_ARM_SMMU_V3_SVA))
-               return -EOPNOTSUPP;
-
        /*
         * Drivers for devices supporting PRI or stall require iopf others have
         * device-specific fault handlers and don't need IOPF, so this is not a
@@ -2776,10 +2773,6 @@ static int arm_smmu_enable_iopf(struct arm_smmu_master 
*master,
        if (!master->stall_enabled)
                return 0;
 
-       /* We're not keeping track of SIDs in fault events */
-       if (master->num_streams != 1)
-               return -EOPNOTSUPP;
-
        if (master->iopf_refcount) {
                master->iopf_refcount++;
                master_domain->using_iopf = true;
@@ -2937,14 +2930,6 @@ int arm_smmu_attach_prepare(struct arm_smmu_attach_state 
*state,
                 * one of them.
                 */
                spin_lock_irqsave(&smmu_domain->devices_lock, flags);
-               if (smmu_domain->enforce_cache_coherency &&
-                   !arm_smmu_master_canwbs(master)) {
-                       spin_unlock_irqrestore(&smmu_domain->devices_lock,
-                                              flags);
-                       ret = -EINVAL;
-                       goto err_iopf;
-               }
-
                if (state->ats_enabled)
                        atomic_inc(&smmu_domain->nr_ats_masters);
                list_add(&master_domain->devices_elm, &smmu_domain->devices);
@@ -2962,8 +2947,6 @@ int arm_smmu_attach_prepare(struct arm_smmu_attach_state 
*state,
        }
        return 0;
 
-err_iopf:
-       arm_smmu_disable_iopf(master, master_domain);
 err_free_master_domain:
        kfree(master_domain);
 err_free_vmaster:
@@ -3002,13 +2985,79 @@ void arm_smmu_attach_commit(struct 
arm_smmu_attach_state *state)
        master->ats_enabled = state->ats_enabled;
 }
 
+int arm_smmu_domain_test_dev(struct iommu_domain *domain, struct device *dev,
+                            ioasid_t pasid, struct iommu_domain *old_domain)
+{
+       struct arm_smmu_domain *device_domain = to_smmu_domain_devices(domain);
+       struct arm_smmu_master *master = dev_iommu_priv_get(dev);
+
+       if (!dev_iommu_fwspec_get(dev))
+               return -ENOENT;
+
+       switch (domain->type) {
+       case IOMMU_DOMAIN_NESTED: {
+               struct arm_smmu_nested_domain *nested_domain =
+                       to_smmu_nested_domain(domain);
+
+               if (WARN_ON(pasid != IOMMU_NO_PASID))
+                       return -EOPNOTSUPP;
+               if (nested_domain->vsmmu->smmu != master->smmu)
+                       return -EINVAL;
+               if (arm_smmu_ssids_in_use(&master->cd_table))
+                       return -EBUSY;
+               break;
+       }
+       case IOMMU_DOMAIN_SVA:
+               if (!(master->smmu->features & ARM_SMMU_FEAT_SVA))
+                       return -EOPNOTSUPP;
+               break;
+       default: {
+               struct arm_smmu_domain *smmu_domain = to_smmu_domain(domain);
+
+               if (smmu_domain->smmu != master->smmu)
+                       return -EINVAL;
+               if (smmu_domain->stage == ARM_SMMU_DOMAIN_S2 &&
+                   arm_smmu_ssids_in_use(&master->cd_table))
+                       return -EBUSY;
+               if (pasid != IOMMU_NO_PASID) {
+                       struct iommu_domain *sid_domain =
+                               iommu_get_domain_for_dev(master->dev);
+
+                       if (smmu_domain->stage != ARM_SMMU_DOMAIN_S1)
+                               return -EINVAL;
+                       if (!master->cd_table.in_ste &&
+                           sid_domain->type != IOMMU_DOMAIN_IDENTITY &&
+                           sid_domain->type != IOMMU_DOMAIN_BLOCKED)
+                               return -EINVAL;
+               }
+               break;
+       }
+       }
+
+       if (domain->iopf_handler) {
+               if (!IS_ENABLED(CONFIG_ARM_SMMU_V3_SVA))
+                       return -EOPNOTSUPP;
+               /* We're not keeping track of SIDs in fault events */
+               if (master->stall_enabled && master->num_streams != 1)
+                       return -EOPNOTSUPP;
+       }
+
+       if (device_domain) {
+               scoped_guard(spinlock_irqsave, &device_domain->devices_lock) {
+                       if (device_domain->enforce_cache_coherency &&
+                           !arm_smmu_master_canwbs(master))
+                               return -EINVAL;
+               }
+       }
+
+       return 0;
+}
+
 static int arm_smmu_attach_dev(struct iommu_domain *domain, struct device *dev,
                               struct iommu_domain *old_domain)
 {
        int ret = 0;
        struct arm_smmu_ste target;
-       struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
-       struct arm_smmu_device *smmu;
        struct arm_smmu_domain *smmu_domain = to_smmu_domain(domain);
        struct arm_smmu_attach_state state = {
                .old_domain = old_domain,
@@ -3017,21 +3066,13 @@ static int arm_smmu_attach_dev(struct iommu_domain 
*domain, struct device *dev,
        struct arm_smmu_master *master;
        struct arm_smmu_cd *cdptr;
 
-       if (!fwspec)
-               return -ENOENT;
-
        state.master = master = dev_iommu_priv_get(dev);
-       smmu = master->smmu;
-
-       if (smmu_domain->smmu != smmu)
-               return -EINVAL;
 
        if (smmu_domain->stage == ARM_SMMU_DOMAIN_S1) {
                cdptr = arm_smmu_alloc_cd_ptr(master, IOMMU_NO_PASID);
                if (!cdptr)
                        return -ENOMEM;
-       } else if (arm_smmu_ssids_in_use(&master->cd_table))
-               return -EBUSY;
+       }
 
        /*
         * Prevent arm_smmu_share_asid() from trying to change the ASID
@@ -3078,15 +3119,8 @@ static int arm_smmu_s1_set_dev_pasid(struct iommu_domain 
*domain,
 {
        struct arm_smmu_domain *smmu_domain = to_smmu_domain(domain);
        struct arm_smmu_master *master = dev_iommu_priv_get(dev);
-       struct arm_smmu_device *smmu = master->smmu;
        struct arm_smmu_cd target_cd;
 
-       if (smmu_domain->smmu != smmu)
-               return -EINVAL;
-
-       if (smmu_domain->stage != ARM_SMMU_DOMAIN_S1)
-               return -EINVAL;
-
        /*
         * We can read cd.asid outside the lock because arm_smmu_set_pasid()
         * will fix it
@@ -3136,14 +3170,6 @@ int arm_smmu_set_pasid(struct arm_smmu_master *master,
 
        /* The core code validates pasid */
 
-       if (smmu_domain->smmu != master->smmu)
-               return -EINVAL;
-
-       if (!master->cd_table.in_ste &&
-           sid_domain->type != IOMMU_DOMAIN_IDENTITY &&
-           sid_domain->type != IOMMU_DOMAIN_BLOCKED)
-               return -EINVAL;
-
        cdptr = arm_smmu_alloc_cd_ptr(master, pasid);
        if (!cdptr)
                return -ENOMEM;
@@ -3695,6 +3721,7 @@ static const struct iommu_ops arm_smmu_ops = {
        .user_pasid_table       = 1,
        .owner                  = THIS_MODULE,
        .default_domain_ops = &(const struct iommu_domain_ops) {
+               .test_dev               = arm_smmu_domain_test_dev,
                .attach_dev             = arm_smmu_attach_dev,
                .enforce_cache_coherency = arm_smmu_enforce_cache_coherency,
                .set_dev_pasid          = arm_smmu_s1_set_dev_pasid,
-- 
2.43.0


Reply via email to