This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 65ac62867e Get boto3.session.Session by appropriate method (#25569)
65ac62867e is described below

commit 65ac62867e2a09a406ca443a6ac44f9b667fbc55
Author: Andrey Anshin <[email protected]>
AuthorDate: Sun Aug 7 14:56:33 2022 +0400

    Get boto3.session.Session by appropriate method (#25569)
---
 airflow/providers/amazon/aws/hooks/base_aws.py    | 62 +++++++++++++----------
 airflow/providers/amazon/aws/hooks/glue.py        |  6 +--
 airflow/providers/amazon/aws/hooks/s3.py          | 14 ++---
 tests/providers/amazon/aws/hooks/test_base_aws.py | 43 ++++++++++++++++
 tests/providers/amazon/aws/hooks/test_glue.py     | 44 ++++++++--------
 5 files changed, 107 insertions(+), 62 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py 
b/airflow/providers/amazon/aws/hooks/base_aws.py
index 1132892e73..224722609f 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -329,8 +329,8 @@ class BaseSessionFactory(LoggingMixin):
 
     def _get_region_name(self) -> Optional[str]:
         warnings.warn(
-            "`BaseSessionFactory._get_region_name` method will be deprecated 
in the future."
-            "Please use `BaseSessionFactory.region_name` property instead.",
+            "`BaseSessionFactory._get_region_name` method deprecated and will 
be removed "
+            "in a future releases. Please use `BaseSessionFactory.region_name` 
property instead.",
             PendingDeprecationWarning,
             stacklevel=2,
         )
@@ -338,8 +338,8 @@ class BaseSessionFactory(LoggingMixin):
 
     def _read_role_arn_from_extra_config(self) -> Optional[str]:
         warnings.warn(
-            "`BaseSessionFactory._read_role_arn_from_extra_config` method will 
be deprecated in the future."
-            "Please use `BaseSessionFactory.role_arn` property instead.",
+            "`BaseSessionFactory._read_role_arn_from_extra_config` method 
deprecated and will be removed "
+            "in a future releases. Please use `BaseSessionFactory.role_arn` 
property instead.",
             PendingDeprecationWarning,
             stacklevel=2,
         )
@@ -347,8 +347,8 @@ class BaseSessionFactory(LoggingMixin):
 
     def _read_credentials_from_connection(self) -> Tuple[Optional[str], 
Optional[str]]:
         warnings.warn(
-            "`BaseSessionFactory._read_credentials_from_connection` method 
will be deprecated in the future."
-            "Please use `BaseSessionFactory.conn.aws_access_key_id` and "
+            "`BaseSessionFactory._read_credentials_from_connection` method 
deprecated and will be removed "
+            "in a future releases. Please use 
`BaseSessionFactory.conn.aws_access_key_id` and "
             "`BaseSessionFactory.aws_secret_access_key` properties instead.",
             PendingDeprecationWarning,
             stacklevel=2,
@@ -430,15 +430,12 @@ class AwsGenericHook(BaseHook, 
Generic[BaseAwsConnection]):
         """Configuration for botocore client read-only property."""
         return self.conn_config.botocore_config
 
-    def _get_credentials(self, region_name: Optional[str]) -> 
Tuple[boto3.session.Session, Optional[str]]:
-        self.log.debug("Airflow Connection: aws_conn_id=%s", self.aws_conn_id)
-
-        session = SessionFactory(
+    def get_session(self, region_name: Optional[str] = None) -> 
boto3.session.Session:
+        """Get the underlying 
boto3.session.Session(region_name=region_name)."""
+        return SessionFactory(
             conn=self.conn_config, region_name=region_name, config=self.config
         ).create_session()
 
-        return session, self.conn_config.endpoint_url
-
     def get_client_type(
         self,
         client_type: Optional[str] = None,
@@ -446,8 +443,6 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
         config: Optional[Config] = None,
     ) -> boto3.client:
         """Get the underlying boto3 client using boto3 session"""
-        session, endpoint_url = self._get_credentials(region_name=region_name)
-
         if client_type:
             warnings.warn(
                 "client_type is deprecated. Set client_type from class 
attribute.",
@@ -462,7 +457,10 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
         if config is None:
             config = self.config
 
-        return session.client(client_type, endpoint_url=endpoint_url, 
config=config, verify=self.verify)
+        session = self.get_session(region_name=region_name)
+        return session.client(
+            client_type, endpoint_url=self.conn_config.endpoint_url, 
config=config, verify=self.verify
+        )
 
     def get_resource_type(
         self,
@@ -471,8 +469,6 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
         config: Optional[Config] = None,
     ) -> boto3.resource:
         """Get the underlying boto3 resource using boto3 session"""
-        session, endpoint_url = self._get_credentials(region_name=region_name)
-
         if resource_type:
             warnings.warn(
                 "resource_type is deprecated. Set resource_type from class 
attribute.",
@@ -487,10 +483,13 @@ class AwsGenericHook(BaseHook, 
Generic[BaseAwsConnection]):
         if config is None:
             config = self.config
 
-        return session.resource(resource_type, endpoint_url=endpoint_url, 
config=config, verify=self.verify)
+        session = self.get_session(region_name=region_name)
+        return session.resource(
+            resource_type, endpoint_url=self.conn_config.endpoint_url, 
config=config, verify=self.verify
+        )
 
     @cached_property
-    def conn(self) -> Union[boto3.client, boto3.resource]:
+    def conn(self) -> BaseAwsConnection:
         """
         Get the underlying boto3 client/resource (cached)
 
@@ -538,22 +537,16 @@ class AwsGenericHook(BaseHook, 
Generic[BaseAwsConnection]):
         # Compat shim
         return self.conn
 
-    def get_session(self, region_name: Optional[str] = None) -> 
boto3.session.Session:
-        """Get the underlying boto3.session."""
-        session, _ = self._get_credentials(region_name=region_name)
-        return session
-
     def get_credentials(self, region_name: Optional[str] = None) -> 
ReadOnlyCredentials:
         """
         Get the underlying `botocore.Credentials` object.
 
         This contains the following authentication attributes: access_key, 
secret_key and token.
         """
-        session, _ = self._get_credentials(region_name=region_name)
         # Credentials are refreshable, so accessing your access key and
         # secret key separately can lead to a race condition.
         # See https://stackoverflow.com/a/36291428/8283373
-        return session.get_credentials().get_frozen_credentials()
+        return 
self.get_session(region_name=region_name).get_credentials().get_frozen_credentials()
 
     def expand_role(self, role: str, region_name: Optional[str] = None) -> str:
         """
@@ -567,8 +560,10 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
         if "/" in role:
             return role
         else:
-            session, endpoint_url = 
self._get_credentials(region_name=region_name)
-            _client = session.client('iam', endpoint_url=endpoint_url, 
config=self.config, verify=self.verify)
+            session = self.get_session(region_name=region_name)
+            _client = session.client(
+                'iam', endpoint_url=self.conn_config.endpoint_url, 
config=self.config, verify=self.verify
+            )
             return _client.get_role(RoleName=role)["Role"]["Arn"]
 
     @staticmethod
@@ -603,6 +598,17 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
 
         return retry_decorator
 
+    def _get_credentials(self, region_name: Optional[str]) -> 
Tuple[boto3.session.Session, Optional[str]]:
+        warnings.warn(
+            "`AwsGenericHook._get_credentials` method deprecated and will be 
removed in a future releases. "
+            "Please use `AwsGenericHook.get_session` method and "
+            "`AwsGenericHook.conn_config.endpoint_url` property instead.",
+            DeprecationWarning,
+            stacklevel=2,
+        )
+
+        return self.get_session(region_name=region_name), 
self.conn_config.endpoint_url
+
     @staticmethod
     def get_ui_field_behaviour() -> Dict[str, Any]:
         """Returns custom UI field behaviour for AWS Connection."""
diff --git a/airflow/providers/amazon/aws/hooks/glue.py 
b/airflow/providers/amazon/aws/hooks/glue.py
index 9201b9f70b..fba765e2ff 100644
--- a/airflow/providers/amazon/aws/hooks/glue.py
+++ b/airflow/providers/amazon/aws/hooks/glue.py
@@ -100,10 +100,10 @@ class GlueJobHook(AwsBaseHook):
 
     def get_iam_execution_role(self) -> Dict:
         """:return: iam role for job execution"""
-        session, endpoint_url = 
self._get_credentials(region_name=self.region_name)
-        iam_client = session.client('iam', endpoint_url=endpoint_url, 
config=self.config, verify=self.verify)
-
         try:
+            iam_client = self.get_session(region_name=self.region_name).client(
+                'iam', endpoint_url=self.conn_config.endpoint_url, 
config=self.config, verify=self.verify
+            )
             glue_execution_role = iam_client.get_role(RoleName=self.role_name)
             self.log.info("Iam Role Name: %s", self.role_name)
             return glue_execution_role
diff --git a/airflow/providers/amazon/aws/hooks/s3.py 
b/airflow/providers/amazon/aws/hooks/s3.py
index 8bd3f0e670..149887f6c9 100644
--- a/airflow/providers/amazon/aws/hooks/s3.py
+++ b/airflow/providers/amazon/aws/hooks/s3.py
@@ -215,12 +215,9 @@ class S3Hook(AwsBaseHook):
         :return: the bucket object to the bucket name.
         :rtype: boto3.S3.Bucket
         """
-        # Buckets have no regions, and we cannot remove the region name from 
_get_credentials as we would
-        # break compatibility, so we set it explicitly to None.
-        session, endpoint_url = self._get_credentials(region_name=None)
-        s3_resource = session.resource(
+        s3_resource = self.get_session().resource(
             "s3",
-            endpoint_url=endpoint_url,
+            endpoint_url=self.conn_config.endpoint_url,
             config=self.config,
             verify=self.verify,
         )
@@ -465,12 +462,9 @@ class S3Hook(AwsBaseHook):
         :return: the key object from the bucket
         :rtype: boto3.s3.Object
         """
-        # Buckets have no regions, and we cannot remove the region name from 
_get_credentials as we would
-        # break compatibility, so we set it explicitly to None.
-        session, endpoint_url = self._get_credentials(region_name=None)
-        s3_resource = session.resource(
+        s3_resource = self.get_session().resource(
             "s3",
-            endpoint_url=endpoint_url,
+            endpoint_url=self.conn_config.endpoint_url,
             config=self.config,
             verify=self.verify,
         )
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py 
b/tests/providers/amazon/aws/hooks/test_base_aws.py
index f2d55cd3bb..67fdd7c186 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -771,6 +771,49 @@ class TestAwsBaseHook:
         assert isinstance(conn_config_fallback_not_exists, 
AwsConnectionWrapper)
         assert not conn_config_fallback_not_exists
 
+    @mock.patch('airflow.providers.amazon.aws.hooks.base_aws.SessionFactory')
+    @pytest.mark.parametrize("hook_region_name", [None, "eu-west-1"])
+    @pytest.mark.parametrize(
+        "hook_botocore_config", [None, 
Config(s3={"us_east_1_regional_endpoint": "regional"})]
+    )
+    @pytest.mark.parametrize("method_region_name", [None, "cn-north-1"])
+    def test_get_session(
+        self, mock_session_factory, hook_region_name, hook_botocore_config, 
method_region_name
+    ):
+        """Test get boto3 Session by hook."""
+        mock_session_factory_instance = mock_session_factory.return_value
+        mock_session_factory_instance.create_session.return_value = 
MOCK_BOTO3_SESSION
+
+        hook = AwsBaseHook(aws_conn_id=None, region_name=hook_region_name, 
config=hook_botocore_config)
+        session = hook.get_session(region_name=method_region_name)
+        mock_session_factory.assert_called_once_with(
+            conn=hook.conn_config,
+            region_name=method_region_name,
+            config=hook_botocore_config,
+        )
+        assert mock_session_factory_instance.create_session.assert_called_once
+        assert session == MOCK_BOTO3_SESSION
+
+    @mock.patch(
+        
'airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook.get_session',
+        return_value=MOCK_BOTO3_SESSION,
+    )
+    @pytest.mark.parametrize("region_name", [None, "aws-global", "eu-west-1"])
+    def test_deprecate_private_method__get_credentials(self, 
mock_boto3_session, region_name):
+        """Test deprecated method AwsGenericHook._get_credentials."""
+        hook = AwsBaseHook(aws_conn_id=None)
+        warning_message = (
+            r"`AwsGenericHook._get_credentials` method deprecated and will be 
removed in a future releases\. "
+            r"Please use `AwsGenericHook.get_session` method and "
+            r"`AwsGenericHook.conn_config.endpoint_url` property instead\."
+        )
+        with pytest.warns(DeprecationWarning, match=warning_message):
+            session, endpoint = hook._get_credentials(region_name)
+
+        mock_boto3_session.assert_called_once_with(region_name=region_name)
+        assert session == MOCK_BOTO3_SESSION
+        assert endpoint == hook.conn_config.endpoint_url
+
 
 class ThrowErrorUntilCount:
     """Holds counter state for invoking a method several times in a row."""
diff --git a/tests/providers/amazon/aws/hooks/test_glue.py 
b/tests/providers/amazon/aws/hooks/test_glue.py
index e8f55c5b81..394c60a60b 100644
--- a/tests/providers/amazon/aws/hooks/test_glue.py
+++ b/tests/providers/amazon/aws/hooks/test_glue.py
@@ -16,9 +16,11 @@
 # specific language governing permissions and limitations
 # under the License.
 import json
-import unittest
 from unittest import mock
 
+import boto3
+import pytest
+
 from airflow.providers.amazon.aws.hooks.glue import GlueJobHook
 
 try:
@@ -27,19 +29,19 @@ except ImportError:
     mock_iam = mock_glue = None
 
 
-class TestGlueJobHook(unittest.TestCase):
-    def setUp(self):
+class TestGlueJobHook:
+    @pytest.fixture(autouse=True)
+    def setup(self):
         self.some_aws_region = "us-west-2"
 
-    @unittest.skipIf(mock_iam is None, 'mock_iam package not present')
+    @pytest.mark.skipif(mock_glue is None, reason="mock_glue package not 
present")
     @mock_iam
-    def test_get_iam_execution_role(self):
-        hook = GlueJobHook(
-            job_name='aws_test_glue_job', s3_bucket='some_bucket', 
iam_role_name='my_test_role'
-        )
-        iam_role = hook.get_client_type('iam').create_role(
-            Path="/",
-            RoleName='my_test_role',
+    @pytest.mark.parametrize("role_path", ["/", "/custom-path/"])
+    def test_get_iam_execution_role(self, role_path):
+        expected_role = "my_test_role"
+        boto3.client("iam").create_role(
+            Path=role_path,
+            RoleName=expected_role,
             AssumeRolePolicyDocument=json.dumps(
                 {
                     "Version": "2012-10-17",
@@ -51,11 +53,18 @@ class TestGlueJobHook(unittest.TestCase):
                 }
             ),
         )
+
+        hook = GlueJobHook(
+            aws_conn_id=None,
+            job_name='aws_test_glue_job',
+            s3_bucket='some_bucket',
+            iam_role_name=expected_role,
+        )
         iam_role = hook.get_iam_execution_role()
         assert iam_role is not None
         assert "Role" in iam_role
         assert "Arn" in iam_role['Role']
-        assert iam_role['Role']['Arn'] == 
"arn:aws:iam::123456789012:role/my_test_role"
+        assert iam_role['Role']['Arn'] == 
f"arn:aws:iam::123456789012:role{role_path}{expected_role}"
 
     @mock.patch.object(GlueJobHook, "get_conn")
     def test_get_or_create_glue_job_get_existing_job(self, mock_get_conn):
@@ -84,7 +93,7 @@ class TestGlueJobHook(unittest.TestCase):
         
mock_get_conn.return_value.get_job.assert_called_once_with(JobName=hook.job_name)
         assert result == expected_job_name
 
-    @unittest.skipIf(mock_glue is None, "mock_glue package not present")
+    @pytest.mark.skipif(mock_glue is None, reason="mock_glue package not 
present")
     @mock_glue
     @mock.patch.object(GlueJobHook, "get_iam_execution_role")
     def test_get_or_create_glue_job_create_new_job(self, 
mock_get_iam_execution_role):
@@ -135,10 +144,7 @@ class TestGlueJobHook(unittest.TestCase):
         some_script = "s3:/glue-examples/glue-scripts/sample_aws_glue_job.py"
         some_s3_bucket = "my-includes"
 
-        with self.assertRaises(
-            ValueError,
-            msg="ValueError should be raised for specifying the num_of_dpus 
and worker type together!",
-        ):
+        with pytest.raises(ValueError, match="Cannot specify num_of_dpus with 
custom WorkerType"):
             GlueJobHook(
                 job_name='aws_test_glue_job',
                 desc='This is test case job from Airflow',
@@ -175,7 +181,3 @@ class TestGlueJobHook(unittest.TestCase):
         glue_job_run = glue_job_hook.initialize_job(some_script_arguments, 
some_run_kwargs)
         glue_job_run_state = 
glue_job_hook.get_job_state(glue_job_run['JobName'], glue_job_run['JobRunId'])
         assert glue_job_run_state == mock_job_run_state, 'Mocks but be equal'
-
-
-if __name__ == '__main__':
-    unittest.main()

Reply via email to