This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new ea65a05  [BEAM-13015] Migrate all user state and side implementations 
to support caching. (#16263)
ea65a05 is described below

commit ea65a054f2fcb6349478d19609a773f66bbfa20e
Author: Lukasz Cwik <[email protected]>
AuthorDate: Tue Jan 4 15:39:36 2022 -0800

    [BEAM-13015] Migrate all user state and side implementations to support 
caching. (#16263)
    
    This change also ensures that prefetch can be invoked on the iterable to 
prevent the prefetch being lost once the iterator is discarded.
    
    See 
https://s.apache.org/beam-fn-state-api-and-bundle-processing#heading=h.tms0ncgbzz6f
---
 .../beam/sdk/fn/stream/PrefetchableIterable.java   |   3 +
 .../beam/sdk/fn/stream/PrefetchableIterables.java  |  53 ++-
 .../sdk/fn/stream/PrefetchableIterablesTest.java   |  20 +
 .../java/org/apache/beam/fn/harness/Caches.java    |   1 +
 .../apache/beam/fn/harness/state/BagUserState.java |  57 ++-
 .../beam/fn/harness/state/FnApiStateAccessor.java  | 135 +++---
 .../beam/fn/harness/state/IterableSideInput.java   |  47 +-
 .../state/LazyCachingIteratorToIterable.java       |   6 +-
 .../beam/fn/harness/state/MultimapSideInput.java   |  69 +--
 .../beam/fn/harness/state/MultimapUserState.java   | 184 +++++---
 .../fn/harness/state/StateFetchingIterators.java   |   9 +-
 .../beam/fn/harness/state/BagUserStateTest.java    | 170 ++++++-
 .../fn/harness/state/IterableSideInputTest.java    |  96 ++++
 .../fn/harness/state/MultimapSideInputTest.java    |  71 ++-
 .../fn/harness/state/MultimapUserStateTest.java    | 524 +++++++++++++++++----
 15 files changed, 1086 insertions(+), 359 deletions(-)

diff --git 
a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/PrefetchableIterable.java
 
b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/PrefetchableIterable.java
index 5700f6c..9a2fde8 100644
--- 
a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/PrefetchableIterable.java
+++ 
b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/PrefetchableIterable.java
@@ -20,6 +20,9 @@ package org.apache.beam.sdk.fn.stream;
 /** An {@link Iterable} that returns {@link PrefetchableIterator}s. */
 public interface PrefetchableIterable<T> extends Iterable<T> {
 
+  /** Ensures that the next iterator returned has been prefetched. */
+  void prefetch();
+
   @Override
   PrefetchableIterator<T> iterator();
 }
diff --git 
a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/PrefetchableIterables.java
 
b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/PrefetchableIterables.java
index e55ece3..d8696f0 100644
--- 
a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/PrefetchableIterables.java
+++ 
b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/PrefetchableIterables.java
@@ -18,6 +18,7 @@
 package org.apache.beam.sdk.fn.stream;
 
 import java.util.NoSuchElementException;
+import javax.annotation.Nullable;
 import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.FluentIterable;
 
 /**
@@ -26,8 +27,42 @@ import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.FluentIt
  */
 public class PrefetchableIterables {
 
+  /**
+   * A default implementation that caches an iterator to be returned when 
{@link #prefetch} is
+   * invoked.
+   */
+  public abstract static class Default<T> implements PrefetchableIterable<T> {
+    @Nullable private PrefetchableIterator<T> iterator = null;
+
+    @Override
+    public final void prefetch() {
+      if (iterator != null) {
+        return;
+      }
+      iterator = createIterator();
+      iterator.prefetch();
+    }
+
+    @Override
+    public final PrefetchableIterator<T> iterator() {
+      if (iterator == null) {
+        return createIterator();
+      }
+      PrefetchableIterator<T> rval = iterator;
+      iterator = null;
+      return rval;
+    }
+
+    protected abstract PrefetchableIterator<T> createIterator();
+  }
+
   private static final PrefetchableIterable<Object> EMPTY_ITERABLE =
-      PrefetchableIterators::emptyIterator;
+      new Default<Object>() {
+        @Override
+        protected PrefetchableIterator<Object> createIterator() {
+          return PrefetchableIterators.emptyIterator();
+        }
+      };
 
   /** Returns an empty {@link PrefetchableIterable}. */
   public static <T> PrefetchableIterable<T> emptyIterable() {
@@ -44,9 +79,9 @@ public class PrefetchableIterables {
     if (values.length == 0) {
       return emptyIterable();
     }
-    return new PrefetchableIterable<T>() {
+    return new Default<T>() {
       @Override
-      public PrefetchableIterator<T> iterator() {
+      public PrefetchableIterator<T> createIterator() {
         return PrefetchableIterators.fromArray(values);
       }
     };
@@ -63,9 +98,9 @@ public class PrefetchableIterables {
     if (iterable instanceof PrefetchableIterable) {
       return (PrefetchableIterable<T>) iterable;
     }
-    return new PrefetchableIterable<T>() {
+    return new Default<T>() {
       @Override
-      public PrefetchableIterator<T> iterator() {
+      public PrefetchableIterator<T> createIterator() {
         return PrefetchableIterators.maybePrefetchable(iterable.iterator());
       }
     };
@@ -87,10 +122,10 @@ public class PrefetchableIterables {
     } else if (iterables.length == 1) {
       return maybePrefetchable(iterables[0]);
     }
-    return new PrefetchableIterable<T>() {
+    return new Default<T>() {
       @SuppressWarnings("methodref.receiver.invalid")
       @Override
-      public PrefetchableIterator<T> iterator() {
+      public PrefetchableIterator<T> createIterator() {
         return PrefetchableIterators.concatIterators(
             
FluentIterable.from(iterables).transform(Iterable::iterator).iterator());
       }
@@ -100,9 +135,9 @@ public class PrefetchableIterables {
   /** Limits the {@link PrefetchableIterable} to the specified number of 
elements. */
   public static <T> PrefetchableIterable<T> limit(Iterable<T> iterable, int 
limit) {
     PrefetchableIterable<T> prefetchableIterable = maybePrefetchable(iterable);
-    return new PrefetchableIterable<T>() {
+    return new Default<T>() {
       @Override
-      public PrefetchableIterator<T> iterator() {
+      public PrefetchableIterator<T> createIterator() {
         return new PrefetchableIterator<T>() {
           PrefetchableIterator<T> delegate = prefetchableIterable.iterator();
           int currentPosition;
diff --git 
a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/PrefetchableIterablesTest.java
 
b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/PrefetchableIterablesTest.java
index e4d1b0d..3de5b48 100644
--- 
a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/PrefetchableIterablesTest.java
+++ 
b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/PrefetchableIterablesTest.java
@@ -17,8 +17,12 @@
  */
 package org.apache.beam.sdk.fn.stream;
 
+import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertSame;
+import static org.junit.Assert.assertTrue;
 
+import org.apache.beam.sdk.fn.stream.PrefetchableIterables.Default;
+import 
org.apache.beam.sdk.fn.stream.PrefetchableIteratorsTest.ReadyAfterPrefetchUntilNext;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -57,6 +61,22 @@ public class PrefetchableIterablesTest {
   }
 
   @Test
+  public void testDefaultPrefetch() {
+    PrefetchableIterable<String> iterable =
+        new Default<String>() {
+          @Override
+          protected PrefetchableIterator<String> createIterator() {
+            return new ReadyAfterPrefetchUntilNext<>(
+                PrefetchableIterators.fromArray("A", "B", "C"));
+          }
+        };
+
+    assertFalse(iterable.iterator().isReady());
+    iterable.prefetch();
+    assertTrue(iterable.iterator().isReady());
+  }
+
+  @Test
   public void testConcat() {
     verifyIterable(PrefetchableIterables.concat());
 
diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/Caches.java 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/Caches.java
index e5671aa..e53483e 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/Caches.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/Caches.java
@@ -165,6 +165,7 @@ public final class Caches {
             cache == null ? "null" : cache.getClass()));
   }
 
+  @VisibleForTesting
   static Cache<Object, Object> forCache(
       
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.Cache<CompositeKey,
 Object>
           cache) {
diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java
index 76664ba..5952dcf 100644
--- 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java
@@ -17,12 +17,17 @@
  */
 package org.apache.beam.fn.harness.state;
 
+import static 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
 import static 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
 
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.List;
+import org.apache.beam.fn.harness.Cache;
+import 
org.apache.beam.fn.harness.state.StateFetchingIterators.CachingStateIterable;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateAppendRequest;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateClearRequest;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
@@ -39,45 +44,39 @@ import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterable
  *
  * <p>TODO: Move to an async persist model where persistence is signalled 
based upon cache memory
  * pressure and its need to flush.
- *
- * <p>TODO: Support block level caching and prefetch.
  */
 @SuppressWarnings({
   "rawtypes", // TODO(https://issues.apache.org/jira/browse/BEAM-10556)
   "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
 })
 public class BagUserState<T> {
+  private final Cache<?, ?> cache;
   private final BeamFnStateClient beamFnStateClient;
   private final StateRequest request;
   private final Coder<T> valueCoder;
-  private PrefetchableIterable<T> oldValues;
-  private ArrayList<T> newValues;
+  private final CachingStateIterable<T> oldValues;
+  private List<T> newValues;
+  private boolean isCleared;
   private boolean isClosed;
 
+  /** The cache must be namespaced for this state object accordingly. */
   public BagUserState(
+      Cache<?, ?> cache,
       BeamFnStateClient beamFnStateClient,
       String instructionId,
-      String ptransformId,
-      String stateId,
-      ByteString encodedWindow,
-      ByteString encodedKey,
+      StateKey stateKey,
       Coder<T> valueCoder) {
+    checkArgument(
+        stateKey.hasBagUserState(), "Expected BagUserState StateKey but 
received %s.", stateKey);
+    this.cache = cache;
     this.beamFnStateClient = beamFnStateClient;
     this.valueCoder = valueCoder;
-
-    StateRequest.Builder requestBuilder = StateRequest.newBuilder();
-    requestBuilder
-        .setInstructionId(instructionId)
-        .getStateKeyBuilder()
-        .getBagUserStateBuilder()
-        .setTransformId(ptransformId)
-        .setUserStateId(stateId)
-        .setWindow(encodedWindow)
-        .setKey(encodedKey);
-    request = requestBuilder.build();
+    this.request =
+        
StateRequest.newBuilder().setInstructionId(instructionId).setStateKey(stateKey).build();
 
     this.oldValues =
-        StateFetchingIterators.readAllAndDecodeStartingFrom(beamFnStateClient, 
request, valueCoder);
+        StateFetchingIterators.readAllAndDecodeStartingFrom(
+            this.cache, beamFnStateClient, request, valueCoder);
     this.newValues = new ArrayList<>();
   }
 
@@ -86,7 +85,7 @@ public class BagUserState<T> {
         !isClosed,
         "Bag user state is no longer usable because it is closed for %s",
         request.getStateKey());
-    if (oldValues == null) {
+    if (isCleared) {
       // If we were cleared we should disregard old values.
       return 
PrefetchableIterables.limit(Collections.unmodifiableList(newValues), 
newValues.size());
     } else if (newValues.isEmpty()) {
@@ -110,7 +109,7 @@ public class BagUserState<T> {
         !isClosed,
         "Bag user state is no longer usable because it is closed for %s",
         request.getStateKey());
-    oldValues = null;
+    isCleared = true;
     newValues = new ArrayList<>();
   }
 
@@ -120,7 +119,11 @@ public class BagUserState<T> {
         !isClosed,
         "Bag user state is no longer usable because it is closed for %s",
         request.getStateKey());
-    if (oldValues == null) {
+    isClosed = true;
+    if (!isCleared && newValues.isEmpty()) {
+      return;
+    }
+    if (isCleared) {
       beamFnStateClient.handle(
           
request.toBuilder().setClear(StateClearRequest.getDefaultInstance()));
     }
@@ -135,6 +138,12 @@ public class BagUserState<T> {
               .toBuilder()
               
.setAppend(StateAppendRequest.newBuilder().setData(out.toByteString())));
     }
-    isClosed = true;
+
+    // Modify the underlying cached state depending on the mutations performed
+    if (isCleared) {
+      oldValues.clearAndAppend(newValues);
+    } else {
+      oldValues.append(newValues);
+    }
   }
 }
diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
index 5a14e2f..682b63d 100644
--- 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
@@ -29,6 +29,7 @@ import java.util.Map;
 import java.util.function.Function;
 import java.util.function.Supplier;
 import org.apache.beam.fn.harness.Cache;
+import org.apache.beam.fn.harness.Caches;
 import 
org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleRequest.CacheToken;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
 import org.apache.beam.runners.core.SideInputReader;
@@ -169,7 +170,6 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
     }
     ByteString encodedWindow = encodedWindowOut.toByteString();
     StateKey.Builder cacheKeyBuilder = StateKey.newBuilder();
-    Object sideInputAccessor;
 
     switch (sideInputSpec.getAccessPattern()) {
       case Materializations.ITERABLE_MATERIALIZATION_URN:
@@ -178,14 +178,6 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
             .setTransformId(ptransformId)
             .setSideInputId(tag.getId())
             .setWindow(encodedWindow);
-        sideInputAccessor =
-            new IterableSideInput<>(
-                beamFnStateClient,
-                processBundleInstructionId.get(),
-                ptransformId,
-                tag.getId(),
-                encodedWindow,
-                sideInputSpec.getCoder());
         break;
 
       case Materializations.MULTIMAP_MATERIALIZATION_URN:
@@ -194,21 +186,11 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
             "Expected %s but received %s.",
             KvCoder.class,
             sideInputSpec.getCoder().getClass());
-        KvCoder<?, ?> kvCoder = (KvCoder) sideInputSpec.getCoder();
         cacheKeyBuilder
-            .getMultimapSideInputBuilder()
+            .getMultimapKeysSideInputBuilder()
             .setTransformId(ptransformId)
             .setSideInputId(tag.getId())
             .setWindow(encodedWindow);
-        sideInputAccessor =
-            new MultimapSideInput<>(
-                beamFnStateClient,
-                processBundleInstructionId.get(),
-                ptransformId,
-                tag.getId(),
-                encodedWindow,
-                kvCoder.getKeyCoder(),
-                kvCoder.getValueCoder());
         break;
 
       default:
@@ -222,10 +204,44 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
                 sideInputSpec.getAccessPattern(),
                 tag));
     }
-
     return (T)
         stateKeyObjectCache.computeIfAbsent(
-            cacheKeyBuilder.build(), key -> 
sideInputSpec.getViewFn().apply(sideInputAccessor));
+            cacheKeyBuilder.build(),
+            key -> {
+              switch (sideInputSpec.getAccessPattern()) {
+                case Materializations.ITERABLE_MATERIALIZATION_URN:
+                  return sideInputSpec
+                      .getViewFn()
+                      .apply(
+                          new IterableSideInput<>(
+                              getCacheFor(key),
+                              beamFnStateClient,
+                              processBundleInstructionId.get(),
+                              key,
+                              sideInputSpec.getCoder()));
+                case Materializations.MULTIMAP_MATERIALIZATION_URN:
+                  return sideInputSpec
+                      .getViewFn()
+                      .apply(
+                          new MultimapSideInput<>(
+                              getCacheFor(key),
+                              beamFnStateClient,
+                              processBundleInstructionId.get(),
+                              key,
+                              ((KvCoder) 
sideInputSpec.getCoder()).getKeyCoder(),
+                              ((KvCoder) 
sideInputSpec.getCoder()).getValueCoder()));
+                default:
+                  throw new IllegalStateException(
+                      String.format(
+                          "This SDK is only capable of dealing with %s 
materializations "
+                              + "but was asked to handle %s for 
PCollectionView with tag %s.",
+                          ImmutableList.of(
+                              Materializations.ITERABLE_MATERIALIZATION_URN,
+                              Materializations.MULTIMAP_MATERIALIZATION_URN),
+                          sideInputSpec.getAccessPattern(),
+                          tag));
+              }
+            });
   }
 
   @Override
@@ -247,7 +263,7 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
               @Override
               public Object apply(StateKey key) {
                 return new ValueState<T>() {
-                  private final BagUserState<T> impl = createBagUserState(id, 
coder);
+                  private final BagUserState<T> impl = createBagUserState(key, 
coder);
 
                   @Override
                   public void clear() {
@@ -272,7 +288,7 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
 
                   @Override
                   public ValueState<T> readLater() {
-                    impl.get().iterator().prefetch();
+                    impl.get().prefetch();
                     return this;
                   }
                 };
@@ -289,7 +305,7 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
               @Override
               public Object apply(StateKey key) {
                 return new BagState<T>() {
-                  private final BagUserState<T> impl = createBagUserState(id, 
elemCoder);
+                  private final BagUserState<T> impl = createBagUserState(key, 
elemCoder);
 
                   @Override
                   public void add(T value) {
@@ -318,7 +334,7 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
 
                   @Override
                   public BagState<T> readLater() {
-                    impl.get().iterator().prefetch();
+                    impl.get().prefetch();
                     return this;
                   }
 
@@ -335,13 +351,13 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
   public <T> SetState<T> bindSet(String id, StateSpec<SetState<T>> spec, 
Coder<T> elemCoder) {
     return (SetState<T>)
         stateKeyObjectCache.computeIfAbsent(
-            createMultimapUserStateKey(id),
+            createMultimapKeysUserStateKey(id),
             new Function<StateKey, Object>() {
               @Override
               public Object apply(StateKey key) {
                 return new SetState<T>() {
                   private final MultimapUserState<T, Void> impl =
-                      createMultimapUserState(id, elemCoder, VoidCoder.of());
+                      createMultimapUserState(key, elemCoder, VoidCoder.of());
 
                   @Override
                   public void clear() {
@@ -358,7 +374,7 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
 
                       @Override
                       public ReadableState<Boolean> readLater() {
-                        impl.get(t).iterator().prefetch();
+                        impl.get(t).prefetch();
                         return this;
                       }
                     };
@@ -394,7 +410,7 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
 
                       @Override
                       public ReadableState<Boolean> readLater() {
-                        impl.keys().iterator().prefetch();
+                        impl.keys().prefetch();
                         return this;
                       }
                     };
@@ -407,7 +423,7 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
 
                   @Override
                   public SetState<T> readLater() {
-                    impl.keys().iterator().prefetch();
+                    impl.keys().prefetch();
                     return this;
                   }
                 };
@@ -423,13 +439,13 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
       Coder<ValueT> mapValueCoder) {
     return (MapState<KeyT, ValueT>)
         stateKeyObjectCache.computeIfAbsent(
-            createMultimapUserStateKey(id),
+            createMultimapKeysUserStateKey(id),
             new Function<StateKey, Object>() {
               @Override
               public Object apply(StateKey key) {
                 return new MapState<KeyT, ValueT>() {
                   private final MultimapUserState<KeyT, ValueT> impl =
-                      createMultimapUserState(id, mapKeyCoder, mapValueCoder);
+                      createMultimapUserState(key, mapKeyCoder, mapValueCoder);
 
                   @Override
                   public void clear() {
@@ -474,7 +490,7 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
 
                       @Override
                       public ReadableState<ValueT> readLater() {
-                        impl.get(key).iterator().prefetch();
+                        impl.get(key).prefetch();
                         return this;
                       }
                     };
@@ -490,7 +506,7 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
 
                       @Override
                       public ReadableState<Iterable<KeyT>> readLater() {
-                        impl.keys().iterator().prefetch();
+                        impl.keys().prefetch();
                         return this;
                       }
                     };
@@ -574,7 +590,7 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
                 // TODO: Support squashing accumulators depending on whether 
we know of all
                 // remote accumulators and local accumulators or just local 
accumulators.
                 return new CombiningState<ElementT, AccumT, ResultT>() {
-                  private final BagUserState<AccumT> impl = 
createBagUserState(id, accumCoder);
+                  private final BagUserState<AccumT> impl = 
createBagUserState(key, accumCoder);
 
                   @Override
                   public AccumT getAccum() {
@@ -606,7 +622,7 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
 
                   @Override
                   public CombiningState<ElementT, AccumT, ResultT> readLater() 
{
-                    impl.get().iterator().prefetch();
+                    impl.get().prefetch();
                     return this;
                   }
 
@@ -636,7 +652,7 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
 
                       @Override
                       public ReadableState<Boolean> readLater() {
-                        impl.get().iterator().prefetch();
+                        impl.get().prefetch();
                         return this;
                       }
                     };
@@ -697,31 +713,20 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
     throw new UnsupportedOperationException("WatermarkHoldState is unsupported 
by the Fn API.");
   }
 
-  private <KeyT, ValueT> MultimapUserState<KeyT, ValueT> 
createMultimapUserState(
-      String stateId, Coder<KeyT> keyCoder, Coder<ValueT> valueCoder) {
-    MultimapUserState<KeyT, ValueT> rval =
-        new MultimapUserState(
-            beamFnStateClient,
-            processBundleInstructionId.get(),
-            ptransformId,
-            stateId,
-            encodedCurrentWindowSupplier.get(),
-            encodedCurrentKeySupplier.get(),
-            keyCoder,
-            valueCoder);
-    stateFinalizers.add(rval::asyncClose);
-    return rval;
+  private Cache<?, ?> getCacheFor(StateKey stateKey) {
+    switch (stateKey.getTypeCase()) {
+      default:
+        return Caches.noop();
+    }
   }
 
-  private <T> BagUserState<T> createBagUserState(String stateId, Coder<T> 
valueCoder) {
+  private <T> BagUserState<T> createBagUserState(StateKey stateKey, Coder<T> 
valueCoder) {
     BagUserState<T> rval =
         new BagUserState<>(
+            getCacheFor(stateKey),
             beamFnStateClient,
             processBundleInstructionId.get(),
-            ptransformId,
-            stateId,
-            encodedCurrentWindowSupplier.get(),
-            encodedCurrentKeySupplier.get(),
+            stateKey,
             valueCoder);
     stateFinalizers.add(rval::asyncClose);
     return rval;
@@ -738,7 +743,21 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
     return builder.build();
   }
 
-  private StateKey createMultimapUserStateKey(String stateId) {
+  private <KeyT, ValueT> MultimapUserState<KeyT, ValueT> 
createMultimapUserState(
+      StateKey stateKey, Coder<KeyT> keyCoder, Coder<ValueT> valueCoder) {
+    MultimapUserState<KeyT, ValueT> rval =
+        new MultimapUserState(
+            Caches.noop(),
+            beamFnStateClient,
+            processBundleInstructionId.get(),
+            stateKey,
+            keyCoder,
+            valueCoder);
+    stateFinalizers.add(rval::asyncClose);
+    return rval;
+  }
+
+  private StateKey createMultimapKeysUserStateKey(String stateId) {
     StateKey.Builder builder = StateKey.newBuilder();
     builder
         .getMultimapKeysUserStateBuilder()
diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/IterableSideInput.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/IterableSideInput.java
index 48c35de..d59742c 100644
--- 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/IterableSideInput.java
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/IterableSideInput.java
@@ -17,55 +17,44 @@
  */
 package org.apache.beam.fn.harness.state;
 
+import static 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
+
+import org.apache.beam.fn.harness.Cache;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.transforms.Materializations.IterableView;
-import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
 
 /**
  * An implementation of a iterable side input that utilizes the Beam Fn State 
API to fetch values.
- *
- * <p>TODO: Support block level caching and prefetch.
  */
 @SuppressWarnings({
   "rawtypes" // TODO(https://issues.apache.org/jira/browse/BEAM-10556)
 })
 public class IterableSideInput<T> implements IterableView<T> {
 
-  private final BeamFnStateClient beamFnStateClient;
-  private final String instructionId;
-  private final String ptransformId;
-  private final String sideInputId;
-  private final ByteString encodedWindow;
-  private final Coder<T> valueCoder;
+  private final Iterable<T> values;
 
   public IterableSideInput(
+      Cache<?, ?> cache,
       BeamFnStateClient beamFnStateClient,
       String instructionId,
-      String ptransformId,
-      String sideInputId,
-      ByteString encodedWindow,
+      StateKey stateKey,
       Coder<T> valueCoder) {
-    this.beamFnStateClient = beamFnStateClient;
-    this.instructionId = instructionId;
-    this.ptransformId = ptransformId;
-    this.sideInputId = sideInputId;
-    this.encodedWindow = encodedWindow;
-    this.valueCoder = valueCoder;
+    checkArgument(
+        stateKey.hasIterableSideInput(),
+        "Expected IterableSideInput StateKey but received %s.",
+        stateKey);
+    this.values =
+        StateFetchingIterators.readAllAndDecodeStartingFrom(
+            cache,
+            beamFnStateClient,
+            
StateRequest.newBuilder().setInstructionId(instructionId).setStateKey(stateKey).build(),
+            valueCoder);
   }
 
   @Override
   public Iterable<T> get() {
-    StateRequest.Builder requestBuilder = StateRequest.newBuilder();
-    requestBuilder
-        .setInstructionId(instructionId)
-        .getStateKeyBuilder()
-        .getIterableSideInputBuilder()
-        .setTransformId(ptransformId)
-        .setSideInputId(sideInputId)
-        .setWindow(encodedWindow);
-
-    return StateFetchingIterators.readAllAndDecodeStartingFrom(
-        beamFnStateClient, requestBuilder.build(), valueCoder);
+    return values;
   }
 }
diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java
index 7828f93..d7e76fa 100644
--- 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java
@@ -22,7 +22,7 @@ import java.util.Iterator;
 import java.util.List;
 import java.util.NoSuchElementException;
 import java.util.Objects;
-import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterables;
 import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
 import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
 import org.checkerframework.checker.nullness.qual.Nullable;
@@ -31,7 +31,7 @@ import org.checkerframework.checker.nullness.qual.Nullable;
  * Converts an iterator to an iterable lazily loading values from the 
underlying iterator and
  * caching them to support reiteration.
  */
-class LazyCachingIteratorToIterable<T> implements PrefetchableIterable<T> {
+class LazyCachingIteratorToIterable<T> extends 
PrefetchableIterables.Default<T> {
   private final List<T> cachedElements;
   private final PrefetchableIterator<T> iterator;
 
@@ -41,7 +41,7 @@ class LazyCachingIteratorToIterable<T> implements 
PrefetchableIterable<T> {
   }
 
   @Override
-  public PrefetchableIterator<T> iterator() {
+  public PrefetchableIterator<T> createIterator() {
     return new CachingIterator();
   }
 
diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapSideInput.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapSideInput.java
index 2fc72fe9..e754171 100644
--- 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapSideInput.java
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapSideInput.java
@@ -17,7 +17,12 @@
  */
 package org.apache.beam.fn.harness.state;
 
+import static 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
+
 import java.io.IOException;
+import org.apache.beam.fn.harness.Cache;
+import org.apache.beam.fn.harness.Caches;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.transforms.Materializations.MultimapView;
@@ -25,8 +30,6 @@ import 
org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
 
 /**
  * An implementation of a multimap side input that utilizes the Beam Fn State 
API to fetch values.
- *
- * <p>TODO: Support block level caching and prefetch.
  */
 @SuppressWarnings({
   "rawtypes", // TODO(https://issues.apache.org/jira/browse/BEAM-10556)
@@ -34,44 +37,35 @@ import 
org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
 })
 public class MultimapSideInput<K, V> implements MultimapView<K, V> {
 
+  private final Cache<?, ?> cache;
   private final BeamFnStateClient beamFnStateClient;
-  private final String instructionId;
-  private final String ptransformId;
-  private final String sideInputId;
-  private final ByteString encodedWindow;
+  private final StateRequest keysRequest;
   private final Coder<K> keyCoder;
   private final Coder<V> valueCoder;
 
   public MultimapSideInput(
+      Cache<?, ?> cache,
       BeamFnStateClient beamFnStateClient,
       String instructionId,
-      String ptransformId,
-      String sideInputId,
-      ByteString encodedWindow,
+      StateKey stateKey,
       Coder<K> keyCoder,
       Coder<V> valueCoder) {
+    checkArgument(
+        stateKey.hasMultimapKeysSideInput(),
+        "Expected MultimapKeysSideInput StateKey but received %s.",
+        stateKey);
+    this.cache = cache;
     this.beamFnStateClient = beamFnStateClient;
-    this.instructionId = instructionId;
-    this.ptransformId = ptransformId;
-    this.sideInputId = sideInputId;
-    this.encodedWindow = encodedWindow;
+    this.keysRequest =
+        
StateRequest.newBuilder().setInstructionId(instructionId).setStateKey(stateKey).build();
     this.keyCoder = keyCoder;
     this.valueCoder = valueCoder;
   }
 
   @Override
   public Iterable<K> get() {
-    StateRequest.Builder requestBuilder = StateRequest.newBuilder();
-    requestBuilder
-        .setInstructionId(instructionId)
-        .getStateKeyBuilder()
-        .getMultimapKeysSideInputBuilder()
-        .setTransformId(ptransformId)
-        .setSideInputId(sideInputId)
-        .setWindow(encodedWindow);
-
     return StateFetchingIterators.readAllAndDecodeStartingFrom(
-        beamFnStateClient, requestBuilder.build(), keyCoder);
+        cache, beamFnStateClient, keysRequest, keyCoder);
   }
 
   @Override
@@ -81,19 +75,26 @@ public class MultimapSideInput<K, V> implements 
MultimapView<K, V> {
       keyCoder.encode(k, output);
     } catch (IOException e) {
       throw new IllegalStateException(
-          String.format("Failed to encode key %s for side input id %s.", k, 
sideInputId), e);
+          String.format(
+              "Failed to encode key %s for side input id %s.",
+              k, 
keysRequest.getStateKey().getMultimapKeysSideInput().getSideInputId()),
+          e);
     }
-    StateRequest.Builder requestBuilder = StateRequest.newBuilder();
-    requestBuilder
-        .setInstructionId(instructionId)
-        .getStateKeyBuilder()
-        .getMultimapSideInputBuilder()
-        .setTransformId(ptransformId)
-        .setSideInputId(sideInputId)
-        .setWindow(encodedWindow)
-        .setKey(output.toByteString());
+    ByteString encodedKey = output.toByteString();
+    StateKey stateKey =
+        StateKey.newBuilder()
+            .setMultimapSideInput(
+                StateKey.MultimapSideInput.newBuilder()
+                    .setTransformId(
+                        
keysRequest.getStateKey().getMultimapKeysSideInput().getTransformId())
+                    .setSideInputId(
+                        
keysRequest.getStateKey().getMultimapKeysSideInput().getSideInputId())
+                    
.setWindow(keysRequest.getStateKey().getMultimapKeysSideInput().getWindow())
+                    .setKey(encodedKey))
+            .build();
 
+    StateRequest request = 
keysRequest.toBuilder().setStateKey(stateKey).build();
     return StateFetchingIterators.readAllAndDecodeStartingFrom(
-        beamFnStateClient, requestBuilder.build(), valueCoder);
+        Caches.subCache(cache, "ValuesForKey", encodedKey), beamFnStateClient, 
request, valueCoder);
   }
 }
diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java
index 293d915..70dc2b1 100644
--- 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java
@@ -17,10 +17,12 @@
  */
 package org.apache.beam.fn.harness.state;
 
+import static 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
 import static 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Iterator;
@@ -28,8 +30,12 @@ import java.util.List;
 import java.util.Map;
 import java.util.NoSuchElementException;
 import java.util.Set;
+import org.apache.beam.fn.harness.Cache;
+import org.apache.beam.fn.harness.Caches;
+import 
org.apache.beam.fn.harness.state.StateFetchingIterators.CachingStateIterable;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateAppendRequest;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateClearRequest;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
@@ -38,7 +44,6 @@ import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
-import org.checkerframework.checker.nullness.qual.Nullable;
 
 /**
  * An implementation of a multimap user state that utilizes the Beam Fn State 
API to fetch, clear
@@ -49,17 +54,16 @@ import org.checkerframework.checker.nullness.qual.Nullable;
  *
  * <p>TODO: Move to an async persist model where persistence is signalled 
based upon cache memory
  * pressure and its need to flush.
- *
- * <p>TODO: Support block level caching and prefetch.
  */
 public class MultimapUserState<K, V> {
 
+  private final Cache<?, ?> cache;
   private final BeamFnStateClient beamFnStateClient;
   private final Coder<K> mapKeyCoder;
   private final Coder<V> valueCoder;
-  private final String stateId;
   private final StateRequest keysStateRequest;
   private final StateRequest userStateRequest;
+  private CachingStateIterable<K> persistedKeys;
 
   private boolean isClosed;
   private boolean isCleared;
@@ -67,44 +71,40 @@ public class MultimapUserState<K, V> {
   private HashMap<Object, K> pendingRemoves = Maps.newHashMap();
   private HashMap<Object, KV<K, List<V>>> pendingAdds = Maps.newHashMap();
   // Values retrieved from persistent storage
-  private HashMap<K, PrefetchableIterable<V>> persistedValues = 
Maps.newHashMap();
-  private @Nullable PrefetchableIterable<K> persistedKeys = null;
+  private HashMap<Object, KV<K, CachingStateIterable<V>>> persistedValues = 
Maps.newHashMap();
 
   public MultimapUserState(
+      Cache<?, ?> cache,
       BeamFnStateClient beamFnStateClient,
       String instructionId,
-      String pTransformId,
-      String stateId,
-      ByteString encodedWindow,
-      ByteString encodedKey,
+      StateKey stateKey,
       Coder<K> mapKeyCoder,
       Coder<V> valueCoder) {
+    checkArgument(
+        stateKey.hasMultimapKeysUserState(),
+        "Expected MultimapKeysUserState StateKey but received %s.",
+        stateKey);
+    this.cache = cache;
     this.beamFnStateClient = beamFnStateClient;
     this.mapKeyCoder = mapKeyCoder;
     this.valueCoder = valueCoder;
-    this.stateId = stateId;
 
-    StateRequest.Builder keysStateRequestBuilder = StateRequest.newBuilder();
-    keysStateRequestBuilder
-        .setInstructionId(instructionId)
-        .getStateKeyBuilder()
-        .getMultimapKeysUserStateBuilder()
-        .setTransformId(pTransformId)
-        .setUserStateId(stateId)
-        .setKey(encodedKey)
-        .setWindow(encodedWindow);
-    keysStateRequest = keysStateRequestBuilder.build();
+    this.keysStateRequest =
+        
StateRequest.newBuilder().setInstructionId(instructionId).setStateKey(stateKey).build();
+    this.persistedKeys =
+        StateFetchingIterators.readAllAndDecodeStartingFrom(
+            cache, beamFnStateClient, keysStateRequest, mapKeyCoder);
 
     StateRequest.Builder userStateRequestBuilder = StateRequest.newBuilder();
     userStateRequestBuilder
         .setInstructionId(instructionId)
         .getStateKeyBuilder()
         .getMultimapUserStateBuilder()
-        .setTransformId(pTransformId)
-        .setUserStateId(stateId)
-        .setWindow(encodedWindow)
-        .setKey(encodedKey);
-    userStateRequest = userStateRequestBuilder.build();
+        .setTransformId(stateKey.getMultimapKeysUserState().getTransformId())
+        .setUserStateId(stateKey.getMultimapKeysUserState().getUserStateId())
+        .setWindow(stateKey.getMultimapKeysUserState().getWindow())
+        .setKey(stateKey.getMultimapKeysUserState().getKey());
+    this.userStateRequest = userStateRequestBuilder.build();
   }
 
   public void clear() {
@@ -115,7 +115,6 @@ public class MultimapUserState<K, V> {
 
     isCleared = true;
     persistedValues = Maps.newHashMap();
-    persistedKeys = null;
     pendingRemoves = Maps.newHashMap();
     pendingAdds = Maps.newHashMap();
   }
@@ -142,8 +141,7 @@ public class MultimapUserState<K, V> {
       return pendingValues;
     }
 
-    PrefetchableIterable<V> persistedValues = getPersistedValues(key);
-    return PrefetchableIterables.concat(persistedValues, pendingValues);
+    return PrefetchableIterables.concat(getPersistedValues(structuralKey, 
key), pendingValues);
   }
 
   @SuppressWarnings({
@@ -165,15 +163,14 @@ public class MultimapUserState<K, V> {
       return PrefetchableIterables.concat(keys);
     }
 
-    PrefetchableIterable<K> persistedKeys = getPersistedKeys();
     Set<Object> pendingRemovesNow = new HashSet<>(pendingRemoves.keySet());
     Map<Object, K> pendingAddsNow = new HashMap<>();
     for (Map.Entry<Object, KV<K, List<V>>> entry : pendingAdds.entrySet()) {
       pendingAddsNow.put(entry.getKey(), entry.getValue().getKey());
     }
-    return new PrefetchableIterable<K>() {
+    return new PrefetchableIterables.Default<K>() {
       @Override
-      public PrefetchableIterator<K> iterator() {
+      public PrefetchableIterator<K> createIterator() {
         return new PrefetchableIterator<K>() {
           PrefetchableIterator<K> persistedKeysIterator = 
persistedKeys.iterator();
           Iterator<K> pendingAddsNowIterator;
@@ -277,37 +274,80 @@ public class MultimapUserState<K, V> {
         "Multimap user state is no longer usable because it is closed for %s",
         keysStateRequest.getStateKey());
     isClosed = true;
-    // Nothing to persist
+    // No mutations necessary
     if (!isCleared && pendingRemoves.isEmpty() && pendingAdds.isEmpty()) {
       return;
     }
 
+    startStateApiWrites();
+    updateCache();
+  }
+
+  @SuppressWarnings("FutureReturnValueIgnored")
+  private void startStateApiWrites() {
     // Clear currently persisted key-values
     if (isCleared) {
-      beamFnStateClient
-          
.handle(keysStateRequest.toBuilder().setClear(StateClearRequest.getDefaultInstance()))
-          .get();
+      beamFnStateClient.handle(
+          
keysStateRequest.toBuilder().setClear(StateClearRequest.getDefaultInstance()));
     } else if (!pendingRemoves.isEmpty()) {
       for (K key : pendingRemoves.values()) {
-        beamFnStateClient
-            .handle(
-                createUserStateRequest(key)
-                    .toBuilder()
-                    .setClear(StateClearRequest.getDefaultInstance()))
-            .get();
+        StateRequest request = createUserStateRequest(key);
+        beamFnStateClient.handle(
+            
request.toBuilder().setClear(StateClearRequest.getDefaultInstance()));
       }
     }
 
     // Persist pending key-values
     if (!pendingAdds.isEmpty()) {
       for (KV<K, List<V>> entry : pendingAdds.values()) {
-        beamFnStateClient
-            .handle(
-                createUserStateRequest(entry.getKey())
-                    .toBuilder()
-                    .setAppend(
-                        
StateAppendRequest.newBuilder().setData(encodeValues(entry.getValue()))))
-            .get();
+        StateRequest request = createUserStateRequest(entry.getKey());
+        beamFnStateClient.handle(
+            request
+                .toBuilder()
+                .setAppend(
+                    
StateAppendRequest.newBuilder().setData(encodeValues(entry.getValue()))));
+      }
+    }
+  }
+
+  private void updateCache() {
+    List<K> pendingAddsKeys = new ArrayList<>(pendingAdds.size());
+    for (KV<K, List<V>> entry : pendingAdds.values()) {
+      pendingAddsKeys.add(entry.getKey());
+    }
+
+    if (isCleared) {
+      // This will clear all keys and values since values is a sub-cache of 
keys.
+      persistedKeys.clearAndAppend(pendingAddsKeys);
+
+      // Since the map was cleared we can add all the values that are pending 
since we know
+      // that they must have been cleared.
+      for (Map.Entry<Object, KV<K, List<V>>> entry : pendingAdds.entrySet()) {
+        CachingStateIterable<V> iterable =
+            getPersistedValues(entry.getKey(), entry.getValue().getKey());
+        iterable.clearAndAppend(entry.getValue().getValue());
+      }
+    } else {
+      // The cast to Set<Object> is necessary since the checker framework 
would like to further
+      // limit the type to Set<@KeyFor("this.pendingRemoves") Object> which is 
incompatible with
+      // the API being remove(Set<Object>). We don't want to limit the API for 
remove either.
+      persistedKeys.remove((Set<Object>) pendingRemoves.keySet());
+      persistedKeys.append(pendingAddsKeys);
+
+      // For each removed key, we want to update the internal cache to clear 
its set of values
+      for (Map.Entry<Object, K> entry : pendingRemoves.entrySet()) {
+        CachingStateIterable<V> iterable = getPersistedValues(entry.getKey(), 
entry.getValue());
+        iterable.clearAndAppend(Collections.emptyList());
+      }
+
+      // For each added key, try to update the internal cache with the set of 
values.
+      for (Map.Entry<Object, KV<K, List<V>>> entry : pendingAdds.entrySet()) {
+        KV<K, CachingStateIterable<V>> value = 
persistedValues.get(entry.getKey());
+        // We don't do anything for keys that haven't been loaded since we 
have no knowledge whether
+        // the key is empty or not.
+        if (value != null) {
+          value.getValue().append(entry.getValue().getValue());
+        }
       }
     }
   }
@@ -321,7 +361,10 @@ public class MultimapUserState<K, V> {
       return output.toByteString();
     } catch (IOException e) {
       throw new IllegalStateException(
-          String.format("Failed to encode values for multimap user state id 
%s.", stateId), e);
+          String.format(
+              "Failed to encode values for multimap user state id %s.",
+              
keysStateRequest.getStateKey().getMultimapKeysUserState().getUserStateId()),
+          e);
     }
   }
 
@@ -334,27 +377,30 @@ public class MultimapUserState<K, V> {
       return request.build();
     } catch (IOException e) {
       throw new IllegalStateException(
-          String.format("Failed to encode key for multimap user state id %s.", 
stateId), e);
+          String.format(
+              "Failed to encode key for multimap user state id %s.",
+              
keysStateRequest.getStateKey().getMultimapKeysUserState().getUserStateId()),
+          e);
     }
   }
 
-  private PrefetchableIterable<V> getPersistedValues(K key) {
-    if (!persistedValues.containsKey(key)) {
-      PrefetchableIterable<V> values =
-          StateFetchingIterators.readAllAndDecodeStartingFrom(
-              beamFnStateClient, createUserStateRequest(key), valueCoder);
-      persistedValues.put(key, values);
-    }
-    return persistedValues.get(key);
-  }
-
-  private PrefetchableIterable<K> getPersistedKeys() {
-    checkState(!isCleared);
-    if (persistedKeys == null) {
-      persistedKeys =
-          StateFetchingIterators.readAllAndDecodeStartingFrom(
-              beamFnStateClient, keysStateRequest, mapKeyCoder);
-    }
-    return persistedKeys;
+  private CachingStateIterable<V> getPersistedValues(Object structuralKey, K 
key) {
+    return persistedValues
+        .computeIfAbsent(
+            structuralKey,
+            unused -> {
+              StateRequest request = createUserStateRequest(key);
+              return KV.of(
+                  key,
+                  StateFetchingIterators.readAllAndDecodeStartingFrom(
+                      Caches.subCache(
+                          cache,
+                          "ValuesForKey",
+                          
request.getStateKey().getMultimapUserState().getMapKey()),
+                      beamFnStateClient,
+                      request,
+                      valueCoder));
+            })
+        .getValue();
   }
 }
diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
index c6bbfed..3525df9 100644
--- 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
@@ -41,6 +41,7 @@ import 
org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.fn.stream.DataStreams.DataStreamDecoder;
 import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterables;
 import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
 import org.apache.beam.sdk.fn.stream.PrefetchableIterators;
 import org.apache.beam.sdk.util.Weighted;
@@ -133,7 +134,7 @@ public class StateFetchingIterators {
    * all the remaining pages.
    */
   @VisibleForTesting
-  static class FirstPageAndRemainder<T> implements PrefetchableIterable<T> {
+  static class FirstPageAndRemainder<T> extends 
PrefetchableIterables.Default<T> {
     private final BeamFnStateClient beamFnStateClient;
     private final StateRequest stateRequestForFirstChunk;
     private final Coder<T> valueCoder;
@@ -151,7 +152,7 @@ public class StateFetchingIterators {
     }
 
     @Override
-    public PrefetchableIterator<T> iterator() {
+    public PrefetchableIterator<T> createIterator() {
       return new PrefetchableIterator<T>() {
         PrefetchableIterator<T> delegate;
 
@@ -261,7 +262,7 @@ public class StateFetchingIterators {
   }
 
   /** A mutable iterable that supports prefetch and is backed by a cache. */
-  static class CachingStateIterable<T> implements PrefetchableIterable<T> {
+  static class CachingStateIterable<T> extends 
PrefetchableIterables.Default<T> {
 
     /** Represents a set of elements. */
     abstract static class Blocks<T> implements Weighted {
@@ -446,7 +447,7 @@ public class StateFetchingIterators {
     }
 
     @Override
-    public PrefetchableIterator<T> iterator() {
+    public PrefetchableIterator<T> createIterator() {
       return new CachingStateIterator();
     }
 
diff --git 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/BagUserStateTest.java
 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/BagUserStateTest.java
index ceeab22..e9c6f71 100644
--- 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/BagUserStateTest.java
+++ 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/BagUserStateTest.java
@@ -25,6 +25,8 @@ import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertThrows;
 
 import java.io.IOException;
+import org.apache.beam.fn.harness.Cache;
+import org.apache.beam.fn.harness.Caches;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
@@ -44,13 +46,7 @@ public class BagUserStateTest {
             StringUtf8Coder.of(), ImmutableMap.of(key("A"), asList("A1", "A2", 
"A3")));
     BagUserState<String> userState =
         new BagUserState<>(
-            fakeClient,
-            "instructionId",
-            "ptransformId",
-            "stateId",
-            ByteString.copyFromUtf8("encodedWindow"),
-            encode("A"),
-            StringUtf8Coder.of());
+            Caches.noop(), fakeClient, "instructionId", key("A"), 
StringUtf8Coder.of());
     assertArrayEquals(
         new String[] {"A1", "A2", "A3"}, Iterables.toArray(userState.get(), 
String.class));
 
@@ -59,18 +55,45 @@ public class BagUserStateTest {
   }
 
   @Test
+  public void testGetCached() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            StringUtf8Coder.of(), ImmutableMap.of(key("A"), asList("A1", "A2", 
"A3")));
+
+    Cache<?, ?> cache = Caches.eternal();
+    {
+      // First user state populates the cache.
+      BagUserState<String> userState =
+          new BagUserState<>(cache, fakeClient, "instructionId", key("A"), 
StringUtf8Coder.of());
+      assertArrayEquals(
+          new String[] {"A1", "A2", "A3"}, Iterables.toArray(userState.get(), 
String.class));
+      userState.asyncClose();
+    }
+
+    {
+      // The next user state will load all of its contents from the cache.
+      BagUserState<String> userState =
+          new BagUserState<>(
+              cache,
+              requestBuilder -> {
+                throw new IllegalStateException("Unexpected call for test.");
+              },
+              "instructionId",
+              key("A"),
+              StringUtf8Coder.of());
+      assertArrayEquals(
+          new String[] {"A1", "A2", "A3"}, Iterables.toArray(userState.get(), 
String.class));
+      userState.asyncClose();
+    }
+  }
+
+  @Test
   public void testAppend() throws Exception {
     FakeBeamFnStateClient fakeClient =
         new FakeBeamFnStateClient(StringUtf8Coder.of(), 
ImmutableMap.of(key("A"), asList("A1")));
     BagUserState<String> userState =
         new BagUserState<>(
-            fakeClient,
-            "instructionId",
-            "ptransformId",
-            "stateId",
-            ByteString.copyFromUtf8("encodedWindow"),
-            encode("A"),
-            StringUtf8Coder.of());
+            Caches.noop(), fakeClient, "instructionId", key("A"), 
StringUtf8Coder.of());
     userState.append("A2");
     Iterable<String> stateBeforeA3 = userState.get();
     assertArrayEquals(new String[] {"A1", "A2"}, 
Iterables.toArray(stateBeforeA3, String.class));
@@ -85,19 +108,62 @@ public class BagUserStateTest {
   }
 
   @Test
+  public void testAppendCached() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(StringUtf8Coder.of(), 
ImmutableMap.of(key("A"), asList("A1")));
+    Cache<?, ?> cache = Caches.eternal();
+    {
+      // First user state populates the cache.
+      BagUserState<String> userState =
+          new BagUserState<>(cache, fakeClient, "instructionId", key("A"), 
StringUtf8Coder.of());
+      userState.append("A2");
+      Iterable<String> stateBeforeA3 = userState.get();
+      assertArrayEquals(new String[] {"A1", "A2"}, 
Iterables.toArray(stateBeforeA3, String.class));
+      userState.append("A3");
+      assertArrayEquals(new String[] {"A1", "A2"}, 
Iterables.toArray(stateBeforeA3, String.class));
+      assertArrayEquals(
+          new String[] {"A1", "A2", "A3"}, Iterables.toArray(userState.get(), 
String.class));
+      userState.asyncClose();
+    }
+
+    {
+      // The next user state will load all of its contents from the cache 
including the appends
+      // persisted via asyncClose.
+      BagUserState<String> userState =
+          new BagUserState<>(
+              cache,
+              requestBuilder -> {
+                if (requestBuilder.hasGet()) {
+                  throw new IllegalStateException("Unexpected call for test.");
+                }
+                return fakeClient.handle(requestBuilder);
+              },
+              "instructionId",
+              key("A"),
+              StringUtf8Coder.of());
+      userState.append("A4");
+      Iterable<String> stateBeforeA5 = userState.get();
+      assertArrayEquals(
+          new String[] {"A1", "A2", "A3", "A4"}, 
Iterables.toArray(stateBeforeA5, String.class));
+      userState.append("A5");
+      assertArrayEquals(
+          new String[] {"A1", "A2", "A3", "A4"}, 
Iterables.toArray(stateBeforeA5, String.class));
+      assertArrayEquals(
+          new String[] {"A1", "A2", "A3", "A4", "A5"},
+          Iterables.toArray(userState.get(), String.class));
+      userState.asyncClose();
+    }
+    assertEquals(encode("A1", "A2", "A3", "A4", "A5"), 
fakeClient.getData().get(key("A")));
+  }
+
+  @Test
   public void testClear() throws Exception {
     FakeBeamFnStateClient fakeClient =
         new FakeBeamFnStateClient(
             StringUtf8Coder.of(), ImmutableMap.of(key("A"), asList("A1", "A2", 
"A3")));
     BagUserState<String> userState =
         new BagUserState<>(
-            fakeClient,
-            "instructionId",
-            "ptransformId",
-            "stateId",
-            ByteString.copyFromUtf8("encodedWindow"),
-            encode("A"),
-            StringUtf8Coder.of());
+            Caches.noop(), fakeClient, "instructionId", key("A"), 
StringUtf8Coder.of());
     assertArrayEquals(
         new String[] {"A1", "A2", "A3"}, Iterables.toArray(userState.get(), 
String.class));
     userState.clear();
@@ -112,6 +178,68 @@ public class BagUserStateTest {
     assertThrows(IllegalStateException.class, () -> userState.clear());
   }
 
+  @Test
+  public void testClearCached() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            StringUtf8Coder.of(), ImmutableMap.of(key("A"), asList("A1", "A2", 
"A3")));
+
+    Cache<?, ?> cache = Caches.eternal();
+    {
+      // First user state populates the cache.
+      BagUserState<String> userState =
+          new BagUserState<>(cache, fakeClient, "instructionId", key("A"), 
StringUtf8Coder.of());
+      assertArrayEquals(
+          new String[] {"A1", "A2", "A3"}, Iterables.toArray(userState.get(), 
String.class));
+      userState.clear();
+      assertFalse(userState.get().iterator().hasNext());
+      userState.append("A4");
+      assertArrayEquals(new String[] {"A4"}, 
Iterables.toArray(userState.get(), String.class));
+      userState.asyncClose();
+    }
+
+    {
+      // The next user state will load all of its contents from the cache 
including the clear and
+      // append persisted via asyncClose.
+      BagUserState<String> userState =
+          new BagUserState<>(
+              cache,
+              requestBuilder -> {
+                if (requestBuilder.hasGet()) {
+                  throw new IllegalStateException("Unexpected call for test.");
+                }
+                return fakeClient.handle(requestBuilder);
+              },
+              "instructionId",
+              key("A"),
+              StringUtf8Coder.of());
+      assertArrayEquals(new String[] {"A4"}, 
Iterables.toArray(userState.get(), String.class));
+      userState.clear();
+      assertFalse(userState.get().iterator().hasNext());
+      userState.asyncClose();
+    }
+
+    {
+      // The next user state will load all of its contents from the cache 
including the clear
+      // persisted via asyncClose.
+      BagUserState<String> userState =
+          new BagUserState<>(
+              cache,
+              requestBuilder -> {
+                if (requestBuilder.hasGet()) {
+                  throw new IllegalStateException("Unexpected call for test.");
+                }
+                return fakeClient.handle(requestBuilder);
+              },
+              "instructionId",
+              key("A"),
+              StringUtf8Coder.of());
+      assertArrayEquals(new String[] {}, Iterables.toArray(userState.get(), 
String.class));
+      userState.asyncClose();
+    }
+    assertNull(fakeClient.getData().get(key("A")));
+  }
+
   private StateKey key(String id) throws IOException {
     return StateKey.newBuilder()
         .setBagUserState(
diff --git 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/IterableSideInputTest.java
 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/IterableSideInputTest.java
new file mode 100644
index 0000000..9e44295
--- /dev/null
+++ 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/IterableSideInputTest.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness.state;
+
+import static java.util.Arrays.asList;
+import static org.junit.Assert.assertArrayEquals;
+
+import java.io.IOException;
+import org.apache.beam.fn.harness.Cache;
+import org.apache.beam.fn.harness.Caches;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class IterableSideInputTest {
+  @Test
+  public void testGet() throws Exception {
+    FakeBeamFnStateClient fakeBeamFnStateClient =
+        new FakeBeamFnStateClient(
+            StringUtf8Coder.of(),
+            ImmutableMap.of(key(), asList("A1", "A2", "A3", "A4", "A5", 
"A6")));
+
+    IterableSideInput<String> iterableSideInput =
+        new IterableSideInput<>(
+            Caches.noop(), fakeBeamFnStateClient, "instructionId", key(), 
StringUtf8Coder.of());
+    assertArrayEquals(
+        new String[] {"A1", "A2", "A3", "A4", "A5", "A6"},
+        Iterables.toArray(iterableSideInput.get(), String.class));
+  }
+
+  @Test
+  public void testGetCached() throws Exception {
+    FakeBeamFnStateClient fakeBeamFnStateClient =
+        new FakeBeamFnStateClient(
+            StringUtf8Coder.of(),
+            ImmutableMap.of(key(), asList("A1", "A2", "A3", "A4", "A5", 
"A6")));
+
+    Cache<?, ?> cache = Caches.eternal();
+    {
+      // The first side input will populate the cache.
+      IterableSideInput<String> iterableSideInput =
+          new IterableSideInput<>(
+              cache, fakeBeamFnStateClient, "instructionId", key(), 
StringUtf8Coder.of());
+      assertArrayEquals(
+          new String[] {"A1", "A2", "A3", "A4", "A5", "A6"},
+          Iterables.toArray(iterableSideInput.get(), String.class));
+    }
+
+    {
+      // The next side input will load all of its contents from the cache.
+      IterableSideInput<String> iterableSideInput =
+          new IterableSideInput<>(
+              cache,
+              requestBuilder -> {
+                throw new IllegalStateException("Unexpected call for test.");
+              },
+              "instructionId",
+              key(),
+              StringUtf8Coder.of());
+      assertArrayEquals(
+          new String[] {"A1", "A2", "A3", "A4", "A5", "A6"},
+          Iterables.toArray(iterableSideInput.get(), String.class));
+    }
+  }
+
+  private StateKey key() throws IOException {
+    return StateKey.newBuilder()
+        .setIterableSideInput(
+            StateKey.IterableSideInput.newBuilder()
+                .setTransformId("ptransformId")
+                .setSideInputId("sideInputId")
+                .setWindow(ByteString.copyFromUtf8("encodedWindow")))
+        .build();
+  }
+}
diff --git 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapSideInputTest.java
 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapSideInputTest.java
index 9321ca6..81e9a45 100644
--- 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapSideInputTest.java
+++ 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapSideInputTest.java
@@ -22,8 +22,11 @@ import static org.junit.Assert.assertArrayEquals;
 
 import java.io.IOException;
 import java.nio.charset.StandardCharsets;
+import org.apache.beam.fn.harness.Cache;
+import org.apache.beam.fn.harness.Caches;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
 import org.apache.beam.sdk.coders.ByteArrayCoder;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
@@ -33,7 +36,13 @@ import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 
-/** Tests for {@link MultimapSideInput}. */
+/**
+ * Tests for {@link MultimapSideInput}.
+ *
+ * <p>It is important to use a key type where its coder is not {@link 
Coder#consistentWithEquals()}
+ * to ensure that comparisons are performed using structural values instead of 
object equality
+ * during testing.
+ */
 @RunWith(JUnit4.class)
 public class MultimapSideInputTest {
   private static final byte[] A = "A".getBytes(StandardCharsets.UTF_8);
@@ -51,11 +60,10 @@ public class MultimapSideInputTest {
 
     MultimapSideInput<byte[], String> multimapSideInput =
         new MultimapSideInput<>(
+            Caches.noop(),
             fakeBeamFnStateClient,
             "instructionId",
-            "ptransformId",
-            "sideInputId",
-            ByteString.copyFromUtf8("encodedWindow"),
+            stateKey(),
             ByteArrayCoder.of(),
             StringUtf8Coder.of());
     assertArrayEquals(
@@ -68,6 +76,61 @@ public class MultimapSideInputTest {
         new byte[][] {A, B}, Iterables.toArray(multimapSideInput.get(), 
byte[].class));
   }
 
+  @Test
+  public void testGetCached() throws Exception {
+    FakeBeamFnStateClient fakeBeamFnStateClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                stateKey(), KV.of(ByteArrayCoder.of(), asList(A, B)),
+                key(A), KV.of(StringUtf8Coder.of(), asList("A1", "A2", "A3")),
+                key(B), KV.of(StringUtf8Coder.of(), asList("B1", "B2"))));
+
+    Cache<?, ?> cache = Caches.eternal();
+    {
+      // The first side input will populate the cache.
+      MultimapSideInput<byte[], String> multimapSideInput =
+          new MultimapSideInput<>(
+              cache,
+              fakeBeamFnStateClient,
+              "instructionId",
+              stateKey(),
+              ByteArrayCoder.of(),
+              StringUtf8Coder.of());
+      assertArrayEquals(
+          new String[] {"A1", "A2", "A3"},
+          Iterables.toArray(multimapSideInput.get(A), String.class));
+      assertArrayEquals(
+          new String[] {"B1", "B2"}, 
Iterables.toArray(multimapSideInput.get(B), String.class));
+      assertArrayEquals(
+          new String[] {}, Iterables.toArray(multimapSideInput.get(UNKNOWN), 
String.class));
+      assertArrayEquals(
+          new byte[][] {A, B}, Iterables.toArray(multimapSideInput.get(), 
byte[].class));
+    }
+
+    {
+      // The next side input will load all of its contents from the cache.
+      MultimapSideInput<byte[], String> multimapSideInput =
+          new MultimapSideInput<>(
+              cache,
+              requestBuilder -> {
+                throw new IllegalStateException("Unexpected call for test.");
+              },
+              "instructionId",
+              stateKey(),
+              ByteArrayCoder.of(),
+              StringUtf8Coder.of());
+      assertArrayEquals(
+          new String[] {"A1", "A2", "A3"},
+          Iterables.toArray(multimapSideInput.get(A), String.class));
+      assertArrayEquals(
+          new String[] {"B1", "B2"}, 
Iterables.toArray(multimapSideInput.get(B), String.class));
+      assertArrayEquals(
+          new String[] {}, Iterables.toArray(multimapSideInput.get(UNKNOWN), 
String.class));
+      assertArrayEquals(
+          new byte[][] {A, B}, Iterables.toArray(multimapSideInput.get(), 
byte[].class));
+    }
+  }
+
   private StateKey stateKey() throws IOException {
     return StateKey.newBuilder()
         .setMultimapKeysSideInput(
diff --git 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java
 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java
index 4749c52..8eb709b 100644
--- 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java
+++ 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java
@@ -33,6 +33,8 @@ import java.util.Arrays;
 import java.util.Collections;
 import java.util.Iterator;
 import java.util.Map;
+import org.apache.beam.fn.harness.Cache;
+import org.apache.beam.fn.harness.Caches;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
 import org.apache.beam.sdk.coders.ByteArrayCoder;
 import org.apache.beam.sdk.coders.Coder;
@@ -70,12 +72,10 @@ public class MultimapUserStateTest {
     FakeBeamFnStateClient fakeClient = new 
FakeBeamFnStateClient(Collections.emptyMap());
     MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
+            Caches.noop(),
             fakeClient,
             "instructionId",
-            pTransformId,
-            stateId,
-            encode(encodedWindow),
-            encode(encodedKey),
+            createMultimapKeyStateKey(),
             ByteArrayCoder.of(),
             StringUtf8Coder.of());
     assertThat(userState.keys(), is(emptyIterable()));
@@ -92,12 +92,10 @@ public class MultimapUserStateTest {
                 KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
     MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
+            Caches.noop(),
             fakeClient,
             "instructionId",
-            pTransformId,
-            stateId,
-            encode(encodedWindow),
-            encode(encodedKey),
+            createMultimapKeyStateKey(),
             ByteArrayCoder.of(),
             StringUtf8Coder.of());
 
@@ -122,12 +120,10 @@ public class MultimapUserStateTest {
                 KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
     MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
+            Caches.noop(),
             fakeClient,
             "instructionId",
-            pTransformId,
-            stateId,
-            encode(encodedWindow),
-            encode(encodedKey),
+            createMultimapKeyStateKey(),
             ByteArrayCoder.of(),
             StringUtf8Coder.of());
 
@@ -158,12 +154,10 @@ public class MultimapUserStateTest {
                 KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
     MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
+            Caches.noop(),
             fakeClient,
             "instructionId",
-            pTransformId,
-            stateId,
-            encode(encodedWindow),
-            encode(encodedKey),
+            createMultimapKeyStateKey(),
             ByteArrayCoder.of(),
             StringUtf8Coder.of());
 
@@ -192,12 +186,10 @@ public class MultimapUserStateTest {
                 KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
     MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
+            Caches.noop(),
             fakeClient,
             "instructionId",
-            pTransformId,
-            stateId,
-            encode(encodedWindow),
-            encode(encodedKey),
+            createMultimapKeyStateKey(),
             ByteArrayCoder.of(),
             StringUtf8Coder.of());
 
@@ -221,20 +213,17 @@ public class MultimapUserStateTest {
                 KV.of(StringUtf8Coder.of(), asList("V1"))));
     MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
+            Caches.noop(),
             fakeClient,
             "instructionId",
-            pTransformId,
-            stateId,
-            encode(encodedWindow),
-            encode(encodedKey),
+            createMultimapKeyStateKey(),
             ByteArrayCoder.of(),
             StringUtf8Coder.of());
     userState.remove(A0);
     userState.put(A0, "V2");
     assertArrayEquals(new String[] {"V2"}, 
Iterables.toArray(userState.get(A0), String.class));
     userState.asyncClose();
-    Map<StateKey, ByteString> data = fakeClient.getData();
-    assertEquals(encode("V2"), data.get(createMultimapValueStateKey(A0)));
+    assertEquals(encode("V2"), 
fakeClient.getData().get(createMultimapValueStateKey(A0)));
   }
 
   @Test
@@ -248,12 +237,10 @@ public class MultimapUserStateTest {
                 KV.of(StringUtf8Coder.of(), asList("V1"))));
     MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
+            Caches.noop(),
             fakeClient,
             "instructionId",
-            pTransformId,
-            stateId,
-            encode(encodedWindow),
-            encode(encodedKey),
+            createMultimapKeyStateKey(),
             ByteArrayCoder.of(),
             StringUtf8Coder.of());
     userState.clear();
@@ -272,12 +259,10 @@ public class MultimapUserStateTest {
                 KV.of(StringUtf8Coder.of(), asList("V1"))));
     MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
+            Caches.noop(),
             fakeClient,
             "instructionId",
-            pTransformId,
-            stateId,
-            encode(encodedWindow),
-            encode(encodedKey),
+            createMultimapKeyStateKey(),
             ByteArrayCoder.of(),
             StringUtf8Coder.of());
     userState.remove(A0);
@@ -292,12 +277,10 @@ public class MultimapUserStateTest {
     FakeBeamFnStateClient fakeClient = new 
FakeBeamFnStateClient(Collections.emptyMap());
     MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
+            Caches.noop(),
             fakeClient,
             "instructionId",
-            pTransformId,
-            stateId,
-            encode(encodedWindow),
-            encode(encodedKey),
+            createMultimapKeyStateKey(),
             ByteArrayCoder.of(),
             StringUtf8Coder.of());
     userState.put(A0, "V0");
@@ -315,12 +298,10 @@ public class MultimapUserStateTest {
     FakeBeamFnStateClient fakeClient = new 
FakeBeamFnStateClient(Collections.emptyMap());
     MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
+            Caches.noop(),
             fakeClient,
             "instructionId",
-            pTransformId,
-            stateId,
-            encode(encodedWindow),
-            encode(encodedKey),
+            createMultimapKeyStateKey(),
             ByteArrayCoder.of(),
             StringUtf8Coder.of());
     userState.put(A0, "V0");
@@ -346,12 +327,10 @@ public class MultimapUserStateTest {
                 KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
     MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
+            Caches.noop(),
             fakeClient,
             "instructionId",
-            pTransformId,
-            stateId,
-            encode(encodedWindow),
-            encode(encodedKey),
+            createMultimapKeyStateKey(),
             ByteArrayCoder.of(),
             StringUtf8Coder.of());
 
@@ -376,12 +355,10 @@ public class MultimapUserStateTest {
                 KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
     MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
+            Caches.noop(),
             fakeClient,
             "instructionId",
-            pTransformId,
-            stateId,
-            encode(encodedWindow),
-            encode(encodedKey),
+            createMultimapKeyStateKey(),
             ByteArrayCoder.of(),
             StringUtf8Coder.of());
     Iterable<byte[]> keys = userState.keys();
@@ -401,12 +378,10 @@ public class MultimapUserStateTest {
                 KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
     MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
+            Caches.noop(),
             fakeClient,
             "instructionId",
-            pTransformId,
-            stateId,
-            encode(encodedWindow),
-            encode(encodedKey),
+            createMultimapKeyStateKey(),
             ByteArrayCoder.of(),
             StringUtf8Coder.of());
     Iterable<String> values = userState.get(A1);
@@ -426,12 +401,10 @@ public class MultimapUserStateTest {
                 KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
     MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
+            Caches.noop(),
             fakeClient,
             "instructionId",
-            pTransformId,
-            stateId,
-            encode(encodedWindow),
-            encode(encodedKey),
+            createMultimapKeyStateKey(),
             ByteArrayCoder.of(),
             StringUtf8Coder.of());
     userState.clear();
@@ -452,12 +425,10 @@ public class MultimapUserStateTest {
                 KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
     MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
+            Caches.noop(),
             fakeClient,
             "instructionId",
-            pTransformId,
-            stateId,
-            encode(encodedWindow),
-            encode(encodedKey),
+            createMultimapKeyStateKey(),
             ByteArrayCoder.of(),
             StringUtf8Coder.of());
     userState.asyncClose();
@@ -478,12 +449,10 @@ public class MultimapUserStateTest {
                 KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
     MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
+            Caches.noop(),
             fakeClient,
             "instructionId",
-            pTransformId,
-            stateId,
-            encode(encodedWindow),
-            encode(encodedKey),
+            createMultimapKeyStateKey(),
             ByteArrayCoder.of(),
             StringUtf8Coder.of());
     userState.remove(A0);
@@ -509,12 +478,10 @@ public class MultimapUserStateTest {
                 KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
     MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
+            Caches.noop(),
             fakeClient,
             "instructionId",
-            pTransformId,
-            stateId,
-            encode(encodedWindow),
-            encode(encodedKey),
+            createMultimapKeyStateKey(),
             NullableCoder.of(ByteArrayCoder.of()),
             NullableCoder.of(StringUtf8Coder.of()));
     userState.put(null, null);
@@ -529,12 +496,10 @@ public class MultimapUserStateTest {
     FakeBeamFnStateClient fakeClient = new 
FakeBeamFnStateClient(Collections.emptyMap());
     MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
+            Caches.eternal(),
             fakeClient,
             "instructionId",
-            pTransformId,
-            stateId,
-            encode(encodedWindow),
-            encode(encodedKey),
+            createMultimapKeyStateKey(),
             ByteArrayCoder.of(),
             StringUtf8Coder.of());
     assertArrayEquals(new String[] {}, Iterables.toArray(userState.get(A1), 
String.class));
@@ -553,18 +518,16 @@ public class MultimapUserStateTest {
                 KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
     MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
+            Caches.eternal(),
             fakeClient,
             "instructionId",
-            pTransformId,
-            stateId,
-            encode(encodedWindow),
-            encode(encodedKey),
+            createMultimapKeyStateKey(),
             ByteArrayCoder.of(),
             StringUtf8Coder.of());
 
     PrefetchableIterable<String> values = userState.get(A1);
     assertEquals(0, fakeClient.getCallCount());
-    values.iterator().prefetch();
+    values.prefetch();
     assertEquals(1, fakeClient.getCallCount());
     assertArrayEquals(new String[] {"V1", "V2"}, Iterables.toArray(values, 
String.class));
     assertEquals(1, fakeClient.getCallCount());
@@ -581,18 +544,16 @@ public class MultimapUserStateTest {
                 KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
     MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
+            Caches.eternal(),
             fakeClient,
             "instructionId",
-            pTransformId,
-            stateId,
-            encode(encodedWindow),
-            encode(encodedKey),
+            createMultimapKeyStateKey(),
             ByteArrayCoder.of(),
             StringUtf8Coder.of());
 
     PrefetchableIterable<byte[]> keys = userState.keys();
     assertEquals(0, fakeClient.getCallCount());
-    keys.iterator().prefetch();
+    keys.prefetch();
     assertEquals(1, fakeClient.getCallCount());
     assertArrayEquals(new byte[][] {A1}, Iterables.toArray(keys, 
byte[].class));
     assertEquals(1, fakeClient.getCallCount());
@@ -609,19 +570,17 @@ public class MultimapUserStateTest {
                 KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
     MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
+            Caches.eternal(),
             fakeClient,
             "instructionId",
-            pTransformId,
-            stateId,
-            encode(encodedWindow),
-            encode(encodedKey),
+            createMultimapKeyStateKey(),
             ByteArrayCoder.of(),
             StringUtf8Coder.of());
 
     userState.put(A2, "V3");
     PrefetchableIterable<byte[]> keys = userState.keys();
     assertEquals(0, fakeClient.getCallCount());
-    keys.iterator().prefetch();
+    keys.prefetch();
     assertEquals(1, fakeClient.getCallCount());
     assertArrayEquals(new byte[][] {A1, A2}, Iterables.toArray(keys, 
byte[].class));
     assertEquals(1, fakeClient.getCallCount());
@@ -638,12 +597,10 @@ public class MultimapUserStateTest {
                 KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
     MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
+            Caches.noop(),
             fakeClient,
             "instructionId",
-            pTransformId,
-            stateId,
-            encode(encodedWindow),
-            encode(encodedKey),
+            createMultimapKeyStateKey(),
             ByteArrayCoder.of(),
             StringUtf8Coder.of());
 
@@ -651,7 +608,7 @@ public class MultimapUserStateTest {
     userState.put(A1, "V3");
     PrefetchableIterable<String> values = userState.get(A1);
     assertEquals(0, fakeClient.getCallCount());
-    values.iterator().prefetch();
+    values.prefetch();
     // Removed keys don't require accessing the underlying persisted state
     assertEquals(0, fakeClient.getCallCount());
     assertArrayEquals(new String[] {"V3"}, Iterables.toArray(values, 
String.class));
@@ -670,12 +627,10 @@ public class MultimapUserStateTest {
                 KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
     MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
+            Caches.noop(),
             fakeClient,
             "instructionId",
-            pTransformId,
-            stateId,
-            encode(encodedWindow),
-            encode(encodedKey),
+            createMultimapKeyStateKey(),
             ByteArrayCoder.of(),
             StringUtf8Coder.of());
 
@@ -683,7 +638,7 @@ public class MultimapUserStateTest {
     userState.put(A2, "V3");
     PrefetchableIterable<byte[]> keys = userState.keys();
     assertEquals(0, fakeClient.getCallCount());
-    keys.iterator().prefetch();
+    keys.prefetch();
     // Cleared keys don't require accessing the underlying persisted state
     assertEquals(0, fakeClient.getCallCount());
     assertArrayEquals(new byte[][] {A2}, Iterables.toArray(keys, 
byte[].class));
@@ -702,24 +657,385 @@ public class MultimapUserStateTest {
                 KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
     MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
+            Caches.eternal(),
             fakeClient,
             "instructionId",
-            pTransformId,
-            stateId,
-            encode(encodedWindow),
-            encode(encodedKey),
+            createMultimapKeyStateKey(),
             ByteArrayCoder.of(),
             StringUtf8Coder.of());
 
     userState.put(A1, "V3");
     PrefetchableIterable<String> values = userState.get(A1);
     assertEquals(0, fakeClient.getCallCount());
-    values.iterator().prefetch();
+    values.prefetch();
     assertEquals(1, fakeClient.getCallCount());
     assertArrayEquals(new String[] {"V1", "V2", "V3"}, 
Iterables.toArray(values, String.class));
     assertEquals(1, fakeClient.getCallCount());
   }
 
+  @Test
+  public void testNoPersistedValuesCached() throws Exception {
+    FakeBeamFnStateClient fakeClient = new 
FakeBeamFnStateClient(Collections.emptyMap());
+    Cache<?, ?> cache = Caches.eternal();
+    {
+      // First user state populates the cache.
+      MultimapUserState<byte[], String> userState =
+          new MultimapUserState<>(
+              cache,
+              fakeClient,
+              "instructionId",
+              createMultimapKeyStateKey(),
+              ByteArrayCoder.of(),
+              StringUtf8Coder.of());
+      assertThat(userState.keys(), is(emptyIterable()));
+      assertThat(userState.get(A1), is(emptyIterable()));
+    }
+
+    {
+      // The next user state will load all of its contents from the cache.
+      MultimapUserState<byte[], String> userState =
+          new MultimapUserState<>(
+              cache,
+              requestBuilder -> {
+                throw new IllegalStateException("Unexpected call for test.");
+              },
+              "instructionId",
+              createMultimapKeyStateKey(),
+              ByteArrayCoder.of(),
+              StringUtf8Coder.of());
+      assertThat(userState.keys(), is(emptyIterable()));
+      assertThat(userState.get(A1), is(emptyIterable()));
+    }
+  }
+
+  @Test
+  public void testGetCached() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                createMultimapKeyStateKey(),
+                KV.of(ByteArrayCoder.of(), singletonList(A1)),
+                createMultimapValueStateKey(A1),
+                KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
+    Cache<?, ?> cache = Caches.eternal();
+    {
+      // First user state populates the cache.
+      MultimapUserState<byte[], String> userState =
+          new MultimapUserState<>(
+              cache,
+              fakeClient,
+              "instructionId",
+              createMultimapKeyStateKey(),
+              ByteArrayCoder.of(),
+              StringUtf8Coder.of());
+
+      assertArrayEquals(
+          new String[] {"V1", "V2"}, Iterables.toArray(userState.get(A1), 
String.class));
+      userState.asyncClose();
+    }
+
+    {
+      // The next user state will load all of its contents from the cache.
+      MultimapUserState<byte[], String> userState =
+          new MultimapUserState<>(
+              cache,
+              requestBuilder -> {
+                throw new IllegalStateException("Unexpected call for test.");
+              },
+              "instructionId",
+              createMultimapKeyStateKey(),
+              ByteArrayCoder.of(),
+              StringUtf8Coder.of());
+
+      assertArrayEquals(
+          new String[] {"V1", "V2"}, Iterables.toArray(userState.get(A1), 
String.class));
+      userState.asyncClose();
+    }
+  }
+
+  @Test
+  public void testClearCached() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                createMultimapKeyStateKey(),
+                KV.of(ByteArrayCoder.of(), singletonList(A1)),
+                createMultimapValueStateKey(A1),
+                KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
+    Cache<?, ?> cache = Caches.eternal();
+    {
+      // First user state populates the cache.
+      MultimapUserState<byte[], String> userState =
+          new MultimapUserState<>(
+              cache,
+              fakeClient,
+              "instructionId",
+              createMultimapKeyStateKey(),
+              ByteArrayCoder.of(),
+              StringUtf8Coder.of());
+
+      userState.clear();
+      assertThat(userState.keys(), is(emptyIterable()));
+      userState.asyncClose();
+    }
+
+    {
+      // The next user state will load all of its contents from the cache 
including the mutations
+      // persisted via asyncClose.
+      MultimapUserState<byte[], String> userState =
+          new MultimapUserState<>(
+              cache,
+              requestBuilder -> {
+                throw new IllegalStateException("Unexpected call for test.");
+              },
+              "instructionId",
+              createMultimapKeyStateKey(),
+              ByteArrayCoder.of(),
+              StringUtf8Coder.of());
+
+      assertThat(userState.keys(), is(emptyIterable()));
+      userState.asyncClose();
+    }
+  }
+
+  @Test
+  public void testKeysCached() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                createMultimapKeyStateKey(),
+                KV.of(ByteArrayCoder.of(), singletonList(A1)),
+                createMultimapValueStateKey(A1),
+                KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
+    Cache<?, ?> cache = Caches.eternal();
+    {
+      // First user state populates the cache.
+      MultimapUserState<byte[], String> userState =
+          new MultimapUserState<>(
+              cache,
+              fakeClient,
+              "instructionId",
+              createMultimapKeyStateKey(),
+              ByteArrayCoder.of(),
+              StringUtf8Coder.of());
+
+      userState.put(A2, "V1");
+      userState.put(A3, "V1");
+      assertArrayEquals(
+          new byte[][] {A1, A2, A3}, Iterables.toArray(userState.keys(), 
byte[].class));
+      userState.asyncClose();
+    }
+
+    {
+      // The next user state will load all of its contents from the cache 
including the mutations
+      // persisted via asyncClose.
+      MultimapUserState<byte[], String> userState =
+          new MultimapUserState<>(
+              cache,
+              requestBuilder -> {
+                throw new IllegalStateException("Unexpected call for test.");
+              },
+              "instructionId",
+              createMultimapKeyStateKey(),
+              ByteArrayCoder.of(),
+              StringUtf8Coder.of());
+
+      assertArrayEquals(
+          new byte[][] {A1, A2, A3}, Iterables.toArray(userState.keys(), 
byte[].class));
+      userState.asyncClose();
+    }
+  }
+
+  @Test
+  public void testPutCached() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                createMultimapKeyStateKey(),
+                KV.of(ByteArrayCoder.of(), singletonList(A1)),
+                createMultimapValueStateKey(A1),
+                KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
+    Cache<?, ?> cache = Caches.eternal();
+    {
+      // First user state populates the cache.
+      MultimapUserState<byte[], String> userState =
+          new MultimapUserState<>(
+              cache,
+              fakeClient,
+              "instructionId",
+              createMultimapKeyStateKey(),
+              ByteArrayCoder.of(),
+              StringUtf8Coder.of());
+
+      userState.put(A1, "V3");
+      userState.put(A2, "V1");
+      assertArrayEquals(
+          new String[] {"V1", "V2", "V3"}, 
Iterables.toArray(userState.get(A1), String.class));
+      userState.asyncClose();
+    }
+
+    {
+      // The next user state will load all of its contents from the cache 
including the mutations
+      // persisted via asyncClose except for A2 since it was never loaded so 
the mutation is
+      // discarded.
+      int callCount = fakeClient.getCallCount();
+      MultimapUserState<byte[], String> userState =
+          new MultimapUserState<>(
+              cache,
+              fakeClient,
+              "instructionId",
+              createMultimapKeyStateKey(),
+              ByteArrayCoder.of(),
+              StringUtf8Coder.of());
+
+      assertArrayEquals(
+          new String[] {"V1", "V2", "V3"}, 
Iterables.toArray(userState.get(A1), String.class));
+      assertEquals(callCount, fakeClient.getCallCount());
+      // We expect one call when loading A2 since the append would have been 
discarded since the
+      // key was never fully loaded.
+      assertArrayEquals(new String[] {"V1"}, 
Iterables.toArray(userState.get(A2), String.class));
+      assertEquals(callCount + 1, fakeClient.getCallCount());
+      userState.asyncClose();
+    }
+  }
+
+  @Test
+  public void testPutAfterRemoveCached() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                createMultimapKeyStateKey(),
+                KV.of(ByteArrayCoder.of(), singletonList(A0)),
+                createMultimapValueStateKey(A0),
+                KV.of(StringUtf8Coder.of(), asList("V1"))));
+    Cache<?, ?> cache = Caches.eternal();
+    {
+      // First user state populates the cache.
+      MultimapUserState<byte[], String> userState =
+          new MultimapUserState<>(
+              cache,
+              fakeClient,
+              "instructionId",
+              createMultimapKeyStateKey(),
+              ByteArrayCoder.of(),
+              StringUtf8Coder.of());
+
+      userState.remove(A0);
+      userState.put(A0, "V2");
+      userState.asyncClose();
+    }
+
+    {
+      // The next user state will load all of its contents from the cache 
including the mutations
+      // persisted via asyncClose.
+      MultimapUserState<byte[], String> userState =
+          new MultimapUserState<>(
+              cache,
+              requestBuilder -> {
+                throw new IllegalStateException("Unexpected call for test.");
+              },
+              "instructionId",
+              createMultimapKeyStateKey(),
+              ByteArrayCoder.of(),
+              StringUtf8Coder.of());
+      assertArrayEquals(new String[] {"V2"}, 
Iterables.toArray(userState.get(A0), String.class));
+      userState.asyncClose();
+    }
+  }
+
+  @Test
+  public void testPutAfterClearCached() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                createMultimapKeyStateKey(),
+                KV.of(ByteArrayCoder.of(), singletonList(A0)),
+                createMultimapValueStateKey(A0),
+                KV.of(StringUtf8Coder.of(), asList("V1"))));
+    Cache<?, ?> cache = Caches.eternal();
+    {
+      // First user state populates the cache.
+      MultimapUserState<byte[], String> userState =
+          new MultimapUserState<>(
+              cache,
+              fakeClient,
+              "instructionId",
+              createMultimapKeyStateKey(),
+              ByteArrayCoder.of(),
+              StringUtf8Coder.of());
+      userState.clear();
+      userState.put(A0, "V2");
+      userState.asyncClose();
+    }
+
+    {
+      // The next user state will load all of its contents from the cache 
including the mutations
+      // persisted via asyncClose.
+      MultimapUserState<byte[], String> userState =
+          new MultimapUserState<>(
+              cache,
+              requestBuilder -> {
+                throw new IllegalStateException("Unexpected call for test.");
+              },
+              "instructionId",
+              createMultimapKeyStateKey(),
+              ByteArrayCoder.of(),
+              StringUtf8Coder.of());
+      assertArrayEquals(new String[] {"V2"}, 
Iterables.toArray(userState.get(A0), String.class));
+      // Even though we never load
+      assertArrayEquals(new byte[][] {A0}, Iterables.toArray(userState.keys(), 
byte[].class));
+      userState.asyncClose();
+    }
+  }
+
+  @Test
+  public void testRemoveCached() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                createMultimapKeyStateKey(),
+                KV.of(ByteArrayCoder.of(), singletonList(A1)),
+                createMultimapValueStateKey(A1),
+                KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
+    Cache<?, ?> cache = Caches.eternal();
+    {
+      // First user state populates the cache.
+      MultimapUserState<byte[], String> userState =
+          new MultimapUserState<>(
+              cache,
+              fakeClient,
+              "instructionId",
+              createMultimapKeyStateKey(),
+              ByteArrayCoder.of(),
+              StringUtf8Coder.of());
+      assertArrayEquals(
+          new String[] {"V1", "V2"}, Iterables.toArray(userState.get(A1), 
String.class));
+      userState.remove(A1);
+      userState.remove(A2);
+      assertThat(userState.keys(), is(emptyIterable()));
+      userState.asyncClose();
+    }
+
+    {
+      // The next user state will load all of its contents from the cache 
including the mutations
+      // persisted via asyncClose.
+      MultimapUserState<byte[], String> userState =
+          new MultimapUserState<>(
+              cache,
+              requestBuilder -> {
+                throw new IllegalStateException("Unexpected call for test.");
+              },
+              "instructionId",
+              createMultimapKeyStateKey(),
+              ByteArrayCoder.of(),
+              StringUtf8Coder.of());
+      assertThat(userState.get(A1), is(emptyIterable()));
+      assertThat(userState.get(A2), is(emptyIterable()));
+      assertThat(userState.keys(), is(emptyIterable()));
+      userState.asyncClose();
+    }
+  }
+
   private StateKey createMultimapKeyStateKey() throws IOException {
     return StateKey.newBuilder()
         .setMultimapKeysUserState(

Reply via email to