This is an automated email from the ASF dual-hosted git repository.

jason810496 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e371845a222 Verify TCP connection ownership before accepting Java 
coordinator supervisor channel (#67781)
e371845a222 is described below

commit e371845a2226f9c499dc27212c74f9d10e101833
Author: GPK <[email protected]>
AuthorDate: Sat Jun 6 03:44:15 2026 +0100

    Verify TCP connection ownership before accepting Java coordinator 
supervisor channel (#67781)
    
    * Validate Java coordinator TCP clients
    
    * resolve conflicts
---
 .../src/airflow/sdk/coordinators/_subprocess.py    |  35 +++++++
 .../tests/task_sdk/coordinators/test_subprocess.py | 101 +++++++++++++++++++++
 2 files changed, 136 insertions(+)

diff --git a/task-sdk/src/airflow/sdk/coordinators/_subprocess.py 
b/task-sdk/src/airflow/sdk/coordinators/_subprocess.py
index da3ba5243a5..c223243eb6f 100644
--- a/task-sdk/src/airflow/sdk/coordinators/_subprocess.py
+++ b/task-sdk/src/airflow/sdk/coordinators/_subprocess.py
@@ -37,6 +37,7 @@ import time
 from typing import TYPE_CHECKING, TypeVar, cast
 
 import attrs
+import psutil
 import structlog
 
 from airflow.sdk.execution_time.coordinator import BaseCoordinator
@@ -65,6 +66,31 @@ def _start_server() -> socket.socket:
     return server
 
 
+def _socket_address(value: tuple | str) -> tuple[str, int] | None:
+    if not isinstance(value, tuple) or len(value) < 2:
+        return None
+    host, port = value[:2]
+    return str(host), int(port)
+
+
+def _is_connection_from_process(conn: socket.socket, proc: subprocess.Popen) 
-> bool:
+    """Return whether the accepted TCP connection belongs to the child 
process."""
+    peer = _socket_address(conn.getpeername())
+    local = _socket_address(conn.getsockname())
+    if peer is None or local is None:
+        return False
+    try:
+        process = psutil.Process(proc.pid)
+        connections = process.net_connections(kind="tcp")
+    except (psutil.AccessDenied, psutil.NoSuchProcess, psutil.ZombieProcess, 
OSError):
+        log.warning("Unable to verify child process connection", pid=proc.pid, 
exc_info=True)
+        return False
+    for connection in connections:
+        if _socket_address(connection.laddr) == peer and 
_socket_address(connection.raddr) == local:
+            return True
+    return False
+
+
 def _accept_connections(
     servers: dict[str, socket.socket],
     drains: dict[str, socket.socket],
@@ -102,6 +128,15 @@ def _accept_connections(
                 else:
                     log.debug("Accepting child process connection", 
key=event.data)
                     conn, _ = soc.accept()
+                    if not _is_connection_from_process(conn, proc):
+                        log.warning(
+                            "Rejected connection not owned by child process",
+                            key=event.data,
+                            pid=proc.pid,
+                            peer=conn.getpeername(),
+                        )
+                        conn.close()
+                        continue
                     sel.unregister(soc)
                     accepted[soc] = conn
     return accepted, drained
diff --git a/task-sdk/tests/task_sdk/coordinators/test_subprocess.py 
b/task-sdk/tests/task_sdk/coordinators/test_subprocess.py
index e5105e9f223..95a41e6282b 100644
--- a/task-sdk/tests/task_sdk/coordinators/test_subprocess.py
+++ b/task-sdk/tests/task_sdk/coordinators/test_subprocess.py
@@ -18,8 +18,10 @@
 from __future__ import annotations
 
 import contextlib
+import os
 import socket
 import subprocess
+import sys
 import threading
 import time
 from unittest.mock import ANY, MagicMock, call, patch
@@ -32,6 +34,7 @@ from airflow.sdk.api.client import Client, 
TaskInstanceOperations
 from airflow.sdk.coordinators._subprocess import (
     SubprocessCoordinator,
     _accept_connections,
+    _is_connection_from_process,
     _PopenActivitySubprocess,
     _ResourceTracker,
     _start_server,
@@ -102,6 +105,14 @@ class TestStartServer:
 
 
 class TestAcceptConnections:
+    @pytest.fixture(autouse=True)
+    def mock_child_connection_check(self):
+        with patch(
+            "airflow.sdk.coordinators._subprocess._is_connection_from_process",
+            return_value=True,
+        ) as mock_check:
+            yield mock_check
+
     def _connect_after_delay(self, addr: tuple[str, int], delay: float = 0.0) 
-> None:
         def _connect():
             time.sleep(delay)
@@ -256,6 +267,96 @@ class TestAcceptConnections:
         finally:
             server.close()
 
+    def test_rejects_connections_not_owned_by_child_process(self, 
mock_child_connection_check):
+        server = _start_server()
+        _, port = server.getsockname()
+        mock_child_connection_check.side_effect = [False, True]
+        self._connect_after_delay(("127.0.0.1", port))
+        self._connect_after_delay(("127.0.0.1", port), delay=0.05)
+
+        mock_proc = MagicMock(spec=subprocess.Popen)
+        mock_proc.pid = 12345
+        mock_proc.poll.return_value = None
+
+        try:
+            accepted, _ = _accept_connections({"comm": server}, {}, mock_proc)
+            assert mock_child_connection_check.call_count == 2
+            assert server in accepted
+            accepted[server].close()
+        finally:
+            server.close()
+
+
+class TestAcceptConnectionsProcessValidation:
+    def _start_connector_process(self, addr: tuple[str, int], *, delay: float 
= 0.0) -> subprocess.Popen:
+        script = """
+import socket
+import sys
+import time
+
+time.sleep(float(sys.argv[3]))
+sock = socket.socket()
+sock.connect((sys.argv[1], int(sys.argv[2])))
+sock.recv(1)
+"""
+        return subprocess.Popen([sys.executable, "-c", script, addr[0], 
str(addr[1]), str(delay)])
+
+    def test_rejects_racing_connection_from_other_process(self):
+        server = _start_server()
+        addr = server.getsockname()
+        attacker = socket.socket()
+        attacker.connect(addr)
+        child_proc = self._start_connector_process(addr, delay=0.05)
+
+        try:
+            accepted, _ = _accept_connections({"comm": server}, {}, child_proc)
+            accepted[server].sendall(b"x")
+            accepted[server].close()
+            assert child_proc.wait(timeout=5) == 0
+            assert attacker.recv(1) == b""
+        finally:
+            attacker.close()
+            server.close()
+            if child_proc.poll() is None:
+                child_proc.terminate()
+                child_proc.wait(timeout=5)
+
+
+class TestConnectionFromProcess:
+    def test_matches_child_process_tcp_connection(self):
+        server = _start_server()
+        _, port = server.getsockname()
+        client = socket.socket()
+        client.connect(("127.0.0.1", port))
+        conn, _ = server.accept()
+        mock_proc = MagicMock(spec=subprocess.Popen)
+        mock_proc.pid = os.getpid()
+
+        try:
+            assert _is_connection_from_process(conn, mock_proc) is True
+        finally:
+            conn.close()
+            client.close()
+            server.close()
+
+    def test_rejects_tcp_connection_not_owned_by_child_process(self):
+        server = _start_server()
+        _, port = server.getsockname()
+        client = socket.socket()
+        client.connect(("127.0.0.1", port))
+        conn, _ = server.accept()
+        mock_proc = MagicMock(spec=subprocess.Popen)
+        mock_proc.pid = os.getpid()
+
+        try:
+            with patch("airflow.sdk.coordinators._subprocess.psutil.Process") 
as mock_process:
+                mock_process.return_value.net_connections.return_value = []
+                assert _is_connection_from_process(conn, mock_proc) is False
+        finally:
+            conn.close()
+            client.close()
+            server.close()
+
 
 class TestResourceTracker:
     """

Reply via email to