This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-7-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 097d2bed4e372e4f4cc86b14c6b8a47dd8d65902 Author: Hussein Awala <[email protected]> AuthorDate: Tue Aug 8 14:43:36 2023 +0200 Replace State by TaskInstanceState in Airflow executors (#32627) * Replace State by TaskInstanceState in Airflow executors * chaneg state type in change_state method, KubernetesResultsType and KubernetesWatchType to TaskInstanceState * Fix change_state annotation in CeleryExecutor --------- Co-authored-by: Tzu-ping Chung <[email protected]> (cherry picked from commit 9556d6d5f611428ac8a3a5891647b720d4498ace) --- airflow/executors/base_executor.py | 8 +++---- airflow/executors/debug_executor.py | 28 +++++++++++----------- airflow/executors/sequential_executor.py | 6 ++--- .../providers/celery/executors/celery_executor.py | 8 +++---- .../kubernetes/executors/kubernetes_executor.py | 9 +++---- .../executors/kubernetes_executor_types.py | 5 ++-- .../executors/kubernetes_executor_utils.py | 14 +++++++---- 7 files changed, 43 insertions(+), 35 deletions(-) diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index 999125afe7..10aebbeb3d 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -32,7 +32,7 @@ from airflow.configuration import conf from airflow.exceptions import RemovedInAirflow3Warning from airflow.stats import Stats from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.state import State +from airflow.utils.state import TaskInstanceState PARALLELISM: int = conf.getint("core", "PARALLELISM") @@ -295,7 +295,7 @@ class BaseExecutor(LoggingMixin): self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) self.running.add(key) - def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None: + def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None: """ Changes state of the task. @@ -317,7 +317,7 @@ class BaseExecutor(LoggingMixin): :param info: Executor information for the task instance :param key: Unique key for the task instance """ - self.change_state(key, State.FAILED, info) + self.change_state(key, TaskInstanceState.FAILED, info) def success(self, key: TaskInstanceKey, info=None) -> None: """ @@ -326,7 +326,7 @@ class BaseExecutor(LoggingMixin): :param info: Executor information for the task instance :param key: Unique key for the task instance """ - self.change_state(key, State.SUCCESS, info) + self.change_state(key, TaskInstanceState.SUCCESS, info) def get_event_buffer(self, dag_ids=None) -> dict[TaskInstanceKey, EventBufferValueType]: """ diff --git a/airflow/executors/debug_executor.py b/airflow/executors/debug_executor.py index ca23b09a67..8a46d6cda0 100644 --- a/airflow/executors/debug_executor.py +++ b/airflow/executors/debug_executor.py @@ -29,7 +29,7 @@ import time from typing import TYPE_CHECKING, Any from airflow.executors.base_executor import BaseExecutor -from airflow.utils.state import State +from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: from airflow.models.taskinstance import TaskInstance @@ -68,15 +68,15 @@ class DebugExecutor(BaseExecutor): while self.tasks_to_run: ti = self.tasks_to_run.pop(0) if self.fail_fast and not task_succeeded: - self.log.info("Setting %s to %s", ti.key, State.UPSTREAM_FAILED) - ti.set_state(State.UPSTREAM_FAILED) - self.change_state(ti.key, State.UPSTREAM_FAILED) + self.log.info("Setting %s to %s", ti.key, TaskInstanceState.UPSTREAM_FAILED) + ti.set_state(TaskInstanceState.UPSTREAM_FAILED) + self.change_state(ti.key, TaskInstanceState.UPSTREAM_FAILED) continue if self._terminated.is_set(): - self.log.info("Executor is terminated! Stopping %s to %s", ti.key, State.FAILED) - ti.set_state(State.FAILED) - self.change_state(ti.key, State.FAILED) + self.log.info("Executor is terminated! Stopping %s to %s", ti.key, TaskInstanceState.FAILED) + ti.set_state(TaskInstanceState.FAILED) + self.change_state(ti.key, TaskInstanceState.FAILED) continue task_succeeded = self._run_task(ti) @@ -87,11 +87,11 @@ class DebugExecutor(BaseExecutor): try: params = self.tasks_params.pop(ti.key, {}) ti.run(job_id=ti.job_id, **params) - self.change_state(key, State.SUCCESS) + self.change_state(key, TaskInstanceState.SUCCESS) return True except Exception as e: - ti.set_state(State.FAILED) - self.change_state(key, State.FAILED) + ti.set_state(TaskInstanceState.FAILED) + self.change_state(key, TaskInstanceState.FAILED) self.log.exception("Failed to execute task: %s.", str(e)) return False @@ -148,14 +148,14 @@ class DebugExecutor(BaseExecutor): def end(self) -> None: """Set states of queued tasks to UPSTREAM_FAILED marking them as not executed.""" for ti in self.tasks_to_run: - self.log.info("Setting %s to %s", ti.key, State.UPSTREAM_FAILED) - ti.set_state(State.UPSTREAM_FAILED) - self.change_state(ti.key, State.UPSTREAM_FAILED) + self.log.info("Setting %s to %s", ti.key, TaskInstanceState.UPSTREAM_FAILED) + ti.set_state(TaskInstanceState.UPSTREAM_FAILED) + self.change_state(ti.key, TaskInstanceState.UPSTREAM_FAILED) def terminate(self) -> None: self._terminated.set() - def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None: + def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None: self.log.debug("Popping %s from executor task queue.", key) self.running.remove(key) self.event_buffer[key] = state, info diff --git a/airflow/executors/sequential_executor.py b/airflow/executors/sequential_executor.py index 28f88c6b87..2715edad6e 100644 --- a/airflow/executors/sequential_executor.py +++ b/airflow/executors/sequential_executor.py @@ -28,7 +28,7 @@ import subprocess from typing import TYPE_CHECKING, Any from airflow.executors.base_executor import BaseExecutor -from airflow.utils.state import State +from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: from airflow.executors.base_executor import CommandType @@ -75,9 +75,9 @@ class SequentialExecutor(BaseExecutor): try: subprocess.check_call(command, close_fds=True) - self.change_state(key, State.SUCCESS) + self.change_state(key, TaskInstanceState.SUCCESS) except subprocess.CalledProcessError as e: - self.change_state(key, State.FAILED) + self.change_state(key, TaskInstanceState.FAILED) self.log.error("Failed to execute task %s.", str(e)) self.commands_to_run = [] diff --git a/airflow/providers/celery/executors/celery_executor.py b/airflow/providers/celery/executors/celery_executor.py index 51287f3c1b..4708ef2137 100644 --- a/airflow/providers/celery/executors/celery_executor.py +++ b/airflow/providers/celery/executors/celery_executor.py @@ -74,7 +74,7 @@ from airflow.configuration import conf from airflow.exceptions import AirflowTaskTimeout from airflow.executors.base_executor import BaseExecutor from airflow.stats import Stats -from airflow.utils.state import State +from airflow.utils.state import TaskInstanceState log = logging.getLogger(__name__) @@ -299,7 +299,7 @@ class CeleryExecutor(BaseExecutor): self.task_publish_retries.pop(key, None) if isinstance(result, ExceptionWithTraceback): self.log.error(CELERY_SEND_ERR_MSG_HEADER + ": %s\n%s\n", result.exception, result.traceback) - self.event_buffer[key] = (State.FAILED, None) + self.event_buffer[key] = (TaskInstanceState.FAILED, None) elif result is not None: result.backend = cached_celery_backend self.running.add(key) @@ -308,7 +308,7 @@ class CeleryExecutor(BaseExecutor): # Store the Celery task_id in the event buffer. This will get "overwritten" if the task # has another event, but that is fine, because the only other events are success/failed at # which point we don't need the ID anymore anyway - self.event_buffer[key] = (State.QUEUED, result.task_id) + self.event_buffer[key] = (TaskInstanceState.QUEUED, result.task_id) # If the task runs _really quickly_ we may already have a result! self.update_task_state(key, result.state, getattr(result, "info", None)) @@ -355,7 +355,7 @@ class CeleryExecutor(BaseExecutor): if state: self.update_task_state(key, state, info) - def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None: + def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None: super().change_state(key, state, info) self.tasks.pop(key, None) diff --git a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py index a5aa36d981..051686c8d5 100644 --- a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py +++ b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py @@ -78,7 +78,7 @@ from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import annota from airflow.utils.event_scheduler import EventScheduler from airflow.utils.log.logging_mixin import remove_escape_codes from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.state import State, TaskInstanceState +from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: from kubernetes import client @@ -425,7 +425,7 @@ class KubernetesExecutor(BaseExecutor): def _change_state( self, key: TaskInstanceKey, - state: str | None, + state: TaskInstanceState | None, pod_name: str, namespace: str, session: Session = NEW_SESSION, @@ -433,12 +433,12 @@ class KubernetesExecutor(BaseExecutor): if TYPE_CHECKING: assert self.kube_scheduler - if state == State.RUNNING: + if state == TaskInstanceState.RUNNING: self.event_buffer[key] = state, None return if self.kube_config.delete_worker_pods: - if state != State.FAILED or self.kube_config.delete_worker_pods_on_failure: + if state != TaskInstanceState.FAILED or self.kube_config.delete_worker_pods_on_failure: self.kube_scheduler.delete_pod(pod_name=pod_name, namespace=namespace) self.log.info("Deleted pod: %s in namespace %s", str(key), str(namespace)) else: @@ -455,6 +455,7 @@ class KubernetesExecutor(BaseExecutor): from airflow.models.taskinstance import TaskInstance state = session.query(TaskInstance.state).filter(TaskInstance.filter_for_tis([key])).scalar() + state = TaskInstanceState(state) self.event_buffer[key] = state, None diff --git a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_types.py b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_types.py index a13cd35f8d..80b8f1de72 100644 --- a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_types.py +++ b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_types.py @@ -21,15 +21,16 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple if TYPE_CHECKING: from airflow.executors.base_executor import CommandType from airflow.models.taskinstance import TaskInstanceKey + from airflow.utils.state import TaskInstanceState # TaskInstance key, command, configuration, pod_template_file KubernetesJobType = Tuple[TaskInstanceKey, CommandType, Any, Optional[str]] # key, pod state, pod_name, namespace, resource_version - KubernetesResultsType = Tuple[TaskInstanceKey, Optional[str], str, str, str] + KubernetesResultsType = Tuple[TaskInstanceKey, Optional[TaskInstanceState], str, str, str] # pod_name, namespace, pod state, annotations, resource_version - KubernetesWatchType = Tuple[str, str, Optional[str], Dict[str, str], str] + KubernetesWatchType = Tuple[str, str, Optional[TaskInstanceState], Dict[str, str], str] ALL_NAMESPACES = "ALL_NAMESPACES" POD_EXECUTOR_DONE_KEY = "airflow_executor_done" diff --git a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py index c1ee9d1ebe..b19d88eb49 100644 --- a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py +++ b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py @@ -36,7 +36,7 @@ from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import ( ) from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.state import State +from airflow.utils.state import TaskInstanceState try: from airflow.providers.cncf.kubernetes.executors.kubernetes_executor_types import ( @@ -223,12 +223,16 @@ class KubernetesJobWatcher(multiprocessing.Process, LoggingMixin): # since kube server have received request to delete pod set TI state failed if event["type"] == "DELETED" and pod.metadata.deletion_timestamp: self.log.info("Event: Failed to start pod %s, annotations: %s", pod_name, annotations_string) - self.watcher_queue.put((pod_name, namespace, State.FAILED, annotations, resource_version)) + self.watcher_queue.put( + (pod_name, namespace, TaskInstanceState.FAILED, annotations, resource_version) + ) else: self.log.debug("Event: %s Pending, annotations: %s", pod_name, annotations_string) elif status == "Failed": self.log.error("Event: %s Failed, annotations: %s", pod_name, annotations_string) - self.watcher_queue.put((pod_name, namespace, State.FAILED, annotations, resource_version)) + self.watcher_queue.put( + (pod_name, namespace, TaskInstanceState.FAILED, annotations, resource_version) + ) elif status == "Succeeded": # We get multiple events once the pod hits a terminal state, and we only want to # send it along to the scheduler once. @@ -256,7 +260,9 @@ class KubernetesJobWatcher(multiprocessing.Process, LoggingMixin): pod_name, annotations_string, ) - self.watcher_queue.put((pod_name, namespace, State.FAILED, annotations, resource_version)) + self.watcher_queue.put( + (pod_name, namespace, TaskInstanceState.FAILED, annotations, resource_version) + ) else: self.log.info("Event: %s is Running, annotations: %s", pod_name, annotations_string) else:
