Repository: beam Updated Branches: refs/heads/master fcf3b5619 -> 6df661b0e
Add ValueProvider class for FileBasedSource I/O Transforms Incorporate a BeamArgumentParser (argparse.ArgumentParser + ValueProviders). Add StaticValueProvider and RuntimeValueProvider derived from ValueProvider. Add serialization for ValueProvider objects. Add testing for ValueProvider objects. Modify FileBasedSource and FileSink to accept ValueProvider objects. Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/1e2168a1 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/1e2168a1 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/1e2168a1 Branch: refs/heads/master Commit: 1e2168a127fb3047fb15d231a001bbf951892e11 Parents: fcf3b56 Author: Maria Garcia Herrero <[email protected]> Authored: Thu Mar 30 14:21:15 2017 -0700 Committer: Ahmet Altay <[email protected]> Committed: Tue Apr 4 18:18:38 2017 -0700 ---------------------------------------------------------------------- sdks/python/apache_beam/examples/wordcount.py | 32 ++-- .../apache_beam/internal/gcp/json_value.py | 6 + sdks/python/apache_beam/io/filebasedsource.py | 54 ++++-- .../apache_beam/io/filebasedsource_test.py | 24 +++ sdks/python/apache_beam/io/fileio.py | 56 +++++-- sdks/python/apache_beam/io/fileio_test.py | 45 +++-- .../runners/dataflow/internal/apiclient.py | 1 + .../apache_beam/runners/direct/direct_runner.py | 9 + sdks/python/apache_beam/transforms/display.py | 1 + .../apache_beam/transforms/display_test.py | 36 ++++ .../apache_beam/utils/pipeline_options.py | 92 ++++++++++- .../apache_beam/utils/pipeline_options_test.py | 52 +++++- sdks/python/apache_beam/utils/value_provider.py | 110 +++++++++++++ .../apache_beam/utils/value_provider_test.py | 165 +++++++++++++++++++ 14 files changed, 618 insertions(+), 65 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/1e2168a1/sdks/python/apache_beam/examples/wordcount.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/examples/wordcount.py b/sdks/python/apache_beam/examples/wordcount.py index 50c0328..27b9dcb 100644 --- a/sdks/python/apache_beam/examples/wordcount.py +++ b/sdks/python/apache_beam/examples/wordcount.py @@ -19,7 +19,6 @@ from __future__ import absolute_import -import argparse import logging import re @@ -67,24 +66,29 @@ class WordExtractingDoFn(beam.DoFn): def run(argv=None): """Main entry point; defines and runs the wordcount pipeline.""" - parser = argparse.ArgumentParser() - parser.add_argument('--input', - dest='input', - default='gs://dataflow-samples/shakespeare/kinglear.txt', - help='Input file to process.') - parser.add_argument('--output', - dest='output', - required=True, - help='Output file to write results to.') - known_args, pipeline_args = parser.parse_known_args(argv) + class WordcountOptions(PipelineOptions): + @classmethod + def _add_argparse_args(cls, parser): + parser.add_value_provider_argument( + '--input', + dest='input', + default='gs://dataflow-samples/shakespeare/kinglear.txt', + help='Input file to process.') + parser.add_value_provider_argument( + '--output', + dest='output', + required=True, + help='Output file to write results to.') + pipeline_options = PipelineOptions(argv) + wordcount_options = pipeline_options.view_as(WordcountOptions) + # We use the save_main_session option because one or more DoFn's in this # workflow rely on global context (e.g., a module imported at module level). - pipeline_options = PipelineOptions(pipeline_args) pipeline_options.view_as(SetupOptions).save_main_session = True p = beam.Pipeline(options=pipeline_options) # Read the text file[pattern] into a PCollection. - lines = p | 'read' >> ReadFromText(known_args.input) + lines = p | 'read' >> ReadFromText(wordcount_options.input) # Count the occurrences of each word. counts = (lines @@ -99,7 +103,7 @@ def run(argv=None): # Write the output using a "Write" transform that has side effects. # pylint: disable=expression-not-assigned - output | 'write' >> WriteToText(known_args.output) + output | 'write' >> WriteToText(wordcount_options.output) # Actually run the pipeline (all operations above are deferred). result = p.run() http://git-wip-us.apache.org/repos/asf/beam/blob/1e2168a1/sdks/python/apache_beam/internal/gcp/json_value.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/internal/gcp/json_value.py b/sdks/python/apache_beam/internal/gcp/json_value.py index c8b5393..4099c1a 100644 --- a/sdks/python/apache_beam/internal/gcp/json_value.py +++ b/sdks/python/apache_beam/internal/gcp/json_value.py @@ -25,6 +25,8 @@ except ImportError: extra_types = None # pylint: enable=wrong-import-order, wrong-import-position +from apache_beam.utils.value_provider import ValueProvider + _MAXINT64 = (1 << 63) - 1 _MININT64 = - (1 << 63) @@ -104,6 +106,10 @@ def to_json_value(obj, with_type=False): raise TypeError('Can not encode {} as a 64-bit integer'.format(obj)) elif isinstance(obj, float): return extra_types.JsonValue(double_value=obj) + elif isinstance(obj, ValueProvider): + if obj.is_accessible(): + return to_json_value(obj.get()) + return extra_types.JsonValue(is_null=True) else: raise TypeError('Cannot convert %s to a JSON value.' % repr(obj)) http://git-wip-us.apache.org/repos/asf/beam/blob/1e2168a1/sdks/python/apache_beam/io/filebasedsource.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/filebasedsource.py b/sdks/python/apache_beam/io/filebasedsource.py index a3e0667..930d958 100644 --- a/sdks/python/apache_beam/io/filebasedsource.py +++ b/sdks/python/apache_beam/io/filebasedsource.py @@ -32,6 +32,9 @@ from apache_beam.io import range_trackers from apache_beam.io.filesystem import CompressionTypes from apache_beam.io.filesystems_util import get_filesystem from apache_beam.transforms.display import DisplayDataItem +from apache_beam.utils.value_provider import ValueProvider +from apache_beam.utils.value_provider import StaticValueProvider +from apache_beam.utils.value_provider import check_accessible MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 25 @@ -50,7 +53,8 @@ class FileBasedSource(iobase.BoundedSource): """Initializes ``FileBasedSource``. Args: - file_pattern: the file glob to read. + file_pattern: the file glob to read a string or a ValueProvider + (placeholder to inject a runtime value). min_bundle_size: minimum size of bundles that should be generated when performing initial splitting on this source. compression_type: compression type to use @@ -68,17 +72,25 @@ class FileBasedSource(iobase.BoundedSource): creation time. Raises: TypeError: when compression_type is not valid or if file_pattern is not a - string. + string or a ValueProvider. ValueError: when compression and splittable files are specified. IOError: when the file pattern specified yields an empty result. """ - if not isinstance(file_pattern, basestring): - raise TypeError( - '%s: file_pattern must be a string; got %r instead' % - (self.__class__.__name__, file_pattern)) + if (not (isinstance(file_pattern, basestring) + or isinstance(file_pattern, ValueProvider))): + raise TypeError('%s: file_pattern must be of type string' + ' or ValueProvider; got %r instead' + % (self.__class__.__name__, file_pattern)) + + if isinstance(file_pattern, basestring): + file_pattern = StaticValueProvider(str, file_pattern) self._pattern = file_pattern - self._file_system = get_filesystem(file_pattern) + if file_pattern.is_accessible(): + self._file_system = get_filesystem(file_pattern.get()) + else: + self._file_system = None + self._concat_source = None self._min_bundle_size = min_bundle_size if not CompressionTypes.is_valid_compression_type(compression_type): @@ -91,19 +103,24 @@ class FileBasedSource(iobase.BoundedSource): else: # We can't split compressed files efficiently so turn off splitting. self._splittable = False - if validate: + if validate and file_pattern.is_accessible(): self._validate() def display_data(self): - return {'file_pattern': DisplayDataItem(self._pattern, + return {'file_pattern': DisplayDataItem(str(self._pattern), label="File Pattern"), 'compression': DisplayDataItem(str(self._compression_type), label='Compression Type')} + @check_accessible(['_pattern']) def _get_concat_source(self): if self._concat_source is None: + pattern = self._pattern.get() + single_file_sources = [] - match_result = self._file_system.match([self._pattern])[0] + if self._file_system is None: + self._file_system = get_filesystem(pattern) + match_result = self._file_system.match([pattern])[0] files_metadata = match_result.metadata_list # We create a reference for FileBasedSource that will be serialized along @@ -142,14 +159,19 @@ class FileBasedSource(iobase.BoundedSource): file_name, 'application/octet-stream', compression_type=self._compression_type) + @check_accessible(['_pattern']) def _validate(self): """Validate if there are actual files in the specified glob pattern """ + pattern = self._pattern.get() + if self._file_system is None: + self._file_system = get_filesystem(pattern) + # Limit the responses as we only want to check if something exists - match_result = self._file_system.match([self._pattern], limits=[1])[0] + match_result = self._file_system.match([pattern], limits=[1])[0] if len(match_result.metadata_list) <= 0: raise IOError( - 'No files found based on the file pattern %s' % self._pattern) + 'No files found based on the file pattern %s' % pattern) def split( self, desired_bundle_size=None, start_position=None, stop_position=None): @@ -158,8 +180,12 @@ class FileBasedSource(iobase.BoundedSource): start_position=start_position, stop_position=stop_position) + @check_accessible(['_pattern']) def estimate_size(self): - match_result = self._file_system.match([self._pattern])[0] + pattern = self._pattern.get() + if self._file_system is None: + self._file_system = get_filesystem(pattern) + match_result = self._file_system.match([pattern])[0] return sum([f.size_in_bytes for f in match_result.metadata_list]) def read(self, range_tracker): @@ -184,7 +210,7 @@ class FileBasedSource(iobase.BoundedSource): defined by a given ``RangeTracker``. Returns: - a iterator that gives the records read from the given file. + an iterator that gives the records read from the given file. """ raise NotImplementedError http://git-wip-us.apache.org/repos/asf/beam/blob/1e2168a1/sdks/python/apache_beam/io/filebasedsource_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/filebasedsource_test.py b/sdks/python/apache_beam/io/filebasedsource_test.py index 7b7ec8a..c25ca5d 100644 --- a/sdks/python/apache_beam/io/filebasedsource_test.py +++ b/sdks/python/apache_beam/io/filebasedsource_test.py @@ -43,6 +43,8 @@ from apache_beam.transforms.display import DisplayData from apache_beam.transforms.display_test import DisplayDataItemMatcher from apache_beam.transforms.util import assert_that from apache_beam.transforms.util import equal_to +from apache_beam.utils.value_provider import StaticValueProvider +from apache_beam.utils.value_provider import RuntimeValueProvider class LineSource(FileBasedSource): @@ -221,6 +223,28 @@ class TestFileBasedSource(unittest.TestCase): # environments with limited amount of resources. filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2 + def test_string_or_value_provider_only(self): + str_file_pattern = tempfile.NamedTemporaryFile(delete=False).name + self.assertEqual(str_file_pattern, + FileBasedSource(str_file_pattern)._pattern.value) + + static_vp_file_pattern = StaticValueProvider(value_type=str, + value=str_file_pattern) + self.assertEqual(static_vp_file_pattern, + FileBasedSource(static_vp_file_pattern)._pattern) + + runtime_vp_file_pattern = RuntimeValueProvider( + option_name='arg', + value_type=str, + default_value=str_file_pattern, + options_id=1) + self.assertEqual(runtime_vp_file_pattern, + FileBasedSource(runtime_vp_file_pattern)._pattern) + + invalid_file_pattern = 123 + with self.assertRaises(TypeError): + FileBasedSource(invalid_file_pattern) + def test_validation_file_exists(self): file_name, _ = write_data(10) LineSource(file_name) http://git-wip-us.apache.org/repos/asf/beam/blob/1e2168a1/sdks/python/apache_beam/io/fileio.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/fileio.py b/sdks/python/apache_beam/io/fileio.py index f33942a..84949dc 100644 --- a/sdks/python/apache_beam/io/fileio.py +++ b/sdks/python/apache_beam/io/fileio.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # + """File-based sources and sinks.""" from __future__ import absolute_import @@ -30,6 +31,9 @@ from apache_beam.io.filesystem import CompressedFile as _CompressedFile from apache_beam.io.filesystem import CompressionTypes from apache_beam.io.filesystems_util import get_filesystem from apache_beam.transforms.display import DisplayDataItem +from apache_beam.utils.value_provider import ValueProvider +from apache_beam.utils.value_provider import StaticValueProvider +from apache_beam.utils.value_provider import check_accessible MAX_BATCH_OPERATION_SIZE = 100 DEFAULT_SHARD_NAME_TEMPLATE = '-SSSSS-of-NNNNN' @@ -149,25 +153,30 @@ class FileSink(iobase.Sink): compression_type=CompressionTypes.AUTO): """ Raises: - TypeError: if file path parameters are not a string or if compression_type - is not member of CompressionTypes. + TypeError: if file path parameters are not a string or ValueProvider, + or if compression_type is not member of CompressionTypes. ValueError: if shard_name_template is not of expected format. """ - if not isinstance(file_path_prefix, basestring): - raise TypeError('file_path_prefix must be a string; got %r instead' % - file_path_prefix) - if not isinstance(file_name_suffix, basestring): - raise TypeError('file_name_suffix must be a string; got %r instead' % - file_name_suffix) + if not (isinstance(file_path_prefix, basestring) + or isinstance(file_path_prefix, ValueProvider)): + raise TypeError('file_path_prefix must be a string or ValueProvider;' + 'got %r instead' % file_path_prefix) + if not (isinstance(file_name_suffix, basestring) + or isinstance(file_name_suffix, ValueProvider)): + raise TypeError('file_name_suffix must be a string or ValueProvider;' + 'got %r instead' % file_name_suffix) if not CompressionTypes.is_valid_compression_type(compression_type): raise TypeError('compression_type must be CompressionType object but ' 'was %s' % type(compression_type)) - if shard_name_template is None: shard_name_template = DEFAULT_SHARD_NAME_TEMPLATE elif shard_name_template is '': num_shards = 1 + if isinstance(file_path_prefix, basestring): + file_path_prefix = StaticValueProvider(str, file_path_prefix) + if isinstance(file_name_suffix, basestring): + file_name_suffix = StaticValueProvider(str, file_name_suffix) self.file_path_prefix = file_path_prefix self.file_name_suffix = file_name_suffix self.num_shards = num_shards @@ -175,7 +184,10 @@ class FileSink(iobase.Sink): self.shard_name_format = self._template_to_format(shard_name_template) self.compression_type = compression_type self.mime_type = mime_type - self._file_system = get_filesystem(file_path_prefix) + if file_path_prefix.is_accessible(): + self._file_system = get_filesystem(file_path_prefix.get()) + else: + self._file_system = None def display_data(self): return {'shards': @@ -189,12 +201,15 @@ class FileSink(iobase.Sink): self.file_name_suffix), label='File Pattern')} + @check_accessible(['file_path_prefix']) def open(self, temp_path): """Opens ``temp_path``, returning an opaque file handle object. The returned file handle is passed to ``write_[encoded_]record`` and ``close``. """ + if self._file_system is None: + self._file_system = get_filesystem(self.file_path_prefix.get()) return self._file_system.create(temp_path, self.mime_type, self.compression_type) @@ -221,22 +236,33 @@ class FileSink(iobase.Sink): if file_handle is not None: file_handle.close() + @check_accessible(['file_path_prefix', 'file_name_suffix']) def initialize_write(self): - tmp_dir = self.file_path_prefix + self.file_name_suffix + time.strftime( + file_path_prefix = self.file_path_prefix.get() + file_name_suffix = self.file_name_suffix.get() + tmp_dir = file_path_prefix + file_name_suffix + time.strftime( '-temp-%Y-%m-%d_%H-%M-%S') + if self._file_system is None: + self._file_system = get_filesystem(file_path_prefix) self._file_system.mkdirs(tmp_dir) return tmp_dir + @check_accessible(['file_path_prefix', 'file_name_suffix']) def open_writer(self, init_result, uid): # A proper suffix is needed for AUTO compression detection. # We also ensure there will be no collisions with uid and a # (possibly unsharded) file_path_prefix and a (possibly empty) # file_name_suffix. + file_path_prefix = self.file_path_prefix.get() + file_name_suffix = self.file_name_suffix.get() suffix = ( - '.' + os.path.basename(self.file_path_prefix) + self.file_name_suffix) + '.' + os.path.basename(file_path_prefix) + file_name_suffix) return FileSinkWriter(self, os.path.join(init_result, uid) + suffix) + @check_accessible(['file_path_prefix', 'file_name_suffix']) def finalize_write(self, init_result, writer_results): + file_path_prefix = self.file_path_prefix.get() + file_name_suffix = self.file_name_suffix.get() writer_results = sorted(writer_results) num_shards = len(writer_results) min_threads = min(num_shards, FileSink._MAX_RENAME_THREADS) @@ -246,8 +272,8 @@ class FileSink(iobase.Sink): destination_files = [] for shard_num, shard in enumerate(writer_results): final_name = ''.join([ - self.file_path_prefix, self.shard_name_format % dict( - shard_num=shard_num, num_shards=num_shards), self.file_name_suffix + file_path_prefix, self.shard_name_format % dict( + shard_num=shard_num, num_shards=num_shards), file_name_suffix ]) source_files.append(shard) destination_files.append(final_name) @@ -270,6 +296,8 @@ class FileSink(iobase.Sink): """_rename_batch executes batch rename operations.""" source_files, destination_files = batch exceptions = [] + if self._file_system is None: + self._file_system = get_filesystem(file_path_prefix) try: self._file_system.rename(source_files, destination_files) return exceptions http://git-wip-us.apache.org/repos/asf/beam/blob/1e2168a1/sdks/python/apache_beam/io/fileio_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/fileio_test.py b/sdks/python/apache_beam/io/fileio_test.py index 6b7437d..504a2b9 100644 --- a/sdks/python/apache_beam/io/fileio_test.py +++ b/sdks/python/apache_beam/io/fileio_test.py @@ -35,6 +35,8 @@ from apache_beam.test_pipeline import TestPipeline from apache_beam.transforms.display import DisplayData from apache_beam.transforms.display_test import DisplayDataItemMatcher +from apache_beam.utils.value_provider import StaticValueProvider + # TODO: Refactor code so all io tests are using same library # TestCaseWithTempDirCleanup class. @@ -124,7 +126,7 @@ class TestFileSink(_TestCaseWithTempDirCleanUp): def test_file_sink_writing(self): temp_path = os.path.join(self._new_tempdir(), 'filesink') sink = MyFileSink( - temp_path, file_name_suffix='.foo', coder=coders.ToStringCoder()) + temp_path, file_name_suffix='.output', coder=coders.ToStringCoder()) # Manually invoke the generic Sink API. init_token = sink.initialize_write() @@ -145,8 +147,8 @@ class TestFileSink(_TestCaseWithTempDirCleanUp): res = list(sink.finalize_write(init_token, [res1, res2])) # Check the results. - shard1 = temp_path + '-00000-of-00002.foo' - shard2 = temp_path + '-00001-of-00002.foo' + shard1 = temp_path + '-00000-of-00002.output' + shard2 = temp_path + '-00001-of-00002.output' self.assertEqual(res, [shard1, shard2]) self.assertEqual(open(shard1).read(), '[start][a][b][end]') self.assertEqual(open(shard2).read(), '[start][x][y][z][end]') @@ -157,33 +159,48 @@ class TestFileSink(_TestCaseWithTempDirCleanUp): def test_file_sink_display_data(self): temp_path = os.path.join(self._new_tempdir(), 'display') sink = MyFileSink( - temp_path, file_name_suffix='.foo', coder=coders.ToStringCoder()) + temp_path, file_name_suffix='.output', coder=coders.ToStringCoder()) dd = DisplayData.create_from(sink) expected_items = [ DisplayDataItemMatcher( 'compression', 'auto'), DisplayDataItemMatcher( 'file_pattern', - '{}{}'.format(temp_path, - '-%(shard_num)05d-of-%(num_shards)05d.foo'))] - + '{}{}'.format( + temp_path, + '-%(shard_num)05d-of-%(num_shards)05d.output'))] hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) def test_empty_write(self): temp_path = tempfile.NamedTemporaryFile().name sink = MyFileSink( - temp_path, file_name_suffix='.foo', coder=coders.ToStringCoder()) + temp_path, file_name_suffix='.output', coder=coders.ToStringCoder() + ) + p = TestPipeline() + p | beam.Create([]) | beam.io.Write(sink) # pylint: disable=expression-not-assigned + p.run() + self.assertEqual( + open(temp_path + '-00000-of-00001.output').read(), '[start][end]') + + def test_static_value_provider_empty_write(self): + temp_path = StaticValueProvider(value_type=str, + value=tempfile.NamedTemporaryFile().name) + sink = MyFileSink( + temp_path, + file_name_suffix=StaticValueProvider(value_type=str, value='.output'), + coder=coders.ToStringCoder() + ) p = TestPipeline() p | beam.Create([]) | beam.io.Write(sink) # pylint: disable=expression-not-assigned p.run() self.assertEqual( - open(temp_path + '-00000-of-00001.foo').read(), '[start][end]') + open(temp_path.get() + '-00000-of-00001.output').read(), '[start][end]') def test_fixed_shard_write(self): temp_path = os.path.join(self._new_tempdir(), 'empty') sink = MyFileSink( temp_path, - file_name_suffix='.foo', + file_name_suffix='.output', num_shards=3, shard_name_template='_NN_SSS_', coder=coders.ToStringCoder()) @@ -193,7 +210,7 @@ class TestFileSink(_TestCaseWithTempDirCleanUp): p.run() concat = ''.join( - open(temp_path + '_03_%03d_.foo' % shard_num).read() + open(temp_path + '_03_%03d_.output' % shard_num).read() for shard_num in range(3)) self.assertTrue('][a][' in concat, concat) self.assertTrue('][b][' in concat, concat) @@ -201,7 +218,7 @@ class TestFileSink(_TestCaseWithTempDirCleanUp): def test_file_sink_multi_shards(self): temp_path = os.path.join(self._new_tempdir(), 'multishard') sink = MyFileSink( - temp_path, file_name_suffix='.foo', coder=coders.ToStringCoder()) + temp_path, file_name_suffix='.output', coder=coders.ToStringCoder()) # Manually invoke the generic Sink API. init_token = sink.initialize_write() @@ -224,7 +241,7 @@ class TestFileSink(_TestCaseWithTempDirCleanUp): res = sorted(res_second) for i in range(num_shards): - shard_name = '%s-%05d-of-%05d.foo' % (temp_path, i, num_shards) + shard_name = '%s-%05d-of-%05d.output' % (temp_path, i, num_shards) uuid = 'uuid-%05d' % i self.assertEqual(res[i], shard_name) self.assertEqual( @@ -236,7 +253,7 @@ class TestFileSink(_TestCaseWithTempDirCleanUp): def test_file_sink_io_error(self): temp_path = os.path.join(self._new_tempdir(), 'ioerror') sink = MyFileSink( - temp_path, file_name_suffix='.foo', coder=coders.ToStringCoder()) + temp_path, file_name_suffix='.output', coder=coders.ToStringCoder()) # Manually invoke the generic Sink API. init_token = sink.initialize_write() http://git-wip-us.apache.org/repos/asf/beam/blob/1e2168a1/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py index 6fa2f26..6d4e538 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py @@ -233,6 +233,7 @@ class Environment(object): options_dict = {k: v for k, v in sdk_pipeline_options.iteritems() if v is not None} + options_dict['_options_id'] = options._options_id self.proto.sdkPipelineOptions.additionalProperties.append( dataflow.Environment.SdkPipelineOptionsValue.AdditionalProperty( key='options', value=to_json_value(options_dict))) http://git-wip-us.apache.org/repos/asf/beam/blob/1e2168a1/sdks/python/apache_beam/runners/direct/direct_runner.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py index efad2e0..1a5775f 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner.py @@ -33,6 +33,7 @@ from apache_beam.runners.runner import PipelineRunner from apache_beam.runners.runner import PipelineState from apache_beam.runners.runner import PValueCache from apache_beam.utils.pipeline_options import DirectOptions +from apache_beam.utils.value_provider import RuntimeValueProvider class DirectRunner(PipelineRunner): @@ -86,6 +87,9 @@ class DirectRunner(PipelineRunner): evaluation_context) # Start the executor. This is a non-blocking call, it will start the # execution in background threads and return. + + if pipeline.options: + RuntimeValueProvider.set_runtime_options(pipeline.options._options_id, {}) executor.start(self.visitor.root_transforms) result = DirectPipelineResult(executor, evaluation_context) @@ -95,6 +99,11 @@ class DirectRunner(PipelineRunner): result.wait_until_finish() self._cache.finalize() + # Unset runtime options after the pipeline finishes. + # TODO: Move this to a post finish hook and clean for all cases. + if pipeline.options: + RuntimeValueProvider.unset_runtime_options(pipeline.options._options_id) + return result @property http://git-wip-us.apache.org/repos/asf/beam/blob/1e2168a1/sdks/python/apache_beam/transforms/display.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/transforms/display.py b/sdks/python/apache_beam/transforms/display.py index 2ced1af..f2ce0fc 100644 --- a/sdks/python/apache_beam/transforms/display.py +++ b/sdks/python/apache_beam/transforms/display.py @@ -40,6 +40,7 @@ from datetime import datetime, timedelta import inspect import json + __all__ = ['HasDisplayData', 'DisplayDataItem', 'DisplayData'] http://git-wip-us.apache.org/repos/asf/beam/blob/1e2168a1/sdks/python/apache_beam/transforms/display_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/transforms/display_test.py b/sdks/python/apache_beam/transforms/display_test.py index 5e106e5..7d1130b 100644 --- a/sdks/python/apache_beam/transforms/display_test.py +++ b/sdks/python/apache_beam/transforms/display_test.py @@ -114,6 +114,42 @@ class DisplayDataTest(unittest.TestCase): with self.assertRaises(ValueError): DisplayData.create_from_options(MyDisplayComponent()) + def test_value_provider_display_data(self): + class TestOptions(PipelineOptions): + @classmethod + def _add_argparse_args(cls, parser): + parser.add_value_provider_argument( + '--int_flag', + type=int, + help='int_flag description') + parser.add_value_provider_argument( + '--str_flag', + type=str, + default='hello', + help='str_flag description') + parser.add_value_provider_argument( + '--float_flag', + type=float, + help='float_flag description') + options = TestOptions(['--int_flag', '1']) + items = DisplayData.create_from_options(options).items + expected_items = [ + DisplayDataItemMatcher( + 'int_flag', + '1'), + DisplayDataItemMatcher( + 'str_flag', + 'RuntimeValueProvider(option: str_flag,' + ' type: str, default_value: \'hello\')' + ), + DisplayDataItemMatcher( + 'float_flag', + 'RuntimeValueProvider(option: float_flag,' + ' type: float, default_value: None)' + ) + ] + hc.assert_that(items, hc.contains_inanyorder(*expected_items)) + def test_create_list_display_data(self): flags = ['--extra_package', 'package1', '--extra_package', 'package2'] pipeline_options = PipelineOptions(flags=flags) http://git-wip-us.apache.org/repos/asf/beam/blob/1e2168a1/sdks/python/apache_beam/utils/pipeline_options.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/utils/pipeline_options.py b/sdks/python/apache_beam/utils/pipeline_options.py index c2a44ad..769beb3 100644 --- a/sdks/python/apache_beam/utils/pipeline_options.py +++ b/sdks/python/apache_beam/utils/pipeline_options.py @@ -18,8 +18,81 @@ """Pipeline options obtained from command line parsing.""" import argparse +import itertools from apache_beam.transforms.display import HasDisplayData +from apache_beam.utils.value_provider import StaticValueProvider +from apache_beam.utils.value_provider import RuntimeValueProvider +from apache_beam.utils.value_provider import ValueProvider + + +def _static_value_provider_of(value_type): + """"Helper function to plug a ValueProvider into argparse. + + Args: + value_type: the type of the value. Since the type param of argparse's + add_argument will always be ValueProvider, we need to + preserve the type of the actual value. + Returns: + A partially constructed StaticValueProvider in the form of a function. + + """ + def _f(value): + _f.func_name = value_type.__name__ + return StaticValueProvider(value_type, value) + return _f + + +class BeamArgumentParser(argparse.ArgumentParser): + """An ArgumentParser that supports ValueProvider options. + + Example Usage:: + + class TemplateUserOptions(PipelineOptions): + @classmethod + + def _add_argparse_args(cls, parser): + parser.add_value_provider_argument('--vp-arg1', default='start') + parser.add_value_provider_argument('--vp-arg2') + parser.add_argument('--non-vp-arg') + + """ + def __init__(self, options_id, *args, **kwargs): + self._options_id = options_id + super(BeamArgumentParser, self).__init__(*args, **kwargs) + + def add_value_provider_argument(self, *args, **kwargs): + """ValueProvider arguments can be either of type keyword or positional. + At runtime, even positional arguments will need to be supplied in the + key/value form. + """ + # Extract the option name from positional argument ['pos_arg'] + assert args != () and len(args[0]) >= 1 + if args[0][0] != '-': + option_name = args[0] + if kwargs.get('nargs') is None: # make them optionally templated + kwargs['nargs'] = '?' + else: + # or keyword arguments like [--kw_arg, -k, -w] or [--kw-arg] + option_name = [i.replace('--', '') for i in args if i[:2] == '--'][0] + + # reassign the type to make room for using + # StaticValueProvider as the type for add_argument + value_type = kwargs.get('type') or str + kwargs['type'] = _static_value_provider_of(value_type) + + # reassign default to default_value to make room for using + # RuntimeValueProvider as the default for add_argument + default_value = kwargs.get('default') + kwargs['default'] = RuntimeValueProvider( + option_name=option_name, + value_type=value_type, + default_value=default_value, + options_id=self._options_id + ) + + # have add_argument do most of the work + self.add_argument(*args, **kwargs) class PipelineOptions(HasDisplayData): @@ -49,8 +122,9 @@ class PipelineOptions(HasDisplayData): By default the options classes will use command line arguments to initialize the options. """ + _options_id_generator = itertools.count(1) - def __init__(self, flags=None, **kwargs): + def __init__(self, flags=None, options_id=None, **kwargs): """Initialize an options class. The initializer will traverse all subclasses, add all their argparse @@ -67,7 +141,10 @@ class PipelineOptions(HasDisplayData): """ self._flags = flags self._all_options = kwargs - parser = argparse.ArgumentParser() + self._options_id = ( + options_id or PipelineOptions._options_id_generator.next()) + parser = BeamArgumentParser(self._options_id) + for cls in type(self).mro(): if cls == PipelineOptions: break @@ -119,13 +196,12 @@ class PipelineOptions(HasDisplayData): # TODO(BEAM-1319): PipelineOption sub-classes in the main session might be # repeated. Pick last unique instance of each subclass to avoid conflicts. - parser = argparse.ArgumentParser() subset = {} + parser = BeamArgumentParser(self._options_id) for cls in PipelineOptions.__subclasses__(): subset[str(cls)] = cls for cls in subset.values(): cls._add_argparse_args(parser) # pylint: disable=protected-access - known_args, _ = parser.parse_known_args(self._flags) result = vars(known_args) @@ -133,7 +209,9 @@ class PipelineOptions(HasDisplayData): for k in result.keys(): if k in self._all_options: result[k] = self._all_options[k] - if drop_default and parser.get_default(k) == result[k]: + if (drop_default and + parser.get_default(k) == result[k] and + not isinstance(parser.get_default(k), ValueProvider)): del result[k] return result @@ -142,7 +220,7 @@ class PipelineOptions(HasDisplayData): return self.get_all_options(True) def view_as(self, cls): - view = cls(self._flags) + view = cls(self._flags, options_id=self._options_id) view._all_options = self._all_options return view @@ -166,7 +244,7 @@ class PipelineOptions(HasDisplayData): (type(self).__name__, name)) def __setattr__(self, name, value): - if name in ('_flags', '_all_options', '_visible_options'): + if name in ('_flags', '_all_options', '_visible_options', '_options_id'): super(PipelineOptions, self).__setattr__(name, value) elif name in self._visible_option_list(): self._all_options[name] = value http://git-wip-us.apache.org/repos/asf/beam/blob/1e2168a1/sdks/python/apache_beam/utils/pipeline_options_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/utils/pipeline_options_test.py b/sdks/python/apache_beam/utils/pipeline_options_test.py index 507a827..633d7da 100644 --- a/sdks/python/apache_beam/utils/pipeline_options_test.py +++ b/sdks/python/apache_beam/utils/pipeline_options_test.py @@ -24,9 +24,13 @@ import hamcrest as hc from apache_beam.transforms.display import DisplayData from apache_beam.transforms.display_test import DisplayDataItemMatcher from apache_beam.utils.pipeline_options import PipelineOptions +from apache_beam.utils.value_provider import StaticValueProvider +from apache_beam.utils.value_provider import RuntimeValueProvider class PipelineOptionsTest(unittest.TestCase): + def setUp(self): + RuntimeValueProvider.runtime_options_map = {} TEST_CASES = [ {'flags': ['--num_workers', '5'], @@ -131,7 +135,7 @@ class PipelineOptionsTest(unittest.TestCase): options.view_as(PipelineOptionsTest.MockOptions).mock_flag = True self.assertEqual(options.get_all_options()['num_workers'], 5) - self.assertEqual(options.get_all_options()['mock_flag'], True) + self.assertTrue(options.get_all_options()['mock_flag']) def test_experiments(self): options = PipelineOptions(['--experiment', 'abc', '--experiment', 'def']) @@ -185,7 +189,51 @@ class PipelineOptionsTest(unittest.TestCase): parser.add_argument('--redefined_flag', action='store_true') options = PipelineOptions(['--redefined_flag']) - self.assertEqual(options.get_all_options()['redefined_flag'], True) + self.assertTrue(options.get_all_options()['redefined_flag']) + + def test_value_provider_options(self): + class UserOptions(PipelineOptions): + @classmethod + def _add_argparse_args(cls, parser): + parser.add_value_provider_argument( + '--vp_arg', + help='This flag is a value provider') + + parser.add_value_provider_argument( + '--vp_arg2', + default=1, + type=int) + + parser.add_argument( + '--non_vp_arg', + default=1, + type=int + ) + + # Provide values: if not provided, the option becomes of the type runtime vp + options = UserOptions(['--vp_arg', 'hello']) + self.assertIsInstance(options.vp_arg, StaticValueProvider) + self.assertIsInstance(options.vp_arg2, RuntimeValueProvider) + self.assertIsInstance(options.non_vp_arg, int) + + # Values can be overwritten + options = UserOptions(vp_arg=5, + vp_arg2=StaticValueProvider(value_type=str, + value='bye'), + non_vp_arg=RuntimeValueProvider( + option_name='foo', + value_type=int, + default_value=10, + options_id=10)) + self.assertEqual(options.vp_arg, 5) + self.assertTrue(options.vp_arg2.is_accessible(), + '%s is not accessible' % options.vp_arg2) + self.assertEqual(options.vp_arg2.get(), 'bye') + self.assertFalse(options.non_vp_arg.is_accessible()) + + with self.assertRaises(RuntimeError): + options.non_vp_arg.get() + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) http://git-wip-us.apache.org/repos/asf/beam/blob/1e2168a1/sdks/python/apache_beam/utils/value_provider.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/utils/value_provider.py b/sdks/python/apache_beam/utils/value_provider.py new file mode 100644 index 0000000..a72fc4c --- /dev/null +++ b/sdks/python/apache_beam/utils/value_provider.py @@ -0,0 +1,110 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""A ValueProvider class to implement templates with both statically +and dynamically provided values. +""" + +from functools import wraps + + +class ValueProvider(object): + def is_accessible(self): + raise NotImplementedError( + 'ValueProvider.is_accessible implemented in derived classes' + ) + + def get(self): + raise NotImplementedError( + 'ValueProvider.get implemented in derived classes' + ) + + +class StaticValueProvider(ValueProvider): + def __init__(self, value_type, value): + self.value_type = value_type + self.value = value_type(value) + + def is_accessible(self): + return True + + def get(self): + return self.value + + def __str__(self): + return str(self.value) + + +class RuntimeValueProvider(ValueProvider): + runtime_options_map = {} + + def __init__(self, option_name, value_type, default_value, options_id): + assert options_id is not None + self.option_name = option_name + self.default_value = default_value + self.value_type = value_type + self.options_id = options_id + + def is_accessible(self): + return RuntimeValueProvider.runtime_options_map.get( + self.options_id) is not None + + def get(self): + runtime_options = ( + RuntimeValueProvider.runtime_options_map.get(self.options_id)) + if runtime_options is None: + raise RuntimeError('%s.get() not called from a runtime context' % self) + + candidate = runtime_options.get(self.option_name) + if candidate: + value = self.value_type(candidate) + else: + value = self.default_value + return value + + @classmethod + def set_runtime_options(cls, options_id, pipeline_options): + assert options_id not in RuntimeValueProvider.runtime_options_map + RuntimeValueProvider.runtime_options_map[options_id] = pipeline_options + + @classmethod + def unset_runtime_options(cls, options_id): + assert options_id in RuntimeValueProvider.runtime_options_map + del RuntimeValueProvider.runtime_options_map[options_id] + + def __str__(self): + return '%s(option: %s, type: %s, default_value: %s)' % ( + self.__class__.__name__, + self.option_name, + self.value_type.__name__, + repr(self.default_value) + ) + + +def check_accessible(value_provider_list): + """Check accessibility of a list of ValueProvider objects.""" + assert isinstance(value_provider_list, list) + + def _check_accessible(fnc): + @wraps(fnc) + def _f(self, *args, **kwargs): + for obj in [getattr(self, vp) for vp in value_provider_list]: + if not obj.is_accessible(): + raise RuntimeError('%s not accessible' % obj) + return fnc(self, *args, **kwargs) + return _f + return _check_accessible http://git-wip-us.apache.org/repos/asf/beam/blob/1e2168a1/sdks/python/apache_beam/utils/value_provider_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/utils/value_provider_test.py b/sdks/python/apache_beam/utils/value_provider_test.py new file mode 100644 index 0000000..83cb5e9 --- /dev/null +++ b/sdks/python/apache_beam/utils/value_provider_test.py @@ -0,0 +1,165 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Unit tests for the ValueProvider class.""" + +import unittest + +from apache_beam.utils.pipeline_options import PipelineOptions +from apache_beam.utils.value_provider import RuntimeValueProvider +from apache_beam.utils.value_provider import StaticValueProvider + + +class ValueProviderTests(unittest.TestCase): + def test_static_value_provider_keyword_argument(self): + class UserDefinedOptions(PipelineOptions): + @classmethod + def _add_argparse_args(cls, parser): + parser.add_value_provider_argument( + '--vp_arg', + help='This keyword argument is a value provider', + default='some value') + options = UserDefinedOptions(['--vp_arg', 'abc']) + self.assertTrue(isinstance(options.vp_arg, StaticValueProvider)) + self.assertTrue(options.vp_arg.is_accessible()) + self.assertEqual(options.vp_arg.get(), 'abc') + + def test_runtime_value_provider_keyword_argument(self): + class UserDefinedOptions(PipelineOptions): + @classmethod + def _add_argparse_args(cls, parser): + parser.add_value_provider_argument( + '--vp_arg', + help='This keyword argument is a value provider') + options = UserDefinedOptions() + self.assertTrue(isinstance(options.vp_arg, RuntimeValueProvider)) + self.assertFalse(options.vp_arg.is_accessible()) + with self.assertRaises(RuntimeError): + options.vp_arg.get() + + def test_static_value_provider_positional_argument(self): + class UserDefinedOptions(PipelineOptions): + @classmethod + def _add_argparse_args(cls, parser): + parser.add_value_provider_argument( + 'vp_pos_arg', + help='This positional argument is a value provider', + default='some value') + options = UserDefinedOptions(['abc']) + self.assertTrue(isinstance(options.vp_pos_arg, StaticValueProvider)) + self.assertTrue(options.vp_pos_arg.is_accessible()) + self.assertEqual(options.vp_pos_arg.get(), 'abc') + + def test_runtime_value_provider_positional_argument(self): + class UserDefinedOptions(PipelineOptions): + @classmethod + def _add_argparse_args(cls, parser): + parser.add_value_provider_argument( + 'vp_pos_arg', + help='This positional argument is a value provider') + options = UserDefinedOptions([]) + self.assertTrue(isinstance(options.vp_pos_arg, RuntimeValueProvider)) + self.assertFalse(options.vp_pos_arg.is_accessible()) + with self.assertRaises(RuntimeError): + options.vp_pos_arg.get() + + def test_static_value_provider_type_cast(self): + class UserDefinedOptions(PipelineOptions): + @classmethod + def _add_argparse_args(cls, parser): + parser.add_value_provider_argument( + '--vp_arg', + type=int, + help='This flag is a value provider') + + options = UserDefinedOptions(['--vp_arg', '123']) + self.assertTrue(isinstance(options.vp_arg, StaticValueProvider)) + self.assertTrue(options.vp_arg.is_accessible()) + self.assertEqual(options.vp_arg.get(), 123) + + def test_set_runtime_option(self): + # define ValueProvider ptions, with and without default values + class UserDefinedOptions1(PipelineOptions): + @classmethod + def _add_argparse_args(cls, parser): + parser.add_value_provider_argument( + '--vp_arg', + help='This keyword argument is a value provider') # set at runtime + + parser.add_value_provider_argument( # not set, had default int + '-v', '--vp_arg2', # with short form + default=123, + type=int) + + parser.add_value_provider_argument( # not set, had default str + '--vp-arg3', # with dash in name + default='123', + type=str) + + parser.add_value_provider_argument( # not set and no default + '--vp_arg4', + type=float) + + parser.add_value_provider_argument( # positional argument set + 'vp_pos_arg', # default & runtime ignored + help='This positional argument is a value provider', + type=float, + default=5.4) + + # provide values at graph-construction time + # (options not provided here become of the type RuntimeValueProvider) + options = UserDefinedOptions1(['1.2']) + self.assertFalse(options.vp_arg.is_accessible()) + self.assertFalse(options.vp_arg2.is_accessible()) + self.assertFalse(options.vp_arg3.is_accessible()) + self.assertFalse(options.vp_arg4.is_accessible()) + self.assertTrue(options.vp_pos_arg.is_accessible()) + + # provide values at job-execution time + # (options not provided here will use their default, if they have one) + RuntimeValueProvider.set_runtime_options( + options._options_id, {'vp_arg': 'abc', 'vp_pos_arg':'3.2'}) + self.assertTrue(options.vp_arg.is_accessible()) + self.assertEqual(options.vp_arg.get(), 'abc') + self.assertTrue(options.vp_arg2.is_accessible()) + self.assertEqual(options.vp_arg2.get(), 123) + self.assertTrue(options.vp_arg3.is_accessible()) + self.assertEqual(options.vp_arg3.get(), '123') + self.assertTrue(options.vp_arg4.is_accessible()) + self.assertIsNone(options.vp_arg4.get()) + self.assertTrue(options.vp_pos_arg.is_accessible()) + self.assertEqual(options.vp_pos_arg.get(), 1.2) + + def test_options_id(self): + class Opt1(PipelineOptions): + @classmethod + def _add_argparse_args(cls, parser): + parser.add_value_provider_argument('--arg1') + + class Opt2(PipelineOptions): + @classmethod + def _add_argparse_args(cls, parser): + parser.add_value_provider_argument('--arg2') + + opt1 = Opt1() + opt2 = Opt2() + self.assertFalse(opt1.arg1.is_accessible()) + self.assertFalse(opt2.arg2.is_accessible()) + RuntimeValueProvider.set_runtime_options( + opt1.arg1.options_id, {'arg1': 'val1'}) + self.assertTrue(opt1.arg1.is_accessible()) + self.assertFalse(opt2.arg2.is_accessible())
