Add unit test for unwindowed iterator picking.

Also lifted this out to a top-level class rather than defining
it on every element now that it's not longer a simple generator
statement.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/e9b1e412
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/e9b1e412
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/e9b1e412

Branch: refs/heads/python-sdk
Commit: e9b1e41240d5032cdaa2b745c95f94da45475f34
Parents: 63904e0
Author: Robert Bradshaw <rober...@google.com>
Authored: Fri Sep 16 15:41:22 2016 -0700
Committer: Robert Bradshaw <rober...@google.com>
Committed: Fri Sep 16 15:41:22 2016 -0700

----------------------------------------------------------------------
 sdks/python/apache_beam/transforms/trigger.py   | 34 ++++++++++++--------
 .../apache_beam/transforms/trigger_test.py      | 12 +++++++
 2 files changed, 33 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/e9b1e412/sdks/python/apache_beam/transforms/trigger.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/transforms/trigger.py 
b/sdks/python/apache_beam/transforms/trigger.py
index 58b6154..8c23873 100644
--- a/sdks/python/apache_beam/transforms/trigger.py
+++ b/sdks/python/apache_beam/transforms/trigger.py
@@ -713,6 +713,26 @@ class TriggerDriver(object):
     pass
 
 
+class _UnwindowedValues(observable.ObservableMixin):
+  """Exposes iterable of windowed values as interable of unwindowed values."""
+
+  def __init__(self, windowed_values):
+    super(_UnwindowedValues, self).__init__()
+    self._windowed_values = windowed_values
+
+  def __iter__(self):
+    for wv in self._windowed_values:
+      unwindowed_value = wv.value
+      self.notify_observers(unwindowed_value)
+      yield unwindowed_value
+
+  def __repr__(self):
+    return '<_UnwindowedValues of %s>' % self._windowed_values
+
+  def __reduce__(self):
+    return list, (list(self),)
+
+
 class DefaultGlobalBatchTriggerDriver(TriggerDriver):
   """Breaks a bundles into window (pane)s according to the default triggering.
   """
@@ -725,19 +745,7 @@ class DefaultGlobalBatchTriggerDriver(TriggerDriver):
     if isinstance(windowed_values, list):
       unwindowed = [wv.value for wv in windowed_values]
     else:
-      class UnwindowedValues(observable.ObservableMixin):
-        def __iter__(self):
-          for wv in windowed_values:
-            unwindowed_value = wv.value
-            self.notify_observers(unwindowed_value)
-            yield unwindowed_value
-
-        def __repr__(self):
-          return '<UnwindowedValues of %s>' % windowed_values
-
-        def __reduce__(self):
-          return list, (list(self),)
-      unwindowed = UnwindowedValues()
+      unwindowed = _UnwindowedValues(windowed_values)
     yield WindowedValue(unwindowed, MIN_TIMESTAMP, self.GLOBAL_WINDOW_TUPLE)
 
   def process_timer(self, window_id, name, time_domain, timestamp, state):

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/e9b1e412/sdks/python/apache_beam/transforms/trigger_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/transforms/trigger_test.py 
b/sdks/python/apache_beam/transforms/trigger_test.py
index a3ad8d8..c37d4ae 100644
--- a/sdks/python/apache_beam/transforms/trigger_test.py
+++ b/sdks/python/apache_beam/transforms/trigger_test.py
@@ -19,12 +19,14 @@
 
 import collections
 import os.path
+import pickle
 import unittest
 
 import yaml
 
 import apache_beam as beam
 from apache_beam.pipeline import Pipeline
+from apache_beam.transforms import trigger
 from apache_beam.transforms.core import Windowing
 from apache_beam.transforms.trigger import AccumulationMode
 from apache_beam.transforms.trigger import AfterAll
@@ -366,6 +368,16 @@ class TriggerTest(unittest.TestCase):
          IntervalWindow(0, 17): [set('abcdefgh')]},
         2)
 
+  def test_picklable_output(self):
+    global_window = trigger.GlobalWindow(),
+    driver = trigger.DefaultGlobalBatchTriggerDriver()
+    unpicklable = (WindowedValue(k, 0, global_window)
+                   for k in range(10))
+    with self.assertRaises(TypeError):
+      pickle.dumps(unpicklable)
+    for unwindowed in driver.process_elements(None, unpicklable, None):
+      self.assertEqual(pickle.loads(pickle.dumps(unwindowed)).value,
+                       range(10))
 
 class TriggerPipelineTest(unittest.TestCase):
 

Reply via email to