Prepare to handle premature release of memcg->vmstats_percpu data.
Currently it's a generic pointer which is expected to be non-NULL
during the whole life time of a memcg. Switch over to the
rcu-protected pointer, and carefully check it for being non-NULL.

This change is a required step towards dynamic premature release
of percpu memcg data.

Signed-off-by: Roman Gushchin <g...@fb.com>
Acked-by: Johannes Weiner <han...@cmpxchg.org>
---
 include/linux/memcontrol.h | 40 +++++++++++++++++-------
 mm/memcontrol.c            | 62 +++++++++++++++++++++++++++++---------
 2 files changed, 77 insertions(+), 25 deletions(-)

diff --git a/include/linux/memcontrol.h b/include/linux/memcontrol.h
index 534267947664..05ca77767c6a 100644
--- a/include/linux/memcontrol.h
+++ b/include/linux/memcontrol.h
@@ -274,7 +274,7 @@ struct mem_cgroup {
        struct task_struct      *move_lock_task;
 
        /* memory.stat */
-       struct memcg_vmstats_percpu __percpu *vmstats_percpu;
+       struct memcg_vmstats_percpu __rcu /* __percpu */ *vmstats_percpu;
 
        MEMCG_PADDING(_pad2_);
 
@@ -597,17 +597,26 @@ static inline unsigned long memcg_page_state(struct 
mem_cgroup *memcg,
 static inline void __mod_memcg_state(struct mem_cgroup *memcg,
                                     int idx, int val)
 {
+       struct memcg_vmstats_percpu __percpu *vmstats_percpu;
        long x;
 
        if (mem_cgroup_disabled())
                return;
 
-       x = val + __this_cpu_read(memcg->vmstats_percpu->stat[idx]);
-       if (unlikely(abs(x) > MEMCG_CHARGE_BATCH)) {
-               atomic_long_add(x, &memcg->vmstats[idx]);
-               x = 0;
+       rcu_read_lock();
+       vmstats_percpu = (struct memcg_vmstats_percpu __percpu *)
+               rcu_dereference(memcg->vmstats_percpu);
+       if (likely(vmstats_percpu)) {
+               x = val + __this_cpu_read(vmstats_percpu->stat[idx]);
+               if (unlikely(abs(x) > MEMCG_CHARGE_BATCH)) {
+                       atomic_long_add(x, &memcg->vmstats[idx]);
+                       x = 0;
+               }
+               __this_cpu_write(vmstats_percpu->stat[idx], x);
+       } else {
+               atomic_long_add(val, &memcg->vmstats[idx]);
        }
-       __this_cpu_write(memcg->vmstats_percpu->stat[idx], x);
+       rcu_read_unlock();
 }
 
 /* idx can be of type enum memcg_stat_item or node_stat_item */
@@ -740,17 +749,26 @@ static inline void __count_memcg_events(struct mem_cgroup 
*memcg,
                                        enum vm_event_item idx,
                                        unsigned long count)
 {
+       struct memcg_vmstats_percpu __percpu *vmstats_percpu;
        unsigned long x;
 
        if (mem_cgroup_disabled())
                return;
 
-       x = count + __this_cpu_read(memcg->vmstats_percpu->events[idx]);
-       if (unlikely(x > MEMCG_CHARGE_BATCH)) {
-               atomic_long_add(x, &memcg->vmevents[idx]);
-               x = 0;
+       rcu_read_lock();
+       vmstats_percpu = (struct memcg_vmstats_percpu __percpu *)
+               rcu_dereference(memcg->vmstats_percpu);
+       if (likely(vmstats_percpu)) {
+               x = count + __this_cpu_read(vmstats_percpu->events[idx]);
+               if (unlikely(x > MEMCG_CHARGE_BATCH)) {
+                       atomic_long_add(x, &memcg->vmevents[idx]);
+                       x = 0;
+               }
+               __this_cpu_write(vmstats_percpu->events[idx], x);
+       } else {
+               atomic_long_add(count, &memcg->vmevents[idx]);
        }
-       __this_cpu_write(memcg->vmstats_percpu->events[idx], x);
+       rcu_read_unlock();
 }
 
 static inline void count_memcg_events(struct mem_cgroup *memcg,
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index c532f8685aa3..803c772f354b 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -697,6 +697,8 @@ static void mem_cgroup_charge_statistics(struct mem_cgroup 
*memcg,
                                         struct page *page,
                                         bool compound, int nr_pages)
 {
+       struct memcg_vmstats_percpu __percpu *vmstats_percpu;
+
        /*
         * Here, RSS means 'mapped anon' and anon's SwapCache. Shmem/tmpfs is
         * counted as CACHE even if it's on ANON LRU.
@@ -722,7 +724,12 @@ static void mem_cgroup_charge_statistics(struct mem_cgroup 
*memcg,
                nr_pages = -nr_pages; /* for event */
        }
 
-       __this_cpu_add(memcg->vmstats_percpu->nr_page_events, nr_pages);
+       rcu_read_lock();
+       vmstats_percpu = (struct memcg_vmstats_percpu __percpu *)
+               rcu_dereference(memcg->vmstats_percpu);
+       if (likely(vmstats_percpu))
+               __this_cpu_add(vmstats_percpu->nr_page_events, nr_pages);
+       rcu_read_unlock();
 }
 
 unsigned long mem_cgroup_node_nr_lru_pages(struct mem_cgroup *memcg,
@@ -756,10 +763,18 @@ static unsigned long mem_cgroup_nr_lru_pages(struct 
mem_cgroup *memcg,
 static bool mem_cgroup_event_ratelimit(struct mem_cgroup *memcg,
                                       enum mem_cgroup_events_target target)
 {
+       struct memcg_vmstats_percpu __percpu *vmstats_percpu;
        unsigned long val, next;
+       bool ret = false;
 
-       val = __this_cpu_read(memcg->vmstats_percpu->nr_page_events);
-       next = __this_cpu_read(memcg->vmstats_percpu->targets[target]);
+       rcu_read_lock();
+       vmstats_percpu = (struct memcg_vmstats_percpu __percpu *)
+               rcu_dereference(memcg->vmstats_percpu);
+       if (!vmstats_percpu)
+               goto out;
+
+       val = __this_cpu_read(vmstats_percpu->nr_page_events);
+       next = __this_cpu_read(vmstats_percpu->targets[target]);
        /* from time_after() in jiffies.h */
        if ((long)(next - val) < 0) {
                switch (target) {
@@ -775,10 +790,12 @@ static bool mem_cgroup_event_ratelimit(struct mem_cgroup 
*memcg,
                default:
                        break;
                }
-               __this_cpu_write(memcg->vmstats_percpu->targets[target], next);
-               return true;
+               __this_cpu_write(vmstats_percpu->targets[target], next);
+               ret = true;
        }
-       return false;
+out:
+       rcu_read_unlock();
+       return ret;
 }
 
 /*
@@ -2104,22 +2121,29 @@ static void drain_all_stock(struct mem_cgroup 
*root_memcg)
 
 static int memcg_hotplug_cpu_dead(unsigned int cpu)
 {
+       struct memcg_vmstats_percpu __percpu *vmstats_percpu;
        struct memcg_stock_pcp *stock;
        struct mem_cgroup *memcg;
 
        stock = &per_cpu(memcg_stock, cpu);
        drain_stock(stock);
 
+       rcu_read_lock();
        for_each_mem_cgroup(memcg) {
                int i;
 
+               vmstats_percpu = (struct memcg_vmstats_percpu __percpu *)
+                       rcu_dereference(memcg->vmstats_percpu);
+
                for (i = 0; i < MEMCG_NR_STAT; i++) {
                        int nid;
                        long x;
 
-                       x = this_cpu_xchg(memcg->vmstats_percpu->stat[i], 0);
-                       if (x)
-                               atomic_long_add(x, &memcg->vmstats[i]);
+                       if (vmstats_percpu) {
+                               x = this_cpu_xchg(vmstats_percpu->stat[i], 0);
+                               if (x)
+                                       atomic_long_add(x, &memcg->vmstats[i]);
+                       }
 
                        if (i >= NR_VM_NODE_STAT_ITEMS)
                                continue;
@@ -2137,11 +2161,14 @@ static int memcg_hotplug_cpu_dead(unsigned int cpu)
                for (i = 0; i < NR_VM_EVENT_ITEMS; i++) {
                        long x;
 
-                       x = this_cpu_xchg(memcg->vmstats_percpu->events[i], 0);
-                       if (x)
-                               atomic_long_add(x, &memcg->vmevents[i]);
+                       if (vmstats_percpu) {
+                               x = this_cpu_xchg(vmstats_percpu->events[i], 0);
+                               if (x)
+                                       atomic_long_add(x, &memcg->vmevents[i]);
+                       }
                }
        }
+       rcu_read_unlock();
 
        return 0;
 }
@@ -4464,7 +4491,8 @@ static struct mem_cgroup *mem_cgroup_alloc(void)
        if (memcg->id.id < 0)
                goto fail;
 
-       memcg->vmstats_percpu = alloc_percpu(struct memcg_vmstats_percpu);
+       rcu_assign_pointer(memcg->vmstats_percpu,
+                          alloc_percpu(struct memcg_vmstats_percpu));
        if (!memcg->vmstats_percpu)
                goto fail;
 
@@ -6054,6 +6082,7 @@ static void uncharge_batch(const struct uncharge_gather 
*ug)
 {
        unsigned long nr_pages = ug->nr_anon + ug->nr_file + ug->nr_kmem;
        unsigned long flags;
+       struct memcg_vmstats_percpu __percpu *vmstats_percpu;
 
        if (!mem_cgroup_is_root(ug->memcg)) {
                page_counter_uncharge(&ug->memcg->memory, nr_pages);
@@ -6070,7 +6099,12 @@ static void uncharge_batch(const struct uncharge_gather 
*ug)
        __mod_memcg_state(ug->memcg, MEMCG_RSS_HUGE, -ug->nr_huge);
        __mod_memcg_state(ug->memcg, NR_SHMEM, -ug->nr_shmem);
        __count_memcg_events(ug->memcg, PGPGOUT, ug->pgpgout);
-       __this_cpu_add(ug->memcg->vmstats_percpu->nr_page_events, nr_pages);
+       rcu_read_lock();
+       vmstats_percpu = (struct memcg_vmstats_percpu __percpu *)
+               rcu_dereference(ug->memcg->vmstats_percpu);
+       if (likely(vmstats_percpu))
+               __this_cpu_add(vmstats_percpu->nr_page_events, nr_pages);
+       rcu_read_unlock();
        memcg_check_events(ug->memcg, ug->dummy_page);
        local_irq_restore(flags);
 
-- 
2.20.1

Reply via email to