Repository: incubator-airflow
Updated Branches:
  refs/heads/master 33b3f6dc6 -> 664521809


[AIRFLOW-1932] Add GCP Pub/Sub Pull and Ack

Adds the necessary hooks to support pulling and
acknowleding Pub/Sub
messages. This is implemented by adding a
PubSubPullSensor operator
that will attempt to retrieve messages from a
specified subscription
and will meet its criteria when a message or
messages is available.
The configuration allows those messages to be
acknowledged immediately.
In addition, the messages are passed to downstream
workers via the
return value of operator's execute method.

An end-to-end example is included showing topic
and subscription
creation, parallel tasks to publish and pull
messages, and a downstream
chain to echo the contents of each message before
cleaning up.

Closes #2885 from prodonjs/airflow-1932-pr


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/66452180
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/66452180
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/66452180

Branch: refs/heads/master
Commit: 6645218092096e4b10fc737a62bacc2670e1d6dc
Parents: 33b3f6d
Author: Jason Prodonovich <[email protected]>
Authored: Wed Dec 20 22:25:12 2017 +0100
Committer: Bolke de Bruin <[email protected]>
Committed: Wed Dec 20 22:25:12 2017 +0100

----------------------------------------------------------------------
 .../contrib/example_dags/example_pubsub_flow.py |  81 +++++++++++++++
 airflow/contrib/hooks/gcp_pubsub_hook.py        |  93 ++++++++++++++---
 airflow/contrib/operators/pubsub_operator.py    |   7 +-
 airflow/contrib/sensors/pubsub_sensor.py        | 100 +++++++++++++++++++
 docs/code.rst                                   |   2 +
 tests/contrib/hooks/test_gcp_pubsub_hook.py     |  74 +++++++++++++-
 tests/contrib/operators/test_pubsub_operator.py |  17 +++-
 tests/contrib/sensors/test_pubsub_sensor.py     |  98 ++++++++++++++++++
 8 files changed, 449 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/66452180/airflow/contrib/example_dags/example_pubsub_flow.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/example_dags/example_pubsub_flow.py 
b/airflow/contrib/example_dags/example_pubsub_flow.py
new file mode 100644
index 0000000..c8843c8
--- /dev/null
+++ b/airflow/contrib/example_dags/example_pubsub_flow.py
@@ -0,0 +1,81 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This example DAG demonstrates how the PubSub*Operators and PubSubPullSensor
+can be used to trigger dependant tasks upon receipt of a Pub/Sub message.
+
+NOTE: project_id must be updated to a GCP project ID accessible with the
+      Google Default Credentials on the machine running the workflow
+"""
+from __future__ import unicode_literals
+from base64 import b64encode
+
+import datetime
+
+from airflow import DAG
+from airflow.operators.bash_operator import BashOperator
+from airflow.contrib.operators.pubsub_operator import (
+    PubSubTopicCreateOperator, PubSubSubscriptionCreateOperator,
+    PubSubPublishOperator, PubSubTopicDeleteOperator,
+    PubSubSubscriptionDeleteOperator
+)
+from airflow.contrib.sensors.pubsub_sensor import PubSubPullSensor
+from airflow.utils import dates
+
+project = 'your-project-id'  # Change this to your own GCP project_id
+topic = 'example-topic'  # Cloud Pub/Sub topic
+subscription = 'subscription-to-example-topic'  # Cloud Pub/Sub subscription
+# Sample messages to push/pull
+messages = [
+    {'data': b64encode(b'Hello World')},
+    {'data': b64encode(b'Another message')},
+    {'data': b64encode(b'A final message')}
+]
+
+default_args = {
+    'owner': 'airflow',
+    'depends_on_past': False,
+    'start_date': dates.days_ago(2),
+    'email': ['[email protected]'],
+    'email_on_failure': False,
+    'email_on_retry': False,
+    'project': project,
+    'topic': topic,
+    'subscription': subscription,
+}
+
+
+echo_template = '''
+{% for m in task_instance.xcom_pull(task_ids='pull-messages') %}
+    echo "AckID: {{ m.get('ackId') }}, Base64-Encoded: {{ m.get('message') }}"
+{% endfor %}
+'''
+
+with DAG('pubsub-end-to-end', default_args=default_args,
+         schedule_interval=datetime.timedelta(days=1)) as dag:
+    t1 = PubSubTopicCreateOperator(task_id='create-topic')
+    t2 = PubSubSubscriptionCreateOperator(
+        task_id='create-subscription', topic_project=project,
+        subscription=subscription)
+    t3 = PubSubPublishOperator(
+        task_id='publish-messages', messages=messages)
+    t4 = PubSubPullSensor(task_id='pull-messages', ack_messages=True)
+    t5 = BashOperator(task_id='echo-pulled-messages',
+                      bash_command=echo_template)
+    t6 = PubSubSubscriptionDeleteOperator(task_id='delete-subscription')
+    t7 = PubSubTopicDeleteOperator(task_id='delete-topic')
+
+    t1 >> t2 >> t3
+    t2 >> t4 >> t5 >> t6 >> t7

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/66452180/airflow/contrib/hooks/gcp_pubsub_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/gcp_pubsub_hook.py 
b/airflow/contrib/hooks/gcp_pubsub_hook.py
index dc95d89..e45ad63 100644
--- a/airflow/contrib/hooks/gcp_pubsub_hook.py
+++ b/airflow/contrib/hooks/gcp_pubsub_hook.py
@@ -21,11 +21,11 @@ from airflow.contrib.hooks.gcp_api_base_hook import 
GoogleCloudBaseHook
 
 
 def _format_subscription(project, subscription):
-    return 'projects/%s/subscriptions/%s' % (project, subscription)
+    return 'projects/{}/subscriptions/{}'.format(project, subscription)
 
 
 def _format_topic(project, topic):
-    return 'projects/%s/topics/%s' % (project, topic)
+    return 'projects/{}/topics/{}'.format(project, topic)
 
 
 class PubSubException(Exception):
@@ -71,7 +71,7 @@ class PubSubHook(GoogleCloudBaseHook):
             request.execute()
         except errors.HttpError as e:
             raise PubSubException(
-                'Error publishing to topic %s' % full_topic, e)
+                'Error publishing to topic {}'.format(full_topic), e)
 
     def create_topic(self, project, topic, fail_if_exists=False):
         """Creates a Pub/Sub topic, if it does not already exist.
