This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-7-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 0655d883ae391c52ad0b9126cf6932cf641faa41
Author: Igor Khrol <[email protected]>
AuthorDate: Sun Aug 6 00:15:12 2023 +0300

    Skip served logs for non-running task try (#32561)
    
    Co-authored-by: eladkal <[email protected]>
    (cherry picked from commit 29d5e955fca5e6bee30b14ac9fcf85eebc94ae6d)
---
 airflow/utils/log/file_task_handler.py | 12 ++++++------
 tests/utils/test_log_handlers.py       | 16 +++++++++++++---
 2 files changed, 19 insertions(+), 9 deletions(-)

diff --git a/airflow/utils/log/file_task_handler.py 
b/airflow/utils/log/file_task_handler.py
index 50938f96cb..6c8073b005 100644
--- a/airflow/utils/log/file_task_handler.py
+++ b/airflow/utils/log/file_task_handler.py
@@ -306,6 +306,10 @@ class FileTaskHandler(logging.Handler):
         executor_messages: list[str] = []
         executor_logs: list[str] = []
         served_logs: list[str] = []
+        is_running = ti.try_number == try_number and ti.state in (
+            TaskInstanceState.RUNNING,
+            TaskInstanceState.DEFERRED,
+        )
         with suppress(NotImplementedError):
             remote_messages, remote_logs = self._read_remote_logs(ti, 
try_number, metadata)
             messages_list.extend(remote_messages)
@@ -320,7 +324,7 @@ class FileTaskHandler(logging.Handler):
             worker_log_full_path = Path(self.local_base, worker_log_rel_path)
             local_messages, local_logs = 
self._read_from_local(worker_log_full_path)
             messages_list.extend(local_messages)
-        if ti.state in (TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED) 
and not executor_messages:
+        if is_running and not executor_messages:
             served_messages, served_logs = self._read_from_logs_server(ti, 
worker_log_rel_path)
             messages_list.extend(served_messages)
         elif ti.state not in State.unfinished and not (local_logs or 
remote_logs):
@@ -340,15 +344,11 @@ class FileTaskHandler(logging.Handler):
         )
         log_pos = len(logs)
         messages = "".join([f"*** {x}\n" for x in messages_list])
-        end_of_log = ti.try_number != try_number or ti.state not in (
-            TaskInstanceState.RUNNING,
-            TaskInstanceState.DEFERRED,
-        )
         if metadata and "log_pos" in metadata:
             previous_chars = metadata["log_pos"]
             logs = logs[previous_chars:]  # Cut off previously passed log test 
as new tail
         out_message = logs if "log_pos" in (metadata or {}) else messages + 
logs
-        return out_message, {"end_of_log": end_of_log, "log_pos": log_pos}
+        return out_message, {"end_of_log": not is_running, "log_pos": log_pos}
 
     @staticmethod
     def _get_pod_namespace(ti: TaskInstance):
diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py
index 8c772c5799..30eb480076 100644
--- a/tests/utils/test_log_handlers.py
+++ b/tests/utils/test_log_handlers.py
@@ -298,7 +298,7 @@ class TestFileTaskLogHandler:
 
     def test__read_for_celery_executor_fallbacks_to_worker(self, 
create_task_instance):
         """Test for executors which do not have `get_task_log` method, it 
fallbacks to reading
-        log from worker"""
+        log from worker. But it happens only for the latest try_number."""
         executor_name = "CeleryExecutor"
 
         ti = create_task_instance(
@@ -308,14 +308,24 @@ class TestFileTaskLogHandler:
             execution_date=DEFAULT_DATE,
         )
         ti.state = TaskInstanceState.RUNNING
+        ti.try_number = 2
         with conf_vars({("core", "executor"): executor_name}):
             fth = FileTaskHandler("")
 
             fth._read_from_logs_server = mock.Mock()
             fth._read_from_logs_server.return_value = ["this message"], 
["this\nlog\ncontent"]
-            actual = fth._read(ti=ti, try_number=1)
+            actual = fth._read(ti=ti, try_number=2)
             fth._read_from_logs_server.assert_called_once()
-        assert actual == ("*** this message\nthis\nlog\ncontent", 
{"end_of_log": True, "log_pos": 16})
+            assert actual == ("*** this message\nthis\nlog\ncontent", 
{"end_of_log": False, "log_pos": 16})
+
+            # Previous try_number is from remote logs without reaching worker 
server
+            fth._read_from_logs_server.reset_mock()
+            fth._read_remote_logs = mock.Mock()
+            fth._read_remote_logs.return_value = ["remote logs"], 
["remote\nlog\ncontent"]
+            actual = fth._read(ti=ti, try_number=1)
+            fth._read_remote_logs.assert_called_once()
+            fth._read_from_logs_server.assert_not_called()
+            assert actual == ("*** remote logs\nremote\nlog\ncontent", 
{"end_of_log": True, "log_pos": 18})
 
     @pytest.mark.parametrize(
         "remote_logs, local_logs, served_logs_checked",

Reply via email to