scwhittle commented on code in PR #38814:
URL: https://github.com/apache/beam/pull/38814#discussion_r3361253798


##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java:
##########
@@ -237,6 +243,17 @@ public void setProcessingThreadName(String 
processingThreadName) {
   @Override
   public void setFailed() {
     this.isFailed = true;
+    Runnable listener = onFailureListener;
+    if (listener != null) {
+      listener.run();
+    }
+  }
+
+  public void setOnFailureListener(@Nullable Runnable listener) {
+    this.onFailureListener = listener;
+    if (isFailed && listener != null) {

Review Comment:
   it seems we might call the listener twice if setFailed is just after setting 
isFailed to true.  Should we synchronize the isFailed and onFailureListener 
fields together?
   
   moot if we remove and just share a  failure bit



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java:
##########
@@ -388,86 +372,143 @@ private ExecuteWorkResult executeWork(
       SideInputStateFetcher localSideInputStateFetcher =
           
sideInputStateFetcherFactory.createSideInputStateFetcher(work::fetchSideInput);
 
-      // If the read output KVs, then we can decode Windmill's byte key into 
userland
-      // key object and provide it to the execution context for use with 
per-key state.
-      // Otherwise, we pass null.
-      //
-      // The coder type that will be present is:
-      //     WindowedValueCoder(TimerOrElementCoder(KvCoder))
-      Optional<Coder<?>> keyCoder = computationWorkExecutor.keyCoder();
-      @SuppressWarnings("deprecation")
-      @Nullable
-      final Object executionKey =
-          !keyCoder.isPresent() ? null : keyCoder.get().decode(key.newInput(), 
Coder.Context.OUTER);
-
-      if (workItem.hasHotKeyInfo()) {
-        Windmill.HotKeyInfo hotKeyInfo = workItem.getHotKeyInfo();
-        Duration hotKeyAge = Duration.millis(hotKeyInfo.getHotKeyAgeUsec() / 
1000);
-
-        String stepName = 
getShuffleTaskStepName(computationState.getMapTask());
-        if (executionKey != null
-            && (options.isHotKeyLoggingEnabled()
-                || hasExperiment(options, "enable_hot_key_logging"))
-            && keyCoder.isPresent()) {
-          hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge, executionKey);
-        } else {
-          hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge);
-        }
-      }
+      StreamingModeExecutionContext.KeySwitchListener keySwitchListener =
+          createKeySwitchListener(computationState);
 
       // Blocks while executing work.
       computationWorkExecutor.executeWork(
-          executionKey, work, stateReader, localSideInputStateFetcher, 
outputBuilder);
+          work, stateReader, localSideInputStateFetcher, workExecutor, handle, 
keySwitchListener);
 
-      if (work.isFailed()) {
-        throw new WorkItemCancelledException(workItem.getShardingKey());
+      StreamingModeExecutionContext context = 
computationWorkExecutor.context();
+      if (context.workIsFailed()) {
+        throw new 
WorkItemCancelledException(work.getWorkItem().getShardingKey());
       }
 
-      // Reports source bytes processed to WorkItemCommitRequest if available.
-      try {
-        long sourceBytesProcessed =
-            computationWorkExecutor.computeSourceBytesProcessed(
-                computationState.sourceBytesProcessCounterName());
-        outputBuilder.setSourceBytesProcessed(sourceBytesProcessed);
-      } catch (Exception e) {
-        LOG.error("{}", e.toString());
-      }
-
-      
commitFinalizer.cacheCommitFinalizers(computationWorkExecutor.context().flushState());
+      // Retrieve executed works, output builders, and accumulated callbacks 
from execution context
+      List<Work> workBatch = context.getExecutedWorks();
+      List<Windmill.WorkItemCommitRequest.Builder> outputBuilders = 
context.getOutputBuilders();
+      Map<Long, Pair<Instant, Runnable>> accumulatedCallbacks = 
context.getAccumulatedCallbacks();
 
+      context.clear();
       // Release the execution state for another thread to use.
       computationState.releaseComputationWorkExecutor(computationWorkExecutor);
       computationWorkExecutor = null;
 
-      work.setState(Work.State.COMMIT_QUEUED);
-      
outputBuilder.addAllPerWorkItemLatencyAttributions(work.getLatencyAttributions(sampler));
-
       return ExecuteWorkResult.create(
-          outputBuilder, stateReader.getBytesRead() + 
localSideInputStateFetcher.getBytesRead());
+          workBatch,
+          outputBuilders,
+          accumulatedCallbacks,
+          context.getStateBytesRead() + 
localSideInputStateFetcher.getBytesRead());
     } catch (Throwable t) {
       if (computationWorkExecutor != null) {
         // If processing failed due to a thrown exception, close the 
executionState. Do not
         // return/release the executionState back to computationState as that 
will lead to this
         // executionState instance being reused.
-        LOG.debug("Invalidating executor after work item {} failed", 
workItem.getWorkToken(), t);
+        LOG.debug(
+            "Invalidating executor after work item {} failed",
+            work.getWorkItem().getWorkToken(),
+            t);
         computationWorkExecutor.invalidate();
       }
-
       // Re-throw the exception, it will be caught and handled by 
workFailureProcessor downstream.
       throw t;
     }
   }
 