@@ -94,12 +94,13 @@ class PubSubHook(GoogleCloudBaseHook):
         except errors.HttpError as e:
             # Status code 409 indicates that the topic already exists.
             if str(e.resp['status']) == '409':
-                message = 'Topic already exists: %s' % full_topic
+                message = 'Topic already exists: {}'.format(full_topic)
                 self.log.warning(message)
                 if fail_if_exists:
                     raise PubSubException(message)
             else:
-                raise PubSubException('Error creating topic %s' % full_topic, 
e)
+                raise PubSubException(
+                    'Error creating topic {}'.format(full_topic), e)
 
     def delete_topic(self, project, topic, fail_if_not_exists=False):
         """Deletes a Pub/Sub topic if it exists.
@@ -120,12 +121,13 @@ class PubSubHook(GoogleCloudBaseHook):
         except errors.HttpError as e:
             # Status code 409 indicates that the topic was not found
             if str(e.resp['status']) == '404':
-                message = 'Topic does not exist: %s' % full_topic
+                message = 'Topic does not exist: {}'.format(full_topic)
                 self.log.warning(message)
                 if fail_if_not_exists:
                     raise PubSubException(message)
             else:
-                raise PubSubException('Error deleting topic %s' % full_topic, 
e)
+                raise PubSubException(
+                    'Error deleting topic {}'.format(full_topic), e)
 
     def create_subscription(self, topic_project, topic, subscription=None,
                             subscription_project=None, ack_deadline_secs=10,
@@ -158,7 +160,7 @@ class PubSubHook(GoogleCloudBaseHook):
         service = self.get_conn()
         full_topic = _format_topic(topic_project, topic)
         if not subscription:
-            subscription = 'sub-%s' % uuid4()
+            subscription = 'sub-{}'.format(uuid4())
         if not subscription_project:
             subscription_project = topic_project
         full_subscription = _format_subscription(subscription_project,
@@ -173,13 +175,15 @@ class PubSubHook(GoogleCloudBaseHook):
         except errors.HttpError as e:
             # Status code 409 indicates that the subscription already exists.
             if str(e.resp['status']) == '409':
-                message = 'Subscription already exists: %s' % full_subscription
+                message = 'Subscription already exists: {}'.format(
+                    full_subscription)
                 self.log.warning(message)
                 if fail_if_exists:
                     raise PubSubException(message)
             else:
                 raise PubSubException(
-                    'Error creating subscription %s' % full_subscription, e)
+                    'Error creating subscription {}'.format(full_subscription),
+                    e)
         return subscription
 
     def delete_subscription(self, project, subscription,
@@ -203,10 +207,73 @@ class PubSubHook(GoogleCloudBaseHook):
         except errors.HttpError as e:
             # Status code 404 indicates that the subscription was not found
             if str(e.resp['status']) == '404':
-                message = 'Subscription does not exist: %s' % full_subscription
+                message = 'Subscription does not exist: {}'.format(
+                    full_subscription)
                 self.log.warning(message)
                 if fail_if_not_exists:
                     raise PubSubException(message)
             else:
-                raise PubSubException('Error deleting subscription %s' %
-                                      full_subscription, e)
+                raise PubSubException(
+                    'Error deleting subscription {}'.format(full_subscription),
+                    e)
+
+    def pull(self, project, subscription, max_messages,
+             return_immediately=False):
+        """Pulls up to ``max_messages`` messages from Pub/Sub subscription.
+
+        :param project: the GCP project ID where the subscription exists
+        :type project: string
+        :param subscription: the Pub/Sub subscription name to pull from; do not
+            include the 'projects/{project}/topics/' prefix.
+        :type subscription: string
+        :param max_messages: The maximum number of messages to return from
+            the Pub/Sub API.
+        :type max_messages: int
+        :param return_immediately: If set, the Pub/Sub API will immediately
+            return if no messages are available. Otherwise, the request will
+            block for an undisclosed, but bounded period of time
+        :type return_immediately: bool
+        :return A list of Pub/Sub ReceivedMessage objects each containing
+            an ``ackId`` property and a ``message`` property, which includes
+            the base64-encoded message content. See
+            https://cloud.google.com/pubsub/docs/reference/rest/v1/\
+                projects.subscriptions/pull#ReceivedMessage
+        """
+        service = self.get_conn()
+        full_subscription = _format_subscription(project, subscription)
+        body = {
+            'maxMessages': max_messages,
+            'returnImmediately': return_immediately
+        }
+        try:
+            response = service.projects().subscriptions().pull(
+                subscription=full_subscription, body=body).execute()
+            return response.get('receivedMessages', [])
+        except errors.HttpError as e:
+            raise PubSubException(
+                'Error pulling messages from subscription {}'.format(
+                    full_subscription), e)
+
+    def acknowledge(self, project, subscription, ack_ids):
+        """Pulls up to ``max_messages`` messages from Pub/Sub subscription.
+
+        :param project: the GCP project name or ID in which to create
+            the topic
+        :type project: string
+        :param subscription: the Pub/Sub subscription name to delete; do not
+            include the 'projects/{project}/topics/' prefix.
+        :type subscription: string
+        :param ack_ids: List of ReceivedMessage ackIds from a previous pull
+            response
+        :type ack_ids: list
+        """
+        service = self.get_conn()
+        full_subscription = _format_subscription(project, subscription)
+        try:
+            service.projects().subscriptions().acknowledge(
+                subscription=full_subscription, body={'ackIds': ack_ids}
+            ).execute()
+        except errors.HttpError as e:
+            raise PubSubException(
+                'Error acknowledging {} messages pulled from subscription {}'
+                .format(len(ack_ids), full_subscription), e)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/66452180/airflow/contrib/operators/pubsub_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/pubsub_operator.py 
b/airflow/contrib/operators/pubsub_operator.py
index 7793787..40c863f 100644
--- a/airflow/contrib/operators/pubsub_operator.py
+++ b/airflow/contrib/operators/pubsub_operator.py
@@ -196,7 +196,7 @@ class PubSubSubscriptionCreateOperator(BaseOperator):
         hook = PubSubHook(gcp_conn_id=self.gcp_conn_id,
                           delegate_to=self.delegate_to)
 
-        hook.create_subscription(
+        return hook.create_subscription(
             self.topic_project, self.topic, self.subscription,
             self.subscription_project, self.ack_deadline_secs,
             self.fail_if_exists)
@@ -368,13 +368,12 @@ class PubSubPublishOperator(BaseOperator):
         m3 = {'attributes': {'foo': ''}}
 
         t1 = PubSubPublishOperator(
-            project='my-project',
-            topic='my_topic',
+            project='my-project',topic='my_topic',
             messages=[m1, m2, m3],
             create_topic=True,
             dag=dag)
 
-    ``project``, ``topic``, and ``messages`` are templated so you can use
+     ``project`` , ``topic``, and ``messages`` are templated so you can use
     variables in them.
     """
     template_fields = ['project', 'topic', 'messages']

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/66452180/airflow/contrib/sensors/pubsub_sensor.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/sensors/pubsub_sensor.py 
b/airflow/contrib/sensors/pubsub_sensor.py
new file mode 100644
index 0000000..112f777
--- /dev/null
+++ b/airflow/contrib/sensors/pubsub_sensor.py
@@ -0,0 +1,100 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from airflow.contrib.hooks.gcp_pubsub_hook import PubSubHook
+from airflow.operators.sensors import BaseSensorOperator
+from airflow.utils.decorators import apply_defaults
+
+
+class PubSubPullSensor(BaseSensorOperator):
+    """Pulls messages from a PubSub subscription and passes them through XCom.
+
+    This sensor operator will pull up to ``max_messages`` messages from the
+    specified PubSub subscription. When the subscription returns messages,
+    the poke method's criteria will be fulfilled and the messages will be
+    returned from the operator and passed through XCom for downstream tasks.
+
+    If ``ack_messages`` is set to True, messages will be immediately
+    acknowledged before being returned, otherwise, downstream tasks will be
+    responsible for acknowledging them.
+
+    ``project`` and ``subscription`` are templated so you can use
+    variables in them.
+    """
+    template_fields = ['project', 'subscription']
+    ui_color = '#ff7f50'
+
+    @apply_defaults
+    def __init__(
+            self,
+            project,
+            subscription,
+            max_messages=5,
+            return_immediately=False,
+            ack_messages=False,
+            gcp_conn_id='google_cloud_default',
+            delegate_to=None,
+            *args,
+            **kwargs):
+        """
+        :param project: the GCP project ID for the subscription (templated)
+        :type project: string
+        :param subscription: the Pub/Sub subscription name. Do not include the
+            full subscription path.
+        :type subscription: string
+        :param max_messages: The maximum number of messages to retrieve per
+            PubSub pull request
+        :type max_messages: int
+        :param return_immediately: If True, instruct the PubSub API to return
+            immediately if no messages are available for delivery.
+        :type return_immediately: bool
+        :param ack_messages: If True, each message will be acknowledged
+            immediately rather than by any downstream tasks
+        :type ack_messages: bool
+        :param gcp_conn_id: The connection ID to use connecting to
+            Google Cloud Platform.
+        :type gcp_conn_id: string
+        :param delegate_to: The account to impersonate, if any.
+            For this to work, the service account making the request
+            must have domain-wide delegation enabled.
+        :type delegate_to: string
+        """
+        super(PubSubPullSensor, self).__init__(*args, **kwargs)
+
+        self.gcp_conn_id = gcp_conn_id
+        self.delegate_to = delegate_to
+        self.project = project
+        self.subscription = subscription
+        self.max_messages = max_messages
+        self.return_immediately = return_immediately
+        self.ack_messages = ack_messages
+
+        self._messages = None
+
+    def execute(self, context):
+        """Overridden to allow messages to be passed"""
+        super(PubSubPullSensor, self).execute(context)
+        return self._messages
+
+    def poke(self, context):
+        hook = PubSubHook(gcp_conn_id=self.gcp_conn_id,
+                          delegate_to=self.delegate_to)
+        self._messages = hook.pull(
+            self.project, self.subscription, self.max_messages,
+            self.return_immediately)
+        if self._messages and self.ack_messages:
+            if self.ack_messages:
+                ack_ids = [m['ackId'] for m in self._messages if 
m.get('ackId')]
+                hook.acknowledge(self.project, self.subscription, ack_ids)
+        return self._messages

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/66452180/docs/code.rst
----------------------------------------------------------------------
diff --git a/docs/code.rst b/docs/code.rst
index 021a05e..4dfdbfa 100644
--- a/docs/code.rst
+++ b/docs/code.rst
@@ -105,7 +105,9 @@ Community-contributed Operators
 .. autoclass:: 
