This is an automated email from the ASF dual-hosted git repository.
scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new c6ab5fe272f Refactor StateFetcher (#28755)
c6ab5fe272f is described below
commit c6ab5fe272f74749713a5ccf7c21a398d017a606
Author: martin trieu <[email protected]>
AuthorDate: Fri Oct 20 05:20:37 2023 -0700
Refactor StateFetcher (#28755)
Refactor and cleanup of StateFetcher in preparation for future changes
---
.../beam/runners/dataflow/worker/StateFetcher.java | 291 ---------------------
.../dataflow/worker/StreamingDataflowWorker.java | 11 +-
.../worker/StreamingModeExecutionContext.java | 114 ++++----
.../dataflow/worker/StreamingSideInputFetcher.java | 8 +-
.../worker/streaming/sideinput/SideInput.java | 50 ++++
.../worker/streaming/sideinput/SideInputCache.java | 113 ++++++++
.../worker/streaming/sideinput/SideInputState.java | 25 ++
.../streaming/sideinput/SideInputStateFetcher.java | 245 +++++++++++++++++
.../worker/StreamingDataflowWorkerTest.java | 6 +-
.../worker/StreamingModeExecutionContextTest.java | 9 +-
.../worker/StreamingSideInputDoFnRunnerTest.java | 2 +-
.../worker/StreamingSideInputFetcherTest.java | 2 +-
.../sideinput/SideInputStateFetcherTest.java} | 170 +++++++-----
13 files changed, 626 insertions(+), 420 deletions(-)
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StateFetcher.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StateFetcher.java
deleted file mode 100644
index 0cbcd2e8301..00000000000
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StateFetcher.java
+++ /dev/null
@@ -1,291 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.dataflow.worker;
-
-import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;
-
-import java.io.Closeable;
-import java.util.Collections;
-import java.util.Objects;
-import java.util.Set;
-import java.util.concurrent.Callable;
-import java.util.concurrent.TimeUnit;
-import org.apache.beam.runners.core.InMemoryMultimapSideInputView;
-import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.IterableCoder;
-import org.apache.beam.sdk.coders.KvCoder;
-import org.apache.beam.sdk.transforms.Materializations;
-import org.apache.beam.sdk.transforms.Materializations.IterableView;
-import org.apache.beam.sdk.transforms.Materializations.MultimapView;
-import org.apache.beam.sdk.transforms.ViewFn;
-import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
-import org.apache.beam.sdk.util.ByteStringOutputStream;
-import org.apache.beam.sdk.values.PCollectionView;
-import org.apache.beam.sdk.values.TupleTag;
-import org.apache.beam.sdk.values.WindowingStrategy;
-import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Optional;
-import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier;
-import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache;
-import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder;
-import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Weigher;
-import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
-import org.checkerframework.checker.nullness.qual.Nullable;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-/** Class responsible for fetching state from the windmill server. */
-@SuppressWarnings({
- "rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
- "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-class StateFetcher {
- private static final Set<String> SUPPORTED_MATERIALIZATIONS =
- ImmutableSet.of(
- Materializations.ITERABLE_MATERIALIZATION_URN,
- Materializations.MULTIMAP_MATERIALIZATION_URN);
-
- private static final Logger LOG =
LoggerFactory.getLogger(StateFetcher.class);
-
- private Cache<SideInputId, SideInputCacheEntry> sideInputCache;
- private MetricTrackingWindmillServerStub server;
- private long bytesRead = 0L;
-
- public StateFetcher(MetricTrackingWindmillServerStub server) {
- this(
- server,
- CacheBuilder.newBuilder()
- .maximumWeight(100000000 /* 100 MB */)
- .expireAfterWrite(1, TimeUnit.MINUTES)
- .weigher((Weigher<SideInputId, SideInputCacheEntry>) (id, entry)
-> entry.size())
- .build());
- }
-
- public StateFetcher(
- MetricTrackingWindmillServerStub server,
- Cache<SideInputId, SideInputCacheEntry> sideInputCache) {
- this.server = server;
- this.sideInputCache = sideInputCache;
- }
-
- /** Returns a view of the underlying cache that keeps track of bytes read
separately. */
- public StateFetcher byteTrackingView() {
- return new StateFetcher(server, sideInputCache);
- }
-
- public long getBytesRead() {
- return bytesRead;
- }
-
- /** Indicates the caller's knowledge of whether a particular side input has
been computed. */
- public enum SideInputState {
- CACHED_IN_WORKITEM,
- KNOWN_READY,
- UNKNOWN;
- }
-
- /**
- * Fetch the given side input, storing it in a process-level cache.
- *
- * <p>If state is KNOWN_READY, attempt to fetch the data regardless of
whether a not-ready entry
- * was cached.
- *
- * <p>Returns {@literal null} if the side input was not ready, {@literal
Optional.absent()} if the
- * side input was null, and {@literal Optional.present(...)} if the side
input was non-null.
- */
- public @Nullable <T, SideWindowT extends BoundedWindow> Optional<T>
fetchSideInput(
- final PCollectionView<T> view,
- final SideWindowT sideWindow,
- final String stateFamily,
- SideInputState state,
- final Supplier<Closeable> scopedReadStateSupplier) {
- final SideInputId id = new SideInputId(view.getTagInternal(), sideWindow);
-
- Callable<SideInputCacheEntry> fetchCallable =
- () -> {
- @SuppressWarnings("unchecked")
- WindowingStrategy<?, SideWindowT> sideWindowStrategy =
- (WindowingStrategy<?, SideWindowT>)
view.getWindowingStrategyInternal();
-
- Coder<SideWindowT> windowCoder =
sideWindowStrategy.getWindowFn().windowCoder();
-
- ByteStringOutputStream windowStream = new ByteStringOutputStream();
- windowCoder.encode(sideWindow, windowStream, Coder.Context.OUTER);
-
- @SuppressWarnings("unchecked")
- Windmill.GlobalDataRequest request =
- Windmill.GlobalDataRequest.newBuilder()
- .setDataId(
- Windmill.GlobalDataId.newBuilder()
- .setTag(view.getTagInternal().getId())
- .setVersion(windowStream.toByteString())
- .build())
- .setStateFamily(stateFamily)
- .setExistenceWatermarkDeadline(
- WindmillTimeUtils.harnessToWindmillTimestamp(
- sideWindowStrategy
- .getTrigger()
- .getWatermarkThatGuaranteesFiring(sideWindow)))
- .build();
-
- Windmill.GlobalData data;
- try (Closeable scope = scopedReadStateSupplier.get()) {
- data = server.getSideInputData(request);
- }
-
- bytesRead += data.getSerializedSize();
-
- checkState(
-
SUPPORTED_MATERIALIZATIONS.contains(view.getViewFn().getMaterialization().getUrn()),
- "Only materializations of type %s supported, received %s",
- SUPPORTED_MATERIALIZATIONS,
- view.getViewFn().getMaterialization().getUrn());
-
- Iterable<?> rawData;
- if (data.getIsReady()) {
- if (data.getData().size() > 0) {
- rawData =
- IterableCoder.of(view.getCoderInternal())
- .decode(data.getData().newInput(), Coder.Context.OUTER);
- } else {
- rawData = Collections.emptyList();
- }
-
- switch (view.getViewFn().getMaterialization().getUrn()) {
- case Materializations.ITERABLE_MATERIALIZATION_URN:
- {
- ViewFn<IterableView, T> viewFn = (ViewFn<IterableView, T>)
view.getViewFn();
- return SideInputCacheEntry.ready(
- viewFn.apply(() -> rawData), data.getData().size());
- }
- case Materializations.MULTIMAP_MATERIALIZATION_URN:
- {
- ViewFn<MultimapView, T> viewFn = (ViewFn<MultimapView, T>)
view.getViewFn();
- Coder<?> keyCoder = ((KvCoder<?, ?>)
view.getCoderInternal()).getKeyCoder();
- return SideInputCacheEntry.ready(
- viewFn.apply(
- InMemoryMultimapSideInputView.fromIterable(keyCoder,
(Iterable) rawData)),
- data.getData().size());
- }
- default:
- throw new IllegalStateException(
- String.format(
- "Unknown side input materialization format requested
'%s'",
- view.getViewFn().getMaterialization().getUrn()));
- }
- } else {
- return SideInputCacheEntry.notReady();
- }
- };
-
- try {
- if (state == SideInputState.KNOWN_READY) {
- SideInputCacheEntry entry = sideInputCache.getIfPresent(id);
- if (entry == null) {
- return sideInputCache.get(id, fetchCallable).getValue();
- } else if (!entry.isReady()) {
- // Invalidate the existing not-ready entry. This must be done
atomically
- // so that another thread doesn't replace the entry with a ready
entry, which
- // would then be deleted here.
- synchronized (entry) {
- SideInputCacheEntry newEntry = sideInputCache.getIfPresent(id);
- if (newEntry != null && !newEntry.isReady()) {
- sideInputCache.invalidate(id);
- }
- }
-
- return sideInputCache.get(id, fetchCallable).getValue();
- } else {
- return entry.getValue();
- }
- } else {
- return sideInputCache.get(id, fetchCallable).getValue();
- }
- } catch (Exception e) {
- LOG.error("Fetch failed: ", e);
- throw new RuntimeException("Exception while fetching side input: ", e);
- }
- }
-
- /** Struct representing a side input for a particular window. */
- static class SideInputId {
- private final TupleTag<?> tag;
- private final BoundedWindow window;
-
- public SideInputId(TupleTag<?> tag, BoundedWindow window) {
- this.tag = tag;
- this.window = window;
- }
-
- @Override
- public boolean equals(@Nullable Object other) {
- if (other instanceof SideInputId) {
- SideInputId otherId = (SideInputId) other;
- return tag.equals(otherId.tag) && window.equals(otherId.window);
- }
- return false;
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(tag, window);
- }
- }
-
- /**
- * Entry in the side input cache that stores the value (null if not ready),
and the encoded size
- * of the value.
- */
- static class SideInputCacheEntry {
- private final boolean ready;
- private final Object value;
- private final int encodedSize;
-
- private SideInputCacheEntry(boolean ready, Object value, int encodedSize) {
- this.ready = ready;
- this.value = value;
- this.encodedSize = encodedSize;
- }
-
- public static SideInputCacheEntry ready(Object value, int encodedSize) {
- return new SideInputCacheEntry(true, value, encodedSize);
- }
-
- public static SideInputCacheEntry notReady() {
- return new SideInputCacheEntry(false, null, 0);
- }
-
- public boolean isReady() {
- return ready;
- }
-
- /**
- * Returns {@literal null} if the side input was not ready, {@literal
Optional.absent()} if the
- * side input was null, and {@literal Optional.present(...)} if the side
input was non-null.
- */
- public @Nullable <T> Optional<T> getValue() {
- @SuppressWarnings("unchecked")
- T typed = (T) value;
- return ready ? Optional.fromNullable(typed) : null;
- }
-
- public int size() {
- return encodedSize;
- }
- }
-}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
index 4c1693d6138..77f5205cf7e 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
@@ -94,6 +94,7 @@ import
org.apache.beam.runners.dataflow.worker.streaming.StageInfo;
import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue;
import org.apache.beam.runners.dataflow.worker.streaming.Work;
import org.apache.beam.runners.dataflow.worker.streaming.Work.State;
+import
org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher;
import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor;
import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor;
import
org.apache.beam.runners.dataflow.worker.util.common.worker.ElementCounter;
@@ -228,7 +229,7 @@ public class StreamingDataflowWorker {
private final Thread commitThread;
private final AtomicLong activeCommitBytes = new AtomicLong();
private final AtomicBoolean running = new AtomicBoolean();
- private final StateFetcher stateFetcher;
+ private final SideInputStateFetcher sideInputStateFetcher;
private final StreamingDataflowWorkerOptions options;
private final boolean windmillServiceEnabled;
private final long clientId;
@@ -406,7 +407,7 @@ public class StreamingDataflowWorker {
this.metricTrackingWindmillServer =
new MetricTrackingWindmillServerStub(windmillServer, memoryMonitor,
windmillServiceEnabled);
this.metricTrackingWindmillServer.start();
- this.stateFetcher = new StateFetcher(metricTrackingWindmillServer);
+ this.sideInputStateFetcher = new
SideInputStateFetcher(metricTrackingWindmillServer);
this.clientId = clientIdGenerator.nextLong();
for (MapTask mapTask : mapTasks) {
@@ -1078,7 +1079,7 @@ public class StreamingDataflowWorker {
}
};
});
- StateFetcher localStateFetcher = stateFetcher.byteTrackingView();
+ SideInputStateFetcher localSideInputStateFetcher =
sideInputStateFetcher.byteTrackingView();
// If the read output KVs, then we can decode Windmill's byte key into a
userland
// key object and provide it to the execution context for use with
per-key state.
@@ -1114,7 +1115,7 @@ public class StreamingDataflowWorker {
outputDataWatermark,
synchronizedProcessingTime,
stateReader,
- localStateFetcher,
+ localSideInputStateFetcher,
outputBuilder);
// Blocks while executing work.
@@ -1184,7 +1185,7 @@ public class StreamingDataflowWorker {
shuffleBytesRead += message.getSerializedSize();
}
}
- long stateBytesRead = stateReader.getBytesRead() +
localStateFetcher.getBytesRead();
+ long stateBytesRead = stateReader.getBytesRead() +
localSideInputStateFetcher.getBytesRead();
windmillShuffleBytesRead.addValue(shuffleBytesRead);
windmillStateBytesRead.addValue(stateBytesRead);
windmillStateBytesWritten.addValue(stateBytesWritten);
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java
index c8fa6e6dfb7..d630601c28a 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java
@@ -30,6 +30,7 @@ import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NavigableSet;
+import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicLong;
@@ -45,6 +46,9 @@ import
org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext.Ste
import org.apache.beam.runners.dataflow.worker.counters.CounterFactory;
import org.apache.beam.runners.dataflow.worker.counters.NameContext;
import
org.apache.beam.runners.dataflow.worker.profiler.ScopedProfiler.ProfileScope;
+import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInput;
+import
org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputState;
+import
org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataId;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest;
@@ -62,7 +66,7 @@ import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
-import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Optional;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.FluentIterable;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.HashBasedTable;
@@ -86,7 +90,7 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
private static final Logger LOG =
LoggerFactory.getLogger(StreamingModeExecutionContext.class);
private final String computationId;
- private final Map<TupleTag<?>, Map<BoundedWindow, Object>> sideInputCache;
+ private final Map<TupleTag<?>, Map<BoundedWindow, SideInput<?>>>
sideInputCache;
// Per-key cache of active Reader objects in use by this process.
private final ImmutableMap<String, String> stateNameMap;
private final WindmillStateCache.ForComputation stateCache;
@@ -104,7 +108,7 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
private Windmill.WorkItem work;
private WindmillComputationKey computationKey;
- private StateFetcher stateFetcher;
+ private SideInputStateFetcher sideInputStateFetcher;
private Windmill.WorkItemCommitRequest.Builder outputBuilder;
private UnboundedSource.UnboundedReader<?> activeReader;
private volatile long backlogBytes;
@@ -145,20 +149,20 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
@Nullable Instant outputDataWatermark,
@Nullable Instant synchronizedProcessingTime,
WindmillStateReader stateReader,
- StateFetcher stateFetcher,
+ SideInputStateFetcher sideInputStateFetcher,
Windmill.WorkItemCommitRequest.Builder outputBuilder) {
this.key = key;
this.work = work;
this.computationKey =
WindmillComputationKey.create(computationId, work.getKey(),
work.getShardingKey());
- this.stateFetcher = stateFetcher;
+ this.sideInputStateFetcher = sideInputStateFetcher;
this.outputBuilder = outputBuilder;
this.sideInputCache.clear();
clearSinkFullHint();
Instant processingTime = Instant.now();
// Ensure that the processing time is greater than any fired processing
time
- // timers. Otherwise a trigger could ignore the timer and orphan the
window.
+ // timers. Otherwise, a trigger could ignore the timer and orphan the
window.
for (Windmill.Timer timer : work.getTimers().getTimersList()) {
if (timer.getType() == Windmill.Timer.Type.REALTIME) {
Instant inferredFiringTime =
@@ -208,42 +212,67 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
return StreamingModeSideInputReader.of(views, this);
}
+ @SuppressWarnings("deprecation")
+ private <T> TupleTag<?> getInternalTag(PCollectionView<T> view) {
+ return view.getTagInternal();
+ }
+
/**
* Fetches the requested sideInput, and maintains a view of the cache that
doesn't remove items
* until the active work item is finished.
*
- * <p>If the side input was not ready, throws {@code IllegalStateException}
if the state is
- * {@literal CACHED_IN_WORKITEM} or returns null otherwise.
- *
- * <p>If the side input was ready and null, returns {@literal
Optional.absent()}. If the side
- * input was ready and non-null returns {@literal Optional.present(...)}.
+ * <p>If the side input was not cached, throws {@code IllegalStateException}
if the state is
+ * {@literal CACHED_IN_WORK_ITEM} or returns {@link SideInput<T>} which
contains {@link
+ * Optional<T>}.
*/
- private @Nullable <T> Optional<T> fetchSideInput(
+ private <T> SideInput<T> fetchSideInput(
+ PCollectionView<T> view,
+ BoundedWindow sideInputWindow,
+ @Nullable String stateFamily,
+ SideInputState state,
+ @Nullable Supplier<Closeable> scopedReadStateSupplier) {
+ TupleTag<?> viewInternalTag = getInternalTag(view);
+ Map<BoundedWindow, SideInput<?>> tagCache =
+ sideInputCache.computeIfAbsent(viewInternalTag, k -> new HashMap<>());
+
+ @SuppressWarnings("unchecked")
+ Optional<SideInput<T>> cachedSideInput =
+ Optional.ofNullable((SideInput<T>) tagCache.get(sideInputWindow));
+
+ if (cachedSideInput.isPresent()) {
+ return cachedSideInput.get();
+ }
+
+ if (state == SideInputState.CACHED_IN_WORK_ITEM) {
+ throw new IllegalStateException(
+ "Expected side input to be cached. Tag: " + viewInternalTag.getId());
+ }
+
+ return fetchSideInputFromWindmill(
+ view,
+ sideInputWindow,
+ Preconditions.checkNotNull(stateFamily),
+ state,
+ Preconditions.checkNotNull(scopedReadStateSupplier),
+ tagCache);
+ }
+
+ private <T> SideInput<T> fetchSideInputFromWindmill(
PCollectionView<T> view,
BoundedWindow sideInputWindow,
String stateFamily,
- StateFetcher.SideInputState state,
- Supplier<Closeable> scopedReadStateSupplier) {
- Map<BoundedWindow, Object> tagCache =
- sideInputCache.computeIfAbsent(view.getTagInternal(), k -> new
HashMap<>());
+ SideInputState state,
+ Supplier<Closeable> scopedReadStateSupplier,
+ Map<BoundedWindow, SideInput<?>> tagCache) {
+ SideInput<T> fetched =
+ sideInputStateFetcher.fetchSideInput(
+ view, sideInputWindow, stateFamily, state,
scopedReadStateSupplier);
- if (tagCache.containsKey(sideInputWindow)) {
- @SuppressWarnings("unchecked")
- T typed = (T) tagCache.get(sideInputWindow);
- return Optional.fromNullable(typed);
- } else {
- if (state == StateFetcher.SideInputState.CACHED_IN_WORKITEM) {
- throw new IllegalStateException(
- "Expected side input to be cached. Tag: " +
view.getTagInternal().getId());
- }
- Optional<T> fetched =
- stateFetcher.fetchSideInput(
- view, sideInputWindow, stateFamily, state,
scopedReadStateSupplier);
- if (fetched != null) {
- tagCache.put(sideInputWindow, fetched.orNull());
- }
- return fetched;
+ if (fetched.isReady()) {
+ tagCache.put(sideInputWindow, fetched);
}
+
+ return fetched;
}
public Iterable<Windmill.GlobalDataId> getSideInputNotifications() {
@@ -378,8 +407,7 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
interface StreamingModeStepContext {
- boolean issueSideInputFetch(
- PCollectionView<?> view, BoundedWindow w, StateFetcher.SideInputState
s);
+ boolean issueSideInputFetch(PCollectionView<?> view, BoundedWindow w,
SideInputState s);
void addBlockingSideInput(Windmill.GlobalDataRequest blocked);
@@ -412,10 +440,7 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
// 2. The reporting thread calls extractUpdate which reads the current sum
*AND* sets it to 0.
private final AtomicLong totalMillisInState = new AtomicLong();
- // The worker that created this state. Used to report lulls back to the
worker.
- @SuppressWarnings("unused") // Affects a public api
- private final StreamingDataflowWorker worker;
-
+ @SuppressWarnings("unused")
public StreamingModeExecutionState(
NameContext nameContext,
String stateName,
@@ -424,7 +449,6 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
StreamingDataflowWorker worker) {
// TODO: Take in the requesting step name and side input index for
streaming.
super(nameContext, stateName, null, null, metricsContainer,
profileScope);
- this.worker = worker;
}
/**
@@ -513,8 +537,7 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
}
@Override
- public boolean issueSideInputFetch(
- PCollectionView<?> view, BoundedWindow w, StateFetcher.SideInputState
s) {
+ public boolean issueSideInputFetch(PCollectionView<?> view, BoundedWindow
w, SideInputState s) {
return wrapped.issueSideInputFetch(view, w, s);
}
@@ -609,9 +632,10 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
view,
window,
null /* unused stateFamily */,
- StateFetcher.SideInputState.CACHED_IN_WORKITEM,
+ SideInputState.CACHED_IN_WORK_ITEM,
null /* unused readStateSupplier */)
- .orNull();
+ .value()
+ .orElse(null);
}
@Override
@@ -883,10 +907,10 @@ public class StreamingModeExecutionContext extends
DataflowExecutionContext<Step
/** Fetch the given side input asynchronously and return true if it is
present. */
@Override
public boolean issueSideInputFetch(
- PCollectionView<?> view, BoundedWindow mainInputWindow,
StateFetcher.SideInputState state) {
+ PCollectionView<?> view, BoundedWindow mainInputWindow, SideInputState
state) {
BoundedWindow sideInputWindow =
view.getWindowMappingFn().getSideInputWindow(mainInputWindow);
return fetchSideInput(view, sideInputWindow, stateFamily, state,
scopedReadStateSupplier)
- != null;
+ .isReady();
}
/** Note that there is data on the current key that is blocked on the
given side input. */
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputFetcher.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputFetcher.java
index 2b551acd2d8..4f585e1c01b 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputFetcher.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputFetcher.java
@@ -33,6 +33,7 @@ import org.apache.beam.runners.core.StateTags;
import org.apache.beam.runners.core.TimerInternals.TimerData;
import org.apache.beam.runners.core.TimerInternals.TimerDataCoder;
import org.apache.beam.runners.core.TimerInternals.TimerDataCoderV2;
+import
org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputState;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest;
import org.apache.beam.sdk.coders.AtomicCoder;
@@ -135,8 +136,7 @@ public class StreamingSideInputFetcher<InputT, W extends
BoundedWindow> {
W window = entry.getKey();
boolean allSideInputsCached = true;
for (PCollectionView<?> view : sideInputViews.values()) {
- if (!stepContext.issueSideInputFetch(
- view, window, StateFetcher.SideInputState.KNOWN_READY)) {
+ if (!stepContext.issueSideInputFetch(view, window,
SideInputState.KNOWN_READY)) {
Windmill.GlobalDataRequest request =
buildGlobalDataRequest(view, window);
stepContext.addBlockingSideInput(request);
windowBlockedSet.add(request);
@@ -192,7 +192,7 @@ public class StreamingSideInputFetcher<InputT, W extends
BoundedWindow> {
Set<Windmill.GlobalDataRequest> blocked = blockedMap().get(window);
if (blocked == null) {
for (PCollectionView<?> view : sideInputViews.values()) {
- if (!stepContext.issueSideInputFetch(view, window,
StateFetcher.SideInputState.UNKNOWN)) {
+ if (!stepContext.issueSideInputFetch(view, window,
SideInputState.UNKNOWN)) {
if (blocked == null) {
blocked = new HashSet<>();
blockedMap().put(window, blocked);
@@ -222,7 +222,7 @@ public class StreamingSideInputFetcher<InputT, W extends
BoundedWindow> {
boolean blocked = false;
for (PCollectionView<?> view : sideInputViews.values()) {
- if (!stepContext.issueSideInputFetch(view, window,
StateFetcher.SideInputState.UNKNOWN)) {
+ if (!stepContext.issueSideInputFetch(view, window,
SideInputState.UNKNOWN)) {
blocked = true;
}
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInput.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInput.java
new file mode 100644
index 00000000000..04eecadc1e5
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInput.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.streaming.sideinput;
+
+import com.google.auto.value.AutoValue;
+import java.util.Optional;
+import javax.annotation.Nullable;
+
+/**
+ * Entry in the side input cache that stores the value and the encoded size of
the value.
+ *
+ * <p>Can be in 1 of 3 states:
+ *
+ * <ul>
+ * <li>Ready with a <T> value.
+ * <li>Ready with no value, represented as {@link Optional<T>}
+ * <li>Not ready.
+ * </ul>
+ */
+@AutoValue
+public abstract class SideInput<T> {
+ static <T> SideInput<T> ready(@Nullable T value, int encodedSize) {
+ return new AutoValue_SideInput<>(true, Optional.ofNullable(value),
encodedSize);
+ }
+
+ static <T> SideInput<T> notReady() {
+ return new AutoValue_SideInput<>(false, Optional.empty(), 0);
+ }
+
+ public abstract boolean isReady();
+
+ public abstract Optional<T> value();
+
+ public abstract int size();
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputCache.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputCache.java
new file mode 100644
index 00000000000..721c477435e
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputCache.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.streaming.sideinput;
+
+import com.google.auto.value.AutoValue;
+import com.google.errorprone.annotations.CheckReturnValue;
+import java.util.Optional;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.TypeDescriptor;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Weigher;
+
+/**
+ * Wrapper around {@code Cache<SideInputId, SideInput>} that mostly delegates
to the underlying
+ * cache, but adds threadsafe functionality to invalidate and load entries
that are not ready.
+ *
+ * @implNote Returned values are explicitly cast, because the {@link
#sideInputCache} holds wildcard
+ * types of all objects.
+ */
+@CheckReturnValue
+final class SideInputCache {
+
+ private static final long MAXIMUM_CACHE_WEIGHT = 100000000; /* 100 MB */
+ private static final long CACHE_ENTRY_EXPIRY_MINUTES = 1L;
+
+ private final Cache<Key<?>, SideInput<?>> sideInputCache;
+
+ SideInputCache(Cache<Key<?>, SideInput<?>> sideInputCache) {
+ this.sideInputCache = sideInputCache;
+ }
+
+ static SideInputCache create() {
+ return new SideInputCache(
+ CacheBuilder.newBuilder()
+ .maximumWeight(MAXIMUM_CACHE_WEIGHT)
+ .expireAfterWrite(CACHE_ENTRY_EXPIRY_MINUTES, TimeUnit.MINUTES)
+ .weigher((Weigher<Key<?>, SideInput<?>>) (id, entry) ->
entry.size())
+ .build());
+ }
+
+ synchronized <T> SideInput<T> invalidateThenLoadNewEntry(
+ Key<T> key, Callable<SideInput<T>> cacheLoaderFn) throws
ExecutionException {
+ // Invalidate the existing not-ready entry. This must be done atomically
+ // so that another thread doesn't replace the entry with a ready entry,
which
+ // would then be deleted here.
+ Optional<SideInput<T>> newEntry = getIfPresentUnchecked(key);
+ if (newEntry.isPresent() && !newEntry.get().isReady()) {
+ sideInputCache.invalidate(key);
+ }
+
+ return getUnchecked(key, cacheLoaderFn);
+ }
+
+ <T> Optional<SideInput<T>> get(Key<T> key) {
+ return getIfPresentUnchecked(key);
+ }
+
+ <T> SideInput<T> getOrLoad(Key<T> key, Callable<SideInput<T>> cacheLoaderFn)
+ throws ExecutionException {
+ return getUnchecked(key, cacheLoaderFn);
+ }
+
+ @SuppressWarnings({
+ "unchecked" // cacheLoaderFn loads SideInput<T>, and key is of type T, so
value for Key is
+ // always SideInput<T>.
+ })
+ private <T> SideInput<T> getUnchecked(Key<T> key, Callable<SideInput<T>>
cacheLoaderFn)
+ throws ExecutionException {
+ return (SideInput<T>) sideInputCache.get(key, cacheLoaderFn);
+ }
+
+ @SuppressWarnings({
+ "unchecked" // cacheLoaderFn loads SideInput<T>, and key is of type T, so
value for Key is
+ // always SideInput<T>.
+ })
+ private <T> Optional<SideInput<T>> getIfPresentUnchecked(Key<T> key) {
+ return Optional.ofNullable((SideInput<T>)
sideInputCache.getIfPresent(key));
+ }
+
+ @AutoValue
+ abstract static class Key<T> {
+ static <T> Key<T> create(
+ TupleTag<?> tag, BoundedWindow window, TypeDescriptor<T>
typeDescriptor) {
+ return new AutoValue_SideInputCache_Key<>(tag, window, typeDescriptor);
+ }
+
+ abstract TupleTag<?> tag();
+
+ abstract BoundedWindow window();
+
+ abstract TypeDescriptor<T> typeDescriptor();
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputState.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputState.java
new file mode 100644
index 00000000000..d7af10d29e1
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputState.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.streaming.sideinput;
+
+/** Indicates the caller's knowledge of whether a particular side input has
been computed. */
+public enum SideInputState {
+ CACHED_IN_WORK_ITEM,
+ KNOWN_READY,
+ UNKNOWN
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcher.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcher.java
new file mode 100644
index 00000000000..aa61c421935
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcher.java
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.streaming.sideinput;
+
+import static
org.apache.beam.sdk.transforms.Materializations.ITERABLE_MATERIALIZATION_URN;
+import static
org.apache.beam.sdk.transforms.Materializations.MULTIMAP_MATERIALIZATION_URN;
+import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.Collections;
+import java.util.Optional;
+import java.util.Set;
+import java.util.concurrent.Callable;
+import javax.annotation.concurrent.NotThreadSafe;
+import org.apache.beam.runners.core.InMemoryMultimapSideInputView;
+import
org.apache.beam.runners.dataflow.worker.MetricTrackingWindmillServerStub;
+import org.apache.beam.runners.dataflow.worker.WindmillTimeUtils;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.IterableCoder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.transforms.Materializations.IterableView;
+import org.apache.beam.sdk.transforms.Materializations.MultimapView;
+import org.apache.beam.sdk.transforms.ViewFn;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.ByteStringOutputStream;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.WindowingStrategy;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** Class responsible for fetching state from the windmill server. */
+@NotThreadSafe
+public class SideInputStateFetcher {
+ private static final Logger LOG =
LoggerFactory.getLogger(SideInputStateFetcher.class);
+
+ private static final Set<String> SUPPORTED_MATERIALIZATIONS =
+ ImmutableSet.of(ITERABLE_MATERIALIZATION_URN,
MULTIMAP_MATERIALIZATION_URN);
+
+ private final SideInputCache sideInputCache;
+ private final MetricTrackingWindmillServerStub server;
+ private long bytesRead = 0L;
+
+ public SideInputStateFetcher(MetricTrackingWindmillServerStub server) {
+ this(server, SideInputCache.create());
+ }
+
+ SideInputStateFetcher(MetricTrackingWindmillServerStub server,
SideInputCache sideInputCache) {
+ this.server = server;
+ this.sideInputCache = sideInputCache;
+ }
+
+ private static <T> Iterable<?> decodeRawData(PCollectionView<T> view,
GlobalData data)
+ throws IOException {
+ return !data.getData().isEmpty()
+ ? IterableCoder.of(getCoder(view)).decode(data.getData().newInput())
+ : Collections.emptyList();
+ }
+
+ @SuppressWarnings({
+ "deprecation" // Required as part of the SideInputCacheKey, and not
exposed.
+ })
+ private static <T> TupleTag<?> getInternalTag(PCollectionView<T> view) {
+ return view.getTagInternal();
+ }
+
+ @SuppressWarnings("deprecation")
+ private static <T> ViewFn<?, T> getViewFn(PCollectionView<T> view) {
+ return view.getViewFn();
+ }
+
+ @SuppressWarnings({
+ "deprecation" // The view's internal coder is required to decode the raw
data.
+ })
+ private static <T> Coder<?> getCoder(PCollectionView<T> view) {
+ return view.getCoderInternal();
+ }
+
+ /** Returns a view of the underlying cache that keeps track of bytes read
separately. */
+ public SideInputStateFetcher byteTrackingView() {
+ return new SideInputStateFetcher(server, sideInputCache);
+ }
+
+ public long getBytesRead() {
+ return bytesRead;
+ }
+
+ /**
+ * Fetch the given side input, storing it in a process-level cache.
+ *
+ * <p>If state is KNOWN_READY, attempt to fetch the data regardless of
whether a not-ready entry
+ * was cached.
+ */
+ public <T> SideInput<T> fetchSideInput(
+ PCollectionView<T> view,
+ BoundedWindow sideWindow,
+ String stateFamily,
+ SideInputState state,
+ Supplier<Closeable> scopedReadStateSupplier) {
+ Callable<SideInput<T>> loadSideInputFromWindmill =
+ () -> loadSideInputFromWindmill(view, sideWindow, stateFamily,
scopedReadStateSupplier);
+ SideInputCache.Key<T> sideInputCacheKey =
+ SideInputCache.Key.create(
+ getInternalTag(view), sideWindow,
getViewFn(view).getTypeDescriptor());
+
+ try {
+ if (state == SideInputState.KNOWN_READY) {
+ Optional<SideInput<T>> existingCacheEntry =
sideInputCache.get(sideInputCacheKey);
+ if (!existingCacheEntry.isPresent()) {
+ return sideInputCache.getOrLoad(sideInputCacheKey,
loadSideInputFromWindmill);
+ }
+
+ if (!existingCacheEntry.get().isReady()) {
+ return sideInputCache.invalidateThenLoadNewEntry(
+ sideInputCacheKey, loadSideInputFromWindmill);
+ }
+
+ return existingCacheEntry.get();
+ }
+
+ return sideInputCache.getOrLoad(sideInputCacheKey,
loadSideInputFromWindmill);
+ } catch (Exception e) {
+ LOG.error("Fetch failed: ", e);
+ throw new RuntimeException("Exception while fetching side input: ", e);
+ }
+ }
+
+ private <T, SideWindowT extends BoundedWindow> GlobalData
fetchGlobalDataFromWindmill(
+ PCollectionView<T> view,
+ SideWindowT sideWindow,
+ String stateFamily,
+ Supplier<Closeable> scopedReadStateSupplier)
+ throws IOException {
+ @SuppressWarnings({
+ "deprecation", // Internal windowStrategy is required to fetch side
input data from Windmill.
+ "unchecked" // Internal windowing strategy matches WindowingStrategy<?,
SideWindowT>.
+ })
+ WindowingStrategy<?, SideWindowT> sideWindowStrategy =
+ (WindowingStrategy<?, SideWindowT>)
view.getWindowingStrategyInternal();
+
+ Coder<SideWindowT> windowCoder =
sideWindowStrategy.getWindowFn().windowCoder();
+
+ ByteStringOutputStream windowStream = new ByteStringOutputStream();
+ windowCoder.encode(sideWindow, windowStream);
+
+ Windmill.GlobalDataRequest request =
+ Windmill.GlobalDataRequest.newBuilder()
+ .setDataId(
+ Windmill.GlobalDataId.newBuilder()
+ .setTag(getInternalTag(view).getId())
+ .setVersion(windowStream.toByteString())
+ .build())
+ .setStateFamily(stateFamily)
+ .setExistenceWatermarkDeadline(
+ WindmillTimeUtils.harnessToWindmillTimestamp(
+
sideWindowStrategy.getTrigger().getWatermarkThatGuaranteesFiring(sideWindow)))
+ .build();
+
+ try (Closeable ignored = scopedReadStateSupplier.get()) {
+ return server.getSideInputData(request);
+ }
+ }
+
+ private <T> SideInput<T> loadSideInputFromWindmill(
+ PCollectionView<T> view,
+ BoundedWindow sideWindow,
+ String stateFamily,
+ Supplier<Closeable> scopedReadStateSupplier)
+ throws IOException {
+ validateViewMaterialization(view);
+ GlobalData data =
+ fetchGlobalDataFromWindmill(view, sideWindow, stateFamily,
scopedReadStateSupplier);
+ bytesRead += data.getSerializedSize();
+ return data.getIsReady() ? createSideInputCacheEntry(view, data) :
SideInput.notReady();
+ }
+
+ private <T> void validateViewMaterialization(PCollectionView<T> view) {
+ String materializationUrn = getViewFn(view).getMaterialization().getUrn();
+ checkState(
+ SUPPORTED_MATERIALIZATIONS.contains(materializationUrn),
+ "Only materialization's of type %s supported, received %s",
+ SUPPORTED_MATERIALIZATIONS,
+ materializationUrn);
+ }
+
+ private <T> SideInput<T> createSideInputCacheEntry(PCollectionView<T> view,
GlobalData data)
+ throws IOException {
+ Iterable<?> rawData = decodeRawData(view, data);
+ switch (getViewFn(view).getMaterialization().getUrn()) {
+ case ITERABLE_MATERIALIZATION_URN:
+ {
+ @SuppressWarnings({
+ "unchecked", // ITERABLE_MATERIALIZATION_URN has
ViewFn<IterableView, T>.
+ "rawtypes" // TODO(https://github.com/apache/beam/issues/20447)
+ })
+ ViewFn<IterableView, T> viewFn = (ViewFn<IterableView, T>)
getViewFn(view);
+ return SideInput.ready(viewFn.apply(() -> rawData),
data.getData().size());
+ }
+ case MULTIMAP_MATERIALIZATION_URN:
+ {
+ @SuppressWarnings({
+ "unchecked", // MULTIMAP_MATERIALIZATION_URN has
ViewFn<MultimapView, T>.
+ "rawtypes" // TODO(https://github.com/apache/beam/issues/20447)
+ })
+ ViewFn<MultimapView, T> viewFn = (ViewFn<MultimapView, T>)
getViewFn(view);
+ Coder<?> keyCoder = ((KvCoder<?, ?>) getCoder(view)).getKeyCoder();
+
+ @SuppressWarnings({
+ "unchecked", // Safe since multimap rawData is of type
Iterable<KV<K, V>>
+ "rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
+ })
+ T multimapSideInputValue =
+ viewFn.apply(
+ InMemoryMultimapSideInputView.fromIterable(keyCoder,
(Iterable) rawData));
+ return SideInput.ready(multimapSideInputValue,
data.getData().size());
+ }
+ default:
+ {
+ throw new IllegalStateException(
+ "Unknown side input materialization format requested: "
+ + getViewFn(view).getMaterialization().getUrn());
+ }
+ }
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
index c4042e37c3b..fdec36d688e 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
@@ -281,11 +281,7 @@ public class StreamingDataflowWorkerTest {
}
static Work createMockWork(long workToken) {
- return Work.create(
-
Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(workToken).build(),
- Instant::now,
- Collections.emptyList(),
- work -> {});
+ return createMockWork(workToken, work -> {});
}
static Work createMockWork(long workToken, Consumer<Work> processWorkFn) {
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
index 6620dbdaab7..9991520d593 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
@@ -55,6 +55,7 @@ import
org.apache.beam.runners.dataflow.worker.counters.CounterSet;
import org.apache.beam.runners.dataflow.worker.counters.NameContext;
import
org.apache.beam.runners.dataflow.worker.profiler.ScopedProfiler.NoopProfileScope;
import
org.apache.beam.runners.dataflow.worker.profiler.ScopedProfiler.ProfileScope;
+import
org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import
org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache;
import
org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader;
@@ -83,10 +84,10 @@ import org.mockito.MockitoAnnotations;
@RunWith(JUnit4.class)
public class StreamingModeExecutionContextTest {
- @Mock private StateFetcher stateFetcher;
+ @Mock private SideInputStateFetcher sideInputStateFetcher;
@Mock private WindmillStateReader stateReader;
- private StreamingModeExecutionStateRegistry executionStateRegistry =
+ private final StreamingModeExecutionStateRegistry executionStateRegistry =
new StreamingModeExecutionStateRegistry(null);
private StreamingModeExecutionContext executionContext;
DataflowWorkerHarnessOptions options;
@@ -133,7 +134,7 @@ public class StreamingModeExecutionContextTest {
null, // output watermark
null, // synchronized processing time
stateReader,
- stateFetcher,
+ sideInputStateFetcher,
outputBuilder);
TimerInternals timerInternals = stepContext.timerInternals();
@@ -183,7 +184,7 @@ public class StreamingModeExecutionContextTest {
null, // output watermark
null, // synchronized processing time
stateReader,
- stateFetcher,
+ sideInputStateFetcher,
outputBuilder);
TimerInternals timerInternals = stepContext.timerInternals();
assertTrue(timerTimestamp.isBefore(timerInternals.currentProcessingTime()));
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputDoFnRunnerTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputDoFnRunnerTest.java
index 05e0ff41761..3c121ab27f7 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputDoFnRunnerTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputDoFnRunnerTest.java
@@ -39,7 +39,7 @@ import org.apache.beam.runners.core.InMemoryStateInternals;
import org.apache.beam.runners.core.SideInputReader;
import org.apache.beam.runners.core.StateInternals;
import org.apache.beam.runners.core.StateNamespaces;
-import org.apache.beam.runners.dataflow.worker.StateFetcher.SideInputState;
+import
org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputState;
import org.apache.beam.runners.dataflow.worker.util.ListOutputManager;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest;
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputFetcherTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputFetcherTest.java
index 9ce462be321..a7196613fbb 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputFetcherTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputFetcherTest.java
@@ -31,7 +31,7 @@ import org.apache.beam.runners.core.SideInputReader;
import org.apache.beam.runners.core.StateInternals;
import org.apache.beam.runners.core.StateNamespaces;
import org.apache.beam.runners.core.TimerInternals.TimerData;
-import org.apache.beam.runners.dataflow.worker.StateFetcher.SideInputState;
+import
org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputState;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.state.BagState;
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StateFetcherTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcherTest.java
similarity index 67%
rename from
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StateFetcherTest.java
rename to
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcherTest.java
index 13d8a9bd3ff..daf81461879 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StateFetcherTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcherTest.java
@@ -15,11 +15,13 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.beam.runners.dataflow.worker;
+package org.apache.beam.runners.dataflow.worker.streaming.sideinput;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.contains;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNull;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@@ -27,10 +29,10 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import java.io.Closeable;
-import java.util.Arrays;
+import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;
-import org.apache.beam.runners.dataflow.worker.StateFetcher.SideInputState;
+import
org.apache.beam.runners.dataflow.worker.MetricTrackingWindmillServerStub;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.ListCoder;
@@ -56,14 +58,16 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
-/** Unit tests for {@link StateFetcher}. */
+/** Unit tests for {@link SideInputStateFetcher}. */
+// TODO: Add tests with different encoded windows to verify version is
correctly plumbed.
+@SuppressWarnings("deprecation")
@RunWith(JUnit4.class)
-public class StateFetcherTest {
+public class SideInputStateFetcherTest {
private static final String STATE_FAMILY = "state";
- @Mock MetricTrackingWindmillServerStub server;
+ @Mock private MetricTrackingWindmillServerStub server;
- @Mock Supplier<Closeable> readStateSupplier;
+ @Mock private Supplier<Closeable> readStateSupplier;
@Before
public void setUp() {
@@ -72,10 +76,11 @@ public class StateFetcherTest {
@Test
public void testFetchGlobalDataBasic() throws Exception {
- StateFetcher fetcher = new StateFetcher(server);
+ SideInputStateFetcher fetcher = new SideInputStateFetcher(server);
ByteStringOutputStream stream = new ByteStringOutputStream();
- ListCoder.of(StringUtf8Coder.of()).encode(Arrays.asList("data"), stream,
Coder.Context.OUTER);
+ ListCoder.of(StringUtf8Coder.of())
+ .encode(Collections.singletonList("data"), stream,
Coder.Context.OUTER);
ByteString encodedIterable = stream.toByteString();
PCollectionView<String> view =
@@ -87,17 +92,29 @@ public class StateFetcherTest {
// then the data is already cached.
when(server.getSideInputData(any(Windmill.GlobalDataRequest.class)))
.thenReturn(
- buildGlobalDataResponse(tag, ByteString.EMPTY, false, null),
- buildGlobalDataResponse(tag, ByteString.EMPTY, true,
encodedIterable));
+ buildGlobalDataResponse(tag, false, null),
+ buildGlobalDataResponse(tag, true, encodedIterable));
+
+ assertFalse(
+ fetcher
+ .fetchSideInput(
+ view,
+ GlobalWindow.INSTANCE,
+ STATE_FAMILY,
+ SideInputState.UNKNOWN,
+ readStateSupplier)
+ .isReady());
+
+ assertFalse(
+ fetcher
+ .fetchSideInput(
+ view,
+ GlobalWindow.INSTANCE,
+ STATE_FAMILY,
+ SideInputState.UNKNOWN,
+ readStateSupplier)
+ .isReady());
- assertEquals(
- null,
- fetcher.fetchSideInput(
- view, GlobalWindow.INSTANCE, STATE_FAMILY, SideInputState.UNKNOWN,
readStateSupplier));
- assertEquals(
- null,
- fetcher.fetchSideInput(
- view, GlobalWindow.INSTANCE, STATE_FAMILY, SideInputState.UNKNOWN,
readStateSupplier));
assertEquals(
"data",
fetcher
@@ -107,7 +124,8 @@ public class StateFetcherTest {
STATE_FAMILY,
SideInputState.KNOWN_READY,
readStateSupplier)
- .orNull());
+ .value()
+ .orElse(null));
assertEquals(
"data",
fetcher
@@ -117,18 +135,20 @@ public class StateFetcherTest {
STATE_FAMILY,
SideInputState.KNOWN_READY,
readStateSupplier)
- .orNull());
+ .value()
+ .orElse(null));
- verify(server, times(2)).getSideInputData(buildGlobalDataRequest(tag,
ByteString.EMPTY));
+ verify(server, times(2)).getSideInputData(buildGlobalDataRequest(tag));
verifyNoMoreInteractions(server);
}
@Test
public void testFetchGlobalDataNull() throws Exception {
- StateFetcher fetcher = new StateFetcher(server);
+ SideInputStateFetcher fetcher = new SideInputStateFetcher(server);
ByteStringOutputStream stream = new ByteStringOutputStream();
- ListCoder.of(VoidCoder.of()).encode(Arrays.asList((Void) null), stream,
Coder.Context.OUTER);
+ ListCoder.of(VoidCoder.of())
+ .encode(Collections.singletonList(null), stream, Coder.Context.OUTER);
ByteString encodedIterable = stream.toByteString();
PCollectionView<Void> view =
@@ -140,19 +160,28 @@ public class StateFetcherTest {
// then the data is already cached.
when(server.getSideInputData(any(Windmill.GlobalDataRequest.class)))
.thenReturn(
- buildGlobalDataResponse(tag, ByteString.EMPTY, false, null),
- buildGlobalDataResponse(tag, ByteString.EMPTY, true,
encodedIterable));
+ buildGlobalDataResponse(tag, false, null),
+ buildGlobalDataResponse(tag, true, encodedIterable));
- assertEquals(
- null,
- fetcher.fetchSideInput(
- view, GlobalWindow.INSTANCE, STATE_FAMILY, SideInputState.UNKNOWN,
readStateSupplier));
- assertEquals(
- null,
- fetcher.fetchSideInput(
- view, GlobalWindow.INSTANCE, STATE_FAMILY, SideInputState.UNKNOWN,
readStateSupplier));
- assertEquals(
- null,
+ assertFalse(
+ fetcher
+ .fetchSideInput(
+ view,
+ GlobalWindow.INSTANCE,
+ STATE_FAMILY,
+ SideInputState.UNKNOWN,
+ readStateSupplier)
+ .isReady());
+ assertFalse(
+ fetcher
+ .fetchSideInput(
+ view,
+ GlobalWindow.INSTANCE,
+ STATE_FAMILY,
+ SideInputState.UNKNOWN,
+ readStateSupplier)
+ .isReady());
+ assertNull(
fetcher
.fetchSideInput(
view,
@@ -160,9 +189,9 @@ public class StateFetcherTest {
STATE_FAMILY,
SideInputState.KNOWN_READY,
readStateSupplier)
- .orNull());
- assertEquals(
- null,
+ .value()
+ .orElse(null));
+ assertNull(
fetcher
.fetchSideInput(
view,
@@ -170,9 +199,10 @@ public class StateFetcherTest {
STATE_FAMILY,
SideInputState.KNOWN_READY,
readStateSupplier)
- .orNull());
+ .value()
+ .orElse(null));
- verify(server, times(2)).getSideInputData(buildGlobalDataRequest(tag,
ByteString.EMPTY));
+ verify(server, times(2)).getSideInputData(buildGlobalDataRequest(tag));
verifyNoMoreInteractions(server);
}
@@ -181,15 +211,14 @@ public class StateFetcherTest {
Coder<List<String>> coder = ListCoder.of(StringUtf8Coder.of());
ByteStringOutputStream stream = new ByteStringOutputStream();
- coder.encode(Arrays.asList("data1"), stream, Coder.Context.OUTER);
+ coder.encode(Collections.singletonList("data1"), stream,
Coder.Context.OUTER);
ByteString encodedIterable1 = stream.toByteStringAndReset();
- coder.encode(Arrays.asList("data2"), stream, Coder.Context.OUTER);
+ coder.encode(Collections.singletonList("data2"), stream,
Coder.Context.OUTER);
ByteString encodedIterable2 = stream.toByteString();
- Cache<StateFetcher.SideInputId, StateFetcher.SideInputCacheEntry> cache =
- CacheBuilder.newBuilder().build();
+ Cache<SideInputCache.Key<?>, SideInput<?>> cache =
CacheBuilder.newBuilder().build();
- StateFetcher fetcher = new StateFetcher(server, cache);
+ SideInputStateFetcher fetcher = new SideInputStateFetcher(server, new
SideInputCache(cache));
PCollectionView<String> view1 =
TestPipeline.create().apply(Create.empty(StringUtf8Coder.of())).apply(View.asSingleton());
@@ -204,9 +233,9 @@ public class StateFetcherTest {
// then view 1 again twice.
when(server.getSideInputData(any(Windmill.GlobalDataRequest.class)))
.thenReturn(
- buildGlobalDataResponse(tag1, ByteString.EMPTY, true,
encodedIterable1),
- buildGlobalDataResponse(tag2, ByteString.EMPTY, true,
encodedIterable2),
- buildGlobalDataResponse(tag1, ByteString.EMPTY, true,
encodedIterable1));
+ buildGlobalDataResponse(tag1, true, encodedIterable1),
+ buildGlobalDataResponse(tag2, true, encodedIterable2),
+ buildGlobalDataResponse(tag1, true, encodedIterable1));
assertEquals(
"data1",
@@ -217,7 +246,8 @@ public class StateFetcherTest {
STATE_FAMILY,
SideInputState.UNKNOWN,
readStateSupplier)
- .orNull());
+ .value()
+ .orElse(null));
assertEquals(
"data2",
fetcher
@@ -227,7 +257,8 @@ public class StateFetcherTest {
STATE_FAMILY,
SideInputState.UNKNOWN,
readStateSupplier)
- .orNull());
+ .value()
+ .orElse(null));
cache.invalidateAll();
assertEquals(
"data1",
@@ -238,7 +269,8 @@ public class StateFetcherTest {
STATE_FAMILY,
SideInputState.UNKNOWN,
readStateSupplier)
- .orNull());
+ .value()
+ .orElse(null));
assertEquals(
"data1",
fetcher
@@ -248,7 +280,8 @@ public class StateFetcherTest {
STATE_FAMILY,
SideInputState.UNKNOWN,
readStateSupplier)
- .orNull());
+ .value()
+ .orElse(null));
ArgumentCaptor<Windmill.GlobalDataRequest> captor =
ArgumentCaptor.forClass(Windmill.GlobalDataRequest.class);
@@ -259,14 +292,14 @@ public class StateFetcherTest {
assertThat(
captor.getAllValues(),
contains(
- buildGlobalDataRequest(tag1, ByteString.EMPTY),
- buildGlobalDataRequest(tag2, ByteString.EMPTY),
- buildGlobalDataRequest(tag1, ByteString.EMPTY)));
+ buildGlobalDataRequest(tag1),
+ buildGlobalDataRequest(tag2),
+ buildGlobalDataRequest(tag1)));
}
@Test
public void testEmptyFetchGlobalData() throws Exception {
- StateFetcher fetcher = new StateFetcher(server);
+ SideInputStateFetcher fetcher = new SideInputStateFetcher(server);
ByteString encodedIterable = ByteString.EMPTY;
@@ -280,7 +313,7 @@ public class StateFetcherTest {
// Test three calls in a row. First, data is not ready, then data is ready,
// then the data is already cached.
when(server.getSideInputData(any(Windmill.GlobalDataRequest.class)))
- .thenReturn(buildGlobalDataResponse(tag, ByteString.EMPTY, true,
encodedIterable));
+ .thenReturn(buildGlobalDataResponse(tag, true, encodedIterable));
assertEquals(
0L,
@@ -292,17 +325,22 @@ public class StateFetcherTest {
STATE_FAMILY,
SideInputState.UNKNOWN,
readStateSupplier)
- .orNull());
+ .value()
+ .orElse(null));
- verify(server).getSideInputData(buildGlobalDataRequest(tag,
ByteString.EMPTY));
+ verify(server).getSideInputData(buildGlobalDataRequest(tag));
verifyNoMoreInteractions(server);
}
- private Windmill.GlobalData buildGlobalDataResponse(
- String tag, ByteString version, boolean isReady, ByteString data) {
+ private static Windmill.GlobalData buildGlobalDataResponse(
+ String tag, boolean isReady, ByteString data) {
Windmill.GlobalData.Builder builder =
Windmill.GlobalData.newBuilder()
-
.setDataId(Windmill.GlobalDataId.newBuilder().setTag(tag).setVersion(version).build());
+ .setDataId(
+ Windmill.GlobalDataId.newBuilder()
+ .setTag(tag)
+ .setVersion(ByteString.EMPTY)
+ .build());
if (isReady) {
builder.setIsReady(true).setData(data);
@@ -312,7 +350,7 @@ public class StateFetcherTest {
return builder.build();
}
- private Windmill.GlobalDataRequest buildGlobalDataRequest(String tag,
ByteString version) {
+ private static Windmill.GlobalDataRequest buildGlobalDataRequest(String tag,
ByteString version) {
Windmill.GlobalDataId id =
Windmill.GlobalDataId.newBuilder().setTag(tag).setVersion(version).build();
@@ -323,4 +361,8 @@ public class StateFetcherTest {
TimeUnit.MILLISECONDS.toMicros(GlobalWindow.INSTANCE.maxTimestamp().getMillis()))
.build();
}
+
+ private static Windmill.GlobalDataRequest buildGlobalDataRequest(String tag)
{
+ return buildGlobalDataRequest(tag, ByteString.EMPTY);
+ }
}