Repository: beam
Updated Branches:
  refs/heads/master 2eeeaa257 -> 9ec22f173


Add support for reading/writing headers to text files

Footer support is not added since it is not clear
how to do it for source and not sure about the value it provides.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/a4201a13
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/a4201a13
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/a4201a13

Branch: refs/heads/master
Commit: a4201a13c8b1eb178f5ecd563d0a50994f2578a2
Parents: 2eeeaa2
Author: Slaven Bilac <[email protected]>
Authored: Wed Feb 8 10:16:46 2017 -0800
Committer: Robert Bradshaw <[email protected]>
Committed: Thu Feb 9 10:42:00 2017 -0800

----------------------------------------------------------------------
 sdks/python/apache_beam/io/fileio.py      |  17 +-
 sdks/python/apache_beam/io/fileio_test.py |  75 ++++++++-
 sdks/python/apache_beam/io/textio.py      | 109 ++++++++----
 sdks/python/apache_beam/io/textio_test.py | 224 +++++++++++++++++++++----
 4 files changed, 361 insertions(+), 64 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/a4201a13/sdks/python/apache_beam/io/fileio.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/fileio.py 
b/sdks/python/apache_beam/io/fileio.py
index 59fabb3..3b3a183 100644
--- a/sdks/python/apache_beam/io/fileio.py
+++ b/sdks/python/apache_beam/io/fileio.py
@@ -95,7 +95,7 @@ class CompressionTypes(object):
 
   @classmethod
   def detect_compression_type(cls, file_path):
-    """Returns the compression type of a file (based on its suffix)"""
+    """Returns the compression type of a file (based on its suffix)."""
     compression_types_by_suffix = {'.bz2': cls.BZIP2, '.gz': cls.GZIP}
     lowercased_path = file_path.lower()
     for suffix, compression_type in compression_types_by_suffix.iteritems():
@@ -326,7 +326,8 @@ class _CompressedFile(object):
                compression_type=CompressionTypes.GZIP,
                read_size=gcsio.DEFAULT_READ_BUFFER_SIZE):
     if not fileobj:
-      raise ValueError('fileobj must be opened file but was %s' % fileobj)
+      raise ValueError('File object must be opened file but was at %s' %
+                       fileobj)
 
     if not CompressionTypes.is_valid_compression_type(compression_type):
       raise TypeError('compression_type must be CompressionType object but '
@@ -339,6 +340,11 @@ class _CompressedFile(object):
     self._file = fileobj
     self._compression_type = compression_type
 
+    if self._file.tell() != 0:
+      raise ValueError('File object must be at position 0 but was %d' %
+                       self._file.tell())
+    self._uncompressed_position = 0
+
     if self.readable():
       self._read_size = read_size
       self._read_buffer = cStringIO.StringIO()
@@ -375,6 +381,7 @@ class _CompressedFile(object):
     """Write data to file."""
     if not self._compressor:
       raise ValueError('compressor not initialized')
+    self._uncompressed_position += len(data)
     compressed = self._compressor.compress(data)
     if compressed:
       self._file.write(compressed)
@@ -429,6 +436,7 @@ class _CompressedFile(object):
     self._read_buffer.seek(self._read_position)
     result = read_fn()
     self._read_position += len(result)
+    self._uncompressed_position += len(result)
     self._read_buffer.seek(0, os.SEEK_END)  # Allow future writes.
     return result
 
@@ -477,10 +485,15 @@ class _CompressedFile(object):
       self._file.write(self._compressor.flush())
     self._file.flush()
 
+  @property
   def seekable(self):
     # TODO: Add support for seeking to a file position.
     return False
 
+  def tell(self):
+    """Returns current position in uncompressed file."""
+    return self._uncompressed_position
+
   def __enter__(self):
     return self
 

http://git-wip-us.apache.org/repos/asf/beam/blob/a4201a13/sdks/python/apache_beam/io/fileio_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/fileio_test.py 
b/sdks/python/apache_beam/io/fileio_test.py
index 6c33f53..a963c67 100644
--- a/sdks/python/apache_beam/io/fileio_test.py
+++ b/sdks/python/apache_beam/io/fileio_test.py
@@ -21,6 +21,7 @@
 import glob
 import logging
 import os
+import shutil
 import tempfile
 import unittest
 
@@ -79,6 +80,68 @@ class TestChannelFactory(unittest.TestCase):
     gcsio_mock.size.assert_called_once_with('gs://bucket/file2')
 
 
+# TODO: Refactor code so all io tests are using same library
+# TestCaseWithTempDirCleanup class.
+class _TestCaseWithTempDirCleanUp(unittest.TestCase):
+  """Base class for TestCases that deals with TempDir clean-up.
+
+  Inherited test cases will call self._new_tempdir() to start a temporary dir
+  which will be deleted at the end of the tests (when tearDown() is called).
+  """
+
+  def setUp(self):
+    self._tempdirs = []
+
+  def tearDown(self):
+    for path in self._tempdirs:
+      if os.path.exists(path):
+        shutil.rmtree(path)
+    self._tempdirs = []
+
+  def _new_tempdir(self):
+    result = tempfile.mkdtemp()
+    self._tempdirs.append(result)
+    return result
+
+  def _create_temp_file(self, name='', suffix=''):
+    if not name:
+      name = tempfile.template
+    file_name = tempfile.NamedTemporaryFile(
+        delete=False, prefix=name,
+        dir=self._new_tempdir(), suffix=suffix).name
+    return file_name
+
+
+class TestCompressedFile(_TestCaseWithTempDirCleanUp):
+
+  def test_seekable(self):
+    readable = fileio._CompressedFile(open(self._create_temp_file(), 'r'))
+    self.assertFalse(readable.seekable)
+
+    writeable = fileio._CompressedFile(open(self._create_temp_file(), 'w'))
+    self.assertFalse(writeable.seekable)
+
+  def test_tell(self):
+    lines = ['line%d\n' % i for i in range(10)]
+    tmpfile = self._create_temp_file()
+    writeable = fileio._CompressedFile(open(tmpfile, 'w'))
+    current_offset = 0
+    for line in lines:
+      writeable.write(line)
+      current_offset += len(line)
+      self.assertEqual(current_offset, writeable.tell())
+
+    writeable.close()
+    readable = fileio._CompressedFile(open(tmpfile))
+    current_offset = 0
+    while True:
+      line = readable.readline()
+      current_offset += len(line)
+      self.assertEqual(current_offset, readable.tell())
+      if not line:
+        break
+
+
 class MyFileSink(fileio.FileSink):
 
   def open(self, temp_path):
@@ -100,10 +163,10 @@ class MyFileSink(fileio.FileSink):
     file_handle = fileio.FileSink.close(self, file_handle)
 
 
-class TestFileSink(unittest.TestCase):
+class TestFileSink(_TestCaseWithTempDirCleanUp):
 
   def test_file_sink_writing(self):
-    temp_path = tempfile.NamedTemporaryFile().name
+    temp_path = os.path.join(self._new_tempdir(), 'filesink')
     sink = MyFileSink(
         temp_path, file_name_suffix='.foo', coder=coders.ToStringCoder())
 
@@ -136,7 +199,7 @@ class TestFileSink(unittest.TestCase):
     self.assertItemsEqual([shard1, shard2], glob.glob(temp_path + '*'))
 
   def test_file_sink_display_data(self):
-    temp_path = tempfile.NamedTemporaryFile().name
+    temp_path = os.path.join(self._new_tempdir(), 'display')
     sink = MyFileSink(
         temp_path, file_name_suffix='.foo', coder=coders.ToStringCoder())
     dd = DisplayData.create_from(sink)
@@ -161,7 +224,7 @@ class TestFileSink(unittest.TestCase):
         open(temp_path + '-00000-of-00001.foo').read(), '[start][end]')
 
   def test_fixed_shard_write(self):
