Revision: 372
Author: bslatkin
Date: Sun Jun 27 11:46:09 2010
Log: hub: moves fetch queue to memory
http://code.google.com/p/pubsubhubbub/source/detail?r=372

Modified:
 /trunk/hub/fork_join_queue.py
 /trunk/hub/fork_join_queue_test.py
 /trunk/hub/main.py
 /trunk/hub/main_test.py

=======================================
--- /trunk/hub/fork_join_queue.py       Mon May 17 01:57:39 2010
+++ /trunk/hub/fork_join_queue.py       Sun Jun 27 11:46:09 2010
@@ -88,6 +88,7 @@
 import random
 import time

+from google.net.proto import ProtocolBuffer
 from google.appengine.api import memcache
 from google.appengine.api.labs import taskqueue
 from google.appengine.ext import db
@@ -121,6 +122,9 @@
 class TaskConflictError(Error):
   """The added task has already ran, meaning the work index is invalid."""

+class MemcacheError(Error):
+  """Enqueuing the work item in memcache failed."""
+

 class ForkJoinQueue(object):
   """A fork-join queue for App Engine."""
@@ -133,13 +137,13 @@
                index_property,
                task_path,
                queue_name,
-               batch_size,
-               batch_period_ms,
-               lock_timeout_ms,
-               sync_timeout_ms,
-               stall_timeout_ms,
-               acquire_timeout_ms,
-               acquire_attempts):
+               batch_size=None,
+               batch_period_ms=None,
+               lock_timeout_ms=None,
+               sync_timeout_ms=None,
+               stall_timeout_ms=None,
+               acquire_timeout_ms=None,
+               acquire_attempts=None):
     """Initializer.

     Args:
@@ -190,7 +194,14 @@
                  memget=memcache.get,
                  memincr=memcache.incr,
                  memdecr=memcache.decr):
-    """Returns the next work index, incrementing the writer lock."""
+    """Reserves the next work index.
+
+    Args:
+      memget, memincr, memdecr: Used for testing.
+
+    Returns:
+      The next work index to use for work.
+    """
     for i in xrange(self.acquire_attempts):
       next_index = memget(self.index_name)
       if next_index is None:
@@ -256,19 +267,23 @@
True if all writers were definitely finished; False if the reader/writer
       lock timed out and we are proceeding anyways.
     """
+ # Increment the batch index counter so incoming jobs will use a new index.
+    # Don't bother setting an initial value here because next_index() will
+    # do this when it notices no current index is present. Do this *before*
+ # closing the reader/writer lock below to decrease active writers on the
+    # current index.
+    memcache.incr(self.index_name)
+
     # Prevent new writers by making the counter extremely negative. If the
# decrement fails here we can't recover anyways, so just let the worker go.
     add_counter = self.add_counter_template % last_index
     memcache.decr(add_counter, self.LOCK_OFFSET)

- # Increment the batch index counter so incoming jobs will use a new index.
-    # Don't bother setting an initial value here because next_index() will
-    # do this when it notices no current index is present.
-    memcache.incr(self.index_name)
-
     for i in xrange(self.sync_attempts):
       counter = memcache.get(add_counter)
-      if counter is None or int(counter) == self.LOCK_OFFSET:
+ # Less than or equal LOCK_OFFSET here in case a writer decrements twice
+      # due to rerunning failure tasks.
+      if counter is None or int(counter) <= self.LOCK_OFFSET:
# Worst-case the counter will be gone due to memcache eviction, which
         # means the worker can procede with without waiting for writers
         # and just process whatever it can find. This may drop some work.
@@ -279,8 +294,19 @@

     return False

-  def pop(self, request):
-    """Pops work to be done based on a task payload.
+  def _query_work(self, index, cursor):
+    """TODO
+    """
+    query = (self.model_class.all()
+        .filter('%s =' % self.index_property.name, index)
+        .order('__key__'))
+    if cursor:
+      query.with_cursor(cursor)
+    result_list = query.fetch(self.batch_size)
+    return result_list, query.cursor()
+
+  def pop_request(self, request):
+    """Pops work to be done based on a task queue request.

     Args:
       request: webapp.Request with the task payload.
