This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6802605ee8d Fix mypy errors for sqla2 in aws hooks (#56751)
6802605ee8d is described below
commit 6802605ee8db4d6fe1efebed6b87a4664e3bcec2
Author: Niko Oliveira <[email protected]>
AuthorDate: Mon Oct 20 23:52:17 2025 -0700
Fix mypy errors for sqla2 in aws hooks (#56751)
---
.../providers/amazon/aws/hooks/athena_sql.py | 20 +++++++++-------
.../providers/amazon/aws/hooks/redshift_sql.py | 28 +++++++++++++++++-----
2 files changed, 33 insertions(+), 15 deletions(-)
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py
index f0491153fbd..70d9a1f8f9b 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py
@@ -56,7 +56,7 @@ class AthenaSQLHook(AwsBaseHook, DbApiHook):
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
.. note::
- get_uri() depends on SQLAlchemy and PyAthena.
+ get_uri() depends on SQLAlchemy and PyAthena
"""
conn_name_attr = "athena_conn_id"
@@ -155,14 +155,16 @@ class AthenaSQLHook(AwsBaseHook, DbApiHook):
conn_params = self._get_conn_params()
creds = self.get_credentials(region_name=conn_params["region_name"])
- return URL.create(
- f"awsathena+{conn_params['driver']}",
- username=creds.access_key,
- password=creds.secret_key,
-
host=f"athena.{conn_params['region_name']}.{conn_params['aws_domain']}",
- port=443,
- database=conn_params["schema_name"],
- query={"aws_session_token": creds.token, **self.conn.extra_dejson},
+ return str(
+ URL.create(
+ f"awsathena+{conn_params['driver']}",
+ username=creds.access_key,
+ password=creds.secret_key,
+
host=f"athena.{conn_params['region_name']}.{conn_params['aws_domain']}",
+ port=443,
+ database=conn_params["schema_name"],
+ query={"aws_session_token": creds.token,
**self.conn.extra_dejson},
+ )
)
def get_conn(self) -> AthenaConnection:
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_sql.py
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_sql.py
index a4e3a76250a..660cdd942be 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_sql.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_sql.py
@@ -51,7 +51,7 @@ class RedshiftSQLHook(DbApiHook):
:ref:`Amazon Redshift connection id<howto/connection:redshift>`
.. note::
- get_sqlalchemy_engine() and get_uri() depend on
sqlalchemy-amazon-redshift
+ get_sqlalchemy_engine() and get_uri() depend on
sqlalchemy-amazon-redshift.
"""
conn_name_attr = "redshift_conn_id"
@@ -155,10 +155,23 @@ class RedshiftSQLHook(DbApiHook):
if "user" in conn_params:
conn_params["username"] = conn_params.pop("user")
- # Compatibility: The 'create' factory method was added in SQLAlchemy
1.4
- # to replace calling the default URL constructor directly.
- create_url = getattr(URL, "create", URL)
- return str(create_url(drivername="postgresql", **conn_params))
+ # Use URL.create for SQLAlchemy 2 compatibility
+ username = conn_params.get("username")
+ password = conn_params.get("password")
+ host = conn_params.get("host")
+ port = conn_params.get("port")
+ database = conn_params.get("database")
+
+ return str(
+ URL.create(
+ drivername="postgresql",
+ username=str(username) if username is not None else None,
+ password=str(password) if password is not None else None,
+ host=str(host) if host is not None else None,
+ port=int(port) if port is not None else None,
+ database=str(database) if database is not None else None,
+ )
+ )
def get_sqlalchemy_engine(self, engine_kwargs=None):
"""Overridden to pass Redshift-specific arguments."""
@@ -237,7 +250,10 @@ class RedshiftSQLHook(DbApiHook):
region_name = AwsBaseHook(aws_conn_id=self.aws_conn_id).region_name
identifier = f"{cluster_identifier}.{region_name}"
if not cluster_identifier:
- identifier = self._get_identifier_from_hostname(connection.host)
+ if connection.host:
+ identifier =
self._get_identifier_from_hostname(connection.host)
+ else:
+ raise AirflowException("Host is required when
cluster_identifier is not provided.")
return f"{identifier}:{port}"
def _get_identifier_from_hostname(self, hostname: str) -> str: