Repository: beam Updated Branches: refs/heads/master 2eeeaa257 -> 9ec22f173
Add support for reading/writing headers to text files Footer support is not added since it is not clear how to do it for source and not sure about the value it provides. Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/a4201a13 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/a4201a13 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/a4201a13 Branch: refs/heads/master Commit: a4201a13c8b1eb178f5ecd563d0a50994f2578a2 Parents: 2eeeaa2 Author: Slaven Bilac <[email protected]> Authored: Wed Feb 8 10:16:46 2017 -0800 Committer: Robert Bradshaw <[email protected]> Committed: Thu Feb 9 10:42:00 2017 -0800 ---------------------------------------------------------------------- sdks/python/apache_beam/io/fileio.py | 17 +- sdks/python/apache_beam/io/fileio_test.py | 75 ++++++++- sdks/python/apache_beam/io/textio.py | 109 ++++++++---- sdks/python/apache_beam/io/textio_test.py | 224 +++++++++++++++++++++---- 4 files changed, 361 insertions(+), 64 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/a4201a13/sdks/python/apache_beam/io/fileio.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/fileio.py b/sdks/python/apache_beam/io/fileio.py index 59fabb3..3b3a183 100644 --- a/sdks/python/apache_beam/io/fileio.py +++ b/sdks/python/apache_beam/io/fileio.py @@ -95,7 +95,7 @@ class CompressionTypes(object): @classmethod def detect_compression_type(cls, file_path): - """Returns the compression type of a file (based on its suffix)""" + """Returns the compression type of a file (based on its suffix).""" compression_types_by_suffix = {'.bz2': cls.BZIP2, '.gz': cls.GZIP} lowercased_path = file_path.lower() for suffix, compression_type in compression_types_by_suffix.iteritems(): @@ -326,7 +326,8 @@ class _CompressedFile(object): compression_type=CompressionTypes.GZIP, read_size=gcsio.DEFAULT_READ_BUFFER_SIZE): if not fileobj: - raise ValueError('fileobj must be opened file but was %s' % fileobj) + raise ValueError('File object must be opened file but was at %s' % + fileobj) if not CompressionTypes.is_valid_compression_type(compression_type): raise TypeError('compression_type must be CompressionType object but ' @@ -339,6 +340,11 @@ class _CompressedFile(object): self._file = fileobj self._compression_type = compression_type + if self._file.tell() != 0: + raise ValueError('File object must be at position 0 but was %d' % + self._file.tell()) + self._uncompressed_position = 0 + if self.readable(): self._read_size = read_size self._read_buffer = cStringIO.StringIO() @@ -375,6 +381,7 @@ class _CompressedFile(object): """Write data to file.""" if not self._compressor: raise ValueError('compressor not initialized') + self._uncompressed_position += len(data) compressed = self._compressor.compress(data) if compressed: self._file.write(compressed) @@ -429,6 +436,7 @@ class _CompressedFile(object): self._read_buffer.seek(self._read_position) result = read_fn() self._read_position += len(result) + self._uncompressed_position += len(result) self._read_buffer.seek(0, os.SEEK_END) # Allow future writes. return result @@ -477,10 +485,15 @@ class _CompressedFile(object): self._file.write(self._compressor.flush()) self._file.flush() + @property def seekable(self): # TODO: Add support for seeking to a file position. return False + def tell(self): + """Returns current position in uncompressed file.""" + return self._uncompressed_position + def __enter__(self): return self http://git-wip-us.apache.org/repos/asf/beam/blob/a4201a13/sdks/python/apache_beam/io/fileio_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/fileio_test.py b/sdks/python/apache_beam/io/fileio_test.py index 6c33f53..a963c67 100644 --- a/sdks/python/apache_beam/io/fileio_test.py +++ b/sdks/python/apache_beam/io/fileio_test.py @@ -21,6 +21,7 @@ import glob import logging import os +import shutil import tempfile import unittest @@ -79,6 +80,68 @@ class TestChannelFactory(unittest.TestCase): gcsio_mock.size.assert_called_once_with('gs://bucket/file2') +# TODO: Refactor code so all io tests are using same library +# TestCaseWithTempDirCleanup class. +class _TestCaseWithTempDirCleanUp(unittest.TestCase): + """Base class for TestCases that deals with TempDir clean-up. + + Inherited test cases will call self._new_tempdir() to start a temporary dir + which will be deleted at the end of the tests (when tearDown() is called). + """ + + def setUp(self): + self._tempdirs = [] + + def tearDown(self): + for path in self._tempdirs: + if os.path.exists(path): + shutil.rmtree(path) + self._tempdirs = [] + + def _new_tempdir(self): + result = tempfile.mkdtemp() + self._tempdirs.append(result) + return result + + def _create_temp_file(self, name='', suffix=''): + if not name: + name = tempfile.template + file_name = tempfile.NamedTemporaryFile( + delete=False, prefix=name, + dir=self._new_tempdir(), suffix=suffix).name + return file_name + + +class TestCompressedFile(_TestCaseWithTempDirCleanUp): + + def test_seekable(self): + readable = fileio._CompressedFile(open(self._create_temp_file(), 'r')) + self.assertFalse(readable.seekable) + + writeable = fileio._CompressedFile(open(self._create_temp_file(), 'w')) + self.assertFalse(writeable.seekable) + + def test_tell(self): + lines = ['line%d\n' % i for i in range(10)] + tmpfile = self._create_temp_file() + writeable = fileio._CompressedFile(open(tmpfile, 'w')) + current_offset = 0 + for line in lines: + writeable.write(line) + current_offset += len(line) + self.assertEqual(current_offset, writeable.tell()) + + writeable.close() + readable = fileio._CompressedFile(open(tmpfile)) + current_offset = 0 + while True: + line = readable.readline() + current_offset += len(line) + self.assertEqual(current_offset, readable.tell()) + if not line: + break + + class MyFileSink(fileio.FileSink): def open(self, temp_path): @@ -100,10 +163,10 @@ class MyFileSink(fileio.FileSink): file_handle = fileio.FileSink.close(self, file_handle) -class TestFileSink(unittest.TestCase): +class TestFileSink(_TestCaseWithTempDirCleanUp): def test_file_sink_writing(self): - temp_path = tempfile.NamedTemporaryFile().name + temp_path = os.path.join(self._new_tempdir(), 'filesink') sink = MyFileSink( temp_path, file_name_suffix='.foo', coder=coders.ToStringCoder()) @@ -136,7 +199,7 @@ class TestFileSink(unittest.TestCase): self.assertItemsEqual([shard1, shard2], glob.glob(temp_path + '*')) def test_file_sink_display_data(self): - temp_path = tempfile.NamedTemporaryFile().name + temp_path = os.path.join(self._new_tempdir(), 'display') sink = MyFileSink( temp_path, file_name_suffix='.foo', coder=coders.ToStringCoder()) dd = DisplayData.create_from(sink) @@ -161,7 +224,7 @@ class TestFileSink(unittest.TestCase): open(temp_path + '-00000-of-00001.foo').read(), '[start][end]') def test_fixed_shard_write(self): - temp_path = tempfile.NamedTemporaryFile().name + temp_path = os.path.join(self._new_tempdir(), 'empty') sink = MyFileSink( temp_path, file_name_suffix='.foo', @@ -180,7 +243,7 @@ class TestFileSink(unittest.TestCase): self.assertTrue('][b][' in concat, concat) def test_file_sink_multi_shards(self): - temp_path = tempfile.NamedTemporaryFile().name + temp_path = os.path.join(self._new_tempdir(), 'multishard') sink = MyFileSink( temp_path, file_name_suffix='.foo', coder=coders.ToStringCoder()) @@ -215,7 +278,7 @@ class TestFileSink(unittest.TestCase): self.assertItemsEqual(res, glob.glob(temp_path + '*')) def test_file_sink_io_error(self): - temp_path = tempfile.NamedTemporaryFile().name + temp_path = os.path.join(self._new_tempdir(), 'ioerror') sink = MyFileSink( temp_path, file_name_suffix='.foo', coder=coders.ToStringCoder()) http://git-wip-us.apache.org/repos/asf/beam/blob/a4201a13/sdks/python/apache_beam/io/textio.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/textio.py b/sdks/python/apache_beam/io/textio.py index ed63770..7a35844 100644 --- a/sdks/python/apache_beam/io/textio.py +++ b/sdks/python/apache_beam/io/textio.py @@ -17,7 +17,9 @@ """A source and a sink for reading from and writing to text files.""" + from __future__ import absolute_import +import logging from apache_beam import coders from apache_beam.io import filebasedsource @@ -31,7 +33,7 @@ __all__ = ['ReadFromText', 'WriteToText'] class _TextSource(filebasedsource.FileBasedSource): - """A source for reading text files. + r"""A source for reading text files. Parses a text file as newline-delimited elements. Supports newline delimiters '\n' and '\r\n. @@ -71,9 +73,15 @@ class _TextSource(filebasedsource.FileBasedSource): 'size of data %d.', value, len(self._data)) self._position = value - def __init__(self, file_pattern, min_bundle_size, - compression_type, strip_trailing_newlines, coder, - buffer_size=DEFAULT_READ_BUFFER_SIZE, validate=True): + def __init__(self, + file_pattern, + min_bundle_size, + compression_type, + strip_trailing_newlines, + coder, + buffer_size=DEFAULT_READ_BUFFER_SIZE, + validate=True, + skip_header_lines=0): super(_TextSource, self).__init__(file_pattern, min_bundle_size, compression_type=compression_type, validate=validate) @@ -82,6 +90,14 @@ class _TextSource(filebasedsource.FileBasedSource): self._compression_type = compression_type self._coder = coder self._buffer_size = buffer_size + if skip_header_lines < 0: + raise ValueError('Cannot skip negative number of header lines: %d', + skip_header_lines) + elif skip_header_lines > 10: + logging.warning( + 'Skipping %d header lines. Skipping large number of header ' + 'lines might significantly slow down processing.') + self._skip_header_lines = skip_header_lines def display_data(self): parent_dd = super(_TextSource, self).display_data() @@ -101,13 +117,18 @@ class _TextSource(filebasedsource.FileBasedSource): read_buffer = _TextSource.ReadBuffer('', 0) with self.open_file(file_name) as file_to_read: - if start_offset > 0: + position_after_skipping_header_lines = self._skip_lines( + file_to_read, read_buffer, + self._skip_header_lines) if self._skip_header_lines else 0 + start_offset = max(start_offset, position_after_skipping_header_lines) + if start_offset > position_after_skipping_header_lines: # Seeking to one position before the start index and ignoring the # current line. If start_position is at beginning if the line, that line # belongs to the current bundle, hence ignoring that is incorrect. # Seeking to one byte before prevents that. file_to_read.seek(start_offset - 1) + read_buffer = _TextSource.ReadBuffer('', 0) sep_bounds = self._find_separator_bounds(file_to_read, read_buffer) if not sep_bounds: # Could not find a separator after (start_offset - 1). This means that @@ -116,14 +137,13 @@ class _TextSource(filebasedsource.FileBasedSource): _, sep_end = sep_bounds read_buffer.data = read_buffer.data[sep_end:] - next_record_start_position = start_offset -1 + sep_end + next_record_start_position = start_offset - 1 + sep_end else: - next_record_start_position = 0 + next_record_start_position = position_after_skipping_header_lines while range_tracker.try_claim(next_record_start_position): record, num_bytes_to_next_record = self._read_record(file_to_read, read_buffer) - # For compressed text files that use an unsplittable OffsetRangeTracker # with infinity as the end position, above 'try_claim()' invocation # would pass for an empty record at the end of file that is not @@ -184,6 +204,20 @@ class _TextSource(filebasedsource.FileBasedSource): return True + def _skip_lines(self, file_to_read, read_buffer, num_lines): + """Skip num_lines from file_to_read, return num_lines+1 start position.""" + if file_to_read.tell() > 0: + file_to_read.seek(0) + position = 0 + for _ in range(num_lines): + _, num_bytes_to_next_record = self._read_record(file_to_read, read_buffer) + if num_bytes_to_next_record < 0: + # We reached end of file. It is OK to just break here + # because subsequent _read_record will return same result. + break + position += num_bytes_to_next_record + return position + def _read_record(self, file_to_read, read_buffer): # Returns a tuple containing the current_record and number of bytes to the # next record starting from 'self._next_position_in_buffer'. If EOF is @@ -224,7 +258,8 @@ class _TextSink(fileio.FileSink): num_shards=0, shard_name_template=None, coder=coders.ToStringCoder(), - compression_type=fileio.CompressionTypes.AUTO): + compression_type=fileio.CompressionTypes.AUTO, + header=None): """Initialize a _TextSink. Args: @@ -251,10 +286,12 @@ class _TextSink(fileio.FileSink): generated. The default pattern used is '-SSSSS-of-NNNNN'. coder: Coder used to encode each line. compression_type: Used to handle compressed output files. Typical value - is CompressionTypes.AUTO, in which case the final file path's - extension (as determined by file_path_prefix, file_name_suffix, - num_shards and shard_name_template) will be used to detect the - compression. + is CompressionTypes.AUTO, in which case the final file path's + extension (as determined by file_path_prefix, file_name_suffix, + num_shards and shard_name_template) will be used to detect the + compression. + header: String to write at beginning of file as a header. If not None and + append_trailing_newlines is set, '\n' will be added. Returns: A _TextSink object usable for writing. @@ -267,7 +304,16 @@ class _TextSink(fileio.FileSink): coder=coder, mime_type='text/plain', compression_type=compression_type) - self.append_trailing_newlines = append_trailing_newlines + self._append_trailing_newlines = append_trailing_newlines + self._header = header + + def open(self, temp_path): + file_handle = super(_TextSink, self).open(temp_path) + if self._header is not None: + file_handle.write(self._header) + if self._append_trailing_newlines: + file_handle.write('\n') + return file_handle def display_data(self): dd_parent = super(_TextSink, self).display_data() @@ -279,7 +325,7 @@ class _TextSink(fileio.FileSink): def write_encoded_record(self, file_handle, encoded_value): """Writes a single encoded record.""" file_handle.write(encoded_value) - if self.append_trailing_newlines: + if self._append_trailing_newlines: file_handle.write('\n') @@ -299,6 +345,7 @@ class ReadFromText(PTransform): strip_trailing_newlines=True, coder=coders.StrUtf8Coder(), validate=True, + skip_header_lines=0, **kwargs): """Initialize the ReadFromText transform. @@ -317,14 +364,18 @@ class ReadFromText(PTransform): decoding that line. validate: flag to verify that the files exist during the pipeline creation time. + skip_header_lines: Number of header lines to skip. Same number is skipped + from each source file. Must be 0 or higher. Large + number of skipped lines might impact performance. coder: Coder used to decode each line. """ super(ReadFromText, self).__init__(**kwargs) self._strip_trailing_newlines = strip_trailing_newlines - self._source = _TextSource(file_pattern, min_bundle_size, compression_type, - strip_trailing_newlines, coder, - validate=validate) + self._source = _TextSource( + file_pattern, min_bundle_size, compression_type, + strip_trailing_newlines, coder, validate=validate, + skip_header_lines=skip_header_lines) def expand(self, pvalue): return pvalue.pipeline | Read(self._source) @@ -333,15 +384,15 @@ class ReadFromText(PTransform): class WriteToText(PTransform): """A PTransform for writing to text files.""" - def __init__( - self, - file_path_prefix, - file_name_suffix='', - append_trailing_newlines=True, - num_shards=0, - shard_name_template=None, - coder=coders.ToStringCoder(), - compression_type=fileio.CompressionTypes.AUTO): + def __init__(self, + file_path_prefix, + file_name_suffix='', + append_trailing_newlines=True, + num_shards=0, + shard_name_template=None, + coder=coders.ToStringCoder(), + compression_type=fileio.CompressionTypes.AUTO, + header=None): """Initialize a WriteToText PTransform. Args: @@ -372,11 +423,13 @@ class WriteToText(PTransform): extension (as determined by file_path_prefix, file_name_suffix, num_shards and shard_name_template) will be used to detect the compression. + header: String to write at beginning of file as a header. If not None and + append_trailing_newlines is set, '\n' will be added. """ self._sink = _TextSink(file_path_prefix, file_name_suffix, append_trailing_newlines, num_shards, - shard_name_template, coder, compression_type) + shard_name_template, coder, compression_type, header) def expand(self, pcoll): return pcoll | Write(self._sink) http://git-wip-us.apache.org/repos/asf/beam/blob/a4201a13/sdks/python/apache_beam/io/textio_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/textio_test.py b/sdks/python/apache_beam/io/textio_test.py index a7133ed..ea417b0 100644 --- a/sdks/python/apache_beam/io/textio_test.py +++ b/sdks/python/apache_beam/io/textio_test.py @@ -22,6 +22,7 @@ import glob import gzip import logging import os +import shutil import tempfile import unittest @@ -47,7 +48,39 @@ from apache_beam.transforms.util import assert_that from apache_beam.transforms.util import equal_to -class TextSourceTest(unittest.TestCase): +# TODO: Refactor code so all io tests are using same library +# TestCaseWithTempDirCleanup class. +class _TestCaseWithTempDirCleanUp(unittest.TestCase): + """Base class for TestCases that deals with TempDir clean-up. + + Inherited test cases will call self._new_tempdir() to start a temporary dir + which will be deleted at the end of the tests (when tearDown() is called). + """ + + def setUp(self): + self._tempdirs = [] + + def tearDown(self): + for path in self._tempdirs: + if os.path.exists(path): + shutil.rmtree(path) + self._tempdirs = [] + + def _new_tempdir(self): + result = tempfile.mkdtemp() + self._tempdirs.append(result) + return result + + def _create_temp_file(self, name='', suffix=''): + if not name: + name = tempfile.template + file_name = tempfile.NamedTemporaryFile( + delete=False, prefix=name, + dir=self._new_tempdir(), suffix=suffix).name + return file_name + + +class TextSourceTest(_TestCaseWithTempDirCleanUp): # Number of records that will be written by most tests. DEFAULT_NUM_RECORDS = 100 @@ -322,8 +355,7 @@ class TextSourceTest(unittest.TestCase): def test_read_auto_bzip2(self): _, lines = write_data(15) - file_name = tempfile.NamedTemporaryFile( - delete=False, prefix=tempfile.template, suffix='.bz2').name + file_name = self._create_temp_file(suffix='.bz2') with bz2.BZ2File(file_name, 'wb') as f: f.write('\n'.join(lines)) @@ -334,8 +366,8 @@ class TextSourceTest(unittest.TestCase): def test_read_auto_gzip(self): _, lines = write_data(15) - file_name = tempfile.NamedTemporaryFile( - delete=False, prefix=tempfile.template, suffix='.gz').name + file_name = self._create_temp_file(suffix='.gz') + with gzip.GzipFile(file_name, 'wb') as f: f.write('\n'.join(lines)) @@ -346,8 +378,7 @@ class TextSourceTest(unittest.TestCase): def test_read_bzip2(self): _, lines = write_data(15) - file_name = tempfile.NamedTemporaryFile( - delete=False, prefix=tempfile.template).name + file_name = self._create_temp_file() with bz2.BZ2File(file_name, 'wb') as f: f.write('\n'.join(lines)) @@ -360,8 +391,7 @@ class TextSourceTest(unittest.TestCase): def test_read_gzip(self): _, lines = write_data(15) - file_name = tempfile.NamedTemporaryFile( - delete=False, prefix=tempfile.template).name + file_name = self._create_temp_file() with gzip.GzipFile(file_name, 'wb') as f: f.write('\n'.join(lines)) @@ -374,9 +404,9 @@ class TextSourceTest(unittest.TestCase): pipeline.run() def test_read_gzip_large(self): - _, lines = write_data(1000) - file_name = tempfile.NamedTemporaryFile( - delete=False, prefix=tempfile.template).name + _, lines = write_data(10000) + file_name = self._create_temp_file() + with gzip.GzipFile(file_name, 'wb') as f: f.write('\n'.join(lines)) @@ -389,9 +419,8 @@ class TextSourceTest(unittest.TestCase): pipeline.run() def test_read_gzip_large_after_splitting(self): - _, lines = write_data(1000) - file_name = tempfile.NamedTemporaryFile( - delete=False, prefix=tempfile.template).name + _, lines = write_data(10000) + file_name = self._create_temp_file() with gzip.GzipFile(file_name, 'wb') as f: f.write('\n'.join(lines)) @@ -411,26 +440,129 @@ class TextSourceTest(unittest.TestCase): reference_source_info, sources_info) def test_read_gzip_empty_file(self): - filename = tempfile.NamedTemporaryFile( - delete=False, prefix=tempfile.template).name + file_name = self._create_temp_file() pipeline = TestPipeline() pcoll = pipeline | 'Read' >> ReadFromText( - filename, + file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder()) assert_that(pcoll, equal_to([])) pipeline.run() + def _remove_lines(self, lines, sublist_lengths, num_to_remove): + """Utility function to remove num_to_remove lines from each sublist. + + Args: + lines: list of items. + sublist_lengths: list of integers representing length of sublist + corresponding to each source file. + num_to_remove: number of lines to remove from each sublist. + Returns: + remaining lines. + """ + curr = 0 + result = [] + for offset in sublist_lengths: + end = curr + offset + start = min(curr + num_to_remove, end) + result += lines[start:end] + curr += offset + return result + + def _read_skip_header_lines(self, file_or_pattern, skip_header_lines): + """Simple wrapper function for instantiating TextSource.""" + source = TextSource( + file_or_pattern, + 0, + CompressionTypes.UNCOMPRESSED, + True, + coders.StrUtf8Coder(), + skip_header_lines=skip_header_lines) -class TextSinkTest(unittest.TestCase): + range_tracker = source.get_range_tracker(None, None) + return [record for record in source.read(range_tracker)] + + def test_read_skip_header_single(self): + file_name, expected_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS) + assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS + skip_header_lines = 1 + expected_data = self._remove_lines(expected_data, + [TextSourceTest.DEFAULT_NUM_RECORDS], + skip_header_lines) + read_data = self._read_skip_header_lines(file_name, skip_header_lines) + self.assertEqual(len(expected_data), len(read_data)) + self.assertItemsEqual(expected_data, read_data) + + def test_read_skip_header_pattern(self): + line_counts = [ + TextSourceTest.DEFAULT_NUM_RECORDS * 5, + TextSourceTest.DEFAULT_NUM_RECORDS * 3, + TextSourceTest.DEFAULT_NUM_RECORDS * 12, + TextSourceTest.DEFAULT_NUM_RECORDS * 8, + TextSourceTest.DEFAULT_NUM_RECORDS * 8, + TextSourceTest.DEFAULT_NUM_RECORDS * 4 + ] + skip_header_lines = 2 + pattern, data = write_pattern(line_counts) + + expected_data = self._remove_lines(data, line_counts, skip_header_lines) + read_data = self._read_skip_header_lines(pattern, skip_header_lines) + self.assertEqual(len(expected_data), len(read_data)) + self.assertItemsEqual(expected_data, read_data) + + def test_read_skip_header_pattern_insufficient_lines(self): + line_counts = [ + 5, 3, # Fewer lines in file than we want to skip + 12, 8, 8, 4 + ] + skip_header_lines = 4 + pattern, data = write_pattern(line_counts) + + data = self._remove_lines(data, line_counts, skip_header_lines) + read_data = self._read_skip_header_lines(pattern, skip_header_lines) + self.assertEqual(len(data), len(read_data)) + self.assertItemsEqual(data, read_data) + + def test_read_gzip_with_skip_lines(self): + _, lines = write_data(15) + file_name = self._create_temp_file() + with gzip.GzipFile(file_name, 'wb') as f: + f.write('\n'.join(lines)) + + pipeline = beam.Pipeline('DirectRunner') + pcoll = pipeline | 'Read' >> ReadFromText( + file_name, 0, CompressionTypes.GZIP, + True, coders.StrUtf8Coder(), skip_header_lines=2) + assert_that(pcoll, equal_to(lines[2:])) + pipeline.run() + + def test_read_after_splitting_skip_header(self): + file_name, expected_data = write_data(100) + assert len(expected_data) == 100 + source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True, + coders.StrUtf8Coder(), skip_header_lines=2) + splits = [split for split in source.split(desired_bundle_size=33)] + + reference_source_info = (source, None, None) + sources_info = ([ + (split.source, split.start_position, split.stop_position) for + split in splits]) + self.assertGreater(len(sources_info), 1) + reference_lines = source_test_utils.readFromSource(*reference_source_info) + split_lines = [] + for source_info in sources_info: + split_lines.extend(source_test_utils.readFromSource(*source_info)) + + self.assertEqual(expected_data[2:], reference_lines) + self.assertEqual(reference_lines, split_lines) + + +class TextSinkTest(_TestCaseWithTempDirCleanUp): def setUp(self): + super(TextSinkTest, self).setUp() self.lines = ['Line %d' % d for d in range(100)] - self.path = tempfile.NamedTemporaryFile().name - - def tearDown(self): - if os.path.exists(self.path): - os.remove(self.path) + self.path = self._create_temp_file() def _write_lines(self, sink, lines): f = sink.open(self.path) @@ -461,7 +593,7 @@ class TextSinkTest(unittest.TestCase): self.assertEqual(f.read().splitlines(), self.lines) def test_write_bzip2_file_auto(self): - self.path = tempfile.NamedTemporaryFile(suffix='.bz2').name + self.path = self._create_temp_file(suffix='.bz2') sink = TextSink(self.path) self._write_lines(sink, self.lines) @@ -477,7 +609,7 @@ class TextSinkTest(unittest.TestCase): self.assertEqual(f.read().splitlines(), self.lines) def test_write_gzip_file_auto(self): - self.path = tempfile.NamedTemporaryFile(suffix='.gz').name + self.path = self._create_temp_file(suffix='.gz') sink = TextSink(self.path) self._write_lines(sink, self.lines) @@ -492,6 +624,22 @@ class TextSinkTest(unittest.TestCase): with gzip.GzipFile(self.path, 'r') as f: self.assertEqual(f.read().splitlines(), []) + def test_write_text_file_with_header(self): + header = 'header1\nheader2' + sink = TextSink(self.path, header=header) + self._write_lines(sink, self.lines) + + with open(self.path, 'r') as f: + self.assertEqual(f.read().splitlines(), header.splitlines() + self.lines) + + def test_write_text_file_empty_with_header(self): + header = 'header1\nheader2' + sink = TextSink(self.path, header=header) + self._write_lines(sink, []) + + with open(self.path, 'r') as f: + self.assertEqual(f.read().splitlines(), header.splitlines()) + def test_write_dataflow(self): pipeline = TestPipeline() pcoll = pipeline | beam.core.Create(self.lines) @@ -520,8 +668,11 @@ class TextSinkTest(unittest.TestCase): def test_write_dataflow_auto_compression_unsharded(self): pipeline = TestPipeline() - pcoll = pipeline | beam.core.Create(self.lines) - pcoll | 'Write' >> WriteToText(self.path + '.gz', shard_name_template='') # pylint: disable=expression-not-assigned + pcoll = pipeline | beam.core.Create('Create', self.lines) + pcoll | 'Write' >> WriteToText( # pylint: disable=expression-not-assigned + self.path + '.gz', + shard_name_template='') + pipeline.run() read_result = [] @@ -531,6 +682,23 @@ class TextSinkTest(unittest.TestCase): self.assertEqual(read_result, self.lines) + def test_write_dataflow_header(self): + pipeline = TestPipeline() + pcoll = pipeline | beam.core.Create('Create', self.lines) + header_text = 'foo' + pcoll | 'Write' >> WriteToText( # pylint: disable=expression-not-assigned + self.path + '.gz', + shard_name_template='', + header=header_text) + pipeline.run() + + read_result = [] + for file_name in glob.glob(self.path + '*'): + with gzip.GzipFile(file_name, 'r') as f: + read_result.extend(f.read().splitlines()) + + self.assertEqual(read_result, [header_text] + self.lines) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO)
