This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new faa4cde34b Disallow create/extend thick-wrapped hooks in Amazon 
Provider (#35163)
faa4cde34b is described below

commit faa4cde34b6900a9f5765aff69abc34003b063fa
Author: Andrey Anshin <[email protected]>
AuthorDate: Tue Oct 31 00:56:07 2023 +0400

    Disallow create/extend thick-wrapped hooks in Amazon Provider (#35163)
---
 .../amazon/aws/hooks/test_hooks_signature.py       | 198 +++++++++++++++++++++
 1 file changed, 198 insertions(+)

diff --git a/tests/providers/amazon/aws/hooks/test_hooks_signature.py 
b/tests/providers/amazon/aws/hooks/test_hooks_signature.py
new file mode 100644
index 0000000000..9b4cfe9f6c
--- /dev/null
+++ b/tests/providers/amazon/aws/hooks/test_hooks_signature.py
@@ -0,0 +1,198 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import inspect
+from importlib import import_module
+from pathlib import Path
+
+import pytest
+
+from airflow.exceptions import AirflowOptionalProviderFeatureException
+from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
+
+BASE_AWS_HOOKS = ["AwsGenericHook", "AwsBaseHook"]
+ALLOWED_THICK_HOOKS_PARAMETERS: dict[str, set[str]] = {
+    # This list should only be reduced not extended with new parameters,
+    # unless there is an exceptional reason.
+    "AthenaHook": {"sleep_time", "log_query"},
+    "BatchClientHook": {"status_retries", "max_retries"},
+    "BatchWaitersHook": {"waiter_config"},
+    "DataSyncHook": {"wait_interval_seconds"},
+    "DynamoDBHook": {"table_name", "table_keys"},
+    "EC2Hook": {"api_type"},
+    "ElastiCacheReplicationGroupHook": {
+        "exponential_back_off_factor",
+        "max_retries",
+        "initial_poke_interval",
+    },
+    "EmrHook": {"emr_conn_id"},
+    "EmrContainerHook": {"virtual_cluster_id"},
+    "FirehoseHook": {"delivery_stream"},
+    "GlueJobHook": {
+        "job_name",
+        "concurrent_run_limit",
+        "job_poll_interval",
+        "create_job_kwargs",
+        "desc",
+        "iam_role_arn",
+        "s3_bucket",
+        "iam_role_name",
+        "update_config",
+        "retry_limit",
+        "num_of_dpus",
+        "script_location",
+    },
+    "S3Hook": {"transfer_config_args", "aws_conn_id", "extra_args"},
+}
+
+
+def get_aws_hooks_modules():
+    """Parse Amazon Provider metadata and find all hooks based on 
`AwsGenericHook` and return it."""
+    hooks_dir = Path(__file__).absolute().parents[5] / "airflow" / "providers" 
/ "amazon" / "aws" / "hooks"
+    if not hooks_dir.exists():
+        msg = f"Amazon Provider hooks directory not found: 
{hooks_dir.__fspath__()!r}"
+        raise FileNotFoundError(msg)
+    elif not hooks_dir.is_dir():
+        raise NotADirectoryError(hooks_dir.__fspath__())
+
+    for module in hooks_dir.glob("*.py"):
+        name = module.stem
+        if name.startswith("_"):
+            continue
+        module_string = f"airflow.providers.amazon.aws.hooks.{name}"
+
+        yield pytest.param(module_string, id=name)
+
+
+def get_aws_hooks_from_module(hook_module: str) -> 
list[tuple[type[AwsGenericHook], str]]:
+    try:
+        imported_module = import_module(hook_module)
+    except AirflowOptionalProviderFeatureException as ex:
+        pytest.skip(str(ex))
+    else:
+        hooks = []
+        for name, o in vars(imported_module).items():
+            if name in BASE_AWS_HOOKS:
+                continue
+
+            if isinstance(o, type) and o.__module__ != "builtins" and 
issubclass(o, AwsGenericHook):
+                hooks.append((o, name))
+        return hooks
+
+
+def validate_hook(hook: type[AwsGenericHook], hook_name: str, hook_module: 
str) -> tuple[bool, str | None]:
+    hook_extra_parameters = set()
+    for k, v in inspect.signature(hook).parameters.items():
+        if v.kind == inspect.Parameter.VAR_POSITIONAL:
+            k = "*args"
+        elif v.kind == inspect.Parameter.VAR_KEYWORD:
+            k = "**kwargs"
+
+        hook_extra_parameters.add(k)
+    hook_extra_parameters.difference_update({"self", "*args", "**kwargs"})
+
+    allowed_parameters = ALLOWED_THICK_HOOKS_PARAMETERS.get(hook_name, set())
+    if allowed_parameters:
+        # Remove historically allowed parameters for Thick Wrapped Hooks
+        hook_extra_parameters -= allowed_parameters
+
+    if not hook_extra_parameters:
+        # No additional arguments found
+        return True, None
+
+    if not allowed_parameters:
+        msg = (
+            f"'{hook_module}.{hook_name}' has additional attributes "
+            f"{', '.join(map(repr, hook_extra_parameters))}. "
+            "Expected that all `boto3` related hooks (based on 
`AwsGenericHook` or `AwsBaseHook`) "
+            "should not use additional attributes in class constructor, "
+            "please move them to method signatures. "
+            f"Make sure that {hook_name!r} constructor has signature `def 
__init__(self, *args, **kwargs):`"
+        )
+    else:
+        msg = (
+            f"'{hook_module}.{hook_name}' allowed only "
+            f"{', '.join(map(repr, allowed_parameters))} additional 
attributes, "
+            f"but got extra parameters {', '.join(map(repr, 
hook_extra_parameters))}. "
+            "Please move additional attributes from class constructor into 
method signatures. "
+        )
+
+    return False, msg
+
+
[email protected]("hook_module", get_aws_hooks_modules())
+def test_expected_thin_hooks(hook_module: str):
+    """
+    Test Amazon provider Hooks' signatures.
+
+    All hooks should provide thin wrapper around boto3 / aiobotocore,
+    that mean we should not define additional parameters in Hook parameters.
+    It should be defined in appropriate methods.
+
+    .. code-block:: python
+
+        # Bad: Thick wrapper
+        from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+
+
+        class AwsServiceName(AwsBaseHook):
+            def __init__(self, foo: str, spam: str, *args, **kwargs) -> None:
+                kwargs.update(dict(client_type="service", resource_type=None))
+                super().__init__(*args, **kwargs)
+                self.foo = foo
+                self.spam = spam
+
+            def method1(self):
+                if self.foo == "bar":
+                    ...
+
+            def method2(self):
+                if self.spam == "egg":
+                    ...
+
+    .. code-block:: python
+
+        # Good: Thin wrapper
+        class AwsServiceName(AwsBaseHook):
+            def __init__(self, *args, **kwargs) -> None:
+                kwargs.update(dict(client_type="service", resource_type=None))
+                super().__init__(*args, **kwargs)
+
+            def method1(self, foo: str):
+                if foo == "bar":
+                    ...
+
+            def method2(self, spam: str):
+                if spam == "egg":
+                    ...
+
+    """
+    hooks = get_aws_hooks_from_module(hook_module)
+    if not hooks:
+        pytest.skip(reason=f"Module {hook_module!r} doesn't contain subclasses 
of `AwsGenericHook`.")
+
+    errors = []
+    for hook, hook_name in hooks:
+        is_valid, msg = validate_hook(hook, hook_name, hook_module)
+        if not is_valid:
+            errors.append(msg)
+
+    if errors:
+        errors_msg = "\n * ".join(errors)
+        pytest.fail(reason=f"Found errors in {hook_module}:\n * {errors_msg}")

Reply via email to