Repository: beam
Updated Branches:
  refs/heads/master c01ed083e -> a88f59063


Migrate CallableWrapperDoFn to use the NewDoFn type


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/926f8687
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/926f8687
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/926f8687

Branch: refs/heads/master
Commit: 926f86877346a68b95b1c618b69553d1c53e1c98
Parents: c01ed08
Author: Sourabh Bajaj <[email protected]>
Authored: Thu Feb 2 14:48:57 2017 -0800
Committer: Ahmet Altay <[email protected]>
Committed: Thu Feb 2 15:51:26 2017 -0800

----------------------------------------------------------------------
 sdks/python/apache_beam/pipeline_test.py   | 16 ++++++++++++++++
 sdks/python/apache_beam/transforms/core.py | 11 ++++++-----
 2 files changed, 22 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/926f8687/sdks/python/apache_beam/pipeline_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/pipeline_test.py 
b/sdks/python/apache_beam/pipeline_test.py
index 833293f..95b55b9 100644
--- a/sdks/python/apache_beam/pipeline_test.py
+++ b/sdks/python/apache_beam/pipeline_test.py
@@ -110,6 +110,22 @@ class PipelineTest(unittest.TestCase):
     assert_that(pcoll3, equal_to([14, 15, 16]), label='pcoll3')
     pipeline.run()
 
+  def test_flatmap_builtin(self):
+    pipeline = TestPipeline()
+    pcoll = pipeline | 'label1' >> Create([1, 2, 3])
+    assert_that(pcoll, equal_to([1, 2, 3]))
+
+    pcoll2 = pcoll | 'do' >> FlatMap(lambda x: [x + 10])
+    assert_that(pcoll2, equal_to([11, 12, 13]), label='pcoll2')
+
+    pcoll3 = pcoll2 | 'm1' >> Map(lambda x: [x, 12])
+    assert_that(pcoll3,
+                equal_to([[11, 12], [12, 12], [13, 12]]), label='pcoll3')
+
+    pcoll4 = pcoll3 | 'do2' >> FlatMap(set)
+    assert_that(pcoll4, equal_to([11, 12, 12, 12, 13]), label='pcoll4')
+    pipeline.run()
+
   def test_create_singleton_pcollection(self):
     pipeline = TestPipeline()
     pcoll = pipeline | 'label' >> Create([[1, 2, 3]])

http://git-wip-us.apache.org/repos/asf/beam/blob/926f8687/sdks/python/apache_beam/transforms/core.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/transforms/core.py 
b/sdks/python/apache_beam/transforms/core.py
index 20126d3..f69511f 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -304,7 +304,7 @@ def _fn_takes_side_inputs(fn):
   return len(argspec.args) > 1 + is_bound or argspec.varargs or 
argspec.keywords
 
 
-class CallableWrapperDoFn(DoFn):
+class CallableWrapperDoFn(NewDoFn):
   """A DoFn (function) object wrapping a callable object.
 
   The purpose of this class is to conveniently wrap simple functions and use
@@ -324,11 +324,12 @@ class CallableWrapperDoFn(DoFn):
       raise TypeError('Expected a callable object instead of: %r' % fn)
 
     self._fn = fn
-    if _fn_takes_side_inputs(fn):
-      self.process = lambda context, *args, **kwargs: fn(
-          context.element, *args, **kwargs)
+    if isinstance(fn, (
+        types.BuiltinFunctionType, types.MethodType, types.FunctionType)):
+      self.process = fn
     else:
-      self.process = lambda context: fn(context.element)
+      # For cases such as set / list where fn is callable but not a function
+      self.process = lambda element: fn(element)
 
     super(CallableWrapperDoFn, self).__init__()
 

Reply via email to