This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch task-sdk-first-code in repository https://gitbox.apache.org/repos/asf/airflow.git
commit ec44ab1e764714b82579c528abb26bb8e6d7abd3 Author: Ash Berlin-Taylor <[email protected]> AuthorDate: Fri Oct 18 22:16:38 2024 +0100 Get more tests passing [ci skip] --- airflow/models/baseoperator.py | 52 ++++------------------ airflow/models/dag.py | 2 +- .../src/airflow/sdk/definitions/baseoperator.py | 45 ++++++++++++++++++- task_sdk/src/airflow/sdk/definitions/dag.py | 42 +++++++---------- task_sdk/src/airflow/sdk/definitions/taskgroup.py | 15 +++++-- task_sdk/tests/defintions/test_dag.py | 42 +++++++++++++++++ tests/models/test_dag.py | 43 ++---------------- 7 files changed, 127 insertions(+), 114 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 39b7a1ba6f4..0dc533bef7c 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -24,6 +24,7 @@ Base operator for all operators. from __future__ import annotations import collections.abc +import contextlib import copy import functools import logging @@ -391,7 +392,14 @@ class BaseOperatorMeta(TaskSDKBaseOperatorMeta): execute_method = namespace.get("execute") if callable(execute_method) and not getattr(execute_method, "__isabstractmethod__", False): namespace["execute"] = ExecutorSafeguard().decorator(execute_method) - return super().__new__(cls, name, bases, namespace, **kwargs) + new_cls = super().__new__(cls, name, bases, namespace, **kwargs) + with contextlib.suppress(KeyError): + # Update the partial descriptor with the class method, so it calls the actual function + # (but let subclasses override it if they need to) + partial_desc = vars(new_cls)["partial"] + if isinstance(partial_desc, _PartialDescriptor): + partial_desc.class_method = classmethod(partial) + return new_cls class BaseOperator(TaskSDKBaseOperator, AbstractOperator, metaclass=BaseOperatorMeta): @@ -620,16 +628,6 @@ class BaseOperator(TaskSDKBaseOperator, AbstractOperator, metaclass=BaseOperator self._pre_execute_hook = pre_execute self._post_execute_hook = post_execute - # base list which includes all the attrs that don't need deep copy. - _base_operator_shallow_copy_attrs: tuple[str, ...] = ( - "user_defined_macros", - "user_defined_filters", - "params", - ) - - # each operator should override this class attr for shallow copy attrs. - shallow_copy_attrs: Sequence[str] = () - # Defines the operator level extra links operator_extra_links: Collection[BaseOperatorLink] = () @@ -719,38 +717,6 @@ class BaseOperator(TaskSDKBaseOperator, AbstractOperator, metaclass=BaseOperator logger=self.log, ).run(context, result) - def __deepcopy__(self, memo): - # Hack sorting double chained task lists by task_id to avoid hitting - # max_depth on deepcopy operations. - sys.setrecursionlimit(5000) # TODO fix this in a better way - - cls = self.__class__ - result = cls.__new__(cls) - memo[id(self)] = result - - shallow_copy = cls.shallow_copy_attrs + cls._base_operator_shallow_copy_attrs - - for k, v in self.__dict__.items(): - if k == "_BaseOperator__instantiated": - # Don't set this until the _end_, as it changes behaviour of __setattr__ - continue - if k not in shallow_copy: - setattr(result, k, copy.deepcopy(v, memo)) - else: - setattr(result, k, copy.copy(v)) - result.__instantiated = self.__instantiated - return result - - def __getstate__(self): - state = dict(self.__dict__) - if self._log: - del state["_log"] - - return state - - def __setstate__(self, state): - self.__dict__ = state - def render_template_fields( self, context: Context, diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 50a4eed22f0..691a64ed290 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -413,7 +413,7 @@ class DAG(TaskSDKDag, LoggingMixin): """ partial: bool = False - last_loaded: datetime | None = None + last_loaded: datetime | None = attrs.field(factory=timezone.utcnow) on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py b/task_sdk/src/airflow/sdk/definitions/baseoperator.py index fc43d840bae..eb7e57907eb 100644 --- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py @@ -22,13 +22,14 @@ import collections.abc import contextlib import copy import inspect +import sys import warnings from collections.abc import Collection, Iterable, Sequence from dataclasses import dataclass, field from datetime import datetime, timedelta from functools import total_ordering, wraps from types import FunctionType -from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, cast +from typing import TYPE_CHECKING, Any, ClassVar, Final, TypeVar, cast import attrs @@ -617,6 +618,16 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): # start_trigger_args: StartTriggerArgs | None = None # start_from_trigger: bool = False + # base list which includes all the attrs that don't need deep copy. + _base_operator_shallow_copy_attrs: Final[tuple[str, ...]] = ( + "user_defined_macros", + "user_defined_filters", + "params", + ) + + # each operator should override this class attr for shallow copy attrs. + shallow_copy_attrs: ClassVar[Sequence[str]] = () + def __setattr__(self: BaseOperator, key: str, value: Any): if converter := getattr(self, f"_convert_{key}", None): value = converter(value) @@ -917,6 +928,38 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): return self + def __deepcopy__(self, memo: dict[int, Any]): + # Hack sorting double chained task lists by task_id to avoid hitting + # max_depth on deepcopy operations. + sys.setrecursionlimit(5000) # TODO fix this in a better way + + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + + shallow_copy = cls.shallow_copy_attrs + cls._base_operator_shallow_copy_attrs + + for k, v in self.__dict__.items(): + if k not in shallow_copy: + v = copy.deepcopy(v, memo) + else: + v = copy.copy(v) + + # Bypass any setters, and set it on the object directly. This works since we are cloning ourself so + # we know the type is already fine + object.__setattr__(result, k, v) + return result + + def __getstate__(self): + state = dict(self.__dict__) + if self._log: + del state["_log"] + + return state + + def __setstate__(self, state): + self.__dict__ = state + def add_inlets(self, inlets: Iterable[Any]): """Set inlets to this operator.""" self.inlets.extend(inlets) diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py b/task_sdk/src/airflow/sdk/definitions/dag.py index 2eafeb54814..0ad82e52455 100644 --- a/task_sdk/src/airflow/sdk/definitions/dag.py +++ b/task_sdk/src/airflow/sdk/definitions/dag.py @@ -108,13 +108,14 @@ _DAG_HASH_ATTRS = frozenset( { "dag_id", "task_ids", - "parent_dag", "start_date", "end_date", "fileloc", "template_searchpath", "last_loaded", - "timetable", + "schedule", + # TODO: Task-SDK: we should be hashing on timetable now, not scheulde! + # "timetable", } ) @@ -218,7 +219,8 @@ else: dict_copy = copy.copy [email protected](repr=False, field_transformer=_all_after_dag_id_to_kw_only) +# TODO: Task-SDK: look at re-enabling slots after we remove pickling [email protected](repr=False, field_transformer=_all_after_dag_id_to_kw_only, slots=False) class DAG: """ A dag (directed acyclic graph) is a collection of tasks with directional dependencies. @@ -330,16 +332,6 @@ class DAG: :param dag_display_name: The display name of the DAG which appears on the UI. """ - _comps = { - "dag_id", - "task_ids", - "start_date", - "end_date", - "fileloc", - "template_searchpath", - "last_loaded", - } - __serialized_fields: ClassVar[frozenset[str] | None] = None # NOTE: When updating arguments here, please also keep arguments in @dag() @@ -430,7 +422,8 @@ class DAG: from airflow.assets import AssetAll schedule = self.schedule - delattr(self, "schedule") + # TODO: Once + # delattr(self, "schedule") if isinstance(schedule, Timetable): return schedule elif isinstance(schedule, BaseAsset): @@ -495,8 +488,9 @@ class DAG: return f"<DAG: {self.dag_id}>" def __eq__(self, other: Self | Any): - if not isinstance(other, type(self)): - return NotImplemented + # TODO: This subclassing behaviour seems wrong, but it's what Airflow has done for ~ever. + if type(self) is not type(other): + return False return all(getattr(self, c, None) == getattr(other, c, None) for c in _DAG_HASH_ATTRS) def __ne__(self, other: Any): @@ -685,7 +679,7 @@ class DAG: return tuple(nested_topo(self.task_group)) - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: dict[int, Any]): # Switcharoo to go around deepcopying objects coming through the # backdoor cls = self.__class__ @@ -693,7 +687,7 @@ class DAG: memo[id(self)] = result for k, v in self.__dict__.items(): if k not in ("user_defined_macros", "user_defined_filters", "_log"): - setattr(result, k, copy.deepcopy(v, memo)) + object.__setattr__(result, k, copy.deepcopy(v, memo)) result.user_defined_macros = self.user_defined_macros result.user_defined_filters = self.user_defined_filters @@ -763,13 +757,13 @@ class DAG: upstream = (u for u in t.upstream_list if isinstance(u, (BaseOperator, MappedOperator))) direct_upstreams.extend(upstream) - # Compiling the unique list of tasks that made the cut # Make sure to not recursively deepcopy the dag or task_group while copying the task. # task_group is reset later def _deepcopy_task(t) -> Operator: memo.setdefault(id(t.task_group), None) return copy.deepcopy(t, memo) + # Compiling the unique list of tasks that made the cut dag.task_dict = { t.task_id: _deepcopy_task(t) for t in itertools.chain(matched_tasks, also_include, direct_upstreams) @@ -785,12 +779,10 @@ class DAG: memo[id(group.children)] = {} if parent_group: memo[id(group.parent_group)] = parent_group - for attr, value in copied.__dict__.items(): - if id(value) in memo: - value = memo[id(value)] - else: - value = copy.deepcopy(value, memo) - copied.__dict__[attr] = value + for attr in type(group).__slots__: + value = getattr(group, attr) + value = copy.deepcopy(value, memo) + object.__setattr__(copied, attr, value) proxy = weakref.proxy(copied) diff --git a/task_sdk/src/airflow/sdk/definitions/taskgroup.py b/task_sdk/src/airflow/sdk/definitions/taskgroup.py index e60542ddf96..72cd9ed3bc7 100644 --- a/task_sdk/src/airflow/sdk/definitions/taskgroup.py +++ b/task_sdk/src/airflow/sdk/definitions/taskgroup.py @@ -96,7 +96,7 @@ class TaskGroup(DAGNode): _group_id: str | None prefix_group_id: bool = True - parent_group: TaskGroup | None = None + parent_group: TaskGroup | None = attrs.field() dag: DAG = attrs.field() default_args: dict[str, Any] = attrs.field(factory=dict, converter=copy.deepcopy) tooltip: str = "" @@ -112,17 +112,24 @@ class TaskGroup(DAGNode): ui_color: str = "CornflowerBlue" ui_fgcolor: str = "#000" + @parent_group.default + def _default_parent_group(self): + from airflow.sdk.definitions.contextmanager import TaskGroupContext + + return TaskGroupContext.get_current() + @dag.default def _default_dag(self): from airflow.sdk.definitions.contextmanager import DagContext if self.parent_group is not None: return self.parent_group.dag - dag = DagContext.get_current() + return DagContext.get_current() + + @dag.validator + def _validate_dag(self, _attr, dag): if not dag: raise RuntimeError("TaskGroup can only be used inside a dag") - self.parent_group = dag.task_group - return dag def __attrs_post_init__(self): if self.parent_group: diff --git a/task_sdk/tests/defintions/test_dag.py b/task_sdk/tests/defintions/test_dag.py index e6ff426b7fa..dd4ded2f4fe 100644 --- a/task_sdk/tests/defintions/test_dag.py +++ b/task_sdk/tests/defintions/test_dag.py @@ -203,3 +203,45 @@ class TestDag: # Check that we get a ValueError 'start_date' for self.start_date when schedule is non-none with pytest.raises(ValueError, match="start_date is required when catchup=True"): DAG(dag_id="dag_with_non_none_schedule_and_empty_start_date", schedule="@hourly", catchup=True) + + def test_partial_subset_updates_all_references_while_deepcopy(self): + with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag: + op1 = BaseOperator(task_id="t1") + op2 = BaseOperator(task_id="t2") + op3 = BaseOperator(task_id="t3") + op1 >> op2 + op2 >> op3 + + partial = dag.partial_subset("t2", include_upstream=True, include_downstream=False) + assert id(partial.task_dict["t1"].downstream_list[0].dag) == id(partial) + + # Copied DAG should not include unused task IDs in used_group_ids + assert "t3" not in partial.task_group.used_group_ids + + def test_partial_subset_taskgroup_join_ids(self): + from airflow.sdk import TaskGroup + + with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag: + start = BaseOperator(task_id="start") + with TaskGroup(group_id="outer", prefix_group_id=False) as outer_group: + with TaskGroup(group_id="tg1", prefix_group_id=False) as tg1: + BaseOperator(task_id="t1") + with TaskGroup(group_id="tg2", prefix_group_id=False) as tg2: + BaseOperator(task_id="t2") + + start >> tg1 >> tg2 + + # Pre-condition checks + task = dag.get_task("t2") + assert task.task_group.upstream_group_ids == {"tg1"} + assert isinstance(task.task_group.parent_group, weakref.ProxyType) + assert task.task_group.parent_group == outer_group + + partial = dag.partial_subset(["t2"], include_upstream=True, include_downstream=False) + copied_task = partial.get_task("t2") + assert copied_task.task_group.upstream_group_ids == {"tg1"} + assert isinstance(copied_task.task_group.parent_group, weakref.ProxyType) + assert copied_task.task_group.parent_group + + # Make sure we don't affect the original! + assert task.task_group.upstream_group_ids is not copied_task.task_group.upstream_group_ids diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index f739c34a4a4..0b0b3c340f7 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -1116,46 +1116,6 @@ class TestDag: dag = DAG("DAG", schedule=None, default_args=default_args) assert dag.timezone.name == local_tz.name - def test_partial_subset_updates_all_references_while_deepcopy(self): - with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag: - op1 = EmptyOperator(task_id="t1") - op2 = EmptyOperator(task_id="t2") - op3 = EmptyOperator(task_id="t3") - op1 >> op2 - op2 >> op3 - - partial = dag.partial_subset("t2", include_upstream=True, include_downstream=False) - assert id(partial.task_dict["t1"].downstream_list[0].dag) == id(partial) - - # Copied DAG should not include unused task IDs in used_group_ids - assert "t3" not in partial.task_group.used_group_ids - - def test_partial_subset_taskgroup_join_ids(self): - with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag: - start = EmptyOperator(task_id="start") - with TaskGroup(group_id="outer", prefix_group_id=False) as outer_group: - with TaskGroup(group_id="tg1", prefix_group_id=False) as tg1: - EmptyOperator(task_id="t1") - with TaskGroup(group_id="tg2", prefix_group_id=False) as tg2: - EmptyOperator(task_id="t2") - - start >> tg1 >> tg2 - - # Pre-condition checks - task = dag.get_task("t2") - assert task.task_group.upstream_group_ids == {"tg1"} - assert isinstance(task.task_group.parent_group, weakref.ProxyType) - assert task.task_group.parent_group == outer_group - - partial = dag.partial_subset(["t2"], include_upstream=True, include_downstream=False) - copied_task = partial.get_task("t2") - assert copied_task.task_group.upstream_group_ids == {"tg1"} - assert isinstance(copied_task.task_group.parent_group, weakref.ProxyType) - assert copied_task.task_group.parent_group - - # Make sure we don't affect the original! - assert task.task_group.upstream_group_ids is not copied_task.task_group.upstream_group_ids - def test_schedule_dag_no_previous_runs(self): """ Tests scheduling a dag with no previous runs @@ -1539,6 +1499,9 @@ class TestDag: # a fail stop dag should not allow a non-default trigger rule with pytest.raises(FailStopDagInvalidTriggerRule): + task_with_non_default_trigger_rule = EmptyOperator( + task_id="task_with_non_default_trigger_rule", trigger_rule=TriggerRule.ALWAYS + ) fail_stop_dag.add_task(task_with_non_default_trigger_rule) def test_dag_add_task_sets_default_task_group(self):
