This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 737e50a02a Fix assume role if user explicit set credentials (#26946)
737e50a02a is described below
commit 737e50a02a7031bf0123e57496a55e477cb61b8c
Author: Andrey Anshin <[email protected]>
AuthorDate: Fri Oct 21 22:13:04 2022 +0400
Fix assume role if user explicit set credentials (#26946)
* Fix assume role if user explicit set credentials
---
airflow/providers/amazon/aws/hooks/base_aws.py | 14 +++++-
tests/providers/amazon/aws/hooks/test_base_aws.py | 58 +++++++++++++++--------
2 files changed, 51 insertions(+), 21 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py
b/airflow/providers/amazon/aws/hooks/base_aws.py
index b61df51e4b..a8c0ea2951 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -125,7 +125,19 @@ class BaseSessionFactory(LoggingMixin):
return boto3.session.Session(region_name=self.region_name)
elif not self.role_arn:
return self.basic_session
- return
self._create_session_with_assume_role(session_kwargs=self.conn.session_kwargs)
+
+ # Values stored in ``AwsConnectionWrapper.session_kwargs`` are
intended to be used only
+ # to create the initial boto3 session.
+ # If the user wants to use the 'assume_role' mechanism then only the
'region_name' needs to be
+ # provided, otherwise other parameters might conflict with the base
botocore session.
+ # Unfortunately it is not a part of public boto3 API, see source of
boto3.session.Session:
+ #
https://boto3.amazonaws.com/v1/documentation/api/latest/_modules/boto3/session.html#Session
+ # If we provide 'aws_access_key_id' or 'aws_secret_access_key' or
'aws_session_token'
+ # as part of session kwargs it will use them instead of assumed
credentials.
+ assume_session_kwargs = {}
+ if self.conn.region_name:
+ assume_session_kwargs["region_name"] = self.conn.region_name
+ return
self._create_session_with_assume_role(session_kwargs=assume_session_kwargs)
def _create_basic_session(self, session_kwargs: dict[str, Any]) ->
boto3.session.Session:
return boto3.session.Session(**session_kwargs)
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py
b/tests/providers/amazon/aws/hooks/test_base_aws.py
index d7b48b5ece..2cff1ce805 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -226,6 +226,44 @@ class TestSessionFactory:
mock_boto3_session.assert_called_once_with(**expected_arguments)
assert session == MOCK_BOTO3_SESSION
+ @pytest.mark.skipif(mock_sts is None, reason="mock_sts package not
present")
+ @mock_sts
+ @pytest.mark.parametrize(
+ "conn_id, conn_extra",
+ [
+ (
+ "assume-with-initial-creds",
+ {
+ "aws_access_key_id": "mock_aws_access_key_id",
+ "aws_secret_access_key": "mock_aws_access_key_id",
+ "aws_session_token": "mock_aws_session_token",
+ },
+ ),
+ ("assume-without-initial-creds", {}),
+ ],
+ )
+ @pytest.mark.parametrize("region_name", ["ap-southeast-2", "sa-east-1"])
+ @pytest.mark.parametrize("role_session_name", [None, "test-session-name"])
+ def test_get_credentials_from_role_arn(self, conn_id, conn_extra,
region_name, role_session_name):
+ """Test creation session which set role_arn extra in connection."""
+ extra = {
+ **conn_extra,
+ "role_arn": "arn:aws:iam::123456:role/role_arn",
+ "region_name": region_name,
+ }
+ if role_session_name:
+ extra["assume_role_kwargs"] = {"RoleSessionName":
role_session_name}
+ conn = AwsConnectionWrapper.from_connection_metadata(conn_id=conn_id,
extra=extra)
+ sf = BaseSessionFactory(conn=conn)
+ session = sf.create_session()
+ assert session.region_name == region_name
+ # Validate method of botocore credentials provider.
+ # It shouldn't be 'explicit' which refers in this case to initial
credentials.
+ assert session.get_credentials().method == 'sts-assume-role'
+
+ user_id = session.client("sts").get_caller_identity()["UserId"]
+ assert user_id.endswith(role_session_name if role_session_name else
f"airflow_{conn_id}")
+
class TestAwsBaseHook:
@unittest.skipIf(mock_emr is None, 'mock_emr package not present')
@@ -332,26 +370,6 @@ class TestAwsBaseHook:
]
mock_boto3.assert_has_calls(calls_assume_role)
- @unittest.skipIf(mock_sts is None, 'mock_sts package not present')
- @mock.patch.object(AwsBaseHook, 'get_connection')
- @mock_sts
- def test_get_credentials_from_role_arn(self, mock_get_connection):
- mock_connection = Connection(
- conn_id='aws_default',
- conn_type=MOCK_CONN_TYPE,
- extra='{"role_arn":"arn:aws:iam::123456:role/role_arn"}',
- )
- mock_get_connection.return_value = mock_connection
- hook = AwsBaseHook(aws_conn_id='aws_default',
client_type='airflow_test')
- credentials_from_hook = hook.get_credentials()
- assert "ASIA" in credentials_from_hook.access_key
-
- # We assert the length instead of actual values as the values are
random:
- # Details:
https://github.com/spulec/moto/commit/ab0d23a0ba2506e6338ae20b3fde70da049f7b03
- assert 20 == len(credentials_from_hook.access_key)
- assert 40 == len(credentials_from_hook.secret_key)
- assert 356 == len(credentials_from_hook.token)
-
def test_get_credentials_from_gcp_credentials(self):
mock_connection = Connection(
extra=json.dumps(