This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new aeec6818178 fix the write-to-es feature for Airflow 3 (#53821)
aeec6818178 is described below
commit aeec681817826fe744745197d97f9c32c39118d1
Author: Owen Leung <[email protected]>
AuthorDate: Fri Feb 20 16:31:19 2026 +0800
fix the write-to-es feature for Airflow 3 (#53821)
* Introduce ElasticsearchRemoteLogIO
* Fix ruff check on testEsTaskhandler
* Add ElasticsearchRemoteIO into TaskHandler. Refactor to handle read/write
on RemoteIO class
---------
Co-authored-by: Jason(Zhe-You) Liu
<[email protected]>
---
providers/elasticsearch/pyproject.toml | 1 +
.../providers/elasticsearch/log/es_response.py | 28 ++
.../providers/elasticsearch/log/es_task_handler.py | 537 ++++++++++++---------
.../providers/elasticsearch/version_compat.py | 3 +-
providers/elasticsearch/tests/conftest.py | 14 +
.../log/elasticmock/fake_elasticsearch.py | 131 ++++-
.../unit/elasticsearch/log/test_es_task_handler.py | 251 +++++++---
7 files changed, 657 insertions(+), 308 deletions(-)
diff --git a/providers/elasticsearch/pyproject.toml
b/providers/elasticsearch/pyproject.toml
index 9894c9096a7..b4fb390fe60 100644
--- a/providers/elasticsearch/pyproject.toml
+++ b/providers/elasticsearch/pyproject.toml
@@ -73,6 +73,7 @@ dev = [
"apache-airflow-providers-common-sql",
# Additional devel dependencies (do not remove this line and add extra
development dependencies)
"apache-airflow-providers-common-sql[pandas,polars]",
+ "testcontainers==4.12.0"
]
# To build docs:
diff --git
a/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_response.py
b/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_response.py
index 610b03f96e1..2af39ce7364 100644
---
a/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_response.py
+++
b/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_response.py
@@ -17,6 +17,7 @@
from __future__ import annotations
from collections.abc import Iterator
+from typing import Any
def _wrap(val):
@@ -25,6 +26,33 @@ def _wrap(val):
return val
+def resolve_nested(self, hit: dict[Any, Any], parent_class=None) -> type[Hit]:
+ """
+ Resolve nested hits from Elasticsearch by iteratively navigating the
`_nested` field.
+
+ The result is used to fetch the appropriate document class to handle the
hit.
+
+ This method can be used with nested Elasticsearch fields which are
structured
+ as dictionaries with "field" and "_nested" keys.
+ """
+ doc_class = Hit
+
+ nested_path: list[str] = []
+ nesting = hit["_nested"]
+ while nesting and "field" in nesting:
+ nested_path.append(nesting["field"])
+ nesting = nesting.get("_nested")
+ nested_path_str = ".".join(nested_path)
+
+ if hasattr(parent_class, "_index"):
+ nested_field = parent_class._index.resolve_field(nested_path_str)
+
+ if nested_field is not None:
+ return nested_field._doc_class
+
+ return doc_class
+
+
class AttributeList:
"""Helper class to provide attribute like access to List objects."""
diff --git
a/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_task_handler.py
b/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_task_handler.py
index 4e76e767eb6..3b5480ff60e 100644
---
a/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_task_handler.py
+++
b/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_task_handler.py
@@ -22,38 +22,45 @@ import inspect
import json
import logging
import os
-import pathlib
import shutil
import sys
import time
from collections import defaultdict
from collections.abc import Callable
from operator import attrgetter
+from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, cast
from urllib.parse import quote, urlparse
+import attrs
+
# Using `from elasticsearch import *` would break elasticsearch mocking used
in unit test.
import elasticsearch
import pendulum
from elasticsearch import helpers
from elasticsearch.exceptions import NotFoundError
-from sqlalchemy import select
+import airflow.logging_config as alc
+from airflow.configuration import conf
from airflow.models.dagrun import DagRun
-from airflow.providers.common.compat.module_loading import import_string
-from airflow.providers.common.compat.sdk import AirflowException, conf,
timezone
from airflow.providers.elasticsearch.log.es_json_formatter import
ElasticsearchJSONFormatter
-from airflow.providers.elasticsearch.log.es_response import
ElasticSearchResponse, Hit
-from airflow.providers.elasticsearch.version_compat import AIRFLOW_V_3_0_PLUS
+from airflow.providers.elasticsearch.log.es_response import
ElasticSearchResponse, Hit, resolve_nested
+from airflow.providers.elasticsearch.version_compat import AIRFLOW_V_3_0_PLUS,
AIRFLOW_V_3_2_PLUS
+from airflow.utils import timezone
from airflow.utils.log.file_task_handler import FileTaskHandler
from airflow.utils.log.logging_mixin import ExternalLoggingMixin, LoggingMixin
-from airflow.utils.session import create_session
+
+if AIRFLOW_V_3_2_PLUS:
+ from airflow._shared.module_loading import import_string
+else:
+ from airflow.utils.module_loading import import_string # type:
ignore[no-redef]
if TYPE_CHECKING:
from datetime import datetime
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
- from airflow.utils.log.file_task_handler import LogMetadata
+ from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI
+ from airflow.utils.log.file_task_handler import LogMessages, LogMetadata,
LogSourceInfo
if AIRFLOW_V_3_0_PLUS:
@@ -89,28 +96,40 @@ def get_es_kwargs_from_config() -> dict[str, Any]:
return kwargs_dict
-def _ensure_ti(ti: TaskInstanceKey | TaskInstance, session) -> TaskInstance:
+def getattr_nested(obj, item, default):
"""
- Given TI | TIKey, return a TI object.
+ Get item from obj but return default if not found.
+
+ E.g. calling ``getattr_nested(a, 'b.c', "NA")`` will return
+ ``a.b.c`` if such a value exists, and "NA" otherwise.
- Will raise exception if no TI is found in the database.
+ :meta private:
"""
- from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
+ try:
+ return attrgetter(item)(obj)
+ except AttributeError:
+ return default
- if not isinstance(ti, TaskInstanceKey):
- return ti
- val = session.scalar(
- select(TaskInstance).where(
- TaskInstance.task_id == ti.task_id,
- TaskInstance.dag_id == ti.dag_id,
- TaskInstance.run_id == ti.run_id,
- TaskInstance.map_index == ti.map_index,
- )
+
+def _render_log_id(log_id_template: str, ti: TaskInstance | TaskInstanceKey,
try_number: int) -> str:
+ return log_id_template.format(
+ dag_id=ti.dag_id,
+ task_id=ti.task_id,
+ run_id=getattr(ti, "run_id", ""),
+ try_number=try_number,
+ map_index=getattr(ti, "map_index", ""),
)
- if isinstance(val, TaskInstance):
- val.try_number = ti.try_number
- return val
- raise AirflowException(f"Could not find TaskInstance for {ti}")
+
+
+def _clean_date(value: datetime | None) -> str:
+ """
+ Clean up a date value so that it is safe to query in elasticsearch by
removing reserved characters.
+
+
https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-query-string-query.html#_reserved_characters
+ """
+ if value is None:
+ return ""
+ return value.strftime("%Y_%m_%dT%H_%M_%S_%f")
class ElasticsearchTaskHandler(FileTaskHandler, ExternalLoggingMixin,
LoggingMixin):
@@ -148,8 +167,8 @@ class ElasticsearchTaskHandler(FileTaskHandler,
ExternalLoggingMixin, LoggingMix
base_log_folder: str,
end_of_log_mark: str,
write_stdout: bool,
- json_format: bool,
json_fields: str,
+ json_format: bool = False,
write_to_es: bool = False,
target_index: str = "airflow-logs",
host_field: str = "host",
@@ -198,6 +217,33 @@ class ElasticsearchTaskHandler(FileTaskHandler,
ExternalLoggingMixin, LoggingMix
self.handler: logging.FileHandler | logging.StreamHandler | None = None
self._doc_type_map: dict[Any, Any] = {}
self._doc_type: list[Any] = []
+ self.log_id_template: str = conf.get(
+ "elasticsearch",
+ "log_id_template",
+ fallback="{dag_id}-{task_id}-{run_id}-{map_index}-{try_number}",
+ )
+ self.io = ElasticsearchRemoteLogIO(
+ host=self.host,
+ target_index=self.target_index,
+ write_stdout=self.write_stdout,
+ write_to_es=self.write_to_es,
+ offset_field=self.offset_field,
+ host_field=self.host_field,
+ base_log_folder=base_log_folder,
+ delete_local_copy=self.delete_local_copy,
+ log_id_template=self.log_id_template,
+ )
+ # Airflow 3 introduce REMOTE_TASK_LOG for handling remote logging
+ # REMOTE_TASK_LOG should be explicitly set in
airflow_local_settings.py when trying to use ESTaskHandler
+ # Before airflow 3.1, REMOTE_TASK_LOG is not set when trying to use ES
TaskHandler.
+ if AIRFLOW_V_3_0_PLUS:
+ if AIRFLOW_V_3_2_PLUS:
+ from airflow.logging_config import _ActiveLoggingConfig,
get_remote_task_log
+
+ if get_remote_task_log() is None:
+ _ActiveLoggingConfig.set(self.io, None)
+ elif alc.REMOTE_TASK_LOG is None: # type: ignore[attr-defined]
+ alc.REMOTE_TASK_LOG = self.io # type: ignore[attr-defined]
@staticmethod
def format_url(host: str) -> str:
@@ -221,70 +267,6 @@ class ElasticsearchTaskHandler(FileTaskHandler,
ExternalLoggingMixin, LoggingMix
return host
- def _get_index_patterns(self, ti: TaskInstance | None) -> str:
- """
- Get index patterns by calling index_patterns_callable, if provided, or
the configured index_patterns.
-
- :param ti: A TaskInstance object or None.
- """
- if self.index_patterns_callable:
- self.log.debug("Using index_patterns_callable: %s",
self.index_patterns_callable)
- index_pattern_callable_obj =
import_string(self.index_patterns_callable)
- return index_pattern_callable_obj(ti)
- self.log.debug("Using index_patterns: %s", self.index_patterns)
- return self.index_patterns
-
- def _render_log_id(self, ti: TaskInstance | TaskInstanceKey, try_number:
int) -> str:
- from airflow.models.taskinstance import TaskInstanceKey
-
- with create_session() as session:
- if isinstance(ti, TaskInstanceKey):
- ti = _ensure_ti(ti, session)
- dag_run = ti.get_dagrun(session=session)
- if USE_PER_RUN_LOG_ID:
- log_id_template =
dag_run.get_log_template(session=session).elasticsearch_id
-
- if self.json_format:
- data_interval_start = self._clean_date(dag_run.data_interval_start)
- data_interval_end = self._clean_date(dag_run.data_interval_end)
- logical_date = self._clean_date(dag_run.logical_date)
- else:
- data_interval_start = (
- dag_run.data_interval_start.isoformat() if
dag_run.data_interval_start else ""
- )
- data_interval_end = dag_run.data_interval_end.isoformat() if
dag_run.data_interval_end else ""
- logical_date = dag_run.logical_date.isoformat() if
dag_run.logical_date else ""
-
- return log_id_template.format(
- dag_id=ti.dag_id,
- task_id=ti.task_id,
- run_id=getattr(ti, "run_id", ""),
- data_interval_start=data_interval_start,
- data_interval_end=data_interval_end,
- logical_date=logical_date,
- execution_date=logical_date,
- try_number=try_number,
- map_index=getattr(ti, "map_index", ""),
- )
-
- @staticmethod
- def _clean_date(value: datetime | None) -> str:
- """
- Clean up a date value so that it is safe to query in elasticsearch by
removing reserved characters.
-
-
https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-query-string-query.html#_reserved_characters
- """
- if value is None:
- return ""
- return value.strftime("%Y_%m_%dT%H_%M_%S_%f")
-
- def _group_logs_by_host(self, response: ElasticSearchResponse) ->
dict[str, list[Hit]]:
- grouped_logs = defaultdict(list)
- for hit in response:
- key = getattr_nested(hit, self.host_field, None) or self.host
- grouped_logs[key].append(hit)
- return grouped_logs
-
def _read_grouped_logs(self):
return True
@@ -308,15 +290,15 @@ class ElasticsearchTaskHandler(FileTaskHandler,
ExternalLoggingMixin, LoggingMix
metadata["offset"] = 0
offset = metadata["offset"]
- log_id = self._render_log_id(ti, try_number)
- response = self._es_read(log_id, offset, ti)
+ log_id = _render_log_id(self.log_id_template, ti, try_number)
+ response = self.io._es_read(log_id, offset, ti)
+ # TODO: Can we skip group logs by host ?
if response is not None and response.hits:
- logs_by_host = self._group_logs_by_host(response)
+ logs_by_host = self.io._group_logs_by_host(response)
next_offset = attrgetter(self.offset_field)(response[-1])
else:
logs_by_host = None
next_offset = offset
-
# Ensure a string here. Large offset numbers will get JSON.parsed
incorrectly
# on the client. Sending as a string prevents this issue.
#
https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Number/MAX_SAFE_INTEGER
@@ -326,7 +308,10 @@ class ElasticsearchTaskHandler(FileTaskHandler,
ExternalLoggingMixin, LoggingMix
# have the log uploaded but will not be stored in elasticsearch.
metadata["end_of_log"] = False
if logs_by_host:
- if any(x[-1].message == self.end_of_log_mark for x in
logs_by_host.values()):
+ end_mark_found = any(
+ self._get_log_message(x[-1]) == self.end_of_log_mark for x in
logs_by_host.values()
+ )
+ if end_mark_found:
metadata["end_of_log"] = True
cur_ts = pendulum.now()
@@ -358,12 +343,6 @@ class ElasticsearchTaskHandler(FileTaskHandler,
ExternalLoggingMixin, LoggingMix
if int(offset) != int(next_offset) or "last_log_timestamp" not in
metadata:
metadata["last_log_timestamp"] = str(cur_ts)
- # If we hit the end of the log, remove the actual end_of_log message
- # to prevent it from showing in the UI.
- def concat_logs(hits: list[Hit]) -> str:
- log_range = (len(hits) - 1) if hits[-1].message ==
self.end_of_log_mark else len(hits)
- return "\n".join(self._format_msg(hits[i]) for i in
range(log_range))
-
if logs_by_host:
if AIRFLOW_V_3_0_PLUS:
from airflow.utils.log.file_task_handler import
StructuredLogMessage
@@ -386,11 +365,12 @@ class ElasticsearchTaskHandler(FileTaskHandler,
ExternalLoggingMixin, LoggingMix
]
else:
message = [
- (host, concat_logs(hits)) # type: ignore[misc]
+ (host, self.concat_logs(hits)) # type: ignore[misc]
for host, hits in logs_by_host.items()
]
else:
message = []
+ metadata["end_of_log"] = True
return message, metadata
def _format_msg(self, hit: Hit):
@@ -404,46 +384,7 @@ class ElasticsearchTaskHandler(FileTaskHandler,
ExternalLoggingMixin, LoggingMix
)
# Just a safe-guard to preserve backwards-compatibility
- return hit.message
-
- def _es_read(self, log_id: str, offset: int | str, ti: TaskInstance) ->
ElasticSearchResponse | None:
- """
- Return the logs matching log_id in Elasticsearch and next offset or ''.
-
- :param log_id: the log_id of the log to read.
- :param offset: the offset start to read log from.
- :param ti: the task instance object
-
- :meta private:
- """
- query: dict[Any, Any] = {
- "bool": {
- "filter": [{"range": {self.offset_field: {"gt":
int(offset)}}}],
- "must": [{"match_phrase": {"log_id": log_id}}],
- }
- }
-
- index_patterns = self._get_index_patterns(ti)
- try:
- max_log_line = self.client.count(index=index_patterns,
query=query)["count"]
- except NotFoundError as e:
- self.log.exception("The target index pattern %s does not exist",
index_patterns)
- raise e
-
- if max_log_line != 0:
- try:
- res = self.client.search(
- index=index_patterns,
- query=query,
- sort=[self.offset_field],
- size=self.MAX_LINE_PER_PAGE,
- from_=self.MAX_LINE_PER_PAGE * self.PAGE,
- )
- return ElasticSearchResponse(self, res)
- except Exception as err:
- self.log.exception("Could not read log with log_id: %s.
Exception: %s", log_id, err)
-
- return None
+ return self._get_log_message(hit)
def emit(self, record):
if self.handler:
@@ -452,6 +393,8 @@ class ElasticsearchTaskHandler(FileTaskHandler,
ExternalLoggingMixin, LoggingMix
def set_context(self, ti: TaskInstance, *, identifier: str | None = None)
-> None:
"""
+ TODO: This API should be removed in airflow 3.
+
Provide task_instance context to airflow task handler.
:param ti: task instance object
@@ -470,12 +413,10 @@ class ElasticsearchTaskHandler(FileTaskHandler,
ExternalLoggingMixin, LoggingMix
"dag_id": str(ti.dag_id),
"task_id": str(ti.task_id),
date_key: (
- self._clean_date(ti.logical_date)
- if AIRFLOW_V_3_0_PLUS
- else self._clean_date(ti.execution_date)
+ _clean_date(ti.logical_date) if AIRFLOW_V_3_0_PLUS
else _clean_date(ti.execution_date)
),
"try_number": str(ti.try_number),
- "log_id": self._render_log_id(ti, ti.try_number),
+ "log_id": _render_log_id(self.log_id_template, ti,
ti.try_number),
},
)
@@ -497,6 +438,7 @@ class ElasticsearchTaskHandler(FileTaskHandler,
ExternalLoggingMixin, LoggingMix
# calling close method. Here we check if logger is already
# closed to prevent uploading the log to remote storage multiple
# times when `logging.shutdown` is called.
+ # TODO: This API should be simplified since Airflow 3 no longer
requires this API for writing log to ES
if self.closed:
return
@@ -519,22 +461,10 @@ class ElasticsearchTaskHandler(FileTaskHandler,
ExternalLoggingMixin, LoggingMix
# so we know where to stop while auto-tailing.
self.emit(logging.makeLogRecord({"msg": self.end_of_log_mark}))
- if self.write_stdout:
+ if self.io.write_stdout:
self.handler.close()
sys.stdout = sys.__stdout__
- if self.write_to_es and not self.write_stdout:
- full_path = self.handler.baseFilename # type: ignore[union-attr]
- log_relative_path =
pathlib.Path(full_path).relative_to(self.local_base).as_posix()
- local_loc = os.path.join(self.local_base, log_relative_path)
- if os.path.exists(local_loc):
- # read log and remove old logs to get just the latest additions
- log = pathlib.Path(local_loc).read_text()
- log_lines = self._parse_raw_log(log)
- success = self._write_to_es(log_lines)
- if success and self.delete_local_copy:
- shutil.rmtree(os.path.dirname(local_loc))
-
super().close()
self.closed = True
@@ -552,7 +482,7 @@ class ElasticsearchTaskHandler(FileTaskHandler,
ExternalLoggingMixin, LoggingMix
:param try_number: task instance try_number to read logs from.
:return: URL to the external log collection service
"""
- log_id = self._render_log_id(task_instance, try_number)
+ log_id = _render_log_id(self.log_id_template, task_instance,
try_number)
scheme = "" if "://" in self.frontend else "https://"
return scheme + self.frontend.format(log_id=quote(log_id))
@@ -561,38 +491,12 @@ class ElasticsearchTaskHandler(FileTaskHandler,
ExternalLoggingMixin, LoggingMix
"""Whether we can support external links."""
return bool(self.frontend)
- def _resolve_nested(self, hit: dict[Any, Any], parent_class=None) ->
type[Hit]:
- """
- Resolve nested hits from Elasticsearch by iteratively navigating the
`_nested` field.
-
- The result is used to fetch the appropriate document class to handle
the hit.
-
- This method can be used with nested Elasticsearch fields which are
structured
- as dictionaries with "field" and "_nested" keys.
- """
- doc_class = Hit
-
- nested_path: list[str] = []
- nesting = hit["_nested"]
- while nesting and "field" in nesting:
- nested_path.append(nesting["field"])
- nesting = nesting.get("_nested")
- nested_path_str = ".".join(nested_path)
-
- if hasattr(parent_class, "_index"):
- nested_field = parent_class._index.resolve_field(nested_path_str)
-
- if nested_field is not None:
- return nested_field._doc_class
-
- return doc_class
-
def _get_result(self, hit: dict[Any, Any], parent_class=None) -> Hit:
"""
Process a hit (i.e., a result) from an Elasticsearch response and
transform it into a class instance.
The transformation depends on the contents of the hit. If the document
in hit contains a nested field,
- the '_resolve_nested' method is used to determine the appropriate
class (based on the nested path).
+ the 'resolve_nested' method is used to determine the appropriate class
(based on the nested path).
If the hit has a document type that is present in the '_doc_type_map',
the corresponding class is
used. If not, the method iterates over the '_doc_type' classes and
uses the first one whose '_matches'
method returns True for the hit.
@@ -602,41 +506,12 @@ class ElasticsearchTaskHandler(FileTaskHandler,
ExternalLoggingMixin, LoggingMix
Finally, the transformed hit is returned. If the determined class has
a 'from_es' method, this is
used to transform the hit
-
- An example of the hit argument:
-
- {'_id': 'jdeZT4kBjAZqZnexVUxk',
- '_index': '.ds-filebeat-8.8.2-2023.07.09-000001',
- '_score': 2.482621,
- '_source': {'@timestamp': '2023-07-13T14:13:15.140Z',
- 'asctime': '2023-07-09T07:47:43.907+0000',
- 'container': {'id': 'airflow'},
- 'dag_id': 'example_bash_operator',
- 'ecs': {'version': '8.0.0'},
- 'logical_date': '2023_07_09T07_47_32_000000',
- 'filename': 'taskinstance.py',
- 'input': {'type': 'log'},
- 'levelname': 'INFO',
- 'lineno': 1144,
- 'log': {'file': {'path':
"/opt/airflow/Documents/GitHub/airflow/logs/
- dag_id=example_bash_operator'/run_id=owen_run_run/
- task_id=run_after_loop/attempt=1.log"},
- 'offset': 0},
- 'log.offset': 1688888863907337472,
- 'log_id':
'example_bash_operator-run_after_loop-owen_run_run--1-1',
- 'message': 'Dependencies all met for
dep_context=non-requeueable '
- 'deps ti=<TaskInstance: '
- 'example_bash_operator.run_after_loop
owen_run_run '
- '[queued]>',
- 'task_id': 'run_after_loop',
- 'try_number': '1'},
- '_type': '_doc'}
"""
doc_class = Hit
dt = hit.get("_type")
if "_nested" in hit:
- doc_class = self._resolve_nested(hit, parent_class)
+ doc_class = resolve_nested(hit, parent_class)
elif dt in self._doc_type_map:
doc_class = self._doc_type_map[dt]
@@ -654,13 +529,97 @@ class ElasticsearchTaskHandler(FileTaskHandler,
ExternalLoggingMixin, LoggingMix
callback: type[Hit] | Callable[..., Any] = getattr(doc_class,
"from_es", doc_class)
return callback(hit)
- def _parse_raw_log(self, log: str) -> list[dict[str, Any]]:
+ def _get_log_message(self, hit: Hit) -> str:
+ """
+ Get log message from hit, supporting both Airflow 2.x and 3.x formats.
+
+ In Airflow 2.x, the log record JSON has a "message" key, e.g.:
+ {
+ "message": "Dag name:dataset_consumes_1 queued_at:2025-08-12
15:05:57.703493+00:00",
+ "offset": 1755011166339518208,
+ "log_id":
"dataset_consumes_1-consuming_1-manual__2025-08-12T15:05:57.691303+00:00--1-1"
+ }
+
+ In Airflow 3.x, the "message" field is renamed to "event".
+ We check the correct attribute depending on the Airflow major version.
+ """
+ if hasattr(hit, "event"):
+ return hit.event
+ if hasattr(hit, "message"):
+ return hit.message
+ return ""
+
+ def concat_logs(self, hits: list[Hit]) -> str:
+ log_range = (len(hits) - 1) if self._get_log_message(hits[-1]) ==
self.end_of_log_mark else len(hits)
+ return "\n".join(self._format_msg(hits[i]) for i in range(log_range))
+
+
[email protected](kw_only=True)
+class ElasticsearchRemoteLogIO(LoggingMixin): # noqa: D101
+ json_format: bool = False
+ write_stdout: bool = False
+ delete_local_copy: bool = False
+ host: str = "http://localhost:9200"
+ host_field: str = "host"
+ target_index: str = "airflow-logs"
+ offset_field: str = "offset"
+ write_to_es: bool = False
+ base_log_folder: Path = attrs.field(converter=Path)
+ log_id_template: str = conf.get(
+ "elasticsearch",
+ "log_id_template",
+ fallback="{dag_id}-{task_id}-{run_id}-{map_index}-{try_number}",
+ )
+
+ processors = ()
+
+ def __attrs_post_init__(self):
+ es_kwargs = get_es_kwargs_from_config()
+ self.client = elasticsearch.Elasticsearch(self.host, **es_kwargs)
+ self.index_patterns_callable = conf.get("elasticsearch",
"index_patterns_callable", fallback="")
+ self.PAGE = 0
+ self.MAX_LINE_PER_PAGE = 1000
+ self.index_patterns: str = conf.get("elasticsearch", "index_patterns")
+ self._doc_type_map: dict[Any, Any] = {}
+ self._doc_type: list[Any] = []
+
+ def upload(self, path: os.PathLike | str, ti: RuntimeTI):
+ """Write the log to ElasticSearch."""
+ path = Path(path)
+
+ if path.is_absolute():
+ local_loc = path
+ else:
+ local_loc = self.base_log_folder.joinpath(path)
+
+ log_id = _render_log_id(self.log_id_template, ti, ti.try_number) #
type: ignore[arg-type]
+ if local_loc.is_file() and self.write_stdout:
+ # Intentionally construct the log_id and offset field
+
+ log_lines = self._parse_raw_log(local_loc.read_text(), log_id)
+ for line in log_lines:
+ sys.stdout.write(json.dumps(line) + "\n")
+ sys.stdout.flush()
+
+ if local_loc.is_file() and self.write_to_es:
+ log_lines = self._parse_raw_log(local_loc.read_text(), log_id)
+ success = self._write_to_es(log_lines)
+ if success and self.delete_local_copy:
+ shutil.rmtree(os.path.dirname(local_loc))
+
+ def _parse_raw_log(self, log: str, log_id: str) -> list[dict[str, Any]]:
logs = log.split("\n")
parsed_logs = []
+ offset = 1
for line in logs:
# Make sure line is not empty
if line.strip():
- parsed_logs.append(json.loads(line))
+ # construct log_id which is
{dag_id}-{task_id}-{run_id}-{map_index}-{try_number}
+ # also construct the offset field (default is 'offset')
+ log_dict = json.loads(line)
+ log_dict.update({"log_id": log_id, self.offset_field: offset})
+ offset += 1
+ parsed_logs.append(log_dict)
return parsed_logs
@@ -675,21 +634,139 @@ class ElasticsearchTaskHandler(FileTaskHandler,
ExternalLoggingMixin, LoggingMix
try:
_ = helpers.bulk(self.client, bulk_actions)
return True
+ except helpers.BulkIndexError as bie:
+ self.log.exception("Bulk upload failed for %d log(s)",
len(bie.errors))
+ for error in bie.errors:
+ self.log.exception(error)
+ return False
except Exception as e:
self.log.exception("Unable to insert logs into Elasticsearch.
Reason: %s", str(e))
return False
+ def read(self, _relative_path: str, ti: RuntimeTI) -> tuple[LogSourceInfo,
LogMessages]:
+ log_id = _render_log_id(self.log_id_template, ti, ti.try_number) #
type: ignore[arg-type]
+ self.log.info("Reading log %s from Elasticsearch", log_id)
+ offset = 0
+ response = self._es_read(log_id, offset, ti)
+ if response is not None and response.hits:
+ logs_by_host = self._group_logs_by_host(response)
+ else:
+ logs_by_host = None
-def getattr_nested(obj, item, default):
- """
- Get item from obj but return default if not found.
+ if logs_by_host is None:
+ missing_log_message = (
+ f"*** Log {log_id} not found in Elasticsearch. "
+ "If your task started recently, please wait a moment and
reload this page. "
+ "Otherwise, the logs for this task instance may have been
removed."
+ )
+ return [], [missing_log_message]
- E.g. calling ``getattr_nested(a, 'b.c', "NA")`` will return
- ``a.b.c`` if such a value exists, and "NA" otherwise.
+ header = []
+ # Start log group
+ header.append("".join([host for host in logs_by_host.keys()]))
- :meta private:
- """
- try:
- return attrgetter(item)(obj)
- except AttributeError:
- return default
+ message = []
+ # Structured log messages
+ for hits in logs_by_host.values():
+ for hit in hits:
+ filtered = {k: v for k, v in hit.to_dict().items() if
k.lower() in TASK_LOG_FIELDS}
+ message.append(json.dumps(filtered))
+
+ return header, message
+
+ def _es_read(self, log_id: str, offset: int | str, ti: RuntimeTI) ->
ElasticSearchResponse | None:
+ """
+ Return the logs matching log_id in Elasticsearch and next offset or ''.
+
+ :param log_id: the log_id of the log to read.
+ :param offset: the offset start to read log from.
+ :param ti: the task instance object
+
+ :meta private:
+ """
+ query: dict[Any, Any] = {
+ "bool": {
+ "filter": [{"range": {self.offset_field: {"gt":
int(offset)}}}],
+ "must": [{"match_phrase": {"log_id": log_id}}],
+ }
+ }
+
+ index_patterns = self._get_index_patterns(ti)
+ try:
+ max_log_line = self.client.count(index=index_patterns,
query=query)["count"]
+ except NotFoundError as e:
+ self.log.exception("The target index pattern %s does not exist",
index_patterns)
+ raise e
+
+ if max_log_line != 0:
+ try:
+ res = self.client.search(
+ index=index_patterns,
+ query=query,
+ sort=[self.offset_field],
+ size=self.MAX_LINE_PER_PAGE,
+ from_=self.MAX_LINE_PER_PAGE * self.PAGE,
+ )
+ return ElasticSearchResponse(self, res)
+ except Exception as err:
+ self.log.exception("Could not read log with log_id: %s.
Exception: %s", log_id, err)
+
+ return None
+
+ def _get_index_patterns(self, ti: RuntimeTI | None) -> str:
+ """
+ Get index patterns by calling index_patterns_callable, if provided, or
the configured index_patterns.
+
+ :param ti: A TaskInstance object or None.
+ """
+ if self.index_patterns_callable:
+ self.log.debug("Using index_patterns_callable: %s",
self.index_patterns_callable)
+ index_pattern_callable_obj =
import_string(self.index_patterns_callable)
+ return index_pattern_callable_obj(ti)
+ self.log.debug("Using index_patterns: %s", self.index_patterns)
+ return self.index_patterns
+
+ def _group_logs_by_host(self, response: ElasticSearchResponse) ->
dict[str, list[Hit]]:
+ grouped_logs = defaultdict(list)
+ for hit in response:
+ key = getattr_nested(hit, self.host_field, None) or self.host
+ grouped_logs[key].append(hit)
+ return grouped_logs
+
+ def _get_result(self, hit: dict[Any, Any], parent_class=None) -> Hit:
+ """
+ Process a hit (i.e., a result) from an Elasticsearch response and
transform it into a class instance.
+
+ The transformation depends on the contents of the hit. If the document
in hit contains a nested field,
+ the 'resolve_nested' method is used to determine the appropriate class
(based on the nested path).
+ If the hit has a document type that is present in the '_doc_type_map',
the corresponding class is
+ used. If not, the method iterates over the '_doc_type' classes and
uses the first one whose '_matches'
+ method returns True for the hit.
+
+ If the hit contains any 'inner_hits', these are also processed into
'ElasticSearchResponse' instances
+ using the determined class.
+
+ Finally, the transformed hit is returned. If the determined class has
a 'from_es' method, this is
+ used to transform the hit
+ """
+ doc_class = Hit
+ dt = hit.get("_type")
+
+ if "_nested" in hit:
+ doc_class = resolve_nested(hit, parent_class)
+
+ elif dt in self._doc_type_map:
+ doc_class = self._doc_type_map[dt]
+
+ else:
+ for doc_type in self._doc_type:
+ if hasattr(doc_type, "_matches") and doc_type._matches(hit):
+ doc_class = doc_type
+ break
+
+ for t in hit.get("inner_hits", ()):
+ hit["inner_hits"][t] = ElasticSearchResponse(self,
hit["inner_hits"][t], doc_class=doc_class)
+
+ # callback should get the Hit class if "from_es" is not defined
+ callback: type[Hit] | Callable[..., Any] = getattr(doc_class,
"from_es", doc_class)
+ return callback(hit)
diff --git
a/providers/elasticsearch/src/airflow/providers/elasticsearch/version_compat.py
b/providers/elasticsearch/src/airflow/providers/elasticsearch/version_compat.py
index f5bb3ae555c..caa8c1a5197 100644
---
a/providers/elasticsearch/src/airflow/providers/elasticsearch/version_compat.py
+++
b/providers/elasticsearch/src/airflow/providers/elasticsearch/version_compat.py
@@ -34,5 +34,6 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0)
AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 0)
+AIRFLOW_V_3_2_PLUS: bool = get_base_airflow_version_tuple() >= (3, 2, 0)
-__all__ = ["AIRFLOW_V_3_0_PLUS", "AIRFLOW_V_3_1_PLUS"]
+__all__ = ["AIRFLOW_V_3_0_PLUS", "AIRFLOW_V_3_1_PLUS", "AIRFLOW_V_3_2_PLUS"]
diff --git a/providers/elasticsearch/tests/conftest.py
b/providers/elasticsearch/tests/conftest.py
index f56ccce0a3f..354827d67d6 100644
--- a/providers/elasticsearch/tests/conftest.py
+++ b/providers/elasticsearch/tests/conftest.py
@@ -16,4 +16,18 @@
# under the License.
from __future__ import annotations
+import pytest
+from testcontainers.elasticsearch import ElasticSearchContainer
+
pytest_plugins = "tests_common.pytest_plugin"
+
+
[email protected](scope="session")
+def es_8_container_url() -> str:
+ es = (
+
ElasticSearchContainer("docker.elastic.co/elasticsearch/elasticsearch:8.19.0")
+ .with_env("discovery.type", "single-node")
+ .with_env("cluster.routing.allocation.disk.threshold_enabled", "false")
+ )
+ with es:
+ yield es.get_url()
diff --git
a/providers/elasticsearch/tests/unit/elasticsearch/log/elasticmock/fake_elasticsearch.py
b/providers/elasticsearch/tests/unit/elasticsearch/log/elasticmock/fake_elasticsearch.py
index c7746001d68..df101f2f918 100644
---
a/providers/elasticsearch/tests/unit/elasticsearch/log/elasticmock/fake_elasticsearch.py
+++
b/providers/elasticsearch/tests/unit/elasticsearch/log/elasticmock/fake_elasticsearch.py
@@ -80,7 +80,7 @@ class FakeElasticsearch(Elasticsearch):
}
@query_params()
- def sample_log_response(self, headers=None, params=None):
+ def sample_airflow_2_log_response(self, headers=None, params=None):
return {
"_shards": {"failed": 0, "skipped": 0, "successful": 7, "total":
7},
"hits": {
@@ -104,17 +104,16 @@ class FakeElasticsearch(Elasticsearch):
"file": {
"path":
"/opt/airflow/Documents/GitHub/airflow/logs/"
"dag_id=example_bash_operator'"
-
"/run_id=owen_run_run/task_id=run_after_loop/attempt=1.log"
+
"/run_id=run_run/task_id=run_after_loop/attempt=1.log"
},
"offset": 0,
},
"log.offset": 1688888863907337472,
- "log_id":
"example_bash_operator-run_after_loop-owen_run_run--1-1",
+ "log_id":
"example_bash_operator-run_after_loop-run_run--1-1",
"message": "Dependencies all met for "
"dep_context=non-requeueable deps "
"ti=<TaskInstance: "
- "example_bash_operator.run_after_loop "
- "owen_run_run [queued]>",
+ "example_bash_operator.run_after_loop ",
"task_id": "run_after_loop",
"try_number": "1",
},
@@ -139,12 +138,12 @@ class FakeElasticsearch(Elasticsearch):
"file": {
"path":
"/opt/airflow/Documents/GitHub/airflow/logs/"
"dag_id=example_bash_operator"
-
"/run_id=owen_run_run/task_id=run_after_loop/attempt=1.log"
+
"/run_id=run_run/task_id=run_after_loop/attempt=1.log"
},
"offset": 988,
},
"log.offset": 1688888863917961216,
- "log_id":
"example_bash_operator-run_after_loop-owen_run_run--1-1",
+ "log_id":
"example_bash_operator-run_after_loop-run_run--1-1",
"message": "Starting attempt 1 of 1",
"task_id": "run_after_loop",
"try_number": "1",
@@ -170,12 +169,12 @@ class FakeElasticsearch(Elasticsearch):
"file": {
"path":
"/opt/airflow/Documents/GitHub/airflow/logs/"
"dag_id=example_bash_operator"
-
"/run_id=owen_run_run/task_id=run_after_loop/attempt=1.log"
+
"/run_id=run_run/task_id=run_after_loop/attempt=1.log"
},
"offset": 1372,
},
"log.offset": 1688888863928218880,
- "log_id":
"example_bash_operator-run_after_loop-owen_run_run--1-1",
+ "log_id":
"example_bash_operator-run_after_loop-run_run--1-1",
"message": "Executing <Task(BashOperator): "
"run_after_loop> on 2023-07-09 "
"07:47:32+00:00",
@@ -192,6 +191,118 @@ class FakeElasticsearch(Elasticsearch):
"took": 7,
}
+ @query_params()
+ def sample_airflow_3_log_response(self, headers=None, params=None):
+ return {
+ "_shards": {"failed": 0, "skipped": 0, "successful": 7, "total":
7},
+ "hits": {
+ "hits": [
+ {
+ "_id": "jdeZT4kBjAZqZnexVUxk",
+ "_index": ".ds-filebeat-8.8.2-2023.07.09-000001",
+ "_score": 2.482621,
+ "_source": {
+ "@timestamp": "2023-07-13T14:13:15.140Z",
+ "asctime": "2023-07-09T07:47:43.907+0000",
+ "container": {"id": "airflow"},
+ "dag_id": "example_bash_operator",
+ "ecs": {"version": "8.0.0"},
+ "execution_date": "2023_07_09T07_47_32_000000",
+ "filename": "taskinstance.py",
+ "input": {"type": "log"},
+ "levelname": "INFO",
+ "lineno": 1144,
+ "log": {
+ "file": {
+ "path":
"/opt/airflow/Documents/GitHub/airflow/logs/"
+ "dag_id=example_bash_operator'"
+
"/run_id=run_run/task_id=run_after_loop/attempt=1.log"
+ },
+ "offset": 0,
+ },
+ "log.offset": 1688888863907337472,
+ "log_id":
"example_bash_operator-run_after_loop-run_run--1-1",
+ "task_id": "run_after_loop",
+ "try_number": "1",
+ "event": "Dependencies all met for "
+ "dep_context=non-requeueable deps "
+ "ti=<TaskInstance: "
+ "example_bash_operator.run_after_loop ",
+ },
+ "_type": "_doc",
+ },
+ {
+ "_id": "qteZT4kBjAZqZnexVUxl",
+ "_index": ".ds-filebeat-8.8.2-2023.07.09-000001",
+ "_score": 2.482621,
+ "_source": {
+ "@timestamp": "2023-07-13T14:13:15.141Z",
+ "asctime": "2023-07-09T07:47:43.917+0000",
+ "container": {"id": "airflow"},
+ "dag_id": "example_bash_operator",
+ "ecs": {"version": "8.0.0"},
+ "execution_date": "2023_07_09T07_47_32_000000",
+ "filename": "taskinstance.py",
+ "input": {"type": "log"},
+ "levelname": "INFO",
+ "lineno": 1347,
+ "log": {
+ "file": {
+ "path":
"/opt/airflow/Documents/GitHub/airflow/logs/"
+ "dag_id=example_bash_operator"
+
"/run_id=run_run/task_id=run_after_loop/attempt=1.log"
+ },
+ "offset": 988,
+ },
+ "log.offset": 1688888863917961216,
+ "log_id":
"example_bash_operator-run_after_loop-run_run--1-1",
+ "event": "Starting attempt 1 of 1",
+ "task_id": "run_after_loop",
+ "try_number": "1",
+ },
+ "_type": "_doc",
+ },
+ {
+ "_id": "v9eZT4kBjAZqZnexVUx2",
+ "_index": ".ds-filebeat-8.8.2-2023.07.09-000001",
+ "_score": 2.482621,
+ "_source": {
+ "@timestamp": "2023-07-13T14:13:15.143Z",
+ "asctime": "2023-07-09T07:47:43.928+0000",
+ "container": {"id": "airflow"},
+ "dag_id": "example_bash_operator",
+ "ecs": {"version": "8.0.0"},
+ "execution_date": "2023_07_09T07_47_32_000000",
+ "filename": "taskinstance.py",
+ "input": {"type": "log"},
+ "levelname": "INFO",
+ "lineno": 1368,
+ "log": {
+ "file": {
+ "path":
"/opt/airflow/Documents/GitHub/airflow/logs/"
+ "dag_id=example_bash_operator"
+
"/run_id=run_run/task_id=run_after_loop/attempt=1.log"
+ },
+ "offset": 1372,
+ },
+ "log.offset": 1688888863928218880,
+ "log_id":
"example_bash_operator-run_after_loop-run_run--1-1",
+ "task_id": "run_after_loop",
+ "try_number": "1",
+ "event": "Executing <Task(BashOperator): "
+ "run_after_loop> on 2023-07-09 "
+ "07:47:32+00:00",
+ },
+ "_type": "_doc",
+ },
+ ],
+ "max_score": 2.482621,
+ "total": {"relation": "eq", "value": 36},
+ },
+ "timed_out": False,
+ "took": 7,
+ }
+
@query_params(
"consistency",
"op_type",
@@ -479,7 +590,6 @@ class FakeElasticsearch(Elasticsearch):
# TODO: support allow_no_indices query parameter
matches = set()
for target in targets:
- print(f"Loop over:::target = {target}")
if target in ("_all", ""):
matches.update(self.__documents_dict)
elif "*" in target:
@@ -499,7 +609,6 @@ class FakeElasticsearch(Elasticsearch):
else:
# Is it the correct exception to use ?
raise ValueError("Invalid param 'index'")
-
generator = (target for index in searchable_indexes for target in
index.split(","))
return list(self._validate_search_targets(generator, query=query))
diff --git
a/providers/elasticsearch/tests/unit/elasticsearch/log/test_es_task_handler.py
b/providers/elasticsearch/tests/unit/elasticsearch/log/test_es_task_handler.py
index d3b27b19874..0b3fa0440f9 100644
---
a/providers/elasticsearch/tests/unit/elasticsearch/log/test_es_task_handler.py
+++
b/providers/elasticsearch/tests/unit/elasticsearch/log/test_es_task_handler.py
@@ -22,6 +22,8 @@ import logging
import os
import re
import shutil
+import tempfile
+import uuid
from io import StringIO
from pathlib import Path
from unittest import mock
@@ -36,7 +38,10 @@ from airflow.providers.common.compat.sdk import conf
from airflow.providers.elasticsearch.log.es_response import
ElasticSearchResponse
from airflow.providers.elasticsearch.log.es_task_handler import (
VALID_ES_CONFIG_KEYS,
+ ElasticsearchRemoteLogIO,
ElasticsearchTaskHandler,
+ _clean_date,
+ _render_log_id,
get_es_kwargs_from_config,
getattr_nested,
)
@@ -54,10 +59,11 @@ from unit.elasticsearch.log.elasticmock.utilities import
SearchFailedException
ES_PROVIDER_YAML_FILE = AIRFLOW_PROVIDERS_ROOT_PATH / "elasticsearch" /
"provider.yaml"
-def get_ti(dag_id, task_id, logical_date, create_task_instance):
+def get_ti(dag_id, task_id, run_id, logical_date, create_task_instance):
ti = create_task_instance(
dag_id=dag_id,
task_id=task_id,
+ run_id=run_id,
logical_date=logical_date,
dagrun_state=DagRunState.RUNNING,
state=TaskInstanceState.RUNNING,
@@ -70,9 +76,12 @@ def get_ti(dag_id, task_id, logical_date,
create_task_instance):
class TestElasticsearchTaskHandler:
DAG_ID = "dag_for_testing_es_task_handler"
TASK_ID = "task_for_testing_es_log_handler"
+ RUN_ID = "run_for_testing_es_log_handler"
+ MAP_INDEX = -1
+ TRY_NUM = 1
LOGICAL_DATE = datetime(2016, 1, 1)
- LOG_ID = f"{DAG_ID}-{TASK_ID}-2016-01-01T00:00:00+00:00-1"
- JSON_LOG_ID =
f"{DAG_ID}-{TASK_ID}-{ElasticsearchTaskHandler._clean_date(LOGICAL_DATE)}-1"
+ LOG_ID = f"{DAG_ID}-{TASK_ID}-{RUN_ID}-{MAP_INDEX}-{TRY_NUM}"
+ JSON_LOG_ID = f"{DAG_ID}-{TASK_ID}-{_clean_date(LOGICAL_DATE)}-1"
FILENAME_TEMPLATE = "{try_number}.log"
@pytest.fixture
@@ -88,6 +97,7 @@ class TestElasticsearchTaskHandler:
yield get_ti(
dag_id=self.DAG_ID,
task_id=self.TASK_ID,
+ run_id=self.RUN_ID,
logical_date=self.LOGICAL_DATE,
create_task_instance=create_task_instance,
)
@@ -128,21 +138,24 @@ class TestElasticsearchTaskHandler:
def teardown_method(self):
shutil.rmtree(self.local_log_location.split(os.path.sep)[0],
ignore_errors=True)
- def test_es_response(self):
- sample_response = self.es.sample_log_response()
- es_response = ElasticSearchResponse(self.es_task_handler,
sample_response)
- logs_by_host = self.es_task_handler._group_logs_by_host(es_response)
-
- def concat_logs(lines):
- log_range = -1 if lines[-1].message ==
self.es_task_handler.end_of_log_mark else None
- return "\n".join(self.es_task_handler._format_msg(line) for line
in lines[:log_range])
+ @pytest.mark.parametrize(
+ "sample_response",
+ [
+ pytest.param(lambda self: self.es.sample_airflow_2_log_response(),
id="airflow_2"),
+ pytest.param(lambda self: self.es.sample_airflow_3_log_response(),
id="airflow_3"),
+ ],
+ )
+ def test_es_response(self, sample_response):
+ response = sample_response(self)
+ es_response = ElasticSearchResponse(self.es_task_handler, response)
+ logs_by_host = self.es_task_handler.io._group_logs_by_host(es_response)
for hosted_log in logs_by_host.values():
- message = concat_logs(hosted_log)
+ message = self.es_task_handler.concat_logs(hosted_log)
assert (
message == "Dependencies all met for dep_context=non-requeueable"
- " deps ti=<TaskInstance: example_bash_operator.run_after_loop
owen_run_run [queued]>\n"
+ " deps ti=<TaskInstance: example_bash_operator.run_after_loop \n"
"Starting attempt 1 of 1\nExecuting <Task(BashOperator):
run_after_loop> "
"on 2023-07-09 07:47:32+00:00"
)
@@ -263,7 +276,7 @@ class TestElasticsearchTaskHandler:
@pytest.mark.db_test
def test_read_with_patterns_no_match(self, ti):
ts = pendulum.now()
- with mock.patch.object(self.es_task_handler, "index_patterns",
new="test_other_*,test_another_*"):
+ with mock.patch.object(self.es_task_handler.io, "index_patterns",
new="test_other_*,test_another_*"):
logs, metadatas = self.es_task_handler.read(
ti, 1, {"offset": 0, "last_log_timestamp": str(ts),
"end_of_log": False}
)
@@ -280,14 +293,14 @@ class TestElasticsearchTaskHandler:
metadata = metadatas[0]
assert metadata["offset"] == "0"
- assert not metadata["end_of_log"]
+ assert metadata["end_of_log"]
# last_log_timestamp won't change if no log lines read.
assert timezone.parse(metadata["last_log_timestamp"]) == ts
@pytest.mark.db_test
def test_read_with_missing_index(self, ti):
ts = pendulum.now()
- with mock.patch.object(self.es_task_handler, "index_patterns",
new="nonexistent,test_*"):
+ with mock.patch.object(self.es_task_handler.io, "index_patterns",
new="nonexistent,test_*"):
with pytest.raises(elasticsearch.exceptions.NotFoundError,
match=r"IndexMissingException.*"):
self.es_task_handler.read(
ti,
@@ -302,9 +315,11 @@ class TestElasticsearchTaskHandler:
When the log actually isn't there to be found, we only want to wait
for 5 seconds.
In this case we expect to receive a message of the form 'Log {log_id}
not found in elasticsearch ...'
"""
+ run_id = "wrong_run_id"
ti = get_ti(
self.DAG_ID,
self.TASK_ID,
+ run_id,
pendulum.instance(self.LOGICAL_DATE).add(days=1), # so logs are
not found
create_task_instance=create_task_instance,
)
@@ -320,7 +335,7 @@ class TestElasticsearchTaskHandler:
else:
# we've "waited" less than 5 seconds so it should not be "end
of log" and should be no log message
assert logs == []
- assert metadatas["end_of_log"] is False
+ assert metadatas["end_of_log"] is True
assert metadatas["offset"] == "0"
assert timezone.parse(metadatas["last_log_timestamp"]) == ts
else:
@@ -336,7 +351,7 @@ class TestElasticsearchTaskHandler:
# we've "waited" less than 5 seconds so it should not be "end
of log" and should be no log message
assert len(logs[0]) == 0
assert logs == [[]]
- assert metadatas[0]["end_of_log"] is False
+ assert metadatas[0]["end_of_log"] is True
assert len(logs) == len(metadatas)
assert metadatas[0]["offset"] == "0"
assert timezone.parse(metadatas[0]["last_log_timestamp"]) == ts
@@ -432,7 +447,7 @@ class TestElasticsearchTaskHandler:
metadata = metadatas[0]
assert metadata["offset"] == "0"
- assert not metadata["end_of_log"]
+ assert metadata["end_of_log"]
# last_log_timestamp won't change if no log lines read.
assert timezone.parse(metadata["last_log_timestamp"]) == ts
@@ -440,6 +455,7 @@ class TestElasticsearchTaskHandler:
def test_read_with_empty_metadata(self, ti):
ts = pendulum.now()
logs, metadatas = self.es_task_handler.read(ti, 1, {})
+ print(f"metadatas: {metadatas}")
if AIRFLOW_V_3_0_PLUS:
logs = list(logs)
assert logs[0].event == "::group::Log message source details"
@@ -455,7 +471,7 @@ class TestElasticsearchTaskHandler:
assert self.test_message == logs[0][0][-1]
metadata = metadatas[0]
-
+ print(f"metadatas: {metadatas}")
assert not metadata["end_of_log"]
# offset should be initialized to 0 if not provided.
assert metadata["offset"] == "1"
@@ -477,7 +493,7 @@ class TestElasticsearchTaskHandler:
metadata = metadatas[0]
- assert not metadata["end_of_log"]
+ assert metadata["end_of_log"]
# offset should be initialized to 0 if not provided.
assert metadata["offset"] == "0"
# last_log_timestamp will be initialized using log reading time
@@ -552,27 +568,22 @@ class TestElasticsearchTaskHandler:
@pytest.mark.db_test
def test_read_raises(self, ti):
- with mock.patch.object(self.es_task_handler.log, "exception") as
mock_exception:
- with mock.patch.object(self.es_task_handler.client, "search") as
mock_execute:
+ with mock.patch.object(self.es_task_handler.io.log, "exception") as
mock_exception:
+ with mock.patch.object(self.es_task_handler.io.client, "search")
as mock_execute:
mock_execute.side_effect = SearchFailedException("Failed to
read")
- logs, metadatas = self.es_task_handler.read(ti, 1)
+ log_sources, log_msgs = self.es_task_handler.io.read("", ti)
assert mock_exception.call_count == 1
args, kwargs = mock_exception.call_args
assert "Could not read log with log_id:" in args[0]
if AIRFLOW_V_3_0_PLUS:
- assert logs == []
-
- metadata = metadatas
+ assert log_sources == []
else:
- assert len(logs) == 1
- assert len(logs) == len(metadatas)
- assert logs == [[]]
-
- metadata = metadatas[0]
+ assert len(log_sources) == 0
+ assert len(log_msgs) == 1
+ assert log_sources == []
- assert metadata["offset"] == "0"
- assert not metadata["end_of_log"]
+ assert "not found in Elasticsearch" in log_msgs[0]
@pytest.mark.db_test
def test_set_context(self, ti):
@@ -616,9 +627,7 @@ class TestElasticsearchTaskHandler:
logs = list(logs)
assert logs[2].event == self.test_message
else:
- assert (
- logs[0][0][1] == "[2020-12-24 19:25:00,962]
{taskinstance.py:851} INFO - some random stuff - "
- )
+ assert logs[0][0][1] == self.test_message
@pytest.mark.db_test
def test_read_with_json_format_with_custom_offset_and_host_fields(self,
ti):
@@ -634,7 +643,7 @@ class TestElasticsearchTaskHandler:
self.body = {
"message": self.test_message,
"event": self.test_message,
- "log_id":
f"{self.DAG_ID}-{self.TASK_ID}-2016_01_01T00_00_00_000000-1",
+ "log_id": self.LOG_ID,
"log": {"offset": 1},
"host": {"name": "somehostname"},
"asctime": "2020-12-24 19:25:00,962",
@@ -652,9 +661,7 @@ class TestElasticsearchTaskHandler:
logs = list(logs)
assert logs[2].event == self.test_message
else:
- assert (
- logs[0][0][1] == "[2020-12-24 19:25:00,962]
{taskinstance.py:851} INFO - some random stuff - "
- )
+ assert logs[0][0][1] == self.test_message
@pytest.mark.db_test
def test_read_with_custom_offset_and_host_fields(self, ti):
@@ -753,13 +760,13 @@ class TestElasticsearchTaskHandler:
@pytest.mark.db_test
def test_render_log_id(self, ti):
- assert self.es_task_handler._render_log_id(ti, 1) == self.LOG_ID
+ assert _render_log_id(self.es_task_handler.log_id_template, ti, 1) ==
self.LOG_ID
self.es_task_handler.json_format = True
- assert self.es_task_handler._render_log_id(ti, 1) == self.JSON_LOG_ID
+ assert _render_log_id(self.es_task_handler.log_id_template, ti, 1) ==
self.LOG_ID
def test_clean_date(self):
- clean_logical_date = self.es_task_handler._clean_date(datetime(2016,
7, 8, 9, 10, 11, 12))
+ clean_logical_date = _clean_date(datetime(2016, 7, 8, 9, 10, 11, 12))
assert clean_logical_date == "2016_07_08T09_10_11_000012"
@pytest.mark.db_test
@@ -770,7 +777,7 @@ class TestElasticsearchTaskHandler:
(
True,
"localhost:5601/{log_id}",
- "https://localhost:5601/" + quote(JSON_LOG_ID),
+ "https://localhost:5601/" + quote(LOG_ID),
),
(
False,
@@ -867,8 +874,8 @@ class TestElasticsearchTaskHandler:
mock_callable = Mock(return_value="callable_index_pattern")
mock_import_string.return_value = mock_callable
- self.es_task_handler.index_patterns_callable =
"path.to.index_pattern_callable"
- result = self.es_task_handler._get_index_patterns({})
+ self.es_task_handler.io.index_patterns_callable =
"path.to.index_pattern_callable"
+ result = self.es_task_handler.io._get_index_patterns({})
mock_import_string.assert_called_once_with("path.to.index_pattern_callable")
mock_callable.assert_called_once_with({})
@@ -885,25 +892,6 @@ class TestElasticsearchTaskHandler:
filename_template=None,
)
- @pytest.mark.db_test
- def test_write_to_es(self, ti):
- self.es_task_handler.write_to_es = True
- self.es_task_handler.json_format = True
- self.es_task_handler.write_stdout = False
- self.es_task_handler.local_base = Path(os.getcwd()) / "local" / "log"
/ "location"
- formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s
- %(message)s")
- self.es_task_handler.formatter = formatter
-
- self.es_task_handler.set_context(ti)
- with patch(
-
"airflow.providers.elasticsearch.log.es_task_handler.ElasticsearchTaskHandler._write_to_es"
- ) as mock_write_to_es:
- mock_write = Mock(return_value=True)
- mock_write_to_es.return_value = mock_write
- self.es_task_handler._write_to_es = mock_write_to_es
- self.es_task_handler.close()
- mock_write_to_es.assert_called_once()
-
def test_safe_attrgetter():
class A: ...
@@ -963,3 +951,134 @@ def test_self_not_valid_arg():
Test if self is not a valid argument.
"""
assert "self" not in VALID_ES_CONFIG_KEYS
+
+
[email protected]_test
+class TestElasticsearchRemoteLogIO:
+ DAG_ID = "dag_for_testing_es_log_handler"
+ TASK_ID = "task_for_testing_es_log_handler"
+ RUN_ID = "run_for_testing_es_log_handler"
+ LOGICAL_DATE = datetime(2016, 1, 1)
+ FILENAME_TEMPLATE = "{try_number}.log"
+
+ @pytest.fixture(autouse=True)
+ def setup_tests(self, ti):
+ self.elasticsearch_io = ElasticsearchRemoteLogIO(
+ write_to_es=True,
+ write_stdout=True,
+ delete_local_copy=True,
+ host="http://localhost:9200",
+ base_log_folder=Path(""),
+ )
+
+ @pytest.fixture
+ def tmp_json_file(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ os.makedirs(tmpdir, exist_ok=True)
+
+ file_path = os.path.join(tmpdir, "1.log")
+ self.tmp_file = file_path
+
+ sample_logs = [
+ {"message": "start"},
+ {"message": "processing"},
+ {"message": "end"},
+ ]
+ with open(file_path, "w") as f:
+ for log in sample_logs:
+ f.write(json.dumps(log) + "\n")
+
+ yield file_path
+
+ del self.tmp_file
+
+ @pytest.fixture
+ def ti(self, create_task_instance, create_log_template):
+ create_log_template(
+ self.FILENAME_TEMPLATE,
+ (
+ "{dag_id}-{task_id}-{logical_date}-{try_number}"
+ if AIRFLOW_V_3_0_PLUS
+ else "{dag_id}-{task_id}-{execution_date}-{try_number}"
+ ),
+ )
+ yield get_ti(
+ dag_id=self.DAG_ID,
+ task_id=self.TASK_ID,
+ run_id=self.RUN_ID,
+ logical_date=self.LOGICAL_DATE,
+ create_task_instance=create_task_instance,
+ )
+ clear_db_runs()
+ clear_db_dags()
+
+ @pytest.fixture
+ def unique_index(self):
+ """Generate a unique index name for each test."""
+ return f"airflow-logs-{uuid.uuid4()}"
+
+ @pytest.mark.integration("elasticsearch")
+ @pytest.mark.setup_timeout(300)
+ @pytest.mark.execution_timeout(300)
+ @patch(
+ "airflow.providers.elasticsearch.log.es_task_handler.TASK_LOG_FIELDS",
+ ["message"],
+ )
+ def test_read_write_to_es(self, tmp_json_file, ti, es_8_container_url):
+ self.elasticsearch_io.host = es_8_container_url
+ self.elasticsearch_io.client =
elasticsearch.Elasticsearch(es_8_container_url).options(
+ request_timeout=120, retry_on_timeout=True, max_retries=5
+ )
+ self.elasticsearch_io.write_stdout = False
+ self.elasticsearch_io.upload(tmp_json_file, ti)
+ self.elasticsearch_io.client.indices.refresh(
+ index=self.elasticsearch_io.target_index, request_timeout=120
+ )
+ log_source_info, log_messages = self.elasticsearch_io.read("", ti)
+ assert log_source_info[0] == es_8_container_url
+ assert len(log_messages) == 3
+
+ expected_msg = ["start", "processing", "end"]
+ for msg, log_message in zip(expected_msg, log_messages):
+ print(f"msg: {msg}, log_message: {log_message}")
+ json_log = json.loads(log_message)
+ assert "message" in json_log
+ assert json_log["message"] == msg
+
+ def test_write_to_stdout(self, tmp_json_file, ti, capsys):
+ self.elasticsearch_io.write_to_es = False
+ self.elasticsearch_io.upload(tmp_json_file, ti)
+
+ captured = capsys.readouterr()
+ stdout_lines = captured.out.strip().splitlines()
+ log_entries = [json.loads(line) for line in stdout_lines]
+ assert log_entries[0]["message"] == "start"
+ assert log_entries[1]["message"] == "processing"
+ assert log_entries[2]["message"] == "end"
+
+ def test_invalid_task_log_file_path(self, ti):
+ with (
+ patch.object(self.elasticsearch_io, "_parse_raw_log") as
mock_parse,
+ patch.object(self.elasticsearch_io, "_write_to_es") as mock_write,
+ ):
+ self.elasticsearch_io.upload(Path("/invalid/path"), ti)
+
+ mock_parse.assert_not_called()
+ mock_write.assert_not_called()
+
+ def test_raw_log_should_contain_log_id_and_offset(self, tmp_json_file, ti):
+ with open(self.tmp_file) as f:
+ raw_log = f.read()
+ json_log_lines = self.elasticsearch_io._parse_raw_log(raw_log, ti)
+ assert len(json_log_lines) == 3
+ for json_log_line in json_log_lines:
+ assert "log_id" in json_log_line
+ assert "offset" in json_log_line
+
+ @patch("elasticsearch.Elasticsearch.count", return_value={"count": 0})
+ def test_read_with_missing_log(self, mocked_count, ti):
+ log_source_info, log_messages = self.elasticsearch_io.read("", ti)
+ log_id = _render_log_id(self.elasticsearch_io.log_id_template, ti,
ti.try_number)
+ assert log_source_info == []
+ assert f"*** Log {log_id} not found in Elasticsearch" in
log_messages[0]
+ mocked_count.assert_called_once()