Modify types for input PCollections of Flatten transform to that of the output PCollection
Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/8c88d6ab Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/8c88d6ab Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/8c88d6ab Branch: refs/heads/gearpump-runner Commit: 8c88d6ab475db40afb99c08ea44f9a2c61d85862 Parents: 1721cea Author: Vikas Kedigehalli <[email protected]> Authored: Fri Apr 14 13:53:13 2017 -0700 Committer: Luke Cwik <[email protected]> Committed: Wed Apr 19 15:35:06 2017 -0700 ---------------------------------------------------------------------- .../runners/dataflow/dataflow_runner.py | 68 ++++++++++++++++++++ .../runners/dataflow/dataflow_runner_test.py | 65 ++++++++++++++++++- .../apache_beam/runners/direct/direct_runner.py | 2 - sdks/python/apache_beam/runners/runner.py | 41 ------------ sdks/python/apache_beam/runners/runner_test.py | 41 ------------ sdks/python/apache_beam/utils/proto_utils.py | 2 +- 6 files changed, 133 insertions(+), 86 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/8c88d6ab/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py index 779db8f..4534895 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py @@ -149,6 +149,66 @@ class DataflowRunner(PipelineRunner): result._job = response runner.last_error_msg = last_error_msg + @staticmethod + def group_by_key_input_visitor(): + # Imported here to avoid circular dependencies. + from apache_beam.pipeline import PipelineVisitor + + class GroupByKeyInputVisitor(PipelineVisitor): + """A visitor that replaces `Any` element type for input `PCollection` of + a `GroupByKey` or `GroupByKeyOnly` with a `KV` type. + + TODO(BEAM-115): Once Python SDk is compatible with the new Runner API, + we could directly replace the coder instead of mutating the element type. + """ + + def visit_transform(self, transform_node): + # Imported here to avoid circular dependencies. + # pylint: disable=wrong-import-order, wrong-import-position + from apache_beam import GroupByKey, GroupByKeyOnly + if isinstance(transform_node.transform, (GroupByKey, GroupByKeyOnly)): + pcoll = transform_node.inputs[0] + input_type = pcoll.element_type + # If input_type is not specified, then treat it as `Any`. + if not input_type: + input_type = typehints.Any + + if not isinstance(input_type, typehints.TupleHint.TupleConstraint): + if isinstance(input_type, typehints.AnyTypeConstraint): + # `Any` type needs to be replaced with a KV[Any, Any] to + # force a KV coder as the main output coder for the pcollection + # preceding a GroupByKey. + pcoll.element_type = typehints.KV[typehints.Any, typehints.Any] + else: + # TODO: Handle other valid types, + # e.g. Union[KV[str, int], KV[str, float]] + raise ValueError( + "Input to GroupByKey must be of Tuple or Any type. " + "Found %s for %s" % (input_type, pcoll)) + + return GroupByKeyInputVisitor() + + @staticmethod + def flatten_input_visitor(): + # Imported here to avoid circular dependencies. + from apache_beam.pipeline import PipelineVisitor + + class FlattenInputVisitor(PipelineVisitor): + """A visitor that replaces the element type for input ``PCollections``s of + a ``Flatten`` transform with that of the output ``PCollection``. + """ + + def visit_transform(self, transform_node): + # Imported here to avoid circular dependencies. + # pylint: disable=wrong-import-order, wrong-import-position + from apache_beam import Flatten + if isinstance(transform_node.transform, Flatten): + output_pcoll = transform_node.outputs[None] + for input_pcoll in transform_node.inputs: + input_pcoll.element_type = output_pcoll.element_type + + return FlattenInputVisitor() + def run(self, pipeline): """Remotely executes entire pipeline or parts reachable from node.""" # Import here to avoid adding the dependency for local running scenarios. @@ -161,6 +221,14 @@ class DataflowRunner(PipelineRunner): 'please install apache_beam[gcp]') self.job = apiclient.Job(pipeline.options) + # Dataflow runner requires a KV type for GBK inputs, hence we enforce that + # here. + pipeline.visit(self.group_by_key_input_visitor()) + + # Dataflow runner requires output type of the Flatten to be the same as the + # inputs, hence we enforce that here. + pipeline.visit(self.flatten_input_visitor()) + # The superclass's run will trigger a traversal of all reachable nodes. super(DataflowRunner, self).run(pipeline) http://git-wip-us.apache.org/repos/asf/beam/blob/8c88d6ab/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py index b9ed84d..f342be5 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py @@ -26,14 +26,17 @@ import mock import apache_beam as beam import apache_beam.transforms as ptransform -from apache_beam.pipeline import Pipeline +from apache_beam.pipeline import Pipeline, AppliedPTransform +from apache_beam.pvalue import PCollection from apache_beam.runners import create_runner from apache_beam.runners import DataflowRunner from apache_beam.runners import TestDataflowRunner from apache_beam.runners.dataflow.dataflow_runner import DataflowPipelineResult from apache_beam.runners.dataflow.dataflow_runner import DataflowRuntimeException from apache_beam.runners.dataflow.internal.clients import dataflow as dataflow_api +from apache_beam.test_pipeline import TestPipeline from apache_beam.transforms.display import DisplayDataItem +from apache_beam.typehints import typehints from apache_beam.utils.pipeline_options import PipelineOptions # Protect against environments where apitools library is not available. @@ -176,6 +179,66 @@ class DataflowRunnerTest(unittest.TestCase): 'RowAsDictJsonCoder')): unused_invalid = rows | beam.GroupByKey() + def test_group_by_key_input_visitor_with_valid_inputs(self): + p = TestPipeline() + pcoll1 = PCollection(p) + pcoll2 = PCollection(p) + pcoll3 = PCollection(p) + for transform in [beam.GroupByKeyOnly(), beam.GroupByKey()]: + pcoll1.element_type = None + pcoll2.element_type = typehints.Any + pcoll3.element_type = typehints.KV[typehints.Any, typehints.Any] + for pcoll in [pcoll1, pcoll2, pcoll3]: + DataflowRunner.group_by_key_input_visitor().visit_transform( + AppliedPTransform(None, transform, "label", [pcoll])) + self.assertEqual(pcoll.element_type, + typehints.KV[typehints.Any, typehints.Any]) + + def test_group_by_key_input_visitor_with_invalid_inputs(self): + p = TestPipeline() + pcoll1 = PCollection(p) + pcoll2 = PCollection(p) + for transform in [beam.GroupByKeyOnly(), beam.GroupByKey()]: + pcoll1.element_type = typehints.TupleSequenceConstraint + pcoll2.element_type = typehints.Set + err_msg = "Input to GroupByKey must be of Tuple or Any type" + for pcoll in [pcoll1, pcoll2]: + with self.assertRaisesRegexp(ValueError, err_msg): + DataflowRunner.group_by_key_input_visitor().visit_transform( + AppliedPTransform(None, transform, "label", [pcoll])) + + def test_group_by_key_input_visitor_for_non_gbk_transforms(self): + p = TestPipeline() + pcoll = PCollection(p) + for transform in [beam.Flatten(), beam.Map(lambda x: x)]: + pcoll.element_type = typehints.Any + DataflowRunner.group_by_key_input_visitor().visit_transform( + AppliedPTransform(None, transform, "label", [pcoll])) + self.assertEqual(pcoll.element_type, typehints.Any) + + def test_flatten_input_with_visitor_with_single_input(self): + self._test_flatten_input_visitor(typehints.KV[int, int], typehints.Any, 1) + + def test_flatten_input_with_visitor_with_multiple_inputs(self): + self._test_flatten_input_visitor( + typehints.KV[int, typehints.Any], typehints.Any, 5) + + def _test_flatten_input_visitor(self, input_type, output_type, num_inputs): + p = TestPipeline() + inputs = [] + for _ in range(num_inputs): + input_pcoll = PCollection(p) + input_pcoll.element_type = input_type + inputs.append(input_pcoll) + output_pcoll = PCollection(p) + output_pcoll.element_type = output_type + + flatten = AppliedPTransform(None, beam.Flatten(), "label", inputs) + flatten.add_output(output_pcoll, None) + DataflowRunner.flatten_input_visitor().visit_transform(flatten) + for _ in range(num_inputs): + self.assertEqual(inputs[0].element_type, output_type) + if __name__ == '__main__': unittest.main() http://git-wip-us.apache.org/repos/asf/beam/blob/8c88d6ab/sdks/python/apache_beam/runners/direct/direct_runner.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py index d776719..d8d8cb9 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner.py @@ -32,7 +32,6 @@ from apache_beam.runners.runner import PipelineResult from apache_beam.runners.runner import PipelineRunner from apache_beam.runners.runner import PipelineState from apache_beam.runners.runner import PValueCache -from apache_beam.runners.runner import group_by_key_input_visitor from apache_beam.utils.pipeline_options import DirectOptions from apache_beam.utils.value_provider import RuntimeValueProvider @@ -70,7 +69,6 @@ class DirectRunner(PipelineRunner): MetricsEnvironment.set_metrics_supported(True) logging.info('Running pipeline with DirectRunner.') self.consumer_tracking_visitor = ConsumerTrackingPipelineVisitor() - pipeline.visit(group_by_key_input_visitor()) pipeline.visit(self.consumer_tracking_visitor) evaluation_context = EvaluationContext( http://git-wip-us.apache.org/repos/asf/beam/blob/8c88d6ab/sdks/python/apache_beam/runners/runner.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/runner.py b/sdks/python/apache_beam/runners/runner.py index 4d33802..ccb066b 100644 --- a/sdks/python/apache_beam/runners/runner.py +++ b/sdks/python/apache_beam/runners/runner.py @@ -94,46 +94,6 @@ def create_runner(runner_name): runner_name, ', '.join(_ALL_KNOWN_RUNNERS))) -def group_by_key_input_visitor(): - # Imported here to avoid circular dependencies. - from apache_beam.pipeline import PipelineVisitor - - class GroupByKeyInputVisitor(PipelineVisitor): - """A visitor that replaces `Any` element type for input `PCollection` of - a `GroupByKey` or `GroupByKeyOnly` with a `KV` type. - - TODO(BEAM-115): Once Python SDk is compatible with the new Runner API, - we could directly replace the coder instead of mutating the element type. - """ - - def visit_transform(self, transform_node): - # Imported here to avoid circular dependencies. - # pylint: disable=wrong-import-order, wrong-import-position - from apache_beam import GroupByKey, GroupByKeyOnly - from apache_beam import typehints - if isinstance(transform_node.transform, (GroupByKey, GroupByKeyOnly)): - pcoll = transform_node.inputs[0] - input_type = pcoll.element_type - # If input_type is not specified, then treat it as `Any`. - if not input_type: - input_type = typehints.Any - - if not isinstance(input_type, typehints.TupleHint.TupleConstraint): - if isinstance(input_type, typehints.AnyTypeConstraint): - # `Any` type needs to be replaced with a KV[Any, Any] to - # force a KV coder as the main output coder for the pcollection - # preceding a GroupByKey. - pcoll.element_type = typehints.KV[typehints.Any, typehints.Any] - else: - # TODO: Handle other valid types, - # e.g. Union[KV[str, int], KV[str, float]] - raise ValueError( - "Input to GroupByKey must be of Tuple or Any type. " - "Found %s for %s" % (input_type, pcoll)) - - return GroupByKeyInputVisitor() - - class PipelineRunner(object): """A runner of a pipeline object. @@ -167,7 +127,6 @@ class PipelineRunner(object): logging.error('Error while visiting %s', transform_node.full_label) raise - pipeline.visit(group_by_key_input_visitor()) pipeline.visit(RunVisitor(self)) def clear(self, pipeline, node=None): http://git-wip-us.apache.org/repos/asf/beam/blob/8c88d6ab/sdks/python/apache_beam/runners/runner_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/runner_test.py b/sdks/python/apache_beam/runners/runner_test.py index 0bebd66..b161cbb 100644 --- a/sdks/python/apache_beam/runners/runner_test.py +++ b/sdks/python/apache_beam/runners/runner_test.py @@ -28,19 +28,14 @@ import hamcrest as hc import apache_beam as beam import apache_beam.transforms as ptransform -from apache_beam import typehints from apache_beam.metrics.cells import DistributionData from apache_beam.metrics.cells import DistributionResult from apache_beam.metrics.execution import MetricKey from apache_beam.metrics.execution import MetricResult from apache_beam.metrics.metricbase import MetricName -from apache_beam.pipeline import AppliedPTransform from apache_beam.pipeline import Pipeline -from apache_beam.pvalue import PCollection from apache_beam.runners import DirectRunner -from apache_beam.runners import runner from apache_beam.runners import create_runner -from apache_beam.test_pipeline import TestPipeline from apache_beam.transforms.util import assert_that from apache_beam.transforms.util import equal_to from apache_beam.utils.pipeline_options import PipelineOptions @@ -123,42 +118,6 @@ class RunnerTest(unittest.TestCase): DistributionResult(DistributionData(15, 5, 1, 5)), DistributionResult(DistributionData(15, 5, 1, 5))))) - def test_group_by_key_input_visitor_with_valid_inputs(self): - p = TestPipeline() - pcoll1 = PCollection(p) - pcoll2 = PCollection(p) - pcoll3 = PCollection(p) - for transform in [beam.GroupByKeyOnly(), beam.GroupByKey()]: - pcoll1.element_type = None - pcoll2.element_type = typehints.Any - pcoll3.element_type = typehints.KV[typehints.Any, typehints.Any] - for pcoll in [pcoll1, pcoll2, pcoll3]: - runner.group_by_key_input_visitor().visit_transform( - AppliedPTransform(None, transform, "label", [pcoll])) - self.assertEqual(pcoll.element_type, - typehints.KV[typehints.Any, typehints.Any]) - - def test_group_by_key_input_visitor_with_invalid_inputs(self): - p = TestPipeline() - pcoll1 = PCollection(p) - pcoll2 = PCollection(p) - for transform in [beam.GroupByKeyOnly(), beam.GroupByKey()]: - pcoll1.element_type = typehints.TupleSequenceConstraint - pcoll2.element_type = typehints.Set - err_msg = "Input to GroupByKey must be of Tuple or Any type" - for pcoll in [pcoll1, pcoll2]: - with self.assertRaisesRegexp(ValueError, err_msg): - runner.group_by_key_input_visitor().visit_transform( - AppliedPTransform(None, transform, "label", [pcoll])) - - def test_group_by_key_input_visitor_for_non_gbk_transforms(self): - p = TestPipeline() - pcoll = PCollection(p) - for transform in [beam.Flatten(), beam.Map(lambda x: x)]: - pcoll.element_type = typehints.Any - runner.group_by_key_input_visitor().visit_transform( - AppliedPTransform(None, transform, "label", [pcoll])) - self.assertEqual(pcoll.element_type, typehints.Any) if __name__ == '__main__': unittest.main() http://git-wip-us.apache.org/repos/asf/beam/blob/8c88d6ab/sdks/python/apache_beam/utils/proto_utils.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/utils/proto_utils.py b/sdks/python/apache_beam/utils/proto_utils.py index 0243495..d929a92 100644 --- a/sdks/python/apache_beam/utils/proto_utils.py +++ b/sdks/python/apache_beam/utils/proto_utils.py @@ -49,5 +49,5 @@ def pack_Struct(**kwargs): """ msg = struct_pb2.Struct() for key, value in kwargs.items(): - msg[key] = value # pylint: disable=unsubscriptable-object + msg[key] = value # pylint: disable=unsubscriptable-object, unsupported-assignment-operation return msg