airflow.contrib.operators.pubsub_operator.PubSubTopicCreateOperator
 .. autoclass:: 
airflow.contrib.operators.pubsub_operator.PubSubTopicDeleteOperator
 .. autoclass:: 
airflow.contrib.operators.pubsub_operator.PubSubSubscriptionCreateOperator
+.. autoclass:: 
airflow.contrib.operators.pubsub_operator.PubSubSubscriptionDeleteOperator
 .. autoclass:: airflow.contrib.operators.pubsub_operator.PubSubPublishOperator
+.. autoclass:: airflow.contrib.sensors.pubsub_sensor.PubSubPullSensor
 .. autoclass:: airflow.contrib.operators.QuboleOperator
 .. autoclass:: airflow.contrib.operators.hipchat_operator.HipChatAPIOperator
 .. autoclass:: 
airflow.contrib.operators.hipchat_operator.HipChatAPISendRoomNotificationOperator

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/66452180/tests/contrib/hooks/test_gcp_pubsub_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_gcp_pubsub_hook.py 
b/tests/contrib/hooks/test_gcp_pubsub_hook.py
index 7226618..9397e8a 100644
--- a/tests/contrib/hooks/test_gcp_pubsub_hook.py
+++ b/tests/contrib/hooks/test_gcp_pubsub_hook.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+
 from __future__ import unicode_literals
 
 from base64 import b64encode as b64e
