Repository: beam
Updated Branches:
  refs/heads/master 4e01fc1ac -> 9088a3e39


Adds two new Read PTransforms that can be used to read a massive number of 
files.

textio.ReadAllFromText is for reading a PCollection of text files/file patterns.
avroio.ReadAllFromAvro is for reading a PCollection of Avro files/file patterns.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/5e998532
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/5e998532
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/5e998532

Branch: refs/heads/master
Commit: 5e99853225baff818a7c23020b33ff25b28b23a2
Parents: 4e01fc1
Author: chamik...@google.com <chamik...@google.com>
Authored: Fri Jul 28 19:39:02 2017 -0700
Committer: chamik...@google.com <chamik...@google.com>
Committed: Thu Aug 10 13:38:18 2017 -0700

----------------------------------------------------------------------
 sdks/python/apache_beam/io/avroio.py            | 103 ++++++++----
 sdks/python/apache_beam/io/avroio_test.py       |  33 +++-
 sdks/python/apache_beam/io/filebasedsource.py   | 165 ++++++++++++++++---
 sdks/python/apache_beam/io/range_trackers.py    |  42 +++++
 .../apache_beam/io/range_trackers_test.py       |  37 +++++
 sdks/python/apache_beam/io/textio.py            |  82 ++++++++-
 sdks/python/apache_beam/io/textio_test.py       |  95 ++++++++++-
 7 files changed, 495 insertions(+), 62 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/5e998532/sdks/python/apache_beam/io/avroio.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/avroio.py 
b/sdks/python/apache_beam/io/avroio.py
index 7df9983..47ea282 100644
--- a/sdks/python/apache_beam/io/avroio.py
+++ b/sdks/python/apache_beam/io/avroio.py
@@ -14,11 +14,38 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-"""Implements a source for reading Avro files."""
+"""``PTransforms`` for reading from and writing to Avro files.
+
+Provides two read ``PTransform``s, ``ReadFromAvro`` and ``ReadAllFromAvro``,
+that produces a ``PCollection`` of records.
+Each record of this ``PCollection`` will contain a single record read from
+an Avro file. Records that are of simple types will be mapped into
+corresponding Python types. Records that are of Avro type 'RECORD' will be
+mapped to Python dictionaries that comply with the schema contained in the
+Avro file that contains those records. In this case, keys of each dictionary
+will contain the corresponding field names and will be of type ``string``
+while the values of the dictionary will be of the type defined in the
+corresponding Avro schema.
+
+For example, if schema of the Avro file is the following.
+{"namespace": "example.avro","type": "record","name": "User","fields":
+[{"name": "name", "type": "string"},
+{"name": "favorite_number",  "type": ["int", "null"]},
+{"name": "favorite_color", "type": ["string", "null"]}]}
+
+Then records generated by read transforms will be dictionaries of the
+following form.
+{u'name': u'Alyssa', u'favorite_number': 256, u'favorite_color': None}).
+
+Additionally, this module provides a write ``PTransform`` ``WriteToAvro``
+that can be used to write a given ``PCollection`` of Python objects to an
+Avro file.
+"""
 
 import cStringIO
 import os
 import zlib
+from functools import partial
 
 import avro
 from avro import datafile
@@ -33,40 +60,25 @@ from apache_beam.io.filesystem import CompressionTypes
 from apache_beam.io.iobase import Read
 from apache_beam.transforms import PTransform
 
-__all__ = ['ReadFromAvro', 'WriteToAvro']
+__all__ = ['ReadFromAvro', 'ReadAllFromAvro', 'WriteToAvro']
 
 
 class ReadFromAvro(PTransform):
