This is an automated email from the ASF dual-hosted git repository.
amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0b341e6b920 Move listeners module to shared library for client server
separation (#59883)
0b341e6b920 is described below
commit 0b341e6b92040360f86f5697da119c58a91aa4c2
Author: Amogh Desai <[email protected]>
AuthorDate: Thu Jan 8 19:46:19 2026 +0530
Move listeners module to shared library for client server separation
(#59883)
Extract the listeners infrastructure to `shared/listeners/` library to
eliminate cross dependencies between airflow-core and task-sdk.
- ListenerManager and hookimpl marker now in shared library
- Hook specs split by callers:
- shared: lifecycle, taskinstance (called from both sdk and core)
- core: dagrun, asset, importerrors (called only from core)
- sdk registers only specs it actually uses (lifecycle, taskinstance)
- core registers all specs for full listener support
---
airflow-core/pyproject.toml | 5 +
airflow-core/src/airflow/_shared/listeners | 1 +
airflow-core/src/airflow/listeners/__init__.py | 4 +-
airflow-core/src/airflow/listeners/listener.py | 82 +++--------
airflow-core/src/airflow/listeners/spec/asset.py | 1 +
.../src/airflow/listeners/spec/importerrors.py | 1 +
.../core_api/routes/public/test_dag_run.py | 11 +-
.../core_api/routes/public/test_task_instances.py | 13 +-
airflow-core/tests/unit/assets/test_manager.py | 9 +-
.../tests/unit/dag_processing/test_collection.py | 6 +-
airflow-core/tests/unit/jobs/test_base_job.py | 5 +-
airflow-core/tests/unit/jobs/test_scheduler_job.py | 11 +-
.../tests/unit/listeners/test_asset_listener.py | 14 +-
.../tests/unit/listeners/test_listeners.py | 60 ++++----
devel-common/src/tests_common/pytest_plugin.py | 40 +++++
.../src/airflow/providers/common/compat/sdk.py | 6 +
.../providers/openlineage/plugins/listener.py | 3 +-
.../tests/system/openlineage/conftest.py | 9 +-
.../unit/openlineage/plugins/test_execution.py | 42 +++---
pyproject.toml | 3 +
.../listeners/pyproject.toml | 63 ++++----
.../src/airflow_shared}/listeners/__init__.py | 0
.../src/airflow_shared}/listeners/listener.py | 44 +++---
.../src/airflow_shared/listeners/spec}/__init__.py | 6 -
.../airflow_shared}/listeners/spec/lifecycle.py | 0
.../airflow_shared}/listeners/spec/taskinstance.py | 10 +-
.../listeners/tests/conftest.py | 4 +-
.../listeners/tests}/listeners/__init__.py | 4 -
.../tests/listeners/test_listener_manager.py | 164 +++++++++++++++++++++
task-sdk/pyproject.toml | 7 +-
task-sdk/src/airflow/sdk/_shared/listeners | 1 +
.../src/airflow/sdk/execution_time/task_runner.py | 2 +-
.../src/airflow/sdk/listener.py | 31 ++--
task-sdk/tests/conftest.py | 8 +-
task-sdk/tests/task_sdk/docs/test_public_api.py | 1 +
.../task_sdk/execution_time/test_task_runner.py | 45 +++---
36 files changed, 433 insertions(+), 283 deletions(-)
diff --git a/airflow-core/pyproject.toml b/airflow-core/pyproject.toml
index 218ed550412..73ee1cd4508 100644
--- a/airflow-core/pyproject.toml
+++ b/airflow-core/pyproject.toml
@@ -161,6 +161,9 @@ dependencies = [
# Start of shared configuration dependencies
"pyyaml>=6.0.3",
# End of shared configuration dependencies
+ # Start of shared listeners dependencies
+ "pluggy>=1.5.0",
+ # End of shared listeners dependencies
]
@@ -235,6 +238,7 @@ exclude = [
"../shared/secrets_backend/src/airflow_shared/secrets_backend" =
"src/airflow/_shared/secrets_backend"
"../shared/secrets_masker/src/airflow_shared/secrets_masker" =
"src/airflow/_shared/secrets_masker"
"../shared/timezones/src/airflow_shared/timezones" =
"src/airflow/_shared/timezones"
+"../shared/listeners/src/airflow_shared/listeners" =
"src/airflow/_shared/listeners"
"../shared/plugins_manager/src/airflow_shared/plugins_manager" =
"src/airflow/_shared/plugins_manager"
[tool.hatch.build.targets.custom]
@@ -305,6 +309,7 @@ apache-airflow-devel-common = { workspace = true }
shared_distributions = [
"apache-airflow-shared-configuration",
"apache-airflow-shared-dagnode",
+ "apache-airflow-shared-listeners",
"apache-airflow-shared-logging",
"apache-airflow-shared-module-loading",
"apache-airflow-shared-observability",
diff --git a/airflow-core/src/airflow/_shared/listeners
b/airflow-core/src/airflow/_shared/listeners
new file mode 120000
index 00000000000..54346425d37
--- /dev/null
+++ b/airflow-core/src/airflow/_shared/listeners
@@ -0,0 +1 @@
+../../../../shared/listeners/src/airflow_shared/listeners
\ No newline at end of file
diff --git a/airflow-core/src/airflow/listeners/__init__.py
b/airflow-core/src/airflow/listeners/__init__.py
index 87840b50e2f..670ecde854c 100644
--- a/airflow-core/src/airflow/listeners/__init__.py
+++ b/airflow-core/src/airflow/listeners/__init__.py
@@ -17,6 +17,6 @@
# under the License.
from __future__ import annotations
-from pluggy import HookimplMarker
+from airflow._shared.listeners import hookimpl
-hookimpl = HookimplMarker("airflow")
+__all__ = ["hookimpl"]
diff --git a/airflow-core/src/airflow/listeners/listener.py
b/airflow-core/src/airflow/listeners/listener.py
index 08869f50947..06be7a1b9b9 100644
--- a/airflow-core/src/airflow/listeners/listener.py
+++ b/airflow-core/src/airflow/listeners/listener.py
@@ -17,72 +17,36 @@
# under the License.
from __future__ import annotations
-import logging
from functools import cache
-from typing import TYPE_CHECKING
-
-import pluggy
+from airflow._shared.listeners.listener import ListenerManager
+from airflow._shared.listeners.spec import lifecycle, taskinstance
+from airflow.listeners.spec import asset, dagrun, importerrors
from airflow.plugins_manager import integrate_listener_plugins
-if TYPE_CHECKING:
- from pluggy._hooks import _HookRelay
-
-log = logging.getLogger(__name__)
-
-
-def _before_hookcall(hook_name, hook_impls, kwargs):
- log.debug("Calling %r with %r", hook_name, kwargs)
- log.debug("Hook impls: %s", hook_impls)
-
-
-def _after_hookcall(outcome, hook_name, hook_impls, kwargs):
- log.debug("Result from %r: %s", hook_name, outcome.get_result())
-
-
-class ListenerManager:
- """Manage listener registration and provides hook property for calling
them."""
-
- def __init__(self):
- from airflow.listeners.spec import (
- asset,
- dagrun,
- importerrors,
- lifecycle,
- taskinstance,
- )
-
- self.pm = pluggy.PluginManager("airflow")
- self.pm.add_hookcall_monitoring(_before_hookcall, _after_hookcall)
- self.pm.add_hookspecs(lifecycle)
- self.pm.add_hookspecs(dagrun)
- self.pm.add_hookspecs(asset)
- self.pm.add_hookspecs(taskinstance)
- self.pm.add_hookspecs(importerrors)
-
- @property
- def has_listeners(self) -> bool:
- return bool(self.pm.get_plugins())
-
- @property
- def hook(self) -> _HookRelay:
- """Return hook, on which plugin methods specified in spec can be
called."""
- return self.pm.hook
-
- def add_listener(self, listener):
- if self.pm.is_registered(listener):
- return
- self.pm.register(listener)
-
- def clear(self):
- """Remove registered plugins."""
- for plugin in self.pm.get_plugins():
- self.pm.unregister(plugin)
-
@cache
def get_listener_manager() -> ListenerManager:
- """Get singleton listener manager."""
+ """
+ Get a listener manager for Airflow core.
+
+ Registers the following listeners:
+ - lifecycle: on_starting, before_stopping
+ - dagrun: on_dag_run_running, on_dag_run_success, on_dag_run_failed
+ - taskinstance: on_task_instance_running, on_task_instance_success, etc.
+ - asset: on_asset_created, on_asset_changed, etc.
+ - importerrors: on_new_dag_import_error, on_existing_dag_import_error
+ """
_listener_manager = ListenerManager()
+
+ _listener_manager.add_hookspecs(lifecycle)
+ _listener_manager.add_hookspecs(dagrun)
+ _listener_manager.add_hookspecs(taskinstance)
+ _listener_manager.add_hookspecs(asset)
+ _listener_manager.add_hookspecs(importerrors)
+
integrate_listener_plugins(_listener_manager)
return _listener_manager
+
+
+__all__ = ["get_listener_manager", "ListenerManager"]
diff --git a/airflow-core/src/airflow/listeners/spec/asset.py
b/airflow-core/src/airflow/listeners/spec/asset.py
index 25d1aacf15f..05ba0809bcd 100644
--- a/airflow-core/src/airflow/listeners/spec/asset.py
+++ b/airflow-core/src/airflow/listeners/spec/asset.py
@@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
from __future__ import annotations
from typing import TYPE_CHECKING
diff --git a/airflow-core/src/airflow/listeners/spec/importerrors.py
b/airflow-core/src/airflow/listeners/spec/importerrors.py
index 2cb2b4e454d..048fb38ffa1 100644
--- a/airflow-core/src/airflow/listeners/spec/importerrors.py
+++ b/airflow-core/src/airflow/listeners/spec/importerrors.py
@@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
from __future__ import annotations
from pluggy import HookspecMarker
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py
index 87a7b932b74..ef0e4f792e9 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py
@@ -27,7 +27,6 @@ from sqlalchemy import func, select
from airflow._shared.timezones import timezone
from airflow.api_fastapi.core_api.datamodels.dag_versions import
DagVersionResponse
-from airflow.listeners.listener import get_listener_manager
from airflow.models import DagModel, DagRun, Log
from airflow.models.asset import AssetEvent, AssetModel
from airflow.providers.standard.operators.empty import EmptyOperator
@@ -1285,12 +1284,6 @@ class TestPatchDagRun:
body = response.json()
assert body["detail"][0]["msg"] == "Input should be 'queued',
'success' or 'failed'"
- @pytest.fixture(autouse=True)
- def clean_listener_manager(self):
- get_listener_manager().clear()
- yield
- get_listener_manager().clear()
-
@pytest.mark.parametrize(
("state", "listener_state"),
[
@@ -1300,11 +1293,11 @@ class TestPatchDagRun:
],
)
@pytest.mark.usefixtures("configure_git_connection_for_dag_bundle")
- def test_patch_dag_run_notifies_listeners(self, test_client, state,
listener_state):
+ def test_patch_dag_run_notifies_listeners(self, test_client, state,
listener_state, listener_manager):
from unit.listeners.class_listener import ClassBasedListener
listener = ClassBasedListener()
- get_listener_manager().add_listener(listener)
+ listener_manager(listener)
response =
test_client.patch(f"/dags/{DAG1_ID}/dagRuns/{DAG1_RUN1_ID}", json={"state":
state})
assert response.status_code == 200
assert listener.state == listener_state
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
index b4974b00e52..6a426d8e418 100644
---
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
+++
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
@@ -33,7 +33,6 @@ from airflow.dag_processing.bundles.manager import
DagBundlesManager
from airflow.dag_processing.dagbag import DagBag, sync_bag_to_db
from airflow.jobs.job import Job
from airflow.jobs.triggerer_job_runner import TriggererJobRunner
-from airflow.listeners.listener import get_listener_manager
from airflow.models import DagRun, Log, TaskInstance
from airflow.models.dag_version import DagVersion
from airflow.models.hitl import HITLDetail
@@ -4084,12 +4083,6 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
TASK_ID = "print_the_context"
RUN_ID = "TEST_DAG_RUN_ID"
- @pytest.fixture(autouse=True)
- def clean_listener_manager(self):
- get_listener_manager().clear()
- yield
- get_listener_manager().clear()
-
@pytest.mark.parametrize(
("state", "listener_state"),
[
@@ -4098,13 +4091,15 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
("skipped", []),
],
)
- def test_patch_task_instance_notifies_listeners(self, test_client,
session, state, listener_state):
+ def test_patch_task_instance_notifies_listeners(
+ self, test_client, session, state, listener_state, listener_manager
+ ):
from unit.listeners.class_listener import ClassBasedListener
self.create_task_instances(session)
listener = ClassBasedListener()
- get_listener_manager().add_listener(listener)
+ listener_manager(listener)
test_client.patch(
self.ENDPOINT_URL,
json={
diff --git a/airflow-core/tests/unit/assets/test_manager.py
b/airflow-core/tests/unit/assets/test_manager.py
index 8036b6d8352..7929ae7c0c5 100644
--- a/airflow-core/tests/unit/assets/test_manager.py
+++ b/airflow-core/tests/unit/assets/test_manager.py
@@ -30,7 +30,6 @@ from sqlalchemy.orm import Session
from airflow import settings
from airflow.assets.manager import AssetManager
-from airflow.listeners.listener import get_listener_manager
from airflow.models.asset import (
AssetAliasModel,
AssetDagRunQueue,
@@ -183,11 +182,11 @@ class TestAssetManager:
assert
session.scalar(select(func.count()).select_from(AssetDagRunQueue)) == 0
def test_register_asset_change_notifies_asset_listener(
- self, session, mock_task_instance, testing_dag_bundle
+ self, session, mock_task_instance, testing_dag_bundle, listener_manager
):
asset_manager = AssetManager()
asset_listener.clear()
- get_listener_manager().add_listener(asset_listener)
+ listener_manager(asset_listener)
bundle_name = "testing"
@@ -207,10 +206,10 @@ class TestAssetManager:
assert len(asset_listener.changed) == 1
assert asset_listener.changed[0].uri == asset.uri
- def test_create_assets_notifies_asset_listener(self, session):
+ def test_create_assets_notifies_asset_listener(self, session,
listener_manager):
asset_manager = AssetManager()
asset_listener.clear()
- get_listener_manager().add_listener(asset_listener)
+ listener_manager(asset_listener)
asset = Asset(uri="test://asset1", name="test_asset_1")
diff --git a/airflow-core/tests/unit/dag_processing/test_collection.py
b/airflow-core/tests/unit/dag_processing/test_collection.py
index ad4a8a73382..458f773046e 100644
--- a/airflow-core/tests/unit/dag_processing/test_collection.py
+++ b/airflow-core/tests/unit/dag_processing/test_collection.py
@@ -41,7 +41,6 @@ from airflow.dag_processing.collection import (
update_dag_parsing_results_in_db,
)
from airflow.exceptions import SerializationError
-from airflow.listeners.listener import get_listener_manager
from airflow.models import DagModel, DagRun
from airflow.models.asset import (
AssetActive,
@@ -321,12 +320,11 @@ class TestUpdateDagParsingResults:
clear_db_import_errors()
@pytest.fixture(name="dag_import_error_listener")
- def _dag_import_error_listener(self):
+ def _dag_import_error_listener(self, listener_manager):
from unit.listeners import dag_import_error_listener
- get_listener_manager().add_listener(dag_import_error_listener)
+ listener_manager(dag_import_error_listener)
yield dag_import_error_listener
- get_listener_manager().clear()
dag_import_error_listener.clear()
@mark_fab_auth_manager_test
diff --git a/airflow-core/tests/unit/jobs/test_base_job.py
b/airflow-core/tests/unit/jobs/test_base_job.py
index f8f780fea07..fbae2a96837 100644
--- a/airflow-core/tests/unit/jobs/test_base_job.py
+++ b/airflow-core/tests/unit/jobs/test_base_job.py
@@ -28,7 +28,6 @@ from sqlalchemy.exc import OperationalError
from airflow._shared.timezones import timezone
from airflow.executors.local_executor import LocalExecutor
from airflow.jobs.job import Job, health_check_threshold, most_recent_job,
perform_heartbeat, run_job
-from airflow.listeners.listener import get_listener_manager
from airflow.utils.session import create_session
from airflow.utils.state import State
@@ -68,11 +67,11 @@ class TestJob:
assert job.state == State.SUCCESS
assert job.end_date is not None
- def test_base_job_respects_plugin_lifecycle(self, dag_maker):
+ def test_base_job_respects_plugin_lifecycle(self, dag_maker,
listener_manager):
"""
Test if DagRun is successful, and if Success callbacks is defined, it
is sent to DagFileProcessor.
"""
- get_listener_manager().add_listener(lifecycle_listener)
+ listener_manager(lifecycle_listener)
job = Job()
job_runner = MockJobRunner(job=job, func=lambda: sys.exit(0))
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py
b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index 74932d3d1ab..2b9ed36c7cb 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -114,7 +114,6 @@ from tests_common.test_utils.mock_executor import
MockExecutor
from tests_common.test_utils.mock_operators import CustomOperator
from tests_common.test_utils.taskinstance import create_task_instance,
run_task_instance
from unit.listeners import dag_listener
-from unit.listeners.test_listeners import get_listener_manager
from unit.models import TEST_DAGS_FOLDER
if TYPE_CHECKING:
@@ -3190,7 +3189,9 @@ class TestSchedulerJob:
("state", "expected_callback_msg"), [(State.SUCCESS, "success"),
(State.FAILED, "task_failure")]
)
@conf_vars({("scheduler", "use_job_schedule"): "False"})
- def test_dagrun_plugins_are_notified(self, state, expected_callback_msg,
dag_maker, session):
+ def test_dagrun_plugins_are_notified(
+ self, state, expected_callback_msg, dag_maker, session,
listener_manager
+ ):
"""
Test if DagRun is successful, and if Success callbacks is defined, it
is sent to DagFileProcessor.
"""
@@ -3203,7 +3204,7 @@ class TestSchedulerJob:
EmptyOperator(task_id="dummy")
dag_listener.clear()
- get_listener_manager().add_listener(dag_listener)
+ listener_manager(dag_listener)
scheduler_job = Job(executor=self.null_exec)
self.job_runner = SchedulerJobRunner(job=scheduler_job)
@@ -3374,7 +3375,7 @@ class TestSchedulerJob:
session.close()
@conf_vars({("scheduler", "use_job_schedule"): "False"})
- def test_dagrun_notify_called_success(self, dag_maker):
+ def test_dagrun_notify_called_success(self, dag_maker, listener_manager):
with dag_maker(
dag_id="test_dagrun_notify_called",
on_success_callback=lambda x: print("success"),
@@ -3383,7 +3384,7 @@ class TestSchedulerJob:
EmptyOperator(task_id="dummy")
dag_listener.clear()
- get_listener_manager().add_listener(dag_listener)
+ listener_manager(dag_listener)
executor = MockExecutor(do_update=False)
diff --git a/airflow-core/tests/unit/listeners/test_asset_listener.py
b/airflow-core/tests/unit/listeners/test_asset_listener.py
index 3b5e933f7d8..b2ce78c2443 100644
--- a/airflow-core/tests/unit/listeners/test_asset_listener.py
+++ b/airflow-core/tests/unit/listeners/test_asset_listener.py
@@ -18,7 +18,6 @@ from __future__ import annotations
import pytest
-from airflow.listeners.listener import get_listener_manager
from airflow.models.asset import AssetModel
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.sdk.definitions.asset import Asset
@@ -28,19 +27,18 @@ from unit.listeners import asset_listener
@pytest.fixture(autouse=True)
-def clean_listener_manager():
- lm = get_listener_manager()
- lm.clear()
- lm.add_listener(asset_listener)
+def clean_listener_state():
+ """Clear listener state after each test."""
yield
- lm = get_listener_manager()
- lm.clear()
asset_listener.clear()
@pytest.mark.db_test
@provide_session
-def
test_asset_listener_on_asset_changed_gets_calls(create_task_instance_of_operator,
session):
+def test_asset_listener_on_asset_changed_gets_calls(
+ create_task_instance_of_operator, session, listener_manager
+):
+ listener_manager(asset_listener)
asset_uri = "test://asset/"
asset_name = "test_asset_uri"
asset_group = "test-group"
diff --git a/airflow-core/tests/unit/listeners/test_listeners.py
b/airflow-core/tests/unit/listeners/test_listeners.py
index 5a2a9ff8bb7..aad2ea7b6e8 100644
--- a/airflow-core/tests/unit/listeners/test_listeners.py
+++ b/airflow-core/tests/unit/listeners/test_listeners.py
@@ -25,7 +25,6 @@ import pytest
from airflow._shared.timezones import timezone
from airflow.exceptions import AirflowException
from airflow.jobs.job import Job, run_job
-from airflow.listeners.listener import get_listener_manager
from airflow.providers.standard.operators.bash import BashOperator
from airflow.utils.session import provide_session
from airflow.utils.state import DagRunState, TaskInstanceState
@@ -58,20 +57,16 @@ TEST_DAG_FOLDER = os.environ["AIRFLOW__CORE__DAGS_FOLDER"]
@pytest.fixture(autouse=True)
-def clean_listener_manager():
- lm = get_listener_manager()
- lm.clear()
+def clean_listener_state():
+ """Clear listener state after each test."""
yield
- lm = get_listener_manager()
- lm.clear()
for listener in LISTENERS:
listener.clear()
@provide_session
-def test_listener_gets_calls(create_task_instance, session):
- lm = get_listener_manager()
- lm.add_listener(full_listener)
+def test_listener_gets_calls(create_task_instance, session, listener_manager):
+ listener_manager(full_listener)
ti = create_task_instance(session=session, state=TaskInstanceState.QUEUED)
# Using ti.run() instead of ti._run_raw_task() to capture state change to
RUNNING
@@ -84,12 +79,11 @@ def test_listener_gets_calls(create_task_instance, session):
@provide_session
-def test_multiple_listeners(create_task_instance, session):
- lm = get_listener_manager()
- lm.add_listener(full_listener)
- lm.add_listener(lifecycle_listener)
+def test_multiple_listeners(create_task_instance, session, listener_manager):
+ listener_manager(full_listener)
+ listener_manager(lifecycle_listener)
class_based_listener = class_listener.ClassBasedListener()
- lm.add_listener(class_based_listener)
+ listener_manager(class_based_listener)
job = Job()
job_runner = MockJobRunner(job=job)
@@ -105,9 +99,8 @@ def test_multiple_listeners(create_task_instance, session):
@provide_session
-def test_listener_gets_only_subscribed_calls(create_task_instance, session):
- lm = get_listener_manager()
- lm.add_listener(partial_listener)
+def test_listener_gets_only_subscribed_calls(create_task_instance, session,
listener_manager):
+ listener_manager(partial_listener)
ti = create_task_instance(session=session, state=TaskInstanceState.QUEUED)
# Using ti.run() instead of ti._run_raw_task() to capture state change to
RUNNING
@@ -120,9 +113,8 @@ def
test_listener_gets_only_subscribed_calls(create_task_instance, session):
@provide_session
-def test_listener_suppresses_exceptions(create_task_instance, session,
cap_structlog):
- lm = get_listener_manager()
- lm.add_listener(throwing_listener)
+def test_listener_suppresses_exceptions(create_task_instance, session,
cap_structlog, listener_manager):
+ listener_manager(throwing_listener)
ti = create_task_instance(session=session, state=TaskInstanceState.QUEUED)
ti.run()
@@ -130,9 +122,8 @@ def
test_listener_suppresses_exceptions(create_task_instance, session, cap_struc
@provide_session
-def
test_listener_captures_failed_taskinstances(create_task_instance_of_operator,
session):
- lm = get_listener_manager()
- lm.add_listener(full_listener)
+def
test_listener_captures_failed_taskinstances(create_task_instance_of_operator,
session, listener_manager):
+ listener_manager(full_listener)
ti = create_task_instance_of_operator(
BashOperator, dag_id=DAG_ID, logical_date=LOGICAL_DATE,
task_id=TASK_ID, bash_command="exit 1"
@@ -145,9 +136,10 @@ def
test_listener_captures_failed_taskinstances(create_task_instance_of_operator
@provide_session
-def
test_listener_captures_longrunning_taskinstances(create_task_instance_of_operator,
session):
- lm = get_listener_manager()
- lm.add_listener(full_listener)
+def test_listener_captures_longrunning_taskinstances(
+ create_task_instance_of_operator, session, listener_manager
+):
+ listener_manager(full_listener)
ti = create_task_instance_of_operator(
BashOperator, dag_id=DAG_ID, logical_date=LOGICAL_DATE,
task_id=TASK_ID, bash_command="sleep 5"
@@ -159,10 +151,9 @@ def
test_listener_captures_longrunning_taskinstances(create_task_instance_of_ope
@provide_session
-def test_class_based_listener(create_task_instance, session):
- lm = get_listener_manager()
+def test_class_based_listener(create_task_instance, session, listener_manager):
listener = class_listener.ClassBasedListener()
- lm.add_listener(listener)
+ listener_manager(listener)
ti = create_task_instance(session=session, state=TaskInstanceState.QUEUED)
ti.run()
@@ -170,16 +161,15 @@ def test_class_based_listener(create_task_instance,
session):
assert listener.state == [TaskInstanceState.RUNNING,
TaskInstanceState.SUCCESS, DagRunState.SUCCESS]
-def test_listener_logs_call(caplog, create_task_instance, session):
- caplog.set_level(logging.DEBUG, logger="airflow.listeners.listener")
- lm = get_listener_manager()
- lm.add_listener(full_listener)
+def test_listener_logs_call(caplog, create_task_instance, session,
listener_manager):
+ caplog.set_level(logging.DEBUG,
logger="airflow.sdk._shared.listeners.listener")
+ listener_manager(full_listener)
ti = create_task_instance(session=session, state=TaskInstanceState.QUEUED)
ti.run()
- listener_logs = [r for r in caplog.record_tuples if r[0] ==
"airflow.listeners.listener"]
- assert all(r[:-1] == ("airflow.listeners.listener", logging.DEBUG) for r
in listener_logs)
+ listener_logs = [r for r in caplog.record_tuples if r[0] ==
"airflow.sdk._shared.listeners.listener"]
+ assert all(r[:-1] == ("airflow.sdk._shared.listeners.listener",
logging.DEBUG) for r in listener_logs)
assert listener_logs[0][-1].startswith("Calling 'on_task_instance_running'
with {'")
assert listener_logs[1][-1].startswith("Hook impls: [<HookImpl plugin")
assert listener_logs[2][-1] == "Result from 'on_task_instance_running': []"
diff --git a/devel-common/src/tests_common/pytest_plugin.py
b/devel-common/src/tests_common/pytest_plugin.py
index b7e79451a85..9dd6e782193 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -2885,3 +2885,43 @@ def mock_task_instance():
return mock_ti
return _create_mock_task_instance
+
+
[email protected]
+def listener_manager():
+ """
+ Fixture that provides a listener manager for tests.
+
+ This fixture registers listeners with both the core listener manager
+ (used by Jobs, DAG runs, etc.) and the SDK listener manager (used by
+ task execution). This ensures listeners work correctly regardless of
+ which code path calls them.
+
+ Usage:
+ def test_something(listener_manager):
+ listener_manager(full_listener)
+ """
+ from airflow.listeners.listener import get_listener_manager as get_core_lm
+
+ try:
+ from airflow.sdk.listener import get_listener_manager as get_sdk_lm
+ except ImportError:
+ get_sdk_lm = None
+
+ core_lm = get_core_lm()
+ sdk_lm = get_sdk_lm() if get_sdk_lm else None
+
+ core_lm.clear()
+ if sdk_lm:
+ sdk_lm.clear()
+
+ def add_listener(listener):
+ core_lm.add_listener(listener)
+ if sdk_lm:
+ sdk_lm.add_listener(listener)
+
+ yield add_listener
+
+ core_lm.clear()
+ if sdk_lm:
+ sdk_lm.clear()
diff --git a/providers/common/compat/src/airflow/providers/common/compat/sdk.py
b/providers/common/compat/src/airflow/providers/common/compat/sdk.py
index 67670f85886..e29bfb1b42f 100644
--- a/providers/common/compat/src/airflow/providers/common/compat/sdk.py
+++ b/providers/common/compat/src/airflow/providers/common/compat/sdk.py
@@ -68,6 +68,7 @@ if TYPE_CHECKING:
task_group as task_group,
teardown as teardown,
)
+ from airflow.sdk._shared.listeners import hookimpl as hookimpl
from airflow.sdk.bases.decorator import (
DecoratedMappedOperator as DecoratedMappedOperator,
DecoratedOperator as DecoratedOperator,
@@ -92,6 +93,7 @@ if TYPE_CHECKING:
TaskDeferred as TaskDeferred,
XComNotFound as XComNotFound,
)
+ from airflow.sdk.listener import get_listener_manager as
get_listener_manager
from airflow.sdk.log import redact as redact
from airflow.sdk.observability.stats import Stats as Stats
from airflow.sdk.plugins_manager import AirflowPlugin as AirflowPlugin
@@ -257,6 +259,10 @@ _IMPORT_MAP: dict[str, str | tuple[str, ...]] = {
"airflow.utils.log.secrets_masker",
),
#
============================================================================
+ # Listeners
+ #
============================================================================
+ "hookimpl": ("airflow.sdk._shared.listeners", "airflow.listeners"),
+ "get_listener_manager": ("airflow.sdk.listener",
"airflow.listeners.listener"),
# Configuration
#
============================================================================
"conf": ("airflow.sdk.configuration", "airflow.configuration"),
diff --git
a/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py
b/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py
index cb90f80d198..6a530c2f93a 100644
---
a/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py
+++
b/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py
@@ -28,9 +28,8 @@ import psutil
from openlineage.client.serde import Serde
from airflow import settings
-from airflow.listeners import hookimpl
from airflow.models import DagRun, TaskInstance
-from airflow.providers.common.compat.sdk import Stats, timeout, timezone
+from airflow.providers.common.compat.sdk import Stats, hookimpl, timeout,
timezone
from airflow.providers.openlineage import conf
from airflow.providers.openlineage.extractors import ExtractorManager,
OperatorLineage
from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter,
RunState
diff --git a/providers/openlineage/tests/system/openlineage/conftest.py
b/providers/openlineage/tests/system/openlineage/conftest.py
index 4b45c39ba41..aa3b87ad553 100644
--- a/providers/openlineage/tests/system/openlineage/conftest.py
+++ b/providers/openlineage/tests/system/openlineage/conftest.py
@@ -18,19 +18,14 @@ from __future__ import annotations
import pytest
-from airflow.listeners.listener import get_listener_manager
from airflow.providers.openlineage.plugins.listener import OpenLineageListener
from system.openlineage.transport.variable import VariableTransport
@pytest.fixture(autouse=True)
-def set_transport_variable():
- lm = get_listener_manager()
- lm.clear()
+def set_transport_variable(listener_manager):
listener = OpenLineageListener()
listener.adapter._client =
listener.adapter.get_or_create_openlineage_client()
listener.adapter._client.transport = VariableTransport({})
- lm.add_listener(listener)
- yield
- lm.clear()
+ listener_manager(listener)
diff --git
a/providers/openlineage/tests/unit/openlineage/plugins/test_execution.py
b/providers/openlineage/tests/unit/openlineage/plugins/test_execution.py
index 488fc675d0f..ffb32b6ff1d 100644
--- a/providers/openlineage/tests/unit/openlineage/plugins/test_execution.py
+++ b/providers/openlineage/tests/unit/openlineage/plugins/test_execution.py
@@ -27,7 +27,6 @@ from pathlib import Path
import pytest
from airflow.jobs.job import Job
-from airflow.listeners.listener import get_listener_manager
from airflow.models import TaskInstance
from airflow.providers.google.cloud.openlineage.utils import
get_from_nullable_chain
from airflow.providers.openlineage.plugins.listener import OpenLineageListener
@@ -75,22 +74,15 @@ with tempfile.TemporaryDirectory(prefix="venv") as tmp_dir:
def teardown_method(self):
clear_db_runs()
- @pytest.fixture(autouse=True)
- def clean_listener_manager(self):
- get_listener_manager().clear()
- yield
- get_listener_manager().clear()
-
- def setup_job(self, task_name, run_id):
+ def setup_job(self, task_name, run_id, listener_manager):
from airflow.jobs.local_task_job_runner import LocalTaskJobRunner
dirpath = Path(tmp_dir)
if dirpath.exists():
shutil.rmtree(dirpath)
dirpath.mkdir(exist_ok=True, parents=True)
- lm = get_listener_manager()
listener = OpenLineageListener()
- lm.add_listener(listener)
+ listener_manager(listener)
dagbag = DagBag(
dag_folder=TEST_DAG_FOLDER,
@@ -116,10 +108,10 @@ with tempfile.TemporaryDirectory(prefix="venv") as
tmp_dir:
return job_runner.task_runner.return_code(timeout=60)
@conf_vars({("openlineage", "transport"): f'{{"type": "file",
"log_file_path": "{listener_path}"}}'})
- def test_not_stalled_task_emits_proper_lineage(self):
+ def test_not_stalled_task_emits_proper_lineage(self, listener_manager):
task_name = "execute_no_stall"
run_id = "test1"
- self.setup_job(task_name, run_id)
+ self.setup_job(task_name, run_id, listener_manager)
events = get_sorted_events(tmp_dir)
log.info(json.dumps(events, indent=2, sort_keys=True))
@@ -127,10 +119,10 @@ with tempfile.TemporaryDirectory(prefix="venv") as
tmp_dir:
assert has_value_in_events(events, ["inputs", "name"],
"on-complete")
@conf_vars({("openlineage", "transport"): f'{{"type": "file",
"log_file_path": "{listener_path}"}}'})
- def test_not_stalled_failing_task_emits_proper_lineage(self):
+ def test_not_stalled_failing_task_emits_proper_lineage(self,
listener_manager):
task_name = "execute_fail"
run_id = "test_failure"
- self.setup_job(task_name, run_id)
+ self.setup_job(task_name, run_id, listener_manager)
events = get_sorted_events(tmp_dir)
assert has_value_in_events(events, ["inputs", "name"], "on-start")
@@ -142,8 +134,10 @@ with tempfile.TemporaryDirectory(prefix="venv") as tmp_dir:
("openlineage", "execution_timeout"): "15",
}
)
- def test_short_stalled_task_emits_proper_lineage(self):
- self.setup_job("execute_short_stall",
"test_short_stalled_task_emits_proper_lineage")
+ def test_short_stalled_task_emits_proper_lineage(self,
listener_manager):
+ self.setup_job(
+ "execute_short_stall",
"test_short_stalled_task_emits_proper_lineage", listener_manager
+ )
events = get_sorted_events(tmp_dir)
assert has_value_in_events(events, ["inputs", "name"], "on-start")
assert has_value_in_events(events, ["inputs", "name"],
"on-complete")
@@ -154,18 +148,23 @@ with tempfile.TemporaryDirectory(prefix="venv") as
tmp_dir:
("openlineage", "execution_timeout"): "3",
}
)
- def
test_short_stalled_task_extraction_with_low_execution_is_killed_by_ol_timeout(self):
+ def
test_short_stalled_task_extraction_with_low_execution_is_killed_by_ol_timeout(
+ self, listener_manager
+ ):
self.setup_job(
"execute_short_stall",
"test_short_stalled_task_extraction_with_low_execution_is_killed_by_ol_timeout",
+ listener_manager,
)
events = get_sorted_events(tmp_dir)
assert has_value_in_events(events, ["inputs", "name"], "on-start")
assert not has_value_in_events(events, ["inputs", "name"],
"on-complete")
@conf_vars({("openlineage", "transport"): f'{{"type": "file",
"log_file_path": "{listener_path}"}}'})
- def test_mid_stalled_task_is_killed_by_ol_timeout(self):
- self.setup_job("execute_mid_stall",
"test_mid_stalled_task_is_killed_by_openlineage")
+ def test_mid_stalled_task_is_killed_by_ol_timeout(self,
listener_manager):
+ self.setup_job(
+ "execute_mid_stall",
"test_mid_stalled_task_is_killed_by_openlineage", listener_manager
+ )
events = get_sorted_events(tmp_dir)
assert has_value_in_events(events, ["inputs", "name"], "on-start")
assert not has_value_in_events(events, ["inputs", "name"],
"on-complete")
@@ -177,7 +176,7 @@ with tempfile.TemporaryDirectory(prefix="venv") as tmp_dir:
("core", "task_success_overtime"): "3",
}
)
- def test_success_overtime_kills_tasks(self):
+ def test_success_overtime_kills_tasks(self, listener_manager):
# This test checks whether LocalTaskJobRunner kills OL listener
which take
# longer time than permitted by core.task_success_overtime setting
from airflow.jobs.local_task_job_runner import LocalTaskJobRunner
@@ -186,8 +185,7 @@ with tempfile.TemporaryDirectory(prefix="venv") as tmp_dir:
if dirpath.exists():
shutil.rmtree(dirpath)
dirpath.mkdir(exist_ok=True, parents=True)
- lm = get_listener_manager()
- lm.add_listener(OpenLineageListener())
+ listener_manager(OpenLineageListener())
dagbag = DagBag(
dag_folder=TEST_DAG_FOLDER,
diff --git a/pyproject.toml b/pyproject.toml
index 272c2560118..c684593f562 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1290,6 +1290,7 @@ dev = [
"apache-airflow-ctl-tests",
"apache-airflow-shared-configuration",
"apache-airflow-shared-dagnode",
+ "apache-airflow-shared-listeners",
"apache-airflow-shared-logging",
"apache-airflow-shared-module-loading",
"apache-airflow-shared-observability",
@@ -1347,6 +1348,7 @@ apache-airflow-providers = { workspace = true }
apache-aurflow-docker-stack = { workspace = true }
apache-airflow-shared-configuration = { workspace = true }
apache-airflow-shared-dagnode = { workspace = true }
+apache-airflow-shared-listeners = { workspace = true }
apache-airflow-shared-logging = { workspace = true }
apache-airflow-shared-module-loading = { workspace = true }
apache-airflow-shared-observability = { workspace = true }
@@ -1473,6 +1475,7 @@ members = [
"docker-stack-docs",
"shared/configuration",
"shared/dagnode",
+ "shared/listeners",
"shared/logging",
"shared/module_loading",
"shared/observability",
diff --git a/airflow-core/src/airflow/listeners/spec/lifecycle.py
b/shared/listeners/pyproject.toml
similarity index 51%
copy from airflow-core/src/airflow/listeners/spec/lifecycle.py
copy to shared/listeners/pyproject.toml
index c5e3bb52e4d..445db6848e4 100644
--- a/airflow-core/src/airflow/listeners/spec/lifecycle.py
+++ b/shared/listeners/pyproject.toml
@@ -1,4 +1,3 @@
-#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -15,30 +14,40 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from __future__ import annotations
-
-from pluggy import HookspecMarker
-
-hookspec = HookspecMarker("airflow")
-
-
-@hookspec
-def on_starting(component):
- """
- Execute before Airflow component - jobs like scheduler, worker, or task
runner starts.
-
- It's guaranteed this will be called before any other plugin method.
-
- :param component: Component that calls this method
- """
-
-
-@hookspec
-def before_stopping(component):
- """
- Execute before Airflow component - jobs like scheduler, worker, or task
runner stops.
-
- It's guaranteed this will be called after any other plugin method.
- :param component: Component that calls this method
- """
+[project]
+name = "apache-airflow-shared-listeners"
+description = "Shared listeners code for Airflow distributions"
+version = "0.0"
+classifiers = [
+ "Private :: Do Not Upload",
+]
+
+dependencies = [
+ "pluggy>=1.5.0",
+ "structlog>=25.4.0",
+]
+
+[dependency-groups]
+dev = [
+ "apache-airflow-devel-common",
+]
+
+[build-system]
+requires = ["hatchling"]
+build-backend = "hatchling.build"
+
+[tool.hatch.build.targets.wheel]
+packages = ["src/airflow_shared"]
+
+[tool.ruff]
+extend = "../../pyproject.toml"
+src = ["src"]
+
+[tool.ruff.lint.per-file-ignores]
+# Ignore Doc rules et al for anything outside of tests
+"!src/*" = ["D", "S101", "TRY002"]
+
+[tool.ruff.lint.flake8-tidy-imports]
+# Override the workspace level default
+ban-relative-imports = "parents"
diff --git a/airflow-core/src/airflow/listeners/__init__.py
b/shared/listeners/src/airflow_shared/listeners/__init__.py
similarity index 100%
copy from airflow-core/src/airflow/listeners/__init__.py
copy to shared/listeners/src/airflow_shared/listeners/__init__.py
diff --git a/airflow-core/src/airflow/listeners/listener.py
b/shared/listeners/src/airflow_shared/listeners/listener.py
similarity index 70%
copy from airflow-core/src/airflow/listeners/listener.py
copy to shared/listeners/src/airflow_shared/listeners/listener.py
index 08869f50947..d4b36c059d4 100644
--- a/airflow-core/src/airflow/listeners/listener.py
+++ b/shared/listeners/src/airflow_shared/listeners/listener.py
@@ -17,18 +17,15 @@
# under the License.
from __future__ import annotations
-import logging
-from functools import cache
from typing import TYPE_CHECKING
import pluggy
-
-from airflow.plugins_manager import integrate_listener_plugins
+import structlog
if TYPE_CHECKING:
from pluggy._hooks import _HookRelay
-log = logging.getLogger(__name__)
+log = structlog.get_logger(__name__)
def _before_hookcall(hook_name, hook_impls, kwargs):
@@ -41,24 +38,25 @@ def _after_hookcall(outcome, hook_name, hook_impls, kwargs):
class ListenerManager:
- """Manage listener registration and provides hook property for calling
them."""
+ """
+ Manage listener registration and provides hook property for calling them.
- def __init__(self):
- from airflow.listeners.spec import (
- asset,
- dagrun,
- importerrors,
- lifecycle,
- taskinstance,
- )
+ This class provides base infra for listener system. The consumers /
components
+ wanting to register listeners should initialise its own ListenerManager and
+ register the hook specs relevant to that component using add_hookspecs.
+ """
+ def __init__(self):
self.pm = pluggy.PluginManager("airflow")
self.pm.add_hookcall_monitoring(_before_hookcall, _after_hookcall)
- self.pm.add_hookspecs(lifecycle)
- self.pm.add_hookspecs(dagrun)
- self.pm.add_hookspecs(asset)
- self.pm.add_hookspecs(taskinstance)
- self.pm.add_hookspecs(importerrors)
+
+ def add_hookspecs(self, spec_module) -> None:
+ """
+ Register hook specs from a module.
+
+ :param spec_module: A module containing functions decorated with
@hookspec.
+ """
+ self.pm.add_hookspecs(spec_module)
@property
def has_listeners(self) -> bool:
@@ -78,11 +76,3 @@ class ListenerManager:
"""Remove registered plugins."""
for plugin in self.pm.get_plugins():
self.pm.unregister(plugin)
-
-
-@cache
-def get_listener_manager() -> ListenerManager:
- """Get singleton listener manager."""
- _listener_manager = ListenerManager()
- integrate_listener_plugins(_listener_manager)
- return _listener_manager
diff --git a/airflow-core/src/airflow/listeners/__init__.py
b/shared/listeners/src/airflow_shared/listeners/spec/__init__.py
similarity index 87%
copy from airflow-core/src/airflow/listeners/__init__.py
copy to shared/listeners/src/airflow_shared/listeners/spec/__init__.py
index 87840b50e2f..13a83393a91 100644
--- a/airflow-core/src/airflow/listeners/__init__.py
+++ b/shared/listeners/src/airflow_shared/listeners/spec/__init__.py
@@ -1,4 +1,3 @@
-#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -15,8 +14,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from __future__ import annotations
-
-from pluggy import HookimplMarker
-
-hookimpl = HookimplMarker("airflow")
diff --git a/airflow-core/src/airflow/listeners/spec/lifecycle.py
b/shared/listeners/src/airflow_shared/listeners/spec/lifecycle.py
similarity index 100%
copy from airflow-core/src/airflow/listeners/spec/lifecycle.py
copy to shared/listeners/src/airflow_shared/listeners/spec/lifecycle.py
diff --git a/airflow-core/src/airflow/listeners/spec/taskinstance.py
b/shared/listeners/src/airflow_shared/listeners/spec/taskinstance.py
similarity index 90%
rename from airflow-core/src/airflow/listeners/spec/taskinstance.py
rename to shared/listeners/src/airflow_shared/listeners/spec/taskinstance.py
index 75b98e8a7b5..d3450d6b05a 100644
--- a/airflow-core/src/airflow/listeners/spec/taskinstance.py
+++ b/shared/listeners/src/airflow_shared/listeners/spec/taskinstance.py
@@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
from __future__ import annotations
from typing import TYPE_CHECKING
@@ -22,6 +23,7 @@ from typing import TYPE_CHECKING
from pluggy import HookspecMarker
if TYPE_CHECKING:
+ # These imports are for type checking only - no runtime dependency
from airflow.models.taskinstance import TaskInstance
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
from airflow.utils.state import TaskInstanceState
@@ -30,13 +32,17 @@ hookspec = HookspecMarker("airflow")
@hookspec
-def on_task_instance_running(previous_state: TaskInstanceState | None,
task_instance: RuntimeTaskInstance):
+def on_task_instance_running(
+ previous_state: TaskInstanceState | None,
+ task_instance: RuntimeTaskInstance | TaskInstance,
+):
"""Execute when task state changes to RUNNING. previous_state can be
None."""
@hookspec
def on_task_instance_success(
- previous_state: TaskInstanceState | None, task_instance:
RuntimeTaskInstance | TaskInstance
+ previous_state: TaskInstanceState | None,
+ task_instance: RuntimeTaskInstance | TaskInstance,
):
"""Execute when task state changes to SUCCESS. previous_state can be
None."""
diff --git a/airflow-core/src/airflow/listeners/__init__.py
b/shared/listeners/tests/conftest.py
similarity index 92%
copy from airflow-core/src/airflow/listeners/__init__.py
copy to shared/listeners/tests/conftest.py
index 87840b50e2f..8b61b1b99b9 100644
--- a/airflow-core/src/airflow/listeners/__init__.py
+++ b/shared/listeners/tests/conftest.py
@@ -17,6 +17,6 @@
# under the License.
from __future__ import annotations
-from pluggy import HookimplMarker
+import os
-hookimpl = HookimplMarker("airflow")
+os.environ["_AIRFLOW__AS_LIBRARY"] = "true"
diff --git a/airflow-core/src/airflow/listeners/__init__.py
b/shared/listeners/tests/listeners/__init__.py
similarity index 91%
copy from airflow-core/src/airflow/listeners/__init__.py
copy to shared/listeners/tests/listeners/__init__.py
index 87840b50e2f..03cb33c14c4 100644
--- a/airflow-core/src/airflow/listeners/__init__.py
+++ b/shared/listeners/tests/listeners/__init__.py
@@ -16,7 +16,3 @@
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
-
-from pluggy import HookimplMarker
-
-hookimpl = HookimplMarker("airflow")
diff --git a/shared/listeners/tests/listeners/test_listener_manager.py
b/shared/listeners/tests/listeners/test_listener_manager.py
new file mode 100644
index 00000000000..ebf360dade0
--- /dev/null
+++ b/shared/listeners/tests/listeners/test_listener_manager.py
@@ -0,0 +1,164 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow_shared.listeners import hookimpl
+from airflow_shared.listeners.listener import ListenerManager
+from airflow_shared.listeners.spec import lifecycle, taskinstance
+
+
+class TestListenerManager:
+ def test_initial_state_has_no_listeners(self):
+ """Test that a new ListenerManager has no listeners."""
+ lm = ListenerManager()
+ assert not lm.has_listeners
+ assert len(lm.pm.get_plugins()) == 0
+
+ def test_add_hookspecs_registers_hooks(self):
+ """Test that add_hookspecs makes hooks available."""
+ lm = ListenerManager()
+ lm.add_hookspecs(lifecycle)
+
+ # Verify lifecycle hooks are now available
+ assert hasattr(lm.hook, "on_starting")
+ assert hasattr(lm.hook, "before_stopping")
+
+ def test_add_multiple_hookspecs(self):
+ """Test that multiple hookspecs can be registered."""
+ lm = ListenerManager()
+ lm.add_hookspecs(lifecycle)
+ lm.add_hookspecs(taskinstance)
+
+ # Verify hooks from both specs are available
+ assert hasattr(lm.hook, "on_starting")
+ assert hasattr(lm.hook, "on_task_instance_running")
+
+ def test_add_listener(self):
+ """Test listener registration."""
+
+ class TestListener:
+ def __init__(self):
+ self.called = False
+
+ @hookimpl
+ def on_starting(self, component):
+ self.called = True
+
+ lm = ListenerManager()
+ lm.add_hookspecs(lifecycle)
+ listener = TestListener()
+ lm.add_listener(listener)
+
+ assert lm.has_listeners
+ assert lm.pm.is_registered(listener)
+
+ def test_duplicate_listener_registration(self):
+ """Test adding same listener twice doesn't duplicate."""
+
+ class TestListener:
+ @hookimpl
+ def on_starting(self, component):
+ pass
+
+ lm = ListenerManager()
+ lm.add_hookspecs(lifecycle)
+ listener = TestListener()
+ lm.add_listener(listener)
+ lm.add_listener(listener)
+
+ # Should only be registered once
+ assert len(lm.pm.get_plugins()) == 1
+
+ def test_clear_listeners(self):
+ """Test clearing listeners removes all registered listeners."""
+
+ class TestListener:
+ @hookimpl
+ def on_starting(self, component):
+ pass
+
+ lm = ListenerManager()
+ lm.add_hookspecs(lifecycle)
+ listener1 = TestListener()
+ listener2 = TestListener()
+ lm.add_listener(listener1)
+ lm.add_listener(listener2)
+
+ assert lm.has_listeners
+ assert len(lm.pm.get_plugins()) == 2
+
+ lm.clear()
+
+ assert not lm.has_listeners
+ assert len(lm.pm.get_plugins()) == 0
+
+ def test_hook_calling(self):
+ """Test hooks can be called and listeners receive them."""
+
+ class TestListener:
+ def __init__(self):
+ self.component_received = None
+
+ @hookimpl
+ def on_starting(self, component):
+ self.component_received = component
+
+ lm = ListenerManager()
+ lm.add_hookspecs(lifecycle)
+ listener = TestListener()
+ lm.add_listener(listener)
+
+ test_component = "test_component"
+ lm.hook.on_starting(component=test_component)
+
+ assert listener.component_received == test_component
+
+ def test_taskinstance_hooks(self):
+ """Test taskinstance hook specs work correctly."""
+
+ class TaskInstanceListener:
+ def __init__(self):
+ self.events = []
+
+ @hookimpl
+ def on_task_instance_running(self, previous_state, task_instance):
+ self.events.append(("running", task_instance))
+
+ @hookimpl
+ def on_task_instance_success(self, previous_state, task_instance):
+ self.events.append(("success", task_instance))
+
+ @hookimpl
+ def on_task_instance_failed(self, previous_state, task_instance,
error):
+ self.events.append(("failed", task_instance, error))
+
+ lm = ListenerManager()
+ lm.add_hookspecs(taskinstance)
+ listener = TaskInstanceListener()
+ lm.add_listener(listener)
+
+ mock_ti = "mock_task_instance"
+ lm.hook.on_task_instance_running(previous_state=None,
task_instance=mock_ti)
+ lm.hook.on_task_instance_success(previous_state=None,
task_instance=mock_ti)
+ lm.hook.on_task_instance_failed(previous_state=None,
task_instance=mock_ti, error="test error")
+
+ assert listener.events == [
+ ("running", mock_ti),
+ ("success", mock_ti),
+ ("failed", mock_ti, "test error"),
+ ]
diff --git a/task-sdk/pyproject.toml b/task-sdk/pyproject.toml
index e1ece8f2831..fc989724391 100644
--- a/task-sdk/pyproject.toml
+++ b/task-sdk/pyproject.toml
@@ -77,6 +77,9 @@ dependencies = [
"packaging>=24.0",
"typing-extensions>=4.14.1",
# End of shared configuration dependencies
+ # Start of shared listeners dependencies
+ "pluggy>=1.5.0",
+ # End of shared listeners dependencies
# Start of shared module-loading dependencies
'importlib_metadata>=6.5;python_version<"3.12"',
"pathspec>=0.9.0",
@@ -123,10 +126,11 @@ path = "src/airflow/sdk/__init__.py"
"../shared/dagnode/src/airflow_shared/dagnode" =
"src/airflow/sdk/_shared/dagnode"
"../shared/logging/src/airflow_shared/logging" =
"src/airflow/sdk/_shared/logging"
"../shared/module_loading/src/airflow_shared/module_loading" =
"src/airflow/sdk/_shared/module_loading"
-"../shared/observability/src/airflow_shared/observability" =
"src/airflow/_shared/observability"
+"../shared/observability/src/airflow_shared/observability" =
"src/airflow/sdk/_shared/observability"
"../shared/secrets_backend/src/airflow_shared/secrets_backend" =
"src/airflow/sdk/_shared/secrets_backend"
"../shared/secrets_masker/src/airflow_shared/secrets_masker" =
"src/airflow/sdk/_shared/secrets_masker"
"../shared/timezones/src/airflow_shared/timezones" =
"src/airflow/sdk/_shared/timezones"
+"../shared/listeners/src/airflow_shared/listeners" =
"src/airflow/sdk/_shared/listeners"
"../shared/plugins_manager/src/airflow_shared/plugins_manager" =
"src/airflow/sdk/_shared/plugins_manager"
[tool.hatch.build.targets.wheel]
@@ -271,6 +275,7 @@ tmp_path_retention_policy = "failed"
shared_distributions = [
"apache-airflow-shared-configuration",
"apache-airflow-shared-dagnode",
+ "apache-airflow-shared-listeners",
"apache-airflow-shared-logging",
"apache-airflow-shared-module-loading",
"apache-airflow-shared-secrets-backend",
diff --git a/task-sdk/src/airflow/sdk/_shared/listeners
b/task-sdk/src/airflow/sdk/_shared/listeners
new file mode 120000
index 00000000000..fa274373206
--- /dev/null
+++ b/task-sdk/src/airflow/sdk/_shared/listeners
@@ -0,0 +1 @@
+../../../../../shared/listeners/src/airflow_shared/listeners
\ No newline at end of file
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 814673d93fe..cb76097b9fe 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -40,7 +40,6 @@ from pydantic import AwareDatetime, ConfigDict, Field,
JsonValue, TypeAdapter
from airflow.dag_processing.bundles.base import BaseDagBundle,
BundleVersionLock
from airflow.dag_processing.bundles.manager import DagBundlesManager
-from airflow.listeners.listener import get_listener_manager
from airflow.sdk.api.client import get_hostname, getuser
from airflow.sdk.api.datamodels._generated import (
AssetProfile,
@@ -118,6 +117,7 @@ from airflow.sdk.execution_time.context import (
)
from airflow.sdk.execution_time.sentry import Sentry
from airflow.sdk.execution_time.xcom import XCom
+from airflow.sdk.listener import get_listener_manager
from airflow.sdk.observability.stats import Stats
from airflow.sdk.timezone import coerce_datetime
from airflow.triggers.base import BaseEventTrigger
diff --git a/airflow-core/src/airflow/listeners/spec/lifecycle.py
b/task-sdk/src/airflow/sdk/listener.py
similarity index 51%
rename from airflow-core/src/airflow/listeners/spec/lifecycle.py
rename to task-sdk/src/airflow/sdk/listener.py
index c5e3bb52e4d..62c36753ce3 100644
--- a/airflow-core/src/airflow/listeners/spec/lifecycle.py
+++ b/task-sdk/src/airflow/sdk/listener.py
@@ -17,28 +17,29 @@
# under the License.
from __future__ import annotations
-from pluggy import HookspecMarker
+from functools import cache
-hookspec = HookspecMarker("airflow")
+from airflow.sdk._shared.listeners.listener import ListenerManager
+from airflow.sdk._shared.listeners.spec import lifecycle, taskinstance
+from airflow.sdk.plugins_manager import integrate_listener_plugins
-@hookspec
-def on_starting(component):
+@cache
+def get_listener_manager() -> ListenerManager:
"""
- Execute before Airflow component - jobs like scheduler, worker, or task
runner starts.
+ Get a listener manager for task sdk.
- It's guaranteed this will be called before any other plugin method.
-
- :param component: Component that calls this method
+ Registers the following listeners:
+ - lifecycle: on_starting, before_stopping
+ - taskinstance: on_task_instance_running, on_task_instance_success, etc.
"""
+ _listener_manager = ListenerManager()
+ _listener_manager.add_hookspecs(lifecycle)
+ _listener_manager.add_hookspecs(taskinstance)
-@hookspec
-def before_stopping(component):
- """
- Execute before Airflow component - jobs like scheduler, worker, or task
runner stops.
+ integrate_listener_plugins(_listener_manager) # type: ignore[arg-type]
+ return _listener_manager
- It's guaranteed this will be called after any other plugin method.
- :param component: Component that calls this method
- """
+__all__ = ["get_listener_manager", "ListenerManager"]
diff --git a/task-sdk/tests/conftest.py b/task-sdk/tests/conftest.py
index eadb2e7d594..c1ef3b72c92 100644
--- a/task-sdk/tests/conftest.py
+++ b/task-sdk/tests/conftest.py
@@ -164,14 +164,14 @@ def _disable_ol_plugin():
# 3.12+ issues a warning when os.fork happens. So for this plugin we
disable it
# And we load plugins when setting the priority_weight field
- import airflow.plugins_manager
+ import airflow.sdk.plugins_manager
- old = airflow.plugins_manager._get_plugins
- airflow.plugins_manager._get_plugins = lambda: ([], {})
+ old = airflow.sdk.plugins_manager._get_plugins
+ airflow.sdk.plugins_manager._get_plugins = lambda: ([], {})
yield
- airflow.plugins_manager._get_plugins = old
+ airflow.sdk.plugins_manager._get_plugins = old
@pytest.fixture(autouse=True)
diff --git a/task-sdk/tests/task_sdk/docs/test_public_api.py
b/task-sdk/tests/task_sdk/docs/test_public_api.py
index f02b9ccebf9..e7f653d76a2 100644
--- a/task-sdk/tests/task_sdk/docs/test_public_api.py
+++ b/task-sdk/tests/task_sdk/docs/test_public_api.py
@@ -60,6 +60,7 @@ def test_airflow_sdk_no_unexpected_exports():
"serde",
"observability",
"plugins_manager",
+ "listener",
}
unexpected = actual - public - ignore
assert not unexpected, f"Unexpected exports in airflow.sdk:
{sorted(unexpected)}"
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 4a5c54ba3a4..e6a179f7189 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -35,8 +35,6 @@ import pytest
from task_sdk import FAKE_BUNDLE
from uuid6 import uuid7
-from airflow.listeners import hookimpl
-from airflow.listeners.listener import get_listener_manager
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk import (
DAG,
@@ -48,6 +46,7 @@ from airflow.sdk import (
task as task_decorator,
timezone,
)
+from airflow.sdk._shared.listeners import hookimpl
from airflow.sdk.api.datamodels._generated import (
AssetProfile,
AssetResponse,
@@ -459,9 +458,9 @@ def test_defer_task_queue_assignment(
)
-def test_run_downstream_skipped(mocked_parse, create_runtime_ti,
mock_supervisor_comms):
+def test_run_downstream_skipped(mocked_parse, create_runtime_ti,
mock_supervisor_comms, listener_manager):
listener = TestTaskRunnerCallsListeners.CustomListener()
- get_listener_manager().add_listener(listener)
+ listener_manager(listener)
class CustomOperator(BaseOperator):
def execute(self, context):
@@ -3269,19 +3268,11 @@ class TestTaskRunnerCallsListeners:
self._add_outlet_events(context)
self.error = error
- @pytest.fixture(autouse=True)
- def clean_listener_manager(self):
- lm = get_listener_manager()
- lm.clear()
- yield
- lm = get_listener_manager()
- lm.clear()
-
def test_task_runner_calls_on_startup_before_stopping(
- self, make_ti_context, mocked_parse, mock_supervisor_comms
+ self, make_ti_context, mocked_parse, mock_supervisor_comms,
listener_manager
):
listener = self.CustomListener()
- get_listener_manager().add_listener(listener)
+ listener_manager(listener)
class CustomOperator(BaseOperator):
def execute(self, context):
@@ -3318,9 +3309,9 @@ class TestTaskRunnerCallsListeners:
finalize(runtime_ti, state, context, log)
assert isinstance(listener.component, TaskRunnerMarker)
- def test_task_runner_calls_listeners_success(self, mocked_parse,
mock_supervisor_comms):
+ def test_task_runner_calls_listeners_success(self, mocked_parse,
mock_supervisor_comms, listener_manager):
listener = self.CustomListener()
- get_listener_manager().add_listener(listener)
+ listener_manager(listener)
class CustomOperator(BaseOperator):
def execute(self, context):
@@ -3357,9 +3348,11 @@ class TestTaskRunnerCallsListeners:
AirflowException("oops"),
],
)
- def test_task_runner_calls_listeners_failed(self, mocked_parse,
mock_supervisor_comms, exception):
+ def test_task_runner_calls_listeners_failed(
+ self, mocked_parse, mock_supervisor_comms, exception, listener_manager
+ ):
listener = self.CustomListener()
- get_listener_manager().add_listener(listener)
+ listener_manager(listener)
class CustomOperator(BaseOperator):
def execute(self, context):
@@ -3389,9 +3382,9 @@ class TestTaskRunnerCallsListeners:
assert listener.state == [TaskInstanceState.RUNNING,
TaskInstanceState.FAILED]
assert listener.error == error
- def test_task_runner_calls_listeners_skipped(self, mocked_parse,
mock_supervisor_comms):
+ def test_task_runner_calls_listeners_skipped(self, mocked_parse,
mock_supervisor_comms, listener_manager):
listener = self.CustomListener()
- get_listener_manager().add_listener(listener)
+ listener_manager(listener)
class CustomOperator(BaseOperator):
def execute(self, context):
@@ -3420,10 +3413,12 @@ class TestTaskRunnerCallsListeners:
assert listener.state == [TaskInstanceState.RUNNING,
TaskInstanceState.SKIPPED]
- def test_listener_access_outlet_event_on_running_and_success(self,
mocked_parse, mock_supervisor_comms):
+ def test_listener_access_outlet_event_on_running_and_success(
+ self, mocked_parse, mock_supervisor_comms, listener_manager
+ ):
"""Test listener can access outlet events through invoking
get_template_context() while task running and success"""
listener = self.CustomOutletEventsListener()
- get_listener_manager().add_listener(listener)
+ listener_manager(listener)
test_asset = Asset("test-asset")
test_key = AssetUniqueKey(name="test-asset", uri="test-asset")
@@ -3480,10 +3475,12 @@ class TestTaskRunnerCallsListeners:
],
ids=["ValueError", "SystemExit", "AirflowException"],
)
- def test_listener_access_outlet_event_on_failed(self, mocked_parse,
mock_supervisor_comms, exception):
+ def test_listener_access_outlet_event_on_failed(
+ self, mocked_parse, mock_supervisor_comms, exception, listener_manager
+ ):
"""Test listener can access outlet events through invoking
get_template_context() while task failed"""
listener = self.CustomOutletEventsListener()
- get_listener_manager().add_listener(listener)
+ listener_manager(listener)
test_asset = Asset("test-asset")
test_key = AssetUniqueKey(name="test-asset", uri="test-asset")