Adds an assertion to source_test_utils for testing reentrancy.

Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/bdcb04cb
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/bdcb04cb
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/bdcb04cb

Branch: refs/heads/python-sdk
Commit: bdcb04cb9146d035339f02559127a810166721ab
Parents: 2ab8d62
Author: Chamikara Jayalath <chamik...@google.com>
Authored: Sat Oct 15 17:49:13 2016 -0700
Committer: Robert Bradshaw <rober...@google.com>
Committed: Tue Oct 18 12:09:19 2016 -0700

----------------------------------------------------------------------
 sdks/python/apache_beam/io/avroio_test.py       | 30 ++++-------
 sdks/python/apache_beam/io/iobase.py            |  5 +-
 sdks/python/apache_beam/io/source_test_utils.py | 55 ++++++++++++++++++++
 sdks/python/apache_beam/io/textio.py            | 29 +++++++++--
 sdks/python/apache_beam/io/textio_test.py       | 34 +++---------
 5 files changed, 99 insertions(+), 54 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/bdcb04cb/sdks/python/apache_beam/io/avroio_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/avroio_test.py 
b/sdks/python/apache_beam/io/avroio_test.py
index eb2c81c..f72c3f3 100644
--- a/sdks/python/apache_beam/io/avroio_test.py
+++ b/sdks/python/apache_beam/io/avroio_test.py
@@ -110,7 +110,7 @@ class TestAvro(unittest.TestCase):
     return file_name_prefix + os.path.sep + 'mytemp*'
 
   def _run_avro_test(self, pattern, desired_bundle_size, perform_splitting,
-                     expected_result, test_reentrancy=False):
+                     expected_result):
     source = AvroSource(pattern)
 
     read_records = []
@@ -128,23 +128,9 @@ class TestAvro(unittest.TestCase):
           (split.source, split.start_position, split.stop_position)
           for split in splits
       ]
-      if test_reentrancy:
-        for source_info in sources_info:
-          reader_iter = source_info[0].read(source_info[0].get_range_tracker(
-              source_info[1], source_info[2]))
-          try:
-            next(reader_iter)
-          except StopIteration:
-            # Ignoring empty bundle
-            pass
-
       source_test_utils.assertSourcesEqualReferenceSource((source, None, None),
                                                           sources_info)
     else:
-      if test_reentrancy:
-        reader_iter = source.read(source.get_range_tracker(None, None))
-        next(reader_iter)
-
       read_records = source_test_utils.readFromSource(source, None, None)
       self.assertItemsEqual(expected_result, read_records)
 
@@ -160,15 +146,17 @@ class TestAvro(unittest.TestCase):
 
   def test_read_reentrant_without_splitting(self):
     file_name = self._write_data()
-    expected_result = self.RECORDS
-    self._run_avro_test(file_name, None, False, expected_result,
-                        test_reentrancy=True)
+    source = AvroSource(file_name)
+    source_test_utils.assertReentrantReadsSucceed((source, None, None))
 
   def test_read_reantrant_with_splitting(self):
     file_name = self._write_data()
-    expected_result = self.RECORDS
-    self._run_avro_test(file_name, 100, True, expected_result,
-                        test_reentrancy=True)
+    source = AvroSource(file_name)
+    splits = [
+        split for split in source.split(desired_bundle_size=100000)]
+    assert len(splits) == 1
+    source_test_utils.assertReentrantReadsSucceed(
+        (splits[0].source, splits[0].start_position, splits[0].stop_position))
 
   def test_read_without_splitting_multiple_blocks(self):
     file_name = self._write_data(count=12000)

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/bdcb04cb/sdks/python/apache_beam/io/iobase.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/iobase.py 
b/sdks/python/apache_beam/io/iobase.py
index edd3524..9701964 100644
--- a/sdks/python/apache_beam/io/iobase.py
+++ b/sdks/python/apache_beam/io/iobase.py
@@ -94,11 +94,10 @@ class BoundedSource(object):
 
   **Mutability**
 
-  A ``BoundedSource`` object should be fully mutated before being submitted
-  for reading. A ``BoundedSource`` object should not be mutated while
+  A ``BoundedSource`` object should not be mutated while
   its methods (for example, ``read()``) are being invoked by a runner. Runner
   implementations may invoke methods of ``BoundedSource`` objects through