@@ -45,9 +46,9 @@ TEST_MESSAGES = [
     {'data': b64e(b'Knock, knock')},
     {'attributes': {'foo': ''}}]
 
-EXPANDED_TOPIC = 'projects/%s/topics/%s' % (TEST_PROJECT, TEST_TOPIC)
-EXPANDED_SUBSCRIPTION = 'projects/%s/subscriptions/%s' % (TEST_PROJECT,
-                                                          TEST_SUBSCRIPTION)
+EXPANDED_TOPIC = 'projects/{}/topics/{}'.format(TEST_PROJECT, TEST_TOPIC)
+EXPANDED_SUBSCRIPTION = 'projects/{}/subscriptions/{}'.format(
+    TEST_PROJECT, TEST_SUBSCRIPTION)
 
 
 def mock_init(self, gcp_conn_id, delegate_to=None):
@@ -242,3 +243,70 @@ class PubSubHookTest(unittest.TestCase):
                           .topics.return_value.publish)
         publish_method.assert_called_with(
             topic=EXPANDED_TOPIC, body={'messages': TEST_MESSAGES})
+
+    @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn'))
+    def test_pull(self, mock_service):
+        pull_method = (mock_service.return_value.projects.return_value
+                       .subscriptions.return_value.pull)
+        pulled_messages = []
+        for i in range(len(TEST_MESSAGES)):
+            pulled_messages.append({'ackId': i, 'message': TEST_MESSAGES[i]})
+        pull_method.return_value.execute.return_value = {
+            'receivedMessages': pulled_messages}
+
+        response = self.pubsub_hook.pull(TEST_PROJECT, TEST_SUBSCRIPTION, 10)
+        pull_method.assert_called_with(
+            subscription=EXPANDED_SUBSCRIPTION,
+            body={'maxMessages': 10, 'returnImmediately': False})
+        self.assertEqual(pulled_messages, response)
+
+    @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn'))
+    def test_pull_no_messages(self, mock_service):
+        pull_method = (mock_service.return_value.projects.return_value
+                       .subscriptions.return_value.pull)
+        pull_method.return_value.execute.return_value = {
+            'receivedMessages': []}
+
+        response = self.pubsub_hook.pull(TEST_PROJECT, TEST_SUBSCRIPTION, 10)
+        pull_method.assert_called_with(
+            subscription=EXPANDED_SUBSCRIPTION,
+            body={'maxMessages': 10, 'returnImmediately': False})
+        self.assertListEqual([], response)
+
+    @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn'))
+    def test_pull_fails_on_exception(self, mock_service):
+        pull_method = (mock_service.return_value.projects.return_value
+                       .subscriptions.return_value.pull)
+        pull_method.return_value.execute.side_effect = HttpError(
+            resp={'status': '404'}, content=EMPTY_CONTENT)
+
+        with self.assertRaises(Exception):
+            self.pubsub_hook.pull(TEST_PROJECT, TEST_SUBSCRIPTION, 10)
+            pull_method.assert_called_with(
+                subscription=EXPANDED_SUBSCRIPTION,
+                body={'maxMessages': 10, 'returnImmediately': False})
+
+    @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn'))
+    def test_acknowledge(self, mock_service):
+        ack_method = (mock_service.return_value.projects.return_value
+                      .subscriptions.return_value.acknowledge)
+        self.pubsub_hook.acknowledge(
+            TEST_PROJECT, TEST_SUBSCRIPTION, ['1', '2', '3'])
+        ack_method.assert_called_with(
+            subscription=EXPANDED_SUBSCRIPTION,
+            body={'ackIds': ['1', '2', '3']})
+
+    @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn'))
+    def test_acknowledge_fails_on_exception(self, mock_service):
+        ack_method = (mock_service.return_value.projects.return_value
+                      .subscriptions.return_value.acknowledge)
+        ack_method.return_value.execute.side_effect = HttpError(
+            resp={'status': '404'}, content=EMPTY_CONTENT)
+
+        with self.assertRaises(Exception) as e:
+            self.pubsub_hook.acknowledge(
+                TEST_PROJECT, TEST_SUBSCRIPTION, ['1', '2', '3'])
+            ack_method.assert_called_with(
+                subscription=EXPANDED_SUBSCRIPTION,
+                body={'ackIds': ['1', '2', '3']})
+            print(e)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/66452180/tests/contrib/operators/test_pubsub_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_pubsub_operator.py 
b/tests/contrib/operators/test_pubsub_operator.py
index 560fe54..ec3d564 100644
--- a/tests/contrib/operators/test_pubsub_operator.py
+++ b/tests/contrib/operators/test_pubsub_operator.py
@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 from __future__ import unicode_literals
 
 from base64 import b64encode as b64e
