This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch v2-3-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit ce0caa0038d0d56d1850e2d0217e07b26bceca24 Author: Tzu-ping Chung <[email protected]> AuthorDate: Mon Jul 18 17:20:35 2022 +0800 Bump typing-extensions and mypy for ParamSpec (#25088) * Bump typing-extensions and mypy for ParamSpec I want to use them in some @task signature improvements. Mypy added this in 0.950, but let's just bump to latest since why not. Changelog of typing-extensions is spotty before 4.0, but ParamSpec was introduced some time before that (likely some time in 2021), and it seems to be a reasonble minimum to bump to. For more about ParamSpec, read PEP 612: https://peps.python.org/pep-0612/ (cherry picked from commit e32e9c58802fe9363cc87ea283a59218df7cec3a) --- airflow/jobs/scheduler_job.py | 4 +- airflow/mypy/plugin/decorators.py | 5 +- .../amazon/aws/transfers/dynamodb_to_s3.py | 1 + .../providers/amazon/aws/transfers/sql_to_s3.py | 19 ++++--- .../providers/google/cloud/operators/cloud_sql.py | 2 +- airflow/providers/microsoft/azure/hooks/cosmos.py | 62 +++++++++++++--------- airflow/utils/context.py | 6 +-- .../airflow_breeze/commands/testing_commands.py | 8 +-- scripts/in_container/run_migration_reference.py | 1 + setup.cfg | 2 +- setup.py | 2 +- .../microsoft/azure/hooks/test_azure_cosmos.py | 8 ++- 12 files changed, 73 insertions(+), 47 deletions(-) diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 3440832275..3613b9be47 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -170,7 +170,7 @@ class SchedulerJob(BaseJob): signal.signal(signal.SIGTERM, self._exit_gracefully) signal.signal(signal.SIGUSR2, self._debug_dump) - def _exit_gracefully(self, signum, frame) -> None: + def _exit_gracefully(self, signum: int, frame) -> None: """Helper method to clean up processor_agent to avoid leaving orphan processes.""" if not _is_parent_process(): # Only the parent process should perform the cleanup. @@ -181,7 +181,7 @@ class SchedulerJob(BaseJob): self.processor_agent.end() sys.exit(os.EX_OK) - def _debug_dump(self, signum, frame): + def _debug_dump(self, signum: int, frame) -> None: if not _is_parent_process(): # Only the parent process should perform the debug dump. return diff --git a/airflow/mypy/plugin/decorators.py b/airflow/mypy/plugin/decorators.py index 76f1af54cd..32e1113876 100644 --- a/airflow/mypy/plugin/decorators.py +++ b/airflow/mypy/plugin/decorators.py @@ -68,7 +68,10 @@ def _change_decorator_function_type( # Mark provided arguments as optional decorator.arg_types = copy.copy(decorated.arg_types) for argument in provided_arguments: - index = decorated.arg_names.index(argument) + try: + index = decorated.arg_names.index(argument) + except ValueError: + continue decorated_type = decorated.arg_types[index] decorator.arg_types[index] = UnionType.make_union([decorated_type, NoneType()]) decorated.arg_kinds[index] = ARG_NAMED_OPT diff --git a/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py b/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py index a6f5f8da21..218f4dc16c 100644 --- a/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py @@ -114,6 +114,7 @@ class DynamoDBToS3Operator(BaseOperator): scan_kwargs = copy(self.dynamodb_scan_kwargs) if self.dynamodb_scan_kwargs else {} err = None + f: IO[Any] with NamedTemporaryFile() as f: try: f = self._scan_dynamodb_and_upload_to_s3(f, scan_kwargs, table) diff --git a/airflow/providers/amazon/aws/transfers/sql_to_s3.py b/airflow/providers/amazon/aws/transfers/sql_to_s3.py index f399c27141..d9bebf5a39 100644 --- a/airflow/providers/amazon/aws/transfers/sql_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/sql_to_s3.py @@ -16,8 +16,8 @@ # specific language governing permissions and limitations # under the License. +import enum from collections import namedtuple -from enum import Enum from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union @@ -35,10 +35,13 @@ if TYPE_CHECKING: from airflow.utils.context import Context -FILE_FORMAT = Enum( - "FILE_FORMAT", - "CSV, JSON, PARQUET", -) +class FILE_FORMAT(enum.Enum): + """Possible file formats.""" + + CSV = enum.auto() + JSON = enum.auto() + PARQUET = enum.auto() + FileOptions = namedtuple('FileOptions', ['mode', 'suffix', 'function']) @@ -118,9 +121,9 @@ class SqlToS3Operator(BaseOperator): if "path_or_buf" in self.pd_kwargs: raise AirflowException('The argument path_or_buf is not allowed, please remove it') - self.file_format = getattr(FILE_FORMAT, file_format.upper(), None) - - if self.file_format is None: + try: + self.file_format = FILE_FORMAT[file_format.upper()] + except KeyError: raise AirflowException(f"The argument file_format doesn't support {file_format} value.") @staticmethod diff --git a/airflow/providers/google/cloud/operators/cloud_sql.py b/airflow/providers/google/cloud/operators/cloud_sql.py index 1441f518b4..fb5a88593e 100644 --- a/airflow/providers/google/cloud/operators/cloud_sql.py +++ b/airflow/providers/google/cloud/operators/cloud_sql.py @@ -35,7 +35,7 @@ if TYPE_CHECKING: SETTINGS = 'settings' SETTINGS_VERSION = 'settingsVersion' -CLOUD_SQL_CREATE_VALIDATION = [ +CLOUD_SQL_CREATE_VALIDATION: Sequence[dict] = [ dict(name="name", allow_empty=False), dict( name="settings", diff --git a/airflow/providers/microsoft/azure/hooks/cosmos.py b/airflow/providers/microsoft/azure/hooks/cosmos.py index ed475978b0..954b584846 100644 --- a/airflow/providers/microsoft/azure/hooks/cosmos.py +++ b/airflow/providers/microsoft/azure/hooks/cosmos.py @@ -23,6 +23,7 @@ Airflow connection of type `azure_cosmos` exists. Authorization can be done by s login (=Endpoint uri), password (=secret key) and extra fields database_name and collection_name to specify the default database and collection to use (see connection `azure_cosmos_default` for an example). """ +import json import uuid from typing import Any, Dict, Optional @@ -140,14 +141,22 @@ class AzureCosmosDBHook(BaseHook): existing_container = list( self.get_conn() .get_database_client(self.__get_database_name(database_name)) - .query_containers("SELECT * FROM r WHERE r.id=@id", [{"name": "@id", "value": collection_name}]) + .query_containers( + "SELECT * FROM r WHERE r.id=@id", + parameters=[json.dumps({"name": "@id", "value": collection_name})], + ) ) if len(existing_container) == 0: return False return True - def create_collection(self, collection_name: str, database_name: Optional[str] = None) -> None: + def create_collection( + self, + collection_name: str, + database_name: Optional[str] = None, + partition_key: Optional[str] = None, + ) -> None: """Creates a new collection in the CosmosDB database.""" if collection_name is None: raise AirflowBadRequest("Collection name cannot be None.") @@ -157,13 +166,16 @@ class AzureCosmosDBHook(BaseHook): existing_container = list( self.get_conn() .get_database_client(self.__get_database_name(database_name)) - .query_containers("SELECT * FROM r WHERE r.id=@id", [{"name": "@id", "value": collection_name}]) + .query_containers( + "SELECT * FROM r WHERE r.id=@id", + parameters=[json.dumps({"name": "@id", "value": collection_name})], + ) ) # Only create if we did not find it already existing if len(existing_container) == 0: self.get_conn().get_database_client(self.__get_database_name(database_name)).create_container( - collection_name + collection_name, partition_key=partition_key ) def does_database_exist(self, database_name: str) -> bool: @@ -173,10 +185,8 @@ class AzureCosmosDBHook(BaseHook): existing_database = list( self.get_conn().query_databases( - { - "query": "SELECT * FROM r WHERE r.id=@id", - "parameters": [{"name": "@id", "value": database_name}], - } + "SELECT * FROM r WHERE r.id=@id", + parameters=[json.dumps({"name": "@id", "value": database_name})], ) ) if len(existing_database) == 0: @@ -193,10 +203,8 @@ class AzureCosmosDBHook(BaseHook): # to create it twice existing_database = list( self.get_conn().query_databases( - { - "query": "SELECT * FROM r WHERE r.id=@id", - "parameters": [{"name": "@id", "value": database_name}], - } + "SELECT * FROM r WHERE r.id=@id", + parameters=[json.dumps({"name": "@id", "value": database_name})], ) ) @@ -267,18 +275,28 @@ class AzureCosmosDBHook(BaseHook): return created_documents def delete_document( - self, document_id: str, database_name: Optional[str] = None, collection_name: Optional[str] = None + self, + document_id: str, + database_name: Optional[str] = None, + collection_name: Optional[str] = None, + partition_key: Optional[str] = None, ) -> None: """Delete an existing document out of a collection in the CosmosDB database.""" if document_id is None: raise AirflowBadRequest("Cannot delete a document without an id") - - self.get_conn().get_database_client(self.__get_database_name(database_name)).get_container_client( - self.__get_collection_name(collection_name) - ).delete_item(document_id) + ( + self.get_conn() + .get_database_client(self.__get_database_name(database_name)) + .get_container_client(self.__get_collection_name(collection_name)) + .delete_item(document_id, partition_key=partition_key) + ) def get_document( - self, document_id: str, database_name: Optional[str] = None, collection_name: Optional[str] = None + self, + document_id: str, + database_name: Optional[str] = None, + collection_name: Optional[str] = None, + partition_key: Optional[str] = None, ): """Get a document from an existing collection in the CosmosDB database.""" if document_id is None: @@ -289,7 +307,7 @@ class AzureCosmosDBHook(BaseHook): self.get_conn() .get_database_client(self.__get_database_name(database_name)) .get_container_client(self.__get_collection_name(collection_name)) - .read_item(document_id) + .read_item(document_id, partition_key=partition_key) ) except CosmosHttpResponseError: return None @@ -305,17 +323,13 @@ class AzureCosmosDBHook(BaseHook): if sql_string is None: raise AirflowBadRequest("SQL query string cannot be None") - # Query them in SQL - query = {'query': sql_string} - try: result_iterable = ( self.get_conn() .get_database_client(self.__get_database_name(database_name)) .get_container_client(self.__get_collection_name(collection_name)) - .query_items(query, partition_key) + .query_items(sql_string, partition_key=partition_key) ) - return list(result_iterable) except CosmosHttpResponseError: return None diff --git a/airflow/utils/context.py b/airflow/utils/context.py index 04dababa24..ffa0e6b95c 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -23,12 +23,12 @@ import copy import functools import warnings from typing import ( - AbstractSet, Any, Container, Dict, ItemsView, Iterator, + KeysView, List, Mapping, MutableMapping, @@ -175,7 +175,7 @@ class Context(MutableMapping[str, Any]): } def __init__(self, context: Optional[MutableMapping[str, Any]] = None, **kwargs: Any) -> None: - self._context = context or {} + self._context: MutableMapping[str, Any] = context or {} if kwargs: self._context.update(kwargs) self._deprecation_replacements = self._DEPRECATION_REPLACEMENTS.copy() @@ -231,7 +231,7 @@ class Context(MutableMapping[str, Any]): return NotImplemented return self._context != other._context - def keys(self) -> AbstractSet[str]: + def keys(self) -> KeysView[str]: return self._context.keys() def items(self): diff --git a/dev/breeze/src/airflow_breeze/commands/testing_commands.py b/dev/breeze/src/airflow_breeze/commands/testing_commands.py index b53333ea64..05aa3aa7e8 100644 --- a/dev/breeze/src/airflow_breeze/commands/testing_commands.py +++ b/dev/breeze/src/airflow_breeze/commands/testing_commands.py @@ -197,9 +197,9 @@ def run_with_progress( ) -> RunCommandResult: title = f"Running tests: {test_type}, Python: {python}, Backend: {backend}:{version}" try: - with tempfile.NamedTemporaryFile(mode='w+t', delete=False) as f: + with tempfile.NamedTemporaryFile(mode='w+t', delete=False) as tf: get_console().print(f"[info]Starting test = {title}[/]") - thread = MonitoringThread(title=title, file_name=f.name) + thread = MonitoringThread(title=title, file_name=tf.name) thread.start() try: result = run_command( @@ -208,14 +208,14 @@ def run_with_progress( dry_run=dry_run, env=env_variables, check=False, - stdout=f, + stdout=tf, stderr=subprocess.STDOUT, ) finally: thread.stop() thread.join() with ci_group(f"Result of {title}", message_type=message_type_from_return_code(result.returncode)): - with open(f.name) as f: + with open(tf.name) as f: shutil.copyfileobj(f, sys.stdout) finally: os.unlink(f.name) diff --git a/scripts/in_container/run_migration_reference.py b/scripts/in_container/run_migration_reference.py index cc05408c2a..12ff265c55 100755 --- a/scripts/in_container/run_migration_reference.py +++ b/scripts/in_container/run_migration_reference.py @@ -102,6 +102,7 @@ def revision_suffix(rev: "Script"): def ensure_airflow_version(revisions: Iterable["Script"]): for rev in revisions: + assert rev.module.__file__ is not None # For Mypy. file = Path(rev.module.__file__) content = file.read_text() if not has_version(content): diff --git a/setup.cfg b/setup.cfg index 0e1e9f7b84..2c96a0de42 100644 --- a/setup.cfg +++ b/setup.cfg @@ -146,7 +146,7 @@ install_requires = tabulate>=0.7.5 tenacity>=6.2.0 termcolor>=1.1.0 - typing-extensions>=3.7.4 + typing-extensions>=4.0.0 unicodecsv>=0.14.1 werkzeug>=2.0 diff --git a/setup.py b/setup.py index 4d5dbd1bb8..6447281e5d 100644 --- a/setup.py +++ b/setup.py @@ -578,7 +578,7 @@ zendesk = [ # mypyd which does not support installing the types dynamically with --install-types mypy_dependencies = [ # TODO: upgrade to newer versions of MyPy continuously as they are released - 'mypy==0.910', + 'mypy==0.950', 'types-boto', 'types-certifi', 'types-croniter', diff --git a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py index b407fbdb3c..e157a5276b 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py @@ -90,7 +90,9 @@ class TestAzureCosmosDbHook(unittest.TestCase): hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') hook.create_collection(self.test_collection_name, self.test_database_name) expected_calls = [ - mock.call().get_database_client('test_database_name').create_container('test_collection_name') + mock.call() + .get_database_client('test_database_name') + .create_container('test_collection_name', partition_key=None) ] mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) mock_cosmos.assert_has_calls(expected_calls) @@ -100,7 +102,9 @@ class TestAzureCosmosDbHook(unittest.TestCase): hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') hook.create_collection(self.test_collection_name) expected_calls = [ - mock.call().get_database_client('test_database_name').create_container('test_collection_name') + mock.call() + .get_database_client('test_database_name') + .create_container('test_collection_name', partition_key=None) ] mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) mock_cosmos.assert_has_calls(expected_calls)
