This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a2ae2265ce Add JobState for job state constants (#32549)
a2ae2265ce is described below

commit a2ae2265ce960d65bc3c4bf805ee77954a1f895c
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Wed Jul 12 19:03:02 2023 +0800

    Add JobState for job state constants (#32549)
    
    Also fixed a couple of cases where DagRunState and TaskInstanceState
    were incorrectly used for job state.
---
 airflow/cli/cli_config.py            | 17 ++++++++++++-----
 airflow/cli/commands/jobs_command.py |  4 ++--
 airflow/jobs/job.py                  | 16 ++++++++--------
 airflow/jobs/scheduler_job_runner.py | 10 +++++-----
 airflow/models/taskinstance.py       |  6 +++---
 airflow/utils/state.py               | 13 +++++++++++++
 6 files changed, 43 insertions(+), 23 deletions(-)

diff --git a/airflow/cli/cli_config.py b/airflow/cli/cli_config.py
index 046cbaad14..92ba87c3de 100644
--- a/airflow/cli/cli_config.py
+++ b/airflow/cli/cli_config.py
@@ -37,7 +37,7 @@ from airflow.executors.executor_loader import ExecutorLoader
 from airflow.settings import _ENABLE_AIP_44
 from airflow.utils.cli import ColorMode
 from airflow.utils.module_loading import import_string
-from airflow.utils.state import DagRunState
+from airflow.utils.state import DagRunState, JobState
 from airflow.utils.timezone import parse as parsedate
 
 BUILD_DOCS = "BUILDING_AIRFLOW_DOCS" in os.environ
@@ -281,9 +281,9 @@ ARG_NO_BACKFILL = Arg(
     ("--no-backfill",), help="filter all the backfill dagruns given the dag 
id", action="store_true"
 )
 dagrun_states = tuple(state.value for state in DagRunState)
-ARG_STATE = Arg(
+ARG_DR_STATE = Arg(
     ("--state",),
-    help="Only list the dag runs corresponding to the state",
+    help="Only list the DAG runs corresponding to the state",
     metavar=", ".join(dagrun_states),
     choices=dagrun_states,
 )
@@ -291,6 +291,13 @@ ARG_STATE = Arg(
 # list_jobs
 ARG_DAG_ID_OPT = Arg(("-d", "--dag-id"), help="The id of the dag")
 ARG_LIMIT = Arg(("--limit",), help="Return a limited number of records")
+job_states = tuple(state.value for state in JobState)
+ARG_JOB_STATE = Arg(
+    ("--state",),
+    help="Only list the jobs corresponding to the state",
+    metavar=", ".join(job_states),
+    choices=job_states,
+)
 
 # next_execution
 ARG_NUM_EXECUTIONS = Arg(
@@ -1161,7 +1168,7 @@ DAGS_COMMANDS = (
         args=(
             ARG_DAG_ID_REQ_FLAG,
             ARG_NO_BACKFILL,
-            ARG_STATE,
+            ARG_DR_STATE,
             ARG_OUTPUT,
             ARG_VERBOSE,
             ARG_START_DATE,
@@ -1172,7 +1179,7 @@ DAGS_COMMANDS = (
         name="list-jobs",
         help="List the jobs",
         
func=lazy_load_command("airflow.cli.commands.dag_command.dag_list_jobs"),
-        args=(ARG_DAG_ID_OPT, ARG_STATE, ARG_LIMIT, ARG_OUTPUT, ARG_VERBOSE),
+        args=(ARG_DAG_ID_OPT, ARG_JOB_STATE, ARG_LIMIT, ARG_OUTPUT, 
ARG_VERBOSE),
     ),
     ActionCommand(
         name="state",
diff --git a/airflow/cli/commands/jobs_command.py 
b/airflow/cli/commands/jobs_command.py
index 339896d877..bcdd6df475 100644
--- a/airflow/cli/commands/jobs_command.py
+++ b/airflow/cli/commands/jobs_command.py
@@ -22,7 +22,7 @@ from sqlalchemy.orm import Session
 from airflow.jobs.job import Job
 from airflow.utils.net import get_hostname
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import State
+from airflow.utils.state import JobState
 
 
 @provide_session
@@ -33,7 +33,7 @@ def check(args, session: Session = NEW_SESSION) -> None:
     if args.hostname and args.local:
         raise SystemExit("You can't use --hostname and --local at the same 
time")
 
-    query = select(Job).where(Job.state == 
State.RUNNING).order_by(Job.latest_heartbeat.desc())
+    query = select(Job).where(Job.state == 
JobState.RUNNING).order_by(Job.latest_heartbeat.desc())
     if args.job_type:
         query = query.where(Job.job_type == args.job_type)
     if args.hostname:
diff --git a/airflow/jobs/job.py b/airflow/jobs/job.py
index 99394cbc45..add43e0aa5 100644
--- a/airflow/jobs/job.py
+++ b/airflow/jobs/job.py
@@ -40,7 +40,7 @@ from airflow.utils.net import get_hostname
 from airflow.utils.platform import getuser
 from airflow.utils.session import NEW_SESSION, create_session, provide_session
 from airflow.utils.sqlalchemy import UtcDateTime
-from airflow.utils.state import State
+from airflow.utils.state import JobState
 
 
 def _resolve_dagrun_model():
@@ -132,7 +132,7 @@ class Job(Base, LoggingMixin):
         else:
             health_check_threshold: int = self.heartrate * grace_multiplier
         return (
-            self.state == State.RUNNING
+            self.state == JobState.RUNNING
             and (timezone.utcnow() - self.latest_heartbeat).total_seconds() < 
health_check_threshold
         )
 
@@ -181,7 +181,7 @@ class Job(Base, LoggingMixin):
             session.merge(self)
             previous_heartbeat = self.latest_heartbeat
 
-            if self.state in State.terminating_states:
+            if self.state in (JobState.SHUTDOWN, JobState.RESTARTING):
                 # TODO: Make sure it is AIP-44 compliant
                 self.kill()
 
@@ -215,7 +215,7 @@ class Job(Base, LoggingMixin):
     def prepare_for_execution(self, session: Session = NEW_SESSION):
         """Prepares the job for execution."""
         Stats.incr(self.__class__.__name__.lower() + "_start", 1, 1)
-        self.state = State.RUNNING
+        self.state = JobState.RUNNING
         self.start_date = timezone.utcnow()
         session.add(self)
         session.commit()
@@ -251,7 +251,7 @@ def most_recent_job(job_type: str, session: Session = 
NEW_SESSION) -> Job | None
         .where(Job.job_type == job_type)
         .order_by(
             # Put "running" jobs at the front.
-            case({State.RUNNING: 0}, value=Job.state, else_=1),
+            case({JobState.RUNNING: 0}, value=Job.state, else_=1),
             Job.latest_heartbeat.desc(),
         )
         .limit(1)
@@ -308,12 +308,12 @@ def execute_job(job: Job | JobPydantic, execute_callable: 
Callable[[], int | Non
     try:
         ret = execute_callable()
         # In case of max runs or max duration
-        job.state = State.SUCCESS
+        job.state = JobState.SUCCESS
     except SystemExit:
         # In case of ^C or SIGTERM
-        job.state = State.SUCCESS
+        job.state = JobState.SUCCESS
     except Exception:
-        job.state = State.FAILED
+        job.state = JobState.FAILED
         raise
     return ret
 
diff --git a/airflow/jobs/scheduler_job_runner.py 
b/airflow/jobs/scheduler_job_runner.py
index c0e6877076..3b399d75b5 100644
--- a/airflow/jobs/scheduler_job_runner.py
+++ b/airflow/jobs/scheduler_job_runner.py
@@ -74,7 +74,7 @@ from airflow.utils.sqlalchemy import (
     tuple_in_condition,
     with_row_locks,
 )
-from airflow.utils.state import DagRunState, State, TaskInstanceState
+from airflow.utils.state import DagRunState, JobState, State, TaskInstanceState
 from airflow.utils.types import DagRunType
 
 if TYPE_CHECKING:
@@ -1586,10 +1586,10 @@ class SchedulerJobRunner(BaseJobRunner[Job], 
LoggingMixin):
                         update(Job)
                         .where(
                             Job.job_type == "SchedulerJob",
-                            Job.state == State.RUNNING,
+                            Job.state == JobState.RUNNING,
                             Job.latest_heartbeat < (timezone.utcnow() - 
timedelta(seconds=timeout)),
                         )
-                        .values(state=State.FAILED)
+                        .values(state=JobState.FAILED)
                     ).rowcount
 
                     if num_failed:
@@ -1605,7 +1605,7 @@ class SchedulerJobRunner(BaseJobRunner[Job], 
LoggingMixin):
                         # "or queued_by_job_id IS NONE") can go as soon as 
scheduler HA is
                         # released.
                         .outerjoin(TI.queued_by_job)
-                        .where(or_(TI.queued_by_job_id.is_(None), Job.state != 
State.RUNNING))
+                        .where(or_(TI.queued_by_job_id.is_(None), Job.state != 
JobState.RUNNING))
                         .join(TI.dag_run)
                         .where(
                             DagRun.run_type != DagRunType.BACKFILL_JOB,
@@ -1690,7 +1690,7 @@ class SchedulerJobRunner(BaseJobRunner[Job], 
LoggingMixin):
                     .where(TI.state == TaskInstanceState.RUNNING)
                     .where(
                         or_(
-                            Job.state != State.RUNNING,
+                            Job.state != JobState.RUNNING,
                             Job.latest_heartbeat < limit_dttm,
                         )
                     )
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 3ec15f449d..3e328ec208 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -57,6 +57,7 @@ from sqlalchemy import (
     inspect,
     or_,
     text,
+    update,
 )
 from sqlalchemy.ext.associationproxy import association_proxy
 from sqlalchemy.ext.mutable import MutableDict
@@ -123,7 +124,7 @@ from airflow.utils.sqlalchemy import (
     tuple_in_condition,
     with_row_locks,
 )
-from airflow.utils.state import DagRunState, State, TaskInstanceState
+from airflow.utils.state import DagRunState, JobState, State, TaskInstanceState
 from airflow.utils.task_group import MappedTaskGroup
 from airflow.utils.timeout import timeout
 from airflow.utils.xcom import XCOM_RETURN_KEY
@@ -292,8 +293,7 @@ def clear_task_instances(
     if job_ids:
         from airflow.jobs.job import Job
 
-        for job in session.query(Job).filter(Job.id.in_(job_ids)).all():
-            job.state = TaskInstanceState.RESTARTING
+        
session.execute(update(Job).where(Job.id.in_(job_ids)).values(state=JobState.RESTARTING))
 
     if activate_dag_runs is not None:
         warnings.warn(
diff --git a/airflow/utils/state.py b/airflow/utils/state.py
index fc74732acc..f18fd48c82 100644
--- a/airflow/utils/state.py
+++ b/airflow/utils/state.py
@@ -20,6 +20,19 @@ from __future__ import annotations
 from enum import Enum
 
 
+class JobState(str, Enum):
+    """All possible states that a Job can be in."""
+
+    RUNNING = "running"
+    SUCCESS = "success"
+    SHUTDOWN = "shutdown"
+    RESTARTING = "restarting"
+    FAILED = "failed"
+
+    def __str__(self) -> str:
+        return self.value
+
+
 class TaskInstanceState(str, Enum):
     """All possible states that a Task Instance can be in.
 

Reply via email to