From: Yulei Zhang <[email protected]>

Append 'memmap' into struct dmem_region, mapping each page of dmem with
struct dmempage.

Currently there is just one member '_refcount' in struct dmempage to
reflect the number of all modules which occupied the dmem page.

Modules which allocates the dmem page from dmempool will make first
reference and set _refcount to 1.

Modules which try to free the dmem page to dmempool will decrease 1
at _refcount and free it if _refcount is tested as zero after decrease.

At each time module A passes dmem page to module B, module B should call
get_dmem_pfn() to increase _refcount for dmem page before making use of it
to avoid referencing a dmem page which is occasionally freeed by any other
module in parallel. Vice versa after finishing usage of that dmem page
need call put_dmem_pfn() to decrease the _refcount.

Signed-off-by: Chen Zhuo <[email protected]>
Signed-off-by: Yulei Zhang <[email protected]>
---
 include/linux/dmem.h |   5 ++
 mm/dmem.c            | 147 ++++++++++++++++++++++++++++++++++++++++++++++-----
 2 files changed, 139 insertions(+), 13 deletions(-)

diff --git a/include/linux/dmem.h b/include/linux/dmem.h
index fe0b270..8aaa80b 100644
--- a/include/linux/dmem.h
+++ b/include/linux/dmem.h
@@ -22,6 +22,9 @@
 bool is_dmem_pfn(unsigned long pfn);
 #define dmem_free_page(addr)   dmem_free_pages(addr, 1)
 
+void get_dmem_pfn(unsigned long pfn);
+#define put_dmem_pfn(pfn)      dmem_free_page(PFN_PHYS(pfn))
+
 bool dmem_memory_failure(unsigned long pfn, int flags);
 
 struct dmem_mce_notifier_info {
@@ -45,5 +48,7 @@ static inline bool dmem_memory_failure(unsigned long pfn, int 
flags)
 {
        return false;
 }
+void get_dmem_pfn(unsigned long pfn) {}
+void put_dmem_pfn(unsigned long pfn) {}
 #endif
 #endif /* _LINUX_DMEM_H */
diff --git a/mm/dmem.c b/mm/dmem.c
index dd81b24..776dbf2 100644
--- a/mm/dmem.c
+++ b/mm/dmem.c
@@ -47,6 +47,7 @@ struct dmem_region {
 
        unsigned long static_error_bitmap;
        unsigned long *error_bitmap;
+       void *memmap;
 };
 
 /*
@@ -91,6 +92,10 @@ struct dmem_pool {
        struct dmem_node nodes[MAX_NUMNODES];
 };
 
+struct dmempage {
+       atomic_t _refcount;
+};
+
 static struct dmem_pool dmem_pool = {
        .lock = __MUTEX_INITIALIZER(dmem_pool.lock),
        .mce_notifier_chain = RAW_NOTIFIER_INIT(dmem_pool.mce_notifier_chain),
@@ -123,6 +128,40 @@ struct dmem_pool {
 #define for_each_dmem_region(_dnode, _dregion)                         \
        list_for_each_entry(_dregion, &(_dnode)->regions, node)
 
+#define pfn_to_dmempage(_pfn, _dregion)                                        
\
+       ((struct dmempage *)(_dregion)->memmap +                        \
+       pfn_to_dpage(_pfn) - (_dregion)->dpage_start_pfn)
+
+#define dmempage_to_dpage(_dmempage, _dregion)                         \
+       ((_dmempage) - (struct dmempage *)(_dregion)->memmap +          \
+       (_dregion)->dpage_start_pfn)
+
+static inline int dmempage_count(struct dmempage *dmempage)
+{
+       return atomic_read(&dmempage->_refcount);
+}
+
+static inline void set_dmempage_count(struct dmempage *dmempage, int v)
+{
+       atomic_set(&dmempage->_refcount, v);
+}
+
+static inline void dmempage_ref_inc(struct dmempage *dmempage)
+{
+       atomic_inc(&dmempage->_refcount);
+}
+
+static inline int dmempage_ref_dec_and_test(struct dmempage *dmempage)
+{
+       return atomic_dec_and_test(&dmempage->_refcount);
+}
+
+static inline int put_dmempage_testzero(struct dmempage *dmempage)
+{
+       VM_BUG_ON(dmempage_count(dmempage) == 0);
+       return dmempage_ref_dec_and_test(dmempage);
+}
+
 int dmem_register_mce_notifier(struct notifier_block *nb)
 {
        int ret;
@@ -559,10 +598,25 @@ static int __init dmem_late_init(void)
 }
 late_initcall(dmem_late_init);
 
+static void *dmem_memmap_alloc(unsigned long dpages)
+{
+       unsigned long size;
+
+       size = dpages * sizeof(struct dmempage);
+       return vzalloc(size);
+}
+
+static void dmem_memmap_free(void *memmap)
+{
+       if (memmap)
+               vfree(memmap);
+}
+
 static int dmem_alloc_region_init(struct dmem_region *dregion,
                                  unsigned long *dpages)
 {
        unsigned long start, end, *bitmap;
+       void *memmap;
 
        start = DMEM_PAGE_UP(dregion->reserved_start_addr);
        end = DMEM_PAGE_DOWN(dregion->reserved_end_addr);
@@ -575,7 +629,14 @@ static int dmem_alloc_region_init(struct dmem_region 
*dregion,
        if (!bitmap)
                return -ENOMEM;
 
+       memmap = dmem_memmap_alloc(*dpages);
+       if (!memmap) {
+               dmem_bitmap_free(*dpages, bitmap, &dregion->static_bitmap);
+               return -ENOMEM;
+       }
+
        dregion->bitmap = bitmap;
+       dregion->memmap = memmap;
        dregion->next_free_pos = 0;
        dregion->dpage_start_pfn = start;
        dregion->dpage_end_pfn = end;
@@ -650,7 +711,9 @@ static void dmem_alloc_region_uinit(struct dmem_region 
*dregion)
        dmem_uinit_check_alloc_bitmap(dregion);
 
        dmem_bitmap_free(dpages, bitmap, &dregion->static_bitmap);
+       dmem_memmap_free(dregion->memmap);
        dregion->bitmap = NULL;
+       dregion->memmap = NULL;
 }
 
 static void __dmem_alloc_uinit(void)
@@ -793,6 +856,16 @@ int dmem_alloc_init(unsigned long dpage_shift)
        return dpage_to_phys(dregion->dpage_start_pfn + pos);
 }
 
+static void prep_new_dmempage(unsigned long phys, unsigned int nr,
+                             struct dmem_region *dregion)
+{
+       struct dmempage *dmempage = pfn_to_dmempage(PHYS_PFN(phys), dregion);
+       unsigned int i;
+
+       for (i = 0; i < nr; i++, dmempage++)
+               set_dmempage_count(dmempage, 1);
+}
+
 /*
  * allocate dmem pages from the nodelist
  *
@@ -839,6 +912,7 @@ int dmem_alloc_init(unsigned long dpage_shift)
                        if (addr) {
                                dnode_count_free_dpages(dnode,
                                                        -(long)(*result_nr));
+                               prep_new_dmempage(addr, *result_nr, dregion);
                                break;
                        }
                }
@@ -993,6 +1067,41 @@ static struct dmem_region *find_dmem_region(phys_addr_t 
phys_addr,
        return NULL;
 }
 
+static unsigned int free_dmempages_prepare(struct dmempage *dmempage,
+                                  unsigned int dpages_nr)
+{
+       unsigned int i, ret = 0;
+
+       for (i = 0; i < dpages_nr; i++, dmempage++)
+               if (put_dmempage_testzero(dmempage))
+                       ret++;
+
+       return ret;
+}
+
+void __dmem_free_pages(struct dmempage *dmempage,
+                      unsigned int dpages_nr,
+                      struct dmem_region *dregion,
+                      struct dmem_node *pdnode)
+{
+       phys_addr_t dpage = dmempage_to_dpage(dmempage, dregion);
+       u64 pos;
+       unsigned long err_dpages;
+
+       trace_dmem_free_pages(dpage_to_phys(dpage), dpages_nr);
+       WARN_ON(!dmem_pool.dpage_shift);
+
+       pos = dpage - dregion->dpage_start_pfn;
+       dregion->next_free_pos = min(dregion->next_free_pos, pos);
+
+       /* it is not possible to span multiple regions */
+       WARN_ON(dpage + dpages_nr - 1 >= dregion->dpage_end_pfn);
+
+       err_dpages = dmem_alloc_bitmap_clear(dregion, dpage, dpages_nr);
+
+       dnode_count_free_dpages(pdnode, dpages_nr - err_dpages);
+}
+
 /*
  * free dmem page to the dmem pool
  *   @addr: the physical addree will be freed
@@ -1002,27 +1111,26 @@ void dmem_free_pages(phys_addr_t addr, unsigned int 
dpages_nr)
 {
        struct dmem_region *dregion;
        struct dmem_node *pdnode = NULL;
-       phys_addr_t dpage = phys_to_dpage(addr);
-       u64 pos;
-       unsigned long err_dpages;
+       struct dmempage *dmempage;
+       unsigned int nr;
 
        mutex_lock(&dmem_pool.lock);
 
-       trace_dmem_free_pages(addr, dpages_nr);
-       WARN_ON(!dmem_pool.dpage_shift);
-
        dregion = find_dmem_region(addr, &pdnode);
        WARN_ON(!dregion || !dregion->bitmap || !pdnode);
 
-       pos = dpage - dregion->dpage_start_pfn;
-       dregion->next_free_pos = min(dregion->next_free_pos, pos);
-
-       /* it is not possible to span multiple regions */
-       WARN_ON(dpage + dpages_nr - 1 >= dregion->dpage_end_pfn);
+       dmempage = pfn_to_dmempage(PHYS_PFN(addr), dregion);
 
