This is an automated email from the ASF dual-hosted git repository.

scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 514b03bebcc check for cachetoken representing a retry before 
activating and completing work (#29082)
514b03bebcc is described below

commit 514b03bebcc1addf39d2250c09c6aa42ee68b3db
Author: martin trieu <[email protected]>
AuthorDate: Tue Feb 13 04:41:41 2024 -0800

    check for cachetoken representing a retry before activating and completing 
work (#29082)
---
 .../dataflow/worker/StreamingDataflowWorker.java   |  40 ++--
 .../dataflow/worker/streaming/ActiveWorkState.java | 161 +++++++--------
 .../worker/streaming/ComputationState.java         |  14 +-
 .../runners/dataflow/worker/streaming/Work.java    |  10 +
 .../runners/dataflow/worker/streaming/WorkId.java  |  48 +++++
 .../worker/StreamingDataflowWorkerTest.java        |  94 +++++++--
 .../worker/streaming/ActiveWorkStateTest.java      | 216 ++++++++++++++++-----
 7 files changed, 418 insertions(+), 165 deletions(-)

diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
index 2f9e18cde67..4d2ef6a03cf 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
@@ -87,7 +87,6 @@ import 
org.apache.beam.runners.dataflow.worker.status.DebugCapture.Capturable;
 import 
org.apache.beam.runners.dataflow.worker.status.LastExceptionDataProvider;
 import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider;
 import org.apache.beam.runners.dataflow.worker.status.WorkerStatusPages;
-import 
org.apache.beam.runners.dataflow.worker.streaming.ActiveWorkState.FailedTokens;
 import org.apache.beam.runners.dataflow.worker.streaming.Commit;
 import org.apache.beam.runners.dataflow.worker.streaming.ComputationState;
 import org.apache.beam.runners.dataflow.worker.streaming.ExecutionState;
@@ -97,6 +96,7 @@ import 
org.apache.beam.runners.dataflow.worker.streaming.StageInfo;
 import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue;
 import org.apache.beam.runners.dataflow.worker.streaming.Work;
 import org.apache.beam.runners.dataflow.worker.streaming.Work.State;
+import org.apache.beam.runners.dataflow.worker.streaming.WorkId;
 import 
org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher;
 import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor;
 import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor;
@@ -104,6 +104,7 @@ import 
org.apache.beam.runners.dataflow.worker.util.common.worker.ElementCounter
 import 
org.apache.beam.runners.dataflow.worker.util.common.worker.OutputObjectAndByteCounter;
 import 
org.apache.beam.runners.dataflow.worker.util.common.worker.ReadOperation;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest;
 import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub;
@@ -1311,7 +1312,7 @@ public class StreamingDataflowWorker {
         // Consider the item invalid. It will eventually be retried by 
Windmill if it still needs to
         // be processed.
         computationState.completeWorkAndScheduleNextWorkForKey(
-            ShardedKey.create(key, workItem.getShardingKey()), 
workItem.getWorkToken());
+            ShardedKey.create(key, workItem.getShardingKey()), work.id());
       }
     } finally {
       // Update total processing time counters. Updating in finally clause 
ensures that
@@ -1389,7 +1390,10 @@ public class StreamingDataflowWorker {
         for (Windmill.WorkItemCommitRequest workRequest : 
entry.getValue().getRequestsList()) {
           computationState.completeWorkAndScheduleNextWorkForKey(
               ShardedKey.create(workRequest.getKey(), 
workRequest.getShardingKey()),
-              workRequest.getWorkToken());
+              WorkId.builder()
+                  .setCacheToken(workRequest.getCacheToken())
+                  .setWorkToken(workRequest.getWorkToken())
+                  .build());
         }
       }
     }
@@ -1409,7 +1413,11 @@ public class StreamingDataflowWorker {
           .forComputation(state.getComputationId())
           .invalidate(request.getKey(), request.getShardingKey());
       state.completeWorkAndScheduleNextWorkForKey(
-          ShardedKey.create(request.getKey(), request.getShardingKey()), 
request.getWorkToken());
+          ShardedKey.create(request.getKey(), request.getShardingKey()),
+          WorkId.builder()
+              .setWorkToken(request.getWorkToken())
+              .setCacheToken(request.getCacheToken())
+              .build());
       return true;
     }
 
@@ -1431,7 +1439,10 @@ public class StreamingDataflowWorker {
           activeCommitBytes.addAndGet(-size);
           state.completeWorkAndScheduleNextWorkForKey(
               ShardedKey.create(request.getKey(), request.getShardingKey()),
-              request.getWorkToken());
+              WorkId.builder()
+                  .setCacheToken(request.getCacheToken())
+                  .setWorkToken(request.getWorkToken())
+                  .build());
         })) {
       return true;
     } else {
@@ -1963,20 +1974,19 @@ public class StreamingDataflowWorker {
     }
   }
 
