This is an automated email from the ASF dual-hosted git repository.

rahulvats pushed a commit to branch v3-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 47ffacd1586f877db9c08fadad183a2f91a1d2fd
Author: Anish Giri <[email protected]>
AuthorDate: Wed Mar 25 16:25:03 2026 -0500

    Raise ``TaskAlreadyRunningError`` when starting an already-running task 
instance (#60855)
    
    * Fix task marked as failed on executor redelivery
    
    Handle 409 CONFLICT (task already running) from the API server gracefully
    by raising TaskAlreadyRunningError instead of letting it propagate as a
    generic failure.
    
    closes: #58441
    
    * Fix test_get_not_found assertion to match unwrapped detail format
    
    * address review feed back
    
    * Trigger CI re-run
    
    * Trigger CI re-run
    
    (cherry picked from commit f5ff9671ec81f9d259980a733766fef4c78ba91c)
---
 task-sdk/src/airflow/sdk/api/client.py             | 20 ++++++-
 task-sdk/src/airflow/sdk/exceptions.py             |  4 ++
 task-sdk/tests/task_sdk/api/test_client.py         | 59 +++++++++++++++---
 .../task_sdk/execution_time/test_supervisor.py     | 70 +++++++++++-----------
 4 files changed, 108 insertions(+), 45 deletions(-)

diff --git a/task-sdk/src/airflow/sdk/api/client.py 
b/task-sdk/src/airflow/sdk/api/client.py
index 90374f76be5..19e691281f7 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -76,7 +76,7 @@ from airflow.sdk.api.datamodels._generated import (
     XComSequenceSliceResponse,
 )
 from airflow.sdk.configuration import conf
-from airflow.sdk.exceptions import ErrorType
+from airflow.sdk.exceptions import ErrorType, TaskAlreadyRunningError
 from airflow.sdk.execution_time.comms import (
     CreateHITLDetailPayload,
     DRCount,
@@ -216,7 +216,18 @@ class TaskInstanceOperations:
         """Tell the API server that this TI has started running."""
         body = TIEnterRunningPayload(pid=pid, hostname=get_hostname(), 
unixname=getuser(), start_date=when)
 
-        resp = self.client.patch(f"task-instances/{id}/run", 
content=body.model_dump_json())
+        try:
+            resp = self.client.patch(f"task-instances/{id}/run", 
content=body.model_dump_json())
+        except ServerResponseError as e:
+            if e.response.status_code == HTTPStatus.CONFLICT:
+                detail = e.detail
+                if (
+                    isinstance(detail, dict)
+                    and detail.get("reason") == "invalid_state"
+                    and detail.get("previous_state") == "running"
+                ):
+                    raise TaskAlreadyRunningError(f"Task instance {id} is 
already running") from e
+            raise
         return TIRunContext.model_validate_json(resp.read())
 
     def finish(self, id: uuid.UUID, state: TerminalStateNonSuccess, when: 
datetime, rendered_map_index):
@@ -1034,7 +1045,7 @@ class Client(httpx.Client):
 
 # This is only used for parsing. ServerResponseError is raised instead
 class _ErrorBody(BaseModel):
-    detail: list[RemoteValidationError] | str
+    detail: list[RemoteValidationError] | dict[str, Any] | str
 
     def __repr__(self):
         return repr(self.detail)
@@ -1068,6 +1079,9 @@ class ServerResponseError(httpx.HTTPStatusError):
             if isinstance(body.detail, list):
                 detail = body.detail
                 msg = "Remote server returned validation error"
+            elif isinstance(body.detail, dict):
+                detail = body.detail
+                msg = "Server returned error"
             else:
                 msg = body.detail or "Un-parseable error"
         except Exception:
diff --git a/task-sdk/src/airflow/sdk/exceptions.py 
b/task-sdk/src/airflow/sdk/exceptions.py
index 9f8a5f11fbf..b69abe62265 100644
--- a/task-sdk/src/airflow/sdk/exceptions.py
+++ b/task-sdk/src/airflow/sdk/exceptions.py
@@ -330,6 +330,10 @@ class TaskNotFound(AirflowException):
     """Raise when a Task is not available in the system."""
 
 
+class TaskAlreadyRunningError(AirflowException):
+    """Raised when a task is already running on another worker."""
+
+
 class FailFastDagInvalidTriggerRule(AirflowException):
     """Raise when a dag has 'fail_fast' enabled yet has a non-default trigger 
rule."""
 
diff --git a/task-sdk/tests/task_sdk/api/test_client.py 
b/task-sdk/tests/task_sdk/api/test_client.py
index 7d960b76570..0df8839c55f 100644
--- a/task-sdk/tests/task_sdk/api/test_client.py
+++ b/task-sdk/tests/task_sdk/api/test_client.py
@@ -47,7 +47,7 @@ from airflow.sdk.api.datamodels._generated import (
     VariableResponse,
     XComResponse,
 )
-from airflow.sdk.exceptions import ErrorType
+from airflow.sdk.exceptions import ErrorType, TaskAlreadyRunningError
 from airflow.sdk.execution_time.comms import (
     DeferTask,
     ErrorResponse,
@@ -161,7 +161,7 @@ class TestClient:
 
         err = exc_info.value
         assert err.args == ("Server returned error",)
-        assert err.detail == {"detail": {"message": "Invalid input"}}
+        assert err.detail == {"message": "Invalid input"}
 
         # Check that the error is picklable
         pickled = pickle.dumps(err)
@@ -171,7 +171,7 @@ class TestClient:
 
         # Test that unpickled error has the same attributes as the original
         assert unpickled.response.json() == {"detail": {"message": "Invalid 
input"}}
-        assert unpickled.detail == {"detail": {"message": "Invalid input"}}
+        assert unpickled.detail == {"message": "Invalid input"}
         assert unpickled.response.status_code == 404
         assert unpickled.request.url == "http://error";
 
@@ -333,6 +333,53 @@ class TestTaskInstanceOperations:
             assert resp == ti_context
             assert call_count == 3
 
+    def test_task_instance_start_already_running(self):
+        """Test that start() raises TaskAlreadyRunningError when TI is already 
running."""
+        ti_id = uuid6.uuid7()
+
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            if request.url.path == f"/task-instances/{ti_id}/run":
+                return httpx.Response(
+                    409,
+                    json={
+                        "detail": {
+                            "reason": "invalid_state",
+                            "message": "TI was not in a state where it could 
be marked as running",
+                            "previous_state": "running",
+                        }
+                    },
+                )
+            return httpx.Response(status_code=204)
+
+        client = make_client(transport=httpx.MockTransport(handle_request))
+
+        with pytest.raises(TaskAlreadyRunningError, match="already running"):
+            client.task_instances.start(ti_id, 100, datetime(2024, 10, 31, 
tzinfo=timezone.utc))
+
+    @pytest.mark.parametrize("previous_state", ["failed", "success", 
"skipped"])
+    def test_task_instance_start_other_invalid_states(self, previous_state):
+        """Test that start() raises ServerResponseError for non-running 
invalid states."""
+        ti_id = uuid6.uuid7()
+
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            if request.url.path == f"/task-instances/{ti_id}/run":
+                return httpx.Response(
+                    409,
+                    json={
+                        "detail": {
+                            "reason": "invalid_state",
+                            "message": "TI was not in a state where it could 
be marked as running",
+                            "previous_state": previous_state,
+                        }
+                    },
+                )
+            return httpx.Response(status_code=204)
+
+        client = make_client(transport=httpx.MockTransport(handle_request))
+
+        with pytest.raises(ServerResponseError):
+            client.task_instances.start(ti_id, 100, datetime(2024, 10, 31, 
tzinfo=timezone.utc))
+
     @pytest.mark.parametrize(
         "state", [state for state in TerminalTIState if state != 
TerminalTIState.SUCCESS]
     )
@@ -1627,10 +1674,8 @@ class TestDagsOperations:
 
         assert exc_info.value.response.status_code == 404
         assert exc_info.value.detail == {
-            "detail": {
-                "message": "The Dag with dag_id: `missing_dag` was not found",
-                "reason": "not_found",
-            }
+            "message": "The Dag with dag_id: `missing_dag` was not found",
+            "reason": "not_found",
         }
 
     def test_get_server_error(self):
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py 
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index ef6cd19b8d7..b486ce77766 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -63,7 +63,7 @@ from airflow.sdk.api.datamodels._generated import (
     TaskInstance,
     TaskInstanceState,
 )
-from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
+from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType, 
TaskAlreadyRunningError
 from airflow.sdk.execution_time import task_runner
 from airflow.sdk.execution_time.comms import (
     AssetEventsResult,
@@ -731,40 +731,6 @@ class TestWatchedSubprocess:
             "task_instance_id": str(ti.id),
         } in captured_logs
 
-    def test_supervisor_handles_already_running_task(self):
-        """Test that Supervisor prevents starting a Task Instance that is 
already running."""
-        ti = TaskInstance(
-            id=uuid7(), task_id="b", dag_id="c", run_id="d", try_number=1, 
dag_version_id=uuid7()
-        )
-
-        # Mock API Server response indicating the TI is already running
-        # The API Server would return a 409 Conflict status code if the TI is 
not
-        # in a "queued" state.
-        def handle_request(request: httpx.Request) -> httpx.Response:
-            if request.url.path == f"/task-instances/{ti.id}/run":
-                return httpx.Response(
-                    409,
-                    json={
-                        "reason": "invalid_state",
-                        "message": "TI was not in a state where it could be 
marked as running",
-                        "previous_state": "running",
-                    },
-                )
-
-            return httpx.Response(status_code=204)
-
-        client = make_client(transport=httpx.MockTransport(handle_request))
-
-        with pytest.raises(ServerResponseError, match="Server returned error") 
as err:
-            ActivitySubprocess.start(dag_rel_path=os.devnull, 
bundle_info=FAKE_BUNDLE, what=ti, client=client)
-
-        assert err.value.response.status_code == 409
-        assert err.value.detail == {
-            "reason": "invalid_state",
-            "message": "TI was not in a state where it could be marked as 
running",
-            "previous_state": "running",
-        }
-
     @pytest.mark.parametrize("captured_logs", [logging.ERROR], indirect=True, 
ids=["log_level=error"])
     def test_state_conflict_on_heartbeat(self, captured_logs, monkeypatch, 
mocker, make_ti_context_dict):
         """
@@ -865,6 +831,40 @@ class TestWatchedSubprocess:
             },
         ]
 
+    def test_start_raises_task_already_running_and_kills_subprocess(self):
+        """Test that ActivitySubprocess.start() raises TaskAlreadyRunningError 
and kills the child
+        when the API returns 409 with previous_state='running'."""
+        ti_id = uuid7()
+
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            if request.url.path == f"/task-instances/{ti_id}/run":
+                return httpx.Response(
+                    409,
+                    json={
+                        "detail": {
+                            "reason": "invalid_state",
+                            "message": "TI was not in a state where it could 
be marked as running",
+                            "previous_state": "running",
+                        }
+                    },
+                )
+            return httpx.Response(status_code=204)
+
+        def subprocess_main():
+            # Ensure we follow the "protocol" and get the startup message 
before we do anything
+            CommsDecoder()._get_response()
+
+        with pytest.raises(TaskAlreadyRunningError, match="already running"):
+            ActivitySubprocess.start(
+                dag_rel_path=os.devnull,
+                bundle_info=FAKE_BUNDLE,
+                what=TaskInstance(
+                    id=ti_id, task_id="b", dag_id="c", run_id="d", 
try_number=1, dag_version_id=uuid7()
+                ),
+                
client=make_client(transport=httpx.MockTransport(handle_request)),
+                target=subprocess_main,
+            )
+
     @pytest.mark.parametrize("captured_logs", [logging.WARNING], indirect=True)
     def test_heartbeat_failures_handling(self, monkeypatch, mocker, 
captured_logs, time_machine):
         """

Reply via email to