diff --git a/sdks/python/apache_beam/transforms/combiners.py b/sdks/python/apache_beam/transforms/combiners.py index e3df43632b21..dbb143e6e168 100644 --- a/sdks/python/apache_beam/transforms/combiners.py +++ b/sdks/python/apache_beam/transforms/combiners.py @@ -522,15 +522,35 @@ class Sample(object): """Combiners for sampling n elements without replacement.""" # pylint: disable=no-self-argument - @staticmethod - @ptransform.ptransform_fn - def FixedSizeGlobally(pcoll, n): - return pcoll | core.CombineGlobally(SampleCombineFn(n)) + class FixedSizeGlobally(ptransform.PTransform): + """Sample n elements from the input PCollection without replacement.""" - @staticmethod - @ptransform.ptransform_fn - def FixedSizePerKey(pcoll, n): - return pcoll | core.CombinePerKey(SampleCombineFn(n)) + def __init__(self, n): + self._n = n + + def expand(self, pcoll): + return pcoll | core.CombineGlobally(SampleCombineFn(self._n)) + + def display_data(self): + return {'n': self._n} + + def default_label(self): + return 'FixedSizeGlobally(%d)' % self._n + + class FixedSizePerKey(ptransform.PTransform): + """Sample n elements associated with each key without replacement.""" + + def __init__(self, n): + self._n = n + + def expand(self, pcoll): + return pcoll | core.CombinePerKey(SampleCombineFn(self._n)) + + def display_data(self): + return {'n': self._n} + + def default_label(self): + return 'FixedSizePerKey(%d)' % self._n @with_input_types(T) diff --git a/sdks/python/apache_beam/transforms/combiners_test.py b/sdks/python/apache_beam/transforms/combiners_test.py index 637a41f3dcb5..3db019a03599 100644 --- a/sdks/python/apache_beam/transforms/combiners_test.py +++ b/sdks/python/apache_beam/transforms/combiners_test.py @@ -177,26 +177,16 @@ def individual_test_per_key_dd(combineFn): individual_test_per_key_dd(combine.Largest(5)) def test_combine_sample_display_data(self): - def individual_test_per_key_dd(sampleFn, args, kwargs): - trs = [sampleFn(*args, **kwargs)] + def individual_test_per_key_dd(sampleFn, n): + trs = [sampleFn(n)] for transform in trs: dd = DisplayData.create_from(transform) - expected_items = [ - DisplayDataItemMatcher('fn', transform._fn.__name__)] - if args: - expected_items.append( - DisplayDataItemMatcher('args', str(args))) - if kwargs: - expected_items.append( - DisplayDataItemMatcher('kwargs', str(kwargs))) - hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) - - individual_test_per_key_dd(combine.Sample.FixedSizePerKey, - args=(5,), - kwargs={}) - individual_test_per_key_dd(combine.Sample.FixedSizeGlobally, - args=(8,), - kwargs={'arg': 9}) + hc.assert_that( + dd.items, + hc.contains_inanyorder(DisplayDataItemMatcher('n', transform._n))) + + individual_test_per_key_dd(combine.Sample.FixedSizePerKey, 5) + individual_test_per_key_dd(combine.Sample.FixedSizeGlobally, 5) def test_combine_globally_display_data(self): transform = beam.CombineGlobally(combine.Smallest(5))
With regards, Apache Git Services