This is an automated email from the ASF dual-hosted git repository.
pabloem pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new f9b9ccc Merge pull request #14869 from [BEAM-12357] improve WithKeys
transform to take args, kwargs
f9b9ccc is described below
commit f9b9ccc64bc44867be46b65af0e35e8287e19be4
Author: heidimhurst <[email protected]>
AuthorDate: Mon Jul 12 22:56:16 2021 -0600
Merge pull request #14869 from [BEAM-12357] improve WithKeys transform to
take args, kwargs
* [BEAM-12357] improve WithKeys transform to take args, kwargs
This commit extends the existing functionality of WithKeys to accept
positional and keyword arguments, consistent with the use of the Map
function. This allows inputs to be passed into the keyword creation
function. Example use: utils.WithKeys(key_fn, foo, kwarg1=bar) would
pass variables foo and bar into key_fn when each key is created.
* [BEAM-12357] PR feedback: move fn_takes_side_inputs to utils
* Update sdks/python/apache_beam/transforms/util.py
update WithKeys to pass in *args, **kwargs to internal Map function, not
just lambda
Co-authored-by: Pablo <[email protected]>
* Revert "Update sdks/python/apache_beam/transforms/util.py
"
This reverts commit a5a654860684f2978140f63054f1f391163f4b7c.
* Preventing circular import in core.py
* allow WithKeys to take side inputs
Expand handling of args, kwargs within WithKeys PTransform to
include side inputs as well as non-pcollection inputs. This
includes the following changes:
- additional if case in WithKeys checking for AsSideInput
- additional test case for AsSideInput inputs
- adds AsSideInput as visible class in pvalue.py
* fix lint
Co-authored-by: Pablo <[email protected]>
---
sdks/python/apache_beam/pvalue.py | 1 +
sdks/python/apache_beam/transforms/core.py | 16 ++-----------
sdks/python/apache_beam/transforms/util.py | 32 +++++++++++++++++++++++--
sdks/python/apache_beam/transforms/util_test.py | 26 ++++++++++++++++++++
4 files changed, 59 insertions(+), 16 deletions(-)
diff --git a/sdks/python/apache_beam/pvalue.py
b/sdks/python/apache_beam/pvalue.py
index e17cfab..2b593e4 100644
--- a/sdks/python/apache_beam/pvalue.py
+++ b/sdks/python/apache_beam/pvalue.py
@@ -56,6 +56,7 @@ if TYPE_CHECKING:
__all__ = [
'PCollection',
'TaggedOutput',
+ 'AsSideInput',
'AsSingleton',
'AsIter',
'AsList',
diff --git a/sdks/python/apache_beam/transforms/core.py
b/sdks/python/apache_beam/transforms/core.py
index 74c5b74..cb69e64 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -698,19 +698,6 @@ class DoFn(WithTypeHints, HasDisplayData,
urns.RunnerApiFn):
urns.RunnerApiFn.register_pickle_urn(python_urns.PICKLED_DOFN)
-def _fn_takes_side_inputs(fn):
- try:
- signature = get_signature(fn)
- except TypeError:
- # We can't tell; maybe it does.
- return True
-
- return (
- len(signature.parameters) > 1 or any(
- p.kind == p.VAR_POSITIONAL or p.kind == p.VAR_KEYWORD
- for p in signature.parameters.values()))
-
-
class CallableWrapperDoFn(DoFn):
"""For internal use only; no backwards-compatibility guarantees.
@@ -1564,7 +1551,8 @@ def Map(fn, *args, **kwargs): # pylint:
disable=invalid-name
raise TypeError(
'Map can be used only with callable objects. '
'Received %r instead.' % (fn))
- if _fn_takes_side_inputs(fn):
+ from apache_beam.transforms.util import fn_takes_side_inputs
+ if fn_takes_side_inputs(fn):
wrapper = lambda x, *args, **kwargs: [fn(x, *args, **kwargs)]
else:
wrapper = lambda x: [fn(x)]
diff --git a/sdks/python/apache_beam/transforms/util.py
b/sdks/python/apache_beam/transforms/util.py
index d00da29..28024fa 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -40,6 +40,7 @@ from apache_beam import typehints
from apache_beam.metrics import Metrics
from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_runner_api_pb2
+from apache_beam.pvalue import AsSideInput
from apache_beam.transforms import window
from apache_beam.transforms.combiners import CountCombineFn
from apache_beam.transforms.core import CombinePerKey
@@ -63,6 +64,7 @@ from apache_beam.transforms.userstate import on_timer
from apache_beam.transforms.window import NonMergingWindowFn
from apache_beam.transforms.window import TimestampCombiner
from apache_beam.transforms.window import TimestampedValue
+from apache_beam.typehints.decorators import get_signature
from apache_beam.typehints.sharded_key_type import ShardedKeyType
from apache_beam.utils import windowed_value
from apache_beam.utils.annotations import deprecated
@@ -741,14 +743,40 @@ class Reshuffle(PTransform):
return Reshuffle()
+def fn_takes_side_inputs(fn):
+ try:
+ signature = get_signature(fn)
+ except TypeError:
+ # We can't tell; maybe it does.
+ return True
+
+ return (
+ len(signature.parameters) > 1 or any(
+ p.kind == p.VAR_POSITIONAL or p.kind == p.VAR_KEYWORD
+ for p in signature.parameters.values()))
+
+
@ptransform_fn
-def WithKeys(pcoll, k):
+def WithKeys(pcoll, k, *args, **kwargs):
"""PTransform that takes a PCollection, and either a constant key or a
callable, and returns a PCollection of (K, V), where each of the values in
the input PCollection has been paired with either the constant key or a key
- computed from the value.
+ computed from the value. The callable may optionally accept positional or
+ keyword arguments, which should be passed to WithKeys directly. These may
+ be either SideInputs or static (non-PCollection) values, such as ints.
"""
if callable(k):
+ if fn_takes_side_inputs(k):
+ if all([isinstance(arg, AsSideInput)
+ for arg in args]) and all([isinstance(kwarg, AsSideInput)
+ for kwarg in kwargs.values()]):
+ return pcoll | Map(
+ lambda v,
+ *args,
+ **kwargs: (k(v, *args, **kwargs), v),
+ *args,
+ **kwargs)
+ return pcoll | Map(lambda v: (k(v, *args, **kwargs), v))
return pcoll | Map(lambda v: (k(v), v))
return pcoll | Map(lambda v: (k, v))
diff --git a/sdks/python/apache_beam/transforms/util_test.py
b/sdks/python/apache_beam/transforms/util_test.py
index a283e19..94cc8f3 100644
--- a/sdks/python/apache_beam/transforms/util_test.py
+++ b/sdks/python/apache_beam/transforms/util_test.py
@@ -38,6 +38,8 @@ from apache_beam.options.pipeline_options import
PipelineOptions
from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_runner_api_pb2
+from apache_beam.pvalue import AsList
+from apache_beam.pvalue import AsSingleton
from apache_beam.runners import pipeline_context
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.test_stream import TestStream
@@ -628,6 +630,30 @@ class WithKeysTest(unittest.TestCase):
with_keys = pc | util.WithKeys(lambda x: x * x)
assert_that(with_keys, equal_to([(1, 1), (4, 2), (9, 3)]))
+ @staticmethod
+ def _test_args_kwargs_fn(x, multiply, subtract):
+ return x * multiply - subtract
+
+ def test_args_kwargs_k(self):
+ with TestPipeline() as p:
+ pc = p | beam.Create(self.l)
+ with_keys = pc | util.WithKeys(
+ WithKeysTest._test_args_kwargs_fn, 2, subtract=1)
+ assert_that(with_keys, equal_to([(1, 1), (3, 2), (5, 3)]))
+
+ def test_sideinputs(self):
+ with TestPipeline() as p:
+ pc = p | beam.Create(self.l)
+ si1 = AsList(p | "side input 1" >> beam.Create([1, 2, 3]))
+ si2 = AsSingleton(p | "side input 2" >> beam.Create([10]))
+ with_keys = pc | util.WithKeys(
+ lambda x,
+ the_list,
+ the_singleton: x + sum(the_list) + the_singleton,
+ si1,
+ the_singleton=si2)
+ assert_that(with_keys, equal_to([(17, 1), (18, 2), (19, 3)]))
+
class GroupIntoBatchesTest(unittest.TestCase):
NUM_ELEMENTS = 10