This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3ab2eb9884c fix(amazon): Filter kwargs in AthenaSQLHook to prevent
TypeError (#62227)
3ab2eb9884c is described below
commit 3ab2eb9884c74696051e7617b9ebbd46315da7b5
Author: Shivam Rastogi <[email protected]>
AuthorDate: Thu Feb 26 01:00:25 2026 -0800
fix(amazon): Filter kwargs in AthenaSQLHook to prevent TypeError (#62227)
* fix(amazon): Filter kwargs in AthenaSQLHook to prevent TypeError
BaseSQLOperator.get_hook() passes all connection extra_dejson fields as
kwargs to the hook constructor. Athena-specific params like s3_staging_dir
and work_group are not accepted by AwsGenericHook.__init__(), causing a
TypeError. Use an allowlist to only forward valid AWS-generic params to
the parent class.
Closes: #55678
* fix: run prek
* fix(athena): Declare explicit constructor params in AthenaSQLHook
Replace kwargs filtering approach with explicit named parameters in
AthenaSQLHook.__init__ to match AwsGenericHook signature. This fixes
the TypeError from #55678 where BaseSQLOperator.get_hook() passed
connection extras like s3_staging_dir as kwargs.
The constructor now declares aws_conn_id, verify, region_name,
client_type, resource_type, and config as named params with types
matching the parent. Unexpected kwargs from connection extras are
absorbed by **kwargs and discarded.
Update ALLOWED_THICK_HOOKS_PARAMETERS in test_hooks_signature.py
to permit the newly declared params. Add config forwarding assertion
to existing test.
---
.../providers/amazon/aws/hooks/athena_sql.py | 27 ++++++++++++++++--
.../tests/unit/amazon/aws/hooks/test_athena_sql.py | 33 ++++++++++++++++++++++
.../unit/amazon/aws/hooks/test_hooks_signature.py | 10 ++++++-
3 files changed, 67 insertions(+), 3 deletions(-)
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py
index d9cda93a45d..59a9191b61a 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py
@@ -33,6 +33,7 @@ from airflow.providers.common.compat.sdk import
AirflowException, AirflowNotFoun
from airflow.providers.common.sql.hooks.sql import DbApiHook
if TYPE_CHECKING:
+ from botocore.config import Config
from pyathena.connection import Connection as AthenaConnection
@@ -69,8 +70,30 @@ class AthenaSQLHook(AwsBaseHook, DbApiHook):
hook_name = "Amazon Athena"
supports_autocommit = True
- def __init__(self, athena_conn_id: str = default_conn_name, *args,
**kwargs) -> None:
- super().__init__(*args, **kwargs)
+ def __init__(
+ self,
+ athena_conn_id: str = default_conn_name,
+ aws_conn_id: str | None = AwsBaseHook.default_conn_name,
+ verify: bool | str | None = None,
+ region_name: str | None = None,
+ client_type: str | None = None,
+ resource_type: str | None = None,
+ config: Config | dict[str, Any] | None = None,
+ **kwargs,
+ ) -> None:
+ # AwsGenericHook.__init__() only accepts the params declared above.
+ # Connection extras like s3_staging_dir and work_group are not
+ # constructor params — they are read later from the connection in
+ # get_conn(). BaseSQLOperator.get_hook() passes all connection extras
+ # as kwargs, so we absorb them here via **kwargs and discard them.
+ super().__init__(
+ aws_conn_id=aws_conn_id,
+ verify=verify,
+ region_name=region_name,
+ client_type=client_type,
+ resource_type=resource_type,
+ config=config,
+ )
self.athena_conn_id = athena_conn_id
@classmethod
diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_athena_sql.py
b/providers/amazon/tests/unit/amazon/aws/hooks/test_athena_sql.py
index 3a814d41bbd..fc6fe82737c 100644
--- a/providers/amazon/tests/unit/amazon/aws/hooks/test_athena_sql.py
+++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_athena_sql.py
@@ -148,3 +148,36 @@ class TestAthenaSQLHookConn:
hook = AthenaSQLHook(athena_conn_id=AWS_ATHENA_CONN_ID,
aws_conn_id=AWS_CONN_ID)
assert hook.athena_conn_id == AWS_ATHENA_CONN_ID
assert hook.aws_conn_id == AWS_CONN_ID
+
+ def test_init_ignores_unexpected_kwargs(self):
+ """Verify that connection extras passed as kwargs don't crash the
constructor.
+
+ BaseSQLOperator.get_hook() passes all connection extras as hook_params
which
+ end up as constructor kwargs. Extras like s3_staging_dir and
work_group are
+ not valid params for AwsGenericHook.__init__ and must be filtered out.
+ """
+ hook = AthenaSQLHook(
+ athena_conn_id="athena_conn",
+ s3_staging_dir="s3://mybucket/athena/",
+ work_group="primary",
+ region_name="eu-west-1",
+ driver="rest",
+ )
+ assert hook.athena_conn_id == "athena_conn"
+ # region_name is a valid AwsGenericHook param and should be passed
through
+ assert hook._region_name == "eu-west-1"
+
+ def test_init_passes_valid_aws_kwargs(self):
+ """Verify that valid AwsGenericHook kwargs are still forwarded
correctly."""
+ hook = AthenaSQLHook(
+ athena_conn_id="athena_conn",
+ aws_conn_id="custom_aws",
+ verify=False,
+ region_name="us-west-2",
+ config={"retries": {"max_attempts": 5}},
+ )
+ assert hook.athena_conn_id == "athena_conn"
+ assert hook.aws_conn_id == "custom_aws"
+ assert hook._verify is False
+ assert hook._region_name == "us-west-2"
+ assert hook._config is not None
diff --git
a/providers/amazon/tests/unit/amazon/aws/hooks/test_hooks_signature.py
b/providers/amazon/tests/unit/amazon/aws/hooks/test_hooks_signature.py
index 6031ac056a6..12cc46df792 100644
--- a/providers/amazon/tests/unit/amazon/aws/hooks/test_hooks_signature.py
+++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_hooks_signature.py
@@ -31,7 +31,15 @@ ALLOWED_THICK_HOOKS_PARAMETERS: dict[str, set[str]] = {
# This list should only be reduced not extended with new parameters,
# unless there is an exceptional reason.
"AthenaHook": {"sleep_time", "log_query"},
- "AthenaSQLHook": {"athena_conn_id"},
+ "AthenaSQLHook": {
+ "athena_conn_id",
+ "aws_conn_id",
+ "verify",
+ "region_name",
+ "client_type",
+ "resource_type",
+ "config",
+ },
"BatchClientHook": {"status_retries", "max_retries"},
"BatchWaitersHook": {"waiter_config"},
"DataSyncHook": {"wait_interval_seconds"},