This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch listener-move-onrunning-callback
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 53d31fd5df87c6ce056030cb32fc1ad23f0110d5
Author: Maciej Obuchowski <[email protected]>
AuthorDate: Thu Jul 20 14:20:00 2023 +0200

    listener: call on_task_instance_running after rendering templates
    
    Signed-off-by: Maciej Obuchowski <[email protected]>
---
 airflow/models/taskinstance.py | 20 +++++++++++---------
 1 file changed, 11 insertions(+), 9 deletions(-)

diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 4abfd94cd0..520d07f092 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -1498,14 +1498,9 @@ class TaskInstance(Base, LoggingMixin):
         self.task = self.task.prepare_for_execution()
         context = self.get_template_context(ignore_param_exceptions=False)
 
-        # We lose previous state because it's changed in other process in 
LocalTaskJob.
-        # We could probably pass it through here though...
-        get_listener_manager().hook.on_task_instance_running(
-            previous_state=TaskInstanceState.QUEUED, task_instance=self, 
session=session
-        )
         try:
             if not mark_success:
-                self._execute_task_with_callbacks(context, test_mode)
+                self._execute_task_with_callbacks(context, test_mode, 
session=session)
             if not test_mode:
                 self.refresh_from_db(lock_for_update=True, session=session)
             self.state = TaskInstanceState.SUCCESS
@@ -1601,7 +1596,7 @@ class TaskInstance(Base, LoggingMixin):
                     session=session,
                 )
 
-    def _execute_task_with_callbacks(self, context, test_mode=False):
+    def _execute_task_with_callbacks(self, context, test_mode: bool = False, 
*, session: Session):
         """Prepare Task for Execution."""
         from airflow.models.renderedtifields import RenderedTaskInstanceFields
 
@@ -1651,16 +1646,23 @@ class TaskInstance(Base, LoggingMixin):
                 )
 
             # Run pre_execute callback
-            self.task.pre_execute(context=context)
+            # Is never MappedOperator at this point
+            self.task.pre_execute(context=context)  # type: ignore[union-attr]
 
             # Run on_execute callback
             self._run_execute_callback(context, self.task)
 
+            # Run on_task_instance_running event
+            get_listener_manager().hook.on_task_instance_running(
+                previous_state=TaskInstanceState.QUEUED, task_instance=self, 
session=session
+            )
+
             # Execute the task
             with set_current_context(context):
                 result = self._execute_task(context, task_orig)
             # Run post_execute callback
-            self.task.post_execute(context=context, result=result)
+            # Is never MappedOperator at this point
+            self.task.post_execute(context=context, result=result)  # type: 
ignore[union-attr]
 
         Stats.incr(f"operator_successes_{self.task.task_type}", 
tags=self.stats_tags)
         # Same metric with tagging

Reply via email to