This is an automated email from the ASF dual-hosted git repository.

weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 51dbabca597 Fix `FileTaskHandler` only read from default executor 
(#45631)
51dbabca597 is described below

commit 51dbabca5977bf405f853bef1734c1eb2fd66dd8
Author: LIU ZHE YOU <[email protected]>
AuthorDate: Fri Jan 24 10:58:44 2025 +0800

    Fix `FileTaskHandler` only read from default executor (#45631)
    
    * Fix FileTaskHandler only read from default executor
    
    * Add cached_property back to avoid loading executors
    
    * Add test for multi-executors scenario
    
    * Allow to call load_executor without init_executors
    
    * Refactor by caching necessary executors
    
    * Refactor test with default executor case
    
    * Fix side effect from executor_loader
    
    * Fix KubernetesExecutor test
    
    - Previous test failure is cuased by cache state of executor_instances
    - Should set ti.state = RUNNING after ti.run
    
    * Fix side effect from executor_loader
    
    - The side effect only show up in postgres as backend environment, as
      previous fix only resolve side effect in sqlite as backend environment.
    - Also refactor clean_executor_loader as pytest fixture with setup
      teardown
    
    * Capitalize default executor key
    
    * Refactor clean_executor_loader fixture
---
 airflow/executors/executor_loader.py               |  4 +
 airflow/utils/log/file_task_handler.py             | 33 ++++++--
 .../kubernetes/log_handlers/test_log_handlers.py   |  8 +-
 tests/executors/test_executor_loader.py            | 89 +++++++++++----------
 tests/ti_deps/deps/test_ready_to_reschedule_dep.py |  1 +
 tests/utils/test_log_handlers.py                   | 91 +++++++++++++++++++++-
 tests_common/pytest_plugin.py                      | 13 ++++
 tests_common/test_utils/executor_loader.py         | 34 ++++++++
 8 files changed, 223 insertions(+), 50 deletions(-)

diff --git a/airflow/executors/executor_loader.py 
b/airflow/executors/executor_loader.py
index 2651718bbad..6d6b8d115bc 100644
--- a/airflow/executors/executor_loader.py
+++ b/airflow/executors/executor_loader.py
@@ -231,6 +231,10 @@ class ExecutorLoader:
     @classmethod
     def lookup_executor_name_by_str(cls, executor_name_str: str) -> 
ExecutorName:
         # lookup the executor by alias first, if not check if we're given a 
module path
+        if not _classname_to_executors or not _module_to_executors or not 
_alias_to_executors:
+            # if we haven't loaded the executors yet, such as directly calling 
load_executor
+            cls._get_executor_names()
+
         if executor_name := _alias_to_executors.get(executor_name_str):
             return executor_name
         elif executor_name := _module_to_executors.get(executor_name_str):
diff --git a/airflow/utils/log/file_task_handler.py 
b/airflow/utils/log/file_task_handler.py
index 73ee79126a9..21b745affbc 100644
--- a/airflow/utils/log/file_task_handler.py
+++ b/airflow/utils/log/file_task_handler.py
@@ -24,7 +24,6 @@ import os
 from collections.abc import Iterable
 from contextlib import suppress
 from enum import Enum
-from functools import cached_property
 from pathlib import Path
 from typing import TYPE_CHECKING, Any, Callable
 from urllib.parse import urljoin
@@ -44,6 +43,7 @@ from airflow.utils.state import State, TaskInstanceState
 if TYPE_CHECKING:
     from pendulum import DateTime
 
+    from airflow.executors.base_executor import BaseExecutor
     from airflow.models.taskinstance import TaskInstance
     from airflow.models.taskinstancekey import TaskInstanceKey
 
@@ -179,6 +179,8 @@ class FileTaskHandler(logging.Handler):
     inherits_from_empty_operator_log_message = (
         "Operator inherits from empty operator and thus does not have logs"
     )
+    executor_instances: dict[str, BaseExecutor] = {}
+    DEFAULT_EXECUTOR_KEY = "_default_executor"
 
     def __init__(
         self,
@@ -314,11 +316,27 @@ class FileTaskHandler(logging.Handler):
     def _read_grouped_logs(self):
         return False
 
-    @cached_property
-    def _executor_get_task_log(self) -> Callable[[TaskInstance, int], 
tuple[list[str], list[str]]]:
-        """This cached property avoids loading executor repeatedly."""
-        executor = ExecutorLoader.get_default_executor()
-        return executor.get_task_log
+    def _get_executor_get_task_log(
+        self, ti: TaskInstance
+    ) -> Callable[[TaskInstance, int], tuple[list[str], list[str]]]:
+        """
+        Get the get_task_log method from executor of current task instance.
+
+        Since there might be multiple executors, so we need to get the 
executor of current task instance instead of getting from default executor.
+
+        :param ti: task instance object
+        :return: get_task_log method of the executor
+        """
+        executor_name = ti.executor or self.DEFAULT_EXECUTOR_KEY
+        executor = self.executor_instances.get(executor_name)
+        if executor is not None:
+            return executor.get_task_log
+
+        if executor_name == self.DEFAULT_EXECUTOR_KEY:
+            self.executor_instances[executor_name] = 
ExecutorLoader.get_default_executor()
+        else:
+            self.executor_instances[executor_name] = 
ExecutorLoader.load_executor(executor_name)
+        return self.executor_instances[executor_name].get_task_log
 
     def _read(
         self,
@@ -360,7 +378,8 @@ class FileTaskHandler(logging.Handler):
             messages_list.extend(remote_messages)
         has_k8s_exec_pod = False
         if ti.state == TaskInstanceState.RUNNING:
-            response = self._executor_get_task_log(ti, try_number)
+            executor_get_task_log = self._get_executor_get_task_log(ti)
+            response = executor_get_task_log(ti, try_number)
             if response:
                 executor_messages, executor_logs = response
             if executor_messages:
diff --git a/providers/tests/cncf/kubernetes/log_handlers/test_log_handlers.py 
b/providers/tests/cncf/kubernetes/log_handlers/test_log_handlers.py
index 9cbebcf8df9..d89fbdf6edb 100644
--- a/providers/tests/cncf/kubernetes/log_handlers/test_log_handlers.py
+++ b/providers/tests/cncf/kubernetes/log_handlers/test_log_handlers.py
@@ -74,6 +74,7 @@ class TestFileTaskLogHandler:
         
"airflow.providers.cncf.kubernetes.executors.kubernetes_executor.KubernetesExecutor.get_task_log"
     )
     @pytest.mark.parametrize("state", [TaskInstanceState.RUNNING, 
TaskInstanceState.SUCCESS])
+    @pytest.mark.usefixtures("clean_executor_loader")
     def test__read_for_k8s_executor(self, mock_k8s_get_task_log, 
create_task_instance, state):
         """Test for k8s executor, the log is read from get_task_log method"""
         mock_k8s_get_task_log.return_value = ([], [])
@@ -86,6 +87,7 @@ class TestFileTaskLogHandler:
         )
         ti.state = state
         ti.triggerer_job = None
+        ti.executor = executor_name
         with conf_vars({("core", "executor"): executor_name}):
             reload(executor_loader)
             fth = FileTaskHandler("")
@@ -105,11 +107,12 @@ class TestFileTaskLogHandler:
             
pytest.param(k8s.V1Pod(metadata=k8s.V1ObjectMeta(name="pod-name-xxx")), 
"default"),
         ],
     )
