This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 840dd25891 Type related import optimization for Executors (#30361)
840dd25891 is described below
commit 840dd2589132040c91f197c23594ddfe0af5aaf9
Author: Niko Oliveira <[email protected]>
AuthorDate: Fri Apr 7 03:52:29 2023 -0700
Type related import optimization for Executors (#30361)
Move some expensive typing related imports to be under TYPE_CHECKING
---
airflow/executors/base_executor.py | 37 +++++++++++++------------
airflow/executors/celery_executor.py | 21 ++++++++------
airflow/executors/celery_kubernetes_executor.py | 10 +++++--
airflow/executors/dask_executor.py | 8 ++++--
airflow/executors/debug_executor.py | 9 ++++--
airflow/executors/kubernetes_executor.py | 24 +++++++++-------
airflow/executors/local_executor.py | 15 ++++++----
airflow/executors/local_kubernetes_executor.py | 10 +++++--
airflow/executors/sequential_executor.py | 9 ++++--
tests/executors/test_debug_executor.py | 2 +-
10 files changed, 88 insertions(+), 57 deletions(-)
diff --git a/airflow/executors/base_executor.py
b/airflow/executors/base_executor.py
index cf1947992d..8b72c919f0 100644
--- a/airflow/executors/base_executor.py
+++ b/airflow/executors/base_executor.py
@@ -23,39 +23,40 @@ import warnings
from collections import OrderedDict, defaultdict
from dataclasses import dataclass, field
from datetime import datetime
-from typing import Any, List, Optional, Sequence, Tuple
+from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple
import pendulum
-from airflow.callbacks.base_callback_sink import BaseCallbackSink
-from airflow.callbacks.callback_requests import CallbackRequest
from airflow.configuration import conf
from airflow.exceptions import RemovedInAirflow3Warning
-from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.stats import Stats
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import State
PARALLELISM: int = conf.getint("core", "PARALLELISM")
-# Command to execute - list of strings
-# the first element is always "airflow".
-# It should be result of TaskInstance.generate_command method.q
-CommandType = List[str]
+if TYPE_CHECKING:
+ from airflow.callbacks.base_callback_sink import BaseCallbackSink
+ from airflow.callbacks.callback_requests import CallbackRequest
+ from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
+ # Command to execute - list of strings
+ # the first element is always "airflow".
+ # It should be result of TaskInstance.generate_command method.
+ CommandType = List[str]
-# Task that is queued. It contains all the information that is
-# needed to run the task.
-#
-# Tuple of: command, priority, queue name, TaskInstance
-QueuedTaskInstanceType = Tuple[CommandType, int, Optional[str], TaskInstance]
+ # Task that is queued. It contains all the information that is
+ # needed to run the task.
+ #
+ # Tuple of: command, priority, queue name, TaskInstance
+ QueuedTaskInstanceType = Tuple[CommandType, int, Optional[str],
TaskInstance]
-# Event_buffer dict value type
-# Tuple of: state, info
-EventBufferValueType = Tuple[Optional[str], Any]
+ # Event_buffer dict value type
+ # Tuple of: state, info
+ EventBufferValueType = Tuple[Optional[str], Any]
-# Task tuple to send to be executed
-TaskTuple = Tuple[TaskInstanceKey, CommandType, Optional[str], Optional[Any]]
+ # Task tuple to send to be executed
+ TaskTuple = Tuple[TaskInstanceKey, CommandType, Optional[str],
Optional[Any]]
log = logging.getLogger(__name__)
diff --git a/airflow/executors/celery_executor.py
b/airflow/executors/celery_executor.py
index ec7e9f008d..5be3f7fc60 100644
--- a/airflow/executors/celery_executor.py
+++ b/airflow/executors/celery_executor.py
@@ -35,7 +35,7 @@ from collections import Counter
from concurrent.futures import ProcessPoolExecutor
from enum import Enum
from multiprocessing import cpu_count
-from typing import Any, Mapping, MutableMapping, Optional, Sequence, Tuple
+from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional,
Sequence, Tuple
from celery import Celery, Task, states as celery_states
from celery.backends.base import BaseKeyValueStoreBackend
@@ -49,8 +49,7 @@ import airflow.settings as settings
from airflow.config_templates.default_celery import DEFAULT_CELERY_CONFIG
from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowTaskTimeout
-from airflow.executors.base_executor import BaseExecutor, CommandType,
EventBufferValueType, TaskTuple
-from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
+from airflow.executors.base_executor import BaseExecutor
from airflow.stats import Stats
from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager
from airflow.utils.log.logging_mixin import LoggingMixin
@@ -60,6 +59,15 @@ from airflow.utils.state import State
from airflow.utils.timeout import timeout
from airflow.utils.timezone import utcnow
+if TYPE_CHECKING:
+ from airflow.executors.base_executor import CommandType,
EventBufferValueType, TaskTuple
+ from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
+
+ # Task instance that is sent over Celery queues
+ # TaskInstanceKey, Command, queue_name, CallableTask
+ TaskInstanceInCelery = Tuple[TaskInstanceKey, CommandType, Optional[str],
Task]
+
+
log = logging.getLogger(__name__)
# Make it constant for unit test.
@@ -164,11 +172,6 @@ class ExceptionWithTraceback:
self.traceback = exception_traceback
-# Task instance that is sent over Celery queues
-# TaskInstanceKey, Command, queue_name, CallableTask
-TaskInstanceInCelery = Tuple[TaskInstanceKey, CommandType, Optional[str], Task]
-
-
def send_task_to_executor(
task_tuple: TaskInstanceInCelery,
) -> tuple[TaskInstanceKey, CommandType, AsyncResult | ExceptionWithTraceback]:
@@ -392,6 +395,8 @@ class CeleryExecutor(BaseExecutor):
def _send_stalled_tis_back_to_scheduler(
self, keys: list[TaskInstanceKey], session: Session = NEW_SESSION
) -> None:
+ from airflow.models.taskinstance import TaskInstance
+
try:
session.query(TaskInstance).filter(
TaskInstance.filter_for_tis(keys),
diff --git a/airflow/executors/celery_kubernetes_executor.py
b/airflow/executors/celery_kubernetes_executor.py
index 00a7f15830..2f6101d14e 100644
--- a/airflow/executors/celery_kubernetes_executor.py
+++ b/airflow/executors/celery_kubernetes_executor.py
@@ -17,17 +17,19 @@
# under the License.
from __future__ import annotations
-from typing import Sequence
+from typing import TYPE_CHECKING, Sequence
from airflow.callbacks.base_callback_sink import BaseCallbackSink
from airflow.callbacks.callback_requests import CallbackRequest
from airflow.configuration import conf
-from airflow.executors.base_executor import CommandType, EventBufferValueType,
QueuedTaskInstanceType
from airflow.executors.celery_executor import CeleryExecutor
from airflow.executors.kubernetes_executor import KubernetesExecutor
-from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance,
TaskInstanceKey
from airflow.utils.log.logging_mixin import LoggingMixin
+if TYPE_CHECKING:
+ from airflow.executors.base_executor import CommandType,
EventBufferValueType, QueuedTaskInstanceType
+ from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance,
TaskInstanceKey
+
class CeleryKubernetesExecutor(LoggingMixin):
"""
@@ -126,6 +128,8 @@ class CeleryKubernetesExecutor(LoggingMixin):
cfg_path: str | None = None,
) -> None:
"""Queues task instance via celery or kubernetes executor."""
+ from airflow.models.taskinstance import SimpleTaskInstance
+
executor = self._router(SimpleTaskInstance.from_ti(task_instance))
self.log.debug(
"Using executor: %s to queue_task_instance for %s",
executor.__class__.__name__, task_instance.key
diff --git a/airflow/executors/dask_executor.py
b/airflow/executors/dask_executor.py
index 0114933d0c..c152abe4fa 100644
--- a/airflow/executors/dask_executor.py
+++ b/airflow/executors/dask_executor.py
@@ -32,8 +32,12 @@ from distributed.security import Security
from airflow.configuration import conf
from airflow.exceptions import AirflowException
-from airflow.executors.base_executor import BaseExecutor, CommandType
-from airflow.models.taskinstance import TaskInstanceKey
+from airflow.executors.base_executor import BaseExecutor
+
+if TYPE_CHECKING:
+ from airflow.executors.base_executor import CommandType
+ from airflow.models.taskinstance import TaskInstanceKey
+
# queue="default" is a special case since this is the base config default
queue name,
# with respect to DaskExecutor, treat it as if no queue is provided
diff --git a/airflow/executors/debug_executor.py
b/airflow/executors/debug_executor.py
index 60fd51282e..59b76d6937 100644
--- a/airflow/executors/debug_executor.py
+++ b/airflow/executors/debug_executor.py
@@ -26,13 +26,14 @@ from __future__ import annotations
import threading
import time
-from typing import Any
+from typing import TYPE_CHECKING, Any
-from airflow.configuration import conf
from airflow.executors.base_executor import BaseExecutor
-from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.utils.state import State
+if TYPE_CHECKING:
+ from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
+
class DebugExecutor(BaseExecutor):
"""
@@ -54,6 +55,8 @@ class DebugExecutor(BaseExecutor):
self.tasks_to_run: list[TaskInstance] = []
# Place where we keep information for task instance raw run
self.tasks_params: dict[TaskInstanceKey, dict[str, Any]] = {}
+ from airflow.configuration import conf
+
self.fail_fast = conf.getboolean("debug", "fail_fast")
def execute_async(self, *args, **kwargs) -> None:
diff --git a/airflow/executors/kubernetes_executor.py
b/airflow/executors/kubernetes_executor.py
index 3c56e7ceaf..bed56856c6 100644
--- a/airflow/executors/kubernetes_executor.py
+++ b/airflow/executors/kubernetes_executor.py
@@ -41,30 +41,33 @@ from urllib3.exceptions import ReadTimeoutError
from airflow.configuration import conf
from airflow.exceptions import AirflowException, PodMutationHookException,
PodReconciliationError
-from airflow.executors.base_executor import BaseExecutor, CommandType
+from airflow.executors.base_executor import BaseExecutor
from airflow.kubernetes import pod_generator
from airflow.kubernetes.kube_client import get_kube_client
from airflow.kubernetes.kube_config import KubeConfig
from airflow.kubernetes.kubernetes_helper_functions import annotations_to_key,
create_pod_id
from airflow.kubernetes.pod_generator import PodGenerator
-from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.utils import timezone
from airflow.utils.event_scheduler import EventScheduler
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import State, TaskInstanceState
-ALL_NAMESPACES = "ALL_NAMESPACES"
-POD_EXECUTOR_DONE_KEY = "airflow_executor_done"
+if TYPE_CHECKING:
+ from airflow.executors.base_executor import CommandType
+ from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
+
+ # TaskInstance key, command, configuration, pod_template_file
+ KubernetesJobType = Tuple[TaskInstanceKey, CommandType, Any, Optional[str]]
-# TaskInstance key, command, configuration, pod_template_file
-KubernetesJobType = Tuple[TaskInstanceKey, CommandType, Any, Optional[str]]
+ # key, pod state, pod_id, namespace, resource_version
+ KubernetesResultsType = Tuple[TaskInstanceKey, Optional[str], str, str,
str]
-# key, pod state, pod_id, namespace, resource_version
-KubernetesResultsType = Tuple[TaskInstanceKey, Optional[str], str, str, str]
+ # pod_id, namespace, pod state, annotations, resource_version
+ KubernetesWatchType = Tuple[str, str, Optional[str], Dict[str, str], str]
-# pod_id, namespace, pod state, annotations, resource_version
-KubernetesWatchType = Tuple[str, str, Optional[str], Dict[str, str], str]
+ALL_NAMESPACES = "ALL_NAMESPACES"
+POD_EXECUTOR_DONE_KEY = "airflow_executor_done"
class ResourceVersion:
@@ -512,6 +515,7 @@ class KubernetesExecutor(BaseExecutor):
"""
if TYPE_CHECKING:
assert self.kube_client
+ from airflow.models.taskinstance import TaskInstance
self.log.debug("Clearing tasks that have not been launched")
query = session.query(TaskInstance).filter(
diff --git a/airflow/executors/local_executor.py
b/airflow/executors/local_executor.py
index fc6eaf221b..6a9bb1a339 100644
--- a/airflow/executors/local_executor.py
+++ b/airflow/executors/local_executor.py
@@ -37,15 +37,18 @@ from setproctitle import getproctitle, setproctitle
from airflow import settings
from airflow.exceptions import AirflowException
-from airflow.executors.base_executor import PARALLELISM, BaseExecutor,
CommandType
-from airflow.models.taskinstance import TaskInstanceKey, TaskInstanceStateType
+from airflow.executors.base_executor import PARALLELISM, BaseExecutor
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import State
-# This is a work to be executed by a worker.
-# It can Key and Command - but it can also be None, None which is actually a
-# "Poison Pill" - worker seeing Poison Pill should take the pill and ... die
instantly.
-ExecutorWorkType = Tuple[Optional[TaskInstanceKey], Optional[CommandType]]
+if TYPE_CHECKING:
+ from airflow.executors.base_executor import CommandType
+ from airflow.models.taskinstance import TaskInstanceKey,
TaskInstanceStateType
+
+ # This is a work to be executed by a worker.
+ # It can Key and Command - but it can also be None, None which is actually
a
+ # "Poison Pill" - worker seeing Poison Pill should take the pill and ...
die instantly.
+ ExecutorWorkType = Tuple[Optional[TaskInstanceKey], Optional[CommandType]]
class LocalWorkerBase(Process, LoggingMixin):
diff --git a/airflow/executors/local_kubernetes_executor.py
b/airflow/executors/local_kubernetes_executor.py
index 916d838391..85c61eca84 100644
--- a/airflow/executors/local_kubernetes_executor.py
+++ b/airflow/executors/local_kubernetes_executor.py
@@ -17,17 +17,19 @@
# under the License.
from __future__ import annotations
-from typing import Sequence
+from typing import TYPE_CHECKING, Sequence
from airflow.callbacks.base_callback_sink import BaseCallbackSink
from airflow.callbacks.callback_requests import CallbackRequest
from airflow.configuration import conf
-from airflow.executors.base_executor import CommandType, EventBufferValueType,
QueuedTaskInstanceType
from airflow.executors.kubernetes_executor import KubernetesExecutor
from airflow.executors.local_executor import LocalExecutor
-from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance,
TaskInstanceKey
from airflow.utils.log.logging_mixin import LoggingMixin
+if TYPE_CHECKING:
+ from airflow.executors.base_executor import CommandType,
EventBufferValueType, QueuedTaskInstanceType
+ from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance,
TaskInstanceKey
+
class LocalKubernetesExecutor(LoggingMixin):
"""
@@ -127,6 +129,8 @@ class LocalKubernetesExecutor(LoggingMixin):
cfg_path: str | None = None,
) -> None:
"""Queues task instance via local or kubernetes executor."""
+ from airflow.models.taskinstance import SimpleTaskInstance
+
executor = self._router(SimpleTaskInstance.from_ti(task_instance))
self.log.debug(
"Using executor: %s to queue_task_instance for %s",
executor.__class__.__name__, task_instance.key
diff --git a/airflow/executors/sequential_executor.py
b/airflow/executors/sequential_executor.py
index 0f75c3d930..b3da5af080 100644
--- a/airflow/executors/sequential_executor.py
+++ b/airflow/executors/sequential_executor.py
@@ -25,12 +25,15 @@ SequentialExecutor.
from __future__ import annotations
import subprocess
-from typing import Any
+from typing import TYPE_CHECKING, Any
-from airflow.executors.base_executor import BaseExecutor, CommandType
-from airflow.models.taskinstance import TaskInstanceKey
+from airflow.executors.base_executor import BaseExecutor
from airflow.utils.state import State
+if TYPE_CHECKING:
+ from airflow.executors.base_executor import CommandType
+ from airflow.models.taskinstance import TaskInstanceKey
+
class SequentialExecutor(BaseExecutor):
"""
diff --git a/tests/executors/test_debug_executor.py
b/tests/executors/test_debug_executor.py
index cb939a1e6a..d4b40f637a 100644
--- a/tests/executors/test_debug_executor.py
+++ b/tests/executors/test_debug_executor.py
@@ -38,7 +38,7 @@ class TestDebugExecutor:
assert not executor.tasks_to_run
run_task_mock.assert_has_calls([mock.call(ti1), mock.call(ti2)])
- @mock.patch("airflow.executors.debug_executor.TaskInstance")
+ @mock.patch("airflow.models.taskinstance.TaskInstance")
def test_run_task(self, task_instance_mock):
ti_key = "key"
job_id = " job_id"