This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3689cee485 Use a continuation token to get logs in ecs (#31824)
3689cee485 is described below

commit 3689cee485215651bdb5ef434f24ab8774995a37
Author: RaphaĆ«l Vandon <[email protected]>
AuthorDate: Thu Jun 15 13:25:42 2023 -0700

    Use a continuation token to get logs in ecs (#31824)
    
    ---------
    
    Co-authored-by: Niko Oliveira <[email protected]>
---
 airflow/providers/amazon/aws/hooks/ecs.py    | 13 ++++++++-----
 airflow/providers/amazon/aws/hooks/logs.py   | 27 +++++++++++++++++++--------
 tests/providers/amazon/aws/hooks/test_ecs.py | 16 ++++++++++++++++
 3 files changed, 43 insertions(+), 13 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/ecs.py 
b/airflow/providers/amazon/aws/hooks/ecs.py
index 5f74b4c138..17119d5968 100644
--- a/airflow/providers/amazon/aws/hooks/ecs.py
+++ b/airflow/providers/amazon/aws/hooks/ecs.py
@@ -171,17 +171,20 @@ class EcsTaskLogFetcher(Thread):
         self.hook = AwsLogsHook(aws_conn_id=aws_conn_id, 
region_name=region_name)
 
     def run(self) -> None:
-        logs_to_skip = 0
+        continuation_token = AwsLogsHook.ContinuationToken()
         while not self.is_stopped():
             time.sleep(self.fetch_interval.total_seconds())
-            log_events = self._get_log_events(logs_to_skip)
+            log_events = self._get_log_events(continuation_token)
             for log_event in log_events:
                 self.logger.info(self._event_to_str(log_event))
-                logs_to_skip += 1
 
-    def _get_log_events(self, skip: int = 0) -> Generator:
+    def _get_log_events(self, skip_token: AwsLogsHook.ContinuationToken | None 
= None) -> Generator:
+        if skip_token is None:
+            skip_token = AwsLogsHook.ContinuationToken()
         try:
-            yield from self.hook.get_log_events(self.log_group, 
self.log_stream_name, skip=skip)
+            yield from self.hook.get_log_events(
+                self.log_group, self.log_stream_name, 
continuation_token=skip_token
+            )
         except ClientError as error:
             if error.response["Error"]["Code"] != "ResourceNotFoundException":
                 self.logger.warning("Error on retrieving Cloudwatch log 
events", error)
diff --git a/airflow/providers/amazon/aws/hooks/logs.py 
b/airflow/providers/amazon/aws/hooks/logs.py
index 6680e1939e..2dff0aaaf3 100644
--- a/airflow/providers/amazon/aws/hooks/logs.py
+++ b/airflow/providers/amazon/aws/hooks/logs.py
@@ -50,6 +50,12 @@ class AwsLogsHook(AwsBaseHook):
         kwargs["client_type"] = "logs"
         super().__init__(*args, **kwargs)
 
+    class ContinuationToken:
+        """Just a wrapper around a str token to allow updating it from the 
caller."""
+
+        def __init__(self):
+            self.value: str | None = None
+
     def get_log_events(
         self,
         log_group: str,
@@ -57,6 +63,7 @@ class AwsLogsHook(AwsBaseHook):
         start_time: int = 0,
         skip: int = 0,
         start_from_head: bool = True,
+        continuation_token: ContinuationToken | None = None,
     ) -> Generator:
         """
         A generator for log items in a single stream. This will yield all the
@@ -72,16 +79,20 @@ class AwsLogsHook(AwsBaseHook):
             This is for when there are multiple entries at the same timestamp.
         :param start_from_head: whether to start from the beginning (True) of 
the log or
             at the end of the log (False).
+        :param continuation_token: a token indicating where to read logs from.
+            Will be updated as this method reads new logs, to be reused in 
subsequent calls.
         :return: | A CloudWatch log event with the following key-value pairs:
                  |   'timestamp' (int): The time in milliseconds of the event.
                  |   'message' (str): The log event data.
                  |   'ingestionTime' (int): The time in milliseconds the event 
was ingested.
         """
+        if continuation_token is None:
+            continuation_token = AwsLogsHook.ContinuationToken()
+
         num_consecutive_empty_response = 0
-        next_token = None
         while True:
-            if next_token is not None:
-                token_arg: dict[str, str] = {"nextToken": next_token}
+            if continuation_token.value is not None:
+                token_arg: dict[str, str] = {"nextToken": 
continuation_token.value}
             else:
                 token_arg = {}
 
@@ -105,16 +116,16 @@ class AwsLogsHook(AwsBaseHook):
 
             yield from events
 
+            if continuation_token.value == response["nextForwardToken"]:
+                return
+
             if not event_count:
                 num_consecutive_empty_response += 1
                 if num_consecutive_empty_response >= 
NUM_CONSECUTIVE_EMPTY_RESPONSE_EXIT_THRESHOLD:
                     # Exit if there are more than 
NUM_CONSECUTIVE_EMPTY_RESPONSE_EXIT_THRESHOLD consecutive
                     # empty responses
                     return
-            elif next_token != response["nextForwardToken"]:
-                num_consecutive_empty_response = 0
             else:
-                # Exit if the value of nextForwardToken is same in subsequent 
calls
-                return
+                num_consecutive_empty_response = 0
 
-            next_token = response["nextForwardToken"]
+            continuation_token.value = response["nextForwardToken"]
diff --git a/tests/providers/amazon/aws/hooks/test_ecs.py 
b/tests/providers/amazon/aws/hooks/test_ecs.py
index d9a4f53fa8..d90ec65e52 100644
--- a/tests/providers/amazon/aws/hooks/test_ecs.py
+++ b/tests/providers/amazon/aws/hooks/test_ecs.py
@@ -18,12 +18,14 @@ from __future__ import annotations
 
 from datetime import timedelta
 from unittest import mock
+from unittest.mock import PropertyMock
 
 import pytest
 from botocore.exceptions import ClientError
 
 from airflow.providers.amazon.aws.exceptions import EcsOperatorError, 
EcsTaskFailToStart
 from airflow.providers.amazon.aws.hooks.ecs import EcsHook, EcsTaskLogFetcher, 
should_retry, should_retry_eni
+from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
 
 DEFAULT_CONN_ID: str = "aws_default"
 REGION: str = "us-east-1"
@@ -146,6 +148,20 @@ class TestEcsTaskLogFetcher:
         with pytest.raises(Exception):
             next(self.log_fetcher._get_log_events())
 
+    @mock.patch.object(AwsLogsHook, "conn", new_callable=PropertyMock)
+    def test_get_log_events_updates_token(self, logs_conn_mock):
+        logs_conn_mock().get_log_events.return_value = {
+            "events": ["my_event"],
+            "nextForwardToken": "my_next_token",
+        }
+
+        token = AwsLogsHook.ContinuationToken()
+        list(self.log_fetcher._get_log_events(token))
+
+        assert token.value == "my_next_token"
+        # 2 calls expected, it's only on the second one that the stop 
condition old_token == next_token is met
+        assert logs_conn_mock().get_log_events.call_count == 2
+
     def test_event_to_str(self):
         events = [
             {"timestamp": 1617400267123, "message": "First"},

Reply via email to