This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 62b7bd6095 Coerce LazyXComAccess to list when pushed to XCom (#27251)
62b7bd6095 is described below

commit 62b7bd6095b7cdd61d828a66aa7dd7ac5643b0da
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Mon Nov 7 10:27:27 2022 +0800

    Coerce LazyXComAccess to list when pushed to XCom (#27251)
    
    The class is intended to work as a "lazy list" to avoid pulling a ton
    of XComs unnecessarily, but if it's pushed into XCom, the user should
    be aware of the performance implications, and this avoids leaking the
    implementation detail.
---
 airflow/models/taskinstance.py                     | 104 +-----------------
 airflow/models/xcom.py                             | 118 ++++++++++++++++++++-
 .../concepts/dynamic-task-mapping.rst              |  24 +++--
 tests/models/test_taskinstance.py                  |   2 +-
 4 files changed, 137 insertions(+), 111 deletions(-)

diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index e8b19f03dc..e024b4db18 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -30,20 +30,9 @@ from collections import defaultdict
 from datetime import datetime, timedelta
 from functools import partial
 from types import TracebackType
-from typing import (
-    TYPE_CHECKING,
-    Any,
-    Callable,
-    Collection,
-    ContextManager,
-    Generator,
-    Iterable,
-    NamedTuple,
-    Tuple,
-)
+from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, 
Iterable, NamedTuple, Tuple
 from urllib.parse import quote
 
-import attr
 import dill
 import jinja2
 import lazy_object_proxy
@@ -69,8 +58,6 @@ from sqlalchemy.ext.associationproxy import association_proxy
 from sqlalchemy.ext.mutable import MutableDict
 from sqlalchemy.orm import reconstructor, relationship
 from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value
-from sqlalchemy.orm.exc import NoResultFound
-from sqlalchemy.orm.query import Query
 from sqlalchemy.orm.session import Session
 from sqlalchemy.sql.elements import BooleanClauseList
 from sqlalchemy.sql.expression import ColumnOperators
@@ -101,7 +88,7 @@ from airflow.models.param import process_params
 from airflow.models.taskfail import TaskFail
 from airflow.models.taskmap import TaskMap
 from airflow.models.taskreschedule import TaskReschedule
-from airflow.models.xcom import XCOM_RETURN_KEY, XCom
+from airflow.models.xcom import XCOM_RETURN_KEY, LazyXComAccess, XCom
 from airflow.plugins_manager import integrate_macros_plugins
 from airflow.sentry import Sentry
 from airflow.stats import Stats
@@ -291,91 +278,6 @@ def clear_task_instances(
     session.flush()
 
 
-class _LazyXComAccessIterator(collections.abc.Iterator):
-    __slots__ = ['_cm', '_it']
-
-    def __init__(self, cm: ContextManager[Query]):
-        self._cm = cm
-        self._it = None
-
-    def __del__(self):
-        if self._it:
-            self._cm.__exit__(None, None, None)
-
-    def __iter__(self):
-        return self
-
-    def __next__(self):
-        if not self._it:
-            self._it = iter(self._cm.__enter__())
-        return XCom.deserialize_value(next(self._it))
-
-
[email protected]
-class _LazyXComAccess(collections.abc.Sequence):
-    """Wrapper to lazily pull XCom with a sequence-like interface.
-
-    Note that since the session bound to the parent query may have died when we
-    actually access the sequence's content, we must create a new session
-    for every function call with ``with_session()``.
-    """
-
-    dag_id: str
-    run_id: str
-    task_id: str
-    _query: Query = attr.ib(repr=False)
-    _len: int | None = attr.ib(init=False, repr=False, default=None)
-
-    @classmethod
-    def build_from_single_xcom(cls, first: XCom, query: Query) -> 
_LazyXComAccess:
-        return cls(
-            dag_id=first.dag_id,
-            run_id=first.run_id,
-            task_id=first.task_id,
-            query=query.with_entities(XCom.value)
-            .filter(
-                XCom.run_id == first.run_id,
-                XCom.task_id == first.task_id,
-                XCom.dag_id == first.dag_id,
-                XCom.map_index >= 0,
-            )
-            .order_by(None)
-            .order_by(XCom.map_index.asc()),
-        )
-
-    def __len__(self):
-        if self._len is None:
-            with self._get_bound_query() as query:
-                self._len = query.count()
-        return self._len
-
-    def __iter__(self):
-        return _LazyXComAccessIterator(self._get_bound_query())
-
-    def __getitem__(self, key):
-        if not isinstance(key, int):
-            raise ValueError("only support index access for now")
-        try:
-            with self._get_bound_query() as query:
-                r = query.offset(key).limit(1).one()
-        except NoResultFound:
-            raise IndexError(key) from None
-        return XCom.deserialize_value(r)
-
-    @contextlib.contextmanager
-    def _get_bound_query(self) -> Generator[Query, None, None]:
-        # Do we have a valid session already?
-        if self._query.session and self._query.session.is_active:
-            yield self._query
-            return
-
-        session = settings.Session()
-        try:
-            yield self._query.with_session(session)
-        finally:
-            session.close()
-
-
 class TaskInstanceKey(NamedTuple):
     """Key used to identify task instance."""
 
@@ -2439,7 +2341,7 @@ class TaskInstance(Base, LoggingMixin):
             if map_indexes is not None or first.map_index < 0:
                 return XCom.deserialize_value(first)
 
-            return _LazyXComAccess.build_from_single_xcom(first, query)
+            return LazyXComAccess.build_from_single_xcom(first, query)
 
         # At this point either task_ids or map_indexes is explicitly 
multi-value.
 
diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py
index cd97aacc7a..6386f2d8af 100644
--- a/airflow/models/xcom.py
+++ b/airflow/models/xcom.py
@@ -17,6 +17,8 @@
 # under the License.
 from __future__ import annotations
 
+import collections.abc
+import contextlib
 import datetime
 import inspect
 import json
@@ -24,8 +26,9 @@ import logging
 import pickle
 import warnings
 from functools import wraps
-from typing import TYPE_CHECKING, Any, Iterable, cast, overload
+from typing import TYPE_CHECKING, Any, Generator, Iterable, cast, overload
 
+import attr
 import pendulum
 from sqlalchemy import (
     Column,
@@ -41,6 +44,8 @@ from sqlalchemy.ext.associationproxy import association_proxy
 from sqlalchemy.orm import Query, Session, reconstructor, relationship
 from sqlalchemy.orm.exc import NoResultFound
 
+from airflow import settings
+from airflow.compat.functools import cached_property
 from airflow.configuration import conf
 from airflow.exceptions import RemovedInAirflow3Warning
 from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
@@ -203,6 +208,27 @@ class BaseXCom(Base, LoggingMixin):
             if dag_run_id is None:
                 raise ValueError(f"DAG run not found on DAG {dag_id!r} with ID 
{run_id!r}")
 
+        # Seamlessly resolve LazyXComAccess to a list. This is intended to work
+        # as a "lazy list" to avoid pulling a ton of XComs unnecessarily, but 
if
+        # it's pushed into XCom, the user should be aware of the performance
+        # implications, and this avoids leaking the implementation detail.
+        if isinstance(value, LazyXComAccess):
+            warning_message = (
+                "Coercing mapped lazy proxy %s from task %s (DAG %s, run %s) "
+                "to list, which may degrade performance. Review resource "
+                "requirements for this operation, and call list() to suppress "
+                "this message. See Dynamic Task Mapping documentation for "
+                "more information about lazy proxy objects."
+            )
+            log.warning(
+                warning_message,
+                "return value" if key == XCOM_RETURN_KEY else f"value {key}",
+                task_id,
+                dag_id,
+                run_id or execution_date,
+            )
+            value = list(value)
+
         value = cls.serialize_value(
             value=value,
             key=key,
@@ -589,8 +615,8 @@ class BaseXCom(Base, LoggingMixin):
         dag_id: str | None = None,
         run_id: str | None = None,
         map_index: int | None = None,
-    ):
-        """Serialize XCom value to str or pickled object"""
+    ) -> Any:
+        """Serialize XCom value to str or pickled object."""
         if conf.getboolean('core', 'enable_xcom_pickling'):
             return pickle.dumps(value)
         try:
@@ -632,6 +658,92 @@ class BaseXCom(Base, LoggingMixin):
         return BaseXCom.deserialize_value(self)
 
 
+class _LazyXComAccessIterator(collections.abc.Iterator):
+    def __init__(self, cm: contextlib.AbstractContextManager[Query]) -> None:
+        self._cm = cm
+        self._entered = False
+
+    def __del__(self) -> None:
+        if self._entered:
+            self._cm.__exit__(None, None, None)
+
+    def __iter__(self) -> collections.abc.Iterator:
+        return self
+
+    def __next__(self) -> Any:
+        return XCom.deserialize_value(next(self._it))
+
+    @cached_property
+    def _it(self) -> collections.abc.Iterator:
+        self._entered = True
+        return iter(self._cm.__enter__())
+
+
[email protected](slots=True)
+class LazyXComAccess(collections.abc.Sequence):
+    """Wrapper to lazily pull XCom with a sequence-like interface.
+
+    Note that since the session bound to the parent query may have died when we
+    actually access the sequence's content, we must create a new session
+    for every function call with ``with_session()``.
+    """
+
+    dag_id: str
+    run_id: str
+    task_id: str
+    _query: Query = attr.ib(repr=False)
+    _len: int | None = attr.ib(init=False, repr=False, default=None)
+
+    @classmethod
+    def build_from_single_xcom(cls, first: XCom, query: Query) -> 
LazyXComAccess:
+        return cls(
+            dag_id=first.dag_id,
+            run_id=first.run_id,
+            task_id=first.task_id,
+            query=query.with_entities(XCom.value)
+            .filter(
+                XCom.run_id == first.run_id,
+                XCom.task_id == first.task_id,
+                XCom.dag_id == first.dag_id,
+                XCom.map_index >= 0,
+            )
+            .order_by(None)
+            .order_by(XCom.map_index.asc()),
+        )
+
+    def __len__(self):
+        if self._len is None:
+            with self._get_bound_query() as query:
+                self._len = query.count()
+        return self._len
+
+    def __iter__(self):
+        return _LazyXComAccessIterator(self._get_bound_query())
+
+    def __getitem__(self, key):
+        if not isinstance(key, int):
+            raise ValueError("only support index access for now")
+        try:
+            with self._get_bound_query() as query:
+                r = query.offset(key).limit(1).one()
+        except NoResultFound:
+            raise IndexError(key) from None
+        return XCom.deserialize_value(r)
+
+    @contextlib.contextmanager
+    def _get_bound_query(self) -> Generator[Query, None, None]:
+        # Do we have a valid session already?
+        if self._query.session and self._query.session.is_active:
+            yield self._query
+            return
+
+        session = settings.Session()
+        try:
+            yield self._query.with_session(session)
+        finally:
+            session.close()
+
+
 def _patch_outdated_serializer(clazz: type[BaseXCom], params: Iterable[str]) 
-> None:
     """Patch a custom ``serialize_value`` to accept the modern signature.
 
diff --git a/docs/apache-airflow/concepts/dynamic-task-mapping.rst 
b/docs/apache-airflow/concepts/dynamic-task-mapping.rst
index 10ff5fff01..5ae0e9fb82 100644
--- a/docs/apache-airflow/concepts/dynamic-task-mapping.rst
+++ b/docs/apache-airflow/concepts/dynamic-task-mapping.rst
@@ -68,20 +68,32 @@ The grid view also provides visibility into your mapped 
tasks in the details pan
 
     In the above example, ``values`` received by ``sum_it`` is an aggregation 
of all values returned by each mapped instance of ``add_one``. However, since 
it is impossible to know how many instances of ``add_one`` we will have in 
advance, ``values`` is not a normal list, but a "lazy sequence" that retrieves 
each individual value only when asked. Therefore, if you run ``print(values)`` 
directly, you would get something like this::
 
-        _LazyXComAccess(dag_id='simple_mapping', run_id='test_run', 
task_id='add_one')
+        LazyXComAccess(dag_id='simple_mapping', run_id='test_run', 
task_id='add_one')
 
-    You can use normal sequence syntax on this object (e.g. ``values[0]``), or 
iterate through it normally with a ``for`` loop. ``list(values)`` will give you 
a "real" ``list``, but please be aware of the potential performance 
implications if the list is large.
+    You can use normal sequence syntax on this object (e.g. ``values[0]``), or 
iterate through it normally with a ``for`` loop. ``list(values)`` will give you 
a "real" ``list``, but since this would eagerly load values from *all* of the 
referenced upstream mapped tasks, you must be aware of the potential 
performance implications if the mapped number is large.
 
-    Note that the same also applies to when you push this proxy object into 
XCom. This, for example, would not
-    work with the default XCom backend:
+    Note that the same also applies to when you push this proxy object into 
XCom. Airflow tries to be smart and coerce the value automatically, but will 
emit a warning for this so you are aware of this. For example:
 
     .. code-block:: python
 
         @task
         def forward_values(values):
-            return values  # This is a lazy proxy and can't be pushed!
+            return values  # This is a lazy proxy!
 
-    You need to explicitly call ``list(values)`` instead, and accept the 
performance implications.
+    will emit a warning like this:
+
+    .. code-block:: text
+
+        Coercing mapped lazy proxy return value from task forward_values to 
list, which may degrade
+        performance. Review resource requirements for this operation, and call 
list() explicitly to suppress this message. See Dynamic Task Mapping 
documentation for more information about lazy proxy objects.
+
+    The message can be suppressed by modifying the task like this:
+
+    .. code-block:: python
+
+        @task
+        def forward_values(values):
+            return list(values)
 
 .. note:: A reduce task is not required.
 
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index d21203b2d7..c167e4d5b5 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -3485,7 +3485,7 @@ def 
test_ti_xcom_pull_on_mapped_operator_return_lazy_iterable(mock_deserialize_v
     joined = ti_2.xcom_pull("task_1", session=session)
     assert mock_deserialize_value.call_count == 0
 
-    assert repr(joined) == "_LazyXComAccess(dag_id='test_xcom', run_id='test', 
task_id='task_1')"
+    assert repr(joined) == "LazyXComAccess(dag_id='test_xcom', run_id='test', 
task_id='task_1')"
 
     # Only when we go through the iterable does deserialization happen.
     it = iter(joined)

Reply via email to