This is an automated email from the ASF dual-hosted git repository.

boyuanz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new cd6e54b  Move ThreadsafeRestrictionTracker and RestrictionTrackerView 
out from iobase.py
     new bcc3e13  Merge pull request #10802 from boyuanzz/refactor
cd6e54b is described below

commit cd6e54bc19f5e69cab49d22f5044b1c869b9ec69
Author: Boyuan Zhang <boyu...@google.com>
AuthorDate: Fri Feb 7 16:07:20 2020 -0800

    Move ThreadsafeRestrictionTracker and RestrictionTrackerView out from 
iobase.py
---
 sdks/python/apache_beam/io/iobase.py               | 128 ---------------
 sdks/python/apache_beam/io/iobase_test.py          |  86 +---------
 sdks/python/apache_beam/runners/common.py          |  69 ++++----
 .../runners/portability/fn_api_runner_test.py      |   8 +-
 sdks/python/apache_beam/runners/sdf_utils.py       | 176 +++++++++++++++++++++
 sdks/python/apache_beam/runners/sdf_utils_test.py  | 114 +++++++++++++
 .../apache_beam/runners/worker/bundle_processor.py |  16 +-
 7 files changed, 337 insertions(+), 260 deletions(-)

diff --git a/sdks/python/apache_beam/io/iobase.py 
b/sdks/python/apache_beam/io/iobase.py
index 10d2933..1302c38 100644
--- a/sdks/python/apache_beam/io/iobase.py
+++ b/sdks/python/apache_beam/io/iobase.py
@@ -37,7 +37,6 @@ from __future__ import division
 import logging
 import math
 import random
-import threading
 import uuid
 from builtins import object
 from builtins import range
@@ -65,9 +64,7 @@ from apache_beam.utils import urns
 from apache_beam.utils.windowed_value import WindowedValue
 
 if TYPE_CHECKING:
-  from apache_beam.io import restriction_trackers
   from apache_beam.runners.pipeline_context import PipelineContext
