Repository: beam Updated Branches: refs/heads/master 781e4172c -> 7447147e0
Add `validate` argument to tfrecordio.ReadFromTFRecord Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/b616505d Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/b616505d Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/b616505d Branch: refs/heads/master Commit: b616505d5c6ae569f3cda5de97cbd8ddb07735c3 Parents: 781e417 Author: Neda Mirian <[email protected]> Authored: Wed Mar 8 17:17:48 2017 -0800 Committer: Ahmet Altay <[email protected]> Committed: Mon Mar 13 10:35:19 2017 -0700 ---------------------------------------------------------------------- sdks/python/apache_beam/io/tfrecordio.py | 11 ++++++--- sdks/python/apache_beam/io/tfrecordio_test.py | 26 ++++++++++++++++++---- 2 files changed, 30 insertions(+), 7 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/b616505d/sdks/python/apache_beam/io/tfrecordio.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/tfrecordio.py b/sdks/python/apache_beam/io/tfrecordio.py index be9f839..05c0a13 100644 --- a/sdks/python/apache_beam/io/tfrecordio.py +++ b/sdks/python/apache_beam/io/tfrecordio.py @@ -146,12 +146,14 @@ class _TFRecordSource(filebasedsource.FileBasedSource): def __init__(self, file_pattern, coder, - compression_type): + compression_type, + validate): """Initialize a TFRecordSource. See ReadFromTFRecord for details.""" super(_TFRecordSource, self).__init__( file_pattern=file_pattern, compression_type=compression_type, - splittable=False) + splittable=False, + validate=validate) self._coder = coder def read_records(self, file_name, offset_range_tracker): @@ -179,6 +181,7 @@ class ReadFromTFRecord(PTransform): file_pattern, coder=coders.BytesCoder(), compression_type=fileio.CompressionTypes.AUTO, + validate=True, **kwargs): """Initialize a ReadFromTFRecord transform. @@ -188,6 +191,8 @@ class ReadFromTFRecord(PTransform): compression_type: Used to handle compressed input files. Default value is CompressionTypes.AUTO, in which case the file_path's extension will be used to detect the compression. + validate: Boolean flag to verify that the files exist during the pipeline + creation time. **kwargs: optional args dictionary. These are passed through to parent constructor. @@ -195,7 +200,7 @@ class ReadFromTFRecord(PTransform): A ReadFromTFRecord transform object. """ super(ReadFromTFRecord, self).__init__(**kwargs) - self._args = (file_pattern, coder, compression_type) + self._args = (file_pattern, coder, compression_type, validate) def expand(self, pvalue): return pvalue.pipeline | Read(_TFRecordSource(*self._args)) http://git-wip-us.apache.org/repos/asf/beam/blob/b616505d/sdks/python/apache_beam/io/tfrecordio_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/tfrecordio_test.py b/sdks/python/apache_beam/io/tfrecordio_test.py index e5dcbdc..df33fcb 100644 --- a/sdks/python/apache_beam/io/tfrecordio_test.py +++ b/sdks/python/apache_beam/io/tfrecordio_test.py @@ -252,7 +252,8 @@ class TestTFRecordSource(_TestCaseWithTempDirCleanUp): _TFRecordSource( path, coder=coders.BytesCoder(), - compression_type=fileio.CompressionTypes.AUTO))) + compression_type=fileio.CompressionTypes.AUTO, + validate=True))) beam.assert_that(result, beam.equal_to(['foo'])) def test_process_multiple(self): @@ -264,7 +265,8 @@ class TestTFRecordSource(_TestCaseWithTempDirCleanUp): _TFRecordSource( path, coder=coders.BytesCoder(), - compression_type=fileio.CompressionTypes.AUTO))) + compression_type=fileio.CompressionTypes.AUTO, + validate=True))) beam.assert_that(result, beam.equal_to(['foo', 'bar'])) def test_process_gzip(self): @@ -276,7 +278,8 @@ class TestTFRecordSource(_TestCaseWithTempDirCleanUp): _TFRecordSource( path, coder=coders.BytesCoder(), - compression_type=fileio.CompressionTypes.GZIP))) + compression_type=fileio.CompressionTypes.GZIP, + validate=True))) beam.assert_that(result, beam.equal_to(['foo', 'bar'])) def test_process_auto(self): @@ -288,7 +291,8 @@ class TestTFRecordSource(_TestCaseWithTempDirCleanUp): _TFRecordSource( path, coder=coders.BytesCoder(), - compression_type=fileio.CompressionTypes.AUTO))) + compression_type=fileio.CompressionTypes.AUTO, + validate=True))) beam.assert_that(result, beam.equal_to(['foo', 'bar'])) @@ -383,6 +387,20 @@ class TestEnd2EndWriteAndRead(_TestCaseWithTempDirCleanUp): coder=beam.coders.ProtoCoder(example.__class__))) beam.assert_that(actual_data, beam.equal_to([example])) + def test_end2end_read_write_read(self): + path = os.path.join(self._new_tempdir(), 'result') + with TestPipeline() as p: + # Initial read to validate the pipeline doesn't fail before the file is + # created. + _ = p | ReadFromTFRecord(path + '-*', validate=False) + expected_data = [self.create_inputs() for _ in range(0, 10)] + _ = p | beam.Create(expected_data) | WriteToTFRecord( + path, file_name_suffix='.gz') + + # Read the file back and compare. + with TestPipeline() as p: + actual_data = p | ReadFromTFRecord(path+'-*', validate=True) + beam.assert_that(actual_data, beam.equal_to(expected_data)) if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO)