@@ -288,8 +314,19 @@
     Returns:
       A list of work items, if any.
     """
-    cursor = request.get('cursor')
-    task_name = os.environ['HTTP_X_APPENGINE_TASKNAME']
+    return self.pop(os.environ['HTTP_X_APPENGINE_TASKNAME'],
+                    request.get('cursor'))
+
+  def pop(self, task_name, cursor=None):
+    """Pops work to be done based on just the task name.
+
+    Args:
+      task_name: The name of the task.
+      cursor: The value of the cursor for this task (optional).
+
+    Returns:
+      A list of work items, if any.
+    """
     rest, index, generation = task_name.rsplit('-', 2)
     index, generation = int(index), int(generation)

@@ -298,19 +335,15 @@
       # tasks can start processing immediately.
       self._increment_index(index)

-    query = (self.model_class.all()
-        .filter('%s =' % self.index_property.name, index)
-        .order('__key__'))
-    if cursor:
-      query.with_cursor(cursor)
-    result_list = query.fetch(self.batch_size)
+    result_list, cursor = self._query_work(index, cursor)
+
     if len(result_list) == self.batch_size:
       try:
         taskqueue.Task(
           method='POST',
           name='%s-%d-%d' % (rest, index, generation + 1),
           url=self.task_path,
-          params={'cursor': query.cursor()}
+          params={'cursor': cursor}
         ).add(self.get_queue_name(index))
except (taskqueue.TaskAlreadyExistsError, taskqueue.TombstonedTaskError):
         # This means the continuation chain already started and this root
@@ -335,3 +368,92 @@

   def get_queue_name(self, index):
     return self.queue_name % {'shard': 1 + (index % self.shard_count)}
+
+
+class MemcacheForkJoinQueue(ShardedForkJoinQueue):
+  """A fork-join queue that only stores work items in memcache.
+
+ To use, call next_index() to get the work index then call the put() method,
+  passing one or more model instances to enqueued in memcache.
+
+  Also a sharded queue for maximum throughput.
+  """
+
+  def __init__(self, *args, **kwargs):
+    """Initializer.
+
+    Args:
+      *args, **kwargs: Passed to ShardedForkJoinQueue.
+ expiration_seconds: How long items inserted into memcache should remain
+        until they are evicted due to timeout. Default is 0, meaning they
+        will never be evicted.
+    """
+    if 'expiration_seconds' in kwargs:
+      self.expiration_seconds = kwargs.pop('expiration_seconds')
+    else:
+      self.expiration_seconds = 0
+    ShardedForkJoinQueue.__init__(self, *args, **kwargs)
+
+  def _create_length_key(self, index):
+ """Creates a length memecache key for the length of the in-memory queue."""
+    return '%s:length:%d' % (self.name, index)
+
+  def _create_index_key(self, index, number):
+ """Creates an index memcache key for the given in-memory queue location."""
+    return '%s:index:%d-%d' % (self.name, index, number)
+
+  def put(self,
+          index,
+          entity_list,
+          memincr=memcache.incr,
+          memset=memcache.set_multi):
+    """Enqueue a model instance on this queue.
+
+    Does not write to the Datastore.
+
+    Args:
+      index: The work index for this entity.
+ entity_list: List of work entities to insert into the in-memory queue.
+      memincr, memset: Used for testing.
+
+    Raises:
+      MemcacheError if the entities were not successfully added.
+    """
+    length_key = self._create_length_key(index)
+    end = memincr(length_key, len(entity_list), initial_value=0)
+    if end is None:
+      raise MemcacheError('Could not increment length key %r' % length_key)
+
+    start = end - len(entity_list)
+    key_map = {}
+    for number, entity in zip(xrange(start, end), entity_list):
+ key_map[self._create_index_key(index, number)] = db.model_to_protobuf(
+          entity)
+
+    result = memset(key_map, time=self.expiration_seconds)
+    if result:
+      raise MemcacheError('Could not set memcache keys %r' % result)
+
+  def _query_work(self, index, cursor):
+    """Queries for work in memcache."""
+    if cursor:
+      cursor = int(cursor)
+    else:
+      cursor = 0
+
+    key_list = [self._create_index_key(index, n)
+                for n in xrange(cursor, cursor + self.batch_size)]
+    results = memcache.get_multi(key_list)
+
+    result_list = []
+    for key in key_list:
+      proto = results.get(key)
+      if not proto:
+        continue
+      try:
+        result_list.append(db.model_from_protobuf(proto))
+      except ProtocolBuffer.ProtocolBufferDecodeError:
+ logging.exception('Could not decode EntityPb at memcache key %r: %r',
+                          key, proto)
+
+    return result_list, cursor + self.batch_size
=======================================
--- /trunk/hub/fork_join_queue_test.py  Mon May 17 01:57:39 2010
+++ /trunk/hub/fork_join_queue_test.py  Sun Jun 27 11:46:09 2010
@@ -71,6 +71,21 @@
     shard_count=4)


+MEMCACHE_QUEUE = fork_join_queue.MemcacheForkJoinQueue(
+    TestModel,
+    TestModel.work_index,
+    '/path/to/my/task',
+    'default',
+    batch_size=3,
+    batch_period_ms=200,
+    lock_timeout_ms=1000,
+    sync_timeout_ms=250,
+    stall_timeout_ms=30000,
+    acquire_timeout_ms=50,
+    acquire_attempts=20,
+    shard_count=4)
+
+
 class ForkJoinQueueTest(unittest.TestCase):
   """Tests for the ForkJoinQueue class."""

@@ -226,7 +241,7 @@
     request = testutil.create_test_request('POST', None)
     os.environ['HTTP_X_APPENGINE_TASKNAME'] = \
         self.expect_task(t1.work_index)['name']
-    result = TEST_QUEUE.pop(request)
+    result = TEST_QUEUE.pop_request(request)

     self.assertEquals(1, len(result))
     self.assertEquals(t1.key(), result[0].key())
@@ -250,7 +265,7 @@
     request = testutil.create_test_request('POST', None)
     os.environ['HTTP_X_APPENGINE_TASKNAME'] = \
         self.expect_task(work_index)['name']
-    result_list = TEST_QUEUE.pop(request)
+    result_list = TEST_QUEUE.pop_request(request)

     self.assertEquals(3, len(result_list))
     for i, result in enumerate(result_list):
@@ -268,7 +283,7 @@
     request = testutil.create_test_request('POST', None,
                                            *next_task['params'].items())
     os.environ['HTTP_X_APPENGINE_TASKNAME'] = next_task['name']
-    result_list = TEST_QUEUE.pop(request)
+    result_list = TEST_QUEUE.pop_request(request)
     self.assertEquals(3, len(result_list))
     for i, result in enumerate(result_list):
       self.assertEquals(work_index, result.work_index)
@@ -286,7 +301,7 @@
     request = testutil.create_test_request('POST', None,
                                            *next_task['params'].items())
     os.environ['HTTP_X_APPENGINE_TASKNAME'] = next_task['name']
-    result_list = TEST_QUEUE.pop(request)
+    result_list = TEST_QUEUE.pop_request(request)
     self.assertEquals([], result_list)
     testutil.get_tasks('default', expected_count=3, usec_eta=True)

@@ -304,10 +319,10 @@
     os.environ['HTTP_X_APPENGINE_TASKNAME'] = \
         self.expect_task(work_index)['name']

-    result_list = TEST_QUEUE.pop(request)
+    result_list = TEST_QUEUE.pop_request(request)
     testutil.get_tasks('default', expected_count=2)

-    result_list = TEST_QUEUE.pop(request)
+    result_list = TEST_QUEUE.pop_request(request)
     testutil.get_tasks('default', expected_count=2)

   def testIncrementIndexFail(self):
@@ -341,7 +356,7 @@
       request = testutil.create_test_request('POST', None)
       os.environ['HTTP_X_APPENGINE_TASKNAME'] = \
           self.expect_task(work_index)['name']
-      result_list = SHARDED_QUEUE.pop(request)
+      result_list = SHARDED_QUEUE.pop_request(request)

       self.assertEquals(3, len(result_list))
       for i, result in enumerate(result_list):
@@ -357,6 +372,105 @@
     finally:
       stub._IsValidQueue = old_valid

+  def testMemcacheQueue(self):
+ """Tests adding and popping from an in-memory queue with continuation."""
+    work_index = MEMCACHE_QUEUE.next_index()
+    work_items = [TestModel(key=db.Key.from_path(TestModel.kind(), i),
+                            work_index=work_index, number=i)
+                  for i in xrange(1, 6)]
+    MEMCACHE_QUEUE.put(work_index, work_items)
+    MEMCACHE_QUEUE.add(work_index, gettime=self.gettime1)
+    testutil.get_tasks('default', expected_count=1)
+
+    # First pop request.
+    request = testutil.create_test_request('POST', None)
+    os.environ['HTTP_X_APPENGINE_TASKNAME'] = \
+        self.expect_task(work_index)['name']
+    result_list = MEMCACHE_QUEUE.pop_request(request)
+
+    self.assertEquals(3, len(result_list))
+    for i, result in enumerate(result_list):
+      self.assertEquals(work_index, result.work_index)
+      self.assertEquals(i + 1, result.number)
+
+    # Continuation task enqueued.
+    next_task = testutil.get_tasks('default',
+                                   expected_count=2,
+                                   index=1)
+    self.assertEquals(3, int(next_task['params']['cursor']))
+    self.assertTrue(next_task['name'].endswith('-1'))
+
+    # Second pop request.
+    request = testutil.create_test_request(
+        'POST', None, *next_task['params'].items())
+    os.environ['HTTP_X_APPENGINE_TASKNAME'] = next_task['name']
+    result_list = MEMCACHE_QUEUE.pop_request(request)
+
+    self.assertEquals(2, len(result_list))
+    for i, result in enumerate(result_list):
+      self.assertEquals(work_index, result.work_index)
+      self.assertEquals(i + 4, result.number)
+
+  def testMemcacheQueue_IncrError(self):
+    """Tests calling put() when memcache increment fails."""
+    work_index = MEMCACHE_QUEUE.next_index()
+    entity = TestModel(work_index=work_index, number=0)
+    self.assertRaises(fork_join_queue.MemcacheError,
+                      MEMCACHE_QUEUE.put,
+                      work_index, [entity],
+                      memincr=lambda *a, **k: None)
+
+  def testMemcacheQueue_PutSetError(self):
+    """Tests calling put() when memcache set fails."""
+    work_index = MEMCACHE_QUEUE.next_index()
+    entity = TestModel(work_index=work_index, number=0)
+    self.assertRaises(fork_join_queue.MemcacheError,
+                      MEMCACHE_QUEUE.put,
+                      work_index, [entity],
+                      memset=lambda *a, **k: ['blah'])
+
+  def testMemcacheQueue_PopError(self):
+    """Tests calling pop() when memcache is down."""
+    work_index = MEMCACHE_QUEUE.next_index()
+    entity = TestModel(work_index=work_index, number=0)
+    MEMCACHE_QUEUE.put(work_index, [entity])
+    memcache.flush_all()
+
+    request = testutil.create_test_request('POST', None)
+    os.environ['HTTP_X_APPENGINE_TASKNAME'] = \
+        self.expect_task(work_index)['name']
+    result_list = MEMCACHE_QUEUE.pop_request(request)
+    self.assertEquals([], result_list)
+
+  def testMemcacheQueue_PopHoles(self):
+    """Tests when there are holes in the memcache array."""
+    work_index = MEMCACHE_QUEUE.next_index()
+    work_items = [TestModel(key=db.Key.from_path(TestModel.kind(), i),
+                            work_index=work_index, number=i)
+                  for i in xrange(1, 6)]
+    MEMCACHE_QUEUE.put(work_index, work_items)
+    memcache.delete(MEMCACHE_QUEUE._create_index_key(work_index, 1))
+
+    request = testutil.create_test_request('POST', None)
+    os.environ['HTTP_X_APPENGINE_TASKNAME'] = \
+        self.expect_task(work_index)['name']
+    result_list = MEMCACHE_QUEUE.pop_request(request)
+    self.assertEquals([1, 3], [r.number for r in result_list])
+
+  def testMemcacheQueue_PopDecodeError(self):
+    """Tests when proto decoding fails on the pop() call."""
+    work_index = MEMCACHE_QUEUE.next_index()
+    work_items = [TestModel(key=db.Key.from_path(TestModel.kind(), i),
+                            work_index=work_index, number=i)
+                  for i in xrange(1, 6)]
+    MEMCACHE_QUEUE.put(work_index, work_items)
+
+ memcache.set(MEMCACHE_QUEUE._create_index_key(work_index, 1), 'bad data')
+    request = testutil.create_test_request('POST', None)
+    os.environ['HTTP_X_APPENGINE_TASKNAME'] = \
+        self.expect_task(work_index)['name']
+    result_list = MEMCACHE_QUEUE.pop_request(request)
+
################################################################################

 if __name__ == '__main__':
=======================================
--- /trunk/hub/main.py  Sat Jun  5 18:26:01 2010
+++ /trunk/hub/main.py  Sun Jun 27 11:46:09 2010
@@ -1075,7 +1075,7 @@
     return cls.get_by_key_name(get_hash_key_name(topic))

   @classmethod
-  def insert(cls, topic_list, source_dict=None):
+  def insert(cls, topic_list, source_dict=None, memory_only=True):
     """Inserts a set of FeedToFetch entities for a set of topics.

     Overwrites any existing entities that are already there.
