This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d0cea6d  Allow using default celery command group with executors 
subclassed from Celery-based executors. (#18189)
d0cea6d is described below

commit d0cea6d849ccf11e2b1e55d3280fcca59948eb53
Author: Georgy Borodin <[email protected]>
AuthorDate: Sat Dec 4 18:19:40 2021 +0300

    Allow using default celery command group with executors subclassed from 
Celery-based executors. (#18189)
---
 airflow/cli/cli_parser.py            | 18 ++++++++----
 airflow/executors/executor_loader.py | 55 +++++++++++++++++++++++-------------
 tests/cli/conftest.py                | 13 +++++++++
 tests/cli/test_cli_parser.py         | 12 ++++++--
 4 files changed, 72 insertions(+), 26 deletions(-)

diff --git a/airflow/cli/cli_parser.py b/airflow/cli/cli_parser.py
index 62943ee..85e5b05 100644
--- a/airflow/cli/cli_parser.py
+++ b/airflow/cli/cli_parser.py
@@ -30,7 +30,9 @@ from airflow import PY37, settings
 from airflow.cli.commands.legacy_commands import check_legacy_command
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException
+from airflow.executors import celery_executor, celery_kubernetes_executor
 from airflow.executors.executor_constants import CELERY_EXECUTOR, 
CELERY_KUBERNETES_EXECUTOR
+from airflow.executors.executor_loader import ExecutorLoader
 from airflow.utils.cli import ColorMode
 from airflow.utils.helpers import partition
 from airflow.utils.module_loading import import_string
@@ -60,10 +62,17 @@ class DefaultHelpParser(argparse.ArgumentParser):
         if action.dest == 'subcommand' and value == 'celery':
             executor = conf.get('core', 'EXECUTOR')
             if executor not in (CELERY_EXECUTOR, CELERY_KUBERNETES_EXECUTOR):
-                message = (
-                    f'celery subcommand works only with CeleryExecutor, your 
current executor: {executor}'
-                )
-                raise ArgumentError(action, message)
+                executor_cls, _ = ExecutorLoader.import_executor_cls(executor)
+                if not issubclass(
+                    executor_cls,
+                    (celery_executor.CeleryExecutor, 
celery_kubernetes_executor.CeleryKubernetesExecutor),
+                ):
+                    message = (
+                        f'celery subcommand works only with CeleryExecutor, 
CeleryKubernetesExecutor and '
+                        f'executors derived from them, your current executor: 
{executor}, subclassed from: '
+                        f'{", ".join([base_cls.__qualname__ for base_cls in 
executor_cls.__bases__])}'
+                    )
+                    raise ArgumentError(action, message)
         if action.dest == 'subcommand' and value == 'kubernetes':
             try:
                 import kubernetes.client  # noqa: F401
@@ -810,7 +819,6 @@ class GroupCommand(NamedTuple):
 
 CLICommand = Union[ActionCommand, GroupCommand]
 
