If srcmap contains invalid data, such as HOLE and UNWRITTEN, the dest
page should be zeroed.  Otherwise, since it's a pmem, old data may
remains on the dest page, the result of CoW will be incorrect.

The function name is also not easy to understand, rename it to
"dax_iomap_copy_around()", which means it copys data around the range.

Signed-off-by: Shiyang Ruan <ruansy.f...@fujitsu.com>
---
 fs/dax.c | 78 ++++++++++++++++++++++++++++++++++----------------------
 1 file changed, 48 insertions(+), 30 deletions(-)

diff --git a/fs/dax.c b/fs/dax.c
index 482dda85ccaf..6b6e07ad8d80 100644
--- a/fs/dax.c
+++ b/fs/dax.c
@@ -1092,7 +1092,7 @@ static int dax_iomap_direct_access(const struct iomap 
*iomap, loff_t pos,
 }
 
 /**
- * dax_iomap_cow_copy - Copy the data from source to destination before write
+ * dax_iomap_copy_around - Copy the data from source to destination before 
write
  * @pos:       address to do copy from.
  * @length:    size of copy operation.
  * @align_size:        aligned w.r.t align_size (either PMD_SIZE or PAGE_SIZE)
@@ -1101,35 +1101,50 @@ static int dax_iomap_direct_access(const struct iomap 
*iomap, loff_t pos,
  *
  * This can be called from two places. Either during DAX write fault (page
  * aligned), to copy the length size data to daddr. Or, while doing normal DAX
- * write operation, dax_iomap_actor() might call this to do the copy of either
+ * write operation, dax_iomap_iter() might call this to do the copy of either
  * start or end unaligned address. In the latter case the rest of the copy of
- * aligned ranges is taken care by dax_iomap_actor() itself.
+ * aligned ranges is taken care by dax_iomap_iter() itself.
+ * If the srcmap contains invalid data, such as HOLE and UNWRITTEN, zero the
+ * area to make sure no old data remains.
  */
-static int dax_iomap_cow_copy(loff_t pos, uint64_t length, size_t align_size,
+static int dax_iomap_copy_around(loff_t pos, uint64_t length, size_t 
align_size,
                const struct iomap *srcmap, void *daddr)
 {
        loff_t head_off = pos & (align_size - 1);
        size_t size = ALIGN(head_off + length, align_size);
        loff_t end = pos + length;
        loff_t pg_end = round_up(end, align_size);
+       /* copy_all is usually in page fault case */
        bool copy_all = head_off == 0 && end == pg_end;
+       /* zero the edges if srcmap is a HOLE or IOMAP_UNWRITTEN */
+       bool zero_edge = srcmap->flags & IOMAP_F_SHARED ||
+                        srcmap->type == IOMAP_UNWRITTEN;
        void *saddr = 0;
        int ret = 0;
 
-       ret = dax_iomap_direct_access(srcmap, pos, size, &saddr, NULL);
-       if (ret)
-               return ret;
+       if (!zero_edge) {
+               ret = dax_iomap_direct_access(srcmap, pos, size, &saddr, NULL);
+               if (ret)
+                       return ret;
+       }
 
        if (copy_all) {
-               ret = copy_mc_to_kernel(daddr, saddr, length);
-               return ret ? -EIO : 0;
+               if (zero_edge)
+                       memset(daddr, 0, size);
+               else
+                       ret = copy_mc_to_kernel(daddr, saddr, length);
+               goto out;
        }
 
        /* Copy the head part of the range */
        if (head_off) {
-               ret = copy_mc_to_kernel(daddr, saddr, head_off);
-               if (ret)
-                       return -EIO;
+               if (zero_edge)
+                       memset(daddr, 0, head_off);
+               else {
+                       ret = copy_mc_to_kernel(daddr, saddr, head_off);
+                       if (ret)
+                               return -EIO;
+               }
        }
 
        /* Copy the tail part of the range */
@@ -1137,12 +1152,19 @@ static int dax_iomap_cow_copy(loff_t pos, uint64_t 
length, size_t align_size,
                loff_t tail_off = head_off + length;
                loff_t tail_len = pg_end - end;
 
-               ret = copy_mc_to_kernel(daddr + tail_off, saddr + tail_off,
-                                       tail_len);
-               if (ret)
-                       return -EIO;
+               if (zero_edge)
+                       memset(daddr + tail_off, 0, tail_len);
+               else {
+                       ret = copy_mc_to_kernel(daddr + tail_off,
+                                               saddr + tail_off, tail_len);
+                       if (ret)
+                               return -EIO;
+               }
        }
-       return 0;
+out:
+       if (zero_edge)
+               dax_flush(srcmap->dax_dev, daddr, size);
+       return ret ? -EIO : 0;
 }
 
 /*
@@ -1241,13 +1263,10 @@ static int dax_memzero(struct iomap_iter *iter, loff_t 
pos, size_t size)
        if (ret < 0)
                return ret;
        memset(kaddr + offset, 0, size);
-       if (srcmap->addr != iomap->addr) {
-               ret = dax_iomap_cow_copy(pos, size, PAGE_SIZE, srcmap,
-                                        kaddr);
-               if (ret < 0)
-                       return ret;
-               dax_flush(iomap->dax_dev, kaddr, PAGE_SIZE);
-       } else
+       if (iomap->flags & IOMAP_F_SHARED)
+               ret = dax_iomap_copy_around(pos, size, PAGE_SIZE, srcmap,
+                                           kaddr);
+       else
                dax_flush(iomap->dax_dev, kaddr + offset, size);
        return ret;
 }
@@ -1401,8 +1420,8 @@ static loff_t dax_iomap_iter(const struct iomap_iter 
*iomi,
                }
 
                if (cow) {
-                       ret = dax_iomap_cow_copy(pos, length, PAGE_SIZE, srcmap,
-                                                kaddr);
+                       ret = dax_iomap_copy_around(pos, length, PAGE_SIZE,
+                                                   srcmap, kaddr);
                        if (ret)
                                break;
                }
@@ -1547,7 +1566,7 @@ static vm_fault_t dax_fault_iter(struct vm_fault *vmf,
                struct xa_state *xas, void **entry, bool pmd)
 {
        const struct iomap *iomap = &iter->iomap;
-       const struct iomap *srcmap = &iter->srcmap;
+       const struct iomap *srcmap = iomap_iter_srcmap(iter);
        size_t size = pmd ? PMD_SIZE : PAGE_SIZE;
        loff_t pos = (loff_t)xas->xa_index << PAGE_SHIFT;
        bool write = iter->flags & IOMAP_WRITE;
@@ -1578,9 +1597,8 @@ static vm_fault_t dax_fault_iter(struct vm_fault *vmf,
 
        *entry = dax_insert_entry(xas, vmf, iter, *entry, pfn, entry_flags);
 
-       if (write &&
-           srcmap->type != IOMAP_HOLE && srcmap->addr != iomap->addr) {
-               err = dax_iomap_cow_copy(pos, size, size, srcmap, kaddr);
+       if (write && iomap->flags & IOMAP_F_SHARED) {
+               err = dax_iomap_copy_around(pos, size, size, srcmap, kaddr);
                if (err)
                        return dax_fault_return(err);
        }
-- 
2.38.1


Reply via email to