This is an automated email from the ASF dual-hosted git repository.

shahar1 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 23fbf878220 Fix consistent return response from PubSubPullOperator 
(#66156)
23fbf878220 is described below

commit 23fbf8782209dd747a67327b63881db5894faa99
Author: Maksim <[email protected]>
AuthorDate: Tue May 5 17:17:46 2026 +0200

    Fix consistent return response from PubSubPullOperator (#66156)
---
 .../providers/google/cloud/operators/pubsub.py     | 17 ++++-
 .../providers/google/cloud/sensors/pubsub.py       |  4 +-
 .../unit/google/cloud/operators/test_pubsub.py     | 79 ++++++++++++++++++++++
 3 files changed, 97 insertions(+), 3 deletions(-)

diff --git 
a/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py 
b/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py
index be905677586..33a1b302786 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py
@@ -30,6 +30,7 @@ from functools import cached_property
 from typing import TYPE_CHECKING, Any
 
 from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
+from google.cloud import pubsub_v1
 from google.cloud.pubsub_v1.types import (
     DeadLetterPolicy,
     Duration,
@@ -56,6 +57,10 @@ if TYPE_CHECKING:
     from airflow.providers.openlineage.extractors import OperatorLineage
 
 
+class PubSubMessageTransformException(Exception):
+    """Raise when messages failed to convert pubsub received format."""
+
+
 class PubSubCreateTopicOperator(GoogleCloudBaseOperator):
     """
     Create a PubSub topic.
@@ -871,12 +876,22 @@ class PubSubPullOperator(GoogleCloudBaseOperator):
         if event["status"] == "success":
             self.log.info("Sensor pulls messages: %s", event["message"])
             messages_callback = self.messages_callback or 
self._default_message_callback
-            _return_value = messages_callback(event["message"], context)
+            received_messages = 
self._convert_to_received_messages(event["message"])
+            _return_value = messages_callback(received_messages, context)
             return _return_value
 
         self.log.info("Sensor failed: %s", event["message"])
         raise AirflowException(event["message"])
 
+    def _convert_to_received_messages(self, messages: Any) -> 
list[ReceivedMessage]:
+        try:
+            received_messages = [pubsub_v1.types.ReceivedMessage(msg) for msg 
in messages]
+            return received_messages
+        except Exception as e:
+            raise PubSubMessageTransformException(
+                f"Error converting triggerer event message back to received 
message format: {e}"
+            ) from e
+
     def _default_message_callback(
         self,
         pulled_messages: list[ReceivedMessage],
diff --git 
a/providers/google/src/airflow/providers/google/cloud/sensors/pubsub.py 
b/providers/google/src/airflow/providers/google/cloud/sensors/pubsub.py
index a6fe4db15f1..f138271b66e 100644
--- a/providers/google/src/airflow/providers/google/cloud/sensors/pubsub.py
+++ b/providers/google/src/airflow/providers/google/cloud/sensors/pubsub.py
@@ -34,7 +34,7 @@ if TYPE_CHECKING:
     from airflow.providers.common.compat.sdk import Context
 
 
-class PubSubMessageTransformException(AirflowException):
+class PubSubMessageTransformException(Exception):
     """Raise when messages failed to convert pubsub received format."""
 
 
@@ -200,7 +200,7 @@ class PubSubPullSensor(BaseSensorOperator):
         except Exception as e:
             raise PubSubMessageTransformException(
                 f"Error converting triggerer event message back to received 
message format: {e}"
-            )
+            ) from e
 
     def _default_message_callback(
         self,
diff --git a/providers/google/tests/unit/google/cloud/operators/test_pubsub.py 
b/providers/google/tests/unit/google/cloud/operators/test_pubsub.py
index 7a0ce1ba02b..3537c5266db 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_pubsub.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_pubsub.py
@@ -22,6 +22,7 @@ from unittest import mock
 
 import pytest
 from google.api_core.gapic_v1.method import DEFAULT
+from google.cloud import pubsub_v1
 from google.cloud.pubsub_v1.types import ReceivedMessage
 
 from airflow.providers.common.compat.sdk import TaskDeferred
@@ -552,3 +553,81 @@ class TestPubSubPullOperator:
         assert len(result.outputs) == 1
         assert result.outputs[0].namespace == "pubsub"
         assert result.outputs[0].name == 
f"subscription:{TEST_PROJECT}:{TEST_SUBSCRIPTION}"
+
+    @mock.patch("airflow.providers.google.cloud.operators.pubsub.PubSubHook")
+    def test_execute_complete_use_message_callback(self, mock_hook):
+        test_message = [
+            {
+                "ack_id": 
"UAYWLF1GSFE3GQhoUQ5PXiM_NSAoRRIJB08CKF15MU0sQVhwaFENGXJ9YHxrUxsDV0ECel1RGQdoTm11H4GglfRLQ1RrWBIHB01Vel5TEwxoX11wBnm4vPO6v8vgfwk9OpX-8tltO6ywsP9GZiM9XhJLLD5-LzlFQV5AEkwkDERJUytDCypYEU4EISE-MD5FU0Q",
+                "message": {
+                    "data": "aGkgZnJvbSBjbG91ZCBjb25zb2xlIQ==",
+                    "message_id": "12165864188103151",
+                    "publish_time": "2024-08-28T11:49:50.962Z",
+                    "attributes": {},
+                    "ordering_key": "",
+                },
+                "delivery_attempt": 0,
+            }
+        ]
+
+        received_messages = [pubsub_v1.types.ReceivedMessage(msg) for msg in 
test_message]
+
+        messages_callback_return_value = "custom_message_from_callback"
+
+        def messages_callback(
+            pulled_messages: list[ReceivedMessage],
+            context: dict[str, Any],
+        ):
+            assert pulled_messages == received_messages
+
+            assert isinstance(context, dict)
+            for key in context.keys():
+                assert isinstance(key, str)
+
+            return messages_callback_return_value
+
+        operator = PubSubPullOperator(
+            task_id="test_task",
+            ack_messages=True,
+            project_id=TEST_PROJECT,
+            subscription=TEST_SUBSCRIPTION,
+            deferrable=True,
+            messages_callback=messages_callback,
+        )
+        mock_hook.return_value.pull.return_value = received_messages
+
+        with mock.patch.object(operator.log, "info") as mock_log_info:
+            resp = operator.execute_complete(context={}, event={"status": 
"success", "message": test_message})
+        mock_log_info.assert_called_with("Sensor pulls messages: %s", 
test_message)
+        assert resp == messages_callback_return_value
+
+    @mock.patch("airflow.providers.google.cloud.operators.pubsub.PubSubHook")
+    def test_execute_complete_use_default_message_callback(self, mock_hook):
+        test_message = [
+            {
+                "ack_id": 
"UAYWLF1GSFE3GQhoUQ5PXiM_NSAoRRIJB08CKF15MU0sQVhwaFENGXJ9YHxrUxsDV0ECel1RGQdoTm11H4GglfRLQ1RrWBIHB01Vel5TEwxoX11wBnm4vPO6v8vgfwk9OpX-8tltO6ywsP9GZiM9XhJLLD5-LzlFQV5AEkwkDERJUytDCypYEU4EISE-MD5FU0Q",
+                "message": {
+                    "data": "aGkgZnJvbSBjbG91ZCBjb25zb2xlIQ==",
+                    "message_id": "12165864188103151",
+                    "publish_time": "2024-08-28T11:49:50.962Z",
+                    "attributes": {},
+                    "ordering_key": "",
+                },
+                "delivery_attempt": 0,
+            }
+        ]
+        received_messages = [pubsub_v1.types.ReceivedMessage(msg) for msg in 
test_message]
+
+        operator = PubSubPullOperator(
+            task_id="test_task",
+            ack_messages=True,
+            project_id=TEST_PROJECT,
+            subscription=TEST_SUBSCRIPTION,
+            deferrable=True,
+        )
+        mock_hook.return_value.pull.return_value = received_messages
+
+        with mock.patch.object(operator.log, "info") as mock_log_info:
+            resp = operator.execute_complete(context={}, event={"status": 
"success", "message": test_message})
+        mock_log_info.assert_called_with("Sensor pulls messages: %s", 
test_message)
+        assert resp == [ReceivedMessage.to_dict(m) for m in received_messages]

Reply via email to