-  from apache_beam.utils.timestamp import Timestamp
 
 __all__ = [
     'BoundedSource',
@@ -1246,131 +1243,6 @@ class RestrictionTracker(object):
     raise NotImplementedError
 
 
-class ThreadsafeRestrictionTracker(object):
-  """A thread-safe wrapper which wraps a `RestritionTracker`.
-
-  This wrapper guarantees synchronization of modifying restrictions across
-  multi-thread.
-  """
-  def __init__(self, restriction_tracker):
-    # type: (RestrictionTracker) -> None
-    if not isinstance(restriction_tracker, RestrictionTracker):
-      raise ValueError(
-          'Initialize ThreadsafeRestrictionTracker requires'
-          'RestrictionTracker.')
-    self._restriction_tracker = restriction_tracker
-    # Records an absolute timestamp when defer_remainder is called.
-    self._deferred_timestamp = None
-    self._lock = threading.RLock()
-    self._deferred_residual = None
-    self._deferred_watermark = None
-
-  def current_restriction(self):
-    with self._lock:
-      return self._restriction_tracker.current_restriction()
-
-  def try_claim(self, position):
-    with self._lock:
-      return self._restriction_tracker.try_claim(position)
-
-  def defer_remainder(self, deferred_time=None):
-    """Performs self-checkpoint on current processing restriction with an
-    expected resuming time.
-
-    Self-checkpoint could happen during processing elements. When executing an
-    DoFn.process(), you may want to stop processing an element and resuming
-    later if current element has been processed quit a long time or you also
-    want to have some outputs from other elements. ``defer_remainder()`` can be
-    called on per element if needed.
-
-    Args:
-      deferred_time: A relative ``timestamp.Duration`` that indicates the ideal
-      time gap between now and resuming, or an absolute ``timestamp.Timestamp``
-      for resuming execution time. If the time_delay is None, the deferred work
-      will be executed as soon as possible.
-    """
-
-    # Record current time for calculating deferred_time later.
-    self._deferred_timestamp = timestamp.Timestamp.now()
-    if (deferred_time and not isinstance(deferred_time, timestamp.Duration) and
-        not isinstance(deferred_time, timestamp.Timestamp)):
-      raise ValueError(
-          'The timestamp of deter_remainder() should be a '
-          'Duration or a Timestamp, or None.')
-    self._deferred_watermark = deferred_time
-    checkpoint = self.try_split(0)
-    if checkpoint:
-      _, self._deferred_residual = checkpoint
-
-  def check_done(self):
-    with self._lock:
-      return self._restriction_tracker.check_done()
-
-  def current_progress(self):
-    with self._lock:
-      return self._restriction_tracker.current_progress()
-
-  def try_split(self, fraction_of_remainder):
-    with self._lock:
-      return self._restriction_tracker.try_split(fraction_of_remainder)
-
-  def deferred_status(self):
-    # type: () -> Optional[Tuple[Any, Timestamp]]
-
-    """Returns deferred work which is produced by ``defer_remainder()``.
-
-    When there is a self-checkpoint performed, the system needs to fulfill the
-    DelayedBundleApplication with deferred_work for a  ProcessBundleResponse.
-    The system calls this API to get deferred_residual with watermark together
-    to help the runner to schedule a future work.
-
-    Returns: (deferred_residual, time_delay) if having any residual, else None.
-    """
-    if self._deferred_residual:
-      # If _deferred_watermark is None, create Duration(0).
-      if not self._deferred_watermark:
-        self._deferred_watermark = timestamp.Duration()
-      # If an absolute timestamp is provided, calculate the delta between
-      # the absoluted time and the time deferred_status() is called.
-      elif isinstance(self._deferred_watermark, timestamp.Timestamp):
-        self._deferred_watermark = (
-            self._deferred_watermark - timestamp.Timestamp.now())
-      # If a Duration is provided, the deferred time should be:
-      # provided duration - the spent time since the defer_remainder() is
-      # called.
-      elif isinstance(self._deferred_watermark, timestamp.Duration):
-        self._deferred_watermark -= (
-            timestamp.Timestamp.now() - self._deferred_timestamp)
-      return self._deferred_residual, self._deferred_watermark
-    return None
-
-
-class RestrictionTrackerView(object):
-  """A DoFn view of thread-safe RestrictionTracker.
-
-  The RestrictionTrackerView wraps a ThreadsafeRestrictionTracker and only
-  exposes APIs that will be called by a ``DoFn.process()``. During execution
-  time, the RestrictionTrackerView will be fed into the ``DoFn.process`` as a
-  restriction_tracker.
-  """
-  def __init__(self, threadsafe_restriction_tracker):
-    if not isinstance(threadsafe_restriction_tracker,
-                      ThreadsafeRestrictionTracker):
-      raise ValueError(
-          'Initialize RestrictionTrackerView requires '
-          'ThreadsafeRestrictionTracker.')
-    self._threadsafe_restriction_tracker = threadsafe_restriction_tracker
-
-  def current_restriction(self):
-    return self._threadsafe_restriction_tracker.current_restriction()
-
-  def try_claim(self, position):
-    return self._threadsafe_restriction_tracker.try_claim(position)
-
-  def defer_remainder(self, deferred_time=None):
-    self._threadsafe_restriction_tracker.defer_remainder(deferred_time)
-
-
 class RestrictionProgress(object):
   """Used to record the progress of a restriction.
 
diff --git a/sdks/python/apache_beam/io/iobase_test.py 
b/sdks/python/apache_beam/io/iobase_test.py
index 7460b71..04be8fa 100644
--- a/sdks/python/apache_beam/io/iobase_test.py
+++ b/sdks/python/apache_beam/io/iobase_test.py
@@ -15,13 +15,12 @@
 # limitations under the License.
 #
 
-"""Unit tests for the SDFRestrictionProvider module."""
+"""Unit tests for classes in iobase.py."""
 
 # pytype: skip-file
 
 from __future__ import absolute_import
 
-import time
 import unittest
 
 import mock
@@ -31,9 +30,6 @@ from apache_beam.io.concat_source import ConcatSource
 from apache_beam.io.concat_source_test import RangeSource
 from apache_beam.io import iobase
 from apache_beam.io.iobase import SourceBundle
-from apache_beam.io.restriction_trackers import OffsetRange
-from apache_beam.io.restriction_trackers import OffsetRestrictionTracker
-from apache_beam.utils import timestamp
 from apache_beam.options.pipeline_options import DebugOptions
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
@@ -227,85 +223,5 @@ class UseSdfBoundedSourcesTests(unittest.TestCase):
     self._run_sdf_wrapper_pipeline(RangeSource(0, 4), [0, 1, 2, 3])
 
 
-class ThreadsafeRestrictionTrackerTest(unittest.TestCase):
-  def test_initialization(self):
-    with self.assertRaises(ValueError):
-      iobase.ThreadsafeRestrictionTracker(RangeSource(0, 1))
-
-  def test_defer_remainder_with_wrong_time_type(self):
-    threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
-        OffsetRestrictionTracker(OffsetRange(0, 10)))
-    with self.assertRaises(ValueError):
-      threadsafe_tracker.defer_remainder(10)
-
-  def test_self_checkpoint_immediately(self):
-    restriction_tracker = OffsetRestrictionTracker(OffsetRange(0, 10))
-    threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
-        restriction_tracker)
-    threadsafe_tracker.defer_remainder()
-    deferred_residual, deferred_time = threadsafe_tracker.deferred_status()
-    expected_residual = OffsetRange(0, 10)
-    self.assertEqual(deferred_residual, expected_residual)
-    self.assertTrue(isinstance(deferred_time, timestamp.Duration))
-    self.assertEqual(deferred_time, 0)
-
-  def test_self_checkpoint_with_relative_time(self):
-    threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
-        OffsetRestrictionTracker(OffsetRange(0, 10)))
-    threadsafe_tracker.defer_remainder(timestamp.Duration(100))
-    time.sleep(2)
-    _, deferred_time = threadsafe_tracker.deferred_status()
-    self.assertTrue(isinstance(deferred_time, timestamp.Duration))
-    # The expectation = 100 - 2 - some_delta
-    self.assertTrue(deferred_time <= 98)
-
-  def test_self_checkpoint_with_absolute_time(self):
-    threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
-        OffsetRestrictionTracker(OffsetRange(0, 10)))
-    now = timestamp.Timestamp.now()
-    schedule_time = now + timestamp.Duration(100)
-    self.assertTrue(isinstance(schedule_time, timestamp.Timestamp))
-    threadsafe_tracker.defer_remainder(schedule_time)
-    time.sleep(2)
-    _, deferred_time = threadsafe_tracker.deferred_status()
-    self.assertTrue(isinstance(deferred_time, timestamp.Duration))
-    # The expectation =
-    # schedule_time - the time when deferred_status is called - some_delta
-    self.assertTrue(deferred_time <= 98)
-
-
-class RestrictionTrackerViewTest(unittest.TestCase):
-  def test_initialization(self):
-    with self.assertRaises(ValueError):
-      iobase.RestrictionTrackerView(
-          OffsetRestrictionTracker(OffsetRange(0, 10)))
-
-  def test_api_expose(self):
-    threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
-        OffsetRestrictionTracker(OffsetRange(0, 10)))
-    tracker_view = iobase.RestrictionTrackerView(threadsafe_tracker)
-    current_restriction = tracker_view.current_restriction()
-    self.assertEqual(current_restriction, OffsetRange(0, 10))
-    self.assertTrue(tracker_view.try_claim(0))
-    tracker_view.defer_remainder()
-    deferred_remainder, deferred_watermark = (
-        threadsafe_tracker.deferred_status())
-    self.assertEqual(deferred_remainder, OffsetRange(1, 10))
-    self.assertEqual(deferred_watermark, timestamp.Duration())
-
-  def test_non_expose_apis(self):
-    threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
-        OffsetRestrictionTracker(OffsetRange(0, 10)))
-    tracker_view = iobase.RestrictionTrackerView(threadsafe_tracker)
-    with self.assertRaises(AttributeError):
-      tracker_view.check_done()
-    with self.assertRaises(AttributeError):
-      tracker_view.current_progress()
-    with self.assertRaises(AttributeError):
-      tracker_view.try_split()
-    with self.assertRaises(AttributeError):
-      tracker_view.deferred_status()
-
-
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/runners/common.py 
b/sdks/python/apache_beam/runners/common.py
index 3651b71..1e17690 100644
--- a/sdks/python/apache_beam/runners/common.py
+++ b/sdks/python/apache_beam/runners/common.py
@@ -44,6 +44,10 @@ from past.builtins import unicode
 from apache_beam.internal import util
 from apache_beam.options.value_provider import RuntimeValueProvider
 from apache_beam.pvalue import TaggedOutput
+from apache_beam.runners.sdf_utils import RestrictionTrackerView
+from apache_beam.runners.sdf_utils import SplitResultPrimary
+from apache_beam.runners.sdf_utils import SplitResultResidual
+from apache_beam.runners.sdf_utils import ThreadsafeRestrictionTracker
 from apache_beam.transforms import DoFn
 from apache_beam.transforms import core
 from apache_beam.transforms import userstate
@@ -57,13 +61,9 @@ from apache_beam.utils.timestamp import Timestamp
 from apache_beam.utils.windowed_value import WindowedValue
 
 if TYPE_CHECKING:
-  from apache_beam.io import iobase
   from apache_beam.transforms import sideinputs
   from apache_beam.transforms.core import TimerSpec
 
-SplitResultType = Tuple[Tuple[WindowedValue, Optional[Timestamp]],
-                        Optional[Timestamp]]
-
 
 class NameContext(object):
   """Holds the name information for a step."""
@@ -414,7 +414,7 @@ class DoFnInvoker(object):
 
   def invoke_process(self,
                      windowed_value,  # type: WindowedValue
-                     restriction_tracker=None,  # type: 
Optional[iobase.RestrictionTracker]
+                     restriction_tracker=None,  # type: 
Optional[RestrictionTracker]
                      additional_args=None,
                      additional_kwargs=None
                     ):
@@ -499,7 +499,7 @@ class SimpleInvoker(DoFnInvoker):
 
   def invoke_process(self,
                      windowed_value,  # type: WindowedValue
-                     restriction_tracker=None,  # type: 
Optional[iobase.RestrictionTracker]
+                     restriction_tracker=None,  # type: 
Optional[RestrictionTracker]
                      additional_args=None,
                      additional_kwargs=None
                     ):
@@ -536,7 +536,7 @@ class PerWindowInvoker(DoFnInvoker):
     self.watermark_estimator_param = (
         self.signature.process_method.watermark_estimator_arg_name
         if self.watermark_estimator else None)
-    self.threadsafe_restriction_tracker = None  # type: 
Optional[iobase.ThreadsafeRestrictionTracker]
+    self.threadsafe_restriction_tracker = None  # type: 
Optional[ThreadsafeRestrictionTracker]
     self.current_windowed_value = None  # type: Optional[WindowedValue]
     self.bundle_finalizer_param = bundle_finalizer_param
     self.is_key_param_required = False
@@ -649,11 +649,10 @@ class PerWindowInvoker(DoFnInvoker):
         raise ValueError(
             'A RestrictionTracker %r was provided but DoFn does not have a '
             'RestrictionTrackerParam defined' % restriction_tracker)
-      from apache_beam.io import iobase
-      self.threadsafe_restriction_tracker = 
iobase.ThreadsafeRestrictionTracker(
+      self.threadsafe_restriction_tracker = ThreadsafeRestrictionTracker(
           restriction_tracker)
       additional_kwargs[restriction_tracker_param] = (
-          iobase.RestrictionTrackerView(self.threadsafe_restriction_tracker))
+          RestrictionTrackerView(self.threadsafe_restriction_tracker))
 
       if self.watermark_estimator:
         # The watermark estimator needs to be reset for every element.
@@ -685,7 +684,7 @@ class PerWindowInvoker(DoFnInvoker):
                                  additional_args,
                                  additional_kwargs,
                                 ):
-    # type: (...) -> Optional[SplitResultType]
+    # type: (...) -> Optional[SplitResultResidual]
     if self.has_windowed_inputs:
       window, = windowed_value.windows
       side_inputs = [si[window] for si in self.side_inputs]
@@ -766,22 +765,23 @@ class PerWindowInvoker(DoFnInvoker):
       # ProcessSizedElementAndRestriction.
       self.threadsafe_restriction_tracker.check_done()
       deferred_status = self.threadsafe_restriction_tracker.deferred_status()
-      output_watermark = None
+      current_watermark = None
       if self.watermark_estimator:
-        output_watermark = self.watermark_estimator.current_watermark()
+        current_watermark = self.watermark_estimator.current_watermark()
       if deferred_status:
-        deferred_restriction, deferred_watermark = deferred_status
+        deferred_restriction, deferred_timestamp = deferred_status
         element = windowed_value.value
         size = self.signature.get_restriction_provider().restriction_size(
             element, deferred_restriction)
-        return ((
-            windowed_value.with_value(((element, deferred_restriction), size)),
-            output_watermark),
-                deferred_watermark)
+        residual_value = ((element, deferred_restriction), size)
+        return SplitResultResidual(
+            residual_value=windowed_value.with_value(residual_value),
+            current_watermark=current_watermark,
+            deferred_timestamp=deferred_timestamp)
     return None
 
   def try_split(self, fraction):
-    # type: (...) -> Optional[Tuple[SplitResultType, SplitResultType]]
+    # type: (...) -> Optional[Tuple[SplitResultPrimary, SplitResultResidual]]
     if self.threadsafe_restriction_tracker and self.current_windowed_value:
       # Temporary workaround for [BEAM-7473]: get current_watermark before
       # split, in case watermark gets advanced before getting split results.
@@ -797,20 +797,21 @@ class PerWindowInvoker(DoFnInvoker):
         restriction_provider = self.signature.get_restriction_provider()
         primary_size = restriction_provider.restriction_size(element, primary)
         residual_size = restriction_provider.restriction_size(element, 
residual)
-        return (((
-            self.current_windowed_value.with_value(
-                ((element, primary), primary_size)),
-            None),
-                 None),
-                ((
-                    self.current_windowed_value.with_value(
-                        ((element, residual), residual_size)),
-                    current_watermark),
-                 None))
+        primary_value = ((element, primary), primary_size)
+        residual_value = ((element, residual), residual_size)
+        return (
+            SplitResultPrimary(
+                primary_value=self.current_windowed_value.with_value(
+                    primary_value)),
+            SplitResultResidual(
+                residual_value=self.current_windowed_value.with_value(
+                    residual_value),
+                current_watermark=current_watermark,
+                deferred_timestamp=None))
     return None
 
   def current_element_progress(self):
-    # type: () -> Optional[iobase.RestrictionProgress]
+    # type: () -> Optional[RestrictionProgress]
     restriction_tracker = self.threadsafe_restriction_tracker
     if restriction_tracker:
       return restriction_tracker.current_progress()
@@ -900,7 +901,7 @@ class DoFnRunner:
         bundle_finalizer_param=self.bundle_finalizer_param)
 
   def process(self, windowed_value):
-    # type: (WindowedValue) -> Optional[SplitResultType]
+    # type: (WindowedValue) -> Optional[SplitResultResidual]
     try:
       return self.do_fn_invoker.invoke_process(windowed_value)
     except BaseException as exn:
@@ -908,7 +909,7 @@ class DoFnRunner:
       return None
 
   def process_with_sized_restriction(self, windowed_value):
-    # type: (WindowedValue) -> Optional[SplitResultType]
+    # type: (WindowedValue) -> Optional[SplitResultResidual]
     (element, restriction), _ = windowed_value.value
     return self.do_fn_invoker.invoke_process(
         windowed_value.with_value(element),
@@ -916,12 +917,12 @@ class DoFnRunner:
             restriction))
 
   def try_split(self, fraction):
-    # type: (...) -> Optional[Tuple[SplitResultType, SplitResultType]]
+    # type: (...) -> Optional[Tuple[SplitResultPrimary, SplitResultResidual]]
     assert isinstance(self.do_fn_invoker, PerWindowInvoker)
     return self.do_fn_invoker.try_split(fraction)
 
   def current_element_progress(self):
-    # type: () -> Optional[iobase.RestrictionProgress]
+    # type: () -> Optional[RestrictionProgress]
     assert isinstance(self.do_fn_invoker, PerWindowInvoker)
     return self.do_fn_invoker.current_element_progress()
 
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py 
b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
index 4e6168e..4e63d67 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
@@ -45,7 +45,6 @@ from tenacity import retry
 from tenacity import stop_after_attempt
 
 import apache_beam as beam
-from apache_beam.io import iobase
 from apache_beam.io import restriction_trackers
 from apache_beam.metrics import monitoring_infos
 from apache_beam.metrics.execution import MetricKey
@@ -53,6 +52,7 @@ from apache_beam.metrics.metricbase import MetricName
 from apache_beam.options.pipeline_options import DebugOptions
 from apache_beam.options.pipeline_options import PipelineOptions
 from apache_beam.runners.portability import fn_api_runner
+from apache_beam.runners.sdf_utils import RestrictionTrackerView
 from apache_beam.runners.worker import data_plane
 from apache_beam.runners.worker import sdk_worker
 from apache_beam.runners.worker import statesampler
@@ -455,7 +455,7 @@ class FnApiRunnerTest(unittest.TestCase):
           element,
           restriction_tracker=beam.DoFn.RestrictionParam(
               ExpandStringsProvider())):
-        assert isinstance(restriction_tracker, iobase.RestrictionTrackerView)
+        assert isinstance(restriction_tracker, RestrictionTrackerView)
         cur = restriction_tracker.current_restriction().start
         while restriction_tracker.try_claim(cur):
           yield element[cur]
@@ -473,7 +473,7 @@ class FnApiRunnerTest(unittest.TestCase):
           element,
           restriction_tracker=beam.DoFn.RestrictionParam(
               ExpandStringsProvider())):
-        assert isinstance(restriction_tracker, iobase.RestrictionTrackerView)
+        assert isinstance(restriction_tracker, RestrictionTrackerView)
         cur = restriction_tracker.current_restriction().start
         while restriction_tracker.try_claim(cur):
           yield element[cur]
@@ -520,7 +520,7 @@ class FnApiRunnerTest(unittest.TestCase):
           element,
           restriction_tracker=beam.DoFn.RestrictionParam(
               ExpandStringsProvider())):
-        assert isinstance(restriction_tracker, iobase.RestrictionTrackerView)
+        assert isinstance(restriction_tracker, RestrictionTrackerView)
         cur = restriction_tracker.current_restriction().start
         while restriction_tracker.try_claim(cur):
           counter.inc()
diff --git a/sdks/python/apache_beam/runners/sdf_utils.py 
b/sdks/python/apache_beam/runners/sdf_utils.py
new file mode 100644
index 0000000..3e882c5
--- /dev/null
+++ b/sdks/python/apache_beam/runners/sdf_utils.py
@@ -0,0 +1,176 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pytype: skip-file
+
+"""Common utility class to help SDK harness to execute an SDF. """
+
+from __future__ import absolute_import
+from __future__ import division
+
+import logging
+import threading
+from builtins import object
+from typing import TYPE_CHECKING
+from typing import Any
+from typing import NamedTuple
+from typing import Optional
+from typing import Tuple
+
+from apache_beam.utils import timestamp
+from apache_beam.utils.windowed_value import WindowedValue
+
+if TYPE_CHECKING:
+  from apache_beam.io.iobase import RestrictionTracker
+  from apache_beam.utils.timestamp import Timestamp
+
+_LOGGER = logging.getLogger(__name__)
+
+SplitResultPrimary = NamedTuple(
+    'SplitResultPrimary', [('primary_value', WindowedValue)])
+
+SplitResultResidual = NamedTuple(
+    'SplitResultResidual',
+    [('residual_value', WindowedValue),
+     ('current_watermark', timestamp.Timestamp),
+     ('deferred_timestamp', timestamp.Duration)])
+
+
+class ThreadsafeRestrictionTracker(object):
+  """A thread-safe wrapper which wraps a `RestritionTracker`.
+
+  This wrapper guarantees synchronization of modifying restrictions across
+  multi-thread.
+  """
+  def __init__(self, restriction_tracker):
+    # type: (RestrictionTracker) -> None
+    from apache_beam.io.iobase import RestrictionTracker
+    if not isinstance(restriction_tracker, RestrictionTracker):
+      raise ValueError(
+          'Initialize ThreadsafeRestrictionTracker requires'
+          'RestrictionTracker.')
+    self._restriction_tracker = restriction_tracker
+    # Records an absolute timestamp when defer_remainder is called.
+    self._deferred_timestamp = None
+    self._lock = threading.RLock()
+    self._deferred_residual = None
+    self._deferred_watermark = None
+
+  def current_restriction(self):
+    with self._lock:
+      return self._restriction_tracker.current_restriction()
+
+  def try_claim(self, position):
+    with self._lock:
+      return self._restriction_tracker.try_claim(position)
+
+  def defer_remainder(self, deferred_time=None):
+    """Performs self-checkpoint on current processing restriction with an
+    expected resuming time.
+
+    Self-checkpoint could happen during processing elements. When executing an
+    DoFn.process(), you may want to stop processing an element and resuming
+    later if current element has been processed quit a long time or you also
+    want to have some outputs from other elements. ``defer_remainder()`` can be
+    called on per element if needed.
+
+    Args:
+      deferred_time: A relative ``timestamp.Duration`` that indicates the ideal
+        time gap between now and resuming, or an absolute
+        ``timestamp.Timestamp`` for resuming execution time. If the time_delay
+        is None, the deferred work will be executed as soon as possible.
+    """
+
+    # Record current time for calculating deferred_time later.
+    self._deferred_timestamp = timestamp.Timestamp.now()
+    if (deferred_time and not isinstance(deferred_time, timestamp.Duration) and
+        not isinstance(deferred_time, timestamp.Timestamp)):
+      raise ValueError(
+          'The timestamp of deter_remainder() should be a '
+          'Duration or a Timestamp, or None.')
+    self._deferred_watermark = deferred_time
+    checkpoint = self.try_split(0)
+    if checkpoint:
+      _, self._deferred_residual = checkpoint
+
+  def check_done(self):
+    with self._lock:
+      return self._restriction_tracker.check_done()
+
+  def current_progress(self):
+    with self._lock:
+      return self._restriction_tracker.current_progress()
+
+  def try_split(self, fraction_of_remainder):
+    with self._lock:
+      return self._restriction_tracker.try_split(fraction_of_remainder)
+
+  def deferred_status(self):
+    # type: () -> Optional[Tuple[Any, Timestamp]]
+
+    """Returns deferred work which is produced by ``defer_remainder()``.
+
+    When there is a self-checkpoint performed, the system needs to fulfill the
+    DelayedBundleApplication with deferred_work for a  ProcessBundleResponse.
+    The system calls this API to get deferred_residual with watermark together
+    to help the runner to schedule a future work.
+
+    Returns: (deferred_residual, time_delay) if having any residual, else None.
+    """
+    if self._deferred_residual:
+      # If _deferred_watermark is None, create Duration(0).
+      if not self._deferred_watermark:
+        self._deferred_watermark = timestamp.Duration()
+      # If an absolute timestamp is provided, calculate the delta between
+      # the absoluted time and the time deferred_status() is called.
+      elif isinstance(self._deferred_watermark, timestamp.Timestamp):
+        self._deferred_watermark = (
+            self._deferred_watermark - timestamp.Timestamp.now())
+      # If a Duration is provided, the deferred time should be:
+      # provided duration - the spent time since the defer_remainder() is
+      # called.
+      elif isinstance(self._deferred_watermark, timestamp.Duration):
+        self._deferred_watermark -= (
+            timestamp.Timestamp.now() - self._deferred_timestamp)
+      return self._deferred_residual, self._deferred_watermark
+    return None
+
+
+class RestrictionTrackerView(object):
+  """A DoFn view of thread-safe RestrictionTracker.
+
+  The RestrictionTrackerView wraps a ThreadsafeRestrictionTracker and only
+  exposes APIs that will be called by a ``DoFn.process()``. During execution
+  time, the RestrictionTrackerView will be fed into the ``DoFn.process`` as a
+  restriction_tracker.
+  """
+  def __init__(self, threadsafe_restriction_tracker):
+    if not isinstance(threadsafe_restriction_tracker,
+                      ThreadsafeRestrictionTracker):
+      raise ValueError(
+          'Initialize RestrictionTrackerView requires '
+          'ThreadsafeRestrictionTracker.')
+    self._threadsafe_restriction_tracker = threadsafe_restriction_tracker
+
+  def current_restriction(self):
+    return self._threadsafe_restriction_tracker.current_restriction()
+
+  def try_claim(self, position):
+    return self._threadsafe_restriction_tracker.try_claim(position)
+
+  def defer_remainder(self, deferred_time=None):
+    self._threadsafe_restriction_tracker.defer_remainder(deferred_time)
diff --git a/sdks/python/apache_beam/runners/sdf_utils_test.py 
b/sdks/python/apache_beam/runners/sdf_utils_test.py
new file mode 100644
index 0000000..30465ff
--- /dev/null
+++ b/sdks/python/apache_beam/runners/sdf_utils_test.py
@@ -0,0 +1,114 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Unit tests for classes in sdf_utils.py."""
+
+# pytype: skip-file
+
+from __future__ import absolute_import
+
+import time
+import unittest
+
+from apache_beam.io.concat_source_test import RangeSource
+from apache_beam.io.restriction_trackers import OffsetRange
+from apache_beam.io.restriction_trackers import OffsetRestrictionTracker
+from apache_beam.runners.sdf_utils import RestrictionTrackerView
+from apache_beam.runners.sdf_utils import ThreadsafeRestrictionTracker
+from apache_beam.utils import timestamp
+
+
+class ThreadsafeRestrictionTrackerTest(unittest.TestCase):
+  def test_initialization(self):
+    with self.assertRaises(ValueError):
+      ThreadsafeRestrictionTracker(RangeSource(0, 1))
+
+  def test_defer_remainder_with_wrong_time_type(self):
+    threadsafe_tracker = ThreadsafeRestrictionTracker(
+        OffsetRestrictionTracker(OffsetRange(0, 10)))
+    with self.assertRaises(ValueError):
+      threadsafe_tracker.defer_remainder(10)
+
+  def test_self_checkpoint_immediately(self):
+    restriction_tracker = OffsetRestrictionTracker(OffsetRange(0, 10))
+    threadsafe_tracker = ThreadsafeRestrictionTracker(restriction_tracker)
+    threadsafe_tracker.defer_remainder()
+    deferred_residual, deferred_time = threadsafe_tracker.deferred_status()
+    expected_residual = OffsetRange(0, 10)
+    self.assertEqual(deferred_residual, expected_residual)
+    self.assertTrue(isinstance(deferred_time, timestamp.Duration))
+    self.assertEqual(deferred_time, 0)
+
+  def test_self_checkpoint_with_relative_time(self):
+    threadsafe_tracker = ThreadsafeRestrictionTracker(
+        OffsetRestrictionTracker(OffsetRange(0, 10)))
+    threadsafe_tracker.defer_remainder(timestamp.Duration(100))
+    time.sleep(2)
+    _, deferred_time = threadsafe_tracker.deferred_status()
+    self.assertTrue(isinstance(deferred_time, timestamp.Duration))
+    # The expectation = 100 - 2 - some_delta
+    self.assertTrue(deferred_time <= 98)
+
+  def test_self_checkpoint_with_absolute_time(self):
+    threadsafe_tracker = ThreadsafeRestrictionTracker(
+        OffsetRestrictionTracker(OffsetRange(0, 10)))
+    now = timestamp.Timestamp.now()
+    schedule_time = now + timestamp.Duration(100)
+    self.assertTrue(isinstance(schedule_time, timestamp.Timestamp))
+    threadsafe_tracker.defer_remainder(schedule_time)
+    time.sleep(2)
+    _, deferred_time = threadsafe_tracker.deferred_status()
+    self.assertTrue(isinstance(deferred_time, timestamp.Duration))
+    # The expectation =
+    # schedule_time - the time when deferred_status is called - some_delta
+    self.assertTrue(deferred_time <= 98)
+
+
+class RestrictionTrackerViewTest(unittest.TestCase):
+  def test_initialization(self):
+    with self.assertRaises(ValueError):
+      RestrictionTrackerView(OffsetRestrictionTracker(OffsetRange(0, 10)))
+
+  def test_api_expose(self):
+    threadsafe_tracker = ThreadsafeRestrictionTracker(
+        OffsetRestrictionTracker(OffsetRange(0, 10)))
+    tracker_view = RestrictionTrackerView(threadsafe_tracker)
+    current_restriction = tracker_view.current_restriction()
+    self.assertEqual(current_restriction, OffsetRange(0, 10))
+    self.assertTrue(tracker_view.try_claim(0))
+    tracker_view.defer_remainder()
+    deferred_remainder, deferred_watermark = (
+        threadsafe_tracker.deferred_status())
+    self.assertEqual(deferred_remainder, OffsetRange(1, 10))
+    self.assertEqual(deferred_watermark, timestamp.Duration())
+
+  def test_non_expose_apis(self):
+    threadsafe_tracker = ThreadsafeRestrictionTracker(
+        OffsetRestrictionTracker(OffsetRange(0, 10)))
+    tracker_view = RestrictionTrackerView(threadsafe_tracker)
+    with self.assertRaises(AttributeError):
+      tracker_view.check_done()
+    with self.assertRaises(AttributeError):
+      tracker_view.current_progress()
+    with self.assertRaises(AttributeError):
+      tracker_view.try_split()
+    with self.assertRaises(AttributeError):
+      tracker_view.deferred_status()
+
+
+if __name__ == '__main__':
+  unittest.main()
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py 
b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index df383d1..f44b5cf 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -906,26 +906,24 @@ class BundleProcessor(object):
     # type: (...) -> beam_fn_api_pb2.DelayedBundleApplication
     assert op.input_info is not None
     # TODO(SDF): For non-root nodes, need main_input_coder + residual_coder.
-    ((element_and_restriction, output_watermark),
-     deferred_watermark) = deferred_remainder
-    if deferred_watermark:
-      assert isinstance(deferred_watermark, timestamp.Duration)
+    (element_and_restriction, current_watermark, deferred_timestamp) = (
+        deferred_remainder)
+    if deferred_timestamp:
+      assert isinstance(deferred_timestamp, timestamp.Duration)
       proto_deferred_watermark = duration_pb2.Duration()
-      proto_deferred_watermark.FromMicroseconds(deferred_watermark.micros)
+      proto_deferred_watermark.FromMicroseconds(deferred_timestamp.micros)
     else:
       proto_deferred_watermark = None
     return beam_fn_api_pb2.DelayedBundleApplication(
         requested_time_delay=proto_deferred_watermark,
         application=self.construct_bundle_application(
-            op, output_watermark, element_and_restriction))
+            op, current_watermark, element_and_restriction))
 
   def bundle_application(self,
                          op,  # type: operations.DoOperation
                          primary  # type: common.SplitResultType
                         ):
-    ((element_and_restriction, output_watermark), _) = primary
-    return self.construct_bundle_application(
-        op, output_watermark, element_and_restriction)
+    return self.construct_bundle_application(op, None, primary.primary_value)
 
   def construct_bundle_application(self, op, output_watermark, element):
     transform_id, main_input_tag, main_input_coder, outputs = op.input_info

Reply via email to