This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 479719297f add async wait method to the "with logging" aws utils 
(#32055)
479719297f is described below

commit 479719297ff4efa8373dc7b6909bfc59a5444c3a
Author: RaphaĆ«l Vandon <[email protected]>
AuthorDate: Thu Jun 22 10:21:32 2023 -0700

    add async wait method to the "with logging" aws utils (#32055)
    
    Also changed the status formatting in the logs so that it'd not be done if 
log level is not including INFO
---
 .../amazon/aws/utils/waiter_with_logging.py        | 90 ++++++++++++++++++----
 .../amazon/aws/utils/test_waiter_with_logging.py   | 59 +++++++++++++-
 2 files changed, 133 insertions(+), 16 deletions(-)

diff --git a/airflow/providers/amazon/aws/utils/waiter_with_logging.py 
b/airflow/providers/amazon/aws/utils/waiter_with_logging.py
index 8c9e33077f..b883e36bdb 100644
--- a/airflow/providers/amazon/aws/utils/waiter_with_logging.py
+++ b/airflow/providers/amazon/aws/utils/waiter_with_logging.py
@@ -17,8 +17,10 @@
 
 from __future__ import annotations
 
+import asyncio
 import logging
 import time
+from typing import Any
 
 import jmespath
 from botocore.exceptions import WaiterError
@@ -31,10 +33,10 @@ def wait(
     waiter: Waiter,
     waiter_delay: int,
     max_attempts: int,
-    args: dict,
+    args: dict[str, Any],
     failure_message: str,
     status_message: str,
-    status_args: list,
+    status_args: list[str],
 ) -> None:
     """
     Use a boto waiter to poll an AWS service for the specified state. Although 
this function
@@ -47,7 +49,7 @@ def wait(
     :param args: The arguments to pass to the waiter.
     :param failure_message: The message to log if a failure state is reached.
     :param status_message: The message logged when printing the status of the 
service.
-    :param status_args: A list containing the arguments to retrieve status 
information from
+    :param status_args: A list containing the JMESPath queries to retrieve 
status information from
         the waiter response.
         e.g.
         response = {"Cluster": {"state": "CREATING"}}
@@ -68,23 +70,83 @@ def wait(
         except WaiterError as error:
             if "terminal failure" in str(error):
                 raise AirflowException(f"{failure_message}: {error}")
-            status_string = _format_status_string(status_args, 
error.last_response)
-            log.info("%s: %s", status_message, status_string)
+
+            log.info("%s: %s", status_message, 
_LazyStatusFormatter(status_args, error.last_response))
+            if attempt >= max_attempts:
+                raise AirflowException("Waiter error: max attempts reached")
+
             time.sleep(waiter_delay)
 
+
+async def async_wait(
+    waiter: Waiter,
+    waiter_delay: int,
+    max_attempts: int,
+    args: dict[str, Any],
+    failure_message: str,
+    status_message: str,
+    status_args: list[str],
+):
+    """
+    Use an async boto waiter to poll an AWS service for the specified state. 
Although this function
+    uses boto waiters to poll the state of the service, it logs the response 
of the service
+    after every attempt, which is not currently supported by boto waiters.
+
+    :param waiter: The boto waiter to use.
+    :param waiter_delay: The amount of time in seconds to wait between 
attempts.
+    :param max_attempts: The maximum number of attempts to be made.
+    :param args: The arguments to pass to the waiter.
+    :param failure_message: The message to log if a failure state is reached.
+    :param status_message: The message logged when printing the status of the 
service.
+    :param status_args: A list containing the JMESPath queries to retrieve 
status information from
+        the waiter response.
+        e.g.
+        response = {"Cluster": {"state": "CREATING"}}
+        status_args = ["Cluster.state"]
+
+        response = {
+        "Clusters": [{"state": "CREATING", "details": "User initiated."},]
+        }
+        status_args = ["Clusters[0].state", "Clusters[0].details"]
+    """
+    log = logging.getLogger(__name__)
+    attempt = 0
+    while True:
+        attempt += 1
+        try:
+            await waiter.wait(**args, WaiterConfig={"MaxAttempts": 1})
+            break
+        except WaiterError as error:
+            if "terminal failure" in str(error):
+                raise AirflowException(f"{failure_message}: {error}")
+
+            log.info("%s: %s", status_message, 
_LazyStatusFormatter(status_args, error.last_response))
             if attempt >= max_attempts:
                 raise AirflowException("Waiter error: max attempts reached")
 
+            await asyncio.sleep(waiter_delay)
+
 
-def _format_status_string(args, response):
+class _LazyStatusFormatter:
     """
-    Loops through the supplied args list and generates a string
-    which contains values from the waiter response.
+    a wrapper containing the info necessary to extract the status from a 
response,
+    that'll only compute the value when necessary.
+    Used to avoid computations if the logs are disabled at the given level.
     """
-    values = []
-    for arg in args:
-        value = jmespath.search(arg, response)
-        if value is not None and value != "":
-            values.append(str(value))
 
-    return " - ".join(values)
+    def __init__(self, jmespath_queries: list[str], response: dict[str, Any]):
+        self.jmespath_queries = jmespath_queries
+        self.response = response
+
+    def __str__(self):
+        """
+        Loops through the supplied args list and generates a string
+        which contains values from the waiter response.
+        """
+        values = []
+        for query in self.jmespath_queries:
+            value = jmespath.search(query, self.response)
+            if value is not None and value != "":
+                values.append(str(value))
+
+        return " - ".join(values)
diff --git a/tests/providers/amazon/aws/utils/test_waiter_with_logging.py 
b/tests/providers/amazon/aws/utils/test_waiter_with_logging.py
index 2ca74936d7..4580c21054 100644
--- a/tests/providers/amazon/aws/utils/test_waiter_with_logging.py
+++ b/tests/providers/amazon/aws/utils/test_waiter_with_logging.py
@@ -20,12 +20,13 @@ from __future__ import annotations
 import logging
 from typing import Any
 from unittest import mock
+from unittest.mock import AsyncMock
 
 import pytest
 from botocore.exceptions import WaiterError
 
 from airflow.exceptions import AirflowException
-from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
+from airflow.providers.amazon.aws.utils.waiter_with_logging import 
_LazyStatusFormatter, async_wait, wait
 
 
 def generate_response(state: str) -> dict[str, Any]:
@@ -63,7 +64,7 @@ class TestWaiter:
                 "MaxAttempts": 1,
             },
         )
-        mock_waiter.wait.call_count == 3
+        assert mock_waiter.wait.call_count == 3
         mock_sleep.assert_called_with(123)
         assert (
             caplog.record_tuples
@@ -77,6 +78,36 @@ class TestWaiter:
             * 2
         )
 
+    @pytest.mark.asyncio
+    async def test_async_wait(self, caplog):
+        mock_waiter = mock.MagicMock()
+        error = WaiterError(
+            name="test_waiter",
+            reason="test_reason",
+            last_response=generate_response("Pending"),
+        )
+        mock_waiter.wait = AsyncMock()
+        mock_waiter.wait.side_effect = [error, error, True]
+
+        await async_wait(
+            waiter=mock_waiter,
+            waiter_delay=0,
+            max_attempts=456,
+            args={"test_arg": "test_value"},
+            failure_message="test failure message",
+            status_message="test status message",
+            status_args=["Status.State"],
+        )
+
+        mock_waiter.wait.assert_called_with(
+            **{"test_arg": "test_value"},
+            WaiterConfig={
+                "MaxAttempts": 1,
+            },
+        )
+        assert mock_waiter.wait.call_count == 3
+        assert caplog.messages == ["test status message: Pending", "test 
status message: Pending"]
+
     @mock.patch("time.sleep")
     def test_wait_max_attempts_exceeded(self, mock_sleep, caplog):
         mock_sleep.return_value = True
@@ -302,3 +333,27 @@ class TestWaiter:
             ]
             * 2
         )
+
+    @mock.patch.object(_LazyStatusFormatter, "__str__")
+    def test_status_formatting_not_done_if_higher_log_level(self, 
status_format_mock: mock.MagicMock, caplog):
+        mock_waiter = mock.MagicMock()
+        error = WaiterError(
+            name="test_waiter",
+            reason="test_reason",
+            last_response=generate_response("Pending"),
+        )
+        mock_waiter.wait.side_effect = [error, error, True]
+
+        with caplog.at_level(level=logging.WARNING):
+            wait(
+                waiter=mock_waiter,
+                waiter_delay=0,
+                max_attempts=456,
+                args={"test_arg": "test_value"},
+                failure_message="test failure message",
+                status_message="test status message",
+                status_args=["Status.State"],
+            )
+
+        assert len(caplog.messages) == 0
+        status_format_mock.assert_not_called()

Reply via email to