This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v3-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v3-0-test by this push:
     new b6a11e70f33 [v3-0-test] Ensure TaskInstance.end_date and duration are 
populated before invoking failure callbacks (#52729) (#54458)
b6a11e70f33 is described below

commit b6a11e70f33973c3c792d3c36c160338dadc7de6
Author: github-actions[bot] 
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Wed Aug 13 13:51:47 2025 +0100

    [v3-0-test] Ensure TaskInstance.end_date and duration are populated before 
invoking failure callbacks (#52729) (#54458)
    
    * Fix: Set TaskInstance end_date before failure callbacks (#52630)
    
    - Set ti.end_date before creating TaskState in exception handlers
    - Ensures on_failure_callback context has proper end_date and duration
    - Fixes race condition where callbacks received incomplete TaskInstance data
    - Add test to verify callback context contains required timing information
    
    * Enhance task callbacks to include end_date and duration metrics
    (cherry picked from commit eda89ae67b2be02fe519f23451a9d2916c1a7ebc)
    
    Co-authored-by: Ranuga Disansa 
<[email protected]>
---
 .../src/airflow/sdk/execution_time/task_runner.py  |  20 +++-
 .../task_sdk/execution_time/test_task_runner.py    | 127 +++++++++++++++++++++
 2 files changed, 141 insertions(+), 6 deletions(-)

diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index d4d09c8e20c..10043b5baad 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -957,9 +957,10 @@ def run(
         # If AirflowFailException is raised, task should not retry.
         # If a sensor in reschedule mode reaches timeout, task should not 
retry.
         log.exception("Task failed with exception")
+        ti.end_date = datetime.now(tz=timezone.utc)
         msg = TaskState(
             state=TaskInstanceState.FAILED,
-            end_date=datetime.now(tz=timezone.utc),
+            end_date=ti.end_date,
             rendered_map_index=ti.rendered_map_index,
         )
         state = TaskInstanceState.FAILED
@@ -974,9 +975,10 @@ def run(
         # updated already be another UI API. So, these exceptions should 
ideally never be thrown.
         # If these are thrown, we should mark the TI state as failed.
         log.exception("Task failed with exception")
+        ti.end_date = datetime.now(tz=timezone.utc)
         msg = TaskState(
             state=TaskInstanceState.FAILED,
-            end_date=datetime.now(tz=timezone.utc),
+            end_date=ti.end_date,
             rendered_map_index=ti.rendered_map_index,
         )
         state = TaskInstanceState.FAILED
@@ -1003,10 +1005,12 @@ def _handle_current_task_success(
     context: Context,
     ti: RuntimeTaskInstance,
 ) -> tuple[SucceedTask, TaskInstanceState]:
+    end_date = datetime.now(tz=timezone.utc)
+    ti.end_date = end_date
     task_outlets = list(_build_asset_profiles(ti.task.outlets))
     outlet_events = list(_serialize_outlet_events(context["outlet_events"]))
     msg = SucceedTask(
-        end_date=datetime.now(tz=timezone.utc),
+        end_date=end_date,
         task_outlets=task_outlets,
         outlet_events=outlet_events,
         rendered_map_index=ti.rendered_map_index,
@@ -1018,11 +1022,15 @@ def _handle_current_task_failed(
     ti: RuntimeTaskInstance,
 ) -> tuple[RetryTask, TaskInstanceState] | tuple[TaskState, TaskInstanceState]:
     end_date = datetime.now(tz=timezone.utc)
+    ti.end_date = end_date
     if ti._ti_context_from_server and ti._ti_context_from_server.should_retry:
         return RetryTask(end_date=end_date), TaskInstanceState.UP_FOR_RETRY
-    return TaskState(
-        state=TaskInstanceState.FAILED, end_date=end_date, 
rendered_map_index=ti.rendered_map_index
-    ), TaskInstanceState.FAILED
+    return (
+        TaskState(
+            state=TaskInstanceState.FAILED, end_date=end_date, 
rendered_map_index=ti.rendered_map_index
+        ),
+        TaskInstanceState.FAILED,
+    )
 
 
 def _handle_trigger_dag_run(
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py 
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 2556845dac8..fc64d02fbc7 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -2566,6 +2566,133 @@ class TestTaskRunnerCallsCallbacks:
         assert state == expected_state
         assert collected_results == expected_results
 
+    def test_task_runner_on_failure_callback_context(self, create_runtime_ti):
+        """Test that on_failure_callback context has end_date and duration."""
+        from airflow.exceptions import AirflowException
+
+        def failure_callback(context):
+            ti = context["task_instance"]
+            assert isinstance(ti.end_date, datetime)
+            duration = (ti.end_date - ti.start_date).total_seconds()
+            assert duration is not None
+            assert duration >= 0
+
+        class FailingOperator(BaseOperator):
+            def execute(self, context):
+                raise AirflowException("Failing task")
+
+        task = FailingOperator(task_id="failing_task", 
on_failure_callback=failure_callback)
+        runtime_ti = create_runtime_ti(dag_id="dag", task=task)
+        log = mock.MagicMock()
+        context = runtime_ti.get_template_context()
+        state, _, error = run(runtime_ti, context, log)
+        finalize(runtime_ti, state, context, log, error)
+
+        assert state == TaskInstanceState.FAILED
+
+    def test_task_runner_on_success_callback_context(self, create_runtime_ti):
+        """Test that on_success_callback context has end_date and duration."""
+        callback_data = {}  # Store callback data for inspection
+
+        def success_callback(context):
+            ti = context["task_instance"]
+            callback_data["end_date"] = ti.end_date
+            callback_data["duration"] = (ti.end_date - 
ti.start_date).total_seconds() if ti.end_date else None
+            callback_data["start_date"] = ti.start_date
+
+        class SuccessOperator(BaseOperator):
+            def execute(self, context):
+                return "success"
+
+        task = SuccessOperator(task_id="success_task", 
on_success_callback=success_callback)
+        runtime_ti = create_runtime_ti(dag_id="dag", task=task)
+        log = mock.MagicMock()
+        context = runtime_ti.get_template_context()
+
+        state, _, error = run(runtime_ti, context, log)
+        finalize(runtime_ti, state, context, log, error)
+
+        assert state == TaskInstanceState.SUCCESS
+
+        # Verify callback was called and data was captured
+        assert "end_date" in callback_data, "Success callback should have been 
called"
+        assert isinstance(callback_data["end_date"], datetime), (
+            f"end_date should be datetime, got 
{type(callback_data['end_date'])}"
+        )
+        assert callback_data["duration"] is not None, (
+            f"duration should not be None, got {callback_data['duration']}"
+        )
+        assert callback_data["duration"] >= 0, f"duration should be >= 0, got 
{callback_data['duration']}"
+
+    def test_task_runner_both_callbacks_have_timing_info(self, 
create_runtime_ti):
+        """Test that both success and failure callbacks receive accurate 
timing information."""
+        import time
+
+        from airflow.exceptions import AirflowException
+
+        success_data = {}
+        failure_data = {}
+
+        def success_callback(context):
+            ti = context["task_instance"]
+            success_data["end_date"] = ti.end_date
+            success_data["start_date"] = ti.start_date
+            success_data["duration"] = (ti.end_date - 
ti.start_date).total_seconds() if ti.end_date else None
+
+        def failure_callback(context):
+            ti = context["task_instance"]
+            failure_data["end_date"] = ti.end_date
+            failure_data["start_date"] = ti.start_date
+            failure_data["duration"] = (ti.end_date - 
ti.start_date).total_seconds() if ti.end_date else None
+
+        # Test success callback
+        class SuccessOperator(BaseOperator):
+            def execute(self, context):
+                time.sleep(0.01)  # Add small delay to ensure measurable 
duration
+                return "success"
+
+        success_task = SuccessOperator(task_id="success_task", 
on_success_callback=success_callback)
+        success_runtime_ti = create_runtime_ti(dag_id="dag", task=success_task)
+        success_log = mock.MagicMock()
+        success_context = success_runtime_ti.get_template_context()
+
+        success_state, _, success_error = run(success_runtime_ti, 
success_context, success_log)
+        finalize(success_runtime_ti, success_state, success_context, 
success_log, success_error)
+
+        # Test failure callback
+        class FailureOperator(BaseOperator):
+            def execute(self, context):
+                time.sleep(0.01)  # Add small delay to ensure measurable 
duration
+                raise AirflowException("Test failure")
+
+        failure_task = FailureOperator(task_id="failure_task", 
on_failure_callback=failure_callback)
+        failure_runtime_ti = create_runtime_ti(dag_id="dag", task=failure_task)
+        failure_log = mock.MagicMock()
+        failure_context = failure_runtime_ti.get_template_context()
+
+        failure_state, _, failure_error = run(failure_runtime_ti, 
failure_context, failure_log)
+        finalize(failure_runtime_ti, failure_state, failure_context, 
failure_log, failure_error)
+
+        # Assertions for success callback
+        assert success_state == TaskInstanceState.SUCCESS
+        assert "end_date" in success_data, "Success callback should have been 
called"
+        assert isinstance(success_data["end_date"], datetime)
+        assert isinstance(success_data["start_date"], datetime)
+        assert success_data["duration"] is not None
+        assert success_data["duration"] >= 0.01, (
+            f"Success duration should be >= 0.01, got 
{success_data['duration']}"
+        )
+
+        # Assertions for failure callback
+        assert failure_state == TaskInstanceState.FAILED
+        assert "end_date" in failure_data, "Failure callback should have been 
called"
+        assert isinstance(failure_data["end_date"], datetime)
+        assert isinstance(failure_data["start_date"], datetime)
+        assert failure_data["duration"] is not None
+        assert failure_data["duration"] >= 0.01, (
+            f"Failure duration should be >= 0.01, got 
{failure_data['duration']}"
+        )
+
     @pytest.mark.parametrize(
         "callback_to_test, execute_impl, should_retry, expected_state, 
expected_results, extra_exceptions",
         [

Reply via email to