This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new fa3be084a1 More strong typed state conversion (#32521)
fa3be084a1 is described below
commit fa3be084a1cd84242893a3367ac2c0c4d3a4f480
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Wed Jul 12 05:54:36 2023 +0800
More strong typed state conversion (#32521)
---
airflow/api/common/delete_dag.py | 4 +-
airflow/api/common/mark_tasks.py | 35 +++++++++++----
.../endpoints/task_instance_endpoint.py | 12 +++---
airflow/dag_processing/processor.py | 10 +++--
airflow/executors/local_executor.py | 16 +++----
airflow/jobs/backfill_job_runner.py | 6 +--
airflow/jobs/local_task_job_runner.py | 6 +--
airflow/jobs/scheduler_job_runner.py | 18 ++++----
airflow/listeners/spec/taskinstance.py | 12 +++---
airflow/models/dag.py | 14 +++---
airflow/models/dagrun.py | 42 +++++++++---------
airflow/models/pool.py | 8 ++--
airflow/models/skipmixin.py | 4 +-
airflow/models/taskinstance.py | 50 +++++++++++-----------
airflow/operators/subdag.py | 20 ++++-----
airflow/sentry.py | 21 +++++----
airflow/ti_deps/deps/dagrun_exists_dep.py | 4 +-
airflow/ti_deps/deps/not_in_retry_period_dep.py | 4 +-
airflow/ti_deps/deps/not_previously_skipped_dep.py | 4 +-
airflow/ti_deps/deps/prev_dagrun_dep.py | 4 +-
airflow/ti_deps/deps/ready_to_reschedule.py | 4 +-
airflow/ti_deps/deps/task_not_running_dep.py | 4 +-
airflow/utils/dot_renderer.py | 2 +-
airflow/utils/log/file_task_handler.py | 5 ++-
airflow/utils/log/log_reader.py | 5 ++-
airflow/utils/state.py | 6 +--
airflow/www/utils.py | 4 +-
airflow/www/views.py | 40 ++++++++++-------
dev/perf/scheduler_dag_execution_timing.py | 8 ++--
29 files changed, 204 insertions(+), 168 deletions(-)
diff --git a/airflow/api/common/delete_dag.py b/airflow/api/common/delete_dag.py
index 45611729ea..1d879a667a 100644
--- a/airflow/api/common/delete_dag.py
+++ b/airflow/api/common/delete_dag.py
@@ -29,7 +29,7 @@ from airflow.models import DagModel, TaskFail
from airflow.models.serialized_dag import SerializedDagModel
from airflow.utils.db import get_sqla_model_classes
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
log = logging.getLogger(__name__)
@@ -50,7 +50,7 @@ def delete_dag(dag_id: str, keep_records_in_log: bool = True,
session: Session =
running_tis = session.scalar(
select(models.TaskInstance.state)
.where(models.TaskInstance.dag_id == dag_id)
- .where(models.TaskInstance.state == State.RUNNING)
+ .where(models.TaskInstance.state == TaskInstanceState.RUNNING)
.limit(1)
)
if running_tis:
diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py
index b965237bdf..4d2df78e82 100644
--- a/airflow/api/common/mark_tasks.py
+++ b/airflow/api/common/mark_tasks.py
@@ -155,7 +155,7 @@ def set_state(
for task_instance in tis_altered:
# The try_number was decremented when setting to up_for_reschedule
and deferred.
# Increment it back when changing the state again
- if task_instance.state in [State.DEFERRED,
State.UP_FOR_RESCHEDULE]:
+ if task_instance.state in (TaskInstanceState.DEFERRED,
TaskInstanceState.UP_FOR_RESCHEDULE):
task_instance._try_number += 1
task_instance.set_state(state, session=session)
session.flush()
@@ -362,7 +362,7 @@ def _set_dag_run_state(dag_id: str, run_id: str, state:
DagRunState, session: SA
select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id)
).scalar_one()
dag_run.state = state
- if state == State.RUNNING:
+ if state == DagRunState.RUNNING:
dag_run.start_date = timezone.utcnow()
dag_run.end_date = None
else:
@@ -415,7 +415,13 @@ def set_dag_run_state_to_success(
# Mark all task instances of the dag run to success.
for task in dag.tasks:
task.dag = dag
- return set_state(tasks=dag.tasks, run_id=run_id, state=State.SUCCESS,
commit=commit, session=session)
+ return set_state(
+ tasks=dag.tasks,
+ run_id=run_id,
+ state=TaskInstanceState.SUCCESS,
+ commit=commit,
+ session=session,
+ )
@provide_session
@@ -461,6 +467,12 @@ def set_dag_run_state_to_failed(
if commit:
_set_dag_run_state(dag.dag_id, run_id, DagRunState.FAILED, session)
+ running_states = (
+ TaskInstanceState.RUNNING,
+ TaskInstanceState.DEFERRED,
+ TaskInstanceState.UP_FOR_RESCHEDULE,
+ )
+
# Mark only RUNNING task instances.
task_ids = [task.task_id for task in dag.tasks]
tis = session.scalars(
@@ -468,7 +480,7 @@ def set_dag_run_state_to_failed(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.run_id == run_id,
TaskInstance.task_id.in_(task_ids),
- TaskInstance.state.in_([State.RUNNING, State.DEFERRED,
State.UP_FOR_RESCHEDULE]),
+ TaskInstance.state.in_(running_states),
)
)
@@ -487,16 +499,21 @@ def set_dag_run_state_to_failed(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.run_id == run_id,
TaskInstance.state.not_in(State.finished),
- TaskInstance.state.not_in([State.RUNNING, State.DEFERRED,
State.UP_FOR_RESCHEDULE]),
+ TaskInstance.state.not_in(running_states),
)
- )
+ ).all()
- tis = [ti for ti in tis]
if commit:
for ti in tis:
- ti.set_state(State.SKIPPED)
+ ti.set_state(TaskInstanceState.SKIPPED)
- return tis + set_state(tasks=tasks, run_id=run_id, state=State.FAILED,
commit=commit, session=session)
+ return tis + set_state(
+ tasks=tasks,
+ run_id=run_id,
+ state=TaskInstanceState.FAILED,
+ commit=commit,
+ session=session,
+ )
def __set_dag_run_state_to_running_or_queued(
diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py
b/airflow/api_connexion/endpoints/task_instance_endpoint.py
index 3028b0bb73..2496433a26 100644
--- a/airflow/api_connexion/endpoints/task_instance_endpoint.py
+++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py
@@ -49,7 +49,7 @@ from airflow.models.taskinstance import TaskInstance as TI,
clear_task_instances
from airflow.security import permissions
from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import DagRunState, State
+from airflow.utils.state import DagRunState, TaskInstanceState
T = TypeVar("T")
@@ -187,7 +187,7 @@ def get_mapped_task_instances(
) -> APIResponse:
"""Get list of task instances."""
# Because state can be 'none'
- states = _convert_state(state)
+ states = _convert_ti_states(state)
base_query = (
select(TI)
@@ -264,10 +264,10 @@ def get_mapped_task_instances(
)
-def _convert_state(states: Iterable[str] | None) -> list[str | None] | None:
+def _convert_ti_states(states: Iterable[str] | None) -> list[TaskInstanceState
| None] | None:
if not states:
return None
- return [State.NONE if s == "none" else s for s in states]
+ return [None if s == "none" else TaskInstanceState(s) for s in states]
def _apply_array_filter(query: Select, key: ClauseElement, values:
Iterable[Any] | None) -> Select:
@@ -329,7 +329,7 @@ def get_task_instances(
) -> APIResponse:
"""Get list of task instances."""
# Because state can be 'none'
- states = _convert_state(state)
+ states = _convert_ti_states(state)
base_query = select(TI).join(TI.dag_run)
@@ -395,7 +395,7 @@ def get_task_instances_batch(session: Session =
NEW_SESSION) -> APIResponse:
data = task_instance_batch_form.load(body)
except ValidationError as err:
raise BadRequest(detail=str(err.messages))
- states = _convert_state(data["state"])
+ states = _convert_ti_states(data["state"])
base_query = select(TI).join(TI.dag_run)
base_query = _apply_array_filter(base_query, key=TI.dag_id,
values=data["dag_ids"])
diff --git a/airflow/dag_processing/processor.py
b/airflow/dag_processing/processor.py
index 5c175571c0..369f676878 100644
--- a/airflow/dag_processing/processor.py
+++ b/airflow/dag_processing/processor.py
@@ -56,7 +56,7 @@ from airflow.utils.file import iter_airflow_imports,
might_contain_dag
from airflow.utils.log.logging_mixin import LoggingMixin, StreamLogWriter,
set_context
from airflow.utils.mixins import MultiprocessingStartMethodMixin
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
if TYPE_CHECKING:
from airflow.models.operator import Operator
@@ -433,7 +433,7 @@ class DagFileProcessor(LoggingMixin):
session.query(TI.task_id,
func.max(DR.execution_date).label("max_ti"))
.join(TI.dag_run)
.filter(TI.dag_id == dag.dag_id)
- .filter(or_(TI.state == State.SUCCESS, TI.state == State.SKIPPED))
+ .filter(or_(TI.state == TaskInstanceState.SUCCESS, TI.state ==
TaskInstanceState.SKIPPED))
.filter(TI.task_id.in_(dag.task_ids))
.group_by(TI.task_id)
.subquery("sq")
@@ -500,7 +500,11 @@ class DagFileProcessor(LoggingMixin):
sla_dates: list[datetime] = [sla.execution_date for sla in slas]
fetched_tis: list[TI] = (
session.query(TI)
- .filter(TI.state != State.SUCCESS,
TI.execution_date.in_(sla_dates), TI.dag_id == dag.dag_id)
+ .filter(
+ TI.dag_id == dag.dag_id,
+ TI.execution_date.in_(sla_dates),
+ TI.state != TaskInstanceState.SUCCESS,
+ )
.all()
)
blocking_tis: list[TI] = []
diff --git a/airflow/executors/local_executor.py
b/airflow/executors/local_executor.py
index ca54a387c8..7f83f8c7a2 100644
--- a/airflow/executors/local_executor.py
+++ b/airflow/executors/local_executor.py
@@ -39,7 +39,7 @@ from airflow import settings
from airflow.exceptions import AirflowException
from airflow.executors.base_executor import PARALLELISM, BaseExecutor
from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
if TYPE_CHECKING:
from airflow.executors.base_executor import CommandType
@@ -94,20 +94,20 @@ class LocalWorkerBase(Process, LoggingMixin):
# Remove the command since the worker is done executing the task
setproctitle("airflow worker -- LocalExecutor")
- def _execute_work_in_subprocess(self, command: CommandType) -> str:
+ def _execute_work_in_subprocess(self, command: CommandType) ->
TaskInstanceState:
try:
subprocess.check_call(command, close_fds=True)
- return State.SUCCESS
+ return TaskInstanceState.SUCCESS
except subprocess.CalledProcessError as e:
self.log.error("Failed to execute task %s.", str(e))
- return State.FAILED
+ return TaskInstanceState.FAILED
- def _execute_work_in_fork(self, command: CommandType) -> str:
+ def _execute_work_in_fork(self, command: CommandType) -> TaskInstanceState:
pid = os.fork()
if pid:
# In parent, wait for the child
pid, ret = os.waitpid(pid, 0)
- return State.SUCCESS if ret == 0 else State.FAILED
+ return TaskInstanceState.SUCCESS if ret == 0 else
TaskInstanceState.FAILED
from airflow.sentry import Sentry
@@ -130,10 +130,10 @@ class LocalWorkerBase(Process, LoggingMixin):
args.func(args)
ret = 0
- return State.SUCCESS
+ return TaskInstanceState.SUCCESS
except Exception as e:
self.log.exception("Failed to execute task %s.", e)
- return State.FAILED
+ return TaskInstanceState.FAILED
finally:
Sentry.flush()
logging.shutdown()
diff --git a/airflow/jobs/backfill_job_runner.py
b/airflow/jobs/backfill_job_runner.py
index 5b13490be7..35910b83ae 100644
--- a/airflow/jobs/backfill_job_runner.py
+++ b/airflow/jobs/backfill_job_runner.py
@@ -68,7 +68,7 @@ class BackfillJobRunner(BaseJobRunner[Job], LoggingMixin):
job_type = "BackfillJob"
- STATES_COUNT_AS_RUNNING = (State.RUNNING, State.QUEUED)
+ STATES_COUNT_AS_RUNNING = (TaskInstanceState.RUNNING,
TaskInstanceState.QUEUED)
@attr.define
class _DagRunTaskStatus:
@@ -219,7 +219,7 @@ class BackfillJobRunner(BaseJobRunner[Job], LoggingMixin):
# is changed externally, e.g. by clearing tasks from the ui. We
need to cover
# for that as otherwise those tasks would fall outside the scope of
# the backfill suddenly.
- elif ti.state == State.NONE:
+ elif ti.state is None:
self.log.warning(
"FIXME: task instance %s state was set to none externally
or "
"reaching concurrency limits. Re-adding task to queue.",
@@ -1000,7 +1000,7 @@ class BackfillJobRunner(BaseJobRunner[Job], LoggingMixin):
).all()
for ti in reset_tis:
- ti.state = State.NONE
+ ti.state = None
session.merge(ti)
return result + reset_tis
diff --git a/airflow/jobs/local_task_job_runner.py
b/airflow/jobs/local_task_job_runner.py
index 9f6a4b55e8..fd234e4150 100644
--- a/airflow/jobs/local_task_job_runner.py
+++ b/airflow/jobs/local_task_job_runner.py
@@ -35,7 +35,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.net import get_hostname
from airflow.utils.platform import IS_WINDOWS
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
SIGSEGV_MESSAGE = """
******************************************* Received SIGSEGV
*******************************************
@@ -243,7 +243,7 @@ class LocalTaskJobRunner(BaseJobRunner["Job |
JobPydantic"], LoggingMixin):
self.task_instance.refresh_from_db()
ti = self.task_instance
- if ti.state == State.RUNNING:
+ if ti.state == TaskInstanceState.RUNNING:
fqdn = get_hostname()
same_hostname = fqdn == ti.hostname
if not same_hostname:
@@ -273,7 +273,7 @@ class LocalTaskJobRunner(BaseJobRunner["Job |
JobPydantic"], LoggingMixin):
)
raise AirflowException("PID of job runner does not match")
elif self.task_runner.return_code() is None and
hasattr(self.task_runner, "process"):
- if ti.state == State.SKIPPED:
+ if ti.state == TaskInstanceState.SKIPPED:
# A DagRun timeout will cause tasks to be externally marked as
skipped.
dagrun = ti.get_dagrun(session=session)
execution_time = (dagrun.end_date or timezone.utcnow()) -
dagrun.start_date
diff --git a/airflow/jobs/scheduler_job_runner.py
b/airflow/jobs/scheduler_job_runner.py
index 8e9db65647..7b13ecc300 100644
--- a/airflow/jobs/scheduler_job_runner.py
+++ b/airflow/jobs/scheduler_job_runner.py
@@ -610,7 +610,7 @@ class SchedulerJobRunner(BaseJobRunner[Job], LoggingMixin):
)
for ti in executable_tis:
- ti.emit_state_change_metric(State.QUEUED)
+ ti.emit_state_change_metric(TaskInstanceState.QUEUED)
for ti in executable_tis:
make_transient(ti)
@@ -626,7 +626,7 @@ class SchedulerJobRunner(BaseJobRunner[Job], LoggingMixin):
# actually enqueue them
for ti in task_instances:
if ti.dag_run.state in State.finished:
- ti.set_state(State.NONE, session=session)
+ ti.set_state(None, session=session)
continue
command = ti.command_as_list(
local=True,
@@ -681,14 +681,12 @@ class SchedulerJobRunner(BaseJobRunner[Job],
LoggingMixin):
tis_with_right_state: list[TaskInstanceKey] = []
# Report execution
- for ti_key, value in event_buffer.items():
- state: str
- state, _ = value
+ for ti_key, (state, _) in event_buffer.items():
# We create map (dag_id, task_id, execution_date) -> in-memory
try_number
ti_primary_key_to_try_number_map[ti_key.primary] =
ti_key.try_number
self.log.info("Received executor event with state %s for task
instance %s", state, ti_key)
- if state in (State.FAILED, State.SUCCESS, State.QUEUED):
+ if state in (TaskInstanceState.FAILED, TaskInstanceState.SUCCESS,
TaskInstanceState.QUEUED):
tis_with_right_state.append(ti_key)
# Return if no finished tasks
@@ -712,7 +710,7 @@ class SchedulerJobRunner(BaseJobRunner[Job], LoggingMixin):
buffer_key = ti.key.with_try_number(try_number)
state, info = event_buffer.pop(buffer_key)
- if state == State.QUEUED:
+ if state == TaskInstanceState.QUEUED:
ti.external_executor_id = info
self.log.info("Setting external_id for %s to %s", ti, info)
continue
@@ -1532,7 +1530,7 @@ class SchedulerJobRunner(BaseJobRunner[Job],
LoggingMixin):
tasks_stuck_in_queued = session.scalars(
select(TI).where(
- TI.state == State.QUEUED,
+ TI.state == TaskInstanceState.QUEUED,
TI.queued_dttm < (timezone.utcnow() -
timedelta(seconds=self._task_queued_timeout)),
TI.queued_by_job_id == self.job.id,
)
@@ -1611,7 +1609,7 @@ class SchedulerJobRunner(BaseJobRunner[Job],
LoggingMixin):
.join(TI.dag_run)
.where(
DagRun.run_type != DagRunType.BACKFILL_JOB,
- DagRun.state == State.RUNNING,
+ DagRun.state == DagRunState.RUNNING,
)
.options(load_only(TI.dag_id, TI.task_id, TI.run_id))
)
@@ -1626,7 +1624,7 @@ class SchedulerJobRunner(BaseJobRunner[Job],
LoggingMixin):
reset_tis_message = []
for ti in to_reset:
reset_tis_message.append(repr(ti))
- ti.state = State.NONE
+ ti.state = None
ti.queued_by_job_id = None
for ti in set(tis_to_reset_or_adopt) - set(to_reset):
diff --git a/airflow/listeners/spec/taskinstance.py
b/airflow/listeners/spec/taskinstance.py
index 78de8a5f62..b87043a99d 100644
--- a/airflow/listeners/spec/taskinstance.py
+++ b/airflow/listeners/spec/taskinstance.py
@@ -32,20 +32,20 @@ hookspec = HookspecMarker("airflow")
@hookspec
def on_task_instance_running(
- previous_state: TaskInstanceState, task_instance: TaskInstance, session:
Session | None
+ previous_state: TaskInstanceState | None, task_instance: TaskInstance,
session: Session | None
):
- """Called when task state changes to RUNNING. Previous_state can be
State.NONE."""
+ """Called when task state changes to RUNNING. previous_state can be
None."""
@hookspec
def on_task_instance_success(
- previous_state: TaskInstanceState, task_instance: TaskInstance, session:
Session | None
+ previous_state: TaskInstanceState | None, task_instance: TaskInstance,
session: Session | None
):
- """Called when task state changes to SUCCESS. Previous_state can be
State.NONE."""
+ """Called when task state changes to SUCCESS. previous_state can be
None."""
@hookspec
def on_task_instance_failed(
- previous_state: TaskInstanceState, task_instance: TaskInstance, session:
Session | None
+ previous_state: TaskInstanceState | None, task_instance: TaskInstance,
session: Session | None
):
- """Called when task state changes to FAIL. Previous_state can be
State.NONE."""
+ """Called when task state changes to FAIL. previous_state can be None."""
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 46e7f4608e..391be9e582 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -127,7 +127,7 @@ from airflow.utils.sqlalchemy import (
tuple_in_condition,
with_row_locks,
)
-from airflow.utils.state import DagRunState, State, TaskInstanceState
+from airflow.utils.state import DagRunState, TaskInstanceState
from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType
if TYPE_CHECKING:
@@ -1416,7 +1416,7 @@ class DAG(LoggingMixin):
:return: List of execution dates
"""
- runs = DagRun.find(dag_id=self.dag_id, state=State.RUNNING)
+ runs = DagRun.find(dag_id=self.dag_id, state=DagRunState.RUNNING)
active_dates = []
for run in runs:
@@ -2118,7 +2118,7 @@ class DAG(LoggingMixin):
@provide_session
def set_dag_runs_state(
self,
- state: str = State.RUNNING,
+ state: DagRunState = DagRunState.RUNNING,
session: Session = NEW_SESSION,
start_date: datetime | None = None,
end_date: datetime | None = None,
@@ -2199,12 +2199,12 @@ class DAG(LoggingMixin):
stacklevel=2,
)
- state = []
+ state: list[TaskInstanceState] = []
if only_failed:
- state += [State.FAILED, State.UPSTREAM_FAILED]
+ state += [TaskInstanceState.FAILED,
TaskInstanceState.UPSTREAM_FAILED]
if only_running:
# Yes, having `+=` doesn't make sense, but this was the existing
behaviour
- state += [State.RUNNING]
+ state += [TaskInstanceState.RUNNING]
tis = self._get_task_instances(
task_ids=task_ids,
@@ -2742,7 +2742,7 @@ class DAG(LoggingMixin):
# Instead of starting a scheduler, we run the minimal loop possible to
check
# for task readiness and dependency management. This is notably faster
# than creating a BackfillJob and allows us to surface logs to the user
- while dr.state == State.RUNNING:
+ while dr.state == DagRunState.RUNNING:
schedulable_tis, _ = dr.update_state(session=session)
try:
for ti in schedulable_tis:
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index f0b29881a8..e3ed0bda00 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -115,7 +115,7 @@ class DagRun(Base, LoggingMixin):
execution_date = Column(UtcDateTime, default=timezone.utcnow,
nullable=False)
start_date = Column(UtcDateTime)
end_date = Column(UtcDateTime)
- _state = Column("state", String(50), default=State.QUEUED)
+ _state = Column("state", String(50), default=DagRunState.QUEUED)
run_id = Column(StringID(), nullable=False)
creating_job_id = Column(Integer)
external_trigger = Column(Boolean, default=True)
@@ -222,7 +222,7 @@ class DagRun(Base, LoggingMixin):
if state is not None:
self.state = state
if queued_at is NOTSET:
- self.queued_at = timezone.utcnow() if state == State.QUEUED else
None
+ self.queued_at = timezone.utcnow() if state == DagRunState.QUEUED
else None
else:
self.queued_at = queued_at
self.run_type = run_type
@@ -265,13 +265,13 @@ class DagRun(Base, LoggingMixin):
def get_state(self):
return self._state
- def set_state(self, state: DagRunState):
+ def set_state(self, state: DagRunState) -> None:
if state not in State.dag_states:
raise ValueError(f"invalid DagRun state: {state}")
if self._state != state:
self._state = state
self.end_date = timezone.utcnow() if self._state in State.finished
else None
- if state == State.QUEUED:
+ if state == DagRunState.QUEUED:
self.queued_at = timezone.utcnow()
@declared_attr
@@ -306,9 +306,9 @@ class DagRun(Base, LoggingMixin):
# because SQLAlchemy doesn't accept a set here.
query = query.where(cls.dag_id.in_(set(dag_ids)))
if only_running:
- query = query.where(cls.state == State.RUNNING)
+ query = query.where(cls.state == DagRunState.RUNNING)
else:
- query = query.where(cls.state.in_([State.RUNNING, State.QUEUED]))
+ query = query.where(cls.state.in_((DagRunState.RUNNING,
DagRunState.QUEUED)))
query = query.group_by(cls.dag_id)
return {dag_id: count for dag_id, count in session.execute(query)}
@@ -340,7 +340,7 @@ class DagRun(Base, LoggingMixin):
.join(DagModel, DagModel.dag_id == cls.dag_id)
.where(DagModel.is_paused == false(), DagModel.is_active == true())
)
- if state == State.QUEUED:
+ if state == DagRunState.QUEUED:
# For dag runs in the queued state, we check if they have reached
the max_active_runs limit
# and if so we drop them
running_drs = (
@@ -477,11 +477,11 @@ class DagRun(Base, LoggingMixin):
tis = tis.where(TI.state == state)
else:
# this is required to deal with NULL values
- if State.NONE in state:
+ if None in state:
if all(x is None for x in state):
tis = tis.where(TI.state.is_(None))
else:
- not_none_state = [s for s in state if s]
+ not_none_state = (s for s in state if s)
tis = tis.where(or_(TI.state.in_(not_none_state),
TI.state.is_(None)))
else:
tis = tis.where(TI.state.in_(state))
@@ -746,9 +746,9 @@ class DagRun(Base, LoggingMixin):
try:
ti.task = dag.get_task(ti.task_id)
except TaskNotFound:
- if ti.state != State.REMOVED:
+ if ti.state != TaskInstanceState.REMOVED:
self.log.error("Failed to get task for ti %s. Marking
it as removed.", ti)
- ti.state = State.REMOVED
+ ti.state = TaskInstanceState.REMOVED
session.flush()
else:
yield ti
@@ -957,7 +957,7 @@ class DagRun(Base, LoggingMixin):
self.log.warning("Failed to record first_task_scheduling_delay
metric:", exc_info=True)
def _emit_duration_stats_for_finished_state(self):
- if self.state == State.RUNNING:
+ if self.state == DagRunState.RUNNING:
return
if self.start_date is None:
self.log.warning("Failed to record duration of %s: start_date is
not set.", self)
@@ -1029,22 +1029,22 @@ class DagRun(Base, LoggingMixin):
try:
task = dag.get_task(ti.task_id)
- should_restore_task = (task is not None) and ti.state ==
State.REMOVED
+ should_restore_task = (task is not None) and ti.state ==
TaskInstanceState.REMOVED
if should_restore_task:
self.log.info("Restoring task '%s' which was previously
removed from DAG '%s'", ti, dag)
Stats.incr(f"task_restored_to_dag.{dag.dag_id}",
tags=self.stats_tags)
# Same metric with tagging
Stats.incr("task_restored_to_dag",
tags={**self.stats_tags, "dag_id": dag.dag_id})
- ti.state = State.NONE
+ ti.state = None
except AirflowException:
- if ti.state == State.REMOVED:
+ if ti.state == TaskInstanceState.REMOVED:
pass # ti has already been removed, just ignore it
- elif self.state != State.RUNNING and not dag.partial:
+ elif self.state != DagRunState.RUNNING and not dag.partial:
self.log.warning("Failed to get task '%s' for dag '%s'.
Marking it as removed.", ti, dag)
Stats.incr(f"task_removed_from_dag.{dag.dag_id}",
tags=self.stats_tags)
# Same metric with tagging
Stats.incr("task_removed_from_dag",
tags={**self.stats_tags, "dag_id": dag.dag_id})
- ti.state = State.REMOVED
+ ti.state = TaskInstanceState.REMOVED
continue
try:
@@ -1061,7 +1061,7 @@ class DagRun(Base, LoggingMixin):
self.log.debug(
"Removing the unmapped TI '%s' as the mapping
can't be resolved yet", ti
)
- ti.state = State.REMOVED
+ ti.state = TaskInstanceState.REMOVED
continue
# Upstreams finished, check there aren't any extras
if ti.map_index >= total_length:
@@ -1070,7 +1070,7 @@ class DagRun(Base, LoggingMixin):
ti,
total_length,
)
- ti.state = State.REMOVED
+ ti.state = TaskInstanceState.REMOVED
else:
# Check if the number of mapped literals has changed, and we
need to mark this TI as removed.
if ti.map_index >= num_mapped_tis:
@@ -1079,10 +1079,10 @@ class DagRun(Base, LoggingMixin):
ti,
num_mapped_tis,
)
- ti.state = State.REMOVED
+ ti.state = TaskInstanceState.REMOVED
elif ti.map_index < 0:
self.log.debug("Removing the unmapped TI '%s' as the
mapping can now be performed", ti)
- ti.state = State.REMOVED
+ ti.state = TaskInstanceState.REMOVED
return task_ids
diff --git a/airflow/models/pool.py b/airflow/models/pool.py
index 60f92506f6..83ec0368bd 100644
--- a/airflow/models/pool.py
+++ b/airflow/models/pool.py
@@ -28,7 +28,7 @@ from airflow.ti_deps.dependencies_states import
EXECUTION_STATES
from airflow.typing_compat import TypedDict
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import nowait, with_row_locks
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
class PoolStats(TypedDict):
@@ -247,7 +247,7 @@ class Pool(Base):
session.scalar(
select(func.sum(TaskInstance.pool_slots))
.filter(TaskInstance.pool == self.pool)
- .filter(TaskInstance.state == State.RUNNING)
+ .filter(TaskInstance.state == TaskInstanceState.RUNNING)
)
or 0
)
@@ -266,7 +266,7 @@ class Pool(Base):
session.scalar(
select(func.sum(TaskInstance.pool_slots))
.filter(TaskInstance.pool == self.pool)
- .filter(TaskInstance.state == State.QUEUED)
+ .filter(TaskInstance.state == TaskInstanceState.QUEUED)
)
or 0
)
@@ -285,7 +285,7 @@ class Pool(Base):
session.scalar(
select(func.sum(TaskInstance.pool_slots))
.filter(TaskInstance.pool == self.pool)
- .filter(TaskInstance.state == State.SCHEDULED)
+ .filter(TaskInstance.state == TaskInstanceState.SCHEDULED)
)
or 0
)
diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index 849083e38b..10991cadc7 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -28,7 +28,7 @@ from airflow.utils import timezone
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.sqlalchemy import tuple_in_condition
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
if TYPE_CHECKING:
from pendulum import DateTime
@@ -79,7 +79,7 @@ class SkipMixin(LoggingMixin):
query.update(
{
- TaskInstance.state: State.SKIPPED,
+ TaskInstance.state: TaskInstanceState.SKIPPED,
TaskInstance.start_date: now,
TaskInstance.end_date: now,
},
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 65eeacdfd3..ee1825c063 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -593,7 +593,7 @@ class TaskInstance(Base, LoggingMixin):
database, in all other cases this will be incremented.
"""
# This is designed so that task logs end up in the right file.
- if self.state == State.RUNNING:
+ if self.state == TaskInstanceState.RUNNING:
return self._try_number
return self._try_number + 1
@@ -811,7 +811,7 @@ class TaskInstance(Base, LoggingMixin):
:param session: SQLAlchemy ORM Session
"""
self.log.error("Recording the task instance as FAILED")
- self.state = State.FAILED
+ self.state = TaskInstanceState.FAILED
session.merge(self)
session.commit()
@@ -940,7 +940,7 @@ class TaskInstance(Base, LoggingMixin):
self.log.debug("Setting task state for %s to %s", self, state)
self.state = state
self.start_date = self.start_date or current_time
- if self.state in State.finished or self.state == State.UP_FOR_RETRY:
+ if self.state in State.finished or self.state ==
TaskInstanceState.UP_FOR_RETRY:
self.end_date = self.end_date or current_time
self.duration = (self.end_date - self.start_date).total_seconds()
session.merge(self)
@@ -953,7 +953,7 @@ class TaskInstance(Base, LoggingMixin):
has elapsed.
"""
# is the task still in the retry waiting period?
- return self.state == State.UP_FOR_RETRY and not self.ready_for_retry()
+ return self.state == TaskInstanceState.UP_FOR_RETRY and not
self.ready_for_retry()
@provide_session
def are_dependents_done(self, session: Session = NEW_SESSION) -> bool:
@@ -976,7 +976,7 @@ class TaskInstance(Base, LoggingMixin):
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id.in_(task.downstream_task_ids),
TaskInstance.run_id == self.run_id,
- TaskInstance.state.in_([State.SKIPPED, State.SUCCESS]),
+ TaskInstance.state.in_((TaskInstanceState.SKIPPED,
TaskInstanceState.SUCCESS)),
)
count = ti[0][0]
return count == len(task.downstream_task_ids)
@@ -1213,7 +1213,7 @@ class TaskInstance(Base, LoggingMixin):
Checks on whether the task instance is in the right state and timeframe
to be retried.
"""
- return self.state == State.UP_FOR_RETRY and self.next_retry_datetime()
< timezone.utcnow()
+ return self.state == TaskInstanceState.UP_FOR_RETRY and
self.next_retry_datetime() < timezone.utcnow()
@provide_session
def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun:
@@ -1282,7 +1282,7 @@ class TaskInstance(Base, LoggingMixin):
self.hostname = get_hostname()
self.pid = None
- if not ignore_all_deps and not ignore_ti_state and self.state ==
State.SUCCESS:
+ if not ignore_all_deps and not ignore_ti_state and self.state ==
TaskInstanceState.SUCCESS:
Stats.incr("previously_succeeded", tags=self.stats_tags)
if not mark_success:
@@ -1310,7 +1310,7 @@ class TaskInstance(Base, LoggingMixin):
# start date that is recorded in task_reschedule table
# If the task continues after being deferred (next_method is set),
use the original start_date
self.start_date = self.start_date if self.next_method else
timezone.utcnow()
- if self.state == State.UP_FOR_RESCHEDULE:
+ if self.state == TaskInstanceState.UP_FOR_RESCHEDULE:
task_reschedule: TR = TR.query_for_task_instance(self,
session=session).first()
if task_reschedule:
self.start_date = task_reschedule.start_date
@@ -1328,7 +1328,7 @@ class TaskInstance(Base, LoggingMixin):
description="requeueable deps",
)
if not self.are_dependencies_met(dep_context=dep_context,
session=session, verbose=True):
- self.state = State.NONE
+ self.state = None
self.log.warning(
"Rescheduling due to concurrency limits reached "
"at task runtime. Attempt %s of "
@@ -1350,8 +1350,8 @@ class TaskInstance(Base, LoggingMixin):
if not test_mode:
session.add(Log(State.RUNNING, self))
- self.state = State.RUNNING
- self.emit_state_change_metric(State.RUNNING)
+ self.state = TaskInstanceState.RUNNING
+ self.emit_state_change_metric(TaskInstanceState.RUNNING)
self.external_executor_id = external_executor_id
self.end_date = None
if not test_mode:
@@ -1391,7 +1391,7 @@ class TaskInstance(Base, LoggingMixin):
self._date_or_empty("end_date"),
)
- def emit_state_change_metric(self, new_state: TaskInstanceState):
+ def emit_state_change_metric(self, new_state: TaskInstanceState) -> None:
"""
Sends a time metric representing how much time a given state
transition took.
The previous state and metric name is deduced from the state the task
was put in.
@@ -1407,7 +1407,7 @@ class TaskInstance(Base, LoggingMixin):
return
# switch on state and deduce which metric to send
- if new_state == State.RUNNING:
+ if new_state == TaskInstanceState.RUNNING:
metric_name = "queued_duration"
if self.queued_dttm is None:
# this should not really happen except in tests or rare cases,
@@ -1419,7 +1419,7 @@ class TaskInstance(Base, LoggingMixin):
)
return
timing = (timezone.utcnow() - self.queued_dttm).total_seconds()
- elif new_state == State.QUEUED:
+ elif new_state == TaskInstanceState.QUEUED:
metric_name = "scheduled_duration"
if self.start_date is None:
# same comment as above
@@ -1506,7 +1506,7 @@ class TaskInstance(Base, LoggingMixin):
self._execute_task_with_callbacks(context, test_mode)
if not test_mode:
self.refresh_from_db(lock_for_update=True, session=session)
- self.state = State.SUCCESS
+ self.state = TaskInstanceState.SUCCESS
except TaskDeferred as defer:
# The task has signalled it wants to defer execution based on
# a trigger.
@@ -1530,7 +1530,7 @@ class TaskInstance(Base, LoggingMixin):
self.log.info(e)
if not test_mode:
self.refresh_from_db(lock_for_update=True, session=session)
- self.state = State.SKIPPED
+ self.state = TaskInstanceState.SKIPPED
except AirflowRescheduleException as reschedule_exception:
self._handle_reschedule(actual_start_date, reschedule_exception,
test_mode, session=session)
session.commit()
@@ -1753,7 +1753,7 @@ class TaskInstance(Base, LoggingMixin):
# Then, update ourselves so it matches the deferral request
# Keep an eye on the logic in
`check_and_change_state_before_execution()`
# depending on self.next_method semantics
- self.state = State.DEFERRED
+ self.state = TaskInstanceState.DEFERRED
self.trigger_id = trigger_row.id
self.next_method = defer.method_name
self.next_kwargs = defer.kwargs or {}
@@ -1872,7 +1872,7 @@ class TaskInstance(Base, LoggingMixin):
)
# set state
- self.state = State.UP_FOR_RESCHEDULE
+ self.state = TaskInstanceState.UP_FOR_RESCHEDULE
# Decrement try_number so subsequent runs will use the same try number
and write
# to same log file.
@@ -1971,7 +1971,7 @@ class TaskInstance(Base, LoggingMixin):
self.log.error("Unable to unmap task to determine if we need to
send an alert email")
if force_fail or not self.is_eligible_to_retry():
- self.state = State.FAILED
+ self.state = TaskInstanceState.FAILED
email_for_state = operator.attrgetter("email_on_failure")
callbacks = task.on_failure_callback if task else None
callback_type = "on_failure"
@@ -1980,10 +1980,10 @@ class TaskInstance(Base, LoggingMixin):
tis = self.get_dagrun(session).get_task_instances()
stop_all_tasks_in_dag(tis, session, self.task_id)
else:
- if self.state == State.QUEUED:
+ if self.state == TaskInstanceState.QUEUED:
# We increase the try_number so as to fail the task if it
fails to start after sometime
self._try_number += 1
- self.state = State.UP_FOR_RETRY
+ self.state = TaskInstanceState.UP_FOR_RETRY
email_for_state = operator.attrgetter("email_on_retry")
callbacks = task.on_retry_callback if task else None
callback_type = "on_retry"
@@ -2004,7 +2004,7 @@ class TaskInstance(Base, LoggingMixin):
def is_eligible_to_retry(self):
"""Is task instance is eligible for retry."""
- if self.state == State.RESTARTING:
+ if self.state == TaskInstanceState.RESTARTING:
# If a task is cleared when running, it goes into RESTARTING state
and is always
# eligible for retry
return True
@@ -2345,7 +2345,7 @@ class TaskInstance(Base, LoggingMixin):
'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>'
)
- # This function is called after changing the state from State.RUNNING,
+ # This function is called after changing the state from RUNNING,
# so we need to subtract 1 from self.try_number here.
current_try_number = self.try_number - 1
additional_context: dict[str, Any] = {
@@ -2572,7 +2572,7 @@ class TaskInstance(Base, LoggingMixin):
num_running_task_instances_query = session.query(func.count()).filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id == self.task_id,
- TaskInstance.state == State.RUNNING,
+ TaskInstance.state == TaskInstanceState.RUNNING,
)
if same_dagrun:
num_running_task_instances_query =
num_running_task_instances_query.filter(
@@ -2894,7 +2894,7 @@ def _is_further_mapped_inside(operator: Operator,
container: TaskGroup) -> bool:
# State of the task instance.
# Stores string version of the task state.
-TaskInstanceStateType = Tuple[TaskInstanceKey, str]
+TaskInstanceStateType = Tuple[TaskInstanceKey, TaskInstanceState]
class SimpleTaskInstance:
diff --git a/airflow/operators/subdag.py b/airflow/operators/subdag.py
index 30daf638bd..0f242c09f3 100644
--- a/airflow/operators/subdag.py
+++ b/airflow/operators/subdag.py
@@ -37,7 +37,7 @@ from airflow.models.taskinstance import TaskInstance
from airflow.sensors.base import BaseSensorOperator
from airflow.utils.context import Context
from airflow.utils.session import NEW_SESSION, create_session, provide_session
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState, TaskInstanceState
from airflow.utils.types import DagRunType
@@ -137,17 +137,17 @@ class SubDagOperator(BaseSensorOperator):
:param execution_date: Execution date to select task instances.
"""
with create_session() as session:
- dag_run.state = State.RUNNING
+ dag_run.state = DagRunState.RUNNING
session.merge(dag_run)
failed_task_instances = (
session.query(TaskInstance)
.filter(TaskInstance.dag_id == self.subdag.dag_id)
.filter(TaskInstance.execution_date == execution_date)
- .filter(TaskInstance.state.in_([State.FAILED,
State.UPSTREAM_FAILED]))
+ .filter(TaskInstance.state.in_((TaskInstanceState.FAILED,
TaskInstanceState.UPSTREAM_FAILED)))
)
for task_instance in failed_task_instances:
- task_instance.state = State.NONE
+ task_instance.state = None
session.merge(task_instance)
session.commit()
@@ -164,7 +164,7 @@ class SubDagOperator(BaseSensorOperator):
dag_run = self.subdag.create_dagrun(
run_type=DagRunType.SCHEDULED,
execution_date=execution_date,
- state=State.RUNNING,
+ state=DagRunState.RUNNING,
conf=self.conf,
external_trigger=True,
data_interval=data_interval,
@@ -172,13 +172,13 @@ class SubDagOperator(BaseSensorOperator):
self.log.info("Created DagRun: %s", dag_run.run_id)
else:
self.log.info("Found existing DagRun: %s", dag_run.run_id)
- if dag_run.state == State.FAILED:
+ if dag_run.state == DagRunState.FAILED:
self._reset_dag_run_and_task_instances(dag_run, execution_date)
def poke(self, context: Context):
execution_date = context["execution_date"]
dag_run = self._get_dagrun(execution_date=execution_date)
- return dag_run.state != State.RUNNING
+ return dag_run.state != DagRunState.RUNNING
def post_execute(self, context, result=None):
super().post_execute(context)
@@ -186,7 +186,7 @@ class SubDagOperator(BaseSensorOperator):
dag_run = self._get_dagrun(execution_date=execution_date)
self.log.info("Execution finished. State is %s", dag_run.state)
- if dag_run.state != State.SUCCESS:
+ if dag_run.state != DagRunState.SUCCESS:
raise AirflowException(f"Expected state: SUCCESS. Actual state:
{dag_run.state}")
if self.propagate_skipped_state and
self._check_skipped_states(context):
@@ -196,9 +196,9 @@ class SubDagOperator(BaseSensorOperator):
leaves_tis = self._get_leaves_tis(context["execution_date"])
if self.propagate_skipped_state ==
SkippedStatePropagationOptions.ANY_LEAF:
- return any(ti.state == State.SKIPPED for ti in leaves_tis)
+ return any(ti.state == TaskInstanceState.SKIPPED for ti in
leaves_tis)
if self.propagate_skipped_state ==
SkippedStatePropagationOptions.ALL_LEAVES:
- return all(ti.state == State.SKIPPED for ti in leaves_tis)
+ return all(ti.state == TaskInstanceState.SKIPPED for ti in
leaves_tis)
raise AirflowException(
f"Unimplemented SkippedStatePropagationOptions
{self.propagate_skipped_state} used."
)
diff --git a/airflow/sentry.py b/airflow/sentry.py
index 3e222405c4..443063af8a 100644
--- a/airflow/sentry.py
+++ b/airflow/sentry.py
@@ -25,27 +25,26 @@ from typing import TYPE_CHECKING
from airflow.configuration import conf
from airflow.executors.executor_loader import ExecutorLoader
from airflow.utils.session import find_session_idx, provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
if TYPE_CHECKING:
from sqlalchemy.orm import Session
+ from airflow.models.taskinstance import TaskInstance
+
log = logging.getLogger(__name__)
class DummySentry:
"""Blank class for Sentry."""
- @classmethod
- def add_tagging(cls, task_instance):
+ def add_tagging(self, task_instance):
"""Blank function for tagging."""
- @classmethod
- def add_breadcrumbs(cls, task_instance, session: Session | None = None):
+ def add_breadcrumbs(self, task_instance, session: Session | None = None):
"""Blank function for breadcrumbs."""
- @classmethod
- def enrich_errors(cls, run):
+ def enrich_errors(self, run):
"""Blank function for formatting a TaskInstance._run_raw_task."""
return run
@@ -137,13 +136,17 @@ if conf.getboolean("sentry", "sentry_on", fallback=False):
scope.set_tag("operator", task.__class__.__name__)
@provide_session
- def add_breadcrumbs(self, task_instance, session=None):
+ def add_breadcrumbs(
+ self,
+ task_instance: TaskInstance,
+ session: Session | None = None,
+ ) -> None:
"""Function to add breadcrumbs inside of a task_instance."""
if session is None:
return
dr = task_instance.get_dagrun(session)
task_instances = dr.get_task_instances(
- state={State.SUCCESS, State.FAILED},
+ state={TaskInstanceState.SUCCESS, TaskInstanceState.FAILED},
session=session,
)
diff --git a/airflow/ti_deps/deps/dagrun_exists_dep.py
b/airflow/ti_deps/deps/dagrun_exists_dep.py
index 781ab0ebaf..0a364628c7 100644
--- a/airflow/ti_deps/deps/dagrun_exists_dep.py
+++ b/airflow/ti_deps/deps/dagrun_exists_dep.py
@@ -19,7 +19,7 @@ from __future__ import annotations
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.utils.session import provide_session
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState
class DagrunRunningDep(BaseTIDep):
@@ -31,7 +31,7 @@ class DagrunRunningDep(BaseTIDep):
@provide_session
def _get_dep_statuses(self, ti, session, dep_context):
dr = ti.get_dagrun(session)
- if dr.state != State.RUNNING:
+ if dr.state != DagRunState.RUNNING:
yield self._failing_status(
reason=f"Task instance's dagrun was not in the 'running' state
but in the state '{dr.state}'."
)
diff --git a/airflow/ti_deps/deps/not_in_retry_period_dep.py
b/airflow/ti_deps/deps/not_in_retry_period_dep.py
index b3b5d4ec56..90954f29f2 100644
--- a/airflow/ti_deps/deps/not_in_retry_period_dep.py
+++ b/airflow/ti_deps/deps/not_in_retry_period_dep.py
@@ -20,7 +20,7 @@ from __future__ import annotations
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.utils import timezone
from airflow.utils.session import provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
class NotInRetryPeriodDep(BaseTIDep):
@@ -38,7 +38,7 @@ class NotInRetryPeriodDep(BaseTIDep):
)
return
- if ti.state != State.UP_FOR_RETRY:
+ if ti.state != TaskInstanceState.UP_FOR_RETRY:
yield self._passing_status(reason="The task instance was not
marked for retrying.")
return
diff --git a/airflow/ti_deps/deps/not_previously_skipped_dep.py
b/airflow/ti_deps/deps/not_previously_skipped_dep.py
index fdc8274d90..855f04af53 100644
--- a/airflow/ti_deps/deps/not_previously_skipped_dep.py
+++ b/airflow/ti_deps/deps/not_previously_skipped_dep.py
@@ -40,7 +40,7 @@ class NotPreviouslySkippedDep(BaseTIDep):
XCOM_SKIPMIXIN_SKIPPED,
SkipMixin,
)
- from airflow.utils.state import State
+ from airflow.utils.state import TaskInstanceState
upstream = ti.task.get_direct_relatives(upstream=True)
@@ -87,7 +87,7 @@ class NotPreviouslySkippedDep(BaseTIDep):
reason=("Task should be skipped but the the
past depends are not met")
)
return
- ti.set_state(State.SKIPPED, session)
+ ti.set_state(TaskInstanceState.SKIPPED, session)
yield self._failing_status(
reason=f"Skipping because of previous XCom result from
parent task {parent.task_id}"
)
diff --git a/airflow/ti_deps/deps/prev_dagrun_dep.py
b/airflow/ti_deps/deps/prev_dagrun_dep.py
index b3165eb7f2..62acdbca33 100644
--- a/airflow/ti_deps/deps/prev_dagrun_dep.py
+++ b/airflow/ti_deps/deps/prev_dagrun_dep.py
@@ -22,7 +22,7 @@ from sqlalchemy import func
from airflow.models.taskinstance import PAST_DEPENDS_MET, TaskInstance as TI
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.utils.session import provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
class PrevDagrunDep(BaseTIDep):
@@ -107,7 +107,7 @@ class PrevDagrunDep(BaseTIDep):
)
return
- if previous_ti.state not in {State.SKIPPED, State.SUCCESS}:
+ if previous_ti.state not in {TaskInstanceState.SKIPPED,
TaskInstanceState.SUCCESS}:
yield self._failing_status(
reason=(
f"depends_on_past is true for this task, but the previous
task instance {previous_ti} "
diff --git a/airflow/ti_deps/deps/ready_to_reschedule.py
b/airflow/ti_deps/deps/ready_to_reschedule.py
index c20ef98a08..4fca6f5538 100644
--- a/airflow/ti_deps/deps/ready_to_reschedule.py
+++ b/airflow/ti_deps/deps/ready_to_reschedule.py
@@ -22,7 +22,7 @@ from airflow.models.taskreschedule import TaskReschedule
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.utils import timezone
from airflow.utils.session import provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
class ReadyToRescheduleDep(BaseTIDep):
@@ -31,7 +31,7 @@ class ReadyToRescheduleDep(BaseTIDep):
NAME = "Ready To Reschedule"
IGNORABLE = True
IS_TASK_DEP = True
- RESCHEDULEABLE_STATES = {State.UP_FOR_RESCHEDULE, State.NONE}
+ RESCHEDULEABLE_STATES = {TaskInstanceState.UP_FOR_RESCHEDULE, None}
@provide_session
def _get_dep_statuses(self, ti, session, dep_context):
diff --git a/airflow/ti_deps/deps/task_not_running_dep.py
b/airflow/ti_deps/deps/task_not_running_dep.py
index fd76873466..7299a4f3e2 100644
--- a/airflow/ti_deps/deps/task_not_running_dep.py
+++ b/airflow/ti_deps/deps/task_not_running_dep.py
@@ -20,7 +20,7 @@ from __future__ import annotations
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.utils.session import provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
class TaskNotRunningDep(BaseTIDep):
@@ -37,7 +37,7 @@ class TaskNotRunningDep(BaseTIDep):
@provide_session
def _get_dep_statuses(self, ti, session, dep_context=None):
- if ti.state != State.RUNNING:
+ if ti.state != TaskInstanceState.RUNNING:
yield self._passing_status(reason="Task is not in running state.")
return
diff --git a/airflow/utils/dot_renderer.py b/airflow/utils/dot_renderer.py
index 3f35d1b21d..d3b329cb54 100644
--- a/airflow/utils/dot_renderer.py
+++ b/airflow/utils/dot_renderer.py
@@ -58,7 +58,7 @@ def _draw_task(
) -> None:
"""Draw a single task on the given parent_graph."""
if states_by_task_id:
- state = states_by_task_id.get(task.task_id, State.NONE)
+ state = states_by_task_id.get(task.task_id)
color = State.color_fg(state)
fill_color = State.color(state)
else:
diff --git a/airflow/utils/log/file_task_handler.py
b/airflow/utils/log/file_task_handler.py
index 0c98d25503..9184a20420 100644
--- a/airflow/utils/log/file_task_handler.py
+++ b/airflow/utils/log/file_task_handler.py
@@ -341,7 +341,10 @@ class FileTaskHandler(logging.Handler):
)
log_pos = len(logs)
messages = "".join([f"*** {x}\n" for x in messages_list])
- end_of_log = ti.try_number != try_number or ti.state not in
[State.RUNNING, State.DEFERRED]
+ end_of_log = ti.try_number != try_number or ti.state not in (
+ TaskInstanceState.RUNNING,
+ TaskInstanceState.DEFERRED,
+ )
if metadata and "log_pos" in metadata:
previous_chars = metadata["log_pos"]
logs = logs[previous_chars:] # Cut off previously passed log test
as new tail
diff --git a/airflow/utils/log/log_reader.py b/airflow/utils/log/log_reader.py
index a4589ebca0..d93f15bb1a 100644
--- a/airflow/utils/log/log_reader.py
+++ b/airflow/utils/log/log_reader.py
@@ -28,7 +28,7 @@ from airflow.models.taskinstance import TaskInstance
from airflow.utils.helpers import render_log_filename
from airflow.utils.log.logging_mixin import ExternalLoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
class TaskLogReader:
@@ -86,7 +86,8 @@ class TaskLogReader:
for host, log in logs[0]:
yield "\n".join([host or "", log]) + "\n"
if "end_of_log" not in metadata or (
- not metadata["end_of_log"] and ti.state not in
[State.RUNNING, State.DEFERRED]
+ not metadata["end_of_log"]
+ and ti.state not in (TaskInstanceState.RUNNING,
TaskInstanceState.DEFERRED)
):
if not logs[0]:
# we did not receive any logs in this loop
diff --git a/airflow/utils/state.py b/airflow/utils/state.py
index f4a8dc1a0a..fc74732acc 100644
--- a/airflow/utils/state.py
+++ b/airflow/utils/state.py
@@ -21,8 +21,7 @@ from enum import Enum
class TaskInstanceState(str, Enum):
- """
- Enum that represents all possible states that a Task Instance can be in.
+ """All possible states that a Task Instance can be in.
Note that None is also allowed, so always use this in a type hint with
Optional.
"""
@@ -53,8 +52,7 @@ class TaskInstanceState(str, Enum):
class DagRunState(str, Enum):
- """
- Enum that represents all possible states that a DagRun can be in.
+ """All possible states that a DagRun can be in.
These are "shared" with TaskInstanceState in some parts of the code,
so please ensure that their values always match the ones with the
diff --git a/airflow/www/utils.py b/airflow/www/utils.py
index 1ca6289122..bb48f81ccc 100644
--- a/airflow/www/utils.py
+++ b/airflow/www/utils.py
@@ -87,7 +87,9 @@ def get_instance_with_map(task_instance, session):
def get_try_count(try_number: int, state: State):
- return try_number + 1 if state in [State.DEFERRED,
State.UP_FOR_RESCHEDULE] else try_number
+ if state in (TaskInstanceState.DEFERRED,
TaskInstanceState.UP_FOR_RESCHEDULE):
+ return try_number + 1
+ return try_number
priority: list[None | TaskInstanceState] = [
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 370e15ee00..e198e390ac 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -767,7 +767,7 @@ class Airflow(AirflowBaseView):
# find DAGs which have a RUNNING DagRun
running_dags = dags_query.join(DagRun, DagModel.dag_id ==
DagRun.dag_id).where(
- DagRun.state == State.RUNNING
+ DagRun.state == DagRunState.RUNNING
)
# find DAGs for which the latest DagRun is FAILED
@@ -778,7 +778,7 @@ class Airflow(AirflowBaseView):
)
subq_failed = (
select(DagRun.dag_id,
func.max(DagRun.start_date).label("start_date"))
- .where(DagRun.state == State.FAILED)
+ .where(DagRun.state == DagRunState.FAILED)
.group_by(DagRun.dag_id)
.subquery()
)
@@ -1127,7 +1127,7 @@ class Airflow(AirflowBaseView):
running_dag_run_query_result = (
select(DagRun.dag_id, DagRun.run_id)
.join(DagModel, DagModel.dag_id == DagRun.dag_id)
- .where(DagRun.state == State.RUNNING, DagModel.is_active)
+ .where(DagRun.state == DagRunState.RUNNING, DagModel.is_active)
)
running_dag_run_query_result =
running_dag_run_query_result.where(DagRun.dag_id.in_(filter_dag_ids))
@@ -1151,7 +1151,7 @@ class Airflow(AirflowBaseView):
last_dag_run = (
select(DagRun.dag_id,
sqla.func.max(DagRun.execution_date).label("execution_date"))
.join(DagModel, DagModel.dag_id == DagRun.dag_id)
- .where(DagRun.state != State.RUNNING, DagModel.is_active)
+ .where(DagRun.state != DagRunState.RUNNING, DagModel.is_active)
.group_by(DagRun.dag_id)
)
@@ -1856,7 +1856,7 @@ class Airflow(AirflowBaseView):
"Airflow administrator for assistance.".format(
"- This task instance already ran and had it's state
changed manually "
"(e.g. cleared in the UI)<br>"
- if ti and ti.state == State.NONE
+ if ti and ti.state is None
else ""
),
)
@@ -2169,7 +2169,7 @@ class Airflow(AirflowBaseView):
run_type=DagRunType.MANUAL,
execution_date=execution_date,
data_interval=dag.timetable.infer_manual_data_interval(run_after=execution_date),
- state=State.QUEUED,
+ state=DagRunState.QUEUED,
conf=run_conf,
external_trigger=True,
dag_hash=get_airflow_app().dag_bag.dags_hash.get(dag_id),
@@ -5426,23 +5426,28 @@ class
DagRunModelView(AirflowPrivilegeVerifierModelView):
@action_logging
def action_set_queued(self, drs: list[DagRun]):
"""Set state to queued."""
- return self._set_dag_runs_to_active_state(drs, State.QUEUED)
+ return self._set_dag_runs_to_active_state(drs, DagRunState.QUEUED)
@action("set_running", "Set state to 'running'", "", single=False)
@action_has_dag_edit_access
@action_logging
def action_set_running(self, drs: list[DagRun]):
"""Set state to running."""
- return self._set_dag_runs_to_active_state(drs, State.RUNNING)
+ return self._set_dag_runs_to_active_state(drs, DagRunState.RUNNING)
@provide_session
- def _set_dag_runs_to_active_state(self, drs: list[DagRun], state: str,
session: Session = NEW_SESSION):
+ def _set_dag_runs_to_active_state(
+ self,
+ drs: list[DagRun],
+ state: DagRunState,
+ session: Session = NEW_SESSION,
+ ):
"""This routine only supports Running and Queued state."""
try:
count = 0
for dr in
session.scalars(select(DagRun).where(DagRun.id.in_(dagrun.id for dagrun in
drs))):
count += 1
- if state == State.RUNNING:
+ if state == DagRunState.RUNNING:
dr.start_date = timezone.utcnow()
dr.state = state
session.commit()
@@ -5863,7 +5868,12 @@ class
TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
return redirect(self.get_redirect())
@provide_session
- def set_task_instance_state(self, tis, target_state, session: Session =
NEW_SESSION):
+ def set_task_instance_state(
+ self,
+ tis: Collection[TaskInstance],
+ target_state: TaskInstanceState,
+ session: Session = NEW_SESSION,
+ ) -> None:
"""Set task instance state."""
try:
count = len(tis)
@@ -5879,7 +5889,7 @@ class
TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
@action_logging
def action_set_running(self, tis):
"""Set state to 'running'."""
- self.set_task_instance_state(tis, State.RUNNING)
+ self.set_task_instance_state(tis, TaskInstanceState.RUNNING)
self.update_redirect()
return redirect(self.get_redirect())
@@ -5888,7 +5898,7 @@ class
TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
@action_logging
def action_set_failed(self, tis):
"""Set state to 'failed'."""
- self.set_task_instance_state(tis, State.FAILED)
+ self.set_task_instance_state(tis, TaskInstanceState.FAILED)
self.update_redirect()
return redirect(self.get_redirect())
@@ -5897,7 +5907,7 @@ class
TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
@action_logging
def action_set_success(self, tis):
"""Set state to 'success'."""
- self.set_task_instance_state(tis, State.SUCCESS)
+ self.set_task_instance_state(tis, TaskInstanceState.SUCCESS)
self.update_redirect()
return redirect(self.get_redirect())
@@ -5906,7 +5916,7 @@ class
TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
@action_logging
def action_set_retry(self, tis):
"""Set state to 'up_for_retry'."""
- self.set_task_instance_state(tis, State.UP_FOR_RETRY)
+ self.set_task_instance_state(tis, TaskInstanceState.UP_FOR_RETRY)
self.update_redirect()
return redirect(self.get_redirect())
diff --git a/dev/perf/scheduler_dag_execution_timing.py
b/dev/perf/scheduler_dag_execution_timing.py
index 613a929e9e..db7f4f2e8d 100755
--- a/dev/perf/scheduler_dag_execution_timing.py
+++ b/dev/perf/scheduler_dag_execution_timing.py
@@ -63,7 +63,7 @@ class ShortCircuitExecutorMixin:
Change the state of scheduler by waiting till the tasks is complete
and then shut down the scheduler after the task is complete
"""
- from airflow.utils.state import State
+ from airflow.utils.state import TaskInstanceState
super().change_state(key, state, info=info)
@@ -83,7 +83,7 @@ class ShortCircuitExecutorMixin:
run = list(airflow.models.DagRun.find(dag_id=dag_id,
execution_date=execution_date))[0]
self.dags_to_watch[dag_id].runs[execution_date] = run
- if run and all(t.state == State.SUCCESS for t in
run.get_task_instances()):
+ if run and all(t.state == TaskInstanceState.SUCCESS for t in
run.get_task_instances()):
self.dags_to_watch[dag_id].runs.pop(execution_date)
self.dags_to_watch[dag_id].waiting_for -= 1
@@ -156,7 +156,7 @@ def create_dag_runs(dag, num_runs, session):
Create `num_runs` of dag runs for sub-sequent schedules
"""
from airflow.utils import timezone
- from airflow.utils.state import State
+ from airflow.utils.state import DagRunState
try:
from airflow.utils.types import DagRunType
@@ -175,7 +175,7 @@ def create_dag_runs(dag, num_runs, session):
run_id=f"{id_prefix}{logical_date.isoformat()}",
execution_date=logical_date,
start_date=timezone.utcnow(),
- state=State.RUNNING,
+ state=DagRunState.RUNNING,
external_trigger=False,
session=session,
)