This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new dbff6e325a9 AIP-72: Pass context keys from API Server to Workers
(#44899)
dbff6e325a9 is described below
commit dbff6e325a9717b2b3f8b39e084034c68fbddfce
Author: Kaxil Naik <[email protected]>
AuthorDate: Mon Dec 16 21:41:18 2024 +0530
AIP-72: Pass context keys from API Server to Workers (#44899)
Part of https://github.com/apache/airflow/issues/44481
This commit augments the TI context available in the Task Execution
Interface with the one from the Execution API Server.
In future PRs the following will be added:
- More methods on TI like ti.xcom_pull, ti.xcom_push etc
- Lazy fetching of connections, variables
- Verifying the "get_current_context" is working
---
.../execution_api/datamodels/taskinstance.py | 39 +++++-
.../execution_api/routes/task_instances.py | 139 ++++++++++++++++-----
task_sdk/src/airflow/sdk/api/client.py | 25 +++-
.../src/airflow/sdk/api/datamodels/_generated.py | 37 ++++++
task_sdk/src/airflow/sdk/execution_time/comms.py | 2 +
.../src/airflow/sdk/execution_time/supervisor.py | 3 +-
.../src/airflow/sdk/execution_time/task_runner.py | 49 ++++++--
task_sdk/tests/api/test_client.py | 18 ++-
task_sdk/tests/conftest.py | 90 ++++++++++++-
task_sdk/tests/execution_time/test_supervisor.py | 17 ++-
task_sdk/tests/execution_time/test_task_runner.py | 97 ++++++++++++--
.../execution_api/routes/test_task_instances.py | 81 +++++++++---
12 files changed, 506 insertions(+), 91 deletions(-)
diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
index bbc557d0124..92a1e933dc9 100644
--- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
+++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
@@ -25,7 +25,10 @@ from pydantic import Discriminator, Field, Tag,
WithJsonSchema
from airflow.api_fastapi.common.types import UtcDateTime
from airflow.api_fastapi.core_api.base import BaseModel
+from airflow.api_fastapi.execution_api.datamodels.connection import
ConnectionResponse
+from airflow.api_fastapi.execution_api.datamodels.variable import
VariableResponse
from airflow.utils.state import IntermediateTIState, TaskInstanceState as
TIState, TerminalTIState
+from airflow.utils.types import DagRunType
class TIEnterRunningPayload(BaseModel):
@@ -94,9 +97,7 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) ->
str:
state = v.get("state")
else:
state = getattr(v, "state", None)
- if state == TIState.RUNNING:
- return str(state)
- elif state in set(TerminalTIState):
+ if state in set(TerminalTIState):
return "_terminal_"
elif state == TIState.DEFERRED:
return "deferred"
@@ -107,7 +108,6 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel)
-> str:
# and "_other_" is a catch-all for all other states that are not covered by
the other schemas.
TIStateUpdate = Annotated[
Union[
- Annotated[TIEnterRunningPayload, Tag("running")],
Annotated[TITerminalStatePayload, Tag("_terminal_")],
Annotated[TITargetStatePayload, Tag("_other_")],
Annotated[TIDeferredStatePayload, Tag("deferred")],
@@ -135,3 +135,34 @@ class TaskInstance(BaseModel):
run_id: str
try_number: int
map_index: int | None = None
+
+
+class DagRun(BaseModel):
+ """Schema for DagRun model with minimal required fields needed for
Runtime."""
+
+ # TODO: `dag_id` and `run_id` are duplicated from TaskInstance
+ # See if we can avoid sending these fields from API server and instead
+ # use the TaskInstance data to get the DAG run information in the client
(Task Execution Interface).
+ dag_id: str
+ run_id: str
+
+ logical_date: UtcDateTime
+ data_interval_start: UtcDateTime | None
+ data_interval_end: UtcDateTime | None
+ start_date: UtcDateTime
+ end_date: UtcDateTime | None
+ run_type: DagRunType
+ conf: Annotated[dict[str, Any], Field(default_factory=dict)]
+
+
+class TIRunContext(BaseModel):
+ """Response schema for TaskInstance run context."""
+
+ dag_run: DagRun
+ """DAG run information for the task instance."""
+
+ variables: Annotated[list[VariableResponse], Field(default_factory=list)]
+ """Variables that can be accessed by the task instance."""
+
+ connections: Annotated[list[ConnectionResponse],
Field(default_factory=list)]
+ """Connections that can be accessed by the task instance."""
diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py
b/airflow/api_fastapi/execution_api/routes/task_instances.py
index e06798209c5..3a1545283e8 100644
--- a/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ b/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -30,12 +30,15 @@ from sqlalchemy.sql import select
from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
+ DagRun,
TIDeferredStatePayload,
TIEnterRunningPayload,
TIHeartbeatInfo,
+ TIRunContext,
TIStateUpdate,
TITerminalStatePayload,
)
+from airflow.models.dagrun import DagRun as DR
from airflow.models.taskinstance import TaskInstance as TI, _update_rtif
from airflow.models.trigger import Trigger
from airflow.utils import timezone
@@ -48,6 +51,110 @@ router = AirflowRouter()
log = logging.getLogger(__name__)
[email protected](
+ "/{task_instance_id}/run",
+ status_code=status.HTTP_200_OK,
+ responses={
+ status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"},
+ status.HTTP_409_CONFLICT: {"description": "The TI is already in the
requested state"},
+ status.HTTP_422_UNPROCESSABLE_ENTITY: {"description": "Invalid payload
for the state transition"},
+ },
+)
+def ti_run(
+ task_instance_id: UUID, ti_run_payload: Annotated[TIEnterRunningPayload,
Body()], session: SessionDep
+) -> TIRunContext:
+ """
+ Run a TaskInstance.
+
+ This endpoint is used to start a TaskInstance that is in the QUEUED state.
+ """
+ # We only use UUID above for validation purposes
+ ti_id_str = str(task_instance_id)
+
+ old = select(TI.state, TI.dag_id, TI.run_id).where(TI.id ==
ti_id_str).with_for_update()
+ try:
+ (previous_state, dag_id, run_id) = session.execute(old).one()
+ except NoResultFound:
+ log.error("Task Instance %s not found", ti_id_str)
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail={
+ "reason": "not_found",
+ "message": "Task Instance not found",
+ },
+ )
+
+ # We exclude_unset to avoid updating fields that are not set in the payload
+ data = ti_run_payload.model_dump(exclude_unset=True)
+
+ query = update(TI).where(TI.id == ti_id_str).values(data)
+
+ # TODO: We will need to change this for other states like:
+ # reschedule, retry, defer etc.
+ if previous_state != State.QUEUED:
+ log.warning(
+ "Can not start Task Instance ('%s') in invalid state: %s",
+ ti_id_str,
+ previous_state,
+ )
+
+ # TODO: Pass a RFC 9457 compliant error message in "detail" field
+ # https://datatracker.ietf.org/doc/html/rfc9457
+ # to provide more information about the error
+ # FastAPI will automatically convert this to a JSON response
+ # This might be added in FastAPI in
https://github.com/fastapi/fastapi/issues/10370
+ raise HTTPException(
+ status_code=status.HTTP_409_CONFLICT,
+ detail={
+ "reason": "invalid_state",
+ "message": "TI was not in a state where it could be marked as
running",
+ "previous_state": previous_state,
+ },
+ )
+ log.info("Task with %s state started on %s ", previous_state,
ti_run_payload.hostname)
+ # Ensure there is no end date set.
+ query = query.values(
+ end_date=None,
+ hostname=ti_run_payload.hostname,
+ unixname=ti_run_payload.unixname,
+ pid=ti_run_payload.pid,
+ state=State.RUNNING,
+ )
+
+ try:
+ result = session.execute(query)
+ log.info("TI %s state updated: %s row(s) affected", ti_id_str,
result.rowcount)
+
+ dr = session.execute(
+ select(
+ DR.run_id,
+ DR.dag_id,
+ DR.data_interval_start,
+ DR.data_interval_end,
+ DR.start_date,
+ DR.end_date,
+ DR.run_type,
+ DR.conf,
+ DR.logical_date,
+ ).filter_by(dag_id=dag_id, run_id=run_id)
+ ).one_or_none()
+
+ if not dr:
+ raise ValueError(f"DagRun with dag_id={dag_id} and run_id={run_id}
not found.")
+
+ return TIRunContext(
+ dag_run=DagRun.model_validate(dr, from_attributes=True),
+ # TODO: Add variables and connections that are needed (and has
perms) for the task
+ variables=[],
+ connections=[],
+ )
+ except SQLAlchemyError as e:
+ log.error("Error marking Task Instance state as running: %s", e)
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Database error occurred"
+ )
+
+
@router.patch(
"/{task_instance_id}/state",
status_code=status.HTTP_204_NO_CONTENT,
@@ -92,37 +199,7 @@ def ti_update_state(
query = update(TI).where(TI.id == ti_id_str).values(data)
- if isinstance(ti_patch_payload, TIEnterRunningPayload):
- if previous_state != State.QUEUED:
- log.warning(
- "Can not start Task Instance ('%s') in invalid state: %s",
- ti_id_str,
- previous_state,
- )
-
- # TODO: Pass a RFC 9457 compliant error message in "detail" field
- # https://datatracker.ietf.org/doc/html/rfc9457
- # to provide more information about the error
- # FastAPI will automatically convert this to a JSON response
- # This might be added in FastAPI in
https://github.com/fastapi/fastapi/issues/10370
- raise HTTPException(
- status_code=status.HTTP_409_CONFLICT,
- detail={
- "reason": "invalid_state",
- "message": "TI was not in a state where it could be marked
as running",
- "previous_state": previous_state,
- },
- )
- log.info("Task with %s state started on %s ", previous_state,
ti_patch_payload.hostname)
- # Ensure there is no end date set.
- query = query.values(
- end_date=None,
- hostname=ti_patch_payload.hostname,
- unixname=ti_patch_payload.unixname,
- pid=ti_patch_payload.pid,
- state=State.RUNNING,
- )
- elif isinstance(ti_patch_payload, TITerminalStatePayload):
+ if isinstance(ti_patch_payload, TITerminalStatePayload):
query = TI.duration_expression_update(ti_patch_payload.end_date,
query, session.bind)
elif isinstance(ti_patch_payload, TIDeferredStatePayload):
# Calculate timeout if it was passed
diff --git a/task_sdk/src/airflow/sdk/api/client.py
b/task_sdk/src/airflow/sdk/api/client.py
index 568eb3c90bd..5f08f2a6242 100644
--- a/task_sdk/src/airflow/sdk/api/client.py
+++ b/task_sdk/src/airflow/sdk/api/client.py
@@ -30,10 +30,12 @@ from uuid6 import uuid7
from airflow.sdk import __version__
from airflow.sdk.api.datamodels._generated import (
ConnectionResponse,
+ DagRunType,
TerminalTIState,
TIDeferredStatePayload,
TIEnterRunningPayload,
TIHeartbeatInfo,
+ TIRunContext,
TITerminalStatePayload,
ValidationError as RemoteValidationError,
VariablePostBody,
@@ -110,11 +112,12 @@ class TaskInstanceOperations:
def __init__(self, client: Client):
self.client = client
- def start(self, id: uuid.UUID, pid: int, when: datetime):
+ def start(self, id: uuid.UUID, pid: int, when: datetime) -> TIRunContext:
"""Tell the API server that this TI has started running."""
body = TIEnterRunningPayload(pid=pid, hostname=get_hostname(),
unixname=getuser(), start_date=when)
- self.client.patch(f"task-instances/{id}/state",
content=body.model_dump_json())
+ resp = self.client.patch(f"task-instances/{id}/run",
content=body.model_dump_json())
+ return TIRunContext.model_validate_json(resp.read())
def finish(self, id: uuid.UUID, state: TerminalTIState, when: datetime):
"""Tell the API server that this TI has reached a terminal state."""
@@ -218,7 +221,23 @@ class BearerAuth(httpx.Auth):
# This exists as a aid for debugging or local running via the `dry_run`
argument to Client. It doesn't make
# sense for returning connections etc.
def noop_handler(request: httpx.Request) -> httpx.Response:
- log.debug("Dry-run request", method=request.method, path=request.url.path)
+ path = request.url.path
+ log.debug("Dry-run request", method=request.method, path=path)
+
+ if path.startswith("/task-instances/") and path.endswith("/run"):
+ # Return a fake context
+ return httpx.Response(
+ 200,
+ json={
+ "dag_run": {
+ "dag_id": "test_dag",
+ "run_id": "test_run",
+ "logical_date": "2021-01-01T00:00:00Z",
+ "start_date": "2021-01-01T00:00:00Z",
+ "run_type": DagRunType.MANUAL,
+ },
+ },
+ )
return httpx.Response(200, json={"text": "Hello, world!"})
diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
index 37659ffcc1b..5a103e78fc0 100644
--- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -44,6 +44,17 @@ class ConnectionResponse(BaseModel):
extra: Annotated[str | None, Field(title="Extra")] = None
+class DagRunType(str, Enum):
+ """
+ Class with DagRun types.
+ """
+
+ BACKFILL = "backfill"
+ SCHEDULED = "scheduled"
+ MANUAL = "manual"
+ ASSET_TRIGGERED = "asset_triggered"
+
+
class IntermediateTIState(str, Enum):
"""
States that a Task Instance can be in that indicate it is not yet in a
terminal or running state.
@@ -159,10 +170,36 @@ class TaskInstance(BaseModel):
map_index: Annotated[int | None, Field(title="Map Index")] = None
+class DagRun(BaseModel):
+ """
+ Schema for DagRun model with minimal required fields needed for Runtime.
+ """
+
+ dag_id: Annotated[str, Field(title="Dag Id")]
+ run_id: Annotated[str, Field(title="Run Id")]
+ logical_date: Annotated[datetime, Field(title="Logical Date")]
+ data_interval_start: Annotated[datetime | None, Field(title="Data Interval
Start")] = None
+ data_interval_end: Annotated[datetime | None, Field(title="Data Interval
End")] = None
+ start_date: Annotated[datetime, Field(title="Start Date")]
+ end_date: Annotated[datetime | None, Field(title="End Date")] = None
+ run_type: DagRunType
+ conf: Annotated[dict[str, Any] | None, Field(title="Conf")] = None
+
+
class HTTPValidationError(BaseModel):
detail: Annotated[list[ValidationError] | None, Field(title="Detail")] =
None
+class TIRunContext(BaseModel):
+ """
+ Response schema for TaskInstance run context.
+ """
+
+ dag_run: DagRun
+ variables: Annotated[list[VariableResponse] | None,
Field(title="Variables")] = None
+ connections: Annotated[list[ConnectionResponse] | None,
Field(title="Connections")] = None
+
+
class TITerminalStatePayload(BaseModel):
"""
Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or
FAILED).
diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py
b/task_sdk/src/airflow/sdk/execution_time/comms.py
index 9e6093a092d..03f92c549fd 100644
--- a/task_sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task_sdk/src/airflow/sdk/execution_time/comms.py
@@ -54,6 +54,7 @@ from airflow.sdk.api.datamodels._generated import (
TaskInstance,
TerminalTIState,
TIDeferredStatePayload,
+ TIRunContext,
VariableResponse,
XComResponse,
)
@@ -70,6 +71,7 @@ class StartupDetails(BaseModel):
Responses will come back on stdin
"""
+ ti_context: TIRunContext
type: Literal["StartupDetails"] = "StartupDetails"
diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
index 677030b7bdc..589cae56434 100644
--- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -397,7 +397,7 @@ class WatchedSubprocess:
# We've forked, but the task won't start doing anything until we
send it the StartupDetails
# message. But before we do that, we need to tell the server it's
started (so it has the chance to
# tell us "no, stop!" for any reason)
- self.client.task_instances.start(ti.id, self.pid,
datetime.now(tz=timezone.utc))
+ ti_context = self.client.task_instances.start(ti.id, self.pid,
datetime.now(tz=timezone.utc))
self._last_successful_heartbeat = time.monotonic()
except Exception:
# On any error kill that subprocess!
@@ -408,6 +408,7 @@ class WatchedSubprocess:
ti=ti,
file=os.fspath(path),
requests_fd=requests_fd,
+ ti_context=ti_context,
)
# Send the message to tell the process what it needs to execute
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index 5aca25f590e..92f400d46e2 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -23,13 +23,13 @@ import os
import sys
from datetime import datetime, timezone
from io import FileIO
-from typing import TYPE_CHECKING, Any, Generic, TextIO, TypeVar
+from typing import TYPE_CHECKING, Annotated, Any, Generic, TextIO, TypeVar
import attrs
import structlog
-from pydantic import BaseModel, ConfigDict, JsonValue, TypeAdapter
+from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter
-from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
+from airflow.sdk.api.datamodels._generated import TaskInstance,
TerminalTIState, TIRunContext
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.execution_time.comms import (
DeferTask,
@@ -48,9 +48,13 @@ class RuntimeTaskInstance(TaskInstance):
model_config = ConfigDict(arbitrary_types_allowed=True)
task: BaseOperator
+ _ti_context_from_server: Annotated[TIRunContext | None, Field(repr=False)]
= None
+ """The Task Instance context from the API server, if any."""
def get_template_context(self):
+ # TODO: Assess if we need to it through
airflow.utils.timezone.coerce_datetime()
context: dict[str, Any] = {
+ # From the Task Execution interface
"dag": self.task.dag,
"inlets": self.task.inlets,
"map_index_template": self.task.map_index_template,
@@ -59,15 +63,9 @@ class RuntimeTaskInstance(TaskInstance):
"task": self.task,
"task_instance": self,
"ti": self,
- # "dag_run": dag_run,
- # "data_interval_end": timezone.coerce_datetime(data_interval.end),
- # "data_interval_start":
timezone.coerce_datetime(data_interval.start),
# "outlet_events": OutletEventAccessors(),
- # "ds": ds,
- # "ds_nodash": ds_nodash,
# "expanded_ti_count": expanded_ti_count,
# "inlet_events": InletEventsAccessors(task.inlets,
session=session),
- # "logical_date": logical_date,
# "macros": macros,
# "params": validated_params,
# "prev_data_interval_start_success":
get_prev_data_interval_start_success(),
@@ -77,15 +75,36 @@ class RuntimeTaskInstance(TaskInstance):
# "task_instance_key_str":
f"{task.dag_id}__{task.task_id}__{ds_nodash}",
# "test_mode": task_instance.test_mode,
# "triggering_asset_events":
lazy_object_proxy.Proxy(get_triggering_events),
- # "ts": ts,
- # "ts_nodash": ts_nodash,
- # "ts_nodash_with_tz": ts_nodash_with_tz,
# "var": {
# "json": VariableAccessor(deserialize_json=True),
# "value": VariableAccessor(deserialize_json=False),
# },
# "conn": ConnectionAccessor(),
}
+ if self._ti_context_from_server:
+ dag_run = self._ti_context_from_server.dag_run
+
+ logical_date = dag_run.logical_date
+ ds = logical_date.strftime("%Y-%m-%d")
+ ds_nodash = ds.replace("-", "")
+ ts = logical_date.isoformat()
+ ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S")
+ ts_nodash_with_tz = ts.replace("-", "").replace(":", "")
+
+ context_from_server = {
+ # TODO: Assess if we need to pass these through
timezone.coerce_datetime
+ "dag_run": dag_run,
+ "data_interval_end": dag_run.data_interval_end,
+ "data_interval_start": dag_run.data_interval_start,
+ "logical_date": logical_date,
+ "ds": ds,
+ "ds_nodash": ds_nodash,
+ "task_instance_key_str":
f"{self.task.dag_id}__{self.task.task_id}__{ds_nodash}",
+ "ts": ts,
+ "ts_nodash": ts_nodash,
+ "ts_nodash_with_tz": ts_nodash_with_tz,
+ }
+ context.update(context_from_server)
return context
@@ -113,7 +132,11 @@ def parse(what: StartupDetails) -> RuntimeTaskInstance:
if not isinstance(task, BaseOperator):
raise TypeError(f"task is of the wrong type, got {type(task)}, wanted
{BaseOperator}")
- return
RuntimeTaskInstance.model_construct(**what.ti.model_dump(exclude_unset=True),
task=task)
+ return RuntimeTaskInstance.model_construct(
+ **what.ti.model_dump(exclude_unset=True),
+ task=task,
+ _ti_context_from_server=what.ti_context,
+ )
SendMsgType = TypeVar("SendMsgType", bound=BaseModel)
diff --git a/task_sdk/tests/api/test_client.py
b/task_sdk/tests/api/test_client.py
index d10531ba1bb..346c3adfcc1 100644
--- a/task_sdk/tests/api/test_client.py
+++ b/task_sdk/tests/api/test_client.py
@@ -94,23 +94,31 @@ class TestTaskInstanceOperations:
response parsing.
"""
- def test_task_instance_start(self):
+ def test_task_instance_start(self, make_ti_context):
# Simulate a successful response from the server that starts a task
ti_id = uuid6.uuid7()
+ start_date = "2024-10-31T12:00:00Z"
+ ti_context = make_ti_context(
+ start_date=start_date,
+ logical_date="2024-10-31T12:00:00Z",
+ run_type="manual",
+ )
def handle_request(request: httpx.Request) -> httpx.Response:
- if request.url.path == f"/task-instances/{ti_id}/state":
+ if request.url.path == f"/task-instances/{ti_id}/run":
actual_body = json.loads(request.read())
assert actual_body["pid"] == 100
- assert actual_body["start_date"] == "2024-10-31T12:00:00Z"
+ assert actual_body["start_date"] == start_date
assert actual_body["state"] == "running"
return httpx.Response(
- status_code=204,
+ status_code=200,
+ json=ti_context.model_dump(mode="json"),
)
return httpx.Response(status_code=400, json={"detail": "Bad
Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
- client.task_instances.start(ti_id, 100, "2024-10-31T12:00:00Z")
+ resp = client.task_instances.start(ti_id, 100, start_date)
+ assert resp == ti_context
@pytest.mark.parametrize("state", [state for state in TerminalTIState])
def test_task_instance_finish(self, state):
diff --git a/task_sdk/tests/conftest.py b/task_sdk/tests/conftest.py
index 04e94008842..25d0a1b0061 100644
--- a/task_sdk/tests/conftest.py
+++ b/task_sdk/tests/conftest.py
@@ -19,7 +19,7 @@ from __future__ import annotations
import logging
import os
from pathlib import Path
-from typing import TYPE_CHECKING, NoReturn
+from typing import TYPE_CHECKING, Any, NoReturn, Protocol
import pytest
@@ -29,8 +29,12 @@ pytest_plugins = "tests_common.pytest_plugin"
os.environ["_AIRFLOW_SKIP_DB_TESTS"] = "true"
if TYPE_CHECKING:
+ from datetime import datetime
+
from structlog.typing import EventDict, WrappedLogger
+ from airflow.sdk.api.datamodels._generated import TIRunContext
+
@pytest.hookimpl()
def pytest_addhooks(pluginmanager: pytest.PytestPluginManager):
@@ -116,7 +120,7 @@ def _disable_ol_plugin():
# The OpenLineage plugin imports setproctitle, and that now causes (C)
level thread calls, which on Py
# 3.12+ issues a warning when os.fork happens. So for this plugin we
disable it
- # And we load plugins when setting the priorty_weight field
+ # And we load plugins when setting the priority_weight field
import airflow.plugins_manager
old = airflow.plugins_manager.plugins
@@ -128,3 +132,85 @@ def _disable_ol_plugin():
yield
airflow.plugins_manager.plugins = None
+
+
+class MakeTIContextCallable(Protocol):
+ def __call__(
+ self,
+ dag_id: str = ...,
+ run_id: str = ...,
+ logical_date: str | datetime = ...,
+ data_interval_start: str | datetime = ...,
+ data_interval_end: str | datetime = ...,
+ start_date: str | datetime = ...,
+ run_type: str = ...,
+ ) -> TIRunContext: ...
+
+
+class MakeTIContextDictCallable(Protocol):
+ def __call__(
+ self,
+ dag_id: str = ...,
+ run_id: str = ...,
+ logical_date: str = ...,
+ data_interval_start: str | datetime = ...,
+ data_interval_end: str | datetime = ...,
+ start_date: str | datetime = ...,
+ run_type: str = ...,
+ ) -> dict[str, Any]: ...
+
+
[email protected]
+def make_ti_context() -> MakeTIContextCallable:
+ """Factory for creating TIRunContext objects."""
+ from airflow.sdk.api.datamodels._generated import DagRun, TIRunContext
+
+ def _make_context(
+ dag_id: str = "test_dag",
+ run_id: str = "test_run",
+ logical_date: str | datetime = "2024-12-01T01:00:00Z",
+ data_interval_start: str | datetime = "2024-12-01T00:00:00Z",
+ data_interval_end: str | datetime = "2024-12-01T01:00:00Z",
+ start_date: str | datetime = "2024-12-01T01:00:00Z",
+ run_type: str = "manual",
+ ) -> TIRunContext:
+ return TIRunContext(
+ dag_run=DagRun(
+ dag_id=dag_id,
+ run_id=run_id,
+ logical_date=logical_date, # type: ignore
+ data_interval_start=data_interval_start, # type: ignore
+ data_interval_end=data_interval_end, # type: ignore
+ start_date=start_date, # type: ignore
+ run_type=run_type, # type: ignore
+ )
+ )
+
+ return _make_context
+
+
[email protected]
+def make_ti_context_dict(make_ti_context: MakeTIContextCallable) ->
MakeTIContextDictCallable:
+ """Factory for creating context dictionaries suited for API Server
response."""
+
+ def _make_context_dict(
+ dag_id: str = "test_dag",
+ run_id: str = "test_run",
+ logical_date: str | datetime = "2024-12-01T00:00:00Z",
+ data_interval_start: str | datetime = "2024-12-01T00:00:00Z",
+ data_interval_end: str | datetime = "2024-12-01T01:00:00Z",
+ start_date: str | datetime = "2024-12-01T00:00:00Z",
+ run_type: str = "manual",
+ ) -> dict[str, Any]:
+ context = make_ti_context(
+ dag_id=dag_id,
+ run_id=run_id,
+ logical_date=logical_date,
+ data_interval_start=data_interval_start,
+ data_interval_end=data_interval_end,
+ start_date=start_date,
+ run_type=run_type,
+ )
+ return context.model_dump(exclude_unset=True, mode="json")
+
+ return _make_context_dict
diff --git a/task_sdk/tests/execution_time/test_supervisor.py
b/task_sdk/tests/execution_time/test_supervisor.py
index 406b2ee2699..70f9e264864 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -254,7 +254,7 @@ class TestWatchedSubprocess:
try_number=1,
)
# Assert Exit Code is 0
- assert supervise(ti=ti, dag_path=dagfile_path, token="", server="",
dry_run=True) == 0
+ assert supervise(ti=ti, dag_path=dagfile_path, token="", server="",
dry_run=True) == 0, captured_logs
# We should have a log from the task!
assert {
@@ -265,7 +265,9 @@ class TestWatchedSubprocess:
"timestamp": "2024-11-07T12:34:56.078901Z",
} in captured_logs
- def test_supervise_handles_deferred_task(self, test_dags_dir,
captured_logs, time_machine, mocker):
+ def test_supervise_handles_deferred_task(
+ self, test_dags_dir, captured_logs, time_machine, mocker,
make_ti_context
+ ):
"""
Test that the supervisor handles a deferred task correctly.
@@ -281,12 +283,13 @@ class TestWatchedSubprocess:
# Create a mock client to assert calls to the client
# We assume the implementation of the client is correct and only need
to check the calls
mock_client = mocker.Mock(spec=sdk_client.Client)
+ mock_client.task_instances.start.return_value = make_ti_context()
instant = tz.datetime(2024, 11, 7, 12, 34, 56, 0)
time_machine.move_to(instant, tick=False)
# Assert supervisor runs the task successfully
- assert supervise(ti=ti, dag_path=dagfile_path, token="",
client=mock_client) == 0
+ assert supervise(ti=ti, dag_path=dagfile_path, token="",
client=mock_client) == 0, captured_logs
# Validate calls to the client
mock_client.task_instances.start.assert_called_once_with(ti.id,
mocker.ANY, mocker.ANY)
@@ -320,7 +323,7 @@ class TestWatchedSubprocess:
# The API Server would return a 409 Conflict status code if the TI is
not
# in a "queued" state.
def handle_request(request: httpx.Request) -> httpx.Response:
- if request.url.path == f"/task-instances/{ti.id}/state":
+ if request.url.path == f"/task-instances/{ti.id}/run":
return httpx.Response(
409,
json={
@@ -345,7 +348,7 @@ class TestWatchedSubprocess:
}
@pytest.mark.parametrize("captured_logs", [logging.ERROR], indirect=True,
ids=["log_level=error"])
- def test_state_conflict_on_heartbeat(self, captured_logs, monkeypatch,
mocker):
+ def test_state_conflict_on_heartbeat(self, captured_logs, monkeypatch,
mocker, make_ti_context_dict):
"""
Test that ensures that the Supervisor does not cause the task to fail
if the Task Instance is no longer
in the running state. Instead, it logs the error and terminates the
task process if it
@@ -383,7 +386,9 @@ class TestWatchedSubprocess:
"current_state": "success",
},
)
- # Return a 204 for all other requests like the initial call to
mark the task as running
+ elif request.url.path == f"/task-instances/{ti_id}/run":
+ return httpx.Response(200, json=make_ti_context_dict())
+ # Return a 204 for all other requests
return httpx.Response(status_code=204)
proc = WatchedSubprocess.start(
diff --git a/task_sdk/tests/execution_time/test_task_runner.py
b/task_sdk/tests/execution_time/test_task_runner.py
index c9755c252bb..2b812c92a73 100644
--- a/task_sdk/tests/execution_time/test_task_runner.py
+++ b/task_sdk/tests/execution_time/test_task_runner.py
@@ -94,7 +94,12 @@ class TestCommsDecoder:
w.makefile("wb").write(
b'{"type":"StartupDetails", "ti": {'
b'"id": "4d828a62-a417-4936-a7a6-2b3fabacecab", "task_id": "a",
"try_number": 1, "run_id": "b", "dag_id": "c" }, '
- b'"file": "/dev/null", "requests_fd": ' +
str(w2.fileno()).encode("ascii") + b"}\n"
+
b'"ti_context":{"dag_run":{"dag_id":"c","run_id":"b","logical_date":"2024-12-01T01:00:00Z",'
+
b'"data_interval_start":"2024-12-01T00:00:00Z","data_interval_end":"2024-12-01T01:00:00Z",'
+
b'"start_date":"2024-12-01T01:00:00Z","end_date":null,"run_type":"manual","conf":null},'
+ b'"variables":null,"connections":null},"file": "/dev/null",
"requests_fd": '
+ + str(w2.fileno()).encode("ascii")
+ + b"}\n"
)
decoder = CommsDecoder(input=r.makefile("r"))
@@ -112,12 +117,13 @@ class TestCommsDecoder:
assert decoder.request_socket.fileno() == w2.fileno()
-def test_parse(test_dags_dir: Path):
+def test_parse(test_dags_dir: Path, make_ti_context):
"""Test that checks parsing of a basic dag with an un-mocked parse."""
what = StartupDetails(
ti=TaskInstance(id=uuid7(), task_id="a", dag_id="super_basic",
run_id="c", try_number=1),
file=str(test_dags_dir / "super_basic.py"),
requests_fd=0,
+ ti_context=make_ti_context(),
)
ti = parse(what)
@@ -128,12 +134,13 @@ def test_parse(test_dags_dir: Path):
assert isinstance(ti.task.dag, DAG)
-def test_run_basic(time_machine, mocked_parse):
+def test_run_basic(time_machine, mocked_parse, make_ti_context):
"""Test running a basic task."""
what = StartupDetails(
ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run",
run_id="c", try_number=1),
file="",
requests_fd=0,
+ ti_context=make_ti_context(),
)
instant = timezone.datetime(2024, 12, 3, 10, 0)
@@ -150,7 +157,7 @@ def test_run_basic(time_machine, mocked_parse):
)
-def test_run_deferred_basic(time_machine, mocked_parse):
+def test_run_deferred_basic(time_machine, mocked_parse, make_ti_context):
"""Test that a task can transition to a deferred state."""
import datetime
@@ -169,6 +176,7 @@ def test_run_deferred_basic(time_machine, mocked_parse):
ti=TaskInstance(id=uuid7(), task_id="async",
dag_id="basic_deferred_run", run_id="c", try_number=1),
file="",
requests_fd=0,
+ ti_context=make_ti_context(),
)
# Expected DeferTask
@@ -194,7 +202,7 @@ def test_run_deferred_basic(time_machine, mocked_parse):
mock_supervisor_comms.send_request.assert_called_once_with(msg=expected_defer_task,
log=mock.ANY)
-def test_run_basic_skipped(time_machine, mocked_parse):
+def test_run_basic_skipped(time_machine, mocked_parse, make_ti_context):
"""Test running a basic task that marks itself skipped."""
from airflow.providers.standard.operators.python import PythonOperator
@@ -209,6 +217,7 @@ def test_run_basic_skipped(time_machine, mocked_parse):
ti=TaskInstance(id=uuid7(), task_id="skip", dag_id="basic_skipped",
run_id="c", try_number=1),
file="",
requests_fd=0,
+ ti_context=make_ti_context(),
)
ti = mocked_parse(what, "basic_skipped", task)
@@ -226,7 +235,7 @@ def test_run_basic_skipped(time_machine, mocked_parse):
)
-def test_startup_basic_templated_dag(mocked_parse):
+def test_startup_basic_templated_dag(mocked_parse, make_ti_context):
"""Test running a DAG with templated task."""
from airflow.providers.standard.operators.bash import BashOperator
@@ -241,6 +250,7 @@ def test_startup_basic_templated_dag(mocked_parse):
),
file="",
requests_fd=0,
+ ti_context=make_ti_context(),
)
mocked_parse(what, "basic_templated_dag", task)
@@ -288,7 +298,9 @@ def test_startup_basic_templated_dag(mocked_parse):
),
],
)
-def test_startup_dag_with_templated_fields(mocked_parse, task_params,
expected_rendered_fields):
+def test_startup_dag_with_templated_fields(
+ mocked_parse, task_params, expected_rendered_fields, make_ti_context
+):
"""Test startup of a DAG with various templated fields."""
class CustomOperator(BaseOperator):
@@ -305,6 +317,7 @@ def test_startup_dag_with_templated_fields(mocked_parse,
task_params, expected_r
ti=TaskInstance(id=uuid7(), task_id="templated_task",
dag_id="basic_dag", run_id="c", try_number=1),
file="",
requests_fd=0,
+ ti_context=make_ti_context(),
)
mocked_parse(what, "basic_dag", task)
@@ -318,3 +331,73 @@ def test_startup_dag_with_templated_fields(mocked_parse,
task_params, expected_r
msg=SetRenderedFields(rendered_fields=expected_rendered_fields),
log=mock.ANY,
)
+
+
+class TestRuntimeTaskInstance:
+ def test_get_context_without_ti_context_from_server(self, mocked_parse,
make_ti_context):
+ """Test get_template_context without ti_context_from_server."""
+
+ task = BaseOperator(task_id="hello")
+
+ ti_id = uuid7()
+ ti = TaskInstance(
+ id=ti_id, task_id=task.task_id, dag_id="basic_task",
run_id="test_run", try_number=1
+ )
+
+ what = StartupDetails(ti=ti, file="", requests_fd=0,
ti_context=make_ti_context())
+ runtime_ti = mocked_parse(what, ti.dag_id, task)
+ context = runtime_ti.get_template_context()
+
+ # Verify the context keys and values
+ assert context == {
+ "dag": runtime_ti.task.dag,
+ "inlets": task.inlets,
+ "map_index_template": task.map_index_template,
+ "outlets": task.outlets,
+ "run_id": "test_run",
+ "task": task,
+ "task_instance": runtime_ti,
+ "ti": runtime_ti,
+ }
+
+ def test_get_context_with_ti_context_from_server(self, mocked_parse,
make_ti_context):
+ """Test the context keys are added when sent from API server
(mocked)"""
+ from airflow.utils import timezone
+
+ ti = TaskInstance(id=uuid7(), task_id="hello", dag_id="basic_task",
run_id="test_run", try_number=1)
+
+ task = BaseOperator(task_id=ti.task_id)
+
+ ti_context = make_ti_context(dag_id=ti.dag_id, run_id=ti.run_id)
+ what = StartupDetails(ti=ti, file="", requests_fd=0,
ti_context=ti_context)
+
+ runtime_ti = mocked_parse(what, ti.dag_id, task)
+
+ # Assume the context is sent from the API server
+ # `task_sdk/tests/api/test_client.py::test_task_instance_start` checks
the context is received
+ # from the API server
+ runtime_ti._ti_context_from_server = ti_context
+ dr = ti_context.dag_run
+
+ context = runtime_ti.get_template_context()
+
+ assert context == {
+ "dag": runtime_ti.task.dag,
+ "inlets": task.inlets,
+ "map_index_template": task.map_index_template,
+ "outlets": task.outlets,
+ "run_id": "test_run",
+ "task": task,
+ "task_instance": runtime_ti,
+ "ti": runtime_ti,
+ "dag_run": dr,
+ "data_interval_end": timezone.datetime(2024, 12, 1, 1, 0, 0),
+ "data_interval_start": timezone.datetime(2024, 12, 1, 0, 0, 0),
+ "logical_date": timezone.datetime(2024, 12, 1, 1, 0, 0),
+ "ds": "2024-12-01",
+ "ds_nodash": "20241201",
+ "task_instance_key_str": "basic_task__hello__20241201",
+ "ts": "2024-12-01T01:00:00+00:00",
+ "ts_nodash": "20241201T010000",
+ "ts_nodash_with_tz": "20241201T010000+0000",
+ }
diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py
b/tests/api_fastapi/execution_api/routes/test_task_instances.py
index 15e56bbc587..e67d82a718c 100644
--- a/tests/api_fastapi/execution_api/routes/test_task_instances.py
+++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py
@@ -39,40 +39,58 @@ DEFAULT_START_DATE = timezone.parse("2024-10-31T11:00:00Z")
DEFAULT_END_DATE = timezone.parse("2024-10-31T12:00:00Z")
-class TestTIUpdateState:
+class TestTIRunState:
def setup_method(self):
clear_db_runs()
def teardown_method(self):
clear_db_runs()
- def test_ti_update_state_to_running(self, client, session,
create_task_instance):
+ def test_ti_run_state_to_running(self, client, session,
create_task_instance, time_machine):
"""
Test that the Task Instance state is updated to running when the Task
Instance is in a state where it can be
marked as running.
"""
+ instant_str = "2024-09-30T12:00:00Z"
+ instant = timezone.parse(instant_str)
+ time_machine.move_to(instant, tick=False)
ti = create_task_instance(
- task_id="test_ti_update_state_to_running",
+ task_id="test_ti_run_state_to_running",
state=State.QUEUED,
session=session,
+ start_date=instant,
)
session.commit()
response = client.patch(
- f"/execution/task-instances/{ti.id}/state",
+ f"/execution/task-instances/{ti.id}/run",
json={
"state": "running",
"hostname": "random-hostname",
"unixname": "random-unixname",
"pid": 100,
- "start_date": "2024-10-31T12:00:00Z",
+ "start_date": instant_str,
},
)
- assert response.status_code == 204
- assert response.text == ""
+ assert response.status_code == 200
+ assert response.json() == {
+ "dag_run": {
+ "dag_id": "dag",
+ "run_id": "test",
+ "logical_date": instant_str,
+ "data_interval_start":
instant.subtract(days=1).to_iso8601_string(),
+ "data_interval_end": instant_str,
+ "start_date": instant_str,
+ "end_date": None,
+ "run_type": "manual",
+ "conf": {},
+ },
+ "variables": [],
+ "connections": [],
+ }
# Refresh the Task Instance from the database so that we can check the
updated values
session.refresh(ti)
@@ -80,10 +98,10 @@ class TestTIUpdateState:
assert ti.hostname == "random-hostname"
assert ti.unixname == "random-unixname"
assert ti.pid == 100
- assert ti.start_date.isoformat() == "2024-10-31T12:00:00+00:00"
+ assert ti.start_date == instant
@pytest.mark.parametrize("initial_ti_state", [s for s in TaskInstanceState
if s != State.QUEUED])
- def test_ti_update_state_conflict_if_not_queued(
+ def test_ti_run_state_conflict_if_not_queued(
self, client, session, create_task_instance, initial_ti_state
):
"""
@@ -91,13 +109,13 @@ class TestTIUpdateState:
running. In this case, the Task Instance is first in NONE state so it
cannot be marked as running.
"""
ti = create_task_instance(
- task_id="test_ti_update_state_conflict_if_not_queued",
+ task_id="test_ti_run_state_conflict_if_not_queued",
state=initial_ti_state,
)
session.commit()
response = client.patch(
- f"/execution/task-instances/{ti.id}/state",
+ f"/execution/task-instances/{ti.id}/run",
json={
"state": "running",
"hostname": "random-hostname",
@@ -118,6 +136,14 @@ class TestTIUpdateState:
assert session.scalar(select(TaskInstance.state).where(TaskInstance.id
== ti.id)) == initial_ti_state
+
+class TestTIUpdateState:
+ def setup_method(self):
+ clear_db_runs()
+
+ def teardown_method(self):
+ clear_db_runs()
+
@pytest.mark.parametrize(
("state", "end_date", "expected_state"),
[
@@ -160,7 +186,7 @@ class TestTIUpdateState:
task_instance_id = "0182e924-0f1e-77e6-ab50-e977118bc139"
# Pre-condition: the Task Instance does not exist
- assert session.scalar(select(TaskInstance.id).where(TaskInstance.id ==
task_instance_id)) is None
+ assert session.get(TaskInstance, task_instance_id) is None
payload = {"state": "success", "end_date": "2024-10-31T12:30:00Z"}
@@ -171,6 +197,26 @@ class TestTIUpdateState:
"message": "Task Instance not found",
}
+ def test_ti_update_state_running_errors(self, client, session,
create_task_instance, time_machine):
+ """
+ Test that a 422 error is returned when the Task Instance state is
RUNNING in the payload.
+
+ Task should be set to Running state via the
/execution/task-instances/{task_instance_id}/run endpoint.
+ """
+
+ ti = create_task_instance(
+ task_id="test_ti_update_state_running_errors",
+ state=State.QUEUED,
+ session=session,
+ start_date=DEFAULT_START_DATE,
+ )
+
+ session.commit()
+
+ response = client.patch(f"/execution/task-instances/{ti.id}/state",
json={"state": "running"})
+
+ assert response.status_code == 422
+
def test_ti_update_state_database_error(self, client, session,
create_task_instance):
"""
Test that a database error is handled correctly when updating the Task
Instance state.
@@ -181,17 +227,14 @@ class TestTIUpdateState:
)
session.commit()
payload = {
- "state": "running",
- "hostname": "random-hostname",
- "unixname": "random-unixname",
- "pid": 100,
- "start_date": "2024-10-31T12:00:00Z",
+ "state": "success",
+ "end_date": "2024-10-31T12:00:00Z",
}
with mock.patch(
"airflow.api_fastapi.common.db.common.Session.execute",
side_effect=[
- mock.Mock(one=lambda: ("queued",)), # First call returns
"queued"
+ mock.Mock(one=lambda: ("running",)), # First call returns
"queued"
SQLAlchemyError("Database error"), # Second call raises an
error
],
):
@@ -334,7 +377,7 @@ class TestTIHealthEndpoint:
task_instance_id = "0182e924-0f1e-77e6-ab50-e977118bc139"
# Pre-condition: the Task Instance does not exist
- assert session.scalar(select(TaskInstance.id).where(TaskInstance.id ==
task_instance_id)) is None
+ assert session.get(TaskInstance, task_instance_id) is None
response = client.put(
f"/execution/task-instances/{task_instance_id}/heartbeat",