vfio_pci_vendor_driver_ops includes two parts:
(1) .probe() and .remove() interface to be called by vfio_pci_probe()
and vfio_pci_remove().
(2) pointer to struct vfio_device_ops. It will be registered as ops of vfio
device if .probe() succeeds.

Suggested-by: Alex Williamson <alex.william...@redhat.com>
Signed-off-by: Yan Zhao <yan.y.z...@intel.com>
---
 drivers/vfio/pci/vfio_pci.c         | 102 +++++++++++++++++++++++++++-
 drivers/vfio/pci/vfio_pci_private.h |   7 ++
 include/linux/vfio.h                |   9 +++
 3 files changed, 117 insertions(+), 1 deletion(-)

diff --git a/drivers/vfio/pci/vfio_pci.c b/drivers/vfio/pci/vfio_pci.c
index 6c6b37b5c04e..43d10d34cbc2 100644
--- a/drivers/vfio/pci/vfio_pci.c
+++ b/drivers/vfio/pci/vfio_pci.c
@@ -68,6 +68,11 @@ static inline bool vfio_vga_disabled(void)
 #endif
 }
 
+static struct vfio_pci {
+       struct  mutex           vendor_drivers_lock;
+       struct  list_head       vendor_drivers_list;
+} vfio_pci;
+
 /*
  * Our VGA arbiter participation is limited since we don't know anything
  * about the device itself.  However, if the device is the only VGA device
@@ -1570,6 +1575,35 @@ static int vfio_pci_bus_notifier(struct notifier_block 
*nb,
        return 0;
 }
 
+static int probe_vendor_drivers(struct vfio_pci_device *vdev)
+{
+       struct vfio_pci_vendor_driver *driver;
+       int ret = -ENODEV;
+
+       request_module("vfio-pci:%x-%x", vdev->pdev->vendor,
+                                        vdev->pdev->device);
+
+       mutex_lock(&vfio_pci.vendor_drivers_lock);
+       list_for_each_entry(driver, &vfio_pci.vendor_drivers_list, next) {
+               void *data;
+
+               if (!try_module_get(driver->ops->owner))
+                       continue;
+
+               data = driver->ops->probe(vdev->pdev);
+               if (IS_ERR(data)) {
+                       module_put(driver->ops->owner);
+                       continue;
+               }
+               vdev->vendor_driver = driver;
+               vdev->vendor_data = data;
+               ret = 0;
+               break;
+       }
+       mutex_unlock(&vfio_pci.vendor_drivers_lock);
+       return ret;
+}
+
 static int vfio_pci_probe(struct pci_dev *pdev, const struct pci_device_id *id)
 {
        struct vfio_pci_device *vdev;
@@ -1609,7 +1643,12 @@ static int vfio_pci_probe(struct pci_dev *pdev, const 
struct pci_device_id *id)
        mutex_init(&vdev->ioeventfds_lock);
        INIT_LIST_HEAD(&vdev->ioeventfds_list);
 
-       ret = vfio_add_group_dev(&pdev->dev, &vfio_pci_ops, vdev);
+       if (probe_vendor_drivers(vdev))
+               ret = vfio_add_group_dev(&pdev->dev, &vfio_pci_ops, vdev);
+       else
+               ret = vfio_add_group_dev(&pdev->dev,
+                                        vdev->vendor_driver->ops->device_ops,
+                                        vdev);
        if (ret)
                goto out_free;
 
@@ -1698,6 +1737,11 @@ static void vfio_pci_remove(struct pci_dev *pdev)
        if (!disable_idle_d3)
                vfio_pci_set_power_state(vdev, PCI_D0);
 
+       if (vdev->vendor_driver) {
+               vdev->vendor_driver->ops->remove(vdev->vendor_data);
+               module_put(vdev->vendor_driver->ops->owner);
+       }
+
        kfree(vdev->pm_save);
        kfree(vdev);
 
@@ -2035,6 +2079,8 @@ static int __init vfio_pci_init(void)
 
        vfio_pci_fill_ids();
 
+       mutex_init(&vfio_pci.vendor_drivers_lock);
+       INIT_LIST_HEAD(&vfio_pci.vendor_drivers_list);
        return 0;
 
 out_driver:
@@ -2042,6 +2088,60 @@ static int __init vfio_pci_init(void)
        return ret;
 }
 
+int __vfio_pci_register_vendor_driver(struct vfio_pci_vendor_driver_ops *ops)
+{
+       struct vfio_pci_vendor_driver *driver, *tmp;
+
+       if (!ops || !ops->device_ops)
+               return -EINVAL;
+
+       driver = kzalloc(sizeof(*driver), GFP_KERNEL);
+       if (!driver)
+               return -ENOMEM;
+
+       driver->ops = ops;
+
+       mutex_lock(&vfio_pci.vendor_drivers_lock);
+
+       /* Check for duplicates */
+       list_for_each_entry(tmp, &vfio_pci.vendor_drivers_list, next) {
+               if (tmp->ops->device_ops == ops->device_ops) {
+                       mutex_unlock(&vfio_pci.vendor_drivers_lock);
+                       kfree(driver);
+                       return -EINVAL;
+               }
+       }
+
+       list_add(&driver->next, &vfio_pci.vendor_drivers_list);
+
+       mutex_unlock(&vfio_pci.vendor_drivers_lock);
+
+       if (!try_module_get(THIS_MODULE))
+               return -ENODEV;
+
+       return 0;
+}
+EXPORT_SYMBOL_GPL(__vfio_pci_register_vendor_driver);
+
+void vfio_pci_unregister_vendor_driver(struct vfio_device_ops *device_ops)
+{
+       struct vfio_pci_vendor_driver *driver, *tmp;
+
+       mutex_lock(&vfio_pci.vendor_drivers_lock);
+       list_for_each_entry_safe(driver, tmp,
+                                &vfio_pci.vendor_drivers_list, next) {
+               if (driver->ops->device_ops == device_ops) {
+                       list_del(&driver->next);
+                       mutex_unlock(&vfio_pci.vendor_drivers_lock);
+                       kfree(driver);
+                       module_put(THIS_MODULE);
+                       return;
+               }
+       }
+       mutex_unlock(&vfio_pci.vendor_drivers_lock);
+}
+EXPORT_SYMBOL_GPL(vfio_pci_unregister_vendor_driver);
+
 module_init(vfio_pci_init);
 module_exit(vfio_pci_cleanup);
 
