Repository: beam Updated Branches: refs/heads/master 84682109b -> 91c7d3d1f
Cleanup and fix ptransform_fn decorator. Previously CallablePTransform was being used both as the factory and the transform itself, which could result in state getting carried between pipelines. Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/2b86a61e Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/2b86a61e Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/2b86a61e Branch: refs/heads/master Commit: 2b86a61e5bb07d3bd7f958e124bc8d79dc300c3f Parents: 8468210 Author: Robert Bradshaw <[email protected]> Authored: Tue Jul 11 14:32:47 2017 -0700 Committer: Robert Bradshaw <[email protected]> Committed: Tue Jul 11 18:08:01 2017 -0700 ---------------------------------------------------------------------- sdks/python/apache_beam/transforms/combiners.py | 8 ++++ .../apache_beam/transforms/combiners_test.py | 7 +--- .../python/apache_beam/transforms/ptransform.py | 41 +++++++++----------- 3 files changed, 28 insertions(+), 28 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/2b86a61e/sdks/python/apache_beam/transforms/combiners.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/transforms/combiners.py b/sdks/python/apache_beam/transforms/combiners.py index fa0742d..875306f 100644 --- a/sdks/python/apache_beam/transforms/combiners.py +++ b/sdks/python/apache_beam/transforms/combiners.py @@ -149,6 +149,7 @@ class Top(object): """Combiners for obtaining extremal elements.""" # pylint: disable=no-self-argument + @staticmethod @ptransform.ptransform_fn def Of(pcoll, n, compare=None, *args, **kwargs): """Obtain a list of the compare-most N elements in a PCollection. @@ -177,6 +178,7 @@ class Top(object): return pcoll | core.CombineGlobally( TopCombineFn(n, compare, key, reverse), *args, **kwargs) + @staticmethod @ptransform.ptransform_fn def PerKey(pcoll, n, compare=None, *args, **kwargs): """Identifies the compare-most N elements associated with each key. @@ -210,21 +212,25 @@ class Top(object): return pcoll | core.CombinePerKey( TopCombineFn(n, compare, key, reverse), *args, **kwargs) + @staticmethod @ptransform.ptransform_fn def Largest(pcoll, n): """Obtain a list of the greatest N elements in a PCollection.""" return pcoll | Top.Of(n) + @staticmethod @ptransform.ptransform_fn def Smallest(pcoll, n): """Obtain a list of the least N elements in a PCollection.""" return pcoll | Top.Of(n, reverse=True) + @staticmethod @ptransform.ptransform_fn def LargestPerKey(pcoll, n): """Identifies the N greatest elements associated with each key.""" return pcoll | Top.PerKey(n) + @staticmethod @ptransform.ptransform_fn def SmallestPerKey(pcoll, n, reverse=True): """Identifies the N least elements associated with each key.""" @@ -369,10 +375,12 @@ class Sample(object): """Combiners for sampling n elements without replacement.""" # pylint: disable=no-self-argument + @staticmethod @ptransform.ptransform_fn def FixedSizeGlobally(pcoll, n): return pcoll | core.CombineGlobally(SampleCombineFn(n)) + @staticmethod @ptransform.ptransform_fn def FixedSizePerKey(pcoll, n): return pcoll | core.CombinePerKey(SampleCombineFn(n)) http://git-wip-us.apache.org/repos/asf/beam/blob/2b86a61e/sdks/python/apache_beam/transforms/combiners_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/transforms/combiners_test.py b/sdks/python/apache_beam/transforms/combiners_test.py index c79fec8..cd2b595 100644 --- a/sdks/python/apache_beam/transforms/combiners_test.py +++ b/sdks/python/apache_beam/transforms/combiners_test.py @@ -156,14 +156,11 @@ class CombineTest(unittest.TestCase): def test_combine_sample_display_data(self): def individual_test_per_key_dd(sampleFn, args, kwargs): - trs = [beam.CombinePerKey(sampleFn(*args, **kwargs)), - beam.CombineGlobally(sampleFn(*args, **kwargs))] + trs = [sampleFn(*args, **kwargs)] for transform in trs: dd = DisplayData.create_from(transform) expected_items = [ - DisplayDataItemMatcher('fn', sampleFn.fn.__name__), - DisplayDataItemMatcher('combine_fn', - transform.fn.__class__)] + DisplayDataItemMatcher('fn', transform._fn.__name__)] if args: expected_items.append( DisplayDataItemMatcher('args', str(args))) http://git-wip-us.apache.org/repos/asf/beam/blob/2b86a61e/sdks/python/apache_beam/transforms/ptransform.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py index 6041353..cd84122 100644 --- a/sdks/python/apache_beam/transforms/ptransform.py +++ b/sdks/python/apache_beam/transforms/ptransform.py @@ -595,32 +595,23 @@ class PTransformWithSideInputs(PTransform): return '%s(%s)' % (self.__class__.__name__, self.fn.default_label()) -class CallablePTransform(PTransform): +class _PTransformFnPTransform(PTransform): """A class wrapper for a function-based transform.""" - def __init__(self, fn): - # pylint: disable=super-init-not-called - # This is a helper class for a function decorator. Only when the class - # is called (and __call__ invoked) we will have all the information - # needed to initialize the super class. - self.fn = fn - self._args = () - self._kwargs = {} + def __init__(self, fn, *args, **kwargs): + super(_PTransformFnPTransform, self).__init__() + self._fn = fn + self._args = args + self._kwargs = kwargs def display_data(self): - res = {'fn': (self.fn.__name__ - if hasattr(self.fn, '__name__') - else self.fn.__class__), + res = {'fn': (self._fn.__name__ + if hasattr(self._fn, '__name__') + else self._fn.__class__), 'args': DisplayDataItem(str(self._args)).drop_if_default('()'), 'kwargs': DisplayDataItem(str(self._kwargs)).drop_if_default('{}')} return res - def __call__(self, *args, **kwargs): - super(CallablePTransform, self).__init__() - self._args = args - self._kwargs = kwargs - return self - def expand(self, pcoll): # Since the PTransform will be implemented entirely as a function # (once called), we need to pass through any type-hinting information that @@ -629,18 +620,18 @@ class CallablePTransform(PTransform): kwargs = dict(self._kwargs) args = tuple(self._args) try: - if 'type_hints' in inspect.getargspec(self.fn).args: + if 'type_hints' in inspect.getargspec(self._fn).args: args = (self.get_type_hints(),) + args except TypeError: # Might not be a function. pass - return self.fn(pcoll, *args, **kwargs) + return self._fn(pcoll, *args, **kwargs) def default_label(self): if self._args: return '%s(%s)' % ( - label_from_callable(self.fn), label_from_callable(self._args[0])) - return label_from_callable(self.fn) + label_from_callable(self._fn), label_from_callable(self._args[0])) + return label_from_callable(self._fn) def ptransform_fn(fn): @@ -684,7 +675,11 @@ def ptransform_fn(fn): operator (i.e., `|`) will inject the pcoll argument in its proper place (first argument if no label was specified and second argument otherwise). """ - return CallablePTransform(fn) + # TODO(robertwb): Consider removing staticmethod to allow for self parameter. + + def callable_ptransform_factory(*args, **kwargs): + return _PTransformFnPTransform(fn, *args, **kwargs) + return callable_ptransform_factory def label_from_callable(fn):
