This is an automated email from the ASF dual-hosted git repository.
boyuanz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new cd6e54b Move ThreadsafeRestrictionTracker and RestrictionTrackerView
out from iobase.py
new bcc3e13 Merge pull request #10802 from boyuanzz/refactor
cd6e54b is described below
commit cd6e54bc19f5e69cab49d22f5044b1c869b9ec69
Author: Boyuan Zhang <[email protected]>
AuthorDate: Fri Feb 7 16:07:20 2020 -0800
Move ThreadsafeRestrictionTracker and RestrictionTrackerView out from
iobase.py
---
sdks/python/apache_beam/io/iobase.py | 128 ---------------
sdks/python/apache_beam/io/iobase_test.py | 86 +---------
sdks/python/apache_beam/runners/common.py | 69 ++++----
.../runners/portability/fn_api_runner_test.py | 8 +-
sdks/python/apache_beam/runners/sdf_utils.py | 176 +++++++++++++++++++++
sdks/python/apache_beam/runners/sdf_utils_test.py | 114 +++++++++++++
.../apache_beam/runners/worker/bundle_processor.py | 16 +-
7 files changed, 337 insertions(+), 260 deletions(-)
diff --git a/sdks/python/apache_beam/io/iobase.py
b/sdks/python/apache_beam/io/iobase.py
index 10d2933..1302c38 100644
--- a/sdks/python/apache_beam/io/iobase.py
+++ b/sdks/python/apache_beam/io/iobase.py
@@ -37,7 +37,6 @@ from __future__ import division
import logging
import math
import random
-import threading
import uuid
from builtins import object
from builtins import range
@@ -65,9 +64,7 @@ from apache_beam.utils import urns
from apache_beam.utils.windowed_value import WindowedValue
if TYPE_CHECKING:
- from apache_beam.io import restriction_trackers
from apache_beam.runners.pipeline_context import PipelineContext
- from apache_beam.utils.timestamp import Timestamp
__all__ = [
'BoundedSource',
@@ -1246,131 +1243,6 @@ class RestrictionTracker(object):
raise NotImplementedError
-class ThreadsafeRestrictionTracker(object):
- """A thread-safe wrapper which wraps a `RestritionTracker`.
-
- This wrapper guarantees synchronization of modifying restrictions across
- multi-thread.
- """
- def __init__(self, restriction_tracker):
- # type: (RestrictionTracker) -> None
- if not isinstance(restriction_tracker, RestrictionTracker):
- raise ValueError(
- 'Initialize ThreadsafeRestrictionTracker requires'
- 'RestrictionTracker.')
- self._restriction_tracker = restriction_tracker
- # Records an absolute timestamp when defer_remainder is called.
- self._deferred_timestamp = None
- self._lock = threading.RLock()
- self._deferred_residual = None
- self._deferred_watermark = None
-
- def current_restriction(self):
- with self._lock:
- return self._restriction_tracker.current_restriction()
-
- def try_claim(self, position):
- with self._lock:
- return self._restriction_tracker.try_claim(position)
-
- def defer_remainder(self, deferred_time=None):
- """Performs self-checkpoint on current processing restriction with an
- expected resuming time.
-
- Self-checkpoint could happen during processing elements. When executing an
- DoFn.process(), you may want to stop processing an element and resuming
- later if current element has been processed quit a long time or you also
- want to have some outputs from other elements. ``defer_remainder()`` can be
- called on per element if needed.
-
- Args:
- deferred_time: A relative ``timestamp.Duration`` that indicates the ideal
- time gap between now and resuming, or an absolute ``timestamp.Timestamp``
- for resuming execution time. If the time_delay is None, the deferred work
- will be executed as soon as possible.
- """
-
- # Record current time for calculating deferred_time later.
- self._deferred_timestamp = timestamp.Timestamp.now()
- if (deferred_time and not isinstance(deferred_time, timestamp.Duration) and
- not isinstance(deferred_time, timestamp.Timestamp)):
- raise ValueError(
- 'The timestamp of deter_remainder() should be a '
- 'Duration or a Timestamp, or None.')
- self._deferred_watermark = deferred_time
- checkpoint = self.try_split(0)
- if checkpoint:
- _, self._deferred_residual = checkpoint
-
- def check_done(self):
- with self._lock:
- return self._restriction_tracker.check_done()
-
- def current_progress(self):
- with self._lock:
- return self._restriction_tracker.current_progress()
-
- def try_split(self, fraction_of_remainder):
- with self._lock:
- return self._restriction_tracker.try_split(fraction_of_remainder)
-
- def deferred_status(self):
- # type: () -> Optional[Tuple[Any, Timestamp]]
-
- """Returns deferred work which is produced by ``defer_remainder()``.
-
- When there is a self-checkpoint performed, the system needs to fulfill the
- DelayedBundleApplication with deferred_work for a ProcessBundleResponse.
- The system calls this API to get deferred_residual with watermark together
- to help the runner to schedule a future work.
-
- Returns: (deferred_residual, time_delay) if having any residual, else None.
- """
- if self._deferred_residual:
- # If _deferred_watermark is None, create Duration(0).
- if not self._deferred_watermark:
- self._deferred_watermark = timestamp.Duration()
- # If an absolute timestamp is provided, calculate the delta between
- # the absoluted time and the time deferred_status() is called.
- elif isinstance(self._deferred_watermark, timestamp.Timestamp):
- self._deferred_watermark = (
- self._deferred_watermark - timestamp.Timestamp.now())
- # If a Duration is provided, the deferred time should be:
- # provided duration - the spent time since the defer_remainder() is
- # called.
- elif isinstance(self._deferred_watermark, timestamp.Duration):
- self._deferred_watermark -= (
- timestamp.Timestamp.now() - self._deferred_timestamp)
- return self._deferred_residual, self._deferred_watermark
- return None
-
-
-class RestrictionTrackerView(object):
- """A DoFn view of thread-safe RestrictionTracker.
-
- The RestrictionTrackerView wraps a ThreadsafeRestrictionTracker and only
- exposes APIs that will be called by a ``DoFn.process()``. During execution
- time, the RestrictionTrackerView will be fed into the ``DoFn.process`` as a
- restriction_tracker.
- """
- def __init__(self, threadsafe_restriction_tracker):
- if not isinstance(threadsafe_restriction_tracker,
- ThreadsafeRestrictionTracker):
- raise ValueError(
- 'Initialize RestrictionTrackerView requires '
- 'ThreadsafeRestrictionTracker.')
- self._threadsafe_restriction_tracker = threadsafe_restriction_tracker
-
- def current_restriction(self):
- return self._threadsafe_restriction_tracker.current_restriction()
-
- def try_claim(self, position):
- return self._threadsafe_restriction_tracker.try_claim(position)
-
- def defer_remainder(self, deferred_time=None):
- self._threadsafe_restriction_tracker.defer_remainder(deferred_time)
-
-
class RestrictionProgress(object):
"""Used to record the progress of a restriction.
diff --git a/sdks/python/apache_beam/io/iobase_test.py
b/sdks/python/apache_beam/io/iobase_test.py
index 7460b71..04be8fa 100644
--- a/sdks/python/apache_beam/io/iobase_test.py
+++ b/sdks/python/apache_beam/io/iobase_test.py
@@ -15,13 +15,12 @@
# limitations under the License.
#
-"""Unit tests for the SDFRestrictionProvider module."""
+"""Unit tests for classes in iobase.py."""
# pytype: skip-file
from __future__ import absolute_import
-import time
import unittest
import mock
@@ -31,9 +30,6 @@ from apache_beam.io.concat_source import ConcatSource
from apache_beam.io.concat_source_test import RangeSource
from apache_beam.io import iobase
from apache_beam.io.iobase import SourceBundle
-from apache_beam.io.restriction_trackers import OffsetRange
-from apache_beam.io.restriction_trackers import OffsetRestrictionTracker
-from apache_beam.utils import timestamp
from apache_beam.options.pipeline_options import DebugOptions
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
@@ -227,85 +223,5 @@ class UseSdfBoundedSourcesTests(unittest.TestCase):
self._run_sdf_wrapper_pipeline(RangeSource(0, 4), [0, 1, 2, 3])
-class ThreadsafeRestrictionTrackerTest(unittest.TestCase):
- def test_initialization(self):
- with self.assertRaises(ValueError):
- iobase.ThreadsafeRestrictionTracker(RangeSource(0, 1))
-
- def test_defer_remainder_with_wrong_time_type(self):
- threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
- OffsetRestrictionTracker(OffsetRange(0, 10)))
- with self.assertRaises(ValueError):
- threadsafe_tracker.defer_remainder(10)
-
- def test_self_checkpoint_immediately(self):
- restriction_tracker = OffsetRestrictionTracker(OffsetRange(0, 10))
- threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
- restriction_tracker)
- threadsafe_tracker.defer_remainder()
- deferred_residual, deferred_time = threadsafe_tracker.deferred_status()
- expected_residual = OffsetRange(0, 10)
- self.assertEqual(deferred_residual, expected_residual)
- self.assertTrue(isinstance(deferred_time, timestamp.Duration))
- self.assertEqual(deferred_time, 0)
-
- def test_self_checkpoint_with_relative_time(self):
- threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
- OffsetRestrictionTracker(OffsetRange(0, 10)))
- threadsafe_tracker.defer_remainder(timestamp.Duration(100))
- time.sleep(2)
- _, deferred_time = threadsafe_tracker.deferred_status()
- self.assertTrue(isinstance(deferred_time, timestamp.Duration))
- # The expectation = 100 - 2 - some_delta
- self.assertTrue(deferred_time <= 98)
-
- def test_self_checkpoint_with_absolute_time(self):
- threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
- OffsetRestrictionTracker(OffsetRange(0, 10)))
- now = timestamp.Timestamp.now()
- schedule_time = now + timestamp.Duration(100)
- self.assertTrue(isinstance(schedule_time, timestamp.Timestamp))
- threadsafe_tracker.defer_remainder(schedule_time)
- time.sleep(2)
- _, deferred_time = threadsafe_tracker.deferred_status()
- self.assertTrue(isinstance(deferred_time, timestamp.Duration))
- # The expectation =
- # schedule_time - the time when deferred_status is called - some_delta
- self.assertTrue(deferred_time <= 98)
-
-
-class RestrictionTrackerViewTest(unittest.TestCase):
- def test_initialization(self):
- with self.assertRaises(ValueError):
- iobase.RestrictionTrackerView(
- OffsetRestrictionTracker(OffsetRange(0, 10)))
-
- def test_api_expose(self):
- threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
- OffsetRestrictionTracker(OffsetRange(0, 10)))
- tracker_view = iobase.RestrictionTrackerView(threadsafe_tracker)
- current_restriction = tracker_view.current_restriction()
- self.assertEqual(current_restriction, OffsetRange(0, 10))
- self.assertTrue(tracker_view.try_claim(0))
- tracker_view.defer_remainder()
- deferred_remainder, deferred_watermark = (
- threadsafe_tracker.deferred_status())
- self.assertEqual(deferred_remainder, OffsetRange(1, 10))
- self.assertEqual(deferred_watermark, timestamp.Duration())
-
- def test_non_expose_apis(self):
- threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
- OffsetRestrictionTracker(OffsetRange(0, 10)))
- tracker_view = iobase.RestrictionTrackerView(threadsafe_tracker)
- with self.assertRaises(AttributeError):
- tracker_view.check_done()
- with self.assertRaises(AttributeError):
- tracker_view.current_progress()
- with self.assertRaises(AttributeError):
- tracker_view.try_split()
- with self.assertRaises(AttributeError):
- tracker_view.deferred_status()
-
-
if __name__ == '__main__':
unittest.main()
diff --git a/sdks/python/apache_beam/runners/common.py
b/sdks/python/apache_beam/runners/common.py
index 3651b71..1e17690 100644
--- a/sdks/python/apache_beam/runners/common.py
+++ b/sdks/python/apache_beam/runners/common.py
@@ -44,6 +44,10 @@ from past.builtins import unicode
from apache_beam.internal import util
from apache_beam.options.value_provider import RuntimeValueProvider
from apache_beam.pvalue import TaggedOutput
+from apache_beam.runners.sdf_utils import RestrictionTrackerView
+from apache_beam.runners.sdf_utils import SplitResultPrimary
+from apache_beam.runners.sdf_utils import SplitResultResidual
+from apache_beam.runners.sdf_utils import ThreadsafeRestrictionTracker
from apache_beam.transforms import DoFn
from apache_beam.transforms import core
from apache_beam.transforms import userstate
@@ -57,13 +61,9 @@ from apache_beam.utils.timestamp import Timestamp
from apache_beam.utils.windowed_value import WindowedValue
if TYPE_CHECKING:
- from apache_beam.io import iobase
from apache_beam.transforms import sideinputs
from apache_beam.transforms.core import TimerSpec
-SplitResultType = Tuple[Tuple[WindowedValue, Optional[Timestamp]],
- Optional[Timestamp]]
-
class NameContext(object):
"""Holds the name information for a step."""
@@ -414,7 +414,7 @@ class DoFnInvoker(object):
def invoke_process(self,
windowed_value, # type: WindowedValue
- restriction_tracker=None, # type:
Optional[iobase.RestrictionTracker]
+ restriction_tracker=None, # type:
Optional[RestrictionTracker]
additional_args=None,
additional_kwargs=None
):
@@ -499,7 +499,7 @@ class SimpleInvoker(DoFnInvoker):
def invoke_process(self,
windowed_value, # type: WindowedValue
- restriction_tracker=None, # type:
Optional[iobase.RestrictionTracker]
+ restriction_tracker=None, # type:
Optional[RestrictionTracker]
additional_args=None,
additional_kwargs=None
):
@@ -536,7 +536,7 @@ class PerWindowInvoker(DoFnInvoker):
self.watermark_estimator_param = (
self.signature.process_method.watermark_estimator_arg_name
if self.watermark_estimator else None)
- self.threadsafe_restriction_tracker = None # type:
Optional[iobase.ThreadsafeRestrictionTracker]
+ self.threadsafe_restriction_tracker = None # type:
Optional[ThreadsafeRestrictionTracker]
self.current_windowed_value = None # type: Optional[WindowedValue]
self.bundle_finalizer_param = bundle_finalizer_param
self.is_key_param_required = False
@@ -649,11 +649,10 @@ class PerWindowInvoker(DoFnInvoker):
raise ValueError(
'A RestrictionTracker %r was provided but DoFn does not have a '
'RestrictionTrackerParam defined' % restriction_tracker)
- from apache_beam.io import iobase
- self.threadsafe_restriction_tracker =
iobase.ThreadsafeRestrictionTracker(
+ self.threadsafe_restriction_tracker = ThreadsafeRestrictionTracker(
restriction_tracker)
additional_kwargs[restriction_tracker_param] = (
- iobase.RestrictionTrackerView(self.threadsafe_restriction_tracker))
+ RestrictionTrackerView(self.threadsafe_restriction_tracker))
if self.watermark_estimator:
# The watermark estimator needs to be reset for every element.
@@ -685,7 +684,7 @@ class PerWindowInvoker(DoFnInvoker):
additional_args,
additional_kwargs,
):
- # type: (...) -> Optional[SplitResultType]
+ # type: (...) -> Optional[SplitResultResidual]
if self.has_windowed_inputs:
window, = windowed_value.windows
side_inputs = [si[window] for si in self.side_inputs]
@@ -766,22 +765,23 @@ class PerWindowInvoker(DoFnInvoker):
# ProcessSizedElementAndRestriction.
self.threadsafe_restriction_tracker.check_done()
deferred_status = self.threadsafe_restriction_tracker.deferred_status()
- output_watermark = None
+ current_watermark = None
if self.watermark_estimator:
- output_watermark = self.watermark_estimator.current_watermark()
+ current_watermark = self.watermark_estimator.current_watermark()
if deferred_status:
- deferred_restriction, deferred_watermark = deferred_status
+ deferred_restriction, deferred_timestamp = deferred_status
element = windowed_value.value
size = self.signature.get_restriction_provider().restriction_size(
element, deferred_restriction)
- return ((
- windowed_value.with_value(((element, deferred_restriction), size)),
- output_watermark),
- deferred_watermark)
+ residual_value = ((element, deferred_restriction), size)
+ return SplitResultResidual(
+ residual_value=windowed_value.with_value(residual_value),
+ current_watermark=current_watermark,
+ deferred_timestamp=deferred_timestamp)
return None
def try_split(self, fraction):
- # type: (...) -> Optional[Tuple[SplitResultType, SplitResultType]]
+ # type: (...) -> Optional[Tuple[SplitResultPrimary, SplitResultResidual]]
if self.threadsafe_restriction_tracker and self.current_windowed_value:
# Temporary workaround for [BEAM-7473]: get current_watermark before
# split, in case watermark gets advanced before getting split results.
@@ -797,20 +797,21 @@ class PerWindowInvoker(DoFnInvoker):
restriction_provider = self.signature.get_restriction_provider()
primary_size = restriction_provider.restriction_size(element, primary)
residual_size = restriction_provider.restriction_size(element,
residual)
- return (((
- self.current_windowed_value.with_value(
- ((element, primary), primary_size)),
- None),
- None),
- ((
- self.current_windowed_value.with_value(
- ((element, residual), residual_size)),
- current_watermark),
- None))
+ primary_value = ((element, primary), primary_size)
+ residual_value = ((element, residual), residual_size)
+ return (
+ SplitResultPrimary(
+ primary_value=self.current_windowed_value.with_value(
+ primary_value)),
+ SplitResultResidual(
+ residual_value=self.current_windowed_value.with_value(
+ residual_value),
+ current_watermark=current_watermark,
+ deferred_timestamp=None))
return None
def current_element_progress(self):
- # type: () -> Optional[iobase.RestrictionProgress]
+ # type: () -> Optional[RestrictionProgress]
restriction_tracker = self.threadsafe_restriction_tracker
if restriction_tracker:
return restriction_tracker.current_progress()
@@ -900,7 +901,7 @@ class DoFnRunner:
bundle_finalizer_param=self.bundle_finalizer_param)
def process(self, windowed_value):
- # type: (WindowedValue) -> Optional[SplitResultType]
+ # type: (WindowedValue) -> Optional[SplitResultResidual]
try:
return self.do_fn_invoker.invoke_process(windowed_value)
except BaseException as exn:
@@ -908,7 +909,7 @@ class DoFnRunner:
return None
def process_with_sized_restriction(self, windowed_value):
- # type: (WindowedValue) -> Optional[SplitResultType]
+ # type: (WindowedValue) -> Optional[SplitResultResidual]
(element, restriction), _ = windowed_value.value
return self.do_fn_invoker.invoke_process(
windowed_value.with_value(element),
@@ -916,12 +917,12 @@ class DoFnRunner:
restriction))
def try_split(self, fraction):
- # type: (...) -> Optional[Tuple[SplitResultType, SplitResultType]]
+ # type: (...) -> Optional[Tuple[SplitResultPrimary, SplitResultResidual]]
assert isinstance(self.do_fn_invoker, PerWindowInvoker)
return self.do_fn_invoker.try_split(fraction)
def current_element_progress(self):
- # type: () -> Optional[iobase.RestrictionProgress]
+ # type: () -> Optional[RestrictionProgress]
assert isinstance(self.do_fn_invoker, PerWindowInvoker)
return self.do_fn_invoker.current_element_progress()
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
index 4e6168e..4e63d67 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
@@ -45,7 +45,6 @@ from tenacity import retry
from tenacity import stop_after_attempt
import apache_beam as beam
-from apache_beam.io import iobase
from apache_beam.io import restriction_trackers
from apache_beam.metrics import monitoring_infos
from apache_beam.metrics.execution import MetricKey
@@ -53,6 +52,7 @@ from apache_beam.metrics.metricbase import MetricName
from apache_beam.options.pipeline_options import DebugOptions
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.runners.portability import fn_api_runner
+from apache_beam.runners.sdf_utils import RestrictionTrackerView
from apache_beam.runners.worker import data_plane
from apache_beam.runners.worker import sdk_worker
from apache_beam.runners.worker import statesampler
@@ -455,7 +455,7 @@ class FnApiRunnerTest(unittest.TestCase):
element,
restriction_tracker=beam.DoFn.RestrictionParam(
ExpandStringsProvider())):
- assert isinstance(restriction_tracker, iobase.RestrictionTrackerView)
+ assert isinstance(restriction_tracker, RestrictionTrackerView)
cur = restriction_tracker.current_restriction().start
while restriction_tracker.try_claim(cur):
yield element[cur]
@@ -473,7 +473,7 @@ class FnApiRunnerTest(unittest.TestCase):
element,
restriction_tracker=beam.DoFn.RestrictionParam(
ExpandStringsProvider())):
- assert isinstance(restriction_tracker, iobase.RestrictionTrackerView)
+ assert isinstance(restriction_tracker, RestrictionTrackerView)
cur = restriction_tracker.current_restriction().start
while restriction_tracker.try_claim(cur):
yield element[cur]
@@ -520,7 +520,7 @@ class FnApiRunnerTest(unittest.TestCase):
element,
restriction_tracker=beam.DoFn.RestrictionParam(
ExpandStringsProvider())):
- assert isinstance(restriction_tracker, iobase.RestrictionTrackerView)
+ assert isinstance(restriction_tracker, RestrictionTrackerView)
cur = restriction_tracker.current_restriction().start
while restriction_tracker.try_claim(cur):
counter.inc()
diff --git a/sdks/python/apache_beam/runners/sdf_utils.py
b/sdks/python/apache_beam/runners/sdf_utils.py
new file mode 100644
index 0000000..3e882c5
--- /dev/null
+++ b/sdks/python/apache_beam/runners/sdf_utils.py
@@ -0,0 +1,176 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pytype: skip-file
+
+"""Common utility class to help SDK harness to execute an SDF. """
+
+from __future__ import absolute_import
+from __future__ import division
+
+import logging
+import threading
+from builtins import object
+from typing import TYPE_CHECKING
+from typing import Any
+from typing import NamedTuple
+from typing import Optional
+from typing import Tuple
+
+from apache_beam.utils import timestamp
+from apache_beam.utils.windowed_value import WindowedValue
+
+if TYPE_CHECKING:
+ from apache_beam.io.iobase import RestrictionTracker
+ from apache_beam.utils.timestamp import Timestamp
+
+_LOGGER = logging.getLogger(__name__)
+
+SplitResultPrimary = NamedTuple(
+ 'SplitResultPrimary', [('primary_value', WindowedValue)])
+
+SplitResultResidual = NamedTuple(
+ 'SplitResultResidual',
+ [('residual_value', WindowedValue),
+ ('current_watermark', timestamp.Timestamp),
+ ('deferred_timestamp', timestamp.Duration)])
+
+
+class ThreadsafeRestrictionTracker(object):
+ """A thread-safe wrapper which wraps a `RestritionTracker`.
+
+ This wrapper guarantees synchronization of modifying restrictions across
+ multi-thread.
+ """
+ def __init__(self, restriction_tracker):
+ # type: (RestrictionTracker) -> None
+ from apache_beam.io.iobase import RestrictionTracker
+ if not isinstance(restriction_tracker, RestrictionTracker):
+ raise ValueError(
+ 'Initialize ThreadsafeRestrictionTracker requires'
+ 'RestrictionTracker.')
+ self._restriction_tracker = restriction_tracker
+ # Records an absolute timestamp when defer_remainder is called.
+ self._deferred_timestamp = None
+ self._lock = threading.RLock()
+ self._deferred_residual = None
+ self._deferred_watermark = None
+
+ def current_restriction(self):
+ with self._lock:
+ return self._restriction_tracker.current_restriction()
+
+ def try_claim(self, position):
+ with self._lock:
+ return self._restriction_tracker.try_claim(position)
+
+ def defer_remainder(self, deferred_time=None):
+ """Performs self-checkpoint on current processing restriction with an
+ expected resuming time.
+
+ Self-checkpoint could happen during processing elements. When executing an
+ DoFn.process(), you may want to stop processing an element and resuming
+ later if current element has been processed quit a long time or you also
+ want to have some outputs from other elements. ``defer_remainder()`` can be
+ called on per element if needed.
+
+ Args:
+ deferred_time: A relative ``timestamp.Duration`` that indicates the ideal
+ time gap between now and resuming, or an absolute
+ ``timestamp.Timestamp`` for resuming execution time. If the time_delay
+ is None, the deferred work will be executed as soon as possible.
+ """
+
+ # Record current time for calculating deferred_time later.
+ self._deferred_timestamp = timestamp.Timestamp.now()
+ if (deferred_time and not isinstance(deferred_time, timestamp.Duration) and
+ not isinstance(deferred_time, timestamp.Timestamp)):
+ raise ValueError(
+ 'The timestamp of deter_remainder() should be a '
+ 'Duration or a Timestamp, or None.')
+ self._deferred_watermark = deferred_time
+ checkpoint = self.try_split(0)
+ if checkpoint:
+ _, self._deferred_residual = checkpoint
+
+ def check_done(self):
+ with self._lock:
+ return self._restriction_tracker.check_done()
+
+ def current_progress(self):
+ with self._lock:
+ return self._restriction_tracker.current_progress()
+
+ def try_split(self, fraction_of_remainder):
+ with self._lock:
+ return self._restriction_tracker.try_split(fraction_of_remainder)
+
+ def deferred_status(self):
+ # type: () -> Optional[Tuple[Any, Timestamp]]
+
+ """Returns deferred work which is produced by ``defer_remainder()``.
+
+ When there is a self-checkpoint performed, the system needs to fulfill the
+ DelayedBundleApplication with deferred_work for a ProcessBundleResponse.
+ The system calls this API to get deferred_residual with watermark together
+ to help the runner to schedule a future work.
+
+ Returns: (deferred_residual, time_delay) if having any residual, else None.
+ """
+ if self._deferred_residual:
+ # If _deferred_watermark is None, create Duration(0).
+ if not self._deferred_watermark:
+ self._deferred_watermark = timestamp.Duration()
+ # If an absolute timestamp is provided, calculate the delta between
+ # the absoluted time and the time deferred_status() is called.
+ elif isinstance(self._deferred_watermark, timestamp.Timestamp):
+ self._deferred_watermark = (
+ self._deferred_watermark - timestamp.Timestamp.now())
+ # If a Duration is provided, the deferred time should be:
+ # provided duration - the spent time since the defer_remainder() is
+ # called.
+ elif isinstance(self._deferred_watermark, timestamp.Duration):
+ self._deferred_watermark -= (
+ timestamp.Timestamp.now() - self._deferred_timestamp)
+ return self._deferred_residual, self._deferred_watermark
+ return None
+
+
+class RestrictionTrackerView(object):
+ """A DoFn view of thread-safe RestrictionTracker.
+
+ The RestrictionTrackerView wraps a ThreadsafeRestrictionTracker and only
+ exposes APIs that will be called by a ``DoFn.process()``. During execution
+ time, the RestrictionTrackerView will be fed into the ``DoFn.process`` as a
+ restriction_tracker.
+ """
+ def __init__(self, threadsafe_restriction_tracker):
+ if not isinstance(threadsafe_restriction_tracker,
+ ThreadsafeRestrictionTracker):
+ raise ValueError(
+ 'Initialize RestrictionTrackerView requires '
+ 'ThreadsafeRestrictionTracker.')
+ self._threadsafe_restriction_tracker = threadsafe_restriction_tracker
+
+ def current_restriction(self):
+ return self._threadsafe_restriction_tracker.current_restriction()
+
+ def try_claim(self, position):
+ return self._threadsafe_restriction_tracker.try_claim(position)
+
+ def defer_remainder(self, deferred_time=None):
+ self._threadsafe_restriction_tracker.defer_remainder(deferred_time)
diff --git a/sdks/python/apache_beam/runners/sdf_utils_test.py
b/sdks/python/apache_beam/runners/sdf_utils_test.py
new file mode 100644
index 0000000..30465ff
--- /dev/null
+++ b/sdks/python/apache_beam/runners/sdf_utils_test.py
@@ -0,0 +1,114 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Unit tests for classes in sdf_utils.py."""
+
+# pytype: skip-file
+
+from __future__ import absolute_import
+
+import time
+import unittest
+
+from apache_beam.io.concat_source_test import RangeSource
+from apache_beam.io.restriction_trackers import OffsetRange
+from apache_beam.io.restriction_trackers import OffsetRestrictionTracker
+from apache_beam.runners.sdf_utils import RestrictionTrackerView
+from apache_beam.runners.sdf_utils import ThreadsafeRestrictionTracker
+from apache_beam.utils import timestamp
+
+
+class ThreadsafeRestrictionTrackerTest(unittest.TestCase):
+ def test_initialization(self):
+ with self.assertRaises(ValueError):
+ ThreadsafeRestrictionTracker(RangeSource(0, 1))
+
+ def test_defer_remainder_with_wrong_time_type(self):
+ threadsafe_tracker = ThreadsafeRestrictionTracker(
+ OffsetRestrictionTracker(OffsetRange(0, 10)))
+ with self.assertRaises(ValueError):
+ threadsafe_tracker.defer_remainder(10)
+
+ def test_self_checkpoint_immediately(self):
+ restriction_tracker = OffsetRestrictionTracker(OffsetRange(0, 10))
+ threadsafe_tracker = ThreadsafeRestrictionTracker(restriction_tracker)
+ threadsafe_tracker.defer_remainder()
+ deferred_residual, deferred_time = threadsafe_tracker.deferred_status()
+ expected_residual = OffsetRange(0, 10)
+ self.assertEqual(deferred_residual, expected_residual)
+ self.assertTrue(isinstance(deferred_time, timestamp.Duration))
+ self.assertEqual(deferred_time, 0)
+
+ def test_self_checkpoint_with_relative_time(self):
+ threadsafe_tracker = ThreadsafeRestrictionTracker(
+ OffsetRestrictionTracker(OffsetRange(0, 10)))
+ threadsafe_tracker.defer_remainder(timestamp.Duration(100))
+ time.sleep(2)
+ _, deferred_time = threadsafe_tracker.deferred_status()
+ self.assertTrue(isinstance(deferred_time, timestamp.Duration))
+ # The expectation = 100 - 2 - some_delta
+ self.assertTrue(deferred_time <= 98)
+
+ def test_self_checkpoint_with_absolute_time(self):
+ threadsafe_tracker = ThreadsafeRestrictionTracker(
+ OffsetRestrictionTracker(OffsetRange(0, 10)))
+ now = timestamp.Timestamp.now()
+ schedule_time = now + timestamp.Duration(100)
+ self.assertTrue(isinstance(schedule_time, timestamp.Timestamp))
+ threadsafe_tracker.defer_remainder(schedule_time)
+ time.sleep(2)
+ _, deferred_time = threadsafe_tracker.deferred_status()
+ self.assertTrue(isinstance(deferred_time, timestamp.Duration))
+ # The expectation =
+ # schedule_time - the time when deferred_status is called - some_delta
+ self.assertTrue(deferred_time <= 98)
+
+
+class RestrictionTrackerViewTest(unittest.TestCase):
+ def test_initialization(self):
+ with self.assertRaises(ValueError):
+ RestrictionTrackerView(OffsetRestrictionTracker(OffsetRange(0, 10)))
+
+ def test_api_expose(self):
+ threadsafe_tracker = ThreadsafeRestrictionTracker(
+ OffsetRestrictionTracker(OffsetRange(0, 10)))
+ tracker_view = RestrictionTrackerView(threadsafe_tracker)
+ current_restriction = tracker_view.current_restriction()
+ self.assertEqual(current_restriction, OffsetRange(0, 10))
+ self.assertTrue(tracker_view.try_claim(0))
+ tracker_view.defer_remainder()
+ deferred_remainder, deferred_watermark = (
+ threadsafe_tracker.deferred_status())
+ self.assertEqual(deferred_remainder, OffsetRange(1, 10))
+ self.assertEqual(deferred_watermark, timestamp.Duration())
+
+ def test_non_expose_apis(self):
+ threadsafe_tracker = ThreadsafeRestrictionTracker(
+ OffsetRestrictionTracker(OffsetRange(0, 10)))
+ tracker_view = RestrictionTrackerView(threadsafe_tracker)
+ with self.assertRaises(AttributeError):
+ tracker_view.check_done()
+ with self.assertRaises(AttributeError):
+ tracker_view.current_progress()
+ with self.assertRaises(AttributeError):
+ tracker_view.try_split()
+ with self.assertRaises(AttributeError):
+ tracker_view.deferred_status()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py
b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index df383d1..f44b5cf 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -906,26 +906,24 @@ class BundleProcessor(object):
# type: (...) -> beam_fn_api_pb2.DelayedBundleApplication
assert op.input_info is not None
# TODO(SDF): For non-root nodes, need main_input_coder + residual_coder.
- ((element_and_restriction, output_watermark),
- deferred_watermark) = deferred_remainder
- if deferred_watermark:
- assert isinstance(deferred_watermark, timestamp.Duration)
+ (element_and_restriction, current_watermark, deferred_timestamp) = (
+ deferred_remainder)
+ if deferred_timestamp:
+ assert isinstance(deferred_timestamp, timestamp.Duration)
proto_deferred_watermark = duration_pb2.Duration()
- proto_deferred_watermark.FromMicroseconds(deferred_watermark.micros)
+ proto_deferred_watermark.FromMicroseconds(deferred_timestamp.micros)
else:
proto_deferred_watermark = None
return beam_fn_api_pb2.DelayedBundleApplication(
requested_time_delay=proto_deferred_watermark,
application=self.construct_bundle_application(
- op, output_watermark, element_and_restriction))
+ op, current_watermark, element_and_restriction))
def bundle_application(self,
op, # type: operations.DoOperation
primary # type: common.SplitResultType
):
- ((element_and_restriction, output_watermark), _) = primary
- return self.construct_bundle_application(
- op, output_watermark, element_and_restriction)
+ return self.construct_bundle_application(op, None, primary.primary_value)
def construct_bundle_application(self, op, output_watermark, element):
transform_id, main_input_tag, main_input_coder, outputs = op.input_info