Allow subclasses of tuple, list, and dict as pvaluish inputs/outputs.
Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/72960b31 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/72960b31 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/72960b31 Branch: refs/heads/master Commit: 72960b31843d1dcdf2b43a55db0797a15f48ef18 Parents: 28d4f09 Author: Robert Bradshaw <rober...@gmail.com> Authored: Fri Sep 8 17:53:09 2017 -0700 Committer: Robert Bradshaw <rober...@gmail.com> Committed: Tue Sep 19 15:23:28 2017 -0700 ---------------------------------------------------------------------- sdks/python/apache_beam/pipeline.py | 2 +- .../python/apache_beam/transforms/ptransform.py | 62 ++++++++++---------- .../apache_beam/transforms/ptransform_test.py | 16 +++++ 3 files changed, 48 insertions(+), 32 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/72960b31/sdks/python/apache_beam/pipeline.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index 1ebd099..c670978 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -438,7 +438,7 @@ class Pipeline(object): if type_options is not None and type_options.pipeline_type_check: transform.type_check_outputs(pvalueish_result) - for result in ptransform.GetPValues().visit(pvalueish_result): + for result in ptransform.get_nested_pvalues(pvalueish_result): assert isinstance(result, (pvalue.PValue, pvalue.DoOutputsTuple)) # Make sure we set the producer only for a leaf node in the transform DAG. http://git-wip-us.apache.org/repos/asf/beam/blob/72960b31/sdks/python/apache_beam/transforms/ptransform.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py index eccaccd..f630977 100644 --- a/sdks/python/apache_beam/transforms/ptransform.py +++ b/sdks/python/apache_beam/transforms/ptransform.py @@ -73,27 +73,25 @@ class _PValueishTransform(object): This visits a PValueish, contstructing a (possibly mutated) copy. """ - def visit(self, node, *args): - return getattr( - self, - 'visit_' + node.__class__.__name__, - lambda x, *args: x)(node, *args) - - def visit_list(self, node, *args): - return [self.visit(x, *args) for x in node] - - def visit_tuple(self, node, *args): - return tuple(self.visit(x, *args) for x in node) - - def visit_dict(self, node, *args): - return {key: self.visit(value, *args) for (key, value) in node.items()} + def visit_nested(self, node, *args): + if isinstance(node, (tuple, list)): + # namedtuples require unpacked arguments in their constructor, + # but do have a _make method that takes a sequence. + return getattr(node.__class__, '_make', node.__class__)( + [self.visit(x, *args) for x in node]) + elif isinstance(node, dict): + return node.__class__( + {key: self.visit(value, *args) for (key, value) in node.items()}) + else: + return node class _SetInputPValues(_PValueishTransform): def visit(self, node, replacements): if id(node) in replacements: return replacements[id(node)] - return super(_SetInputPValues, self).visit(node, replacements) + else: + return self.visit_nested(node, replacements) class _MaterializedDoOutputsTuple(pvalue.DoOutputsTuple): @@ -116,22 +114,25 @@ class _MaterializePValues(_PValueishTransform): return self._pvalue_cache.get_unwindowed_pvalue(node) elif isinstance(node, pvalue.DoOutputsTuple): return _MaterializedDoOutputsTuple(node, self._pvalue_cache) - return super(_MaterializePValues, self).visit(node) + else: + return self.visit_nested(node) -class GetPValues(_PValueishTransform): - def visit(self, node, pvalues=None): - if pvalues is None: - pvalues = [] - self.visit(node, pvalues) - return pvalues - elif isinstance(node, (pvalue.PValue, pvalue.DoOutputsTuple)): +class _GetPValues(_PValueishTransform): + def visit(self, node, pvalues): + if isinstance(node, (pvalue.PValue, pvalue.DoOutputsTuple)): pvalues.append(node) else: - super(GetPValues, self).visit(node, pvalues) + self.visit_nested(node, pvalues) + + +def get_nested_pvalues(pvalueish): + pvalues = [] + _GetPValues().visit(pvalueish, pvalues) + return pvalues -class _ZipPValues(_PValueishTransform): +class _ZipPValues(object): """Pairs each PValue in a pvalueish with a value in a parallel out sibling. Sibling should have the same nested structure as pvalueish. Leaves in @@ -153,10 +154,12 @@ class _ZipPValues(_PValueishTransform): return pairs elif isinstance(pvalueish, (pvalue.PValue, pvalue.DoOutputsTuple)): pairs.append((context, pvalueish, sibling)) - else: - super(_ZipPValues, self).visit(pvalueish, sibling, pairs, context) + elif isinstance(pvalueish, (list, tuple)): + self.visit_sequence(pvalueish, sibling, pairs, context) + elif isinstance(pvalueish, dict): + self.visit_dict(pvalueish, sibling, pairs, context) - def visit_list(self, pvalueish, sibling, pairs, context): + def visit_sequence(self, pvalueish, sibling, pairs, context): if isinstance(sibling, (list, tuple)): for ix, (p, s) in enumerate(zip( pvalueish, list(sibling) + [None] * len(pvalueish))): @@ -165,9 +168,6 @@ class _ZipPValues(_PValueishTransform): for p in pvalueish: self.visit(p, sibling, pairs, context) - def visit_tuple(self, pvalueish, sibling, pairs, context): - self.visit_list(pvalueish, sibling, pairs, context) - def visit_dict(self, pvalueish, sibling, pairs, context): if isinstance(sibling, dict): for key, p in pvalueish.items(): http://git-wip-us.apache.org/repos/asf/beam/blob/72960b31/sdks/python/apache_beam/transforms/ptransform_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index 435270e..c328cb1 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -19,6 +19,7 @@ from __future__ import absolute_import +import collections import operator import re import unittest @@ -670,6 +671,21 @@ class PTransformTest(unittest.TestCase): self.assertEqual(['x', 'x', 'y', 'y', 'z'], sorted(res['b'])) self.assertEqual([], sorted(res['c'])) + def test_named_tuple(self): + MinMax = collections.namedtuple('MinMax', ['min', 'max']) + + class MinMaxTransform(PTransform): + def expand(self, pcoll): + return MinMax( + min=pcoll | beam.CombineGlobally(min).without_defaults(), + max=pcoll | beam.CombineGlobally(max).without_defaults()) + res = [1, 2, 4, 8] | MinMaxTransform() + self.assertIsInstance(res, MinMax) + self.assertEqual(res, MinMax(min=[1], max=[8])) + + flat = res | beam.Flatten() + self.assertEqual(sorted(flat), [1, 8]) + @beam.ptransform_fn def SamplePTransform(pcoll):