This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new fd2687d0d5 Strong-type all single-state enum values (#32537)
fd2687d0d5 is described below
commit fd2687d0d51fa444f63aac705d843f30e5922319
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Wed Jul 12 14:50:20 2023 +0800
Strong-type all single-state enum values (#32537)
---
airflow/jobs/backfill_job_runner.py | 4 ++--
airflow/jobs/scheduler_job_runner.py | 4 ++--
airflow/models/dagrun.py | 2 +-
airflow/models/pool.py | 8 ++++----
airflow/models/taskinstance.py | 4 ++--
airflow/operators/trigger_dagrun.py | 16 +++++++++++-----
airflow/sensors/external_task.py | 34 ++++++++++++++++++----------------
airflow/ti_deps/dependencies_states.py | 30 +++++++++++++++---------------
airflow/triggers/external_task.py | 3 ++-
9 files changed, 57 insertions(+), 48 deletions(-)
diff --git a/airflow/jobs/backfill_job_runner.py
b/airflow/jobs/backfill_job_runner.py
index 35910b83ae..162f103642 100644
--- a/airflow/jobs/backfill_job_runner.py
+++ b/airflow/jobs/backfill_job_runner.py
@@ -696,7 +696,7 @@ class BackfillJobRunner(BaseJobRunner[Job], LoggingMixin):
_dag_runs = ti_status.active_runs[:]
for run in _dag_runs:
run.update_state(session=session)
- if run.state in State.finished:
+ if run.state in State.finished_dr_states:
ti_status.finished_runs += 1
ti_status.active_runs.remove(run)
executed_run_dates.append(run.execution_date)
@@ -824,7 +824,7 @@ class BackfillJobRunner(BaseJobRunner[Job], LoggingMixin):
"""
for dag_run in dag_runs:
dag_run.update_state()
- if dag_run.state not in State.finished:
+ if dag_run.state not in State.finished_dr_states:
dag_run.set_state(DagRunState.FAILED)
session.merge(dag_run)
diff --git a/airflow/jobs/scheduler_job_runner.py
b/airflow/jobs/scheduler_job_runner.py
index 7b13ecc300..c0e6877076 100644
--- a/airflow/jobs/scheduler_job_runner.py
+++ b/airflow/jobs/scheduler_job_runner.py
@@ -625,7 +625,7 @@ class SchedulerJobRunner(BaseJobRunner[Job], LoggingMixin):
"""
# actually enqueue them
for ti in task_instances:
- if ti.dag_run.state in State.finished:
+ if ti.dag_run.state in State.finished_dr_states:
ti.set_state(None, session=session)
continue
command = ti.command_as_list(
@@ -1449,7 +1449,7 @@ class SchedulerJobRunner(BaseJobRunner[Job],
LoggingMixin):
# TODO[HA]: Rename update_state -> schedule_dag_run, ?? something else?
schedulable_tis, callback_to_run =
dag_run.update_state(session=session, execute_callbacks=False)
# Check if DAG not scheduled then skip interval calculation to same
scheduler runtime
- if dag_run.state in State.finished:
+ if dag_run.state in State.finished_dr_states:
# Work out if we should allow creating a new DagRun now?
if self._should_update_dag_next_dagruns(dag, dag_model,
session=session):
dag_model.calculate_dagrun_date_fields(dag,
dag.get_run_data_interval(dag_run))
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index e3ed0bda00..2d3ff26d7f 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -270,7 +270,7 @@ class DagRun(Base, LoggingMixin):
raise ValueError(f"invalid DagRun state: {state}")
if self._state != state:
self._state = state
- self.end_date = timezone.utcnow() if self._state in State.finished
else None
+ self.end_date = timezone.utcnow() if self._state in
State.finished_dr_states else None
if state == DagRunState.QUEUED:
self.queued_at = timezone.utcnow()
diff --git a/airflow/models/pool.py b/airflow/models/pool.py
index 83ec0368bd..0aab6fb634 100644
--- a/airflow/models/pool.py
+++ b/airflow/models/pool.py
@@ -175,12 +175,12 @@ class Pool(Base):
state_count_by_pool = session.execute(
select(TaskInstance.pool, TaskInstance.state,
func.sum(TaskInstance.pool_slots))
- .filter(TaskInstance.state.in_(list(EXECUTION_STATES)))
+ .filter(TaskInstance.state.in_(EXECUTION_STATES))
.group_by(TaskInstance.pool, TaskInstance.state)
)
# calculate queued and running metrics
- for (pool_name, state, count) in state_count_by_pool:
+ for pool_name, state, count in state_count_by_pool:
# Some databases return decimal.Decimal here.
count = int(count)
@@ -188,9 +188,9 @@ class Pool(Base):
if not stats_dict:
continue
# TypedDict key must be a string literal, so we use if-statements
to set value
- if state == "running":
+ if state == TaskInstanceState.RUNNING:
stats_dict["running"] = count
- elif state == "queued":
+ elif state == TaskInstanceState.QUEUED:
stats_dict["queued"] = count
else:
raise AirflowException(f"Unexpected state. Expected values:
{EXECUTION_STATES}.")
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index ee1825c063..3ec15f449d 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -1348,7 +1348,7 @@ class TaskInstance(Base, LoggingMixin):
self._try_number += 1
if not test_mode:
- session.add(Log(State.RUNNING, self))
+ session.add(Log(TaskInstanceState.RUNNING.value, self))
self.state = TaskInstanceState.RUNNING
self.emit_state_change_metric(TaskInstanceState.RUNNING)
@@ -1937,7 +1937,7 @@ class TaskInstance(Base, LoggingMixin):
Stats.incr("ti_failures", tags=self.stats_tags)
if not test_mode:
- session.add(Log(State.FAILED, self))
+ session.add(Log(TaskInstanceState.FAILED.value, self))
# Log failure duration
session.add(TaskFail(ti=self))
diff --git a/airflow/operators/trigger_dagrun.py
b/airflow/operators/trigger_dagrun.py
index 548ef91894..5a7be1dcab 100644
--- a/airflow/operators/trigger_dagrun.py
+++ b/airflow/operators/trigger_dagrun.py
@@ -37,7 +37,7 @@ from airflow.utils import timezone
from airflow.utils.context import Context
from airflow.utils.helpers import build_airflow_url_with_query
from airflow.utils.session import provide_session
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunType
XCOM_EXECUTION_DATE_ISO = "trigger_execution_date_iso"
@@ -112,8 +112,8 @@ class TriggerDagRunOperator(BaseOperator):
reset_dag_run: bool = False,
wait_for_completion: bool = False,
poke_interval: int = 60,
- allowed_states: list | None = None,
- failed_states: list | None = None,
+ allowed_states: list[str] | None = None,
+ failed_states: list[str] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
**kwargs,
) -> None:
@@ -124,8 +124,14 @@ class TriggerDagRunOperator(BaseOperator):
self.reset_dag_run = reset_dag_run
self.wait_for_completion = wait_for_completion
self.poke_interval = poke_interval
- self.allowed_states = allowed_states or [State.SUCCESS]
- self.failed_states = failed_states or [State.FAILED]
+ if allowed_states:
+ self.allowed_states = [DagRunState(s) for s in allowed_states]
+ else:
+ self.allowed_states = [DagRunState.SUCCESS]
+ if failed_states:
+ self.failed_states = [DagRunState(s) for s in failed_states]
+ else:
+ self.failed_states = [DagRunState.FAILED]
self._defer = deferrable
if execution_date is not None and not isinstance(execution_date, (str,
datetime.datetime)):
diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py
index bacafc979c..b5c1c1321f 100644
--- a/airflow/sensors/external_task.py
+++ b/airflow/sensors/external_task.py
@@ -37,7 +37,7 @@ from airflow.utils.file import correct_maybe_zipped
from airflow.utils.helpers import build_airflow_url_with_query
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import tuple_in_condition
-from airflow.utils.state import State
+from airflow.utils.state import State, TaskInstanceState
if TYPE_CHECKING:
from sqlalchemy.orm import Query, Session
@@ -76,26 +76,28 @@ class ExternalTaskSensor(BaseSensorOperator):
without also having to clear the sensor).
By default, the ExternalTaskSensor will not skip if the external task
skips.
- To change this, simply set ``skipped_states=[State.SKIPPED]``. Note that if
- you are monitoring multiple tasks, and one enters error state and the other
- enters a skipped state, then the external task will react to whichever one
- it sees first. If both happen together, then the failed state takes
priority.
+ To change this, simply set ``skipped_states=[TaskInstanceState.SKIPPED]``.
+ Note that if you are monitoring multiple tasks, and one enters error state
+ and the other enters a skipped state, then the external task will react to
+ whichever one it sees first. If both happen together, then the failed state
+ takes priority.
It is possible to alter the default behavior by setting states which
- cause the sensor to fail, e.g. by setting ``allowed_states=[State.FAILED]``
- and ``failed_states=[State.SUCCESS]`` you will flip the behaviour to get a
- sensor which goes green when the external task *fails* and immediately goes
- red if the external task *succeeds*!
+ cause the sensor to fail, e.g. by setting
``allowed_states=[DagRunState.FAILED]``
+ and ``failed_states=[DagRunState.SUCCESS]`` you will flip the behaviour to
+ get a sensor which goes green when the external task *fails* and
immediately
+ goes red if the external task *succeeds*!
Note that ``soft_fail`` is respected when examining the failed_states. Thus
if the external task enters a failed state and ``soft_fail == True`` the
sensor will _skip_ rather than fail. As a result, setting
``soft_fail=True``
- and ``failed_states=[State.SKIPPED]`` will result in the sensor skipping if
- the external task skips. However, this is a contrived example - consider
- using ``skipped_states`` if you would like this behaviour. Using
- ``skipped_states`` allows the sensor to skip if the target fails, but still
- enter failed state on timeout. Using ``soft_fail == True`` as above will
- cause the sensor to skip if the target fails, but also if it times out.
+ and ``failed_states=[DagRunState.SKIPPED]`` will result in the sensor
+ skipping if the external task skips. However, this is a contrived
+ example---consider using ``skipped_states`` if you would like this
+ behaviour. Using ``skipped_states`` allows the sensor to skip if the target
+ fails, but still enter failed state on timeout. Using ``soft_fail == True``
+ as above will cause the sensor to skip if the target fails, but also if it
+ times out.
:param external_dag_id: The dag_id that contains the task you want to
wait for. (templated)
@@ -146,7 +148,7 @@ class ExternalTaskSensor(BaseSensorOperator):
**kwargs,
):
super().__init__(**kwargs)
- self.allowed_states = list(allowed_states) if allowed_states else
[State.SUCCESS]
+ self.allowed_states = list(allowed_states) if allowed_states else
[TaskInstanceState.SUCCESS.value]
self.skipped_states = list(skipped_states) if skipped_states else []
self.failed_states = list(failed_states) if failed_states else []
diff --git a/airflow/ti_deps/dependencies_states.py
b/airflow/ti_deps/dependencies_states.py
index 543ce3528a..fd25d62f6d 100644
--- a/airflow/ti_deps/dependencies_states.py
+++ b/airflow/ti_deps/dependencies_states.py
@@ -16,38 +16,38 @@
# under the License.
from __future__ import annotations
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
EXECUTION_STATES = {
- State.RUNNING,
- State.QUEUED,
+ TaskInstanceState.RUNNING,
+ TaskInstanceState.QUEUED,
}
# In order to be able to get queued a task must have one of these states
SCHEDULEABLE_STATES = {
- State.NONE,
- State.UP_FOR_RETRY,
- State.UP_FOR_RESCHEDULE,
+ None,
+ TaskInstanceState.UP_FOR_RETRY,
+ TaskInstanceState.UP_FOR_RESCHEDULE,
}
RUNNABLE_STATES = {
# For cases like unit tests and run manually
- State.NONE,
- State.UP_FOR_RETRY,
- State.UP_FOR_RESCHEDULE,
+ None,
+ TaskInstanceState.UP_FOR_RETRY,
+ TaskInstanceState.UP_FOR_RESCHEDULE,
# For normal scheduler/backfill cases
- State.QUEUED,
+ TaskInstanceState.QUEUED,
}
QUEUEABLE_STATES = {
- State.SCHEDULED,
+ TaskInstanceState.SCHEDULED,
}
BACKFILL_QUEUEABLE_STATES = {
# For cases like unit tests and run manually
- State.NONE,
- State.UP_FOR_RESCHEDULE,
- State.UP_FOR_RETRY,
+ None,
+ TaskInstanceState.UP_FOR_RESCHEDULE,
+ TaskInstanceState.UP_FOR_RETRY,
# For normal backfill cases
- State.SCHEDULED,
+ TaskInstanceState.SCHEDULED,
}
diff --git a/airflow/triggers/external_task.py
b/airflow/triggers/external_task.py
index 547556e315..e739c7a7cb 100644
--- a/airflow/triggers/external_task.py
+++ b/airflow/triggers/external_task.py
@@ -27,6 +27,7 @@ from sqlalchemy.orm import Session
from airflow.models import DagRun, TaskInstance
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.state import DagRunState
class TaskStateTrigger(BaseTrigger):
@@ -110,7 +111,7 @@ class DagStateTrigger(BaseTrigger):
def __init__(
self,
dag_id: str,
- states: list[str],
+ states: list[DagRunState],
execution_dates: list[datetime.datetime],
poll_interval: float = 5.0,
):