This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new abe3b318b5 Use asserts instead of exceptions for executor not started 
(#28019)
abe3b318b5 is described below

commit abe3b318b525cca703cd6c0cda25af87cdf19b1b
Author: Daniel Standish <[email protected]>
AuthorDate: Fri Dec 2 15:27:46 2022 -0800

    Use asserts instead of exceptions for executor not started (#28019)
---
 airflow/executors/base_executor.py       |  2 -
 airflow/executors/dask_executor.py       | 32 +++++++------
 airflow/executors/kubernetes_executor.py | 81 +++++++++++++++++---------------
 airflow/executors/local_executor.py      | 39 ++++++++-------
 4 files changed, 80 insertions(+), 74 deletions(-)

diff --git a/airflow/executors/base_executor.py 
b/airflow/executors/base_executor.py
index f9fac2fd3d..0c9af11864 100644
--- a/airflow/executors/base_executor.py
+++ b/airflow/executors/base_executor.py
@@ -33,8 +33,6 @@ from airflow.utils.state import State
 
 PARALLELISM: int = conf.getint("core", "PARALLELISM")
 
-NOT_STARTED_MESSAGE = "The executor should be started first!"
-
 QUEUEING_ATTEMPTS = 5
 
 # Command to execute - list of strings
diff --git a/airflow/executors/dask_executor.py 
b/airflow/executors/dask_executor.py
index 41a560ddc8..a2c2c57163 100644
--- a/airflow/executors/dask_executor.py
+++ b/airflow/executors/dask_executor.py
@@ -25,14 +25,14 @@ DaskExecutor
 from __future__ import annotations
 
 import subprocess
-from typing import Any
+from typing import TYPE_CHECKING, Any
 
 from distributed import Client, Future, as_completed
 from distributed.security import Security
 
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException
-from airflow.executors.base_executor import NOT_STARTED_MESSAGE, BaseExecutor, 
CommandType
+from airflow.executors.base_executor import BaseExecutor, CommandType
 from airflow.models.taskinstance import TaskInstanceKey
 
 # queue="default" is a special case since this is the base config default 
queue name,
@@ -78,15 +78,14 @@ class DaskExecutor(BaseExecutor):
         queue: str | None = None,
         executor_config: Any | None = None,
     ) -> None:
+        if TYPE_CHECKING:
+            assert self.client
 
         self.validate_airflow_tasks_run_command(command)
 
         def airflow_run():
             return subprocess.check_call(command, close_fds=True)
 
-        if not self.client:
-            raise AirflowException(NOT_STARTED_MESSAGE)
-
         resources = None
         if queue not in _UNDEFINED_QUEUES:
             scheduler_info = self.client.scheduler_info()
@@ -102,8 +101,9 @@ class DaskExecutor(BaseExecutor):
         self.futures[future] = key  # type: ignore
 
     def _process_future(self, future: Future) -> None:
-        if not self.futures:
-            raise AirflowException(NOT_STARTED_MESSAGE)
+        if TYPE_CHECKING:
+            assert self.futures
+
         if future.done():
             key = self.futures[future]
             if future.exception():
@@ -117,23 +117,25 @@ class DaskExecutor(BaseExecutor):
             self.futures.pop(future)
 
     def sync(self) -> None:
-        if self.futures is None:
-            raise AirflowException(NOT_STARTED_MESSAGE)
+        if TYPE_CHECKING:
+            assert self.futures
+
         # make a copy so futures can be popped during iteration
         for future in self.futures.copy():
             self._process_future(future)
 
     def end(self) -> None:
-        if not self.client:
-            raise AirflowException(NOT_STARTED_MESSAGE)
-        if self.futures is None:
-            raise AirflowException(NOT_STARTED_MESSAGE)
+        if TYPE_CHECKING:
+            assert self.client
+            assert self.futures
+
         self.client.cancel(list(self.futures.keys()))
         for future in as_completed(self.futures.copy()):
             self._process_future(future)
 
     def terminate(self):
-        if self.futures is None:
-            raise AirflowException(NOT_STARTED_MESSAGE)
+        if TYPE_CHECKING:
+            assert self.futures
+
         self.client.cancel(self.futures.keys())
         self.end()
diff --git a/airflow/executors/kubernetes_executor.py 
b/airflow/executors/kubernetes_executor.py
index 3a3308d6c6..aeb02dfebc 100644
--- a/airflow/executors/kubernetes_executor.py
+++ b/airflow/executors/kubernetes_executor.py
@@ -30,7 +30,7 @@ import multiprocessing
 import time
 from datetime import timedelta
 from queue import Empty, Queue
-from typing import Any, Dict, Optional, Sequence, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple
 
 from kubernetes import client, watch
 from kubernetes.client import Configuration, models as k8s
@@ -38,7 +38,7 @@ from kubernetes.client.rest import ApiException
 from urllib3.exceptions import ReadTimeoutError
 
 from airflow.exceptions import AirflowException, PodMutationHookException, 
PodReconciliationError
-from airflow.executors.base_executor import NOT_STARTED_MESSAGE, BaseExecutor, 
CommandType
+from airflow.executors.base_executor import BaseExecutor, CommandType
 from airflow.kubernetes import pod_generator
 from airflow.kubernetes.kube_client import get_kube_client
 from airflow.kubernetes.kube_config import KubeConfig
@@ -96,9 +96,10 @@ class KubernetesJobWatcher(multiprocessing.Process, 
LoggingMixin):
 
     def run(self) -> None:
         """Performs watching"""
+        if TYPE_CHECKING:
+            assert self.scheduler_job_id
+
         kube_client: client.CoreV1Api = get_kube_client()
-        if not self.scheduler_job_id:
-            raise AirflowException(NOT_STARTED_MESSAGE)
         while True:
             try:
                 self.resource_version = self._run(
@@ -456,10 +457,10 @@ class KubernetesExecutor(BaseExecutor):
         is around, and if not, and there's no matching entry in our own
         task_queue, marks it for re-execution.
         """
-        self.log.debug("Clearing tasks that have not been launched")
-        if not self.kube_client:
-            raise AirflowException(NOT_STARTED_MESSAGE)
+        if TYPE_CHECKING:
+            assert self.kube_client
 
+        self.log.debug("Clearing tasks that have not been launched")
         query = session.query(TaskInstance).filter(
             TaskInstance.state == State.QUEUED, TaskInstance.queued_by_job_id 
== self.job_id
         )
@@ -551,6 +552,9 @@ class KubernetesExecutor(BaseExecutor):
         executor_config: Any | None = None,
     ) -> None:
         """Executes task asynchronously"""
+        if TYPE_CHECKING:
+            assert self.task_queue
+
         if self.log.isEnabledFor(logging.DEBUG):
             self.log.debug("Add task %s with command %s, executor_config %s", 
key, command, executor_config)
         else:
@@ -567,8 +571,6 @@ class KubernetesExecutor(BaseExecutor):
             pod_template_file = executor_config.get("pod_template_file", None)
         else:
             pod_template_file = None
-        if not self.task_queue:
-            raise AirflowException(NOT_STARTED_MESSAGE)
         self.event_buffer[key] = (State.QUEUED, self.scheduler_job_id)
         self.task_queue.put((key, command, kube_executor_config, 
pod_template_file))
         # We keep a temporary local record that we've handled this so we don't
@@ -577,22 +579,18 @@ class KubernetesExecutor(BaseExecutor):
 
     def sync(self) -> None:
         """Synchronize task state."""
+        if TYPE_CHECKING:
+            assert self.scheduler_job_id
+            assert self.kube_scheduler
+            assert self.kube_config
+            assert self.result_queue
+            assert self.task_queue
+            assert self.event_scheduler
+
         if self.running:
             self.log.debug("self.running: %s", self.running)
         if self.queued_tasks:
             self.log.debug("self.queued: %s", self.queued_tasks)
-        if not self.scheduler_job_id:
-            raise AirflowException(NOT_STARTED_MESSAGE)
-        if not self.kube_scheduler:
-            raise AirflowException(NOT_STARTED_MESSAGE)
-        if not self.kube_config:
-            raise AirflowException(NOT_STARTED_MESSAGE)
-        if not self.result_queue:
-            raise AirflowException(NOT_STARTED_MESSAGE)
-        if not self.task_queue:
-            raise AirflowException(NOT_STARTED_MESSAGE)
-        if not self.event_scheduler:
-            raise AirflowException(NOT_STARTED_MESSAGE)
         self.kube_scheduler.sync()
 
         last_resource_version = None
@@ -667,8 +665,9 @@ class KubernetesExecutor(BaseExecutor):
 
     def _check_worker_pods_pending_timeout(self):
         """Check if any pending worker pods have timed out"""
-        if not self.scheduler_job_id:
-            raise AirflowException(NOT_STARTED_MESSAGE)
+        if TYPE_CHECKING:
+            assert self.scheduler_job_id
+
         timeout = self.kube_config.worker_pods_pending_timeout
         self.log.debug("Looking for pending worker pods older than %d 
seconds", timeout)
 
@@ -702,10 +701,11 @@ class KubernetesExecutor(BaseExecutor):
                 self.kube_scheduler.delete_pod(pod.metadata.name, 
pod.metadata.namespace)
 
     def _change_state(self, key: TaskInstanceKey, state: str | None, pod_id: 
str, namespace: str) -> None:
+        if TYPE_CHECKING:
+            assert self.kube_scheduler
+
         if state != State.RUNNING:
             if self.kube_config.delete_worker_pods:
-                if not self.kube_scheduler:
-                    raise AirflowException(NOT_STARTED_MESSAGE)
                 if state != State.FAILED or 
self.kube_config.delete_worker_pods_on_failure:
                     self.kube_scheduler.delete_pod(pod_id, namespace)
                     self.log.info("Deleted pod: %s in namespace %s", str(key), 
str(namespace))
@@ -740,8 +740,9 @@ class KubernetesExecutor(BaseExecutor):
         :param pod: V1Pod spec that we will patch with new label
         :param pod_ids: pod_ids we expect to patch.
         """
-        if not self.scheduler_job_id:
-            raise AirflowException(NOT_STARTED_MESSAGE)
+        if TYPE_CHECKING:
+            assert self.scheduler_job_id
+
         self.log.info("attempting to adopt pod %s", pod.metadata.name)
         pod.metadata.labels["airflow-worker"] = 
pod_generator.make_safe_label_value(self.scheduler_job_id)
         pod_id = annotations_to_key(pod.metadata.annotations)
@@ -767,8 +768,9 @@ class KubernetesExecutor(BaseExecutor):
 
         :param kube_client: kubernetes client for speaking to kube API
         """
-        if not self.scheduler_job_id:
-            raise AirflowException(NOT_STARTED_MESSAGE)
+        if TYPE_CHECKING:
+            assert self.scheduler_job_id
+
         new_worker_id_label = 
pod_generator.make_safe_label_value(self.scheduler_job_id)
         kwargs = {
             "field_selector": "status.phase=Succeeded",
@@ -788,8 +790,9 @@ class KubernetesExecutor(BaseExecutor):
                 self.log.info("Failed to adopt pod %s. Reason: %s", 
pod.metadata.name, e)
 
     def _flush_task_queue(self) -> None:
-        if not self.task_queue:
-            raise AirflowException(NOT_STARTED_MESSAGE)
+        if TYPE_CHECKING:
+            assert self.task_queue
+
         self.log.debug("Executor shutting down, task_queue approximate 
size=%d", self.task_queue.qsize())
         while True:
             try:
@@ -801,8 +804,9 @@ class KubernetesExecutor(BaseExecutor):
                 break
 
     def _flush_result_queue(self) -> None:
-        if not self.result_queue:
-            raise AirflowException(NOT_STARTED_MESSAGE)
+        if TYPE_CHECKING:
+            assert self.result_queue
+
         self.log.debug("Executor shutting down, result_queue approximate 
size=%d", self.result_queue.qsize())
         while True:
             try:
@@ -829,12 +833,11 @@ class KubernetesExecutor(BaseExecutor):
 
     def end(self) -> None:
         """Called when the executor shuts down"""
-        if not self.task_queue:
-            raise AirflowException(NOT_STARTED_MESSAGE)
-        if not self.result_queue:
-            raise AirflowException(NOT_STARTED_MESSAGE)
-        if not self.kube_scheduler:
-            raise AirflowException(NOT_STARTED_MESSAGE)
+        if TYPE_CHECKING:
+            assert self.task_queue
+            assert self.result_queue
+            assert self.kube_scheduler
+
         self.log.info("Shutting down Kubernetes executor")
         self.log.debug("Flushing task_queue...")
         self._flush_task_queue()
diff --git a/airflow/executors/local_executor.py 
b/airflow/executors/local_executor.py
index 4fae990bc3..c2c82d8639 100644
--- a/airflow/executors/local_executor.py
+++ b/airflow/executors/local_executor.py
@@ -31,13 +31,13 @@ from abc import abstractmethod
 from multiprocessing import Manager, Process
 from multiprocessing.managers import SyncManager
 from queue import Empty, Queue
-from typing import Any, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Optional, Tuple
 
 from setproctitle import getproctitle, setproctitle
 
 from airflow import settings
 from airflow.exceptions import AirflowException
-from airflow.executors.base_executor import NOT_STARTED_MESSAGE, PARALLELISM, 
BaseExecutor, CommandType
+from airflow.executors.base_executor import PARALLELISM, BaseExecutor, 
CommandType
 from airflow.models.taskinstance import TaskInstanceKey, TaskInstanceStateType
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.state import State
@@ -245,8 +245,9 @@ class LocalExecutor(BaseExecutor):
             :param queue: Name of the queue
             :param executor_config: configuration for the executor
             """
-            if not self.executor.result_queue:
-                raise AirflowException(NOT_STARTED_MESSAGE)
+            if TYPE_CHECKING:
+                assert self.executor.result_queue
+
             local_worker = LocalWorker(self.executor.result_queue, key=key, 
command=command)
             self.executor.workers_used += 1
             self.executor.workers_active += 1
@@ -284,11 +285,11 @@ class LocalExecutor(BaseExecutor):
 
         def start(self) -> None:
             """Starts limited parallelism implementation."""
-            if not self.executor.manager:
-                raise AirflowException(NOT_STARTED_MESSAGE)
+            if TYPE_CHECKING:
+                assert self.executor.manager
+                assert self.executor.result_queue
+
             self.queue = self.executor.manager.Queue()
-            if not self.executor.result_queue:
-                raise AirflowException(NOT_STARTED_MESSAGE)
             self.executor.workers = [
                 QueuedLocalWorker(self.queue, self.executor.result_queue)
                 for _ in range(self.executor.parallelism)
@@ -314,8 +315,9 @@ class LocalExecutor(BaseExecutor):
             :param queue: name of the queue
             :param executor_config: configuration for the executor
             """
-            if not self.queue:
-                raise AirflowException(NOT_STARTED_MESSAGE)
+            if TYPE_CHECKING:
+                assert self.queue
+
             self.queue.put((key, command))
 
         def sync(self):
@@ -365,8 +367,8 @@ class LocalExecutor(BaseExecutor):
         executor_config: Any | None = None,
     ) -> None:
         """Execute asynchronously."""
-        if not self.impl:
-            raise AirflowException(NOT_STARTED_MESSAGE)
+        if TYPE_CHECKING:
+            assert self.impl
 
         self.validate_airflow_tasks_run_command(command)
 
@@ -374,8 +376,9 @@ class LocalExecutor(BaseExecutor):
 
     def sync(self) -> None:
         """Sync will get called periodically by the heartbeat method."""
-        if not self.impl:
-            raise AirflowException(NOT_STARTED_MESSAGE)
+        if TYPE_CHECKING:
+            assert self.impl
+
         self.impl.sync()
 
     def end(self) -> None:
@@ -383,10 +386,10 @@ class LocalExecutor(BaseExecutor):
         Ends the executor.
         :return:
         """
-        if not self.impl:
-            raise AirflowException(NOT_STARTED_MESSAGE)
-        if not self.manager:
-            raise AirflowException(NOT_STARTED_MESSAGE)
+        if TYPE_CHECKING:
+            assert self.impl
+            assert self.manager
+
         self.log.info(
             "Shutting down LocalExecutor"
             "; waiting for running tasks to finish.  Signal again if you don't 
want to wait."

Reply via email to