This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new bafa9da6fc1 AIP-72: Use explicit TI State enums for TI update state
endpoint (#44065)
bafa9da6fc1 is described below
commit bafa9da6fc1314bcc5ed64662167af97f0a79f18
Author: Kaxil Naik <[email protected]>
AuthorDate: Fri Nov 15 17:41:37 2024 +0000
AIP-72: Use explicit TI State enums for TI update state endpoint (#44065)
Created explicit enums so the generated client doesn't have names like
`State1` and reduces duplication of states!
---
airflow/api_fastapi/execution_api/datamodels.py | 25 ++------
airflow/utils/state.py | 58 +++++++++++-------
task_sdk/src/airflow/sdk/api/client.py | 7 +--
.../src/airflow/sdk/api/datamodels/_generated.py | 68 +++++++++-------------
task_sdk/src/airflow/sdk/execution_time/comms.py | 4 +-
.../src/airflow/sdk/execution_time/supervisor.py | 6 +-
6 files changed, 78 insertions(+), 90 deletions(-)
diff --git a/airflow/api_fastapi/execution_api/datamodels.py
b/airflow/api_fastapi/execution_api/datamodels.py
index 32115c9ac5a..ec8be531e10 100644
--- a/airflow/api_fastapi/execution_api/datamodels.py
+++ b/airflow/api_fastapi/execution_api/datamodels.py
@@ -29,7 +29,7 @@ from pydantic import (
)
from airflow.api_fastapi.common.types import UtcDateTime
-from airflow.utils.state import State, TaskInstanceState as TIState
+from airflow.utils.state import IntermediateTIState, TaskInstanceState as
TIState, TerminalTIState
class TIEnterRunningPayload(BaseModel):
@@ -40,7 +40,7 @@ class TIEnterRunningPayload(BaseModel):
state: Annotated[
Literal[TIState.RUNNING],
# Specify a default in the schema, but not in code, so Pydantic marks
it as required.
- WithJsonSchema({"enum": [TIState.RUNNING], "default":
TIState.RUNNING}),
+ WithJsonSchema({"type": "string", "enum": [TIState.RUNNING],
"default": TIState.RUNNING}),
]
hostname: str
"""Hostname where this task has started"""
@@ -55,11 +55,7 @@ class TIEnterRunningPayload(BaseModel):
class TITerminalStatePayload(BaseModel):
"""Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or
FAILED)."""
- state: Annotated[
- Literal[TIState.SUCCESS, TIState.FAILED, TIState.SKIPPED],
- Field(title="TerminalState"),
- WithJsonSchema({"enum": list(State.ran_and_finished_states)}),
- ]
+ state: TerminalTIState
end_date: UtcDateTime
"""When the task completed executing"""
@@ -68,18 +64,7 @@ class TITerminalStatePayload(BaseModel):
class TITargetStatePayload(BaseModel):
"""Schema for updating TaskInstance to a target state, excluding terminal
and running states."""
- state: Annotated[
- TIState,
- # For the OpenAPI schema generation,
- # make sure we do not include RUNNING as a valid state here
- WithJsonSchema(
- {
- "enum": [
- state for state in TIState if state not in
(State.ran_and_finished_states | {State.NONE})
- ]
- }
- ),
- ]
+ state: IntermediateTIState
def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str:
@@ -97,7 +82,7 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) ->
str:
state = getattr(v, "state", None)
if state == TIState.RUNNING:
return str(state)
- elif state in State.ran_and_finished_states:
+ elif state in set(TerminalTIState):
return "_terminal_"
return "_other_"
diff --git a/airflow/utils/state.py b/airflow/utils/state.py
index 246c157611b..e4e2e9db8a5 100644
--- a/airflow/utils/state.py
+++ b/airflow/utils/state.py
@@ -32,6 +32,33 @@ class JobState(str, Enum):
return self.value
+class TerminalTIState(str, Enum):
+ """States that a Task Instance can be in that indicate it has reached a
terminal state."""
+
+ SUCCESS = "success"
+ FAILED = "failed"
+ SKIPPED = "skipped" # A user can raise a AirflowSkipException from a task
& it will be marked as skipped
+ REMOVED = "removed"
+
+ def __str__(self) -> str:
+ return self.value
+
+
+class IntermediateTIState(str, Enum):
+ """States that a Task Instance can be in that indicate it is not yet in a
terminal or running state."""
+
+ SCHEDULED = "scheduled"
+ QUEUED = "queued"
+ RESTARTING = "restarting"
+ UP_FOR_RETRY = "up_for_retry"
+ UP_FOR_RESCHEDULE = "up_for_reschedule"
+ UPSTREAM_FAILED = "upstream_failed"
+ DEFERRED = "deferred"
+
+ def __str__(self) -> str:
+ return self.value
+
+
class TaskInstanceState(str, Enum):
"""
All possible states that a Task Instance can be in.
@@ -44,20 +71,20 @@ class TaskInstanceState(str, Enum):
# Use None instead if need this state.
# Set by the scheduler
- REMOVED = "removed" # Task vanished from DAG before it ran
- SCHEDULED = "scheduled" # Task should run and will be handed to executor
soon
+ REMOVED = TerminalTIState.REMOVED # Task vanished from DAG before it ran
+ SCHEDULED = IntermediateTIState.SCHEDULED # Task should run and will be
handed to executor soon
# Set by the task instance itself
- QUEUED = "queued" # Executor has enqueued the task
+ QUEUED = IntermediateTIState.QUEUED # Executor has enqueued the task
RUNNING = "running" # Task is executing
- SUCCESS = "success" # Task completed
- RESTARTING = "restarting" # External request to restart (e.g. cleared
when running)
- FAILED = "failed" # Task errored out
- UP_FOR_RETRY = "up_for_retry" # Task failed but has retries left
- UP_FOR_RESCHEDULE = "up_for_reschedule" # A waiting `reschedule` sensor
- UPSTREAM_FAILED = "upstream_failed" # One or more upstream deps failed
- SKIPPED = "skipped" # Skipped by branching or some other mechanism
- DEFERRED = "deferred" # Deferrable operator waiting on a trigger
+ SUCCESS = TerminalTIState.SUCCESS # Task completed
+ RESTARTING = IntermediateTIState.RESTARTING # External request to restart
(e.g. cleared when running)
+ FAILED = TerminalTIState.FAILED # Task errored out
+ UP_FOR_RETRY = IntermediateTIState.UP_FOR_RETRY # Task failed but has
retries left
+ UP_FOR_RESCHEDULE = IntermediateTIState.UP_FOR_RESCHEDULE # A waiting
`reschedule` sensor
+ UPSTREAM_FAILED = IntermediateTIState.UPSTREAM_FAILED # One or more
upstream deps failed
+ SKIPPED = TerminalTIState.SKIPPED # Skipped by branching or some other
mechanism
+ DEFERRED = IntermediateTIState.DEFERRED # Deferrable operator waiting on
a trigger
def __str__(self) -> str:
return self.value
@@ -199,12 +226,3 @@ class State:
A list of states indicating that a task can be adopted or reset by a
scheduler job
if it was queued by another scheduler job that is not running anymore.
"""
-
- ran_and_finished_states = frozenset(
- [TaskInstanceState.SUCCESS, TaskInstanceState.FAILED,
TaskInstanceState.SKIPPED]
- )
- """
- A list of states indicating that a task has run and finished. This
excludes states like
- removed and upstream_failed. Skipped is included because a user can raise a
- AirflowSkipException in a task and it will be marked as skipped.
- """
diff --git a/task_sdk/src/airflow/sdk/api/client.py
b/task_sdk/src/airflow/sdk/api/client.py
index 3706c737812..c740dceb010 100644
--- a/task_sdk/src/airflow/sdk/api/client.py
+++ b/task_sdk/src/airflow/sdk/api/client.py
@@ -31,8 +31,7 @@ from uuid6 import uuid7
from airflow.sdk import __version__
from airflow.sdk.api.datamodels._generated import (
ConnectionResponse,
- State1 as TerminalState,
- TaskInstanceState,
+ TerminalTIState,
TIEnterRunningPayload,
TITerminalStatePayload,
ValidationError as RemoteValidationError,
@@ -100,9 +99,9 @@ class TaskInstanceOperations:
self.client.patch(f"task-instance/{id}/state",
content=body.model_dump_json())
- def finish(self, id: uuid.UUID, state: TaskInstanceState, when: datetime):
+ def finish(self, id: uuid.UUID, state: TerminalTIState, when: datetime):
"""Tell the API server that this TI has reached a terminal state."""
- body = TITerminalStatePayload(end_date=when,
state=TerminalState(state))
+ body = TITerminalStatePayload(end_date=when,
state=TerminalTIState(state))
self.client.patch(f"task-instance/{id}/state",
content=body.model_dump_json())
diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
index f41508cae2a..e921bee4bc2 100644
--- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -43,6 +43,21 @@ class ConnectionResponse(BaseModel):
extra: Annotated[str | None, Field(title="Extra")] = None
+class IntermediateTIState(str, Enum):
+ """
+ States that a Task Instance can be in that indicate it is not yet in a
terminal or running state
+ """
+
+ REMOVED = "removed"
+ SCHEDULED = "scheduled"
+ QUEUED = "queued"
+ RESTARTING = "restarting"
+ UP_FOR_RETRY = "up_for_retry"
+ UP_FOR_RESCHEDULE = "up_for_reschedule"
+ UPSTREAM_FAILED = "upstream_failed"
+ DEFERRED = "deferred"
+
+
class TIEnterRunningPayload(BaseModel):
"""
Schema for updating TaskInstance to 'RUNNING' state with minimal required
fields.
@@ -64,60 +79,22 @@ class TIHeartbeatInfo(BaseModel):
pid: Annotated[int, Field(title="Pid")]
-class State(Enum):
- REMOVED = "removed"
- SCHEDULED = "scheduled"
- QUEUED = "queued"
- RUNNING = "running"
- RESTARTING = "restarting"
- UP_FOR_RETRY = "up_for_retry"
- UP_FOR_RESCHEDULE = "up_for_reschedule"
- UPSTREAM_FAILED = "upstream_failed"
- DEFERRED = "deferred"
-
-
class TITargetStatePayload(BaseModel):
"""
Schema for updating TaskInstance to a target state, excluding terminal and
running states.
"""
- state: State
-
+ state: IntermediateTIState
-class State1(Enum):
- FAILED = "failed"
- SUCCESS = "success"
- SKIPPED = "skipped"
-
-class TITerminalStatePayload(BaseModel):
+class TerminalTIState(str, Enum):
"""
- Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or
FAILED).
+ States that a Task Instance can be in that indicate it has reached a
terminal state
"""
- state: Annotated[State1, Field(title="TerminalState")]
- end_date: Annotated[datetime, Field(title="End Date")]
-
-
-class TaskInstanceState(str, Enum):
- """
- All possible states that a Task Instance can be in.
-
- Note that None is also allowed, so always use this in a type hint with
Optional.
- """
-
- REMOVED = "removed"
- SCHEDULED = "scheduled"
- QUEUED = "queued"
- RUNNING = "running"
SUCCESS = "success"
- RESTARTING = "restarting"
FAILED = "failed"
- UP_FOR_RETRY = "up_for_retry"
- UP_FOR_RESCHEDULE = "up_for_reschedule"
- UPSTREAM_FAILED = "upstream_failed"
SKIPPED = "skipped"
- DEFERRED = "deferred"
class ValidationError(BaseModel):
@@ -146,3 +123,12 @@ class XComResponse(BaseModel):
class HTTPValidationError(BaseModel):
detail: Annotated[list[ValidationError] | None, Field(title="Detail")] =
None
+
+
+class TITerminalStatePayload(BaseModel):
+ """
+ Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or
FAILED).
+ """
+
+ state: TerminalTIState
+ end_date: Annotated[datetime, Field(title="End Date")]
diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py
b/task_sdk/src/airflow/sdk/execution_time/comms.py
index 3128e98bf43..a78fbb3e33b 100644
--- a/task_sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task_sdk/src/airflow/sdk/execution_time/comms.py
@@ -47,7 +47,7 @@ from typing import Annotated, Any, Literal, Union
from pydantic import BaseModel, ConfigDict, Field
-from airflow.sdk.api.datamodels._generated import TaskInstanceState # noqa:
TCH001
+from airflow.sdk.api.datamodels._generated import TerminalTIState # noqa:
TCH001
from airflow.sdk.api.datamodels.ti import TaskInstance # noqa: TCH001
@@ -95,7 +95,7 @@ class TaskState(BaseModel):
- anything else = FAILED
"""
- state: TaskInstanceState
+ state: TerminalTIState
type: Literal["TaskState"] = "TaskState"
diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
index 3c0623ba1b0..c05c6138f96 100644
--- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -43,7 +43,7 @@ import structlog
from pydantic import TypeAdapter
from airflow.sdk.api.client import Client
-from airflow.sdk.api.datamodels._generated import TaskInstanceState
+from airflow.sdk.api.datamodels._generated import TerminalTIState
from airflow.sdk.execution_time.comms import (
ConnectionResponse,
GetConnection,
@@ -431,8 +431,8 @@ class WatchedSubprocess:
Not valid before the process has finished.
"""
if self._exit_code == 0:
- return self._terminal_state or TaskInstanceState.SUCCESS
- return TaskInstanceState.FAILED
+ return self._terminal_state or TerminalTIState.SUCCESS
+ return TerminalTIState.FAILED
def __rich_repr__(self):
yield "pid", self.pid