tvalentyn commented on a change in pull request #13048:
URL: https://github.com/apache/beam/pull/13048#discussion_r503029540



##########
File path: sdks/python/apache_beam/transforms/core.py
##########
@@ -875,18 +875,20 @@ class CombineFn(WithTypeHints, HasDisplayData, 
urns.RunnerApiFn):
   input argument, which is an instance of CombineFnProcessContext). The
   combining process proceeds as follows:
 
-  1. Input values are partitioned into one or more batches.
-  2. For each batch, the create_accumulator method is invoked to create a fresh
+  1. The setup method is invoked.

Review comment:
       This alludes that  setup method is called once per entire collection. 
Wouldn't it be called per batch? Aggregation may happen on multiple workers, 
and I imagine that in such case each worker will call setup/teardown methods. 
Should we switch step 1 and 2?

##########
File path: sdks/python/apache_beam/transforms/core.py
##########
@@ -895,6 +897,15 @@ class CombineFn(WithTypeHints, HasDisplayData, 
urns.RunnerApiFn):
   def default_label(self):
     return self.__class__.__name__
 
+  def setup(self):
+    """Called to prepare an instance for combining.
+
+    This method can be useful if there is some state that needs to be loaded
+    before executing any of the other methods. The resources can then be
+    disposed in ``CombineFn.teardown``.

Review comment:
       nit: s/disposed/disposed of

##########
File path: sdks/python/apache_beam/transforms/core.py
##########
@@ -1970,10 +1985,14 @@ def add_input_types(transform):
       return combined
 
     if self.has_defaults:
-      combine_fn = (
-          self.fn if isinstance(self.fn, CombineFn) else
-          CombineFn.from_callable(self.fn))
-      default_value = combine_fn.apply([], *self.args, **self.kwargs)
+      combine_fn = copy.copy(

Review comment:
       What is the reason for a shallow copy here?

##########
File path: sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py
##########
@@ -0,0 +1,147 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""ValidatesRunner tests for CombineFn lifecycle and bundle methods."""
+
+# pytype: skip-file
+
+import unittest
+from weakref import WeakSet
+
+from nose.plugins.attrib import attr
+
+import apache_beam as beam
+from apache_beam.options.pipeline_options import DebugOptions
+from apache_beam.options.pipeline_options import StandardOptions
+from apache_beam.runners.direct import direct_runner
+from apache_beam.runners.portability import fn_api_runner
+from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+from apache_beam.transforms import trigger
+from apache_beam.transforms import window
+
+
+class CallSequenceEnforcingCombineFn(beam.CombineFn):
+  instances = WeakSet()
+
+  def __init__(self):
+    super(CallSequenceEnforcingCombineFn, self).__init__()
+    self._setup_called = False
+    self._accumulators_created = 0
+    self._teardown_called = False
+
+  def setup(self):
+    assert not self._setup_called, 'setup should not be called twice'
+    assert not self._teardown_called, 'setup should be called before teardown'
+    # Keep track of instances so that we can check if teardown is called
+    # properly after pipeline execution.
+    self.instances.add(self)
+    self._setup_called = True
+
+  def create_accumulator(self):
+    assert self._setup_called, 'setup should have been called'
+    assert not self._teardown_called, 'teardown should not have been called'
+    self._accumulators_created += 1
+    return 0
+
+  def add_input(self, mutable_accumulator, element):
+    assert self._setup_called, 'setup should have been called'
+    assert self._accumulators_created > 0, \
+        'create_accumulator should have been called'
+    assert not self._teardown_called, 'teardown should not have been called'
+    mutable_accumulator += element
+    return mutable_accumulator
+
+  def add_inputs(self, mutable_accumulator, elements):
+    return self.add_input(mutable_accumulator, sum(elements))
+
+  def merge_accumulators(self, accumulators):
+    assert self._setup_called, 'setup should have been called'
+    assert not self._teardown_called, 'teardown should not have been called'
+    return sum(accumulators)
+
+  def extract_output(self, accumulator):
+    assert self._setup_called, 'setup should have been called'
+    assert not self._teardown_called, 'teardown should not have been called'
+    return accumulator
+
+  def teardown(self):
+    assert self._setup_called, 'setup should have been called'
+    assert not self._teardown_called, 'teardown should not be called twice'
+    self._teardown_called = True
+
+
+class BaseCombineFnLifecycleTest(unittest.TestCase):
+  def start(self, pipeline, lift_combiners=True):
+    with pipeline as p:
+      pcoll = p | 'Start' >> beam.Create(range(5))
+
+      # Certain triggers, such as AfterCount, are incompatible with combiner
+      # lifting. We can use that fact to prevent combiners from being lifted.
+      if not lift_combiners:
+        pcoll |= beam.WindowInto(
+            window.GlobalWindows(),
+            trigger=trigger.AfterCount(5),
+            accumulation_mode=trigger.AccumulationMode.DISCARDING)
+
+      pcoll |= 'Do' >> beam.CombineGlobally(CallSequenceEnforcingCombineFn())
+      assert_that(pcoll, equal_to([10]))
+
+    # Ensure that _teardown_called equals True for all CombineFns.
+    for instance in CallSequenceEnforcingCombineFn.instances:
+      self.assertTrue(instance._teardown_called)
+
+
+@attr('ValidatesRunner')
+class CombineFnLifecycleTest(BaseCombineFnLifecycleTest):
+  def setUp(self):
+    self.pipeline = TestPipeline(is_integration_test=True)
+    options = self.pipeline.get_pipeline_options()
+    standard_options = options.view_as(StandardOptions)
+    experiments = options.view_as(DebugOptions).experiments or []
+
+    if 'DataflowRunner' in standard_options.runner and \
+       not standard_options.streaming and \
+       'beam_fn_api' not in experiments and 'use_runner_v2' not in experiments:
+      self.skipTest(
+          'Non-portable Dataflow batch worker does not support '

Review comment:
       Should we make  non-portable Dataflow runner detect usage of combiner 
initialization and alert the user that this functionality is unsupported?
   cc: @robertwb   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to