This is an automated email from the ASF dual-hosted git repository.

robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 216f0d9  Fix grpc data read thread block with finished instruction_id 
in _GrpcDataChannel (#15293)
216f0d9 is described below

commit 216f0d9f80c3f2a169e139a0818a1b6b059f3219
Author: Minbo Bae <[email protected]>
AuthorDate: Wed Aug 18 06:58:17 2021 +0900

    Fix grpc data read thread block with finished instruction_id in 
_GrpcDataChannel (#15293)
---
 .../apache_beam/runners/worker/data_plane.py       | 193 ++++++++++++++-------
 1 file changed, 133 insertions(+), 60 deletions(-)

diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py 
b/sdks/python/apache_beam/runners/worker/data_plane.py
index ffe48d1..e89a669 100644
--- a/sdks/python/apache_beam/runners/worker/data_plane.py
+++ b/sdks/python/apache_beam/runners/worker/data_plane.py
@@ -70,6 +70,10 @@ _LOGGER = logging.getLogger(__name__)
 _DEFAULT_SIZE_FLUSH_THRESHOLD = 10 << 20  # 10MB
 _DEFAULT_TIME_FLUSH_THRESHOLD_MS = 0  # disable time-based flush by default
 
+# Keep a set of completed instructions to discard late received data. The set
+# can have up to _MAX_CLEANED_INSTRUCTIONS items. See _GrpcDataChannel.
+_MAX_CLEANED_INSTRUCTIONS = 10000
+
 
 class ClosableOutputStream(OutputStream):
   """A Outputstream for use with CoderImpls that has a close() method."""
@@ -95,10 +99,11 @@ class ClosableOutputStream(OutputStream):
     pass
 
   @staticmethod
-  def create(close_callback,  # type: Optional[Callable[[bytes], None]]
-             flush_callback,  # type: Optional[Callable[[bytes], None]]
-             data_buffer_time_limit_ms  # type: int
-            ):
+  def create(
+      close_callback,  # type: Optional[Callable[[bytes], None]]
+      flush_callback,  # type: Optional[Callable[[bytes], None]]
+      data_buffer_time_limit_ms  # type: int
+  ):
     # type: (...) -> ClosableOutputStream
     if data_buffer_time_limit_ms > 0:
       return TimeBasedBufferingClosableOutputStream(
@@ -112,12 +117,12 @@ class ClosableOutputStream(OutputStream):
 
 class SizeBasedBufferingClosableOutputStream(ClosableOutputStream):
   """A size-based buffering OutputStream."""
-
-  def __init__(self,
-               close_callback=None,  # type: Optional[Callable[[bytes], None]]
-               flush_callback=None,  # type: Optional[Callable[[bytes], None]]
-               size_flush_threshold=_DEFAULT_SIZE_FLUSH_THRESHOLD  # type: int
-              ):
+  def __init__(
+      self,
+      close_callback=None,  # type: Optional[Callable[[bytes], None]]
+      flush_callback=None,  # type: Optional[Callable[[bytes], None]]
+      size_flush_threshold=_DEFAULT_SIZE_FLUSH_THRESHOLD  # type: int
+  ):
     super(SizeBasedBufferingClosableOutputStream, 
self).__init__(close_callback)
     self._flush_callback = flush_callback
     self._size_flush_threshold = size_flush_threshold
@@ -187,12 +192,13 @@ class TimeBasedBufferingClosableOutputStream(
 
 class PeriodicThread(threading.Thread):
   """Call a function periodically with the specified number of seconds"""
-  def __init__(self,
-               interval,  # type: float
-               function,  # type: Callable
-               args=None,  # type: Optional[Iterable]
-               kwargs=None  # type: Optional[Mapping[str, Any]]
-              ):
+  def __init__(
+      self,
+      interval,  # type: float
+      function,  # type: Callable
+      args=None,  # type: Optional[Iterable]
+      kwargs=None  # type: Optional[Mapping[str, Any]]
+  ):
     # type: (...) -> None
     threading.Thread.__init__(self)
     self._interval = interval
@@ -243,11 +249,12 @@ class DataChannel(metaclass=abc.ABCMeta):
     data_channel.close()
   """
   @abc.abstractmethod
-  def input_elements(self,
-                     instruction_id,  # type: str
-                     expected_inputs,  # type: Collection[Union[str, 
Tuple[str, str]]]
-                     abort_callback=None  # type: Optional[Callable[[], bool]]
-                    ):
+  def input_elements(
+      self,
+      instruction_id,  # type: str
+      expected_inputs,  # type: Collection[Union[str, Tuple[str, str]]]
+      abort_callback=None  # type: Optional[Callable[[], bool]]
+  ):
     # type: (...) -> Iterator[DataOrTimers]
 
     """Returns an iterable of all Element.Data and Element.Timers bundles for
@@ -281,11 +288,12 @@ class DataChannel(metaclass=abc.ABCMeta):
     raise NotImplementedError(type(self))
 
   @abc.abstractmethod
-  def output_timer_stream(self,
-                          instruction_id,  # type: str
-                          transform_id,  # type: str
-                          timer_family_id  # type: str
-                          ):
+  def output_timer_stream(
+      self,
+      instruction_id,  # type: str
+      transform_id,  # type: str
+      timer_family_id  # type: str
+  ):
     # type: (...) -> ClosableOutputStream
 
     """Returns an output stream written timers to transform_id.
@@ -328,11 +336,12 @@ class InMemoryDataChannel(DataChannel):
     # type: () -> InMemoryDataChannel
     return self._inverse
 
-  def input_elements(self,
+  def input_elements(
+      self,
       instruction_id,  # type: str
-      unused_expected_inputs,   # type: Any
+      unused_expected_inputs,  # type: Any
       abort_callback=None  # type: Optional[Callable[[], bool]]
-                     ):
+  ):
     # type: (...) -> Iterator[DataOrTimers]
     other_inputs = []
     for element in self._inputs:
@@ -347,11 +356,12 @@ class InMemoryDataChannel(DataChannel):
         other_inputs.append(element)
     self._inputs = other_inputs
 
-  def output_timer_stream(self,
-                          instruction_id,  # type: str
-                          transform_id,  # type: str
-                          timer_family_id  # type: str
-                          ):
+  def output_timer_stream(
+      self,
+      instruction_id,  # type: str
+      transform_id,  # type: str
+      timer_family_id  # type: str
+  ):
     # type: (...) -> ClosableOutputStream
     def add_to_inverse_output(timer):
       # type: (bytes) -> None
@@ -410,6 +420,12 @@ class _GrpcDataChannel(DataChannel):
         lambda: queue.Queue(maxsize=5)
     )  # type: DefaultDict[str, queue.Queue[DataOrTimers]]
 
+    # Keep a cache of completed instructions. Data for completed instructions
+    # must be discarded. See input_elements() and _clean_receiving_queue().
+    # OrderedDict is used as FIFO set with the value being always `True`.
+    self._cleaned_instruction_ids = collections.OrderedDict(
+    )  # type: collections.OrderedDict[str, bool]
+
     self._receive_lock = threading.Lock()
     self._reads_finished = threading.Event()
     self._closed = False
@@ -425,20 +441,37 @@ class _GrpcDataChannel(DataChannel):
     self._reads_finished.wait(timeout)
 
   def _receiving_queue(self, instruction_id):
-    # type: (str) -> queue.Queue[DataOrTimers]
+    # type: (str) -> Optional[queue.Queue[DataOrTimers]]
+
+    """
+    Gets or creates queue for a instruction_id. Or, returns None if the
+    instruction_id is already cleaned up. This is best-effort as we track
+    a limited number of cleaned-up instructions.
+    """
     with self._receive_lock:
+      if instruction_id in self._cleaned_instruction_ids:
+        return None
       return self._received[instruction_id]
 
   def _clean_receiving_queue(self, instruction_id):
     # type: (str) -> None
+
+    """
+    Removes the queue and adds the instruction_id to the cleaned-up list. The
+    instruction_id cannot be reused for new queue.
+    """
     with self._receive_lock:
       self._received.pop(instruction_id)
+      self._cleaned_instruction_ids[instruction_id] = True
+      while len(self._cleaned_instruction_ids) > _MAX_CLEANED_INSTRUCTIONS:
+        self._cleaned_instruction_ids.popitem(last=False)
 
-  def input_elements(self,
+  def input_elements(
+      self,
       instruction_id,  # type: str
-      expected_inputs,   # type: Collection[Union[str, Tuple[str, str]]]
+      expected_inputs,  # type: Collection[Union[str, Tuple[str, str]]]
       abort_callback=None  # type: Optional[Callable[[], bool]]
-    ):
+  ):
 
     # type: (...) -> Iterator[DataOrTimers]
 
@@ -451,6 +484,8 @@ class _GrpcDataChannel(DataChannel):
       expected_inputs(collection): expected inputs, include both data and 
timer.
     """
     received = self._receiving_queue(instruction_id)
+    if received is None:
+      raise RuntimeError('Instruction cleaned up already %s' % instruction_id)
     done_inputs = set()  # type: Set[Union[str, Tuple[str, str]]]
     abort_callback = abort_callback or (lambda: False)
     try:
@@ -508,11 +543,12 @@ class _GrpcDataChannel(DataChannel):
     return ClosableOutputStream.create(
         close_callback, add_to_send_queue, self._data_buffer_time_limit_ms)
 
-  def output_timer_stream(self,
-                          instruction_id,  # type: str
-                          transform_id,  # type: str
-                          timer_family_id  # type: str
-                          ):
+  def output_timer_stream(
+      self,
+      instruction_id,  # type: str
+      transform_id,  # type: str
+      timer_family_id  # type: str
+  ):
     # type: (...) -> ClosableOutputStream
     def add_to_send_queue(timer):
       # type: (bytes) -> None
@@ -566,12 +602,48 @@ class _GrpcDataChannel(DataChannel):
 
   def _read_inputs(self, elements_iterator):
     # type: (Iterable[beam_fn_api_pb2.Elements]) -> None
+
+    next_discard_log_time = 0  # type: float
+
+    def _put_queue(instruction_id, element):
+      # type: (str, Union[beam_fn_api_pb2.Elements.Data, 
beam_fn_api_pb2.Elements.Timers]) -> None
+
+      """
+      Puts element to the queue of the instruction_id, or discards it if the
+      instruction_id is already cleaned up.
+      """
+      nonlocal next_discard_log_time
+      start_time = time.time()
+      next_waiting_log_time = start_time + 300
+      while True:
+        input_queue = self._receiving_queue(instruction_id)
+        if input_queue is None:
+          current_time = time.time()
+          if next_discard_log_time <= current_time:
+            # Log every 10 seconds across all _put_queue calls
+            _LOGGER.info(
+                'Discard inputs for cleaned up instruction: %s', 
instruction_id)
+            next_discard_log_time = current_time + 10
+          return
+        try:
+          input_queue.put(element, timeout=1)
+          return
+        except queue.Full:
+          current_time = time.time()
+          if next_waiting_log_time <= current_time:
+            # Log every 5 mins in each _put_queue call
+            _LOGGER.info(
+                'Waiting on input queue of instruction: %s for %.2f seconds',
+                instruction_id,
+                current_time - start_time)
+            next_waiting_log_time = current_time + 300
+
     try:
       for elements in elements_iterator:
         for timer in elements.timers:
-          self._receiving_queue(timer.instruction_id).put(timer)
+          _put_queue(timer.instruction_id, timer)
         for data in elements.data:
-          self._receiving_queue(data.instruction_id).put(data)
+          _put_queue(data.instruction_id, data)
     except:  # pylint: disable=bare-except
       if not self._closed:
         _LOGGER.exception('Failed to read inputs in the data plane.')
@@ -592,11 +664,11 @@ class _GrpcDataChannel(DataChannel):
 
 class GrpcClientDataChannel(_GrpcDataChannel):
   """A DataChannel wrapping the client side of a BeamFnData connection."""
-
-  def __init__(self,
-               data_stub,  # type: beam_fn_api_pb2_grpc.BeamFnDataStub
-               data_buffer_time_limit_ms=0  # type: int
-               ):
+  def __init__(
+      self,
+      data_stub,  # type: beam_fn_api_pb2_grpc.BeamFnDataStub
+      data_buffer_time_limit_ms=0  # type: int
+  ):
     # type: (...) -> None
     super(GrpcClientDataChannel, self).__init__(data_buffer_time_limit_ms)
     self.set_inputs(data_stub.Data(self._write_outputs()))
@@ -618,10 +690,11 @@ class 
BeamFnDataServicer(beam_fn_api_pb2_grpc.BeamFnDataServicer):
     with self._lock:
       return self._connections_by_worker_id[worker_id]
 
-  def Data(self,
-           elements_iterator,  # type: Iterable[beam_fn_api_pb2.Elements]
-           context  # type: Any
-          ):
+  def Data(
+      self,
+      elements_iterator,  # type: Iterable[beam_fn_api_pb2.Elements]
+      context  # type: Any
+  ):
     # type: (...) -> Iterator[beam_fn_api_pb2.Elements]
     worker_id = dict(context.invocation_metadata())['worker_id']
     data_conn = self.get_conn_by_worker_id(worker_id)
@@ -659,12 +732,12 @@ class GrpcClientDataChannelFactory(DataChannelFactory):
 
   Caches the created channels by ``data descriptor url``.
   """
-
-  def __init__(self,
-               credentials=None,  # type: Any
-               worker_id=None,  # type: Optional[str]
-               data_buffer_time_limit_ms=0  # type: int
-               ):
+  def __init__(
+      self,
+      credentials=None,  # type: Any
+      worker_id=None,  # type: Optional[str]
+      data_buffer_time_limit_ms=0  # type: int
+  ):
     # type: (...) -> None
     self._data_channel_cache = {}  # type: Dict[str, GrpcClientDataChannel]
     self._lock = threading.Lock()

Reply via email to