Repository: beam
Updated Branches:
  refs/heads/master 5fe11a2fc -> 93ae666be


[BEAM-782] support runner names to be partial or case insensitive


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/75d5348d
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/75d5348d
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/75d5348d

Branch: refs/heads/master
Commit: 75d5348d4757280513db66fde06ef8ea6aecb4d5
Parents: 5fe11a2
Author: Sourabh Bajaj <sourabhba...@google.com>
Authored: Fri Feb 10 17:21:58 2017 -0800
Committer: Ahmet Altay <al...@google.com>
Committed: Sun Feb 12 20:49:38 2017 -0800

----------------------------------------------------------------------
 sdks/python/apache_beam/runners/runner.py      | 33 +++++++++++++++++----
 sdks/python/apache_beam/runners/runner_test.py | 12 ++++++++
 2 files changed, 39 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/75d5348d/sdks/python/apache_beam/runners/runner.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/runner.py 
b/sdks/python/apache_beam/runners/runner.py
index cf75a2b..34b4dcb 100644
--- a/sdks/python/apache_beam/runners/runner.py
+++ b/sdks/python/apache_beam/runners/runner.py
@@ -26,9 +26,30 @@ import shutil
 import tempfile
 
 
+def _get_runner_map(runner_names, module_path):
+  """Create a map of runner name in lower case to full import path to the
+  runner class.
+  """
+  return {runner_name.lower(): module_path + runner_name
+          for runner_name in runner_names}
+
+
+_DIRECT_RUNNER_PATH = 'apache_beam.runners.direct.direct_runner.'
+_DATAFLOW_RUNNER_PATH = 'apache_beam.runners.dataflow_runner.'
+_TEST_RUNNER_PATH = 'apache_beam.runners.test.'
+
 _KNOWN_DIRECT_RUNNERS = ('DirectRunner', 'EagerRunner')
 _KNOWN_DATAFLOW_RUNNERS = ('DataflowRunner',)
 _KNOWN_TEST_RUNNERS = ('TestDataflowRunner',)
+
+_RUNNER_MAP = {}
+_RUNNER_MAP.update(_get_runner_map(_KNOWN_DIRECT_RUNNERS,
+                                   _DIRECT_RUNNER_PATH))
+_RUNNER_MAP.update(_get_runner_map(_KNOWN_DATAFLOW_RUNNERS,
+                                   _DATAFLOW_RUNNER_PATH))
+_RUNNER_MAP.update(_get_runner_map(_KNOWN_TEST_RUNNERS,
+                                   _TEST_RUNNER_PATH))
+
 _ALL_KNOWN_RUNNERS = (
     _KNOWN_DIRECT_RUNNERS + _KNOWN_DATAFLOW_RUNNERS + _KNOWN_TEST_RUNNERS)
 
@@ -47,12 +68,12 @@ def create_runner(runner_name):
     RuntimeError: if an invalid runner name is used.
   """
 
-  if runner_name in _KNOWN_DIRECT_RUNNERS:
-    runner_name = 'apache_beam.runners.direct.direct_runner.' + runner_name
-  elif runner_name in _KNOWN_DATAFLOW_RUNNERS:
-    runner_name = 'apache_beam.runners.dataflow_runner.' + runner_name
-  elif runner_name in _KNOWN_TEST_RUNNERS:
-    runner_name = 'apache_beam.runners.test.' + runner_name
+  # Get the qualified runner name by using the lower case runner name. If that
+  # fails try appending the name with 'runner' and check if it matches.
+  # If that also fails, use the given runner name as is.
+  runner_name = _RUNNER_MAP.get(
+      runner_name.lower(),
+      _RUNNER_MAP.get(runner_name.lower() + 'runner', runner_name))
 
   if '.' in runner_name:
     module, runner = runner_name.rsplit('.', 1)

http://git-wip-us.apache.org/repos/asf/beam/blob/75d5348d/sdks/python/apache_beam/runners/runner_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/runner_test.py 
b/sdks/python/apache_beam/runners/runner_test.py
index ef2b994..d7a50aa 100644
--- a/sdks/python/apache_beam/runners/runner_test.py
+++ b/sdks/python/apache_beam/runners/runner_test.py
@@ -67,6 +67,18 @@ class RunnerTest(unittest.TestCase):
                    TestDataflowRunner))
     self.assertRaises(ValueError, create_runner, 'xyz')
 
+  def test_create_runner_shorthand(self):
+    self.assertTrue(
+        isinstance(create_runner('DiReCtRuNnEr'), DirectRunner))
+    self.assertTrue(
+        isinstance(create_runner('directrunner'), DirectRunner))
+    self.assertTrue(
+        isinstance(create_runner('direct'), DirectRunner))
+    self.assertTrue(
+        isinstance(create_runner('DiReCt'), DirectRunner))
+    self.assertTrue(
+        isinstance(create_runner('Direct'), DirectRunner))
+
   def test_remote_runner_translation(self):
     remote_runner = DataflowRunner()
     p = Pipeline(remote_runner,

Reply via email to