This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3ab2eb9884c fix(amazon): Filter kwargs in AthenaSQLHook to prevent 
TypeError (#62227)
3ab2eb9884c is described below

commit 3ab2eb9884c74696051e7617b9ebbd46315da7b5
Author: Shivam Rastogi <[email protected]>
AuthorDate: Thu Feb 26 01:00:25 2026 -0800

    fix(amazon): Filter kwargs in AthenaSQLHook to prevent TypeError (#62227)
    
    * fix(amazon): Filter kwargs in AthenaSQLHook to prevent TypeError
    
    BaseSQLOperator.get_hook() passes all connection extra_dejson fields as
    kwargs to the hook constructor. Athena-specific params like s3_staging_dir
    and work_group are not accepted by AwsGenericHook.__init__(), causing a
    TypeError. Use an allowlist to only forward valid AWS-generic params to
    the parent class.
    
    Closes: #55678
    
    * fix: run prek
    
    * fix(athena): Declare explicit constructor params in AthenaSQLHook
    
    Replace kwargs filtering approach with explicit named parameters in
    AthenaSQLHook.__init__ to match AwsGenericHook signature. This fixes
    the TypeError from #55678 where BaseSQLOperator.get_hook() passed
    connection extras like s3_staging_dir as kwargs.
    
    The constructor now declares aws_conn_id, verify, region_name,
    client_type, resource_type, and config as named params with types
    matching the parent. Unexpected kwargs from connection extras are
    absorbed by **kwargs and discarded.
    
    Update ALLOWED_THICK_HOOKS_PARAMETERS in test_hooks_signature.py
    to permit the newly declared params. Add config forwarding assertion
    to existing test.
---
 .../providers/amazon/aws/hooks/athena_sql.py       | 27 ++++++++++++++++--
 .../tests/unit/amazon/aws/hooks/test_athena_sql.py | 33 ++++++++++++++++++++++
 .../unit/amazon/aws/hooks/test_hooks_signature.py  | 10 ++++++-
 3 files changed, 67 insertions(+), 3 deletions(-)

diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py 
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py
index d9cda93a45d..59a9191b61a 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py
@@ -33,6 +33,7 @@ from airflow.providers.common.compat.sdk import 
AirflowException, AirflowNotFoun
 from airflow.providers.common.sql.hooks.sql import DbApiHook
 
 if TYPE_CHECKING:
+    from botocore.config import Config
     from pyathena.connection import Connection as AthenaConnection
 
 
@@ -69,8 +70,30 @@ class AthenaSQLHook(AwsBaseHook, DbApiHook):
     hook_name = "Amazon Athena"
     supports_autocommit = True
 
-    def __init__(self, athena_conn_id: str = default_conn_name, *args, 
**kwargs) -> None:
-        super().__init__(*args, **kwargs)
+    def __init__(
+        self,
+        athena_conn_id: str = default_conn_name,
+        aws_conn_id: str | None = AwsBaseHook.default_conn_name,
+        verify: bool | str | None = None,
+        region_name: str | None = None,
+        client_type: str | None = None,
+        resource_type: str | None = None,
+        config: Config | dict[str, Any] | None = None,
+        **kwargs,
+    ) -> None:
+        # AwsGenericHook.__init__() only accepts the params declared above.
+        # Connection extras like s3_staging_dir and work_group are not
+        # constructor params — they are read later from the connection in
+        # get_conn(). BaseSQLOperator.get_hook() passes all connection extras
+        # as kwargs, so we absorb them here via **kwargs and discard them.
+        super().__init__(
+            aws_conn_id=aws_conn_id,
+            verify=verify,
+            region_name=region_name,
+            client_type=client_type,
+            resource_type=resource_type,
+            config=config,
+        )
         self.athena_conn_id = athena_conn_id
 
     @classmethod
diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_athena_sql.py 
b/providers/amazon/tests/unit/amazon/aws/hooks/test_athena_sql.py
index 3a814d41bbd..fc6fe82737c 100644
--- a/providers/amazon/tests/unit/amazon/aws/hooks/test_athena_sql.py
+++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_athena_sql.py
@@ -148,3 +148,36 @@ class TestAthenaSQLHookConn:
         hook = AthenaSQLHook(athena_conn_id=AWS_ATHENA_CONN_ID, 
aws_conn_id=AWS_CONN_ID)
         assert hook.athena_conn_id == AWS_ATHENA_CONN_ID
         assert hook.aws_conn_id == AWS_CONN_ID
+
+    def test_init_ignores_unexpected_kwargs(self):
+        """Verify that connection extras passed as kwargs don't crash the 
constructor.
+
+        BaseSQLOperator.get_hook() passes all connection extras as hook_params 
which
+        end up as constructor kwargs. Extras like s3_staging_dir and 
work_group are
+        not valid params for AwsGenericHook.__init__ and must be filtered out.
+        """
+        hook = AthenaSQLHook(
+            athena_conn_id="athena_conn",
+            s3_staging_dir="s3://mybucket/athena/",
+            work_group="primary",
+            region_name="eu-west-1",
+            driver="rest",
+        )
+        assert hook.athena_conn_id == "athena_conn"
+        # region_name is a valid AwsGenericHook param and should be passed 
through
+        assert hook._region_name == "eu-west-1"
+
+    def test_init_passes_valid_aws_kwargs(self):
+        """Verify that valid AwsGenericHook kwargs are still forwarded 
correctly."""
+        hook = AthenaSQLHook(
+            athena_conn_id="athena_conn",
+            aws_conn_id="custom_aws",
+            verify=False,
+            region_name="us-west-2",
+            config={"retries": {"max_attempts": 5}},
+        )
+        assert hook.athena_conn_id == "athena_conn"
+        assert hook.aws_conn_id == "custom_aws"
+        assert hook._verify is False
+        assert hook._region_name == "us-west-2"
+        assert hook._config is not None
diff --git 
a/providers/amazon/tests/unit/amazon/aws/hooks/test_hooks_signature.py 
b/providers/amazon/tests/unit/amazon/aws/hooks/test_hooks_signature.py
index 6031ac056a6..12cc46df792 100644
--- a/providers/amazon/tests/unit/amazon/aws/hooks/test_hooks_signature.py
+++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_hooks_signature.py
@@ -31,7 +31,15 @@ ALLOWED_THICK_HOOKS_PARAMETERS: dict[str, set[str]] = {
     # This list should only be reduced not extended with new parameters,
     # unless there is an exceptional reason.
     "AthenaHook": {"sleep_time", "log_query"},
-    "AthenaSQLHook": {"athena_conn_id"},
+    "AthenaSQLHook": {
+        "athena_conn_id",
+        "aws_conn_id",
+        "verify",
+        "region_name",
+        "client_type",
+        "resource_type",
+        "config",
+    },
     "BatchClientHook": {"status_retries", "max_retries"},
     "BatchWaitersHook": {"waiter_config"},
     "DataSyncHook": {"wait_interval_seconds"},

Reply via email to