scwhittle commented on code in PR #38814:
URL: https://github.com/apache/beam/pull/38814#discussion_r3490366175


##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java:
##########
@@ -384,89 +376,142 @@ private ExecuteWorkResult executeWork(
 
     try {
       WindmillStateReader stateReader = work.createWindmillStateReader();
-      SideInputStateFetcher localSideInputStateFetcher =
-          
sideInputStateFetcherFactory.createSideInputStateFetcher(work::fetchSideInput);
-
-      // If the read output KVs, then we can decode Windmill's byte key into 
userland
-      // key object and provide it to the execution context for use with 
per-key state.
-      // Otherwise, we pass null.
-      //
-      // The coder type that will be present is:
-      //     WindowedValueCoder(TimerOrElementCoder(KvCoder))
-      Optional<Coder<?>> keyCoder = computationWorkExecutor.keyCoder();
-      @SuppressWarnings("deprecation")
-      @Nullable
-      final Object executionKey =
-          !keyCoder.isPresent() ? null : keyCoder.get().decode(key.newInput(), 
Coder.Context.OUTER);
-
-      if (workItem.hasHotKeyInfo()) {
-        Windmill.HotKeyInfo hotKeyInfo = workItem.getHotKeyInfo();
-        Duration hotKeyAge = Duration.millis(hotKeyInfo.getHotKeyAgeUsec() / 
1000);
-
-        String stepName = 
getShuffleTaskStepName(computationState.getMapTask());
-        if (executionKey != null
-            && (options.isHotKeyLoggingEnabled()
-                || hasExperiment(options, "enable_hot_key_logging"))
-            && keyCoder.isPresent()) {
-          hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge, executionKey);
-        } else {
-          hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge);
-        }
-      }
 
-      // Blocks while executing work.
-      computationWorkExecutor.executeWork(
-          executionKey, work, stateReader, localSideInputStateFetcher, 
outputBuilder);
+      KeyTransitionListener keyTransitionListener = 
createKeyTransitionListener();
+
+      List<Work> workBatch;
+      List<Windmill.WorkItemCommitRequest> workItemCommits;
+      Map<Long, Pair<Instant, Runnable>> accumulatedCallbacks;
+      long stateBytesRead;
+      {
+        // Blocks while executing work.
+        StreamingModeExecutionContext context =
+            computationWorkExecutor.executeWork(
+                work, stateReader, workExecutor, handle, 
keyTransitionListener);
+        if (context.workIsFailed()) {
+          throw new 
WorkItemCancelledException(work.getWorkItem().getShardingKey());
+        }
 
-      if (work.isFailed()) {
-        throw new WorkItemCancelledException(workItem.getShardingKey());
-      }
+        // Retrieve executed works, work item commits, and accumulated 
callbacks from execution
+        // context
+        workBatch = context.getExecutedWorks();
+        workItemCommits = context.getWorkItemCommits();
+        accumulatedCallbacks = context.getAccumulatedCallbacks();
+        stateBytesRead = context.getStateBytesRead();
 
-      // Reports source bytes processed to WorkItemCommitRequest if available.
-      try {
-        long sourceBytesProcessed =
-            computationWorkExecutor.computeSourceBytesProcessed(
-                computationState.sourceBytesProcessCounterName());
-        outputBuilder.setSourceBytesProcessed(sourceBytesProcessed);
-      } catch (Exception e) {
-        LOG.error("{}", e.toString());
+        context.reset(); // Don't use context after this.
       }
-
-      
commitFinalizer.cacheCommitFinalizers(computationWorkExecutor.context().flushState());
-
       // Release the execution state for another thread to use.
       computationState.releaseComputationWorkExecutor(computationWorkExecutor);
       computationWorkExecutor = null;
 
-      work.setState(Work.State.COMMIT_QUEUED);
-      
outputBuilder.addAllPerWorkItemLatencyAttributions(work.getLatencyAttributions(sampler));
-
       return ExecuteWorkResult.create(
-          outputBuilder, stateReader.getBytesRead() + 
localSideInputStateFetcher.getBytesRead());
+          workBatch, workItemCommits, accumulatedCallbacks, stateBytesRead);
     } catch (Throwable t) {
       if (computationWorkExecutor != null) {
         // If processing failed due to a thrown exception, close the 
executionState. Do not
         // return/release the executionState back to computationState as that 
will lead to this
         // executionState instance being reused.
-        LOG.debug("Invalidating executor after work item {} failed", 
workItem.getWorkToken(), t);
+        LOG.debug(
+            "Invalidating executor after work item {} failed",
+            work.getWorkItem().getWorkToken(),
+            t);
         computationWorkExecutor.invalidate();
       }
-
       // Re-throw the exception, it will be caught and handled by 
workFailureProcessor downstream.
       throw t;
     }
   }
 
+  private void handleOnlyFinalize(
+      ComputationState computationState, Work work, Windmill.WorkItem 
workItem) {
+    Windmill.WorkItemCommitRequest.Builder outputBuilder =
+        initializeOutputBuilder(workItem.getKey(), workItem);
+    
outputBuilder.setSourceStateUpdates(Windmill.SourceState.newBuilder().setOnlyFinalize(true));
+    work.setState(Work.State.COMMIT_QUEUED);
+    work.queueCommit(outputBuilder.build(), computationState);
+  }
+
+  private StageInfo getStageInfo(ComputationState computationState) {
+    MapTask mapTask = computationState.getMapTask();
+    return stageInfoMap.computeIfAbsent(
+        mapTask.getStageName(), s -> StageInfo.create(s, 
mapTask.getSystemName()));
+  }
+
+  private void commitWorkBatch(
+      ComputationState computationState,
+      List<Work> workBatch,
+      List<Windmill.WorkItemCommitRequest> workItemCommits) {
+    Preconditions.checkState(
+        workBatch.size() == 1, "Expected single-key work batch, got: " + 
workBatch.size());
+    commitSingleKeyWork(computationState, workBatch.get(0), 
workItemCommits.get(0));
+  }
+
+  private void commitSingleKeyWork(
+      ComputationState computationState, Work work, 
Windmill.WorkItemCommitRequest commitRequest) {
+    // Validate the commit request, possibly requesting truncation if the 
commitSize is too large.
+    Windmill.WorkItemCommitRequest validatedCommitRequest =
+        validateCommitRequestSize(
+            commitRequest, computationState.getComputationId(), 
work.getWorkItem());
+    work.setState(Work.State.COMMIT_QUEUED);
+    validatedCommitRequest =
+        validatedCommitRequest
+            .toBuilder()
+            
.addAllPerWorkItemLatencyAttributions(work.getLatencyAttributions(sampler))
+            .build();
+    work.queueCommit(validatedCommitRequest, computationState);
+  }
+
+  private void recordProcessingTime(
+      StageInfo stageInfo,
+      @Nullable List<Work> worksToCleanup,
+      Work work,
+      long processingStartTimeNanos) {
+    long processingTimeMsecs =
+        TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - 
processingStartTimeNanos);
+    stageInfo.totalProcessingMsecs().addValue(processingTimeMsecs);
+    if (anyWorkHasTimers(worksToCleanup, work)) {
+      // Attribute all the processing to timers if the work item contains any 
timers.
+      // Tests show that work items rarely contain both timers and message 
bundles. It should
+      // be a fairly close approximation.
+      // Another option: Derive time split between messages and timers based 
on recent totals.
+      // either here or in DFE.
+      stageInfo.timerProcessingMsecs().addValue(processingTimeMsecs);
+    }
+  }
+
+  private static boolean anyWorkHasTimers(@Nullable List<Work> works, Work 
primaryWork) {
+    if (works != null && !works.isEmpty()) {
+      return works.stream().anyMatch(w -> w.getWorkItem().hasTimers());
+    }
+    return primaryWork.getWorkItem().hasTimers();
+  }
+
+  private KeyTransitionListener createKeyTransitionListener() {
+    return (oldWork, newWork) -> {
+      setLoggingContextWorkId(newWork.getLatencyTrackingId());
+      newWork.setProcessingThreadName(oldWork.getProcessingThreadName());

Review Comment:
   see above, maybe we should set oldWork state to somethign showing it is 
waiting for the batch



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java:
##########
@@ -384,89 +376,142 @@ private ExecuteWorkResult executeWork(
 
     try {
       WindmillStateReader stateReader = work.createWindmillStateReader();
-      SideInputStateFetcher localSideInputStateFetcher =
-          
sideInputStateFetcherFactory.createSideInputStateFetcher(work::fetchSideInput);
-
-      // If the read output KVs, then we can decode Windmill's byte key into 
userland
-      // key object and provide it to the execution context for use with 
per-key state.
-      // Otherwise, we pass null.
-      //
-      // The coder type that will be present is:
-      //     WindowedValueCoder(TimerOrElementCoder(KvCoder))
-      Optional<Coder<?>> keyCoder = computationWorkExecutor.keyCoder();
-      @SuppressWarnings("deprecation")
-      @Nullable
-      final Object executionKey =
-          !keyCoder.isPresent() ? null : keyCoder.get().decode(key.newInput(), 
Coder.Context.OUTER);
-
-      if (workItem.hasHotKeyInfo()) {
-        Windmill.HotKeyInfo hotKeyInfo = workItem.getHotKeyInfo();
-        Duration hotKeyAge = Duration.millis(hotKeyInfo.getHotKeyAgeUsec() / 
1000);
-
-        String stepName = 
getShuffleTaskStepName(computationState.getMapTask());
-        if (executionKey != null
-            && (options.isHotKeyLoggingEnabled()
-                || hasExperiment(options, "enable_hot_key_logging"))
-            && keyCoder.isPresent()) {
-          hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge, executionKey);
-        } else {
-          hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge);
-        }
-      }
 
-      // Blocks while executing work.
-      computationWorkExecutor.executeWork(
-          executionKey, work, stateReader, localSideInputStateFetcher, 
outputBuilder);
+      KeyTransitionListener keyTransitionListener = 
createKeyTransitionListener();
+
+      List<Work> workBatch;
+      List<Windmill.WorkItemCommitRequest> workItemCommits;
+      Map<Long, Pair<Instant, Runnable>> accumulatedCallbacks;
+      long stateBytesRead;
+      {
+        // Blocks while executing work.
+        StreamingModeExecutionContext context =
+            computationWorkExecutor.executeWork(
+                work, stateReader, workExecutor, handle, 
keyTransitionListener);
+        if (context.workIsFailed()) {
+          throw new 
WorkItemCancelledException(work.getWorkItem().getShardingKey());
+        }
 
-      if (work.isFailed()) {
-        throw new WorkItemCancelledException(workItem.getShardingKey());
-      }
+        // Retrieve executed works, work item commits, and accumulated 
callbacks from execution
+        // context
+        workBatch = context.getExecutedWorks();
+        workItemCommits = context.getWorkItemCommits();
+        accumulatedCallbacks = context.getAccumulatedCallbacks();
+        stateBytesRead = context.getStateBytesRead();
 
-      // Reports source bytes processed to WorkItemCommitRequest if available.
-      try {
-        long sourceBytesProcessed =
-            computationWorkExecutor.computeSourceBytesProcessed(
-                computationState.sourceBytesProcessCounterName());
-        outputBuilder.setSourceBytesProcessed(sourceBytesProcessed);
-      } catch (Exception e) {
-        LOG.error("{}", e.toString());
+        context.reset(); // Don't use context after this.
       }
-
-      
commitFinalizer.cacheCommitFinalizers(computationWorkExecutor.context().flushState());
-
       // Release the execution state for another thread to use.
       computationState.releaseComputationWorkExecutor(computationWorkExecutor);
       computationWorkExecutor = null;
 
-      work.setState(Work.State.COMMIT_QUEUED);
-      
outputBuilder.addAllPerWorkItemLatencyAttributions(work.getLatencyAttributions(sampler));
-
       return ExecuteWorkResult.create(
-          outputBuilder, stateReader.getBytesRead() + 
localSideInputStateFetcher.getBytesRead());
+          workBatch, workItemCommits, accumulatedCallbacks, stateBytesRead);
     } catch (Throwable t) {
       if (computationWorkExecutor != null) {
         // If processing failed due to a thrown exception, close the 
executionState. Do not
         // return/release the executionState back to computationState as that 
will lead to this
         // executionState instance being reused.
-        LOG.debug("Invalidating executor after work item {} failed", 
workItem.getWorkToken(), t);
+        LOG.debug(
+            "Invalidating executor after work item {} failed",
+            work.getWorkItem().getWorkToken(),
+            t);
         computationWorkExecutor.invalidate();
       }
-
       // Re-throw the exception, it will be caught and handled by 
workFailureProcessor downstream.
       throw t;
     }
   }
 
