This is an automated email from the ASF dual-hosted git repository.
scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 5e7edc45598 Heartbeats (#29963)
5e7edc45598 is described below
commit 5e7edc45598b6438761386856e96a66487704b69
Author: Andrew Crites <[email protected]>
AuthorDate: Tue Jan 23 01:28:20 2024 -0800
Heartbeats (#29963)
* Adds sending new HeartbeatRequest protos to StreamingDataflowWorker. If
any HeartbeatResponses are sent from Windmill containing failed work items,
aborts processing those work items as soon as possible.
* Adds sending new HeartbeatRequest protos when using streaming RPC's
(streaming engine). Also adds a test.
* Adds new test for custom source reader exiting early for failed work.
Adds special exception for handling failed work.
* removes some extra cache invalidations and unneeded log statements.
* Added streaming_engine prefix to experiment enabling heartbeats and
changed exception in state reader to be WorkItemFailedException.
* Adds check that heartbeat response sets failed before failing work.
* Adds ability to plumb experiments to test server for
GrpcWindmillServerTest so we can test the new style heartbeats.
* Changes StreamingDataflowWorkerTest to look for latency attribution in
new-style heartbeat requests since that's what FakeWindmillServer returns now.
---
.../worker/MetricTrackingWindmillServerStub.java | 32 ++--
.../beam/runners/dataflow/worker/PubsubReader.java | 8 +
.../dataflow/worker/StreamingDataflowWorker.java | 51 +++++-
.../worker/StreamingModeExecutionContext.java | 14 +-
.../dataflow/worker/UngroupedWindmillReader.java | 8 +
.../worker/WorkItemCancelledException.java | 39 +++++
.../dataflow/worker/WorkerCustomSources.java | 3 +-
.../dataflow/worker/streaming/ActiveWorkState.java | 66 +++++++-
.../worker/streaming/ComputationState.java | 14 +-
.../runners/dataflow/worker/streaming/Work.java | 11 ++
.../worker/windmill/WindmillServerStub.java | 6 +
.../worker/windmill/client/WindmillStream.java | 5 +-
.../windmill/client/grpc/GrpcGetDataStream.java | 107 ++++++++++---
.../windmill/client/grpc/GrpcWindmillServer.java | 35 +++-
.../client/grpc/GrpcWindmillStreamFactory.java | 16 +-
.../worker/windmill/state/WindmillStateCache.java | 5 +
.../worker/windmill/state/WindmillStateReader.java | 18 ++-
.../dataflow/worker/FakeWindmillServer.java | 47 +++++-
.../worker/StreamingDataflowWorkerTest.java | 76 ++++++++-
.../worker/StreamingModeExecutionContextTest.java | 6 +-
.../dataflow/worker/WorkerCustomSourcesTest.java | 82 +++++++++-
.../worker/streaming/ActiveWorkStateTest.java | 50 +++---
.../client/grpc/GrpcWindmillServerTest.java | 177 ++++++++++++++++++---
.../worker/windmill/src/main/proto/windmill.proto | 51 +++++-
24 files changed, 801 insertions(+), 126 deletions(-)
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricTrackingWindmillServerStub.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricTrackingWindmillServerStub.java
index 0e929249b3a..800504f4451 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricTrackingWindmillServerStub.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricTrackingWindmillServerStub.java
@@ -27,7 +27,7 @@ import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.concurrent.GuardedBy;
import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
-import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub;
import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool;
@@ -239,25 +239,37 @@ public class MetricTrackingWindmillServerStub {
}
/** Tells windmill processing is ongoing for the given keys. */
- public void refreshActiveWork(Map<String, List<KeyedGetDataRequest>> active)
{
- activeHeartbeats.set(active.size());
+ public void refreshActiveWork(Map<String, List<HeartbeatRequest>>
heartbeats) {
+ activeHeartbeats.set(heartbeats.size());
try {
if (useStreamingRequests) {
// With streaming requests, always send the request even when it is
empty, to ensure that
// we trigger health checks for the stream even when it is idle.
GetDataStream stream = streamPool.getStream();
try {
- stream.refreshActiveWork(active);
+ stream.refreshActiveWork(heartbeats);
} finally {
streamPool.releaseStream(stream);
}
- } else if (!active.isEmpty()) {
+ } else if (!heartbeats.isEmpty()) {
+ // This code path is only used by appliance which sends heartbeats
(used to refresh active
+ // work) as KeyedGetDataRequests. So we must translate the
HeartbeatRequest to a
+ // KeyedGetDataRequest here regardless of the value of
sendKeyedGetDataRequests.
Windmill.GetDataRequest.Builder builder =
Windmill.GetDataRequest.newBuilder();
- for (Map.Entry<String, List<KeyedGetDataRequest>> entry :
active.entrySet()) {
- builder.addRequests(
- Windmill.ComputationGetDataRequest.newBuilder()
- .setComputationId(entry.getKey())
- .addAllRequests(entry.getValue()));
+ for (Map.Entry<String, List<HeartbeatRequest>> entry :
heartbeats.entrySet()) {
+ Windmill.ComputationGetDataRequest.Builder perComputationBuilder =
+ Windmill.ComputationGetDataRequest.newBuilder();
+ perComputationBuilder.setComputationId(entry.getKey());
+ for (HeartbeatRequest request : entry.getValue()) {
+ perComputationBuilder.addRequests(
+ Windmill.KeyedGetDataRequest.newBuilder()
+ .setShardingKey(request.getShardingKey())
+ .setWorkToken(request.getWorkToken())
+ .setCacheToken(request.getCacheToken())
+
.addAllLatencyAttribution(request.getLatencyAttributionList())
+ .build());
+ }
+ builder.addRequests(perComputationBuilder.build());
}
server.getData(builder.build());
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/PubsubReader.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/PubsubReader.java
index d0931e02cc8..be0bccec026 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/PubsubReader.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/PubsubReader.java
@@ -112,6 +112,14 @@ class PubsubReader<T> extends
NativeReader<WindowedValue<T>> {
super(work);
}
+ @Override
+ public boolean advance() throws IOException {
+ if (context.workIsFailed()) {
+ return false;
+ }
+ return super.advance();
+ }
+
@Override
protected WindowedValue<T> decodeMessage(Windmill.Message message) throws
IOException {
T value;
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
index f2d7c02729c..a95e7828881 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
@@ -85,6 +85,7 @@ import
org.apache.beam.runners.dataflow.worker.status.DebugCapture.Capturable;
import
org.apache.beam.runners.dataflow.worker.status.LastExceptionDataProvider;
import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider;
import org.apache.beam.runners.dataflow.worker.status.WorkerStatusPages;
+import
org.apache.beam.runners.dataflow.worker.streaming.ActiveWorkState.FailedTokens;
import org.apache.beam.runners.dataflow.worker.streaming.Commit;
import org.apache.beam.runners.dataflow.worker.streaming.ComputationState;
import org.apache.beam.runners.dataflow.worker.streaming.ExecutionState;
@@ -422,6 +423,7 @@ public class StreamingDataflowWorker {
this.publishCounters = publishCounters;
this.windmillServer = options.getWindmillServerStub();
+
this.windmillServer.setProcessHeartbeatResponses(this::handleHeartbeatResponses);
this.metricTrackingWindmillServer =
new MetricTrackingWindmillServerStub(windmillServer, memoryMonitor,
windmillServiceEnabled);
this.metricTrackingWindmillServer.start();
@@ -982,6 +984,9 @@ public class StreamingDataflowWorker {
String counterName = "dataflow_source_bytes_processed-" +
mapTask.getSystemName();
try {
+ if (work.isFailed()) {
+ throw new WorkItemCancelledException(workItem.getShardingKey());
+ }
executionState = computationState.getExecutionStateQueue().poll();
if (executionState == null) {
MutableNetwork<Node, Edge> mapTaskNetwork =
mapTaskToNetwork.apply(mapTask);
@@ -1098,7 +1103,8 @@ public class StreamingDataflowWorker {
work.setState(State.PROCESSING);
}
};
- });
+ },
+ work::isFailed);
SideInputStateFetcher localSideInputStateFetcher =
sideInputStateFetcher.byteTrackingView();
// If the read output KVs, then we can decode Windmill's byte key into a
userland
@@ -1136,12 +1142,16 @@ public class StreamingDataflowWorker {
synchronizedProcessingTime,
stateReader,
localSideInputStateFetcher,
- outputBuilder);
+ outputBuilder,
+ work::isFailed);
// Blocks while executing work.
executionState.workExecutor().execute();
- // Reports source bytes processed to workitemcommitrequest if available.
+ if (work.isFailed()) {
+ throw new WorkItemCancelledException(workItem.getShardingKey());
+ }
+ // Reports source bytes processed to WorkItemCommitRequest if available.
try {
long sourceBytesProcessed = 0;
HashMap<String, ElementCounter> counters =
@@ -1234,6 +1244,12 @@ public class StreamingDataflowWorker {
+ "Work will not be retried locally.",
computationId,
key.toStringUtf8());
+ } else if (WorkItemCancelledException.isWorkItemCancelledException(t)) {
+ LOG.debug(
+ "Execution of work for computation '{}' on key '{}' failed. "
+ + "Work will not be retried locally.",
+ computationId,
+ workItem.getShardingKey());
} else {
LastExceptionDataProvider.reportException(t);
LOG.debug("Failed work: {}", work);
@@ -1369,6 +1385,10 @@ public class StreamingDataflowWorker {
// Adds the commit to the commitStream if it fits, returning true iff it is
consumed.
private boolean addCommitToStream(Commit commit, CommitWorkStream
commitStream) {
Preconditions.checkNotNull(commit);
+ // Drop commits for failed work. Such commits will be dropped by Windmill
anyway.
+ if (commit.work().isFailed()) {
+ return true;
+ }
final ComputationState state = commit.computationState();
final Windmill.WorkItemCommitRequest request = commit.request();
final int size = commit.getSize();
@@ -1896,6 +1916,25 @@ public class StreamingDataflowWorker {
}
}
+ public void
handleHeartbeatResponses(List<Windmill.ComputationHeartbeatResponse> responses)
{
+ for (Windmill.ComputationHeartbeatResponse computationHeartbeatResponse :
responses) {
+ // Maps sharding key to (work token, cache token) for work that should
be marked failed.
+ Map<Long, List<FailedTokens>> failedWork = new HashMap<>();
+ for (Windmill.HeartbeatResponse heartbeatResponse :
+ computationHeartbeatResponse.getHeartbeatResponsesList()) {
+ if (heartbeatResponse.getFailed()) {
+ failedWork
+ .computeIfAbsent(heartbeatResponse.getShardingKey(), key -> new
ArrayList<>())
+ .add(
+ new FailedTokens(
+ heartbeatResponse.getWorkToken(),
heartbeatResponse.getCacheToken()));
+ }
+ }
+ ComputationState state =
computationMap.get(computationHeartbeatResponse.getComputationId());
+ if (state != null) state.failWork(failedWork);
+ }
+ }
+
/**
* Sends a GetData request to Windmill for all sufficiently old active work.
*
@@ -1904,15 +1943,15 @@ public class StreamingDataflowWorker {
* StreamingDataflowWorkerOptions#getActiveWorkRefreshPeriodMillis}.
*/
private void refreshActiveWork() {
- Map<String, List<Windmill.KeyedGetDataRequest>> active = new HashMap<>();
+ Map<String, List<Windmill.HeartbeatRequest>> heartbeats = new HashMap<>();
Instant refreshDeadline =
clock.get().minus(Duration.millis(options.getActiveWorkRefreshPeriodMillis()));
for (Map.Entry<String, ComputationState> entry :
computationMap.entrySet()) {
- active.put(entry.getKey(),
entry.getValue().getKeysToRefresh(refreshDeadline, sampler));
+ heartbeats.put(entry.getKey(),
entry.getValue().getKeyHeartbeats(refreshDeadline, sampler));
}
- metricTrackingWindmillServer.refreshActiveWork(active);
+ metricTrackingWindmillServer.refreshActiveWork(heartbeats);
}
private void invalidateStuckCommits() {
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java
index d630601c28a..83cf49112a8 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java
@@ -112,6 +112,7 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
private Windmill.WorkItemCommitRequest.Builder outputBuilder;
private UnboundedSource.UnboundedReader<?> activeReader;
private volatile long backlogBytes;
+ private Supplier<Boolean> workIsFailed;
public StreamingModeExecutionContext(
CounterFactory counterFactory,
@@ -135,6 +136,7 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
this.stateNameMap = ImmutableMap.copyOf(stateNameMap);
this.stateCache = stateCache;
this.backlogBytes = UnboundedSource.UnboundedReader.BACKLOG_UNKNOWN;
+ this.workIsFailed = () -> Boolean.FALSE;
}
@VisibleForTesting
@@ -142,6 +144,10 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
return backlogBytes;
}
+ public boolean workIsFailed() {
+ return workIsFailed.get();
+ }
+
public void start(
@Nullable Object key,
Windmill.WorkItem work,
@@ -150,9 +156,11 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
@Nullable Instant synchronizedProcessingTime,
WindmillStateReader stateReader,
SideInputStateFetcher sideInputStateFetcher,
- Windmill.WorkItemCommitRequest.Builder outputBuilder) {
+ Windmill.WorkItemCommitRequest.Builder outputBuilder,
+ @Nullable Supplier<Boolean> workFailed) {
this.key = key;
this.work = work;
+ this.workIsFailed = (workFailed != null) ? workFailed : () ->
Boolean.FALSE;
this.computationKey =
WindmillComputationKey.create(computationId, work.getKey(),
work.getShardingKey());
this.sideInputStateFetcher = sideInputStateFetcher;
@@ -429,7 +437,7 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
/**
* Execution states in Streaming are shared between multiple map-task
executors. Thus this class
- * needs to be thread safe for multiple writers. A single stage could have
have multiple executors
+ * needs to be thread safe for multiple writers. A single stage could have
multiple executors
* running concurrently.
*/
public static class StreamingModeExecutionState
@@ -670,7 +678,7 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
private NavigableSet<TimerData>
modifiedUserSynchronizedProcessingTimersOrdered = null;
// A list of timer keys that were modified by user processing earlier in
this bundle. This
// serves a tombstone, so
- // that we know not to fire any bundle tiemrs that were moddified.
+ // that we know not to fire any bundle timers that were modified.
private Table<String, StateNamespace, TimerData> modifiedUserTimerKeys =
null;
public StepContext(DataflowOperationContext operationContext) {
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/UngroupedWindmillReader.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/UngroupedWindmillReader.java
index e4e56a96c15..4aac93ceb3f 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/UngroupedWindmillReader.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/UngroupedWindmillReader.java
@@ -99,6 +99,14 @@ class UngroupedWindmillReader<T> extends
NativeReader<WindowedValue<T>> {
super(work);
}
+ @Override
+ public boolean advance() throws IOException {
+ if (context.workIsFailed()) {
+ return false;
+ }
+ return super.advance();
+ }
+
@Override
protected WindowedValue<T> decodeMessage(Windmill.Message message) throws
IOException {
Instant timestampMillis =
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java
new file mode 100644
index 00000000000..934977fe098
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker;
+
+/** Indicates that the work item was cancelled and should not be retried. */
+@SuppressWarnings({
+ "nullness" // TODO(https://github.com/apache/beam/issues/20497)
+})
+public class WorkItemCancelledException extends RuntimeException {
+ public WorkItemCancelledException(long sharding_key) {
+ super("Work item cancelled for key " + sharding_key);
+ }
+
+ /** Returns whether an exception was caused by a {@link
WorkItemCancelledException}. */
+ public static boolean isWorkItemCancelledException(Throwable t) {
+ while (t != null) {
+ if (t instanceof WorkItemCancelledException) {
+ return true;
+ }
+ t = t.getCause();
+ }
+ return false;
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSources.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSources.java
index a9050236efc..2dc3494af5e 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSources.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSources.java
@@ -836,7 +836,8 @@ public class WorkerCustomSources {
while (true) {
if (elemsRead >= maxElems
|| Instant.now().isAfter(endTime)
- || context.isSinkFullHintSet()) {
+ || context.isSinkFullHintSet()
+ || context.workIsFailed()) {
return false;
}
try {
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java
index 16266de9d47..54942dfeee1 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java
@@ -23,6 +23,7 @@ import java.io.PrintWriter;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
@@ -34,8 +35,10 @@ import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.ThreadSafe;
import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
-import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem;
import
org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache;
+import org.apache.beam.sdk.annotations.Internal;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
@@ -50,7 +53,8 @@ import org.slf4j.LoggerFactory;
* activate, queue, and complete {@link Work} (including invalidating stuck
{@link Work}).
*/
@ThreadSafe
-final class ActiveWorkState {
+@Internal
+public final class ActiveWorkState {
private static final Logger LOG =
LoggerFactory.getLogger(ActiveWorkState.class);
/* The max number of keys in COMMITTING or COMMIT_QUEUED status to be
shown.*/
@@ -120,6 +124,50 @@ final class ActiveWorkState {
return ActivateWorkResult.QUEUED;
}
+ public static final class FailedTokens {
+ public long workToken;
+ public long cacheToken;
+
+ public FailedTokens(long workToken, long cacheToken) {
+ this.workToken = workToken;
+ this.cacheToken = cacheToken;
+ }
+ }
+
+ /**
+ * Fails any active work matching an element of the input Map.
+ *
+ * @param failedWork a map from sharding_key to tokens for the corresponding
work.
+ */
+ synchronized void failWorkForKey(Map<Long, List<FailedTokens>> failedWork) {
+ // Note we can't construct a ShardedKey and look it up in activeWork
directly since
+ // HeartbeatResponse doesn't include the user key.
+ for (Entry<ShardedKey, Deque<Work>> entry : activeWork.entrySet()) {
+ List<FailedTokens> failedTokens =
failedWork.get(entry.getKey().shardingKey());
+ if (failedTokens == null) continue;
+ for (FailedTokens failedToken : failedTokens) {
+ for (Work queuedWork : entry.getValue()) {
+ WorkItem workItem = queuedWork.getWorkItem();
+ if (workItem.getWorkToken() == failedToken.workToken
+ && workItem.getCacheToken() == failedToken.cacheToken) {
+ LOG.debug(
+ "Failing work "
+ + computationStateCache.getComputation()
+ + " "
+ + entry.getKey().shardingKey()
+ + " "
+ + failedToken.workToken
+ + " "
+ + failedToken.cacheToken
+ + ". The work will be retried and is not lost.");
+ queuedWork.setFailed();
+ break;
+ }
+ }
+ }
+ }
+ }
+
/**
* Removes the complete work from the {@link Queue<Work>}. The {@link Work}
is marked as completed
* if its workToken matches the one that is passed in. Returns the next
{@link Work} in the {@link
@@ -211,14 +259,14 @@ final class ActiveWorkState {
return stuckCommits.build();
}
- synchronized ImmutableList<KeyedGetDataRequest> getKeysToRefresh(
+ synchronized ImmutableList<HeartbeatRequest> getKeyHeartbeats(
Instant refreshDeadline, DataflowExecutionStateSampler sampler) {
return activeWork.entrySet().stream()
- .flatMap(entry -> toKeyedGetDataRequestStream(entry, refreshDeadline,
sampler))
+ .flatMap(entry -> toHeartbeatRequestStream(entry, refreshDeadline,
sampler))
.collect(toImmutableList());
}
- private static Stream<KeyedGetDataRequest> toKeyedGetDataRequestStream(
+ private static Stream<HeartbeatRequest> toHeartbeatRequestStream(
Entry<ShardedKey, Deque<Work>> shardedKeyAndWorkQueue,
Instant refreshDeadline,
DataflowExecutionStateSampler sampler) {
@@ -227,12 +275,14 @@ final class ActiveWorkState {
return workQueue.stream()
.filter(work -> work.getStartTime().isBefore(refreshDeadline))
+ // Don't send heartbeats for queued work we already know is failed.
+ .filter(work -> !work.isFailed())
.map(
work ->
- Windmill.KeyedGetDataRequest.newBuilder()
- .setKey(shardedKey.key())
+ Windmill.HeartbeatRequest.newBuilder()
.setShardingKey(shardedKey.shardingKey())
.setWorkToken(work.getWorkItem().getWorkToken())
+ .setCacheToken(work.getWorkItem().getCacheToken())
.addAllLatencyAttribution(
work.getLatencyAttributions(true,
work.getLatencyTrackingId(), sampler))
.build());
@@ -250,7 +300,7 @@ final class ActiveWorkState {
for (Map.Entry<ShardedKey, Deque<Work>> entry : activeWork.entrySet()) {
Queue<Work> workQueue = Preconditions.checkNotNull(entry.getValue());
Work activeWork = Preconditions.checkNotNull(workQueue.peek());
- Windmill.WorkItem workItem = activeWork.getWorkItem();
+ WorkItem workItem = activeWork.getWorkItem();
if (activeWork.isCommitPending()) {
if (++commitsPendingCount >= MAX_PRINTABLE_COMMIT_PENDING_KEYS) {
continue;
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java
index 4ac1d8bc9fa..8207a6ef2f0 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java
@@ -19,12 +19,14 @@ package org.apache.beam.runners.dataflow.worker.streaming;
import com.google.api.services.dataflow.model.MapTask;
import java.io.PrintWriter;
+import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentLinkedQueue;
import javax.annotation.Nullable;
import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler;
+import
org.apache.beam.runners.dataflow.worker.streaming.ActiveWorkState.FailedTokens;
import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor;
-import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest;
import
org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
@@ -98,6 +100,10 @@ public class ComputationState implements AutoCloseable {
}
}
+ public void failWork(Map<Long, List<FailedTokens>> failedWork) {
+ activeWorkState.failWorkForKey(failedWork);
+ }
+
/**
* Marks the work for the given shardedKey as complete. Schedules queued
work for the key if any.
*/
@@ -120,10 +126,10 @@ public class ComputationState implements AutoCloseable {
executor.forceExecute(work, work.getWorkItem().getSerializedSize());
}
- /** Adds any work started before the refreshDeadline to the GetDataRequest
builder. */
- public ImmutableList<KeyedGetDataRequest> getKeysToRefresh(
+ /** Gets HeartbeatRequests for any work started before refreshDeadline. */
+ public ImmutableList<HeartbeatRequest> getKeyHeartbeats(
Instant refreshDeadline, DataflowExecutionStateSampler sampler) {
- return activeWorkState.getKeysToRefresh(refreshDeadline, sampler);
+ return activeWorkState.getKeyHeartbeats(refreshDeadline, sampler);
}
public void printActiveWork(PrintWriter writer) {
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
index 3a77a8322b4..69f2a0dcee7 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
@@ -50,6 +50,8 @@ public class Work implements Runnable {
private final Consumer<Work> processWorkFn;
private TimedState currentState;
+ private boolean isFailed;
+
private Work(Windmill.WorkItem workItem, Supplier<Instant> clock,
Consumer<Work> processWorkFn) {
this.workItem = workItem;
this.clock = clock;
@@ -57,6 +59,7 @@ public class Work implements Runnable {
this.startTime = clock.get();
this.totalDurationPerState = new
EnumMap<>(Windmill.LatencyAttribution.State.class);
this.currentState = TimedState.initialState(startTime);
+ this.isFailed = false;
}
public static Work create(
@@ -95,6 +98,10 @@ public class Work implements Runnable {
this.currentState = TimedState.create(state, now);
}
+ public void setFailed() {
+ this.isFailed = true;
+ }
+
public boolean isCommitPending() {
return currentState.isCommitPending();
}
@@ -180,6 +187,10 @@ public class Work implements Runnable {
return builder;
}
+ public boolean isFailed() {
+ return isFailed;
+ }
+
boolean isStuckCommittingAt(Instant stuckCommitDeadline) {
return currentState.state() == Work.State.COMMITTING
&& currentState.startTime().isBefore(stuckCommitDeadline);
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java
index c327e68d7e9..25581bee208 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java
@@ -19,8 +19,11 @@ package org.apache.beam.runners.dataflow.worker.windmill;
import java.io.IOException;
import java.io.PrintWriter;
+import java.util.List;
import java.util.Set;
+import java.util.function.Consumer;
import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse;
import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream;
import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream;
@@ -79,6 +82,9 @@ public abstract class WindmillServerStub implements
StatusDataProvider {
@Override
public void appendSummaryHtml(PrintWriter writer) {}
+ public void setProcessHeartbeatResponses(
+ Consumer<List<ComputationHeartbeatResponse>> processHeartbeatResponses)
{}
+
/** Generic Exception type for implementors to use to represent errors while
making RPCs. */
public static final class RpcException extends RuntimeException {
public RpcException(Throwable cause) {
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java
index fa1f797a191..7c22f4fb576 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java
@@ -23,6 +23,7 @@ import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import javax.annotation.concurrent.ThreadSafe;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest;
import
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
import org.joda.time.Instant;
@@ -59,7 +60,9 @@ public interface WindmillStream {
Windmill.GlobalData requestGlobalData(Windmill.GlobalDataRequest request);
/** Tells windmill processing is ongoing for the given keys. */
- void refreshActiveWork(Map<String, List<Windmill.KeyedGetDataRequest>>
active);
+ void refreshActiveWork(Map<String, List<HeartbeatRequest>> heartbeats);
+
+ void onHeartbeatResponse(List<Windmill.ComputationHeartbeatResponse>
responses);
}
/** Interface for streaming CommitWorkRequests to Windmill. */
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java
index a04a961ca9c..b6600e04a09 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java
@@ -32,10 +32,15 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.Consumer;
import java.util.function.Function;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatRequest;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataResponse;
@@ -64,6 +69,10 @@ public final class GrpcGetDataStream
private final ThrottleTimer getDataThrottleTimer;
private final JobHeader jobHeader;
private final int streamingRpcBatchLimit;
+ // If true, then active work refreshes will be sent as KeyedGetDataRequests.
Otherwise, use the
+ // newer ComputationHeartbeatRequests.
+ private final boolean sendKeyedGetDataRequests;
+ private Consumer<List<ComputationHeartbeatResponse>>
processHeartbeatResponses;
private GrpcGetDataStream(
Function<StreamObserver<StreamingGetDataResponse>,
StreamObserver<StreamingGetDataRequest>>
@@ -75,7 +84,9 @@ public final class GrpcGetDataStream
ThrottleTimer getDataThrottleTimer,
JobHeader jobHeader,
AtomicLong idGenerator,
- int streamingRpcBatchLimit) {
+ int streamingRpcBatchLimit,
+ boolean sendKeyedGetDataRequests,
+ Consumer<List<Windmill.ComputationHeartbeatResponse>>
processHeartbeatResponses) {
super(
startGetDataRpcFn, backoff, streamObserverFactory, streamRegistry,
logEveryNStreamFailures);
this.idGenerator = idGenerator;
@@ -84,6 +95,8 @@ public final class GrpcGetDataStream
this.streamingRpcBatchLimit = streamingRpcBatchLimit;
this.batches = new ConcurrentLinkedDeque<>();
this.pending = new ConcurrentHashMap<>();
+ this.sendKeyedGetDataRequests = sendKeyedGetDataRequests;
+ this.processHeartbeatResponses = processHeartbeatResponses;
}
public static GrpcGetDataStream create(
@@ -96,7 +109,9 @@ public final class GrpcGetDataStream
ThrottleTimer getDataThrottleTimer,
JobHeader jobHeader,
AtomicLong idGenerator,
- int streamingRpcBatchLimit) {
+ int streamingRpcBatchLimit,
+ boolean sendKeyedGetDataRequests,
+ Consumer<List<Windmill.ComputationHeartbeatResponse>>
processHeartbeatResponses) {
GrpcGetDataStream getDataStream =
new GrpcGetDataStream(
startGetDataRpcFn,
@@ -107,7 +122,9 @@ public final class GrpcGetDataStream
getDataThrottleTimer,
jobHeader,
idGenerator,
- streamingRpcBatchLimit);
+ streamingRpcBatchLimit,
+ sendKeyedGetDataRequests,
+ processHeartbeatResponses);
getDataStream.startStream();
return getDataStream;
}
@@ -138,6 +155,7 @@ public final class GrpcGetDataStream
checkArgument(chunk.getRequestIdCount() ==
chunk.getSerializedResponseCount());
checkArgument(chunk.getRemainingBytesForResponse() == 0 ||
chunk.getRequestIdCount() == 1);
getDataThrottleTimer.stop();
+ onHeartbeatResponse(chunk.getComputationHeartbeatResponseList());
for (int i = 0; i < chunk.getRequestIdCount(); ++i) {
AppendableInputStream responseStream =
pending.get(chunk.getRequestId(i));
@@ -171,30 +189,71 @@ public final class GrpcGetDataStream
}
@Override
- public void refreshActiveWork(Map<String, List<KeyedGetDataRequest>> active)
{
- long builderBytes = 0;
+ public void refreshActiveWork(Map<String, List<HeartbeatRequest>>
heartbeats) {
StreamingGetDataRequest.Builder builder =
StreamingGetDataRequest.newBuilder();
- for (Map.Entry<String, List<KeyedGetDataRequest>> entry :
active.entrySet()) {
- for (KeyedGetDataRequest request : entry.getValue()) {
- // Calculate the bytes with some overhead for proto encoding.
- long bytes = (long) entry.getKey().length() +
request.getSerializedSize() + 10;
- if (builderBytes > 0
- && (builderBytes + bytes >
AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE
- || builder.getRequestIdCount() >= streamingRpcBatchLimit)) {
- send(builder.build());
- builderBytes = 0;
- builder.clear();
+ if (sendKeyedGetDataRequests) {
+ long builderBytes = 0;
+ for (Map.Entry<String, List<HeartbeatRequest>> entry :
heartbeats.entrySet()) {
+ for (HeartbeatRequest request : entry.getValue()) {
+ // Calculate the bytes with some overhead for proto encoding.
+ long bytes = (long) entry.getKey().length() +
request.getSerializedSize() + 10;
+ if (builderBytes > 0
+ && (builderBytes + bytes >
AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE
+ || builder.getRequestIdCount() >= streamingRpcBatchLimit)) {
+ send(builder.build());
+ builderBytes = 0;
+ builder.clear();
+ }
+ builderBytes += bytes;
+ builder.addStateRequest(
+ ComputationGetDataRequest.newBuilder()
+ .setComputationId(entry.getKey())
+ .addRequests(
+ Windmill.KeyedGetDataRequest.newBuilder()
+ .setShardingKey(request.getShardingKey())
+ .setWorkToken(request.getWorkToken())
+ .setCacheToken(request.getCacheToken())
+
.addAllLatencyAttribution(request.getLatencyAttributionList())
+ .build()));
}
- builderBytes += bytes;
- builder.addStateRequest(
- ComputationGetDataRequest.newBuilder()
- .setComputationId(entry.getKey())
- .addRequests(request));
+ }
+
+ if (builderBytes > 0) {
+ send(builder.build());
+ }
+ } else {
+ // No translation necessary, but we must still respect
`RPC_STREAM_CHUNK_SIZE`.
+ long builderBytes = 0;
+ for (Map.Entry<String, List<HeartbeatRequest>> entry :
heartbeats.entrySet()) {
+ ComputationHeartbeatRequest.Builder computationHeartbeatBuilder =
+
ComputationHeartbeatRequest.newBuilder().setComputationId(entry.getKey());
+ for (HeartbeatRequest request : entry.getValue()) {
+ long bytes = (long) entry.getKey().length() +
request.getSerializedSize() + 10;
+ if (builderBytes > 0
+ && builderBytes + bytes >
AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) {
+ if (computationHeartbeatBuilder.getHeartbeatRequestsCount() > 0) {
+
builder.addComputationHeartbeatRequest(computationHeartbeatBuilder.build());
+ }
+ send(builder.build());
+ builderBytes = 0;
+ builder.clear();
+
computationHeartbeatBuilder.clear().setComputationId(entry.getKey());
+ }
+ builderBytes += bytes;
+ computationHeartbeatBuilder.addHeartbeatRequests(request);
+ }
+
builder.addComputationHeartbeatRequest(computationHeartbeatBuilder.build());
+ }
+
+ if (builderBytes > 0) {
+ send(builder.build());
}
}
- if (builderBytes > 0) {
- send(builder.build());
- }
+ }
+
+ @Override
+ public void onHeartbeatResponse(List<Windmill.ComputationHeartbeatResponse>
responses) {
+ processHeartbeatResponses.accept(responses);
}
@Override
@@ -277,7 +336,7 @@ public final class GrpcGetDataStream
waitForSendLatch.await();
}
// Finalize the batch so that no additional requests will be added.
Leave the batch in the
- // queue so that a subsequent batch will wait for it's completion.
+ // queue so that a subsequent batch will wait for its completion.
synchronized (batches) {
verify(batch == batches.peekFirst());
batch.markFinalized();
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java
index 3a881df7146..9f0126a9cc6 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java
@@ -28,13 +28,17 @@ import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
+import java.util.function.Consumer;
import java.util.function.Supplier;
import javax.annotation.Nullable;
+import org.apache.beam.runners.dataflow.DataflowRunner;
import
org.apache.beam.runners.dataflow.worker.options.StreamingDataflowWorkerOptions;
import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitWorkRequest;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitWorkResponse;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetConfigRequest;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetConfigResponse;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetDataRequest;
@@ -93,6 +97,10 @@ public final class GrpcWindmillServer extends
WindmillServerStub {
private final StreamingEngineThrottleTimers throttleTimers;
private Duration maxBackoff;
private @Nullable WindmillApplianceGrpc.WindmillApplianceBlockingStub
syncApplianceStub;
+ // If true, then active work refreshes will be sent as KeyedGetDataRequests.
Otherwise, use the
+ // newer ComputationHeartbeatRequests.
+ private final boolean sendKeyedGetDataRequests;
+ private Consumer<List<ComputationHeartbeatResponse>>
processHeartbeatResponses;
private GrpcWindmillServer(
StreamingDataflowWorkerOptions options, GrpcDispatcherClient
grpcDispatcherClient) {
@@ -118,9 +126,21 @@ public final class GrpcWindmillServer extends
WindmillServerStub {
this.dispatcherClient = grpcDispatcherClient;
this.syncApplianceStub = null;
+ this.sendKeyedGetDataRequests =
+ !options.isEnableStreamingEngine()
+ || !DataflowRunner.hasExperiment(
+ options, "streaming_engine_send_new_heartbeat_requests");
+ this.processHeartbeatResponses = (responses) -> {};
}
- private static StreamingDataflowWorkerOptions testOptions(boolean
enableStreamingEngine) {
+ @Override
+ public void setProcessHeartbeatResponses(
+ Consumer<List<Windmill.ComputationHeartbeatResponse>>
processHeartbeatResponses) {
+ this.processHeartbeatResponses = processHeartbeatResponses;
+ };
+
+ private static StreamingDataflowWorkerOptions testOptions(
+ boolean enableStreamingEngine, List<String> additionalExperiments) {
StreamingDataflowWorkerOptions options =
PipelineOptionsFactory.create().as(StreamingDataflowWorkerOptions.class);
options.setProject("project");
@@ -131,6 +151,7 @@ public final class GrpcWindmillServer extends
WindmillServerStub {
if (enableStreamingEngine) {
experiments.add(GcpOptions.STREAMING_ENGINE_EXPERIMENT);
}
+ experiments.addAll(additionalExperiments);
options.setExperiments(experiments);
options.setWindmillServiceStreamingRpcBatchLimit(Integer.MAX_VALUE);
@@ -162,7 +183,7 @@ public final class GrpcWindmillServer extends
WindmillServerStub {
}
@VisibleForTesting
- static GrpcWindmillServer newTestInstance(String name) {
+ static GrpcWindmillServer newTestInstance(String name, List<String>
experiments) {
ManagedChannel inProcessChannel = inProcessChannel(name);
CloudWindmillServiceV1Alpha1Stub stub =
CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel);
@@ -173,14 +194,15 @@ public final class GrpcWindmillServer extends
WindmillServerStub {
WindmillStubFactory.inProcessStubFactory(name, unused ->
inProcessChannel),
dispatcherStubs,
dispatcherEndpoints);
- return new GrpcWindmillServer(testOptions(/* enableStreamingEngine= */
true), dispatcherClient);
+ return new GrpcWindmillServer(
+ testOptions(/* enableStreamingEngine= */ true, experiments),
dispatcherClient);
}
@VisibleForTesting
static GrpcWindmillServer newApplianceTestInstance(Channel channel) {
GrpcWindmillServer testServer =
new GrpcWindmillServer(
- testOptions(/* enableStreamingEngine= */ false),
+ testOptions(/* enableStreamingEngine= */ false, new ArrayList<>()),
// No-op, Appliance does not use Dispatcher to call Streaming
Engine.
GrpcDispatcherClient.create(WindmillStubFactory.inProcessStubFactory("test")));
testServer.syncApplianceStub =
createWindmillApplianceStubWithDeadlineInterceptor(channel);
@@ -319,7 +341,10 @@ public final class GrpcWindmillServer extends
WindmillServerStub {
@Override
public GetDataStream getDataStream() {
return windmillStreamFactory.createGetDataStream(
- dispatcherClient.getDispatcherStub(),
throttleTimers.getDataThrottleTimer());
+ dispatcherClient.getDispatcherStub(),
+ throttleTimers.getDataThrottleTimer(),
+ sendKeyedGetDataRequests,
+ this.processHeartbeatResponses);
}
@Override
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java
index 099be8db0fd..7dc43e791e3 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java
@@ -21,6 +21,7 @@ import static
org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWi
import com.google.auto.value.AutoBuilder;
import java.io.PrintWriter;
+import java.util.List;
import java.util.Set;
import java.util.Timer;
import java.util.TimerTask;
@@ -32,6 +33,7 @@ import java.util.function.Supplier;
import javax.annotation.concurrent.ThreadSafe;
import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider;
import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints;
@@ -152,7 +154,10 @@ public class GrpcWindmillStreamFactory implements
StatusDataProvider {
}
public GetDataStream createGetDataStream(
- CloudWindmillServiceV1Alpha1Stub stub, ThrottleTimer
getDataThrottleTimer) {
+ CloudWindmillServiceV1Alpha1Stub stub,
+ ThrottleTimer getDataThrottleTimer,
+ boolean sendKeyedGetDataRequests,
+ Consumer<List<ComputationHeartbeatResponse>> processHeartbeatResponses) {
return GrpcGetDataStream.create(
responseObserver -> withDeadline(stub).getDataStream(responseObserver),
grpcBackOff.get(),
@@ -162,7 +167,14 @@ public class GrpcWindmillStreamFactory implements
StatusDataProvider {
getDataThrottleTimer,
jobHeader,
streamIdGenerator,
- streamingRpcBatchLimit);
+ streamingRpcBatchLimit,
+ sendKeyedGetDataRequests,
+ processHeartbeatResponses);
+ }
+
+ public GetDataStream createGetDataStream(
+ CloudWindmillServiceV1Alpha1Stub stub, ThrottleTimer
getDataThrottleTimer) {
+ return createGetDataStream(stub, getDataThrottleTimer, false, (response)
-> {});
}
public CommitWorkStream createCommitWorkStream(
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java
index 6c1239d6ebd..5a9e5443a50 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java
@@ -296,6 +296,11 @@ public class WindmillStateCache implements
StatusDataProvider {
this.computation = computation;
}
+ /** Returns the computation associated to this class. */
+ public String getComputation() {
+ return this.computation;
+ }
+
/** Invalidate all cache entries for this computation and {@code
processingKey}. */
public void invalidate(ByteString processingKey, long shardingKey) {
WindmillComputationKey key =
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java
index c28939c59ee..637b838c7fe 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java
@@ -39,6 +39,7 @@ import javax.annotation.Nullable;
import org.apache.beam.runners.dataflow.worker.KeyTokenInvalidException;
import
org.apache.beam.runners.dataflow.worker.MetricTrackingWindmillServerStub;
import org.apache.beam.runners.dataflow.worker.WindmillTimeUtils;
+import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataResponse;
@@ -123,6 +124,7 @@ public class WindmillStateReader {
private final MetricTrackingWindmillServerStub
metricTrackingWindmillServerStub;
private final ConcurrentHashMap<StateTag<?>, CoderAndFuture<?>> waiting;
private long bytesRead = 0L;
+ private final Supplier<Boolean> workItemIsFailed;
public WindmillStateReader(
MetricTrackingWindmillServerStub metricTrackingWindmillServerStub,
@@ -130,7 +132,8 @@ public class WindmillStateReader {
ByteString key,
long shardingKey,
long workToken,
- Supplier<AutoCloseable> readWrapperSupplier) {
+ Supplier<AutoCloseable> readWrapperSupplier,
+ Supplier<Boolean> workItemIsFailed) {
this.metricTrackingWindmillServerStub = metricTrackingWindmillServerStub;
this.computation = computation;
this.key = key;
@@ -139,6 +142,7 @@ public class WindmillStateReader {
this.readWrapperSupplier = readWrapperSupplier;
this.waiting = new ConcurrentHashMap<>();
this.pendingLookups = new ConcurrentLinkedQueue<>();
+ this.workItemIsFailed = workItemIsFailed;
}
public WindmillStateReader(
@@ -147,7 +151,14 @@ public class WindmillStateReader {
ByteString key,
long shardingKey,
long workToken) {
- this(metricTrackingWindmillServerStub, computation, key, shardingKey,
workToken, () -> null);
+ this(
+ metricTrackingWindmillServerStub,
+ computation,
+ key,
+ shardingKey,
+ workToken,
+ () -> null,
+ () -> Boolean.FALSE);
}
private <FutureT> Future<FutureT> stateFuture(StateTag<?> stateTag,
@Nullable Coder<?> coder) {
@@ -404,6 +415,9 @@ public class WindmillStateReader {
private KeyedGetDataResponse tryGetDataFromWindmill(HashSet<StateTag<?>>
stateTags)
throws Exception {
+ if (workItemIsFailed.get()) {
+ throw new WorkItemCancelledException(shardingKey);
+ }
KeyedGetDataRequest keyedGetDataRequest = createRequest(stateTags);
try (AutoCloseable ignored = readWrapperSupplier.get()) {
return Optional.ofNullable(
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
index a434b200120..2cfec6d3139 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
@@ -46,8 +46,11 @@ import
org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitWorkResponse;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationCommitWorkRequest;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatRequest;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetDataRequest;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetDataResponse;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution.State;
@@ -80,9 +83,10 @@ class FakeWindmillServer extends WindmillServerStub {
private final ErrorCollector errorCollector;
private final ConcurrentHashMap<Long, Consumer<Windmill.CommitStatus>>
droppedStreamingCommits;
private int commitsRequested = 0;
- private List<Windmill.GetDataRequest> getDataRequests = new ArrayList<>();
+ private final List<Windmill.GetDataRequest> getDataRequests = new
ArrayList<>();
private boolean isReady = true;
private boolean dropStreamingCommits = false;
+ private Consumer<List<Windmill.ComputationHeartbeatResponse>>
processHeartbeatResponses;
public FakeWindmillServer(ErrorCollector errorCollector) {
workToOffer =
@@ -91,7 +95,7 @@ class FakeWindmillServer extends WindmillServerStub {
dataToOffer =
new ResponseQueue<GetDataRequest, GetDataResponse>()
.returnByDefault(GetDataResponse.getDefaultInstance())
- // Sleep for a little bit to ensure that *-windmill-read
state-sampled counters show up.
+ // Sleep for a bit to ensure that *-windmill-read state-sampled
counters show up.
.delayEachResponseBy(Duration.millis(500));
commitsToOffer =
new ResponseQueue<Windmill.CommitWorkRequest, CommitWorkResponse>()
@@ -102,6 +106,13 @@ class FakeWindmillServer extends WindmillServerStub {
this.errorCollector = errorCollector;
statsReceived = new ArrayList<>();
droppedStreamingCommits = new ConcurrentHashMap<>();
+ processHeartbeatResponses = (responses) -> {};
+ }
+
+ @Override
+ public void setProcessHeartbeatResponses(
+ Consumer<List<Windmill.ComputationHeartbeatResponse>>
processHeartbeatResponses) {
+ this.processHeartbeatResponses = processHeartbeatResponses;
}
public void setDropStreamingCommits(boolean dropStreamingCommits) {
@@ -116,6 +127,10 @@ class FakeWindmillServer extends WindmillServerStub {
return dataToOffer;
}
+ public void sendFailedHeartbeats(List<Windmill.ComputationHeartbeatResponse>
responses) {
+ getDataStream().onHeartbeatResponse(responses);
+ }
+
public ResponseQueue<Windmill.CommitWorkRequest, Windmill.CommitWorkResponse>
whenCommitWorkCalled() {
return commitsToOffer;
@@ -304,17 +319,23 @@ class FakeWindmillServer extends WindmillServerStub {
}
@Override
- public void refreshActiveWork(Map<String, List<KeyedGetDataRequest>>
active) {
+ public void refreshActiveWork(Map<String, List<HeartbeatRequest>>
heartbeats) {
Windmill.GetDataRequest.Builder builder =
Windmill.GetDataRequest.newBuilder();
- for (Map.Entry<String, List<KeyedGetDataRequest>> entry :
active.entrySet()) {
- builder.addRequests(
- ComputationGetDataRequest.newBuilder()
+ for (Map.Entry<String, List<HeartbeatRequest>> entry :
heartbeats.entrySet()) {
+ builder.addComputationHeartbeatRequest(
+ ComputationHeartbeatRequest.newBuilder()
.setComputationId(entry.getKey())
- .addAllRequests(entry.getValue()));
+ .addAllHeartbeatRequests(entry.getValue()));
}
+
getData(builder.build());
}
+ @Override
+ public void onHeartbeatResponse(List<ComputationHeartbeatResponse>
responses) {
+ processHeartbeatResponses.accept(responses);
+ }
+
@Override
public void close() {}
@@ -383,6 +404,18 @@ class FakeWindmillServer extends WindmillServerStub {
}
}
+ public Map<Long, WorkItemCommitRequest> waitForAndGetCommitsWithTimeout(
+ int numCommits, Duration timeout) {
+ LOG.debug("waitForAndGetCommitsWithTimeout: {} {}", numCommits, timeout);
+ Instant waitStart = Instant.now();
+ while (commitsReceived.size() < commitsRequested + numCommits
+ && Instant.now().isBefore(waitStart.plus(timeout))) {
+ Uninterruptibles.sleepUninterruptibly(1000, TimeUnit.MILLISECONDS);
+ }
+ commitsRequested += numCommits;
+ return commitsReceived;
+ }
+
public Map<Long, WorkItemCommitRequest> waitForAndGetCommits(int numCommits)
{
LOG.debug("waitForAndGetCommitsRequest: {}", numCommits);
int maxTries = 10;
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
index 31a9af9004a..9526c96fd04 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
@@ -109,9 +109,12 @@ import
org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataResponse;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatRequest;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetDataRequest;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetDataResponse;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkResponse;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.InputMessageBundle;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataResponse;
@@ -577,8 +580,9 @@ public class StreamingDataflowWorkerTest {
}
/**
- * Returns a {@link
org.apache.beam.runners.dataflow.windmill.Windmill.WorkItemCommitRequest}
- * builder parsed from the provided text format proto.
+ * Returns a {@link
+ *
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest}
builder parsed
+ * from the provided text format proto.
*/
private WorkItemCommitRequest.Builder parseCommitRequest(String output)
throws Exception {
WorkItemCommitRequest.Builder builder =
Windmill.WorkItemCommitRequest.newBuilder();
@@ -3258,6 +3262,49 @@ public class StreamingDataflowWorkerTest {
assertThat(server.numGetDataRequests(), greaterThan(0));
}
+ @Test
+ public void testActiveWorkFailure() throws Exception {
+ List<ParallelInstruction> instructions =
+ Arrays.asList(
+ makeSourceInstruction(StringUtf8Coder.of()),
+ makeDoFnInstruction(blockingFn, 0, StringUtf8Coder.of()),
+ makeSinkInstruction(StringUtf8Coder.of(), 0));
+
+ FakeWindmillServer server = new FakeWindmillServer(errorCollector);
+ StreamingDataflowWorkerOptions options =
createTestingPipelineOptions(server);
+ options.setActiveWorkRefreshPeriodMillis(100);
+ StreamingDataflowWorker worker = makeWorker(instructions, options, true /*
publishCounters */);
+ worker.start();
+
+ // Queue up two work items for the same key.
+ server
+ .whenGetWorkCalled()
+ .thenReturn(makeInput(0, TimeUnit.MILLISECONDS.toMicros(0), "key",
DEFAULT_SHARDING_KEY))
+ .thenReturn(makeInput(1, TimeUnit.MILLISECONDS.toMicros(0), "key",
DEFAULT_SHARDING_KEY));
+ server.waitForEmptyWorkQueue();
+
+ // Mock Windmill sending a heartbeat response failing the second work item
while the first
+ // is still processing.
+ ComputationHeartbeatResponse.Builder failedHeartbeat =
+ ComputationHeartbeatResponse.newBuilder();
+ failedHeartbeat
+ .setComputationId(DEFAULT_COMPUTATION_ID)
+ .addHeartbeatResponsesBuilder()
+ .setCacheToken(3)
+ .setWorkToken(1)
+ .setShardingKey(DEFAULT_SHARDING_KEY)
+ .setFailed(true);
+
server.sendFailedHeartbeats(Collections.singletonList(failedHeartbeat.build()));
+
+ // Release the blocked calls.
+ BlockingFn.blocker.countDown();
+ Map<Long, Windmill.WorkItemCommitRequest> commits =
+ server.waitForAndGetCommitsWithTimeout(2,
Duration.standardSeconds((5)));
+ assertEquals(1, commits.size());
+
+ worker.stop();
+ }
+
@Test
public void testLatencyAttributionProtobufsPopulated() {
FakeClock clock = new FakeClock();
@@ -3573,7 +3620,10 @@ public class StreamingDataflowWorkerTest {
Windmill.GetDataRequest heartbeat = server.getGetDataRequests().get(2);
for (LatencyAttribution la :
- heartbeat.getRequests(0).getRequests(0).getLatencyAttributionList()) {
+ heartbeat
+ .getComputationHeartbeatRequest(0)
+ .getHeartbeatRequests(0)
+ .getLatencyAttributionList()) {
if (la.getState() == State.ACTIVE) {
assertTrue(la.getActiveLatencyBreakdownCount() > 0);
assertTrue(la.getActiveLatencyBreakdown(0).hasActiveMessageMetadata());
@@ -3768,7 +3818,7 @@ public class StreamingDataflowWorkerTest {
server
.whenGetWorkCalled()
.thenReturn(makeInput(1, TimeUnit.MILLISECONDS.toMicros(1),
DEFAULT_KEY_STRING, 1));
- // Ensure that the this work item processes.
+ // Ensure that this work item processes.
Map<Long, Windmill.WorkItemCommitRequest> result =
server.waitForAndGetCommits(1);
// Now ensure that nothing happens if a dropped commit actually completes.
droppedCommits.values().iterator().next().accept(CommitStatus.OK);
@@ -4129,7 +4179,7 @@ public class StreamingDataflowWorkerTest {
FakeClock.this.schedule(Duration.millis(unit.toMillis(delay)),
this);
}
});
- FakeClock.this.sleep(Duration.ZERO); // Execute work that has an
intial delay of zero.
+ FakeClock.this.sleep(Duration.ZERO); // Execute work that has an
initial delay of zero.
return null;
}
}
@@ -4167,6 +4217,7 @@ public class StreamingDataflowWorkerTest {
}
boolean isActiveWorkRefresh(GetDataRequest request) {
+ if (request.getComputationHeartbeatRequestCount() > 0) return true;
for (ComputationGetDataRequest computationRequest :
request.getRequestsList()) {
if
(!computationRequest.getComputationId().equals(DEFAULT_COMPUTATION_ID)) {
return false;
@@ -4203,6 +4254,21 @@ public class StreamingDataflowWorkerTest {
}
}
}
+ for (ComputationHeartbeatRequest heartbeatRequest :
+ request.getComputationHeartbeatRequestList()) {
+ for (HeartbeatRequest heartbeat :
heartbeatRequest.getHeartbeatRequestsList()) {
+ for (LatencyAttribution la : heartbeat.getLatencyAttributionList()) {
+ EnumMap<LatencyAttribution.State, Duration> durations =
+ totalDurations.computeIfAbsent(
+ heartbeat.getWorkToken(),
+ (Long workToken) ->
+ new EnumMap<LatencyAttribution.State, Duration>(
+ LatencyAttribution.State.class));
+ Duration cur = Duration.millis(la.getTotalDurationMillis());
+ durations.compute(la.getState(), (s, d) -> d == null ||
d.isShorterThan(cur) ? cur : d);
+ }
+ }
+ }
return EMPTY_DATA_RESPONDER.apply(request);
}
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
index 60ecaa3e37e..451ec649aa2 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
@@ -137,7 +137,8 @@ public class StreamingModeExecutionContextTest {
null, // synchronized processing time
stateReader,
sideInputStateFetcher,
- outputBuilder);
+ outputBuilder,
+ null);
TimerInternals timerInternals = stepContext.timerInternals();
@@ -187,7 +188,8 @@ public class StreamingModeExecutionContextTest {
null, // synchronized processing time
stateReader,
sideInputStateFetcher,
- outputBuilder);
+ outputBuilder,
+ null);
TimerInternals timerInternals = stepContext.timerInternals();
assertTrue(timerTimestamp.isBefore(timerInternals.currentProcessingTime()));
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java
index 6fa2ffe711f..b488641d1ca 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java
@@ -37,6 +37,7 @@ import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.lessThan;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
@@ -90,9 +91,12 @@ import
org.apache.beam.runners.dataflow.worker.WorkerCustomSources.SplittableOnl
import org.apache.beam.runners.dataflow.worker.counters.CounterSet;
import org.apache.beam.runners.dataflow.worker.counters.NameContext;
import
org.apache.beam.runners.dataflow.worker.profiler.ScopedProfiler.NoopProfileScope;
+import org.apache.beam.runners.dataflow.worker.streaming.Work;
import org.apache.beam.runners.dataflow.worker.testing.TestCountingSource;
import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader;
+import
org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader.NativeReaderIterator;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import
org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
import org.apache.beam.sdk.coders.Coder;
@@ -613,7 +617,8 @@ public class WorkerCustomSourcesTest {
null, // synchronized processing time
null, // StateReader
null, // StateFetcher
- Windmill.WorkItemCommitRequest.newBuilder());
+ Windmill.WorkItemCommitRequest.newBuilder(),
+ null);
@SuppressWarnings({"unchecked", "rawtypes"})
NativeReader<WindowedValue<ValueWithRecordId<KV<Integer, Integer>>>>
reader =
@@ -931,4 +936,79 @@ public class WorkerCustomSourcesTest {
assertNull(progress.getRemainingParallelism());
logged.verifyWarn("remaining parallelism");
}
+
+ @Test
+ public void testFailedWorkItemsAbort() throws Exception {
+ CounterSet counterSet = new CounterSet();
+ StreamingModeExecutionStateRegistry executionStateRegistry =
+ new StreamingModeExecutionStateRegistry(null);
+ StreamingModeExecutionContext context =
+ new StreamingModeExecutionContext(
+ counterSet,
+ "computationId",
+ new ReaderCache(Duration.standardMinutes(1), Runnable::run),
+ /*stateNameMap=*/ ImmutableMap.of(),
+ new
WindmillStateCache(options.getWorkerCacheMb()).forComputation("computationId"),
+ StreamingStepMetricsContainer.createRegistry(),
+ new DataflowExecutionStateTracker(
+ ExecutionStateSampler.newForTest(),
+ executionStateRegistry.getState(
+ NameContext.forStage("stageName"), "other", null,
NoopProfileScope.NOOP),
+ counterSet,
+ PipelineOptionsFactory.create(),
+ "test-work-item-id"),
+ executionStateRegistry,
+ Long.MAX_VALUE);
+
+ options.setNumWorkers(5);
+ int maxElements = 100;
+ DataflowPipelineDebugOptions debugOptions =
options.as(DataflowPipelineDebugOptions.class);
+ debugOptions.setUnboundedReaderMaxElements(maxElements);
+
+ ByteString state = ByteString.EMPTY;
+ Windmill.WorkItem workItem =
+ Windmill.WorkItem.newBuilder()
+ .setKey(ByteString.copyFromUtf8("0000000000000001")) // key is
zero-padded index.
+ .setWorkToken(0)
+ .setCacheToken(1)
+ .setSourceState(
+ Windmill.SourceState.newBuilder().setState(state).build()) //
Source state.
+ .build();
+ Work dummyWork = Work.create(workItem, Instant::now,
Collections.emptyList(), unused -> {});
+
+ context.start(
+ "key",
+ workItem,
+ new Instant(0), // input watermark
+ null, // output watermark
+ null, // synchronized processing time
+ null, // StateReader
+ null, // StateFetcher
+ Windmill.WorkItemCommitRequest.newBuilder(),
+ dummyWork::isFailed);
+
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ NativeReader<WindowedValue<ValueWithRecordId<KV<Integer, Integer>>>>
reader =
+ (NativeReader)
+ WorkerCustomSources.create(
+ (CloudObject)
+ serializeToCloudSource(new
TestCountingSource(Integer.MAX_VALUE), options)
+ .getSpec(),
+ options,
+ context);
+
+ NativeReaderIterator<WindowedValue<ValueWithRecordId<KV<Integer,
Integer>>>> readerIterator =
+ reader.iterator();
+ int numReads = 0;
+ while ((numReads == 0) ? readerIterator.start() :
readerIterator.advance()) {
+ WindowedValue<ValueWithRecordId<KV<Integer, Integer>>> value =
readerIterator.getCurrent();
+ assertEquals(KV.of(0, numReads), value.getValue().getValue());
+ numReads++;
+ // Fail the work item after reading two elements.
+ if (numReads == 2) {
+ dummyWork.setFailed();
+ }
+ }
+ assertThat(numReads, equalTo(2));
+ }
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java
index ea57f687fd9..de30fd0f8d5 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java
@@ -36,7 +36,7 @@ import javax.annotation.Nullable;
import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler;
import
org.apache.beam.runners.dataflow.worker.streaming.ActiveWorkState.ActivateWorkResult;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
-import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest;
import
org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache;
import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
@@ -239,7 +239,7 @@ public class ActiveWorkStateTest {
}
@Test
- public void testGetKeysToRefresh() {
+ public void testGetKeyHeartbeats() {
Instant refreshDeadline = Instant.now();
Work freshWork = createWork(createWorkItem(3L));
@@ -254,47 +254,51 @@ public class ActiveWorkStateTest {
activeWorkState.activateWorkForKey(shardedKey1, freshWork);
activeWorkState.activateWorkForKey(shardedKey2, refreshableWork2);
- ImmutableList<KeyedGetDataRequest> requests =
- activeWorkState.getKeysToRefresh(refreshDeadline,
DataflowExecutionStateSampler.instance());
+ ImmutableList<HeartbeatRequest> requests =
+ activeWorkState.getKeyHeartbeats(refreshDeadline,
DataflowExecutionStateSampler.instance());
- ImmutableList<GetDataRequestKeyShardingKeyAndWorkToken> expected =
+ ImmutableList<HeartbeatRequestShardingKeyWorkTokenAndCacheToken> expected =
ImmutableList.of(
- GetDataRequestKeyShardingKeyAndWorkToken.from(shardedKey1,
refreshableWork1),
- GetDataRequestKeyShardingKeyAndWorkToken.from(shardedKey2,
refreshableWork2));
+
HeartbeatRequestShardingKeyWorkTokenAndCacheToken.from(shardedKey1,
refreshableWork1),
+
HeartbeatRequestShardingKeyWorkTokenAndCacheToken.from(shardedKey2,
refreshableWork2));
- ImmutableList<GetDataRequestKeyShardingKeyAndWorkToken> actual =
+ ImmutableList<HeartbeatRequestShardingKeyWorkTokenAndCacheToken> actual =
requests.stream()
- .map(GetDataRequestKeyShardingKeyAndWorkToken::from)
+ .map(HeartbeatRequestShardingKeyWorkTokenAndCacheToken::from)
.collect(toImmutableList());
assertThat(actual).containsExactlyElementsIn(expected);
}
@AutoValue
- abstract static class GetDataRequestKeyShardingKeyAndWorkToken {
+ abstract static class HeartbeatRequestShardingKeyWorkTokenAndCacheToken {
- private static GetDataRequestKeyShardingKeyAndWorkToken create(
- ByteString key, long shardingKey, long workToken) {
- return new
AutoValue_ActiveWorkStateTest_GetDataRequestKeyShardingKeyAndWorkToken(
- key, shardingKey, workToken);
+ private static HeartbeatRequestShardingKeyWorkTokenAndCacheToken create(
+ long shardingKey, long workToken, long cacheToken) {
+ return new
AutoValue_ActiveWorkStateTest_HeartbeatRequestShardingKeyWorkTokenAndCacheToken(
+ shardingKey, workToken, cacheToken);
}
- private static GetDataRequestKeyShardingKeyAndWorkToken from(
- KeyedGetDataRequest keyedGetDataRequest) {
+ private static HeartbeatRequestShardingKeyWorkTokenAndCacheToken from(
+ HeartbeatRequest heartbeatRequest) {
return create(
- keyedGetDataRequest.getKey(),
- keyedGetDataRequest.getShardingKey(),
- keyedGetDataRequest.getWorkToken());
+ heartbeatRequest.getShardingKey(),
+ heartbeatRequest.getWorkToken(),
+ heartbeatRequest.getCacheToken());
}
- private static GetDataRequestKeyShardingKeyAndWorkToken from(ShardedKey
shardedKey, Work work) {
- return create(shardedKey.key(), shardedKey.shardingKey(),
work.getWorkItem().getWorkToken());
+ private static HeartbeatRequestShardingKeyWorkTokenAndCacheToken from(
+ ShardedKey shardedKey, Work work) {
+ return create(
+ shardedKey.shardingKey(),
+ work.getWorkItem().getWorkToken(),
+ work.getWorkItem().getCacheToken());
}
- abstract ByteString key();
-
abstract long shardingKey();
abstract long workToken();
+
+ abstract long cacheToken();
}
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java
index 5f8a452a043..0ea25302767 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java
@@ -44,6 +44,7 @@ import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Al
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatRequest;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationWorkItemMetadata;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkStreamTimingInfo;
@@ -51,6 +52,7 @@ import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkStreamTi
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataId;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataResponse;
@@ -126,7 +128,7 @@ public class GrpcWindmillServerTest {
.build()
.start();
- this.client = GrpcWindmillServer.newTestInstance(name);
+ this.client = GrpcWindmillServer.newTestInstance(name, new ArrayList<>());
}
@After
@@ -744,7 +746,7 @@ public class GrpcWindmillServerTest {
while (true) {
Thread.sleep(100);
int tmpErrorsBeforeClose = errorsBeforeClose.get();
- // wait for at least 1 errors before close
+ // wait for at least 1 error before close
if (tmpErrorsBeforeClose > 0) {
break;
}
@@ -765,7 +767,7 @@ public class GrpcWindmillServerTest {
while (true) {
Thread.sleep(100);
int tmpErrorsAfterClose = errorsAfterClose.get();
- // wait for at least 1 errors after close
+ // wait for at least 1 error after close
if (tmpErrorsAfterClose > 0) {
break;
}
@@ -786,22 +788,36 @@ public class GrpcWindmillServerTest {
assertTrue(stream.awaitTermination(30, TimeUnit.SECONDS));
}
- private List<KeyedGetDataRequest> makeHeartbeatRequest(List<String> keys) {
+ private List<KeyedGetDataRequest> makeGetDataHeartbeatRequest(List<String>
keys) {
List<KeyedGetDataRequest> result = new ArrayList<>();
for (String key : keys) {
result.add(
Windmill.KeyedGetDataRequest.newBuilder()
- .setKey(ByteString.copyFromUtf8(key))
+ .setShardingKey(key.hashCode())
.setWorkToken(0)
+ .setCacheToken(0)
+ .build());
+ }
+ return result;
+ }
+
+ private List<HeartbeatRequest> makeHeartbeatRequest(List<String> keys) {
+ List<HeartbeatRequest> result = new ArrayList<>();
+ for (String key : keys) {
+ result.add(
+ Windmill.HeartbeatRequest.newBuilder()
+ .setShardingKey(key.hashCode())
+ .setWorkToken(0)
+ .setCacheToken(0)
.build());
}
return result;
}
@Test
- public void testStreamingGetDataHeartbeats() throws Exception {
+ public void testStreamingGetDataHeartbeatsAsKeyedGetDataRequests() throws
Exception {
// This server records the heartbeats observed but doesn't respond.
- final Map<String, List<KeyedGetDataRequest>> heartbeats = new HashMap<>();
+ final Map<String, List<KeyedGetDataRequest>> getDataHeartbeats = new
HashMap<>();
serviceRegistry.addService(
new CloudWindmillServiceV1Alpha1ImplBase() {
@@ -826,16 +842,17 @@ public class GrpcWindmillServerTest {
.build()));
sawHeader = true;
} else {
- LOG.info("Received {} heartbeats",
chunk.getStateRequestCount());
+ LOG.info("Received {} getDataHeartbeats",
chunk.getStateRequestCount());
errorCollector.checkThat(
chunk.getSerializedSize(),
Matchers.lessThanOrEqualTo(STREAM_CHUNK_SIZE));
errorCollector.checkThat(chunk.getRequestIdCount(),
Matchers.is(0));
- synchronized (heartbeats) {
+ synchronized (getDataHeartbeats) {
for (ComputationGetDataRequest request :
chunk.getStateRequestList()) {
errorCollector.checkThat(request.getRequestsCount(),
Matchers.is(1));
- heartbeats.putIfAbsent(request.getComputationId(), new
ArrayList<>());
- heartbeats
+ getDataHeartbeats.putIfAbsent(
+ request.getComputationId(), new ArrayList<>());
+ getDataHeartbeats
.get(request.getComputationId())
.add(request.getRequestsList().get(0));
}
@@ -857,7 +874,6 @@ public class GrpcWindmillServerTest {
}
});
- Map<String, List<KeyedGetDataRequest>> activeMap = new HashMap<>();
List<String> computation1Keys = new ArrayList<>();
List<String> computation2Keys = new ArrayList<>();
@@ -865,22 +881,141 @@ public class GrpcWindmillServerTest {
computation1Keys.add("Computation1Key" + i);
computation2Keys.add("Computation2Key" + largeString(i * 20));
}
- activeMap.put("Computation1", makeHeartbeatRequest(computation1Keys));
- activeMap.put("Computation2", makeHeartbeatRequest(computation2Keys));
+ // We're adding HeartbeatRequests to refreshActiveWork, but expecting to
get back
+ // KeyedGetDataRequests, so make a Map of both types.
+ Map<String, List<KeyedGetDataRequest>> expectedKeyedGetDataRequests = new
HashMap<>();
+ expectedKeyedGetDataRequests.put("Computation1",
makeGetDataHeartbeatRequest(computation1Keys));
+ expectedKeyedGetDataRequests.put("Computation2",
makeGetDataHeartbeatRequest(computation2Keys));
+ Map<String, List<HeartbeatRequest>> heartbeatsToRefresh = new HashMap<>();
+ heartbeatsToRefresh.put("Computation1",
makeHeartbeatRequest(computation1Keys));
+ heartbeatsToRefresh.put("Computation2",
makeHeartbeatRequest(computation2Keys));
GetDataStream stream = client.getDataStream();
- stream.refreshActiveWork(activeMap);
+ stream.refreshActiveWork(heartbeatsToRefresh);
stream.close();
assertTrue(stream.awaitTermination(60, TimeUnit.SECONDS));
- while (true) {
+ boolean receivedAllGetDataHeartbeats = false;
+ while (!receivedAllGetDataHeartbeats) {
Thread.sleep(100);
- synchronized (heartbeats) {
- if (heartbeats.size() != activeMap.size()) {
+ synchronized (getDataHeartbeats) {
+ if (getDataHeartbeats.size() != expectedKeyedGetDataRequests.size()) {
continue;
}
- assertEquals(heartbeats, activeMap);
- break;
+ assertEquals(expectedKeyedGetDataRequests, getDataHeartbeats);
+ receivedAllGetDataHeartbeats = true;
+ }
+ }
+ }
+
+ @Test
+ public void testStreamingGetDataHeartbeatsAsHeartbeatRequests() throws
Exception {
+ // Create a client and server different from the one in SetUp so we can
add an experiment to the
+ // options passed in.
+ this.server =
+ InProcessServerBuilder.forName("TestServer")
+ .fallbackHandlerRegistry(serviceRegistry)
+ .executor(Executors.newFixedThreadPool(1))
+ .build()
+ .start();
+ this.client =
+ GrpcWindmillServer.newTestInstance(
+ "TestServer",
+
Collections.singletonList("streaming_engine_send_new_heartbeat_requests"));
+ // This server records the heartbeats observed but doesn't respond.
+ final List<ComputationHeartbeatRequest> receivedHeartbeats = new
ArrayList<>();
+
+ serviceRegistry.addService(
+ new CloudWindmillServiceV1Alpha1ImplBase() {
+ @Override
+ public StreamObserver<StreamingGetDataRequest> getDataStream(
+ StreamObserver<StreamingGetDataResponse> responseObserver) {
+ return new StreamObserver<StreamingGetDataRequest>() {
+ boolean sawHeader = false;
+
+ @Override
+ public void onNext(StreamingGetDataRequest chunk) {
+ try {
+ if (!sawHeader) {
+ LOG.info("Received header");
+ errorCollector.checkThat(
+ chunk.getHeader(),
+ Matchers.equalTo(
+ JobHeader.newBuilder()
+ .setJobId("job")
+ .setProjectId("project")
+ .setWorkerId("worker")
+ .build()));
+ sawHeader = true;
+ } else {
+ LOG.info(
+ "Received {} computationHeartbeatRequests",
+ chunk.getComputationHeartbeatRequestCount());
+ errorCollector.checkThat(
+ chunk.getSerializedSize(),
Matchers.lessThanOrEqualTo(STREAM_CHUNK_SIZE));
+ errorCollector.checkThat(chunk.getRequestIdCount(),
Matchers.is(0));
+
+ synchronized (receivedHeartbeats) {
+
receivedHeartbeats.addAll(chunk.getComputationHeartbeatRequestList());
+ }
+ }
+ } catch (Exception e) {
+ errorCollector.addError(e);
+ }
+ }
+
+ @Override
+ public void onError(Throwable throwable) {}
+
+ @Override
+ public void onCompleted() {
+ responseObserver.onCompleted();
+ }
+ };
+ }
+ });
+
+ List<String> computation1Keys = new ArrayList<>();
+ List<String> computation2Keys = new ArrayList<>();
+
+ // When sending heartbeats as HeartbeatRequest protos, all keys for the
same computation should
+ // be batched into the same ComputationHeartbeatRequest. Compare to the
KeyedGetDataRequest
+ // version in the test above, which only sends one key per
ComputationGetDataRequest.
+ List<ComputationHeartbeatRequest> expectedHeartbeats = new ArrayList<>();
+ ComputationHeartbeatRequest.Builder comp1Builder =
+
ComputationHeartbeatRequest.newBuilder().setComputationId("Computation1");
+ ComputationHeartbeatRequest.Builder comp2Builder =
+
ComputationHeartbeatRequest.newBuilder().setComputationId("Computation2");
+ for (int i = 0; i < 100; ++i) {
+ String computation1Key = "Computation1Key" + i;
+ computation1Keys.add(computation1Key);
+ comp1Builder.addHeartbeatRequests(
+
makeHeartbeatRequest(Collections.singletonList(computation1Key)).get(0));
+ String computation2Key = "Computation2Key" + largeString(i * 20);
+ computation2Keys.add(computation2Key);
+ comp2Builder.addHeartbeatRequests(
+
makeHeartbeatRequest(Collections.singletonList(computation2Key)).get(0));
+ }
+ expectedHeartbeats.add(comp1Builder.build());
+ expectedHeartbeats.add(comp2Builder.build());
+ Map<String, List<HeartbeatRequest>> heartbeatRequestMap = new HashMap<>();
+ heartbeatRequestMap.put("Computation1",
makeHeartbeatRequest(computation1Keys));
+ heartbeatRequestMap.put("Computation2",
makeHeartbeatRequest(computation2Keys));
+
+ GetDataStream stream = client.getDataStream();
+ stream.refreshActiveWork(heartbeatRequestMap);
+ stream.close();
+ assertTrue(stream.awaitTermination(60, TimeUnit.SECONDS));
+
+ boolean receivedAllHeartbeatRequests = false;
+ while (!receivedAllHeartbeatRequests) {
+ Thread.sleep(100);
+ synchronized (receivedHeartbeats) {
+ if (receivedHeartbeats.size() != expectedHeartbeats.size()) {
+ continue;
+ }
+ assertEquals(expectedHeartbeats, receivedHeartbeats);
+ receivedAllHeartbeatRequests = true;
}
}
}
@@ -888,7 +1023,7 @@ public class GrpcWindmillServerTest {
@Test
public void testThrottleSignal() throws Exception {
// This server responds with work items until the throttleMessage limit is
hit at which point it
- // returns RESROUCE_EXHAUSTED errors for throttleTime msecs after which it
resumes sending
+ // returns RESOURCE_EXHAUSTED errors for throttleTime msecs after which it
resumes sending
// work items.
final int throttleTime = 2000;
final int throttleMessage = 15;
diff --git
a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto
b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto
index 6aaeb57001e..0c824ca301b 100644
---
a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto
+++
b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto
@@ -477,9 +477,10 @@ message GetWorkResponse {
// GetData
message KeyedGetDataRequest {
- required bytes key = 1;
+ optional bytes key = 1;
required fixed64 work_token = 2;
optional fixed64 sharding_key = 6;
+ optional fixed64 cache_token = 11;
repeated TagValue values_to_fetch = 3;
repeated TagValuePrefixRequest tag_value_prefixes_to_fetch = 10;
repeated TagBag bags_to_fetch = 8;
@@ -507,6 +508,8 @@ message GetDataRequest {
// Assigned worker id for the instance.
optional string worker_id = 6;
+ // SE only. Will only be set by compatible client
+ repeated ComputationHeartbeatRequest computation_heartbeat_request = 7;
// DEPRECATED
repeated GlobalDataId global_data_to_fetch = 2;
}
@@ -536,6 +539,44 @@ message ComputationGetDataResponse {
message GetDataResponse {
repeated ComputationGetDataResponse data = 1;
repeated GlobalData global_data = 2;
+ // Only set if ComputationHeartbeatRequest was sent, prior versions do not
+ // expect a response for heartbeats. SE only.
+ repeated ComputationHeartbeatResponse computation_heartbeat_response = 3;
+}
+
+// Heartbeats
+//
+// Heartbeats are sent over the GetData stream in Streaming Engine and
+// indicates the work item that the user worker has previously received from
+// GetWork but not yet committed with CommitWork.
+// Note that implicit heartbeats not expecting a response may be sent as
+// special KeyedGetDataRequests see function KeyedGetDataRequestIsHeartbeat.
+// SE only.
+message HeartbeatRequest {
+ optional fixed64 sharding_key = 1;
+ optional fixed64 work_token = 2;
+ optional fixed64 cache_token = 3;
+ repeated LatencyAttribution latency_attribution = 4;
+}
+
+// Responses for heartbeat requests, indicating which work is no longer valid
+// on the windmill worker and may be dropped/cancelled in the client.
+// SE only.
+message HeartbeatResponse {
+ optional fixed64 sharding_key = 1;
+ optional fixed64 work_token = 2;
+ optional fixed64 cache_token = 3;
+ optional bool failed = 4;
+}
+
+message ComputationHeartbeatRequest {
+ optional string computation_id = 1;
+ repeated HeartbeatRequest heartbeat_requests = 2;
+}
+
+message ComputationHeartbeatResponse {
+ optional string computation_id = 1;
+ repeated HeartbeatResponse heartbeat_responses = 2;
}
// CommitWork
@@ -772,6 +813,8 @@ message StreamingGetDataRequest {
repeated fixed64 request_id = 1;
repeated GlobalDataRequest global_data_request = 3;
repeated ComputationGetDataRequest state_request = 4;
+ // Will only be set by compatible client
+ repeated ComputationHeartbeatRequest computation_heartbeat_request = 5;
}
message StreamingGetDataResponse {
@@ -784,6 +827,12 @@ message StreamingGetDataResponse {
repeated bytes serialized_response = 2;
// Remaining bytes field applies only to the last serialized_response
optional int64 remaining_bytes_for_response = 3;
+
+ // Only set if ComputationHeartbeatRequest was sent, prior versions do not
+ // expect a response for heartbeats.
+ repeated ComputationHeartbeatResponse computation_heartbeat_response = 5;
+
+ reserved 4;
}
message StreamingCommitWorkRequest {