Repository: beam
Updated Branches:
  refs/heads/master 781e4172c -> 7447147e0


Add `validate` argument to tfrecordio.ReadFromTFRecord


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/b616505d
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/b616505d
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/b616505d

Branch: refs/heads/master
Commit: b616505d5c6ae569f3cda5de97cbd8ddb07735c3
Parents: 781e417
Author: Neda Mirian <[email protected]>
Authored: Wed Mar 8 17:17:48 2017 -0800
Committer: Ahmet Altay <[email protected]>
Committed: Mon Mar 13 10:35:19 2017 -0700

----------------------------------------------------------------------
 sdks/python/apache_beam/io/tfrecordio.py      | 11 ++++++---
 sdks/python/apache_beam/io/tfrecordio_test.py | 26 ++++++++++++++++++----
 2 files changed, 30 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/b616505d/sdks/python/apache_beam/io/tfrecordio.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/tfrecordio.py 
b/sdks/python/apache_beam/io/tfrecordio.py
index be9f839..05c0a13 100644
--- a/sdks/python/apache_beam/io/tfrecordio.py
+++ b/sdks/python/apache_beam/io/tfrecordio.py
@@ -146,12 +146,14 @@ class _TFRecordSource(filebasedsource.FileBasedSource):
   def __init__(self,
                file_pattern,
                coder,
-               compression_type):
+               compression_type,
+               validate):
     """Initialize a TFRecordSource.  See ReadFromTFRecord for details."""
     super(_TFRecordSource, self).__init__(
         file_pattern=file_pattern,
         compression_type=compression_type,
-        splittable=False)
+        splittable=False,
+        validate=validate)
     self._coder = coder
 
   def read_records(self, file_name, offset_range_tracker):
@@ -179,6 +181,7 @@ class ReadFromTFRecord(PTransform):
                file_pattern,
                coder=coders.BytesCoder(),
                compression_type=fileio.CompressionTypes.AUTO,
+               validate=True,
                **kwargs):
     """Initialize a ReadFromTFRecord transform.
 
@@ -188,6 +191,8 @@ class ReadFromTFRecord(PTransform):
       compression_type: Used to handle compressed input files. Default value
           is CompressionTypes.AUTO, in which case the file_path's extension 
will
           be used to detect the compression.
+      validate: Boolean flag to verify that the files exist during the pipeline
+          creation time.
       **kwargs: optional args dictionary. These are passed through to parent
         constructor.
 
@@ -195,7 +200,7 @@ class ReadFromTFRecord(PTransform):
       A ReadFromTFRecord transform object.
     """
     super(ReadFromTFRecord, self).__init__(**kwargs)
-    self._args = (file_pattern, coder, compression_type)
+    self._args = (file_pattern, coder, compression_type, validate)
 
   def expand(self, pvalue):
     return pvalue.pipeline | Read(_TFRecordSource(*self._args))

http://git-wip-us.apache.org/repos/asf/beam/blob/b616505d/sdks/python/apache_beam/io/tfrecordio_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/tfrecordio_test.py 
b/sdks/python/apache_beam/io/tfrecordio_test.py
index e5dcbdc..df33fcb 100644
--- a/sdks/python/apache_beam/io/tfrecordio_test.py
+++ b/sdks/python/apache_beam/io/tfrecordio_test.py
@@ -252,7 +252,8 @@ class TestTFRecordSource(_TestCaseWithTempDirCleanUp):
                     _TFRecordSource(
                         path,
                         coder=coders.BytesCoder(),
-                        compression_type=fileio.CompressionTypes.AUTO)))
+                        compression_type=fileio.CompressionTypes.AUTO,
+                        validate=True)))
       beam.assert_that(result, beam.equal_to(['foo']))
 
   def test_process_multiple(self):
@@ -264,7 +265,8 @@ class TestTFRecordSource(_TestCaseWithTempDirCleanUp):
                     _TFRecordSource(
                         path,
                         coder=coders.BytesCoder(),
-                        compression_type=fileio.CompressionTypes.AUTO)))
+                        compression_type=fileio.CompressionTypes.AUTO,
+                        validate=True)))
       beam.assert_that(result, beam.equal_to(['foo', 'bar']))
 
   def test_process_gzip(self):
@@ -276,7 +278,8 @@ class TestTFRecordSource(_TestCaseWithTempDirCleanUp):
                     _TFRecordSource(
                         path,
                         coder=coders.BytesCoder(),
-                        compression_type=fileio.CompressionTypes.GZIP)))
+                        compression_type=fileio.CompressionTypes.GZIP,
+                        validate=True)))
       beam.assert_that(result, beam.equal_to(['foo', 'bar']))
 
   def test_process_auto(self):
@@ -288,7 +291,8 @@ class TestTFRecordSource(_TestCaseWithTempDirCleanUp):
                     _TFRecordSource(
                         path,
                         coder=coders.BytesCoder(),
-                        compression_type=fileio.CompressionTypes.AUTO)))
+                        compression_type=fileio.CompressionTypes.AUTO,
+                        validate=True)))
       beam.assert_that(result, beam.equal_to(['foo', 'bar']))
 
 
@@ -383,6 +387,20 @@ class TestEnd2EndWriteAndRead(_TestCaseWithTempDirCleanUp):
           coder=beam.coders.ProtoCoder(example.__class__)))
       beam.assert_that(actual_data, beam.equal_to([example]))
 
+  def test_end2end_read_write_read(self):
+    path = os.path.join(self._new_tempdir(), 'result')
+    with TestPipeline() as p:
+      # Initial read to validate the pipeline doesn't fail before the file is
+      # created.
+      _ = p | ReadFromTFRecord(path + '-*', validate=False)
+      expected_data = [self.create_inputs() for _ in range(0, 10)]
+      _ = p | beam.Create(expected_data) | WriteToTFRecord(
+          path, file_name_suffix='.gz')
+
+    # Read the file back and compare.
+    with TestPipeline() as p:
+      actual_data = p | ReadFromTFRecord(path+'-*', validate=True)
+      beam.assert_that(actual_data, beam.equal_to(expected_data))
 
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)

Reply via email to