This is an automated email from the ASF dual-hosted git repository. mobuchowski pushed a commit to branch listener-move-onrunning-callback in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 983dcc95ee43d71d4dd022cdc9f6a93c84f6f24c Author: Maciej Obuchowski <[email protected]> AuthorDate: Thu Jul 20 14:20:00 2023 +0200 listener: call on_task_instance_running after rendering templates Signed-off-by: Maciej Obuchowski <[email protected]> --- airflow/models/taskinstance.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 4abfd94cd0..83c25ba3af 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -1505,7 +1505,7 @@ class TaskInstance(Base, LoggingMixin): ) try: if not mark_success: - self._execute_task_with_callbacks(context, test_mode) + self._execute_task_with_callbacks(context, test_mode, session) if not test_mode: self.refresh_from_db(lock_for_update=True, session=session) self.state = TaskInstanceState.SUCCESS @@ -1601,7 +1601,8 @@ class TaskInstance(Base, LoggingMixin): session=session, ) - def _execute_task_with_callbacks(self, context, test_mode=False): + @provide_session + def _execute_task_with_callbacks(self, context, test_mode: bool = False, session: Session = NEW_SESSION): """Prepare Task for Execution.""" from airflow.models.renderedtifields import RenderedTaskInstanceFields @@ -1651,7 +1652,13 @@ class TaskInstance(Base, LoggingMixin): ) # Run pre_execute callback - self.task.pre_execute(context=context) + # Is never MappedOperator at this point + self.task.pre_execute(context=context) # type: ignore[union-attr] + + # Run on_task_instance_running event + get_listener_manager().hook.on_task_instance_running( + previous_state=TaskInstanceState.QUEUED, task_instance=self, session=session + ) # Run on_execute callback self._run_execute_callback(context, self.task) @@ -1660,7 +1667,8 @@ class TaskInstance(Base, LoggingMixin): with set_current_context(context): result = self._execute_task(context, task_orig) # Run post_execute callback - self.task.post_execute(context=context, result=result) + # Is never MappedOperator at this point + self.task.post_execute(context=context, result=result) # type: ignore[union-attr] Stats.incr(f"operator_successes_{self.task.task_type}", tags=self.stats_tags) # Same metric with tagging
