This is an automated email from the ASF dual-hosted git repository.

scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 21dbf592f87 Adds Multimap support to JAVA FnApi (#36218)
21dbf592f87 is described below

commit 21dbf592f87d17f2f6e863323c9b946ae6b09b4a
Author: Andrew Crites <[email protected]>
AuthorDate: Tue Oct 28 12:25:11 2025 -0700

    Adds Multimap support to JAVA FnApi (#36218)
    
    * Changes multimap state key() tests to not care about order. There is no 
guarantee on the order keys are returned. Also fixes a couple warnings from 
other FnApi tests.
    
    * Adds Multimap user state support to the Java FnApi harness. Also adds a 
missing FnApi state proto to get all of the entries of a multimap. This type of 
access is part of the state API (and supported by the non-portable harness), 
but was not present in the protos.
    
    * Adds FnApi binding for entries() method.
    
    * Changes multimap entries() iterable to put values for the same key from 
the backend and local adds together. Also needed to make maybePrefetchable 
public.
    
    * Adds a test that prefetching multimap entries results in a StateRequest 
sent across FnApi.
    
    * Adds an environment capability for multimap state and sets in for the 
java sdk.
---
 .../beam/model/fn_execution/v1/beam_fn_api.proto   |  24 ++++
 .../beam/model/pipeline/v1/beam_runner_api.proto   |  10 +-
 .../beam/sdk/fn/stream/PrefetchableIterables.java  |   2 +-
 .../beam/sdk/util/construction/Environments.java   |   1 +
 .../org/apache/beam/sdk/transforms/ParDoTest.java  |  67 ++++++++++
 .../sdk/util/construction/EnvironmentsTest.java    |   3 +
 .../beam/fn/harness/state/FnApiStateAccessor.java  | 121 ++++++++++++++++--
 .../beam/fn/harness/state/MultimapUserState.java   | 139 ++++++++++++++++++++-
 .../fn/harness/state/MultimapUserStateTest.java    | 129 +++++++++++++++++++
 9 files changed, 480 insertions(+), 16 deletions(-)

diff --git 
a/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto
 
b/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto
index 9b32048b499..4eee2ef5d89 100644
--- 
a/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto
+++ 
b/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto
@@ -1017,6 +1017,29 @@ message StateKey {
     bytes key = 4;
   }
 
+  // Represents a request for all of the entries of a multimap associated with 
a
+  // specified user key and window for a PTransform. See
+  // https://s.apache.org/beam-fn-state-api-and-bundle-processing for further
+  // details.
+  //
+  // Can only be used to perform StateGetRequests and StateClearRequests on the
+  // user state.
+  //
+  // The response data stream will be a concatenation of pairs, where the first
+  // component is the map key and the second component is a concatenation of
+  // values associated with that map key.
+  message MultimapEntriesUserState {
+    // (Required) The id of the PTransform containing user state.
+    string transform_id = 1;
+    // (Required) The id of the user state.
+    string user_state_id = 2;
+    // (Required) The window encoded in a nested context.
+    bytes window = 3;
+    // (Required) The key of the currently executing element encoded in a
+    // nested context.
+    bytes key = 4;
+  }
+
   // Represents a request for the values of the map key associated with a
   // specified user key and window for a PTransform. See
   // https://s.apache.org/beam-fn-state-api-and-bundle-processing for further
@@ -1072,6 +1095,7 @@ message StateKey {
     MultimapKeysSideInput multimap_keys_side_input = 5;
     MultimapKeysValuesSideInput multimap_keys_values_side_input = 8;
     MultimapKeysUserState multimap_keys_user_state = 6;
+    MultimapEntriesUserState multimap_entries_user_state = 10;
     MultimapUserState multimap_user_state = 7;
     OrderedListUserState ordered_list_user_state = 9;
   }
diff --git 
a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto
 
b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto
index c615b2a5279..0bdc4f69aab 100644
--- 
a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto
+++ 
b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto
@@ -1621,13 +1621,13 @@ message AnyOfEnvironmentPayload {
 // environment understands.
 message StandardProtocols {
   enum Enum {
-    // Indicates suport for progress reporting via the legacy Metrics proto.
+    // Indicates support for progress reporting via the legacy Metrics proto.
     LEGACY_PROGRESS_REPORTING = 0 [(beam_urn) = 
"beam:protocol:progress_reporting:v0"];
 
-    // Indicates suport for progress reporting via the new MonitoringInfo 
proto.
+    // Indicates support for progress reporting via the new MonitoringInfo 
proto.
     PROGRESS_REPORTING = 1 [(beam_urn) = 
"beam:protocol:progress_reporting:v1"];
 
-    // Indicates suport for worker status protocol defined at
+    // Indicates support for worker status protocol defined at
     // https://s.apache.org/beam-fn-api-harness-status.
     WORKER_STATUS = 2 [(beam_urn) = "beam:protocol:worker_status:v1"];
 
@@ -1681,6 +1681,10 @@ message StandardProtocols {
     // Indicates support for reading, writing and propagating Element's 
metadata
     ELEMENT_METADATA = 11
     [(beam_urn) = "beam:protocol:element_metadata:v1"];
+
+    // Indicates whether the SDK supports multimap state.
+    MULTIMAP_STATE = 12
+    [(beam_urn) = "beam:protocol:multimap_state:v1"];
   }
 }
 
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/stream/PrefetchableIterables.java
 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/stream/PrefetchableIterables.java
index dd7ec6b0f65..1f7451e72a2 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/stream/PrefetchableIterables.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/stream/PrefetchableIterables.java
@@ -94,7 +94,7 @@ public class PrefetchableIterables {
    * constructed that ensures that {@link PrefetchableIterator#prefetch()} is 
a no-op and {@link
    * PrefetchableIterator#isReady()} always returns true.
    */
-  private static <T> PrefetchableIterable<T> maybePrefetchable(Iterable<T> 
iterable) {
+  public static <T> PrefetchableIterable<T> maybePrefetchable(Iterable<T> 
iterable) {
     if (iterable instanceof PrefetchableIterable) {
       return (PrefetchableIterable<T>) iterable;
     }
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/Environments.java
 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/Environments.java
index 55379bf3a80..969bda88d07 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/Environments.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/Environments.java
@@ -521,6 +521,7 @@ public class Environments {
     capabilities.add(BeamUrns.getUrn(StandardProtocols.Enum.DATA_SAMPLING));
     
capabilities.add(BeamUrns.getUrn(StandardProtocols.Enum.SDK_CONSUMING_RECEIVED_DATA));
     
capabilities.add(BeamUrns.getUrn(StandardProtocols.Enum.ORDERED_LIST_STATE));
+    capabilities.add(BeamUrns.getUrn(StandardProtocols.Enum.MULTIMAP_STATE));
     return capabilities.build();
   }
 
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
index 8409133772e..8a273127b4f 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
@@ -2917,6 +2917,73 @@ public class ParDoTest implements Serializable {
       pipeline.run();
     }
 
+    @Test
+    @Category({ValidatesRunner.class, UsesStatefulParDo.class, 
UsesMultimapState.class})
+    public void testMultimapStateEntries() {
+      final String stateId = "foo:";
+      final String countStateId = "count";
+      DoFn<KV<String, KV<String, Integer>>, KV<String, Integer>> fn =
+          new DoFn<KV<String, KV<String, Integer>>, KV<String, Integer>>() {
+
+            @StateId(stateId)
+            private final StateSpec<MultimapState<String, Integer>> 
multimapState =
+                StateSpecs.multimap(StringUtf8Coder.of(), VarIntCoder.of());
+
+            @StateId(countStateId)
+            private final StateSpec<CombiningState<Integer, int[], Integer>> 
countState =
+                StateSpecs.combiningFromInputInternal(VarIntCoder.of(), 
Sum.ofIntegers());
+
+            @ProcessElement
+            public void processElement(
+                ProcessContext c,
+                @Element KV<String, KV<String, Integer>> element,
+                @StateId(stateId) MultimapState<String, Integer> state,
+                @StateId(countStateId) CombiningState<Integer, int[], Integer> 
count,
+                OutputReceiver<KV<String, Integer>> r) {
+              // Empty before we process any elements.
+              if (count.read() == 0) {
+                assertThat(state.entries().read(), emptyIterable());
+              }
+              assertEquals(count.read().intValue(), 
Iterables.size(state.entries().read()));
+
+              KV<String, Integer> value = element.getValue();
+              state.put(value.getKey(), value.getValue());
+              count.add(1);
+
+              if (count.read() >= 4) {
+                // This should be evaluated only when ReadableState.read is 
called.
+                ReadableState<Iterable<Entry<String, Integer>>> entriesView = 
state.entries();
+
+                // This is evaluated immediately.
+                Iterable<Entry<String, Integer>> entries = 
state.entries().read();
+
+                state.remove("b");
+                assertEquals(4, Iterables.size(entries));
+                state.put("a", 2);
+                state.put("a", 3);
+
+                assertEquals(5, Iterables.size(entriesView.read()));
+                // Note we output the view of state before the modifications 
in this if statement.
+                for (Entry<String, Integer> entry : entries) {
+                  r.output(KV.of(entry.getKey(), entry.getValue()));
+                }
+              }
+            }
+          };
+      PCollection<KV<String, Integer>> output =
+          pipeline
+              .apply(
+                  Create.of(
+                      KV.of("hello", KV.of("a", 97)), KV.of("hello", 
KV.of("a", 97)),
+                      KV.of("hello", KV.of("a", 98)), KV.of("hello", 
KV.of("b", 33))))
+              .apply(ParDo.of(fn));
+      PAssert.that(output)
+          .containsInAnyOrder(
+              KV.of("a", 97), KV.of("a", 97),
+              KV.of("a", 98), KV.of("b", 33));
+      pipeline.run();
+    }
+
     @Test
     @Category({ValidatesRunner.class, UsesStatefulParDo.class, 
UsesMultimapState.class})
     public void testMultimapStateRemove() {
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/construction/EnvironmentsTest.java
 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/construction/EnvironmentsTest.java
index f12a2a77f99..ebd4e9fbe24 100644
--- 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/construction/EnvironmentsTest.java
+++ 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/construction/EnvironmentsTest.java
@@ -219,6 +219,9 @@ public class EnvironmentsTest implements Serializable {
     assertThat(
         Environments.getJavaCapabilities(),
         
hasItem(BeamUrns.getUrn(RunnerApi.StandardProtocols.Enum.ORDERED_LIST_STATE)));
+    assertThat(
+        Environments.getJavaCapabilities(),
+        
hasItem(BeamUrns.getUrn(RunnerApi.StandardProtocols.Enum.MULTIMAP_STATE)));
     // Check that SDF truncation is supported
     assertThat(
         Environments.getJavaCapabilities(),
diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
index e06a82c8e25..6913c75a5f2 100644
--- 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
@@ -117,7 +117,7 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
 
     public Factory(
         PipelineOptions pipelineOptions,
-        Set<String> runnerCapabilites,
+        Set<String> runnerCapabilities,
         String ptransformId,
         Supplier<String> processBundleInstructionId,
         Supplier<List<CacheToken>> cacheTokens,
@@ -128,7 +128,7 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
         Coder<K> keyCoder,
         Coder<BoundedWindow> windowCoder) {
       this.pipelineOptions = pipelineOptions;
-      this.runnerCapabilities = runnerCapabilites;
+      this.runnerCapabilities = runnerCapabilities;
       this.ptransformId = ptransformId;
       this.processBundleInstructionId = processBundleInstructionId;
       this.cacheTokens = cacheTokens;
@@ -240,7 +240,7 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
   }
 
   private final PipelineOptions pipelineOptions;
-  private final Set<String> runnerCapabilites;
+  private final Set<String> runnerCapabilities;
   private final Map<StateKey, Object> stateKeyObjectCache;
   private final Map<TupleTag<?>, SideInputSpec> sideInputSpecMap;
   private final BeamFnStateClient beamFnStateClient;
@@ -259,7 +259,7 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
 
   public FnApiStateAccessor(
       PipelineOptions pipelineOptions,
-      Set<String> runnerCapabilites,
+      Set<String> runnerCapabilities,
       String ptransformId,
       Supplier<String> processBundleInstructionId,
       Supplier<List<CacheToken>> cacheTokens,
@@ -270,7 +270,7 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
       Coder<K> keyCoder,
       Coder<BoundedWindow> windowCoder) {
     this.pipelineOptions = pipelineOptions;
-    this.runnerCapabilites = runnerCapabilites;
+    this.runnerCapabilities = runnerCapabilities;
     this.stateKeyObjectCache = Maps.newHashMap();
     this.sideInputSpecMap = sideInputSpecMap;
     this.beamFnStateClient = beamFnStateClient;
@@ -414,7 +414,7 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
                               key,
                               ((KvCoder) 
sideInputSpec.getCoder()).getKeyCoder(),
                               ((KvCoder) 
sideInputSpec.getCoder()).getValueCoder(),
-                              runnerCapabilites.contains(
+                              runnerCapabilities.contains(
                                   BeamUrns.getUrn(
                                       RunnerApi.StandardRunnerProtocols.Enum
                                           .MULTIMAP_KEYS_VALUES_SIDE_INPUT))));
@@ -762,8 +762,113 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
       StateSpec<MultimapState<KeyT, ValueT>> spec,
       Coder<KeyT> keyCoder,
       Coder<ValueT> valueCoder) {
-    // TODO(https://github.com/apache/beam/issues/23616)
-    throw new UnsupportedOperationException("Multimap is not currently 
supported with Fn API.");
+    return (MultimapState<KeyT, ValueT>)
+        stateKeyObjectCache.computeIfAbsent(
+            createMultimapKeysUserStateKey(id),
+            new Function<StateKey, Object>() {
+              @Override
+              public Object apply(StateKey stateKey) {
+                return new MultimapState<KeyT, ValueT>() {
+                  private final MultimapUserState<KeyT, ValueT> impl =
+                      createMultimapUserState(stateKey, keyCoder, valueCoder);
+
+                  @Override
+                  public void put(KeyT key, ValueT value) {
+                    impl.put(key, value);
+                  }
+
+                  @Override
+                  public ReadableState<Iterable<ValueT>> get(KeyT key) {
+                    return new ReadableState<Iterable<ValueT>>() {
+                      @Override
+                      public Iterable<ValueT> read() {
+                        return impl.get(key);
+                      }
+
+                      @Override
+                      public ReadableState<Iterable<ValueT>> readLater() {
+                        impl.get(key).prefetch();
+                        return this;
+                      }
+                    };
+                  }
+
+                  @Override
+                  public void remove(KeyT key) {
+                    impl.remove(key);
+                  }
+
+                  @Override
+                  public ReadableState<Iterable<KeyT>> keys() {
+                    return new ReadableState<Iterable<KeyT>>() {
+                      @Override
+                      public Iterable<KeyT> read() {
+                        return impl.keys();
+                      }
+
+                      @Override
+                      public ReadableState<Iterable<KeyT>> readLater() {
+                        impl.keys().prefetch();
+                        return this;
+                      }
+                    };
+                  }
+
+                  @Override
+                  public ReadableState<Iterable<Map.Entry<KeyT, ValueT>>> 
entries() {
+                    return new ReadableState<Iterable<Map.Entry<KeyT, 
ValueT>>>() {
+                      @Override
+                      public Iterable<Map.Entry<KeyT, ValueT>> read() {
+                        return impl.entries();
+                      }
+
+                      @Override
+                      public ReadableState<Iterable<Map.Entry<KeyT, ValueT>>> 
readLater() {
+                        impl.entries().prefetch();
+                        return this;
+                      }
+                    };
+                  }
+
+                  @Override
+                  public ReadableState<Boolean> containsKey(KeyT key) {
+                    return new ReadableState<Boolean>() {
+                      @Override
+                      public Boolean read() {
+                        return !Iterables.isEmpty(impl.get(key));
+                      }
+
+                      @Override
+                      public ReadableState<Boolean> readLater() {
+                        impl.get(key).prefetch();
+                        return this;
+                      }
+                    };
+                  }
+
+                  @Override
+                  public ReadableState<Boolean> isEmpty() {
+                    return new ReadableState<Boolean>() {
+                      @Override
+                      public Boolean read() {
+                        return Iterables.isEmpty(impl.keys());
+                      }
+
+                      @Override
+                      public ReadableState<Boolean> readLater() {
+                        impl.keys().prefetch();
+                        return this;
+                      }
+                    };
+                  }
+
+                  @Override
+                  public void clear() {
+                    impl.clear();
+                  }
+                };
+              }
+            });
   }
 
   @Override
diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java
index 617faba87cc..8e3d76f5fc8 100644
--- 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java
@@ -29,6 +29,7 @@ import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.NoSuchElementException;
+import java.util.Objects;
 import java.util.Set;
 import org.apache.beam.fn.harness.Cache;
 import org.apache.beam.fn.harness.Caches;
@@ -38,13 +39,19 @@ import 
org.apache.beam.model.fnexecution.v1.BeamFnApi.StateClearRequest;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.IterableCoder;
+import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
 import org.apache.beam.sdk.fn.stream.PrefetchableIterables;
 import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
 import org.apache.beam.sdk.util.ByteStringOutputStream;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
 import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets;
 
 /**
  * An implementation of a multimap user state that utilizes the Beam Fn State 
API to fetch, clear
@@ -52,9 +59,6 @@ import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
  *
  * <p>Calling {@link #asyncClose()} schedules any required persistence 
changes. This object should
  * no longer be used after it is closed.
- *
- * <p>TODO: Move to an async persist model where persistence is signalled 
based upon cache memory
- * pressure and its need to flush.
  */
 public class MultimapUserState<K, V> {
 
@@ -63,8 +67,10 @@ public class MultimapUserState<K, V> {
   private final Coder<K> mapKeyCoder;
   private final Coder<V> valueCoder;
   private final StateRequest keysStateRequest;
+  private final StateRequest entriesStateRequest;
   private final StateRequest userStateRequest;
   private final CachingStateIterable<K> persistedKeys;
+  private final CachingStateIterable<KV<K, Iterable<V>>> persistedEntries;
 
   private boolean isClosed;
   private boolean isCleared;
@@ -90,6 +96,8 @@ public class MultimapUserState<K, V> {
     this.mapKeyCoder = mapKeyCoder;
     this.valueCoder = valueCoder;
 
+    // Note: These StateRequest protos are constructed even if we never try to 
read the
+    // corresponding state type. Consider constructing them lazily, as needed.
     this.keysStateRequest =
         
StateRequest.newBuilder().setInstructionId(instructionId).setStateKey(stateKey).build();
     this.persistedKeys =
@@ -106,6 +114,23 @@ public class MultimapUserState<K, V> {
         .setWindow(stateKey.getMultimapKeysUserState().getWindow())
         .setKey(stateKey.getMultimapKeysUserState().getKey());
     this.userStateRequest = userStateRequestBuilder.build();
+
+    StateRequest.Builder entriesStateRequestBuilder = 
StateRequest.newBuilder();
+    entriesStateRequestBuilder
+        .setInstructionId(instructionId)
+        .getStateKeyBuilder()
+        .getMultimapEntriesUserStateBuilder()
+        .setTransformId(stateKey.getMultimapKeysUserState().getTransformId())
+        .setUserStateId(stateKey.getMultimapKeysUserState().getUserStateId())
+        .setWindow(stateKey.getMultimapKeysUserState().getWindow())
+        .setKey(stateKey.getMultimapKeysUserState().getKey());
+    this.entriesStateRequest = entriesStateRequestBuilder.build();
+    this.persistedEntries =
+        StateFetchingIterators.readAllAndDecodeStartingFrom(
+            Caches.subCache(this.cache, "AllEntries"),
+            beamFnStateClient,
+            entriesStateRequest,
+            KvCoder.of(mapKeyCoder, IterableCoder.of(valueCoder)));
   }
 
   public void clear() {
@@ -200,7 +225,7 @@ public class MultimapUserState<K, V> {
               nextKey = persistedKeysIterator.next();
               Object nextKeyStructuralValue = 
mapKeyCoder.structuralValue(nextKey);
               if (!pendingRemovesNow.contains(nextKeyStructuralValue)) {
-                // Remove all keys that we will visit when passing over the 
persistedKeysIterator
+                // Remove all keys that we will visit when passing over the 
persistedKeysIterator,
                 // so we do not revisit them when passing over the 
pendingAddsNowIterator
                 if (pendingAddsNow.containsKey(nextKeyStructuralValue)) {
                   pendingAddsNow.remove(nextKeyStructuralValue);
@@ -235,6 +260,112 @@ public class MultimapUserState<K, V> {
     };
   }
 
+  @SuppressWarnings({
+    "nullness" // TODO(https://github.com/apache/beam/issues/21068)
+  })
+  /*
+   * Returns an Iterable containing all <K, V> entries in this multimap.
+   */
+  public PrefetchableIterable<Map.Entry<K, V>> entries() {
+    checkState(
+        !isClosed,
+        "Multimap user state is no longer usable because it is closed for %s",
+        keysStateRequest.getStateKey());
+    // Make a deep copy of pendingAdds so this iterator represents a snapshot 
of state at the time
+    // it was created.
+    Map<Object, KV<K, List<V>>> pendingAddsNow = 
ImmutableMap.copyOf(pendingAdds);
+    if (isCleared) {
+      return PrefetchableIterables.maybePrefetchable(
+          Iterables.concat(
+              Iterables.transform(
+                  pendingAddsNow.entrySet(),
+                  entry ->
+                      Iterables.transform(
+                          entry.getValue().getValue(),
+                          value -> 
Maps.immutableEntry(entry.getValue().getKey(), value)))));
+    }
+
+    Set<Object> pendingRemovesNow = 
ImmutableSet.copyOf(pendingRemoves.keySet());
+    return new PrefetchableIterables.Default<Map.Entry<K, V>>() {
+      @Override
+      public PrefetchableIterator<Map.Entry<K, V>> createIterator() {
+        return new PrefetchableIterator<Map.Entry<K, V>>() {
+          // We can get the same key multiple times from persistedEntries in 
the case that its
+          // values are paginated across multiple pages. Keep track of which 
keys we've seen, so we
+          // only add in pendingAdds once (with the first page). We'll also 
use it to return all
+          // keys not on the backend at the end of the iterator.
+          Set<Object> seenKeys = Sets.newHashSet();
+          final PrefetchableIterator<Map.Entry<K, V>> allEntries =
+              PrefetchableIterables.concat(
+                      Iterables.concat(
+                          Iterables.filter(
+                              Iterables.transform(
+                                  persistedEntries,
+                                  entry -> {
+                                    final Object structuralKey =
+                                        
mapKeyCoder.structuralValue(entry.getKey());
+                                    if 
(pendingRemovesNow.contains(structuralKey)) {
+                                      return null;
+                                    }
+                                    // add returns true if we haven't seen 
this key yet.
+                                    if (seenKeys.add(structuralKey)
+                                        && 
pendingAddsNow.containsKey(structuralKey)) {
+                                      return PrefetchableIterables.concat(
+                                          Iterables.transform(
+                                              
pendingAddsNow.get(structuralKey).getValue(),
+                                              pendingAdd ->
+                                                  
Maps.immutableEntry(entry.getKey(), pendingAdd)),
+                                          Iterables.transform(
+                                              entry.getValue(),
+                                              value -> 
Maps.immutableEntry(entry.getKey(), value)));
+                                    }
+                                    return Iterables.transform(
+                                        entry.getValue(),
+                                        value -> 
Maps.immutableEntry(entry.getKey(), value));
+                                  }),
+                              Objects::nonNull)),
+                      Iterables.concat(
+                          Iterables.filter(
+                              Iterables.transform(
+                                  pendingAddsNow.entrySet(),
+                                  entry -> {
+                                    if (seenKeys.contains(entry.getKey())) {
+                                      return null;
+                                    }
+                                    return Iterables.transform(
+                                        entry.getValue().getValue(),
+                                        value ->
+                                            
Maps.immutableEntry(entry.getValue().getKey(), value));
+                                  }),
+                              Objects::nonNull)))
+                  .iterator();
+
+          @Override
+          public boolean isReady() {
+            return allEntries.isReady();
+          }
+
+          @Override
+          public void prefetch() {
+            if (!isReady()) {
+              allEntries.prefetch();
+            }
+          }
+
+          @Override
+          public boolean hasNext() {
+            return allEntries.hasNext();
+          }
+
+          @Override
+          public Map.Entry<K, V> next() {
+            return allEntries.next();
+          }
+        };
+      }
+    };
+  }
+
   /*
    * Store a key-value pair in the multimap.
    * Allows duplicate key-value pairs.
diff --git 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java
 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java
index 48c9ce43bdf..67930732182 100644
--- 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java
+++ 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java
@@ -22,6 +22,7 @@ import static java.util.Collections.singletonList;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.emptyIterable;
 import static org.hamcrest.collection.ArrayMatching.arrayContainingInAnyOrder;
+import static 
org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder;
 import static org.hamcrest.core.Is.is;
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
@@ -34,11 +35,15 @@ import java.util.Arrays;
 import java.util.Collections;
 import java.util.Iterator;
 import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.StreamSupport;
 import org.apache.beam.fn.harness.Cache;
 import org.apache.beam.fn.harness.Caches;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
 import org.apache.beam.sdk.coders.ByteArrayCoder;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.IterableCoder;
+import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.NullableCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
@@ -179,6 +184,81 @@ public class MultimapUserStateTest {
     assertThrows(IllegalStateException.class, () -> userState.keys());
   }
 
+  @Test
+  public void testEntries() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                createMultimapEntriesStateKey(),
+                KV.of(
+                    KvCoder.of(ByteArrayCoder.of(), 
IterableCoder.of(StringUtf8Coder.of())),
+                    asList(KV.of(A1, asList("V1", "V2")), KV.of(A2, 
asList("V3"))))));
+    MultimapUserState<byte[], String> userState =
+        new MultimapUserState<>(
+            Caches.noop(),
+            fakeClient,
+            "instructionId",
+            createMultimapKeyStateKey(),
+            ByteArrayCoder.of(),
+            StringUtf8Coder.of());
+
+    assertArrayEquals(A1, userState.entries().iterator().next().getKey());
+    assertThat(
+        StreamSupport.stream(userState.entries().spliterator(), false)
+            .map(entry -> KV.of(ByteString.copyFrom(entry.getKey()), 
entry.getValue()))
+            .collect(Collectors.toList()),
+        containsInAnyOrder(
+            KV.of(ByteString.copyFrom(A1), "V1"),
+            KV.of(ByteString.copyFrom(A1), "V2"),
+            KV.of(ByteString.copyFrom(A2), "V3")));
+
+    userState.put(A1, "V4");
+    // Iterable is a snapshot of the entries at this time.
+    PrefetchableIterable<Map.Entry<byte[], String>> entriesBeforeOperations = 
userState.entries();
+
+    assertThat(
+        StreamSupport.stream(userState.entries().spliterator(), false)
+            .map(entry -> KV.of(ByteString.copyFrom(entry.getKey()), 
entry.getValue()))
+            .collect(Collectors.toList()),
+        containsInAnyOrder(
+            KV.of(ByteString.copyFrom(A1), "V1"),
+            KV.of(ByteString.copyFrom(A1), "V2"),
+            KV.of(ByteString.copyFrom(A2), "V3"),
+            KV.of(ByteString.copyFrom(A1), "V4")));
+
+    userState.remove(A1);
+    assertThat(
+        StreamSupport.stream(userState.entries().spliterator(), false)
+            .map(entry -> KV.of(ByteString.copyFrom(entry.getKey()), 
entry.getValue()))
+            .collect(Collectors.toList()),
+        containsInAnyOrder(KV.of(ByteString.copyFrom(A2), "V3")));
+
+    userState.put(A1, "V5");
+    assertThat(
+        StreamSupport.stream(userState.entries().spliterator(), false)
+            .map(entry -> KV.of(ByteString.copyFrom(entry.getKey()), 
entry.getValue()))
+            .collect(Collectors.toList()),
+        containsInAnyOrder(
+            KV.of(ByteString.copyFrom(A2), "V3"), 
KV.of(ByteString.copyFrom(A1), "V5")));
+
+    userState.clear();
+    assertThat(userState.entries(), emptyIterable());
+    // Check that after applying all these operations, our original entries 
Iterable contains a
+    // snapshot of state from when it was created.
+    assertThat(
+        StreamSupport.stream(entriesBeforeOperations.spliterator(), false)
+            .map(entry -> KV.of(ByteString.copyFrom(entry.getKey()), 
entry.getValue()))
+            .collect(Collectors.toList()),
+        containsInAnyOrder(
+            KV.of(ByteString.copyFrom(A1), "V1"),
+            KV.of(ByteString.copyFrom(A1), "V2"),
+            KV.of(ByteString.copyFrom(A1), "V4"),
+            KV.of(ByteString.copyFrom(A2), "V3")));
+
+    userState.asyncClose();
+    assertThrows(IllegalStateException.class, () -> userState.entries());
+  }
+
   @Test
   public void testPut() throws Exception {
     FakeBeamFnStateClient fakeClient =
@@ -620,6 +700,44 @@ public class MultimapUserStateTest {
     assertEquals(0, fakeClient.getCallCount());
   }
 
+  @Test
+  public void testEntriesPrefetched() throws Exception {
+    // Use a really large chunk size so all elements get returned in a single 
page. This makes it
+    // easier to count how many get calls we should expect.
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                createMultimapEntriesStateKey(),
+                KV.of(
+                    KvCoder.of(ByteArrayCoder.of(), 
IterableCoder.of(StringUtf8Coder.of())),
+                    asList(KV.of(A1, asList("V1", "V2")), KV.of(A2, 
asList("V3"))))),
+            1000000);
+    MultimapUserState<byte[], String> userState =
+        new MultimapUserState<>(
+            Caches.noop(),
+            fakeClient,
+            "instructionId",
+            createMultimapKeyStateKey(),
+            ByteArrayCoder.of(),
+            StringUtf8Coder.of());
+
+    userState.put(A1, "V4");
+    PrefetchableIterable<Map.Entry<byte[], String>> entries = 
userState.entries();
+    assertEquals(0, fakeClient.getCallCount());
+    entries.prefetch();
+    assertEquals(1, fakeClient.getCallCount());
+    assertThat(
+        StreamSupport.stream(entries.spliterator(), false)
+            .map(entry -> KV.of(ByteString.copyFrom(entry.getKey()), 
entry.getValue()))
+            .collect(Collectors.toList()),
+        containsInAnyOrder(
+            KV.of(ByteString.copyFrom(A1), "V1"),
+            KV.of(ByteString.copyFrom(A1), "V2"),
+            KV.of(ByteString.copyFrom(A1), "V4"),
+            KV.of(ByteString.copyFrom(A2), "V3")));
+    assertEquals(1, fakeClient.getCallCount());
+  }
+
   @Test
   public void testClearPrefetch() throws Exception {
     FakeBeamFnStateClient fakeClient =
@@ -1053,6 +1171,17 @@ public class MultimapUserStateTest {
         .build();
   }
 
+  private StateKey createMultimapEntriesStateKey() throws IOException {
+    return StateKey.newBuilder()
+        .setMultimapEntriesUserState(
+            StateKey.MultimapEntriesUserState.newBuilder()
+                .setWindow(encode(encodedWindow))
+                .setKey(encode(encodedKey))
+                .setTransformId(pTransformId)
+                .setUserStateId(stateId))
+        .build();
+  }
+
   private StateKey createMultimapValueStateKey(byte[] key) throws IOException {
     return StateKey.newBuilder()
         .setMultimapUserState(


Reply via email to