This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 41b151e7dde AIP-72: Get Previous Successful Dag Run in Task Context
(#45813)
41b151e7dde is described below
commit 41b151e7dde473ec445f9f78fb4b8db826c368fc
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Jan 21 11:14:25 2025 +0530
AIP-72: Get Previous Successful Dag Run in Task Context (#45813)
closes https://github.com/apache/airflow/issues/45814
Adds following keys to the Task Context:
- prev_data_interval_start_success
- prev_data_interval_end_success
- prev_start_date_success
- prev_end_date_success
---
.../execution_api/datamodels/taskinstance.py | 9 +++
.../execution_api/routes/task_instances.py | 39 ++++++++++++-
task_sdk/src/airflow/sdk/api/client.py | 11 ++++
.../src/airflow/sdk/api/datamodels/_generated.py | 28 +++++-----
task_sdk/src/airflow/sdk/execution_time/comms.py | 46 ++++++++++++---
task_sdk/src/airflow/sdk/execution_time/context.py | 27 ++++++++-
.../src/airflow/sdk/execution_time/supervisor.py | 6 ++
.../src/airflow/sdk/execution_time/task_runner.py | 18 ++++--
task_sdk/tests/execution_time/test_supervisor.py | 20 +++++++
task_sdk/tests/execution_time/test_task_runner.py | 37 +++++++++++-
.../execution_api/routes/test_task_instances.py | 65 ++++++++++++++++++++++
11 files changed, 278 insertions(+), 28 deletions(-)
diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
index 563b32a2693..bb6d643fb8e 100644
--- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
+++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
@@ -199,3 +199,12 @@ class TIRunContext(BaseModel):
connections: Annotated[list[ConnectionResponse],
Field(default_factory=list)]
"""Connections that can be accessed by the task instance."""
+
+
+class PrevSuccessfulDagRunResponse(BaseModel):
+ """Schema for response with previous successful DagRun information for
Task Template Context."""
+
+ data_interval_start: UtcDateTime | None = None
+ data_interval_end: UtcDateTime | None = None
+ start_date: UtcDateTime | None = None
+ end_date: UtcDateTime | None = None
diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py
b/airflow/api_fastapi/execution_api/routes/task_instances.py
index ba6ea0c14b6..4899e93c612 100644
--- a/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ b/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -31,6 +31,7 @@ from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
DagRun,
+ PrevSuccessfulDagRunResponse,
TIDeferredStatePayload,
TIEnterRunningPayload,
TIHeartbeatInfo,
@@ -45,7 +46,7 @@ from airflow.models.taskreschedule import TaskReschedule
from airflow.models.trigger import Trigger
from airflow.models.xcom import XCom
from airflow.utils import timezone
-from airflow.utils.state import State, TerminalTIState
+from airflow.utils.state import DagRunState, State, TerminalTIState
# TODO: Add dependency on JWT token
router = AirflowRouter()
@@ -393,6 +394,42 @@ def ti_put_rtif(
return {"message": "Rendered task instance fields successfully set"}
[email protected](
+ "/{task_instance_id}/previous-successful-dagrun",
+ status_code=status.HTTP_200_OK,
+ responses={
+ status.HTTP_404_NOT_FOUND: {"description": "Task Instance or Dag Run
not found"},
+ },
+)
+def get_previous_successful_dagrun(
+ task_instance_id: UUID, session: SessionDep
+) -> PrevSuccessfulDagRunResponse:
+ """
+ Get the previous successful DagRun for a TaskInstance.
+
+ The data from this endpoint is used to get values for Task Context.
+ """
+ ti_id_str = str(task_instance_id)
+ task_instance = session.scalar(select(TI).where(TI.id == ti_id_str))
+ if not task_instance:
+ return PrevSuccessfulDagRunResponse()
+
+ dag_run = session.scalar(
+ select(DR)
+ .where(
+ DR.dag_id == task_instance.dag_id,
+ DR.logical_date < task_instance.logical_date,
+ DR.state == DagRunState.SUCCESS,
+ )
+ .order_by(DR.logical_date.desc())
+ .limit(1)
+ )
+ if not dag_run:
+ return PrevSuccessfulDagRunResponse()
+
+ return PrevSuccessfulDagRunResponse.model_validate(dag_run)
+
+
def _is_eligible_to_retry(state: str, try_number: int, max_tries: int) -> bool:
"""Is task instance is eligible for retry."""
if state == State.RESTARTING:
diff --git a/task_sdk/src/airflow/sdk/api/client.py
b/task_sdk/src/airflow/sdk/api/client.py
index e73e5aebea6..b984669aa74 100644
--- a/task_sdk/src/airflow/sdk/api/client.py
+++ b/task_sdk/src/airflow/sdk/api/client.py
@@ -37,6 +37,7 @@ from airflow.sdk.api.datamodels._generated import (
AssetResponse,
ConnectionResponse,
DagRunType,
+ PrevSuccessfulDagRunResponse,
TerminalTIState,
TIDeferredStatePayload,
TIEnterRunningPayload,
@@ -161,6 +162,15 @@ class TaskInstanceOperations:
# decouple from the server response string
return {"ok": True}
+ def get_previous_successful_dagrun(self, id: uuid.UUID) ->
PrevSuccessfulDagRunResponse:
+ """
+ Get the previous successful dag run for a given task instance.
+
+ The data from it is used to get values for Task Context.
+ """
+ resp =
self.client.get(f"task-instances/{id}/previous-successful-dagrun")
+ return PrevSuccessfulDagRunResponse.model_validate_json(resp.read())
+
class ConnectionOperations:
__slots__ = ("client",)
@@ -181,6 +191,7 @@ class ConnectionOperations:
status_code=e.response.status_code,
)
return ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND,
detail={"conn_id": conn_id})
+ raise
return ConnectionResponse.model_validate_json(resp.read())
diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
index f0a04da21c8..7d8bd25e959 100644
--- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -29,13 +29,15 @@ from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
-class AssetAliasResponse(BaseModel):
+class AssetResponse(BaseModel):
"""
- Asset alias schema with fields that are needed for Runtime.
+ Asset schema for responses with fields that are needed for Runtime.
"""
name: Annotated[str, Field(title="Name")]
+ uri: Annotated[str, Field(title="Uri")]
group: Annotated[str, Field(title="Group")]
+ extra: Annotated[dict[str, Any] | None, Field(title="Extra")] = None
class ConnectionResponse(BaseModel):
@@ -78,6 +80,17 @@ class IntermediateTIState(str, Enum):
DEFERRED = "deferred"
+class PrevSuccessfulDagRunResponse(BaseModel):
+ """
+ Schema for response with previous successful DagRun information for Task
Template Context.
+ """
+
+ data_interval_start: Annotated[datetime | None, Field(title="Data Interval
Start")] = None
+ data_interval_end: Annotated[datetime | None, Field(title="Data Interval
End")] = None
+ start_date: Annotated[datetime | None, Field(title="Start Date")] = None
+ end_date: Annotated[datetime | None, Field(title="End Date")] = None
+
+
class TIDeferredStatePayload(BaseModel):
"""
Schema for updating TaskInstance to a deferred state.
@@ -196,17 +209,6 @@ class TaskInstance(BaseModel):
hostname: Annotated[str | None, Field(title="Hostname")] = None
-class AssetResponse(BaseModel):
- """
- Asset schema for responses with fields that are needed for Runtime.
- """
-
- name: Annotated[str, Field(title="Name")]
- uri: Annotated[str, Field(title="Uri")]
- group: Annotated[str, Field(title="Group")]
- extra: Annotated[dict[str, Any] | None, Field(title="Extra")] = None
-
-
class DagRun(BaseModel):
"""
Schema for DagRun model with minimal required fields needed for Runtime.
diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py
b/task_sdk/src/airflow/sdk/execution_time/comms.py
index f8aaab65af4..007e3fe10fe 100644
--- a/task_sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task_sdk/src/airflow/sdk/execution_time/comms.py
@@ -45,6 +45,7 @@ from __future__ import annotations
from datetime import datetime
from typing import Annotated, Literal, Union
+from uuid import UUID
from fastapi import Body
from pydantic import BaseModel, ConfigDict, Field, JsonValue
@@ -53,6 +54,7 @@ from airflow.sdk.api.datamodels._generated import (
AssetResponse,
BundleInfo,
ConnectionResponse,
+ PrevSuccessfulDagRunResponse,
TaskInstance,
TerminalTIState,
TIDeferredStatePayload,
@@ -146,6 +148,20 @@ class VariableResult(VariableResponse):
return cls(**variable_response.model_dump(exclude_defaults=True),
type="VariableResult")
+class PrevSuccessfulDagRunResult(PrevSuccessfulDagRunResponse):
+ type: Literal["PrevSuccessfulDagRunResult"] = "PrevSuccessfulDagRunResult"
+
+ @classmethod
+ def from_dagrun_response(cls, prev_dag_run: PrevSuccessfulDagRunResponse)
-> PrevSuccessfulDagRunResult:
+ """
+ Get a result object from response object.
+
+ PrevSuccessfulDagRunResponse is autogenerated from the API schema, so
we need to convert it to
+ PrevSuccessfulDagRunResult for communication between the Supervisor
and the task process.
+ """
+ return cls(**prev_dag_run.model_dump(exclude_defaults=True),
type="PrevSuccessfulDagRunResult")
+
+
class ErrorResponse(BaseModel):
error: ErrorType = ErrorType.GENERIC_ERROR
detail: dict | None = None
@@ -153,7 +169,15 @@ class ErrorResponse(BaseModel):
ToTask = Annotated[
- Union[StartupDetails, XComResult, ConnectionResult, VariableResult,
ErrorResponse, AssetResult],
+ Union[
+ AssetResult,
+ ConnectionResult,
+ ErrorResponse,
+ PrevSuccessfulDagRunResult,
+ StartupDetails,
+ VariableResult,
+ XComResult,
+ ],
Field(discriminator="type"),
]
@@ -261,19 +285,25 @@ class GetAssetByUri(BaseModel):
type: Literal["GetAssetByUri"] = "GetAssetByUri"
+class GetPrevSuccessfulDagRun(BaseModel):
+ ti_id: UUID
+ type: Literal["GetPrevSuccessfulDagRun"] = "GetPrevSuccessfulDagRun"
+
+
ToSupervisor = Annotated[
Union[
- TaskState,
- GetXCom,
- GetConnection,
- GetVariable,
+ DeferTask,
GetAssetByName,
GetAssetByUri,
- DeferTask,
+ GetConnection,
+ GetPrevSuccessfulDagRun,
+ GetVariable,
+ GetXCom,
PutVariable,
- SetXCom,
- SetRenderedFields,
RescheduleTask,
+ SetRenderedFields,
+ SetXCom,
+ TaskState,
],
Field(discriminator="type"),
]
diff --git a/task_sdk/src/airflow/sdk/execution_time/context.py
b/task_sdk/src/airflow/sdk/execution_time/context.py
index a068b53aec7..984919ea1c8 100644
--- a/task_sdk/src/airflow/sdk/execution_time/context.py
+++ b/task_sdk/src/airflow/sdk/execution_time/context.py
@@ -18,6 +18,7 @@ from __future__ import annotations
import contextlib
from collections.abc import Generator, Iterator, Mapping
+from functools import cache
from typing import TYPE_CHECKING, Any, Union
import attrs
@@ -39,10 +40,17 @@ from airflow.sdk.definitions.asset import (
from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
if TYPE_CHECKING:
+ from uuid import UUID
+
from airflow.sdk.definitions.connection import Connection
from airflow.sdk.definitions.context import Context
from airflow.sdk.definitions.variable import Variable
- from airflow.sdk.execution_time.comms import AssetResult,
ConnectionResult, VariableResult
+ from airflow.sdk.execution_time.comms import (
+ AssetResult,
+ ConnectionResult,
+ PrevSuccessfulDagRunResponse,
+ VariableResult,
+ )
log = structlog.get_logger(logger_name="task")
@@ -272,6 +280,23 @@ class OutletEventAccessors(Mapping[Union[Asset,
AssetAlias], OutletEventAccessor
return Asset(**msg.model_dump(exclude={"type"}))
+@cache # Prevent multiple API access.
+def get_previous_dagrun_success(ti_id: UUID) -> PrevSuccessfulDagRunResponse:
+ from airflow.sdk.execution_time.comms import (
+ GetPrevSuccessfulDagRun,
+ PrevSuccessfulDagRunResponse,
+ PrevSuccessfulDagRunResult,
+ )
+ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+ SUPERVISOR_COMMS.send_request(log=log,
msg=GetPrevSuccessfulDagRun(ti_id=ti_id))
+ msg = SUPERVISOR_COMMS.get_message()
+
+ if TYPE_CHECKING:
+ assert isinstance(msg, PrevSuccessfulDagRunResult)
+ return PrevSuccessfulDagRunResponse(**msg.model_dump(exclude={"type"}))
+
+
@contextlib.contextmanager
def set_current_context(context: Context) -> Generator[Context, None, None]:
"""
diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
index bd50ee5126b..45da306722f 100644
--- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -67,8 +67,10 @@ from airflow.sdk.execution_time.comms import (
GetAssetByName,
GetAssetByUri,
GetConnection,
+ GetPrevSuccessfulDagRun,
GetVariable,
GetXCom,
+ PrevSuccessfulDagRunResult,
PutVariable,
RescheduleTask,
SetRenderedFields,
@@ -798,6 +800,10 @@ class ActivitySubprocess(WatchedSubprocess):
asset_resp = self.client.assets.get(uri=msg.uri)
asset_result = AssetResult.from_asset_response(asset_resp)
resp = asset_result.model_dump_json(exclude_unset=True).encode()
+ elif isinstance(msg, GetPrevSuccessfulDagRun):
+ dagrun_resp =
self.client.task_instances.get_previous_successful_dagrun(self.id)
+ dagrun_result =
PrevSuccessfulDagRunResult.from_dagrun_response(dagrun_resp)
+ resp = dagrun_result.model_dump_json(exclude_unset=True).encode()
else:
log.error("Unhandled request", msg=msg)
return
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index d4816c8ae59..230f5414a87 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -28,6 +28,7 @@ from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Any, Generic, TextIO, TypeVar
import attrs
+import lazy_object_proxy
import structlog
from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter
@@ -52,6 +53,7 @@ from airflow.sdk.execution_time.context import (
MacrosAccessor,
OutletEventAccessors,
VariableAccessor,
+ get_previous_dagrun_success,
set_current_context,
)
from airflow.utils.net import get_hostname
@@ -100,10 +102,6 @@ class RuntimeTaskInstance(TaskInstance):
"macros": MacrosAccessor(),
# "params": validated_params,
# TODO: Make this go through Public API longer term.
- # "prev_data_interval_start_success":
get_prev_data_interval_start_success(),
- # "prev_data_interval_end_success":
get_prev_data_interval_end_success(),
- # "prev_start_date_success": get_prev_start_date_success(),
- # "prev_end_date_success": get_prev_end_date_success(),
# "test_mode": task_instance.test_mode,
# "triggering_asset_events":
lazy_object_proxy.Proxy(get_triggering_events),
"var": {
@@ -134,6 +132,18 @@ class RuntimeTaskInstance(TaskInstance):
"ts": ts,
"ts_nodash": ts_nodash,
"ts_nodash_with_tz": ts_nodash_with_tz,
+ "prev_data_interval_start_success": lazy_object_proxy.Proxy(
+ lambda:
get_previous_dagrun_success(self.id).data_interval_start
+ ),
+ "prev_data_interval_end_success": lazy_object_proxy.Proxy(
+ lambda:
get_previous_dagrun_success(self.id).data_interval_end
+ ),
+ "prev_start_date_success": lazy_object_proxy.Proxy(
+ lambda: get_previous_dagrun_success(self.id).start_date
+ ),
+ "prev_end_date_success": lazy_object_proxy.Proxy(
+ lambda: get_previous_dagrun_success(self.id).end_date
+ ),
}
context.update(context_from_server)
diff --git a/task_sdk/tests/execution_time/test_supervisor.py
b/task_sdk/tests/execution_time/test_supervisor.py
index cae34a90adf..f2fcca8a2ab 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -47,8 +47,10 @@ from airflow.sdk.execution_time.comms import (
GetAssetByName,
GetAssetByUri,
GetConnection,
+ GetPrevSuccessfulDagRun,
GetVariable,
GetXCom,
+ PrevSuccessfulDagRunResult,
PutVariable,
RescheduleTask,
SetRenderedFields,
@@ -976,6 +978,24 @@ class TestHandleRequest:
AssetResult(name="asset", uri="s3://bucket/obj",
group="asset"),
id="get_asset_by_uri",
),
+ pytest.param(
+ GetPrevSuccessfulDagRun(ti_id=TI_ID),
+ (
+
b'{"data_interval_start":"2025-01-10T12:00:00Z","data_interval_end":"2025-01-10T14:00:00Z",'
+
b'"start_date":"2025-01-10T12:00:00Z","end_date":"2025-01-10T14:00:00Z",'
+ b'"type":"PrevSuccessfulDagRunResult"}\n'
+ ),
+ "task_instances.get_previous_successful_dagrun",
+ (TI_ID,),
+ {},
+ PrevSuccessfulDagRunResult(
+ start_date=timezone.parse("2025-01-10T12:00:00Z"),
+ end_date=timezone.parse("2025-01-10T14:00:00Z"),
+ data_interval_start=timezone.parse("2025-01-10T12:00:00Z"),
+ data_interval_end=timezone.parse("2025-01-10T14:00:00Z"),
+ ),
+ id="get_prev_successful_dagrun",
+ ),
],
)
def test_handle_requests(
diff --git a/task_sdk/tests/execution_time/test_task_runner.py
b/task_sdk/tests/execution_time/test_task_runner.py
index 1a09c05908d..9889c192e90 100644
--- a/task_sdk/tests/execution_time/test_task_runner.py
+++ b/task_sdk/tests/execution_time/test_task_runner.py
@@ -46,6 +46,7 @@ from airflow.sdk.execution_time.comms import (
GetConnection,
GetVariable,
GetXCom,
+ PrevSuccessfulDagRunResult,
SetRenderedFields,
StartupDetails,
TaskState,
@@ -626,7 +627,7 @@ class TestRuntimeTaskInstance:
"ti": runtime_ti,
}
- def test_get_context_with_ti_context_from_server(self, create_runtime_ti):
+ def test_get_context_with_ti_context_from_server(self, create_runtime_ti,
mock_supervisor_comms):
"""Test the context keys are added when sent from API server
(mocked)"""
from airflow.utils import timezone
@@ -639,6 +640,13 @@ class TestRuntimeTaskInstance:
dr = runtime_ti._ti_context_from_server.dag_run
+ mock_supervisor_comms.get_message.return_value =
PrevSuccessfulDagRunResult(
+ data_interval_end=dr.logical_date - timedelta(hours=1),
+ data_interval_start=dr.logical_date - timedelta(hours=2),
+ start_date=dr.start_date - timedelta(hours=1),
+ end_date=dr.start_date,
+ )
+
context = runtime_ti.get_template_context()
assert context == {
@@ -653,6 +661,10 @@ class TestRuntimeTaskInstance:
"map_index_template": task.map_index_template,
"outlet_events": OutletEventAccessors(),
"outlets": task.outlets,
+ "prev_data_interval_end_success": timezone.datetime(2024, 12, 1,
0, 0, 0),
+ "prev_data_interval_start_success": timezone.datetime(2024, 11,
30, 23, 0, 0),
+ "prev_end_date_success": timezone.datetime(2024, 12, 1, 1, 0, 0),
+ "prev_start_date_success": timezone.datetime(2024, 12, 1, 0, 0, 0),
"run_id": "test_run",
"task": task,
"task_instance": runtime_ti,
@@ -670,6 +682,29 @@ class TestRuntimeTaskInstance:
"ts_nodash_with_tz": "20241201T010000+0000",
}
+ def test_lazy_loading_not_triggered_until_accessed(self,
create_runtime_ti, mock_supervisor_comms):
+ """Ensure lazy-loaded attributes are not resolved until accessed."""
+ task = BaseOperator(task_id="hello")
+ runtime_ti = create_runtime_ti(task=task, dag_id="basic_task")
+
+ mock_supervisor_comms.get_message.return_value =
PrevSuccessfulDagRunResult(
+ data_interval_end=timezone.datetime(2025, 1, 1, 2, 0, 0),
+ data_interval_start=timezone.datetime(2025, 1, 1, 1, 0, 0),
+ start_date=timezone.datetime(2025, 1, 1, 1, 0, 0),
+ end_date=timezone.datetime(2025, 1, 1, 2, 0, 0),
+ )
+
+ context = runtime_ti.get_template_context()
+
+ # Assert lazy attributes are not resolved initially
+ mock_supervisor_comms.get_message.assert_not_called()
+
+ # Access a lazy-loaded attribute to trigger computation
+ assert context["prev_data_interval_start_success"] ==
timezone.datetime(2025, 1, 1, 1, 0, 0)
+
+ # Now the lazy attribute should trigger the call
+ mock_supervisor_comms.get_message.assert_called_once()
+
def test_get_connection_from_context(self, create_runtime_ti,
mock_supervisor_comms):
"""Test that the connection is fetched from the API server via the
Supervisor lazily when accessed"""
diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py
b/tests/api_fastapi/execution_api/routes/test_task_instances.py
index e6da0f3a192..c11fbb21bb2 100644
--- a/tests/api_fastapi/execution_api/routes/test_task_instances.py
+++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py
@@ -785,3 +785,68 @@ class TestTIPutRTIF:
response = client.put(f"/execution/task-instances/{random_id}/rtif",
json=payload)
assert response.status_code == 404
assert response.json()["detail"] == "Not Found"
+
+
+class TestPreviousDagRun:
+ def setup_method(self):
+ clear_db_runs()
+
+ def teardown_method(self):
+ clear_db_runs()
+
+ def test_ti_previous_dag_run(self, client, session, create_task_instance,
dag_maker):
+ """Test that the previous dag run is returned correctly for a task
instance."""
+ ti = create_task_instance(
+ task_id="test_ti_previous_dag_run",
+ dag_id="test_dag",
+ logical_date=timezone.datetime(2025, 1, 19),
+ state=State.RUNNING,
+ start_date=timezone.datetime(2024, 1, 17),
+ session=session,
+ )
+ session.commit()
+
+ # Create 2 DagRuns for the same DAG to verify that the correct DagRun
(last) is returned
+ dr1 = dag_maker.create_dagrun(
+ run_id="test_run_id_1",
+ logical_date=timezone.datetime(2025, 1, 17),
+ run_type="scheduled",
+ state=State.SUCCESS,
+ session=session,
+ )
+ dr1.end_date = timezone.datetime(2025, 1, 17, 1, 0, 0)
+
+ dr2 = dag_maker.create_dagrun(
+ run_id="test_run_id_2",
+ logical_date=timezone.datetime(2025, 1, 18),
+ run_type="scheduled",
+ state=State.SUCCESS,
+ session=session,
+ )
+
+ dr2.end_date = timezone.datetime(2025, 1, 18, 1, 0, 0)
+
+ session.commit()
+
+ response =
client.get(f"/execution/task-instances/{ti.id}/previous-successful-dagrun")
+ assert response.status_code == 200
+ assert response.json() == {
+ "data_interval_start": "2025-01-18T00:00:00Z",
+ "data_interval_end": "2025-01-19T00:00:00Z",
+ "start_date": "2024-01-17T00:00:00Z",
+ "end_date": "2025-01-18T01:00:00Z",
+ }
+
+ def test_ti_previous_dag_run_not_found(self, client, session):
+ ti_id = "0182e924-0f1e-77e6-ab50-e977118bc139"
+
+ assert session.get(TaskInstance, ti_id) is None
+
+ response =
client.get(f"/execution/task-instances/{ti_id}/previous-successful-dagrun")
+ assert response.status_code == 200
+ assert response.json() == {
+ "data_interval_start": None,
+ "data_interval_end": None,
+ "start_date": None,
+ "end_date": None,
+ }