https://github.com/python/cpython/commit/affffc7ddac68d01ab8d63872770a44e96c46c44
commit: affffc7ddac68d01ab8d63872770a44e96c46c44
branch: 3.12
author: Miss Islington (bot) <[email protected]>
committer: 1st1 <[email protected]>
date: 2024-10-11T20:12:11-07:00
summary:

[3.12] gh-124309: fix staggered race on eager tasks (GH-124847) (#125340)

gh-124309: fix staggered race on eager tasks (GH-124847)

This patch is entirely by Thomas and Peter

(cherry picked from commit 979c0df7c0adfb744159a5fc184043dc733d8534)

Co-authored-by: Thomas Grainger <[email protected]>
Co-authored-by: Peter Bierma <[email protected]>

files:
A Misc/NEWS.d/next/Library/2024-10-01-13-46-58.gh-issue-124390.dK1Zcm.rst
M Lib/asyncio/staggered.py
M Lib/test/test_asyncio/test_eager_task_factory.py
M Lib/test/test_asyncio/test_staggered.py

diff --git a/Lib/asyncio/staggered.py b/Lib/asyncio/staggered.py
index c3a7441a7b091d..7aafcea4d885eb 100644
--- a/Lib/asyncio/staggered.py
+++ b/Lib/asyncio/staggered.py
@@ -69,7 +69,11 @@ async def staggered_race(coro_fns, delay, *, loop=None):
     exceptions = []
     running_tasks = []
 
-    async def run_one_coro(previous_failed) -> None:
+    async def run_one_coro(ok_to_start, previous_failed) -> None:
+        # in eager tasks this waits for the calling task to append this task
+        # to running_tasks, in regular tasks this wait is a no-op that does
+        # not yield a future. See gh-124309.
+        await ok_to_start.wait()
         # Wait for the previous task to finish, or for delay seconds
         if previous_failed is not None:
             with contextlib.suppress(exceptions_mod.TimeoutError):
@@ -85,8 +89,12 @@ async def run_one_coro(previous_failed) -> None:
             return
         # Start task that will run the next coroutine
         this_failed = locks.Event()
-        next_task = loop.create_task(run_one_coro(this_failed))
+        next_ok_to_start = locks.Event()
+        next_task = loop.create_task(run_one_coro(next_ok_to_start, 
this_failed))
         running_tasks.append(next_task)
+        # next_task has been appended to running_tasks so next_task is ok to
+        # start.
+        next_ok_to_start.set()
         assert len(running_tasks) == this_index + 2
         # Prepare place to put this coroutine's exceptions if not won
         exceptions.append(None)
@@ -116,8 +124,11 @@ async def run_one_coro(previous_failed) -> None:
                 if i != this_index:
                     t.cancel()
 
-    first_task = loop.create_task(run_one_coro(None))
+    ok_to_start = locks.Event()
+    first_task = loop.create_task(run_one_coro(ok_to_start, None))
     running_tasks.append(first_task)
+    # first_task has been appended to running_tasks so first_task is ok to 
start.
+    ok_to_start.set()
     try:
         # Wait for a growing list of tasks to all finish: poor man's version of
         # curio's TaskGroup or trio's nursery
diff --git a/Lib/test/test_asyncio/test_eager_task_factory.py 
b/Lib/test/test_asyncio/test_eager_task_factory.py
index 58c06287bc3c5d..b06832e02f00d6 100644
--- a/Lib/test/test_asyncio/test_eager_task_factory.py
+++ b/Lib/test/test_asyncio/test_eager_task_factory.py
@@ -218,6 +218,52 @@ async def run():
 
         self.run_coro(run())
 
+    def test_staggered_race_with_eager_tasks(self):
+        # See https://github.com/python/cpython/issues/124309
+
+        async def fail():
+            await asyncio.sleep(0)
+            raise ValueError("no good")
+
+        async def run():
+            winner, index, excs = await asyncio.staggered.staggered_race(
+                [
+                    lambda: asyncio.sleep(2, result="sleep2"),
+                    lambda: asyncio.sleep(1, result="sleep1"),
+                    lambda: fail()
+                ],
+                delay=0.25
+            )
+            self.assertEqual(winner, 'sleep1')
+            self.assertEqual(index, 1)
+            self.assertIsNone(excs[index])
+            self.assertIsInstance(excs[0], asyncio.CancelledError)
+            self.assertIsInstance(excs[2], ValueError)
+
+        self.run_coro(run())
+
+    def test_staggered_race_with_eager_tasks_no_delay(self):
+        # See https://github.com/python/cpython/issues/124309
+        async def fail():
+            raise ValueError("no good")
+
+        async def run():
+            winner, index, excs = await asyncio.staggered.staggered_race(
+                [
+                    lambda: fail(),
+                    lambda: asyncio.sleep(1, result="sleep1"),
+                    lambda: asyncio.sleep(0, result="sleep0"),
+                ],
+                delay=None
+            )
+            self.assertEqual(winner, 'sleep1')
+            self.assertEqual(index, 1)
+            self.assertIsNone(excs[index])
+            self.assertIsInstance(excs[0], ValueError)
+            self.assertEqual(len(excs), 2)
+
+        self.run_coro(run())
+
 
 class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, 
test_utils.TestCase):
     Task = tasks._PyTask
diff --git a/Lib/test/test_asyncio/test_staggered.py 
b/Lib/test/test_asyncio/test_staggered.py
index e6e32f7dbbbcba..74941f704c4890 100644
--- a/Lib/test/test_asyncio/test_staggered.py
+++ b/Lib/test/test_asyncio/test_staggered.py
@@ -95,3 +95,30 @@ async def coro(index):
         self.assertEqual(len(excs), 2)
         self.assertIsInstance(excs[0], ValueError)
         self.assertIsInstance(excs[1], ValueError)
+
+
+    async def test_multiple_winners(self):
+        event = asyncio.Event()
+
+        async def coro(index):
+            await event.wait()
+            return index
+
+        async def do_set():
+            event.set()
+            await asyncio.Event().wait()
+
+        winner, index, excs = await staggered_race(
+            [
+                lambda: coro(0),
+                lambda: coro(1),
+                do_set,
+            ],
+            delay=0.1,
+        )
+        self.assertIs(winner, 0)
+        self.assertIs(index, 0)
+        self.assertEqual(len(excs), 3)
+        self.assertIsNone(excs[0], None)
+        self.assertIsInstance(excs[1], asyncio.CancelledError)
+        self.assertIsInstance(excs[2], asyncio.CancelledError)
diff --git 
a/Misc/NEWS.d/next/Library/2024-10-01-13-46-58.gh-issue-124390.dK1Zcm.rst 
b/Misc/NEWS.d/next/Library/2024-10-01-13-46-58.gh-issue-124390.dK1Zcm.rst
new file mode 100644
index 00000000000000..89610fa44bf743
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2024-10-01-13-46-58.gh-issue-124390.dK1Zcm.rst
@@ -0,0 +1 @@
+Fixed :exc:`AssertionError` when using 
:func:`!asyncio.staggered.staggered_race` with 
:attr:`asyncio.eager_task_factory`.

_______________________________________________
Python-checkins mailing list -- [email protected]
To unsubscribe send an email to [email protected]
https://mail.python.org/mailman3/lists/python-checkins.python.org/
Member address: [email protected]

Reply via email to