This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch v3-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v3-0-test by this push:
new 0044ec712ac [v3-0-test] Deserialize response of `get_all` when we call
`XCom.get_all` (#53020) (#53102)
0044ec712ac is described below
commit 0044ec712ac38b22f85abbfc0c12fc806e595ea7
Author: github-actions[bot]
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Thu Jul 10 00:56:10 2025 +0530
[v3-0-test] Deserialize response of `get_all` when we call `XCom.get_all`
(#53020) (#53102)
(cherry picked from commit bdc9cd115d103614159ea7492db8ff16607c6959)
Co-authored-by: Amogh Desai <[email protected]>
---
task-sdk/src/airflow/sdk/bases/xcom.py | 6 ++-
.../src/airflow/sdk/execution_time/task_runner.py | 12 +++--
.../task_sdk/execution_time/test_task_runner.py | 53 ++++++++++++++++++++++
3 files changed, 65 insertions(+), 6 deletions(-)
diff --git a/task-sdk/src/airflow/sdk/bases/xcom.py
b/task-sdk/src/airflow/sdk/bases/xcom.py
index 770dbf53df7..82df8d151ab 100644
--- a/task-sdk/src/airflow/sdk/bases/xcom.py
+++ b/task-sdk/src/airflow/sdk/bases/xcom.py
@@ -290,6 +290,7 @@ class BaseXCom:
:return: List of all XCom values if found.
"""
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+ from airflow.serialization.serde import deserialize
msg = SUPERVISOR_COMMS.send(
msg=GetXComSequenceSlice(
@@ -306,7 +307,10 @@ class BaseXCom:
if not isinstance(msg, XComSequenceSliceResult):
raise TypeError(f"Expected XComSequenceSliceResult, received:
{type(msg)} {msg}")
- return msg.root
+ result = deserialize(msg.root)
+ if not result:
+ return None
+ return result
@staticmethod
def serialize_value(
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 6c6e597f65e..bdd3b240daf 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -351,17 +351,19 @@ class RuntimeTaskInstance(TaskInstance):
# If map_indexes is not specified, pull xcoms from all map indexes for
each task
if isinstance(map_indexes, ArgNotSet):
- xcoms = [
- value
- for t_id in task_ids
- for value in XCom.get_all(
+ xcoms: list[Any] = []
+ for t_id in task_ids:
+ values = XCom.get_all(
run_id=run_id,
key=key,
task_id=t_id,
dag_id=dag_id,
)
- ]
+ if values is None:
+ xcoms.append(None)
+ else:
+ xcoms.extend(values)
# For single task pulling from unmapped task, return single value
if single_task_requested and len(xcoms) == 1:
return xcoms[0]
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index f2ff8967081..362b89e3ad0 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -1355,6 +1355,7 @@ class TestRuntimeTaskInstance:
pytest.param("hello", id="string_value"),
pytest.param("'hello'", id="quoted_string_value"),
pytest.param({"key": "value"}, id="json_value"),
+ pytest.param([], id="empty_list_no_xcoms_found"),
pytest.param((1, 2, 3), id="tuple_int_value"),
pytest.param([1, 2, 3], id="list_int_value"),
pytest.param(42, id="int_value"),
@@ -1377,6 +1378,9 @@ class TestRuntimeTaskInstance:
"""
map_indexes_kwarg = {} if map_indexes is NOTSET else {"map_indexes":
map_indexes}
task_ids_kwarg = {} if task_ids is NOTSET else {"task_ids": task_ids}
+ from airflow.serialization.serde import deserialize
+
+ spy_agency.spy_on(deserialize)
class CustomOperator(BaseOperator):
def execute(self, context):
@@ -1402,6 +1406,7 @@ class TestRuntimeTaskInstance:
mock_supervisor_comms.send.side_effect = mock_send_side_effect
run(runtime_ti, context=runtime_ti.get_template_context(),
log=mock.MagicMock())
+ spy_agency.assert_spy_called_with(deserialize, ser_value)
if not isinstance(task_ids, Iterable) or isinstance(task_ids, str):
task_ids = [task_ids]
@@ -1507,6 +1512,54 @@ class TestRuntimeTaskInstance:
assert mock_get_one.called
assert not mock_get_all.called
+ @pytest.mark.parametrize(
+ "api_return_value",
+ [
+ pytest.param(("data", "test_value"), id="api returns tuple"),
+ pytest.param({"data": "test_value"}, id="api returns dict"),
+ pytest.param(None, id="api returns None, no xcom found"),
+ ],
+ )
+ def test_xcom_pull_with_no_map_index(
+ self,
+ api_return_value,
+ create_runtime_ti,
+ mock_supervisor_comms,
+ ):
+ """
+ Test xcom_pull when map_indexes is not specified, so that XCom.get_all
is called.
+ The test also tests if the response is deserialized and returned.
+ """
+ test_task_id = "pull_task"
+ task = BaseOperator(task_id=test_task_id)
+ runtime_ti = create_runtime_ti(task=task)
+
+ ser_value = BaseXCom.serialize_value(api_return_value)
+
+ def mock_send_side_effect(*args, **kwargs):
+ msg = kwargs.get("msg") or args[0]
+ if isinstance(msg, GetXComSequenceSlice):
+ return XComSequenceSliceResult(root=[ser_value])
+ return XComResult(key="test_key", value=None)
+
+ mock_supervisor_comms.send.side_effect = mock_send_side_effect
+ result = runtime_ti.xcom_pull(key="test_key", task_ids="task_a")
+
+ # if the API returns a tuple or dict, the below assertion assures that
the value is deserialized correctly by XCom.get_all
+ assert result == api_return_value
+
+ mock_supervisor_comms.send.assert_called_once_with(
+ msg=GetXComSequenceSlice(
+ key="test_key",
+ dag_id=runtime_ti.dag_id,
+ run_id=runtime_ti.run_id,
+ task_id="task_a",
+ start=None,
+ stop=None,
+ step=None,
+ ),
+ )
+
def test_get_param_from_context(
self, mocked_parse, make_ti_context, mock_supervisor_comms,
create_runtime_ti
):