This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 840dd25891 Type related import optimization for Executors (#30361)
840dd25891 is described below

commit 840dd2589132040c91f197c23594ddfe0af5aaf9
Author: Niko Oliveira <[email protected]>
AuthorDate: Fri Apr 7 03:52:29 2023 -0700

    Type related import optimization for Executors (#30361)
    
    Move some expensive typing related imports to be under TYPE_CHECKING
---
 airflow/executors/base_executor.py              | 37 +++++++++++++------------
 airflow/executors/celery_executor.py            | 21 ++++++++------
 airflow/executors/celery_kubernetes_executor.py | 10 +++++--
 airflow/executors/dask_executor.py              |  8 ++++--
 airflow/executors/debug_executor.py             |  9 ++++--
 airflow/executors/kubernetes_executor.py        | 24 +++++++++-------
 airflow/executors/local_executor.py             | 15 ++++++----
 airflow/executors/local_kubernetes_executor.py  | 10 +++++--
 airflow/executors/sequential_executor.py        |  9 ++++--
 tests/executors/test_debug_executor.py          |  2 +-
 10 files changed, 88 insertions(+), 57 deletions(-)

diff --git a/airflow/executors/base_executor.py 
b/airflow/executors/base_executor.py
index cf1947992d..8b72c919f0 100644
--- a/airflow/executors/base_executor.py
+++ b/airflow/executors/base_executor.py
@@ -23,39 +23,40 @@ import warnings
 from collections import OrderedDict, defaultdict
 from dataclasses import dataclass, field
 from datetime import datetime
-from typing import Any, List, Optional, Sequence, Tuple
+from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple
 
 import pendulum
 
-from airflow.callbacks.base_callback_sink import BaseCallbackSink
-from airflow.callbacks.callback_requests import CallbackRequest
 from airflow.configuration import conf
 from airflow.exceptions import RemovedInAirflow3Warning
-from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
 from airflow.stats import Stats
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.state import State
 
 PARALLELISM: int = conf.getint("core", "PARALLELISM")
 
-# Command to execute - list of strings
-# the first element is always "airflow".
-# It should be result of TaskInstance.generate_command method.q
-CommandType = List[str]
+if TYPE_CHECKING:
+    from airflow.callbacks.base_callback_sink import BaseCallbackSink
+    from airflow.callbacks.callback_requests import CallbackRequest
+    from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
 
+    # Command to execute - list of strings
+    # the first element is always "airflow".
+    # It should be result of TaskInstance.generate_command method.
+    CommandType = List[str]
 
-# Task that is queued. It contains all the information that is
-# needed to run the task.
-#
-# Tuple of: command, priority, queue name, TaskInstance
-QueuedTaskInstanceType = Tuple[CommandType, int, Optional[str], TaskInstance]
+    # Task that is queued. It contains all the information that is
+    # needed to run the task.
+    #
+    # Tuple of: command, priority, queue name, TaskInstance
+    QueuedTaskInstanceType = Tuple[CommandType, int, Optional[str], 
TaskInstance]
 
-# Event_buffer dict value type
-# Tuple of: state, info
-EventBufferValueType = Tuple[Optional[str], Any]
+    # Event_buffer dict value type
+    # Tuple of: state, info
+    EventBufferValueType = Tuple[Optional[str], Any]
 
-# Task tuple to send to be executed
-TaskTuple = Tuple[TaskInstanceKey, CommandType, Optional[str], Optional[Any]]
+    # Task tuple to send to be executed
+    TaskTuple = Tuple[TaskInstanceKey, CommandType, Optional[str], 
Optional[Any]]
 
 log = logging.getLogger(__name__)
 
diff --git a/airflow/executors/celery_executor.py 
b/airflow/executors/celery_executor.py
index ec7e9f008d..5be3f7fc60 100644
--- a/airflow/executors/celery_executor.py
+++ b/airflow/executors/celery_executor.py
@@ -35,7 +35,7 @@ from collections import Counter
 from concurrent.futures import ProcessPoolExecutor
 from enum import Enum
 from multiprocessing import cpu_count
-from typing import Any, Mapping, MutableMapping, Optional, Sequence, Tuple
+from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional, 
Sequence, Tuple
 
 from celery import Celery, Task, states as celery_states
 from celery.backends.base import BaseKeyValueStoreBackend
@@ -49,8 +49,7 @@ import airflow.settings as settings
 from airflow.config_templates.default_celery import DEFAULT_CELERY_CONFIG
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException, AirflowTaskTimeout
-from airflow.executors.base_executor import BaseExecutor, CommandType, 
EventBufferValueType, TaskTuple
-from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
+from airflow.executors.base_executor import BaseExecutor
 from airflow.stats import Stats
 from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager
 from airflow.utils.log.logging_mixin import LoggingMixin
@@ -60,6 +59,15 @@ from airflow.utils.state import State
 from airflow.utils.timeout import timeout
 from airflow.utils.timezone import utcnow
 
+if TYPE_CHECKING:
+    from airflow.executors.base_executor import CommandType, 
EventBufferValueType, TaskTuple
+    from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
+
+    # Task instance that is sent over Celery queues
+    # TaskInstanceKey, Command, queue_name, CallableTask
+    TaskInstanceInCelery = Tuple[TaskInstanceKey, CommandType, Optional[str], 
Task]
+
+
 log = logging.getLogger(__name__)
 
 # Make it constant for unit test.
@@ -164,11 +172,6 @@ class ExceptionWithTraceback:
         self.traceback = exception_traceback
 
 
-# Task instance that is sent over Celery queues
-# TaskInstanceKey, Command, queue_name, CallableTask
-TaskInstanceInCelery = Tuple[TaskInstanceKey, CommandType, Optional[str], Task]
-
-
 def send_task_to_executor(
     task_tuple: TaskInstanceInCelery,
 ) -> tuple[TaskInstanceKey, CommandType, AsyncResult | ExceptionWithTraceback]:
@@ -392,6 +395,8 @@ class CeleryExecutor(BaseExecutor):
     def _send_stalled_tis_back_to_scheduler(
         self, keys: list[TaskInstanceKey], session: Session = NEW_SESSION
     ) -> None:
+        from airflow.models.taskinstance import TaskInstance
+
         try:
             session.query(TaskInstance).filter(
                 TaskInstance.filter_for_tis(keys),
diff --git a/airflow/executors/celery_kubernetes_executor.py 
b/airflow/executors/celery_kubernetes_executor.py
index 00a7f15830..2f6101d14e 100644
--- a/airflow/executors/celery_kubernetes_executor.py
+++ b/airflow/executors/celery_kubernetes_executor.py
@@ -17,17 +17,19 @@
 # under the License.
 from __future__ import annotations
 
-from typing import Sequence
+from typing import TYPE_CHECKING, Sequence
 
 from airflow.callbacks.base_callback_sink import BaseCallbackSink
 from airflow.callbacks.callback_requests import CallbackRequest
 from airflow.configuration import conf
-from airflow.executors.base_executor import CommandType, EventBufferValueType, 
QueuedTaskInstanceType
 from airflow.executors.celery_executor import CeleryExecutor
 from airflow.executors.kubernetes_executor import KubernetesExecutor
-from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance, 
TaskInstanceKey
 from airflow.utils.log.logging_mixin import LoggingMixin
 
+if TYPE_CHECKING:
+    from airflow.executors.base_executor import CommandType, 
EventBufferValueType, QueuedTaskInstanceType
+    from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance, 
TaskInstanceKey
+
 
 class CeleryKubernetesExecutor(LoggingMixin):
     """
@@ -126,6 +128,8 @@ class CeleryKubernetesExecutor(LoggingMixin):
         cfg_path: str | None = None,
     ) -> None:
         """Queues task instance via celery or kubernetes executor."""
+        from airflow.models.taskinstance import SimpleTaskInstance
+
         executor = self._router(SimpleTaskInstance.from_ti(task_instance))
         self.log.debug(
             "Using executor: %s to queue_task_instance for %s", 
executor.__class__.__name__, task_instance.key
diff --git a/airflow/executors/dask_executor.py 
b/airflow/executors/dask_executor.py
index 0114933d0c..c152abe4fa 100644
--- a/airflow/executors/dask_executor.py
+++ b/airflow/executors/dask_executor.py
@@ -32,8 +32,12 @@ from distributed.security import Security
 
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException
-from airflow.executors.base_executor import BaseExecutor, CommandType
-from airflow.models.taskinstance import TaskInstanceKey
+from airflow.executors.base_executor import BaseExecutor
+
+if TYPE_CHECKING:
+    from airflow.executors.base_executor import CommandType
+    from airflow.models.taskinstance import TaskInstanceKey
+
 
 # queue="default" is a special case since this is the base config default 
queue name,
 # with respect to DaskExecutor, treat it as if no queue is provided
diff --git a/airflow/executors/debug_executor.py 
b/airflow/executors/debug_executor.py
index 60fd51282e..59b76d6937 100644
--- a/airflow/executors/debug_executor.py
+++ b/airflow/executors/debug_executor.py
@@ -26,13 +26,14 @@ from __future__ import annotations
 
 import threading
 import time
-from typing import Any
+from typing import TYPE_CHECKING, Any
 
-from airflow.configuration import conf
 from airflow.executors.base_executor import BaseExecutor
-from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
 from airflow.utils.state import State
 
+if TYPE_CHECKING:
+    from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
+
 
 class DebugExecutor(BaseExecutor):
     """
@@ -54,6 +55,8 @@ class DebugExecutor(BaseExecutor):
         self.tasks_to_run: list[TaskInstance] = []
         # Place where we keep information for task instance raw run
         self.tasks_params: dict[TaskInstanceKey, dict[str, Any]] = {}
+        from airflow.configuration import conf
+
         self.fail_fast = conf.getboolean("debug", "fail_fast")
 
     def execute_async(self, *args, **kwargs) -> None:
diff --git a/airflow/executors/kubernetes_executor.py 
b/airflow/executors/kubernetes_executor.py
index 3c56e7ceaf..bed56856c6 100644
--- a/airflow/executors/kubernetes_executor.py
+++ b/airflow/executors/kubernetes_executor.py
@@ -41,30 +41,33 @@ from urllib3.exceptions import ReadTimeoutError
 
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException, PodMutationHookException, 
PodReconciliationError
-from airflow.executors.base_executor import BaseExecutor, CommandType
+from airflow.executors.base_executor import BaseExecutor
 from airflow.kubernetes import pod_generator
 from airflow.kubernetes.kube_client import get_kube_client
 from airflow.kubernetes.kube_config import KubeConfig
 from airflow.kubernetes.kubernetes_helper_functions import annotations_to_key, 
create_pod_id
 from airflow.kubernetes.pod_generator import PodGenerator
-from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
 from airflow.utils import timezone
 from airflow.utils.event_scheduler import EventScheduler
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.state import State, TaskInstanceState
 
-ALL_NAMESPACES = "ALL_NAMESPACES"
-POD_EXECUTOR_DONE_KEY = "airflow_executor_done"
+if TYPE_CHECKING:
+    from airflow.executors.base_executor import CommandType
+    from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
+
+    # TaskInstance key, command, configuration, pod_template_file
+    KubernetesJobType = Tuple[TaskInstanceKey, CommandType, Any, Optional[str]]
 
-# TaskInstance key, command, configuration, pod_template_file
-KubernetesJobType = Tuple[TaskInstanceKey, CommandType, Any, Optional[str]]
+    # key, pod state, pod_id, namespace, resource_version
+    KubernetesResultsType = Tuple[TaskInstanceKey, Optional[str], str, str, 
str]
 
-# key, pod state, pod_id, namespace, resource_version
-KubernetesResultsType = Tuple[TaskInstanceKey, Optional[str], str, str, str]
+    # pod_id, namespace, pod state, annotations, resource_version
+    KubernetesWatchType = Tuple[str, str, Optional[str], Dict[str, str], str]
 
-# pod_id, namespace, pod state, annotations, resource_version
-KubernetesWatchType = Tuple[str, str, Optional[str], Dict[str, str], str]
+ALL_NAMESPACES = "ALL_NAMESPACES"
+POD_EXECUTOR_DONE_KEY = "airflow_executor_done"
 
 
 class ResourceVersion:
@@ -512,6 +515,7 @@ class KubernetesExecutor(BaseExecutor):
         """
         if TYPE_CHECKING:
             assert self.kube_client
+        from airflow.models.taskinstance import TaskInstance
 
         self.log.debug("Clearing tasks that have not been launched")
         query = session.query(TaskInstance).filter(
diff --git a/airflow/executors/local_executor.py 
b/airflow/executors/local_executor.py
index fc6eaf221b..6a9bb1a339 100644
--- a/airflow/executors/local_executor.py
+++ b/airflow/executors/local_executor.py
@@ -37,15 +37,18 @@ from setproctitle import getproctitle, setproctitle
 
 from airflow import settings
 from airflow.exceptions import AirflowException
-from airflow.executors.base_executor import PARALLELISM, BaseExecutor, 
CommandType
-from airflow.models.taskinstance import TaskInstanceKey, TaskInstanceStateType
+from airflow.executors.base_executor import PARALLELISM, BaseExecutor
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.state import State
 
-# This is a work to be executed by a worker.
-# It can Key and Command - but it can also be None, None which is actually a
-# "Poison Pill" - worker seeing Poison Pill should take the pill and ... die 
instantly.
-ExecutorWorkType = Tuple[Optional[TaskInstanceKey], Optional[CommandType]]
+if TYPE_CHECKING:
+    from airflow.executors.base_executor import CommandType
+    from airflow.models.taskinstance import TaskInstanceKey, 
TaskInstanceStateType
+
+    # This is a work to be executed by a worker.
+    # It can Key and Command - but it can also be None, None which is actually 
a
+    # "Poison Pill" - worker seeing Poison Pill should take the pill and ... 
die instantly.
+    ExecutorWorkType = Tuple[Optional[TaskInstanceKey], Optional[CommandType]]
 
 
 class LocalWorkerBase(Process, LoggingMixin):
diff --git a/airflow/executors/local_kubernetes_executor.py 
b/airflow/executors/local_kubernetes_executor.py
index 916d838391..85c61eca84 100644
--- a/airflow/executors/local_kubernetes_executor.py
+++ b/airflow/executors/local_kubernetes_executor.py
@@ -17,17 +17,19 @@
 # under the License.
 from __future__ import annotations
 
-from typing import Sequence
+from typing import TYPE_CHECKING, Sequence
 
 from airflow.callbacks.base_callback_sink import BaseCallbackSink
 from airflow.callbacks.callback_requests import CallbackRequest
 from airflow.configuration import conf
-from airflow.executors.base_executor import CommandType, EventBufferValueType, 
QueuedTaskInstanceType
 from airflow.executors.kubernetes_executor import KubernetesExecutor
 from airflow.executors.local_executor import LocalExecutor
-from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance, 
TaskInstanceKey
 from airflow.utils.log.logging_mixin import LoggingMixin
 
+if TYPE_CHECKING:
+    from airflow.executors.base_executor import CommandType, 
EventBufferValueType, QueuedTaskInstanceType
+    from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance, 
TaskInstanceKey
+
 
 class LocalKubernetesExecutor(LoggingMixin):
     """
@@ -127,6 +129,8 @@ class LocalKubernetesExecutor(LoggingMixin):
         cfg_path: str | None = None,
     ) -> None:
         """Queues task instance via local or kubernetes executor."""
+        from airflow.models.taskinstance import SimpleTaskInstance
+
         executor = self._router(SimpleTaskInstance.from_ti(task_instance))
         self.log.debug(
             "Using executor: %s to queue_task_instance for %s", 
executor.__class__.__name__, task_instance.key
diff --git a/airflow/executors/sequential_executor.py 
b/airflow/executors/sequential_executor.py
index 0f75c3d930..b3da5af080 100644
--- a/airflow/executors/sequential_executor.py
+++ b/airflow/executors/sequential_executor.py
@@ -25,12 +25,15 @@ SequentialExecutor.
 from __future__ import annotations
 
 import subprocess
-from typing import Any
+from typing import TYPE_CHECKING, Any
 
-from airflow.executors.base_executor import BaseExecutor, CommandType
-from airflow.models.taskinstance import TaskInstanceKey
+from airflow.executors.base_executor import BaseExecutor
 from airflow.utils.state import State
 
+if TYPE_CHECKING:
+    from airflow.executors.base_executor import CommandType
+    from airflow.models.taskinstance import TaskInstanceKey
+
 
 class SequentialExecutor(BaseExecutor):
     """
diff --git a/tests/executors/test_debug_executor.py 
b/tests/executors/test_debug_executor.py
index cb939a1e6a..d4b40f637a 100644
--- a/tests/executors/test_debug_executor.py
+++ b/tests/executors/test_debug_executor.py
@@ -38,7 +38,7 @@ class TestDebugExecutor:
         assert not executor.tasks_to_run
         run_task_mock.assert_has_calls([mock.call(ti1), mock.call(ti2)])
 
-    @mock.patch("airflow.executors.debug_executor.TaskInstance")
+    @mock.patch("airflow.models.taskinstance.TaskInstance")
     def test_run_task(self, task_instance_mock):
         ti_key = "key"
         job_id = " job_id"

Reply via email to