-    temp_path = tempfile.NamedTemporaryFile().name
+    temp_path = os.path.join(self._new_tempdir(), 'empty')
     sink = MyFileSink(
         temp_path,
         file_name_suffix='.foo',
@@ -180,7 +243,7 @@ class TestFileSink(unittest.TestCase):
     self.assertTrue('][b][' in concat, concat)
 
   def test_file_sink_multi_shards(self):
-    temp_path = tempfile.NamedTemporaryFile().name
+    temp_path = os.path.join(self._new_tempdir(), 'multishard')
     sink = MyFileSink(
         temp_path, file_name_suffix='.foo', coder=coders.ToStringCoder())
 
@@ -215,7 +278,7 @@ class TestFileSink(unittest.TestCase):
     self.assertItemsEqual(res, glob.glob(temp_path + '*'))
 
   def test_file_sink_io_error(self):
-    temp_path = tempfile.NamedTemporaryFile().name
+    temp_path = os.path.join(self._new_tempdir(), 'ioerror')
     sink = MyFileSink(
         temp_path, file_name_suffix='.foo', coder=coders.ToStringCoder())
 

http://git-wip-us.apache.org/repos/asf/beam/blob/a4201a13/sdks/python/apache_beam/io/textio.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/textio.py 
b/sdks/python/apache_beam/io/textio.py
index ed63770..7a35844 100644
--- a/sdks/python/apache_beam/io/textio.py
+++ b/sdks/python/apache_beam/io/textio.py
@@ -17,7 +17,9 @@
 
 """A source and a sink for reading from and writing to text files."""
 
+
 from __future__ import absolute_import
+import logging
 
 from apache_beam import coders
 from apache_beam.io import filebasedsource
@@ -31,7 +33,7 @@ __all__ = ['ReadFromText', 'WriteToText']
 
 
 class _TextSource(filebasedsource.FileBasedSource):
-  """A source for reading text files.
+  r"""A source for reading text files.
 
   Parses a text file as newline-delimited elements. Supports newline delimiters
   '\n' and '\r\n.
@@ -71,9 +73,15 @@ class _TextSource(filebasedsource.FileBasedSource):
                          'size of data %d.', value, len(self._data))
       self._position = value
 
-  def __init__(self, file_pattern, min_bundle_size,
-               compression_type, strip_trailing_newlines, coder,
-               buffer_size=DEFAULT_READ_BUFFER_SIZE, validate=True):
+  def __init__(self,
+               file_pattern,
+               min_bundle_size,
+               compression_type,
+               strip_trailing_newlines,
+               coder,
+               buffer_size=DEFAULT_READ_BUFFER_SIZE,
+               validate=True,
+               skip_header_lines=0):
     super(_TextSource, self).__init__(file_pattern, min_bundle_size,
                                       compression_type=compression_type,
                                       validate=validate)
@@ -82,6 +90,14 @@ class _TextSource(filebasedsource.FileBasedSource):
     self._compression_type = compression_type
     self._coder = coder
     self._buffer_size = buffer_size
+    if skip_header_lines < 0:
+      raise ValueError('Cannot skip negative number of header lines: %d',
+                       skip_header_lines)
+    elif skip_header_lines > 10:
+      logging.warning(
+          'Skipping %d header lines. Skipping large number of header '
+          'lines might significantly slow down processing.')
+    self._skip_header_lines = skip_header_lines
 
   def display_data(self):
     parent_dd = super(_TextSource, self).display_data()
@@ -101,13 +117,18 @@ class _TextSource(filebasedsource.FileBasedSource):
     read_buffer = _TextSource.ReadBuffer('', 0)
 
     with self.open_file(file_name) as file_to_read:
-      if start_offset > 0:
+      position_after_skipping_header_lines = self._skip_lines(
+          file_to_read, read_buffer,
+          self._skip_header_lines) if self._skip_header_lines else 0
+      start_offset = max(start_offset, position_after_skipping_header_lines)
+      if start_offset > position_after_skipping_header_lines:
         # Seeking to one position before the start index and ignoring the
         # current line. If start_position is at beginning if the line, that 
line
         # belongs to the current bundle, hence ignoring that is incorrect.
         # Seeking to one byte before prevents that.
 
         file_to_read.seek(start_offset - 1)
