This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3b25168c41 AIP-51 - Executor Coupling in Logging (#28161)
3b25168c41 is described below

commit 3b25168c413a8434f8f65efb09aaf949cf7adc3b
Author: sanjayp <sanjaypilla...@gmail.com>
AuthorDate: Tue Jan 24 14:38:49 2023 -0600

    AIP-51 - Executor Coupling in Logging (#28161)
    
    Executors may now implement a method to vend task logs
---
 airflow/executors/base_executor.py                 |   9 ++
 airflow/executors/celery_kubernetes_executor.py    |   6 +
 airflow/executors/kubernetes_executor.py           |  53 +++++++
 airflow/executors/local_kubernetes_executor.py     |   7 +
 airflow/utils/log/file_task_handler.py             | 166 ++++++++-------------
 tests/executors/test_base_executor.py              |   8 +-
 tests/executors/test_celery_kubernetes_executor.py |  16 ++
 tests/executors/test_kubernetes_executor.py        |  28 ++++
 tests/executors/test_local_kubernetes_executor.py  |  19 +++
 .../amazon/aws/log/test_s3_task_handler.py         |   2 +-
 tests/utils/test_log_handlers.py                   |  97 +++++++-----
 11 files changed, 272 insertions(+), 139 deletions(-)

diff --git a/airflow/executors/base_executor.py 
b/airflow/executors/base_executor.py
index 40563a11d4..47a37e1401 100644
--- a/airflow/executors/base_executor.py
+++ b/airflow/executors/base_executor.py
@@ -355,6 +355,15 @@ class BaseExecutor(LoggingMixin):
         """
         raise NotImplementedError()
 
+    def get_task_log(self, ti: TaskInstance, log: str = "") -> None | str | 
tuple[str, dict[str, bool]]:
+        """
+        This method can be implemented by any child class to return the task 
logs.
+
+        :param ti: A TaskInstance object
+        :param log: log str
+        :return: logs or tuple of logs and meta dict
+        """
+
     def end(self) -> None:  # pragma: no cover
         """Wait synchronously for the previously submitted job to complete."""
         raise NotImplementedError()
diff --git a/airflow/executors/celery_kubernetes_executor.py 
b/airflow/executors/celery_kubernetes_executor.py
index b477d25eaa..8426fb526f 100644
--- a/airflow/executors/celery_kubernetes_executor.py
+++ b/airflow/executors/celery_kubernetes_executor.py
@@ -141,6 +141,12 @@ class CeleryKubernetesExecutor(LoggingMixin):
             cfg_path=cfg_path,
         )
 
+    def get_task_log(self, ti: TaskInstance, log: str = "") -> None | str | 
tuple[str, dict[str, bool]]:
+        """Fetch task log from Kubernetes executor"""
+        if ti.queue == self.kubernetes_executor.kubernetes_queue:
+            return self.kubernetes_executor.get_task_log(ti=ti, log=log)
+        return None
+
     def has_task(self, task_instance: TaskInstance) -> bool:
         """
         Checks if a task is either queued or running in either celery or 
kubernetes executor.
diff --git a/airflow/executors/kubernetes_executor.py 
b/airflow/executors/kubernetes_executor.py
index e1d7b06a98..739b41de5d 100644
--- a/airflow/executors/kubernetes_executor.py
+++ b/airflow/executors/kubernetes_executor.py
@@ -28,6 +28,7 @@ import logging
 import multiprocessing
 import time
 from collections import defaultdict
+from contextlib import suppress
 from datetime import timedelta
 from queue import Empty, Queue
 from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple
@@ -37,6 +38,7 @@ from kubernetes.client import Configuration, models as k8s
 from kubernetes.client.rest import ApiException
 from urllib3.exceptions import ReadTimeoutError
 
+from airflow.configuration import conf
 from airflow.exceptions import AirflowException, PodMutationHookException, 
PodReconciliationError
 from airflow.executors.base_executor import BaseExecutor, CommandType
 from airflow.kubernetes import pod_generator
@@ -771,6 +773,57 @@ class KubernetesExecutor(BaseExecutor):
             # do this once, so only do it when we remove the task from running
             self.event_buffer[key] = state, None
 
+    @staticmethod
+    def _get_pod_namespace(ti: TaskInstance):
+        pod_override = ti.executor_config.get("pod_override")
+        namespace = None
+        with suppress(Exception):
+            namespace = pod_override.metadata.namespace
+        return namespace or conf.get("kubernetes_executor", "namespace", 
fallback="default")
+
+    def get_task_log(self, ti: TaskInstance, log: str = "") -> str | 
tuple[str, dict[str, bool]]:
+
+        try:
+            from airflow.kubernetes.pod_generator import PodGenerator
+
+            client = get_kube_client()
+
+            log += f"*** Trying to get logs (last 100 lines) from worker pod 
{ti.hostname} ***\n\n"
+            selector = PodGenerator.build_selector_for_k8s_executor_pod(
+                dag_id=ti.dag_id,
+                task_id=ti.task_id,
+                try_number=ti.try_number,
+                map_index=ti.map_index,
+                run_id=ti.run_id,
+                airflow_worker=ti.queued_by_job_id,
+            )
+            namespace = self._get_pod_namespace(ti)
+            pod_list = client.list_namespaced_pod(
+                namespace=namespace,
+                label_selector=selector,
+            ).items
+            if not pod_list:
+                raise RuntimeError("Cannot find pod for ti %s", ti)
+            elif len(pod_list) > 1:
+                raise RuntimeError("Found multiple pods for ti %s: %s", ti, 
pod_list)
+            res = client.read_namespaced_pod_log(
+                name=pod_list[0].metadata.name,
+                namespace=namespace,
+                container="base",
+                follow=False,
+                tail_lines=100,
+                _preload_content=False,
+            )
+
+            for line in res:
+                log += line.decode()
+
+            return log
+
+        except Exception as f:
+            log += f"*** Unable to fetch logs from worker pod {ti.hostname} 
***\n{str(f)}\n\n"
+            return log, {"end_of_log": True}
+
     def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> 
Sequence[TaskInstance]:
         tis_to_flush = [ti for ti in tis if not ti.queued_by_job_id]
         scheduler_job_ids = {ti.queued_by_job_id for ti in tis}
diff --git a/airflow/executors/local_kubernetes_executor.py 
b/airflow/executors/local_kubernetes_executor.py
index f723ab9998..258135f31c 100644
--- a/airflow/executors/local_kubernetes_executor.py
+++ b/airflow/executors/local_kubernetes_executor.py
@@ -142,6 +142,13 @@ class LocalKubernetesExecutor(LoggingMixin):
             cfg_path=cfg_path,
         )
 
+    def get_task_log(self, ti: TaskInstance, log: str = "") -> None | str | 
tuple[str, dict[str, bool]]:
+        """Fetch task log from kubernetes executor"""
+        if ti.queue == self.kubernetes_executor.kubernetes_queue:
+            return self.kubernetes_executor.get_task_log(ti=ti, log=log)
+
+        return None
+
     def has_task(self, task_instance: TaskInstance) -> bool:
         """
         Checks if a task is either queued or running in either local or 
kubernetes executor.
diff --git a/airflow/utils/log/file_task_handler.py 
b/airflow/utils/log/file_task_handler.py
index 09fbbbe097..0d54783244 100644
--- a/airflow/utils/log/file_task_handler.py
+++ b/airflow/utils/log/file_task_handler.py
@@ -26,8 +26,9 @@ from pathlib import Path
 from typing import TYPE_CHECKING, Any
 from urllib.parse import urljoin
 
-from airflow.configuration import AirflowConfigException, conf
-from airflow.exceptions import RemovedInAirflow3Warning
+from airflow.configuration import conf
+from airflow.exceptions import AirflowConfigException, RemovedInAirflow3Warning
+from airflow.executors.executor_loader import ExecutorLoader
 from airflow.utils.context import Context
 from airflow.utils.helpers import parse_template_string, 
render_template_to_string
 from airflow.utils.log.logging_mixin import SetContextPropagate
@@ -146,23 +147,54 @@ class FileTaskHandler(logging.Handler):
     def _read_grouped_logs(self):
         return False
 
-    @staticmethod
-    def _should_check_k8s(queue):
-        """
-        If the task is running through kubernetes executor, return True.
+    def _get_task_log_from_worker(
+        self, ti: TaskInstance, log: str, log_relative_path: str
+    ) -> str | tuple[str, dict[str, bool]]:
+        import httpx
 
-        When logs aren't available locally, in this case we read from k8s pod 
logs.
-        """
-        executor = conf.get("core", "executor")
-        if executor == "KubernetesExecutor":
-            return True
-        elif executor == "LocalKubernetesExecutor":
-            if queue == conf.get("local_kubernetes_executor", 
"kubernetes_queue"):
-                return True
-        elif executor == "CeleryKubernetesExecutor":
-            if queue == conf.get("celery_kubernetes_executor", 
"kubernetes_queue"):
-                return True
-        return False
+        from airflow.utils.jwt_signer import JWTSigner
+
+        url = self._get_log_retrieval_url(ti, log_relative_path)
+        log += f"*** Fetching from: {url}\n"
+
+        try:
+            timeout = None  # No timeout
+            try:
+                timeout = conf.getint("webserver", "log_fetch_timeout_sec")
+            except (AirflowConfigException, ValueError):
+                pass
+
+            signer = JWTSigner(
+                secret_key=conf.get("webserver", "secret_key"),
+                expiration_time_in_seconds=conf.getint("webserver", 
"log_request_clock_grace", fallback=30),
+                audience="task-instance-logs",
+            )
+            response = httpx.get(
+                url,
+                timeout=timeout,
+                headers={"Authorization": 
signer.generate_signed_token({"filename": log_relative_path})},
+            )
+            response.encoding = "utf-8"
+
+            if response.status_code == 403:
+                log += (
+                    "*** !!!! Please make sure that all your Airflow 
components (e.g. "
+                    "schedulers, webservers and workers) have "
+                    "the same 'secret_key' configured in 'webserver' section 
and "
+                    "time is synchronized on all your machines (for example 
with ntpd) !!!!!\n***"
+                )
+                log += (
+                    "*** See more at 
https://airflow.apache.org/docs/apache-airflow/";
+                    "stable/configurations-ref.html#secret-key\n***"
+                )
+            # Check if the resource was properly fetched
+            response.raise_for_status()
+
+            log += "\n" + response.text
+            return log
+        except Exception as e:
+            log += f"*** Failed to fetch log file from worker. {str(e)}\n"
+            return log, {"end_of_log": True}
 
     def _read(self, ti: TaskInstance, try_number: int, metadata: dict[str, 
Any] | None = None):
         """
@@ -186,8 +218,6 @@ class FileTaskHandler(logging.Handler):
                              This is determined by the status of the 
TaskInstance
                  log_pos: (absolute) Char position to which the log is 
retrieved
         """
-        from airflow.utils.jwt_signer import JWTSigner
-
         # Task instance here might be different from task instance when
         # initializing the handler. Thus explicitly getting log location
         # is needed to get correct log path.
@@ -204,91 +234,23 @@ class FileTaskHandler(logging.Handler):
                 log = f"*** Failed to load local log file: {location}\n"
                 log += f"*** {str(e)}\n"
                 return log, {"end_of_log": True}
-        elif self._should_check_k8s(ti.queue):
-            try:
-                from airflow.kubernetes.kube_client import get_kube_client
-                from airflow.kubernetes.pod_generator import PodGenerator
-
-                client = get_kube_client()
-
-                log += f"*** Trying to get logs (last 100 lines) from worker 
pod {ti.hostname} ***\n\n"
-                selector = PodGenerator.build_selector_for_k8s_executor_pod(
-                    dag_id=ti.dag_id,
-                    task_id=ti.task_id,
-                    try_number=ti.try_number,
-                    map_index=ti.map_index,
-                    run_id=ti.run_id,
-                    airflow_worker=ti.queued_by_job_id,
-                )
-                namespace = self._get_pod_namespace(ti)
-                pod_list = client.list_namespaced_pod(
-                    namespace=namespace,
-                    label_selector=selector,
-                ).items
-                if not pod_list:
-                    raise RuntimeError("Cannot find pod for ti %s", ti)
-                elif len(pod_list) > 1:
-                    raise RuntimeError("Found multiple pods for ti %s: %s", 
ti, pod_list)
-                res = client.read_namespaced_pod_log(
-                    name=pod_list[0].metadata.name,
-                    namespace=namespace,
-                    container="base",
-                    follow=False,
-                    tail_lines=100,
-                    _preload_content=False,
-                )
+        else:
+            log += f"*** Local log file does not exist: {location}\n"
+            executor = ExecutorLoader.get_default_executor()
+            task_log = None
 
-                for line in res:
-                    log += line.decode()
+            task_log = executor.get_task_log(ti=ti, log=log)
+            if isinstance(task_log, tuple):
+                return task_log
 
-            except Exception as f:
-                log += f"*** Unable to fetch logs from worker pod 
{ti.hostname} ***\n{str(f)}\n\n"
-                return log, {"end_of_log": True}
-        else:
-            import httpx
+            if task_log is None:
+                log += "*** Failed to fetch log from executor. Falling back to 
fetching log from worker.\n"
+                task_log = self._get_task_log_from_worker(ti, log, 
log_relative_path=log_relative_path)
 
-            url = self._get_log_retrieval_url(ti, log_relative_path)
-            log += f"*** Log file does not exist: {location}\n"
-            log += f"*** Fetching from: {url}\n"
-            try:
-                timeout = None  # No timeout
-                try:
-                    timeout = conf.getint("webserver", "log_fetch_timeout_sec")
-                except (AirflowConfigException, ValueError):
-                    pass
-
-                signer = JWTSigner(
-                    secret_key=conf.get("webserver", "secret_key"),
-                    expiration_time_in_seconds=conf.getint(
-                        "webserver", "log_request_clock_grace", fallback=30
-                    ),
-                    audience="task-instance-logs",
-                )
-                response = httpx.get(
-                    url,
-                    timeout=timeout,
-                    headers={"Authorization": 
signer.generate_signed_token({"filename": log_relative_path})},
-                )
-                response.encoding = "utf-8"
-
-                if response.status_code == 403:
-                    log += (
-                        "*** !!!! Please make sure that all your Airflow 
components (e.g. "
-                        "schedulers, webservers and workers) have "
-                        "the same 'secret_key' configured in 'webserver' 
section and "
-                        "time is synchronized on all your machines (for 
example with ntpd) !!!!!\n***"
-                    )
-                    log += (
-                        "*** See more at 
https://airflow.apache.org/docs/apache-airflow/";
-                        "stable/configurations-ref.html#secret-key\n***"
-                    )
-                # Check if the resource was properly fetched
-                response.raise_for_status()
-
-                log += "\n" + response.text
-            except Exception as e:
-                log += f"*** Failed to fetch log file from worker. {str(e)}\n"
-                return log, {"end_of_log": True}
+            if isinstance(task_log, tuple):
+                return task_log
+
+            log = str(task_log)
 
         # Process tailing if log is not at it's end
         end_of_log = ti.try_number != try_number or ti.state not in 
[State.RUNNING, State.DEFERRED]
diff --git a/tests/executors/test_base_executor.py 
b/tests/executors/test_base_executor.py
index 80650c83d4..30ddaaacc4 100644
--- a/tests/executors/test_base_executor.py
+++ b/tests/executors/test_base_executor.py
@@ -27,7 +27,7 @@ from pytest import mark
 
 from airflow.executors.base_executor import BaseExecutor, 
RunningRetryAttemptType
 from airflow.models.baseoperator import BaseOperator
-from airflow.models.taskinstance import TaskInstanceKey
+from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
 from airflow.utils import timezone
 from airflow.utils.state import State
 
@@ -44,6 +44,12 @@ def test_is_local_default_value():
     assert not BaseExecutor.is_local
 
 
+def test_get_task_log():
+    executor = BaseExecutor()
+    ti = TaskInstance(task=BaseOperator(task_id="dummy"))
+    assert executor.get_task_log(ti=ti) is None
+
+
 def test_serve_logs_default_value():
     assert not BaseExecutor.serve_logs
 
diff --git a/tests/executors/test_celery_kubernetes_executor.py 
b/tests/executors/test_celery_kubernetes_executor.py
index 361481417f..89ccfada2f 100644
--- a/tests/executors/test_celery_kubernetes_executor.py
+++ b/tests/executors/test_celery_kubernetes_executor.py
@@ -173,6 +173,22 @@ class TestCeleryKubernetesExecutor:
         
celery_executor_mock.try_adopt_task_instances.assert_called_once_with(celery_tis)
         
k8s_executor_mock.try_adopt_task_instances.assert_called_once_with(k8s_tis)
 
+    def test_log_is_fetched_from_k8s_executor_only_for_k8s_queue(self):
+        celery_executor_mock = mock.MagicMock()
+        k8s_executor_mock = mock.MagicMock()
+        cke = CeleryKubernetesExecutor(celery_executor_mock, k8s_executor_mock)
+        simple_task_instance = mock.MagicMock()
+        simple_task_instance.queue = KUBERNETES_QUEUE
+        cke.get_task_log(ti=simple_task_instance, log="")
+        
k8s_executor_mock.get_task_log.assert_called_once_with(ti=simple_task_instance, 
log=mock.ANY)
+
+        k8s_executor_mock.reset_mock()
+
+        simple_task_instance.queue = "test-queue"
+        log = cke.get_task_log(ti=simple_task_instance, log="")
+        k8s_executor_mock.get_task_log.assert_not_called()
+        assert log is None
+
     def test_get_event_buffer(self):
         celery_executor_mock = mock.MagicMock()
         k8s_executor_mock = mock.MagicMock()
diff --git a/tests/executors/test_kubernetes_executor.py 
b/tests/executors/test_kubernetes_executor.py
index 0ca2cd3a96..99d6faf527 100644
--- a/tests/executors/test_kubernetes_executor.py
+++ b/tests/executors/test_kubernetes_executor.py
@@ -34,6 +34,7 @@ from airflow import AirflowException
 from airflow.exceptions import PodReconciliationError
 from airflow.models.taskinstance import TaskInstanceKey
 from airflow.operators.bash import BashOperator
+from airflow.operators.empty import EmptyOperator
 from airflow.utils import timezone
 from tests.test_utils.config import conf_vars
 
@@ -1215,6 +1216,33 @@ class TestKubernetesExecutor:
         assert ti0.state == State.SCHEDULED
         assert ti1.state == State.QUEUED
 
+    @mock.patch("airflow.executors.kubernetes_executor.get_kube_client")
+    def test_get_task_log(self, mock_get_kube_client, 
create_task_instance_of_operator):
+        """fetch task log from pod"""
+        mock_kube_client = mock_get_kube_client.return_value
+
+        mock_kube_client.read_namespaced_pod_log.return_value = [b"a_", b"b_", 
b"c_"]
+        mock_pod = mock.Mock()
+        mock_pod.metadata.name = "x"
+        mock_kube_client.list_namespaced_pod.return_value.items = [mock_pod]
+        ti = create_task_instance_of_operator(EmptyOperator, 
dag_id="test_k8s_log_dag", task_id="test_task")
+
+        executor = KubernetesExecutor()
+        log = executor.get_task_log(ti=ti, log="test_init_log")
+
+        mock_kube_client.read_namespaced_pod_log.assert_called_once()
+        assert "test_init_log" in log
+        assert "Trying to get logs (last 100 lines) from worker pod" in log
+        assert "a_b_c" in log
+
+        mock_kube_client.reset_mock()
+        mock_kube_client.read_namespaced_pod_log.side_effect = 
Exception("error_fetching_pod_log")
+
+        log = executor.get_task_log(ti=ti, log="test_init_log")
+        assert len(log) == 2
+        assert "error_fetching_pod_log" in log[0]
+        assert log[1]["end_of_log"]
+
     def test_supports_pickling(self):
         assert KubernetesExecutor.supports_pickling
 
diff --git a/tests/executors/test_local_kubernetes_executor.py 
b/tests/executors/test_local_kubernetes_executor.py
index 809b0277df..497d3a5f9b 100644
--- a/tests/executors/test_local_kubernetes_executor.py
+++ b/tests/executors/test_local_kubernetes_executor.py
@@ -83,6 +83,25 @@ class TestLocalKubernetesExecutor:
 
         assert k8s_executor_mock.kubernetes_queue == 
conf.get("local_kubernetes_executor", "kubernetes_queue")
 
+    def test_log_is_fetched_from_k8s_executor_only_for_k8s_queue(self):
+        local_executor_mock = mock.MagicMock()
+        k8s_executor_mock = mock.MagicMock()
+
+        KUBERNETES_QUEUE = conf.get("local_kubernetes_executor", 
"kubernetes_queue")
+        LocalKubernetesExecutor(local_executor_mock, k8s_executor_mock)
+        local_k8s_exec = LocalKubernetesExecutor(local_executor_mock, 
k8s_executor_mock)
+        simple_task_instance = mock.MagicMock()
+        simple_task_instance.queue = KUBERNETES_QUEUE
+        local_k8s_exec.get_task_log(ti=simple_task_instance, log="")
+        
k8s_executor_mock.get_task_log.assert_called_once_with(ti=simple_task_instance, 
log=mock.ANY)
+
+        k8s_executor_mock.reset_mock()
+
+        simple_task_instance.queue = "test-queue"
+        log = local_k8s_exec.get_task_log(ti=simple_task_instance, log="")
+        k8s_executor_mock.get_task_log.assert_not_called()
+        assert log is None
+
     def test_send_callback(self):
         local_executor_mock = mock.MagicMock()
         k8s_executor_mock = mock.MagicMock()
diff --git a/tests/providers/amazon/aws/log/test_s3_task_handler.py 
b/tests/providers/amazon/aws/log/test_s3_task_handler.py
index 3abd472952..8b01025e3f 100644
--- a/tests/providers/amazon/aws/log/test_s3_task_handler.py
+++ b/tests/providers/amazon/aws/log/test_s3_task_handler.py
@@ -138,7 +138,7 @@ class TestS3TaskHandler:
 
         assert 1 == len(log)
         assert len(log) == len(metadata)
-        assert "*** Log file does not exist:" in log[0][0][-1]
+        assert "*** Local log file does not exist:" in log[0][0][-1]
         assert {"end_of_log": True} == metadata[0]
 
     def test_s3_read_when_log_missing(self):
diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py
index 3a49a47f6b..04cdeec8c4 100644
--- a/tests/utils/test_log_handlers.py
+++ b/tests/utils/test_log_handlers.py
@@ -21,7 +21,8 @@ import logging
 import logging.config
 import os
 import re
-from unittest.mock import patch
+from unittest import mock
+from unittest.mock import mock_open, patch
 
 import pytest
 from kubernetes.client import models as k8s
@@ -35,6 +36,7 @@ from airflow.utils.session import create_session
 from airflow.utils.state import State
 from airflow.utils.timezone import datetime
 from airflow.utils.types import DagRunType
+from tests.test_utils.config import conf_vars
 
 DEFAULT_DATE = datetime(2016, 1, 1)
 TASK_LOGGER = "airflow.task"
@@ -220,6 +222,64 @@ class TestFileTaskLogHandler:
         # Remove the generated tmp log file.
         os.remove(log_filename)
 
+    def test__read_from_location(self, create_task_instance):
+        """Test if local log file exists, then log is read from it"""
+        local_log_file_read = create_task_instance(
+            dag_id="dag_for_testing_local_log_read",
+            task_id="task_for_testing_local_log_read",
+            run_type=DagRunType.SCHEDULED,
+            execution_date=DEFAULT_DATE,
+        )
+        with patch("os.path.exists", return_value=True):
+            opener = mock_open(read_data="dummy test log data")
+            with patch("airflow.utils.log.file_task_handler.open", opener):
+                fth = FileTaskHandler("")
+                log = fth._read(ti=local_log_file_read, try_number=1)
+                assert len(log) == 2
+                assert "dummy test log data" in log[0]
+
+    
@mock.patch("airflow.executors.kubernetes_executor.KubernetesExecutor.get_task_log")
+    def test__read_for_k8s_executor(self, mock_k8s_get_task_log, 
create_task_instance):
+        """Test for k8s executor, the log is read from get_task_log method"""
+        executor_name = "KubernetesExecutor"
+        ti = create_task_instance(
+            dag_id="dag_for_testing_k8s_executor_log_read",
+            task_id="task_for_testing_k8s_executor_log_read",
+            run_type=DagRunType.SCHEDULED,
+            execution_date=DEFAULT_DATE,
+        )
+
+        with conf_vars({("core", "executor"): executor_name}):
+            with patch("os.path.exists", return_value=False):
+                fth = FileTaskHandler("")
+                fth._read(ti=ti, try_number=1)
+                mock_k8s_get_task_log.assert_called_once_with(ti=ti, 
log=mock.ANY)
+
+    def test__read_for_celery_executor_fallbacks_to_worker(self, 
create_task_instance):
+        """Test for executors which do not have `get_task_log` method, it 
fallbacks to reading
+        log from worker"""
+        executor_name = "CeleryExecutor"
+
+        ti = create_task_instance(
+            dag_id="dag_for_testing_celery_executor_log_read",
+            task_id="task_for_testing_celery_executor_log_read",
+            run_type=DagRunType.SCHEDULED,
+            execution_date=DEFAULT_DATE,
+        )
+
+        with conf_vars({("core", "executor"): executor_name}):
+            with patch("os.path.exists", return_value=False):
+                fth = FileTaskHandler("")
+
+                def mock_log_from_worker(ti, log, log_relative_path):
+                    return (log, {"end_of_log": True})
+
+                fth._get_task_log_from_worker = 
mock.Mock(side_effect=mock_log_from_worker)
+                log = fth._read(ti=ti, try_number=1)
+                fth._get_task_log_from_worker.assert_called_once()
+                assert "Local log file does not exist" in log[0]
+                assert "Failed to fetch log from executor. Falling back to 
fetching log from worker" in log[0]
+
     @pytest.mark.parametrize(
         "pod_override, namespace_to_call",
         [
@@ -231,7 +291,7 @@ class TestFileTaskLogHandler:
         ],
     )
     @patch.dict("os.environ", AIRFLOW__CORE__EXECUTOR="KubernetesExecutor")
-    @patch("airflow.kubernetes.kube_client.get_kube_client")
+    @patch("airflow.executors.kubernetes_executor.get_kube_client")
     def test_read_from_k8s_under_multi_namespace_mode(
         self, mock_kube_client, pod_override, namespace_to_call
     ):
@@ -342,36 +402,3 @@ class TestLogUrl:
         log_url_ti.hostname = "hostname"
         url = FileTaskHandler._get_log_retrieval_url(log_url_ti, 
"DYNAMIC_PATH")
         assert url == "http://hostname:8793/log/DYNAMIC_PATH";
-
-
-@pytest.mark.parametrize(
-    "config, queue, expected",
-    [
-        (dict(AIRFLOW__CORE__EXECUTOR="LocalExecutor"), None, False),
-        (dict(AIRFLOW__CORE__EXECUTOR="LocalExecutor"), "kubernetes", False),
-        (dict(AIRFLOW__CORE__EXECUTOR="KubernetesExecutor"), None, True),
-        (dict(AIRFLOW__CORE__EXECUTOR="CeleryKubernetesExecutor"), "any", 
False),
-        (dict(AIRFLOW__CORE__EXECUTOR="CeleryKubernetesExecutor"), 
"kubernetes", True),
-        (
-            dict(
-                AIRFLOW__CORE__EXECUTOR="CeleryKubernetesExecutor",
-                
AIRFLOW__CELERY_KUBERNETES_EXECUTOR__KUBERNETES_QUEUE="hithere",
-            ),
-            "hithere",
-            True,
-        ),
-        (dict(AIRFLOW__CORE__EXECUTOR="LocalKubernetesExecutor"), "any", 
False),
-        (dict(AIRFLOW__CORE__EXECUTOR="LocalKubernetesExecutor"), 
"kubernetes", True),
-        (
-            dict(
-                AIRFLOW__CORE__EXECUTOR="LocalKubernetesExecutor",
-                AIRFLOW__LOCAL_KUBERNETES_EXECUTOR__KUBERNETES_QUEUE="hithere",
-            ),
-            "hithere",
-            True,
-        ),
-    ],
-)
-def test__should_check_k8s(config, queue, expected):
-    with patch.dict("os.environ", **config):
-        assert FileTaskHandler._should_check_k8s(queue) == expected

Reply via email to