This is an automated email from the ASF dual-hosted git repository.

ferruzzi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new fb3c8117fb3  Include simple context in triggerer async callback 
(#55241)
fb3c8117fb3 is described below

commit fb3c8117fb3e5482712d448d702c032b19b29039
Author: Ramit Kataria <[email protected]>
AuthorDate: Mon Sep 8 12:24:08 2025 -0700

     Include simple context in triggerer async callback (#55241)
    
    - Added a simple context dict in the kwargs
    - Set `context` as a reserved field for kwargs in callback definition
    
    Eventually, we should probably use the TaskSDK API in the triggerer to
    fetch the full context but this solution covers most use cases for now.
---
 airflow-core/src/airflow/models/deadline.py        | 12 ++++++-
 airflow-core/src/airflow/triggers/deadline.py      | 11 +++---
 airflow-core/tests/unit/models/test_deadline.py    | 42 +++++++++++++---------
 airflow-core/tests/unit/triggers/test_deadline.py  | 11 +++---
 task-sdk/src/airflow/sdk/definitions/deadline.py   |  4 ++-
 .../tests/task_sdk/definitions/test_deadline.py    | 11 ++++++
 6 files changed, 62 insertions(+), 29 deletions(-)

diff --git a/airflow-core/src/airflow/models/deadline.py 
b/airflow-core/src/airflow/models/deadline.py
index f41fe648418..cfb99160a11 100644
--- a/airflow-core/src/airflow/models/deadline.py
+++ b/airflow-core/src/airflow/models/deadline.py
@@ -202,10 +202,20 @@ class Deadline(Base):
         """Handle a missed deadline by running the callback in the appropriate 
host and updating the `callback_state`."""
         from airflow.sdk.definitions.deadline import AsyncCallback, 
SyncCallback
 
+        def get_simple_context():
+            from airflow.api_fastapi.core_api.datamodels.dag_run import 
DAGRunResponse
+
+            # TODO: Use the TaskAPI from within Triggerer to fetch full 
context instead of sending this context
+            #  from the scheduler
+            return {
+                "dag_run": 
DAGRunResponse.model_validate(self.dagrun).model_dump(mode="json"),
+                "deadline": {"id": self.id, "deadline_time": 
self.deadline_time},
+            }
+
         if isinstance(self.callback, AsyncCallback):
             callback_trigger = DeadlineCallbackTrigger(
                 callback_path=self.callback.path,
-                callback_kwargs=self.callback.kwargs,
+                callback_kwargs=(self.callback.kwargs or {}) | {"context": 
get_simple_context()},
             )
             trigger_orm = Trigger.from_object(callback_trigger)
             session.add(trigger_orm)
diff --git a/airflow-core/src/airflow/triggers/deadline.py 
b/airflow-core/src/airflow/triggers/deadline.py
index bcff27fd1b2..bd8a665b9fc 100644
--- a/airflow-core/src/airflow/triggers/deadline.py
+++ b/airflow-core/src/airflow/triggers/deadline.py
@@ -49,15 +49,13 @@ class DeadlineCallbackTrigger(BaseTrigger):
         from airflow.models.deadline import DeadlineCallbackState  # to avoid 
cyclic imports
 
         try:
-            callback = import_string(self.callback_path)
             yield TriggerEvent({PAYLOAD_STATUS_KEY: 
DeadlineCallbackState.RUNNING})
+            callback = import_string(self.callback_path)
 
-            # TODO: get airflow context
-            context: dict = {}
-
-            result = await callback(**self.callback_kwargs, context=context)
-            log.info("Deadline callback completed with return value: %s", 
result)
+            # TODO: get full context and run template rendering. Right now, a 
simple context in included in `callback_kwargs`
+            result = await callback(**self.callback_kwargs)
             yield TriggerEvent({PAYLOAD_STATUS_KEY: 
DeadlineCallbackState.SUCCESS, PAYLOAD_BODY_KEY: result})
+
         except Exception as e:
             if isinstance(e, ImportError):
                 message = "Failed to import this deadline callback on the 
triggerer"
@@ -65,6 +63,7 @@ class DeadlineCallbackTrigger(BaseTrigger):
                 message = "Failed to run this deadline callback because it is 
not awaitable"
             else:
                 message = "An error occurred during execution of this deadline 
callback"
+
             log.exception("%s: %s; kwargs: %s\n%s", message, 
self.callback_path, self.callback_kwargs, e)
             yield TriggerEvent(
                 {
diff --git a/airflow-core/tests/unit/models/test_deadline.py 
b/airflow-core/tests/unit/models/test_deadline.py
index 1412b16bdbb..5e152935707 100644
--- a/airflow-core/tests/unit/models/test_deadline.py
+++ b/airflow-core/tests/unit/models/test_deadline.py
@@ -23,6 +23,7 @@ import pytest
 import time_machine
 from sqlalchemy.exc import SQLAlchemyError
 
+from airflow.api_fastapi.core_api.datamodels.dag_run import DAGRunResponse
 from airflow.models import DagRun, Trigger
 from airflow.models.deadline import Deadline, DeadlineCallbackState, 
ReferenceModels, _fetch_from_db
 from airflow.providers.standard.operators.empty import EmptyOperator
@@ -160,16 +161,37 @@ class TestDeadline:
         )
 
     @pytest.mark.db_test
-    def test_handle_miss_async_callback(self, dagrun, deadline_orm, session):
+    @pytest.mark.parametrize(
+        "kwargs",
+        [
+            pytest.param(TEST_CALLBACK_KWARGS, id="non-empty kwargs"),
+            pytest.param(None, id="null kwargs"),
+        ],
+    )
+    def test_handle_miss_async_callback(self, dagrun, session, kwargs):
+        deadline_orm = Deadline(
+            deadline_time=DEFAULT_DATE,
+            callback=AsyncCallback(TEST_CALLBACK_PATH, kwargs),
+            dagrun_id=dagrun.id,
+        )
+        session.add(deadline_orm)
+        session.flush()
         deadline_orm.handle_miss(session=session)
         session.flush()
 
         assert deadline_orm.trigger_id is not None
-
         trigger = session.query(Trigger).filter(Trigger.id == 
deadline_orm.trigger_id).one()
         assert trigger is not None
+
         assert trigger.kwargs["callback_path"] == TEST_CALLBACK_PATH
-        assert trigger.kwargs["callback_kwargs"] == TEST_CALLBACK_KWARGS
+
+        trigger_kwargs = trigger.kwargs["callback_kwargs"]
+        context = trigger_kwargs.pop("context")
+        assert trigger_kwargs == (kwargs or {})
+
+        assert context["deadline"]["id"] == str(deadline_orm.id)
+        assert context["deadline"]["deadline_time"].timestamp() == 
deadline_orm.deadline_time.timestamp()
+        assert context["dag_run"] == 
DAGRunResponse.model_validate(dagrun).model_dump(mode="json")
 
     @pytest.mark.db_test
     def test_handle_miss_sync_callback(self, dagrun, session):
@@ -232,20 +254,6 @@ class TestDeadline:
         else:
             assert deadline_orm.callback_state == DeadlineCallbackState.QUEUED
 
-    def test_handle_miss_creates_trigger(self, dagrun, deadline_orm, session):
-        """Test that handle_miss creates a trigger with correct parameters."""
-        deadline_orm.handle_miss(session)
-        session.flush()
-
-        # Check trigger was created
-        trigger = session.query(Trigger).first()
-        assert trigger is not None
-        assert deadline_orm.trigger_id == trigger.id
-
-        # Check trigger has correct kwargs
-        assert trigger.kwargs["callback_path"] == TEST_CALLBACK_PATH
-        assert trigger.kwargs["callback_kwargs"] == TEST_CALLBACK_KWARGS
-
     def test_handle_miss_sets_callback_state(self, dagrun, deadline_orm, 
session):
         """Test that handle_miss sets the callback state to QUEUED."""
         deadline_orm.handle_miss(session)
diff --git a/airflow-core/tests/unit/triggers/test_deadline.py 
b/airflow-core/tests/unit/triggers/test_deadline.py
index 72bea33f188..955b6cb49c0 100644
--- a/airflow-core/tests/unit/triggers/test_deadline.py
+++ b/airflow-core/tests/unit/triggers/test_deadline.py
@@ -27,7 +27,7 @@ from airflow.triggers.deadline import PAYLOAD_BODY_KEY, 
PAYLOAD_STATUS_KEY, Dead
 
 TEST_MESSAGE = "test_message"
 TEST_CALLBACK_PATH = "classpath.test_callback_for_deadline"
-TEST_CALLBACK_KWARGS = {"message": TEST_MESSAGE}
+TEST_CALLBACK_KWARGS = {"message": TEST_MESSAGE, "context": {"dag_run": 
"test"}}
 TEST_TRIGGER = DeadlineCallbackTrigger(callback_path=TEST_CALLBACK_PATH, 
callback_kwargs=TEST_CALLBACK_KWARGS)
 
 
@@ -85,7 +85,7 @@ class TestDeadlineCallbackTrigger:
 
         success_event = await anext(trigger_gen)
         mock_import_string.assert_called_once_with(TEST_CALLBACK_PATH)
-        mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS, 
context=mock.ANY)
+        mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS)
         assert success_event.payload[PAYLOAD_STATUS_KEY] == 
DeadlineCallbackState.SUCCESS
         assert success_event.payload[PAYLOAD_BODY_KEY] == callback_return_value
 
@@ -102,7 +102,10 @@ class TestDeadlineCallbackTrigger:
         success_event = await anext(trigger_gen)
         mock_import_string.assert_called_once_with(TEST_CALLBACK_PATH)
         assert success_event.payload[PAYLOAD_STATUS_KEY] == 
DeadlineCallbackState.SUCCESS
-        assert success_event.payload[PAYLOAD_BODY_KEY] == f"Async 
notification: {TEST_MESSAGE}, context: {{}}"
+        assert (
+            success_event.payload[PAYLOAD_BODY_KEY]
+            == f"Async notification: {TEST_MESSAGE}, context: {{'dag_run': 
'test'}}"
+        )
 
     @pytest.mark.asyncio
     async def test_run_failure(self, mock_import_string):