@@ -41,6 +42,7 @@ TEST_MESSAGES = [
     },
     {'data': b64e(b'Knock, knock')},
     {'attributes': {'foo': ''}}]
+TEST_POKE_INTERVAl = 0
 
 
 class PubSubTopicCreateOperatorTest(unittest.TestCase):
@@ -88,10 +90,13 @@ class 
PubSubSubscriptionCreateOperatorTest(unittest.TestCase):
         operator = PubSubSubscriptionCreateOperator(
             task_id=TASK_ID, topic_project=TEST_PROJECT, topic=TEST_TOPIC,
             subscription=TEST_SUBSCRIPTION)
-        operator.execute(None)
+        mock_hook.return_value.create_subscription.return_value = (
+            TEST_SUBSCRIPTION)
+        response = operator.execute(None)
         mock_hook.return_value.create_subscription.assert_called_once_with(
             TEST_PROJECT, TEST_TOPIC, TEST_SUBSCRIPTION, None,
             10, False)
+        self.assertEquals(response, TEST_SUBSCRIPTION)
 
     @mock.patch('airflow.contrib.operators.pubsub_operator.PubSubHook')
     def test_execute_different_project_ids(self, mock_hook):
@@ -100,18 +105,24 @@ class 
PubSubSubscriptionCreateOperatorTest(unittest.TestCase):
             task_id=TASK_ID, topic_project=TEST_PROJECT, topic=TEST_TOPIC,
             subscription=TEST_SUBSCRIPTION,
             subscription_project=another_project)
