The kthread return value was checked against the wrong variable (sktp
instead of sktp[i].tp), so WARN_ON_ONCE(!sktp->tp) fired at line 87.

Replace kthread_run() with kthread_create()+wake_up_process() so the
return value can be validated before assigning to sktp[i].tp. On
creation failure, jump to a common cleanup path that signals doneflag,
stops all already-started threads, and frees sktp. This avoids leaving
orphan kthreads and leaked memory when a mid-loop failure occurs.

Also reset doneflag before spawning threads so back-to-back test
invocations don't race on stale state.

To: "Paul E. McKenney" <[email protected]>
To: Petr Mladek <[email protected]>
To: Kees Cook <[email protected]>

Signed-off-by: Jia He <[email protected]>
---
 lib/tests/test_ratelimit.c | 34 ++++++++++++++++++++++++++--------
 1 file changed, 26 insertions(+), 8 deletions(-)

diff --git a/lib/tests/test_ratelimit.c b/lib/tests/test_ratelimit.c
index 33cea5f3d28b..64f26260c0d8 100644
--- a/lib/tests/test_ratelimit.c
+++ b/lib/tests/test_ratelimit.c
@@ -105,26 +105,44 @@ static void test_ratelimit_stress(struct kunit *test)
        const int n_stress_kthread = cpumask_weight(cpu_online_mask);
        struct stress_kthread skt = { 0 };
        struct stress_kthread *sktp = kzalloc_objs(*sktp, n_stress_kthread);
+       int n_started = 0;
 
-       KUNIT_EXPECT_NOT_NULL_MSG(test, sktp, "Memory allocation failure");
+       KUNIT_ASSERT_NOT_NULL_MSG(test, sktp, "Memory allocation failure");
+       WRITE_ONCE(doneflag, 0);
        for (i = 0; i < n_stress_kthread; i++) {
-               sktp[i].tp = kthread_run(test_ratelimit_stress_child, &sktp[i], 
"%s/%i",
-                                        "test_ratelimit_stress_child", i);
-               KUNIT_EXPECT_NOT_NULL_MSG(test, sktp, "kthread creation 
failure");
+               struct task_struct *tp;
+
+               tp = kthread_create(test_ratelimit_stress_child, &sktp[i],
+                                   "%s/%i", "test_ratelimit_stress_child", i);
+               if (IS_ERR(tp)) {
+                       KUNIT_FAIL(test, "kthread_create failed: %ld", 
PTR_ERR(tp));
+                       goto out_stop;
+               }
+
+               sktp[i].tp = tp;
+               wake_up_process(tp);
+               n_started++;
                pr_alert("Spawned test_ratelimit_stress_child %d\n", i);
        }
        schedule_timeout_idle(stress_duration);
+
+out_stop:
        WRITE_ONCE(doneflag, 1);
-       for (i = 0; i < n_stress_kthread; i++) {
+       for (i = 0; i < n_started; i++) {
                kthread_stop(sktp[i].tp);
                skt.nattempts += sktp[i].nattempts;
                skt.nunlimited += sktp[i].nunlimited;
                skt.nlimited += sktp[i].nlimited;
                skt.nmissed += sktp[i].nmissed;
        }
-       KUNIT_ASSERT_EQ_MSG(test, skt.nunlimited + skt.nlimited, skt.nattempts,
-                           "Outcomes not equal to attempts");
-       KUNIT_ASSERT_EQ_MSG(test, skt.nlimited, skt.nmissed, "Misses not equal 
to limits");
+       if (n_started == n_stress_kthread) {
+               KUNIT_ASSERT_EQ_MSG(test, skt.nunlimited + skt.nlimited, 
skt.nattempts,
+                                   "Outcomes not equal to attempts");
+               KUNIT_ASSERT_EQ_MSG(test, skt.nlimited, skt.nmissed,
+                                   "Misses not equal to limits");
+       }
+
+       kfree(sktp);
 }
 
 static struct kunit_case ratelimit_test_cases[] = {
-- 
2.34.1


Reply via email to