Repository: incubator-beam
Updated Branches:
  refs/heads/python-sdk 53ab635c7 -> b4716d9dc


Better top implementation.

When selecting the top k of n, it is common that k << n.
Using a heap is O(n log k) while select algorithms can
achieve O(n + k log k).

This also avoids the ugliness that heapq does not take the
comparator as an argument, resulting in _HeapItem classes that
were cumbersome and expensive to serialize.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/adb3ed93
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/adb3ed93
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/adb3ed93

Branch: refs/heads/python-sdk
Commit: adb3ed93053c83b4e28e7baa879e9aee82f02785
Parents: 53ab635
Author: Robert Bradshaw <rober...@gmail.com>
Authored: Wed Jul 27 10:09:49 2016 -0700
Committer: Robert Bradshaw <rober...@gmail.com>
Committed: Thu Jul 28 11:08:05 2016 -0700

----------------------------------------------------------------------
 sdks/python/apache_beam/transforms/combiners.py | 111 +++++++++----------
 1 file changed, 51 insertions(+), 60 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adb3ed93/sdks/python/apache_beam/transforms/combiners.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/transforms/combiners.py 
b/sdks/python/apache_beam/transforms/combiners.py
index 8c56e5a..453c0f8 100644
--- a/sdks/python/apache_beam/transforms/combiners.py
+++ b/sdks/python/apache_beam/transforms/combiners.py
@@ -228,76 +228,67 @@ class TopCombineFn(core.CombineFn):
   apply call become additional arguments to the comparator.
   """
 
-  # Actually pickling the comparison operators (including, often, their
-  # entire globals) can be very expensive.  Instead refer to them by index
-  # in this dictionary, which is populated on construction (including
-  # unpickling).
-  compare_by_id = {}
-
-  def __init__(self, n, compare, _compare_id=None):  # pylint: 
disable=invalid-name
+  # TODO(robertwb): Allow taking a key rather than a compare.
+  def __init__(self, n, compare):
     self._n = n
+    self._buffer_size = min(2 * n, n + 1000)
     self._compare = compare
-    self._compare_id = _compare_id or id(compare)
-    TopCombineFn.compare_by_id[self._compare_id] = self._compare
-
-  def __reduce_ex__(self, _):
-    return TopCombineFn, (self._n, self._compare, self._compare_id)
 
-  class _HeapItem(object):
-    """A wrapper for values supporting arbitrary comparisons.
-
-    The heap implementation supplied by Python is a min heap that always uses
-    the __lt__ operator if one is available. This wrapper overloads __lt__,
-    letting us specify arbitrary precedence for elements in the PCollection.
-    """
+  def create_accumulator(self, *args, **kwargs):
+    return None, []
 
-    def __init__(self, item, compare_id, *args, **kwargs):
-      # item:         wrapped item.
-      # compare:      an implementation of the pairwise < operator.
-      # args, kwargs: extra arguments supplied to the compare function.
-      self.item = item
-      self.compare_id = compare_id
-      self.args = args
-      self.kwargs = kwargs
+  def add_input(self, accumulator, element, *args, **kwargs):
+    if args or kwargs:
+      lt = lambda a, b: self._compare(a, b, *args, **kwargs)
+    else:
+      lt = self._compare
 
-    def __lt__(self, other):
-      return TopCombineFn.compare_by_id[self.compare_id](
-          self.item, other.item, *self.args, **self.kwargs)
+    threshold, buffer = accumulator
+    if len(buffer) < self._n:
+      if not buffer:
+        return element, [element]
+      else:
+        buffer.append(element)
+        if lt(element, threshold):  # element < threshold
+          return element, buffer
+        else:
+          return accumulator  # with mutated buffer
+    elif lt(threshold, element):  # threshold < element
+      buffer.append(element)
+      if len(buffer) < self._buffer_size:
+        return accumulator
+      else:
+        buffer.sort(cmp=lambda a, b: (not lt(a, b)) - (not lt(b, a)))
+        return buffer[-self._n], buffer[-self._n:]
+    else:
+      return accumulator
 
-  def create_accumulator(self, *args, **kwargs):
-    return []  # Empty heap.
-
-  def add_input(self, heap, element, *args, **kwargs):
-    # Note that because heap is a min heap, heappushpop will discard incoming
-    # elements that are lesser (according to compare) than those in the heap
-    # (since that's what you would get if you pushed a small element on and
-    # popped the smallest element off). So, filtering a collection with a
-    # min-heap gives you the largest elements in the collection.
-    item = self._HeapItem(element, self._compare_id, *args, **kwargs)
-    if len(heap) < self._n:
-      heapq.heappush(heap, item)
+  def merge_accumulators(self, accumulators, *args, **kwargs):
+    accumulators = list(accumulators)
+    if args or kwargs:
+      add_input = lambda accumulator, element: self.add_input(
+          accumulator, element, *args, **kwargs)
     else:
-      heapq.heappushpop(heap, item)
-    return heap
-
-  def merge_accumulators(self, heaps, *args, **kwargs):
-    heap = []
-    for e in itertools.chain(*heaps):
-      if len(heap) < self._n:
-        heapq.heappush(heap, e)
-      else:
-        heapq.heappushpop(heap, e)
-    return heap
+      add_input = self.add_input
 
-  def extract_output(self, heap, *args, **kwargs):
-    # Items in the heap are heap-ordered. We put them in sorted order, but we
-    # have to use the reverse order because the result is expected to go
-    # from greatest to least (as defined by the supplied comparison function).
-    return [e.item for e in sorted(heap, reverse=True)]
+    total_accumulator = None
+    for accumulator in accumulators:
+      if total_accumulator is None:
+        total_accumulator = accumulator
+      else:
+        for element in accumulator[1]:
+          total_accumulator = add_input(total_accumulator, element)
+    return total_accumulator
 
+  def extract_output(self, accumulator, *args, **kwargs):
+    if args or kwargs:
+      lt = lambda a, b: self._compare(a, b, *args, **kwargs)
+    else:
+      lt = self._compare
 
-# Python's pickling is broken for nested classes.
-_HeapItem = TopCombineFn._HeapItem  # pylint: disable=protected-access
+    _, buffer = accumulator
+    buffer.sort(cmp=lambda a, b: (not lt(a, b)) - (not lt(b, a)))
+    return buffer[:-self._n-1:-1]
 
 
 class Largest(TopCombineFn):

Reply via email to