-        operator.execute(None)
+        mock_hook.return_value.create_subscription.return_value = (
+            TEST_SUBSCRIPTION)
+        response = operator.execute(None)
         mock_hook.return_value.create_subscription.assert_called_once_with(
             TEST_PROJECT, TEST_TOPIC, TEST_SUBSCRIPTION, another_project,
             10, False)
+        self.assertEquals(response, TEST_SUBSCRIPTION)
 
     @mock.patch('airflow.contrib.operators.pubsub_operator.PubSubHook')
     def test_execute_no_subscription(self, mock_hook):
         operator = PubSubSubscriptionCreateOperator(
             task_id=TASK_ID, topic_project=TEST_PROJECT, topic=TEST_TOPIC)
-        operator.execute(None)
+        mock_hook.return_value.create_subscription.return_value = (
+            TEST_SUBSCRIPTION)
+        response = operator.execute(None)
         mock_hook.return_value.create_subscription.assert_called_once_with(
             TEST_PROJECT, TEST_TOPIC, None, None, 10, False)
+        self.assertEquals(response, TEST_SUBSCRIPTION)
 
 
 class PubSubSubscriptionDeleteOperatorTest(unittest.TestCase):

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/66452180/tests/contrib/sensors/test_pubsub_sensor.py
----------------------------------------------------------------------
diff --git a/tests/contrib/sensors/test_pubsub_sensor.py 
b/tests/contrib/sensors/test_pubsub_sensor.py
new file mode 100644
index 0000000..ae59bb7
--- /dev/null
+++ b/tests/contrib/sensors/test_pubsub_sensor.py
@@ -0,0 +1,98 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import unicode_literals
+
+from base64 import b64encode as b64e
+import unittest
+
+from airflow.contrib.sensors.pubsub_sensor import PubSubPullSensor
+from airflow.exceptions import AirflowSensorTimeout
+
+try:
+    from unittest import mock
+except ImportError:
+    try:
+        import mock
+    except ImportError:
+        mock = None
+
+TASK_ID = 'test-task-id'
+TEST_PROJECT = 'test-project'
+TEST_TOPIC = 'test-topic'
+TEST_SUBSCRIPTION = 'test-subscription'
+TEST_MESSAGES = [
+    {
+        'data': b64e(b'Hello, World!'),
+        'attributes': {'type': 'greeting'}
+    },
+    {'data': b64e(b'Knock, knock')},
+    {'attributes': {'foo': ''}}]
+
+
+class PubSubPullSensorTest(unittest.TestCase):
+
+    def _generate_messages(self, count):
+        messages = []
+        for i in range(1, count + 1):
+            messages.append({
+                'ackId': '%s' % i,
+                'message': {
+                    'data': b64e('Message {}'.format(i).encode('utf8')),
+                    'attributes': {'type': 'generated message'}
+                }
+            })
+        return messages
+
+    @mock.patch('airflow.contrib.sensors.pubsub_sensor.PubSubHook')
+    def test_poke_no_messages(self, mock_hook):
+        operator = PubSubPullSensor(task_id=TASK_ID, project=TEST_PROJECT,
+                                    subscription=TEST_SUBSCRIPTION)
+        mock_hook.return_value.pull.return_value = []
+        self.assertEquals([], operator.poke(None))
+
+    @mock.patch('airflow.contrib.sensors.pubsub_sensor.PubSubHook')
+    def test_poke_with_ack_messages(self, mock_hook):
+        operator = PubSubPullSensor(task_id=TASK_ID, project=TEST_PROJECT,
+                                    subscription=TEST_SUBSCRIPTION,
+                                    ack_messages=True)
+        generated_messages = self._generate_messages(5)
+        mock_hook.return_value.pull.return_value = generated_messages
+        self.assertEquals(generated_messages, operator.poke(None))
+        mock_hook.return_value.acknowledge.assert_called_with(
+            TEST_PROJECT, TEST_SUBSCRIPTION, ['1', '2', '3', '4', '5']
+        )
+
+    @mock.patch('airflow.contrib.sensors.pubsub_sensor.PubSubHook')
+    def test_execute(self, mock_hook):
+        operator = PubSubPullSensor(task_id=TASK_ID, project=TEST_PROJECT,
+                                    subscription=TEST_SUBSCRIPTION,
+                                    poke_interval=0)
+        generated_messages = self._generate_messages(5)
+        mock_hook.return_value.pull.return_value = generated_messages
+        response = operator.execute(None)
+        mock_hook.return_value.pull.assert_called_with(
+            TEST_PROJECT, TEST_SUBSCRIPTION, 5, False)
+        self.assertEquals(response, generated_messages)
+
+    @mock.patch('airflow.contrib.sensors.pubsub_sensor.PubSubHook')
+    def test_execute_timeout(self, mock_hook):
+        operator = PubSubPullSensor(task_id=TASK_ID, project=TEST_PROJECT,
+                                    subscription=TEST_SUBSCRIPTION,
+                                    poke_interval=0, timeout=1)
+        mock_hook.return_value.pull.return_value = []
+        with self.assertRaises(AirflowSensorTimeout):
+            operator.execute(None)
+            mock_hook.return_value.pull.assert_called_with(
+                TEST_PROJECT, TEST_SUBSCRIPTION, 5, False)

Reply via email to