This is an automated email from the ASF dual-hosted git repository.
jason810496 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e371845a222 Verify TCP connection ownership before accepting Java
coordinator supervisor channel (#67781)
e371845a222 is described below
commit e371845a2226f9c499dc27212c74f9d10e101833
Author: GPK <[email protected]>
AuthorDate: Sat Jun 6 03:44:15 2026 +0100
Verify TCP connection ownership before accepting Java coordinator
supervisor channel (#67781)
* Validate Java coordinator TCP clients
* resolve conflicts
---
.../src/airflow/sdk/coordinators/_subprocess.py | 35 +++++++
.../tests/task_sdk/coordinators/test_subprocess.py | 101 +++++++++++++++++++++
2 files changed, 136 insertions(+)
diff --git a/task-sdk/src/airflow/sdk/coordinators/_subprocess.py
b/task-sdk/src/airflow/sdk/coordinators/_subprocess.py
index da3ba5243a5..c223243eb6f 100644
--- a/task-sdk/src/airflow/sdk/coordinators/_subprocess.py
+++ b/task-sdk/src/airflow/sdk/coordinators/_subprocess.py
@@ -37,6 +37,7 @@ import time
from typing import TYPE_CHECKING, TypeVar, cast
import attrs
+import psutil
import structlog
from airflow.sdk.execution_time.coordinator import BaseCoordinator
@@ -65,6 +66,31 @@ def _start_server() -> socket.socket:
return server
+def _socket_address(value: tuple | str) -> tuple[str, int] | None:
+ if not isinstance(value, tuple) or len(value) < 2:
+ return None
+ host, port = value[:2]
+ return str(host), int(port)
+
+
+def _is_connection_from_process(conn: socket.socket, proc: subprocess.Popen)
-> bool:
+ """Return whether the accepted TCP connection belongs to the child
process."""
+ peer = _socket_address(conn.getpeername())
+ local = _socket_address(conn.getsockname())
+ if peer is None or local is None:
+ return False
+ try:
+ process = psutil.Process(proc.pid)
+ connections = process.net_connections(kind="tcp")
+ except (psutil.AccessDenied, psutil.NoSuchProcess, psutil.ZombieProcess,
OSError):
+ log.warning("Unable to verify child process connection", pid=proc.pid,
exc_info=True)
+ return False
+ for connection in connections:
+ if _socket_address(connection.laddr) == peer and
_socket_address(connection.raddr) == local:
+ return True
+ return False
+
+
def _accept_connections(
servers: dict[str, socket.socket],
drains: dict[str, socket.socket],
@@ -102,6 +128,15 @@ def _accept_connections(
else:
log.debug("Accepting child process connection",
key=event.data)
conn, _ = soc.accept()
+ if not _is_connection_from_process(conn, proc):
+ log.warning(
+ "Rejected connection not owned by child process",
+ key=event.data,
+ pid=proc.pid,
+ peer=conn.getpeername(),
+ )
+ conn.close()
+ continue
sel.unregister(soc)
accepted[soc] = conn
return accepted, drained
diff --git a/task-sdk/tests/task_sdk/coordinators/test_subprocess.py
b/task-sdk/tests/task_sdk/coordinators/test_subprocess.py
index e5105e9f223..95a41e6282b 100644
--- a/task-sdk/tests/task_sdk/coordinators/test_subprocess.py
+++ b/task-sdk/tests/task_sdk/coordinators/test_subprocess.py
@@ -18,8 +18,10 @@
from __future__ import annotations
import contextlib
+import os
import socket
import subprocess
+import sys
import threading
import time
from unittest.mock import ANY, MagicMock, call, patch
@@ -32,6 +34,7 @@ from airflow.sdk.api.client import Client,
TaskInstanceOperations
from airflow.sdk.coordinators._subprocess import (
SubprocessCoordinator,
_accept_connections,
+ _is_connection_from_process,
_PopenActivitySubprocess,
_ResourceTracker,
_start_server,
@@ -102,6 +105,14 @@ class TestStartServer:
class TestAcceptConnections:
+ @pytest.fixture(autouse=True)
+ def mock_child_connection_check(self):
+ with patch(
+ "airflow.sdk.coordinators._subprocess._is_connection_from_process",
+ return_value=True,
+ ) as mock_check:
+ yield mock_check
+
def _connect_after_delay(self, addr: tuple[str, int], delay: float = 0.0)
-> None:
def _connect():
time.sleep(delay)
@@ -256,6 +267,96 @@ class TestAcceptConnections:
finally:
server.close()
+ def test_rejects_connections_not_owned_by_child_process(self,
mock_child_connection_check):
+ server = _start_server()
+ _, port = server.getsockname()
+ mock_child_connection_check.side_effect = [False, True]
+ self._connect_after_delay(("127.0.0.1", port))
+ self._connect_after_delay(("127.0.0.1", port), delay=0.05)
+
+ mock_proc = MagicMock(spec=subprocess.Popen)
+ mock_proc.pid = 12345
+ mock_proc.poll.return_value = None
+
+ try:
+ accepted, _ = _accept_connections({"comm": server}, {}, mock_proc)
+ assert mock_child_connection_check.call_count == 2
+ assert server in accepted
+ accepted[server].close()
+ finally:
+ server.close()
+
+
+class TestAcceptConnectionsProcessValidation:
+ def _start_connector_process(self, addr: tuple[str, int], *, delay: float
= 0.0) -> subprocess.Popen:
+ script = """
+import socket
+import sys
+import time
+
+time.sleep(float(sys.argv[3]))
+sock = socket.socket()
+sock.connect((sys.argv[1], int(sys.argv[2])))
+sock.recv(1)
+"""
+ return subprocess.Popen([sys.executable, "-c", script, addr[0],
str(addr[1]), str(delay)])
+
+ def test_rejects_racing_connection_from_other_process(self):
+ server = _start_server()
+ addr = server.getsockname()
+ attacker = socket.socket()
+ attacker.connect(addr)
+ child_proc = self._start_connector_process(addr, delay=0.05)
+
+ try:
+ accepted, _ = _accept_connections({"comm": server}, {}, child_proc)
+ accepted[server].sendall(b"x")
+ accepted[server].close()
+ assert child_proc.wait(timeout=5) == 0
+ assert attacker.recv(1) == b""
+ finally:
+ attacker.close()
+ server.close()
+ if child_proc.poll() is None:
+ child_proc.terminate()
+ child_proc.wait(timeout=5)
+
+
+class TestConnectionFromProcess:
+ def test_matches_child_process_tcp_connection(self):
+ server = _start_server()
+ _, port = server.getsockname()
+ client = socket.socket()
+ client.connect(("127.0.0.1", port))
+ conn, _ = server.accept()
+ mock_proc = MagicMock(spec=subprocess.Popen)
+ mock_proc.pid = os.getpid()
+
+ try:
+ assert _is_connection_from_process(conn, mock_proc) is True
+ finally:
+ conn.close()
+ client.close()
+ server.close()
+
+ def test_rejects_tcp_connection_not_owned_by_child_process(self):
+ server = _start_server()
+ _, port = server.getsockname()
+ client = socket.socket()
+ client.connect(("127.0.0.1", port))
+ conn, _ = server.accept()
+ mock_proc = MagicMock(spec=subprocess.Popen)
+ mock_proc.pid = os.getpid()
+
+ try:
+ with patch("airflow.sdk.coordinators._subprocess.psutil.Process")
as mock_process:
+ mock_process.return_value.net_connections.return_value = []
+ assert _is_connection_from_process(conn, mock_proc) is False
+ finally:
+ conn.close()
+ client.close()
+ server.close()
+
class TestResourceTracker:
"""