This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 11fb956a300 Add commit_offset option to AwaitMessageSensor and 
AwaitMessageTrigger (#62916)
11fb956a300 is described below

commit 11fb956a3001c9e11a75e7aa402fda4734f6ac6e
Author: Filipp <[email protected]>
AuthorDate: Thu Mar 5 20:36:53 2026 +0300

    Add commit_offset option to AwaitMessageSensor and AwaitMessageTrigger 
(#62916)
---
 .../src/airflow/providers/apache/kafka/sensors/kafka.py | 14 +++++++++++---
 .../providers/apache/kafka/triggers/await_message.py    | 17 +++++++++++++++--
 .../unit/apache/kafka/triggers/test_await_message.py    |  2 ++
 .../tests/unit/apache/kafka/triggers/test_msg_queue.py  |  1 +
 4 files changed, 29 insertions(+), 5 deletions(-)

diff --git 
a/providers/apache/kafka/src/airflow/providers/apache/kafka/sensors/kafka.py 
b/providers/apache/kafka/src/airflow/providers/apache/kafka/sensors/kafka.py
index 42c249b6f58..05ad8259550 100644
--- a/providers/apache/kafka/src/airflow/providers/apache/kafka/sensors/kafka.py
+++ b/providers/apache/kafka/src/airflow/providers/apache/kafka/sensors/kafka.py
@@ -35,6 +35,7 @@ class AwaitMessageSensor(BaseSensorOperator):
     - poll the Kafka topics for a message
     - if no message returned, sleep
     - process the message with provided callable and commit the message offset
+    - if commit_offset is True (default), commit the message offset after 
processing
     - if callable returns any data, raise a TriggerEvent with the return data
     - else continue to next message
     - return event (as default xcom or specific xcom key)
@@ -53,6 +54,9 @@ class AwaitMessageSensor(BaseSensorOperator):
     :param poll_interval: How long the kafka consumer should sleep after 
reaching the end of the Kafka log,
         defaults to 5
     :param xcom_push_key: the name of a key to push the returned message to, 
defaults to None
+    :param commit_offset: Whether to commit the message offset after 
processing.
+        If False, the offset is not committed by the sensor, allowing 
downstream
+        tasks to commit it manually (e.g., after successful processing). 
Defaults to True.
     :param soft_fail: Set to true to mark the task as SKIPPED on failure
     :param timeout: Time elapsed before the task times out and fails (in 
seconds)
     :param poke_interval: This parameter is inherited but not used in this 
deferrable implementation
@@ -70,18 +74,20 @@ class AwaitMessageSensor(BaseSensorOperator):
         "apply_function_args",
         "apply_function_kwargs",
         "kafka_config_id",
+        "commit_offset",
     )
 
     def __init__(
         self,
         topics: Sequence[str],
-        apply_function: str,
+        apply_function: str | None,
         kafka_config_id: str = "kafka_default",
         apply_function_args: Sequence[Any] | None = None,
         apply_function_kwargs: dict[Any, Any] | None = None,
         poll_timeout: float = 1,
         poll_interval: float = 5,
-        xcom_push_key=None,
+        xcom_push_key: str | None = None,
+        commit_offset: bool = True,
         **kwargs: Any,
     ) -> None:
         super().__init__(**kwargs)
@@ -94,6 +100,7 @@ class AwaitMessageSensor(BaseSensorOperator):
         self.poll_timeout = poll_timeout
         self.poll_interval = poll_interval
         self.xcom_push_key = xcom_push_key
+        self.commit_offset = commit_offset
 
     def execute(self, context) -> Any:
         self.defer(
@@ -105,6 +112,7 @@ class AwaitMessageSensor(BaseSensorOperator):
                 kafka_config_id=self.kafka_config_id,
                 poll_timeout=self.poll_timeout,
                 poll_interval=self.poll_interval,
+                commit_offset=self.commit_offset,
             ),
             method_name="execute_complete",
         )
