This is an automated email from the ASF dual-hosted git repository.

tvalentyn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new fbe61eab131 [Python] Python] Bound the memory used for fnapi outbound 
data messages and receiving messages. (#38407)
fbe61eab131 is described below

commit fbe61eab13151dd95c80ebb0c39bab5e630ab73b
Author: Sam Whittle <[email protected]>
AuthorDate: Tue May 12 19:37:32 2026 +0200

    [Python] Python] Bound the memory used for fnapi outbound data messages and 
receiving messages. (#38407)
    
    * [Python] Bound the memory used for fnapi outbound data messages.
    
    Previously an unbounded queue was used for pending data outputs to
    be sent over the fnapi to the runner. If outputs were being generated
    faster than the runner was consuming them, this would lead to memory
    growth and possible OOMs. This PR introduces a byte-limited queue
    data structure that is used instead to limit the # of bytes in the
    queue. This was preferred to just using a queue with max number of
    elements because the size of elements can vary greatly.  For batch
    pipelines they are likely large while for stremaing pipelines there
    may be more small outputs.
    
    * monotonic and not shutdown restriction
    
    * change to not subclass queue.Queue and to be fair
    
    * fixups
    
    * add missing pxd file, fixup test
    
    * use 64-bit for size in pxd
    
    * address comments
    
    * add condition caching
---
 .../apache_beam/runners/worker/data_plane.py       |  70 +++---
 .../apache_beam/utils/byte_limited_queue.pxd       |  31 +++
 .../python/apache_beam/utils/byte_limited_queue.py | 204 ++++++++++++++++
 .../apache_beam/utils/byte_limited_queue_test.py   | 270 +++++++++++++++++++++
 sdks/python/setup.py                               |   1 +
 5 files changed, 547 insertions(+), 29 deletions(-)

diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py 
b/sdks/python/apache_beam/runners/worker/data_plane.py
index cbd28f8b0a3..a5589ac33a1 100644
--- a/sdks/python/apache_beam/runners/worker/data_plane.py
+++ b/sdks/python/apache_beam/runners/worker/data_plane.py
@@ -49,6 +49,7 @@ from apache_beam.portability.api import beam_fn_api_pb2
 from apache_beam.portability.api import beam_fn_api_pb2_grpc
 from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
 from apache_beam.runners.worker.worker_id_interceptor import 
WorkerIdInterceptor
+from apache_beam.utils.byte_limited_queue import ByteLimitedQueue
 
 if TYPE_CHECKING:
   import apache_beam.coders.slow_stream
@@ -455,11 +456,14 @@ class _GrpcDataChannel(DataChannel):
 
   def __init__(self, data_buffer_time_limit_ms=0):
     # type: (int) -> None
+
     self._data_buffer_time_limit_ms = data_buffer_time_limit_ms
-    self._to_send = queue.Queue()  # type: queue.Queue[DataOrTimers]
+    self._to_send = ByteLimitedQueue(
+        maxsize=10000,
+        maxbytes=100 << 20)  # type: ByteLimitedQueue[DataOrTimers]
     self._received = collections.defaultdict(
-        lambda: queue.Queue(maxsize=5)
-    )  # type: DefaultDict[str, queue.Queue[DataOrTimers]]
+        lambda: ByteLimitedQueue(maxsize=5, maxbytes=100 << 20)
+    )  # type: DefaultDict[str, ByteLimitedQueue[DataOrTimers]]
 
     # Keep a cache of completed instructions. Data for completed instructions
     # must be discarded. See input_elements() and _clean_receiving_queue().
@@ -474,7 +478,7 @@ class _GrpcDataChannel(DataChannel):
 
   def close(self):
     # type: () -> None
-    self._to_send.put(self._WRITES_FINISHED)
+    self._to_send.put(self._WRITES_FINISHED, 0)
     self._closed = True
 
   def wait(self, timeout=None):
@@ -482,7 +486,7 @@ class _GrpcDataChannel(DataChannel):
     self._reads_finished.wait(timeout)
 
   def _receiving_queue(self, instruction_id):
