Repository: beam Updated Branches: refs/heads/master 3961ce46c -> 26c61f414
[BEAM-1316] start bundle should not output any elements Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/1360c21d Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/1360c21d Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/1360c21d Branch: refs/heads/master Commit: 1360c21dfbe6131d514a96e2d5290aa2308825de Parents: 3961ce4 Author: Sourabh Bajaj <[email protected]> Authored: Wed Apr 26 15:09:03 2017 -0700 Committer: Ahmet Altay <[email protected]> Committed: Wed Apr 26 18:07:05 2017 -0700 ---------------------------------------------------------------------- sdks/python/apache_beam/runners/common.py | 68 ++++++++++++++------ .../apache_beam/transforms/ptransform_test.py | 44 +++++++++---- 2 files changed, 81 insertions(+), 31 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/1360c21d/sdks/python/apache_beam/runners/common.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py index 08071a6..e2a6949 100644 --- a/sdks/python/apache_beam/runners/common.py +++ b/sdks/python/apache_beam/runners/common.py @@ -161,8 +161,8 @@ class DoFnInvoker(object): defaults = self.signature.start_bundle_method.defaults args = [self.context if d == core.DoFn.ContextParam else d for d in defaults] - self.output_processor.process_outputs( - None, self.signature.start_bundle_method.method_value(*args)) + self.output_processor.start_bundle_outputs( + self.signature.start_bundle_method.method_value(*args)) def invoke_finish_bundle(self): """Invokes the DoFn.finish_bundle() method. @@ -170,8 +170,8 @@ class DoFnInvoker(object): defaults = self.signature.finish_bundle_method.defaults args = [self.context if d == core.DoFn.ContextParam else d for d in defaults] - self.output_processor.process_outputs( - None, self.signature.finish_bundle_method.method_value(*args)) + self.output_processor.finish_bundle_outputs( + self.signature.finish_bundle_method.method_value(*args)) class SimpleInvoker(DoFnInvoker): @@ -436,13 +436,14 @@ class OutputProcessor(object): self.tagged_receivers = tagged_receivers def process_outputs(self, windowed_input_element, results): - """Dispatch the result of computation to the appropriate receivers. + """Dispatch the result of process computation to the appropriate receivers. A value wrapped in a OutputValue object will be unwrapped and then dispatched to the appropriate indexed output. """ if results is None: return + for result in results: tag = None if isinstance(result, OutputValue): @@ -455,19 +456,6 @@ class OutputProcessor(object): if (windowed_input_element is not None and len(windowed_input_element.windows) != 1): windowed_value.windows *= len(windowed_input_element.windows) - elif windowed_input_element is None: - # Start and finish have no element from which to grab context, - # but may emit elements. - if isinstance(result, TimestampedValue): - value = result.value - timestamp = result.timestamp - assign_context = NoContext(value, timestamp) - else: - value = result - timestamp = -1 - assign_context = NoContext(value) - windowed_value = WindowedValue( - value, timestamp, self.window_fn.assign(assign_context)) elif isinstance(result, TimestampedValue): assign_context = WindowFn.AssignContext(result.timestamp, result.value) windowed_value = WindowedValue( @@ -482,6 +470,50 @@ class OutputProcessor(object): else: self.tagged_receivers[tag].output(windowed_value) + def start_bundle_outputs(self, results): + """Validate that start_bundle does not output any elements""" + if results is None: + return + raise RuntimeError( + 'Start Bundle should not output any elements but got %s' % results) + + def finish_bundle_outputs(self, results): + """Dispatch the result of finish_bundle to the appropriate receivers. + + A value wrapped in a OutputValue object will be unwrapped and + then dispatched to the appropriate indexed output. + """ + if results is None: + return + + for result in results: + tag = None + if isinstance(result, OutputValue): + tag = result.tag + if not isinstance(tag, basestring): + raise TypeError('In %s, tag %s is not a string' % (self, tag)) + result = result.value + + if isinstance(result, WindowedValue): + windowed_value = result + elif isinstance(result, TimestampedValue): + value = result.value + timestamp = result.timestamp + assign_context = NoContext(value, timestamp) + windowed_value = WindowedValue( + value, timestamp, self.window_fn.assign(assign_context)) + else: + value = result + timestamp = -1 + assign_context = NoContext(value) + windowed_value = WindowedValue( + value, timestamp, self.window_fn.assign(assign_context)) + + if tag is None: + self.main_receivers.receive(windowed_value) + else: + self.tagged_receivers[tag].output(windowed_value) + class NoContext(WindowFn.AssignContext): """An uninspectable WindowFn.AssignContext.""" http://git-wip-us.apache.org/repos/asf/beam/blob/1360c21d/sdks/python/apache_beam/transforms/ptransform_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index 303dfb8..ae77227 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -22,6 +22,7 @@ from __future__ import absolute_import import operator import re import unittest +import mock import hamcrest as hc from nose.plugins.attrib import attr @@ -46,6 +47,16 @@ from apache_beam.utils.pipeline_options import TypeOptions # Disable frequent lint warning due to pipe operator for chaining transforms. # pylint: disable=expression-not-assigned +class MyDoFn(beam.DoFn): + def start_bundle(self): + pass + + def process(self, element): + pass + + def finish_bundle(self): + yield 'finish' + class PTransformTest(unittest.TestCase): # Enable nose tests running in parallel @@ -274,17 +285,7 @@ class PTransformTest(unittest.TestCase): expected_error_prefix = 'FlatMap and ParDo must return an iterable.' self.assertStartswith(cm.exception.message, expected_error_prefix) - def test_do_fn_with_start_finish(self): - class MyDoFn(beam.DoFn): - def start_bundle(self): - yield 'start' - - def process(self, element): - pass - - def finish_bundle(self): - yield 'finish' - + def test_do_fn_with_finish(self): pipeline = TestPipeline() pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3]) result = pcoll | 'Do' >> beam.ParDo(MyDoFn()) @@ -292,13 +293,30 @@ class PTransformTest(unittest.TestCase): # May have many bundles, but each has a start and finish. def matcher(): def match(actual): - equal_to(['start', 'finish'])(list(set(actual))) - equal_to([actual.count('start')])([actual.count('finish')]) + equal_to(['finish'])(list(set(actual))) + equal_to([1])([actual.count('finish')]) return match assert_that(result, matcher()) pipeline.run() + @mock.patch.object(MyDoFn, 'start_bundle') + def test_do_fn_with_start(self, mock_method): + mock_method.return_value = None + pipeline = TestPipeline() + pipeline | 'Start' >> beam.Create([1, 2, 3]) | 'Do' >> beam.ParDo(MyDoFn()) + pipeline.run() + self.assertTrue(mock_method.called) + + @mock.patch.object(MyDoFn, 'start_bundle') + def test_do_fn_with_start_error(self, mock_method): + mock_method.return_value = [1] + pipeline = TestPipeline() + pipeline | 'Start' >> beam.Create([1, 2, 3]) | 'Do' >> beam.ParDo(MyDoFn()) + with self.assertRaises(RuntimeError): + pipeline.run() + self.assertTrue(mock_method.called) + def test_filter(self): pipeline = TestPipeline() pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3, 4])
