This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch v3-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v3-0-test by this push:
new b6a11e70f33 [v3-0-test] Ensure TaskInstance.end_date and duration are
populated before invoking failure callbacks (#52729) (#54458)
b6a11e70f33 is described below
commit b6a11e70f33973c3c792d3c36c160338dadc7de6
Author: github-actions[bot]
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Wed Aug 13 13:51:47 2025 +0100
[v3-0-test] Ensure TaskInstance.end_date and duration are populated before
invoking failure callbacks (#52729) (#54458)
* Fix: Set TaskInstance end_date before failure callbacks (#52630)
- Set ti.end_date before creating TaskState in exception handlers
- Ensures on_failure_callback context has proper end_date and duration
- Fixes race condition where callbacks received incomplete TaskInstance data
- Add test to verify callback context contains required timing information
* Enhance task callbacks to include end_date and duration metrics
(cherry picked from commit eda89ae67b2be02fe519f23451a9d2916c1a7ebc)
Co-authored-by: Ranuga Disansa
<[email protected]>
---
.../src/airflow/sdk/execution_time/task_runner.py | 20 +++-
.../task_sdk/execution_time/test_task_runner.py | 127 +++++++++++++++++++++
2 files changed, 141 insertions(+), 6 deletions(-)
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index d4d09c8e20c..10043b5baad 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -957,9 +957,10 @@ def run(
# If AirflowFailException is raised, task should not retry.
# If a sensor in reschedule mode reaches timeout, task should not
retry.
log.exception("Task failed with exception")
+ ti.end_date = datetime.now(tz=timezone.utc)
msg = TaskState(
state=TaskInstanceState.FAILED,
- end_date=datetime.now(tz=timezone.utc),
+ end_date=ti.end_date,
rendered_map_index=ti.rendered_map_index,
)
state = TaskInstanceState.FAILED
@@ -974,9 +975,10 @@ def run(
# updated already be another UI API. So, these exceptions should
ideally never be thrown.
# If these are thrown, we should mark the TI state as failed.
log.exception("Task failed with exception")
+ ti.end_date = datetime.now(tz=timezone.utc)
msg = TaskState(
state=TaskInstanceState.FAILED,
- end_date=datetime.now(tz=timezone.utc),
+ end_date=ti.end_date,
rendered_map_index=ti.rendered_map_index,
)
state = TaskInstanceState.FAILED
@@ -1003,10 +1005,12 @@ def _handle_current_task_success(
context: Context,
ti: RuntimeTaskInstance,
) -> tuple[SucceedTask, TaskInstanceState]:
+ end_date = datetime.now(tz=timezone.utc)
+ ti.end_date = end_date
task_outlets = list(_build_asset_profiles(ti.task.outlets))
outlet_events = list(_serialize_outlet_events(context["outlet_events"]))
msg = SucceedTask(
- end_date=datetime.now(tz=timezone.utc),
+ end_date=end_date,
task_outlets=task_outlets,
outlet_events=outlet_events,
rendered_map_index=ti.rendered_map_index,
@@ -1018,11 +1022,15 @@ def _handle_current_task_failed(
ti: RuntimeTaskInstance,
) -> tuple[RetryTask, TaskInstanceState] | tuple[TaskState, TaskInstanceState]:
end_date = datetime.now(tz=timezone.utc)
+ ti.end_date = end_date
if ti._ti_context_from_server and ti._ti_context_from_server.should_retry:
return RetryTask(end_date=end_date), TaskInstanceState.UP_FOR_RETRY
- return TaskState(
- state=TaskInstanceState.FAILED, end_date=end_date,
rendered_map_index=ti.rendered_map_index
- ), TaskInstanceState.FAILED
+ return (
+ TaskState(
+ state=TaskInstanceState.FAILED, end_date=end_date,
rendered_map_index=ti.rendered_map_index
+ ),
+ TaskInstanceState.FAILED,
+ )
def _handle_trigger_dag_run(
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 2556845dac8..fc64d02fbc7 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -2566,6 +2566,133 @@ class TestTaskRunnerCallsCallbacks:
assert state == expected_state
assert collected_results == expected_results
+ def test_task_runner_on_failure_callback_context(self, create_runtime_ti):
+ """Test that on_failure_callback context has end_date and duration."""
+ from airflow.exceptions import AirflowException
+
+ def failure_callback(context):
+ ti = context["task_instance"]
+ assert isinstance(ti.end_date, datetime)
+ duration = (ti.end_date - ti.start_date).total_seconds()
+ assert duration is not None
+ assert duration >= 0
+
+ class FailingOperator(BaseOperator):
+ def execute(self, context):
+ raise AirflowException("Failing task")
+
+ task = FailingOperator(task_id="failing_task",
on_failure_callback=failure_callback)
+ runtime_ti = create_runtime_ti(dag_id="dag", task=task)
+ log = mock.MagicMock()
+ context = runtime_ti.get_template_context()
+ state, _, error = run(runtime_ti, context, log)
+ finalize(runtime_ti, state, context, log, error)
+
+ assert state == TaskInstanceState.FAILED
+
+ def test_task_runner_on_success_callback_context(self, create_runtime_ti):
+ """Test that on_success_callback context has end_date and duration."""
+ callback_data = {} # Store callback data for inspection
+
+ def success_callback(context):
+ ti = context["task_instance"]
+ callback_data["end_date"] = ti.end_date
+ callback_data["duration"] = (ti.end_date -
ti.start_date).total_seconds() if ti.end_date else None
+ callback_data["start_date"] = ti.start_date
+
+ class SuccessOperator(BaseOperator):
+ def execute(self, context):
+ return "success"
+
+ task = SuccessOperator(task_id="success_task",
on_success_callback=success_callback)
+ runtime_ti = create_runtime_ti(dag_id="dag", task=task)
+ log = mock.MagicMock()
+ context = runtime_ti.get_template_context()
+
+ state, _, error = run(runtime_ti, context, log)
+ finalize(runtime_ti, state, context, log, error)
+
+ assert state == TaskInstanceState.SUCCESS
+
+ # Verify callback was called and data was captured
+ assert "end_date" in callback_data, "Success callback should have been
called"
+ assert isinstance(callback_data["end_date"], datetime), (
+ f"end_date should be datetime, got
{type(callback_data['end_date'])}"
+ )
+ assert callback_data["duration"] is not None, (
+ f"duration should not be None, got {callback_data['duration']}"
+ )
+ assert callback_data["duration"] >= 0, f"duration should be >= 0, got
{callback_data['duration']}"
+
+ def test_task_runner_both_callbacks_have_timing_info(self,
create_runtime_ti):
+ """Test that both success and failure callbacks receive accurate
timing information."""
+ import time
+
+ from airflow.exceptions import AirflowException
+
+ success_data = {}
+ failure_data = {}
+
+ def success_callback(context):
+ ti = context["task_instance"]
+ success_data["end_date"] = ti.end_date
+ success_data["start_date"] = ti.start_date
+ success_data["duration"] = (ti.end_date -
ti.start_date).total_seconds() if ti.end_date else None
+
+ def failure_callback(context):
+ ti = context["task_instance"]
+ failure_data["end_date"] = ti.end_date
+ failure_data["start_date"] = ti.start_date
+ failure_data["duration"] = (ti.end_date -
ti.start_date).total_seconds() if ti.end_date else None
+
+ # Test success callback
+ class SuccessOperator(BaseOperator):
+ def execute(self, context):
+ time.sleep(0.01) # Add small delay to ensure measurable
duration
+ return "success"
+
+ success_task = SuccessOperator(task_id="success_task",
on_success_callback=success_callback)
+ success_runtime_ti = create_runtime_ti(dag_id="dag", task=success_task)
+ success_log = mock.MagicMock()
+ success_context = success_runtime_ti.get_template_context()
+
+ success_state, _, success_error = run(success_runtime_ti,
success_context, success_log)
+ finalize(success_runtime_ti, success_state, success_context,
success_log, success_error)
+
+ # Test failure callback
+ class FailureOperator(BaseOperator):
+ def execute(self, context):
+ time.sleep(0.01) # Add small delay to ensure measurable
duration
+ raise AirflowException("Test failure")
+
+ failure_task = FailureOperator(task_id="failure_task",
on_failure_callback=failure_callback)
+ failure_runtime_ti = create_runtime_ti(dag_id="dag", task=failure_task)
+ failure_log = mock.MagicMock()
+ failure_context = failure_runtime_ti.get_template_context()
+
+ failure_state, _, failure_error = run(failure_runtime_ti,
failure_context, failure_log)
+ finalize(failure_runtime_ti, failure_state, failure_context,
failure_log, failure_error)
+
+ # Assertions for success callback
+ assert success_state == TaskInstanceState.SUCCESS
+ assert "end_date" in success_data, "Success callback should have been
called"
+ assert isinstance(success_data["end_date"], datetime)
+ assert isinstance(success_data["start_date"], datetime)
+ assert success_data["duration"] is not None
+ assert success_data["duration"] >= 0.01, (
+ f"Success duration should be >= 0.01, got
{success_data['duration']}"
+ )
+
+ # Assertions for failure callback
+ assert failure_state == TaskInstanceState.FAILED
+ assert "end_date" in failure_data, "Failure callback should have been
called"
+ assert isinstance(failure_data["end_date"], datetime)
+ assert isinstance(failure_data["start_date"], datetime)
+ assert failure_data["duration"] is not None
+ assert failure_data["duration"] >= 0.01, (
+ f"Failure duration should be >= 0.01, got
{failure_data['duration']}"
+ )
+
@pytest.mark.parametrize(
"callback_to_test, execute_impl, should_retry, expected_state,
expected_results, extra_exceptions",
[