This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3689cee485 Use a continuation token to get logs in ecs (#31824)
3689cee485 is described below
commit 3689cee485215651bdb5ef434f24ab8774995a37
Author: Raphaƫl Vandon <[email protected]>
AuthorDate: Thu Jun 15 13:25:42 2023 -0700
Use a continuation token to get logs in ecs (#31824)
---------
Co-authored-by: Niko Oliveira <[email protected]>
---
airflow/providers/amazon/aws/hooks/ecs.py | 13 ++++++++-----
airflow/providers/amazon/aws/hooks/logs.py | 27 +++++++++++++++++++--------
tests/providers/amazon/aws/hooks/test_ecs.py | 16 ++++++++++++++++
3 files changed, 43 insertions(+), 13 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/ecs.py
b/airflow/providers/amazon/aws/hooks/ecs.py
index 5f74b4c138..17119d5968 100644
--- a/airflow/providers/amazon/aws/hooks/ecs.py
+++ b/airflow/providers/amazon/aws/hooks/ecs.py
@@ -171,17 +171,20 @@ class EcsTaskLogFetcher(Thread):
self.hook = AwsLogsHook(aws_conn_id=aws_conn_id,
region_name=region_name)
def run(self) -> None:
- logs_to_skip = 0
+ continuation_token = AwsLogsHook.ContinuationToken()
while not self.is_stopped():
time.sleep(self.fetch_interval.total_seconds())
- log_events = self._get_log_events(logs_to_skip)
+ log_events = self._get_log_events(continuation_token)
for log_event in log_events:
self.logger.info(self._event_to_str(log_event))
- logs_to_skip += 1
- def _get_log_events(self, skip: int = 0) -> Generator:
+ def _get_log_events(self, skip_token: AwsLogsHook.ContinuationToken | None
= None) -> Generator:
+ if skip_token is None:
+ skip_token = AwsLogsHook.ContinuationToken()
try:
- yield from self.hook.get_log_events(self.log_group,
self.log_stream_name, skip=skip)
+ yield from self.hook.get_log_events(
+ self.log_group, self.log_stream_name,
continuation_token=skip_token
+ )
except ClientError as error:
if error.response["Error"]["Code"] != "ResourceNotFoundException":
self.logger.warning("Error on retrieving Cloudwatch log
events", error)
diff --git a/airflow/providers/amazon/aws/hooks/logs.py
b/airflow/providers/amazon/aws/hooks/logs.py
index 6680e1939e..2dff0aaaf3 100644
--- a/airflow/providers/amazon/aws/hooks/logs.py
+++ b/airflow/providers/amazon/aws/hooks/logs.py
@@ -50,6 +50,12 @@ class AwsLogsHook(AwsBaseHook):
kwargs["client_type"] = "logs"
super().__init__(*args, **kwargs)
+ class ContinuationToken:
+ """Just a wrapper around a str token to allow updating it from the
caller."""
+
+ def __init__(self):
+ self.value: str | None = None
+
def get_log_events(
self,
log_group: str,
@@ -57,6 +63,7 @@ class AwsLogsHook(AwsBaseHook):
start_time: int = 0,
skip: int = 0,
start_from_head: bool = True,
+ continuation_token: ContinuationToken | None = None,
) -> Generator:
"""
A generator for log items in a single stream. This will yield all the
@@ -72,16 +79,20 @@ class AwsLogsHook(AwsBaseHook):
This is for when there are multiple entries at the same timestamp.
:param start_from_head: whether to start from the beginning (True) of
the log or
at the end of the log (False).
+ :param continuation_token: a token indicating where to read logs from.
+ Will be updated as this method reads new logs, to be reused in
subsequent calls.
:return: | A CloudWatch log event with the following key-value pairs:
| 'timestamp' (int): The time in milliseconds of the event.
| 'message' (str): The log event data.
| 'ingestionTime' (int): The time in milliseconds the event
was ingested.
"""
+ if continuation_token is None:
+ continuation_token = AwsLogsHook.ContinuationToken()
+
num_consecutive_empty_response = 0
- next_token = None
while True:
- if next_token is not None:
- token_arg: dict[str, str] = {"nextToken": next_token}
+ if continuation_token.value is not None:
+ token_arg: dict[str, str] = {"nextToken":
continuation_token.value}
else:
token_arg = {}
@@ -105,16 +116,16 @@ class AwsLogsHook(AwsBaseHook):
yield from events
+ if continuation_token.value == response["nextForwardToken"]:
+ return
+
if not event_count:
num_consecutive_empty_response += 1
if num_consecutive_empty_response >=
NUM_CONSECUTIVE_EMPTY_RESPONSE_EXIT_THRESHOLD:
# Exit if there are more than
NUM_CONSECUTIVE_EMPTY_RESPONSE_EXIT_THRESHOLD consecutive
# empty responses
return
- elif next_token != response["nextForwardToken"]:
- num_consecutive_empty_response = 0
else:
- # Exit if the value of nextForwardToken is same in subsequent
calls
- return
+ num_consecutive_empty_response = 0
- next_token = response["nextForwardToken"]
+ continuation_token.value = response["nextForwardToken"]
diff --git a/tests/providers/amazon/aws/hooks/test_ecs.py
b/tests/providers/amazon/aws/hooks/test_ecs.py
index d9a4f53fa8..d90ec65e52 100644
--- a/tests/providers/amazon/aws/hooks/test_ecs.py
+++ b/tests/providers/amazon/aws/hooks/test_ecs.py
@@ -18,12 +18,14 @@ from __future__ import annotations
from datetime import timedelta
from unittest import mock
+from unittest.mock import PropertyMock
import pytest
from botocore.exceptions import ClientError
from airflow.providers.amazon.aws.exceptions import EcsOperatorError,
EcsTaskFailToStart
from airflow.providers.amazon.aws.hooks.ecs import EcsHook, EcsTaskLogFetcher,
should_retry, should_retry_eni
+from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
DEFAULT_CONN_ID: str = "aws_default"
REGION: str = "us-east-1"
@@ -146,6 +148,20 @@ class TestEcsTaskLogFetcher:
with pytest.raises(Exception):
next(self.log_fetcher._get_log_events())
+ @mock.patch.object(AwsLogsHook, "conn", new_callable=PropertyMock)
+ def test_get_log_events_updates_token(self, logs_conn_mock):
+ logs_conn_mock().get_log_events.return_value = {
+ "events": ["my_event"],
+ "nextForwardToken": "my_next_token",
+ }
+
+ token = AwsLogsHook.ContinuationToken()
+ list(self.log_fetcher._get_log_events(token))
+
+ assert token.value == "my_next_token"
+ # 2 calls expected, it's only on the second one that the stop
condition old_token == next_token is met
+ assert logs_conn_mock().get_log_events.call_count == 2
+
def test_event_to_str(self):
events = [
{"timestamp": 1617400267123, "message": "First"},