Add a napi_thread_ctx struct that has a back pointer to napi_struct.

Make the NAPI kthread to use the thread_ctx as data pointer so that
it can poll on different NAPIs thoughout its lifetime.

Mirror the thread and thread_ctx in napi_config all the time.

Park the thread on napi_del instead of stopping if napi_config is
available.

Restore the thread and context when trying to create a new NAPI
kthread.

Signed-off-by: Shuhao Tan <[email protected]>
---
 include/linux/netdevice.h |  12 +++++
 net/core/dev.c            | 106 +++++++++++++++++++++++++++++++-------
 2 files changed, 99 insertions(+), 19 deletions(-)

diff --git a/include/linux/netdevice.h b/include/linux/netdevice.h
index 9981d637f8b5..05e430f10aba 100644
--- a/include/linux/netdevice.h
+++ b/include/linux/netdevice.h
@@ -63,6 +63,7 @@ struct dsa_port;
 struct ip_tunnel_parm_kern;
 struct macsec_context;
 struct macsec_ops;
+struct napi_struct;
 struct netdev_config;
 struct netdev_name_node;
 struct sd_flow_limit;
@@ -363,6 +364,10 @@ struct gro_node {
        u32                     cached_napi_id;
 };
 
+struct napi_thread_ctx {
+       struct napi_struct *napi;
+};
+
 /*
  * Structure for per-NAPI config
  */
@@ -371,6 +376,12 @@ struct napi_config {
        u64 irq_suspend_timeout;
        u32 defer_hard_irqs;
        cpumask_t affinity_mask;
+       /* thread and thread_ctx mirrors fields of napi_struct when napi_struct
+        * is alive. When the napi_struct gets destroyed, napi_config holds the
+        * sole reference to the now parked kthread.
+        */
+       struct task_struct *thread;
+       struct napi_thread_ctx *thread_ctx;
        u8 threaded;
        unsigned int napi_id;
 };