-    @patch.dict("os.environ", AIRFLOW__CORE__EXECUTOR="KubernetesExecutor")
+    @conf_vars({("core", "executor"): "KubernetesExecutor"})
     @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client")
     def test_read_from_k8s_under_multi_namespace_mode(
         self, mock_kube_client, pod_override, namespace_to_call
     ):
+        reload(executor_loader)
         mock_read_log = mock_kube_client.return_value.read_namespaced_pod_log
         mock_list_pod = mock_kube_client.return_value.list_namespaced_pod
 
@@ -139,6 +142,7 @@ class TestFileTaskLogHandler:
         )
         ti = TaskInstance(task=task, run_id=dagrun.run_id)
         ti.try_number = 3
+        ti.executor = "KubernetesExecutor"
 
         logger = ti.log
         ti.log.disabled = False
@@ -147,6 +151,8 @@ class TestFileTaskLogHandler:
         set_context(logger, ti)
         ti.run(ignore_ti_state=True)
         ti.state = TaskInstanceState.RUNNING
+        # clear executor_instances cache
+        file_handler.executor_instances = {}
         file_handler.read(ti, 2)
 
         # first we find pod name
diff --git a/tests/executors/test_executor_loader.py 
b/tests/executors/test_executor_loader.py
index 87455bd841b..de6703954b1 100644
--- a/tests/executors/test_executor_loader.py
+++ b/tests/executors/test_executor_loader.py
@@ -16,14 +16,13 @@
 # under the License.
 from __future__ import annotations
 
