This call chain is using dev->iommu->fwspec to pass around the fwspec
between the three parts (acpi_iommu_configure_id(),
acpi_iommu_fwspec_init(), iommu_probe_device()).

However there is no locking around the accesses to dev->iommu, so this is
all racy.

Allocate a clean, local, fwspec at the start of acpi_iommu_configure_id(),
pass it through all functions on the stack to fill it with data, and
finally pass it into iommu_probe_device_fwspec() which will load it into
dev->iommu under a lock.

Acked-by: Rafael J. Wysocki <rafael.j.wyso...@intel.com>
Reviewed-by: Jerry Snitselaar <jsnit...@redhat.com>
Reviewed-by: Moritz Fischer <m...@kernel.org>
Signed-off-by: Jason Gunthorpe <j...@nvidia.com>
---
 drivers/acpi/arm64/iort.c | 42 +++++++++---------
 drivers/acpi/scan.c       | 89 ++++++++++++++++++---------------------
 drivers/acpi/viot.c       | 45 +++++++++++---------
 drivers/iommu/iommu.c     |  5 +--
 include/acpi/acpi_bus.h   |  8 ++--
 include/linux/acpi_iort.h |  8 +++-
 include/linux/acpi_viot.h |  5 ++-
 include/linux/iommu.h     |  2 +
 8 files changed, 103 insertions(+), 101 deletions(-)

diff --git a/drivers/acpi/arm64/iort.c b/drivers/acpi/arm64/iort.c
index 6496ff5a6ba20d..b08682282ee5a7 100644
--- a/drivers/acpi/arm64/iort.c
+++ b/drivers/acpi/arm64/iort.c
@@ -1218,10 +1218,9 @@ static bool iort_pci_rc_supports_ats(struct 
acpi_iort_node *node)
        return pci_rc->ats_attribute & ACPI_IORT_ATS_SUPPORTED;
 }
 
-static int iort_iommu_xlate(struct device *dev, struct acpi_iort_node *node,
-                           u32 streamid)
+static int iort_iommu_xlate(struct iommu_fwspec *fwspec, struct device *dev,
+                           struct acpi_iort_node *node, u32 streamid)
 {
-       const struct iommu_ops *ops;
        struct fwnode_handle *iort_fwnode;
 
        if (!node)
@@ -1239,17 +1238,14 @@ static int iort_iommu_xlate(struct device *dev, struct 
acpi_iort_node *node,
         * in the kernel or not, defer the IOMMU configuration
         * or just abort it.
         */
-       ops = iommu_ops_from_fwnode(iort_fwnode);
-       if (!ops)
-               return iort_iommu_driver_enabled(node->type) ?
-                      -EPROBE_DEFER : -ENODEV;
-
-       return acpi_iommu_fwspec_init(dev, streamid, iort_fwnode, ops);
+       return acpi_iommu_fwspec_init(fwspec, dev, streamid, iort_fwnode,
+                                     iort_iommu_driver_enabled(node->type));
 }
 
 struct iort_pci_alias_info {
        struct device *dev;
        struct acpi_iort_node *node;
+       struct iommu_fwspec *fwspec;
 };
 
 static int iort_pci_iommu_init(struct pci_dev *pdev, u16 alias, void *data)
@@ -1260,7 +1256,7 @@ static int iort_pci_iommu_init(struct pci_dev *pdev, u16 
alias, void *data)
 
        parent = iort_node_map_id(info->node, alias, &streamid,
                                  IORT_IOMMU_TYPE);
-       return iort_iommu_xlate(info->dev, parent, streamid);
+       return iort_iommu_xlate(info->fwspec, info->dev, parent, streamid);
 }
 
 static void iort_named_component_init(struct device *dev,
@@ -1280,7 +1276,8 @@ static void iort_named_component_init(struct device *dev,
                dev_warn(dev, "Could not add device properties\n");
 }
 
-static int iort_nc_iommu_map(struct device *dev, struct acpi_iort_node *node)
+static int iort_nc_iommu_map(struct iommu_fwspec *fwspec, struct device *dev,
+                            struct acpi_iort_node *node)
 {
        struct acpi_iort_node *parent;
        int err = -ENODEV, i = 0;
@@ -1293,13 +1290,13 @@ static int iort_nc_iommu_map(struct device *dev, struct 
acpi_iort_node *node)
                                                   i++);
 
                if (parent)
-                       err = iort_iommu_xlate(dev, parent, streamid);
+                       err = iort_iommu_xlate(fwspec, dev, parent, streamid);
        } while (parent && !err);
 
        return err;
 }
 