+        read_buffer = _TextSource.ReadBuffer('', 0)
         sep_bounds = self._find_separator_bounds(file_to_read, read_buffer)
         if not sep_bounds:
           # Could not find a separator after (start_offset - 1). This means 
that
@@ -116,14 +137,13 @@ class _TextSource(filebasedsource.FileBasedSource):
 
         _, sep_end = sep_bounds
         read_buffer.data = read_buffer.data[sep_end:]
-        next_record_start_position = start_offset -1 + sep_end
+        next_record_start_position = start_offset - 1 + sep_end
       else:
-        next_record_start_position = 0
+        next_record_start_position = position_after_skipping_header_lines
 
       while range_tracker.try_claim(next_record_start_position):
         record, num_bytes_to_next_record = self._read_record(file_to_read,
                                                              read_buffer)
-
         # For compressed text files that use an unsplittable OffsetRangeTracker
         # with infinity as the end position, above 'try_claim()' invocation
         # would pass for an empty record at the end of file that is not
@@ -184,6 +204,20 @@ class _TextSource(filebasedsource.FileBasedSource):
 
     return True
 
+  def _skip_lines(self, file_to_read, read_buffer, num_lines):
+    """Skip num_lines from file_to_read, return num_lines+1 start position."""
+    if file_to_read.tell() > 0:
+      file_to_read.seek(0)
+    position = 0
+    for _ in range(num_lines):
+      _, num_bytes_to_next_record = self._read_record(file_to_read, 
read_buffer)
+      if num_bytes_to_next_record < 0:
+        # We reached end of file. It is OK to just break here
+        # because subsequent _read_record will return same result.
+        break
+      position += num_bytes_to_next_record
+    return position
+
   def _read_record(self, file_to_read, read_buffer):
     # Returns a tuple containing the current_record and number of bytes to the
     # next record starting from 'self._next_position_in_buffer'. If EOF is
@@ -224,7 +258,8 @@ class _TextSink(fileio.FileSink):
                num_shards=0,
                shard_name_template=None,
                coder=coders.ToStringCoder(),
-               compression_type=fileio.CompressionTypes.AUTO):
+               compression_type=fileio.CompressionTypes.AUTO,
+               header=None):
     """Initialize a _TextSink.
 
     Args:
@@ -251,10 +286,12 @@ class _TextSink(fileio.FileSink):
         generated. The default pattern used is '-SSSSS-of-NNNNN'.
       coder: Coder used to encode each line.
       compression_type: Used to handle compressed output files. Typical value
-          is CompressionTypes.AUTO, in which case the final file path's
-          extension (as determined by file_path_prefix, file_name_suffix,
-          num_shards and shard_name_template) will be used to detect the
-          compression.
+        is CompressionTypes.AUTO, in which case the final file path's
+        extension (as determined by file_path_prefix, file_name_suffix,
+        num_shards and shard_name_template) will be used to detect the
+        compression.
+      header: String to write at beginning of file as a header. If not None and
+        append_trailing_newlines is set, '\n' will be added.
 
     Returns:
       A _TextSink object usable for writing.
@@ -267,7 +304,16 @@ class _TextSink(fileio.FileSink):
         coder=coder,
         mime_type='text/plain',
         compression_type=compression_type)
-    self.append_trailing_newlines = append_trailing_newlines
+    self._append_trailing_newlines = append_trailing_newlines
+    self._header = header
+
+  def open(self, temp_path):
+    file_handle = super(_TextSink, self).open(temp_path)
+    if self._header is not None:
+      file_handle.write(self._header)
+      if self._append_trailing_newlines:
+        file_handle.write('\n')
+    return file_handle
 
   def display_data(self):
     dd_parent = super(_TextSink, self).display_data()
@@ -279,7 +325,7 @@ class _TextSink(fileio.FileSink):
   def write_encoded_record(self, file_handle, encoded_value):
     """Writes a single encoded record."""
     file_handle.write(encoded_value)
-    if self.append_trailing_newlines:
+    if self._append_trailing_newlines:
       file_handle.write('\n')
 
 
@@ -299,6 +345,7 @@ class ReadFromText(PTransform):
       strip_trailing_newlines=True,
       coder=coders.StrUtf8Coder(),
       validate=True,
+      skip_header_lines=0,
       **kwargs):
     """Initialize the ReadFromText transform.
 
