This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 122b20bb0fe fix: don't use blocking property access for async purposes 
(#47326)
122b20bb0fe is described below

commit 122b20bb0fe63757a4a7807cb7fbd755b38728d2
Author: Cedrik Neumann <[email protected]>
AuthorDate: Fri Mar 7 11:57:52 2025 -0600

    fix: don't use blocking property access for async purposes (#47326)
    
    The AwsGenericHook provides the `@property` decorated function `async_conn`
    for building and accessing the async boto3 client. Unfortunately, this
    is results in blocking calls to the Airflow Db or secrets backends within 
async contexts.
    
    This PR provides an async method `get_async_conn` as an alternative that 
can be awaited
    in async contexts. This method will call the underlying sync code in a
    `sync_to_async` wrapper.
    
    The `async_conn` property is now deprecated and will be removed in future 
versions.
---
 .../airflow/providers/amazon/aws/hooks/base_aws.py |  25 +++++
 .../src/airflow/providers/amazon/aws/hooks/ec2.py  |   2 +-
 .../src/airflow/providers/amazon/aws/hooks/glue.py |   2 +-
 .../src/airflow/providers/amazon/aws/hooks/logs.py |   4 +-
 .../providers/amazon/aws/hooks/redshift_cluster.py |   2 +-
 .../providers/amazon/aws/hooks/redshift_data.py    |   4 +-
 .../providers/amazon/aws/hooks/sagemaker.py        |   2 +-
 .../providers/amazon/aws/triggers/README.md        |   8 +-
 .../airflow/providers/amazon/aws/triggers/base.py  |   2 +-
 .../airflow/providers/amazon/aws/triggers/ecs.py   |   8 +-
 .../airflow/providers/amazon/aws/triggers/eks.py   |   4 +-
 .../airflow/providers/amazon/aws/triggers/glue.py  |   2 +-
 .../airflow/providers/amazon/aws/triggers/s3.py    |   4 +-
 .../providers/amazon/aws/triggers/sagemaker.py     |   4 +-
 .../airflow/providers/amazon/aws/triggers/sqs.py   |   2 +-
 .../unit/amazon/aws/hooks/test_redshift_data.py    |  16 +--
 .../amazon/tests/unit/amazon/aws/hooks/test_s3.py  | 109 +++++++++------------
 .../tests/unit/amazon/aws/triggers/test_base.py    |   4 +-
 .../tests/unit/amazon/aws/triggers/test_ec2.py     |   8 +-
 .../tests/unit/amazon/aws/triggers/test_ecs.py     |  18 ++--
 .../tests/unit/amazon/aws/triggers/test_eks.py     |   4 +-
 .../tests/unit/amazon/aws/triggers/test_neptune.py |  16 +--
 .../tests/unit/amazon/aws/triggers/test_s3.py      |  14 +--
 .../unit/amazon/aws/triggers/test_sagemaker.py     |   4 +-
 24 files changed, 143 insertions(+), 125 deletions(-)

diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py 
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py
index 04c21380155..862a210cd23 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -30,6 +30,7 @@ import inspect
 import json
 import logging
 import os
+import warnings
 from copy import deepcopy
 from functools import cached_property, wraps
 from pathlib import Path
@@ -41,6 +42,7 @@ import botocore.session
 import jinja2
 import requests
 import tenacity
+from asgiref.sync import sync_to_async
 from botocore.config import Config
 from botocore.waiter import Waiter, WaiterModel
 from dateutil.tz import tzlocal
@@ -50,6 +52,7 @@ from airflow.configuration import conf
 from airflow.exceptions import (
     AirflowException,
     AirflowNotFoundException,
+    AirflowProviderDeprecationWarning,
 )
 from airflow.hooks.base import BaseHook
 from airflow.providers.amazon.aws.utils.connection_wrapper import 
AwsConnectionWrapper
@@ -747,7 +750,29 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
 
     @property
     def async_conn(self):
+        """
+        [DEPRECATED] Get an aiobotocore client to use for async operations.
+
+        This property is deprecated. Accessing it in an async context will 
cause the event loop to block.
+        Use the async method `get_async_conn` instead.
+        """
+        warnings.warn(
+            "The property `async_conn` is deprecated. Accessing it in an async 
context will cause the event loop to block. "
+            "Use the async method `get_async_conn` instead.",
+            AirflowProviderDeprecationWarning,
+            stacklevel=2,
+        )
+
+        return self._get_async_conn()
+
+    async def get_async_conn(self):
         """Get an aiobotocore client to use for async operations."""
+        # We have to wrap the call `self.get_client_type` in another call 
`_get_async_conn`,
+        # because one of it's arguments `self.region_name` is a `@property` 
decorated function
+        # calling the cached property `self.conn_config` at the end.
+        return await sync_to_async(self._get_async_conn)()
+
+    def _get_async_conn(self):
         if not self.client_type:
             raise ValueError("client_type must be specified.")
 
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/ec2.py 
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/ec2.py
index cd13f007c4e..ef4f374eb7b 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/ec2.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/ec2.py
@@ -173,7 +173,7 @@ class EC2Hook(AwsBaseHook):
         return [instance["InstanceId"] for instance in 
self.get_instances(filters=filters)]
 
     async def get_instance_state_async(self, instance_id: str) -> str:
-        async with self.async_conn as client:
+        async with await self.get_async_conn() as client:
             response = await 
client.describe_instances(InstanceIds=[instance_id])
             return response["Reservations"][0]["Instances"][0]["State"]["Name"]
 
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/glue.py 
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/glue.py
index d025652d0a8..91e69c7745b 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/glue.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/glue.py
@@ -211,7 +211,7 @@ class GlueJobHook(AwsBaseHook):
 
         The async version of get_job_state.
         """
-        async with self.async_conn as client:
+        async with await self.get_async_conn() as client:
             job_run = await client.get_job_run(JobName=job_name, RunId=run_id)
         return job_run["JobRun"]["JobRunState"]
 
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/logs.py 
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/logs.py
index 32047e81dc8..884e02478ba 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/logs.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/logs.py
@@ -152,7 +152,7 @@ class AwsLogsHook(AwsBaseHook):
          If the value is LastEventTime , the results are ordered by the event 
time. The default value is LogStreamName.
         :param count: The maximum number of items returned
         """
-        async with self.async_conn as client:
+        async with await self.get_async_conn() as client:
             try:
                 response: dict[str, Any] = await client.describe_log_streams(
                     logGroupName=log_group,
@@ -194,7 +194,7 @@ class AwsLogsHook(AwsBaseHook):
             else:
                 token_arg = {}
 
-            async with self.async_conn as client:
+            async with await self.get_async_conn() as client:
                 response = await client.get_log_events(
                     logGroupName=log_group,
                     logStreamName=log_stream_name,
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_cluster.py 
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_cluster.py
index 9c64ded01d2..a2dc837ea2f 100644
--- 
a/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_cluster.py
+++ 
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_cluster.py
@@ -93,7 +93,7 @@ class RedshiftHook(AwsBaseHook):
             return "cluster_not_found"
 
     async def cluster_status_async(self, cluster_identifier: str) -> str:
-        async with self.async_conn as client:
+        async with await self.get_async_conn() as client:
             response = await 
client.describe_clusters(ClusterIdentifier=cluster_identifier)
             return response["Clusters"][0]["ClusterStatus"] if response else 
None
 
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_data.py 
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_data.py
index 309b35448e7..50a655ec4c4 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_data.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_data.py
@@ -275,7 +275,7 @@ class 
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
 
         :param statement_id: the UUID of the statement
         """
-        async with self.async_conn as client:
+        async with await self.get_async_conn() as client:
             desc = await client.describe_statement(Id=statement_id)
             return desc["Status"] in RUNNING_STATES
 
@@ -288,6 +288,6 @@ class 
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
 
         :param statement_id: the UUID of the statement
         """
-        async with self.async_conn as client:
+        async with await self.get_async_conn() as client:
             resp = await client.describe_statement(Id=statement_id)
             return self.parse_statement_response(resp)
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker.py 
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker.py
index fb2b5ba79b8..7e02a408cb4 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker.py
@@ -1318,7 +1318,7 @@ class SageMakerHook(AwsBaseHook):
 
         :param job_name: the name of the training job
         """
-        async with self.async_conn as client:
+        async with await self.get_async_conn() as client:
             response: dict[str, Any] = await 
client.describe_training_job(TrainingJobName=job_name)
             return response
 
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/triggers/README.md 
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/README.md
index b7796dd603f..3d7c72a139e 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/README.md
+++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/README.md
@@ -65,10 +65,10 @@ To call the asynchronous `wait` function, first create a 
hook for the particular
 self.redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
 ```
 
-With this hook, we can use the async_conn property to get access to the 
aiobotocore client:
+With this hook, we can use the asynchronous get_async_conn method to get 
access to the aiobotocore client:
 
 ```python
-async with self.redshift_hook.async_conn as client:
+async with await self.redshift_hook.get_async_conn() as client:
     await client.get_waiter("cluster_available").wait(
         ClusterIdentifier=self.cluster_identifier,
         WaiterConfig={
@@ -81,7 +81,7 @@ async with self.redshift_hook.async_conn as client:
 In this case, we are using the built-in cluster_available waiter. If we wanted 
to use a custom waiter, we would change the code slightly to use the 
`get_waiter` function from the hook, rather than the aiobotocore client:
 
 ```python
-async with self.redshift_hook.async_conn as client:
+async with await self.redshift_hook.get_async_conn() as client:
     waiter = self.redshift_hook.get_waiter("cluster_paused", deferrable=True, 
client=client)
     await waiter.wait(
         ClusterIdentifier=self.cluster_identifier,
@@ -131,7 +131,7 @@ For more information about writing custom waiter, see the 
[README.md](https://gi
 In some cases, a built-in or custom waiter may not be able to solve the 
problem. In such cases, the asynchronous method used to poll the boto3 API 
would need to be defined in the hook of the service being used. This method is 
essentially the same as the synchronous version of the method, except that it 
will use the aiobotocore client, and will be awaited. For the Redshift example, 
the async `describe_clusters` method would look as follows:
 
 ```python
-async with self.async_conn as client:
+async with await self.get_async_conn() as client:
     response = 
client.describe_clusters(ClusterIdentifier=self.cluster_identifier)
 ```
 
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/base.py 
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/base.py
index eac4a8c56e1..f2c71a99adc 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/base.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/base.py
@@ -139,7 +139,7 @@ class AwsBaseWaiterTrigger(BaseTrigger):
 
     async def run(self) -> AsyncIterator[TriggerEvent]:
         hook = self.hook()
-        async with hook.async_conn as client:
+        async with await hook.get_async_conn() as client:
             waiter = hook.get_waiter(self.waiter_name, deferrable=True, 
client=client)
             await async_wait(
                 waiter,
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/ecs.py 
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/ecs.py
index 11c9cf18043..e6e54a838c7 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/ecs.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/ecs.py
@@ -167,8 +167,12 @@ class TaskDoneTrigger(BaseTrigger):
 
     async def run(self) -> AsyncIterator[TriggerEvent]:
         async with (
-            EcsHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region).async_conn as ecs_client,
-            AwsLogsHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region).async_conn as logs_client,
+            await EcsHook(
+                aws_conn_id=self.aws_conn_id, region_name=self.region
+            ).get_async_conn() as ecs_client,
+            await AwsLogsHook(
+                aws_conn_id=self.aws_conn_id, region_name=self.region
+            ).get_async_conn() as logs_client,
         ):
             waiter = ecs_client.get_waiter("tasks_stopped")
             logs_token = None
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/eks.py 
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/eks.py
index f9b3625644d..3a326e3f8e0 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/eks.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/eks.py
@@ -70,7 +70,7 @@ class EksCreateClusterTrigger(AwsBaseWaiterTrigger):
         return EksHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name)
 
     async def run(self):
-        async with self.hook().async_conn as client:
+        async with await self.hook().get_async_conn() as client:
             waiter = client.get_waiter(self.waiter_name)
             try:
                 await async_wait(
@@ -140,7 +140,7 @@ class EksDeleteClusterTrigger(AwsBaseWaiterTrigger):
         return EksHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name)
 
     async def run(self):
-        async with self.hook().async_conn as client:
+        async with await self.hook().get_async_conn() as client:
             waiter = client.get_waiter("cluster_deleted")
             if self.force_delete_compute:
                 await self.delete_any_nodegroups(client=client)
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/glue.py 
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/glue.py
index ad7b3945be5..4a56f47689a 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/glue.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/glue.py
@@ -157,7 +157,7 @@ class GlueCatalogPartitionTrigger(BaseTrigger):
         return bool(partitions)
 
     async def run(self) -> AsyncIterator[TriggerEvent]:
-        async with self.hook.async_conn as client:
+        async with await self.hook.get_async_conn() as client:
             while True:
                 result = await self.poke(client=client)
                 if result:
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/s3.py 
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/s3.py
index 9d2b055fe44..0bc192b4798 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/s3.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/s3.py
@@ -102,7 +102,7 @@ class S3KeyTrigger(BaseTrigger):
     async def run(self) -> AsyncIterator[TriggerEvent]:
         """Make an asynchronous connection using S3HookAsync."""
         try:
-            async with self.hook.async_conn as client:
+            async with await self.hook.get_async_conn() as client:
                 while True:
                     if await self.hook.check_key_async(
                         client, self.bucket_name, self.bucket_key, 
self.wildcard_match, self.use_regex
@@ -216,7 +216,7 @@ class S3KeysUnchangedTrigger(BaseTrigger):
     async def run(self) -> AsyncIterator[TriggerEvent]:
         """Make an asynchronous connection using S3Hook."""
         try:
-            async with self.hook.async_conn as client:
+            async with await self.hook.get_async_conn() as client:
                 while True:
                     result = await self.hook.is_keys_unchanged_async(
                         client=client,
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/triggers/sagemaker.py 
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/sagemaker.py
index 5f6350473f0..b8bad2ac810 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/sagemaker.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/sagemaker.py
@@ -108,7 +108,7 @@ class SageMakerTrigger(BaseTrigger):
 
     async def run(self):
         self.log.info("job name is %s and job type is %s", self.job_name, 
self.job_type)
-        async with self.hook.async_conn as client:
+        async with await self.hook.get_async_conn() as client:
             waiter = self.hook.get_waiter(
                 self._get_job_type_waiter(self.job_type), deferrable=True, 
client=client
             )
@@ -166,7 +166,7 @@ class SageMakerPipelineTrigger(BaseTrigger):
 
     async def run(self) -> AsyncIterator[TriggerEvent]:
         hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
-        async with hook.async_conn as conn:
+        async with await hook.get_async_conn() as conn:
             waiter = hook.get_waiter(self._waiter_name[self.waiter_type], 
deferrable=True, client=conn)
             for _ in range(self.waiter_max_attempts):
                 try:
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/sqs.py 
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/sqs.py
index f565b821920..60e399f97f4 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/sqs.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/sqs.py
@@ -184,7 +184,7 @@ class SqsSensorTrigger(BaseEventTrigger):
         while True:
             # This loop will run indefinitely until the timeout, which is set 
in the self.defer
             # method, is reached.
-            async with self.hook.async_conn as client:
+            async with await self.hook.get_async_conn() as client:
                 result = await self.poke(client=client)
                 if result:
                     yield TriggerEvent({"status": "success", "message_batch": 
result})
diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_redshift_data.py 
b/providers/amazon/tests/unit/amazon/aws/hooks/test_redshift_data.py
index d5480864498..96e33516d5d 100644
--- a/providers/amazon/tests/unit/amazon/aws/hooks/test_redshift_data.py
+++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_redshift_data.py
@@ -425,21 +425,23 @@ class TestRedshiftDataHook:
             ({"Status": "ABORTED"}, False),
         ],
     )
-    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.async_conn")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.get_async_conn")
     async def test_is_still_running(self, mock_conn, 
describe_statement_response, expected_result):
         hook = RedshiftDataHook()
-        mock_conn.__aenter__.return_value.describe_statement.return_value = 
describe_statement_response
+        
mock_conn.return_value.__aenter__.return_value.describe_statement.return_value 
= (
+            describe_statement_response
+        )
         response = await hook.is_still_running("uuid")
         assert response == expected_result
 
     @pytest.mark.asyncio
-    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.async_conn")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.get_async_conn")
     
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.is_still_running")
     async def test_check_query_is_finished_async(self, mock_is_still_running, 
mock_conn):
         hook = RedshiftDataHook()
         mock_is_still_running.return_value = False
         mock_conn.describe_statement = mock.AsyncMock()
-        mock_conn.__aenter__.return_value.describe_statement.return_value = {
+        
mock_conn.return_value.__aenter__.return_value.describe_statement.return_value 
= {
             "Id": "uuid",
             "Status": "FINISHED",
         }
@@ -457,13 +459,15 @@ class TestRedshiftDataHook:
             ({"Id": "uuid", "Status": "ABORTED"}, 
RedshiftDataQueryAbortedError),
         ),
     )
-    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.async_conn")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.get_async_conn")
     
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.is_still_running")
     async def test_check_query_is_finished_async_exception(
         self, mock_is_still_running, mock_conn, describe_statement_response, 
expected_exception
     ):
         hook = RedshiftDataHook()
         mock_is_still_running.return_value = False
-        mock_conn.__aenter__.return_value.describe_statement.return_value = 
describe_statement_response
+        
mock_conn.return_value.__aenter__.return_value.describe_statement.return_value 
= (
+            describe_statement_response
+        )
         with pytest.raises(expected_exception):
             await hook.check_query_is_finished_async(statement_id="uuid")
diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_s3.py 
b/providers/amazon/tests/unit/amazon/aws/hooks/test_s3.py
index d14a4e8698a..8d1da3de743 100644
--- a/providers/amazon/tests/unit/amazon/aws/hooks/test_s3.py
+++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_s3.py
@@ -477,9 +477,8 @@ class TestAwsS3Hook:
         assert response["Grants"][1]["Permission"] == "READ"
         assert response["Grants"][0]["Permission"] == "FULL_CONTROL"
 
-    
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn")
     @pytest.mark.asyncio
-    async def test_s3_key_hook_get_file_metadata_async(self, mock_client):
+    async def test_s3_key_hook_get_file_metadata_async(self):
         """
         Test check_wildcard_key for a valid response
         :return:
@@ -498,6 +497,7 @@ class TestAwsS3Hook:
         mock_paginator.paginate.return_value = mock_paginate
 
         s3_hook_async = S3Hook(client_type="S3")
+        mock_client = AsyncMock()
         mock_client.get_paginator = mock.Mock(return_value=mock_paginator)
         keys = [x async for x in 
s3_hook_async.get_file_metadata_async(mock_client, "test_bucket", "test*")]
 
@@ -507,14 +507,13 @@ class TestAwsS3Hook:
         ]
 
     @pytest.mark.asyncio
-    
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn")
-    async def test_s3_key_hook_get_head_object_with_error_async(self, 
mock_client):
+    async def test_s3_key_hook_get_head_object_with_error_async(self):
         """
         Test for 404 error if key not found and assert based on response.
         :return:
         """
         s3_hook_async = S3Hook(client_type="S3", resource_type="S3")
-
+        mock_client = AsyncMock()
         mock_client.head_object.side_effect = ClientError(
             {
                 "Error": {
@@ -536,15 +535,14 @@ class TestAwsS3Hook:
         )
         assert response is None
 
-    
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn")
     @pytest.mark.asyncio
-    async def test_s3_key_hook_get_head_object_raise_exception_async(self, 
mock_client):
+    async def test_s3_key_hook_get_head_object_raise_exception_async(self):
         """
         Test for 500 error if key not found and assert based on response.
         :return:
         """
         s3_hook_async = S3Hook(client_type="S3", resource_type="S3")
-
+        mock_client = AsyncMock()
         mock_client.head_object.side_effect = ClientError(
             {
                 "Error": {
@@ -566,8 +564,7 @@ class TestAwsS3Hook:
         assert isinstance(err.value, ClientError)
 
     @pytest.mark.asyncio
-    
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn")
-    async def test_s3_key_hook_get_files_without_wildcard_async(self, 
mock_client):
+    async def test_s3_key_hook_get_files_without_wildcard_async(self):
         """
         Test get_files for a valid response
         :return:
@@ -586,13 +583,13 @@ class TestAwsS3Hook:
         mock_paginator.paginate.return_value = mock_paginate
 
         s3_hook_async = S3Hook(client_type="S3", resource_type="S3")
+        mock_client = AsyncMock()
         mock_client.get_paginator = mock.Mock(return_value=mock_paginator)
         response = await s3_hook_async.get_files_async(mock_client, 
"test_bucket", "test.txt", False)
         assert response == []
 
     @pytest.mark.asyncio
-    
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn")
-    async def test_s3_key_hook_get_files_with_wildcard_async(self, 
mock_client):
+    async def test_s3_key_hook_get_files_with_wildcard_async(self):
         """
         Test get_files for a valid response
         :return:
@@ -611,13 +608,13 @@ class TestAwsS3Hook:
         mock_paginator.paginate.return_value = mock_paginate
 
         s3_hook_async = S3Hook(client_type="S3", resource_type="S3")
+        mock_client = AsyncMock()
         mock_client.get_paginator = mock.Mock(return_value=mock_paginator)
         response = await s3_hook_async.get_files_async(mock_client, 
"test_bucket", "test.txt", True)
         assert response == []
 
     @pytest.mark.asyncio
-    
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn")
-    async def test_s3_key_hook_list_keys_async(self, mock_client):
+    async def test_s3_key_hook_list_keys_async(self):
         """
         Test _list_keys for a valid response
         :return:
@@ -636,6 +633,7 @@ class TestAwsS3Hook:
         mock_paginator.paginate.return_value = mock_paginate
 
         s3_hook_async = S3Hook(client_type="S3", resource_type="S3")
+        mock_client = AsyncMock()
         mock_client.get_paginator = mock.Mock(return_value=mock_paginator)
         response = await s3_hook_async._list_keys_async(mock_client, 
"test_bucket", "test*")
         assert response == ["test_key", "test_key2"]
@@ -647,10 +645,7 @@ class TestAwsS3Hook:
             ("async-prefix1/", "async-prefix2/"),
         ],
     )
-    
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn")
-    async def test_s3_prefix_sensor_hook_list_prefixes_async(
-        self, mock_client, test_first_prefix, test_second_prefix
-    ):
+    async def test_s3_prefix_sensor_hook_list_prefixes_async(self, 
test_first_prefix, test_second_prefix):
         """
         Test list_prefixes whether it returns a valid response
         """
@@ -661,6 +656,7 @@ class TestAwsS3Hook:
         mock_paginator.paginate.return_value = mock_paginate
 
         s3_hook_async = S3Hook(client_type="S3", resource_type="S3")
+        mock_client = AsyncMock()
         mock_client.get_paginator = mock.Mock(return_value=mock_paginator)
 
         actual_output = await s3_hook_async.list_prefixes_async(mock_client, 
"test_bucket", "test")
@@ -674,10 +670,9 @@ class TestAwsS3Hook:
             ("async-prefix1", "test_bucket"),
         ],
     )
-    
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn")
     
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.list_prefixes_async")
     async def test_s3_prefix_sensor_hook_check_for_prefix_async(
-        self, mock_list_prefixes, mock_client, mock_prefix, mock_bucket
+        self, mock_list_prefixes, mock_prefix, mock_bucket
     ):
         """
         Test that _check_for_prefix method returns True when valid prefix is 
used and returns False
@@ -686,15 +681,16 @@ class TestAwsS3Hook:
         mock_list_prefixes.return_value = ["async-prefix1/", "async-prefix2/"]
 
         s3_hook_async = S3Hook(client_type="S3", resource_type="S3")
+        mock_client = AsyncMock()
 
         response = await s3_hook_async._check_for_prefix_async(
-            client=mock_client.return_value, prefix=mock_prefix, 
bucket_name=mock_bucket, delimiter="/"
+            client=mock_client, prefix=mock_prefix, bucket_name=mock_bucket, 
delimiter="/"
         )
 
         assert response is True
 
         response = await s3_hook_async._check_for_prefix_async(
-            client=mock_client.return_value,
+            client=mock_client,
             prefix="non-existing-prefix",
             bucket_name=mock_bucket,
             delimiter="/",
@@ -704,11 +700,10 @@ class TestAwsS3Hook:
 
     @pytest.mark.asyncio
     
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key")
-    
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn")
-    async def test__check_key_async_without_wildcard_match(self, 
mock_get_conn, mock_get_bucket_key):
+    async def test__check_key_async_without_wildcard_match(self, 
mock_get_bucket_key):
         """Test _check_key_async function without using wildcard_match"""
         mock_get_bucket_key.return_value = "test_bucket", "test.txt"
-        mock_client = mock_get_conn.return_value
+        mock_client = AsyncMock()
         mock_client.head_object = AsyncMock(return_value={"ContentLength": 0})
         s3_hook_async = S3Hook(client_type="S3", resource_type="S3")
         response = await s3_hook_async._check_key_async(
@@ -718,14 +713,11 @@ class TestAwsS3Hook:
 
     @pytest.mark.asyncio
     
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key")
-    
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn")
-    async def test_s3__check_key_async_without_wildcard_match_and_get_none(
-        self, mock_get_conn, mock_get_bucket_key
-    ):
+    async def 
test_s3__check_key_async_without_wildcard_match_and_get_none(self, 
mock_get_bucket_key):
         """Test _check_key_async function when get head object returns none"""
         mock_get_bucket_key.return_value = "test_bucket", "test.txt"
         s3_hook_async = S3Hook(client_type="S3", resource_type="S3")
-        mock_client = mock_get_conn.return_value
+        mock_client = AsyncMock()
         mock_client.head_object = AsyncMock(return_value=None)
         response = await s3_hook_async._check_key_async(
             mock_client, "test_bucket", False, "s3://test_bucket/file/test.txt"
@@ -734,7 +726,6 @@ class TestAwsS3Hook:
 
     # 
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key")
     @pytest.mark.asyncio
-    
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn")
     @pytest.mark.parametrize(
         "contents, result",
         [
@@ -774,12 +765,12 @@ class TestAwsS3Hook:
             ),
         ],
     )
-    async def test_s3__check_key_async_with_wildcard_match(self, 
mock_get_conn, contents, result):
+    async def test_s3__check_key_async_with_wildcard_match(self, contents, 
result):
         """Test _check_key_async function"""
-        client = mock_get_conn.return_value
-        paginator = client.get_paginator.return_value
-        r = paginator.paginate.return_value
+        client = Mock()
+        r = AsyncMock()
         r.__aiter__.return_value = [{"Contents": contents}]
+        client.get_paginator.return_value.paginate.return_value = r
         s3_hook_async = S3Hook(client_type="S3", resource_type="S3")
         response = await s3_hook_async._check_key_async(
             client=client,
@@ -799,15 +790,11 @@ class TestAwsS3Hook:
     )
     @pytest.mark.asyncio
     
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key")
-    
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn")
-    async def test__check_key_async_with_use_regex(
-        self, mock_get_conn, mock_get_bucket_key, key, pattern, expected
-    ):
+    async def test__check_key_async_with_use_regex(self, mock_get_bucket_key, 
key, pattern, expected):
         """Match AWS S3 key with regex expression"""
         mock_get_bucket_key.return_value = "test_bucket", pattern
-        client = mock_get_conn.return_value
-        paginator = client.get_paginator.return_value
-        r = paginator.paginate.return_value
+        client = Mock()
+        r = AsyncMock()
         r.__aiter__.return_value = [
             {
                 "Contents": [
@@ -820,6 +807,7 @@ class TestAwsS3Hook:
                 ]
             }
         ]
+        client.get_paginator.return_value.paginate.return_value = r
 
         s3_hook_async = S3Hook(client_type="S3", resource_type="S3")
         response = await s3_hook_async._check_key_async(
@@ -832,15 +820,14 @@ class TestAwsS3Hook:
         assert response is expected
 
     @pytest.mark.asyncio
-    
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn")
     
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook._list_keys_async")
-    async def test_s3_key_hook_is_keys_unchanged_false_async(self, 
mock_list_keys, mock_client):
+    async def test_s3_key_hook_is_keys_unchanged_false_async(self, 
mock_list_keys):
         """
         Test is_key_unchanged gives False response when the key value is 
unchanged in specified period.
         """
 
         mock_list_keys.return_value = ["test"]
-
+        mock_client = AsyncMock()
         s3_hook_async = S3Hook(client_type="S3", resource_type="S3")
         response = await s3_hook_async.is_keys_unchanged_async(
             client=mock_client.return_value,
@@ -875,18 +862,18 @@ class TestAwsS3Hook:
         assert response.get("status") == "pending"
 
     @pytest.mark.asyncio
-    
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn")
     
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook._list_keys_async")
-    async def test_s3_key_hook_is_keys_unchanged_exception_async(self, 
mock_list_keys, mock_client):
+    async def test_s3_key_hook_is_keys_unchanged_exception_async(self, 
mock_list_keys):
         """
         Test is_key_unchanged gives AirflowException.
         """
         mock_list_keys.return_value = []
 
         s3_hook_async = S3Hook(client_type="S3", resource_type="S3")
+        mock_client = AsyncMock()
 
         response = await s3_hook_async.is_keys_unchanged_async(
-            client=mock_client.return_value,
+            client=mock_client,
             bucket_name="test_bucket",
             prefix="test",
             inactivity_period=1,
@@ -900,18 +887,18 @@ class TestAwsS3Hook:
         assert response == {"message": "test_bucket/test between pokes.", 
"status": "error"}
 
     @pytest.mark.asyncio
-    
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn")
     
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook._list_keys_async")
-    async def test_s3_key_hook_is_keys_unchanged_async_handle_tzinfo(self, 
mock_list_keys, mock_client):
+    async def test_s3_key_hook_is_keys_unchanged_async_handle_tzinfo(self, 
mock_list_keys):
         """
         Test is_key_unchanged gives AirflowException.
         """
         mock_list_keys.return_value = []
 
         s3_hook_async = S3Hook(client_type="S3", resource_type="S3")
+        mock_client = AsyncMock()
 
         response = await s3_hook_async.is_keys_unchanged_async(
-            client=mock_client.return_value,
+            client=mock_client,
             bucket_name="test_bucket",
             prefix="test",
             inactivity_period=1,
@@ -925,18 +912,18 @@ class TestAwsS3Hook:
         assert response.get("status") == "pending"
 
     @pytest.mark.asyncio
-    
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn")
     
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook._list_keys_async")
-    async def test_s3_key_hook_is_keys_unchanged_inactivity_error_async(self, 
mock_list_keys, mock_client):
+    async def test_s3_key_hook_is_keys_unchanged_inactivity_error_async(self, 
mock_list_keys):
         """
         Test is_key_unchanged gives AirflowException.
         """
         mock_list_keys.return_value = []
 
         s3_hook_async = S3Hook(client_type="S3", resource_type="S3")
+        mock_client = AsyncMock()
 
         response = await s3_hook_async.is_keys_unchanged_async(
-            client=mock_client.return_value,
+            client=mock_client,
             bucket_name="test_bucket",
             prefix="test",
             inactivity_period=0,
@@ -953,20 +940,18 @@ class TestAwsS3Hook:
         }
 
     @pytest.mark.asyncio
-    
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn")
     
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook._list_keys_async")
-    async def test_s3_key_hook_is_keys_unchanged_pending_async_without_tzinfo(
-        self, mock_list_keys, mock_client
-    ):
+    async def 
test_s3_key_hook_is_keys_unchanged_pending_async_without_tzinfo(self, 
mock_list_keys):
         """
         Test is_key_unchanged gives AirflowException.
         """
         mock_list_keys.return_value = []
 
         s3_hook_async = S3Hook(client_type="S3", resource_type="S3")
+        mock_client = AsyncMock()
 
         response = await s3_hook_async.is_keys_unchanged_async(
-            client=mock_client.return_value,
+            client=mock_client,
             bucket_name="test_bucket",
             prefix="test",
             inactivity_period=1,
@@ -979,18 +964,18 @@ class TestAwsS3Hook:
         assert response.get("status") == "pending"
 
     @pytest.mark.asyncio
-    
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn")
     
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook._list_keys_async")
-    async def 
test_s3_key_hook_is_keys_unchanged_pending_async_with_tzinfo(self, 
mock_list_keys, mock_client):
+    async def 
test_s3_key_hook_is_keys_unchanged_pending_async_with_tzinfo(self, 
mock_list_keys):
         """
         Test is_key_unchanged gives AirflowException.
         """
         mock_list_keys.return_value = []
 
         s3_hook_async = S3Hook(client_type="S3", resource_type="S3")
+        mock_client = AsyncMock()
 
         response = await s3_hook_async.is_keys_unchanged_async(
-            client=mock_client.return_value,
+            client=mock_client,
             bucket_name="test_bucket",
             prefix="test",
             inactivity_period=1,
diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_base.py 
b/providers/amazon/tests/unit/amazon/aws/triggers/test_base.py
index 5dcf9437630..7e64cb38e1a 100644
--- a/providers/amazon/tests/unit/amazon/aws/triggers/test_base.py
+++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_base.py
@@ -18,7 +18,7 @@ from __future__ import annotations
 
 from typing import TYPE_CHECKING
 from unittest import mock
-from unittest.mock import MagicMock
+from unittest.mock import AsyncMock, MagicMock
 
 import pytest
 
@@ -38,7 +38,7 @@ class TestImplem(AwsBaseWaiterTrigger):
         super().__init__(**kwargs)
 
     def hook(self) -> AwsGenericHook:
-        return MagicMock()
+        return AsyncMock()
 
 
 class TestAwsBaseWaiterTrigger:
diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_ec2.py 
b/providers/amazon/tests/unit/amazon/aws/triggers/test_ec2.py
index e2470e6bfc3..c91b8ecf62e 100644
--- a/providers/amazon/tests/unit/amazon/aws/triggers/test_ec2.py
+++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_ec2.py
@@ -51,10 +51,10 @@ class TestEC2StateSensorTrigger:
 
     @pytest.mark.asyncio
     
@mock.patch("airflow.providers.amazon.aws.hooks.ec2.EC2Hook.get_instance_state_async")
-    @mock.patch("airflow.providers.amazon.aws.hooks.ec2.EC2Hook.async_conn")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.ec2.EC2Hook.get_async_conn")
     async def test_ec2_state_sensor_run(self, mock_async_conn, 
mock_get_instance_state_async):
         mock = AsyncMock()
-        mock_async_conn.__aenter__.return_value = mock
+        mock_async_conn.return_value.__aenter__.return_value = mock
         mock_get_instance_state_async.return_value = TEST_TARGET_STATE
 
         test_ec2_state_sensor = EC2StateSensorTrigger(
@@ -73,12 +73,12 @@ class TestEC2StateSensorTrigger:
     @pytest.mark.asyncio
     @mock.patch("asyncio.sleep")
     
@mock.patch("airflow.providers.amazon.aws.hooks.ec2.EC2Hook.get_instance_state_async")
-    @mock.patch("airflow.providers.amazon.aws.hooks.ec2.EC2Hook.async_conn")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.ec2.EC2Hook.get_async_conn")
     async def test_ec2_state_sensor_run_multiple(
         self, mock_async_conn, mock_get_instance_state_async, mock_sleep
     ):
         mock = AsyncMock()
-        mock_async_conn.__aenter__.return_value = mock
+        mock_async_conn.return_value.__aenter__.return_value = mock
         mock_get_instance_state_async.side_effect = ["test-state", 
TEST_TARGET_STATE]
         mock_sleep.return_value = True
 
diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_ecs.py 
b/providers/amazon/tests/unit/amazon/aws/triggers/test_ecs.py
index a5c11d92986..8b15d9fcf7a 100644
--- a/providers/amazon/tests/unit/amazon/aws/triggers/test_ecs.py
+++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_ecs.py
@@ -36,12 +36,12 @@ if TYPE_CHECKING:
 
 class TestTaskDoneTrigger:
     @pytest.mark.asyncio
-    @mock.patch.object(EcsHook, "async_conn")
+    @mock.patch.object(EcsHook, "get_async_conn")
     # this mock is only necessary to avoid a "No module named 'aiobotocore'" 
error in the LatestBoto CI step
-    @mock.patch.object(AwsLogsHook, "async_conn")
+    @mock.patch.object(AwsLogsHook, "get_async_conn")
     async def test_run_until_error(self, _, client_mock):
         a_mock = mock.MagicMock()
-        client_mock.__aenter__.return_value = a_mock
+        client_mock.return_value.__aenter__.return_value = a_mock
         wait_mock = AsyncMock()
         wait_mock.side_effect = [
             WaiterError("name", "reason", {"tasks": [{"lastStatus": 
"my_status"}]}),
@@ -57,12 +57,12 @@ class TestTaskDoneTrigger:
         assert wait_mock.call_count == 3
 
     @pytest.mark.asyncio
-    @mock.patch.object(EcsHook, "async_conn")
+    @mock.patch.object(EcsHook, "get_async_conn")
     # this mock is only necessary to avoid a "No module named 'aiobotocore'" 
error in the LatestBoto CI step
-    @mock.patch.object(AwsLogsHook, "async_conn")
+    @mock.patch.object(AwsLogsHook, "get_async_conn")
     async def test_run_until_timeout(self, _, client_mock):
         a_mock = mock.MagicMock()
-        client_mock.__aenter__.return_value = a_mock
+        client_mock.return_value.__aenter__.return_value = a_mock
         wait_mock = AsyncMock()
         wait_mock.side_effect = WaiterError("name", "reason", {"tasks": 
[{"lastStatus": "my_status"}]})
         a_mock.get_waiter().wait = wait_mock
@@ -76,12 +76,12 @@ class TestTaskDoneTrigger:
         assert "max attempts" in str(err.value)
 
     @pytest.mark.asyncio
-    @mock.patch.object(EcsHook, "async_conn")
+    @mock.patch.object(EcsHook, "get_async_conn")
     # this mock is only necessary to avoid a "No module named 'aiobotocore'" 
error in the LatestBoto CI step
-    @mock.patch.object(AwsLogsHook, "async_conn")
+    @mock.patch.object(AwsLogsHook, "get_async_conn")
     async def test_run_success(self, _, client_mock):
         a_mock = mock.MagicMock()
-        client_mock.__aenter__.return_value = a_mock
+        client_mock.return_value.__aenter__.return_value = a_mock
         wait_mock = AsyncMock()
         a_mock.get_waiter().wait = wait_mock
 
diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_eks.py 
b/providers/amazon/tests/unit/amazon/aws/triggers/test_eks.py
index 16f49a769d4..b41994ea226 100644
--- a/providers/amazon/tests/unit/amazon/aws/triggers/test_eks.py
+++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_eks.py
@@ -39,11 +39,11 @@ FARGATE_PROFILES = ["p1", "p2"]
 
 class TestEksTrigger:
     def setup_method(self):
-        self.async_conn_patcher = 
patch("airflow.providers.amazon.aws.hooks.eks.EksHook.async_conn")
+        self.async_conn_patcher = 
patch("airflow.providers.amazon.aws.hooks.eks.EksHook.get_async_conn")
         self.mock_async_conn = self.async_conn_patcher.start()
 
         self.mock_client = AsyncMock()
-        self.mock_async_conn.__aenter__.return_value = self.mock_client
+        self.mock_async_conn.return_value.__aenter__.return_value = 
self.mock_client
 
         self.async_wait_patcher = patch(
             "airflow.providers.amazon.aws.triggers.eks.async_wait", 
return_value=True
diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune.py 
b/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune.py
index 6e046efa5c0..bcab5f8e5a2 100644
--- a/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune.py
+++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune.py
@@ -47,9 +47,9 @@ class TestNeptuneClusterAvailableTrigger:
 
     @pytest.mark.asyncio
     
@mock.patch("airflow.providers.amazon.aws.hooks.neptune.NeptuneHook.get_waiter")
-    
@mock.patch("airflow.providers.amazon.aws.hooks.neptune.NeptuneHook.async_conn")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.neptune.NeptuneHook.get_async_conn")
     async def test_run_success(self, mock_async_conn, mock_get_waiter):
-        mock_async_conn.__aenter__.return_value = "available"
+        mock_async_conn.return_value.__aenter__.return_value = "available"
         mock_get_waiter().wait = AsyncMock()
         trigger = NeptuneClusterAvailableTrigger(db_cluster_id=CLUSTER_ID)
         generator = trigger.run()
@@ -73,9 +73,9 @@ class TestNeptuneClusterStoppedTrigger:
 
     @pytest.mark.asyncio
     
@mock.patch("airflow.providers.amazon.aws.hooks.neptune.NeptuneHook.get_waiter")
-    
@mock.patch("airflow.providers.amazon.aws.hooks.neptune.NeptuneHook.async_conn")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.neptune.NeptuneHook.get_async_conn")
     async def test_run_success(self, mock_async_conn, mock_get_waiter):
-        mock_async_conn.__aenter__.return_value = "stopped"
+        mock_async_conn.return_value.__aenter__.return_value = "stopped"
         mock_get_waiter().wait = AsyncMock()
         trigger = NeptuneClusterStoppedTrigger(db_cluster_id=CLUSTER_ID)
         generator = trigger.run()
@@ -102,9 +102,9 @@ class TestNeptuneClusterInstancesAvailableTrigger:
 
     @pytest.mark.asyncio
     
@mock.patch("airflow.providers.amazon.aws.hooks.neptune.NeptuneHook.get_waiter")
-    
@mock.patch("airflow.providers.amazon.aws.hooks.neptune.NeptuneHook.async_conn")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.neptune.NeptuneHook.get_async_conn")
     async def test_run_success(self, mock_async_conn, mock_get_waiter):
-        mock_async_conn.__aenter__.return_value = "available"
+        mock_async_conn.return_value.__aenter__.return_value = "available"
         mock_get_waiter().wait = AsyncMock()
         trigger = 
NeptuneClusterInstancesAvailableTrigger(db_cluster_id=CLUSTER_ID)
         generator = trigger.run()
@@ -114,10 +114,10 @@ class TestNeptuneClusterInstancesAvailableTrigger:
         assert mock_get_waiter().wait.call_count == 1
 
     @pytest.mark.asyncio
-    
@mock.patch("airflow.providers.amazon.aws.hooks.neptune.NeptuneHook.async_conn")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.neptune.NeptuneHook.get_async_conn")
     async def test_run_fail(self, mock_async_conn):
         a_mock = mock.MagicMock()
-        mock_async_conn.__aenter__.return_value = a_mock
+        mock_async_conn.return_value.__aenter__.return_value = a_mock
         wait_mock = AsyncMock()
         wait_mock.side_effect = WaiterError("name", "reason", {"test": 
[{"lastStatus": "my_status"}]})
         a_mock.get_waiter().wait = wait_mock
diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_s3.py 
b/providers/amazon/tests/unit/amazon/aws/triggers/test_s3.py
index 14c79f1e462..5043e2bb377 100644
--- a/providers/amazon/tests/unit/amazon/aws/triggers/test_s3.py
+++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_s3.py
@@ -52,12 +52,12 @@ class TestS3KeyTrigger:
         }
 
     @pytest.mark.asyncio
-    
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn")
+    
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_async_conn")
     async def test_run_success(self, mock_client):
         """
         Test if the task is run is in triggerr successfully.
         """
-        mock_client.return_value.check_key.return_value = True
+        mock_client.return_value.return_value.check_key.return_value = True
         trigger = S3KeyTrigger(bucket_key="s3://test_bucket/file", 
bucket_name="test_bucket")
         task = asyncio.create_task(trigger.run().__anext__())
         await asyncio.sleep(0.5)
@@ -67,7 +67,7 @@ class TestS3KeyTrigger:
 
     @pytest.mark.asyncio
     
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.check_key_async")
-    
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn")
+    
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_async_conn")
     async def test_run_pending(self, mock_client, mock_check_key_async):
         """
         Test if the task is run is in trigger successfully and set check_key 
to return false.
@@ -115,10 +115,10 @@ class TestS3KeysUnchangedTrigger:
         }
 
     @pytest.mark.asyncio
-    
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn")
+    
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_async_conn")
     async def test_run_wait(self, mock_client):
         """Test if the task is run in trigger successfully."""
-        mock_client.return_value.check_key.return_value = True
+        mock_client.return_value.return_value.check_key.return_value = True
         trigger = S3KeysUnchangedTrigger(bucket_name="test_bucket", 
prefix="test")
         with mock_client:
             task = asyncio.create_task(trigger.run().__anext__())
@@ -135,7 +135,7 @@ class TestS3KeysUnchangedTrigger:
             S3KeysUnchangedTrigger(bucket_name="test_bucket", prefix="test", 
inactivity_period=-100)
 
     @pytest.mark.asyncio
-    
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn")
+    
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_async_conn")
     
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.is_keys_unchanged_async")
     async def test_run_success(self, mock_is_keys_unchanged, mock_client):
         """
@@ -148,7 +148,7 @@ class TestS3KeysUnchangedTrigger:
         assert TriggerEvent({"status": "success"}) == actual
 
     @pytest.mark.asyncio
-    
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn")
+    
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_async_conn")
     
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.is_keys_unchanged_async")
     async def test_run_pending(self, mock_is_keys_unchanged, mock_client):
         """Test if the task is run in triggerer successfully."""
diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_sagemaker.py 
b/providers/amazon/tests/unit/amazon/aws/triggers/test_sagemaker.py
index a69a7837ded..0f1851df048 100644
--- a/providers/amazon/tests/unit/amazon/aws/triggers/test_sagemaker.py
+++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_sagemaker.py
@@ -60,9 +60,9 @@ class TestSagemakerTrigger:
         ],
     )
     
@mock.patch("airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.get_waiter")
-    
@mock.patch("airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.async_conn")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.get_async_conn")
     async def test_sagemaker_trigger_run_all_job_types(self, mock_async_conn, 
mock_get_waiter, job_type):
-        mock_async_conn.__aenter__.return_value = mock.MagicMock()
+        mock_async_conn.return_value.__aenter__.return_value = mock.MagicMock()
 
         mock_get_waiter().wait = AsyncMock()
 

Reply via email to