@@ -117,6 +120,6 @@ class TestDeadlineCallbackTrigger:
 
         failure_event = await anext(trigger_gen)
         mock_import_string.assert_called_once_with(TEST_CALLBACK_PATH)
-        mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS, 
context=mock.ANY)
+        mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS)
         assert failure_event.payload[PAYLOAD_STATUS_KEY] == 
DeadlineCallbackState.FAILED
         assert all(s in failure_event.payload[PAYLOAD_BODY_KEY] for s in 
["raise", "RuntimeError", exc_msg])
diff --git a/task-sdk/src/airflow/sdk/definitions/deadline.py 
b/task-sdk/src/airflow/sdk/definitions/deadline.py
index 966e2b926a6..46e5eeb7be2 100644
--- a/task-sdk/src/airflow/sdk/definitions/deadline.py
+++ b/task-sdk/src/airflow/sdk/definitions/deadline.py
@@ -120,8 +120,10 @@ class Callback(ABC):
     path: str
     kwargs: dict | None
 
-    def __init__(self, callback_callable: Callable | str, kwargs: dict | None 
= None):
+    def __init__(self, callback_callable: Callable | str, kwargs: dict[str, 
Any] | None = None):
         self.path = self.get_callback_path(callback_callable)
+        if kwargs and "context" in kwargs:
+            raise ValueError("context is a reserved kwarg for this class")
         self.kwargs = kwargs
 
     @classmethod
diff --git a/task-sdk/tests/task_sdk/definitions/test_deadline.py 
b/task-sdk/tests/task_sdk/definitions/test_deadline.py
index 8bb70a7fad2..654cc41e2b8 100644
--- a/task-sdk/tests/task_sdk/definitions/test_deadline.py
+++ b/task-sdk/tests/task_sdk/definitions/test_deadline.py
@@ -160,6 +160,17 @@ class TestDeadlineAlert:
 
 
 class TestCallback:
+    @pytest.mark.parametrize(
+        "subclass, callable",
+        [
+            pytest.param(AsyncCallback, 
empty_async_callback_for_deadline_tests, id="async"),
+            pytest.param(SyncCallback, empty_sync_callback_for_deadline_tests, 
id="sync"),
+        ],
+    )
+    def test_init_error_reserved_kwarg(self, subclass, callable):
+        with pytest.raises(ValueError, match="context is a reserved kwarg for 
this class"):
+            subclass(callable, {"context": None})
+
     @pytest.mark.parametrize(
         "callback_callable, expected_path",
         [

Reply via email to