-
 DAGS_COMMANDS = (
     ActionCommand(
         name='list',
diff --git a/airflow/executors/executor_loader.py 
b/airflow/executors/executor_loader.py
index 676bf35..01b5f9f 100644
--- a/airflow/executors/executor_loader.py
+++ b/airflow/executors/executor_loader.py
@@ -17,7 +17,8 @@
 """All executors."""
 import logging
 from contextlib import suppress
-from typing import Optional
+from enum import Enum, unique
+from typing import Optional, Tuple, Type
 
 from airflow.exceptions import AirflowConfigException
 from airflow.executors.base_executor import BaseExecutor
@@ -35,6 +36,15 @@ from airflow.utils.module_loading import import_string
 log = logging.getLogger(__name__)
 
 
+@unique
+class ConnectorSource(Enum):
+    """Enum of supported executor import sources."""
+
+    CORE = "core"
+    PLUGIN = "plugin"
+    CUSTOM_PATH = "custom path"
+
+
 class ExecutorLoader:
     """Keeps constants for all the currently available executors."""
 
@@ -77,14 +87,33 @@ class ExecutorLoader:
         """
         if executor_name == CELERY_KUBERNETES_EXECUTOR:
             return cls.__load_celery_kubernetes_executor()
+        try:
+            executor_cls, import_source = 
cls.import_executor_cls(executor_name)
+            log.debug("Loading executor %s from %s", executor_name, 
import_source.value)
+        except ImportError as e:
+            log.error(e)
+            raise AirflowConfigException(
+                f'The module/attribute could not be loaded. Please check 
"executor" key in "core" section. '
+                f'Current value: "{executor_name}".'
+            )
+        log.info("Loaded executor: %s", executor_name)
 
+        return executor_cls()
+
+    @classmethod
+    def import_executor_cls(cls, executor_name: str) -> 
Tuple[Type[BaseExecutor], ConnectorSource]:
+        """
+        Imports the executor class.
+
+        Supports the same formats as ExecutorLoader.load_executor.
+
+        :return: executor class via executor_name and executor import source
+        """
         if executor_name in cls.executors:
-            log.debug("Loading core executor: %s", executor_name)
-            return import_string(cls.executors[executor_name])()
-        # If the executor name looks like "plugin executor path" then try to 
load plugins.
+            return import_string(cls.executors[executor_name]), 
ConnectorSource.CORE
         if executor_name.count(".") == 1:
             log.debug(
-                "The executor name looks like the plugin path 
(executor_name=%s). Trying to load a "
+                "The executor name looks like the plugin path 
(executor_name=%s). Trying to import a "
                 "executor from a plugin",
                 executor_name,
             )
@@ -94,20 +123,8 @@ class ExecutorLoader:
                 from airflow import plugins_manager
 
                 plugins_manager.integrate_executor_plugins()
-                return import_string(f"airflow.executors.{executor_name}")()
-
-        log.debug("Loading executor from custom path: %s", executor_name)
-        try:
-            executor = import_string(executor_name)()
-        except ImportError as e:
-            log.error(e)
-            raise AirflowConfigException(
-                f'The module/attribute could not be loaded. Please check 
"executor" key in "core" section. '
-                f'Current value: "{executor_name}".'
-            )
-        log.info("Loaded executor: %s", executor_name)
-
-        return executor
+                return import_string(f"airflow.executors.{executor_name}"), 
ConnectorSource.PLUGIN
+        return import_string(executor_name), ConnectorSource.CUSTOM_PATH
 
     @classmethod
     def __load_celery_kubernetes_executor(cls) -> BaseExecutor:
diff --git a/tests/cli/conftest.py b/tests/cli/conftest.py
index fe4d857..6d74471 100644
--- a/tests/cli/conftest.py
+++ b/tests/cli/conftest.py
@@ -16,10 +16,23 @@
 # specific language governing permissions and limitations
 # under the License.
 #
+import sys
+
 import pytest
 
 from airflow import models
 from airflow.cli import cli_parser
+from airflow.executors import celery_executor, celery_kubernetes_executor
+
+# Create custom executors here because conftest is imported first
+custom_executor_module = type(sys)('custom_executor')
+custom_executor_module.CustomCeleryExecutor = type(
+    'CustomCeleryExecutor', (celery_executor.CeleryExecutor,), {}
+)
+custom_executor_module.CustomCeleryKubernetesExecutor = type(
+    'CustomCeleryKubernetesExecutor', 
(celery_kubernetes_executor.CeleryKubernetesExecutor,), {}
+)
+sys.modules['custom_executor'] = custom_executor_module
 
 
 @pytest.fixture(scope="session")
diff --git a/tests/cli/test_cli_parser.py b/tests/cli/test_cli_parser.py
index 55ec492..af5788c 100644
--- a/tests/cli/test_cli_parser.py
+++ b/tests/cli/test_cli_parser.py
@@ -206,10 +206,18 @@ class TestCli(TestCase):
             stderr = stderr.getvalue()
         assert (
             "airflow command error: argument GROUP_OR_COMMAND: celery 
subcommand "
-            "works only with CeleryExecutor, your current executor: 
SequentialExecutor, see help above."
+            "works only with CeleryExecutor, CeleryKubernetesExecutor and 
executors derived from them, "
+            "your current executor: SequentialExecutor, subclassed from: 
BaseExecutor, see help above."
         ) in stderr
 
-    @parameterized.expand(["CeleryExecutor", "CeleryKubernetesExecutor"])
+    @parameterized.expand(
+        [
+            "CeleryExecutor",
+            "CeleryKubernetesExecutor",
+            "custom_executor.CustomCeleryExecutor",
+            "custom_executor.CustomCeleryKubernetesExecutor",
+        ]
+    )
     def test_dag_parser_celery_command_accept_celery_executor(self, executor):
         with conf_vars({('core', 'executor'): executor}), 
contextlib.redirect_stderr(io.StringIO()) as stderr:
             parser = cli_parser.get_parser()

Reply via email to