This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2ead075ee22 Move SchedulerXComArg to serialization (#59777)
2ead075ee22 is described below
commit 2ead075ee2287f30a8f6bc07dce8202f1a939c99
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Wed Dec 24 22:06:28 2025 +0800
Move SchedulerXComArg to serialization (#59777)
---
airflow-core/src/airflow/__init__.py | 2 +-
airflow-core/src/airflow/models/expandinput.py | 8 +-
airflow-core/src/airflow/models/mappedoperator.py | 2 +-
airflow-core/src/airflow/models/xcom_arg.py | 233 +--------------------
.../airflow/serialization/definitions/taskgroup.py | 2 +-
.../definitions}/xcom_arg.py | 9 -
.../airflow/serialization/serialized_objects.py | 2 +-
.../unit/serialization/test_dag_serialization.py | 12 +-
8 files changed, 21 insertions(+), 249 deletions(-)
diff --git a/airflow-core/src/airflow/__init__.py
b/airflow-core/src/airflow/__init__.py
index cfaf575da1f..302c275f60b 100644
--- a/airflow-core/src/airflow/__init__.py
+++ b/airflow-core/src/airflow/__init__.py
@@ -84,7 +84,7 @@ if not os.environ.get("_AIRFLOW__AS_LIBRARY", None):
__lazy_imports: dict[str, tuple[str, str, bool]] = {
"DAG": (".sdk", "DAG", False),
"Asset": (".sdk", "Asset", False),
- "XComArg": (".models.xcom_arg", "XComArg", False),
+ "XComArg": (".sdk", "XComArg", False),
"version": (".version", "", False),
# Deprecated lazy imports
"AirflowException": (".exceptions", "AirflowException", True),
diff --git a/airflow-core/src/airflow/models/expandinput.py
b/airflow-core/src/airflow/models/expandinput.py
index 58f45a83ea9..8ebf310a875 100644
--- a/airflow-core/src/airflow/models/expandinput.py
+++ b/airflow-core/src/airflow/models/expandinput.py
@@ -40,8 +40,8 @@ if TYPE_CHECKING:
from sqlalchemy.orm import Session
from airflow.models.mappedoperator import MappedOperator
- from airflow.models.xcom_arg import SchedulerXComArg
from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
+ from airflow.serialization.definitions.xcom_arg import SchedulerXComArg
Operator: TypeAlias = MappedOperator | SerializedBaseOperator
@@ -58,7 +58,7 @@ __all__ = [
def _needs_run_time_resolution(v: OperatorExpandArgument) ->
TypeGuard[MappedArgument | SchedulerXComArg]:
- from airflow.models.xcom_arg import SchedulerXComArg
+ from airflow.serialization.definitions.xcom_arg import SchedulerXComArg
return isinstance(v, (MappedArgument, SchedulerXComArg))
@@ -89,7 +89,7 @@ class SchedulerDictOfListsExpandInput:
If any arguments are not known right now (upstream task not finished),
they will not be present in the dict.
"""
- from airflow.models.xcom_arg import SchedulerXComArg,
get_task_map_length
+ from airflow.serialization.definitions.xcom_arg import
SchedulerXComArg, get_task_map_length
# TODO: This initiates one database call for each XComArg. Would it be
# more efficient to do one single db call and unpack the value here?
@@ -136,7 +136,7 @@ class SchedulerListOfDictsExpandInput:
raise NotFullyPopulated({"expand_kwargs() argument"})
def get_total_map_length(self, run_id: str, *, session: Session) -> int:
- from airflow.models.xcom_arg import get_task_map_length
+ from airflow.serialization.definitions.xcom_arg import
get_task_map_length
if isinstance(self.value, Sized):
return len(self.value)
diff --git a/airflow-core/src/airflow/models/mappedoperator.py
b/airflow-core/src/airflow/models/mappedoperator.py
index 473ca098613..0441643ad65 100644
--- a/airflow-core/src/airflow/models/mappedoperator.py
+++ b/airflow-core/src/airflow/models/mappedoperator.py
@@ -478,7 +478,7 @@ class MappedOperator(DAGNode):
def iter_mapped_dependencies(self) -> Iterator[Operator]:
"""Upstream dependencies that provide XComs used by this task for task
mapping."""
- from airflow.models.xcom_arg import SchedulerXComArg
+ from airflow.serialization.definitions.xcom_arg import SchedulerXComArg
for op, _ in
SchedulerXComArg.iter_xcom_references(self._get_specified_expand_input()):
yield op
diff --git a/airflow-core/src/airflow/models/xcom_arg.py
b/airflow-core/src/airflow/models/xcom_arg.py
index 5abd2aec8e1..5ddb9565872 100644
--- a/airflow-core/src/airflow/models/xcom_arg.py
+++ b/airflow-core/src/airflow/models/xcom_arg.py
@@ -17,229 +17,16 @@
from __future__ import annotations
-from collections.abc import Iterator, Sequence
-from functools import singledispatch
-from typing import TYPE_CHECKING, Any, TypeAlias
+import warnings
-import attrs
-from sqlalchemy import func, or_, select
-from sqlalchemy.orm import Session
+from airflow.sdk import XComArg
+from airflow.utils.deprecation_tools import DeprecatedImportWarning
-from airflow.models.referencemixin import ReferenceMixin
-from airflow.models.xcom import XCOM_RETURN_KEY
-from airflow.serialization.definitions.notset import NOTSET, is_arg_set
-from airflow.utils.db import exists_query
-from airflow.utils.state import State
+__all__ = ["XComArg"]
-__all__ = ["SchedulerXComArg", "deserialize_xcom_arg", "get_task_map_length"]
-
-if TYPE_CHECKING:
- from airflow.models.mappedoperator import MappedOperator
- from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
- from airflow.serialization.definitions.dag import SerializedDAG
- from airflow.typing_compat import Self
-
- Operator: TypeAlias = MappedOperator | SerializedBaseOperator
-
-
-class SchedulerXComArg:
- """
- Reference to an XCom value pushed from another operator.
-
- This is the safe counterpart to :class:`airflow.sdk.XComArg`.
- """
-
- @classmethod
- def _deserialize(cls, data: dict[str, Any], dag: SerializedDAG) -> Self:
- """
- Deserialize an XComArg.
-
- The implementation should be the inverse function to ``serialize``,
- implementing given a data dict converted from this XComArg derivative,
- how the original XComArg should be created. DAG serialization relies on
- additional information added in ``serialize_xcom_arg`` to dispatch data
- dicts to the correct ``_deserialize`` information, so this function
does
- not need to validate whether the incoming data contains correct keys.
- """
- raise NotImplementedError("This class should not be instantiated
directly")
-
- @classmethod
- def iter_xcom_references(cls, arg: Any) -> Iterator[tuple[Operator, str]]:
- """
- Return XCom references in an arbitrary value.
-
- Recursively traverse ``arg`` and look for XComArg instances in any
- collection objects, and instances with ``template_fields`` set.
- """
- from airflow.models.mappedoperator import MappedOperator
- from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
-
- if isinstance(arg, ReferenceMixin):
- yield from arg.iter_references()
- elif isinstance(arg, (tuple, set, list)):
- for elem in arg:
- yield from cls.iter_xcom_references(elem)
- elif isinstance(arg, dict):
- for elem in arg.values():
- yield from cls.iter_xcom_references(elem)
- elif isinstance(arg, (MappedOperator, SerializedBaseOperator)):
- for attr in arg.template_fields:
- yield from cls.iter_xcom_references(getattr(arg, attr))
-
- def iter_references(self) -> Iterator[tuple[Operator, str]]:
- raise NotImplementedError("This class should not be instantiated
directly")
-
-
[email protected]
-class SchedulerPlainXComArg(SchedulerXComArg):
- operator: Operator
- key: str
-
- @classmethod
- def _deserialize(cls, data: dict[str, Any], dag: SerializedDAG) -> Self:
- return cls(dag.get_task(data["task_id"]), data["key"])
-
- def iter_references(self) -> Iterator[tuple[Operator, str]]:
- yield self.operator, self.key
-
-
[email protected]
-class SchedulerMapXComArg(SchedulerXComArg):
- arg: SchedulerXComArg
- callables: Sequence[str]
-
- @classmethod
- def _deserialize(cls, data: dict[str, Any], dag: SerializedDAG) -> Self:
- # We are deliberately NOT deserializing the callables. These are shown
- # in the UI, and displaying a function object is useless.
- return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"])
-
- def iter_references(self) -> Iterator[tuple[Operator, str]]:
- yield from self.arg.iter_references()
-
-
[email protected]
-class SchedulerConcatXComArg(SchedulerXComArg):
- args: Sequence[SchedulerXComArg]
-
- @classmethod
- def _deserialize(cls, data: dict[str, Any], dag: SerializedDAG) -> Self:
- return cls([deserialize_xcom_arg(arg, dag) for arg in data["args"]])
-
- def iter_references(self) -> Iterator[tuple[Operator, str]]:
- for arg in self.args:
- yield from arg.iter_references()
-
-
[email protected]
-class SchedulerZipXComArg(SchedulerXComArg):
- args: Sequence[SchedulerXComArg]
- fillvalue: Any
-
- @classmethod
- def _deserialize(cls, data: dict[str, Any], dag: SerializedDAG) -> Self:
- return cls(
- [deserialize_xcom_arg(arg, dag) for arg in data["args"]],
- fillvalue=data.get("fillvalue", NOTSET),
- )
-
- def iter_references(self) -> Iterator[tuple[Operator, str]]:
- for arg in self.args:
- yield from arg.iter_references()
-
-
-@singledispatch
-def get_task_map_length(xcom_arg: SchedulerXComArg, run_id: str, *, session:
Session) -> int | None:
- # The base implementation -- specific XComArg subclasses have specialised
implementations
- raise NotImplementedError(f"get_task_map_length not implemented for
{type(xcom_arg)}")
-
-
-@get_task_map_length.register
-def _(xcom_arg: SchedulerPlainXComArg, run_id: str, *, session: Session) ->
int | None:
- from airflow.models.mappedoperator import is_mapped
- from airflow.models.taskinstance import TaskInstance
- from airflow.models.taskmap import TaskMap
- from airflow.models.xcom import XComModel
-
- dag_id = xcom_arg.operator.dag_id
- task_id = xcom_arg.operator.task_id
-
- if is_mapped(xcom_arg.operator):
- unfinished_ti_exists = exists_query(
- TaskInstance.dag_id == dag_id,
- TaskInstance.run_id == run_id,
- TaskInstance.task_id == task_id,
- # Special NULL treatment is needed because 'state' can be NULL.
- # The "IN" part would produce "NULL NOT IN ..." and eventually
- # "NULl = NULL", which is a big no-no in SQL.
- or_(
- TaskInstance.state.is_(None),
- TaskInstance.state.in_(s.value for s in State.unfinished if s
is not None),
- ),
- session=session,
- )
- if unfinished_ti_exists:
- return None # Not all of the expanded tis are done yet.
- query = select(func.count(XComModel.map_index)).where(
- XComModel.dag_id == dag_id,
- XComModel.run_id == run_id,
- XComModel.task_id == task_id,
- XComModel.map_index >= 0,
- XComModel.key == XCOM_RETURN_KEY,
- )
- else:
- query = select(TaskMap.length).where(
- TaskMap.dag_id == dag_id,
- TaskMap.run_id == run_id,
- TaskMap.task_id == task_id,
- TaskMap.map_index < 0,
- )
- return session.scalar(query)
-
-
-@get_task_map_length.register
-def _(xcom_arg: SchedulerMapXComArg, run_id: str, *, session: Session) -> int
| None:
- return get_task_map_length(xcom_arg.arg, run_id, session=session)
-
-
-@get_task_map_length.register
-def _(xcom_arg: SchedulerZipXComArg, run_id: str, *, session: Session) -> int
| None:
- all_lengths = (get_task_map_length(arg, run_id, session=session) for arg
in xcom_arg.args)
- ready_lengths = [length for length in all_lengths if length is not None]
- if len(ready_lengths) != len(xcom_arg.args):
- return None # If any of the referenced XComs is not ready, we are not
ready either.
- if is_arg_set(xcom_arg.fillvalue):
- return max(ready_lengths)
- return min(ready_lengths)
-
-
-@get_task_map_length.register
-def _(xcom_arg: SchedulerConcatXComArg, run_id: str, *, session: Session) ->
int | None:
- all_lengths = (get_task_map_length(arg, run_id, session=session) for arg
in xcom_arg.args)
- ready_lengths = [length for length in all_lengths if length is not None]
- if len(ready_lengths) != len(xcom_arg.args):
- return None # If any of the referenced XComs is not ready, we are not
ready either.
- return sum(ready_lengths)
-
-
-def deserialize_xcom_arg(data: dict[str, Any], dag: SerializedDAG):
- """DAG serialization interface."""
- klass = _XCOM_ARG_TYPES[data.get("type", "")]
- return klass._deserialize(data, dag)
-
-
-_XCOM_ARG_TYPES: dict[str, type[SchedulerXComArg]] = {
- "": SchedulerPlainXComArg,
- "concat": SchedulerConcatXComArg,
- "map": SchedulerMapXComArg,
- "zip": SchedulerZipXComArg,
-}
-
-
-def __getattr__(name: str):
- if name == "XComArg":
- from airflow.sdk import XComArg
-
- return XComArg
-
- raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
+warnings.warn(
+ "Importing airflow.models.xcom_arg is deprecated and will be removed in "
+ "the future. Please import from 'airflow.sdk' instead.",
+ DeprecatedImportWarning,
+ stacklevel=2,
+)
diff --git a/airflow-core/src/airflow/serialization/definitions/taskgroup.py
b/airflow-core/src/airflow/serialization/definitions/taskgroup.py
index 414936d7e67..2632c8d4b32 100644
--- a/airflow-core/src/airflow/serialization/definitions/taskgroup.py
+++ b/airflow-core/src/airflow/serialization/definitions/taskgroup.py
@@ -288,7 +288,7 @@ class SerializedMappedTaskGroup(SerializedTaskGroup):
def iter_mapped_dependencies(self) -> Iterator[SerializedOperator]:
"""Upstream dependencies that provide XComs used by this mapped task
group."""
- from airflow.models.xcom_arg import SchedulerXComArg
+ from airflow.serialization.definitions.xcom_arg import SchedulerXComArg
for op, _ in SchedulerXComArg.iter_xcom_references(self._expand_input):
yield op
diff --git a/airflow-core/src/airflow/models/xcom_arg.py
b/airflow-core/src/airflow/serialization/definitions/xcom_arg.py
similarity index 97%
copy from airflow-core/src/airflow/models/xcom_arg.py
copy to airflow-core/src/airflow/serialization/definitions/xcom_arg.py
index 5abd2aec8e1..c12ce096ad8 100644
--- a/airflow-core/src/airflow/models/xcom_arg.py
+++ b/airflow-core/src/airflow/serialization/definitions/xcom_arg.py
@@ -234,12 +234,3 @@ _XCOM_ARG_TYPES: dict[str, type[SchedulerXComArg]] = {
"map": SchedulerMapXComArg,
"zip": SchedulerZipXComArg,
}
-
-
-def __getattr__(name: str):
- if name == "XComArg":
- from airflow.sdk import XComArg
-
- return XComArg
-
- raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py
b/airflow-core/src/airflow/serialization/serialized_objects.py
index c70deab2e0a..a3b840a3dea 100644
--- a/airflow-core/src/airflow/serialization/serialized_objects.py
+++ b/airflow-core/src/airflow/serialization/serialized_objects.py
@@ -48,7 +48,6 @@ from airflow.exceptions import AirflowException,
DeserializationError, Serializa
from airflow.models.connection import Connection
from airflow.models.expandinput import create_expand_input
from airflow.models.taskinstancekey import TaskInstanceKey
-from airflow.models.xcom_arg import SchedulerXComArg, deserialize_xcom_arg
from airflow.sdk import DAG, Asset, AssetAlias, BaseOperator, XComArg
from airflow.sdk.bases.operator import OPERATOR_DEFAULTS # TODO: Copy this
into the scheduler?
from airflow.sdk.definitions.asset import (
@@ -78,6 +77,7 @@ from airflow.serialization.definitions.node import DAGNode
from airflow.serialization.definitions.operatorlink import XComOperatorLink
from airflow.serialization.definitions.param import SerializedParam,
SerializedParamsDict
from airflow.serialization.definitions.taskgroup import
SerializedMappedTaskGroup, SerializedTaskGroup
+from airflow.serialization.definitions.xcom_arg import SchedulerXComArg,
deserialize_xcom_arg
from airflow.serialization.encoders import (
coerce_to_core_timetable,
encode_asset_like,
diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py
b/airflow-core/tests/unit/serialization/test_dag_serialization.py
index a59e2efe5f3..d52e9089ac1 100644
--- a/airflow-core/tests/unit/serialization/test_dag_serialization.py
+++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py
@@ -61,7 +61,7 @@ from airflow.models.taskinstance import TaskInstance as TI
from airflow.models.xcom import XCOM_RETURN_KEY, XComModel
from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator
from airflow.providers.standard.operators.bash import BashOperator
-from airflow.sdk import DAG, Asset, AssetAlias, BaseHook, TaskGroup,
WeightRule, teardown
+from airflow.sdk import DAG, Asset, AssetAlias, BaseHook, TaskGroup,
WeightRule, XComArg, teardown
from airflow.sdk.bases.decorator import DecoratedOperator
from airflow.sdk.bases.operator import OPERATOR_DEFAULTS, BaseOperator
from airflow.sdk.definitions._internal.expandinput import EXPAND_INPUT_EMPTY
@@ -74,6 +74,7 @@ from airflow.serialization.definitions.dag import
SerializedDAG
from airflow.serialization.definitions.notset import NOTSET
from airflow.serialization.definitions.operatorlink import XComOperatorLink
from airflow.serialization.definitions.param import SerializedParam
+from airflow.serialization.definitions.xcom_arg import SchedulerPlainXComArg
from airflow.serialization.encoders import ensure_serialized_asset
from airflow.serialization.enums import Encoding
from airflow.serialization.json_schema import load_dag_schema_dict
@@ -81,6 +82,7 @@ from airflow.serialization.serialized_objects import (
BaseSerialization,
DagSerialization,
OperatorSerialization,
+ _XComRef,
)
from airflow.task.priority_strategy import _AbsolutePriorityWeightStrategy,
_DownstreamPriorityWeightStrategy
from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
@@ -2656,10 +2658,6 @@ def test_operator_expand_serde():
def test_operator_expand_xcomarg_serde():
- from airflow.models.xcom_arg import SchedulerPlainXComArg
- from airflow.sdk.definitions.xcom_arg import XComArg
- from airflow.serialization.serialized_objects import _XComRef
-
with DAG("test-dag", schedule=None, start_date=datetime(2020, 1, 1)) as
dag:
task1 = BaseOperator(task_id="op1")
mapped =
MockOperator.partial(task_id="task_2").expand(arg2=XComArg(task1))
@@ -2767,10 +2765,6 @@ def test_operator_expand_kwargs_literal_serde(strict):
@pytest.mark.parametrize("strict", [True, False])
def test_operator_expand_kwargs_xcomarg_serde(strict):
- from airflow.models.xcom_arg import SchedulerPlainXComArg
- from airflow.sdk.definitions.xcom_arg import XComArg
- from airflow.serialization.serialized_objects import _XComRef
-
with DAG("test-dag", schedule=None, start_date=datetime(2020, 1, 1)) as
dag:
task1 = BaseOperator(task_id="op1")
mapped =
MockOperator.partial(task_id="task_2").expand_kwargs(XComArg(task1),
strict=strict)