This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch task-sdk-first-code
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit d5c7d046eb80a7d66e0517107de81a20c4ad3fa0
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Thu Oct 24 19:20:48 2024 +0100

    make mpypy happy [skip ci]
---
 airflow/decorators/bash.py                             |  4 ++--
 airflow/decorators/sensor.py                           |  2 +-
 airflow/models/abstractoperator.py                     |  4 ++--
 airflow/models/xcom_arg.py                             |  8 ++++----
 airflow/sensors/external_task.py                       |  4 ++--
 scripts/ci/pre_commit/sync_init_decorator.py           |  8 +++++++-
 .../src/airflow/sdk/definitions/abstractoperator.py    |  6 ------
 task_sdk/src/airflow/sdk/definitions/contextmanager.py |  4 ++--
 task_sdk/src/airflow/sdk/definitions/dag.py            | 16 +++++++---------
 task_sdk/src/airflow/sdk/definitions/mixins.py         |  6 ++----
 task_sdk/src/airflow/sdk/definitions/taskgroup.py      | 18 +++++++++---------
 11 files changed, 38 insertions(+), 42 deletions(-)

diff --git a/airflow/decorators/bash.py b/airflow/decorators/bash.py
index 44738492da0..e4dc19745e0 100644
--- a/airflow/decorators/bash.py
+++ b/airflow/decorators/bash.py
@@ -18,7 +18,7 @@
 from __future__ import annotations
 
 import warnings
-from typing import Any, Callable, Collection, Mapping, Sequence
+from typing import Any, Callable, ClassVar, Collection, Mapping, Sequence
 
 from airflow.decorators.base import DecoratedOperator, TaskDecorator, 
task_decorator_factory
 from airflow.providers.standard.operators.bash import BashOperator
@@ -39,7 +39,7 @@ class _BashDecoratedOperator(DecoratedOperator, BashOperator):
     """
 
     template_fields: Sequence[str] = (*DecoratedOperator.template_fields, 
*BashOperator.template_fields)
-    template_fields_renderers: dict[str, str] = {
+    template_fields_renderers: ClassVar[dict[str, str]] = {
         **DecoratedOperator.template_fields_renderers,
         **BashOperator.template_fields_renderers,
     }
diff --git a/airflow/decorators/sensor.py b/airflow/decorators/sensor.py
index c37cd08d6b4..6ed3e9cc398 100644
--- a/airflow/decorators/sensor.py
+++ b/airflow/decorators/sensor.py
@@ -42,7 +42,7 @@ class DecoratedSensorOperator(PythonSensor):
     """
 
     template_fields: Sequence[str] = ("op_args", "op_kwargs")
-    template_fields_renderers: dict[str, str] = {"op_args": "py", "op_kwargs": 
"py"}
+    template_fields_renderers: ClassVar[dict[str, str]] = {"op_args": "py", 
"op_kwargs": "py"}
 
     custom_operator_name = "@task.sensor"
 
diff --git a/airflow/models/abstractoperator.py 
b/airflow/models/abstractoperator.py
index 93c0fd2d93a..feafb0b6b63 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -40,8 +40,6 @@ from airflow.utils.task_group import MappedTaskGroup
 from airflow.utils.trigger_rule import TriggerRule
 from airflow.utils.weight_rule import WeightRule
 
-TaskStateChangeCallback = Callable[[Context], None]
-
 if TYPE_CHECKING:
     from collections.abc import Mapping
 
@@ -58,6 +56,8 @@ if TYPE_CHECKING:
     from airflow.triggers.base import StartTriggerArgs
     from airflow.utils.task_group import TaskGroup
 
+TaskStateChangeCallback = Callable[[Context], None]
+
 DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner")
 DEFAULT_POOL_SLOTS: int = 1
 DEFAULT_PRIORITY_WEIGHT: int = 1
diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py
index 940a7f1a066..c28af6acbe5 100644
--- a/airflow/models/xcom_arg.py
+++ b/airflow/models/xcom_arg.py
@@ -393,15 +393,15 @@ class PlainXComArg(XComArg):
     def as_teardown(
         self,
         *,
-        setups: BaseOperator | Iterable[BaseOperator] | ArgNotSet = NOTSET,
-        on_failure_fail_dagrun=NOTSET,
+        setups: BaseOperator | Iterable[BaseOperator] | None = None,
+        on_failure_fail_dagrun: bool | None = None,
     ):
         for operator, _ in self.iter_references():
             operator.is_teardown = True
             operator.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS
-            if on_failure_fail_dagrun is not NOTSET:
+            if on_failure_fail_dagrun is not None:
                 operator.on_failure_fail_dagrun = on_failure_fail_dagrun
-            if not isinstance(setups, ArgNotSet):
+            if setups is not None:
                 setups = [setups] if isinstance(setups, DependencyMixin) else 
setups
                 for s in setups:
                     s.is_setup = True
diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py
index 8eb501e281d..331e17168ba 100644
--- a/airflow/sensors/external_task.py
+++ b/airflow/sensors/external_task.py
@@ -20,7 +20,7 @@ from __future__ import annotations
 import datetime
 import os
 import warnings
-from typing import TYPE_CHECKING, Any, Callable, Collection, Iterable
+from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, Iterable
 
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException, AirflowSkipException
@@ -476,7 +476,7 @@ class ExternalTaskMarker(EmptyOperator):
     operator_extra_links = [ExternalDagLink()]
 
     # The _serialized_fields are lazily loaded when get_serialized_fields() 
