This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c8af0592c0 Improve taskflow type hints with ParamSpec (#25173)
c8af0592c0 is described below

commit c8af0592c08017ee48f69f608ad4a6529ee14292
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Tue Jul 26 17:13:06 2022 +0800

    Improve taskflow type hints with ParamSpec (#25173)
---
 airflow/decorators/__init__.pyi                    | 12 ++--
 airflow/decorators/base.py                         | 78 +++++++++++++---------
 .../cncf/kubernetes/operators/kubernetes_pod.py    |  2 +-
 airflow/providers/dbt/cloud/hooks/dbt.py           |  2 +-
 airflow/providers/google/cloud/operators/gcs.py    | 28 ++++----
 airflow/providers/qubole/hooks/qubole.py           |  3 +-
 airflow/providers/salesforce/operators/bulk.py     |  2 +-
 airflow/typing_compat.py                           | 32 +++++++--
 8 files changed, 96 insertions(+), 63 deletions(-)

diff --git a/airflow/decorators/__init__.pyi b/airflow/decorators/__init__.pyi
index 726798f134..35a56e231e 100644
--- a/airflow/decorators/__init__.pyi
+++ b/airflow/decorators/__init__.pyi
@@ -20,9 +20,9 @@
 # necessarily exist at run time. See "Creating Custom @task Decorators"
 # documentation for more details.
 
-from typing import Any, Dict, Iterable, List, Mapping, Optional, Union, 
overload
+from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, 
Union, overload
 
-from airflow.decorators.base import Function, Task, TaskDecorator
+from airflow.decorators.base import FParams, FReturn, Task, TaskDecorator
 from airflow.decorators.branch_python import branch_task
 from airflow.decorators.python import python_task
 from airflow.decorators.python_virtualenv import virtualenv_task
@@ -68,7 +68,7 @@ class TaskDecoratorCollection:
         """
     # [START mixin_for_typing]
     @overload
-    def python(self, python_callable: Function) -> Task[Function]: ...
+    def python(self, python_callable: Callable[FParams, FReturn]) -> 
Task[FParams, FReturn]: ...
     # [END mixin_for_typing]
     @overload
     def __call__(
@@ -81,7 +81,7 @@ class TaskDecoratorCollection:
     ) -> TaskDecorator:
         """Aliasing ``python``; signature should match exactly."""
     @overload
-    def __call__(self, python_callable: Function) -> Task[Function]:
+    def __call__(self, python_callable: Callable[FParams, FReturn]) -> 
Task[FParams, FReturn]:
         """Aliasing ``python``; signature should match exactly."""
     @overload
     def virtualenv(
@@ -122,7 +122,7 @@ class TaskDecoratorCollection:
             such as transmission a large amount of XCom to TaskAPI.
         """
     @overload
-    def virtualenv(self, python_callable: Function) -> Task[Function]: ...
+    def virtualenv(self, python_callable: Callable[FParams, FReturn]) -> 
Task[FParams, FReturn]: ...
     @overload
     def branch(self, *, multiple_outputs: Optional[bool] = None, **kwargs) -> 
TaskDecorator:
         """Create a decorator to wrap the decorated callable into a 
BranchPythonOperator.
@@ -134,7 +134,7 @@ class TaskDecoratorCollection:
             Dict will unroll to XCom values with keys as XCom keys. Defaults 
to False.
         """
     @overload
-    def branch(self, python_callable: Function) -> Task[Function]: ...
+    def branch(self, python_callable: Callable[FParams, FReturn]) -> 
Task[FParams, FReturn]: ...
     # [START decorator_signature]
     def docker(
         self,
diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index a36d7d5d43..0a4b75cece 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -15,7 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import functools
 import inspect
 import re
 from typing import (
@@ -68,7 +67,7 @@ from airflow.models.mappedoperator import (
 )
 from airflow.models.pool import Pool
 from airflow.models.xcom_arg import XComArg
-from airflow.typing_compat import Protocol
+from airflow.typing_compat import ParamSpec, Protocol
 from airflow.utils import timezone
 from airflow.utils.context import KNOWN_CONTEXT_KEYS, Context
 from airflow.utils.task_group import TaskGroup, TaskGroupContext
@@ -236,13 +235,15 @@ class DecoratedOperator(BaseOperator):
         return args, kwargs
 
 
-Function = TypeVar("Function", bound=Callable)
+FParams = ParamSpec("FParams")
+
+FReturn = TypeVar("FReturn")
 
 OperatorSubclass = TypeVar("OperatorSubclass", bound="BaseOperator")
 
 
 @attr.define(slots=False)
-class _TaskDecorator(Generic[Function, OperatorSubclass]):
+class _TaskDecorator(Generic[FParams, FReturn, OperatorSubclass]):
     """
     Helper class for providing dynamic task mapping to decorated functions.
 
@@ -251,7 +252,7 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
     :meta private:
     """
 
-    function: Function = attr.ib()
+    function: Callable[FParams, FReturn] = attr.ib()
     operator_class: Type[OperatorSubclass]
     multiple_outputs: bool = attr.ib()
     kwargs: Dict[str, Any] = attr.ib(factory=dict)
@@ -272,7 +273,7 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
             raise TypeError(f"@{self.decorator_name} does not support methods")
         self.kwargs.setdefault('task_id', self.function.__name__)
 
-    def __call__(self, *args, **kwargs) -> XComArg:
+    def __call__(self, *args: "FParams.args", **kwargs: "FParams.kwargs") -> 
XComArg:
         op = self.operator_class(
             python_callable=self.function,
             op_args=args,
@@ -285,7 +286,7 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
         return XComArg(op)
 
     @property
-    def __wrapped__(self) -> Function:
+    def __wrapped__(self) -> Callable[FParams, FReturn]:
         return self.function
 
     @cached_property
@@ -337,9 +338,7 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
         # to False to skip the checks on execution.
         return self._expand(DictOfListsExpandInput(map_kwargs), strict=False)
 
-    def expand_kwargs(self, kwargs: "XComArg", *, strict: bool = True) -> 
XComArg:
-        from airflow.models.xcom_arg import XComArg
-
+    def expand_kwargs(self, kwargs: XComArg, *, strict: bool = True) -> 
XComArg:
         if not isinstance(kwargs, XComArg):
             raise TypeError(f"expected XComArg object, not 
{type(kwargs).__name__}")
         return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)
@@ -420,14 +419,14 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
         )
         return XComArg(operator=operator)
 
-    def partial(self, **kwargs: Any) -> "_TaskDecorator[Function, 
OperatorSubclass]":
+    def partial(self, **kwargs: Any) -> "_TaskDecorator[FParams, FReturn, 
OperatorSubclass]":
         self._validate_arg_names("partial", kwargs)
         old_kwargs = self.kwargs.get("op_kwargs", {})
         prevent_duplicates(old_kwargs, kwargs, fail_reason="duplicate partial")
         kwargs.update(old_kwargs)
         return attr.evolve(self, kwargs={**self.kwargs, "op_kwargs": kwargs})
 
-    def override(self, **kwargs: Any) -> "_TaskDecorator[Function, 
OperatorSubclass]":
+    def override(self, **kwargs: Any) -> "_TaskDecorator[FParams, FReturn, 
OperatorSubclass]":
         return attr.evolve(self, kwargs={**self.kwargs, **kwargs})
 
 
