madvise_collapse() computes the THP-aligned window:
hstart = (start + ~HPAGE_PMD_MASK) & HPAGE_PMD_MASK /* round up */
hend = end & HPAGE_PMD_MASK /* round down */
Previously this was done after kmalloc_obj(), so problem arose when
the range contained no complete PMD-aligned window (hstart >= hend).
When hstart > hend, (hend - hstart) wraps unsigned to a huge value, the
final comparison fails and -EINVAL is returned instead of 0. Consider
two single-page calls on a 2 MiB-aligned address:
/* hstart == hend == aligned -> 0 == 0 -> returns 0 */
madvise(aligned, PAGE_SIZE, MADV_COLLAPSE);
/* hstart = aligned + 2MiB, hend = aligned
* (hend - hstart) wraps unsigned -> returns -EINVAL */
madvise(aligned + PAGE_SIZE, PAGE_SIZE, MADV_COLLAPSE);
Both calls cover less than one THP and collapse nothing; both should
return 0.
In addition, kmalloc_obj(), mmgrab() and lru_add_drain_all() were all
called before discovering there was nothing to do, only for the code
to kfree() and return immediately after.
Fix both by computing hstart/hend after thp_vma_allowable_order() but
before kmalloc_obj(), and returning 0 early when hstart >= hend.
Fixes: 7d8faaf15545 ("mm/madvise: introduce MADV_COLLAPSE sync hugepage
collapse")
Signed-off-by: Chen Wandun <[email protected]>
---
mm/khugepaged.c | 9 ++++++---
1 file changed, 6 insertions(+), 3 deletions(-)
diff --git a/mm/khugepaged.c b/mm/khugepaged.c
index b8452dbdb043..92473d93e837 100644
--- a/mm/khugepaged.c
+++ b/mm/khugepaged.c
@@ -2836,6 +2836,12 @@ int madvise_collapse(struct vm_area_struct *vma,
unsigned long start,
if (!thp_vma_allowable_order(vma, vma->vm_flags, TVA_FORCED_COLLAPSE,
PMD_ORDER))
return -EINVAL;
+ hstart = (start + ~HPAGE_PMD_MASK) & HPAGE_PMD_MASK;
+ hend = end & HPAGE_PMD_MASK;
+
+ if (hstart >= hend)
+ return 0;
+
cc = kmalloc_obj(*cc);
if (!cc)
return -ENOMEM;
@@ -2845,9 +2851,6 @@ int madvise_collapse(struct vm_area_struct *vma, unsigned
long start,
mmgrab(mm);
lru_add_drain_all();
- hstart = (start + ~HPAGE_PMD_MASK) & HPAGE_PMD_MASK;
- hend = end & HPAGE_PMD_MASK;
-
for (addr = hstart; addr < hend; addr += HPAGE_PMD_SIZE) {
enum scan_result result = SCAN_FAIL;
--
2.43.0