-  multi-threaded and/or re-entrant execution modes.
+  multi-threaded and/or reentrant execution modes.
   """
 
   def estimate_size(self):

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/bdcb04cb/sdks/python/apache_beam/io/source_test_utils.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/source_test_utils.py 
b/sdks/python/apache_beam/io/source_test_utils.py
index 33ab083..480a95d 100644
--- a/sdks/python/apache_beam/io/source_test_utils.py
+++ b/sdks/python/apache_beam/io/source_test_utils.py
@@ -151,6 +151,61 @@ def 
assertSourcesEqualReferenceSource(reference_source_info, sources_info):
         'same set of records.')
 
 
+def assertReentrantReadsSucceed(source_info):
+  """Tests if a given source can be read in a reentrant manner.
+
+  Assume that given source produces the set of values {v1, v2, v3, ... vn}. For
+  i in range [1, n-1] this method performs a reentrant read after reading i
+  elements and verifies that both the original and reentrant read produce the
+  expected set of values.
+
+  Args:
+    source_info: a three-tuple that gives the reference
+                 ``iobase.BoundedSource``, position to start reading at, and a
+                 position to stop reading at.
+  Raises:
+    ValueError: if source is too trivial or reentrant read result in an
+                incorrect read.
+  """
+
+  source, start_position, stop_position = source_info
+  assert isinstance(source, iobase.BoundedSource)
+
+  expected_values = [val for val in source.read(source.get_range_tracker(
+      start_position, stop_position))]
+  if len(expected_values) < 2:
+    raise ValueError('Source is too trivial since it produces only %d '
+                     'values. Please give a source that reads at least 2 '
+                     'values.', len(expected_values))
+
+  for i in range(1, len(expected_values) - 1):
+    read_iter = source.read(source.get_range_tracker(
+        start_position, stop_position))
+    original_read = []
+    for _ in range(i):
+      original_read.append(next(read_iter))
+
+    # Reentrant read
+    reentrant_read = [val for val in source.read(
+        source.get_range_tracker(start_position, stop_position))]
+
+    # Continuing original read.
+    for val in read_iter:
+      original_read.append(val)
+
+    if sorted(original_read) != sorted(expected_values):
+      raise ValueError('Source did not produce expected values when '
+                       'performing a reentrant read after reading %d values. '
+                       'Expected %r received %r.',
+                       i, expected_values, original_read)
+
+    if sorted(reentrant_read) != sorted(expected_values):
+      raise ValueError('A reentrant read of source after reading %d values '
+                       'did not produce expected values. Expected %r '
+                       'received %r.',
+                       i, expected_values, reentrant_read)
+
+
 def assertSplitAtFractionBehavior(source, num_items_to_read_before_split,
                                   split_fraction, expected_outcome):
   """Verifies the behaviour of splitting a source at a given fraction.

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/bdcb04cb/sdks/python/apache_beam/io/textio.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/textio.py 
b/sdks/python/apache_beam/io/textio.py
index dcaceef..01f6ef6 100644
--- a/sdks/python/apache_beam/io/textio.py
+++ b/sdks/python/apache_beam/io/textio.py
@@ -46,8 +46,29 @@ class _TextSource(filebasedsource.FileBasedSource):
     # buffer that should be read.
 
     def __init__(self, data, position):
-      self.data = data
-      self.position = position
+      self._data = data
+      self._position = position
+
+    @property
+    def data(self):
+      return self._data
+
+    @data.setter
+    def data(self, value):
+      assert isinstance(value, bytes)
+      self._data = value
+
+    @property
+    def position(self):
+      return self._position
+
+    @position.setter
+    def position(self, value):
+      assert isinstance(value, (int, long))
+      if value > len(self._data):
+        raise ValueError('Cannot set position to %d since it\'s larger than '
+                         'size of data %d.', value, len(self._data))
+      self._position = value
 
   def __init__(self, file_pattern, min_bundle_size,
                compression_type, strip_trailing_newlines, coder,
@@ -119,9 +140,11 @@ class _TextSource(filebasedsource.FileBasedSource):
       # array.
       next_lf = read_buffer.data.find('\n', current_pos)
       if next_lf >= 0:
-        if read_buffer.data[next_lf - 1] == '\r':
+        if next_lf > 0 and read_buffer.data[next_lf - 1] == '\r':
+          # Found a '\r\n'. Accepting that as the next separator.
           return (next_lf - 1, next_lf + 1)
         else:
+          # Found a '\n'. Accepting that as the next separator.
           return (next_lf, next_lf + 1)
 
       current_pos = len(read_buffer.data)

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/bdcb04cb/sdks/python/apache_beam/io/textio_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/textio_test.py 
b/sdks/python/apache_beam/io/textio_test.py
index 90ff3cc..81d04ab 100644
--- a/sdks/python/apache_beam/io/textio_test.py
+++ b/sdks/python/apache_beam/io/textio_test.py
@@ -201,39 +201,19 @@ class TextSourceTest(unittest.TestCase):
   def test_read_reentrant_without_splitting(self):
     file_name, expected_data = write_data(10)
     assert len(expected_data) == 10
-    source1 = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
-                         coders.StrUtf8Coder())
-    reader_iter = source1.read(source1.get_range_tracker(None, None))
-    next(reader_iter)
-    next(reader_iter)
-
-    source2 = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
-                         coders.StrUtf8Coder())
-    source_test_utils.assertSourcesEqualReferenceSource((source1, None, None),
-                                                        [(source2, None, 
None)])
+    source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
+                        coders.StrUtf8Coder())
+    source_test_utils.assertReentrantReadsSucceed((source, None, None))
 
   def test_read_reentrant_after_splitting(self):
     file_name, expected_data = write_data(10)
     assert len(expected_data) == 10
     source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                         coders.StrUtf8Coder())
-    splits1 = [split for split in source.split(desired_bundle_size=100000)]
-    assert len(splits1) == 1
-    reader_iter = splits1[0].source.read(
-        splits1[0].source.get_range_tracker(
-            splits1[0].start_position, splits1[0].stop_position))
-    next(reader_iter)
-    next(reader_iter)
-
-    splits2 = [split for split in source.split(desired_bundle_size=100000)]
-    assert len(splits2) == 1
-    source_test_utils.assertSourcesEqualReferenceSource(
-        (splits1[0].source,
-         splits1[0].start_position,
-         splits1[0].stop_position),
-        [(splits2[0].source,
-          splits2[0].start_position,
-          splits2[0].stop_position)])
+    splits = [split for split in source.split(desired_bundle_size=100000)]
+    assert len(splits) == 1
+    source_test_utils.assertReentrantReadsSucceed(
+        (splits[0].source, splits[0].start_position, splits[0].stop_position))
 
   def test_dynamic_work_rebalancing(self):
     file_name, expected_data = write_data(15)

Reply via email to