Execute windowing in Fn API runner.

Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/ccc32a25
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/ccc32a25
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/ccc32a25

Branch: refs/heads/master
Commit: ccc32a25fc139135b13c5fd14353377c4343e403
Parents: a92c45f
Author: Robert Bradshaw <[email protected]>
Authored: Thu Sep 14 00:30:10 2017 -0700
Committer: Robert Bradshaw <[email protected]>
Committed: Thu Sep 21 10:59:52 2017 -0700

----------------------------------------------------------------------
 .../runners/portability/fn_api_runner.py        | 52 ++++++++++++++------
 sdks/python/apache_beam/transforms/core.py      | 16 ++----
 sdks/python/apache_beam/transforms/trigger.py   | 13 +++++
 3 files changed, 52 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/ccc32a25/sdks/python/apache_beam/runners/portability/fn_api_runner.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py 
b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
index b0faa38..74bae11 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -46,6 +46,7 @@ from apache_beam.runners.worker import bundle_processor
 from apache_beam.runners.worker import data_plane
 from apache_beam.runners.worker import operation_specs
 from apache_beam.runners.worker import sdk_worker
+from apache_beam.transforms import trigger
 from apache_beam.transforms.window import GlobalWindows
 from apache_beam.utils import proto_utils
 from apache_beam.utils import urns
@@ -126,25 +127,31 @@ OLDE_SOURCE_SPLITTABLE_DOFN_DATA = pickler.dumps(
 
 class _GroupingBuffer(object):
   """Used to accumulate groupded (shuffled) results."""
-  def __init__(self, pre_grouped_coder, post_grouped_coder):
+  def __init__(self, pre_grouped_coder, post_grouped_coder, windowing):
     self._key_coder = pre_grouped_coder.key_coder()
     self._pre_grouped_coder = pre_grouped_coder
     self._post_grouped_coder = post_grouped_coder
     self._table = collections.defaultdict(list)
+    self._windowing = windowing
 
   def append(self, elements_data):
     input_stream = create_InputStream(elements_data)
     while input_stream.size() > 0:
-      key, value = self._pre_grouped_coder.get_impl().decode_from_stream(
-          input_stream, True).value
-      self._table[self._key_coder.encode(key)].append(value)
+      windowed_key_value = self._pre_grouped_coder.get_impl(
+          ).decode_from_stream(input_stream, True)
+      key = windowed_key_value.value[0]
+      windowed_value = windowed_key_value.with_value(
+          windowed_key_value.value[1])
+      self._table[self._key_coder.encode(key)].append(windowed_value)
 
   def __iter__(self):
     output_stream = create_OutputStream()
-    for encoded_key, values in self._table.items():
+    trigger_driver = trigger.create_trigger_driver(self._windowing, True)
+    for encoded_key, windowed_values in self._table.items():
       key = self._key_coder.decode(encoded_key)
-      self._post_grouped_coder.get_impl().encode_to_stream(
-          GlobalWindows.windowed_value((key, values)), output_stream, True)
+      for wkvs in trigger_driver.process_entire_key(key, windowed_values):
+        self._post_grouped_coder.get_impl().encode_to_stream(
+            wkvs, output_stream, True)
     return iter([output_stream.get()])
 
 
@@ -326,7 +333,7 @@ class 
FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner):
       for stage in stages:
         assert len(stage.transforms) == 1
         transform = stage.transforms[0]
-        if transform.spec.urn == urns.GROUP_BY_KEY_ONLY_TRANSFORM:
+        if transform.spec.urn == urns.GROUP_BY_KEY_TRANSFORM:
           for pcoll_id in transform.inputs.values():
             fix_pcoll_coder(pipeline_components.pcollections[pcoll_id])
           for pcoll_id in transform.outputs.values():
@@ -608,11 +615,21 @@ class 
FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner):
         pcoll.coder_id = coders.get_id(coder)
     coders.populate_map(pipeline_components.coders)
 
-    # Initial set of stages are singleton transforms.
+    known_composites = set([urns.GROUP_BY_KEY_TRANSFORM])
+
+    def leaf_transforms(root_ids):
+      for root_id in root_ids:
+        root = pipeline_proto.components.transforms[root_id]
+        if root.spec.urn in known_composites or not root.subtransforms:
+          yield root_id
+        else:
+          for leaf in leaf_transforms(root.subtransforms):
+            yield leaf
+
+    # Initial set of stages are singleton leaf transforms.
     stages = [
-        Stage(name, [transform])
-        for name, transform in pipeline_proto.components.transforms.items()
-        if not transform.subtransforms]
+        Stage(name, [pipeline_proto.components.transforms[name]])
+        for name in leaf_transforms(pipeline_proto.root_transform_ids)]
 
     # Apply each phase in order.
     for phase in [
@@ -645,7 +662,7 @@ class 
FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner):
   def run_stage(
       self, controller, pipeline_components, stage, pcoll_buffers, 
safe_coders):
 
