From: Nicolin Chen <[email protected]>

This allows iommufd_hwpt_alloc() to have a common routine but jump to a
different allocator and hold a different mutex, corresponding to types
of HWPT allocation (either kernel-managed or user-managed). This shared
function pointer takes "pt_obj" as an input that would be coverted into
an IOAS pointer or a parent HWPT pointer.

Then, update the kernel-managed allocator to follow this pt_obj change.

Signed-off-by: Nicolin Chen <[email protected]>
Signed-off-by: Yi Liu <[email protected]>
---
 drivers/iommu/iommufd/device.c          |  2 +-
 drivers/iommu/iommufd/hw_pagetable.c    | 46 ++++++++++++++++++-------
 drivers/iommu/iommufd/iommufd_private.h |  3 +-
 3 files changed, 37 insertions(+), 14 deletions(-)

diff --git a/drivers/iommu/iommufd/device.c b/drivers/iommu/iommufd/device.c
index e04900f101f1..eb120f70a3e3 100644
--- a/drivers/iommu/iommufd/device.c
+++ b/drivers/iommu/iommufd/device.c
@@ -539,7 +539,7 @@ iommufd_device_auto_get_domain(struct iommufd_device *idev,
                goto out_unlock;
        }
 
-       hwpt = iommufd_hw_pagetable_alloc(idev->ictx, ioas, idev,
+       hwpt = iommufd_hw_pagetable_alloc(idev->ictx, &ioas->obj, idev,
                                          0, IOMMU_HWPT_TYPE_DEFAULT,
                                          NULL, immediate_attach);
        if (IS_ERR(hwpt)) {
diff --git a/drivers/iommu/iommufd/hw_pagetable.c 
b/drivers/iommu/iommufd/hw_pagetable.c
index 1cc7178121d1..b2af68776877 100644
--- a/drivers/iommu/iommufd/hw_pagetable.c
+++ b/drivers/iommu/iommufd/hw_pagetable.c
@@ -69,7 +69,7 @@ int iommufd_hw_pagetable_enforce_cc(struct 
iommufd_hw_pagetable *hwpt)
 /**
  * iommufd_hw_pagetable_alloc() - Get a kernel-managed iommu_domain for a 
device
  * @ictx: iommufd context
- * @ioas: IOAS to associate the domain with
+ * @pt_obj: An object to an IOAS to associate the domain with
  * @idev: Device to get an iommu_domain for
  * @flags: Flags from userspace
  * @hwpt_type: Requested type of hw_pagetable
@@ -85,12 +85,15 @@ int iommufd_hw_pagetable_enforce_cc(struct 
iommufd_hw_pagetable *hwpt)
  * the returned hwpt.
  */
 struct iommufd_hw_pagetable *
-iommufd_hw_pagetable_alloc(struct iommufd_ctx *ictx, struct iommufd_ioas *ioas,
+iommufd_hw_pagetable_alloc(struct iommufd_ctx *ictx,
+                          struct iommufd_object *pt_obj,
                           struct iommufd_device *idev, u32 flags,
                           enum iommu_hwpt_type hwpt_type,
                           struct iommu_user_data *user_data,
                           bool immediate_attach)
 {
+       struct iommufd_ioas *ioas =
+               container_of(pt_obj, struct iommufd_ioas, obj);
        const struct iommu_ops *ops = dev_iommu_ops(idev->dev);
        struct iommufd_hw_pagetable *hwpt;
        int rc;
@@ -184,10 +187,19 @@ iommufd_hw_pagetable_alloc(struct iommufd_ctx *ictx, 
struct iommufd_ioas *ioas,
 
 int iommufd_hwpt_alloc(struct iommufd_ucmd *ucmd)
 {
+       struct iommufd_hw_pagetable *(*alloc_fn)(
+                                       struct iommufd_ctx *ictx,
+                                       struct iommufd_object *pt_obj,
+                                       struct iommufd_device *idev,
+                                       u32 flags, enum iommu_hwpt_type type,
+                                       struct iommu_user_data *user_data,
+                                       bool flag);
        struct iommu_hwpt_alloc *cmd = ucmd->cmd;
        struct iommufd_hw_pagetable *hwpt;
+       struct iommufd_object *pt_obj;
        struct iommufd_device *idev;
        struct iommufd_ioas *ioas;
+       struct mutex *mutex;
        int rc;
 
        if (cmd->flags & ~IOMMU_HWPT_ALLOC_NEST_PARENT || cmd->__reserved)
@@ -197,17 +209,26 @@ int iommufd_hwpt_alloc(struct iommufd_ucmd *ucmd)
        if (IS_ERR(idev))
                return PTR_ERR(idev);
 
-       ioas = iommufd_get_ioas(ucmd->ictx, cmd->pt_id);
-       if (IS_ERR(ioas)) {
-               rc = PTR_ERR(ioas);
+       pt_obj = iommufd_get_object(ucmd->ictx, cmd->pt_id, IOMMUFD_OBJ_ANY);
+       if (IS_ERR(pt_obj)) {
+               rc = -EINVAL;
                goto out_put_idev;
        }
 
-       mutex_lock(&ioas->mutex);
-       hwpt = iommufd_hw_pagetable_alloc(ucmd->ictx, ioas,
-                                         idev, cmd->flags,
-                                         IOMMU_HWPT_TYPE_DEFAULT,
-                                         NULL, false);
+       switch (pt_obj->type) {
+       case IOMMUFD_OBJ_IOAS:
+               ioas = container_of(pt_obj, struct iommufd_ioas, obj);
+               mutex = &ioas->mutex;
+               alloc_fn = iommufd_hw_pagetable_alloc;
+               break;
+       default:
+               rc = -EINVAL;
+               goto out_put_pt;
+       }
+
+       mutex_lock(mutex);
+       hwpt = alloc_fn(ucmd->ictx, pt_obj, idev, cmd->flags,
+                       IOMMU_HWPT_TYPE_DEFAULT, NULL, false);
        if (IS_ERR(hwpt)) {
                rc = PTR_ERR(hwpt);
                goto out_unlock;
@@ -223,8 +244,9 @@ int iommufd_hwpt_alloc(struct iommufd_ucmd *ucmd)
 out_hwpt:
        iommufd_object_abort_and_destroy(ucmd->ictx, &hwpt->obj);
 out_unlock:
-       mutex_unlock(&ioas->mutex);
-       iommufd_put_object(&ioas->obj);
+       mutex_unlock(mutex);
+out_put_pt:
+       iommufd_put_object(pt_obj);
 out_put_idev:
        iommufd_put_object(&idev->obj);
        return rc;
diff --git a/drivers/iommu/iommufd/iommufd_private.h 
b/drivers/iommu/iommufd/iommufd_private.h
index 3e89c3d530f3..e4d06ae6b0c5 100644
--- a/drivers/iommu/iommufd/iommufd_private.h
+++ b/drivers/iommu/iommufd/iommufd_private.h
@@ -250,7 +250,8 @@ struct iommufd_hw_pagetable {
 };
 
 struct iommufd_hw_pagetable *
-iommufd_hw_pagetable_alloc(struct iommufd_ctx *ictx, struct iommufd_ioas *ioas,
+iommufd_hw_pagetable_alloc(struct iommufd_ctx *ictx,
+                          struct iommufd_object *pt_obj,
                           struct iommufd_device *idev, u32 flags,
                           enum iommu_hwpt_type hwpt_type,
                           struct iommu_user_data *user_data,
-- 
2.34.1

Reply via email to