This is an automated email from the ASF dual-hosted git repository.

scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 3732cffb33c [Dataflow Streaming] Fix nullness supression in 
StreamingModeExecutionContext (#38842)
3732cffb33c is described below

commit 3732cffb33c856905174c3d0c444d093b075d301
Author: Arun Pandian <[email protected]>
AuthorDate: Mon Jun 8 13:01:24 2026 -0700

    [Dataflow Streaming] Fix nullness supression in 
StreamingModeExecutionContext (#38842)
---
 .../worker/StreamingModeExecutionContext.java      | 271 +++++++++++----------
 .../worker/StreamingModeExecutionContextTest.java  |  74 +++---
 2 files changed, 178 insertions(+), 167 deletions(-)

diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java
index 25ce299adf7..00fdf67b8d0 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java
@@ -18,7 +18,6 @@
 package org.apache.beam.runners.dataflow.worker;
 
 import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
-import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
 import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;
 
 import com.google.api.services.dataflow.model.CounterUpdate;
@@ -62,6 +61,7 @@ import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataId;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill.Timer;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest;
 import 
org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache;
 import 
org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache.ForComputation;
 import 
org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateInternals;
@@ -94,6 +94,7 @@ import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterat
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.PeekingIterator;
 import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Table;
+import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
 import org.checkerframework.checker.nullness.qual.Nullable;
 import org.joda.time.Duration;
 import org.joda.time.Instant;
@@ -105,11 +106,7 @@ import org.slf4j.LoggerFactory;
  * state pertaining to a processing its owning computation. Can be reused 
across processing
  * different WorkItems for the same computation.
  */
-@SuppressWarnings({
-  "deprecation",
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-// TODO(m-trieu) fix nullability issues in StreamingModeExecutionContext.java
+@SuppressWarnings({"deprecation"})
 @NotThreadSafe
 @Internal
 public class StreamingModeExecutionContext extends 
DataflowExecutionContext<StepContext> {
@@ -130,7 +127,7 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
    */
   private final Map<TupleTag<?>, Map<BoundedWindow, SideInput<?>>> 
sideInputCache;
 
-  private WindmillTagEncoding windmillTagEncoding;
+  private final WindmillTagEncoding windmillTagEncoding;
   /**
    * The current user-facing key for this execution context.
    *
@@ -143,13 +140,13 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
   private @Nullable Object key = null;
 
   private @Nullable Work work;
-  private WindmillComputationKey computationKey;
-  private SideInputStateFetcher sideInputStateFetcher;
+  private @Nullable WindmillComputationKey computationKey;
+  private @Nullable SideInputStateFetcher sideInputStateFetcher;
   // OperationalLimits is updated in start() because a 
StreamingModeExecutionContext can
   // be used for processing many work items and these values can change during 
the context's
   // lifetime. start() is called for each work item.
   private OperationalLimits operationalLimits;
-  private Windmill.WorkItemCommitRequest.Builder outputBuilder;
+  private Windmill.WorkItemCommitRequest.@Nullable Builder outputBuilder;
 
   /**
    * Current reader used for processing {@link Work}. Set by calling {@link
@@ -188,6 +185,12 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
     this.stateCache = stateCache;
     this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN;
     this.throwExceptionOnLargeOutput = throwExceptionOnLargeOutput;
+    StreamingGlobalConfig config = globalConfigHandle.getConfig();
+    this.operationalLimits = config.operationalLimits();
+    this.windmillTagEncoding =
+        config.enableStateTagEncodingV2()
+            ? WindmillTagEncodingV2.instance()
+            : WindmillTagEncodingV1.instance();
   }
 
   @VisibleForTesting
@@ -229,7 +232,7 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
       throw new RuntimeException(
           "Unexpected getCurrentRecordId() while offset-based deduplication is 
not enabled.");
     }
-    return activeReader.getCurrentRecordId();
+    return checkStateNotNull(activeReader).getCurrentRecordId();
   }
 
   public byte[] getCurrentRecordOffset() {
@@ -237,7 +240,7 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
       throw new RuntimeException(
           "Unexpected getCurrentRecordOffset() while offset-based 
deduplication is not enabled.");
     }
-    return activeReader.getCurrentRecordOffset();
+    return checkStateNotNull(activeReader).getCurrentRecordOffset();
   }
 
   public void start(
@@ -256,10 +259,6 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
     StreamingGlobalConfig config = globalConfigHandle.getConfig();
     // Snapshot the limits for entire bundle processing.
     this.operationalLimits = config.operationalLimits();
-    this.windmillTagEncoding =
-        config.enableStateTagEncodingV2()
-            ? WindmillTagEncodingV2.instance()
-            : WindmillTagEncodingV1.instance();
     this.outputBuilder = outputBuilder;
     this.sideInputCache.clear();
     this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN;
@@ -369,9 +368,9 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
     return fetchSideInputFromWindmill(
         view,
         sideInputWindow,
-        checkNotNull(stateFamily),
+        checkStateNotNull(stateFamily),
         state,
-        checkNotNull(scopedReadStateSupplier),
+        checkStateNotNull(scopedReadStateSupplier),
         tagCache);
   }
 
@@ -383,8 +382,8 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
       Supplier<Closeable> scopedReadStateSupplier,
       Map<BoundedWindow, SideInput<?>> tagCache) {
     SideInput<T> fetched =
-        sideInputStateFetcher.fetchSideInput(
-            view, sideInputWindow, stateFamily, state, 
scopedReadStateSupplier);
+        checkStateNotNull(sideInputStateFetcher)
+            .fetchSideInput(view, sideInputWindow, stateFamily, state, 
scopedReadStateSupplier);
 
     if (fetched.isReady()) {
       tagCache.put(sideInputWindow, fetched);
@@ -406,7 +405,7 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
   }
 
   public WindmillComputationKey getComputationKey() {
-    return computationKey;
+    return checkStateNotNull(computationKey);
   }
 
   public long getWorkToken() {
@@ -414,7 +413,7 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
   }
 
   public Windmill.WorkItem getWorkItem() {
-    return checkNotNull(
+    return checkStateNotNull(
             work,
             "work is null. A call to StreamingModeExecutionContext.start(...) 
is required to set"
                 + " work for execution.")
@@ -422,7 +421,7 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
   }
 
   public Windmill.WorkItemCommitRequest.Builder getOutputBuilder() {
-    return outputBuilder;
+    return checkStateNotNull(outputBuilder);
   }
 
   /**
@@ -490,15 +489,16 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
                     throw new RuntimeException("Exception while running bundle 
finalizer", e);
                   }
                 }));
-        outputBuilder.addFinalizeIds(id);
+        getOutputBuilder().addFinalizeIds(id);
       }
     }
 
-    if (activeReader != null) {
-      Windmill.SourceState.Builder sourceStateBuilder =
-          outputBuilder.getSourceStateUpdatesBuilder();
-      final UnboundedSource.CheckpointMark checkpointMark = 
activeReader.getCheckpointMark();
-      final Instant watermark = activeReader.getWatermark();
+    UnboundedReader<?> reader = activeReader;
+    if (reader != null) {
+      Windmill.WorkItemCommitRequest.Builder builder = getOutputBuilder();
+      Windmill.SourceState.Builder sourceStateBuilder = 
builder.getSourceStateUpdatesBuilder();
+      final UnboundedSource.CheckpointMark checkpointMark = 
reader.getCheckpointMark();
+      final Instant watermark = reader.getWatermark();
       long id = ThreadLocalRandom.current().nextLong();
       sourceStateBuilder.addFinalizeIds(id);
       callbacks.put(
@@ -515,7 +515,7 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
 
       @SuppressWarnings("unchecked")
       Coder<UnboundedSource.CheckpointMark> checkpointCoder =
-          ((UnboundedSource<?, UnboundedSource.CheckpointMark>) 
activeReader.getCurrentSource())
+          ((UnboundedSource<?, UnboundedSource.CheckpointMark>) 
reader.getCurrentSource())
               .getCheckpointMarkCoder();
       if (checkpointCoder != null) {
         ByteStringOutputStream stream = new ByteStringOutputStream();
@@ -525,7 +525,7 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
           throw new RuntimeException("Exception while encoding checkpoint", e);
         }
         sourceStateBuilder.setState(stream.toByteString());
-        if 
(activeReader.getCurrentSource().offsetBasedDeduplicationSupported()) {
+        if (reader.getCurrentSource().offsetBasedDeduplicationSupported()) {
           byte[] offsetLimit = checkpointMark.getOffsetLimit();
           if (offsetLimit.length == 0) {
             throw new RuntimeException("Checkpoint offset limit must be 
non-empty.");
@@ -533,31 +533,30 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
           sourceStateBuilder.setOffsetLimit(ByteString.copyFrom(offsetLimit));
         }
       }
-      
outputBuilder.setSourceWatermark(WindmillTimeUtils.harnessToWindmillTimestamp(watermark));
+      
builder.setSourceWatermark(WindmillTimeUtils.harnessToWindmillTimestamp(watermark));
 
-      backlogBytes = activeReader.getSplitBacklogBytes();
+      backlogBytes = reader.getSplitBacklogBytes();
+      ByteString serializedKey = checkStateNotNull(getSerializedKey());
       if (backlogBytes == UnboundedReader.BACKLOG_UNKNOWN
-          && 
WorkerCustomSources.isFirstUnboundedSourceSplit(getSerializedKey())) {
+          && WorkerCustomSources.isFirstUnboundedSourceSplit(serializedKey)) {
         // Only call getTotalBacklogBytes() on the first split.
-        backlogBytes = activeReader.getTotalBacklogBytes();
+        backlogBytes = reader.getTotalBacklogBytes();
       }
-      outputBuilder.setSourceBacklogBytes(backlogBytes);
+      builder.setSourceBacklogBytes(backlogBytes);
 
       readerCache.cacheReader(
-          getComputationKey(),
-          getWorkItem().getCacheToken(),
-          getWorkItem().getWorkToken(),
-          activeReader);
+          getComputationKey(), getWorkItem().getCacheToken(), 
getWorkItem().getWorkToken(), reader);
       activeReader = null;
     } else if (backlogBytes != UnboundedReader.BACKLOG_UNKNOWN && backlogBytes 
!= 1L) {
       // If activeReader is null, we might still have backlogBytes from an 
SDF. We ignore a reported
       // backlogBytes of 1 since older versions of the Java SDK use this value 
as a default when
       // RestrictionTracker.getProgress() or GetSize() are not defined.
-      outputBuilder.setSourceBacklogBytes(backlogBytes);
+      getOutputBuilder().setSourceBacklogBytes(backlogBytes);
     }
     return callbacks;
   }
 
+  @Nullable
   String getStateFamily(NameContext nameContext) {
     return nameContext.userName() == null ? null : 
stateNameMap.get(nameContext.userName());
   }
@@ -599,7 +598,7 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
     public StreamingModeExecutionState(
         NameContext nameContext,
         String stateName,
-        MetricsContainer metricsContainer,
+        @Nullable MetricsContainer metricsContainer,
         ProfileScope profileScope) {
       // TODO: Take in the requesting step name and side input index for 
streaming.
       super(nameContext, stateName, null, null, metricsContainer, 
profileScope);
@@ -642,14 +641,16 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
     protected DataflowExecutionState createState(
         NameContext nameContext,
         String stateName,
-        String requestingStepName,
-        Integer inputIndex,
-        MetricsContainer container,
+        @Nullable String requestingStepName,
+        @Nullable Integer inputIndex,
+        @Nullable MetricsContainer container,
         ProfileScope profileScope) {
       return new StreamingModeExecutionState(nameContext, stateName, 
container, profileScope);
     }
   }
 
+  private static final Closeable NO_OP_CLOSEABLE = () -> {};
+
   private static class ScopedReadStateSupplier implements Supplier<Closeable> {
 
     private final ExecutionState readState;
@@ -662,9 +663,9 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
     }
 
     @Override
-    public @Nullable Closeable get() {
+    public Closeable get() {
       if (stateTracker == null) {
-        return null;
+        return NO_OP_CLOSEABLE;
       }
       return stateTracker.enterState(readState);
     }
@@ -725,7 +726,7 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
     }
 
     @Override
-    public <W extends BoundedWindow> TimerData getNextFiredTimer(Coder<W> 
windowCoder) {
+    public <W extends BoundedWindow> @Nullable TimerData 
getNextFiredTimer(Coder<W> windowCoder) {
       return wrapped.getNextFiredUserTimer(windowCoder);
     }
 
@@ -777,7 +778,7 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
     }
 
     @Override
-    public <T> T get(PCollectionView<T> view, BoundedWindow window) {
+    public <T> @Nullable T get(PCollectionView<T> view, BoundedWindow window) {
       if (!contains(view)) {
         throw new RuntimeException("get() called with unknown view");
       }
@@ -810,31 +811,32 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
   class StepContext extends DataflowExecutionContext.DataflowStepContext
       implements StreamingModeStepContext {
 
-    private final String stateFamily;
+    private final @Nullable String stateFamily;
     private final Supplier<Closeable> scopedReadStateSupplier;
-    private WindmillStateInternals<Object> stateInternals;
-    private WindmillTimerInternals systemTimerInternals;
-    private WindmillTimerInternals userTimerInternals;
+    private @MonotonicNonNull WindmillStateInternals<Object> stateInternals;
+    private @MonotonicNonNull WindmillTimerInternals systemTimerInternals;
+    private @MonotonicNonNull WindmillTimerInternals userTimerInternals;
     // Lazily initialized
-    private Iterator<TimerData> cachedFiredSystemTimers = null;
+    private @Nullable Iterator<TimerData> cachedFiredSystemTimers = null;
     // Lazily initialized
-    private PeekingIterator<TimerData> cachedFiredUserTimers = null;
+    private @Nullable PeekingIterator<TimerData> cachedFiredUserTimers = null;
     // An ordered list of any timers that were set or modified by user 
processing earlier in this
     // bundle.
     // We use a NavigableSet instead of a priority queue to prevent duplicate 
elements from ending
     // up in the queue.
-    private NavigableSet<TimerData> modifiedUserEventTimersOrdered = null;
-    private NavigableSet<TimerData> modifiedUserProcessingTimersOrdered = null;
-    private NavigableSet<TimerData> 
modifiedUserSynchronizedProcessingTimersOrdered = null;
+    private final NavigableSet<TimerData> modifiedUserEventTimersOrdered = 
Sets.newTreeSet();
+    private final NavigableSet<TimerData> modifiedUserProcessingTimersOrdered 
= Sets.newTreeSet();
+    private final NavigableSet<TimerData> 
modifiedUserSynchronizedProcessingTimersOrdered =
+        Sets.newTreeSet();
     // A list of timer keys that were modified by user processing earlier in 
this bundle. This
     // serves a tombstone, so that we know not to fire any bundle timers that 
were modified.
-    private Table<String, StateNamespace, TimerData> modifiedUserTimerKeys = 
null;
+    private final Table<String, StateNamespace, TimerData> 
modifiedUserTimerKeys =
+        HashBasedTable.create();
     private final WindmillBundleFinalizer bundleFinalizer = new 
WindmillBundleFinalizer();
 
     public StepContext(DataflowOperationContext operationContext) {
       super(operationContext.nameContext());
       this.stateFamily = getStateFamily(operationContext.nameContext());
-
       this.scopedReadStateSupplier =
           new ScopedReadStateSupplier(operationContext, 
getExecutionStateTracker());
     }
@@ -845,46 +847,50 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
         Instant processingTime,
         WindmillStateCache.ForKey cacheForKey,
         Watermarks watermarks) {
-      this.stateInternals =
-          new WindmillStateInternals<>(
-              key,
-              stateFamily,
-              stateReader,
-              getWorkItem().getIsNewKey(),
-              cacheForKey.forFamily(stateFamily),
-              windmillTagEncoding,
-              scopedReadStateSupplier);
-
-      this.systemTimerInternals =
-          new WindmillTimerInternals(
-              stateFamily,
-              WindmillTimerType.SYSTEM_TIMER,
-              processingTime,
-              watermarks,
-              windmillTagEncoding,
-              td -> {});
-
-      this.userTimerInternals =
-          new WindmillTimerInternals(
-              stateFamily,
-              WindmillTimerType.USER_TIMER,
-              processingTime,
-              watermarks,
-              windmillTagEncoding,
-              this::onUserTimerModified);
-
+      if (stateFamily != null) {
+        this.stateInternals =
+            new WindmillStateInternals<>(
+                key,
+                stateFamily,
+                stateReader,
+                getWorkItem().getIsNewKey(),
+                cacheForKey.forFamily(stateFamily),
+                windmillTagEncoding,
+                scopedReadStateSupplier);
+
+        this.systemTimerInternals =
+            new WindmillTimerInternals(
+                stateFamily,
+                WindmillTimerType.SYSTEM_TIMER,
+                processingTime,
+                watermarks,
+                windmillTagEncoding,
+                td -> {});
+
+        this.userTimerInternals =
+            new WindmillTimerInternals(
+                stateFamily,
+                WindmillTimerType.USER_TIMER,
+                processingTime,
+                watermarks,
+                windmillTagEncoding,
+                this::onUserTimerModified);
+      }
       this.cachedFiredSystemTimers = null;
       this.cachedFiredUserTimers = null;
-      modifiedUserEventTimersOrdered = Sets.newTreeSet();
-      modifiedUserProcessingTimersOrdered = Sets.newTreeSet();
-      modifiedUserSynchronizedProcessingTimersOrdered = Sets.newTreeSet();
-      modifiedUserTimerKeys = HashBasedTable.create();
+      this.modifiedUserEventTimersOrdered.clear();
+      this.modifiedUserProcessingTimersOrdered.clear();
+      this.modifiedUserSynchronizedProcessingTimersOrdered.clear();
+      this.modifiedUserTimerKeys.clear();
     }
 
     public void flushState() {
-      stateInternals.persist(outputBuilder);
-      systemTimerInternals.persistTo(outputBuilder);
-      userTimerInternals.persistTo(outputBuilder);
+      if (stateFamily != null) {
+        WorkItemCommitRequest.Builder builder = getOutputBuilder();
+        checkStateNotNull(stateInternals).persist(builder);
+        checkStateNotNull(systemTimerInternals).persistTo(builder);
+        checkStateNotNull(userTimerInternals).persistTo(builder);
+      }
     }
 
     @Override
@@ -893,9 +899,14 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
     }
 
     @Override
-    public <W extends BoundedWindow> TimerData getNextFiredTimer(Coder<W> 
windowCoder) {
-      if (cachedFiredSystemTimers == null) {
-        cachedFiredSystemTimers =
+    public <W extends BoundedWindow> @Nullable TimerData 
getNextFiredTimer(Coder<W> windowCoder) {
+      if (stateFamily == null) {
+        // no timers on stateless stages
+        return null;
+      }
+      Iterator<TimerData> firedSystemTimers = cachedFiredSystemTimers;
+      if (firedSystemTimers == null) {
+        firedSystemTimers =
             
FluentIterable.from(StreamingModeExecutionContext.this.getFiredTimers())
                 .filter(timer -> timer.getStateFamily().equals(stateFamily))
                 .transform(
@@ -907,16 +918,17 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
                         windmillTimerData.getWindmillTimerType() == 
WindmillTimerType.SYSTEM_TIMER)
                 .transform(WindmillTimerData::getTimerData)
                 .iterator();
+        cachedFiredSystemTimers = firedSystemTimers;
       }
 
-      if (!cachedFiredSystemTimers.hasNext()) {
+      if (!firedSystemTimers.hasNext()) {
         return null;
       }
-      TimerData nextTimer = cachedFiredSystemTimers.next();
+      TimerData nextTimer = firedSystemTimers.next();
       // system timers ( GC timer) must be explicitly deleted if only there is 
a hold.
       // if timestamp is not equals to outputTimestamp then there should be a 
hold
       if (!nextTimer.getTimestamp().equals(nextTimer.getOutputTimestamp())) {
-        systemTimerInternals.deleteTimer(nextTimer);
+        checkStateNotNull(systemTimerInternals).deleteTimer(nextTimer);
       }
       return nextTimer;
     }
@@ -950,12 +962,19 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
       return updatedTimer == null || updatedTimer.equals(timerData);
     }
 
-    public <W extends BoundedWindow> TimerData getNextFiredUserTimer(Coder<W> 
windowCoder) {
-      if (cachedFiredUserTimers == null) {
+    public <W extends BoundedWindow> @Nullable TimerData getNextFiredUserTimer(
+        Coder<W> windowCoder) {
+      if (stateFamily == null) {
+        // no timers on stateless stages
+        return null;
+      }
+
+      PeekingIterator<TimerData> firedUserTimers = cachedFiredUserTimers;
+      if (firedUserTimers == null) {
         // This is the first call to getNextFiredUserTimer in this bundle. 
Extract any user timers
         // from the bundle
         // and cache the list for the rest of this bundle processing.
-        cachedFiredUserTimers =
+        firedUserTimers =
             Iterators.peekingIterator(
                 
FluentIterable.from(StreamingModeExecutionContext.this.getFiredTimers())
                     .filter(timer -> 
timer.getStateFamily().equals(stateFamily))
@@ -969,17 +988,20 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
                                 == WindmillTimerType.USER_TIMER)
                     .transform(WindmillTimerData::getTimerData)
                     .iterator());
+        cachedFiredUserTimers = firedUserTimers;
       }
 
-      while (cachedFiredUserTimers.hasNext()) {
-        TimerData nextInBundle = cachedFiredUserTimers.peek();
+      WindmillTimerInternals nonNullUserTimerInternals = 
checkStateNotNull(this.userTimerInternals);
+
+      while (firedUserTimers.hasNext()) {
+        TimerData nextInBundle = firedUserTimers.peek();
         NavigableSet<TimerData> modifiedUserTimersOrdered =
             getModifiedUserTimersOrdered(nextInBundle.getDomain());
         // If there is a modified timer that is earlier than the next timer in 
the bundle, try and
         // fire that first.
         while (!modifiedUserTimersOrdered.isEmpty()
             && modifiedUserTimersOrdered.first().compareTo(nextInBundle) <= 0) 
{
-          TimerData earlierTimer = modifiedUserTimersOrdered.pollFirst();
+          TimerData earlierTimer = 
checkStateNotNull(modifiedUserTimersOrdered.pollFirst());
           if (isTimerUnmodified(earlierTimer)) {
             // We must delete the timer. This prevents it from being committed 
to the backing store.
             // It also handles the
@@ -987,15 +1009,15 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
             // without deleting the
             // timer, the runner will still have that future timer stored, and 
would fire it
             // spuriously.
-            userTimerInternals.deleteTimer(earlierTimer);
+            nonNullUserTimerInternals.deleteTimer(earlierTimer);
             return earlierTimer;
           }
         }
         // There is no earlier timer to fire, so return the next timer in the 
bundle.
-        nextInBundle = cachedFiredUserTimers.next();
+        nextInBundle = firedUserTimers.next();
         if (isTimerUnmodified(nextInBundle)) {
           // User timers must be explicitly deleted when delivered, to release 
the implied hold.
-          userTimerInternals.deleteTimer(nextInBundle);
+          nonNullUserTimerInternals.deleteTimer(nextInBundle);
           return nextInBundle;
         }
       }
@@ -1029,12 +1051,6 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
       return StreamingModeExecutionContext.this.getSideInputNotifications();
     }
 
-    private void ensureStateful(String errorPrefix) {
-      if (stateFamily == null) {
-        throw new IllegalStateException(errorPrefix + " for stateless step: " 
+ getNameContext());
-      }
-    }
-
     @Override
     public <T, W extends BoundedWindow> void writePCollectionViewData(
         TupleTag<?> tag,
@@ -1043,7 +1059,8 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
         W window,
         Coder<W> windowCoder)
         throws IOException {
-      if (getSerializedKey().size() != 0) {
+      ByteString serializedKey = checkStateNotNull(getSerializedKey());
+      if (serializedKey.size() != 0) {
         throw new IllegalStateException("writePCollectionViewData must follow 
a Combine.globally");
       }
 
@@ -1053,7 +1070,7 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
       ByteStringOutputStream windowStream = new ByteStringOutputStream();
       windowCoder.encode(window, windowStream, Coder.Context.OUTER);
 
-      ensureStateful("Tried to write view data");
+      String stateFamily = checkStateNotNull(this.stateFamily, "Tried to write 
view data");
 
       Windmill.GlobalData.Builder builder =
           Windmill.GlobalData.newBuilder()
@@ -1065,7 +1082,7 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
               .setData(dataStream.toByteString())
               .setStateFamily(stateFamily);
 
-      outputBuilder.addGlobalDataUpdates(builder.build());
+      getOutputBuilder().addGlobalDataUpdates(builder.build());
     }
 
     /** Fetch the given side input asynchronously and return true if it is 
present. */
@@ -1080,11 +1097,12 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
     /** Note that there is data on the current key that is blocked on the 
given side input. */
     @Override
     public void addBlockingSideInput(Windmill.GlobalDataRequest sideInput) {
-      ensureStateful("Tried to set global data request");
+      String stateFamily = checkStateNotNull(this.stateFamily, "Tried to set 
global data request");
       sideInput =
           
Windmill.GlobalDataRequest.newBuilder(sideInput).setStateFamily(stateFamily).build();
-      outputBuilder.addGlobalDataRequests(sideInput);
-      outputBuilder.addGlobalDataIdRequests(sideInput.getDataId());
+      WorkItemCommitRequest.Builder builder = getOutputBuilder();
+      builder.addGlobalDataRequests(sideInput);
+      builder.addGlobalDataIdRequests(sideInput.getDataId());
     }
 
     /** Note that there is data on the current key that is blocked on the 
given side inputs. */
@@ -1097,14 +1115,12 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
 
     @Override
     public StateInternals stateInternals() {
-      ensureStateful("Tried to access state");
-      return checkNotNull(stateInternals);
+      return checkStateNotNull(stateInternals, "Tried to access state");
     }
 
     @Override
     public TimerInternals timerInternals() {
-      ensureStateful("Tried to access timers");
-      return checkNotNull(systemTimerInternals);
+      return checkStateNotNull(systemTimerInternals, "Tried to access timers");
     }
 
     @Override
@@ -1113,8 +1129,7 @@ public class StreamingModeExecutionContext extends 
DataflowExecutionContext<Step
     }
 
     public TimerInternals userTimerInternals() {
-      ensureStateful("Tried to access user timers");
-      return checkNotNull(userTimerInternals);
+      return checkStateNotNull(userTimerInternals, "Tried to access user 
timers");
     }
 
     public ImmutableList<Pair<Instant, BundleFinalizer.Callback>> 
flushBundleFinalizerCallbacks() {
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
index 216ca538667..13601410bfd 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
@@ -61,6 +61,7 @@ import 
org.apache.beam.runners.dataflow.worker.streaming.Watermarks;
 import org.apache.beam.runners.dataflow.worker.streaming.Work;
 import 
org.apache.beam.runners.dataflow.worker.streaming.config.FakeGlobalConfigHandle;
 import 
org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfig;
+import 
org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfigHandle;
 import 
org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher;
 import org.apache.beam.runners.dataflow.worker.util.common.worker.WorkExecutor;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
@@ -110,36 +111,40 @@ public class StreamingModeExecutionContextTest {
   DataflowWorkerHarnessOptions options;
   private FakeGlobalConfigHandle globalConfigHandle;
 
+  private StreamingModeExecutionContext createExecutionContext(
+      StreamingGlobalConfigHandle configHandle) {
+    CounterSet counterSet = new CounterSet();
+    ConcurrentHashMap<String, String> stateNameMap = new ConcurrentHashMap<>();
+    stateNameMap.put(NameContextsForTests.nameContextForTest().userName(), 
"testStateFamily");
+    return new StreamingModeExecutionContext(
+        counterSet,
+        COMPUTATION_ID,
+        new ReaderCache(Duration.standardMinutes(1), 
Executors.newCachedThreadPool()),
+        stateNameMap,
+        WindmillStateCache.builder()
+            .setSizeMb(options.getWorkerCacheMb())
+            .build()
+            .forComputation("comp"),
+        StreamingStepMetricsContainer.createRegistry(),
+        new DataflowExecutionStateTracker(
+            ExecutionStateSampler.newForTest(),
+            executionStateRegistry.getState(
+                NameContext.forStage("stage"), "other", null, 
NoopProfileScope.NOOP),
+            counterSet,
+            PipelineOptionsFactory.create(),
+            "test-work-item-id"),
+        executionStateRegistry,
+        configHandle,
+        Long.MAX_VALUE,
+        /*throwExceptionOnLargeOutput=*/ false);
+  }
+
   @Before
   public void setUp() {
     MockitoAnnotations.initMocks(this);
     options = PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class);
-    CounterSet counterSet = new CounterSet();
-    ConcurrentHashMap<String, String> stateNameMap = new ConcurrentHashMap<>();
     globalConfigHandle = new 
FakeGlobalConfigHandle(StreamingGlobalConfig.builder().build());
-    stateNameMap.put(NameContextsForTests.nameContextForTest().userName(), 
"testStateFamily");
-    executionContext =
-        new StreamingModeExecutionContext(
-            counterSet,
-            COMPUTATION_ID,
-            new ReaderCache(Duration.standardMinutes(1), 
Executors.newCachedThreadPool()),
-            stateNameMap,
-            WindmillStateCache.builder()
-                .setSizeMb(options.getWorkerCacheMb())
-                .build()
-                .forComputation("comp"),
-            StreamingStepMetricsContainer.createRegistry(),
-            new DataflowExecutionStateTracker(
-                ExecutionStateSampler.newForTest(),
-                executionStateRegistry.getState(
-                    NameContext.forStage("stage"), "other", null, 
NoopProfileScope.NOOP),
-                counterSet,
-                PipelineOptionsFactory.create(),
-                "test-work-item-id"),
-            executionStateRegistry,
-            globalConfigHandle,
-            Long.MAX_VALUE,
-            /*throwExceptionOnLargeOutput=*/ false);
+    executionContext = createExecutionContext(globalConfigHandle);
   }
 
   private static Work createMockWork(Windmill.WorkItem workItem, Watermarks 
watermarks) {
@@ -421,20 +426,11 @@ public class StreamingModeExecutionContextTest {
     for (Boolean isV2Encoding : Lists.newArrayList(Boolean.TRUE, 
Boolean.FALSE)) {
       Class<?> expectedEncoding =
           isV2Encoding ? WindmillTagEncodingV2.class : 
WindmillTagEncodingV1.class;
-      Windmill.WorkItemCommitRequest.Builder outputBuilder =
-          Windmill.WorkItemCommitRequest.newBuilder();
-      globalConfigHandle.setConfig(
-          
StreamingGlobalConfig.builder().setEnableStateTagEncodingV2(isV2Encoding).build());
-      executionContext.start(
-          "key",
-          createMockWork(
-              
Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(17L).build(),
-              Watermarks.builder().setInputDataWatermark(new 
Instant(1000)).build()),
-          stateReader,
-          sideInputStateFetcher,
-          outputBuilder,
-          workExecutor);
-      assertEquals(expectedEncoding, 
executionContext.getWindmillTagEncoding().getClass());
+      FakeGlobalConfigHandle configHandle =
+          new FakeGlobalConfigHandle(
+              
StreamingGlobalConfig.builder().setEnableStateTagEncodingV2(isV2Encoding).build());
+      StreamingModeExecutionContext context = 
createExecutionContext(configHandle);
+      assertEquals(expectedEncoding, 
context.getWindmillTagEncoding().getClass());
     }
   }
 


Reply via email to