This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch task-sdk-first-code
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit ec44ab1e764714b82579c528abb26bb8e6d7abd3
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Fri Oct 18 22:16:38 2024 +0100

    Get more tests passing
    
    [ci skip]
---
 airflow/models/baseoperator.py                     | 52 ++++------------------
 airflow/models/dag.py                              |  2 +-
 .../src/airflow/sdk/definitions/baseoperator.py    | 45 ++++++++++++++++++-
 task_sdk/src/airflow/sdk/definitions/dag.py        | 42 +++++++----------
 task_sdk/src/airflow/sdk/definitions/taskgroup.py  | 15 +++++--
 task_sdk/tests/defintions/test_dag.py              | 42 +++++++++++++++++
 tests/models/test_dag.py                           | 43 ++----------------
 7 files changed, 127 insertions(+), 114 deletions(-)

diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 39b7a1ba6f4..0dc533bef7c 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -24,6 +24,7 @@ Base operator for all operators.
 from __future__ import annotations
 
 import collections.abc
+import contextlib
 import copy
 import functools
 import logging
@@ -391,7 +392,14 @@ class BaseOperatorMeta(TaskSDKBaseOperatorMeta):
         execute_method = namespace.get("execute")
         if callable(execute_method) and not getattr(execute_method, 
"__isabstractmethod__", False):
             namespace["execute"] = 
