This is an automated email from the ASF dual-hosted git repository.

dabla pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3ac0d74db71 Fix read out-of-order issue with send method in 
CommsDecoder (#64894)
3ac0d74db71 is described below

commit 3ac0d74db71ad358794cfbfd20665d0bcf1f0132
Author: David Blain <[email protected]>
AuthorDate: Thu Apr 9 09:13:22 2026 +0200

    Fix read out-of-order issue with send method in CommsDecoder (#64894)
    
    * refactor: Fix read out-of-order issue with send method in CommsDecoder
---
 task-sdk/src/airflow/sdk/execution_time/comms.py   |  2 +-
 .../tests/task_sdk/execution_time/test_comms.py    | 49 +++++++++++++++++++++-
 2 files changed, 48 insertions(+), 3 deletions(-)

diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py 
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index a4755424977..f4d83baef6d 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -216,7 +216,7 @@ class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]):
                 # always be in the return type union
                 return resp  # type: ignore[return-value]
 
-        return self._get_response()
+            return self._get_response()
 
     async def asend(self, msg: SendMsgType) -> ReceiveMsgType | None:
         """
diff --git a/task-sdk/tests/task_sdk/execution_time/test_comms.py 
b/task-sdk/tests/task_sdk/execution_time/test_comms.py
index 861b1d51c8a..5c6d8843925 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_comms.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_comms.py
@@ -23,10 +23,17 @@ from socket import socketpair
 
 import msgspec
 import pytest
+import structlog
 
 from airflow.sdk import timezone
-from airflow.sdk.execution_time.comms import BundleInfo, MaskSecret, 
StartupDetails, _ResponseFrame
-from airflow.sdk.execution_time.task_runner import CommsDecoder
+from airflow.sdk.execution_time.comms import (
+    BundleInfo,
+    CommsDecoder,
+    MaskSecret,
+    StartupDetails,
+    VariableResult,
+    _ResponseFrame,
+)
 
 
 class TestCommsModels:
@@ -148,3 +155,41 @@ class TestCommsDecoder:
         # It actually failed to read at all for large values, but lets just 
make sure we get it all
         assert len(msg.value) == 10 * 1024 * 1024 + 1
         assert msg.value[-1] == "b"
+
+    def test_send_thread_safety(self):
+        r, w = socketpair()
+        decoder = CommsDecoder(socket=r, log=structlog.get_logger())
+        num_threads = 5
+        results = [None] * num_threads
+        errors = [None] * num_threads
+        request_sent = [threading.Event() for _ in range(num_threads)]
+
+        def send_and_store(idx):
+            request_sent[idx].set()  # Signal that this thread is about to send
+            try:
+                msg = VariableResult(key=f"key{idx}", value=f"value{idx}", 
type="VariableResult")
+                results[idx] = decoder.send(msg)
+            except Exception as e:
+                errors[idx] = e
+
+        threads = [threading.Thread(target=send_and_store, args=(i,)) for i in 
range(num_threads)]
+        for t in threads:
+            t.start()
+
+        # For each thread, wait until it signals it's ready, then send the 
response
+        for idx in range(num_threads):
+            request_sent[idx].wait()
+            resp = {"type": "VariableResult", "key": f"key{idx}", "value": 
f"value{idx}"}
+            frame = _ResponseFrame(idx, resp, None)
+            data = msgspec.msgpack.encode(frame)
+            w.sendall(len(data).to_bytes(4, byteorder="big") + data)
+
+        for t in threads:
+            t.join(timeout=5)
+        for idx, t in enumerate(threads):
+            assert not t.is_alive(), f"Thread {idx} did not finish (possible 
deadlock or hang in send method)"
+
+        for idx in range(num_threads):
+            assert errors[idx] is None, f"Thread {idx} error: {errors[idx]}"
+            assert results[idx].key == f"key{idx}", f"Out-of-order or missing 
response for thread {idx}"
+            assert results[idx].value == f"value{idx}", f"Incorrect value for 
thread {idx}"

Reply via email to