kamilwu commented on a change in pull request #13048: URL: https://github.com/apache/beam/pull/13048#discussion_r503336032
########## File path: sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py ########## @@ -0,0 +1,147 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""ValidatesRunner tests for CombineFn lifecycle and bundle methods.""" + +# pytype: skip-file + +import unittest +from weakref import WeakSet + +from nose.plugins.attrib import attr + +import apache_beam as beam +from apache_beam.options.pipeline_options import DebugOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.direct import direct_runner +from apache_beam.runners.portability import fn_api_runner +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to +from apache_beam.transforms import trigger +from apache_beam.transforms import window + + +class CallSequenceEnforcingCombineFn(beam.CombineFn): + instances = WeakSet() + + def __init__(self): + super(CallSequenceEnforcingCombineFn, self).__init__() + self._setup_called = False + self._accumulators_created = 0 + self._teardown_called = False + + def setup(self): + assert not self._setup_called, 'setup should not be called twice' + assert not self._teardown_called, 'setup should be called before teardown' + # Keep track of instances so that we can check if teardown is called + # properly after pipeline execution. + self.instances.add(self) + self._setup_called = True + + def create_accumulator(self): + assert self._setup_called, 'setup should have been called' + assert not self._teardown_called, 'teardown should not have been called' + self._accumulators_created += 1 + return 0 + + def add_input(self, mutable_accumulator, element): + assert self._setup_called, 'setup should have been called' + assert self._accumulators_created > 0, \ + 'create_accumulator should have been called' + assert not self._teardown_called, 'teardown should not have been called' + mutable_accumulator += element + return mutable_accumulator + + def add_inputs(self, mutable_accumulator, elements): + return self.add_input(mutable_accumulator, sum(elements)) + + def merge_accumulators(self, accumulators): + assert self._setup_called, 'setup should have been called' + assert not self._teardown_called, 'teardown should not have been called' + return sum(accumulators) + + def extract_output(self, accumulator): + assert self._setup_called, 'setup should have been called' + assert not self._teardown_called, 'teardown should not have been called' + return accumulator + + def teardown(self): + assert self._setup_called, 'setup should have been called' + assert not self._teardown_called, 'teardown should not be called twice' + self._teardown_called = True + + +class BaseCombineFnLifecycleTest(unittest.TestCase): + def start(self, pipeline, lift_combiners=True): + with pipeline as p: + pcoll = p | 'Start' >> beam.Create(range(5)) + + # Certain triggers, such as AfterCount, are incompatible with combiner + # lifting. We can use that fact to prevent combiners from being lifted. + if not lift_combiners: + pcoll |= beam.WindowInto( + window.GlobalWindows(), + trigger=trigger.AfterCount(5), + accumulation_mode=trigger.AccumulationMode.DISCARDING) + + pcoll |= 'Do' >> beam.CombineGlobally(CallSequenceEnforcingCombineFn()) + assert_that(pcoll, equal_to([10])) + + # Ensure that _teardown_called equals True for all CombineFns. + for instance in CallSequenceEnforcingCombineFn.instances: + self.assertTrue(instance._teardown_called) + + +@attr('ValidatesRunner') +class CombineFnLifecycleTest(BaseCombineFnLifecycleTest): + def setUp(self): + self.pipeline = TestPipeline(is_integration_test=True) + options = self.pipeline.get_pipeline_options() + standard_options = options.view_as(StandardOptions) + experiments = options.view_as(DebugOptions).experiments or [] + + if 'DataflowRunner' in standard_options.runner and \ + not standard_options.streaming and \ + 'beam_fn_api' not in experiments and 'use_runner_v2' not in experiments: + self.skipTest( + 'Non-portable Dataflow batch worker does not support ' Review comment: Yes, it's probably a good idea. Should we raise an exception and exit the program abnormally if user-provided setup and teardown are detected, or just inform the user that those methods won't be called? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