-  """A ``PTransform`` for reading avro files."""
+  """A ``PTransform`` for reading Avro files.
+
+  Uses source '_AvroSource' to read a set of Avro files defined by a given
+  file pattern.
+  If '/mypath/myavrofiles*' is a file-pattern that points to a set of Avro
+  files, a ``PCollection`` for the records in these Avro files can be created
+  in the following manner.
+
+  p = df.Pipeline(argv=pipeline_args)
+  records = p | 'Read' >> df.io.ReadFromAvro('/mypath/myavrofiles*')
+  """
 
   def __init__(self, file_pattern=None, min_bundle_size=0, validate=True):
     """Initializes ``ReadFromAvro``.
 
-    Uses source '_AvroSource' to read a set of Avro files defined by a given
-    file pattern.
-    If '/mypath/myavrofiles*' is a file-pattern that points to a set of Avro
-    files, a ``PCollection`` for the records in these Avro files can be created
-    in the following manner.
-      p = df.Pipeline(argv=pipeline_args)
-      records = p | 'Read' >> df.io.ReadFromAvro('/mypath/myavrofiles*')
-
-    Each record of this ``PCollection`` will contain a single record read from 
a
-    source. Records that are of simple types will be mapped into corresponding
-    Python types. Records that are of Avro type 'RECORD' will be mapped to
-    Python dictionaries that comply with the schema contained in the Avro file
-    that contains those records. In this case, keys of each dictionary
-    will contain the corresponding field names and will be of type ``string``
-    while the values of the dictionary will be of the type defined in the
-    corresponding Avro schema.
-    For example, if schema of the Avro file is the following.
-      {"namespace": "example.avro","type": "record","name": "User","fields":
-      [{"name": "name", "type": "string"},
-       {"name": "favorite_number",  "type": ["int", "null"]},
-       {"name": "favorite_color", "type": ["string", "null"]}]}
-    Then records generated by ``AvroSource`` will be dictionaries of the
-    following form.
-      {u'name': u'Alyssa', u'favorite_number': 256, u'favorite_color': None}).
-
     Args:
       file_pattern: the set of files to be read.
       min_bundle_size: the minimum size in bytes, to be considered when
@@ -84,6 +96,35 @@ class ReadFromAvro(PTransform):
     return {'source_dd': self._source}
 
 
+class ReadAllFromAvro(PTransform):
+  """A ``PTransform`` for reading ``PCollection`` of Avro files.
+
+   Uses source '_AvroSource' to read a ``PCollection`` of Avro files or
+   file patterns and produce a ``PCollection`` of Avro records.
+  """
+
+  DEFAULT_DESIRED_BUNDLE_SIZE = 64 * 1024 * 1024  # 64MB
+
+  def __init__(self, min_bundle_size=0,
+               desired_bundle_size=DEFAULT_DESIRED_BUNDLE_SIZE):
+    """Initializes ``ReadAllFromAvro``.
+
+    Args:
+      min_bundle_size: the minimum size in bytes, to be considered when
+                       splitting the input into bundles.
+      desired_bundle_size: the desired size in bytes, to be considered when
+                       splitting the input into bundles.
+    """
+    source_from_file = partial(
+        _create_avro_source, min_bundle_size=min_bundle_size)
+    self._read_all_files = filebasedsource.ReadAllFiles(
+        True, CompressionTypes.AUTO, desired_bundle_size, min_bundle_size,
+        source_from_file)
+
+  def expand(self, pvalue):
+    return pvalue | 'ReadAllFiles' >> self._read_all_files
+
+
 class _AvroUtils(object):
 
   @staticmethod
@@ -176,6 +217,12 @@ class _AvroUtils(object):
         data = f.read(buf_size)
 
 
+def _create_avro_source(file_pattern=None, min_bundle_size=None):
+  return _AvroSource(
+      file_pattern=file_pattern, min_bundle_size=min_bundle_size,
+      validate=False)
+
+
 class _AvroBlock(object):
   """Represents a block of an Avro file."""
 

http://git-wip-us.apache.org/repos/asf/beam/blob/5e998532/sdks/python/apache_beam/io/avroio_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/avroio_test.py 
b/sdks/python/apache_beam/io/avroio_test.py
index 6dcf121..969f440 100644
--- a/sdks/python/apache_beam/io/avroio_test.py
+++ b/sdks/python/apache_beam/io/avroio_test.py
@@ -22,6 +22,7 @@ import tempfile
 import unittest
 
 import apache_beam as beam
+from apache_beam import Create
 from apache_beam.io import iobase
 from apache_beam.io import avroio
 from apache_beam.io import filebasedsource
@@ -346,11 +347,41 @@ class TestAvro(unittest.TestCase):
       source_test_utils.read_from_source(source, None, None)
       self.assertEqual(0, exn.exception.message.find('Unexpected sync marker'))
 
-  def test_source_transform(self):
+  def test_read_from_avro(self):
     path = self._write_data()
     with TestPipeline() as p:
       assert_that(p | avroio.ReadFromAvro(path), equal_to(self.RECORDS))
 
+  def test_read_all_from_avro_single_file(self):
+    path = self._write_data()
+    with TestPipeline() as p:
+      assert_that(p | Create([path]) | avroio.ReadAllFromAvro(),
+                  equal_to(self.RECORDS))
+
+  def test_read_all_from_avro_many_single_files(self):
+    path1 = self._write_data()
+    path2 = self._write_data()
+    path3 = self._write_data()
+    with TestPipeline() as p:
+      assert_that(p | Create([path1, path2, path3]) | avroio.ReadAllFromAvro(),
+                  equal_to(self.RECORDS * 3))
+
+  def test_read_all_from_avro_file_pattern(self):
+    file_pattern = self._write_pattern(5)
+    with TestPipeline() as p:
+      assert_that(p | Create([file_pattern]) | avroio.ReadAllFromAvro(),
+                  equal_to(self.RECORDS * 5))
+
+  def test_read_all_from_avro_many_file_patterns(self):
+    file_pattern1 = self._write_pattern(5)
+    file_pattern2 = self._write_pattern(2)
+    file_pattern3 = self._write_pattern(3)
+    with TestPipeline() as p:
+      assert_that(p
+                  | Create([file_pattern1, file_pattern2, file_pattern3])
+                  | avroio.ReadAllFromAvro(),
+                  equal_to(self.RECORDS * 10))
+
   def test_sink_transform(self):
     with tempfile.NamedTemporaryFile() as dst:
       path = dst.name

http://git-wip-us.apache.org/repos/asf/beam/blob/5e998532/sdks/python/apache_beam/io/filebasedsource.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/filebasedsource.py 
b/sdks/python/apache_beam/io/filebasedsource.py
index bb9efc4..f78bf3f 100644
--- a/sdks/python/apache_beam/io/filebasedsource.py
+++ b/sdks/python/apache_beam/io/filebasedsource.py
@@ -24,17 +24,26 @@ for more details.
 
 For an example implementation of ``FileBasedSource`` see ``avroio.AvroSource``.
 """