method is called
-    __serialized_fields: frozenset[str] | None = None
+    __serialized_fields: ClassVar[frozenset[str] | None] = None
 
     def __init__(
         self,
diff --git a/scripts/ci/pre_commit/sync_init_decorator.py 
b/scripts/ci/pre_commit/sync_init_decorator.py
index 13e80d62c6c..7b02136ead3 100755
--- a/scripts/ci/pre_commit/sync_init_decorator.py
+++ b/scripts/ci/pre_commit/sync_init_decorator.py
@@ -116,7 +116,13 @@ def _expr_to_ast_dump(expr: str) -> str:
 
 
 ALLOWABLE_TYPE_ANNOTATIONS = {
-    _expr_to_ast_dump("Collection[str] | None"): 
_expr_to_ast_dump("MutableSet[str]")
+    # Mapping of allowble Decorator type -> Class attribute type
+    _expr_to_ast_dump("Collection[str] | None"): 
_expr_to_ast_dump("MutableSet[str]"),
+    _expr_to_ast_dump("ParamsDict | dict[str, Any] | None"): 
_expr_to_ast_dump("ParamsDict"),
+    # TODO: This one is legacy access control. Remove it in 3.0. 
RemovedInAirflow3Warning
+    _expr_to_ast_dump(
+        "dict[str, dict[str, Collection[str]]] | dict[str, Collection[str]] | 
None"
+    ): _expr_to_ast_dump("dict[str, dict[str, Collection[str]]] | None"),
 }
 
 
diff --git a/task_sdk/src/airflow/sdk/definitions/abstractoperator.py 
b/task_sdk/src/airflow/sdk/definitions/abstractoperator.py
index bb5ddf88e23..5285bd97ef4 100644
--- a/task_sdk/src/airflow/sdk/definitions/abstractoperator.py
+++ b/task_sdk/src/airflow/sdk/definitions/abstractoperator.py
@@ -34,18 +34,12 @@ from airflow.sdk.definitions.node import DAGNode
 from airflow.utils.trigger_rule import TriggerRule
 from airflow.utils.weight_rule import WeightRule
 
-# TaskStateChangeCallback = Callable[[Context], None]
-
 if TYPE_CHECKING:
     from airflow.models.baseoperatorlink import BaseOperatorLink
     from airflow.models.operator import Operator
     from airflow.sdk.definitions.baseoperator import BaseOperator
     from airflow.sdk.definitions.dag import DAG
 
-    # TODO: Task-SDK
-    Context = dict[str, Any]
-
-
 DEFAULT_OWNER: str = "airflow"
 DEFAULT_POOL_SLOTS: int = 1
 DEFAULT_PRIORITY_WEIGHT: int = 1
diff --git a/task_sdk/src/airflow/sdk/definitions/contextmanager.py 
b/task_sdk/src/airflow/sdk/definitions/contextmanager.py
index 8b5458c65b9..ac50dcadbfc 100644
--- a/task_sdk/src/airflow/sdk/definitions/contextmanager.py
+++ b/task_sdk/src/airflow/sdk/definitions/contextmanager.py
@@ -20,7 +20,7 @@ from __future__ import annotations
 import sys
 from collections import deque
 from types import ModuleType
-from typing import Any, Generic, Optional, TypeVar, cast
+from typing import Any, Generic, TypeVar
 
 from airflow.sdk.definitions.dag import DAG
 from airflow.sdk.definitions.taskgroup import TaskGroup
@@ -109,7 +109,7 @@ class DagContext(ContextStack[DAG]):
 
     @classmethod
     def get_current_dag(cls) -> DAG | None:
-        return cast(Optional[DAG], cls.get_current())
+        return cls.get_current()
 
 
 class TaskGroupContext(ContextStack[TaskGroup]):
diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py 
b/task_sdk/src/airflow/sdk/definitions/dag.py
index 111f51ce855..a8f222fd8ad 100644
--- a/task_sdk/src/airflow/sdk/definitions/dag.py
+++ b/task_sdk/src/airflow/sdk/definitions/dag.py
@@ -63,6 +63,7 @@ from airflow.timetables.simple import (
     NullTimetable,
     OnceTimetable,
 )
+from airflow.utils.context import Context
 from airflow.utils.dag_cycle_tester import check_cycle
 from airflow.utils.decorators import fixup_decorator_warning_stack
 from airflow.utils.trigger_rule import TriggerRule
@@ -88,10 +89,6 @@ __all__ = [
 ]
 
 
-# TODO: Task-SDK
-class Context: ...
-
-
 DagStateChangeCallback = Callable[[Context], None]
 ScheduleInterval = Union[None, str, timedelta, relativedelta]
 
@@ -341,7 +338,8 @@ class DAG:
     default_args: dict[str, Any] = attrs.field(
         factory=dict, validator=attrs.validators.instance_of(dict), 
converter=dict_copy
     )
-    start_date: datetime | None = attrs.field()
+    start_date: datetime | None = attrs.field()  # type: ignore[misc]  # mypy 
doesn't grok the `@dag.default` seemingly
+
     end_date: datetime | None = None
     timezone: FixedTimezone | Timezone = attrs.field(init=False)
     schedule: ScheduleArg = attrs.field(default=None, 
on_setattr=attrs.setters.frozen)
@@ -373,7 +371,7 @@ class DAG:
         default=None,
         converter=attrs.Converter(_convert_params, takes_self=True),  # type: 
ignore[misc, call-overload]
     )
