This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new bafa9da6fc1 AIP-72: Use explicit TI State enums for TI update state 
endpoint (#44065)
bafa9da6fc1 is described below

commit bafa9da6fc1314bcc5ed64662167af97f0a79f18
Author: Kaxil Naik <[email protected]>
AuthorDate: Fri Nov 15 17:41:37 2024 +0000

    AIP-72: Use explicit TI State enums for TI update state endpoint (#44065)
    
    Created explicit enums so the generated client doesn't have names like 
`State1` and reduces duplication of states!
---
 airflow/api_fastapi/execution_api/datamodels.py    | 25 ++------
 airflow/utils/state.py                             | 58 +++++++++++-------
 task_sdk/src/airflow/sdk/api/client.py             |  7 +--
 .../src/airflow/sdk/api/datamodels/_generated.py   | 68 +++++++++-------------
 task_sdk/src/airflow/sdk/execution_time/comms.py   |  4 +-
 .../src/airflow/sdk/execution_time/supervisor.py   |  6 +-
 6 files changed, 78 insertions(+), 90 deletions(-)

diff --git a/airflow/api_fastapi/execution_api/datamodels.py 
b/airflow/api_fastapi/execution_api/datamodels.py
index 32115c9ac5a..ec8be531e10 100644
--- a/airflow/api_fastapi/execution_api/datamodels.py
+++ b/airflow/api_fastapi/execution_api/datamodels.py
@@ -29,7 +29,7 @@ from pydantic import (
 )
 
 from airflow.api_fastapi.common.types import UtcDateTime
-from airflow.utils.state import State, TaskInstanceState as TIState
+from airflow.utils.state import IntermediateTIState, TaskInstanceState as 
TIState, TerminalTIState
 
 
 class TIEnterRunningPayload(BaseModel):
@@ -40,7 +40,7 @@ class TIEnterRunningPayload(BaseModel):
     state: Annotated[
         Literal[TIState.RUNNING],
         # Specify a default in the schema, but not in code, so Pydantic marks 
it as required.
-        WithJsonSchema({"enum": [TIState.RUNNING], "default": 
TIState.RUNNING}),
+        WithJsonSchema({"type": "string", "enum": [TIState.RUNNING], 
"default": TIState.RUNNING}),
     ]
     hostname: str
     """Hostname where this task has started"""
@@ -55,11 +55,7 @@ class TIEnterRunningPayload(BaseModel):
 class TITerminalStatePayload(BaseModel):
     """Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or 
FAILED)."""
 
-    state: Annotated[
-        Literal[TIState.SUCCESS, TIState.FAILED, TIState.SKIPPED],
-        Field(title="TerminalState"),
-        WithJsonSchema({"enum": list(State.ran_and_finished_states)}),
-    ]
+    state: TerminalTIState
 
     end_date: UtcDateTime
     """When the task completed executing"""
@@ -68,18 +64,7 @@ class TITerminalStatePayload(BaseModel):
 class TITargetStatePayload(BaseModel):
     """Schema for updating TaskInstance to a target state, excluding terminal 
and running states."""
 
-    state: Annotated[
-        TIState,
-        # For the OpenAPI schema generation,
-        #   make sure we do not include RUNNING as a valid state here
-        WithJsonSchema(
-            {
-                "enum": [
-                    state for state in TIState if state not in 
(State.ran_and_finished_states | {State.NONE})
-                ]
-            }
-        ),
-    ]
+    state: IntermediateTIState
 
 
 def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str:
@@ -97,7 +82,7 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> 
str:
         state = getattr(v, "state", None)
     if state == TIState.RUNNING:
         return str(state)
-    elif state in State.ran_and_finished_states:
+    elif state in set(TerminalTIState):
         return "_terminal_"
     return "_other_"
 
diff --git a/airflow/utils/state.py b/airflow/utils/state.py
index 246c157611b..e4e2e9db8a5 100644
--- a/airflow/utils/state.py
+++ b/airflow/utils/state.py
@@ -32,6 +32,33 @@ class JobState(str, Enum):
         return self.value
 
 
+class TerminalTIState(str, Enum):
+    """States that a Task Instance can be in that indicate it has reached a 
terminal state."""
+
+    SUCCESS = "success"
+    FAILED = "failed"
+    SKIPPED = "skipped"  # A user can raise a AirflowSkipException from a task 
& it will be marked as skipped
+    REMOVED = "removed"
+
+    def __str__(self) -> str:
+        return self.value
+
+
+class IntermediateTIState(str, Enum):
+    """States that a Task Instance can be in that indicate it is not yet in a 
terminal or running state."""
+
+    SCHEDULED = "scheduled"
+    QUEUED = "queued"
+    RESTARTING = "restarting"
+    UP_FOR_RETRY = "up_for_retry"
+    UP_FOR_RESCHEDULE = "up_for_reschedule"
+    UPSTREAM_FAILED = "upstream_failed"
+    DEFERRED = "deferred"
+
+    def __str__(self) -> str:
+        return self.value
+
+
 class TaskInstanceState(str, Enum):
     """
     All possible states that a Task Instance can be in.
@@ -44,20 +71,20 @@ class TaskInstanceState(str, Enum):
     # Use None instead if need this state.
 
     # Set by the scheduler
-    REMOVED = "removed"  # Task vanished from DAG before it ran
-    SCHEDULED = "scheduled"  # Task should run and will be handed to executor 
soon
+    REMOVED = TerminalTIState.REMOVED  # Task vanished from DAG before it ran
+    SCHEDULED = IntermediateTIState.SCHEDULED  # Task should run and will be 
handed to executor soon
 
     # Set by the task instance itself
-    QUEUED = "queued"  # Executor has enqueued the task
+    QUEUED = IntermediateTIState.QUEUED  # Executor has enqueued the task
     RUNNING = "running"  # Task is executing
-    SUCCESS = "success"  # Task completed
-    RESTARTING = "restarting"  # External request to restart (e.g. cleared 
when running)
-    FAILED = "failed"  # Task errored out
-    UP_FOR_RETRY = "up_for_retry"  # Task failed but has retries left
-    UP_FOR_RESCHEDULE = "up_for_reschedule"  # A waiting `reschedule` sensor
-    UPSTREAM_FAILED = "upstream_failed"  # One or more upstream deps failed
-    SKIPPED = "skipped"  # Skipped by branching or some other mechanism
-    DEFERRED = "deferred"  # Deferrable operator waiting on a trigger
+    SUCCESS = TerminalTIState.SUCCESS  # Task completed
+    RESTARTING = IntermediateTIState.RESTARTING  # External request to restart 
(e.g. cleared when running)
+    FAILED = TerminalTIState.FAILED  # Task errored out
+    UP_FOR_RETRY = IntermediateTIState.UP_FOR_RETRY  # Task failed but has 
retries left
+    UP_FOR_RESCHEDULE = IntermediateTIState.UP_FOR_RESCHEDULE  # A waiting 
`reschedule` sensor
+    UPSTREAM_FAILED = IntermediateTIState.UPSTREAM_FAILED  # One or more 
upstream deps failed
+    SKIPPED = TerminalTIState.SKIPPED  # Skipped by branching or some other 
mechanism
+    DEFERRED = IntermediateTIState.DEFERRED  # Deferrable operator waiting on 
a trigger
 
     def __str__(self) -> str:
         return self.value
@@ -199,12 +226,3 @@ class State:
     A list of states indicating that a task can be adopted or reset by a 
scheduler job
     if it was queued by another scheduler job that is not running anymore.
     """
-
-    ran_and_finished_states = frozenset(
-        [TaskInstanceState.SUCCESS, TaskInstanceState.FAILED, 
TaskInstanceState.SKIPPED]
-    )
-    """
-    A list of states indicating that a task has run and finished. This 
excludes states like
-    removed and upstream_failed. Skipped is included because a user can raise a
-    AirflowSkipException in a task and it will be marked as skipped.
-    """
diff --git a/task_sdk/src/airflow/sdk/api/client.py 
b/task_sdk/src/airflow/sdk/api/client.py
index 3706c737812..c740dceb010 100644
--- a/task_sdk/src/airflow/sdk/api/client.py
+++ b/task_sdk/src/airflow/sdk/api/client.py
@@ -31,8 +31,7 @@ from uuid6 import uuid7
 from airflow.sdk import __version__
 from airflow.sdk.api.datamodels._generated import (
     ConnectionResponse,
-    State1 as TerminalState,
-    TaskInstanceState,
+    TerminalTIState,
     TIEnterRunningPayload,
     TITerminalStatePayload,
     ValidationError as RemoteValidationError,
@@ -100,9 +99,9 @@ class TaskInstanceOperations:
 
         self.client.patch(f"task-instance/{id}/state", 
content=body.model_dump_json())
 
-    def finish(self, id: uuid.UUID, state: TaskInstanceState, when: datetime):
+    def finish(self, id: uuid.UUID, state: TerminalTIState, when: datetime):
         """Tell the API server that this TI has reached a terminal state."""
-        body = TITerminalStatePayload(end_date=when, 
state=TerminalState(state))
+        body = TITerminalStatePayload(end_date=when, 
state=TerminalTIState(state))
 
         self.client.patch(f"task-instance/{id}/state", 
content=body.model_dump_json())
 
diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py 
b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
index f41508cae2a..e921bee4bc2 100644
--- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -43,6 +43,21 @@ class ConnectionResponse(BaseModel):
     extra: Annotated[str | None, Field(title="Extra")] = None
 
 
+class IntermediateTIState(str, Enum):
+    """
+    States that a Task Instance can be in that indicate it is not yet in a 
terminal or running state
+    """
+
+    REMOVED = "removed"
+    SCHEDULED = "scheduled"
+    QUEUED = "queued"
+    RESTARTING = "restarting"
+    UP_FOR_RETRY = "up_for_retry"
+    UP_FOR_RESCHEDULE = "up_for_reschedule"
+    UPSTREAM_FAILED = "upstream_failed"
+    DEFERRED = "deferred"
+
+
 class TIEnterRunningPayload(BaseModel):
     """
     Schema for updating TaskInstance to 'RUNNING' state with minimal required 
fields.
@@ -64,60 +79,22 @@ class TIHeartbeatInfo(BaseModel):
     pid: Annotated[int, Field(title="Pid")]
 
 
-class State(Enum):
-    REMOVED = "removed"
-    SCHEDULED = "scheduled"
-    QUEUED = "queued"
-    RUNNING = "running"
-    RESTARTING = "restarting"
-    UP_FOR_RETRY = "up_for_retry"
-    UP_FOR_RESCHEDULE = "up_for_reschedule"
-    UPSTREAM_FAILED = "upstream_failed"
-    DEFERRED = "deferred"
-
-
 class TITargetStatePayload(BaseModel):
     """
     Schema for updating TaskInstance to a target state, excluding terminal and 
running states.
     """
 
-    state: State
-
+    state: IntermediateTIState
 
-class State1(Enum):
-    FAILED = "failed"
-    SUCCESS = "success"
-    SKIPPED = "skipped"
 
-
-class TITerminalStatePayload(BaseModel):
+class TerminalTIState(str, Enum):
     """
-    Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or 
FAILED).
+    States that a Task Instance can be in that indicate it has reached a 
terminal state
     """
 
-    state: Annotated[State1, Field(title="TerminalState")]
-    end_date: Annotated[datetime, Field(title="End Date")]
-
-
-class TaskInstanceState(str, Enum):
-    """
-    All possible states that a Task Instance can be in.
-
-    Note that None is also allowed, so always use this in a type hint with 
Optional.
-    """
-
-    REMOVED = "removed"
-    SCHEDULED = "scheduled"
-    QUEUED = "queued"
-    RUNNING = "running"
     SUCCESS = "success"
-    RESTARTING = "restarting"
     FAILED = "failed"
-    UP_FOR_RETRY = "up_for_retry"
-    UP_FOR_RESCHEDULE = "up_for_reschedule"
-    UPSTREAM_FAILED = "upstream_failed"
     SKIPPED = "skipped"
-    DEFERRED = "deferred"
 
 
 class ValidationError(BaseModel):
@@ -146,3 +123,12 @@ class XComResponse(BaseModel):
 
 class HTTPValidationError(BaseModel):
     detail: Annotated[list[ValidationError] | None, Field(title="Detail")] = 
None
+
+
+class TITerminalStatePayload(BaseModel):
+    """
+    Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or 
FAILED).
+    """
+
+    state: TerminalTIState
+    end_date: Annotated[datetime, Field(title="End Date")]
diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py 
b/task_sdk/src/airflow/sdk/execution_time/comms.py
index 3128e98bf43..a78fbb3e33b 100644
--- a/task_sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task_sdk/src/airflow/sdk/execution_time/comms.py
@@ -47,7 +47,7 @@ from typing import Annotated, Any, Literal, Union
 
 from pydantic import BaseModel, ConfigDict, Field
 
-from airflow.sdk.api.datamodels._generated import TaskInstanceState  # noqa: 
TCH001
+from airflow.sdk.api.datamodels._generated import TerminalTIState  # noqa: 
TCH001
 from airflow.sdk.api.datamodels.ti import TaskInstance  # noqa: TCH001
 
 
@@ -95,7 +95,7 @@ class TaskState(BaseModel):
     - anything else = FAILED
     """
 
-    state: TaskInstanceState
+    state: TerminalTIState
     type: Literal["TaskState"] = "TaskState"
 
 
diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
index 3c0623ba1b0..c05c6138f96 100644
--- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -43,7 +43,7 @@ import structlog
 from pydantic import TypeAdapter
 
 from airflow.sdk.api.client import Client
-from airflow.sdk.api.datamodels._generated import TaskInstanceState
+from airflow.sdk.api.datamodels._generated import TerminalTIState
 from airflow.sdk.execution_time.comms import (
     ConnectionResponse,
     GetConnection,
@@ -431,8 +431,8 @@ class WatchedSubprocess:
         Not valid before the process has finished.
         """
         if self._exit_code == 0:
-            return self._terminal_state or TaskInstanceState.SUCCESS
-        return TaskInstanceState.FAILED
+            return self._terminal_state or TerminalTIState.SUCCESS
+        return TerminalTIState.FAILED
 
     def __rich_repr__(self):
         yield "pid", self.pid

Reply via email to