@@ -1084,6 +1084,10 @@
       topic_list: List of the topic URLs of feeds that need to be fetched.
       source_dict: Dictionary of sources for the feed. Defaults to an empty
         dictionary.
+      memory_only: Only save FeedToFetch records to memory, not to disk.
+
+    Returns:
+      The list of FeedToFetch records that was created.
     """
     if not topic_list:
       return
@@ -1098,18 +1102,27 @@
     else:
       cls.FORK_JOIN_QUEUE.queue_name = FEED_QUEUE

-    work_index = cls.FORK_JOIN_QUEUE.next_index()
+    if memory_only:
+      work_index = cls.FORK_JOIN_QUEUE.next_index()
+    else:
+      work_index = None
     try:
       feed_list = [
-          cls(key_name=get_hash_key_name(topic),
+          cls(key=db.Key.from_path(cls.kind(), get_hash_key_name(topic)),
               topic=topic,
               source_keys=list(source_keys),
               source_values=list(source_values),
               work_index=work_index)
           for topic in set(topic_list)]
-      db.put(feed_list)
+      if memory_only:
+        cls.FORK_JOIN_QUEUE.put(work_index, feed_list)
+      else:
+        db.put(feed_list)
     finally:
-      cls.FORK_JOIN_QUEUE.add(work_index)
+      if memory_only:
+        cls.FORK_JOIN_QUEUE.add(work_index)
+
+    return feed_list

   def fetch_failed(self,
                    max_failures=MAX_FEED_PULL_FAILURES,
@@ -1147,7 +1160,9 @@
     take care of this FeedToFetch and we should leave the entry.

     Returns:
-      True if the entity was deleted, False otherwise.
+      True if the entity was deleted, False otherwise. In the case the
+      FeedToFetch record never made it into the Datastore (because it only
+      ever lived in the in-memory cache), this function will return False.
     """
     def txn():
       other = db.get(self.key())
