This is an automated email from the ASF dual-hosted git repository.

boyuanz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new a2451d4  [BEAM-7746] Add typing for try_split
     new a9e6d21  Merge pull request #10593 from chadrik/python-typing-try-split
a2451d4 is described below

commit a2451d437d4357e40c204d863fb74b159df910de
Author: Chad Dombrova <[email protected]>
AuthorDate: Thu Jan 16 19:58:45 2020 -0800

    [BEAM-7746] Add typing for try_split
---
 sdks/python/apache_beam/io/iobase.py                     |  1 +
 sdks/python/apache_beam/runners/common.py                | 16 +++++++++++-----
 .../apache_beam/runners/worker/bundle_processor.py       | 16 ++++++++++------
 sdks/python/apache_beam/runners/worker/operations.py     | 16 ++++++++++++++--
 4 files changed, 36 insertions(+), 13 deletions(-)

diff --git a/sdks/python/apache_beam/io/iobase.py 
b/sdks/python/apache_beam/io/iobase.py
index 6f49801..10d2933 100644
--- a/sdks/python/apache_beam/io/iobase.py
+++ b/sdks/python/apache_beam/io/iobase.py
@@ -1342,6 +1342,7 @@ class ThreadsafeRestrictionTracker(object):
         self._deferred_watermark -= (
             timestamp.Timestamp.now() - self._deferred_timestamp)
       return self._deferred_residual, self._deferred_watermark
+    return None
 
 
 class RestrictionTrackerView(object):
diff --git a/sdks/python/apache_beam/runners/common.py 
b/sdks/python/apache_beam/runners/common.py
index 892facf..2c8e0df 100644
--- a/sdks/python/apache_beam/runners/common.py
+++ b/sdks/python/apache_beam/runners/common.py
@@ -61,6 +61,9 @@ if TYPE_CHECKING:
   from apache_beam.transforms import sideinputs
   from apache_beam.transforms.core import TimerSpec
 
+SplitResultType = Tuple[Tuple[WindowedValue, Optional[Timestamp]],
+                        Optional[Timestamp]]
+
 
 class NameContext(object):
   """Holds the name information for a step."""
@@ -415,7 +418,7 @@ class DoFnInvoker(object):
                      additional_args=None,
                      additional_kwargs=None
                     ):