-
+import uuid
+
+from apache_beam.transforms.core import DoFn
+from apache_beam.transforms.core import ParDo
+from apache_beam.transforms.core import GroupByKey
+from apache_beam.transforms.core import PTransform
+from apache_beam.transforms.core import FlatMap
+from apache_beam.transforms.core import Map
 from apache_beam.internal import pickler
 from apache_beam.io import concat_source
 from apache_beam.io import iobase
 from apache_beam.io import range_trackers
 from apache_beam.io.filesystem import CompressionTypes
 from apache_beam.io.filesystems import FileSystems
+from apache_beam.io.range_trackers import OffsetRange
 from apache_beam.transforms.display import DisplayDataItem
 from apache_beam.options.value_provider import ValueProvider
 from apache_beam.options.value_provider import StaticValueProvider
 from apache_beam.options.value_provider import check_accessible
+from apache_beam.transforms.trigger import DefaultTrigger
 
 MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 25
 
@@ -95,12 +104,7 @@ class FileBasedSource(iobase.BoundedSource):
       raise TypeError('compression_type must be CompressionType object but '
                       'was %s' % type(compression_type))
     self._compression_type = compression_type
-    if compression_type in (CompressionTypes.UNCOMPRESSED,
-                            CompressionTypes.AUTO):
-      self._splittable = splittable
-    else:
-      # We can't split compressed files efficiently so turn off splitting.
-      self._splittable = False
+    self._splittable = splittable
     if validate and file_pattern.is_accessible():
       self._validate()
 
@@ -132,13 +136,10 @@ class FileBasedSource(iobase.BoundedSource):
           continue  # Ignoring empty file.
 
         # We determine splittability of this specific file.
