SameerMesiah97 commented on code in PR #61965:
URL: https://github.com/apache/airflow/pull/61965#discussion_r2810287855


##########
providers/postgres/src/airflow/providers/postgres/hooks/postgres.py:
##########
@@ -476,9 +476,12 @@ def get_aws_iam_token(self, conn: Connection) -> 
tuple[str, str, int]:
             port = conn.port or 5439
             # Pull the custer-identifier from the beginning of the Redshift URL
             # ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com 
returns my-cluster
-            cluster_identifier = conn.extra_dejson.get(
-                "cluster-identifier", cast("str", conn.host).split(".")[0]
-            )
+            cluster_identifier = conn.extra_dejson.get("cluster-identifier")
+            if cluster_identifier is None and not conn.host:
+                raise ValueError(
+                    "connection host is required for AWS IAM token when 
cluster-identifier is not set in extras."
+                )
+            cluster_identifier = cluster_identifier or 
(conn.host.split(".")[0] if conn.host else None)

Review Comment:
   I think this will behave slightly differently when cluster_identifier is 
falsy e.g. empty string or None. Before, that would resolve to the falsy value 
but in your change, it will fallback to `conn.host.split(".")[0] `? is this 
intentional? If that is the case, this should be made more explicit. Like this:
   
   ```
   cluster_identifier = conn.extra_dejson.get("cluster-identifier")
   
   if cluster_identifier is None:
       if not conn.host:
           raise ValueError(
               "Connection host is required for AWS IAM token when "
               "'cluster-identifier' is not set in extras."
           )
       cluster_identifier = conn.host.split(".")[0]
   ```



##########
providers/postgres/tests/unit/postgres/hooks/test_postgres.py:
##########
@@ -353,45 +358,52 @@ def test_get_conn_rds_iam_redshift(
             "DbUser": mock_db_user,
         }
         type(mock_aws_hook_instance).conn = 
mocker.PropertyMock(return_value=mock_client)
-
-        self.db_hook.get_conn()
-        # Check AwsHook initialization
-        mock_aws_hook_class.assert_called_once_with(
-            # If aws_conn_id not set than fallback to aws_default
-            aws_conn_id=aws_conn_id if aws_conn_id is not NOTSET else 
"aws_default",
-            client_type="redshift",
-        )
-        # Check boto3 'redshift' client method `get_cluster_credentials` call 
args
-        mock_client.get_cluster_credentials.assert_called_once_with(
-            DbUser=self.connection.login,
-            DbName=self.connection.schema,
-            ClusterIdentifier=expected_cluster_identifier,
-            AutoCreate=False,
-        )
-        # Check expected psycopg2 connection call args
-        mock_connect.assert_called_once_with(
-            user=mock_db_user,
-            password=mock_db_pass,
-            host=host,
-            dbname=self.connection.schema,
-            port=(port or 5439),
-        )
+        if raises_exception:
+            with pytest.raises(ValueError, match="connection host is 
required"):
+                self.db_hook.get_conn()
+        else:
+            self.db_hook.get_conn()
+            # Check AwsHook initialization
+            mock_aws_hook_class.assert_called_once_with(
+                # If aws_conn_id not set than fallback to aws_default
+                aws_conn_id=aws_conn_id if aws_conn_id is not NOTSET else 
"aws_default",
+                client_type="redshift",
+            )
+            # Check boto3 'redshift' client method `get_cluster_credentials` 
call args
+            mock_client.get_cluster_credentials.assert_called_once_with(
+                DbUser=self.connection.login,
+                DbName=self.connection.schema,
+                ClusterIdentifier=expected_cluster_identifier,
+                AutoCreate=False,
+            )
+            # Check expected psycopg2 connection call args
+            mock_connect.assert_called_once_with(
+                user=mock_db_user,
+                password=mock_db_pass,
+                host=host,
+                dbname=self.connection.schema,
+                port=(port or 5439),
+            )
 
     @pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"])
     @pytest.mark.parametrize("port", [5432, 5439, None])
     @pytest.mark.parametrize(
-        ("host", "conn_workgroup_name", "expected_workgroup_name"),
+        ("host", "conn_workgroup_name", "expected_workgroup_name", 
"raises_exception"),
         [
             (
                 
"serverless-workgroup.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com",
                 NOTSET,
                 "serverless-workgroup",
+                False,
             ),
             (
                 
"cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com",
                 "different-workgroup",
                 "different-workgroup",
+                False,
             ),
+            (None, NOTSET, None, True),
+            (None, "serverless-workgroup", "serverless-workgroup", False),
         ],

Review Comment:
   Above comment applies here too.



##########
providers/postgres/src/airflow/providers/postgres/hooks/postgres.py:
##########
@@ -494,7 +497,12 @@ def get_aws_iam_token(self, conn: Connection) -> 
tuple[str, str, int]:
             # Pull the workgroup-name from the query params/extras, if not 
there then pull it from the
             # beginning of the Redshift URL
             # ex. workgroup-name.ccdre4hpd39h.us-east-1.redshift.amazonaws.com 
returns workgroup-name
-            workgroup_name = conn.extra_dejson.get("workgroup-name", 
cast("str", conn.host).split(".")[0])
+            workgroup_name = conn.extra_dejson.get("workgroup-name")
+            if workgroup_name is None and not conn.host:
+                raise ValueError(
+                    "connection host is required for AWS IAM token when 
workgroup-name is not set in extras."
+                )
+            workgroup_name = workgroup_name or (conn.host.split(".")[0] if 
conn.host else None)

Review Comment:
   The above comment applies to this as well. 



##########
providers/postgres/tests/unit/postgres/hooks/test_postgres.py:
##########
@@ -304,18 +304,22 @@ def test_get_conn_extra(self, mock_connect):
     @pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"])
     @pytest.mark.parametrize("port", [5432, 5439, None])
     @pytest.mark.parametrize(
-        ("host", "conn_cluster_identifier", "expected_cluster_identifier"),
+        ("host", "conn_cluster_identifier", "expected_cluster_identifier", 
"raises_exception"),
         [
             (
                 
"cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com",
                 NOTSET,
                 "cluster-identifier",
+                False,
             ),
             (
                 
"cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com",
                 "different-identifier",
                 "different-identifier",
+                False,
             ),
+            (None, NOTSET, None, True),
+            (None, "cluster-identifier", "cluster-identifier", False),

Review Comment:
   I think you should add another case for falsy `conn_cluster_identifier` 
scenarios (such as ""). 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to