-    coders = pipeline_context.PipelineContext(pipeline_components).coders
+    context = pipeline_context.PipelineContext(pipeline_components)
     data_operation_spec = controller.data_operation_spec()
 
     def extract_endpoints(stage):
@@ -744,12 +761,15 @@ class 
FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner):
                 original_gbk_transform]
             input_pcoll = only_element(transform_proto.inputs.values())
             output_pcoll = only_element(transform_proto.outputs.values())
-            pre_gbk_coder = coders[safe_coders[
+            pre_gbk_coder = context.coders[safe_coders[
                 pipeline_components.pcollections[input_pcoll].coder_id]]
-            post_gbk_coder = coders[safe_coders[
+            post_gbk_coder = context.coders[safe_coders[
                 pipeline_components.pcollections[output_pcoll].coder_id]]
+            windowing_strategy = context.windowing_strategies[
+                pipeline_components
+                .pcollections[output_pcoll].windowing_strategy_id]
             pcoll_buffers[pcoll_id] = _GroupingBuffer(
-                pre_gbk_coder, post_gbk_coder)
+                pre_gbk_coder, post_gbk_coder, windowing_strategy)
           pcoll_buffers[pcoll_id].append(output.data)
         else:
           # These should be the only two identifiers we produce for now,

http://git-wip-us.apache.org/repos/asf/beam/blob/ccc32a25/sdks/python/apache_beam/transforms/core.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/transforms/core.py 
b/sdks/python/apache_beam/transforms/core.py
index ceaa60a..0a82de2 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -38,7 +38,6 @@ from apache_beam.transforms.display import DisplayDataItem
 from apache_beam.transforms.display import HasDisplayData
 from apache_beam.transforms.ptransform import PTransform
 from apache_beam.transforms.ptransform import PTransformWithSideInputs
-from apache_beam.transforms.window import MIN_TIMESTAMP
 from apache_beam.transforms.window import GlobalWindows
 from apache_beam.transforms.window import TimestampCombiner
 from apache_beam.transforms.window import TimestampedValue
@@ -1186,6 +1185,8 @@ class GroupByKey(PTransform):
       # Initialize type-hints used below to enforce type-checking and to pass
       # downstream to further PTransforms.
       key_type, value_type = trivial_inference.key_value_types(input_type)
+      # Enforce the input to a GBK has a KV element type.
+      pcoll.element_type = KV[key_type, value_type]
       typecoders.registry.verify_deterministic(
           typecoders.registry.get_coder(key_type),
           'GroupByKey operation "%s"' % self.label)
@@ -1281,24 +1282,13 @@ class _GroupAlsoByWindowDoFn(DoFn):
 
   def start_bundle(self):
     # pylint: disable=wrong-import-order, wrong-import-position
-    from apache_beam.transforms.trigger import InMemoryUnmergedState
     from apache_beam.transforms.trigger import create_trigger_driver
     # pylint: enable=wrong-import-order, wrong-import-position
     self.driver = create_trigger_driver(self.windowing, True)
-    self.state_type = InMemoryUnmergedState
 
   def process(self, element):
     k, vs = element
-    state = self.state_type()
-    # TODO(robertwb): Conditionally process in smaller chunks.
-    for wvalue in self.driver.process_elements(state, vs, MIN_TIMESTAMP):
-      yield wvalue.with_value((k, wvalue.value))
-    while state.timers:
-      fired = state.get_and_clear_timers()
-      for timer_window, (name, time_domain, fire_time) in fired:
-        for wvalue in self.driver.process_timer(
-            timer_window, name, time_domain, fire_time, state):
-          yield wvalue.with_value((k, wvalue.value))
+    return self.driver.process_entire_key(k, vs)
 
 
 class Partition(PTransformWithSideInputs):

http://git-wip-us.apache.org/repos/asf/beam/blob/ccc32a25/sdks/python/apache_beam/transforms/trigger.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/transforms/trigger.py 
b/sdks/python/apache_beam/transforms/trigger.py
index 8175d30..3583e62 100644
--- a/sdks/python/apache_beam/transforms/trigger.py
+++ b/sdks/python/apache_beam/transforms/trigger.py
@@ -859,6 +859,19 @@ class TriggerDriver(object):
   def process_timer(self, window_id, name, time_domain, timestamp, state):
     pass
 
+  def process_entire_key(
+      self, key, windowed_values, output_watermark=MIN_TIMESTAMP):
+    state = InMemoryUnmergedState()
+    for wvalue in self.process_elements(
+        state, windowed_values, output_watermark):
+      yield wvalue.with_value((key, wvalue.value))
+    while state.timers:
+      fired = state.get_and_clear_timers()
+      for timer_window, (name, time_domain, fire_time) in fired:
+        for wvalue in self.process_timer(
+            timer_window, name, time_domain, fire_time, state):
+          yield wvalue.with_value((key, wvalue.value))
+
 
 class _UnwindowedValues(observable.ObservableMixin):
   """Exposes iterable of windowed values as iterable of unwindowed values."""

Reply via email to