Use the common workqueue machanism for processing vhost work
by creating a unbound workqueue/vhost device. The backend workers
could still share work and are not visible to vhost.

Signed-off-by: Bandan Das <[email protected]>
---
 drivers/vhost/vhost.c | 103 +++++++++++++++++++++++++++++++++++++++++++-------
 drivers/vhost/vhost.h |   2 +
 2 files changed, 92 insertions(+), 13 deletions(-)

diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index ad2146a..162e25e 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -30,6 +30,10 @@
 
 #include "vhost.h"
 
+static int cmwq_worker = 1;
+module_param(cmwq_worker, int, 0444);
+MODULE_PARM_DESC(cmwq_worker, "Use cmwq for worker threads - Experimental, 1 - 
Enable; 0 - Disable");
+
 static ushort max_mem_regions = 64;
 module_param(max_mem_regions, ushort, 0444);
 MODULE_PARM_DESC(max_mem_regions,
@@ -238,7 +242,10 @@ void vhost_work_queue(struct vhost_dev *dev, struct 
vhost_work *work)
                list_add_tail(&work->node, &dev->work_list);
                work->queue_seq++;
                spin_unlock_irqrestore(&dev->work_lock, flags);
-               wake_up_process(dev->worker);
+               if (cmwq_worker)
+                       queue_work(dev->qworker, &dev->qwork);
+               else
+                       wake_up_process(dev->worker);
        } else {
                spin_unlock_irqrestore(&dev->work_lock, flags);
        }
@@ -370,6 +377,52 @@ static void vhost_dev_free_iovecs(struct vhost_dev *dev)
                vhost_vq_free_iovecs(dev->vqs[i]);
 }
 
+static void vhost_wq_worker(struct work_struct *qwork)
+{
+       struct vhost_dev *dev =
+               container_of(qwork, struct vhost_dev, qwork);
+       struct vhost_work *work = NULL;
+       unsigned uninitialized_var(seq);
+       struct mm_struct *prev_mm = NULL;
+       mm_segment_t oldfs = get_fs();
+
+       set_fs(USER_DS);
+
+       for (;;) {
+               spin_lock_irq(&dev->work_lock);
+               if (list_empty(&dev->work_list)) {
+                       spin_unlock(&dev->work_lock);
+                       break;
+               }
+
+               work = list_first_entry(&dev->work_list,
+                                       struct vhost_work, node);
+               list_del_init(&work->node);
+               seq = work->queue_seq;
+
+               if (prev_mm != dev->mm) {
+                       if (prev_mm)
+                               unuse_mm(prev_mm);
+                       prev_mm = dev->mm;
+                       use_mm(prev_mm);
+               }
+               spin_unlock(&dev->work_lock);
+
+               if (work) {
+                       work->fn(work);
+                       spin_lock_irq(&dev->work_lock);
+                       work->done_seq = seq;
+                       if (work->flushing)
+                               wake_up_all(&work->done);
+                       spin_unlock_irq(&dev->work_lock);
+               }
+       }
+
+       if (prev_mm)
+               unuse_mm(prev_mm);
+       set_fs(oldfs);
+}
+
 void vhost_dev_init(struct vhost_dev *dev,
                    struct vhost_virtqueue **vqs, int nvqs)
 {
@@ -386,6 +439,10 @@ void vhost_dev_init(struct vhost_dev *dev,
        spin_lock_init(&dev->work_lock);
        INIT_LIST_HEAD(&dev->work_list);
        dev->worker = NULL;
+       dev->qworker = NULL;
+
+       if (cmwq_worker)
+               INIT_WORK(&dev->qwork, vhost_wq_worker);
 
        for (i = 0; i < dev->nvqs; ++i) {
                vq = dev->vqs[i];
@@ -445,7 +502,8 @@ EXPORT_SYMBOL_GPL(vhost_dev_has_owner);
 /* Caller should have device mutex */
 long vhost_dev_set_owner(struct vhost_dev *dev)
 {
-       struct task_struct *worker;
+       struct task_struct *worker = NULL;
+       struct workqueue_struct *qworker;
        int err;
 
        /* Is there an owner already? */
@@ -456,18 +514,31 @@ long vhost_dev_set_owner(struct vhost_dev *dev)
 
        /* No owner, become one */
        dev->mm = get_task_mm(current);
-       worker = kthread_create(vhost_worker, dev, "vhost-%d", current->pid);
-       if (IS_ERR(worker)) {
-               err = PTR_ERR(worker);
-               goto err_worker;
-       }
+       if (cmwq_worker) {
+               qworker = alloc_workqueue("vhost-wq-%d",
+                                         WQ_UNBOUND|WQ_CGROUPS,
+                                         0, current->pid);
+               if (!qworker) {
+                       err = -ENOMEM;
+                       goto err_worker;
+               }
+               dev->qworker = qworker;
+       } else {
+               worker = kthread_create(vhost_worker, dev,
+                                       "vhost-%d", current->pid);
+               if (IS_ERR(worker)) {
+                       err = PTR_ERR(worker);
+                       goto err_worker;
+               }
 
-       dev->worker = worker;
-       wake_up_process(worker);        /* avoid contributing to loadavg */
+               dev->worker = worker;
+               /* avoid contributing to loadavg */
+               wake_up_process(worker);
 
-       err = vhost_attach_cgroups(dev);
-       if (err)
-               goto err_cgroup;
+               err = vhost_attach_cgroups(dev);
+               if (err)
+                       goto err_cgroup;
+       }
 
        err = vhost_dev_alloc_iovecs(dev);
        if (err)
@@ -475,7 +546,8 @@ long vhost_dev_set_owner(struct vhost_dev *dev)
 
        return 0;
 err_cgroup:
-       kthread_stop(worker);
+       if (worker)
+               kthread_stop(worker);
        dev->worker = NULL;
 err_worker:
        if (dev->mm)
@@ -556,6 +628,11 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
                kthread_stop(dev->worker);
                dev->worker = NULL;
        }
+       if (dev->qworker) {
+               /* destroy does flush */
+               destroy_workqueue(dev->qworker);
+               dev->qworker = NULL;
+       }
        if (dev->mm)
                mmput(dev->mm);
        dev->mm = NULL;
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index d3f7674..e2ce0c3 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -127,6 +127,8 @@ struct vhost_dev {
        spinlock_t work_lock;
        struct list_head work_list;
        struct task_struct *worker;
+       struct workqueue_struct *qworker;
+       struct work_struct qwork;
 };
 
 void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs, int 
nvqs);
-- 
2.5.0

Reply via email to