This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7d79812b0f task runner: notify of component start and finish (#27855)
7d79812b0f is described below

commit 7d79812b0f96ae72531f78572d2cfc181074f72c
Author: Maciej Obuchowski <[email protected]>
AuthorDate: Thu Nov 24 13:38:30 2022 +0100

    task runner: notify of component start and finish (#27855)
    
    Signed-off-by: Maciej Obuchowski <[email protected]>
    
    Signed-off-by: Maciej Obuchowski <[email protected]>
---
 airflow/cli/commands/task_command.py               | 21 ++++++++--
 tests/listeners/file_write_listener.py             | 44 +++++++++++++++++++
 .../task/task_runner/test_standard_task_runner.py  | 49 ++++++++++++++++++++++
 3 files changed, 110 insertions(+), 4 deletions(-)

diff --git a/airflow/cli/commands/task_command.py 
b/airflow/cli/commands/task_command.py
index c5feefb4eb..a217d2c78d 100644
--- a/airflow/cli/commands/task_command.py
+++ b/airflow/cli/commands/task_command.py
@@ -37,6 +37,7 @@ from airflow.configuration import conf
 from airflow.exceptions import AirflowException, DagRunNotFound, 
TaskInstanceNotFound
 from airflow.executors.executor_loader import ExecutorLoader
 from airflow.jobs.local_task_job import LocalTaskJob
+from airflow.listeners.listener import get_listener_manager
 from airflow.models import DagPickle, TaskInstance
 from airflow.models.baseoperator import BaseOperator
 from airflow.models.dag import DAG
@@ -313,6 +314,10 @@ def _capture_task_logs(ti: TaskInstance) -> 
Generator[None, None, None]:
             root_logger.handlers[:] = orig_handlers
 
 
+class TaskCommandMarker:
+    """Marker for listener hooks, to properly detect from which component they 
are called."""
+
+
 @cli_utils.action_cli(check_db=False)
 def task_run(args, dag=None):
     """
@@ -364,6 +369,8 @@ def task_run(args, dag=None):
     # processing hundreds of simultaneous tasks.
     settings.reconfigure_orm(disable_connection_pool=True)
 
+    get_listener_manager().hook.on_starting(component=TaskCommandMarker())
+
     if args.pickle:
         print(f"Loading pickle id: {args.pickle}")
         dag = get_dag_by_pickle(args.pickle)
@@ -380,11 +387,17 @@ def task_run(args, dag=None):
 
     log.info("Running %s on host %s", ti, hostname)
 
-    if args.interactive:
-        _run_task_by_selected_method(args, dag, ti)
-    else:
-        with _capture_task_logs(ti):
+    try:
+        if args.interactive:
             _run_task_by_selected_method(args, dag, ti)
+        else:
+            with _capture_task_logs(ti):
+                _run_task_by_selected_method(args, dag, ti)
+    finally:
+        try:
+            
get_listener_manager().hook.before_stopping(component=TaskCommandMarker())
+        except Exception:
+            pass
 
 
 @cli_utils.action_cli(check_db=False)
diff --git a/tests/listeners/file_write_listener.py 
b/tests/listeners/file_write_listener.py
new file mode 100644
index 0000000000..7d51ad05c7
--- /dev/null
+++ b/tests/listeners/file_write_listener.py
@@ -0,0 +1,44 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import logging
+
+from airflow.cli.commands.task_command import TaskCommandMarker
+from airflow.listeners import hookimpl
+
+log = logging.getLogger(__name__)
+
+
+class FileWriteListener:
+    def __init__(self, path):
+        self.path = path
+
+    def write(self, line: str):
+        with open(self.path, "a") as f:
+            f.write(line + "\n")
+
+    @hookimpl
+    def on_starting(self, component):
+        if isinstance(component, TaskCommandMarker):
+            self.write("on_starting")
+
+    @hookimpl
+    def before_stopping(self, component):
+        if isinstance(component, TaskCommandMarker):
+            self.write("before_stopping")
diff --git a/tests/task/task_runner/test_standard_task_runner.py 
b/tests/task/task_runner/test_standard_task_runner.py
index 7bcac6b4d6..c54a27ae89 100644
--- a/tests/task/task_runner/test_standard_task_runner.py
+++ b/tests/task/task_runner/test_standard_task_runner.py
@@ -29,6 +29,7 @@ import pytest
 
 from airflow.config_templates.airflow_local_settings import 
DEFAULT_LOGGING_CONFIG
 from airflow.jobs.local_task_job import LocalTaskJob
+from airflow.listeners.listener import get_listener_manager
 from airflow.models.dagbag import DagBag
 from airflow.models.taskinstance import TaskInstance
 from airflow.task.task_runner.standard_task_runner import StandardTaskRunner
@@ -37,6 +38,7 @@ from airflow.utils.platform import getuser
 from airflow.utils.session import create_session
 from airflow.utils.state import State
 from airflow.utils.timeout import timeout
+from tests.listeners.file_write_listener import FileWriteListener
 from tests.test_utils.db import clear_db_runs
 
 TEST_DAG_FOLDER = os.environ["AIRFLOW__CORE__DAGS_FOLDER"]
@@ -111,6 +113,53 @@ class TestStandardTaskRunner:
 
         assert runner.return_code() is not None
 
+    def test_notifies_about_start_and_stop(self):
+        path_listener_writer = "/tmp/path_listener_writer"
+        try:
+            os.unlink(path_listener_writer)
+        except OSError:
+            pass
+
+        lm = get_listener_manager()
+        lm.add_listener(FileWriteListener(path_listener_writer))
+
+        dagbag = DagBag(
+            dag_folder=TEST_DAG_FOLDER,
+            include_examples=False,
+        )
+        dag = dagbag.dags.get("test_example_bash_operator")
+        task = dag.get_task("runme_1")
+
+        with create_session() as session:
+            dag.create_dagrun(
+                run_id="test",
+                data_interval=(DEFAULT_DATE, DEFAULT_DATE),
+                state=State.RUNNING,
+                start_date=DEFAULT_DATE,
+                session=session,
+            )
+            ti = TaskInstance(task=task, run_id="test")
+            job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True)
+            session.commit()
+            ti.refresh_from_task(task)
+
+            runner = StandardTaskRunner(job1)
+            runner.start()
+
+            # Wait until process sets its pgid to be equal to pid
+            with timeout(seconds=1):
+                while True:
+                    runner_pgid = os.getpgid(runner.process.pid)
+                    if runner_pgid == runner.process.pid:
+                        break
+                    time.sleep(0.01)
+
+                # Wait till process finishes
+            assert runner.return_code(timeout=10) is not None
+            with open(path_listener_writer) as f:
+                assert f.readline() == "on_starting\n"
+                assert f.readline() == "before_stopping\n"
+
     def test_start_and_terminate_run_as_user(self):
         local_task_job = mock.Mock()
         local_task_job.task_instance = mock.MagicMock()

Reply via email to