ExecutorSafeguard().decorator(execute_method)
-        return super().__new__(cls, name, bases, namespace, **kwargs)
+        new_cls = super().__new__(cls, name, bases, namespace, **kwargs)
+        with contextlib.suppress(KeyError):
+            # Update the partial descriptor with the class method, so it calls 
the actual function
+            # (but let subclasses override it if they need to)
+            partial_desc = vars(new_cls)["partial"]
+            if isinstance(partial_desc, _PartialDescriptor):
+                partial_desc.class_method = classmethod(partial)
+        return new_cls
 
 
 class BaseOperator(TaskSDKBaseOperator, AbstractOperator, 
metaclass=BaseOperatorMeta):
@@ -620,16 +628,6 @@ class BaseOperator(TaskSDKBaseOperator, AbstractOperator, 
metaclass=BaseOperator
         self._pre_execute_hook = pre_execute
         self._post_execute_hook = post_execute
 
-    # base list which includes all the attrs that don't need deep copy.
-    _base_operator_shallow_copy_attrs: tuple[str, ...] = (
-        "user_defined_macros",
-        "user_defined_filters",
-        "params",
-    )
-
-    # each operator should override this class attr for shallow copy attrs.
-    shallow_copy_attrs: Sequence[str] = ()
-
     # Defines the operator level extra links
     operator_extra_links: Collection[BaseOperatorLink] = ()
 
@@ -719,38 +717,6 @@ class BaseOperator(TaskSDKBaseOperator, AbstractOperator, 
metaclass=BaseOperator
             logger=self.log,
         ).run(context, result)
 
-    def __deepcopy__(self, memo):
-        # Hack sorting double chained task lists by task_id to avoid hitting
-        # max_depth on deepcopy operations.
-        sys.setrecursionlimit(5000)  # TODO fix this in a better way
-
-        cls = self.__class__
-        result = cls.__new__(cls)
-        memo[id(self)] = result
-
-        shallow_copy = cls.shallow_copy_attrs + 
cls._base_operator_shallow_copy_attrs
-
-        for k, v in self.__dict__.items():
-            if k == "_BaseOperator__instantiated":
-                # Don't set this until the _end_, as it changes behaviour of 
__setattr__
-                continue
-            if k not in shallow_copy:
-                setattr(result, k, copy.deepcopy(v, memo))
-            else:
-                setattr(result, k, copy.copy(v))
-        result.__instantiated = self.__instantiated
-        return result
-
-    def __getstate__(self):
-        state = dict(self.__dict__)
-        if self._log:
-            del state["_log"]
-
-        return state
-
-    def __setstate__(self, state):
-        self.__dict__ = state
-
     def render_template_fields(
         self,
         context: Context,
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 50a4eed22f0..691a64ed290 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -413,7 +413,7 @@ class DAG(TaskSDKDag, LoggingMixin):
     """
 
     partial: bool = False
-    last_loaded: datetime | None = None
+    last_loaded: datetime | None = attrs.field(factory=timezone.utcnow)
     on_success_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None
     on_failure_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None
 
diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py 
b/task_sdk/src/airflow/sdk/definitions/baseoperator.py
index fc43d840bae..eb7e57907eb 100644
--- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py
+++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py
@@ -22,13 +22,14 @@ import collections.abc
 import contextlib
 import copy
 import inspect
+import sys
 import warnings
 from collections.abc import Collection, Iterable, Sequence
 from dataclasses import dataclass, field
 from datetime import datetime, timedelta
 from functools import total_ordering, wraps
 from types import FunctionType
-from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, cast
+from typing import TYPE_CHECKING, Any, ClassVar, Final, TypeVar, cast
 
 import attrs
 
@@ -617,6 +618,16 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
     # start_trigger_args: StartTriggerArgs | None = None
     # start_from_trigger: bool = False
 
+    # base list which includes all the attrs that don't need deep copy.
+    _base_operator_shallow_copy_attrs: Final[tuple[str, ...]] = (
+        "user_defined_macros",
+        "user_defined_filters",
+        "params",
+    )
+
+    # each operator should override this class attr for shallow copy attrs.
+    shallow_copy_attrs: ClassVar[Sequence[str]] = ()
+
     def __setattr__(self: BaseOperator, key: str, value: Any):
         if converter := getattr(self, f"_convert_{key}", None):
             value = converter(value)
@@ -917,6 +928,38 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
 
         return self
 
+    def __deepcopy__(self, memo: dict[int, Any]):
+        # Hack sorting double chained task lists by task_id to avoid hitting
+        # max_depth on deepcopy operations.
+        sys.setrecursionlimit(5000)  # TODO fix this in a better way
+
+        cls = self.__class__
+        result = cls.__new__(cls)
+        memo[id(self)] = result
+
+        shallow_copy = cls.shallow_copy_attrs + 
cls._base_operator_shallow_copy_attrs
+
+        for k, v in self.__dict__.items():
+            if k not in shallow_copy:
+                v = copy.deepcopy(v, memo)
+            else:
+                v = copy.copy(v)
+
+            # Bypass any setters, and set it on the object directly. This 
works since we are cloning ourself so
+            # we know the type is already fine
+            object.__setattr__(result, k, v)
+        return result
+
+    def __getstate__(self):
+        state = dict(self.__dict__)
+        if self._log:
+            del state["_log"]
+
+        return state
+
+    def __setstate__(self, state):
+        self.__dict__ = state
+
     def add_inlets(self, inlets: Iterable[Any]):
         """Set inlets to this operator."""
         self.inlets.extend(inlets)
diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py 
b/task_sdk/src/airflow/sdk/definitions/dag.py
index 2eafeb54814..0ad82e52455 100644
--- a/task_sdk/src/airflow/sdk/definitions/dag.py
+++ b/task_sdk/src/airflow/sdk/definitions/dag.py
@@ -108,13 +108,14 @@ _DAG_HASH_ATTRS = frozenset(
     {
         "dag_id",
         "task_ids",
-        "parent_dag",
         "start_date",
         "end_date",
         "fileloc",
         "template_searchpath",
         "last_loaded",
-        "timetable",
+        "schedule",
+        # TODO: Task-SDK: we should be hashing on timetable now, not scheulde!
+        # "timetable",
     }
 )
 
@@ -218,7 +219,8 @@ else:
     dict_copy = copy.copy
 
 
[email protected](repr=False, field_transformer=_all_after_dag_id_to_kw_only)
+# TODO: Task-SDK: look at re-enabling slots after we remove pickling
[email protected](repr=False, field_transformer=_all_after_dag_id_to_kw_only, 
slots=False)
 class DAG:
     """
     A dag (directed acyclic graph) is a collection of tasks with directional 
dependencies.
@@ -330,16 +332,6 @@ class DAG:
     :param dag_display_name: The display name of the DAG which appears on the 
UI.
     """
 
-    _comps = {
-        "dag_id",
-        "task_ids",
-        "start_date",
-        "end_date",
-        "fileloc",
-        "template_searchpath",
-        "last_loaded",
-    }
-
     __serialized_fields: ClassVar[frozenset[str] | None] = None
 
     # NOTE: When updating arguments here, please also keep arguments in @dag()
@@ -430,7 +422,8 @@ class DAG:
         from airflow.assets import AssetAll
 
         schedule = self.schedule
-        delattr(self, "schedule")
+        # TODO: Once
+        # delattr(self, "schedule")
         if isinstance(schedule, Timetable):
             return schedule
         elif isinstance(schedule, BaseAsset):
@@ -495,8 +488,9 @@ class DAG:
         return f"<DAG: {self.dag_id}>"
 
     def __eq__(self, other: Self | Any):
-        if not isinstance(other, type(self)):
-            return NotImplemented
+        # TODO: This subclassing behaviour seems wrong, but it's what Airflow 
has done for ~ever.
+        if type(self) is not type(other):
+            return False
         return all(getattr(self, c, None) == getattr(other, c, None) for c in 
_DAG_HASH_ATTRS)
 
     def __ne__(self, other: Any):
@@ -685,7 +679,7 @@ class DAG:
 
         return tuple(nested_topo(self.task_group))
 
-    def __deepcopy__(self, memo):
+    def __deepcopy__(self, memo: dict[int, Any]):
         # Switcharoo to go around deepcopying objects coming through the
         # backdoor
         cls = self.__class__
@@ -693,7 +687,7 @@ class DAG:
         memo[id(self)] = result
         for k, v in self.__dict__.items():
             if k not in ("user_defined_macros", "user_defined_filters", 
"_log"):
-                setattr(result, k, copy.deepcopy(v, memo))
+                object.__setattr__(result, k, copy.deepcopy(v, memo))
 
         result.user_defined_macros = self.user_defined_macros
         result.user_defined_filters = self.user_defined_filters
@@ -763,13 +757,13 @@ class DAG:
                 upstream = (u for u in t.upstream_list if isinstance(u, 
(BaseOperator, MappedOperator)))
                 direct_upstreams.extend(upstream)
 
-        # Compiling the unique list of tasks that made the cut
         # Make sure to not recursively deepcopy the dag or task_group while 
copying the task.
         # task_group is reset later
         def _deepcopy_task(t) -> Operator:
             memo.setdefault(id(t.task_group), None)
             return copy.deepcopy(t, memo)
 
+        # Compiling the unique list of tasks that made the cut
         dag.task_dict = {
             t.task_id: _deepcopy_task(t)
             for t in itertools.chain(matched_tasks, also_include, 
direct_upstreams)
@@ -785,12 +779,10 @@ class DAG:
             memo[id(group.children)] = {}
             if parent_group:
                 memo[id(group.parent_group)] = parent_group
-            for attr, value in copied.__dict__.items():
-                if id(value) in memo:
-                    value = memo[id(value)]
-                else:
-                    value = copy.deepcopy(value, memo)
-                copied.__dict__[attr] = value
+            for attr in type(group).__slots__:
+                value = getattr(group, attr)
+                value = copy.deepcopy(value, memo)
+                object.__setattr__(copied, attr, value)
 
             proxy = weakref.proxy(copied)
 
diff --git a/task_sdk/src/airflow/sdk/definitions/taskgroup.py 
b/task_sdk/src/airflow/sdk/definitions/taskgroup.py
index e60542ddf96..72cd9ed3bc7 100644
--- a/task_sdk/src/airflow/sdk/definitions/taskgroup.py
+++ b/task_sdk/src/airflow/sdk/definitions/taskgroup.py
@@ -96,7 +96,7 @@ class TaskGroup(DAGNode):
 
     _group_id: str | None
     prefix_group_id: bool = True
-    parent_group: TaskGroup | None = None
+    parent_group: TaskGroup | None = attrs.field()
     dag: DAG = attrs.field()
     default_args: dict[str, Any] = attrs.field(factory=dict, 
converter=copy.deepcopy)
     tooltip: str = ""
@@ -112,17 +112,24 @@ class TaskGroup(DAGNode):
     ui_color: str = "CornflowerBlue"
     ui_fgcolor: str = "#000"
 
+    @parent_group.default
+    def _default_parent_group(self):
+        from airflow.sdk.definitions.contextmanager import TaskGroupContext
+
+        return TaskGroupContext.get_current()
+
     @dag.default
     def _default_dag(self):
         from airflow.sdk.definitions.contextmanager import DagContext
 
         if self.parent_group is not None:
             return self.parent_group.dag
-        dag = DagContext.get_current()
+        return DagContext.get_current()
+
+    @dag.validator
+    def _validate_dag(self, _attr, dag):
         if not dag:
             raise RuntimeError("TaskGroup can only be used inside a dag")
-        self.parent_group = dag.task_group
-        return dag
 
     def __attrs_post_init__(self):
         if self.parent_group:
diff --git a/task_sdk/tests/defintions/test_dag.py 
b/task_sdk/tests/defintions/test_dag.py
index e6ff426b7fa..dd4ded2f4fe 100644
--- a/task_sdk/tests/defintions/test_dag.py
+++ b/task_sdk/tests/defintions/test_dag.py
@@ -203,3 +203,45 @@ class TestDag:
         # Check that we get a ValueError 'start_date' for self.start_date when 
schedule is non-none
         with pytest.raises(ValueError, match="start_date is required when 
catchup=True"):
             DAG(dag_id="dag_with_non_none_schedule_and_empty_start_date", 
schedule="@hourly", catchup=True)
+
+    def test_partial_subset_updates_all_references_while_deepcopy(self):
+        with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag:
+            op1 = BaseOperator(task_id="t1")
+            op2 = BaseOperator(task_id="t2")
+            op3 = BaseOperator(task_id="t3")
+            op1 >> op2
+            op2 >> op3
+
+        partial = dag.partial_subset("t2", include_upstream=True, 
include_downstream=False)
+        assert id(partial.task_dict["t1"].downstream_list[0].dag) == 
id(partial)
+
+        # Copied DAG should not include unused task IDs in used_group_ids
+        assert "t3" not in partial.task_group.used_group_ids
+
+    def test_partial_subset_taskgroup_join_ids(self):
+        from airflow.sdk import TaskGroup
+
+        with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag:
+            start = BaseOperator(task_id="start")
+            with TaskGroup(group_id="outer", prefix_group_id=False) as 
outer_group:
+                with TaskGroup(group_id="tg1", prefix_group_id=False) as tg1:
+                    BaseOperator(task_id="t1")
+                with TaskGroup(group_id="tg2", prefix_group_id=False) as tg2:
+                    BaseOperator(task_id="t2")
+
+                start >> tg1 >> tg2
+
+        # Pre-condition checks
+        task = dag.get_task("t2")
+        assert task.task_group.upstream_group_ids == {"tg1"}
+        assert isinstance(task.task_group.parent_group, weakref.ProxyType)
+        assert task.task_group.parent_group == outer_group
+
+        partial = dag.partial_subset(["t2"], include_upstream=True, 
include_downstream=False)
+        copied_task = partial.get_task("t2")
+        assert copied_task.task_group.upstream_group_ids == {"tg1"}
+        assert isinstance(copied_task.task_group.parent_group, 
weakref.ProxyType)
+        assert copied_task.task_group.parent_group
+
+        # Make sure we don't affect the original!
+        assert task.task_group.upstream_group_ids is not 
copied_task.task_group.upstream_group_ids
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index f739c34a4a4..0b0b3c340f7 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -1116,46 +1116,6 @@ class TestDag:
         dag = DAG("DAG", schedule=None, default_args=default_args)
         assert dag.timezone.name == local_tz.name
 
-    def test_partial_subset_updates_all_references_while_deepcopy(self):
-        with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag:
-            op1 = EmptyOperator(task_id="t1")
-            op2 = EmptyOperator(task_id="t2")
-            op3 = EmptyOperator(task_id="t3")
-            op1 >> op2
-            op2 >> op3
-
-        partial = dag.partial_subset("t2", include_upstream=True, 
include_downstream=False)
-        assert id(partial.task_dict["t1"].downstream_list[0].dag) == 
id(partial)
-
-        # Copied DAG should not include unused task IDs in used_group_ids
-        assert "t3" not in partial.task_group.used_group_ids
-
-    def test_partial_subset_taskgroup_join_ids(self):
-        with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag:
-            start = EmptyOperator(task_id="start")
-            with TaskGroup(group_id="outer", prefix_group_id=False) as 
outer_group:
-                with TaskGroup(group_id="tg1", prefix_group_id=False) as tg1:
-                    EmptyOperator(task_id="t1")
-                with TaskGroup(group_id="tg2", prefix_group_id=False) as tg2:
-                    EmptyOperator(task_id="t2")
-
-                start >> tg1 >> tg2
-
-        # Pre-condition checks
-        task = dag.get_task("t2")
-        assert task.task_group.upstream_group_ids == {"tg1"}
-        assert isinstance(task.task_group.parent_group, weakref.ProxyType)
-        assert task.task_group.parent_group == outer_group
-
-        partial = dag.partial_subset(["t2"], include_upstream=True, 
include_downstream=False)
-        copied_task = partial.get_task("t2")
-        assert copied_task.task_group.upstream_group_ids == {"tg1"}
-        assert isinstance(copied_task.task_group.parent_group, 
weakref.ProxyType)
-        assert copied_task.task_group.parent_group
-
-        # Make sure we don't affect the original!
-        assert task.task_group.upstream_group_ids is not 
copied_task.task_group.upstream_group_ids
-
     def test_schedule_dag_no_previous_runs(self):
         """
         Tests scheduling a dag with no previous runs
@@ -1539,6 +1499,9 @@ class TestDag:
 
         # a fail stop dag should not allow a non-default trigger rule
         with pytest.raises(FailStopDagInvalidTriggerRule):
+            task_with_non_default_trigger_rule = EmptyOperator(
+                task_id="task_with_non_default_trigger_rule", 
trigger_rule=TriggerRule.ALWAYS
+            )
             fail_stop_dag.add_task(task_with_non_default_trigger_rule)
 
     def test_dag_add_task_sets_default_task_group(self):

Reply via email to