-       err_dpages = dmem_alloc_bitmap_clear(dregion, dpage, dpages_nr);
+       nr = free_dmempages_prepare(dmempage, dpages_nr);
+       if (nr == dpages_nr)
+               __dmem_free_pages(dmempage, dpages_nr, dregion, pdnode);
+       else if (nr)
+               while (dpages_nr--, dmempage++) {
+                       if (dmempage_count(dmempage))
+                               continue;
+                       __dmem_free_pages(dmempage, 1, dregion, pdnode);
+               }
 
-       dnode_count_free_dpages(pdnode, dpages_nr - err_dpages);
        mutex_unlock(&dmem_pool.lock);
 }
 EXPORT_SYMBOL(dmem_free_pages);
@@ -1073,3 +1181,16 @@ bool is_dmem_pfn(unsigned long pfn)
        return !!find_dmem_region(__pfn_to_phys(pfn), &dnode);
 }
 EXPORT_SYMBOL(is_dmem_pfn);
+
+void get_dmem_pfn(unsigned long pfn)
+{
+       struct dmem_region *dregion = find_dmem_region(PFN_PHYS(pfn), NULL);
+       struct dmempage *dmempage;
+
+       VM_BUG_ON(!dregion || !dregion->memmap);
+
+       dmempage = pfn_to_dmempage(pfn, dregion);
+       VM_BUG_ON(dmempage_count(dmempage) + 127u <= 127u);
+       dmempage_ref_inc(dmempage);
+}
+EXPORT_SYMBOL(get_dmem_pfn);
-- 
1.8.3.1

Reply via email to