This is an automated email from the ASF dual-hosted git repository.

scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 708932149f1 extract semaphore logic out of WeightBoundedQueue to allow 
for sharing the weigher (#32905)
708932149f1 is described below

commit 708932149f1a86af87d5d599335157801a98ec9e
Author: martin trieu <[email protected]>
AuthorDate: Wed Nov 6 03:24:13 2024 -0600

    extract semaphore logic out of WeightBoundedQueue to allow for sharing the 
weigher (#32905)
---
 .../dataflow/worker/StreamingDataflowWorker.java   |  2 +
 .../worker/streaming/WeightedBoundedQueue.java     | 45 +++++-------
 .../worker/streaming/WeightedSemaphore.java        | 53 ++++++++++++++
 .../worker/windmill/client/commits/Commits.java    | 36 ++++++++++
 .../commits/StreamingApplianceWorkCommitter.java   |  8 +--
 .../commits/StreamingEngineWorkCommitter.java      | 16 +++--
 .../worker/streaming/WeightBoundedQueueTest.java   | 81 +++++++++++++++++-----
 .../commits/StreamingEngineWorkCommitterTest.java  |  2 +
 8 files changed, 184 insertions(+), 59 deletions(-)

diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
index ff72add83e4..6ce60283735 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
@@ -65,6 +65,7 @@ import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub;
 import 
org.apache.beam.runners.dataflow.worker.windmill.appliance.JniWindmillApplianceServer;
 import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
 import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool;
+import org.apache.beam.runners.dataflow.worker.windmill.client.commits.Commits;
 import 
org.apache.beam.runners.dataflow.worker.windmill.client.commits.CompleteCommit;
 import 
org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingApplianceWorkCommitter;
 import 
org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingEngineWorkCommitter;
@@ -199,6 +200,7 @@ public final class StreamingDataflowWorker {
     this.workCommitter =
         windmillServiceEnabled
             ? StreamingEngineWorkCommitter.builder()
+                .setCommitByteSemaphore(Commits.maxCommitByteSemaphore())
                 .setCommitWorkStreamFactory(
                     WindmillStreamPool.create(
                             numCommitThreads,
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java
index f2893f3e719..5f039be7b00 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java
@@ -18,33 +18,24 @@
 package org.apache.beam.runners.dataflow.worker.streaming;
 
 import java.util.concurrent.LinkedBlockingQueue;
-import java.util.concurrent.Semaphore;
 import java.util.concurrent.TimeUnit;
-import java.util.function.Function;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
 import org.checkerframework.checker.nullness.qual.Nullable;
 
-/** Bounded set of queues, with a maximum total weight. */
+/** Queue bounded by a {@link WeightedSemaphore}. */
 public final class WeightedBoundedQueue<V> {
 
   private final LinkedBlockingQueue<V> queue;
-  private final int maxWeight;
-  private final Semaphore limit;
-  private final Function<V, Integer> weigher;
+  private final WeightedSemaphore<V> weightedSemaphore;
 
   private WeightedBoundedQueue(
-      LinkedBlockingQueue<V> linkedBlockingQueue,
-      int maxWeight,
-      Semaphore limit,
-      Function<V, Integer> weigher) {
+      LinkedBlockingQueue<V> linkedBlockingQueue, WeightedSemaphore<V> 
weightedSemaphore) {
     this.queue = linkedBlockingQueue;
-    this.maxWeight = maxWeight;
-    this.limit = limit;
-    this.weigher = weigher;
+    this.weightedSemaphore = weightedSemaphore;
   }
 
-  public static <V> WeightedBoundedQueue<V> create(int maxWeight, Function<V, 
Integer> weigherFn) {
-    return new WeightedBoundedQueue<>(
-        new LinkedBlockingQueue<>(), maxWeight, new Semaphore(maxWeight, 
true), weigherFn);
+  public static <V> WeightedBoundedQueue<V> create(WeightedSemaphore<V> 
weightedSemaphore) {
+    return new WeightedBoundedQueue<>(new LinkedBlockingQueue<>(), 
weightedSemaphore);
   }
 
   /**
@@ -52,15 +43,15 @@ public final class WeightedBoundedQueue<V> {
    * limit.
    */
   public void put(V value) {
-    limit.acquireUninterruptibly(weigher.apply(value));
+    weightedSemaphore.acquireUninterruptibly(value);
     queue.add(value);
   }
 
   /** Returns and removes the next value, or null if there is no such value. */
   public @Nullable V poll() {
-    V result = queue.poll();
+    @Nullable V result = queue.poll();
     if (result != null) {
-      limit.release(weigher.apply(result));
+      weightedSemaphore.release(result);
     }
     return result;
   }
@@ -76,26 +67,22 @@ public final class WeightedBoundedQueue<V> {
    * @throws InterruptedException if interrupted while waiting
    */
   public @Nullable V poll(long timeout, TimeUnit unit) throws 
InterruptedException {
-    V result = queue.poll(timeout, unit);
+    @Nullable V result = queue.poll(timeout, unit);
     if (result != null) {
-      limit.release(weigher.apply(result));
+      weightedSemaphore.release(result);
     }
     return result;
   }
 
   /** Returns and removes the next value, or blocks until one is available. */
-  public @Nullable V take() throws InterruptedException {
+  public V take() throws InterruptedException {
     V result = queue.take();
-    limit.release(weigher.apply(result));
+    weightedSemaphore.release(result);
     return result;
   }
 
-  /** Returns the current weight of the queue. */
-  public int queuedElementsWeight() {
-    return maxWeight - limit.availablePermits();
-  }
-
-  public int size() {
+  @VisibleForTesting
+  int size() {
     return queue.size();
   }
 }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedSemaphore.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedSemaphore.java
new file mode 100644
index 00000000000..d92dd07cb1a
--- /dev/null
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedSemaphore.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.streaming;
+
+import java.util.concurrent.Semaphore;
+import java.util.function.Function;
+
+public final class WeightedSemaphore<V> {
+  private final int maxWeight;
+  private final Semaphore limit;
+  private final Function<V, Integer> weigher;
+
+  private WeightedSemaphore(int maxWeight, Semaphore limit, Function<V, 
Integer> weigher) {
+    this.maxWeight = maxWeight;
+    this.limit = limit;
+    this.weigher = weigher;
+  }
+
+  public static <V> WeightedSemaphore<V> create(int maxWeight, Function<V, 
Integer> weigherFn) {
+    return new WeightedSemaphore<>(maxWeight, new Semaphore(maxWeight, true), 
weigherFn);
+  }
+
+  public void acquireUninterruptibly(V value) {
+    limit.acquireUninterruptibly(computePermits(value));
+  }
+
+  public void release(V value) {
+    limit.release(computePermits(value));
+  }
+
+  private int computePermits(V value) {
+    return Math.min(weigher.apply(value), maxWeight);
+  }
+
+  public int currentWeight() {
+    return maxWeight - limit.availablePermits();
+  }
+}
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commits.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commits.java
new file mode 100644
index 00000000000..498e90f78e2
--- /dev/null
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commits.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.commits;
+
+import org.apache.beam.runners.dataflow.worker.streaming.WeightedSemaphore;
+import org.apache.beam.sdk.annotations.Internal;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
+
+/** Utility class for commits. */
+@Internal
+public final class Commits {
+
+  /** Max bytes of commits queued on the user worker. */
+  @VisibleForTesting static final int MAX_QUEUED_COMMITS_BYTES = 500 << 20; // 
500MB
+
+  private Commits() {}
+
+  public static WeightedSemaphore<Commit> maxCommitByteSemaphore() {
+    return WeightedSemaphore.create(MAX_QUEUED_COMMITS_BYTES, Commit::getSize);
+  }
+}
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java
index 6889764afe6..20b95b0661d 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java
@@ -42,7 +42,6 @@ import org.slf4j.LoggerFactory;
 public final class StreamingApplianceWorkCommitter implements WorkCommitter {
   private static final Logger LOG = 
LoggerFactory.getLogger(StreamingApplianceWorkCommitter.class);
   private static final long TARGET_COMMIT_BUNDLE_BYTES = 32 << 20;
-  private static final int MAX_COMMIT_QUEUE_BYTES = 500 << 20; // 500MB
 
   private final Consumer<CommitWorkRequest> commitWorkFn;
   private final WeightedBoundedQueue<Commit> commitQueue;
@@ -53,9 +52,7 @@ public final class StreamingApplianceWorkCommitter implements 
WorkCommitter {
   private StreamingApplianceWorkCommitter(
       Consumer<CommitWorkRequest> commitWorkFn, Consumer<CompleteCommit> 
onCommitComplete) {
     this.commitWorkFn = commitWorkFn;
-    this.commitQueue =
-        WeightedBoundedQueue.create(
-            MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES, 
commit.getSize()));
+    this.commitQueue = 
WeightedBoundedQueue.create(Commits.maxCommitByteSemaphore());
     this.commitWorkers =
         Executors.newSingleThreadExecutor(
             new ThreadFactoryBuilder()
@@ -73,10 +70,9 @@ public final class StreamingApplianceWorkCommitter 
implements WorkCommitter {
   }
 
   @Override
-  @SuppressWarnings("FutureReturnValueIgnored")
   public void start() {
     if (!commitWorkers.isShutdown()) {
-      commitWorkers.submit(this::commitLoop);
+      commitWorkers.execute(this::commitLoop);
     }
   }
 
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java
index bf1007bc4bf..85fa1d67c6c 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java
@@ -28,6 +28,7 @@ import java.util.function.Supplier;
 import javax.annotation.Nullable;
 import javax.annotation.concurrent.ThreadSafe;
 import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue;
+import org.apache.beam.runners.dataflow.worker.streaming.WeightedSemaphore;
 import org.apache.beam.runners.dataflow.worker.streaming.Work;
 import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream;
 import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream;
@@ -46,7 +47,6 @@ import org.slf4j.LoggerFactory;
 public final class StreamingEngineWorkCommitter implements WorkCommitter {
   private static final Logger LOG = 
LoggerFactory.getLogger(StreamingEngineWorkCommitter.class);
   private static final int TARGET_COMMIT_BATCH_KEYS = 5;
-  private static final int MAX_COMMIT_QUEUE_BYTES = 500 << 20; // 500MB
   private static final String NO_BACKEND_WORKER_TOKEN = "";
 
   private final Supplier<CloseableStream<CommitWorkStream>> 
commitWorkStreamFactory;
@@ -61,11 +61,10 @@ public final class StreamingEngineWorkCommitter implements 
WorkCommitter {
       Supplier<CloseableStream<CommitWorkStream>> commitWorkStreamFactory,
       int numCommitSenders,
       Consumer<CompleteCommit> onCommitComplete,
-      String backendWorkerToken) {
+      String backendWorkerToken,
+      WeightedSemaphore<Commit> commitByteSemaphore) {
     this.commitWorkStreamFactory = commitWorkStreamFactory;
-    this.commitQueue =
-        WeightedBoundedQueue.create(
-            MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES, 
commit.getSize()));
+    this.commitQueue = WeightedBoundedQueue.create(commitByteSemaphore);
     this.commitSenders =
         Executors.newFixedThreadPool(
             numCommitSenders,
@@ -90,12 +89,11 @@ public final class StreamingEngineWorkCommitter implements 
WorkCommitter {
   }
 
   @Override
-  @SuppressWarnings("FutureReturnValueIgnored")
   public void start() {
     Preconditions.checkState(
         isRunning.compareAndSet(false, true), "Multiple calls to 
WorkCommitter.start().");
     for (int i = 0; i < numCommitSenders; i++) {
-      commitSenders.submit(this::streamingCommitLoop);
+      commitSenders.execute(this::streamingCommitLoop);
     }
   }
 
@@ -166,6 +164,8 @@ public final class StreamingEngineWorkCommitter implements 
WorkCommitter {
             return;
           }
         }
+
+        // take() blocks until a value is available in the commitQueue.
         Preconditions.checkNotNull(initialCommit);
 
         if (initialCommit.work().isFailed()) {
@@ -258,6 +258,8 @@ public final class StreamingEngineWorkCommitter implements 
WorkCommitter {
     Builder setCommitWorkStreamFactory(
         Supplier<CloseableStream<CommitWorkStream>> commitWorkStreamFactory);
 
+    Builder setCommitByteSemaphore(WeightedSemaphore<Commit> 
commitByteSemaphore);
+
     Builder setNumCommitSenders(int numCommitSenders);
 
     Builder setOnCommitComplete(Consumer<CompleteCommit> onCommitComplete);
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java
index 4f035c88774..c71001fbeee 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java
@@ -22,6 +22,7 @@ import static org.junit.Assert.assertNull;
 
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
+import javax.annotation.Nullable;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.Timeout;
@@ -30,27 +31,29 @@ import org.junit.runners.JUnit4;
 
 @RunWith(JUnit4.class)
 public class WeightBoundedQueueTest {
-  @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
   private static final int MAX_WEIGHT = 10;
+  @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
 
   @Test
   public void testPut_hasCapacity() {
-    WeightedBoundedQueue<Integer> queue =
-        WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
+    WeightedSemaphore<Integer> weightedSemaphore =
+        WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
+    WeightedBoundedQueue<Integer> queue = 
WeightedBoundedQueue.create(weightedSemaphore);
 
     int insertedValue = 1;
 
     queue.put(insertedValue);
 
-    assertEquals(insertedValue, queue.queuedElementsWeight());
+    assertEquals(insertedValue, weightedSemaphore.currentWeight());
     assertEquals(1, queue.size());
     assertEquals(insertedValue, (int) queue.poll());
   }
 
   @Test
   public void testPut_noCapacity() throws InterruptedException {
-    WeightedBoundedQueue<Integer> queue =
-        WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
+    WeightedSemaphore<Integer> weightedSemaphore =
+        WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
+    WeightedBoundedQueue<Integer> queue = 
WeightedBoundedQueue.create(weightedSemaphore);
 
     // Insert value that takes all the capacity into the queue.
     queue.put(MAX_WEIGHT);
@@ -71,7 +74,7 @@ public class WeightBoundedQueueTest {
 
     // Should only see the first value in the queue, since the queue is at 
capacity.  thread2
     // should be blocked.
-    assertEquals(MAX_WEIGHT, queue.queuedElementsWeight());
+    assertEquals(MAX_WEIGHT, weightedSemaphore.currentWeight());
     assertEquals(1, queue.size());
 
     // Poll the queue, pulling off the only value inside and freeing up the 
capacity in the queue.
@@ -80,14 +83,15 @@ public class WeightBoundedQueueTest {
     // Wait for the putThread which was previously blocked due to the queue 
being at capacity.
     putThread.join();
 
-    assertEquals(MAX_WEIGHT, queue.queuedElementsWeight());
+    assertEquals(MAX_WEIGHT, weightedSemaphore.currentWeight());
     assertEquals(1, queue.size());
   }
 
   @Test
   public void testPoll() {
-    WeightedBoundedQueue<Integer> queue =
-        WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
+    WeightedSemaphore<Integer> weightedSemaphore =
+        WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
+    WeightedBoundedQueue<Integer> queue = 
WeightedBoundedQueue.create(weightedSemaphore);
 
     int insertedValue1 = 1;
     int insertedValue2 = 2;
@@ -95,7 +99,7 @@ public class WeightBoundedQueueTest {
     queue.put(insertedValue1);
     queue.put(insertedValue2);
 
-    assertEquals(insertedValue1 + insertedValue2, 
queue.queuedElementsWeight());
+    assertEquals(insertedValue1 + insertedValue2, 
weightedSemaphore.currentWeight());
     assertEquals(2, queue.size());
     assertEquals(insertedValue1, (int) queue.poll());
     assertEquals(1, queue.size());
@@ -104,7 +108,8 @@ public class WeightBoundedQueueTest {
   @Test
   public void testPoll_withTimeout() throws InterruptedException {
     WeightedBoundedQueue<Integer> queue =
-        WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
+        WeightedBoundedQueue.create(
+            WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, 
i)));
     int pollWaitTimeMillis = 10000;
     int insertedValue1 = 1;
 
@@ -132,7 +137,8 @@ public class WeightBoundedQueueTest {
   @Test
   public void testPoll_withTimeout_timesOut() throws InterruptedException {
     WeightedBoundedQueue<Integer> queue =
-        WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
+        WeightedBoundedQueue.create(
+            WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, 
i)));
     int defaultPollResult = -10;
     int pollWaitTimeMillis = 100;
     int insertedValue1 = 1;
@@ -144,13 +150,17 @@ public class WeightBoundedQueueTest {
     Thread pollThread =
         new Thread(
             () -> {
-              int polled;
+              @Nullable Integer polled;
               try {
                 polled = queue.poll(pollWaitTimeMillis, TimeUnit.MILLISECONDS);
-                pollResult.set(polled);
+                if (polled != null) {
+                  pollResult.set(polled);
+                }
               } catch (InterruptedException e) {
                 throw new RuntimeException(e);
               }
+
+              assertNull(polled);
             });
 
     pollThread.start();
@@ -164,7 +174,8 @@ public class WeightBoundedQueueTest {
   @Test
   public void testPoll_emptyQueue() {
     WeightedBoundedQueue<Integer> queue =
-        WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
+        WeightedBoundedQueue.create(
+            WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, 
i)));
 
     assertNull(queue.poll());
   }
@@ -172,7 +183,8 @@ public class WeightBoundedQueueTest {
   @Test
   public void testTake() throws InterruptedException {
     WeightedBoundedQueue<Integer> queue =
-        WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
+        WeightedBoundedQueue.create(
+            WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, 
i)));
 
     AtomicInteger value = new AtomicInteger();
     // Should block until value is available
@@ -194,4 +206,39 @@ public class WeightBoundedQueueTest {
 
     assertEquals(MAX_WEIGHT, value.get());
   }
+
+  @Test
+  public void testPut_sharedWeigher() throws InterruptedException {
+    WeightedSemaphore<Integer> weigher =
+        WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
+    WeightedBoundedQueue<Integer> queue1 = 
WeightedBoundedQueue.create(weigher);
+    WeightedBoundedQueue<Integer> queue2 = 
WeightedBoundedQueue.create(weigher);
+
+    // Insert value that takes all the weight into the queue1.
+    queue1.put(MAX_WEIGHT);
+
+    // Try to insert a value into the queue2. This will block since there is 
no capacity in the
+    // weigher.
+    Thread putThread = new Thread(() -> queue2.put(MAX_WEIGHT));
+    putThread.start();
+    // Should only see the first value in the queue, since the queue is at 
capacity. putThread
+    // should be blocked. The weight should be the same however, since queue1 
and queue2 are sharing
+    // the weigher.
+    Thread.sleep(100);
+    assertEquals(MAX_WEIGHT, weigher.currentWeight());
+    assertEquals(1, queue1.size());
+    assertEquals(0, queue2.size());
+
+    // Poll queue1, pulling off the only value inside and freeing up the 
capacity in the weigher.
+    queue1.poll();
+
+    // Wait for the putThread which was previously blocked due to the weigher 
being at capacity.
+    putThread.join();
+
+    assertEquals(MAX_WEIGHT, weigher.currentWeight());
+    assertEquals(1, queue2.size());
+    queue2.poll();
+    assertEquals(0, queue2.size());
+    assertEquals(0, weigher.currentWeight());
+  }
 }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java
index 546a2883e3b..c05a4dd340d 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java
@@ -121,6 +121,7 @@ public class StreamingEngineWorkCommitterTest {
 
   private WorkCommitter createWorkCommitter(Consumer<CompleteCommit> 
onCommitComplete) {
     return StreamingEngineWorkCommitter.builder()
+        .setCommitByteSemaphore(Commits.maxCommitByteSemaphore())
         .setCommitWorkStreamFactory(commitWorkStreamFactory)
         .setOnCommitComplete(onCommitComplete)
         .build();
@@ -342,6 +343,7 @@ public class StreamingEngineWorkCommitterTest {
     Set<CompleteCommit> completeCommits = Collections.newSetFromMap(new 
ConcurrentHashMap<>());
     workCommitter =
         StreamingEngineWorkCommitter.builder()
+            .setCommitByteSemaphore(Commits.maxCommitByteSemaphore())
             .setCommitWorkStreamFactory(commitWorkStreamFactory)
             .setNumCommitSenders(5)
             .setOnCommitComplete(completeCommits::add)

Reply via email to