This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v3-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v3-0-test by this push:
     new acc2c97100a [v3-0-test] feat(task_instances): guard ti update state 
and set task to fail if exception encountered (#51295) (#51470)
acc2c97100a is described below

commit acc2c97100a4393a41a279b0047f9933a21e0381
Author: github-actions[bot] 
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Fri Jun 6 14:24:08 2025 +0530

    [v3-0-test] feat(task_instances): guard ti update state and set task to 
fail if exception encountered (#51295) (#51470)
    
    * feat(task_instances): guard ti update state and set to fail if exception 
enounctered
    
    * feat(task_instances): catch mysql error and set the task to fail
    
    * test: remove unnecessay check
    
    * fix(task_instances): handle mysql error
    
    * refactor(task-instances): merge mysql logic back to the original private 
function
    (cherry picked from commit b5a3b4e7d02d3a286a7daf81c6e5e8d4855eb553)
    
    Co-authored-by: Wei Lee <[email protected]>
---
 .../execution_api/routes/task_instances.py         | 117 +++++++++++++--------
 .../versions/head/test_task_instances.py           |  43 +++++++-
 task-sdk/src/airflow/sdk/api/client.py             |   8 +-
 3 files changed, 119 insertions(+), 49 deletions(-)

diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index 6c86ef38035..e22d0a5f34d 100644
--- 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -22,6 +22,7 @@ import itertools
 import json
 from collections import defaultdict
 from collections.abc import Iterator
+from datetime import datetime
 from typing import TYPE_CHECKING, Annotated, Any
 from uuid import UUID
 
@@ -55,7 +56,7 @@ from 
airflow.api_fastapi.execution_api.datamodels.taskinstance import (
     TITerminalStatePayload,
 )
 from airflow.api_fastapi.execution_api.deps import JWTBearerTIPathDep
-from airflow.exceptions import AirflowInactiveAssetInInletOrOutletException, 
TaskNotFound
+from airflow.exceptions import TaskNotFound
 from airflow.models.asset import AssetActive
 from airflow.models.dagbag import DagBag
 from airflow.models.dagrun import DagRun as DR
@@ -70,6 +71,8 @@ from airflow.utils import timezone
 from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState
 
 if TYPE_CHECKING:
+    from sqlalchemy.sql.dml import Update
+
     from airflow.sdk.types import Operator
 
 
@@ -381,43 +384,74 @@ def ti_update_state(
 
     # We exclude_unset to avoid updating fields that are not set in the payload
     data = ti_patch_payload.model_dump(exclude={"task_outlets", 
"outlet_events"}, exclude_unset=True)
-
     query = update(TI).where(TI.id == ti_id_str).values(data)
 
-    if isinstance(ti_patch_payload, TITerminalStatePayload):
-        updated_state = ti_patch_payload.state
-        query = TI.duration_expression_update(ti_patch_payload.end_date, 
query, session.bind)
-        query = query.values(state=updated_state)
+    try:
+        query, updated_state = _create_ti_state_update_query_and_update_state(
+            ti_patch_payload=ti_patch_payload,
+            ti_id_str=ti_id_str,
+            session=session,
+            query=query,
+            updated_state=updated_state,
+            dag_id=dag_id,
+            dag_bag=dag_bag,
+        )
+    except Exception:
+        # Set a task to failed in case any unexpected exception happened 
during task state update
+        log.exception("Error updating Task Instance state to %s. Set the task 
to failed", updated_state)
+        ti = session.get(TI, ti_id_str)
+        query = TI.duration_expression_update(datetime.now(tz=timezone.utc), 
query, session.bind)
+        query = query.values(state=TaskInstanceState.FAILED)
+        _handle_fail_fast_for_dag(ti=ti, dag_id=dag_id, session=session, 
dag_bag=dag_bag)
 
-        if updated_state == TerminalTIState.FAILED:
-            ti = session.get(TI, ti_id_str)
-            ser_dag = dag_bag.get_dag(dag_id)
-            if ser_dag and getattr(ser_dag, "fail_fast", False):
-                task_dict = getattr(ser_dag, "task_dict")
-                task_teardown_map = {k: v.is_teardown for k, v in 
task_dict.items()}
-                _stop_remaining_tasks(task_instance=ti, 
task_teardown_map=task_teardown_map, session=session)
-
-    elif isinstance(ti_patch_payload, TIRetryStatePayload):
+    # TODO: Replace this with FastAPI's Custom Exception handling:
+    # 
https://fastapi.tiangolo.com/tutorial/handling-errors/#install-custom-exception-handlers
+    try:
+        result = session.execute(query)
+        log.info("Task instance state updated", new_state=updated_state, 
rows_affected=result.rowcount)
+    except SQLAlchemyError as e:
+        log.error("Error updating Task Instance state", error=str(e))
+        raise HTTPException(
+            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 
detail="Database error occurred"
+        )
+
+
+def _handle_fail_fast_for_dag(ti: TI, dag_id: str, session: SessionDep, 
dag_bag: DagBagDep) -> None:
+    ser_dag = dag_bag.get_dag(dag_id)
+    if ser_dag and getattr(ser_dag, "fail_fast", False):
+        task_dict = getattr(ser_dag, "task_dict")
+        task_teardown_map = {k: v.is_teardown for k, v in task_dict.items()}
+        _stop_remaining_tasks(task_instance=ti, 
task_teardown_map=task_teardown_map, session=session)
+
+
+def _create_ti_state_update_query_and_update_state(
+    *,
+    ti_patch_payload: TIStateUpdate,
+    ti_id_str: str,
+    query: Update,
+    updated_state,
+    session: SessionDep,
+    dag_bag: DagBagDep,
+    dag_id: str,
+) -> tuple[Update, TaskInstanceState]:
+    if isinstance(ti_patch_payload, (TITerminalStatePayload, 
TIRetryStatePayload, TISuccessStatePayload)):
         ti = session.get(TI, ti_id_str)
         updated_state = ti_patch_payload.state
-        ti.prepare_db_for_next_try(session)
         query = TI.duration_expression_update(ti_patch_payload.end_date, 
query, session.bind)
         query = query.values(state=updated_state)
-    elif isinstance(ti_patch_payload, TISuccessStatePayload):
-        query = TI.duration_expression_update(ti_patch_payload.end_date, 
query, session.bind)
-        updated_state = ti_patch_payload.state
-        task_instance = session.get(TI, ti_id_str)
-        try:
+
+        if updated_state == TerminalTIState.FAILED:
+            # This is the only case needs extra handling for 
TITerminalStatePayload
+            _handle_fail_fast_for_dag(ti=ti, dag_id=dag_id, session=session, 
dag_bag=dag_bag)
+        elif isinstance(ti_patch_payload, TIRetryStatePayload):
+            ti.prepare_db_for_next_try(session)
+        elif isinstance(ti_patch_payload, TISuccessStatePayload):
             TI.register_asset_changes_in_db(
-                task_instance,
+                ti,
                 ti_patch_payload.task_outlets,  # type: ignore
                 ti_patch_payload.outlet_events,
                 session,
             )
-        except AirflowInactiveAssetInInletOrOutletException as err:
-            log.error("Asset registration failed due to conflicting asset: 
%s", err)
-
-        query = query.values(state=updated_state)
     elif isinstance(ti_patch_payload, TIDeferredStatePayload):
         # Calculate timeout if it was passed
         timeout = None
@@ -468,14 +502,17 @@ def ti_update_state(
             # As documented in 
https://dev.mysql.com/doc/refman/5.7/en/datetime.html.
             _MYSQL_TIMESTAMP_MAX = timezone.datetime(2038, 1, 19, 3, 14, 7)
             if ti_patch_payload.reschedule_date > _MYSQL_TIMESTAMP_MAX:
-                raise HTTPException(
-                    status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
-                    detail={
-                        "reason": "invalid_reschedule_date",
-                        "message": f"Cannot reschedule to 
{ti_patch_payload.reschedule_date.isoformat()} "
-                        f"since it is over MySQL's TIMESTAMP storage limit.",
-                    },
+                # Set a task to failed in case any unexpected exception 
happened during task state update
+                log.exception(
+                    "Error updating Task Instance state to %s. Set the task to 
failed", updated_state
                 )
+                data = 
ti_patch_payload.model_dump(exclude={"reschedule_date"}, exclude_unset=True)
+                query = update(TI).where(TI.id == ti_id_str).values(data)
+                query = 
TI.duration_expression_update(datetime.now(tz=timezone.utc), query, 
session.bind)
+                query = query.values(state=TaskInstanceState.FAILED)
+                ti = session.get(TI, ti_id_str)
+                _handle_fail_fast_for_dag(ti=ti, dag_id=dag_id, 
session=session, dag_bag=dag_bag)
+                return query, updated_state
 
         task_instance = session.get(TI, ti_id_str)
         actual_start_date = timezone.utcnow()
@@ -494,16 +531,10 @@ def ti_update_state(
         # clear the next_method and next_kwargs so that none of the retries 
pick them up
         query = query.values(state=TaskInstanceState.UP_FOR_RESCHEDULE, 
next_method=None, next_kwargs=None)
         updated_state = TaskInstanceState.UP_FOR_RESCHEDULE
-    # TODO: Replace this with FastAPI's Custom Exception handling:
-    # 
https://fastapi.tiangolo.com/tutorial/handling-errors/#install-custom-exception-handlers
-    try:
-        result = session.execute(query)
-        log.info("Task instance state updated", new_state=updated_state, 
rows_affected=result.rowcount)
-    except SQLAlchemyError as e:
-        log.error("Error updating Task Instance state", error=str(e))
-        raise HTTPException(
-            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 
detail="Database error occurred"
-        )
+    else:
+        raise ValueError(f"Unexpected Payload Type {type(ti_patch_payload)}")
+
+    return query, updated_state
 
 
 @ti_id_router.patch(
diff --git 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index 13f0b8df846..3c107f61863 100644
--- 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++ 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -705,6 +705,40 @@ class TestTIUpdateState:
         assert event[0].asset == AssetModel(name="my-task", 
uri="s3://bucket/my-task", extra={})
         assert event[0].extra == expected_extra
 
+    def test_ti_update_state_to_failed_with_inactive_asset(self, client, 
session, create_task_instance):
+        # inactive
+        asset = AssetModel(
+            id=1,
+            name="my-task-2",
+            uri="s3://bucket/my-task",
+            group="asset",
+            extra={},
+        )
+        session.add(asset)
+
+        ti = create_task_instance(
+            task_id="test_ti_update_state_to_success_with_asset_events",
+            start_date=DEFAULT_START_DATE,
+            state=State.RUNNING,
+        )
+        session.commit()
+
+        response = client.patch(
+            f"/execution/task-instances/{ti.id}/state",
+            json={
+                "state": "success",
+                "end_date": DEFAULT_END_DATE.isoformat(),
+                "task_outlets": [{"name": "my-task-2", "uri": 
"s3://bucket/my-task", "type": "Asset"}],
+                "outlet_events": [],
+            },
+        )
+
+        assert response.status_code == 204
+        session.expire_all()
+
+        ti = session.get(TaskInstance, ti.id)
+        assert ti.state == State.FAILED
+
     @pytest.mark.parametrize(
         "outlet_events, expected_extra",
         [
@@ -976,8 +1010,13 @@ class TestTIUpdateState:
             },
         )
 
-        assert response.status_code == 422
-        assert response.json()["detail"]["reason"] == "invalid_reschedule_date"
+        assert response.status_code == 204
+        assert response.text == ""
+
+        session.expire_all()
+
+        ti = session.get(TaskInstance, ti.id)
+        assert ti.state == State.FAILED
 
     def test_ti_update_state_handle_retry(self, client, session, 
create_task_instance):
         ti = create_task_instance(
diff --git a/task-sdk/src/airflow/sdk/api/client.py 
b/task-sdk/src/airflow/sdk/api/client.py
index 7ef0a68a6c0..89bee9807e1 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -174,10 +174,6 @@ class TaskInstanceOperations:
         )
         self.client.patch(f"task-instances/{id}/state", 
content=body.model_dump_json())
 
-    def heartbeat(self, id: uuid.UUID, pid: int):
-        body = TIHeartbeatInfo(pid=pid, hostname=get_hostname())
-        self.client.put(f"task-instances/{id}/heartbeat", 
content=body.model_dump_json())
-
     def defer(self, id: uuid.UUID, msg):
         """Tell the API server that this TI has been deferred."""
         body = TIDeferredStatePayload(**msg.model_dump(exclude_unset=True, 
exclude={"type"}))
@@ -192,6 +188,10 @@ class TaskInstanceOperations:
         # Create a reschedule state payload from msg
         self.client.patch(f"task-instances/{id}/state", 
content=body.model_dump_json())
 
+    def heartbeat(self, id: uuid.UUID, pid: int):
+        body = TIHeartbeatInfo(pid=pid, hostname=get_hostname())
+        self.client.put(f"task-instances/{id}/heartbeat", 
content=body.model_dump_json())
+
     def skip_downstream_tasks(self, id: uuid.UUID, msg: SkipDownstreamTasks):
         """Tell the API server to skip the downstream tasks of this TI."""
         body = TISkippedDownstreamTasksStatePayload(tasks=msg.tasks)

Reply via email to