This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch task-sdk-first-code in repository https://gitbox.apache.org/repos/asf/airflow.git
commit d5c7d046eb80a7d66e0517107de81a20c4ad3fa0 Author: Ash Berlin-Taylor <[email protected]> AuthorDate: Thu Oct 24 19:20:48 2024 +0100 make mpypy happy [skip ci] --- airflow/decorators/bash.py | 4 ++-- airflow/decorators/sensor.py | 2 +- airflow/models/abstractoperator.py | 4 ++-- airflow/models/xcom_arg.py | 8 ++++---- airflow/sensors/external_task.py | 4 ++-- scripts/ci/pre_commit/sync_init_decorator.py | 8 +++++++- .../src/airflow/sdk/definitions/abstractoperator.py | 6 ------ task_sdk/src/airflow/sdk/definitions/contextmanager.py | 4 ++-- task_sdk/src/airflow/sdk/definitions/dag.py | 16 +++++++--------- task_sdk/src/airflow/sdk/definitions/mixins.py | 6 ++---- task_sdk/src/airflow/sdk/definitions/taskgroup.py | 18 +++++++++--------- 11 files changed, 38 insertions(+), 42 deletions(-) diff --git a/airflow/decorators/bash.py b/airflow/decorators/bash.py index 44738492da0..e4dc19745e0 100644 --- a/airflow/decorators/bash.py +++ b/airflow/decorators/bash.py @@ -18,7 +18,7 @@ from __future__ import annotations import warnings -from typing import Any, Callable, Collection, Mapping, Sequence +from typing import Any, Callable, ClassVar, Collection, Mapping, Sequence from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory from airflow.providers.standard.operators.bash import BashOperator @@ -39,7 +39,7 @@ class _BashDecoratedOperator(DecoratedOperator, BashOperator): """ template_fields: Sequence[str] = (*DecoratedOperator.template_fields, *BashOperator.template_fields) - template_fields_renderers: dict[str, str] = { + template_fields_renderers: ClassVar[dict[str, str]] = { **DecoratedOperator.template_fields_renderers, **BashOperator.template_fields_renderers, } diff --git a/airflow/decorators/sensor.py b/airflow/decorators/sensor.py index c37cd08d6b4..6ed3e9cc398 100644 --- a/airflow/decorators/sensor.py +++ b/airflow/decorators/sensor.py @@ -42,7 +42,7 @@ class DecoratedSensorOperator(PythonSensor): """ template_fields: Sequence[str] = ("op_args", "op_kwargs") - template_fields_renderers: dict[str, str] = {"op_args": "py", "op_kwargs": "py"} + template_fields_renderers: ClassVar[dict[str, str]] = {"op_args": "py", "op_kwargs": "py"} custom_operator_name = "@task.sensor" diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index 93c0fd2d93a..feafb0b6b63 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -40,8 +40,6 @@ from airflow.utils.task_group import MappedTaskGroup from airflow.utils.trigger_rule import TriggerRule from airflow.utils.weight_rule import WeightRule -TaskStateChangeCallback = Callable[[Context], None] - if TYPE_CHECKING: from collections.abc import Mapping @@ -58,6 +56,8 @@ if TYPE_CHECKING: from airflow.triggers.base import StartTriggerArgs from airflow.utils.task_group import TaskGroup +TaskStateChangeCallback = Callable[[Context], None] + DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner") DEFAULT_POOL_SLOTS: int = 1 DEFAULT_PRIORITY_WEIGHT: int = 1 diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index 940a7f1a066..c28af6acbe5 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -393,15 +393,15 @@ class PlainXComArg(XComArg): def as_teardown( self, *, - setups: BaseOperator | Iterable[BaseOperator] | ArgNotSet = NOTSET, - on_failure_fail_dagrun=NOTSET, + setups: BaseOperator | Iterable[BaseOperator] | None = None, + on_failure_fail_dagrun: bool | None = None, ): for operator, _ in self.iter_references(): operator.is_teardown = True operator.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS - if on_failure_fail_dagrun is not NOTSET: + if on_failure_fail_dagrun is not None: operator.on_failure_fail_dagrun = on_failure_fail_dagrun - if not isinstance(setups, ArgNotSet): + if setups is not None: setups = [setups] if isinstance(setups, DependencyMixin) else setups for s in setups: s.is_setup = True diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py index 8eb501e281d..331e17168ba 100644 --- a/airflow/sensors/external_task.py +++ b/airflow/sensors/external_task.py @@ -20,7 +20,7 @@ from __future__ import annotations import datetime import os import warnings -from typing import TYPE_CHECKING, Any, Callable, Collection, Iterable +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, Iterable from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowSkipException @@ -476,7 +476,7 @@ class ExternalTaskMarker(EmptyOperator): operator_extra_links = [ExternalDagLink()] # The _serialized_fields are lazily loaded when get_serialized_fields() method is called - __serialized_fields: frozenset[str] | None = None + __serialized_fields: ClassVar[frozenset[str] | None] = None def __init__( self, diff --git a/scripts/ci/pre_commit/sync_init_decorator.py b/scripts/ci/pre_commit/sync_init_decorator.py index 13e80d62c6c..7b02136ead3 100755 --- a/scripts/ci/pre_commit/sync_init_decorator.py +++ b/scripts/ci/pre_commit/sync_init_decorator.py @@ -116,7 +116,13 @@ def _expr_to_ast_dump(expr: str) -> str: ALLOWABLE_TYPE_ANNOTATIONS = { - _expr_to_ast_dump("Collection[str] | None"): _expr_to_ast_dump("MutableSet[str]") + # Mapping of allowble Decorator type -> Class attribute type + _expr_to_ast_dump("Collection[str] | None"): _expr_to_ast_dump("MutableSet[str]"), + _expr_to_ast_dump("ParamsDict | dict[str, Any] | None"): _expr_to_ast_dump("ParamsDict"), + # TODO: This one is legacy access control. Remove it in 3.0. RemovedInAirflow3Warning + _expr_to_ast_dump( + "dict[str, dict[str, Collection[str]]] | dict[str, Collection[str]] | None" + ): _expr_to_ast_dump("dict[str, dict[str, Collection[str]]] | None"), } diff --git a/task_sdk/src/airflow/sdk/definitions/abstractoperator.py b/task_sdk/src/airflow/sdk/definitions/abstractoperator.py index bb5ddf88e23..5285bd97ef4 100644 --- a/task_sdk/src/airflow/sdk/definitions/abstractoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/abstractoperator.py @@ -34,18 +34,12 @@ from airflow.sdk.definitions.node import DAGNode from airflow.utils.trigger_rule import TriggerRule from airflow.utils.weight_rule import WeightRule -# TaskStateChangeCallback = Callable[[Context], None] - if TYPE_CHECKING: from airflow.models.baseoperatorlink import BaseOperatorLink from airflow.models.operator import Operator from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.dag import DAG - # TODO: Task-SDK - Context = dict[str, Any] - - DEFAULT_OWNER: str = "airflow" DEFAULT_POOL_SLOTS: int = 1 DEFAULT_PRIORITY_WEIGHT: int = 1 diff --git a/task_sdk/src/airflow/sdk/definitions/contextmanager.py b/task_sdk/src/airflow/sdk/definitions/contextmanager.py index 8b5458c65b9..ac50dcadbfc 100644 --- a/task_sdk/src/airflow/sdk/definitions/contextmanager.py +++ b/task_sdk/src/airflow/sdk/definitions/contextmanager.py @@ -20,7 +20,7 @@ from __future__ import annotations import sys from collections import deque from types import ModuleType -from typing import Any, Generic, Optional, TypeVar, cast +from typing import Any, Generic, TypeVar from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.taskgroup import TaskGroup @@ -109,7 +109,7 @@ class DagContext(ContextStack[DAG]): @classmethod def get_current_dag(cls) -> DAG | None: - return cast(Optional[DAG], cls.get_current()) + return cls.get_current() class TaskGroupContext(ContextStack[TaskGroup]): diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py b/task_sdk/src/airflow/sdk/definitions/dag.py index 111f51ce855..a8f222fd8ad 100644 --- a/task_sdk/src/airflow/sdk/definitions/dag.py +++ b/task_sdk/src/airflow/sdk/definitions/dag.py @@ -63,6 +63,7 @@ from airflow.timetables.simple import ( NullTimetable, OnceTimetable, ) +from airflow.utils.context import Context from airflow.utils.dag_cycle_tester import check_cycle from airflow.utils.decorators import fixup_decorator_warning_stack from airflow.utils.trigger_rule import TriggerRule @@ -88,10 +89,6 @@ __all__ = [ ] -# TODO: Task-SDK -class Context: ... - - DagStateChangeCallback = Callable[[Context], None] ScheduleInterval = Union[None, str, timedelta, relativedelta] @@ -341,7 +338,8 @@ class DAG: default_args: dict[str, Any] = attrs.field( factory=dict, validator=attrs.validators.instance_of(dict), converter=dict_copy ) - start_date: datetime | None = attrs.field() + start_date: datetime | None = attrs.field() # type: ignore[misc] # mypy doesn't grok the `@dag.default` seemingly + end_date: datetime | None = None timezone: FixedTimezone | Timezone = attrs.field(init=False) schedule: ScheduleArg = attrs.field(default=None, on_setattr=attrs.setters.frozen) @@ -373,7 +371,7 @@ class DAG: default=None, converter=attrs.Converter(_convert_params, takes_self=True), # type: ignore[misc, call-overload] ) - access_control: dict[str, dict[str, Collection[str]]] | dict[str, Collection[str]] | None = attrs.field( + access_control: dict[str, dict[str, Collection[str]]] | None = attrs.field( default=None, converter=attrs.Converter(_convert_access_control, takes_self=True), # type: ignore[misc, call-overload] ) @@ -384,11 +382,11 @@ class DAG: owner_links: dict[str, str] = attrs.field(factory=dict) auto_register: bool = attrs.field(default=True, converter=bool) fail_stop: bool = attrs.field(default=False, converter=bool) - dag_display_name: str = attrs.field(validator=attrs.validators.instance_of(str)) + dag_display_name: str = attrs.field(validator=attrs.validators.instance_of(str)) # type: ignore[misc] # mypy doesn't grok the `@dag.default` seemingly task_dict: dict[str, Operator] = attrs.field(factory=dict, init=False) - task_group: TaskGroup = attrs.field(on_setattr=attrs.setters.frozen) + task_group: TaskGroup = attrs.field(on_setattr=attrs.setters.frozen) # type: ignore[misc] # mypy doesn't grok the `@dag.default` seemingly fileloc: str = attrs.field(init=False) partial: bool = attrs.field(init=False, default=False) @@ -1036,7 +1034,7 @@ if TYPE_CHECKING: on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, doc_md: str | None = None, - params: ParamsDict | None = None, + params: ParamsDict | dict[str, Any] | None = None, access_control: dict[str, dict[str, Collection[str]]] | dict[str, Collection[str]] | None = None, is_paused_upon_creation: bool | None = None, jinja_environment_kwargs: dict | None = None, diff --git a/task_sdk/src/airflow/sdk/definitions/mixins.py b/task_sdk/src/airflow/sdk/definitions/mixins.py index e9d6e162927..de63772615d 100644 --- a/task_sdk/src/airflow/sdk/definitions/mixins.py +++ b/task_sdk/src/airflow/sdk/definitions/mixins.py @@ -21,8 +21,6 @@ from abc import abstractmethod from collections.abc import Iterable, Sequence from typing import TYPE_CHECKING, Any -from airflow.sdk.types import NOTSET, ArgNotSet - if TYPE_CHECKING: from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.edges import EdgeModifier @@ -72,8 +70,8 @@ class DependencyMixin: def as_teardown( self, *, - setups: BaseOperator | Iterable[BaseOperator] | ArgNotSet = NOTSET, - on_failure_fail_dagrun: bool | ArgNotSet = NOTSET, + setups: BaseOperator | Iterable[BaseOperator] | None = None, + on_failure_fail_dagrun: bool | None = None, ) -> DependencyMixin: """Mark a task as teardown and set its setups as direct relatives.""" raise NotImplementedError() diff --git a/task_sdk/src/airflow/sdk/definitions/taskgroup.py b/task_sdk/src/airflow/sdk/definitions/taskgroup.py index e417eab2760..54602961cb8 100644 --- a/task_sdk/src/airflow/sdk/definitions/taskgroup.py +++ b/task_sdk/src/airflow/sdk/definitions/taskgroup.py @@ -63,6 +63,12 @@ TASKGROUP_ARGS_EXPECTED_TYPES = { } +def _default_parent_group() -> TaskGroup | None: + from airflow.sdk.definitions.contextmanager import TaskGroupContext + + return TaskGroupContext.get_current() + + @attrs.define(repr=False) class TaskGroup(DAGNode): """ @@ -96,8 +102,8 @@ class TaskGroup(DAGNode): _group_id: str | None prefix_group_id: bool = True - parent_group: TaskGroup | None = attrs.field() - dag: DAG = attrs.field() + parent_group: TaskGroup | None = attrs.field(factory=_default_parent_group) + dag: DAG = attrs.field() # type: ignore[misc] # mypy doesn't grok the `@dag.default` seemingly default_args: dict[str, Any] = attrs.field(factory=dict, converter=copy.deepcopy) tooltip: str = "" children: dict[str, DAGNode] = attrs.field(factory=dict, init=False) @@ -114,12 +120,6 @@ class TaskGroup(DAGNode): add_suffix_on_collision: bool = False - @parent_group.default - def _default_parent_group(self): - from airflow.sdk.definitions.contextmanager import TaskGroupContext - - return TaskGroupContext.get_current() - @dag.default def _default_dag(self): from airflow.sdk.definitions.contextmanager import DagContext @@ -247,7 +247,7 @@ class TaskGroup(DAGNode): @property def group_id(self) -> str | None: """group_id of this TaskGroup.""" - if self.parent_group and self.parent_group.prefix_group_id and self.parent_group.group_id: + if self.parent_group and self.parent_group.prefix_group_id and self.parent_group.node_id: # defer to parent whether it adds a prefix return self.parent_group.child_id(self.group_id)
