This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 41b151e7dde AIP-72: Get Previous Successful Dag Run in Task Context 
(#45813)
41b151e7dde is described below

commit 41b151e7dde473ec445f9f78fb4b8db826c368fc
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Jan 21 11:14:25 2025 +0530

    AIP-72: Get Previous Successful Dag Run in Task Context (#45813)
    
    closes https://github.com/apache/airflow/issues/45814
    
    Adds following keys to the Task Context:
    
    - prev_data_interval_start_success
    - prev_data_interval_end_success
    - prev_start_date_success
    - prev_end_date_success
---
 .../execution_api/datamodels/taskinstance.py       |  9 +++
 .../execution_api/routes/task_instances.py         | 39 ++++++++++++-
 task_sdk/src/airflow/sdk/api/client.py             | 11 ++++
 .../src/airflow/sdk/api/datamodels/_generated.py   | 28 +++++-----
 task_sdk/src/airflow/sdk/execution_time/comms.py   | 46 ++++++++++++---
 task_sdk/src/airflow/sdk/execution_time/context.py | 27 ++++++++-
 .../src/airflow/sdk/execution_time/supervisor.py   |  6 ++
 .../src/airflow/sdk/execution_time/task_runner.py  | 18 ++++--
 task_sdk/tests/execution_time/test_supervisor.py   | 20 +++++++
 task_sdk/tests/execution_time/test_task_runner.py  | 37 +++++++++++-
 .../execution_api/routes/test_task_instances.py    | 65 ++++++++++++++++++++++
 11 files changed, 278 insertions(+), 28 deletions(-)

diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py 
b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
index 563b32a2693..bb6d643fb8e 100644
--- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
+++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
@@ -199,3 +199,12 @@ class TIRunContext(BaseModel):
 
     connections: Annotated[list[ConnectionResponse], 
Field(default_factory=list)]
     """Connections that can be accessed by the task instance."""
+
+
+class PrevSuccessfulDagRunResponse(BaseModel):
+    """Schema for response with previous successful DagRun information for 
Task Template Context."""
+
+    data_interval_start: UtcDateTime | None = None
+    data_interval_end: UtcDateTime | None = None
+    start_date: UtcDateTime | None = None
+    end_date: UtcDateTime | None = None
diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py 
b/airflow/api_fastapi/execution_api/routes/task_instances.py
index ba6ea0c14b6..4899e93c612 100644
--- a/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ b/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -31,6 +31,7 @@ from airflow.api_fastapi.common.db.common import SessionDep
 from airflow.api_fastapi.common.router import AirflowRouter
 from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
     DagRun,
+    PrevSuccessfulDagRunResponse,
     TIDeferredStatePayload,
     TIEnterRunningPayload,
     TIHeartbeatInfo,
@@ -45,7 +46,7 @@ from airflow.models.taskreschedule import TaskReschedule
 from airflow.models.trigger import Trigger
 from airflow.models.xcom import XCom
 from airflow.utils import timezone
-from airflow.utils.state import State, TerminalTIState
+from airflow.utils.state import DagRunState, State, TerminalTIState
 
 # TODO: Add dependency on JWT token
 router = AirflowRouter()
@@ -393,6 +394,42 @@ def ti_put_rtif(
     return {"message": "Rendered task instance fields successfully set"}
 
 
[email protected](
+    "/{task_instance_id}/previous-successful-dagrun",
+    status_code=status.HTTP_200_OK,
+    responses={
+        status.HTTP_404_NOT_FOUND: {"description": "Task Instance or Dag Run 
not found"},
+    },
+)
+def get_previous_successful_dagrun(
+    task_instance_id: UUID, session: SessionDep
+) -> PrevSuccessfulDagRunResponse:
+    """
+    Get the previous successful DagRun for a TaskInstance.
+
+    The data from this endpoint is used to get values for Task Context.
+    """
+    ti_id_str = str(task_instance_id)
+    task_instance = session.scalar(select(TI).where(TI.id == ti_id_str))
+    if not task_instance:
+        return PrevSuccessfulDagRunResponse()
+
+    dag_run = session.scalar(
+        select(DR)
+        .where(
+            DR.dag_id == task_instance.dag_id,
+            DR.logical_date < task_instance.logical_date,
+            DR.state == DagRunState.SUCCESS,
+        )
+        .order_by(DR.logical_date.desc())
+        .limit(1)
+    )
+    if not dag_run:
+        return PrevSuccessfulDagRunResponse()
+
+    return PrevSuccessfulDagRunResponse.model_validate(dag_run)
+
+
 def _is_eligible_to_retry(state: str, try_number: int, max_tries: int) -> bool:
     """Is task instance is eligible for retry."""
     if state == State.RESTARTING:
diff --git a/task_sdk/src/airflow/sdk/api/client.py 
b/task_sdk/src/airflow/sdk/api/client.py
index e73e5aebea6..b984669aa74 100644
--- a/task_sdk/src/airflow/sdk/api/client.py
+++ b/task_sdk/src/airflow/sdk/api/client.py
@@ -37,6 +37,7 @@ from airflow.sdk.api.datamodels._generated import (
     AssetResponse,
     ConnectionResponse,
     DagRunType,
+    PrevSuccessfulDagRunResponse,
     TerminalTIState,
     TIDeferredStatePayload,
     TIEnterRunningPayload,
@@ -161,6 +162,15 @@ class TaskInstanceOperations:
         # decouple from the server response string
         return {"ok": True}
 
+    def get_previous_successful_dagrun(self, id: uuid.UUID) -> 
PrevSuccessfulDagRunResponse:
+        """
+        Get the previous successful dag run for a given task instance.
+
+        The data from it is used to get values for Task Context.
+        """
+        resp = 
self.client.get(f"task-instances/{id}/previous-successful-dagrun")
+        return PrevSuccessfulDagRunResponse.model_validate_json(resp.read())
+
 
 class ConnectionOperations:
     __slots__ = ("client",)
@@ -181,6 +191,7 @@ class ConnectionOperations:
                     status_code=e.response.status_code,
                 )
                 return ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND, 
detail={"conn_id": conn_id})
+            raise
         return ConnectionResponse.model_validate_json(resp.read())
 
 
diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py 
b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
index f0a04da21c8..7d8bd25e959 100644
--- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -29,13 +29,15 @@ from uuid import UUID
 from pydantic import BaseModel, ConfigDict, Field
 
 
-class AssetAliasResponse(BaseModel):
+class AssetResponse(BaseModel):
     """
-    Asset alias schema with fields that are needed for Runtime.
+    Asset schema for responses with fields that are needed for Runtime.
     """
 
     name: Annotated[str, Field(title="Name")]
+    uri: Annotated[str, Field(title="Uri")]
     group: Annotated[str, Field(title="Group")]
+    extra: Annotated[dict[str, Any] | None, Field(title="Extra")] = None
 
 
 class ConnectionResponse(BaseModel):
@@ -78,6 +80,17 @@ class IntermediateTIState(str, Enum):
     DEFERRED = "deferred"
 
 
+class PrevSuccessfulDagRunResponse(BaseModel):
+    """
+    Schema for response with previous successful DagRun information for Task 
Template Context.
+    """
+
+    data_interval_start: Annotated[datetime | None, Field(title="Data Interval 
Start")] = None
+    data_interval_end: Annotated[datetime | None, Field(title="Data Interval 
End")] = None
+    start_date: Annotated[datetime | None, Field(title="Start Date")] = None
+    end_date: Annotated[datetime | None, Field(title="End Date")] = None
+
+
 class TIDeferredStatePayload(BaseModel):
     """
     Schema for updating TaskInstance to a deferred state.
@@ -196,17 +209,6 @@ class TaskInstance(BaseModel):
     hostname: Annotated[str | None, Field(title="Hostname")] = None
 
 
-class AssetResponse(BaseModel):
-    """
-    Asset schema for responses with fields that are needed for Runtime.
-    """
-
-    name: Annotated[str, Field(title="Name")]
-    uri: Annotated[str, Field(title="Uri")]
-    group: Annotated[str, Field(title="Group")]
-    extra: Annotated[dict[str, Any] | None, Field(title="Extra")] = None
-
-
 class DagRun(BaseModel):
     """
     Schema for DagRun model with minimal required fields needed for Runtime.
diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py 
b/task_sdk/src/airflow/sdk/execution_time/comms.py
index f8aaab65af4..007e3fe10fe 100644
--- a/task_sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task_sdk/src/airflow/sdk/execution_time/comms.py
@@ -45,6 +45,7 @@ from __future__ import annotations
 
 from datetime import datetime
 from typing import Annotated, Literal, Union
+from uuid import UUID
 
 from fastapi import Body
 from pydantic import BaseModel, ConfigDict, Field, JsonValue
@@ -53,6 +54,7 @@ from airflow.sdk.api.datamodels._generated import (
     AssetResponse,
     BundleInfo,
     ConnectionResponse,
+    PrevSuccessfulDagRunResponse,
     TaskInstance,
     TerminalTIState,
     TIDeferredStatePayload,
@@ -146,6 +148,20 @@ class VariableResult(VariableResponse):
         return cls(**variable_response.model_dump(exclude_defaults=True), 
type="VariableResult")
 
 
+class PrevSuccessfulDagRunResult(PrevSuccessfulDagRunResponse):
+    type: Literal["PrevSuccessfulDagRunResult"] = "PrevSuccessfulDagRunResult"
+
+    @classmethod
+    def from_dagrun_response(cls, prev_dag_run: PrevSuccessfulDagRunResponse) 
-> PrevSuccessfulDagRunResult:
+        """
+        Get a result object from response object.
+
+        PrevSuccessfulDagRunResponse is autogenerated from the API schema, so 
we need to convert it to
+        PrevSuccessfulDagRunResult for communication between the Supervisor 
and the task process.
+        """
+        return cls(**prev_dag_run.model_dump(exclude_defaults=True), 
type="PrevSuccessfulDagRunResult")
+
+
 class ErrorResponse(BaseModel):
     error: ErrorType = ErrorType.GENERIC_ERROR
     detail: dict | None = None
@@ -153,7 +169,15 @@ class ErrorResponse(BaseModel):
 
 
 ToTask = Annotated[
-    Union[StartupDetails, XComResult, ConnectionResult, VariableResult, 
ErrorResponse, AssetResult],
+    Union[
+        AssetResult,
+        ConnectionResult,
+        ErrorResponse,
+        PrevSuccessfulDagRunResult,
+        StartupDetails,
+        VariableResult,
+        XComResult,
+    ],
     Field(discriminator="type"),
 ]
 
@@ -261,19 +285,25 @@ class GetAssetByUri(BaseModel):
     type: Literal["GetAssetByUri"] = "GetAssetByUri"
 
 
+class GetPrevSuccessfulDagRun(BaseModel):
+    ti_id: UUID
+    type: Literal["GetPrevSuccessfulDagRun"] = "GetPrevSuccessfulDagRun"
+
+
 ToSupervisor = Annotated[
     Union[
-        TaskState,
-        GetXCom,
-        GetConnection,
-        GetVariable,
+        DeferTask,
         GetAssetByName,
         GetAssetByUri,
-        DeferTask,
+        GetConnection,
+        GetPrevSuccessfulDagRun,
+        GetVariable,
+        GetXCom,
         PutVariable,
-        SetXCom,
-        SetRenderedFields,
         RescheduleTask,
+        SetRenderedFields,
+        SetXCom,
+        TaskState,
     ],
     Field(discriminator="type"),
 ]
diff --git a/task_sdk/src/airflow/sdk/execution_time/context.py 
b/task_sdk/src/airflow/sdk/execution_time/context.py
index a068b53aec7..984919ea1c8 100644
--- a/task_sdk/src/airflow/sdk/execution_time/context.py
+++ b/task_sdk/src/airflow/sdk/execution_time/context.py
@@ -18,6 +18,7 @@ from __future__ import annotations
 
 import contextlib
 from collections.abc import Generator, Iterator, Mapping
+from functools import cache
 from typing import TYPE_CHECKING, Any, Union
 
 import attrs
@@ -39,10 +40,17 @@ from airflow.sdk.definitions.asset import (
 from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
 
 if TYPE_CHECKING:
+    from uuid import UUID
+
     from airflow.sdk.definitions.connection import Connection
     from airflow.sdk.definitions.context import Context
     from airflow.sdk.definitions.variable import Variable
-    from airflow.sdk.execution_time.comms import AssetResult, 
ConnectionResult, VariableResult
+    from airflow.sdk.execution_time.comms import (
+        AssetResult,
+        ConnectionResult,
+        PrevSuccessfulDagRunResponse,
+        VariableResult,
+    )
 
 log = structlog.get_logger(logger_name="task")
 
@@ -272,6 +280,23 @@ class OutletEventAccessors(Mapping[Union[Asset, 
AssetAlias], OutletEventAccessor
         return Asset(**msg.model_dump(exclude={"type"}))
 
 
+@cache  # Prevent multiple API access.
+def get_previous_dagrun_success(ti_id: UUID) -> PrevSuccessfulDagRunResponse:
+    from airflow.sdk.execution_time.comms import (
+        GetPrevSuccessfulDagRun,
+        PrevSuccessfulDagRunResponse,
+        PrevSuccessfulDagRunResult,
+    )
+    from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+    SUPERVISOR_COMMS.send_request(log=log, 
msg=GetPrevSuccessfulDagRun(ti_id=ti_id))
+    msg = SUPERVISOR_COMMS.get_message()
+
+    if TYPE_CHECKING:
+        assert isinstance(msg, PrevSuccessfulDagRunResult)
+    return PrevSuccessfulDagRunResponse(**msg.model_dump(exclude={"type"}))
+
+
 @contextlib.contextmanager
 def set_current_context(context: Context) -> Generator[Context, None, None]:
     """
diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
index bd50ee5126b..45da306722f 100644
--- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -67,8 +67,10 @@ from airflow.sdk.execution_time.comms import (
     GetAssetByName,
     GetAssetByUri,
     GetConnection,
+    GetPrevSuccessfulDagRun,
     GetVariable,
     GetXCom,
+    PrevSuccessfulDagRunResult,
     PutVariable,
     RescheduleTask,
     SetRenderedFields,
@@ -798,6 +800,10 @@ class ActivitySubprocess(WatchedSubprocess):
             asset_resp = self.client.assets.get(uri=msg.uri)
             asset_result = AssetResult.from_asset_response(asset_resp)
             resp = asset_result.model_dump_json(exclude_unset=True).encode()
+        elif isinstance(msg, GetPrevSuccessfulDagRun):
+            dagrun_resp = 
self.client.task_instances.get_previous_successful_dagrun(self.id)
+            dagrun_result = 
PrevSuccessfulDagRunResult.from_dagrun_response(dagrun_resp)
+            resp = dagrun_result.model_dump_json(exclude_unset=True).encode()
         else:
             log.error("Unhandled request", msg=msg)
             return
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index d4816c8ae59..230f5414a87 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -28,6 +28,7 @@ from pathlib import Path
 from typing import TYPE_CHECKING, Annotated, Any, Generic, TextIO, TypeVar
 
 import attrs
+import lazy_object_proxy
 import structlog
 from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter
 
@@ -52,6 +53,7 @@ from airflow.sdk.execution_time.context import (
     MacrosAccessor,
     OutletEventAccessors,
     VariableAccessor,
+    get_previous_dagrun_success,
     set_current_context,
 )
 from airflow.utils.net import get_hostname
@@ -100,10 +102,6 @@ class RuntimeTaskInstance(TaskInstance):
             "macros": MacrosAccessor(),
             # "params": validated_params,
             # TODO: Make this go through Public API longer term.
-            # "prev_data_interval_start_success": 
get_prev_data_interval_start_success(),
-            # "prev_data_interval_end_success": 
get_prev_data_interval_end_success(),
-            # "prev_start_date_success": get_prev_start_date_success(),
-            # "prev_end_date_success": get_prev_end_date_success(),
             # "test_mode": task_instance.test_mode,
             # "triggering_asset_events": 
lazy_object_proxy.Proxy(get_triggering_events),
             "var": {
@@ -134,6 +132,18 @@ class RuntimeTaskInstance(TaskInstance):
                 "ts": ts,
                 "ts_nodash": ts_nodash,
                 "ts_nodash_with_tz": ts_nodash_with_tz,
+                "prev_data_interval_start_success": lazy_object_proxy.Proxy(
+                    lambda: 
get_previous_dagrun_success(self.id).data_interval_start
+                ),
+                "prev_data_interval_end_success": lazy_object_proxy.Proxy(
+                    lambda: 
get_previous_dagrun_success(self.id).data_interval_end
+                ),
+                "prev_start_date_success": lazy_object_proxy.Proxy(
+                    lambda: get_previous_dagrun_success(self.id).start_date
+                ),
+                "prev_end_date_success": lazy_object_proxy.Proxy(
+                    lambda: get_previous_dagrun_success(self.id).end_date
+                ),
             }
             context.update(context_from_server)
 
diff --git a/task_sdk/tests/execution_time/test_supervisor.py 
b/task_sdk/tests/execution_time/test_supervisor.py
index cae34a90adf..f2fcca8a2ab 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -47,8 +47,10 @@ from airflow.sdk.execution_time.comms import (
     GetAssetByName,
     GetAssetByUri,
     GetConnection,
+    GetPrevSuccessfulDagRun,
     GetVariable,
     GetXCom,
+    PrevSuccessfulDagRunResult,
     PutVariable,
     RescheduleTask,
     SetRenderedFields,
@@ -976,6 +978,24 @@ class TestHandleRequest:
                 AssetResult(name="asset", uri="s3://bucket/obj", 
group="asset"),
                 id="get_asset_by_uri",
             ),
+            pytest.param(
+                GetPrevSuccessfulDagRun(ti_id=TI_ID),
+                (
+                    
b'{"data_interval_start":"2025-01-10T12:00:00Z","data_interval_end":"2025-01-10T14:00:00Z",'
+                    
b'"start_date":"2025-01-10T12:00:00Z","end_date":"2025-01-10T14:00:00Z",'
+                    b'"type":"PrevSuccessfulDagRunResult"}\n'
+                ),
+                "task_instances.get_previous_successful_dagrun",
+                (TI_ID,),
+                {},
+                PrevSuccessfulDagRunResult(
+                    start_date=timezone.parse("2025-01-10T12:00:00Z"),
+                    end_date=timezone.parse("2025-01-10T14:00:00Z"),
+                    data_interval_start=timezone.parse("2025-01-10T12:00:00Z"),
+                    data_interval_end=timezone.parse("2025-01-10T14:00:00Z"),
+                ),
+                id="get_prev_successful_dagrun",
+            ),
         ],
     )
     def test_handle_requests(
diff --git a/task_sdk/tests/execution_time/test_task_runner.py 
b/task_sdk/tests/execution_time/test_task_runner.py
index 1a09c05908d..9889c192e90 100644
--- a/task_sdk/tests/execution_time/test_task_runner.py
+++ b/task_sdk/tests/execution_time/test_task_runner.py
@@ -46,6 +46,7 @@ from airflow.sdk.execution_time.comms import (
     GetConnection,
     GetVariable,
     GetXCom,
+    PrevSuccessfulDagRunResult,
     SetRenderedFields,
     StartupDetails,
     TaskState,
@@ -626,7 +627,7 @@ class TestRuntimeTaskInstance:
             "ti": runtime_ti,
         }
 
-    def test_get_context_with_ti_context_from_server(self, create_runtime_ti):
+    def test_get_context_with_ti_context_from_server(self, create_runtime_ti, 
mock_supervisor_comms):
         """Test the context keys are added when sent from API server 
(mocked)"""
         from airflow.utils import timezone
 
@@ -639,6 +640,13 @@ class TestRuntimeTaskInstance:
 
         dr = runtime_ti._ti_context_from_server.dag_run
 
+        mock_supervisor_comms.get_message.return_value = 
PrevSuccessfulDagRunResult(
+            data_interval_end=dr.logical_date - timedelta(hours=1),
+            data_interval_start=dr.logical_date - timedelta(hours=2),
+            start_date=dr.start_date - timedelta(hours=1),
+            end_date=dr.start_date,
+        )
+
         context = runtime_ti.get_template_context()
 
         assert context == {
@@ -653,6 +661,10 @@ class TestRuntimeTaskInstance:
             "map_index_template": task.map_index_template,
             "outlet_events": OutletEventAccessors(),
             "outlets": task.outlets,
+            "prev_data_interval_end_success": timezone.datetime(2024, 12, 1, 
0, 0, 0),
+            "prev_data_interval_start_success": timezone.datetime(2024, 11, 
30, 23, 0, 0),
+            "prev_end_date_success": timezone.datetime(2024, 12, 1, 1, 0, 0),
+            "prev_start_date_success": timezone.datetime(2024, 12, 1, 0, 0, 0),
             "run_id": "test_run",
             "task": task,
             "task_instance": runtime_ti,
@@ -670,6 +682,29 @@ class TestRuntimeTaskInstance:
             "ts_nodash_with_tz": "20241201T010000+0000",
         }
 
+    def test_lazy_loading_not_triggered_until_accessed(self, 
create_runtime_ti, mock_supervisor_comms):
+        """Ensure lazy-loaded attributes are not resolved until accessed."""
+        task = BaseOperator(task_id="hello")
+        runtime_ti = create_runtime_ti(task=task, dag_id="basic_task")
+
+        mock_supervisor_comms.get_message.return_value = 
PrevSuccessfulDagRunResult(
+            data_interval_end=timezone.datetime(2025, 1, 1, 2, 0, 0),
+            data_interval_start=timezone.datetime(2025, 1, 1, 1, 0, 0),
+            start_date=timezone.datetime(2025, 1, 1, 1, 0, 0),
+            end_date=timezone.datetime(2025, 1, 1, 2, 0, 0),
+        )
+
+        context = runtime_ti.get_template_context()
+
+        # Assert lazy attributes are not resolved initially
+        mock_supervisor_comms.get_message.assert_not_called()
+
+        # Access a lazy-loaded attribute to trigger computation
+        assert context["prev_data_interval_start_success"] == 
timezone.datetime(2025, 1, 1, 1, 0, 0)
+
+        # Now the lazy attribute should trigger the call
+        mock_supervisor_comms.get_message.assert_called_once()
+
     def test_get_connection_from_context(self, create_runtime_ti, 
mock_supervisor_comms):
         """Test that the connection is fetched from the API server via the 
Supervisor lazily when accessed"""
 
diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py 
b/tests/api_fastapi/execution_api/routes/test_task_instances.py
index e6da0f3a192..c11fbb21bb2 100644
--- a/tests/api_fastapi/execution_api/routes/test_task_instances.py
+++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py
@@ -785,3 +785,68 @@ class TestTIPutRTIF:
         response = client.put(f"/execution/task-instances/{random_id}/rtif", 
json=payload)
         assert response.status_code == 404
         assert response.json()["detail"] == "Not Found"
+
+
+class TestPreviousDagRun:
+    def setup_method(self):
+        clear_db_runs()
+
+    def teardown_method(self):
+        clear_db_runs()
+
+    def test_ti_previous_dag_run(self, client, session, create_task_instance, 
dag_maker):
+        """Test that the previous dag run is returned correctly for a task 
instance."""
+        ti = create_task_instance(
+            task_id="test_ti_previous_dag_run",
+            dag_id="test_dag",
+            logical_date=timezone.datetime(2025, 1, 19),
+            state=State.RUNNING,
+            start_date=timezone.datetime(2024, 1, 17),
+            session=session,
+        )
+        session.commit()
+
+        # Create 2 DagRuns for the same DAG to verify that the correct DagRun 
(last) is returned
+        dr1 = dag_maker.create_dagrun(
+            run_id="test_run_id_1",
+            logical_date=timezone.datetime(2025, 1, 17),
+            run_type="scheduled",
+            state=State.SUCCESS,
+            session=session,
+        )
+        dr1.end_date = timezone.datetime(2025, 1, 17, 1, 0, 0)
+
+        dr2 = dag_maker.create_dagrun(
+            run_id="test_run_id_2",
+            logical_date=timezone.datetime(2025, 1, 18),
+            run_type="scheduled",
+            state=State.SUCCESS,
+            session=session,
+        )
+
+        dr2.end_date = timezone.datetime(2025, 1, 18, 1, 0, 0)
+
+        session.commit()
+
+        response = 
client.get(f"/execution/task-instances/{ti.id}/previous-successful-dagrun")
+        assert response.status_code == 200
+        assert response.json() == {
+            "data_interval_start": "2025-01-18T00:00:00Z",
+            "data_interval_end": "2025-01-19T00:00:00Z",
+            "start_date": "2024-01-17T00:00:00Z",
+            "end_date": "2025-01-18T01:00:00Z",
+        }
+
+    def test_ti_previous_dag_run_not_found(self, client, session):
+        ti_id = "0182e924-0f1e-77e6-ab50-e977118bc139"
+
+        assert session.get(TaskInstance, ti_id) is None
+
+        response = 
client.get(f"/execution/task-instances/{ti_id}/previous-successful-dagrun")
+        assert response.status_code == 200
+        assert response.json() == {
+            "data_interval_start": None,
+            "data_interval_end": None,
+            "start_date": None,
+            "end_date": None,
+        }

Reply via email to