scwhittle commented on code in PR #28537:
URL: https://github.com/apache/beam/pull/28537#discussion_r1337172196


##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java:
##########
@@ -0,0 +1,304 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.streaming;
+
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;
+
+import com.google.auto.value.AutoValue;
+import java.util.ArrayDeque;
+import java.util.Deque;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Optional;
+import java.util.Queue;
+import java.util.function.BiConsumer;
+import java.util.stream.Stream;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.GuardedBy;
+import javax.annotation.concurrent.ThreadSafe;
+import org.apache.beam.runners.dataflow.worker.WindmillStateCache;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Manages the active {@link Work} queues for their {@link ShardedKey}(s). 
Gives an interface to
+ * activate, queue, and complete {@link Work} (including invalidating stuck 
{@link Work}).
+ */
+@ThreadSafe
+final class ActiveWorkState {
+  private static final Logger LOG = 
LoggerFactory.getLogger(ActiveWorkState.class);
+
+  /**
+   * Map from {@link ShardedKey} to {@link Work} for the key. The first item 
in the {@link
+   * Queue<Work>} is actively processing.
+   */
+  @GuardedBy("this")
+  private final Map<ShardedKey, Deque<Work>> activeWork;
+
+  @GuardedBy("this")
+  private final WindmillStateCache.ForComputation computationStateCache;
+
+  private ActiveWorkState(
+      Map<ShardedKey, Deque<Work>> activeWork,
+      WindmillStateCache.ForComputation computationStateCache) {
+    this.activeWork = activeWork;
+    this.computationStateCache = computationStateCache;
+  }
+
+  static ActiveWorkState create(WindmillStateCache.ForComputation 
computationStateCache) {
+    return new ActiveWorkState(new HashMap<>(), computationStateCache);
+  }
+
+  @VisibleForTesting
+  static ActiveWorkState forTesting(
+      Map<ShardedKey, Deque<Work>> activeWork,
+      WindmillStateCache.ForComputation computationStateCache) {
+    return new ActiveWorkState(activeWork, computationStateCache);
+  }
+
+  /**
+   * Activates {@link Work} for the {@link ShardedKey}. Outcome can be 1 of 3 
{@link
+   * ActivateWorkResult}
+   *
+   * <p>1. EXECUTE: The {@link ShardedKey} has not been seen before, create a 
{@link Queue<Work>}
+   * for the key. The caller should execute the work.
+   *
+   * <p>2. DUPLICATE: A work queue for the {@link ShardedKey} exists, and the 
work already exists in
+   * the {@link ShardedKey}'s work queue, mark the {@link Work} as a duplicate.
+   *
+   * <p>3. QUEUED: A work queue for the {@link ShardedKey} exists, and the 
work is not in the key's
+   * work queue, queue the work for later processing.
+   */
+  synchronized ActivateWorkResult activateWorkForKey(ShardedKey shardedKey, 
Work work) {
+    Deque<Work> workQueue = activeWork.getOrDefault(shardedKey, new 
ArrayDeque<>());
+
+    // This key does not have any work queued up on it. Create one, insert 
Work, and mark the work
+    // to be executed.
+    if (!activeWork.containsKey(shardedKey) || workQueue.isEmpty()) {
+      workQueue.addLast(work);
+      activeWork.put(shardedKey, workQueue);
+      return ActivateWorkResult.EXECUTE;
+    }
+
+    // Ensure we don't already have this work token queued.
+    for (Work queuedWork : workQueue) {
+      if (queuedWork.getWorkItem().getWorkToken() == 
work.getWorkItem().getWorkToken()) {
+        return ActivateWorkResult.DUPLICATE;
+      }
+    }
+
+    // Queue the work for later processing.
+    workQueue.addLast(work);
+    return ActivateWorkResult.QUEUED;
+  }
+
+  /**
+   * Removes the complete work from the {@link Queue<Work>}. The {@link Work} 
is marked as completed
+   * if its workToken matches the one that is passed in. Returns the next 
{@link Work} in the {@link
+   * ShardedKey}'s work queue, if one exists else removes the {@link 
ShardedKey} from {@link
+   * #activeWork}.
+   */
+  synchronized Optional<Work> completeWorkAndGetNextWorkForKey(
+      ShardedKey shardedKey, long workToken) {
+    @Nullable Queue<Work> workQueue = activeWork.get(shardedKey);
+    if (workQueue == null) {
+      // Work may have been completed due to clearing of stuck commits.
+      LOG.warn("Unable to complete inactive work for key {} and token {}.", 
shardedKey, workToken);
+      return Optional.empty();
+    }
+    removeCompletedWorkFromQueue(workQueue, shardedKey, workToken);
+    return getNextWork(workQueue, shardedKey);
+  }
+
+  private synchronized void removeCompletedWorkFromQueue(
+      Queue<Work> workQueue, ShardedKey shardedKey, long workToken) {
+    // avoid Preconditions.checkState here to prevent eagerly evaluating the
+    // format string parameters for the error message.
+    Work completedWork =
+        Optional.ofNullable(workQueue.peek())
+            .orElseThrow(
+                () ->
+                    new IllegalStateException(
+                        String.format(
+                            "Active key %s without work, expected token %d",
+                            shardedKey, workToken)));
+
+    if (completedWork.getWorkItem().getWorkToken() != workToken) {

Review Comment:
   we should be comparing the workToken and cacheToken



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java:
##########
@@ -0,0 +1,304 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.streaming;
+
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;
+
+import com.google.auto.value.AutoValue;
+import java.util.ArrayDeque;
+import java.util.Deque;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Optional;
+import java.util.Queue;
+import java.util.function.BiConsumer;
+import java.util.stream.Stream;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.GuardedBy;
+import javax.annotation.concurrent.ThreadSafe;
+import org.apache.beam.runners.dataflow.worker.WindmillStateCache;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Manages the active {@link Work} queues for their {@link ShardedKey}(s). 
Gives an interface to
+ * activate, queue, and complete {@link Work} (including invalidating stuck 
{@link Work}).
+ */
+@ThreadSafe
+final class ActiveWorkState {
+  private static final Logger LOG = 
LoggerFactory.getLogger(ActiveWorkState.class);
+
+  /**
+   * Map from {@link ShardedKey} to {@link Work} for the key. The first item 
in the {@link
+   * Queue<Work>} is actively processing.
+   */
+  @GuardedBy("this")
+  private final Map<ShardedKey, Deque<Work>> activeWork;
+
+  @GuardedBy("this")
+  private final WindmillStateCache.ForComputation computationStateCache;
+
+  private ActiveWorkState(
+      Map<ShardedKey, Deque<Work>> activeWork,
+      WindmillStateCache.ForComputation computationStateCache) {
+    this.activeWork = activeWork;
+    this.computationStateCache = computationStateCache;
+  }
+
+  static ActiveWorkState create(WindmillStateCache.ForComputation 
computationStateCache) {
+    return new ActiveWorkState(new HashMap<>(), computationStateCache);
+  }
+
+  @VisibleForTesting
+  static ActiveWorkState forTesting(
+      Map<ShardedKey, Deque<Work>> activeWork,
+      WindmillStateCache.ForComputation computationStateCache) {
+    return new ActiveWorkState(activeWork, computationStateCache);
+  }
+
+  /**
+   * Activates {@link Work} for the {@link ShardedKey}. Outcome can be 1 of 3 
{@link
+   * ActivateWorkResult}
+   *
+   * <p>1. EXECUTE: The {@link ShardedKey} has not been seen before, create a 
{@link Queue<Work>}
+   * for the key. The caller should execute the work.
+   *
+   * <p>2. DUPLICATE: A work queue for the {@link ShardedKey} exists, and the 
work already exists in
+   * the {@link ShardedKey}'s work queue, mark the {@link Work} as a duplicate.
+   *
+   * <p>3. QUEUED: A work queue for the {@link ShardedKey} exists, and the 
work is not in the key's
+   * work queue, queue the work for later processing.
+   */
+  synchronized ActivateWorkResult activateWorkForKey(ShardedKey shardedKey, 
Work work) {
+    Deque<Work> workQueue = activeWork.getOrDefault(shardedKey, new 
ArrayDeque<>());
+
+    // This key does not have any work queued up on it. Create one, insert 
Work, and mark the work
+    // to be executed.
+    if (!activeWork.containsKey(shardedKey) || workQueue.isEmpty()) {
+      workQueue.addLast(work);
+      activeWork.put(shardedKey, workQueue);
+      return ActivateWorkResult.EXECUTE;
+    }
+
+    // Ensure we don't already have this work token queued.
+    for (Work queuedWork : workQueue) {
+      if (queuedWork.getWorkItem().getWorkToken() == 
work.getWorkItem().getWorkToken()) {

Review Comment:
   if the cache token is different then it isn't a pure duplicate it could be a 
retry, in that case it might be better to take the more recent observed item 
(guessing it is more likely the newer one) or keep both.



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java:
##########
@@ -0,0 +1,301 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.streaming;
+
+import static com.google.common.truth.Truth.assertThat;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertSame;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+
+import com.google.auto.value.AutoValue;
+import java.util.ArrayDeque;
+import java.util.Collections;
+import java.util.Deque;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Optional;
+import javax.annotation.Nullable;
+import org.apache.beam.runners.dataflow.worker.WindmillStateCache;
+import 
org.apache.beam.runners.dataflow.worker.streaming.ActiveWorkState.ActivateWorkResult;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
+import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
+import org.joda.time.Instant;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class ActiveWorkStateTest {
+
+  private final WindmillStateCache.ForComputation computationStateCache =
+      mock(WindmillStateCache.ForComputation.class);
+  private final Map<ShardedKey, Deque<Work>> activeWork = new HashMap<>();
+
+  private ActiveWorkState activeWorkState;
+
+  private static ShardedKey shardedKey(String str, long shardKey) {
+    return ShardedKey.create(ByteString.copyFromUtf8(str), shardKey);
+  }
+
+  private static Work emptyWork() {
+    return createWork(null);
+  }
+
+  private static Work createWork(@Nullable Windmill.WorkItem workItem) {
+    return Work.create(workItem, Instant::now, Collections.emptyList(), unused 
-> {});
+  }
+
+  private static Work expiredWork(Windmill.WorkItem workItem) {
+    return Work.create(workItem, () -> Instant.EPOCH, Collections.emptyList(), 
unused -> {});
+  }
+
+  private static Windmill.WorkItem createWorkItem(long workToken) {
+    return Windmill.WorkItem.newBuilder()
+        .setKey(ByteString.copyFromUtf8(""))
+        .setShardingKey(1)
+        .setWorkToken(workToken)
+        .build();
+  }
+
+  @Before
+  public void setup() {
+    activeWork.clear();
+    activeWorkState = ActiveWorkState.forTesting(activeWork, 
computationStateCache);
+  }
+
+  @Test
+  public void testActivateWorkForKey_EXECUTE_unknownKey() {
+    ActivateWorkResult activateWorkResult =
+        activeWorkState.activateWorkForKey(shardedKey("someKey", 1L), 
emptyWork());
+
+    assertEquals(ActivateWorkResult.EXECUTE, activateWorkResult);
+  }
+
+  @Test
+  public void testActivateWorkForKey_EXECUTE_emptyWorkQueueForKey() {
+    ShardedKey shardedKey = shardedKey("someKey", 1L);
+    activeWork.put(shardedKey, new ArrayDeque<>());
+
+    ActivateWorkResult activateWorkResult =
+        activeWorkState.activateWorkForKey(shardedKey("someKey", 1L), 
emptyWork());
+
+    assertEquals(ActivateWorkResult.EXECUTE, activateWorkResult);
+  }
+
+  @Test
+  public void testActivateWorkForKey_DUPLICATE() {
+    long workToken = 10L;
+    ShardedKey shardedKey = shardedKey("someKey", 1L);
+
+    // ActivateWork with the same shardedKey, and the same workTokens.
+    activeWorkState.activateWorkForKey(shardedKey, 
createWork(createWorkItem(workToken)));
+    ActivateWorkResult activateWorkResult =
+        activeWorkState.activateWorkForKey(shardedKey, 
createWork(createWorkItem(workToken)));
+
+    assertEquals(ActivateWorkResult.DUPLICATE, activateWorkResult);
+  }
+
+  @Test
+  public void testActivateWorkForKey_QUEUED() {
+    ShardedKey shardedKey = shardedKey("someKey", 1L);
+
+    // ActivateWork with the same shardedKey, but different workTokens.
+    activeWorkState.activateWorkForKey(shardedKey, 
createWork(createWorkItem(1L)));
+    ActivateWorkResult activateWorkResult =
+        activeWorkState.activateWorkForKey(shardedKey, 
createWork(createWorkItem(2L)));
+
+    assertEquals(ActivateWorkResult.QUEUED, activateWorkResult);
+  }
+
+  @Test
+  public void testCompleteWorkAndGetNextWorkForKey_noWorkQueueForKey() {
+    assertEquals(
+        Optional.empty(),
+        activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey("someKey", 
1L), 10L));
+  }
+
+  @Test
+  public void 
testCompleteWorkAndGetNextWorkForKey_throwsWhenNoWorkInQueueForKey() {
+    ShardedKey shardedKey = shardedKey("someKey", 1L);
+    activeWork.put(shardedKey, new ArrayDeque<>());
+
+    assertThrows(
+        IllegalStateException.class,
+        () -> activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, 
10L));
+  }
+
+  @Test
+  public void 
testCompleteWorkAndGetNextWorkForKey_currentWorkInQueueDoesNotMatchWorkToComplete()
 {
+    long workTokenToComplete = 1L;
+
+    Work workInQueue = createWork(createWorkItem(2L));
+    ShardedKey shardedKey = shardedKey("someKey", 1L);
+    Deque<Work> workQueue = new ArrayDeque<>();
+    workQueue.addLast(workInQueue);
+    activeWork.put(shardedKey, workQueue);

Review Comment:
   ditto, this seems like it should just be combined with the above QUEUED test 
which sets up the right state, instead of hard-coding the state here which 
might not match the actual implementation.
   
   If there was a bug and queueing the work returned QUEUED but didn't actually 
put it in the internal state, then the above test would pass and this test 
would pass.
   
   Ditto for the below



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java:
##########
@@ -99,127 +82,47 @@ public ConcurrentLinkedQueue<ExecutionState> 
getExecutionStateQueue() {
    * Work} if there is no active {@link Work} for the {@link ShardedKey} 
already processing.
    */
   public boolean activateWork(ShardedKey shardedKey, Work work) {
-    synchronized (activeWork) {
-      Deque<Work> queue = activeWork.get(shardedKey);
-      if (queue != null) {
-        Preconditions.checkState(!queue.isEmpty());
-        // Ensure we don't already have this work token queued.
-        for (Work queuedWork : queue) {
-          if (queuedWork.getWorkItem().getWorkToken() == 
work.getWorkItem().getWorkToken()) {
-            return false;
-          }
-        }
-        // Queue the work for later processing.
-        queue.addLast(work);
+    switch (activeWorkState.activateWorkForKey(shardedKey, work)) {
+      case DUPLICATE:
+        return false;
+      case QUEUED:
         return true;
-      } else {
-        queue = new ArrayDeque<>();
-        queue.addLast(work);
-        activeWork.put(shardedKey, queue);
-        // Fall through to execute without the lock held.
-      }
+      case EXECUTE:
+        {
+          execute(work);
+          return true;
+        }
+        // This will never happen, the switch is exhaustive.

Review Comment:
   move inside default case



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java:
##########
@@ -0,0 +1,304 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.streaming;
+
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;
+
+import com.google.auto.value.AutoValue;
+import java.util.ArrayDeque;
+import java.util.Deque;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Optional;
+import java.util.Queue;
+import java.util.function.BiConsumer;
+import java.util.stream.Stream;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.GuardedBy;
+import javax.annotation.concurrent.ThreadSafe;
+import org.apache.beam.runners.dataflow.worker.WindmillStateCache;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Manages the active {@link Work} queues for their {@link ShardedKey}(s). 
Gives an interface to
+ * activate, queue, and complete {@link Work} (including invalidating stuck 
{@link Work}).
+ */
+@ThreadSafe
+final class ActiveWorkState {
+  private static final Logger LOG = 
LoggerFactory.getLogger(ActiveWorkState.class);
+
+  /**
+   * Map from {@link ShardedKey} to {@link Work} for the key. The first item 
in the {@link
+   * Queue<Work>} is actively processing.
+   */
+  @GuardedBy("this")
+  private final Map<ShardedKey, Deque<Work>> activeWork;
+
+  @GuardedBy("this")
+  private final WindmillStateCache.ForComputation computationStateCache;
+
+  private ActiveWorkState(
+      Map<ShardedKey, Deque<Work>> activeWork,
+      WindmillStateCache.ForComputation computationStateCache) {
+    this.activeWork = activeWork;
+    this.computationStateCache = computationStateCache;
+  }
+
+  static ActiveWorkState create(WindmillStateCache.ForComputation 
computationStateCache) {
+    return new ActiveWorkState(new HashMap<>(), computationStateCache);
+  }
+
+  @VisibleForTesting
+  static ActiveWorkState forTesting(
+      Map<ShardedKey, Deque<Work>> activeWork,
+      WindmillStateCache.ForComputation computationStateCache) {
+    return new ActiveWorkState(activeWork, computationStateCache);
+  }
+
+  /**
+   * Activates {@link Work} for the {@link ShardedKey}. Outcome can be 1 of 3 
{@link
+   * ActivateWorkResult}
+   *
+   * <p>1. EXECUTE: The {@link ShardedKey} has not been seen before, create a 
{@link Queue<Work>}
+   * for the key. The caller should execute the work.
+   *
+   * <p>2. DUPLICATE: A work queue for the {@link ShardedKey} exists, and the 
work already exists in
+   * the {@link ShardedKey}'s work queue, mark the {@link Work} as a duplicate.
+   *
+   * <p>3. QUEUED: A work queue for the {@link ShardedKey} exists, and the 
work is not in the key's
+   * work queue, queue the work for later processing.
+   */
+  synchronized ActivateWorkResult activateWorkForKey(ShardedKey shardedKey, 
Work work) {
+    Deque<Work> workQueue = activeWork.getOrDefault(shardedKey, new 
ArrayDeque<>());
+
+    // This key does not have any work queued up on it. Create one, insert 
Work, and mark the work
+    // to be executed.
+    if (!activeWork.containsKey(shardedKey) || workQueue.isEmpty()) {
+      workQueue.addLast(work);
+      activeWork.put(shardedKey, workQueue);
+      return ActivateWorkResult.EXECUTE;
+    }
+
+    // Ensure we don't already have this work token queued.
+    for (Work queuedWork : workQueue) {
+      if (queuedWork.getWorkItem().getWorkToken() == 
work.getWorkItem().getWorkToken()) {
+        return ActivateWorkResult.DUPLICATE;
+      }
+    }
+
+    // Queue the work for later processing.
+    workQueue.addLast(work);
+    return ActivateWorkResult.QUEUED;
+  }
+
+  /**
+   * Removes the complete work from the {@link Queue<Work>}. The {@link Work} 
is marked as completed
+   * if its workToken matches the one that is passed in. Returns the next 
{@link Work} in the {@link
+   * ShardedKey}'s work queue, if one exists else removes the {@link 
ShardedKey} from {@link
+   * #activeWork}.
+   */
+  synchronized Optional<Work> completeWorkAndGetNextWorkForKey(
+      ShardedKey shardedKey, long workToken) {
+    @Nullable Queue<Work> workQueue = activeWork.get(shardedKey);
+    if (workQueue == null) {
+      // Work may have been completed due to clearing of stuck commits.
+      LOG.warn("Unable to complete inactive work for key {} and token {}.", 
shardedKey, workToken);
+      return Optional.empty();
+    }
+    removeCompletedWorkFromQueue(workQueue, shardedKey, workToken);
+    return getNextWork(workQueue, shardedKey);
+  }
+
+  private synchronized void removeCompletedWorkFromQueue(
+      Queue<Work> workQueue, ShardedKey shardedKey, long workToken) {
+    // avoid Preconditions.checkState here to prevent eagerly evaluating the
+    // format string parameters for the error message.
+    Work completedWork =
+        Optional.ofNullable(workQueue.peek())
+            .orElseThrow(
+                () ->
+                    new IllegalStateException(
+                        String.format(
+                            "Active key %s without work, expected token %d",
+                            shardedKey, workToken)));
+
+    if (completedWork.getWorkItem().getWorkToken() != workToken) {
+      // Work may have been completed due to clearing of stuck commits.
+      LOG.warn(
+          "Unable to complete due to token mismatch for key {} and token {}, 
actual token was {}.",
+          shardedKey,
+          workToken,
+          completedWork.getWorkItem().getWorkToken());
+      return;
+    }
+
+    // We consumed the matching work item.
+    workQueue.remove();
+  }
+
+  private synchronized Optional<Work> getNextWork(Queue<Work> workQueue, 
ShardedKey shardedKey) {
+    Optional<Work> nextWork = Optional.ofNullable(workQueue.peek());
+    if (!nextWork.isPresent()) {
+      Preconditions.checkState(workQueue == activeWork.remove(shardedKey));
+    }
+
+    return nextWork;
+  }
+
+  /**
+   * Invalidates all {@link Work} that is in the {@link Work.State#COMMITTING} 
state which started
+   * before the stuckCommitDeadline.
+   */
+  synchronized void invalidateStuckCommits(
+      Instant stuckCommitDeadline, BiConsumer<ShardedKey, Long> 
shardedKeyAndWorkTokenConsumer) {
+    for (Entry<ShardedKey, Long> shardedKeyAndWorkToken :
+        getStuckCommitsAt(stuckCommitDeadline).entrySet()) {
+      ShardedKey shardedKey = shardedKeyAndWorkToken.getKey();
+      long workToken = shardedKeyAndWorkToken.getValue();
+      computationStateCache.invalidate(shardedKey.key(), 
shardedKey.shardingKey());
+      shardedKeyAndWorkTokenConsumer.accept(shardedKey, workToken);
+    }
+  }
+
+  private synchronized ImmutableMap<ShardedKey, Long> getStuckCommitsAt(
+      Instant stuckCommitDeadline) {
+    // Determine the stuck commit keys but complete them outside the loop 
iterating over
+    // activeWork as completeWork may delete the entry from activeWork.
+    ImmutableMap.Builder<ShardedKey, Long> stuckCommits = 
ImmutableMap.builder();
+    for (Entry<ShardedKey, Deque<Work>> entry : activeWork.entrySet()) {
+      ShardedKey shardedKey = entry.getKey();
+      @Nullable Work work = entry.getValue().peek();
+      if (work != null) {
+        if (work.isStuckAt(stuckCommitDeadline)) {
+          LOG.error(
+              "Detected key {} stuck in COMMITTING state since {}, completing 
it with error.",
+              shardedKey,
+              work.getStateStartTime());
+          stuckCommits.put(shardedKey, work.getWorkItem().getWorkToken());
+        }
+      }
+    }
+
+    return stuckCommits.build();
+  }
+
+  synchronized ImmutableList<KeyedGetDataRequest> getKeysToRefresh(Instant 
refreshDeadline) {
+    return activeWork.entrySet().stream()
+        .flatMap(entry -> toKeyedGetDataRequestStream(entry, refreshDeadline))
+        .collect(toImmutableList());
+  }
+
+  private static Stream<KeyedGetDataRequest> toKeyedGetDataRequestStream(
+      Entry<ShardedKey, Deque<Work>> shardedKeyAndWorkQueue, Instant 
refreshDeadline) {
+    ShardedKey shardedKey = shardedKeyAndWorkQueue.getKey();
+    Deque<Work> workQueue = shardedKeyAndWorkQueue.getValue();
+
+    return workQueue.stream()
+        .filter(work -> work.getStartTime().isBefore(refreshDeadline))
+        .map(
+            work ->
+                Windmill.KeyedGetDataRequest.newBuilder()
+                    .setKey(shardedKey.key())
+                    .setShardingKey(shardedKey.shardingKey())
+                    .setWorkToken(work.getWorkItem().getWorkToken())
+                    .addAllLatencyAttribution(work.getLatencyAttributions())
+                    .build());
+  }
+
+  synchronized CommitsPendingCountAndActiveWorkStatus 
getPendingCommitsAndPrintActiveWorkAt(

Review Comment:
   what about just changing this to printTo(PrintWriter writer) and doing the 
printing here?
   Or just returning a String that we print?
   
   I don't think we need to expose this for testing (don't see test anyway yet) 
and the current logic is a bit spread out with 
MAX_PRINTABLE_COMMIT_PENDING_KEYS used here when building and in printing. The 
header for the table is also separated from the population of the table here.  
Just having all of that in a single method instead of introducing 
CommitsPendingCountAndActiveWorkStatus seems cleaner.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java:
##########
@@ -120,6 +120,11 @@ public Collection<Windmill.LatencyAttribution> 
getLatencyAttributions() {
     return list;
   }
 
+  boolean isStuckAt(Instant stuckCommitDeadline) {

Review Comment:
   isStuckCommittingAt



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java:
##########
@@ -0,0 +1,301 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.streaming;
+
+import static com.google.common.truth.Truth.assertThat;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertSame;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+
+import com.google.auto.value.AutoValue;
+import java.util.ArrayDeque;
+import java.util.Collections;
+import java.util.Deque;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Optional;
+import javax.annotation.Nullable;
+import org.apache.beam.runners.dataflow.worker.WindmillStateCache;
+import 
org.apache.beam.runners.dataflow.worker.streaming.ActiveWorkState.ActivateWorkResult;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
+import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
+import org.joda.time.Instant;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class ActiveWorkStateTest {
+
+  private final WindmillStateCache.ForComputation computationStateCache =
+      mock(WindmillStateCache.ForComputation.class);
+  private final Map<ShardedKey, Deque<Work>> activeWork = new HashMap<>();
+
+  private ActiveWorkState activeWorkState;
+
+  private static ShardedKey shardedKey(String str, long shardKey) {
+    return ShardedKey.create(ByteString.copyFromUtf8(str), shardKey);
+  }
+
+  private static Work emptyWork() {
+    return createWork(null);
+  }
+
+  private static Work createWork(@Nullable Windmill.WorkItem workItem) {
+    return Work.create(workItem, Instant::now, Collections.emptyList(), unused 
-> {});
+  }
+
+  private static Work expiredWork(Windmill.WorkItem workItem) {
+    return Work.create(workItem, () -> Instant.EPOCH, Collections.emptyList(), 
unused -> {});
+  }
+
+  private static Windmill.WorkItem createWorkItem(long workToken) {
+    return Windmill.WorkItem.newBuilder()
+        .setKey(ByteString.copyFromUtf8(""))
+        .setShardingKey(1)
+        .setWorkToken(workToken)
+        .build();
+  }
+
+  @Before
+  public void setup() {
+    activeWork.clear();
+    activeWorkState = ActiveWorkState.forTesting(activeWork, 
computationStateCache);
+  }
+
+  @Test
+  public void testActivateWorkForKey_EXECUTE_unknownKey() {
+    ActivateWorkResult activateWorkResult =
+        activeWorkState.activateWorkForKey(shardedKey("someKey", 1L), 
emptyWork());
+
+    assertEquals(ActivateWorkResult.EXECUTE, activateWorkResult);
+  }
+
+  @Test
+  public void testActivateWorkForKey_EXECUTE_emptyWorkQueueForKey() {

Review Comment:
   This seems odd to me because I don't think this is a legal state for the 
class. If somethig is active it is in the queue, if it becomes active and there 
is nothing it should clean itself up.
   
   It seems like it would be better to test through the interface of the class 
instead of injecting internal state members, maybe verifying internal state but 
not directly manipulating it from the test.
   
   For example here you coudl do the above, add an item, activate it, complete 
it, verify that the key doesn't exist in the internal map.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to