This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 8402e9adf4 Use `boto3.client` linked to resource meta instead of 
create new one for waiters (#33552)
8402e9adf4 is described below

commit 8402e9adf4c7d0ddf234ccfb22fce5c34384920a
Author: Andrey Anshin <[email protected]>
AuthorDate: Mon Aug 21 13:13:55 2023 +0400

    Use `boto3.client` linked to resource meta instead of create new one for 
waiters (#33552)
---
 airflow/providers/amazon/aws/hooks/base_aws.py     | 28 ++++++-------
 .../amazon/aws/waiters/test_custom_waiters.py      | 46 ++++++++++++++++------
 2 files changed, 44 insertions(+), 30 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py 
b/airflow/providers/amazon/aws/hooks/base_aws.py
index 904ed0a385..660d6fa92e 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -680,12 +680,16 @@ class AwsGenericHook(BaseHook, 
Generic[BaseAwsConnection]):
         return self.get_client_type(region_name=self.region_name, 
deferrable=True)
 
     @cached_property
-    def conn_client_meta(self) -> ClientMeta:
-        """Get botocore client metadata from Hook connection (cached)."""
+    def _client(self) -> botocore.client.BaseClient:
         conn = self.conn
         if isinstance(conn, botocore.client.BaseClient):
-            return conn.meta
-        return conn.meta.client.meta
+            return conn
+        return conn.meta.client
+
+    @property
+    def conn_client_meta(self) -> ClientMeta:
+        """Get botocore client metadata from Hook connection (cached)."""
+        return self._client.meta
 
     @property
     def conn_region_name(self) -> str:
@@ -862,19 +866,9 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
 
         if deferrable and not client:
             raise ValueError("client must be provided for a deferrable 
waiter.")
-        client = client or self.conn
+        # Currently, the custom waiter doesn't work with resource_type, only 
client_type is supported.
+        client = client or self._client
         if self.waiter_path and (waiter_name in self._list_custom_waiters()):
-            # Currently, the custom waiter doesn't work with resource_type, 
only client_type is supported.
-            if self.resource_type:
-                credentials = self.get_credentials()
-                client = boto3.client(
-                    self.resource_type,
-                    region_name=self.region_name,
-                    aws_access_key_id=credentials.access_key,
-                    aws_secret_access_key=credentials.secret_key,
-                    aws_session_token=credentials.token,
-                )
-
             # Technically if waiter_name is in custom_waiters then 
self.waiter_path must
             # exist but MyPy doesn't like the fact that self.waiter_path could 
be None.
             with open(self.waiter_path) as config_file:
@@ -909,7 +903,7 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
         return [*self._list_official_waiters(), *self._list_custom_waiters()]
 
     def _list_official_waiters(self) -> list[str]:
-        return self.conn.waiter_names
+        return self._client.waiter_names
 
     def _list_custom_waiters(self) -> list[str]:
         if not self.waiter_path:
diff --git a/tests/providers/amazon/aws/waiters/test_custom_waiters.py 
b/tests/providers/amazon/aws/waiters/test_custom_waiters.py
index d02c9c49e8..21c051f3b4 100644
--- a/tests/providers/amazon/aws/waiters/test_custom_waiters.py
+++ b/tests/providers/amazon/aws/waiters/test_custom_waiters.py
@@ -26,6 +26,7 @@ from botocore.exceptions import WaiterError
 from botocore.waiter import WaiterModel
 from moto import mock_eks
 
+from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
 from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook
 from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook, 
EcsTaskDefinitionStates
 from airflow.providers.amazon.aws.hooks.eks import EksHook
@@ -73,6 +74,22 @@ class TestBaseWaiter:
             assert waiter.model.__getattribute__(attr) == 
expected_model.__getattribute__(attr)
         assert waiter.client == client_name
 
+    @pytest.mark.parametrize("boto_type", ["client", "resource"])
+    def test_get_botocore_waiter(self, boto_type, monkeypatch):
+        kw = {f"{boto_type}_type": "s3"}
+        if boto_type == "client":
+            fake_client = boto3.client("s3", region_name="eu-west-3")
+        elif boto_type == "resource":
+            fake_client = boto3.resource("s3", region_name="eu-west-3")
+        else:
+            raise ValueError(f"Unexpected value {boto_type!r} for 
`boto_type`.")
+        monkeypatch.setattr(AwsBaseHook, "conn", fake_client)
+
+        hook = AwsBaseHook(**kw)
+        with mock.patch("botocore.client.BaseClient.get_waiter") as m:
+            hook.get_waiter(waiter_name="FooBar")
+            m.assert_called_once_with("FooBar")
+
 
 class TestCustomEKSServiceWaiters:
     def test_service_waiters(self):
@@ -230,8 +247,9 @@ class TestCustomDynamoDBServiceWaiters:
 
     @pytest.fixture(autouse=True)
     def setup_test_cases(self, monkeypatch):
-        self.client = boto3.client("dynamodb", region_name="eu-west-3")
-        monkeypatch.setattr(DynamoDBHook, "conn", self.client)
+        self.resource = boto3.resource("dynamodb", region_name="eu-west-3")
+        monkeypatch.setattr(DynamoDBHook, "conn", self.resource)
+        self.client = self.resource.meta.client
 
     @pytest.fixture
     def mock_describe_export(self):
@@ -253,16 +271,15 @@ class TestCustomDynamoDBServiceWaiters:
 
     def test_export_table_to_point_in_time_completed(self, 
mock_describe_export):
         """Test state transition from `in progress` to `completed` during 
init."""
-        with mock.patch("boto3.client") as client:
-            client.return_value = self.client
-            waiter = DynamoDBHook(aws_conn_id=None).get_waiter("export_table", 
client=self.client)
-            mock_describe_export.side_effect = [
-                self.describe_export(self.STATUS_IN_PROGRESS),
-                self.describe_export(self.STATUS_COMPLETED),
-            ]
-            waiter.wait(
-                
ExportArn="LoremIpsumissimplydummytextoftheprintingandtypesettingindustry",
-            )
+        waiter = DynamoDBHook(aws_conn_id=None).get_waiter("export_table")
+        mock_describe_export.side_effect = [
+            self.describe_export(self.STATUS_IN_PROGRESS),
+            self.describe_export(self.STATUS_COMPLETED),
+        ]
+        waiter.wait(
+            
ExportArn="LoremIpsumissimplydummytextoftheprintingandtypesettingindustry",
+            WaiterConfig={"Delay": 0.01, "MaxAttempts": 3},
+        )
 
     def test_export_table_to_point_in_time_failed(self, mock_describe_export):
         """Test state transition from `in progress` to `failed` during init."""
@@ -274,4 +291,7 @@ class TestCustomDynamoDBServiceWaiters:
             ]
             waiter = DynamoDBHook(aws_conn_id=None).get_waiter("export_table", 
client=self.client)
             with pytest.raises(WaiterError, match='we matched expected path: 
"FAILED"'):
-                
waiter.wait(ExportArn="LoremIpsumissimplydummytextoftheprintingandtypesettingindustry")
+                waiter.wait(
+                    
ExportArn="LoremIpsumissimplydummytextoftheprintingandtypesettingindustry",
+                    WaiterConfig={"Delay": 0.01, "MaxAttempts": 3},
+                )

Reply via email to