Repository: incubator-beam Updated Branches: refs/heads/python-sdk 15e78b28a -> 560fe79f8
[BEAM-852] Add validation to file based sources during create time Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/76ad2929 Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/76ad2929 Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/76ad2929 Branch: refs/heads/python-sdk Commit: 76ad29296fd57e1eec97bf40d9cf3a1d54a63a3f Parents: 15e78b2 Author: Sourabh Bajaj <sourabhba...@google.com> Authored: Mon Nov 14 15:40:10 2016 -0800 Committer: Robert Bradshaw <rober...@google.com> Committed: Mon Nov 14 15:40:10 2016 -0800 ---------------------------------------------------------------------- sdks/python/apache_beam/io/avroio.py | 8 +++- sdks/python/apache_beam/io/bigquery.py | 2 +- sdks/python/apache_beam/io/filebasedsource.py | 16 +++++++- .../apache_beam/io/filebasedsource_test.py | 41 ++++++++++++++------ sdks/python/apache_beam/io/textio.py | 13 +++++-- 5 files changed, 60 insertions(+), 20 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/76ad2929/sdks/python/apache_beam/io/avroio.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/avroio.py b/sdks/python/apache_beam/io/avroio.py index 53ed95a..e7e73dd 100644 --- a/sdks/python/apache_beam/io/avroio.py +++ b/sdks/python/apache_beam/io/avroio.py @@ -37,7 +37,7 @@ __all__ = ['ReadFromAvro', 'WriteToAvro'] class ReadFromAvro(PTransform): """A ``PTransform`` for reading avro files.""" - def __init__(self, file_pattern=None, min_bundle_size=0): + def __init__(self, file_pattern=None, min_bundle_size=0, validate=True): """Initializes ``ReadFromAvro``. Uses source '_AvroSource' to read a set of Avro files defined by a given @@ -70,13 +70,17 @@ class ReadFromAvro(PTransform): file_pattern: the set of files to be read. min_bundle_size: the minimum size in bytes, to be considered when splitting the input into bundles. + validate: flag to verify that the files exist during the pipeline + creation time. **kwargs: Additional keyword arguments to be passed to the base class. """ super(ReadFromAvro, self).__init__() self._args = (file_pattern, min_bundle_size) + self._validate = validate def apply(self, pvalue): - return pvalue.pipeline | Read(_AvroSource(*self._args)) + return pvalue.pipeline | Read(_AvroSource(*self._args, + validate=self._validate)) class _AvroUtils(object): http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/76ad2929/sdks/python/apache_beam/io/bigquery.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/bigquery.py b/sdks/python/apache_beam/io/bigquery.py index f0e88a6..8d7892a 100644 --- a/sdks/python/apache_beam/io/bigquery.py +++ b/sdks/python/apache_beam/io/bigquery.py @@ -65,7 +65,7 @@ input entails querying the table for all its rows. The coder argument on BigQuerySource controls the reading of the lines in the export files (i.e., transform a JSON object into a PCollection element). The coder is not involved when the same table is read as a side input since there is no intermediate -format involved. We get the table rows directly from the BigQuery service with +format involved. We get the table rows directly from the BigQuery service with a query. Users may provide a query to read from rather than reading all of a BigQuery http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/76ad2929/sdks/python/apache_beam/io/filebasedsource.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/filebasedsource.py b/sdks/python/apache_beam/io/filebasedsource.py index 58ad118..c7bc27e 100644 --- a/sdks/python/apache_beam/io/filebasedsource.py +++ b/sdks/python/apache_beam/io/filebasedsource.py @@ -50,7 +50,8 @@ class FileBasedSource(iobase.BoundedSource): file_pattern, min_bundle_size=0, compression_type=fileio.CompressionTypes.AUTO, - splittable=True): + splittable=True, + validate=True): """Initializes ``FileBasedSource``. Args: @@ -68,10 +69,13 @@ class FileBasedSource(iobase.BoundedSource): the file, for example, for compressed files where currently it is not possible to efficiently read a data range without decompressing the whole file. + validate: Boolean flag to verify that the files exist during the pipeline + creation time. Raises: TypeError: when compression_type is not valid or if file_pattern is not a string. ValueError: when compression and splittable files are specified. + IOError: when the file pattern specified yields an empty result. """ if not isinstance(file_pattern, basestring): raise TypeError( @@ -91,6 +95,8 @@ class FileBasedSource(iobase.BoundedSource): else: # We can't split compressed files efficiently so turn off splitting. self._splittable = False + if validate: + self._validate() def display_data(self): return {'filePattern': DisplayDataItem(self._pattern, label="File Pattern"), @@ -133,7 +139,6 @@ class FileBasedSource(iobase.BoundedSource): @staticmethod def _estimate_sizes_in_parallel(file_names): - if not file_names: return [] elif len(file_names) == 1: @@ -150,6 +155,13 @@ class FileBasedSource(iobase.BoundedSource): finally: pool.terminate() + def _validate(self): + """Validate if there are actual files in the specified glob pattern + """ + if len(fileio.ChannelFactory.glob(self._pattern)) <= 0: + raise IOError( + 'No files found based on the file pattern %s' % self._pattern) + def split( self, desired_bundle_size=None, start_position=None, stop_position=None): return self._get_concat_source().split( http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/76ad2929/sdks/python/apache_beam/io/filebasedsource_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/filebasedsource_test.py b/sdks/python/apache_beam/io/filebasedsource_test.py index 7bc31fd..7f4d8d3 100644 --- a/sdks/python/apache_beam/io/filebasedsource_test.py +++ b/sdks/python/apache_beam/io/filebasedsource_test.py @@ -220,6 +220,26 @@ class TestFileBasedSource(unittest.TestCase): # environments with limited amount of resources. filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2 + def test_validation_file_exists(self): + file_name, _ = write_data(10) + LineSource(file_name) + + def test_validation_directory_non_empty(self): + temp_dir = tempfile.mkdtemp() + file_name, _ = write_data(10, directory=temp_dir) + LineSource(file_name) + + def test_validation_failing(self): + no_files_found_error = 'No files found based on the file pattern*' + with self.assertRaisesRegexp(IOError, no_files_found_error): + LineSource('dummy_pattern') + with self.assertRaisesRegexp(IOError, no_files_found_error): + temp_dir = tempfile.mkdtemp() + LineSource(os.path.join(temp_dir, '*')) + + def test_validation_file_missing_verification_disabled(self): + LineSource('dummy_pattern', validate=False) + def test_fully_read_single_file(self): file_name, expected_data = write_data(10) assert len(expected_data) == 10 @@ -525,7 +545,7 @@ class TestSingleFileSource(unittest.TestCase): start_not_a_number_error = 'start_offset must be a number*' stop_not_a_number_error = 'stop_offset must be a number*' file_name = 'dummy_pattern' - fbs = LineSource(file_name) + fbs = LineSource(file_name, validate=False) with self.assertRaisesRegexp(TypeError, start_not_a_number_error): SingleFileSource( @@ -545,7 +565,7 @@ class TestSingleFileSource(unittest.TestCase): def test_source_creation_display_data(self): file_name = 'dummy_pattern' - fbs = LineSource(file_name) + fbs = LineSource(file_name, validate=False) dd = DisplayData.create_from(fbs) expected_items = [ DisplayDataItemMatcher('compression', 'auto'), @@ -556,8 +576,7 @@ class TestSingleFileSource(unittest.TestCase): def test_source_creation_fails_if_start_lg_stop(self): start_larger_than_stop_error = ( 'start_offset must be smaller than stop_offset*') - - fbs = LineSource('dummy_pattern') + fbs = LineSource('dummy_pattern', validate=False) SingleFileSource( fbs, file_name='dummy_file', start_offset=99, stop_offset=100) with self.assertRaisesRegexp(ValueError, start_larger_than_stop_error): @@ -568,7 +587,7 @@ class TestSingleFileSource(unittest.TestCase): fbs, file_name='dummy_file', start_offset=100, stop_offset=100) def test_estimates_size(self): - fbs = LineSource('dummy_pattern') + fbs = LineSource('dummy_pattern', validate=False) # Should simply return stop_offset - start_offset source = SingleFileSource( @@ -580,7 +599,7 @@ class TestSingleFileSource(unittest.TestCase): self.assertEquals(90, source.estimate_size()) def test_read_range_at_beginning(self): - fbs = LineSource('dummy_pattern') + fbs = LineSource('dummy_pattern', validate=False) file_name, expected_data = write_data(10) assert len(expected_data) == 10 @@ -591,7 +610,7 @@ class TestSingleFileSource(unittest.TestCase): self.assertItemsEqual(expected_data[:4], read_data) def test_read_range_at_end(self): - fbs = LineSource('dummy_pattern') + fbs = LineSource('dummy_pattern', validate=False) file_name, expected_data = write_data(10) assert len(expected_data) == 10 @@ -602,7 +621,7 @@ class TestSingleFileSource(unittest.TestCase): self.assertItemsEqual(expected_data[-3:], read_data) def test_read_range_at_middle(self): - fbs = LineSource('dummy_pattern') + fbs = LineSource('dummy_pattern', validate=False) file_name, expected_data = write_data(10) assert len(expected_data) == 10 @@ -613,7 +632,7 @@ class TestSingleFileSource(unittest.TestCase): self.assertItemsEqual(expected_data[4:7], read_data) def test_produces_splits_desiredsize_large_than_size(self): - fbs = LineSource('dummy_pattern') + fbs = LineSource('dummy_pattern', validate=False) file_name, expected_data = write_data(10) assert len(expected_data) == 10 @@ -629,7 +648,7 @@ class TestSingleFileSource(unittest.TestCase): self.assertItemsEqual(expected_data, read_data) def test_produces_splits_desiredsize_smaller_than_size(self): - fbs = LineSource('dummy_pattern') + fbs = LineSource('dummy_pattern', validate=False) file_name, expected_data = write_data(10) assert len(expected_data) == 10 @@ -647,7 +666,7 @@ class TestSingleFileSource(unittest.TestCase): self.assertItemsEqual(expected_data, read_data) def test_produce_split_with_start_and_end_positions(self): - fbs = LineSource('dummy_pattern') + fbs = LineSource('dummy_pattern', validate=False) file_name, expected_data = write_data(10) assert len(expected_data) == 10 http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/76ad2929/sdks/python/apache_beam/io/textio.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/textio.py b/sdks/python/apache_beam/io/textio.py index 01f6ef6..e031572 100644 --- a/sdks/python/apache_beam/io/textio.py +++ b/sdks/python/apache_beam/io/textio.py @@ -72,9 +72,10 @@ class _TextSource(filebasedsource.FileBasedSource): def __init__(self, file_pattern, min_bundle_size, compression_type, strip_trailing_newlines, coder, - buffer_size=DEFAULT_READ_BUFFER_SIZE): + buffer_size=DEFAULT_READ_BUFFER_SIZE, validate=True): super(_TextSource, self).__init__(file_pattern, min_bundle_size, - compression_type=compression_type) + compression_type=compression_type, + validate=validate) self._strip_trailing_newlines = strip_trailing_newlines self._compression_type = compression_type @@ -206,7 +207,6 @@ class ReadFromText(PTransform): This implementation only supports reading text encoded using UTF-8 or ASCII. This does not support other encodings such as UTF-16 or UTF-32.""" - def __init__( self, file_pattern=None, @@ -214,6 +214,7 @@ class ReadFromText(PTransform): compression_type=fileio.CompressionTypes.AUTO, strip_trailing_newlines=True, coder=coders.StrUtf8Coder(), + validate=True, **kwargs): """Initialize the ReadFromText transform. @@ -230,15 +231,19 @@ class ReadFromText(PTransform): strip_trailing_newlines: Indicates whether this source should remove the newline char in each line it reads before decoding that line. + validate: flag to verify that the files exist during the pipeline + creation time. coder: Coder used to decode each line. """ super(ReadFromText, self).__init__(**kwargs) self._args = (file_pattern, min_bundle_size, compression_type, strip_trailing_newlines, coder) + self._validate = validate def apply(self, pvalue): - return pvalue.pipeline | Read(_TextSource(*self._args)) + return pvalue.pipeline | Read(_TextSource(*self._args, + validate=self._validate)) class WriteToText(PTransform):