This is an automated email from the ASF dual-hosted git repository.
altay pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new f792e2e Add helper functions for reading and writing to PubSub
directly from Python (#9212)
f792e2e is described below
commit f792e2e46925ace3e0221ff6bf17fdede3383fbd
Author: Alexey Strokach <[email protected]>
AuthorDate: Wed Aug 7 17:16:18 2019 -0700
Add helper functions for reading and writing to PubSub directly from Python
(#9212)
* Add helper functions for reading and writing to PubSub directly from
Python
These functions are helpful when writing tests and when working with
streaming pipelines interactively (e.g. inside a Jupyter notebook).
Notes:
- Not sure if apache_beam/testing/test_utils.py is a better place for the
helper functions than apache_beam/io/gcp/tests/utils.py?
- google.cloud.exceptions seems to have moved to
google.api_core.exceptions. Currently, google.cloud.exceptions re-imports some,
but not all, of the exceptions defined in google.api_core.exceptions.
---
sdks/python/apache_beam/io/gcp/tests/utils.py | 57 +++++-
sdks/python/apache_beam/io/gcp/tests/utils_test.py | 200 ++++++++++++++++++++-
2 files changed, 249 insertions(+), 8 deletions(-)
diff --git a/sdks/python/apache_beam/io/gcp/tests/utils.py
b/sdks/python/apache_beam/io/gcp/tests/utils.py
index 68d3f43..4ed9af3 100644
--- a/sdks/python/apache_beam/io/gcp/tests/utils.py
+++ b/sdks/python/apache_beam/io/gcp/tests/utils.py
@@ -25,15 +25,16 @@ import random
import time
from apache_beam.io import filesystems
+from apache_beam.io.gcp.pubsub import PubsubMessage
from apache_beam.utils import retry
# Protect against environments where bigquery library is not available.
try:
+ from google.api_core import exceptions as gexc
from google.cloud import bigquery
- from google.cloud.exceptions import NotFound
except ImportError:
+ gexc = None
bigquery = None
- NotFound = None
class GcpTestIOError(retry.PermanentException):
@@ -98,7 +99,7 @@ def delete_bq_table(project, dataset_id, table_id):
table_ref = client.dataset(dataset_id).table(table_id)
try:
client.delete_table(table_ref)
- except NotFound:
+ except gexc.NotFound:
raise GcpTestIOError('BigQuery table does not exist: %s' % table_ref)
@@ -113,3 +114,53 @@ def delete_directory(directory):
"gs://mybucket/mydir/", "s3://...", ...)
"""
filesystems.FileSystems.delete([directory])
+
+
+def write_to_pubsub(pub_client,
+ topic_path,
+ messages,
+ with_attributes=False,
+ chunk_size=100,
+ delay_between_chunks=0.1):
+ for start in range(0, len(messages), chunk_size):
+ message_chunk = messages[start:start + chunk_size]
+ if with_attributes:
+ futures = [
+ pub_client.publish(topic_path, message.data, **message.attributes)
+ for message in message_chunk
+ ]
+ else:
+ futures = [
+ pub_client.publish(topic_path, message) for message in message_chunk
+ ]
+ for future in futures:
+ future.result()
+ time.sleep(delay_between_chunks)
+
+
+def read_from_pubsub(sub_client,
+ subscription_path,
+ with_attributes=False,
+ number_of_elements=None,
+ timeout=None):
+ if number_of_elements is None and timeout is None:
+ raise ValueError("Either number_of_elements or timeout must be specified.")
+ messages = []
+ start_time = time.time()
+
+ while ((number_of_elements is None or len(messages) < number_of_elements) and
+ (timeout is None or (time.time() - start_time) < timeout)):
+ try:
+ response = sub_client.pull(
+ subscription_path, max_messages=1000, retry=None, timeout=10)
+ except (gexc.RetryError, gexc.DeadlineExceeded):
+ continue
+ ack_ids = [msg.ack_id for msg in response.received_messages]
+ sub_client.acknowledge(subscription_path, ack_ids)
+ for msg in response.received_messages:
+ message = PubsubMessage._from_message(msg.message)
+ if with_attributes:
+ messages.append(message)
+ else:
+ messages.append(message.data)
+ return messages
diff --git a/sdks/python/apache_beam/io/gcp/tests/utils_test.py
b/sdks/python/apache_beam/io/gcp/tests/utils_test.py
index 8af7497..c9e96d1 100644
--- a/sdks/python/apache_beam/io/gcp/tests/utils_test.py
+++ b/sdks/python/apache_beam/io/gcp/tests/utils_test.py
@@ -24,16 +24,19 @@ import unittest
import mock
+from apache_beam.io.gcp.pubsub import PubsubMessage
from apache_beam.io.gcp.tests import utils
-from apache_beam.testing.test_utils import patch_retry
+from apache_beam.testing import test_utils
# Protect against environments where bigquery library is not available.
try:
+ from google.api_core import exceptions as gexc
from google.cloud import bigquery
- from google.cloud.exceptions import NotFound
+ from google.cloud import pubsub
except ImportError:
+ gexc = None
bigquery = None
- NotFound = None
+ pubsub = None
@unittest.skipIf(bigquery is None, 'Bigquery dependencies are not installed.')
@@ -41,7 +44,7 @@ except ImportError:
class UtilsTest(unittest.TestCase):
def setUp(self):
- patch_retry(self, utils)
+ test_utils.patch_retry(self, utils)
@mock.patch.object(bigquery, 'Dataset')
def test_create_bq_dataset(self, mock_dataset, mock_client):
@@ -68,7 +71,7 @@ class UtilsTest(unittest.TestCase):
def test_delete_table_fails_not_found(self, mock_client):
mock_client.return_value.dataset.return_value.table.return_value = (
'table_ref')
- mock_client.return_value.delete_table.side_effect = NotFound('test')
+ mock_client.return_value.delete_table.side_effect = gexc.NotFound('test')
with self.assertRaisesRegexp(Exception, r'does not exist:.*table_ref'):
utils.delete_bq_table('unused_project',
@@ -76,6 +79,193 @@ class UtilsTest(unittest.TestCase):
'unused_table')
[email protected](pubsub is None, 'GCP dependencies are not installed')
+class PubSubUtilTest(unittest.TestCase):
+
+ def test_write_to_pubsub(self):
+ mock_pubsub = mock.Mock()
+ topic_path = "project/fakeproj/topics/faketopic"
+ data = b'data'
+ utils.write_to_pubsub(mock_pubsub, topic_path, [data])
+ mock_pubsub.publish.assert_has_calls(
+ [mock.call(topic_path, data),
+ mock.call().result()])
+
+ def test_write_to_pubsub_with_attributes(self):
+ mock_pubsub = mock.Mock()
+ topic_path = "project/fakeproj/topics/faketopic"
+ data = b'data'
+ attributes = {'key': 'value'}
+ message = PubsubMessage(data, attributes)
+ utils.write_to_pubsub(
+ mock_pubsub, topic_path, [message], with_attributes=True)
+ mock_pubsub.publish.assert_has_calls(
+ [mock.call(topic_path, data, **attributes),
+ mock.call().result()])
+
+ def test_write_to_pubsub_delay(self):
+ number_of_elements = 2
+ chunk_size = 1
+ mock_pubsub = mock.Mock()
+ topic_path = "project/fakeproj/topics/faketopic"
+ data = b'data'
+ with mock.patch('apache_beam.io.gcp.tests.utils.time') as mock_time:
+ utils.write_to_pubsub(
+ mock_pubsub,
+ topic_path, [data] * number_of_elements,
+ chunk_size=chunk_size,
+ delay_between_chunks=123)
+ mock_time.sleep.assert_called_with(123)
+ mock_pubsub.publish.assert_has_calls(
+ [mock.call(topic_path, data),
+ mock.call().result()] * number_of_elements)
+
+ def test_write_to_pubsub_many_chunks(self):
+ number_of_elements = 83
+ chunk_size = 11
+ mock_pubsub = mock.Mock()
+ topic_path = "project/fakeproj/topics/faketopic"
+ data_list = [
+ 'data {}'.format(i).encode("utf-8") for i in range(number_of_elements)
+ ]
+ utils.write_to_pubsub(
+ mock_pubsub, topic_path, data_list, chunk_size=chunk_size)
+ call_list = []
+ for start in range(0, number_of_elements, chunk_size):
+ # Publish a batch of messages
+ call_list += [
+ mock.call(topic_path, data)
+ for data in data_list[start:start + chunk_size]
+ ]
+ # Wait for those messages to be received
+ call_list += [
+ mock.call().result() for _ in data_list[start:start + chunk_size]
+ ]
+ mock_pubsub.publish.assert_has_calls(call_list)
+
+ def test_read_from_pubsub(self):
+ mock_pubsub = mock.Mock()
+ subscription_path = "project/fakeproj/subscriptions/fakesub"
+ data = b'data'
+ ack_id = 'ack_id'
+ pull_response = test_utils.create_pull_response(
+ [test_utils.PullResponseMessage(data, ack_id=ack_id)])
+ mock_pubsub.pull.return_value = pull_response
+ output = utils.read_from_pubsub(
+ mock_pubsub, subscription_path, number_of_elements=1)
+ self.assertEqual([data], output)
+ mock_pubsub.acknowledge.assert_called_once_with(subscription_path,
[ack_id])
+
+ def test_read_from_pubsub_with_attributes(self):
+ mock_pubsub = mock.Mock()
+ subscription_path = "project/fakeproj/subscriptions/fakesub"
+ data = b'data'
+ ack_id = 'ack_id'
+ attributes = {'key': 'value'}
+ message = PubsubMessage(data, attributes)
+ pull_response = test_utils.create_pull_response(
+ [test_utils.PullResponseMessage(data, attributes, ack_id=ack_id)])
+ mock_pubsub.pull.return_value = pull_response
+ output = utils.read_from_pubsub(
+ mock_pubsub,
+ subscription_path,
+ with_attributes=True,
+ number_of_elements=1)
+ self.assertEqual([message], output)
+ mock_pubsub.acknowledge.assert_called_once_with(subscription_path,
[ack_id])
+
+ def test_read_from_pubsub_flaky(self):
+ number_of_elements = 10
+ mock_pubsub = mock.Mock()
+ subscription_path = "project/fakeproj/subscriptions/fakesub"
+ data = b'data'
+ ack_id = 'ack_id'
+ pull_response = test_utils.create_pull_response(
+ [test_utils.PullResponseMessage(data, ack_id=ack_id)])
+
+ class FlakyPullResponse(object):
+
+ def __init__(self, pull_response):
+ self.pull_response = pull_response
+ self._state = -1
+
+ def __call__(self, *args, **kwargs):
+ self._state += 1
+ if self._state % 3 == 0:
+ raise gexc.RetryError("", "")
+ if self._state % 3 == 1:
+ raise gexc.DeadlineExceeded("")
+ if self._state % 3 == 2:
+ return self.pull_response
+
+ mock_pubsub.pull.side_effect = FlakyPullResponse(pull_response)
+ output = utils.read_from_pubsub(
+ mock_pubsub, subscription_path, number_of_elements=number_of_elements)
+ self.assertEqual([data] * number_of_elements, output)
+ self._assert_ack_ids_equal(mock_pubsub, [ack_id] * number_of_elements)
+
+ def test_read_from_pubsub_many(self):
+ response_size = 33
+ number_of_elements = 100
+ mock_pubsub = mock.Mock()
+ subscription_path = "project/fakeproj/subscriptions/fakesub"
+ data_list = [
+ 'data {}'.format(i).encode("utf-8") for i in range(number_of_elements)
+ ]
+ attributes_list = [{
+ 'key': 'value {}'.format(i)
+ } for i in range(number_of_elements)]
+ ack_ids = ['ack_id_{}'.format(i) for i in range(number_of_elements)]
+ messages = [
+ PubsubMessage(data, attributes)
+ for data, attributes in zip(data_list, attributes_list)
+ ]
+ response_messages = [
+ test_utils.PullResponseMessage(data, attributes, ack_id=ack_id)
+ for data, attributes, ack_id in zip(data_list, attributes_list,
ack_ids)
+ ]
+
+ class SequentialPullResponse(object):
+
+ def __init__(self, response_messages, response_size):
+ self.response_messages = response_messages
+ self.response_size = response_size
+ self._index = 0
+
+ def __call__(self, *args, **kwargs):
+ start = self._index
+ self._index += self.response_size
+ response = test_utils.create_pull_response(
+ self.response_messages[start:start + self.response_size])
+ return response
+
+ mock_pubsub.pull.side_effect = SequentialPullResponse(
+ response_messages, response_size)
+ output = utils.read_from_pubsub(
+ mock_pubsub,
+ subscription_path,
+ with_attributes=True,
+ number_of_elements=number_of_elements)
+ self.assertEqual(messages, output)
+ self._assert_ack_ids_equal(mock_pubsub, ack_ids)
+
+ def test_read_from_pubsub_invalid_arg(self):
+ sub_client = mock.Mock()
+ subscription_path = "project/fakeproj/subscriptions/fakesub"
+ with self.assertRaisesRegexp(ValueError, "number_of_elements"):
+ utils.read_from_pubsub(sub_client, subscription_path)
+ with self.assertRaisesRegexp(ValueError, "number_of_elements"):
+ utils.read_from_pubsub(
+ sub_client, subscription_path, with_attributes=True)
+
+ def _assert_ack_ids_equal(self, mock_pubsub, ack_ids):
+ actual_ack_ids = [
+ ack_id for args_list in mock_pubsub.acknowledge.call_args_list
+ for ack_id in args_list[0][1]
+ ]
+ self.assertEqual(actual_ack_ids, ack_ids)
+
+
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()