Copilot commented on code in PR #67635: URL: https://github.com/apache/airflow/pull/67635#discussion_r3316173904
########## task-sdk/tests/task_sdk/coordinators/socket/test_coordinator.py: ########## @@ -0,0 +1,580 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import contextlib +import socket +import subprocess +import threading +import time +from unittest.mock import ANY, MagicMock, call, patch + +import attrs +import pytest +from uuid6 import uuid7 + +from airflow.sdk.coordinators.socket.coordinator import ( + SocketCoordinator, + _accept_connections, + _ResourceTracker, + _SocketActivitySubprocess, + _start_server, +) +from airflow.sdk.execution_time.coordinator import BaseCoordinator +from airflow.sdk.execution_time.supervisor import ActivitySubprocess +from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO + +from tests_common.test_utils.version_compat import AIRFLOW_V_3_3_PLUS + +if not AIRFLOW_V_3_3_PLUS: + pytest.skip("Coordinator is only compatible with Airflow >= 3.3.0", allow_module_level=True) + + +def _make_ti(dag_id: str = "tutorial_dag", queue: str = "socket") -> TaskInstanceDTO: + return TaskInstanceDTO( + id=uuid7(), + dag_version_id=uuid7(), + task_id="task_1", + dag_id=dag_id, + run_id="run_1", + try_number=1, + map_index=-1, + pool_slots=1, + queue=queue, + priority_weight=1, + ) + + +class TestStartServer: + def test_returns_listening_socket(self): + server = _start_server() + try: + host, port = server.getsockname() + finally: + server.close() + assert host == "127.0.0.1" + assert port > 0 + + def test_two_calls_return_different_ports(self): + s1 = _start_server() + s2 = _start_server() + try: + _, port1 = s1.getsockname() + _, port2 = s2.getsockname() + finally: + s1.close() + s2.close() + assert port1 != port2 + + def test_accepts_connection(self): + conn = client = None + server = _start_server() + try: + _, port = server.getsockname() + client = socket.socket() + client.connect(("127.0.0.1", port)) + conn, _ = server.accept() + conn.sendall(b"ping") + received = client.recv(4) + finally: + if conn: + conn.close() + if client: + client.close() + server.close() + assert received == b"ping" + + +class TestAcceptConnections: + def _connect_after_delay(self, addr: tuple[str, int], delay: float = 0.0) -> None: + def _connect(): + time.sleep(delay) + c = socket.socket() + with contextlib.suppress(OSError): + c.connect(addr) + + threading.Thread(target=_connect, daemon=True).start() Review Comment: This test uses `time.sleep()` to coordinate a background connector thread. Sleeps tend to make socket/concurrency tests flaky under CI load. Prefer synchronizing with `threading.Event` (or similar) and/or polling with a timeout so the test doesn't rely on wall-clock delays. ########## task-sdk/src/airflow/sdk/coordinators/socket/__init__.py: ########## @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Socket-based subprocess coordinator base for the Apache Airflow Task SDK.""" + +from __future__ import annotations + +from airflow.sdk.coordinators.socket.coordinator import SocketCoordinator + +__all__ = ["SocketCoordinator", "__version__"] + +__version__ = "0.1.0" Review Comment: `airflow.sdk.coordinators.socket` defines and exports its own `__version__ = "0.1.0"`, which is inconsistent with the rest of the Task SDK (the distribution version is exposed as `airflow.sdk.__version__`). Keeping a separate subpackage version risks confusion and stale/incorrect version reporting. Consider removing this `__version__` entirely (and dropping it from `__all__`), or deriving it from the Task SDK package version if you truly need it. ########## task-sdk/tests/task_sdk/coordinators/socket/test_coordinator.py: ########## @@ -0,0 +1,580 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import contextlib +import socket +import subprocess +import threading +import time +from unittest.mock import ANY, MagicMock, call, patch + +import attrs +import pytest +from uuid6 import uuid7 + +from airflow.sdk.coordinators.socket.coordinator import ( + SocketCoordinator, + _accept_connections, + _ResourceTracker, + _SocketActivitySubprocess, + _start_server, +) +from airflow.sdk.execution_time.coordinator import BaseCoordinator +from airflow.sdk.execution_time.supervisor import ActivitySubprocess +from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO + +from tests_common.test_utils.version_compat import AIRFLOW_V_3_3_PLUS + +if not AIRFLOW_V_3_3_PLUS: + pytest.skip("Coordinator is only compatible with Airflow >= 3.3.0", allow_module_level=True) + + +def _make_ti(dag_id: str = "tutorial_dag", queue: str = "socket") -> TaskInstanceDTO: + return TaskInstanceDTO( + id=uuid7(), + dag_version_id=uuid7(), + task_id="task_1", + dag_id=dag_id, + run_id="run_1", + try_number=1, + map_index=-1, + pool_slots=1, + queue=queue, + priority_weight=1, + ) + + +class TestStartServer: + def test_returns_listening_socket(self): + server = _start_server() + try: + host, port = server.getsockname() + finally: + server.close() + assert host == "127.0.0.1" + assert port > 0 + + def test_two_calls_return_different_ports(self): + s1 = _start_server() + s2 = _start_server() + try: + _, port1 = s1.getsockname() + _, port2 = s2.getsockname() + finally: + s1.close() + s2.close() + assert port1 != port2 + + def test_accepts_connection(self): + conn = client = None + server = _start_server() + try: + _, port = server.getsockname() + client = socket.socket() + client.connect(("127.0.0.1", port)) + conn, _ = server.accept() + conn.sendall(b"ping") + received = client.recv(4) + finally: + if conn: + conn.close() + if client: + client.close() + server.close() + assert received == b"ping" + + +class TestAcceptConnections: + def _connect_after_delay(self, addr: tuple[str, int], delay: float = 0.0) -> None: + def _connect(): + time.sleep(delay) + c = socket.socket() + with contextlib.suppress(OSError): + c.connect(addr) Review Comment: The background `_connect` helper creates a client socket (`c = socket.socket()`) but never closes it. Even though these are short-lived tests, leaking sockets in a loop can accumulate and cause intermittent failures on some platforms. Use a context manager (`with socket.socket() as c:`) or explicitly `close()` the socket in a `finally:` block. ########## task-sdk/tests/task_sdk/coordinators/socket/test_coordinator.py: ########## @@ -0,0 +1,580 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import contextlib +import socket +import subprocess +import threading +import time +from unittest.mock import ANY, MagicMock, call, patch + +import attrs +import pytest +from uuid6 import uuid7 + +from airflow.sdk.coordinators.socket.coordinator import ( + SocketCoordinator, + _accept_connections, + _ResourceTracker, + _SocketActivitySubprocess, + _start_server, +) +from airflow.sdk.execution_time.coordinator import BaseCoordinator +from airflow.sdk.execution_time.supervisor import ActivitySubprocess +from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO + +from tests_common.test_utils.version_compat import AIRFLOW_V_3_3_PLUS + +if not AIRFLOW_V_3_3_PLUS: + pytest.skip("Coordinator is only compatible with Airflow >= 3.3.0", allow_module_level=True) + + +def _make_ti(dag_id: str = "tutorial_dag", queue: str = "socket") -> TaskInstanceDTO: + return TaskInstanceDTO( + id=uuid7(), + dag_version_id=uuid7(), + task_id="task_1", + dag_id=dag_id, + run_id="run_1", + try_number=1, + map_index=-1, + pool_slots=1, + queue=queue, + priority_weight=1, + ) + + +class TestStartServer: + def test_returns_listening_socket(self): + server = _start_server() + try: + host, port = server.getsockname() + finally: + server.close() + assert host == "127.0.0.1" + assert port > 0 + + def test_two_calls_return_different_ports(self): + s1 = _start_server() + s2 = _start_server() + try: + _, port1 = s1.getsockname() + _, port2 = s2.getsockname() + finally: + s1.close() + s2.close() + assert port1 != port2 + + def test_accepts_connection(self): + conn = client = None + server = _start_server() + try: + _, port = server.getsockname() + client = socket.socket() + client.connect(("127.0.0.1", port)) + conn, _ = server.accept() + conn.sendall(b"ping") + received = client.recv(4) + finally: + if conn: + conn.close() + if client: + client.close() + server.close() + assert received == b"ping" + + +class TestAcceptConnections: + def _connect_after_delay(self, addr: tuple[str, int], delay: float = 0.0) -> None: + def _connect(): + time.sleep(delay) + c = socket.socket() + with contextlib.suppress(OSError): + c.connect(addr) + + threading.Thread(target=_connect, daemon=True).start() + + def test_accepts_single_server(self): + server = _start_server() + _, port = server.getsockname() + self._connect_after_delay(("127.0.0.1", port)) + + mock_proc = MagicMock(spec=subprocess.Popen) + mock_proc.poll.return_value = None + + try: + accepted, _ = _accept_connections({"comm": server}, {}, mock_proc) + assert server in accepted + accepted[server].close() + finally: + server.close() + + def test_accepts_multiple_servers_keyed_by_server_socket(self): + comm_server = _start_server() + logs_server = _start_server() + _, comm_port = comm_server.getsockname() + _, logs_port = logs_server.getsockname() + + self._connect_after_delay(("127.0.0.1", comm_port)) + self._connect_after_delay(("127.0.0.1", logs_port)) + + mock_proc = MagicMock(spec=subprocess.Popen) + mock_proc.poll.return_value = None + + try: + accepted, drained = _accept_connections({"comm": comm_server, "logs": logs_server}, {}, mock_proc) + assert set(accepted) == {comm_server, logs_server} + assert drained == {} + for sock in accepted.values(): + sock.close() + finally: + comm_server.close() + logs_server.close() + + def test_empty_drains_returns_empty_drained_dict(self): + server = _start_server() + _, port = server.getsockname() + self._connect_after_delay(("127.0.0.1", port)) + + mock_proc = MagicMock(spec=subprocess.Popen) + mock_proc.poll.return_value = None + try: + _, drained = _accept_connections({"comm": server}, {}, mock_proc) + assert drained == {} + finally: + server.close() + + def test_drain_socket_present_in_drained_dict(self): + server = _start_server() + drain_r, drain_w = socket.socketpair() + _, port = server.getsockname() + self._connect_after_delay(("127.0.0.1", port)) + + mock_proc = MagicMock(spec=subprocess.Popen) + mock_proc.poll.return_value = None + try: + _, drained = _accept_connections({"comm": server}, {"stdout": drain_r}, mock_proc) + assert drain_r in drained + finally: + drain_r.close() + drain_w.close() + server.close() + + def test_drain_captures_early_output(self): + """Bytes written to the drain socket before the comm server accepts + must be captured and returned in the drained dict.""" + server = _start_server() + drain_r, drain_w = socket.socketpair() + _, port = server.getsockname() + + drain_w.sendall(b"early output\n") + drain_w.shutdown(socket.SHUT_WR) + self._connect_after_delay(("127.0.0.1", port), delay=0.05) + + mock_proc = MagicMock(spec=subprocess.Popen) + mock_proc.poll.return_value = None + try: + _, drained = _accept_connections({"comm": server}, {"stdout": drain_r}, mock_proc) + assert drained[drain_r] == b"early output\n" + finally: + drain_r.close() + drain_w.close() + server.close() + + def test_raises_timeout_when_no_connection(self): + server = _start_server() + mock_proc = MagicMock(spec=subprocess.Popen) + mock_proc.poll.return_value = None + try: + with pytest.raises(TimeoutError, match="did not connect within timeout"): + _accept_connections({"comm": server}, {}, mock_proc, max_wait=0.05) + finally: + server.close() + + def test_raises_runtime_error_if_process_exits_before_connecting(self): + server = _start_server() + mock_proc = MagicMock(spec=subprocess.Popen) + mock_proc.poll.return_value = 1 + mock_proc.returncode = 1 + try: + with pytest.raises(RuntimeError, match="process exited with 1"): + _accept_connections({"comm": server}, {}, mock_proc) + finally: + server.close() + + def test_returned_sockets_are_connected(self): + """Accepted sockets should be real, usable connections.""" + server = _start_server() + _, port = server.getsockname() + + client = socket.socket() + client.connect(("127.0.0.1", port)) + + mock_proc = MagicMock(spec=subprocess.Popen) + mock_proc.poll.return_value = None + + try: + accepted, _ = _accept_connections({"comm": server}, {}, mock_proc) + accepted[server].sendall(b"hello") + assert client.recv(5) == b"hello" + accepted[server].close() + client.close() + finally: + server.close() + + def test_accepted_dict_keyed_by_server_socket_object(self): + """The returned accepted mapping must use server socket objects as keys, + not the string names passed in the servers dict.""" + server = _start_server() + _, port = server.getsockname() + self._connect_after_delay(("127.0.0.1", port)) + mock_proc = MagicMock(spec=subprocess.Popen) + mock_proc.poll.return_value = None + try: + accepted, _ = _accept_connections({"comm": server}, {}, mock_proc) + # Key must be the socket object itself, not the string "comm" + assert server in accepted + assert "comm" not in accepted + accepted[server].close() + finally: + server.close() + + +class TestResourceTracker: + """ + Unit tests for the _ResourceTracker context manager. + + _ResourceTracker tracks sockets and Popen objects and ensures they are + closed/terminated on context-manager exit, unless explicitly untracked + beforehand. + """ + + def test_track_returns_passed_objects_as_tuple(self): + tracker = _ResourceTracker(timeout=0.1) + sock = MagicMock(spec=socket.socket) + result = tracker.track(sock) + assert result == (sock,) + + def test_track_multiple_objects_returns_all(self): + tracker = _ResourceTracker(timeout=0.1) + sock1 = MagicMock(spec=socket.socket) + sock2 = MagicMock(spec=socket.socket) + result = tracker.track(sock1, sock2) + assert set(result) == {sock1, sock2} + + def test_untrack_returns_objects(self): + tracker = _ResourceTracker(timeout=0.1) + sock = MagicMock(spec=socket.socket) + tracker.track(sock) + result = tracker.untrack(sock) + assert result == (sock,) + + def test_context_manager_closes_tracked_socket_on_exit(self): + sock = MagicMock(spec=socket.socket) + with _ResourceTracker(timeout=0.1) as tracker: + tracker.track(sock) + sock.close.assert_called_once() + + def test_context_manager_terminates_tracked_popen_on_exit(self): + proc = MagicMock(spec=subprocess.Popen) + with _ResourceTracker(timeout=0.1) as tracker: + tracker.track(proc) + proc.terminate.assert_called_once() + + def test_untracked_socket_not_closed_on_exit(self): + sock = MagicMock(spec=socket.socket) + with _ResourceTracker(timeout=0.1) as tracker: + tracker.track(sock) + tracker.untrack(sock) + sock.close.assert_not_called() + + def test_only_remaining_tracked_objects_cleaned_up(self): + """After untracking one socket the other must still be closed.""" + sock_keep = MagicMock(spec=socket.socket) + sock_release = MagicMock(spec=socket.socket) + with _ResourceTracker(timeout=0.1) as tracker: + tracker.track(sock_keep, sock_release) + tracker.untrack(sock_release) + sock_keep.close.assert_called_once() + sock_release.close.assert_not_called() + + def test_untrack_unknown_object_does_not_raise(self): + sock = MagicMock(spec=socket.socket) + tracker = _ResourceTracker(timeout=0.1) + # Untracking something never tracked must be a no-op, not an error + tracker.untrack(sock) + + [email protected](kw_only=True) +class _StubSocketCoordinator(SocketCoordinator): + """Minimal SocketCoordinator subclass used to exercise the base machinery.""" + + command: list[str] + schema_version: str | None = None + + def _build_execute_task_command(self, *, what): + return list(self.command), self.schema_version + + [email protected] +def mock_client(make_ti_context): + client = MagicMock() + client.task_instances.start.return_value = make_ti_context() + return client Review Comment: `mock_client` is created as a bare `MagicMock()` without `spec`/`autospec`, which can hide real API mismatches (e.g. typos in attribute/method names) and makes these coordinator tests less robust. Please consider using `MagicMock(spec=Client)` (and similarly spec-ing nested mocks like `task_instances`) or `create_autospec` to ensure the mock matches the real client surface. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
