This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4b1b2331720 Import XComArg from TaskSDK definitions (#49834)
4b1b2331720 is described below

commit 4b1b2331720af91ebac5b5896c378ad35f93af49
Author: GPK <[email protected]>
AuthorDate: Mon Apr 28 14:34:39 2025 +0100

    Import XComArg from TaskSDK definitions (#49834)
---
 task-sdk/src/airflow/sdk/bases/operator.py                    | 11 +++++------
 task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py |  8 ++++----
 task-sdk/src/airflow/sdk/definitions/edges.py                 |  4 ++--
 task-sdk/src/airflow/sdk/definitions/mappedoperator.py        |  8 ++++----
 task-sdk/src/airflow/sdk/definitions/taskgroup.py             |  2 +-
 5 files changed, 16 insertions(+), 17 deletions(-)

diff --git a/task-sdk/src/airflow/sdk/bases/operator.py 
b/task-sdk/src/airflow/sdk/bases/operator.py
index 09ca8c6c03e..f9cb80b52cc 100644
--- a/task-sdk/src/airflow/sdk/bases/operator.py
+++ b/task-sdk/src/airflow/sdk/bases/operator.py
@@ -77,10 +77,10 @@ if TYPE_CHECKING:
 
     import jinja2
 
-    from airflow.models.xcom_arg import XComArg
     from airflow.sdk.definitions.context import Context
     from airflow.sdk.definitions.dag import DAG
     from airflow.sdk.definitions.taskgroup import TaskGroup
+    from airflow.sdk.definitions.xcom_arg import XComArg
     from airflow.serialization.enums import DagAttributeTypes
     from airflow.task.priority_strategy import PriorityWeightStrategy
     from airflow.triggers.base import BaseTrigger, StartTriggerArgs
@@ -1389,7 +1389,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         return self._dag is not None
 
     def _set_xcomargs_dependencies(self) -> None:
-        from airflow.models.xcom_arg import XComArg
+        from airflow.sdk.definitions.xcom_arg import XComArg
 
         for f in self.template_fields:
             arg = getattr(self, f, NOTSET)
@@ -1418,7 +1418,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
                 generate_content >> send_email
 
         """
-        from airflow.models.xcom_arg import XComArg
+        from airflow.sdk.definitions.xcom_arg import XComArg
 
         if field not in self.template_fields:
             return
@@ -1465,10 +1465,9 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
     @property
     def output(self) -> XComArg:
         """Returns reference to XCom pushed by current operator."""
-        from airflow.models.xcom_arg import XComArg
+        from airflow.sdk.definitions.xcom_arg import XComArg
 
-        # TODO: Task-SDK: remove this type ignore once XComArg is ported over
-        return XComArg(operator=self)  # type: ignore[call-overload]
+        return XComArg(operator=self)
 
     @classmethod
     def get_serialized_fields(cls):
diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py 
b/task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py
index 3e304b7dfe9..b1c0c6ee5f9 100644
--- a/task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py
+++ b/task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py
@@ -67,14 +67,14 @@ def is_mappable(v: Any) -> 
TypeGuard[OperatorExpandArgument]:
 
 # To replace tedious isinstance() checks.
 def _is_parse_time_mappable(v: OperatorExpandArgument) -> TypeGuard[Mapping | 
Sequence]:
-    from airflow.models.xcom_arg import XComArg
+    from airflow.sdk.definitions.xcom_arg import XComArg
 
     return not isinstance(v, (MappedArgument, XComArg))
 
 
 # To replace tedious isinstance() checks.
 def _needs_run_time_resolution(v: OperatorExpandArgument) -> 
TypeGuard[MappedArgument | XComArg]:
-    from airflow.models.xcom_arg import XComArg
+    from airflow.sdk.definitions.xcom_arg import XComArg
 
     return isinstance(v, (MappedArgument, XComArg))
 
@@ -187,7 +187,7 @@ class DictOfListsExpandInput(ResolveMixin):
         raise IndexError(f"index {map_index} is over mapped length")
 
     def iter_references(self) -> Iterable[tuple[Operator, str]]:
-        from airflow.models.xcom_arg import XComArg
+        from airflow.sdk.definitions.xcom_arg import XComArg
 
         for x in self.value.values():
             if isinstance(x, XComArg):
@@ -238,7 +238,7 @@ class ListOfDictsExpandInput(ResolveMixin):
         raise NotFullyPopulated({"expand_kwargs() argument"})
 
     def iter_references(self) -> Iterable[tuple[Operator, str]]:
-        from airflow.models.xcom_arg import XComArg
+        from airflow.sdk.definitions.xcom_arg import XComArg
 
         if isinstance(self.value, XComArg):
             yield from self.value.iter_references()
diff --git a/task-sdk/src/airflow/sdk/definitions/edges.py 
b/task-sdk/src/airflow/sdk/definitions/edges.py
index 556cedbe78a..4bc620645dc 100644
--- a/task-sdk/src/airflow/sdk/definitions/edges.py
+++ b/task-sdk/src/airflow/sdk/definitions/edges.py
@@ -70,9 +70,9 @@ class EdgeModifier(DependencyMixin):
         nodes: DependencyMixin | Sequence[DependencyMixin],
         stream: list[DependencyMixin],
     ):
-        from airflow.models.xcom_arg import XComArg
         from airflow.sdk.definitions._internal.node import DAGNode
         from airflow.sdk.definitions.taskgroup import TaskGroup
+        from airflow.sdk.definitions.xcom_arg import XComArg
 
         for node in self._make_list(nodes):
             if isinstance(node, (TaskGroup, XComArg, DAGNode)):
@@ -92,9 +92,9 @@ class EdgeModifier(DependencyMixin):
         the nodes are from the same TaskGroup, we will leave them as DAGNodes 
and not
         convert them to TaskGroups
         """
-        from airflow.models.xcom_arg import XComArg
         from airflow.sdk.definitions._internal.node import DAGNode
         from airflow.sdk.definitions.taskgroup import TaskGroup
+        from airflow.sdk.definitions.xcom_arg import XComArg
 
         group_ids = set()
         for node in [*self._upstream, *self._downstream]:
diff --git a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py 
b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
index b2e3baaeccf..abcb0366eb4 100644
--- a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
+++ b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
@@ -68,11 +68,11 @@ if TYPE_CHECKING:
         OperatorExpandArgument,
         OperatorExpandKwargsArgument,
     )
-    from airflow.models.xcom_arg import XComArg
     from airflow.sdk.bases.operator import BaseOperator
     from airflow.sdk.bases.operatorlink import BaseOperatorLink
     from airflow.sdk.definitions.dag import DAG
     from airflow.sdk.definitions.param import ParamsDict
+    from airflow.sdk.definitions.xcom_arg import XComArg
     from airflow.sdk.types import Operator
     from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
     from airflow.triggers.base import StartTriggerArgs
@@ -192,7 +192,7 @@ class OperatorPartial:
         return self._expand(DictOfListsExpandInput(mapped_kwargs), 
strict=False)
 
     def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: 
bool = True) -> MappedOperator:
-        from airflow.models.xcom_arg import XComArg
+        from airflow.sdk.definitions.xcom_arg import XComArg
 
         if isinstance(kwargs, Sequence):
             for item in kwargs:
