diff --git a/sdks/python/apache_beam/runners/worker/statesampler_test.py b/sdks/python/apache_beam/runners/worker/statesampler_test.py index dfb4bc10a913..1ff0c5ca6f5d 100644 --- a/sdks/python/apache_beam/runners/worker/statesampler_test.py +++ b/sdks/python/apache_beam/runners/worker/statesampler_test.py @@ -24,6 +24,9 @@ import unittest from builtins import range +from tenacity import retry +from tenacity import stop_after_attempt + from apache_beam.runners.worker import statesampler from apache_beam.utils.counters import CounterFactory from apache_beam.utils.counters import CounterName @@ -31,33 +34,41 @@ class StateSamplerTest(unittest.TestCase): + # Due to somewhat non-deterministic nature of state sampling and sleep, + # this test is flaky when state duraiton is low. + # Since increasing state duration significantly would also slow down + # the test suite, we are retrying twice on failure as a mitigation. + @retry(reraise=True, stop=stop_after_attempt(3)) def test_basic_sampler(self): # Set up state sampler. counter_factory = CounterFactory() sampler = statesampler.StateSampler('basic', counter_factory, sampling_period_ms=1) + # Duration of the fastest state. Total test duration is 6 times longer. + state_duration_ms = 1000 + margin_of_error = 0.25 # Run basic workload transitioning between 3 states. sampler.start() with sampler.scoped_state('step1', 'statea'): - time.sleep(0.1) + time.sleep(state_duration_ms / 1000) self.assertEqual( sampler.current_state().name, CounterName( 'statea-msecs', step_name='step1', stage_name='basic')) with sampler.scoped_state('step1', 'stateb'): - time.sleep(0.2 / 2) + time.sleep(state_duration_ms / 1000) self.assertEqual( sampler.current_state().name, CounterName( 'stateb-msecs', step_name='step1', stage_name='basic')) with sampler.scoped_state('step1', 'statec'): - time.sleep(0.3) + time.sleep(3 * state_duration_ms / 1000) self.assertEqual( sampler.current_state().name, CounterName( 'statec-msecs', step_name='step1', stage_name='basic')) - time.sleep(0.2 / 2) + time.sleep(state_duration_ms / 1000) sampler.stop() sampler.commit_counters() @@ -68,9 +79,12 @@ def test_basic_sampler(self): # Test that sampled state timings are close to their expected values. expected_counter_values = { - CounterName('statea-msecs', step_name='step1', stage_name='basic'): 100, - CounterName('stateb-msecs', step_name='step1', stage_name='basic'): 200, - CounterName('statec-msecs', step_name='step1', stage_name='basic'): 300, + CounterName('statea-msecs', step_name='step1', stage_name='basic'): + state_duration_ms, + CounterName('stateb-msecs', step_name='step1', stage_name='basic'): + 2 * state_duration_ms, + CounterName('statec-msecs', step_name='step1', stage_name='basic'): + 3 * state_duration_ms, } for counter in counter_factory.get_counters(): self.assertIn(counter.name, expected_counter_values) @@ -78,8 +92,8 @@ def test_basic_sampler(self): actual_value = counter.value() deviation = float(abs(actual_value - expected_value)) / expected_value logging.info('Sampling deviation from expectation: %f', deviation) - self.assertGreater(actual_value, expected_value * 0.75) - self.assertLess(actual_value, expected_value * 1.25) + self.assertGreater(actual_value, expected_value * (1.0 - margin_of_error)) + self.assertLess(actual_value, expected_value * (1.0 + margin_of_error)) def test_sampler_transition_overhead(self): # Set up state sampler. diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 807ff09ac5fe..4e2c761185f9 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -133,10 +133,11 @@ def get_version(): REQUIRED_TEST_PACKAGES = [ 'nose>=1.3.7', + 'numpy>=1.14.3,<2', 'pandas>=0.23.4,<0.24', 'parameterized>=0.6.0,<0.7.0', - 'numpy>=1.14.3,<2', 'pyhamcrest>=1.9,<2.0', + 'tenacity>=5.0.2,<6.0', ] GCP_REQUIREMENTS = [
With regards, Apache Git Services