https://github.com/python/cpython/commit/415964417771946dcb7a163951913adf84644b6d
commit: 415964417771946dcb7a163951913adf84644b6d
branch: main
author: Pierre Ossman (ThinLinc team) <oss...@cendio.se>
committer: gvanrossum <gvanros...@gmail.com>
date: 2024-03-18T13:15:53-07:00
summary:

gh-113538: Add asycio.Server.{close,abort}_clients (redo) (#116784)

These give applications the option of more forcefully terminating client
connections for asyncio servers. Useful when terminating a service and
there is limited time to wait for clients to finish up their work.

This is a do-over with a test fix for gh-114432, which was reverted.

files:
A Misc/NEWS.d/next/Library/2024-01-22-15-50-58.gh-issue-113538.v2wrwg.rst
M Doc/library/asyncio-eventloop.rst
M Doc/whatsnew/3.13.rst
M Lib/asyncio/base_events.py
M Lib/asyncio/events.py
M Lib/asyncio/proactor_events.py
M Lib/asyncio/selector_events.py
M Lib/test/test_asyncio/test_server.py

diff --git a/Doc/library/asyncio-eventloop.rst 
b/Doc/library/asyncio-eventloop.rst
index 06c5c877ccc173..d6ed817b13676f 100644
--- a/Doc/library/asyncio-eventloop.rst
+++ b/Doc/library/asyncio-eventloop.rst
@@ -1641,6 +1641,31 @@ Do not instantiate the :class:`Server` class directly.
       coroutine to wait until the server is closed (and no more
       connections are active).
 
+   .. method:: close_clients()
+
+      Close all existing incoming client connections.
+
+      Calls :meth:`~asyncio.BaseTransport.close` on all associated
+      transports.
+
+      :meth:`close` should be called before :meth:`close_clients` when
+      closing the server to avoid races with new clients connecting.
+
+      .. versionadded:: 3.13
+
+   .. method:: abort_clients()
+
+      Close all existing incoming client connections immediately,
+      without waiting for pending operations to complete.
+
+      Calls :meth:`~asyncio.WriteTransport.abort` on all associated
+      transports.
+
+      :meth:`close` should be called before :meth:`abort_clients` when
+      closing the server to avoid races with new clients connecting.
+
+      .. versionadded:: 3.13
+
    .. method:: get_loop()
 
       Return the event loop associated with the server object.
diff --git a/Doc/whatsnew/3.13.rst b/Doc/whatsnew/3.13.rst
index b665e6f1c85915..0553cc97c5c75a 100644
--- a/Doc/whatsnew/3.13.rst
+++ b/Doc/whatsnew/3.13.rst
@@ -270,6 +270,11 @@ asyncio
   the buffer size.
   (Contributed by Jamie Phan in :gh:`115199`.)
 
+* Add :meth:`asyncio.Server.close_clients` and
+  :meth:`asyncio.Server.abort_clients` methods which allow to more
+  forcefully close an asyncio server.
+  (Contributed by Pierre Ossman in :gh:`113538`.)
+
 base64
 ---
 
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py
index 6c5cf28e7c59d4..f0e690b61a73dd 100644
--- a/Lib/asyncio/base_events.py
+++ b/Lib/asyncio/base_events.py
@@ -279,7 +279,9 @@ def __init__(self, loop, sockets, protocol_factory, 
ssl_context, backlog,
                  ssl_handshake_timeout, ssl_shutdown_timeout=None):
         self._loop = loop
         self._sockets = sockets
-        self._active_count = 0
+        # Weak references so we don't break Transport's ability to
+        # detect abandoned transports
+        self._clients = weakref.WeakSet()
         self._waiters = []
         self._protocol_factory = protocol_factory
         self._backlog = backlog
@@ -292,14 +294,13 @@ def __init__(self, loop, sockets, protocol_factory, 
ssl_context, backlog,
     def __repr__(self):
         return f'<{self.__class__.__name__} sockets={self.sockets!r}>'
 
-    def _attach(self):
+    def _attach(self, transport):
         assert self._sockets is not None
-        self._active_count += 1
+        self._clients.add(transport)
 
-    def _detach(self):
-        assert self._active_count > 0
-        self._active_count -= 1
-        if self._active_count == 0 and self._sockets is None:
+    def _detach(self, transport):
+        self._clients.discard(transport)
+        if len(self._clients) == 0 and self._sockets is None:
             self._wakeup()
 
     def _wakeup(self):
@@ -348,9 +349,17 @@ def close(self):
             self._serving_forever_fut.cancel()
             self._serving_forever_fut = None
 
-        if self._active_count == 0:
+        if len(self._clients) == 0:
             self._wakeup()
 
+    def close_clients(self):
+        for transport in self._clients.copy():
+            transport.close()
+
+    def abort_clients(self):
+        for transport in self._clients.copy():
+            transport.abort()
+
     async def start_serving(self):
         self._start_serving()
         # Skip one loop iteration so that all 'loop.add_reader'
diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py
index 680749325025db..be495469a0558b 100644
--- a/Lib/asyncio/events.py
+++ b/Lib/asyncio/events.py
@@ -175,6 +175,14 @@ def close(self):
         """Stop serving.  This leaves existing connections open."""
         raise NotImplementedError
 
+    def close_clients(self):
+        """Close all active connections."""
+        raise NotImplementedError
+
+    def abort_clients(self):
+        """Close all active connections immediately."""
+        raise NotImplementedError
+
     def get_loop(self):
         """Get the event loop the Server object is attached to."""
         raise NotImplementedError
diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py
index a512db6367b20a..397a8cda757895 100644
--- a/Lib/asyncio/proactor_events.py
+++ b/Lib/asyncio/proactor_events.py
@@ -63,7 +63,7 @@ def __init__(self, loop, sock, protocol, waiter=None,
         self._called_connection_lost = False
         self._eof_written = False
         if self._server is not None:
-            self._server._attach()
+            self._server._attach(self)
         self._loop.call_soon(self._protocol.connection_made, self)
         if waiter is not None:
             # only wake up the waiter when connection_made() has been called
@@ -167,7 +167,7 @@ def _call_connection_lost(self, exc):
             self._sock = None
             server = self._server
             if server is not None:
-                server._detach()
+                server._detach(self)
                 self._server = None
             self._called_connection_lost = True
 
diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py
index 8e888d26ea0737..f94bf10b4225e7 100644
--- a/Lib/asyncio/selector_events.py
+++ b/Lib/asyncio/selector_events.py
@@ -791,7 +791,7 @@ def __init__(self, loop, sock, protocol, extra=None, 
server=None):
         self._paused = False  # Set when pause_reading() called
 
         if self._server is not None:
-            self._server._attach()
+            self._server._attach(self)
         loop._transports[self._sock_fd] = self
 
     def __repr__(self):
@@ -868,6 +868,8 @@ def __del__(self, _warn=warnings.warn):
         if self._sock is not None:
             _warn(f"unclosed transport {self!r}", ResourceWarning, source=self)
             self._sock.close()
+            if self._server is not None:
+                self._server._detach(self)
 
     def _fatal_error(self, exc, message='Fatal error on transport'):
         # Should be called from exception handler only.
@@ -906,7 +908,7 @@ def _call_connection_lost(self, exc):
             self._loop = None
             server = self._server
             if server is not None:
-                server._detach()
+                server._detach(self)
                 self._server = None
 
     def get_write_buffer_size(self):
diff --git a/Lib/test/test_asyncio/test_server.py 
b/Lib/test/test_asyncio/test_server.py
index 918faac909b9bf..4ca8a166a0f1a1 100644
--- a/Lib/test/test_asyncio/test_server.py
+++ b/Lib/test/test_asyncio/test_server.py
@@ -125,8 +125,12 @@ async def main(srv):
 class TestServer2(unittest.IsolatedAsyncioTestCase):
 
     async def test_wait_closed_basic(self):
-        async def serve(*args):
-            pass
+        async def serve(rd, wr):
+            try:
+                await rd.read()
+            finally:
+                wr.close()
+                await wr.wait_closed()
 
         srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0)
         self.addCleanup(srv.close)
@@ -137,7 +141,8 @@ async def serve(*args):
         self.assertFalse(task1.done())
 
         # active count != 0, not closed: should block
-        srv._attach()
+        addr = srv.sockets[0].getsockname()
+        (rd, wr) = await asyncio.open_connection(addr[0], addr[1])
         task2 = asyncio.create_task(srv.wait_closed())
         await asyncio.sleep(0)
         self.assertFalse(task1.done())
@@ -152,7 +157,8 @@ async def serve(*args):
         self.assertFalse(task2.done())
         self.assertFalse(task3.done())
 
-        srv._detach()
+        wr.close()
+        await wr.wait_closed()
         # active count == 0, closed: should unblock
         await task1
         await task2
@@ -161,8 +167,12 @@ async def serve(*args):
 
     async def test_wait_closed_race(self):
         # Test a regression in 3.12.0, should be fixed in 3.12.1
-        async def serve(*args):
-            pass
+        async def serve(rd, wr):
+            try:
+                await rd.read()
+            finally:
+                wr.close()
+                await wr.wait_closed()
 
         srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0)
         self.addCleanup(srv.close)
@@ -170,13 +180,83 @@ async def serve(*args):
         task = asyncio.create_task(srv.wait_closed())
         await asyncio.sleep(0)
         self.assertFalse(task.done())
-        srv._attach()
+        addr = srv.sockets[0].getsockname()
+        (rd, wr) = await asyncio.open_connection(addr[0], addr[1])
         loop = asyncio.get_running_loop()
         loop.call_soon(srv.close)
-        loop.call_soon(srv._detach)
+        loop.call_soon(wr.close)
         await srv.wait_closed()
 
+    async def test_close_clients(self):
+        async def serve(rd, wr):
+            try:
+                await rd.read()
+            finally:
+                wr.close()
+                await wr.wait_closed()
+
+        srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0)
+        self.addCleanup(srv.close)
+
+        addr = srv.sockets[0].getsockname()
+        (rd, wr) = await asyncio.open_connection(addr[0], addr[1])
+        self.addCleanup(wr.close)
+
+        task = asyncio.create_task(srv.wait_closed())
+        await asyncio.sleep(0)
+        self.assertFalse(task.done())
+
+        srv.close()
+        srv.close_clients()
+        await asyncio.sleep(0)
+        await asyncio.sleep(0)
+        self.assertTrue(task.done())
+
+    async def test_abort_clients(self):
+        async def serve(rd, wr):
+            fut.set_result((rd, wr))
+            await wr.wait_closed()
+
+        fut = asyncio.Future()
+        srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0)
+        self.addCleanup(srv.close)
+
+        addr = srv.sockets[0].getsockname()
+        (c_rd, c_wr) = await asyncio.open_connection(addr[0], addr[1], 
limit=4096)
+        self.addCleanup(c_wr.close)
+
+        (s_rd, s_wr) = await fut
+
+        # Limit the socket buffers so we can reliably overfill them
+        s_sock = s_wr.get_extra_info('socket')
+        s_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 65536)
+        c_sock = c_wr.get_extra_info('socket')
+        c_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 65536)
+
+        # Get the reader in to a paused state by sending more than twice
+        # the configured limit
+        s_wr.write(b'a' * 4096)
+        s_wr.write(b'a' * 4096)
+        s_wr.write(b'a' * 4096)
+        while c_wr.transport.is_reading():
+            await asyncio.sleep(0)
+
+        # Get the writer in a waiting state by sending data until the
+        # socket buffers are full on both server and client sockets and
+        # the kernel stops accepting more data
+        s_wr.write(b'a' * c_sock.getsockopt(socket.SOL_SOCKET, 
socket.SO_RCVBUF))
+        s_wr.write(b'a' * s_sock.getsockopt(socket.SOL_SOCKET, 
socket.SO_SNDBUF))
+        self.assertNotEqual(s_wr.transport.get_write_buffer_size(), 0)
+
+        task = asyncio.create_task(srv.wait_closed())
+        await asyncio.sleep(0)
+        self.assertFalse(task.done())
 
+        srv.close()
+        srv.abort_clients()
+        await asyncio.sleep(0)
+        await asyncio.sleep(0)
+        self.assertTrue(task.done())
 
 
 # Test the various corner cases of Unix server socket removal
diff --git 
a/Misc/NEWS.d/next/Library/2024-01-22-15-50-58.gh-issue-113538.v2wrwg.rst 
b/Misc/NEWS.d/next/Library/2024-01-22-15-50-58.gh-issue-113538.v2wrwg.rst
new file mode 100644
index 00000000000000..5c59af98e136bb
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2024-01-22-15-50-58.gh-issue-113538.v2wrwg.rst
@@ -0,0 +1,3 @@
+Add :meth:`asyncio.Server.close_clients` and
+:meth:`asyncio.Server.abort_clients` methods which allow to more forcefully
+close an asyncio server.

_______________________________________________
Python-checkins mailing list -- python-checkins@python.org
To unsubscribe send an email to python-checkins-le...@python.org
https://mail.python.org/mailman3/lists/python-checkins.python.org/
Member address: arch...@mail-archive.com

Reply via email to