This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c8af0592c0 Improve taskflow type hints with ParamSpec (#25173)
c8af0592c0 is described below
commit c8af0592c08017ee48f69f608ad4a6529ee14292
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Tue Jul 26 17:13:06 2022 +0800
Improve taskflow type hints with ParamSpec (#25173)
---
airflow/decorators/__init__.pyi | 12 ++--
airflow/decorators/base.py | 78 +++++++++++++---------
.../cncf/kubernetes/operators/kubernetes_pod.py | 2 +-
airflow/providers/dbt/cloud/hooks/dbt.py | 2 +-
airflow/providers/google/cloud/operators/gcs.py | 28 ++++----
airflow/providers/qubole/hooks/qubole.py | 3 +-
airflow/providers/salesforce/operators/bulk.py | 2 +-
airflow/typing_compat.py | 32 +++++++--
8 files changed, 96 insertions(+), 63 deletions(-)
diff --git a/airflow/decorators/__init__.pyi b/airflow/decorators/__init__.pyi
index 726798f134..35a56e231e 100644
--- a/airflow/decorators/__init__.pyi
+++ b/airflow/decorators/__init__.pyi
@@ -20,9 +20,9 @@
# necessarily exist at run time. See "Creating Custom @task Decorators"
# documentation for more details.
-from typing import Any, Dict, Iterable, List, Mapping, Optional, Union,
overload
+from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional,
Union, overload
-from airflow.decorators.base import Function, Task, TaskDecorator
+from airflow.decorators.base import FParams, FReturn, Task, TaskDecorator
from airflow.decorators.branch_python import branch_task
from airflow.decorators.python import python_task
from airflow.decorators.python_virtualenv import virtualenv_task
@@ -68,7 +68,7 @@ class TaskDecoratorCollection:
"""
# [START mixin_for_typing]
@overload
- def python(self, python_callable: Function) -> Task[Function]: ...
+ def python(self, python_callable: Callable[FParams, FReturn]) ->
Task[FParams, FReturn]: ...
# [END mixin_for_typing]
@overload
def __call__(
@@ -81,7 +81,7 @@ class TaskDecoratorCollection:
) -> TaskDecorator:
"""Aliasing ``python``; signature should match exactly."""
@overload
- def __call__(self, python_callable: Function) -> Task[Function]:
+ def __call__(self, python_callable: Callable[FParams, FReturn]) ->
Task[FParams, FReturn]:
"""Aliasing ``python``; signature should match exactly."""
@overload
def virtualenv(
@@ -122,7 +122,7 @@ class TaskDecoratorCollection:
such as transmission a large amount of XCom to TaskAPI.
"""
@overload
- def virtualenv(self, python_callable: Function) -> Task[Function]: ...
+ def virtualenv(self, python_callable: Callable[FParams, FReturn]) ->
Task[FParams, FReturn]: ...
@overload
def branch(self, *, multiple_outputs: Optional[bool] = None, **kwargs) ->
TaskDecorator:
"""Create a decorator to wrap the decorated callable into a
BranchPythonOperator.
@@ -134,7 +134,7 @@ class TaskDecoratorCollection:
Dict will unroll to XCom values with keys as XCom keys. Defaults
to False.
"""
@overload
- def branch(self, python_callable: Function) -> Task[Function]: ...
+ def branch(self, python_callable: Callable[FParams, FReturn]) ->
Task[FParams, FReturn]: ...
# [START decorator_signature]
def docker(
self,
diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index a36d7d5d43..0a4b75cece 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-import functools
import inspect
import re
from typing import (
@@ -68,7 +67,7 @@ from airflow.models.mappedoperator import (
)
from airflow.models.pool import Pool
from airflow.models.xcom_arg import XComArg
-from airflow.typing_compat import Protocol
+from airflow.typing_compat import ParamSpec, Protocol
from airflow.utils import timezone
from airflow.utils.context import KNOWN_CONTEXT_KEYS, Context
from airflow.utils.task_group import TaskGroup, TaskGroupContext
@@ -236,13 +235,15 @@ class DecoratedOperator(BaseOperator):
return args, kwargs
-Function = TypeVar("Function", bound=Callable)
+FParams = ParamSpec("FParams")
+
+FReturn = TypeVar("FReturn")
OperatorSubclass = TypeVar("OperatorSubclass", bound="BaseOperator")
@attr.define(slots=False)
-class _TaskDecorator(Generic[Function, OperatorSubclass]):
+class _TaskDecorator(Generic[FParams, FReturn, OperatorSubclass]):
"""
Helper class for providing dynamic task mapping to decorated functions.
@@ -251,7 +252,7 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
:meta private:
"""
- function: Function = attr.ib()
+ function: Callable[FParams, FReturn] = attr.ib()
operator_class: Type[OperatorSubclass]
multiple_outputs: bool = attr.ib()
kwargs: Dict[str, Any] = attr.ib(factory=dict)
@@ -272,7 +273,7 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
raise TypeError(f"@{self.decorator_name} does not support methods")
self.kwargs.setdefault('task_id', self.function.__name__)
- def __call__(self, *args, **kwargs) -> XComArg:
+ def __call__(self, *args: "FParams.args", **kwargs: "FParams.kwargs") ->
XComArg:
op = self.operator_class(
python_callable=self.function,
op_args=args,
@@ -285,7 +286,7 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
return XComArg(op)
@property
- def __wrapped__(self) -> Function:
+ def __wrapped__(self) -> Callable[FParams, FReturn]:
return self.function
@cached_property
@@ -337,9 +338,7 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
# to False to skip the checks on execution.
return self._expand(DictOfListsExpandInput(map_kwargs), strict=False)
- def expand_kwargs(self, kwargs: "XComArg", *, strict: bool = True) ->
XComArg:
- from airflow.models.xcom_arg import XComArg
-
+ def expand_kwargs(self, kwargs: XComArg, *, strict: bool = True) ->
XComArg:
if not isinstance(kwargs, XComArg):
raise TypeError(f"expected XComArg object, not
{type(kwargs).__name__}")
return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)
@@ -420,14 +419,14 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
)
return XComArg(operator=operator)
- def partial(self, **kwargs: Any) -> "_TaskDecorator[Function,
OperatorSubclass]":
+ def partial(self, **kwargs: Any) -> "_TaskDecorator[FParams, FReturn,
OperatorSubclass]":
self._validate_arg_names("partial", kwargs)
old_kwargs = self.kwargs.get("op_kwargs", {})
prevent_duplicates(old_kwargs, kwargs, fail_reason="duplicate partial")
kwargs.update(old_kwargs)
return attr.evolve(self, kwargs={**self.kwargs, "op_kwargs": kwargs})
- def override(self, **kwargs: Any) -> "_TaskDecorator[Function,
OperatorSubclass]":
+ def override(self, **kwargs: Any) -> "_TaskDecorator[FParams, FReturn,
OperatorSubclass]":
return attr.evolve(self, kwargs={**self.kwargs, **kwargs})
@@ -506,7 +505,7 @@ class DecoratedMappedOperator(MappedOperator):
return {k: _render_if_not_already_resolved(k, v) for k, v in
value.items()}
-class Task(Generic[Function]):
+class Task(Generic[FParams, FReturn]):
"""Declaration of a @task-decorated callable for type-checking.
An instance of this type inherits the call signature of the decorated
@@ -517,18 +516,21 @@ class Task(Generic[Function]):
This type is implemented by ``_TaskDecorator`` at runtime.
"""
- __call__: Function
+ __call__: Callable[FParams, XComArg]
- function: Function
+ function: Callable[FParams, FReturn]
@property
- def __wrapped__(self) -> Function:
+ def __wrapped__(self) -> Callable[FParams, FReturn]:
+ ...
+
+ def partial(self, **kwargs: Any) -> "Task[FParams, FReturn]":
...
def expand(self, **kwargs: "Mappable") -> XComArg:
...
- def partial(self, **kwargs: Any) -> "Task[Function]":
+ def expand_kwargs(self, kwargs: XComArg, *, strict: bool = True) ->
XComArg:
...
@@ -536,7 +538,10 @@ class TaskDecorator(Protocol):
"""Type declaration for ``task_decorator_factory`` return type."""
@overload
- def __call__(self, python_callable: Function) -> Task[Function]:
+ def __call__( # type: ignore[misc]
+ self,
+ python_callable: Callable[FParams, FReturn],
+ ) -> Task[FParams, FReturn]:
"""For the "bare decorator" ``@task`` case."""
@overload
@@ -545,7 +550,7 @@ class TaskDecorator(Protocol):
*,
multiple_outputs: Optional[bool] = None,
**kwargs: Any,
- ) -> Callable[[Function], Task[Function]]:
+ ) -> Callable[[Callable[FParams, FReturn]], Task[FParams, FReturn]]:
"""For the decorator factory ``@task()`` case."""
@@ -556,16 +561,20 @@ def task_decorator_factory(
decorated_operator_class: Type[BaseOperator],
**kwargs,
) -> TaskDecorator:
- """
- A factory that generates a wrapper that wraps a function into an Airflow
operator.
- Accepts kwargs for operator kwarg. Can be reused in a single DAG.
+ """Generate a wrapper that wraps a function into an Airflow operator.
- :param python_callable: Function to decorate
- :param multiple_outputs: If set to True, the decorated function's return
value will be unrolled to
- multiple XCom values. Dict will unroll to XCom values with its keys as
XCom keys. Defaults to False.
- :param decorated_operator_class: The operator that executes the logic
needed to run the python function in
- the correct environment
+ Can be reused in a single DAG.
+ :param python_callable: Function to decorate.
+ :param multiple_outputs: If set to True, the decorated function's return
+ value will be unrolled to multiple XCom values. Dict will unroll to
XCom
+ values with its keys as XCom keys. If set to False (default), only at
+ most one XCom value is pushed.
+ :param decorated_operator_class: The operator that executes the logic
needed
+ to run the python function in the correct environment.
+
+ Other kwargs are directly forwarded to the underlying operator class when
+ it's instantiated.
"""
if multiple_outputs is None:
multiple_outputs = cast(bool, attr.NOTHING)
@@ -579,10 +588,13 @@ def task_decorator_factory(
return cast(TaskDecorator, decorator)
elif python_callable is not None:
raise TypeError('No args allowed while using @task, use kwargs
instead')
- decorator_factory = functools.partial(
- _TaskDecorator,
- multiple_outputs=multiple_outputs,
- operator_class=decorated_operator_class,
- kwargs=kwargs,
- )
+
+ def decorator_factory(python_callable):
+ return _TaskDecorator(
+ function=python_callable,
+ multiple_outputs=multiple_outputs,
+ operator_class=decorated_operator_class,
+ kwargs=kwargs,
+ )
+
return cast(TaskDecorator, decorator_factory)
diff --git a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
index 90d4fb88de..ef2ebed6ae 100644
--- a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
+++ b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
@@ -328,7 +328,7 @@ class KubernetesPodOperator(BaseOperator):
if include_try_number:
labels.update(try_number=ti.try_number)
# In the case of sub dags this is just useful
- if context['dag'].is_subdag:
+ if context['dag'].parent_dag:
labels['parent_dag_id'] = context['dag'].parent_dag.dag_id
# Ensure that label is valid for Kube,
# and if not truncate/remove invalid chars and replace with short hash.
diff --git a/airflow/providers/dbt/cloud/hooks/dbt.py
b/airflow/providers/dbt/cloud/hooks/dbt.py
index 07a327a045..fd370502a5 100644
--- a/airflow/providers/dbt/cloud/hooks/dbt.py
+++ b/airflow/providers/dbt/cloud/hooks/dbt.py
@@ -87,7 +87,7 @@ class TokenAuth(AuthBase):
class JobRunInfo(TypedDict):
"""Type class for the ``job_run_info`` dictionary."""
- account_id: int
+ account_id: Optional[int]
run_id: int
diff --git a/airflow/providers/google/cloud/operators/gcs.py
b/airflow/providers/google/cloud/operators/gcs.py
index 7c58ee7c5a..77a1ab656e 100644
--- a/airflow/providers/google/cloud/operators/gcs.py
+++ b/airflow/providers/google/cloud/operators/gcs.py
@@ -30,7 +30,6 @@ if TYPE_CHECKING:
from google.api_core.exceptions import Conflict
from google.cloud.exceptions import GoogleCloudError
-from pendulum.datetime import DateTime
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
@@ -723,22 +722,25 @@ class GCSTimeSpanFileTransformOperator(BaseOperator):
def execute(self, context: "Context") -> List[str]:
# Define intervals and prefixes.
try:
- timespan_start = context["data_interval_start"]
- timespan_end = context["data_interval_end"]
+ orig_start = context["data_interval_start"]
+ orig_end = context["data_interval_end"]
except KeyError:
- timespan_start = pendulum.instance(context["execution_date"])
+ orig_start = pendulum.instance(context["execution_date"])
following_execution_date =
context["dag"].following_schedule(context["execution_date"])
if following_execution_date is None:
- timespan_end = None
+ orig_end = None
else:
- timespan_end = pendulum.instance(following_execution_date)
-
- if timespan_end is None: # Only possible in Airflow before 2.2.
- self.log.warning("No following schedule found, setting timespan
end to max %s", timespan_end)
- timespan_end = DateTime.max
- elif timespan_start >= timespan_end: # Airflow 2.2 sets start == end
for non-perodic schedules.
- self.log.warning("DAG schedule not periodic, setting timespan end
to max %s", timespan_end)
- timespan_end = DateTime.max
+ orig_end = pendulum.instance(following_execution_date)
+
+ timespan_start = orig_start
+ if orig_end is None: # Only possible in Airflow before 2.2.
+ self.log.warning("No following schedule found, setting timespan
end to max %s", orig_end)
+ timespan_end = pendulum.instance(datetime.datetime.max)
+ elif orig_start >= orig_end: # Airflow 2.2 sets start == end for
non-perodic schedules.
+ self.log.warning("DAG schedule not periodic, setting timespan end
to max %s", orig_end)
+ timespan_end = pendulum.instance(datetime.datetime.max)
+ else:
+ timespan_end = orig_end
timespan_start = timespan_start.in_timezone(timezone.utc)
timespan_end = timespan_end.in_timezone(timezone.utc)
diff --git a/airflow/providers/qubole/hooks/qubole.py
b/airflow/providers/qubole/hooks/qubole.py
index 3b0d4bdd1a..340cf4fe13 100644
--- a/airflow/providers/qubole/hooks/qubole.py
+++ b/airflow/providers/qubole/hooks/qubole.py
@@ -46,6 +46,7 @@ from airflow.hooks.base import BaseHook
from airflow.utils.state import State
if TYPE_CHECKING:
+ from airflow.models.taskinstance import TaskInstance
from airflow.utils.context import Context
@@ -139,7 +140,7 @@ class QuboleHook(BaseHook):
self.kwargs = kwargs
self.cls = COMMAND_CLASSES[self.kwargs['command_type']]
self.cmd: Optional[Command] = None
- self.task_instance = None
+ self.task_instance: Optional["TaskInstance"] = None
@staticmethod
def handle_failure_retry(context) -> None:
diff --git a/airflow/providers/salesforce/operators/bulk.py
b/airflow/providers/salesforce/operators/bulk.py
index 110ed685ea..39de722032 100644
--- a/airflow/providers/salesforce/operators/bulk.py
+++ b/airflow/providers/salesforce/operators/bulk.py
@@ -47,7 +47,7 @@ class SalesforceBulkOperator(BaseOperator):
def __init__(
self,
*,
- operation: Literal[available_operations],
+ operation: Literal['insert', 'update', 'upsert', 'delete',
'hard_delete'],
object_name: str,
payload: list,
external_id_field: str = 'Id',
diff --git a/airflow/typing_compat.py b/airflow/typing_compat.py
index 163889b8a2..ec1846438f 100644
--- a/airflow/typing_compat.py
+++ b/airflow/typing_compat.py
@@ -21,10 +21,28 @@ This module provides helper code to make type annotation
within Airflow
codebase easier.
"""
-try:
- # Literal, Protocol and TypedDict are only added to typing module starting
from
- # python 3.8 we can safely remove this shim import after Airflow drops
- # support for <3.8
- from typing import Literal, Protocol, TypedDict, runtime_checkable #
type: ignore
-except ImportError:
- from typing_extensions import Literal, Protocol, TypedDict,
runtime_checkable # type: ignore # noqa
+__all__ = [
+ "Literal",
+ "ParamSpec",
+ "Protocol",
+ "TypedDict",
+ "runtime_checkable",
+]
+
+import sys
+
+if sys.version_info >= (3, 8):
+ from typing import Protocol, TypedDict, runtime_checkable
+else:
+ from typing_extensions import Protocol, TypedDict, runtime_checkable
+
+# Literal in 3.8 is limited to one single argument, not e.g. "Literal[1, 2]".
+if sys.version_info >= (3, 9):
+ from typing import Literal
+else:
+ from typing_extensions import Literal
+
+if sys.version_info >= (3, 10):
+ from typing import ParamSpec
+else:
+ from typing_extensions import ParamSpec