-    access_control: dict[str, dict[str, Collection[str]]] | dict[str, 
Collection[str]] | None = attrs.field(
+    access_control: dict[str, dict[str, Collection[str]]] | None = attrs.field(
         default=None,
         converter=attrs.Converter(_convert_access_control, takes_self=True),  
# type: ignore[misc, call-overload]
     )
@@ -384,11 +382,11 @@ class DAG:
     owner_links: dict[str, str] = attrs.field(factory=dict)
     auto_register: bool = attrs.field(default=True, converter=bool)
     fail_stop: bool = attrs.field(default=False, converter=bool)
-    dag_display_name: str = 
attrs.field(validator=attrs.validators.instance_of(str))
+    dag_display_name: str = 
attrs.field(validator=attrs.validators.instance_of(str))  # type: ignore[misc]  
# mypy doesn't grok the `@dag.default` seemingly
 
     task_dict: dict[str, Operator] = attrs.field(factory=dict, init=False)
 
-    task_group: TaskGroup = attrs.field(on_setattr=attrs.setters.frozen)
+    task_group: TaskGroup = attrs.field(on_setattr=attrs.setters.frozen)  # 
type: ignore[misc]  # mypy doesn't grok the `@dag.default` seemingly
 
     fileloc: str = attrs.field(init=False)
     partial: bool = attrs.field(init=False, default=False)
@@ -1036,7 +1034,7 @@ if TYPE_CHECKING:
         on_success_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None,
         on_failure_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None,
         doc_md: str | None = None,
-        params: ParamsDict | None = None,
+        params: ParamsDict | dict[str, Any] | None = None,
         access_control: dict[str, dict[str, Collection[str]]] | dict[str, 
Collection[str]] | None = None,
         is_paused_upon_creation: bool | None = None,
         jinja_environment_kwargs: dict | None = None,
diff --git a/task_sdk/src/airflow/sdk/definitions/mixins.py 
b/task_sdk/src/airflow/sdk/definitions/mixins.py
index e9d6e162927..de63772615d 100644
--- a/task_sdk/src/airflow/sdk/definitions/mixins.py
+++ b/task_sdk/src/airflow/sdk/definitions/mixins.py
@@ -21,8 +21,6 @@ from abc import abstractmethod
 from collections.abc import Iterable, Sequence
 from typing import TYPE_CHECKING, Any
 
-from airflow.sdk.types import NOTSET, ArgNotSet
-
 if TYPE_CHECKING:
     from airflow.sdk.definitions.baseoperator import BaseOperator
     from airflow.sdk.definitions.edges import EdgeModifier
@@ -72,8 +70,8 @@ class DependencyMixin:
     def as_teardown(
         self,
         *,
-        setups: BaseOperator | Iterable[BaseOperator] | ArgNotSet = NOTSET,
-        on_failure_fail_dagrun: bool | ArgNotSet = NOTSET,
+        setups: BaseOperator | Iterable[BaseOperator] | None = None,
+        on_failure_fail_dagrun: bool | None = None,
     ) -> DependencyMixin:
         """Mark a task as teardown and set its setups as direct relatives."""
         raise NotImplementedError()
diff --git a/task_sdk/src/airflow/sdk/definitions/taskgroup.py 
b/task_sdk/src/airflow/sdk/definitions/taskgroup.py
index e417eab2760..54602961cb8 100644
--- a/task_sdk/src/airflow/sdk/definitions/taskgroup.py
+++ b/task_sdk/src/airflow/sdk/definitions/taskgroup.py
@@ -63,6 +63,12 @@ TASKGROUP_ARGS_EXPECTED_TYPES = {
 }
 
 
+def _default_parent_group() -> TaskGroup | None:
+    from airflow.sdk.definitions.contextmanager import TaskGroupContext
+
+    return TaskGroupContext.get_current()
+
+
 @attrs.define(repr=False)
 class TaskGroup(DAGNode):
     """
@@ -96,8 +102,8 @@ class TaskGroup(DAGNode):
 
     _group_id: str | None
     prefix_group_id: bool = True
-    parent_group: TaskGroup | None = attrs.field()
-    dag: DAG = attrs.field()
+    parent_group: TaskGroup | None = attrs.field(factory=_default_parent_group)
+    dag: DAG = attrs.field()  # type: ignore[misc]  # mypy doesn't grok the 
`@dag.default` seemingly
     default_args: dict[str, Any] = attrs.field(factory=dict, 
converter=copy.deepcopy)
     tooltip: str = ""
     children: dict[str, DAGNode] = attrs.field(factory=dict, init=False)
@@ -114,12 +120,6 @@ class TaskGroup(DAGNode):
 
     add_suffix_on_collision: bool = False
 
-    @parent_group.default
-    def _default_parent_group(self):
-        from airflow.sdk.definitions.contextmanager import TaskGroupContext
-
-        return TaskGroupContext.get_current()
-
     @dag.default
     def _default_dag(self):
         from airflow.sdk.definitions.contextmanager import DagContext
@@ -247,7 +247,7 @@ class TaskGroup(DAGNode):
     @property
     def group_id(self) -> str | None:
         """group_id of this TaskGroup."""
-        if self.parent_group and self.parent_group.prefix_group_id and 
self.parent_group.group_id:
+        if self.parent_group and self.parent_group.prefix_group_id and 
self.parent_group.node_id:
             # defer to parent whether it adds a prefix
             return self.parent_group.child_id(self.group_id)
 

Reply via email to