@@ -404,6 +415,7 @@ struct napi_struct {
        struct hrtimer          timer;
        /* all fields past this point are write-protected by netdev_lock */
        struct task_struct      *thread;
+       struct napi_thread_ctx  *thread_ctx;
        unsigned long           gro_flush_timeout;
        unsigned long           irq_suspend_timeout;
        u32                     defer_hard_irqs;
diff --git a/net/core/dev.c b/net/core/dev.c
index 4b3d5cfdf6e0..c81992c929d9 100644
--- a/net/core/dev.c
+++ b/net/core/dev.c
@@ -1647,20 +1647,45 @@ static int napi_threaded_poll(void *data);
 
 static int napi_kthread_create(struct napi_struct *n)
 {
+       struct napi_thread_ctx *thread_ctx = NULL;
        int err = 0;
 
+       if (n->config && n->config->thread) {
+               n->thread_ctx = n->config->thread_ctx;
+               n->thread = n->config->thread;
+               WRITE_ONCE(n->thread_ctx->napi, n);
+               kthread_unpark(n->thread);
+               return 0;
+       }
+
+       thread_ctx = kvzalloc_obj(*thread_ctx);
+       if (!thread_ctx)
+               return -ENOMEM;
+
        /* Create and wake up the kthread once to put it in
         * TASK_INTERRUPTIBLE mode to avoid the blocked task
         * warning and work with loadavg.
         */
-       n->thread = kthread_run(napi_threaded_poll, n, "napi/%s-%d",
+       thread_ctx->napi = n;
+       n->thread = kthread_run(napi_threaded_poll, thread_ctx, "napi/%s-%d",
                                n->dev->name, n->napi_id);
        if (IS_ERR(n->thread)) {
                err = PTR_ERR(n->thread);
                pr_err("kthread_run failed with err %d\n", err);
                n->thread = NULL;
+               goto free_thread_ctx;
+       }
+       n->thread_ctx = thread_ctx;
+       if (n->config) {
+               n->config->thread = n->thread;
+               n->config->thread_ctx = thread_ctx;
        }
 
+       return 0;
+
+free_thread_ctx:
+       kvfree(thread_ctx);
+
        return err;
 }
 
@@ -7183,7 +7208,13 @@ static void napi_stop_kthread(struct napi_struct *napi)
        }
 
        kthread_stop(napi->thread);
+       kvfree(napi->thread_ctx);
        napi->thread = NULL;
+       napi->thread_ctx = NULL;
+       if (napi->config) {
+               napi->config->thread = NULL;
+               napi->config->thread_ctx = NULL;
+       }
 }
 
 static void napi_set_threaded_state(struct napi_struct *napi,
@@ -7199,13 +7230,11 @@ static void napi_set_threaded_state(struct napi_struct 
*napi,
 int napi_set_threaded(struct napi_struct *napi,
                      enum netdev_napi_threaded threaded)
 {
-       if (threaded) {
-               if (!napi->thread) {
-                       int err = napi_kthread_create(napi);
+       if (threaded && !napi->thread) {
+               int err = napi_kthread_create(napi);
 
-                       if (err)
-                               return err;
-               }
+               if (err)
+                       return err;
        }
 
        if (napi->config)
@@ -7255,8 +7284,15 @@ int netif_set_threaded(struct net_device *dev,
                WARN_ON_ONCE(napi_set_threaded(napi, threaded));
 
        /* Override the config for all NAPIs even if currently not listed */
-       for (i = 0; i < dev->num_napi_configs; i++)
+       for (i = 0; i < dev->num_napi_configs; i++) {
                dev->napi_config[i].threaded = threaded;
+               if (!threaded && dev->napi_config[i].thread) {
+                       kthread_stop(dev->napi_config[i].thread);
+                       kvfree(dev->napi_config[i].thread_ctx);
+                       dev->napi_config[i].thread = NULL;
+                       dev->napi_config[i].thread_ctx = NULL;
+               }
+       }
 
        return err;
 }
@@ -7501,6 +7537,8 @@ static void napi_save_config(struct napi_struct *n)
        n->config->defer_hard_irqs = n->defer_hard_irqs;
        n->config->gro_flush_timeout = n->gro_flush_timeout;
        n->config->irq_suspend_timeout = n->irq_suspend_timeout;
+       n->config->thread = n->thread;
+       n->config->thread_ctx = n->thread_ctx;
        napi_hash_del(n);
 }
 
@@ -7695,6 +7733,21 @@ void __netif_napi_del_locked(struct napi_struct *napi)
        if (test_and_clear_bit(NAPI_STATE_HAS_NOTIFIER, &napi->state))
                irq_set_affinity_notifier(napi->irq, NULL);
 
+       if (napi->thread) {
+               if (napi->config) {
+                       kthread_park(napi->thread);
+                       /* napi->config holds the only reference to the thread
+                        * from now on.
+                        */
+                       napi->thread_ctx->napi = NULL;
+               } else {
+                       kthread_stop(napi->thread);
+                       kvfree(napi->thread_ctx);
+               }
+               napi->thread = NULL;
+               napi->thread_ctx = NULL;
+       }
+
        if (napi->config) {
                napi->index = -1;
                napi->config = NULL;
@@ -7704,11 +7757,6 @@ void __netif_napi_del_locked(struct napi_struct *napi)
        napi_free_frags(napi);
 
        gro_cleanup(&napi->gro);
-
-       if (napi->thread) {
-               kthread_stop(napi->thread);
-               napi->thread = NULL;
-       }
 }
 EXPORT_SYMBOL(__netif_napi_del_locked);
 
@@ -7804,11 +7852,18 @@ static int napi_poll(struct napi_struct *n, struct 
list_head *repoll)
        return work;
 }
 
-static int napi_thread_wait(struct napi_struct *napi)
+static struct napi_struct *napi_thread_wait(struct napi_thread_ctx *thread_ctx)
 {
+       struct napi_struct *napi = READ_ONCE(thread_ctx->napi);
        set_current_state(TASK_INTERRUPTIBLE);
 
        while (!kthread_should_stop()) {
+               if (kthread_should_park()) {
+                       kthread_parkme();
+                       napi = READ_ONCE(thread_ctx->napi);
+                       /* Might be awakened for stopping */
+                       continue;
+               }
                /* Testing SCHED_THREADED bit here to make sure the current
                 * kthread owns this napi and could poll on this napi.
                 * Testing SCHED bit is not enough because SCHED bit might be
@@ -7817,7 +7872,7 @@ static int napi_thread_wait(struct napi_struct *napi)
                if (test_bit(NAPI_STATE_SCHED_THREADED, &napi->state)) {
                        WARN_ON(!list_empty(&napi->poll_list));
                        __set_current_state(TASK_RUNNING);
-                       return 0;
+                       return napi;
                }
 
                schedule();
@@ -7825,7 +7880,7 @@ static int napi_thread_wait(struct napi_struct *napi)
        }
        __set_current_state(TASK_RUNNING);
 
-       return -1;
+       return NULL;
 }
 
 static void napi_threaded_poll_loop(struct napi_struct *napi,
@@ -7882,13 +7937,18 @@ static void napi_threaded_poll_loop(struct napi_struct 
*napi,
 
 static int napi_threaded_poll(void *data)
 {
-       struct napi_struct *napi = data;
+       struct napi_thread_ctx *thread_ctx = data;
        unsigned long last_qs = jiffies;
+       struct napi_struct *napi;
        bool want_busy_poll;
        bool in_busy_poll;
        unsigned long val;
 
-       while (!napi_thread_wait(napi)) {
+       while (1) {
+               napi = napi_thread_wait(thread_ctx);
+               if (!napi)
+                       break;
+
                val = READ_ONCE(napi->state);
 
                want_busy_poll = val & NAPIF_STATE_THREADED_BUSY_POLL;
@@ -12128,11 +12188,11 @@ struct net_device *alloc_netdev_mqs(int sizeof_priv, 
const char *name,
                goto free_all;
        dev->cfg_pending = dev->cfg;
 
-       dev->num_napi_configs = maxqs;
        napi_config_sz = array_size(maxqs, sizeof(*dev->napi_config));
        dev->napi_config = kvzalloc(napi_config_sz, GFP_KERNEL_ACCOUNT);
        if (!dev->napi_config)
                goto free_all;
+       dev->num_napi_configs = maxqs;
 
        strscpy(dev->name, name);
        dev->name_assign_type = name_assign_type;
@@ -12160,6 +12220,8 @@ EXPORT_SYMBOL(alloc_netdev_mqs);
 
 static void netdev_napi_exit(struct net_device *dev)
 {
+       unsigned int i;
+
        if (!list_empty(&dev->napi_list)) {
                struct napi_struct *p, *n;
 
@@ -12171,6 +12233,12 @@ static void netdev_napi_exit(struct net_device *dev)
                synchronize_net();
        }
 
+       for (i = 0; i < dev->num_napi_configs; i++) {
+               if (dev->napi_config[i].thread) {
+                       kthread_stop(dev->napi_config[i].thread);
+                       kvfree(dev->napi_config[i].thread_ctx);
+               }
+       }
        kvfree(dev->napi_config);
 }
 
-- 
2.55.0.rc0.799.gd6f94ed593-goog


Reply via email to