@@ -506,7 +505,7 @@ class DecoratedMappedOperator(MappedOperator):
         return {k: _render_if_not_already_resolved(k, v) for k, v in 
value.items()}
 
 
-class Task(Generic[Function]):
+class Task(Generic[FParams, FReturn]):
     """Declaration of a @task-decorated callable for type-checking.
 
     An instance of this type inherits the call signature of the decorated
@@ -517,18 +516,21 @@ class Task(Generic[Function]):
     This type is implemented by ``_TaskDecorator`` at runtime.
     """
 
-    __call__: Function
+    __call__: Callable[FParams, XComArg]
 
-    function: Function
+    function: Callable[FParams, FReturn]
 
     @property
-    def __wrapped__(self) -> Function:
+    def __wrapped__(self) -> Callable[FParams, FReturn]:
+        ...
+
+    def partial(self, **kwargs: Any) -> "Task[FParams, FReturn]":
         ...
 
     def expand(self, **kwargs: "Mappable") -> XComArg:
         ...
 
-    def partial(self, **kwargs: Any) -> "Task[Function]":
+    def expand_kwargs(self, kwargs: XComArg, *, strict: bool = True) -> 
XComArg:
         ...
 
 
@@ -536,7 +538,10 @@ class TaskDecorator(Protocol):
     """Type declaration for ``task_decorator_factory`` return type."""
 
     @overload
-    def __call__(self, python_callable: Function) -> Task[Function]:
+    def __call__(  # type: ignore[misc]
+        self,
+        python_callable: Callable[FParams, FReturn],
+    ) -> Task[FParams, FReturn]:
         """For the "bare decorator" ``@task`` case."""
 
     @overload
@@ -545,7 +550,7 @@ class TaskDecorator(Protocol):
         *,
         multiple_outputs: Optional[bool] = None,
         **kwargs: Any,
