This is an automated email from the ASF dual-hosted git repository.
scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 708932149f1 extract semaphore logic out of WeightBoundedQueue to allow
for sharing the weigher (#32905)
708932149f1 is described below
commit 708932149f1a86af87d5d599335157801a98ec9e
Author: martin trieu <[email protected]>
AuthorDate: Wed Nov 6 03:24:13 2024 -0600
extract semaphore logic out of WeightBoundedQueue to allow for sharing the
weigher (#32905)
---
.../dataflow/worker/StreamingDataflowWorker.java | 2 +
.../worker/streaming/WeightedBoundedQueue.java | 45 +++++-------
.../worker/streaming/WeightedSemaphore.java | 53 ++++++++++++++
.../worker/windmill/client/commits/Commits.java | 36 ++++++++++
.../commits/StreamingApplianceWorkCommitter.java | 8 +--
.../commits/StreamingEngineWorkCommitter.java | 16 +++--
.../worker/streaming/WeightBoundedQueueTest.java | 81 +++++++++++++++++-----
.../commits/StreamingEngineWorkCommitterTest.java | 2 +
8 files changed, 184 insertions(+), 59 deletions(-)
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
index ff72add83e4..6ce60283735 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
@@ -65,6 +65,7 @@ import
org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub;
import
org.apache.beam.runners.dataflow.worker.windmill.appliance.JniWindmillApplianceServer;
import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool;
+import org.apache.beam.runners.dataflow.worker.windmill.client.commits.Commits;
import
org.apache.beam.runners.dataflow.worker.windmill.client.commits.CompleteCommit;
import
org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingApplianceWorkCommitter;
import
org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingEngineWorkCommitter;
@@ -199,6 +200,7 @@ public final class StreamingDataflowWorker {
this.workCommitter =
windmillServiceEnabled
? StreamingEngineWorkCommitter.builder()
+ .setCommitByteSemaphore(Commits.maxCommitByteSemaphore())
.setCommitWorkStreamFactory(
WindmillStreamPool.create(
numCommitThreads,
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java
index f2893f3e719..5f039be7b00 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java
@@ -18,33 +18,24 @@
package org.apache.beam.runners.dataflow.worker.streaming;
import java.util.concurrent.LinkedBlockingQueue;
-import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
-import java.util.function.Function;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.checkerframework.checker.nullness.qual.Nullable;
-/** Bounded set of queues, with a maximum total weight. */
+/** Queue bounded by a {@link WeightedSemaphore}. */
public final class WeightedBoundedQueue<V> {
private final LinkedBlockingQueue<V> queue;
- private final int maxWeight;
- private final Semaphore limit;
- private final Function<V, Integer> weigher;
+ private final WeightedSemaphore<V> weightedSemaphore;
private WeightedBoundedQueue(
- LinkedBlockingQueue<V> linkedBlockingQueue,
- int maxWeight,
- Semaphore limit,
- Function<V, Integer> weigher) {
+ LinkedBlockingQueue<V> linkedBlockingQueue, WeightedSemaphore<V>
weightedSemaphore) {
this.queue = linkedBlockingQueue;
- this.maxWeight = maxWeight;
- this.limit = limit;
- this.weigher = weigher;
+ this.weightedSemaphore = weightedSemaphore;
}
- public static <V> WeightedBoundedQueue<V> create(int maxWeight, Function<V,
Integer> weigherFn) {
- return new WeightedBoundedQueue<>(
- new LinkedBlockingQueue<>(), maxWeight, new Semaphore(maxWeight,
true), weigherFn);
+ public static <V> WeightedBoundedQueue<V> create(WeightedSemaphore<V>
weightedSemaphore) {
+ return new WeightedBoundedQueue<>(new LinkedBlockingQueue<>(),
weightedSemaphore);
}
/**
@@ -52,15 +43,15 @@ public final class WeightedBoundedQueue<V> {
* limit.
*/
public void put(V value) {
- limit.acquireUninterruptibly(weigher.apply(value));
+ weightedSemaphore.acquireUninterruptibly(value);
queue.add(value);
}
/** Returns and removes the next value, or null if there is no such value. */
public @Nullable V poll() {
- V result = queue.poll();
+ @Nullable V result = queue.poll();
if (result != null) {
- limit.release(weigher.apply(result));
+ weightedSemaphore.release(result);
}
return result;
}
@@ -76,26 +67,22 @@ public final class WeightedBoundedQueue<V> {
* @throws InterruptedException if interrupted while waiting
*/
public @Nullable V poll(long timeout, TimeUnit unit) throws
InterruptedException {
- V result = queue.poll(timeout, unit);
+ @Nullable V result = queue.poll(timeout, unit);
if (result != null) {
- limit.release(weigher.apply(result));
+ weightedSemaphore.release(result);
}
return result;
}
/** Returns and removes the next value, or blocks until one is available. */
- public @Nullable V take() throws InterruptedException {
+ public V take() throws InterruptedException {
V result = queue.take();
- limit.release(weigher.apply(result));
+ weightedSemaphore.release(result);
return result;
}
- /** Returns the current weight of the queue. */
- public int queuedElementsWeight() {
- return maxWeight - limit.availablePermits();
- }
-
- public int size() {
+ @VisibleForTesting
+ int size() {
return queue.size();
}
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedSemaphore.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedSemaphore.java
new file mode 100644
index 00000000000..d92dd07cb1a
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedSemaphore.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.streaming;
+
+import java.util.concurrent.Semaphore;
+import java.util.function.Function;
+
+public final class WeightedSemaphore<V> {
+ private final int maxWeight;
+ private final Semaphore limit;
+ private final Function<V, Integer> weigher;
+
+ private WeightedSemaphore(int maxWeight, Semaphore limit, Function<V,
Integer> weigher) {
+ this.maxWeight = maxWeight;
+ this.limit = limit;
+ this.weigher = weigher;
+ }
+
+ public static <V> WeightedSemaphore<V> create(int maxWeight, Function<V,
Integer> weigherFn) {
+ return new WeightedSemaphore<>(maxWeight, new Semaphore(maxWeight, true),
weigherFn);
+ }
+
+ public void acquireUninterruptibly(V value) {
+ limit.acquireUninterruptibly(computePermits(value));
+ }
+
+ public void release(V value) {
+ limit.release(computePermits(value));
+ }
+
+ private int computePermits(V value) {
+ return Math.min(weigher.apply(value), maxWeight);
+ }
+
+ public int currentWeight() {
+ return maxWeight - limit.availablePermits();
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commits.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commits.java
new file mode 100644
index 00000000000..498e90f78e2
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commits.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.commits;
+
+import org.apache.beam.runners.dataflow.worker.streaming.WeightedSemaphore;
+import org.apache.beam.sdk.annotations.Internal;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
+
+/** Utility class for commits. */
+@Internal
+public final class Commits {
+
+ /** Max bytes of commits queued on the user worker. */
+ @VisibleForTesting static final int MAX_QUEUED_COMMITS_BYTES = 500 << 20; //
500MB
+
+ private Commits() {}
+
+ public static WeightedSemaphore<Commit> maxCommitByteSemaphore() {
+ return WeightedSemaphore.create(MAX_QUEUED_COMMITS_BYTES, Commit::getSize);
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java
index 6889764afe6..20b95b0661d 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java
@@ -42,7 +42,6 @@ import org.slf4j.LoggerFactory;
public final class StreamingApplianceWorkCommitter implements WorkCommitter {
private static final Logger LOG =
LoggerFactory.getLogger(StreamingApplianceWorkCommitter.class);
private static final long TARGET_COMMIT_BUNDLE_BYTES = 32 << 20;
- private static final int MAX_COMMIT_QUEUE_BYTES = 500 << 20; // 500MB
private final Consumer<CommitWorkRequest> commitWorkFn;
private final WeightedBoundedQueue<Commit> commitQueue;
@@ -53,9 +52,7 @@ public final class StreamingApplianceWorkCommitter implements
WorkCommitter {
private StreamingApplianceWorkCommitter(
Consumer<CommitWorkRequest> commitWorkFn, Consumer<CompleteCommit>
onCommitComplete) {
this.commitWorkFn = commitWorkFn;
- this.commitQueue =
- WeightedBoundedQueue.create(
- MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES,
commit.getSize()));
+ this.commitQueue =
WeightedBoundedQueue.create(Commits.maxCommitByteSemaphore());
this.commitWorkers =
Executors.newSingleThreadExecutor(
new ThreadFactoryBuilder()
@@ -73,10 +70,9 @@ public final class StreamingApplianceWorkCommitter
implements WorkCommitter {
}
@Override
- @SuppressWarnings("FutureReturnValueIgnored")
public void start() {
if (!commitWorkers.isShutdown()) {
- commitWorkers.submit(this::commitLoop);
+ commitWorkers.execute(this::commitLoop);
}
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java
index bf1007bc4bf..85fa1d67c6c 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java
@@ -28,6 +28,7 @@ import java.util.function.Supplier;
import javax.annotation.Nullable;
import javax.annotation.concurrent.ThreadSafe;
import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue;
+import org.apache.beam.runners.dataflow.worker.streaming.WeightedSemaphore;
import org.apache.beam.runners.dataflow.worker.streaming.Work;
import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream;
import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream;
@@ -46,7 +47,6 @@ import org.slf4j.LoggerFactory;
public final class StreamingEngineWorkCommitter implements WorkCommitter {
private static final Logger LOG =
LoggerFactory.getLogger(StreamingEngineWorkCommitter.class);
private static final int TARGET_COMMIT_BATCH_KEYS = 5;
- private static final int MAX_COMMIT_QUEUE_BYTES = 500 << 20; // 500MB
private static final String NO_BACKEND_WORKER_TOKEN = "";
private final Supplier<CloseableStream<CommitWorkStream>>
commitWorkStreamFactory;
@@ -61,11 +61,10 @@ public final class StreamingEngineWorkCommitter implements
WorkCommitter {
Supplier<CloseableStream<CommitWorkStream>> commitWorkStreamFactory,
int numCommitSenders,
Consumer<CompleteCommit> onCommitComplete,
- String backendWorkerToken) {
+ String backendWorkerToken,
+ WeightedSemaphore<Commit> commitByteSemaphore) {
this.commitWorkStreamFactory = commitWorkStreamFactory;
- this.commitQueue =
- WeightedBoundedQueue.create(
- MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES,
commit.getSize()));
+ this.commitQueue = WeightedBoundedQueue.create(commitByteSemaphore);
this.commitSenders =
Executors.newFixedThreadPool(
numCommitSenders,
@@ -90,12 +89,11 @@ public final class StreamingEngineWorkCommitter implements
WorkCommitter {
}
@Override
- @SuppressWarnings("FutureReturnValueIgnored")
public void start() {
Preconditions.checkState(
isRunning.compareAndSet(false, true), "Multiple calls to
WorkCommitter.start().");
for (int i = 0; i < numCommitSenders; i++) {
- commitSenders.submit(this::streamingCommitLoop);
+ commitSenders.execute(this::streamingCommitLoop);
}
}
@@ -166,6 +164,8 @@ public final class StreamingEngineWorkCommitter implements
WorkCommitter {
return;
}
}
+
+ // take() blocks until a value is available in the commitQueue.
Preconditions.checkNotNull(initialCommit);
if (initialCommit.work().isFailed()) {
@@ -258,6 +258,8 @@ public final class StreamingEngineWorkCommitter implements
WorkCommitter {
Builder setCommitWorkStreamFactory(
Supplier<CloseableStream<CommitWorkStream>> commitWorkStreamFactory);
+ Builder setCommitByteSemaphore(WeightedSemaphore<Commit>
commitByteSemaphore);
+
Builder setNumCommitSenders(int numCommitSenders);
Builder setOnCommitComplete(Consumer<CompleteCommit> onCommitComplete);
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java
index 4f035c88774..c71001fbeee 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java
@@ -22,6 +22,7 @@ import static org.junit.Assert.assertNull;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
+import javax.annotation.Nullable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;
@@ -30,27 +31,29 @@ import org.junit.runners.JUnit4;
@RunWith(JUnit4.class)
public class WeightBoundedQueueTest {
- @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
private static final int MAX_WEIGHT = 10;
+ @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
@Test
public void testPut_hasCapacity() {
- WeightedBoundedQueue<Integer> queue =
- WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
+ WeightedSemaphore<Integer> weightedSemaphore =
+ WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
+ WeightedBoundedQueue<Integer> queue =
WeightedBoundedQueue.create(weightedSemaphore);
int insertedValue = 1;
queue.put(insertedValue);
- assertEquals(insertedValue, queue.queuedElementsWeight());
+ assertEquals(insertedValue, weightedSemaphore.currentWeight());
assertEquals(1, queue.size());
assertEquals(insertedValue, (int) queue.poll());
}
@Test
public void testPut_noCapacity() throws InterruptedException {
- WeightedBoundedQueue<Integer> queue =
- WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
+ WeightedSemaphore<Integer> weightedSemaphore =
+ WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
+ WeightedBoundedQueue<Integer> queue =
WeightedBoundedQueue.create(weightedSemaphore);
// Insert value that takes all the capacity into the queue.
queue.put(MAX_WEIGHT);
@@ -71,7 +74,7 @@ public class WeightBoundedQueueTest {
// Should only see the first value in the queue, since the queue is at
capacity. thread2
// should be blocked.
- assertEquals(MAX_WEIGHT, queue.queuedElementsWeight());
+ assertEquals(MAX_WEIGHT, weightedSemaphore.currentWeight());
assertEquals(1, queue.size());
// Poll the queue, pulling off the only value inside and freeing up the
capacity in the queue.
@@ -80,14 +83,15 @@ public class WeightBoundedQueueTest {
// Wait for the putThread which was previously blocked due to the queue
being at capacity.
putThread.join();
- assertEquals(MAX_WEIGHT, queue.queuedElementsWeight());
+ assertEquals(MAX_WEIGHT, weightedSemaphore.currentWeight());
assertEquals(1, queue.size());
}
@Test
public void testPoll() {
- WeightedBoundedQueue<Integer> queue =
- WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
+ WeightedSemaphore<Integer> weightedSemaphore =
+ WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
+ WeightedBoundedQueue<Integer> queue =
WeightedBoundedQueue.create(weightedSemaphore);
int insertedValue1 = 1;
int insertedValue2 = 2;
@@ -95,7 +99,7 @@ public class WeightBoundedQueueTest {
queue.put(insertedValue1);
queue.put(insertedValue2);
- assertEquals(insertedValue1 + insertedValue2,
queue.queuedElementsWeight());
+ assertEquals(insertedValue1 + insertedValue2,
weightedSemaphore.currentWeight());
assertEquals(2, queue.size());
assertEquals(insertedValue1, (int) queue.poll());
assertEquals(1, queue.size());
@@ -104,7 +108,8 @@ public class WeightBoundedQueueTest {
@Test
public void testPoll_withTimeout() throws InterruptedException {
WeightedBoundedQueue<Integer> queue =
- WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
+ WeightedBoundedQueue.create(
+ WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT,
i)));
int pollWaitTimeMillis = 10000;
int insertedValue1 = 1;
@@ -132,7 +137,8 @@ public class WeightBoundedQueueTest {
@Test
public void testPoll_withTimeout_timesOut() throws InterruptedException {
WeightedBoundedQueue<Integer> queue =
- WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
+ WeightedBoundedQueue.create(
+ WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT,
i)));
int defaultPollResult = -10;
int pollWaitTimeMillis = 100;
int insertedValue1 = 1;
@@ -144,13 +150,17 @@ public class WeightBoundedQueueTest {
Thread pollThread =
new Thread(
() -> {
- int polled;
+ @Nullable Integer polled;
try {
polled = queue.poll(pollWaitTimeMillis, TimeUnit.MILLISECONDS);
- pollResult.set(polled);
+ if (polled != null) {
+ pollResult.set(polled);
+ }
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
+
+ assertNull(polled);
});
pollThread.start();
@@ -164,7 +174,8 @@ public class WeightBoundedQueueTest {
@Test
public void testPoll_emptyQueue() {
WeightedBoundedQueue<Integer> queue =
- WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
+ WeightedBoundedQueue.create(
+ WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT,
i)));
assertNull(queue.poll());
}
@@ -172,7 +183,8 @@ public class WeightBoundedQueueTest {
@Test
public void testTake() throws InterruptedException {
WeightedBoundedQueue<Integer> queue =
- WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
+ WeightedBoundedQueue.create(
+ WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT,
i)));
AtomicInteger value = new AtomicInteger();
// Should block until value is available
@@ -194,4 +206,39 @@ public class WeightBoundedQueueTest {
assertEquals(MAX_WEIGHT, value.get());
}
+
+ @Test
+ public void testPut_sharedWeigher() throws InterruptedException {
+ WeightedSemaphore<Integer> weigher =
+ WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i));
+ WeightedBoundedQueue<Integer> queue1 =
WeightedBoundedQueue.create(weigher);
+ WeightedBoundedQueue<Integer> queue2 =
WeightedBoundedQueue.create(weigher);
+
+ // Insert value that takes all the weight into the queue1.
+ queue1.put(MAX_WEIGHT);
+
+ // Try to insert a value into the queue2. This will block since there is
no capacity in the
+ // weigher.
+ Thread putThread = new Thread(() -> queue2.put(MAX_WEIGHT));
+ putThread.start();
+ // Should only see the first value in the queue, since the queue is at
capacity. putThread
+ // should be blocked. The weight should be the same however, since queue1
and queue2 are sharing
+ // the weigher.
+ Thread.sleep(100);
+ assertEquals(MAX_WEIGHT, weigher.currentWeight());
+ assertEquals(1, queue1.size());
+ assertEquals(0, queue2.size());
+
+ // Poll queue1, pulling off the only value inside and freeing up the
capacity in the weigher.
+ queue1.poll();
+
+ // Wait for the putThread which was previously blocked due to the weigher
being at capacity.
+ putThread.join();
+
+ assertEquals(MAX_WEIGHT, weigher.currentWeight());
+ assertEquals(1, queue2.size());
+ queue2.poll();
+ assertEquals(0, queue2.size());
+ assertEquals(0, weigher.currentWeight());
+ }
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java
index 546a2883e3b..c05a4dd340d 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java
@@ -121,6 +121,7 @@ public class StreamingEngineWorkCommitterTest {
private WorkCommitter createWorkCommitter(Consumer<CompleteCommit>
onCommitComplete) {
return StreamingEngineWorkCommitter.builder()
+ .setCommitByteSemaphore(Commits.maxCommitByteSemaphore())
.setCommitWorkStreamFactory(commitWorkStreamFactory)
.setOnCommitComplete(onCommitComplete)
.build();
@@ -342,6 +343,7 @@ public class StreamingEngineWorkCommitterTest {
Set<CompleteCommit> completeCommits = Collections.newSetFromMap(new
ConcurrentHashMap<>());
workCommitter =
StreamingEngineWorkCommitter.builder()
+ .setCommitByteSemaphore(Commits.maxCommitByteSemaphore())
.setCommitWorkStreamFactory(commitWorkStreamFactory)
.setNumCommitSenders(5)
.setOnCommitComplete(completeCommits::add)