This is an automated email from the ASF dual-hosted git repository.

jasonliu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4a5451632ea Fix CloudwatchTaskHandler display error (#54054)
4a5451632ea is described below

commit 4a5451632ea54e0f66c978f4da35f75acc18d144
Author: Jason(Zhe-You) Liu <[email protected]>
AuthorDate: Thu Dec 4 13:29:52 2025 +0800

    Fix CloudwatchTaskHandler display error (#54054)
    
    * Take over ash's fix
    
    * Fix datetime serialization error
    
    * Fix 'generator is not subscriptable' error
    
    * Fix test_cloudwatch_task_handler
    
    * Fix nits in code review
    
    * Add CloudWatchLogEvent type
    
    * Correct return type of .stream method
    
    - .stream method should return gen[str] but it return in gen[dict] in
      previous fix
    - Make _parse_cloudwatch_log_event return as str intead of dict can fix
      the problem
    
    * Revert change in file_task_handler
    
    * Fix test_log_message
    
    * Revert file_task_handler change
    
    * Fix type annotation for CloudWatchRemoteLogIO
    
    * Fix review comments
    
    - consolidate str_logs
    - rename _event_to_dict
---
 .../src/airflow/providers/amazon/aws/hooks/logs.py | 12 +++-
 .../amazon/aws/log/cloudwatch_task_handler.py      | 70 +++++++++++++++-------
 .../amazon/aws/log/test_cloudwatch_task_handler.py | 34 ++++++-----
 3 files changed, 78 insertions(+), 38 deletions(-)

diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/logs.py 
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/logs.py
index 884e02478ba..144c45c9d46 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/logs.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/logs.py
@@ -19,7 +19,7 @@ from __future__ import annotations
 
 import asyncio
 from collections.abc import AsyncGenerator, Generator
-from typing import Any
+from typing import Any, TypedDict
 
 from botocore.exceptions import ClientError
 
@@ -35,6 +35,14 @@ from airflow.utils.helpers import prune_dict
 NUM_CONSECUTIVE_EMPTY_RESPONSE_EXIT_THRESHOLD = 3
 
 
+class CloudWatchLogEvent(TypedDict):
+    """TypedDict for CloudWatch Log Event."""
+
+    timestamp: int
+    message: str
+    ingestionTime: int
+
+
 class AwsLogsHook(AwsBaseHook):
     """
     Interact with Amazon CloudWatch Logs.
@@ -67,7 +75,7 @@ class AwsLogsHook(AwsBaseHook):
         start_from_head: bool | None = None,
         continuation_token: ContinuationToken | None = None,
         end_time: int | None = None,
-    ) -> Generator:
+    ) -> Generator[CloudWatchLogEvent, None, None]:
         """
         Return a generator for log items in a single stream; yields all items 
available at the current moment.
 
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py
 
b/providers/amazon/src/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py
index 79d56e7b6ad..f510daca8da 100644
--- 
a/providers/amazon/src/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py
+++ 
b/providers/amazon/src/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py
@@ -22,6 +22,7 @@ import copy
 import json
 import logging
 import os
+from collections.abc import Generator
 from datetime import date, datetime, timedelta, timezone
 from functools import cached_property
 from pathlib import Path
@@ -40,8 +41,15 @@ if TYPE_CHECKING:
     import structlog.typing
 
     from airflow.models.taskinstance import TaskInstance
+    from airflow.providers.amazon.aws.hooks.logs import CloudWatchLogEvent
     from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI
-    from airflow.utils.log.file_task_handler import LogMessages, LogSourceInfo
+    from airflow.utils.log.file_task_handler import (
+        LogMessages,
+        LogResponse,
+        LogSourceInfo,
+        RawLogStream,
+        StreamingLogResponse,
+    )
 
 
 def json_serialize_legacy(value: Any) -> str | None:
@@ -163,20 +171,31 @@ class CloudWatchRemoteLogIO(LoggingMixin):  # noqa: D101
         self.close()
         return
 
-    def read(self, relative_path, ti: RuntimeTI) -> tuple[LogSourceInfo, 
LogMessages | None]:
-        logs: LogMessages | None = []
+    def read(self, relative_path: str, ti: RuntimeTI) -> LogResponse:
+        messages, logs = self.stream(relative_path, ti)
+        str_logs: list[str] = [f"{msg}\n" for group in logs for msg in group]
+
+        return messages, str_logs
+
+    def stream(self, relative_path: str, ti: RuntimeTI) -> 
StreamingLogResponse:
+        logs: list[RawLogStream] = []
         messages = [
             f"Reading remote log from Cloudwatch log_group: {self.log_group} 
log_stream: {relative_path}"
         ]
         try:
-            logs = [self.get_cloudwatch_logs(relative_path, ti)]
+            gen: RawLogStream = (
+                self._parse_log_event_as_dumped_json(event)
+                for event in self.get_cloudwatch_logs(relative_path, ti)
+            )
+            logs = [gen]
         except Exception as e:
-            logs = None
             messages.append(str(e))
 
         return messages, logs
 
-    def get_cloudwatch_logs(self, stream_name: str, task_instance: RuntimeTI):
+    def get_cloudwatch_logs(
+        self, stream_name: str, task_instance: RuntimeTI
+    ) -> Generator[CloudWatchLogEvent, None, None]:
         """
         Return all logs from the given log stream.
 
@@ -192,29 +211,22 @@ class CloudWatchRemoteLogIO(LoggingMixin):  # noqa: D101
             if (end_date := getattr(task_instance, "end_date", None)) is None
             else datetime_to_epoch_utc_ms(end_date + timedelta(seconds=30))
         )
-        events = self.hook.get_log_events(
+        return self.hook.get_log_events(
             log_group=self.log_group,
             log_stream_name=stream_name,
             end_time=end_time,
         )
-        return "\n".join(self._event_to_str(event) for event in events)
 
-    def _event_to_dict(self, event: dict) -> dict:
+    def _parse_log_event_as_dumped_json(self, event: CloudWatchLogEvent) -> 
str:
         event_dt = datetime.fromtimestamp(event["timestamp"] / 1000.0, 
tz=timezone.utc).isoformat()
-        message = event["message"]
+        event_msg = event["message"]
         try:
-            message = json.loads(message)
+            message = json.loads(event_msg)
             message["timestamp"] = event_dt
-            return message
         except Exception:
-            return {"timestamp": event_dt, "event": message}
+            message = {"timestamp": event_dt, "event": event_msg}
 
-    def _event_to_str(self, event: dict) -> str:
-        event_dt = datetime.fromtimestamp(event["timestamp"] / 1000.0, 
tz=timezone.utc)
-        # Format a datetime object to a string in Zulu time without 
milliseconds.
-        formatted_event_dt = event_dt.strftime("%Y-%m-%dT%H:%M:%SZ")
-        message = event["message"]
-        return f"[{formatted_event_dt}] {message}"
+        return json.dumps(message)
 
 
 class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
@@ -291,4 +303,22 @@ class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
     ) -> tuple[LogSourceInfo, LogMessages]:
         stream_name = self._render_filename(task_instance, try_number)
         messages, logs = self.io.read(stream_name, task_instance)
-        return messages, logs or []
+
+        messages = [
+            f"Reading remote log from Cloudwatch log_group: 
{self.io.log_group} log_stream: {stream_name}"
+        ]
+        try:
+            events = self.io.get_cloudwatch_logs(stream_name, task_instance)
+            logs = ["\n".join(self._event_to_str(event) for event in events)]
+        except Exception as e:
+            logs = []
+            messages.append(str(e))
+
+        return messages, logs
+
+    def _event_to_str(self, event: CloudWatchLogEvent) -> str:
+        event_dt = datetime.fromtimestamp(event["timestamp"] / 1000.0, 
tz=timezone.utc)
+        # Format a datetime object to a string in Zulu time without 
milliseconds.
+        formatted_event_dt = event_dt.strftime("%Y-%m-%dT%H:%M:%SZ")
+        message = event["message"]
+        return f"[{formatted_event_dt}] {message}"
diff --git 
a/providers/amazon/tests/unit/amazon/aws/log/test_cloudwatch_task_handler.py 
b/providers/amazon/tests/unit/amazon/aws/log/test_cloudwatch_task_handler.py
index 01eebfa0b1f..e0f9e169c8a 100644
--- a/providers/amazon/tests/unit/amazon/aws/log/test_cloudwatch_task_handler.py
+++ b/providers/amazon/tests/unit/amazon/aws/log/test_cloudwatch_task_handler.py
@@ -159,23 +159,9 @@ class TestCloudRemoteLogIO:
             assert metadata == [
                 f"Reading remote log from Cloudwatch log_group: log_group_name 
log_stream: {stream_name}"
             ]
-            assert logs == ['[2025-03-27T21:58:01Z] {"foo": "bar", "event": 
"Hi", "level": "info"}']
-
-    def test_event_to_str(self):
-        handler = self.subject
-        current_time = int(time.time()) * 1000
-        events = [
-            {"timestamp": current_time - 2000, "message": "First"},
-            {"timestamp": current_time - 1000, "message": "Second"},
-            {"timestamp": current_time, "message": "Third"},
-        ]
-        assert [handler._event_to_str(event) for event in events] == (
-            [
-                f"[{get_time_str(current_time - 2000)}] First",
-                f"[{get_time_str(current_time - 1000)}] Second",
-                f"[{get_time_str(current_time)}] Third",
+            assert logs == [
+                '{"foo": "bar", "event": "Hi", "level": "info", "timestamp": 
"2025-03-27T21:58:01.002000+00:00"}\n'
             ]
-        )
 
 
 @pytest.mark.db_test
@@ -426,6 +412,22 @@ class TestCloudwatchTaskHandler:
             filename_template=None,
         )
 
+    def test_event_to_str(self):
+        handler = self.cloudwatch_task_handler
+        current_time = int(time.time()) * 1000
+        events = [
+            {"timestamp": current_time - 2000, "message": "First"},
+            {"timestamp": current_time - 1000, "message": "Second"},
+            {"timestamp": current_time, "message": "Third"},
+        ]
+        assert [handler._event_to_str(event) for event in events] == (
+            [
+                f"[{get_time_str(current_time - 2000)}] First",
+                f"[{get_time_str(current_time - 1000)}] Second",
+                f"[{get_time_str(current_time)}] Third",
+            ]
+        )
+
 
 def generate_log_events(conn, log_group_name, log_stream_name, log_events):
     conn.create_log_group(logGroupName=log_group_name)

Reply via email to