This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f8d7290178 Improve testing AWS Connection response (#26953)
f8d7290178 is described below

commit f8d7290178dba6b96ba0ec2cc28a5c4289902229
Author: Andrey Anshin <[email protected]>
AuthorDate: Mon Oct 10 13:13:22 2022 +0400

    Improve testing AWS Connection response (#26953)
---
 airflow/providers/amazon/aws/hooks/base_aws.py    | 18 +++++++-------
 tests/providers/amazon/aws/hooks/test_base_aws.py | 29 +++++++++++++++++++++++
 2 files changed, 38 insertions(+), 9 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py 
b/airflow/providers/amazon/aws/hooks/base_aws.py
index 01b5618602..b61df51e4b 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -635,21 +635,21 @@ class AwsGenericHook(BaseHook, 
Generic[BaseAwsConnection]):
         .. seealso::
             
https://docs.aws.amazon.com/STS/latest/APIReference/API_GetCallerIdentity.html
         """
-        orig_client_type, self.client_type = self.client_type, 'sts'
         try:
-            res = self.get_client_type().get_caller_identity()
-            metadata = res.pop("ResponseMetadata", {})
-            if metadata.get("HTTPStatusCode") == 200:
-                return True, json.dumps(res)
-            else:
+            session = self.get_session()
+            conn_info = session.client("sts").get_caller_identity()
+            metadata = conn_info.pop("ResponseMetadata", {})
+            if metadata.get("HTTPStatusCode") != 200:
                 try:
                     return False, json.dumps(metadata)
                 except TypeError:
                     return False, str(metadata)
+            conn_info["credentials_method"] = session.get_credentials().method
+            conn_info["region_name"] = session.region_name
+            return True, ", ".join(f"{k}={v!r}" for k, v in conn_info.items())
+
         except Exception as e:
-            return False, str(e)
-        finally:
-            self.client_type = orig_client_type
+            return False, str(f"{type(e).__name__!r} error occurred while 
testing connection: {e}")
 
 
 class AwsBaseHook(AwsGenericHook[Union[boto3.client, boto3.resource]]):
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py 
b/tests/providers/amazon/aws/hooks/test_base_aws.py
index 13cffbb33f..07e1cc7edc 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -670,6 +670,35 @@ class TestAwsBaseHook:
         assert result
         assert hook.client_type == "s3"  # Same client_type which defined 
during initialisation
 
+    @mock.patch("boto3.session.Session")
+    def test_hook_connection_test_failed(self, mock_boto3_session):
+        """Test ``test_connection`` failure."""
+        hook = AwsBaseHook(client_type="ec2")
+
+        # Tests that STS API return non 200 code. Under normal circumstances 
this is hardly possible.
+        response_metadata = {"HTTPStatusCode": 500, "reason": "Test Failure"}
+        mock_sts_client = mock.MagicMock()
+        mock_sts_client.return_value.get_caller_identity.return_value = {
+            "ResponseMetadata": response_metadata
+        }
+        mock_boto3_session.return_value.client = mock_sts_client
+        result, message = hook.test_connection()
+        assert not result
+        assert message == json.dumps(response_metadata)
+        mock_sts_client.assert_called_once_with("sts")
+
+        def mock_error():
+            raise ConnectionError("Test Error")
+
+        # Something bad happen during boto3.session.Session creation (e.g. 
wrong credentials or conn error)
+        mock_boto3_session.reset_mock()
+        mock_boto3_session.side_effect = mock_error
+        result, message = hook.test_connection()
+        assert not result
+        assert message == "'ConnectionError' error occurred while testing 
connection: Test Error"
+
+        assert hook.client_type == "ec2"
+
     @mock.patch.dict(os.environ, {f"AIRFLOW_CONN_{MOCK_AWS_CONN_ID.upper()}": 
"aws://"})
     def test_conn_config_conn_id_exists(self):
         """Test retrieve connection config if aws_conn_id exists."""

Reply via email to