-        splittable = self.splittable
-        if (splittable and
-            self._compression_type == CompressionTypes.AUTO):
-          compression_type = CompressionTypes.detect_compression_type(
-              file_name)
-          if compression_type != CompressionTypes.UNCOMPRESSED:
-            splittable = False
+        splittable = (
+            self.splittable and
+            _determine_splittability_from_compression_type(
+                file_name, self._compression_type))
 
         single_file_source = _SingleFileSource(
             file_based_source_ref, file_name,
@@ -211,6 +212,14 @@ class FileBasedSource(iobase.BoundedSource):
     return self._splittable
 
 
+def _determine_splittability_from_compression_type(
+    file_path, compression_type):
+  if compression_type == CompressionTypes.AUTO:
+    compression_type = CompressionTypes.detect_compression_type(file_path)
+
+  return compression_type == CompressionTypes.UNCOMPRESSED
+
+
 class _SingleFileSource(iobase.BoundedSource):
   """Denotes a source for a specific file type."""
 
@@ -244,24 +253,21 @@ class _SingleFileSource(iobase.BoundedSource):
       stop_offset = self._stop_offset
 
     if self._splittable:
-      bundle_size = max(desired_bundle_size, self._min_bundle_size)
-
-      bundle_start = start_offset
-      while bundle_start < stop_offset:
-        bundle_stop = min(bundle_start + bundle_size, stop_offset)
+      splits = OffsetRange(start_offset, stop_offset).split(
+          desired_bundle_size, self._min_bundle_size)
+      for split in splits:
         yield iobase.SourceBundle(
-            bundle_stop - bundle_start,
+            split.stop - split.start,
             _SingleFileSource(
                 # Copying this so that each sub-source gets a fresh instance.
                 pickler.loads(pickler.dumps(self._file_based_source)),
                 self._file_name,
-                bundle_start,
-                bundle_stop,
+                split.start,
+                split.stop,
                 min_bundle_size=self._min_bundle_size,
                 splittable=self._splittable),
-            bundle_start,
-            bundle_stop)
-        bundle_start = bundle_stop
+            split.start,
+            split.stop)
     else:
       # Returning a single sub-source with end offset set to OFFSET_INFINITY 
