This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-3-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 1db9c1713d37bd551571913a5657cc7c554025ef
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Mon May 30 03:24:08 2022 -0400

    Ensure @contextmanager decorates generator func (#23103)
    
    (cherry picked from commit e58985598f202395098e15b686aec33645a906ff)
---
 airflow/cli/commands/task_command.py                  |  4 ++--
 airflow/models/taskinstance.py                        |  3 +--
 airflow/providers/google/cloud/hooks/gcs.py           | 19 ++++++++++++++++---
 .../google/cloud/utils/credentials_provider.py        |  9 ++++++---
 airflow/providers/google/common/hooks/base_google.py  | 10 +++++-----
 airflow/providers/microsoft/psrp/hooks/psrp.py        |  4 ++--
 airflow/utils/db.py                                   | 11 ++++++++---
 airflow/utils/process_utils.py                        |  4 ++--
 airflow/utils/session.py                              |  4 ++--
 dev/breeze/src/airflow_breeze/utils/run_utils.py      |  4 ++--
 dev/provider_packages/prepare_provider_packages.py    |  4 ++--
 11 files changed, 48 insertions(+), 28 deletions(-)

diff --git a/airflow/cli/commands/task_command.py 
b/airflow/cli/commands/task_command.py
index ea20ebb646..2b743b91fe 100644
--- a/airflow/cli/commands/task_command.py
+++ b/airflow/cli/commands/task_command.py
@@ -23,7 +23,7 @@ import logging
 import os
 import textwrap
 from contextlib import contextmanager, redirect_stderr, redirect_stdout
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Dict, Generator, List, Optional, Tuple, Union
 
 from pendulum.parsing.exceptions import ParserError
 from sqlalchemy.orm.exc import NoResultFound
@@ -269,7 +269,7 @@ def _extract_external_executor_id(args) -> Optional[str]:
 
 
 @contextmanager
-def _capture_task_logs(ti):
+def _capture_task_logs(ti: TaskInstance) -> Generator[None, None, None]:
     """Manage logging context for a task run
 
     - Replace the root logger configuration with the airflow.task configuration
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 2885d56b54..5cd582ce3e 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -40,7 +40,6 @@ from typing import (
     Dict,
     Generator,
     Iterable,
-    Iterator,
     List,
     NamedTuple,
     Optional,
@@ -142,7 +141,7 @@ if TYPE_CHECKING:
 
 
 @contextlib.contextmanager
-def set_current_context(context: Context) -> Iterator[Context]:
+def set_current_context(context: Context) -> Generator[Context, None, None]:
     """
     Sets the current execution context to the provided context object.
     This method should be called once per Task execution, before calling 
operator.execute.
diff --git a/airflow/providers/google/cloud/hooks/gcs.py 
b/airflow/providers/google/cloud/hooks/gcs.py
index 29ad6ac438..93717e00e9 100644
--- a/airflow/providers/google/cloud/hooks/gcs.py
+++ b/airflow/providers/google/cloud/hooks/gcs.py
@@ -29,7 +29,20 @@ from functools import partial
 from io import BytesIO
 from os import path
 from tempfile import NamedTemporaryFile
-from typing import Callable, List, Optional, Sequence, Set, Tuple, TypeVar, 
Union, cast, overload
+from typing import (
+    IO,
+    Callable,
+    Generator,
+    List,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+    TypeVar,
+    Union,
+    cast,
+    overload,
+)
 from urllib.parse import urlparse
 
 from google.api_core.exceptions import NotFound
@@ -385,7 +398,7 @@ class GCSHook(GoogleBaseHook):
         object_name: Optional[str] = None,
         object_url: Optional[str] = None,
         dir: Optional[str] = None,
-    ):
+    ) -> Generator[IO[bytes], None, None]:
         """
         Downloads the file to a temporary directory and returns a file handle
 
@@ -413,7 +426,7 @@ class GCSHook(GoogleBaseHook):
         bucket_name: str = PROVIDE_BUCKET,
         object_name: Optional[str] = None,
         object_url: Optional[str] = None,
-    ):
+    ) -> Generator[IO[bytes], None, None]:
         """
         Creates temporary file, returns a file handle and uploads the files 
content
         on close.
