Allow subclasses of tuple, list, and dict as pvaluish inputs/outputs.

Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/72960b31
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/72960b31
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/72960b31

Branch: refs/heads/master
Commit: 72960b31843d1dcdf2b43a55db0797a15f48ef18
Parents: 28d4f09
Author: Robert Bradshaw <rober...@gmail.com>
Authored: Fri Sep 8 17:53:09 2017 -0700
Committer: Robert Bradshaw <rober...@gmail.com>
Committed: Tue Sep 19 15:23:28 2017 -0700

----------------------------------------------------------------------
 sdks/python/apache_beam/pipeline.py             |  2 +-
 .../python/apache_beam/transforms/ptransform.py | 62 ++++++++++----------
 .../apache_beam/transforms/ptransform_test.py   | 16 +++++
 3 files changed, 48 insertions(+), 32 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/72960b31/sdks/python/apache_beam/pipeline.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/pipeline.py 
b/sdks/python/apache_beam/pipeline.py
index 1ebd099..c670978 100644
--- a/sdks/python/apache_beam/pipeline.py
+++ b/sdks/python/apache_beam/pipeline.py
@@ -438,7 +438,7 @@ class Pipeline(object):
     if type_options is not None and type_options.pipeline_type_check:
       transform.type_check_outputs(pvalueish_result)
 
-    for result in ptransform.GetPValues().visit(pvalueish_result):
+    for result in ptransform.get_nested_pvalues(pvalueish_result):
       assert isinstance(result, (pvalue.PValue, pvalue.DoOutputsTuple))
 
       # Make sure we set the producer only for a leaf node in the transform 
DAG.

http://git-wip-us.apache.org/repos/asf/beam/blob/72960b31/sdks/python/apache_beam/transforms/ptransform.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/transforms/ptransform.py 
b/sdks/python/apache_beam/transforms/ptransform.py
index eccaccd..f630977 100644
--- a/sdks/python/apache_beam/transforms/ptransform.py
+++ b/sdks/python/apache_beam/transforms/ptransform.py
@@ -73,27 +73,25 @@ class _PValueishTransform(object):
 
   This visits a PValueish, contstructing a (possibly mutated) copy.
   """
-  def visit(self, node, *args):
-    return getattr(
-        self,
-        'visit_' + node.__class__.__name__,
-        lambda x, *args: x)(node, *args)
-
-  def visit_list(self, node, *args):
-    return [self.visit(x, *args) for x in node]
-
-  def visit_tuple(self, node, *args):
-    return tuple(self.visit(x, *args) for x in node)
-
-  def visit_dict(self, node, *args):
-    return {key: self.visit(value, *args) for (key, value) in node.items()}
+  def visit_nested(self, node, *args):
+    if isinstance(node, (tuple, list)):
+      # namedtuples require unpacked arguments in their constructor,
+      # but do have a _make method that takes a sequence.
+      return getattr(node.__class__, '_make', node.__class__)(
+          [self.visit(x, *args) for x in node])
+    elif isinstance(node, dict):
+      return node.__class__(
+          {key: self.visit(value, *args) for (key, value) in node.items()})
+    else:
+      return node
 
 
 class _SetInputPValues(_PValueishTransform):
   def visit(self, node, replacements):
     if id(node) in replacements:
       return replacements[id(node)]
-    return super(_SetInputPValues, self).visit(node, replacements)
+    else:
+      return self.visit_nested(node, replacements)
 
 
 class _MaterializedDoOutputsTuple(pvalue.DoOutputsTuple):
@@ -116,22 +114,25 @@ class _MaterializePValues(_PValueishTransform):
       return self._pvalue_cache.get_unwindowed_pvalue(node)
     elif isinstance(node, pvalue.DoOutputsTuple):
       return _MaterializedDoOutputsTuple(node, self._pvalue_cache)
-    return super(_MaterializePValues, self).visit(node)
+    else:
+      return self.visit_nested(node)
 
 
-class GetPValues(_PValueishTransform):
-  def visit(self, node, pvalues=None):
-    if pvalues is None:
-      pvalues = []
-      self.visit(node, pvalues)
-      return pvalues
-    elif isinstance(node, (pvalue.PValue, pvalue.DoOutputsTuple)):
+class _GetPValues(_PValueishTransform):
+  def visit(self, node, pvalues):
+    if isinstance(node, (pvalue.PValue, pvalue.DoOutputsTuple)):
       pvalues.append(node)
     else:
-      super(GetPValues, self).visit(node, pvalues)
+      self.visit_nested(node, pvalues)
+
+
+def get_nested_pvalues(pvalueish):
+  pvalues = []
+  _GetPValues().visit(pvalueish, pvalues)
+  return pvalues
 
 
-class _ZipPValues(_PValueishTransform):
+class _ZipPValues(object):
   """Pairs each PValue in a pvalueish with a value in a parallel out sibling.
 
   Sibling should have the same nested structure as pvalueish.  Leaves in