-    # type: (...) -> Optional[Tuple[WindowedValue, Timestamp]]
+    # type: (...) -> Optional[SplitResultType]
 
     """Invokes the DoFn.process() function.
 
@@ -619,7 +622,7 @@ class PerWindowInvoker(DoFnInvoker):
                      additional_args=None,
                      additional_kwargs=None
                     ):
-    # type: (...) -> Optional[Tuple[WindowedValue, Timestamp]]
+    # type: (...) -> Optional[SplitResultType]
     if not additional_args:
       additional_args = []
     if not additional_kwargs:
@@ -682,7 +685,7 @@ class PerWindowInvoker(DoFnInvoker):
                                  additional_args,
                                  additional_kwargs,
                                 ):
-    # type: (...) -> Optional[Tuple[WindowedValue, Timestamp]]
+    # type: (...) -> Optional[SplitResultType]
     if self.has_windowed_inputs:
       window, = windowed_value.windows
       side_inputs = [si[window] for si in self.side_inputs]
@@ -804,6 +807,7 @@ class PerWindowInvoker(DoFnInvoker):
                         ((element, residual), residual_size)),
                     current_watermark),
                  None))
+    return None
 
   def current_element_progress(self):
     # type: () -> Optional[iobase.RestrictionProgress]
@@ -900,7 +904,7 @@ class DoFnRunner(Receiver):
     self.process(windowed_value)
 
   def process(self, windowed_value):
-    # type: (WindowedValue) -> Optional[Tuple[WindowedValue, Timestamp]]
+    # type: (WindowedValue) -> Optional[SplitResultType]
     try:
       return self.do_fn_invoker.invoke_process(windowed_value)
     except BaseException as exn:
@@ -908,7 +912,7 @@ class DoFnRunner(Receiver):
       return None
 
   def process_with_sized_restriction(self, windowed_value):
-    # type: (WindowedValue) -> Optional[Tuple[WindowedValue, Timestamp]]
+    # type: (WindowedValue) -> Optional[SplitResultType]
     (element, restriction), _ = windowed_value.value
     return self.do_fn_invoker.invoke_process(
         windowed_value.with_value(element),
@@ -916,6 +920,8 @@ class DoFnRunner(Receiver):
             restriction))
 
   def try_split(self, fraction):
+    # type: (...) -> Optional[Tuple[SplitResultType, SplitResultType]]
+    assert isinstance(self.do_fn_invoker, PerWindowInvoker)
     return self.do_fn_invoker.try_split(fraction)
 
   def current_element_progress(self):
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py 
b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index 6237d07..df383d1 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -216,16 +216,17 @@ class DataInputOperation(RunnerIOOperation):
       self.output(decoded_value)
 
   def try_split(self, fraction_of_remainder, total_buffer_size):
+    # type: (...) -> Optional[Tuple[int, 
Optional[Tuple[operations.DoOperation, common.SplitResultType]], 
Optional[Tuple[operations.DoOperation, common.SplitResultType]], int]]
     with self.splitting_lock:
       if not self.started:
-        return
+        return None
       if total_buffer_size < self.index + 1:
         total_buffer_size = self.index + 1
       elif self.stop and total_buffer_size > self.stop:
         total_buffer_size = self.stop
       if self.index == -1:
         # We are "finished" with the (non-existent) previous element.
-        current_element_progress = 1
+        current_element_progress = 1.0
       else:
         current_element_progress_object = (
             self.receivers[0].current_element_progress())
@@ -900,7 +901,7 @@ class BundleProcessor(object):
 
   def delayed_bundle_application(self,
                                  op,  # type: operations.DoOperation
-                                 deferred_remainder  # type: 
Tuple[windowed_value.WindowedValue, Timestamp]
+                                 deferred_remainder  # type: 
common.SplitResultType
                                 ):
     # type: (...) -> beam_fn_api_pb2.DelayedBundleApplication
     assert op.input_info is not None
@@ -918,7 +919,10 @@ class BundleProcessor(object):
         application=self.construct_bundle_application(
             op, output_watermark, element_and_restriction))
 
-  def bundle_application(self, op, primary):
+  def bundle_application(self,
+                         op,  # type: operations.DoOperation
+                         primary  # type: common.SplitResultType
+                        ):
     ((element_and_restriction, output_watermark), _) = primary
     return self.construct_bundle_application(
         op, output_watermark, element_and_restriction)
@@ -1043,7 +1047,7 @@ class BundleProcessor(object):
 class ExecutionContext(object):
   def __init__(self):
     self.delayed_applications = [
-    ]  # type: List[Tuple[operations.DoOperation, 
Tuple[windowed_value.WindowedValue, Timestamp]]]
+    ]  # type: List[Tuple[operations.DoOperation, common.SplitResultType]]
 
 
 class BeamTransformFactory(object):
@@ -1412,7 +1416,7 @@ def create_par_do(
 
 
 def _create_pardo_operation(
-    factory,
+    factory,  # type: BeamTransformFactory
     transform_id,  # type: str
     transform_proto,  # type: beam_runner_api_pb2.PTransform
     consumers,
diff --git a/sdks/python/apache_beam/runners/worker/operations.py 
b/sdks/python/apache_beam/runners/worker/operations.py
index e712a30..ec0a843 100644
--- a/sdks/python/apache_beam/runners/worker/operations.py
+++ b/sdks/python/apache_beam/runners/worker/operations.py
@@ -32,6 +32,7 @@ from builtins import filter
 from builtins import object
 from builtins import zip
 from typing import TYPE_CHECKING
+from typing import Any
 from typing import DefaultDict
 from typing import Dict
 from typing import FrozenSet
@@ -39,6 +40,7 @@ from typing import Hashable
 from typing import Iterator
 from typing import List
 from typing import Optional
+from typing import Tuple
 from typing import Union
 
 from apache_beam import pvalue
@@ -132,6 +134,7 @@ class ConsumerSet(Receiver):
     self.update_counters_finish()
 
   def try_split(self, fraction_of_remainder):
+    # type: (...) -> Optional[Any]
     # TODO(SDF): Consider supporting splitting each consumer individually.
     # This would never come up in the existing SDF expansion, but might
     # be useful to support fused SDF nodes.
@@ -169,8 +172,13 @@ class ConsumerSet(Receiver):
 
 
 class SingletonConsumerSet(ConsumerSet):
-  def __init__(
-      self, counter_factory, step_name, output_index, consumers, coder):
+  def __init__(self,
+               counter_factory,
+               step_name,
+               output_index,
+               consumers,  # type: List[Operation]
+               coder
+              ):
     assert len(consumers) == 1
     super(SingletonConsumerSet, self).__init__(
         counter_factory, step_name, output_index, consumers, coder)
@@ -183,6 +191,7 @@ class SingletonConsumerSet(ConsumerSet):
     self.update_counters_finish()
 
   def try_split(self, fraction_of_remainder):
+    # type: (...) -> Optional[Any]
     return self.consumer.try_split(fraction_of_remainder)
 
   def current_element_progress(self):
@@ -288,6 +297,7 @@ class Operation(object):
     return False
 
   def try_split(self, fraction_of_remainder):
+    # type: (...) -> Optional[Any]
     return None
 
   def current_element_progress(self):
@@ -774,10 +784,12 @@ class SdfProcessSizedElements(DoOperation):
           self.element_start_output_bytes = None
 
   def try_split(self, fraction_of_remainder):
+    # type: (...) -> Optional[Tuple[Tuple[DoOperation, 
common.SplitResultType], Tuple[DoOperation, common.SplitResultType]]]
     split = self.dofn_runner.try_split(fraction_of_remainder)
     if split:
       primary, residual = split
       return (self, primary), (self, residual)
+    return None
 
   def current_element_progress(self):
     with self.lock:

Reply via email to