This is an automated email from the ASF dual-hosted git repository.
dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b459af3ee0 Replace State usages with strong-typed enums (#31735)
b459af3ee0 is described below
commit b459af3ee0f94cd246e3d401ca3eec18ffd85db0
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Sat Jun 17 04:48:57 2023 +0800
Replace State usages with strong-typed enums (#31735)
Only in the main Airflow code base. There are many more in tests that I
might tackle some day.
Additionally, there are some cases where TI state is used for "job" state.
We may deal with this later by introducing a new type ExecutorState.
---
airflow/api/common/delete_dag.py | 4 +-
airflow/api/common/mark_tasks.py | 30 +++++++++---
.../endpoints/task_instance_endpoint.py | 4 +-
airflow/api_connexion/schemas/enum_schemas.py | 4 +-
airflow/dag_processing/processor.py | 16 +++++--
airflow/executors/base_executor.py | 10 ++--
airflow/executors/celery_executor.py | 8 ++--
airflow/executors/celery_executor_utils.py | 8 ++--
airflow/executors/debug_executor.py | 28 +++++------
airflow/executors/kubernetes_executor.py | 24 ++++++----
airflow/executors/local_executor.py | 16 +++----
airflow/executors/sequential_executor.py | 6 +--
airflow/jobs/backfill_job_runner.py | 6 +--
airflow/jobs/job.py | 7 ++-
airflow/jobs/local_task_job_runner.py | 6 +--
airflow/jobs/scheduler_job_runner.py | 28 +++++------
airflow/listeners/spec/taskinstance.py | 6 +--
airflow/models/dag.py | 18 ++++----
airflow/models/dagrun.py | 42 ++++++++---------
airflow/models/pool.py | 8 ++--
airflow/models/skipmixin.py | 4 +-
airflow/models/taskinstance.py | 54 +++++++++++-----------
airflow/operators/subdag.py | 20 ++++----
airflow/operators/trigger_dagrun.py | 6 +--
airflow/sensors/external_task.py | 28 +++++------
airflow/sentry.py | 4 +-
airflow/ti_deps/dependencies_states.py | 30 ++++++------
airflow/ti_deps/deps/dagrun_exists_dep.py | 4 +-
airflow/ti_deps/deps/not_in_retry_period_dep.py | 4 +-
airflow/ti_deps/deps/not_previously_skipped_dep.py | 4 +-
airflow/ti_deps/deps/prev_dagrun_dep.py | 4 +-
airflow/ti_deps/deps/ready_to_reschedule.py | 4 +-
airflow/ti_deps/deps/task_not_running_dep.py | 4 +-
airflow/utils/dot_renderer.py | 2 +-
airflow/utils/log/file_task_handler.py | 5 +-
airflow/utils/log/log_reader.py | 23 +++++----
airflow/utils/state.py | 9 +---
airflow/www/utils.py | 10 ++--
airflow/www/views.py | 26 +++++------
39 files changed, 280 insertions(+), 244 deletions(-)
diff --git a/airflow/api/common/delete_dag.py b/airflow/api/common/delete_dag.py
index 45611729ea..1d879a667a 100644
--- a/airflow/api/common/delete_dag.py
+++ b/airflow/api/common/delete_dag.py
@@ -29,7 +29,7 @@ from airflow.models import DagModel, TaskFail
from airflow.models.serialized_dag import SerializedDagModel
from airflow.utils.db import get_sqla_model_classes
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
log = logging.getLogger(__name__)
@@ -50,7 +50,7 @@ def delete_dag(dag_id: str, keep_records_in_log: bool = True,
session: Session =
running_tis = session.scalar(
select(models.TaskInstance.state)
.where(models.TaskInstance.dag_id == dag_id)
- .where(models.TaskInstance.state == State.RUNNING)
+ .where(models.TaskInstance.state == TaskInstanceState.RUNNING)
.limit(1)
)
if running_tis:
diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py
index b965237bdf..184251b515 100644
--- a/airflow/api/common/mark_tasks.py
+++ b/airflow/api/common/mark_tasks.py
@@ -155,7 +155,7 @@ def set_state(
for task_instance in tis_altered:
# The try_number was decremented when setting to up_for_reschedule
and deferred.
# Increment it back when changing the state again
- if task_instance.state in [State.DEFERRED,
State.UP_FOR_RESCHEDULE]:
+ if task_instance.state in (TaskInstanceState.DEFERRED,
TaskInstanceState.UP_FOR_RESCHEDULE):
task_instance._try_number += 1
task_instance.set_state(state, session=session)
session.flush()
@@ -362,7 +362,7 @@ def _set_dag_run_state(dag_id: str, run_id: str, state:
DagRunState, session: SA
select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id)
).scalar_one()
dag_run.state = state
- if state == State.RUNNING:
+ if state == DagRunState.RUNNING:
dag_run.start_date = timezone.utcnow()
dag_run.end_date = None
else:
@@ -415,7 +415,13 @@ def set_dag_run_state_to_success(
# Mark all task instances of the dag run to success.
for task in dag.tasks:
task.dag = dag
- return set_state(tasks=dag.tasks, run_id=run_id, state=State.SUCCESS,
commit=commit, session=session)
+ return set_state(
+ tasks=dag.tasks,
+ run_id=run_id,
+ state=TaskInstanceState.SUCCESS,
+ commit=commit,
+ session=session,
+ )
@provide_session
@@ -468,7 +474,9 @@ def set_dag_run_state_to_failed(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.run_id == run_id,
TaskInstance.task_id.in_(task_ids),
- TaskInstance.state.in_([State.RUNNING, State.DEFERRED,
State.UP_FOR_RESCHEDULE]),
+ TaskInstance.state.in_(
+ (TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED,
TaskInstanceState.UP_FOR_RESCHEDULE),
+ ),
)
)
@@ -487,16 +495,24 @@ def set_dag_run_state_to_failed(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.run_id == run_id,
TaskInstance.state.not_in(State.finished),
- TaskInstance.state.not_in([State.RUNNING, State.DEFERRED,
State.UP_FOR_RESCHEDULE]),
+ TaskInstance.state.not_in(
+ (TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED,
TaskInstanceState.UP_FOR_RESCHEDULE),
+ ),
)
)
tis = [ti for ti in tis]
if commit:
for ti in tis:
- ti.set_state(State.SKIPPED)
+ ti.set_state(TaskInstanceState.SKIPPED)
- return tis + set_state(tasks=tasks, run_id=run_id, state=State.FAILED,
commit=commit, session=session)
+ return tis + set_state(
+ tasks=tasks,
+ run_id=run_id,
+ state=TaskInstanceState.FAILED,
+ commit=commit,
+ session=session,
+ )
def __set_dag_run_state_to_running_or_queued(
diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py
b/airflow/api_connexion/endpoints/task_instance_endpoint.py
index 533d97c858..7ccc6b8848 100644
--- a/airflow/api_connexion/endpoints/task_instance_endpoint.py
+++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py
@@ -50,7 +50,7 @@ from airflow.models.taskinstance import TaskInstance as TI,
clear_task_instances
from airflow.security import permissions
from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import DagRunState, State
+from airflow.utils.state import DagRunState
T = TypeVar("T")
@@ -264,7 +264,7 @@ def get_mapped_task_instances(
def _convert_state(states: Iterable[str] | None) -> list[str | None] | None:
if not states:
return None
- return [State.NONE if s == "none" else s for s in states]
+ return [None if s == "none" else s for s in states]
def _apply_array_filter(query: Query, key: ClauseElement, values:
Iterable[Any] | None) -> Query:
diff --git a/airflow/api_connexion/schemas/enum_schemas.py
b/airflow/api_connexion/schemas/enum_schemas.py
index 981a3669b1..ba82010783 100644
--- a/airflow/api_connexion/schemas/enum_schemas.py
+++ b/airflow/api_connexion/schemas/enum_schemas.py
@@ -26,7 +26,7 @@ class DagStateField(fields.String):
def __init__(self, **metadata):
super().__init__(**metadata)
- self.validators = [validate.OneOf(State.dag_states)] +
list(self.validators)
+ self.validators = [validate.OneOf(State.dag_states), *self.validators]
class TaskInstanceStateField(fields.String):
@@ -34,4 +34,4 @@ class TaskInstanceStateField(fields.String):
def __init__(self, **metadata):
super().__init__(**metadata)
- self.validators = [validate.OneOf(State.task_states)] +
list(self.validators)
+ self.validators = [validate.OneOf(State.task_states), *self.validators]
diff --git a/airflow/dag_processing/processor.py
b/airflow/dag_processing/processor.py
index 9ee865bbf8..d742c5314a 100644
--- a/airflow/dag_processing/processor.py
+++ b/airflow/dag_processing/processor.py
@@ -56,7 +56,7 @@ from airflow.utils.file import iter_airflow_imports,
might_contain_dag
from airflow.utils.log.logging_mixin import LoggingMixin, StreamLogWriter,
set_context
from airflow.utils.mixins import MultiprocessingStartMethodMixin
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
if TYPE_CHECKING:
from airflow.models.operator import Operator
@@ -432,9 +432,11 @@ class DagFileProcessor(LoggingMixin):
qry = (
session.query(TI.task_id,
func.max(DR.execution_date).label("max_ti"))
.join(TI.dag_run)
- .filter(TI.dag_id == dag.dag_id)
- .filter(or_(TI.state == State.SUCCESS, TI.state == State.SKIPPED))
- .filter(TI.task_id.in_(dag.task_ids))
+ .filter(
+ TI.dag_id == dag.dag_id,
+ or_(TI.state == TaskInstanceState.SUCCESS, TI.state ==
TaskInstanceState.SKIPPED),
+ TI.task_id.in_(dag.task_ids),
+ )
.group_by(TI.task_id)
.subquery("sq")
)
@@ -500,7 +502,11 @@ class DagFileProcessor(LoggingMixin):
sla_dates: list[datetime] = [sla.execution_date for sla in slas]
fetched_tis: list[TI] = (
session.query(TI)
- .filter(TI.state != State.SUCCESS,
TI.execution_date.in_(sla_dates), TI.dag_id == dag.dag_id)
+ .filter(
+ TI.state != TaskInstanceState.SUCCESS,
+ TI.execution_date.in_(sla_dates),
+ TI.dag_id == dag.dag_id,
+ )
.all()
)
blocking_tis: list[TI] = []
diff --git a/airflow/executors/base_executor.py
b/airflow/executors/base_executor.py
index 9599adabdf..6ae80dced1 100644
--- a/airflow/executors/base_executor.py
+++ b/airflow/executors/base_executor.py
@@ -31,7 +31,7 @@ from airflow.configuration import conf
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.stats import Stats
from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
PARALLELISM: int = conf.getint("core", "PARALLELISM")
@@ -54,7 +54,7 @@ if TYPE_CHECKING:
# Event_buffer dict value type
# Tuple of: state, info
- EventBufferValueType = Tuple[Optional[str], Any]
+ EventBufferValueType = Tuple[Optional[TaskInstanceState], Any]
# Task tuple to send to be executed
TaskTuple = Tuple[TaskInstanceKey, CommandType, Optional[str],
Optional[Any]]
@@ -298,7 +298,7 @@ class BaseExecutor(LoggingMixin):
self.execute_async(key=key, command=command, queue=queue,
executor_config=executor_config)
self.running.add(key)
- def change_state(self, key: TaskInstanceKey, state: str, info=None) ->
None:
+ def change_state(self, key: TaskInstanceKey, state: TaskInstanceState,
info=None) -> None:
"""
Changes state of the task.
@@ -320,7 +320,7 @@ class BaseExecutor(LoggingMixin):
:param info: Executor information for the task instance
:param key: Unique key for the task instance
"""
- self.change_state(key, State.FAILED, info)
+ self.change_state(key, TaskInstanceState.FAILED, info)
def success(self, key: TaskInstanceKey, info=None) -> None:
"""
@@ -329,7 +329,7 @@ class BaseExecutor(LoggingMixin):
:param info: Executor information for the task instance
:param key: Unique key for the task instance
"""
- self.change_state(key, State.SUCCESS, info)
+ self.change_state(key, TaskInstanceState.SUCCESS, info)
def get_event_buffer(self, dag_ids=None) -> dict[TaskInstanceKey,
EventBufferValueType]:
"""
diff --git a/airflow/executors/celery_executor.py
b/airflow/executors/celery_executor.py
index de59804a04..c9f4ded309 100644
--- a/airflow/executors/celery_executor.py
+++ b/airflow/executors/celery_executor.py
@@ -38,7 +38,7 @@ from airflow.configuration import conf
from airflow.exceptions import AirflowTaskTimeout
from airflow.executors.base_executor import BaseExecutor
from airflow.stats import Stats
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
log = logging.getLogger(__name__)
@@ -150,7 +150,7 @@ class CeleryExecutor(BaseExecutor):
self.task_publish_retries.pop(key, None)
if isinstance(result, ExceptionWithTraceback):
self.log.error(CELERY_SEND_ERR_MSG_HEADER + ": %s\n%s\n",
result.exception, result.traceback)
- self.event_buffer[key] = (State.FAILED, None)
+ self.event_buffer[key] = (TaskInstanceState.FAILED, None)
elif result is not None:
result.backend = cached_celery_backend
self.running.add(key)
@@ -159,7 +159,7 @@ class CeleryExecutor(BaseExecutor):
# Store the Celery task_id in the event buffer. This will get
"overwritten" if the task
# has another event, but that is fine, because the only other
events are success/failed at
# which point we don't need the ID anymore anyway
- self.event_buffer[key] = (State.QUEUED, result.task_id)
+ self.event_buffer[key] = (TaskInstanceState.QUEUED,
result.task_id)
# If the task runs _really quickly_ we may already have a
result!
self.update_task_state(key, result.state, getattr(result,
"info", None))
@@ -206,7 +206,7 @@ class CeleryExecutor(BaseExecutor):
if state:
self.update_task_state(key, state, info)
- def change_state(self, key: TaskInstanceKey, state: str, info=None) ->
None:
+ def change_state(self, key: TaskInstanceKey, state: TaskInstanceState,
info=None) -> None:
super().change_state(key, state, info)
self.tasks.pop(key, None)
diff --git a/airflow/executors/celery_executor_utils.py
b/airflow/executors/celery_executor_utils.py
index 2c8af4cf91..80460c3c8a 100644
--- a/airflow/executors/celery_executor_utils.py
+++ b/airflow/executors/celery_executor_utils.py
@@ -46,6 +46,7 @@ from airflow.stats import Stats
from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.net import get_hostname
+from airflow.utils.state import TaskInstanceState
from airflow.utils.timeout import timeout
log = logging.getLogger(__name__)
@@ -192,9 +193,10 @@ def send_task_to_executor(
return key, command, result
-def fetch_celery_task_state(async_result: AsyncResult) -> tuple[str, str |
ExceptionWithTraceback, Any]:
- """
- Fetch and return the state of the given celery task.
+def fetch_celery_task_state(
+ async_result: AsyncResult,
+) -> tuple[str, TaskInstanceState | ExceptionWithTraceback, Any]:
+ """Fetch and return the state of the given celery task.
The scope of this function is global so that it can be called by
subprocesses in the pool.
diff --git a/airflow/executors/debug_executor.py
b/airflow/executors/debug_executor.py
index ca23b09a67..8a46d6cda0 100644
--- a/airflow/executors/debug_executor.py
+++ b/airflow/executors/debug_executor.py
@@ -29,7 +29,7 @@ import time
from typing import TYPE_CHECKING, Any
from airflow.executors.base_executor import BaseExecutor
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstance
@@ -68,15 +68,15 @@ class DebugExecutor(BaseExecutor):
while self.tasks_to_run:
ti = self.tasks_to_run.pop(0)
if self.fail_fast and not task_succeeded:
- self.log.info("Setting %s to %s", ti.key,
State.UPSTREAM_FAILED)
- ti.set_state(State.UPSTREAM_FAILED)
- self.change_state(ti.key, State.UPSTREAM_FAILED)
+ self.log.info("Setting %s to %s", ti.key,
TaskInstanceState.UPSTREAM_FAILED)
+ ti.set_state(TaskInstanceState.UPSTREAM_FAILED)
+ self.change_state(ti.key, TaskInstanceState.UPSTREAM_FAILED)
continue
if self._terminated.is_set():
- self.log.info("Executor is terminated! Stopping %s to %s",
ti.key, State.FAILED)
- ti.set_state(State.FAILED)
- self.change_state(ti.key, State.FAILED)
+ self.log.info("Executor is terminated! Stopping %s to %s",
ti.key, TaskInstanceState.FAILED)
+ ti.set_state(TaskInstanceState.FAILED)
+ self.change_state(ti.key, TaskInstanceState.FAILED)
continue
task_succeeded = self._run_task(ti)
@@ -87,11 +87,11 @@ class DebugExecutor(BaseExecutor):
try:
params = self.tasks_params.pop(ti.key, {})
ti.run(job_id=ti.job_id, **params)
- self.change_state(key, State.SUCCESS)
+ self.change_state(key, TaskInstanceState.SUCCESS)
return True
except Exception as e:
- ti.set_state(State.FAILED)
- self.change_state(key, State.FAILED)
+ ti.set_state(TaskInstanceState.FAILED)
+ self.change_state(key, TaskInstanceState.FAILED)
self.log.exception("Failed to execute task: %s.", str(e))
return False
@@ -148,14 +148,14 @@ class DebugExecutor(BaseExecutor):
def end(self) -> None:
"""Set states of queued tasks to UPSTREAM_FAILED marking them as not
executed."""
for ti in self.tasks_to_run:
- self.log.info("Setting %s to %s", ti.key, State.UPSTREAM_FAILED)
- ti.set_state(State.UPSTREAM_FAILED)
- self.change_state(ti.key, State.UPSTREAM_FAILED)
+ self.log.info("Setting %s to %s", ti.key,
TaskInstanceState.UPSTREAM_FAILED)
+ ti.set_state(TaskInstanceState.UPSTREAM_FAILED)
+ self.change_state(ti.key, TaskInstanceState.UPSTREAM_FAILED)
def terminate(self) -> None:
self._terminated.set()
- def change_state(self, key: TaskInstanceKey, state: str, info=None) ->
None:
+ def change_state(self, key: TaskInstanceKey, state: TaskInstanceState,
info=None) -> None:
self.log.debug("Popping %s from executor task queue.", key)
self.running.remove(key)
self.event_buffer[key] = state, info
diff --git a/airflow/executors/kubernetes_executor.py
b/airflow/executors/kubernetes_executor.py
index 8707bf9699..092a84f470 100644
--- a/airflow/executors/kubernetes_executor.py
+++ b/airflow/executors/kubernetes_executor.py
@@ -53,7 +53,7 @@ from airflow.kubernetes.pod_generator import PodGenerator
from airflow.utils.event_scheduler import EventScheduler
from airflow.utils.log.logging_mixin import LoggingMixin, remove_escape_codes
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import State, TaskInstanceState
+from airflow.utils.state import TaskInstanceState
if TYPE_CHECKING:
from airflow.executors.base_executor import CommandType
@@ -228,12 +228,16 @@ class KubernetesJobWatcher(multiprocessing.Process,
LoggingMixin):
# since kube server have received request to delete pod set TI
state failed
if event["type"] == "DELETED" and pod.metadata.deletion_timestamp:
self.log.info("Event: Failed to start pod %s, annotations:
%s", pod_name, annotations_string)
- self.watcher_queue.put((pod_name, namespace, State.FAILED,
annotations, resource_version))
+ self.watcher_queue.put(
+ (pod_name, namespace, TaskInstanceState.FAILED,
annotations, resource_version),
+ )
else:
self.log.debug("Event: %s Pending, annotations: %s", pod_name,
annotations_string)
elif status == "Failed":
self.log.error("Event: %s Failed, annotations: %s", pod_name,
annotations_string)
- self.watcher_queue.put((pod_name, namespace, State.FAILED,
annotations, resource_version))
+ self.watcher_queue.put(
+ (pod_name, namespace, TaskInstanceState.FAILED, annotations,
resource_version),
+ )
elif status == "Succeeded":
# We get multiple events once the pod hits a terminal state, and
we only want to
# send it along to the scheduler once.
@@ -261,7 +265,9 @@ class KubernetesJobWatcher(multiprocessing.Process,
LoggingMixin):
pod_name,
annotations_string,
)
- self.watcher_queue.put((pod_name, namespace, State.FAILED,
annotations, resource_version))
+ self.watcher_queue.put(
+ (pod_name, namespace, TaskInstanceState.FAILED,
annotations, resource_version),
+ )
else:
self.log.info("Event: %s is Running, annotations: %s",
pod_name, annotations_string)
else:
@@ -700,7 +706,7 @@ class KubernetesExecutor(BaseExecutor):
last_resource_version[namespace] = resource_version
self.log.info("Changing state of %s to %s", results, state)
try:
- self._change_state(key, state, pod_name, namespace)
+ self._change_state(key, TaskInstanceState(state),
pod_name, namespace)
except Exception as e:
self.log.exception(
"Exception: %s when attempting to change state of
%s to %s, re-queueing.",
@@ -767,7 +773,7 @@ class KubernetesExecutor(BaseExecutor):
def _change_state(
self,
key: TaskInstanceKey,
- state: str | None,
+ state: TaskInstanceState | None,
pod_name: str,
namespace: str,
session: Session = NEW_SESSION,
@@ -776,12 +782,12 @@ class KubernetesExecutor(BaseExecutor):
assert self.kube_scheduler
from airflow.models.taskinstance import TaskInstance
- if state == State.RUNNING:
+ if state == TaskInstanceState.RUNNING:
self.event_buffer[key] = state, None
return
if self.kube_config.delete_worker_pods:
- if state != State.FAILED or
self.kube_config.delete_worker_pods_on_failure:
+ if state != TaskInstanceState.FAILED or
self.kube_config.delete_worker_pods_on_failure:
self.kube_scheduler.delete_pod(pod_name=pod_name,
namespace=namespace)
self.log.info("Deleted pod: %s in namespace %s", str(key),
str(namespace))
else:
@@ -1011,7 +1017,7 @@ class KubernetesExecutor(BaseExecutor):
"Changing state of %s to %s : resource_version=%d",
results, state, resource_version
)
try:
- self._change_state(key, state, pod_name, namespace)
+ self._change_state(key, TaskInstanceState(state),
pod_name, namespace)
except Exception as e:
self.log.exception(
"Ignoring exception: %s when attempting to change
state of %s to %s.",
diff --git a/airflow/executors/local_executor.py
b/airflow/executors/local_executor.py
index 715bdb42ae..550a6519e1 100644
--- a/airflow/executors/local_executor.py
+++ b/airflow/executors/local_executor.py
@@ -39,7 +39,7 @@ from airflow import settings
from airflow.exceptions import AirflowException
from airflow.executors.base_executor import PARALLELISM, BaseExecutor
from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
if TYPE_CHECKING:
from airflow.executors.base_executor import CommandType
@@ -94,20 +94,20 @@ class LocalWorkerBase(Process, LoggingMixin):
# Remove the command since the worker is done executing the task
setproctitle("airflow worker -- LocalExecutor")
- def _execute_work_in_subprocess(self, command: CommandType) -> str:
+ def _execute_work_in_subprocess(self, command: CommandType) ->
TaskInstanceState:
try:
subprocess.check_call(command, close_fds=True)
- return State.SUCCESS
+ return TaskInstanceState.SUCCESS
except subprocess.CalledProcessError as e:
self.log.error("Failed to execute task %s.", str(e))
- return State.FAILED
+ return TaskInstanceState.FAILED
- def _execute_work_in_fork(self, command: CommandType) -> str:
+ def _execute_work_in_fork(self, command: CommandType) -> TaskInstanceState:
pid = os.fork()
if pid:
# In parent, wait for the child
pid, ret = os.waitpid(pid, 0)
- return State.SUCCESS if ret == 0 else State.FAILED
+ return TaskInstanceState.SUCCESS if ret == 0 else
TaskInstanceState.FAILED
from airflow.sentry import Sentry
@@ -130,10 +130,10 @@ class LocalWorkerBase(Process, LoggingMixin):
args.func(args)
ret = 0
- return State.SUCCESS
+ return TaskInstanceState.SUCCESS
except Exception as e:
self.log.exception("Failed to execute task %s.", e)
- return State.FAILED
+ return TaskInstanceState.FAILED
finally:
Sentry.flush()
logging.shutdown()
diff --git a/airflow/executors/sequential_executor.py
b/airflow/executors/sequential_executor.py
index 28f88c6b87..2715edad6e 100644
--- a/airflow/executors/sequential_executor.py
+++ b/airflow/executors/sequential_executor.py
@@ -28,7 +28,7 @@ import subprocess
from typing import TYPE_CHECKING, Any
from airflow.executors.base_executor import BaseExecutor
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
if TYPE_CHECKING:
from airflow.executors.base_executor import CommandType
@@ -75,9 +75,9 @@ class SequentialExecutor(BaseExecutor):
try:
subprocess.check_call(command, close_fds=True)
- self.change_state(key, State.SUCCESS)
+ self.change_state(key, TaskInstanceState.SUCCESS)
except subprocess.CalledProcessError as e:
- self.change_state(key, State.FAILED)
+ self.change_state(key, TaskInstanceState.FAILED)
self.log.error("Failed to execute task %s.", str(e))
self.commands_to_run = []
diff --git a/airflow/jobs/backfill_job_runner.py
b/airflow/jobs/backfill_job_runner.py
index 40c31c4451..a0efcb17af 100644
--- a/airflow/jobs/backfill_job_runner.py
+++ b/airflow/jobs/backfill_job_runner.py
@@ -68,7 +68,7 @@ class BackfillJobRunner(BaseJobRunner[Job], LoggingMixin):
job_type = "BackfillJob"
- STATES_COUNT_AS_RUNNING = (State.RUNNING, State.QUEUED)
+ STATES_COUNT_AS_RUNNING = (TaskInstanceState.RUNNING,
TaskInstanceState.QUEUED)
@attr.define
class _DagRunTaskStatus:
@@ -219,7 +219,7 @@ class BackfillJobRunner(BaseJobRunner[Job], LoggingMixin):
# is changed externally, e.g. by clearing tasks from the ui. We
need to cover
# for that as otherwise those tasks would fall outside the scope of
# the backfill suddenly.
- elif ti.state == State.NONE:
+ elif ti.state is None:
self.log.warning(
"FIXME: task instance %s state was set to none externally
or "
"reaching concurrency limits. Re-adding task to queue.",
@@ -1004,7 +1004,7 @@ class BackfillJobRunner(BaseJobRunner[Job], LoggingMixin):
).all()
for ti in reset_tis:
- ti.state = State.NONE
+ ti.state = None
session.merge(ti)
return result + reset_tis
diff --git a/airflow/jobs/job.py b/airflow/jobs/job.py
index fe75508f63..903cfb9aca 100644
--- a/airflow/jobs/job.py
+++ b/airflow/jobs/job.py
@@ -19,7 +19,7 @@ from __future__ import annotations
from functools import cached_property
from time import sleep
-from typing import Callable, NoReturn
+from typing import TYPE_CHECKING, Callable, NoReturn
from sqlalchemy import Column, Index, Integer, String, case, select
from sqlalchemy.exc import OperationalError
@@ -42,6 +42,9 @@ from airflow.utils.session import NEW_SESSION,
create_session, provide_session
from airflow.utils.sqlalchemy import UtcDateTime
from airflow.utils.state import State
+if TYPE_CHECKING:
+ from airflow.executors.base_executor import BaseExecutor
+
def _resolve_dagrun_model():
from airflow.models.dagrun import DagRun
@@ -117,7 +120,7 @@ class Job(Base, LoggingMixin):
super().__init__(**kwargs)
@cached_property
- def executor(self):
+ def executor(self) -> BaseExecutor:
return ExecutorLoader.get_default_executor()
def is_alive(self, grace_multiplier=2.1):
diff --git a/airflow/jobs/local_task_job_runner.py
b/airflow/jobs/local_task_job_runner.py
index 9f6a4b55e8..fd234e4150 100644
--- a/airflow/jobs/local_task_job_runner.py
+++ b/airflow/jobs/local_task_job_runner.py
@@ -35,7 +35,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.net import get_hostname
from airflow.utils.platform import IS_WINDOWS
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
SIGSEGV_MESSAGE = """
******************************************* Received SIGSEGV
*******************************************
@@ -243,7 +243,7 @@ class LocalTaskJobRunner(BaseJobRunner["Job |
JobPydantic"], LoggingMixin):
self.task_instance.refresh_from_db()
ti = self.task_instance
- if ti.state == State.RUNNING:
+ if ti.state == TaskInstanceState.RUNNING:
fqdn = get_hostname()
same_hostname = fqdn == ti.hostname
if not same_hostname:
@@ -273,7 +273,7 @@ class LocalTaskJobRunner(BaseJobRunner["Job |
JobPydantic"], LoggingMixin):
)
raise AirflowException("PID of job runner does not match")
elif self.task_runner.return_code() is None and
hasattr(self.task_runner, "process"):
- if ti.state == State.SKIPPED:
+ if ti.state == TaskInstanceState.SKIPPED:
# A DagRun timeout will cause tasks to be externally marked as
skipped.
dagrun = ti.get_dagrun(session=session)
execution_time = (dagrun.end_date or timezone.utcnow()) -
dagrun.start_date
diff --git a/airflow/jobs/scheduler_job_runner.py
b/airflow/jobs/scheduler_job_runner.py
index a85a0925fb..16764b1980 100644
--- a/airflow/jobs/scheduler_job_runner.py
+++ b/airflow/jobs/scheduler_job_runner.py
@@ -610,7 +610,7 @@ class SchedulerJobRunner(BaseJobRunner[Job], LoggingMixin):
)
for ti in executable_tis:
- ti.emit_state_change_metric(State.QUEUED)
+ ti.emit_state_change_metric(TaskInstanceState.QUEUED)
for ti in executable_tis:
make_transient(ti)
@@ -626,7 +626,7 @@ class SchedulerJobRunner(BaseJobRunner[Job], LoggingMixin):
# actually enqueue them
for ti in task_instances:
if ti.dag_run.state in State.finished:
- ti.set_state(State.NONE, session=session)
+ ti.set_state(None, session=session)
continue
command = ti.command_as_list(
local=True,
@@ -682,13 +682,13 @@ class SchedulerJobRunner(BaseJobRunner[Job],
LoggingMixin):
# Report execution
for ti_key, value in event_buffer.items():
- state: str
+ state: TaskInstanceState | None
state, _ = value
# We create map (dag_id, task_id, execution_date) -> in-memory
try_number
ti_primary_key_to_try_number_map[ti_key.primary] =
ti_key.try_number
self.log.info("Received executor event with state %s for task
instance %s", state, ti_key)
- if state in (State.FAILED, State.SUCCESS, State.QUEUED):
+ if state in (TaskInstanceState.FAILED, TaskInstanceState.SUCCESS,
TaskInstanceState.QUEUED):
tis_with_right_state.append(ti_key)
# Return if no finished tasks
@@ -712,7 +712,7 @@ class SchedulerJobRunner(BaseJobRunner[Job], LoggingMixin):
buffer_key = ti.key.with_try_number(try_number)
state, info = event_buffer.pop(buffer_key)
- if state == State.QUEUED:
+ if state == TaskInstanceState.QUEUED:
ti.external_executor_id = info
self.log.info("Setting external_id for %s to %s", ti, info)
continue
@@ -1535,7 +1535,7 @@ class SchedulerJobRunner(BaseJobRunner[Job],
LoggingMixin):
tasks_stuck_in_queued = session.scalars(
select(TI).where(
- TI.state == State.QUEUED,
+ TI.state == TaskInstanceState.QUEUED,
TI.queued_dttm < (timezone.utcnow() -
timedelta(seconds=self._task_queued_timeout)),
TI.queued_by_job_id == self.job.id,
)
@@ -1602,20 +1602,19 @@ class SchedulerJobRunner(BaseJobRunner[Job],
LoggingMixin):
self.log.info("Marked %d SchedulerJob instances as
failed", num_failed)
Stats.incr(self.__class__.__name__.lower() + "_end",
num_failed)
- resettable_states = [TaskInstanceState.QUEUED,
TaskInstanceState.RUNNING]
query = (
select(TI)
- .where(TI.state.in_(resettable_states))
+ .where(TI.state.in_((TaskInstanceState.QUEUED,
TaskInstanceState.RUNNING)))
# outerjoin is because we didn't use to have
queued_by_job
# set, so we need to pick up anything pre upgrade.
This (and the
# "or queued_by_job_id IS NONE") can go as soon as
scheduler HA is
# released.
.outerjoin(TI.queued_by_job)
- .where(or_(TI.queued_by_job_id.is_(None), Job.state !=
State.RUNNING))
+ .where(or_(TI.queued_by_job_id.is_(None), Job.state !=
TaskInstanceState.RUNNING))
.join(TI.dag_run)
.where(
DagRun.run_type != DagRunType.BACKFILL_JOB,
- DagRun.state == State.RUNNING,
+ DagRun.state == DagRunState.RUNNING,
)
.options(load_only(TI.dag_id, TI.task_id, TI.run_id))
)
@@ -1630,7 +1629,7 @@ class SchedulerJobRunner(BaseJobRunner[Job],
LoggingMixin):
reset_tis_message = []
for ti in to_reset:
reset_tis_message.append(repr(ti))
- ti.state = State.NONE
+ ti.state = None
ti.queued_by_job_id = None
for ti in set(tis_to_reset_or_adopt) - set(to_reset):
@@ -1697,12 +1696,7 @@ class SchedulerJobRunner(BaseJobRunner[Job],
LoggingMixin):
.join(Job, TI.job_id == Job.id)
.join(DM, TI.dag_id == DM.dag_id)
.where(TI.state == TaskInstanceState.RUNNING)
- .where(
- or_(
- Job.state != State.RUNNING,
- Job.latest_heartbeat < limit_dttm,
- )
- )
+ .where(or_(Job.state != State.RUNNING,
Job.latest_heartbeat < limit_dttm))
.where(Job.job_type == "LocalTaskJob")
.where(TI.queued_by_job_id == self.job.id)
)
diff --git a/airflow/listeners/spec/taskinstance.py
b/airflow/listeners/spec/taskinstance.py
index 78de8a5f62..56b4cb7322 100644
--- a/airflow/listeners/spec/taskinstance.py
+++ b/airflow/listeners/spec/taskinstance.py
@@ -34,18 +34,18 @@ hookspec = HookspecMarker("airflow")
def on_task_instance_running(
previous_state: TaskInstanceState, task_instance: TaskInstance, session:
Session | None
):
- """Called when task state changes to RUNNING. Previous_state can be
State.NONE."""
+ """Called when task state changes to RUNNING. Previous state can be
None."""
@hookspec
def on_task_instance_success(
previous_state: TaskInstanceState, task_instance: TaskInstance, session:
Session | None
):
- """Called when task state changes to SUCCESS. Previous_state can be
State.NONE."""
+ """Called when task state changes to SUCCESS. Previous state can be
None."""
@hookspec
def on_task_instance_failed(
previous_state: TaskInstanceState, task_instance: TaskInstance, session:
Session | None
):
- """Called when task state changes to FAIL. Previous_state can be
State.NONE."""
+ """Called when task state changes to FAIL. Previous state can be None."""
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index b7361d6f3b..57b15e5f32 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -103,7 +103,7 @@ from airflow.utils.helpers import at_most_one, exactly_one,
validate_key
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked,
tuple_in_condition, with_row_locks
-from airflow.utils.state import DagRunState, State, TaskInstanceState
+from airflow.utils.state import DagRunState, TaskInstanceState
from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType
if TYPE_CHECKING:
@@ -1281,7 +1281,7 @@ class DAG(LoggingMixin):
TI = TaskInstance
qry = session.query(func.count(TI.task_id)).filter(
TI.dag_id == self.dag_id,
- TI.state == State.RUNNING,
+ TI.state == TaskInstanceState.RUNNING,
)
return qry.scalar() >= self.max_active_tasks
@@ -1368,7 +1368,7 @@ class DAG(LoggingMixin):
:return: List of execution dates
"""
- runs = DagRun.find(dag_id=self.dag_id, state=State.RUNNING)
+ runs = DagRun.find(dag_id=self.dag_id, state=DagRunState.RUNNING)
active_dates = []
for run in runs:
@@ -1388,9 +1388,9 @@ class DAG(LoggingMixin):
# .count() is inefficient
query = session.query(func.count()).filter(DagRun.dag_id ==
self.dag_id)
if only_running:
- query = query.filter(DagRun.state == State.RUNNING)
+ query = query.filter(DagRun.state == DagRunState.RUNNING)
else:
- query = query.filter(DagRun.state.in_({State.RUNNING,
State.QUEUED}))
+ query = query.filter(DagRun.state.in_({DagRunState.RUNNING,
DagRunState.QUEUED}))
if external_trigger is not None:
query = query.filter(
@@ -2077,7 +2077,7 @@ class DAG(LoggingMixin):
@provide_session
def set_dag_runs_state(
self,
- state: str = State.RUNNING,
+ state: DagRunState = DagRunState.RUNNING,
session: Session = NEW_SESSION,
start_date: datetime | None = None,
end_date: datetime | None = None,
@@ -2160,10 +2160,10 @@ class DAG(LoggingMixin):
state = []
if only_failed:
- state += [State.FAILED, State.UPSTREAM_FAILED]
+ state += [TaskInstanceState.FAILED,
TaskInstanceState.UPSTREAM_FAILED]
if only_running:
# Yes, having `+=` doesn't make sense, but this was the existing
behaviour
- state += [State.RUNNING]
+ state += [TaskInstanceState.RUNNING]
tis = self._get_task_instances(
task_ids=task_ids,
@@ -2694,7 +2694,7 @@ class DAG(LoggingMixin):
# Instead of starting a scheduler, we run the minimal loop possible to
check
# for task readiness and dependency management. This is notably faster
# than creating a BackfillJob and allows us to surface logs to the user
- while dr.state == State.RUNNING:
+ while dr.state == DagRunState.RUNNING:
schedulable_tis, _ = dr.update_state(session=session)
try:
for ti in schedulable_tis:
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 45989be15a..bd94977599 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -111,7 +111,7 @@ class DagRun(Base, LoggingMixin):
execution_date = Column(UtcDateTime, default=timezone.utcnow,
nullable=False)
start_date = Column(UtcDateTime)
end_date = Column(UtcDateTime)
- _state = Column("state", String(50), default=State.QUEUED)
+ _state = Column("state", String(50), default=DagRunState.QUEUED)
run_id = Column(StringID(), nullable=False)
creating_job_id = Column(Integer)
external_trigger = Column(Boolean, default=True)
@@ -218,7 +218,7 @@ class DagRun(Base, LoggingMixin):
if state is not None:
self.state = state
if queued_at is NOTSET:
- self.queued_at = timezone.utcnow() if state == State.QUEUED else
None
+ self.queued_at = timezone.utcnow() if state == DagRunState.QUEUED
else None
else:
self.queued_at = queued_at
self.run_type = run_type
@@ -256,7 +256,7 @@ class DagRun(Base, LoggingMixin):
if self._state != state:
self._state = state
self.end_date = timezone.utcnow() if self._state in State.finished
else None
- if state == State.QUEUED:
+ if state == DagRunState.QUEUED:
self.queued_at = timezone.utcnow()
@declared_attr
@@ -289,9 +289,9 @@ class DagRun(Base, LoggingMixin):
# because SQLAlchemy doesn't accept a set here.
query = query.filter(cls.dag_id.in_(set(dag_ids)))
if only_running:
- query = query.filter(cls.state == State.RUNNING)
+ query = query.filter(cls.state == DagRunState.RUNNING)
else:
- query = query.filter(cls.state.in_([State.RUNNING, State.QUEUED]))
+ query = query.filter(cls.state.in_((DagRunState.RUNNING,
DagRunState.QUEUED)))
query = query.group_by(cls.dag_id)
return {dag_id: count for dag_id, count in query.all()}
@@ -323,7 +323,7 @@ class DagRun(Base, LoggingMixin):
.join(DagModel, DagModel.dag_id == cls.dag_id)
.filter(DagModel.is_paused == false(), DagModel.is_active ==
true())
)
- if state == State.QUEUED:
+ if state == DagRunState.QUEUED:
# For dag runs in the queued state, we check if they have reached
the max_active_runs limit
# and if so we drop them
running_drs = (
@@ -462,7 +462,7 @@ class DagRun(Base, LoggingMixin):
tis = tis.filter(TI.state == state)
else:
# this is required to deal with NULL values
- if State.NONE in state:
+ if None in state:
if all(x is None for x in state):
tis = tis.filter(TI.state.is_(None))
else:
@@ -734,9 +734,9 @@ class DagRun(Base, LoggingMixin):
try:
ti.task = dag.get_task(ti.task_id)
except TaskNotFound:
- if ti.state != State.REMOVED:
+ if ti.state != TaskInstanceState.REMOVED:
self.log.error("Failed to get task for ti %s. Marking
it as removed.", ti)
- ti.state = State.REMOVED
+ ti.state = TaskInstanceState.REMOVED
session.flush()
else:
yield ti
@@ -945,7 +945,7 @@ class DagRun(Base, LoggingMixin):
self.log.warning("Failed to record first_task_scheduling_delay
metric:", exc_info=True)
def _emit_duration_stats_for_finished_state(self):
- if self.state == State.RUNNING:
+ if self.state == DagRunState.RUNNING:
return
if self.start_date is None:
self.log.warning("Failed to record duration of %s: start_date is
not set.", self)
@@ -1017,22 +1017,22 @@ class DagRun(Base, LoggingMixin):
try:
task = dag.get_task(ti.task_id)
- should_restore_task = (task is not None) and ti.state ==
State.REMOVED
+ should_restore_task = (task is not None) and ti.state ==
TaskInstanceState.REMOVED
if should_restore_task:
self.log.info("Restoring task '%s' which was previously
removed from DAG '%s'", ti, dag)
Stats.incr(f"task_restored_to_dag.{dag.dag_id}",
tags=self.stats_tags)
# Same metric with tagging
Stats.incr("task_restored_to_dag",
tags={**self.stats_tags, "dag_id": dag.dag_id})
- ti.state = State.NONE
+ ti.state = None
except AirflowException:
- if ti.state == State.REMOVED:
+ if ti.state == TaskInstanceState.REMOVED:
pass # ti has already been removed, just ignore it
- elif self.state != State.RUNNING and not dag.partial:
+ elif self.state != DagRunState.RUNNING and not dag.partial:
self.log.warning("Failed to get task '%s' for dag '%s'.
Marking it as removed.", ti, dag)
Stats.incr(f"task_removed_from_dag.{dag.dag_id}",
tags=self.stats_tags)
# Same metric with tagging
Stats.incr("task_removed_from_dag",
tags={**self.stats_tags, "dag_id": dag.dag_id})
- ti.state = State.REMOVED
+ ti.state = TaskInstanceState.REMOVED
continue
try:
@@ -1049,7 +1049,7 @@ class DagRun(Base, LoggingMixin):
self.log.debug(
"Removing the unmapped TI '%s' as the mapping
can't be resolved yet", ti
)
- ti.state = State.REMOVED
+ ti.state = TaskInstanceState.REMOVED
continue
# Upstreams finished, check there aren't any extras
if ti.map_index >= total_length:
@@ -1058,7 +1058,7 @@ class DagRun(Base, LoggingMixin):
ti,
total_length,
)
- ti.state = State.REMOVED
+ ti.state = TaskInstanceState.REMOVED
else:
# Check if the number of mapped literals has changed, and we
need to mark this TI as removed.
if ti.map_index >= num_mapped_tis:
@@ -1067,10 +1067,10 @@ class DagRun(Base, LoggingMixin):
ti,
num_mapped_tis,
)
- ti.state = State.REMOVED
+ ti.state = TaskInstanceState.REMOVED
elif ti.map_index < 0:
self.log.debug("Removing the unmapped TI '%s' as the
mapping can now be performed", ti)
- ti.state = State.REMOVED
+ ti.state = TaskInstanceState.REMOVED
return task_ids
@@ -1342,7 +1342,7 @@ class DagRun(Base, LoggingMixin):
TI.run_id == self.run_id,
tuple_in_condition((TI.task_id, TI.map_index),
schedulable_ti_ids_chunk),
)
- .update({TI.state: State.SCHEDULED},
synchronize_session=False)
+ .update({TI.state: TaskInstanceState.SCHEDULED},
synchronize_session=False)
)
# Tasks using EmptyOperator should not be executed, mark them as
success
@@ -1358,7 +1358,7 @@ class DagRun(Base, LoggingMixin):
)
.update(
{
- TI.state: State.SUCCESS,
+ TI.state: TaskInstanceState.SUCCESS,
TI.start_date: timezone.utcnow(),
TI.end_date: timezone.utcnow(),
TI.duration: 0,
diff --git a/airflow/models/pool.py b/airflow/models/pool.py
index d1766d4a0a..cfad662691 100644
--- a/airflow/models/pool.py
+++ b/airflow/models/pool.py
@@ -28,7 +28,7 @@ from airflow.ti_deps.dependencies_states import
EXECUTION_STATES
from airflow.typing_compat import TypedDict
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import nowait, with_row_locks
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
class PoolStats(TypedDict):
@@ -245,7 +245,7 @@ class Pool(Base):
return int(
session.query(func.sum(TaskInstance.pool_slots))
.filter(TaskInstance.pool == self.pool)
- .filter(TaskInstance.state == State.RUNNING)
+ .filter(TaskInstance.state == TaskInstanceState.RUNNING)
.scalar()
or 0
)
@@ -263,7 +263,7 @@ class Pool(Base):
return int(
session.query(func.sum(TaskInstance.pool_slots))
.filter(TaskInstance.pool == self.pool)
- .filter(TaskInstance.state == State.QUEUED)
+ .filter(TaskInstance.state == TaskInstanceState.QUEUED)
.scalar()
or 0
)
@@ -281,7 +281,7 @@ class Pool(Base):
return int(
session.query(func.sum(TaskInstance.pool_slots))
.filter(TaskInstance.pool == self.pool)
- .filter(TaskInstance.state == State.SCHEDULED)
+ .filter(TaskInstance.state == TaskInstanceState.SCHEDULED)
.scalar()
or 0
)
diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index d75a4a0e4d..73a733bd03 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -26,7 +26,7 @@ from airflow.serialization.pydantic.dag_run import
DagRunPydantic
from airflow.utils import timezone
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, create_session, provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
if TYPE_CHECKING:
from pendulum import DateTime
@@ -72,7 +72,7 @@ class SkipMixin(LoggingMixin):
TaskInstance.task_id.in_(d.task_id for d in tasks),
).update(
{
- TaskInstance.state: State.SKIPPED,
+ TaskInstance.state: TaskInstanceState.SKIPPED,
TaskInstance.start_date: now,
TaskInstance.end_date: now,
},
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 28dc168ec3..829c630786 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -591,7 +591,7 @@ class TaskInstance(Base, LoggingMixin):
database, in all other cases this will be incremented.
"""
# This is designed so that task logs end up in the right file.
- if self.state == State.RUNNING:
+ if self.state == TaskInstanceState.RUNNING:
return self._try_number
return self._try_number + 1
@@ -804,7 +804,7 @@ class TaskInstance(Base, LoggingMixin):
:param session: SQLAlchemy ORM Session
"""
self.log.error("Recording the task instance as FAILED")
- self.state = State.FAILED
+ self.state = TaskInstanceState.FAILED
session.merge(self)
session.commit()
@@ -931,7 +931,7 @@ class TaskInstance(Base, LoggingMixin):
self.log.debug("Setting task state for %s to %s", self, state)
self.state = state
self.start_date = self.start_date or current_time
- if self.state in State.finished or self.state == State.UP_FOR_RETRY:
+ if self.state in State.finished or self.state ==
TaskInstanceState.UP_FOR_RETRY:
self.end_date = self.end_date or current_time
self.duration = (self.end_date - self.start_date).total_seconds()
session.merge(self)
@@ -944,7 +944,7 @@ class TaskInstance(Base, LoggingMixin):
has elapsed.
"""
# is the task still in the retry waiting period?
- return self.state == State.UP_FOR_RETRY and not self.ready_for_retry()
+ return self.state == TaskInstanceState.UP_FOR_RETRY and not
self.ready_for_retry()
@provide_session
def are_dependents_done(self, session: Session = NEW_SESSION) -> bool:
@@ -967,7 +967,7 @@ class TaskInstance(Base, LoggingMixin):
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id.in_(task.downstream_task_ids),
TaskInstance.run_id == self.run_id,
- TaskInstance.state.in_([State.SKIPPED, State.SUCCESS]),
+ TaskInstance.state.in_((TaskInstanceState.SKIPPED,
TaskInstanceState.SUCCESS)),
)
count = ti[0][0]
return count == len(task.downstream_task_ids)
@@ -1204,7 +1204,7 @@ class TaskInstance(Base, LoggingMixin):
Checks on whether the task instance is in the right state and timeframe
to be retried.
"""
- return self.state == State.UP_FOR_RETRY and self.next_retry_datetime()
< timezone.utcnow()
+ return self.state == TaskInstanceState.UP_FOR_RETRY and
self.next_retry_datetime() < timezone.utcnow()
@provide_session
def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun:
@@ -1273,7 +1273,7 @@ class TaskInstance(Base, LoggingMixin):
self.hostname = get_hostname()
self.pid = None
- if not ignore_all_deps and not ignore_ti_state and self.state ==
State.SUCCESS:
+ if not ignore_all_deps and not ignore_ti_state and self.state ==
TaskInstanceState.SUCCESS:
Stats.incr("previously_succeeded", tags=self.stats_tags)
if not mark_success:
@@ -1301,7 +1301,7 @@ class TaskInstance(Base, LoggingMixin):
# start date that is recorded in task_reschedule table
# If the task continues after being deferred (next_method is set),
use the original start_date
self.start_date = self.start_date if self.next_method else
timezone.utcnow()
- if self.state == State.UP_FOR_RESCHEDULE:
+ if self.state == TaskInstanceState.UP_FOR_RESCHEDULE:
task_reschedule: TR = TR.query_for_task_instance(self,
session=session).first()
if task_reschedule:
self.start_date = task_reschedule.start_date
@@ -1319,7 +1319,7 @@ class TaskInstance(Base, LoggingMixin):
description="requeueable deps",
)
if not self.are_dependencies_met(dep_context=dep_context,
session=session, verbose=True):
- self.state = State.NONE
+ self.state = None
self.log.warning(
"Rescheduling due to concurrency limits reached "
"at task runtime. Attempt %s of "
@@ -1339,10 +1339,10 @@ class TaskInstance(Base, LoggingMixin):
self._try_number += 1
if not test_mode:
- session.add(Log(State.RUNNING, self))
+ session.add(Log(TaskInstanceState.RUNNING, self))
- self.state = State.RUNNING
- self.emit_state_change_metric(State.RUNNING)
+ self.state = TaskInstanceState.RUNNING
+ self.emit_state_change_metric(TaskInstanceState.RUNNING)
self.external_executor_id = external_executor_id
self.end_date = None
if not test_mode:
@@ -1398,7 +1398,7 @@ class TaskInstance(Base, LoggingMixin):
return
# switch on state and deduce which metric to send
- if new_state == State.RUNNING:
+ if new_state == TaskInstanceState.RUNNING:
metric_name = "queued_duration"
if self.queued_dttm is None:
# this should not really happen except in tests or rare cases,
@@ -1410,7 +1410,7 @@ class TaskInstance(Base, LoggingMixin):
)
return
timing = (timezone.utcnow() - self.queued_dttm).total_seconds()
- elif new_state == State.QUEUED:
+ elif new_state == TaskInstanceState.QUEUED:
metric_name = "scheduled_duration"
if self.start_date is None:
# same comment as above
@@ -1497,7 +1497,7 @@ class TaskInstance(Base, LoggingMixin):
self._execute_task_with_callbacks(context, test_mode)
if not test_mode:
self.refresh_from_db(lock_for_update=True, session=session)
- self.state = State.SUCCESS
+ self.state = TaskInstanceState.SUCCESS
except TaskDeferred as defer:
# The task has signalled it wants to defer execution based on
# a trigger.
@@ -1521,7 +1521,7 @@ class TaskInstance(Base, LoggingMixin):
self.log.info(e)
if not test_mode:
self.refresh_from_db(lock_for_update=True, session=session)
- self.state = State.SKIPPED
+ self.state = TaskInstanceState.SKIPPED
except AirflowRescheduleException as reschedule_exception:
self._handle_reschedule(actual_start_date, reschedule_exception,
test_mode, session=session)
session.commit()
@@ -1744,7 +1744,7 @@ class TaskInstance(Base, LoggingMixin):
# Then, update ourselves so it matches the deferral request
# Keep an eye on the logic in
`check_and_change_state_before_execution()`
# depending on self.next_method semantics
- self.state = State.DEFERRED
+ self.state = TaskInstanceState.DEFERRED
self.trigger_id = trigger_row.id
self.next_method = defer.method_name
self.next_kwargs = defer.kwargs or {}
@@ -1863,7 +1863,7 @@ class TaskInstance(Base, LoggingMixin):
)
# set state
- self.state = State.UP_FOR_RESCHEDULE
+ self.state = TaskInstanceState.UP_FOR_RESCHEDULE
# Decrement try_number so subsequent runs will use the same try number
and write
# to same log file.
@@ -1928,7 +1928,7 @@ class TaskInstance(Base, LoggingMixin):
Stats.incr("ti_failures", tags=self.stats_tags)
if not test_mode:
- session.add(Log(State.FAILED, self))
+ session.add(Log(TaskInstanceState.FAILED, self))
# Log failure duration
session.add(TaskFail(ti=self))
@@ -1962,7 +1962,7 @@ class TaskInstance(Base, LoggingMixin):
self.log.error("Unable to unmap task to determine if we need to
send an alert email")
if force_fail or not self.is_eligible_to_retry():
- self.state = State.FAILED
+ self.state = TaskInstanceState.FAILED
email_for_state = operator.attrgetter("email_on_failure")
callbacks = task.on_failure_callback if task else None
callback_type = "on_failure"
@@ -1971,10 +1971,10 @@ class TaskInstance(Base, LoggingMixin):
tis = self.get_dagrun(session).get_task_instances()
stop_all_tasks_in_dag(tis, session, self.task_id)
else:
- if self.state == State.QUEUED:
+ if self.state == TaskInstanceState.QUEUED:
# We increase the try_number so as to fail the task if it
fails to start after sometime
self._try_number += 1
- self.state = State.UP_FOR_RETRY
+ self.state = TaskInstanceState.UP_FOR_RETRY
email_for_state = operator.attrgetter("email_on_retry")
callbacks = task.on_retry_callback if task else None
callback_type = "on_retry"
@@ -1995,7 +1995,7 @@ class TaskInstance(Base, LoggingMixin):
def is_eligible_to_retry(self):
"""Is task instance is eligible for retry."""
- if self.state == State.RESTARTING:
+ if self.state == TaskInstanceState.RESTARTING:
# If a task is cleared when running, it goes into RESTARTING state
and is always
# eligible for retry
return True
@@ -2336,8 +2336,8 @@ class TaskInstance(Base, LoggingMixin):
'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>'
)
- # This function is called after changing the state from State.RUNNING,
- # so we need to subtract 1 from self.try_number here.
+ # This function is called after changing the state from RUNNING, so we
+ # need to subtract 1 from try_number here.
current_try_number = self.try_number - 1
additional_context: dict[str, Any] = {
"exception": exception,
@@ -2563,7 +2563,7 @@ class TaskInstance(Base, LoggingMixin):
num_running_task_instances_query = session.query(func.count()).filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id == self.task_id,
- TaskInstance.state == State.RUNNING,
+ TaskInstance.state == TaskInstanceState.RUNNING,
)
if same_dagrun:
num_running_task_instances_query =
num_running_task_instances_query.filter(
@@ -2885,7 +2885,7 @@ def _is_further_mapped_inside(operator: Operator,
container: TaskGroup) -> bool:
# State of the task instance.
# Stores string version of the task state.
-TaskInstanceStateType = Tuple[TaskInstanceKey, str]
+TaskInstanceStateType = Tuple[TaskInstanceKey, TaskInstanceState]
class SimpleTaskInstance:
diff --git a/airflow/operators/subdag.py b/airflow/operators/subdag.py
index de7afcb4ea..af784d5d32 100644
--- a/airflow/operators/subdag.py
+++ b/airflow/operators/subdag.py
@@ -36,7 +36,7 @@ from airflow.models.taskinstance import TaskInstance
from airflow.sensors.base import BaseSensorOperator
from airflow.utils.context import Context
from airflow.utils.session import NEW_SESSION, create_session, provide_session
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState, TaskInstanceState
from airflow.utils.types import DagRunType
@@ -137,17 +137,17 @@ class SubDagOperator(BaseSensorOperator):
:param execution_date: Execution date to select task instances.
"""
with create_session() as session:
- dag_run.state = State.RUNNING
+ dag_run.state = DagRunState.RUNNING
session.merge(dag_run)
failed_task_instances = (
session.query(TaskInstance)
.filter(TaskInstance.dag_id == self.subdag.dag_id)
.filter(TaskInstance.execution_date == execution_date)
- .filter(TaskInstance.state.in_([State.FAILED,
State.UPSTREAM_FAILED]))
+ .filter(TaskInstance.state.in_((TaskInstanceState.FAILED,
TaskInstanceState.UPSTREAM_FAILED)))
)
for task_instance in failed_task_instances:
- task_instance.state = State.NONE
+ task_instance.state = None
session.merge(task_instance)
session.commit()
@@ -164,7 +164,7 @@ class SubDagOperator(BaseSensorOperator):
dag_run = self.subdag.create_dagrun(
run_type=DagRunType.SCHEDULED,
execution_date=execution_date,
- state=State.RUNNING,
+ state=DagRunState.RUNNING,
conf=self.conf,
external_trigger=True,
data_interval=data_interval,
@@ -172,13 +172,13 @@ class SubDagOperator(BaseSensorOperator):
self.log.info("Created DagRun: %s", dag_run.run_id)
else:
self.log.info("Found existing DagRun: %s", dag_run.run_id)
- if dag_run.state == State.FAILED:
+ if dag_run.state == DagRunState.FAILED:
self._reset_dag_run_and_task_instances(dag_run, execution_date)
def poke(self, context: Context):
execution_date = context["execution_date"]
dag_run = self._get_dagrun(execution_date=execution_date)
- return dag_run.state != State.RUNNING
+ return dag_run.state != DagRunState.RUNNING
def post_execute(self, context, result=None):
super().post_execute(context)
@@ -186,7 +186,7 @@ class SubDagOperator(BaseSensorOperator):
dag_run = self._get_dagrun(execution_date=execution_date)
self.log.info("Execution finished. State is %s", dag_run.state)
- if dag_run.state != State.SUCCESS:
+ if dag_run.state != DagRunState.SUCCESS:
raise AirflowException(f"Expected state: SUCCESS. Actual state:
{dag_run.state}")
if self.propagate_skipped_state and
self._check_skipped_states(context):
@@ -196,9 +196,9 @@ class SubDagOperator(BaseSensorOperator):
leaves_tis = self._get_leaves_tis(context["execution_date"])
if self.propagate_skipped_state ==
SkippedStatePropagationOptions.ANY_LEAF:
- return any(ti.state == State.SKIPPED for ti in leaves_tis)
+ return any(ti.state == TaskInstanceState.SKIPPED for ti in
leaves_tis)
if self.propagate_skipped_state ==
SkippedStatePropagationOptions.ALL_LEAVES:
- return all(ti.state == State.SKIPPED for ti in leaves_tis)
+ return all(ti.state == TaskInstanceState.SKIPPED for ti in
leaves_tis)
raise AirflowException(
f"Unimplemented SkippedStatePropagationOptions
{self.propagate_skipped_state} used."
)
diff --git a/airflow/operators/trigger_dagrun.py
b/airflow/operators/trigger_dagrun.py
index f3560bffc7..a3a6bf7c7c 100644
--- a/airflow/operators/trigger_dagrun.py
+++ b/airflow/operators/trigger_dagrun.py
@@ -36,7 +36,7 @@ from airflow.utils import timezone
from airflow.utils.context import Context
from airflow.utils.helpers import build_airflow_url_with_query
from airflow.utils.session import provide_session
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunType
XCOM_EXECUTION_DATE_ISO = "trigger_execution_date_iso"
@@ -116,8 +116,8 @@ class TriggerDagRunOperator(BaseOperator):
self.reset_dag_run = reset_dag_run
self.wait_for_completion = wait_for_completion
self.poke_interval = poke_interval
- self.allowed_states = allowed_states or [State.SUCCESS]
- self.failed_states = failed_states or [State.FAILED]
+ self.allowed_states = allowed_states or [DagRunState.SUCCESS]
+ self.failed_states = failed_states or [DagRunState.FAILED]
self._defer = deferrable
if execution_date is not None and not isinstance(execution_date, (str,
datetime.datetime)):
diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py
index 959ebe5131..158032a2cc 100644
--- a/airflow/sensors/external_task.py
+++ b/airflow/sensors/external_task.py
@@ -37,7 +37,7 @@ from airflow.utils.file import correct_maybe_zipped
from airflow.utils.helpers import build_airflow_url_with_query
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import tuple_in_condition
-from airflow.utils.state import State
+from airflow.utils.state import State, TaskInstanceState
if TYPE_CHECKING:
from sqlalchemy.orm import Query, Session
@@ -76,23 +76,25 @@ class ExternalTaskSensor(BaseSensorOperator):
without also having to clear the sensor).
By default, the ExternalTaskSensor will not skip if the external task
skips.
- To change this, simply set ``skipped_states=[State.SKIPPED]``. Note that if
- you are monitoring multiple tasks, and one enters error state and the other
- enters a skipped state, then the external task will react to whichever one
- it sees first. If both happen together, then the failed state takes
priority.
+ To change this, simply set ``skipped_states=[TaskInstanceState.SKIPPED]``.
+ Note that if you are monitoring multiple tasks, and one enters error state
+ and the other enters a skipped state, then the external task will react to
+ whichever one it sees first. If both happen together, then the failed state
+ takes priority.
It is possible to alter the default behavior by setting states which
- cause the sensor to fail, e.g. by setting ``allowed_states=[State.FAILED]``
- and ``failed_states=[State.SUCCESS]`` you will flip the behaviour to get a
- sensor which goes green when the external task *fails* and immediately goes
- red if the external task *succeeds*!
+ cause the sensor to fail, e.g. by setting
+ ``allowed_states=[TaskInstanceState.FAILED]`` and
+ ``failed_states=[TaskInstanceState.SUCCESS]``, you will flip the behaviour
+ to get a sensor which goes green when the external task *fails* and
+ immediately goes red if the external task *succeeds*!
Note that ``soft_fail`` is respected when examining the failed_states. Thus
if the external task enters a failed state and ``soft_fail == True`` the
sensor will _skip_ rather than fail. As a result, setting
``soft_fail=True``
- and ``failed_states=[State.SKIPPED]`` will result in the sensor skipping if
- the external task skips. However, this is a contrived example - consider
- using ``skipped_states`` if you would like this behaviour. Using
+ and ``failed_states=[TaskInstanceState.SKIPPED]`` will result in the sensor
+ skipping if the external task skips. However, this is a contrived example;
+ consider using ``skipped_states`` if you would like this behaviour. Using
``skipped_states`` allows the sensor to skip if the target fails, but still
enter failed state on timeout. Using ``soft_fail == True`` as above will
cause the sensor to skip if the target fails, but also if it times out.
@@ -146,7 +148,7 @@ class ExternalTaskSensor(BaseSensorOperator):
**kwargs,
):
super().__init__(**kwargs)
- self.allowed_states = list(allowed_states) if allowed_states else
[State.SUCCESS]
+ self.allowed_states = list(allowed_states) if allowed_states else
[TaskInstanceState.SUCCESS]
self.skipped_states = list(skipped_states) if skipped_states else []
self.failed_states = list(failed_states) if failed_states else []
diff --git a/airflow/sentry.py b/airflow/sentry.py
index 8742552e7e..fbc1715eed 100644
--- a/airflow/sentry.py
+++ b/airflow/sentry.py
@@ -25,7 +25,7 @@ from typing import TYPE_CHECKING
from airflow.configuration import conf
from airflow.executors.executor_loader import ExecutorLoader
from airflow.utils.session import find_session_idx, provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
if TYPE_CHECKING:
from sqlalchemy.orm import Session
@@ -143,7 +143,7 @@ if conf.getboolean("sentry", "sentry_on", fallback=False):
return
dr = task_instance.get_dagrun(session)
task_instances = dr.get_task_instances(
- state={State.SUCCESS, State.FAILED},
+ state={TaskInstanceState.SUCCESS, TaskInstanceState.FAILED},
session=session,
)
diff --git a/airflow/ti_deps/dependencies_states.py
b/airflow/ti_deps/dependencies_states.py
index 543ce3528a..fd25d62f6d 100644
--- a/airflow/ti_deps/dependencies_states.py
+++ b/airflow/ti_deps/dependencies_states.py
@@ -16,38 +16,38 @@
# under the License.
from __future__ import annotations
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
EXECUTION_STATES = {
- State.RUNNING,
- State.QUEUED,
+ TaskInstanceState.RUNNING,
+ TaskInstanceState.QUEUED,
}
# In order to be able to get queued a task must have one of these states
SCHEDULEABLE_STATES = {
- State.NONE,
- State.UP_FOR_RETRY,
- State.UP_FOR_RESCHEDULE,
+ None,
+ TaskInstanceState.UP_FOR_RETRY,
+ TaskInstanceState.UP_FOR_RESCHEDULE,
}
RUNNABLE_STATES = {
# For cases like unit tests and run manually
- State.NONE,
- State.UP_FOR_RETRY,
- State.UP_FOR_RESCHEDULE,
+ None,
+ TaskInstanceState.UP_FOR_RETRY,
+ TaskInstanceState.UP_FOR_RESCHEDULE,
# For normal scheduler/backfill cases
- State.QUEUED,
+ TaskInstanceState.QUEUED,
}
QUEUEABLE_STATES = {
- State.SCHEDULED,
+ TaskInstanceState.SCHEDULED,
}
BACKFILL_QUEUEABLE_STATES = {
# For cases like unit tests and run manually
- State.NONE,
- State.UP_FOR_RESCHEDULE,
- State.UP_FOR_RETRY,
+ None,
+ TaskInstanceState.UP_FOR_RESCHEDULE,
+ TaskInstanceState.UP_FOR_RETRY,
# For normal backfill cases
- State.SCHEDULED,
+ TaskInstanceState.SCHEDULED,
}
diff --git a/airflow/ti_deps/deps/dagrun_exists_dep.py
b/airflow/ti_deps/deps/dagrun_exists_dep.py
index 781ab0ebaf..0a364628c7 100644
--- a/airflow/ti_deps/deps/dagrun_exists_dep.py
+++ b/airflow/ti_deps/deps/dagrun_exists_dep.py
@@ -19,7 +19,7 @@ from __future__ import annotations
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.utils.session import provide_session
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState
class DagrunRunningDep(BaseTIDep):
@@ -31,7 +31,7 @@ class DagrunRunningDep(BaseTIDep):
@provide_session
def _get_dep_statuses(self, ti, session, dep_context):
dr = ti.get_dagrun(session)
- if dr.state != State.RUNNING:
+ if dr.state != DagRunState.RUNNING:
yield self._failing_status(
reason=f"Task instance's dagrun was not in the 'running' state
but in the state '{dr.state}'."
)
diff --git a/airflow/ti_deps/deps/not_in_retry_period_dep.py
b/airflow/ti_deps/deps/not_in_retry_period_dep.py
index b3b5d4ec56..90954f29f2 100644
--- a/airflow/ti_deps/deps/not_in_retry_period_dep.py
+++ b/airflow/ti_deps/deps/not_in_retry_period_dep.py
@@ -20,7 +20,7 @@ from __future__ import annotations
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.utils import timezone
from airflow.utils.session import provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
class NotInRetryPeriodDep(BaseTIDep):
@@ -38,7 +38,7 @@ class NotInRetryPeriodDep(BaseTIDep):
)
return
- if ti.state != State.UP_FOR_RETRY:
+ if ti.state != TaskInstanceState.UP_FOR_RETRY:
yield self._passing_status(reason="The task instance was not
marked for retrying.")
return
diff --git a/airflow/ti_deps/deps/not_previously_skipped_dep.py
b/airflow/ti_deps/deps/not_previously_skipped_dep.py
index fdc8274d90..855f04af53 100644
--- a/airflow/ti_deps/deps/not_previously_skipped_dep.py
+++ b/airflow/ti_deps/deps/not_previously_skipped_dep.py
@@ -40,7 +40,7 @@ class NotPreviouslySkippedDep(BaseTIDep):
XCOM_SKIPMIXIN_SKIPPED,
SkipMixin,
)
- from airflow.utils.state import State
+ from airflow.utils.state import TaskInstanceState
upstream = ti.task.get_direct_relatives(upstream=True)
@@ -87,7 +87,7 @@ class NotPreviouslySkippedDep(BaseTIDep):
reason=("Task should be skipped but the the
past depends are not met")
)
return
- ti.set_state(State.SKIPPED, session)
+ ti.set_state(TaskInstanceState.SKIPPED, session)
yield self._failing_status(
reason=f"Skipping because of previous XCom result from
parent task {parent.task_id}"
)
diff --git a/airflow/ti_deps/deps/prev_dagrun_dep.py
b/airflow/ti_deps/deps/prev_dagrun_dep.py
index b3165eb7f2..62acdbca33 100644
--- a/airflow/ti_deps/deps/prev_dagrun_dep.py
+++ b/airflow/ti_deps/deps/prev_dagrun_dep.py
@@ -22,7 +22,7 @@ from sqlalchemy import func
from airflow.models.taskinstance import PAST_DEPENDS_MET, TaskInstance as TI
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.utils.session import provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
class PrevDagrunDep(BaseTIDep):
@@ -107,7 +107,7 @@ class PrevDagrunDep(BaseTIDep):
)
return
- if previous_ti.state not in {State.SKIPPED, State.SUCCESS}:
+ if previous_ti.state not in {TaskInstanceState.SKIPPED,
TaskInstanceState.SUCCESS}:
yield self._failing_status(
reason=(
f"depends_on_past is true for this task, but the previous
task instance {previous_ti} "
diff --git a/airflow/ti_deps/deps/ready_to_reschedule.py
b/airflow/ti_deps/deps/ready_to_reschedule.py
index 66aa5c5613..8394907081 100644
--- a/airflow/ti_deps/deps/ready_to_reschedule.py
+++ b/airflow/ti_deps/deps/ready_to_reschedule.py
@@ -22,7 +22,7 @@ from airflow.models.taskreschedule import TaskReschedule
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.utils import timezone
from airflow.utils.session import provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
class ReadyToRescheduleDep(BaseTIDep):
@@ -31,7 +31,7 @@ class ReadyToRescheduleDep(BaseTIDep):
NAME = "Ready To Reschedule"
IGNORABLE = True
IS_TASK_DEP = True
- RESCHEDULEABLE_STATES = {State.UP_FOR_RESCHEDULE, State.NONE}
+ RESCHEDULEABLE_STATES = {None, TaskInstanceState.UP_FOR_RESCHEDULE}
@provide_session
def _get_dep_statuses(self, ti, session, dep_context):
diff --git a/airflow/ti_deps/deps/task_not_running_dep.py
b/airflow/ti_deps/deps/task_not_running_dep.py
index fd76873466..7299a4f3e2 100644
--- a/airflow/ti_deps/deps/task_not_running_dep.py
+++ b/airflow/ti_deps/deps/task_not_running_dep.py
@@ -20,7 +20,7 @@ from __future__ import annotations
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.utils.session import provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
class TaskNotRunningDep(BaseTIDep):
@@ -37,7 +37,7 @@ class TaskNotRunningDep(BaseTIDep):
@provide_session
def _get_dep_statuses(self, ti, session, dep_context=None):
- if ti.state != State.RUNNING:
+ if ti.state != TaskInstanceState.RUNNING:
yield self._passing_status(reason="Task is not in running state.")
return
diff --git a/airflow/utils/dot_renderer.py b/airflow/utils/dot_renderer.py
index 3f35d1b21d..d9c2acce80 100644
--- a/airflow/utils/dot_renderer.py
+++ b/airflow/utils/dot_renderer.py
@@ -58,7 +58,7 @@ def _draw_task(
) -> None:
"""Draw a single task on the given parent_graph."""
if states_by_task_id:
- state = states_by_task_id.get(task.task_id, State.NONE)
+ state = states_by_task_id.get(task.task_id, None)
color = State.color_fg(state)
fill_color = State.color(state)
else:
diff --git a/airflow/utils/log/file_task_handler.py
b/airflow/utils/log/file_task_handler.py
index 0c98d25503..9184a20420 100644
--- a/airflow/utils/log/file_task_handler.py
+++ b/airflow/utils/log/file_task_handler.py
@@ -341,7 +341,10 @@ class FileTaskHandler(logging.Handler):
)
log_pos = len(logs)
messages = "".join([f"*** {x}\n" for x in messages_list])
- end_of_log = ti.try_number != try_number or ti.state not in
[State.RUNNING, State.DEFERRED]
+ end_of_log = ti.try_number != try_number or ti.state not in (
+ TaskInstanceState.RUNNING,
+ TaskInstanceState.DEFERRED,
+ )
if metadata and "log_pos" in metadata:
previous_chars = metadata["log_pos"]
logs = logs[previous_chars:] # Cut off previously passed log test
as new tail
diff --git a/airflow/utils/log/log_reader.py b/airflow/utils/log/log_reader.py
index a4589ebca0..e8c3b9897a 100644
--- a/airflow/utils/log/log_reader.py
+++ b/airflow/utils/log/log_reader.py
@@ -28,7 +28,7 @@ from airflow.models.taskinstance import TaskInstance
from airflow.utils.helpers import render_log_filename
from airflow.utils.log.logging_mixin import ExternalLoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
class TaskLogReader:
@@ -81,19 +81,26 @@ class TaskLogReader:
metadata.pop("max_offset", None)
metadata.pop("offset", None)
metadata.pop("log_pos", None)
+
while True:
logs, metadata = self.read_log_chunks(ti, current_try_number,
metadata)
for host, log in logs[0]:
yield "\n".join([host or "", log]) + "\n"
- if "end_of_log" not in metadata or (
- not metadata["end_of_log"] and ti.state not in
[State.RUNNING, State.DEFERRED]
- ):
- if not logs[0]:
- # we did not receive any logs in this loop
- # sleeping to conserve resources / limit requests on
external services
- time.sleep(self.STREAM_LOOP_SLEEP_SECONDS)
+ try:
+ end_of_log = bool(metadata["end_of_log"])
+ except KeyError:
+ continue_read = True
else:
+ continue_read = not end_of_log and ti.state not in (
+ TaskInstanceState.RUNNING,
+ TaskInstanceState.DEFERRED,
+ )
+ if not continue_read:
break
+ if not logs[0]:
+ # we did not receive any logs in this loop
+ # sleeping to conserve resources / limit requests on
external services
+ time.sleep(self.STREAM_LOOP_SLEEP_SECONDS)
@cached_property
def log_handler(self):
diff --git a/airflow/utils/state.py b/airflow/utils/state.py
index f4a8dc1a0a..b6297dfb71 100644
--- a/airflow/utils/state.py
+++ b/airflow/utils/state.py
@@ -98,14 +98,9 @@ class State:
finished_dr_states: frozenset[DagRunState] =
frozenset([DagRunState.SUCCESS, DagRunState.FAILED])
unfinished_dr_states: frozenset[DagRunState] =
frozenset([DagRunState.QUEUED, DagRunState.RUNNING])
- task_states: tuple[TaskInstanceState | None, ...] = (None,) +
tuple(TaskInstanceState)
+ task_states: tuple[TaskInstanceState | None, ...] = (None,
*TaskInstanceState)
- dag_states: tuple[DagRunState, ...] = (
- DagRunState.QUEUED,
- DagRunState.SUCCESS,
- DagRunState.RUNNING,
- DagRunState.FAILED,
- )
+ dag_states: tuple[DagRunState, ...] = tuple(DagRunState)
state_color: dict[TaskInstanceState | None, str] = {
None: "lightblue",
diff --git a/airflow/www/utils.py b/airflow/www/utils.py
index 25fc1a28f9..58836beefe 100644
--- a/airflow/www/utils.py
+++ b/airflow/www/utils.py
@@ -85,8 +85,10 @@ def get_instance_with_map(task_instance, session):
return get_mapped_summary(task_instance, mapped_instances)
-def get_try_count(try_number: int, state: State):
- return try_number + 1 if state in [State.DEFERRED,
State.UP_FOR_RESCHEDULE] else try_number
+def get_try_count(try_number: int, state: TaskInstanceState) -> int:
+ if state in (TaskInstanceState.DEFERRED,
TaskInstanceState.UP_FOR_RESCHEDULE):
+ return try_number + 1
+ return try_number
priority: list[None | TaskInstanceState] = [
@@ -426,7 +428,7 @@ def task_instance_link(attr):
def state_token(state):
- """Returns a formatted string with HTML for a given State."""
+ """Returns a formatted string with HTML for a given state."""
color = State.color(state)
fg_color = State.color_fg(state)
return Markup(
@@ -438,7 +440,7 @@ def state_token(state):
def state_f(attr):
- """Gets 'state' & returns a formatted string with HTML for a given
State."""
+ """Gets 'state' & returns a formatted string with HTML for a given
state."""
state = attr.get("state")
return state_token(state)
diff --git a/airflow/www/views.py b/airflow/www/views.py
index f50df518fe..3da11f5681 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -754,7 +754,7 @@ class Airflow(AirflowBaseView):
# find DAGs which have a RUNNING DagRun
running_dags = dags_query.join(DagRun, DagModel.dag_id ==
DagRun.dag_id).filter(
- DagRun.state == State.RUNNING
+ DagRun.state == DagRunState.RUNNING
)
# find DAGs for which the latest DagRun is FAILED
@@ -765,7 +765,7 @@ class Airflow(AirflowBaseView):
)
subq_failed = (
session.query(DagRun.dag_id,
func.max(DagRun.start_date).label("start_date"))
- .filter(DagRun.state == State.FAILED)
+ .filter(DagRun.state == DagRunState.FAILED)
.group_by(DagRun.dag_id)
.subquery()
)
@@ -1101,7 +1101,7 @@ class Airflow(AirflowBaseView):
running_dag_run_query_result = (
session.query(DagRun.dag_id, DagRun.run_id)
.join(DagModel, DagModel.dag_id == DagRun.dag_id)
- .filter(DagRun.state == State.RUNNING, DagModel.is_active)
+ .filter(DagRun.state == DagRunState.RUNNING, DagModel.is_active)
)
running_dag_run_query_result =
running_dag_run_query_result.filter(DagRun.dag_id.in_(filter_dag_ids))
@@ -1125,7 +1125,7 @@ class Airflow(AirflowBaseView):
last_dag_run = (
session.query(DagRun.dag_id,
sqla.func.max(DagRun.execution_date).label("execution_date"))
.join(DagModel, DagModel.dag_id == DagRun.dag_id)
- .filter(DagRun.state != State.RUNNING, DagModel.is_active)
+ .filter(DagRun.state != DagRunState.RUNNING,
DagModel.is_active)
.group_by(DagRun.dag_id)
)
@@ -1820,7 +1820,7 @@ class Airflow(AirflowBaseView):
"Airflow administrator for assistance.".format(
"- This task instance already ran and had it's state
changed manually "
"(e.g. cleared in the UI)<br>"
- if ti and ti.state == State.NONE
+ if ti and ti.state is None
else ""
),
)
@@ -2119,7 +2119,7 @@ class Airflow(AirflowBaseView):
run_type=DagRunType.MANUAL,
execution_date=execution_date,
data_interval=dag.timetable.infer_manual_data_interval(run_after=execution_date),
- state=State.QUEUED,
+ state=DagRunState.QUEUED,
conf=run_conf,
external_trigger=True,
dag_hash=get_airflow_app().dag_bag.dags_hash.get(dag_id),
@@ -5337,14 +5337,14 @@ class
DagRunModelView(AirflowPrivilegeVerifierModelView):
@action_logging
def action_set_queued(self, drs: list[DagRun]):
"""Set state to queued."""
- return self._set_dag_runs_to_active_state(drs, State.QUEUED)
+ return self._set_dag_runs_to_active_state(drs, DagRunState.QUEUED)
@action("set_running", "Set state to 'running'", "", single=False)
@action_has_dag_edit_access
@action_logging
def action_set_running(self, drs: list[DagRun]):
"""Set state to running."""
- return self._set_dag_runs_to_active_state(drs, State.RUNNING)
+ return self._set_dag_runs_to_active_state(drs, DagRunState.RUNNING)
@provide_session
def _set_dag_runs_to_active_state(self, drs: list[DagRun], state: str,
session: Session = NEW_SESSION):
@@ -5353,7 +5353,7 @@ class DagRunModelView(AirflowPrivilegeVerifierModelView):
count = 0
for dr in session.query(DagRun).filter(DagRun.id.in_(dagrun.id for
dagrun in drs)):
count += 1
- if state == State.RUNNING:
+ if state == DagRunState.RUNNING:
dr.start_date = timezone.utcnow()
dr.state = state
session.commit()
@@ -5790,7 +5790,7 @@ class
TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
@action_logging
def action_set_running(self, tis):
"""Set state to 'running'."""
- self.set_task_instance_state(tis, State.RUNNING)
+ self.set_task_instance_state(tis, TaskInstanceState.RUNNING)
self.update_redirect()
return redirect(self.get_redirect())
@@ -5799,7 +5799,7 @@ class
TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
@action_logging
def action_set_failed(self, tis):
"""Set state to 'failed'."""
- self.set_task_instance_state(tis, State.FAILED)
+ self.set_task_instance_state(tis, TaskInstanceState.FAILED)
self.update_redirect()
return redirect(self.get_redirect())
@@ -5808,7 +5808,7 @@ class
TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
@action_logging
def action_set_success(self, tis):
"""Set state to 'success'."""
- self.set_task_instance_state(tis, State.SUCCESS)
+ self.set_task_instance_state(tis, TaskInstanceState.SUCCESS)
self.update_redirect()
return redirect(self.get_redirect())
@@ -5817,7 +5817,7 @@ class
TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
@action_logging
def action_set_retry(self, tis):
"""Set state to 'up_for_retry'."""
- self.set_task_instance_state(tis, State.UP_FOR_RETRY)
+ self.set_task_instance_state(tis, TaskInstanceState.UP_FOR_RETRY)
self.update_redirect()
return redirect(self.get_redirect())