This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a99d44d Typing fixes needed to deprecation warning fixes (#20376)
a99d44d is described below
commit a99d44d678bfc89a5b8f30cb83df457011aa7a1d
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Fri Dec 17 20:06:20 2021 +0800
Typing fixes needed to deprecation warning fixes (#20376)
---
airflow/models/baseoperator.py | 2 +-
airflow/models/skipmixin.py | 4 +--
airflow/operators/datetime.py | 2 +-
airflow/operators/python.py | 14 +++++-----
airflow/sensors/base.py | 57 +++++++++++++++++++++------------------
airflow/sensors/external_task.py | 9 ++++---
airflow/sensors/python.py | 3 ++-
airflow/utils/context.py | 15 ++++++++---
airflow/utils/context.pyi | 4 ++-
airflow/utils/helpers.py | 21 +++++++++++----
airflow/utils/operator_helpers.py | 14 +++++++---
airflow/utils/weekday.py | 7 ++---
12 files changed, 93 insertions(+), 59 deletions(-)
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 9f5d2d8..1233939 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -988,7 +988,7 @@ class BaseOperator(Operator, LoggingMixin, DependencyMixin,
metaclass=BaseOperat
if self._pre_execute_hook is not None:
self._pre_execute_hook(context)
- def execute(self, context: Any):
+ def execute(self, context: Context) -> Any:
"""
This is the main method to derive when creating an operator.
Context is the same dictionary used as when rendering jinja templates.
diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index de5f1fd..a552af8 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -22,7 +22,7 @@ from typing import TYPE_CHECKING, Iterable, Optional,
Sequence, Union
from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.session import create_session, provide_session
+from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.state import State
if TYPE_CHECKING:
@@ -69,7 +69,7 @@ class SkipMixin(LoggingMixin):
dag_run: "DagRun",
execution_date: "DateTime",
tasks: Sequence["BaseOperator"],
- session: "Session",
+ session: "Session" = NEW_SESSION,
):
"""
Sets tasks instances to skipped from the same dag run.
diff --git a/airflow/operators/datetime.py b/airflow/operators/datetime.py
index 6b1acf7..6750f12 100644
--- a/airflow/operators/datetime.py
+++ b/airflow/operators/datetime.py
@@ -100,7 +100,7 @@ def target_times_as_dates(
if upper is not None and isinstance(upper, datetime.time):
upper = datetime.datetime.combine(base_date, upper)
- if any(date is None for date in (lower, upper)):
+ if lower is None or upper is None:
return lower, upper
if upper < lower:
diff --git a/airflow/operators/python.py b/airflow/operators/python.py
index 31e3a6f..c908bce 100644
--- a/airflow/operators/python.py
+++ b/airflow/operators/python.py
@@ -24,7 +24,7 @@ import types
import warnings
from tempfile import TemporaryDirectory
from textwrap import dedent
-from typing import Callable, Dict, Iterable, List, Optional, Union
+from typing import Any, Callable, Dict, Iterable, List, Optional, Union
import dill
@@ -32,7 +32,7 @@ from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.models.skipmixin import SkipMixin
from airflow.models.taskinstance import _CURRENT_CONTEXT
-from airflow.utils.context import Context
+from airflow.utils.context import Context, context_copy_partial
from airflow.utils.operator_helpers import determine_kwargs
from airflow.utils.process_utils import execute_in_subprocess
from airflow.utils.python_virtualenv import prepare_virtualenv,
write_python_script
@@ -172,7 +172,7 @@ class PythonOperator(BaseOperator):
self.template_ext = templates_exts
self.show_return_value_in_logs = show_return_value_in_logs
- def execute(self, context: Dict):
+ def execute(self, context: Context) -> Any:
context.update(self.op_kwargs)
context['templates_dict'] = self.templates_dict
@@ -210,7 +210,7 @@ class BranchPythonOperator(PythonOperator, SkipMixin):
to be inferred.
"""
- def execute(self, context: Dict):
+ def execute(self, context: Context) -> Any:
branch = super().execute(context)
# TODO: The logic should be moved to SkipMixin to be available to all
branch operators.
if isinstance(branch, str):
@@ -242,7 +242,7 @@ class ShortCircuitOperator(PythonOperator, SkipMixin):
The condition is determined by the result of `python_callable`.
"""
- def execute(self, context: Dict):
+ def execute(self, context: Context) -> Any:
condition = super().execute(context)
self.log.info("Condition result is %s", condition)
@@ -398,9 +398,9 @@ class PythonVirtualenvOperator(PythonOperator):
self.requirements.append('dill')
self.pickling_library = dill if self.use_dill else pickle
- def execute(self, context: Context):
+ def execute(self, context: Context) -> Any:
serializable_keys = set(self._iter_serializable_context_keys())
- serializable_context = context.copy_only(serializable_keys)
+ serializable_context = context_copy_partial(context, serializable_keys)
return super().execute(context=serializable_context)
def execute_callable(self):
diff --git a/airflow/sensors/base.py b/airflow/sensors/base.py
index b15596e..523caae 100644
--- a/airflow/sensors/base.py
+++ b/airflow/sensors/base.py
@@ -22,7 +22,7 @@ import hashlib
import time
import warnings
from datetime import timedelta
-from typing import Any, Callable, Dict, Iterable
+from typing import Any, Callable, Iterable, Union
from airflow import settings
from airflow.configuration import conf
@@ -37,6 +37,7 @@ from airflow.models.skipmixin import SkipMixin
from airflow.models.taskreschedule import TaskReschedule
from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
from airflow.utils import timezone
+from airflow.utils.context import Context
# We need to keep the import here because GCSToLocalFilesystemOperator
released in
# Google Provider before 3.0.0 imported apply_defaults from here.
@@ -150,7 +151,7 @@ class BaseSensorOperator(BaseOperator, SkipMixin):
f"mode since it will take reschedule time over MySQL's
TIMESTAMP limit."
)
- def poke(self, context: Dict) -> bool:
+ def poke(self, context: Context) -> bool:
"""
Function that the sensors defined while deriving this class should
override.
@@ -224,8 +225,8 @@ class BaseSensorOperator(BaseOperator, SkipMixin):
result['execution_timeout'] =
result['execution_timeout'].total_seconds()
return result
- def execute(self, context: Dict) -> Any:
- started_at = None
+ def execute(self, context: Context) -> Any:
+ started_at: Union[datetime.datetime, float]
if self.reschedule:
@@ -235,23 +236,22 @@ class BaseSensorOperator(BaseOperator, SkipMixin):
task_reschedules = TaskReschedule.find_for_task_instance(
context['ti'], try_number=first_try_number
)
- if task_reschedules:
- started_at = task_reschedules[0].start_date
+ if not task_reschedules:
+ start_date = timezone.utcnow()
else:
- started_at = timezone.utcnow()
+ start_date = task_reschedules[0].start_date
+ started_at = start_date
def run_duration() -> float:
# If we are in reschedule mode, then we have to compute diff
# based on the time in a DB, so can't use time.monotonic
- nonlocal started_at
- return (timezone.utcnow() - started_at).total_seconds()
+ return (timezone.utcnow() - start_date).total_seconds()
else:
- started_at = time.monotonic()
+ started_at = start_monotonic = time.monotonic()
def run_duration() -> float:
- nonlocal started_at
- return time.monotonic() - started_at
+ return time.monotonic() - start_monotonic
try_number = 1
log_dag_id = self.dag.dag_id if self.has_dag() else ""
@@ -277,23 +277,28 @@ class BaseSensorOperator(BaseOperator, SkipMixin):
try_number += 1
self.log.info("Success criteria met. Exiting.")
- def _get_next_poke_interval(self, started_at: Any, run_duration:
Callable[[], int], try_number):
+ def _get_next_poke_interval(
+ self,
+ started_at: Union[datetime.datetime, float],
+ run_duration: Callable[[], float],
+ try_number: int,
+ ) -> float:
"""Using the similar logic which is used for exponential backoff retry
delay for operators."""
- if self.exponential_backoff:
- min_backoff = int(self.poke_interval * (2 ** (try_number - 2)))
+ if not self.exponential_backoff:
+ return self.poke_interval
- run_hash = int(
-
hashlib.sha1(f"{self.dag_id}#{self.task_id}#{started_at}#{try_number}".encode()).hexdigest(),
- 16,
- )
- modded_hash = min_backoff + run_hash % min_backoff
+ min_backoff = int(self.poke_interval * (2 ** (try_number - 2)))
- delay_backoff_in_seconds = min(modded_hash,
timedelta.max.total_seconds() - 1)
- new_interval = min(self.timeout - int(run_duration()),
delay_backoff_in_seconds)
- self.log.info("new %s interval is %s", self.mode, new_interval)
- return new_interval
- else:
- return self.poke_interval
+ run_hash = int(
+
hashlib.sha1(f"{self.dag_id}#{self.task_id}#{started_at}#{try_number}".encode()).hexdigest(),
+ 16,
+ )
+ modded_hash = min_backoff + run_hash % min_backoff
+
+ delay_backoff_in_seconds = min(modded_hash,
timedelta.max.total_seconds() - 1)
+ new_interval = min(self.timeout - int(run_duration()),
delay_backoff_in_seconds)
+ self.log.info("new %s interval is %s", self.mode, new_interval)
+ return new_interval
def prepare_for_execution(self) -> BaseOperator:
task = super().prepare_for_execution()
diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py
index ef68c77..9295682 100644
--- a/airflow/sensors/external_task.py
+++ b/airflow/sensors/external_task.py
@@ -18,7 +18,7 @@
import datetime
import os
-from typing import Any, Callable, FrozenSet, Iterable, Optional, Union
+from typing import TYPE_CHECKING, Any, Callable, Collection, FrozenSet,
Iterable, Optional, Union
from sqlalchemy import func
@@ -96,7 +96,7 @@ class ExternalTaskSensor(BaseSensorOperator):
*,
external_dag_id: str,
external_task_id: Optional[str] = None,
- external_task_ids: Optional[Iterable[str]] = None,
+ external_task_ids: Optional[Collection[str]] = None,
allowed_states: Optional[Iterable[str]] = None,
failed_states: Optional[Iterable[str]] = None,
execution_delta: Optional[datetime.timedelta] = None,
@@ -108,8 +108,7 @@ class ExternalTaskSensor(BaseSensorOperator):
self.allowed_states = list(allowed_states) if allowed_states else
[State.SUCCESS]
self.failed_states = list(failed_states) if failed_states else []
- total_states = self.allowed_states + self.failed_states
- total_states = set(total_states)
+ total_states = set(self.allowed_states + self.failed_states)
if set(self.failed_states).intersection(set(self.allowed_states)):
raise AirflowException(
@@ -266,6 +265,8 @@ class ExternalTaskSensor(BaseSensorOperator):
# Add "context" in the kwargs for backward compatibility (because
context used to be
# an acceptable argument of execution_date_fn)
kwargs["context"] = context
+ if TYPE_CHECKING:
+ assert self.execution_date_fn is not None
kwargs_callable = make_kwargs_callable(self.execution_date_fn)
return kwargs_callable(execution_date, **kwargs)
diff --git a/airflow/sensors/python.py b/airflow/sensors/python.py
index 5780cd1..06a7ca6 100644
--- a/airflow/sensors/python.py
+++ b/airflow/sensors/python.py
@@ -18,6 +18,7 @@
from typing import Callable, Dict, List, Optional
from airflow.sensors.base import BaseSensorOperator
+from airflow.utils.context import Context
from airflow.utils.operator_helpers import determine_kwargs
@@ -62,7 +63,7 @@ class PythonSensor(BaseSensorOperator):
self.op_kwargs = op_kwargs or {}
self.templates_dict = templates_dict
- def poke(self, context: Dict):
+ def poke(self, context: Context) -> bool:
context.update(self.op_kwargs)
context['templates_dict'] = self.templates_dict
self.op_kwargs = determine_kwargs(self.python_callable, self.op_args,
context)
diff --git a/airflow/utils/context.py b/airflow/utils/context.py
index dd17ac1..5412b09 100644
--- a/airflow/utils/context.py
+++ b/airflow/utils/context.py
@@ -191,7 +191,14 @@ class Context(MutableMapping[str, Any]):
def values(self):
return ValuesView(self._context)
- def copy_only(self, keys: Container[str]) -> "Context":
- new = type(self)({k: v for k, v in self._context.items() if k in keys})
- new._deprecation_replacements = self._deprecation_replacements.copy()
- return new
+
+def context_copy_partial(source: Context, keys: Container[str]) -> "Context":
+ """Create a context by copying items under selected keys in ``source``.
+
+ This is implemented as a free function because the ``Context`` type is
+ "faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom
+ functions.
+ """
+ new = Context({k: v for k, v in source._context.items() if k in keys})
+ new._deprecation_replacements = source._deprecation_replacements.copy()
+ return new
diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi
index fdcd6ec..c479991 100644
--- a/airflow/utils/context.pyi
+++ b/airflow/utils/context.pyi
@@ -25,7 +25,7 @@
# undefined attribute errors from Mypy. Hopefully there will be a mechanism to
# declare "these are defined, but don't error if others are accessed" someday.
-from typing import Any, Optional, Union
+from typing import Any, Container, Optional, Union
from pendulum import DateTime
@@ -88,3 +88,5 @@ class Context(TypedDict, total=False):
var: _VariableAccessors
yesterday_ds: str
yesterday_ds_nodash: str
+
+def context_copy_partial(source: Context, keys: Container[str]) -> Context: ...
diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py
index 8d1e53b..8950dcb 100644
--- a/airflow/utils/helpers.py
+++ b/airflow/utils/helpers.py
@@ -22,7 +22,19 @@ import warnings
from datetime import datetime
from functools import reduce
from itertools import filterfalse, tee
-from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable,
List, Optional, Tuple, TypeVar
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Generator,
+ Iterable,
+ List,
+ MutableMapping,
+ Optional,
+ Tuple,
+ TypeVar,
+)
from urllib import parse
import flask
@@ -31,7 +43,6 @@ import jinja2.nativetypes
from airflow.configuration import conf
from airflow.exceptions import AirflowException
-from airflow.utils.context import Context
from airflow.utils.module_loading import import_string
if TYPE_CHECKING:
@@ -251,7 +262,7 @@ def build_airflow_url_with_query(query: Dict[str, Any]) ->
str:
# The 'template' argument is typed as Any because the jinja2.Template is too
# dynamic to be effectively type-checked.
-def render_template(template: Any, context: Context, *, native: bool) -> Any:
+def render_template(template: Any, context: MutableMapping[str, Any], *,
native: bool) -> Any:
"""Render a Jinja2 template with given Airflow context.
The default implementation of ``jinja2.Template.render()`` converts the
@@ -278,12 +289,12 @@ def render_template(template: Any, context: Context, *,
native: bool) -> Any:
return "".join(nodes)
-def render_template_to_string(template: jinja2.Template, context: Context) ->
str:
+def render_template_to_string(template: jinja2.Template, context:
MutableMapping[str, Any]) -> str:
"""Shorthand to ``render_template(native=False)`` with better typing
support."""
return render_template(template, context, native=False)
-def render_template_as_native(template: jinja2.Template, context: Context) ->
Any:
+def render_template_as_native(template: jinja2.Template, context:
MutableMapping[str, Any]) -> Any:
"""Shorthand to ``render_template(native=True)`` with better typing
support."""
return render_template(template, context, native=True)
diff --git a/airflow/utils/operator_helpers.py
b/airflow/utils/operator_helpers.py
index 8c5125b..e320f3c 100644
--- a/airflow/utils/operator_helpers.py
+++ b/airflow/utils/operator_helpers.py
@@ -17,7 +17,9 @@
# under the License.
#
from datetime import datetime
-from typing import Callable, Dict, List, Mapping, Tuple, Union
+from typing import Any, Callable, Dict, Mapping, Sequence, TypeVar
+
+R = TypeVar("R")
AIRFLOW_VAR_NAME_FORMAT_MAPPING = {
'AIRFLOW_CONTEXT_DAG_ID': {'default': 'airflow.ctx.dag_id',
'env_var_format': 'AIRFLOW_CTX_DAG_ID'},
@@ -41,7 +43,7 @@ AIRFLOW_VAR_NAME_FORMAT_MAPPING = {
}
-def context_to_airflow_vars(context, in_env_var_format=False):
+def context_to_airflow_vars(context: Mapping[str, Any], in_env_var_format:
bool = False) -> Dict[str, str]:
"""
Given a context, this function provides a dictionary of values that can be
used to
externally reconstruct relations between dags, dag_runs, tasks and
task_instances.
@@ -88,7 +90,11 @@ def context_to_airflow_vars(context,
in_env_var_format=False):
return params
-def determine_kwargs(func: Callable, args: Union[Tuple, List], kwargs:
Mapping) -> Dict:
+def determine_kwargs(
+ func: Callable[..., Any],
+ args: Sequence[Any],
+ kwargs: Mapping[str, Any],
+) -> Mapping[str, Any]:
"""
Inspect the signature of a given callable to determine which arguments in
kwargs need
to be passed to the callable.
@@ -118,7 +124,7 @@ def determine_kwargs(func: Callable, args: Union[Tuple,
List], kwargs: Mapping)
return {key: kwargs[key] for key in signature.parameters if key in kwargs}
-def make_kwargs_callable(func: Callable) -> Callable:
+def make_kwargs_callable(func: Callable[..., R]) -> Callable[..., R]:
"""
Make a new callable that can accept any number of positional or keyword
arguments
but only forwards those required by the given callable func.
diff --git a/airflow/utils/weekday.py b/airflow/utils/weekday.py
index 2bfdea5..6841307 100644
--- a/airflow/utils/weekday.py
+++ b/airflow/utils/weekday.py
@@ -16,7 +16,7 @@
# under the License.
"""Get the ISO standard day number of the week from a given day string"""
import enum
-from typing import Iterable, List, Set, Union
+from typing import Iterable, Set, Union
@enum.unique
@@ -56,8 +56,9 @@ class WeekDay(enum.IntEnum):
@classmethod
def validate_week_day(
- cls, week_day: Union[str, 'WeekDay', Set[str], Set['WeekDay'],
List[str], List['WeekDay']]
- ):
+ cls,
+ week_day: Union[str, "WeekDay", Iterable[str], Iterable["WeekDay"]],
+ ) -> Set[int]:
"""Validate each item of iterable and create a set to ease compare of
values"""
if not isinstance(week_day, Iterable):
if isinstance(week_day, WeekDay):