This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new fd2687d0d5 Strong-type all single-state enum values (#32537)
fd2687d0d5 is described below

commit fd2687d0d51fa444f63aac705d843f30e5922319
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Wed Jul 12 14:50:20 2023 +0800

    Strong-type all single-state enum values (#32537)
---
 airflow/jobs/backfill_job_runner.py    |  4 ++--
 airflow/jobs/scheduler_job_runner.py   |  4 ++--
 airflow/models/dagrun.py               |  2 +-
 airflow/models/pool.py                 |  8 ++++----
 airflow/models/taskinstance.py         |  4 ++--
 airflow/operators/trigger_dagrun.py    | 16 +++++++++++-----
 airflow/sensors/external_task.py       | 34 ++++++++++++++++++----------------
 airflow/ti_deps/dependencies_states.py | 30 +++++++++++++++---------------
 airflow/triggers/external_task.py      |  3 ++-
 9 files changed, 57 insertions(+), 48 deletions(-)

diff --git a/airflow/jobs/backfill_job_runner.py 
b/airflow/jobs/backfill_job_runner.py
index 35910b83ae..162f103642 100644
--- a/airflow/jobs/backfill_job_runner.py
+++ b/airflow/jobs/backfill_job_runner.py
@@ -696,7 +696,7 @@ class BackfillJobRunner(BaseJobRunner[Job], LoggingMixin):
             _dag_runs = ti_status.active_runs[:]
             for run in _dag_runs:
                 run.update_state(session=session)
-                if run.state in State.finished:
+                if run.state in State.finished_dr_states:
                     ti_status.finished_runs += 1
                     ti_status.active_runs.remove(run)
                     executed_run_dates.append(run.execution_date)
@@ -824,7 +824,7 @@ class BackfillJobRunner(BaseJobRunner[Job], LoggingMixin):
         """
         for dag_run in dag_runs:
             dag_run.update_state()
-            if dag_run.state not in State.finished:
+            if dag_run.state not in State.finished_dr_states:
                 dag_run.set_state(DagRunState.FAILED)
             session.merge(dag_run)
 
diff --git a/airflow/jobs/scheduler_job_runner.py 
b/airflow/jobs/scheduler_job_runner.py
index 7b13ecc300..c0e6877076 100644
--- a/airflow/jobs/scheduler_job_runner.py
+++ b/airflow/jobs/scheduler_job_runner.py
@@ -625,7 +625,7 @@ class SchedulerJobRunner(BaseJobRunner[Job], LoggingMixin):
         """
         # actually enqueue them
         for ti in task_instances:
-            if ti.dag_run.state in State.finished:
+            if ti.dag_run.state in State.finished_dr_states:
                 ti.set_state(None, session=session)
                 continue
             command = ti.command_as_list(
@@ -1449,7 +1449,7 @@ class SchedulerJobRunner(BaseJobRunner[Job], 
LoggingMixin):
         # TODO[HA]: Rename update_state -> schedule_dag_run, ?? something else?
         schedulable_tis, callback_to_run = 
dag_run.update_state(session=session, execute_callbacks=False)
         # Check if DAG not scheduled then skip interval calculation to same 
scheduler runtime
-        if dag_run.state in State.finished:
+        if dag_run.state in State.finished_dr_states:
             # Work out if we should allow creating a new DagRun now?
             if self._should_update_dag_next_dagruns(dag, dag_model, 
session=session):
                 dag_model.calculate_dagrun_date_fields(dag, 
dag.get_run_data_interval(dag_run))
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index e3ed0bda00..2d3ff26d7f 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -270,7 +270,7 @@ class DagRun(Base, LoggingMixin):
             raise ValueError(f"invalid DagRun state: {state}")
         if self._state != state:
             self._state = state
-            self.end_date = timezone.utcnow() if self._state in State.finished 
else None
+            self.end_date = timezone.utcnow() if self._state in 
State.finished_dr_states else None
             if state == DagRunState.QUEUED:
                 self.queued_at = timezone.utcnow()
 
diff --git a/airflow/models/pool.py b/airflow/models/pool.py
index 83ec0368bd..0aab6fb634 100644
--- a/airflow/models/pool.py
+++ b/airflow/models/pool.py
@@ -175,12 +175,12 @@ class Pool(Base):
 
         state_count_by_pool = session.execute(
             select(TaskInstance.pool, TaskInstance.state, 
func.sum(TaskInstance.pool_slots))
-            .filter(TaskInstance.state.in_(list(EXECUTION_STATES)))
+            .filter(TaskInstance.state.in_(EXECUTION_STATES))
             .group_by(TaskInstance.pool, TaskInstance.state)
         )
 
         # calculate queued and running metrics
-        for (pool_name, state, count) in state_count_by_pool:
+        for pool_name, state, count in state_count_by_pool:
             # Some databases return decimal.Decimal here.
             count = int(count)
 
@@ -188,9 +188,9 @@ class Pool(Base):
             if not stats_dict:
                 continue
             # TypedDict key must be a string literal, so we use if-statements 
to set value
-            if state == "running":
+            if state == TaskInstanceState.RUNNING:
                 stats_dict["running"] = count
-            elif state == "queued":
+            elif state == TaskInstanceState.QUEUED:
                 stats_dict["queued"] = count
             else:
                 raise AirflowException(f"Unexpected state. Expected values: 
{EXECUTION_STATES}.")
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index ee1825c063..3ec15f449d 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -1348,7 +1348,7 @@ class TaskInstance(Base, LoggingMixin):
         self._try_number += 1
 
         if not test_mode:
-            session.add(Log(State.RUNNING, self))
+            session.add(Log(TaskInstanceState.RUNNING.value, self))
 
         self.state = TaskInstanceState.RUNNING
         self.emit_state_change_metric(TaskInstanceState.RUNNING)
@@ -1937,7 +1937,7 @@ class TaskInstance(Base, LoggingMixin):
         Stats.incr("ti_failures", tags=self.stats_tags)
 
         if not test_mode:
-            session.add(Log(State.FAILED, self))
+            session.add(Log(TaskInstanceState.FAILED.value, self))
 
             # Log failure duration
             session.add(TaskFail(ti=self))
diff --git a/airflow/operators/trigger_dagrun.py 
b/airflow/operators/trigger_dagrun.py
index 548ef91894..5a7be1dcab 100644
--- a/airflow/operators/trigger_dagrun.py
+++ b/airflow/operators/trigger_dagrun.py
@@ -37,7 +37,7 @@ from airflow.utils import timezone
 from airflow.utils.context import Context
 from airflow.utils.helpers import build_airflow_url_with_query
 from airflow.utils.session import provide_session
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState
 from airflow.utils.types import DagRunType
 
 XCOM_EXECUTION_DATE_ISO = "trigger_execution_date_iso"
@@ -112,8 +112,8 @@ class TriggerDagRunOperator(BaseOperator):
         reset_dag_run: bool = False,
         wait_for_completion: bool = False,
         poke_interval: int = 60,
-        allowed_states: list | None = None,
-        failed_states: list | None = None,
+        allowed_states: list[str] | None = None,
+        failed_states: list[str] | None = None,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         **kwargs,
     ) -> None:
@@ -124,8 +124,14 @@ class TriggerDagRunOperator(BaseOperator):
         self.reset_dag_run = reset_dag_run
         self.wait_for_completion = wait_for_completion
         self.poke_interval = poke_interval
-        self.allowed_states = allowed_states or [State.SUCCESS]
-        self.failed_states = failed_states or [State.FAILED]
+        if allowed_states:
+            self.allowed_states = [DagRunState(s) for s in allowed_states]
+        else:
+            self.allowed_states = [DagRunState.SUCCESS]
+        if failed_states:
+            self.failed_states = [DagRunState(s) for s in failed_states]
+        else:
+            self.failed_states = [DagRunState.FAILED]
         self._defer = deferrable
 
         if execution_date is not None and not isinstance(execution_date, (str, 
datetime.datetime)):
diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py
index bacafc979c..b5c1c1321f 100644
--- a/airflow/sensors/external_task.py
+++ b/airflow/sensors/external_task.py
@@ -37,7 +37,7 @@ from airflow.utils.file import correct_maybe_zipped
 from airflow.utils.helpers import build_airflow_url_with_query
 from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.sqlalchemy import tuple_in_condition
-from airflow.utils.state import State
+from airflow.utils.state import State, TaskInstanceState
 
 if TYPE_CHECKING:
     from sqlalchemy.orm import Query, Session
@@ -76,26 +76,28 @@ class ExternalTaskSensor(BaseSensorOperator):
     without also having to clear the sensor).
 
     By default, the ExternalTaskSensor will not skip if the external task 
skips.
-    To change this, simply set ``skipped_states=[State.SKIPPED]``. Note that if
-    you are monitoring multiple tasks, and one enters error state and the other
-    enters a skipped state, then the external task will react to whichever one
-    it sees first. If both happen together, then the failed state takes 
priority.
+    To change this, simply set ``skipped_states=[TaskInstanceState.SKIPPED]``.
+    Note that if you are monitoring multiple tasks, and one enters error state
+    and the other enters a skipped state, then the external task will react to
+    whichever one it sees first. If both happen together, then the failed state
+    takes priority.
 
     It is possible to alter the default behavior by setting states which
-    cause the sensor to fail, e.g. by setting ``allowed_states=[State.FAILED]``
-    and ``failed_states=[State.SUCCESS]`` you will flip the behaviour to get a
-    sensor which goes green when the external task *fails* and immediately goes
-    red if the external task *succeeds*!
+    cause the sensor to fail, e.g. by setting 
``allowed_states=[DagRunState.FAILED]``
+    and ``failed_states=[DagRunState.SUCCESS]`` you will flip the behaviour to
+    get a sensor which goes green when the external task *fails* and 
immediately
+    goes red if the external task *succeeds*!
 
     Note that ``soft_fail`` is respected when examining the failed_states. Thus
     if the external task enters a failed state and ``soft_fail == True`` the
     sensor will _skip_ rather than fail. As a result, setting 
``soft_fail=True``
-    and ``failed_states=[State.SKIPPED]`` will result in the sensor skipping if
-    the external task skips. However, this is a contrived example - consider
-    using ``skipped_states`` if you would like this behaviour. Using
-    ``skipped_states`` allows the sensor to skip if the target fails, but still
-    enter failed state on timeout. Using ``soft_fail == True`` as above will
-    cause the sensor to skip if the target fails, but also if it times out.
+    and ``failed_states=[DagRunState.SKIPPED]`` will result in the sensor
+    skipping if the external task skips. However, this is a contrived
+    example---consider using ``skipped_states`` if you would like this
+    behaviour. Using ``skipped_states`` allows the sensor to skip if the target
+    fails, but still enter failed state on timeout. Using ``soft_fail == True``
+    as above will cause the sensor to skip if the target fails, but also if it
+    times out.
 
     :param external_dag_id: The dag_id that contains the task you want to
         wait for. (templated)
@@ -146,7 +148,7 @@ class ExternalTaskSensor(BaseSensorOperator):
         **kwargs,
     ):
         super().__init__(**kwargs)
-        self.allowed_states = list(allowed_states) if allowed_states else 
[State.SUCCESS]
+        self.allowed_states = list(allowed_states) if allowed_states else 
[TaskInstanceState.SUCCESS.value]
         self.skipped_states = list(skipped_states) if skipped_states else []
         self.failed_states = list(failed_states) if failed_states else []
 
diff --git a/airflow/ti_deps/dependencies_states.py 
b/airflow/ti_deps/dependencies_states.py
index 543ce3528a..fd25d62f6d 100644
--- a/airflow/ti_deps/dependencies_states.py
+++ b/airflow/ti_deps/dependencies_states.py
@@ -16,38 +16,38 @@
 # under the License.
 from __future__ import annotations
 
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 EXECUTION_STATES = {
-    State.RUNNING,
-    State.QUEUED,
+    TaskInstanceState.RUNNING,
+    TaskInstanceState.QUEUED,
 }
 
 # In order to be able to get queued a task must have one of these states
 SCHEDULEABLE_STATES = {
-    State.NONE,
-    State.UP_FOR_RETRY,
-    State.UP_FOR_RESCHEDULE,
+    None,
+    TaskInstanceState.UP_FOR_RETRY,
+    TaskInstanceState.UP_FOR_RESCHEDULE,
 }
 
 RUNNABLE_STATES = {
     # For cases like unit tests and run manually
-    State.NONE,
-    State.UP_FOR_RETRY,
-    State.UP_FOR_RESCHEDULE,
+    None,
+    TaskInstanceState.UP_FOR_RETRY,
+    TaskInstanceState.UP_FOR_RESCHEDULE,
     # For normal scheduler/backfill cases
-    State.QUEUED,
+    TaskInstanceState.QUEUED,
 }
 
 QUEUEABLE_STATES = {
-    State.SCHEDULED,
+    TaskInstanceState.SCHEDULED,
 }
 
 BACKFILL_QUEUEABLE_STATES = {
     # For cases like unit tests and run manually
-    State.NONE,
-    State.UP_FOR_RESCHEDULE,
-    State.UP_FOR_RETRY,
+    None,
+    TaskInstanceState.UP_FOR_RESCHEDULE,
+    TaskInstanceState.UP_FOR_RETRY,
     # For normal backfill cases
-    State.SCHEDULED,
+    TaskInstanceState.SCHEDULED,
 }
diff --git a/airflow/triggers/external_task.py 
b/airflow/triggers/external_task.py
index 547556e315..e739c7a7cb 100644
--- a/airflow/triggers/external_task.py
+++ b/airflow/triggers/external_task.py
@@ -27,6 +27,7 @@ from sqlalchemy.orm import Session
 from airflow.models import DagRun, TaskInstance
 from airflow.triggers.base import BaseTrigger, TriggerEvent
 from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.state import DagRunState
 
 
 class TaskStateTrigger(BaseTrigger):
@@ -110,7 +111,7 @@ class DagStateTrigger(BaseTrigger):
     def __init__(
         self,
         dag_id: str,
-        states: list[str],
+        states: list[DagRunState],
         execution_dates: list[datetime.datetime],
         poll_interval: float = 5.0,
     ):

Reply via email to