This is an automated email from the ASF dual-hosted git repository. rahulvats pushed a commit to branch v3-2-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 47ffacd1586f877db9c08fadad183a2f91a1d2fd Author: Anish Giri <[email protected]> AuthorDate: Wed Mar 25 16:25:03 2026 -0500 Raise ``TaskAlreadyRunningError`` when starting an already-running task instance (#60855) * Fix task marked as failed on executor redelivery Handle 409 CONFLICT (task already running) from the API server gracefully by raising TaskAlreadyRunningError instead of letting it propagate as a generic failure. closes: #58441 * Fix test_get_not_found assertion to match unwrapped detail format * address review feed back * Trigger CI re-run * Trigger CI re-run (cherry picked from commit f5ff9671ec81f9d259980a733766fef4c78ba91c) --- task-sdk/src/airflow/sdk/api/client.py | 20 ++++++- task-sdk/src/airflow/sdk/exceptions.py | 4 ++ task-sdk/tests/task_sdk/api/test_client.py | 59 +++++++++++++++--- .../task_sdk/execution_time/test_supervisor.py | 70 +++++++++++----------- 4 files changed, 108 insertions(+), 45 deletions(-) diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 90374f76be5..19e691281f7 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -76,7 +76,7 @@ from airflow.sdk.api.datamodels._generated import ( XComSequenceSliceResponse, ) from airflow.sdk.configuration import conf -from airflow.sdk.exceptions import ErrorType +from airflow.sdk.exceptions import ErrorType, TaskAlreadyRunningError from airflow.sdk.execution_time.comms import ( CreateHITLDetailPayload, DRCount, @@ -216,7 +216,18 @@ class TaskInstanceOperations: """Tell the API server that this TI has started running.""" body = TIEnterRunningPayload(pid=pid, hostname=get_hostname(), unixname=getuser(), start_date=when) - resp = self.client.patch(f"task-instances/{id}/run", content=body.model_dump_json()) + try: + resp = self.client.patch(f"task-instances/{id}/run", content=body.model_dump_json()) + except ServerResponseError as e: + if e.response.status_code == HTTPStatus.CONFLICT: + detail = e.detail + if ( + isinstance(detail, dict) + and detail.get("reason") == "invalid_state" + and detail.get("previous_state") == "running" + ): + raise TaskAlreadyRunningError(f"Task instance {id} is already running") from e + raise return TIRunContext.model_validate_json(resp.read()) def finish(self, id: uuid.UUID, state: TerminalStateNonSuccess, when: datetime, rendered_map_index): @@ -1034,7 +1045,7 @@ class Client(httpx.Client): # This is only used for parsing. ServerResponseError is raised instead class _ErrorBody(BaseModel): - detail: list[RemoteValidationError] | str + detail: list[RemoteValidationError] | dict[str, Any] | str def __repr__(self): return repr(self.detail) @@ -1068,6 +1079,9 @@ class ServerResponseError(httpx.HTTPStatusError): if isinstance(body.detail, list): detail = body.detail msg = "Remote server returned validation error" + elif isinstance(body.detail, dict): + detail = body.detail + msg = "Server returned error" else: msg = body.detail or "Un-parseable error" except Exception: diff --git a/task-sdk/src/airflow/sdk/exceptions.py b/task-sdk/src/airflow/sdk/exceptions.py index 9f8a5f11fbf..b69abe62265 100644 --- a/task-sdk/src/airflow/sdk/exceptions.py +++ b/task-sdk/src/airflow/sdk/exceptions.py @@ -330,6 +330,10 @@ class TaskNotFound(AirflowException): """Raise when a Task is not available in the system.""" +class TaskAlreadyRunningError(AirflowException): + """Raised when a task is already running on another worker.""" + + class FailFastDagInvalidTriggerRule(AirflowException): """Raise when a dag has 'fail_fast' enabled yet has a non-default trigger rule.""" diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index 7d960b76570..0df8839c55f 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -47,7 +47,7 @@ from airflow.sdk.api.datamodels._generated import ( VariableResponse, XComResponse, ) -from airflow.sdk.exceptions import ErrorType +from airflow.sdk.exceptions import ErrorType, TaskAlreadyRunningError from airflow.sdk.execution_time.comms import ( DeferTask, ErrorResponse, @@ -161,7 +161,7 @@ class TestClient: err = exc_info.value assert err.args == ("Server returned error",) - assert err.detail == {"detail": {"message": "Invalid input"}} + assert err.detail == {"message": "Invalid input"} # Check that the error is picklable pickled = pickle.dumps(err) @@ -171,7 +171,7 @@ class TestClient: # Test that unpickled error has the same attributes as the original assert unpickled.response.json() == {"detail": {"message": "Invalid input"}} - assert unpickled.detail == {"detail": {"message": "Invalid input"}} + assert unpickled.detail == {"message": "Invalid input"} assert unpickled.response.status_code == 404 assert unpickled.request.url == "http://error" @@ -333,6 +333,53 @@ class TestTaskInstanceOperations: assert resp == ti_context assert call_count == 3 + def test_task_instance_start_already_running(self): + """Test that start() raises TaskAlreadyRunningError when TI is already running.""" + ti_id = uuid6.uuid7() + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == f"/task-instances/{ti_id}/run": + return httpx.Response( + 409, + json={ + "detail": { + "reason": "invalid_state", + "message": "TI was not in a state where it could be marked as running", + "previous_state": "running", + } + }, + ) + return httpx.Response(status_code=204) + + client = make_client(transport=httpx.MockTransport(handle_request)) + + with pytest.raises(TaskAlreadyRunningError, match="already running"): + client.task_instances.start(ti_id, 100, datetime(2024, 10, 31, tzinfo=timezone.utc)) + + @pytest.mark.parametrize("previous_state", ["failed", "success", "skipped"]) + def test_task_instance_start_other_invalid_states(self, previous_state): + """Test that start() raises ServerResponseError for non-running invalid states.""" + ti_id = uuid6.uuid7() + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == f"/task-instances/{ti_id}/run": + return httpx.Response( + 409, + json={ + "detail": { + "reason": "invalid_state", + "message": "TI was not in a state where it could be marked as running", + "previous_state": previous_state, + } + }, + ) + return httpx.Response(status_code=204) + + client = make_client(transport=httpx.MockTransport(handle_request)) + + with pytest.raises(ServerResponseError): + client.task_instances.start(ti_id, 100, datetime(2024, 10, 31, tzinfo=timezone.utc)) + @pytest.mark.parametrize( "state", [state for state in TerminalTIState if state != TerminalTIState.SUCCESS] ) @@ -1627,10 +1674,8 @@ class TestDagsOperations: assert exc_info.value.response.status_code == 404 assert exc_info.value.detail == { - "detail": { - "message": "The Dag with dag_id: `missing_dag` was not found", - "reason": "not_found", - } + "message": "The Dag with dag_id: `missing_dag` was not found", + "reason": "not_found", } def test_get_server_error(self): diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index ef6cd19b8d7..b486ce77766 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -63,7 +63,7 @@ from airflow.sdk.api.datamodels._generated import ( TaskInstance, TaskInstanceState, ) -from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType +from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType, TaskAlreadyRunningError from airflow.sdk.execution_time import task_runner from airflow.sdk.execution_time.comms import ( AssetEventsResult, @@ -731,40 +731,6 @@ class TestWatchedSubprocess: "task_instance_id": str(ti.id), } in captured_logs - def test_supervisor_handles_already_running_task(self): - """Test that Supervisor prevents starting a Task Instance that is already running.""" - ti = TaskInstance( - id=uuid7(), task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7() - ) - - # Mock API Server response indicating the TI is already running - # The API Server would return a 409 Conflict status code if the TI is not - # in a "queued" state. - def handle_request(request: httpx.Request) -> httpx.Response: - if request.url.path == f"/task-instances/{ti.id}/run": - return httpx.Response( - 409, - json={ - "reason": "invalid_state", - "message": "TI was not in a state where it could be marked as running", - "previous_state": "running", - }, - ) - - return httpx.Response(status_code=204) - - client = make_client(transport=httpx.MockTransport(handle_request)) - - with pytest.raises(ServerResponseError, match="Server returned error") as err: - ActivitySubprocess.start(dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, what=ti, client=client) - - assert err.value.response.status_code == 409 - assert err.value.detail == { - "reason": "invalid_state", - "message": "TI was not in a state where it could be marked as running", - "previous_state": "running", - } - @pytest.mark.parametrize("captured_logs", [logging.ERROR], indirect=True, ids=["log_level=error"]) def test_state_conflict_on_heartbeat(self, captured_logs, monkeypatch, mocker, make_ti_context_dict): """ @@ -865,6 +831,40 @@ class TestWatchedSubprocess: }, ] + def test_start_raises_task_already_running_and_kills_subprocess(self): + """Test that ActivitySubprocess.start() raises TaskAlreadyRunningError and kills the child + when the API returns 409 with previous_state='running'.""" + ti_id = uuid7() + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == f"/task-instances/{ti_id}/run": + return httpx.Response( + 409, + json={ + "detail": { + "reason": "invalid_state", + "message": "TI was not in a state where it could be marked as running", + "previous_state": "running", + } + }, + ) + return httpx.Response(status_code=204) + + def subprocess_main(): + # Ensure we follow the "protocol" and get the startup message before we do anything + CommsDecoder()._get_response() + + with pytest.raises(TaskAlreadyRunningError, match="already running"): + ActivitySubprocess.start( + dag_rel_path=os.devnull, + bundle_info=FAKE_BUNDLE, + what=TaskInstance( + id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7() + ), + client=make_client(transport=httpx.MockTransport(handle_request)), + target=subprocess_main, + ) + @pytest.mark.parametrize("captured_logs", [logging.WARNING], indirect=True) def test_heartbeat_failures_handling(self, monkeypatch, mocker, captured_logs, time_machine): """
