This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch listener-task-timeout
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 460fadf1b97c645c00ef38a7877d49f6a643ce13
Author: Maciej Obuchowski <[email protected]>
AuthorDate: Fri May 24 17:17:01 2024 +0200

    local task job: add timeout, to not kill on_task_instance_success listener 
prematurely
    
    Signed-off-by: Maciej Obuchowski <[email protected]>
---
 airflow/config_templates/config.yml               |   7 +
 airflow/jobs/local_task_job_runner.py             |  13 +-
 airflow/providers/openlineage/plugins/listener.py |   1 -
 tests/dags/test_mark_state.py                     |  15 +++
 tests/jobs/test_local_task_job.py                 | 150 +++++++++++++++++++++-
 tests/listeners/slow_listener.py                  |  26 ++++
 tests/listeners/very_slow_listener.py             |  26 ++++
 7 files changed, 235 insertions(+), 3 deletions(-)

diff --git a/airflow/config_templates/config.yml 
b/airflow/config_templates/config.yml
index 95d83f9d4c..b71414f6ae 100644
--- a/airflow/config_templates/config.yml
+++ b/airflow/config_templates/config.yml
@@ -329,6 +329,13 @@ core:
       type: string
       example: ~
       default: "downstream"
+    task_listener_timeout:
+      description: |
+        Maximum possible time (in seconds) that task listener will have for 
their execution.
+      version_added: 2.10.0
+      type: integer
+      example: ~
+      default: "20"
     default_task_execution_timeout:
       description: |
         The default task execution_timeout value for the operators. Expected 
an integer value to
diff --git a/airflow/jobs/local_task_job_runner.py 
b/airflow/jobs/local_task_job_runner.py
index bb520825f2..c26f6735b4 100644
--- a/airflow/jobs/local_task_job_runner.py
+++ b/airflow/jobs/local_task_job_runner.py
@@ -26,6 +26,7 @@ from airflow.configuration import conf
 from airflow.exceptions import AirflowException
 from airflow.jobs.base_job_runner import BaseJobRunner
 from airflow.jobs.job import perform_heartbeat
+from airflow.listeners.listener import get_listener_manager
 from airflow.models.taskinstance import TaskReturnCode
 from airflow.stats import Stats
 from airflow.utils import timezone
@@ -110,6 +111,8 @@ class LocalTaskJobRunner(BaseJobRunner, LoggingMixin):
         self.terminating = False
 
         self._state_change_checks = 0
+        # time spend after task completed, but before it exited - used to 
measure listener execution time
+        self._overtime = 0.0
 
     def _execute(self) -> int | None:
         from airflow.task.task_runner import get_task_runner
@@ -195,7 +198,6 @@ class LocalTaskJobRunner(BaseJobRunner, LoggingMixin):
                         self.job.heartrate if self.job.heartrate is not None 
else heartbeat_time_limit,
                     ),
                 )
-
                 return_code = 
self.task_runner.return_code(timeout=max_wait_time)
                 if return_code is not None:
                     self.handle_task_exit(return_code)
@@ -290,6 +292,7 @@ class LocalTaskJobRunner(BaseJobRunner, LoggingMixin):
                 )
                 raise AirflowException("PID of job runner does not match")
         elif self.task_runner.return_code() is None and 
hasattr(self.task_runner, "process"):
+            self._overtime = (timezone.utcnow() - (ti.end_date or 
timezone.utcnow())).total_seconds()
             if ti.state == TaskInstanceState.SKIPPED:
                 # A DagRun timeout will cause tasks to be externally marked as 
skipped.
                 dagrun = ti.get_dagrun(session=session)
@@ -303,6 +306,14 @@ class LocalTaskJobRunner(BaseJobRunner, LoggingMixin):
                 if dagrun_timeout and execution_time > dagrun_timeout:
                     self.log.warning("DagRun timed out after %s.", 
execution_time)
 
+            # If process still runs after being marked as success, let it run 
until configured overtime
+            # if there are configured listeners
+            if (
+                ti.state == TaskInstanceState.SUCCESS
+                and self._overtime < conf.getint("core", 
"task_listener_timeout")
+                and get_listener_manager().has_listeners
+            ):
+                return
             # potential race condition, the _run_raw_task commits `success` or 
other state
             # but task_runner does not exit right away due to slow process 
shutdown or any other reasons
             # let's do a throttle here, if the above case is true, the 
handle_task_exit will handle it
diff --git a/airflow/providers/openlineage/plugins/listener.py 
b/airflow/providers/openlineage/plugins/listener.py
index 73f8c8c79e..3df5f36d6c 100644
--- a/airflow/providers/openlineage/plugins/listener.py
+++ b/airflow/providers/openlineage/plugins/listener.py
@@ -130,7 +130,6 @@ class OpenLineageListener:
                 dagrun.data_interval_start.isoformat() if 