(so
       # that all data of the source gets read) since this source is
@@ -308,3 +314,112 @@ class _SingleFileSource(iobase.BoundedSource):
 
   def default_output_coder(self):
     return self._file_based_source.default_output_coder()
+
+
+class _ExpandIntoRanges(DoFn):
+
+  def __init__(
+      self, splittable, compression_type, desired_bundle_size, 
min_bundle_size):
+    self._desired_bundle_size = desired_bundle_size
+    self._min_bundle_size = min_bundle_size
+    self._splittable = splittable
+    self._compression_type = compression_type
+
+  def process(self, element, *args, **kwargs):
+    match_results = FileSystems.match([element])
+    for metadata in match_results[0].metadata_list:
+      splittable = (
+          self._splittable and
+          _determine_splittability_from_compression_type(
+              metadata.path, self._compression_type))
+
+      if splittable:
+        for split in OffsetRange(
+            0, metadata.size_in_bytes).split(
+                self._desired_bundle_size, self._min_bundle_size):
+          yield (metadata, split)
+      else:
+        yield (metadata, OffsetRange(
+            0, range_trackers.OffsetRangeTracker.OFFSET_INFINITY))
+
+
+# Replace following with a generic reshard transform once
+# https://issues.apache.org/jira/browse/BEAM-1872 is implemented.
+class _Reshard(PTransform):
+
+  def expand(self, pvalue):
+    keyed_pc = (pvalue
+                | 'AssignKey' >> Map(lambda x: (uuid.uuid4(), x)))
+    if keyed_pc.windowing.windowfn.is_merging():
+      raise ValueError('Transform ReadAllFiles cannot be used in the presence '
+                       'of merging windows')
+    if not isinstance(keyed_pc.windowing.triggerfn, DefaultTrigger):
+      raise ValueError('Transform ReadAllFiles cannot be used in the presence '
+                       'of non-trivial triggers')
+
+    return (keyed_pc | 'GroupByKey' >> GroupByKey()
+            # Using FlatMap below due to the possibility of key collisions.
+            | 'DropKey' >> FlatMap(lambda (k, values): values))
+
+
+class _ReadRange(DoFn):
+
+  def __init__(self, source_from_file):
+    self._source_from_file = source_from_file
+
+  def process(self, element, *args, **kwargs):
+    metadata, range = element
+    source = self._source_from_file(metadata.path)
+    # Following split() operation has to be performed to create a proper
+    # _SingleFileSource. Otherwise what we have is a ConcatSource that contains
+    # a single _SingleFileSource. ConcatSource.read() expects a RangeTraker for
+    # sub-source range and reads full sub-sources (not byte ranges).
+    source = list(source.split(float('inf')))[0].source
+    for record in source.read(range.new_tracker()):
+      yield record
+
+
+class ReadAllFiles(PTransform):
+  """A Read transform that reads a PCollection of files.
+
+  Pipeline authors should not use this directly. This is to be used by Read
+  PTransform authors who wishes to implement file-based Read transforms that
+  read a PCollection of files.
+  """
+
+  def __init__(
+      self, splittable, compression_type, desired_bundle_size, min_bundle_size,
+      source_from_file):
+    """
+    Args:
+      splittable: If True, files won't be split into sub-ranges. If False, 
files
+                  may or may not be split into data ranges.
+      compression_type: A ``CompressionType`` object that specifies the
+                  compression type of the files that will be processed. If
+                  ``CompressionType.AUTO``, system will try to automatically
+                  determine the compression type based on the extension of
+                  files.
+      desired_bundle_size: the desired size of data ranges that should be
+                           generated when splitting a file into data ranges.
+      min_bundle_size: minimum size of data ranges that should be generated 
when
+                           splitting a file into data ranges.
+      source_from_file: a function that produces a ``BoundedSource`` given a
+                        file name. System will use this function to generate
+                        ``BoundedSource`` objects for file paths. Note that 
file
+                        paths passed to this will be for individual files, not
+                        for file patterns even if the ``PCollection`` of files
+                        processed by the transform consist of file patterns.
+    """
+    self._splittable = splittable
+    self._compression_type = compression_type
+    self._desired_bundle_size = desired_bundle_size
+    self._min_bundle_size = min_bundle_size
+    self._source_from_file = source_from_file
+
+  def expand(self, pvalue):
+    return (pvalue
+            | 'ExpandIntoRanges' >> ParDo(_ExpandIntoRanges(
+                self._splittable, self._compression_type,
+                self._desired_bundle_size, self._min_bundle_size))
+            | 'Reshard' >> _Reshard()
+            | 'ReadRange' >> ParDo(_ReadRange(self._source_from_file)))

http://git-wip-us.apache.org/repos/asf/beam/blob/5e998532/sdks/python/apache_beam/io/range_trackers.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/range_trackers.py 
b/sdks/python/apache_beam/io/range_trackers.py
index bef77d4..4bd19f8 100644
--- a/sdks/python/apache_beam/io/range_trackers.py
+++ b/sdks/python/apache_beam/io/range_trackers.py
@@ -28,6 +28,48 @@ __all__ = ['OffsetRangeTracker', 
'LexicographicKeyRangeTracker',
            'OrderedPositionRangeTracker', 'UnsplittableRangeTracker']
 
 
+class OffsetRange(object):
+
+  def __init__(self, start, stop):
+    if start >= stop:
+      raise ValueError(
+          'Start offset must be smaller than the stop offset. '
+          'Received %d and %d respectively.', start, stop)
+    self.start = start
+    self.stop = stop
+
+  def __eq__(self, other):
+    if not isinstance(other, OffsetRange):
+      return False
+
+    return self.start == other.start and self.stop == other.stop
+
+  def __ne__(self, other):
+    if not isinstance(other, OffsetRange):
+      return True
+
+    return not (self.start == other.start and self.stop == other.stop)
+
+  def split(self, desired_num_offsets_per_split, min_num_offsets_per_split=1):
+    current_split_start = self.start
+    max_split_size = max(desired_num_offsets_per_split,
+                         min_num_offsets_per_split)
+    while current_split_start < self.stop:
+      current_split_stop = min(current_split_start + max_split_size, self.stop)
+      remaining = self.stop - current_split_stop
+
+      # Avoiding a small split at the end.
+      if (remaining < desired_num_offsets_per_split / 4 or
+          remaining < min_num_offsets_per_split):
+        current_split_stop = self.stop
+
+      yield OffsetRange(current_split_start, current_split_stop)
+      current_split_start = current_split_stop
+
+  def new_tracker(self):
+    return OffsetRangeTracker(self.start, self.stop)
+
+
 class OffsetRangeTracker(iobase.RangeTracker):
   """A 'RangeTracker' for non-negative positions of type 'long'."""
 

http://git-wip-us.apache.org/repos/asf/beam/blob/5e998532/sdks/python/apache_beam/io/range_trackers_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/range_trackers_test.py 
b/sdks/python/apache_beam/io/range_trackers_test.py
index 3e92663..762d654 100644
--- a/sdks/python/apache_beam/io/range_trackers_test.py
+++ b/sdks/python/apache_beam/io/range_trackers_test.py
@@ -23,6 +23,43 @@ import math
 import unittest
 
 from apache_beam.io import range_trackers
+from apache_beam.io.range_trackers import OffsetRange
+
+
+class OffsetRangeTest(unittest.TestCase):
+
+  def test_create(self):
+    OffsetRange(0, 10)
+    OffsetRange(10, 100)
+
+    with self.assertRaises(ValueError):
+      OffsetRange(10, 9)
+
+  def test_split_respects_desired_num_splits(self):
+    range = OffsetRange(10, 100)
+    splits = list(range.split(desired_num_offsets_per_split=25))
+    self.assertEqual(4, len(splits))
+    self.assertIn(OffsetRange(10, 35), splits)
+    self.assertIn(OffsetRange(35, 60), splits)
+    self.assertIn(OffsetRange(60, 85), splits)
+    self.assertIn(OffsetRange(85, 100), splits)
+
+  def test_split_respects_min_num_splits(self):
+    range = OffsetRange(10, 100)
+    splits = list(range.split(desired_num_offsets_per_split=5,
+                              min_num_offsets_per_split=25))
+    self.assertEqual(3, len(splits))
+    self.assertIn(OffsetRange(10, 35), splits)
+    self.assertIn(OffsetRange(35, 60), splits)
+    self.assertIn(OffsetRange(60, 100), splits)
+
+  def test_split_no_small_split_at_end(self):
+    range = OffsetRange(10, 90)
+    splits = list(range.split(desired_num_offsets_per_split=25))
+    self.assertEqual(3, len(splits))
+    self.assertIn(OffsetRange(10, 35), splits)
+    self.assertIn(OffsetRange(35, 60), splits)
+    self.assertIn(OffsetRange(60, 90), splits)
 
 
 class OffsetRangeTrackerTest(unittest.TestCase):

http://git-wip-us.apache.org/repos/asf/beam/blob/5e998532/sdks/python/apache_beam/io/textio.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/textio.py 
b/sdks/python/apache_beam/io/textio.py
index 60e1512..9c6532e 100644
--- a/sdks/python/apache_beam/io/textio.py
+++ b/sdks/python/apache_beam/io/textio.py
@@ -19,19 +19,21 @@
 
 
 from __future__ import absolute_import
+from functools import partial
 import logging
 
 from apache_beam.coders import coders
 from apache_beam.io import filebasedsource
 from apache_beam.io import filebasedsink
 from apache_beam.io import iobase
+from apache_beam.io.filebasedsource import ReadAllFiles
 from apache_beam.io.filesystem import CompressionTypes
 from apache_beam.io.iobase import Read
 from apache_beam.io.iobase import Write
 from apache_beam.transforms import PTransform
 from apache_beam.transforms.display import DisplayDataItem
 
-__all__ = ['ReadFromText', 'WriteToText']
+__all__ = ['ReadFromText', 'ReadAllFromText', 'WriteToText']
 
 
 class _TextSource(filebasedsource.FileBasedSource):
@@ -342,8 +344,80 @@ class _TextSink(filebasedsink.FileBasedSink):
       file_handle.write('\n')
 
 
+def _create_text_source(
+    file_pattern=None, min_bundle_size=None, compression_type=None,
+    strip_trailing_newlines=None, coder=None, skip_header_lines=None):
+  return _TextSource(
+      file_pattern=file_pattern, min_bundle_size=min_bundle_size,
+      compression_type=compression_type,
+      strip_trailing_newlines=strip_trailing_newlines,
+      coder=coder, validate=False, skip_header_lines=skip_header_lines)
+
+
+class ReadAllFromText(PTransform):
+  """A ``PTransform`` for reading a ``PCollection`` of text files.
+
+   Reads a ``PCollection`` of text files or file patterns and and produces a
+   ``PCollection`` of strings.
+
+  Parses a text file as newline-delimited elements, by default assuming
+  UTF-8 encoding. Supports newline delimiters '\\n' and '\\r\\n'.
+
+  This implementation only supports reading text encoded using UTF-8 or ASCII.
+  This does not support other encodings such as UTF-16 or UTF-32.
+  """
+
+  DEFAULT_DESIRED_BUNDLE_SIZE = 64 * 1024 * 1024  # 64MB
+
+  def __init__(
+      self,
+      min_bundle_size=0,
+      desired_bundle_size=DEFAULT_DESIRED_BUNDLE_SIZE,
+      compression_type=CompressionTypes.AUTO,
+      strip_trailing_newlines=True,
+      coder=coders.StrUtf8Coder(),
+      skip_header_lines=0,
+      **kwargs):
+    """Initialize the ``ReadAllFromText`` transform.
+
+    Args:
+      min_bundle_size: Minimum size of bundles that should be generated when
+        splitting this source into bundles. See ``FileBasedSource`` for more
+        details.
+      desired_bundle_size: Desired size of bundles that should be generated 
when
+        splitting this source into bundles. See ``FileBasedSource`` for more
+        details.
+      compression_type: Used to handle compressed input files. Typical value
+        is ``CompressionTypes.AUTO``, in which case the underlying file_path's
+        extension will be used to detect the compression.
+      strip_trailing_newlines: Indicates whether this source should remove
+        the newline char in each line it reads before decoding that line.
+      validate: flag to verify that the files exist during the pipeline
+        creation time.
+      skip_header_lines: Number of header lines to skip. Same number is skipped
+        from each source file. Must be 0 or higher. Large number of skipped
+        lines might impact performance.
+      coder: Coder used to decode each line.
+    """
+    super(ReadAllFromText, self).__init__(**kwargs)
+    source_from_file = partial(
+        _create_text_source, min_bundle_size=min_bundle_size,
+        compression_type=compression_type,
+        strip_trailing_newlines=strip_trailing_newlines, coder=coder,
+        skip_header_lines=skip_header_lines)
+    self._desired_bundle_size = desired_bundle_size
+    self._min_bundle_size = min_bundle_size
+    self._compression_type = compression_type
+    self._read_all_files = ReadAllFiles(
+        True, compression_type, desired_bundle_size, min_bundle_size,
+        source_from_file)
+
+  def expand(self, pvalue):
+    return pvalue | 'ReadAllFiles' >> self._read_all_files
+
+
 class ReadFromText(PTransform):
-  """A PTransform for reading text files.
+  """A ``PTransform`` for reading text files.
 
   Parses a text file as newline-delimited elements, by default assuming
   UTF-8 encoding. Supports newline delimiters '\\n' and '\\r\\n'.
@@ -361,7 +435,7 @@ class ReadFromText(PTransform):
       validate=True,
       skip_header_lines=0,
       **kwargs):
