This is an automated email from the ASF dual-hosted git repository.
scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 3732cffb33c [Dataflow Streaming] Fix nullness supression in
StreamingModeExecutionContext (#38842)
3732cffb33c is described below
commit 3732cffb33c856905174c3d0c444d093b075d301
Author: Arun Pandian <[email protected]>
AuthorDate: Mon Jun 8 13:01:24 2026 -0700
[Dataflow Streaming] Fix nullness supression in
StreamingModeExecutionContext (#38842)
---
.../worker/StreamingModeExecutionContext.java | 271 +++++++++++----------
.../worker/StreamingModeExecutionContextTest.java | 74 +++---
2 files changed, 178 insertions(+), 167 deletions(-)
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java
index 25ce299adf7..00fdf67b8d0 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java
@@ -18,7 +18,6 @@
package org.apache.beam.runners.dataflow.worker;
import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
-import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;
import com.google.api.services.dataflow.model.CounterUpdate;
@@ -62,6 +61,7 @@ import
org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataId;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.Timer;
+import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest;
import
org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache;
import
org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache.ForComputation;
import
org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateInternals;
@@ -94,6 +94,7 @@ import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterat
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.PeekingIterator;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Table;
+import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.Duration;
import org.joda.time.Instant;
@@ -105,11 +106,7 @@ import org.slf4j.LoggerFactory;
* state pertaining to a processing its owning computation. Can be reused
across processing
* different WorkItems for the same computation.
*/
-@SuppressWarnings({
- "deprecation",
- "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-// TODO(m-trieu) fix nullability issues in StreamingModeExecutionContext.java
+@SuppressWarnings({"deprecation"})
@NotThreadSafe
@Internal
public class StreamingModeExecutionContext extends
DataflowExecutionContext<StepContext> {
@@ -130,7 +127,7 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
*/
private final Map<TupleTag<?>, Map<BoundedWindow, SideInput<?>>>
sideInputCache;
- private WindmillTagEncoding windmillTagEncoding;
+ private final WindmillTagEncoding windmillTagEncoding;
/**
* The current user-facing key for this execution context.
*
@@ -143,13 +140,13 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
private @Nullable Object key = null;
private @Nullable Work work;
- private WindmillComputationKey computationKey;
- private SideInputStateFetcher sideInputStateFetcher;
+ private @Nullable WindmillComputationKey computationKey;
+ private @Nullable SideInputStateFetcher sideInputStateFetcher;
// OperationalLimits is updated in start() because a
StreamingModeExecutionContext can
// be used for processing many work items and these values can change during
the context's
// lifetime. start() is called for each work item.
private OperationalLimits operationalLimits;
- private Windmill.WorkItemCommitRequest.Builder outputBuilder;
+ private Windmill.WorkItemCommitRequest.@Nullable Builder outputBuilder;
/**
* Current reader used for processing {@link Work}. Set by calling {@link
@@ -188,6 +185,12 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
this.stateCache = stateCache;
this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN;
this.throwExceptionOnLargeOutput = throwExceptionOnLargeOutput;
+ StreamingGlobalConfig config = globalConfigHandle.getConfig();
+ this.operationalLimits = config.operationalLimits();
+ this.windmillTagEncoding =
+ config.enableStateTagEncodingV2()
+ ? WindmillTagEncodingV2.instance()
+ : WindmillTagEncodingV1.instance();
}
@VisibleForTesting
@@ -229,7 +232,7 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
throw new RuntimeException(
"Unexpected getCurrentRecordId() while offset-based deduplication is
not enabled.");
}
- return activeReader.getCurrentRecordId();
+ return checkStateNotNull(activeReader).getCurrentRecordId();
}
public byte[] getCurrentRecordOffset() {
@@ -237,7 +240,7 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
throw new RuntimeException(
"Unexpected getCurrentRecordOffset() while offset-based
deduplication is not enabled.");
}
- return activeReader.getCurrentRecordOffset();
+ return checkStateNotNull(activeReader).getCurrentRecordOffset();
}
public void start(
@@ -256,10 +259,6 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
StreamingGlobalConfig config = globalConfigHandle.getConfig();
// Snapshot the limits for entire bundle processing.
this.operationalLimits = config.operationalLimits();
- this.windmillTagEncoding =
- config.enableStateTagEncodingV2()
- ? WindmillTagEncodingV2.instance()
- : WindmillTagEncodingV1.instance();
this.outputBuilder = outputBuilder;
this.sideInputCache.clear();
this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN;
@@ -369,9 +368,9 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
return fetchSideInputFromWindmill(
view,
sideInputWindow,
- checkNotNull(stateFamily),
+ checkStateNotNull(stateFamily),
state,
- checkNotNull(scopedReadStateSupplier),
+ checkStateNotNull(scopedReadStateSupplier),
tagCache);
}
@@ -383,8 +382,8 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
Supplier<Closeable> scopedReadStateSupplier,
Map<BoundedWindow, SideInput<?>> tagCache) {
SideInput<T> fetched =
- sideInputStateFetcher.fetchSideInput(
- view, sideInputWindow, stateFamily, state,
scopedReadStateSupplier);
+ checkStateNotNull(sideInputStateFetcher)
+ .fetchSideInput(view, sideInputWindow, stateFamily, state,
scopedReadStateSupplier);
if (fetched.isReady()) {
tagCache.put(sideInputWindow, fetched);
@@ -406,7 +405,7 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
}
public WindmillComputationKey getComputationKey() {
- return computationKey;
+ return checkStateNotNull(computationKey);
}
public long getWorkToken() {
@@ -414,7 +413,7 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
}
public Windmill.WorkItem getWorkItem() {
- return checkNotNull(
+ return checkStateNotNull(
work,
"work is null. A call to StreamingModeExecutionContext.start(...)
is required to set"
+ " work for execution.")
@@ -422,7 +421,7 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
}
public Windmill.WorkItemCommitRequest.Builder getOutputBuilder() {
- return outputBuilder;
+ return checkStateNotNull(outputBuilder);
}
/**
@@ -490,15 +489,16 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
throw new RuntimeException("Exception while running bundle
finalizer", e);
}
}));
- outputBuilder.addFinalizeIds(id);
+ getOutputBuilder().addFinalizeIds(id);
}
}
- if (activeReader != null) {
- Windmill.SourceState.Builder sourceStateBuilder =
- outputBuilder.getSourceStateUpdatesBuilder();
- final UnboundedSource.CheckpointMark checkpointMark =
activeReader.getCheckpointMark();
- final Instant watermark = activeReader.getWatermark();
+ UnboundedReader<?> reader = activeReader;
+ if (reader != null) {
+ Windmill.WorkItemCommitRequest.Builder builder = getOutputBuilder();
+ Windmill.SourceState.Builder sourceStateBuilder =
builder.getSourceStateUpdatesBuilder();
+ final UnboundedSource.CheckpointMark checkpointMark =
reader.getCheckpointMark();
+ final Instant watermark = reader.getWatermark();
long id = ThreadLocalRandom.current().nextLong();
sourceStateBuilder.addFinalizeIds(id);
callbacks.put(
@@ -515,7 +515,7 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
@SuppressWarnings("unchecked")
Coder<UnboundedSource.CheckpointMark> checkpointCoder =
- ((UnboundedSource<?, UnboundedSource.CheckpointMark>)
activeReader.getCurrentSource())
+ ((UnboundedSource<?, UnboundedSource.CheckpointMark>)
reader.getCurrentSource())
.getCheckpointMarkCoder();
if (checkpointCoder != null) {
ByteStringOutputStream stream = new ByteStringOutputStream();
@@ -525,7 +525,7 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
throw new RuntimeException("Exception while encoding checkpoint", e);
}
sourceStateBuilder.setState(stream.toByteString());
- if
(activeReader.getCurrentSource().offsetBasedDeduplicationSupported()) {
+ if (reader.getCurrentSource().offsetBasedDeduplicationSupported()) {
byte[] offsetLimit = checkpointMark.getOffsetLimit();
if (offsetLimit.length == 0) {
throw new RuntimeException("Checkpoint offset limit must be
non-empty.");
@@ -533,31 +533,30 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
sourceStateBuilder.setOffsetLimit(ByteString.copyFrom(offsetLimit));
}
}
-
outputBuilder.setSourceWatermark(WindmillTimeUtils.harnessToWindmillTimestamp(watermark));
+
builder.setSourceWatermark(WindmillTimeUtils.harnessToWindmillTimestamp(watermark));
- backlogBytes = activeReader.getSplitBacklogBytes();
+ backlogBytes = reader.getSplitBacklogBytes();
+ ByteString serializedKey = checkStateNotNull(getSerializedKey());
if (backlogBytes == UnboundedReader.BACKLOG_UNKNOWN
- &&
WorkerCustomSources.isFirstUnboundedSourceSplit(getSerializedKey())) {
+ && WorkerCustomSources.isFirstUnboundedSourceSplit(serializedKey)) {
// Only call getTotalBacklogBytes() on the first split.
- backlogBytes = activeReader.getTotalBacklogBytes();
+ backlogBytes = reader.getTotalBacklogBytes();
}
- outputBuilder.setSourceBacklogBytes(backlogBytes);
+ builder.setSourceBacklogBytes(backlogBytes);
readerCache.cacheReader(
- getComputationKey(),
- getWorkItem().getCacheToken(),
- getWorkItem().getWorkToken(),
- activeReader);
+ getComputationKey(), getWorkItem().getCacheToken(),
getWorkItem().getWorkToken(), reader);
activeReader = null;
} else if (backlogBytes != UnboundedReader.BACKLOG_UNKNOWN && backlogBytes
!= 1L) {
// If activeReader is null, we might still have backlogBytes from an
SDF. We ignore a reported
// backlogBytes of 1 since older versions of the Java SDK use this value
as a default when
// RestrictionTracker.getProgress() or GetSize() are not defined.
- outputBuilder.setSourceBacklogBytes(backlogBytes);
+ getOutputBuilder().setSourceBacklogBytes(backlogBytes);
}
return callbacks;
}
+ @Nullable
String getStateFamily(NameContext nameContext) {
return nameContext.userName() == null ? null :
stateNameMap.get(nameContext.userName());
}
@@ -599,7 +598,7 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
public StreamingModeExecutionState(
NameContext nameContext,
String stateName,
- MetricsContainer metricsContainer,
+ @Nullable MetricsContainer metricsContainer,
ProfileScope profileScope) {
// TODO: Take in the requesting step name and side input index for
streaming.
super(nameContext, stateName, null, null, metricsContainer,
profileScope);
@@ -642,14 +641,16 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
protected DataflowExecutionState createState(
NameContext nameContext,
String stateName,
- String requestingStepName,
- Integer inputIndex,
- MetricsContainer container,
+ @Nullable String requestingStepName,
+ @Nullable Integer inputIndex,
+ @Nullable MetricsContainer container,
ProfileScope profileScope) {
return new StreamingModeExecutionState(nameContext, stateName,
container, profileScope);
}
}
+ private static final Closeable NO_OP_CLOSEABLE = () -> {};
+
private static class ScopedReadStateSupplier implements Supplier<Closeable> {
private final ExecutionState readState;
@@ -662,9 +663,9 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
}
@Override
- public @Nullable Closeable get() {
+ public Closeable get() {
if (stateTracker == null) {
- return null;
+ return NO_OP_CLOSEABLE;
}
return stateTracker.enterState(readState);
}
@@ -725,7 +726,7 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
}
@Override
- public <W extends BoundedWindow> TimerData getNextFiredTimer(Coder<W>
windowCoder) {
+ public <W extends BoundedWindow> @Nullable TimerData
getNextFiredTimer(Coder<W> windowCoder) {
return wrapped.getNextFiredUserTimer(windowCoder);
}
@@ -777,7 +778,7 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
}
@Override
- public <T> T get(PCollectionView<T> view, BoundedWindow window) {
+ public <T> @Nullable T get(PCollectionView<T> view, BoundedWindow window) {
if (!contains(view)) {
throw new RuntimeException("get() called with unknown view");
}
@@ -810,31 +811,32 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
class StepContext extends DataflowExecutionContext.DataflowStepContext
implements StreamingModeStepContext {
- private final String stateFamily;
+ private final @Nullable String stateFamily;
private final Supplier<Closeable> scopedReadStateSupplier;
- private WindmillStateInternals<Object> stateInternals;
- private WindmillTimerInternals systemTimerInternals;
- private WindmillTimerInternals userTimerInternals;
+ private @MonotonicNonNull WindmillStateInternals<Object> stateInternals;
+ private @MonotonicNonNull WindmillTimerInternals systemTimerInternals;
+ private @MonotonicNonNull WindmillTimerInternals userTimerInternals;
// Lazily initialized
- private Iterator<TimerData> cachedFiredSystemTimers = null;
+ private @Nullable Iterator<TimerData> cachedFiredSystemTimers = null;
// Lazily initialized
- private PeekingIterator<TimerData> cachedFiredUserTimers = null;
+ private @Nullable PeekingIterator<TimerData> cachedFiredUserTimers = null;
// An ordered list of any timers that were set or modified by user
processing earlier in this
// bundle.
// We use a NavigableSet instead of a priority queue to prevent duplicate
elements from ending
// up in the queue.
- private NavigableSet<TimerData> modifiedUserEventTimersOrdered = null;
- private NavigableSet<TimerData> modifiedUserProcessingTimersOrdered = null;
- private NavigableSet<TimerData>
modifiedUserSynchronizedProcessingTimersOrdered = null;
+ private final NavigableSet<TimerData> modifiedUserEventTimersOrdered =
Sets.newTreeSet();
+ private final NavigableSet<TimerData> modifiedUserProcessingTimersOrdered
= Sets.newTreeSet();
+ private final NavigableSet<TimerData>
modifiedUserSynchronizedProcessingTimersOrdered =
+ Sets.newTreeSet();
// A list of timer keys that were modified by user processing earlier in
this bundle. This
// serves a tombstone, so that we know not to fire any bundle timers that
were modified.
- private Table<String, StateNamespace, TimerData> modifiedUserTimerKeys =
null;
+ private final Table<String, StateNamespace, TimerData>
modifiedUserTimerKeys =
+ HashBasedTable.create();
private final WindmillBundleFinalizer bundleFinalizer = new
WindmillBundleFinalizer();
public StepContext(DataflowOperationContext operationContext) {
super(operationContext.nameContext());
this.stateFamily = getStateFamily(operationContext.nameContext());
-
this.scopedReadStateSupplier =
new ScopedReadStateSupplier(operationContext,
getExecutionStateTracker());
}
@@ -845,46 +847,50 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
Instant processingTime,
WindmillStateCache.ForKey cacheForKey,
Watermarks watermarks) {
- this.stateInternals =
- new WindmillStateInternals<>(
- key,
- stateFamily,
- stateReader,
- getWorkItem().getIsNewKey(),
- cacheForKey.forFamily(stateFamily),
- windmillTagEncoding,
- scopedReadStateSupplier);
-
- this.systemTimerInternals =
- new WindmillTimerInternals(
- stateFamily,
- WindmillTimerType.SYSTEM_TIMER,
- processingTime,
- watermarks,
- windmillTagEncoding,
- td -> {});
-
- this.userTimerInternals =
- new WindmillTimerInternals(
- stateFamily,
- WindmillTimerType.USER_TIMER,
- processingTime,
- watermarks,
- windmillTagEncoding,
- this::onUserTimerModified);
-
+ if (stateFamily != null) {
+ this.stateInternals =
+ new WindmillStateInternals<>(
+ key,
+ stateFamily,
+ stateReader,
+ getWorkItem().getIsNewKey(),
+ cacheForKey.forFamily(stateFamily),
+ windmillTagEncoding,
+ scopedReadStateSupplier);
+
+ this.systemTimerInternals =
+ new WindmillTimerInternals(
+ stateFamily,
+ WindmillTimerType.SYSTEM_TIMER,
+ processingTime,
+ watermarks,
+ windmillTagEncoding,
+ td -> {});
+
+ this.userTimerInternals =
+ new WindmillTimerInternals(
+ stateFamily,
+ WindmillTimerType.USER_TIMER,
+ processingTime,
+ watermarks,
+ windmillTagEncoding,
+ this::onUserTimerModified);
+ }
this.cachedFiredSystemTimers = null;
this.cachedFiredUserTimers = null;
- modifiedUserEventTimersOrdered = Sets.newTreeSet();
- modifiedUserProcessingTimersOrdered = Sets.newTreeSet();
- modifiedUserSynchronizedProcessingTimersOrdered = Sets.newTreeSet();
- modifiedUserTimerKeys = HashBasedTable.create();
+ this.modifiedUserEventTimersOrdered.clear();
+ this.modifiedUserProcessingTimersOrdered.clear();
+ this.modifiedUserSynchronizedProcessingTimersOrdered.clear();
+ this.modifiedUserTimerKeys.clear();
}
public void flushState() {
- stateInternals.persist(outputBuilder);
- systemTimerInternals.persistTo(outputBuilder);
- userTimerInternals.persistTo(outputBuilder);
+ if (stateFamily != null) {
+ WorkItemCommitRequest.Builder builder = getOutputBuilder();
+ checkStateNotNull(stateInternals).persist(builder);
+ checkStateNotNull(systemTimerInternals).persistTo(builder);
+ checkStateNotNull(userTimerInternals).persistTo(builder);
+ }
}
@Override
@@ -893,9 +899,14 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
}
@Override
- public <W extends BoundedWindow> TimerData getNextFiredTimer(Coder<W>
windowCoder) {
- if (cachedFiredSystemTimers == null) {
- cachedFiredSystemTimers =
+ public <W extends BoundedWindow> @Nullable TimerData
getNextFiredTimer(Coder<W> windowCoder) {
+ if (stateFamily == null) {
+ // no timers on stateless stages
+ return null;
+ }
+ Iterator<TimerData> firedSystemTimers = cachedFiredSystemTimers;
+ if (firedSystemTimers == null) {
+ firedSystemTimers =
FluentIterable.from(StreamingModeExecutionContext.this.getFiredTimers())
.filter(timer -> timer.getStateFamily().equals(stateFamily))
.transform(
@@ -907,16 +918,17 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
windmillTimerData.getWindmillTimerType() ==
WindmillTimerType.SYSTEM_TIMER)
.transform(WindmillTimerData::getTimerData)
.iterator();
+ cachedFiredSystemTimers = firedSystemTimers;
}
- if (!cachedFiredSystemTimers.hasNext()) {
+ if (!firedSystemTimers.hasNext()) {
return null;
}
- TimerData nextTimer = cachedFiredSystemTimers.next();
+ TimerData nextTimer = firedSystemTimers.next();
// system timers ( GC timer) must be explicitly deleted if only there is
a hold.
// if timestamp is not equals to outputTimestamp then there should be a
hold
if (!nextTimer.getTimestamp().equals(nextTimer.getOutputTimestamp())) {
- systemTimerInternals.deleteTimer(nextTimer);
+ checkStateNotNull(systemTimerInternals).deleteTimer(nextTimer);
}
return nextTimer;
}
@@ -950,12 +962,19 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
return updatedTimer == null || updatedTimer.equals(timerData);
}
- public <W extends BoundedWindow> TimerData getNextFiredUserTimer(Coder<W>
windowCoder) {
- if (cachedFiredUserTimers == null) {
+ public <W extends BoundedWindow> @Nullable TimerData getNextFiredUserTimer(
+ Coder<W> windowCoder) {
+ if (stateFamily == null) {
+ // no timers on stateless stages
+ return null;
+ }
+
+ PeekingIterator<TimerData> firedUserTimers = cachedFiredUserTimers;
+ if (firedUserTimers == null) {
// This is the first call to getNextFiredUserTimer in this bundle.
Extract any user timers
// from the bundle
// and cache the list for the rest of this bundle processing.
- cachedFiredUserTimers =
+ firedUserTimers =
Iterators.peekingIterator(
FluentIterable.from(StreamingModeExecutionContext.this.getFiredTimers())
.filter(timer ->
timer.getStateFamily().equals(stateFamily))
@@ -969,17 +988,20 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
== WindmillTimerType.USER_TIMER)
.transform(WindmillTimerData::getTimerData)
.iterator());
+ cachedFiredUserTimers = firedUserTimers;
}
- while (cachedFiredUserTimers.hasNext()) {
- TimerData nextInBundle = cachedFiredUserTimers.peek();
+ WindmillTimerInternals nonNullUserTimerInternals =
checkStateNotNull(this.userTimerInternals);
+
+ while (firedUserTimers.hasNext()) {
+ TimerData nextInBundle = firedUserTimers.peek();
NavigableSet<TimerData> modifiedUserTimersOrdered =
getModifiedUserTimersOrdered(nextInBundle.getDomain());
// If there is a modified timer that is earlier than the next timer in
the bundle, try and
// fire that first.
while (!modifiedUserTimersOrdered.isEmpty()
&& modifiedUserTimersOrdered.first().compareTo(nextInBundle) <= 0)
{
- TimerData earlierTimer = modifiedUserTimersOrdered.pollFirst();
+ TimerData earlierTimer =
checkStateNotNull(modifiedUserTimersOrdered.pollFirst());
if (isTimerUnmodified(earlierTimer)) {
// We must delete the timer. This prevents it from being committed
to the backing store.
// It also handles the
@@ -987,15 +1009,15 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
// without deleting the
// timer, the runner will still have that future timer stored, and
would fire it
// spuriously.
- userTimerInternals.deleteTimer(earlierTimer);
+ nonNullUserTimerInternals.deleteTimer(earlierTimer);
return earlierTimer;
}
}
// There is no earlier timer to fire, so return the next timer in the
bundle.
- nextInBundle = cachedFiredUserTimers.next();
+ nextInBundle = firedUserTimers.next();
if (isTimerUnmodified(nextInBundle)) {
// User timers must be explicitly deleted when delivered, to release
the implied hold.
- userTimerInternals.deleteTimer(nextInBundle);
+ nonNullUserTimerInternals.deleteTimer(nextInBundle);
return nextInBundle;
}
}
@@ -1029,12 +1051,6 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
return StreamingModeExecutionContext.this.getSideInputNotifications();
}
- private void ensureStateful(String errorPrefix) {
- if (stateFamily == null) {
- throw new IllegalStateException(errorPrefix + " for stateless step: "
+ getNameContext());
- }
- }
-
@Override
public <T, W extends BoundedWindow> void writePCollectionViewData(
TupleTag<?> tag,
@@ -1043,7 +1059,8 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
W window,
Coder<W> windowCoder)
throws IOException {
- if (getSerializedKey().size() != 0) {
+ ByteString serializedKey = checkStateNotNull(getSerializedKey());
+ if (serializedKey.size() != 0) {
throw new IllegalStateException("writePCollectionViewData must follow
a Combine.globally");
}
@@ -1053,7 +1070,7 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
ByteStringOutputStream windowStream = new ByteStringOutputStream();
windowCoder.encode(window, windowStream, Coder.Context.OUTER);
- ensureStateful("Tried to write view data");
+ String stateFamily = checkStateNotNull(this.stateFamily, "Tried to write
view data");
Windmill.GlobalData.Builder builder =
Windmill.GlobalData.newBuilder()
@@ -1065,7 +1082,7 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
.setData(dataStream.toByteString())
.setStateFamily(stateFamily);
- outputBuilder.addGlobalDataUpdates(builder.build());
+ getOutputBuilder().addGlobalDataUpdates(builder.build());
}
/** Fetch the given side input asynchronously and return true if it is
present. */
@@ -1080,11 +1097,12 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
/** Note that there is data on the current key that is blocked on the
given side input. */
@Override
public void addBlockingSideInput(Windmill.GlobalDataRequest sideInput) {
- ensureStateful("Tried to set global data request");
+ String stateFamily = checkStateNotNull(this.stateFamily, "Tried to set
global data request");
sideInput =
Windmill.GlobalDataRequest.newBuilder(sideInput).setStateFamily(stateFamily).build();
- outputBuilder.addGlobalDataRequests(sideInput);
- outputBuilder.addGlobalDataIdRequests(sideInput.getDataId());
+ WorkItemCommitRequest.Builder builder = getOutputBuilder();
+ builder.addGlobalDataRequests(sideInput);
+ builder.addGlobalDataIdRequests(sideInput.getDataId());
}
/** Note that there is data on the current key that is blocked on the
given side inputs. */
@@ -1097,14 +1115,12 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
@Override
public StateInternals stateInternals() {
- ensureStateful("Tried to access state");
- return checkNotNull(stateInternals);
+ return checkStateNotNull(stateInternals, "Tried to access state");
}
@Override
public TimerInternals timerInternals() {
- ensureStateful("Tried to access timers");
- return checkNotNull(systemTimerInternals);
+ return checkStateNotNull(systemTimerInternals, "Tried to access timers");
}
@Override
@@ -1113,8 +1129,7 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
}
public TimerInternals userTimerInternals() {
- ensureStateful("Tried to access user timers");
- return checkNotNull(userTimerInternals);
+ return checkStateNotNull(userTimerInternals, "Tried to access user
timers");
}
public ImmutableList<Pair<Instant, BundleFinalizer.Callback>>
flushBundleFinalizerCallbacks() {
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
index 216ca538667..13601410bfd 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
@@ -61,6 +61,7 @@ import
org.apache.beam.runners.dataflow.worker.streaming.Watermarks;
import org.apache.beam.runners.dataflow.worker.streaming.Work;
import
org.apache.beam.runners.dataflow.worker.streaming.config.FakeGlobalConfigHandle;
import
org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfig;
+import
org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfigHandle;
import
org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher;
import org.apache.beam.runners.dataflow.worker.util.common.worker.WorkExecutor;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
@@ -110,36 +111,40 @@ public class StreamingModeExecutionContextTest {
DataflowWorkerHarnessOptions options;
private FakeGlobalConfigHandle globalConfigHandle;
+ private StreamingModeExecutionContext createExecutionContext(
+ StreamingGlobalConfigHandle configHandle) {
+ CounterSet counterSet = new CounterSet();
+ ConcurrentHashMap<String, String> stateNameMap = new ConcurrentHashMap<>();
+ stateNameMap.put(NameContextsForTests.nameContextForTest().userName(),
"testStateFamily");
+ return new StreamingModeExecutionContext(
+ counterSet,
+ COMPUTATION_ID,
+ new ReaderCache(Duration.standardMinutes(1),
Executors.newCachedThreadPool()),
+ stateNameMap,
+ WindmillStateCache.builder()
+ .setSizeMb(options.getWorkerCacheMb())
+ .build()
+ .forComputation("comp"),
+ StreamingStepMetricsContainer.createRegistry(),
+ new DataflowExecutionStateTracker(
+ ExecutionStateSampler.newForTest(),
+ executionStateRegistry.getState(
+ NameContext.forStage("stage"), "other", null,
NoopProfileScope.NOOP),
+ counterSet,
+ PipelineOptionsFactory.create(),
+ "test-work-item-id"),
+ executionStateRegistry,
+ configHandle,
+ Long.MAX_VALUE,
+ /*throwExceptionOnLargeOutput=*/ false);
+ }
+
@Before
public void setUp() {
MockitoAnnotations.initMocks(this);
options = PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class);
- CounterSet counterSet = new CounterSet();
- ConcurrentHashMap<String, String> stateNameMap = new ConcurrentHashMap<>();
globalConfigHandle = new
FakeGlobalConfigHandle(StreamingGlobalConfig.builder().build());
- stateNameMap.put(NameContextsForTests.nameContextForTest().userName(),
"testStateFamily");
- executionContext =
- new StreamingModeExecutionContext(
- counterSet,
- COMPUTATION_ID,
- new ReaderCache(Duration.standardMinutes(1),
Executors.newCachedThreadPool()),
- stateNameMap,
- WindmillStateCache.builder()
- .setSizeMb(options.getWorkerCacheMb())
- .build()
- .forComputation("comp"),
- StreamingStepMetricsContainer.createRegistry(),
- new DataflowExecutionStateTracker(
- ExecutionStateSampler.newForTest(),
- executionStateRegistry.getState(
- NameContext.forStage("stage"), "other", null,
NoopProfileScope.NOOP),
- counterSet,
- PipelineOptionsFactory.create(),
- "test-work-item-id"),
- executionStateRegistry,
- globalConfigHandle,
- Long.MAX_VALUE,
- /*throwExceptionOnLargeOutput=*/ false);
+ executionContext = createExecutionContext(globalConfigHandle);
}
private static Work createMockWork(Windmill.WorkItem workItem, Watermarks
watermarks) {
@@ -421,20 +426,11 @@ public class StreamingModeExecutionContextTest {
for (Boolean isV2Encoding : Lists.newArrayList(Boolean.TRUE,
Boolean.FALSE)) {
Class<?> expectedEncoding =
isV2Encoding ? WindmillTagEncodingV2.class :
WindmillTagEncodingV1.class;
- Windmill.WorkItemCommitRequest.Builder outputBuilder =
- Windmill.WorkItemCommitRequest.newBuilder();
- globalConfigHandle.setConfig(
-
StreamingGlobalConfig.builder().setEnableStateTagEncodingV2(isV2Encoding).build());
- executionContext.start(
- "key",
- createMockWork(
-
Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(17L).build(),
- Watermarks.builder().setInputDataWatermark(new
Instant(1000)).build()),
- stateReader,
- sideInputStateFetcher,
- outputBuilder,
- workExecutor);
- assertEquals(expectedEncoding,
executionContext.getWindmillTagEncoding().getClass());
+ FakeGlobalConfigHandle configHandle =
+ new FakeGlobalConfigHandle(
+
StreamingGlobalConfig.builder().setEnableStateTagEncodingV2(isV2Encoding).build());
+ StreamingModeExecutionContext context =
createExecutionContext(configHandle);
+ assertEquals(expectedEncoding,
context.getWindmillTagEncoding().getClass());
}
}