This is an automated email from the ASF dual-hosted git repository.
shahar1 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 23fbf878220 Fix consistent return response from PubSubPullOperator
(#66156)
23fbf878220 is described below
commit 23fbf8782209dd747a67327b63881db5894faa99
Author: Maksim <[email protected]>
AuthorDate: Tue May 5 17:17:46 2026 +0200
Fix consistent return response from PubSubPullOperator (#66156)
---
.../providers/google/cloud/operators/pubsub.py | 17 ++++-
.../providers/google/cloud/sensors/pubsub.py | 4 +-
.../unit/google/cloud/operators/test_pubsub.py | 79 ++++++++++++++++++++++
3 files changed, 97 insertions(+), 3 deletions(-)
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py
b/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py
index be905677586..33a1b302786 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py
@@ -30,6 +30,7 @@ from functools import cached_property
from typing import TYPE_CHECKING, Any
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
+from google.cloud import pubsub_v1
from google.cloud.pubsub_v1.types import (
DeadLetterPolicy,
Duration,
@@ -56,6 +57,10 @@ if TYPE_CHECKING:
from airflow.providers.openlineage.extractors import OperatorLineage
+class PubSubMessageTransformException(Exception):
+ """Raise when messages failed to convert pubsub received format."""
+
+
class PubSubCreateTopicOperator(GoogleCloudBaseOperator):
"""
Create a PubSub topic.
@@ -871,12 +876,22 @@ class PubSubPullOperator(GoogleCloudBaseOperator):
if event["status"] == "success":
self.log.info("Sensor pulls messages: %s", event["message"])
messages_callback = self.messages_callback or
self._default_message_callback
- _return_value = messages_callback(event["message"], context)
+ received_messages =
self._convert_to_received_messages(event["message"])
+ _return_value = messages_callback(received_messages, context)
return _return_value
self.log.info("Sensor failed: %s", event["message"])
raise AirflowException(event["message"])
+ def _convert_to_received_messages(self, messages: Any) ->
list[ReceivedMessage]:
+ try:
+ received_messages = [pubsub_v1.types.ReceivedMessage(msg) for msg
in messages]
+ return received_messages
+ except Exception as e:
+ raise PubSubMessageTransformException(
+ f"Error converting triggerer event message back to received
message format: {e}"
+ ) from e
+
def _default_message_callback(
self,
pulled_messages: list[ReceivedMessage],
diff --git
a/providers/google/src/airflow/providers/google/cloud/sensors/pubsub.py
b/providers/google/src/airflow/providers/google/cloud/sensors/pubsub.py
index a6fe4db15f1..f138271b66e 100644
--- a/providers/google/src/airflow/providers/google/cloud/sensors/pubsub.py
+++ b/providers/google/src/airflow/providers/google/cloud/sensors/pubsub.py
@@ -34,7 +34,7 @@ if TYPE_CHECKING:
from airflow.providers.common.compat.sdk import Context
-class PubSubMessageTransformException(AirflowException):
+class PubSubMessageTransformException(Exception):
"""Raise when messages failed to convert pubsub received format."""
@@ -200,7 +200,7 @@ class PubSubPullSensor(BaseSensorOperator):
except Exception as e:
raise PubSubMessageTransformException(
f"Error converting triggerer event message back to received
message format: {e}"
- )
+ ) from e
def _default_message_callback(
self,
diff --git a/providers/google/tests/unit/google/cloud/operators/test_pubsub.py
b/providers/google/tests/unit/google/cloud/operators/test_pubsub.py
index 7a0ce1ba02b..3537c5266db 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_pubsub.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_pubsub.py
@@ -22,6 +22,7 @@ from unittest import mock
import pytest
from google.api_core.gapic_v1.method import DEFAULT
+from google.cloud import pubsub_v1
from google.cloud.pubsub_v1.types import ReceivedMessage
from airflow.providers.common.compat.sdk import TaskDeferred
@@ -552,3 +553,81 @@ class TestPubSubPullOperator:
assert len(result.outputs) == 1
assert result.outputs[0].namespace == "pubsub"
assert result.outputs[0].name ==
f"subscription:{TEST_PROJECT}:{TEST_SUBSCRIPTION}"
+
+ @mock.patch("airflow.providers.google.cloud.operators.pubsub.PubSubHook")
+ def test_execute_complete_use_message_callback(self, mock_hook):
+ test_message = [
+ {
+ "ack_id":
"UAYWLF1GSFE3GQhoUQ5PXiM_NSAoRRIJB08CKF15MU0sQVhwaFENGXJ9YHxrUxsDV0ECel1RGQdoTm11H4GglfRLQ1RrWBIHB01Vel5TEwxoX11wBnm4vPO6v8vgfwk9OpX-8tltO6ywsP9GZiM9XhJLLD5-LzlFQV5AEkwkDERJUytDCypYEU4EISE-MD5FU0Q",
+ "message": {
+ "data": "aGkgZnJvbSBjbG91ZCBjb25zb2xlIQ==",
+ "message_id": "12165864188103151",
+ "publish_time": "2024-08-28T11:49:50.962Z",
+ "attributes": {},
+ "ordering_key": "",
+ },
+ "delivery_attempt": 0,
+ }
+ ]
+
+ received_messages = [pubsub_v1.types.ReceivedMessage(msg) for msg in
test_message]
+
+ messages_callback_return_value = "custom_message_from_callback"
+
+ def messages_callback(
+ pulled_messages: list[ReceivedMessage],
+ context: dict[str, Any],
+ ):
+ assert pulled_messages == received_messages
+
+ assert isinstance(context, dict)
+ for key in context.keys():
+ assert isinstance(key, str)
+
+ return messages_callback_return_value
+
+ operator = PubSubPullOperator(
+ task_id="test_task",
+ ack_messages=True,
+ project_id=TEST_PROJECT,
+ subscription=TEST_SUBSCRIPTION,
+ deferrable=True,
+ messages_callback=messages_callback,
+ )
+ mock_hook.return_value.pull.return_value = received_messages
+
+ with mock.patch.object(operator.log, "info") as mock_log_info:
+ resp = operator.execute_complete(context={}, event={"status":
"success", "message": test_message})
+ mock_log_info.assert_called_with("Sensor pulls messages: %s",
test_message)
+ assert resp == messages_callback_return_value
+
+ @mock.patch("airflow.providers.google.cloud.operators.pubsub.PubSubHook")
+ def test_execute_complete_use_default_message_callback(self, mock_hook):
+ test_message = [
+ {
+ "ack_id":
"UAYWLF1GSFE3GQhoUQ5PXiM_NSAoRRIJB08CKF15MU0sQVhwaFENGXJ9YHxrUxsDV0ECel1RGQdoTm11H4GglfRLQ1RrWBIHB01Vel5TEwxoX11wBnm4vPO6v8vgfwk9OpX-8tltO6ywsP9GZiM9XhJLLD5-LzlFQV5AEkwkDERJUytDCypYEU4EISE-MD5FU0Q",
+ "message": {
+ "data": "aGkgZnJvbSBjbG91ZCBjb25zb2xlIQ==",
+ "message_id": "12165864188103151",
+ "publish_time": "2024-08-28T11:49:50.962Z",
+ "attributes": {},
+ "ordering_key": "",
+ },
+ "delivery_attempt": 0,
+ }
+ ]
+ received_messages = [pubsub_v1.types.ReceivedMessage(msg) for msg in
test_message]
+
+ operator = PubSubPullOperator(
+ task_id="test_task",
+ ack_messages=True,
+ project_id=TEST_PROJECT,
+ subscription=TEST_SUBSCRIPTION,
+ deferrable=True,
+ )
+ mock_hook.return_value.pull.return_value = received_messages
+
+ with mock.patch.object(operator.log, "info") as mock_log_info:
+ resp = operator.execute_complete(context={}, event={"status":
"success", "message": test_message})
+ mock_log_info.assert_called_with("Sensor pulls messages: %s",
test_message)
+ assert resp == [ReceivedMessage.to_dict(m) for m in received_messages]