-    """Initialize the ReadFromText transform.
+    """Initialize the ``ReadFromText`` transform.
 
     Args:
       file_pattern: The file path to read from as a local file path or a GCS
@@ -371,7 +445,7 @@ class ReadFromText(PTransform):
         splitting this source into bundles. See ``FileBasedSource`` for more
         details.
       compression_type: Used to handle compressed input files. Typical value
-        is CompressionTypes.AUTO, in which case the underlying file_path's
+        is ``CompressionTypes.AUTO``, in which case the underlying file_path's
         extension will be used to detect the compression.
       strip_trailing_newlines: Indicates whether this source should remove
         the newline char in each line it reads before decoding that line.

http://git-wip-us.apache.org/repos/asf/beam/blob/5e998532/sdks/python/apache_beam/io/textio_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/textio_test.py 
b/sdks/python/apache_beam/io/textio_test.py
index 8bd7116..b29ca5a 100644
--- a/sdks/python/apache_beam/io/textio_test.py
+++ b/sdks/python/apache_beam/io/textio_test.py
@@ -27,7 +27,7 @@ import tempfile
 import unittest
 
 import apache_beam as beam
-from apache_beam.io import iobase
+from apache_beam.io import iobase, ReadAllFromText
 import apache_beam.io.source_test_utils as source_test_utils
 
 # Importing following private classes for testing.