@@ -1180,7 +1195,7 @@
         return


-FeedToFetch.FORK_JOIN_QUEUE = fork_join_queue.ShardedForkJoinQueue(
+FeedToFetch.FORK_JOIN_QUEUE = fork_join_queue.MemcacheForkJoinQueue(
     FeedToFetch,
     FeedToFetch.work_index,
     '/work/pull_feeds',
@@ -1192,7 +1207,8 @@
     stall_timeout_ms=30000,
     acquire_timeout_ms=10,
     acquire_attempts=50,
-    shard_count=1)
+    shard_count=1,
+    expiration_seconds=600)  # Give up on fetches after 10 minutes.


 class FeedRecord(db.Model):
@@ -2655,7 +2671,7 @@
         return
       self._handle_fetches([work])
     else:
-      work_list = FeedToFetch.FORK_JOIN_QUEUE.pop(self.request)
+      work_list = FeedToFetch.FORK_JOIN_QUEUE.pop_request(self.request)
       self._handle_fetches(work_list)

################################################################################
@@ -2802,7 +2818,9 @@
       for topic in topic_list:
         KnownFeed.record(topic)
     else:
-      FeedToFetch.insert(topic_list)
+      # Force these FeedToFetch records to be written to disk so we ensure
+      # that we will eventually polll the feeds.
+      FeedToFetch.insert(topic_list, memory_only=False)
   except (taskqueue.Error, apiproxy_errors.Error,
           db.Error, runtime.DeadlineExceededError,
           fork_join_queue.Error):
=======================================
--- /trunk/hub/main_test.py     Sat Jun  5 18:26:01 2010
+++ /trunk/hub/main_test.py     Sun Jun 27 11:46:09 2010
@@ -725,9 +725,7 @@
   def testInsertAndGet(self):
     """Tests inserting and getting work."""
     all_topics = [self.topic, self.topic2, self.topic3]
-    FeedToFetch.insert(all_topics)
-    found_feeds = [FeedToFetch.get_by_topic(t) for t in all_topics]
-    found_topics = set(t.topic for t in found_feeds)
+    found_feeds = FeedToFetch.insert(all_topics)
     task = testutil.get_tasks(main.FEED_QUEUE, index=0, expected_count=1)
self.assertTrue(task['name'].endswith('%d-0' % found_feeds[0].work_index))

@@ -745,8 +743,7 @@
   def testDuplicates(self):
     """Tests duplicate urls."""
     all_topics = [self.topic, self.topic, self.topic2, self.topic2]
-    FeedToFetch.insert(all_topics)
-    found_feeds = [FeedToFetch.get_by_topic(t) for t in all_topics]
+    found_feeds = FeedToFetch.insert(all_topics)
     found_topics = set(t.topic for t in found_feeds)
     self.assertEquals(set(all_topics), found_topics)
     task = testutil.get_tasks(main.FEED_QUEUE, index=0, expected_count=1)
@@ -754,17 +751,25 @@

   def testDone(self):
     """Tests marking the feed as completed."""
-    FeedToFetch.insert([self.topic])
-    feed = FeedToFetch.get_by_topic(self.topic)
+    (feed,) = FeedToFetch.insert([self.topic])
+    self.assertFalse(feed.done())
+    self.assertTrue(FeedToFetch.get_by_topic(self.topic) is None)
+
+  def testDoneAfterFailure(self):
+ """Tests done() after a fetch_failed() writes the FeedToFetch to disk."""
+    (feed,) = FeedToFetch.insert([self.topic])
+    feed.fetch_failed()
     self.assertTrue(feed.done())
     self.assertTrue(FeedToFetch.get_by_topic(self.topic) is None)

   def testDoneConflict(self):
     """Tests when another entity was written over the top of this one."""
