This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 282854b55f Add endpoint_url in test_connection (#32664)
282854b55f is described below

commit 282854b55fd8b0ef46ae0b9032b67654b4789249
Author: ieunea1128 <[email protected]>
AuthorDate: Tue Jul 25 02:10:07 2023 +0900

    Add endpoint_url in test_connection (#32664)
---
 airflow/providers/amazon/aws/hooks/base_aws.py    |  6 +++++-
 tests/providers/amazon/aws/hooks/test_base_aws.py | 20 +++++++++++++++++++-
 2 files changed, 24 insertions(+), 2 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py 
b/airflow/providers/amazon/aws/hooks/base_aws.py
index 94782401ac..01748089c8 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -809,7 +809,11 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
         """
         try:
             session = self.get_session()
-            conn_info = session.client("sts").get_caller_identity()
+            test_endpoint_url = 
self.conn_config.extra_config.get("test_endpoint_url")
+            conn_info = session.client(
+                "sts",
+                endpoint_url=test_endpoint_url,
+            ).get_caller_identity()
             metadata = conn_info.pop("ResponseMetadata", {})
             if metadata.get("HTTPStatusCode") != 200:
                 try:
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py 
b/tests/providers/amazon/aws/hooks/test_base_aws.py
index 8ba8b6b812..084c282382 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -295,6 +295,7 @@ class TestSessionFactory:
         conn = AwsConnectionWrapper.from_connection_metadata(conn_id=conn_id, 
extra=extra)
         sf = BaseSessionFactory(conn=conn)
         session = sf.create_session()
+
         assert session.region_name == region_name
         # Validate method of botocore credentials provider.
         # It shouldn't be 'explicit' which refers in this case to initial 
credentials.
@@ -858,7 +859,7 @@ class TestAwsBaseHook:
         result, message = hook.test_connection()
         assert not result
         assert message == json.dumps(response_metadata)
-        mock_sts_client.assert_called_once_with("sts")
+        mock_sts_client.assert_called_once_with("sts", endpoint_url=None)
 
         def mock_error():
             raise ConnectionError("Test Error")
@@ -872,6 +873,23 @@ class TestAwsBaseHook:
 
         assert hook.client_type == "ec2"
 
+    @mock_sts
+    @pytest.mark.parametrize(
+        "test_endpoint_url, result_url",
+        [
+            (None, "https://sts.amazonaws.com";),
+            ("https://sts.us-east-1.amazonaws.com";, 
"https://sts.us-east-1.amazonaws.com";),
+        ],
+    )
+    def test_hook_connection_endpoint_url_valid(self, test_endpoint_url, 
result_url):
+        """Test if test_endpoint_url is valid in test connection"""
+        conn = 
AwsConnectionWrapper.from_connection_metadata(conn_id=MOCK_AWS_CONN_ID)
+        sf = BaseSessionFactory(conn=conn)
+        session = sf.create_session()
+        client = session.client("sts", endpoint_url=test_endpoint_url)
+
+        assert client._endpoint.host == result_url
+
     @mock.patch.dict(os.environ, {f"AIRFLOW_CONN_{MOCK_AWS_CONN_ID.upper()}": 
"aws://"})
     def test_conn_config_conn_id_exists(self):
         """Test retrieve connection config if aws_conn_id exists."""

Reply via email to