diff --git a/drivers/vfio/pci/vfio_pci_private.h 
b/drivers/vfio/pci/vfio_pci_private.h
index 36ec69081ecd..7758a20546fa 100644
--- a/drivers/vfio/pci/vfio_pci_private.h
+++ b/drivers/vfio/pci/vfio_pci_private.h
@@ -92,6 +92,11 @@ struct vfio_pci_vf_token {
        int                     users;
 };
 
+struct vfio_pci_vendor_driver {
+       const struct vfio_pci_vendor_driver_ops *ops;
+       struct list_head                        next;
+};
+
 struct vfio_pci_device {
        struct pci_dev          *pdev;
        void __iomem            *barmap[PCI_STD_NUM_BARS];
@@ -132,6 +137,8 @@ struct vfio_pci_device {
        struct list_head        ioeventfds_list;
        struct vfio_pci_vf_token        *vf_token;
        struct notifier_block   nb;
+       void                    *vendor_data;
+       struct vfio_pci_vendor_driver   *vendor_driver;
 };
 
 #define is_intx(vdev) (vdev->irq_type == VFIO_PCI_INTX_IRQ_INDEX)
diff --git a/include/linux/vfio.h b/include/linux/vfio.h
index 38d3c6a8dc7e..3e53deb012b6 100644
--- a/include/linux/vfio.h
+++ b/include/linux/vfio.h
@@ -214,4 +214,13 @@ extern int vfio_virqfd_enable(void *opaque,
                              void *data, struct virqfd **pvirqfd, int fd);
 extern void vfio_virqfd_disable(struct virqfd **pvirqfd);
 
+struct vfio_pci_vendor_driver_ops {
+       char                    *name;
+       struct module           *owner;
+       void                    *(*probe)(struct pci_dev *pdev);
+       void                    (*remove)(void *vendor_data);
+       struct vfio_device_ops *device_ops;
+};
+int __vfio_pci_register_vendor_driver(struct vfio_pci_vendor_driver_ops *ops);
+void vfio_pci_unregister_vendor_driver(struct vfio_device_ops *device_ops);
 #endif /* VFIO_H */
-- 
2.17.1

Reply via email to