-    FeedToFetch.insert([self.topic])
-    feed = FeedToFetch.get_by_topic(self.topic)
-    FeedToFetch.insert([self.topic])
-    self.assertFalse(feed.done())
+    (feed1,) = FeedToFetch.insert([self.topic])
+    feed1.put()
+    (feed2,) = FeedToFetch.insert([self.topic])
+    feed2.put()
+
+    self.assertFalse(feed1.done())
     self.assertTrue(FeedToFetch.get_by_topic(self.topic) is not None)

   def testFetchFailed(self):
@@ -772,10 +777,10 @@
     start = datetime.datetime.utcnow()
     now = lambda: start

-    FeedToFetch.insert([self.topic])
+    (feed,) = FeedToFetch.insert([self.topic])
     etas = []
     for i, delay in enumerate((5, 10, 20, 40, 80)):
-      feed = FeedToFetch.get_by_topic(self.topic)
+      feed = FeedToFetch.get_by_topic(self.topic) or feed
       feed.fetch_failed(max_failures=5, retry_period=5, now=now)
       expected_eta = start + datetime.timedelta(seconds=delay)
       self.assertEquals(expected_eta, feed.eta)
@@ -794,14 +799,11 @@
   def testQueuePreserved(self):
"""Tests the request's polling queue is preserved for new FeedToFetch."""
     FeedToFetch.insert([self.topic])
-    feed = FeedToFetch.all().get()
     testutil.get_tasks(main.FEED_QUEUE, expected_count=1)
-    feed.delete()

     os.environ['HTTP_X_APPENGINE_QUEUENAME'] = main.POLLING_QUEUE
     try:
-      FeedToFetch.insert([self.topic])
-      feed = FeedToFetch.all().get()
+      (feed,) = FeedToFetch.insert([self.topic])
       testutil.get_tasks(main.FEED_QUEUE, expected_count=1)
       testutil.get_tasks(main.POLLING_QUEUE, expected_count=1)
     finally:
@@ -811,10 +813,8 @@
     """Tests when sources are supplied."""
     source_dict = {'foo': 'bar', 'meepa': 'stuff'}
     all_topics = [self.topic, self.topic2, self.topic3]
-    FeedToFetch.insert(all_topics, source_dict=source_dict)
-    for topic in all_topics:
-      feed_to_fetch = FeedToFetch.get_by_topic(topic)
-      self.assertEquals(topic, feed_to_fetch.topic)
+    feed_list = FeedToFetch.insert(all_topics, source_dict=source_dict)
+    for feed_to_fetch in feed_list:
       found_source_dict = dict(zip(feed_to_fetch.source_keys,
                                    feed_to_fetch.source_values))
       self.assertEquals(source_dict, found_source_dict)
@@ -1147,6 +1147,11 @@
     self.topic2 = 'http://example.com/second-url'
     self.topic3 = 'http://example.com/third-url'

+  def get_feeds_to_fetch(self):
+    """Gets the enqueued FeedToFetch records."""
+    return FeedToFetch.FORK_JOIN_QUEUE.pop(
+ testutil.get_tasks(main.FEED_QUEUE, index=0, expected_count=1)['name'])
+
   def testDebugFormRenders(self):
     self.handle('get')
     self.assertTrue('<html>' in self.response_body())
@@ -1181,7 +1186,8 @@
                 ('hub.url', self.topic3))
     self.assertEquals(204, self.response_code())
     expected_topics = set([self.topic, self.topic2, self.topic3])
-    inserted_topics = set(f.topic for f in FeedToFetch.all())
+    feed_list = self.get_feeds_to_fetch()
+    inserted_topics = set(f.topic for f in feed_list)
     self.assertEquals(expected_topics, inserted_topics)

   def testIgnoreUnknownFeed(self):
@@ -1191,7 +1197,7 @@
                 ('hub.url', self.topic2),
                 ('hub.url', self.topic3))
     self.assertEquals(204, self.response_code())
-    self.assertEquals([], list(FeedToFetch.all()))
+    testutil.get_tasks(main.FEED_QUEUE, expected_count=0)

   def testDuplicateUrls(self):
     db.put([KnownFeed.create(self.topic),
@@ -1214,7 +1220,7 @@
                 ('hub.url', self.topic2))
     self.assertEquals(204, self.response_code())
     expected_topics = set([self.topic, self.topic2])
-    inserted_topics = set(f.topic for f in FeedToFetch.all())
+    inserted_topics = set(f.topic for f in self.get_feeds_to_fetch())
     self.assertEquals(expected_topics, inserted_topics)

   def testInsertFailure(self):
@@ -1251,7 +1257,7 @@
                 ('hub.url', self.topic3))
     self.assertEquals(204, self.response_code())
     expected_topics = set([self.topic, self.topic2, self.topic3])
-    inserted_topics = set(f.topic for f in FeedToFetch.all())
+    inserted_topics = set(f.topic for f in self.get_feeds_to_fetch())
     self.assertEquals(expected_topics, inserted_topics)

   def testNormalization(self):
@@ -1269,7 +1275,7 @@
                 ('hub.url', self.topic2),
                 ('hub.url', self.topic3))
     self.assertEquals(204, self.response_code())
-    inserted_topics = set(f.topic for f in FeedToFetch.all())
+    inserted_topics = set(f.topic for f in self.get_feeds_to_fetch())
     self.assertEquals(set(normalized), inserted_topics)

   def testIri(self):
@@ -1285,7 +1291,7 @@
                 ('hub.url', self.topic2 + FUNNY_UTF8),
                 ('hub.url', self.topic3 + FUNNY_UTF8))
     self.assertEquals(204, self.response_code())
-    inserted_topics = set(f.topic for f in FeedToFetch.all())
+    inserted_topics = set(f.topic for f in self.get_feeds_to_fetch())
     self.assertEquals(set(normalized), inserted_topics)

   def testUnicode(self):