diff --git a/airflow/providers/google/cloud/utils/credentials_provider.py 
b/airflow/providers/google/cloud/utils/credentials_provider.py
index 0a8143ceae..1cf33ea70b 100644
--- a/airflow/providers/google/cloud/utils/credentials_provider.py
+++ b/airflow/providers/google/cloud/utils/credentials_provider.py
@@ -74,7 +74,10 @@ def build_gcp_conn(
 
 
 @contextmanager
-def provide_gcp_credentials(key_file_path: Optional[str] = None, 
key_file_dict: Optional[Dict] = None):
+def provide_gcp_credentials(
+    key_file_path: Optional[str] = None,
+    key_file_dict: Optional[Dict] = None,
+) -> Generator[None, None, None]:
     """
     Context manager that provides a Google Cloud credentials for application 
supporting
     `Application Default Credentials (ADC) strategy`__.
@@ -111,7 +114,7 @@ def provide_gcp_connection(
     key_file_path: Optional[str] = None,
     scopes: Optional[Sequence] = None,
     project_id: Optional[str] = None,
-) -> Generator:
+) -> Generator[None, None, None]:
     """
     Context manager that provides a temporary value of 
:envvar:`AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT`
     connection. It build a new connection that includes path to provided 
service json,
@@ -135,7 +138,7 @@ def provide_gcp_conn_and_credentials(
     key_file_path: Optional[str] = None,
     scopes: Optional[Sequence] = None,
     project_id: Optional[str] = None,
-) -> Generator:
+) -> Generator[None, None, None]:
     """
     Context manager that provides both:
 
diff --git a/airflow/providers/google/common/hooks/base_google.py 
b/airflow/providers/google/common/hooks/base_google.py
index f2c0d5157a..d9fe5daba5 100644
--- a/airflow/providers/google/common/hooks/base_google.py
+++ b/airflow/providers/google/common/hooks/base_google.py
@@ -25,7 +25,7 @@ import tempfile
 import warnings
 from contextlib import ExitStack, contextmanager
 from subprocess import check_output
-from typing import Any, Callable, Dict, Optional, Sequence, Tuple, TypeVar, 
Union, cast
+from typing import Any, Callable, Dict, Generator, Optional, Sequence, Tuple, 
TypeVar, Union, cast
 
 import google.auth
 import google.auth.credentials
@@ -459,7 +459,7 @@ class GoogleBaseHook(BaseHook):
         return cast(T, wrapper)
 
     @contextmanager
-    def provide_gcp_credential_file_as_context(self):
+    def provide_gcp_credential_file_as_context(self) -> 
Generator[Optional[str], None, None]:
         """
         Context manager that provides a Google Cloud credentials for 
application supporting `Application
         Default Credentials (ADC) strategy 
<https://cloud.google.com/docs/authentication/production>`__.
@@ -467,8 +467,8 @@ class GoogleBaseHook(BaseHook):
         It can be used to provide credentials for external programs (e.g. 
gcloud) that expect authorization
         file in ``GOOGLE_APPLICATION_CREDENTIALS`` environment variable.
         """
-        key_path = self._get_field('key_path', None)  # type: Optional[str]    
#
-        keyfile_dict = self._get_field('keyfile_dict', None)  # type: 
Optional[Dict]
+        key_path: Optional[str] = self._get_field('key_path', None)
+        keyfile_dict: Optional[str] = self._get_field('keyfile_dict', None)
         if key_path and keyfile_dict:
             raise AirflowException(
                 "The `keyfile_dict` and `key_path` fields are mutually 
exclusive. "
@@ -490,7 +490,7 @@ class GoogleBaseHook(BaseHook):
             yield None
 
     @contextmanager
-    def provide_authorized_gcloud(self):
+    def provide_authorized_gcloud(self) -> Generator[None, None, None]:
         """
         Provides a separate gcloud configuration with current credentials.
 
diff --git a/airflow/providers/microsoft/psrp/hooks/psrp.py 
b/airflow/providers/microsoft/psrp/hooks/psrp.py
index 005f1e215d..0aebe63d03 100644
--- a/airflow/providers/microsoft/psrp/hooks/psrp.py
+++ b/airflow/providers/microsoft/psrp/hooks/psrp.py
@@ -19,7 +19,7 @@
 from contextlib import contextmanager
 from copy import copy
 from logging import DEBUG, ERROR, INFO, WARNING
-from typing import Any, Callable, Dict, Iterator, Optional
+from typing import Any, Callable, Dict, Generator, Optional
 from weakref import WeakKeyDictionary
 
 from pypsrp.host import PSHost
@@ -155,7 +155,7 @@ class PsrpHook(BaseHook):
         return pool
 
     @contextmanager