@@ -317,14 +364,18 @@ class ReadFromText(PTransform):
                                decoding that line.
       validate: flag to verify that the files exist during the pipeline
                 creation time.
+      skip_header_lines: Number of header lines to skip. Same number is skipped
+                         from each source file. Must be 0 or higher. Large
+                         number of skipped lines might impact performance.
       coder: Coder used to decode each line.
     """
 
     super(ReadFromText, self).__init__(**kwargs)
     self._strip_trailing_newlines = strip_trailing_newlines
-    self._source = _TextSource(file_pattern, min_bundle_size, compression_type,
-                               strip_trailing_newlines, coder,
-                               validate=validate)
+    self._source = _TextSource(
+        file_pattern, min_bundle_size, compression_type,
+        strip_trailing_newlines, coder, validate=validate,
+        skip_header_lines=skip_header_lines)
 
   def expand(self, pvalue):
     return pvalue.pipeline | Read(self._source)
@@ -333,15 +384,15 @@ class ReadFromText(PTransform):
 class WriteToText(PTransform):
   """A PTransform for writing to text files."""
 
-  def __init__(
-      self,
-      file_path_prefix,
-      file_name_suffix='',
-      append_trailing_newlines=True,
-      num_shards=0,
-      shard_name_template=None,
-      coder=coders.ToStringCoder(),
-      compression_type=fileio.CompressionTypes.AUTO):
+  def __init__(self,
+               file_path_prefix,
+               file_name_suffix='',
+               append_trailing_newlines=True,
+               num_shards=0,
+               shard_name_template=None,
+               coder=coders.ToStringCoder(),
+               compression_type=fileio.CompressionTypes.AUTO,
+               header=None):
     """Initialize a WriteToText PTransform.
 
     Args:
@@ -372,11 +423,13 @@ class WriteToText(PTransform):
           extension (as determined by file_path_prefix, file_name_suffix,
           num_shards and shard_name_template) will be used to detect the
           compression.
