This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2ead075ee22 Move SchedulerXComArg to serialization (#59777)
2ead075ee22 is described below

commit 2ead075ee2287f30a8f6bc07dce8202f1a939c99
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Wed Dec 24 22:06:28 2025 +0800

    Move SchedulerXComArg to serialization (#59777)
---
 airflow-core/src/airflow/__init__.py               |   2 +-
 airflow-core/src/airflow/models/expandinput.py     |   8 +-
 airflow-core/src/airflow/models/mappedoperator.py  |   2 +-
 airflow-core/src/airflow/models/xcom_arg.py        | 233 +--------------------
 .../airflow/serialization/definitions/taskgroup.py |   2 +-
 .../definitions}/xcom_arg.py                       |   9 -
 .../airflow/serialization/serialized_objects.py    |   2 +-
 .../unit/serialization/test_dag_serialization.py   |  12 +-
 8 files changed, 21 insertions(+), 249 deletions(-)

diff --git a/airflow-core/src/airflow/__init__.py 
b/airflow-core/src/airflow/__init__.py
index cfaf575da1f..302c275f60b 100644
--- a/airflow-core/src/airflow/__init__.py
+++ b/airflow-core/src/airflow/__init__.py
@@ -84,7 +84,7 @@ if not os.environ.get("_AIRFLOW__AS_LIBRARY", None):
 __lazy_imports: dict[str, tuple[str, str, bool]] = {
     "DAG": (".sdk", "DAG", False),
     "Asset": (".sdk", "Asset", False),
-    "XComArg": (".models.xcom_arg", "XComArg", False),
+    "XComArg": (".sdk", "XComArg", False),
     "version": (".version", "", False),
     # Deprecated lazy imports
     "AirflowException": (".exceptions", "AirflowException", True),
diff --git a/airflow-core/src/airflow/models/expandinput.py 
b/airflow-core/src/airflow/models/expandinput.py
index 58f45a83ea9..8ebf310a875 100644
--- a/airflow-core/src/airflow/models/expandinput.py
+++ b/airflow-core/src/airflow/models/expandinput.py
@@ -40,8 +40,8 @@ if TYPE_CHECKING:
     from sqlalchemy.orm import Session
 
     from airflow.models.mappedoperator import MappedOperator
-    from airflow.models.xcom_arg import SchedulerXComArg
     from airflow.serialization.definitions.baseoperator import 
SerializedBaseOperator
+    from airflow.serialization.definitions.xcom_arg import SchedulerXComArg
 
     Operator: TypeAlias = MappedOperator | SerializedBaseOperator
 
@@ -58,7 +58,7 @@ __all__ = [
 
 
 def _needs_run_time_resolution(v: OperatorExpandArgument) -> 
TypeGuard[MappedArgument | SchedulerXComArg]:
-    from airflow.models.xcom_arg import SchedulerXComArg
+    from airflow.serialization.definitions.xcom_arg import SchedulerXComArg
 
     return isinstance(v, (MappedArgument, SchedulerXComArg))
 
@@ -89,7 +89,7 @@ class SchedulerDictOfListsExpandInput:
         If any arguments are not known right now (upstream task not finished),
         they will not be present in the dict.
         """
-        from airflow.models.xcom_arg import SchedulerXComArg, 
get_task_map_length
+        from airflow.serialization.definitions.xcom_arg import 
SchedulerXComArg, get_task_map_length
 
         # TODO: This initiates one database call for each XComArg. Would it be
         # more efficient to do one single db call and unpack the value here?
@@ -136,7 +136,7 @@ class SchedulerListOfDictsExpandInput:
         raise NotFullyPopulated({"expand_kwargs() argument"})
 
     def get_total_map_length(self, run_id: str, *, session: Session) -> int:
-        from airflow.models.xcom_arg import get_task_map_length
+        from airflow.serialization.definitions.xcom_arg import 
get_task_map_length
 
         if isinstance(self.value, Sized):
             return len(self.value)
diff --git a/airflow-core/src/airflow/models/mappedoperator.py 
b/airflow-core/src/airflow/models/mappedoperator.py
index 473ca098613..0441643ad65 100644
--- a/airflow-core/src/airflow/models/mappedoperator.py
+++ b/airflow-core/src/airflow/models/mappedoperator.py
@@ -478,7 +478,7 @@ class MappedOperator(DAGNode):
 
     def iter_mapped_dependencies(self) -> Iterator[Operator]:
         """Upstream dependencies that provide XComs used by this task for task 
mapping."""
-        from airflow.models.xcom_arg import SchedulerXComArg
+        from airflow.serialization.definitions.xcom_arg import SchedulerXComArg
 
         for op, _ in 
SchedulerXComArg.iter_xcom_references(self._get_specified_expand_input()):
             yield op
diff --git a/airflow-core/src/airflow/models/xcom_arg.py 
b/airflow-core/src/airflow/models/xcom_arg.py
index 5abd2aec8e1..5ddb9565872 100644
--- a/airflow-core/src/airflow/models/xcom_arg.py
+++ b/airflow-core/src/airflow/models/xcom_arg.py
@@ -17,229 +17,16 @@
 
 from __future__ import annotations
 
-from collections.abc import Iterator, Sequence
-from functools import singledispatch
-from typing import TYPE_CHECKING, Any, TypeAlias
+import warnings
 
-import attrs
-from sqlalchemy import func, or_, select
-from sqlalchemy.orm import Session
+from airflow.sdk import XComArg
+from airflow.utils.deprecation_tools import DeprecatedImportWarning
 
-from airflow.models.referencemixin import ReferenceMixin
-from airflow.models.xcom import XCOM_RETURN_KEY
-from airflow.serialization.definitions.notset import NOTSET, is_arg_set
-from airflow.utils.db import exists_query
-from airflow.utils.state import State
+__all__ = ["XComArg"]
 
-__all__ = ["SchedulerXComArg", "deserialize_xcom_arg", "get_task_map_length"]
-
-if TYPE_CHECKING:
-    from airflow.models.mappedoperator import MappedOperator
-    from airflow.serialization.definitions.baseoperator import 
SerializedBaseOperator
-    from airflow.serialization.definitions.dag import SerializedDAG
-    from airflow.typing_compat import Self
-
-    Operator: TypeAlias = MappedOperator | SerializedBaseOperator
-
-
-class SchedulerXComArg:
-    """
-    Reference to an XCom value pushed from another operator.
-
-    This is the safe counterpart to :class:`airflow.sdk.XComArg`.
-    """
-
-    @classmethod
-    def _deserialize(cls, data: dict[str, Any], dag: SerializedDAG) -> Self:
-        """
-        Deserialize an XComArg.
-
-        The implementation should be the inverse function to ``serialize``,
-        implementing given a data dict converted from this XComArg derivative,
-        how the original XComArg should be created. DAG serialization relies on
-        additional information added in ``serialize_xcom_arg`` to dispatch data
-        dicts to the correct ``_deserialize`` information, so this function 
does
-        not need to validate whether the incoming data contains correct keys.
-        """
-        raise NotImplementedError("This class should not be instantiated 
directly")
-
-    @classmethod
-    def iter_xcom_references(cls, arg: Any) -> Iterator[tuple[Operator, str]]:
-        """
-        Return XCom references in an arbitrary value.
-
-        Recursively traverse ``arg`` and look for XComArg instances in any
-        collection objects, and instances with ``template_fields`` set.
-        """
-        from airflow.models.mappedoperator import MappedOperator
-        from airflow.serialization.definitions.baseoperator import 
SerializedBaseOperator
-
-        if isinstance(arg, ReferenceMixin):
-            yield from arg.iter_references()
-        elif isinstance(arg, (tuple, set, list)):
-            for elem in arg:
-                yield from cls.iter_xcom_references(elem)
-        elif isinstance(arg, dict):
-            for elem in arg.values():
-                yield from cls.iter_xcom_references(elem)
-        elif isinstance(arg, (MappedOperator, SerializedBaseOperator)):
-            for attr in arg.template_fields:
-                yield from cls.iter_xcom_references(getattr(arg, attr))
-
-    def iter_references(self) -> Iterator[tuple[Operator, str]]:
-        raise NotImplementedError("This class should not be instantiated 
directly")
-
-
[email protected]
-class SchedulerPlainXComArg(SchedulerXComArg):
-    operator: Operator
-    key: str
-
-    @classmethod
-    def _deserialize(cls, data: dict[str, Any], dag: SerializedDAG) -> Self:
-        return cls(dag.get_task(data["task_id"]), data["key"])
-
-    def iter_references(self) -> Iterator[tuple[Operator, str]]:
-        yield self.operator, self.key
-
-
[email protected]
-class SchedulerMapXComArg(SchedulerXComArg):
-    arg: SchedulerXComArg
-    callables: Sequence[str]
-
-    @classmethod
-    def _deserialize(cls, data: dict[str, Any], dag: SerializedDAG) -> Self:
-        # We are deliberately NOT deserializing the callables. These are shown
-        # in the UI, and displaying a function object is useless.
-        return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"])
-
-    def iter_references(self) -> Iterator[tuple[Operator, str]]:
-        yield from self.arg.iter_references()
-
-
[email protected]
-class SchedulerConcatXComArg(SchedulerXComArg):
-    args: Sequence[SchedulerXComArg]
-
-    @classmethod
-    def _deserialize(cls, data: dict[str, Any], dag: SerializedDAG) -> Self:
-        return cls([deserialize_xcom_arg(arg, dag) for arg in data["args"]])
-
-    def iter_references(self) -> Iterator[tuple[Operator, str]]:
-        for arg in self.args:
-            yield from arg.iter_references()
-
-
[email protected]
-class SchedulerZipXComArg(SchedulerXComArg):
-    args: Sequence[SchedulerXComArg]
-    fillvalue: Any
-
-    @classmethod
-    def _deserialize(cls, data: dict[str, Any], dag: SerializedDAG) -> Self:
-        return cls(
-            [deserialize_xcom_arg(arg, dag) for arg in data["args"]],
-            fillvalue=data.get("fillvalue", NOTSET),
-        )
-
-    def iter_references(self) -> Iterator[tuple[Operator, str]]:
-        for arg in self.args:
-            yield from arg.iter_references()
-
-
-@singledispatch
-def get_task_map_length(xcom_arg: SchedulerXComArg, run_id: str, *, session: 
Session) -> int | None:
-    # The base implementation -- specific XComArg subclasses have specialised 
implementations
-    raise NotImplementedError(f"get_task_map_length not implemented for 
{type(xcom_arg)}")
-
-
-@get_task_map_length.register
-def _(xcom_arg: SchedulerPlainXComArg, run_id: str, *, session: Session) -> 
int | None:
-    from airflow.models.mappedoperator import is_mapped
-    from airflow.models.taskinstance import TaskInstance
-    from airflow.models.taskmap import TaskMap
-    from airflow.models.xcom import XComModel
-
-    dag_id = xcom_arg.operator.dag_id
-    task_id = xcom_arg.operator.task_id
-
-    if is_mapped(xcom_arg.operator):
-        unfinished_ti_exists = exists_query(
-            TaskInstance.dag_id == dag_id,
-            TaskInstance.run_id == run_id,
-            TaskInstance.task_id == task_id,
-            # Special NULL treatment is needed because 'state' can be NULL.
-            # The "IN" part would produce "NULL NOT IN ..." and eventually
-            # "NULl = NULL", which is a big no-no in SQL.
-            or_(
-                TaskInstance.state.is_(None),
-                TaskInstance.state.in_(s.value for s in State.unfinished if s 
is not None),
-            ),
-            session=session,
-        )
-        if unfinished_ti_exists:
-            return None  # Not all of the expanded tis are done yet.
-        query = select(func.count(XComModel.map_index)).where(
-            XComModel.dag_id == dag_id,
-            XComModel.run_id == run_id,
-            XComModel.task_id == task_id,
-            XComModel.map_index >= 0,
-            XComModel.key == XCOM_RETURN_KEY,
-        )
-    else:
-        query = select(TaskMap.length).where(
-            TaskMap.dag_id == dag_id,
-            TaskMap.run_id == run_id,
-            TaskMap.task_id == task_id,
-            TaskMap.map_index < 0,
-        )
-    return session.scalar(query)
-
-
-@get_task_map_length.register
-def _(xcom_arg: SchedulerMapXComArg, run_id: str, *, session: Session) -> int 
| None:
-    return get_task_map_length(xcom_arg.arg, run_id, session=session)
-
-
-@get_task_map_length.register
-def _(xcom_arg: SchedulerZipXComArg, run_id: str, *, session: Session) -> int 
| None:
-    all_lengths = (get_task_map_length(arg, run_id, session=session) for arg 
in xcom_arg.args)
-    ready_lengths = [length for length in all_lengths if length is not None]
-    if len(ready_lengths) != len(xcom_arg.args):
-        return None  # If any of the referenced XComs is not ready, we are not 
ready either.
-    if is_arg_set(xcom_arg.fillvalue):
-        return max(ready_lengths)
-    return min(ready_lengths)
-
-
-@get_task_map_length.register
-def _(xcom_arg: SchedulerConcatXComArg, run_id: str, *, session: Session) -> 
int | None:
-    all_lengths = (get_task_map_length(arg, run_id, session=session) for arg 
in xcom_arg.args)
-    ready_lengths = [length for length in all_lengths if length is not None]
-    if len(ready_lengths) != len(xcom_arg.args):
-        return None  # If any of the referenced XComs is not ready, we are not 
ready either.
-    return sum(ready_lengths)
-
-
-def deserialize_xcom_arg(data: dict[str, Any], dag: SerializedDAG):
-    """DAG serialization interface."""
-    klass = _XCOM_ARG_TYPES[data.get("type", "")]
-    return klass._deserialize(data, dag)
-
-
-_XCOM_ARG_TYPES: dict[str, type[SchedulerXComArg]] = {
-    "": SchedulerPlainXComArg,
-    "concat": SchedulerConcatXComArg,
-    "map": SchedulerMapXComArg,
-    "zip": SchedulerZipXComArg,
-}
-
-
-def __getattr__(name: str):
-    if name == "XComArg":
-        from airflow.sdk import XComArg
-
-        return XComArg
-
-    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
+warnings.warn(
+    "Importing airflow.models.xcom_arg is deprecated and will be removed in "
+    "the future. Please import from 'airflow.sdk' instead.",
+    DeprecatedImportWarning,
+    stacklevel=2,
+)
diff --git a/airflow-core/src/airflow/serialization/definitions/taskgroup.py 
b/airflow-core/src/airflow/serialization/definitions/taskgroup.py
index 414936d7e67..2632c8d4b32 100644
--- a/airflow-core/src/airflow/serialization/definitions/taskgroup.py
+++ b/airflow-core/src/airflow/serialization/definitions/taskgroup.py
@@ -288,7 +288,7 @@ class SerializedMappedTaskGroup(SerializedTaskGroup):
 
     def iter_mapped_dependencies(self) -> Iterator[SerializedOperator]:
         """Upstream dependencies that provide XComs used by this mapped task 
group."""
-        from airflow.models.xcom_arg import SchedulerXComArg
+        from airflow.serialization.definitions.xcom_arg import SchedulerXComArg
 
         for op, _ in SchedulerXComArg.iter_xcom_references(self._expand_input):
             yield op
diff --git a/airflow-core/src/airflow/models/xcom_arg.py 
b/airflow-core/src/airflow/serialization/definitions/xcom_arg.py
similarity index 97%
copy from airflow-core/src/airflow/models/xcom_arg.py
copy to airflow-core/src/airflow/serialization/definitions/xcom_arg.py
index 5abd2aec8e1..c12ce096ad8 100644
--- a/airflow-core/src/airflow/models/xcom_arg.py
+++ b/airflow-core/src/airflow/serialization/definitions/xcom_arg.py
@@ -234,12 +234,3 @@ _XCOM_ARG_TYPES: dict[str, type[SchedulerXComArg]] = {
     "map": SchedulerMapXComArg,
     "zip": SchedulerZipXComArg,
 }
-
-
-def __getattr__(name: str):
-    if name == "XComArg":
-        from airflow.sdk import XComArg
-
-        return XComArg
-
-    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py 
b/airflow-core/src/airflow/serialization/serialized_objects.py
index c70deab2e0a..a3b840a3dea 100644
--- a/airflow-core/src/airflow/serialization/serialized_objects.py
+++ b/airflow-core/src/airflow/serialization/serialized_objects.py
@@ -48,7 +48,6 @@ from airflow.exceptions import AirflowException, 
DeserializationError, Serializa
 from airflow.models.connection import Connection
 from airflow.models.expandinput import create_expand_input
 from airflow.models.taskinstancekey import TaskInstanceKey
-from airflow.models.xcom_arg import SchedulerXComArg, deserialize_xcom_arg
 from airflow.sdk import DAG, Asset, AssetAlias, BaseOperator, XComArg
 from airflow.sdk.bases.operator import OPERATOR_DEFAULTS  # TODO: Copy this 
into the scheduler?
 from airflow.sdk.definitions.asset import (
@@ -78,6 +77,7 @@ from airflow.serialization.definitions.node import DAGNode
 from airflow.serialization.definitions.operatorlink import XComOperatorLink
 from airflow.serialization.definitions.param import SerializedParam, 
SerializedParamsDict
 from airflow.serialization.definitions.taskgroup import 
SerializedMappedTaskGroup, SerializedTaskGroup
+from airflow.serialization.definitions.xcom_arg import SchedulerXComArg, 
deserialize_xcom_arg
 from airflow.serialization.encoders import (
     coerce_to_core_timetable,
     encode_asset_like,
diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py 
b/airflow-core/tests/unit/serialization/test_dag_serialization.py
index a59e2efe5f3..d52e9089ac1 100644
--- a/airflow-core/tests/unit/serialization/test_dag_serialization.py
+++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py
@@ -61,7 +61,7 @@ from airflow.models.taskinstance import TaskInstance as TI
 from airflow.models.xcom import XCOM_RETURN_KEY, XComModel
 from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator
 from airflow.providers.standard.operators.bash import BashOperator
-from airflow.sdk import DAG, Asset, AssetAlias, BaseHook, TaskGroup, 
WeightRule, teardown
+from airflow.sdk import DAG, Asset, AssetAlias, BaseHook, TaskGroup, 
WeightRule, XComArg, teardown
 from airflow.sdk.bases.decorator import DecoratedOperator
 from airflow.sdk.bases.operator import OPERATOR_DEFAULTS, BaseOperator
 from airflow.sdk.definitions._internal.expandinput import EXPAND_INPUT_EMPTY
@@ -74,6 +74,7 @@ from airflow.serialization.definitions.dag import 
SerializedDAG
 from airflow.serialization.definitions.notset import NOTSET
 from airflow.serialization.definitions.operatorlink import XComOperatorLink
 from airflow.serialization.definitions.param import SerializedParam
+from airflow.serialization.definitions.xcom_arg import SchedulerPlainXComArg
 from airflow.serialization.encoders import ensure_serialized_asset
 from airflow.serialization.enums import Encoding
 from airflow.serialization.json_schema import load_dag_schema_dict
@@ -81,6 +82,7 @@ from airflow.serialization.serialized_objects import (
     BaseSerialization,
     DagSerialization,
     OperatorSerialization,
+    _XComRef,
 )
 from airflow.task.priority_strategy import _AbsolutePriorityWeightStrategy, 
_DownstreamPriorityWeightStrategy
 from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
@@ -2656,10 +2658,6 @@ def test_operator_expand_serde():
 
 
 def test_operator_expand_xcomarg_serde():
-    from airflow.models.xcom_arg import SchedulerPlainXComArg
-    from airflow.sdk.definitions.xcom_arg import XComArg
-    from airflow.serialization.serialized_objects import _XComRef
-
     with DAG("test-dag", schedule=None, start_date=datetime(2020, 1, 1)) as 
dag:
         task1 = BaseOperator(task_id="op1")
         mapped = 
MockOperator.partial(task_id="task_2").expand(arg2=XComArg(task1))
@@ -2767,10 +2765,6 @@ def test_operator_expand_kwargs_literal_serde(strict):
 
 @pytest.mark.parametrize("strict", [True, False])
 def test_operator_expand_kwargs_xcomarg_serde(strict):
-    from airflow.models.xcom_arg import SchedulerPlainXComArg
-    from airflow.sdk.definitions.xcom_arg import XComArg
-    from airflow.serialization.serialized_objects import _XComRef
-
     with DAG("test-dag", schedule=None, start_date=datetime(2020, 1, 1)) as 
dag:
         task1 = BaseOperator(task_id="op1")
         mapped = 
MockOperator.partial(task_id="task_2").expand_kwargs(XComArg(task1), 
strict=strict)

Reply via email to