@@ -1303,7 +1309,7 @@
         '&hub.url=' + urllib.quote(self.topic3) + FUNNY_UTF8)
     self.handle_body('post', payload)
     self.assertEquals(204, self.response_code())
-    inserted_topics = set(f.topic for f in FeedToFetch.all())
+    inserted_topics = set(f.topic for f in self.get_feeds_to_fetch())
     self.assertEquals(set(normalized), inserted_topics)

   def testSources(self):
@@ -1327,8 +1333,7 @@
                   ('hub.url', self.topic3),
                   ('the-real-thing', 'testvalue'))
       self.assertEquals(204, self.response_code())
-      for topic in topics:
-        feed_to_fetch = FeedToFetch.get_by_topic(topic)
+      for feed_to_fetch in self.get_feeds_to_fetch():
         found_source_dict = dict(zip(feed_to_fetch.source_keys,
                                      feed_to_fetch.source_values))
         self.assertEquals(source_dict, found_source_dict)
@@ -1493,6 +1498,15 @@
     main.find_feed_updates = self.old_find_feed_updates
     urlfetch_test_stub.instance.verify_and_reset()

+  def run_fetch_task(self, index=0):
+    """Runs the currently enqueued fetch task."""
+    task = testutil.get_tasks(main.FEED_QUEUE, index=index)
+    os.environ['HTTP_X_APPENGINE_TASKNAME'] = task['name']
+    try:
+      self.handle('post')
+    finally:
+      del os.environ['HTTP_X_APPENGINE_TASKNAME']
+
   def testNoWork(self):
     self.handle('post', ('topic', self.topic))

@@ -1502,7 +1516,7 @@
     urlfetch_test_stub.instance.expect(
         'get', self.topic, 200, self.expected_response,
         response_headers=self.headers)
-    self.handle('post', ('topic', self.topic))
+    self.run_fetch_task()

     # Verify that all feed entry records have been written along with the
     # EventToDeliver and FeedRecord.
@@ -1539,7 +1553,7 @@
     urlfetch_test_stub.instance.expect(
         'get', self.topic, 200, self.expected_response,
         response_headers=self.headers)
-    self.handle('post', ('topic', self.topic))
+    self.run_fetch_task()

     feed_entries = FeedEntryRecord.get_entries_for_topic(
         self.topic, self.all_ids)
@@ -1574,7 +1588,7 @@
     urlfetch_test_stub.instance.expect(
         'get', self.topic, 200, self.expected_response,
         response_headers=self.headers)
-    self.handle('post', ('topic', self.topic))
+    self.run_fetch_task()

     feed_entries = FeedEntryRecord.get_entries_for_topic(
         self.topic, self.all_ids)
@@ -1604,7 +1618,7 @@
     urlfetch_test_stub.instance.expect(
         'get', self.topic, 200, self.expected_response,
         response_headers=self.headers)
-    self.handle('post', ('topic', self.topic))
+    self.run_fetch_task()

     feed = FeedToFetch.get_by_key_name(get_hash_key_name(self.topic))
     self.assertTrue(feed is None)
@@ -1632,7 +1646,7 @@
         'get', self.topic, 304, '',
         request_headers=request_headers,
         response_headers=self.headers)
-    self.handle('post', ('topic', self.topic))
+    self.run_fetch_task()
     self.assertTrue(EventToDeliver.all().get() is None)
     testutil.get_tasks(main.EVENT_QUEUE, expected_count=0)

@@ -1645,7 +1659,7 @@
     urlfetch_test_stub.instance.expect(
         'get', self.topic, 200, self.expected_response,
         response_headers=self.headers)
-    self.handle('post', ('topic', self.topic))
+    self.run_fetch_task()
     self.assertTrue(EventToDeliver.all().get() is None)
     testutil.get_tasks(main.EVENT_QUEUE, expected_count=0)

@@ -1662,7 +1676,7 @@
     FeedToFetch.insert([self.topic])
     urlfetch_test_stub.instance.expect(
'get', self.topic, 200, self.expected_response, urlfetch_error=True)
-    self.handle('post', ('topic', self.topic))
+    self.run_fetch_task()
     feed = FeedToFetch.get_by_key_name(get_hash_key_name(self.topic))
     self.assertEquals(1, feed.fetching_failures)
     testutil.get_tasks(main.EVENT_QUEUE, expected_count=0)
@@ -1672,12 +1686,39 @@
     self.assertEquals(self.topic, task['params']['topic'])
     self.assertEquals([(0, 1)], main.FETCH_SCORER.get_scores([self.topic]))

+  def testPullRetry(self):
+    """Tests that the task enqueued after a failure will run properly."""
+    FeedToFetch.insert([self.topic])
+    urlfetch_test_stub.instance.expect(
+ 'get', self.topic, 200, self.expected_response, urlfetch_error=True)
+    self.run_fetch_task()
+
+    # Verify the failed feed was written to the Datastore.
+    feed = FeedToFetch.get_by_key_name(get_hash_key_name(self.topic))
+    self.assertEquals(1, feed.fetching_failures)
+    testutil.get_tasks(main.EVENT_QUEUE, expected_count=0)
+    testutil.get_tasks(main.FEED_QUEUE, expected_count=1)
+    testutil.get_tasks(main.FEED_RETRIES_QUEUE, expected_count=1)
+    task = testutil.get_tasks(main.FEED_RETRIES_QUEUE,
+                              index=0, expected_count=1)
+    self.assertEquals(self.topic, task['params']['topic'])
+    self.assertEquals([(0, 1)], main.FETCH_SCORER.get_scores([self.topic]))
+
+    urlfetch_test_stub.instance.expect(
+ 'get', self.topic, 200, self.expected_response, urlfetch_error=True)
+    self.handle('post', *task['params'].items())
+    feed = FeedToFetch.get_by_key_name(get_hash_key_name(self.topic))
+    self.assertEquals(2, feed.fetching_failures)
+    testutil.get_tasks(main.EVENT_QUEUE, expected_count=0)
+    testutil.get_tasks(main.FEED_QUEUE, expected_count=1)
+    testutil.get_tasks(main.FEED_RETRIES_QUEUE, expected_count=2)
+
   def testPullBadStatusCode(self):
     """Tests when the response status is bad."""
     FeedToFetch.insert([self.topic])
     urlfetch_test_stub.instance.expect(
         'get', self.topic, 500, self.expected_response)