@@ -47,6 +47,8 @@ from apache_beam.testing.test_pipeline import TestPipeline
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
 
+from apache_beam.transforms.core import Create
+
 
 # TODO: Refactor code so all io tests are using same library
 # TestCaseWithTempDirCleanup class.
@@ -334,7 +336,7 @@ class TextSourceTest(_TestCaseWithTempDirCleanUp):
         splits[0].source, splits[0].start_position, splits[0].stop_position,
         perform_multi_threaded_test=False)
 
-  def test_dataflow_single_file(self):
+  def test_read_from_text_single_file(self):
     file_name, expected_data = write_data(5)
     assert len(expected_data) == 5
     pipeline = TestPipeline()
@@ -342,7 +344,53 @@ class TextSourceTest(_TestCaseWithTempDirCleanUp):
     assert_that(pcoll, equal_to(expected_data))
     pipeline.run()
 
-  def test_dataflow_single_file_with_coder(self):
+  def test_read_all_single_file(self):
+    file_name, expected_data = write_data(5)
+    assert len(expected_data) == 5
+    pipeline = TestPipeline()
+    pcoll = pipeline | 'Create' >> Create(
+        [file_name]) |'ReadAll' >> ReadAllFromText()
+    assert_that(pcoll, equal_to(expected_data))
+    pipeline.run()
+
+  def test_read_all_many_single_files(self):
+    file_name1, expected_data1 = write_data(5)
+    assert len(expected_data1) == 5
+    file_name2, expected_data2 = write_data(10)
+    assert len(expected_data2) == 10
+    file_name3, expected_data3 = write_data(15)
+    assert len(expected_data3) == 15
+    expected_data = []
+    expected_data.extend(expected_data1)
+    expected_data.extend(expected_data2)
+    expected_data.extend(expected_data3)
+    pipeline = TestPipeline()
+    pcoll = pipeline | 'Create' >> Create(
+        [file_name1, file_name2, file_name3]) |'ReadAll' >> ReadAllFromText()
+    assert_that(pcoll, equal_to(expected_data))
+    pipeline.run()
+
+  def test_read_all_unavailable_files_ignored(self):
+    file_name1, expected_data1 = write_data(5)
+    assert len(expected_data1) == 5
+    file_name2, expected_data2 = write_data(10)
+    assert len(expected_data2) == 10
+    file_name3, expected_data3 = write_data(15)
+    assert len(expected_data3) == 15
+    file_name4 = "/unavailable_file"
+    expected_data = []
+    expected_data.extend(expected_data1)
+    expected_data.extend(expected_data2)
+    expected_data.extend(expected_data3)
+    pipeline = TestPipeline()
+    pcoll = (pipeline
+             | 'Create' >> Create(
+                 [file_name1, file_name2, file_name3, file_name4])
+             |'ReadAll' >> ReadAllFromText())
+    assert_that(pcoll, equal_to(expected_data))
+    pipeline.run()
+
+  def test_read_from_text_single_file_with_coder(self):
     class DummyCoder(coders.Coder):
       def encode(self, x):
         raise ValueError