+      header: String to write at beginning of file as a header. If not None and
+          append_trailing_newlines is set, '\n' will be added.
     """
 
     self._sink = _TextSink(file_path_prefix, file_name_suffix,
                            append_trailing_newlines, num_shards,
-                           shard_name_template, coder, compression_type)
+                           shard_name_template, coder, compression_type, 
header)
 
   def expand(self, pcoll):
     return pcoll | Write(self._sink)

http://git-wip-us.apache.org/repos/asf/beam/blob/a4201a13/sdks/python/apache_beam/io/textio_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/textio_test.py 
b/sdks/python/apache_beam/io/textio_test.py
index a7133ed..ea417b0 100644
--- a/sdks/python/apache_beam/io/textio_test.py
+++ b/sdks/python/apache_beam/io/textio_test.py
@@ -22,6 +22,7 @@ import glob
 import gzip
 import logging
 import os
+import shutil
 import tempfile
 import unittest
 
@@ -47,7 +48,39 @@ from apache_beam.transforms.util import assert_that
 from apache_beam.transforms.util import equal_to
 
 
-class TextSourceTest(unittest.TestCase):
+# TODO: Refactor code so all io tests are using same library
+# TestCaseWithTempDirCleanup class.
+class _TestCaseWithTempDirCleanUp(unittest.TestCase):
+  """Base class for TestCases that deals with TempDir clean-up.
+
+  Inherited test cases will call self._new_tempdir() to start a temporary dir
+  which will be deleted at the end of the tests (when tearDown() is called).
+  """
+
+  def setUp(self):
+    self._tempdirs = []
+
+  def tearDown(self):
+    for path in self._tempdirs:
+      if os.path.exists(path):
+        shutil.rmtree(path)
+    self._tempdirs = []
+
+  def _new_tempdir(self):
+    result = tempfile.mkdtemp()
+    self._tempdirs.append(result)
+    return result
+
+  def _create_temp_file(self, name='', suffix=''):
+    if not name:
+      name = tempfile.template
+    file_name = tempfile.NamedTemporaryFile(
+        delete=False, prefix=name,
+        dir=self._new_tempdir(), suffix=suffix).name
+    return file_name
+
+
+class TextSourceTest(_TestCaseWithTempDirCleanUp):
 
   # Number of records that will be written by most tests.
   DEFAULT_NUM_RECORDS = 100
@@ -322,8 +355,7 @@ class TextSourceTest(unittest.TestCase):
 
   def test_read_auto_bzip2(self):
     _, lines = write_data(15)
-    file_name = tempfile.NamedTemporaryFile(
-        delete=False, prefix=tempfile.template, suffix='.bz2').name
+    file_name = self._create_temp_file(suffix='.bz2')
     with bz2.BZ2File(file_name, 'wb') as f:
       f.write('\n'.join(lines))
 
@@ -334,8 +366,8 @@ class TextSourceTest(unittest.TestCase):
 
   def test_read_auto_gzip(self):
     _, lines = write_data(15)
-    file_name = tempfile.NamedTemporaryFile(
-        delete=False, prefix=tempfile.template, suffix='.gz').name
+    file_name = self._create_temp_file(suffix='.gz')
+
     with gzip.GzipFile(file_name, 'wb') as f:
       f.write('\n'.join(lines))
 
@@ -346,8 +378,7 @@ class TextSourceTest(unittest.TestCase):
 
   def test_read_bzip2(self):
     _, lines = write_data(15)
-    file_name = tempfile.NamedTemporaryFile(
-        delete=False, prefix=tempfile.template).name
+    file_name = self._create_temp_file()
     with bz2.BZ2File(file_name, 'wb') as f:
       f.write('\n'.join(lines))
 
@@ -360,8 +391,7 @@ class TextSourceTest(unittest.TestCase):
 
   def test_read_gzip(self):
     _, lines = write_data(15)
-    file_name = tempfile.NamedTemporaryFile(
-        delete=False, prefix=tempfile.template).name
+    file_name = self._create_temp_file()
     with gzip.GzipFile(file_name, 'wb') as f:
       f.write('\n'.join(lines))
 
@@ -374,9 +404,9 @@ class TextSourceTest(unittest.TestCase):
     pipeline.run()
 
   def test_read_gzip_large(self):
-    _, lines = write_data(1000)
-    file_name = tempfile.NamedTemporaryFile(
-        delete=False, prefix=tempfile.template).name
+    _, lines = write_data(10000)
+    file_name = self._create_temp_file()
+
     with gzip.GzipFile(file_name, 'wb') as f:
       f.write('\n'.join(lines))
 
@@ -389,9 +419,8 @@ class TextSourceTest(unittest.TestCase):
     pipeline.run()
 
   def test_read_gzip_large_after_splitting(self):
-    _, lines = write_data(1000)
-    file_name = tempfile.NamedTemporaryFile(
-        delete=False, prefix=tempfile.template).name
+    _, lines = write_data(10000)
+    file_name = self._create_temp_file()
     with gzip.GzipFile(file_name, 'wb') as f:
       f.write('\n'.join(lines))
 
@@ -411,26 +440,129 @@ class TextSourceTest(unittest.TestCase):
         reference_source_info, sources_info)
 
   def test_read_gzip_empty_file(self):
-    filename = tempfile.NamedTemporaryFile(
-        delete=False, prefix=tempfile.template).name
+    file_name = self._create_temp_file()
     pipeline = TestPipeline()
     pcoll = pipeline | 'Read' >> ReadFromText(
-        filename,
+        file_name,
         0, CompressionTypes.GZIP,
         True, coders.StrUtf8Coder())
     assert_that(pcoll, equal_to([]))
     pipeline.run()
 
+  def _remove_lines(self, lines, sublist_lengths, num_to_remove):
+    """Utility function to remove num_to_remove lines from each sublist.
+
+    Args:
+      lines: list of items.
+      sublist_lengths: list of integers representing length of sublist
+        corresponding to each source file.
+      num_to_remove: number of lines to remove from each sublist.
+    Returns:
+      remaining lines.
+    """
+    curr = 0
+    result = []
+    for offset in sublist_lengths:
+      end = curr + offset
+      start = min(curr + num_to_remove, end)
+      result += lines[start:end]
+      curr += offset
+    return result
+
+  def _read_skip_header_lines(self, file_or_pattern, skip_header_lines):
+    """Simple wrapper function for instantiating TextSource."""
+    source = TextSource(
+        file_or_pattern,
+        0,
+        CompressionTypes.UNCOMPRESSED,
+        True,
+        coders.StrUtf8Coder(),
+        skip_header_lines=skip_header_lines)
 
-class TextSinkTest(unittest.TestCase):
+    range_tracker = source.get_range_tracker(None, None)
+    return [record for record in source.read(range_tracker)]
+
+  def test_read_skip_header_single(self):
+    file_name, expected_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS)
+    assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS
+    skip_header_lines = 1
+    expected_data = self._remove_lines(expected_data,
+                                       [TextSourceTest.DEFAULT_NUM_RECORDS],
+                                       skip_header_lines)
+    read_data = self._read_skip_header_lines(file_name, skip_header_lines)
+    self.assertEqual(len(expected_data), len(read_data))
+    self.assertItemsEqual(expected_data, read_data)
+
+  def test_read_skip_header_pattern(self):
+    line_counts = [
+        TextSourceTest.DEFAULT_NUM_RECORDS * 5,
+        TextSourceTest.DEFAULT_NUM_RECORDS * 3,
+        TextSourceTest.DEFAULT_NUM_RECORDS * 12,
+        TextSourceTest.DEFAULT_NUM_RECORDS * 8,
+        TextSourceTest.DEFAULT_NUM_RECORDS * 8,
+        TextSourceTest.DEFAULT_NUM_RECORDS * 4
+    ]
+    skip_header_lines = 2
+    pattern, data = write_pattern(line_counts)
+
+    expected_data = self._remove_lines(data, line_counts, skip_header_lines)
+    read_data = self._read_skip_header_lines(pattern, skip_header_lines)
+    self.assertEqual(len(expected_data), len(read_data))
+    self.assertItemsEqual(expected_data, read_data)
+
+  def test_read_skip_header_pattern_insufficient_lines(self):
+    line_counts = [
+        5, 3, # Fewer lines in file than we want to skip
+        12, 8, 8, 4
+    ]
+    skip_header_lines = 4
+    pattern, data = write_pattern(line_counts)
+
+    data = self._remove_lines(data, line_counts, skip_header_lines)
+    read_data = self._read_skip_header_lines(pattern, skip_header_lines)
+    self.assertEqual(len(data), len(read_data))
+    self.assertItemsEqual(data, read_data)
+
+  def test_read_gzip_with_skip_lines(self):
+    _, lines = write_data(15)
+    file_name = self._create_temp_file()
+    with gzip.GzipFile(file_name, 'wb') as f:
+      f.write('\n'.join(lines))
+
+    pipeline = beam.Pipeline('DirectRunner')
+    pcoll = pipeline | 'Read' >> ReadFromText(
+        file_name, 0, CompressionTypes.GZIP,
+        True, coders.StrUtf8Coder(), skip_header_lines=2)
+    assert_that(pcoll, equal_to(lines[2:]))
+    pipeline.run()
+
+  def test_read_after_splitting_skip_header(self):
+    file_name, expected_data = write_data(100)
+    assert len(expected_data) == 100
+    source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
+                        coders.StrUtf8Coder(), skip_header_lines=2)
+    splits = [split for split in source.split(desired_bundle_size=33)]
+
+    reference_source_info = (source, None, None)
+    sources_info = ([
+        (split.source, split.start_position, split.stop_position) for
+        split in splits])
+    self.assertGreater(len(sources_info), 1)
+    reference_lines = source_test_utils.readFromSource(*reference_source_info)
+    split_lines = []
+    for source_info in sources_info:
+      split_lines.extend(source_test_utils.readFromSource(*source_info))
+
+    self.assertEqual(expected_data[2:], reference_lines)
+    self.assertEqual(reference_lines, split_lines)
+
+
+class TextSinkTest(_TestCaseWithTempDirCleanUp):
 
   def setUp(self):
+    super(TextSinkTest, self).setUp()
     self.lines = ['Line %d' % d for d in range(100)]
-    self.path = tempfile.NamedTemporaryFile().name
-
-  def tearDown(self):
-    if os.path.exists(self.path):
-      os.remove(self.path)
+    self.path = self._create_temp_file()
 
   def _write_lines(self, sink, lines):
     f = sink.open(self.path)
@@ -461,7 +593,7 @@ class TextSinkTest(unittest.TestCase):
       self.assertEqual(f.read().splitlines(), self.lines)
 
   def test_write_bzip2_file_auto(self):
-    self.path = tempfile.NamedTemporaryFile(suffix='.bz2').name
+    self.path = self._create_temp_file(suffix='.bz2')
     sink = TextSink(self.path)
     self._write_lines(sink, self.lines)
 
@@ -477,7 +609,7 @@ class TextSinkTest(unittest.TestCase):
       self.assertEqual(f.read().splitlines(), self.lines)
 
   def test_write_gzip_file_auto(self):
-    self.path = tempfile.NamedTemporaryFile(suffix='.gz').name
+    self.path = self._create_temp_file(suffix='.gz')
     sink = TextSink(self.path)
     self._write_lines(sink, self.lines)
 
@@ -492,6 +624,22 @@ class TextSinkTest(unittest.TestCase):
     with gzip.GzipFile(self.path, 'r') as f:
       self.assertEqual(f.read().splitlines(), [])
 
+  def test_write_text_file_with_header(self):
+    header = 'header1\nheader2'
+    sink = TextSink(self.path, header=header)
+    self._write_lines(sink, self.lines)
+
+    with open(self.path, 'r') as f:
+      self.assertEqual(f.read().splitlines(), header.splitlines() + self.lines)
+
+  def test_write_text_file_empty_with_header(self):
+    header = 'header1\nheader2'
+    sink = TextSink(self.path, header=header)
+    self._write_lines(sink, [])
+
+    with open(self.path, 'r') as f:
+      self.assertEqual(f.read().splitlines(), header.splitlines())
+
   def test_write_dataflow(self):
     pipeline = TestPipeline()
     pcoll = pipeline | beam.core.Create(self.lines)
@@ -520,8 +668,11 @@ class TextSinkTest(unittest.TestCase):
 
   def test_write_dataflow_auto_compression_unsharded(self):
     pipeline = TestPipeline()
-    pcoll = pipeline | beam.core.Create(self.lines)
-    pcoll | 'Write' >> WriteToText(self.path + '.gz', shard_name_template='')  
# pylint: disable=expression-not-assigned
+    pcoll = pipeline | beam.core.Create('Create', self.lines)
+    pcoll | 'Write' >> WriteToText(  # pylint: disable=expression-not-assigned
+        self.path + '.gz',
+        shard_name_template='')
+
     pipeline.run()
 
     read_result = []
@@ -531,6 +682,23 @@ class TextSinkTest(unittest.TestCase):
 
     self.assertEqual(read_result, self.lines)
 
+  def test_write_dataflow_header(self):
+    pipeline = TestPipeline()
+    pcoll = pipeline | beam.core.Create('Create', self.lines)
+    header_text = 'foo'
+    pcoll | 'Write' >> WriteToText(  # pylint: disable=expression-not-assigned
+        self.path + '.gz',
+        shard_name_template='',
+        header=header_text)
+    pipeline.run()
+
+    read_result = []
+    for file_name in glob.glob(self.path + '*'):
+      with gzip.GzipFile(file_name, 'r') as f:
+        read_result.extend(f.read().splitlines())
+
+    self.assertEqual(read_result, [header_text] + self.lines)
+
 
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)

Reply via email to