-    self.handle('post', ('topic', self.topic))
+    self.run_fetch_task()
     feed = FeedToFetch.get_by_key_name(get_hash_key_name(self.topic))
     self.assertEquals(1, feed.fetching_failures)

@@ -1693,7 +1734,7 @@
     FeedToFetch.insert([self.topic])
     urlfetch_test_stub.instance.expect(
'get', self.topic, 200, self.expected_response, apiproxy_error=True)
-    self.handle('post', ('topic', self.topic))
+    self.run_fetch_task()
     feed = FeedToFetch.get_by_key_name(get_hash_key_name(self.topic))
     self.assertEquals(1, feed.fetching_failures)

@@ -1711,7 +1752,7 @@
     self.assertTrue(db.get(KnownFeed.create_key(self.topic)) is not None)
     self.entry_list = []
     FeedToFetch.insert([self.topic])
-    self.handle('post', ('topic', self.topic))
+    self.run_fetch_task()

     # Verify that *no* feed entry records have been written.
     self.assertEquals([], FeedEntryRecord.get_entries_for_topic(
@@ -1742,7 +1783,7 @@
         'get', real_topic, 200, self.expected_response,
         response_headers=self.headers)

-    self.handle('post', ('topic', self.topic))
+    self.run_fetch_task()
     self.assertTrue(EventToDeliver.all().get() is not None)
     testutil.get_tasks(main.EVENT_QUEUE, expected_count=1)

@@ -1765,7 +1806,7 @@
           response_headers=self.headers.copy())
       last_topic = next_topic

-    self.handle('post', ('topic', self.topic))
+    self.run_fetch_task()
     self.assertTrue(EventToDeliver.all().get() is None)

     testutil.get_tasks(main.EVENT_QUEUE, expected_count=0)
@@ -1789,7 +1830,7 @@
         'get', self.topic, 302, '',
         response_headers=self.headers.copy())

-    self.handle('post', ('topic', self.topic))
+    self.run_fetch_task()
     self.assertTrue(EventToDeliver.all().get() is None)
     testutil.get_tasks(main.EVENT_QUEUE, expected_count=0)

@@ -1817,7 +1858,7 @@
     old_max_new = main.MAX_NEW_FEED_ENTRY_RECORDS
     main.MAX_NEW_FEED_ENTRY_RECORDS = len(self.all_ids) + 1
     try:
-        self.handle('post', ('topic', self.topic))
+        self.run_fetch_task()
     finally:
       main.MAX_NEW_FEED_ENTRY_RECORDS = old_max_new

@@ -1872,7 +1913,7 @@
     main.MAX_FEED_RECORD_SAVES = len(self.entry_list) + 1
     main.MAX_NEW_FEED_ENTRY_RECORDS = main.MAX_FEED_RECORD_SAVES
     try:
-      self.handle('post', ('topic', self.topic))
+      self.run_fetch_task()
     finally:
       main.PUT_SPLITTING_ATTEMPTS = old_splitting_attempts
       main.MAX_FEED_RECORD_SAVES = old_max_saves
@@ -1898,7 +1939,7 @@
         'get', self.topic, 200, '',
         response_headers=self.headers,
         urlfetch_size_error=True)
-    self.handle('post', ('topic', self.topic))
+    self.run_fetch_task()
     self.assertEquals([], list(FeedEntryRecord.all()))
     self.assertEquals(None, EventToDeliver.all().get())
     testutil.get_tasks(main.EVENT_QUEUE, expected_count=0)
@@ -1922,7 +1963,7 @@
         'get', self.topic, 200, self.expected_response,
         response_headers=self.headers)

-    self.handle('post', ('topic', self.topic))
+    self.run_fetch_task()

     # Verify that a subset of the entry records are present and the payload
     # only has the first N entries.
@@ -1963,7 +2004,7 @@
       info.update(self.headers)
       info.put()
       FeedToFetch.insert([self.topic])
