Revision: 393
Author: bslatkin
Date: Fri Nov  5 17:47:09 2010
Log: hub: better data sharding and tests for subscription reconfirmation periodic mapper
http://code.google.com/p/pubsubhubbub/source/detail?r=393

Added:
 /trunk/hub/offline_jobs_test.py
Modified:
 /trunk/hub/main.py
 /trunk/hub/main_test.py
 /trunk/hub/offline_jobs.py
 /trunk/nonstandard/virtual_feed.py

=======================================
--- /dev/null
+++ /trunk/hub/offline_jobs_test.py     Fri Nov  5 17:47:09 2010
@@ -0,0 +1,252 @@
+#!/usr/bin/env python
+#
+# Copyright 2010 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Tests for the offline_jobs module."""
+
+import datetime
+import logging
+logging.basicConfig(format='%(levelname)-8s %(filename)s] %(message)s')
+import time
+import unittest
+
+import testutil
+testutil.fix_path()
+
+from google.appengine.ext import db
+
+from mapreduce import context
+from mapreduce.lib import key_range
+
+import main
+import offline_jobs
+
+################################################################################
+
+class HashKeyDatastoreInputReaderTest(unittest.TestCase):
+  """Tests for the HashKeyDatastoreInputReader."""
+
+  def setUp(self):
+    """Sets up the test harness."""
+    testutil.setup_for_testing()
+    self.app = 'my-app-id'
+    self.entity_kind = 'my-entity-kind'
+    self.namespace = 'my-namespace'
+
+  def testOneShard(self):
+    """Tests just one shard."""
+    result = (
+ offline_jobs.HashKeyDatastoreInputReader._split_input_from_namespace(
+          self.app, self.namespace, self.entity_kind, 1))
+
+    expected = [
+      key_range.KeyRange(
+          key_start=db.Key.from_path(
+              u'my-entity-kind',
+              u'hash_0000000000000000000000000000000000000000',
+              _app=u'my-app-id'),
+          key_end=db.Key.from_path(
+              u'my-entity-kind',
+              u'hash_ffffffffffffffffffffffffffffffffffffffff',
+              _app=u'my-app-id'),
+          direction='ASC',
+          include_start=True,
+          include_end=True,
+          namespace='my-namespace',
+          _app='my-app-id')
+    ]
+    self.assertEquals(expected, result)
+
+  def testTwoShards(self):
+ """Tests two shares: one for number prefixes, one for letter prefixes."""
+    result = (
+ offline_jobs.HashKeyDatastoreInputReader._split_input_from_namespace(
+          self.app, self.namespace, self.entity_kind, 2))
+
+    expected = [
+      key_range.KeyRange(
+          key_start=db.Key.from_path(
+              u'my-entity-kind',
+              u'hash_0000000000000000000000000000000000000000',
+              _app=u'my-app-id'),
+          key_end=db.Key.from_path(
+              u'my-entity-kind',
+              u'hash_7fffffffffffffffffffffffffffffffffffffff',
+              _app=u'my-app-id'),
+          direction='DESC',
+          include_start=True,
+          include_end=True,
+          namespace='my-namespace',
+          _app='my-app-id'),
+      key_range.KeyRange(
+          key_start=db.Key.from_path(
+              u'my-entity-kind',
+              u'hash_7fffffffffffffffffffffffffffffffffffffff',
+              _app=u'my-app-id'),
+          key_end=db.Key.from_path(
+              u'my-entity-kind',
+              u'hash_ffffffffffffffffffffffffffffffffffffffff',
+              _app=u'my-app-id'),
+          direction='ASC',
+          include_start=False,
+          include_end=True,
+          namespace='my-namespace',
+          _app='my-app-id'),
+    ]
+    self.assertEquals(expected, result)
+
+  def testManyShards(self):
+    """Tests having many shards with multiple levels of splits."""
+    result = (
+ offline_jobs.HashKeyDatastoreInputReader._split_input_from_namespace(
+          self.app, self.namespace, self.entity_kind, 4))
+
+    expected = [
+      key_range.KeyRange(
+          key_start=db.Key.from_path(
+              u'my-entity-kind',
+              u'hash_0000000000000000000000000000000000000000',
+              _app=u'my-app-id'),
+          key_end=db.Key.from_path(
+              u'my-entity-kind',
+              u'hash_3fffffffffffffffffffffffffffffffffffffff',
+              _app=u'my-app-id'),
+          direction='DESC',
+          include_start=True,
+          include_end=True,
+          namespace='my-namespace',
+          _app='my-app-id'),
+      key_range.KeyRange(
+          key_start=db.Key.from_path(
+              u'my-entity-kind',
+              u'hash_3fffffffffffffffffffffffffffffffffffffff',
+              _app=u'my-app-id'),
+          key_end=db.Key.from_path(
+              u'my-entity-kind',
+              u'hash_7fffffffffffffffffffffffffffffffffffffff',
+              _app=u'my-app-id'),
+          direction='ASC',
+          include_start=False,
+          include_end=True,
+          namespace='my-namespace',
+          _app='my-app-id'),
+      key_range.KeyRange(
+          key_start=db.Key.from_path(
+              u'my-entity-kind',
+              u'hash_7fffffffffffffffffffffffffffffffffffffff',
+              _app=u'my-app-id'),
+          key_end=db.Key.from_path(
+              u'my-entity-kind',
+              u'hash_bfffffffffffffffffffffffffffffffffffffff',
+              _app=u'my-app-id'),
+          direction='DESC',
+          include_start=False,
+          include_end=True,
+          namespace='my-namespace',
+          _app='my-app-id'),
+      key_range.KeyRange(
+          key_start=db.Key.from_path(
+              u'my-entity-kind',
+              u'hash_bfffffffffffffffffffffffffffffffffffffff',
+              _app=u'my-app-id'),
+          key_end=db.Key.from_path(
+              u'my-entity-kind',
+              u'hash_ffffffffffffffffffffffffffffffffffffffff',
+              _app=u'my-app-id'),
+          direction='ASC',
+          include_start=False,
+          include_end=True,
+          namespace='my-namespace',
+          _app='my-app-id'),
+    ]
+    self.assertEquals(expected, result)
+
+
+Subscription = main.Subscription
+
+
+class SubscriptionReconfirmMapperTest(unittest.TestCase):
+  """Tests for the SubscriptionReconfirmMapper."""
+
+  def setUp(self):
+    """Sets up the test harness."""
+    testutil.setup_for_testing()
+    self.mapper = offline_jobs.SubscriptionReconfirmMapper()
+    self.callback = 'http://example.com/my-callback-url'
+    self.topic = 'http://example.com/my-topic-url'
+    self.token = 'token'
+    self.secret = 'my secrat'
+
+    self.now = datetime.datetime.utcnow()
+    self.threshold_seconds = 1000
+    self.threshold_timestamp = (
+        time.mktime(self.now.utctimetuple()) + self.threshold_seconds)
+    self.getnow = lambda: self.now
+
+    class FakeMapper(object):
+      params = {'threshold_timestamp': str(self.threshold_timestamp)}
+    class FakeSpec(object):
+      mapreduce_id = '1234'
+      mapper = FakeMapper()
+    self.context = context.Context(FakeSpec(), None)
+    context.Context._set(self.context)
+
+  def get_subscription(self):
+    """Returns the Subscription used for testing."""
+    return Subscription.get_by_key_name(
+        Subscription.create_key_name(self.callback, self.topic))
+
+  def testValidateParams(self):
+    """Tests the validate_params static method."""
+    self.assertRaises(
+        AssertionError,
+        offline_jobs.SubscriptionReconfirmMapper.validate_params,
+        {})
+    offline_jobs.SubscriptionReconfirmMapper.validate_params(
+        {'threshold_timestamp': 123})
+
+  def testIgnoreUnverified(self):
+    """Tests that unverified subscriptions are skipped."""
+    self.assertTrue(Subscription.request_insert(
+        self.callback, self.topic, self.token, self.secret,
+        now=self.getnow))
+    sub = self.get_subscription()
+    self.mapper.run(sub)
+    testutil.get_tasks(main.POLLING_QUEUE, expected_count=0)
+
+  def testAfterThreshold(self):
+    """Tests when a subscription is not yet ready for reconfirmation."""
+    self.assertTrue(Subscription.insert(
+        self.callback, self.topic, self.token, self.secret,
+        now=self.getnow, lease_seconds=self.threshold_seconds))
+    sub = self.get_subscription()
+    self.mapper.run(sub)
+    testutil.get_tasks(main.POLLING_QUEUE, expected_count=0)
+
+  def testBeforeThreshold(self):
+    """Tests when a subscription is ready for reconfirmation."""
+    self.assertTrue(Subscription.insert(
+        self.callback, self.topic, self.token, self.secret,
+        now=self.getnow, lease_seconds=self.threshold_seconds-1))
+    sub = self.get_subscription()
+    self.mapper.run(sub)
+ task = testutil.get_tasks(main.POLLING_QUEUE, index=0, expected_count=1)
+    self.assertEquals('polling', task['headers']['X-AppEngine-QueueName'])
+
+################################################################################
+
+if __name__ == '__main__':
+  unittest.main()
=======================================
--- /trunk/hub/main.py  Fri Nov  5 13:41:59 2010
+++ /trunk/hub/main.py  Fri Nov  5 17:47:09 2010
@@ -977,7 +977,7 @@
         confirmation.
     """
     RETRIES = 3
-    if os.environ.get('HTTP_X_APPENGINE_QUEUENAME') == POLLING_QUEUE:
+    if auto_reconfirm:
       target_queue = POLLING_QUEUE
     else:
       target_queue = SUBSCRIPTION_QUEUE
@@ -2148,15 +2148,15 @@
     self.start_map(
         name='Reconfirm expiring subscriptions',
         handler_spec='offline_jobs.SubscriptionReconfirmMapper.run',
-        reader_spec='mapreduce.input_readers.DatastoreInputReader',
+        reader_spec='offline_jobs.HashKeyDatastoreInputReader',
         reader_parameters=dict(
             processing_rate=100000,
-            entity_kind='main.Subscription'),
+            entity_kind='main.Subscription',
+            threshold_timestamp=int(
+                self.now() + SUBSCRIPTION_CHECK_BUFFER_SECONDS)),
         shard_count=SUBSCRIPTION_RECONFIRM_SHARD_COUNT,
         queue_name=POLLING_QUEUE,
         mapreduce_parameters=dict(
-          threshold_timestamp=int(
-              self.now() + SUBSCRIPTION_CHECK_BUFFER_SECONDS),
           done_callback='/work/cleanup_mapper',
           done_callback_queue=POLLING_QUEUE))

=======================================
--- /trunk/hub/main_test.py     Fri Nov  5 13:41:59 2010
+++ /trunk/hub/main_test.py     Fri Nov  5 17:47:09 2010
@@ -680,18 +680,17 @@
self.assertEquals(Subscription.STATE_NOT_VERIFIED, sub.subscription_state)
     testutil.get_tasks(main.SUBSCRIPTION_QUEUE, index=0, expected_count=6)

-  def testQueuePreserved(self):
-    """Tests that insert will put the task on the polling queue."""
+  def testQueueSelected(self):
+    """Tests that auto_reconfirm will put the task on the polling queue."""
     self.assertTrue(Subscription.request_insert(
-        self.callback, self.topic, self.token, self.secret))
-    testutil.get_tasks(main.SUBSCRIPTION_QUEUE, expected_count=1)
-    os.environ['HTTP_X_APPENGINE_QUEUENAME'] = main.POLLING_QUEUE
-    try:
-      self.assertFalse(Subscription.request_insert(
-          self.callback, self.topic, self.token, self.secret))
-    finally:
-      del os.environ['HTTP_X_APPENGINE_QUEUENAME']
-
+        self.callback, self.topic, self.token, self.secret,
+        auto_reconfirm=True))
+    testutil.get_tasks(main.SUBSCRIPTION_QUEUE, expected_count=0)
+    testutil.get_tasks(main.POLLING_QUEUE, expected_count=1)
+
+    self.assertFalse(Subscription.request_insert(
+        self.callback, self.topic, self.token, self.secret,
+        auto_reconfirm=False))
     testutil.get_tasks(main.SUBSCRIPTION_QUEUE, expected_count=1)
     testutil.get_tasks(main.POLLING_QUEUE, expected_count=1)

@@ -3512,15 +3511,14 @@
     self.handle('post', ('subscription_key_name', self.sub_key),
                         ('verify_token', self.verify_token),
                         ('secret', self.secret),
-                        ('next_state', Subscription.STATE_VERIFIED),
-                        ('auto_reconfirm', 'True'))
+                        ('next_state', Subscription.STATE_VERIFIED))
     sub = Subscription.get_by_key_name(self.sub_key)
self.assertEquals(Subscription.STATE_NOT_VERIFIED, sub.subscription_state)
     self.assertEquals(1, sub.confirm_failures)
     self.assertEquals(self.verify_token, sub.verify_token)
     self.assertEquals(self.secret, sub.secret)
-    self.verify_retry_task(sub.eta, Subscription.STATE_VERIFIED,
-                           auto_reconfirm=True,
+    self.verify_retry_task(sub.eta,
+                           Subscription.STATE_VERIFIED,
                            verify_token=self.verify_token,
                            secret=self.secret)

@@ -3699,19 +3697,19 @@
     def start_map(*args, **kwargs):
       self.assertEquals(kwargs, {
           'name': 'Reconfirm expiring subscriptions',
-          'reader_spec': 'mapreduce.input_readers.DatastoreInputReader',
+          'reader_spec': 'offline_jobs.HashKeyDatastoreInputReader',
           'queue_name': 'polling',
           'handler_spec': 'offline_jobs.SubscriptionReconfirmMapper.run',
           'shard_count': 4,
           'reader_parameters': {
             'entity_kind': 'main.Subscription',
-            'processing_rate': 100000
+            'processing_rate': 100000,
+            'threshold_timestamp':
+                int(self.now + main.SUBSCRIPTION_CHECK_BUFFER_SECONDS),
           },
           'mapreduce_parameters': {
             'done_callback': '/work/cleanup_mapper',
             'done_callback_queue': 'polling',
-            'threshold_timestamp':
-                int(self.now + main.SUBSCRIPTION_CHECK_BUFFER_SECONDS)
           },
       })
       self.called = True
=======================================
--- /trunk/hub/offline_jobs.py  Wed Sep 22 15:54:24 2010
+++ /trunk/hub/offline_jobs.py  Fri Nov  5 17:47:09 2010
@@ -19,6 +19,7 @@

 import datetime
 import logging
+import math
 import time

 from google.appengine.ext import db
@@ -26,7 +27,9 @@
 import main

 from mapreduce import context
+from mapreduce import input_readers
 from mapreduce import operation as op
+from mapreduce.lib import key_range


 def RemoveOldFeedEntryRecordPropertiesMapper(feed_entry_record):
@@ -62,6 +65,55 @@
       yield op.db.Delete(event)


+class HashKeyDatastoreInputReader(input_readers.DatastoreInputReader):
+  """A DatastoreInputReader that can split evenly across hash key ranges.
+
+ Assumes key names are in the format supplied by the main.get_hash_key_name
+  function.
+  """
+
+  @classmethod
+  def _split_input_from_namespace(
+      cls, app, namespace, entity_kind_name, shard_count):
+    hex_key_start = db.Key.from_path(
+        entity_kind_name, 0)
+    hex_key_end = db.Key.from_path(
+        entity_kind_name, int('f' * 40, base=16))
+    hex_range = key_range.KeyRange(
+        hex_key_start, hex_key_end, None, True, True,
+        namespace=namespace,
+        _app=app)
+
+    key_range_list = [hex_range]
+    number_of_half_splits = int(math.floor(math.log(shard_count, 2)))
+    for index in xrange(0, number_of_half_splits):
+      new_ranges = []
+      for current_range in key_range_list:
+        new_ranges.extend(current_range.split_range(1))
+      key_range_list = new_ranges
+
+    adjusted_range_list = []
+    for current_range in key_range_list:
+      adjusted_range = key_range.KeyRange(
+          key_start=db.Key.from_path(
+              current_range.key_start.kind(),
+              'hash_%040x' % (current_range.key_start.id() or 0),
+              _app=current_range._app),
+          key_end=db.Key.from_path(
+              current_range.key_end.kind(),
+              'hash_%040x' % (current_range.key_end.id() or 0),
+              _app=current_range._app),
+          direction=current_range.direction,
+          include_start=current_range.include_start,
+          include_end=current_range.include_end,
+          namespace=current_range.namespace,
+          _app=current_range._app)
+
+      adjusted_range_list.append(adjusted_range)
+
+    return adjusted_range_list
+
+
 class SubscriptionReconfirmMapper(object):
   """For reconfirming subscriptions that are nearing expiration."""

@@ -77,9 +129,7 @@
       return

     if self.threshold_timestamp is None:
-      params = context.get().mapreduce_spec.params
-      if 'threshold_timestamp' not in params:
-        params = context.get().mapreduce_spec.mapper.params
+      params = context.get().mapreduce_spec.mapper.params
       self.threshold_timestamp = datetime.datetime.utcfromtimestamp(
           float(params['threshold_timestamp']))

=======================================
--- /trunk/nonstandard/virtual_feed.py  Sun Jul 11 22:24:18 2010
+++ /trunk/nonstandard/virtual_feed.py  Fri Nov  5 17:47:09 2010
@@ -145,8 +145,13 @@

     def txn():
       event_to_deliver = EventToDeliver.create_event_for_topic(
-          fragment.topic, fragment.format, fragment.header_footer,
-          entry_payloads, set_parent=False, max_failures=1)
+          fragment.topic,
+          fragment.format,
+          self.request.headers.get('Content-Type', 'application/atom+xml'),
+          fragment.header_footer,
+          entry_payloads,
+          set_parent=False,
+          max_failures=1)
       db.put(event_to_deliver)
       event_to_deliver.enqueue()

Reply via email to