-    ) -> Callable[[Function], Task[Function]]:
+    ) -> Callable[[Callable[FParams, FReturn]], Task[FParams, FReturn]]:
         """For the decorator factory ``@task()`` case."""
 
 
@@ -556,16 +561,20 @@ def task_decorator_factory(
     decorated_operator_class: Type[BaseOperator],
     **kwargs,
 ) -> TaskDecorator:
-    """
-    A factory that generates a wrapper that wraps a function into an Airflow 
operator.
-    Accepts kwargs for operator kwarg. Can be reused in a single DAG.
+    """Generate a wrapper that wraps a function into an Airflow operator.
 
-    :param python_callable: Function to decorate
-    :param multiple_outputs: If set to True, the decorated function's return 
value will be unrolled to
-        multiple XCom values. Dict will unroll to XCom values with its keys as 
XCom keys. Defaults to False.
-    :param decorated_operator_class: The operator that executes the logic 
needed to run the python function in
-        the correct environment
+    Can be reused in a single DAG.
 
+    :param python_callable: Function to decorate.
+    :param multiple_outputs: If set to True, the decorated function's return
+        value will be unrolled to multiple XCom values. Dict will unroll to 
XCom
+        values with its keys as XCom keys. If set to False (default), only at
+        most one XCom value is pushed.
+    :param decorated_operator_class: The operator that executes the logic 
needed
+        to run the python function in the correct environment.
+
+    Other kwargs are directly forwarded to the underlying operator class when
+    it's instantiated.
     """
     if multiple_outputs is None:
         multiple_outputs = cast(bool, attr.NOTHING)
@@ -579,10 +588,13 @@ def task_decorator_factory(
         return cast(TaskDecorator, decorator)
     elif python_callable is not None:
         raise TypeError('No args allowed while using @task, use kwargs 
instead')
-    decorator_factory = functools.partial(
-        _TaskDecorator,
-        multiple_outputs=multiple_outputs,
-        operator_class=decorated_operator_class,
-        kwargs=kwargs,
-    )
+
+    def decorator_factory(python_callable):
+        return _TaskDecorator(
+            function=python_callable,
+            multiple_outputs=multiple_outputs,
+            operator_class=decorated_operator_class,
+            kwargs=kwargs,
+        )
+
     return cast(TaskDecorator, decorator_factory)
diff --git a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py 
b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
index 90d4fb88de..ef2ebed6ae 100644
--- a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
+++ b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
@@ -328,7 +328,7 @@ class KubernetesPodOperator(BaseOperator):
         if include_try_number:
             labels.update(try_number=ti.try_number)
         # In the case of sub dags this is just useful
-        if context['dag'].is_subdag:
+        if context['dag'].parent_dag:
             labels['parent_dag_id'] = context['dag'].parent_dag.dag_id
         # Ensure that label is valid for Kube,
         # and if not truncate/remove invalid chars and replace with short hash.
diff --git a/airflow/providers/dbt/cloud/hooks/dbt.py 
b/airflow/providers/dbt/cloud/hooks/dbt.py
index 07a327a045..fd370502a5 100644
--- a/airflow/providers/dbt/cloud/hooks/dbt.py
+++ b/airflow/providers/dbt/cloud/hooks/dbt.py
@@ -87,7 +87,7 @@ class TokenAuth(AuthBase):
 class JobRunInfo(TypedDict):
     """Type class for the ``job_run_info`` dictionary."""
 
-    account_id: int
+    account_id: Optional[int]
     run_id: int
 
 
diff --git a/airflow/providers/google/cloud/operators/gcs.py 
b/airflow/providers/google/cloud/operators/gcs.py
index 7c58ee7c5a..77a1ab656e 100644
--- a/airflow/providers/google/cloud/operators/gcs.py
+++ b/airflow/providers/google/cloud/operators/gcs.py
@@ -30,7 +30,6 @@ if TYPE_CHECKING:
 
 from google.api_core.exceptions import Conflict
 from google.cloud.exceptions import GoogleCloudError
-from pendulum.datetime import DateTime
 
 from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator
@@ -723,22 +722,25 @@ class GCSTimeSpanFileTransformOperator(BaseOperator):
     def execute(self, context: "Context") -> List[str]:
         # Define intervals and prefixes.
         try:
-            timespan_start = context["data_interval_start"]
-            timespan_end = context["data_interval_end"]
+            orig_start = context["data_interval_start"]
+            orig_end = context["data_interval_end"]
         except KeyError:
-            timespan_start = pendulum.instance(context["execution_date"])
+            orig_start = pendulum.instance(context["execution_date"])
             following_execution_date = 
context["dag"].following_schedule(context["execution_date"])
             if following_execution_date is None:
-                timespan_end = None
+                orig_end = None
             else:
-                timespan_end = pendulum.instance(following_execution_date)
-
-        if timespan_end is None:  # Only possible in Airflow before 2.2.
-            self.log.warning("No following schedule found, setting timespan 
end to max %s", timespan_end)
-            timespan_end = DateTime.max
-        elif timespan_start >= timespan_end:  # Airflow 2.2 sets start == end 
for non-perodic schedules.
-            self.log.warning("DAG schedule not periodic, setting timespan end 
to max %s", timespan_end)
-            timespan_end = DateTime.max
+                orig_end = pendulum.instance(following_execution_date)
+
+        timespan_start = orig_start
+        if orig_end is None:  # Only possible in Airflow before 2.2.
+            self.log.warning("No following schedule found, setting timespan 
end to max %s", orig_end)
+            timespan_end = pendulum.instance(datetime.datetime.max)
+        elif orig_start >= orig_end:  # Airflow 2.2 sets start == end for 
non-perodic schedules.
+            self.log.warning("DAG schedule not periodic, setting timespan end 
to max %s", orig_end)
+            timespan_end = pendulum.instance(datetime.datetime.max)
+        else:
+            timespan_end = orig_end
 
         timespan_start = timespan_start.in_timezone(timezone.utc)
         timespan_end = timespan_end.in_timezone(timezone.utc)
diff --git a/airflow/providers/qubole/hooks/qubole.py 
b/airflow/providers/qubole/hooks/qubole.py
index 3b0d4bdd1a..340cf4fe13 100644
--- a/airflow/providers/qubole/hooks/qubole.py
+++ b/airflow/providers/qubole/hooks/qubole.py
@@ -46,6 +46,7 @@ from airflow.hooks.base import BaseHook
 from airflow.utils.state import State
 
 if TYPE_CHECKING:
+    from airflow.models.taskinstance import TaskInstance
     from airflow.utils.context import Context
 
 
@@ -139,7 +140,7 @@ class QuboleHook(BaseHook):
         self.kwargs = kwargs
         self.cls = COMMAND_CLASSES[self.kwargs['command_type']]
         self.cmd: Optional[Command] = None
-        self.task_instance = None
+        self.task_instance: Optional["TaskInstance"] = None
 
     @staticmethod
     def handle_failure_retry(context) -> None:
diff --git a/airflow/providers/salesforce/operators/bulk.py 
b/airflow/providers/salesforce/operators/bulk.py
index 110ed685ea..39de722032 100644
--- a/airflow/providers/salesforce/operators/bulk.py
+++ b/airflow/providers/salesforce/operators/bulk.py
@@ -47,7 +47,7 @@ class SalesforceBulkOperator(BaseOperator):
     def __init__(
         self,
         *,
-        operation: Literal[available_operations],
+        operation: Literal['insert', 'update', 'upsert', 'delete', 
'hard_delete'],
         object_name: str,
         payload: list,
         external_id_field: str = 'Id',
diff --git a/airflow/typing_compat.py b/airflow/typing_compat.py
index 163889b8a2..ec1846438f 100644
--- a/airflow/typing_compat.py
+++ b/airflow/typing_compat.py
@@ -21,10 +21,28 @@ This module provides helper code to make type annotation 
within Airflow
 codebase easier.
 """
 
-try:
-    # Literal, Protocol and TypedDict are only added to typing module starting 
from
-    # python 3.8 we can safely remove this shim import after Airflow drops
-    # support for <3.8
-    from typing import Literal, Protocol, TypedDict, runtime_checkable  # 
type: ignore
-except ImportError:
-    from typing_extensions import Literal, Protocol, TypedDict, 
runtime_checkable  # type: ignore # noqa
+__all__ = [
+    "Literal",
+    "ParamSpec",
+    "Protocol",
+    "TypedDict",
+    "runtime_checkable",
+]
+
+import sys
+
+if sys.version_info >= (3, 8):
+    from typing import Protocol, TypedDict, runtime_checkable
+else:
+    from typing_extensions import Protocol, TypedDict, runtime_checkable
+
+# Literal in 3.8 is limited to one single argument, not e.g. "Literal[1, 2]".
+if sys.version_info >= (3, 9):
+    from typing import Literal
+else:
+    from typing_extensions import Literal
+
+if sys.version_info >= (3, 10):
+    from typing import ParamSpec
+else:
+    from typing_extensions import ParamSpec

Reply via email to