[ 
https://issues.apache.org/jira/browse/BEAM-4546?focusedWorklogId=111794&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-111794
 ]

ASF GitHub Bot logged work on BEAM-4546:
----------------------------------------

                Author: ASF GitHub Bot
            Created on: 14/Jun/18 06:25
            Start Date: 14/Jun/18 06:25
    Worklog Time Spent: 10m 
      Work Description: robertwb closed pull request #5639: [BEAM-4546] Multi 
level combine
URL: https://github.com/apache/beam/pull/5639
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/sdks/python/apache_beam/transforms/combiners_test.py 
b/sdks/python/apache_beam/transforms/combiners_test.py
index 705106e784c..afba5558620 100644
--- a/sdks/python/apache_beam/transforms/combiners_test.py
+++ b/sdks/python/apache_beam/transforms/combiners_test.py
@@ -17,6 +17,8 @@
 
 """Unit tests for our libraries of combine PTransforms."""
 
+import itertools
+import random
 import unittest
 
 import hamcrest as hc
@@ -311,6 +313,29 @@ def expand(self, pcoll):
       assert_that(result1, equal_to([0]), label='r1')
       assert_that(result2, equal_to([10]), label='r2')
 
+  def test_hot_key_fanout(self):
+    with TestPipeline() as p:
+      result = (
+          p
+          | beam.Create(itertools.product(['hot', 'cold'], range(10)))
+          | beam.CombinePerKey(combine.MeanCombineFn()).with_hot_key_fanout(
+              lambda key: (key == 'hot') * 5))
+      assert_that(result, equal_to([('hot', 4.5), ('cold', 4.5)]))
+
+  def test_hot_key_fanout_sharded(self):
+    # Lots of elements with the same key with varying/no fanout.
+    with TestPipeline() as p:
+      elements = [(None, e) for e in range(1000)]
+      random.shuffle(elements)
+      shards = [p | "Shard%s" % shard >> beam.Create(elements[shard::20])
+                for shard in range(20)]
+      result = (
+          shards
+          | beam.Flatten()
+          | beam.CombinePerKey(combine.MeanCombineFn()).with_hot_key_fanout(
+              lambda key: random.randrange(0, 5)))
+      assert_that(result, equal_to([(None, 499.5)]))
+
 
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/transforms/core.py 
b/sdks/python/apache_beam/transforms/core.py
index 2850928099d..f02609ad2c9 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -21,6 +21,7 @@
 
 import copy
 import inspect
+import random
 import types
 
 from six import string_types
@@ -1089,6 +1090,7 @@ class CombineGlobally(PTransform):
   """
   has_defaults = True
   as_view = False
+  fanout = None
 
   def __init__(self, fn, *args, **kwargs):
     if not (isinstance(fn, CombineFn) or callable(fn)):
@@ -1115,6 +1117,9 @@ def _clone(self, **extra_attributes):
     clone.__dict__.update(extra_attributes)
     return clone
 
+  def with_fanout(self, fanout):
+    return self._clone(fanout=fanout)
+
   def with_defaults(self, has_defaults=True):
     return self._clone(has_defaults=has_defaults)
 
@@ -1131,12 +1136,15 @@ def add_input_types(transform):
         return transform.with_input_types(type_hints.input_types[0][0])
       return transform
 
+    combine_per_key = CombinePerKey(self.fn, *self.args, **self.kwargs)
+    if self.fanout:
+      combine_per_key = combine_per_key.with_hot_key_fanout(fanout)
+
     combined = (pcoll
                 | 'KeyWithVoid' >> add_input_types(
                     Map(lambda v: (None, v)).with_output_types(
                         KV[None, pcoll.element_type]))
-                | 'CombinePerKey' >> CombinePerKey(
-                    self.fn, *self.args, **self.kwargs)
+                | 'CombinePerKey' >> combine_per_key
                 | 'UnKey' >> Map(lambda k_v: k_v[1]))
 
     if not self.has_defaults and not self.as_view:
@@ -1191,6 +1199,34 @@ class CombinePerKey(PTransformWithSideInputs):
   Returns:
     A PObject holding the result of the combine operation.
   """
