This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ebd65ce579 Add log for running callback (#38892)
ebd65ce579 is described below

commit ebd65ce5795e230e1bc56d97c6104d1ffc210a57
Author: rom sharon <[email protected]>
AuthorDate: Wed Apr 17 17:51:39 2024 +0300

    Add log for running callback (#38892)
    
    * add log for running callback
    
    * get callback name before try statement
    
    Co-authored-by: Andrey Anshin <[email protected]>
    
    * add tests
    
    * fix test
    
    * change logging
    
    * fix tests
    
    ---------
    
    Co-authored-by: Andrey Anshin <[email protected]>
---
 airflow/models/taskinstance.py    |  5 ++---
 tests/models/test_taskinstance.py | 30 ++++++++++++++----------------
 2 files changed, 16 insertions(+), 19 deletions(-)

diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 8f9d71cfe7..b4d5d5d65a 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -117,7 +117,6 @@ from airflow.utils.context import (
 from airflow.utils.email import send_email
 from airflow.utils.helpers import prune_dict, render_template_to_string
 from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.module_loading import qualname
 from airflow.utils.net import get_hostname
 from airflow.utils.operator_helpers import ExecutionCallableRunner, 
context_to_airflow_vars
 from airflow.utils.platform import getuser
@@ -1230,11 +1229,11 @@ def _run_finished_callback(
     if callbacks:
         callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
         for callback in callbacks:
+            log.info("Executing %s callback", callback.__name__)
             try:
                 callback(context)
             except Exception:
-                callback_name = qualname(callback).split(".")[-1]
-                log.exception("Error when executing %s callback", 
callback_name)  # type: ignore[attr-defined]
+                log.exception("Error when executing %s callback", 
callback.__name__)  # type: ignore[attr-defined]
 
 
 def _log_state(*, task_instance: TaskInstance | TaskInstancePydantic, 
lead_msg: str = "") -> None:
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index e4c9e17b21..247afb012b 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -2849,20 +2849,7 @@ class TestTaskInstance:
             ti.refresh_from_db()
             assert ti.state == State.SUCCESS
 
-    @pytest.mark.parametrize(
-        "finished_state",
-        [
-            State.SUCCESS,
-            State.UP_FOR_RETRY,
-            State.FAILED,
-        ],
-    )
-    @patch("logging.Logger.exception")
-    def test_finished_callbacks_handle_and_log_exception(
-        self, mock_log, finished_state, create_task_instance
-    ):
-        called = completed = False
-
+    def test_finished_callbacks_handle_and_log_exception(self, caplog):
         def on_finish_callable(context):
             nonlocal called, completed
             called = True
@@ -2870,14 +2857,16 @@ class TestTaskInstance:
             completed = True
 
         for callback_input in [[on_finish_callable], on_finish_callable]:
+            called = completed = False
+            caplog.clear()
             _run_finished_callback(callbacks=callback_input, context={})
 
             assert called
             assert not completed
             callback_name = callback_input[0] if isinstance(callback_input, 
list) else callback_input
             callback_name = qualname(callback_name).split(".")[-1]
-            expected_message = "Error when executing %s callback"
-            mock_log.assert_called_with(expected_message, callback_name)
+            assert "Executing on_finish_callable callback" in caplog.text
+            assert "Error when executing on_finish_callable callback" in 
caplog.text
 
     @provide_session
     def test_handle_failure(self, create_dummy_dag, session=None):
@@ -2890,7 +2879,9 @@ class TestTaskInstance:
         get_listener_manager().pm.hook.on_task_instance_failed = 
listener_callback_on_error
 
         mock_on_failure_1 = mock.MagicMock()
+        mock_on_failure_1.__name__ = "mock_on_failure_1"
         mock_on_retry_1 = mock.MagicMock()
+        mock_on_retry_1.__name__ = "mock_on_retry_1"
         dag, task1 = create_dummy_dag(
             dag_id="test_handle_failure",
             schedule=None,
@@ -2927,7 +2918,9 @@ class TestTaskInstance:
         mock_on_retry_1.assert_not_called()
 
         mock_on_failure_2 = mock.MagicMock()
+        mock_on_failure_2.__name__ = "mock_on_failure_2"
         mock_on_retry_2 = mock.MagicMock()
+        mock_on_retry_2.__name__ = "mock_on_retry_2"
         task2 = EmptyOperator(
             task_id="test_handle_failure_on_retry",
             on_failure_callback=mock_on_failure_2,
@@ -2949,7 +2942,9 @@ class TestTaskInstance:
 
         # test the scenario where normally we would retry but have been asked 
to fail
         mock_on_failure_3 = mock.MagicMock()
+        mock_on_failure_3.__name__ = "mock_on_failure_3"
         mock_on_retry_3 = mock.MagicMock()
+        mock_on_retry_3.__name__ = "mock_on_retry_3"
         task3 = EmptyOperator(
             task_id="test_handle_failure_on_force_fail",
             on_failure_callback=mock_on_failure_3,
@@ -3465,6 +3460,7 @@ class TestTaskInstance:
             raise AirflowSkipException
 
         callback_function = mock.MagicMock()
+        callback_function.__name__ = "callback_function"
 
         with dag_maker(dag_id="test_skipped_task"):
             task = PythonOperator(
@@ -3560,6 +3556,7 @@ def test_sensor_timeout(mode, retries, dag_maker):
         raise AirflowSensorTimeout
 
     mock_on_failure = mock.MagicMock()
+    mock_on_failure.__name__ = "mock_on_failure"
     with dag_maker(dag_id=f"test_sensor_timeout_{mode}_{retries}"):
         PythonSensor(
             task_id="test_raise_sensor_timeout",
@@ -3588,6 +3585,7 @@ def test_mapped_sensor_timeout(mode, retries, dag_maker):
         raise AirflowSensorTimeout
 
     mock_on_failure = mock.MagicMock()
+    mock_on_failure.__name__ = "mock_on_failure"
     with dag_maker(dag_id=f"test_sensor_timeout_{mode}_{retries}"):
         PythonSensor.partial(
             task_id="test_raise_sensor_timeout",

Reply via email to