This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7d79812b0f task runner: notify of component start and finish (#27855)
7d79812b0f is described below
commit 7d79812b0f96ae72531f78572d2cfc181074f72c
Author: Maciej Obuchowski <[email protected]>
AuthorDate: Thu Nov 24 13:38:30 2022 +0100
task runner: notify of component start and finish (#27855)
Signed-off-by: Maciej Obuchowski <[email protected]>
Signed-off-by: Maciej Obuchowski <[email protected]>
---
airflow/cli/commands/task_command.py | 21 ++++++++--
tests/listeners/file_write_listener.py | 44 +++++++++++++++++++
.../task/task_runner/test_standard_task_runner.py | 49 ++++++++++++++++++++++
3 files changed, 110 insertions(+), 4 deletions(-)
diff --git a/airflow/cli/commands/task_command.py
b/airflow/cli/commands/task_command.py
index c5feefb4eb..a217d2c78d 100644
--- a/airflow/cli/commands/task_command.py
+++ b/airflow/cli/commands/task_command.py
@@ -37,6 +37,7 @@ from airflow.configuration import conf
from airflow.exceptions import AirflowException, DagRunNotFound,
TaskInstanceNotFound
from airflow.executors.executor_loader import ExecutorLoader
from airflow.jobs.local_task_job import LocalTaskJob
+from airflow.listeners.listener import get_listener_manager
from airflow.models import DagPickle, TaskInstance
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DAG
@@ -313,6 +314,10 @@ def _capture_task_logs(ti: TaskInstance) ->
Generator[None, None, None]:
root_logger.handlers[:] = orig_handlers
+class TaskCommandMarker:
+ """Marker for listener hooks, to properly detect from which component they
are called."""
+
+
@cli_utils.action_cli(check_db=False)
def task_run(args, dag=None):
"""
@@ -364,6 +369,8 @@ def task_run(args, dag=None):
# processing hundreds of simultaneous tasks.
settings.reconfigure_orm(disable_connection_pool=True)
+ get_listener_manager().hook.on_starting(component=TaskCommandMarker())
+
if args.pickle:
print(f"Loading pickle id: {args.pickle}")
dag = get_dag_by_pickle(args.pickle)
@@ -380,11 +387,17 @@ def task_run(args, dag=None):
log.info("Running %s on host %s", ti, hostname)
- if args.interactive:
- _run_task_by_selected_method(args, dag, ti)
- else:
- with _capture_task_logs(ti):
+ try:
+ if args.interactive:
_run_task_by_selected_method(args, dag, ti)
+ else:
+ with _capture_task_logs(ti):
+ _run_task_by_selected_method(args, dag, ti)
+ finally:
+ try:
+
get_listener_manager().hook.before_stopping(component=TaskCommandMarker())
+ except Exception:
+ pass
@cli_utils.action_cli(check_db=False)
diff --git a/tests/listeners/file_write_listener.py
b/tests/listeners/file_write_listener.py
new file mode 100644
index 0000000000..7d51ad05c7
--- /dev/null
+++ b/tests/listeners/file_write_listener.py
@@ -0,0 +1,44 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import logging
+
+from airflow.cli.commands.task_command import TaskCommandMarker
+from airflow.listeners import hookimpl
+
+log = logging.getLogger(__name__)
+
+
+class FileWriteListener:
+ def __init__(self, path):
+ self.path = path
+
+ def write(self, line: str):
+ with open(self.path, "a") as f:
+ f.write(line + "\n")
+
+ @hookimpl
+ def on_starting(self, component):
+ if isinstance(component, TaskCommandMarker):
+ self.write("on_starting")
+
+ @hookimpl
+ def before_stopping(self, component):
+ if isinstance(component, TaskCommandMarker):
+ self.write("before_stopping")
diff --git a/tests/task/task_runner/test_standard_task_runner.py
b/tests/task/task_runner/test_standard_task_runner.py
index 7bcac6b4d6..c54a27ae89 100644
--- a/tests/task/task_runner/test_standard_task_runner.py
+++ b/tests/task/task_runner/test_standard_task_runner.py
@@ -29,6 +29,7 @@ import pytest
from airflow.config_templates.airflow_local_settings import
DEFAULT_LOGGING_CONFIG
from airflow.jobs.local_task_job import LocalTaskJob
+from airflow.listeners.listener import get_listener_manager
from airflow.models.dagbag import DagBag
from airflow.models.taskinstance import TaskInstance
from airflow.task.task_runner.standard_task_runner import StandardTaskRunner
@@ -37,6 +38,7 @@ from airflow.utils.platform import getuser
from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.timeout import timeout
+from tests.listeners.file_write_listener import FileWriteListener
from tests.test_utils.db import clear_db_runs
TEST_DAG_FOLDER = os.environ["AIRFLOW__CORE__DAGS_FOLDER"]
@@ -111,6 +113,53 @@ class TestStandardTaskRunner:
assert runner.return_code() is not None
+ def test_notifies_about_start_and_stop(self):
+ path_listener_writer = "/tmp/path_listener_writer"
+ try:
+ os.unlink(path_listener_writer)
+ except OSError:
+ pass
+
+ lm = get_listener_manager()
+ lm.add_listener(FileWriteListener(path_listener_writer))
+
+ dagbag = DagBag(
+ dag_folder=TEST_DAG_FOLDER,
+ include_examples=False,
+ )
+ dag = dagbag.dags.get("test_example_bash_operator")
+ task = dag.get_task("runme_1")
+
+ with create_session() as session:
+ dag.create_dagrun(
+ run_id="test",
+ data_interval=(DEFAULT_DATE, DEFAULT_DATE),
+ state=State.RUNNING,
+ start_date=DEFAULT_DATE,
+ session=session,
+ )
+ ti = TaskInstance(task=task, run_id="test")
+ job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True)
+ session.commit()
+ ti.refresh_from_task(task)
+
+ runner = StandardTaskRunner(job1)
+ runner.start()
+
+ # Wait until process sets its pgid to be equal to pid
+ with timeout(seconds=1):
+ while True:
+ runner_pgid = os.getpgid(runner.process.pid)
+ if runner_pgid == runner.process.pid:
+ break
+ time.sleep(0.01)
+
+ # Wait till process finishes
+ assert runner.return_code(timeout=10) is not None
+ with open(path_listener_writer) as f:
+ assert f.readline() == "on_starting\n"
+ assert f.readline() == "before_stopping\n"
+
def test_start_and_terminate_run_as_user(self):
local_task_job = mock.Mock()
local_task_job.task_instance = mock.MagicMock()