This is an automated email from the ASF dual-hosted git repository.
scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 21dbf592f87 Adds Multimap support to JAVA FnApi (#36218)
21dbf592f87 is described below
commit 21dbf592f87d17f2f6e863323c9b946ae6b09b4a
Author: Andrew Crites <[email protected]>
AuthorDate: Tue Oct 28 12:25:11 2025 -0700
Adds Multimap support to JAVA FnApi (#36218)
* Changes multimap state key() tests to not care about order. There is no
guarantee on the order keys are returned. Also fixes a couple warnings from
other FnApi tests.
* Adds Multimap user state support to the Java FnApi harness. Also adds a
missing FnApi state proto to get all of the entries of a multimap. This type of
access is part of the state API (and supported by the non-portable harness),
but was not present in the protos.
* Adds FnApi binding for entries() method.
* Changes multimap entries() iterable to put values for the same key from
the backend and local adds together. Also needed to make maybePrefetchable
public.
* Adds a test that prefetching multimap entries results in a StateRequest
sent across FnApi.
* Adds an environment capability for multimap state and sets in for the
java sdk.
---
.../beam/model/fn_execution/v1/beam_fn_api.proto | 24 ++++
.../beam/model/pipeline/v1/beam_runner_api.proto | 10 +-
.../beam/sdk/fn/stream/PrefetchableIterables.java | 2 +-
.../beam/sdk/util/construction/Environments.java | 1 +
.../org/apache/beam/sdk/transforms/ParDoTest.java | 67 ++++++++++
.../sdk/util/construction/EnvironmentsTest.java | 3 +
.../beam/fn/harness/state/FnApiStateAccessor.java | 121 ++++++++++++++++--
.../beam/fn/harness/state/MultimapUserState.java | 139 ++++++++++++++++++++-
.../fn/harness/state/MultimapUserStateTest.java | 129 +++++++++++++++++++
9 files changed, 480 insertions(+), 16 deletions(-)
diff --git
a/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto
b/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto
index 9b32048b499..4eee2ef5d89 100644
---
a/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto
+++
b/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto
@@ -1017,6 +1017,29 @@ message StateKey {
bytes key = 4;
}
+ // Represents a request for all of the entries of a multimap associated with
a
+ // specified user key and window for a PTransform. See
+ // https://s.apache.org/beam-fn-state-api-and-bundle-processing for further
+ // details.
+ //
+ // Can only be used to perform StateGetRequests and StateClearRequests on the
+ // user state.
+ //
+ // The response data stream will be a concatenation of pairs, where the first
+ // component is the map key and the second component is a concatenation of
+ // values associated with that map key.
+ message MultimapEntriesUserState {
+ // (Required) The id of the PTransform containing user state.
+ string transform_id = 1;
+ // (Required) The id of the user state.
+ string user_state_id = 2;
+ // (Required) The window encoded in a nested context.
+ bytes window = 3;
+ // (Required) The key of the currently executing element encoded in a
+ // nested context.
+ bytes key = 4;
+ }
+
// Represents a request for the values of the map key associated with a
// specified user key and window for a PTransform. See
// https://s.apache.org/beam-fn-state-api-and-bundle-processing for further
@@ -1072,6 +1095,7 @@ message StateKey {
MultimapKeysSideInput multimap_keys_side_input = 5;
MultimapKeysValuesSideInput multimap_keys_values_side_input = 8;
MultimapKeysUserState multimap_keys_user_state = 6;
+ MultimapEntriesUserState multimap_entries_user_state = 10;
MultimapUserState multimap_user_state = 7;
OrderedListUserState ordered_list_user_state = 9;
}
diff --git
a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto
b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto
index c615b2a5279..0bdc4f69aab 100644
---
a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto
+++
b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto
@@ -1621,13 +1621,13 @@ message AnyOfEnvironmentPayload {
// environment understands.
message StandardProtocols {
enum Enum {
- // Indicates suport for progress reporting via the legacy Metrics proto.
+ // Indicates support for progress reporting via the legacy Metrics proto.
LEGACY_PROGRESS_REPORTING = 0 [(beam_urn) =
"beam:protocol:progress_reporting:v0"];
- // Indicates suport for progress reporting via the new MonitoringInfo
proto.
+ // Indicates support for progress reporting via the new MonitoringInfo
proto.
PROGRESS_REPORTING = 1 [(beam_urn) =
"beam:protocol:progress_reporting:v1"];
- // Indicates suport for worker status protocol defined at
+ // Indicates support for worker status protocol defined at
// https://s.apache.org/beam-fn-api-harness-status.
WORKER_STATUS = 2 [(beam_urn) = "beam:protocol:worker_status:v1"];
@@ -1681,6 +1681,10 @@ message StandardProtocols {
// Indicates support for reading, writing and propagating Element's
metadata
ELEMENT_METADATA = 11
[(beam_urn) = "beam:protocol:element_metadata:v1"];
+
+ // Indicates whether the SDK supports multimap state.
+ MULTIMAP_STATE = 12
+ [(beam_urn) = "beam:protocol:multimap_state:v1"];
}
}
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/stream/PrefetchableIterables.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/stream/PrefetchableIterables.java
index dd7ec6b0f65..1f7451e72a2 100644
---
a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/stream/PrefetchableIterables.java
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/stream/PrefetchableIterables.java
@@ -94,7 +94,7 @@ public class PrefetchableIterables {
* constructed that ensures that {@link PrefetchableIterator#prefetch()} is
a no-op and {@link
* PrefetchableIterator#isReady()} always returns true.
*/
- private static <T> PrefetchableIterable<T> maybePrefetchable(Iterable<T>
iterable) {
+ public static <T> PrefetchableIterable<T> maybePrefetchable(Iterable<T>
iterable) {
if (iterable instanceof PrefetchableIterable) {
return (PrefetchableIterable<T>) iterable;
}
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/Environments.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/Environments.java
index 55379bf3a80..969bda88d07 100644
---
a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/Environments.java
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/Environments.java
@@ -521,6 +521,7 @@ public class Environments {
capabilities.add(BeamUrns.getUrn(StandardProtocols.Enum.DATA_SAMPLING));
capabilities.add(BeamUrns.getUrn(StandardProtocols.Enum.SDK_CONSUMING_RECEIVED_DATA));
capabilities.add(BeamUrns.getUrn(StandardProtocols.Enum.ORDERED_LIST_STATE));
+ capabilities.add(BeamUrns.getUrn(StandardProtocols.Enum.MULTIMAP_STATE));
return capabilities.build();
}
diff --git
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
index 8409133772e..8a273127b4f 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
@@ -2917,6 +2917,73 @@ public class ParDoTest implements Serializable {
pipeline.run();
}
+ @Test
+ @Category({ValidatesRunner.class, UsesStatefulParDo.class,
UsesMultimapState.class})
+ public void testMultimapStateEntries() {
+ final String stateId = "foo:";
+ final String countStateId = "count";
+ DoFn<KV<String, KV<String, Integer>>, KV<String, Integer>> fn =
+ new DoFn<KV<String, KV<String, Integer>>, KV<String, Integer>>() {
+
+ @StateId(stateId)
+ private final StateSpec<MultimapState<String, Integer>>
multimapState =
+ StateSpecs.multimap(StringUtf8Coder.of(), VarIntCoder.of());
+
+ @StateId(countStateId)
+ private final StateSpec<CombiningState<Integer, int[], Integer>>
countState =
+ StateSpecs.combiningFromInputInternal(VarIntCoder.of(),
Sum.ofIntegers());
+
+ @ProcessElement
+ public void processElement(
+ ProcessContext c,
+ @Element KV<String, KV<String, Integer>> element,
+ @StateId(stateId) MultimapState<String, Integer> state,
+ @StateId(countStateId) CombiningState<Integer, int[], Integer>
count,
+ OutputReceiver<KV<String, Integer>> r) {
+ // Empty before we process any elements.
+ if (count.read() == 0) {
+ assertThat(state.entries().read(), emptyIterable());
+ }
+ assertEquals(count.read().intValue(),
Iterables.size(state.entries().read()));
+
+ KV<String, Integer> value = element.getValue();
+ state.put(value.getKey(), value.getValue());
+ count.add(1);
+
+ if (count.read() >= 4) {
+ // This should be evaluated only when ReadableState.read is
called.
+ ReadableState<Iterable<Entry<String, Integer>>> entriesView =
state.entries();
+
+ // This is evaluated immediately.
+ Iterable<Entry<String, Integer>> entries =
state.entries().read();
+
+ state.remove("b");
+ assertEquals(4, Iterables.size(entries));
+ state.put("a", 2);
+ state.put("a", 3);
+
+ assertEquals(5, Iterables.size(entriesView.read()));
+ // Note we output the view of state before the modifications
in this if statement.
+ for (Entry<String, Integer> entry : entries) {
+ r.output(KV.of(entry.getKey(), entry.getValue()));
+ }
+ }
+ }
+ };
+ PCollection<KV<String, Integer>> output =
+ pipeline
+ .apply(
+ Create.of(
+ KV.of("hello", KV.of("a", 97)), KV.of("hello",
KV.of("a", 97)),
+ KV.of("hello", KV.of("a", 98)), KV.of("hello",
KV.of("b", 33))))
+ .apply(ParDo.of(fn));
+ PAssert.that(output)
+ .containsInAnyOrder(
+ KV.of("a", 97), KV.of("a", 97),
+ KV.of("a", 98), KV.of("b", 33));
+ pipeline.run();
+ }
+
@Test
@Category({ValidatesRunner.class, UsesStatefulParDo.class,
UsesMultimapState.class})
public void testMultimapStateRemove() {
diff --git
a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/construction/EnvironmentsTest.java
b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/construction/EnvironmentsTest.java
index f12a2a77f99..ebd4e9fbe24 100644
---
a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/construction/EnvironmentsTest.java
+++
b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/construction/EnvironmentsTest.java
@@ -219,6 +219,9 @@ public class EnvironmentsTest implements Serializable {
assertThat(
Environments.getJavaCapabilities(),
hasItem(BeamUrns.getUrn(RunnerApi.StandardProtocols.Enum.ORDERED_LIST_STATE)));
+ assertThat(
+ Environments.getJavaCapabilities(),
+
hasItem(BeamUrns.getUrn(RunnerApi.StandardProtocols.Enum.MULTIMAP_STATE)));
// Check that SDF truncation is supported
assertThat(
Environments.getJavaCapabilities(),
diff --git
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
index e06a82c8e25..6913c75a5f2 100644
---
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
+++
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
@@ -117,7 +117,7 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
public Factory(
PipelineOptions pipelineOptions,
- Set<String> runnerCapabilites,
+ Set<String> runnerCapabilities,
String ptransformId,
Supplier<String> processBundleInstructionId,
Supplier<List<CacheToken>> cacheTokens,
@@ -128,7 +128,7 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
Coder<K> keyCoder,
Coder<BoundedWindow> windowCoder) {
this.pipelineOptions = pipelineOptions;
- this.runnerCapabilities = runnerCapabilites;
+ this.runnerCapabilities = runnerCapabilities;
this.ptransformId = ptransformId;
this.processBundleInstructionId = processBundleInstructionId;
this.cacheTokens = cacheTokens;
@@ -240,7 +240,7 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
}
private final PipelineOptions pipelineOptions;
- private final Set<String> runnerCapabilites;
+ private final Set<String> runnerCapabilities;
private final Map<StateKey, Object> stateKeyObjectCache;
private final Map<TupleTag<?>, SideInputSpec> sideInputSpecMap;
private final BeamFnStateClient beamFnStateClient;
@@ -259,7 +259,7 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
public FnApiStateAccessor(
PipelineOptions pipelineOptions,
- Set<String> runnerCapabilites,
+ Set<String> runnerCapabilities,
String ptransformId,
Supplier<String> processBundleInstructionId,
Supplier<List<CacheToken>> cacheTokens,
@@ -270,7 +270,7 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
Coder<K> keyCoder,
Coder<BoundedWindow> windowCoder) {
this.pipelineOptions = pipelineOptions;
- this.runnerCapabilites = runnerCapabilites;
+ this.runnerCapabilities = runnerCapabilities;
this.stateKeyObjectCache = Maps.newHashMap();
this.sideInputSpecMap = sideInputSpecMap;
this.beamFnStateClient = beamFnStateClient;
@@ -414,7 +414,7 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
key,
((KvCoder)
sideInputSpec.getCoder()).getKeyCoder(),
((KvCoder)
sideInputSpec.getCoder()).getValueCoder(),
- runnerCapabilites.contains(
+ runnerCapabilities.contains(
BeamUrns.getUrn(
RunnerApi.StandardRunnerProtocols.Enum
.MULTIMAP_KEYS_VALUES_SIDE_INPUT))));
@@ -762,8 +762,113 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
StateSpec<MultimapState<KeyT, ValueT>> spec,
Coder<KeyT> keyCoder,
Coder<ValueT> valueCoder) {
- // TODO(https://github.com/apache/beam/issues/23616)
- throw new UnsupportedOperationException("Multimap is not currently
supported with Fn API.");
+ return (MultimapState<KeyT, ValueT>)
+ stateKeyObjectCache.computeIfAbsent(
+ createMultimapKeysUserStateKey(id),
+ new Function<StateKey, Object>() {
+ @Override
+ public Object apply(StateKey stateKey) {
+ return new MultimapState<KeyT, ValueT>() {
+ private final MultimapUserState<KeyT, ValueT> impl =
+ createMultimapUserState(stateKey, keyCoder, valueCoder);
+
+ @Override
+ public void put(KeyT key, ValueT value) {
+ impl.put(key, value);
+ }
+
+ @Override
+ public ReadableState<Iterable<ValueT>> get(KeyT key) {
+ return new ReadableState<Iterable<ValueT>>() {
+ @Override
+ public Iterable<ValueT> read() {
+ return impl.get(key);
+ }
+
+ @Override
+ public ReadableState<Iterable<ValueT>> readLater() {
+ impl.get(key).prefetch();
+ return this;
+ }
+ };
+ }
+
+ @Override
+ public void remove(KeyT key) {
+ impl.remove(key);
+ }
+
+ @Override
+ public ReadableState<Iterable<KeyT>> keys() {
+ return new ReadableState<Iterable<KeyT>>() {
+ @Override
+ public Iterable<KeyT> read() {
+ return impl.keys();
+ }
+
+ @Override
+ public ReadableState<Iterable<KeyT>> readLater() {
+ impl.keys().prefetch();
+ return this;
+ }
+ };
+ }
+
+ @Override
+ public ReadableState<Iterable<Map.Entry<KeyT, ValueT>>>
entries() {
+ return new ReadableState<Iterable<Map.Entry<KeyT,
ValueT>>>() {
+ @Override
+ public Iterable<Map.Entry<KeyT, ValueT>> read() {
+ return impl.entries();
+ }
+
+ @Override
+ public ReadableState<Iterable<Map.Entry<KeyT, ValueT>>>
readLater() {
+ impl.entries().prefetch();
+ return this;
+ }
+ };
+ }
+
+ @Override
+ public ReadableState<Boolean> containsKey(KeyT key) {
+ return new ReadableState<Boolean>() {
+ @Override
+ public Boolean read() {
+ return !Iterables.isEmpty(impl.get(key));
+ }
+
+ @Override
+ public ReadableState<Boolean> readLater() {
+ impl.get(key).prefetch();
+ return this;
+ }
+ };
+ }
+
+ @Override
+ public ReadableState<Boolean> isEmpty() {
+ return new ReadableState<Boolean>() {
+ @Override
+ public Boolean read() {
+ return Iterables.isEmpty(impl.keys());
+ }
+
+ @Override
+ public ReadableState<Boolean> readLater() {
+ impl.keys().prefetch();
+ return this;
+ }
+ };
+ }
+
+ @Override
+ public void clear() {
+ impl.clear();
+ }
+ };
+ }
+ });
}
@Override
diff --git
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java
index 617faba87cc..8e3d76f5fc8 100644
---
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java
+++
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java
@@ -29,6 +29,7 @@ import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
+import java.util.Objects;
import java.util.Set;
import org.apache.beam.fn.harness.Cache;
import org.apache.beam.fn.harness.Caches;
@@ -38,13 +39,19 @@ import
org.apache.beam.model.fnexecution.v1.BeamFnApi.StateClearRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.IterableCoder;
+import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
import org.apache.beam.sdk.fn.stream.PrefetchableIterables;
import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
import org.apache.beam.sdk.util.ByteStringOutputStream;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets;
/**
* An implementation of a multimap user state that utilizes the Beam Fn State
API to fetch, clear
@@ -52,9 +59,6 @@ import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
*
* <p>Calling {@link #asyncClose()} schedules any required persistence
changes. This object should
* no longer be used after it is closed.
- *
- * <p>TODO: Move to an async persist model where persistence is signalled
based upon cache memory
- * pressure and its need to flush.
*/
public class MultimapUserState<K, V> {
@@ -63,8 +67,10 @@ public class MultimapUserState<K, V> {
private final Coder<K> mapKeyCoder;
private final Coder<V> valueCoder;
private final StateRequest keysStateRequest;
+ private final StateRequest entriesStateRequest;
private final StateRequest userStateRequest;
private final CachingStateIterable<K> persistedKeys;
+ private final CachingStateIterable<KV<K, Iterable<V>>> persistedEntries;
private boolean isClosed;
private boolean isCleared;
@@ -90,6 +96,8 @@ public class MultimapUserState<K, V> {
this.mapKeyCoder = mapKeyCoder;
this.valueCoder = valueCoder;
+ // Note: These StateRequest protos are constructed even if we never try to
read the
+ // corresponding state type. Consider constructing them lazily, as needed.
this.keysStateRequest =
StateRequest.newBuilder().setInstructionId(instructionId).setStateKey(stateKey).build();
this.persistedKeys =
@@ -106,6 +114,23 @@ public class MultimapUserState<K, V> {
.setWindow(stateKey.getMultimapKeysUserState().getWindow())
.setKey(stateKey.getMultimapKeysUserState().getKey());
this.userStateRequest = userStateRequestBuilder.build();
+
+ StateRequest.Builder entriesStateRequestBuilder =
StateRequest.newBuilder();
+ entriesStateRequestBuilder
+ .setInstructionId(instructionId)
+ .getStateKeyBuilder()
+ .getMultimapEntriesUserStateBuilder()
+ .setTransformId(stateKey.getMultimapKeysUserState().getTransformId())
+ .setUserStateId(stateKey.getMultimapKeysUserState().getUserStateId())
+ .setWindow(stateKey.getMultimapKeysUserState().getWindow())
+ .setKey(stateKey.getMultimapKeysUserState().getKey());
+ this.entriesStateRequest = entriesStateRequestBuilder.build();
+ this.persistedEntries =
+ StateFetchingIterators.readAllAndDecodeStartingFrom(
+ Caches.subCache(this.cache, "AllEntries"),
+ beamFnStateClient,
+ entriesStateRequest,
+ KvCoder.of(mapKeyCoder, IterableCoder.of(valueCoder)));
}
public void clear() {
@@ -200,7 +225,7 @@ public class MultimapUserState<K, V> {
nextKey = persistedKeysIterator.next();
Object nextKeyStructuralValue =
mapKeyCoder.structuralValue(nextKey);
if (!pendingRemovesNow.contains(nextKeyStructuralValue)) {
- // Remove all keys that we will visit when passing over the
persistedKeysIterator
+ // Remove all keys that we will visit when passing over the
persistedKeysIterator,
// so we do not revisit them when passing over the
pendingAddsNowIterator
if (pendingAddsNow.containsKey(nextKeyStructuralValue)) {
pendingAddsNow.remove(nextKeyStructuralValue);
@@ -235,6 +260,112 @@ public class MultimapUserState<K, V> {
};
}
+ @SuppressWarnings({
+ "nullness" // TODO(https://github.com/apache/beam/issues/21068)
+ })
+ /*
+ * Returns an Iterable containing all <K, V> entries in this multimap.
+ */
+ public PrefetchableIterable<Map.Entry<K, V>> entries() {
+ checkState(
+ !isClosed,
+ "Multimap user state is no longer usable because it is closed for %s",
+ keysStateRequest.getStateKey());
+ // Make a deep copy of pendingAdds so this iterator represents a snapshot
of state at the time
+ // it was created.
+ Map<Object, KV<K, List<V>>> pendingAddsNow =
ImmutableMap.copyOf(pendingAdds);
+ if (isCleared) {
+ return PrefetchableIterables.maybePrefetchable(
+ Iterables.concat(
+ Iterables.transform(
+ pendingAddsNow.entrySet(),
+ entry ->
+ Iterables.transform(
+ entry.getValue().getValue(),
+ value ->
Maps.immutableEntry(entry.getValue().getKey(), value)))));
+ }
+
+ Set<Object> pendingRemovesNow =
ImmutableSet.copyOf(pendingRemoves.keySet());
+ return new PrefetchableIterables.Default<Map.Entry<K, V>>() {
+ @Override
+ public PrefetchableIterator<Map.Entry<K, V>> createIterator() {
+ return new PrefetchableIterator<Map.Entry<K, V>>() {
+ // We can get the same key multiple times from persistedEntries in
the case that its
+ // values are paginated across multiple pages. Keep track of which
keys we've seen, so we
+ // only add in pendingAdds once (with the first page). We'll also
use it to return all
+ // keys not on the backend at the end of the iterator.
+ Set<Object> seenKeys = Sets.newHashSet();
+ final PrefetchableIterator<Map.Entry<K, V>> allEntries =
+ PrefetchableIterables.concat(
+ Iterables.concat(
+ Iterables.filter(
+ Iterables.transform(
+ persistedEntries,
+ entry -> {
+ final Object structuralKey =
+
mapKeyCoder.structuralValue(entry.getKey());
+ if
(pendingRemovesNow.contains(structuralKey)) {
+ return null;
+ }
+ // add returns true if we haven't seen
this key yet.
+ if (seenKeys.add(structuralKey)
+ &&
pendingAddsNow.containsKey(structuralKey)) {
+ return PrefetchableIterables.concat(
+ Iterables.transform(
+
pendingAddsNow.get(structuralKey).getValue(),
+ pendingAdd ->
+
Maps.immutableEntry(entry.getKey(), pendingAdd)),
+ Iterables.transform(
+ entry.getValue(),
+ value ->
Maps.immutableEntry(entry.getKey(), value)));
+ }
+ return Iterables.transform(
+ entry.getValue(),
+ value ->
Maps.immutableEntry(entry.getKey(), value));
+ }),
+ Objects::nonNull)),
+ Iterables.concat(
+ Iterables.filter(
+ Iterables.transform(
+ pendingAddsNow.entrySet(),
+ entry -> {
+ if (seenKeys.contains(entry.getKey())) {
+ return null;
+ }
+ return Iterables.transform(
+ entry.getValue().getValue(),
+ value ->
+
Maps.immutableEntry(entry.getValue().getKey(), value));
+ }),
+ Objects::nonNull)))
+ .iterator();
+
+ @Override
+ public boolean isReady() {
+ return allEntries.isReady();
+ }
+
+ @Override
+ public void prefetch() {
+ if (!isReady()) {
+ allEntries.prefetch();
+ }
+ }
+
+ @Override
+ public boolean hasNext() {
+ return allEntries.hasNext();
+ }
+
+ @Override
+ public Map.Entry<K, V> next() {
+ return allEntries.next();
+ }
+ };
+ }
+ };
+ }
+
/*
* Store a key-value pair in the multimap.
* Allows duplicate key-value pairs.
diff --git
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java
index 48c9ce43bdf..67930732182 100644
---
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java
+++
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java
@@ -22,6 +22,7 @@ import static java.util.Collections.singletonList;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.emptyIterable;
import static org.hamcrest.collection.ArrayMatching.arrayContainingInAnyOrder;
+import static
org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder;
import static org.hamcrest.core.Is.is;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
@@ -34,11 +35,15 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.StreamSupport;
import org.apache.beam.fn.harness.Cache;
import org.apache.beam.fn.harness.Caches;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
import org.apache.beam.sdk.coders.ByteArrayCoder;
import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.IterableCoder;
+import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.NullableCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
@@ -179,6 +184,81 @@ public class MultimapUserStateTest {
assertThrows(IllegalStateException.class, () -> userState.keys());
}
+ @Test
+ public void testEntries() throws Exception {
+ FakeBeamFnStateClient fakeClient =
+ new FakeBeamFnStateClient(
+ ImmutableMap.of(
+ createMultimapEntriesStateKey(),
+ KV.of(
+ KvCoder.of(ByteArrayCoder.of(),
IterableCoder.of(StringUtf8Coder.of())),
+ asList(KV.of(A1, asList("V1", "V2")), KV.of(A2,
asList("V3"))))));
+ MultimapUserState<byte[], String> userState =
+ new MultimapUserState<>(
+ Caches.noop(),
+ fakeClient,
+ "instructionId",
+ createMultimapKeyStateKey(),
+ ByteArrayCoder.of(),
+ StringUtf8Coder.of());
+
+ assertArrayEquals(A1, userState.entries().iterator().next().getKey());
+ assertThat(
+ StreamSupport.stream(userState.entries().spliterator(), false)
+ .map(entry -> KV.of(ByteString.copyFrom(entry.getKey()),
entry.getValue()))
+ .collect(Collectors.toList()),
+ containsInAnyOrder(
+ KV.of(ByteString.copyFrom(A1), "V1"),
+ KV.of(ByteString.copyFrom(A1), "V2"),
+ KV.of(ByteString.copyFrom(A2), "V3")));
+
+ userState.put(A1, "V4");
+ // Iterable is a snapshot of the entries at this time.
+ PrefetchableIterable<Map.Entry<byte[], String>> entriesBeforeOperations =
userState.entries();
+
+ assertThat(
+ StreamSupport.stream(userState.entries().spliterator(), false)
+ .map(entry -> KV.of(ByteString.copyFrom(entry.getKey()),
entry.getValue()))
+ .collect(Collectors.toList()),
+ containsInAnyOrder(
+ KV.of(ByteString.copyFrom(A1), "V1"),
+ KV.of(ByteString.copyFrom(A1), "V2"),
+ KV.of(ByteString.copyFrom(A2), "V3"),
+ KV.of(ByteString.copyFrom(A1), "V4")));
+
+ userState.remove(A1);
+ assertThat(
+ StreamSupport.stream(userState.entries().spliterator(), false)
+ .map(entry -> KV.of(ByteString.copyFrom(entry.getKey()),
entry.getValue()))
+ .collect(Collectors.toList()),
+ containsInAnyOrder(KV.of(ByteString.copyFrom(A2), "V3")));
+
+ userState.put(A1, "V5");
+ assertThat(
+ StreamSupport.stream(userState.entries().spliterator(), false)
+ .map(entry -> KV.of(ByteString.copyFrom(entry.getKey()),
entry.getValue()))
+ .collect(Collectors.toList()),
+ containsInAnyOrder(
+ KV.of(ByteString.copyFrom(A2), "V3"),
KV.of(ByteString.copyFrom(A1), "V5")));
+
+ userState.clear();
+ assertThat(userState.entries(), emptyIterable());
+ // Check that after applying all these operations, our original entries
Iterable contains a
+ // snapshot of state from when it was created.
+ assertThat(
+ StreamSupport.stream(entriesBeforeOperations.spliterator(), false)
+ .map(entry -> KV.of(ByteString.copyFrom(entry.getKey()),
entry.getValue()))
+ .collect(Collectors.toList()),
+ containsInAnyOrder(
+ KV.of(ByteString.copyFrom(A1), "V1"),
+ KV.of(ByteString.copyFrom(A1), "V2"),
+ KV.of(ByteString.copyFrom(A1), "V4"),
+ KV.of(ByteString.copyFrom(A2), "V3")));
+
+ userState.asyncClose();
+ assertThrows(IllegalStateException.class, () -> userState.entries());
+ }
+
@Test
public void testPut() throws Exception {
FakeBeamFnStateClient fakeClient =
@@ -620,6 +700,44 @@ public class MultimapUserStateTest {
assertEquals(0, fakeClient.getCallCount());
}
+ @Test
+ public void testEntriesPrefetched() throws Exception {
+ // Use a really large chunk size so all elements get returned in a single
page. This makes it
+ // easier to count how many get calls we should expect.
+ FakeBeamFnStateClient fakeClient =
+ new FakeBeamFnStateClient(
+ ImmutableMap.of(
+ createMultimapEntriesStateKey(),
+ KV.of(
+ KvCoder.of(ByteArrayCoder.of(),
IterableCoder.of(StringUtf8Coder.of())),
+ asList(KV.of(A1, asList("V1", "V2")), KV.of(A2,
asList("V3"))))),
+ 1000000);
+ MultimapUserState<byte[], String> userState =
+ new MultimapUserState<>(
+ Caches.noop(),
+ fakeClient,
+ "instructionId",
+ createMultimapKeyStateKey(),
+ ByteArrayCoder.of(),
+ StringUtf8Coder.of());
+
+ userState.put(A1, "V4");
+ PrefetchableIterable<Map.Entry<byte[], String>> entries =
userState.entries();
+ assertEquals(0, fakeClient.getCallCount());
+ entries.prefetch();
+ assertEquals(1, fakeClient.getCallCount());
+ assertThat(
+ StreamSupport.stream(entries.spliterator(), false)
+ .map(entry -> KV.of(ByteString.copyFrom(entry.getKey()),
entry.getValue()))
+ .collect(Collectors.toList()),
+ containsInAnyOrder(
+ KV.of(ByteString.copyFrom(A1), "V1"),
+ KV.of(ByteString.copyFrom(A1), "V2"),
+ KV.of(ByteString.copyFrom(A1), "V4"),
+ KV.of(ByteString.copyFrom(A2), "V3")));
+ assertEquals(1, fakeClient.getCallCount());
+ }
+
@Test
public void testClearPrefetch() throws Exception {
FakeBeamFnStateClient fakeClient =
@@ -1053,6 +1171,17 @@ public class MultimapUserStateTest {
.build();
}
+ private StateKey createMultimapEntriesStateKey() throws IOException {
+ return StateKey.newBuilder()
+ .setMultimapEntriesUserState(
+ StateKey.MultimapEntriesUserState.newBuilder()
+ .setWindow(encode(encodedWindow))
+ .setKey(encode(encodedKey))
+ .setTransformId(pTransformId)
+ .setUserStateId(stateId))
+ .build();
+ }
+
private StateKey createMultimapValueStateKey(byte[] key) throws IOException {
return StateKey.newBuilder()
.setMultimapUserState(