This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 737e50a02a Fix assume role if user explicit set credentials (#26946)
737e50a02a is described below

commit 737e50a02a7031bf0123e57496a55e477cb61b8c
Author: Andrey Anshin <[email protected]>
AuthorDate: Fri Oct 21 22:13:04 2022 +0400

    Fix assume role if user explicit set credentials (#26946)
    
    * Fix assume role if user explicit set credentials
---
 airflow/providers/amazon/aws/hooks/base_aws.py    | 14 +++++-
 tests/providers/amazon/aws/hooks/test_base_aws.py | 58 +++++++++++++++--------
 2 files changed, 51 insertions(+), 21 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py 
b/airflow/providers/amazon/aws/hooks/base_aws.py
index b61df51e4b..a8c0ea2951 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -125,7 +125,19 @@ class BaseSessionFactory(LoggingMixin):
             return boto3.session.Session(region_name=self.region_name)
         elif not self.role_arn:
             return self.basic_session
-        return 
self._create_session_with_assume_role(session_kwargs=self.conn.session_kwargs)
+
+        # Values stored in ``AwsConnectionWrapper.session_kwargs`` are 
intended to be used only
+        # to create the initial boto3 session.
+        # If the user wants to use the 'assume_role' mechanism then only the 
'region_name' needs to be
+        # provided, otherwise other parameters might conflict with the base 
botocore session.
+        # Unfortunately it is not a part of public boto3 API, see source of 
boto3.session.Session:
+        # 
https://boto3.amazonaws.com/v1/documentation/api/latest/_modules/boto3/session.html#Session
+        # If we provide 'aws_access_key_id' or 'aws_secret_access_key' or 
'aws_session_token'
+        # as part of session kwargs it will use them instead of assumed 
credentials.
+        assume_session_kwargs = {}
+        if self.conn.region_name:
+            assume_session_kwargs["region_name"] = self.conn.region_name
+        return 
self._create_session_with_assume_role(session_kwargs=assume_session_kwargs)
 
     def _create_basic_session(self, session_kwargs: dict[str, Any]) -> 
boto3.session.Session:
         return boto3.session.Session(**session_kwargs)
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py 
b/tests/providers/amazon/aws/hooks/test_base_aws.py
index d7b48b5ece..2cff1ce805 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -226,6 +226,44 @@ class TestSessionFactory:
         mock_boto3_session.assert_called_once_with(**expected_arguments)
         assert session == MOCK_BOTO3_SESSION
 
+    @pytest.mark.skipif(mock_sts is None, reason="mock_sts package not 
present")
+    @mock_sts
+    @pytest.mark.parametrize(
+        "conn_id, conn_extra",
+        [
+            (
+                "assume-with-initial-creds",
+                {
+                    "aws_access_key_id": "mock_aws_access_key_id",
+                    "aws_secret_access_key": "mock_aws_access_key_id",
+                    "aws_session_token": "mock_aws_session_token",
+                },
+            ),
+            ("assume-without-initial-creds", {}),
+        ],
+    )
+    @pytest.mark.parametrize("region_name", ["ap-southeast-2", "sa-east-1"])
+    @pytest.mark.parametrize("role_session_name", [None, "test-session-name"])
+    def test_get_credentials_from_role_arn(self, conn_id, conn_extra, 
region_name, role_session_name):
+        """Test creation session which set role_arn extra in connection."""
+        extra = {
+            **conn_extra,
+            "role_arn": "arn:aws:iam::123456:role/role_arn",
+            "region_name": region_name,
+        }
+        if role_session_name:
+            extra["assume_role_kwargs"] = {"RoleSessionName": 
role_session_name}
+        conn = AwsConnectionWrapper.from_connection_metadata(conn_id=conn_id, 
extra=extra)
+        sf = BaseSessionFactory(conn=conn)
+        session = sf.create_session()
+        assert session.region_name == region_name
+        # Validate method of botocore credentials provider.
+        # It shouldn't be 'explicit' which refers in this case to initial 
credentials.
+        assert session.get_credentials().method == 'sts-assume-role'
+
+        user_id = session.client("sts").get_caller_identity()["UserId"]
+        assert user_id.endswith(role_session_name if role_session_name else 
f"airflow_{conn_id}")
+
 
 class TestAwsBaseHook:
     @unittest.skipIf(mock_emr is None, 'mock_emr package not present')
@@ -332,26 +370,6 @@ class TestAwsBaseHook:
         ]
         mock_boto3.assert_has_calls(calls_assume_role)
 
-    @unittest.skipIf(mock_sts is None, 'mock_sts package not present')
-    @mock.patch.object(AwsBaseHook, 'get_connection')
-    @mock_sts
-    def test_get_credentials_from_role_arn(self, mock_get_connection):
-        mock_connection = Connection(
-            conn_id='aws_default',
-            conn_type=MOCK_CONN_TYPE,
-            extra='{"role_arn":"arn:aws:iam::123456:role/role_arn"}',
-        )
-        mock_get_connection.return_value = mock_connection
-        hook = AwsBaseHook(aws_conn_id='aws_default', 
client_type='airflow_test')
-        credentials_from_hook = hook.get_credentials()
-        assert "ASIA" in credentials_from_hook.access_key
-
-        # We assert the length instead of actual values as the values are 
random:
-        # Details: 
https://github.com/spulec/moto/commit/ab0d23a0ba2506e6338ae20b3fde70da049f7b03
-        assert 20 == len(credentials_from_hook.access_key)
-        assert 40 == len(credentials_from_hook.secret_key)
-        assert 356 == len(credentials_from_hook.token)
-
     def test_get_credentials_from_gcp_credentials(self):
         mock_connection = Connection(
             extra=json.dumps(

Reply via email to