This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3939d13224 AIP-72: Add "update TI state" endpoint for Execution API 
(#43602)
3939d13224 is described below

commit 3939d13224af28c855f4f79be53e0cae1c48026e
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Nov 5 13:23:23 2024 +0000

    AIP-72: Add "update TI state" endpoint for Execution API (#43602)
    
    Part of https://github.com/apache/airflow/issues/43586
    
    This PR adds a new endpoint `/execution/{task_instance_id}/state` that will 
allow Updating the State of the TI from the worker.
    
    Some of the interesting changes / TILs were:
    
    (hat tip to @ashb for this)
    
    To streamline the data exchange between workers and the Task Execution API, 
this PR adds minified schemas for Task Instance updates i.e. focuses solely on 
the fields necessary for specific state transitions, reducing payload size and 
validations. Since our TaskInstance model is huge this also keeps it clean to 
focus on only those fields that matter for this case.
    
    The endpoint added in this PR also leverages Pydantic’s [discriminated 
unions](https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions) 
to handle varying payload structures for each target state. This allows a 
single endpoint to receive different payloads (with different validations). For 
example:
    
    - `TIEnterRunningPayload`: Requires fields such as hostname, unixname, pid, 
and start_date to mark a task as RUNNING.
    - `TITerminalStatePayload`: Supports terminal states like SUCCESS, FAILED, 
SKIPPED,
    - `TITargetStatePayload`: Allows for other non-terminal, non-running states 
that a task may transition to.
    
    This is better so we don't have invalid payloads for example adding a 
start_date when a task is marked as SUCCESS, it doesn't make sense and it might 
be an error from the client!
    
    ![Nov-04-2024 
20-00-26](https://github.com/user-attachments/assets/07c1a197-0238-4c1a-9783-f23dd74a8d3e)
    
    `fastapi` allows importing a handy `status` module from starlette which has 
status code and the reason in its name. Reference: 
https://fastapi.tiangolo.com/reference/status/
    Example:
    
    `status.HTTP_204_NO_CONTENT` and `status.HTTP_409_CONFLICT` explain a lot 
more than just a "204 code" which doesn't tell much. I plan to change our 
current integers on public API to these in coming days.
    
    For now, I have assumed that we/the user don't care about `end_date` for 
`REMOVED` & `UPSTREAM_FAILED` status since they should be handled by the 
scheduler and shouldn't even show up on the worker. For `SKIPPED` state, since 
there are 2 scenarios: 1) A user can run the task and raise a 
`AirflowSkipException` 2) a task skipped on scheduler itself! For (1), we could 
set an end date, but (2) doesn't have it.
    
    - [ ] Pass a [RFC 9457](https://datatracker.ietf.org/doc/html/rfc9457) 
compliant error message in "detail" field of `HTTPException` to provide more 
information about the error
    - [ ] Add a separate heartbeat endpoint to track the TI’s active state.
    - [ ] Replace handling of `SQLAlchemyError` with FastAPI's [Custom 
Exception 
handling](https://fastapi.tiangolo.com/tutorial/handling-errors/#install-custom-exception-handlers)
 across the Execution API endpoints. That way we don't need duplicate code 
across multiple endpoints.
    - [ ] Replace `None` state on TaskInstance with a `Created` state. 
([link](https://github.com/orgs/apache/projects/405/views/1?pane=issue&itemId=85900878))
    - [ ] Remove redundant code that also set's task type once we remove DB 
access from the worker. This is assuming that the Webserver or the new FastAPI 
endpoints don't use this endpoint.
---
 .../routes/health.py => common/types.py}           |  12 +-
 airflow/api_fastapi/execution_api/app.py           |   1 +
 .../api_fastapi/execution_api/routes/__init__.py   |   2 +
 airflow/api_fastapi/execution_api/routes/health.py |   2 +-
 .../execution_api/routes/task_instance.py          | 131 ++++++++++++++
 airflow/api_fastapi/execution_api/schemas.py       | 114 ++++++++++++
 airflow/models/taskinstance.py                     |  36 ++++
 airflow/utils/state.py                             |   9 +
 .../api_fastapi/execution_api/conftest.py          |  12 +-
 .../execution_api/routes/test_task_instance.py     | 194 +++++++++++++++++++++
 10 files changed, 499 insertions(+), 14 deletions(-)

diff --git a/airflow/api_fastapi/execution_api/routes/health.py 
b/airflow/api_fastapi/common/types.py
similarity index 74%
copy from airflow/api_fastapi/execution_api/routes/health.py
copy to airflow/api_fastapi/common/types.py
index 21ef586b8c..d9664c0722 100644
--- a/airflow/api_fastapi/execution_api/routes/health.py
+++ b/airflow/api_fastapi/common/types.py
@@ -14,14 +14,12 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
 from __future__ import annotations
 
-from airflow.api_fastapi.common.router import AirflowRouter
-
-health_router = AirflowRouter(tags=["Task SDK"])
+from pydantic import AfterValidator, AwareDatetime
+from typing_extensions import Annotated
 
+from airflow.utils import timezone
 
-@health_router.get("/health")
-async def health() -> dict:
-    return {"status": "healthy"}
+UtcDateTime = Annotated[AwareDatetime, AfterValidator(lambda d: 
d.astimezone(timezone.utc))]
+"""UTCDateTime is a datetime with timezone information"""
diff --git a/airflow/api_fastapi/execution_api/app.py 
b/airflow/api_fastapi/execution_api/app.py
index 771b81c43a..8f4cd3fd0a 100644
--- a/airflow/api_fastapi/execution_api/app.py
+++ b/airflow/api_fastapi/execution_api/app.py
@@ -24,6 +24,7 @@ def create_task_execution_api_app(app: FastAPI) -> FastAPI:
     """Create FastAPI app for task execution API."""
     from airflow.api_fastapi.execution_api.routes import execution_api_router
 
+    # TODO: Add versioning to the API
     task_exec_api_app = FastAPI(
         title="Airflow Task Execution API",
         description="The private Airflow Task Execution API.",
diff --git a/airflow/api_fastapi/execution_api/routes/__init__.py 
b/airflow/api_fastapi/execution_api/routes/__init__.py
index 3d8761caef..55ee56b616 100644
--- a/airflow/api_fastapi/execution_api/routes/__init__.py
+++ b/airflow/api_fastapi/execution_api/routes/__init__.py
@@ -18,6 +18,8 @@ from __future__ import annotations
 
 from airflow.api_fastapi.common.router import AirflowRouter
 from airflow.api_fastapi.execution_api.routes.health import health_router
+from airflow.api_fastapi.execution_api.routes.task_instance import ti_router
 
 execution_api_router = AirflowRouter()
 execution_api_router.include_router(health_router)
+execution_api_router.include_router(ti_router)
diff --git a/airflow/api_fastapi/execution_api/routes/health.py 
b/airflow/api_fastapi/execution_api/routes/health.py
index 21ef586b8c..e0d51e3c71 100644
--- a/airflow/api_fastapi/execution_api/routes/health.py
+++ b/airflow/api_fastapi/execution_api/routes/health.py
@@ -19,7 +19,7 @@ from __future__ import annotations
 
 from airflow.api_fastapi.common.router import AirflowRouter
 
-health_router = AirflowRouter(tags=["Task SDK"])
+health_router = AirflowRouter(tags=["Health"])
 
 
 @health_router.get("/health")
diff --git a/airflow/api_fastapi/execution_api/routes/task_instance.py 
b/airflow/api_fastapi/execution_api/routes/task_instance.py
new file mode 100644
index 0000000000..05ce184964
--- /dev/null
+++ b/airflow/api_fastapi/execution_api/routes/task_instance.py
@@ -0,0 +1,131 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import logging
+from uuid import UUID
+
+from fastapi import Body, Depends, HTTPException, status
+from sqlalchemy import update
+from sqlalchemy.exc import NoResultFound, SQLAlchemyError
+from sqlalchemy.orm import Session
+from sqlalchemy.sql import select
+from typing_extensions import Annotated
+
+from airflow.api_fastapi.common.db.common import get_session
+from airflow.api_fastapi.common.router import AirflowRouter
+from airflow.api_fastapi.execution_api import schemas
+from airflow.models.taskinstance import TaskInstance as TI
+from airflow.utils.state import State
+
+# TODO: Add dependency on JWT token
+ti_router = AirflowRouter(
+    prefix="/task_instance",
+    tags=["Task Instance"],
+)
+
+
+log = logging.getLogger(__name__)
+
+
+@ti_router.patch(
+    "/{task_instance_id}/state",
+    status_code=status.HTTP_204_NO_CONTENT,
+    # TODO: Add Operation ID to control the function name in the OpenAPI spec
+    # TODO: Do we need to use create_openapi_http_exception_doc here?
+    responses={
+        status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"},
+        status.HTTP_409_CONFLICT: {"description": "The TI is already in the 
requested state"},
+        status.HTTP_422_UNPROCESSABLE_ENTITY: {"description": "Invalid payload 
for the state transition"},
+    },
+)
+async def ti_update_state(
+    task_instance_id: UUID,
+    ti_patch_payload: Annotated[schemas.TIStateUpdate, Body()],
+    session: Annotated[Session, Depends(get_session)],
+):
+    """
+    Update the state of a TaskInstance.
+
+    Not all state transitions are valid, and transitioning to some states 
required extra information to be
+    passed along. (Check our the schemas for details, the rendered docs might 
not reflect this accurately)
+    """
+    # We only use UUID above for validation purposes
+    ti_id_str = str(task_instance_id)
+
+    old = select(TI.state).where(TI.id == ti_id_str).with_for_update()
+    try:
+        (previous_state,) = session.execute(old).one()
+    except NoResultFound:
+        log.error("Task Instance %s not found", ti_id_str)
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail={
+                "reason": "not_found",
+                "message": "Task Instance not found",
+            },
+        )
+
+    # We exclude_unset to avoid updating fields that are not set in the payload
+    data = ti_patch_payload.model_dump(exclude_unset=True)
+
+    query = update(TI).where(TI.id == ti_id_str).values(data)
+
+    if isinstance(ti_patch_payload, schemas.TIEnterRunningPayload):
+        if previous_state != State.QUEUED:
+            log.warning(
+                "Can not start Task Instance ('%s') in invalid state: %s",
+                ti_id_str,
+                previous_state,
+            )
+
+            # TODO: Pass a RFC 9457 compliant error message in "detail" field
+            # https://datatracker.ietf.org/doc/html/rfc9457
+            # to provide more information about the error
+            # FastAPI will automatically convert this to a JSON response
+            # This might be added in FastAPI in 
https://github.com/fastapi/fastapi/issues/10370
+            raise HTTPException(
+                status_code=status.HTTP_409_CONFLICT,
+                detail={
+                    "reason": "invalid_state",
+                    "message": "TI was not in a state where it could be marked 
as running",
+                    "previous_state": previous_state,
+                },
+            )
+        log.info("Task with %s state started on %s ", previous_state, 
ti_patch_payload.hostname)
+        # Ensure there is no end date set.
+        query = query.values(
+            end_date=None,
+            hostname=ti_patch_payload.hostname,
+            unixname=ti_patch_payload.unixname,
+            pid=ti_patch_payload.pid,
+            state=State.RUNNING,
+        )
+    elif isinstance(ti_patch_payload, schemas.TITerminalStatePayload):
+        query = TI.duration_expression_update(ti_patch_payload.end_date, 
query, session.bind)
+
+    # TODO: Replace this with FastAPI's Custom Exception handling:
+    # 
https://fastapi.tiangolo.com/tutorial/handling-errors/#install-custom-exception-handlers
+    try:
+        result = session.execute(query)
+        log.info("TI %s state updated: %s row(s) affected", ti_id_str, 
result.rowcount)
+    except SQLAlchemyError as e:
+        log.error("Error updating Task Instance state: %s", e)
+        raise HTTPException(
+            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 
detail="Database error occurred"
+        )
diff --git a/airflow/api_fastapi/execution_api/schemas.py 
b/airflow/api_fastapi/execution_api/schemas.py
new file mode 100644
index 0000000000..3b60b109d9
--- /dev/null
+++ b/airflow/api_fastapi/execution_api/schemas.py
@@ -0,0 +1,114 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from typing import Annotated, Literal, Union
+
+from pydantic import (
+    BaseModel,
+    ConfigDict,
+    Discriminator,
+    Field,
+    Tag,
+    WithJsonSchema,
+)
+
+from airflow.api_fastapi.common.types import UtcDateTime
+from airflow.utils.state import State, TaskInstanceState as TIState
+
+
+class TIEnterRunningPayload(BaseModel):
+    """Schema for updating TaskInstance to 'RUNNING' state with minimal 
required fields."""
+
+    model_config = ConfigDict(from_attributes=True)
+
+    state: Annotated[
+        Literal[TIState.RUNNING],
+        # Specify a default in the schema, but not in code, so Pydantic marks 
it as required.
+        WithJsonSchema({"enum": [TIState.RUNNING], "default": 
TIState.RUNNING}),
+    ]
+    hostname: str
+    """Hostname where this task has started"""
+    unixname: str
+    """Local username of the process where this task has started"""
+    pid: int
+    """Process Identifier on `hostname`"""
+    start_date: UtcDateTime
+    """When the task started executing"""
+
+
+class TITerminalStatePayload(BaseModel):
+    """Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or 
FAILED)."""
+
+    state: Annotated[
+        Literal[TIState.SUCCESS, TIState.FAILED, TIState.SKIPPED],
+        Field(title="TerminalState"),
+        WithJsonSchema({"enum": list(State.ran_and_finished_states)}),
+    ]
+
+    end_date: UtcDateTime
+    """When the task completed executing"""
+
+
+class TITargetStatePayload(BaseModel):
+    """Schema for updating TaskInstance to a target state, excluding terminal 
and running states."""
+
+    state: Annotated[
+        TIState,
+        # For the OpenAPI schema generation,
+        #   make sure we do not include RUNNING as a valid state here
+        WithJsonSchema(
+            {
+                "enum": [
+                    state for state in TIState if state not in 
(State.ran_and_finished_states | {State.NONE})
+                ]
+            }
+        ),
+    ]
+
+
+def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str:
+    """
+    Determine the discriminator key for TaskInstance state transitions.
+
+    This function serves as a discriminator for the TIStateUpdate union schema,
+    categorizing the payload based on the ``state`` attribute in the input 
data.
+    It returns a key that directs FastAPI to the appropriate subclass (schema)
+    based on the requested state.
+    """
+    if isinstance(v, dict):
+        state = v.get("state")
+    else:
+        state = getattr(v, "state", None)
+    if state == TIState.RUNNING:
+        return str(state)
+    elif state in State.ran_and_finished_states:
+        return "_terminal_"
+    return "_other_"
+
+
+# It is called "_terminal_" to avoid future conflicts if we added an actual 
state named "terminal"
+# and "_other_" is a catch-all for all other states that are not covered by 
the other schemas.
+TIStateUpdate = Annotated[
+    Union[
+        Annotated[TIEnterRunningPayload, Tag("running")],
+        Annotated[TITerminalStatePayload, Tag("_terminal_")],
+        Annotated[TITargetStatePayload, Tag("_other_")],
+    ],
+    Discriminator(ti_state_discriminator),
+]
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index dfd776e685..c525a40a14 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -54,6 +54,7 @@ from sqlalchemy import (
     UniqueConstraint,
     and_,
     delete,
+    extract,
     false,
     func,
     inspect,
@@ -151,7 +152,9 @@ if TYPE_CHECKING:
     from pathlib import PurePath
     from types import TracebackType
 
+    from sqlalchemy.engine import Connection as SAConnection, Engine
     from sqlalchemy.orm.session import Session
+    from sqlalchemy.sql import Update
     from sqlalchemy.sql.elements import BooleanClauseList
     from sqlalchemy.sql.expression import ColumnOperators
 
@@ -3843,6 +3846,39 @@ class TaskInstance(Base, LoggingMixin):
                 )
             )
 
+    @classmethod
+    def duration_expression_update(
+        cls, end_date: datetime, query: Update, bind: Engine | SAConnection
+    ) -> Update:
+        """Return a SQL expression for calculating the duration of this TI, 
based on the start and end date columns."""
+        # TODO: Compare it with self._set_duration method
+
+        if bind.dialect.name == "sqlite":
+            return query.values(
+                {
+                    "end_date": end_date,
+                    "duration": (func.julianday(end_date) - 
func.julianday(cls.start_date)) * 86400,
+                }
+            )
+        elif bind.dialect.name == "postgresql":
+            return query.values(
+                {
+                    "end_date": end_date,
+                    "duration": extract("EPOCH", end_date - cls.start_date),
+                }
+            )
+
+        return query.values(
+            {
+                "end_date": end_date,
+                "duration": (
+                    func.timestampdiff(text("MICROSECOND"), cls.start_date, 
end_date)
+                    # Turn microseconds into floating point seconds.
+                    / 1_000_000
+                ),
+            }
+        )
+
 
 def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> 
MappedTaskGroup | None:
     """Given two operators, find their innermost common mapped task group."""
diff --git a/airflow/utils/state.py b/airflow/utils/state.py
index 87ce20effc..246c157611 100644
--- a/airflow/utils/state.py
+++ b/airflow/utils/state.py
@@ -199,3 +199,12 @@ class State:
     A list of states indicating that a task can be adopted or reset by a 
scheduler job
     if it was queued by another scheduler job that is not running anymore.
     """
+
+    ran_and_finished_states = frozenset(
+        [TaskInstanceState.SUCCESS, TaskInstanceState.FAILED, 
TaskInstanceState.SKIPPED]
+    )
+    """
+    A list of states indicating that a task has run and finished. This 
excludes states like
+    removed and upstream_failed. Skipped is included because a user can raise a
+    AirflowSkipException in a task and it will be marked as skipped.
+    """
diff --git a/airflow/api_fastapi/execution_api/routes/health.py 
b/tests/api_fastapi/execution_api/conftest.py
similarity index 80%
copy from airflow/api_fastapi/execution_api/routes/health.py
copy to tests/api_fastapi/execution_api/conftest.py
index 21ef586b8c..784cb29249 100644
--- a/airflow/api_fastapi/execution_api/routes/health.py
+++ b/tests/api_fastapi/execution_api/conftest.py
@@ -14,14 +14,14 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
 from __future__ import annotations
 
-from airflow.api_fastapi.common.router import AirflowRouter
+import pytest
+from fastapi.testclient import TestClient
 
-health_router = AirflowRouter(tags=["Task SDK"])
+from airflow.api_fastapi.app import cached_app
 
 
-@health_router.get("/health")
-async def health() -> dict:
-    return {"status": "healthy"}
[email protected]
+def client():
+    return TestClient(cached_app(apps="execution"))
diff --git a/tests/api_fastapi/execution_api/routes/test_task_instance.py 
b/tests/api_fastapi/execution_api/routes/test_task_instance.py
new file mode 100644
index 0000000000..602ed1fbd2
--- /dev/null
+++ b/tests/api_fastapi/execution_api/routes/test_task_instance.py
@@ -0,0 +1,194 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+from sqlalchemy import select
+from sqlalchemy.exc import SQLAlchemyError
+
+from airflow.models.taskinstance import TaskInstance
+from airflow.utils import timezone
+from airflow.utils.state import State
+
+from tests_common.test_utils.db import clear_db_runs
+
+pytestmark = pytest.mark.db_test
+
+
+DEFAULT_START_DATE = timezone.parse("2024-10-31T11:00:00Z")
+DEFAULT_END_DATE = timezone.parse("2024-10-31T12:00:00Z")
+
+
+class TestTIUpdateState:
+    def setup_method(self):
+        clear_db_runs()
+
+    def teardown_method(self):
+        clear_db_runs()
+
+    def test_ti_update_state_to_running(self, client, session, 
create_task_instance):
+        """
+        Test that the Task Instance state is updated to running when the Task 
Instance is in a state where it can be
+        marked as running.
+        """
+
+        ti = create_task_instance(
+            task_id="test_ti_update_state_to_running",
+            state=State.QUEUED,
+            session=session,
+        )
+
+        session.commit()
+
+        response = client.patch(
+            f"/execution/task_instance/{ti.id}/state",
+            json={
+                "state": "running",
+                "hostname": "random-hostname",
+                "unixname": "random-unixname",
+                "pid": 100,
+                "start_date": "2024-10-31T12:00:00Z",
+            },
+        )
+
+        assert response.status_code == 204
+        assert response.text == ""
+
+        # Refresh the Task Instance from the database so that we can check the 
updated values
+        session.refresh(ti)
+        assert ti.state == State.RUNNING
+        assert ti.hostname == "random-hostname"
+        assert ti.unixname == "random-unixname"
+        assert ti.pid == 100
+        assert ti.start_date.isoformat() == "2024-10-31T12:00:00+00:00"
+
+    def test_ti_update_state_conflict_if_not_queued(self, client, session, 
create_task_instance):
+        """
+        Test that a 409 error is returned when the Task Instance is not in a 
state where it can be marked as
+        running. In this case, the Task Instance is first in NONE state so it 
cannot be marked as running.
+        """
+        ti = create_task_instance(
+            task_id="test_ti_update_state_conflict_if_not_queued",
+            state=State.NONE,
+        )
+        session.commit()
+
+        response = client.patch(
+            f"/execution/task_instance/{ti.id}/state",
+            json={
+                "state": "running",
+                "hostname": "random-hostname",
+                "unixname": "random-unixname",
+                "pid": 100,
+                "start_date": "2024-10-31T12:00:00Z",
+            },
+        )
+
+        assert response.status_code == 409
+        assert response.json() == {
+            "detail": {
+                "message": "TI was not in a state where it could be marked as 
running",
+                "previous_state": State.NONE,
+                "reason": "invalid_state",
+            }
+        }
+
+        assert session.scalar(select(TaskInstance.state).where(TaskInstance.id 
== ti.id)) == State.NONE
+
+    @pytest.mark.parametrize(
+        ("state", "end_date", "expected_state"),
+        [
+            (State.SUCCESS, DEFAULT_END_DATE, State.SUCCESS),
+            (State.FAILED, DEFAULT_END_DATE, State.FAILED),
+            (State.SKIPPED, DEFAULT_END_DATE, State.SKIPPED),
+        ],
+    )
+    def test_ti_update_state_to_terminal(
+        self, client, session, create_task_instance, state, end_date, 
expected_state
+    ):
+        ti = create_task_instance(
+            task_id="test_ti_update_state_to_terminal",
+            start_date=DEFAULT_START_DATE,
+            state=State.RUNNING,
+        )
+        session.commit()
+
+        response = client.patch(
+            f"/execution/task_instance/{ti.id}/state",
+            json={
+                "state": state,
+                "end_date": end_date.isoformat(),
+            },
+        )
+
+        assert response.status_code == 204
+        assert response.text == ""
+
+        session.expire_all()
+
+        ti = session.get(TaskInstance, ti.id)
+        assert ti.state == expected_state
+        assert ti.end_date == end_date
+
+    def test_ti_update_state_not_found(self, client, session):
+        """
+        Test that a 404 error is returned when the Task Instance does not 
exist.
+        """
+        task_instance_id = "0182e924-0f1e-77e6-ab50-e977118bc139"
+
+        # Pre-condition: the Task Instance does not exist
+        assert session.scalar(select(TaskInstance.id).where(TaskInstance.id == 
task_instance_id)) is None
+
+        payload = {"state": "success", "end_date": "2024-10-31T12:30:00Z"}
+
+        response = 
client.patch(f"/execution/task_instance/{task_instance_id}/state", json=payload)
+        assert response.status_code == 404
+        assert response.json()["detail"] == {
+            "reason": "not_found",
+            "message": "Task Instance not found",
+        }
+
+    def test_ti_update_state_database_error(self, client, session, 
create_task_instance):
+        """
+        Test that a database error is handled correctly when updating the Task 
Instance state.
+        """
+        ti = create_task_instance(
+            task_id="test_ti_update_state_database_error",
+            state=State.QUEUED,
+        )
+        session.commit()
+        payload = {
+            "state": "running",
+            "hostname": "random-hostname",
+            "unixname": "random-unixname",
+            "pid": 100,
+            "start_date": "2024-10-31T12:00:00Z",
+        }
+
+        with mock.patch(
+            
"airflow.api_fastapi.execution_api.routes.task_instance.Session.execute",
+            side_effect=[
+                mock.Mock(one=lambda: ("queued",)),  # First call returns 
"queued"
+                SQLAlchemyError("Database error"),  # Second call raises an 
error
+            ],
+        ):
+            response = client.patch(f"/execution/task_instance/{ti.id}/state", 
json=payload)
+            assert response.status_code == 500
+            assert response.json()["detail"] == "Database error occurred"

Reply via email to