Provided temporary directory management for test cases.
Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/93e8d19e Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/93e8d19e Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/93e8d19e Branch: refs/heads/python-sdk Commit: 93e8d19e32807fb5279ed711f0f06c3123adfb2e Parents: 88833ba Author: Younghee Kwon <younghee.k...@gmail.com> Authored: Mon Jan 9 11:50:57 2017 -0800 Committer: Robert Bradshaw <rober...@google.com> Committed: Mon Jan 9 13:13:46 2017 -0800 ---------------------------------------------------------------------- sdks/python/apache_beam/io/tfrecordio_test.py | 58 +++++++++++++++------- 1 file changed, 41 insertions(+), 17 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/93e8d19e/sdks/python/apache_beam/io/tfrecordio_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/tfrecordio_test.py b/sdks/python/apache_beam/io/tfrecordio_test.py index ee287b3..ecd58f5 100644 --- a/sdks/python/apache_beam/io/tfrecordio_test.py +++ b/sdks/python/apache_beam/io/tfrecordio_test.py @@ -20,8 +20,10 @@ import cStringIO import glob import gzip import logging +import os import pickle import random +import shutil import tempfile import unittest @@ -134,7 +136,29 @@ class TestTFRecordUtil(unittest.TestCase): self.assertEqual(record, actual) -class TestTFRecordSink(unittest.TestCase): +class _TestCaseWithTempDirCleanUp(unittest.TestCase): + """Base class for TestCases that deals with TempDir clean-up. + + Inherited test cases will call self._new_tempdir() to start a temporary dir + which will be deleted at the end of the tests (when tearDown() is called). + """ + + def setUp(self): + self._tempdirs = [] + + def tearDown(self): + for path in self._tempdirs: + if os.path.exists(path): + shutil.rmtree(path) + self._tempdirs = [] + + def _new_tempdir(self): + result = tempfile.mkdtemp() + self._tempdirs.append(result) + return result + + +class TestTFRecordSink(_TestCaseWithTempDirCleanUp): def _write_lines(self, sink, path, lines): f = sink.open(path) @@ -143,7 +167,7 @@ class TestTFRecordSink(unittest.TestCase): sink.close(f) def test_write_record_single(self): - path = tempfile.NamedTemporaryFile().name + path = os.path.join(self._new_tempdir(), 'result') record = binascii.a2b_base64(FOO_RECORD_BASE64) sink = _TFRecordSink( path, @@ -158,7 +182,7 @@ class TestTFRecordSink(unittest.TestCase): self.assertEqual(f.read(), record) def test_write_record_multiple(self): - path = tempfile.NamedTemporaryFile().name + path = os.path.join(self._new_tempdir(), 'result') record = binascii.a2b_base64(FOO_BAR_RECORD_BASE64) sink = _TFRecordSink( path, @@ -177,8 +201,8 @@ class TestTFRecordSink(unittest.TestCase): class TestWriteToTFRecord(TestTFRecordSink): def test_write_record_gzip(self): + file_path_prefix = os.path.join(self._new_tempdir(), 'result') with beam.Pipeline(DirectRunner()) as p: - file_path_prefix = tempfile.NamedTemporaryFile().name input_data = ['foo', 'bar'] _ = p | beam.Create(input_data) | WriteToTFRecord( file_path_prefix, compression_type=fileio.CompressionTypes.GZIP) @@ -192,8 +216,8 @@ class TestWriteToTFRecord(TestTFRecordSink): self.assertEqual(actual, input_data) def test_write_record_auto(self): + file_path_prefix = os.path.join(self._new_tempdir(), 'result') with beam.Pipeline(DirectRunner()) as p: - file_path_prefix = tempfile.NamedTemporaryFile().name input_data = ['foo', 'bar'] _ = p | beam.Create(input_data) | WriteToTFRecord( file_path_prefix, file_name_suffix='.gz') @@ -207,7 +231,7 @@ class TestWriteToTFRecord(TestTFRecordSink): self.assertEqual(actual, input_data) -class TestTFRecordSource(unittest.TestCase): +class TestTFRecordSource(_TestCaseWithTempDirCleanUp): def _write_file(self, path, base64_records): record = binascii.a2b_base64(base64_records) @@ -220,7 +244,7 @@ class TestTFRecordSource(unittest.TestCase): f.write(record) def test_process_single(self): - path = tempfile.NamedTemporaryFile().name + path = os.path.join(self._new_tempdir(), 'result') self._write_file(path, FOO_RECORD_BASE64) with beam.Pipeline(DirectRunner()) as p: result = (p @@ -232,7 +256,7 @@ class TestTFRecordSource(unittest.TestCase): beam.assert_that(result, beam.equal_to(['foo'])) def test_process_multiple(self): - path = tempfile.NamedTemporaryFile().name + path = os.path.join(self._new_tempdir(), 'result') self._write_file(path, FOO_BAR_RECORD_BASE64) with beam.Pipeline(DirectRunner()) as p: result = (p @@ -244,7 +268,7 @@ class TestTFRecordSource(unittest.TestCase): beam.assert_that(result, beam.equal_to(['foo', 'bar'])) def test_process_gzip(self): - path = tempfile.NamedTemporaryFile().name + path = os.path.join(self._new_tempdir(), 'result') self._write_file_gzip(path, FOO_BAR_RECORD_BASE64) with beam.Pipeline(DirectRunner()) as p: result = (p @@ -256,7 +280,7 @@ class TestTFRecordSource(unittest.TestCase): beam.assert_that(result, beam.equal_to(['foo', 'bar'])) def test_process_auto(self): - path = tempfile.mkstemp(suffix='.gz')[1] + path = os.path.join(self._new_tempdir(), 'result.gz') self._write_file_gzip(path, FOO_BAR_RECORD_BASE64) with beam.Pipeline(DirectRunner()) as p: result = (p @@ -271,7 +295,7 @@ class TestTFRecordSource(unittest.TestCase): class TestReadFromTFRecordSource(TestTFRecordSource): def test_process_gzip(self): - path = tempfile.NamedTemporaryFile().name + path = os.path.join(self._new_tempdir(), 'result') self._write_file_gzip(path, FOO_BAR_RECORD_BASE64) with beam.Pipeline(DirectRunner()) as p: result = (p @@ -280,7 +304,7 @@ class TestReadFromTFRecordSource(TestTFRecordSource): beam.assert_that(result, beam.equal_to(['foo', 'bar'])) def test_process_gzip_auto(self): - path = tempfile.mkstemp(suffix='.gz')[1] + path = os.path.join(self._new_tempdir(), 'result.gz') self._write_file_gzip(path, FOO_BAR_RECORD_BASE64) with beam.Pipeline(DirectRunner()) as p: result = (p @@ -289,7 +313,7 @@ class TestReadFromTFRecordSource(TestTFRecordSource): beam.assert_that(result, beam.equal_to(['foo', 'bar'])) -class TestEnd2EndWriteAndRead(unittest.TestCase): +class TestEnd2EndWriteAndRead(_TestCaseWithTempDirCleanUp): def create_inputs(self): input_array = [[random.random() - 0.5 for _ in xrange(15)] @@ -299,7 +323,7 @@ class TestEnd2EndWriteAndRead(unittest.TestCase): return memfile.getvalue() def test_end2end(self): - file_path_prefix = tempfile.NamedTemporaryFile().name + file_path_prefix = os.path.join(self._new_tempdir(), 'result') # Generate a TFRecord file. with beam.Pipeline(DirectRunner()) as p: @@ -312,7 +336,7 @@ class TestEnd2EndWriteAndRead(unittest.TestCase): beam.assert_that(actual_data, beam.equal_to(expected_data)) def test_end2end_auto_compression(self): - file_path_prefix = tempfile.NamedTemporaryFile().name + file_path_prefix = os.path.join(self._new_tempdir(), 'result') # Generate a TFRecord file. with beam.Pipeline(DirectRunner()) as p: @@ -326,7 +350,7 @@ class TestEnd2EndWriteAndRead(unittest.TestCase): beam.assert_that(actual_data, beam.equal_to(expected_data)) def test_end2end_auto_compression_unsharded(self): - file_path_prefix = tempfile.NamedTemporaryFile().name + file_path_prefix = os.path.join(self._new_tempdir(), 'result') # Generate a TFRecord file. with beam.Pipeline(DirectRunner()) as p: @@ -341,7 +365,7 @@ class TestEnd2EndWriteAndRead(unittest.TestCase): @unittest.skipIf(tf is None, 'tensorflow not installed.') def test_end2end_example_proto(self): - file_path_prefix = tempfile.NamedTemporaryFile().name + file_path_prefix = os.path.join(self._new_tempdir(), 'result') example = tf.train.Example() example.features.feature['int'].int64_list.value.extend(range(3))