Now that all other kernel callers of get_user_pages(FOLL_LONGTERM)
have been converted to pin_longterm_pages(), lock it down:

1) Add an assertion to get_user_pages(), preventing callers from
   passing FOLL_LONGTERM (in addition to the existing assertion that
   prevents FOLL_PIN).

2) Remove the associated GUP_LONGTERM_BENCHMARK test.

Signed-off-by: John Hubbard <jhubb...@nvidia.com>
---
 mm/gup.c                                   | 8 ++++----
 mm/gup_benchmark.c                         | 9 +--------
 tools/testing/selftests/vm/gup_benchmark.c | 7 ++-----
 3 files changed, 7 insertions(+), 17 deletions(-)

diff --git a/mm/gup.c b/mm/gup.c
index e51b3820a995..9a28935a2cb1 100644
--- a/mm/gup.c
+++ b/mm/gup.c
@@ -1744,11 +1744,11 @@ long get_user_pages(unsigned long start, unsigned long 
nr_pages,
                struct vm_area_struct **vmas)
 {
        /*
-        * As detailed above, FOLL_PIN must only be set internally by the
-        * pin_user_page*() and pin_longterm_*() APIs, never directly by the
-        * caller, so enforce that with an assertion:
+        * As detailed above, FOLL_PIN and FOLL_LONGTERM must only be set
+        * internally by the pin_user_page*() and pin_longterm_*() APIs, never
+        * directly by the caller, so enforce that with an assertion:
         */
-       if (WARN_ON_ONCE(gup_flags & FOLL_PIN))
+       if (WARN_ON_ONCE(gup_flags & (FOLL_PIN | FOLL_LONGTERM)))
                return -EINVAL;
 
        return __gup_longterm_locked(current, current->mm, start, nr_pages,
diff --git a/mm/gup_benchmark.c b/mm/gup_benchmark.c
index 2bb0f5df4803..de6941855b7e 100644
--- a/mm/gup_benchmark.c
+++ b/mm/gup_benchmark.c
@@ -6,7 +6,7 @@
 #include <linux/debugfs.h>
 
 #define GUP_FAST_BENCHMARK     _IOWR('g', 1, struct gup_benchmark)
-#define GUP_LONGTERM_BENCHMARK _IOWR('g', 2, struct gup_benchmark)
+/* Command 2 has been deleted. */
 #define GUP_BENCHMARK          _IOWR('g', 3, struct gup_benchmark)
 #define PIN_FAST_BENCHMARK     _IOWR('g', 4, struct gup_benchmark)
 #define PIN_LONGTERM_BENCHMARK _IOWR('g', 5, struct gup_benchmark)
@@ -28,7 +28,6 @@ static void put_back_pages(int cmd, struct page **pages, 
unsigned long nr_pages)
 
        switch (cmd) {
        case GUP_FAST_BENCHMARK:
-       case GUP_LONGTERM_BENCHMARK:
        case GUP_BENCHMARK:
                for (i = 0; i < nr_pages; i++)
                        put_page(pages[i]);
@@ -94,11 +93,6 @@ static int __gup_benchmark_ioctl(unsigned int cmd,
                        nr = get_user_pages_fast(addr, nr, gup->flags & 1,
                                                 pages + i);
                        break;
-               case GUP_LONGTERM_BENCHMARK:
-                       nr = get_user_pages(addr, nr,
-                                           (gup->flags & 1) | FOLL_LONGTERM,
-                                           pages + i, NULL);
-                       break;
                case GUP_BENCHMARK:
                        nr = get_user_pages(addr, nr, gup->flags & 1, pages + i,
                                            NULL);
@@ -157,7 +151,6 @@ static long gup_benchmark_ioctl(struct file *filep, 
unsigned int cmd,
 
        switch (cmd) {
        case GUP_FAST_BENCHMARK:
-       case GUP_LONGTERM_BENCHMARK:
        case GUP_BENCHMARK:
        case PIN_FAST_BENCHMARK:
        case PIN_LONGTERM_BENCHMARK:
diff --git a/tools/testing/selftests/vm/gup_benchmark.c 
b/tools/testing/selftests/vm/gup_benchmark.c
index c5c934c0f402..5ef3cf8f3da5 100644
--- a/tools/testing/selftests/vm/gup_benchmark.c
+++ b/tools/testing/selftests/vm/gup_benchmark.c
@@ -15,7 +15,7 @@
 #define PAGE_SIZE sysconf(_SC_PAGESIZE)
 
 #define GUP_FAST_BENCHMARK     _IOWR('g', 1, struct gup_benchmark)
-#define GUP_LONGTERM_BENCHMARK _IOWR('g', 2, struct gup_benchmark)
+/* Command 2 has been deleted. */
 #define GUP_BENCHMARK          _IOWR('g', 3, struct gup_benchmark)
 
 /*
@@ -46,7 +46,7 @@ int main(int argc, char **argv)
        char *file = "/dev/zero";
        char *p;
 
-       while ((opt = getopt(argc, argv, "m:r:n:f:abctTLUuwSH")) != -1) {
+       while ((opt = getopt(argc, argv, "m:r:n:f:abctTUuwSH")) != -1) {
                switch (opt) {
                case 'a':
                        cmd = PIN_FAST_BENCHMARK;
@@ -72,9 +72,6 @@ int main(int argc, char **argv)
                case 'T':
                        thp = 0;
                        break;
-               case 'L':
-                       cmd = GUP_LONGTERM_BENCHMARK;
-                       break;
                case 'U':
                        cmd = GUP_BENCHMARK;
                        break;
-- 
2.23.0

Reply via email to