Repository: beam Updated Branches: refs/heads/master 5fe11a2fc -> 93ae666be
[BEAM-782] support runner names to be partial or case insensitive Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/75d5348d Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/75d5348d Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/75d5348d Branch: refs/heads/master Commit: 75d5348d4757280513db66fde06ef8ea6aecb4d5 Parents: 5fe11a2 Author: Sourabh Bajaj <sourabhba...@google.com> Authored: Fri Feb 10 17:21:58 2017 -0800 Committer: Ahmet Altay <al...@google.com> Committed: Sun Feb 12 20:49:38 2017 -0800 ---------------------------------------------------------------------- sdks/python/apache_beam/runners/runner.py | 33 +++++++++++++++++---- sdks/python/apache_beam/runners/runner_test.py | 12 ++++++++ 2 files changed, 39 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/75d5348d/sdks/python/apache_beam/runners/runner.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/runner.py b/sdks/python/apache_beam/runners/runner.py index cf75a2b..34b4dcb 100644 --- a/sdks/python/apache_beam/runners/runner.py +++ b/sdks/python/apache_beam/runners/runner.py @@ -26,9 +26,30 @@ import shutil import tempfile +def _get_runner_map(runner_names, module_path): + """Create a map of runner name in lower case to full import path to the + runner class. + """ + return {runner_name.lower(): module_path + runner_name + for runner_name in runner_names} + + +_DIRECT_RUNNER_PATH = 'apache_beam.runners.direct.direct_runner.' +_DATAFLOW_RUNNER_PATH = 'apache_beam.runners.dataflow_runner.' +_TEST_RUNNER_PATH = 'apache_beam.runners.test.' + _KNOWN_DIRECT_RUNNERS = ('DirectRunner', 'EagerRunner') _KNOWN_DATAFLOW_RUNNERS = ('DataflowRunner',) _KNOWN_TEST_RUNNERS = ('TestDataflowRunner',) + +_RUNNER_MAP = {} +_RUNNER_MAP.update(_get_runner_map(_KNOWN_DIRECT_RUNNERS, + _DIRECT_RUNNER_PATH)) +_RUNNER_MAP.update(_get_runner_map(_KNOWN_DATAFLOW_RUNNERS, + _DATAFLOW_RUNNER_PATH)) +_RUNNER_MAP.update(_get_runner_map(_KNOWN_TEST_RUNNERS, + _TEST_RUNNER_PATH)) + _ALL_KNOWN_RUNNERS = ( _KNOWN_DIRECT_RUNNERS + _KNOWN_DATAFLOW_RUNNERS + _KNOWN_TEST_RUNNERS) @@ -47,12 +68,12 @@ def create_runner(runner_name): RuntimeError: if an invalid runner name is used. """ - if runner_name in _KNOWN_DIRECT_RUNNERS: - runner_name = 'apache_beam.runners.direct.direct_runner.' + runner_name - elif runner_name in _KNOWN_DATAFLOW_RUNNERS: - runner_name = 'apache_beam.runners.dataflow_runner.' + runner_name - elif runner_name in _KNOWN_TEST_RUNNERS: - runner_name = 'apache_beam.runners.test.' + runner_name + # Get the qualified runner name by using the lower case runner name. If that + # fails try appending the name with 'runner' and check if it matches. + # If that also fails, use the given runner name as is. + runner_name = _RUNNER_MAP.get( + runner_name.lower(), + _RUNNER_MAP.get(runner_name.lower() + 'runner', runner_name)) if '.' in runner_name: module, runner = runner_name.rsplit('.', 1) http://git-wip-us.apache.org/repos/asf/beam/blob/75d5348d/sdks/python/apache_beam/runners/runner_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/runner_test.py b/sdks/python/apache_beam/runners/runner_test.py index ef2b994..d7a50aa 100644 --- a/sdks/python/apache_beam/runners/runner_test.py +++ b/sdks/python/apache_beam/runners/runner_test.py @@ -67,6 +67,18 @@ class RunnerTest(unittest.TestCase): TestDataflowRunner)) self.assertRaises(ValueError, create_runner, 'xyz') + def test_create_runner_shorthand(self): + self.assertTrue( + isinstance(create_runner('DiReCtRuNnEr'), DirectRunner)) + self.assertTrue( + isinstance(create_runner('directrunner'), DirectRunner)) + self.assertTrue( + isinstance(create_runner('direct'), DirectRunner)) + self.assertTrue( + isinstance(create_runner('DiReCt'), DirectRunner)) + self.assertTrue( + isinstance(create_runner('Direct'), DirectRunner)) + def test_remote_runner_translation(self): remote_runner = DataflowRunner() p = Pipeline(remote_runner,