-from importlib import reload
 from unittest import mock
 
 import pytest
 
 from airflow.exceptions import AirflowConfigException
 from airflow.executors import executor_loader
-from airflow.executors.executor_loader import ConnectorSource, ExecutorLoader, 
ExecutorName
+from airflow.executors.executor_loader import ConnectorSource, ExecutorName
 from airflow.executors.local_executor import LocalExecutor
 from airflow.providers.amazon.aws.executors.ecs.ecs_executor import 
AwsEcsExecutor
 from airflow.providers.celery.executors.celery_executor import CeleryExecutor
@@ -35,24 +34,12 @@ class FakeExecutor:
     pass
 
 
[email protected]("clean_executor_loader")
 class TestExecutorLoader:
-    def setup_method(self) -> None:
-        from airflow.executors import executor_loader
-
-        reload(executor_loader)
-        global ExecutorLoader
-        ExecutorLoader = executor_loader.ExecutorLoader  # type: ignore
-
-    def teardown_method(self) -> None:
-        from airflow.executors import executor_loader
-
-        reload(executor_loader)
-        ExecutorLoader.init_executors()
-
     def test_no_executor_configured(self):
         with conf_vars({("core", "executor"): None}):
             with pytest.raises(AirflowConfigException, match=r".*not found in 
config$"):
-                ExecutorLoader.get_default_executor()
+                executor_loader.ExecutorLoader.get_default_executor()
 
     @pytest.mark.parametrize(
         "executor_name",
@@ -66,16 +53,18 @@ class TestExecutorLoader:
     )
     def test_should_support_executor_from_core(self, executor_name):
         with conf_vars({("core", "executor"): executor_name}):
-            executor = ExecutorLoader.get_default_executor()
+            executor = executor_loader.ExecutorLoader.get_default_executor()
             assert executor is not None
             assert executor_name == executor.__class__.__name__
             assert executor.name is not None
-            assert executor.name == 
ExecutorName(ExecutorLoader.executors[executor_name], alias=executor_name)
+            assert executor.name == ExecutorName(
+                executor_loader.ExecutorLoader.executors[executor_name], 
alias=executor_name
+            )
             assert executor.name.connector_source == ConnectorSource.CORE
 
     def test_should_support_custom_path(self):
         with conf_vars({("core", "executor"): 
"tests.executors.test_executor_loader.FakeExecutor"}):
-            executor = ExecutorLoader.get_default_executor()
+            executor = executor_loader.ExecutorLoader.get_default_executor()
             assert executor is not None
             assert executor.__class__.__name__ == "FakeExecutor"
             assert executor.name is not None
@@ -249,17 +238,17 @@ class TestExecutorLoader:
                 
"airflow.executors.executor_loader.ExecutorLoader._get_team_executor_configs",
                 return_value=team_executor_config,
             ):
-                executors = ExecutorLoader._get_executor_names()
+                executors = 
executor_loader.ExecutorLoader._get_executor_names()
                 assert executors == expected_executors_list
 
     def test_init_executors(self):
         with conf_vars({("core", "executor"): "CeleryExecutor"}):
-            executors = ExecutorLoader.init_executors()
-            executor_name = ExecutorLoader.get_default_executor_name()
+            executors = executor_loader.ExecutorLoader.init_executors()
+            executor_name = 
executor_loader.ExecutorLoader.get_default_executor_name()
             assert len(executors) == 1
             assert isinstance(executors[0], CeleryExecutor)
