This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch task-sdk-first-code in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 872de48f3bb2579fba791cbce2e7f7314acc3884 Author: Ash Berlin-Taylor <[email protected]> AuthorDate: Fri Oct 25 17:56:33 2024 +0100 Fix mypy typing --- airflow/decorators/base.py | 2 +- airflow/decorators/bash.py | 4 +- airflow/decorators/sensor.py | 6 +- airflow/operators/python.py | 7 +- dev/mypy/plugin/outputs.py | 1 + .../providers/cncf/kubernetes/operators/pod.py | 4 +- .../cncf/kubernetes/operators/spark_kubernetes.py | 3 +- .../src/airflow/sdk/definitions/baseoperator.py | 17 ++- task_sdk/src/airflow/sdk/definitions/dag.py | 144 +++++++++++---------- task_sdk/src/airflow/sdk/definitions/taskgroup.py | 23 ++-- 10 files changed, 114 insertions(+), 97 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 6129dc1dd42..c9e4cf170f9 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -186,7 +186,7 @@ class DecoratedOperator(BaseOperator): # since we won't mutate the arguments, we should just do the shallow copy # there are some cases we can't deepcopy the objects (e.g protobuf). - shallow_copy_attrs: ClassVar[Sequence[str]] = ("python_callable",) + shallow_copy_attrs: Sequence[str] = ("python_callable",) def __init__( self, diff --git a/airflow/decorators/bash.py b/airflow/decorators/bash.py index e4dc19745e0..44738492da0 100644 --- a/airflow/decorators/bash.py +++ b/airflow/decorators/bash.py @@ -18,7 +18,7 @@ from __future__ import annotations import warnings -from typing import Any, Callable, ClassVar, Collection, Mapping, Sequence +from typing import Any, Callable, Collection, Mapping, Sequence from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory from airflow.providers.standard.operators.bash import BashOperator @@ -39,7 +39,7 @@ class _BashDecoratedOperator(DecoratedOperator, BashOperator): """ template_fields: Sequence[str] = (*DecoratedOperator.template_fields, *BashOperator.template_fields) - template_fields_renderers: ClassVar[dict[str, str]] = { + template_fields_renderers: dict[str, str] = { **DecoratedOperator.template_fields_renderers, **BashOperator.template_fields_renderers, } diff --git a/airflow/decorators/sensor.py b/airflow/decorators/sensor.py index 6ed3e9cc398..c332a78f95c 100644 --- a/airflow/decorators/sensor.py +++ b/airflow/decorators/sensor.py @@ -17,7 +17,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, ClassVar, Sequence +from typing import TYPE_CHECKING, Callable, Sequence from airflow.decorators.base import get_unique_task_id, task_decorator_factory from airflow.sensors.python import PythonSensor @@ -42,13 +42,13 @@ class DecoratedSensorOperator(PythonSensor): """ template_fields: Sequence[str] = ("op_args", "op_kwargs") - template_fields_renderers: ClassVar[dict[str, str]] = {"op_args": "py", "op_kwargs": "py"} + template_fields_renderers: dict[str, str] = {"op_args": "py", "op_kwargs": "py"} custom_operator_name = "@task.sensor" # since we won't mutate the arguments, we should just do the shallow copy # there are some cases we can't deepcopy the objects (e.g protobuf). - shallow_copy_attrs: ClassVar[Sequence[str]] = ("python_callable",) + shallow_copy_attrs: Sequence[str] = ("python_callable",) def __init__( self, diff --git a/airflow/operators/python.py b/airflow/operators/python.py index dc2e772af0e..3d40ad2c845 100644 --- a/airflow/operators/python.py +++ b/airflow/operators/python.py @@ -33,7 +33,7 @@ from collections.abc import Container from functools import cache from pathlib import Path from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, Iterable, Mapping, NamedTuple, Sequence +from typing import TYPE_CHECKING, Any, Callable, Collection, Iterable, Mapping, NamedTuple, Sequence import lazy_object_proxy @@ -197,10 +197,7 @@ class PythonOperator(BaseOperator): # since we won't mutate the arguments, we should just do the shallow copy # there are some cases we can't deepcopy the objects(e.g protobuf). - shallow_copy_attrs: ClassVar[Sequence[str]] = ( - "python_callable", - "op_kwargs", - ) + shallow_copy_attrs: Sequence[str] = ("python_callable", "op_kwargs") def __init__( self, diff --git a/dev/mypy/plugin/outputs.py b/dev/mypy/plugin/outputs.py index fe1ccd5e7cf..a3ba7351f55 100644 --- a/dev/mypy/plugin/outputs.py +++ b/dev/mypy/plugin/outputs.py @@ -25,6 +25,7 @@ from mypy.types import AnyType, Type, TypeOfAny OUTPUT_PROPERTIES = { "airflow.models.baseoperator.BaseOperator.output", "airflow.models.mappedoperator.MappedOperator.output", + "airflow.sdk.definitions.baseoperator.BaseOperator.output", } TASK_CALL_FUNCTIONS = { diff --git a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py index e51397447c3..62f08439d41 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py @@ -27,7 +27,7 @@ import re import shlex import string import warnings -from collections.abc import Container +from collections.abc import Container, Mapping from contextlib import AbstractContextManager from enum import Enum from functools import cached_property @@ -436,7 +436,7 @@ class KubernetesPodOperator(BaseOperator): def _render_nested_template_fields( self, content: Any, - context: Context, + context: Mapping[str, Any], jinja_env: jinja2.Environment, seen_oids: set, ) -> None: diff --git a/providers/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py b/providers/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py index c3dd4755b98..c1f5b36d6d3 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py +++ b/providers/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +from collections.abc import Mapping from functools import cached_property from pathlib import Path from typing import TYPE_CHECKING, Any @@ -127,7 +128,7 @@ class SparkKubernetesOperator(KubernetesPodOperator): def _render_nested_template_fields( self, content: Any, - context: Context, + context: Mapping[str, Any], jinja_env: jinja2.Environment, seen_oids: set, ) -> None: diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py b/task_sdk/src/airflow/sdk/definitions/baseoperator.py index 7b78aaca83d..e99ad835fb7 100644 --- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py @@ -507,7 +507,7 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): email_on_retry: bool = True email_on_failure: bool = True retries: int | None = DEFAULT_RETRIES - retry_delay: timedelta | float = DEFAULT_RETRY_DELAY + retry_delay: timedelta = DEFAULT_RETRY_DELAY retry_exponential_backoff: bool = False max_retry_delay: timedelta | float | None = None start_date: datetime | None = None @@ -561,10 +561,11 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): is_setup: bool = False is_teardown: bool = False + # TODO: Task-SDK: Make these ClassVar[]? template_fields: Collection[str] = () template_ext: Sequence[str] = () - template_fields_renderers: ClassVar[dict[str, str]] = {} + template_fields_renderers: dict[str, str] = field(default_factory=dict, init=False) # Defines the color in the UI ui_color: str = "#fff" @@ -575,6 +576,10 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): _dag: DAG | None = field(init=False, default=None) + # Make this optional so the type matches the one define in LoggingMixin + _log_config_logger_name: str | None = field(default="airflow.task.operators", init=False) + _logger_name: str | None = None + # The _serialized_fields are lazily loaded when get_serialized_fields() method is called __serialized_fields: ClassVar[frozenset[str] | None] = None @@ -633,7 +638,7 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): ) # each operator should override this class attr for shallow copy attrs. - shallow_copy_attrs: ClassVar[Sequence[str]] = () + shallow_copy_attrs: Sequence[str] = () def __setattr__(self: BaseOperator, key: str, value: Any): if converter := getattr(self, f"_convert_{key}", None): @@ -789,7 +794,8 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): if wait_for_downstream: self.depends_on_past = True - self.retry_delay = retry_delay + # Converted by setattr + self.retry_delay = retry_delay # type: ignore[assignment] self.retry_exponential_backoff = retry_exponential_backoff if max_retry_delay is not None: self.max_retry_delay = max_retry_delay @@ -817,10 +823,7 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): self.allow_nested_operators = allow_nested_operators - """ - self._log_config_logger_name = "airflow.task.operators" self._logger_name = logger_name - """ # Lineage if inlets: diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py b/task_sdk/src/airflow/sdk/definitions/dag.py index 2bb15e9f2df..9cc24828458 100644 --- a/task_sdk/src/airflow/sdk/definitions/dag.py +++ b/task_sdk/src/airflow/sdk/definitions/dag.py @@ -213,6 +213,37 @@ else: dict_copy = copy.copy +def _default_start_date(instance: DAG): + # Find start date inside default_args for compat with Airflow 2. + from airflow.utils import timezone + + if date := instance.default_args.get("start_date"): + if not isinstance(date, datetime): + date = timezone.parse(date) + instance.default_args["start_date"] = date + return date + return None + + +def _default_dag_display_name(instance: DAG) -> str: + return instance.dag_id + + +def _default_fileloc() -> str: + # Skip over this frame, and the 'attrs generated init' + back = sys._getframe().f_back + if not back or not (back := back.f_back): + # We expect two frames back, if not we don't know where we are + return "" + return back.f_code.co_filename if back else "" + + +def _default_task_group(instance: DAG) -> TaskGroup: + from airflow.sdk.definitions.taskgroup import TaskGroup + + return TaskGroup.create_root(dag=instance) + + # TODO: Task-SDK: look at re-enabling slots after we remove pickling @attrs.define(repr=False, field_transformer=_all_after_dag_id_to_kw_only, slots=False) class DAG: @@ -328,6 +359,11 @@ class DAG: __serialized_fields: ClassVar[frozenset[str] | None] = None + # Note: mypy gets very confused about the use of `@${attr}.default` for attrs without init=False -- and it + # doesn't correctly track/notice that they have default values (it gives errors about `Missing positional + # argument "description" in call to "DAG"`` etc), so for init=True args we use the `default=Factory()` + # style + # NOTE: When updating arguments here, please also keep arguments in @dag() # below in sync. (Search for 'def dag(' in this file.) dag_id: str = attrs.field(kw_only=False, validator=attrs.validators.instance_of(str)) @@ -338,7 +374,9 @@ class DAG: default_args: dict[str, Any] = attrs.field( factory=dict, validator=attrs.validators.instance_of(dict), converter=dict_copy ) - start_date: datetime | None = attrs.field() # type: ignore[misc] # mypy doesn't grok the `@dag.default` seemingly + start_date: datetime | None = attrs.field( + default=attrs.Factory(_default_start_date, takes_self=True), + ) end_date: datetime | None = None timezone: FixedTimezone | Timezone = attrs.field(init=False) @@ -382,13 +420,18 @@ class DAG: owner_links: dict[str, str] = attrs.field(factory=dict) auto_register: bool = attrs.field(default=True, converter=bool) fail_stop: bool = attrs.field(default=False, converter=bool) - dag_display_name: str = attrs.field(validator=attrs.validators.instance_of(str)) # type: ignore[misc] # mypy doesn't grok the `@dag.default` seemingly + dag_display_name: str = attrs.field( + default=attrs.Factory(_default_dag_display_name, takes_self=True), + validator=attrs.validators.instance_of(str), + ) task_dict: dict[str, Operator] = attrs.field(factory=dict, init=False) - task_group: TaskGroup = attrs.field(on_setattr=attrs.setters.frozen) # type: ignore[misc] # mypy doesn't grok the `@dag.default` seemingly + task_group: TaskGroup = attrs.field( + on_setattr=attrs.setters.frozen, default=attrs.Factory(_default_task_group, takes_self=True) + ) - fileloc: str = attrs.field(init=False) + fileloc: str = attrs.field(init=False, factory=_default_fileloc) partial: bool = attrs.field(init=False, default=False) edge_info: dict[str, dict[str, EdgeInfoType]] = attrs.field(init=False, factory=dict) @@ -406,68 +449,6 @@ class DAG: self.start_date = timezone.convert_to_utc(self.start_date) self.end_date = timezone.convert_to_utc(self.end_date) - @fileloc.default - def _default_fileloc(self) -> str: - # Skip over this frame, and the 'attrs generated init' - back = sys._getframe().f_back - if not back or not (back := back.f_back): - # We expect two frames back, if not we don't know where we are - return "" - return back.f_code.co_filename if back else "" - - @dag_display_name.default - def _default_dag_display_name(self) -> str: - return self.dag_id - - @task_group.default - def _default_task_group(self) -> TaskGroup: - from airflow.sdk.definitions.taskgroup import TaskGroup - - return TaskGroup.create_root(dag=self) - - @timetable.default - def _default_timetable(self): - from airflow.assets import AssetAll - - schedule = self.schedule - # TODO: Once - # delattr(self, "schedule") - if isinstance(schedule, Timetable): - return schedule - elif isinstance(schedule, BaseAsset): - return AssetTriggeredTimetable(schedule) - elif isinstance(schedule, Collection) and not isinstance(schedule, str): - if not all(isinstance(x, (Asset, AssetAlias)) for x in schedule): - raise ValueError("All elements in 'schedule' should be assets or asset aliases") - return AssetTriggeredTimetable(AssetAll(*schedule)) - else: - return _create_timetable(schedule, self.timezone) - - @start_date.default - def _default_start_date(self): - # Find start date inside default_args for compat with Airflow 2. - from airflow.utils import timezone - - if date := self.default_args.get("start_date"): - if not isinstance(date, datetime): - date = timezone.parse(date) - self.default_args["start_date"] = date - return date - return None - - @timezone.default - def _extract_tz(self): - import pendulum - - from airflow.utils import timezone - - # TODO: Task-SDK: get default dag tz from settings - tz = timezone.utc - if self.start_date and (tzinfo := self.start_date.tzinfo): - tzinfo = None if tzinfo else tz - tz = pendulum.instance(self.start_date, tz=tzinfo).timezone - return tz - @params.validator def _validate_params(self, _, params: ParamsDict): """ @@ -506,6 +487,37 @@ class DAG: f"requires max_active_runs <= {self.timetable.active_runs_limit}" ) + @timetable.default + def _default_timetable(instance: DAG): + from airflow.assets import AssetAll + + schedule = instance.schedule + # TODO: Once + # delattr(self, "schedule") + if isinstance(schedule, Timetable): + return schedule + elif isinstance(schedule, BaseAsset): + return AssetTriggeredTimetable(schedule) + elif isinstance(schedule, Collection) and not isinstance(schedule, str): + if not all(isinstance(x, (Asset, AssetAlias)) for x in schedule): + raise ValueError("All elements in 'schedule' should be assets or asset aliases") + return AssetTriggeredTimetable(AssetAll(*schedule)) + else: + return _create_timetable(schedule, instance.timezone) + + @timezone.default + def _extract_tz(instance): + import pendulum + + from airflow.utils import timezone + + # TODO: Task-SDK: get default dag tz from settings + tz = timezone.utc + if instance.start_date and (tzinfo := instance.start_date.tzinfo): + tzinfo = None if tzinfo else tz + tz = pendulum.instance(instance.start_date, tz=tzinfo).timezone + return tz + @has_on_success_callback.default def _has_on_success_callback(self) -> bool: return self.on_success_callback is not None diff --git a/task_sdk/src/airflow/sdk/definitions/taskgroup.py b/task_sdk/src/airflow/sdk/definitions/taskgroup.py index 54602961cb8..26b1f6c45e4 100644 --- a/task_sdk/src/airflow/sdk/definitions/taskgroup.py +++ b/task_sdk/src/airflow/sdk/definitions/taskgroup.py @@ -69,6 +69,17 @@ def _default_parent_group() -> TaskGroup | None: return TaskGroupContext.get_current() +# This could be achieved with `@dag.default` and make this a method, but for some unknown reason when we do +# that it makes Mypy (1.9.0 and 1.13.0 tested) seem to entirely loose track that this is an Attrs class. So +# we've gone with this and moved on with our lives, mypy is to much of a dark beast to battle over this. +def _default_dag(instance: TaskGroup): + from airflow.sdk.definitions.contextmanager import DagContext + + if (pg := instance.parent_group) is not None: + return pg.dag + return DagContext.get_current() + + @attrs.define(repr=False) class TaskGroup(DAGNode): """ @@ -101,9 +112,9 @@ class TaskGroup(DAGNode): """ _group_id: str | None - prefix_group_id: bool = True + prefix_group_id: bool = attrs.field(default=True) parent_group: TaskGroup | None = attrs.field(factory=_default_parent_group) - dag: DAG = attrs.field() # type: ignore[misc] # mypy doesn't grok the `@dag.default` seemingly + dag: DAG = attrs.field(default=attrs.Factory(_default_dag, takes_self=True)) default_args: dict[str, Any] = attrs.field(factory=dict, converter=copy.deepcopy) tooltip: str = "" children: dict[str, DAGNode] = attrs.field(factory=dict, init=False) @@ -120,14 +131,6 @@ class TaskGroup(DAGNode): add_suffix_on_collision: bool = False - @dag.default - def _default_dag(self): - from airflow.sdk.definitions.contextmanager import DagContext - - if self.parent_group is not None: - return self.parent_group.dag - return DagContext.get_current() - @dag.validator def _validate_dag(self, _attr, dag): if not dag:
