This is an automated email from the ASF dual-hosted git repository. boyuanz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new cd6e54b Move ThreadsafeRestrictionTracker and RestrictionTrackerView out from iobase.py new bcc3e13 Merge pull request #10802 from boyuanzz/refactor cd6e54b is described below commit cd6e54bc19f5e69cab49d22f5044b1c869b9ec69 Author: Boyuan Zhang <boyu...@google.com> AuthorDate: Fri Feb 7 16:07:20 2020 -0800 Move ThreadsafeRestrictionTracker and RestrictionTrackerView out from iobase.py --- sdks/python/apache_beam/io/iobase.py | 128 --------------- sdks/python/apache_beam/io/iobase_test.py | 86 +--------- sdks/python/apache_beam/runners/common.py | 69 ++++---- .../runners/portability/fn_api_runner_test.py | 8 +- sdks/python/apache_beam/runners/sdf_utils.py | 176 +++++++++++++++++++++ sdks/python/apache_beam/runners/sdf_utils_test.py | 114 +++++++++++++ .../apache_beam/runners/worker/bundle_processor.py | 16 +- 7 files changed, 337 insertions(+), 260 deletions(-) diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py index 10d2933..1302c38 100644 --- a/sdks/python/apache_beam/io/iobase.py +++ b/sdks/python/apache_beam/io/iobase.py @@ -37,7 +37,6 @@ from __future__ import division import logging import math import random -import threading import uuid from builtins import object from builtins import range @@ -65,9 +64,7 @@ from apache_beam.utils import urns from apache_beam.utils.windowed_value import WindowedValue if TYPE_CHECKING: - from apache_beam.io import restriction_trackers from apache_beam.runners.pipeline_context import PipelineContext - from apache_beam.utils.timestamp import Timestamp __all__ = [ 'BoundedSource', @@ -1246,131 +1243,6 @@ class RestrictionTracker(object): raise NotImplementedError -class ThreadsafeRestrictionTracker(object): - """A thread-safe wrapper which wraps a `RestritionTracker`. - - This wrapper guarantees synchronization of modifying restrictions across - multi-thread. - """ - def __init__(self, restriction_tracker): - # type: (RestrictionTracker) -> None - if not isinstance(restriction_tracker, RestrictionTracker): - raise ValueError( - 'Initialize ThreadsafeRestrictionTracker requires' - 'RestrictionTracker.') - self._restriction_tracker = restriction_tracker - # Records an absolute timestamp when defer_remainder is called. - self._deferred_timestamp = None - self._lock = threading.RLock() - self._deferred_residual = None - self._deferred_watermark = None - - def current_restriction(self): - with self._lock: - return self._restriction_tracker.current_restriction() - - def try_claim(self, position): - with self._lock: - return self._restriction_tracker.try_claim(position) - - def defer_remainder(self, deferred_time=None): - """Performs self-checkpoint on current processing restriction with an - expected resuming time. - - Self-checkpoint could happen during processing elements. When executing an - DoFn.process(), you may want to stop processing an element and resuming - later if current element has been processed quit a long time or you also - want to have some outputs from other elements. ``defer_remainder()`` can be - called on per element if needed. - - Args: - deferred_time: A relative ``timestamp.Duration`` that indicates the ideal - time gap between now and resuming, or an absolute ``timestamp.Timestamp`` - for resuming execution time. If the time_delay is None, the deferred work - will be executed as soon as possible. - """ - - # Record current time for calculating deferred_time later. - self._deferred_timestamp = timestamp.Timestamp.now() - if (deferred_time and not isinstance(deferred_time, timestamp.Duration) and - not isinstance(deferred_time, timestamp.Timestamp)): - raise ValueError( - 'The timestamp of deter_remainder() should be a ' - 'Duration or a Timestamp, or None.') - self._deferred_watermark = deferred_time - checkpoint = self.try_split(0) - if checkpoint: - _, self._deferred_residual = checkpoint - - def check_done(self): - with self._lock: - return self._restriction_tracker.check_done() - - def current_progress(self): - with self._lock: - return self._restriction_tracker.current_progress() - - def try_split(self, fraction_of_remainder): - with self._lock: - return self._restriction_tracker.try_split(fraction_of_remainder) - - def deferred_status(self): - # type: () -> Optional[Tuple[Any, Timestamp]] - - """Returns deferred work which is produced by ``defer_remainder()``. - - When there is a self-checkpoint performed, the system needs to fulfill the - DelayedBundleApplication with deferred_work for a ProcessBundleResponse. - The system calls this API to get deferred_residual with watermark together - to help the runner to schedule a future work. - - Returns: (deferred_residual, time_delay) if having any residual, else None. - """ - if self._deferred_residual: - # If _deferred_watermark is None, create Duration(0). - if not self._deferred_watermark: - self._deferred_watermark = timestamp.Duration() - # If an absolute timestamp is provided, calculate the delta between - # the absoluted time and the time deferred_status() is called. - elif isinstance(self._deferred_watermark, timestamp.Timestamp): - self._deferred_watermark = ( - self._deferred_watermark - timestamp.Timestamp.now()) - # If a Duration is provided, the deferred time should be: - # provided duration - the spent time since the defer_remainder() is - # called. - elif isinstance(self._deferred_watermark, timestamp.Duration): - self._deferred_watermark -= ( - timestamp.Timestamp.now() - self._deferred_timestamp) - return self._deferred_residual, self._deferred_watermark - return None - - -class RestrictionTrackerView(object): - """A DoFn view of thread-safe RestrictionTracker. - - The RestrictionTrackerView wraps a ThreadsafeRestrictionTracker and only - exposes APIs that will be called by a ``DoFn.process()``. During execution - time, the RestrictionTrackerView will be fed into the ``DoFn.process`` as a - restriction_tracker. - """ - def __init__(self, threadsafe_restriction_tracker): - if not isinstance(threadsafe_restriction_tracker, - ThreadsafeRestrictionTracker): - raise ValueError( - 'Initialize RestrictionTrackerView requires ' - 'ThreadsafeRestrictionTracker.') - self._threadsafe_restriction_tracker = threadsafe_restriction_tracker - - def current_restriction(self): - return self._threadsafe_restriction_tracker.current_restriction() - - def try_claim(self, position): - return self._threadsafe_restriction_tracker.try_claim(position) - - def defer_remainder(self, deferred_time=None): - self._threadsafe_restriction_tracker.defer_remainder(deferred_time) - - class RestrictionProgress(object): """Used to record the progress of a restriction. diff --git a/sdks/python/apache_beam/io/iobase_test.py b/sdks/python/apache_beam/io/iobase_test.py index 7460b71..04be8fa 100644 --- a/sdks/python/apache_beam/io/iobase_test.py +++ b/sdks/python/apache_beam/io/iobase_test.py @@ -15,13 +15,12 @@ # limitations under the License. # -"""Unit tests for the SDFRestrictionProvider module.""" +"""Unit tests for classes in iobase.py.""" # pytype: skip-file from __future__ import absolute_import -import time import unittest import mock @@ -31,9 +30,6 @@ from apache_beam.io.concat_source import ConcatSource from apache_beam.io.concat_source_test import RangeSource from apache_beam.io import iobase from apache_beam.io.iobase import SourceBundle -from apache_beam.io.restriction_trackers import OffsetRange -from apache_beam.io.restriction_trackers import OffsetRestrictionTracker -from apache_beam.utils import timestamp from apache_beam.options.pipeline_options import DebugOptions from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to @@ -227,85 +223,5 @@ class UseSdfBoundedSourcesTests(unittest.TestCase): self._run_sdf_wrapper_pipeline(RangeSource(0, 4), [0, 1, 2, 3]) -class ThreadsafeRestrictionTrackerTest(unittest.TestCase): - def test_initialization(self): - with self.assertRaises(ValueError): - iobase.ThreadsafeRestrictionTracker(RangeSource(0, 1)) - - def test_defer_remainder_with_wrong_time_type(self): - threadsafe_tracker = iobase.ThreadsafeRestrictionTracker( - OffsetRestrictionTracker(OffsetRange(0, 10))) - with self.assertRaises(ValueError): - threadsafe_tracker.defer_remainder(10) - - def test_self_checkpoint_immediately(self): - restriction_tracker = OffsetRestrictionTracker(OffsetRange(0, 10)) - threadsafe_tracker = iobase.ThreadsafeRestrictionTracker( - restriction_tracker) - threadsafe_tracker.defer_remainder() - deferred_residual, deferred_time = threadsafe_tracker.deferred_status() - expected_residual = OffsetRange(0, 10) - self.assertEqual(deferred_residual, expected_residual) - self.assertTrue(isinstance(deferred_time, timestamp.Duration)) - self.assertEqual(deferred_time, 0) - - def test_self_checkpoint_with_relative_time(self): - threadsafe_tracker = iobase.ThreadsafeRestrictionTracker( - OffsetRestrictionTracker(OffsetRange(0, 10))) - threadsafe_tracker.defer_remainder(timestamp.Duration(100)) - time.sleep(2) - _, deferred_time = threadsafe_tracker.deferred_status() - self.assertTrue(isinstance(deferred_time, timestamp.Duration)) - # The expectation = 100 - 2 - some_delta - self.assertTrue(deferred_time <= 98) - - def test_self_checkpoint_with_absolute_time(self): - threadsafe_tracker = iobase.ThreadsafeRestrictionTracker( - OffsetRestrictionTracker(OffsetRange(0, 10))) - now = timestamp.Timestamp.now() - schedule_time = now + timestamp.Duration(100) - self.assertTrue(isinstance(schedule_time, timestamp.Timestamp)) - threadsafe_tracker.defer_remainder(schedule_time) - time.sleep(2) - _, deferred_time = threadsafe_tracker.deferred_status() - self.assertTrue(isinstance(deferred_time, timestamp.Duration)) - # The expectation = - # schedule_time - the time when deferred_status is called - some_delta - self.assertTrue(deferred_time <= 98) - - -class RestrictionTrackerViewTest(unittest.TestCase): - def test_initialization(self): - with self.assertRaises(ValueError): - iobase.RestrictionTrackerView( - OffsetRestrictionTracker(OffsetRange(0, 10))) - - def test_api_expose(self): - threadsafe_tracker = iobase.ThreadsafeRestrictionTracker( - OffsetRestrictionTracker(OffsetRange(0, 10))) - tracker_view = iobase.RestrictionTrackerView(threadsafe_tracker) - current_restriction = tracker_view.current_restriction() - self.assertEqual(current_restriction, OffsetRange(0, 10)) - self.assertTrue(tracker_view.try_claim(0)) - tracker_view.defer_remainder() - deferred_remainder, deferred_watermark = ( - threadsafe_tracker.deferred_status()) - self.assertEqual(deferred_remainder, OffsetRange(1, 10)) - self.assertEqual(deferred_watermark, timestamp.Duration()) - - def test_non_expose_apis(self): - threadsafe_tracker = iobase.ThreadsafeRestrictionTracker( - OffsetRestrictionTracker(OffsetRange(0, 10))) - tracker_view = iobase.RestrictionTrackerView(threadsafe_tracker) - with self.assertRaises(AttributeError): - tracker_view.check_done() - with self.assertRaises(AttributeError): - tracker_view.current_progress() - with self.assertRaises(AttributeError): - tracker_view.try_split() - with self.assertRaises(AttributeError): - tracker_view.deferred_status() - - if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py index 3651b71..1e17690 100644 --- a/sdks/python/apache_beam/runners/common.py +++ b/sdks/python/apache_beam/runners/common.py @@ -44,6 +44,10 @@ from past.builtins import unicode from apache_beam.internal import util from apache_beam.options.value_provider import RuntimeValueProvider from apache_beam.pvalue import TaggedOutput +from apache_beam.runners.sdf_utils import RestrictionTrackerView +from apache_beam.runners.sdf_utils import SplitResultPrimary +from apache_beam.runners.sdf_utils import SplitResultResidual +from apache_beam.runners.sdf_utils import ThreadsafeRestrictionTracker from apache_beam.transforms import DoFn from apache_beam.transforms import core from apache_beam.transforms import userstate @@ -57,13 +61,9 @@ from apache_beam.utils.timestamp import Timestamp from apache_beam.utils.windowed_value import WindowedValue if TYPE_CHECKING: - from apache_beam.io import iobase from apache_beam.transforms import sideinputs from apache_beam.transforms.core import TimerSpec -SplitResultType = Tuple[Tuple[WindowedValue, Optional[Timestamp]], - Optional[Timestamp]] - class NameContext(object): """Holds the name information for a step.""" @@ -414,7 +414,7 @@ class DoFnInvoker(object): def invoke_process(self, windowed_value, # type: WindowedValue - restriction_tracker=None, # type: Optional[iobase.RestrictionTracker] + restriction_tracker=None, # type: Optional[RestrictionTracker] additional_args=None, additional_kwargs=None ): @@ -499,7 +499,7 @@ class SimpleInvoker(DoFnInvoker): def invoke_process(self, windowed_value, # type: WindowedValue - restriction_tracker=None, # type: Optional[iobase.RestrictionTracker] + restriction_tracker=None, # type: Optional[RestrictionTracker] additional_args=None, additional_kwargs=None ): @@ -536,7 +536,7 @@ class PerWindowInvoker(DoFnInvoker): self.watermark_estimator_param = ( self.signature.process_method.watermark_estimator_arg_name if self.watermark_estimator else None) - self.threadsafe_restriction_tracker = None # type: Optional[iobase.ThreadsafeRestrictionTracker] + self.threadsafe_restriction_tracker = None # type: Optional[ThreadsafeRestrictionTracker] self.current_windowed_value = None # type: Optional[WindowedValue] self.bundle_finalizer_param = bundle_finalizer_param self.is_key_param_required = False @@ -649,11 +649,10 @@ class PerWindowInvoker(DoFnInvoker): raise ValueError( 'A RestrictionTracker %r was provided but DoFn does not have a ' 'RestrictionTrackerParam defined' % restriction_tracker) - from apache_beam.io import iobase - self.threadsafe_restriction_tracker = iobase.ThreadsafeRestrictionTracker( + self.threadsafe_restriction_tracker = ThreadsafeRestrictionTracker( restriction_tracker) additional_kwargs[restriction_tracker_param] = ( - iobase.RestrictionTrackerView(self.threadsafe_restriction_tracker)) + RestrictionTrackerView(self.threadsafe_restriction_tracker)) if self.watermark_estimator: # The watermark estimator needs to be reset for every element. @@ -685,7 +684,7 @@ class PerWindowInvoker(DoFnInvoker): additional_args, additional_kwargs, ): - # type: (...) -> Optional[SplitResultType] + # type: (...) -> Optional[SplitResultResidual] if self.has_windowed_inputs: window, = windowed_value.windows side_inputs = [si[window] for si in self.side_inputs] @@ -766,22 +765,23 @@ class PerWindowInvoker(DoFnInvoker): # ProcessSizedElementAndRestriction. self.threadsafe_restriction_tracker.check_done() deferred_status = self.threadsafe_restriction_tracker.deferred_status() - output_watermark = None + current_watermark = None if self.watermark_estimator: - output_watermark = self.watermark_estimator.current_watermark() + current_watermark = self.watermark_estimator.current_watermark() if deferred_status: - deferred_restriction, deferred_watermark = deferred_status + deferred_restriction, deferred_timestamp = deferred_status element = windowed_value.value size = self.signature.get_restriction_provider().restriction_size( element, deferred_restriction) - return (( - windowed_value.with_value(((element, deferred_restriction), size)), - output_watermark), - deferred_watermark) + residual_value = ((element, deferred_restriction), size) + return SplitResultResidual( + residual_value=windowed_value.with_value(residual_value), + current_watermark=current_watermark, + deferred_timestamp=deferred_timestamp) return None def try_split(self, fraction): - # type: (...) -> Optional[Tuple[SplitResultType, SplitResultType]] + # type: (...) -> Optional[Tuple[SplitResultPrimary, SplitResultResidual]] if self.threadsafe_restriction_tracker and self.current_windowed_value: # Temporary workaround for [BEAM-7473]: get current_watermark before # split, in case watermark gets advanced before getting split results. @@ -797,20 +797,21 @@ class PerWindowInvoker(DoFnInvoker): restriction_provider = self.signature.get_restriction_provider() primary_size = restriction_provider.restriction_size(element, primary) residual_size = restriction_provider.restriction_size(element, residual) - return ((( - self.current_windowed_value.with_value( - ((element, primary), primary_size)), - None), - None), - (( - self.current_windowed_value.with_value( - ((element, residual), residual_size)), - current_watermark), - None)) + primary_value = ((element, primary), primary_size) + residual_value = ((element, residual), residual_size) + return ( + SplitResultPrimary( + primary_value=self.current_windowed_value.with_value( + primary_value)), + SplitResultResidual( + residual_value=self.current_windowed_value.with_value( + residual_value), + current_watermark=current_watermark, + deferred_timestamp=None)) return None def current_element_progress(self): - # type: () -> Optional[iobase.RestrictionProgress] + # type: () -> Optional[RestrictionProgress] restriction_tracker = self.threadsafe_restriction_tracker if restriction_tracker: return restriction_tracker.current_progress() @@ -900,7 +901,7 @@ class DoFnRunner: bundle_finalizer_param=self.bundle_finalizer_param) def process(self, windowed_value): - # type: (WindowedValue) -> Optional[SplitResultType] + # type: (WindowedValue) -> Optional[SplitResultResidual] try: return self.do_fn_invoker.invoke_process(windowed_value) except BaseException as exn: @@ -908,7 +909,7 @@ class DoFnRunner: return None def process_with_sized_restriction(self, windowed_value): - # type: (WindowedValue) -> Optional[SplitResultType] + # type: (WindowedValue) -> Optional[SplitResultResidual] (element, restriction), _ = windowed_value.value return self.do_fn_invoker.invoke_process( windowed_value.with_value(element), @@ -916,12 +917,12 @@ class DoFnRunner: restriction)) def try_split(self, fraction): - # type: (...) -> Optional[Tuple[SplitResultType, SplitResultType]] + # type: (...) -> Optional[Tuple[SplitResultPrimary, SplitResultResidual]] assert isinstance(self.do_fn_invoker, PerWindowInvoker) return self.do_fn_invoker.try_split(fraction) def current_element_progress(self): - # type: () -> Optional[iobase.RestrictionProgress] + # type: () -> Optional[RestrictionProgress] assert isinstance(self.do_fn_invoker, PerWindowInvoker) return self.do_fn_invoker.current_element_progress() diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py index 4e6168e..4e63d67 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py @@ -45,7 +45,6 @@ from tenacity import retry from tenacity import stop_after_attempt import apache_beam as beam -from apache_beam.io import iobase from apache_beam.io import restriction_trackers from apache_beam.metrics import monitoring_infos from apache_beam.metrics.execution import MetricKey @@ -53,6 +52,7 @@ from apache_beam.metrics.metricbase import MetricName from apache_beam.options.pipeline_options import DebugOptions from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.runners.portability import fn_api_runner +from apache_beam.runners.sdf_utils import RestrictionTrackerView from apache_beam.runners.worker import data_plane from apache_beam.runners.worker import sdk_worker from apache_beam.runners.worker import statesampler @@ -455,7 +455,7 @@ class FnApiRunnerTest(unittest.TestCase): element, restriction_tracker=beam.DoFn.RestrictionParam( ExpandStringsProvider())): - assert isinstance(restriction_tracker, iobase.RestrictionTrackerView) + assert isinstance(restriction_tracker, RestrictionTrackerView) cur = restriction_tracker.current_restriction().start while restriction_tracker.try_claim(cur): yield element[cur] @@ -473,7 +473,7 @@ class FnApiRunnerTest(unittest.TestCase): element, restriction_tracker=beam.DoFn.RestrictionParam( ExpandStringsProvider())): - assert isinstance(restriction_tracker, iobase.RestrictionTrackerView) + assert isinstance(restriction_tracker, RestrictionTrackerView) cur = restriction_tracker.current_restriction().start while restriction_tracker.try_claim(cur): yield element[cur] @@ -520,7 +520,7 @@ class FnApiRunnerTest(unittest.TestCase): element, restriction_tracker=beam.DoFn.RestrictionParam( ExpandStringsProvider())): - assert isinstance(restriction_tracker, iobase.RestrictionTrackerView) + assert isinstance(restriction_tracker, RestrictionTrackerView) cur = restriction_tracker.current_restriction().start while restriction_tracker.try_claim(cur): counter.inc() diff --git a/sdks/python/apache_beam/runners/sdf_utils.py b/sdks/python/apache_beam/runners/sdf_utils.py new file mode 100644 index 0000000..3e882c5 --- /dev/null +++ b/sdks/python/apache_beam/runners/sdf_utils.py @@ -0,0 +1,176 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# pytype: skip-file + +"""Common utility class to help SDK harness to execute an SDF. """ + +from __future__ import absolute_import +from __future__ import division + +import logging +import threading +from builtins import object +from typing import TYPE_CHECKING +from typing import Any +from typing import NamedTuple +from typing import Optional +from typing import Tuple + +from apache_beam.utils import timestamp +from apache_beam.utils.windowed_value import WindowedValue + +if TYPE_CHECKING: + from apache_beam.io.iobase import RestrictionTracker + from apache_beam.utils.timestamp import Timestamp + +_LOGGER = logging.getLogger(__name__) + +SplitResultPrimary = NamedTuple( + 'SplitResultPrimary', [('primary_value', WindowedValue)]) + +SplitResultResidual = NamedTuple( + 'SplitResultResidual', + [('residual_value', WindowedValue), + ('current_watermark', timestamp.Timestamp), + ('deferred_timestamp', timestamp.Duration)]) + + +class ThreadsafeRestrictionTracker(object): + """A thread-safe wrapper which wraps a `RestritionTracker`. + + This wrapper guarantees synchronization of modifying restrictions across + multi-thread. + """ + def __init__(self, restriction_tracker): + # type: (RestrictionTracker) -> None + from apache_beam.io.iobase import RestrictionTracker + if not isinstance(restriction_tracker, RestrictionTracker): + raise ValueError( + 'Initialize ThreadsafeRestrictionTracker requires' + 'RestrictionTracker.') + self._restriction_tracker = restriction_tracker + # Records an absolute timestamp when defer_remainder is called. + self._deferred_timestamp = None + self._lock = threading.RLock() + self._deferred_residual = None + self._deferred_watermark = None + + def current_restriction(self): + with self._lock: + return self._restriction_tracker.current_restriction() + + def try_claim(self, position): + with self._lock: + return self._restriction_tracker.try_claim(position) + + def defer_remainder(self, deferred_time=None): + """Performs self-checkpoint on current processing restriction with an + expected resuming time. + + Self-checkpoint could happen during processing elements. When executing an + DoFn.process(), you may want to stop processing an element and resuming + later if current element has been processed quit a long time or you also + want to have some outputs from other elements. ``defer_remainder()`` can be + called on per element if needed. + + Args: + deferred_time: A relative ``timestamp.Duration`` that indicates the ideal + time gap between now and resuming, or an absolute + ``timestamp.Timestamp`` for resuming execution time. If the time_delay + is None, the deferred work will be executed as soon as possible. + """ + + # Record current time for calculating deferred_time later. + self._deferred_timestamp = timestamp.Timestamp.now() + if (deferred_time and not isinstance(deferred_time, timestamp.Duration) and + not isinstance(deferred_time, timestamp.Timestamp)): + raise ValueError( + 'The timestamp of deter_remainder() should be a ' + 'Duration or a Timestamp, or None.') + self._deferred_watermark = deferred_time + checkpoint = self.try_split(0) + if checkpoint: + _, self._deferred_residual = checkpoint + + def check_done(self): + with self._lock: + return self._restriction_tracker.check_done() + + def current_progress(self): + with self._lock: + return self._restriction_tracker.current_progress() + + def try_split(self, fraction_of_remainder): + with self._lock: + return self._restriction_tracker.try_split(fraction_of_remainder) + + def deferred_status(self): + # type: () -> Optional[Tuple[Any, Timestamp]] + + """Returns deferred work which is produced by ``defer_remainder()``. + + When there is a self-checkpoint performed, the system needs to fulfill the + DelayedBundleApplication with deferred_work for a ProcessBundleResponse. + The system calls this API to get deferred_residual with watermark together + to help the runner to schedule a future work. + + Returns: (deferred_residual, time_delay) if having any residual, else None. + """ + if self._deferred_residual: + # If _deferred_watermark is None, create Duration(0). + if not self._deferred_watermark: + self._deferred_watermark = timestamp.Duration() + # If an absolute timestamp is provided, calculate the delta between + # the absoluted time and the time deferred_status() is called. + elif isinstance(self._deferred_watermark, timestamp.Timestamp): + self._deferred_watermark = ( + self._deferred_watermark - timestamp.Timestamp.now()) + # If a Duration is provided, the deferred time should be: + # provided duration - the spent time since the defer_remainder() is + # called. + elif isinstance(self._deferred_watermark, timestamp.Duration): + self._deferred_watermark -= ( + timestamp.Timestamp.now() - self._deferred_timestamp) + return self._deferred_residual, self._deferred_watermark + return None + + +class RestrictionTrackerView(object): + """A DoFn view of thread-safe RestrictionTracker. + + The RestrictionTrackerView wraps a ThreadsafeRestrictionTracker and only + exposes APIs that will be called by a ``DoFn.process()``. During execution + time, the RestrictionTrackerView will be fed into the ``DoFn.process`` as a + restriction_tracker. + """ + def __init__(self, threadsafe_restriction_tracker): + if not isinstance(threadsafe_restriction_tracker, + ThreadsafeRestrictionTracker): + raise ValueError( + 'Initialize RestrictionTrackerView requires ' + 'ThreadsafeRestrictionTracker.') + self._threadsafe_restriction_tracker = threadsafe_restriction_tracker + + def current_restriction(self): + return self._threadsafe_restriction_tracker.current_restriction() + + def try_claim(self, position): + return self._threadsafe_restriction_tracker.try_claim(position) + + def defer_remainder(self, deferred_time=None): + self._threadsafe_restriction_tracker.defer_remainder(deferred_time) diff --git a/sdks/python/apache_beam/runners/sdf_utils_test.py b/sdks/python/apache_beam/runners/sdf_utils_test.py new file mode 100644 index 0000000..30465ff --- /dev/null +++ b/sdks/python/apache_beam/runners/sdf_utils_test.py @@ -0,0 +1,114 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Unit tests for classes in sdf_utils.py.""" + +# pytype: skip-file + +from __future__ import absolute_import + +import time +import unittest + +from apache_beam.io.concat_source_test import RangeSource +from apache_beam.io.restriction_trackers import OffsetRange +from apache_beam.io.restriction_trackers import OffsetRestrictionTracker +from apache_beam.runners.sdf_utils import RestrictionTrackerView +from apache_beam.runners.sdf_utils import ThreadsafeRestrictionTracker +from apache_beam.utils import timestamp + + +class ThreadsafeRestrictionTrackerTest(unittest.TestCase): + def test_initialization(self): + with self.assertRaises(ValueError): + ThreadsafeRestrictionTracker(RangeSource(0, 1)) + + def test_defer_remainder_with_wrong_time_type(self): + threadsafe_tracker = ThreadsafeRestrictionTracker( + OffsetRestrictionTracker(OffsetRange(0, 10))) + with self.assertRaises(ValueError): + threadsafe_tracker.defer_remainder(10) + + def test_self_checkpoint_immediately(self): + restriction_tracker = OffsetRestrictionTracker(OffsetRange(0, 10)) + threadsafe_tracker = ThreadsafeRestrictionTracker(restriction_tracker) + threadsafe_tracker.defer_remainder() + deferred_residual, deferred_time = threadsafe_tracker.deferred_status() + expected_residual = OffsetRange(0, 10) + self.assertEqual(deferred_residual, expected_residual) + self.assertTrue(isinstance(deferred_time, timestamp.Duration)) + self.assertEqual(deferred_time, 0) + + def test_self_checkpoint_with_relative_time(self): + threadsafe_tracker = ThreadsafeRestrictionTracker( + OffsetRestrictionTracker(OffsetRange(0, 10))) + threadsafe_tracker.defer_remainder(timestamp.Duration(100)) + time.sleep(2) + _, deferred_time = threadsafe_tracker.deferred_status() + self.assertTrue(isinstance(deferred_time, timestamp.Duration)) + # The expectation = 100 - 2 - some_delta + self.assertTrue(deferred_time <= 98) + + def test_self_checkpoint_with_absolute_time(self): + threadsafe_tracker = ThreadsafeRestrictionTracker( + OffsetRestrictionTracker(OffsetRange(0, 10))) + now = timestamp.Timestamp.now() + schedule_time = now + timestamp.Duration(100) + self.assertTrue(isinstance(schedule_time, timestamp.Timestamp)) + threadsafe_tracker.defer_remainder(schedule_time) + time.sleep(2) + _, deferred_time = threadsafe_tracker.deferred_status() + self.assertTrue(isinstance(deferred_time, timestamp.Duration)) + # The expectation = + # schedule_time - the time when deferred_status is called - some_delta + self.assertTrue(deferred_time <= 98) + + +class RestrictionTrackerViewTest(unittest.TestCase): + def test_initialization(self): + with self.assertRaises(ValueError): + RestrictionTrackerView(OffsetRestrictionTracker(OffsetRange(0, 10))) + + def test_api_expose(self): + threadsafe_tracker = ThreadsafeRestrictionTracker( + OffsetRestrictionTracker(OffsetRange(0, 10))) + tracker_view = RestrictionTrackerView(threadsafe_tracker) + current_restriction = tracker_view.current_restriction() + self.assertEqual(current_restriction, OffsetRange(0, 10)) + self.assertTrue(tracker_view.try_claim(0)) + tracker_view.defer_remainder() + deferred_remainder, deferred_watermark = ( + threadsafe_tracker.deferred_status()) + self.assertEqual(deferred_remainder, OffsetRange(1, 10)) + self.assertEqual(deferred_watermark, timestamp.Duration()) + + def test_non_expose_apis(self): + threadsafe_tracker = ThreadsafeRestrictionTracker( + OffsetRestrictionTracker(OffsetRange(0, 10))) + tracker_view = RestrictionTrackerView(threadsafe_tracker) + with self.assertRaises(AttributeError): + tracker_view.check_done() + with self.assertRaises(AttributeError): + tracker_view.current_progress() + with self.assertRaises(AttributeError): + tracker_view.try_split() + with self.assertRaises(AttributeError): + tracker_view.deferred_status() + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index df383d1..f44b5cf 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -906,26 +906,24 @@ class BundleProcessor(object): # type: (...) -> beam_fn_api_pb2.DelayedBundleApplication assert op.input_info is not None # TODO(SDF): For non-root nodes, need main_input_coder + residual_coder. - ((element_and_restriction, output_watermark), - deferred_watermark) = deferred_remainder - if deferred_watermark: - assert isinstance(deferred_watermark, timestamp.Duration) + (element_and_restriction, current_watermark, deferred_timestamp) = ( + deferred_remainder) + if deferred_timestamp: + assert isinstance(deferred_timestamp, timestamp.Duration) proto_deferred_watermark = duration_pb2.Duration() - proto_deferred_watermark.FromMicroseconds(deferred_watermark.micros) + proto_deferred_watermark.FromMicroseconds(deferred_timestamp.micros) else: proto_deferred_watermark = None return beam_fn_api_pb2.DelayedBundleApplication( requested_time_delay=proto_deferred_watermark, application=self.construct_bundle_application( - op, output_watermark, element_and_restriction)) + op, current_watermark, element_and_restriction)) def bundle_application(self, op, # type: operations.DoOperation primary # type: common.SplitResultType ): - ((element_and_restriction, output_watermark), _) = primary - return self.construct_bundle_application( - op, output_watermark, element_and_restriction) + return self.construct_bundle_application(op, None, primary.primary_value) def construct_bundle_application(self, op, output_watermark, element): transform_id, main_input_tag, main_input_coder, outputs = op.input_info