This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ee54fe93768 Resolve OOM When Reading Large Logs in Webserver (#49470)
ee54fe93768 is described below
commit ee54fe93768e70da5bd2f0bb7398707181bee26a
Author: LIU ZHE YOU <[email protected]>
AuthorDate: Thu Jul 10 19:47:17 2025 +0800
Resolve OOM When Reading Large Logs in Webserver (#49470)
* Add note for new usage of LogMetadata
* Add _stream_parsed_lines_by_chunk
* Refactor _read_from_local/logs_server as return stream
* Refactor _interleave_logs with K-Way Merge
* Add _get_compatible_log_stream
* Refactor _read method to return stream with compatible interface
- Add compatible interface for executor, remote_logs
- Refactor skip log_pos with skip for each log source
* Refactor log_reader to adapt stream
* Fix _read_from_local open closed file error
* Refactor LogReader by yielding in batch
* Add ndjson header to get_log openapi schema
* Fix _add_log_from_parsed_log_streams_to_heap
- Add comparator for StructuredLogMessage
- Refactor parsed_log_streams from list to dict for removing empty logs
* Fix _interleave_logs dedupe logic
- should check the current logs with default timestamp
* Refactor test_log_handlers
- Fix events utils
- Add convert_list_to_stream, mock_parsed_logs_factory utils
- Fix the following test after refactoring FileTaskHandler
- test_file_task_handler_when_ti_value_is_invalid
- test_file_task_handler
- test_file_task_handler_running
- test_file_task_handler_rotate_size_limit
- test__read_when_local
- test__read_served_logs_checked_when_done_and_no_local_or_remote_logs
- test_interleave_interleaves
- test_interleave_logs_correct_ordering
- test_interleave_logs_correct_dedupe
- Add new test for refactoring FileTaskHandler
- test__stream_lines_by_chunk
- test__log_stream_to_parsed_log_stream
- test__sort_key
- test__is_sort_key_with_default_timestamp
- test__is_logs_stream_like
- test__add_log_from_parsed_log_streams_to_heap
* Move test_log_handlers utils to test_common
* Fix unit/celery/log_handlers test
* Fix mypy-providers static check
* Fix _get_compatible_log_stream
- sequential yield instead of parallel yield for all log_stream
* Fix amazon task_handler test
* Fix wask task handler test
* Fix elasticsearch task handler test
* Fix opensearch task handler test
* Fix TaskLogReader buffer
- don't concat buffer with empty str, yield directly from buffer
* Fix test_log_reader
* Fix CloudWatchRemoteLogIO.read mypy
* Fix test_gcs_task_handler
* Fix core_api test_log
* Fix CloudWatchRemoteLogIO._event_to_str dt format
* Fix TestCloudRemoteLogIO.test_log_message
* Fix es/os task_hander convert_list_to_stream
* Fix compact tests
* Refactor es,os task handler for 3.0 compact
* Fix compat for RedisTaskHandler
* Fix ruff format for test_cloudwatch_task_handler after rebase
* Fix 2.10 compat TestCloudwatchTaskHandler
* Fix 3.0 compat test for celery, wasb
Fix wasb test, spelling
* Fix 3.0 compat test for gcs
* Fix 3.0 compat test for cloudwatch, s3
* Set get_log API default response format to JSON
* Remove "first_time_read" key in log metadata
* Remove "<source>_log_pos" key in log metadata
* Add LogStreamCounter for backward compatibility
* Remove "first_time_read" with backward "log_pos" for tests
- test_log_reader
- test_log_handlers
- test_cloudwatch_task_handler
- test_s3_task_handler
- celery test_log_handler
- test_gcs_task_handler
- test_wasb_task_handler
- fix redis_task_handler
- fix log_pos
* Fix RedisTaskHandler compatibility
* Fix chores in self review
- Fix typo in _read_from_logs_server
- Remove unused parameters in _stream_lines_by_chunk
- read_log_stream
- Fix doc string by removing outdate note
- Only add buffer for full_download
- Add test ndjson format for get_log API
* Fine-tune HEAP_DUMP_SIZE
* Replace get_compatible_output_log_stream with iter
* Remove buffer in log_reader
* Fix log_id not found compact for es_task_handler
* Fix review comments
- rename LogStreamCounter as LogStreamAccumulator
- simply for-yield with yield-from in log_reader
- add type annotation for LogStreamAccumulator
* Refactor LogStreamAccumulator._capture method
- use itertools.isslice to get chunk
* Fix type hint, joinedload for ti.dag_run after merge
* Replace _sort_key as _create_sort_key
* Add _flush_logs_out_of_heap common util
* Fix review nits
- _is_logs_stream_like
- add type annotation
- reduce to 1 isinstance call
- construct log_streams in _get_compatible_log_stream inline
- use TypeDict for LogMetadata
- remove len(logs) to check empty
- revert typo of self.log_handler.read in log_reader
- log_stream_accumulator
- refactor flush logic
- make totoal_lines as property
- make stream as property
* Fix mypy errors after merge
* Fix redis task handler test
* Refactor _capture logic in LogStreamAccumulator
* Add comments for ingore LogMetadata TypeDict
* Add comment for offset; Fix commet for LogMessages
* Refactor with from_iterable, islice
* Fix nits in test
- refactor structured_logs fixtures in TestLogStreamAccumulator
- use f-strign in test_file_task_handler
- assert actual value of _create_sort_key
- add details comments in test__add_log_from_parsed_log_streams_to_heap
* Refactor test_utils
* Add comment for lazy initialization
* Fix error handling for _stream_lines_by_chunk
* Fix mypy error after merge
* Fix final review nits
* Fix mypy error
---
.../api_fastapi/core_api/routes/public/log.py | 57 ++-
.../src/airflow/utils/log/file_task_handler.py | 452 ++++++++++++++---
airflow-core/src/airflow/utils/log/log_reader.py | 31 +-
.../airflow/utils/log/log_stream_accumulator.py | 155 ++++++
.../api_fastapi/core_api/routes/public/test_log.py | 27 +-
.../tests/unit/utils/log/test_log_reader.py | 45 +-
.../unit/utils/log/test_stream_accumulator.py | 165 +++++++
airflow-core/tests/unit/utils/test_log_handlers.py | 539 +++++++++++++++++----
.../tests_common/test_utils/file_task_handler.py | 76 +++
docs/spelling_wordlist.txt | 1 +
.../amazon/aws/log/cloudwatch_task_handler.py | 16 +-
.../amazon/aws/log/test_cloudwatch_task_handler.py | 49 +-
.../unit/amazon/aws/log/test_s3_task_handler.py | 2 +
.../unit/celery/log_handlers/test_log_handlers.py | 17 +-
.../providers/elasticsearch/log/es_task_handler.py | 15 +-
.../unit/elasticsearch/log/test_es_task_handler.py | 11 +-
.../unit/google/cloud/log/test_gcs_task_handler.py | 2 +
.../microsoft/azure/log/test_wasb_task_handler.py | 6 +-
.../providers/opensearch/log/os_task_handler.py | 18 +-
.../unit/opensearch/log/test_os_task_handler.py | 8 +-
.../providers/redis/log/redis_task_handler.py | 7 +-
.../unit/redis/log/test_redis_task_handler.py | 13 +-
22 files changed, 1439 insertions(+), 273 deletions(-)
diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/log.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/log.py
index 277e705c18d..688a563d2b6 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/log.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/log.py
@@ -20,7 +20,8 @@ from __future__ import annotations
import contextlib
import textwrap
-from fastapi import Depends, HTTPException, Request, Response, status
+from fastapi import Depends, HTTPException, Request, status
+from fastapi.responses import StreamingResponse
from itsdangerous import BadSignature, URLSafeSerializer
from pydantic import NonNegativeInt, PositiveInt
from sqlalchemy.orm import joinedload
@@ -120,12 +121,17 @@ def get_log(
)
ti = session.scalar(query)
if ti is None:
- query = select(TaskInstanceHistory).where(
- TaskInstanceHistory.task_id == task_id,
- TaskInstanceHistory.dag_id == dag_id,
- TaskInstanceHistory.run_id == dag_run_id,
- TaskInstanceHistory.map_index == map_index,
- TaskInstanceHistory.try_number == try_number,
+ query = (
+ select(TaskInstanceHistory)
+ .where(
+ TaskInstanceHistory.task_id == task_id,
+ TaskInstanceHistory.dag_id == dag_id,
+ TaskInstanceHistory.run_id == dag_run_id,
+ TaskInstanceHistory.map_index == map_index,
+ TaskInstanceHistory.try_number == try_number,
+ )
+ .options(joinedload(TaskInstanceHistory.dag_run))
+ # we need to joinedload the dag_run, since
FileTaskHandler._render_filename needs ti.dag_run
)
ti = session.scalar(query)
@@ -138,24 +144,27 @@ def get_log(
with contextlib.suppress(TaskNotFound):
ti.task = dag.get_task(ti.task_id)
- if accept == Mimetype.JSON or accept == Mimetype.ANY: # default
- logs, metadata = task_log_reader.read_log_chunks(ti, try_number,
metadata)
- encoded_token = None
+ if accept == Mimetype.NDJSON: # only specified application/x-ndjson will
return streaming response
+ # LogMetadata(TypedDict) is used as type annotation for log_reader;
added ignore to suppress mypy error
+ log_stream = task_log_reader.read_log_stream(ti, try_number, metadata)
# type: ignore[arg-type]
+ headers = None
if not metadata.get("end_of_log", False):
- encoded_token =
URLSafeSerializer(request.app.state.secret_key).dumps(metadata)
- return
TaskInstancesLogResponse.model_construct(continuation_token=encoded_token,
content=logs)
- # text/plain, or something else we don't understand. Return raw log content
-
- # We need to exhaust the iterator before we can generate the continuation
token.
- # We could improve this by making it a streaming/async response, and by
then setting the header using
- # HTTP Trailers
- logs = "".join(task_log_reader.read_log_stream(ti, try_number, metadata))
- headers = None
- if not metadata.get("end_of_log", False):
- headers = {
- "Airflow-Continuation-Token":
URLSafeSerializer(request.app.state.secret_key).dumps(metadata)
- }
- return Response(media_type="application/x-ndjson", content=logs,
headers=headers)
+ headers = {
+ "Airflow-Continuation-Token":
URLSafeSerializer(request.app.state.secret_key).dumps(metadata)
+ }
+ return StreamingResponse(media_type="application/x-ndjson",
content=log_stream, headers=headers)
+
+ # application/json, or something else we don't understand.
+ # Return JSON format, which will be more easily for users to debug.
+
+ # LogMetadata(TypedDict) is used as type annotation for log_reader; added
ignore to suppress mypy error
+ structured_log_stream, out_metadata = task_log_reader.read_log_chunks(ti,
try_number, metadata) # type: ignore[arg-type]
+ encoded_token = None
+ if not out_metadata.get("end_of_log", False):
+ encoded_token =
URLSafeSerializer(request.app.state.secret_key).dumps(out_metadata)
+ return TaskInstancesLogResponse.model_construct(
+ continuation_token=encoded_token, content=list(structured_log_stream)
+ )
@task_instances_log_router.get(
diff --git a/airflow-core/src/airflow/utils/log/file_task_handler.py
b/airflow-core/src/airflow/utils/log/file_task_handler.py
index 6ee087c3858..10595fa7859 100644
--- a/airflow-core/src/airflow/utils/log/file_task_handler.py
+++ b/airflow-core/src/airflow/utils/log/file_task_handler.py
@@ -19,51 +19,114 @@
from __future__ import annotations
-import itertools
+import heapq
+import io
import logging
import os
-from collections.abc import Callable, Iterable
+from collections.abc import Callable, Generator, Iterator
from contextlib import suppress
from datetime import datetime
from enum import Enum
+from itertools import chain, islice
from pathlib import Path
-from typing import TYPE_CHECKING, Any
+from types import GeneratorType
+from typing import IO, TYPE_CHECKING, TypedDict, cast
from urllib.parse import urljoin
import pendulum
from pydantic import BaseModel, ConfigDict, ValidationError
+from typing_extensions import NotRequired
from airflow.configuration import conf
from airflow.executors.executor_loader import ExecutorLoader
from airflow.utils.helpers import parse_template_string, render_template
+from airflow.utils.log.log_stream_accumulator import LogStreamAccumulator
from airflow.utils.log.logging_mixin import SetContextPropagate
from airflow.utils.log.non_caching_file_handler import
NonCachingRotatingFileHandler
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import State, TaskInstanceState
if TYPE_CHECKING:
+ from requests import Response
+
from airflow.executors.base_executor import BaseExecutor
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancehistory import TaskInstanceHistory
from airflow.typing_compat import TypeAlias
+CHUNK_SIZE = 1024 * 1024 * 5 # 5MB
+DEFAULT_SORT_DATETIME = pendulum.datetime(2000, 1, 1)
+DEFAULT_SORT_TIMESTAMP = int(DEFAULT_SORT_DATETIME.timestamp() * 1000)
+SORT_KEY_OFFSET = 10000000
+"""An offset used by the _create_sort_key utility.
+
+Assuming 50 characters per line, an offset of 10,000,000 can represent
approximately 500 MB of file data, which is sufficient for use as a constant.
+"""
+HEAP_DUMP_SIZE = 5000
+HALF_HEAP_DUMP_SIZE = HEAP_DUMP_SIZE // 2
# These types are similar, but have distinct names to make processing them
less error prone
-LogMessages: TypeAlias = list["StructuredLogMessage"] | list[str]
-"""The log messages themselves, either in already sturcutured form, or a
single string blob to be parsed later"""
+LogMessages: TypeAlias = list[str]
+"""The legacy format of log messages before 3.0.2"""
LogSourceInfo: TypeAlias = list[str]
"""Information _about_ the log fetching process for display to a user"""
-LogMetadata: TypeAlias = dict[str, Any]
+RawLogStream: TypeAlias = Generator[str, None, None]
+"""Raw log stream, containing unparsed log lines."""
+LegacyLogResponse: TypeAlias = tuple[LogSourceInfo, LogMessages]
+"""Legacy log response, containing source information and log messages."""
+LogResponse: TypeAlias = tuple[LogSourceInfo, list[RawLogStream]]
+LogResponseWithSize: TypeAlias = tuple[LogSourceInfo, list[RawLogStream], int]
+"""Log response, containing source information, stream of log lines, and total
log size."""
+StructuredLogStream: TypeAlias = Generator["StructuredLogMessage", None, None]
+"""Structured log stream, containing structured log messages."""
+LogHandlerOutputStream: TypeAlias = (
+ StructuredLogStream | Iterator["StructuredLogMessage"] |
chain["StructuredLogMessage"]
+)
+"""Output stream, containing structured log messages or a chain of them."""
+ParsedLog: TypeAlias = tuple[datetime | None, int, "StructuredLogMessage"]
+"""Parsed log record, containing timestamp, line_num and the structured log
message."""
+ParsedLogStream: TypeAlias = Generator[ParsedLog, None, None]
+LegacyProvidersLogType: TypeAlias = list["StructuredLogMessage"] | str |
list[str]
+"""Return type used by legacy `_read` methods for Alibaba Cloud,
Elasticsearch, OpenSearch, and Redis log handlers.
+
+- For Elasticsearch and OpenSearch: returns either a list of structured log
messages.
+- For Alibaba Cloud: returns a string.
+- For Redis: returns a list of strings.
+"""
+
logger = logging.getLogger(__name__)
+class LogMetadata(TypedDict):
+ """Metadata about the log fetching process, including `end_of_log` and
`log_pos`."""
+
+ end_of_log: bool
+ log_pos: NotRequired[int]
+ # the following attributes are used for Elasticsearch and OpenSearch log
handlers
+ offset: NotRequired[str | int]
+ # Ensure a string here. Large offset numbers will get JSON.parsed
incorrectly
+ # on the client. Sending as a string prevents this issue.
+ #
https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Number/MAX_SAFE_INTEGER
+ last_log_timestamp: NotRequired[str]
+ max_offset: NotRequired[str]
+
+
class StructuredLogMessage(BaseModel):
"""An individual log message."""
timestamp: datetime | None = None
event: str
+ # Collisions of sort_key may occur due to duplicated messages. If this
happens, the heap will use the second element,
+ # which is the StructuredLogMessage for comparison. Therefore, we need to
define a comparator for it.
+ def __lt__(self, other: StructuredLogMessage) -> bool:
+ return self.sort_key < other.sort_key
+
+ @property
+ def sort_key(self) -> datetime:
+ return self.timestamp or DEFAULT_SORT_DATETIME
+
# We don't need to cache string when parsing in to this, as almost every
line will have a different
# values; `extra=allow` means we'll create extra properties as needed.
Only timestamp and event are
# required, everything else is up to what ever is producing the logs
@@ -100,7 +163,7 @@ def _set_task_deferred_context_var():
h.ctx_task_deferred = True
-def _fetch_logs_from_service(url, log_relative_path):
+def _fetch_logs_from_service(url: str, log_relative_path: str) -> Response:
# Import occurs in function scope for perf. Ref:
https://github.com/apache/airflow/pull/21438
import requests
@@ -111,7 +174,6 @@ def _fetch_logs_from_service(url, log_relative_path):
secret_key=get_signing_key("api", "secret_key"),
# Since we are using a secret key, we need to be explicit about the
algorithm here too
algorithm="HS512",
- private_key=None,
issuer=None,
valid_for=conf.getint("webserver", "log_request_clock_grace",
fallback=30),
audience="task-instance-logs",
@@ -120,6 +182,7 @@ def _fetch_logs_from_service(url, log_relative_path):
url,
timeout=timeout,
headers={"Authorization": generator.generate({"filename":
log_relative_path})},
+ stream=True,
)
response.encoding = "utf-8"
return response
@@ -134,28 +197,68 @@ if not _parse_timestamp:
return pendulum.parse(timestamp_str.strip("[]"))
-def _parse_log_lines(
- lines: str | LogMessages,
-) -> Iterable[tuple[datetime | None, int, StructuredLogMessage]]:
+def _stream_lines_by_chunk(
+ log_io: IO[str],
+) -> RawLogStream:
+ """
+ Stream lines from a file-like IO object.
+
+ :param log_io: A file-like IO object to read from.
+ :return: A generator that yields individual lines within the specified
range.
+ """
+ # Skip processing if file is already closed
+ if log_io.closed:
+ return
+
+ # Seek to beginning if possible
+ if log_io.seekable():
+ try:
+ log_io.seek(0)
+ except Exception as e:
+ logger.error("Error seeking in log stream: %s", e)
+ return
+
+ buffer = ""
+ while True:
+ # Check if file is already closed
+ if log_io.closed:
+ break
+
+ try:
+ chunk = log_io.read(CHUNK_SIZE)
+ except Exception as e:
+ logger.error("Error reading log stream: %s", e)
+ break
+
+ if not chunk:
+ break
+
+ buffer += chunk
+ *lines, buffer = buffer.split("\n")
+ yield from lines
+
+ if buffer:
+ yield from buffer.split("\n")
+
+
+def _log_stream_to_parsed_log_stream(
+ log_stream: RawLogStream,
+) -> ParsedLogStream:
+ """
+ Turn a str log stream into a generator of parsed log lines.
+
+ :param log_stream: The stream to parse.
+ :return: A generator of parsed log lines.
+ """
from airflow.utils.timezone import coerce_datetime
timestamp = None
next_timestamp = None
- if isinstance(lines, str):
- lines = lines.splitlines()
- if isinstance(lines, list) and len(lines) and isinstance(lines[0], str):
- # A list of content from each location. It's a super odd format, but
this is what we load
- # [['a\nb\n'], ['c\nd\ne\n']] -> ['a', 'b', 'c', 'd', 'e']
- lines = itertools.chain.from_iterable(map(str.splitlines, lines)) #
type: ignore[assignment,arg-type]
-
- # https://github.com/python/mypy/issues/8586
- for idx, line in enumerate[str | StructuredLogMessage](lines):
+ idx = 0
+ for line in log_stream:
if line:
try:
- if isinstance(line, StructuredLogMessage):
- log = line
- else:
- log = StructuredLogMessage.model_validate_json(line)
+ log = StructuredLogMessage.model_validate_json(line)
except ValidationError:
with suppress(Exception):
# If we can't parse the timestamp, don't attach one to the
row
@@ -166,17 +269,146 @@ def _parse_log_lines(
log.timestamp = coerce_datetime(log.timestamp)
timestamp = log.timestamp
yield timestamp, idx, log
+ idx += 1
+
+
+def _create_sort_key(timestamp: datetime | None, line_num: int) -> int:
+ """
+ Create a sort key for log record, to be used in K-way merge.
+
+ :param timestamp: timestamp of the log line
+ :param line_num: line number of the log line
+ :return: a integer as sort key to avoid overhead of memory usage
+ """
+ return int((timestamp or DEFAULT_SORT_DATETIME).timestamp() * 1000) *
SORT_KEY_OFFSET + line_num
+
+
+def _is_sort_key_with_default_timestamp(sort_key: int) -> bool:
+ """
+ Check if the sort key was generated with the DEFAULT_SORT_TIMESTAMP.
+ This is used to identify log records that don't have timestamp.
-def _interleave_logs(*logs: str | LogMessages) ->
Iterable[StructuredLogMessage]:
- min_date = pendulum.datetime(2000, 1, 1)
+ :param sort_key: The sort key to check
+ :return: True if the sort key was generated with DEFAULT_SORT_TIMESTAMP,
False otherwise
+ """
+ # Extract the timestamp part from the sort key (remove the line number
part)
+ timestamp_part = sort_key // SORT_KEY_OFFSET
+ return timestamp_part == DEFAULT_SORT_TIMESTAMP
+
+
+def _add_log_from_parsed_log_streams_to_heap(
+ heap: list[tuple[int, StructuredLogMessage]],
+ parsed_log_streams: dict[int, ParsedLogStream],
+) -> None:
+ """
+ Add one log record from each parsed log stream to the heap, and will
remove empty log stream from the dict after iterating.
+
+ :param heap: heap to store log records
+ :param parsed_log_streams: dict of parsed log streams
+ """
+ # We intend to initialize the list lazily, as in most cases we don't need
to remove any log streams.
+ # This reduces memory overhead, since this function is called repeatedly
until all log streams are empty.
+ log_stream_to_remove: list[int] | None = None
+ for idx, log_stream in parsed_log_streams.items():
+ record: ParsedLog | None = next(log_stream, None)
+ if record is None:
+ if log_stream_to_remove is None:
+ log_stream_to_remove = []
+ log_stream_to_remove.append(idx)
+ continue
+ # add type hint to avoid mypy error
+ record = cast("ParsedLog", record)
+ timestamp, line_num, line = record
+ # take int as sort key to avoid overhead of memory usage
+ heapq.heappush(heap, (_create_sort_key(timestamp, line_num), line))
+ # remove empty log stream from the dict
+ if log_stream_to_remove is not None:
+ for idx in log_stream_to_remove:
+ del parsed_log_streams[idx]
+
+
+def _flush_logs_out_of_heap(
+ heap: list[tuple[int, StructuredLogMessage]],
+ flush_size: int,
+ last_log_container: list[StructuredLogMessage | None],
+) -> Generator[StructuredLogMessage, None, None]:
+ """
+ Flush logs out of the heap, deduplicating them based on the last log.
+
+ :param heap: heap to flush logs from
+ :param flush_size: number of logs to flush
+ :param last_log_container: a container to store the last log, to avoid
duplicate logs
+ :return: a generator that yields deduplicated logs
+ """
+ last_log = last_log_container[0]
+ for _ in range(flush_size):
+ sort_key, line = heapq.heappop(heap)
+ if line != last_log or _is_sort_key_with_default_timestamp(sort_key):
# dedupe
+ yield line
+ last_log = line
+ # update the last log container with the last log
+ last_log_container[0] = last_log
+
+
+def _interleave_logs(*log_streams: RawLogStream) -> StructuredLogStream:
+ """
+ Merge parsed log streams using K-way merge.
+
+ By yielding HALF_CHUNK_SIZE records when heap size exceeds CHUNK_SIZE, we
can reduce the chance of messing up the global order.
+ Since there are multiple log streams, we can't guarantee that the records
are in global order.
+
+ e.g.
+
+ log_stream1: ----------
+ log_stream2: ----
+ log_stream3: --------
+
+ The first record of log_stream3 is later than the fourth record of
log_stream1 !
+ :param parsed_log_streams: parsed log streams
+ :return: interleaved log stream
+ """
+ # don't need to push whole tuple into heap, which increases too much
overhead
+ # push only sort_key and line into heap
+ heap: list[tuple[int, StructuredLogMessage]] = []
+ # to allow removing empty streams while iterating, also turn the str
stream into parsed log stream
+ parsed_log_streams: dict[int, ParsedLogStream] = {
+ idx: _log_stream_to_parsed_log_stream(log_stream) for idx, log_stream
in enumerate(log_streams)
+ }
+
+ # keep adding records from logs until all logs are empty
+ last_log_container: list[StructuredLogMessage | None] = [None]
+ while parsed_log_streams:
+ _add_log_from_parsed_log_streams_to_heap(heap, parsed_log_streams)
+
+ # yield HALF_HEAP_DUMP_SIZE records when heap size exceeds
HEAP_DUMP_SIZE
+ if len(heap) >= HEAP_DUMP_SIZE:
+ yield from _flush_logs_out_of_heap(heap, HALF_HEAP_DUMP_SIZE,
last_log_container)
+
+ # yield remaining records
+ yield from _flush_logs_out_of_heap(heap, len(heap), last_log_container)
+ # free memory
+ del heap
+ del parsed_log_streams
+
+
+def _is_logs_stream_like(log) -> bool:
+ """Check if the logs are stream-like."""
+ return isinstance(log, (chain, GeneratorType))
+
+
+def _get_compatible_log_stream(
+ log_messages: LogMessages,
+) -> RawLogStream:
+ """
+ Convert legacy log message blobs into a generator that yields log lines.
- records = itertools.chain.from_iterable(_parse_log_lines(log) for log in
logs)
- last = None
- for timestamp, _, msg in sorted(records, key=lambda x: (x[0] or min_date,
x[1])):
- if msg != last or not timestamp: # dedupe
- yield msg
- last = msg
+ :param log_messages: List of legacy log message strings.
+ :return: A generator that yields interleaved log lines.
+ """
+ yield from chain.from_iterable(
+ _stream_lines_by_chunk(io.StringIO(log_message)) for log_message in
log_messages
+ )
class FileTaskHandler(logging.Handler):
@@ -345,8 +577,8 @@ class FileTaskHandler(logging.Handler):
self,
ti: TaskInstance | TaskInstanceHistory,
try_number: int,
- metadata: dict[str, Any] | None = None,
- ):
+ metadata: LogMetadata | None = None,
+ ) -> tuple[LogHandlerOutputStream | LegacyProvidersLogType, LogMetadata]:
"""
Template method that contains custom logic of reading logs given the
try_number.
@@ -370,22 +602,38 @@ class FileTaskHandler(logging.Handler):
# initializing the handler. Thus explicitly getting log location
# is needed to get correct log path.
worker_log_rel_path = self._render_filename(ti, try_number)
+ sources: LogSourceInfo = []
source_list: list[str] = []
- remote_logs: LogMessages | None = []
- local_logs: list[str] = []
- sources: list[str] = []
- executor_logs: list[str] = []
- served_logs: LogMessages = []
+ remote_logs: list[RawLogStream] = []
+ local_logs: list[RawLogStream] = []
+ executor_logs: list[RawLogStream] = []
+ served_logs: list[RawLogStream] = []
with suppress(NotImplementedError):
- sources, remote_logs = self._read_remote_logs(ti, try_number,
metadata)
-
+ sources, logs = self._read_remote_logs(ti, try_number, metadata)
+ if not logs:
+ remote_logs = []
+ elif isinstance(logs, list) and isinstance(logs[0], str):
+ # If the logs are in legacy format, convert them to a
generator of log lines
+ remote_logs = [
+ # We don't need to use the log_pos here, as we are using
the metadata to track the position
+ _get_compatible_log_stream(cast("list[str]", logs))
+ ]
+ elif isinstance(logs, list) and _is_logs_stream_like(logs[0]):
+ # If the logs are already in a stream-like format, we can use
them directly
+ remote_logs = cast("list[RawLogStream]", logs)
+ else:
+ # If the logs are in a different format, raise an error
+ raise TypeError("Logs should be either a list of strings or a
generator of log lines.")
+ # Extend LogSourceInfo
source_list.extend(sources)
has_k8s_exec_pod = False
if ti.state == TaskInstanceState.RUNNING:
executor_get_task_log = self._get_executor_get_task_log(ti)
response = executor_get_task_log(ti, try_number)
if response:
- sources, executor_logs = response
+ sources, logs = response
+ # make the logs stream-like compatible
+ executor_logs = [_get_compatible_log_stream(logs)]
if sources:
source_list.extend(sources)
has_k8s_exec_pod = True
@@ -404,15 +652,13 @@ class FileTaskHandler(logging.Handler):
sources, served_logs = self._read_from_logs_server(ti,
worker_log_rel_path)
source_list.extend(sources)
- logs = list(
- _interleave_logs(
- *local_logs,
- (remote_logs or []),
- *(executor_logs or []),
- *served_logs,
- )
+ out_stream: LogHandlerOutputStream = _interleave_logs(
+ *local_logs,
+ *remote_logs,
+ *executor_logs,
+ *served_logs,
)
- log_pos = len(logs)
+
# Log message source details are grouped: they are not relevant for
most users and can
# distract them from finding the root cause of their errors
header = [
@@ -423,12 +669,22 @@ class FileTaskHandler(logging.Handler):
TaskInstanceState.RUNNING,
TaskInstanceState.DEFERRED,
)
- if metadata and "log_pos" in metadata:
- previous_line = metadata["log_pos"]
- logs = logs[previous_line:] # Cut off previously passed log test
as new tail
- else:
- logs = header + logs
- return logs, {"end_of_log": end_of_log, "log_pos": log_pos}
+
+ with LogStreamAccumulator(out_stream, HEAP_DUMP_SIZE) as
stream_accumulator:
+ log_pos = stream_accumulator.total_lines
+ out_stream = stream_accumulator.stream
+
+ # skip log stream until the last position
+ if metadata and "log_pos" in metadata:
+ islice(out_stream, metadata["log_pos"])
+ else:
+ # first time reading log, add messages before interleaved log
stream
+ out_stream = chain(header, out_stream)
+
+ return out_stream, {
+ "end_of_log": end_of_log,
+ "log_pos": log_pos,
+ }
@staticmethod
@staticmethod
@@ -469,8 +725,8 @@ class FileTaskHandler(logging.Handler):
self,
task_instance: TaskInstance | TaskInstanceHistory,
try_number: int | None = None,
- metadata: dict[str, Any] | None = None,
- ) -> tuple[list[StructuredLogMessage] | str, dict[str, Any]]:
+ metadata: LogMetadata | None = None,
+ ) -> tuple[LogHandlerOutputStream, LogMetadata]:
"""
Read logs of given task instance from local machine.
@@ -489,7 +745,7 @@ class FileTaskHandler(logging.Handler):
event="Task was skipped, no logs available."
)
]
- return logs, {"end_of_log": True}
+ return chain(logs), {"end_of_log": True}
if try_number is None or try_number < 1:
logs = [
@@ -497,9 +753,38 @@ class FileTaskHandler(logging.Handler):
level="error", event=f"Error fetching the logs. Try number
{try_number} is invalid."
)
]
- return logs, {"end_of_log": True}
-
- return self._read(task_instance, try_number, metadata)
+ return chain(logs), {"end_of_log": True}
+
+ # compatibility for es_task_handler and os_task_handler
+ read_result = self._read(task_instance, try_number, metadata)
+ out_stream, metadata = read_result
+ # If the out_stream is None or empty, return the read result
+ if not out_stream:
+ out_stream = cast("Generator[StructuredLogMessage, None, None]",
out_stream)
+ return out_stream, metadata
+
+ if _is_logs_stream_like(out_stream):
+ out_stream = cast("Generator[StructuredLogMessage, None, None]",
out_stream)
+ return out_stream, metadata
+ if isinstance(out_stream, list) and isinstance(out_stream[0],
StructuredLogMessage):
+ out_stream = cast("list[StructuredLogMessage]", out_stream)
+ return (log for log in out_stream), metadata
+ if isinstance(out_stream, list) and isinstance(out_stream[0], str):
+ # If the out_stream is a list of strings, convert it to a generator
+ out_stream = cast("list[str]", out_stream)
+ raw_stream =
_stream_lines_by_chunk(io.StringIO("".join(out_stream)))
+ out_stream = (log for _, _, log in
_log_stream_to_parsed_log_stream(raw_stream))
+ return out_stream, metadata
+ if isinstance(out_stream, str):
+ # If the out_stream is a string, convert it to a generator
+ raw_stream = _stream_lines_by_chunk(io.StringIO(out_stream))
+ out_stream = (log for _, _, log in
_log_stream_to_parsed_log_stream(raw_stream))
+ return out_stream, metadata
+ raise TypeError(
+ "Invalid log stream type. Expected a generator of
StructuredLogMessage, list of StructuredLogMessage, list of str or str."
+ f" Got {type(out_stream).__name__} instead."
+ f" Content type: {type(out_stream[0]).__name__ if
isinstance(out_stream, (list, tuple)) and out_stream else 'empty'}"
+ )
@staticmethod
def _prepare_log_folder(directory: Path, new_folder_permissions: int):
@@ -565,15 +850,28 @@ class FileTaskHandler(logging.Handler):
return full_path
@staticmethod
- def _read_from_local(worker_log_path: Path) -> tuple[list[str], list[str]]:
+ def _read_from_local(
+ worker_log_path: Path,
+ ) -> LogResponse:
+ sources: LogSourceInfo = []
+ log_streams: list[RawLogStream] = []
paths = sorted(worker_log_path.parent.glob(worker_log_path.name + "*"))
- sources = [os.fspath(x) for x in paths]
- logs = [file.read_text() for file in paths]
- return sources, logs
+ if not paths:
+ return sources, log_streams
- def _read_from_logs_server(self, ti, worker_log_rel_path) ->
tuple[LogSourceInfo, LogMessages]:
- sources = []
- logs = []
+ for path in paths:
+ sources.append(os.fspath(path))
+ # Read the log file and yield lines
+ log_streams.append(_stream_lines_by_chunk(open(path,
encoding="utf-8")))
+ return sources, log_streams
+
+ def _read_from_logs_server(
+ self,
+ ti: TaskInstance,
+ worker_log_rel_path: str,
+ ) -> LogResponse:
+ sources: LogSourceInfo = []
+ log_streams: list[RawLogStream] = []
try:
log_type = LogType.TRIGGER if ti.triggerer_job else LogType.WORKER
url, rel_path = self._get_log_retrieval_url(ti,
worker_log_rel_path, log_type=log_type)
@@ -590,20 +888,26 @@ class FileTaskHandler(logging.Handler):
else:
# Check if the resource was properly fetched
response.raise_for_status()
- if response.text:
+ if int(response.headers.get("Content-Length", 0)) > 0:
sources.append(url)
- logs.append(response.text)
+ log_streams.append(
+
_stream_lines_by_chunk(io.TextIOWrapper(cast("IO[bytes]", response.raw)))
+ )
except Exception as e:
from requests.exceptions import InvalidURL
- if isinstance(e, InvalidURL) and
ti.task.inherits_from_empty_operator is True:
+ if (
+ isinstance(e, InvalidURL)
+ and ti.task is not None
+ and ti.task.inherits_from_empty_operator is True
+ ):
sources.append(self.inherits_from_empty_operator_log_message)
else:
sources.append(f"Could not read served logs: {e}")
logger.exception("Could not read served logs")
- return sources, logs
+ return sources, log_streams
- def _read_remote_logs(self, ti, try_number, metadata=None) ->
tuple[LogSourceInfo, LogMessages]:
+ def _read_remote_logs(self, ti, try_number, metadata=None) ->
LegacyLogResponse | LogResponse:
"""
Implement in subclasses to read from the remote service.
diff --git a/airflow-core/src/airflow/utils/log/log_reader.py
b/airflow-core/src/airflow/utils/log/log_reader.py
index 0bb61c52dbc..9f61c2f730c 100644
--- a/airflow-core/src/airflow/utils/log/log_reader.py
+++ b/airflow-core/src/airflow/utils/log/log_reader.py
@@ -18,13 +18,12 @@ from __future__ import annotations
import logging
import time
-from collections.abc import Iterator
+from collections.abc import Generator, Iterator
from functools import cached_property
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING
from airflow.configuration import conf
from airflow.utils.helpers import render_log_filename
-from airflow.utils.log.file_task_handler import StructuredLogMessage
from airflow.utils.log.logging_mixin import ExternalLoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import TaskInstanceState
@@ -35,9 +34,11 @@ if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancehistory import TaskInstanceHistory
from airflow.typing_compat import TypeAlias
+ from airflow.utils.log.file_task_handler import LogHandlerOutputStream,
LogMetadata
-LogMessages: TypeAlias = list[StructuredLogMessage] | str
-LogMetadata: TypeAlias = dict[str, Any]
+LogReaderOutputStream: TypeAlias = Generator[str, None, None]
+
+READ_BATCH_SIZE = 1024
class TaskLogReader:
@@ -54,7 +55,7 @@ class TaskLogReader:
ti: TaskInstance | TaskInstanceHistory,
try_number: int | None,
metadata: LogMetadata,
- ) -> tuple[LogMessages, LogMetadata]:
+ ) -> tuple[LogHandlerOutputStream, LogMetadata]:
"""
Read chunks of Task Instance logs.
@@ -92,24 +93,19 @@ class TaskLogReader:
try_number = ti.try_number
for key in ("end_of_log", "max_offset", "offset", "log_pos"):
- metadata.pop(key, None)
+ #
https://mypy.readthedocs.io/en/stable/typed_dict.html#supported-operations
+ metadata.pop(key, None) # type: ignore[misc]
empty_iterations = 0
while True:
- logs, out_metadata = self.read_log_chunks(ti, try_number, metadata)
- # Update the metadata dict in place so caller can get new
values/end-of-log etc.
-
- for log in logs:
- # It's a bit wasteful here to parse the JSON then dump it back
again.
- # Optimize this so in stream mode we can just pass logs right
through, or even better add
- # support to 307 redirect to a signed URL etc.
- yield (log if isinstance(log, str) else log.model_dump_json())
+ "\n"
+ log_stream, out_metadata = self.read_log_chunks(ti, try_number,
metadata)
+ yield from (f"{log.model_dump_json()}\n" for log in log_stream)
if not out_metadata.get("end_of_log", False) and ti.state not in (
TaskInstanceState.RUNNING,
TaskInstanceState.DEFERRED,
):
- if logs:
+ if log_stream:
empty_iterations = 0
else:
# we did not receive any logs in this loop
@@ -121,7 +117,8 @@ class TaskLogReader:
yield "(Log stream stopped - End of log marker not
found; logs may be incomplete.)\n"
return
else:
- metadata.clear()
+ #
https://mypy.readthedocs.io/en/stable/typed_dict.html#supported-operations
+ metadata.clear() # type: ignore[attr-defined]
metadata.update(out_metadata)
return
diff --git a/airflow-core/src/airflow/utils/log/log_stream_accumulator.py
b/airflow-core/src/airflow/utils/log/log_stream_accumulator.py
new file mode 100644
index 00000000000..953b47dd971
--- /dev/null
+++ b/airflow-core/src/airflow/utils/log/log_stream_accumulator.py
@@ -0,0 +1,155 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import os
+import tempfile
+from itertools import islice
+from typing import IO, TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from airflow.typing_compat import Self
+ from airflow.utils.log.file_task_handler import (
+ LogHandlerOutputStream,
+ StructuredLogMessage,
+ StructuredLogStream,
+ )
+
+
+class LogStreamAccumulator:
+ """
+ Memory-efficient log stream accumulator that tracks the total number of
lines while preserving the original stream.
+
+ This class captures logs from a stream and stores them in a buffer,
flushing them to disk when the buffer
+ exceeds a specified threshold. This approach optimizes memory usage while
handling large log streams.
+
+ Usage:
+
+ .. code-block:: python
+
+ with LogStreamAccumulator(stream, threshold) as log_accumulator:
+ # Get total number of lines captured
+ total_lines = log_accumulator.get_total_lines()
+
+ # Retrieve the original stream of logs
+ for log in log_accumulator.get_stream():
+ print(log)
+ """
+
+ def __init__(
+ self,
+ stream: LogHandlerOutputStream,
+ threshold: int,
+ ) -> None:
+ """
+ Initialize the LogStreamAccumulator.
+
+ Args:
+ stream: The input log stream to capture and count.
+ threshold: Maximum number of lines to keep in memory before
flushing to disk.
+ """
+ self._stream = stream
+ self._threshold = threshold
+ self._buffer: list[StructuredLogMessage] = []
+ self._disk_lines: int = 0
+ self._tmpfile: IO[str] | None = None
+
+ def _flush_buffer_to_disk(self) -> None:
+ """Flush the buffer contents to a temporary file on disk."""
+ if self._tmpfile is None:
+ self._tmpfile = tempfile.NamedTemporaryFile(delete=False,
mode="w+", encoding="utf-8")
+
+ self._disk_lines += len(self._buffer)
+ self._tmpfile.writelines(f"{log.model_dump_json()}\n" for log in
self._buffer)
+ self._tmpfile.flush()
+ self._buffer.clear()
+
+ def _capture(self) -> None:
+ """Capture logs from the stream into the buffer, flushing to disk when
threshold is reached."""
+ while True:
+ # `islice` will try to get up to `self._threshold` lines from the
stream.
+ self._buffer.extend(islice(self._stream, self._threshold))
+ # If there are no more lines to capture, exit the loop.
+ if len(self._buffer) < self._threshold:
+ break
+ self._flush_buffer_to_disk()
+
+ def _cleanup(self) -> None:
+ """Clean up the temporary file if it exists."""
+ self._buffer.clear()
+ if self._tmpfile:
+ self._tmpfile.close()
+ os.remove(self._tmpfile.name)
+ self._tmpfile = None
+
+ @property
+ def total_lines(self) -> int:
+ """
+ Return the total number of lines captured from the stream.
+
+ Returns:
+ The sum of lines stored in the buffer and lines written to disk.
+ """
+ return self._disk_lines + len(self._buffer)
+
+ @property
+ def stream(self) -> StructuredLogStream:
+ """
+ Return the original stream of logs and clean up resources.
+
+ Important: This method automatically cleans up resources after all
logs have been yielded.
+ Make sure to fully consume the returned generator to ensure proper
cleanup.
+
+ Returns:
+ A stream of the captured log messages.
+ """
+ try:
+ if not self._tmpfile:
+ # if no temporary file was created, return from the buffer
+ yield from self._buffer
+ else:
+ # avoid circular import
+ from airflow.utils.log.file_task_handler import
StructuredLogMessage
+
+ with open(self._tmpfile.name, encoding="utf-8") as f:
+ yield from
(StructuredLogMessage.model_validate_json(line.strip()) for line in f)
+ # yield the remaining buffer
+ yield from self._buffer
+ finally:
+ # Ensure cleanup after yielding
+ self._cleanup()
+
+ def __enter__(self) -> Self:
+ """
+ Context manager entry point that initiates log capture.
+
+ Returns:
+ Self instance for use in context manager.
+ """
+ self._capture()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
+ """
+ Context manager exit that doesn't perform resource cleanup.
+
+ Note: Resources are not cleaned up here. Cleanup is deferred until
+ get_stream() is called and fully consumed, ensuring all logs are
properly
+ yielded before cleanup occurs.
+ """
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_log.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_log.py
index 64dedcfd328..adb5aebcb21 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_log.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_log.py
@@ -18,6 +18,7 @@
from __future__ import annotations
import copy
+import json
import logging.config
import sys
from unittest import mock
@@ -36,6 +37,7 @@ from airflow.utils import timezone
from airflow.utils.types import DagRunType
from tests_common.test_utils.db import clear_db_runs
+from tests_common.test_utils.file_task_handler import convert_list_to_stream
pytestmark = [pytest.mark.db_test, pytest.mark.need_serialized_dag]
@@ -233,6 +235,12 @@ class TestTaskInstancesLog:
assert expected_filename in resp_content
assert log_content in resp_content
+ # check content is in ndjson format
+ for line in resp_content.splitlines():
+ log = json.loads(line)
+ assert "event" in log
+ assert "timestamp" in log
+
@pytest.mark.parametrize(
"request_url, expected_filename, extra_query_string, try_number",
[
@@ -304,11 +312,22 @@ class TestTaskInstancesLog:
@pytest.mark.parametrize("try_number", [1, 2])
def test_get_logs_with_metadata_as_download_large_file(self, try_number):
+ from airflow.utils.log.file_task_handler import StructuredLogMessage
+
with
mock.patch("airflow.utils.log.file_task_handler.FileTaskHandler.read") as
read_mock:
- first_return = (["", "1st line"], {})
- second_return = (["", "2nd line"], {"end_of_log": False})
- third_return = (["", "3rd line"], {"end_of_log": True})
- fourth_return = (["", "should never be read"], {"end_of_log":
True})
+ first_return =
(convert_list_to_stream([StructuredLogMessage(event="", message="1st line")]),
{})
+ second_return = (
+ convert_list_to_stream([StructuredLogMessage(event="",
message="2nd line")]),
+ {"end_of_log": False},
+ )
+ third_return = (
+ convert_list_to_stream([StructuredLogMessage(event="",
message="3rd line")]),
+ {"end_of_log": True},
+ )
+ fourth_return = (
+ convert_list_to_stream([StructuredLogMessage(event="",
message="should never be read")]),
+ {"end_of_log": True},
+ )
read_mock.side_effect = [first_return, second_return,
third_return, fourth_return]
response = self.client.get(
diff --git a/airflow-core/tests/unit/utils/log/test_log_reader.py
b/airflow-core/tests/unit/utils/log/test_log_reader.py
index 5ca675f8d08..cd8b430090a 100644
--- a/airflow-core/tests/unit/utils/log/test_log_reader.py
+++ b/airflow-core/tests/unit/utils/log/test_log_reader.py
@@ -41,6 +41,7 @@ from airflow.utils.types import DagRunType
from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.db import clear_db_dags, clear_db_runs
+from tests_common.test_utils.file_task_handler import convert_list_to_stream
pytestmark = pytest.mark.db_test
@@ -127,6 +128,7 @@ class TestLogView:
ti.state = TaskInstanceState.SUCCESS
logs, metadata = task_log_reader.read_log_chunks(ti=ti, try_number=1,
metadata={})
+ logs = list(logs)
assert logs[0].event == "::group::Log message source details"
assert logs[0].sources == [
f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/1.log"
@@ -141,6 +143,7 @@ class TestLogView:
ti.state = TaskInstanceState.SUCCESS
logs, metadata = task_log_reader.read_log_chunks(ti=ti,
try_number=None, metadata={})
+ logs = list(logs)
assert logs[0].event == "::group::Log message source details"
assert logs[0].sources == [
f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/3.log"
@@ -180,16 +183,31 @@ class TestLogView:
@mock.patch("airflow.utils.log.file_task_handler.FileTaskHandler.read")
def test_read_log_stream_should_support_multiple_chunks(self, mock_read):
- first_return = (["1st line"], {})
- second_return = (["2nd line"], {"end_of_log": False})
- third_return = (["3rd line"], {"end_of_log": True})
- fourth_return = (["should never be read"], {"end_of_log": True})
+ from airflow.utils.log.file_task_handler import StructuredLogMessage
+
+ first_return =
(convert_list_to_stream([StructuredLogMessage(event="1st line")]), {})
+ second_return = (
+ convert_list_to_stream([StructuredLogMessage(event="2nd line")]),
+ {"end_of_log": False},
+ )
+ third_return = (
+ convert_list_to_stream([StructuredLogMessage(event="3rd line")]),
+ {"end_of_log": True},
+ )
+ fourth_return = (
+ convert_list_to_stream([StructuredLogMessage(event="should never
be read")]),
+ {"end_of_log": True},
+ )
mock_read.side_effect = [first_return, second_return, third_return,
fourth_return]
task_log_reader = TaskLogReader()
self.ti.state = TaskInstanceState.SUCCESS
log_stream = task_log_reader.read_log_stream(ti=self.ti, try_number=1,
metadata={})
- assert list(log_stream) == ["1st line\n", "2nd line\n", "3rd line\n"]
+ assert list(log_stream) == [
+ '{"timestamp":null,"event":"1st line"}\n',
+ '{"timestamp":null,"event":"2nd line"}\n',
+ '{"timestamp":null,"event":"3rd line"}\n',
+ ]
# as the metadata is now updated in place, when the latest run update
metadata.
# the metadata stored in the mock_read will also be updated
@@ -205,11 +223,18 @@ class TestLogView:
@mock.patch("airflow.utils.log.file_task_handler.FileTaskHandler.read")
def test_read_log_stream_should_read_each_try_in_turn(self, mock_read):
- mock_read.side_effect = [(["try_number=3."], {"end_of_log": True})]
+ from airflow.utils.log.file_task_handler import StructuredLogMessage
+
+ mock_read.side_effect = [
+ (
+
convert_list_to_stream([StructuredLogMessage(event="try_number=3.")]),
+ {"end_of_log": True},
+ )
+ ]
task_log_reader = TaskLogReader()
log_stream = task_log_reader.read_log_stream(ti=self.ti,
try_number=None, metadata={})
- assert list(log_stream) == ["try_number=3.\n"]
+ assert list(log_stream) ==
['{"timestamp":null,"event":"try_number=3."}\n']
mock_read.assert_has_calls(
[
@@ -220,8 +245,10 @@ class TestLogView:
@mock.patch("airflow.utils.log.file_task_handler.FileTaskHandler.read")
def test_read_log_stream_no_end_of_log_marker(self, mock_read):
+ from airflow.utils.log.file_task_handler import StructuredLogMessage
+
mock_read.side_effect = [
- (["hello"], {"end_of_log": False}),
+ ([StructuredLogMessage(event="hello")], {"end_of_log": False}),
*[([], {"end_of_log": False}) for _ in range(10)],
]
@@ -230,7 +257,7 @@ class TestLogView:
task_log_reader.STREAM_LOOP_SLEEP_SECONDS = 0.001 # to speed up the
test
log_stream = task_log_reader.read_log_stream(ti=self.ti, try_number=1,
metadata={})
assert list(log_stream) == [
- "hello\n",
+ '{"timestamp":null,"event":"hello"}\n',
"(Log stream stopped - End of log marker not found; logs may be
incomplete.)\n",
]
assert mock_read.call_count == 11
diff --git a/airflow-core/tests/unit/utils/log/test_stream_accumulator.py
b/airflow-core/tests/unit/utils/log/test_stream_accumulator.py
new file mode 100644
index 00000000000..fd2856d851e
--- /dev/null
+++ b/airflow-core/tests/unit/utils/log/test_stream_accumulator.py
@@ -0,0 +1,165 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import os
+from typing import TYPE_CHECKING
+from unittest import mock
+
+import pendulum
+import pytest
+
+from airflow.utils.log.file_task_handler import StructuredLogMessage
+from airflow.utils.log.log_stream_accumulator import LogStreamAccumulator
+
+if TYPE_CHECKING:
+ from airflow.utils.log.file_task_handler import LogHandlerOutputStream
+
+LOG_START_DATETIME = pendulum.datetime(2023, 10, 1, 0, 0, 0)
+LOG_COUNT = 20
+
+
+class TestLogStreamAccumulator:
+ """Test cases for the LogStreamAccumulator class."""
+
+ @pytest.fixture
+ def structured_logs(self):
+ """Create a stream of mock structured log messages."""
+
+ def generate_logs():
+ yield from (
+ StructuredLogMessage(
+ event=f"test_event_{i + 1}",
+ timestamp=LOG_START_DATETIME.add(seconds=i),
+ level="INFO",
+ message=f"Test log message {i + 1}",
+ )
+ for i in range(LOG_COUNT)
+ )
+
+ return generate_logs()
+
+ def validate_log_stream(self, log_stream: LogHandlerOutputStream):
+ """Validate the log stream by checking the number of lines."""
+
+ count = 0
+ for i, log in enumerate(log_stream):
+ assert log.event == f"test_event_{i + 1}"
+ assert log.timestamp == LOG_START_DATETIME.add(seconds=i)
+ count += 1
+ assert count == 20
+
+ def test__capture(self, structured_logs):
+ """Test that temporary file is properly cleaned up during get_stream,
not when exiting context."""
+
+ accumulator = LogStreamAccumulator(structured_logs, 5)
+ with (
+ mock.patch.object(accumulator, "_capture") as mock_setup,
+ ):
+ with accumulator:
+ mock_setup.assert_called_once()
+
+ def test__flush_buffer_to_disk(self, structured_logs):
+ """Test flush-to-disk behavior with a small threshold."""
+ threshold = 6
+
+ # Mock the temporary file to verify it's being written to
+ with (
+ mock.patch("tempfile.NamedTemporaryFile") as mock_tmpfile,
+ ):
+ mock_file = mock.MagicMock()
+ mock_tmpfile.return_value = mock_file
+
+ with LogStreamAccumulator(structured_logs, threshold) as
accumulator:
+ mock_tmpfile.assert_called_once_with(
+ delete=False,
+ mode="w+",
+ encoding="utf-8",
+ )
+ # Verify _flush_buffer_to_disk was called multiple times
+ # (20 logs / 6 threshold = 3 flushes + 2 remaining logs in
buffer)
+ assert accumulator._disk_lines == 18
+ assert mock_file.writelines.call_count == 3
+ assert len(accumulator._buffer) == 2
+
+ @pytest.mark.parametrize(
+ "threshold",
+ [
+ pytest.param(30, id="buffer_only"),
+ pytest.param(5, id="flush_to_disk"),
+ ],
+ )
+ def test_get_stream(self, structured_logs, threshold):
+ """Test that stream property returns all logs regardless of whether
they were flushed to disk."""
+
+ tmpfile_name = None
+ with LogStreamAccumulator(structured_logs, threshold) as accumulator:
+ out_stream = accumulator.stream
+
+ # Check if the temporary file was created
+ if threshold < LOG_COUNT:
+ tmpfile_name = accumulator._tmpfile.name
+ assert os.path.exists(tmpfile_name)
+ else:
+ assert accumulator._tmpfile is None
+
+ # Validate the log stream
+ self.validate_log_stream(out_stream)
+
+ # Verify temp file was created and cleaned up
+ if threshold < LOG_COUNT:
+ assert accumulator._tmpfile is None
+ assert not os.path.exists(tmpfile_name) if tmpfile_name else
True
+
+ @pytest.mark.parametrize(
+ "threshold, expected_buffer_size, expected_disk_lines",
+ [
+ pytest.param(30, 20, 0, id="no_flush_needed"),
+ pytest.param(10, 0, 20, id="single_flush_needed"),
+ pytest.param(3, 2, 18, id="multiple_flushes_needed"),
+ ],
+ )
+ def test_total_lines(self, structured_logs, threshold,
expected_buffer_size, expected_disk_lines):
+ """Test that LogStreamAccumulator correctly counts lines across buffer
and disk."""
+
+ with LogStreamAccumulator(structured_logs, threshold) as accumulator:
+ # Check buffer and disk line counts
+ assert len(accumulator._buffer) == expected_buffer_size
+ assert accumulator._disk_lines == expected_disk_lines
+ # Validate the log stream and line counts
+ self.validate_log_stream(accumulator.stream)
+
+ def test__cleanup(self, structured_logs):
+ """Test that cleanup happens when stream property is fully consumed,
not on context exit."""
+
+ accumulator = LogStreamAccumulator(structured_logs, 5)
+ with mock.patch.object(accumulator, "_cleanup") as mock_cleanup:
+ with accumulator:
+ # _cleanup should not be called yet
+ mock_cleanup.assert_not_called()
+
+ # Get the stream but don't iterate through it yet
+ stream = accumulator.stream
+ mock_cleanup.assert_not_called()
+
+ # Now iterate through the stream
+ for _ in stream:
+ pass
+
+ # After fully consuming the stream, cleanup should be called
+ mock_cleanup.assert_called_once()
diff --git a/airflow-core/tests/unit/utils/test_log_handlers.py
b/airflow-core/tests/unit/utils/test_log_handlers.py
index 8367aed97ce..9c32e9b232a 100644
--- a/airflow-core/tests/unit/utils/test_log_handlers.py
+++ b/airflow-core/tests/unit/utils/test_log_handlers.py
@@ -17,15 +17,17 @@
# under the License.
from __future__ import annotations
+import heapq
+import io
import itertools
import logging
import logging.config
import os
import re
-from collections.abc import Iterable
from http import HTTPStatus
from importlib import reload
from pathlib import Path
+from typing import cast
from unittest import mock
from unittest.mock import patch
@@ -47,43 +49,42 @@ from airflow.models.taskinstancehistory import
TaskInstanceHistory
from airflow.models.trigger import Trigger
from airflow.providers.standard.operators.python import PythonOperator
from airflow.utils.log.file_task_handler import (
+ DEFAULT_SORT_DATETIME,
FileTaskHandler,
LogType,
+ ParsedLogStream,
StructuredLogMessage,
+ _add_log_from_parsed_log_streams_to_heap,
+ _create_sort_key,
_fetch_logs_from_service,
+ _flush_logs_out_of_heap,
_interleave_logs,
- _parse_log_lines,
+ _is_logs_stream_like,
+ _is_sort_key_with_default_timestamp,
+ _log_stream_to_parsed_log_stream,
+ _stream_lines_by_chunk,
)
from airflow.utils.log.logging_mixin import set_context
from airflow.utils.net import get_hostname
from airflow.utils.session import create_session
from airflow.utils.state import State, TaskInstanceState
-from airflow.utils.timezone import datetime
from airflow.utils.types import DagRunType
from tests_common.test_utils.config import conf_vars
+from tests_common.test_utils.file_task_handler import (
+ convert_list_to_stream,
+ extract_events,
+ mock_parsed_logs_factory,
+)
from tests_common.test_utils.markers import
skip_if_force_lowest_dependencies_marker
pytestmark = pytest.mark.db_test
-DEFAULT_DATE = datetime(2016, 1, 1)
+DEFAULT_DATE = pendulum.datetime(2016, 1, 1)
TASK_LOGGER = "airflow.task"
FILE_TASK_HANDLER = "task"
-def events(logs: Iterable[StructuredLogMessage], skip_source_info=True) ->
list[str]:
- """Helper function to return just the event (a.k.a message) from a list of
StructuredLogMessage"""
- logs = iter(logs)
- if skip_source_info:
-
- def is_source_group(log: StructuredLogMessage):
- return not hasattr(log, "timestamp") or log.event == "::endgroup"
-
- logs = itertools.dropwhile(is_source_group, logs)
-
- return [s.event for s in logs]
-
-
class TestFileTaskLogHandler:
def clean_up(self):
with create_session() as session:
@@ -146,9 +147,11 @@ class TestFileTaskLogHandler:
assert hasattr(file_handler, "read")
# Return value of read must be a tuple of list and list.
# passing invalid `try_number` to read function
- log, metadata = file_handler.read(ti, 0)
+ log_handler_output_stream, metadata = file_handler.read(ti, 0)
assert isinstance(metadata, dict)
- assert log[0].event == "Error fetching the logs. Try number 0 is
invalid."
+ assert extract_events(log_handler_output_stream) == [
+ "Error fetching the logs. Try number 0 is invalid."
+ ]
# Remove the generated tmp log file.
os.remove(log_filename)
@@ -185,7 +188,8 @@ class TestFileTaskLogHandler:
assert log_filename.endswith("0.log"), log_filename
# Return value of read must be a tuple of list and list.
- logs, metadata = file_handler.read(ti)
+ log_handler_output_stream, metadata = file_handler.read(ti)
+ logs = list(log_handler_output_stream)
assert logs[0].event == "Task was skipped, no logs available."
# Remove the generated tmp log file.
@@ -228,14 +232,16 @@ class TestFileTaskLogHandler:
file_handler.close()
assert hasattr(file_handler, "read")
- log, metadata = file_handler.read(ti, 1)
+ log_handler_output_stream, metadata = file_handler.read(ti, 1)
assert isinstance(metadata, dict)
target_re = re.compile(r"\A\[[^\]]+\] {test_log_handlers.py:\d+} INFO
- test\Z")
# We should expect our log line from the callable above to appear in
# the logs we read back
- assert any(re.search(target_re, e) for e in events(log)), "Logs were "
+ str(log)
+ assert any(re.search(target_re, e) for e in
extract_events(log_handler_output_stream)), (
+ f"Logs were {log_handler_output_stream}"
+ )
# Remove the generated tmp log file.
os.remove(log_filename)
@@ -356,8 +362,8 @@ class TestFileTaskLogHandler:
logger.info("Test")
# Return value of read must be a tuple of list and list.
- logs, metadata = file_handler.read(ti)
- assert isinstance(logs, list)
+ log_handler_output_stream, metadata = file_handler.read(ti)
+ assert _is_logs_stream_like(log_handler_output_stream)
# Logs for running tasks should show up too.
assert isinstance(metadata, dict)
@@ -435,7 +441,7 @@ class TestFileTaskLogHandler:
assert find_rotate_log_1 is True
# Logs for running tasks should show up too.
- assert isinstance(logs, list)
+ assert _is_logs_stream_like(logs)
# Remove the two generated tmp log files.
os.remove(log_filename)
@@ -450,7 +456,7 @@ class TestFileTaskLogHandler:
path = Path(
"dag_id=dag_for_testing_local_log_read/run_id=scheduled__2016-01-01T00:00:00+00:00/task_id=task_for_testing_local_log_read/attempt=1.log"
)
- mock_read_local.return_value = (["the messages"], ["the log"])
+ mock_read_local.return_value = (["the messages"],
[convert_list_to_stream(["the log"])])
local_log_file_read = create_task_instance(
dag_id="dag_for_testing_local_log_read",
task_id="task_for_testing_local_log_read",
@@ -458,24 +464,23 @@ class TestFileTaskLogHandler:
logical_date=DEFAULT_DATE,
)
fth = FileTaskHandler("")
- logs, metadata = fth._read(ti=local_log_file_read, try_number=1)
+ log_handler_output_stream, metadata =
fth._read(ti=local_log_file_read, try_number=1)
mock_read_local.assert_called_with(path)
- as_text = events(logs)
- assert logs[0].sources == ["the messages"]
- assert as_text[-1] == "the log"
+ assert extract_events(log_handler_output_stream) == ["the log"]
assert metadata == {"end_of_log": True, "log_pos": 1}
def test__read_from_local(self, tmp_path):
"""Tests the behavior of method _read_from_local"""
path1 = tmp_path / "hello1.log"
path2 = tmp_path / "hello1.log.suffix.log"
- path1.write_text("file1 content")
- path2.write_text("file2 content")
+ path1.write_text("file1 content\nfile1 content2")
+ path2.write_text("file2 content\nfile2 content2")
fth = FileTaskHandler("")
- assert fth._read_from_local(path1) == (
- [str(path1), str(path2)],
- ["file1 content", "file2 content"],
- )
+ log_source_info, log_streams = fth._read_from_local(path1)
+ assert log_source_info == [str(path1), str(path2)]
+ assert len(log_streams) == 2
+ assert list(log_streams[0]) == ["file1 content", "file1 content2"]
+ assert list(log_streams[1]) == ["file2 content", "file2 content2"]
@pytest.mark.parametrize(
"remote_logs, local_logs, served_logs_checked",
@@ -505,32 +510,57 @@ class TestFileTaskLogHandler:
logical_date=DEFAULT_DATE,
)
ti.state = TaskInstanceState.SUCCESS # we're testing scenario when
task is done
+ expected_logs = ["::group::Log message source details", "::endgroup::"]
with conf_vars({("core", "executor"): executor_name}):
reload(executor_loader)
fth = FileTaskHandler("")
if remote_logs:
fth._read_remote_logs = mock.Mock()
fth._read_remote_logs.return_value = ["found remote logs"],
["remote\nlog\ncontent"]
+ expected_logs.extend(
+ [
+ "remote",
+ "log",
+ "content",
+ ]
+ )
if local_logs:
fth._read_from_local = mock.Mock()
- fth._read_from_local.return_value = ["found local logs"],
["local\nlog\ncontent"]
+ fth._read_from_local.return_value = (
+ ["found local logs"],
+
[convert_list_to_stream("local\nlog\ncontent".splitlines())],
+ )
+ # only when not read from remote and TI is unfinished will
read from local
+ if not remote_logs:
+ expected_logs.extend(
+ [
+ "local",
+ "log",
+ "content",
+ ]
+ )
fth._read_from_logs_server = mock.Mock()
- fth._read_from_logs_server.return_value = ["this message"],
["this\nlog\ncontent"]
+ fth._read_from_logs_server.return_value = (
+ ["this message"],
+ [convert_list_to_stream("this\nlog\ncontent".splitlines())],
+ )
+ # only when not read from remote and not read from local will read
from logs server
+ if served_logs_checked:
+ expected_logs.extend(
+ [
+ "this",
+ "log",
+ "content",
+ ]
+ )
+
logs, metadata = fth._read(ti=ti, try_number=1)
if served_logs_checked:
fth._read_from_logs_server.assert_called_once()
- assert events(logs) == [
- "::group::Log message source details",
- "::endgroup::",
- "this",
- "log",
- "content",
- ]
- assert metadata == {"end_of_log": True, "log_pos": 3}
else:
fth._read_from_logs_server.assert_not_called()
- assert logs
- assert metadata
+ assert extract_events(logs, False) == expected_logs
+ assert metadata == {"end_of_log": True, "log_pos": 3}
def test_add_triggerer_suffix(self):
sample = "any/path/to/thing.txt"
@@ -738,11 +768,119 @@
AIRFLOW_CTX_DAG_RUN_ID=manual__2022-11-16T08:05:52.324532+00:00
"""
-def test_parse_timestamps():
- actual = []
- for timestamp, _, _ in _parse_log_lines(log_sample.splitlines()):
- actual.append(timestamp)
- assert actual == [
[email protected](
+ "chunk_size, expected_read_calls",
+ [
+ (10, 4),
+ (20, 3),
+ # will read all logs in one call, but still need another call to get
empty string at the end to escape the loop
+ (50, 2),
+ (100, 2),
+ ],
+)
+def test__stream_lines_by_chunk(chunk_size, expected_read_calls):
+ # Mock CHUNK_SIZE to a smaller value to test
+ with mock.patch("airflow.utils.log.file_task_handler.CHUNK_SIZE",
chunk_size):
+ log_io = io.StringIO("line1\nline2\nline3\nline4\n")
+ log_io.read = mock.MagicMock(wraps=log_io.read)
+
+ # Stream lines using the function
+ streamed_lines = list(_stream_lines_by_chunk(log_io))
+
+ # Verify the output matches the input split by lines
+ expected_output = ["line1", "line2", "line3", "line4"]
+ assert log_io.read.call_count == expected_read_calls, (
+ f"Expected {expected_read_calls} calls to read, got
{log_io.read.call_count}"
+ )
+ assert streamed_lines == expected_output, f"Expected
{expected_output}, got {streamed_lines}"
+
+
[email protected](
+ "seekable",
+ [
+ pytest.param(True, id="seekable_stream"),
+ pytest.param(False, id="non_seekable_stream"),
+ ],
+)
[email protected](
+ "closed",
+ [
+ pytest.param(False, id="not_closed_stream"),
+ pytest.param(True, id="closed_stream"),
+ ],
+)
[email protected](
+ "unexpected_exception",
+ [
+ pytest.param(None, id="no_exception"),
+ pytest.param(ValueError, id="value_error"),
+ pytest.param(IOError, id="io_error"),
+ pytest.param(Exception, id="generic_exception"),
+ ],
+)
[email protected](
+ "airflow.utils.log.file_task_handler.CHUNK_SIZE", 10
+) # Mock CHUNK_SIZE to a smaller value for testing
+def test__stream_lines_by_chunk_error_handling(seekable, closed,
unexpected_exception):
+ """
+ Test that _stream_lines_by_chunk handles errors correctly.
+ """
+ log_io = io.StringIO("line1\nline2\nline3\nline4\n")
+ log_io.seekable = mock.MagicMock(return_value=seekable)
+ log_io.seek = mock.MagicMock(wraps=log_io.seek)
+ # Mock the read method to check the call count and handle exceptions
+ if unexpected_exception:
+ expected_error = unexpected_exception("An error occurred while reading
the log stream.")
+ log_io.read = mock.MagicMock(side_effect=expected_error)
+ else:
+ log_io.read = mock.MagicMock(wraps=log_io.read)
+
+ # Setup closed state if needed - must be done before starting the test
+ if closed:
+ log_io.close()
+
+ # If an exception is expected, we mock the read method to raise it
+ if unexpected_exception and not closed:
+ # Only expect logger error if stream is not closed and there's an
exception
+ with mock.patch("airflow.utils.log.file_task_handler.logger.error") as
mock_logger_error:
+ result = list(_stream_lines_by_chunk(log_io))
+ mock_logger_error.assert_called_once_with("Error reading log
stream: %s", expected_error)
+ else:
+ # For normal case or closed stream with exception, collect the output
+ result = list(_stream_lines_by_chunk(log_io))
+
+ # Check if seekable was called properly
+ if seekable and not closed:
+ log_io.seek.assert_called_once_with(0)
+ if not seekable:
+ log_io.seek.assert_not_called()
+
+ # Validate the results based on the conditions
+ if not closed and not unexpected_exception: # Non-seekable streams
without errors should still get lines
+ assert log_io.read.call_count > 1, "Expected read method to be called
at least once."
+ assert result == ["line1", "line2", "line3", "line4"]
+ elif closed:
+ assert log_io.read.call_count == 0, "Read method should not be called
on a closed stream."
+ assert result == [], "Expected no lines to be yield from a closed
stream."
+ elif unexpected_exception: # If an exception was raised
+ assert log_io.read.call_count == 1, "Read method should be called
once."
+ assert result == [], "Expected no lines to be yield from a stream that
raised an exception."
+
+
+def test__log_stream_to_parsed_log_stream():
+ parsed_log_stream =
_log_stream_to_parsed_log_stream(io.StringIO(log_sample))
+
+ actual_timestamps = []
+ last_idx = -1
+ for parsed_log in parsed_log_stream:
+ timestamp, idx, structured_log = parsed_log
+ actual_timestamps.append(timestamp)
+ if last_idx != -1:
+ assert idx > last_idx
+ last_idx = idx
+ assert isinstance(structured_log, StructuredLogMessage)
+
+ assert actual_timestamps == [
pendulum.parse("2022-11-16T00:05:54.278000-08:00"),
pendulum.parse("2022-11-16T00:05:54.278000-08:00"),
pendulum.parse("2022-11-16T00:05:54.278000-08:00"),
@@ -766,34 +904,249 @@ def test_parse_timestamps():
]
+def test__create_sort_key():
+ # assert _sort_key should return int
+ sort_key =
_create_sort_key(pendulum.parse("2022-11-16T00:05:54.278000-08:00"), 10)
+ assert sort_key == 16685859542780000010
+
+
[email protected](
+ "timestamp, line_num, expected",
+ [
+ pytest.param(
+ pendulum.parse("2022-11-16T00:05:54.278000-08:00"),
+ 10,
+ False,
+ id="normal_timestamp_1",
+ ),
+ pytest.param(
+ pendulum.parse("2022-11-16T00:05:54.457000-08:00"),
+ 2025,
+ False,
+ id="normal_timestamp_2",
+ ),
+ pytest.param(
+ DEFAULT_SORT_DATETIME,
+ 200,
+ True,
+ id="default_timestamp",
+ ),
+ ],
+)
+def test__is_sort_key_with_default_timestamp(timestamp, line_num, expected):
+ assert _is_sort_key_with_default_timestamp(_create_sort_key(timestamp,
line_num)) == expected
+
+
[email protected](
+ "log_stream, expected",
+ [
+ pytest.param(
+ convert_list_to_stream(
+ [
+ "2022-11-16T00:05:54.278000-08:00",
+ "2022-11-16T00:05:54.457000-08:00",
+ ]
+ ),
+ True,
+ id="normal_log_stream",
+ ),
+ pytest.param(
+ itertools.chain(
+ [
+ "2022-11-16T00:05:54.278000-08:00",
+ "2022-11-16T00:05:54.457000-08:00",
+ ],
+ convert_list_to_stream(
+ [
+ "2022-11-16T00:05:54.278000-08:00",
+ "2022-11-16T00:05:54.457000-08:00",
+ ]
+ ),
+ ),
+ True,
+ id="chain_log_stream",
+ ),
+ pytest.param(
+ [
+ "2022-11-16T00:05:54.278000-08:00",
+ "2022-11-16T00:05:54.457000-08:00",
+ ],
+ False,
+ id="non_stream_log",
+ ),
+ ],
+)
+def test__is_logs_stream_like(log_stream, expected):
+ assert _is_logs_stream_like(log_stream) == expected
+
+
+def test__add_log_from_parsed_log_streams_to_heap():
+ """
+ Test cases:
+
+ Timestamp: 26 27 28 29 30 31
+ Source 1: --
+ Source 2: -- --
+ Source 3: -- -- --
+ """
+ heap: list[tuple[int, StructuredLogMessage]] = []
+ input_parsed_log_streams: dict[int, ParsedLogStream] = {
+ 0: convert_list_to_stream(
+ mock_parsed_logs_factory("Source 1",
pendulum.parse("2022-11-16T00:05:54.270000-08:00"), 1)
+ ),
+ 1: convert_list_to_stream(
+ mock_parsed_logs_factory("Source 2",
pendulum.parse("2022-11-16T00:05:54.290000-08:00"), 2)
+ ),
+ 2: convert_list_to_stream(
+ mock_parsed_logs_factory("Source 3",
pendulum.parse("2022-11-16T00:05:54.380000-08:00"), 3)
+ ),
+ }
+
+ # Check that we correctly get the first line of each non-empty log stream
+
+ # First call: should add log records for all log streams
+ _add_log_from_parsed_log_streams_to_heap(heap, input_parsed_log_streams)
+ assert len(input_parsed_log_streams) == 3
+ assert len(heap) == 3
+ # Second call: source 1 is empty, should add log records for source 2 and
source 3
+ _add_log_from_parsed_log_streams_to_heap(heap, input_parsed_log_streams)
+ assert len(input_parsed_log_streams) == 2 # Source 1 should be removed
+ assert len(heap) == 5
+ # Third call: source 1 and source 2 are empty, should add log records for
source 3
+ _add_log_from_parsed_log_streams_to_heap(heap, input_parsed_log_streams)
+ assert len(input_parsed_log_streams) == 1 # Source 2 should be removed
+ assert len(heap) == 6
+ # Fourth call: source 1, source 2, and source 3 are empty, should not add
any log records
+ _add_log_from_parsed_log_streams_to_heap(heap, input_parsed_log_streams)
+ assert len(input_parsed_log_streams) == 0 # Source 3 should be removed
+ assert len(heap) == 6
+ # Fifth call: all sources are empty, should not add any log records
+ assert len(input_parsed_log_streams) == 0 # remains empty
+ assert len(heap) == 6 # no change in heap size
+ # Check heap
+ expected_logs: list[str] = [
+ "Source 1 Event 0",
+ "Source 2 Event 0",
+ "Source 3 Event 0",
+ "Source 2 Event 1",
+ "Source 3 Event 1",
+ "Source 3 Event 2",
+ ]
+ actual_logs: list[str] = []
+ for _ in range(len(heap)):
+ _, log = heapq.heappop(heap)
+ actual_logs.append(log.event)
+ assert actual_logs == expected_logs
+
+
[email protected](
+ "heap_setup, flush_size, last_log, expected_events",
+ [
+ pytest.param(
+ [("msg1", "2023-01-01"), ("msg2", "2023-01-02")],
+ 2,
+ None,
+ ["msg1", "msg2"],
+ id="exact_size_flush",
+ ),
+ pytest.param(
+ [
+ ("msg1", "2023-01-01"),
+ ("msg2", "2023-01-02"),
+ ("msg3", "2023-01-03"),
+ ("msg3", "2023-01-03"),
+ ("msg5", "2023-01-05"),
+ ],
+ 5,
+ None,
+ ["msg1", "msg2", "msg3", "msg5"], # msg3 is deduplicated, msg5
has default timestamp
+ id="flush_with_duplicates",
+ ),
+ pytest.param(
+ [("msg1", "2023-01-01"), ("msg1", "2023-01-01"), ("msg2",
"2023-01-02")],
+ 3,
+ "msg1",
+ ["msg2"], # The last_log is "msg1", so any duplicates of "msg1"
should be skipped
+ id="flush_with_last_log",
+ ),
+ pytest.param(
+ [("msg1", "DEFAULT"), ("msg1", "DEFAULT"), ("msg2", "DEFAULT")],
+ 3,
+ "msg1",
+ [
+ "msg1",
+ "msg1",
+ "msg2",
+ ], # All messages have default timestamp, so they should be
flushed even if last_log is "msg1"
+ id="flush_with_default_timestamp_and_last_log",
+ ),
+ pytest.param(
+ [("msg1", "2023-01-01"), ("msg2", "2023-01-02"), ("msg3",
"2023-01-03")],
+ 2,
+ None,
+ ["msg1", "msg2"], # Only the first two messages should be flushed
+ id="flush_size_smaller_than_heap",
+ ),
+ ],
+)
+def test__flush_logs_out_of_heap(heap_setup, flush_size, last_log,
expected_events):
+ """Test the _flush_logs_out_of_heap function with different scenarios."""
+
+ # Create structured log messages from the test setup
+ heap = []
+ messages = {}
+ for i, (event, timestamp_str) in enumerate(heap_setup):
+ if timestamp_str == "DEFAULT":
+ timestamp = DEFAULT_SORT_DATETIME
+ else:
+ timestamp = pendulum.parse(timestamp_str)
+
+ msg = StructuredLogMessage(event=event, timestamp=timestamp)
+ messages[event] = msg
+ heapq.heappush(heap, (_create_sort_key(msg.timestamp, i), msg))
+
+ # Set last_log if specified in the test case
+ last_log_obj = messages.get(last_log) if last_log is not None else None
+ last_log_container = [last_log_obj]
+
+ # Run the function under test
+ result = list(_flush_logs_out_of_heap(heap, flush_size,
last_log_container))
+
+ # Verify the results
+ assert len(result) == len(expected_events)
+ assert len(heap) == (len(heap_setup) - flush_size)
+ for i, expected_event in enumerate(expected_events):
+ assert result[i].event == expected_event, f"result = {result},
expected_event = {expected_events}"
+
+ # verify that the last log is updated correctly
+ last_log_obj = last_log_container[0]
+ assert last_log_obj is not None
+ last_log_obj = cast("StructuredLogMessage", last_log_obj)
+ assert last_log_obj.event == expected_events[-1]
+
+
def test_interleave_interleaves():
- log_sample1 = "\n".join(
- [
- "[2022-11-16T00:05:54.278-0800] {taskinstance.py:1258} INFO -
Starting attempt 1 of 1",
- ]
- )
- log_sample2 = "\n".join(
- [
- "[2022-11-16T00:05:54.295-0800] {taskinstance.py:1278} INFO -
Executing <Task(TimeDeltaSensorAsync): wait> on 2022-11-16
08:05:52.324532+00:00",
- "[2022-11-16T00:05:54.300-0800] {standard_task_runner.py:55} INFO
- Started process 52536 to run task",
- "[2022-11-16T00:05:54.300-0800] {standard_task_runner.py:55} INFO
- Started process 52536 to run task",
- "[2022-11-16T00:05:54.300-0800] {standard_task_runner.py:55} INFO
- Started process 52536 to run task",
- "[2022-11-16T00:05:54.306-0800] {standard_task_runner.py:82} INFO
- Running: ['airflow', 'tasks', 'run', 'simple_async_timedelta', 'wait',
'manual__2022-11-16T08:05:52.324532+00:00', '--job-id', '33648', '--raw',
'--subdir',
'/Users/dstandish/code/airflow/airflow/example_dags/example_time_delta_sensor_async.py',
'--cfg-path', '/var/folders/7_/1xx0hqcs3txd7kqt0ngfdjth0000gn/T/tmp725r305n']",
- "[2022-11-16T00:05:54.309-0800] {standard_task_runner.py:83} INFO
- Job 33648: Subtask wait",
- ]
- )
- log_sample3 = "\n".join(
- [
- "[2022-11-16T00:05:54.457-0800] {task_command.py:376} INFO -
Running <TaskInstance: simple_async_timedelta.wait
manual__2022-11-16T08:05:52.324532+00:00 [running]> on host daniels-mbp-2.lan",
- "[2022-11-16T00:05:54.592-0800] {taskinstance.py:1485} INFO -
Exporting env vars: AIRFLOW_CTX_DAG_OWNER=airflow",
- "AIRFLOW_CTX_DAG_ID=simple_async_timedelta",
- "AIRFLOW_CTX_TASK_ID=wait",
- "AIRFLOW_CTX_LOGICAL_DATE=2022-11-16T08:05:52.324532+00:00",
- "AIRFLOW_CTX_TRY_NUMBER=1",
- "AIRFLOW_CTX_DAG_RUN_ID=manual__2022-11-16T08:05:52.324532+00:00",
- "[2022-11-16T00:05:54.604-0800] {taskinstance.py:1360} INFO -
Pausing task as DEFERRED. dag_id=simple_async_timedelta, task_id=wait,
execution_date=20221116T080552, start_date=20221116T080554",
- ]
- )
+ log_sample1 = [
+ "[2022-11-16T00:05:54.278-0800] {taskinstance.py:1258} INFO - Starting
attempt 1 of 1",
+ ]
+ log_sample2 = [
+ "[2022-11-16T00:05:54.295-0800] {taskinstance.py:1278} INFO -
Executing <Task(TimeDeltaSensorAsync): wait> on 2022-11-16
08:05:52.324532+00:00",
+ "[2022-11-16T00:05:54.300-0800] {standard_task_runner.py:55} INFO -
Started process 52536 to run task",
+ "[2022-11-16T00:05:54.300-0800] {standard_task_runner.py:55} INFO -
Started process 52536 to run task",
+ "[2022-11-16T00:05:54.300-0800] {standard_task_runner.py:55} INFO -
Started process 52536 to run task",
+ "[2022-11-16T00:05:54.306-0800] {standard_task_runner.py:82} INFO -
Running: ['airflow', 'tasks', 'run', 'simple_async_timedelta', 'wait',
'manual__2022-11-16T08:05:52.324532+00:00', '--job-id', '33648', '--raw',
'--subdir',
'/Users/dstandish/code/airflow/airflow/example_dags/example_time_delta_sensor_async.py',
'--cfg-path', '/var/folders/7_/1xx0hqcs3txd7kqt0ngfdjth0000gn/T/tmp725r305n']",
+ "[2022-11-16T00:05:54.309-0800] {standard_task_runner.py:83} INFO -
Job 33648: Subtask wait",
+ ]
+ log_sample3 = [
+ "[2022-11-16T00:05:54.457-0800] {task_command.py:376} INFO - Running
<TaskInstance: simple_async_timedelta.wait
manual__2022-11-16T08:05:52.324532+00:00 [running]> on host daniels-mbp-2.lan",
+ "[2022-11-16T00:05:54.592-0800] {taskinstance.py:1485} INFO -
Exporting env vars: AIRFLOW_CTX_DAG_OWNER=airflow",
+ "AIRFLOW_CTX_DAG_ID=simple_async_timedelta",
+ "AIRFLOW_CTX_TASK_ID=wait",
+ "AIRFLOW_CTX_LOGICAL_DATE=2022-11-16T08:05:52.324532+00:00",
+ "AIRFLOW_CTX_TRY_NUMBER=1",
+ "AIRFLOW_CTX_DAG_RUN_ID=manual__2022-11-16T08:05:52.324532+00:00",
+ "[2022-11-16T00:05:54.604-0800] {taskinstance.py:1360} INFO - Pausing
task as DEFERRED. dag_id=simple_async_timedelta, task_id=wait,
execution_date=20221116T080552, start_date=20221116T080554",
+ ]
# -08:00
tz = pendulum.tz.fixed_timezone(-28800)
@@ -870,11 +1223,14 @@ def test_interleave_interleaves():
},
]
# Use a type adapter to durn it in to dicts -- makes it easier to
compare/test than a bunch of objects
- results = TypeAdapter(list[StructuredLogMessage]).dump_python(
- _interleave_logs(log_sample2, log_sample1, log_sample3)
+ results: list[StructuredLogMessage] = list(
+ _interleave_logs(
+ convert_list_to_stream(log_sample2),
+ convert_list_to_stream(log_sample1),
+ convert_list_to_stream(log_sample3),
+ )
)
- # TypeAdapter gives us a generator out when it's generator is an input.
Nice, but not useful for testing
- results = list(results)
+ results: list[dict] =
TypeAdapter(list[StructuredLogMessage]).dump_python(results)
assert results == expected
@@ -891,7 +1247,13 @@ def test_interleave_logs_correct_ordering():
[2023-01-17T12:47:11.883-0800] {triggerer_job.py:540} INFO - Trigger
<airflow.triggers.temporal.DateTimeTrigger
moment=2023-01-17T20:47:11.254388+00:00> (ID 1) fired:
TriggerEvent<DateTime(2023, 1, 17, 20, 47, 11, 254388, tzinfo=Timezone('UTC'))>
"""
- logs = events(_interleave_logs(sample_with_dupe, "", sample_with_dupe))
+ logs = extract_events(
+ _interleave_logs(
+ convert_list_to_stream(sample_with_dupe.splitlines()),
+ convert_list_to_stream([]),
+ convert_list_to_stream(sample_with_dupe.splitlines()),
+ )
+ )
assert sample_with_dupe == "\n".join(logs)
@@ -907,7 +1269,12 @@ def test_interleave_logs_correct_dedupe():
test,
test"""
- logs = events(_interleave_logs(",\n ".join(["test"] * 10)))
+ input_logs = ",\n ".join(["test"] * 10)
+ logs = extract_events(
+ _interleave_logs(
+ convert_list_to_stream(input_logs.splitlines()),
+ )
+ )
assert sample_without_dupe == "\n".join(logs)
diff --git a/devel-common/src/tests_common/test_utils/file_task_handler.py
b/devel-common/src/tests_common/test_utils/file_task_handler.py
new file mode 100644
index 00000000000..5153fcfc511
--- /dev/null
+++ b/devel-common/src/tests_common/test_utils/file_task_handler.py
@@ -0,0 +1,76 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import itertools
+from collections.abc import Generator, Iterable
+from datetime import datetime
+from typing import TYPE_CHECKING
+
+import pendulum
+
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
+
+if TYPE_CHECKING:
+ from airflow.utils.log.file_task_handler import ParsedLog,
StructuredLogMessage
+
+
+def extract_events(logs: Iterable[StructuredLogMessage],
skip_source_info=True) -> list[str]:
+ """Helper function to return just the event (a.k.a message) from a list of
StructuredLogMessage"""
+ logs = iter(logs)
+ if skip_source_info:
+
+ def is_source_group(log: StructuredLogMessage) -> bool:
+ return not hasattr(log, "timestamp") or log.event ==
"::endgroup::" or hasattr(log, "sources")
+
+ logs = itertools.dropwhile(is_source_group, logs)
+
+ return [s.event for s in logs]
+
+
+def convert_list_to_stream(input_list: list[str]) -> Generator[str, None,
None]:
+ """
+ Convert a list of strings to a stream-like object.
+ This function yields each string in the list one by one.
+ """
+ yield from input_list
+
+
+def mock_parsed_logs_factory(
+ event_prefix: str,
+ start_datetime: datetime,
+ count: int,
+) -> list[ParsedLog]:
+ """
+ Create a list of ParsedLog objects with the specified start datetime and
count.
+ Each ParsedLog object contains a timestamp and a list of
StructuredLogMessage objects.
+ """
+ if AIRFLOW_V_3_0_PLUS:
+ from airflow.utils.log.file_task_handler import StructuredLogMessage
+
+ return [
+ (
+ pendulum.instance(start_datetime + pendulum.duration(seconds=i)),
+ i,
+ StructuredLogMessage(
+ timestamp=pendulum.instance(start_datetime +
pendulum.duration(seconds=i)),
+ event=f"{event_prefix} Event {i}",
+ ),
+ )
+ for i in range(count)
+ ]
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 62b5b70fa9b..6b20d228bad 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1708,6 +1708,7 @@ StrictUndefined
Stringified
stringified
Struct
+StructuredLogMessage
STS
subchart
subclassed
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py
b/providers/amazon/src/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py
index b5f6aaa1266..0961b091a58 100644
---
a/providers/amazon/src/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py
+++
b/providers/amazon/src/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py
@@ -33,7 +33,6 @@ import watchtower
from airflow.configuration import conf
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.utils import datetime_to_epoch_utc_ms
-from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.utils.log.file_task_handler import FileTaskHandler
from airflow.utils.log.logging_mixin import LoggingMixin
@@ -170,15 +169,7 @@ class CloudWatchRemoteLogIO(LoggingMixin): # noqa: D101
f"Reading remote log from Cloudwatch log_group: {self.log_group}
log_stream: {relative_path}"
]
try:
- if AIRFLOW_V_3_0_PLUS:
- from airflow.utils.log.file_task_handler import
StructuredLogMessage
-
- logs = [
- StructuredLogMessage.model_validate(log)
- for log in self.get_cloudwatch_logs(relative_path, ti)
- ]
- else:
- logs = [self.get_cloudwatch_logs(relative_path, ti)] # type:
ignore[arg-value]
+ logs = [self.get_cloudwatch_logs(relative_path, ti)] # type:
ignore[arg-value]
except Exception as e:
logs = None
messages.append(str(e))
@@ -206,8 +197,6 @@ class CloudWatchRemoteLogIO(LoggingMixin): # noqa: D101
log_stream_name=stream_name,
end_time=end_time,
)
- if AIRFLOW_V_3_0_PLUS:
- return list(self._event_to_dict(e) for e in events)
return "\n".join(self._event_to_str(event) for event in events)
def _event_to_dict(self, event: dict) -> dict:
@@ -222,7 +211,8 @@ class CloudWatchRemoteLogIO(LoggingMixin): # noqa: D101
def _event_to_str(self, event: dict) -> str:
event_dt = datetime.fromtimestamp(event["timestamp"] / 1000.0,
tz=timezone.utc)
- formatted_event_dt = event_dt.strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
+ # Format a datetime object to a string in Zulu time without
milliseconds.
+ formatted_event_dt = event_dt.strftime("%Y-%m-%dT%H:%M:%SZ")
message = event["message"]
return f"[{formatted_event_dt}] {message}"
diff --git
a/providers/amazon/tests/unit/amazon/aws/log/test_cloudwatch_task_handler.py
b/providers/amazon/tests/unit/amazon/aws/log/test_cloudwatch_task_handler.py
index 097ccc9eabb..4ca32105256 100644
--- a/providers/amazon/tests/unit/amazon/aws/log/test_cloudwatch_task_handler.py
+++ b/providers/amazon/tests/unit/amazon/aws/log/test_cloudwatch_task_handler.py
@@ -25,11 +25,11 @@ from unittest import mock
from unittest.mock import ANY, call
import boto3
+import pendulum
import pytest
import time_machine
from moto import mock_aws
from pydantic import TypeAdapter
-from pydantic_core import TzInfo
from watchtower import CloudWatchLogHandler
from airflow.models import DAG, DagRun, TaskInstance
@@ -50,7 +50,7 @@ from tests_common.test_utils.version_compat import
AIRFLOW_V_3_0_PLUS
def get_time_str(time_in_milliseconds):
dt_time = dt.fromtimestamp(time_in_milliseconds / 1000.0, tz=timezone.utc)
- return dt_time.strftime("%Y-%m-%d %H:%M:%S,000")
+ return dt_time.strftime("%Y-%m-%dT%H:%M:%SZ")
@pytest.fixture(autouse=True)
@@ -148,23 +148,12 @@ class TestCloudRemoteLogIO:
stream_name = self.task_log_path.replace(":", "_")
logs = self.subject.read(stream_name, self.ti)
- if AIRFLOW_V_3_0_PLUS:
- from airflow.utils.log.file_task_handler import
StructuredLogMessage
-
- metadata, logs = logs
+ metadata, logs = logs
- results =
TypeAdapter(list[StructuredLogMessage]).dump_python(logs)
- assert metadata == [
- f"Reading remote log from Cloudwatch log_group:
log_group_name log_stream: {stream_name}"
- ]
- assert results == [
- {
- "event": "Hi",
- "foo": "bar",
- "level": "info",
- "timestamp": datetime(2025, 3, 27, 21, 58, 1, 2000,
tzinfo=TzInfo(0)),
- },
- ]
+ assert metadata == [
+ f"Reading remote log from Cloudwatch log_group: log_group_name
log_stream: {stream_name}"
+ ]
+ assert logs == ['[2025-03-27T21:58:01Z] {"foo": "bar", "event":
"Hi", "level": "info"}']
def test_event_to_str(self):
handler = self.subject
@@ -282,7 +271,11 @@ class TestCloudwatchTaskHandler:
{"timestamp": current_time, "message": "Third"},
],
)
- monkeypatch.setattr(self.cloudwatch_task_handler,
"_read_from_logs_server", lambda a, b: ([], []))
+ monkeypatch.setattr(
+ self.cloudwatch_task_handler,
+ "_read_from_logs_server",
+ lambda ti, worker_log_rel_path: ([], []),
+ )
msg_template = textwrap.dedent("""
INFO - ::group::Log message source details
*** Reading remote log from Cloudwatch log_group: {} log_stream: {}
@@ -294,14 +287,24 @@ class TestCloudwatchTaskHandler:
if AIRFLOW_V_3_0_PLUS:
from airflow.utils.log.file_task_handler import
StructuredLogMessage
+ logs = list(logs)
results = TypeAdapter(list[StructuredLogMessage]).dump_python(logs)
assert results[-4:] == [
{"event": "::endgroup::", "timestamp": None},
- {"event": "First", "timestamp": datetime(2025, 3, 27, 21, 57,
59)},
- {"event": "Second", "timestamp": datetime(2025, 3, 27, 21, 58,
0)},
- {"event": "Third", "timestamp": datetime(2025, 3, 27, 21, 58,
1)},
+ {
+ "event": "[2025-03-27T21:57:59Z] First",
+ "timestamp": pendulum.datetime(2025, 3, 27, 21, 57, 59),
+ },
+ {
+ "event": "[2025-03-27T21:58:00Z] Second",
+ "timestamp": pendulum.datetime(2025, 3, 27, 21, 58, 0),
+ },
+ {
+ "event": "[2025-03-27T21:58:01Z] Third",
+ "timestamp": pendulum.datetime(2025, 3, 27, 21, 58, 1),
+ },
]
- assert metadata["log_pos"] == 3
+ assert metadata == {"end_of_log": False, "log_pos": 3}
else:
events = "\n".join(
[
diff --git a/providers/amazon/tests/unit/amazon/aws/log/test_s3_task_handler.py
b/providers/amazon/tests/unit/amazon/aws/log/test_s3_task_handler.py
index f254eeeec07..7d58b5e5b30 100644
--- a/providers/amazon/tests/unit/amazon/aws/log/test_s3_task_handler.py
+++ b/providers/amazon/tests/unit/amazon/aws/log/test_s3_task_handler.py
@@ -264,6 +264,7 @@ class TestS3TaskHandler:
expected_s3_uri = f"s3://bucket/{self.remote_log_key}"
if AIRFLOW_V_3_0_PLUS:
+ log = list(log)
assert log[0].event == "::group::Log message source details"
assert expected_s3_uri in log[0].sources
assert log[1].event == "::endgroup::"
@@ -284,6 +285,7 @@ class TestS3TaskHandler:
self.s3_task_handler._read_from_logs_server =
mock.Mock(return_value=([], []))
log, metadata = self.s3_task_handler.read(ti)
if AIRFLOW_V_3_0_PLUS:
+ log = list(log)
assert len(log) == 2
assert metadata == {"end_of_log": True, "log_pos": 0}
else:
diff --git
a/providers/celery/tests/unit/celery/log_handlers/test_log_handlers.py
b/providers/celery/tests/unit/celery/log_handlers/test_log_handlers.py
index 456f680de41..c95ab43236e 100644
--- a/providers/celery/tests/unit/celery/log_handlers/test_log_handlers.py
+++ b/providers/celery/tests/unit/celery/log_handlers/test_log_handlers.py
@@ -37,6 +37,9 @@ from airflow.utils.timezone import datetime
from airflow.utils.types import DagRunType
from tests_common.test_utils.config import conf_vars
+from tests_common.test_utils.file_task_handler import (
+ convert_list_to_stream,
+)
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
pytestmark = pytest.mark.db_test
@@ -78,14 +81,24 @@ class TestFileTaskLogHandler:
fth = FileTaskHandler("")
fth._read_from_logs_server = mock.Mock()
- fth._read_from_logs_server.return_value = ["this message"],
["this\nlog\ncontent"]
+
+ # compat with 2.x and 3.x
+ if AIRFLOW_V_3_0_PLUS:
+ fth._read_from_logs_server.return_value = (
+ ["this message"],
+ [convert_list_to_stream(["this", "log", "content"])],
+ )
+ else:
+ fth._read_from_logs_server.return_value = ["this message"],
["this\nlog\ncontent"]
+
logs, metadata = fth._read(ti=ti, try_number=1)
fth._read_from_logs_server.assert_called_once()
if AIRFLOW_V_3_0_PLUS:
- assert metadata == {"end_of_log": False, "log_pos": 3}
+ logs = list(logs)
assert logs[0].sources == ["this message"]
assert [x.event for x in logs[-3:]] == ["this", "log", "content"]
+ assert metadata == {"end_of_log": False, "log_pos": 3}
else:
assert "*** this message\n" in logs
assert logs.endswith("this\nlog\ncontent")
diff --git
a/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_task_handler.py
b/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_task_handler.py
index 420f33b67f0..3c6124603c7 100644
---
a/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_task_handler.py
+++
b/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_task_handler.py
@@ -29,7 +29,7 @@ import time
from collections import defaultdict
from collections.abc import Callable
from operator import attrgetter
-from typing import TYPE_CHECKING, Any, Literal
+from typing import TYPE_CHECKING, Any, Literal, cast
from urllib.parse import quote, urlparse
# Using `from elasticsearch import *` would break elasticsearch mocking used
in unit test.
@@ -56,6 +56,7 @@ if TYPE_CHECKING:
from datetime import datetime
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
+ from airflow.utils.log.file_task_handler import LogMetadata
LOG_LINE_DEFAULTS = {"exc_text": "", "stack_info": ""}
@@ -294,8 +295,8 @@ class ElasticsearchTaskHandler(FileTaskHandler,
ExternalLoggingMixin, LoggingMix
return True
def _read(
- self, ti: TaskInstance, try_number: int, metadata: dict | None = None
- ) -> tuple[EsLogMsgType, dict]:
+ self, ti: TaskInstance, try_number: int, metadata: LogMetadata | None
= None
+ ) -> tuple[EsLogMsgType, LogMetadata]:
"""
Endpoint for streaming log.
@@ -306,7 +307,9 @@ class ElasticsearchTaskHandler(FileTaskHandler,
ExternalLoggingMixin, LoggingMix
:return: a list of tuple with host and log documents, metadata.
"""
if not metadata:
- metadata = {"offset": 0}
+ # LogMetadata(TypedDict) is used as type annotation for
log_reader; added ignore to suppress mypy error
+ metadata = {"offset": 0} # type: ignore[assignment]
+ metadata = cast("LogMetadata", metadata)
if "offset" not in metadata:
metadata["offset"] = 0
@@ -346,7 +349,9 @@ class ElasticsearchTaskHandler(FileTaskHandler,
ExternalLoggingMixin, LoggingMix
"Otherwise, the logs for this task instance may have been
removed."
)
if AIRFLOW_V_3_0_PLUS:
- return missing_log_message, metadata
+ from airflow.utils.log.file_task_handler import
StructuredLogMessage
+
+ return [StructuredLogMessage(event=missing_log_message)],
metadata
return [("", missing_log_message)], metadata # type:
ignore[list-item]
if (
# Assume end of log after not receiving new log for N min,
diff --git
a/providers/elasticsearch/tests/unit/elasticsearch/log/test_es_task_handler.py
b/providers/elasticsearch/tests/unit/elasticsearch/log/test_es_task_handler.py
index be3e36c8f90..2d90c110832 100644
---
a/providers/elasticsearch/tests/unit/elasticsearch/log/test_es_task_handler.py
+++
b/providers/elasticsearch/tests/unit/elasticsearch/log/test_es_task_handler.py
@@ -208,6 +208,7 @@ class TestElasticsearchTaskHandler:
)
if AIRFLOW_V_3_0_PLUS:
+ logs = list(logs)
assert logs[0].event == "::group::Log message source details"
assert logs[0].sources == ["default_host"]
assert logs[1].event == "::endgroup::"
@@ -235,6 +236,7 @@ class TestElasticsearchTaskHandler:
)
if AIRFLOW_V_3_0_PLUS:
+ logs = list(logs)
assert logs[0].event == "::group::Log message source details"
assert logs[0].sources == ["default_host"]
assert logs[1].event == "::endgroup::"
@@ -304,10 +306,11 @@ class TestElasticsearchTaskHandler:
ts = pendulum.now().add(seconds=-seconds)
logs, metadatas = self.es_task_handler.read(ti, 1, {"offset": 0,
"last_log_timestamp": str(ts)})
if AIRFLOW_V_3_0_PLUS:
+ logs = list(logs)
if seconds > 5:
# we expect a log not found message when checking began more
than 5 seconds ago
expected_pattern = r"^\*\*\* Log .* not found in
Elasticsearch.*"
- assert re.match(expected_pattern, logs) is not None
+ assert re.match(expected_pattern, logs[0].event) is not None
assert metadatas["end_of_log"] is True
else:
# we've "waited" less than 5 seconds so it should not be "end
of log" and should be no log message
@@ -360,6 +363,7 @@ class TestElasticsearchTaskHandler:
},
)
if AIRFLOW_V_3_0_PLUS:
+ logs = list(logs)
assert logs[0].event == "::group::Log message source details"
assert logs[0].sources == ["default_host"]
assert logs[1].event == "::endgroup::"
@@ -382,6 +386,7 @@ class TestElasticsearchTaskHandler:
def test_read_with_none_metadata(self, ti):
logs, metadatas = self.es_task_handler.read(ti, 1)
if AIRFLOW_V_3_0_PLUS:
+ logs = list(logs)
assert logs[0].event == "::group::Log message source details"
assert logs[0].sources == ["default_host"]
assert logs[1].event == "::endgroup::"
@@ -431,6 +436,7 @@ class TestElasticsearchTaskHandler:
ts = pendulum.now()
logs, metadatas = self.es_task_handler.read(ti, 1, {})
if AIRFLOW_V_3_0_PLUS:
+ logs = list(logs)
assert logs[0].event == "::group::Log message source details"
assert logs[0].sources == ["default_host"]
assert logs[1].event == "::endgroup::"
@@ -520,6 +526,7 @@ class TestElasticsearchTaskHandler:
},
)
if AIRFLOW_V_3_0_PLUS:
+ logs = list(logs)
assert logs[0].event == "::group::Log message source details"
assert logs[0].sources == ["default_host"]
assert logs[1].event == "::endgroup::"
@@ -601,6 +608,7 @@ class TestElasticsearchTaskHandler:
)
expected_message = "[2020-12-24 19:25:00,962] {taskinstance.py:851}
INFO - some random stuff - "
if AIRFLOW_V_3_0_PLUS:
+ logs = list(logs)
assert logs[2].event == expected_message
else:
assert logs[0][0][1] == expected_message
@@ -634,6 +642,7 @@ class TestElasticsearchTaskHandler:
)
expected_message = "[2020-12-24 19:25:00,962] {taskinstance.py:851}
INFO - some random stuff - "
if AIRFLOW_V_3_0_PLUS:
+ logs = list(logs)
assert logs[2].event == expected_message
else:
assert logs[0][0][1] == expected_message
diff --git
a/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py
b/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py
index 205bcfc5cb1..cbe6df81c75 100644
--- a/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py
+++ b/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py
@@ -114,6 +114,7 @@ class TestGCSTaskHandler:
mock_blob.from_string.assert_called_once_with(expected_gs_uri,
mock_client.return_value)
if AIRFLOW_V_3_0_PLUS:
+ logs = list(logs)
assert logs[0].event == "::group::Log message source details"
assert logs[0].sources == [expected_gs_uri]
assert logs[1].event == "::endgroup::"
@@ -143,6 +144,7 @@ class TestGCSTaskHandler:
expected_gs_uri = f"gs://bucket/{mock_obj.name}"
if AIRFLOW_V_3_0_PLUS:
+ log = list(log)
assert log[0].event == "::group::Log message source details"
assert log[0].sources == [
expected_gs_uri,
diff --git
a/providers/microsoft/azure/tests/unit/microsoft/azure/log/test_wasb_task_handler.py
b/providers/microsoft/azure/tests/unit/microsoft/azure/log/test_wasb_task_handler.py
index ba7fe5f1b2a..dc87b5a9da7 100644
---
a/providers/microsoft/azure/tests/unit/microsoft/azure/log/test_wasb_task_handler.py
+++
b/providers/microsoft/azure/tests/unit/microsoft/azure/log/test_wasb_task_handler.py
@@ -117,14 +117,12 @@ class TestWasbTaskHandler:
logs, metadata = self.wasb_task_handler.read(ti)
if AIRFLOW_V_3_0_PLUS:
+ logs = list(logs)
assert logs[0].event == "::group::Log message source details"
assert logs[0].sources ==
["https://wasb-container.blob.core.windows.net/abc/hello.log"]
assert logs[1].event == "::endgroup::"
assert logs[2].event == "Log line"
- assert metadata == {
- "end_of_log": True,
- "log_pos": 1,
- }
+ assert metadata == {"end_of_log": True, "log_pos": 1}
else:
assert logs[0][0][0] == "localhost"
assert (
diff --git
a/providers/opensearch/src/airflow/providers/opensearch/log/os_task_handler.py
b/providers/opensearch/src/airflow/providers/opensearch/log/os_task_handler.py
index 58b328cfefe..e770d7bb489 100644
---
a/providers/opensearch/src/airflow/providers/opensearch/log/os_task_handler.py
+++
b/providers/opensearch/src/airflow/providers/opensearch/log/os_task_handler.py
@@ -25,7 +25,7 @@ from collections import defaultdict
from collections.abc import Callable
from datetime import datetime
from operator import attrgetter
-from typing import TYPE_CHECKING, Any, Literal
+from typing import TYPE_CHECKING, Any, Literal, cast
import pendulum
from opensearchpy import OpenSearch
@@ -45,6 +45,7 @@ from airflow.utils.session import create_session
if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
+ from airflow.utils.log.file_task_handler import LogMetadata
if AIRFLOW_V_3_0_PLUS:
@@ -333,8 +334,8 @@ class OpensearchTaskHandler(FileTaskHandler,
ExternalLoggingMixin, LoggingMixin)
)
def _read(
- self, ti: TaskInstance, try_number: int, metadata: dict | None = None
- ) -> tuple[OsLogMsgType, dict]:
+ self, ti: TaskInstance, try_number: int, metadata: LogMetadata | None
= None
+ ) -> tuple[OsLogMsgType, LogMetadata]:
"""
Endpoint for streaming log.
@@ -345,7 +346,10 @@ class OpensearchTaskHandler(FileTaskHandler,
ExternalLoggingMixin, LoggingMixin)
:return: a list of tuple with host and log documents, metadata.
"""
if not metadata:
- metadata = {"offset": 0}
+ # LogMetadata(TypedDict) is used as type annotation for
log_reader; added ignore to suppress mypy error
+ metadata = {"offset": 0} # type: ignore[assignment]
+ metadata = cast("LogMetadata", metadata)
+
if "offset" not in metadata:
metadata["offset"] = 0
@@ -384,6 +388,12 @@ class OpensearchTaskHandler(FileTaskHandler,
ExternalLoggingMixin, LoggingMixin)
"If your task started recently, please wait a moment and
reload this page. "
"Otherwise, the logs for this task instance may have been
removed."
)
+ if AIRFLOW_V_3_0_PLUS:
+ from airflow.utils.log.file_task_handler import
StructuredLogMessage
+
+ # return list of StructuredLogMessage for Airflow 3.0+
+ return [StructuredLogMessage(event=missing_log_message)],
metadata
+
return [("", missing_log_message)], metadata # type:
ignore[list-item]
if (
# Assume end of log after not receiving new log for N min,
diff --git
a/providers/opensearch/tests/unit/opensearch/log/test_os_task_handler.py
b/providers/opensearch/tests/unit/opensearch/log/test_os_task_handler.py
index fb51c56e469..4faeb2223f9 100644
--- a/providers/opensearch/tests/unit/opensearch/log/test_os_task_handler.py
+++ b/providers/opensearch/tests/unit/opensearch/log/test_os_task_handler.py
@@ -204,6 +204,7 @@ class TestOpensearchTaskHandler:
"on 2023-07-09 07:47:32+00:00"
)
if AIRFLOW_V_3_0_PLUS:
+ logs = list(logs)
assert logs[0].event == "::group::Log message source details"
assert logs[0].sources == ["default_host"]
assert logs[1].event == "::endgroup::"
@@ -235,6 +236,7 @@ class TestOpensearchTaskHandler:
"on 2023-07-09 07:47:32+00:00"
)
if AIRFLOW_V_3_0_PLUS:
+ logs = list(logs)
assert logs[0].event == "::group::Log message source details"
assert logs[0].sources == ["default_host"]
assert logs[1].event == "::endgroup::"
@@ -332,10 +334,11 @@ class TestOpensearchTaskHandler:
):
logs, metadatas = self.os_task_handler.read(ti, 1, {"offset": 0,
"last_log_timestamp": str(ts)})
if AIRFLOW_V_3_0_PLUS:
+ logs = list(logs)
if seconds > 5:
# we expect a log not found message when checking began more
than 5 seconds ago
- assert len(logs[0]) == 2
- actual_message = logs[0][1]
+ assert len(logs) == 1
+ actual_message = logs[0].event
expected_pattern = r"^\*\*\* Log .* not found in Opensearch.*"
assert re.match(expected_pattern, actual_message) is not None
assert metadatas["end_of_log"] is True
@@ -374,6 +377,7 @@ class TestOpensearchTaskHandler:
"on 2023-07-09 07:47:32+00:00"
)
if AIRFLOW_V_3_0_PLUS:
+ logs = list(logs)
assert logs[0].event == "::group::Log message source details"
assert logs[0].sources == ["default_host"]
assert logs[1].event == "::endgroup::"
diff --git
a/providers/redis/src/airflow/providers/redis/log/redis_task_handler.py
b/providers/redis/src/airflow/providers/redis/log/redis_task_handler.py
index d1fd1cc8de0..bdc5b6ed0f8 100644
--- a/providers/redis/src/airflow/providers/redis/log/redis_task_handler.py
+++ b/providers/redis/src/airflow/providers/redis/log/redis_task_handler.py
@@ -19,7 +19,7 @@ from __future__ import annotations
import logging
from functools import cached_property
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING
from airflow.configuration import conf
from airflow.providers.redis.hooks.redis import RedisHook
@@ -31,6 +31,7 @@ if TYPE_CHECKING:
from redis import Redis
from airflow.models import TaskInstance
+ from airflow.utils.log.file_task_handler import LogMetadata
class RedisTaskHandler(FileTaskHandler, LoggingMixin):
@@ -75,8 +76,8 @@ class RedisTaskHandler(FileTaskHandler, LoggingMixin):
self,
ti: TaskInstance,
try_number: int,
- metadata: dict[str, Any] | None = None,
- ):
+ metadata: LogMetadata | None = None,
+ ) -> tuple[str | list[str], LogMetadata]:
log_str = b"\n".join(
self.conn.lrange(self._render_filename(ti, try_number), start=0,
end=-1)
).decode()
diff --git a/providers/redis/tests/unit/redis/log/test_redis_task_handler.py
b/providers/redis/tests/unit/redis/log/test_redis_task_handler.py
index 99bb497c563..2a64f8d7674 100644
--- a/providers/redis/tests/unit/redis/log/test_redis_task_handler.py
+++ b/providers/redis/tests/unit/redis/log/test_redis_task_handler.py
@@ -31,7 +31,11 @@ from airflow.utils.state import State
from airflow.utils.timezone import datetime
from tests_common.test_utils.config import conf_vars
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
+from tests_common.test_utils.file_task_handler import extract_events
+from tests_common.test_utils.version_compat import (
+ AIRFLOW_V_3_0_PLUS,
+ get_base_airflow_version_tuple,
+)
class TestRedisTaskHandler:
@@ -120,7 +124,12 @@ class TestRedisTaskHandler:
logs = handler.read(ti)
if AIRFLOW_V_3_0_PLUS:
- assert logs == (["Line 1\nLine 2"], {"end_of_log": True})
+ if get_base_airflow_version_tuple() < (3, 1, 0):
+ assert logs == (["Line 1\nLine 2"], {"end_of_log": True})
+ else:
+ log_stream, metadata = logs
+ assert extract_events(log_stream) == ["Line 1", "Line 2"]
+ assert metadata == {"end_of_log": True}
else:
assert logs == ([[("", "Line 1\nLine 2")]], [{"end_of_log": True}])
lrange.assert_called_once_with(key, start=0, end=-1)