This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 282854b55f Add endpoint_url in test_connection (#32664)
282854b55f is described below
commit 282854b55fd8b0ef46ae0b9032b67654b4789249
Author: ieunea1128 <[email protected]>
AuthorDate: Tue Jul 25 02:10:07 2023 +0900
Add endpoint_url in test_connection (#32664)
---
airflow/providers/amazon/aws/hooks/base_aws.py | 6 +++++-
tests/providers/amazon/aws/hooks/test_base_aws.py | 20 +++++++++++++++++++-
2 files changed, 24 insertions(+), 2 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py
b/airflow/providers/amazon/aws/hooks/base_aws.py
index 94782401ac..01748089c8 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -809,7 +809,11 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
"""
try:
session = self.get_session()
- conn_info = session.client("sts").get_caller_identity()
+ test_endpoint_url =
self.conn_config.extra_config.get("test_endpoint_url")
+ conn_info = session.client(
+ "sts",
+ endpoint_url=test_endpoint_url,
+ ).get_caller_identity()
metadata = conn_info.pop("ResponseMetadata", {})
if metadata.get("HTTPStatusCode") != 200:
try:
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py
b/tests/providers/amazon/aws/hooks/test_base_aws.py
index 8ba8b6b812..084c282382 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -295,6 +295,7 @@ class TestSessionFactory:
conn = AwsConnectionWrapper.from_connection_metadata(conn_id=conn_id,
extra=extra)
sf = BaseSessionFactory(conn=conn)
session = sf.create_session()
+
assert session.region_name == region_name
# Validate method of botocore credentials provider.
# It shouldn't be 'explicit' which refers in this case to initial
credentials.
@@ -858,7 +859,7 @@ class TestAwsBaseHook:
result, message = hook.test_connection()
assert not result
assert message == json.dumps(response_metadata)
- mock_sts_client.assert_called_once_with("sts")
+ mock_sts_client.assert_called_once_with("sts", endpoint_url=None)
def mock_error():
raise ConnectionError("Test Error")
@@ -872,6 +873,23 @@ class TestAwsBaseHook:
assert hook.client_type == "ec2"
+ @mock_sts
+ @pytest.mark.parametrize(
+ "test_endpoint_url, result_url",
+ [
+ (None, "https://sts.amazonaws.com"),
+ ("https://sts.us-east-1.amazonaws.com",
"https://sts.us-east-1.amazonaws.com"),
+ ],
+ )
+ def test_hook_connection_endpoint_url_valid(self, test_endpoint_url,
result_url):
+ """Test if test_endpoint_url is valid in test connection"""
+ conn =
AwsConnectionWrapper.from_connection_metadata(conn_id=MOCK_AWS_CONN_ID)
+ sf = BaseSessionFactory(conn=conn)
+ session = sf.create_session()
+ client = session.client("sts", endpoint_url=test_endpoint_url)
+
+ assert client._endpoint.host == result_url
+
@mock.patch.dict(os.environ, {f"AIRFLOW_CONN_{MOCK_AWS_CONN_ID.upper()}":
"aws://"})
def test_conn_config_conn_id_exists(self):
"""Test retrieve connection config if aws_conn_id exists."""