Repository: beam Updated Branches: refs/heads/master f731de0b3 -> 8d337ff0e
[BEAM-2729] Allow GBK of union-typed PCollections. Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/b10c1c34 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/b10c1c34 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/b10c1c34 Branch: refs/heads/master Commit: b10c1c34e03e7e77add8a4cc7a98c60914becda6 Parents: f731de0 Author: Robert Bradshaw <[email protected]> Authored: Fri Aug 4 11:43:12 2017 -0700 Committer: Robert Bradshaw <[email protected]> Committed: Fri Aug 4 12:13:14 2017 -0700 ---------------------------------------------------------------------- .../runners/dataflow/dataflow_runner.py | 25 +++++++++++++++----- .../runners/dataflow/dataflow_runner_test.py | 22 +++++++++++++++++ 2 files changed, 41 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/b10c1c34/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py index 0df1882..880901e 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py @@ -188,18 +188,31 @@ class DataflowRunner(PipelineRunner): if not input_type: input_type = typehints.Any - if not isinstance(input_type, typehints.TupleHint.TupleConstraint): - if isinstance(input_type, typehints.AnyTypeConstraint): + def coerce_to_kv_type(element_type): + if isinstance(element_type, typehints.TupleHint.TupleConstraint): + if len(element_type.tuple_types) == 2: + return element_type + else: + raise ValueError( + "Tuple input to GroupByKey must be have two components. " + "Found %s for %s" % (element_type, pcoll)) + elif isinstance(input_type, typehints.AnyTypeConstraint): # `Any` type needs to be replaced with a KV[Any, Any] to # force a KV coder as the main output coder for the pcollection # preceding a GroupByKey. - pcoll.element_type = typehints.KV[typehints.Any, typehints.Any] + return typehints.KV[typehints.Any, typehints.Any] + elif isinstance(element_type, typehints.UnionConstraint): + union_types = [ + coerce_to_kv_type(t) for t in element_type.union_types] + return typehints.KV[ + typehints.Union[tuple(t.tuple_types[0] for t in union_types)], + typehints.Union[tuple(t.tuple_types[1] for t in union_types)]] else: - # TODO: Handle other valid types, - # e.g. Union[KV[str, int], KV[str, float]] + # TODO: Possibly handle other valid types. raise ValueError( "Input to GroupByKey must be of Tuple or Any type. " - "Found %s for %s" % (input_type, pcoll)) + "Found %s for %s" % (element_type, pcoll)) + pcoll.element_type = coerce_to_kv_type(input_type) return GroupByKeyInputVisitor() http://git-wip-us.apache.org/repos/asf/beam/blob/b10c1c34/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py index a9b8fdb..80414d6 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py @@ -256,6 +256,28 @@ class DataflowRunnerTest(unittest.TestCase): for _ in range(num_inputs): self.assertEqual(inputs[0].element_type, output_type) + def test_gbk_then_flatten_input_visitor(self): + p = TestPipeline( + runner=DataflowRunner(), + options=PipelineOptions(self.default_properties)) + none_str_pc = p | 'c1' >> beam.Create({None: 'a'}) + none_int_pc = p | 'c2' >> beam.Create({None: 3}) + flat = (none_str_pc, none_int_pc) | beam.Flatten() + _ = flat | beam.GroupByKey() + + # This may change if type inference changes, but we assert it here + # to make sure the check below is not vacuous. + self.assertNotIsInstance(flat.element_type, typehints.TupleConstraint) + + p.visit(DataflowRunner.group_by_key_input_visitor()) + p.visit(DataflowRunner.flatten_input_visitor()) + + # The dataflow runner requires gbk input to be tuples *and* flatten + # inputs to be equal to their outputs. Assert both hold. + self.assertIsInstance(flat.element_type, typehints.TupleConstraint) + self.assertEqual(flat.element_type, none_str_pc.element_type) + self.assertEqual(flat.element_type, none_int_pc.element_type) + def test_serialize_windowing_strategy(self): # This just tests the basic path; more complete tests # are in window_test.py.
