This is an automated email from the ASF dual-hosted git repository.
ferruzzi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 4b99ae89f8f bugfix: Enforce and document context injection into custom
callbacks (#62649)
4b99ae89f8f is described below
commit 4b99ae89f8f4b509d400f1cc296df9ed6c849b18
Author: D. Ferruzzi <[email protected]>
AuthorDate: Wed Mar 4 16:21:52 2026 -0800
bugfix: Enforce and document context injection into custom callbacks
(#62649)
* Enforce and document context injection into custom callbacks
---
airflow-core/docs/howto/deadline-alerts.rst | 7 +++++
.../src/airflow/executors/workloads/callback.py | 10 +++++--
airflow-core/src/airflow/models/callback.py | 12 ++++++++
airflow-core/src/airflow/triggers/callback.py | 11 ++++++--
airflow-core/tests/unit/models/test_callback.py | 33 ++++++++++++++++++++++
airflow-core/tests/unit/triggers/test_callback.py | 23 ++++++++++-----
6 files changed, 84 insertions(+), 12 deletions(-)
diff --git a/airflow-core/docs/howto/deadline-alerts.rst
b/airflow-core/docs/howto/deadline-alerts.rst
index ab1e9da5f69..643e17fc185 100644
--- a/airflow-core/docs/howto/deadline-alerts.rst
+++ b/airflow-core/docs/howto/deadline-alerts.rst
@@ -237,6 +237,13 @@ Triggerer's system path.
Nested callables are not currently supported.
* The Triggerer will need to be restarted when a callback is added or
changed in order to reload the file.
+.. note::
+ **Airflow ``context``:** When a deadline is missed, Airflow automatically
provides a ``context``
+ kwarg into the callback containing information about the Dag run and the
deadline. To receive it,
+ accept ``**kwargs`` in your callback and access ``kwargs["context"]``, or
add a named ``context``
+ parameter. Callbacks that don't need the context can omit it — Airflow
will only pass kwargs that
+ the callable accepts. The ``context`` keyword is reserved and cannot be
used in the ``kwargs``
+ parameter of a ``Callback``; attempting to do so will raise a
``ValueError`` at DAG parse time.
A **custom asynchronous callback** might look like this:
diff --git a/airflow-core/src/airflow/executors/workloads/callback.py
b/airflow-core/src/airflow/executors/workloads/callback.py
index c15bb33fba7..2563f9a78f5 100644
--- a/airflow-core/src/airflow/executors/workloads/callback.py
+++ b/airflow-core/src/airflow/executors/workloads/callback.py
@@ -125,6 +125,8 @@ def execute_callback_workload(
:param log: Logger instance for recording execution
:return: Tuple of (success: bool, error_message: str | None)
"""
+ from airflow.models.callback import _accepts_context # circular import
+
callback_path = callback.data.get("path")
callback_kwargs = callback.data.get("kwargs", {})
@@ -137,15 +139,19 @@ def execute_callback_workload(
module_path, function_name = callback_path.rsplit(".", 1)
module = import_module(module_path)
callback_callable = getattr(module, function_name)
+ context = callback_kwargs.pop("context", None)
log.debug("Executing callback %s(%s)...", callback_path,
callback_kwargs)
# If the callback is a callable, call it. If it is a class,
instantiate it.
- result = callback_callable(**callback_kwargs)
+ # Rather than forcing all custom callbacks to accept context,
conditionally provide it only if supported.
+ if _accepts_context(callback_callable) and context is not None:
+ result = callback_callable(**callback_kwargs, context=context)
+ else:
+ result = callback_callable(**callback_kwargs)
# If the callback is a class then it is now instantiated and callable,
call it.
if callable(result):
- context = callback_kwargs.get("context", {})
log.debug("Calling result with context for %s", callback_path)
result = result(context)
diff --git a/airflow-core/src/airflow/models/callback.py
b/airflow-core/src/airflow/models/callback.py
index ea482ab7ba8..e2c46153a71 100644
--- a/airflow-core/src/airflow/models/callback.py
+++ b/airflow-core/src/airflow/models/callback.py
@@ -16,6 +16,8 @@
# under the License.
from __future__ import annotations
+import inspect
+from collections.abc import Callable
from datetime import datetime
from enum import Enum
from importlib import import_module
@@ -50,6 +52,16 @@ ACTIVE_STATES = frozenset((CallbackState.PENDING,
CallbackState.QUEUED, Callback
TERMINAL_STATES = frozenset((CallbackState.SUCCESS, CallbackState.FAILED))
+def _accepts_context(callback: Callable) -> bool:
+ """Check if callback accepts a 'context' parameter or **kwargs."""
+ try:
+ sig = inspect.signature(callback)
+ except (ValueError, TypeError):
+ return True
+ params = sig.parameters
+ return "context" in params or any(p.kind == inspect.Parameter.VAR_KEYWORD
for p in params.values())
+
+
class CallbackType(str, Enum):
"""
Types of Callbacks.
diff --git a/airflow-core/src/airflow/triggers/callback.py
b/airflow-core/src/airflow/triggers/callback.py
index aadfffe38cc..9c2470c77ea 100644
--- a/airflow-core/src/airflow/triggers/callback.py
+++ b/airflow-core/src/airflow/triggers/callback.py
@@ -23,7 +23,7 @@ from collections.abc import AsyncIterator
from typing import Any
from airflow._shared.module_loading import import_string, qualname
-from airflow.models.callback import CallbackState
+from airflow.models.callback import CallbackState, _accepts_context
from airflow.triggers.base import BaseTrigger, TriggerEvent
log = logging.getLogger(__name__)
@@ -52,9 +52,14 @@ class CallbackTrigger(BaseTrigger):
try:
yield TriggerEvent({PAYLOAD_STATUS_KEY: CallbackState.RUNNING})
callback = import_string(self.callback_path)
+ # TODO: get full context and run template rendering. Right now, a
simple context is included in `callback_kwargs`
+ context = self.callback_kwargs.pop("context", None)
+
+ if _accepts_context(callback) and context is not None:
+ result = await callback(**self.callback_kwargs,
context=context)
+ else:
+ result = await callback(**self.callback_kwargs)
- # TODO: get full context and run template rendering. Right now, a
simple context in included in `callback_kwargs`
- result = await callback(**self.callback_kwargs)
yield TriggerEvent({PAYLOAD_STATUS_KEY: CallbackState.SUCCESS,
PAYLOAD_BODY_KEY: result})
except Exception as e:
diff --git a/airflow-core/tests/unit/models/test_callback.py
b/airflow-core/tests/unit/models/test_callback.py
index 6ab6ad2d02d..20bbba29fc1 100644
--- a/airflow-core/tests/unit/models/test_callback.py
+++ b/airflow-core/tests/unit/models/test_callback.py
@@ -16,6 +16,8 @@
# under the License.
from __future__ import annotations
+from unittest.mock import patch
+
import pytest
from sqlalchemy import select
@@ -26,6 +28,7 @@ from airflow.models.callback import (
CallbackState,
ExecutorCallback,
TriggererCallback,
+ _accepts_context,
)
from airflow.sdk.definitions.callback import AsyncCallback, SyncCallback
from airflow.triggers.base import TriggerEvent
@@ -208,3 +211,33 @@ class TestExecutorCallback:
# Note: class DagProcessorCallback is tested in
airflow-core/tests/unit/dag_processing/test_manager.py
+
+
+class TestAcceptsContext:
+ def test_true_when_var_keyword_present(self):
+ def func_with_var_keyword(**kwargs):
+ pass
+
+ assert _accepts_context(func_with_var_keyword) is True
+
+ def test_true_when_context_param_present(self):
+ def func_with_context(context, alert_type):
+ pass
+
+ assert _accepts_context(func_with_context) is True
+
+ def test_false_when_no_context_or_var_keyword(self):
+ def func_without_context(a, b):
+ pass
+
+ assert _accepts_context(func_without_context) is False
+
+ def test_false_when_no_params(self):
+ def func_no_params():
+ pass
+
+ assert _accepts_context(func_no_params) is False
+
+ def test_true_for_uninspectable_callable(self):
+ with patch("airflow.models.callback.inspect.signature",
side_effect=ValueError):
+ assert _accepts_context(lambda: None) is True
diff --git a/airflow-core/tests/unit/triggers/test_callback.py
b/airflow-core/tests/unit/triggers/test_callback.py
index ca59ea735f8..99eca603323 100644
--- a/airflow-core/tests/unit/triggers/test_callback.py
+++ b/airflow-core/tests/unit/triggers/test_callback.py
@@ -28,7 +28,6 @@ from airflow.triggers.callback import PAYLOAD_BODY_KEY,
PAYLOAD_STATUS_KEY, Call
TEST_MESSAGE = "test_message"
TEST_CALLBACK_PATH = "classpath.test_callback"
TEST_CALLBACK_KWARGS = {"message": TEST_MESSAGE, "context": {"dag_run":
"test"}}
-TEST_TRIGGER = CallbackTrigger(callback_path=TEST_CALLBACK_PATH,
callback_kwargs=TEST_CALLBACK_KWARGS)
class ExampleAsyncNotifier(BaseNotifier):
@@ -46,6 +45,14 @@ class ExampleAsyncNotifier(BaseNotifier):
class TestCallbackTrigger:
+ @pytest.fixture
+ def trigger(self):
+ """Create a fresh trigger per test to avoid shared mutable state."""
+ return CallbackTrigger(
+ callback_path=TEST_CALLBACK_PATH,
+ callback_kwargs=dict(TEST_CALLBACK_KWARGS),
+ )
+
@pytest.fixture
def mock_import_string(self):
with mock.patch("airflow.triggers.callback.import_string") as m:
@@ -72,29 +79,30 @@ class TestCallbackTrigger:
}
@pytest.mark.asyncio
- async def test_run_success_with_async_function(self, mock_import_string):
+ async def test_run_success_with_async_function(self, trigger,
mock_import_string):
"""Test trigger handles async functions correctly."""
callback_return_value = "some value"
mock_callback = mock.AsyncMock(return_value=callback_return_value)
mock_import_string.return_value = mock_callback
- trigger_gen = TEST_TRIGGER.run()
+ trigger_gen = trigger.run()
running_event = await anext(trigger_gen)
assert running_event.payload[PAYLOAD_STATUS_KEY] ==
CallbackState.RUNNING
success_event = await anext(trigger_gen)
mock_import_string.assert_called_once_with(TEST_CALLBACK_PATH)
+ # AsyncMock accepts **kwargs, so _accepts_context returns True and
context is passed through
mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS)
assert success_event.payload[PAYLOAD_STATUS_KEY] ==
CallbackState.SUCCESS
assert success_event.payload[PAYLOAD_BODY_KEY] == callback_return_value
@pytest.mark.asyncio
- async def test_run_success_with_notifier(self, mock_import_string):
+ async def test_run_success_with_notifier(self, trigger,
mock_import_string):
"""Test trigger handles async notifier classes correctly."""
mock_import_string.return_value = ExampleAsyncNotifier
- trigger_gen = TEST_TRIGGER.run()
+ trigger_gen = trigger.run()
running_event = await anext(trigger_gen)
assert running_event.payload[PAYLOAD_STATUS_KEY] ==
CallbackState.RUNNING
@@ -108,18 +116,19 @@ class TestCallbackTrigger:
)
@pytest.mark.asyncio
- async def test_run_failure(self, mock_import_string):
+ async def test_run_failure(self, trigger, mock_import_string):
exc_msg = "Something went wrong"
mock_callback = mock.AsyncMock(side_effect=RuntimeError(exc_msg))
mock_import_string.return_value = mock_callback
- trigger_gen = TEST_TRIGGER.run()
+ trigger_gen = trigger.run()
running_event = await anext(trigger_gen)
assert running_event.payload[PAYLOAD_STATUS_KEY] ==
CallbackState.RUNNING
failure_event = await anext(trigger_gen)
mock_import_string.assert_called_once_with(TEST_CALLBACK_PATH)
+ # AsyncMock accepts **kwargs, so _accepts_context returns True and
context is passed through
mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS)
assert failure_event.payload[PAYLOAD_STATUS_KEY] ==
CallbackState.FAILED
assert all(s in failure_event.payload[PAYLOAD_BODY_KEY] for s in
["raise", "RuntimeError", exc_msg])