This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f41ec5f65c4 Add validation for missing host and cluster/workgroup 
identifier in aws iam token retrieval (#61965)
f41ec5f65c4 is described below

commit f41ec5f65c4948e5103a5a534fdf5f145562f38c
Author: Justin Pakzad <[email protected]>
AuthorDate: Wed Mar 11 21:18:04 2026 -0400

    Add validation for missing host and cluster/workgroup identifier in aws iam 
token retrieval (#61965)
    
    * Add validation for missing host and cluster/workgroup identifier in aws 
iam token retrieval
    
    * mypy fix
    
    * Cleaned up the handling of cluster id to be more explicit
---
 .../airflow/providers/postgres/hooks/postgres.py   |  18 +++-
 .../tests/unit/postgres/hooks/test_postgres.py     | 108 ++++++++++++---------
 2 files changed, 76 insertions(+), 50 deletions(-)

diff --git 
a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py 
b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py
index 8e3edabb94a..d38171bbe21 100644
--- a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py
+++ b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py
@@ -483,9 +483,13 @@ class PostgresHook(DbApiHook):
             port = conn.port or 5439
             # Pull the custer-identifier from the beginning of the Redshift URL
             # ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com 
returns my-cluster
-            cluster_identifier = conn.extra_dejson.get(
-                "cluster-identifier", cast("str", conn.host).split(".")[0]
-            )
+            cluster_identifier = conn.extra_dejson.get("cluster-identifier")
+            if cluster_identifier is None:
+                if not conn.host:
+                    raise ValueError(
+                        "connection host is required for AWS IAM token when 
cluster-identifier is not set in extras."
+                    )
+                cluster_identifier = conn.host.split(".")[0]
             redshift_client = AwsBaseHook(aws_conn_id=aws_conn_id, 
client_type="redshift").conn
             # 
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift/client/get_cluster_credentials.html#Redshift.Client.get_cluster_credentials
             cluster_creds = redshift_client.get_cluster_credentials(
@@ -501,7 +505,13 @@ class PostgresHook(DbApiHook):
             # Pull the workgroup-name from the query params/extras, if not 
there then pull it from the
             # beginning of the Redshift URL
             # ex. workgroup-name.ccdre4hpd39h.us-east-1.redshift.amazonaws.com 
returns workgroup-name
-            workgroup_name = conn.extra_dejson.get("workgroup-name", 
cast("str", conn.host).split(".")[0])
+            workgroup_name = conn.extra_dejson.get("workgroup-name")
+            if workgroup_name is None:
+                if not conn.host:
+                    raise ValueError(
+                        "connection host is required for AWS IAM token when 
workgroup-name is not set in extras."
+                    )
+                workgroup_name = conn.host.split(".")[0]
             redshift_serverless_client = AwsBaseHook(
                 aws_conn_id=aws_conn_id, client_type="redshift-serverless"
             ).conn
diff --git a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py 
b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py
index f4570f6563b..6cad536c07b 100644
--- a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py
+++ b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py
@@ -304,18 +304,22 @@ class TestPostgresHookConn:
     @pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"])
     @pytest.mark.parametrize("port", [5432, 5439, None])
     @pytest.mark.parametrize(
-        ("host", "conn_cluster_identifier", "expected_cluster_identifier"),
+        ("host", "conn_cluster_identifier", "expected_cluster_identifier", 
"raises_exception"),
         [
             (
                 
"cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com",
                 NOTSET,
                 "cluster-identifier",
+                False,
             ),
             (
                 
"cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com",
                 "different-identifier",
                 "different-identifier",
+                False,
             ),
+            (None, NOTSET, None, True),
+            (None, "cluster-identifier", "cluster-identifier", False),
         ],
     )
     def test_get_conn_rds_iam_redshift(
@@ -327,6 +331,7 @@ class TestPostgresHookConn:
         host,
         conn_cluster_identifier,
         expected_cluster_identifier,
+        raises_exception,
     ):
         mock_aws_hook_class = 
mocker.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook")
 
@@ -353,45 +358,52 @@ class TestPostgresHookConn:
             "DbUser": mock_db_user,
         }
         type(mock_aws_hook_instance).conn = 
mocker.PropertyMock(return_value=mock_client)
-
-        self.db_hook.get_conn()
-        # Check AwsHook initialization
-        mock_aws_hook_class.assert_called_once_with(
-            # If aws_conn_id not set than fallback to aws_default
-            aws_conn_id=aws_conn_id if aws_conn_id is not NOTSET else 
"aws_default",
-            client_type="redshift",
-        )
-        # Check boto3 'redshift' client method `get_cluster_credentials` call 
args
-        mock_client.get_cluster_credentials.assert_called_once_with(
-            DbUser=self.connection.login,
-            DbName=self.connection.schema,
-            ClusterIdentifier=expected_cluster_identifier,
-            AutoCreate=False,
-        )
-        # Check expected psycopg2 connection call args
-        mock_connect.assert_called_once_with(
-            user=mock_db_user,
-            password=mock_db_pass,
-            host=host,
-            dbname=self.connection.schema,
-            port=(port or 5439),
-        )
+        if raises_exception:
+            with pytest.raises(ValueError, match="connection host is 
required"):
+                self.db_hook.get_conn()
+        else:
+            self.db_hook.get_conn()
+            # Check AwsHook initialization
+            mock_aws_hook_class.assert_called_once_with(
+                # If aws_conn_id not set than fallback to aws_default
+                aws_conn_id=aws_conn_id if aws_conn_id is not NOTSET else 
"aws_default",
+                client_type="redshift",
+            )
+            # Check boto3 'redshift' client method `get_cluster_credentials` 
call args
+            mock_client.get_cluster_credentials.assert_called_once_with(
+                DbUser=self.connection.login,
+                DbName=self.connection.schema,
+                ClusterIdentifier=expected_cluster_identifier,
+                AutoCreate=False,
+            )
+            # Check expected psycopg2 connection call args
+            mock_connect.assert_called_once_with(
+                user=mock_db_user,
+                password=mock_db_pass,
+                host=host,
+                dbname=self.connection.schema,
+                port=(port or 5439),
+            )
 
     @pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"])
     @pytest.mark.parametrize("port", [5432, 5439, None])
     @pytest.mark.parametrize(
-        ("host", "conn_workgroup_name", "expected_workgroup_name"),
+        ("host", "conn_workgroup_name", "expected_workgroup_name", 
"raises_exception"),
         [
             (
                 
"serverless-workgroup.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com",
                 NOTSET,
                 "serverless-workgroup",
+                False,
             ),
             (
                 
"cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com",
                 "different-workgroup",
                 "different-workgroup",
+                False,
             ),
+            (None, NOTSET, None, True),
+            (None, "serverless-workgroup", "serverless-workgroup", False),
         ],
     )
     def test_get_conn_rds_iam_redshift_serverless(
@@ -403,6 +415,7 @@ class TestPostgresHookConn:
         host,
         conn_workgroup_name,
         expected_workgroup_name,
+        raises_exception,
     ):
         mock_aws_hook_class = 
mocker.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook")
 
@@ -429,27 +442,30 @@ class TestPostgresHookConn:
             "dbUser": mock_db_user,
         }
         type(mock_aws_hook_instance).conn = 
mocker.PropertyMock(return_value=mock_client)
-
-        self.db_hook.get_conn()
-        # Check AwsHook initialization
-        mock_aws_hook_class.assert_called_once_with(
-            # If aws_conn_id not set than fallback to aws_default
-            aws_conn_id=aws_conn_id if aws_conn_id is not NOTSET else 
"aws_default",
-            client_type="redshift-serverless",
-        )
-        # Check boto3 'redshift' client method `get_cluster_credentials` call 
args
-        mock_client.get_credentials.assert_called_once_with(
-            dbName=self.connection.schema,
-            workgroupName=expected_workgroup_name,
-        )
-        # Check expected psycopg2 connection call args
-        mock_connect.assert_called_once_with(
-            user=mock_db_user,
-            password=mock_db_pass,
-            host=host,
-            dbname=self.connection.schema,
-            port=(port or 5439),
-        )
+        if raises_exception:
+            with pytest.raises(ValueError, match="connection host is 
required"):
+                self.db_hook.get_conn()
+        else:
+            self.db_hook.get_conn()
+            # Check AwsHook initialization
+            mock_aws_hook_class.assert_called_once_with(
+                # If aws_conn_id not set than fallback to aws_default
+                aws_conn_id=aws_conn_id if aws_conn_id is not NOTSET else 
"aws_default",
+                client_type="redshift-serverless",
+            )
+            # Check boto3 'redshift' client method `get_cluster_credentials` 
call args
+            mock_client.get_credentials.assert_called_once_with(
+                dbName=self.connection.schema,
+                workgroupName=expected_workgroup_name,
+            )
+            # Check expected psycopg2 connection call args
+            mock_connect.assert_called_once_with(
+                user=mock_db_user,
+                password=mock_db_pass,
+                host=host,
+                dbname=self.connection.schema,
+                port=(port or 5439),
+            )
 
     def test_get_conn_azure_iam(self, mocker, mock_connect):
         mock_azure_conn_id = "azure_conn1"

Reply via email to