Repository: beam Updated Branches: refs/heads/master 4e01fc1ac -> 9088a3e39
Adds two new Read PTransforms that can be used to read a massive number of files. textio.ReadAllFromText is for reading a PCollection of text files/file patterns. avroio.ReadAllFromAvro is for reading a PCollection of Avro files/file patterns. Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/5e998532 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/5e998532 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/5e998532 Branch: refs/heads/master Commit: 5e99853225baff818a7c23020b33ff25b28b23a2 Parents: 4e01fc1 Author: [email protected] <[email protected]> Authored: Fri Jul 28 19:39:02 2017 -0700 Committer: [email protected] <[email protected]> Committed: Thu Aug 10 13:38:18 2017 -0700 ---------------------------------------------------------------------- sdks/python/apache_beam/io/avroio.py | 103 ++++++++---- sdks/python/apache_beam/io/avroio_test.py | 33 +++- sdks/python/apache_beam/io/filebasedsource.py | 165 ++++++++++++++++--- sdks/python/apache_beam/io/range_trackers.py | 42 +++++ .../apache_beam/io/range_trackers_test.py | 37 +++++ sdks/python/apache_beam/io/textio.py | 82 ++++++++- sdks/python/apache_beam/io/textio_test.py | 95 ++++++++++- 7 files changed, 495 insertions(+), 62 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/5e998532/sdks/python/apache_beam/io/avroio.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/avroio.py b/sdks/python/apache_beam/io/avroio.py index 7df9983..47ea282 100644 --- a/sdks/python/apache_beam/io/avroio.py +++ b/sdks/python/apache_beam/io/avroio.py @@ -14,11 +14,38 @@ # See the License for the specific language governing permissions and # limitations under the License. # -"""Implements a source for reading Avro files.""" +"""``PTransforms`` for reading from and writing to Avro files. + +Provides two read ``PTransform``s, ``ReadFromAvro`` and ``ReadAllFromAvro``, +that produces a ``PCollection`` of records. +Each record of this ``PCollection`` will contain a single record read from +an Avro file. Records that are of simple types will be mapped into +corresponding Python types. Records that are of Avro type 'RECORD' will be +mapped to Python dictionaries that comply with the schema contained in the +Avro file that contains those records. In this case, keys of each dictionary +will contain the corresponding field names and will be of type ``string`` +while the values of the dictionary will be of the type defined in the +corresponding Avro schema. + +For example, if schema of the Avro file is the following. +{"namespace": "example.avro","type": "record","name": "User","fields": +[{"name": "name", "type": "string"}, +{"name": "favorite_number", "type": ["int", "null"]}, +{"name": "favorite_color", "type": ["string", "null"]}]} + +Then records generated by read transforms will be dictionaries of the +following form. +{u'name': u'Alyssa', u'favorite_number': 256, u'favorite_color': None}). + +Additionally, this module provides a write ``PTransform`` ``WriteToAvro`` +that can be used to write a given ``PCollection`` of Python objects to an +Avro file. +""" import cStringIO import os import zlib +from functools import partial import avro from avro import datafile @@ -33,40 +60,25 @@ from apache_beam.io.filesystem import CompressionTypes from apache_beam.io.iobase import Read from apache_beam.transforms import PTransform -__all__ = ['ReadFromAvro', 'WriteToAvro'] +__all__ = ['ReadFromAvro', 'ReadAllFromAvro', 'WriteToAvro'] class ReadFromAvro(PTransform): - """A ``PTransform`` for reading avro files.""" + """A ``PTransform`` for reading Avro files. + + Uses source '_AvroSource' to read a set of Avro files defined by a given + file pattern. + If '/mypath/myavrofiles*' is a file-pattern that points to a set of Avro + files, a ``PCollection`` for the records in these Avro files can be created + in the following manner. + + p = df.Pipeline(argv=pipeline_args) + records = p | 'Read' >> df.io.ReadFromAvro('/mypath/myavrofiles*') + """ def __init__(self, file_pattern=None, min_bundle_size=0, validate=True): """Initializes ``ReadFromAvro``. - Uses source '_AvroSource' to read a set of Avro files defined by a given - file pattern. - If '/mypath/myavrofiles*' is a file-pattern that points to a set of Avro - files, a ``PCollection`` for the records in these Avro files can be created - in the following manner. - p = df.Pipeline(argv=pipeline_args) - records = p | 'Read' >> df.io.ReadFromAvro('/mypath/myavrofiles*') - - Each record of this ``PCollection`` will contain a single record read from a - source. Records that are of simple types will be mapped into corresponding - Python types. Records that are of Avro type 'RECORD' will be mapped to - Python dictionaries that comply with the schema contained in the Avro file - that contains those records. In this case, keys of each dictionary - will contain the corresponding field names and will be of type ``string`` - while the values of the dictionary will be of the type defined in the - corresponding Avro schema. - For example, if schema of the Avro file is the following. - {"namespace": "example.avro","type": "record","name": "User","fields": - [{"name": "name", "type": "string"}, - {"name": "favorite_number", "type": ["int", "null"]}, - {"name": "favorite_color", "type": ["string", "null"]}]} - Then records generated by ``AvroSource`` will be dictionaries of the - following form. - {u'name': u'Alyssa', u'favorite_number': 256, u'favorite_color': None}). - Args: file_pattern: the set of files to be read. min_bundle_size: the minimum size in bytes, to be considered when @@ -84,6 +96,35 @@ class ReadFromAvro(PTransform): return {'source_dd': self._source} +class ReadAllFromAvro(PTransform): + """A ``PTransform`` for reading ``PCollection`` of Avro files. + + Uses source '_AvroSource' to read a ``PCollection`` of Avro files or + file patterns and produce a ``PCollection`` of Avro records. + """ + + DEFAULT_DESIRED_BUNDLE_SIZE = 64 * 1024 * 1024 # 64MB + + def __init__(self, min_bundle_size=0, + desired_bundle_size=DEFAULT_DESIRED_BUNDLE_SIZE): + """Initializes ``ReadAllFromAvro``. + + Args: + min_bundle_size: the minimum size in bytes, to be considered when + splitting the input into bundles. + desired_bundle_size: the desired size in bytes, to be considered when + splitting the input into bundles. + """ + source_from_file = partial( + _create_avro_source, min_bundle_size=min_bundle_size) + self._read_all_files = filebasedsource.ReadAllFiles( + True, CompressionTypes.AUTO, desired_bundle_size, min_bundle_size, + source_from_file) + + def expand(self, pvalue): + return pvalue | 'ReadAllFiles' >> self._read_all_files + + class _AvroUtils(object): @staticmethod @@ -176,6 +217,12 @@ class _AvroUtils(object): data = f.read(buf_size) +def _create_avro_source(file_pattern=None, min_bundle_size=None): + return _AvroSource( + file_pattern=file_pattern, min_bundle_size=min_bundle_size, + validate=False) + + class _AvroBlock(object): """Represents a block of an Avro file.""" http://git-wip-us.apache.org/repos/asf/beam/blob/5e998532/sdks/python/apache_beam/io/avroio_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/avroio_test.py b/sdks/python/apache_beam/io/avroio_test.py index 6dcf121..969f440 100644 --- a/sdks/python/apache_beam/io/avroio_test.py +++ b/sdks/python/apache_beam/io/avroio_test.py @@ -22,6 +22,7 @@ import tempfile import unittest import apache_beam as beam +from apache_beam import Create from apache_beam.io import iobase from apache_beam.io import avroio from apache_beam.io import filebasedsource @@ -346,11 +347,41 @@ class TestAvro(unittest.TestCase): source_test_utils.read_from_source(source, None, None) self.assertEqual(0, exn.exception.message.find('Unexpected sync marker')) - def test_source_transform(self): + def test_read_from_avro(self): path = self._write_data() with TestPipeline() as p: assert_that(p | avroio.ReadFromAvro(path), equal_to(self.RECORDS)) + def test_read_all_from_avro_single_file(self): + path = self._write_data() + with TestPipeline() as p: + assert_that(p | Create([path]) | avroio.ReadAllFromAvro(), + equal_to(self.RECORDS)) + + def test_read_all_from_avro_many_single_files(self): + path1 = self._write_data() + path2 = self._write_data() + path3 = self._write_data() + with TestPipeline() as p: + assert_that(p | Create([path1, path2, path3]) | avroio.ReadAllFromAvro(), + equal_to(self.RECORDS * 3)) + + def test_read_all_from_avro_file_pattern(self): + file_pattern = self._write_pattern(5) + with TestPipeline() as p: + assert_that(p | Create([file_pattern]) | avroio.ReadAllFromAvro(), + equal_to(self.RECORDS * 5)) + + def test_read_all_from_avro_many_file_patterns(self): + file_pattern1 = self._write_pattern(5) + file_pattern2 = self._write_pattern(2) + file_pattern3 = self._write_pattern(3) + with TestPipeline() as p: + assert_that(p + | Create([file_pattern1, file_pattern2, file_pattern3]) + | avroio.ReadAllFromAvro(), + equal_to(self.RECORDS * 10)) + def test_sink_transform(self): with tempfile.NamedTemporaryFile() as dst: path = dst.name http://git-wip-us.apache.org/repos/asf/beam/blob/5e998532/sdks/python/apache_beam/io/filebasedsource.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/filebasedsource.py b/sdks/python/apache_beam/io/filebasedsource.py index bb9efc4..f78bf3f 100644 --- a/sdks/python/apache_beam/io/filebasedsource.py +++ b/sdks/python/apache_beam/io/filebasedsource.py @@ -24,17 +24,26 @@ for more details. For an example implementation of ``FileBasedSource`` see ``avroio.AvroSource``. """ - +import uuid + +from apache_beam.transforms.core import DoFn +from apache_beam.transforms.core import ParDo +from apache_beam.transforms.core import GroupByKey +from apache_beam.transforms.core import PTransform +from apache_beam.transforms.core import FlatMap +from apache_beam.transforms.core import Map from apache_beam.internal import pickler from apache_beam.io import concat_source from apache_beam.io import iobase from apache_beam.io import range_trackers from apache_beam.io.filesystem import CompressionTypes from apache_beam.io.filesystems import FileSystems +from apache_beam.io.range_trackers import OffsetRange from apache_beam.transforms.display import DisplayDataItem from apache_beam.options.value_provider import ValueProvider from apache_beam.options.value_provider import StaticValueProvider from apache_beam.options.value_provider import check_accessible +from apache_beam.transforms.trigger import DefaultTrigger MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 25 @@ -95,12 +104,7 @@ class FileBasedSource(iobase.BoundedSource): raise TypeError('compression_type must be CompressionType object but ' 'was %s' % type(compression_type)) self._compression_type = compression_type - if compression_type in (CompressionTypes.UNCOMPRESSED, - CompressionTypes.AUTO): - self._splittable = splittable - else: - # We can't split compressed files efficiently so turn off splitting. - self._splittable = False + self._splittable = splittable if validate and file_pattern.is_accessible(): self._validate() @@ -132,13 +136,10 @@ class FileBasedSource(iobase.BoundedSource): continue # Ignoring empty file. # We determine splittability of this specific file. - splittable = self.splittable - if (splittable and - self._compression_type == CompressionTypes.AUTO): - compression_type = CompressionTypes.detect_compression_type( - file_name) - if compression_type != CompressionTypes.UNCOMPRESSED: - splittable = False + splittable = ( + self.splittable and + _determine_splittability_from_compression_type( + file_name, self._compression_type)) single_file_source = _SingleFileSource( file_based_source_ref, file_name, @@ -211,6 +212,14 @@ class FileBasedSource(iobase.BoundedSource): return self._splittable +def _determine_splittability_from_compression_type( + file_path, compression_type): + if compression_type == CompressionTypes.AUTO: + compression_type = CompressionTypes.detect_compression_type(file_path) + + return compression_type == CompressionTypes.UNCOMPRESSED + + class _SingleFileSource(iobase.BoundedSource): """Denotes a source for a specific file type.""" @@ -244,24 +253,21 @@ class _SingleFileSource(iobase.BoundedSource): stop_offset = self._stop_offset if self._splittable: - bundle_size = max(desired_bundle_size, self._min_bundle_size) - - bundle_start = start_offset - while bundle_start < stop_offset: - bundle_stop = min(bundle_start + bundle_size, stop_offset) + splits = OffsetRange(start_offset, stop_offset).split( + desired_bundle_size, self._min_bundle_size) + for split in splits: yield iobase.SourceBundle( - bundle_stop - bundle_start, + split.stop - split.start, _SingleFileSource( # Copying this so that each sub-source gets a fresh instance. pickler.loads(pickler.dumps(self._file_based_source)), self._file_name, - bundle_start, - bundle_stop, + split.start, + split.stop, min_bundle_size=self._min_bundle_size, splittable=self._splittable), - bundle_start, - bundle_stop) - bundle_start = bundle_stop + split.start, + split.stop) else: # Returning a single sub-source with end offset set to OFFSET_INFINITY (so # that all data of the source gets read) since this source is @@ -308,3 +314,112 @@ class _SingleFileSource(iobase.BoundedSource): def default_output_coder(self): return self._file_based_source.default_output_coder() + + +class _ExpandIntoRanges(DoFn): + + def __init__( + self, splittable, compression_type, desired_bundle_size, min_bundle_size): + self._desired_bundle_size = desired_bundle_size + self._min_bundle_size = min_bundle_size + self._splittable = splittable + self._compression_type = compression_type + + def process(self, element, *args, **kwargs): + match_results = FileSystems.match([element]) + for metadata in match_results[0].metadata_list: + splittable = ( + self._splittable and + _determine_splittability_from_compression_type( + metadata.path, self._compression_type)) + + if splittable: + for split in OffsetRange( + 0, metadata.size_in_bytes).split( + self._desired_bundle_size, self._min_bundle_size): + yield (metadata, split) + else: + yield (metadata, OffsetRange( + 0, range_trackers.OffsetRangeTracker.OFFSET_INFINITY)) + + +# Replace following with a generic reshard transform once +# https://issues.apache.org/jira/browse/BEAM-1872 is implemented. +class _Reshard(PTransform): + + def expand(self, pvalue): + keyed_pc = (pvalue + | 'AssignKey' >> Map(lambda x: (uuid.uuid4(), x))) + if keyed_pc.windowing.windowfn.is_merging(): + raise ValueError('Transform ReadAllFiles cannot be used in the presence ' + 'of merging windows') + if not isinstance(keyed_pc.windowing.triggerfn, DefaultTrigger): + raise ValueError('Transform ReadAllFiles cannot be used in the presence ' + 'of non-trivial triggers') + + return (keyed_pc | 'GroupByKey' >> GroupByKey() + # Using FlatMap below due to the possibility of key collisions. + | 'DropKey' >> FlatMap(lambda (k, values): values)) + + +class _ReadRange(DoFn): + + def __init__(self, source_from_file): + self._source_from_file = source_from_file + + def process(self, element, *args, **kwargs): + metadata, range = element + source = self._source_from_file(metadata.path) + # Following split() operation has to be performed to create a proper + # _SingleFileSource. Otherwise what we have is a ConcatSource that contains + # a single _SingleFileSource. ConcatSource.read() expects a RangeTraker for + # sub-source range and reads full sub-sources (not byte ranges). + source = list(source.split(float('inf')))[0].source + for record in source.read(range.new_tracker()): + yield record + + +class ReadAllFiles(PTransform): + """A Read transform that reads a PCollection of files. + + Pipeline authors should not use this directly. This is to be used by Read + PTransform authors who wishes to implement file-based Read transforms that + read a PCollection of files. + """ + + def __init__( + self, splittable, compression_type, desired_bundle_size, min_bundle_size, + source_from_file): + """ + Args: + splittable: If True, files won't be split into sub-ranges. If False, files + may or may not be split into data ranges. + compression_type: A ``CompressionType`` object that specifies the + compression type of the files that will be processed. If + ``CompressionType.AUTO``, system will try to automatically + determine the compression type based on the extension of + files. + desired_bundle_size: the desired size of data ranges that should be + generated when splitting a file into data ranges. + min_bundle_size: minimum size of data ranges that should be generated when + splitting a file into data ranges. + source_from_file: a function that produces a ``BoundedSource`` given a + file name. System will use this function to generate + ``BoundedSource`` objects for file paths. Note that file + paths passed to this will be for individual files, not + for file patterns even if the ``PCollection`` of files + processed by the transform consist of file patterns. + """ + self._splittable = splittable + self._compression_type = compression_type + self._desired_bundle_size = desired_bundle_size + self._min_bundle_size = min_bundle_size + self._source_from_file = source_from_file + + def expand(self, pvalue): + return (pvalue + | 'ExpandIntoRanges' >> ParDo(_ExpandIntoRanges( + self._splittable, self._compression_type, + self._desired_bundle_size, self._min_bundle_size)) + | 'Reshard' >> _Reshard() + | 'ReadRange' >> ParDo(_ReadRange(self._source_from_file))) http://git-wip-us.apache.org/repos/asf/beam/blob/5e998532/sdks/python/apache_beam/io/range_trackers.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/range_trackers.py b/sdks/python/apache_beam/io/range_trackers.py index bef77d4..4bd19f8 100644 --- a/sdks/python/apache_beam/io/range_trackers.py +++ b/sdks/python/apache_beam/io/range_trackers.py @@ -28,6 +28,48 @@ __all__ = ['OffsetRangeTracker', 'LexicographicKeyRangeTracker', 'OrderedPositionRangeTracker', 'UnsplittableRangeTracker'] +class OffsetRange(object): + + def __init__(self, start, stop): + if start >= stop: + raise ValueError( + 'Start offset must be smaller than the stop offset. ' + 'Received %d and %d respectively.', start, stop) + self.start = start + self.stop = stop + + def __eq__(self, other): + if not isinstance(other, OffsetRange): + return False + + return self.start == other.start and self.stop == other.stop + + def __ne__(self, other): + if not isinstance(other, OffsetRange): + return True + + return not (self.start == other.start and self.stop == other.stop) + + def split(self, desired_num_offsets_per_split, min_num_offsets_per_split=1): + current_split_start = self.start + max_split_size = max(desired_num_offsets_per_split, + min_num_offsets_per_split) + while current_split_start < self.stop: + current_split_stop = min(current_split_start + max_split_size, self.stop) + remaining = self.stop - current_split_stop + + # Avoiding a small split at the end. + if (remaining < desired_num_offsets_per_split / 4 or + remaining < min_num_offsets_per_split): + current_split_stop = self.stop + + yield OffsetRange(current_split_start, current_split_stop) + current_split_start = current_split_stop + + def new_tracker(self): + return OffsetRangeTracker(self.start, self.stop) + + class OffsetRangeTracker(iobase.RangeTracker): """A 'RangeTracker' for non-negative positions of type 'long'.""" http://git-wip-us.apache.org/repos/asf/beam/blob/5e998532/sdks/python/apache_beam/io/range_trackers_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/range_trackers_test.py b/sdks/python/apache_beam/io/range_trackers_test.py index 3e92663..762d654 100644 --- a/sdks/python/apache_beam/io/range_trackers_test.py +++ b/sdks/python/apache_beam/io/range_trackers_test.py @@ -23,6 +23,43 @@ import math import unittest from apache_beam.io import range_trackers +from apache_beam.io.range_trackers import OffsetRange + + +class OffsetRangeTest(unittest.TestCase): + + def test_create(self): + OffsetRange(0, 10) + OffsetRange(10, 100) + + with self.assertRaises(ValueError): + OffsetRange(10, 9) + + def test_split_respects_desired_num_splits(self): + range = OffsetRange(10, 100) + splits = list(range.split(desired_num_offsets_per_split=25)) + self.assertEqual(4, len(splits)) + self.assertIn(OffsetRange(10, 35), splits) + self.assertIn(OffsetRange(35, 60), splits) + self.assertIn(OffsetRange(60, 85), splits) + self.assertIn(OffsetRange(85, 100), splits) + + def test_split_respects_min_num_splits(self): + range = OffsetRange(10, 100) + splits = list(range.split(desired_num_offsets_per_split=5, + min_num_offsets_per_split=25)) + self.assertEqual(3, len(splits)) + self.assertIn(OffsetRange(10, 35), splits) + self.assertIn(OffsetRange(35, 60), splits) + self.assertIn(OffsetRange(60, 100), splits) + + def test_split_no_small_split_at_end(self): + range = OffsetRange(10, 90) + splits = list(range.split(desired_num_offsets_per_split=25)) + self.assertEqual(3, len(splits)) + self.assertIn(OffsetRange(10, 35), splits) + self.assertIn(OffsetRange(35, 60), splits) + self.assertIn(OffsetRange(60, 90), splits) class OffsetRangeTrackerTest(unittest.TestCase): http://git-wip-us.apache.org/repos/asf/beam/blob/5e998532/sdks/python/apache_beam/io/textio.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/textio.py b/sdks/python/apache_beam/io/textio.py index 60e1512..9c6532e 100644 --- a/sdks/python/apache_beam/io/textio.py +++ b/sdks/python/apache_beam/io/textio.py @@ -19,19 +19,21 @@ from __future__ import absolute_import +from functools import partial import logging from apache_beam.coders import coders from apache_beam.io import filebasedsource from apache_beam.io import filebasedsink from apache_beam.io import iobase +from apache_beam.io.filebasedsource import ReadAllFiles from apache_beam.io.filesystem import CompressionTypes from apache_beam.io.iobase import Read from apache_beam.io.iobase import Write from apache_beam.transforms import PTransform from apache_beam.transforms.display import DisplayDataItem -__all__ = ['ReadFromText', 'WriteToText'] +__all__ = ['ReadFromText', 'ReadAllFromText', 'WriteToText'] class _TextSource(filebasedsource.FileBasedSource): @@ -342,8 +344,80 @@ class _TextSink(filebasedsink.FileBasedSink): file_handle.write('\n') +def _create_text_source( + file_pattern=None, min_bundle_size=None, compression_type=None, + strip_trailing_newlines=None, coder=None, skip_header_lines=None): + return _TextSource( + file_pattern=file_pattern, min_bundle_size=min_bundle_size, + compression_type=compression_type, + strip_trailing_newlines=strip_trailing_newlines, + coder=coder, validate=False, skip_header_lines=skip_header_lines) + + +class ReadAllFromText(PTransform): + """A ``PTransform`` for reading a ``PCollection`` of text files. + + Reads a ``PCollection`` of text files or file patterns and and produces a + ``PCollection`` of strings. + + Parses a text file as newline-delimited elements, by default assuming + UTF-8 encoding. Supports newline delimiters '\\n' and '\\r\\n'. + + This implementation only supports reading text encoded using UTF-8 or ASCII. + This does not support other encodings such as UTF-16 or UTF-32. + """ + + DEFAULT_DESIRED_BUNDLE_SIZE = 64 * 1024 * 1024 # 64MB + + def __init__( + self, + min_bundle_size=0, + desired_bundle_size=DEFAULT_DESIRED_BUNDLE_SIZE, + compression_type=CompressionTypes.AUTO, + strip_trailing_newlines=True, + coder=coders.StrUtf8Coder(), + skip_header_lines=0, + **kwargs): + """Initialize the ``ReadAllFromText`` transform. + + Args: + min_bundle_size: Minimum size of bundles that should be generated when + splitting this source into bundles. See ``FileBasedSource`` for more + details. + desired_bundle_size: Desired size of bundles that should be generated when + splitting this source into bundles. See ``FileBasedSource`` for more + details. + compression_type: Used to handle compressed input files. Typical value + is ``CompressionTypes.AUTO``, in which case the underlying file_path's + extension will be used to detect the compression. + strip_trailing_newlines: Indicates whether this source should remove + the newline char in each line it reads before decoding that line. + validate: flag to verify that the files exist during the pipeline + creation time. + skip_header_lines: Number of header lines to skip. Same number is skipped + from each source file. Must be 0 or higher. Large number of skipped + lines might impact performance. + coder: Coder used to decode each line. + """ + super(ReadAllFromText, self).__init__(**kwargs) + source_from_file = partial( + _create_text_source, min_bundle_size=min_bundle_size, + compression_type=compression_type, + strip_trailing_newlines=strip_trailing_newlines, coder=coder, + skip_header_lines=skip_header_lines) + self._desired_bundle_size = desired_bundle_size + self._min_bundle_size = min_bundle_size + self._compression_type = compression_type + self._read_all_files = ReadAllFiles( + True, compression_type, desired_bundle_size, min_bundle_size, + source_from_file) + + def expand(self, pvalue): + return pvalue | 'ReadAllFiles' >> self._read_all_files + + class ReadFromText(PTransform): - """A PTransform for reading text files. + """A ``PTransform`` for reading text files. Parses a text file as newline-delimited elements, by default assuming UTF-8 encoding. Supports newline delimiters '\\n' and '\\r\\n'. @@ -361,7 +435,7 @@ class ReadFromText(PTransform): validate=True, skip_header_lines=0, **kwargs): - """Initialize the ReadFromText transform. + """Initialize the ``ReadFromText`` transform. Args: file_pattern: The file path to read from as a local file path or a GCS @@ -371,7 +445,7 @@ class ReadFromText(PTransform): splitting this source into bundles. See ``FileBasedSource`` for more details. compression_type: Used to handle compressed input files. Typical value - is CompressionTypes.AUTO, in which case the underlying file_path's + is ``CompressionTypes.AUTO``, in which case the underlying file_path's extension will be used to detect the compression. strip_trailing_newlines: Indicates whether this source should remove the newline char in each line it reads before decoding that line. http://git-wip-us.apache.org/repos/asf/beam/blob/5e998532/sdks/python/apache_beam/io/textio_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/textio_test.py b/sdks/python/apache_beam/io/textio_test.py index 8bd7116..b29ca5a 100644 --- a/sdks/python/apache_beam/io/textio_test.py +++ b/sdks/python/apache_beam/io/textio_test.py @@ -27,7 +27,7 @@ import tempfile import unittest import apache_beam as beam -from apache_beam.io import iobase +from apache_beam.io import iobase, ReadAllFromText import apache_beam.io.source_test_utils as source_test_utils # Importing following private classes for testing. @@ -47,6 +47,8 @@ from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to +from apache_beam.transforms.core import Create + # TODO: Refactor code so all io tests are using same library # TestCaseWithTempDirCleanup class. @@ -334,7 +336,7 @@ class TextSourceTest(_TestCaseWithTempDirCleanUp): splits[0].source, splits[0].start_position, splits[0].stop_position, perform_multi_threaded_test=False) - def test_dataflow_single_file(self): + def test_read_from_text_single_file(self): file_name, expected_data = write_data(5) assert len(expected_data) == 5 pipeline = TestPipeline() @@ -342,7 +344,53 @@ class TextSourceTest(_TestCaseWithTempDirCleanUp): assert_that(pcoll, equal_to(expected_data)) pipeline.run() - def test_dataflow_single_file_with_coder(self): + def test_read_all_single_file(self): + file_name, expected_data = write_data(5) + assert len(expected_data) == 5 + pipeline = TestPipeline() + pcoll = pipeline | 'Create' >> Create( + [file_name]) |'ReadAll' >> ReadAllFromText() + assert_that(pcoll, equal_to(expected_data)) + pipeline.run() + + def test_read_all_many_single_files(self): + file_name1, expected_data1 = write_data(5) + assert len(expected_data1) == 5 + file_name2, expected_data2 = write_data(10) + assert len(expected_data2) == 10 + file_name3, expected_data3 = write_data(15) + assert len(expected_data3) == 15 + expected_data = [] + expected_data.extend(expected_data1) + expected_data.extend(expected_data2) + expected_data.extend(expected_data3) + pipeline = TestPipeline() + pcoll = pipeline | 'Create' >> Create( + [file_name1, file_name2, file_name3]) |'ReadAll' >> ReadAllFromText() + assert_that(pcoll, equal_to(expected_data)) + pipeline.run() + + def test_read_all_unavailable_files_ignored(self): + file_name1, expected_data1 = write_data(5) + assert len(expected_data1) == 5 + file_name2, expected_data2 = write_data(10) + assert len(expected_data2) == 10 + file_name3, expected_data3 = write_data(15) + assert len(expected_data3) == 15 + file_name4 = "/unavailable_file" + expected_data = [] + expected_data.extend(expected_data1) + expected_data.extend(expected_data2) + expected_data.extend(expected_data3) + pipeline = TestPipeline() + pcoll = (pipeline + | 'Create' >> Create( + [file_name1, file_name2, file_name3, file_name4]) + |'ReadAll' >> ReadAllFromText()) + assert_that(pcoll, equal_to(expected_data)) + pipeline.run() + + def test_read_from_text_single_file_with_coder(self): class DummyCoder(coders.Coder): def encode(self, x): raise ValueError @@ -357,7 +405,7 @@ class TextSourceTest(_TestCaseWithTempDirCleanUp): assert_that(pcoll, equal_to([record * 2 for record in expected_data])) pipeline.run() - def test_dataflow_file_pattern(self): + def test_read_from_text_file_pattern(self): pattern, expected_data = write_pattern([5, 3, 12, 8, 8, 4]) assert len(expected_data) == 40 pipeline = TestPipeline() @@ -365,6 +413,33 @@ class TextSourceTest(_TestCaseWithTempDirCleanUp): assert_that(pcoll, equal_to(expected_data)) pipeline.run() + def test_read_all_file_pattern(self): + pattern, expected_data = write_pattern([5, 3, 12, 8, 8, 4]) + assert len(expected_data) == 40 + pipeline = TestPipeline() + pcoll = (pipeline + | 'Create' >> Create([pattern]) + |'ReadAll' >> ReadAllFromText()) + assert_that(pcoll, equal_to(expected_data)) + pipeline.run() + + def test_read_all_many_file_patterns(self): + pattern1, expected_data1 = write_pattern([5, 3, 12, 8, 8, 4]) + assert len(expected_data1) == 40 + pattern2, expected_data2 = write_pattern([3, 7, 9]) + assert len(expected_data2) == 19 + pattern3, expected_data3 = write_pattern([11, 20, 5, 5]) + assert len(expected_data3) == 41 + expected_data = [] + expected_data.extend(expected_data1) + expected_data.extend(expected_data2) + expected_data.extend(expected_data3) + pipeline = TestPipeline() + pcoll = pipeline | 'Create' >> Create( + [pattern1, pattern2, pattern3]) |'ReadAll' >> ReadAllFromText() + assert_that(pcoll, equal_to(expected_data)) + pipeline.run() + def test_read_auto_bzip2(self): _, lines = write_data(15) file_name = self._create_temp_file(suffix='.bz2') @@ -528,6 +603,18 @@ class TextSourceTest(_TestCaseWithTempDirCleanUp): expected = ['a', 'b', 'c', 'p', 'q', 'r', 'x', 'y', 'z'] assert_that(lines, equal_to(expected)) + + def test_read_all_gzip(self): + _, lines = write_data(100) + file_name = self._create_temp_file() + with gzip.GzipFile(file_name, 'wb') as f: + f.write('\n'.join(lines)) + pipeline = TestPipeline() + pcoll = (pipeline + | Create([file_name]) + | 'ReadAll' >> ReadAllFromText( + compression_type=CompressionTypes.GZIP)) + assert_that(pcoll, equal_to(lines)) pipeline.run() def test_read_gzip_large(self):
