The memcg kmem cache creation and deactivation (SLUB only) is
asynchronous. If a root kmem cache is destroyed whose memcg cache is in
the process of creation or deactivation, the kernel may crash.

Example of one such crash:
        general protection fault: 0000 [#1] SMP PTI
        CPU: 1 PID: 1721 Comm: kworker/14:1 Not tainted 4.17.0-smp
        ...
        Workqueue: memcg_kmem_cache kmemcg_deactivate_workfn
        RIP: 0010:has_cpu_slab
        ...
        Call Trace:
        ? on_each_cpu_cond
        __kmem_cache_shrink
        kmemcg_cache_deact_after_rcu
        kmemcg_deactivate_workfn
        process_one_work
        worker_thread
        kthread
        ret_from_fork+0x35/0x40

This issue is due to the lack of reference counting for the root
kmem_caches. There exist a refcount in kmem_cache but it is actually a
count of aliases i.e. number of kmem_caches merged together.

This patch make alias count explicit and adds reference counting to the
root kmem_caches. The reference of a root kmem cache is elevated on
merge and while its memcg kmem_cache is in the process of creation or
deactivation.

Signed-off-by: Shakeel Butt <shake...@google.com>
---
 include/linux/slab.h     |  2 +
 include/linux/slab_def.h |  3 +-
 include/linux/slub_def.h |  3 +-
 mm/memcontrol.c          |  7 ++++
 mm/slab.c                |  4 +-
 mm/slab.h                |  5 ++-
 mm/slab_common.c         | 84 ++++++++++++++++++++++++++++++----------
 mm/slub.c                | 14 ++++---
 8 files changed, 90 insertions(+), 32 deletions(-)

diff --git a/include/linux/slab.h b/include/linux/slab.h
index 9ebe659bd4a5..4c28f2483a22 100644
--- a/include/linux/slab.h
+++ b/include/linux/slab.h
@@ -674,6 +674,8 @@ struct memcg_cache_params {
 };
 
 int memcg_update_all_caches(int num_memcgs);
+bool kmem_cache_tryget(struct kmem_cache *s);
+void kmem_cache_put(struct kmem_cache *s);
 
 /**
  * kmalloc_array - allocate memory for an array.
diff --git a/include/linux/slab_def.h b/include/linux/slab_def.h
index d9228e4d0320..4bb22c89a740 100644
--- a/include/linux/slab_def.h
+++ b/include/linux/slab_def.h
@@ -41,7 +41,8 @@ struct kmem_cache {
 /* 4) cache creation/removal */
        const char *name;
        struct list_head list;
-       int refcount;
+       refcount_t refcount;
+       int alias_count;
        int object_size;
        int align;
 
diff --git a/include/linux/slub_def.h b/include/linux/slub_def.h
index 3773e26c08c1..532d4b6f83ed 100644
--- a/include/linux/slub_def.h
+++ b/include/linux/slub_def.h
@@ -97,7 +97,8 @@ struct kmem_cache {
        struct kmem_cache_order_objects max;
        struct kmem_cache_order_objects min;
        gfp_t allocflags;       /* gfp flags to use on each alloc */
-       int refcount;           /* Refcount for slab cache destroy */
+       refcount_t refcount;    /* Refcount for slab cache destroy */
+       int alias_count;        /* Number of root kmem caches merged */
        void (*ctor)(void *);
        unsigned int inuse;             /* Offset to metadata */
        unsigned int align;             /* Alignment */
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index bdb8028c806c..ab5673dbfc4e 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -2185,6 +2185,7 @@ static void memcg_kmem_cache_create_func(struct 
work_struct *w)
        memcg_create_kmem_cache(memcg, cachep);
 
        css_put(&memcg->css);
+       kmem_cache_put(cachep);
        kfree(cw);
 }
 
@@ -2200,6 +2201,12 @@ static void __memcg_schedule_kmem_cache_create(struct 
mem_cgroup *memcg,
        if (!cw)
                return;
 
+       /* Make sure root kmem cache does not get destroyed in the middle */
+       if (!kmem_cache_tryget(cachep)) {
+               kfree(cw);
+               return;
+       }
+
        css_get(&memcg->css);
 
        cw->memcg = memcg;
diff --git a/mm/slab.c b/mm/slab.c
index c1fe8099b3cd..080732f5f20d 100644
--- a/mm/slab.c
+++ b/mm/slab.c
@@ -1883,8 +1883,8 @@ __kmem_cache_alias(const char *name, unsigned int size, 
unsigned int align,
        struct kmem_cache *cachep;
 
        cachep = find_mergeable(size, align, flags, name, ctor);
-       if (cachep) {
-               cachep->refcount++;
+       if (cachep && kmem_cache_tryget(cachep)) {
+               cachep->alias_count++;
 
                /*
                 * Adjust the object sizes so that we clear
diff --git a/mm/slab.h b/mm/slab.h
index 68bdf498da3b..25962ab75ec1 100644
--- a/mm/slab.h
+++ b/mm/slab.h
@@ -25,7 +25,8 @@ struct kmem_cache {
        unsigned int useroffset;/* Usercopy region offset */
        unsigned int usersize;  /* Usercopy region size */
        const char *name;       /* Slab name for sysfs */
-       int refcount;           /* Use counter */
+       refcount_t refcount;    /* Use counter */
+       int alias_count;
        void (*ctor)(void *);   /* Called on object slot creation */
        struct list_head list;  /* List of all slab caches on the system */
 };
@@ -295,7 +296,7 @@ extern void slab_init_memcg_params(struct kmem_cache *);
 extern void memcg_link_cache(struct kmem_cache *s);
 extern void slab_deactivate_memcg_cache_rcu_sched(struct kmem_cache *s,
                                void (*deact_fn)(struct kmem_cache *));
-
+extern void kmem_cache_put_locked(struct kmem_cache *s);
 #else /* CONFIG_MEMCG && !CONFIG_SLOB */
 
 /* If !memcg, all caches are root. */
diff --git a/mm/slab_common.c b/mm/slab_common.c
index b0dd9db1eb2f..390eb47486fd 100644
--- a/mm/slab_common.c
+++ b/mm/slab_common.c
@@ -306,7 +306,7 @@ int slab_unmergeable(struct kmem_cache *s)
        /*
         * We may have set a slab to be unmergeable during bootstrap.
         */
-       if (s->refcount < 0)
+       if (s->alias_count < 0)
                return 1;
 
        return 0;
@@ -391,7 +391,8 @@ static struct kmem_cache *create_cache(const char *name,
        if (err)
                goto out_free_cache;
 
-       s->refcount = 1;
+       s->alias_count = 1;
+       refcount_set(&s->refcount, 1);
        list_add(&s->list, &slab_caches);
        memcg_link_cache(s);
 out:
@@ -611,6 +612,13 @@ void memcg_create_kmem_cache(struct mem_cgroup *memcg,
        if (memcg->kmem_state != KMEM_ONLINE)
                goto out_unlock;
 
+       /*
+        * The root cache has been requested to be destroyed while its memcg
+        * cache was in creation queue.
+        */
+       if (!root_cache->alias_count)
+               goto out_unlock;
+
        idx = memcg_cache_id(memcg);
        arr = rcu_dereference_protected(root_cache->memcg_params.memcg_caches,
                                        lockdep_is_held(&slab_mutex));
@@ -663,6 +671,8 @@ static void kmemcg_deactivate_workfn(struct work_struct 
*work)
 {
        struct kmem_cache *s = container_of(work, struct kmem_cache,
                                            memcg_params.deact_work);
+       struct kmem_cache *root = s->memcg_params.root_cache;
+       struct mem_cgroup *memcg = s->memcg_params.memcg;
 
        get_online_cpus();
        get_online_mems();
@@ -677,7 +687,8 @@ static void kmemcg_deactivate_workfn(struct work_struct 
*work)
        put_online_cpus();
 
        /* done, put the ref from slab_deactivate_memcg_cache_rcu_sched() */
-       css_put(&s->memcg_params.memcg->css);
+       css_put(&memcg->css);
+       kmem_cache_put(root);
 }
 
 static void kmemcg_deactivate_rcufn(struct rcu_head *head)
@@ -712,6 +723,10 @@ void slab_deactivate_memcg_cache_rcu_sched(struct 
kmem_cache *s,
            WARN_ON_ONCE(s->memcg_params.deact_fn))
                return;
 
+       /* Make sure root kmem_cache does not get destroyed in the middle */
+       if (!kmem_cache_tryget(s->memcg_params.root_cache))
+               return;
+
        /* pin memcg so that @s doesn't get destroyed in the middle */
        css_get(&s->memcg_params.memcg->css);
 
@@ -838,21 +853,17 @@ void slab_kmem_cache_release(struct kmem_cache *s)
        kmem_cache_free(kmem_cache, s);
 }
 
-void kmem_cache_destroy(struct kmem_cache *s)
+static void __kmem_cache_destroy(struct kmem_cache *s, bool lock)
 {
        int err;
 
-       if (unlikely(!s))
-               return;
-
-       get_online_cpus();
-       get_online_mems();
+       if (lock) {
+               get_online_cpus();
+               get_online_mems();
+               mutex_lock(&slab_mutex);
+       }
 
-       mutex_lock(&slab_mutex);
-
-       s->refcount--;
-       if (s->refcount)
-               goto out_unlock;
+       VM_BUG_ON(s->alias_count);
 
        err = shutdown_memcg_caches(s);
        if (!err)
@@ -863,11 +874,42 @@ void kmem_cache_destroy(struct kmem_cache *s)
                       s->name);
                dump_stack();
        }
-out_unlock:
-       mutex_unlock(&slab_mutex);
 
-       put_online_mems();
-       put_online_cpus();
+       if (lock) {
+               mutex_unlock(&slab_mutex);
+               put_online_mems();
+               put_online_cpus();
+       }
+}
+
+bool kmem_cache_tryget(struct kmem_cache *s)
+{
+       if (is_root_cache(s))
+               return refcount_inc_not_zero(&s->refcount);
+       return false;
+}
+
+void kmem_cache_put(struct kmem_cache *s)
+{
+       if (is_root_cache(s) &&
+           refcount_dec_and_test(&s->refcount))
+               __kmem_cache_destroy(s, true);
+}
+
+void kmem_cache_put_locked(struct kmem_cache *s)
+{
+       if (is_root_cache(s) &&
+           refcount_dec_and_test(&s->refcount))
+               __kmem_cache_destroy(s, false);
+}
+
+void kmem_cache_destroy(struct kmem_cache *s)
+{
+       if (unlikely(!s))
+               return;
+
+       s->alias_count--;
+       kmem_cache_put(s);
 }
 EXPORT_SYMBOL(kmem_cache_destroy);
 
@@ -919,7 +961,8 @@ void __init create_boot_cache(struct kmem_cache *s, const 
char *name,
                panic("Creation of kmalloc slab %s size=%u failed. Reason %d\n",
                                        name, size, err);
 
-       s->refcount = -1;       /* Exempt from merging for now */
+       s->alias_count = -1;    /* Exempt from merging for now */
+       refcount_set(&s->refcount, 1);
 }
 
 struct kmem_cache *__init create_kmalloc_cache(const char *name,
@@ -934,7 +977,8 @@ struct kmem_cache *__init create_kmalloc_cache(const char 
*name,
        create_boot_cache(s, name, size, flags, useroffset, usersize);
        list_add(&s->list, &slab_caches);
        memcg_link_cache(s);
-       s->refcount = 1;
+       s->alias_count = 1;
+       refcount_set(&s->refcount, 1);
        return s;
 }
 
diff --git a/mm/slub.c b/mm/slub.c
index 48f75872c356..2e45f7febc6e 100644
--- a/mm/slub.c
+++ b/mm/slub.c
@@ -4270,8 +4270,8 @@ __kmem_cache_alias(const char *name, unsigned int size, 
unsigned int align,
        struct kmem_cache *s, *c;
 
        s = find_mergeable(size, align, flags, name, ctor);
-       if (s) {
-               s->refcount++;
+       if (s && kmem_cache_tryget(s)) {
+               s->alias_count++;
 
                /*
                 * Adjust the object sizes so that we clear
@@ -4286,7 +4286,8 @@ __kmem_cache_alias(const char *name, unsigned int size, 
unsigned int align,
                }
 
                if (sysfs_slab_alias(s, name)) {
-                       s->refcount--;
+                       s->alias_count--;
+                       kmem_cache_put_locked(s);
                        s = NULL;
                }
        }
@@ -5009,7 +5010,8 @@ SLAB_ATTR_RO(ctor);
 
 static ssize_t aliases_show(struct kmem_cache *s, char *buf)
 {
-       return sprintf(buf, "%d\n", s->refcount < 0 ? 0 : s->refcount - 1);
+       return sprintf(buf, "%d\n",
+                      s->alias_count < 0 ? 0 : s->alias_count - 1);
 }
 SLAB_ATTR_RO(aliases);
 
@@ -5162,7 +5164,7 @@ static ssize_t trace_store(struct kmem_cache *s, const 
char *buf,
         * as well as cause other issues like converting a mergeable
         * cache into an umergeable one.
         */
-       if (s->refcount > 1)
+       if (s->alias_count > 1)
                return -EINVAL;
 
        s->flags &= ~SLAB_TRACE;
@@ -5280,7 +5282,7 @@ static ssize_t failslab_show(struct kmem_cache *s, char 
*buf)
 static ssize_t failslab_store(struct kmem_cache *s, const char *buf,
                                                        size_t length)
 {
-       if (s->refcount > 1)
+       if (s->alias_count > 1)
                return -EINVAL;
 
        s->flags &= ~SLAB_FAILSLAB;
-- 
2.17.0.441.gb46fe60e1d-goog

Reply via email to