This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b9620bf4b77 Fix reading huge (XCom) resposne in TaskSDK task process 
(#53186)
b9620bf4b77 is described below

commit b9620bf4b77ec45b0e776aa56cab2b6ed7744dbe
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Fri Jul 11 16:00:48 2025 +0100

    Fix reading huge (XCom) resposne in TaskSDK task process (#53186)
    
    If you tried to send a large XCom value, it would fail in the task/child
    process side with this error:
    
    > RuntimeError: unable to read full response in child. (We read 36476, but 
expected 1310046)
    
    (The exact number that was able to read dependent on any different factors,
    like the OS, the current state of the socket and other things. Sometimes it
    would read up to 256kb fine, othertimes only 35kb as here)
    
    This is because the kernel level read-side socket buffer is full, so that 
was
    as much as the Supervisor could send. The fix is to read in a loop until we
    get it all.
---
 task-sdk/src/airflow/sdk/execution_time/comms.py   | 19 +++++++-------
 .../tests/task_sdk/execution_time/test_comms.py    | 30 ++++++++++++++++++++++
 2 files changed, 40 insertions(+), 9 deletions(-)

diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py 
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index fb9bdbaf6ba..74a53991ff2 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -234,15 +234,16 @@ class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]):
         length = int.from_bytes(len_bytes, byteorder="big")
 
         buffer = bytearray(length)
-        nread = self.socket.recv_into(buffer)
-        if nread != length:
-            raise RuntimeError(
-                f"unable to read full response in child. (We read {nread}, but 
expected {length})"
-            )
-        if nread == 0:
-            raise EOFError(f"Request socket closed before response was 
complete ({self.id_counter=})")
-
-        resp = self.resp_decoder.decode(buffer)
+        mv = memoryview(buffer)
+
+        pos = 0
+        while pos < length:
+            nread = self.socket.recv_into(mv[pos:])
+            if nread == 0:
+                raise EOFError(f"Request socket closed before response was 
complete ({self.id_counter=})")
+            pos += nread
+
+        resp = self.resp_decoder.decode(mv)
         if maxfds:
             return resp, fds or []
         return resp
diff --git a/task-sdk/tests/task_sdk/execution_time/test_comms.py 
b/task-sdk/tests/task_sdk/execution_time/test_comms.py
index b2e7f5e71c0..5595fc2775f 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_comms.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_comms.py
@@ -17,6 +17,7 @@
 
 from __future__ import annotations
 
+import threading
 import uuid
 from socket import socketpair
 
@@ -82,3 +83,32 @@ class TestCommsDecoder:
         assert msg.dag_rel_path == "/dev/null"
         assert msg.bundle_info == BundleInfo(name="any-name", 
version="any-version")
         assert msg.start_date == timezone.datetime(2024, 12, 1, 1)
+
+    def test_huge_payload(self):
+        r, w = socketpair()
+
+        msg = {
+            "type": "XComResult",
+            "key": "a",
+            "value": ("a" * 10 * 1024 * 1024) + "b",  # A 10mb xcom value
+        }
+
+        w.settimeout(1.0)
+        bytes = msgspec.msgpack.encode(_ResponseFrame(0, msg, None))
+
+        # Since `sendall` blocks, we need to do the send in another thread, so 
we can perform the read here
+        t = threading.Thread(target=w.sendall, args=(len(bytes).to_bytes(4, 
byteorder="big") + bytes,))
+        t.start()
+
+        decoder = CommsDecoder(socket=r, log=None)
+
+        try:
+            msg = decoder._get_response()
+        finally:
+            t.join(2)
+
+        assert msg is not None
+
+        # It actually failed to read at all for large values, but lets just 
make sure we get it all
+        assert len(msg.value) == 10 * 1024 * 1024 + 1
+        assert msg.value[-1] == "b"

Reply via email to