-static int iort_nc_iommu_map_id(struct device *dev,
+static int iort_nc_iommu_map_id(struct iommu_fwspec *fwspec, struct device 
*dev,
                                struct acpi_iort_node *node,
                                const u32 *in_id)
 {
@@ -1308,7 +1305,7 @@ static int iort_nc_iommu_map_id(struct device *dev,
 
        parent = iort_node_map_id(node, *in_id, &streamid, IORT_IOMMU_TYPE);
        if (parent)
-               return iort_iommu_xlate(dev, parent, streamid);
+               return iort_iommu_xlate(fwspec, dev, parent, streamid);
 
        return -ENODEV;
 }
@@ -1317,20 +1314,22 @@ static int iort_nc_iommu_map_id(struct device *dev,
 /**
  * iort_iommu_configure_id - Set-up IOMMU configuration for a device.
  *
+ * @fwspec: The iommu_fwspec to fill in with fw information about the device
  * @dev: device to configure
  * @id_in: optional input id const value pointer
  *
  * Returns: 0 on success, <0 on failure
  */
-int iort_iommu_configure_id(struct device *dev, const u32 *id_in)
+int iort_iommu_configure_id(struct iommu_fwspec *fwspec, struct device *dev,
+                           const u32 *id_in)
 {
        struct acpi_iort_node *node;
        int err = -ENODEV;
 
        if (dev_is_pci(dev)) {
-               struct iommu_fwspec *fwspec;
                struct pci_bus *bus = to_pci_dev(dev)->bus;
-               struct iort_pci_alias_info info = { .dev = dev };
+               struct iort_pci_alias_info info = { .dev = dev,
+                                                   .fwspec = fwspec };
 
                node = iort_scan_node(ACPI_IORT_NODE_PCI_ROOT_COMPLEX,
                                      iort_match_node_callback, &bus->dev);
@@ -1341,8 +1340,7 @@ int iort_iommu_configure_id(struct device *dev, const u32 
*id_in)
                err = pci_for_each_dma_alias(to_pci_dev(dev),
                                             iort_pci_iommu_init, &info);
 
-               fwspec = dev_iommu_fwspec_get(dev);
-               if (fwspec && iort_pci_rc_supports_ats(node))
+               if (iort_pci_rc_supports_ats(node))
                        fwspec->flags |= IOMMU_FWSPEC_PCI_RC_ATS;
        } else {
                node = iort_scan_node(ACPI_IORT_NODE_NAMED_COMPONENT,
@@ -1350,8 +1348,8 @@ int iort_iommu_configure_id(struct device *dev, const u32 
*id_in)
                if (!node)
                        return -ENODEV;
 
-               err = id_in ? iort_nc_iommu_map_id(dev, node, id_in) :
-                             iort_nc_iommu_map(dev, node);
+               err = id_in ? iort_nc_iommu_map_id(fwspec, dev, node, id_in) :
+                             iort_nc_iommu_map(fwspec, dev, node);
 
                if (!err)
                        iort_named_component_init(dev, node);
@@ -1363,8 +1361,6 @@ int iort_iommu_configure_id(struct device *dev, const u32 
*id_in)
 #else
 void iort_iommu_get_resv_regions(struct device *dev, struct list_head *head)
 { }
-int iort_iommu_configure_id(struct device *dev, const u32 *input_id)
-{ return -ENODEV; }
 #endif
 
 static int nc_dma_get_range(struct device *dev, u64 *size)
diff --git a/drivers/acpi/scan.c b/drivers/acpi/scan.c
index d171d193f2a51c..5d467ff58ff24b 100644
--- a/drivers/acpi/scan.c
+++ b/drivers/acpi/scan.c
@@ -1543,74 +1543,67 @@ int acpi_dma_get_range(struct device *dev, const struct 
bus_dma_region **map)
 }
 
 #ifdef CONFIG_IOMMU_API
-int acpi_iommu_fwspec_init(struct device *dev, u32 id,
-                          struct fwnode_handle *fwnode,
-                          const struct iommu_ops *ops)
+int acpi_iommu_fwspec_init(struct iommu_fwspec *fwspec, struct device *dev,
+                          u32 id, struct fwnode_handle *fwnode,
+                          bool iommu_driver_available)
 {
-       int ret = iommu_fwspec_init(dev, fwnode, ops);
+       int ret;
 
-       if (!ret)
-               ret = iommu_fwspec_add_ids(dev, &id, 1);
-
-       return ret;
-}
-
-static inline const struct iommu_ops *acpi_iommu_fwspec_ops(struct device *dev)
-{
-       struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
-
-       return fwspec ? fwspec->ops : NULL;
+       ret = iommu_fwspec_assign_iommu(fwspec, dev, fwnode);
+       if (ret) {
+               if (ret == -EPROBE_DEFER && !iommu_driver_available)
+                       return -ENODEV;
+               return ret;
+       }
+       return iommu_fwspec_append_ids(fwspec, &id, 1);
 }
 
 static int acpi_iommu_configure_id(struct device *dev, const u32 *id_in)
 {
        int err;
-       const struct iommu_ops *ops;
+       struct iommu_fwspec *fwspec;
 
-       /*
-        * If we already translated the fwspec there is nothing left to do,
-        * return the iommu_ops.
-        */
-       ops = acpi_iommu_fwspec_ops(dev);
-       if (ops)
-               return 0;
+       fwspec = iommu_fwspec_alloc();
+       if (IS_ERR(fwspec))
+               return PTR_ERR(fwspec);
 
-       err = iort_iommu_configure_id(dev, id_in);
-       if (err && err != -EPROBE_DEFER)
-               err = viot_iommu_configure(dev);
+       err = iort_iommu_configure_id(fwspec, dev, id_in);
+       if (err == -ENODEV)
+               err = viot_iommu_configure(fwspec, dev);
+       if (err == -ENODEV || err == -EPROBE_DEFER)
+               goto err_free;
+       if (err)
+               goto err_log;
 
-       /*
-        * If we have reason to believe the IOMMU driver missed the initial
-        * iommu_probe_device() call for dev, replay it to get things in order.
-        */
-       if (!err && dev->bus)
-               err = iommu_probe_device(dev);
-
-       /* Ignore all other errors apart from EPROBE_DEFER */
-       if (err == -EPROBE_DEFER) {
-               return err;
-       } else if (err) {
-               dev_dbg(dev, "Adding to IOMMU failed: %d\n", err);
-               return -ENODEV;
+       err = iommu_probe_device_fwspec(dev, fwspec);
+       if (err) {
+               /*
+                * Ownership for fwspec always passes into
+                * iommu_probe_device_fwspec()
+                */
+               fwspec = NULL;
+               goto err_log;
        }
-       if (!acpi_iommu_fwspec_ops(dev))
-               return -ENODEV;
-       return 0;
+
+err_log:
+       dev_dbg(dev, "Adding to IOMMU failed: %d\n", err);
+err_free:
+       iommu_fwspec_dealloc(fwspec);
+       return err;
 }
 
 #else /* !CONFIG_IOMMU_API */
 
-int acpi_iommu_fwspec_init(struct device *dev, u32 id,
-                          struct fwnode_handle *fwnode,
-                          const struct iommu_ops *ops)
+int acpi_iommu_fwspec_init(struct iommu_fwspec *fwspec, struct device *dev,
+                          u32 id, struct fwnode_handle *fwnode,
+                          bool iommu_driver_available)
 {
        return -ENODEV;
 }
 
-static const struct iommu_ops *acpi_iommu_configure_id(struct device *dev,
-                                                      const u32 *id_in)
+static int acpi_iommu_configure_id(struct device *dev, const u32 *id_in)
 {
-       return NULL;
+       return -ENODEV;
 }
 
 #endif /* !CONFIG_IOMMU_API */
diff --git a/drivers/acpi/viot.c b/drivers/acpi/viot.c
index c8025921c129b2..1d0da99e65e6af 100644
--- a/drivers/acpi/viot.c
+++ b/drivers/acpi/viot.c
@@ -304,11 +304,9 @@ void __init acpi_viot_init(void)
        acpi_put_table(hdr);
 }
 
-static int viot_dev_iommu_init(struct device *dev, struct viot_iommu *viommu,
-                              u32 epid)
+static int viot_dev_iommu_init(struct iommu_fwspec *fwspec, struct device *dev,
+                              struct viot_iommu *viommu, u32 epid)
 {
-       const struct iommu_ops *ops;
-
        if (!viommu)
                return -ENODEV;
 
@@ -316,19 +314,20 @@ static int viot_dev_iommu_init(struct device *dev, struct 
viot_iommu *viommu,
        if (device_match_fwnode(dev, viommu->fwnode))
                return -EINVAL;
 
-       ops = iommu_ops_from_fwnode(viommu->fwnode);
-       if (!ops)
-               return IS_ENABLED(CONFIG_VIRTIO_IOMMU) ?
-                       -EPROBE_DEFER : -ENODEV;
-
-       return acpi_iommu_fwspec_init(dev, epid, viommu->fwnode, ops);
+       return acpi_iommu_fwspec_init(fwspec, dev, epid, viommu->fwnode,
+                                     IS_ENABLED(CONFIG_VIRTIO_IOMMU));
 }
 
+struct viot_pci_alias_info {
+       struct device *dev;
+       struct iommu_fwspec *fwspec;
+};
+
 static int viot_pci_dev_iommu_init(struct pci_dev *pdev, u16 dev_id, void 
*data)
 {
        u32 epid;
        struct viot_endpoint *ep;
-       struct device *aliased_dev = data;
+       struct viot_pci_alias_info *info = data;
        u32 domain_nr = pci_domain_nr(pdev->bus);
 
        list_for_each_entry(ep, &viot_pci_ranges, list) {
@@ -339,14 +338,15 @@ static int viot_pci_dev_iommu_init(struct pci_dev *pdev, 
u16 dev_id, void *data)
                        epid = ((domain_nr - ep->segment_start) << 16) +
                                dev_id - ep->bdf_start + ep->endpoint_id;
 
-                       return viot_dev_iommu_init(aliased_dev, ep->viommu,
-                                                  epid);
+                       return viot_dev_iommu_init(info->fwspec, info->dev,
+                                                  ep->viommu, epid);
                }
        }
        return -ENODEV;
 }
 
-static int viot_mmio_dev_iommu_init(struct platform_device *pdev)
+static int viot_mmio_dev_iommu_init(struct iommu_fwspec *fwspec,
+                                   struct platform_device *pdev)
 {
        struct resource *mem;
        struct viot_endpoint *ep;
@@ -357,24 +357,29 @@ static int viot_mmio_dev_iommu_init(struct 
platform_device *pdev)
 
        list_for_each_entry(ep, &viot_mmio_endpoints, list) {
                if (ep->address == mem->start)
-                       return viot_dev_iommu_init(&pdev->dev, ep->viommu,
-                                                  ep->endpoint_id);
+                       return viot_dev_iommu_init(fwspec, &pdev->dev,
+                                                  ep->viommu, ep->endpoint_id);
        }
        return -ENODEV;
 }
 
 /**
  * viot_iommu_configure - Setup IOMMU ops for an endpoint described by VIOT
+ * @fwspec: The iommu_fwspec to fill in with fw information about the device
  * @dev: the endpoint
  *
  * Return: 0 on success, <0 on failure
  */
-int viot_iommu_configure(struct device *dev)
+int viot_iommu_configure(struct iommu_fwspec *fwspec, struct device *dev)
 {
-       if (dev_is_pci(dev))
+       if (dev_is_pci(dev)) {
+               struct viot_pci_alias_info info = { .dev = dev,
+                                                   .fwspec = fwspec };
                return pci_for_each_dma_alias(to_pci_dev(dev),
-                                             viot_pci_dev_iommu_init, dev);
+                                             viot_pci_dev_iommu_init, &info);
+       }
        else if (dev_is_platform(dev))
-               return viot_mmio_dev_iommu_init(to_platform_device(dev));
+               return viot_mmio_dev_iommu_init(fwspec,
+                                               to_platform_device(dev));
        return -ENODEV;
 }
diff --git a/drivers/iommu/iommu.c b/drivers/iommu/iommu.c
index 2076345d0d6653..f7bda1c0959d34 100644
--- a/drivers/iommu/iommu.c
+++ b/drivers/iommu/iommu.c
@@ -2943,9 +2943,8 @@ const struct iommu_ops *iommu_ops_from_fwnode(struct 
fwnode_handle *fwnode)
        return ops;
 }
 
-static int iommu_fwspec_assign_iommu(struct iommu_fwspec *fwspec,
-                                    struct device *dev,
-                                    struct fwnode_handle *iommu_fwnode)
+int iommu_fwspec_assign_iommu(struct iommu_fwspec *fwspec, struct device *dev,
+                             struct fwnode_handle *iommu_fwnode)
 {
        const struct iommu_ops *ops;
 
diff --git a/include/acpi/acpi_bus.h b/include/acpi/acpi_bus.h
index afeed6e72049e4..4197c0004a30b0 100644
--- a/include/acpi/acpi_bus.h
+++ b/include/acpi/acpi_bus.h
@@ -12,6 +12,8 @@
 #include <linux/device.h>
 #include <linux/property.h>
 
+struct iommu_fwspec;
+
 struct acpi_handle_list {
        u32 count;
        acpi_handle* handles;
@@ -628,9 +630,9 @@ struct acpi_pci_root {
 
 bool acpi_dma_supported(const struct acpi_device *adev);
 enum dev_dma_attr acpi_get_dma_attr(struct acpi_device *adev);
-int acpi_iommu_fwspec_init(struct device *dev, u32 id,
-                          struct fwnode_handle *fwnode,
-                          const struct iommu_ops *ops);
+int acpi_iommu_fwspec_init(struct iommu_fwspec *fwspec, struct device *dev,
+                          u32 id, struct fwnode_handle *fwnode,
+                          bool iommu_driver_available);
 int acpi_dma_get_range(struct device *dev, const struct bus_dma_region **map);
 int acpi_dma_configure_id(struct device *dev, enum dev_dma_attr attr,
                           const u32 *input_id);
diff --git a/include/linux/acpi_iort.h b/include/linux/acpi_iort.h
index 1cb65592c95dd3..7644922ecb0705 100644
--- a/include/linux/acpi_iort.h
+++ b/include/linux/acpi_iort.h
@@ -11,6 +11,8 @@
 #include <linux/fwnode.h>
 #include <linux/irqdomain.h>
 
+struct iommu_fwspec;
+
 #define IORT_IRQ_MASK(irq)             (irq & 0xffffffffULL)
 #define IORT_IRQ_TRIGGER_MASK(irq)     ((irq >> 32) & 0xffffffffULL)
 
@@ -40,7 +42,8 @@ void iort_put_rmr_sids(struct fwnode_handle *iommu_fwnode,
                       struct list_head *head);
 /* IOMMU interface */
 int iort_dma_get_ranges(struct device *dev, u64 *size);
-int iort_iommu_configure_id(struct device *dev, const u32 *id_in);
+int iort_iommu_configure_id(struct iommu_fwspec *fwspec, struct device *dev,
+                           const u32 *id_in);
 void iort_iommu_get_resv_regions(struct device *dev, struct list_head *head);
 phys_addr_t acpi_iort_dma_get_max_cpu_address(void);
 #else
@@ -57,7 +60,8 @@ void iort_put_rmr_sids(struct fwnode_handle *iommu_fwnode, 
struct list_head *hea
 /* IOMMU interface */
 static inline int iort_dma_get_ranges(struct device *dev, u64 *size)
 { return -ENODEV; }
-static inline int iort_iommu_configure_id(struct device *dev, const u32 *id_in)
+static inline int iort_iommu_configure_id(struct iommu_fwspec *fwspec,
+                                         struct device *dev, const u32 *id_in)
 { return -ENODEV; }
 static inline
 void iort_iommu_get_resv_regions(struct device *dev, struct list_head *head)
diff --git a/include/linux/acpi_viot.h b/include/linux/acpi_viot.h
index a5a12243156377..f1874cb6d43c09 100644
--- a/include/linux/acpi_viot.h
+++ b/include/linux/acpi_viot.h
@@ -8,11 +8,12 @@
 #ifdef CONFIG_ACPI_VIOT
 void __init acpi_viot_early_init(void);
 void __init acpi_viot_init(void);
-int viot_iommu_configure(struct device *dev);
+int viot_iommu_configure(struct iommu_fwspec *fwspec, struct device *dev);
 #else
 static inline void acpi_viot_early_init(void) {}
 static inline void acpi_viot_init(void) {}
-static inline int viot_iommu_configure(struct device *dev)
+static inline int viot_iommu_configure(struct iommu_fwspec *fwspec,
+                                      struct device *dev)
 {
        return -ENODEV;
 }
diff --git a/include/linux/iommu.h b/include/linux/iommu.h
index bbbba7d0a159ba..72ec71bd31a376 100644
--- a/include/linux/iommu.h
+++ b/include/linux/iommu.h
@@ -818,6 +818,8 @@ void iommu_fwspec_dealloc(struct iommu_fwspec *fwspec);
 int iommu_fwspec_of_xlate(struct iommu_fwspec *fwspec, struct device *dev,
                          struct fwnode_handle *iommu_fwnode,
                          struct of_phandle_args *iommu_spec);
+int iommu_fwspec_assign_iommu(struct iommu_fwspec *fwspec, struct device *dev,
+                             struct fwnode_handle *iommu_fwnode);
 
 int iommu_fwspec_init(struct device *dev, struct fwnode_handle *iommu_fwnode,
                      const struct iommu_ops *ops);
-- 
2.42.0


Reply via email to