This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 11fb956a300 Add commit_offset option to AwaitMessageSensor and
AwaitMessageTrigger (#62916)
11fb956a300 is described below
commit 11fb956a3001c9e11a75e7aa402fda4734f6ac6e
Author: Filipp <[email protected]>
AuthorDate: Thu Mar 5 20:36:53 2026 +0300
Add commit_offset option to AwaitMessageSensor and AwaitMessageTrigger
(#62916)
---
.../src/airflow/providers/apache/kafka/sensors/kafka.py | 14 +++++++++++---
.../providers/apache/kafka/triggers/await_message.py | 17 +++++++++++++++--
.../unit/apache/kafka/triggers/test_await_message.py | 2 ++
.../tests/unit/apache/kafka/triggers/test_msg_queue.py | 1 +
4 files changed, 29 insertions(+), 5 deletions(-)
diff --git
a/providers/apache/kafka/src/airflow/providers/apache/kafka/sensors/kafka.py
b/providers/apache/kafka/src/airflow/providers/apache/kafka/sensors/kafka.py
index 42c249b6f58..05ad8259550 100644
--- a/providers/apache/kafka/src/airflow/providers/apache/kafka/sensors/kafka.py
+++ b/providers/apache/kafka/src/airflow/providers/apache/kafka/sensors/kafka.py
@@ -35,6 +35,7 @@ class AwaitMessageSensor(BaseSensorOperator):
- poll the Kafka topics for a message
- if no message returned, sleep
- process the message with provided callable and commit the message offset
+ - if commit_offset is True (default), commit the message offset after
processing
- if callable returns any data, raise a TriggerEvent with the return data
- else continue to next message
- return event (as default xcom or specific xcom key)
@@ -53,6 +54,9 @@ class AwaitMessageSensor(BaseSensorOperator):
:param poll_interval: How long the kafka consumer should sleep after
reaching the end of the Kafka log,
defaults to 5
:param xcom_push_key: the name of a key to push the returned message to,
defaults to None
+ :param commit_offset: Whether to commit the message offset after
processing.
+ If False, the offset is not committed by the sensor, allowing
downstream
+ tasks to commit it manually (e.g., after successful processing).
Defaults to True.
:param soft_fail: Set to true to mark the task as SKIPPED on failure
:param timeout: Time elapsed before the task times out and fails (in
seconds)
:param poke_interval: This parameter is inherited but not used in this
deferrable implementation
@@ -70,18 +74,20 @@ class AwaitMessageSensor(BaseSensorOperator):
"apply_function_args",
"apply_function_kwargs",
"kafka_config_id",
+ "commit_offset",
)
def __init__(
self,
topics: Sequence[str],
- apply_function: str,
+ apply_function: str | None,
kafka_config_id: str = "kafka_default",
apply_function_args: Sequence[Any] | None = None,
apply_function_kwargs: dict[Any, Any] | None = None,
poll_timeout: float = 1,
poll_interval: float = 5,
- xcom_push_key=None,
+ xcom_push_key: str | None = None,
+ commit_offset: bool = True,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
@@ -94,6 +100,7 @@ class AwaitMessageSensor(BaseSensorOperator):
self.poll_timeout = poll_timeout
self.poll_interval = poll_interval
self.xcom_push_key = xcom_push_key
+ self.commit_offset = commit_offset
def execute(self, context) -> Any:
self.defer(
@@ -105,6 +112,7 @@ class AwaitMessageSensor(BaseSensorOperator):
kafka_config_id=self.kafka_config_id,
poll_timeout=self.poll_timeout,
poll_interval=self.poll_interval,
+ commit_offset=self.commit_offset,
),
method_name="execute_complete",
)
@@ -163,7 +171,7 @@ class AwaitMessageTriggerFunctionSensor(BaseSensorOperator):
def __init__(
self,
topics: Sequence[str],
- apply_function: str,
+ apply_function: str | None,
event_triggered_function: Callable,
kafka_config_id: str = "kafka_default",
apply_function_args: Sequence[Any] | None = None,
diff --git
a/providers/apache/kafka/src/airflow/providers/apache/kafka/triggers/await_message.py
b/providers/apache/kafka/src/airflow/providers/apache/kafka/triggers/await_message.py
index 0e431b82897..755969bbc58 100644
---
a/providers/apache/kafka/src/airflow/providers/apache/kafka/triggers/await_message.py
+++
b/providers/apache/kafka/src/airflow/providers/apache/kafka/triggers/await_message.py
@@ -46,6 +46,10 @@ class AwaitMessageTrigger(BaseEventTrigger):
- if callable is provided and returns any data, raise a TriggerEvent
with the return data
- else raise a TriggerEvent with the original message
+ - by default, the message offset is committed after processing. This can be
+ disabled by setting ``commit_offset=False``, allowing manual offset
management
+ in downstream tasks.
+
:param kafka_config_id: The connection object to use, defaults to
"kafka_default"
:param topics: The topic (or topic regex) that should be searched for
messages
:param apply_function: the location of the function to apply to messages
for determination of matching
@@ -57,6 +61,10 @@ class AwaitMessageTrigger(BaseEventTrigger):
Kafka (seconds), defaults to 1
:param poll_interval: How long the trigger should sleep after reaching the
end of the Kafka log
(seconds), defaults to 5
+ :param commit_offset: Whether to commit the message offset after poll.
+ If set to False, the offset is not committed automatically, allowing
+ downstream tasks to handle offset committing manually (e.g., after
+ successful processing). Defaults to True.
"""
@@ -69,6 +77,7 @@ class AwaitMessageTrigger(BaseEventTrigger):
apply_function_kwargs: dict[Any, Any] | None = None,
poll_timeout: float = 1,
poll_interval: float = 5,
+ commit_offset: bool = True,
) -> None:
self.topics = topics
self.apply_function = apply_function
@@ -77,6 +86,7 @@ class AwaitMessageTrigger(BaseEventTrigger):
self.kafka_config_id = kafka_config_id
self.poll_timeout = poll_timeout
self.poll_interval = poll_interval
+ self.commit_offset = commit_offset
def serialize(self) -> tuple[str, dict[str, Any]]:
return (
@@ -89,6 +99,7 @@ class AwaitMessageTrigger(BaseEventTrigger):
"kafka_config_id": self.kafka_config_id,
"poll_timeout": self.poll_timeout,
"poll_interval": self.poll_interval,
+ "commit_offset": self.commit_offset,
},
)
@@ -122,9 +133,11 @@ class AwaitMessageTrigger(BaseEventTrigger):
else message.value().decode("utf-8")
)
if event:
- await async_commit(message=message, asynchronous=False)
+ if self.commit_offset:
+ await async_commit(message=message, asynchronous=False)
yield TriggerEvent(event)
break
else:
- await async_commit(message=message, asynchronous=False)
+ if self.commit_offset:
+ await async_commit(message=message, asynchronous=False)
await asyncio.sleep(self.poll_interval)
diff --git
a/providers/apache/kafka/tests/unit/apache/kafka/triggers/test_await_message.py
b/providers/apache/kafka/tests/unit/apache/kafka/triggers/test_await_message.py
index f08c17dfb5f..a746503c693 100644
---
a/providers/apache/kafka/tests/unit/apache/kafka/triggers/test_await_message.py
+++
b/providers/apache/kafka/tests/unit/apache/kafka/triggers/test_await_message.py
@@ -82,6 +82,7 @@ class TestTrigger:
apply_function_kwargs=dict(one=1, two=2),
poll_timeout=10,
poll_interval=5,
+ commit_offset=True,
)
assert isinstance(trigger, AwaitMessageTrigger)
@@ -97,6 +98,7 @@ class TestTrigger:
apply_function_kwargs=dict(one=1, two=2),
poll_timeout=10,
poll_interval=5,
+ commit_offset=True,
)
@pytest.mark.parametrize(
diff --git
a/providers/apache/kafka/tests/unit/apache/kafka/triggers/test_msg_queue.py
b/providers/apache/kafka/tests/unit/apache/kafka/triggers/test_msg_queue.py
index 151b5f48d32..372925e757d 100644
--- a/providers/apache/kafka/tests/unit/apache/kafka/triggers/test_msg_queue.py
+++ b/providers/apache/kafka/tests/unit/apache/kafka/triggers/test_msg_queue.py
@@ -112,6 +112,7 @@ class TestMessageQueueTrigger:
"apply_function_kwargs": {"one": 1, "two": 2},
"poll_timeout": 10,
"poll_interval": 5,
+ "commit_offset": True,
}
@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Requires Airflow
3.0.+")