Repository: incubator-beam Updated Branches: refs/heads/python-sdk b00f915ee -> 561090580
Add >> operator for labeling PTransforms. Also propagates label passed into PValue.apply(...) method. Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/a740f827 Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/a740f827 Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/a740f827 Branch: refs/heads/python-sdk Commit: a740f827f1b868e391e5c865fe9ef1f005ed3cef Parents: b00f915 Author: Robert Bradshaw <[email protected]> Authored: Mon Jul 18 12:46:54 2016 -0700 Committer: Robert Bradshaw <[email protected]> Committed: Mon Jul 18 17:55:03 2016 -0700 ---------------------------------------------------------------------- sdks/python/apache_beam/dataflow_test.py | 2 -- sdks/python/apache_beam/pipeline.py | 14 +++++---- sdks/python/apache_beam/pipeline_test.py | 10 +++---- sdks/python/apache_beam/pvalue.py | 6 ++-- .../python/apache_beam/transforms/ptransform.py | 31 ++++++++------------ 5 files changed, 29 insertions(+), 34 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a740f827/sdks/python/apache_beam/dataflow_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/dataflow_test.py b/sdks/python/apache_beam/dataflow_test.py index 9bbb5ff..476f8b2 100644 --- a/sdks/python/apache_beam/dataflow_test.py +++ b/sdks/python/apache_beam/dataflow_test.py @@ -50,8 +50,6 @@ class DataflowTest(unittest.TestCase): SAMPLE_DATA = ['aa bb cc aa bb aa \n'] * 10 SAMPLE_RESULT = [('cc', 10), ('bb', 20), ('aa', 30)] - # TODO(silviuc): Figure out a nice way to specify labels for stages so that - # internal steps get prepended with surorunding stage names. @beam.ptransform_fn def Count(pcoll): # pylint: disable=invalid-name, no-self-argument """A Count transform: v, ... => (v, n), ...""" http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a740f827/sdks/python/apache_beam/pipeline.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index bc1feb2..30ad315 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -53,7 +53,6 @@ from apache_beam import typehints from apache_beam.internal import pickler from apache_beam.runners import create_runner from apache_beam.runners import PipelineRunner -from apache_beam.transforms import format_full_label from apache_beam.transforms import ptransform from apache_beam.typehints import TypeCheckError from apache_beam.utils.options import PipelineOptions @@ -182,7 +181,7 @@ class Pipeline(object): visited = set() self._root_transform().visit(visitor, self, visited) - def apply(self, transform, pvalueish=None): + def apply(self, transform, pvalueish=None, label=None): """Applies a custom transform using the pvalueish specified. Args: @@ -195,16 +194,21 @@ class Pipeline(object): RuntimeError: if the transform object was already applied to this pipeline and needs to be cloned in order to apply again. """ + if isinstance(transform, ptransform._NamedPTransform): + return self.apply(transform.transform, pvalueish, + label or transform.label) + if not isinstance(transform, ptransform.PTransform): raise TypeError("Expected a PTransform object, got %s" % transform) - full_label = format_full_label(self._current_transform(), transform) + full_label = '/'.join([self._current_transform().full_label, + label or transform.label]).lstrip('/') if full_label in self.applied_labels: raise RuntimeError( 'Transform "%s" does not have a stable unique label. ' 'This will prevent updating of pipelines. ' - 'To clone a transform with a new label use: ' - 'transform.clone("NEW LABEL").' + 'To apply a transform with a specified label write ' + 'pvalue | "label" >> transform' % full_label) self.applied_labels.add(full_label) http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a740f827/sdks/python/apache_beam/pipeline_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py index 04cd2ee..c1db5cb 100644 --- a/sdks/python/apache_beam/pipeline_test.py +++ b/sdks/python/apache_beam/pipeline_test.py @@ -157,16 +157,16 @@ class PipelineTest(unittest.TestCase): cm.exception.message, 'Transform "CustomTransform" does not have a stable unique label. ' 'This will prevent updating of pipelines. ' - 'To clone a transform with a new label use: ' - 'transform.clone("NEW LABEL").') + 'To apply a transform with a specified label write ' + 'pvalue | "label" >> transform') def test_reuse_cloned_custom_transform_instance(self): pipeline = Pipeline(self.runner_name) - pcoll1 = pipeline | Create('pcoll1', [1, 2, 3]) - pcoll2 = pipeline | Create('pcoll2', [4, 5, 6]) + pcoll1 = pipeline | 'pc1' >> Create([1, 2, 3]) + pcoll2 = pipeline | 'pc2' >> Create([4, 5, 6]) transform = PipelineTest.CustomTransform() result1 = pcoll1 | transform - result2 = pcoll2 | transform.clone('new label') + result2 = pcoll2 | 'new_label' >> transform assert_that(result1, equal_to([2, 3, 4]), label='r1') assert_that(result2, equal_to([5, 6, 7]), label='r2') pipeline.run() http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a740f827/sdks/python/apache_beam/pvalue.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/pvalue.py b/sdks/python/apache_beam/pvalue.py index 6fc3041..063d0b5 100644 --- a/sdks/python/apache_beam/pvalue.py +++ b/sdks/python/apache_beam/pvalue.py @@ -79,10 +79,8 @@ class PValue(object): optional first label and a transform/callable object. It will call the pipeline.apply() method with this modified argument list. """ - if isinstance(args[0], str): - # TODO(robertwb): Make sure labels are properly passed during - # ptransform construction and drop this argument. - args = args[1:] + if isinstance(args[0], basestring): + kwargs['label'], args = args[0], args[1:] arglist = list(args) arglist.insert(1, self) return self.pipeline.apply(*arglist, **kwargs) http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a740f827/sdks/python/apache_beam/transforms/ptransform.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py index 1457bec..bde05b5 100644 --- a/sdks/python/apache_beam/transforms/ptransform.py +++ b/sdks/python/apache_beam/transforms/ptransform.py @@ -369,6 +369,9 @@ class PTransform(WithTypeHints): # TODO(robertwb): Assert all input WindowFns compatible. return inputs[0].windowing + def __rrshift__(self, label): + return _NamedPTransform(self, label) + def __or__(self, right): """Used to compose PTransforms, e.g., ptransform1 | ptransform2.""" if isinstance(right, PTransform): @@ -676,24 +679,6 @@ def ptransform_fn(fn): return CallablePTransform(fn) -def format_full_label(applied_transform, pending_transform): - """Returns a fully formatted cumulative PTransform label. - - Args: - applied_transform: An instance of an AppliedPTransform that has been fully - applied prior to 'pending_transform'. - pending_transform: An instance of PTransform that has yet to be applied to - the Pipeline. - - Returns: - A fully formatted PTransform label. Example: '/foo/bar/baz'. - """ - label = '/'.join([applied_transform.full_label, pending_transform.label]) - # Remove leading backslash because the monitoring UI expects names that do not - # start with such a character. - return label if not label.startswith('/') else label[1:] - - def label_from_callable(fn): if hasattr(fn, 'default_label'): return fn.default_label() @@ -706,3 +691,13 @@ def label_from_callable(fn): return fn.__name__ else: return str(fn) + + +class _NamedPTransform(PTransform): + + def __init__(self, transform, label): + super(_NamedPTransform, self).__init__(label) + self.transform = transform + + def apply(self, pvalue): + raise RuntimeError("Should never be applied directly.")
