This is an automated email from the ASF dual-hosted git repository.
scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new c08afeae60d Enable MapState and SetState for dataflow streaming engine
pipelines with legacy runner by building on top of MultimapState. (#31453)
c08afeae60d is described below
commit c08afeae60dfb1a15a0f4c8669085662a847249f
Author: Sam Whittle <[email protected]>
AuthorDate: Thu Jul 4 22:22:21 2024 +0200
Enable MapState and SetState for dataflow streaming engine pipelines with
legacy runner by building on top of MultimapState. (#31453)
---
CHANGES.md | 1 +
.../org/apache/beam/runners/core/StateTags.java | 8 +
.../beam/runners/dataflow/DataflowRunner.java | 35 +---
.../beam/runners/dataflow/DataflowRunnerTest.java | 59 ------
.../dataflow/worker/StreamingDataflowWorker.java | 11 +-
.../worker/windmill/state/AbstractWindmillMap.java | 23 +++
.../worker/windmill/state/CachingStateTable.java | 53 +++--
.../worker/windmill/state/WindmillMap.java | 24 +--
.../windmill/state/WindmillMapViaMultimap.java | 164 +++++++++++++++
.../worker/windmill/state/WindmillMultimap.java | 4 +-
.../worker/windmill/state/WindmillSet.java | 36 +---
.../worker/windmill/state/WindmillStateCache.java | 46 +++--
.../windmill/state/WindmillStateInternals.java | 14 +-
.../worker/StreamingModeExecutionContextTest.java | 5 +-
.../dataflow/worker/WindmillStateTestUtils.java | 2 +-
.../dataflow/worker/WorkerCustomSourcesTest.java | 5 +-
.../windmill/state/WindmillStateCacheTest.java | 2 +-
.../windmill/state/WindmillStateInternalsTest.java | 225 ++++++++++++++++++++-
.../refresh/DispatchedActiveWorkRefresherTest.java | 2 +-
.../java/org/apache/beam/sdk/state/StateSpecs.java | 23 +++
.../org/apache/beam/sdk/transforms/ParDoTest.java | 28 ++-
21 files changed, 573 insertions(+), 197 deletions(-)
diff --git a/CHANGES.md b/CHANGES.md
index 38fa6e44b73..0a620038f11 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -68,6 +68,7 @@
* Multiple RunInference instances can now share the same model instance by
setting the model_identifier parameter (Python)
([#31665](https://github.com/apache/beam/issues/31665)).
* Removed a 3rd party LGPL dependency from the Go SDK
([#31765](https://github.com/apache/beam/issues/31765)).
+* Support for MapState and SetState when using Dataflow Runner v1 with
Streaming Engine (Java)
([[#18200](https://github.com/apache/beam/issues/18200)])
## Breaking Changes
diff --git
a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java
b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java
index 7ffb10c85c0..6ed7f8525fd 100644
---
a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java
+++
b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java
@@ -257,6 +257,14 @@ public class StateTags {
new StructuredId(setTag.getId()),
StateSpecs.convertToMapSpecInternal(setTag.getSpec()));
}
+ public static <KeyT, ValueT> StateTag<MultimapState<KeyT, ValueT>>
convertToMultiMapTagInternal(
+ StateTag<MapState<KeyT, ValueT>> mapTag) {
+ StateSpec<MapState<KeyT, ValueT>> spec = mapTag.getSpec();
+ StateSpec<MultimapState<KeyT, ValueT>> multimapSpec =
+ StateSpecs.convertToMultimapSpecInternal(spec);
+ return new SimpleStateTag<>(new StructuredId(mapTag.getId()),
multimapSpec);
+ }
+
private static class StructuredId implements Serializable {
private final StateKind kind;
private final String rawId;
diff --git
a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
index de566599bf8..708c6341326 100644
---
a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
+++
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
@@ -2564,11 +2564,6 @@ public class DataflowRunner extends
PipelineRunner<DataflowPipelineJob> {
|| hasExperiment(options, "use_portable_job_submission");
}
- static boolean useStreamingEngine(DataflowPipelineOptions options) {
- return hasExperiment(options, GcpOptions.STREAMING_ENGINE_EXPERIMENT)
- || hasExperiment(options, GcpOptions.WINDMILL_SERVICE_EXPERIMENT);
- }
-
static void verifyDoFnSupported(
DoFn<?, ?> fn, boolean streaming, DataflowPipelineOptions options) {
if (!streaming && DoFnSignatures.usesMultimapState(fn)) {
@@ -2583,8 +2578,6 @@ public class DataflowRunner extends
PipelineRunner<DataflowPipelineJob> {
"%s does not currently support @RequiresTimeSortedInput in
streaming mode.",
DataflowRunner.class.getSimpleName()));
}
-
- boolean streamingEngine = useStreamingEngine(options);
boolean isUnifiedWorker = useUnifiedWorker(options);
if (DoFnSignatures.usesMultimapState(fn) && isUnifiedWorker) {
@@ -2593,25 +2586,17 @@ public class DataflowRunner extends
PipelineRunner<DataflowPipelineJob> {
"%s does not currently support %s running using streaming on
unified worker",
DataflowRunner.class.getSimpleName(),
MultimapState.class.getSimpleName()));
}
- if (DoFnSignatures.usesSetState(fn)) {
- if (streaming && (isUnifiedWorker || streamingEngine)) {
- throw new UnsupportedOperationException(
- String.format(
- "%s does not currently support %s when using %s",
- DataflowRunner.class.getSimpleName(),
- SetState.class.getSimpleName(),
- isUnifiedWorker ? "streaming on unified worker" : "streaming
engine"));
- }
+ if (DoFnSignatures.usesSetState(fn) && streaming && isUnifiedWorker) {
+ throw new UnsupportedOperationException(
+ String.format(
+ "%s does not currently support %s when using streaming on
unified worker",
+ DataflowRunner.class.getSimpleName(),
SetState.class.getSimpleName()));
}
- if (DoFnSignatures.usesMapState(fn)) {
- if (streaming && (isUnifiedWorker || streamingEngine)) {
- throw new UnsupportedOperationException(
- String.format(
- "%s does not currently support %s when using %s",
- DataflowRunner.class.getSimpleName(),
- MapState.class.getSimpleName(),
- isUnifiedWorker ? "streaming on unified worker" : "streaming
engine"));
- }
+ if (DoFnSignatures.usesMapState(fn) && streaming && isUnifiedWorker) {
+ throw new UnsupportedOperationException(
+ String.format(
+ "%s does not currently support %s when using streaming on
unified worker",
+ DataflowRunner.class.getSimpleName(),
MapState.class.getSimpleName()));
}
if (DoFnSignatures.usesBundleFinalizer(fn) && !isUnifiedWorker) {
throw new UnsupportedOperationException(
diff --git
a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java
b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java
index 55bfc44ee62..cf1066e41d2 100644
---
a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java
+++
b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java
@@ -131,8 +131,6 @@ import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.runners.PTransformOverrideFactory.ReplacementOutput;
import org.apache.beam.sdk.runners.TransformHierarchy;
import org.apache.beam.sdk.runners.TransformHierarchy.Node;
-import org.apache.beam.sdk.state.MapState;
-import org.apache.beam.sdk.state.SetState;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.StateSpecs;
import org.apache.beam.sdk.state.ValueState;
@@ -1880,63 +1878,6 @@ public class DataflowRunnerTest implements Serializable {
}
}
- private void verifyMapStateUnsupported(PipelineOptions options) throws
Exception {
- Pipeline p = Pipeline.create(options);
- p.apply(Create.of(KV.of(13, 42)))
- .apply(
- ParDo.of(
- new DoFn<KV<Integer, Integer>, Void>() {
-
- @StateId("fizzle")
- private final StateSpec<MapState<Void, Void>> voidState =
StateSpecs.map();
-
- @ProcessElement
- public void process() {}
- }));
-
- thrown.expectMessage("MapState");
- thrown.expect(UnsupportedOperationException.class);
- p.run();
- }
-
- @Test
- public void testMapStateUnsupportedStreamingEngine() throws Exception {
- PipelineOptions options = buildPipelineOptions();
- ExperimentalOptions.addExperiment(
- options.as(ExperimentalOptions.class),
GcpOptions.STREAMING_ENGINE_EXPERIMENT);
- options.as(DataflowPipelineOptions.class).setStreaming(true);
-
- verifyMapStateUnsupported(options);
- }
-
- private void verifySetStateUnsupported(PipelineOptions options) throws
Exception {
- Pipeline p = Pipeline.create(options);
- p.apply(Create.of(KV.of(13, 42)))
- .apply(
- ParDo.of(
- new DoFn<KV<Integer, Integer>, Void>() {
-
- @StateId("fizzle")
- private final StateSpec<SetState<Void>> voidState =
StateSpecs.set();
-
- @ProcessElement
- public void process() {}
- }));
-
- thrown.expectMessage("SetState");
- thrown.expect(UnsupportedOperationException.class);
- p.run();
- }
-
- @Test
- public void testSetStateUnsupportedStreamingEngine() throws Exception {
- PipelineOptions options = buildPipelineOptions();
- ExperimentalOptions.addExperiment(
- options.as(ExperimentalOptions.class),
GcpOptions.STREAMING_ENGINE_EXPERIMENT);
- options.as(DataflowPipelineOptions.class).setStreaming(true);
- verifySetStateUnsupported(options);
- }
-
/** Records all the composite transforms visited within the Pipeline. */
private static class CompositeTransformRecorder extends
PipelineVisitor.Defaults {
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
index 59819db88a0..0e46e7e4687 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
@@ -324,7 +324,10 @@ public class StreamingDataflowWorker {
BoundedQueueExecutor workExecutor = createWorkUnitExecutor(options);
AtomicInteger maxWorkItemCommitBytes = new
AtomicInteger(Integer.MAX_VALUE);
WindmillStateCache windmillStateCache =
- WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb());
+ WindmillStateCache.builder()
+ .setSizeMb(options.getWorkerCacheMb())
+ .setSupportMapViaMultimap(options.isEnableStreamingEngine())
+ .build();
Function<String, ScheduledExecutorService> executorSupplier =
threadName ->
Executors.newSingleThreadScheduledExecutor(
@@ -478,7 +481,11 @@ public class StreamingDataflowWorker {
ConcurrentMap<String, StageInfo> stageInfo = new ConcurrentHashMap<>();
AtomicInteger maxWorkItemCommitBytes = new
AtomicInteger(maxWorkItemCommitBytesOverrides);
BoundedQueueExecutor workExecutor = createWorkUnitExecutor(options);
- WindmillStateCache stateCache =
WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb());
+ WindmillStateCache stateCache =
+ WindmillStateCache.builder()
+ .setSizeMb(options.getWorkerCacheMb())
+ .setSupportMapViaMultimap(options.isEnableStreamingEngine())
+ .build();
ComputationConfig.Fetcher configFetcher =
options.isEnableStreamingEngine()
? StreamingEngineComputationConfigFetcher.forTesting(
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/AbstractWindmillMap.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/AbstractWindmillMap.java
new file mode 100644
index 00000000000..e144d5cf8c3
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/AbstractWindmillMap.java
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.state;
+
+import org.apache.beam.sdk.state.MapState;
+
+public abstract class AbstractWindmillMap<K, V> extends SimpleWindmillState
+ implements MapState<K, V> {}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/CachingStateTable.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/CachingStateTable.java
index bcaf8bf21a2..c026aac4f96 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/CachingStateTable.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/CachingStateTable.java
@@ -24,17 +24,9 @@ import org.apache.beam.runners.core.StateNamespace;
import org.apache.beam.runners.core.StateTable;
import org.apache.beam.runners.core.StateTag;
import org.apache.beam.runners.core.StateTags;
+import org.apache.beam.sdk.coders.BooleanCoder;
import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.state.BagState;
-import org.apache.beam.sdk.state.CombiningState;
-import org.apache.beam.sdk.state.MapState;
-import org.apache.beam.sdk.state.MultimapState;
-import org.apache.beam.sdk.state.OrderedListState;
-import org.apache.beam.sdk.state.SetState;
-import org.apache.beam.sdk.state.State;
-import org.apache.beam.sdk.state.StateContext;
-import org.apache.beam.sdk.state.ValueState;
-import org.apache.beam.sdk.state.WatermarkHoldState;
+import org.apache.beam.sdk.state.*;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.CombineWithContext;
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
@@ -50,6 +42,7 @@ final class CachingStateTable extends StateTable {
private final Supplier<Closeable> scopedReadStateSupplier;
private final @Nullable StateTable derivedStateTable;
private final boolean isNewKey;
+ private final boolean mapStateViaMultimapState;
private CachingStateTable(Builder builder) {
this.stateFamily = builder.stateFamily;
@@ -59,6 +52,7 @@ final class CachingStateTable extends StateTable {
this.isNewKey = builder.isNewKey;
this.scopedReadStateSupplier = builder.scopedReadStateSupplier;
this.derivedStateTable = builder.derivedStateTable;
+ this.mapStateViaMultimapState = builder.mapStateViaMultimapState;
if (this.isSystemTable) {
Preconditions.checkState(derivedStateTable == null);
@@ -103,30 +97,39 @@ final class CachingStateTable extends StateTable {
@Override
public <T> SetState<T> bindSet(StateTag<SetState<T>> spec, Coder<T>
elemCoder) {
+ StateTag<MapState<T, Boolean>> internalMapAddress =
StateTags.convertToMapTagInternal(spec);
WindmillSet<T> result =
- new WindmillSet<>(namespace, spec, stateFamily, elemCoder, cache,
isNewKey);
+ new WindmillSet<>(bindMap(internalMapAddress, elemCoder,
BooleanCoder.of()));
result.initializeForWorkItem(reader, scopedReadStateSupplier);
return result;
}
@Override
- public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
+ public <KeyT, ValueT> AbstractWindmillMap<KeyT, ValueT> bindMap(
StateTag<MapState<KeyT, ValueT>> spec, Coder<KeyT> keyCoder,
Coder<ValueT> valueCoder) {
- WindmillMap<KeyT, ValueT> result =
- cache
- .get(namespace, spec)
- .map(mapState -> (WindmillMap<KeyT, ValueT>) mapState)
- .orElseGet(
- () ->
- new WindmillMap<>(
- namespace, spec, stateFamily, keyCoder,
valueCoder, isNewKey));
-
+ AbstractWindmillMap<KeyT, ValueT> result;
+ if (mapStateViaMultimapState) {
+ StateTag<MultimapState<KeyT, ValueT>> internalMultimapAddress =
+ StateTags.convertToMultiMapTagInternal(spec);
+ result =
+ new WindmillMapViaMultimap<>(
+ bindMultimap(internalMultimapAddress, keyCoder, valueCoder));
+ } else {
+ result =
+ cache
+ .get(namespace, spec)
+ .map(mapState -> (AbstractWindmillMap<KeyT, ValueT>)
mapState)
+ .orElseGet(
+ () ->
+ new WindmillMap<>(
+ namespace, spec, stateFamily, keyCoder,
valueCoder, isNewKey));
+ }
result.initializeForWorkItem(reader, scopedReadStateSupplier);
return result;
}
@Override
- public <KeyT, ValueT> MultimapState<KeyT, ValueT> bindMultimap(
+ public <KeyT, ValueT> WindmillMultimap<KeyT, ValueT> bindMultimap(
StateTag<MultimapState<KeyT, ValueT>> spec,
Coder<KeyT> keyCoder,
Coder<ValueT> valueCoder) {
@@ -246,6 +249,7 @@ final class CachingStateTable extends StateTable {
private final boolean isNewKey;
private boolean isSystemTable;
private @Nullable StateTable derivedStateTable;
+ private boolean mapStateViaMultimapState = false;
private Builder(
String stateFamily,
@@ -268,6 +272,11 @@ final class CachingStateTable extends StateTable {
return this;
}
+ Builder withMapStateViaMultimapState() {
+ this.mapStateViaMultimapState = true;
+ return this;
+ }
+
CachingStateTable build() {
return new CachingStateTable(this);
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMap.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMap.java
index 9f027af0a87..aed03f33e6d 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMap.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMap.java
@@ -21,10 +21,7 @@ import static
org.apache.beam.runners.dataflow.worker.windmill.state.WindmillSta
import java.io.Closeable;
import java.io.IOException;
-import java.util.AbstractMap;
-import java.util.Collections;
-import java.util.Map;
-import java.util.Set;
+import java.util.*;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.function.Function;
@@ -40,6 +37,8 @@ import org.apache.beam.sdk.util.ByteStringOutputStream;
import org.apache.beam.sdk.util.Weighted;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets;
@@ -51,7 +50,7 @@ import
org.checkerframework.checker.nullness.qual.UnknownKeyFor;
@SuppressWarnings({
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
})
-public class WindmillMap<K, V> extends SimpleWindmillState implements
MapState<K, V> {
+public class WindmillMap<K, V> extends AbstractWindmillMap<K, V> {
private final StateNamespace namespace;
private final StateTag<MapState<K, V>> address;
private final ByteString stateKeyPrefix;
@@ -327,7 +326,7 @@ public class WindmillMap<K, V> extends SimpleWindmillState
implements MapState<K
@Override
public Iterable<Map.Entry<K, V>> read() {
if (complete) {
- return Iterables.unmodifiableIterable(cachedValues.entrySet());
+ return ImmutableMap.copyOf(cachedValues).entrySet();
}
Future<Iterable<Map.Entry<ByteString, V>>> persistedData = getFuture();
try (Closeable scope = scopedReadState()) {
@@ -352,20 +351,22 @@ public class WindmillMap<K, V> extends
SimpleWindmillState implements MapState<K
cachedValues.putIfAbsent(e.getKey(), e.getValue());
});
complete = true;
- return Iterables.unmodifiableIterable(cachedValues.entrySet());
+ return ImmutableMap.copyOf(cachedValues).entrySet();
} else {
+ ImmutableMap<K, V> cachedCopy = ImmutableMap.copyOf(cachedValues);
+ ImmutableSet<K> removalCopy = ImmutableSet.copyOf(localRemovals);
// This means that the result might be too large to cache, so don't
add it to the
// local cache. Instead merge the iterables, giving priority to any
local additions
- // (represented in cachedValued and localRemovals) that may not have
been committed
+ // (represented in cachedCopy and removalCopy) that may not have
been committed
// yet.
return Iterables.unmodifiableIterable(
Iterables.concat(
- cachedValues.entrySet(),
+ cachedCopy.entrySet(),
Iterables.filter(
transformedData,
e ->
- !cachedValues.containsKey(e.getKey())
- && !localRemovals.contains(e.getKey()))));
+ !cachedCopy.containsKey(e.getKey())
+ && !removalCopy.contains(e.getKey()))));
}
} catch (InterruptedException | ExecutionException | IOException e) {
@@ -428,7 +429,6 @@ public class WindmillMap<K, V> extends SimpleWindmillState
implements MapState<K
negativeCache.add(key);
return defaultValue;
}
- // TODO: Don't do this if it was already in cache.
cachedValues.put(key, persistedValue);
return persistedValue;
} catch (InterruptedException | ExecutionException | IOException e) {
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMapViaMultimap.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMapViaMultimap.java
new file mode 100644
index 00000000000..0ee508a53ba
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMapViaMultimap.java
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.state;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.function.Function;
+import javax.annotation.Nullable;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.sdk.state.ReadableState;
+import org.apache.beam.sdk.state.ReadableStates;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterators;
+
+public class WindmillMapViaMultimap<KeyT, ValueT> extends
AbstractWindmillMap<KeyT, ValueT> {
+ final WindmillMultimap<KeyT, ValueT> multimap;
+
+ WindmillMapViaMultimap(WindmillMultimap<KeyT, ValueT> multimap) {
+ this.multimap = multimap;
+ }
+
+ @Override
+ protected Windmill.WorkItemCommitRequest
persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+ throws IOException {
+ return multimap.persistDirectly(cache);
+ }
+
+ @Override
+ void initializeForWorkItem(
+ WindmillStateReader reader, Supplier<Closeable> scopedReadStateSupplier)
{
+ super.initializeForWorkItem(reader, scopedReadStateSupplier);
+ multimap.initializeForWorkItem(reader, scopedReadStateSupplier);
+ }
+
+ @Override
+ void cleanupAfterWorkItem() {
+ super.cleanupAfterWorkItem();
+ multimap.cleanupAfterWorkItem();
+ }
+
+ @Override
+ public void put(KeyT key, ValueT value) {
+ multimap.remove(key);
+ multimap.put(key, value);
+ }
+
+ @Override
+ public ReadableState<ValueT> computeIfAbsent(
+ KeyT key, Function<? super KeyT, ? extends ValueT> mappingFunction) {
+ // Note that computeIfAbsent comments indicate that the read is lazy but
this matches the
+ // existing eager
+ // behavior of WindmillMap.
+ Iterable<ValueT> existingValues = multimap.get(key).read();
+ if (Iterables.isEmpty(existingValues)) {
+ ValueT inserted = mappingFunction.apply(key);
+ multimap.put(key, inserted);
+ return ReadableStates.immediate(inserted);
+ } else {
+ return
ReadableStates.immediate(Iterables.getOnlyElement(existingValues));
+ }
+ }
+
+ @Override
+ public void remove(KeyT key) {
+ multimap.remove(key);
+ }
+
+ private static class SingleValueIterableAdaptor<T> implements
ReadableState<T> {
+ final ReadableState<Iterable<T>> wrapped;
+ final @Nullable T defaultValue;
+
+ SingleValueIterableAdaptor(ReadableState<Iterable<T>> wrapped, @Nullable T
defaultValue) {
+ this.wrapped = wrapped;
+ this.defaultValue = defaultValue;
+ }
+
+ @Override
+ public T read() {
+ Iterator<T> iterator = wrapped.read().iterator();
+ if (!iterator.hasNext()) {
+ return null;
+ }
+ return Iterators.getOnlyElement(iterator);
+ }
+
+ @Override
+ public ReadableState<T> readLater() {
+ wrapped.readLater();
+ return this;
+ }
+ }
+
+ @Override
+ public ReadableState<ValueT> get(KeyT key) {
+ return getOrDefault(key, null);
+ }
+
+ @Override
+ public ReadableState<ValueT> getOrDefault(KeyT key, @Nullable ValueT
defaultValue) {
+ return new SingleValueIterableAdaptor<>(multimap.get(key), defaultValue);
+ }
+
+ @Override
+ public ReadableState<Iterable<KeyT>> keys() {
+ return multimap.keys();
+ }
+
+ private static class RemoveKeyAdaptor<K, V> implements
ReadableState<Iterable<V>> {
+ final ReadableState<Iterable<Map.Entry<K, V>>> wrapped;
+
+ RemoveKeyAdaptor(ReadableState<Iterable<Map.Entry<K, V>>> wrapped) {
+ this.wrapped = wrapped;
+ }
+
+ @Override
+ public Iterable<V> read() {
+ return Iterables.transform(wrapped.read(), Map.Entry::getValue);
+ }
+
+ @Override
+ public ReadableState<Iterable<V>> readLater() {
+ wrapped.readLater();
+ return this;
+ }
+ }
+
+ @Override
+ public ReadableState<Iterable<ValueT>> values() {
+ return new RemoveKeyAdaptor<>(multimap.entries());
+ }
+
+ @Override
+ public ReadableState<Iterable<Map.Entry<KeyT, ValueT>>> entries() {
+ return multimap.entries();
+ }
+
+ @Override
+ public ReadableState<Boolean> isEmpty() {
+ return multimap.isEmpty();
+ }
+
+ @Override
+ public void clear() {
+ multimap.clear();
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMultimap.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMultimap.java
index 75f33e69e0b..19c79a497d4 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMultimap.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMultimap.java
@@ -216,8 +216,8 @@ public class WindmillMultimap<K, V> extends
SimpleWindmillState implements Multi
if (keyState == null || keyState.existence ==
KeyExistence.KNOWN_NONEXISTENT) {
return;
}
- if (keyState.valuesCached && keyState.valuesSize == 0) {
- // no data in windmill, deleting from local cache is sufficient.
+ if (keyState.valuesCached && keyState.valuesSize == 0 &&
!keyState.removedLocally) {
+ // no data in windmill and no need to keep state, deleting from local
cache is sufficient.
keyStateMap.remove(structuralKey);
} else {
// there may be data in windmill that need to be removed.
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillSet.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillSet.java
index 4afb879e722..ee7e6862c7a 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillSet.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillSet.java
@@ -20,13 +20,7 @@ package
org.apache.beam.runners.dataflow.worker.windmill.state;
import java.io.Closeable;
import java.io.IOException;
import java.util.Optional;
-import org.apache.beam.runners.core.StateNamespace;
-import org.apache.beam.runners.core.StateTag;
-import org.apache.beam.runners.core.StateTags;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
-import org.apache.beam.sdk.coders.BooleanCoder;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.state.MapState;
import org.apache.beam.sdk.state.ReadableState;
import org.apache.beam.sdk.state.SetState;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier;
@@ -35,30 +29,10 @@ import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.UnknownKeyFor;
public class WindmillSet<K> extends SimpleWindmillState implements SetState<K>
{
- private final WindmillMap<K, Boolean> windmillMap;
-
- WindmillSet(
- StateNamespace namespace,
- StateTag<SetState<K>> address,
- String stateFamily,
- Coder<K> keyCoder,
- WindmillStateCache.ForKeyAndFamily cache,
- boolean isNewKey) {
- StateTag<MapState<K, Boolean>> internalMapAddress =
StateTags.convertToMapTagInternal(address);
-
- this.windmillMap =
- cache
- .get(namespace, internalMapAddress)
- .map(map -> (WindmillMap<K, Boolean>) map)
- .orElseGet(
- () ->
- new WindmillMap<>(
- namespace,
- internalMapAddress,
- stateFamily,
- keyCoder,
- BooleanCoder.of(),
- isNewKey));
+ private final AbstractWindmillMap<K, Boolean> windmillMap;
+
+ WindmillSet(AbstractWindmillMap<K, Boolean> windmillMap) {
+ this.windmillMap = windmillMap;
}
@Override
@@ -117,11 +91,13 @@ public class WindmillSet<K> extends SimpleWindmillState
implements SetState<K> {
@Override
void initializeForWorkItem(
WindmillStateReader reader, Supplier<Closeable> scopedReadStateSupplier)
{
+ super.initializeForWorkItem(reader, scopedReadStateSupplier);
windmillMap.initializeForWorkItem(reader, scopedReadStateSupplier);
}
@Override
void cleanupAfterWorkItem() {
+ super.cleanupAfterWorkItem();
windmillMap.cleanupAfterWorkItem();
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java
index c6c49134bcb..64eb9dd941b 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java
@@ -17,6 +17,7 @@
*/
package org.apache.beam.runners.dataflow.worker.windmill.state;
+import com.google.auto.value.AutoBuilder;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.HashMap;
@@ -29,9 +30,7 @@ import javax.servlet.http.HttpServletResponse;
import org.apache.beam.runners.core.StateNamespace;
import org.apache.beam.runners.core.StateTag;
import org.apache.beam.runners.core.StateTags;
-import org.apache.beam.runners.dataflow.worker.StreamingDataflowWorker;
-import org.apache.beam.runners.dataflow.worker.Weighers;
-import org.apache.beam.runners.dataflow.worker.WindmillComputationKey;
+import org.apache.beam.runners.dataflow.worker.*;
import org.apache.beam.runners.dataflow.worker.status.BaseStatusServlet;
import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider;
import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey;
@@ -76,26 +75,33 @@ public class WindmillStateCache implements
StatusDataProvider {
// entries inaccessible. They will be evicted through normal cache operation.
private final ConcurrentMap<WindmillComputationKey, ForKey> keyIndex;
private final long workerCacheBytes; // Copy workerCacheMb and convert to
bytes.
+ private final boolean supportMapViaMultimap;
- private WindmillStateCache(
- long workerCacheMb,
- ConcurrentMap<WindmillComputationKey, ForKey> keyIndex,
- Cache<StateId, StateCacheEntry> stateCache) {
- this.workerCacheBytes = workerCacheMb * MEGABYTES;
- this.stateCache = stateCache;
- this.keyIndex = keyIndex;
- }
-
- public static WindmillStateCache ofSizeMbs(long workerCacheMb) {
- return new WindmillStateCache(
- workerCacheMb,
- new
MapMaker().weakValues().concurrencyLevel(STATE_CACHE_CONCURRENCY_LEVEL).makeMap(),
+ WindmillStateCache(long sizeMb, boolean supportMapViaMultimap) {
+ this.workerCacheBytes = sizeMb * MEGABYTES;
+ this.stateCache =
CacheBuilder.newBuilder()
- .maximumWeight(workerCacheMb * MEGABYTES)
+ .maximumWeight(workerCacheBytes)
.recordStats()
.weigher(Weighers.weightedKeysAndValues())
.concurrencyLevel(STATE_CACHE_CONCURRENCY_LEVEL)
- .build());
+ .build();
+ this.keyIndex =
+ new
MapMaker().weakValues().concurrencyLevel(STATE_CACHE_CONCURRENCY_LEVEL).makeMap();
+ this.supportMapViaMultimap = supportMapViaMultimap;
+ }
+
+ @AutoBuilder(ofClass = WindmillStateCache.class)
+ public interface Builder {
+ Builder setSizeMb(long sizeMb);
+
+ Builder setSupportMapViaMultimap(boolean supportMapViaMultimap);
+
+ WindmillStateCache build();
+ }
+
+ public static Builder builder() {
+ return new
AutoBuilder_WindmillStateCache_Builder().setSupportMapViaMultimap(false);
}
private EntryStats calculateEntryStats() {
@@ -399,6 +405,10 @@ public class WindmillStateCache implements
StatusDataProvider {
return stateFamily;
}
+ public boolean supportMapStateViaMultimapState() {
+ return supportMapViaMultimap;
+ }
+
public <T extends State> Optional<T> get(StateNamespace namespace,
StateTag<T> address) {
@SuppressWarnings("nullness")
// the mapping function for localCache.computeIfAbsent (i.e
stateCache.getIfPresent) is
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternals.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternals.java
index c900228e86b..f757db991fa 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternals.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternals.java
@@ -66,13 +66,13 @@ public class WindmillStateInternals<K> implements
StateInternals {
this.key = key;
this.cache = cache;
this.scopedReadStateSupplier = scopedReadStateSupplier;
- this.workItemDerivedState =
- CachingStateTable.builder(stateFamily, reader, cache, isNewKey,
scopedReadStateSupplier)
- .build();
- this.workItemState =
- CachingStateTable.builder(stateFamily, reader, cache, isNewKey,
scopedReadStateSupplier)
- .withDerivedState(workItemDerivedState)
- .build();
+ CachingStateTable.Builder builder =
+ CachingStateTable.builder(stateFamily, reader, cache, isNewKey,
scopedReadStateSupplier);
+ if (cache.supportMapStateViaMultimapState()) {
+ builder = builder.withMapStateViaMultimapState();
+ }
+ this.workItemDerivedState = builder.build();
+ this.workItemState =
builder.withDerivedState(workItemDerivedState).build();
}
@Override
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
index 2193f20f3fe..6c46bda5acf 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
@@ -112,7 +112,10 @@ public class StreamingModeExecutionContextTest {
COMPUTATION_ID,
new ReaderCache(Duration.standardMinutes(1),
Executors.newCachedThreadPool()),
stateNameMap,
-
WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb()).forComputation("comp"),
+ WindmillStateCache.builder()
+ .setSizeMb(options.getWorkerCacheMb())
+ .build()
+ .forComputation("comp"),
StreamingStepMetricsContainer.createRegistry(),
new DataflowExecutionStateTracker(
ExecutionStateSampler.newForTest(),
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateTestUtils.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateTestUtils.java
index 17da531d452..8708b9f502d 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateTestUtils.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateTestUtils.java
@@ -66,8 +66,8 @@ public class WindmillStateTestUtils {
boolean accessible = f.isAccessible();
try {
- f.setAccessible(true);
path.add(thisClazz.getName() + "#" + f.getName());
+ f.setAccessible(true);
assertNoReference(f.get(obj), clazz, path, visited);
} finally {
path.remove(path.size() - 1);
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java
index 9f97c9835dd..5d8ebd53400 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java
@@ -964,7 +964,10 @@ public class WorkerCustomSourcesTest {
COMPUTATION_ID,
new ReaderCache(Duration.standardMinutes(1), Runnable::run),
/*stateNameMap=*/ ImmutableMap.of(),
-
WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb()).forComputation(COMPUTATION_ID),
+ WindmillStateCache.builder()
+ .setSizeMb(options.getWorkerCacheMb())
+ .build()
+ .forComputation(COMPUTATION_ID),
StreamingStepMetricsContainer.createRegistry(),
new DataflowExecutionStateTracker(
ExecutionStateSampler.newForTest(),
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCacheTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCacheTest.java
index 446a34f73de..ce8da106b0c 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCacheTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCacheTest.java
@@ -148,7 +148,7 @@ public class WindmillStateCacheTest {
@Before
public void setUp() {
options = PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class);
- cache = WindmillStateCache.ofSizeMbs(400);
+ cache = WindmillStateCache.builder().setSizeMb(400).build();
assertEquals(0, cache.getWeight());
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternalsTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternalsTest.java
index a53240d6453..33e47623cd0 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternalsTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternalsTest.java
@@ -20,11 +20,7 @@ package
org.apache.beam.runners.dataflow.worker.windmill.state;
import static
org.apache.beam.runners.dataflow.worker.DataflowMatchers.ByteStringMatcher.byteStringEq;
import static org.apache.beam.sdk.testing.SystemNanoTimeSleeper.sleepMillis;
import static org.hamcrest.MatcherAssert.assertThat;
-import static org.junit.Assert.assertArrayEquals;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNull;
-import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.Matchers.eq;
@@ -130,7 +126,9 @@ public class WindmillStateInternalsTest {
@Mock private WindmillStateReader mockReader;
private WindmillStateInternals<String> underTest;
private WindmillStateInternals<String> underTestNewKey;
+ private WindmillStateInternals<String> underTestMapViaMultimap;
private WindmillStateCache cache;
+ private WindmillStateCache cacheViaMultimap;
@Mock private Supplier<Closeable> readStateSupplier;
private static ByteString key(StateNamespace namespace, String addrId) {
@@ -206,7 +204,12 @@ public class WindmillStateInternalsTest {
public void setUp() {
MockitoAnnotations.initMocks(this);
options = PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class);
- cache = WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb());
+ cache =
WindmillStateCache.builder().setSizeMb(options.getWorkerCacheMb()).build();
+ cacheViaMultimap =
+ WindmillStateCache.builder()
+ .setSizeMb(options.getWorkerCacheMb())
+ .setSupportMapViaMultimap(true)
+ .build();
resetUnderTest();
}
@@ -242,6 +245,21 @@ public class WindmillStateInternalsTest {
workToken)
.forFamily(STATE_FAMILY),
readStateSupplier);
+ underTestMapViaMultimap =
+ new WindmillStateInternals<String>(
+ "dummyNewKey",
+ STATE_FAMILY,
+ mockReader,
+ false,
+ cacheViaMultimap
+ .forComputation("comp")
+ .forKey(
+ WindmillComputationKey.create(
+ "comp", ByteString.copyFrom("dummyNewKey",
Charsets.UTF_8), 123),
+ 17L,
+ workToken)
+ .forFamily(STATE_FAMILY),
+ readStateSupplier);
}
@After
@@ -249,6 +267,7 @@ public class WindmillStateInternalsTest {
// Make sure no WindmillStateReader (a per-WorkItem object) escapes into
the cache
// (a global object).
WindmillStateTestUtils.assertNoReference(cache, WindmillStateReader.class);
+ WindmillStateTestUtils.assertNoReference(cacheViaMultimap,
WindmillStateReader.class);
}
private <T> void waitAndSet(final SettableFuture<T> future, final T value,
final long millis) {
@@ -741,6 +760,38 @@ public class WindmillStateInternalsTest {
assertThat(result.read(), Matchers.containsInAnyOrder(1, 2, 3));
}
+ @Test
+ public void testMapViaMultimapGet() {
+ final String tag = "map";
+ StateTag<MapState<byte[], Integer>> addr =
+ StateTags.map(tag, ByteArrayCoder.of(), VarIntCoder.of());
+ MapState<byte[], Integer> mapViaMultiMapState =
underTestMapViaMultimap.state(NAMESPACE, addr);
+
+ final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8);
+ final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8);
+ SettableFuture<Iterable<Integer>> future1 = SettableFuture.create();
+ when(mockReader.multimapFetchSingleEntryFuture(
+ encodeWithCoder(key1, ByteArrayCoder.of()),
+ key(NAMESPACE, tag),
+ STATE_FAMILY,
+ VarIntCoder.of()))
+ .thenReturn(future1);
+ SettableFuture<Iterable<Integer>> future2 = SettableFuture.create();
+ when(mockReader.multimapFetchSingleEntryFuture(
+ encodeWithCoder(key2, ByteArrayCoder.of()),
+ key(NAMESPACE, tag),
+ STATE_FAMILY,
+ VarIntCoder.of()))
+ .thenReturn(future2);
+
+ ReadableState<Integer> result1 =
mapViaMultiMapState.get(dup(key1)).readLater();
+ ReadableState<Integer> result2 =
mapViaMultiMapState.get(dup(key2)).readLater();
+ waitAndSet(future1, Collections.singletonList(1), 30);
+ waitAndSet(future2, Collections.emptyList(), 1);
+ assertEquals(Integer.valueOf(1), result1.read());
+ assertNull(result2.read());
+ }
+
@Test
public void testMultimapPutAndGet() {
final String tag = "multimap";
@@ -761,6 +812,41 @@ public class WindmillStateInternalsTest {
ReadableState<Iterable<Integer>> result =
multimapState.get(dup(key)).readLater();
waitAndSet(future, Arrays.asList(1, 2, 3), 30);
assertThat(result.read(), Matchers.containsInAnyOrder(1, 1, 2, 3));
+
+ multimapState.remove(key);
+ multimapState.put(key, 4);
+ multimapState.remove(key);
+ multimapState.put(key, 5);
+ assertThat(result.read(), Matchers.containsInAnyOrder(5));
+ multimapState.clear();
+ assertThat(multimapState.get(key).read(), Matchers.emptyIterable());
+ }
+
+ @Test
+ public void testMapViaMultimapPutAndGet() {
+ final String tag = "map";
+ StateTag<MapState<byte[], Integer>> addr =
+ StateTags.map(tag, ByteArrayCoder.of(), VarIntCoder.of());
+ MapState<byte[], Integer> mapViaMultiMapState =
underTestMapViaMultimap.state(NAMESPACE, addr);
+
+ final byte[] key = "key".getBytes(StandardCharsets.UTF_8);
+ SettableFuture<Iterable<Integer>> future = SettableFuture.create();
+ when(mockReader.multimapFetchSingleEntryFuture(
+ encodeWithCoder(key, ByteArrayCoder.of()),
+ key(NAMESPACE, tag),
+ STATE_FAMILY,
+ VarIntCoder.of()))
+ .thenReturn(future);
+
+ mapViaMultiMapState.put(key, 1);
+ ReadableState<Integer> result =
mapViaMultiMapState.get(dup(key)).readLater();
+ waitAndSet(future, Collections.singletonList(2), 30);
+ assertEquals(Integer.valueOf(1), result.read());
+
+ mapViaMultiMapState.put(key, 3);
+ assertEquals(Integer.valueOf(3), mapViaMultiMapState.get(key).read());
+ mapViaMultiMapState.clear();
+ assertNull(mapViaMultiMapState.get(key).read());
}
@Test
@@ -791,6 +877,33 @@ public class WindmillStateInternalsTest {
assertThat(result2.read(), Matchers.emptyIterable());
}
+ @Test
+ public void testMapViaMultimapRemoveAndGet() {
+ final String tag = "map";
+ StateTag<MapState<byte[], Integer>> addr =
+ StateTags.map(tag, ByteArrayCoder.of(), VarIntCoder.of());
+ MapState<byte[], Integer> mapViaMultiMapState =
underTestMapViaMultimap.state(NAMESPACE, addr);
+
+ final byte[] key = "key".getBytes(StandardCharsets.UTF_8);
+ SettableFuture<Iterable<Integer>> future = SettableFuture.create();
+ when(mockReader.multimapFetchSingleEntryFuture(
+ encodeWithCoder(key, ByteArrayCoder.of()),
+ key(NAMESPACE, tag),
+ STATE_FAMILY,
+ VarIntCoder.of()))
+ .thenReturn(future);
+
+ ReadableState<Integer> result1 = mapViaMultiMapState.get(key).readLater();
+ ReadableState<Integer> result2 =
mapViaMultiMapState.get(dup(key)).readLater();
+ waitAndSet(future, Collections.singletonList(1), 30);
+
+ assertEquals(Integer.valueOf(1), result1.read());
+
+ mapViaMultiMapState.remove(key);
+ assertNull(mapViaMultiMapState.get(dup(key)).read());
+ assertNull(result2.read());
+ }
+
@Test
public void testMultimapRemoveThenPut() {
final String tag = "multimap";
@@ -1030,6 +1143,64 @@ public class WindmillStateInternalsTest {
assertThat(keys, Matchers.containsInAnyOrder(key1, key2, key3));
}
+ @Test
+ public void testMapViaMultimapEntriesAndKeysMergeLocalAddRemoveClear() {
+ final String tag = "map";
+ StateTag<MapState<byte[], Integer>> addr =
+ StateTags.map(tag, ByteArrayCoder.of(), VarIntCoder.of());
+ MapState<byte[], Integer> mapState =
underTestMapViaMultimap.state(NAMESPACE, addr);
+
+ final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8);
+ final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8);
+ final byte[] key3 = "key3".getBytes(StandardCharsets.UTF_8);
+ final byte[] key4 = "key4".getBytes(StandardCharsets.UTF_8);
+
+ SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>>
entriesFuture =
+ SettableFuture.create();
+ when(mockReader.multimapFetchAllFuture(
+ false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+ .thenReturn(entriesFuture);
+ SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>>
keysFuture =
+ SettableFuture.create();
+ when(mockReader.multimapFetchAllFuture(
+ true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+ .thenReturn(keysFuture);
+
+ ReadableState<Iterable<Map.Entry<byte[], Integer>>> entriesResult =
+ mapState.entries().readLater();
+ ReadableState<Iterable<byte[]>> keysResult = mapState.keys().readLater();
+ waitAndSet(entriesFuture, Arrays.asList(multimapEntry(key1, 3),
multimapEntry(key2, 4)), 30);
+ waitAndSet(keysFuture, Arrays.asList(multimapEntry(key1),
multimapEntry(key2)), 30);
+
+ mapState.put(key1, 7);
+ mapState.put(dup(key3), 8);
+ mapState.put(key4, 1);
+ mapState.remove(key4);
+
+ Iterable<Map.Entry<byte[], Integer>> entries = entriesResult.read();
+ assertEquals(3, Iterables.size(entries));
+ assertThat(
+ entries,
+ Matchers.containsInAnyOrder(
+ multimapEntryMatcher(key1, 7),
+ multimapEntryMatcher(key2, 4),
+ multimapEntryMatcher(key3, 8)));
+
+ Iterable<byte[]> keys = keysResult.read();
+ assertEquals(3, Iterables.size(keys));
+ assertThat(keys, Matchers.containsInAnyOrder(key1, key2, key3));
+ assertFalse(mapState.isEmpty().read());
+
+ mapState.clear();
+ assertTrue(mapState.isEmpty().read());
+ assertTrue(Iterables.isEmpty(mapState.keys().read()));
+ assertTrue(Iterables.isEmpty(mapState.entries().read()));
+
+ // Previously read iterable should still have the same result.
+ assertEquals(3, Iterables.size(keys));
+ assertThat(keys, Matchers.containsInAnyOrder(key1, key2, key3));
+ }
+
@Test
public void testMultimapEntriesAndKeysMergeLocalRemove() {
final String tag = "multimap";
@@ -1080,6 +1251,48 @@ public class WindmillStateInternalsTest {
assertThat(keys, Matchers.containsInAnyOrder(key2, key3));
}
+ @Test
+ public void testMapViaMultimapEntriesAndKeysMergeLocalRemove() {
+ final String tag = "map";
+ StateTag<MapState<byte[], Integer>> addr =
+ StateTags.map(tag, ByteArrayCoder.of(), VarIntCoder.of());
+ MapState<byte[], Integer> mapState =
underTestMapViaMultimap.state(NAMESPACE, addr);
+
+ final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8);
+ final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8);
+ final byte[] key3 = "key3".getBytes(StandardCharsets.UTF_8);
+
+ SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>>
entriesFuture =
+ SettableFuture.create();
+ when(mockReader.multimapFetchAllFuture(
+ false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+ .thenReturn(entriesFuture);
+ SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>>
keysFuture =
+ SettableFuture.create();
+ when(mockReader.multimapFetchAllFuture(
+ true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+ .thenReturn(keysFuture);
+
+ ReadableState<Iterable<Map.Entry<byte[], Integer>>> entriesResult =
+ mapState.entries().readLater();
+ ReadableState<Iterable<byte[]>> keysResult = mapState.keys().readLater();
+ waitAndSet(entriesFuture, Arrays.asList(multimapEntry(key1, 1),
multimapEntry(key2, 2)), 30);
+ waitAndSet(keysFuture, Arrays.asList(multimapEntry(key1),
multimapEntry(key2)), 30);
+
+ mapState.remove(dup(key1));
+ mapState.put(key2, 8);
+ mapState.put(dup(key3), 9);
+
+ Iterable<Map.Entry<byte[], Integer>> entries = entriesResult.read();
+ assertEquals(2, Iterables.size(entries));
+ assertThat(
+ entries,
+ Matchers.containsInAnyOrder(multimapEntryMatcher(key2, 8),
multimapEntryMatcher(key3, 9)));
+
+ Iterable<byte[]> keys = keysResult.read();
+ assertThat(keys, Matchers.containsInAnyOrder(key2, key3));
+ }
+
@Test
public void testMultimapCacheComplete() {
final String tag = "multimap";
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/DispatchedActiveWorkRefresherTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/DispatchedActiveWorkRefresherTest.java
index 175c8421ff8..13019116767 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/DispatchedActiveWorkRefresherTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/DispatchedActiveWorkRefresherTest.java
@@ -207,7 +207,7 @@ public class DispatchedActiveWorkRefresherTest {
int stuckCommitDurationMillis = 100;
Table<ComputationState, ExecutableWork, WindmillStateCache.ForComputation>
computations =
HashBasedTable.create();
- WindmillStateCache stateCache = WindmillStateCache.ofSizeMbs(100);
+ WindmillStateCache stateCache =
WindmillStateCache.builder().setSizeMb(100).build();
ByteString key = ByteString.EMPTY;
for (int i = 0; i < 5; i++) {
WindmillStateCache.ForComputation perComputationStateCache =
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java
index 942881522cf..df5084ad092 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java
@@ -377,6 +377,25 @@ public class StateSpecs {
}
}
+ /**
+ * <b><i>For internal use only; no backwards-compatibility
guarantees.</i></b>
+ *
+ * <p>Convert a set state spec to a map-state spec.
+ */
+ @Internal
+ public static <KeyT, ValueT> StateSpec<MultimapState<KeyT, ValueT>>
convertToMultimapSpecInternal(
+ StateSpec<MapState<KeyT, ValueT>> spec) {
+ if (spec instanceof MapStateSpec) {
+ // Checked above; conversion to a map spec depends on the provided spec
being one of those
+ // created via the factory methods in this class.
+ @SuppressWarnings("unchecked")
+ MapStateSpec<KeyT, ValueT> typedSpec = (MapStateSpec<KeyT, ValueT>) spec;
+ return typedSpec.asMultimapSpec();
+ } else {
+ throw new IllegalArgumentException("Unexpected StateSpec " + spec);
+ }
+ }
+
/**
* A specification for a state cell holding a settable value of type {@code
T}.
*
@@ -768,6 +787,10 @@ public class StateSpecs {
public int hashCode() {
return Objects.hash(getClass(), keyCoder, valueCoder);
}
+
+ private MultimapStateSpec<K, V> asMultimapSpec() {
+ return new MultimapStateSpec<>(this.keyCoder, this.valueCoder);
+ }
}
private static class MultimapStateSpec<K, V> implements
StateSpec<MultimapState<K, V>> {
diff --git
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
index 89dcafbdf94..fb2321328b3 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
@@ -2709,19 +2709,26 @@ public class ParDoTest implements Serializable {
@StateId(countStateId) CombiningState<Integer, int[], Integer>
count,
OutputReceiver<KV<String, Integer>> r) {
KV<String, Integer> value = element.getValue();
- ReadableState<Iterable<Entry<String, Integer>>> entriesView =
state.entries();
state.put(value.getKey(), value.getValue());
count.add(1);
+
+ @Nullable Integer max = state.get("max").read();
+ state.put("max", Math.max(max == null ? 0 : max,
value.getValue()));
if (count.read() >= 4) {
- Iterable<Map.Entry<String, Integer>> iterate =
state.entries().read();
+ assertEquals(Integer.valueOf(97), state.get("a").read());
+
+ Iterable<Map.Entry<String, Integer>> entriesView =
state.entries().read();
+ Iterable<String> keysView = state.keys().read();
// Make sure that the cached Iterable doesn't change when new
elements are added,
// but that cached ReadableState views of the state do change.
state.put("BadKey", -1);
- assertEquals(3, Iterables.size(iterate));
- assertEquals(4, Iterables.size(entriesView.read()));
- assertEquals(4, Iterables.size(state.entries().read()));
+ assertEquals(4, Iterables.size(entriesView));
+ assertEquals(4, Iterables.size(keysView));
+ assertEquals(5, Iterables.size(state.entries().read()));
+ assertEquals(5, Iterables.size(state.keys().read()));
+ assertEquals(Integer.valueOf(97), state.get("max").read());
- for (Map.Entry<String, Integer> entry : iterate) {
+ for (Map.Entry<String, Integer> entry : entriesView) {
r.output(KV.of(entry.getKey(), entry.getValue()));
}
}
@@ -2732,11 +2739,14 @@ public class ParDoTest implements Serializable {
pipeline
.apply(
Create.of(
- KV.of("hello", KV.of("a", 97)), KV.of("hello",
KV.of("b", 42)),
- KV.of("hello", KV.of("b", 42)), KV.of("hello",
KV.of("c", 12))))
+ KV.of("hello", KV.of("a", 97)),
+ KV.of("hello", KV.of("b", 42)),
+ KV.of("hello", KV.of("b", 42)),
+ KV.of("hello", KV.of("c", 12))))
.apply(ParDo.of(fn));
- PAssert.that(output).containsInAnyOrder(KV.of("a", 97), KV.of("b", 42),
KV.of("c", 12));
+ PAssert.that(output)
+ .containsInAnyOrder(KV.of("a", 97), KV.of("b", 42), KV.of("c", 12),
KV.of("max", 97));
pipeline.run();
}