This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 4b1b2331720 Import XComArg from TaskSDK definitions (#49834)
4b1b2331720 is described below
commit 4b1b2331720af91ebac5b5896c378ad35f93af49
Author: GPK <[email protected]>
AuthorDate: Mon Apr 28 14:34:39 2025 +0100
Import XComArg from TaskSDK definitions (#49834)
---
task-sdk/src/airflow/sdk/bases/operator.py | 11 +++++------
task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py | 8 ++++----
task-sdk/src/airflow/sdk/definitions/edges.py | 4 ++--
task-sdk/src/airflow/sdk/definitions/mappedoperator.py | 8 ++++----
task-sdk/src/airflow/sdk/definitions/taskgroup.py | 2 +-
5 files changed, 16 insertions(+), 17 deletions(-)
diff --git a/task-sdk/src/airflow/sdk/bases/operator.py
b/task-sdk/src/airflow/sdk/bases/operator.py
index 09ca8c6c03e..f9cb80b52cc 100644
--- a/task-sdk/src/airflow/sdk/bases/operator.py
+++ b/task-sdk/src/airflow/sdk/bases/operator.py
@@ -77,10 +77,10 @@ if TYPE_CHECKING:
import jinja2
- from airflow.models.xcom_arg import XComArg
from airflow.sdk.definitions.context import Context
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.definitions.taskgroup import TaskGroup
+ from airflow.sdk.definitions.xcom_arg import XComArg
from airflow.serialization.enums import DagAttributeTypes
from airflow.task.priority_strategy import PriorityWeightStrategy
from airflow.triggers.base import BaseTrigger, StartTriggerArgs
@@ -1389,7 +1389,7 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
return self._dag is not None
def _set_xcomargs_dependencies(self) -> None:
- from airflow.models.xcom_arg import XComArg
+ from airflow.sdk.definitions.xcom_arg import XComArg
for f in self.template_fields:
arg = getattr(self, f, NOTSET)
@@ -1418,7 +1418,7 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
generate_content >> send_email
"""
- from airflow.models.xcom_arg import XComArg
+ from airflow.sdk.definitions.xcom_arg import XComArg
if field not in self.template_fields:
return
@@ -1465,10 +1465,9 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
@property
def output(self) -> XComArg:
"""Returns reference to XCom pushed by current operator."""
- from airflow.models.xcom_arg import XComArg
+ from airflow.sdk.definitions.xcom_arg import XComArg
- # TODO: Task-SDK: remove this type ignore once XComArg is ported over
- return XComArg(operator=self) # type: ignore[call-overload]
+ return XComArg(operator=self)
@classmethod
def get_serialized_fields(cls):
diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py
b/task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py
index 3e304b7dfe9..b1c0c6ee5f9 100644
--- a/task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py
+++ b/task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py
@@ -67,14 +67,14 @@ def is_mappable(v: Any) ->
TypeGuard[OperatorExpandArgument]:
# To replace tedious isinstance() checks.
def _is_parse_time_mappable(v: OperatorExpandArgument) -> TypeGuard[Mapping |
Sequence]:
- from airflow.models.xcom_arg import XComArg
+ from airflow.sdk.definitions.xcom_arg import XComArg
return not isinstance(v, (MappedArgument, XComArg))
# To replace tedious isinstance() checks.
def _needs_run_time_resolution(v: OperatorExpandArgument) ->
TypeGuard[MappedArgument | XComArg]:
- from airflow.models.xcom_arg import XComArg
+ from airflow.sdk.definitions.xcom_arg import XComArg
return isinstance(v, (MappedArgument, XComArg))
@@ -187,7 +187,7 @@ class DictOfListsExpandInput(ResolveMixin):
raise IndexError(f"index {map_index} is over mapped length")
def iter_references(self) -> Iterable[tuple[Operator, str]]:
- from airflow.models.xcom_arg import XComArg
+ from airflow.sdk.definitions.xcom_arg import XComArg
for x in self.value.values():
if isinstance(x, XComArg):
@@ -238,7 +238,7 @@ class ListOfDictsExpandInput(ResolveMixin):
raise NotFullyPopulated({"expand_kwargs() argument"})
def iter_references(self) -> Iterable[tuple[Operator, str]]:
- from airflow.models.xcom_arg import XComArg
+ from airflow.sdk.definitions.xcom_arg import XComArg
if isinstance(self.value, XComArg):
yield from self.value.iter_references()
diff --git a/task-sdk/src/airflow/sdk/definitions/edges.py
b/task-sdk/src/airflow/sdk/definitions/edges.py
index 556cedbe78a..4bc620645dc 100644
--- a/task-sdk/src/airflow/sdk/definitions/edges.py
+++ b/task-sdk/src/airflow/sdk/definitions/edges.py
@@ -70,9 +70,9 @@ class EdgeModifier(DependencyMixin):
nodes: DependencyMixin | Sequence[DependencyMixin],
stream: list[DependencyMixin],
):
- from airflow.models.xcom_arg import XComArg
from airflow.sdk.definitions._internal.node import DAGNode
from airflow.sdk.definitions.taskgroup import TaskGroup
+ from airflow.sdk.definitions.xcom_arg import XComArg
for node in self._make_list(nodes):
if isinstance(node, (TaskGroup, XComArg, DAGNode)):
@@ -92,9 +92,9 @@ class EdgeModifier(DependencyMixin):
the nodes are from the same TaskGroup, we will leave them as DAGNodes
and not
convert them to TaskGroups
"""
- from airflow.models.xcom_arg import XComArg
from airflow.sdk.definitions._internal.node import DAGNode
from airflow.sdk.definitions.taskgroup import TaskGroup
+ from airflow.sdk.definitions.xcom_arg import XComArg
group_ids = set()
for node in [*self._upstream, *self._downstream]:
diff --git a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
index b2e3baaeccf..abcb0366eb4 100644
--- a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
+++ b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
@@ -68,11 +68,11 @@ if TYPE_CHECKING:
OperatorExpandArgument,
OperatorExpandKwargsArgument,
)
- from airflow.models.xcom_arg import XComArg
from airflow.sdk.bases.operator import BaseOperator
from airflow.sdk.bases.operatorlink import BaseOperatorLink
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.definitions.param import ParamsDict
+ from airflow.sdk.definitions.xcom_arg import XComArg
from airflow.sdk.types import Operator
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.triggers.base import StartTriggerArgs
@@ -192,7 +192,7 @@ class OperatorPartial:
return self._expand(DictOfListsExpandInput(mapped_kwargs),
strict=False)
def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict:
bool = True) -> MappedOperator:
- from airflow.models.xcom_arg import XComArg
+ from airflow.sdk.definitions.xcom_arg import XComArg
if isinstance(kwargs, Sequence):
for item in kwargs:
@@ -338,7 +338,7 @@ class MappedOperator(AbstractOperator):
return f"<Mapped({self._task_type}): {self.task_id}>"
def __attrs_post_init__(self):
- from airflow.models.xcom_arg import XComArg
+ from airflow.sdk.definitions.xcom_arg import XComArg
if self.get_closest_mapped_task_group() is not None:
raise NotImplementedError("operator expansion in an expanded task
group is not yet supported")
@@ -675,7 +675,7 @@ class MappedOperator(AbstractOperator):
@property
def output(self) -> XComArg:
"""Return reference to XCom pushed by current operator."""
- from airflow.models.xcom_arg import XComArg
+ from airflow.sdk.definitions.xcom_arg import XComArg
return XComArg(operator=self)
diff --git a/task-sdk/src/airflow/sdk/definitions/taskgroup.py
b/task-sdk/src/airflow/sdk/definitions/taskgroup.py
index 3363424dee6..5ef139c14eb 100644
--- a/task-sdk/src/airflow/sdk/definitions/taskgroup.py
+++ b/task-sdk/src/airflow/sdk/definitions/taskgroup.py
@@ -655,7 +655,7 @@ class MappedTaskGroup(TaskGroup):
def iter_mapped_dependencies(self) -> Iterator[Operator]:
"""Upstream dependencies that provide XComs used by this mapped task
group."""
- from airflow.models.xcom_arg import XComArg
+ from airflow.sdk.definitions.xcom_arg import XComArg
for op, _ in XComArg.iter_xcom_references(self._expand_input):
yield op