This is an automated email from the ASF dual-hosted git repository.
scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new c298da550f8 Refactor commit logic out of StreamingDataflowWorker
(#30312)
c298da550f8 is described below
commit c298da550f88b1fd489bc25c08a40d484833e428
Author: martin trieu <[email protected]>
AuthorDate: Fri Mar 15 02:01:18 2024 -0700
Refactor commit logic out of StreamingDataflowWorker (#30312)
---
.../dataflow/worker/StreamingDataflowWorker.java | 237 +++-------------
.../dataflow/worker/WindmillComputationKey.java | 5 +
.../client/CloseableStream.java} | 30 +-
.../worker/windmill/client/WindmillStreamPool.java | 7 +
.../client/commits}/Commit.java | 10 +-
.../windmill/client/commits/CompleteCommit.java | 67 +++++
.../commits/StreamingApplianceWorkCommitter.java | 167 +++++++++++
.../commits/StreamingEngineWorkCommitter.java | 233 ++++++++++++++++
.../windmill/client/commits/WorkCommitter.java | 54 ++++
.../worker/windmill/state/WindmillStateCache.java | 5 +
.../dataflow/worker/FakeWindmillServer.java | 32 ++-
.../worker/StreamingDataflowWorkerTest.java | 2 +-
.../StreamingApplianceWorkCommitterTest.java | 140 ++++++++++
.../commits/StreamingEngineWorkCommitterTest.java | 308 +++++++++++++++++++++
14 files changed, 1077 insertions(+), 220 deletions(-)
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
index 6f1bb0847bc..4c3ffd08a0b 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
@@ -87,17 +87,14 @@ import
org.apache.beam.runners.dataflow.worker.status.DebugCapture.Capturable;
import
org.apache.beam.runners.dataflow.worker.status.LastExceptionDataProvider;
import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider;
import org.apache.beam.runners.dataflow.worker.status.WorkerStatusPages;
-import org.apache.beam.runners.dataflow.worker.streaming.Commit;
import org.apache.beam.runners.dataflow.worker.streaming.ComputationState;
import org.apache.beam.runners.dataflow.worker.streaming.ExecutionState;
import
org.apache.beam.runners.dataflow.worker.streaming.KeyCommitTooLargeException;
import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey;
import org.apache.beam.runners.dataflow.worker.streaming.StageInfo;
-import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue;
import org.apache.beam.runners.dataflow.worker.streaming.Work;
import org.apache.beam.runners.dataflow.worker.streaming.Work.State;
import
org.apache.beam.runners.dataflow.worker.streaming.WorkHeartbeatResponseProcessor;
-import org.apache.beam.runners.dataflow.worker.streaming.WorkId;
import
org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher;
import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor;
import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor;
@@ -110,9 +107,13 @@ import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribut
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub;
import
org.apache.beam.runners.dataflow.worker.windmill.appliance.JniWindmillApplianceServer;
-import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream;
import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream;
import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool;
+import org.apache.beam.runners.dataflow.worker.windmill.client.commits.Commit;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.commits.CompleteCommit;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingApplianceWorkCommitter;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingEngineWorkCommitter;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
import
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.ChannelzServlet;
import
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillServer;
import
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillStreamFactory;
@@ -217,9 +218,6 @@ public class StreamingDataflowWorker {
final WindmillStateCache stateCache;
// Maps from computation ids to per-computation state.
private final ConcurrentMap<String, ComputationState> computationMap;
- private final WeightedBoundedQueue<Commit> commitQueue =
- WeightedBoundedQueue.create(
- MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES,
commit.getSize()));
// Cache of tokens to commit callbacks.
// Using Cache with time eviction policy helps us to prevent memory leak
when callback ids are
// discarded by Dataflow service and calling commitCallback is best-effort.
@@ -234,8 +232,6 @@ public class StreamingDataflowWorker {
private final BoundedQueueExecutor workUnitExecutor;
private final WindmillServerStub windmillServer;
private final Thread dispatchThread;
- @VisibleForTesting final ImmutableList<Thread> commitThreads;
- private final AtomicLong activeCommitBytes = new AtomicLong();
private final AtomicLong previousTimeAtMaxThreads = new AtomicLong();
private final AtomicBoolean running = new AtomicBoolean();
private final SideInputStateFetcher sideInputStateFetcher;
@@ -296,6 +292,7 @@ public class StreamingDataflowWorker {
private final DataflowExecutionStateSampler sampler =
DataflowExecutionStateSampler.instance();
private final ActiveWorkRefresher activeWorkRefresher;
+ private final WorkCommitter workCommitter;
private StreamingDataflowWorker(
WindmillServerStub windmillServer,
@@ -403,29 +400,6 @@ public class StreamingDataflowWorker {
dispatchThread.setPriority(Thread.MIN_PRIORITY);
dispatchThread.setName("DispatchThread");
- int numCommitThreads = 1;
- if (windmillServiceEnabled && options.getWindmillServiceCommitThreads() >
0) {
- numCommitThreads = options.getWindmillServiceCommitThreads();
- }
-
- ImmutableList.Builder<Thread> commitThreadsBuilder =
ImmutableList.builder();
- for (int i = 0; i < numCommitThreads; ++i) {
- Thread commitThread =
- new Thread(
- () -> {
- if (windmillServiceEnabled) {
- streamingCommitLoop();
- } else {
- commitLoop();
- }
- });
- commitThread.setDaemon(true);
- commitThread.setPriority(Thread.MAX_PRIORITY);
- commitThread.setName("CommitThread " + i);
- commitThreadsBuilder.add(commitThread);
- }
- commitThreads = commitThreadsBuilder.build();
-
this.publishCounters = publishCounters;
this.clientId = clientId;
this.windmillServer = windmillServer;
@@ -438,6 +412,21 @@ public class StreamingDataflowWorker {
this.sideInputStateFetcher =
new
SideInputStateFetcher(metricTrackingWindmillServer::getSideInputData, options);
+ int numCommitThreads = 1;
+ if (windmillServiceEnabled && options.getWindmillServiceCommitThreads() >
0) {
+ numCommitThreads = options.getWindmillServiceCommitThreads();
+ }
+
+ this.workCommitter =
+ windmillServiceEnabled
+ ? StreamingEngineWorkCommitter.create(
+ WindmillStreamPool.create(
+ NUM_COMMIT_STREAMS, COMMIT_STREAM_TIMEOUT,
windmillServer::commitWorkStream)
+ ::getCloseableStream,
+ numCommitThreads,
+ this::onCompleteCommit)
+ : StreamingApplianceWorkCommitter.create(
+ windmillServer::commitWork, this::onCompleteCommit);
// Register standard file systems.
FileSystems.setDefaultPipelineOptions(options);
@@ -705,6 +694,11 @@ public class StreamingDataflowWorker {
return workUnitExecutor.executorQueueIsEmpty();
}
+ @VisibleForTesting
+ int numCommitThreads() {
+ return workCommitter.parallelism();
+ }
+
@SuppressWarnings("FutureReturnValueIgnored")
public void start() {
running.set(true);
@@ -716,7 +710,6 @@ public class StreamingDataflowWorker {
memoryMonitorThread.start();
dispatchThread.start();
- commitThreads.forEach(Thread::start);
sampler.start();
// Periodically report workers counters and other updates.
@@ -778,7 +771,7 @@ public class StreamingDataflowWorker {
TimeUnit.SECONDS);
scheduledExecutors.add(statusPageTimer);
}
-
+ workCommitter.start();
reportHarnessStartup();
}
@@ -834,12 +827,8 @@ public class StreamingDataflowWorker {
running.set(false);
dispatchThread.interrupt();
dispatchThread.join();
- // We need to interrupt the commitThreads in case they are blocking on
pulling
- // from the commitQueue.
- commitThreads.forEach(Thread::interrupt);
- for (Thread commitThread : commitThreads) {
- commitThread.join();
- }
+
+ workCommitter.stop();
memoryMonitor.stop();
memoryMonitorThread.join();
workUnitExecutor.shutdown();
@@ -1086,7 +1075,7 @@ public class StreamingDataflowWorker {
if (workItem.getSourceState().getOnlyFinalize()) {
outputBuilder.setSourceStateUpdates(Windmill.SourceState.newBuilder().setOnlyFinalize(true));
work.setState(State.COMMIT_QUEUED);
- commitQueue.put(Commit.create(outputBuilder.build(), computationState,
work));
+ workCommitter.commit(Commit.create(outputBuilder.build(),
computationState, work));
return;
}
@@ -1315,7 +1304,7 @@ public class StreamingDataflowWorker {
commitRequest = buildWorkItemTruncationRequest(key, workItem,
estimatedCommitSize);
}
- commitQueue.put(Commit.create(commitRequest, computationState, work));
+ workCommitter.commit(Commit.create(commitRequest, computationState,
work));
// Compute shuffle and state byte statistics these will be flushed
asynchronously.
long stateBytesWritten =
@@ -1444,163 +1433,21 @@ public class StreamingDataflowWorker {
return outputBuilder.build();
}
- private void commitLoop() {
- Map<ComputationState, Windmill.ComputationCommitWorkRequest.Builder>
computationRequestMap =
- new HashMap<>();
- while (running.get()) {
- computationRequestMap.clear();
- Windmill.CommitWorkRequest.Builder commitRequestBuilder =
- Windmill.CommitWorkRequest.newBuilder();
- long commitBytes = 0;
- // Block until we have a commit, then batch with additional commits.
- Commit commit = null;
- try {
- commit = commitQueue.take();
- } catch (InterruptedException e) {
- Thread.currentThread().interrupt();
- continue;
- }
- while (commit != null) {
- ComputationState computationState = commit.computationState();
- commit.work().setState(Work.State.COMMITTING);
- Windmill.ComputationCommitWorkRequest.Builder
computationRequestBuilder =
- computationRequestMap.get(computationState);
- if (computationRequestBuilder == null) {
- computationRequestBuilder =
commitRequestBuilder.addRequestsBuilder();
-
computationRequestBuilder.setComputationId(computationState.getComputationId());
- computationRequestMap.put(computationState,
computationRequestBuilder);
- }
- computationRequestBuilder.addRequests(commit.request());
- // Send the request if we've exceeded the bytes or there is no more
- // pending work. commitBytes is a long, so this cannot overflow.
- commitBytes += commit.getSize();
- if (commitBytes >= TARGET_COMMIT_BUNDLE_BYTES) {
- break;
- }
- commit = commitQueue.poll();
- }
- Windmill.CommitWorkRequest commitRequest = commitRequestBuilder.build();
- LOG.trace("Commit: {}", commitRequest);
- activeCommitBytes.addAndGet(commitBytes);
- windmillServer.commitWork(commitRequest);
- activeCommitBytes.addAndGet(-commitBytes);
- for (Map.Entry<ComputationState,
Windmill.ComputationCommitWorkRequest.Builder> entry :
- computationRequestMap.entrySet()) {
- ComputationState computationState = entry.getKey();
- for (Windmill.WorkItemCommitRequest workRequest :
entry.getValue().getRequestsList()) {
- computationState.completeWorkAndScheduleNextWorkForKey(
- ShardedKey.create(workRequest.getKey(),
workRequest.getShardingKey()),
- WorkId.builder()
- .setCacheToken(workRequest.getCacheToken())
- .setWorkToken(workRequest.getWorkToken())
- .build());
- }
- }
- }
- }
-
- // Adds the commit to the commitStream if it fits, returning true iff it is
consumed.
- private boolean addCommitToStream(Commit commit, CommitWorkStream
commitStream) {
- Preconditions.checkNotNull(commit);
- final ComputationState state = commit.computationState();
- final Windmill.WorkItemCommitRequest request = commit.request();
- // Drop commits for failed work. Such commits will be dropped by Windmill
anyway.
- if (commit.work().isFailed()) {
+ private void onCompleteCommit(CompleteCommit completeCommit) {
+ if (completeCommit.status() != Windmill.CommitStatus.OK) {
readerCache.invalidateReader(
WindmillComputationKey.create(
- state.getComputationId(), request.getKey(),
request.getShardingKey()));
+ completeCommit.computationId(), completeCommit.shardedKey()));
stateCache
- .forComputation(state.getComputationId())
- .invalidate(request.getKey(), request.getShardingKey());
- state.completeWorkAndScheduleNextWorkForKey(
- ShardedKey.create(request.getKey(), request.getShardingKey()),
- WorkId.builder()
- .setWorkToken(request.getWorkToken())
- .setCacheToken(request.getCacheToken())
- .build());
- return true;
- }
-
- final int size = commit.getSize();
- commit.work().setState(Work.State.COMMITTING);
- activeCommitBytes.addAndGet(size);
- if (commitStream.commitWorkItem(
- state.getComputationId(),
- request,
- (Windmill.CommitStatus status) -> {
- if (status != Windmill.CommitStatus.OK) {
- readerCache.invalidateReader(
- WindmillComputationKey.create(
- state.getComputationId(), request.getKey(),
request.getShardingKey()));
- stateCache
- .forComputation(state.getComputationId())
- .invalidate(request.getKey(), request.getShardingKey());
- }
- activeCommitBytes.addAndGet(-size);
- state.completeWorkAndScheduleNextWorkForKey(
- ShardedKey.create(request.getKey(), request.getShardingKey()),
- WorkId.builder()
- .setCacheToken(request.getCacheToken())
- .setWorkToken(request.getWorkToken())
- .build());
- })) {
- return true;
- } else {
- // Back out the stats changes since the commit wasn't consumed.
- commit.work().setState(Work.State.COMMIT_QUEUED);
- activeCommitBytes.addAndGet(-size);
- return false;
+ .forComputation(completeCommit.computationId())
+ .invalidate(completeCommit.shardedKey());
}
- }
- // Helper to batch additional commits into the commit stream as long as they
fit.
- // Returns a commit that was removed from the queue but not consumed or null.
- private Commit batchCommitsToStream(CommitWorkStream commitStream) {
- int commits = 1;
- while (running.get()) {
- Commit commit;
- try {
- if (commits < 5) {
- commit = commitQueue.poll(10 - 2L * commits, TimeUnit.MILLISECONDS);
- } else {
- commit = commitQueue.poll();
- }
- } catch (InterruptedException e) {
- // Continue processing until !running.get()
- continue;
- }
- if (commit == null || !addCommitToStream(commit, commitStream)) {
- return commit;
- }
- commits++;
- }
- return null;
- }
-
- private void streamingCommitLoop() {
- WindmillStreamPool<CommitWorkStream> streamPool =
- WindmillStreamPool.create(
- NUM_COMMIT_STREAMS, COMMIT_STREAM_TIMEOUT,
windmillServer::commitWorkStream);
- Commit initialCommit = null;
- while (running.get()) {
- if (initialCommit == null) {
- try {
- initialCommit = commitQueue.take();
- } catch (InterruptedException e) {
- continue;
- }
- }
- // We initialize the commit stream only after we have a commit to make
sure it is fresh.
- CommitWorkStream commitStream = streamPool.getStream();
- if (!addCommitToStream(initialCommit, commitStream)) {
- throw new AssertionError("Initial commit on flushed stream should
always be accepted.");
- }
- // Batch additional commits to the stream and possibly make an
un-batched commit the next
- // initial commit.
- initialCommit = batchCommitsToStream(commitStream);
- commitStream.flush();
- streamPool.releaseStream(commitStream);
- }
+ Optional.ofNullable(computationMap.get(completeCommit.computationId()))
+ .ifPresent(
+ state ->
+ state.completeWorkAndScheduleNextWorkForKey(
+ completeCommit.shardedKey(), completeCommit.workId()));
}
private Windmill.GetWorkResponse getWork() {
@@ -2094,7 +1941,7 @@ public class StreamingDataflowWorker {
writer.println(workUnitExecutor.summaryHtml());
writer.print("Active commit: ");
- appendHumanizedBytes(activeCommitBytes.get(), writer);
+ appendHumanizedBytes(workCommitter.currentActiveCommitBytes(), writer);
writer.println("<br>");
metricTrackingWindmillServer.printHtml(writer);
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillComputationKey.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillComputationKey.java
index a01b1d297c2..274fa3aff02 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillComputationKey.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillComputationKey.java
@@ -18,6 +18,7 @@
package org.apache.beam.runners.dataflow.worker;
import com.google.auto.value.AutoValue;
+import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.TextFormat;
@@ -29,6 +30,10 @@ public abstract class WindmillComputationKey {
return new AutoValue_WindmillComputationKey(computationId, key,
shardingKey);
}
+ public static WindmillComputationKey create(String computationId, ShardedKey
shardedKey) {
+ return create(computationId, shardedKey.key(), shardedKey.shardingKey());
+ }
+
public abstract String computationId();
public abstract ByteString key();
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillComputationKey.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/CloseableStream.java
similarity index 55%
copy from
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillComputationKey.java
copy to
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/CloseableStream.java
index a01b1d297c2..e76cc365965 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillComputationKey.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/CloseableStream.java
@@ -15,29 +15,29 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.beam.runners.dataflow.worker;
+package org.apache.beam.runners.dataflow.worker.windmill.client;
import com.google.auto.value.AutoValue;
-import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
-import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.TextFormat;
+import org.apache.beam.sdk.annotations.Internal;
+/**
+ * Wrapper for a {@link WindmillStream} that allows callers to tie an action
after the stream is
+ * finished being used. Has an option for closing code to be a no-op.
+ */
+@Internal
@AutoValue
-public abstract class WindmillComputationKey {
-
- public static WindmillComputationKey create(
- String computationId, ByteString key, long shardingKey) {
- return new AutoValue_WindmillComputationKey(computationId, key,
shardingKey);
+public abstract class CloseableStream<StreamT extends WindmillStream>
implements AutoCloseable {
+ public static <StreamT extends WindmillStream> CloseableStream<StreamT>
create(
+ StreamT stream, Runnable onClose) {
+ return new AutoValue_CloseableStream<>(stream, onClose);
}
- public abstract String computationId();
-
- public abstract ByteString key();
+ public abstract StreamT stream();
- public abstract long shardingKey();
+ abstract Runnable onClose();
@Override
- public final String toString() {
- return String.format(
- "%s: %s-%d", computationId(), TextFormat.escapeBytes(key()),
shardingKey());
+ public void close() throws Exception {
+ onClose().run();
}
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPool.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPool.java
index 9f1b67edc1e..0e4e085c066 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPool.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPool.java
@@ -25,6 +25,7 @@ import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Supplier;
import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.ThreadSafe;
+import org.apache.beam.sdk.annotations.Internal;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.Duration;
@@ -36,6 +37,7 @@ import org.joda.time.Instant;
* <p>The pool holds a fixed total number of streams, and keeps each stream
open for a specified
* time to allow for better load-balancing.
*/
+@Internal
@ThreadSafe
public class WindmillStreamPool<StreamT extends WindmillStream> {
@@ -131,6 +133,11 @@ public class WindmillStreamPool<StreamT extends
WindmillStream> {
}
}
+ public CloseableStream<StreamT> getCloseableStream() {
+ StreamT stream = getStream();
+ return CloseableStream.create(stream, () -> releaseStream(stream));
+ }
+
private synchronized WindmillStreamPool.StreamData<StreamT>
createAndCacheStream(int cacheKey) {
WindmillStreamPool.StreamData<StreamT> newStreamData =
new WindmillStreamPool.StreamData<>(streamSupplier.get());
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Commit.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commit.java
similarity index 81%
rename from
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Commit.java
rename to
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commit.java
index 94689796756..b840d22a343 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Commit.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commit.java
@@ -15,13 +15,17 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.beam.runners.dataflow.worker.streaming;
+package org.apache.beam.runners.dataflow.worker.windmill.client.commits;
import com.google.auto.value.AutoValue;
+import org.apache.beam.runners.dataflow.worker.streaming.ComputationState;
+import org.apache.beam.runners.dataflow.worker.streaming.Work;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest;
+import org.apache.beam.sdk.annotations.Internal;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
/** Value class for a queued commit. */
+@Internal
@AutoValue
public abstract class Commit {
@@ -31,6 +35,10 @@ public abstract class Commit {
return new AutoValue_Commit(request, computationState, work);
}
+ public final String computationId() {
+ return computationState().getComputationId();
+ }
+
public abstract WorkItemCommitRequest request();
public abstract ComputationState computationState();
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/CompleteCommit.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/CompleteCommit.java
new file mode 100644
index 00000000000..64fec71b000
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/CompleteCommit.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.commits;
+
+import com.google.auto.value.AutoValue;
+import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey;
+import org.apache.beam.runners.dataflow.worker.streaming.WorkId;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+
+/**
+ * A {@link Commit} is marked as complete when it has been attempted to be
committed back to
+ * Streaming Engine/Appliance via {@link
+ *
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub#commitWorkStream(StreamObserver)}
+ * for Streaming Engine or {@link
+ *
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub#commitWork(Windmill.CommitWorkRequest,
+ * StreamObserver)} for Streaming Appliance.
+ */
+@Internal
+@AutoValue
+public abstract class CompleteCommit {
+
+ public static CompleteCommit create(Commit commit, CommitStatus
commitStatus) {
+ return new AutoValue_CompleteCommit(
+ commit.computationId(),
+ ShardedKey.create(commit.request().getKey(),
commit.request().getShardingKey()),
+ WorkId.builder()
+ .setWorkToken(commit.request().getWorkToken())
+ .setCacheToken(commit.request().getCacheToken())
+ .build(),
+ commitStatus);
+ }
+
+ public static CompleteCommit create(
+ String computationId, ShardedKey shardedKey, WorkId workId, CommitStatus
status) {
+ return new AutoValue_CompleteCommit(computationId, shardedKey, workId,
status);
+ }
+
+ public static CompleteCommit forFailedWork(Commit commit) {
+ return create(commit, CommitStatus.ABORTED);
+ }
+
+ public abstract String computationId();
+
+ public abstract ShardedKey shardedKey();
+
+ public abstract WorkId workId();
+
+ public abstract CommitStatus status();
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java
new file mode 100644
index 00000000000..344f04cfd00
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.commits;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.Consumer;
+import javax.annotation.concurrent.ThreadSafe;
+import org.apache.beam.runners.dataflow.worker.streaming.ComputationState;
+import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey;
+import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue;
+import org.apache.beam.runners.dataflow.worker.streaming.Work;
+import org.apache.beam.runners.dataflow.worker.streaming.WorkId;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitWorkRequest;
+import org.apache.beam.sdk.annotations.Internal;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** Streaming appliance implementation of {@link WorkCommitter}. */
+@Internal
+@ThreadSafe
+public final class StreamingApplianceWorkCommitter implements WorkCommitter {
+ private static final Logger LOG =
LoggerFactory.getLogger(StreamingApplianceWorkCommitter.class);
+ private static final long TARGET_COMMIT_BUNDLE_BYTES = 32 << 20;
+ private static final int MAX_COMMIT_QUEUE_BYTES = 500 << 20; // 500MB
+
+ private final Consumer<CommitWorkRequest> commitWorkFn;
+ private final WeightedBoundedQueue<Commit> commitQueue;
+ private final ExecutorService commitWorkers;
+ private final AtomicLong activeCommitBytes;
+ private final Consumer<CompleteCommit> onCommitComplete;
+
+ private StreamingApplianceWorkCommitter(
+ Consumer<CommitWorkRequest> commitWorkFn, Consumer<CompleteCommit>
onCommitComplete) {
+ this.commitWorkFn = commitWorkFn;
+ this.commitQueue =
+ WeightedBoundedQueue.create(
+ MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES,
commit.getSize()));
+ this.commitWorkers =
+ Executors.newSingleThreadScheduledExecutor(
+ new ThreadFactoryBuilder()
+ .setDaemon(true)
+ .setPriority(Thread.MAX_PRIORITY)
+ .setNameFormat("CommitThread-%d")
+ .build());
+ this.activeCommitBytes = new AtomicLong();
+ this.onCommitComplete = onCommitComplete;
+ }
+
+ public static StreamingApplianceWorkCommitter create(
+ Consumer<CommitWorkRequest> commitWork, Consumer<CompleteCommit>
onCommitComplete) {
+ return new StreamingApplianceWorkCommitter(commitWork, onCommitComplete);
+ }
+
+ @Override
+ @SuppressWarnings("FutureReturnValueIgnored")
+ public void start() {
+ if (!commitWorkers.isShutdown()) {
+ commitWorkers.submit(this::commitLoop);
+ }
+ }
+
+ @Override
+ public void commit(Commit commit) {
+ commitQueue.put(commit);
+ }
+
+ @Override
+ public long currentActiveCommitBytes() {
+ return activeCommitBytes.get();
+ }
+
+ @Override
+ public void stop() {
+ commitWorkers.shutdownNow();
+ }
+
+ @Override
+ public int parallelism() {
+ return 1;
+ }
+
+ private void commitLoop() {
+ Map<ComputationState, Windmill.ComputationCommitWorkRequest.Builder>
computationRequestMap =
+ new HashMap<>();
+ while (true) {
+ computationRequestMap.clear();
+ CommitWorkRequest.Builder commitRequestBuilder =
CommitWorkRequest.newBuilder();
+ long commitBytes = 0;
+ // Block until we have a commit, then batch with additional commits.
+ Commit commit;
+ try {
+ commit = commitQueue.take();
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ continue;
+ }
+ while (commit != null) {
+ ComputationState computationState = commit.computationState();
+ commit.work().setState(Work.State.COMMITTING);
+ Windmill.ComputationCommitWorkRequest.Builder
computationRequestBuilder =
+ computationRequestMap.get(computationState);
+ if (computationRequestBuilder == null) {
+ computationRequestBuilder =
commitRequestBuilder.addRequestsBuilder();
+
computationRequestBuilder.setComputationId(computationState.getComputationId());
+ computationRequestMap.put(computationState,
computationRequestBuilder);
+ }
+ computationRequestBuilder.addRequests(commit.request());
+ // Send the request if we've exceeded the bytes or there is no more
+ // pending work. commitBytes is a long, so this cannot overflow.
+ commitBytes += commit.getSize();
+ if (commitBytes >= TARGET_COMMIT_BUNDLE_BYTES) {
+ break;
+ }
+ commit = commitQueue.poll();
+ }
+ commitWork(commitRequestBuilder.build(), commitBytes);
+ completeWork(computationRequestMap);
+ }
+ }
+
+ private void commitWork(CommitWorkRequest commitRequest, long commitBytes) {
+ LOG.trace("Commit: {}", commitRequest);
+ activeCommitBytes.addAndGet(commitBytes);
+ commitWorkFn.accept(commitRequest);
+ activeCommitBytes.addAndGet(-commitBytes);
+ }
+
+ private void completeWork(
+ Map<ComputationState, Windmill.ComputationCommitWorkRequest.Builder>
committedWork) {
+ for (Map.Entry<ComputationState,
Windmill.ComputationCommitWorkRequest.Builder> entry :
+ committedWork.entrySet()) {
+ for (Windmill.WorkItemCommitRequest workRequest :
entry.getValue().getRequestsList()) {
+ // Appliance errors are propagated by exception on entire batch.
+ onCommitComplete.accept(
+ CompleteCommit.create(
+ entry.getKey().getComputationId(),
+ ShardedKey.create(workRequest.getKey(),
workRequest.getShardingKey()),
+ WorkId.builder()
+ .setCacheToken(workRequest.getCacheToken())
+ .setWorkToken(workRequest.getWorkToken())
+ .build(),
+ Windmill.CommitStatus.OK));
+ }
+ }
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java
new file mode 100644
index 00000000000..f6088acf011
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.commits;
+
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.Consumer;
+import java.util.function.Supplier;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.ThreadSafe;
+import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue;
+import org.apache.beam.runners.dataflow.worker.streaming.Work;
+import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream;
+import org.apache.beam.sdk.annotations.Internal;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Streaming engine implementation of {@link WorkCommitter}. Commits work back
to Streaming Engine
+ * backend.
+ */
+@Internal
+@ThreadSafe
+public final class StreamingEngineWorkCommitter implements WorkCommitter {
+ private static final Logger LOG =
LoggerFactory.getLogger(StreamingEngineWorkCommitter.class);
+ private static final int TARGET_COMMIT_BATCH_KEYS = 5;
+ private static final int MAX_COMMIT_QUEUE_BYTES = 500 << 20; // 500MB
+
+ private final Supplier<CloseableStream<CommitWorkStream>>
commitWorkStreamFactory;
+ private final WeightedBoundedQueue<Commit> commitQueue;
+ private final ExecutorService commitSenders;
+ private final AtomicLong activeCommitBytes;
+ private final Consumer<CompleteCommit> onCommitComplete;
+ private final int numCommitSenders;
+
+ private StreamingEngineWorkCommitter(
+ Supplier<CloseableStream<CommitWorkStream>> commitWorkStreamFactory,
+ int numCommitSenders,
+ Consumer<CompleteCommit> onCommitComplete) {
+ this.commitWorkStreamFactory = commitWorkStreamFactory;
+ this.commitQueue =
+ WeightedBoundedQueue.create(
+ MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES,
commit.getSize()));
+ this.commitSenders =
+ Executors.newFixedThreadPool(
+ numCommitSenders,
+ new ThreadFactoryBuilder()
+ .setDaemon(true)
+ .setPriority(Thread.MAX_PRIORITY)
+ .setNameFormat("CommitThread-%d")
+ .build());
+ this.activeCommitBytes = new AtomicLong();
+ this.onCommitComplete = onCommitComplete;
+ this.numCommitSenders = numCommitSenders;
+ }
+
+ public static StreamingEngineWorkCommitter create(
+ Supplier<CloseableStream<CommitWorkStream>> commitWorkStreamFactory,
+ int numCommitSenders,
+ Consumer<CompleteCommit> onCommitComplete) {
+ return new StreamingEngineWorkCommitter(
+ commitWorkStreamFactory, numCommitSenders, onCommitComplete);
+ }
+
+ @Override
+ @SuppressWarnings("FutureReturnValueIgnored")
+ public void start() {
+ if (!commitSenders.isShutdown()) {
+ for (int i = 0; i < numCommitSenders; i++) {
+ commitSenders.submit(this::streamingCommitLoop);
+ }
+ }
+ }
+
+ @Override
+ public void commit(Commit commit) {
+ commitQueue.put(commit);
+ }
+
+ @Override
+ public long currentActiveCommitBytes() {
+ return activeCommitBytes.get();
+ }
+
+ @Override
+ public void stop() {
+ if (!commitSenders.isTerminated() || !commitSenders.isShutdown()) {
+ commitSenders.shutdown();
+ try {
+ commitSenders.awaitTermination(10, TimeUnit.SECONDS);
+ } catch (InterruptedException e) {
+ LOG.warn("Could not shut down commitSenders gracefully, forcing
shutdown.", e);
+ }
+ commitSenders.shutdownNow();
+ }
+ drainCommitQueue();
+ }
+
+ private void drainCommitQueue() {
+ Commit queuedCommit = commitQueue.poll();
+ while (queuedCommit != null) {
+ failCommit(queuedCommit);
+ queuedCommit = commitQueue.poll();
+ }
+ }
+
+ private void failCommit(Commit commit) {
+ commit.work().setFailed();
+ onCommitComplete.accept(CompleteCommit.forFailedWork(commit));
+ }
+
+ @Override
+ public int parallelism() {
+ return numCommitSenders;
+ }
+
+ private void streamingCommitLoop() {
+ @Nullable Commit initialCommit = null;
+ try {
+ while (true) {
+ if (initialCommit == null) {
+ try {
+ // Block until we have a commit or are shutting down.
+ initialCommit = commitQueue.take();
+ } catch (InterruptedException e) {
+ continue;
+ }
+ }
+
+ if (initialCommit.work().isFailed()) {
+ onCommitComplete.accept(CompleteCommit.forFailedWork(initialCommit));
+ initialCommit = null;
+ continue;
+ }
+
+ try (CloseableStream<CommitWorkStream> closeableCommitStream =
+ commitWorkStreamFactory.get()) {
+ CommitWorkStream commitStream = closeableCommitStream.stream();
+ if (!tryAddToCommitStream(initialCommit, commitStream)) {
+ throw new AssertionError("Initial commit on flushed stream should
always be accepted.");
+ }
+ // Batch additional commits to the stream and possibly make an
un-batched commit the next
+ // initial commit.
+ initialCommit = batchCommitsToStream(commitStream);
+ commitStream.flush();
+ } catch (Exception e) {
+ LOG.error("Error occurred fetching a CommitWorkStream.", e);
+ }
+ }
+ } finally {
+ if (initialCommit != null) {
+ failCommit(initialCommit);
+ }
+ }
+ }
+
+ /** Adds the commit to the commitStream if it fits, returning true if it is
consumed. */
+ private boolean tryAddToCommitStream(Commit commit, CommitWorkStream
commitStream) {
+ Preconditions.checkNotNull(commit);
+ commit.work().setState(Work.State.COMMITTING);
+ activeCommitBytes.addAndGet(commit.getSize());
+ boolean isCommitAccepted =
+ commitStream.commitWorkItem(
+ commit.computationId(),
+ commit.request(),
+ (commitStatus) -> {
+ onCommitComplete.accept(CompleteCommit.create(commit,
commitStatus));
+ activeCommitBytes.addAndGet(-commit.getSize());
+ });
+
+ // Since the commit was not accepted, revert the changes made above.
+ if (!isCommitAccepted) {
+ commit.work().setState(Work.State.COMMIT_QUEUED);
+ activeCommitBytes.addAndGet(-commit.getSize());
+ }
+
+ return isCommitAccepted;
+ }
+
+ // Helper to batch additional commits into the commit stream as long as they
fit.
+ // Returns a commit that was removed from the queue but not consumed or null.
+ private Commit batchCommitsToStream(CommitWorkStream commitStream) {
+ int commits = 1;
+ while (true) {
+ Commit commit;
+ try {
+ if (commits < TARGET_COMMIT_BATCH_KEYS) {
+ commit = commitQueue.poll(10 - 2L * commits, TimeUnit.MILLISECONDS);
+ } else {
+ commit = commitQueue.poll();
+ }
+ } catch (InterruptedException e) {
+ // Continue processing until !running.get()
+ continue;
+ }
+
+ if (commit == null) {
+ return null;
+ }
+
+ // Drop commits for failed work. Such commits will be dropped by
Windmill anyway.
+ if (commit.work().isFailed()) {
+ onCommitComplete.accept(CompleteCommit.forFailedWork(commit));
+ continue;
+ }
+
+ if (!tryAddToCommitStream(commit, commitStream)) {
+ return commit;
+ }
+ commits++;
+ }
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/WorkCommitter.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/WorkCommitter.java
new file mode 100644
index 00000000000..11a4c00db9d
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/WorkCommitter.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.commits;
+
+import javax.annotation.concurrent.ThreadSafe;
+import org.apache.beam.sdk.annotations.Internal;
+
+/**
+ * Commits {@link org.apache.beam.runners.dataflow.worker.streaming.Work} that
has completed user
+ * processing back to persistence layer.
+ */
+@Internal
+@ThreadSafe
+public interface WorkCommitter {
+
+ /** Starts internal processing of commits. */
+ void start();
+
+ /**
+ * Add a commit to {@link WorkCommitter}. This may be block the calling
thread depending on
+ * underlying implementations, and persisting to the persistence layer may
be done asynchronously.
+ */
+ void commit(Commit commit);
+
+ /** Number of bytes currently trying to be committed to the backing
persistence layer. */
+ long currentActiveCommitBytes();
+
+ /**
+ * Stops internal processing of commits. In progress and subsequent commits
may be canceled or
+ * dropped.
+ */
+ void stop();
+
+ /**
+ * Number of internal workers {@link WorkCommitter} uses to commit work to
the backing persistence
+ * layer.
+ */
+ int parallelism();
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java
index 0d4e7c6b645..85c74fe8591 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java
@@ -34,6 +34,7 @@ import org.apache.beam.runners.dataflow.worker.Weighers;
import org.apache.beam.runners.dataflow.worker.WindmillComputationKey;
import org.apache.beam.runners.dataflow.worker.status.BaseStatusServlet;
import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider;
+import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey;
import org.apache.beam.sdk.state.State;
import org.apache.beam.sdk.util.Weighted;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
@@ -318,6 +319,10 @@ public class WindmillStateCache implements
StatusDataProvider {
keyIndex.remove(key);
}
+ public final void invalidate(ShardedKey shardedKey) {
+ invalidate(shardedKey.key(), shardedKey.shardingKey());
+ }
+
/**
* Returns a per-computation, per-key view of the state cache. Access to
the cached data for
* this key is not thread-safe. Callers should ensure that there is only a
single ForKey object
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
index e4985193d1c..89939d5d341 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
@@ -29,6 +29,7 @@ import static org.junit.Assert.assertFalse;
import java.util.ArrayList;
import java.util.Collections;
+import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -45,6 +46,7 @@ import java.util.function.Function;
import javax.annotation.concurrent.GuardedBy;
import org.apache.beam.runners.dataflow.worker.streaming.ComputationState;
import
org.apache.beam.runners.dataflow.worker.streaming.WorkHeartbeatResponseProcessor;
+import org.apache.beam.runners.dataflow.worker.streaming.WorkId;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitWorkResponse;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationCommitWorkRequest;
@@ -74,11 +76,12 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** An in-memory Windmill server that offers provided work and data. */
-public class FakeWindmillServer extends WindmillServerStub {
+public final class FakeWindmillServer extends WindmillServerStub {
private static final Logger LOG =
LoggerFactory.getLogger(FakeWindmillServer.class);
private final ResponseQueue<Windmill.GetWorkRequest,
Windmill.GetWorkResponse> workToOffer;
private final ResponseQueue<GetDataRequest, GetDataResponse> dataToOffer;
private final ResponseQueue<Windmill.CommitWorkRequest, CommitWorkResponse>
commitsToOffer;
+ private final Map<WorkId, Windmill.CommitStatus> streamingCommitsToOffer;
// Keys are work tokens.
private final Map<Long, WorkItemCommitRequest> commitsReceived;
private final ArrayList<Windmill.ReportStatsRequest> statsReceived;
@@ -109,6 +112,7 @@ public class FakeWindmillServer extends WindmillServerStub {
commitsToOffer =
new ResponseQueue<Windmill.CommitWorkRequest, CommitWorkResponse>()
.returnByDefault(CommitWorkResponse.getDefaultInstance());
+ streamingCommitsToOffer = new HashMap<>();
commitsReceived = new ConcurrentHashMap<>();
exceptions = new LinkedBlockingQueue<>();
expectedExceptionCount = new AtomicInteger();
@@ -139,6 +143,10 @@ public class FakeWindmillServer extends WindmillServerStub
{
return commitsToOffer;
}
+ public Map<WorkId, Windmill.CommitStatus> whenCommitWorkStreamCalled() {
+ return streamingCommitsToOffer;
+ }
+
@Override
public Windmill.GetWorkResponse getWork(Windmill.GetWorkRequest request) {
LOG.debug("getWorkRequest: {}", request.toString());
@@ -376,7 +384,15 @@ public class FakeWindmillServer extends WindmillServerStub
{
droppedStreamingCommits.put(request.getWorkToken(), onDone);
} else {
commitsReceived.put(request.getWorkToken(), request);
- onDone.accept(Windmill.CommitStatus.OK);
+ onDone.accept(
+ Optional.ofNullable(
+ streamingCommitsToOffer.remove(
+ WorkId.builder()
+ .setWorkToken(request.getWorkToken())
+ .setCacheToken(request.getCacheToken())
+ .build()))
+ // Default to CommitStatus.OK
+ .orElse(Windmill.CommitStatus.OK));
}
// Return true to indicate the request was accepted even if we are
dropping the commit
// to simulate a dropped commit.
@@ -502,32 +518,32 @@ public class FakeWindmillServer extends
WindmillServerStub {
this.isReady = ready;
}
- static class ResponseQueue<T, U> {
+ public static class ResponseQueue<T, U> {
private final Queue<Function<T, U>> responses = new
ConcurrentLinkedQueue<>();
Duration sleep = Duration.ZERO;
private Function<T, U> defaultResponse;
// (Fluent) interface for response producers, accessible from tests.
- ResponseQueue<T, U> thenAnswer(Function<T, U> mapFun) {
+ public ResponseQueue<T, U> thenAnswer(Function<T, U> mapFun) {
responses.add(mapFun);
return this;
}
- ResponseQueue<T, U> thenReturn(U response) {
+ public ResponseQueue<T, U> thenReturn(U response) {
return thenAnswer((request) -> response);
}
- ResponseQueue<T, U> answerByDefault(Function<T, U> mapFun) {
+ public ResponseQueue<T, U> answerByDefault(Function<T, U> mapFun) {
defaultResponse = mapFun;
return this;
}
- ResponseQueue<T, U> returnByDefault(U response) {
+ public ResponseQueue<T, U> returnByDefault(U response) {
return answerByDefault((request) -> response);
}
- ResponseQueue<T, U> delayEachResponseBy(Duration sleep) {
+ public ResponseQueue<T, U> delayEachResponseBy(Duration sleep) {
this.sleep = sleep;
return this;
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
index d00ea64d7d4..d8ead447e8e 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
@@ -3894,7 +3894,7 @@ public class StreamingDataflowWorkerTest {
options.setWindmillServiceCommitThreads(configNumCommitThreads);
StreamingDataflowWorker worker = makeWorker(instructions, options, true /*
publishCounters */);
worker.start();
- assertEquals(expectedNumCommitThreads, worker.commitThreads.size());
+ assertEquals(expectedNumCommitThreads, worker.numCommitThreads());
worker.stop();
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java
new file mode 100644
index 00000000000..cfad6138547
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.commits;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertNotNull;
+
+import com.google.api.services.dataflow.model.MapTask;
+import com.google.common.truth.Correspondence;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.function.Consumer;
+import org.apache.beam.runners.dataflow.worker.FakeWindmillServer;
+import org.apache.beam.runners.dataflow.worker.streaming.ComputationState;
+import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey;
+import org.apache.beam.runners.dataflow.worker.streaming.Work;
+import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import org.joda.time.Instant;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ErrorCollector;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.Mockito;
+
+@RunWith(JUnit4.class)
+public class StreamingApplianceWorkCommitterTest {
+ @Rule public ErrorCollector errorCollector = new ErrorCollector();
+ private FakeWindmillServer fakeWindmillServer;
+ private StreamingApplianceWorkCommitter workCommitter;
+
+ private static Work createMockWork(long workToken, Consumer<Work>
processWorkFn) {
+ return Work.create(
+ Windmill.WorkItem.newBuilder()
+ .setKey(ByteString.EMPTY)
+ .setWorkToken(workToken)
+ .setShardingKey(workToken)
+ .setCacheToken(workToken)
+ .build(),
+ Instant::now,
+ Collections.emptyList(),
+ processWorkFn);
+ }
+
+ private static ComputationState createComputationState(String computationId)
{
+ return new ComputationState(
+ computationId,
+ new MapTask().setSystemName("system").setStageName("stage"),
+ Mockito.mock(BoundedQueueExecutor.class),
+ ImmutableMap.of(),
+ null);
+ }
+
+ private StreamingApplianceWorkCommitter createWorkCommitter(
+ Consumer<CompleteCommit> onCommitComplete) {
+ return
StreamingApplianceWorkCommitter.create(fakeWindmillServer::commitWork,
onCommitComplete);
+ }
+
+ @Before
+ public void setUp() {
+ fakeWindmillServer =
+ new FakeWindmillServer(
+ errorCollector, ignored ->
Optional.of(Mockito.mock(ComputationState.class)));
+ }
+
+ @After
+ public void cleanUp() {
+ workCommitter.stop();
+ }
+
+ @Test
+ public void testCommit() {
+ List<CompleteCommit> completeCommits = new ArrayList<>();
+ workCommitter = createWorkCommitter(completeCommits::add);
+ List<Commit> commits = new ArrayList<>();
+ for (int i = 1; i <= 5; i++) {
+ Work work = createMockWork(i, ignored -> {});
+ Windmill.WorkItemCommitRequest commitRequest =
+ Windmill.WorkItemCommitRequest.newBuilder()
+ .setKey(work.getWorkItem().getKey())
+ .setShardingKey(work.getWorkItem().getShardingKey())
+ .setWorkToken(work.getWorkItem().getWorkToken())
+ .setCacheToken(work.getWorkItem().getCacheToken())
+ .build();
+ commits.add(Commit.create(commitRequest,
createComputationState("computationId-" + i), work));
+ }
+
+ workCommitter.start();
+ commits.forEach(workCommitter::commit);
+
+ Map<Long, Windmill.WorkItemCommitRequest> committed =
+ fakeWindmillServer.waitForAndGetCommits(commits.size());
+
+ for (Commit commit : commits) {
+ Windmill.WorkItemCommitRequest request =
+ committed.get(commit.work().getWorkItem().getWorkToken());
+ assertNotNull(request);
+ assertThat(request).isEqualTo(commit.request());
+ }
+
+ assertThat(completeCommits).hasSize(commits.size());
+ assertThat(completeCommits)
+ .comparingElementsUsing(
+ Correspondence.from(
+ (CompleteCommit completeCommit, Commit commit) ->
+
completeCommit.computationId().equals(commit.computationId())
+ && completeCommit.status() == Windmill.CommitStatus.OK
+ && completeCommit.workId().equals(commit.work().id())
+ && completeCommit
+ .shardedKey()
+ .equals(
+ ShardedKey.create(
+ commit.request().getKey(),
commit.request().getShardingKey())),
+ "expected to equal"))
+ .containsExactlyElementsIn(commits);
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java
new file mode 100644
index 00000000000..1bf2e44f9f0
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java
@@ -0,0 +1,308 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.commits;
+
+import static com.google.common.truth.Truth.assertThat;
+import static
org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus.OK;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+import com.google.api.services.dataflow.model.MapTask;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Consumer;
+import java.util.function.Function;
+import java.util.function.Supplier;
+import org.apache.beam.runners.dataflow.worker.FakeWindmillServer;
+import org.apache.beam.runners.dataflow.worker.streaming.ComputationState;
+import org.apache.beam.runners.dataflow.worker.streaming.Work;
+import org.apache.beam.runners.dataflow.worker.streaming.WorkId;
+import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest;
+import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool;
+import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ErrorCollector;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.Mockito;
+
+@RunWith(JUnit4.class)
+public class StreamingEngineWorkCommitterTest {
+
+ @Rule public ErrorCollector errorCollector = new ErrorCollector();
+ private StreamingEngineWorkCommitter workCommitter;
+ private FakeWindmillServer fakeWindmillServer;
+ private Supplier<CloseableStream<CommitWorkStream>> commitWorkStreamFactory;
+
+ private static Work createMockWork(long workToken, Consumer<Work>
processWorkFn) {
+ return Work.create(
+ Windmill.WorkItem.newBuilder()
+ .setKey(ByteString.EMPTY)
+ .setWorkToken(workToken)
+ .setShardingKey(workToken)
+ .setCacheToken(workToken)
+ .build(),
+ Instant::now,
+ Collections.emptyList(),
+ processWorkFn);
+ }
+
+ private static ComputationState createComputationState(String computationId)
{
+ return new ComputationState(
+ computationId,
+ new MapTask().setSystemName("system").setStageName("stage"),
+ Mockito.mock(BoundedQueueExecutor.class),
+ ImmutableMap.of(),
+ null);
+ }
+
+ private static CompleteCommit asCompleteCommit(Commit commit,
Windmill.CommitStatus status) {
+ if (commit.work().isFailed()) {
+ return CompleteCommit.forFailedWork(commit);
+ }
+
+ return CompleteCommit.create(commit, status);
+ }
+
+ @Before
+ public void setUp() throws IOException {
+ fakeWindmillServer =
+ new FakeWindmillServer(
+ errorCollector, ignored ->
Optional.of(Mockito.mock(ComputationState.class)));
+ commitWorkStreamFactory =
+ WindmillStreamPool.create(
+ 1, Duration.standardMinutes(1),
fakeWindmillServer::commitWorkStream)
+ ::getCloseableStream;
+ }
+
+ @After
+ public void cleanUp() {
+ workCommitter.stop();
+ }
+
+ private StreamingEngineWorkCommitter createWorkCommitter(
+ Consumer<CompleteCommit> onCommitComplete) {
+ return StreamingEngineWorkCommitter.create(commitWorkStreamFactory, 1,
onCommitComplete);
+ }
+
+ @Test
+ public void testCommit_sendsCommitsToStreamingEngine() {
+ Set<CompleteCommit> completeCommits = new HashSet<>();
+ workCommitter = createWorkCommitter(completeCommits::add);
+ List<Commit> commits = new ArrayList<>();
+ for (int i = 1; i <= 5; i++) {
+ Work work = createMockWork(i, ignored -> {});
+ WorkItemCommitRequest commitRequest =
+ WorkItemCommitRequest.newBuilder()
+ .setKey(work.getWorkItem().getKey())
+ .setShardingKey(work.getWorkItem().getShardingKey())
+ .setWorkToken(work.getWorkItem().getWorkToken())
+ .setCacheToken(work.getWorkItem().getCacheToken())
+ .build();
+ commits.add(Commit.create(commitRequest,
createComputationState("computationId-" + i), work));
+ }
+
+ workCommitter.start();
+ commits.parallelStream().forEach(workCommitter::commit);
+
+ Map<Long, WorkItemCommitRequest> committed =
+ fakeWindmillServer.waitForAndGetCommits(commits.size());
+
+ for (Commit commit : commits) {
+ WorkItemCommitRequest request =
committed.get(commit.work().getWorkItem().getWorkToken());
+ assertNotNull(request);
+ assertThat(request).isEqualTo(commit.request());
+ assertThat(completeCommits).contains(asCompleteCommit(commit,
Windmill.CommitStatus.OK));
+ }
+ }
+
+ @Test
+ public void testCommit_handlesFailedCommits() {
+ Set<CompleteCommit> completeCommits = new HashSet<>();
+ workCommitter = createWorkCommitter(completeCommits::add);
+ List<Commit> commits = new ArrayList<>();
+ for (int i = 1; i <= 10; i++) {
+ Work work = createMockWork(i, ignored -> {});
+ // Fail half of the work.
+ if (i % 2 == 0) {
+ work.setFailed();
+ }
+ WorkItemCommitRequest commitRequest =
+ WorkItemCommitRequest.newBuilder()
+ .setKey(work.getWorkItem().getKey())
+ .setShardingKey(work.getWorkItem().getShardingKey())
+ .setWorkToken(work.getWorkItem().getWorkToken())
+ .setCacheToken(work.getWorkItem().getCacheToken())
+ .build();
+ commits.add(Commit.create(commitRequest,
createComputationState("computationId-" + i), work));
+ }
+
+ workCommitter.start();
+ commits.parallelStream().forEach(workCommitter::commit);
+
+ Map<Long, WorkItemCommitRequest> committed =
+ fakeWindmillServer.waitForAndGetCommits(commits.size() / 2);
+
+ for (Commit commit : commits) {
+ if (commit.work().isFailed()) {
+ assertThat(completeCommits)
+ .contains(asCompleteCommit(commit, Windmill.CommitStatus.ABORTED));
+
assertThat(committed).doesNotContainKey(commit.work().getWorkItem().getWorkToken());
+ } else {
+ assertThat(completeCommits).contains(asCompleteCommit(commit,
Windmill.CommitStatus.OK));
+ assertThat(committed)
+ .containsEntry(commit.work().getWorkItem().getWorkToken(),
commit.request());
+ }
+ }
+ }
+
+ @Test
+ public void testCommit_handlesCompleteCommits_commitStatusNotOK() {
+ Set<CompleteCommit> completeCommits = new HashSet<>();
+ workCommitter = createWorkCommitter(completeCommits::add);
+ Map<WorkId, Windmill.CommitStatus> expectedCommitStatus = new HashMap<>();
+ Random commitStatusSelector = new Random();
+ int commitStatusSelectorBound = Windmill.CommitStatus.values().length - 1;
+ // Compute the CommitStatus randomly, to test plumbing of different
commitStatuses to
+ // StreamingEngine.
+ Function<Work, Windmill.CommitStatus> computeCommitStatusForTest =
+ work -> {
+ Windmill.CommitStatus commitStatus =
+ work.getWorkItem().getWorkToken() % 2 == 0
+ ? Windmill.CommitStatus.values()[
+ commitStatusSelector.nextInt(commitStatusSelectorBound)]
+ : OK;
+ expectedCommitStatus.put(work.id(), commitStatus);
+ return commitStatus;
+ };
+
+ List<Commit> commits = new ArrayList<>();
+ for (int i = 1; i <= 10; i++) {
+ Work work = createMockWork(i, ignored -> {});
+ WorkItemCommitRequest commitRequest =
+ WorkItemCommitRequest.newBuilder()
+ .setKey(work.getWorkItem().getKey())
+ .setShardingKey(work.getWorkItem().getShardingKey())
+ .setWorkToken(work.getWorkItem().getWorkToken())
+ .setCacheToken(work.getWorkItem().getCacheToken())
+ .build();
+ commits.add(Commit.create(commitRequest,
createComputationState("computationId-" + i), work));
+ fakeWindmillServer
+ .whenCommitWorkStreamCalled()
+ .put(work.id(), computeCommitStatusForTest.apply(work));
+ }
+
+ workCommitter.start();
+ commits.parallelStream().forEach(workCommitter::commit);
+
+ Map<Long, WorkItemCommitRequest> committed =
+ fakeWindmillServer.waitForAndGetCommits(commits.size());
+
+ for (Commit commit : commits) {
+ WorkItemCommitRequest request =
committed.get(commit.work().getWorkItem().getWorkToken());
+ assertNotNull(request);
+ assertThat(request).isEqualTo(commit.request());
+ assertThat(completeCommits)
+ .contains(asCompleteCommit(commit,
expectedCommitStatus.get(commit.work().id())));
+ }
+ assertThat(completeCommits.size()).isEqualTo(commits.size());
+ }
+
+ @Test
+ public void testStop_drainsCommitQueue() {
+ // Use this fake to queue up commits on the committer.
+ Supplier<CommitWorkStream> fakeCommitWorkStream =
+ () ->
+ new CommitWorkStream() {
+ @Override
+ public boolean commitWorkItem(
+ String computation,
+ WorkItemCommitRequest request,
+ Consumer<Windmill.CommitStatus> onDone) {
+ return false;
+ }
+
+ @Override
+ public void flush() {}
+
+ @Override
+ public void close() {}
+
+ @Override
+ public boolean awaitTermination(int time, TimeUnit unit) {
+ return false;
+ }
+
+ @Override
+ public Instant startTime() {
+ return Instant.now();
+ }
+ };
+ commitWorkStreamFactory =
+ WindmillStreamPool.create(1, Duration.standardMinutes(1),
fakeCommitWorkStream)
+ ::getCloseableStream;
+
+ Set<CompleteCommit> completeCommits = new HashSet<>();
+ workCommitter = createWorkCommitter(completeCommits::add);
+
+ List<Commit> commits = new ArrayList<>();
+ for (int i = 1; i <= 10; i++) {
+ Work work = createMockWork(i, ignored -> {});
+ WorkItemCommitRequest commitRequest =
+ WorkItemCommitRequest.newBuilder()
+ .setKey(work.getWorkItem().getKey())
+ .setShardingKey(work.getWorkItem().getShardingKey())
+ .setWorkToken(work.getWorkItem().getWorkToken())
+ .setCacheToken(work.getWorkItem().getCacheToken())
+ .build();
+ commits.add(Commit.create(commitRequest,
createComputationState("computationId-" + i), work));
+ }
+
+ workCommitter.start();
+ commits.parallelStream().forEach(workCommitter::commit);
+ workCommitter.stop();
+
+ assertThat(commits.size()).isEqualTo(completeCommits.size());
+ for (CompleteCommit completeCommit : completeCommits) {
+
assertThat(completeCommit.status()).isEqualTo(Windmill.CommitStatus.ABORTED);
+ }
+
+ for (Commit commit : commits) {
+ assertTrue(commit.work().isFailed());
+ }
+ }
+}