+  private void handleOnlyFinalize(
+      ComputationState computationState, Work work, Windmill.WorkItem 
workItem) {
+    Windmill.WorkItemCommitRequest.Builder outputBuilder =
+        initializeOutputBuilder(workItem.getKey(), workItem);
+    
outputBuilder.setSourceStateUpdates(Windmill.SourceState.newBuilder().setOnlyFinalize(true));
+    work.setState(Work.State.COMMIT_QUEUED);
+    work.queueCommit(outputBuilder.build(), computationState);
+  }
+
+  private StageInfo getStageInfo(ComputationState computationState) {
+    MapTask mapTask = computationState.getMapTask();
+    return stageInfoMap.computeIfAbsent(
+        mapTask.getStageName(), s -> StageInfo.create(s, 
mapTask.getSystemName()));
+  }
+
+  private void commitWorkBatch(
+      ComputationState computationState,
+      List<Work> workBatch,
+      List<Windmill.WorkItemCommitRequest.Builder> outputBuilders) {
+    Preconditions.checkState(
+        workBatch.size() == 1, "Expected single-key work batch, got: " + 
workBatch.size());
+    commitSingleKeyWork(computationState, workBatch.get(0), 
outputBuilders.get(0));
+  }
+
+  private void commitSingleKeyWork(
+      ComputationState computationState,
+      Work work,
+      Windmill.WorkItemCommitRequest.Builder commitRequestBuilder) {
+    // Validate the commit request, possibly requesting truncation if the 
commitSize is too large.
+    Windmill.WorkItemCommitRequest validatedCommitRequest =
+        validateCommitRequestSize(
+            commitRequestBuilder.build(), computationState.getComputationId(), 
work.getWorkItem());
+    work.setState(Work.State.COMMIT_QUEUED);
+    validatedCommitRequest =
+        validatedCommitRequest
+            .toBuilder()
+            
.addAllPerWorkItemLatencyAttributions(work.getLatencyAttributions(sampler))
+            .build();
+    work.queueCommit(validatedCommitRequest, computationState);
+  }
+
+  private void recordProcessingTime(
+      StageInfo stageInfo,
+      @Nullable List<Work> worksToCleanup,
+      Work work,
+      long processingStartTimeNanos) {
+    // Update total processing time counters. Updating in finally clause 
ensures that
+    // work items causing exceptions are also accounted in time spent.
+    long processingTimeMsecs =
+        TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - 
processingStartTimeNanos);
+    stageInfo.totalProcessingMsecs().addValue(processingTimeMsecs);
+    if (anyWorkHasTimers(worksToCleanup, work)) {
+      // Attribute all the processing to timers if the work item contains any 
timers.
+      // Tests show that work items rarely contain both timers and message 
bundles. It should
+      // be a fairly close approximation.
+      // Another option: Derive time split between messages and timers based 
on recent totals.
+      // either here or in DFE.
+      stageInfo.timerProcessingMsecs().addValue(processingTimeMsecs);
+    }
+  }
+
+  private static boolean anyWorkHasTimers(@Nullable List<Work> works, Work 
primaryWork) {
+    if (works != null && !works.isEmpty()) {
+      return works.stream().anyMatch(w -> w.getWorkItem().hasTimers());
+    }
+    return primaryWork.getWorkItem().hasTimers();
+  }
+
+  private StreamingModeExecutionContext.KeySwitchListener 
createKeySwitchListener(
+      ComputationState computationState) {
+    return (oldWork, newWork) -> {
+      resetWorkLoggingContext();
+      setUpWorkLoggingContext(newWork.getLatencyTrackingId(), 
computationState.getComputationId());
+      newWork.setProcessingThreadName(Thread.currentThread().getName());
+      oldWork.setProcessingThreadName("");
+    };
+  }
+
   @AutoValue
   abstract static class ExecuteWorkResult {
-
-    private static ExecuteWorkResult create(
-        Windmill.WorkItemCommitRequest.Builder commitWorkRequest, long 
stateBytesRead) {
+    static ExecuteWorkResult create(
+        List<Work> workBatch,
+        List<Windmill.WorkItemCommitRequest.Builder> outputBuilders,
+        Map<Long, Pair<Instant, Runnable>> accumulatedCallbacks,
+        long stateBytesRead) {
       return new AutoValue_StreamingWorkScheduler_ExecuteWorkResult(
-          commitWorkRequest, stateBytesRead);
+          workBatch, outputBuilders, accumulatedCallbacks, stateBytesRead);
     }
 
-    abstract Windmill.WorkItemCommitRequest.Builder commitWorkRequest();
+    abstract List<Work> workBatch();
+
+    abstract List<Windmill.WorkItemCommitRequest.Builder> outputBuilders();
+
+    abstract Map<Long, Pair<Instant, Runnable>> accumulatedCallbacks();

Review Comment:
   comment on what these are, what the params are



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java:
##########


Review Comment:
   should we get rid of this and merge into StreamingModeExecutionContext



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java:
##########
@@ -555,7 +666,92 @@ public Map<Long, Pair<Instant, Runnable>> flushState() {
       // RestrictionTracker.getProgress() or GetSize() are not defined.
       outputBuilder.setSourceBacklogBytes(backlogBytes);
     }
-    return callbacks;
+
+    this.accumulatedCallbacks.putAll(callbacks);
+
+    outputBuilder.setSourceBytesProcessed(
+        computeSourceBytesProcessed(sourceBytesProcessCounterName));
+  }
+
+  private final long computeSourceBytesProcessed(String 
sourceBytesCounterName) {
+    if (!(workExecutor instanceof DataflowMapTaskExecutor)) {
+      return 0L;
+    }
+    HashMap<String, ElementCounter> counters =
+        ((DataflowMapTaskExecutor) workExecutor)
+            .getReadOperation()
+            .receivers[0]
+            .getOutputCounters();
+
+    return Optional.ofNullable(counters.get(sourceBytesCounterName))
+        .map(counter -> ((OutputObjectAndByteCounter) 
counter).getByteCount().getAndReset())
+        .orElse(0L);
+  }
+
+  public Map<Long, Pair<Instant, Runnable>> flushState() {
+    return accumulatedCallbacks;
+  }
+
+  public boolean advance() {
+    return false;
+  }
+
+  private void startForNewKey(Work newWork, WindmillStateReader reader) {
+    if (keySwitchListener != null && this.work != null && this.work != 
newWork) {
+      keySwitchListener.onKeySwitch(this.work, newWork);
+    }
+    this.key = decodeKey(newWork);
+    this.work = newWork;
+    this.finishKeyCalled = false;
+    this.computationKey = WindmillComputationKey.create(computationId, 
newWork.getShardedKey());
+
+    this.outputBuilder = createOutputBuilder(newWork);
+    this.outputBuilders.add(this.outputBuilder);
+    newWork.setOnFailureListener(() -> this.workIsFailed = true);
+    this.executedWorks.add(newWork);
+
+    logHotKeyIfDetected(newWork, this.key);
+
+    // Note: We do NOT clear sideInputCache here, allowing Key B to reuse warm 
side inputs!
+
+    // Re-initialize state cache and state/timer internals across all step 
contexts
+    Instant processingTime =
+        
computeProcessingTime(newWork.getWorkItem().getTimers().getTimersList());
+    if (!getAllStepContexts().isEmpty()) {
+      // This must be only created once for a workItem as token validation 
will fail if the same
+      // work token is reused.
+      WindmillStateCache.ForKey cacheForKey =
+          stateCache.forKey(
+              getComputationKey(), newWork.getWorkItem().getCacheToken(), 
getWorkToken());
+      this.activeStateReader = reader;
+      startStepContexts(reader, processingTime, cacheForKey, 
newWork.watermarks());
+    } else {
+      this.activeStateReader = null;
+    }
+  }
+
+  public List<Work> getExecutedWorks() {
+    return executedWorks;
+  }
+
+  public long getStateBytesRead() {
+    return stateBytesRead;
+  }
+
+  public List<Windmill.WorkItemCommitRequest.Builder> getOutputBuilders() {

Review Comment:
   I'm a little concerned about exposing the internal data structures which 
might be modified externally. Can we instead add methods to add to these as 
needed?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java:
##########
@@ -258,37 +247,36 @@ private void processWork(
     // Before any processing starts, call any pending OnCommit callbacks.  
Nothing that requires
     // cleanup should be done before this, since we might exit early here.
     
commitFinalizer.finalizeCommits(workItem.getSourceState().getFinalizeIdsList());
+
     if (workItem.getSourceState().getOnlyFinalize()) {
-      Windmill.WorkItemCommitRequest.Builder outputBuilder = 
initializeOutputBuilder(key, workItem);
-      
outputBuilder.setSourceStateUpdates(Windmill.SourceState.newBuilder().setOnlyFinalize(true));
-      work.setState(Work.State.COMMIT_QUEUED);
-      work.queueCommit(outputBuilder.build(), computationState);
+      handleOnlyFinalize(computationState, work, workItem);
       return;
     }
 
     long processingStartTimeNanos = System.nanoTime();
-    MapTask mapTask = computationState.getMapTask();
-    StageInfo stageInfo =
-        stageInfoMap.computeIfAbsent(
-            mapTask.getStageName(), s -> StageInfo.create(s, 
mapTask.getSystemName()));
+    StageInfo stageInfo = getStageInfo(computationState);
 
+    List<Work> worksToCleanup = null;
     try {
       if (work.isFailed()) {
         throw new WorkItemCancelledException(workItem.getShardingKey());
       }
 
-      // Execute the user code for the Work.
-      ExecuteWorkResult executeWorkResult = executeWork(work, stageInfo, 
computationState);
-      Windmill.WorkItemCommitRequest.Builder commitRequest = 
executeWorkResult.commitWorkRequest();
+      // Execute the user code for the Work batch.
+      ExecuteWorkResult executeWorkResult = executeWork(work, stageInfo, 
computationState, handle);
+      List<Work> workBatch = executeWorkResult.workBatch();
+      worksToCleanup = workBatch;
+      List<Windmill.WorkItemCommitRequest.Builder> outputBuilders =
+          executeWorkResult.outputBuilders();
+      Map<Long, Pair<Instant, Runnable>> accumulatedCallbacks =

Review Comment:
   just pass function result directly to single use as param



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java:
##########
@@ -240,54 +280,123 @@ public byte[] getCurrentRecordOffset() {
     return activeReader.getCurrentRecordOffset();
   }
 
+  public void clear() {
+    for (Work w : executedWorks) {
+      w.setOnFailureListener(null);
+    }
+    this.executedWorks = new ArrayList<>();
+    this.outputBuilders = new ArrayList<>();
+    this.accumulatedCallbacks = new HashMap<>();
+    this.workIsFailed = false;
+    this.sideInputCache.clear();
+    this.activeStateReader = null;
+    this.activeReader = null;
+    this.keyCoder = null;
+    this.workExecutor = null;
+    this.workQueueExecutor = null;
+    this.budgetHandle = null;
+    this.keySwitchListener = null;
+    this.work = null;
+    this.key = null;
+    this.outputBuilder = null;
+  }
+
   public void start(
-      @Nullable Object key,
       Work work,
       WindmillStateReader stateReader,
       SideInputStateFetcher sideInputStateFetcher,
-      Windmill.WorkItemCommitRequest.Builder outputBuilder,
-      WorkExecutor workExecutor) {
-    this.key = key;
-    this.work = work;
+      WorkExecutor workExecutor,
+      BoundedQueueExecutor workQueueExecutor,
+      BoundedQueueExecutorWorkHandle budgetHandle,
+      @Nullable Coder<?> keyCoder,
+      KeySwitchListener keySwitchListener) {
+    clear();
+    this.keyCoder = keyCoder;
     this.workExecutor = workExecutor;
-    this.finishKeyCalled = false;
-    this.computationKey = WindmillComputationKey.create(computationId, 
work.getShardedKey());
-    this.sideInputStateFetcher = sideInputStateFetcher;
+    this.workQueueExecutor = workQueueExecutor;
+    this.budgetHandle = budgetHandle;
+    this.keySwitchListener = keySwitchListener;
+
+    this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN;
+    clearSinkFullHint();
+    this.stateBytesRead = 0;
+
     StreamingGlobalConfig config = globalConfigHandle.getConfig();
     // Snapshot the limits for entire bundle processing.
     this.operationalLimits = config.operationalLimits();
     this.windmillTagEncoding =
         config.enableStateTagEncodingV2()
             ? WindmillTagEncodingV2.instance()
             : WindmillTagEncodingV1.instance();
-    this.outputBuilder = outputBuilder;
-    this.sideInputCache.clear();
-    this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN;
-    clearSinkFullHint();
+    this.sideInputStateFetcher = sideInputStateFetcher;
 
-    Instant processingTime = 
computeProcessingTime(work.getWorkItem().getTimers().getTimersList());
+    startForNewKey(work, stateReader);
+  }
 
-    Collection<? extends StepContext> stepContexts = getAllStepContexts();
-    if (!stepContexts.isEmpty()) {
-      // This must be only created once for the workItem as token validation 
will fail if the same
-      // work token is reused.
-      WindmillStateCache.ForKey cacheForKey =
-          stateCache.forKey(getComputationKey(), 
getWorkItem().getCacheToken(), getWorkToken());
-      for (StepContext stepContext : stepContexts) {
-        stepContext.start(stateReader, processingTime, cacheForKey, 
work.watermarks());
+  private @Nullable Object decodeKey(Work work) {

Review Comment:
   just noting for myself that this moved from StreamingWorkScheduler so we're 
not decoding more than before



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java:
##########
@@ -354,27 +330,35 @@ private Windmill.WorkItemCommitRequest 
validateCommitRequestSize(
   }
 
   private void recordProcessingStats(
-      Windmill.WorkItemCommitRequest.Builder outputBuilder,
-      Windmill.WorkItem workItem,
-      ExecuteWorkResult executeWorkResult) {
-    // Compute shuffle and state byte statistics these will be flushed 
asynchronously.
-    long stateBytesWritten =
-        outputBuilder
-            .clearOutputMessages()
-            .clearPerWorkItemLatencyAttributions()
-            .build()
-            .getSerializedSize();
-
-    
streamingCounters.windmillShuffleBytesRead().addValue(computeShuffleBytesRead(workItem));
-    
streamingCounters.windmillStateBytesRead().addValue(executeWorkResult.stateBytesRead());
-    streamingCounters.windmillStateBytesWritten().addValue(stateBytesWritten);
+      List<Work> workBatch,
+      List<Windmill.WorkItemCommitRequest.Builder> outputBuilders,
+      long totalStateBytesRead) {
+    long totalStateBytesWritten = 0;
+    long totalShuffleBytesRead = 0;
+    for (int i = 0; i < workBatch.size(); i++) {
+      Windmill.WorkItem workItem = workBatch.get(i).getWorkItem();
+      Windmill.WorkItemCommitRequest.Builder outputBuilder = 
outputBuilders.get(i);

Review Comment:
   checkState that workBatch size and outputBuidlers size is the same?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java:
##########
@@ -306,22 +294,10 @@ private void processWork(
         throw ExceptionUtils.safeWrapThrowableAsException(t2);
       }
     } finally {
-      // Update total processing time counters. Updating in finally clause 
ensures that
-      // work items causing exceptions are also accounted in time spent.
-      long processingTimeMsecs =
-          TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - 
processingStartTimeNanos);
-      stageInfo.totalProcessingMsecs().addValue(processingTimeMsecs);
-
-      // Attribute all the processing to timers if the work item contains any 
timers.
-      // Tests show that work items rarely contain both timers and message 
bundles. It should
-      // be a fairly close approximation.
-      // Another option: Derive time split between messages and timers based 
on recent totals.
-      // either here or in DFE.
-      if (work.getWorkItem().hasTimers()) {
-        stageInfo.timerProcessingMsecs().addValue(processingTimeMsecs);
-      }
+      recordProcessingTime(stageInfo, worksToCleanup, work, 
processingStartTimeNanos);

Review Comment:
   keep comment here about why it is done in finally clause



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java:
##########
@@ -388,86 +372,143 @@ private ExecuteWorkResult executeWork(
       SideInputStateFetcher localSideInputStateFetcher =
           
sideInputStateFetcherFactory.createSideInputStateFetcher(work::fetchSideInput);
 
-      // If the read output KVs, then we can decode Windmill's byte key into 
userland
-      // key object and provide it to the execution context for use with 
per-key state.
-      // Otherwise, we pass null.
-      //
-      // The coder type that will be present is:
-      //     WindowedValueCoder(TimerOrElementCoder(KvCoder))
-      Optional<Coder<?>> keyCoder = computationWorkExecutor.keyCoder();
-      @SuppressWarnings("deprecation")
-      @Nullable
-      final Object executionKey =
-          !keyCoder.isPresent() ? null : keyCoder.get().decode(key.newInput(), 
Coder.Context.OUTER);
-
-      if (workItem.hasHotKeyInfo()) {
-        Windmill.HotKeyInfo hotKeyInfo = workItem.getHotKeyInfo();
-        Duration hotKeyAge = Duration.millis(hotKeyInfo.getHotKeyAgeUsec() / 
1000);
-
-        String stepName = 
getShuffleTaskStepName(computationState.getMapTask());
-        if (executionKey != null
-            && (options.isHotKeyLoggingEnabled()
-                || hasExperiment(options, "enable_hot_key_logging"))
-            && keyCoder.isPresent()) {
-          hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge, executionKey);
-        } else {
-          hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge);
-        }
-      }
+      StreamingModeExecutionContext.KeySwitchListener keySwitchListener =
+          createKeySwitchListener(computationState);
 
       // Blocks while executing work.
       computationWorkExecutor.executeWork(
-          executionKey, work, stateReader, localSideInputStateFetcher, 
outputBuilder);
+          work, stateReader, localSideInputStateFetcher, workExecutor, handle, 
keySwitchListener);
 
-      if (work.isFailed()) {
-        throw new WorkItemCancelledException(workItem.getShardingKey());
+      StreamingModeExecutionContext context = 
computationWorkExecutor.context();
+      if (context.workIsFailed()) {
+        throw new 
WorkItemCancelledException(work.getWorkItem().getShardingKey());
       }
 
-      // Reports source bytes processed to WorkItemCommitRequest if available.
-      try {
-        long sourceBytesProcessed =
-            computationWorkExecutor.computeSourceBytesProcessed(
-                computationState.sourceBytesProcessCounterName());
-        outputBuilder.setSourceBytesProcessed(sourceBytesProcessed);
-      } catch (Exception e) {
-        LOG.error("{}", e.toString());
-      }
-
-      
commitFinalizer.cacheCommitFinalizers(computationWorkExecutor.context().flushState());
+      // Retrieve executed works, output builders, and accumulated callbacks 
from execution context
+      List<Work> workBatch = context.getExecutedWorks();
+      List<Windmill.WorkItemCommitRequest.Builder> outputBuilders = 
context.getOutputBuilders();
+      Map<Long, Pair<Instant, Runnable>> accumulatedCallbacks = 
context.getAccumulatedCallbacks();
 
+      context.clear();
       // Release the execution state for another thread to use.
       computationState.releaseComputationWorkExecutor(computationWorkExecutor);
       computationWorkExecutor = null;
 
-      work.setState(Work.State.COMMIT_QUEUED);
-      
outputBuilder.addAllPerWorkItemLatencyAttributions(work.getLatencyAttributions(sampler));
-
       return ExecuteWorkResult.create(
-          outputBuilder, stateReader.getBytesRead() + 
localSideInputStateFetcher.getBytesRead());
+          workBatch,
+          outputBuilders,
+          accumulatedCallbacks,
+          context.getStateBytesRead() + 
localSideInputStateFetcher.getBytesRead());
     } catch (Throwable t) {
       if (computationWorkExecutor != null) {
         // If processing failed due to a thrown exception, close the 
executionState. Do not
         // return/release the executionState back to computationState as that 
will lead to this
         // executionState instance being reused.
-        LOG.debug("Invalidating executor after work item {} failed", 
workItem.getWorkToken(), t);
+        LOG.debug(
+            "Invalidating executor after work item {} failed",
+            work.getWorkItem().getWorkToken(),
+            t);
         computationWorkExecutor.invalidate();
       }
-
       // Re-throw the exception, it will be caught and handled by 
workFailureProcessor downstream.
       throw t;
     }
   }
 
+  private void handleOnlyFinalize(
+      ComputationState computationState, Work work, Windmill.WorkItem 
workItem) {
+    Windmill.WorkItemCommitRequest.Builder outputBuilder =
+        initializeOutputBuilder(workItem.getKey(), workItem);
+    
outputBuilder.setSourceStateUpdates(Windmill.SourceState.newBuilder().setOnlyFinalize(true));
+    work.setState(Work.State.COMMIT_QUEUED);
+    work.queueCommit(outputBuilder.build(), computationState);
+  }
+
+  private StageInfo getStageInfo(ComputationState computationState) {
+    MapTask mapTask = computationState.getMapTask();
+    return stageInfoMap.computeIfAbsent(
+        mapTask.getStageName(), s -> StageInfo.create(s, 
mapTask.getSystemName()));
+  }
+
+  private void commitWorkBatch(
+      ComputationState computationState,
+      List<Work> workBatch,
+      List<Windmill.WorkItemCommitRequest.Builder> outputBuilders) {
+    Preconditions.checkState(
+        workBatch.size() == 1, "Expected single-key work batch, got: " + 
workBatch.size());
+    commitSingleKeyWork(computationState, workBatch.get(0), 
outputBuilders.get(0));
+  }
+
+  private void commitSingleKeyWork(
+      ComputationState computationState,
+      Work work,
+      Windmill.WorkItemCommitRequest.Builder commitRequestBuilder) {
+    // Validate the commit request, possibly requesting truncation if the 
commitSize is too large.
+    Windmill.WorkItemCommitRequest validatedCommitRequest =
+        validateCommitRequestSize(
+            commitRequestBuilder.build(), computationState.getComputationId(), 
work.getWorkItem());
+    work.setState(Work.State.COMMIT_QUEUED);
+    validatedCommitRequest =
+        validatedCommitRequest
+            .toBuilder()
+            
.addAllPerWorkItemLatencyAttributions(work.getLatencyAttributions(sampler))
+            .build();
+    work.queueCommit(validatedCommitRequest, computationState);
+  }
+
+  private void recordProcessingTime(
+      StageInfo stageInfo,
+      @Nullable List<Work> worksToCleanup,
+      Work work,
+      long processingStartTimeNanos) {
+    // Update total processing time counters. Updating in finally clause 
ensures that
+    // work items causing exceptions are also accounted in time spent.
+    long processingTimeMsecs =
+        TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - 
processingStartTimeNanos);
+    stageInfo.totalProcessingMsecs().addValue(processingTimeMsecs);
+    if (anyWorkHasTimers(worksToCleanup, work)) {
+      // Attribute all the processing to timers if the work item contains any 
timers.
+      // Tests show that work items rarely contain both timers and message 
bundles. It should
+      // be a fairly close approximation.
+      // Another option: Derive time split between messages and timers based 
on recent totals.
+      // either here or in DFE.
+      stageInfo.timerProcessingMsecs().addValue(processingTimeMsecs);
+    }
+  }
+
+  private static boolean anyWorkHasTimers(@Nullable List<Work> works, Work 
primaryWork) {
+    if (works != null && !works.isEmpty()) {
+      return works.stream().anyMatch(w -> w.getWorkItem().hasTimers());
+    }
+    return primaryWork.getWorkItem().hasTimers();
+  }
+
+  private StreamingModeExecutionContext.KeySwitchListener 
createKeySwitchListener(
+      ComputationState computationState) {
+    return (oldWork, newWork) -> {
+      resetWorkLoggingContext();
+      setUpWorkLoggingContext(newWork.getLatencyTrackingId(), 
computationState.getComputationId());
+      newWork.setProcessingThreadName(Thread.currentThread().getName());

Review Comment:
   could just take it from oldWork without thread-local state



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java:
##########
@@ -388,86 +372,143 @@ private ExecuteWorkResult executeWork(
       SideInputStateFetcher localSideInputStateFetcher =
           
sideInputStateFetcherFactory.createSideInputStateFetcher(work::fetchSideInput);
 
-      // If the read output KVs, then we can decode Windmill's byte key into 
userland
-      // key object and provide it to the execution context for use with 
per-key state.
-      // Otherwise, we pass null.
-      //
-      // The coder type that will be present is:
-      //     WindowedValueCoder(TimerOrElementCoder(KvCoder))
-      Optional<Coder<?>> keyCoder = computationWorkExecutor.keyCoder();
-      @SuppressWarnings("deprecation")
-      @Nullable
-      final Object executionKey =
-          !keyCoder.isPresent() ? null : keyCoder.get().decode(key.newInput(), 
Coder.Context.OUTER);
-
-      if (workItem.hasHotKeyInfo()) {
-        Windmill.HotKeyInfo hotKeyInfo = workItem.getHotKeyInfo();
-        Duration hotKeyAge = Duration.millis(hotKeyInfo.getHotKeyAgeUsec() / 
1000);
-
-        String stepName = 
getShuffleTaskStepName(computationState.getMapTask());
-        if (executionKey != null
-            && (options.isHotKeyLoggingEnabled()
-                || hasExperiment(options, "enable_hot_key_logging"))
-            && keyCoder.isPresent()) {
-          hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge, executionKey);
-        } else {
-          hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge);
-        }
-      }
+      StreamingModeExecutionContext.KeySwitchListener keySwitchListener =
+          createKeySwitchListener(computationState);
 
       // Blocks while executing work.
       computationWorkExecutor.executeWork(
-          executionKey, work, stateReader, localSideInputStateFetcher, 
outputBuilder);
+          work, stateReader, localSideInputStateFetcher, workExecutor, handle, 
keySwitchListener);
 
-      if (work.isFailed()) {
-        throw new WorkItemCancelledException(workItem.getShardingKey());
+      StreamingModeExecutionContext context = 
computationWorkExecutor.context();
+      if (context.workIsFailed()) {
+        throw new 
WorkItemCancelledException(work.getWorkItem().getShardingKey());
       }
 
-      // Reports source bytes processed to WorkItemCommitRequest if available.
-      try {
-        long sourceBytesProcessed =
-            computationWorkExecutor.computeSourceBytesProcessed(
-                computationState.sourceBytesProcessCounterName());
-        outputBuilder.setSourceBytesProcessed(sourceBytesProcessed);
-      } catch (Exception e) {
-        LOG.error("{}", e.toString());
-      }
-
-      
commitFinalizer.cacheCommitFinalizers(computationWorkExecutor.context().flushState());
+      // Retrieve executed works, output builders, and accumulated callbacks 
from execution context
+      List<Work> workBatch = context.getExecutedWorks();
+      List<Windmill.WorkItemCommitRequest.Builder> outputBuilders = 
context.getOutputBuilders();
+      Map<Long, Pair<Instant, Runnable>> accumulatedCallbacks = 
context.getAccumulatedCallbacks();
 
+      context.clear();
       // Release the execution state for another thread to use.
       computationState.releaseComputationWorkExecutor(computationWorkExecutor);
       computationWorkExecutor = null;
 
-      work.setState(Work.State.COMMIT_QUEUED);
-      
outputBuilder.addAllPerWorkItemLatencyAttributions(work.getLatencyAttributions(sampler));
-
       return ExecuteWorkResult.create(
-          outputBuilder, stateReader.getBytesRead() + 
localSideInputStateFetcher.getBytesRead());
+          workBatch,
+          outputBuilders,
+          accumulatedCallbacks,
+          context.getStateBytesRead() + 
localSideInputStateFetcher.getBytesRead());
     } catch (Throwable t) {
       if (computationWorkExecutor != null) {
         // If processing failed due to a thrown exception, close the 
executionState. Do not
         // return/release the executionState back to computationState as that 
will lead to this
         // executionState instance being reused.
-        LOG.debug("Invalidating executor after work item {} failed", 
workItem.getWorkToken(), t);
+        LOG.debug(
+            "Invalidating executor after work item {} failed",
+            work.getWorkItem().getWorkToken(),
+            t);
         computationWorkExecutor.invalidate();
       }
-
       // Re-throw the exception, it will be caught and handled by 
workFailureProcessor downstream.
       throw t;
     }
   }
 