+  def with_hot_key_fanout(self, fanout):
+    """A per-key combine operation like self but with two levels of 
aggregation.
+
+    If a given key is produced by too many upstream bundles, the final
+    reduction can become a bottleneck despite partial combining being lifted
+    pre-GroupByKey.  In these cases it can be helpful to perform intermediate
+    partial aggregations in parallel and then re-group to peform a final
+    (per-key) combine.  This is also useful for high-volume keys in streaming
+    where combiners are not generally lifted for latency reasons.
+
+    Note that a fanout greater than 1 requires the data to be sent through
+    two GroupByKeys, and a high fanout can also result in more shuffle data
+    due to less per-bundle combining. Setting the fanout for a key at 1 or less
+    places values on the "cold key" path that skip the intermediate level of
+    aggregation.
+
+    Args:
+      fanout: either an int, for a constant-degree fanout, or a callable
+          mapping keys to a key-specific degree of fanout
+
+    Returns:
+      A per-key combining PTransform with the specified fanout.
+    """
+    from apache_beam.transforms.combiners import curry_combine_fn
+    return _CombinePerKeyWithHotKeyFanout(
+        curry_combine_fn(self.fn, self.args, self.kwargs),
+        fanout)
+
   def display_data(self):
     return {'combine_fn':
             DisplayDataItem(self.fn.__class__, label='Combine Function'),
@@ -1337,6 +1373,78 @@ def default_type_hints(self):
     return hints
 
 
+class _CombinePerKeyWithHotKeyFanout(PTransform):
+
+  def __init__(self, combine_fn, fanout):
+    self._fanout_fn = (
+        (lambda key: fanout) if isinstance(fanout, int) else fanout)
+    self._combine_fn = combine_fn
+
+  def expand(self, pcoll):
+
+    from apache_beam.transforms.trigger import AccumulationMode
+    combine_fn = self._combine_fn
+    fanout_fn = self._fanout_fn
+
+    class SplitHotCold(DoFn):
+
+      def start_bundle(self):
+        self.counter = random.randrange(10000)
+
+      def process(self, element):
+        key, value = element
+        fanout = fanout_fn(key)
+        if fanout <= 1:
+          # Boolean indicates this is not an accumulator.
+          yield pvalue.TaggedOutput('cold', (key, (False, value)))
+        else:
+          # Round-robin should spread things more evenly than random 
assignment.
+          self.counter += 1.
+          yield pvalue.TaggedOutput('hot',
+                                    ((self.counter % fanout, key), value))
+
+    class PreCombineFn(CombineFn):
+      @staticmethod
+      def extract_output(accumulator):
+        # Boolean indicates this is an accumulator.
+        return (True, accumulator)
+      create_accumulator = combine_fn.create_accumulator
+      add_input = combine_fn.add_input
+      merge_accumulators = combine_fn.merge_accumulators
+
+    class PostCombineFn(CombineFn):
+      @staticmethod
+      def add_input(accumulator, input):
+        is_accumulator, value = input
+        if is_accumulator:
+          return combine_fn.merge_accumulators([accumulator, value])
+        else:
+          return combine_fn.add_input(accumulator, value)
+      create_accumulator = combine_fn.create_accumulator
+      merge_accumulators = combine_fn.merge_accumulators
+      extract_output = combine_fn.extract_output
+
+    def StripNonce(nonce_key_value):
+      (_, key), value = nonce_key_value
+      return key, value
+
+    hot, cold = pcoll | ParDo(SplitHotCold()).with_outputs('hot', 'cold')
+    # No multi-output type hints.
+    cold.element_type = typehints.Any
+    precombined_hot = (
+        hot
+        # Avoid double counting that may happen with stacked accumulating mode.
+        | 'WindowIntoDiscarding' >> WindowInto(
+            pcoll.windowing, accumulation_mode=AccumulationMode.DISCARDING)
+        | CombinePerKey(PreCombineFn())
+        | Map(StripNonce)
+        | 'WindowIntoOriginal' >> WindowInto(pcoll.windowing))
+    return (
+        (cold, precombined_hot)
+        | Flatten()
+        | CombinePerKey(PostCombineFn()))
+
+
 @typehints.with_input_types(typehints.KV[K, V])
 @typehints.with_output_types(typehints.KV[K, typehints.Iterable[V]])
 class GroupByKey(PTransform):
@@ -1614,14 +1722,23 @@ def __init__(self, windowfn, **kwargs):
       timestamp_combiner: (optional) Timestamp combniner used for windowing,
           or None for default.
     """
+    if isinstance(windowfn, Windowing):
+      # Overlay windowing with kwargs.
+      windowing = windowfn
+      windowfn = windowing.windowfn
+      # Use windowing to fill in defaults for kwargs.
+      kwargs = dict(dict(
+          trigger=windowing.triggerfn,
+          accumulation_mode=windowing.accumulation_mode,
+          timestamp_combiner=windowing.timestamp_combiner), **kwargs)
     # Use kwargs to simulate keyword-only arguments.
     triggerfn = kwargs.pop('trigger', None)
     accumulation_mode = kwargs.pop('accumulation_mode', None)
     timestamp_combiner = kwargs.pop('timestamp_combiner', None)
     if kwargs:
       raise ValueError('Unexpected keyword arguments: %s' % kwargs.keys())
-    self.windowing = Windowing(windowfn, triggerfn, accumulation_mode,
-                               timestamp_combiner)
+    self.windowing = Windowing(
+        windowfn, triggerfn, accumulation_mode, timestamp_combiner)
     super(WindowInto, self).__init__(self.WindowIntoFn(self.windowing))
 
   def get_windowing(self, unused_inputs):


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Issue Time Tracking
-------------------

    Worklog Id:     (was: 111794)
    Time Spent: 1.5h  (was: 1h 20m)

> Implement with hot key fanout for combiners
> -------------------------------------------
>
>                 Key: BEAM-4546
>                 URL: https://issues.apache.org/jira/browse/BEAM-4546
>             Project: Beam
>          Issue Type: New Feature
>          Components: sdk-py-core
>            Reporter: Ahmet Altay
>            Assignee: Robert Bradshaw
>            Priority: Major
>          Time Spent: 1.5h
>  Remaining Estimate: 0h
>




--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Reply via email to