This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a99d44d  Typing fixes needed to deprecation warning fixes (#20376)
a99d44d is described below

commit a99d44d678bfc89a5b8f30cb83df457011aa7a1d
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Fri Dec 17 20:06:20 2021 +0800

    Typing fixes needed to deprecation warning fixes (#20376)
---
 airflow/models/baseoperator.py    |  2 +-
 airflow/models/skipmixin.py       |  4 +--
 airflow/operators/datetime.py     |  2 +-
 airflow/operators/python.py       | 14 +++++-----
 airflow/sensors/base.py           | 57 +++++++++++++++++++++------------------
 airflow/sensors/external_task.py  |  9 ++++---
 airflow/sensors/python.py         |  3 ++-
 airflow/utils/context.py          | 15 ++++++++---
 airflow/utils/context.pyi         |  4 ++-
 airflow/utils/helpers.py          | 21 +++++++++++----
 airflow/utils/operator_helpers.py | 14 +++++++---
 airflow/utils/weekday.py          |  7 ++---
 12 files changed, 93 insertions(+), 59 deletions(-)

diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 9f5d2d8..1233939 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -988,7 +988,7 @@ class BaseOperator(Operator, LoggingMixin, DependencyMixin, 
metaclass=BaseOperat
         if self._pre_execute_hook is not None:
             self._pre_execute_hook(context)
 
-    def execute(self, context: Any):
+    def execute(self, context: Context) -> Any:
         """
         This is the main method to derive when creating an operator.
         Context is the same dictionary used as when rendering jinja templates.
diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index de5f1fd..a552af8 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -22,7 +22,7 @@ from typing import TYPE_CHECKING, Iterable, Optional, 
Sequence, Union
 from airflow.models.taskinstance import TaskInstance
 from airflow.utils import timezone
 from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.session import create_session, provide_session
+from airflow.utils.session import NEW_SESSION, create_session, provide_session
 from airflow.utils.state import State
 
 if TYPE_CHECKING:
@@ -69,7 +69,7 @@ class SkipMixin(LoggingMixin):
         dag_run: "DagRun",
         execution_date: "DateTime",
         tasks: Sequence["BaseOperator"],
-        session: "Session",
+        session: "Session" = NEW_SESSION,
     ):
         """
         Sets tasks instances to skipped from the same dag run.
diff --git a/airflow/operators/datetime.py b/airflow/operators/datetime.py
index 6b1acf7..6750f12 100644
--- a/airflow/operators/datetime.py
+++ b/airflow/operators/datetime.py
@@ -100,7 +100,7 @@ def target_times_as_dates(
     if upper is not None and isinstance(upper, datetime.time):
         upper = datetime.datetime.combine(base_date, upper)
 
-    if any(date is None for date in (lower, upper)):
+    if lower is None or upper is None:
         return lower, upper
 
     if upper < lower:
diff --git a/airflow/operators/python.py b/airflow/operators/python.py
index 31e3a6f..c908bce 100644
--- a/airflow/operators/python.py
+++ b/airflow/operators/python.py
@@ -24,7 +24,7 @@ import types
 import warnings
 from tempfile import TemporaryDirectory
 from textwrap import dedent
-from typing import Callable, Dict, Iterable, List, Optional, Union
+from typing import Any, Callable, Dict, Iterable, List, Optional, Union
 
 import dill
 
@@ -32,7 +32,7 @@ from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator
 from airflow.models.skipmixin import SkipMixin
 from airflow.models.taskinstance import _CURRENT_CONTEXT
-from airflow.utils.context import Context
+from airflow.utils.context import Context, context_copy_partial
 from airflow.utils.operator_helpers import determine_kwargs
 from airflow.utils.process_utils import execute_in_subprocess
 from airflow.utils.python_virtualenv import prepare_virtualenv, 
write_python_script
@@ -172,7 +172,7 @@ class PythonOperator(BaseOperator):
             self.template_ext = templates_exts
         self.show_return_value_in_logs = show_return_value_in_logs
 
-    def execute(self, context: Dict):
+    def execute(self, context: Context) -> Any:
         context.update(self.op_kwargs)
         context['templates_dict'] = self.templates_dict
 
@@ -210,7 +210,7 @@ class BranchPythonOperator(PythonOperator, SkipMixin):
     to be inferred.
     """
 
-    def execute(self, context: Dict):
+    def execute(self, context: Context) -> Any:
         branch = super().execute(context)
         # TODO: The logic should be moved to SkipMixin to be available to all 
branch operators.
         if isinstance(branch, str):
@@ -242,7 +242,7 @@ class ShortCircuitOperator(PythonOperator, SkipMixin):
     The condition is determined by the result of `python_callable`.
     """
 
-    def execute(self, context: Dict):
+    def execute(self, context: Context) -> Any:
         condition = super().execute(context)
         self.log.info("Condition result is %s", condition)
 
@@ -398,9 +398,9 @@ class PythonVirtualenvOperator(PythonOperator):
                 self.requirements.append('dill')
         self.pickling_library = dill if self.use_dill else pickle
 
-    def execute(self, context: Context):
+    def execute(self, context: Context) -> Any:
         serializable_keys = set(self._iter_serializable_context_keys())
-        serializable_context = context.copy_only(serializable_keys)
+        serializable_context = context_copy_partial(context, serializable_keys)
         return super().execute(context=serializable_context)
 
     def execute_callable(self):
diff --git a/airflow/sensors/base.py b/airflow/sensors/base.py
index b15596e..523caae 100644
--- a/airflow/sensors/base.py
+++ b/airflow/sensors/base.py
@@ -22,7 +22,7 @@ import hashlib
 import time
 import warnings
 from datetime import timedelta
-from typing import Any, Callable, Dict, Iterable
+from typing import Any, Callable, Iterable, Union
 
 from airflow import settings
 from airflow.configuration import conf
@@ -37,6 +37,7 @@ from airflow.models.skipmixin import SkipMixin
 from airflow.models.taskreschedule import TaskReschedule
 from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
 from airflow.utils import timezone
+from airflow.utils.context import Context
 
 # We need to keep the import here because GCSToLocalFilesystemOperator 
released in
 # Google Provider before 3.0.0 imported apply_defaults from here.
@@ -150,7 +151,7 @@ class BaseSensorOperator(BaseOperator, SkipMixin):
                     f"mode since it will take reschedule time over MySQL's 
TIMESTAMP limit."
                 )
 
-    def poke(self, context: Dict) -> bool:
+    def poke(self, context: Context) -> bool:
         """
         Function that the sensors defined while deriving this class should
         override.
@@ -224,8 +225,8 @@ class BaseSensorOperator(BaseOperator, SkipMixin):
             result['execution_timeout'] = 
result['execution_timeout'].total_seconds()
         return result
 
-    def execute(self, context: Dict) -> Any:
-        started_at = None
+    def execute(self, context: Context) -> Any:
+        started_at: Union[datetime.datetime, float]
 
         if self.reschedule:
 
@@ -235,23 +236,22 @@ class BaseSensorOperator(BaseOperator, SkipMixin):
             task_reschedules = TaskReschedule.find_for_task_instance(
                 context['ti'], try_number=first_try_number
             )
-            if task_reschedules:
-                started_at = task_reschedules[0].start_date
+            if not task_reschedules:
+                start_date = timezone.utcnow()
             else:
-                started_at = timezone.utcnow()
+                start_date = task_reschedules[0].start_date
+            started_at = start_date
 
             def run_duration() -> float:
                 # If we are in reschedule mode, then we have to compute diff
                 # based on the time in a DB, so can't use time.monotonic
-                nonlocal started_at
-                return (timezone.utcnow() - started_at).total_seconds()
+                return (timezone.utcnow() - start_date).total_seconds()
 
         else:
-            started_at = time.monotonic()
+            started_at = start_monotonic = time.monotonic()
 
             def run_duration() -> float:
-                nonlocal started_at
-                return time.monotonic() - started_at
+                return time.monotonic() - start_monotonic
 
         try_number = 1
         log_dag_id = self.dag.dag_id if self.has_dag() else ""
@@ -277,23 +277,28 @@ class BaseSensorOperator(BaseOperator, SkipMixin):
                 try_number += 1
         self.log.info("Success criteria met. Exiting.")
 
-    def _get_next_poke_interval(self, started_at: Any, run_duration: 
Callable[[], int], try_number):
+    def _get_next_poke_interval(
+        self,
+        started_at: Union[datetime.datetime, float],
+        run_duration: Callable[[], float],
+        try_number: int,
+    ) -> float:
         """Using the similar logic which is used for exponential backoff retry 
delay for operators."""
-        if self.exponential_backoff:
-            min_backoff = int(self.poke_interval * (2 ** (try_number - 2)))
+        if not self.exponential_backoff:
+            return self.poke_interval
 
-            run_hash = int(
-                
hashlib.sha1(f"{self.dag_id}#{self.task_id}#{started_at}#{try_number}".encode()).hexdigest(),
-                16,
-            )
-            modded_hash = min_backoff + run_hash % min_backoff
+        min_backoff = int(self.poke_interval * (2 ** (try_number - 2)))
 
-            delay_backoff_in_seconds = min(modded_hash, 
timedelta.max.total_seconds() - 1)
-            new_interval = min(self.timeout - int(run_duration()), 
delay_backoff_in_seconds)
-            self.log.info("new %s interval is %s", self.mode, new_interval)
-            return new_interval
-        else:
-            return self.poke_interval
+        run_hash = int(
+            
hashlib.sha1(f"{self.dag_id}#{self.task_id}#{started_at}#{try_number}".encode()).hexdigest(),
+            16,
+        )
+        modded_hash = min_backoff + run_hash % min_backoff
+
+        delay_backoff_in_seconds = min(modded_hash, 
timedelta.max.total_seconds() - 1)
+        new_interval = min(self.timeout - int(run_duration()), 
delay_backoff_in_seconds)
+        self.log.info("new %s interval is %s", self.mode, new_interval)
+        return new_interval
 
     def prepare_for_execution(self) -> BaseOperator:
         task = super().prepare_for_execution()
diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py
index ef68c77..9295682 100644
--- a/airflow/sensors/external_task.py
+++ b/airflow/sensors/external_task.py
@@ -18,7 +18,7 @@
 
 import datetime
 import os
-from typing import Any, Callable, FrozenSet, Iterable, Optional, Union
+from typing import TYPE_CHECKING, Any, Callable, Collection, FrozenSet, 
Iterable, Optional, Union
 
 from sqlalchemy import func
 
@@ -96,7 +96,7 @@ class ExternalTaskSensor(BaseSensorOperator):
         *,
         external_dag_id: str,
         external_task_id: Optional[str] = None,
-        external_task_ids: Optional[Iterable[str]] = None,
+        external_task_ids: Optional[Collection[str]] = None,
         allowed_states: Optional[Iterable[str]] = None,
         failed_states: Optional[Iterable[str]] = None,
         execution_delta: Optional[datetime.timedelta] = None,
@@ -108,8 +108,7 @@ class ExternalTaskSensor(BaseSensorOperator):
         self.allowed_states = list(allowed_states) if allowed_states else 
[State.SUCCESS]
         self.failed_states = list(failed_states) if failed_states else []
 
-        total_states = self.allowed_states + self.failed_states
-        total_states = set(total_states)
+        total_states = set(self.allowed_states + self.failed_states)
 
         if set(self.failed_states).intersection(set(self.allowed_states)):
             raise AirflowException(
@@ -266,6 +265,8 @@ class ExternalTaskSensor(BaseSensorOperator):
         # Add "context" in the kwargs for backward compatibility (because 
context used to be
         # an acceptable argument of execution_date_fn)
         kwargs["context"] = context
+        if TYPE_CHECKING:
+            assert self.execution_date_fn is not None
         kwargs_callable = make_kwargs_callable(self.execution_date_fn)
         return kwargs_callable(execution_date, **kwargs)
 
diff --git a/airflow/sensors/python.py b/airflow/sensors/python.py
index 5780cd1..06a7ca6 100644
--- a/airflow/sensors/python.py
+++ b/airflow/sensors/python.py
@@ -18,6 +18,7 @@
 from typing import Callable, Dict, List, Optional
 
 from airflow.sensors.base import BaseSensorOperator
+from airflow.utils.context import Context
 from airflow.utils.operator_helpers import determine_kwargs
 
 
@@ -62,7 +63,7 @@ class PythonSensor(BaseSensorOperator):
         self.op_kwargs = op_kwargs or {}
         self.templates_dict = templates_dict
 
-    def poke(self, context: Dict):
+    def poke(self, context: Context) -> bool:
         context.update(self.op_kwargs)
         context['templates_dict'] = self.templates_dict
         self.op_kwargs = determine_kwargs(self.python_callable, self.op_args, 
context)
diff --git a/airflow/utils/context.py b/airflow/utils/context.py
index dd17ac1..5412b09 100644
--- a/airflow/utils/context.py
+++ b/airflow/utils/context.py
@@ -191,7 +191,14 @@ class Context(MutableMapping[str, Any]):
     def values(self):
         return ValuesView(self._context)
 
-    def copy_only(self, keys: Container[str]) -> "Context":
-        new = type(self)({k: v for k, v in self._context.items() if k in keys})
-        new._deprecation_replacements = self._deprecation_replacements.copy()
-        return new
+
+def context_copy_partial(source: Context, keys: Container[str]) -> "Context":
+    """Create a context by copying items under selected keys in ``source``.
+
+    This is implemented as a free function because the ``Context`` type is
+    "faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom
+    functions.
+    """
+    new = Context({k: v for k, v in source._context.items() if k in keys})
+    new._deprecation_replacements = source._deprecation_replacements.copy()
+    return new
diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi
index fdcd6ec..c479991 100644
--- a/airflow/utils/context.pyi
+++ b/airflow/utils/context.pyi
@@ -25,7 +25,7 @@
 # undefined attribute errors from Mypy. Hopefully there will be a mechanism to
 # declare "these are defined, but don't error if others are accessed" someday.
 
-from typing import Any, Optional, Union
+from typing import Any, Container, Optional, Union
 
 from pendulum import DateTime
 
@@ -88,3 +88,5 @@ class Context(TypedDict, total=False):
     var: _VariableAccessors
     yesterday_ds: str
     yesterday_ds_nodash: str
+
+def context_copy_partial(source: Context, keys: Container[str]) -> Context: ...
diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py
index 8d1e53b..8950dcb 100644
--- a/airflow/utils/helpers.py
+++ b/airflow/utils/helpers.py
@@ -22,7 +22,19 @@ import warnings
 from datetime import datetime
 from functools import reduce
 from itertools import filterfalse, tee
-from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, 
List, Optional, Tuple, TypeVar
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Dict,
+    Generator,
+    Iterable,
+    List,
+    MutableMapping,
+    Optional,
+    Tuple,
+    TypeVar,
+)
 from urllib import parse
 
 import flask
@@ -31,7 +43,6 @@ import jinja2.nativetypes
 
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException
-from airflow.utils.context import Context
 from airflow.utils.module_loading import import_string
 
 if TYPE_CHECKING:
@@ -251,7 +262,7 @@ def build_airflow_url_with_query(query: Dict[str, Any]) -> 
str:
 
 # The 'template' argument is typed as Any because the jinja2.Template is too
 # dynamic to be effectively type-checked.
-def render_template(template: Any, context: Context, *, native: bool) -> Any:
+def render_template(template: Any, context: MutableMapping[str, Any], *, 
native: bool) -> Any:
     """Render a Jinja2 template with given Airflow context.
 
     The default implementation of ``jinja2.Template.render()`` converts the
@@ -278,12 +289,12 @@ def render_template(template: Any, context: Context, *, 
native: bool) -> Any:
     return "".join(nodes)
 
 
-def render_template_to_string(template: jinja2.Template, context: Context) -> 
str:
+def render_template_to_string(template: jinja2.Template, context: 
MutableMapping[str, Any]) -> str:
     """Shorthand to ``render_template(native=False)`` with better typing 
support."""
     return render_template(template, context, native=False)
 
 
-def render_template_as_native(template: jinja2.Template, context: Context) -> 
Any:
+def render_template_as_native(template: jinja2.Template, context: 
MutableMapping[str, Any]) -> Any:
     """Shorthand to ``render_template(native=True)`` with better typing 
support."""
     return render_template(template, context, native=True)
 
diff --git a/airflow/utils/operator_helpers.py 
b/airflow/utils/operator_helpers.py
index 8c5125b..e320f3c 100644
--- a/airflow/utils/operator_helpers.py
+++ b/airflow/utils/operator_helpers.py
@@ -17,7 +17,9 @@
 # under the License.
 #
 from datetime import datetime
-from typing import Callable, Dict, List, Mapping, Tuple, Union
+from typing import Any, Callable, Dict, Mapping, Sequence, TypeVar
+
+R = TypeVar("R")
 
 AIRFLOW_VAR_NAME_FORMAT_MAPPING = {
     'AIRFLOW_CONTEXT_DAG_ID': {'default': 'airflow.ctx.dag_id', 
'env_var_format': 'AIRFLOW_CTX_DAG_ID'},
@@ -41,7 +43,7 @@ AIRFLOW_VAR_NAME_FORMAT_MAPPING = {
 }
 
 
-def context_to_airflow_vars(context, in_env_var_format=False):
+def context_to_airflow_vars(context: Mapping[str, Any], in_env_var_format: 
bool = False) -> Dict[str, str]:
     """
     Given a context, this function provides a dictionary of values that can be 
used to
     externally reconstruct relations between dags, dag_runs, tasks and 
task_instances.
@@ -88,7 +90,11 @@ def context_to_airflow_vars(context, 
in_env_var_format=False):
     return params
 
 
-def determine_kwargs(func: Callable, args: Union[Tuple, List], kwargs: 
Mapping) -> Dict:
+def determine_kwargs(
+    func: Callable[..., Any],
+    args: Sequence[Any],
+    kwargs: Mapping[str, Any],
+) -> Mapping[str, Any]:
     """
     Inspect the signature of a given callable to determine which arguments in 
kwargs need
     to be passed to the callable.
@@ -118,7 +124,7 @@ def determine_kwargs(func: Callable, args: Union[Tuple, 
List], kwargs: Mapping)
     return {key: kwargs[key] for key in signature.parameters if key in kwargs}
 
 
-def make_kwargs_callable(func: Callable) -> Callable:
+def make_kwargs_callable(func: Callable[..., R]) -> Callable[..., R]:
     """
     Make a new callable that can accept any number of positional or keyword 
arguments
     but only forwards those required by the given callable func.
diff --git a/airflow/utils/weekday.py b/airflow/utils/weekday.py
index 2bfdea5..6841307 100644
--- a/airflow/utils/weekday.py
+++ b/airflow/utils/weekday.py
@@ -16,7 +16,7 @@
 # under the License.
 """Get the ISO standard day number of the week from a given day string"""
 import enum
-from typing import Iterable, List, Set, Union
+from typing import Iterable, Set, Union
 
 
 @enum.unique
@@ -56,8 +56,9 @@ class WeekDay(enum.IntEnum):
 
     @classmethod
     def validate_week_day(
-        cls, week_day: Union[str, 'WeekDay', Set[str], Set['WeekDay'], 
List[str], List['WeekDay']]
-    ):
+        cls,
+        week_day: Union[str, "WeekDay", Iterable[str], Iterable["WeekDay"]],
+    ) -> Set[int]:
         """Validate each item of iterable and create a set to ease compare of 
values"""
         if not isinstance(week_day, Iterable):
             if isinstance(week_day, WeekDay):

Reply via email to