-      self.handle('post', ('topic', self.topic))
+      self.run_fetch_task()

       # Verify that *no* feed entry records have been written.
       self.assertEquals([], FeedEntryRecord.get_entries_for_topic(
@@ -1984,6 +2025,15 @@

   handler_class = main.PullFeedHandler

+  def run_fetch_task(self, index=0):
+    """Runs the currently enqueued fetch task."""
+    task = testutil.get_tasks(main.FEED_QUEUE, index=index)
+    os.environ['HTTP_X_APPENGINE_TASKNAME'] = task['name']
+    try:
+      self.handle('post')
+    finally:
+      del os.environ['HTTP_X_APPENGINE_TASKNAME']
+
   def testPullBadContent(self):
     """Tests when the content doesn't parse correctly."""
     topic = 'http://example.com/my-topic'
@@ -1992,7 +2042,8 @@
     FeedToFetch.insert([topic])
     urlfetch_test_stub.instance.expect(
         'get', topic, 200, 'this does not parse')
-    self.handle('post', ('topic', topic))
+    self.run_fetch_task()
+    # No retry task should be written.
     feed = FeedToFetch.get_by_key_name(get_hash_key_name(topic))
     self.assertTrue(feed is None)

@@ -2005,7 +2056,8 @@
self.assertTrue(Subscription.insert(callback, topic, 'token', 'secret'))
     FeedToFetch.insert([topic])
     urlfetch_test_stub.instance.expect('get', topic, 200, data)
-    self.handle('post', ('topic', topic))
+    self.run_fetch_task()
+    # No retry task should be written.
     feed = FeedToFetch.get_by_key_name(get_hash_key_name(topic))
     self.assertTrue(feed is None)

@@ -2019,7 +2071,8 @@
self.assertTrue(Subscription.insert(callback, topic, 'token', 'secret'))
     FeedToFetch.insert([topic])
     urlfetch_test_stub.instance.expect('get', topic, 200, data)
-    self.handle('post', ('topic', topic))
+    self.run_fetch_task()
+    # No retry task should be written.
     feed = FeedToFetch.get_by_key_name(get_hash_key_name(topic))
     self.assertTrue(feed is None)

@@ -2032,7 +2085,7 @@
self.assertTrue(Subscription.insert(callback, topic, 'token', 'secret'))
     FeedToFetch.insert([topic])
     urlfetch_test_stub.instance.expect('get', topic, 200, data)
-    self.handle('post', ('topic', topic))
+    self.run_fetch_task()
     feed = FeedToFetch.get_by_key_name(get_hash_key_name(topic))
     self.assertTrue(feed is None)
     event = EventToDeliver.all().get()
@@ -2056,7 +2109,7 @@
         'ETag': '\xe3\x83\x96\xe3\x83\xad\xe3\x82\xb0\xe8\xa1\x86',
         'Content-Type': 'application/atom+xml',
     })
-    self.handle('post', ('topic', topic))
+    self.run_fetch_task()
     feed = FeedToFetch.get_by_key_name(get_hash_key_name(topic))
     self.assertTrue(feed is None)
     event = EventToDeliver.all().get()
@@ -2078,7 +2131,7 @@
self.assertTrue(Subscription.insert(callback, topic, 'token', 'secret'))
     FeedToFetch.insert([topic])
     urlfetch_test_stub.instance.expect('get', topic, 200, data)
-    self.handle('post', ('topic', topic))
+    self.run_fetch_task()
     feed = FeedToFetch.get_by_key_name(get_hash_key_name(topic))
     self.assertTrue(feed is None)
     event = EventToDeliver.all().get()
@@ -2098,7 +2151,7 @@
self.assertTrue(Subscription.insert(callback, topic, 'token', 'secret'))
     FeedToFetch.insert([topic])
     urlfetch_test_stub.instance.expect('get', topic, 200, data)
-    self.handle('post', ('topic', topic))
+    self.run_fetch_task()
     feed = FeedToFetch.get_by_key_name(get_hash_key_name(topic))
     self.assertTrue(feed is None)
     event = EventToDeliver.all().get()
@@ -2107,7 +2160,10 @@
     self.assertEquals('rss', FeedRecord.all().get().format)

   def testMultipleFetch(self):
-    """Tests doing multiple fetches asynchronously in parallel."""
+    """Tests doing multiple fetches asynchronously in parallel.
+
+    Exercises the fork-join queue part of the fetching pipeline.
+    """
data = ('<?xml version="1.0" encoding="utf-8"?>\n<feed><my header="data"/>'
             '<entry><id>1</id><updated>123</updated>wooh</entry></feed>')
     topic_base = 'http://example.com/my-topic'
@@ -3672,8 +3728,9 @@
     called = [False]
     topics = ['one', 'two', 'three']
     @classmethod
-    def new_insert(cls, topic_list):
+    def new_insert(cls, topic_list, memory_only=True):
       called[0] = True
+      self.assertFalse(memory_only)
       self.assertEquals(topic_list, topics)
       raise db.Error('Mock DB error')

@@ -3716,8 +3773,6 @@
     self.assertTrue(FeedToFetch.get_by_topic(topic3) is None)

# This will repeatedly insert the initial task to start the polling process. - # TODO(bslatkin): This is actually broken. Stub needs to be fixed to ignore
-    # duplicate task names.
     self.handle('get')
     self.handle('get')
     self.handle('get')
@@ -3737,7 +3792,7 @@
# iterating through all KnownFeed entries or the fork-join queue task that
     # will do the actual fetching.
     self.handle('post', *task['params'].items())
- task = testutil.get_tasks(main.POLLING_QUEUE, index=1, expected_count=3) + task = testutil.get_tasks(main.POLLING_QUEUE, index=1, expected_count=2)
     self.assertEquals(sequence, task['params']['sequence'])
     self.assertEquals('bootstrap', task['params']['poll_type'])
     self.assertEquals(str(KnownFeed.create_key(topic2)),
@@ -3754,7 +3809,7 @@
# the continuation task to prevent doing any more work in the current cycle.
     self.handle('post', *task['params'].items())

-    task_list = testutil.get_tasks(main.POLLING_QUEUE, expected_count=4)
+    task_list = testutil.get_tasks(main.POLLING_QUEUE, expected_count=3)

     # Deal with a stupid race condition
     task = task_list[2]
@@ -3769,7 +3824,7 @@

     # Starting the cycle again will do nothing.
     self.handle('get')
-    testutil.get_tasks(main.POLLING_QUEUE, expected_count=4)
+    testutil.get_tasks(main.POLLING_QUEUE, expected_count=3)

     # Resetting the next start time to before the present time will
     # cause the iteration to start again.
@@ -3779,13 +3834,8 @@
     db.put(the_mark)
     self.handle('get')

-    task_list = testutil.get_tasks(main.POLLING_QUEUE, expected_count=5)
-
-    # Deal with stupid race condition
-    task = task_list[4]
-    if 'params' not in task:
-      task = task_list[3]
-
+    task_list = testutil.get_tasks(main.POLLING_QUEUE, expected_count=4)
+    task = task_list[3]
     self.assertNotEquals(sequence, task['params']['sequence'])

   def testRecord(self):

Reply via email to