+  private void handleOnlyFinalize(
+      ComputationState computationState, Work work, Windmill.WorkItem 
workItem) {
+    Windmill.WorkItemCommitRequest.Builder outputBuilder =
+        initializeOutputBuilder(workItem.getKey(), workItem);
+    
outputBuilder.setSourceStateUpdates(Windmill.SourceState.newBuilder().setOnlyFinalize(true));
+    work.setState(Work.State.COMMIT_QUEUED);
+    work.queueCommit(outputBuilder.build(), computationState);
+  }
+
+  private StageInfo getStageInfo(ComputationState computationState) {
+    MapTask mapTask = computationState.getMapTask();
+    return stageInfoMap.computeIfAbsent(
+        mapTask.getStageName(), s -> StageInfo.create(s, 
mapTask.getSystemName()));
+  }
+
+  private void commitWorkBatch(
+      ComputationState computationState,
+      List<Work> workBatch,
+      List<Windmill.WorkItemCommitRequest.Builder> outputBuilders) {
+    Preconditions.checkState(
+        workBatch.size() == 1, "Expected single-key work batch, got: " + 
workBatch.size());
+    commitSingleKeyWork(computationState, workBatch.get(0), 
outputBuilders.get(0));
+  }
+
+  private void commitSingleKeyWork(
+      ComputationState computationState,
+      Work work,
+      Windmill.WorkItemCommitRequest.Builder commitRequestBuilder) {
+    // Validate the commit request, possibly requesting truncation if the 
commitSize is too large.
+    Windmill.WorkItemCommitRequest validatedCommitRequest =
+        validateCommitRequestSize(
+            commitRequestBuilder.build(), computationState.getComputationId(), 
work.getWorkItem());

Review Comment:
   see the other comment about builders being mutated, could we lift this build 
up so that we pass immutable built things to commit and the record processing 
stats?
   



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java:
##########
@@ -388,86 +372,143 @@ private ExecuteWorkResult executeWork(
       SideInputStateFetcher localSideInputStateFetcher =
           
sideInputStateFetcherFactory.createSideInputStateFetcher(work::fetchSideInput);
 
-      // If the read output KVs, then we can decode Windmill's byte key into 
userland
-      // key object and provide it to the execution context for use with 
per-key state.
-      // Otherwise, we pass null.
-      //
-      // The coder type that will be present is:
-      //     WindowedValueCoder(TimerOrElementCoder(KvCoder))
-      Optional<Coder<?>> keyCoder = computationWorkExecutor.keyCoder();
-      @SuppressWarnings("deprecation")
-      @Nullable
-      final Object executionKey =
-          !keyCoder.isPresent() ? null : keyCoder.get().decode(key.newInput(), 
Coder.Context.OUTER);
-
-      if (workItem.hasHotKeyInfo()) {
-        Windmill.HotKeyInfo hotKeyInfo = workItem.getHotKeyInfo();
-        Duration hotKeyAge = Duration.millis(hotKeyInfo.getHotKeyAgeUsec() / 
1000);
-
-        String stepName = 
getShuffleTaskStepName(computationState.getMapTask());
-        if (executionKey != null
-            && (options.isHotKeyLoggingEnabled()
-                || hasExperiment(options, "enable_hot_key_logging"))
-            && keyCoder.isPresent()) {
-          hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge, executionKey);
-        } else {
-          hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge);
-        }
-      }
+      StreamingModeExecutionContext.KeySwitchListener keySwitchListener =
+          createKeySwitchListener(computationState);
 
       // Blocks while executing work.
       computationWorkExecutor.executeWork(
-          executionKey, work, stateReader, localSideInputStateFetcher, 
outputBuilder);
+          work, stateReader, localSideInputStateFetcher, workExecutor, handle, 
keySwitchListener);
 
