mtk was doing a lot of stuff under of_xlate, and it looked kind of like it
might support multi-instances. But the dt files don't do that, and the
driver has no way to keep track of which instance the ids are for.

Enforce single instance with iommu_of_get_single_iommu().

Introduce a per-device data to store the iommu and ids list. Allocate and
initialize it with iommu_fw_alloc_per_device_ids(). Remove
mtk_iommu_of_xlate().

Convert the rest of the funcs from calling dev_iommu_fwspec_get() to using
the per-device data and remove all use of fwspec.

Covnert the places using dev_iommu_priv_get() to use the per-device data
not the iommu.

Signed-off-by: Jason Gunthorpe <[email protected]>
---
 drivers/iommu/mtk_iommu.c | 116 ++++++++++++++++++++------------------
 1 file changed, 62 insertions(+), 54 deletions(-)

diff --git a/drivers/iommu/mtk_iommu.c b/drivers/iommu/mtk_iommu.c
index 7abe9e85a57063..477171e83eaa6e 100644
--- a/drivers/iommu/mtk_iommu.c
+++ b/drivers/iommu/mtk_iommu.c
@@ -13,6 +13,7 @@
 #include <linux/interrupt.h>
 #include <linux/io.h>
 #include <linux/iommu.h>
+#include <linux/iommu-driver.h>
 #include <linux/iopoll.h>
 #include <linux/io-pgtable.h>
 #include <linux/list.h>
@@ -277,6 +278,12 @@ struct mtk_iommu_data {
        struct mtk_smi_larb_iommu       larb_imu[MTK_LARB_NR_MAX];
 };
 