-  public void 
handleHeartbeatResponses(List<Windmill.ComputationHeartbeatResponse> responses) 
{
-    for (Windmill.ComputationHeartbeatResponse computationHeartbeatResponse : 
responses) {
+  public void handleHeartbeatResponses(List<ComputationHeartbeatResponse> 
responses) {
+    for (ComputationHeartbeatResponse computationHeartbeatResponse : 
responses) {
       // Maps sharding key to (work token, cache token) for work that should 
be marked failed.
-      Map<Long, List<FailedTokens>> failedWork = new HashMap<>();
+      Multimap<Long, WorkId> failedWork = ArrayListMultimap.create();
       for (Windmill.HeartbeatResponse heartbeatResponse :
           computationHeartbeatResponse.getHeartbeatResponsesList()) {
         if (heartbeatResponse.getFailed()) {
-          failedWork
-              .computeIfAbsent(heartbeatResponse.getShardingKey(), key -> new 
ArrayList<>())
-              .add(
-                  FailedTokens.newBuilder()
-                      .setWorkToken(heartbeatResponse.getWorkToken())
-                      .setCacheToken(heartbeatResponse.getCacheToken())
-                      .build());
+          failedWork.put(
+              heartbeatResponse.getShardingKey(),
+              WorkId.builder()
+                  .setWorkToken(heartbeatResponse.getWorkToken())
+                  .setCacheToken(heartbeatResponse.getCacheToken())
+                  .build());
         }
       }
       ComputationState state = 
computationMap.get(computationHeartbeatResponse.getComputationId());
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java
index b4b46932393..a989206408e 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java
@@ -19,12 +19,12 @@ package org.apache.beam.runners.dataflow.worker.streaming;
 
 import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;
 
-import com.google.auto.value.AutoValue;
 import java.io.PrintWriter;
 import java.util.ArrayDeque;
+import java.util.Collection;
 import java.util.Deque;
 import java.util.HashMap;
-import java.util.List;
+import java.util.Iterator;
 import java.util.Map;
 import java.util.Map.Entry;
 import java.util.Optional;
@@ -46,6 +46,7 @@ import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.Vi
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap;
 import org.joda.time.Duration;
 import org.joda.time.Instant;
 import org.slf4j.Logger;
@@ -60,7 +61,7 @@ import org.slf4j.LoggerFactory;
 public final class ActiveWorkState {
   private static final Logger LOG = 
LoggerFactory.getLogger(ActiveWorkState.class);
 
-  /* The max number of keys in COMMITTING or COMMIT_QUEUED status to be 
shown.*/
+  /* The max number of keys in COMMITTING or COMMIT_QUEUED status to be shown 
for observability.*/
   private static final int MAX_PRINTABLE_COMMIT_PENDING_KEYS = 50;
 
   /**
@@ -76,7 +77,7 @@ public final class ActiveWorkState {
   /**
    * Current budget that is being processed or queued on the user worker. 
Incremented when work is
    * activated in {@link #activateWorkForKey(ShardedKey, Work)}, and 
decremented when work is
-   * completed in {@link #completeWorkAndGetNextWorkForKey(ShardedKey, long)}.
+   * completed in {@link #completeWorkAndGetNextWorkForKey(ShardedKey, 
WorkId)}.
    */
   private final AtomicReference<GetWorkBudget> activeGetWorkBudget;
 
@@ -105,8 +106,31 @@ public final class ActiveWorkState {
     return activeFor.toString().substring(2);
   }
 
+  private static Stream<HeartbeatRequest> toHeartbeatRequestStream(
+      Entry<ShardedKey, Deque<Work>> shardedKeyAndWorkQueue,
+      Instant refreshDeadline,
+      DataflowExecutionStateSampler sampler) {
+    ShardedKey shardedKey = shardedKeyAndWorkQueue.getKey();
+    Deque<Work> workQueue = shardedKeyAndWorkQueue.getValue();
+
+    return workQueue.stream()
+        .filter(work -> work.getStartTime().isBefore(refreshDeadline))
+        // Don't send heartbeats for queued work we already know is failed.
+        .filter(work -> !work.isFailed())
+        .map(
+            work ->
+                Windmill.HeartbeatRequest.newBuilder()
+                    .setShardingKey(shardedKey.shardingKey())
+                    .setWorkToken(work.getWorkItem().getWorkToken())
+                    .setCacheToken(work.getWorkItem().getCacheToken())
+                    .addAllLatencyAttribution(
+                        work.getLatencyAttributions(
+                            /* isHeartbeat= */ true, 
work.getLatencyTrackingId(), sampler))
+                    .build());
+  }
+
   /**
-   * Activates {@link Work} for the {@link ShardedKey}. Outcome can be 1 of 3 
{@link
+   * Activates {@link Work} for the {@link ShardedKey}. Outcome can be 1 of 4 
{@link
    * ActivateWorkResult}
    *
    * <p>1. EXECUTE: The {@link ShardedKey} has not been seen before, create a 
{@link Queue<Work>}
@@ -116,7 +140,11 @@ public final class ActiveWorkState {
    * the {@link ShardedKey}'s work queue, mark the {@link Work} as a duplicate.
    *
    * <p>3. QUEUED: A work queue for the {@link ShardedKey} exists, and the 
work is not in the key's
-   * work queue, queue the work for later processing.
+   * work queue, OR the work in the work queue is stale, OR the work in the 
queue has a matching
+   * work token but different cache token, queue the work for later processing.
+   *
+   * <p>4. STALE: A work queue for the {@link ShardedKey} exists, and there is 
a queued {@link Work}
+   * with a greater workToken than the passed in {@link Work}.
    */
   synchronized ActivateWorkResult activateWorkForKey(ShardedKey shardedKey, 
Work work) {
     Deque<Work> workQueue = activeWork.getOrDefault(shardedKey, new 
ArrayDeque<>());
@@ -129,11 +157,26 @@ public final class ActiveWorkState {
       return ActivateWorkResult.EXECUTE;
     }
 
-    // Ensure we don't already have this work token queued.
-    for (Work queuedWork : workQueue) {
-      if (queuedWork.getWorkItem().getWorkToken() == 
work.getWorkItem().getWorkToken()) {
+    // Check to see if we have this work token queued.
+    Iterator<Work> workIterator = workQueue.iterator();
+    while (workIterator.hasNext()) {
+      Work queuedWork = workIterator.next();
+      if (queuedWork.id().equals(work.id())) {
         return ActivateWorkResult.DUPLICATE;
       }
+      if (queuedWork.id().cacheToken() == work.id().cacheToken()) {
+        if (work.id().workToken() > queuedWork.id().workToken()) {
+          // Check to see if the queuedWork is active. We only want to remove 
it if it is NOT
+          // currently active.
+          if (!queuedWork.equals(workQueue.peek())) {
+            workIterator.remove();
+            decrementActiveWorkBudget(queuedWork);
+          }
+          // Continue here to possibly remove more non-active stale work that 
is queued.
+        } else {
+          return ActivateWorkResult.STALE;
+        }
+      }
     }
 
     // Queue the work for later processing.
@@ -142,51 +185,30 @@ public final class ActiveWorkState {
     return ActivateWorkResult.QUEUED;
   }
 
-  @AutoValue
-  public abstract static class FailedTokens {
-    public static Builder newBuilder() {
-      return new AutoValue_ActiveWorkState_FailedTokens.Builder();
-    }
-
-    public abstract long workToken();
-
-    public abstract long cacheToken();
-
-    @AutoValue.Builder
-    public abstract static class Builder {
-      public abstract Builder setWorkToken(long value);
-
-      public abstract Builder setCacheToken(long value);
-
-      public abstract FailedTokens build();
-    }
-  }
-
   /**
    * Fails any active work matching an element of the input Map.
    *
    * @param failedWork a map from sharding_key to tokens for the corresponding 
work.
    */
-  synchronized void failWorkForKey(Map<Long, List<FailedTokens>> failedWork) {
+  synchronized void failWorkForKey(Multimap<Long, WorkId> failedWork) {
     // Note we can't construct a ShardedKey and look it up in activeWork 
directly since
     // HeartbeatResponse doesn't include the user key.
     for (Entry<ShardedKey, Deque<Work>> entry : activeWork.entrySet()) {
-      List<FailedTokens> failedTokens = 
failedWork.get(entry.getKey().shardingKey());
-      if (failedTokens == null) continue;
-      for (FailedTokens failedToken : failedTokens) {
+      Collection<WorkId> failedWorkIds = 
failedWork.get(entry.getKey().shardingKey());
+      for (WorkId failedWorkId : failedWorkIds) {
         for (Work queuedWork : entry.getValue()) {
           WorkItem workItem = queuedWork.getWorkItem();
-          if (workItem.getWorkToken() == failedToken.workToken()
-              && workItem.getCacheToken() == failedToken.cacheToken()) {
+          if (workItem.getWorkToken() == failedWorkId.workToken()
+              && workItem.getCacheToken() == failedWorkId.cacheToken()) {
             LOG.debug(
                 "Failing work "
                     + computationStateCache.getComputation()
                     + " "
                     + entry.getKey().shardingKey()
                     + " "
-                    + failedToken.workToken()
+                    + failedWorkId.workToken()
                     + " "
-                    + failedToken.cacheToken()
+                    + failedWorkId.cacheToken()
                     + ". The work will be retried and is not lost.");
             queuedWork.setFailed();
             break;
@@ -213,34 +235,38 @@ public final class ActiveWorkState {
    * #activeWork}.
    */
   synchronized Optional<Work> completeWorkAndGetNextWorkForKey(
-      ShardedKey shardedKey, long workToken) {
+      ShardedKey shardedKey, WorkId workId) {
     @Nullable Queue<Work> workQueue = activeWork.get(shardedKey);
     if (workQueue == null) {
       // Work may have been completed due to clearing of stuck commits.
-      LOG.warn("Unable to complete inactive work for key {} and token {}.", 
shardedKey, workToken);
+      LOG.warn("Unable to complete inactive work for key {} and token {}.", 
shardedKey, workId);
       return Optional.empty();
     }
-    removeCompletedWorkFromQueue(workQueue, shardedKey, workToken);
+    removeCompletedWorkFromQueue(workQueue, shardedKey, workId);
     return getNextWork(workQueue, shardedKey);
   }
 
   private synchronized void removeCompletedWorkFromQueue(
-      Queue<Work> workQueue, ShardedKey shardedKey, long workToken) {
+      Queue<Work> workQueue, ShardedKey shardedKey, WorkId workId) {
+    // avoid Preconditions.checkState here to prevent eagerly evaluating the
+    // format string parameters for the error message.
     Work completedWork = workQueue.peek();
     if (completedWork == null) {
       // Work may have been completed due to clearing of stuck commits.
-      LOG.warn(
-          String.format("Active key %s without work, expected token %d", 
shardedKey, workToken));
+      LOG.warn("Active key {} without work, expected token {}", shardedKey, 
workId);
       return;
     }
 
-    if (completedWork.getWorkItem().getWorkToken() != workToken) {
+    if (!completedWork.id().equals(workId)) {
       // Work may have been completed due to clearing of stuck commits.
       LOG.warn(
-          "Unable to complete due to token mismatch for key {} and token {}, 
actual token was {}.",
+          "Unable to complete due to token mismatch for "
+              + "key {},"
+              + "expected work_id {}, "
+              + "actual work_id was {}",
           shardedKey,
-          workToken,
-          completedWork.getWorkItem().getWorkToken());
+          workId,
+          completedWork.id());
       return;
     }
 
@@ -263,21 +289,21 @@ public final class ActiveWorkState {
    * before the stuckCommitDeadline.
    */
   synchronized void invalidateStuckCommits(
-      Instant stuckCommitDeadline, BiConsumer<ShardedKey, Long> 
shardedKeyAndWorkTokenConsumer) {
-    for (Entry<ShardedKey, Long> shardedKeyAndWorkToken :
+      Instant stuckCommitDeadline, BiConsumer<ShardedKey, WorkId> 
shardedKeyAndWorkTokenConsumer) {
+    for (Entry<ShardedKey, WorkId> shardedKeyAndWorkId :
         getStuckCommitsAt(stuckCommitDeadline).entrySet()) {
-      ShardedKey shardedKey = shardedKeyAndWorkToken.getKey();
-      long workToken = shardedKeyAndWorkToken.getValue();
+      ShardedKey shardedKey = shardedKeyAndWorkId.getKey();
+      WorkId workId = shardedKeyAndWorkId.getValue();
       computationStateCache.invalidate(shardedKey.key(), 
shardedKey.shardingKey());
-      shardedKeyAndWorkTokenConsumer.accept(shardedKey, workToken);
+      shardedKeyAndWorkTokenConsumer.accept(shardedKey, workId);
     }
   }
 
-  private synchronized ImmutableMap<ShardedKey, Long> getStuckCommitsAt(
+  private synchronized ImmutableMap<ShardedKey, WorkId> getStuckCommitsAt(
       Instant stuckCommitDeadline) {
     // Determine the stuck commit keys but complete them outside the loop 
iterating over
     // activeWork as completeWork may delete the entry from activeWork.
-    ImmutableMap.Builder<ShardedKey, Long> stuckCommits = 
ImmutableMap.builder();
+    ImmutableMap.Builder<ShardedKey, WorkId> stuckCommits = 
ImmutableMap.builder();
     for (Entry<ShardedKey, Deque<Work>> entry : activeWork.entrySet()) {
       ShardedKey shardedKey = entry.getKey();
       @Nullable Work work = entry.getValue().peek();
@@ -287,7 +313,7 @@ public final class ActiveWorkState {
               "Detected key {} stuck in COMMITTING state since {}, completing 
it with error.",
               shardedKey,
               work.getStateStartTime());
-          stuckCommits.put(shardedKey, work.getWorkItem().getWorkToken());
+          stuckCommits.put(shardedKey, work.id());
         }
       }
     }
@@ -302,28 +328,6 @@ public final class ActiveWorkState {
         .collect(toImmutableList());
   }
 
-  private static Stream<HeartbeatRequest> toHeartbeatRequestStream(
-      Entry<ShardedKey, Deque<Work>> shardedKeyAndWorkQueue,
-      Instant refreshDeadline,
-      DataflowExecutionStateSampler sampler) {
-    ShardedKey shardedKey = shardedKeyAndWorkQueue.getKey();
-    Deque<Work> workQueue = shardedKeyAndWorkQueue.getValue();
-
-    return workQueue.stream()
-        .filter(work -> work.getStartTime().isBefore(refreshDeadline))
-        // Don't send heartbeats for queued work we already know is failed.
-        .filter(work -> !work.isFailed())
-        .map(
-            work ->
-                Windmill.HeartbeatRequest.newBuilder()
-                    .setShardingKey(shardedKey.shardingKey())
-                    .setWorkToken(work.getWorkItem().getWorkToken())
-                    .setCacheToken(work.getWorkItem().getCacheToken())
-                    .addAllLatencyAttribution(
-                        work.getLatencyAttributions(true, 
work.getLatencyTrackingId(), sampler))
-                    .build());
-  }
-
   /**
    * Returns the current aggregate {@link GetWorkBudget} that is active on the 
user worker. Active
    * means that the work is received from Windmill, being processed or queued 
to be processed in
@@ -386,6 +390,7 @@ public final class ActiveWorkState {
   enum ActivateWorkResult {
     QUEUED,
     EXECUTE,
-    DUPLICATE
+    DUPLICATE,
+    STALE
   }
 }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java
index 8207a6ef2f0..33ef4950f9a 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java
@@ -19,18 +19,17 @@ package org.apache.beam.runners.dataflow.worker.streaming;
 
 import com.google.api.services.dataflow.model.MapTask;
 import java.io.PrintWriter;
-import java.util.List;
 import java.util.Map;
 import java.util.concurrent.ConcurrentLinkedQueue;
 import javax.annotation.Nullable;
 import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler;
-import 
org.apache.beam.runners.dataflow.worker.streaming.ActiveWorkState.FailedTokens;
 import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest;
 import 
org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap;
 import org.joda.time.Instant;
 
 /**
@@ -81,11 +80,14 @@ public class ComputationState implements AutoCloseable {
 
   /**
    * Mark the given {@link ShardedKey} and {@link Work} as active, and 
schedules execution of {@link
-   * Work} if there is no active {@link Work} for the {@link ShardedKey} 
already processing.
+   * Work} if there is no active {@link Work} for the {@link ShardedKey} 
already processing. Returns
+   * whether the {@link Work} will be activated, either immediately or 
sometime in the future.
    */
   public boolean activateWork(ShardedKey shardedKey, Work work) {
     switch (activeWorkState.activateWorkForKey(shardedKey, work)) {
       case DUPLICATE:
+        // Fall through intentionally. Work was not and will not be activated 
in these cases.
+      case STALE:
         return false;
       case QUEUED:
         return true;
@@ -100,16 +102,16 @@ public class ComputationState implements AutoCloseable {
     }
   }
 
-  public void failWork(Map<Long, List<FailedTokens>> failedWork) {
+  public void failWork(Multimap<Long, WorkId> failedWork) {
     activeWorkState.failWorkForKey(failedWork);
   }
 
   /**
    * Marks the work for the given shardedKey as complete. Schedules queued 
work for the key if any.
    */
-  public void completeWorkAndScheduleNextWorkForKey(ShardedKey shardedKey, 
long workToken) {
+  public void completeWorkAndScheduleNextWorkForKey(ShardedKey shardedKey, 
WorkId workId) {
     activeWorkState
-        .completeWorkAndGetNextWorkForKey(shardedKey, workToken)
+        .completeWorkAndGetNextWorkForKey(shardedKey, workId)
         .ifPresent(this::forceExecute);
   }
 
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
index 6c85c615af1..99cdaad200e 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
@@ -47,6 +47,7 @@ public class Work implements Runnable {
   private final Instant startTime;
   private final Map<Windmill.LatencyAttribution.State, Duration> 
totalDurationPerState;
   private final Consumer<Work> processWorkFn;
+  private final WorkId id;
   private TimedState currentState;
   private volatile boolean isFailed;
 
@@ -58,6 +59,11 @@ public class Work implements Runnable {
     this.totalDurationPerState = new 
EnumMap<>(Windmill.LatencyAttribution.State.class);
     this.currentState = TimedState.initialState(startTime);
     this.isFailed = false;
+    this.id =
+        WorkId.builder()
+            .setCacheToken(workItem.getCacheToken())
+            .setWorkToken(workItem.getWorkToken())
+            .build();
   }
 
   public static Work create(
@@ -116,6 +122,10 @@ public class Work implements Runnable {
     return workIdBuilder.toString();
   }
 
+  public WorkId id() {
+    return id;
+  }
+
   private void recordGetWorkStreamLatencies(
       Collection<Windmill.LatencyAttribution> getWorkStreamLatencies) {
     for (Windmill.LatencyAttribution latency : getWorkStreamLatencies) {
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkId.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkId.java
new file mode 100644
index 00000000000..d56b56c184c
--- /dev/null
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkId.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.streaming;
+
+import com.google.auto.value.AutoValue;
+
+/**
+ * A composite key used to identify a unit of {@link Work}. If multiple units 
of {@link Work} have
+ * the same workToken AND cacheToken, the {@link Work} is a duplicate. If 
multiple units of {@link
+ * Work} have the same workToken, but different cacheTokens, the {@link Work} 
is a retry. If
+ * multiple units of {@link Work} have the same cacheToken, but different 
workTokens, the {@link
+ * Work} is obsolete.
+ */
+@AutoValue
+public abstract class WorkId {
+
+  public static Builder builder() {
+    return new AutoValue_WorkId.Builder();
+  }
+
+  abstract long cacheToken();
+
+  abstract long workToken();
+
+  @AutoValue.Builder
+  public abstract static class Builder {
+    public abstract Builder setCacheToken(long value);
+
+    public abstract Builder setWorkToken(long value);
+
+    public abstract WorkId build();
+  }
+}
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
index e7eedcf3780..1035fada0ff 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
@@ -556,7 +556,8 @@ public class StreamingDataflowWorkerTest {
             + shardingKey
             + "    work_token: "
             + index
-            + "    cache_token: 3"
+            + "    cache_token: "
+            + (index + 1)
             + "    hot_key_info {"
             + "      hot_key_age_usec: 1000000"
             + "    }"
@@ -579,6 +580,47 @@ public class StreamingDataflowWorkerTest {
             Collections.singletonList(DEFAULT_WINDOW)));
   }
 
+  private Windmill.GetWorkResponse makeInput(
+      int workToken, int cacheToken, long timestamp, String key, long 
shardingKey)
+      throws Exception {
+    return buildInput(
+        "work {"
+            + "  computation_id: \""
+            + DEFAULT_COMPUTATION_ID
+            + "\""
+            + "  input_data_watermark: 0"
+            + "  work {"
+            + "    key: \""
+            + key
+            + "\""
+            + "    sharding_key: "
+            + shardingKey
+            + "    work_token: "
+            + workToken
+            + "    cache_token: "
+            + cacheToken
+            + "    hot_key_info {"
+            + "      hot_key_age_usec: 1000000"
+            + "    }"
+            + "    message_bundles {"
+            + "      source_computation_id: \""
+            + DEFAULT_SOURCE_COMPUTATION_ID
+            + "\""
+            + "      messages {"
+            + "        timestamp: "
+            + timestamp
+            + "        data: \"data"
+            + workToken
+            + "\""
+            + "      }"
+            + "    }"
+            + "  }"
+            + "}",
+        CoderUtils.encodeToByteArray(
+            CollectionCoder.of(IntervalWindow.getCoder()),
+            Collections.singletonList(DEFAULT_WINDOW)));
+  }
+
   /**
    * Returns a {@link
    * 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest}
 builder parsed
@@ -655,7 +697,9 @@ public class StreamingDataflowWorkerTest {
     requestBuilder.append("work_token: ");
     requestBuilder.append(index);
     requestBuilder.append(" ");
-    requestBuilder.append("cache_token: 3 ");
+    requestBuilder.append("cache_token: ");
+    requestBuilder.append(index + 1);
+    requestBuilder.append(" ");
     if (hasSourceBytesProcessed) 
requestBuilder.append("source_bytes_processed: 0 ");
 
     return requestBuilder;
@@ -2677,7 +2721,7 @@ public class StreamingDataflowWorkerTest {
     Work m1 = createMockWork(1);
     assertTrue(computationState.activateWork(key1, m1));
     Mockito.verify(mockExecutor).execute(m1, 
m1.getWorkItem().getSerializedSize());
-    computationState.completeWorkAndScheduleNextWorkForKey(key1, 1);
+    computationState.completeWorkAndScheduleNextWorkForKey(key1, m1.id());
     Mockito.verifyNoMoreInteractions(mockExecutor);
 
     // Verify work queues.
@@ -2692,12 +2736,12 @@ public class StreamingDataflowWorkerTest {
     Work m4 = createMockWork(4);
     assertTrue(computationState.activateWork(key2, m4));
     Mockito.verify(mockExecutor).execute(m4, 
m4.getWorkItem().getSerializedSize());
-    computationState.completeWorkAndScheduleNextWorkForKey(key2, 4);
+    computationState.completeWorkAndScheduleNextWorkForKey(key2, m4.id());
     Mockito.verifyNoMoreInteractions(mockExecutor);
 
-    computationState.completeWorkAndScheduleNextWorkForKey(key1, 2);
+    computationState.completeWorkAndScheduleNextWorkForKey(key1, m2.id());
     Mockito.verify(mockExecutor).forceExecute(m3, 
m3.getWorkItem().getSerializedSize());
-    computationState.completeWorkAndScheduleNextWorkForKey(key1, 3);
+    computationState.completeWorkAndScheduleNextWorkForKey(key1, m3.id());
     Mockito.verifyNoMoreInteractions(mockExecutor);
 
     // Verify duplicate work dropped.
@@ -2706,7 +2750,7 @@ public class StreamingDataflowWorkerTest {
     Mockito.verify(mockExecutor).execute(m5, 
m5.getWorkItem().getSerializedSize());
     assertFalse(computationState.activateWork(key1, m5));
     Mockito.verifyNoMoreInteractions(mockExecutor);
-    computationState.completeWorkAndScheduleNextWorkForKey(key1, 5);
+    computationState.completeWorkAndScheduleNextWorkForKey(key1, m5.id());
     Mockito.verifyNoMoreInteractions(mockExecutor);
   }
 
@@ -2727,7 +2771,7 @@ public class StreamingDataflowWorkerTest {
     Work m1 = createMockWork(1);
     assertTrue(computationState.activateWork(key1Shard1, m1));
     Mockito.verify(mockExecutor).execute(m1, 
m1.getWorkItem().getSerializedSize());
-    computationState.completeWorkAndScheduleNextWorkForKey(key1Shard1, 1);
+    computationState.completeWorkAndScheduleNextWorkForKey(key1Shard1, 
m1.id());
     Mockito.verifyNoMoreInteractions(mockExecutor);
 
     // Verify work queues.
@@ -2747,7 +2791,7 @@ public class StreamingDataflowWorkerTest {
 
     // Verify duplicate work dropped
     assertFalse(computationState.activateWork(key1Shard2, m4));
-    computationState.completeWorkAndScheduleNextWorkForKey(key1Shard2, 3);
+    computationState.completeWorkAndScheduleNextWorkForKey(key1Shard2, 
m4.id());
     Mockito.verifyNoMoreInteractions(mockExecutor);
   }
 
@@ -3286,11 +3330,20 @@ public class StreamingDataflowWorkerTest {
     StreamingDataflowWorker worker = makeWorker(instructions, options, true /* 
publishCounters */);
     worker.start();
 
+    GetWorkResponse workItem =
+        makeInput(0, TimeUnit.MILLISECONDS.toMicros(0), "key", 
DEFAULT_SHARDING_KEY);
+    int failedWorkToken = 1;
+    int failedCacheToken = 5;
+    GetWorkResponse workItemToFail =
+        makeInput(
+            failedWorkToken,
+            failedCacheToken,
+            TimeUnit.MILLISECONDS.toMicros(0),
+            "key",
+            DEFAULT_SHARDING_KEY);
+
     // Queue up two work items for the same key.
-    server
-        .whenGetWorkCalled()
-        .thenReturn(makeInput(0, TimeUnit.MILLISECONDS.toMicros(0), "key", 
DEFAULT_SHARDING_KEY))
-        .thenReturn(makeInput(1, TimeUnit.MILLISECONDS.toMicros(0), "key", 
DEFAULT_SHARDING_KEY));
+    server.whenGetWorkCalled().thenReturn(workItem).thenReturn(workItemToFail);
     server.waitForEmptyWorkQueue();
 
     // Mock Windmill sending a heartbeat response failing the second work item 
while the first
@@ -3300,8 +3353,8 @@ public class StreamingDataflowWorkerTest {
     failedHeartbeat
         .setComputationId(DEFAULT_COMPUTATION_ID)
         .addHeartbeatResponsesBuilder()
-        .setCacheToken(3)
-        .setWorkToken(1)
+        .setCacheToken(failedCacheToken)
+        .setWorkToken(failedWorkToken)
         .setShardingKey(DEFAULT_SHARDING_KEY)
         .setFailed(true);
     
server.sendFailedHeartbeats(Collections.singletonList(failedHeartbeat.build()));
@@ -3318,7 +3371,16 @@ public class StreamingDataflowWorkerTest {
   @Test
   public void testLatencyAttributionProtobufsPopulated() {
     FakeClock clock = new FakeClock();
-    Work work = Work.create(null, clock, Collections.emptyList(), unused -> 
{});
+    Work work =
+        Work.create(
+            Windmill.WorkItem.newBuilder()
+                .setKey(ByteString.EMPTY)
+                .setWorkToken(1L)
+                .setCacheToken(1L)
+                .build(),
+            clock,
+            Collections.emptyList(),
+            unused -> {});
 
     clock.sleep(Duration.millis(10));
     work.setState(Work.State.PROCESSING);
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java
index 82ff24c03bb..c581638d98b 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java
@@ -69,11 +69,16 @@ public class ActiveWorkStateTest {
     return Work.create(workItem, () -> Instant.EPOCH, Collections.emptyList(), 
unused -> {});
   }
 
-  private static Windmill.WorkItem createWorkItem(long workToken) {
+  private static WorkId workId(long workToken, long cacheToken) {
+    return 
WorkId.builder().setCacheToken(cacheToken).setWorkToken(workToken).build();
+  }
+
+  private static Windmill.WorkItem createWorkItem(long workToken, long 
cacheToken) {
     return Windmill.WorkItem.newBuilder()
         .setKey(ByteString.copyFromUtf8(""))
         .setShardingKey(1)
         .setWorkToken(workToken)
+        .setCacheToken(cacheToken)
         .build();
   }
 
@@ -89,7 +94,7 @@ public class ActiveWorkStateTest {
   public void testActivateWorkForKey_EXECUTE_unknownKey() {
     ActivateWorkResult activateWorkResult =
         activeWorkState.activateWorkForKey(
-            shardedKey("someKey", 1L), createWork(createWorkItem(1L)));
+            shardedKey("someKey", 1L), createWork(createWorkItem(1L, 1L)));
 
     assertEquals(ActivateWorkResult.EXECUTE, activateWorkResult);
   }
@@ -98,12 +103,14 @@ public class ActiveWorkStateTest {
   public void testActivateWorkForKey_EXECUTE_emptyWorkQueueForKey() {
     ShardedKey shardedKey = shardedKey("someKey", 1L);
     long workToken = 1L;
+    long cacheToken = 2L;
 
     ActivateWorkResult activateWorkResult =
-        activeWorkState.activateWorkForKey(shardedKey, 
createWork(createWorkItem(workToken)));
+        activeWorkState.activateWorkForKey(
+            shardedKey, createWork(createWorkItem(workToken, cacheToken)));
 
     Optional<Work> nextWorkForKey =
-        activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, 
workToken);
+        activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, 
workId(workToken, cacheToken));
 
     assertEquals(ActivateWorkResult.EXECUTE, activateWorkResult);
     assertEquals(Optional.empty(), nextWorkForKey);
@@ -116,9 +123,9 @@ public class ActiveWorkStateTest {
     ShardedKey shardedKey = shardedKey("someKey", 1L);
 
     // ActivateWork with the same shardedKey, and the same workTokens.
-    activeWorkState.activateWorkForKey(shardedKey, 
createWork(createWorkItem(workToken)));
+    activeWorkState.activateWorkForKey(shardedKey, 
createWork(createWorkItem(workToken, 1L)));
     ActivateWorkResult activateWorkResult =
-        activeWorkState.activateWorkForKey(shardedKey, 
createWork(createWorkItem(workToken)));
+        activeWorkState.activateWorkForKey(shardedKey, 
createWork(createWorkItem(workToken, 1L)));
 
     assertEquals(ActivateWorkResult.DUPLICATE, activateWorkResult);
   }
@@ -128,9 +135,9 @@ public class ActiveWorkStateTest {
     ShardedKey shardedKey = shardedKey("someKey", 1L);
 
     // ActivateWork with the same shardedKey, but different workTokens.
-    activeWorkState.activateWorkForKey(shardedKey, 
createWork(createWorkItem(1L)));
+    activeWorkState.activateWorkForKey(shardedKey, 
createWork(createWorkItem(1L, 1L)));
     ActivateWorkResult activateWorkResult =
-        activeWorkState.activateWorkForKey(shardedKey, 
createWork(createWorkItem(2L)));
+        activeWorkState.activateWorkForKey(shardedKey, 
createWork(createWorkItem(2L, 1L)));
 
     assertEquals(ActivateWorkResult.QUEUED, activateWorkResult);
   }
@@ -139,18 +146,22 @@ public class ActiveWorkStateTest {
   public void testCompleteWorkAndGetNextWorkForKey_noWorkQueueForKey() {
     assertEquals(
         Optional.empty(),
-        activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey("someKey", 
1L), 10L));
+        activeWorkState.completeWorkAndGetNextWorkForKey(
+            shardedKey("someKey", 1L), workId(1L, 1L)));
   }
 
   @Test
-  public void 
testCompleteWorkAndGetNextWorkForKey_currentWorkInQueueDoesNotMatchWorkToComplete()
 {
-    long workTokenToComplete = 1L;
-
-    Work workInQueue = createWork(createWorkItem(2L));
+  public void
+      
testCompleteWorkAndGetNextWorkForKey_currentWorkInQueueWorkTokenDoesNotMatchWorkToComplete()
 {
+    long workTokenInQueue = 2L;
+    long otherWorkToken = 1L;
+    long cacheToken = 1L;
+    Work workInQueue = createWork(createWorkItem(workTokenInQueue, 
cacheToken));
     ShardedKey shardedKey = shardedKey("someKey", 1L);
 
     activeWorkState.activateWorkForKey(shardedKey, workInQueue);
-    activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, 
workTokenToComplete);
+    activeWorkState.completeWorkAndGetNextWorkForKey(
+        shardedKey, workId(otherWorkToken, cacheToken));
 
     assertEquals(1, readOnlyActiveWork.get(shardedKey).size());
     assertEquals(workInQueue, readOnlyActiveWork.get(shardedKey).peek());
@@ -158,15 +169,13 @@ public class ActiveWorkStateTest {
 
   @Test
   public void 
testCompleteWorkAndGetNextWorkForKey_removesWorkFromQueueWhenComplete() {
-    long workTokenToComplete = 1L;
-
-    Work activeWork = createWork(createWorkItem(workTokenToComplete));
-    Work nextWork = createWork(createWorkItem(2L));
+    Work activeWork = createWork(createWorkItem(1L, 1L));
+    Work nextWork = createWork(createWorkItem(2L, 2L));
     ShardedKey shardedKey = shardedKey("someKey", 1L);
 
     activeWorkState.activateWorkForKey(shardedKey, activeWork);
     activeWorkState.activateWorkForKey(shardedKey, nextWork);
-    activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, 
workTokenToComplete);
+    activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, 
activeWork.id());
 
     assertEquals(nextWork, readOnlyActiveWork.get(shardedKey).peek());
     assertEquals(1, readOnlyActiveWork.get(shardedKey).size());
@@ -175,37 +184,33 @@ public class ActiveWorkStateTest {
 
   @Test
   public void 
testCompleteWorkAndGetNextWorkForKey_removesQueueIfNoWorkPresent() {
-    Work workInQueue = createWork(createWorkItem(1L));
+    Work workInQueue = createWork(createWorkItem(1L, 1L));
     ShardedKey shardedKey = shardedKey("someKey", 1L);
 
     activeWorkState.activateWorkForKey(shardedKey, workInQueue);
-    activeWorkState.completeWorkAndGetNextWorkForKey(
-        shardedKey, workInQueue.getWorkItem().getWorkToken());
+    activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, 
workInQueue.id());
 
     assertFalse(readOnlyActiveWork.containsKey(shardedKey));
   }
 
   @Test
   public void testCompleteWorkAndGetNextWorkForKey_returnsWorkIfPresent() {
-    Work workToBeCompleted = createWork(createWorkItem(1L));
-    Work nextWork = createWork(createWorkItem(2L));
+    Work workToBeCompleted = createWork(createWorkItem(1L, 1L));
+    Work nextWork = createWork(createWorkItem(2L, 2L));
     ShardedKey shardedKey = shardedKey("someKey", 1L);
 
     activeWorkState.activateWorkForKey(shardedKey, workToBeCompleted);
     activeWorkState.activateWorkForKey(shardedKey, nextWork);
-    activeWorkState.completeWorkAndGetNextWorkForKey(
-        shardedKey, workToBeCompleted.getWorkItem().getWorkToken());
+    activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, 
workToBeCompleted.id());
 
     Optional<Work> nextWorkOpt =
-        activeWorkState.completeWorkAndGetNextWorkForKey(
-            shardedKey, workToBeCompleted.getWorkItem().getWorkToken());
+        activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, 
workToBeCompleted.id());
 
     assertTrue(nextWorkOpt.isPresent());
     assertSame(nextWork, nextWorkOpt.get());
 
     Optional<Work> endOfWorkQueue =
-        activeWorkState.completeWorkAndGetNextWorkForKey(
-            shardedKey, nextWork.getWorkItem().getWorkToken());
+        activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, 
nextWork.id());
 
     assertFalse(endOfWorkQueue.isPresent());
     assertFalse(readOnlyActiveWork.containsKey(shardedKey));
@@ -214,8 +219,8 @@ public class ActiveWorkStateTest {
   @Test
   public void 
testCurrentActiveWorkBudget_correctlyAggregatesActiveWorkBudget_oneShardKey() {
     ShardedKey shardedKey = shardedKey("someKey", 1L);
-    Work work1 = createWork(createWorkItem(1L));
-    Work work2 = createWork(createWorkItem(2L));
+    Work work1 = createWork(createWorkItem(1L, 1L));
+    Work work2 = createWork(createWorkItem(2L, 2L));
 
     activeWorkState.activateWorkForKey(shardedKey, work1);
     activeWorkState.activateWorkForKey(shardedKey, work2);
@@ -229,8 +234,7 @@ public class ActiveWorkStateTest {
 
     
assertThat(activeWorkState.currentActiveWorkBudget()).isEqualTo(expectedActiveBudget1);
 
-    activeWorkState.completeWorkAndGetNextWorkForKey(
-        shardedKey, work1.getWorkItem().getWorkToken());
+    activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, work1.id());
 
     GetWorkBudget expectedActiveBudget2 =
         GetWorkBudget.builder()
@@ -244,13 +248,12 @@ public class ActiveWorkStateTest {
   @Test
   public void 
testCurrentActiveWorkBudget_correctlyAggregatesActiveWorkBudget_whenWorkCompleted()
 {
     ShardedKey shardedKey = shardedKey("someKey", 1L);
-    Work work1 = createWork(createWorkItem(1L));
-    Work work2 = createWork(createWorkItem(2L));
+    Work work1 = createWork(createWorkItem(1L, 1L));
+    Work work2 = createWork(createWorkItem(2L, 2L));
 
     activeWorkState.activateWorkForKey(shardedKey, work1);
     activeWorkState.activateWorkForKey(shardedKey, work2);
-    activeWorkState.completeWorkAndGetNextWorkForKey(
-        shardedKey, work1.getWorkItem().getWorkToken());
+    activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, work1.id());
 
     GetWorkBudget expectedActiveBudget =
         GetWorkBudget.builder()
@@ -265,8 +268,8 @@ public class ActiveWorkStateTest {
   public void 
testCurrentActiveWorkBudget_correctlyAggregatesActiveWorkBudget_multipleShardKeys()
 {
     ShardedKey shardedKey1 = shardedKey("someKey", 1L);
     ShardedKey shardedKey2 = shardedKey("someKey", 2L);
-    Work work1 = createWork(createWorkItem(1L));
-    Work work2 = createWork(createWorkItem(2L));
+    Work work1 = createWork(createWorkItem(1L, 1L));
+    Work work2 = createWork(createWorkItem(2L, 2L));
 
     activeWorkState.activateWorkForKey(shardedKey1, work1);
     activeWorkState.activateWorkForKey(shardedKey2, work2);
@@ -283,11 +286,11 @@ public class ActiveWorkStateTest {
 
   @Test
   public void testInvalidateStuckCommits() {
-    Map<ShardedKey, Long> invalidatedCommits = new HashMap<>();
+    Map<ShardedKey, WorkId> invalidatedCommits = new HashMap<>();
 
-    Work stuckWork1 = expiredWork(createWorkItem(1L));
+    Work stuckWork1 = expiredWork(createWorkItem(1L, 1L));
     stuckWork1.setState(Work.State.COMMITTING);
-    Work stuckWork2 = expiredWork(createWorkItem(2L));
+    Work stuckWork2 = expiredWork(createWorkItem(2L, 1L));
     stuckWork2.setState(Work.State.COMMITTING);
     ShardedKey shardedKey1 = shardedKey("someKey", 1L);
     ShardedKey shardedKey2 = shardedKey("anotherKey", 2L);
@@ -297,22 +300,135 @@ public class ActiveWorkStateTest {
 
     activeWorkState.invalidateStuckCommits(Instant.now(), 
invalidatedCommits::put);
 
-    assertThat(invalidatedCommits)
-        .containsEntry(shardedKey1, stuckWork1.getWorkItem().getWorkToken());
-    assertThat(invalidatedCommits)
-        .containsEntry(shardedKey2, stuckWork2.getWorkItem().getWorkToken());
+    assertThat(invalidatedCommits).containsEntry(shardedKey1, stuckWork1.id());
+    assertThat(invalidatedCommits).containsEntry(shardedKey2, stuckWork2.id());
     verify(computationStateCache).invalidate(shardedKey1.key(), 
shardedKey1.shardingKey());
     verify(computationStateCache).invalidate(shardedKey2.key(), 
shardedKey2.shardingKey());
   }
 
+  @Test
+  public void
+      
testActivateWorkForKey_withMatchingWorkTokenAndDifferentCacheToken_queuedWorkIsNotActive_QUEUED()
 {
+    long workToken = 10L;
+    long cacheToken1 = 5L;
+    long cacheToken2 = cacheToken1 + 2L;
+
+    Work firstWork = createWork(createWorkItem(workToken, cacheToken1));
+    Work secondWork = createWork(createWorkItem(workToken, cacheToken2));
+    Work differentWorkTokenWork = createWork(createWorkItem(1L, 1L));
+    ShardedKey shardedKey = shardedKey("someKey", 1L);
+
+    activeWorkState.activateWorkForKey(shardedKey, differentWorkTokenWork);
+    // ActivateWork with the same shardedKey, and the same workTokens, but 
different cacheTokens.
+    activeWorkState.activateWorkForKey(shardedKey, firstWork);
+    ActivateWorkResult activateWorkResult =
+        activeWorkState.activateWorkForKey(shardedKey, secondWork);
+
+    assertEquals(ActivateWorkResult.QUEUED, activateWorkResult);
+    assertTrue(readOnlyActiveWork.get(shardedKey).contains(secondWork));
+
+    Optional<Work> nextWork =
+        activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, 
differentWorkTokenWork.id());
+    assertTrue(nextWork.isPresent());
+    assertSame(firstWork, nextWork.get());
+    nextWork = activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, 
firstWork.id());
+    assertTrue(nextWork.isPresent());
+    assertSame(secondWork, nextWork.get());
+  }
+
+  @Test
+  public void
+      
testActivateWorkForKey_withMatchingWorkTokenAndDifferentCacheToken_queuedWorkIsActive_QUEUED()
 {
+    long workToken = 10L;
+    long cacheToken1 = 5L;
+    long cacheToken2 = 7L;
+
+    Work firstWork = createWork(createWorkItem(workToken, cacheToken1));
+    Work secondWork = createWork(createWorkItem(workToken, cacheToken2));
+    ShardedKey shardedKey = shardedKey("someKey", 1L);
+
+    // ActivateWork with the same shardedKey, and the same workTokens, but 
different cacheTokens.
+    activeWorkState.activateWorkForKey(shardedKey, firstWork);
+    ActivateWorkResult activateWorkResult =
+        activeWorkState.activateWorkForKey(shardedKey, secondWork);
+
+    assertEquals(ActivateWorkResult.QUEUED, activateWorkResult);
+    assertEquals(firstWork, readOnlyActiveWork.get(shardedKey).peek());
+    assertTrue(readOnlyActiveWork.get(shardedKey).contains(secondWork));
+    Optional<Work> nextWork =
+        activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, 
firstWork.id());
+    assertTrue(nextWork.isPresent());
+    assertSame(secondWork, nextWork.get());
+  }
+
+  @Test
+  public void
+      
testActivateWorkForKey_matchingCacheTokens_newWorkTokenGreater_queuedWorkIsActive_QUEUED()
 {
+    long cacheToken = 1L;
+    long newWorkToken = 10L;
+    long queuedWorkToken = newWorkToken / 2;
+
+    Work queuedWork = createWork(createWorkItem(queuedWorkToken, cacheToken));
+    Work newWork = createWork(createWorkItem(newWorkToken, cacheToken));
+    ShardedKey shardedKey = shardedKey("someKey", 1L);
+
+    activeWorkState.activateWorkForKey(shardedKey, queuedWork);
+    ActivateWorkResult activateWorkResult = 
activeWorkState.activateWorkForKey(shardedKey, newWork);
+
+    // newWork should be queued and queuedWork should not be removed since it 
is currently active.
+    assertEquals(ActivateWorkResult.QUEUED, activateWorkResult);
+    assertTrue(readOnlyActiveWork.get(shardedKey).contains(newWork));
+    assertEquals(queuedWork, readOnlyActiveWork.get(shardedKey).peek());
+  }
+
+  @Test
+  public void
+      
testActivateWorkForKey_matchingCacheTokens_newWorkTokenGreater_queuedWorkNotActive_QUEUED()
 {
+    long matchingCacheToken = 1L;
+    long newWorkToken = 10L;
+    long queuedWorkToken = newWorkToken / 2;
+
+    Work differentWorkTokenWork = createWork(createWorkItem(100L, 100L));
+    Work queuedWork = createWork(createWorkItem(queuedWorkToken, 
matchingCacheToken));
+    Work newWork = createWork(createWorkItem(newWorkToken, 
matchingCacheToken));
+    ShardedKey shardedKey = shardedKey("someKey", 1L);
+
+    activeWorkState.activateWorkForKey(shardedKey, differentWorkTokenWork);
+    activeWorkState.activateWorkForKey(shardedKey, queuedWork);
+    ActivateWorkResult activateWorkResult = 
activeWorkState.activateWorkForKey(shardedKey, newWork);
+
+    assertEquals(ActivateWorkResult.QUEUED, activateWorkResult);
+    assertTrue(readOnlyActiveWork.get(shardedKey).contains(newWork));
+    assertFalse(readOnlyActiveWork.get(shardedKey).contains(queuedWork));
+    assertEquals(differentWorkTokenWork, 
readOnlyActiveWork.get(shardedKey).peek());
+  }
+
+  @Test
+  public void 
testActivateWorkForKey_matchingCacheTokens_newWorkTokenLesser_STALE() {
+    long cacheToken = 1L;
+    long queuedWorkToken = 10L;
+    long newWorkToken = queuedWorkToken / 2;
+
+    Work queuedWork = createWork(createWorkItem(queuedWorkToken, cacheToken));
+    Work newWork = createWork(createWorkItem(newWorkToken, cacheToken));
+    ShardedKey shardedKey = shardedKey("someKey", 1L);
+
+    activeWorkState.activateWorkForKey(shardedKey, queuedWork);
+    ActivateWorkResult activateWorkResult = 
activeWorkState.activateWorkForKey(shardedKey, newWork);
+
+    assertEquals(ActivateWorkResult.STALE, activateWorkResult);
+    assertFalse(readOnlyActiveWork.get(shardedKey).contains(newWork));
+    assertEquals(queuedWork, readOnlyActiveWork.get(shardedKey).peek());
+  }
+
   @Test
   public void testGetKeyHeartbeats() {
     Instant refreshDeadline = Instant.now();
 
-    Work freshWork = createWork(createWorkItem(3L));
-    Work refreshableWork1 = expiredWork(createWorkItem(1L));
+    Work freshWork = createWork(createWorkItem(3L, 3L));
+    Work refreshableWork1 = expiredWork(createWorkItem(1L, 1L));
     refreshableWork1.setState(Work.State.COMMITTING);
-    Work refreshableWork2 = expiredWork(createWorkItem(2L));
+    Work refreshableWork2 = expiredWork(createWorkItem(2L, 2L));
     refreshableWork2.setState(Work.State.COMMITTING);
     ShardedKey shardedKey1 = shardedKey("someKey", 1L);
     ShardedKey shardedKey2 = shardedKey("anotherKey", 2L);

Reply via email to