This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 3ff64c72266 [Python] Implement combiner deferred side inputs (#35601)
3ff64c72266 is described below

commit 3ff64c72266813b296fdbf3241412b84a1d3c0d9
Author: Hai Joey Tran <[email protected]>
AuthorDate: Tue Jul 22 14:59:49 2025 -0400

    [Python] Implement combiner deferred side inputs (#35601)
    
    * squash
    
    * revert unnecessary fn_runner change
    
    * improve test
    
    * add a test with streaming
    
    * add streaming with matching window test
    
    * implement streaming support
    
    * add combiner support
    
    * implement args support
    
    * more rigorously test which combinefn methods are called with side inputs
    
    * clean up
    
    * use pack/unpack terminology
    
    * add an explanatory comment
    
    * revert old unneeded changes
    
    * tidy
    
    * use temp dir for json test file
    
    * add combineglobally test
    
    * connect args/kwargs to all combinefn methods after all
    
    * enable streaming for streaming tests
    
    * remove streaming options
    
    * move liftedcombineperkey
    
    * add additional docstring to liftedcombineperkey
    
    * isort
    
    * Update sdks/python/apache_beam/transforms/combiners.py
    
    Co-authored-by: Danny McCormick <[email protected]>
    
    * privatize a couple transforms
    
    ---------
    
    Co-authored-by: Danny McCormick <[email protected]>
---
 .../apache_beam/runners/direct/direct_runner.py    |   2 +-
 .../runners/direct/helper_transforms.py            | 120 -------------
 sdks/python/apache_beam/transforms/combiners.py    | 124 ++++++++++++++
 .../apache_beam/transforms/combiners_test.py       | 186 +++++++++++++++++++++
 sdks/python/apache_beam/transforms/core.py         |  16 ++
 5 files changed, 327 insertions(+), 121 deletions(-)

diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py 
b/sdks/python/apache_beam/runners/direct/direct_runner.py
index fcc13ae1024..a629c12a058 100644
--- a/sdks/python/apache_beam/runners/direct/direct_runner.py
+++ b/sdks/python/apache_beam/runners/direct/direct_runner.py
@@ -389,7 +389,7 @@ def _get_transform_overrides(pipeline_options):
 
   # Importing following locally to avoid a circular dependency.
   from apache_beam.pipeline import PTransformOverride
-  from apache_beam.runners.direct.helper_transforms import LiftedCombinePerKey
+  from apache_beam.transforms.combiners import LiftedCombinePerKey
   from apache_beam.runners.direct.sdf_direct_runner import 
ProcessKeyedElementsViaKeyedWorkItemsOverride
   from apache_beam.runners.direct.sdf_direct_runner import 
SplittableParDoOverride
 
diff --git a/sdks/python/apache_beam/runners/direct/helper_transforms.py 
b/sdks/python/apache_beam/runners/direct/helper_transforms.py
deleted file mode 100644
index 0e88c021e2f..00000000000
--- a/sdks/python/apache_beam/runners/direct/helper_transforms.py
+++ /dev/null
@@ -1,120 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements.  See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License.  You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-# pytype: skip-file
-
-import collections
-import itertools
-import typing
-
-import apache_beam as beam
-from apache_beam import typehints
-from apache_beam.internal.util import ArgumentPlaceholder
-from apache_beam.transforms.combiners import _CurriedFn
-from apache_beam.utils.windowed_value import WindowedValue
-
-
-class LiftedCombinePerKey(beam.PTransform):
-  """An implementation of CombinePerKey that does mapper-side pre-combining.
-  """
-  def __init__(self, combine_fn, args, kwargs):
-    args_to_check = itertools.chain(args, kwargs.values())
-    if isinstance(combine_fn, _CurriedFn):
-      args_to_check = itertools.chain(
-          args_to_check, combine_fn.args, combine_fn.kwargs.values())
-    if any(isinstance(arg, ArgumentPlaceholder) for arg in args_to_check):
-      # This isn't implemented in dataflow either...
-      raise NotImplementedError('Deferred CombineFn side inputs.')
-    self._combine_fn = beam.transforms.combiners.curry_combine_fn(
-        combine_fn, args, kwargs)
-
-  def expand(self, pcoll):
-    return (
-        pcoll
-        | beam.ParDo(PartialGroupByKeyCombiningValues(self._combine_fn))
-        | beam.GroupByKey()
-        | beam.ParDo(FinishCombine(self._combine_fn)))
-
-
-class PartialGroupByKeyCombiningValues(beam.DoFn):
-  """Aggregates values into a per-key-window cache.
-
-  As bundles are in-memory-sized, we don't bother flushing until the very end.
-  """
-  def __init__(self, combine_fn):
-    self._combine_fn = combine_fn
-
-  def setup(self):
-    self._combine_fn.setup()
-
-  def start_bundle(self):
-    self._cache = collections.defaultdict(self._combine_fn.create_accumulator)
-
-  def process(self, element, window=beam.DoFn.WindowParam):
-    k, vi = element
-    self._cache[k, window] = self._combine_fn.add_input(
-        self._cache[k, window], vi)
-
-  def finish_bundle(self):
-    for (k, w), va in self._cache.items():
-      # We compact the accumulator since a GBK (which necessitates encoding)
-      # will follow.
-      yield WindowedValue((k, self._combine_fn.compact(va)), w.end, (w, ))
-
-  def teardown(self):
-    self._combine_fn.teardown()
-
-  def default_type_hints(self):
-    hints = self._combine_fn.get_type_hints()
-    K = typehints.TypeVariable('K')
-    if hints.input_types:
-      args, kwargs = hints.input_types
-      args = (typehints.Tuple[K, args[0]], ) + args[1:]
-      hints = hints.with_input_types(*args, **kwargs)
-    else:
-      hints = hints.with_input_types(typehints.Tuple[K, typing.Any])
-    hints = hints.with_output_types(typehints.Tuple[K, typing.Any])
-    return hints
-
-
-class FinishCombine(beam.DoFn):
-  """Merges partially combined results.
-  """
-  def __init__(self, combine_fn):
-    self._combine_fn = combine_fn
-
-  def setup(self):
-    self._combine_fn.setup()
-
-  def process(self, element):
-    k, vs = element
-    return [(
-        k,
-        self._combine_fn.extract_output(
-            self._combine_fn.merge_accumulators(vs)))]
-
-  def teardown(self):
-    self._combine_fn.teardown()
-
-  def default_type_hints(self):
-    hints = self._combine_fn.get_type_hints()
-    K = typehints.TypeVariable('K')
-    hints = hints.with_input_types(typehints.Tuple[K, typing.Any])
-    if hints.output_types:
-      main_output_type = hints.simple_output_type('')
-      hints = hints.with_output_types(typehints.Tuple[K, main_output_type])
-    return hints
diff --git a/sdks/python/apache_beam/transforms/combiners.py 
b/sdks/python/apache_beam/transforms/combiners.py
index 58267ef97ac..6e4647fecef 100644
--- a/sdks/python/apache_beam/transforms/combiners.py
+++ b/sdks/python/apache_beam/transforms/combiners.py
@@ -41,6 +41,7 @@ from apache_beam.typehints import with_input_types
 from apache_beam.typehints import with_output_types
 from apache_beam.utils.timestamp import Duration
 from apache_beam.utils.timestamp import Timestamp
+from apache_beam.utils.windowed_value import WindowedValue
 
 __all__ = [
     'Count',
@@ -985,3 +986,126 @@ class LatestCombineFn(core.CombineFn):
 
   def extract_output(self, accumulator):
     return accumulator[0]
+
+
+class LiftedCombinePerKey(core.PTransform):
+  """An implementation of CombinePerKey that does mapper-side pre-combining.
+
+  This shouldn't generally be used directly except for use-cases where a
+  runner doesn't support CombinePerKey. This implementation manually implements
+  a CombinePerKey using ParDos, as opposed to runner implementations which may
+  use a more efficient implementation.
+  """
+  def __init__(self, combine_fn, args, kwargs):
+    side_inputs = _pack_side_inputs(args, kwargs)
+    self._side_inputs: dict = side_inputs
+    if not isinstance(combine_fn, core.CombineFn):
+      combine_fn = core.CombineFn.from_callable(combine_fn)
+    self._combine_fn = combine_fn
+
+  def expand(self, pcoll):
+    return (
+        pcoll
+        | core.ParDo(
+            _PartialGroupByKeyCombiningValues(self._combine_fn),
+            **self._side_inputs)
+        | core.GroupByKey()
+        | core.ParDo(_FinishCombine(self._combine_fn), **self._side_inputs))
+
+
+def _pack_side_inputs(side_input_args, side_input_kwargs):
+  if len(side_input_args) >= 10:
+    # If we have more than 10 side inputs, we can't use the
+    # _side_input_arg_{i} as our keys since they won't sort
+    # correctly. Just punt for now, more than 10 args probably
+    # doesn't happen often.
+    raise NotImplementedError
+  side_inputs = {}
+  for i, si in enumerate(side_input_args):
+    side_inputs[f'_side_input_arg_{i}'] = si
+  for k, v in side_input_kwargs.items():
+    side_inputs[k] = v
+  return side_inputs
+
+
+def _unpack_side_inputs(side_inputs):
+  side_input_args = []
+  side_input_kwargs = {}
+  for k, v in sorted(side_inputs.items(), key=lambda x: x[0]):
+    if k.startswith('_side_input_arg_'):
+      side_input_args.append(v)
+    else:
+      side_input_kwargs[k] = v
+  return side_input_args, side_input_kwargs
+
+
+class _PartialGroupByKeyCombiningValues(core.DoFn):
+  """Aggregates values into a per-key-window cache.
+
+  As bundles are in-memory-sized, we don't bother flushing until the very end.
+  """
+  def __init__(self, combine_fn):
+    self._combine_fn = combine_fn
+    self.side_input_args = []
+    self.side_input_kwargs = {}
+
+  def setup(self):
+    self._combine_fn.setup()
+
+  def start_bundle(self):
+    self._cache = dict()
+    self._cached_windowed_side_inputs = {}
+
+  def process(self, element, window=core.DoFn.WindowParam, **side_inputs):
+    k, vi = element
+    side_input_args, side_input_kwargs = _unpack_side_inputs(side_inputs)
+    if (k, window) not in self._cache:
+      self._cache[(k, window)] = self._combine_fn.create_accumulator(
+          *side_input_args, **side_input_kwargs)
+
+    self._cache[k, window] = self._combine_fn.add_input(
+        self._cache[k, window], vi, *side_input_args, **side_input_kwargs)
+    self._cached_windowed_side_inputs[window] = (
+        side_input_args, side_input_kwargs)
+
+  def finish_bundle(self):
+    for (k, w), va in self._cache.items():
+      # We compact the accumulator since a GBK (which necessitates encoding)
+      # will follow.
+      side_input_args, side_input_kwargs = (
+        self._cached_windowed_side_inputs[w])
+      yield WindowedValue((
+          k,
+          self._combine_fn.compact(va, *side_input_args, **side_input_kwargs)),
+                          w.end, (w, ))
+
+  def teardown(self):
+    self._combine_fn.teardown()
+
+
+class _FinishCombine(core.DoFn):
+  """Merges partially combined results.
+  """
+  def __init__(self, combine_fn):
+    self._combine_fn = combine_fn
+
+  def setup(self):
+    self._combine_fn.setup()
+
+  def process(self, element, window=core.DoFn.WindowParam, **side_inputs):
+
+    k, vs = element
+    side_input_args, side_input_kwargs = _unpack_side_inputs(side_inputs)
+    return [(
+        k,
+        self._combine_fn.extract_output(
+            self._combine_fn.merge_accumulators(
+                vs, *side_input_args, **side_input_kwargs),
+            *side_input_args,
+            **side_input_kwargs))]
+
+  def teardown(self):
+    try:
+      self._combine_fn.teardown()
+    except AttributeError:
+      pass
diff --git a/sdks/python/apache_beam/transforms/combiners_test.py 
b/sdks/python/apache_beam/transforms/combiners_test.py
index a8979239f83..ba9e21f8556 100644
--- a/sdks/python/apache_beam/transforms/combiners_test.py
+++ b/sdks/python/apache_beam/transforms/combiners_test.py
@@ -19,15 +19,20 @@
 # pytype: skip-file
 
 import itertools
+import json
+import os
 import random
+import tempfile
 import time
 import unittest
+from pathlib import Path
 
 import hamcrest as hc
 import pytest
 
 import apache_beam as beam
 import apache_beam.transforms.combiners as combine
+from apache_beam import pvalue
 from apache_beam.metrics import Metrics
 from apache_beam.metrics import MetricsFilter
 from apache_beam.options.pipeline_options import PipelineOptions
@@ -1021,5 +1026,186 @@ class CombineGloballyTest(unittest.TestCase):
           | beam.CombineGlobally(sum).without_defaults())
 
 
+def get_common_items(sets, excluded_chars=""):
+  # set.intersection() takes multiple sets as separete arguments.
+  # We unpack the `sets` list into multiple arguments with the * operator.
+  # The combine transform might give us an empty list of `sets`,
+  # so we use a list with an empty set as a default value.
+  common = set.intersection(*(sets or [set()]))
+  return common.difference(excluded_chars)
+
+
+class CombinerWithSideInputs(unittest.TestCase):
+  def test_cpk_with_side_input(self):
+    test_cases = [(get_common_items, True),
+                  (beam.CombineFn.from_callable(get_common_items), True),
+                  (get_common_items, False),
+                  (beam.CombineFn.from_callable(get_common_items), False)]
+    for combiner, with_kwarg in test_cases:
+      self._check_combineperkey_with_side_input(combiner, with_kwarg)
+      self._check_combineglobally_with_side_input(combiner, with_kwarg)
+
+  def _check_combineperkey_with_side_input(self, combiner, with_kwarg):
+    with beam.Pipeline() as pipeline:
+      pc = (pipeline | beam.Create(['🍅']))
+      if with_kwarg:
+        cpk = beam.CombinePerKey(
+            combiner, excluded_chars=beam.pvalue.AsSingleton(pc))
+      else:
+        cpk = beam.CombinePerKey(combiner, beam.pvalue.AsSingleton(pc))
+      common_items = (
+          pipeline
+          | 'Create produce' >> beam.Create([
+              {'🍓', '🥕', '🍌', '🍅', '🌶️'},
+              {'🍇', '🥕', '🥝', '🍅', '🥔'},
+              {'🍉', '🥕', '🍆', '🍅', '🍍'},
+              {'🥑', '🥕', '🌽', '🍅', '🥥'},
+          ])
+          | beam.WithKeys(lambda x: None)
+          | cpk)
+      assert_that(common_items, equal_to([(None, {'🥕'})]))
+
+  def _check_combineglobally_with_side_input(self, combiner, with_kwarg):
+    with beam.Pipeline() as pipeline:
+      pc = (pipeline | beam.Create(['🍅']))
+      if with_kwarg:
+        cpk = beam.CombineGlobally(
+            combiner, excluded_chars=beam.pvalue.AsSingleton(pc))
+      else:
+        cpk = beam.CombineGlobally(combiner, beam.pvalue.AsSingleton(pc))
+      common_items = (
+          pipeline
+          | 'Create produce' >> beam.Create([
+              {'🍓', '🥕', '🍌', '🍅', '🌶️'},
+              {'🍇', '🥕', '🥝', '🍅', '🥔'},
+              {'🍉', '🥕', '🍆', '🍅', '🍍'},
+              {'🥑', '🥕', '🌽', '🍅', '🥥'},
+          ])
+          | cpk)
+      assert_that(common_items, equal_to([{'🥕'}]))
+
+  def test_combinefn_methods_with_side_input(self):
+    # Test that the expected combinefn methods are called with the
+    # expected arguments when using side inputs in CombinePerKey.
+    with tempfile.TemporaryDirectory() as tmp_dirname:
+      fname = str(Path(tmp_dirname) / "combinefn_calls.json")
+      with open(fname, "w") as f:
+        json.dump({}, f)
+
+      def set_in_json(key, values):
+        current_json = {}
+        if os.path.exists(fname):
+          with open(fname, "r") as f:
+            current_json = json.load(f)
+        current_json[key] = values
+        with open(fname, "w") as f:
+          json.dump(current_json, f)
+
+      class MyCombiner(beam.CombineFn):
+        def create_accumulator(self, *args, **kwargs):
+          set_in_json("create_accumulator_args", args)
+          set_in_json("create_accumulator_kwargs", kwargs)
+          return args, kwargs
+
+        def add_input(self, accumulator, input, *args, **kwargs):
+          set_in_json("add_input_args", args)
+          set_in_json("add_input_kwargs", kwargs)
+          return accumulator
+
+        def merge_accumulators(self, accumulators, *args, **kwargs):
+          set_in_json("merge_accumulators_args", args)
+          set_in_json("merge_accumulators_kwargs", kwargs)
+          return args, kwargs
+
+        def compact(self, accumulator, *args, **kwargs):
+          set_in_json("compact_args", args)
+          set_in_json("compact_kwargs", kwargs)
+          return accumulator
+
+        def extract_output(self, accumulator, *args, **kwargs):
+          set_in_json("extract_output_args", args)
+          set_in_json("extract_output_kwargs", kwargs)
+          return accumulator
+
+      with beam.Pipeline() as p:
+        static_pos_arg = 0
+        deferred_pos_arg = beam.pvalue.AsSingleton(
+            p | "CreateDeferredSideInput" >> beam.Create([1]))
+        static_kwarg = 2
+        deferred_kwarg = beam.pvalue.AsSingleton(
+            p | "CreateDeferredSideInputKwarg" >> beam.Create([3]))
+        res = (
+            p
+            | "CreateInputs" >> beam.Create([(None, None)])
+            | beam.CombinePerKey(
+                MyCombiner(),
+                static_pos_arg,
+                deferred_pos_arg,
+                static_kwarg=static_kwarg,
+                deferred_kwarg=deferred_kwarg))
+        assert_that(
+            res,
+            equal_to([
+                (None, ((0, 1), {
+                    'static_kwarg': 2, 'deferred_kwarg': 3
+                }))
+            ]))
+
+      # Check that the combinefn was called with the expected arguments
+      with open(fname, "r") as f:
+        data = json.load(f)
+        expected_args = [0, 1]
+        expected_kwargs = {"static_kwarg": 2, "deferred_kwarg": 3}
+        method_names = [
+            "create_accumulator",
+            "compact",
+            "add_input",
+            "merge_accumulators",
+            "extract_output"
+        ]
+        for key in method_names:
+          print(f"Checking {key}")
+          self.assertEqual(data[key + "_args"], expected_args)
+          self.assertEqual(data[key + "_kwargs"], expected_kwargs)
+
+  def test_cpk_with_windows(self):
+    # With global window side input
+    with TestPipeline() as p:
+
+      def sum_with_floor(vals, min_value=0):
+        vals_sum = sum(vals)
+        if vals_sum < min_value:
+          vals_sum += min_value
+        return vals_sum
+
+      res = (
+          p
+          | "CreateInputs" >> beam.Create([1, 2, 100, 101, 102])
+          | beam.Map(lambda x: window.TimestampedValue(('k', x), x))
+          | beam.WindowInto(FixedWindows(99))
+          | beam.CombinePerKey(
+              sum_with_floor,
+              min_value=pvalue.AsSingleton(p | beam.Create([100]))))
+      assert_that(res, equal_to([('k', 103), ('k', 303)]))
+
+    # with matching window side input
+    with TestPipeline() as p:
+      min_value = (
+          p
+          | "CreateMinValue" >> beam.Create([
+              window.TimestampedValue(50, 5),
+              window.TimestampedValue(1000, 100)
+          ])
+          | "WindowSideInputs" >> beam.WindowInto(FixedWindows(99)))
+      res = (
+          p
+          | "CreateInputs" >> beam.Create([1, 2, 100, 101, 102])
+          | beam.Map(lambda x: window.TimestampedValue(('k', x), x))
+          | beam.WindowInto(FixedWindows(99))
+          | beam.CombinePerKey(
+              sum_with_floor, min_value=pvalue.AsSingleton(min_value)))
+      assert_that(res, equal_to([('k', 53), ('k', 1303)]))
+
+
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/transforms/core.py 
b/sdks/python/apache_beam/transforms/core.py
index c043f768574..6e0170c04ea 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -2927,6 +2927,22 @@ class CombinePerKey(PTransformWithSideInputs):
   Returns:
     A PObject holding the result of the combine operation.
   """
+  def __new__(cls, *args, **kwargs):
+    def has_side_inputs():
+      return (
+          any(isinstance(arg, pvalue.AsSideInput) for arg in args) or
+          any(isinstance(arg, pvalue.AsSideInput) for arg in kwargs.values()))
+
+    if has_side_inputs():
+      # If the CombineFn has deferred side inputs, the python SDK
+      # doesn't implement it.
+      # Use a ParDo-based CombinePerKey instead.
+      from apache_beam.transforms.combiners import \
+        LiftedCombinePerKey
+      combine_fn, *args = args
+      return LiftedCombinePerKey(combine_fn, args, kwargs)
+    return super(CombinePerKey, cls).__new__(cls)
+
   def with_hot_key_fanout(self, fanout):
     """A per-key combine operation like self but with two levels of 
aggregation.
 

Reply via email to