+struct mtk_iommu_device {
+       struct mtk_iommu_data *iommu;
+       unsigned int num_ids;
+       u32 ids[] __counted_by(num_ids);
+};
+
 struct mtk_iommu_domain {
        struct io_pgtable_cfg           cfg;
        struct io_pgtable_ops           *iop;
@@ -526,14 +533,14 @@ static irqreturn_t mtk_iommu_isr(int irq, void *dev_id)
 static unsigned int mtk_iommu_get_bank_id(struct device *dev,
                                          const struct mtk_iommu_plat_data 
*plat_data)
 {
-       struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
+       struct mtk_iommu_device *mtkdev = dev_iommu_priv_get(dev);
        unsigned int i, portmsk = 0, bankid = 0;
 
        if (plat_data->banks_num == 1)
                return bankid;
 
-       for (i = 0; i < fwspec->num_ids; i++)
-               portmsk |= BIT(MTK_M4U_TO_PORT(fwspec->ids[i]));
+       for (i = 0; i < mtkdev->num_ids; i++)
+               portmsk |= BIT(MTK_M4U_TO_PORT(mtkdev->ids[i]));
 
        for (i = 0; i < plat_data->banks_num && i < MTK_IOMMU_BANK_MAX; i++) {
                if (!plat_data->banks_enable[i])
@@ -550,7 +557,7 @@ static unsigned int mtk_iommu_get_bank_id(struct device 
*dev,
 static int mtk_iommu_get_iova_region_id(struct device *dev,
                                        const struct mtk_iommu_plat_data 
*plat_data)
 {
-       struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
+       struct mtk_iommu_device *mtkdev = dev_iommu_priv_get(dev);
        unsigned int portidmsk = 0, larbid;
        const u32 *rgn_larb_msk;
        int i;
@@ -558,9 +565,9 @@ static int mtk_iommu_get_iova_region_id(struct device *dev,
        if (plat_data->iova_region_nr == 1)
                return 0;
 
-       larbid = MTK_M4U_TO_LARB(fwspec->ids[0]);
-       for (i = 0; i < fwspec->num_ids; i++)
-               portidmsk |= BIT(MTK_M4U_TO_PORT(fwspec->ids[i]));
+       larbid = MTK_M4U_TO_LARB(mtkdev->ids[0]);
+       for (i = 0; i < mtkdev->num_ids; i++)
+               portidmsk |= BIT(MTK_M4U_TO_PORT(mtkdev->ids[i]));
 
        for (i = 0; i < plat_data->iova_region_nr; i++) {
                rgn_larb_msk = plat_data->iova_region_larb_msk[i];
@@ -579,22 +586,22 @@ static int mtk_iommu_get_iova_region_id(struct device 
*dev,
 static int mtk_iommu_config(struct mtk_iommu_data *data, struct device *dev,
                            bool enable, unsigned int regionid)
 {
+       struct mtk_iommu_device *mtkdev = dev_iommu_priv_get(dev);
        struct mtk_smi_larb_iommu    *larb_mmu;
        unsigned int                 larbid, portid;
-       struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
        const struct mtk_iommu_iova_region *region;
        unsigned long portid_msk = 0;
        struct arm_smccc_res res;
        int i, ret = 0;
 
-       for (i = 0; i < fwspec->num_ids; ++i) {
-               portid = MTK_M4U_TO_PORT(fwspec->ids[i]);
+       for (i = 0; i < mtkdev->num_ids; ++i) {
+               portid = MTK_M4U_TO_PORT(mtkdev->ids[i]);
                portid_msk |= BIT(portid);
        }
 
        if (MTK_IOMMU_IS_TYPE(data->plat_data, MTK_IOMMU_TYPE_MM)) {
                /* All ports should be in the same larb. just use 0 here */
-               larbid = MTK_M4U_TO_LARB(fwspec->ids[0]);
+               larbid = MTK_M4U_TO_LARB(mtkdev->ids[0]);
                larb_mmu = &data->larb_imu[larbid];
                region = data->plat_data->iova_region + regionid;
 
@@ -618,7 +625,7 @@ static int mtk_iommu_config(struct mtk_iommu_data *data, 
struct device *dev,
                } else {
                        /* PCI dev has only one output id, enable the next 
writing bit for PCIe */
                        if (dev_is_pci(dev)) {
-                               if (fwspec->num_ids != 1) {
+                               if (mtkdev->num_ids != 1) {
                                        dev_err(dev, "PCI dev can only have one 
port.\n");
                                        return -ENODEV;
                                }
@@ -708,7 +715,9 @@ static void mtk_iommu_domain_free(struct iommu_domain 
*domain)
 static int mtk_iommu_attach_device(struct iommu_domain *domain,
                                   struct device *dev)
 {
-       struct mtk_iommu_data *data = dev_iommu_priv_get(dev), *frstdata;
+       struct mtk_iommu_device *mtkdev = dev_iommu_priv_get(dev);
+       struct mtk_iommu_data *data = mtkdev->iommu;
+       struct mtk_iommu_data *frstdata;
        struct mtk_iommu_domain *dom = to_mtk_domain(domain);
        struct list_head *hw_list = data->hw_list;
        struct device *m4udev = data->dev;
@@ -777,12 +786,12 @@ static int mtk_iommu_identity_attach(struct iommu_domain 
*identity_domain,
                                     struct device *dev)
 {
        struct iommu_domain *domain = iommu_get_domain_for_dev(dev);
-       struct mtk_iommu_data *data = dev_iommu_priv_get(dev);
+       struct mtk_iommu_device *mtkdev = dev_iommu_priv_get(dev);
 
        if (domain == identity_domain || !domain)
                return 0;
 
-       mtk_iommu_config(data, dev, false, 0);
+       mtk_iommu_config(mtkdev->iommu, dev, false, 0);
        return 0;
 }
 
@@ -860,14 +869,28 @@ static phys_addr_t mtk_iommu_iova_to_phys(struct 
iommu_domain *domain,
        return pa;
 }
 
-static struct iommu_device *mtk_iommu_probe_device(struct device *dev)
+static struct iommu_device *
+mtk_iommu_probe_device(struct iommu_probe_info *pinf)
 {
-       struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
-       struct mtk_iommu_data *data = dev_iommu_priv_get(dev);
+       struct mtk_iommu_device *mtkdev;
+       struct device *dev = pinf->dev;
+       struct mtk_iommu_data *data;
        struct device_link *link;
        struct device *larbdev;
        unsigned int larbid, larbidx, i;
 
+       data = iommu_of_get_single_iommu(pinf, &mtk_iommu_ops, 1,
+                                        struct mtk_iommu_data, iommu);
+       if (IS_ERR(data))
+               return ERR_CAST(data);
+
+       mtkdev = iommu_fw_alloc_per_device_ids(pinf, mtkdev);
+       if (IS_ERR(mtkdev))
+               return ERR_CAST(mtkdev);
+       mtkdev->iommu = data;
+
+       dev_iommu_priv_set(dev, mtkdev);
+
        if (!MTK_IOMMU_IS_TYPE(data->plat_data, MTK_IOMMU_TYPE_MM))
                return &data->iommu;
 
@@ -876,42 +899,46 @@ static struct iommu_device *mtk_iommu_probe_device(struct 
device *dev)
         * The device that connects with each a larb is a independent HW.
         * All the ports in each a device should be in the same larbs.
         */
-       larbid = MTK_M4U_TO_LARB(fwspec->ids[0]);
+       larbid = MTK_M4U_TO_LARB(mtkdev->ids[0]);
        if (larbid >= MTK_LARB_NR_MAX)
-               return ERR_PTR(-EINVAL);
+               goto err_out;
 
-       for (i = 1; i < fwspec->num_ids; i++) {
-               larbidx = MTK_M4U_TO_LARB(fwspec->ids[i]);
+       for (i = 1; i < mtkdev->num_ids; i++) {
+               larbidx = MTK_M4U_TO_LARB(mtkdev->ids[i]);
                if (larbid != larbidx) {
                        dev_err(dev, "Can only use one larb. Fail@larb%d-%d.\n",
                                larbid, larbidx);
-                       return ERR_PTR(-EINVAL);
+                       goto err_out;
                }
        }
        larbdev = data->larb_imu[larbid].dev;
        if (!larbdev)
-               return ERR_PTR(-EINVAL);
+               goto err_out;
 
        link = device_link_add(dev, larbdev,
                               DL_FLAG_PM_RUNTIME | DL_FLAG_STATELESS);
        if (!link)
                dev_err(dev, "Unable to link %s\n", dev_name(larbdev));
        return &data->iommu;
+
+err_out:
+       kfree(mtkdev);
+       return ERR_PTR(-EINVAL);
 }
 
 static void mtk_iommu_release_device(struct device *dev)
 {
-       struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
-       struct mtk_iommu_data *data;
+       struct mtk_iommu_device *mtkdev = dev_iommu_priv_get(dev);
+       struct mtk_iommu_data *data = mtkdev->iommu;
        struct device *larbdev;
        unsigned int larbid;
 
-       data = dev_iommu_priv_get(dev);
        if (MTK_IOMMU_IS_TYPE(data->plat_data, MTK_IOMMU_TYPE_MM)) {
-               larbid = MTK_M4U_TO_LARB(fwspec->ids[0]);
+               larbid = MTK_M4U_TO_LARB(mtkdev->ids[0]);
                larbdev = data->larb_imu[larbid].dev;
                device_link_remove(dev, larbdev);
        }
+       kfree(mtkdev);
 }
 
 static int mtk_iommu_get_group_id(struct device *dev, const struct 
mtk_iommu_plat_data *plat_data)
@@ -931,7 +958,9 @@ static int mtk_iommu_get_group_id(struct device *dev, const 
struct mtk_iommu_pla
 
 static struct iommu_group *mtk_iommu_device_group(struct device *dev)
 {
-       struct mtk_iommu_data *c_data = dev_iommu_priv_get(dev), *data;
+       struct mtk_iommu_device *mtkdev = dev_iommu_priv_get(dev);
+       struct mtk_iommu_data *c_data = mtkdev->iommu;
+       struct mtk_iommu_data *data;
        struct list_head *hw_list = c_data->hw_list;
        struct iommu_group *group;
        int groupid;
@@ -957,32 +986,11 @@ static struct iommu_group *mtk_iommu_device_group(struct 
device *dev)
        return group;
 }
 
-static int mtk_iommu_of_xlate(struct device *dev, struct of_phandle_args *args)
-{
-       struct platform_device *m4updev;
-
-       if (args->args_count != 1) {
-               dev_err(dev, "invalid #iommu-cells(%d) property for IOMMU\n",
-                       args->args_count);
-               return -EINVAL;
-       }
-
-       if (!dev_iommu_priv_get(dev)) {
-               /* Get the m4u device */
-               m4updev = of_find_device_by_node(args->np);
-               if (WARN_ON(!m4updev))
-                       return -EINVAL;
-
-               dev_iommu_priv_set(dev, platform_get_drvdata(m4updev));
-       }
-
-       return iommu_fwspec_add_ids(dev, args->args, 1);
-}
-
 static void mtk_iommu_get_resv_regions(struct device *dev,
                                       struct list_head *head)
 {
-       struct mtk_iommu_data *data = dev_iommu_priv_get(dev);
+       struct mtk_iommu_device *mtkdev = dev_iommu_priv_get(dev);
+       struct mtk_iommu_data *data = mtkdev->iommu;
        unsigned int regionid = mtk_iommu_get_iova_region_id(dev, 
data->plat_data), i;
        const struct mtk_iommu_iova_region *resv, *curdom;
        struct iommu_resv_region *region;
@@ -1012,10 +1020,10 @@ static void mtk_iommu_get_resv_regions(struct device 
*dev,
 static const struct iommu_ops mtk_iommu_ops = {
        .identity_domain = &mtk_iommu_identity_domain,
        .domain_alloc_paging = mtk_iommu_domain_alloc_paging,
-       .probe_device   = mtk_iommu_probe_device,
+       .probe_device_pinf = mtk_iommu_probe_device,
        .release_device = mtk_iommu_release_device,
        .device_group   = mtk_iommu_device_group,
-       .of_xlate       = mtk_iommu_of_xlate,
+       .of_xlate = iommu_dummy_of_xlate,
        .get_resv_regions = mtk_iommu_get_resv_regions,
        .pgsize_bitmap  = SZ_4K | SZ_64K | SZ_1M | SZ_16M,
        .owner          = THIS_MODULE,
-- 
2.42.0


Reply via email to