ashb commented on code in PR #68299:
URL: https://github.com/apache/airflow/pull/68299#discussion_r3388006829
##########
task-sdk/src/airflow/sdk/bases/xcom.py:
##########
@@ -324,6 +464,57 @@ def get_all(
return [cls.deserialize_value(_XComValueWrapper(value)) for value in
msg.root]
+ @classmethod
+ async def aget_all(
+ cls,
+ *,
+ key: str,
+ dag_id: str,
+ task_id: str,
+ run_id: str,
+ include_prior_dates: bool = False,
+ ) -> Any:
+ """
+ Retrieve all XCom values for a task asynchronously, typically from all
map indexes.
+
+ XComSequenceSliceResult can never have *None* in it, it returns an
empty list
+ if no values were found.
+
+ This is particularly useful for getting all XCom values from all map
+ indexes of a mapped task at once.
+
+ :param key: A key for the XCom. Only XComs with this key will be
returned.
+ :param run_id: Dag run ID for the task.
+ :param dag_id: Dag ID to pull XComs from.
+ :param task_id: Task ID to pull XComs from.
+ :param include_prior_dates: If *False* (default), only XComs from the
+ specified Dag run are returned. If *True*, the latest matching
XComs are
+ returned regardless of the run they belong to.
+ :return: List of all XCom values if found.
+ """
+ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+ msg = await SUPERVISOR_COMMS.asend(
+ msg=GetXComSequenceSlice(
+ key=key,
+ dag_id=dag_id,
+ task_id=task_id,
+ run_id=run_id,
+ start=None,
+ stop=None,
+ step=None,
+ include_prior_dates=include_prior_dates,
+ ),
+ )
+
+ if not isinstance(msg, XComSequenceSliceResult):
+ raise TypeError(f"Expected XComSequenceSliceResult, received:
{type(msg)} {msg}")
Review Comment:
```suggestion
raise RuntimeError(f"Expected XComSequenceSliceResult, received:
{type(msg)} {msg}")
```
IMO TypeError isn't quite right here, as this is more of a "omg we got the
wrong response back!" than a type error, which is usuall more "you called this
with the wrong type"
##########
task-sdk/tests/task_sdk/execution_time/test_task_runner.py:
##########
@@ -2323,6 +2321,113 @@ def mock_send_side_effect(*args, **kwargs):
),
)
+ @pytest.mark.asyncio
Review Comment:
"Try to leave the code base in a better situation then we found it".
Just because there is an existing pattern that is messy doesn't give us a
reason to blindly copy it.
I agree with Jake here, it's really hard to work out what the test is
actually asserting.
##########
task-sdk/tests/task_sdk/bases/test_xcom.py:
##########
@@ -70,3 +77,190 @@ def
test_delete_includes_map_index_in_delete_xcom_message(self, map_index, mock_
assert sent_message.task_id == "test_task"
assert sent_message.run_id == "test_run"
assert sent_message.map_index == map_index
+
+ @pytest.mark.asyncio
+ async def test_aget_one_returns_value(self, mock_supervisor_comms):
+ """aget_one awaits asend and returns the deserialized value."""
+ mock_supervisor_comms.asend.return_value = XComResult(key="test_key",
value="test_value")
+
+ result = await BaseXCom.aget_one(
+ key="test_key",
+ dag_id="test_dag",
+ task_id="test_task",
+ run_id="test_run",
+ map_index=0,
+ )
+
+ assert result == "test_value"
+ mock_supervisor_comms.asend.assert_called_once_with(
+ GetXCom(
+ key="test_key",
+ dag_id="test_dag",
+ task_id="test_task",
+ run_id="test_run",
+ map_index=0,
+ include_prior_dates=False,
+ )
+ )
+ mock_supervisor_comms.send.assert_not_called()
+
+ @pytest.mark.asyncio
+ async def test_aget_one_returns_none_when_not_found(self,
mock_supervisor_comms):
+ """aget_one returns None when XCom value is not found."""
+ mock_supervisor_comms.asend.return_value = XComResult(key="test_key",
value=None)
+
+ result = await BaseXCom.aget_one(
+ key="test_key",
+ dag_id="test_dag",
+ task_id="test_task",
+ run_id="test_run",
+ )
+
+ assert result is None
+
+ @pytest.mark.asyncio
+ async def test_aget_one_with_include_prior_dates(self,
mock_supervisor_comms):
+ """aget_one passes include_prior_dates parameter correctly."""
+ mock_supervisor_comms.asend.return_value = XComResult(key="test_key",
value="prior_value")
+
+ result = await BaseXCom.aget_one(
+ key="test_key",
+ dag_id="test_dag",
+ task_id="test_task",
+ run_id="test_run",
+ include_prior_dates=True,
+ )
+
+ assert result == "prior_value"
+ mock_supervisor_comms.asend.assert_called_once_with(
+ GetXCom(
+ key="test_key",
+ dag_id="test_dag",
+ task_id="test_task",
+ run_id="test_run",
+ map_index=None,
+ include_prior_dates=True,
+ )
+ )
+
+ @pytest.mark.asyncio
+ async def test_aget_one_raises_on_invalid_response(self,
mock_supervisor_comms):
+ """aget_one raises TypeError when receiving unexpected response
type."""
+ mock_supervisor_comms.asend.return_value = "invalid_response"
+
+ with pytest.raises(TypeError, match="Expected XComResult"):
+ await BaseXCom.aget_one(
+ key="test_key",
+ dag_id="test_dag",
+ task_id="test_task",
+ run_id="test_run",
+ )
+
+ @pytest.mark.asyncio
+ async def test_aget_all_returns_values(self, mock_supervisor_comms):
+ """aget_all awaits asend and returns deserialized values from all map
indexes."""
+ mock_supervisor_comms.asend.return_value = XComSequenceSliceResult(
+ root=["value1", "value2", "value3"]
+ )
+
+ result = await BaseXCom.aget_all(
+ key="test_key",
+ dag_id="test_dag",
+ task_id="test_task",
+ run_id="test_run",
+ )
+
+ assert result == ["value1", "value2", "value3"]
+ mock_supervisor_comms.asend.assert_called_once_with(
+ msg=GetXComSequenceSlice(
+ key="test_key",
+ dag_id="test_dag",
+ task_id="test_task",
+ run_id="test_run",
+ start=None,
+ stop=None,
+ step=None,
+ include_prior_dates=False,
+ )
+ )
+ mock_supervisor_comms.send.assert_not_called()
+
+ @pytest.mark.asyncio
+ async def test_aget_all_returns_none_when_empty(self,
mock_supervisor_comms):
+ """aget_all returns None when no XCom values are found."""
+ mock_supervisor_comms.asend.return_value =
XComSequenceSliceResult(root=[])
+
+ result = await BaseXCom.aget_all(
+ key="test_key",
+ dag_id="test_dag",
+ task_id="test_task",
+ run_id="test_run",
+ )
+
+ assert result is None
+
+ @pytest.mark.asyncio
+ async def test_aget_all_with_include_prior_dates(self,
mock_supervisor_comms):
+ """aget_all passes include_prior_dates parameter correctly."""
+ mock_supervisor_comms.asend.return_value =
XComSequenceSliceResult(root=["prior_value"])
+
+ result = await BaseXCom.aget_all(
+ key="test_key",
+ dag_id="test_dag",
+ task_id="test_task",
+ run_id="test_run",
+ include_prior_dates=True,
+ )
+
+ assert result == ["prior_value"]
+ mock_supervisor_comms.asend.assert_called_once_with(
+ msg=GetXComSequenceSlice(
+ key="test_key",
+ dag_id="test_dag",
+ task_id="test_task",
+ run_id="test_run",
+ start=None,
+ stop=None,
+ step=None,
+ include_prior_dates=True,
+ )
+ )
+
+ @pytest.mark.asyncio
+ async def test_aget_all_raises_on_invalid_response(self,
mock_supervisor_comms):
+ """aget_all raises TypeError when receiving unexpected response
type."""
+ mock_supervisor_comms.asend.return_value = "invalid_response"
+
+ with pytest.raises(TypeError, match="Expected
XComSequenceSliceResult"):
+ await BaseXCom.aget_all(
+ key="test_key",
+ dag_id="test_dag",
+ task_id="test_task",
+ run_id="test_run",
+ )
+
+ @pytest.mark.asyncio
Review Comment:
Do we need to mark every test fn? I thought this was automatic for asnyc
test functions?
##########
task-sdk/src/airflow/sdk/bases/xcom.py:
##########
@@ -324,6 +464,57 @@ def get_all(
return [cls.deserialize_value(_XComValueWrapper(value)) for value in
msg.root]
+ @classmethod
+ async def aget_all(
+ cls,
+ *,
+ key: str,
+ dag_id: str,
+ task_id: str,
+ run_id: str,
+ include_prior_dates: bool = False,
+ ) -> Any:
+ """
+ Retrieve all XCom values for a task asynchronously, typically from all
map indexes.
+
+ XComSequenceSliceResult can never have *None* in it, it returns an
empty list
+ if no values were found.
+
+ This is particularly useful for getting all XCom values from all map
+ indexes of a mapped task at once.
+
+ :param key: A key for the XCom. Only XComs with this key will be
returned.
+ :param run_id: Dag run ID for the task.
+ :param dag_id: Dag ID to pull XComs from.
+ :param task_id: Task ID to pull XComs from.
+ :param include_prior_dates: If *False* (default), only XComs from the
+ specified Dag run are returned. If *True*, the latest matching
XComs are
+ returned regardless of the run they belong to.
+ :return: List of all XCom values if found.
+ """
+ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+ msg = await SUPERVISOR_COMMS.asend(
+ msg=GetXComSequenceSlice(
+ key=key,
+ dag_id=dag_id,
+ task_id=task_id,
+ run_id=run_id,
+ start=None,
+ stop=None,
+ step=None,
+ include_prior_dates=include_prior_dates,
+ ),
+ )
+
+ if not isinstance(msg, XComSequenceSliceResult):
+ raise TypeError(f"Expected XComSequenceSliceResult, received:
{type(msg)} {msg}")
+
+ if not msg.root:
+ return None
Review Comment:
```suggestion
return []
```
I think?
##########
task-sdk/src/airflow/sdk/execution_time/task_runner.py:
##########
@@ -529,6 +529,123 @@ def xcom_pull(
return xcoms
+ async def axcom_pull(
+ self,
+ task_ids: str | Iterable[str] | None = None,
+ dag_id: str | None = None,
+ key: str = BaseXCom.XCOM_RETURN_KEY,
+ include_prior_dates: bool = False,
+ *,
+ map_indexes: int | Iterable[int] | None | ArgNotSet = NOTSET,
+ default: Any = None,
+ run_id: str | None = None,
+ ) -> Any:
+ """
+ Pull XComs either from the API server (BaseXCom) or from the custom
XCOM backend if configured.
+
+ The pull can be filtered optionally by certain criterion.
+
+ :param key: A key for the XCom. If provided, only XComs with matching
+ keys will be returned. The default key is ``'return_value'``, also
+ available as constant ``XCOM_RETURN_KEY``. This key is
automatically
+ given to XComs returned by tasks (as opposed to being pushed
+ manually).
+ :param task_ids: Only XComs from tasks with matching ids will be
+ pulled. If *None* (default), the task_id of the calling task is
used.
+ :param dag_id: If provided, only pulls XComs from this Dag. If *None*
+ (default), the Dag of the calling task is used.
+ :param map_indexes: If provided, only pull XComs with matching indexes.
+ If *None* (default), this is inferred from the task(s) being pulled
+ (see below for details).
+ :param include_prior_dates: If False, only XComs from the current
+ logical_date are returned. If *True*, XComs from previous dates
+ are returned as well.
+ :param run_id: If provided, only pulls XComs from a DagRun w/a
matching run_id.
+ If *None* (default), the run_id of the calling task is used.
+
+ When pulling one single task (``task_id`` is *None* or a str) without
+ specifying ``map_indexes``, the return value is a single XCom entry
+ (map_indexes is set to map_index of the calling task instance).
+
+ When pulling task is mapped the specified ``map_index`` is used, so by
default
+ pulling on mapped task will result in no matching XComs if the task
instance
+ of the method call is not mapped. Otherwise, the map_index of the
calling task
+ instance is used. Setting ``map_indexes`` to *None* will pull XCom as
it would
+ from a non mapped task.
+
+ In either case, ``default`` (*None* if not specified) is returned if no
+ matching XComs are found.
+
+ When pulling multiple tasks (i.e. either ``task_id`` or ``map_index``
is
+ a non-str iterable), a list of matching XComs is returned. Elements in
+ the list is ordered by item ordering in ``task_id`` and ``map_index``.
+ """
+ if dag_id is None:
+ dag_id = self.dag_id
+ if run_id is None:
+ run_id = self.run_id
+
+ single_task_requested = isinstance(task_ids, (str, type(None)))
+ single_map_index_requested = isinstance(map_indexes, (int, type(None)))
+
+ if task_ids is None:
+ # default to the current task if not provided
+ task_ids = [self.task_id]
+ elif isinstance(task_ids, str):
+ task_ids = [task_ids]
+
+ # If map_indexes is not specified, pull xcoms from all map indexes for
each task
+ if not is_arg_set(map_indexes):
+ xcoms: list[Any] = []
+ for t_id in task_ids:
+ values = await XCom.aget_all(
+ run_id=run_id,
+ key=key,
+ task_id=t_id,
+ dag_id=dag_id,
+ include_prior_dates=include_prior_dates,
+ )
+
+ if values is None:
+ xcoms.append(None)
+ else:
+ xcoms.extend(values)
+ # For single task pulling from unmapped task, return single value
+ if single_task_requested and len(xcoms) == 1:
+ return xcoms[0]
+ return xcoms
+
+ # Original logic when map_indexes is explicitly specified
+ map_indexes_iterable: Iterable[int | None] = []
+ if isinstance(map_indexes, int) or map_indexes is None:
+ map_indexes_iterable = [map_indexes]
+ elif isinstance(map_indexes, Iterable):
+ map_indexes_iterable = map_indexes
+ else:
+ raise TypeError(
+ f"Invalid type for map_indexes: expected int, iterable of
ints, or None, got {type(map_indexes)}"
+ )
Review Comment:
It might be overkill, but this could possible be a shared helper function
between sync and async versions.
##########
task-sdk/src/airflow/sdk/execution_time/task_runner.py:
##########
@@ -538,6 +655,15 @@ def xcom_push(self, key: str, value: Any):
"""
_xcom_push(self, key, value)
+ async def axcom_push(self, key: str, value: Any):
+ """
+ Make an XCom available for tasks to pull asynchronously.
+
+ :param key: Key to store the value under.
+ :param value: Value to store. Only be JSON-serializable values may be
used.
+ """
+ await _axcom_push(self, key, value)
Review Comment:
Why is this a separate function, rather than just calling `await
XCom.aset()` inline here?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]