Repository: incubator-beam Updated Branches: refs/heads/python-sdk ebae225ed -> 4b7fe2dc5
Adds a text source to Python SDK. Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/2d1e7ff6 Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/2d1e7ff6 Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/2d1e7ff6 Branch: refs/heads/python-sdk Commit: 2d1e7ff6d342442f83b97c382f08a03e2bac6572 Parents: ebae225 Author: Chamikara Jayalath <chamik...@google.com> Authored: Mon Aug 29 18:08:46 2016 -0700 Committer: Robert Bradshaw <rober...@google.com> Committed: Fri Sep 16 18:01:46 2016 -0700 ---------------------------------------------------------------------- sdks/python/apache_beam/io/__init__.py | 2 + sdks/python/apache_beam/io/filebasedsource.py | 66 ++- .../apache_beam/io/filebasedsource_test.py | 84 ++-- sdks/python/apache_beam/io/fileio.py | 5 + sdks/python/apache_beam/io/fileio_test.py | 81 ---- sdks/python/apache_beam/io/source_test_utils.py | 26 +- sdks/python/apache_beam/io/textio.py | 264 ++++++++++++ sdks/python/apache_beam/io/textio_test.py | 423 +++++++++++++++++++ .../runners/inprocess/inprocess_runner_test.py | 5 +- 9 files changed, 811 insertions(+), 145 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2d1e7ff6/sdks/python/apache_beam/io/__init__.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/__init__.py b/sdks/python/apache_beam/io/__init__.py index c12b5e3..4ce9872 100644 --- a/sdks/python/apache_beam/io/__init__.py +++ b/sdks/python/apache_beam/io/__init__.py @@ -18,6 +18,7 @@ """A package defining several input sources and output sinks.""" # pylint: disable=wildcard-import +from apache_beam.io.avroio import * from apache_beam.io.bigquery import * from apache_beam.io.fileio import * from apache_beam.io.iobase import Read @@ -25,4 +26,5 @@ from apache_beam.io.iobase import Sink from apache_beam.io.iobase import Write from apache_beam.io.iobase import Writer from apache_beam.io.pubsub import * +from apache_beam.io.textio import * from apache_beam.io.range_trackers import * http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2d1e7ff6/sdks/python/apache_beam/io/filebasedsource.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/filebasedsource.py b/sdks/python/apache_beam/io/filebasedsource.py index 7b9fe47..14de140 100644 --- a/sdks/python/apache_beam/io/filebasedsource.py +++ b/sdks/python/apache_beam/io/filebasedsource.py @@ -27,6 +27,7 @@ For an example implementation of ``FileBasedSource`` see ``avroio.AvroSource``. from multiprocessing.pool import ThreadPool +from apache_beam.internal import pickler from apache_beam.io import fileio from apache_beam.io import iobase @@ -96,7 +97,7 @@ class FileBasedSource(iobase.BoundedSource): file_pattern, min_bundle_size=0, # TODO(BEAM-614) - compression_type=fileio.CompressionTypes.NO_COMPRESSION, + compression_type=fileio.CompressionTypes.UNCOMPRESSED, splittable=True): """Initializes ``FileBasedSource``. @@ -116,9 +117,22 @@ class FileBasedSource(iobase.BoundedSource): it is not possible to efficiently read a data range without decompressing the whole file. Raises: - TypeError: when compression_type is not valid. + TypeError: when compression_type is not valid or if file_pattern is not a + string. ValueError: when compression and splittable files are specified. """ + if not isinstance(file_pattern, basestring): + raise TypeError( + '%s: file_pattern must be a string; got %r instead' % + (self.__class__.__name__, file_pattern)) + + if compression_type == fileio.CompressionTypes.AUTO: + raise ValueError('FileBasedSource currently does not support ' + 'CompressionTypes.AUTO. Please explicitly specify the ' + 'compression type or use ' + 'CompressionTypes.UNCOMPRESSED if file is ' + 'uncompressed.') + self._pattern = file_pattern self._concat_source = None self._min_bundle_size = min_bundle_size @@ -126,7 +140,7 @@ class FileBasedSource(iobase.BoundedSource): raise TypeError('compression_type must be CompressionType object but ' 'was %s' % type(compression_type)) self._compression_type = compression_type - if compression_type != fileio.CompressionTypes.NO_COMPRESSION: + if compression_type != fileio.CompressionTypes.UNCOMPRESSED: # We can't split compressed files efficiently so turn off splitting. self._splittable = False else: @@ -152,14 +166,9 @@ class FileBasedSource(iobase.BoundedSource): return self._concat_source def open_file(self, file_name): - raw_file = fileio.ChannelFactory.open( - file_name, 'rb', 'application/octet-stream') - if self._compression_type == fileio.CompressionTypes.NO_COMPRESSION: - return raw_file - else: - return fileio._CompressedFile( # pylint: disable=protected-access - fileobj=raw_file, - compression_type=self.compression_type) + return fileio.ChannelFactory.open( + file_name, 'rb', 'application/octet-stream', + compression_type=self._compression_type) @staticmethod def _estimate_sizes_in_parallel(file_names): @@ -225,13 +234,15 @@ class _SingleFileSource(iobase.BoundedSource): if not (isinstance(start_offset, int) or isinstance(start_offset, long)): raise ValueError( 'start_offset must be a number. Received: %r', start_offset) - if not (isinstance(stop_offset, int) or isinstance(stop_offset, long)): - raise ValueError( - 'stop_offset must be a number. Received: %r', stop_offset) - if start_offset >= stop_offset: - raise ValueError( - 'start_offset must be smaller than stop_offset. Received %d and %d ' - 'for start and stop offsets respectively', start_offset, stop_offset) + if stop_offset != range_trackers.OffsetRangeTracker.OFFSET_INFINITY: + if not (isinstance(stop_offset, int) or isinstance(stop_offset, long)): + raise ValueError( + 'stop_offset must be a number. Received: %r', stop_offset) + if start_offset >= stop_offset: + raise ValueError( + 'start_offset must be smaller than stop_offset. Received %d and %d ' + 'for start and stop offsets respectively', + start_offset, stop_offset) self._file_name = file_name self._is_gcs_file = file_name.startswith('gs://') if file_name else False @@ -255,7 +266,8 @@ class _SingleFileSource(iobase.BoundedSource): yield iobase.SourceBundle( bundle_stop - bundle_start, _SingleFileSource( - self._file_based_source, + # Copying this so that each sub-source gets a fresh instance. + pickler.loads(pickler.dumps(self._file_based_source)), self._file_name, bundle_start, bundle_stop, @@ -264,17 +276,21 @@ class _SingleFileSource(iobase.BoundedSource): bundle_stop) bundle_start = bundle_stop else: + # Returning a single sub-source with end offset set to OFFSET_INFINITY (so + # that all data of the source gets read) since this source is + # unsplittable. Choosing size of the file as end offset will be wrong for + # certain unsplittable source, e.g., compressed sources. yield iobase.SourceBundle( stop_offset - start_offset, _SingleFileSource( self._file_based_source, self._file_name, start_offset, - stop_offset, + range_trackers.OffsetRangeTracker.OFFSET_INFINITY, min_bundle_size=self._min_bundle_size ), start_offset, - stop_offset + range_trackers.OffsetRangeTracker.OFFSET_INFINITY ) def estimate_size(self): @@ -284,7 +300,13 @@ class _SingleFileSource(iobase.BoundedSource): if start_position is None: start_position = self._start_offset if stop_position is None: - stop_position = self._stop_offset + # If file is unsplittable we choose OFFSET_INFINITY as the default end + # offset so that all data of the source gets read. Choosing size of the + # file as end offset will be wrong for certain unsplittable source, for + # e.g., compressed sources. + stop_position = ( + self._stop_offset if self._file_based_source.splittable + else range_trackers.OffsetRangeTracker.OFFSET_INFINITY) range_tracker = range_trackers.OffsetRangeTracker( start_position, stop_position) http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2d1e7ff6/sdks/python/apache_beam/io/filebasedsource_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/filebasedsource_test.py b/sdks/python/apache_beam/io/filebasedsource_test.py index 725fd34..c4ad026 100644 --- a/sdks/python/apache_beam/io/filebasedsource_test.py +++ b/sdks/python/apache_beam/io/filebasedsource_test.py @@ -60,14 +60,36 @@ class LineSource(FileBasedSource): f.close() -def _write_data(num_lines, directory=None, prefix=tempfile.template): +class EOL(object): + LF = 1 + CRLF = 2 + MIXED = 3 + LF_WITH_NOTHING_AT_LAST_LINE = 4 + + +def write_data( + num_lines, no_data=False, directory=None, prefix=tempfile.template, + eol=EOL.LF): all_data = [] with tempfile.NamedTemporaryFile( delete=False, dir=directory, prefix=prefix) as f: + sep_values = ['\n', '\r\n'] for i in range(num_lines): - data = 'line' + str(i) + data = '' if no_data else 'line' + str(i) all_data.append(data) - f.write(data + '\n') + + if eol == EOL.LF: + sep = sep_values[0] + elif eol == EOL.CRLF: + sep = sep_values[1] + elif eol == EOL.MIXED: + sep = sep_values[i % len(sep_values)] + elif eol == EOL.LF_WITH_NOTHING_AT_LAST_LINE: + sep = '' if i == (num_lines - 1) else sep_values[0] + else: + raise ValueError('Received unknown value %s for eol.', eol) + + f.write(data + sep) return f.name, all_data @@ -79,22 +101,22 @@ def _write_prepared_data(data, directory=None, prefix=tempfile.template): return f.name -def _write_prepared_pattern(data): +def write_prepared_pattern(data): temp_dir = tempfile.mkdtemp() for d in data: file_name = _write_prepared_data(d, temp_dir, prefix='mytemp') return file_name[:file_name.rfind(os.path.sep)] + os.path.sep + 'mytemp*' -def _write_pattern(lines_per_file): +def write_pattern(lines_per_file, no_data=False): temp_dir = tempfile.mkdtemp() all_data = [] file_name = None start_index = 0 for i in range(len(lines_per_file)): - file_name, data = _write_data(lines_per_file[i], - directory=temp_dir, prefix='mytemp') + file_name, data = write_data(lines_per_file[i], no_data=no_data, + directory=temp_dir, prefix='mytemp') all_data.extend(data) start_index += lines_per_file[i] @@ -183,7 +205,7 @@ class TestFileBasedSource(unittest.TestCase): filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2 def test_fully_read_single_file(self): - file_name, expected_data = _write_data(10) + file_name, expected_data = write_data(10) assert len(expected_data) == 10 fbs = LineSource(file_name) range_tracker = fbs.get_range_tracker(None, None) @@ -191,7 +213,7 @@ class TestFileBasedSource(unittest.TestCase): self.assertItemsEqual(expected_data, read_data) def test_fully_read_file_pattern(self): - pattern, expected_data = _write_pattern([5, 3, 12, 8, 8, 4]) + pattern, expected_data = write_pattern([5, 3, 12, 8, 8, 4]) assert len(expected_data) == 40 fbs = LineSource(pattern) range_tracker = fbs.get_range_tracker(None, None) @@ -199,7 +221,7 @@ class TestFileBasedSource(unittest.TestCase): self.assertItemsEqual(expected_data, read_data) def test_fully_read_file_pattern_with_empty_files(self): - pattern, expected_data = _write_pattern([5, 0, 12, 0, 8, 0]) + pattern, expected_data = write_pattern([5, 0, 12, 0, 8, 0]) assert len(expected_data) == 25 fbs = LineSource(pattern) range_tracker = fbs.get_range_tracker(None, None) @@ -207,24 +229,24 @@ class TestFileBasedSource(unittest.TestCase): self.assertItemsEqual(expected_data, read_data) def test_estimate_size_of_file(self): - file_name, expected_data = _write_data(10) + file_name, expected_data = write_data(10) assert len(expected_data) == 10 fbs = LineSource(file_name) self.assertEquals(10 * 6, fbs.estimate_size()) def test_estimate_size_of_pattern(self): - pattern, expected_data = _write_pattern([5, 3, 10, 8, 8, 4]) + pattern, expected_data = write_pattern([5, 3, 10, 8, 8, 4]) assert len(expected_data) == 38 fbs = LineSource(pattern) self.assertEquals(38 * 6, fbs.estimate_size()) - pattern, expected_data = _write_pattern([5, 3, 9]) + pattern, expected_data = write_pattern([5, 3, 9]) assert len(expected_data) == 17 fbs = LineSource(pattern) self.assertEquals(17 * 6, fbs.estimate_size()) def test_splits_into_subranges(self): - pattern, expected_data = _write_pattern([5, 9, 6]) + pattern, expected_data = write_pattern([5, 9, 6]) assert len(expected_data) == 20 fbs = LineSource(pattern) splits = [split for split in fbs.split(desired_bundle_size=15)] @@ -235,7 +257,7 @@ class TestFileBasedSource(unittest.TestCase): assert len(splits) == expected_num_splits def test_read_splits_single_file(self): - file_name, expected_data = _write_data(100) + file_name, expected_data = write_data(100) assert len(expected_data) == 100 fbs = LineSource(file_name) splits = [split for split in fbs.split(desired_bundle_size=33)] @@ -252,7 +274,7 @@ class TestFileBasedSource(unittest.TestCase): self.assertItemsEqual(expected_data, read_data) def test_read_splits_file_pattern(self): - pattern, expected_data = _write_pattern([34, 66, 40, 24, 24, 12]) + pattern, expected_data = write_pattern([34, 66, 40, 24, 24, 12]) assert len(expected_data) == 200 fbs = LineSource(pattern) splits = [split for split in fbs.split(desired_bundle_size=50)] @@ -276,34 +298,34 @@ class TestFileBasedSource(unittest.TestCase): pipeline.run() def test_dataflow_file(self): - file_name, expected_data = _write_data(100) + file_name, expected_data = write_data(100) assert len(expected_data) == 100 self._run_dataflow_test(file_name, expected_data) def test_dataflow_pattern(self): - pattern, expected_data = _write_pattern([34, 66, 40, 24, 24, 12]) + pattern, expected_data = write_pattern([34, 66, 40, 24, 24, 12]) assert len(expected_data) == 200 self._run_dataflow_test(pattern, expected_data) def test_unsplittable_does_not_split(self): - pattern, expected_data = _write_pattern([5, 9, 6]) + pattern, expected_data = write_pattern([5, 9, 6]) assert len(expected_data) == 20 fbs = LineSource(pattern, splittable=False) splits = [split for split in fbs.split(desired_bundle_size=15)] self.assertEquals(3, len(splits)) def test_dataflow_file_unsplittable(self): - file_name, expected_data = _write_data(100) + file_name, expected_data = write_data(100) assert len(expected_data) == 100 self._run_dataflow_test(file_name, expected_data, False) def test_dataflow_pattern_unsplittable(self): - pattern, expected_data = _write_pattern([34, 66, 40, 24, 24, 12]) + pattern, expected_data = write_pattern([34, 66, 40, 24, 24, 12]) assert len(expected_data) == 200 self._run_dataflow_test(pattern, expected_data, False) def test_read_gzip_file(self): - _, lines = _write_data(10) + _, lines = write_data(10) filename = tempfile.NamedTemporaryFile( delete=False, prefix=tempfile.template).name with gzip.GzipFile(filename, 'wb') as f: @@ -317,7 +339,7 @@ class TestFileBasedSource(unittest.TestCase): assert_that(pcoll, equal_to(lines)) def test_read_zlib_file(self): - _, lines = _write_data(10) + _, lines = write_data(10) compressobj = zlib.compressobj( zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, zlib.MAX_WBITS) compressed = compressobj.compress('\n'.join(lines)) + compressobj.flush() @@ -331,7 +353,7 @@ class TestFileBasedSource(unittest.TestCase): assert_that(pcoll, equal_to(lines)) def test_read_zlib_pattern(self): - _, lines = _write_data(200) + _, lines = write_data(200) splits = [0, 34, 100, 140, 164, 188, 200] chunks = [lines[splits[i-1]:splits[i]] for i in xrange(1, len(splits))] compressed_chunks = [] @@ -340,7 +362,7 @@ class TestFileBasedSource(unittest.TestCase): zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, zlib.MAX_WBITS) compressed_chunks.append( compressobj.compress('\n'.join(c)) + compressobj.flush()) - file_pattern = _write_prepared_pattern(compressed_chunks) + file_pattern = write_prepared_pattern(compressed_chunks) pipeline = beam.Pipeline('DirectPipelineRunner') pcoll = pipeline | 'Read' >> beam.Read(LineSource( file_pattern, @@ -407,7 +429,7 @@ class TestSingleFileSource(unittest.TestCase): def test_read_range_at_beginning(self): fbs = LineSource('dymmy_pattern') - file_name, expected_data = _write_data(10) + file_name, expected_data = write_data(10) assert len(expected_data) == 10 source = SingleFileSource(fbs, file_name, 0, 10 * 6) @@ -418,7 +440,7 @@ class TestSingleFileSource(unittest.TestCase): def test_read_range_at_end(self): fbs = LineSource('dymmy_pattern') - file_name, expected_data = _write_data(10) + file_name, expected_data = write_data(10) assert len(expected_data) == 10 source = SingleFileSource(fbs, file_name, 0, 10 * 6) @@ -429,7 +451,7 @@ class TestSingleFileSource(unittest.TestCase): def test_read_range_at_middle(self): fbs = LineSource('dymmy_pattern') - file_name, expected_data = _write_data(10) + file_name, expected_data = write_data(10) assert len(expected_data) == 10 source = SingleFileSource(fbs, file_name, 0, 10 * 6) @@ -440,7 +462,7 @@ class TestSingleFileSource(unittest.TestCase): def test_produces_splits_desiredsize_large_than_size(self): fbs = LineSource('dymmy_pattern') - file_name, expected_data = _write_data(10) + file_name, expected_data = write_data(10) assert len(expected_data) == 10 source = SingleFileSource(fbs, file_name, 0, 10 * 6) splits = [split for split in source.split(desired_bundle_size=100)] @@ -456,7 +478,7 @@ class TestSingleFileSource(unittest.TestCase): def test_produces_splits_desiredsize_smaller_than_size(self): fbs = LineSource('dymmy_pattern') - file_name, expected_data = _write_data(10) + file_name, expected_data = write_data(10) assert len(expected_data) == 10 source = SingleFileSource(fbs, file_name, 0, 10 * 6) splits = [split for split in source.split(desired_bundle_size=25)] @@ -474,7 +496,7 @@ class TestSingleFileSource(unittest.TestCase): def test_produce_split_with_start_and_end_positions(self): fbs = LineSource('dymmy_pattern') - file_name, expected_data = _write_data(10) + file_name, expected_data = write_data(10) assert len(expected_data) == 10 source = SingleFileSource(fbs, file_name, 0, 10 * 6) splits = [split for split in http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2d1e7ff6/sdks/python/apache_beam/io/fileio.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/fileio.py b/sdks/python/apache_beam/io/fileio.py index bc93138..e3d4dae 100644 --- a/sdks/python/apache_beam/io/fileio.py +++ b/sdks/python/apache_beam/io/fileio.py @@ -944,6 +944,11 @@ class TextFileSink(FileSink): compression_type=compression_type) self.append_trailing_newlines = append_trailing_newlines + if type(self) is TextFileSink: + logging.warning('Direct usage of TextFileSink is deprecated. Please use ' + '\'textio.WriteToText()\' instead of directly ' + 'instantiating a TextFileSink object.') + def write_encoded_record(self, file_handle, encoded_value): """Writes a single encoded record.""" file_handle.write(encoded_value) http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2d1e7ff6/sdks/python/apache_beam/io/fileio_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/fileio_test.py b/sdks/python/apache_beam/io/fileio_test.py index 6c6fe12..7da0149 100644 --- a/sdks/python/apache_beam/io/fileio_test.py +++ b/sdks/python/apache_beam/io/fileio_test.py @@ -685,87 +685,6 @@ class TestNativeTextFileSink(unittest.TestCase): zlib.decompress(f.read(), zlib.MAX_WBITS).splitlines(), []) -class TestTextFileSink(unittest.TestCase): - - def setUp(self): - self.lines = ['Line %d' % d for d in range(100)] - self.path = tempfile.NamedTemporaryFile().name - - def _write_lines(self, sink, lines): - f = sink.open(self.path) - for line in lines: - sink.write_record(f, line) - sink.close(f) - - def test_write_text_file(self): - sink = fileio.TextFileSink(self.path) - self._write_lines(sink, self.lines) - - with open(self.path, 'r') as f: - self.assertEqual(f.read().splitlines(), self.lines) - - def test_write_text_file_empty(self): - sink = fileio.TextFileSink(self.path) - self._write_lines(sink, []) - - with open(self.path, 'r') as f: - self.assertEqual(f.read().splitlines(), []) - - def test_write_gzip_file(self): - sink = fileio.TextFileSink( - self.path, compression_type=fileio.CompressionTypes.GZIP) - self._write_lines(sink, self.lines) - - with gzip.GzipFile(self.path, 'r') as f: - self.assertEqual(f.read().splitlines(), self.lines) - - def test_write_gzip_file_auto(self): - self.path = tempfile.NamedTemporaryFile(suffix='.gz').name - sink = fileio.TextFileSink(self.path) - self._write_lines(sink, self.lines) - - with gzip.GzipFile(self.path, 'r') as f: - self.assertEqual(f.read().splitlines(), self.lines) - - def test_write_gzip_file_empty(self): - sink = fileio.TextFileSink( - self.path, compression_type=fileio.CompressionTypes.GZIP) - self._write_lines(sink, []) - - with gzip.GzipFile(self.path, 'r') as f: - self.assertEqual(f.read().splitlines(), []) - - def test_write_zlib_file(self): - sink = fileio.TextFileSink( - self.path, compression_type=fileio.CompressionTypes.ZLIB) - self._write_lines(sink, self.lines) - - with open(self.path, 'r') as f: - content = f.read() - self.assertEqual( - zlib.decompress(content, zlib.MAX_WBITS).splitlines(), self.lines) - - def test_write_zlib_file_auto(self): - self.path = tempfile.NamedTemporaryFile(suffix='.Z').name - sink = fileio.TextFileSink(self.path) - self._write_lines(sink, self.lines) - - with open(self.path, 'r') as f: - content = f.read() - self.assertEqual( - zlib.decompress(content, zlib.MAX_WBITS).splitlines(), self.lines) - - def test_write_zlib_file_empty(self): - sink = fileio.TextFileSink( - self.path, compression_type=fileio.CompressionTypes.ZLIB) - self._write_lines(sink, []) - - with open(self.path, 'r') as f: - content = f.read() - self.assertEqual( - zlib.decompress(content, zlib.MAX_WBITS).splitlines(), []) - - class MyFileSink(fileio.FileSink): def open(self, temp_path): http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2d1e7ff6/sdks/python/apache_beam/io/source_test_utils.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/source_test_utils.py b/sdks/python/apache_beam/io/source_test_utils.py index 9968e9d..4e3a3e3 100644 --- a/sdks/python/apache_beam/io/source_test_utils.py +++ b/sdks/python/apache_beam/io/source_test_utils.py @@ -48,6 +48,7 @@ from collections import namedtuple import logging from multiprocessing.pool import ThreadPool +from apache_beam.internal import pickler from apache_beam.io import iobase @@ -80,7 +81,7 @@ def readFromSource(source, start_position, stop_position): values = [] range_tracker = source.get_range_tracker(start_position, stop_position) assert isinstance(range_tracker, iobase.RangeTracker) - reader = source.read(range_tracker) + reader = _copy_source(source).read(range_tracker) for value in reader: values.append(value) @@ -172,7 +173,7 @@ def assertSplitAtFractionBehavior(source, num_items_to_read_before_split, source while the second value of the tuple will be '-1'. """ assert isinstance(source, iobase.BoundedSource) - expected_items = readFromSource(source, None, None) + expected_items = readFromSource(_copy_source(source), None, None) return _assertSplitAtFractionBehavior( source, expected_items, num_items_to_read_before_split, split_fraction, expected_outcome) @@ -180,12 +181,12 @@ def assertSplitAtFractionBehavior(source, num_items_to_read_before_split, def _assertSplitAtFractionBehavior( source, expected_items, num_items_to_read_before_split, - split_fraction, expected_outcome): + split_fraction, expected_outcome, start_position=None, stop_position=None): - range_tracker = source.get_range_tracker(None, None) + range_tracker = source.get_range_tracker(start_position, stop_position) assert isinstance(range_tracker, iobase.RangeTracker) current_items = [] - reader = source.read(range_tracker) + reader = _copy_source(source).read(range_tracker) # Reading 'num_items_to_read_before_split' items. reader_iter = iter(reader) for _ in range(num_items_to_read_before_split): @@ -353,7 +354,8 @@ def assertSplitAtFractionFails(source, num_items_to_read_before_split, def assertSplitAtFractionBinary(source, expected_items, num_items_to_read_before_split, left_fraction, left_result, - right_fraction, right_result, stats): + right_fraction, right_result, stats, + start_position=None, stop_position=None): """Performs dynamic work rebalancing for fractions within a given range. Asserts that given a start position, a source can be split at every @@ -419,7 +421,9 @@ MAX_CONCURRENT_SPLITTING_TRIALS_PER_ITEM = 100 MAX_CONCURRENT_SPLITTING_TRIALS_TOTAL = 1000 -def assertSplitAtFractionExhaustive(source, perform_multi_threaded_test=True): +def assertSplitAtFractionExhaustive( + source, start_position=None, stop_position=None, + perform_multi_threaded_test=True): """Performs and tests dynamic work rebalancing exhaustively. Asserts that for each possible start position, a source can be split at @@ -436,7 +440,7 @@ def assertSplitAtFractionExhaustive(source, perform_multi_threaded_test=True): ValueError: if the exhaustive splitting test fails. """ - expected_items = readFromSource(source, None, None) + expected_items = readFromSource(source, start_position, stop_position) if not expected_items: raise ValueError('Source %r is empty.', source) @@ -533,7 +537,7 @@ def _assertSplitAtFractionConcurrent( range_tracker = source.get_range_tracker(None, None) stop_position_before_split = range_tracker.stop_position() - reader = source.read(range_tracker) + reader = _copy_source(source).read(range_tracker) reader_iter = iter(reader) current_items = [] @@ -572,3 +576,7 @@ def _assertSplitAtFractionConcurrent( primary_range, residual_range, split_fraction) return res[1] > 0 + + +def _copy_source(source): + return pickler.loads(pickler.dumps(source)) http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2d1e7ff6/sdks/python/apache_beam/io/textio.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/textio.py b/sdks/python/apache_beam/io/textio.py new file mode 100644 index 0000000..28fd949 --- /dev/null +++ b/sdks/python/apache_beam/io/textio.py @@ -0,0 +1,264 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""A source and a sink for reading from and writing to text files.""" + +from __future__ import absolute_import + +from apache_beam import coders +from apache_beam.io import filebasedsource +from apache_beam.io import fileio +from apache_beam.io.iobase import Read +from apache_beam.io.iobase import Write +from apache_beam.transforms import PTransform + + +class _TextSource(filebasedsource.FileBasedSource): + """A source for reading text files. + + Parses a text file as newline-delimited elements. Supports newline delimiters + '\n' and '\r\n. + + This implementation only supports reading text encoded using UTF-8 or + ASCII. + """ + + DEFAULT_READ_BUFFER_SIZE = 8192 + + def __init__(self, file_pattern, min_bundle_size, + compression_type, strip_trailing_newlines, coder, + buffer_size=DEFAULT_READ_BUFFER_SIZE): + super(_TextSource, self).__init__(file_pattern, min_bundle_size, + compression_type=compression_type) + self._buffer = '' + self._next_position_in_buffer = 0 + self._file = None + self._strip_trailing_newlines = strip_trailing_newlines + self._compression_type = compression_type + self._coder = coder + self._buffer_size = buffer_size + + def read_records(self, file_name, range_tracker): + start_offset = range_tracker.start_position() + + self._file = self.open_file(file_name) + try: + if start_offset > 0: + # Seeking to one position before the start index and ignoring the + # current line. If start_position is at beginning if the line, that line + # belongs to the current bundle, hence ignoring that is incorrect. + # Seeking to one byte before prevents that. + + self._file.seek(start_offset - 1) + sep_bounds = self._find_separator_bounds() + if not sep_bounds: + # Could not find a separator after (start_offset - 1). This means that + # none of the records within the file belongs to the current source. + return + + _, sep_end = sep_bounds + self._buffer = self._buffer[sep_end:] + next_record_start_position = start_offset -1 + sep_end + else: + next_record_start_position = 0 + + while range_tracker.try_claim(next_record_start_position): + record, num_bytes_to_next_record = self._read_record() + yield self._coder.decode(record) + if num_bytes_to_next_record < 0: + break + next_record_start_position += num_bytes_to_next_record + finally: + self._file.close() + + def _find_separator_bounds(self): + # Determines the start and end positions within 'self._buffer' of the next + # separator starting from 'self._next_position_in_buffer'. + # Currently supports following separators. + # * '\n' + # * '\r\n' + # This method may increase the size of buffer but it will not decrease the + # size of it. + + current_pos = self._next_position_in_buffer + + while True: + if current_pos >= len(self._buffer): + # Ensuring that there are enough bytes to determine if there is a '\n' + # at current_pos. + if not self._try_to_ensure_num_bytes_in_buffer(current_pos + 1): + return + + # Using find() here is more efficient than a linear scan of the byte + # array. + next_lf = self._buffer.find('\n', current_pos) + if next_lf >= 0: + if self._buffer[next_lf - 1] == '\r': + return (next_lf - 1, next_lf + 1) + else: + return (next_lf, next_lf + 1) + + current_pos = len(self._buffer) + + def _try_to_ensure_num_bytes_in_buffer(self, num_bytes): + # Tries to ensure that there are at least num_bytes bytes in the buffer. + # Returns True if this can be fulfilled, returned False if this cannot be + # fulfilled due to reaching EOF. + while len(self._buffer) < num_bytes: + read_data = self._file.read(self._buffer_size) + if not read_data: + return False + + self._buffer += read_data + + return True + + def _read_record(self): + # Returns a tuple containing the current_record and number of bytes to the + # next record starting from 'self._next_position_in_buffer'. If EOF is + # reached, returns a tuple containing the current record and -1. + + if self._next_position_in_buffer > self._buffer_size: + # Buffer is too large. Truncating it and adjusting + # self._next_position_in_buffer. + self._buffer = self._buffer[self._next_position_in_buffer:] + self._next_position_in_buffer = 0 + + record_start_position_in_buffer = self._next_position_in_buffer + sep_bounds = self._find_separator_bounds() + self._next_position_in_buffer = sep_bounds[1] if sep_bounds else len( + self._buffer) + + if not sep_bounds: + # Reached EOF. Bytes up to the EOF is the next record. Returning '-1' for + # the starting position of the next record. + return (self._buffer[record_start_position_in_buffer:], -1) + + if self._strip_trailing_newlines: + # Current record should not contain the separator. + return (self._buffer[record_start_position_in_buffer:sep_bounds[0]], + sep_bounds[1] - record_start_position_in_buffer) + else: + # Current record should contain the separator. + return (self._buffer[record_start_position_in_buffer:sep_bounds[1]], + sep_bounds[1] - record_start_position_in_buffer) + + +class _TextSink(fileio.TextFileSink): + # TODO: Move code from 'fileio.TextFileSink' to here. + pass + + +class ReadFromText(PTransform): + """A PTransform for reading text files. + + Parses a text file as newline-delimited elements, by default assuming + UTF-8 encoding. Supports newline delimiters '\n' and '\r\n'. + + This implementation only supports reading text encoded using UTF-8 or ASCII. + This does not support other encodings such as UTF-16 or UTF-32.""" + + def __init__(self, file_pattern=None, min_bundle_size=0, + compression_type=fileio.CompressionTypes.UNCOMPRESSED, + strip_trailing_newlines=True, + coder=coders.StrUtf8Coder(), **kwargs): + """Initialize the ReadFromText transform. + + Args: + file_pattern: The file path to read from as a local file path or a GCS + gs:// path. The path can contain glob characters (*, ?, and [...] + sets). + min_bundle_size: Minimum size of bundles that should be generated when + splitting this source into bundles. See + ``FileBasedSource`` for more details. + compression_type: Used to handle compressed input files. Should be an + object of type fileio.CompressionTypes. + strip_trailing_newlines: Indicates whether this source should remove + the newline char in each line it reads before + decoding that line. + coder: Coder used to decode each line. + """ + + super(ReadFromText, self).__init__(**kwargs) + self._file_pattern = file_pattern + self._min_bundle_size = min_bundle_size + self._compression_type = compression_type + self._strip_trailing_newlines = strip_trailing_newlines + self._coder = coder + + def apply(self, pcoll): + return pcoll | Read(_TextSource( + self._file_pattern, + self._min_bundle_size, + self._compression_type, + self._strip_trailing_newlines, + self._coder)) + + +class WriteToText(PTransform): + """A PTransform for writing to text files.""" + + def __init__(self, + file_path_prefix, + file_name_suffix='', + append_trailing_newlines=True, + num_shards=0, + shard_name_template=None, + coder=coders.ToStringCoder(), + compression_type=fileio.CompressionTypes.NO_COMPRESSION, + ): + """Initialize a WriteToText PTransform. + + Args: + file_path_prefix: The file path to write to. The files written will begin + with this prefix, followed by a shard identifier (see num_shards), and + end in a common extension, if given by file_name_suffix. In most cases, + only this argument is specified and num_shards, shard_name_template, and + file_name_suffix use default values. + file_name_suffix: Suffix for the files written. + append_trailing_newlines: indicate whether this sink should write an + additional newline char after writing each element. + num_shards: The number of files (shards) used for output. If not set, the + service will decide on the optimal number of shards. + Constraining the number of shards is likely to reduce + the performance of a pipeline. Setting this value is not recommended + unless you require a specific number of output files. + shard_name_template: A template string containing placeholders for + the shard number and shard count. Currently only '' and + '-SSSSS-of-NNNNN' are patterns accepted by the service. + When constructing a filename for a particular shard number, the + upper-case letters 'S' and 'N' are replaced with the 0-padded shard + number and shard count respectively. This argument can be '' in which + case it behaves as if num_shards was set to 1 and only one file will be + generated. The default pattern used is '-SSSSS-of-NNNNN'. + coder: Coder used to encode each line. + compression_type: Type of compression to use for this sink. + """ + + self._file_path_prefix = file_path_prefix + self._file_name_suffix = file_name_suffix + self._append_trailing_newlines = append_trailing_newlines + self._num_shards = num_shards + self._shard_name_template = shard_name_template + self._coder = coder + self._compression_type = compression_type + + def apply(self, pcoll): + return pcoll | Write(_TextSink( + self._file_path_prefix, self._file_name_suffix, + self._append_trailing_newlines, self._num_shards, + self._shard_name_template, self._coder, self._compression_type)) http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2d1e7ff6/sdks/python/apache_beam/io/textio_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/textio_test.py b/sdks/python/apache_beam/io/textio_test.py new file mode 100644 index 0000000..3fa0f9a --- /dev/null +++ b/sdks/python/apache_beam/io/textio_test.py @@ -0,0 +1,423 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for textio module.""" + +import glob +import gzip +import logging +import tempfile +import unittest +import zlib + +import apache_beam as beam +import apache_beam.io.source_test_utils as source_test_utils + +# Importing following private classes for testing. +from apache_beam.io.textio import _TextSink as TextSink +from apache_beam.io.textio import _TextSource as TextSource + +from apache_beam.io.textio import ReadFromText +from apache_beam.io.textio import WriteToText + +from apache_beam import coders +from apache_beam.io.filebasedsource_test import EOL +from apache_beam.io.filebasedsource_test import write_data +from apache_beam.io.filebasedsource_test import write_pattern +from apache_beam.io.fileio import CompressionTypes + +from apache_beam.transforms.util import assert_that +from apache_beam.transforms.util import equal_to + + +class TextSourceTest(unittest.TestCase): + + # Number of records that will be written by most tests. + DEFAULT_NUM_RECORDS = 100 + + def _run_read_test(self, file_or_pattern, expected_data, + buffer_size=DEFAULT_NUM_RECORDS): + # Since each record usually takes more than 1 byte, default buffer size is + # smaller than the total size of the file. This is done to + # increase test coverage for cases that hit the buffer boundary. + source = TextSource(file_or_pattern, 0, CompressionTypes.UNCOMPRESSED, + True, coders.StrUtf8Coder(), buffer_size) + range_tracker = source.get_range_tracker(None, None) + read_data = [record for record in source.read(range_tracker)] + self.assertItemsEqual(expected_data, read_data) + + def test_read_single_file(self): + file_name, expected_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS) + assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS + self._run_read_test(file_name, expected_data) + + def test_read_single_file_smaller_than_default_buffer(self): + file_name, expected_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS) + self._run_read_test(file_name, expected_data, + buffer_size=TextSource.DEFAULT_READ_BUFFER_SIZE) + + def test_read_single_file_larger_than_default_buffer(self): + file_name, expected_data = write_data(TextSource.DEFAULT_READ_BUFFER_SIZE) + self._run_read_test(file_name, expected_data, + buffer_size=TextSource.DEFAULT_READ_BUFFER_SIZE) + + def test_read_file_pattern(self): + pattern, expected_data = write_pattern( + [TextSourceTest.DEFAULT_NUM_RECORDS * 5, + TextSourceTest.DEFAULT_NUM_RECORDS * 3, + TextSourceTest.DEFAULT_NUM_RECORDS * 12, + TextSourceTest.DEFAULT_NUM_RECORDS * 8, + TextSourceTest.DEFAULT_NUM_RECORDS * 8, + TextSourceTest.DEFAULT_NUM_RECORDS * 4]) + assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS * 40 + self._run_read_test(pattern, expected_data) + + def test_read_single_file_windows_eol(self): + file_name, expected_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS, + eol=EOL.CRLF) + assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS + self._run_read_test(file_name, expected_data) + + def test_read_single_file_mixed_eol(self): + file_name, expected_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS, + eol=EOL.MIXED) + assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS + self._run_read_test(file_name, expected_data) + + def test_read_single_file_last_line_no_eol(self): + file_name, expected_data = write_data( + TextSourceTest.DEFAULT_NUM_RECORDS, + eol=EOL.LF_WITH_NOTHING_AT_LAST_LINE) + assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS + self._run_read_test(file_name, expected_data) + + def test_read_single_file_single_line_no_eol(self): + file_name, expected_data = write_data( + 1, eol=EOL.LF_WITH_NOTHING_AT_LAST_LINE) + + assert len(expected_data) == 1 + self._run_read_test(file_name, expected_data) + + def test_read_empty_single_file(self): + file_name, written_data = write_data( + 1, no_data=True, eol=EOL.LF_WITH_NOTHING_AT_LAST_LINE) + + assert len(written_data) == 1 + # written data has a single entry with an empty string. Reading the source + # should not produce anything since we only wrote a single empty string + # without an end of line character. + self._run_read_test(file_name, []) + + def test_read_single_file_with_empty_lines(self): + file_name, expected_data = write_data( + TextSourceTest.DEFAULT_NUM_RECORDS, no_data=True, eol=EOL.LF) + + assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS + assert not expected_data[0] + + self._run_read_test(file_name, expected_data) + + def test_read_single_file_without_striping_eol_lf(self): + file_name, written_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS, + eol=EOL.LF) + assert len(written_data) == TextSourceTest.DEFAULT_NUM_RECORDS + source = TextSource(file_name, 0, + CompressionTypes.UNCOMPRESSED, + False, coders.StrUtf8Coder()) + + range_tracker = source.get_range_tracker(None, None) + read_data = [record for record in source.read(range_tracker)] + self.assertItemsEqual([line + '\n' for line in written_data], read_data) + + def test_read_single_file_without_striping_eol_crlf(self): + file_name, written_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS, + eol=EOL.CRLF) + assert len(written_data) == TextSourceTest.DEFAULT_NUM_RECORDS + source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, + False, coders.StrUtf8Coder()) + + range_tracker = source.get_range_tracker(None, None) + read_data = [record for record in source.read(range_tracker)] + self.assertItemsEqual([line + '\r\n' for line in written_data], read_data) + + def test_read_file_pattern_with_empty_files(self): + pattern, expected_data = write_pattern( + [5 * TextSourceTest.DEFAULT_NUM_RECORDS, + 3 * TextSourceTest.DEFAULT_NUM_RECORDS, + 12 * TextSourceTest.DEFAULT_NUM_RECORDS, + 8 * TextSourceTest.DEFAULT_NUM_RECORDS, + 8 * TextSourceTest.DEFAULT_NUM_RECORDS, + 4 * TextSourceTest.DEFAULT_NUM_RECORDS], + no_data=True) + assert len(expected_data) == 40 * TextSourceTest.DEFAULT_NUM_RECORDS + assert not expected_data[0] + self._run_read_test(pattern, expected_data) + + def test_read_after_splitting(self): + file_name, expected_data = write_data(10) + assert len(expected_data) == 10 + source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True, + coders.StrUtf8Coder()) + splits = [split for split in source.split(desired_bundle_size=33)] + + reference_source_info = (source, None, None) + sources_info = ([ + (split.source, split.start_position, split.stop_position) for + split in splits]) + source_test_utils.assertSourcesEqualReferenceSource( + reference_source_info, sources_info) + + def test_progress(self): + file_name, expected_data = write_data(10) + assert len(expected_data) == 10 + source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True, + coders.StrUtf8Coder()) + splits = [split for split in source.split(desired_bundle_size=100000)] + assert len(splits) == 1 + fraction_consumed_report = [] + range_tracker = splits[0].source.get_range_tracker( + splits[0].start_position, splits[0].stop_position) + for _ in splits[0].source.read(range_tracker): + fraction_consumed_report.append(range_tracker.fraction_consumed()) + + self.assertEqual( + [float(i) / 10 for i in range(0, 10)], fraction_consumed_report) + + def test_dynamic_work_rebalancing(self): + file_name, expected_data = write_data(15) + assert len(expected_data) == 15 + source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True, + coders.StrUtf8Coder()) + splits = [split for split in source.split(desired_bundle_size=100000)] + assert len(splits) == 1 + source_test_utils.assertSplitAtFractionExhaustive( + splits[0].source, splits[0].start_position, splits[0].stop_position) + + def test_dynamic_work_rebalancing_windows_eol(self): + file_name, expected_data = write_data(15, eol=EOL.CRLF) + assert len(expected_data) == 15 + source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True, + coders.StrUtf8Coder()) + splits = [split for split in source.split(desired_bundle_size=100000)] + assert len(splits) == 1 + source_test_utils.assertSplitAtFractionExhaustive( + splits[0].source, splits[0].start_position, splits[0].stop_position, + perform_multi_threaded_test=False) + + def test_dynamic_work_rebalancing_mixed_eol(self): + file_name, expected_data = write_data(15, eol=EOL.MIXED) + assert len(expected_data) == 15 + source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True, + coders.StrUtf8Coder()) + splits = [split for split in source.split(desired_bundle_size=100000)] + assert len(splits) == 1 + source_test_utils.assertSplitAtFractionExhaustive( + splits[0].source, splits[0].start_position, splits[0].stop_position, + perform_multi_threaded_test=False) + + def test_dataflow_single_file(self): + file_name, expected_data = write_data(5) + assert len(expected_data) == 5 + pipeline = beam.Pipeline('DirectPipelineRunner') + pcoll = pipeline | 'Read' >> ReadFromText(file_name) + assert_that(pcoll, equal_to(expected_data)) + pipeline.run() + + def test_dataflow_single_file_with_coder(self): + class DummyCoder(coders.Coder): + def encode(self, x): + raise ValueError + + def decode(self, x): + return x * 2 + + file_name, expected_data = write_data(5) + assert len(expected_data) == 5 + pipeline = beam.Pipeline('DirectPipelineRunner') + pcoll = pipeline | 'Read' >> ReadFromText(file_name, coder=DummyCoder()) + assert_that(pcoll, equal_to([record * 2 for record in expected_data])) + pipeline.run() + + def test_dataflow_file_pattern(self): + pattern, expected_data = write_pattern([5, 3, 12, 8, 8, 4]) + assert len(expected_data) == 40 + pipeline = beam.Pipeline('DirectPipelineRunner') + pcoll = pipeline | 'Read' >> ReadFromText(pattern) + assert_that(pcoll, equal_to(expected_data)) + pipeline.run() + + def test_read_gzip(self): + _, lines = write_data(15) + file_name = tempfile.NamedTemporaryFile( + delete=False, prefix=tempfile.template).name + with gzip.GzipFile(file_name, 'wb') as f: + f.write('\n'.join(lines)) + + pipeline = beam.Pipeline('DirectPipelineRunner') + pcoll = pipeline | 'Read' >> ReadFromText( + file_name, + 0, CompressionTypes.GZIP, + True, coders.StrUtf8Coder()) + assert_that(pcoll, equal_to(lines)) + pipeline.run() + + def test_read_gzip_large(self): + _, lines = write_data(10000) + file_name = tempfile.NamedTemporaryFile( + delete=False, prefix=tempfile.template).name + with gzip.GzipFile(file_name, 'wb') as f: + f.write('\n'.join(lines)) + + pipeline = beam.Pipeline('DirectPipelineRunner') + pcoll = pipeline | 'Read' >> ReadFromText( + file_name, + 0, CompressionTypes.GZIP, + True, coders.StrUtf8Coder()) + assert_that(pcoll, equal_to(lines)) + pipeline.run() + + def test_read_gzip_large_after_splitting(self): + _, lines = write_data(10000) + file_name = tempfile.NamedTemporaryFile( + delete=False, prefix=tempfile.template).name + with gzip.GzipFile(file_name, 'wb') as f: + f.write('\n'.join(lines)) + + source = TextSource(file_name, 0, CompressionTypes.GZIP, True, + coders.StrUtf8Coder()) + splits = [split for split in source.split(desired_bundle_size=1000)] + + if len(splits) > 1: + raise ValueError('FileBasedSource generated more than one initial split ' + 'for a compressed file.') + + reference_source_info = (source, None, None) + sources_info = ([ + (split.source, split.start_position, split.stop_position) for + split in splits]) + source_test_utils.assertSourcesEqualReferenceSource( + reference_source_info, sources_info) + + def test_read_gzip_empty_file(self): + filename = tempfile.NamedTemporaryFile( + delete=False, prefix=tempfile.template).name + pipeline = beam.Pipeline('DirectPipelineRunner') + pcoll = pipeline | 'Read' >> ReadFromText( + filename, + 0, CompressionTypes.GZIP, + True, coders.StrUtf8Coder()) + assert_that(pcoll, equal_to([])) + pipeline.run() + + +class TextSinkTest(unittest.TestCase): + + def setUp(self): + self.lines = ['Line %d' % d for d in range(100)] + self.path = tempfile.NamedTemporaryFile().name + + def _write_lines(self, sink, lines): + f = sink.open(self.path) + for line in lines: + sink.write_record(f, line) + sink.close(f) + + def test_write_text_file(self): + sink = TextSink(self.path) + self._write_lines(sink, self.lines) + + with open(self.path, 'r') as f: + self.assertEqual(f.read().splitlines(), self.lines) + + def test_write_text_file_empty(self): + sink = TextSink(self.path) + self._write_lines(sink, []) + + with open(self.path, 'r') as f: + self.assertEqual(f.read().splitlines(), []) + + def test_write_gzip_file(self): + sink = TextSink( + self.path, compression_type=CompressionTypes.GZIP) + self._write_lines(sink, self.lines) + + with gzip.GzipFile(self.path, 'r') as f: + self.assertEqual(f.read().splitlines(), self.lines) + + def test_write_gzip_file_auto(self): + self.path = tempfile.NamedTemporaryFile(suffix='.gz').name + sink = TextSink(self.path) + self._write_lines(sink, self.lines) + + with gzip.GzipFile(self.path, 'r') as f: + self.assertEqual(f.read().splitlines(), self.lines) + + def test_write_gzip_file_empty(self): + sink = TextSink( + self.path, compression_type=CompressionTypes.GZIP) + self._write_lines(sink, []) + + with gzip.GzipFile(self.path, 'r') as f: + self.assertEqual(f.read().splitlines(), []) + + def test_write_zlib_file(self): + sink = TextSink( + self.path, compression_type=CompressionTypes.ZLIB) + self._write_lines(sink, self.lines) + + with open(self.path, 'r') as f: + content = f.read() + self.assertEqual( + zlib.decompress(content, zlib.MAX_WBITS).splitlines(), self.lines) + + def test_write_zlib_file_auto(self): + self.path = tempfile.NamedTemporaryFile(suffix='.Z').name + sink = TextSink(self.path) + self._write_lines(sink, self.lines) + + with open(self.path, 'r') as f: + content = f.read() + self.assertEqual( + zlib.decompress(content, zlib.MAX_WBITS).splitlines(), self.lines) + + def test_write_zlib_file_empty(self): + sink = TextSink( + self.path, compression_type=CompressionTypes.ZLIB) + self._write_lines(sink, []) + + with open(self.path, 'r') as f: + content = f.read() + self.assertEqual( + zlib.decompress(content, zlib.MAX_WBITS).splitlines(), []) + + def test_write_dataflow(self): + pipeline = beam.Pipeline('DirectPipelineRunner') + pcoll = pipeline | beam.core.Create('Create', self.lines) + pcoll | 'Write' >> WriteToText(self.path) # pylint: disable=expression-not-assigned + pipeline.run() + + read_result = [] + for file_name in glob.glob(self.path + '*'): + with open(file_name, 'r') as f: + read_result.extend(f.read().splitlines()) + + self.assertEqual(read_result, self.lines) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2d1e7ff6/sdks/python/apache_beam/runners/inprocess/inprocess_runner_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/inprocess/inprocess_runner_test.py b/sdks/python/apache_beam/runners/inprocess/inprocess_runner_test.py index 3ab8383..aa9db24 100644 --- a/sdks/python/apache_beam/runners/inprocess/inprocess_runner_test.py +++ b/sdks/python/apache_beam/runners/inprocess/inprocess_runner_test.py @@ -23,6 +23,7 @@ import unittest from apache_beam import Pipeline import apache_beam.examples.snippets.snippets_test as snippets_test import apache_beam.io.fileio_test as fileio_test +import apache_beam.io.textio_test as textio_test import apache_beam.io.sources_test as sources_test import apache_beam.pipeline_test as pipeline_test import apache_beam.pvalue_test as pvalue_test @@ -96,11 +97,11 @@ class TestNativeTextFileSink( class TestTextFileSink( - TestWithInProcessPipelineRunner, fileio_test.TestTextFileSink): + TestWithInProcessPipelineRunner, textio_test.TextSinkTest): def setUp(self): TestWithInProcessPipelineRunner.setUp(self) - fileio_test.TestTextFileSink.setUp(self) + textio_test.TextSinkTest.setUp(self) class MyFileSink(TestWithInProcessPipelineRunner, fileio_test.MyFileSink):