This is an automated email from the ASF dual-hosted git repository.

ferruzzi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4b99ae89f8f bugfix: Enforce and document context injection into custom 
callbacks (#62649)
4b99ae89f8f is described below

commit 4b99ae89f8f4b509d400f1cc296df9ed6c849b18
Author: D. Ferruzzi <[email protected]>
AuthorDate: Wed Mar 4 16:21:52 2026 -0800

    bugfix: Enforce and document context injection into custom callbacks 
(#62649)
    
    * Enforce and document context injection into custom callbacks
---
 airflow-core/docs/howto/deadline-alerts.rst        |  7 +++++
 .../src/airflow/executors/workloads/callback.py    | 10 +++++--
 airflow-core/src/airflow/models/callback.py        | 12 ++++++++
 airflow-core/src/airflow/triggers/callback.py      | 11 ++++++--
 airflow-core/tests/unit/models/test_callback.py    | 33 ++++++++++++++++++++++
 airflow-core/tests/unit/triggers/test_callback.py  | 23 ++++++++++-----
 6 files changed, 84 insertions(+), 12 deletions(-)

diff --git a/airflow-core/docs/howto/deadline-alerts.rst 
b/airflow-core/docs/howto/deadline-alerts.rst
index ab1e9da5f69..643e17fc185 100644
--- a/airflow-core/docs/howto/deadline-alerts.rst
+++ b/airflow-core/docs/howto/deadline-alerts.rst
@@ -237,6 +237,13 @@ Triggerer's system path.
       Nested callables are not currently supported.
     * The Triggerer will need to be restarted when a callback is added or 
changed in order to reload the file.
 
+.. note::
+    **Airflow ``context``:** When a deadline is missed, Airflow automatically 
provides a ``context``
+    kwarg into the callback containing information about the Dag run and the 
deadline. To receive it,
+    accept ``**kwargs`` in your callback and access ``kwargs["context"]``, or 
add a named ``context``
+    parameter. Callbacks that don't need the context can omit it — Airflow 
will only pass kwargs that
+    the callable accepts. The ``context`` keyword is reserved and cannot be 
used in the ``kwargs``
+    parameter of a ``Callback``; attempting to do so will raise a 
``ValueError`` at DAG parse time.
 
 A **custom asynchronous callback** might look like this:
 
diff --git a/airflow-core/src/airflow/executors/workloads/callback.py 
b/airflow-core/src/airflow/executors/workloads/callback.py
index c15bb33fba7..2563f9a78f5 100644
--- a/airflow-core/src/airflow/executors/workloads/callback.py
+++ b/airflow-core/src/airflow/executors/workloads/callback.py
@@ -125,6 +125,8 @@ def execute_callback_workload(
     :param log: Logger instance for recording execution
     :return: Tuple of (success: bool, error_message: str | None)
     """
+    from airflow.models.callback import _accepts_context  # circular import
+
     callback_path = callback.data.get("path")
     callback_kwargs = callback.data.get("kwargs", {})
 
@@ -137,15 +139,19 @@ def execute_callback_workload(
         module_path, function_name = callback_path.rsplit(".", 1)
         module = import_module(module_path)
         callback_callable = getattr(module, function_name)
+        context = callback_kwargs.pop("context", None)
 
         log.debug("Executing callback %s(%s)...", callback_path, 
callback_kwargs)
 
         # If the callback is a callable, call it.  If it is a class, 
instantiate it.
-        result = callback_callable(**callback_kwargs)
+        # Rather than forcing all custom callbacks to accept context, 
conditionally provide it only if supported.
+        if _accepts_context(callback_callable) and context is not None:
+            result = callback_callable(**callback_kwargs, context=context)
+        else:
+            result = callback_callable(**callback_kwargs)
 
         # If the callback is a class then it is now instantiated and callable, 
call it.
         if callable(result):
-            context = callback_kwargs.get("context", {})
             log.debug("Calling result with context for %s", callback_path)
             result = result(context)
 
diff --git a/airflow-core/src/airflow/models/callback.py 
b/airflow-core/src/airflow/models/callback.py
index ea482ab7ba8..e2c46153a71 100644
--- a/airflow-core/src/airflow/models/callback.py
+++ b/airflow-core/src/airflow/models/callback.py
@@ -16,6 +16,8 @@
 # under the License.
 from __future__ import annotations
 
+import inspect
+from collections.abc import Callable
 from datetime import datetime
 from enum import Enum
 from importlib import import_module
@@ -50,6 +52,16 @@ ACTIVE_STATES = frozenset((CallbackState.PENDING, 
CallbackState.QUEUED, Callback
 TERMINAL_STATES = frozenset((CallbackState.SUCCESS, CallbackState.FAILED))
 
 
+def _accepts_context(callback: Callable) -> bool:
+    """Check if callback accepts a 'context' parameter or **kwargs."""
+    try:
+        sig = inspect.signature(callback)
+    except (ValueError, TypeError):
+        return True
+    params = sig.parameters
+    return "context" in params or any(p.kind == inspect.Parameter.VAR_KEYWORD 
for p in params.values())
+
+
 class CallbackType(str, Enum):
     """
     Types of Callbacks.
diff --git a/airflow-core/src/airflow/triggers/callback.py 
b/airflow-core/src/airflow/triggers/callback.py
index aadfffe38cc..9c2470c77ea 100644
--- a/airflow-core/src/airflow/triggers/callback.py
+++ b/airflow-core/src/airflow/triggers/callback.py
@@ -23,7 +23,7 @@ from collections.abc import AsyncIterator
 from typing import Any
 
 from airflow._shared.module_loading import import_string, qualname
-from airflow.models.callback import CallbackState
+from airflow.models.callback import CallbackState, _accepts_context
 from airflow.triggers.base import BaseTrigger, TriggerEvent
 
 log = logging.getLogger(__name__)
@@ -52,9 +52,14 @@ class CallbackTrigger(BaseTrigger):
         try:
             yield TriggerEvent({PAYLOAD_STATUS_KEY: CallbackState.RUNNING})
             callback = import_string(self.callback_path)
+            # TODO: get full context and run template rendering. Right now, a 
simple context is included in `callback_kwargs`
+            context = self.callback_kwargs.pop("context", None)
+
+            if _accepts_context(callback) and context is not None:
+                result = await callback(**self.callback_kwargs, 
context=context)
+            else:
+                result = await callback(**self.callback_kwargs)
 
-            # TODO: get full context and run template rendering. Right now, a 
simple context in included in `callback_kwargs`
-            result = await callback(**self.callback_kwargs)
             yield TriggerEvent({PAYLOAD_STATUS_KEY: CallbackState.SUCCESS, 
PAYLOAD_BODY_KEY: result})
 
         except Exception as e:
diff --git a/airflow-core/tests/unit/models/test_callback.py 
b/airflow-core/tests/unit/models/test_callback.py
index 6ab6ad2d02d..20bbba29fc1 100644
--- a/airflow-core/tests/unit/models/test_callback.py
+++ b/airflow-core/tests/unit/models/test_callback.py
@@ -16,6 +16,8 @@
 # under the License.
 from __future__ import annotations
 
+from unittest.mock import patch
+
 import pytest
 from sqlalchemy import select
 
@@ -26,6 +28,7 @@ from airflow.models.callback import (
     CallbackState,
     ExecutorCallback,
     TriggererCallback,
+    _accepts_context,
 )
 from airflow.sdk.definitions.callback import AsyncCallback, SyncCallback
 from airflow.triggers.base import TriggerEvent
@@ -208,3 +211,33 @@ class TestExecutorCallback:
 
 
 # Note: class DagProcessorCallback is tested in 
airflow-core/tests/unit/dag_processing/test_manager.py
+
+
+class TestAcceptsContext:
+    def test_true_when_var_keyword_present(self):
+        def func_with_var_keyword(**kwargs):
+            pass
+
+        assert _accepts_context(func_with_var_keyword) is True
+
+    def test_true_when_context_param_present(self):
+        def func_with_context(context, alert_type):
+            pass
+
+        assert _accepts_context(func_with_context) is True
+
+    def test_false_when_no_context_or_var_keyword(self):
+        def func_without_context(a, b):
+            pass
+
+        assert _accepts_context(func_without_context) is False
+
+    def test_false_when_no_params(self):
+        def func_no_params():
+            pass
+
+        assert _accepts_context(func_no_params) is False
+
+    def test_true_for_uninspectable_callable(self):
+        with patch("airflow.models.callback.inspect.signature", 
side_effect=ValueError):
+            assert _accepts_context(lambda: None) is True
diff --git a/airflow-core/tests/unit/triggers/test_callback.py 
b/airflow-core/tests/unit/triggers/test_callback.py
index ca59ea735f8..99eca603323 100644
--- a/airflow-core/tests/unit/triggers/test_callback.py
+++ b/airflow-core/tests/unit/triggers/test_callback.py
@@ -28,7 +28,6 @@ from airflow.triggers.callback import PAYLOAD_BODY_KEY, 
PAYLOAD_STATUS_KEY, Call
 TEST_MESSAGE = "test_message"
 TEST_CALLBACK_PATH = "classpath.test_callback"
 TEST_CALLBACK_KWARGS = {"message": TEST_MESSAGE, "context": {"dag_run": 
"test"}}
-TEST_TRIGGER = CallbackTrigger(callback_path=TEST_CALLBACK_PATH, 
callback_kwargs=TEST_CALLBACK_KWARGS)
 
 
 class ExampleAsyncNotifier(BaseNotifier):
@@ -46,6 +45,14 @@ class ExampleAsyncNotifier(BaseNotifier):
 
 
 class TestCallbackTrigger:
+    @pytest.fixture
+    def trigger(self):
+        """Create a fresh trigger per test to avoid shared mutable state."""
+        return CallbackTrigger(
+            callback_path=TEST_CALLBACK_PATH,
+            callback_kwargs=dict(TEST_CALLBACK_KWARGS),
+        )
+
     @pytest.fixture
     def mock_import_string(self):
         with mock.patch("airflow.triggers.callback.import_string") as m:
@@ -72,29 +79,30 @@ class TestCallbackTrigger:
         }
 
     @pytest.mark.asyncio
-    async def test_run_success_with_async_function(self, mock_import_string):
+    async def test_run_success_with_async_function(self, trigger, 
mock_import_string):
         """Test trigger handles async functions correctly."""
         callback_return_value = "some value"
         mock_callback = mock.AsyncMock(return_value=callback_return_value)
         mock_import_string.return_value = mock_callback
 
-        trigger_gen = TEST_TRIGGER.run()
+        trigger_gen = trigger.run()
 
         running_event = await anext(trigger_gen)
         assert running_event.payload[PAYLOAD_STATUS_KEY] == 
CallbackState.RUNNING
 
         success_event = await anext(trigger_gen)
         mock_import_string.assert_called_once_with(TEST_CALLBACK_PATH)
+        # AsyncMock accepts **kwargs, so _accepts_context returns True and 
context is passed through
         mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS)
         assert success_event.payload[PAYLOAD_STATUS_KEY] == 
CallbackState.SUCCESS
         assert success_event.payload[PAYLOAD_BODY_KEY] == callback_return_value
 
     @pytest.mark.asyncio
-    async def test_run_success_with_notifier(self, mock_import_string):
+    async def test_run_success_with_notifier(self, trigger, 
mock_import_string):
         """Test trigger handles async notifier classes correctly."""
         mock_import_string.return_value = ExampleAsyncNotifier
 
-        trigger_gen = TEST_TRIGGER.run()
+        trigger_gen = trigger.run()
 
         running_event = await anext(trigger_gen)
         assert running_event.payload[PAYLOAD_STATUS_KEY] == 
CallbackState.RUNNING
@@ -108,18 +116,19 @@ class TestCallbackTrigger:
         )
 
     @pytest.mark.asyncio
-    async def test_run_failure(self, mock_import_string):
+    async def test_run_failure(self, trigger, mock_import_string):
         exc_msg = "Something went wrong"
         mock_callback = mock.AsyncMock(side_effect=RuntimeError(exc_msg))
         mock_import_string.return_value = mock_callback
 
-        trigger_gen = TEST_TRIGGER.run()
+        trigger_gen = trigger.run()
 
         running_event = await anext(trigger_gen)
         assert running_event.payload[PAYLOAD_STATUS_KEY] == 
CallbackState.RUNNING
 
         failure_event = await anext(trigger_gen)
         mock_import_string.assert_called_once_with(TEST_CALLBACK_PATH)
+        # AsyncMock accepts **kwargs, so _accepts_context returns True and 
context is passed through
         mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS)
         assert failure_event.payload[PAYLOAD_STATUS_KEY] == 
CallbackState.FAILED
         assert all(s in failure_event.payload[PAYLOAD_BODY_KEY] for s in 
["raise", "RuntimeError", exc_msg])

Reply via email to