This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a9f3e20133d Add ability to get previous TI on ``RuntimeTaskInstance``
(#59712)
a9f3e20133d is described below
commit a9f3e20133df19b317045886165d292af06868a4
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Dec 23 16:20:21 2025 +0000
Add ability to get previous TI on ``RuntimeTaskInstance`` (#59712)
Add `get_previous_ti()` method to `RuntimeTaskInstance` for retrieving
previous task instance executions. Supports filtering by state,
logical_date, and run_id.
Fixes #59609
- Match RuntimeTaskInstanceProtocol signature (AwareDatetime,
TaskInstanceState)
- Use TaskInstanceState enum in tests
- Add GetPreviousTI/PreviousTIResult to triggerer and dag processor
exclusion lists
---
.../execution_api/datamodels/taskinstance.py | 15 ++
.../execution_api/routes/task_instances.py | 53 ++++++
.../src/airflow/dag_processing/processor.py | 12 ++
.../src/airflow/jobs/triggerer_job_runner.py | 10 ++
.../versions/head/test_task_instances.py | 189 +++++++++++++++++++++
airflow-core/tests/unit/jobs/test_triggerer_job.py | 1 +
task-sdk/src/airflow/sdk/api/client.py | 27 +++
.../src/airflow/sdk/api/datamodels/_generated.py | 53 ++++--
task-sdk/src/airflow/sdk/execution_time/comms.py | 21 +++
.../src/airflow/sdk/execution_time/supervisor.py | 9 +
.../src/airflow/sdk/execution_time/task_runner.py | 43 +++++
task-sdk/src/airflow/sdk/types.py | 9 +-
task-sdk/tests/task_sdk/api/test_client.py | 120 +++++++++++++
.../task_sdk/execution_time/test_supervisor.py | 52 ++++++
.../task_sdk/execution_time/test_task_runner.py | 128 ++++++++++++++
15 files changed, 726 insertions(+), 16 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
index 78faf620991..134019b6e31 100644
---
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
@@ -350,6 +350,21 @@ class PrevSuccessfulDagRunResponse(BaseModel):
end_date: UtcDateTime | None = None
+class PreviousTIResponse(BaseModel):
+ """Schema for response with previous TaskInstance information."""
+
+ task_id: str
+ dag_id: str
+ run_id: str
+ logical_date: UtcDateTime | None = None
+ start_date: UtcDateTime | None = None
+ end_date: UtcDateTime | None = None
+ state: str | None = None
+ try_number: int
+ map_index: int | None = -1
+ duration: float | None = None
+
+
class TaskStatesResponse(BaseModel):
"""Response for task states with run_id, task and state."""
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index c45f585f4ed..61b205864e6 100644
---
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -44,6 +44,7 @@ from airflow.api_fastapi.common.types import UtcDateTime
from airflow.api_fastapi.compat import HTTP_422_UNPROCESSABLE_CONTENT
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
InactiveAssetsResponse,
+ PreviousTIResponse,
PrevSuccessfulDagRunResponse,
TaskBreadcrumbsResponse,
TaskStatesResponse,
@@ -872,6 +873,58 @@ def get_task_instance_count(
return count or 0
[email protected]("/previous/{dag_id}/{task_id}", status_code=status.HTTP_200_OK)
+def get_previous_task_instance(
+ dag_id: str,
+ task_id: str,
+ session: SessionDep,
+ logical_date: Annotated[UtcDateTime | None, Query()] = None,
+ map_index: Annotated[int, Query()] = -1,
+ state: Annotated[TaskInstanceState | None, Query()] = None,
+) -> PreviousTIResponse | None:
+ """
+ Get the previous task instance matching the given criteria.
+
+ :param dag_id: DAG ID (from path)
+ :param task_id: Task ID (from path)
+ :param logical_date: If provided, finds TI with logical_date < this value
(before filter)
+ :param map_index: Map index to filter by (defaults to -1 for non-mapped
tasks)
+ :param state: If provided, filters by TaskInstance state
+ """
+ query = (
+ select(TI)
+ .join(DR, (TI.dag_id == DR.dag_id) & (TI.run_id == DR.run_id))
+ .options(joinedload(TI.dag_run))
+ .where(TI.dag_id == dag_id, TI.task_id == task_id, TI.map_index ==
map_index)
+ .order_by(DR.logical_date.desc())
+ )
+
+ if logical_date:
+ # Find TI with logical_date BEFORE the provided date (previous)
+ query = query.where(DR.logical_date < logical_date)
+
+ if state:
+ query = query.where(TI.state == state)
+
+ ti = session.scalars(query).first()
+
+ if not ti:
+ return None
+
+ return PreviousTIResponse(
+ task_id=ti.task_id,
+ dag_id=ti.dag_id,
+ run_id=ti.run_id,
+ logical_date=ti.dag_run.logical_date,
+ start_date=ti.start_date,
+ end_date=ti.end_date,
+ state=ti.state,
+ try_number=ti.try_number,
+ map_index=ti.map_index,
+ duration=ti.duration,
+ )
+
+
@router.get("/states", status_code=status.HTTP_200_OK)
def get_task_instance_states(
dag_id: str,
diff --git a/airflow-core/src/airflow/dag_processing/processor.py
b/airflow-core/src/airflow/dag_processing/processor.py
index 41181f7bca8..82711527803 100644
--- a/airflow-core/src/airflow/dag_processing/processor.py
+++ b/airflow-core/src/airflow/dag_processing/processor.py
@@ -44,6 +44,7 @@ from airflow.sdk.execution_time.comms import (
ErrorResponse,
GetConnection,
GetPreviousDagRun,
+ GetPreviousTI,
GetPrevSuccessfulDagRun,
GetTaskStates,
GetTICount,
@@ -55,6 +56,7 @@ from airflow.sdk.execution_time.comms import (
MaskSecret,
OKResponse,
PreviousDagRunResult,
+ PreviousTIResult,
PrevSuccessfulDagRunResult,
PutVariable,
TaskStatesResult,
@@ -127,6 +129,7 @@ ToManager = Annotated[
| DeleteVariable
| GetPrevSuccessfulDagRun
| GetPreviousDagRun
+ | GetPreviousTI
| GetXCom
| GetXComCount
| GetXComSequenceItem
@@ -141,6 +144,7 @@ ToDagProcessor = Annotated[
| VariableResult
| TaskStatesResult
| PreviousDagRunResult
+ | PreviousTIResult
| PrevSuccessfulDagRunResult
| ErrorResponse
| OKResponse
@@ -632,6 +636,14 @@ class DagFileProcessorProcess(WatchedSubprocess):
resp = TaskStatesResult.from_api_response(task_states_map)
else:
resp = task_states_map
+ elif isinstance(msg, GetPreviousTI):
+ resp = self.client.task_instances.get_previous(
+ dag_id=msg.dag_id,
+ task_id=msg.task_id,
+ logical_date=msg.logical_date,
+ map_index=msg.map_index,
+ state=msg.state,
+ )
else:
log.error("Unhandled request", msg=msg)
self.send_msg(
diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py
b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
index 13ec255832f..258dd47f602 100644
--- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py
+++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
@@ -60,6 +60,7 @@ from airflow.sdk.execution_time.comms import (
GetDagRunState,
GetDRCount,
GetHITLDetailResponse,
+ GetPreviousTI,
GetTaskStates,
GetTICount,
GetVariable,
@@ -269,6 +270,7 @@ ToTriggerSupervisor = Annotated[
| GetTaskStates
| GetDagRunState
| GetDRCount
+ | GetPreviousTI
| GetHITLDetailResponse
| UpdateHITLDetail
| MaskSecret,
@@ -498,6 +500,14 @@ class TriggerRunnerSupervisor(WatchedSubprocess):
resp =
TaskStatesResult.from_api_response(run_id_task_state_map)
else:
resp = run_id_task_state_map
+ elif isinstance(msg, GetPreviousTI):
+ resp = self.client.task_instances.get_previous(
+ dag_id=msg.dag_id,
+ task_id=msg.task_id,
+ logical_date=msg.logical_date,
+ map_index=msg.map_index,
+ state=msg.state,
+ )
elif isinstance(msg, UpdateHITLDetail):
api_resp = self.client.hitl.update_response(
ti_id=msg.ti_id,
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index d9f4e1f4242..d794cb4fe20 100644
---
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -2138,6 +2138,195 @@ class TestGetCount:
assert response.json() == expected_count
+class TestGetPreviousTI:
+ def setup_method(self):
+ clear_db_runs()
+
+ def teardown_method(self):
+ clear_db_runs()
+
+ def test_get_previous_ti_basic(self, client, session,
create_task_instance):
+ """Test basic get_previous_ti without filters."""
+ # Create TIs with different logical dates
+ create_task_instance(
+ task_id="test_task",
+ state=State.SUCCESS,
+ logical_date=timezone.datetime(2025, 1, 1),
+ run_id="run1",
+ )
+ create_task_instance(
+ task_id="test_task",
+ state=State.SUCCESS,
+ logical_date=timezone.datetime(2025, 1, 2),
+ run_id="run2",
+ )
+ session.commit()
+
+ response = client.get(
+ "/execution/task-instances/previous/dag/test_task",
+ params={
+ "logical_date": "2025-01-02T00:00:00Z",
+ },
+ )
+ assert response.status_code == 200
+ data = response.json()
+ assert data["task_id"] == "test_task"
+ assert data["dag_id"] == "dag"
+ assert data["run_id"] == "run1"
+ assert data["state"] == State.SUCCESS
+
+ def test_get_previous_ti_with_state_filter(self, client, session,
create_task_instance):
+ """Test get_previous_ti with state filter."""
+ # Create TIs with different states
+ create_task_instance(
+ task_id="test_task",
+ state=State.FAILED,
+ logical_date=timezone.datetime(2025, 1, 1),
+ run_id="run1",
+ )
+ create_task_instance(
+ task_id="test_task",
+ state=State.SUCCESS,
+ logical_date=timezone.datetime(2025, 1, 2),
+ run_id="run2",
+ )
+ create_task_instance(
+ task_id="test_task",
+ state=State.FAILED,
+ logical_date=timezone.datetime(2025, 1, 3),
+ run_id="run3",
+ )
+ session.commit()
+
+ # Query for previous successful TI before 2025-01-03
+ response = client.get(
+ "/execution/task-instances/previous/dag/test_task",
+ params={
+ "logical_date": "2025-01-03T00:00:00Z",
+ "state": State.SUCCESS,
+ },
+ )
+ assert response.status_code == 200
+ data = response.json()
+ assert data["task_id"] == "test_task"
+ assert data["run_id"] == "run2"
+ assert data["state"] == State.SUCCESS
+
+ def test_get_previous_ti_with_map_index_filter(self, client, session,
create_task_instance):
+ """Test get_previous_ti with map_index filter for mapped tasks."""
+ # Create TIs with different map_index values
+ # map_index=0 in run1
+ create_task_instance(
+ task_id="test_task",
+ state=State.SUCCESS,
+ logical_date=timezone.datetime(2025, 1, 1),
+ run_id="run1",
+ map_index=0,
+ )
+ # map_index=1 in run1_alt (different logical date to avoid constraint)
+ create_task_instance(
+ task_id="test_task",
+ state=State.SUCCESS,
+ logical_date=timezone.datetime(2025, 1, 1, 12, 0, 0),
+ run_id="run1_alt",
+ map_index=1,
+ )
+ # map_index=0 in run2
+ create_task_instance(
+ task_id="test_task",
+ state=State.SUCCESS,
+ logical_date=timezone.datetime(2025, 1, 2),
+ run_id="run2",
+ map_index=0,
+ )
+ session.commit()
+
+ # Query for previous TI with map_index=0 before 2025-01-02
+ response = client.get(
+ "/execution/task-instances/previous/dag/test_task",
+ params={
+ "logical_date": "2025-01-02T00:00:00Z",
+ "map_index": 0,
+ },
+ )
+ assert response.status_code == 200
+ data = response.json()
+ assert data["run_id"] == "run1"
+ assert data["map_index"] == 0
+
+ def test_get_previous_ti_not_found(self, client, session):
+ """Test get_previous_ti when no previous TI exists."""
+ response = client.get(
+ "/execution/task-instances/previous/dag/test_task",
+ )
+ assert response.status_code == 200
+ assert response.json() is None
+
+ def test_get_previous_ti_returns_most_recent(self, client, session,
create_task_instance):
+ """Test that get_previous_ti returns the most recent matching TI."""
+ # Create multiple TIs
+ for i in range(5):
+ create_task_instance(
+ task_id="test_task",
+ state=State.SUCCESS,
+ logical_date=timezone.datetime(2025, 1, i + 1),
+ run_id=f"run{i + 1}",
+ )
+ session.commit()
+
+ # Query for TI before 2025-01-05
+ response = client.get(
+ "/execution/task-instances/previous/dag/test_task",
+ params={
+ "logical_date": "2025-01-05T00:00:00Z",
+ },
+ )
+ assert response.status_code == 200
+ data = response.json()
+ # Should return the most recent one (2025-01-04)
+ assert data["run_id"] == "run4"
+
+ def test_get_previous_ti_with_all_filters(self, client, session,
create_task_instance):
+ """Test get_previous_ti with all filters combined."""
+ # Create TIs with different states and map_index values
+ create_task_instance(
+ task_id="test_task",
+ state=State.SUCCESS,
+ logical_date=timezone.datetime(2025, 1, 1),
+ run_id="target_run_1",
+ map_index=0,
+ )
+ create_task_instance(
+ task_id="test_task",
+ state=State.FAILED,
+ logical_date=timezone.datetime(2025, 1, 2),
+ run_id="target_run_2",
+ map_index=0,
+ )
+ create_task_instance(
+ task_id="test_task",
+ state=State.SUCCESS,
+ logical_date=timezone.datetime(2025, 1, 3),
+ run_id="target_run_3",
+ map_index=1,
+ )
+ session.commit()
+
+ # Query for previous successful TI before 2025-01-03 with map_index=0
+ response = client.get(
+ "/execution/task-instances/previous/dag/test_task",
+ params={
+ "logical_date": "2025-01-03T00:00:00Z",
+ "state": State.SUCCESS,
+ "map_index": 0,
+ },
+ )
+ assert response.status_code == 200
+ data = response.json()
+ assert data["run_id"] == "target_run_1"
+ assert data["state"] == State.SUCCESS
+
+
class TestGetTaskStates:
def setup_method(self):
clear_db_runs()
diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py
b/airflow-core/tests/unit/jobs/test_triggerer_job.py
index b1c09e3de92..b9df47ec768 100644
--- a/airflow-core/tests/unit/jobs/test_triggerer_job.py
+++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py
@@ -1236,6 +1236,7 @@ class TestTriggererMessageTypes:
"XComSequenceIndexResult",
"XComSequenceSliceResult",
"PreviousDagRunResult",
+ "PreviousTIResult",
"HITLDetailRequestResult",
}
diff --git a/task-sdk/src/airflow/sdk/api/client.py
b/task-sdk/src/airflow/sdk/api/client.py
index fece35271f5..b520d0ab226 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -82,6 +82,7 @@ from airflow.sdk.execution_time.comms import (
ErrorResponse,
OKResponse,
PreviousDagRunResult,
+ PreviousTIResult,
SkipDownstreamTasks,
TaskRescheduleStartDate,
TICount,
@@ -322,6 +323,32 @@ class TaskInstanceOperations:
resp = self.client.get("task-instances/count", params=params)
return TICount(count=resp.json())
+ def get_previous(
+ self,
+ dag_id: str,
+ task_id: str,
+ logical_date: datetime | None = None,
+ map_index: int = -1,
+ state: TaskInstanceState | str | None = None,
+ ) -> PreviousTIResult:
+ """
+ Get the previous task instance matching the given criteria.
+
+ :param dag_id: DAG ID
+ :param task_id: Task ID
+ :param logical_date: If provided, finds TI with logical_date < this
value (before filter)
+ :param map_index: Map index to filter by (defaults to -1 for
non-mapped tasks)
+ :param state: If provided, filters by TaskInstance state
+ """
+ params: dict[str, Any] = {"map_index": map_index}
+ if logical_date:
+ params["logical_date"] = logical_date.isoformat()
+ if state:
+ params["state"] = state.value if isinstance(state,
TaskInstanceState) else state
+
+ resp = self.client.get(f"task-instances/previous/{dag_id}/{task_id}",
params=params)
+ return PreviousTIResult(task_instance=resp.json())
+
def get_task_states(
self,
dag_id: str,
diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
index 6b71bb83c9a..63e8edbf81c 100644
--- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -171,6 +171,23 @@ class PrevSuccessfulDagRunResponse(BaseModel):
end_date: Annotated[AwareDatetime | None, Field(title="End Date")] = None
+class PreviousTIResponse(BaseModel):
+ """
+ Schema for response with previous TaskInstance information.
+ """
+
+ task_id: Annotated[str, Field(title="Task Id")]
+ dag_id: Annotated[str, Field(title="Dag Id")]
+ run_id: Annotated[str, Field(title="Run Id")]
+ logical_date: Annotated[AwareDatetime | None, Field(title="Logical Date")]
= None
+ start_date: Annotated[AwareDatetime | None, Field(title="Start Date")] =
None
+ end_date: Annotated[AwareDatetime | None, Field(title="End Date")] = None
+ state: Annotated[str | None, Field(title="State")] = None
+ try_number: Annotated[int, Field(title="Try Number")]
+ map_index: Annotated[int | None, Field(title="Map Index")] = -1
+ duration: Annotated[float | None, Field(title="Duration")] = None
+
+
class TIDeferredStatePayload(BaseModel):
"""
Schema for updating TaskInstance to a deferred state.
@@ -286,6 +303,27 @@ class TaskBreadcrumbsResponse(BaseModel):
breadcrumbs: Annotated[list[dict[str, Any]], Field(title="Breadcrumbs")]
+class TaskInstanceState(str, Enum):
+ """
+ All possible states that a Task Instance can be in.
+
+ Note that None is also allowed, so always use this in a type hint with
Optional.
+ """
+
+ REMOVED = "removed"
+ SCHEDULED = "scheduled"
+ QUEUED = "queued"
+ RUNNING = "running"
+ SUCCESS = "success"
+ RESTARTING = "restarting"
+ FAILED = "failed"
+ UP_FOR_RETRY = "up_for_retry"
+ UP_FOR_RESCHEDULE = "up_for_reschedule"
+ UPSTREAM_FAILED = "upstream_failed"
+ SKIPPED = "skipped"
+ DEFERRED = "deferred"
+
+
class TaskStatesResponse(BaseModel):
"""
Response for task states with run_id, task and state.
@@ -424,21 +462,6 @@ class TerminalTIState(str, Enum):
REMOVED = "removed"
-class TaskInstanceState(str, Enum):
- REMOVED = "removed"
- SCHEDULED = "scheduled"
- QUEUED = "queued"
- RUNNING = "running"
- SUCCESS = "success"
- RESTARTING = "restarting"
- FAILED = "failed"
- UP_FOR_RETRY = "up_for_retry"
- UP_FOR_RESCHEDULE = "up_for_reschedule"
- UPSTREAM_FAILED = "upstream_failed"
- SKIPPED = "skipped"
- DEFERRED = "deferred"
-
-
class WeightRule(str, Enum):
DOWNSTREAM = "downstream"
UPSTREAM = "upstream"
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index 0420122286b..7ef112a4921 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -73,6 +73,7 @@ from airflow.sdk.api.datamodels._generated import (
DagRunStateResponse,
HITLDetailRequest,
InactiveAssetsResponse,
+ PreviousTIResponse,
PrevSuccessfulDagRunResponse,
TaskBreadcrumbsResponse,
TaskInstance,
@@ -542,6 +543,13 @@ class PreviousDagRunResult(BaseModel):
type: Literal["PreviousDagRunResult"] = "PreviousDagRunResult"
+class PreviousTIResult(BaseModel):
+ """Response containing previous task instance data."""
+
+ task_instance: PreviousTIResponse | None = None
+ type: Literal["PreviousTIResult"] = "PreviousTIResult"
+
+
class PrevSuccessfulDagRunResult(PrevSuccessfulDagRunResponse):
type: Literal["PrevSuccessfulDagRunResult"] = "PrevSuccessfulDagRunResult"
@@ -655,6 +663,7 @@ ToTask = Annotated[
| DRCount
| ErrorResponse
| PrevSuccessfulDagRunResult
+ | PreviousTIResult
| SentFDs
| StartupDetails
| TaskRescheduleStartDate
@@ -866,6 +875,17 @@ class GetPreviousDagRun(BaseModel):
type: Literal["GetPreviousDagRun"] = "GetPreviousDagRun"
+class GetPreviousTI(BaseModel):
+ """Request to get previous task instance."""
+
+ dag_id: str
+ task_id: str
+ logical_date: AwareDatetime | None = None
+ map_index: int = -1
+ state: TaskInstanceState | None = None
+ type: Literal["GetPreviousTI"] = "GetPreviousTI"
+
+
class GetAssetByName(BaseModel):
name: str
type: Literal["GetAssetByName"] = "GetAssetByName"
@@ -984,6 +1004,7 @@ ToSupervisor = Annotated[
| GetDRCount
| GetPrevSuccessfulDagRun
| GetPreviousDagRun
+ | GetPreviousTI
| GetTaskRescheduleStartDate
| GetTICount
| GetTaskBreadcrumbs
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index c2c838911fe..45432ac37e5 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -88,6 +88,7 @@ from airflow.sdk.execution_time.comms import (
GetDagRunState,
GetDRCount,
GetPreviousDagRun,
+ GetPreviousTI,
GetPrevSuccessfulDagRun,
GetTaskBreadcrumbs,
GetTaskRescheduleStartDate,
@@ -1418,6 +1419,14 @@ class ActivitySubprocess(WatchedSubprocess):
logical_date=msg.logical_date,
state=msg.state,
)
+ elif isinstance(msg, GetPreviousTI):
+ resp = self.client.task_instances.get_previous(
+ dag_id=msg.dag_id,
+ task_id=msg.task_id,
+ logical_date=msg.logical_date,
+ map_index=msg.map_index,
+ state=msg.state,
+ )
elif isinstance(msg, DeleteVariable):
resp = self.client.variables.delete(msg.key)
elif isinstance(msg, ValidateInletsAndOutlets):
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 8f1e8c37da9..85fd83551d4 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -45,6 +45,7 @@ from airflow.sdk.api.client import get_hostname, getuser
from airflow.sdk.api.datamodels._generated import (
AssetProfile,
DagRun,
+ PreviousTIResponse,
TaskInstance,
TaskInstanceState,
TIRunContext,
@@ -76,12 +77,14 @@ from airflow.sdk.execution_time.comms import (
GetDagRunState,
GetDRCount,
GetPreviousDagRun,
+ GetPreviousTI,
GetTaskBreadcrumbs,
GetTaskRescheduleStartDate,
GetTaskStates,
GetTICount,
InactiveAssetsResult,
PreviousDagRunResult,
+ PreviousTIResult,
RescheduleTask,
ResendLoggingFD,
RetryTask,
@@ -481,6 +484,46 @@ class RuntimeTaskInstance(TaskInstance):
return response.dag_run
+ def get_previous_ti(
+ self,
+ state: TaskInstanceState | None = None,
+ logical_date: AwareDatetime | None = None,
+ map_index: int = -1,
+ ) -> PreviousTIResponse | None:
+ """
+ Return the previous task instance matching the given criteria.
+
+ :param state: Filter by TaskInstance state
+ :param logical_date: Filter by logical date (returns TI before this
date)
+ :param map_index: Filter by map_index (defaults to -1 for non-mapped
tasks)
+ :return: Previous task instance or None if not found
+ """
+ context = self.get_template_context()
+ dag_run = context.get("dag_run")
+
+ log = structlog.get_logger(logger_name="task")
+ log.debug("Getting previous task instance", task_id=self.task_id,
state=state)
+
+ # Use current dag run's logical_date if not provided
+ effective_logical_date = logical_date
+ if effective_logical_date is None and dag_run and dag_run.logical_date:
+ effective_logical_date = dag_run.logical_date
+
+ response = SUPERVISOR_COMMS.send(
+ msg=GetPreviousTI(
+ dag_id=self.dag_id,
+ task_id=self.task_id,
+ logical_date=effective_logical_date,
+ map_index=map_index,
+ state=state,
+ )
+ )
+
+ if TYPE_CHECKING:
+ assert isinstance(response, PreviousTIResult)
+
+ return response.task_instance
+
@staticmethod
def get_ti_count(
dag_id: str,
diff --git a/task-sdk/src/airflow/sdk/types.py
b/task-sdk/src/airflow/sdk/types.py
index c87a7ebd6d1..9bb5e4f480f 100644
--- a/task-sdk/src/airflow/sdk/types.py
+++ b/task-sdk/src/airflow/sdk/types.py
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
from pydantic import AwareDatetime, JsonValue
from airflow.sdk._shared.logging.types import Logger as Logger
- from airflow.sdk.api.datamodels._generated import TaskInstanceState
+ from airflow.sdk.api.datamodels._generated import PreviousTIResponse,
TaskInstanceState
from airflow.sdk.bases.operator import BaseOperator
from airflow.sdk.definitions.asset import Asset, AssetAlias,
AssetAliasEvent, AssetRef, BaseAssetUniqueKey
from airflow.sdk.definitions.context import Context
@@ -92,6 +92,13 @@ class RuntimeTaskInstanceProtocol(Protocol):
def get_previous_dagrun(self, state: str | None = None) -> DagRunProtocol
| None: ...
+ def get_previous_ti(
+ self,
+ state: TaskInstanceState | None = None,
+ logical_date: AwareDatetime | None = None,
+ map_index: int = -1,
+ ) -> PreviousTIResponse | None: ...
+
@staticmethod
def get_ti_count(
dag_id: str,
diff --git a/task-sdk/tests/task_sdk/api/test_client.py
b/task-sdk/tests/task_sdk/api/test_client.py
index 0ebcd79c39d..22b85eeff6f 100644
--- a/task-sdk/tests/task_sdk/api/test_client.py
+++ b/task-sdk/tests/task_sdk/api/test_client.py
@@ -52,6 +52,7 @@ from airflow.sdk.execution_time.comms import (
ErrorResponse,
OKResponse,
PreviousDagRunResult,
+ PreviousTIResult,
RescheduleTask,
TaskRescheduleStartDate,
)
@@ -555,6 +556,125 @@ class TestTaskInstanceOperations:
)
assert result.task_states == {"run_id": {"group1.task1": "success",
"group1.task2": "failed"}}
+ def test_get_previous_basic(self):
+ """Test basic get_previous functionality."""
+ logical_date = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc)
+
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if request.url.path ==
"/task-instances/previous/test_dag/test_task":
+ assert request.url.params == httpx.QueryParams(
+ logical_date=logical_date.isoformat(),
+ map_index="-1",
+ )
+ # Return complete TI data
+ return httpx.Response(
+ status_code=200,
+ json={
+ "task_id": "test_task",
+ "dag_id": "test_dag",
+ "run_id": "prev_run",
+ "logical_date": "2024-01-14T12:00:00+00:00",
+ "start_date": "2024-01-14T12:05:00+00:00",
+ "end_date": "2024-01-14T12:10:00+00:00",
+ "state": "success",
+ "try_number": 1,
+ "map_index": -1,
+ "duration": 300.0,
+ },
+ )
+ return httpx.Response(status_code=422)
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.task_instances.get_previous(
+ dag_id="test_dag", task_id="test_task", logical_date=logical_date
+ )
+
+ assert isinstance(result, PreviousTIResult)
+ assert result.task_instance.task_id == "test_task"
+ assert result.task_instance.dag_id == "test_dag"
+ assert result.task_instance.run_id == "prev_run"
+ assert result.task_instance.state == "success"
+
+ def test_get_previous_with_state_filter(self):
+ """Test get_previous functionality with state filtering."""
+ logical_date = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc)
+
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if request.url.path ==
"/task-instances/previous/test_dag/test_task":
+ assert request.url.params == httpx.QueryParams(
+ logical_date=logical_date.isoformat(),
+ map_index="-1",
+ state="success",
+ )
+ return httpx.Response(
+ status_code=200,
+ json={
+ "task_id": "test_task",
+ "dag_id": "test_dag",
+ "run_id": "prev_run",
+ "logical_date": "2024-01-14T12:00:00+00:00",
+ "start_date": "2024-01-14T12:05:00+00:00",
+ "end_date": "2024-01-14T12:10:00+00:00",
+ "state": "success",
+ "try_number": 1,
+ "map_index": -1,
+ "duration": 300.0,
+ },
+ )
+ return httpx.Response(status_code=422)
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.task_instances.get_previous(
+ dag_id="test_dag", task_id="test_task", logical_date=logical_date,
state="success"
+ )
+
+ assert result.task_instance.state == "success"
+
+ def test_get_previous_with_map_index_filter(self):
+ """Test get_previous functionality with map_index filtering."""
+
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if request.url.path ==
"/task-instances/previous/test_dag/test_task":
+ assert request.url.params == httpx.QueryParams(
+ map_index="0",
+ )
+ return httpx.Response(
+ status_code=200,
+ json={
+ "task_id": "test_task",
+ "dag_id": "test_dag",
+ "run_id": "prev_run",
+ "logical_date": "2024-01-14T12:00:00+00:00",
+ "start_date": "2024-01-14T12:05:00+00:00",
+ "end_date": "2024-01-14T12:10:00+00:00",
+ "state": "success",
+ "try_number": 1,
+ "map_index": 0,
+ "duration": 300.0,
+ },
+ )
+ return httpx.Response(status_code=422)
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.task_instances.get_previous(dag_id="test_dag",
task_id="test_task", map_index=0)
+
+ assert result.task_instance.map_index == 0
+
+ def test_get_previous_not_found(self):
+ """Test get_previous when no previous TI exists returns None."""
+
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if request.url.path ==
"/task-instances/previous/test_dag/test_task":
+ # Return None (null) when no previous TI found
+ return httpx.Response(status_code=200, content="null")
+ return httpx.Response(status_code=422)
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.task_instances.get_previous(dag_id="test_dag",
task_id="test_task")
+
+ assert isinstance(result, PreviousTIResult)
+ assert result.task_instance is None
+
class TestVariableOperations:
"""
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index 63291652efd..9c0eb2a79f0 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -59,6 +59,7 @@ from airflow.sdk.api.datamodels._generated import (
DagRun,
DagRunState,
DagRunType,
+ PreviousTIResponse,
TaskInstance,
TaskInstanceState,
)
@@ -87,6 +88,7 @@ from airflow.sdk.execution_time.comms import (
GetDRCount,
GetHITLDetailResponse,
GetPreviousDagRun,
+ GetPreviousTI,
GetPrevSuccessfulDagRun,
GetTaskBreadcrumbs,
GetTaskRescheduleStartDate,
@@ -102,6 +104,7 @@ from airflow.sdk.execution_time.comms import (
MaskSecret,
OKResponse,
PreviousDagRunResult,
+ PreviousTIResult,
PrevSuccessfulDagRunResult,
PutVariable,
RescheduleTask,
@@ -2167,6 +2170,55 @@ REQUEST_TEST_CASES = [
),
test_id="get_previous_dagrun_with_state",
),
+ RequestTestCase(
+ message=GetPreviousTI(
+ dag_id="test_dag",
+ task_id="test_task",
+ logical_date=timezone.parse("2024-01-15T12:00:00Z"),
+ map_index=0,
+ state=TaskInstanceState.SUCCESS,
+ ),
+ expected_body={
+ "task_instance": {
+ "task_id": "test_task",
+ "dag_id": "test_dag",
+ "run_id": "prev_run",
+ "logical_date": timezone.parse("2024-01-14T12:00:00Z"),
+ "start_date": timezone.parse("2024-01-14T12:05:00Z"),
+ "end_date": timezone.parse("2024-01-14T12:10:00Z"),
+ "state": "success",
+ "try_number": 1,
+ "map_index": 0,
+ "duration": 300.0,
+ },
+ "type": "PreviousTIResult",
+ },
+ client_mock=ClientMock(
+ method_path="task_instances.get_previous",
+ kwargs={
+ "dag_id": "test_dag",
+ "task_id": "test_task",
+ "logical_date": timezone.parse("2024-01-15T12:00:00Z"),
+ "map_index": 0,
+ "state": TaskInstanceState.SUCCESS,
+ },
+ response=PreviousTIResult(
+ task_instance=PreviousTIResponse(
+ task_id="test_task",
+ dag_id="test_dag",
+ run_id="prev_run",
+ logical_date=timezone.parse("2024-01-14T12:00:00Z"),
+ start_date=timezone.parse("2024-01-14T12:05:00Z"),
+ end_date=timezone.parse("2024-01-14T12:10:00Z"),
+ state="success",
+ try_number=1,
+ map_index=0,
+ duration=300.0,
+ )
+ ),
+ ),
+ test_id="get_previous_ti",
+ ),
RequestTestCase(
message=GetTaskRescheduleStartDate(ti_id=TI_ID),
expected_body={
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 9914f68e265..7830d8543db 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -53,6 +53,7 @@ from airflow.sdk.api.datamodels._generated import (
AssetResponse,
DagRun,
DagRunState,
+ PreviousTIResponse,
TaskInstance,
TaskInstanceState,
TIRunContext,
@@ -84,6 +85,7 @@ from airflow.sdk.execution_time.comms import (
GetDagRunState,
GetDRCount,
GetPreviousDagRun,
+ GetPreviousTI,
GetTaskStates,
GetTICount,
GetVariable,
@@ -92,6 +94,7 @@ from airflow.sdk.execution_time.comms import (
MaskSecret,
OKResponse,
PreviousDagRunResult,
+ PreviousTIResult,
PrevSuccessfulDagRunResult,
SetRenderedFields,
SetXCom,
@@ -2143,6 +2146,131 @@ class TestRuntimeTaskInstance:
assert dr.run_id == "prev_success_run"
assert dr.state == "success"
+ def test_get_previous_ti_basic(self, create_runtime_ti,
mock_supervisor_comms):
+ """Test that get_previous_ti sends the correct request without
filters."""
+
+ task = BaseOperator(task_id="test_task")
+ dag_id = "test_dag"
+ runtime_ti = create_runtime_ti(task=task, dag_id=dag_id,
logical_date=timezone.datetime(2025, 1, 2))
+
+ ti_data = PreviousTIResponse(
+ task_id="test_task",
+ dag_id=dag_id,
+ run_id="prev_run",
+ logical_date=timezone.datetime(2025, 1, 1),
+ start_date=timezone.datetime(2025, 1, 1, 12, 0, 0),
+ end_date=timezone.datetime(2025, 1, 1, 12, 5, 0),
+ state="success",
+ try_number=1,
+ map_index=-1,
+ duration=300.0,
+ )
+
+ mock_supervisor_comms.send.return_value =
PreviousTIResult(task_instance=ti_data)
+
+ prev_ti = runtime_ti.get_previous_ti()
+
+ mock_supervisor_comms.send.assert_called_once_with(
+ msg=GetPreviousTI(
+ dag_id="test_dag",
+ task_id="test_task",
+ logical_date=timezone.datetime(2025, 1, 2),
+ map_index=-1,
+ state=None,
+ ),
+ )
+ assert prev_ti.task_id == "test_task"
+ assert prev_ti.dag_id == "test_dag"
+ assert prev_ti.run_id == "prev_run"
+ assert prev_ti.state == "success"
+
+ def test_get_previous_ti_with_state(self, create_runtime_ti,
mock_supervisor_comms):
+ """Test get_previous_ti with state filter."""
+
+ task = BaseOperator(task_id="test_task")
+ dag_id = "test_dag"
+ runtime_ti = create_runtime_ti(task=task, dag_id=dag_id,
logical_date=timezone.datetime(2025, 1, 2))
+
+ ti_data = PreviousTIResponse(
+ task_id="test_task",
+ dag_id=dag_id,
+ run_id="prev_success_run",
+ logical_date=timezone.datetime(2025, 1, 1),
+ start_date=timezone.datetime(2025, 1, 1, 12, 0, 0),
+ end_date=timezone.datetime(2025, 1, 1, 12, 5, 0),
+ state="success",
+ try_number=1,
+ map_index=-1,
+ duration=300.0,
+ )
+
+ mock_supervisor_comms.send.return_value =
PreviousTIResult(task_instance=ti_data)
+
+ prev_ti = runtime_ti.get_previous_ti(state=TaskInstanceState.SUCCESS)
+
+ mock_supervisor_comms.send.assert_called_once_with(
+ msg=GetPreviousTI(
+ dag_id="test_dag",
+ task_id="test_task",
+ logical_date=timezone.datetime(2025, 1, 2),
+ map_index=-1,
+ state=TaskInstanceState.SUCCESS,
+ ),
+ )
+ assert prev_ti.state == "success"
+ assert prev_ti.run_id == "prev_success_run"
+
+ def test_get_previous_ti_with_map_index(self, create_runtime_ti,
mock_supervisor_comms):
+ """Test get_previous_ti with explicit map_index filter."""
+
+ task = BaseOperator(task_id="test_task")
+ dag_id = "test_dag"
+ runtime_ti = create_runtime_ti(
+ task=task, dag_id=dag_id, logical_date=timezone.datetime(2025, 1,
2), map_index=0
+ )
+
+ ti_data = PreviousTIResponse(
+ task_id="test_task",
+ dag_id=dag_id,
+ run_id="prev_run",
+ logical_date=timezone.datetime(2025, 1, 1),
+ start_date=timezone.datetime(2025, 1, 1, 12, 0, 0),
+ end_date=timezone.datetime(2025, 1, 1, 12, 5, 0),
+ state="success",
+ try_number=1,
+ map_index=1,
+ duration=300.0,
+ )
+
+ mock_supervisor_comms.send.return_value =
PreviousTIResult(task_instance=ti_data)
+
+ # Query for a different map_index than current TI
+ prev_ti = runtime_ti.get_previous_ti(map_index=1)
+
+ mock_supervisor_comms.send.assert_called_once_with(
+ msg=GetPreviousTI(
+ dag_id="test_dag",
+ task_id="test_task",
+ logical_date=timezone.datetime(2025, 1, 2),
+ map_index=1,
+ state=None,
+ ),
+ )
+ assert prev_ti.map_index == 1
+
+ def test_get_previous_ti_not_found(self, create_runtime_ti,
mock_supervisor_comms):
+ """Test get_previous_ti when no previous TI exists."""
+
+ task = BaseOperator(task_id="test_task")
+ dag_id = "test_dag"
+ runtime_ti = create_runtime_ti(task=task, dag_id=dag_id,
logical_date=timezone.datetime(2025, 1, 2))
+
+ mock_supervisor_comms.send.return_value =
PreviousTIResult(task_instance=None)
+
+ prev_ti = runtime_ti.get_previous_ti()
+
+ assert prev_ti is None
+
@pytest.mark.parametrize(
"map_index",
[