This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch fix-error-dags
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 2bf994ace3a0b6f3814ef31d85cf8b6d271e50f4
Author: Maciej Obuchowski <[email protected]>
AuthorDate: Thu Feb 20 22:21:47 2025 +0100

    pass error for on_task_instance_failed in task sdk
    
    Signed-off-by: Maciej Obuchowski <[email protected]>
---
 .../providers/openlineage/plugins/listener.py      | 27 ++++++++++++++----
 .../src/airflow/sdk/execution_time/task_runner.py  | 32 ++++++++++++++--------
 task_sdk/tests/execution_time/test_task_runner.py  | 15 ++++++----
 3 files changed, 50 insertions(+), 24 deletions(-)

diff --git 
a/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py 
b/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py
index c49880237b8..6954d1cc2a8 100644
--- 
a/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py
+++ 
b/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py
@@ -178,9 +178,13 @@ class OpenLineageListener:
                 self.log.debug("Skipping this instance of rescheduled task - 
START event was emitted already")
                 return
 
+            date = dagrun.logical_date
+            if AIRFLOW_V_3_0_PLUS and date is None:
+                date = dagrun.run_after
+
             parent_run_id = self.adapter.build_dag_run_id(
                 dag_id=dag.dag_id,
-                logical_date=dagrun.logical_date,
+                logical_date=date,
                 clear_number=clear_number,
             )
 
@@ -188,7 +192,7 @@ class OpenLineageListener:
                 dag_id=dag.dag_id,
                 task_id=task.task_id,
                 try_number=task_instance.try_number,
-                logical_date=dagrun.logical_date,
+                logical_date=date,
                 map_index=task_instance.map_index,
             )
             event_type = RunState.RUNNING.value.lower()
@@ -276,9 +280,13 @@ class OpenLineageListener:
 
         @print_warning(self.log)
         def on_success():
+            date = dagrun.logical_date
+            if AIRFLOW_V_3_0_PLUS and date is None:
+                date = dagrun.run_after
+
             parent_run_id = self.adapter.build_dag_run_id(
                 dag_id=dag.dag_id,
-                logical_date=dagrun.logical_date,
+                logical_date=date,
                 clear_number=dagrun.clear_number,
             )
 
@@ -286,7 +294,7 @@ class OpenLineageListener:
                 dag_id=dag.dag_id,
                 task_id=task.task_id,
                 try_number=_get_try_number_success(task_instance),
-                logical_date=dagrun.logical_date,
+                logical_date=date,
                 map_index=task_instance.map_index,
             )
             event_type = RunState.COMPLETE.value.lower()
@@ -393,9 +401,16 @@ class OpenLineageListener:
 
         @print_warning(self.log)
         def on_failure():
+            self.log.error(
+                "ELO ELO %s %s %s", type(dagrun), type(dagrun.logical_date), 
type(dagrun.run_after)
+            )
+            date = dagrun.logical_date
+            if AIRFLOW_V_3_0_PLUS and date is None:
+                date = dagrun.run_after
+
             parent_run_id = self.adapter.build_dag_run_id(
                 dag_id=dag.dag_id,
-                logical_date=dagrun.logical_date,
+                logical_date=date,
                 clear_number=dagrun.clear_number,
             )
 
@@ -403,7 +418,7 @@ class OpenLineageListener:
                 dag_id=dag.dag_id,
                 task_id=task.task_id,
                 try_number=task_instance.try_number,
