This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 479719297f add async wait method to the "with logging" aws utils
(#32055)
479719297f is described below
commit 479719297ff4efa8373dc7b6909bfc59a5444c3a
Author: Raphaƫl Vandon <[email protected]>
AuthorDate: Thu Jun 22 10:21:32 2023 -0700
add async wait method to the "with logging" aws utils (#32055)
Also changed the status formatting in the logs so that it'd not be done if
log level is not including INFO
---
.../amazon/aws/utils/waiter_with_logging.py | 90 ++++++++++++++++++----
.../amazon/aws/utils/test_waiter_with_logging.py | 59 +++++++++++++-
2 files changed, 133 insertions(+), 16 deletions(-)
diff --git a/airflow/providers/amazon/aws/utils/waiter_with_logging.py
b/airflow/providers/amazon/aws/utils/waiter_with_logging.py
index 8c9e33077f..b883e36bdb 100644
--- a/airflow/providers/amazon/aws/utils/waiter_with_logging.py
+++ b/airflow/providers/amazon/aws/utils/waiter_with_logging.py
@@ -17,8 +17,10 @@
from __future__ import annotations
+import asyncio
import logging
import time
+from typing import Any
import jmespath
from botocore.exceptions import WaiterError
@@ -31,10 +33,10 @@ def wait(
waiter: Waiter,
waiter_delay: int,
max_attempts: int,
- args: dict,
+ args: dict[str, Any],
failure_message: str,
status_message: str,
- status_args: list,
+ status_args: list[str],
) -> None:
"""
Use a boto waiter to poll an AWS service for the specified state. Although
this function
@@ -47,7 +49,7 @@ def wait(
:param args: The arguments to pass to the waiter.
:param failure_message: The message to log if a failure state is reached.
:param status_message: The message logged when printing the status of the
service.
- :param status_args: A list containing the arguments to retrieve status
information from
+ :param status_args: A list containing the JMESPath queries to retrieve
status information from
the waiter response.
e.g.
response = {"Cluster": {"state": "CREATING"}}
@@ -68,23 +70,83 @@ def wait(
except WaiterError as error:
if "terminal failure" in str(error):
raise AirflowException(f"{failure_message}: {error}")
- status_string = _format_status_string(status_args,
error.last_response)
- log.info("%s: %s", status_message, status_string)
+
+ log.info("%s: %s", status_message,
_LazyStatusFormatter(status_args, error.last_response))
+ if attempt >= max_attempts:
+ raise AirflowException("Waiter error: max attempts reached")
+
time.sleep(waiter_delay)
+
+async def async_wait(
+ waiter: Waiter,
+ waiter_delay: int,
+ max_attempts: int,
+ args: dict[str, Any],
+ failure_message: str,
+ status_message: str,
+ status_args: list[str],
+):
+ """
+ Use an async boto waiter to poll an AWS service for the specified state.
Although this function
+ uses boto waiters to poll the state of the service, it logs the response
of the service
+ after every attempt, which is not currently supported by boto waiters.
+
+ :param waiter: The boto waiter to use.
+ :param waiter_delay: The amount of time in seconds to wait between
attempts.
+ :param max_attempts: The maximum number of attempts to be made.
+ :param args: The arguments to pass to the waiter.
+ :param failure_message: The message to log if a failure state is reached.
+ :param status_message: The message logged when printing the status of the
service.
+ :param status_args: A list containing the JMESPath queries to retrieve
status information from
+ the waiter response.
+ e.g.
+ response = {"Cluster": {"state": "CREATING"}}
+ status_args = ["Cluster.state"]
+
+ response = {
+ "Clusters": [{"state": "CREATING", "details": "User initiated."},]
+ }
+ status_args = ["Clusters[0].state", "Clusters[0].details"]
+ """
+ log = logging.getLogger(__name__)
+ attempt = 0
+ while True:
+ attempt += 1
+ try:
+ await waiter.wait(**args, WaiterConfig={"MaxAttempts": 1})
+ break
+ except WaiterError as error:
+ if "terminal failure" in str(error):
+ raise AirflowException(f"{failure_message}: {error}")
+
+ log.info("%s: %s", status_message,
_LazyStatusFormatter(status_args, error.last_response))
if attempt >= max_attempts:
raise AirflowException("Waiter error: max attempts reached")
+ await asyncio.sleep(waiter_delay)
+
-def _format_status_string(args, response):
+class _LazyStatusFormatter:
"""
- Loops through the supplied args list and generates a string
- which contains values from the waiter response.
+ a wrapper containing the info necessary to extract the status from a
response,
+ that'll only compute the value when necessary.
+ Used to avoid computations if the logs are disabled at the given level.
"""
- values = []
- for arg in args:
- value = jmespath.search(arg, response)
- if value is not None and value != "":
- values.append(str(value))
- return " - ".join(values)
+ def __init__(self, jmespath_queries: list[str], response: dict[str, Any]):
+ self.jmespath_queries = jmespath_queries
+ self.response = response
+
+ def __str__(self):
+ """
+ Loops through the supplied args list and generates a string
+ which contains values from the waiter response.
+ """
+ values = []
+ for query in self.jmespath_queries:
+ value = jmespath.search(query, self.response)
+ if value is not None and value != "":
+ values.append(str(value))
+
+ return " - ".join(values)
diff --git a/tests/providers/amazon/aws/utils/test_waiter_with_logging.py
b/tests/providers/amazon/aws/utils/test_waiter_with_logging.py
index 2ca74936d7..4580c21054 100644
--- a/tests/providers/amazon/aws/utils/test_waiter_with_logging.py
+++ b/tests/providers/amazon/aws/utils/test_waiter_with_logging.py
@@ -20,12 +20,13 @@ from __future__ import annotations
import logging
from typing import Any
from unittest import mock
+from unittest.mock import AsyncMock
import pytest
from botocore.exceptions import WaiterError
from airflow.exceptions import AirflowException
-from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
+from airflow.providers.amazon.aws.utils.waiter_with_logging import
_LazyStatusFormatter, async_wait, wait
def generate_response(state: str) -> dict[str, Any]:
@@ -63,7 +64,7 @@ class TestWaiter:
"MaxAttempts": 1,
},
)
- mock_waiter.wait.call_count == 3
+ assert mock_waiter.wait.call_count == 3
mock_sleep.assert_called_with(123)
assert (
caplog.record_tuples
@@ -77,6 +78,36 @@ class TestWaiter:
* 2
)
+ @pytest.mark.asyncio
+ async def test_async_wait(self, caplog):
+ mock_waiter = mock.MagicMock()
+ error = WaiterError(
+ name="test_waiter",
+ reason="test_reason",
+ last_response=generate_response("Pending"),
+ )
+ mock_waiter.wait = AsyncMock()
+ mock_waiter.wait.side_effect = [error, error, True]
+
+ await async_wait(
+ waiter=mock_waiter,
+ waiter_delay=0,
+ max_attempts=456,
+ args={"test_arg": "test_value"},
+ failure_message="test failure message",
+ status_message="test status message",
+ status_args=["Status.State"],
+ )
+
+ mock_waiter.wait.assert_called_with(
+ **{"test_arg": "test_value"},
+ WaiterConfig={
+ "MaxAttempts": 1,
+ },
+ )
+ assert mock_waiter.wait.call_count == 3
+ assert caplog.messages == ["test status message: Pending", "test
status message: Pending"]
+
@mock.patch("time.sleep")
def test_wait_max_attempts_exceeded(self, mock_sleep, caplog):
mock_sleep.return_value = True
@@ -302,3 +333,27 @@ class TestWaiter:
]
* 2
)
+
+ @mock.patch.object(_LazyStatusFormatter, "__str__")
+ def test_status_formatting_not_done_if_higher_log_level(self,
status_format_mock: mock.MagicMock, caplog):
+ mock_waiter = mock.MagicMock()
+ error = WaiterError(
+ name="test_waiter",
+ reason="test_reason",
+ last_response=generate_response("Pending"),
+ )
+ mock_waiter.wait.side_effect = [error, error, True]
+
+ with caplog.at_level(level=logging.WARNING):
+ wait(
+ waiter=mock_waiter,
+ waiter_delay=0,
+ max_attempts=456,
+ args={"test_arg": "test_value"},
+ failure_message="test failure message",
+ status_message="test status message",
+ status_args=["Status.State"],
+ )
+
+ assert len(caplog.messages) == 0
+ status_format_mock.assert_not_called()