kaxil commented on code in PR #67635:
URL: https://github.com/apache/airflow/pull/67635#discussion_r3325272486


##########
task-sdk/tests/task_sdk/coordinators/socket/test_coordinator.py:
##########
@@ -0,0 +1,580 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import contextlib
+import socket
+import subprocess
+import threading
+import time
+from unittest.mock import ANY, MagicMock, call, patch
+
+import attrs
+import pytest
+from uuid6 import uuid7
+
+from airflow.sdk.coordinators.socket.coordinator import (
+    SocketCoordinator,
+    _accept_connections,
+    _ResourceTracker,
+    _SocketActivitySubprocess,
+    _start_server,
+)
+from airflow.sdk.execution_time.coordinator import BaseCoordinator
+from airflow.sdk.execution_time.supervisor import ActivitySubprocess
+from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO
+
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_3_PLUS
+
+if not AIRFLOW_V_3_3_PLUS:
+    pytest.skip("Coordinator is only compatible with Airflow >= 3.3.0", 
allow_module_level=True)
+
+
+def _make_ti(dag_id: str = "tutorial_dag", queue: str = "socket") -> 
TaskInstanceDTO:
+    return TaskInstanceDTO(
+        id=uuid7(),
+        dag_version_id=uuid7(),
+        task_id="task_1",
+        dag_id=dag_id,
+        run_id="run_1",
+        try_number=1,
+        map_index=-1,
+        pool_slots=1,
+        queue=queue,
+        priority_weight=1,
+    )
+
+
+class TestStartServer:
+    def test_returns_listening_socket(self):
+        server = _start_server()
+        try:
+            host, port = server.getsockname()
+        finally:
+            server.close()
+        assert host == "127.0.0.1"
+        assert port > 0
+
+    def test_two_calls_return_different_ports(self):
+        s1 = _start_server()
+        s2 = _start_server()
+        try:
+            _, port1 = s1.getsockname()
+            _, port2 = s2.getsockname()
+        finally:
+            s1.close()
+            s2.close()
+        assert port1 != port2
+
+    def test_accepts_connection(self):
+        conn = client = None
+        server = _start_server()
+        try:
+            _, port = server.getsockname()
+            client = socket.socket()
+            client.connect(("127.0.0.1", port))
+            conn, _ = server.accept()
+            conn.sendall(b"ping")
+            received = client.recv(4)
+        finally:
+            if conn:
+                conn.close()
+            if client:
+                client.close()
+            server.close()
+        assert received == b"ping"
+
+
+class TestAcceptConnections:
+    def _connect_after_delay(self, addr: tuple[str, int], delay: float = 0.0) 
-> None:
+        def _connect():
+            time.sleep(delay)
+            c = socket.socket()
+            with contextlib.suppress(OSError):
+                c.connect(addr)
+
+        threading.Thread(target=_connect, daemon=True).start()
+
+    def test_accepts_single_server(self):
+        server = _start_server()
+        _, port = server.getsockname()
+        self._connect_after_delay(("127.0.0.1", port))
+
+        mock_proc = MagicMock(spec=subprocess.Popen)
+        mock_proc.poll.return_value = None
+
+        try:
+            accepted, _ = _accept_connections({"comm": server}, {}, mock_proc)
+            assert server in accepted
+            accepted[server].close()
+        finally:
+            server.close()
+
+    def test_accepts_multiple_servers_keyed_by_server_socket(self):
+        comm_server = _start_server()
+        logs_server = _start_server()
+        _, comm_port = comm_server.getsockname()
+        _, logs_port = logs_server.getsockname()
+
+        self._connect_after_delay(("127.0.0.1", comm_port))
+        self._connect_after_delay(("127.0.0.1", logs_port))
+
+        mock_proc = MagicMock(spec=subprocess.Popen)
+        mock_proc.poll.return_value = None
+
+        try:
+            accepted, drained = _accept_connections({"comm": comm_server, 
"logs": logs_server}, {}, mock_proc)
+            assert set(accepted) == {comm_server, logs_server}
+            assert drained == {}
+            for sock in accepted.values():
+                sock.close()
+        finally:
+            comm_server.close()
+            logs_server.close()
+
+    def test_empty_drains_returns_empty_drained_dict(self):
+        server = _start_server()
+        _, port = server.getsockname()
+        self._connect_after_delay(("127.0.0.1", port))
+
+        mock_proc = MagicMock(spec=subprocess.Popen)
+        mock_proc.poll.return_value = None
+        try:
+            _, drained = _accept_connections({"comm": server}, {}, mock_proc)
+            assert drained == {}
+        finally:
+            server.close()
+
+    def test_drain_socket_present_in_drained_dict(self):
+        server = _start_server()
+        drain_r, drain_w = socket.socketpair()
+        _, port = server.getsockname()
+        self._connect_after_delay(("127.0.0.1", port))
+
+        mock_proc = MagicMock(spec=subprocess.Popen)
+        mock_proc.poll.return_value = None
+        try:
+            _, drained = _accept_connections({"comm": server}, {"stdout": 
drain_r}, mock_proc)
+            assert drain_r in drained
+        finally:
+            drain_r.close()
+            drain_w.close()
+            server.close()
+
+    def test_drain_captures_early_output(self):
+        """Bytes written to the drain socket before the comm server accepts
+        must be captured and returned in the drained dict."""
+        server = _start_server()
+        drain_r, drain_w = socket.socketpair()
+        _, port = server.getsockname()
+
+        drain_w.sendall(b"early output\n")
+        drain_w.shutdown(socket.SHUT_WR)
+        self._connect_after_delay(("127.0.0.1", port), delay=0.05)
+
+        mock_proc = MagicMock(spec=subprocess.Popen)
+        mock_proc.poll.return_value = None
+        try:
+            _, drained = _accept_connections({"comm": server}, {"stdout": 
drain_r}, mock_proc)
+            assert drained[drain_r] == b"early output\n"
+        finally:
+            drain_r.close()
+            drain_w.close()
+            server.close()
+
+    def test_raises_timeout_when_no_connection(self):
+        server = _start_server()
+        mock_proc = MagicMock(spec=subprocess.Popen)
+        mock_proc.poll.return_value = None
+        try:
+            with pytest.raises(TimeoutError, match="did not connect within 
timeout"):
+                _accept_connections({"comm": server}, {}, mock_proc, 
max_wait=0.05)
+        finally:
+            server.close()
+
+    def test_raises_runtime_error_if_process_exits_before_connecting(self):
+        server = _start_server()
+        mock_proc = MagicMock(spec=subprocess.Popen)
+        mock_proc.poll.return_value = 1
+        mock_proc.returncode = 1
+        try:
+            with pytest.raises(RuntimeError, match="process exited with 1"):
+                _accept_connections({"comm": server}, {}, mock_proc)
+        finally:
+            server.close()
+
+    def test_returned_sockets_are_connected(self):
+        """Accepted sockets should be real, usable connections."""
+        server = _start_server()
+        _, port = server.getsockname()
+
+        client = socket.socket()
+        client.connect(("127.0.0.1", port))
+
+        mock_proc = MagicMock(spec=subprocess.Popen)
+        mock_proc.poll.return_value = None
+
+        try:
+            accepted, _ = _accept_connections({"comm": server}, {}, mock_proc)
+            accepted[server].sendall(b"hello")
+            assert client.recv(5) == b"hello"
+            accepted[server].close()
+            client.close()
+        finally:
+            server.close()
+
+    def test_accepted_dict_keyed_by_server_socket_object(self):
+        """The returned accepted mapping must use server socket objects as 
keys,
+        not the string names passed in the servers dict."""
+        server = _start_server()
+        _, port = server.getsockname()
+        self._connect_after_delay(("127.0.0.1", port))
+        mock_proc = MagicMock(spec=subprocess.Popen)
+        mock_proc.poll.return_value = None
+        try:
+            accepted, _ = _accept_connections({"comm": server}, {}, mock_proc)
+            # Key must be the socket object itself, not the string "comm"
+            assert server in accepted
+            assert "comm" not in accepted
+            accepted[server].close()
+        finally:
+            server.close()
+
+
+class TestResourceTracker:
+    """
+    Unit tests for the _ResourceTracker context manager.
+
+    _ResourceTracker tracks sockets and Popen objects and ensures they are
+    closed/terminated on context-manager exit, unless explicitly untracked
+    beforehand.
+    """
+
+    def test_track_returns_passed_objects_as_tuple(self):
+        tracker = _ResourceTracker(timeout=0.1)
+        sock = MagicMock(spec=socket.socket)
+        result = tracker.track(sock)
+        assert result == (sock,)
+
+    def test_track_multiple_objects_returns_all(self):
+        tracker = _ResourceTracker(timeout=0.1)
+        sock1 = MagicMock(spec=socket.socket)
+        sock2 = MagicMock(spec=socket.socket)
+        result = tracker.track(sock1, sock2)
+        assert set(result) == {sock1, sock2}
+
+    def test_untrack_returns_objects(self):
+        tracker = _ResourceTracker(timeout=0.1)
+        sock = MagicMock(spec=socket.socket)
+        tracker.track(sock)
+        result = tracker.untrack(sock)
+        assert result == (sock,)
+
+    def test_context_manager_closes_tracked_socket_on_exit(self):
+        sock = MagicMock(spec=socket.socket)
+        with _ResourceTracker(timeout=0.1) as tracker:
+            tracker.track(sock)
+        sock.close.assert_called_once()
+
+    def test_context_manager_terminates_tracked_popen_on_exit(self):
+        proc = MagicMock(spec=subprocess.Popen)
+        with _ResourceTracker(timeout=0.1) as tracker:
+            tracker.track(proc)
+        proc.terminate.assert_called_once()
+
+    def test_untracked_socket_not_closed_on_exit(self):
+        sock = MagicMock(spec=socket.socket)
+        with _ResourceTracker(timeout=0.1) as tracker:
+            tracker.track(sock)
+            tracker.untrack(sock)
+        sock.close.assert_not_called()
+
+    def test_only_remaining_tracked_objects_cleaned_up(self):
+        """After untracking one socket the other must still be closed."""
+        sock_keep = MagicMock(spec=socket.socket)
+        sock_release = MagicMock(spec=socket.socket)
+        with _ResourceTracker(timeout=0.1) as tracker:
+            tracker.track(sock_keep, sock_release)
+            tracker.untrack(sock_release)
+        sock_keep.close.assert_called_once()
+        sock_release.close.assert_not_called()
+
+    def test_untrack_unknown_object_does_not_raise(self):
+        sock = MagicMock(spec=socket.socket)
+        tracker = _ResourceTracker(timeout=0.1)
+        # Untracking something never tracked must be a no-op, not an error
+        tracker.untrack(sock)
+
+
[email protected](kw_only=True)
+class _StubSocketCoordinator(SocketCoordinator):
+    """Minimal SocketCoordinator subclass used to exercise the base 
machinery."""
+
+    command: list[str]
+    schema_version: str | None = None
+
+    def _build_execute_task_command(self, *, what):
+        return list(self.command), self.schema_version
+
+
[email protected]
+def mock_client(make_ti_context):
+    client = MagicMock()

Review Comment:
   `mock_client` is still a bare `MagicMock()` -- the one fixture Copilot's 
(now-resolved) spec thread pointed at. The rest of this file already specs its 
mocks (`MagicMock(spec=subprocess.Popen)`, `spec=socket.socket`), so this is 
the odd one out. `MagicMock(spec=Client)` here would catch drift if the 
coordinator's `client.task_instances.start(...)` call surface is ever renamed.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to