-            assert "CeleryExecutor" in ExecutorLoader.executors
-            assert ExecutorLoader.executors["CeleryExecutor"] == 
executor_name.module_path
+            assert "CeleryExecutor" in executor_loader.ExecutorLoader.executors
+            assert executor_loader.ExecutorLoader.executors["CeleryExecutor"] 
== executor_name.module_path
 
     @pytest.mark.parametrize(
         "executor_config",
@@ -276,7 +265,7 @@ class TestExecutorLoader:
             with pytest.raises(
                 AirflowConfigException, match=r".+Duplicate executors are not 
yet supported.+"
             ):
-                ExecutorLoader._get_executor_names()
+                executor_loader.ExecutorLoader._get_executor_names()
 
     @pytest.mark.parametrize(
         "executor_config",
@@ -292,7 +281,7 @@ class TestExecutorLoader:
     def 
test_get_hybrid_executors_from_config_core_executors_bad_config_format(self, 
executor_config):
         with conf_vars({("core", "executor"): executor_config}):
             with pytest.raises(AirflowConfigException):
-                ExecutorLoader._get_executor_names()
+                executor_loader.ExecutorLoader._get_executor_names()
 
     @pytest.mark.parametrize(
         ("executor_config", "expected_value"),
@@ -308,7 +297,7 @@ class TestExecutorLoader:
     )
     def test_should_support_import_executor_from_core(self, executor_config, 
expected_value):
         with conf_vars({("core", "executor"): executor_config}):
-            executor, import_source = 
ExecutorLoader.import_default_executor_cls()
+            executor, import_source = 
executor_loader.ExecutorLoader.import_default_executor_cls()
             assert expected_value == executor.__name__
             assert import_source == ConnectorSource.CORE
 
@@ -322,26 +311,43 @@ class TestExecutorLoader:
     )
     def test_should_support_import_custom_path(self, executor_config):
         with conf_vars({("core", "executor"): executor_config}):
-            executor, import_source = 
ExecutorLoader.import_default_executor_cls()
+            executor, import_source = 
executor_loader.ExecutorLoader.import_default_executor_cls()
             assert executor.__name__ == "FakeExecutor"
             assert import_source == ConnectorSource.CUSTOM_PATH
 
     def test_load_executor(self):
         with conf_vars({("core", "executor"): "LocalExecutor"}):
-            ExecutorLoader.init_executors()
-            assert isinstance(ExecutorLoader.load_executor("LocalExecutor"), 
LocalExecutor)
-            assert 
isinstance(ExecutorLoader.load_executor(executor_loader._executor_names[0]), 
LocalExecutor)
-            assert isinstance(ExecutorLoader.load_executor(None), 
LocalExecutor)
+            executor_loader.ExecutorLoader.init_executors()
+            assert 
isinstance(executor_loader.ExecutorLoader.load_executor("LocalExecutor"), 
LocalExecutor)
+            assert isinstance(
+                
executor_loader.ExecutorLoader.load_executor(executor_loader._executor_names[0]),
+                LocalExecutor,
+            )
+            assert 
isinstance(executor_loader.ExecutorLoader.load_executor(None), LocalExecutor)
 
     def test_load_executor_alias(self):
         with conf_vars({("core", "executor"): 
"local_exec:airflow.executors.local_executor.LocalExecutor"}):
-            ExecutorLoader.init_executors()
-            assert isinstance(ExecutorLoader.load_executor("local_exec"), 
LocalExecutor)
+            executor_loader.ExecutorLoader.init_executors()
+            assert 
isinstance(executor_loader.ExecutorLoader.load_executor("local_exec"), 
LocalExecutor)
             assert isinstance(
-                
ExecutorLoader.load_executor("airflow.executors.local_executor.LocalExecutor"),
+                executor_loader.ExecutorLoader.load_executor(
+                    "airflow.executors.local_executor.LocalExecutor"
+                ),
+                LocalExecutor,
+            )
+            assert isinstance(
+                
executor_loader.ExecutorLoader.load_executor(executor_loader._executor_names[0]),
                 LocalExecutor,
             )
-            assert 
isinstance(ExecutorLoader.load_executor(executor_loader._executor_names[0]), 
LocalExecutor)
+
+    @mock.patch(
+        "airflow.executors.executor_loader.ExecutorLoader._get_executor_names",
+        wraps=executor_loader.ExecutorLoader._get_executor_names,
+    )
+    def test_call_load_executor_method_without_init_executors(self, 
mock_get_executor_names):
+        with conf_vars({("core", "executor"): "LocalExecutor"}):
+            executor_loader.ExecutorLoader.load_executor("LocalExecutor")
+            mock_get_executor_names.assert_called_once()
 
     
@mock.patch("airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor",
 autospec=True)
     def test_load_custom_executor_with_classname(self, mock_executor):
@@ -353,15 +359,16 @@ class TestExecutorLoader:
                 ): 