+  private void handleOnlyFinalize(
+      ComputationState computationState, Work work, Windmill.WorkItem 
workItem) {
+    Windmill.WorkItemCommitRequest.Builder outputBuilder =
+        initializeOutputBuilder(workItem.getKey(), workItem);
+    
outputBuilder.setSourceStateUpdates(Windmill.SourceState.newBuilder().setOnlyFinalize(true));
+    work.setState(Work.State.COMMIT_QUEUED);
+    work.queueCommit(outputBuilder.build(), computationState);
+  }
+
+  private StageInfo getStageInfo(ComputationState computationState) {
+    MapTask mapTask = computationState.getMapTask();
+    return stageInfoMap.computeIfAbsent(
+        mapTask.getStageName(), s -> StageInfo.create(s, 
mapTask.getSystemName()));
+  }
+
+  private void commitWorkBatch(
+      ComputationState computationState,
+      List<Work> workBatch,
+      List<Windmill.WorkItemCommitRequest> workItemCommits) {
+    Preconditions.checkState(
+        workBatch.size() == 1, "Expected single-key work batch, got: " + 
workBatch.size());

Review Comment:
   check taht commits and batch are same size?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java:
##########
@@ -248,46 +240,39 @@ private void processWork(
   }
 
   private void processWork(
-      ComputationState computationState, Work work, 
BoundedQueueExecutorWorkHandle unusedHandle) {
+      ComputationState computationState, Work work, 
BoundedQueueExecutorWorkHandle handle) {
     Windmill.WorkItem workItem = work.getWorkItem();
     String computationId = computationState.getComputationId();
-    ByteString key = workItem.getKey();
     work.setProcessingThreadName(Thread.currentThread().getName());
     work.setState(Work.State.PROCESSING);

Review Comment:
   the first item is set to PROCESSING here and then others are set to 
PROCESSING as they are added to the batch.  But the first remains PROCESSING 
until it is transferred to QUEUED.  This might confuse user worker latency 
analysis, we attribute too much user processing time to that particular key and 
if we have N items taking a second each we would have O(N^2) total processing 
seconds instead of N.  Should we add a new POST_PROCESSING_QUEUED or something?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java:
##########
@@ -159,6 +168,35 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
   private @Nullable WorkExecutor workExecutor;
   private boolean finishKeyCalled = false;
 
+  @SuppressWarnings("UnusedVariable")

Review Comment:
   any idea why these suppressions are needed?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java:
##########
@@ -243,50 +291,120 @@ public byte[] getCurrentRecordOffset() {
     return checkStateNotNull(activeReader).getCurrentRecordOffset();
   }
 
+  /** Reset context before using it on a new bundle */
+  public void reset() {
+    this.executedWorks = new ArrayList<>();
+    this.outputBuilders = new ArrayList<>();
+    this.accumulatedCallbacks = new HashMap<>();
+    // Work from prior bundles might have a reference to the old 
workBatchFailed.
+    // If the work gets retried it'll get the new workBatchFailed to notify 
failure.
+    this.workBatchFailed = new AtomicBoolean(false);
+    this.sideInputCache.clear();
+    this.activeStateReader = null;
+    this.activeReader = null;
+    this.keyCoder = null;
+    this.workExecutor = null;
+    this.workQueueExecutor = null;
+    this.budgetHandle = null;
+    this.keyTransitionListener = null;
+    this.work = null;
+    this.key = null;
+    this.outputBuilder = null;
+    this.sideInputStateFetcher = null;
+    this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN;
+    clearSinkFullHint();
+    this.stateBytesRead = 0;
+  }
+
   public void start(
-      @Nullable Object key,
       Work work,
       WindmillStateReader stateReader,
-      SideInputStateFetcher sideInputStateFetcher,
-      Windmill.WorkItemCommitRequest.Builder outputBuilder,
-      WorkExecutor workExecutor) {
-    this.key = key;
-    this.work = work;
+      WorkExecutor workExecutor,
+      BoundedQueueExecutor workQueueExecutor,
+      BoundedQueueExecutorWorkHandle budgetHandle,
+      @Nullable Coder<?> keyCoder,
+      KeyTransitionListener keyTransitionListener) {
+    reset();
+    this.keyCoder = keyCoder;
     this.workExecutor = workExecutor;
-    this.finishKeyCalled = false;
-    this.computationKey = WindmillComputationKey.create(computationId, 
work.getShardedKey());
-    this.sideInputStateFetcher = sideInputStateFetcher;
+    this.workQueueExecutor = workQueueExecutor;
+    this.budgetHandle = budgetHandle;
+    this.keyTransitionListener = keyTransitionListener;
+
     StreamingGlobalConfig config = globalConfigHandle.getConfig();
     // Snapshot the limits for entire bundle processing.
     this.operationalLimits = config.operationalLimits();
-    this.outputBuilder = outputBuilder;
-    this.sideInputCache.clear();
-    this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN;
-    clearSinkFullHint();
 
-    Instant processingTime = 
computeProcessingTime(work.getWorkItem().getTimers().getTimersList());
+    startForNewKey(work, stateReader);
+  }
 
-    Collection<? extends StepContext> stepContexts = getAllStepContexts();
-    if (!stepContexts.isEmpty()) {
-      // This must be only created once for the workItem as token validation 
will fail if the same
-      // work token is reused.
-      WindmillStateCache.ForKey cacheForKey =
-          stateCache.forKey(getComputationKey(), 
getWorkItem().getCacheToken(), getWorkToken());
-      for (StepContext stepContext : stepContexts) {
-        stepContext.start(stateReader, processingTime, cacheForKey, 
work.watermarks());
+  private @Nullable Object decodeKey(Work work) {
+    // If the read output KVs, then we can decode Windmill's byte key into 
userland
+    // key object and provide it to the execution context for use with per-key 
state.
+    // Otherwise, we pass null.
+    //
+    // The coder type that will be present is:
+    //     WindowedValueCoder(TimerOrElementCoder(KvCoder))
+    if (keyCoder != null) {
+      try {
+        return keyCoder.decode(work.getWorkItem().getKey().newInput(), 
Coder.Context.OUTER);
+      } catch (IOException e) {
+        throw new RuntimeException("Failed to decode key during processing", 
e);
+      }
+    }
+    return null;
+  }
+
+  private Windmill.WorkItemCommitRequest.Builder createOutputBuilder(Work 
work) {
+    return Windmill.WorkItemCommitRequest.newBuilder()
+        .setKey(work.getWorkItem().getKey())
+        .setShardingKey(work.getWorkItem().getShardingKey())
+        .setWorkToken(work.getWorkItem().getWorkToken())
+        .setCacheToken(work.getWorkItem().getCacheToken());
+  }
+
+  private void logHotKeyIfDetected(Work work, @Nullable Object decodedKey) {
+    if (work.getWorkItem().hasHotKeyInfo()) {
+      Windmill.HotKeyInfo hotKeyInfo = work.getWorkItem().getHotKeyInfo();
+      Duration hotKeyAge = Duration.millis(hotKeyInfo.getHotKeyAgeUsec() / 
1000);
+      if (decodedKey != null && hotKeyLoggingEnabled) {
+        hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge, decodedKey);
+      } else {
+        hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge);
       }
     }
   }
 
+  private void startStepContexts(
+      WindmillStateReader stateReader,
+      Instant processingTime,
+      WindmillStateCache.ForKey cacheForKey,
+      Watermarks watermarks) {
+    Collection<? extends StepContext> stepContexts = getAllStepContexts();

Review Comment:
   rm stepContexts var and just inline call to loop?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java:
##########
@@ -243,50 +291,120 @@ public byte[] getCurrentRecordOffset() {
     return checkStateNotNull(activeReader).getCurrentRecordOffset();
   }
 
+  /** Reset context before using it on a new bundle */
+  public void reset() {
+    this.executedWorks = new ArrayList<>();
+    this.outputBuilders = new ArrayList<>();
+    this.accumulatedCallbacks = new HashMap<>();
+    // Work from prior bundles might have a reference to the old 
workBatchFailed.
+    // If the work gets retried it'll get the new workBatchFailed to notify 
failure.
+    this.workBatchFailed = new AtomicBoolean(false);
+    this.sideInputCache.clear();
+    this.activeStateReader = null;
+    this.activeReader = null;
+    this.keyCoder = null;
+    this.workExecutor = null;
+    this.workQueueExecutor = null;
+    this.budgetHandle = null;
+    this.keyTransitionListener = null;
+    this.work = null;
+    this.key = null;
+    this.outputBuilder = null;
+    this.sideInputStateFetcher = null;
+    this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN;
+    clearSinkFullHint();
+    this.stateBytesRead = 0;
+  }
+
   public void start(
-      @Nullable Object key,
       Work work,
       WindmillStateReader stateReader,
-      SideInputStateFetcher sideInputStateFetcher,
-      Windmill.WorkItemCommitRequest.Builder outputBuilder,
-      WorkExecutor workExecutor) {
-    this.key = key;
-    this.work = work;
+      WorkExecutor workExecutor,
+      BoundedQueueExecutor workQueueExecutor,
+      BoundedQueueExecutorWorkHandle budgetHandle,
+      @Nullable Coder<?> keyCoder,
+      KeyTransitionListener keyTransitionListener) {
+    reset();
+    this.keyCoder = keyCoder;
     this.workExecutor = workExecutor;
-    this.finishKeyCalled = false;
-    this.computationKey = WindmillComputationKey.create(computationId, 
work.getShardedKey());
-    this.sideInputStateFetcher = sideInputStateFetcher;
+    this.workQueueExecutor = workQueueExecutor;
+    this.budgetHandle = budgetHandle;
+    this.keyTransitionListener = keyTransitionListener;
+
     StreamingGlobalConfig config = globalConfigHandle.getConfig();
     // Snapshot the limits for entire bundle processing.
     this.operationalLimits = config.operationalLimits();
-    this.outputBuilder = outputBuilder;
-    this.sideInputCache.clear();
-    this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN;
-    clearSinkFullHint();
 
-    Instant processingTime = 
computeProcessingTime(work.getWorkItem().getTimers().getTimersList());
+    startForNewKey(work, stateReader);
+  }
 
-    Collection<? extends StepContext> stepContexts = getAllStepContexts();
-    if (!stepContexts.isEmpty()) {
-      // This must be only created once for the workItem as token validation 
will fail if the same
-      // work token is reused.
-      WindmillStateCache.ForKey cacheForKey =
-          stateCache.forKey(getComputationKey(), 
getWorkItem().getCacheToken(), getWorkToken());
-      for (StepContext stepContext : stepContexts) {
-        stepContext.start(stateReader, processingTime, cacheForKey, 
work.watermarks());
+  private @Nullable Object decodeKey(Work work) {
+    // If the read output KVs, then we can decode Windmill's byte key into 
userland
+    // key object and provide it to the execution context for use with per-key 
state.
+    // Otherwise, we pass null.
+    //
+    // The coder type that will be present is:
+    //     WindowedValueCoder(TimerOrElementCoder(KvCoder))
+    if (keyCoder != null) {
+      try {
+        return keyCoder.decode(work.getWorkItem().getKey().newInput(), 
Coder.Context.OUTER);
+      } catch (IOException e) {
+        throw new RuntimeException("Failed to decode key during processing", 
e);
+      }
+    }
+    return null;
+  }
+
+  private Windmill.WorkItemCommitRequest.Builder createOutputBuilder(Work 
work) {
+    return Windmill.WorkItemCommitRequest.newBuilder()
+        .setKey(work.getWorkItem().getKey())
+        .setShardingKey(work.getWorkItem().getShardingKey())
+        .setWorkToken(work.getWorkItem().getWorkToken())
+        .setCacheToken(work.getWorkItem().getCacheToken());
+  }
+
+  private void logHotKeyIfDetected(Work work, @Nullable Object decodedKey) {
+    if (work.getWorkItem().hasHotKeyInfo()) {
+      Windmill.HotKeyInfo hotKeyInfo = work.getWorkItem().getHotKeyInfo();
+      Duration hotKeyAge = Duration.millis(hotKeyInfo.getHotKeyAgeUsec() / 
1000);
+      if (decodedKey != null && hotKeyLoggingEnabled) {
+        hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge, decodedKey);
+      } else {
+        hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge);
       }
     }
   }
 
+  private void startStepContexts(
+      WindmillStateReader stateReader,
+      Instant processingTime,
+      WindmillStateCache.ForKey cacheForKey,
+      Watermarks watermarks) {
+    Collection<? extends StepContext> stepContexts = getAllStepContexts();
+    for (StepContext stepContext : stepContexts) {
+      stepContext.start(stateReader, processingTime, cacheForKey, watermarks);
+    }
+  }
+
   public void finishKey() {
-    checkState(!finishKeyCalled, "finishKey was already called");
+    if (finishKeyCalled) {
+      return;
+    }
+    if (activeStateReader != null) {
+      this.stateBytesRead += activeStateReader.getBytesRead();
+    }
+    if (sideInputStateFetcher != null) {
+      this.stateBytesRead += sideInputStateFetcher.getBytesRead();
+    }
     checkStateNotNull(workExecutor, "workExecutor must be set before calling 
finishKey()");
     try {
       workExecutor.finishKey(key);
     } catch (Exception e) {
       throw new RuntimeException(e);
     }
     this.finishKeyCalled = true;

Review Comment:
   set to true just after checking?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java:
##########
@@ -553,7 +673,102 @@ public Map<Long, Pair<Instant, Runnable>> flushState() {
       // RestrictionTracker.getProgress() or GetSize() are not defined.
       getOutputBuilder().setSourceBacklogBytes(backlogBytes);
     }
-    return callbacks;
+
+    this.accumulatedCallbacks.putAll(callbacks);
+
+    getOutputBuilder()
+        
.setSourceBytesProcessed(computeSourceBytesProcessed(sourceBytesProcessCounterName));
+  }
+
+  private final long computeSourceBytesProcessed(String 
sourceBytesCounterName) {
+    if (!(workExecutor instanceof DataflowMapTaskExecutor)) {
+      return 0L;
+    }
+    HashMap<String, ElementCounter> counters =
+        ((DataflowMapTaskExecutor) workExecutor)
+            .getReadOperation()
+            .receivers[0]
+            .getOutputCounters();
+
+    return Optional.ofNullable(counters.get(sourceBytesCounterName))
+        .map(counter -> ((OutputObjectAndByteCounter) 
counter).getByteCount().getAndReset())
+        .orElse(0L);
+  }
+
+  public Map<Long, Pair<Instant, Runnable>> flushState() {
+    return accumulatedCallbacks;
+  }
+
+  public boolean advance() {

Review Comment:
   This is presumably going to possibly merge in more keys in the future but 
looks like it should just be removed for now. add a TODO/comment with context? 



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java:
##########
@@ -553,7 +673,102 @@ public Map<Long, Pair<Instant, Runnable>> flushState() {
       // RestrictionTracker.getProgress() or GetSize() are not defined.
       getOutputBuilder().setSourceBacklogBytes(backlogBytes);
     }
-    return callbacks;
+
+    this.accumulatedCallbacks.putAll(callbacks);
+
+    getOutputBuilder()
+        
.setSourceBytesProcessed(computeSourceBytesProcessed(sourceBytesProcessCounterName));
+  }
+
+  private final long computeSourceBytesProcessed(String 
sourceBytesCounterName) {
+    if (!(workExecutor instanceof DataflowMapTaskExecutor)) {
+      return 0L;
+    }
+    HashMap<String, ElementCounter> counters =
+        ((DataflowMapTaskExecutor) workExecutor)
+            .getReadOperation()
+            .receivers[0]
+            .getOutputCounters();
+
+    return Optional.ofNullable(counters.get(sourceBytesCounterName))
+        .map(counter -> ((OutputObjectAndByteCounter) 
counter).getByteCount().getAndReset())
+        .orElse(0L);
+  }
+
+  public Map<Long, Pair<Instant, Runnable>> flushState() {
+    return accumulatedCallbacks;
+  }
+
+  public boolean advance() {
+    return false;
+  }
+
+  private void startForNewKey(Work newWork, WindmillStateReader reader) {
+    newWork.setState(Work.State.PROCESSING);
+    if (keyTransitionListener != null && this.work != null && this.work != 
newWork) {
+      keyTransitionListener.onKeyTransition(this.work, newWork);
+    }
+    this.key = decodeKey(newWork);
+    this.work = newWork;
+    this.finishKeyCalled = false;
+    this.computationKey = WindmillComputationKey.create(computationId, 
newWork.getShardedKey());
+
+    this.outputBuilder = createOutputBuilder(newWork);
+    this.outputBuilders.add(this.outputBuilder);
+    newWork.setOnFailureListener(this.workBatchFailed);
+    this.executedWorks.add(newWork);
+
+    logHotKeyIfDetected(newWork, this.key);
+
+    this.sideInputStateFetcher =
+        
sideInputStateFetcherFactory.createSideInputStateFetcher(newWork::fetchSideInput);
+    this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN;
+    this.activeReader = null;
+
+    // Note: We do NOT clear sideInputCache here, allowing Key B to reuse warm 
side inputs!
+
+    // Re-initialize state cache and state/timer internals across all step 
contexts
+    Instant processingTime =
+        
computeProcessingTime(newWork.getWorkItem().getTimers().getTimersList());
+    if (!getAllStepContexts().isEmpty()) {

Review Comment:
   coudl remove ! and invert since there is an else



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java:
##########
@@ -129,73 +126,77 @@ public static <K, T> WindowingWindmillReader<K, T> create(
     return new WindowingWindmillReader<>(coder, context, 
skipUndecodableElements);
   }
 
+  private KeyedWorkItem<K, T> createKeyedWorkItem() {
+    @SuppressWarnings("unchecked")
+    @Nullable
+    K key = (K) context.getKey();
+    return new WindmillKeyedWorkItem<>(
+        key,
+        context.getWorkItem(),
+        windowCoder,
+        windowsCoder,
+        valueCoder,
+        context.getWindmillTagEncoding(),
+        context.getDrainMode(),
+        skipUndecodableElements.isAccessible()
+            && Boolean.TRUE.equals(skipUndecodableElements.get()));
+  }
+
+  private boolean isEmpty(KeyedWorkItem<K, T> keyedWorkItem) {
+    return Iterables.isEmpty(keyedWorkItem.timersIterable())
+        && Iterables.isEmpty(keyedWorkItem.elementsIterable());
+  }
+
   @Override
   public NativeReaderIterator<WindowedValue<KeyedWorkItem<K, T>>> iterator() 
throws IOException {
-    final K key =
-        keyCoder.decode(
-            checkStateNotNull(context.getSerializedKey()).newInput(), 
Coder.Context.OUTER);
-    final WorkItem workItem = context.getWorkItem();
-    KeyedWorkItem<K, T> keyedWorkItem =
-        new WindmillKeyedWorkItem<>(
-            key,
-            workItem,
-            windowCoder,
-            windowsCoder,
-            valueCoder,
-            context.getWindmillTagEncoding(),
-            context.getDrainMode(),
-            skipUndecodableElements.isAccessible()
-                && Boolean.TRUE.equals(skipUndecodableElements.get()));
-    final boolean isEmptyWorkItem =
-        (Iterables.isEmpty(keyedWorkItem.timersIterable())
-            && Iterables.isEmpty(keyedWorkItem.elementsIterable()));
-    final WindowedValue<KeyedWorkItem<K, T>> value = new 
ValueInEmptyWindows<>(keyedWorkItem);
-
-    // Return a noop iterator when current workitem is an empty workitem.
-    if (isEmptyWorkItem) {
-      return new NativeReaderIterator<WindowedValue<KeyedWorkItem<K, T>>>() {
-        @Override
-        public boolean start() throws IOException {
-          context.finishKey();
-          return false;
-        }
-
-        @Override
-        public boolean advance() throws IOException {
-          return false;
+    return new NativeReaderIterator<WindowedValue<KeyedWorkItem<K, T>>>() {
+      private @Nullable WindowedValue<KeyedWorkItem<K, T>> current = null;
+
+      @Override
+      public boolean start() throws IOException {

Review Comment:
   can we just implement start() and advance() with a helper method taking a 
bool on whether or not to advance initially?  seems safer if there is more 
setup/checking added before processing an item
   
   



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java:
##########
@@ -436,28 +450,50 @@ public void testStateTagEncodingBasedOnConfig() {
 
   @Test
   public void testSetBacklogBytes() {
-    Windmill.WorkItemCommitRequest.Builder outputBuilder =
-        Windmill.WorkItemCommitRequest.newBuilder();
     NameContext nameContext = NameContextsForTests.nameContextForTest();
     DataflowOperationContext operationContext =
         executionContext.createOperationContext(nameContext);
     StreamingModeExecutionContext.StepContext stepContext =
         executionContext.getStepContext(operationContext);
 
-    executionContext.start(
-        "key",
+    start(
         createMockWork(
             
Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(17L).build(),
-            Watermarks.builder().setInputDataWatermark(new 
Instant(1000)).build()),
-        stateReader,
-        sideInputStateFetcher,
-        outputBuilder,
-        workExecutor);
+            Watermarks.builder().setInputDataWatermark(new 
Instant(1000)).build()));
 
     stepContext.setBacklogBytes(1234.0);
     executionContext.finishKey();
     executionContext.flushState();
 
-    assertEquals(1234, outputBuilder.getSourceBacklogBytes());
+    assertEquals(1234, 
executionContext.getOutputBuilder().getSourceBacklogBytes());
+  }
+
+  @Test
+  public void testFinishKeyReentrantSafety() {
+    start(
+        createMockWork(
+            
Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(17L).build(),
+            Watermarks.builder().setInputDataWatermark(new 
Instant(1000)).build()));
+
+    // First call
+    executionContext.finishKey();
+    // Second call - should not throw any Exception
+    executionContext.finishKey();

Review Comment:
   when does this happen? add a comment here why we want this to be the case



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReaderTest.java:
##########
@@ -0,0 +1,275 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import java.io.IOException;
+import java.util.List;
+import org.apache.beam.runners.core.KeyedWorkItem;
+import org.apache.beam.runners.dataflow.worker.streaming.Watermarks;
+import org.apache.beam.runners.dataflow.worker.streaming.Work;
+import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient;
+import 
org.apache.beam.runners.dataflow.worker.windmill.state.WindmillTagEncodingV1;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.ListCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.VarLongCoder;
+import org.apache.beam.sdk.options.ValueProvider;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import 
org.apache.beam.sdk.transforms.windowing.IntervalWindow.IntervalWindowCoder;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo.PaneInfoCoder;
+import org.apache.beam.sdk.util.ByteStringOutputStream;
+import org.apache.beam.sdk.values.WindowedValue;
+import org.apache.beam.sdk.values.WindowedValues.FullWindowedValueCoder;
+import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
+import org.joda.time.Instant;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class WindowingWindmillReaderTest {
+  private StreamingModeExecutionContext mockContext;
+  private WindowingWindmillReader<String, Long> reader;
+
+  @SuppressWarnings("unchecked")
+  @Before
+  public void setUp() {
+    mockContext = mock(StreamingModeExecutionContext.class);
+    when(mockContext.workIsFailed()).thenReturn(false);
+    
when(mockContext.getWindmillTagEncoding()).thenReturn(WindmillTagEncodingV1.instance());
+    when(mockContext.getDrainMode()).thenReturn(false);
+
+    Coder<String> keyCoder = StringUtf8Coder.of();
+    Coder<Long> valueCoder = VarLongCoder.of();
+    KvCoder<String, Long> kvCoder = KvCoder.of(keyCoder, valueCoder);
+    WindmillKeyedWorkItem.FakeKeyedWorkItemCoder<String, Long> 
keyedWorkItemCoder =
+        (WindmillKeyedWorkItem.FakeKeyedWorkItemCoder<String, Long>)
+            WindmillKeyedWorkItem.FakeKeyedWorkItemCoder.of(kvCoder);
+    FullWindowedValueCoder<KeyedWorkItem<String, Long>> coder =
+        FullWindowedValueCoder.of(keyedWorkItemCoder, 
IntervalWindowCoder.of());
+
+    reader =
+        WindowingWindmillReader.create(
+            coder, mockContext, ValueProvider.StaticValueProvider.of(false));
+  }
+
+  private static Work createMockWork(Windmill.WorkItem workItem) {
+    return Work.create(
+        workItem,
+        workItem.getSerializedSize(),
+        Watermarks.builder().setInputDataWatermark(new Instant(1000)).build(),
+        Work.createProcessingContext(
+            "computationId", new FakeGetDataClient(), ignored -> {}, 
mock(HeartbeatSender.class)),
+        false,
+        Instant::now);
+  }
+
+  private static ByteString encodeMetadata(List<IntervalWindow> windows) 
throws IOException {
+    ByteStringOutputStream stream = new ByteStringOutputStream();
+    PaneInfoCoder.INSTANCE.encode(PaneInfo.NO_FIRING, stream);
+    ListCoder.of(IntervalWindowCoder.of()).encode(windows, stream);
+    return stream.toByteString();
+  }
+
+  private static ByteString encodeValue(long value) throws IOException {
+    ByteStringOutputStream stream = new ByteStringOutputStream();
+    VarLongCoder.of().encode(value, stream);
+    return stream.toByteString();
+  }
+
+  @Test
+  public void testSingleNonEmptyKey() throws IOException {
+    IntervalWindow window = new IntervalWindow(new Instant(0), new 
Instant(1000));
+    Windmill.WorkItem workItem =
+        Windmill.WorkItem.newBuilder()
+            .setKey(ByteString.copyFromUtf8("key1"))
+            .setWorkToken(100L)
+            .addMessageBundles(
+                Windmill.InputMessageBundle.newBuilder()
+                    .setSourceComputationId("foo")
+                    .addMessages(
+                        Windmill.Message.newBuilder()
+                            .setTimestamp(1000)
+                            .setData(encodeValue(42L))
+                            
.setMetadata(encodeMetadata(ImmutableList.of(window)))
+                            .build())
+                    .build())
+            .build();
+    Work work = createMockWork(workItem);
+
+    when(mockContext.getKey()).thenReturn("key1");
+    when(mockContext.getWorkItem()).thenReturn(workItem);
+    when(mockContext.getWork()).thenReturn(work);
+    when(mockContext.advance()).thenReturn(false);
+
+    try (NativeReader.NativeReaderIterator<WindowedValue<KeyedWorkItem<String, 
Long>>> iter =
+        reader.iterator()) {
+      assertTrue(iter.start());
+      WindowedValue<KeyedWorkItem<String, Long>> current = iter.getCurrent();
+      assertEquals("key1", current.getValue().key());
+      assertFalse(Iterables.isEmpty(current.getValue().elementsIterable()));
+      WindowedValue<Long> elem = 
Iterables.getOnlyElement(current.getValue().elementsIterable());
+      assertEquals(42L, elem.getValue().longValue());
+
+      assertFalse(iter.advance());
+      verify(mockContext).finishKey();
+    }
+  }
+
+  @Test
+  public void testSingleEmptyKey() throws IOException {
+    Windmill.WorkItem workItem =
+        Windmill.WorkItem.newBuilder()
+            .setKey(ByteString.copyFromUtf8("key1"))
+            .setWorkToken(100L)
+            .build(); // No message bundles or timers
+    Work work = createMockWork(workItem);
+
+    when(mockContext.getKey()).thenReturn("key1");
+    when(mockContext.getWorkItem()).thenReturn(workItem);
+    when(mockContext.getWork()).thenReturn(work);
+    when(mockContext.advance()).thenReturn(false);
+
+    try (NativeReader.NativeReaderIterator<WindowedValue<KeyedWorkItem<String, 
Long>>> iter =
+        reader.iterator()) {
+      assertFalse(
+          iter.start()); // Should skip the empty key and return false because 
advance returns false
+      verify(mockContext).finishKey();
+    }
+  }
+
+  @Test
+  public void testMultipleKeys_withEmptyAndNonEmpty() throws IOException {
+    IntervalWindow window = new IntervalWindow(new Instant(0), new 
Instant(1000));
+    // Key 1: Empty
+    Windmill.WorkItem workItem1 =
+        Windmill.WorkItem.newBuilder()
+            .setKey(ByteString.copyFromUtf8("key1"))
+            .setWorkToken(100L)
+            .build();
+    Work work1 = createMockWork(workItem1);
+
+    // Key 2: Non-empty
+    Windmill.WorkItem workItem2 =
+        Windmill.WorkItem.newBuilder()
+            .setKey(ByteString.copyFromUtf8("key2"))
+            .setWorkToken(200L)
+            .addMessageBundles(
+                Windmill.InputMessageBundle.newBuilder()
+                    .setSourceComputationId("foo")
+                    .addMessages(
+                        Windmill.Message.newBuilder()
+                            .setTimestamp(2000)
+                            .setData(encodeValue(84L))
+                            
.setMetadata(encodeMetadata(ImmutableList.of(window)))
+                            .build())
+                    .build())
+            .build();
+    Work work2 = createMockWork(workItem2);
+
+    // Key 3: Empty
+    Windmill.WorkItem workItem3 =
+        Windmill.WorkItem.newBuilder()
+            .setKey(ByteString.copyFromUtf8("key3"))
+            .setWorkToken(300L)
+            .build();
+    Work work3 = createMockWork(workItem3);
+
+    // Initial state
+    when(mockContext.getKey()).thenReturn("key1");
+    when(mockContext.getWorkItem()).thenReturn(workItem1);
+    when(mockContext.getWork()).thenReturn(work1);
+
+    // Mock transition behaviour of context.advance()
+    when(mockContext.advance())

Review Comment:
   ditto seems a bit odd to mock this in here



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java:
##########
@@ -307,21 +292,18 @@ private void processWork(
     } finally {
       // Update total processing time counters. Updating in finally clause 
ensures that
       // work items causing exceptions are also accounted in time spent.
-      long processingTimeMsecs =
-          TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - 
processingStartTimeNanos);
-      stageInfo.totalProcessingMsecs().addValue(processingTimeMsecs);
-
-      // Attribute all the processing to timers if the work item contains any 
timers.
-      // Tests show that work items rarely contain both timers and message 
bundles. It should
-      // be a fairly close approximation.
-      // Another option: Derive time split between messages and timers based 
on recent totals.
-      // either here or in DFE.
-      if (work.getWorkItem().hasTimers()) {
-        stageInfo.timerProcessingMsecs().addValue(processingTimeMsecs);
+      recordProcessingTime(stageInfo, workBatch, work, 
processingStartTimeNanos);

Review Comment:
   could we simplify here by just creating a single-element workBatch if 
workBatch is null? and then just using the batch for this method and below



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java:
##########
@@ -243,50 +291,120 @@ public byte[] getCurrentRecordOffset() {
     return checkStateNotNull(activeReader).getCurrentRecordOffset();
   }
 
+  /** Reset context before using it on a new bundle */
+  public void reset() {
+    this.executedWorks = new ArrayList<>();
+    this.outputBuilders = new ArrayList<>();
+    this.accumulatedCallbacks = new HashMap<>();
+    // Work from prior bundles might have a reference to the old 
workBatchFailed.
+    // If the work gets retried it'll get the new workBatchFailed to notify 
failure.
+    this.workBatchFailed = new AtomicBoolean(false);
+    this.sideInputCache.clear();
+    this.activeStateReader = null;
+    this.activeReader = null;
+    this.keyCoder = null;
+    this.workExecutor = null;
+    this.workQueueExecutor = null;
+    this.budgetHandle = null;
+    this.keyTransitionListener = null;
+    this.work = null;
+    this.key = null;
+    this.outputBuilder = null;
+    this.sideInputStateFetcher = null;
+    this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN;
+    clearSinkFullHint();
+    this.stateBytesRead = 0;
+  }
+
   public void start(
-      @Nullable Object key,
       Work work,
       WindmillStateReader stateReader,
-      SideInputStateFetcher sideInputStateFetcher,
-      Windmill.WorkItemCommitRequest.Builder outputBuilder,
-      WorkExecutor workExecutor) {
-    this.key = key;
-    this.work = work;
+      WorkExecutor workExecutor,
+      BoundedQueueExecutor workQueueExecutor,
+      BoundedQueueExecutorWorkHandle budgetHandle,
+      @Nullable Coder<?> keyCoder,
+      KeyTransitionListener keyTransitionListener) {
+    reset();

Review Comment:
   we are also reset'ing when we are done. So this should be unneeded
   
   could maybe instead have a couple checks that lists are empty?
   or reset could set to null and start() does the allocations



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java:
##########
@@ -243,50 +291,120 @@ public byte[] getCurrentRecordOffset() {
     return checkStateNotNull(activeReader).getCurrentRecordOffset();
   }
 
+  /** Reset context before using it on a new bundle */
+  public void reset() {
+    this.executedWorks = new ArrayList<>();
+    this.outputBuilders = new ArrayList<>();
+    this.accumulatedCallbacks = new HashMap<>();
+    // Work from prior bundles might have a reference to the old 
workBatchFailed.
+    // If the work gets retried it'll get the new workBatchFailed to notify 
failure.
+    this.workBatchFailed = new AtomicBoolean(false);
+    this.sideInputCache.clear();
+    this.activeStateReader = null;
+    this.activeReader = null;
+    this.keyCoder = null;
+    this.workExecutor = null;
+    this.workQueueExecutor = null;
+    this.budgetHandle = null;
+    this.keyTransitionListener = null;
+    this.work = null;
+    this.key = null;
+    this.outputBuilder = null;
+    this.sideInputStateFetcher = null;
+    this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN;
+    clearSinkFullHint();
+    this.stateBytesRead = 0;
+  }
+
   public void start(
-      @Nullable Object key,
       Work work,
       WindmillStateReader stateReader,
-      SideInputStateFetcher sideInputStateFetcher,
-      Windmill.WorkItemCommitRequest.Builder outputBuilder,
-      WorkExecutor workExecutor) {
-    this.key = key;
-    this.work = work;
+      WorkExecutor workExecutor,
+      BoundedQueueExecutor workQueueExecutor,
+      BoundedQueueExecutorWorkHandle budgetHandle,
+      @Nullable Coder<?> keyCoder,
+      KeyTransitionListener keyTransitionListener) {
+    reset();
+    this.keyCoder = keyCoder;
     this.workExecutor = workExecutor;
-    this.finishKeyCalled = false;
-    this.computationKey = WindmillComputationKey.create(computationId, 
work.getShardedKey());
-    this.sideInputStateFetcher = sideInputStateFetcher;
+    this.workQueueExecutor = workQueueExecutor;
+    this.budgetHandle = budgetHandle;
+    this.keyTransitionListener = keyTransitionListener;
+
     StreamingGlobalConfig config = globalConfigHandle.getConfig();
     // Snapshot the limits for entire bundle processing.
     this.operationalLimits = config.operationalLimits();
-    this.outputBuilder = outputBuilder;
-    this.sideInputCache.clear();
-    this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN;
-    clearSinkFullHint();
 
-    Instant processingTime = 
computeProcessingTime(work.getWorkItem().getTimers().getTimersList());
+    startForNewKey(work, stateReader);
+  }
 
-    Collection<? extends StepContext> stepContexts = getAllStepContexts();
-    if (!stepContexts.isEmpty()) {
-      // This must be only created once for the workItem as token validation 
will fail if the same
-      // work token is reused.
-      WindmillStateCache.ForKey cacheForKey =
-          stateCache.forKey(getComputationKey(), 
getWorkItem().getCacheToken(), getWorkToken());
-      for (StepContext stepContext : stepContexts) {
-        stepContext.start(stateReader, processingTime, cacheForKey, 
work.watermarks());
+  private @Nullable Object decodeKey(Work work) {
+    // If the read output KVs, then we can decode Windmill's byte key into 
userland
+    // key object and provide it to the execution context for use with per-key 
state.
+    // Otherwise, we pass null.
+    //
+    // The coder type that will be present is:
+    //     WindowedValueCoder(TimerOrElementCoder(KvCoder))
+    if (keyCoder != null) {
+      try {
+        return keyCoder.decode(work.getWorkItem().getKey().newInput(), 
Coder.Context.OUTER);
+      } catch (IOException e) {
+        throw new RuntimeException("Failed to decode key during processing", 
e);

Review Comment:
   I wonder if this would be better to have this and parent methods throw 
CoderException.  It could make it clearer at the call sites that start or key 
transition could possibly fail for this reason (and later support for 
--skipInputElementsWithDecodingExceptions could be added if desired).



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java:
##########
@@ -243,50 +291,120 @@ public byte[] getCurrentRecordOffset() {
     return checkStateNotNull(activeReader).getCurrentRecordOffset();
   }
 
+  /** Reset context before using it on a new bundle */
+  public void reset() {
+    this.executedWorks = new ArrayList<>();
+    this.outputBuilders = new ArrayList<>();
+    this.accumulatedCallbacks = new HashMap<>();
+    // Work from prior bundles might have a reference to the old 
workBatchFailed.
+    // If the work gets retried it'll get the new workBatchFailed to notify 
failure.
+    this.workBatchFailed = new AtomicBoolean(false);
+    this.sideInputCache.clear();
+    this.activeStateReader = null;
+    this.activeReader = null;
+    this.keyCoder = null;
+    this.workExecutor = null;
+    this.workQueueExecutor = null;
+    this.budgetHandle = null;
+    this.keyTransitionListener = null;
+    this.work = null;
+    this.key = null;
+    this.outputBuilder = null;
+    this.sideInputStateFetcher = null;
+    this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN;
+    clearSinkFullHint();
+    this.stateBytesRead = 0;
+  }
+
   public void start(
-      @Nullable Object key,
       Work work,
       WindmillStateReader stateReader,
-      SideInputStateFetcher sideInputStateFetcher,
-      Windmill.WorkItemCommitRequest.Builder outputBuilder,
-      WorkExecutor workExecutor) {
-    this.key = key;
-    this.work = work;
+      WorkExecutor workExecutor,
+      BoundedQueueExecutor workQueueExecutor,
+      BoundedQueueExecutorWorkHandle budgetHandle,
+      @Nullable Coder<?> keyCoder,
+      KeyTransitionListener keyTransitionListener) {
+    reset();
+    this.keyCoder = keyCoder;
     this.workExecutor = workExecutor;
-    this.finishKeyCalled = false;
-    this.computationKey = WindmillComputationKey.create(computationId, 
work.getShardedKey());
-    this.sideInputStateFetcher = sideInputStateFetcher;
+    this.workQueueExecutor = workQueueExecutor;
+    this.budgetHandle = budgetHandle;
+    this.keyTransitionListener = keyTransitionListener;
+
     StreamingGlobalConfig config = globalConfigHandle.getConfig();
     // Snapshot the limits for entire bundle processing.
     this.operationalLimits = config.operationalLimits();
-    this.outputBuilder = outputBuilder;
-    this.sideInputCache.clear();
-    this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN;
-    clearSinkFullHint();
 
-    Instant processingTime = 
computeProcessingTime(work.getWorkItem().getTimers().getTimersList());
+    startForNewKey(work, stateReader);
+  }
 
-    Collection<? extends StepContext> stepContexts = getAllStepContexts();
-    if (!stepContexts.isEmpty()) {
-      // This must be only created once for the workItem as token validation 
will fail if the same
-      // work token is reused.
-      WindmillStateCache.ForKey cacheForKey =
-          stateCache.forKey(getComputationKey(), 
getWorkItem().getCacheToken(), getWorkToken());
-      for (StepContext stepContext : stepContexts) {
-        stepContext.start(stateReader, processingTime, cacheForKey, 
work.watermarks());
+  private @Nullable Object decodeKey(Work work) {
+    // If the read output KVs, then we can decode Windmill's byte key into 
userland
+    // key object and provide it to the execution context for use with per-key 
state.
+    // Otherwise, we pass null.
+    //
+    // The coder type that will be present is:
+    //     WindowedValueCoder(TimerOrElementCoder(KvCoder))
+    if (keyCoder != null) {
+      try {
+        return keyCoder.decode(work.getWorkItem().getKey().newInput(), 
Coder.Context.OUTER);
+      } catch (IOException e) {
+        throw new RuntimeException("Failed to decode key during processing", 
e);
+      }
+    }
+    return null;
+  }
+
+  private Windmill.WorkItemCommitRequest.Builder createOutputBuilder(Work 
work) {
+    return Windmill.WorkItemCommitRequest.newBuilder()
+        .setKey(work.getWorkItem().getKey())
+        .setShardingKey(work.getWorkItem().getShardingKey())
+        .setWorkToken(work.getWorkItem().getWorkToken())
+        .setCacheToken(work.getWorkItem().getCacheToken());
+  }
+
+  private void logHotKeyIfDetected(Work work, @Nullable Object decodedKey) {
+    if (work.getWorkItem().hasHotKeyInfo()) {
+      Windmill.HotKeyInfo hotKeyInfo = work.getWorkItem().getHotKeyInfo();
+      Duration hotKeyAge = Duration.millis(hotKeyInfo.getHotKeyAgeUsec() / 
1000);
+      if (decodedKey != null && hotKeyLoggingEnabled) {
+        hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge, decodedKey);
+      } else {
+        hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge);
       }
     }
   }
 
+  private void startStepContexts(
+      WindmillStateReader stateReader,
+      Instant processingTime,
+      WindmillStateCache.ForKey cacheForKey,
+      Watermarks watermarks) {
+    Collection<? extends StepContext> stepContexts = getAllStepContexts();
+    for (StepContext stepContext : stepContexts) {
+      stepContext.start(stateReader, processingTime, cacheForKey, watermarks);
+    }
+  }
+
   public void finishKey() {
-    checkState(!finishKeyCalled, "finishKey was already called");
+    if (finishKeyCalled) {
+      return;
+    }
+    if (activeStateReader != null) {
+      this.stateBytesRead += activeStateReader.getBytesRead();
+    }
+    if (sideInputStateFetcher != null) {
+      this.stateBytesRead += sideInputStateFetcher.getBytesRead();
+    }
     checkStateNotNull(workExecutor, "workExecutor must be set before calling 
finishKey()");

Review Comment:
   move to top?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java:
##########
@@ -553,7 +673,102 @@ public Map<Long, Pair<Instant, Runnable>> flushState() {
       // RestrictionTracker.getProgress() or GetSize() are not defined.
       getOutputBuilder().setSourceBacklogBytes(backlogBytes);
     }
-    return callbacks;
+
+    this.accumulatedCallbacks.putAll(callbacks);
+
+    getOutputBuilder()
+        
.setSourceBytesProcessed(computeSourceBytesProcessed(sourceBytesProcessCounterName));
+  }
+
+  private final long computeSourceBytesProcessed(String 
sourceBytesCounterName) {
+    if (!(workExecutor instanceof DataflowMapTaskExecutor)) {
+      return 0L;
+    }
+    HashMap<String, ElementCounter> counters =
+        ((DataflowMapTaskExecutor) workExecutor)
+            .getReadOperation()
+            .receivers[0]
+            .getOutputCounters();
+
+    return Optional.ofNullable(counters.get(sourceBytesCounterName))
+        .map(counter -> ((OutputObjectAndByteCounter) 
counter).getByteCount().getAndReset())
+        .orElse(0L);
+  }
+
+  public Map<Long, Pair<Instant, Runnable>> flushState() {
+    return accumulatedCallbacks;
+  }
+
+  public boolean advance() {
+    return false;
+  }
+
+  private void startForNewKey(Work newWork, WindmillStateReader reader) {
+    newWork.setState(Work.State.PROCESSING);
+    if (keyTransitionListener != null && this.work != null && this.work != 
newWork) {
+      keyTransitionListener.onKeyTransition(this.work, newWork);
+    }
+    this.key = decodeKey(newWork);

Review Comment:
   could consider moving this decode outside startForNewKey and passing the key 
in.  That would make it easier to skip keys that had decoding errors and keep 
the rest of the batch.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java:
##########
@@ -553,7 +673,102 @@ public Map<Long, Pair<Instant, Runnable>> flushState() {
       // RestrictionTracker.getProgress() or GetSize() are not defined.
       getOutputBuilder().setSourceBacklogBytes(backlogBytes);
     }
-    return callbacks;
+
+    this.accumulatedCallbacks.putAll(callbacks);
+
+    getOutputBuilder()
+        
.setSourceBytesProcessed(computeSourceBytesProcessed(sourceBytesProcessCounterName));
+  }
+
+  private final long computeSourceBytesProcessed(String 
sourceBytesCounterName) {
+    if (!(workExecutor instanceof DataflowMapTaskExecutor)) {
+      return 0L;
+    }
+    HashMap<String, ElementCounter> counters =
+        ((DataflowMapTaskExecutor) workExecutor)
+            .getReadOperation()
+            .receivers[0]
+            .getOutputCounters();
+
+    return Optional.ofNullable(counters.get(sourceBytesCounterName))
+        .map(counter -> ((OutputObjectAndByteCounter) 
counter).getByteCount().getAndReset())
+        .orElse(0L);
+  }
+
+  public Map<Long, Pair<Instant, Runnable>> flushState() {

Review Comment:
   the return value is pretty confusing here without comment or context.
   
   Should we change this to 
   getFinalizationCallbacks() and have flush not return anything?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java:
##########
@@ -553,7 +673,102 @@ public Map<Long, Pair<Instant, Runnable>> flushState() {
       // RestrictionTracker.getProgress() or GetSize() are not defined.
       getOutputBuilder().setSourceBacklogBytes(backlogBytes);
     }
-    return callbacks;
+
+    this.accumulatedCallbacks.putAll(callbacks);
+
+    getOutputBuilder()
+        
.setSourceBytesProcessed(computeSourceBytesProcessed(sourceBytesProcessCounterName));
+  }
+
+  private final long computeSourceBytesProcessed(String 
sourceBytesCounterName) {
+    if (!(workExecutor instanceof DataflowMapTaskExecutor)) {
+      return 0L;
+    }
+    HashMap<String, ElementCounter> counters =
+        ((DataflowMapTaskExecutor) workExecutor)
+            .getReadOperation()
+            .receivers[0]
+            .getOutputCounters();
+
+    return Optional.ofNullable(counters.get(sourceBytesCounterName))
+        .map(counter -> ((OutputObjectAndByteCounter) 
counter).getByteCount().getAndReset())
+        .orElse(0L);
+  }
+
+  public Map<Long, Pair<Instant, Runnable>> flushState() {
+    return accumulatedCallbacks;
+  }
+
+  public boolean advance() {
+    return false;
+  }
+
+  private void startForNewKey(Work newWork, WindmillStateReader reader) {
+    newWork.setState(Work.State.PROCESSING);
+    if (keyTransitionListener != null && this.work != null && this.work != 
newWork) {
+      keyTransitionListener.onKeyTransition(this.work, newWork);
+    }
+    this.key = decodeKey(newWork);
+    this.work = newWork;
+    this.finishKeyCalled = false;
+    this.computationKey = WindmillComputationKey.create(computationId, 
newWork.getShardedKey());
+
+    this.outputBuilder = createOutputBuilder(newWork);
+    this.outputBuilders.add(this.outputBuilder);
+    newWork.setOnFailureListener(this.workBatchFailed);
+    this.executedWorks.add(newWork);
+
+    logHotKeyIfDetected(newWork, this.key);
+
+    this.sideInputStateFetcher =
+        
sideInputStateFetcherFactory.createSideInputStateFetcher(newWork::fetchSideInput);
+    this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN;
+    this.activeReader = null;
+
+    // Note: We do NOT clear sideInputCache here, allowing Key B to reuse warm 
side inputs!
+
+    // Re-initialize state cache and state/timer internals across all step 
contexts
+    Instant processingTime =
+        
computeProcessingTime(newWork.getWorkItem().getTimers().getTimersList());
+    if (!getAllStepContexts().isEmpty()) {
+      // This must be only created once for a workItem as token validation 
will fail if the same
+      // work token is reused.
+      WindmillStateCache.ForKey cacheForKey =
+          stateCache.forKey(
+              getComputationKey(), newWork.getWorkItem().getCacheToken(), 
getWorkToken());
+      this.activeStateReader = reader;
+      startStepContexts(reader, processingTime, cacheForKey, 
newWork.watermarks());
+    } else {
+      this.activeStateReader = null;

Review Comment:
   maybe this shoudl just check it is already null?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java:
##########
@@ -553,7 +673,102 @@ public Map<Long, Pair<Instant, Runnable>> flushState() {
       // RestrictionTracker.getProgress() or GetSize() are not defined.
       getOutputBuilder().setSourceBacklogBytes(backlogBytes);
     }
-    return callbacks;
+
+    this.accumulatedCallbacks.putAll(callbacks);
+
+    getOutputBuilder()
+        
.setSourceBytesProcessed(computeSourceBytesProcessed(sourceBytesProcessCounterName));
+  }
+
+  private final long computeSourceBytesProcessed(String 
sourceBytesCounterName) {
+    if (!(workExecutor instanceof DataflowMapTaskExecutor)) {
+      return 0L;
+    }
+    HashMap<String, ElementCounter> counters =
+        ((DataflowMapTaskExecutor) workExecutor)
+            .getReadOperation()
+            .receivers[0]
+            .getOutputCounters();
+
+    return Optional.ofNullable(counters.get(sourceBytesCounterName))
+        .map(counter -> ((OutputObjectAndByteCounter) 
counter).getByteCount().getAndReset())
+        .orElse(0L);
+  }
+
+  public Map<Long, Pair<Instant, Runnable>> flushState() {
+    return accumulatedCallbacks;
+  }
+
+  public boolean advance() {
+    return false;
+  }
+
+  private void startForNewKey(Work newWork, WindmillStateReader reader) {
+    newWork.setState(Work.State.PROCESSING);
+    if (keyTransitionListener != null && this.work != null && this.work != 
newWork) {
+      keyTransitionListener.onKeyTransition(this.work, newWork);
+    }
+    this.key = decodeKey(newWork);
+    this.work = newWork;
+    this.finishKeyCalled = false;
+    this.computationKey = WindmillComputationKey.create(computationId, 
newWork.getShardedKey());
+
+    this.outputBuilder = createOutputBuilder(newWork);
+    this.outputBuilders.add(this.outputBuilder);
+    newWork.setOnFailureListener(this.workBatchFailed);
+    this.executedWorks.add(newWork);
+
+    logHotKeyIfDetected(newWork, this.key);
+
+    this.sideInputStateFetcher =
+        
sideInputStateFetcherFactory.createSideInputStateFetcher(newWork::fetchSideInput);
+    this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN;
+    this.activeReader = null;
+
+    // Note: We do NOT clear sideInputCache here, allowing Key B to reuse warm 
side inputs!
+
+    // Re-initialize state cache and state/timer internals across all step 
contexts
+    Instant processingTime =
+        
computeProcessingTime(newWork.getWorkItem().getTimers().getTimersList());
+    if (!getAllStepContexts().isEmpty()) {
+      // This must be only created once for a workItem as token validation 
will fail if the same
+      // work token is reused.
+      WindmillStateCache.ForKey cacheForKey =
+          stateCache.forKey(
+              getComputationKey(), newWork.getWorkItem().getCacheToken(), 
getWorkToken());
+      this.activeStateReader = reader;
+      startStepContexts(reader, processingTime, cacheForKey, 
newWork.watermarks());
+    } else {
+      this.activeStateReader = null;
+    }
+  }
+
+  public List<Work> getExecutedWorks() {
+    return executedWorks;
+  }
+
+  public long getStateBytesRead() {
+    return stateBytesRead;
+  }
+
+  public List<Windmill.WorkItemCommitRequest> getWorkItemCommits() {
+    List<Windmill.WorkItemCommitRequest> commits = new 
ArrayList<>(outputBuilders.size());
+    for (Windmill.WorkItemCommitRequest.Builder builder : outputBuilders) {
+      commits.add(builder.build());
+    }
+    return commits;
+  }
+
+  public Map<Long, Pair<Instant, Runnable>> getAccumulatedCallbacks() {
+    return accumulatedCallbacks;
+  }
+
+  public @Nullable Object getKey() {
+    return key;
+  }
+
+  public Work getWork() {

Review Comment:
   how about moving getKey, getWorkItem,getWork, getExecutedWorks all next to 
each other with some comments



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java:
##########
@@ -553,7 +673,102 @@ public Map<Long, Pair<Instant, Runnable>> flushState() {
       // RestrictionTracker.getProgress() or GetSize() are not defined.
       getOutputBuilder().setSourceBacklogBytes(backlogBytes);
     }
-    return callbacks;
+
+    this.accumulatedCallbacks.putAll(callbacks);
+
+    getOutputBuilder()
+        
.setSourceBytesProcessed(computeSourceBytesProcessed(sourceBytesProcessCounterName));
+  }
+
+  private final long computeSourceBytesProcessed(String 
sourceBytesCounterName) {
+    if (!(workExecutor instanceof DataflowMapTaskExecutor)) {
+      return 0L;
+    }
+    HashMap<String, ElementCounter> counters =
+        ((DataflowMapTaskExecutor) workExecutor)
+            .getReadOperation()
+            .receivers[0]
+            .getOutputCounters();
+
+    return Optional.ofNullable(counters.get(sourceBytesCounterName))
+        .map(counter -> ((OutputObjectAndByteCounter) 
counter).getByteCount().getAndReset())
+        .orElse(0L);
+  }
+
+  public Map<Long, Pair<Instant, Runnable>> flushState() {
+    return accumulatedCallbacks;
+  }
+
+  public boolean advance() {
+    return false;
+  }
+
+  private void startForNewKey(Work newWork, WindmillStateReader reader) {
+    newWork.setState(Work.State.PROCESSING);
+    if (keyTransitionListener != null && this.work != null && this.work != 
newWork) {
+      keyTransitionListener.onKeyTransition(this.work, newWork);
+    }
+    this.key = decodeKey(newWork);
+    this.work = newWork;
+    this.finishKeyCalled = false;
+    this.computationKey = WindmillComputationKey.create(computationId, 
newWork.getShardedKey());
+
+    this.outputBuilder = createOutputBuilder(newWork);
+    this.outputBuilders.add(this.outputBuilder);
+    newWork.setOnFailureListener(this.workBatchFailed);
+    this.executedWorks.add(newWork);
+
+    logHotKeyIfDetected(newWork, this.key);
+
+    this.sideInputStateFetcher =
+        
sideInputStateFetcherFactory.createSideInputStateFetcher(newWork::fetchSideInput);
+    this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN;
+    this.activeReader = null;
+
+    // Note: We do NOT clear sideInputCache here, allowing Key B to reuse warm 
side inputs!
+
+    // Re-initialize state cache and state/timer internals across all step 
contexts
+    Instant processingTime =
+        
computeProcessingTime(newWork.getWorkItem().getTimers().getTimersList());
+    if (!getAllStepContexts().isEmpty()) {
+      // This must be only created once for a workItem as token validation 
will fail if the same
+      // work token is reused.
+      WindmillStateCache.ForKey cacheForKey =
+          stateCache.forKey(
+              getComputationKey(), newWork.getWorkItem().getCacheToken(), 
getWorkToken());
+      this.activeStateReader = reader;
+      startStepContexts(reader, processingTime, cacheForKey, 
newWork.watermarks());
+    } else {
+      this.activeStateReader = null;
+    }
+  }
+
+  public List<Work> getExecutedWorks() {
+    return executedWorks;
+  }
+
+  public long getStateBytesRead() {
+    return stateBytesRead;
+  }
+
+  public List<Windmill.WorkItemCommitRequest> getWorkItemCommits() {
+    List<Windmill.WorkItemCommitRequest> commits = new 
ArrayList<>(outputBuilders.size());
+    for (Windmill.WorkItemCommitRequest.Builder builder : outputBuilders) {
+      commits.add(builder.build());
+    }
+    return commits;
+  }
+
+  public Map<Long, Pair<Instant, Runnable>> getAccumulatedCallbacks() {
+    return accumulatedCallbacks;
+  }
+
+  public @Nullable Object getKey() {

Review Comment:
   // Returns the current key being processed or null if an unkeyed stage.



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java:
##########
@@ -3511,8 +3520,8 @@ public void testExceptionInvalidatesCache() throws 
Exception {
     }
 
     // Ensure that the invalidated dofn had tearDown called on them.
-    assertEquals(1, TestExceptionInvalidatesCacheFn.tearDownCallCount.get());
-    assertEquals(2, TestExceptionInvalidatesCacheFn.setupCallCount.get());
+    assertEquals(2, TestExceptionInvalidatesCacheFn.tearDownCallCount.get());

Review Comment:
   why did this change?



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBaseTest.java:
##########
@@ -131,6 +133,76 @@ public void testFinishKeyCalled() throws Exception {
     }
   }
 
+  @Test
+  public void testAdvanceKeyChaining() throws Exception {
+    StreamingModeExecutionContext mockContext = 
mock(StreamingModeExecutionContext.class);
+    when(mockContext.workIsFailed()).thenReturn(false);
+
+    // Work item A (1 message)
+    Windmill.WorkItem workItemA =
+        Windmill.WorkItem.newBuilder()
+            .setKey(ByteString.copyFromUtf8("keyA"))
+            .setWorkToken(100L)
+            .addMessageBundles(
+                Windmill.InputMessageBundle.newBuilder()
+                    .setSourceComputationId("foo")
+                    .addMessages(
+                        Windmill.Message.newBuilder()
+                            .setTimestamp(1000)
+                            .setData(ByteString.EMPTY)
+                            .build())
+                    .build())
+            .build();
+    when(mockContext.getWorkItem()).thenReturn(workItemA);
+
+    // Work item B (1 message)
+    Windmill.WorkItem workItemB =
+        Windmill.WorkItem.newBuilder()
+            .setKey(ByteString.copyFromUtf8("keyB"))
+            .setWorkToken(200L)
+            .addMessageBundles(
+                Windmill.InputMessageBundle.newBuilder()
+                    .setSourceComputationId("foo")
+                    .addMessages(
+                        Windmill.Message.newBuilder()
+                            .setTimestamp(2000)
+                            .setData(ByteString.EMPTY)
+                            .build())
+                    .build())
+            .build();
+
+    // Set up context.advance() to mock transition

Review Comment:
   this seems a bit odd to test this way by just mocking out responses. It 
seems better suited for a test once support of multiple items via advance is 
actually within the context.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to