This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 65ac62867e Get boto3.session.Session by appropriate method (#25569)
65ac62867e is described below
commit 65ac62867e2a09a406ca443a6ac44f9b667fbc55
Author: Andrey Anshin <[email protected]>
AuthorDate: Sun Aug 7 14:56:33 2022 +0400
Get boto3.session.Session by appropriate method (#25569)
---
airflow/providers/amazon/aws/hooks/base_aws.py | 62 +++++++++++++----------
airflow/providers/amazon/aws/hooks/glue.py | 6 +--
airflow/providers/amazon/aws/hooks/s3.py | 14 ++---
tests/providers/amazon/aws/hooks/test_base_aws.py | 43 ++++++++++++++++
tests/providers/amazon/aws/hooks/test_glue.py | 44 ++++++++--------
5 files changed, 107 insertions(+), 62 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py
b/airflow/providers/amazon/aws/hooks/base_aws.py
index 1132892e73..224722609f 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -329,8 +329,8 @@ class BaseSessionFactory(LoggingMixin):
def _get_region_name(self) -> Optional[str]:
warnings.warn(
- "`BaseSessionFactory._get_region_name` method will be deprecated
in the future."
- "Please use `BaseSessionFactory.region_name` property instead.",
+ "`BaseSessionFactory._get_region_name` method deprecated and will
be removed "
+ "in a future releases. Please use `BaseSessionFactory.region_name`
property instead.",
PendingDeprecationWarning,
stacklevel=2,
)
@@ -338,8 +338,8 @@ class BaseSessionFactory(LoggingMixin):
def _read_role_arn_from_extra_config(self) -> Optional[str]:
warnings.warn(
- "`BaseSessionFactory._read_role_arn_from_extra_config` method will
be deprecated in the future."
- "Please use `BaseSessionFactory.role_arn` property instead.",
+ "`BaseSessionFactory._read_role_arn_from_extra_config` method
deprecated and will be removed "
+ "in a future releases. Please use `BaseSessionFactory.role_arn`
property instead.",
PendingDeprecationWarning,
stacklevel=2,
)
@@ -347,8 +347,8 @@ class BaseSessionFactory(LoggingMixin):
def _read_credentials_from_connection(self) -> Tuple[Optional[str],
Optional[str]]:
warnings.warn(
- "`BaseSessionFactory._read_credentials_from_connection` method
will be deprecated in the future."
- "Please use `BaseSessionFactory.conn.aws_access_key_id` and "
+ "`BaseSessionFactory._read_credentials_from_connection` method
deprecated and will be removed "
+ "in a future releases. Please use
`BaseSessionFactory.conn.aws_access_key_id` and "
"`BaseSessionFactory.aws_secret_access_key` properties instead.",
PendingDeprecationWarning,
stacklevel=2,
@@ -430,15 +430,12 @@ class AwsGenericHook(BaseHook,
Generic[BaseAwsConnection]):
"""Configuration for botocore client read-only property."""
return self.conn_config.botocore_config
- def _get_credentials(self, region_name: Optional[str]) ->
Tuple[boto3.session.Session, Optional[str]]:
- self.log.debug("Airflow Connection: aws_conn_id=%s", self.aws_conn_id)
-
- session = SessionFactory(
+ def get_session(self, region_name: Optional[str] = None) ->
boto3.session.Session:
+ """Get the underlying
boto3.session.Session(region_name=region_name)."""
+ return SessionFactory(
conn=self.conn_config, region_name=region_name, config=self.config
).create_session()
- return session, self.conn_config.endpoint_url
-
def get_client_type(
self,
client_type: Optional[str] = None,
@@ -446,8 +443,6 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
config: Optional[Config] = None,
) -> boto3.client:
"""Get the underlying boto3 client using boto3 session"""
- session, endpoint_url = self._get_credentials(region_name=region_name)
-
if client_type:
warnings.warn(
"client_type is deprecated. Set client_type from class
attribute.",
@@ -462,7 +457,10 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
if config is None:
config = self.config
- return session.client(client_type, endpoint_url=endpoint_url,
config=config, verify=self.verify)
+ session = self.get_session(region_name=region_name)
+ return session.client(
+ client_type, endpoint_url=self.conn_config.endpoint_url,
config=config, verify=self.verify
+ )
def get_resource_type(
self,
@@ -471,8 +469,6 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
config: Optional[Config] = None,
) -> boto3.resource:
"""Get the underlying boto3 resource using boto3 session"""
- session, endpoint_url = self._get_credentials(region_name=region_name)
-
if resource_type:
warnings.warn(
"resource_type is deprecated. Set resource_type from class
attribute.",
@@ -487,10 +483,13 @@ class AwsGenericHook(BaseHook,
Generic[BaseAwsConnection]):
if config is None:
config = self.config
- return session.resource(resource_type, endpoint_url=endpoint_url,
config=config, verify=self.verify)
+ session = self.get_session(region_name=region_name)
+ return session.resource(
+ resource_type, endpoint_url=self.conn_config.endpoint_url,
config=config, verify=self.verify
+ )
@cached_property
- def conn(self) -> Union[boto3.client, boto3.resource]:
+ def conn(self) -> BaseAwsConnection:
"""
Get the underlying boto3 client/resource (cached)
@@ -538,22 +537,16 @@ class AwsGenericHook(BaseHook,
Generic[BaseAwsConnection]):
# Compat shim
return self.conn
- def get_session(self, region_name: Optional[str] = None) ->
boto3.session.Session:
- """Get the underlying boto3.session."""
- session, _ = self._get_credentials(region_name=region_name)
- return session
-
def get_credentials(self, region_name: Optional[str] = None) ->
ReadOnlyCredentials:
"""
Get the underlying `botocore.Credentials` object.
This contains the following authentication attributes: access_key,
secret_key and token.
"""
- session, _ = self._get_credentials(region_name=region_name)
# Credentials are refreshable, so accessing your access key and
# secret key separately can lead to a race condition.
# See https://stackoverflow.com/a/36291428/8283373
- return session.get_credentials().get_frozen_credentials()
+ return
self.get_session(region_name=region_name).get_credentials().get_frozen_credentials()
def expand_role(self, role: str, region_name: Optional[str] = None) -> str:
"""
@@ -567,8 +560,10 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
if "/" in role:
return role
else:
- session, endpoint_url =
self._get_credentials(region_name=region_name)
- _client = session.client('iam', endpoint_url=endpoint_url,
config=self.config, verify=self.verify)
+ session = self.get_session(region_name=region_name)
+ _client = session.client(
+ 'iam', endpoint_url=self.conn_config.endpoint_url,
config=self.config, verify=self.verify
+ )
return _client.get_role(RoleName=role)["Role"]["Arn"]
@staticmethod
@@ -603,6 +598,17 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
return retry_decorator
+ def _get_credentials(self, region_name: Optional[str]) ->
Tuple[boto3.session.Session, Optional[str]]:
+ warnings.warn(
+ "`AwsGenericHook._get_credentials` method deprecated and will be
removed in a future releases. "
+ "Please use `AwsGenericHook.get_session` method and "
+ "`AwsGenericHook.conn_config.endpoint_url` property instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
+ return self.get_session(region_name=region_name),
self.conn_config.endpoint_url
+
@staticmethod
def get_ui_field_behaviour() -> Dict[str, Any]:
"""Returns custom UI field behaviour for AWS Connection."""
diff --git a/airflow/providers/amazon/aws/hooks/glue.py
b/airflow/providers/amazon/aws/hooks/glue.py
index 9201b9f70b..fba765e2ff 100644
--- a/airflow/providers/amazon/aws/hooks/glue.py
+++ b/airflow/providers/amazon/aws/hooks/glue.py
@@ -100,10 +100,10 @@ class GlueJobHook(AwsBaseHook):
def get_iam_execution_role(self) -> Dict:
""":return: iam role for job execution"""
- session, endpoint_url =
self._get_credentials(region_name=self.region_name)
- iam_client = session.client('iam', endpoint_url=endpoint_url,
config=self.config, verify=self.verify)
-
try:
+ iam_client = self.get_session(region_name=self.region_name).client(
+ 'iam', endpoint_url=self.conn_config.endpoint_url,
config=self.config, verify=self.verify
+ )
glue_execution_role = iam_client.get_role(RoleName=self.role_name)
self.log.info("Iam Role Name: %s", self.role_name)
return glue_execution_role
diff --git a/airflow/providers/amazon/aws/hooks/s3.py
b/airflow/providers/amazon/aws/hooks/s3.py
index 8bd3f0e670..149887f6c9 100644
--- a/airflow/providers/amazon/aws/hooks/s3.py
+++ b/airflow/providers/amazon/aws/hooks/s3.py
@@ -215,12 +215,9 @@ class S3Hook(AwsBaseHook):
:return: the bucket object to the bucket name.
:rtype: boto3.S3.Bucket
"""
- # Buckets have no regions, and we cannot remove the region name from
_get_credentials as we would
- # break compatibility, so we set it explicitly to None.
- session, endpoint_url = self._get_credentials(region_name=None)
- s3_resource = session.resource(
+ s3_resource = self.get_session().resource(
"s3",
- endpoint_url=endpoint_url,
+ endpoint_url=self.conn_config.endpoint_url,
config=self.config,
verify=self.verify,
)
@@ -465,12 +462,9 @@ class S3Hook(AwsBaseHook):
:return: the key object from the bucket
:rtype: boto3.s3.Object
"""
- # Buckets have no regions, and we cannot remove the region name from
_get_credentials as we would
- # break compatibility, so we set it explicitly to None.
- session, endpoint_url = self._get_credentials(region_name=None)
- s3_resource = session.resource(
+ s3_resource = self.get_session().resource(
"s3",
- endpoint_url=endpoint_url,
+ endpoint_url=self.conn_config.endpoint_url,
config=self.config,
verify=self.verify,
)
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py
b/tests/providers/amazon/aws/hooks/test_base_aws.py
index f2d55cd3bb..67fdd7c186 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -771,6 +771,49 @@ class TestAwsBaseHook:
assert isinstance(conn_config_fallback_not_exists,
AwsConnectionWrapper)
assert not conn_config_fallback_not_exists
+ @mock.patch('airflow.providers.amazon.aws.hooks.base_aws.SessionFactory')
+ @pytest.mark.parametrize("hook_region_name", [None, "eu-west-1"])
+ @pytest.mark.parametrize(
+ "hook_botocore_config", [None,
Config(s3={"us_east_1_regional_endpoint": "regional"})]
+ )
+ @pytest.mark.parametrize("method_region_name", [None, "cn-north-1"])
+ def test_get_session(
+ self, mock_session_factory, hook_region_name, hook_botocore_config,
method_region_name
+ ):
+ """Test get boto3 Session by hook."""
+ mock_session_factory_instance = mock_session_factory.return_value
+ mock_session_factory_instance.create_session.return_value =
MOCK_BOTO3_SESSION
+
+ hook = AwsBaseHook(aws_conn_id=None, region_name=hook_region_name,
config=hook_botocore_config)
+ session = hook.get_session(region_name=method_region_name)
+ mock_session_factory.assert_called_once_with(
+ conn=hook.conn_config,
+ region_name=method_region_name,
+ config=hook_botocore_config,
+ )
+ assert mock_session_factory_instance.create_session.assert_called_once
+ assert session == MOCK_BOTO3_SESSION
+
+ @mock.patch(
+
'airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook.get_session',
+ return_value=MOCK_BOTO3_SESSION,
+ )
+ @pytest.mark.parametrize("region_name", [None, "aws-global", "eu-west-1"])
+ def test_deprecate_private_method__get_credentials(self,
mock_boto3_session, region_name):
+ """Test deprecated method AwsGenericHook._get_credentials."""
+ hook = AwsBaseHook(aws_conn_id=None)
+ warning_message = (
+ r"`AwsGenericHook._get_credentials` method deprecated and will be
removed in a future releases\. "
+ r"Please use `AwsGenericHook.get_session` method and "
+ r"`AwsGenericHook.conn_config.endpoint_url` property instead\."
+ )
+ with pytest.warns(DeprecationWarning, match=warning_message):
+ session, endpoint = hook._get_credentials(region_name)
+
+ mock_boto3_session.assert_called_once_with(region_name=region_name)
+ assert session == MOCK_BOTO3_SESSION
+ assert endpoint == hook.conn_config.endpoint_url
+
class ThrowErrorUntilCount:
"""Holds counter state for invoking a method several times in a row."""
diff --git a/tests/providers/amazon/aws/hooks/test_glue.py
b/tests/providers/amazon/aws/hooks/test_glue.py
index e8f55c5b81..394c60a60b 100644
--- a/tests/providers/amazon/aws/hooks/test_glue.py
+++ b/tests/providers/amazon/aws/hooks/test_glue.py
@@ -16,9 +16,11 @@
# specific language governing permissions and limitations
# under the License.
import json
-import unittest
from unittest import mock
+import boto3
+import pytest
+
from airflow.providers.amazon.aws.hooks.glue import GlueJobHook
try:
@@ -27,19 +29,19 @@ except ImportError:
mock_iam = mock_glue = None
-class TestGlueJobHook(unittest.TestCase):
- def setUp(self):
+class TestGlueJobHook:
+ @pytest.fixture(autouse=True)
+ def setup(self):
self.some_aws_region = "us-west-2"
- @unittest.skipIf(mock_iam is None, 'mock_iam package not present')
+ @pytest.mark.skipif(mock_glue is None, reason="mock_glue package not
present")
@mock_iam
- def test_get_iam_execution_role(self):
- hook = GlueJobHook(
- job_name='aws_test_glue_job', s3_bucket='some_bucket',
iam_role_name='my_test_role'
- )
- iam_role = hook.get_client_type('iam').create_role(
- Path="/",
- RoleName='my_test_role',
+ @pytest.mark.parametrize("role_path", ["/", "/custom-path/"])
+ def test_get_iam_execution_role(self, role_path):
+ expected_role = "my_test_role"
+ boto3.client("iam").create_role(
+ Path=role_path,
+ RoleName=expected_role,
AssumeRolePolicyDocument=json.dumps(
{
"Version": "2012-10-17",
@@ -51,11 +53,18 @@ class TestGlueJobHook(unittest.TestCase):
}
),
)
+
+ hook = GlueJobHook(
+ aws_conn_id=None,
+ job_name='aws_test_glue_job',
+ s3_bucket='some_bucket',
+ iam_role_name=expected_role,
+ )
iam_role = hook.get_iam_execution_role()
assert iam_role is not None
assert "Role" in iam_role
assert "Arn" in iam_role['Role']
- assert iam_role['Role']['Arn'] ==
"arn:aws:iam::123456789012:role/my_test_role"
+ assert iam_role['Role']['Arn'] ==
f"arn:aws:iam::123456789012:role{role_path}{expected_role}"
@mock.patch.object(GlueJobHook, "get_conn")
def test_get_or_create_glue_job_get_existing_job(self, mock_get_conn):
@@ -84,7 +93,7 @@ class TestGlueJobHook(unittest.TestCase):
mock_get_conn.return_value.get_job.assert_called_once_with(JobName=hook.job_name)
assert result == expected_job_name
- @unittest.skipIf(mock_glue is None, "mock_glue package not present")
+ @pytest.mark.skipif(mock_glue is None, reason="mock_glue package not
present")
@mock_glue
@mock.patch.object(GlueJobHook, "get_iam_execution_role")
def test_get_or_create_glue_job_create_new_job(self,
mock_get_iam_execution_role):
@@ -135,10 +144,7 @@ class TestGlueJobHook(unittest.TestCase):
some_script = "s3:/glue-examples/glue-scripts/sample_aws_glue_job.py"
some_s3_bucket = "my-includes"
- with self.assertRaises(
- ValueError,
- msg="ValueError should be raised for specifying the num_of_dpus
and worker type together!",
- ):
+ with pytest.raises(ValueError, match="Cannot specify num_of_dpus with
custom WorkerType"):
GlueJobHook(
job_name='aws_test_glue_job',
desc='This is test case job from Airflow',
@@ -175,7 +181,3 @@ class TestGlueJobHook(unittest.TestCase):
glue_job_run = glue_job_hook.initialize_job(some_script_arguments,
some_run_kwargs)
glue_job_run_state =
glue_job_hook.get_job_state(glue_job_run['JobName'], glue_job_run['JobRunId'])
assert glue_job_run_state == mock_job_run_state, 'Mocks but be equal'
-
-
-if __name__ == '__main__':
- unittest.main()