"my_alias:airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor"
             }
         ):
-            ExecutorLoader.init_executors()
-            assert isinstance(ExecutorLoader.load_executor("my_alias"), 
AwsEcsExecutor)
-            assert isinstance(ExecutorLoader.load_executor("AwsEcsExecutor"), 
AwsEcsExecutor)
+            executor_loader.ExecutorLoader.init_executors()
+            assert 
isinstance(executor_loader.ExecutorLoader.load_executor("my_alias"), 
AwsEcsExecutor)
+            assert 
isinstance(executor_loader.ExecutorLoader.load_executor("AwsEcsExecutor"), 
AwsEcsExecutor)
             assert isinstance(
-                ExecutorLoader.load_executor(
+                executor_loader.ExecutorLoader.load_executor(
                     
"airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor"
                 ),
                 AwsEcsExecutor,
             )
             assert isinstance(
-                
ExecutorLoader.load_executor(executor_loader._executor_names[0]), AwsEcsExecutor
+                
executor_loader.ExecutorLoader.load_executor(executor_loader._executor_names[0]),
+                AwsEcsExecutor,
             )
diff --git a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py 
b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py
index d982cf4b271..7e6f1b2253e 100644
--- a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py
+++ b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py
@@ -49,6 +49,7 @@ def not_expected_tr_db_call():
         yield m
 
 
[email protected]("clean_executor_loader")
 class TestNotInReschedulePeriodDep:
     @pytest.fixture(autouse=True)
     def setup_test_cases(self, request, create_task_instance):
diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py
index 454af48d667..fda432e01d1 100644
--- a/tests/utils/test_log_handlers.py
+++ b/tests/utils/test_log_handlers.py
@@ -33,7 +33,7 @@ from pydantic.v1.utils import deep_update
 from requests.adapters import Response
 
 from airflow.config_templates.airflow_local_settings import 
DEFAULT_LOGGING_CONFIG
-from airflow.executors import executor_loader
+from airflow.executors import executor_constants, executor_loader
 from airflow.jobs.job import Job
 from airflow.jobs.triggerer_job_runner import TriggererJobRunner
 from airflow.models.dagrun import DagRun
@@ -187,6 +187,95 @@ class TestFileTaskLogHandler:
         # Remove the generated tmp log file.
         os.remove(log_filename)
 
+    @pytest.mark.parametrize(
+        "executor_name",
+        [
+            (executor_constants.LOCAL_KUBERNETES_EXECUTOR),
+            (executor_constants.CELERY_KUBERNETES_EXECUTOR),
+            (executor_constants.KUBERNETES_EXECUTOR),
+            (None),
+        ],
+    )
+    @conf_vars(
+        {
+            ("core", "EXECUTOR"): ",".join(
+                [
+                    executor_constants.LOCAL_KUBERNETES_EXECUTOR,
+                    executor_constants.CELERY_KUBERNETES_EXECUTOR,
+                    executor_constants.KUBERNETES_EXECUTOR,
+                ]
+            ),
+        }
+    )
+    @patch(
+        "airflow.executors.executor_loader.ExecutorLoader.load_executor",
+        wraps=executor_loader.ExecutorLoader.load_executor,
+    )
+    @patch(
+        
"airflow.executors.executor_loader.ExecutorLoader.get_default_executor",
+        wraps=executor_loader.ExecutorLoader.get_default_executor,
+    )
+    def test_file_task_handler_with_multiple_executors(
+        self,
+        mock_get_default_executor,
+        mock_load_executor,
+        executor_name,
+        create_task_instance,
+        clean_executor_loader,
+    ):
+        executors_mapping = executor_loader.ExecutorLoader.executors
+        default_executor_name = 
executor_loader.ExecutorLoader.get_default_executor_name()
+        path_to_executor_class: str
+        if executor_name is None:
+            path_to_executor_class = 
executors_mapping.get(default_executor_name.alias)
+        else:
+            path_to_executor_class = executors_mapping.get(executor_name)
+
+        with patch(f"{path_to_executor_class}.get_task_log", return_value=([], 
[])) as mock_get_task_log:
+            mock_get_task_log.return_value = ([], [])
+            ti = create_task_instance(
+                dag_id="dag_for_testing_multiple_executors",
+                task_id="task_for_testing_multiple_executors",
+                run_type=DagRunType.SCHEDULED,
+                logical_date=DEFAULT_DATE,
+            )
+            if executor_name is not None:
+                ti.executor = executor_name
+            ti.try_number = 1
+            ti.state = TaskInstanceState.RUNNING
+            logger = ti.log
+            ti.log.disabled = False
+
+            file_handler = next(
+                (handler for handler in logger.handlers if handler.name == 
FILE_TASK_HANDLER), None
+            )
+            assert file_handler is not None
+
+            set_context(logger, ti)
+            # clear executor_instances cache
+            file_handler.executor_instances = {}
+            assert file_handler.handler is not None
+            # We expect set_context generates a file locally.
+            log_filename = file_handler.handler.baseFilename
+            assert os.path.isfile(log_filename)
+            assert log_filename.endswith("1.log"), log_filename
+
+            file_handler.flush()
+            file_handler.close()
+
+            assert hasattr(file_handler, "read")
+            file_handler.read(ti)
+            os.remove(log_filename)
+            mock_get_task_log.assert_called_once()
+
+            if executor_name is None:
+                mock_get_default_executor.assert_called_once()
+                # will be called in `ExecutorLoader.get_default_executor` 
method
+                
mock_load_executor.assert_called_once_with(default_executor_name)
+            else:
+                mock_get_default_executor.assert_not_called()
+                mock_load_executor.assert_called_once_with(executor_name)
+
     def test_file_task_handler_running(self, dag_maker):
         def task_callable(ti):
             ti.log.info("test")
diff --git a/tests_common/pytest_plugin.py b/tests_common/pytest_plugin.py
index 1b68f039eaa..969d0b2a61c 100644
--- a/tests_common/pytest_plugin.py
+++ b/tests_common/pytest_plugin.py
@@ -1567,6 +1567,19 @@ def clean_dags_and_dagruns():
     clear_db_runs()
 
 
[email protected]
+def clean_executor_loader():
+    """Clean the executor_loader state, as it stores global variables in the 
module, causing side effects for some tests."""
+    from airflow.executors.executor_loader import ExecutorLoader
+
+    from tests_common.test_utils.executor_loader import 
clean_executor_loader_module
+
+    clean_executor_loader_module()
+    yield  # Test runs here
+    clean_executor_loader_module()
+    ExecutorLoader.init_executors()
+
+
 @pytest.fixture(scope="session")
 def app():
     from tests_common.test_utils.config import conf_vars
diff --git a/tests_common/test_utils/executor_loader.py 
b/tests_common/test_utils/executor_loader.py
new file mode 100644
index 00000000000..f7dd98b7264
--- /dev/null
+++ b/tests_common/test_utils/executor_loader.py
@@ -0,0 +1,34 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import airflow.executors.executor_loader as executor_loader
+
+if TYPE_CHECKING:
+    from airflow.executors.executor_utils import ExecutorName
+
+
+def clean_executor_loader_module():
+    """Clean the executor_loader state, as it stores global variables in the 
module, causing side effects for some tests."""
+    executor_loader._alias_to_executors: dict[str, ExecutorName] = {}
+    executor_loader._module_to_executors: dict[str, ExecutorName] = {}
+    executor_loader._team_id_to_executors: dict[str | None, ExecutorName] = {}
+    executor_loader._classname_to_executors: dict[str, ExecutorName] = {}
+    executor_loader._executor_names: list[ExecutorName] = []

Reply via email to