@@ -357,7 +405,7 @@ class TextSourceTest(_TestCaseWithTempDirCleanUp):
     assert_that(pcoll, equal_to([record * 2 for record in expected_data]))
     pipeline.run()
 
-  def test_dataflow_file_pattern(self):
+  def test_read_from_text_file_pattern(self):
     pattern, expected_data = write_pattern([5, 3, 12, 8, 8, 4])
     assert len(expected_data) == 40
     pipeline = TestPipeline()
@@ -365,6 +413,33 @@ class TextSourceTest(_TestCaseWithTempDirCleanUp):
     assert_that(pcoll, equal_to(expected_data))
     pipeline.run()
 
+  def test_read_all_file_pattern(self):
+    pattern, expected_data = write_pattern([5, 3, 12, 8, 8, 4])
+    assert len(expected_data) == 40
+    pipeline = TestPipeline()
+    pcoll = (pipeline
+             | 'Create' >> Create([pattern])
+             |'ReadAll' >> ReadAllFromText())
+    assert_that(pcoll, equal_to(expected_data))
+    pipeline.run()
+
+  def test_read_all_many_file_patterns(self):
+    pattern1, expected_data1 = write_pattern([5, 3, 12, 8, 8, 4])
+    assert len(expected_data1) == 40
+    pattern2, expected_data2 = write_pattern([3, 7, 9])
+    assert len(expected_data2) == 19
+    pattern3, expected_data3 = write_pattern([11, 20, 5, 5])
+    assert len(expected_data3) == 41
+    expected_data = []
+    expected_data.extend(expected_data1)
+    expected_data.extend(expected_data2)
+    expected_data.extend(expected_data3)
+    pipeline = TestPipeline()
+    pcoll = pipeline | 'Create' >> Create(
+        [pattern1, pattern2, pattern3]) |'ReadAll' >> ReadAllFromText()
+    assert_that(pcoll, equal_to(expected_data))
+    pipeline.run()
+
   def test_read_auto_bzip2(self):
     _, lines = write_data(15)
     file_name = self._create_temp_file(suffix='.bz2')
@@ -528,6 +603,18 @@ class TextSourceTest(_TestCaseWithTempDirCleanUp):
 
     expected = ['a', 'b', 'c', 'p', 'q', 'r', 'x', 'y', 'z']
     assert_that(lines, equal_to(expected))
+
+  def test_read_all_gzip(self):
+    _, lines = write_data(100)
+    file_name = self._create_temp_file()
+    with gzip.GzipFile(file_name, 'wb') as f:
+      f.write('\n'.join(lines))
+    pipeline = TestPipeline()
+    pcoll = (pipeline
+             | Create([file_name])
+             | 'ReadAll' >> ReadAllFromText(
+                 compression_type=CompressionTypes.GZIP))
+    assert_that(pcoll, equal_to(lines))
     pipeline.run()
 
   def test_read_gzip_large(self):

Reply via email to