refcount_t type and corresponding API can protect refcounters from
accidental underflow and overflow and further use-after-free situations.

Signed-off-by: Xiyu Yang <xiyuyan...@fudan.edu.cn>
Signed-off-by: Xin Tan <tanxin....@gmail.com>
---
 drivers/iommu/amd/iommu_v2.c | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/drivers/iommu/amd/iommu_v2.c b/drivers/iommu/amd/iommu_v2.c
index f8d4ad421e07..15d64f916fe5 100644
--- a/drivers/iommu/amd/iommu_v2.c
+++ b/drivers/iommu/amd/iommu_v2.c
@@ -6,6 +6,7 @@
 
 #define pr_fmt(fmt)     "AMD-Vi: " fmt
 
+#include <linux/refcount.h>
 #include <linux/mmu_notifier.h>
 #include <linux/amd-iommu.h>
 #include <linux/mm_types.h>
@@ -51,7 +52,7 @@ struct pasid_state {
 struct device_state {
        struct list_head list;
        u16 devid;
-       atomic_t count;
+       refcount_t count;
        struct pci_dev *pdev;
        struct pasid_state **states;
        struct iommu_domain *domain;
@@ -113,7 +114,7 @@ static struct device_state *get_device_state(u16 devid)
        spin_lock_irqsave(&state_lock, flags);
        dev_state = __get_device_state(devid);
        if (dev_state != NULL)
-               atomic_inc(&dev_state->count);
+               refcount_inc(&dev_state->count);
        spin_unlock_irqrestore(&state_lock, flags);
 
        return dev_state;
@@ -144,7 +145,7 @@ static void free_device_state(struct device_state 
*dev_state)
 
 static void put_device_state(struct device_state *dev_state)
 {
-       if (atomic_dec_and_test(&dev_state->count))
+       if (refcount_dec_and_test(&dev_state->count))
                wake_up(&dev_state->wq);
 }
 
@@ -765,7 +766,7 @@ int amd_iommu_init_device(struct pci_dev *pdev, int pasids)
        for (dev_state->pasid_levels = 0; (tmp - 1) & ~0x1ff; tmp >>= 9)
                dev_state->pasid_levels += 1;
 
-       atomic_set(&dev_state->count, 1);
+       refcount_set(&dev_state->count, 1);
        dev_state->max_pasids = pasids;
 
        ret = -ENOMEM;
@@ -856,7 +857,7 @@ void amd_iommu_free_device(struct pci_dev *pdev)
         * Wait until the last reference is dropped before freeing
         * the device state.
         */
-       wait_event(dev_state->wq, !atomic_read(&dev_state->count));
+       wait_event(dev_state->wq, !refcount_read(&dev_state->count));
        free_device_state(dev_state);
 }
 EXPORT_SYMBOL(amd_iommu_free_device);
-- 
2.7.4

_______________________________________________
iommu mailing list
iommu@lists.linux-foundation.org
https://lists.linuxfoundation.org/mailman/listinfo/iommu

Reply via email to