This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d34b434648e Do not call runtime-checks api from the task sdk anymore
(#48125)
d34b434648e is described below
commit d34b434648edf59d0f5008c374e415b398ad96d7
Author: Sneha Prabhu <[email protected]>
AuthorDate: Wed Mar 26 21:57:06 2025 +0530
Do not call runtime-checks api from the task sdk anymore (#48125)
We'd already removed any meaningfull checks in this endpoint in a previous
PR
Although right now we don't have any API clients calling the previous
version
since we aren't released yet and could have simply deleted the endpoint
outright, this is a good opportunity to use this as a practice run for API
versioning
---------
Co-authored-by: Sneha Prabhu <[email protected]>
Co-authored-by: Ash Berlin-Taylor <[email protected]>
---
.../src/airflow/api_fastapi/execution_api/app.py | 6 +-
.../execution_api/routes/task_instances.py | 1 +
.../execution_api/versions/__init__.py} | 16 ++---
.../execution_api/versions/v2025_03_26.py} | 15 +++--
.../api_fastapi/execution_api/versions/README.md | 22 +++++++
.../execution_api/{routes => versions}/__init__.py | 0
.../{routes => versions/head}/__init__.py | 0
.../{routes => versions/head}/test_asset_events.py | 0
.../{routes => versions/head}/test_assets.py | 0
.../{routes => versions/head}/test_connections.py | 0
.../{routes => versions/head}/test_dag_runs.py | 0
.../{routes => versions/head}/test_health.py | 0
.../head}/test_task_instances.py | 29 ---------
.../{routes => versions/head}/test_variables.py | 0
.../{routes => versions/head}/test_xcoms.py | 0
.../{routes => versions/v2025_03_19}/__init__.py | 0
.../versions/v2025_03_19/test_task_instances.py | 73 +++++++++++++++++++++
task-sdk/dev/generate_task_sdk_models.py | 2 +-
task-sdk/src/airflow/sdk/api/client.py | 15 -----
.../src/airflow/sdk/api/datamodels/_generated.py | 14 +---
task-sdk/src/airflow/sdk/execution_time/comms.py | 6 --
.../src/airflow/sdk/execution_time/supervisor.py | 4 --
.../src/airflow/sdk/execution_time/task_runner.py | 13 ----
.../task_sdk/execution_time/test_supervisor.py | 21 ------
.../task_sdk/execution_time/test_task_runner.py | 75 +---------------------
25 files changed, 117 insertions(+), 195 deletions(-)
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py
b/airflow-core/src/airflow/api_fastapi/execution_api/app.py
index 12e86129faa..8e73cfd08c4 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py
@@ -26,9 +26,6 @@ import attrs
import svcs
from cadwyn import (
Cadwyn,
- HeadVersion,
- Version,
- VersionBundle,
)
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
@@ -127,6 +124,7 @@ class CadwynWithOpenAPICustomization(Cadwyn):
def create_task_execution_api_app() -> FastAPI:
"""Create FastAPI app for task execution API."""
from airflow.api_fastapi.execution_api.routes import execution_api_router
+ from airflow.api_fastapi.execution_api.versions import bundle
# See https://docs.cadwyn.dev/concepts/version_changes/ for info about API
versions
app = CadwynWithOpenAPICustomization(
@@ -134,7 +132,7 @@ def create_task_execution_api_app() -> FastAPI:
description="The private Airflow Task Execution API.",
lifespan=lifespan,
api_version_parameter_name="Airflow-API-Version",
- versions=VersionBundle(HeadVersion(), Version("2025-03-19")),
+ versions=bundle,
)
app.generate_and_include_versioned_routers(execution_api_router)
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index afc514f08f1..e78e2d41063 100644
---
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -564,6 +564,7 @@ def get_previous_successful_dagrun(
return PrevSuccessfulDagRunResponse.model_validate(dag_run)
[email protected]_exists_in_older_versions
@router.post(
"/{task_instance_id}/runtime-checks",
status_code=status.HTTP_204_NO_CONTENT,
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/routes/test_health.py
b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
similarity index 74%
copy from
airflow-core/tests/unit/api_fastapi/execution_api/routes/test_health.py
copy to airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
index dc13f9880a8..af93aab29ba 100644
--- a/airflow-core/tests/unit/api_fastapi/execution_api/routes/test_health.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
@@ -14,15 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from __future__ import annotations
-
-import pytest
-pytestmark = pytest.mark.db_test
+from __future__ import annotations
+from cadwyn import HeadVersion, Version, VersionBundle
-def test_health(client):
- response = client.get("/execution/health")
+from airflow.api_fastapi.execution_api.versions.v2025_03_26 import
RemoveTIRuntimeChecksEndpoint
- assert response.status_code == 200
- assert response.json() == {"status": "healthy"}
+bundle = VersionBundle(
+ HeadVersion(),
+ Version("2025-03-26", RemoveTIRuntimeChecksEndpoint),
+ Version("2025-03-19"),
+)
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/routes/test_health.py
b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_03_26.py
similarity index 70%
copy from
airflow-core/tests/unit/api_fastapi/execution_api/routes/test_health.py
copy to
airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_03_26.py
index dc13f9880a8..98d9b985399 100644
--- a/airflow-core/tests/unit/api_fastapi/execution_api/routes/test_health.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_03_26.py
@@ -14,15 +14,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from __future__ import annotations
-import pytest
+from __future__ import annotations
-pytestmark = pytest.mark.db_test
+from cadwyn import VersionChange, endpoint
-def test_health(client):
- response = client.get("/execution/health")
+class RemoveTIRuntimeChecksEndpoint(VersionChange):
+ """Remove the runtime-check endpoint as it does nothing anymore."""
- assert response.status_code == 200
- assert response.json() == {"status": "healthy"}
+ description = __doc__
+ instructions_to_migrate_to_previous_version = (
+ endpoint("/task-instances/{task_instance_id}/runtime-checks",
["POST"]).existed,
+ )
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/README.md
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/README.md
new file mode 100644
index 00000000000..c4a59f3c652
--- /dev/null
+++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/README.md
@@ -0,0 +1,22 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+ -->
+
+# Note about test structure
+
+This test package follows the approach laid out in [Cadwyn's Testing
page](https://docs.cadwyn.dev/concepts/testing/).
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/routes/__init__.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/__init__.py
similarity index 100%
copy from airflow-core/tests/unit/api_fastapi/execution_api/routes/__init__.py
copy to airflow-core/tests/unit/api_fastapi/execution_api/versions/__init__.py
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/routes/__init__.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/__init__.py
similarity index 100%
copy from airflow-core/tests/unit/api_fastapi/execution_api/routes/__init__.py
copy to
airflow-core/tests/unit/api_fastapi/execution_api/versions/head/__init__.py
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/routes/test_asset_events.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_asset_events.py
similarity index 100%
rename from
airflow-core/tests/unit/api_fastapi/execution_api/routes/test_asset_events.py
rename to
airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_asset_events.py
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/routes/test_assets.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_assets.py
similarity index 100%
rename from
airflow-core/tests/unit/api_fastapi/execution_api/routes/test_assets.py
rename to
airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_assets.py
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/routes/test_connections.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_connections.py
similarity index 100%
rename from
airflow-core/tests/unit/api_fastapi/execution_api/routes/test_connections.py
rename to
airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_connections.py
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/routes/test_dag_runs.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py
similarity index 100%
rename from
airflow-core/tests/unit/api_fastapi/execution_api/routes/test_dag_runs.py
rename to
airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/routes/test_health.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_health.py
similarity index 100%
rename from
airflow-core/tests/unit/api_fastapi/execution_api/routes/test_health.py
rename to
airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_health.py
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/routes/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
similarity index 98%
rename from
airflow-core/tests/unit/api_fastapi/execution_api/routes/test_task_instances.py
rename to
airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index 2ab0f943e6d..a99291ed4d6 100644
---
a/airflow-core/tests/unit/api_fastapi/execution_api/routes/test_task_instances.py
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -804,35 +804,6 @@ class TestTIUpdateState:
assert ti.next_kwargs is None
assert ti.duration == 3600.00
- @pytest.mark.parametrize(
- ("state", "expected_status_code"),
- [
- (State.RUNNING, 204),
- (State.SUCCESS, 409),
- (State.QUEUED, 409),
- (State.FAILED, 409),
- ],
- )
- def test_ti_runtime_checks_success(
- self, client, session, create_task_instance, state,
expected_status_code
- ):
- ti = create_task_instance(
- task_id="test_ti_runtime_checks",
- state=state,
- )
- session.commit()
-
- response = client.post(
- f"/execution/task-instances/{ti.id}/runtime-checks",
- json={
- "inlets": [],
- "outlets": [],
- },
- )
- assert response.status_code == expected_status_code
-
- session.expire_all()
-
class TestTISkipDownstream:
def setup_method(self):
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/routes/test_variables.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py
similarity index 100%
rename from
airflow-core/tests/unit/api_fastapi/execution_api/routes/test_variables.py
rename to
airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/routes/test_xcoms.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py
similarity index 100%
rename from
airflow-core/tests/unit/api_fastapi/execution_api/routes/test_xcoms.py
rename to
airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/routes/__init__.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2025_03_19/__init__.py
similarity index 100%
rename from airflow-core/tests/unit/api_fastapi/execution_api/routes/__init__.py
rename to
airflow-core/tests/unit/api_fastapi/execution_api/versions/v2025_03_19/__init__.py
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2025_03_19/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2025_03_19/test_task_instances.py
new file mode 100644
index 00000000000..9ac6880ba04
--- /dev/null
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2025_03_19/test_task_instances.py
@@ -0,0 +1,73 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import pytest
+
+from airflow.utils import timezone
+from airflow.utils.state import State
+
+from tests_common.test_utils.db import clear_db_assets, clear_db_runs
+
+pytestmark = pytest.mark.db_test
+
+
+DEFAULT_START_DATE = timezone.parse("2024-10-31T11:00:00Z")
+DEFAULT_END_DATE = timezone.parse("2024-10-31T12:00:00Z")
+
+
+class TestTIUpdateState:
+ def setup_method(self):
+ clear_db_assets()
+ clear_db_runs()
+
+ def teardown_method(self):
+ clear_db_assets()
+ clear_db_runs()
+
+ @pytest.mark.parametrize(
+ ("state", "expected_status_code"),
+ [
+ (State.RUNNING, 204),
+ (State.SUCCESS, 409),
+ (State.QUEUED, 409),
+ (State.FAILED, 409),
+ ],
+ )
+ def test_ti_runtime_checks_success(
+ self, client, session, create_task_instance, state,
expected_status_code
+ ):
+ # Last version this endpoint exists in
+ client.headers["Airflow-API-Version"] = "2025-03-19"
+
+ ti = create_task_instance(
+ task_id="test_ti_runtime_checks",
+ state=state,
+ )
+ session.commit()
+
+ response = client.post(
+ f"/execution/task-instances/{ti.id}/runtime-checks",
+ json={
+ "inlets": [],
+ "outlets": [],
+ },
+ )
+ assert response.status_code == expected_status_code
+
+ session.expire_all()
diff --git a/task-sdk/dev/generate_task_sdk_models.py
b/task-sdk/dev/generate_task_sdk_models.py
index bce9af5172a..43c30dc5022 100644
--- a/task-sdk/dev/generate_task_sdk_models.py
+++ b/task-sdk/dev/generate_task_sdk_models.py
@@ -78,7 +78,7 @@ def generate_file():
app = InProcessExecutionAPI()
- latest_version = app.app.versions.version_values[-1]
+ latest_version = app.app.versions.version_values[0]
client = httpx.Client(transport=app.transport)
openapi_schema = (
client.get(f"http://localhost/openapi.json?version={latest_version}").raise_for_status().text
diff --git a/task-sdk/src/airflow/sdk/api/client.py
b/task-sdk/src/airflow/sdk/api/client.py
index 6dc5123530b..0ab6bae0344 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -32,7 +32,6 @@ from retryhttp import retry, wait_retry_after
from tenacity import before_log, wait_random_exponential
from uuid6 import uuid7
-from airflow.api_fastapi.execution_api.datamodels.taskinstance import
TIRuntimeCheckPayload
from airflow.sdk import __version__
from airflow.sdk.api.datamodels._generated import (
API_VERSION,
@@ -62,7 +61,6 @@ from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import (
ErrorResponse,
OKResponse,
- RuntimeCheckOnTask,
SkipDownstreamTasks,
TaskRescheduleStartDate,
)
@@ -197,19 +195,6 @@ class TaskInstanceOperations:
resp =
self.client.get(f"task-instances/{id}/previous-successful-dagrun")
return PrevSuccessfulDagRunResponse.model_validate_json(resp.read())
- def runtime_checks(self, id: uuid.UUID, msg: RuntimeCheckOnTask) ->
OKResponse:
- body = TIRuntimeCheckPayload(**msg.model_dump(exclude_unset=True,
exclude={"type"}))
- try:
- self.client.post(f"task-instances/{id}/runtime-checks",
content=body.model_dump_json())
- return OKResponse(ok=True)
- except ServerResponseError as e:
- if e.response.status_code == 400:
- return OKResponse(ok=False)
- elif e.response.status_code == 409:
- # The TI isn't in the right state to perform the check, but we
shouldn't fail the task for that
- return OKResponse(ok=True)
- raise
-
def get_reschedule_start_date(self, id: uuid.UUID, try_number: int = 1) ->
TaskRescheduleStartDate:
"""Get the start date of a task reschedule via the API server."""
resp = self.client.get(f"task-reschedules/{id}/start_date",
params={"try_number": try_number})
diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
index 2fb929516cf..af24f756675 100644
--- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -27,7 +27,7 @@ from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field, JsonValue
-API_VERSION: Final[str] = "2025-03-19"
+API_VERSION: Final[str] = "2025-03-26"
class AssetProfile(BaseModel):
@@ -223,18 +223,6 @@ class TIRetryStatePayload(BaseModel):
end_date: Annotated[datetime, Field(title="End Date")]
-class TIRuntimeCheckPayload(BaseModel):
- """
- Payload for performing Runtime checks on the TaskInstance model as
requested by the SDK.
- """
-
- model_config = ConfigDict(
- extra="forbid",
- )
- inlets: Annotated[list[AssetProfile] | None, Field(title="Inlets")] = None
- outlets: Annotated[list[AssetProfile] | None, Field(title="Outlets")] =
None
-
-
class TISkippedDownstreamTasksStatePayload(BaseModel):
"""
Schema for updating downstream tasks to a skipped state.
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index 20867b22ca5..477ec7c1fd5 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -63,7 +63,6 @@ from airflow.sdk.api.datamodels._generated import (
TIRescheduleStatePayload,
TIRetryStatePayload,
TIRunContext,
- TIRuntimeCheckPayload,
TISkippedDownstreamTasksStatePayload,
TISuccessStatePayload,
TriggerDAGRunPayload,
@@ -312,10 +311,6 @@ class
SkipDownstreamTasks(TISkippedDownstreamTasksStatePayload):
type: Literal["SkipDownstreamTasks"] = "SkipDownstreamTasks"
-class RuntimeCheckOnTask(TIRuntimeCheckPayload):
- type: Literal["RuntimeCheckOnTask"] = "RuntimeCheckOnTask"
-
-
class GetXCom(BaseModel):
key: str
dag_id: str
@@ -470,7 +465,6 @@ ToSupervisor = Annotated[
SetXCom,
TaskState,
TriggerDagRun,
- RuntimeCheckOnTask,
DeleteXCom,
],
Field(discriminator="type"),
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index 371330a379a..27dd1bc4339 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -82,7 +82,6 @@ from airflow.sdk.execution_time.comms import (
PrevSuccessfulDagRunResult,
PutVariable,
RescheduleTask,
- RuntimeCheckOnTask,
SetRenderedFields,
SetXCom,
SkipDownstreamTasks,
@@ -875,9 +874,6 @@ class ActivitySubprocess(WatchedSubprocess):
if isinstance(msg, TaskState):
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
- elif isinstance(msg, RuntimeCheckOnTask):
- runtime_check_resp =
self.client.task_instances.runtime_checks(id=self.id, msg=msg)
- resp = runtime_check_resp.model_dump_json().encode()
elif isinstance(msg, SucceedTask):
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 31c24c33324..04efa13616f 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -60,10 +60,8 @@ from airflow.sdk.execution_time.comms import (
ErrorResponse,
GetDagRunState,
GetTaskRescheduleStartDate,
- OKResponse,
RescheduleTask,
RetryTask,
- RuntimeCheckOnTask,
SetRenderedFields,
SkipDownstreamTasks,
StartupDetails,
@@ -582,17 +580,6 @@ def _serialize_outlet_events(events:
OutletEventAccessorsProtocol) -> Iterator[d
def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) ->
ToSupervisor | None:
ti.hostname = get_hostname()
ti.task = ti.task.prepare_for_execution()
- if ti.task.inlets or ti.task.outlets:
- inlets = [asset.asprofile() for asset in ti.task.inlets if
isinstance(asset, Asset)]
- outlets = [asset.asprofile() for asset in ti.task.outlets if
isinstance(asset, Asset)]
- SUPERVISOR_COMMS.send_request(msg=RuntimeCheckOnTask(inlets=inlets,
outlets=outlets), log=log) # type: ignore
- ok_response = SUPERVISOR_COMMS.get_message() # type: ignore
- if not isinstance(ok_response, OKResponse) or not ok_response.ok:
- log.info("Runtime checks failed for task, marking task as
failed..")
- return TaskState(
- state=TerminalTIState.FAILED,
- end_date=datetime.now(tz=timezone.utc),
- )
jinja_env = ti.task.dag.get_template_env()
ti.render_templates(context=context, jinja_env=jinja_env)
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index 9751932b8e8..f1011bb9131 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -47,7 +47,6 @@ from airflow.sdk.api import client as sdk_client
from airflow.sdk.api.client import ServerResponseError
from airflow.sdk.api.datamodels._generated import (
AssetEventResponse,
- AssetProfile,
AssetResponse,
DagRunState,
TaskInstance,
@@ -76,7 +75,6 @@ from airflow.sdk.execution_time.comms import (
PrevSuccessfulDagRunResult,
PutVariable,
RescheduleTask,
- RuntimeCheckOnTask,
SetRenderedFields,
SetXCom,
SucceedTask,
@@ -1294,25 +1292,6 @@ class TestHandleRequest:
),
id="get_prev_successful_dagrun",
),
- pytest.param(
- RuntimeCheckOnTask(
- inlets=[AssetProfile(name="alias", uri="alias",
type="asset")],
- outlets=[AssetProfile(name="alias", uri="alias",
type="asset")],
- ),
- b'{"ok":true,"type":"OKResponse"}\n',
- "task_instances.runtime_checks",
- (),
- {
- "id": TI_ID,
- "msg": RuntimeCheckOnTask(
- inlets=[AssetProfile(name="alias", uri="alias",
type="asset")],
- outlets=[AssetProfile(name="alias", uri="alias",
type="asset")],
- type="RuntimeCheckOnTask",
- ),
- },
- OKResponse(ok=True),
- id="runtime_check_on_task",
- ),
pytest.param(
TriggerDagRun(
dag_id="test_dag",
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index d968f761620..0dfe64bbce4 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -72,7 +72,6 @@ from airflow.sdk.execution_time.comms import (
GetXCom,
OKResponse,
PrevSuccessfulDagRunResult,
- RuntimeCheckOnTask,
SetRenderedFields,
SetXCom,
SkipDownstreamTasks,
@@ -804,9 +803,6 @@ def test_run_with_asset_outlets(
ti = create_runtime_ti(task=task, dag_id="dag_with_asset_outlet_task")
instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)
- mock_supervisor_comms.get_message.return_value = OKResponse(
- ok=True,
- )
run(ti, context=ti.get_template_context(), log=mock.MagicMock())
@@ -857,75 +853,6 @@ def test_run_with_asset_inlets(create_runtime_ti,
mock_supervisor_comms):
inlet_events[Asset(name="no such asset in inlets")]
[email protected](
- ["ok", "last_expected_msg"],
- [
- pytest.param(
- True,
- SucceedTask(
- end_date=timezone.datetime(2024, 12, 3, 10, 0),
- task_outlets=[
- AssetProfile(name="name", uri="s3://bucket/my-task",
type="Asset"),
- AssetProfile(name="new-name", uri="s3://bucket/my-task",
type="Asset"),
- ],
- outlet_events=[],
- ),
- id="runtime_checks_pass",
- ),
- pytest.param(
- False,
- TaskState(
- state=TerminalTIState.FAILED,
- end_date=timezone.datetime(2024, 12, 3, 10, 0),
- ),
- id="runtime_checks_fail",
- ),
- ],
-)
-def test_run_with_inlets_and_outlets(
- create_runtime_ti, mock_supervisor_comms, time_machine, ok,
last_expected_msg
-):
- """Test running a basic tasks with inlets and outlets."""
-
- instant = timezone.datetime(2024, 12, 3, 10, 0)
- time_machine.move_to(instant, tick=False)
-
- from airflow.providers.standard.operators.bash import BashOperator
-
- task = BashOperator(
- outlets=[
- Asset(name="name", uri="s3://bucket/my-task"),
- Asset(name="new-name", uri="s3://bucket/my-task"),
- ],
- inlets=[
- Asset(name="name", uri="s3://bucket/my-task"),
- Asset(name="new-name", uri="s3://bucket/my-task"),
- ],
- task_id="inlets-and-outlets",
- bash_command="echo 'hi'",
- )
-
- ti = create_runtime_ti(task=task, dag_id="dag_with_inlets_and_outlets")
- mock_supervisor_comms.get_message.return_value = OKResponse(
- ok=ok,
- )
-
- run(ti, context=ti.get_template_context(), log=mock.MagicMock())
-
- expected = RuntimeCheckOnTask(
- inlets=[
- AssetProfile(name="name", uri="s3://bucket/my-task", type="Asset"),
- AssetProfile(name="new-name", uri="s3://bucket/my-task",
type="Asset"),
- ],
- outlets=[
- AssetProfile(name="name", uri="s3://bucket/my-task", type="Asset"),
- AssetProfile(name="new-name", uri="s3://bucket/my-task",
type="Asset"),
- ],
- )
- mock_supervisor_comms.send_request.assert_any_call(msg=expected,
log=mock.ANY)
- mock_supervisor_comms.send_request.assert_any_call(msg=last_expected_msg,
log=mock.ANY)
-
-
@mock.patch("airflow.sdk.execution_time.task_runner.context_to_airflow_vars")
@mock.patch.dict(os.environ, {}, clear=True)
def test_execute_task_exports_env_vars(
@@ -2218,7 +2145,7 @@ class TestTriggerDagRunOperator:
ti = create_runtime_ti(dag_id="test_handle_trigger_dag_run",
run_id="test_run", task=task)
log = mock.MagicMock()
- mock_supervisor_comms.get_message.return_value = OKResponse(ok=True)
+
state, msg, _ = run(ti, ti.get_template_context(), log)
assert state == TaskInstanceState.SUCCESS