dagrun.data_interval_start else None
             )
             data_interval_end = dagrun.data_interval_end.isoformat() if 
dagrun.data_interval_end else None
-
             redacted_event = self.adapter.start_task(
                 run_id=task_uuid,
                 job_name=get_job_name(task),
diff --git a/tests/dags/test_mark_state.py b/tests/dags/test_mark_state.py
index 331da2d498..2157fdcff6 100644
--- a/tests/dags/test_mark_state.py
+++ b/tests/dags/test_mark_state.py
@@ -17,6 +17,7 @@
 # under the License.
 from __future__ import annotations
 
+import time
 from datetime import datetime
 from time import sleep
 
@@ -41,6 +42,16 @@ def success_callback(context):
     assert context["dag_run"].dag_id == dag_id
 
 
+def sleep_execution():
+    time.sleep(1)
+
+
+def slow_execution():
+    import re
+
+    re.match(r"(a?){30}a{30}", "a" * 30)
+
+
 def test_mark_success_no_kill(ti):
     assert ti.state == State.RUNNING
     # Simulate marking this successful in the UI
@@ -103,3 +114,7 @@ def test_mark_skipped_externally(ti):
 PythonOperator(task_id="test_mark_skipped_externally", 
python_callable=test_mark_skipped_externally, dag=dag)
 
 PythonOperator(task_id="dummy", python_callable=lambda: True, dag=dag)
+
+PythonOperator(task_id="slow_execution", python_callable=slow_execution, 
dag=dag)
+
+PythonOperator(task_id="sleep_execution", python_callable=sleep_execution, 
dag=dag)
diff --git a/tests/jobs/test_local_task_job.py 
b/tests/jobs/test_local_task_job.py
index 6de4dcd6a4..90c2c36e6f 100644
--- a/tests/jobs/test_local_task_job.py
+++ b/tests/jobs/test_local_task_job.py
@@ -33,11 +33,12 @@ import psutil
 import pytest
 
 from airflow import settings
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowTaskTimeout
 from airflow.executors.sequential_executor import SequentialExecutor
 from airflow.jobs.job import Job, run_job
 from airflow.jobs.local_task_job_runner import SIGSEGV_MESSAGE, 
LocalTaskJobRunner
 from airflow.jobs.scheduler_job_runner import SchedulerJobRunner
+from airflow.listeners.listener import get_listener_manager
 from airflow.models.dag import DAG
 from airflow.models.dagbag import DagBag
 from airflow.models.serialized_dag import SerializedDagModel
@@ -327,6 +328,7 @@ class TestLocalTaskJob:
         Test that ensures that mark_success in the UI doesn't cause
         the task to fail, and that the task exits
         """
+        get_listener_manager().clear()
         dag = get_test_dag("test_mark_state")
         data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE)
         dr = dag.create_dagrun(
@@ -583,6 +585,152 @@ class TestLocalTaskJob:
             "State of this instance has been externally set to success. 
Terminating instance." in caplog.text
         )
 
+    def test_success_listeners_executed(self, caplog, get_test_dag):
+        """
+        Test that ensures that when listeners are executed, the task is not 
killed before they finish
+        or timeout
+        """
+        from tests.listeners import slow_listener
+
+        lm = get_listener_manager()
+        lm.clear()
+        lm.add_listener(slow_listener)
+
+        dag = get_test_dag("test_mark_state")
+        data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE)
+        with create_session() as session:
+            dr = dag.create_dagrun(
+                state=State.RUNNING,
+                execution_date=DEFAULT_DATE,
+                run_type=DagRunType.SCHEDULED,
+                session=session,
+                data_interval=data_interval,
+            )
+        task = dag.get_task(task_id="sleep_execution")
+
+        ti = dr.get_task_instance(task.task_id)
+        ti.refresh_from_task(task)
+
+        job = Job(executor=SequentialExecutor(), dag_id=ti.dag_id)
+        job_runner = LocalTaskJobRunner(job=job, task_instance=ti, 
ignore_ti_state=True)
+        with timeout(30):
+            run_job(job=job, execute_callable=job_runner._execute)
+        ti.refresh_from_db()
+        assert (
+            "State of this instance has been externally set to success. 
Terminating instance."
+            not in caplog.text
+        )
+        lm.clear()
+
+    def test_success_slow_listeners_executed_kill(self, caplog, get_test_dag):
+        """
+        Test that ensures that when there are too slow listeners, the task is 
killed
+        """
+        from tests.listeners import very_slow_listener
+
+        lm = get_listener_manager()
+        lm.clear()
+        lm.add_listener(very_slow_listener)
+
+        dag = get_test_dag("test_mark_state")
+        data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE)
+        with create_session() as session:
+            dr = dag.create_dagrun(
+                state=State.RUNNING,
+                execution_date=DEFAULT_DATE,
+                run_type=DagRunType.SCHEDULED,
+                session=session,
+                data_interval=data_interval,
+            )
+        task = dag.get_task(task_id="sleep_execution")
+
+        ti = dr.get_task_instance(task.task_id)
+        ti.refresh_from_task(task)
+
+        job = Job(executor=SequentialExecutor(), dag_id=ti.dag_id)
+        job_runner = LocalTaskJobRunner(job=job, task_instance=ti, 
ignore_ti_state=True)
+        with timeout(30):
+            run_job(job=job, execute_callable=job_runner._execute)
+        ti.refresh_from_db()
+        assert (
+            "State of this instance has been externally set to success. 
Terminating instance." in caplog.text
+        )
+        lm.clear()
+
+        def test_success_slow_listeners_executed_kill(self, caplog, 
get_test_dag):
+            """
+            Test that ensures that when there are too slow listeners, the task 
is killed
+            """
+            from tests.listeners import very_slow_listener
+
+            lm = get_listener_manager()
+            lm.clear()
+            lm.add_listener(very_slow_listener)
+
+            dag = get_test_dag("test_mark_state")
+            data_interval = 
dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE)
+            with create_session() as session:
+                dr = dag.create_dagrun(
+                    state=State.RUNNING,
+                    execution_date=DEFAULT_DATE,
+                    run_type=DagRunType.SCHEDULED,
+                    session=session,
+                    data_interval=data_interval,
+                )
+            task = dag.get_task(task_id="sleep_execution")
+
+            ti = dr.get_task_instance(task.task_id)
+            ti.refresh_from_task(task)
+
+            job = Job(executor=SequentialExecutor(), dag_id=ti.dag_id)
+            job_runner = LocalTaskJobRunner(job=job, task_instance=ti, 
ignore_ti_state=True)
+            with timeout(30):
+                run_job(job=job, execute_callable=job_runner._execute)
+            ti.refresh_from_db()
+            assert (
+                "State of this instance has been externally set to success. 
Terminating instance."
+                in caplog.text
+            )
+            lm.clear()
+
+    def 
test_success_slow_task_not_killed_by_overtime_but_regular_timeout(self, caplog, 
get_test_dag):
+        """
+        Test that ensures that when there are listeners, but the task is 
taking a long time anyways,
+        it's not killed by the overtime mechanism.
+        """
+        from tests.listeners import slow_listener
+
+        lm = get_listener_manager()
+        lm.clear()
+        lm.add_listener(slow_listener)
+
+        dag = get_test_dag("test_mark_state")
+        data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE)
+        with create_session() as session:
+            dr = dag.create_dagrun(
+                state=State.RUNNING,
+                execution_date=DEFAULT_DATE,
+                run_type=DagRunType.SCHEDULED,
+                session=session,
+                data_interval=data_interval,
+            )
+        task = dag.get_task(task_id="sleep_execution")
+
+        ti = dr.get_task_instance(task.task_id)
+        ti.refresh_from_task(task)
+
+        job = Job(executor=SequentialExecutor(), dag_id=ti.dag_id)
+        job_runner = LocalTaskJobRunner(job=job, task_instance=ti, 
ignore_ti_state=True)
+        with pytest.raises(AirflowTaskTimeout):
+            with timeout(30):
+                run_job(job=job, execute_callable=job_runner._execute)
+        ti.refresh_from_db()
+        assert (
+            "State of this instance has been externally set to success. 
Terminating instance."
+            not in caplog.text
+        )
+        lm.clear()
+
     @pytest.mark.parametrize("signal_type", [signal.SIGTERM, signal.SIGKILL])
     def test_process_os_signal_calls_on_failure_callback(
         self, monkeypatch, tmp_path, get_test_dag, signal_type
diff --git a/tests/listeners/slow_listener.py b/tests/listeners/slow_listener.py
new file mode 100644
index 0000000000..0575e50beb
--- /dev/null
+++ b/tests/listeners/slow_listener.py
@@ -0,0 +1,26 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import time
+
+from airflow.listeners import hookimpl
+
+
+@hookimpl
+def on_task_instance_success(previous_state, task_instance, session):
+    time.sleep(5)
diff --git a/tests/listeners/very_slow_listener.py 
b/tests/listeners/very_slow_listener.py
new file mode 100644
index 0000000000..688faded97
--- /dev/null
+++ b/tests/listeners/very_slow_listener.py
@@ -0,0 +1,26 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import time
+
+from airflow.listeners import hookimpl
+
+
+@hookimpl
+def on_task_instance_success(previous_state, task_instance, session):
+    time.sleep(30)

Reply via email to