@@ -153,10 +154,12 @@ class _ZipPValues(_PValueishTransform):
       return pairs
     elif isinstance(pvalueish, (pvalue.PValue, pvalue.DoOutputsTuple)):
       pairs.append((context, pvalueish, sibling))
-    else:
-      super(_ZipPValues, self).visit(pvalueish, sibling, pairs, context)
+    elif isinstance(pvalueish, (list, tuple)):
+      self.visit_sequence(pvalueish, sibling, pairs, context)
+    elif isinstance(pvalueish, dict):
+      self.visit_dict(pvalueish, sibling, pairs, context)
 
-  def visit_list(self, pvalueish, sibling, pairs, context):
+  def visit_sequence(self, pvalueish, sibling, pairs, context):
     if isinstance(sibling, (list, tuple)):
       for ix, (p, s) in enumerate(zip(
           pvalueish, list(sibling) + [None] * len(pvalueish))):
@@ -165,9 +168,6 @@ class _ZipPValues(_PValueishTransform):
       for p in pvalueish:
         self.visit(p, sibling, pairs, context)
 
-  def visit_tuple(self, pvalueish, sibling, pairs, context):
-    self.visit_list(pvalueish, sibling, pairs, context)
-
   def visit_dict(self, pvalueish, sibling, pairs, context):
     if isinstance(sibling, dict):
       for key, p in pvalueish.items():

http://git-wip-us.apache.org/repos/asf/beam/blob/72960b31/sdks/python/apache_beam/transforms/ptransform_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py 
b/sdks/python/apache_beam/transforms/ptransform_test.py
index 435270e..c328cb1 100644
--- a/sdks/python/apache_beam/transforms/ptransform_test.py
+++ b/sdks/python/apache_beam/transforms/ptransform_test.py
@@ -19,6 +19,7 @@
 
 from __future__ import absolute_import
 
+import collections
 import operator
 import re
 import unittest
@@ -670,6 +671,21 @@ class PTransformTest(unittest.TestCase):
     self.assertEqual(['x', 'x', 'y', 'y', 'z'], sorted(res['b']))
     self.assertEqual([], sorted(res['c']))
 
+  def test_named_tuple(self):
+    MinMax = collections.namedtuple('MinMax', ['min', 'max'])
+
+    class MinMaxTransform(PTransform):
+      def expand(self, pcoll):
+        return MinMax(
+            min=pcoll | beam.CombineGlobally(min).without_defaults(),
+            max=pcoll | beam.CombineGlobally(max).without_defaults())
+    res = [1, 2, 4, 8] | MinMaxTransform()
+    self.assertIsInstance(res, MinMax)
+    self.assertEqual(res, MinMax(min=[1], max=[8]))
+
+    flat = res | beam.Flatten()
+    self.assertEqual(sorted(flat), [1, 8])
+
 
 @beam.ptransform_fn
 def SamplePTransform(pcoll):

Reply via email to