-    # type: (str) -> Optional[queue.Queue[DataOrTimers]]
+    # type: (str) -> Optional[ByteLimitedQueue[DataOrTimers]]
 
     """
     Gets or creates queue for a instruction_id. Or, returns None if the
@@ -585,21 +589,19 @@ class _GrpcDataChannel(DataChannel):
     def add_to_send_queue(data):
       # type: (bytes) -> None
       if data:
-        self._to_send.put(
-            beam_fn_api_pb2.Elements.Data(
-                instruction_id=instruction_id,
-                transform_id=transform_id,
-                data=data))
+        elem = beam_fn_api_pb2.Elements.Data(
+            instruction_id=instruction_id, transform_id=transform_id, 
data=data)
+        self._to_send.put(elem, self._get_element_size_bytes(elem))
 
     def close_callback(data):
       # type: (bytes) -> None
       add_to_send_queue(data)
       # End of stream marker.
-      self._to_send.put(
-          beam_fn_api_pb2.Elements.Data(
-              instruction_id=instruction_id,
-              transform_id=transform_id,
-              is_last=True))
+      elem = beam_fn_api_pb2.Elements.Data(
+          instruction_id=instruction_id,
+          transform_id=transform_id,
+          is_last=True)
+      self._to_send.put(elem, self._get_element_size_bytes(elem))
 
     return ClosableOutputStream.create(
         close_callback, add_to_send_queue, self._data_buffer_time_limit_ms)
@@ -614,23 +616,23 @@ class _GrpcDataChannel(DataChannel):
     def add_to_send_queue(timer):
       # type: (bytes) -> None
       if timer:
-        self._to_send.put(
-            beam_fn_api_pb2.Elements.Timers(
-                instruction_id=instruction_id,
-                transform_id=transform_id,
-                timer_family_id=timer_family_id,
-                timers=timer,
-                is_last=False))
+        elem = beam_fn_api_pb2.Elements.Timers(
+            instruction_id=instruction_id,
+            transform_id=transform_id,
+            timer_family_id=timer_family_id,
+            timers=timer,
+            is_last=False)
+        self._to_send.put(elem, self._get_element_size_bytes(elem))
 
     def close_callback(timer):
       # type: (bytes) -> None
       add_to_send_queue(timer)
-      self._to_send.put(
-          beam_fn_api_pb2.Elements.Timers(
-              instruction_id=instruction_id,
-              transform_id=transform_id,
-              timer_family_id=timer_family_id,
-              is_last=True))
+      elem = beam_fn_api_pb2.Elements.Timers(
+          instruction_id=instruction_id,
+          transform_id=transform_id,
+          timer_family_id=timer_family_id,
+          is_last=True)
+      self._to_send.put(elem, self._get_element_size_bytes(elem))
 
     return ClosableOutputStream.create(
         close_callback, add_to_send_queue, self._data_buffer_time_limit_ms)
@@ -665,6 +667,15 @@ class _GrpcDataChannel(DataChannel):
             raise ValueError('Unexpected output element type %s' % 
type(stream))
         yield beam_fn_api_pb2.Elements(data=data_stream, timers=timer_stream)
 
+  def _get_element_size_bytes(self, element):
+    # type: (Union[beam_fn_api_pb2.Elements.Data, 
beam_fn_api_pb2.Elements.Timers]) -> int
+    if isinstance(element, beam_fn_api_pb2.Elements.Data):
+      return len(element.data)
+    elif isinstance(element, beam_fn_api_pb2.Elements.Timers):
+      return len(element.timers)
+    else:
+      return 0
+
   def _read_inputs(self, elements_iterator):
     # type: (Iterable[beam_fn_api_pb2.Elements]) -> None
 
@@ -691,7 +702,8 @@ class _GrpcDataChannel(DataChannel):
             next_discard_log_time = current_time + 10
           return
         try:
-          input_queue.put(element, timeout=1)
+          input_queue.put(
+              element, self._get_element_size_bytes(element), timeout=1)
           return
         except queue.Full:
           current_time = time.time()
diff --git a/sdks/python/apache_beam/utils/byte_limited_queue.pxd 
b/sdks/python/apache_beam/utils/byte_limited_queue.pxd
new file mode 100644
index 00000000000..396185e8e10
--- /dev/null
+++ b/sdks/python/apache_beam/utils/byte_limited_queue.pxd
@@ -0,0 +1,31 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# cython: overflowcheck=True
+
+cdef class ByteLimitedQueue(object):
+  cdef readonly Py_ssize_t max_elements
+  cdef readonly Py_ssize_t max_bytes
+  cdef readonly Py_ssize_t _byte_size
+  cdef readonly object _mutex
+  cdef readonly object _not_empty
+  cdef readonly object _waiting_writers
+  cdef readonly list _condition_pool
+  cdef readonly object _queue
+  cdef readonly Py_ssize_t _blocked_bytes
+
+  cpdef bint _can_fit(self, Py_ssize_t item_bytes) except -1
diff --git a/sdks/python/apache_beam/utils/byte_limited_queue.py 
b/sdks/python/apache_beam/utils/byte_limited_queue.py
new file mode 100644
index 00000000000..a6ff669800c
--- /dev/null
+++ b/sdks/python/apache_beam/utils/byte_limited_queue.py
@@ -0,0 +1,204 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""A thread-safe queue that limits capacity by total byte size."""
+
+import collections
+import queue
+import threading
+import time
+import types
+
+
+class ByteLimitedQueue(object):
+  """A fair queue that limits by both element count and total byte size.
+
+  A single element is allowed to exceed the maxbytes to avoid deadlock.
+  """
+  __class_getitem__ = classmethod(types.GenericAlias)
+
+  def __init__(
+      self,
+      maxsize=0,  # type: int
+      maxbytes=0,  # type: int
+  ):
+    # type: (...) -> None
+
+    """Initializes a ByteLimitedQueue.
+
+    Args:
+      maxsize: The maximum number of items allowed in the queue. If 0 or
+        negative, there is no limit on the number of elements.
+      maxbytes: The maximum accumulated bytes allowed in the queue. If 0 or
+        negative, there is no limit on the total bytes of the elements.
+    """
+    self.max_elements = maxsize
+    self.max_bytes = maxbytes
+
+    self._byte_size = 0
+    self._blocked_bytes = 0
+    self._mutex = threading.Lock()
+    self._not_empty = threading.Condition(self._mutex)
+
+    self._waiting_writers = collections.deque()
+    self._condition_pool = []
+    self._queue = collections.deque()
+
+  def put(self, item, item_bytes, *, block=True, timeout=None):
+    """Put an item into the queue.
+
+    If the queue is full, block until a free slot is available, unless `block`
+    is false or a timeout occurs.
+
+    Args:
+      item: The item to put into the queue.
+      item_bytes: The size of the item.
+      block: If True, block until space is available. If False, raise 
queue.Full
+        immediately if the queue is full.
+      timeout: If block is True, wait for at most `timeout` seconds. If None,
+        block indefinitely.
+
+    Raises:
+      ValueError: If timeout or item_bytes is negative.
+      queue.Full: If the queue is full and block is False or the timeout 
occurs.
+    """
+    if timeout is not None and timeout < 0:
+      raise ValueError("'timeout' must be a non-negative number")
+    if item_bytes < 0:
+      raise ValueError("'item_bytes' must be a non-negative number")
+
+    with self._mutex:
+      if not self._waiting_writers and self._can_fit(item_bytes):
+        self._queue.append((item, item_bytes))
+        self._byte_size += item_bytes
+        self._not_empty.notify()
+        return
+
+      if not block:
+        raise queue.Full
+
+      # Reuse or create a condition
+      my_cond = (
+          self._condition_pool.pop()
+          if self._condition_pool else threading.Condition(self._mutex))
+
+      endtime = time.monotonic() + timeout if timeout is not None else None
+
+      try:
+        self._blocked_bytes += item_bytes
+        self._waiting_writers.append(my_cond)
+        while True:
+          if timeout is None:
+            my_cond.wait()
+          else:
+            remaining = endtime - time.monotonic()
+            if remaining <= 0.0:
+              raise queue.Full
+            my_cond.wait(remaining)
+
+          if self._waiting_writers[0] is my_cond and self._can_fit(item_bytes):
+            break
+
+        self._queue.append((item, item_bytes))
+        self._byte_size += item_bytes
+        self._not_empty.notify()
+      finally:
+        self._blocked_bytes -= item_bytes
+        if self._waiting_writers:
+          was_first = (self._waiting_writers[0] is my_cond)
+          if was_first:
+            self._waiting_writers.popleft()
+          else:
+            self._waiting_writers.remove(my_cond)
+          self._condition_pool.append(my_cond)
+          if was_first and self._waiting_writers:
+            self._waiting_writers[0].notify()
+
+  def get(self, *, block=True, timeout=None):
+    """Remove and return an item from the queue.
+
+    If the queue is empty, block until an item is available, unless `block`
+    is false or a timeout occurs.
+
+    Args:
+      block: If True, block until an item is available. If False, raise
+        queue.Empty immediately if the queue is empty.
+      timeout: If block is True, wait for at most `timeout` seconds. If None,
+        block indefinitely.
+
+    Returns:
+      The item removed from the queue.
+
+    Raises:
+      ValueError: If timeout is negative.
+      queue.Empty: If the queue is empty and block is False or the timeout
+        occurs.
+    """
+    if timeout is not None and timeout < 0:
+      raise ValueError("'timeout' must be a non-negative number")
+
+    with self._mutex:
+      if not block:
+        if not self._queue:
+          raise queue.Empty
+      elif timeout is None:
+        while not self._queue:
+          self._not_empty.wait()
+      else:
+        endtime = time.monotonic() + timeout
+        while not self._queue:
+          remaining = endtime - time.monotonic()
+          if remaining <= 0.0:
+            raise queue.Empty
+          self._not_empty.wait(remaining)
+
+      item, item_bytes = self._queue.popleft()
+      self._byte_size -= item_bytes
+
+      if self._waiting_writers:
+        self._waiting_writers[0].notify()
+
+      return item
+
+  def get_nowait(self):
+    """Remove and return an item from the queue without blocking."""
+    return self.get(block=False)
+
+  def byte_size(self):
+    """Return the total byte size of elements in the queue."""
+    with self._mutex:
+      return self._byte_size
+
+  def blocked_byte_size(self):
+    """Return the total byte size of elements in the queue that are blocked."""
+    with self._mutex:
+      return self._blocked_bytes
+
+  def qsize(self):
+    """Return the total number of elements in the queue."""
+    with self._mutex:
+      return len(self._queue)
+
+  def _can_fit(self, item_bytes):
+    # Always let in a single element, regardless of size.
+    if not self._queue:
+      return True
+    if self.max_elements > 0 and len(self._queue) >= self.max_elements:
+      return False
+    if self.max_bytes > 0 and self._byte_size + item_bytes > self.max_bytes:
+      return False
+    return True
diff --git a/sdks/python/apache_beam/utils/byte_limited_queue_test.py 
b/sdks/python/apache_beam/utils/byte_limited_queue_test.py
new file mode 100644
index 00000000000..27ccb242184
--- /dev/null
+++ b/sdks/python/apache_beam/utils/byte_limited_queue_test.py
@@ -0,0 +1,270 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Unit tests for byte-limited queue."""
+
+import queue
+import threading
+import time
+import unittest
+
+from apache_beam.utils.byte_limited_queue import ByteLimitedQueue
+
+
+class ByteLimitedQueueTest(unittest.TestCase):
+  def test_unbounded(self):
+    bq = ByteLimitedQueue()
+    for i in range(201):
+      bq.put(str(i), i)
+    self.assertEqual(bq.byte_size(), sum(range(201)))
+    self.assertEqual(bq.qsize(), 201)
+
+  def test_put_and_get(self):
+    bq = ByteLimitedQueue(maxbytes=200)
+    bq.put('50', 50)
+    bq.put('140', 140)
+    self.assertEqual(bq.byte_size(), 190)
+    self.assertEqual(bq.qsize(), 2)
+    # Putting another would exceed 200.
+    with self.assertRaises(queue.Full):
+      bq.put('20', 20, block=False)
+    bq.put('10', 10, block=False)
+    self.assertEqual(bq.byte_size(), 200)
+    self.assertEqual(bq.qsize(), 3)
+
+    self.assertEqual(bq.get(), '50')
+    self.assertEqual(bq.byte_size(), 150)
+    self.assertEqual(bq.qsize(), 2)
+    bq.put('20', 20, block=False)
+
+  def test_dual_limit(self):
+    # Queue limits: at most 3 items, OR at most 100 item bytes.
+    bq = ByteLimitedQueue(maxsize=3, maxbytes=100)
+    bq.put('30', 30)
+    bq.put('40', 40)
+    bq.put('20', 20)
+    self.assertEqual(bq.byte_size(), 90)
+    self.assertEqual(bq.qsize(), 3)
+    # Full on element count (size=3).
+    with self.assertRaises(queue.Full):
+      bq.put('10', 10, block=False)
+    self.assertEqual(bq.get(), '30')
+    self.assertEqual(bq.get(), '40')
+    bq.put('10', 10)
+    # Full on byte count
+    with self.assertRaises(queue.Full):
+      bq.put('90', 90, block=False)
+    self.assertEqual(bq.get(), '20')
+    bq.put('90', 90, block=False)
+
+  def test_multithreading(self):
+    bq = ByteLimitedQueue(maxsize=0, maxbytes=100)
+    received = []
+
+    def producer():
+      for i in range(101):
+        bq.put(str(i), i)
+
+    poison_pill = 'POISON'
+
+    def consumer():
+      while True:
+        item = bq.get()
+        if item == poison_pill:
+          break
+        received.append(int(item))
+
+    t1 = threading.Thread(target=producer)
+    t2 = threading.Thread(target=producer)
+    t3 = threading.Thread(target=consumer)
+
+    t1.start()
+    t2.start()
+    t3.start()
+
+    t1.join()
+    t2.join()
+    bq.put(poison_pill, 0)
+
+    t3.join()
+
+    self.assertEqual(len(received), 202)
+    self.assertEqual(sum(received), 2 * sum(range(101)))
+
+  def test_put_timeout(self):
+    bq = ByteLimitedQueue(maxsize=0, maxbytes=10)
+    bq.put('10', 10)
+
+    # The queue is completely full. A timeout put should raise queue.Full.
+    with self.assertRaises(queue.Full):
+      bq.put('5', 5, timeout=0.01)
+
+    def delayed_consumer():
+      time.sleep(0.05)
+      bq.get()
+
+    # Start a thread that will free up space after 50ms.
+    t = threading.Thread(target=delayed_consumer)
+    t.start()
+
+    # The put should succeed once the consumer runs, use a high timeout to
+    # flakiness.
+    bq.put('item', 5, timeout=60)
+    t.join()
+
+  def test_get_timeout(self):
+    bq = ByteLimitedQueue(maxsize=0, maxbytes=100)
+    with self.assertRaises(queue.Empty):
+      bq.get(block=False)
+    with self.assertRaises(queue.Empty):
+      bq.get(timeout=0.0)
+    with self.assertRaises(queue.Empty):
+      bq.get(timeout=.01)
+
+    bq.put('1', 1)
+    self.assertEqual('1', bq.get(timeout=0))
+
+    bq.put('2', 2)
+    self.assertEqual('2', bq.get(timeout=0.1))
+
+    def delayed_producer():
+      time.sleep(0.05)
+      bq.put('3', 3)
+
+    # Start a thread that will produce soon
+    t = threading.Thread(target=delayed_producer)
+    t.start()
+
+    # The get should succeed once the produer runs, use a high timeout to
+    # flakiness.
+    self.assertEqual('3', bq.get(timeout=60))
+    t.join()
+
+  def test_negative_timeout(self):
+    bq = ByteLimitedQueue()
+    # Putting an item with a negative timeout should raise ValueError.
+    with self.assertRaises(ValueError):
+      bq.put('5', 5, timeout=-1)
+    with self.assertRaises(ValueError):
+      bq.get(timeout=-1)
+
+  def test_single_element_override(self):
+    bq = ByteLimitedQueue(maxbytes=10)
+    # An item of size 50 exceeds maxbytes 10, but should be admitted
+    # immediately without blocking since the queue is currently empty!
+    bq.put('50', 50, block=False)
+    self.assertEqual(bq.qsize(), 1)
+    self.assertEqual(bq.byte_size(), 50)
+
+  def test_fairness(self):
+    bq = ByteLimitedQueue(maxbytes=10)
+    # Put an initial item so that the queue is not empty,
+    # causing the subsequent large item to block.
+    bq.put('first', 2)
+    self.assertEqual(bq.blocked_byte_size(), 0)
+
+    def producer(item, size):
+      bq.put(item, size)
+
+    # Add an item in a background thread that should block due to exceeding
+    # the limit.
+    t1 = threading.Thread(target=producer, args=('too_large', 9))
+    t1.start()
+
+    # Wait until the background write is queued.
+    while bq.blocked_byte_size() < 1:
+      time.sleep(0.005)
+    self.assertEqual(bq.blocked_byte_size(), 9)
+
+    # Add smaller items afterwards.
+    t2 = threading.Thread(target=producer, args=('small1', 1))
+    t2.start()
+
+    while bq.blocked_byte_size() < 10:
+      time.sleep(0.005)
+    self.assertEqual(bq.blocked_byte_size(), 10)
+
+    t3 = threading.Thread(target=producer, args=('small2', 1))
+    t3.start()
+
+    while bq.blocked_byte_size() < 11:
+      time.sleep(0.005)
+    self.assertEqual(bq.blocked_byte_size(), 11)
+
+    # Verify all items are received in order.
+    self.assertEqual(bq.get(), 'first')
+    t1.join()
+    t2.join()
+    self.assertEqual(bq.get(), 'too_large')
+    t3.join()
+    self.assertEqual(bq.get(), 'small1')
+    self.assertEqual(bq.get(), 'small2')
+
+  def test_blocked_waiter_timeout_multiple(self):
+    bq = ByteLimitedQueue(maxbytes=10)
+    bq.put('initial', 5)
+
+    status = []
+    lock = threading.Lock()
+
+    def producer(name, size, timeout_val):
+      try:
+        bq.put(name, size, timeout=timeout_val)
+        with lock:
+          status.append((name, 'success'))
+      except queue.Full:
+        with lock:
+          status.append((name, 'timeout'))
+
+    threads = []
+    threads.append(threading.Thread(target=producer, args=('t1', 8, 0.2)))
+    threads.append(threading.Thread(target=producer, args=('t2', 8, 60.0)))
+    threads.append(threading.Thread(target=producer, args=('t3', 3, 0.1)))
+    threads.append(threading.Thread(target=producer, args=('t4', 3, 60.0)))
+    threads.append(threading.Thread(target=producer, args=('t5', 3, 0.1)))
+    for t in threads:
+      t.start()
+
+    # Wait for the short-timeout threads.
+    threads[4].join()
+    threads[2].join()
+    threads[0].join()
+
+    # Now waiting writers should just be t1 and t3
+    self.assertEqual(bq.blocked_byte_size(), 11)
+
+    self.assertEqual(bq.get(), 'initial')
+    threads[1].join()
+    self.assertGreater(bq.blocked_byte_size(), 0)
+
+    elem = bq.get()
+    self.assertTrue(elem == 't2' or elem == 't4')
+    threads[3].join()
+    self.assertEqual(bq.blocked_byte_size(), 0)
+    elem = bq.get()
+    self.assertTrue(elem == 't2' or elem == 't4')
+
+    with lock:
+      self.assertIn(('t1', 'timeout'), status)
+      self.assertIn(('t2', 'success'), status)
+      self.assertIn(('t3', 'timeout'), status)
+      self.assertIn(('t4', 'success'), status)
+      self.assertIn(('t5', 'timeout'), status)
+
+
+if __name__ == '__main__':
+  unittest.main()
diff --git a/sdks/python/setup.py b/sdks/python/setup.py
index b3fb98d8b0e..45781a44c4b 100644
--- a/sdks/python/setup.py
+++ b/sdks/python/setup.py
@@ -368,6 +368,7 @@ if __name__ == '__main__':
         'apache_beam/runners/worker/operations.py',
         'apache_beam/transforms/cy_combiners.py',
         'apache_beam/transforms/stats.py',
+        'apache_beam/utils/byte_limited_queue.py',
         'apache_beam/utils/counters.py',
         'apache_beam/utils/windowed_value.py',
     ])

Reply via email to