This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b1a44b4e3db AIP-72: Improve Supervisor and Task Instance State 
Validation (#44405)
b1a44b4e3db is described below

commit b1a44b4e3db8a3d3c13dfc810ad02f5785f82eca
Author: Kaxil Naik <kaxiln...@apache.org>
AuthorDate: Wed Nov 27 12:13:44 2024 +0000

    AIP-72: Improve Supervisor and Task Instance State Validation (#44405)
---
 task_sdk/tests/execution_time/test_supervisor.py   | 38 +++++++++++++++++++++-
 .../execution_api/routes/test_task_instances.py    | 13 +++++---
 2 files changed, 45 insertions(+), 6 deletions(-)

diff --git a/task_sdk/tests/execution_time/test_supervisor.py 
b/task_sdk/tests/execution_time/test_supervisor.py
index 1083364f289..9f582074586 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -28,11 +28,13 @@ from time import sleep
 from typing import TYPE_CHECKING
 from unittest.mock import MagicMock
 
+import httpx
 import pytest
 import structlog
 from uuid6 import uuid7
 
 from airflow.sdk.api import client as sdk_client
+from airflow.sdk.api.client import ServerResponseError
 from airflow.sdk.api.datamodels._generated import TaskInstance
 from airflow.sdk.api.datamodels.activities import ExecuteTaskActivity
 from airflow.sdk.execution_time.comms import (
@@ -46,6 +48,8 @@ from airflow.sdk.execution_time.comms import (
 from airflow.sdk.execution_time.supervisor import WatchedSubprocess, supervise
 from airflow.utils import timezone as tz
 
+from task_sdk.tests.api.test_client import make_client
+
 if TYPE_CHECKING:
     import kgb
 
@@ -73,7 +77,7 @@ class TestWatchedSubprocess:
             print("I'm a short message")
             sys.stdout.write("Message ")
             print("stderr message", file=sys.stderr)
-            # We need a short sleep for the main process to process things. I 
worry this timining will be
+            # We need a short sleep for the main process to process things. I 
worry this timing will be
             # fragile, but I can't think of a better way. This lets the stdout 
be read (partial line) and the
             # stderr full line be read
             sleep(0.1)
@@ -265,6 +269,38 @@ class TestWatchedSubprocess:
             "timestamp": "2024-11-07T12:34:56.078901Z",
         } in captured_logs
 
+    def test_supervisor_handles_already_running_task(self):
+        """Test that Supervisor prevents starting a Task Instance that is 
already running."""
+        ti = TaskInstance(id=uuid7(), task_id="b", dag_id="c", run_id="d", 
try_number=1)
+
+        # Mock API Server response indicating the TI is already running
+        # The API Server would return a 409 Conflict status code if the TI is 
not
+        # in a "queued" state.
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            if request.url.path == f"/task-instances/{ti.id}/state":
+                return httpx.Response(
+                    409,
+                    json={
+                        "reason": "invalid_state",
+                        "message": "TI was not in a state where it could be 
marked as running",
+                        "previous_state": "running",
+                    },
+                )
+
+            return httpx.Response(status_code=204)
+
+        client = make_client(transport=httpx.MockTransport(handle_request))
+
+        with pytest.raises(ServerResponseError, match="Server returned error") 
as err:
+            WatchedSubprocess.start(path=os.devnull, ti=ti, client=client)
+
+        assert err.value.response.status_code == 409
+        assert err.value.detail == {
+            "reason": "invalid_state",
+            "message": "TI was not in a state where it could be marked as 
running",
+            "previous_state": "running",
+        }
+
 
 class TestHandleRequest:
     @pytest.fixture
diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py 
b/tests/api_fastapi/execution_api/routes/test_task_instances.py
index efb48ccb533..d2285e7e3a9 100644
--- a/tests/api_fastapi/execution_api/routes/test_task_instances.py
+++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py
@@ -25,7 +25,7 @@ from sqlalchemy.exc import SQLAlchemyError
 
 from airflow.models.taskinstance import TaskInstance
 from airflow.utils import timezone
-from airflow.utils.state import State
+from airflow.utils.state import State, TaskInstanceState
 
 from tests_common.test_utils.db import clear_db_runs
 
@@ -79,14 +79,17 @@ class TestTIUpdateState:
         assert ti.pid == 100
         assert ti.start_date.isoformat() == "2024-10-31T12:00:00+00:00"
 
-    def test_ti_update_state_conflict_if_not_queued(self, client, session, 
create_task_instance):
+    @pytest.mark.parametrize("initial_ti_state", [s for s in TaskInstanceState 
if s != State.QUEUED])
+    def test_ti_update_state_conflict_if_not_queued(
+        self, client, session, create_task_instance, initial_ti_state
+    ):
         """
         Test that a 409 error is returned when the Task Instance is not in a 
state where it can be marked as
         running. In this case, the Task Instance is first in NONE state so it 
cannot be marked as running.
         """
         ti = create_task_instance(
             task_id="test_ti_update_state_conflict_if_not_queued",
-            state=State.NONE,
+            state=initial_ti_state,
         )
         session.commit()
 
@@ -105,12 +108,12 @@ class TestTIUpdateState:
         assert response.json() == {
             "detail": {
                 "message": "TI was not in a state where it could be marked as 
running",
-                "previous_state": State.NONE,
+                "previous_state": initial_ti_state,
                 "reason": "invalid_state",
             }
         }
 
-        assert session.scalar(select(TaskInstance.state).where(TaskInstance.id 
== ti.id)) == State.NONE
+        assert session.scalar(select(TaskInstance.state).where(TaskInstance.id 
== ti.id)) == initial_ti_state
 
     @pytest.mark.parametrize(
         ("state", "end_date", "expected_state"),

Reply via email to