This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 62b7bd6095 Coerce LazyXComAccess to list when pushed to XCom (#27251)
62b7bd6095 is described below
commit 62b7bd6095b7cdd61d828a66aa7dd7ac5643b0da
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Mon Nov 7 10:27:27 2022 +0800
Coerce LazyXComAccess to list when pushed to XCom (#27251)
The class is intended to work as a "lazy list" to avoid pulling a ton
of XComs unnecessarily, but if it's pushed into XCom, the user should
be aware of the performance implications, and this avoids leaking the
implementation detail.
---
airflow/models/taskinstance.py | 104 +-----------------
airflow/models/xcom.py | 118 ++++++++++++++++++++-
.../concepts/dynamic-task-mapping.rst | 24 +++--
tests/models/test_taskinstance.py | 2 +-
4 files changed, 137 insertions(+), 111 deletions(-)
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index e8b19f03dc..e024b4db18 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -30,20 +30,9 @@ from collections import defaultdict
from datetime import datetime, timedelta
from functools import partial
from types import TracebackType
-from typing import (
- TYPE_CHECKING,
- Any,
- Callable,
- Collection,
- ContextManager,
- Generator,
- Iterable,
- NamedTuple,
- Tuple,
-)
+from typing import TYPE_CHECKING, Any, Callable, Collection, Generator,
Iterable, NamedTuple, Tuple
from urllib.parse import quote
-import attr
import dill
import jinja2
import lazy_object_proxy
@@ -69,8 +58,6 @@ from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.orm import reconstructor, relationship
from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value
-from sqlalchemy.orm.exc import NoResultFound
-from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.elements import BooleanClauseList
from sqlalchemy.sql.expression import ColumnOperators
@@ -101,7 +88,7 @@ from airflow.models.param import process_params
from airflow.models.taskfail import TaskFail
from airflow.models.taskmap import TaskMap
from airflow.models.taskreschedule import TaskReschedule
-from airflow.models.xcom import XCOM_RETURN_KEY, XCom
+from airflow.models.xcom import XCOM_RETURN_KEY, LazyXComAccess, XCom
from airflow.plugins_manager import integrate_macros_plugins
from airflow.sentry import Sentry
from airflow.stats import Stats
@@ -291,91 +278,6 @@ def clear_task_instances(
session.flush()
-class _LazyXComAccessIterator(collections.abc.Iterator):
- __slots__ = ['_cm', '_it']
-
- def __init__(self, cm: ContextManager[Query]):
- self._cm = cm
- self._it = None
-
- def __del__(self):
- if self._it:
- self._cm.__exit__(None, None, None)
-
- def __iter__(self):
- return self
-
- def __next__(self):
- if not self._it:
- self._it = iter(self._cm.__enter__())
- return XCom.deserialize_value(next(self._it))
-
-
[email protected]
-class _LazyXComAccess(collections.abc.Sequence):
- """Wrapper to lazily pull XCom with a sequence-like interface.
-
- Note that since the session bound to the parent query may have died when we
- actually access the sequence's content, we must create a new session
- for every function call with ``with_session()``.
- """
-
- dag_id: str
- run_id: str
- task_id: str
- _query: Query = attr.ib(repr=False)
- _len: int | None = attr.ib(init=False, repr=False, default=None)
-
- @classmethod
- def build_from_single_xcom(cls, first: XCom, query: Query) ->
_LazyXComAccess:
- return cls(
- dag_id=first.dag_id,
- run_id=first.run_id,
- task_id=first.task_id,
- query=query.with_entities(XCom.value)
- .filter(
- XCom.run_id == first.run_id,
- XCom.task_id == first.task_id,
- XCom.dag_id == first.dag_id,
- XCom.map_index >= 0,
- )
- .order_by(None)
- .order_by(XCom.map_index.asc()),
- )
-
- def __len__(self):
- if self._len is None:
- with self._get_bound_query() as query:
- self._len = query.count()
- return self._len
-
- def __iter__(self):
- return _LazyXComAccessIterator(self._get_bound_query())
-
- def __getitem__(self, key):
- if not isinstance(key, int):
- raise ValueError("only support index access for now")
- try:
- with self._get_bound_query() as query:
- r = query.offset(key).limit(1).one()
- except NoResultFound:
- raise IndexError(key) from None
- return XCom.deserialize_value(r)
-
- @contextlib.contextmanager
- def _get_bound_query(self) -> Generator[Query, None, None]:
- # Do we have a valid session already?
- if self._query.session and self._query.session.is_active:
- yield self._query
- return
-
- session = settings.Session()
- try:
- yield self._query.with_session(session)
- finally:
- session.close()
-
-
class TaskInstanceKey(NamedTuple):
"""Key used to identify task instance."""
@@ -2439,7 +2341,7 @@ class TaskInstance(Base, LoggingMixin):
if map_indexes is not None or first.map_index < 0:
return XCom.deserialize_value(first)
- return _LazyXComAccess.build_from_single_xcom(first, query)
+ return LazyXComAccess.build_from_single_xcom(first, query)
# At this point either task_ids or map_indexes is explicitly
multi-value.
diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py
index cd97aacc7a..6386f2d8af 100644
--- a/airflow/models/xcom.py
+++ b/airflow/models/xcom.py
@@ -17,6 +17,8 @@
# under the License.
from __future__ import annotations
+import collections.abc
+import contextlib
import datetime
import inspect
import json
@@ -24,8 +26,9 @@ import logging
import pickle
import warnings
from functools import wraps
-from typing import TYPE_CHECKING, Any, Iterable, cast, overload
+from typing import TYPE_CHECKING, Any, Generator, Iterable, cast, overload
+import attr
import pendulum
from sqlalchemy import (
Column,
@@ -41,6 +44,8 @@ from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import Query, Session, reconstructor, relationship
from sqlalchemy.orm.exc import NoResultFound
+from airflow import settings
+from airflow.compat.functools import cached_property
from airflow.configuration import conf
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
@@ -203,6 +208,27 @@ class BaseXCom(Base, LoggingMixin):
if dag_run_id is None:
raise ValueError(f"DAG run not found on DAG {dag_id!r} with ID
{run_id!r}")
+ # Seamlessly resolve LazyXComAccess to a list. This is intended to work
+ # as a "lazy list" to avoid pulling a ton of XComs unnecessarily, but
if
+ # it's pushed into XCom, the user should be aware of the performance
+ # implications, and this avoids leaking the implementation detail.
+ if isinstance(value, LazyXComAccess):
+ warning_message = (
+ "Coercing mapped lazy proxy %s from task %s (DAG %s, run %s) "
+ "to list, which may degrade performance. Review resource "
+ "requirements for this operation, and call list() to suppress "
+ "this message. See Dynamic Task Mapping documentation for "
+ "more information about lazy proxy objects."
+ )
+ log.warning(
+ warning_message,
+ "return value" if key == XCOM_RETURN_KEY else f"value {key}",
+ task_id,
+ dag_id,
+ run_id or execution_date,
+ )
+ value = list(value)
+
value = cls.serialize_value(
value=value,
key=key,
@@ -589,8 +615,8 @@ class BaseXCom(Base, LoggingMixin):
dag_id: str | None = None,
run_id: str | None = None,
map_index: int | None = None,
- ):
- """Serialize XCom value to str or pickled object"""
+ ) -> Any:
+ """Serialize XCom value to str or pickled object."""
if conf.getboolean('core', 'enable_xcom_pickling'):
return pickle.dumps(value)
try:
@@ -632,6 +658,92 @@ class BaseXCom(Base, LoggingMixin):
return BaseXCom.deserialize_value(self)
+class _LazyXComAccessIterator(collections.abc.Iterator):
+ def __init__(self, cm: contextlib.AbstractContextManager[Query]) -> None:
+ self._cm = cm
+ self._entered = False
+
+ def __del__(self) -> None:
+ if self._entered:
+ self._cm.__exit__(None, None, None)
+
+ def __iter__(self) -> collections.abc.Iterator:
+ return self
+
+ def __next__(self) -> Any:
+ return XCom.deserialize_value(next(self._it))
+
+ @cached_property
+ def _it(self) -> collections.abc.Iterator:
+ self._entered = True
+ return iter(self._cm.__enter__())
+
+
[email protected](slots=True)
+class LazyXComAccess(collections.abc.Sequence):
+ """Wrapper to lazily pull XCom with a sequence-like interface.
+
+ Note that since the session bound to the parent query may have died when we
+ actually access the sequence's content, we must create a new session
+ for every function call with ``with_session()``.
+ """
+
+ dag_id: str
+ run_id: str
+ task_id: str
+ _query: Query = attr.ib(repr=False)
+ _len: int | None = attr.ib(init=False, repr=False, default=None)
+
+ @classmethod
+ def build_from_single_xcom(cls, first: XCom, query: Query) ->
LazyXComAccess:
+ return cls(
+ dag_id=first.dag_id,
+ run_id=first.run_id,
+ task_id=first.task_id,
+ query=query.with_entities(XCom.value)
+ .filter(
+ XCom.run_id == first.run_id,
+ XCom.task_id == first.task_id,
+ XCom.dag_id == first.dag_id,
+ XCom.map_index >= 0,
+ )
+ .order_by(None)
+ .order_by(XCom.map_index.asc()),
+ )
+
+ def __len__(self):
+ if self._len is None:
+ with self._get_bound_query() as query:
+ self._len = query.count()
+ return self._len
+
+ def __iter__(self):
+ return _LazyXComAccessIterator(self._get_bound_query())
+
+ def __getitem__(self, key):
+ if not isinstance(key, int):
+ raise ValueError("only support index access for now")
+ try:
+ with self._get_bound_query() as query:
+ r = query.offset(key).limit(1).one()
+ except NoResultFound:
+ raise IndexError(key) from None
+ return XCom.deserialize_value(r)
+
+ @contextlib.contextmanager
+ def _get_bound_query(self) -> Generator[Query, None, None]:
+ # Do we have a valid session already?
+ if self._query.session and self._query.session.is_active:
+ yield self._query
+ return
+
+ session = settings.Session()
+ try:
+ yield self._query.with_session(session)
+ finally:
+ session.close()
+
+
def _patch_outdated_serializer(clazz: type[BaseXCom], params: Iterable[str])
-> None:
"""Patch a custom ``serialize_value`` to accept the modern signature.
diff --git a/docs/apache-airflow/concepts/dynamic-task-mapping.rst
b/docs/apache-airflow/concepts/dynamic-task-mapping.rst
index 10ff5fff01..5ae0e9fb82 100644
--- a/docs/apache-airflow/concepts/dynamic-task-mapping.rst
+++ b/docs/apache-airflow/concepts/dynamic-task-mapping.rst
@@ -68,20 +68,32 @@ The grid view also provides visibility into your mapped
tasks in the details pan
In the above example, ``values`` received by ``sum_it`` is an aggregation
of all values returned by each mapped instance of ``add_one``. However, since
it is impossible to know how many instances of ``add_one`` we will have in
advance, ``values`` is not a normal list, but a "lazy sequence" that retrieves
each individual value only when asked. Therefore, if you run ``print(values)``
directly, you would get something like this::
- _LazyXComAccess(dag_id='simple_mapping', run_id='test_run',
task_id='add_one')
+ LazyXComAccess(dag_id='simple_mapping', run_id='test_run',
task_id='add_one')
- You can use normal sequence syntax on this object (e.g. ``values[0]``), or
iterate through it normally with a ``for`` loop. ``list(values)`` will give you
a "real" ``list``, but please be aware of the potential performance
implications if the list is large.
+ You can use normal sequence syntax on this object (e.g. ``values[0]``), or
iterate through it normally with a ``for`` loop. ``list(values)`` will give you
a "real" ``list``, but since this would eagerly load values from *all* of the
referenced upstream mapped tasks, you must be aware of the potential
performance implications if the mapped number is large.
- Note that the same also applies to when you push this proxy object into
XCom. This, for example, would not
- work with the default XCom backend:
+ Note that the same also applies to when you push this proxy object into
XCom. Airflow tries to be smart and coerce the value automatically, but will
emit a warning for this so you are aware of this. For example:
.. code-block:: python
@task
def forward_values(values):
- return values # This is a lazy proxy and can't be pushed!
+ return values # This is a lazy proxy!
- You need to explicitly call ``list(values)`` instead, and accept the
performance implications.
+ will emit a warning like this:
+
+ .. code-block:: text
+
+ Coercing mapped lazy proxy return value from task forward_values to
list, which may degrade
+ performance. Review resource requirements for this operation, and call
list() explicitly to suppress this message. See Dynamic Task Mapping
documentation for more information about lazy proxy objects.
+
+ The message can be suppressed by modifying the task like this:
+
+ .. code-block:: python
+
+ @task
+ def forward_values(values):
+ return list(values)
.. note:: A reduce task is not required.
diff --git a/tests/models/test_taskinstance.py
b/tests/models/test_taskinstance.py
index d21203b2d7..c167e4d5b5 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -3485,7 +3485,7 @@ def
test_ti_xcom_pull_on_mapped_operator_return_lazy_iterable(mock_deserialize_v
joined = ti_2.xcom_pull("task_1", session=session)
assert mock_deserialize_value.call_count == 0
- assert repr(joined) == "_LazyXComAccess(dag_id='test_xcom', run_id='test',
task_id='task_1')"
+ assert repr(joined) == "LazyXComAccess(dag_id='test_xcom', run_id='test',
task_id='task_1')"
# Only when we go through the iterable does deserialization happen.
it = iter(joined)