-      if (work.isFailed()) {
-        throw new WorkItemCancelledException(workItem.getShardingKey());
+      StreamingModeExecutionContext context = 
computationWorkExecutor.context();
+      if (context.workIsFailed()) {
+        throw new 
WorkItemCancelledException(work.getWorkItem().getShardingKey());
       }
 
-      // Reports source bytes processed to WorkItemCommitRequest if available.
-      try {
-        long sourceBytesProcessed =
-            computationWorkExecutor.computeSourceBytesProcessed(
-                computationState.sourceBytesProcessCounterName());
-        outputBuilder.setSourceBytesProcessed(sourceBytesProcessed);
-      } catch (Exception e) {
-        LOG.error("{}", e.toString());
-      }
-
-      
commitFinalizer.cacheCommitFinalizers(computationWorkExecutor.context().flushState());
+      // Retrieve executed works, output builders, and accumulated callbacks 
from execution context
+      List<Work> workBatch = context.getExecutedWorks();
+      List<Windmill.WorkItemCommitRequest.Builder> outputBuilders = 
context.getOutputBuilders();
+      Map<Long, Pair<Instant, Runnable>> accumulatedCallbacks = 
context.getAccumulatedCallbacks();
 
+      context.clear();
       // Release the execution state for another thread to use.
       computationState.releaseComputationWorkExecutor(computationWorkExecutor);
       computationWorkExecutor = null;
 
-      work.setState(Work.State.COMMIT_QUEUED);
-      
outputBuilder.addAllPerWorkItemLatencyAttributions(work.getLatencyAttributions(sampler));
-
       return ExecuteWorkResult.create(
-          outputBuilder, stateReader.getBytesRead() + 
localSideInputStateFetcher.getBytesRead());
+          workBatch,
+          outputBuilders,
+          accumulatedCallbacks,
+          context.getStateBytesRead() + 
localSideInputStateFetcher.getBytesRead());
     } catch (Throwable t) {
       if (computationWorkExecutor != null) {
         // If processing failed due to a thrown exception, close the 
executionState. Do not
         // return/release the executionState back to computationState as that 
will lead to this
         // executionState instance being reused.
-        LOG.debug("Invalidating executor after work item {} failed", 
workItem.getWorkToken(), t);
+        LOG.debug(
+            "Invalidating executor after work item {} failed",
+            work.getWorkItem().getWorkToken(),
+            t);
         computationWorkExecutor.invalidate();
       }
-
       // Re-throw the exception, it will be caught and handled by 
workFailureProcessor downstream.
       throw t;
     }
   }
 
