This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch task-sdk-first-code
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 872de48f3bb2579fba791cbce2e7f7314acc3884
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Fri Oct 25 17:56:33 2024 +0100

    Fix mypy typing
---
 airflow/decorators/base.py                         |   2 +-
 airflow/decorators/bash.py                         |   4 +-
 airflow/decorators/sensor.py                       |   6 +-
 airflow/operators/python.py                        |   7 +-
 dev/mypy/plugin/outputs.py                         |   1 +
 .../providers/cncf/kubernetes/operators/pod.py     |   4 +-
 .../cncf/kubernetes/operators/spark_kubernetes.py  |   3 +-
 .../src/airflow/sdk/definitions/baseoperator.py    |  17 ++-
 task_sdk/src/airflow/sdk/definitions/dag.py        | 144 +++++++++++----------
 task_sdk/src/airflow/sdk/definitions/taskgroup.py  |  23 ++--
 10 files changed, 114 insertions(+), 97 deletions(-)

diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 6129dc1dd42..c9e4cf170f9 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -186,7 +186,7 @@ class DecoratedOperator(BaseOperator):
 
     # since we won't mutate the arguments, we should just do the shallow copy
     # there are some cases we can't deepcopy the objects (e.g protobuf).
-    shallow_copy_attrs: ClassVar[Sequence[str]] = ("python_callable",)
+    shallow_copy_attrs: Sequence[str] = ("python_callable",)
 
     def __init__(
         self,
diff --git a/airflow/decorators/bash.py b/airflow/decorators/bash.py
index e4dc19745e0..44738492da0 100644
--- a/airflow/decorators/bash.py
+++ b/airflow/decorators/bash.py
@@ -18,7 +18,7 @@
 from __future__ import annotations
 
 import warnings
-from typing import Any, Callable, ClassVar, Collection, Mapping, Sequence
+from typing import Any, Callable, Collection, Mapping, Sequence
 
 from airflow.decorators.base import DecoratedOperator, TaskDecorator, 
task_decorator_factory
 from airflow.providers.standard.operators.bash import BashOperator
@@ -39,7 +39,7 @@ class _BashDecoratedOperator(DecoratedOperator, BashOperator):
     """
 
     template_fields: Sequence[str] = (*DecoratedOperator.template_fields, 
*BashOperator.template_fields)
-    template_fields_renderers: ClassVar[dict[str, str]] = {
+    template_fields_renderers: dict[str, str] = {
         **DecoratedOperator.template_fields_renderers,
         **BashOperator.template_fields_renderers,
     }
diff --git a/airflow/decorators/sensor.py b/airflow/decorators/sensor.py
index 6ed3e9cc398..c332a78f95c 100644
--- a/airflow/decorators/sensor.py
+++ b/airflow/decorators/sensor.py
@@ -17,7 +17,7 @@
 
 from __future__ import annotations
 
-from typing import TYPE_CHECKING, Callable, ClassVar, Sequence
+from typing import TYPE_CHECKING, Callable, Sequence
 
 from airflow.decorators.base import get_unique_task_id, task_decorator_factory
 from airflow.sensors.python import PythonSensor
@@ -42,13 +42,13 @@ class DecoratedSensorOperator(PythonSensor):
     """
 
     template_fields: Sequence[str] = ("op_args", "op_kwargs")
-    template_fields_renderers: ClassVar[dict[str, str]] = {"op_args": "py", 
"op_kwargs": "py"}
+    template_fields_renderers: dict[str, str] = {"op_args": "py", "op_kwargs": 
"py"}
 
     custom_operator_name = "@task.sensor"
 
     # since we won't mutate the arguments, we should just do the shallow copy
     # there are some cases we can't deepcopy the objects (e.g protobuf).
-    shallow_copy_attrs: ClassVar[Sequence[str]] = ("python_callable",)
+    shallow_copy_attrs: Sequence[str] = ("python_callable",)
 
     def __init__(
         self,
diff --git a/airflow/operators/python.py b/airflow/operators/python.py
index dc2e772af0e..3d40ad2c845 100644
--- a/airflow/operators/python.py
+++ b/airflow/operators/python.py
@@ -33,7 +33,7 @@ from collections.abc import Container
 from functools import cache
 from pathlib import Path
 from tempfile import TemporaryDirectory
-from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, 
Iterable, Mapping, NamedTuple, Sequence
+from typing import TYPE_CHECKING, Any, Callable, Collection, Iterable, 
Mapping, NamedTuple, Sequence
 
 import lazy_object_proxy
 
@@ -197,10 +197,7 @@ class PythonOperator(BaseOperator):
 
     # since we won't mutate the arguments, we should just do the shallow copy
     # there are some cases we can't deepcopy the objects(e.g protobuf).
-    shallow_copy_attrs: ClassVar[Sequence[str]] = (
-        "python_callable",
-        "op_kwargs",
-    )
+    shallow_copy_attrs: Sequence[str] = ("python_callable", "op_kwargs")
 
     def __init__(
         self,
diff --git a/dev/mypy/plugin/outputs.py b/dev/mypy/plugin/outputs.py
index fe1ccd5e7cf..a3ba7351f55 100644
--- a/dev/mypy/plugin/outputs.py
+++ b/dev/mypy/plugin/outputs.py
@@ -25,6 +25,7 @@ from mypy.types import AnyType, Type, TypeOfAny
 OUTPUT_PROPERTIES = {
     "airflow.models.baseoperator.BaseOperator.output",
     "airflow.models.mappedoperator.MappedOperator.output",
+    "airflow.sdk.definitions.baseoperator.BaseOperator.output",
 }
 
 TASK_CALL_FUNCTIONS = {
diff --git a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py 
b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py
index e51397447c3..62f08439d41 100644
--- a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py
+++ b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py
@@ -27,7 +27,7 @@ import re
 import shlex
 import string
 import warnings
-from collections.abc import Container
+from collections.abc import Container, Mapping
 from contextlib import AbstractContextManager
 from enum import Enum
 from functools import cached_property
@@ -436,7 +436,7 @@ class KubernetesPodOperator(BaseOperator):
     def _render_nested_template_fields(
         self,
         content: Any,
-        context: Context,
+        context: Mapping[str, Any],
         jinja_env: jinja2.Environment,
         seen_oids: set,
     ) -> None:
diff --git 
a/providers/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py 
b/providers/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
index c3dd4755b98..c1f5b36d6d3 100644
--- 
a/providers/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
+++ 
b/providers/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
@@ -17,6 +17,7 @@
 # under the License.
 from __future__ import annotations
 
+from collections.abc import Mapping
 from functools import cached_property
 from pathlib import Path
 from typing import TYPE_CHECKING, Any
@@ -127,7 +128,7 @@ class SparkKubernetesOperator(KubernetesPodOperator):
     def _render_nested_template_fields(
         self,
         content: Any,
-        context: Context,
+        context: Mapping[str, Any],
         jinja_env: jinja2.Environment,
         seen_oids: set,
     ) -> None:
diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py 
b/task_sdk/src/airflow/sdk/definitions/baseoperator.py
index 7b78aaca83d..e99ad835fb7 100644
--- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py
+++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py
@@ -507,7 +507,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
     email_on_retry: bool = True
     email_on_failure: bool = True
     retries: int | None = DEFAULT_RETRIES
-    retry_delay: timedelta | float = DEFAULT_RETRY_DELAY
+    retry_delay: timedelta = DEFAULT_RETRY_DELAY
     retry_exponential_backoff: bool = False
     max_retry_delay: timedelta | float | None = None
     start_date: datetime | None = None
@@ -561,10 +561,11 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
     is_setup: bool = False
     is_teardown: bool = False
 
+    # TODO: Task-SDK: Make these ClassVar[]?
     template_fields: Collection[str] = ()
     template_ext: Sequence[str] = ()
 
-    template_fields_renderers: ClassVar[dict[str, str]] = {}
+    template_fields_renderers: dict[str, str] = field(default_factory=dict, 
init=False)
 
     # Defines the color in the UI
     ui_color: str = "#fff"
@@ -575,6 +576,10 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
 
     _dag: DAG | None = field(init=False, default=None)
 
+    # Make this optional so the type matches the one define in LoggingMixin
+    _log_config_logger_name: str | None = 
field(default="airflow.task.operators", init=False)
+    _logger_name: str | None = None
+
     # The _serialized_fields are lazily loaded when get_serialized_fields() 
method is called
     __serialized_fields: ClassVar[frozenset[str] | None] = None
 
@@ -633,7 +638,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
     )
 
     # each operator should override this class attr for shallow copy attrs.
-    shallow_copy_attrs: ClassVar[Sequence[str]] = ()
+    shallow_copy_attrs: Sequence[str] = ()
 
     def __setattr__(self: BaseOperator, key: str, value: Any):
         if converter := getattr(self, f"_convert_{key}", None):
@@ -789,7 +794,8 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         if wait_for_downstream:
             self.depends_on_past = True
 
-        self.retry_delay = retry_delay
+        # Converted by setattr
+        self.retry_delay = retry_delay  # type: ignore[assignment]
         self.retry_exponential_backoff = retry_exponential_backoff
         if max_retry_delay is not None:
             self.max_retry_delay = max_retry_delay
@@ -817,10 +823,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
 
         self.allow_nested_operators = allow_nested_operators
 
-        """
-        self._log_config_logger_name = "airflow.task.operators"
         self._logger_name = logger_name
-        """
 
         # Lineage
         if inlets:
diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py 
b/task_sdk/src/airflow/sdk/definitions/dag.py
index 2bb15e9f2df..9cc24828458 100644
--- a/task_sdk/src/airflow/sdk/definitions/dag.py
+++ b/task_sdk/src/airflow/sdk/definitions/dag.py
@@ -213,6 +213,37 @@ else:
     dict_copy = copy.copy
 
 
+def _default_start_date(instance: DAG):
+    # Find start date inside default_args for compat with Airflow 2.
+    from airflow.utils import timezone
+
+    if date := instance.default_args.get("start_date"):
+        if not isinstance(date, datetime):
+            date = timezone.parse(date)
+            instance.default_args["start_date"] = date
+        return date
+    return None
+
+
+def _default_dag_display_name(instance: DAG) -> str:
+    return instance.dag_id
+
+
+def _default_fileloc() -> str:
+    # Skip over this frame, and the 'attrs generated init'
+    back = sys._getframe().f_back
+    if not back or not (back := back.f_back):
+        # We expect two frames back, if not we don't know where we are
+        return ""
+    return back.f_code.co_filename if back else ""
+
+
+def _default_task_group(instance: DAG) -> TaskGroup:
+    from airflow.sdk.definitions.taskgroup import TaskGroup
+
+    return TaskGroup.create_root(dag=instance)
+
+
 # TODO: Task-SDK: look at re-enabling slots after we remove pickling
 @attrs.define(repr=False, field_transformer=_all_after_dag_id_to_kw_only, 
slots=False)
 class DAG:
@@ -328,6 +359,11 @@ class DAG:
 
     __serialized_fields: ClassVar[frozenset[str] | None] = None
 
+    # Note: mypy gets very confused about the use of `@${attr}.default` for 
attrs without init=False -- and it
+    # doesn't correctly track/notice that they have default values (it gives 
errors about `Missing positional
+    # argument "description" in call to "DAG"`` etc), so for init=True args we 
use the `default=Factory()`
+    # style
+
     # NOTE: When updating arguments here, please also keep arguments in @dag()
     # below in sync. (Search for 'def dag(' in this file.)
     dag_id: str = attrs.field(kw_only=False, 
validator=attrs.validators.instance_of(str))
@@ -338,7 +374,9 @@ class DAG:
     default_args: dict[str, Any] = attrs.field(
         factory=dict, validator=attrs.validators.instance_of(dict), 
converter=dict_copy
     )
-    start_date: datetime | None = attrs.field()  # type: ignore[misc]  # mypy 
doesn't grok the `@dag.default` seemingly
+    start_date: datetime | None = attrs.field(
+        default=attrs.Factory(_default_start_date, takes_self=True),
+    )
 
     end_date: datetime | None = None
     timezone: FixedTimezone | Timezone = attrs.field(init=False)
@@ -382,13 +420,18 @@ class DAG:
     owner_links: dict[str, str] = attrs.field(factory=dict)
     auto_register: bool = attrs.field(default=True, converter=bool)
     fail_stop: bool = attrs.field(default=False, converter=bool)
-    dag_display_name: str = 
attrs.field(validator=attrs.validators.instance_of(str))  # type: ignore[misc]  
# mypy doesn't grok the `@dag.default` seemingly
+    dag_display_name: str = attrs.field(
+        default=attrs.Factory(_default_dag_display_name, takes_self=True),
+        validator=attrs.validators.instance_of(str),
+    )
 
     task_dict: dict[str, Operator] = attrs.field(factory=dict, init=False)
 
-    task_group: TaskGroup = attrs.field(on_setattr=attrs.setters.frozen)  # 
type: ignore[misc]  # mypy doesn't grok the `@dag.default` seemingly
+    task_group: TaskGroup = attrs.field(
+        on_setattr=attrs.setters.frozen, 
default=attrs.Factory(_default_task_group, takes_self=True)
+    )
 
-    fileloc: str = attrs.field(init=False)
+    fileloc: str = attrs.field(init=False, factory=_default_fileloc)
     partial: bool = attrs.field(init=False, default=False)
 
     edge_info: dict[str, dict[str, EdgeInfoType]] = attrs.field(init=False, 
factory=dict)
@@ -406,68 +449,6 @@ class DAG:
         self.start_date = timezone.convert_to_utc(self.start_date)
         self.end_date = timezone.convert_to_utc(self.end_date)
 
-    @fileloc.default
-    def _default_fileloc(self) -> str:
-        # Skip over this frame, and the 'attrs generated init'
-        back = sys._getframe().f_back
-        if not back or not (back := back.f_back):
-            # We expect two frames back, if not we don't know where we are
-            return ""
-        return back.f_code.co_filename if back else ""
-
-    @dag_display_name.default
-    def _default_dag_display_name(self) -> str:
-        return self.dag_id
-
-    @task_group.default
-    def _default_task_group(self) -> TaskGroup:
-        from airflow.sdk.definitions.taskgroup import TaskGroup
-
-        return TaskGroup.create_root(dag=self)
-
-    @timetable.default
-    def _default_timetable(self):
-        from airflow.assets import AssetAll
-
-        schedule = self.schedule
-        # TODO: Once
-        # delattr(self, "schedule")
-        if isinstance(schedule, Timetable):
-            return schedule
-        elif isinstance(schedule, BaseAsset):
-            return AssetTriggeredTimetable(schedule)
-        elif isinstance(schedule, Collection) and not isinstance(schedule, 
str):
-            if not all(isinstance(x, (Asset, AssetAlias)) for x in schedule):
-                raise ValueError("All elements in 'schedule' should be assets 
or asset aliases")
-            return AssetTriggeredTimetable(AssetAll(*schedule))
-        else:
-            return _create_timetable(schedule, self.timezone)
-
-    @start_date.default
-    def _default_start_date(self):
-        # Find start date inside default_args for compat with Airflow 2.
-        from airflow.utils import timezone
-
-        if date := self.default_args.get("start_date"):
-            if not isinstance(date, datetime):
-                date = timezone.parse(date)
-                self.default_args["start_date"] = date
-            return date
-        return None
-
-    @timezone.default
-    def _extract_tz(self):
-        import pendulum
-
-        from airflow.utils import timezone
-
-        # TODO: Task-SDK: get default dag tz from settings
-        tz = timezone.utc
-        if self.start_date and (tzinfo := self.start_date.tzinfo):
-            tzinfo = None if tzinfo else tz
-            tz = pendulum.instance(self.start_date, tz=tzinfo).timezone
-        return tz
-
     @params.validator
     def _validate_params(self, _, params: ParamsDict):
         """
@@ -506,6 +487,37 @@ class DAG:
                     f"requires max_active_runs <= 
{self.timetable.active_runs_limit}"
                 )
 
+    @timetable.default
+    def _default_timetable(instance: DAG):
+        from airflow.assets import AssetAll
+
+        schedule = instance.schedule
+        # TODO: Once
+        # delattr(self, "schedule")
+        if isinstance(schedule, Timetable):
+            return schedule
+        elif isinstance(schedule, BaseAsset):
+            return AssetTriggeredTimetable(schedule)
+        elif isinstance(schedule, Collection) and not isinstance(schedule, 
str):
+            if not all(isinstance(x, (Asset, AssetAlias)) for x in schedule):
+                raise ValueError("All elements in 'schedule' should be assets 
or asset aliases")
+            return AssetTriggeredTimetable(AssetAll(*schedule))
+        else:
+            return _create_timetable(schedule, instance.timezone)
+
+    @timezone.default
+    def _extract_tz(instance):
+        import pendulum
+
+        from airflow.utils import timezone
+
+        # TODO: Task-SDK: get default dag tz from settings
+        tz = timezone.utc
+        if instance.start_date and (tzinfo := instance.start_date.tzinfo):
+            tzinfo = None if tzinfo else tz
+            tz = pendulum.instance(instance.start_date, tz=tzinfo).timezone
+        return tz
+
     @has_on_success_callback.default
     def _has_on_success_callback(self) -> bool:
         return self.on_success_callback is not None
diff --git a/task_sdk/src/airflow/sdk/definitions/taskgroup.py 
b/task_sdk/src/airflow/sdk/definitions/taskgroup.py
index 54602961cb8..26b1f6c45e4 100644
--- a/task_sdk/src/airflow/sdk/definitions/taskgroup.py
+++ b/task_sdk/src/airflow/sdk/definitions/taskgroup.py
@@ -69,6 +69,17 @@ def _default_parent_group() -> TaskGroup | None:
     return TaskGroupContext.get_current()
 
 
+# This could be achieved with `@dag.default` and make this a method, but for 
some unknown reason when we do
+# that it makes Mypy (1.9.0 and 1.13.0 tested) seem to entirely loose track 
that this is an Attrs class. So
+# we've gone with this and moved on with our lives, mypy is to much of a dark 
beast to battle over this.
+def _default_dag(instance: TaskGroup):
+    from airflow.sdk.definitions.contextmanager import DagContext
+
+    if (pg := instance.parent_group) is not None:
+        return pg.dag
+    return DagContext.get_current()
+
+
 @attrs.define(repr=False)
 class TaskGroup(DAGNode):
     """
@@ -101,9 +112,9 @@ class TaskGroup(DAGNode):
     """
 
     _group_id: str | None
-    prefix_group_id: bool = True
+    prefix_group_id: bool = attrs.field(default=True)
     parent_group: TaskGroup | None = attrs.field(factory=_default_parent_group)
-    dag: DAG = attrs.field()  # type: ignore[misc]  # mypy doesn't grok the 
`@dag.default` seemingly
+    dag: DAG = attrs.field(default=attrs.Factory(_default_dag, 
takes_self=True))
     default_args: dict[str, Any] = attrs.field(factory=dict, 
converter=copy.deepcopy)
     tooltip: str = ""
     children: dict[str, DAGNode] = attrs.field(factory=dict, init=False)
@@ -120,14 +131,6 @@ class TaskGroup(DAGNode):
 
     add_suffix_on_collision: bool = False
 
-    @dag.default
-    def _default_dag(self):
-        from airflow.sdk.definitions.contextmanager import DagContext
-
-        if self.parent_group is not None:
-            return self.parent_group.dag
-        return DagContext.get_current()
-
     @dag.validator
     def _validate_dag(self, _attr, dag):
         if not dag:

Reply via email to