@@ -338,7 +338,7 @@ class MappedOperator(AbstractOperator):
         return f"<Mapped({self._task_type}): {self.task_id}>"
 
     def __attrs_post_init__(self):
-        from airflow.models.xcom_arg import XComArg
+        from airflow.sdk.definitions.xcom_arg import XComArg
 
         if self.get_closest_mapped_task_group() is not None:
             raise NotImplementedError("operator expansion in an expanded task 
group is not yet supported")
@@ -675,7 +675,7 @@ class MappedOperator(AbstractOperator):
     @property
     def output(self) -> XComArg:
         """Return reference to XCom pushed by current operator."""
-        from airflow.models.xcom_arg import XComArg
+        from airflow.sdk.definitions.xcom_arg import XComArg
 
         return XComArg(operator=self)
 
diff --git a/task-sdk/src/airflow/sdk/definitions/taskgroup.py 
b/task-sdk/src/airflow/sdk/definitions/taskgroup.py
index 3363424dee6..5ef139c14eb 100644
--- a/task-sdk/src/airflow/sdk/definitions/taskgroup.py
+++ b/task-sdk/src/airflow/sdk/definitions/taskgroup.py
@@ -655,7 +655,7 @@ class MappedTaskGroup(TaskGroup):
 
     def iter_mapped_dependencies(self) -> Iterator[Operator]:
         """Upstream dependencies that provide XComs used by this mapped task 
group."""
-        from airflow.models.xcom_arg import XComArg
+        from airflow.sdk.definitions.xcom_arg import XComArg
 
         for op, _ in XComArg.iter_xcom_references(self._expand_input):
             yield op

Reply via email to