+  private void handleOnlyFinalize(
+      ComputationState computationState, Work work, Windmill.WorkItem 
workItem) {
+    Windmill.WorkItemCommitRequest.Builder outputBuilder =
+        initializeOutputBuilder(workItem.getKey(), workItem);
+    
outputBuilder.setSourceStateUpdates(Windmill.SourceState.newBuilder().setOnlyFinalize(true));
+    work.setState(Work.State.COMMIT_QUEUED);
+    work.queueCommit(outputBuilder.build(), computationState);
+  }
+
+  private StageInfo getStageInfo(ComputationState computationState) {
+    MapTask mapTask = computationState.getMapTask();
+    return stageInfoMap.computeIfAbsent(
+        mapTask.getStageName(), s -> StageInfo.create(s, 
mapTask.getSystemName()));
+  }
+
+  private void commitWorkBatch(
+      ComputationState computationState,
+      List<Work> workBatch,
+      List<Windmill.WorkItemCommitRequest.Builder> outputBuilders) {
+    Preconditions.checkState(
+        workBatch.size() == 1, "Expected single-key work batch, got: " + 
workBatch.size());
+    commitSingleKeyWork(computationState, workBatch.get(0), 
outputBuilders.get(0));
+  }
+
+  private void commitSingleKeyWork(
+      ComputationState computationState,
+      Work work,
+      Windmill.WorkItemCommitRequest.Builder commitRequestBuilder) {
+    // Validate the commit request, possibly requesting truncation if the 
commitSize is too large.
+    Windmill.WorkItemCommitRequest validatedCommitRequest =
+        validateCommitRequestSize(
+            commitRequestBuilder.build(), computationState.getComputationId(), 
work.getWorkItem());
+    work.setState(Work.State.COMMIT_QUEUED);
+    validatedCommitRequest =
+        validatedCommitRequest
+            .toBuilder()
+            
.addAllPerWorkItemLatencyAttributions(work.getLatencyAttributions(sampler))
+            .build();
+    work.queueCommit(validatedCommitRequest, computationState);
+  }
+
+  private void recordProcessingTime(
+      StageInfo stageInfo,
+      @Nullable List<Work> worksToCleanup,
+      Work work,
+      long processingStartTimeNanos) {
+    // Update total processing time counters. Updating in finally clause 
ensures that
+    // work items causing exceptions are also accounted in time spent.
+    long processingTimeMsecs =
+        TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - 
processingStartTimeNanos);
+    stageInfo.totalProcessingMsecs().addValue(processingTimeMsecs);
+    if (anyWorkHasTimers(worksToCleanup, work)) {
+      // Attribute all the processing to timers if the work item contains any 
timers.
+      // Tests show that work items rarely contain both timers and message 
bundles. It should
+      // be a fairly close approximation.
+      // Another option: Derive time split between messages and timers based 
on recent totals.
+      // either here or in DFE.
+      stageInfo.timerProcessingMsecs().addValue(processingTimeMsecs);
+    }
+  }
+
+  private static boolean anyWorkHasTimers(@Nullable List<Work> works, Work 
primaryWork) {
+    if (works != null && !works.isEmpty()) {
+      return works.stream().anyMatch(w -> w.getWorkItem().hasTimers());
+    }
+    return primaryWork.getWorkItem().hasTimers();
+  }
+
+  private StreamingModeExecutionContext.KeySwitchListener 
createKeySwitchListener(
+      ComputationState computationState) {
+    return (oldWork, newWork) -> {
+      resetWorkLoggingContext();

Review Comment:
   can we specialize what we have to migrate for the MDC? this reset doesn't 
seem to have anything that needs to change for a key.  And setup doesn't seem 
like it needs to change the computation.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java:
##########
@@ -162,6 +167,33 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
   private @Nullable WorkExecutor workExecutor;
   private boolean finishKeyCalled = false;
 
+  @SuppressWarnings("UnusedVariable")
+  private @Nullable BoundedQueueExecutor workQueueExecutor;
+
+  @SuppressWarnings("UnusedVariable")
+  private @Nullable BoundedQueueExecutorWorkHandle budgetHandle;
+
+  private final HotKeyLogger hotKeyLogger;
+  private boolean hotKeyLoggingEnabled = false;
+  private final String stepName;
+  private @Nullable Coder<?> keyCoder;
+
+  // Key switch listener to delegate MDC logging context and thread name 
updates
+  public interface KeySwitchListener {
+    void onKeySwitch(Work oldWork, Work newWork);

Review Comment:
   how about KeyTransitionListener etc.  Switch implies something moves 
backwards as well to me



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java:
##########
@@ -555,7 +666,92 @@ public Map<Long, Pair<Instant, Runnable>> flushState() {
       // RestrictionTracker.getProgress() or GetSize() are not defined.
       outputBuilder.setSourceBacklogBytes(backlogBytes);
     }
-    return callbacks;
+
+    this.accumulatedCallbacks.putAll(callbacks);
+
+    outputBuilder.setSourceBytesProcessed(
+        computeSourceBytesProcessed(sourceBytesProcessCounterName));
+  }
+
+  private final long computeSourceBytesProcessed(String 
sourceBytesCounterName) {
+    if (!(workExecutor instanceof DataflowMapTaskExecutor)) {
+      return 0L;
+    }
+    HashMap<String, ElementCounter> counters =
+        ((DataflowMapTaskExecutor) workExecutor)
+            .getReadOperation()
+            .receivers[0]
+            .getOutputCounters();
+
+    return Optional.ofNullable(counters.get(sourceBytesCounterName))
+        .map(counter -> ((OutputObjectAndByteCounter) 
counter).getByteCount().getAndReset())
+        .orElse(0L);
+  }
+
+  public Map<Long, Pair<Instant, Runnable>> flushState() {
+    return accumulatedCallbacks;
+  }
+
+  public boolean advance() {
+    return false;
+  }
+
+  private void startForNewKey(Work newWork, WindmillStateReader reader) {
+    if (keySwitchListener != null && this.work != null && this.work != 
newWork) {
+      keySwitchListener.onKeySwitch(this.work, newWork);
+    }
+    this.key = decodeKey(newWork);
+    this.work = newWork;
+    this.finishKeyCalled = false;
+    this.computationKey = WindmillComputationKey.create(computationId, 
newWork.getShardedKey());
+
+    this.outputBuilder = createOutputBuilder(newWork);
+    this.outputBuilders.add(this.outputBuilder);
+    newWork.setOnFailureListener(() -> this.workIsFailed = true);
+    this.executedWorks.add(newWork);
+
+    logHotKeyIfDetected(newWork, this.key);
+
+    // Note: We do NOT clear sideInputCache here, allowing Key B to reuse warm 
side inputs!
+
+    // Re-initialize state cache and state/timer internals across all step 
contexts
+    Instant processingTime =
+        
computeProcessingTime(newWork.getWorkItem().getTimers().getTimersList());
+    if (!getAllStepContexts().isEmpty()) {
+      // This must be only created once for a workItem as token validation 
will fail if the same

Review Comment:
   not sure what this means, I don't see any guard against recreating it.  I'm 
wondering if there are cases with retries etc that we could get the same key 
and possibly same work token. Should we guard against that somehow?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java:
##########
@@ -555,7 +666,92 @@ public Map<Long, Pair<Instant, Runnable>> flushState() {
       // RestrictionTracker.getProgress() or GetSize() are not defined.
       outputBuilder.setSourceBacklogBytes(backlogBytes);
     }
-    return callbacks;
+
+    this.accumulatedCallbacks.putAll(callbacks);
+
+    outputBuilder.setSourceBytesProcessed(
+        computeSourceBytesProcessed(sourceBytesProcessCounterName));
+  }
+
+  private final long computeSourceBytesProcessed(String 
sourceBytesCounterName) {
+    if (!(workExecutor instanceof DataflowMapTaskExecutor)) {
+      return 0L;
+    }
+    HashMap<String, ElementCounter> counters =
+        ((DataflowMapTaskExecutor) workExecutor)
+            .getReadOperation()
+            .receivers[0]
+            .getOutputCounters();
+
+    return Optional.ofNullable(counters.get(sourceBytesCounterName))
+        .map(counter -> ((OutputObjectAndByteCounter) 
counter).getByteCount().getAndReset())
+        .orElse(0L);
+  }
+
+  public Map<Long, Pair<Instant, Runnable>> flushState() {
+    return accumulatedCallbacks;
+  }
+
+  public boolean advance() {
+    return false;
+  }
+
+  private void startForNewKey(Work newWork, WindmillStateReader reader) {
+    if (keySwitchListener != null && this.work != null && this.work != 
newWork) {
+      keySwitchListener.onKeySwitch(this.work, newWork);
+    }
+    this.key = decodeKey(newWork);
+    this.work = newWork;
+    this.finishKeyCalled = false;
+    this.computationKey = WindmillComputationKey.create(computationId, 
newWork.getShardedKey());
+
+    this.outputBuilder = createOutputBuilder(newWork);
+    this.outputBuilders.add(this.outputBuilder);
+    newWork.setOnFailureListener(() -> this.workIsFailed = true);

Review Comment:
   if this is the only failure listener, what about some simpler support for 
it? We could pass an AtomicBool (or wrapper around one that is one-way 
false->true). the individual works could just store that instead of their own 
bool, so that the work fails as a group.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java:
##########


Review Comment:
   can we remove this? helpful to get static checking when making larger changes
   
   could do it in separate PR first if you want



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java:
##########
@@ -57,22 +57,35 @@ public boolean start() throws IOException {
   @Override
   public boolean advance() throws IOException {
     if (context.workIsFailed()) {
-      throw new 
WorkItemCancelledException(context.getWorkItem().getShardingKey());
+      throw new 
WorkItemCancelledException(checkNotNull(context.getWorkItem()).getShardingKey());
     }
 
     while (true) {
       if (bundleIndex >= work.getMessageBundlesCount()) {
-        current = null;
+        // If elements are exhausted, try advancing the execution context to 
the next key in the
+        // group
         context.finishKey();
+        if (context.advance()) {
+          // Transition succeeded! Update iterator references to the new work 
item
+          this.work = context.getWork().getWorkItem();
+          this.bundleIndex = 0;
+          this.messageIndex = -1;
+          continue;
+        }
+
+        // All work items are exhausted. Iterator returns false.

Review Comment:
   rm second comment sentence



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java:
##########
@@ -151,51 +151,65 @@ public 
NativeReaderIterator<WindowedValue<KeyedWorkItem<K, T>>> iterator() throw
             && Iterables.isEmpty(keyedWorkItem.elementsIterable()));
     final WindowedValue<KeyedWorkItem<K, T>> value = new 
ValueInEmptyWindows<>(keyedWorkItem);
 
-    // Return a noop iterator when current workitem is an empty workitem.
-    if (isEmptyWorkItem) {
-      return new NativeReaderIterator<WindowedValue<KeyedWorkItem<K, T>>>() {
-        @Override
-        public boolean start() throws IOException {
-          context.finishKey();
-          return false;
+    return new NativeReaderIterator<WindowedValue<KeyedWorkItem<K, T>>>() {
+      private @Nullable WindowedValue<KeyedWorkItem<K, T>> current = null;
+      private boolean started = false;
+
+      @Override
+      public boolean start() throws IOException {
+        if (context.workIsFailed()) {
+          throw new WorkItemCancelledException(
+              checkStateNotNull(context.getWorkItem()).getShardingKey());
         }
-
-        @Override
-        public boolean advance() throws IOException {
+        if (started) {
           return false;
         }
-
-        @Override
-        public WindowedValue<KeyedWorkItem<K, T>> getCurrent() {
-          throw new NoSuchElementException();
+        started = true;
+        if (isEmptyWorkItem) {
+          return advance(); // Try to transition immediately if the first key 
is empty!
         }
-      };
-    } else {
-      return new NativeReaderIterator<WindowedValue<KeyedWorkItem<K, T>>>() {
-        private @Nullable WindowedValue<KeyedWorkItem<K, T>> current = null;
-
-        @Override
-        public boolean start() throws IOException {
-          current = value;
-          return true;
+        current = value;
+        return true;
+      }
+
+      @Override
+      public boolean advance() throws IOException {
+        if (context.workIsFailed()) {
+          throw new WorkItemCancelledException(
+              checkStateNotNull(context.getWorkItem()).getShardingKey());
         }
 
-        @Override
-        public boolean advance() throws IOException {
-          current = null;
-          context.finishKey();
-          return false;
+        context.finishKey();
+        if (context.advance()) {
+          @SuppressWarnings("unchecked")
+          K newKey = (K) context.getKey();

Review Comment:
   nullable K?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java:
##########
@@ -354,27 +330,35 @@ private Windmill.WorkItemCommitRequest 
validateCommitRequestSize(
   }
 
   private void recordProcessingStats(
-      Windmill.WorkItemCommitRequest.Builder outputBuilder,
-      Windmill.WorkItem workItem,
-      ExecuteWorkResult executeWorkResult) {
-    // Compute shuffle and state byte statistics these will be flushed 
asynchronously.
-    long stateBytesWritten =
-        outputBuilder
-            .clearOutputMessages()
-            .clearPerWorkItemLatencyAttributions()
-            .build()
-            .getSerializedSize();
-
-    
streamingCounters.windmillShuffleBytesRead().addValue(computeShuffleBytesRead(workItem));
-    
streamingCounters.windmillStateBytesRead().addValue(executeWorkResult.stateBytesRead());
-    streamingCounters.windmillStateBytesWritten().addValue(stateBytesWritten);
+      List<Work> workBatch,
+      List<Windmill.WorkItemCommitRequest.Builder> outputBuilders,
+      long totalStateBytesRead) {
+    long totalStateBytesWritten = 0;
+    long totalShuffleBytesRead = 0;
+    for (int i = 0; i < workBatch.size(); i++) {
+      Windmill.WorkItem workItem = workBatch.get(i).getWorkItem();
+      Windmill.WorkItemCommitRequest.Builder outputBuilder = 
outputBuilders.get(i);
+      // Compute shuffle and state byte statistics these will be flushed 
asynchronously.
+      long stateBytesWritten =
+          outputBuilder
+              .clearOutputMessages()
+              .clearPerWorkItemLatencyAttributions()
+              .build()
+              .getSerializedSize();

Review Comment:
   does seem possibly dangerous.  Also this seems possibly slow unless the 
serialized size is cached by java impl



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to