@@ -163,7 +171,7 @@ class AwaitMessageTriggerFunctionSensor(BaseSensorOperator):
     def __init__(
         self,
         topics: Sequence[str],
-        apply_function: str,
+        apply_function: str | None,
         event_triggered_function: Callable,
         kafka_config_id: str = "kafka_default",
         apply_function_args: Sequence[Any] | None = None,
diff --git 
a/providers/apache/kafka/src/airflow/providers/apache/kafka/triggers/await_message.py
 
b/providers/apache/kafka/src/airflow/providers/apache/kafka/triggers/await_message.py
index 0e431b82897..755969bbc58 100644
--- 
a/providers/apache/kafka/src/airflow/providers/apache/kafka/triggers/await_message.py
+++ 
b/providers/apache/kafka/src/airflow/providers/apache/kafka/triggers/await_message.py
@@ -46,6 +46,10 @@ class AwaitMessageTrigger(BaseEventTrigger):
         - if callable is provided and returns any data, raise a TriggerEvent 
with the return data
         - else raise a TriggerEvent with the original message
 
+    - by default, the message offset is committed after processing. This can be
+      disabled by setting ``commit_offset=False``, allowing manual offset 
management
+      in downstream tasks.
+
     :param kafka_config_id: The connection object to use, defaults to 
"kafka_default"
     :param topics: The topic (or topic regex) that should be searched for 
messages
     :param apply_function: the location of the function to apply to messages 
for determination of matching
@@ -57,6 +61,10 @@ class AwaitMessageTrigger(BaseEventTrigger):
         Kafka (seconds), defaults to 1
     :param poll_interval: How long the trigger should sleep after reaching the 
end of the Kafka log
         (seconds), defaults to 5
+    :param commit_offset: Whether to commit the message offset after poll.
+        If set to False, the offset is not committed automatically, allowing
+        downstream tasks to handle offset committing manually (e.g., after
+        successful processing). Defaults to True.
 
     """
 
@@ -69,6 +77,7 @@ class AwaitMessageTrigger(BaseEventTrigger):
         apply_function_kwargs: dict[Any, Any] | None = None,
         poll_timeout: float = 1,
         poll_interval: float = 5,
+        commit_offset: bool = True,
     ) -> None:
         self.topics = topics
         self.apply_function = apply_function
@@ -77,6 +86,7 @@ class AwaitMessageTrigger(BaseEventTrigger):
         self.kafka_config_id = kafka_config_id
         self.poll_timeout = poll_timeout
         self.poll_interval = poll_interval
+        self.commit_offset = commit_offset
 
     def serialize(self) -> tuple[str, dict[str, Any]]:
         return (
@@ -89,6 +99,7 @@ class AwaitMessageTrigger(BaseEventTrigger):
                 "kafka_config_id": self.kafka_config_id,
                 "poll_timeout": self.poll_timeout,
                 "poll_interval": self.poll_interval,
+                "commit_offset": self.commit_offset,
             },
         )
 
@@ -122,9 +133,11 @@ class AwaitMessageTrigger(BaseEventTrigger):
                     else message.value().decode("utf-8")
                 )
                 if event:
-                    await async_commit(message=message, asynchronous=False)
+                    if self.commit_offset:
+                        await async_commit(message=message, asynchronous=False)
                     yield TriggerEvent(event)
                     break
                 else:
-                    await async_commit(message=message, asynchronous=False)
+                    if self.commit_offset:
+                        await async_commit(message=message, asynchronous=False)
                     await asyncio.sleep(self.poll_interval)
diff --git 
a/providers/apache/kafka/tests/unit/apache/kafka/triggers/test_await_message.py 
b/providers/apache/kafka/tests/unit/apache/kafka/triggers/test_await_message.py
index f08c17dfb5f..a746503c693 100644
--- 
a/providers/apache/kafka/tests/unit/apache/kafka/triggers/test_await_message.py
+++ 
b/providers/apache/kafka/tests/unit/apache/kafka/triggers/test_await_message.py
@@ -82,6 +82,7 @@ class TestTrigger:
             apply_function_kwargs=dict(one=1, two=2),
             poll_timeout=10,
             poll_interval=5,
+            commit_offset=True,
         )
 
         assert isinstance(trigger, AwaitMessageTrigger)
@@ -97,6 +98,7 @@ class TestTrigger:
             apply_function_kwargs=dict(one=1, two=2),
             poll_timeout=10,
             poll_interval=5,
+            commit_offset=True,
         )
 
     @pytest.mark.parametrize(
diff --git 
a/providers/apache/kafka/tests/unit/apache/kafka/triggers/test_msg_queue.py 
b/providers/apache/kafka/tests/unit/apache/kafka/triggers/test_msg_queue.py
index 151b5f48d32..372925e757d 100644
--- a/providers/apache/kafka/tests/unit/apache/kafka/triggers/test_msg_queue.py
+++ b/providers/apache/kafka/tests/unit/apache/kafka/triggers/test_msg_queue.py
@@ -112,6 +112,7 @@ class TestMessageQueueTrigger:
             "apply_function_kwargs": {"one": 1, "two": 2},
             "poll_timeout": 10,
             "poll_interval": 5,
+            "commit_offset": True,
         }
 
     @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Requires Airflow 
3.0.+")

Reply via email to