This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch v3-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v3-0-test by this push:
new acc2c97100a [v3-0-test] feat(task_instances): guard ti update state
and set task to fail if exception encountered (#51295) (#51470)
acc2c97100a is described below
commit acc2c97100a4393a41a279b0047f9933a21e0381
Author: github-actions[bot]
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Fri Jun 6 14:24:08 2025 +0530
[v3-0-test] feat(task_instances): guard ti update state and set task to
fail if exception encountered (#51295) (#51470)
* feat(task_instances): guard ti update state and set to fail if exception
enounctered
* feat(task_instances): catch mysql error and set the task to fail
* test: remove unnecessay check
* fix(task_instances): handle mysql error
* refactor(task-instances): merge mysql logic back to the original private
function
(cherry picked from commit b5a3b4e7d02d3a286a7daf81c6e5e8d4855eb553)
Co-authored-by: Wei Lee <[email protected]>
---
.../execution_api/routes/task_instances.py | 117 +++++++++++++--------
.../versions/head/test_task_instances.py | 43 +++++++-
task-sdk/src/airflow/sdk/api/client.py | 8 +-
3 files changed, 119 insertions(+), 49 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index 6c86ef38035..e22d0a5f34d 100644
---
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -22,6 +22,7 @@ import itertools
import json
from collections import defaultdict
from collections.abc import Iterator
+from datetime import datetime
from typing import TYPE_CHECKING, Annotated, Any
from uuid import UUID
@@ -55,7 +56,7 @@ from
airflow.api_fastapi.execution_api.datamodels.taskinstance import (
TITerminalStatePayload,
)
from airflow.api_fastapi.execution_api.deps import JWTBearerTIPathDep
-from airflow.exceptions import AirflowInactiveAssetInInletOrOutletException,
TaskNotFound
+from airflow.exceptions import TaskNotFound
from airflow.models.asset import AssetActive
from airflow.models.dagbag import DagBag
from airflow.models.dagrun import DagRun as DR
@@ -70,6 +71,8 @@ from airflow.utils import timezone
from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState
if TYPE_CHECKING:
+ from sqlalchemy.sql.dml import Update
+
from airflow.sdk.types import Operator
@@ -381,43 +384,74 @@ def ti_update_state(
# We exclude_unset to avoid updating fields that are not set in the payload
data = ti_patch_payload.model_dump(exclude={"task_outlets",
"outlet_events"}, exclude_unset=True)
-
query = update(TI).where(TI.id == ti_id_str).values(data)
- if isinstance(ti_patch_payload, TITerminalStatePayload):
- updated_state = ti_patch_payload.state
- query = TI.duration_expression_update(ti_patch_payload.end_date,
query, session.bind)
- query = query.values(state=updated_state)
+ try:
+ query, updated_state = _create_ti_state_update_query_and_update_state(
+ ti_patch_payload=ti_patch_payload,
+ ti_id_str=ti_id_str,
+ session=session,
+ query=query,
+ updated_state=updated_state,
+ dag_id=dag_id,
+ dag_bag=dag_bag,
+ )
+ except Exception:
+ # Set a task to failed in case any unexpected exception happened
during task state update
+ log.exception("Error updating Task Instance state to %s. Set the task
to failed", updated_state)
+ ti = session.get(TI, ti_id_str)
+ query = TI.duration_expression_update(datetime.now(tz=timezone.utc),
query, session.bind)
+ query = query.values(state=TaskInstanceState.FAILED)
+ _handle_fail_fast_for_dag(ti=ti, dag_id=dag_id, session=session,
dag_bag=dag_bag)
- if updated_state == TerminalTIState.FAILED:
- ti = session.get(TI, ti_id_str)
- ser_dag = dag_bag.get_dag(dag_id)
- if ser_dag and getattr(ser_dag, "fail_fast", False):
- task_dict = getattr(ser_dag, "task_dict")
- task_teardown_map = {k: v.is_teardown for k, v in
task_dict.items()}
- _stop_remaining_tasks(task_instance=ti,
task_teardown_map=task_teardown_map, session=session)
-
- elif isinstance(ti_patch_payload, TIRetryStatePayload):
+ # TODO: Replace this with FastAPI's Custom Exception handling:
+ #
https://fastapi.tiangolo.com/tutorial/handling-errors/#install-custom-exception-handlers
+ try:
+ result = session.execute(query)
+ log.info("Task instance state updated", new_state=updated_state,
rows_affected=result.rowcount)
+ except SQLAlchemyError as e:
+ log.error("Error updating Task Instance state", error=str(e))
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Database error occurred"
+ )
+
+
+def _handle_fail_fast_for_dag(ti: TI, dag_id: str, session: SessionDep,
dag_bag: DagBagDep) -> None:
+ ser_dag = dag_bag.get_dag(dag_id)
+ if ser_dag and getattr(ser_dag, "fail_fast", False):
+ task_dict = getattr(ser_dag, "task_dict")
+ task_teardown_map = {k: v.is_teardown for k, v in task_dict.items()}
+ _stop_remaining_tasks(task_instance=ti,
task_teardown_map=task_teardown_map, session=session)
+
+
+def _create_ti_state_update_query_and_update_state(
+ *,
+ ti_patch_payload: TIStateUpdate,
+ ti_id_str: str,
+ query: Update,
+ updated_state,
+ session: SessionDep,
+ dag_bag: DagBagDep,
+ dag_id: str,
+) -> tuple[Update, TaskInstanceState]:
+ if isinstance(ti_patch_payload, (TITerminalStatePayload,
TIRetryStatePayload, TISuccessStatePayload)):
ti = session.get(TI, ti_id_str)
updated_state = ti_patch_payload.state
- ti.prepare_db_for_next_try(session)
query = TI.duration_expression_update(ti_patch_payload.end_date,
query, session.bind)
query = query.values(state=updated_state)
- elif isinstance(ti_patch_payload, TISuccessStatePayload):
- query = TI.duration_expression_update(ti_patch_payload.end_date,
query, session.bind)
- updated_state = ti_patch_payload.state
- task_instance = session.get(TI, ti_id_str)
- try:
+
+ if updated_state == TerminalTIState.FAILED:
+ # This is the only case needs extra handling for
TITerminalStatePayload
+ _handle_fail_fast_for_dag(ti=ti, dag_id=dag_id, session=session,
dag_bag=dag_bag)
+ elif isinstance(ti_patch_payload, TIRetryStatePayload):
+ ti.prepare_db_for_next_try(session)
+ elif isinstance(ti_patch_payload, TISuccessStatePayload):
TI.register_asset_changes_in_db(
- task_instance,
+ ti,
ti_patch_payload.task_outlets, # type: ignore
ti_patch_payload.outlet_events,
session,
)
- except AirflowInactiveAssetInInletOrOutletException as err:
- log.error("Asset registration failed due to conflicting asset:
%s", err)
-
- query = query.values(state=updated_state)
elif isinstance(ti_patch_payload, TIDeferredStatePayload):
# Calculate timeout if it was passed
timeout = None
@@ -468,14 +502,17 @@ def ti_update_state(
# As documented in
https://dev.mysql.com/doc/refman/5.7/en/datetime.html.
_MYSQL_TIMESTAMP_MAX = timezone.datetime(2038, 1, 19, 3, 14, 7)
if ti_patch_payload.reschedule_date > _MYSQL_TIMESTAMP_MAX:
- raise HTTPException(
- status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- detail={
- "reason": "invalid_reschedule_date",
- "message": f"Cannot reschedule to
{ti_patch_payload.reschedule_date.isoformat()} "
- f"since it is over MySQL's TIMESTAMP storage limit.",
- },
+ # Set a task to failed in case any unexpected exception
happened during task state update
+ log.exception(
+ "Error updating Task Instance state to %s. Set the task to
failed", updated_state
)
+ data =
ti_patch_payload.model_dump(exclude={"reschedule_date"}, exclude_unset=True)
+ query = update(TI).where(TI.id == ti_id_str).values(data)
+ query =
TI.duration_expression_update(datetime.now(tz=timezone.utc), query,
session.bind)
+ query = query.values(state=TaskInstanceState.FAILED)
+ ti = session.get(TI, ti_id_str)
+ _handle_fail_fast_for_dag(ti=ti, dag_id=dag_id,
session=session, dag_bag=dag_bag)
+ return query, updated_state
task_instance = session.get(TI, ti_id_str)
actual_start_date = timezone.utcnow()
@@ -494,16 +531,10 @@ def ti_update_state(
# clear the next_method and next_kwargs so that none of the retries
pick them up
query = query.values(state=TaskInstanceState.UP_FOR_RESCHEDULE,
next_method=None, next_kwargs=None)
updated_state = TaskInstanceState.UP_FOR_RESCHEDULE
- # TODO: Replace this with FastAPI's Custom Exception handling:
- #
https://fastapi.tiangolo.com/tutorial/handling-errors/#install-custom-exception-handlers
- try:
- result = session.execute(query)
- log.info("Task instance state updated", new_state=updated_state,
rows_affected=result.rowcount)
- except SQLAlchemyError as e:
- log.error("Error updating Task Instance state", error=str(e))
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Database error occurred"
- )
+ else:
+ raise ValueError(f"Unexpected Payload Type {type(ti_patch_payload)}")
+
+ return query, updated_state
@ti_id_router.patch(
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index 13f0b8df846..3c107f61863 100644
---
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -705,6 +705,40 @@ class TestTIUpdateState:
assert event[0].asset == AssetModel(name="my-task",
uri="s3://bucket/my-task", extra={})
assert event[0].extra == expected_extra
+ def test_ti_update_state_to_failed_with_inactive_asset(self, client,
session, create_task_instance):
+ # inactive
+ asset = AssetModel(
+ id=1,
+ name="my-task-2",
+ uri="s3://bucket/my-task",
+ group="asset",
+ extra={},
+ )
+ session.add(asset)
+
+ ti = create_task_instance(
+ task_id="test_ti_update_state_to_success_with_asset_events",
+ start_date=DEFAULT_START_DATE,
+ state=State.RUNNING,
+ )
+ session.commit()
+
+ response = client.patch(
+ f"/execution/task-instances/{ti.id}/state",
+ json={
+ "state": "success",
+ "end_date": DEFAULT_END_DATE.isoformat(),
+ "task_outlets": [{"name": "my-task-2", "uri":
"s3://bucket/my-task", "type": "Asset"}],
+ "outlet_events": [],
+ },
+ )
+
+ assert response.status_code == 204
+ session.expire_all()
+
+ ti = session.get(TaskInstance, ti.id)
+ assert ti.state == State.FAILED
+
@pytest.mark.parametrize(
"outlet_events, expected_extra",
[
@@ -976,8 +1010,13 @@ class TestTIUpdateState:
},
)
- assert response.status_code == 422
- assert response.json()["detail"]["reason"] == "invalid_reschedule_date"
+ assert response.status_code == 204
+ assert response.text == ""
+
+ session.expire_all()
+
+ ti = session.get(TaskInstance, ti.id)
+ assert ti.state == State.FAILED
def test_ti_update_state_handle_retry(self, client, session,
create_task_instance):
ti = create_task_instance(
diff --git a/task-sdk/src/airflow/sdk/api/client.py
b/task-sdk/src/airflow/sdk/api/client.py
index 7ef0a68a6c0..89bee9807e1 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -174,10 +174,6 @@ class TaskInstanceOperations:
)
self.client.patch(f"task-instances/{id}/state",
content=body.model_dump_json())
- def heartbeat(self, id: uuid.UUID, pid: int):
- body = TIHeartbeatInfo(pid=pid, hostname=get_hostname())
- self.client.put(f"task-instances/{id}/heartbeat",
content=body.model_dump_json())
-
def defer(self, id: uuid.UUID, msg):
"""Tell the API server that this TI has been deferred."""
body = TIDeferredStatePayload(**msg.model_dump(exclude_unset=True,
exclude={"type"}))
@@ -192,6 +188,10 @@ class TaskInstanceOperations:
# Create a reschedule state payload from msg
self.client.patch(f"task-instances/{id}/state",
content=body.model_dump_json())
+ def heartbeat(self, id: uuid.UUID, pid: int):
+ body = TIHeartbeatInfo(pid=pid, hostname=get_hostname())
+ self.client.put(f"task-instances/{id}/heartbeat",
content=body.model_dump_json())
+
def skip_downstream_tasks(self, id: uuid.UUID, msg: SkipDownstreamTasks):
"""Tell the API server to skip the downstream tasks of this TI."""
body = TISkippedDownstreamTasksStatePayload(tasks=msg.tasks)