This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b1a44b4e3db AIP-72: Improve Supervisor and Task Instance State
Validation (#44405)
b1a44b4e3db is described below
commit b1a44b4e3db8a3d3c13dfc810ad02f5785f82eca
Author: Kaxil Naik <[email protected]>
AuthorDate: Wed Nov 27 12:13:44 2024 +0000
AIP-72: Improve Supervisor and Task Instance State Validation (#44405)
---
task_sdk/tests/execution_time/test_supervisor.py | 38 +++++++++++++++++++++-
.../execution_api/routes/test_task_instances.py | 13 +++++---
2 files changed, 45 insertions(+), 6 deletions(-)
diff --git a/task_sdk/tests/execution_time/test_supervisor.py
b/task_sdk/tests/execution_time/test_supervisor.py
index 1083364f289..9f582074586 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -28,11 +28,13 @@ from time import sleep
from typing import TYPE_CHECKING
from unittest.mock import MagicMock
+import httpx
import pytest
import structlog
from uuid6 import uuid7
from airflow.sdk.api import client as sdk_client
+from airflow.sdk.api.client import ServerResponseError
from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.api.datamodels.activities import ExecuteTaskActivity
from airflow.sdk.execution_time.comms import (
@@ -46,6 +48,8 @@ from airflow.sdk.execution_time.comms import (
from airflow.sdk.execution_time.supervisor import WatchedSubprocess, supervise
from airflow.utils import timezone as tz
+from task_sdk.tests.api.test_client import make_client
+
if TYPE_CHECKING:
import kgb
@@ -73,7 +77,7 @@ class TestWatchedSubprocess:
print("I'm a short message")
sys.stdout.write("Message ")
print("stderr message", file=sys.stderr)
- # We need a short sleep for the main process to process things. I
worry this timining will be
+ # We need a short sleep for the main process to process things. I
worry this timing will be
# fragile, but I can't think of a better way. This lets the stdout
be read (partial line) and the
# stderr full line be read
sleep(0.1)
@@ -265,6 +269,38 @@ class TestWatchedSubprocess:
"timestamp": "2024-11-07T12:34:56.078901Z",
} in captured_logs
+ def test_supervisor_handles_already_running_task(self):
+ """Test that Supervisor prevents starting a Task Instance that is
already running."""
+ ti = TaskInstance(id=uuid7(), task_id="b", dag_id="c", run_id="d",
try_number=1)
+
+ # Mock API Server response indicating the TI is already running
+ # The API Server would return a 409 Conflict status code if the TI is
not
+ # in a "queued" state.
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if request.url.path == f"/task-instances/{ti.id}/state":
+ return httpx.Response(
+ 409,
+ json={
+ "reason": "invalid_state",
+ "message": "TI was not in a state where it could be
marked as running",
+ "previous_state": "running",
+ },
+ )
+
+ return httpx.Response(status_code=204)
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+
+ with pytest.raises(ServerResponseError, match="Server returned error")
as err:
+ WatchedSubprocess.start(path=os.devnull, ti=ti, client=client)
+
+ assert err.value.response.status_code == 409
+ assert err.value.detail == {
+ "reason": "invalid_state",
+ "message": "TI was not in a state where it could be marked as
running",
+ "previous_state": "running",
+ }
+
class TestHandleRequest:
@pytest.fixture
diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py
b/tests/api_fastapi/execution_api/routes/test_task_instances.py
index efb48ccb533..d2285e7e3a9 100644
--- a/tests/api_fastapi/execution_api/routes/test_task_instances.py
+++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py
@@ -25,7 +25,7 @@ from sqlalchemy.exc import SQLAlchemyError
from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
-from airflow.utils.state import State
+from airflow.utils.state import State, TaskInstanceState
from tests_common.test_utils.db import clear_db_runs
@@ -79,14 +79,17 @@ class TestTIUpdateState:
assert ti.pid == 100
assert ti.start_date.isoformat() == "2024-10-31T12:00:00+00:00"
- def test_ti_update_state_conflict_if_not_queued(self, client, session,
create_task_instance):
+ @pytest.mark.parametrize("initial_ti_state", [s for s in TaskInstanceState
if s != State.QUEUED])
+ def test_ti_update_state_conflict_if_not_queued(
+ self, client, session, create_task_instance, initial_ti_state
+ ):
"""
Test that a 409 error is returned when the Task Instance is not in a
state where it can be marked as
running. In this case, the Task Instance is first in NONE state so it
cannot be marked as running.
"""
ti = create_task_instance(
task_id="test_ti_update_state_conflict_if_not_queued",
- state=State.NONE,
+ state=initial_ti_state,
)
session.commit()
@@ -105,12 +108,12 @@ class TestTIUpdateState:
assert response.json() == {
"detail": {
"message": "TI was not in a state where it could be marked as
running",
- "previous_state": State.NONE,
+ "previous_state": initial_ti_state,
"reason": "invalid_state",
}
}
- assert session.scalar(select(TaskInstance.state).where(TaskInstance.id
== ti.id)) == State.NONE
+ assert session.scalar(select(TaskInstance.state).where(TaskInstance.id
== ti.id)) == initial_ti_state
@pytest.mark.parametrize(
("state", "end_date", "expected_state"),