This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f41ec5f65c4 Add validation for missing host and cluster/workgroup
identifier in aws iam token retrieval (#61965)
f41ec5f65c4 is described below
commit f41ec5f65c4948e5103a5a534fdf5f145562f38c
Author: Justin Pakzad <[email protected]>
AuthorDate: Wed Mar 11 21:18:04 2026 -0400
Add validation for missing host and cluster/workgroup identifier in aws iam
token retrieval (#61965)
* Add validation for missing host and cluster/workgroup identifier in aws
iam token retrieval
* mypy fix
* Cleaned up the handling of cluster id to be more explicit
---
.../airflow/providers/postgres/hooks/postgres.py | 18 +++-
.../tests/unit/postgres/hooks/test_postgres.py | 108 ++++++++++++---------
2 files changed, 76 insertions(+), 50 deletions(-)
diff --git
a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py
b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py
index 8e3edabb94a..d38171bbe21 100644
--- a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py
+++ b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py
@@ -483,9 +483,13 @@ class PostgresHook(DbApiHook):
port = conn.port or 5439
# Pull the custer-identifier from the beginning of the Redshift URL
# ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com
returns my-cluster
- cluster_identifier = conn.extra_dejson.get(
- "cluster-identifier", cast("str", conn.host).split(".")[0]
- )
+ cluster_identifier = conn.extra_dejson.get("cluster-identifier")
+ if cluster_identifier is None:
+ if not conn.host:
+ raise ValueError(
+ "connection host is required for AWS IAM token when
cluster-identifier is not set in extras."
+ )
+ cluster_identifier = conn.host.split(".")[0]
redshift_client = AwsBaseHook(aws_conn_id=aws_conn_id,
client_type="redshift").conn
#
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift/client/get_cluster_credentials.html#Redshift.Client.get_cluster_credentials
cluster_creds = redshift_client.get_cluster_credentials(
@@ -501,7 +505,13 @@ class PostgresHook(DbApiHook):
# Pull the workgroup-name from the query params/extras, if not
there then pull it from the
# beginning of the Redshift URL
# ex. workgroup-name.ccdre4hpd39h.us-east-1.redshift.amazonaws.com
returns workgroup-name
- workgroup_name = conn.extra_dejson.get("workgroup-name",
cast("str", conn.host).split(".")[0])
+ workgroup_name = conn.extra_dejson.get("workgroup-name")
+ if workgroup_name is None:
+ if not conn.host:
+ raise ValueError(
+ "connection host is required for AWS IAM token when
workgroup-name is not set in extras."
+ )
+ workgroup_name = conn.host.split(".")[0]
redshift_serverless_client = AwsBaseHook(
aws_conn_id=aws_conn_id, client_type="redshift-serverless"
).conn
diff --git a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py
b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py
index f4570f6563b..6cad536c07b 100644
--- a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py
+++ b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py
@@ -304,18 +304,22 @@ class TestPostgresHookConn:
@pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"])
@pytest.mark.parametrize("port", [5432, 5439, None])
@pytest.mark.parametrize(
- ("host", "conn_cluster_identifier", "expected_cluster_identifier"),
+ ("host", "conn_cluster_identifier", "expected_cluster_identifier",
"raises_exception"),
[
(
"cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com",
NOTSET,
"cluster-identifier",
+ False,
),
(
"cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com",
"different-identifier",
"different-identifier",
+ False,
),
+ (None, NOTSET, None, True),
+ (None, "cluster-identifier", "cluster-identifier", False),
],
)
def test_get_conn_rds_iam_redshift(
@@ -327,6 +331,7 @@ class TestPostgresHookConn:
host,
conn_cluster_identifier,
expected_cluster_identifier,
+ raises_exception,
):
mock_aws_hook_class =
mocker.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook")
@@ -353,45 +358,52 @@ class TestPostgresHookConn:
"DbUser": mock_db_user,
}
type(mock_aws_hook_instance).conn =
mocker.PropertyMock(return_value=mock_client)
-
- self.db_hook.get_conn()
- # Check AwsHook initialization
- mock_aws_hook_class.assert_called_once_with(
- # If aws_conn_id not set than fallback to aws_default
- aws_conn_id=aws_conn_id if aws_conn_id is not NOTSET else
"aws_default",
- client_type="redshift",
- )
- # Check boto3 'redshift' client method `get_cluster_credentials` call
args
- mock_client.get_cluster_credentials.assert_called_once_with(
- DbUser=self.connection.login,
- DbName=self.connection.schema,
- ClusterIdentifier=expected_cluster_identifier,
- AutoCreate=False,
- )
- # Check expected psycopg2 connection call args
- mock_connect.assert_called_once_with(
- user=mock_db_user,
- password=mock_db_pass,
- host=host,
- dbname=self.connection.schema,
- port=(port or 5439),
- )
+ if raises_exception:
+ with pytest.raises(ValueError, match="connection host is
required"):
+ self.db_hook.get_conn()
+ else:
+ self.db_hook.get_conn()
+ # Check AwsHook initialization
+ mock_aws_hook_class.assert_called_once_with(
+ # If aws_conn_id not set than fallback to aws_default
+ aws_conn_id=aws_conn_id if aws_conn_id is not NOTSET else
"aws_default",
+ client_type="redshift",
+ )
+ # Check boto3 'redshift' client method `get_cluster_credentials`
call args
+ mock_client.get_cluster_credentials.assert_called_once_with(
+ DbUser=self.connection.login,
+ DbName=self.connection.schema,
+ ClusterIdentifier=expected_cluster_identifier,
+ AutoCreate=False,
+ )
+ # Check expected psycopg2 connection call args
+ mock_connect.assert_called_once_with(
+ user=mock_db_user,
+ password=mock_db_pass,
+ host=host,
+ dbname=self.connection.schema,
+ port=(port or 5439),
+ )
@pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"])
@pytest.mark.parametrize("port", [5432, 5439, None])
@pytest.mark.parametrize(
- ("host", "conn_workgroup_name", "expected_workgroup_name"),
+ ("host", "conn_workgroup_name", "expected_workgroup_name",
"raises_exception"),
[
(
"serverless-workgroup.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com",
NOTSET,
"serverless-workgroup",
+ False,
),
(
"cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com",
"different-workgroup",
"different-workgroup",
+ False,
),
+ (None, NOTSET, None, True),
+ (None, "serverless-workgroup", "serverless-workgroup", False),
],
)
def test_get_conn_rds_iam_redshift_serverless(
@@ -403,6 +415,7 @@ class TestPostgresHookConn:
host,
conn_workgroup_name,
expected_workgroup_name,
+ raises_exception,
):
mock_aws_hook_class =
mocker.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook")
@@ -429,27 +442,30 @@ class TestPostgresHookConn:
"dbUser": mock_db_user,
}
type(mock_aws_hook_instance).conn =
mocker.PropertyMock(return_value=mock_client)
-
- self.db_hook.get_conn()
- # Check AwsHook initialization
- mock_aws_hook_class.assert_called_once_with(
- # If aws_conn_id not set than fallback to aws_default
- aws_conn_id=aws_conn_id if aws_conn_id is not NOTSET else
"aws_default",
- client_type="redshift-serverless",
- )
- # Check boto3 'redshift' client method `get_cluster_credentials` call
args
- mock_client.get_credentials.assert_called_once_with(
- dbName=self.connection.schema,
- workgroupName=expected_workgroup_name,
- )
- # Check expected psycopg2 connection call args
- mock_connect.assert_called_once_with(
- user=mock_db_user,
- password=mock_db_pass,
- host=host,
- dbname=self.connection.schema,
- port=(port or 5439),
- )
+ if raises_exception:
+ with pytest.raises(ValueError, match="connection host is
required"):
+ self.db_hook.get_conn()
+ else:
+ self.db_hook.get_conn()
+ # Check AwsHook initialization
+ mock_aws_hook_class.assert_called_once_with(
+ # If aws_conn_id not set than fallback to aws_default
+ aws_conn_id=aws_conn_id if aws_conn_id is not NOTSET else
"aws_default",
+ client_type="redshift-serverless",
+ )
+ # Check boto3 'redshift' client method `get_cluster_credentials`
call args
+ mock_client.get_credentials.assert_called_once_with(
+ dbName=self.connection.schema,
+ workgroupName=expected_workgroup_name,
+ )
+ # Check expected psycopg2 connection call args
+ mock_connect.assert_called_once_with(
+ user=mock_db_user,
+ password=mock_db_pass,
+ host=host,
+ dbname=self.connection.schema,
+ port=(port or 5439),
+ )
def test_get_conn_azure_iam(self, mocker, mock_connect):
mock_azure_conn_id = "azure_conn1"