-                logical_date=dagrun.logical_date,
+                logical_date=date,
                 map_index=task_instance.map_index,
             )
             event_type = RunState.FAIL.value.lower()
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index 99967579bb4..7924ea1cfb8 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -572,7 +572,7 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger, context: 
Context) -> ToSuperv
 
 def run(
     ti: RuntimeTaskInstance, log: Logger
-) -> tuple[IntermediateTIState | TerminalTIState, ToSupervisor | None]:
+) -> tuple[IntermediateTIState | TerminalTIState, ToSupervisor | None, 
BaseException | None]:
     """Run the task in this process."""
     from airflow.exceptions import (
         AirflowException,
@@ -591,6 +591,7 @@ def run(
 
     msg: ToSupervisor | None = None
     state: IntermediateTIState | TerminalTIState
+    error: BaseException | None = None
     try:
         context = ti.get_template_context()
         with set_current_context(context):
@@ -599,7 +600,7 @@ def run(
             if early_exit := _prepare(ti, log, context):
                 msg = early_exit
                 state = TerminalTIState.FAILED
-                return state, msg
+                return state, msg, error
 
             result = _execute_task(context, ti)
 
@@ -639,7 +640,7 @@ def run(
             reschedule_date=reschedule.reschedule_date, 
end_date=datetime.now(tz=timezone.utc)
         )
         state = IntermediateTIState.UP_FOR_RESCHEDULE
-    except (AirflowFailException, AirflowSensorTimeout):
+    except (AirflowFailException, AirflowSensorTimeout) as e:
         # If AirflowFailException is raised, task should not retry.
         # If a sensor in reschedule mode reaches timeout, task should not 
retry.
         log.exception("Task failed with exception")
@@ -650,7 +651,8 @@ def run(
             end_date=datetime.now(tz=timezone.utc),
         )
         state = TerminalTIState.FAIL_WITHOUT_RETRY
-    except (AirflowTaskTimeout, AirflowException):
+        error = e
+    except (AirflowTaskTimeout, AirflowException) as e:
         # We should allow retries if the task has defined it.
         log.exception("Task failed with exception")
         msg = TaskState(
@@ -658,7 +660,8 @@ def run(
             end_date=datetime.now(tz=timezone.utc),
         )
         state = TerminalTIState.FAILED
-    except AirflowTaskTerminated:
+        error = e
+    except AirflowTaskTerminated as e:
         # External state updates are already handled with `ti_heartbeat` and 
will be
         # updated already be another UI API. So, these exceptions should 
ideally never be thrown.
         # If these are thrown, we should mark the TI state as failed.
@@ -668,7 +671,8 @@ def run(
             end_date=datetime.now(tz=timezone.utc),
         )
         state = TerminalTIState.FAIL_WITHOUT_RETRY
-    except SystemExit:
+        error = e
+    except SystemExit as e:
         # SystemExit needs to be retried if they are eligible.
         log.exception("Task failed with exception")
         msg = TaskState(
@@ -676,15 +680,17 @@ def run(
             end_date=datetime.now(tz=timezone.utc),
         )
         state = TerminalTIState.FAILED
-    except BaseException:
+        error = e
+    except BaseException as e:
         log.exception("Task failed with exception")
         msg = TaskState(state=TerminalTIState.FAILED, 
end_date=datetime.now(tz=timezone.utc))
         state = TerminalTIState.FAILED
+        error = e
     finally:
         if msg:
             SUPERVISOR_COMMS.send_request(msg=msg, log=log)
     # Return the message to make unit tests easier too
-    return state, msg
+    return state, msg, error
 
 
 def _execute_task(context: Context, ti: RuntimeTaskInstance):
@@ -759,7 +765,9 @@ def _push_xcom_if_needed(result: Any, ti: 
RuntimeTaskInstance, log: Logger):
     _xcom_push(ti, "return_value", result, mapped_length=mapped_length)
 
 
-def finalize(ti: RuntimeTaskInstance, state: TerminalTIState, log: Logger):
+def finalize(
+    ti: RuntimeTaskInstance, state: TerminalTIState, log: Logger, error: 
BaseException | None = None
+):
     # Pushing xcom for each operator extra links defined on the operator only.
     for oe in ti.task.operator_extra_links:
         link, xcom_key = oe.get_link(operator=ti.task, ti_key=ti.id), 
oe.xcom_key  # type: ignore[arg-type]
@@ -774,7 +782,7 @@ def finalize(ti: RuntimeTaskInstance, state: 
TerminalTIState, log: Logger):
         # TODO: Run task success callbacks here
     if state in [TerminalTIState.FAILED, TerminalTIState.FAIL_WITHOUT_RETRY]:
         get_listener_manager().hook.on_task_instance_failed(
-            previous_state=TaskInstanceState.RUNNING, task_instance=ti
+            previous_state=TaskInstanceState.RUNNING, task_instance=ti, 
error=error
         )
         # TODO: Run task failure callbacks here
 
@@ -787,8 +795,8 @@ def main():
     SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](input=sys.stdin)
     try:
         ti, log = startup()
-        state, msg = run(ti, log)
-        finalize(ti, state, log)
+        state, msg, error = run(ti, log)
+        finalize(ti, state, log, error)
     except KeyboardInterrupt:
         log = structlog.get_logger(logger_name="task")
         log.exception("Ctrl-c hit")
diff --git a/task_sdk/tests/execution_time/test_task_runner.py 
b/task_sdk/tests/execution_time/test_task_runner.py
index 5e4405165e2..ba53c618e1d 100644
--- a/task_sdk/tests/execution_time/test_task_runner.py
+++ b/task_sdk/tests/execution_time/test_task_runner.py
@@ -1147,7 +1147,7 @@ class TestRuntimeTaskInstance:
                 "a_simple_list": ["one", "two", "three", "actually one value 
is made per line"],
             },
         )
-        _, msg = run(runtime_ti, log=mock.MagicMock())
+        _, msg, _ = run(runtime_ti, log=mock.MagicMock())
         assert isinstance(msg, SucceedTask)
 
     def test_task_run_with_operator_extra_links(self, create_runtime_ti, 
mock_supervisor_comms, time_machine):
@@ -1502,6 +1502,7 @@ class TestTaskRunnerCallsListeners:
         def __init__(self):
             self.state = []
             self.component = None
+            self.error = None
 
         @hookimpl
         def on_starting(self, component):
@@ -1516,8 +1517,9 @@ class TestTaskRunnerCallsListeners:
             self.state.append(TaskInstanceState.SUCCESS)
 
         @hookimpl
-        def on_task_instance_failed(self, previous_state, task_instance):
+        def on_task_instance_failed(self, previous_state, task_instance, 
error):
             self.state.append(TaskInstanceState.FAILED)
+            self.error = error
 
         @hookimpl
         def before_stopping(self, component):
@@ -1566,7 +1568,7 @@ class TestTaskRunnerCallsListeners:
         assert isinstance(listener.component, TaskRunnerMarker)
         del listener.component
 
-        state, _ = run(runtime_ti, log)
+        state, _, _ = run(runtime_ti, log)
         finalize(runtime_ti, state, log)
         assert isinstance(listener.component, TaskRunnerMarker)
 
@@ -1595,7 +1597,7 @@ class TestTaskRunnerCallsListeners:
         )
         log = mock.MagicMock()
 
-        state, _ = run(runtime_ti, log)
+        state, _, _ = run(runtime_ti, log)
         finalize(runtime_ti, state, log)
 
         assert listener.state == [TaskInstanceState.RUNNING, 
TaskInstanceState.SUCCESS]
@@ -1633,7 +1635,8 @@ class TestTaskRunnerCallsListeners:
         )
         log = mock.MagicMock()
 
-        state, _ = run(runtime_ti, log)
-        finalize(runtime_ti, state, log)
+        state, _, error = run(runtime_ti, log)
+        finalize(runtime_ti, state, log, error)
 
         assert listener.state == [TaskInstanceState.RUNNING, 
TaskInstanceState.FAILED]
+        assert listener.error == error

Reply via email to