Execute windowing in Fn API runner.
Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/ccc32a25 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/ccc32a25 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/ccc32a25 Branch: refs/heads/master Commit: ccc32a25fc139135b13c5fd14353377c4343e403 Parents: a92c45f Author: Robert Bradshaw <[email protected]> Authored: Thu Sep 14 00:30:10 2017 -0700 Committer: Robert Bradshaw <[email protected]> Committed: Thu Sep 21 10:59:52 2017 -0700 ---------------------------------------------------------------------- .../runners/portability/fn_api_runner.py | 52 ++++++++++++++------ sdks/python/apache_beam/transforms/core.py | 16 ++---- sdks/python/apache_beam/transforms/trigger.py | 13 +++++ 3 files changed, 52 insertions(+), 29 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/ccc32a25/sdks/python/apache_beam/runners/portability/fn_api_runner.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py index b0faa38..74bae11 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py @@ -46,6 +46,7 @@ from apache_beam.runners.worker import bundle_processor from apache_beam.runners.worker import data_plane from apache_beam.runners.worker import operation_specs from apache_beam.runners.worker import sdk_worker +from apache_beam.transforms import trigger from apache_beam.transforms.window import GlobalWindows from apache_beam.utils import proto_utils from apache_beam.utils import urns @@ -126,25 +127,31 @@ OLDE_SOURCE_SPLITTABLE_DOFN_DATA = pickler.dumps( class _GroupingBuffer(object): """Used to accumulate groupded (shuffled) results.""" - def __init__(self, pre_grouped_coder, post_grouped_coder): + def __init__(self, pre_grouped_coder, post_grouped_coder, windowing): self._key_coder = pre_grouped_coder.key_coder() self._pre_grouped_coder = pre_grouped_coder self._post_grouped_coder = post_grouped_coder self._table = collections.defaultdict(list) + self._windowing = windowing def append(self, elements_data): input_stream = create_InputStream(elements_data) while input_stream.size() > 0: - key, value = self._pre_grouped_coder.get_impl().decode_from_stream( - input_stream, True).value - self._table[self._key_coder.encode(key)].append(value) + windowed_key_value = self._pre_grouped_coder.get_impl( + ).decode_from_stream(input_stream, True) + key = windowed_key_value.value[0] + windowed_value = windowed_key_value.with_value( + windowed_key_value.value[1]) + self._table[self._key_coder.encode(key)].append(windowed_value) def __iter__(self): output_stream = create_OutputStream() - for encoded_key, values in self._table.items(): + trigger_driver = trigger.create_trigger_driver(self._windowing, True) + for encoded_key, windowed_values in self._table.items(): key = self._key_coder.decode(encoded_key) - self._post_grouped_coder.get_impl().encode_to_stream( - GlobalWindows.windowed_value((key, values)), output_stream, True) + for wkvs in trigger_driver.process_entire_key(key, windowed_values): + self._post_grouped_coder.get_impl().encode_to_stream( + wkvs, output_stream, True) return iter([output_stream.get()]) @@ -326,7 +333,7 @@ class FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner): for stage in stages: assert len(stage.transforms) == 1 transform = stage.transforms[0] - if transform.spec.urn == urns.GROUP_BY_KEY_ONLY_TRANSFORM: + if transform.spec.urn == urns.GROUP_BY_KEY_TRANSFORM: for pcoll_id in transform.inputs.values(): fix_pcoll_coder(pipeline_components.pcollections[pcoll_id]) for pcoll_id in transform.outputs.values(): @@ -608,11 +615,21 @@ class FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner): pcoll.coder_id = coders.get_id(coder) coders.populate_map(pipeline_components.coders) - # Initial set of stages are singleton transforms. + known_composites = set([urns.GROUP_BY_KEY_TRANSFORM]) + + def leaf_transforms(root_ids): + for root_id in root_ids: + root = pipeline_proto.components.transforms[root_id] + if root.spec.urn in known_composites or not root.subtransforms: + yield root_id + else: + for leaf in leaf_transforms(root.subtransforms): + yield leaf + + # Initial set of stages are singleton leaf transforms. stages = [ - Stage(name, [transform]) - for name, transform in pipeline_proto.components.transforms.items() - if not transform.subtransforms] + Stage(name, [pipeline_proto.components.transforms[name]]) + for name in leaf_transforms(pipeline_proto.root_transform_ids)] # Apply each phase in order. for phase in [ @@ -645,7 +662,7 @@ class FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner): def run_stage( self, controller, pipeline_components, stage, pcoll_buffers, safe_coders): - coders = pipeline_context.PipelineContext(pipeline_components).coders + context = pipeline_context.PipelineContext(pipeline_components) data_operation_spec = controller.data_operation_spec() def extract_endpoints(stage): @@ -744,12 +761,15 @@ class FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner): original_gbk_transform] input_pcoll = only_element(transform_proto.inputs.values()) output_pcoll = only_element(transform_proto.outputs.values()) - pre_gbk_coder = coders[safe_coders[ + pre_gbk_coder = context.coders[safe_coders[ pipeline_components.pcollections[input_pcoll].coder_id]] - post_gbk_coder = coders[safe_coders[ + post_gbk_coder = context.coders[safe_coders[ pipeline_components.pcollections[output_pcoll].coder_id]] + windowing_strategy = context.windowing_strategies[ + pipeline_components + .pcollections[output_pcoll].windowing_strategy_id] pcoll_buffers[pcoll_id] = _GroupingBuffer( - pre_gbk_coder, post_gbk_coder) + pre_gbk_coder, post_gbk_coder, windowing_strategy) pcoll_buffers[pcoll_id].append(output.data) else: # These should be the only two identifiers we produce for now, http://git-wip-us.apache.org/repos/asf/beam/blob/ccc32a25/sdks/python/apache_beam/transforms/core.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index ceaa60a..0a82de2 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -38,7 +38,6 @@ from apache_beam.transforms.display import DisplayDataItem from apache_beam.transforms.display import HasDisplayData from apache_beam.transforms.ptransform import PTransform from apache_beam.transforms.ptransform import PTransformWithSideInputs -from apache_beam.transforms.window import MIN_TIMESTAMP from apache_beam.transforms.window import GlobalWindows from apache_beam.transforms.window import TimestampCombiner from apache_beam.transforms.window import TimestampedValue @@ -1186,6 +1185,8 @@ class GroupByKey(PTransform): # Initialize type-hints used below to enforce type-checking and to pass # downstream to further PTransforms. key_type, value_type = trivial_inference.key_value_types(input_type) + # Enforce the input to a GBK has a KV element type. + pcoll.element_type = KV[key_type, value_type] typecoders.registry.verify_deterministic( typecoders.registry.get_coder(key_type), 'GroupByKey operation "%s"' % self.label) @@ -1281,24 +1282,13 @@ class _GroupAlsoByWindowDoFn(DoFn): def start_bundle(self): # pylint: disable=wrong-import-order, wrong-import-position - from apache_beam.transforms.trigger import InMemoryUnmergedState from apache_beam.transforms.trigger import create_trigger_driver # pylint: enable=wrong-import-order, wrong-import-position self.driver = create_trigger_driver(self.windowing, True) - self.state_type = InMemoryUnmergedState def process(self, element): k, vs = element - state = self.state_type() - # TODO(robertwb): Conditionally process in smaller chunks. - for wvalue in self.driver.process_elements(state, vs, MIN_TIMESTAMP): - yield wvalue.with_value((k, wvalue.value)) - while state.timers: - fired = state.get_and_clear_timers() - for timer_window, (name, time_domain, fire_time) in fired: - for wvalue in self.driver.process_timer( - timer_window, name, time_domain, fire_time, state): - yield wvalue.with_value((k, wvalue.value)) + return self.driver.process_entire_key(k, vs) class Partition(PTransformWithSideInputs): http://git-wip-us.apache.org/repos/asf/beam/blob/ccc32a25/sdks/python/apache_beam/transforms/trigger.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/transforms/trigger.py b/sdks/python/apache_beam/transforms/trigger.py index 8175d30..3583e62 100644 --- a/sdks/python/apache_beam/transforms/trigger.py +++ b/sdks/python/apache_beam/transforms/trigger.py @@ -859,6 +859,19 @@ class TriggerDriver(object): def process_timer(self, window_id, name, time_domain, timestamp, state): pass + def process_entire_key( + self, key, windowed_values, output_watermark=MIN_TIMESTAMP): + state = InMemoryUnmergedState() + for wvalue in self.process_elements( + state, windowed_values, output_watermark): + yield wvalue.with_value((key, wvalue.value)) + while state.timers: + fired = state.get_and_clear_timers() + for timer_window, (name, time_domain, fire_time) in fired: + for wvalue in self.process_timer( + timer_window, name, time_domain, fire_time, state): + yield wvalue.with_value((key, wvalue.value)) + class _UnwindowedValues(observable.ObservableMixin): """Exposes iterable of windowed values as iterable of unwindowed values."""
