yifanmai commented on a change in pull request #13048: URL: https://github.com/apache/beam/pull/13048#discussion_r513623504
########## File path: CHANGES.md ########## @@ -62,6 +62,7 @@ ## New Features / Improvements * Added support for avro payload format in Beam SQL Kafka Table ([BEAM-10885](https://issues.apache.org/jira/browse/BEAM-10885)) +* Added CombineFn.setup and CombineFn.teardown to Python SDK. These methods let you initialize a state before any of the other methods of the CombineFn is executed and clean that state up later on. ([BEAM-3736](https://issues.apache.org/jira/browse/BEAM-3736)) Review comment: nit: 'a state' -> 'the CombineFn's state' ########## File path: sdks/python/apache_beam/transforms/core.py ########## @@ -1975,10 +1990,14 @@ def add_input_types(transform): return combined if self.has_defaults: - combine_fn = ( - self.fn if isinstance(self.fn, CombineFn) else - CombineFn.from_callable(self.fn)) - default_value = combine_fn.apply([], *self.args, **self.kwargs) + combine_fn = copy.deepcopy( + self.fn if isinstance(self.fn, CombineFn) else CombineFn. Review comment: nit: this can be `copy.deepcopy(self.fn) if...` i.e. copy is only needed in the first branch ########## File path: sdks/python/apache_beam/transforms/core.py ########## @@ -877,17 +877,19 @@ class CombineFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn): combining process proceeds as follows: 1. Input values are partitioned into one or more batches. - 2. For each batch, the create_accumulator method is invoked to create a fresh + 2. For each batch, the setup method is invoked. + 3. For each batch, the create_accumulator method is invoked to create a fresh initial "accumulator" value representing the combination of zero values. - 3. For each input value in the batch, the add_input method is invoked to + 4. For each input value in the batch, the add_input method is invoked to combine more values with the accumulator for that batch. - 4. The merge_accumulators method is invoked to combine accumulators from + 5. The merge_accumulators method is invoked to combine accumulators from separate batches into a single combined output accumulator value, once all of the accumulators have had all the input value in their batches added to them. This operation is invoked repeatedly, until there is only one accumulator value left. - 5. The extract_output operation is invoked on the final accumulator to get + 6. The extract_output operation is invoked on the final accumulator to get the output value. + 7. The teardown method is invoked. Review comment: Question: What is the expected behavior if setup throws an exception? Should teardown still be called? ########## File path: sdks/python/apache_beam/runners/dataflow/dataflow_runner.py ########## @@ -411,6 +411,33 @@ def visit_transform(self, transform_node): return FlattenInputVisitor() + @staticmethod + def combinefn_visitor(): + # Imported here to avoid circular dependencies. + from apache_beam.pipeline import PipelineVisitor + from apache_beam import core + + class CombineFnVisitor(PipelineVisitor): + """Checks if `CombineFn` has non-default setup or teardown methods. + If yes, raises `ValueError`. + """ + def visit_transform(self, applied_transform): + transform = applied_transform.transform + if isinstance(transform, core.ParDo) and isinstance( + transform.fn, core.CombineValuesDoFn): + if self._overrides_setup_or_teardown(transform.fn.combinefn): + raise ValueError( + 'CombineFn.setup and CombineFn.teardown are ' + 'not supported with non-portable Dataflow ' + 'runner. Please use Dataflow Runner V2 instead.') Review comment: Question: Is there any plan to support this in non-portable Dataflow Runner, or will this be a V2 feature only? ########## File path: sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py ########## @@ -0,0 +1,131 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# pytype: skip-file + +from typing import Set +from typing import Tuple + +import apache_beam as beam +from apache_beam.options.pipeline_options import TypeOptions +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to +from apache_beam.transforms import combiners +from apache_beam.transforms import trigger +from apache_beam.transforms import userstate +from apache_beam.transforms import window +from apache_beam.typehints import with_input_types +from apache_beam.typehints import with_output_types + + +@with_input_types(int) +@with_output_types(int) +class CallSequenceEnforcingCombineFn(beam.CombineFn): + instances = set() # type: Set[CallSequenceEnforcingCombineFn] + + def __init__(self): + super(CallSequenceEnforcingCombineFn, self).__init__() + self._setup_called = False + self._teardown_called = False + + def setup(self, *args, **kwargs): + assert not self._setup_called, 'setup should not be called twice' + assert not self._teardown_called, 'setup should be called before teardown' + # Keep track of instances so that we can check if teardown is called + # properly after pipeline execution. + self.instances.add(self) + self._setup_called = True + + def create_accumulator(self, *args, **kwargs): + assert self._setup_called, 'setup should have been called' + assert not self._teardown_called, 'teardown should not have been called' + return 0 + + def add_input(self, mutable_accumulator, element, *args, **kwargs): + assert self._setup_called, 'setup should have been called' + assert not self._teardown_called, 'teardown should not have been called' + mutable_accumulator += element + return mutable_accumulator + + def add_inputs(self, mutable_accumulator, elements, *args, **kwargs): + return self.add_input(mutable_accumulator, sum(elements)) + + def merge_accumulators(self, accumulators, *args, **kwargs): + assert self._setup_called, 'setup should have been called' + assert not self._teardown_called, 'teardown should not have been called' + return sum(accumulators) + + def extract_output(self, accumulator, *args, **kwargs): + assert self._setup_called, 'setup should have been called' + assert not self._teardown_called, 'teardown should not have been called' + return accumulator + + def teardown(self, *args, **kwargs): + assert self._setup_called, 'setup should have been called' + assert not self._teardown_called, 'teardown should not be called twice' + self._teardown_called = True + + +@with_input_types(Tuple[None, str]) +@with_output_types(Tuple[int, str]) +class IndexAssigningDoFn(beam.DoFn): + state_param = beam.DoFn.StateParam( + userstate.CombiningValueStateSpec( + 'index', beam.coders.VarIntCoder(), CallSequenceEnforcingCombineFn())) + + def process(self, element, state=state_param): + _, value = element + current_index = state.read() + yield current_index, value + state.add(1) + + +def run_combine(pipeline, input_elements=5, lift_combiners=True): + # Calculate the excepted result, which is the sum of an arythmetic sequence. Review comment: nit: 'arythmetic' -> 'arithmetic' ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