-    def invoke(self) -> Iterator[PowerShell]:
+    def invoke(self) -> Generator[PowerShell, None, None]:
         """
         Context manager that yields a PowerShell object to which commands can 
be
         added. Upon exit, the commands will be invoked.
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index 03c0848e6c..f606dfc332 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -24,7 +24,7 @@ import time
 import warnings
 from dataclasses import dataclass
 from tempfile import gettempdir
-from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Tuple, 
Union
+from typing import TYPE_CHECKING, Callable, Generator, Iterable, List, 
Optional, Tuple, Union
 
 from sqlalchemy import Table, and_, column, exc, func, inspect, or_, select, 
table, text, tuple_
 from sqlalchemy.orm.session import Session
@@ -68,6 +68,7 @@ from airflow.utils.session import NEW_SESSION, 
create_session, provide_session
 from airflow.version import version
 
 if TYPE_CHECKING:
+    from alembic.runtime.environment import EnvironmentContext
     from alembic.script import ScriptDirectory
     from sqlalchemy.orm import Query
 
@@ -709,7 +710,7 @@ def check_migrations(timeout):
 
 
 @contextlib.contextmanager
-def _configured_alembic_environment():
+def _configured_alembic_environment() -> Generator["EnvironmentContext", None, 
None]:
     from alembic.runtime.environment import EnvironmentContext
 
     config = _get_alembic_config()
@@ -1606,7 +1607,11 @@ class DBLocks(enum.IntEnum):
 
 
 @contextlib.contextmanager
-def create_global_lock(session: Session, lock: DBLocks, lock_timeout=1800):
+def create_global_lock(
+    session: Session,
+    lock: DBLocks,
+    lock_timeout: int = 1800,
+) -> Generator[None, None, None]:
     """Contextmanager that will create and teardown a global db lock."""
     conn = session.get_bind().connect()
     dialect = conn.dialect
diff --git a/airflow/utils/process_utils.py b/airflow/utils/process_utils.py
index fd63f3e959..2ec782df66 100644
--- a/airflow/utils/process_utils.py
+++ b/airflow/utils/process_utils.py
@@ -34,7 +34,7 @@ if not IS_WINDOWS:
     import pty
 
 from contextlib import contextmanager
-from typing import Dict, List, Optional
+from typing import Dict, Generator, List, Optional
 
 import psutil
 from lockfile.pidlockfile import PIDLockFile
@@ -258,7 +258,7 @@ def kill_child_processes_by_pids(pids_to_kill: List[int], 
timeout: int = 5) -> N
 
 
 @contextmanager
-def patch_environ(new_env_variables: Dict[str, str]):
+def patch_environ(new_env_variables: Dict[str, str]) -> Generator[None, None, 
None]:
     """
     Sets environment variables in context. After leaving the context, it 
restores its original state.
 
diff --git a/airflow/utils/session.py b/airflow/utils/session.py
index 3565e216a2..377ff55cbf 100644
--- a/airflow/utils/session.py
+++ b/airflow/utils/session.py
@@ -17,13 +17,13 @@
 import contextlib
 from functools import wraps
 from inspect import signature
-from typing import Callable, Iterator, TypeVar, cast
+from typing import Callable, Generator, TypeVar, cast
 
 from airflow import settings
 
 
 @contextlib.contextmanager
-def create_session() -> Iterator[settings.SASession]:
+def create_session() -> Generator[settings.SASession, None, None]:
     """Contextmanager that will create and teardown a session."""
     if not settings.Session:
         raise RuntimeError("Session must be set before!")
diff --git a/dev/breeze/src/airflow_breeze/utils/run_utils.py 
b/dev/breeze/src/airflow_breeze/utils/run_utils.py
index 86b84be4c0..03f3c0532d 100644
--- a/dev/breeze/src/airflow_breeze/utils/run_utils.py
+++ b/dev/breeze/src/airflow_breeze/utils/run_utils.py
@@ -25,7 +25,7 @@ from distutils.version import StrictVersion
 from functools import lru_cache
 from pathlib import Path
 from re import match
-from typing import Dict, List, Mapping, Optional, Union
+from typing import Dict, Generator, List, Mapping, Optional, Union
 
 from airflow_breeze.branch_defaults import AIRFLOW_BRANCH
 from airflow_breeze.params._common_build_params import _CommonBuildParams
@@ -213,7 +213,7 @@ def instruct_build_image(python: str):
 
 
 @contextlib.contextmanager
-def working_directory(source_path: Path):
+def working_directory(source_path: Path) -> Generator[None, None, None]:
     """
     # Equivalent of pushd and popd in bash script.
     # https://stackoverflow.com/a/42441759/3101838
diff --git a/dev/provider_packages/prepare_provider_packages.py 
b/dev/provider_packages/prepare_provider_packages.py
index 0091dee5c6..5ecf9a4850 100755
--- a/dev/provider_packages/prepare_provider_packages.py
+++ b/dev/provider_packages/prepare_provider_packages.py
@@ -38,7 +38,7 @@ from functools import lru_cache
 from os.path import dirname, relpath
 from pathlib import Path
 from shutil import copyfile
-from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Set, 
Tuple, Union
+from typing import Any, Dict, Generator, Iterable, List, NamedTuple, Optional, 
Set, Tuple, Union
 
 import jsonschema
 import rich_click as click
@@ -195,7 +195,7 @@ argument_package_ids = click.argument('package_ids', 
nargs=-1)
 
 
 @contextmanager
-def with_group(title):
+def with_group(title: str) -> Generator[None, None, None]:
     """
     If used in GitHub Action, creates an expandable group in the GitHub Action 
log.
     Otherwise, display simple text groups.

Reply via email to