This is an automated email from the ASF dual-hosted git repository.

scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 5e7edc45598 Heartbeats (#29963)
5e7edc45598 is described below

commit 5e7edc45598b6438761386856e96a66487704b69
Author: Andrew Crites <[email protected]>
AuthorDate: Tue Jan 23 01:28:20 2024 -0800

    Heartbeats (#29963)
    
    * Adds sending new HeartbeatRequest protos to StreamingDataflowWorker. If 
any HeartbeatResponses are sent from Windmill containing failed work items, 
aborts processing those work items as soon as possible.
    
    * Adds sending new HeartbeatRequest protos when using streaming RPC's 
(streaming engine). Also adds a test.
    
    * Adds new test for custom source reader exiting early for failed work. 
Adds special exception for handling failed work.
    
    * removes some extra cache invalidations and unneeded log statements.
    
    * Added streaming_engine prefix to experiment enabling heartbeats and 
changed exception in state reader to be WorkItemFailedException.
    
    * Adds check that heartbeat response sets failed before failing work.
    
    * Adds ability to plumb experiments to test server for 
GrpcWindmillServerTest so we can test the new style heartbeats.
    
    * Changes StreamingDataflowWorkerTest to look for latency attribution in 
new-style heartbeat requests since that's what FakeWindmillServer returns now.
---
 .../worker/MetricTrackingWindmillServerStub.java   |  32 ++--
 .../beam/runners/dataflow/worker/PubsubReader.java |   8 +
 .../dataflow/worker/StreamingDataflowWorker.java   |  51 +++++-
 .../worker/StreamingModeExecutionContext.java      |  14 +-
 .../dataflow/worker/UngroupedWindmillReader.java   |   8 +
 .../worker/WorkItemCancelledException.java         |  39 +++++
 .../dataflow/worker/WorkerCustomSources.java       |   3 +-
 .../dataflow/worker/streaming/ActiveWorkState.java |  66 +++++++-
 .../worker/streaming/ComputationState.java         |  14 +-
 .../runners/dataflow/worker/streaming/Work.java    |  11 ++
 .../worker/windmill/WindmillServerStub.java        |   6 +
 .../worker/windmill/client/WindmillStream.java     |   5 +-
 .../windmill/client/grpc/GrpcGetDataStream.java    | 107 ++++++++++---
 .../windmill/client/grpc/GrpcWindmillServer.java   |  35 +++-
 .../client/grpc/GrpcWindmillStreamFactory.java     |  16 +-
 .../worker/windmill/state/WindmillStateCache.java  |   5 +
 .../worker/windmill/state/WindmillStateReader.java |  18 ++-
 .../dataflow/worker/FakeWindmillServer.java        |  47 +++++-
 .../worker/StreamingDataflowWorkerTest.java        |  76 ++++++++-
 .../worker/StreamingModeExecutionContextTest.java  |   6 +-
 .../dataflow/worker/WorkerCustomSourcesTest.java   |  82 +++++++++-
 .../worker/streaming/ActiveWorkStateTest.java      |  50 +++---
 .../client/grpc/GrpcWindmillServerTest.java        | 177 ++++++++++++++++++---
 .../worker/windmill/src/main/proto/windmill.proto  |  51 +++++-
 24 files changed, 801 insertions(+), 126 deletions(-)

diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricTrackingWindmillServerStub.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricTrackingWindmillServerStub.java
index 0e929249b3a..800504f4451 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricTrackingWindmillServerStub.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricTrackingWindmillServerStub.java
@@ -27,7 +27,7 @@ import java.util.concurrent.atomic.AtomicInteger;
 import javax.annotation.concurrent.GuardedBy;
 import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
-import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest;
 import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub;
 import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
 import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool;
@@ -239,25 +239,37 @@ public class MetricTrackingWindmillServerStub {
   }
 
   /** Tells windmill processing is ongoing for the given keys. */
-  public void refreshActiveWork(Map<String, List<KeyedGetDataRequest>> active) 
{
-    activeHeartbeats.set(active.size());
+  public void refreshActiveWork(Map<String, List<HeartbeatRequest>> 
heartbeats) {
+    activeHeartbeats.set(heartbeats.size());
     try {
       if (useStreamingRequests) {
         // With streaming requests, always send the request even when it is 
empty, to ensure that
         // we trigger health checks for the stream even when it is idle.
         GetDataStream stream = streamPool.getStream();
         try {
-          stream.refreshActiveWork(active);
+          stream.refreshActiveWork(heartbeats);
         } finally {
           streamPool.releaseStream(stream);
         }
-      } else if (!active.isEmpty()) {
+      } else if (!heartbeats.isEmpty()) {
+        // This code path is only used by appliance which sends heartbeats 
(used to refresh active
+        // work) as KeyedGetDataRequests. So we must translate the 
HeartbeatRequest to a
+        // KeyedGetDataRequest here regardless of the value of 
sendKeyedGetDataRequests.
         Windmill.GetDataRequest.Builder builder = 
Windmill.GetDataRequest.newBuilder();
-        for (Map.Entry<String, List<KeyedGetDataRequest>> entry : 
active.entrySet()) {
-          builder.addRequests(
-              Windmill.ComputationGetDataRequest.newBuilder()
-                  .setComputationId(entry.getKey())
-                  .addAllRequests(entry.getValue()));
+        for (Map.Entry<String, List<HeartbeatRequest>> entry : 
heartbeats.entrySet()) {
+          Windmill.ComputationGetDataRequest.Builder perComputationBuilder =
+              Windmill.ComputationGetDataRequest.newBuilder();
+          perComputationBuilder.setComputationId(entry.getKey());
+          for (HeartbeatRequest request : entry.getValue()) {
+            perComputationBuilder.addRequests(
+                Windmill.KeyedGetDataRequest.newBuilder()
+                    .setShardingKey(request.getShardingKey())
+                    .setWorkToken(request.getWorkToken())
+                    .setCacheToken(request.getCacheToken())
+                    
.addAllLatencyAttribution(request.getLatencyAttributionList())
+                    .build());
+          }
+          builder.addRequests(perComputationBuilder.build());
         }
         server.getData(builder.build());
       }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/PubsubReader.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/PubsubReader.java
index d0931e02cc8..be0bccec026 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/PubsubReader.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/PubsubReader.java
@@ -112,6 +112,14 @@ class PubsubReader<T> extends 
NativeReader<WindowedValue<T>> {
       super(work);
     }
 
+    @Override
+    public boolean advance() throws IOException {
+      if (context.workIsFailed()) {
+        return false;
+      }
+      return super.advance();
+    }
+
     @Override
     protected WindowedValue<T> decodeMessage(Windmill.Message message) throws 
IOException {
       T value;
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
index f2d7c02729c..a95e7828881 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
@@ -85,6 +85,7 @@ import 
org.apache.beam.runners.dataflow.worker.status.DebugCapture.Capturable;
 import 
org.apache.beam.runners.dataflow.worker.status.LastExceptionDataProvider;
 import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider;
 import org.apache.beam.runners.dataflow.worker.status.WorkerStatusPages;
+import 
org.apache.beam.runners.dataflow.worker.streaming.ActiveWorkState.FailedTokens;
 import org.apache.beam.runners.dataflow.worker.streaming.Commit;
 import org.apache.beam.runners.dataflow.worker.streaming.ComputationState;
 import org.apache.beam.runners.dataflow.worker.streaming.ExecutionState;
@@ -422,6 +423,7 @@ public class StreamingDataflowWorker {
 
     this.publishCounters = publishCounters;
     this.windmillServer = options.getWindmillServerStub();
+    
this.windmillServer.setProcessHeartbeatResponses(this::handleHeartbeatResponses);
     this.metricTrackingWindmillServer =
         new MetricTrackingWindmillServerStub(windmillServer, memoryMonitor, 
windmillServiceEnabled);
     this.metricTrackingWindmillServer.start();
@@ -982,6 +984,9 @@ public class StreamingDataflowWorker {
     String counterName = "dataflow_source_bytes_processed-" + 
mapTask.getSystemName();
 
     try {
+      if (work.isFailed()) {
+        throw new WorkItemCancelledException(workItem.getShardingKey());
+      }
       executionState = computationState.getExecutionStateQueue().poll();
       if (executionState == null) {
         MutableNetwork<Node, Edge> mapTaskNetwork = 
mapTaskToNetwork.apply(mapTask);
@@ -1098,7 +1103,8 @@ public class StreamingDataflowWorker {
                     work.setState(State.PROCESSING);
                   }
                 };
-              });
+              },
+              work::isFailed);
       SideInputStateFetcher localSideInputStateFetcher = 
sideInputStateFetcher.byteTrackingView();
 
       // If the read output KVs, then we can decode Windmill's byte key into a 
userland
@@ -1136,12 +1142,16 @@ public class StreamingDataflowWorker {
               synchronizedProcessingTime,
               stateReader,
               localSideInputStateFetcher,
-              outputBuilder);
+              outputBuilder,
+              work::isFailed);
 
       // Blocks while executing work.
       executionState.workExecutor().execute();
 
-      // Reports source bytes processed to workitemcommitrequest if available.
+      if (work.isFailed()) {
+        throw new WorkItemCancelledException(workItem.getShardingKey());
+      }
+      // Reports source bytes processed to WorkItemCommitRequest if available.
       try {
         long sourceBytesProcessed = 0;
         HashMap<String, ElementCounter> counters =
@@ -1234,6 +1244,12 @@ public class StreamingDataflowWorker {
                 + "Work will not be retried locally.",
             computationId,
             key.toStringUtf8());
+      } else if (WorkItemCancelledException.isWorkItemCancelledException(t)) {
+        LOG.debug(
+            "Execution of work for computation '{}' on key '{}' failed. "
+                + "Work will not be retried locally.",
+            computationId,
+            workItem.getShardingKey());
       } else {
         LastExceptionDataProvider.reportException(t);
         LOG.debug("Failed work: {}", work);
@@ -1369,6 +1385,10 @@ public class StreamingDataflowWorker {
   // Adds the commit to the commitStream if it fits, returning true iff it is 
consumed.
   private boolean addCommitToStream(Commit commit, CommitWorkStream 
commitStream) {
     Preconditions.checkNotNull(commit);
+    // Drop commits for failed work. Such commits will be dropped by Windmill 
anyway.
+    if (commit.work().isFailed()) {
+      return true;
+    }
     final ComputationState state = commit.computationState();
     final Windmill.WorkItemCommitRequest request = commit.request();
     final int size = commit.getSize();
@@ -1896,6 +1916,25 @@ public class StreamingDataflowWorker {
     }
   }
 
+  public void 
handleHeartbeatResponses(List<Windmill.ComputationHeartbeatResponse> responses) 
{
+    for (Windmill.ComputationHeartbeatResponse computationHeartbeatResponse : 
responses) {
+      // Maps sharding key to (work token, cache token) for work that should 
be marked failed.
+      Map<Long, List<FailedTokens>> failedWork = new HashMap<>();
+      for (Windmill.HeartbeatResponse heartbeatResponse :
+          computationHeartbeatResponse.getHeartbeatResponsesList()) {
+        if (heartbeatResponse.getFailed()) {
+          failedWork
+              .computeIfAbsent(heartbeatResponse.getShardingKey(), key -> new 
ArrayList<>())
+              .add(
+                  new FailedTokens(
+                      heartbeatResponse.getWorkToken(), 
heartbeatResponse.getCacheToken()));
+        }
+      }
+      ComputationState state = 
computationMap.get(computationHeartbeatResponse.getComputationId());
+      if (state != null) state.failWork(failedWork);
+    }
+  }
+
   /**
    * Sends a GetData request to Windmill for all sufficiently old active work.
    *
@@ -1904,15 +1943,15 @@ public class StreamingDataflowWorker {
    * StreamingDataflowWorkerOptions#getActiveWorkRefreshPeriodMillis}.
    */
   private void refreshActiveWork() {
-    Map<String, List<Windmill.KeyedGetDataRequest>> active = new HashMap<>();
+    Map<String, List<Windmill.HeartbeatRequest>> heartbeats = new HashMap<>();
     Instant refreshDeadline =
         
clock.get().minus(Duration.millis(options.getActiveWorkRefreshPeriodMillis()));
 
     for (Map.Entry<String, ComputationState> entry : 
computationMap.entrySet()) {
-      active.put(entry.getKey(), 
entry.getValue().getKeysToRefresh(refreshDeadline, sampler));
+      heartbeats.put(entry.getKey(), 
entry.getValue().getKeyHeartbeats(refreshDeadline, sampler));
     }
 
-    metricTrackingWindmillServer.refreshActiveWork(active);
+    metricTrackingWindmillServer.refreshActiveWork(heartbeats);
   }
 
   private void invalidateStuckCommits() {
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java
index d630601c28a..83cf49112a8 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java
@@ -112,6 +112,7 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
   private Windmill.WorkItemCommitRequest.Builder outputBuilder;
   private UnboundedSource.UnboundedReader<?> activeReader;
   private volatile long backlogBytes;
+  private Supplier<Boolean> workIsFailed;
 
   public StreamingModeExecutionContext(
       CounterFactory counterFactory,
@@ -135,6 +136,7 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
     this.stateNameMap = ImmutableMap.copyOf(stateNameMap);
     this.stateCache = stateCache;
     this.backlogBytes = UnboundedSource.UnboundedReader.BACKLOG_UNKNOWN;
+    this.workIsFailed = () -> Boolean.FALSE;
   }
 
   @VisibleForTesting
@@ -142,6 +144,10 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
     return backlogBytes;
   }
 
+  public boolean workIsFailed() {
+    return workIsFailed.get();
+  }
+
   public void start(
       @Nullable Object key,
       Windmill.WorkItem work,
@@ -150,9 +156,11 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
       @Nullable Instant synchronizedProcessingTime,
       WindmillStateReader stateReader,
       SideInputStateFetcher sideInputStateFetcher,
-      Windmill.WorkItemCommitRequest.Builder outputBuilder) {
+      Windmill.WorkItemCommitRequest.Builder outputBuilder,
+      @Nullable Supplier<Boolean> workFailed) {
     this.key = key;
     this.work = work;
+    this.workIsFailed = (workFailed != null) ? workFailed : () -> 
Boolean.FALSE;
     this.computationKey =
         WindmillComputationKey.create(computationId, work.getKey(), 
work.getShardingKey());
     this.sideInputStateFetcher = sideInputStateFetcher;
@@ -429,7 +437,7 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
 
   /**
    * Execution states in Streaming are shared between multiple map-task 
executors. Thus this class
-   * needs to be thread safe for multiple writers. A single stage could have 
have multiple executors
+   * needs to be thread safe for multiple writers. A single stage could have 
multiple executors
    * running concurrently.
    */
   public static class StreamingModeExecutionState
@@ -670,7 +678,7 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
     private NavigableSet<TimerData> 
modifiedUserSynchronizedProcessingTimersOrdered = null;
     // A list of timer keys that were modified by user processing earlier in 
this bundle. This
     // serves a tombstone, so
-    // that we know not to fire any bundle tiemrs that were moddified.
+    // that we know not to fire any bundle timers that were modified.
     private Table<String, StateNamespace, TimerData> modifiedUserTimerKeys = 
null;
 
     public StepContext(DataflowOperationContext operationContext) {
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/UngroupedWindmillReader.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/UngroupedWindmillReader.java
index e4e56a96c15..4aac93ceb3f 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/UngroupedWindmillReader.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/UngroupedWindmillReader.java
@@ -99,6 +99,14 @@ class UngroupedWindmillReader<T> extends 
NativeReader<WindowedValue<T>> {
       super(work);
     }
 
+    @Override
+    public boolean advance() throws IOException {
+      if (context.workIsFailed()) {
+        return false;
+      }
+      return super.advance();
+    }
+
     @Override
     protected WindowedValue<T> decodeMessage(Windmill.Message message) throws 
IOException {
       Instant timestampMillis =
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java
new file mode 100644
index 00000000000..934977fe098
--- /dev/null
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker;
+
+/** Indicates that the work item was cancelled and should not be retried. */
+@SuppressWarnings({
+  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
+})
+public class WorkItemCancelledException extends RuntimeException {
+  public WorkItemCancelledException(long sharding_key) {
+    super("Work item cancelled for key " + sharding_key);
+  }
+
+  /** Returns whether an exception was caused by a {@link 
WorkItemCancelledException}. */
+  public static boolean isWorkItemCancelledException(Throwable t) {
+    while (t != null) {
+      if (t instanceof WorkItemCancelledException) {
+        return true;
+      }
+      t = t.getCause();
+    }
+    return false;
+  }
+}
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSources.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSources.java
index a9050236efc..2dc3494af5e 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSources.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSources.java
@@ -836,7 +836,8 @@ public class WorkerCustomSources {
       while (true) {
         if (elemsRead >= maxElems
             || Instant.now().isAfter(endTime)
-            || context.isSinkFullHintSet()) {
+            || context.isSinkFullHintSet()
+            || context.workIsFailed()) {
           return false;
         }
         try {
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java
index 16266de9d47..54942dfeee1 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java
@@ -23,6 +23,7 @@ import java.io.PrintWriter;
 import java.util.ArrayDeque;
 import java.util.Deque;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
 import java.util.Optional;
@@ -34,8 +35,10 @@ import javax.annotation.concurrent.GuardedBy;
 import javax.annotation.concurrent.ThreadSafe;
 import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
-import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem;
 import 
org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache;
+import org.apache.beam.sdk.annotations.Internal;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
@@ -50,7 +53,8 @@ import org.slf4j.LoggerFactory;
  * activate, queue, and complete {@link Work} (including invalidating stuck 
{@link Work}).
  */
 @ThreadSafe
-final class ActiveWorkState {
+@Internal
+public final class ActiveWorkState {
   private static final Logger LOG = 
LoggerFactory.getLogger(ActiveWorkState.class);
 
   /* The max number of keys in COMMITTING or COMMIT_QUEUED status to be 
shown.*/
@@ -120,6 +124,50 @@ final class ActiveWorkState {
     return ActivateWorkResult.QUEUED;
   }
 
+  public static final class FailedTokens {
+    public long workToken;
+    public long cacheToken;
+
+    public FailedTokens(long workToken, long cacheToken) {
+      this.workToken = workToken;
+      this.cacheToken = cacheToken;
+    }
+  }
+
+  /**
+   * Fails any active work matching an element of the input Map.
+   *
+   * @param failedWork a map from sharding_key to tokens for the corresponding 
work.
+   */
+  synchronized void failWorkForKey(Map<Long, List<FailedTokens>> failedWork) {
+    // Note we can't construct a ShardedKey and look it up in activeWork 
directly since
+    // HeartbeatResponse doesn't include the user key.
+    for (Entry<ShardedKey, Deque<Work>> entry : activeWork.entrySet()) {
+      List<FailedTokens> failedTokens = 
failedWork.get(entry.getKey().shardingKey());
+      if (failedTokens == null) continue;
+      for (FailedTokens failedToken : failedTokens) {
+        for (Work queuedWork : entry.getValue()) {
+          WorkItem workItem = queuedWork.getWorkItem();
+          if (workItem.getWorkToken() == failedToken.workToken
+              && workItem.getCacheToken() == failedToken.cacheToken) {
+            LOG.debug(
+                "Failing work "
+                    + computationStateCache.getComputation()
+                    + " "
+                    + entry.getKey().shardingKey()
+                    + " "
+                    + failedToken.workToken
+                    + " "
+                    + failedToken.cacheToken
+                    + ". The work will be retried and is not lost.");
+            queuedWork.setFailed();
+            break;
+          }
+        }
+      }
+    }
+  }
+
   /**
    * Removes the complete work from the {@link Queue<Work>}. The {@link Work} 
is marked as completed
    * if its workToken matches the one that is passed in. Returns the next 
{@link Work} in the {@link
@@ -211,14 +259,14 @@ final class ActiveWorkState {
     return stuckCommits.build();
   }
 
-  synchronized ImmutableList<KeyedGetDataRequest> getKeysToRefresh(
+  synchronized ImmutableList<HeartbeatRequest> getKeyHeartbeats(
       Instant refreshDeadline, DataflowExecutionStateSampler sampler) {
     return activeWork.entrySet().stream()
-        .flatMap(entry -> toKeyedGetDataRequestStream(entry, refreshDeadline, 
sampler))
+        .flatMap(entry -> toHeartbeatRequestStream(entry, refreshDeadline, 
sampler))
         .collect(toImmutableList());
   }
 
-  private static Stream<KeyedGetDataRequest> toKeyedGetDataRequestStream(
+  private static Stream<HeartbeatRequest> toHeartbeatRequestStream(
       Entry<ShardedKey, Deque<Work>> shardedKeyAndWorkQueue,
       Instant refreshDeadline,
       DataflowExecutionStateSampler sampler) {
@@ -227,12 +275,14 @@ final class ActiveWorkState {
 
     return workQueue.stream()
         .filter(work -> work.getStartTime().isBefore(refreshDeadline))
+        // Don't send heartbeats for queued work we already know is failed.
+        .filter(work -> !work.isFailed())
         .map(
             work ->
-                Windmill.KeyedGetDataRequest.newBuilder()
-                    .setKey(shardedKey.key())
+                Windmill.HeartbeatRequest.newBuilder()
                     .setShardingKey(shardedKey.shardingKey())
                     .setWorkToken(work.getWorkItem().getWorkToken())
+                    .setCacheToken(work.getWorkItem().getCacheToken())
                     .addAllLatencyAttribution(
                         work.getLatencyAttributions(true, 
work.getLatencyTrackingId(), sampler))
                     .build());
@@ -250,7 +300,7 @@ final class ActiveWorkState {
     for (Map.Entry<ShardedKey, Deque<Work>> entry : activeWork.entrySet()) {
       Queue<Work> workQueue = Preconditions.checkNotNull(entry.getValue());
       Work activeWork = Preconditions.checkNotNull(workQueue.peek());
-      Windmill.WorkItem workItem = activeWork.getWorkItem();
+      WorkItem workItem = activeWork.getWorkItem();
       if (activeWork.isCommitPending()) {
         if (++commitsPendingCount >= MAX_PRINTABLE_COMMIT_PENDING_KEYS) {
           continue;
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java
index 4ac1d8bc9fa..8207a6ef2f0 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java
@@ -19,12 +19,14 @@ package org.apache.beam.runners.dataflow.worker.streaming;
 
 import com.google.api.services.dataflow.model.MapTask;
 import java.io.PrintWriter;
+import java.util.List;
 import java.util.Map;
 import java.util.concurrent.ConcurrentLinkedQueue;
 import javax.annotation.Nullable;
 import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler;
+import 
org.apache.beam.runners.dataflow.worker.streaming.ActiveWorkState.FailedTokens;
 import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor;
-import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest;
 import 
org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
@@ -98,6 +100,10 @@ public class ComputationState implements AutoCloseable {
     }
   }
 
+  public void failWork(Map<Long, List<FailedTokens>> failedWork) {
+    activeWorkState.failWorkForKey(failedWork);
+  }
+
   /**
    * Marks the work for the given shardedKey as complete. Schedules queued 
work for the key if any.
    */
@@ -120,10 +126,10 @@ public class ComputationState implements AutoCloseable {
     executor.forceExecute(work, work.getWorkItem().getSerializedSize());
   }
 
-  /** Adds any work started before the refreshDeadline to the GetDataRequest 
builder. */
-  public ImmutableList<KeyedGetDataRequest> getKeysToRefresh(
+  /** Gets HeartbeatRequests for any work started before refreshDeadline. */
+  public ImmutableList<HeartbeatRequest> getKeyHeartbeats(
       Instant refreshDeadline, DataflowExecutionStateSampler sampler) {
-    return activeWorkState.getKeysToRefresh(refreshDeadline, sampler);
+    return activeWorkState.getKeyHeartbeats(refreshDeadline, sampler);
   }
 
   public void printActiveWork(PrintWriter writer) {
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
index 3a77a8322b4..69f2a0dcee7 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
@@ -50,6 +50,8 @@ public class Work implements Runnable {
   private final Consumer<Work> processWorkFn;
   private TimedState currentState;
 
+  private boolean isFailed;
+
   private Work(Windmill.WorkItem workItem, Supplier<Instant> clock, 
Consumer<Work> processWorkFn) {
     this.workItem = workItem;
     this.clock = clock;
@@ -57,6 +59,7 @@ public class Work implements Runnable {
     this.startTime = clock.get();
     this.totalDurationPerState = new 
EnumMap<>(Windmill.LatencyAttribution.State.class);
     this.currentState = TimedState.initialState(startTime);
+    this.isFailed = false;
   }
 
   public static Work create(
@@ -95,6 +98,10 @@ public class Work implements Runnable {
     this.currentState = TimedState.create(state, now);
   }
 
+  public void setFailed() {
+    this.isFailed = true;
+  }
+
   public boolean isCommitPending() {
     return currentState.isCommitPending();
   }
@@ -180,6 +187,10 @@ public class Work implements Runnable {
     return builder;
   }
 
+  public boolean isFailed() {
+    return isFailed;
+  }
+
   boolean isStuckCommittingAt(Instant stuckCommitDeadline) {
     return currentState.state() == Work.State.COMMITTING
         && currentState.startTime().isBefore(stuckCommitDeadline);
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java
index c327e68d7e9..25581bee208 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java
@@ -19,8 +19,11 @@ package org.apache.beam.runners.dataflow.worker.windmill;
 
 import java.io.IOException;
 import java.io.PrintWriter;
+import java.util.List;
 import java.util.Set;
+import java.util.function.Consumer;
 import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse;
 import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream;
 import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
 import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream;
@@ -79,6 +82,9 @@ public abstract class WindmillServerStub implements 
StatusDataProvider {
   @Override
   public void appendSummaryHtml(PrintWriter writer) {}
 
+  public void setProcessHeartbeatResponses(
+      Consumer<List<ComputationHeartbeatResponse>> processHeartbeatResponses) 
{}
+
   /** Generic Exception type for implementors to use to represent errors while 
making RPCs. */
   public static final class RpcException extends RuntimeException {
     public RpcException(Throwable cause) {
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java
index fa1f797a191..7c22f4fb576 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java
@@ -23,6 +23,7 @@ import java.util.concurrent.TimeUnit;
 import java.util.function.Consumer;
 import javax.annotation.concurrent.ThreadSafe;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest;
 import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
 import org.joda.time.Instant;
 
@@ -59,7 +60,9 @@ public interface WindmillStream {
     Windmill.GlobalData requestGlobalData(Windmill.GlobalDataRequest request);
 
     /** Tells windmill processing is ongoing for the given keys. */
-    void refreshActiveWork(Map<String, List<Windmill.KeyedGetDataRequest>> 
active);
+    void refreshActiveWork(Map<String, List<HeartbeatRequest>> heartbeats);
+
+    void onHeartbeatResponse(List<Windmill.ComputationHeartbeatResponse> 
responses);
   }
 
   /** Interface for streaming CommitWorkRequests to Windmill. */
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java
index a04a961ca9c..b6600e04a09 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java
@@ -32,10 +32,15 @@ import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentLinkedDeque;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.Consumer;
 import java.util.function.Function;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatRequest;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataResponse;
@@ -64,6 +69,10 @@ public final class GrpcGetDataStream
   private final ThrottleTimer getDataThrottleTimer;
   private final JobHeader jobHeader;
   private final int streamingRpcBatchLimit;
+  // If true, then active work refreshes will be sent as KeyedGetDataRequests. 
Otherwise, use the
+  // newer ComputationHeartbeatRequests.
+  private final boolean sendKeyedGetDataRequests;
+  private Consumer<List<ComputationHeartbeatResponse>> 
processHeartbeatResponses;
 
   private GrpcGetDataStream(
       Function<StreamObserver<StreamingGetDataResponse>, 
StreamObserver<StreamingGetDataRequest>>
@@ -75,7 +84,9 @@ public final class GrpcGetDataStream
       ThrottleTimer getDataThrottleTimer,
       JobHeader jobHeader,
       AtomicLong idGenerator,
-      int streamingRpcBatchLimit) {
+      int streamingRpcBatchLimit,
+      boolean sendKeyedGetDataRequests,
+      Consumer<List<Windmill.ComputationHeartbeatResponse>> 
processHeartbeatResponses) {
     super(
         startGetDataRpcFn, backoff, streamObserverFactory, streamRegistry, 
logEveryNStreamFailures);
     this.idGenerator = idGenerator;
@@ -84,6 +95,8 @@ public final class GrpcGetDataStream
     this.streamingRpcBatchLimit = streamingRpcBatchLimit;
     this.batches = new ConcurrentLinkedDeque<>();
     this.pending = new ConcurrentHashMap<>();
+    this.sendKeyedGetDataRequests = sendKeyedGetDataRequests;
+    this.processHeartbeatResponses = processHeartbeatResponses;
   }
 
   public static GrpcGetDataStream create(
@@ -96,7 +109,9 @@ public final class GrpcGetDataStream
       ThrottleTimer getDataThrottleTimer,
       JobHeader jobHeader,
       AtomicLong idGenerator,
-      int streamingRpcBatchLimit) {
+      int streamingRpcBatchLimit,
+      boolean sendKeyedGetDataRequests,
+      Consumer<List<Windmill.ComputationHeartbeatResponse>> 
processHeartbeatResponses) {
     GrpcGetDataStream getDataStream =
         new GrpcGetDataStream(
             startGetDataRpcFn,
@@ -107,7 +122,9 @@ public final class GrpcGetDataStream
             getDataThrottleTimer,
             jobHeader,
             idGenerator,
-            streamingRpcBatchLimit);
+            streamingRpcBatchLimit,
+            sendKeyedGetDataRequests,
+            processHeartbeatResponses);
     getDataStream.startStream();
     return getDataStream;
   }
@@ -138,6 +155,7 @@ public final class GrpcGetDataStream
     checkArgument(chunk.getRequestIdCount() == 
chunk.getSerializedResponseCount());
     checkArgument(chunk.getRemainingBytesForResponse() == 0 || 
chunk.getRequestIdCount() == 1);
     getDataThrottleTimer.stop();
+    onHeartbeatResponse(chunk.getComputationHeartbeatResponseList());
 
     for (int i = 0; i < chunk.getRequestIdCount(); ++i) {
       AppendableInputStream responseStream = 
pending.get(chunk.getRequestId(i));
@@ -171,30 +189,71 @@ public final class GrpcGetDataStream
   }
 
   @Override
-  public void refreshActiveWork(Map<String, List<KeyedGetDataRequest>> active) 
{
-    long builderBytes = 0;
+  public void refreshActiveWork(Map<String, List<HeartbeatRequest>> 
heartbeats) {
     StreamingGetDataRequest.Builder builder = 
StreamingGetDataRequest.newBuilder();
-    for (Map.Entry<String, List<KeyedGetDataRequest>> entry : 
active.entrySet()) {
-      for (KeyedGetDataRequest request : entry.getValue()) {
-        // Calculate the bytes with some overhead for proto encoding.
-        long bytes = (long) entry.getKey().length() + 
request.getSerializedSize() + 10;
-        if (builderBytes > 0
-            && (builderBytes + bytes > 
AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE
-                || builder.getRequestIdCount() >= streamingRpcBatchLimit)) {
-          send(builder.build());
-          builderBytes = 0;
-          builder.clear();
+    if (sendKeyedGetDataRequests) {
+      long builderBytes = 0;
+      for (Map.Entry<String, List<HeartbeatRequest>> entry : 
heartbeats.entrySet()) {
+        for (HeartbeatRequest request : entry.getValue()) {
+          // Calculate the bytes with some overhead for proto encoding.
+          long bytes = (long) entry.getKey().length() + 
request.getSerializedSize() + 10;
+          if (builderBytes > 0
+              && (builderBytes + bytes > 
AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE
+                  || builder.getRequestIdCount() >= streamingRpcBatchLimit)) {
+            send(builder.build());
+            builderBytes = 0;
+            builder.clear();
+          }
+          builderBytes += bytes;
+          builder.addStateRequest(
+              ComputationGetDataRequest.newBuilder()
+                  .setComputationId(entry.getKey())
+                  .addRequests(
+                      Windmill.KeyedGetDataRequest.newBuilder()
+                          .setShardingKey(request.getShardingKey())
+                          .setWorkToken(request.getWorkToken())
+                          .setCacheToken(request.getCacheToken())
+                          
.addAllLatencyAttribution(request.getLatencyAttributionList())
+                          .build()));
         }
-        builderBytes += bytes;
-        builder.addStateRequest(
-            ComputationGetDataRequest.newBuilder()
-                .setComputationId(entry.getKey())
-                .addRequests(request));
+      }
+
+      if (builderBytes > 0) {
+        send(builder.build());
+      }
+    } else {
+      // No translation necessary, but we must still respect 
`RPC_STREAM_CHUNK_SIZE`.
+      long builderBytes = 0;
+      for (Map.Entry<String, List<HeartbeatRequest>> entry : 
heartbeats.entrySet()) {
+        ComputationHeartbeatRequest.Builder computationHeartbeatBuilder =
+            
ComputationHeartbeatRequest.newBuilder().setComputationId(entry.getKey());
+        for (HeartbeatRequest request : entry.getValue()) {
+          long bytes = (long) entry.getKey().length() + 
request.getSerializedSize() + 10;
+          if (builderBytes > 0
+              && builderBytes + bytes > 
AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) {
+            if (computationHeartbeatBuilder.getHeartbeatRequestsCount() > 0) {
+              
builder.addComputationHeartbeatRequest(computationHeartbeatBuilder.build());
+            }
+            send(builder.build());
+            builderBytes = 0;
+            builder.clear();
+            
computationHeartbeatBuilder.clear().setComputationId(entry.getKey());
+          }
+          builderBytes += bytes;
+          computationHeartbeatBuilder.addHeartbeatRequests(request);
+        }
+        
builder.addComputationHeartbeatRequest(computationHeartbeatBuilder.build());
+      }
+
+      if (builderBytes > 0) {
+        send(builder.build());
       }
     }
-    if (builderBytes > 0) {
-      send(builder.build());
-    }
+  }
+
+  @Override
+  public void onHeartbeatResponse(List<Windmill.ComputationHeartbeatResponse> 
responses) {
+    processHeartbeatResponses.accept(responses);
   }
 
   @Override
@@ -277,7 +336,7 @@ public final class GrpcGetDataStream
         waitForSendLatch.await();
       }
       // Finalize the batch so that no additional requests will be added.  
Leave the batch in the
-      // queue so that a subsequent batch will wait for it's completion.
+      // queue so that a subsequent batch will wait for its completion.
       synchronized (batches) {
         verify(batch == batches.peekFirst());
         batch.markFinalized();
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java
index 3a881df7146..9f0126a9cc6 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java
@@ -28,13 +28,17 @@ import java.util.ArrayList;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
+import java.util.function.Consumer;
 import java.util.function.Supplier;
 import javax.annotation.Nullable;
+import org.apache.beam.runners.dataflow.DataflowRunner;
 import 
org.apache.beam.runners.dataflow.worker.options.StreamingDataflowWorkerOptions;
 import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
 import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitWorkRequest;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitWorkResponse;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetConfigRequest;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetConfigResponse;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetDataRequest;
@@ -93,6 +97,10 @@ public final class GrpcWindmillServer extends 
WindmillServerStub {
   private final StreamingEngineThrottleTimers throttleTimers;
   private Duration maxBackoff;
   private @Nullable WindmillApplianceGrpc.WindmillApplianceBlockingStub 
syncApplianceStub;
+  // If true, then active work refreshes will be sent as KeyedGetDataRequests. 
Otherwise, use the
+  // newer ComputationHeartbeatRequests.
+  private final boolean sendKeyedGetDataRequests;
+  private Consumer<List<ComputationHeartbeatResponse>> 
processHeartbeatResponses;
 
   private GrpcWindmillServer(
       StreamingDataflowWorkerOptions options, GrpcDispatcherClient 
grpcDispatcherClient) {
@@ -118,9 +126,21 @@ public final class GrpcWindmillServer extends 
WindmillServerStub {
 
     this.dispatcherClient = grpcDispatcherClient;
     this.syncApplianceStub = null;
+    this.sendKeyedGetDataRequests =
+        !options.isEnableStreamingEngine()
+            || !DataflowRunner.hasExperiment(
+                options, "streaming_engine_send_new_heartbeat_requests");
+    this.processHeartbeatResponses = (responses) -> {};
   }
 
-  private static StreamingDataflowWorkerOptions testOptions(boolean 
enableStreamingEngine) {
+  @Override
+  public void setProcessHeartbeatResponses(
+      Consumer<List<Windmill.ComputationHeartbeatResponse>> 
processHeartbeatResponses) {
+    this.processHeartbeatResponses = processHeartbeatResponses;
+  };
+
+  private static StreamingDataflowWorkerOptions testOptions(
+      boolean enableStreamingEngine, List<String> additionalExperiments) {
     StreamingDataflowWorkerOptions options =
         
PipelineOptionsFactory.create().as(StreamingDataflowWorkerOptions.class);
     options.setProject("project");
@@ -131,6 +151,7 @@ public final class GrpcWindmillServer extends 
WindmillServerStub {
     if (enableStreamingEngine) {
       experiments.add(GcpOptions.STREAMING_ENGINE_EXPERIMENT);
     }
+    experiments.addAll(additionalExperiments);
     options.setExperiments(experiments);
 
     options.setWindmillServiceStreamingRpcBatchLimit(Integer.MAX_VALUE);
@@ -162,7 +183,7 @@ public final class GrpcWindmillServer extends 
WindmillServerStub {
   }
 
   @VisibleForTesting
-  static GrpcWindmillServer newTestInstance(String name) {
+  static GrpcWindmillServer newTestInstance(String name, List<String> 
experiments) {
     ManagedChannel inProcessChannel = inProcessChannel(name);
     CloudWindmillServiceV1Alpha1Stub stub =
         CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel);
@@ -173,14 +194,15 @@ public final class GrpcWindmillServer extends 
WindmillServerStub {
             WindmillStubFactory.inProcessStubFactory(name, unused -> 
inProcessChannel),
             dispatcherStubs,
             dispatcherEndpoints);
-    return new GrpcWindmillServer(testOptions(/* enableStreamingEngine= */ 
true), dispatcherClient);
+    return new GrpcWindmillServer(
+        testOptions(/* enableStreamingEngine= */ true, experiments), 
dispatcherClient);
   }
 
   @VisibleForTesting
   static GrpcWindmillServer newApplianceTestInstance(Channel channel) {
     GrpcWindmillServer testServer =
         new GrpcWindmillServer(
-            testOptions(/* enableStreamingEngine= */ false),
+            testOptions(/* enableStreamingEngine= */ false, new ArrayList<>()),
             // No-op, Appliance does not use Dispatcher to call Streaming 
Engine.
             
GrpcDispatcherClient.create(WindmillStubFactory.inProcessStubFactory("test")));
     testServer.syncApplianceStub = 
createWindmillApplianceStubWithDeadlineInterceptor(channel);
@@ -319,7 +341,10 @@ public final class GrpcWindmillServer extends 
WindmillServerStub {
   @Override
   public GetDataStream getDataStream() {
     return windmillStreamFactory.createGetDataStream(
-        dispatcherClient.getDispatcherStub(), 
throttleTimers.getDataThrottleTimer());
+        dispatcherClient.getDispatcherStub(),
+        throttleTimers.getDataThrottleTimer(),
+        sendKeyedGetDataRequests,
+        this.processHeartbeatResponses);
   }
 
   @Override
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java
index 099be8db0fd..7dc43e791e3 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java
@@ -21,6 +21,7 @@ import static 
org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWi
 
 import com.google.auto.value.AutoBuilder;
 import java.io.PrintWriter;
+import java.util.List;
 import java.util.Set;
 import java.util.Timer;
 import java.util.TimerTask;
@@ -32,6 +33,7 @@ import java.util.function.Supplier;
 import javax.annotation.concurrent.ThreadSafe;
 import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider;
 import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
 import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints;
@@ -152,7 +154,10 @@ public class GrpcWindmillStreamFactory implements 
StatusDataProvider {
   }
 
   public GetDataStream createGetDataStream(
-      CloudWindmillServiceV1Alpha1Stub stub, ThrottleTimer 
getDataThrottleTimer) {
+      CloudWindmillServiceV1Alpha1Stub stub,
+      ThrottleTimer getDataThrottleTimer,
+      boolean sendKeyedGetDataRequests,
+      Consumer<List<ComputationHeartbeatResponse>> processHeartbeatResponses) {
     return GrpcGetDataStream.create(
         responseObserver -> withDeadline(stub).getDataStream(responseObserver),
         grpcBackOff.get(),
@@ -162,7 +167,14 @@ public class GrpcWindmillStreamFactory implements 
StatusDataProvider {
         getDataThrottleTimer,
         jobHeader,
         streamIdGenerator,
-        streamingRpcBatchLimit);
+        streamingRpcBatchLimit,
+        sendKeyedGetDataRequests,
+        processHeartbeatResponses);
+  }
+
+  public GetDataStream createGetDataStream(
+      CloudWindmillServiceV1Alpha1Stub stub, ThrottleTimer 
getDataThrottleTimer) {
+    return createGetDataStream(stub, getDataThrottleTimer, false, (response) 
-> {});
   }
 
   public CommitWorkStream createCommitWorkStream(
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java
index 6c1239d6ebd..5a9e5443a50 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java
@@ -296,6 +296,11 @@ public class WindmillStateCache implements 
StatusDataProvider {
       this.computation = computation;
     }
 
+    /** Returns the computation associated to this class. */
+    public String getComputation() {
+      return this.computation;
+    }
+
     /** Invalidate all cache entries for this computation and {@code 
processingKey}. */
     public void invalidate(ByteString processingKey, long shardingKey) {
       WindmillComputationKey key =
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java
index c28939c59ee..637b838c7fe 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java
@@ -39,6 +39,7 @@ import javax.annotation.Nullable;
 import org.apache.beam.runners.dataflow.worker.KeyTokenInvalidException;
 import 
org.apache.beam.runners.dataflow.worker.MetricTrackingWindmillServerStub;
 import org.apache.beam.runners.dataflow.worker.WindmillTimeUtils;
+import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataResponse;
@@ -123,6 +124,7 @@ public class WindmillStateReader {
   private final MetricTrackingWindmillServerStub 
metricTrackingWindmillServerStub;
   private final ConcurrentHashMap<StateTag<?>, CoderAndFuture<?>> waiting;
   private long bytesRead = 0L;
+  private final Supplier<Boolean> workItemIsFailed;
 
   public WindmillStateReader(
       MetricTrackingWindmillServerStub metricTrackingWindmillServerStub,
@@ -130,7 +132,8 @@ public class WindmillStateReader {
       ByteString key,
       long shardingKey,
       long workToken,
-      Supplier<AutoCloseable> readWrapperSupplier) {
+      Supplier<AutoCloseable> readWrapperSupplier,
+      Supplier<Boolean> workItemIsFailed) {
     this.metricTrackingWindmillServerStub = metricTrackingWindmillServerStub;
     this.computation = computation;
     this.key = key;
@@ -139,6 +142,7 @@ public class WindmillStateReader {
     this.readWrapperSupplier = readWrapperSupplier;
     this.waiting = new ConcurrentHashMap<>();
     this.pendingLookups = new ConcurrentLinkedQueue<>();
+    this.workItemIsFailed = workItemIsFailed;
   }
 
   public WindmillStateReader(
@@ -147,7 +151,14 @@ public class WindmillStateReader {
       ByteString key,
       long shardingKey,
       long workToken) {
-    this(metricTrackingWindmillServerStub, computation, key, shardingKey, 
workToken, () -> null);
+    this(
+        metricTrackingWindmillServerStub,
+        computation,
+        key,
+        shardingKey,
+        workToken,
+        () -> null,
+        () -> Boolean.FALSE);
   }
 
   private <FutureT> Future<FutureT> stateFuture(StateTag<?> stateTag, 
@Nullable Coder<?> coder) {
@@ -404,6 +415,9 @@ public class WindmillStateReader {
 
   private KeyedGetDataResponse tryGetDataFromWindmill(HashSet<StateTag<?>> 
stateTags)
       throws Exception {
+    if (workItemIsFailed.get()) {
+      throw new WorkItemCancelledException(shardingKey);
+    }
     KeyedGetDataRequest keyedGetDataRequest = createRequest(stateTags);
     try (AutoCloseable ignored = readWrapperSupplier.get()) {
       return Optional.ofNullable(
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
index a434b200120..2cfec6d3139 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
@@ -46,8 +46,11 @@ import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitWorkResponse;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationCommitWorkRequest;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatRequest;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetDataRequest;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetDataResponse;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution.State;
@@ -80,9 +83,10 @@ class FakeWindmillServer extends WindmillServerStub {
   private final ErrorCollector errorCollector;
   private final ConcurrentHashMap<Long, Consumer<Windmill.CommitStatus>> 
droppedStreamingCommits;
   private int commitsRequested = 0;
-  private List<Windmill.GetDataRequest> getDataRequests = new ArrayList<>();
+  private final List<Windmill.GetDataRequest> getDataRequests = new 
ArrayList<>();
   private boolean isReady = true;
   private boolean dropStreamingCommits = false;
+  private Consumer<List<Windmill.ComputationHeartbeatResponse>> 
processHeartbeatResponses;
 
   public FakeWindmillServer(ErrorCollector errorCollector) {
     workToOffer =
@@ -91,7 +95,7 @@ class FakeWindmillServer extends WindmillServerStub {
     dataToOffer =
         new ResponseQueue<GetDataRequest, GetDataResponse>()
             .returnByDefault(GetDataResponse.getDefaultInstance())
-            // Sleep for a little bit to ensure that *-windmill-read 
state-sampled counters show up.
+            // Sleep for a bit to ensure that *-windmill-read state-sampled 
counters show up.
             .delayEachResponseBy(Duration.millis(500));
     commitsToOffer =
         new ResponseQueue<Windmill.CommitWorkRequest, CommitWorkResponse>()
@@ -102,6 +106,13 @@ class FakeWindmillServer extends WindmillServerStub {
     this.errorCollector = errorCollector;
     statsReceived = new ArrayList<>();
     droppedStreamingCommits = new ConcurrentHashMap<>();
+    processHeartbeatResponses = (responses) -> {};
+  }
+
+  @Override
+  public void setProcessHeartbeatResponses(
+      Consumer<List<Windmill.ComputationHeartbeatResponse>> 
processHeartbeatResponses) {
+    this.processHeartbeatResponses = processHeartbeatResponses;
   }
 
   public void setDropStreamingCommits(boolean dropStreamingCommits) {
@@ -116,6 +127,10 @@ class FakeWindmillServer extends WindmillServerStub {
     return dataToOffer;
   }
 
+  public void sendFailedHeartbeats(List<Windmill.ComputationHeartbeatResponse> 
responses) {
+    getDataStream().onHeartbeatResponse(responses);
+  }
+
   public ResponseQueue<Windmill.CommitWorkRequest, Windmill.CommitWorkResponse>
       whenCommitWorkCalled() {
     return commitsToOffer;
@@ -304,17 +319,23 @@ class FakeWindmillServer extends WindmillServerStub {
       }
 
       @Override
-      public void refreshActiveWork(Map<String, List<KeyedGetDataRequest>> 
active) {
+      public void refreshActiveWork(Map<String, List<HeartbeatRequest>> 
heartbeats) {
         Windmill.GetDataRequest.Builder builder = 
Windmill.GetDataRequest.newBuilder();
-        for (Map.Entry<String, List<KeyedGetDataRequest>> entry : 
active.entrySet()) {
-          builder.addRequests(
-              ComputationGetDataRequest.newBuilder()
+        for (Map.Entry<String, List<HeartbeatRequest>> entry : 
heartbeats.entrySet()) {
+          builder.addComputationHeartbeatRequest(
+              ComputationHeartbeatRequest.newBuilder()
                   .setComputationId(entry.getKey())
-                  .addAllRequests(entry.getValue()));
+                  .addAllHeartbeatRequests(entry.getValue()));
         }
+
         getData(builder.build());
       }
 
+      @Override
+      public void onHeartbeatResponse(List<ComputationHeartbeatResponse> 
responses) {
+        processHeartbeatResponses.accept(responses);
+      }
+
       @Override
       public void close() {}
 
@@ -383,6 +404,18 @@ class FakeWindmillServer extends WindmillServerStub {
     }
   }
 
+  public Map<Long, WorkItemCommitRequest> waitForAndGetCommitsWithTimeout(
+      int numCommits, Duration timeout) {
+    LOG.debug("waitForAndGetCommitsWithTimeout: {} {}", numCommits, timeout);
+    Instant waitStart = Instant.now();
+    while (commitsReceived.size() < commitsRequested + numCommits
+        && Instant.now().isBefore(waitStart.plus(timeout))) {
+      Uninterruptibles.sleepUninterruptibly(1000, TimeUnit.MILLISECONDS);
+    }
+    commitsRequested += numCommits;
+    return commitsReceived;
+  }
+
   public Map<Long, WorkItemCommitRequest> waitForAndGetCommits(int numCommits) 
{
     LOG.debug("waitForAndGetCommitsRequest: {}", numCommits);
     int maxTries = 10;
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
index 31a9af9004a..9526c96fd04 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
@@ -109,9 +109,12 @@ import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataResponse;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatRequest;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetDataRequest;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetDataResponse;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkResponse;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.InputMessageBundle;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataResponse;
@@ -577,8 +580,9 @@ public class StreamingDataflowWorkerTest {
   }
 
   /**
-   * Returns a {@link 
org.apache.beam.runners.dataflow.windmill.Windmill.WorkItemCommitRequest}
-   * builder parsed from the provided text format proto.
+   * Returns a {@link
+   * 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest}
 builder parsed
+   * from the provided text format proto.
    */
   private WorkItemCommitRequest.Builder parseCommitRequest(String output) 
throws Exception {
     WorkItemCommitRequest.Builder builder = 
Windmill.WorkItemCommitRequest.newBuilder();
@@ -3258,6 +3262,49 @@ public class StreamingDataflowWorkerTest {
     assertThat(server.numGetDataRequests(), greaterThan(0));
   }
 
+  @Test
+  public void testActiveWorkFailure() throws Exception {
+    List<ParallelInstruction> instructions =
+        Arrays.asList(
+            makeSourceInstruction(StringUtf8Coder.of()),
+            makeDoFnInstruction(blockingFn, 0, StringUtf8Coder.of()),
+            makeSinkInstruction(StringUtf8Coder.of(), 0));
+
+    FakeWindmillServer server = new FakeWindmillServer(errorCollector);
+    StreamingDataflowWorkerOptions options = 
createTestingPipelineOptions(server);
+    options.setActiveWorkRefreshPeriodMillis(100);
+    StreamingDataflowWorker worker = makeWorker(instructions, options, true /* 
publishCounters */);
+    worker.start();
+
+    // Queue up two work items for the same key.
+    server
+        .whenGetWorkCalled()
+        .thenReturn(makeInput(0, TimeUnit.MILLISECONDS.toMicros(0), "key", 
DEFAULT_SHARDING_KEY))
+        .thenReturn(makeInput(1, TimeUnit.MILLISECONDS.toMicros(0), "key", 
DEFAULT_SHARDING_KEY));
+    server.waitForEmptyWorkQueue();
+
+    // Mock Windmill sending a heartbeat response failing the second work item 
while the first
+    // is still processing.
+    ComputationHeartbeatResponse.Builder failedHeartbeat =
+        ComputationHeartbeatResponse.newBuilder();
+    failedHeartbeat
+        .setComputationId(DEFAULT_COMPUTATION_ID)
+        .addHeartbeatResponsesBuilder()
+        .setCacheToken(3)
+        .setWorkToken(1)
+        .setShardingKey(DEFAULT_SHARDING_KEY)
+        .setFailed(true);
+    
server.sendFailedHeartbeats(Collections.singletonList(failedHeartbeat.build()));
+
+    // Release the blocked calls.
+    BlockingFn.blocker.countDown();
+    Map<Long, Windmill.WorkItemCommitRequest> commits =
+        server.waitForAndGetCommitsWithTimeout(2, 
Duration.standardSeconds((5)));
+    assertEquals(1, commits.size());
+
+    worker.stop();
+  }
+
   @Test
   public void testLatencyAttributionProtobufsPopulated() {
     FakeClock clock = new FakeClock();
@@ -3573,7 +3620,10 @@ public class StreamingDataflowWorkerTest {
     Windmill.GetDataRequest heartbeat = server.getGetDataRequests().get(2);
 
     for (LatencyAttribution la :
-        heartbeat.getRequests(0).getRequests(0).getLatencyAttributionList()) {
+        heartbeat
+            .getComputationHeartbeatRequest(0)
+            .getHeartbeatRequests(0)
+            .getLatencyAttributionList()) {
       if (la.getState() == State.ACTIVE) {
         assertTrue(la.getActiveLatencyBreakdownCount() > 0);
         assertTrue(la.getActiveLatencyBreakdown(0).hasActiveMessageMetadata());
@@ -3768,7 +3818,7 @@ public class StreamingDataflowWorkerTest {
     server
         .whenGetWorkCalled()
         .thenReturn(makeInput(1, TimeUnit.MILLISECONDS.toMicros(1), 
DEFAULT_KEY_STRING, 1));
-    // Ensure that the this work item processes.
+    // Ensure that this work item processes.
     Map<Long, Windmill.WorkItemCommitRequest> result = 
server.waitForAndGetCommits(1);
     // Now ensure that nothing happens if a dropped commit actually completes.
     droppedCommits.values().iterator().next().accept(CommitStatus.OK);
@@ -4129,7 +4179,7 @@ public class StreamingDataflowWorkerTest {
                 FakeClock.this.schedule(Duration.millis(unit.toMillis(delay)), 
this);
               }
             });
-        FakeClock.this.sleep(Duration.ZERO); // Execute work that has an 
intial delay of zero.
+        FakeClock.this.sleep(Duration.ZERO); // Execute work that has an 
initial delay of zero.
         return null;
       }
     }
@@ -4167,6 +4217,7 @@ public class StreamingDataflowWorkerTest {
     }
 
     boolean isActiveWorkRefresh(GetDataRequest request) {
+      if (request.getComputationHeartbeatRequestCount() > 0) return true;
       for (ComputationGetDataRequest computationRequest : 
request.getRequestsList()) {
         if 
(!computationRequest.getComputationId().equals(DEFAULT_COMPUTATION_ID)) {
           return false;
@@ -4203,6 +4254,21 @@ public class StreamingDataflowWorkerTest {
           }
         }
       }
+      for (ComputationHeartbeatRequest heartbeatRequest :
+          request.getComputationHeartbeatRequestList()) {
+        for (HeartbeatRequest heartbeat : 
heartbeatRequest.getHeartbeatRequestsList()) {
+          for (LatencyAttribution la : heartbeat.getLatencyAttributionList()) {
+            EnumMap<LatencyAttribution.State, Duration> durations =
+                totalDurations.computeIfAbsent(
+                    heartbeat.getWorkToken(),
+                    (Long workToken) ->
+                        new EnumMap<LatencyAttribution.State, Duration>(
+                            LatencyAttribution.State.class));
+            Duration cur = Duration.millis(la.getTotalDurationMillis());
+            durations.compute(la.getState(), (s, d) -> d == null || 
d.isShorterThan(cur) ? cur : d);
+          }
+        }
+      }
       return EMPTY_DATA_RESPONDER.apply(request);
     }
   }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
index 60ecaa3e37e..451ec649aa2 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
@@ -137,7 +137,8 @@ public class StreamingModeExecutionContextTest {
         null, // synchronized processing time
         stateReader,
         sideInputStateFetcher,
-        outputBuilder);
+        outputBuilder,
+        null);
 
     TimerInternals timerInternals = stepContext.timerInternals();
 
@@ -187,7 +188,8 @@ public class StreamingModeExecutionContextTest {
         null, // synchronized processing time
         stateReader,
         sideInputStateFetcher,
-        outputBuilder);
+        outputBuilder,
+        null);
     TimerInternals timerInternals = stepContext.timerInternals();
     
assertTrue(timerTimestamp.isBefore(timerInternals.currentProcessingTime()));
   }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java
index 6fa2ffe711f..b488641d1ca 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java
@@ -37,6 +37,7 @@ import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.allOf;
 import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.greaterThan;
 import static org.hamcrest.Matchers.lessThan;
 import static org.hamcrest.Matchers.lessThanOrEqualTo;
@@ -90,9 +91,12 @@ import 
org.apache.beam.runners.dataflow.worker.WorkerCustomSources.SplittableOnl
 import org.apache.beam.runners.dataflow.worker.counters.CounterSet;
 import org.apache.beam.runners.dataflow.worker.counters.NameContext;
 import 
org.apache.beam.runners.dataflow.worker.profiler.ScopedProfiler.NoopProfileScope;
+import org.apache.beam.runners.dataflow.worker.streaming.Work;
 import org.apache.beam.runners.dataflow.worker.testing.TestCountingSource;
 import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader;
+import 
org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader.NativeReaderIterator;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import 
org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
 import org.apache.beam.sdk.coders.Coder;
@@ -613,7 +617,8 @@ public class WorkerCustomSourcesTest {
           null, // synchronized processing time
           null, // StateReader
           null, // StateFetcher
-          Windmill.WorkItemCommitRequest.newBuilder());
+          Windmill.WorkItemCommitRequest.newBuilder(),
+          null);
 
       @SuppressWarnings({"unchecked", "rawtypes"})
       NativeReader<WindowedValue<ValueWithRecordId<KV<Integer, Integer>>>> 
reader =
@@ -931,4 +936,79 @@ public class WorkerCustomSourcesTest {
     assertNull(progress.getRemainingParallelism());
     logged.verifyWarn("remaining parallelism");
   }
+
+  @Test
+  public void testFailedWorkItemsAbort() throws Exception {
+    CounterSet counterSet = new CounterSet();
+    StreamingModeExecutionStateRegistry executionStateRegistry =
+        new StreamingModeExecutionStateRegistry(null);
+    StreamingModeExecutionContext context =
+        new StreamingModeExecutionContext(
+            counterSet,
+            "computationId",
+            new ReaderCache(Duration.standardMinutes(1), Runnable::run),
+            /*stateNameMap=*/ ImmutableMap.of(),
+            new 
WindmillStateCache(options.getWorkerCacheMb()).forComputation("computationId"),
+            StreamingStepMetricsContainer.createRegistry(),
+            new DataflowExecutionStateTracker(
+                ExecutionStateSampler.newForTest(),
+                executionStateRegistry.getState(
+                    NameContext.forStage("stageName"), "other", null, 
NoopProfileScope.NOOP),
+                counterSet,
+                PipelineOptionsFactory.create(),
+                "test-work-item-id"),
+            executionStateRegistry,
+            Long.MAX_VALUE);
+
+    options.setNumWorkers(5);
+    int maxElements = 100;
+    DataflowPipelineDebugOptions debugOptions = 
options.as(DataflowPipelineDebugOptions.class);
+    debugOptions.setUnboundedReaderMaxElements(maxElements);
+
+    ByteString state = ByteString.EMPTY;
+    Windmill.WorkItem workItem =
+        Windmill.WorkItem.newBuilder()
+            .setKey(ByteString.copyFromUtf8("0000000000000001")) // key is 
zero-padded index.
+            .setWorkToken(0)
+            .setCacheToken(1)
+            .setSourceState(
+                Windmill.SourceState.newBuilder().setState(state).build()) // 
Source state.
+            .build();
+    Work dummyWork = Work.create(workItem, Instant::now, 
Collections.emptyList(), unused -> {});
+
+    context.start(
+        "key",
+        workItem,
+        new Instant(0), // input watermark
+        null, // output watermark
+        null, // synchronized processing time
+        null, // StateReader
+        null, // StateFetcher
+        Windmill.WorkItemCommitRequest.newBuilder(),
+        dummyWork::isFailed);
+
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    NativeReader<WindowedValue<ValueWithRecordId<KV<Integer, Integer>>>> 
reader =
+        (NativeReader)
+            WorkerCustomSources.create(
+                (CloudObject)
+                    serializeToCloudSource(new 
TestCountingSource(Integer.MAX_VALUE), options)
+                        .getSpec(),
+                options,
+                context);
+
+    NativeReaderIterator<WindowedValue<ValueWithRecordId<KV<Integer, 
Integer>>>> readerIterator =
+        reader.iterator();
+    int numReads = 0;
+    while ((numReads == 0) ? readerIterator.start() : 
readerIterator.advance()) {
+      WindowedValue<ValueWithRecordId<KV<Integer, Integer>>> value = 
readerIterator.getCurrent();
+      assertEquals(KV.of(0, numReads), value.getValue().getValue());
+      numReads++;
+      // Fail the work item after reading two elements.
+      if (numReads == 2) {
+        dummyWork.setFailed();
+      }
+    }
+    assertThat(numReads, equalTo(2));
+  }
 }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java
index ea57f687fd9..de30fd0f8d5 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java
@@ -36,7 +36,7 @@ import javax.annotation.Nullable;
 import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler;
 import 
org.apache.beam.runners.dataflow.worker.streaming.ActiveWorkState.ActivateWorkResult;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
-import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest;
 import 
org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache;
 import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
@@ -239,7 +239,7 @@ public class ActiveWorkStateTest {
   }
 
   @Test
-  public void testGetKeysToRefresh() {
+  public void testGetKeyHeartbeats() {
     Instant refreshDeadline = Instant.now();
 
     Work freshWork = createWork(createWorkItem(3L));
@@ -254,47 +254,51 @@ public class ActiveWorkStateTest {
     activeWorkState.activateWorkForKey(shardedKey1, freshWork);
     activeWorkState.activateWorkForKey(shardedKey2, refreshableWork2);
 
-    ImmutableList<KeyedGetDataRequest> requests =
-        activeWorkState.getKeysToRefresh(refreshDeadline, 
DataflowExecutionStateSampler.instance());
+    ImmutableList<HeartbeatRequest> requests =
+        activeWorkState.getKeyHeartbeats(refreshDeadline, 
DataflowExecutionStateSampler.instance());
 
-    ImmutableList<GetDataRequestKeyShardingKeyAndWorkToken> expected =
+    ImmutableList<HeartbeatRequestShardingKeyWorkTokenAndCacheToken> expected =
         ImmutableList.of(
-            GetDataRequestKeyShardingKeyAndWorkToken.from(shardedKey1, 
refreshableWork1),
-            GetDataRequestKeyShardingKeyAndWorkToken.from(shardedKey2, 
refreshableWork2));
+            
HeartbeatRequestShardingKeyWorkTokenAndCacheToken.from(shardedKey1, 
refreshableWork1),
+            
HeartbeatRequestShardingKeyWorkTokenAndCacheToken.from(shardedKey2, 
refreshableWork2));
 
-    ImmutableList<GetDataRequestKeyShardingKeyAndWorkToken> actual =
+    ImmutableList<HeartbeatRequestShardingKeyWorkTokenAndCacheToken> actual =
         requests.stream()
-            .map(GetDataRequestKeyShardingKeyAndWorkToken::from)
+            .map(HeartbeatRequestShardingKeyWorkTokenAndCacheToken::from)
             .collect(toImmutableList());
 
     assertThat(actual).containsExactlyElementsIn(expected);
   }
 
   @AutoValue
-  abstract static class GetDataRequestKeyShardingKeyAndWorkToken {
+  abstract static class HeartbeatRequestShardingKeyWorkTokenAndCacheToken {
 
-    private static GetDataRequestKeyShardingKeyAndWorkToken create(
-        ByteString key, long shardingKey, long workToken) {
-      return new 
AutoValue_ActiveWorkStateTest_GetDataRequestKeyShardingKeyAndWorkToken(
-          key, shardingKey, workToken);
+    private static HeartbeatRequestShardingKeyWorkTokenAndCacheToken create(
+        long shardingKey, long workToken, long cacheToken) {
+      return new 
AutoValue_ActiveWorkStateTest_HeartbeatRequestShardingKeyWorkTokenAndCacheToken(
+          shardingKey, workToken, cacheToken);
     }
 
-    private static GetDataRequestKeyShardingKeyAndWorkToken from(
-        KeyedGetDataRequest keyedGetDataRequest) {
+    private static HeartbeatRequestShardingKeyWorkTokenAndCacheToken from(
+        HeartbeatRequest heartbeatRequest) {
       return create(
-          keyedGetDataRequest.getKey(),
-          keyedGetDataRequest.getShardingKey(),
-          keyedGetDataRequest.getWorkToken());
+          heartbeatRequest.getShardingKey(),
+          heartbeatRequest.getWorkToken(),
+          heartbeatRequest.getCacheToken());
     }
 
-    private static GetDataRequestKeyShardingKeyAndWorkToken from(ShardedKey 
shardedKey, Work work) {
-      return create(shardedKey.key(), shardedKey.shardingKey(), 
work.getWorkItem().getWorkToken());
+    private static HeartbeatRequestShardingKeyWorkTokenAndCacheToken from(
+        ShardedKey shardedKey, Work work) {
+      return create(
+          shardedKey.shardingKey(),
+          work.getWorkItem().getWorkToken(),
+          work.getWorkItem().getCacheToken());
     }
 
-    abstract ByteString key();
-
     abstract long shardingKey();
 
     abstract long workToken();
+
+    abstract long cacheToken();
   }
 }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java
index 5f8a452a043..0ea25302767 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java
@@ -44,6 +44,7 @@ import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Al
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatRequest;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationWorkItemMetadata;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkStreamTimingInfo;
@@ -51,6 +52,7 @@ import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkStreamTi
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataId;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataResponse;
@@ -126,7 +128,7 @@ public class GrpcWindmillServerTest {
             .build()
             .start();
 
-    this.client = GrpcWindmillServer.newTestInstance(name);
+    this.client = GrpcWindmillServer.newTestInstance(name, new ArrayList<>());
   }
 
   @After
@@ -744,7 +746,7 @@ public class GrpcWindmillServerTest {
     while (true) {
       Thread.sleep(100);
       int tmpErrorsBeforeClose = errorsBeforeClose.get();
-      // wait for at least 1 errors before close
+      // wait for at least 1 error before close
       if (tmpErrorsBeforeClose > 0) {
         break;
       }
@@ -765,7 +767,7 @@ public class GrpcWindmillServerTest {
     while (true) {
       Thread.sleep(100);
       int tmpErrorsAfterClose = errorsAfterClose.get();
-      // wait for at least 1 errors after close
+      // wait for at least 1 error after close
       if (tmpErrorsAfterClose > 0) {
         break;
       }
@@ -786,22 +788,36 @@ public class GrpcWindmillServerTest {
     assertTrue(stream.awaitTermination(30, TimeUnit.SECONDS));
   }
 
-  private List<KeyedGetDataRequest> makeHeartbeatRequest(List<String> keys) {
+  private List<KeyedGetDataRequest> makeGetDataHeartbeatRequest(List<String> 
keys) {
     List<KeyedGetDataRequest> result = new ArrayList<>();
     for (String key : keys) {
       result.add(
           Windmill.KeyedGetDataRequest.newBuilder()
-              .setKey(ByteString.copyFromUtf8(key))
+              .setShardingKey(key.hashCode())
               .setWorkToken(0)
+              .setCacheToken(0)
+              .build());
+    }
+    return result;
+  }
+
+  private List<HeartbeatRequest> makeHeartbeatRequest(List<String> keys) {
+    List<HeartbeatRequest> result = new ArrayList<>();
+    for (String key : keys) {
+      result.add(
+          Windmill.HeartbeatRequest.newBuilder()
+              .setShardingKey(key.hashCode())
+              .setWorkToken(0)
+              .setCacheToken(0)
               .build());
     }
     return result;
   }
 
   @Test
-  public void testStreamingGetDataHeartbeats() throws Exception {
+  public void testStreamingGetDataHeartbeatsAsKeyedGetDataRequests() throws 
Exception {
     // This server records the heartbeats observed but doesn't respond.
-    final Map<String, List<KeyedGetDataRequest>> heartbeats = new HashMap<>();
+    final Map<String, List<KeyedGetDataRequest>> getDataHeartbeats = new 
HashMap<>();
 
     serviceRegistry.addService(
         new CloudWindmillServiceV1Alpha1ImplBase() {
@@ -826,16 +842,17 @@ public class GrpcWindmillServerTest {
                                 .build()));
                     sawHeader = true;
                   } else {
-                    LOG.info("Received {} heartbeats", 
chunk.getStateRequestCount());
+                    LOG.info("Received {} getDataHeartbeats", 
chunk.getStateRequestCount());
                     errorCollector.checkThat(
                         chunk.getSerializedSize(), 
Matchers.lessThanOrEqualTo(STREAM_CHUNK_SIZE));
                     errorCollector.checkThat(chunk.getRequestIdCount(), 
Matchers.is(0));
 
-                    synchronized (heartbeats) {
+                    synchronized (getDataHeartbeats) {
                       for (ComputationGetDataRequest request : 
chunk.getStateRequestList()) {
                         errorCollector.checkThat(request.getRequestsCount(), 
Matchers.is(1));
-                        heartbeats.putIfAbsent(request.getComputationId(), new 
ArrayList<>());
-                        heartbeats
+                        getDataHeartbeats.putIfAbsent(
+                            request.getComputationId(), new ArrayList<>());
+                        getDataHeartbeats
                             .get(request.getComputationId())
                             .add(request.getRequestsList().get(0));
                       }
@@ -857,7 +874,6 @@ public class GrpcWindmillServerTest {
           }
         });
 
-    Map<String, List<KeyedGetDataRequest>> activeMap = new HashMap<>();
     List<String> computation1Keys = new ArrayList<>();
     List<String> computation2Keys = new ArrayList<>();
 
@@ -865,22 +881,141 @@ public class GrpcWindmillServerTest {
       computation1Keys.add("Computation1Key" + i);
       computation2Keys.add("Computation2Key" + largeString(i * 20));
     }
-    activeMap.put("Computation1", makeHeartbeatRequest(computation1Keys));
-    activeMap.put("Computation2", makeHeartbeatRequest(computation2Keys));
+    // We're adding HeartbeatRequests to refreshActiveWork, but expecting to 
get back
+    // KeyedGetDataRequests, so make a Map of both types.
+    Map<String, List<KeyedGetDataRequest>> expectedKeyedGetDataRequests = new 
HashMap<>();
+    expectedKeyedGetDataRequests.put("Computation1", 
makeGetDataHeartbeatRequest(computation1Keys));
+    expectedKeyedGetDataRequests.put("Computation2", 
makeGetDataHeartbeatRequest(computation2Keys));
+    Map<String, List<HeartbeatRequest>> heartbeatsToRefresh = new HashMap<>();
+    heartbeatsToRefresh.put("Computation1", 
makeHeartbeatRequest(computation1Keys));
+    heartbeatsToRefresh.put("Computation2", 
makeHeartbeatRequest(computation2Keys));
 
     GetDataStream stream = client.getDataStream();
-    stream.refreshActiveWork(activeMap);
+    stream.refreshActiveWork(heartbeatsToRefresh);
     stream.close();
     assertTrue(stream.awaitTermination(60, TimeUnit.SECONDS));
 
-    while (true) {
+    boolean receivedAllGetDataHeartbeats = false;
+    while (!receivedAllGetDataHeartbeats) {
       Thread.sleep(100);
-      synchronized (heartbeats) {
-        if (heartbeats.size() != activeMap.size()) {
+      synchronized (getDataHeartbeats) {
+        if (getDataHeartbeats.size() != expectedKeyedGetDataRequests.size()) {
           continue;
         }
-        assertEquals(heartbeats, activeMap);
-        break;
+        assertEquals(expectedKeyedGetDataRequests, getDataHeartbeats);
+        receivedAllGetDataHeartbeats = true;
+      }
+    }
+  }
+
+  @Test
+  public void testStreamingGetDataHeartbeatsAsHeartbeatRequests() throws 
Exception {
+    // Create a client and server different from the one in SetUp so we can 
add an experiment to the
+    // options passed in.
+    this.server =
+        InProcessServerBuilder.forName("TestServer")
+            .fallbackHandlerRegistry(serviceRegistry)
+            .executor(Executors.newFixedThreadPool(1))
+            .build()
+            .start();
+    this.client =
+        GrpcWindmillServer.newTestInstance(
+            "TestServer",
+            
Collections.singletonList("streaming_engine_send_new_heartbeat_requests"));
+    // This server records the heartbeats observed but doesn't respond.
+    final List<ComputationHeartbeatRequest> receivedHeartbeats = new 
ArrayList<>();
+
+    serviceRegistry.addService(
+        new CloudWindmillServiceV1Alpha1ImplBase() {
+          @Override
+          public StreamObserver<StreamingGetDataRequest> getDataStream(
+              StreamObserver<StreamingGetDataResponse> responseObserver) {
+            return new StreamObserver<StreamingGetDataRequest>() {
+              boolean sawHeader = false;
+
+              @Override
+              public void onNext(StreamingGetDataRequest chunk) {
+                try {
+                  if (!sawHeader) {
+                    LOG.info("Received header");
+                    errorCollector.checkThat(
+                        chunk.getHeader(),
+                        Matchers.equalTo(
+                            JobHeader.newBuilder()
+                                .setJobId("job")
+                                .setProjectId("project")
+                                .setWorkerId("worker")
+                                .build()));
+                    sawHeader = true;
+                  } else {
+                    LOG.info(
+                        "Received {} computationHeartbeatRequests",
+                        chunk.getComputationHeartbeatRequestCount());
+                    errorCollector.checkThat(
+                        chunk.getSerializedSize(), 
Matchers.lessThanOrEqualTo(STREAM_CHUNK_SIZE));
+                    errorCollector.checkThat(chunk.getRequestIdCount(), 
Matchers.is(0));
+
+                    synchronized (receivedHeartbeats) {
+                      
receivedHeartbeats.addAll(chunk.getComputationHeartbeatRequestList());
+                    }
+                  }
+                } catch (Exception e) {
+                  errorCollector.addError(e);
+                }
+              }
+
+              @Override
+              public void onError(Throwable throwable) {}
+
+              @Override
+              public void onCompleted() {
+                responseObserver.onCompleted();
+              }
+            };
+          }
+        });
+
+    List<String> computation1Keys = new ArrayList<>();
+    List<String> computation2Keys = new ArrayList<>();
+
+    // When sending heartbeats as HeartbeatRequest protos, all keys for the 
same computation should
+    // be batched into the same ComputationHeartbeatRequest. Compare to the 
KeyedGetDataRequest
+    // version in the test above, which only sends one key per 
ComputationGetDataRequest.
+    List<ComputationHeartbeatRequest> expectedHeartbeats = new ArrayList<>();
+    ComputationHeartbeatRequest.Builder comp1Builder =
+        
ComputationHeartbeatRequest.newBuilder().setComputationId("Computation1");
+    ComputationHeartbeatRequest.Builder comp2Builder =
+        
ComputationHeartbeatRequest.newBuilder().setComputationId("Computation2");
+    for (int i = 0; i < 100; ++i) {
+      String computation1Key = "Computation1Key" + i;
+      computation1Keys.add(computation1Key);
+      comp1Builder.addHeartbeatRequests(
+          
makeHeartbeatRequest(Collections.singletonList(computation1Key)).get(0));
+      String computation2Key = "Computation2Key" + largeString(i * 20);
+      computation2Keys.add(computation2Key);
+      comp2Builder.addHeartbeatRequests(
+          
makeHeartbeatRequest(Collections.singletonList(computation2Key)).get(0));
+    }
+    expectedHeartbeats.add(comp1Builder.build());
+    expectedHeartbeats.add(comp2Builder.build());
+    Map<String, List<HeartbeatRequest>> heartbeatRequestMap = new HashMap<>();
+    heartbeatRequestMap.put("Computation1", 
makeHeartbeatRequest(computation1Keys));
+    heartbeatRequestMap.put("Computation2", 
makeHeartbeatRequest(computation2Keys));
+
+    GetDataStream stream = client.getDataStream();
+    stream.refreshActiveWork(heartbeatRequestMap);
+    stream.close();
+    assertTrue(stream.awaitTermination(60, TimeUnit.SECONDS));
+
+    boolean receivedAllHeartbeatRequests = false;
+    while (!receivedAllHeartbeatRequests) {
+      Thread.sleep(100);
+      synchronized (receivedHeartbeats) {
+        if (receivedHeartbeats.size() != expectedHeartbeats.size()) {
+          continue;
+        }
+        assertEquals(expectedHeartbeats, receivedHeartbeats);
+        receivedAllHeartbeatRequests = true;
       }
     }
   }
@@ -888,7 +1023,7 @@ public class GrpcWindmillServerTest {
   @Test
   public void testThrottleSignal() throws Exception {
     // This server responds with work items until the throttleMessage limit is 
hit at which point it
-    // returns RESROUCE_EXHAUSTED errors for throttleTime msecs after which it 
resumes sending
+    // returns RESOURCE_EXHAUSTED errors for throttleTime msecs after which it 
resumes sending
     // work items.
     final int throttleTime = 2000;
     final int throttleMessage = 15;
diff --git 
a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto
 
b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto
index 6aaeb57001e..0c824ca301b 100644
--- 
a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto
+++ 
b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto
@@ -477,9 +477,10 @@ message GetWorkResponse {
 // GetData
 
 message KeyedGetDataRequest {
-  required bytes key = 1;
+  optional bytes key = 1;
   required fixed64 work_token = 2;
   optional fixed64 sharding_key = 6;
+  optional fixed64 cache_token = 11;
   repeated TagValue values_to_fetch = 3;
   repeated TagValuePrefixRequest tag_value_prefixes_to_fetch = 10;
   repeated TagBag bags_to_fetch = 8;
@@ -507,6 +508,8 @@ message GetDataRequest {
   // Assigned worker id for the instance.
   optional string worker_id = 6;
 
+  // SE only. Will only be set by compatible client
+  repeated ComputationHeartbeatRequest computation_heartbeat_request = 7;
   // DEPRECATED
   repeated GlobalDataId global_data_to_fetch = 2;
 }
@@ -536,6 +539,44 @@ message ComputationGetDataResponse {
 message GetDataResponse {
   repeated ComputationGetDataResponse data = 1;
   repeated GlobalData global_data = 2;
+  // Only set if ComputationHeartbeatRequest was sent, prior versions do not
+  // expect a response for heartbeats. SE only.
+  repeated ComputationHeartbeatResponse computation_heartbeat_response = 3;
+}
+
+// Heartbeats
+//
+// Heartbeats are sent over the GetData stream in Streaming Engine and
+// indicates the work item that the user worker has previously received from
+// GetWork but not yet committed with CommitWork.
+// Note that implicit heartbeats not expecting a response may be sent as
+// special KeyedGetDataRequests see function KeyedGetDataRequestIsHeartbeat.
+// SE only.
+message HeartbeatRequest {
+  optional fixed64 sharding_key = 1;
+  optional fixed64 work_token = 2;
+  optional fixed64 cache_token = 3;
+  repeated LatencyAttribution latency_attribution = 4;
+}
+
+// Responses for heartbeat requests, indicating which work is no longer valid
+// on the windmill worker and may be dropped/cancelled in the client.
+// SE only.
+message HeartbeatResponse {
+  optional fixed64 sharding_key = 1;
+  optional fixed64 work_token = 2;
+  optional fixed64 cache_token = 3;
+  optional bool failed = 4;
+}
+
+message ComputationHeartbeatRequest {
+  optional string computation_id = 1;
+  repeated HeartbeatRequest heartbeat_requests = 2;
+}
+
+message ComputationHeartbeatResponse {
+  optional string computation_id = 1;
+  repeated HeartbeatResponse heartbeat_responses = 2;
 }
 
 // CommitWork
@@ -772,6 +813,8 @@ message StreamingGetDataRequest {
   repeated fixed64 request_id = 1;
   repeated GlobalDataRequest global_data_request = 3;
   repeated ComputationGetDataRequest state_request = 4;
+  // Will only be set by compatible client
+  repeated ComputationHeartbeatRequest computation_heartbeat_request = 5;
 }
 
 message StreamingGetDataResponse {
@@ -784,6 +827,12 @@ message StreamingGetDataResponse {
   repeated bytes serialized_response = 2;
   // Remaining bytes field applies only to the last serialized_response
   optional int64 remaining_bytes_for_response = 3;
+
+  // Only set if ComputationHeartbeatRequest was sent, prior versions do not
+  // expect a response for heartbeats.
+  repeated ComputationHeartbeatResponse computation_heartbeat_response = 5;
+
+  reserved 4;
 }
 
 message StreamingCommitWorkRequest {

Reply via email to