This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 29213ce  [BEAM-13354, BEAM-13015, BEAM-12802, BEAM-12588] Support 
prefetch for multimap and set state making loading keys and values truly lazy 
(#16092)
29213ce is described below

commit 29213ce366e7f7ef55ba2ed0b943532289421c19
Author: Lukasz Cwik <[email protected]>
AuthorDate: Fri Dec 3 11:23:28 2021 -0800

    [BEAM-13354, BEAM-13015, BEAM-12802, BEAM-12588] Support prefetch for 
multimap and set state making loading keys and values truly lazy (#16092)
    
    * [BEAM-13354, BEAM-13015, BEAM-12802, BEAM-12588] Support prefetch for 
multimap and set state making loading keys and values truly lazy.
    
    Also fix implementation to use structural values when comparing keys.
---
 .../beam/fn/harness/state/FnApiStateAccessor.java  |  19 +-
 .../beam/fn/harness/state/MultimapUserState.java   | 165 +++++---
 .../fn/harness/state/MultimapUserStateTest.java    | 449 +++++++++++++++------
 3 files changed, 450 insertions(+), 183 deletions(-)

diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
index 517be05..55052b4 100644
--- 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
@@ -352,7 +352,7 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
 
                       @Override
                       public ReadableState<Boolean> readLater() {
-                        // TODO: Support prefetching.
+                        impl.get(t).iterator().prefetch();
                         return this;
                       }
                     };
@@ -364,7 +364,6 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
                     if (isEmpty) {
                       impl.put(t, null);
                     }
-                    // TODO: Support prefetching.
                     return ReadableStates.immediate(isEmpty);
                   }
 
@@ -389,7 +388,7 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
 
                       @Override
                       public ReadableState<Boolean> readLater() {
-                        // TODO: Support prefetching.
+                        impl.keys().iterator().prefetch();
                         return this;
                       }
                     };
@@ -402,7 +401,7 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
 
                   @Override
                   public SetState<T> readLater() {
-                    // TODO: Support prefetching.
+                    impl.keys().iterator().prefetch();
                     return this;
                   }
                 };
@@ -469,7 +468,7 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
 
                       @Override
                       public ReadableState<ValueT> readLater() {
-                        // TODO: Support prefetching.
+                        impl.get(key).iterator().prefetch();
                         return this;
                       }
                     };
@@ -485,7 +484,7 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
 
                       @Override
                       public ReadableState<Iterable<KeyT>> readLater() {
-                        // TODO: Support prefetching.
+                        impl.keys().iterator().prefetch();
                         return this;
                       }
                     };
@@ -501,7 +500,7 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
 
                       @Override
                       public ReadableState<Iterable<ValueT>> readLater() {
-                        // TODO: Support prefetching.
+                        entries().readLater();
                         return this;
                       }
                     };
@@ -519,7 +518,9 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
 
                       @Override
                       public ReadableState<Iterable<Map.Entry<KeyT, ValueT>>> 
readLater() {
-                        // TODO: Support prefetching.
+                        // Start prefetching the keys. We would need to block 
to start prefetching
+                        // the values.
+                        keys().readLater();
                         return this;
                       }
                     };
@@ -535,7 +536,7 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
 
                       @Override
                       public ReadableState<Boolean> readLater() {
-                        // TODO: Support prefetching.
+                        keys().readLater();
                         return this;
                       }
                     };
diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java
index f679608..293d915 100644
--- 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java
@@ -21,24 +21,23 @@ import static 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Prec
 
 import java.io.IOException;
 import java.util.ArrayList;
-import java.util.Collection;
-import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.NoSuchElementException;
 import java.util.Set;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateAppendRequest;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateClearRequest;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterables;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
+import org.apache.beam.sdk.values.KV;
 import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
-import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ArrayListMultimap;
-import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
-import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Multimap;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets;
 import org.checkerframework.checker.nullness.qual.Nullable;
 
 /**
@@ -65,13 +64,11 @@ public class MultimapUserState<K, V> {
   private boolean isClosed;
   private boolean isCleared;
   // Pending updates to persistent storage
-  private HashSet<K> pendingRemoves = Sets.newHashSet();
-  private HashMap<K, List<V>> pendingAdds = Maps.newHashMap();
-  // Map keys with no values in persistent storage
-  private HashSet<K> negativeCache = Sets.newHashSet();
+  private HashMap<Object, K> pendingRemoves = Maps.newHashMap();
+  private HashMap<Object, KV<K, List<V>>> pendingAdds = Maps.newHashMap();
   // Values retrieved from persistent storage
-  private Multimap<K, V> persistedValues = ArrayListMultimap.create();
-  private @Nullable Iterable<K> persistedKeys = null;
+  private HashMap<K, PrefetchableIterable<V>> persistedValues = 
Maps.newHashMap();
+  private @Nullable PrefetchableIterable<K> persistedKeys = null;
 
   public MultimapUserState(
       BeamFnStateClient beamFnStateClient,
@@ -117,32 +114,36 @@ public class MultimapUserState<K, V> {
         keysStateRequest.getStateKey());
 
     isCleared = true;
-    persistedValues = ArrayListMultimap.create();
+    persistedValues = Maps.newHashMap();
     persistedKeys = null;
-    pendingRemoves = Sets.newHashSet();
+    pendingRemoves = Maps.newHashMap();
     pendingAdds = Maps.newHashMap();
-    negativeCache = Sets.newHashSet();
   }
 
   /*
    * Returns an iterable of the values associated with key in this multimap, 
if any.
    * If there are no values, this returns an empty collection, not null.
    */
-  public Iterable<V> get(K key) {
+  public PrefetchableIterable<V> get(K key) {
     checkState(
         !isClosed,
         "Multimap user state is no longer usable because it is closed for %s",
         keysStateRequest.getStateKey());
 
-    List<V> pendingAddValues = pendingAdds.getOrDefault(key, 
Collections.emptyList());
-    Collection<V> pendingValues =
-        Collections.unmodifiableCollection(pendingAddValues.subList(0, 
pendingAddValues.size()));
-    if (isCleared || pendingRemoves.contains(key)) {
+    Object structuralKey = mapKeyCoder.structuralValue(key);
+    KV<K, List<V>> pendingAddValues = pendingAdds.get(structuralKey);
+
+    PrefetchableIterable<V> pendingValues =
+        pendingAddValues == null
+            ? PrefetchableIterables.fromArray()
+            : PrefetchableIterables.limit(
+                pendingAddValues.getValue(), 
pendingAddValues.getValue().size());
+    if (isCleared || pendingRemoves.containsKey(structuralKey)) {
       return pendingValues;
     }
 
-    Iterable<V> persistedValues = getPersistedValues(key);
-    return Iterables.concat(persistedValues, pendingValues);
+    PrefetchableIterable<V> persistedValues = getPersistedValues(key);
+    return PrefetchableIterables.concat(persistedValues, pendingValues);
   }
 
   @SuppressWarnings({
@@ -151,19 +152,89 @@ public class MultimapUserState<K, V> {
   /*
    * Returns an iterables containing all distinct keys in this multimap.
    */
-  public Iterable<K> keys() {
+  public PrefetchableIterable<K> keys() {
     checkState(
         !isClosed,
         "Multimap user state is no longer usable because it is closed for %s",
         keysStateRequest.getStateKey());
     if (isCleared) {
-      return 
Collections.unmodifiableCollection(Lists.newArrayList(pendingAdds.keySet()));
+      List<K> keys = new ArrayList<>(pendingAdds.size());
+      for (Map.Entry<?, KV<K, List<V>>> entry : pendingAdds.entrySet()) {
+        keys.add(entry.getValue().getKey());
+      }
+      return PrefetchableIterables.concat(keys);
+    }
+
+    PrefetchableIterable<K> persistedKeys = getPersistedKeys();
+    Set<Object> pendingRemovesNow = new HashSet<>(pendingRemoves.keySet());
+    Map<Object, K> pendingAddsNow = new HashMap<>();
+    for (Map.Entry<Object, KV<K, List<V>>> entry : pendingAdds.entrySet()) {
+      pendingAddsNow.put(entry.getKey(), entry.getValue().getKey());
     }
+    return new PrefetchableIterable<K>() {
+      @Override
+      public PrefetchableIterator<K> iterator() {
+        return new PrefetchableIterator<K>() {
+          PrefetchableIterator<K> persistedKeysIterator = 
persistedKeys.iterator();
+          Iterator<K> pendingAddsNowIterator;
+          boolean hasNext;
+          K nextKey;
+
+          @Override
+          public boolean isReady() {
+            return persistedKeysIterator.isReady();
+          }
+
+          @Override
+          public void prefetch() {
+            if (!isReady()) {
+              persistedKeysIterator.prefetch();
+            }
+          }
+
+          @Override
+          public boolean hasNext() {
+            if (hasNext) {
+              return true;
+            }
 
-    Set<K> keys = Sets.newHashSet(getPersistedKeys());
-    keys.removeAll(pendingRemoves);
-    keys.addAll(pendingAdds.keySet());
-    return Collections.unmodifiableCollection(keys);
+            while (persistedKeysIterator.hasNext()) {
+              nextKey = persistedKeysIterator.next();
+              Object nextKeyStructuralValue = 
mapKeyCoder.structuralValue(nextKey);
+              if (!pendingRemovesNow.contains(nextKeyStructuralValue)) {
+                // Remove all keys that we will visit when passing over the 
persistedKeysIterator
+                // so we do not revisit them when passing over the 
pendingAddsNowIterator
+                if (pendingAddsNow.containsKey(nextKeyStructuralValue)) {
+                  pendingAddsNow.remove(nextKeyStructuralValue);
+                }
+                hasNext = true;
+                return true;
+              }
+            }
+
+            if (pendingAddsNowIterator == null) {
+              pendingAddsNowIterator = pendingAddsNow.values().iterator();
+            }
+            while (pendingAddsNowIterator.hasNext()) {
+              nextKey = pendingAddsNowIterator.next();
+              hasNext = true;
+              return true;
+            }
+
+            return false;
+          }
+
+          @Override
+          public K next() {
+            if (!hasNext()) {
+              throw new NoSuchElementException();
+            }
+            hasNext = false;
+            return nextKey;
+          }
+        };
+      }
+    };
   }
 
   /*
@@ -175,8 +246,9 @@ public class MultimapUserState<K, V> {
         !isClosed,
         "Multimap user state is no longer usable because it is closed for %s",
         keysStateRequest.getStateKey());
-    pendingAdds.putIfAbsent(key, new ArrayList<>());
-    pendingAdds.get(key).add(value);
+    Object keyStructuralValue = mapKeyCoder.structuralValue(key);
+    pendingAdds.putIfAbsent(keyStructuralValue, KV.of(key, new ArrayList<>()));
+    pendingAdds.get(keyStructuralValue).getValue().add(value);
   }
 
   /*
@@ -187,9 +259,10 @@ public class MultimapUserState<K, V> {
         !isClosed,
         "Multimap user state is no longer usable because it is closed for %s",
         keysStateRequest.getStateKey());
-    pendingAdds.remove(key);
+    Object keyStructuralValue = mapKeyCoder.structuralValue(key);
+    pendingAdds.remove(keyStructuralValue);
     if (!isCleared) {
-      pendingRemoves.add(key);
+      pendingRemoves.put(keyStructuralValue, key);
     }
   }
 
@@ -215,7 +288,7 @@ public class MultimapUserState<K, V> {
           
.handle(keysStateRequest.toBuilder().setClear(StateClearRequest.getDefaultInstance()))
           .get();
     } else if (!pendingRemoves.isEmpty()) {
-      for (K key : pendingRemoves) {
+      for (K key : pendingRemoves.values()) {
         beamFnStateClient
             .handle(
                 createUserStateRequest(key)
@@ -227,7 +300,7 @@ public class MultimapUserState<K, V> {
 
     // Persist pending key-values
     if (!pendingAdds.isEmpty()) {
-      for (Map.Entry<K, List<V>> entry : pendingAdds.entrySet()) {
+      for (KV<K, List<V>> entry : pendingAdds.values()) {
         beamFnStateClient
             .handle(
                 createUserStateRequest(entry.getKey())
@@ -265,30 +338,22 @@ public class MultimapUserState<K, V> {
     }
   }
 
-  private Iterable<V> getPersistedValues(K key) {
-    if (negativeCache.contains(key)) {
-      return Collections.emptyList();
-    }
-
-    if (persistedValues.get(key).isEmpty()) {
-      Iterable<V> values =
+  private PrefetchableIterable<V> getPersistedValues(K key) {
+    if (!persistedValues.containsKey(key)) {
+      PrefetchableIterable<V> values =
           StateFetchingIterators.readAllAndDecodeStartingFrom(
               beamFnStateClient, createUserStateRequest(key), valueCoder);
-      if (Iterables.isEmpty(values)) {
-        negativeCache.add(key);
-      }
-      persistedValues.putAll(key, values);
+      persistedValues.put(key, values);
     }
-    return Iterables.unmodifiableIterable(persistedValues.get(key));
+    return persistedValues.get(key);
   }
 
-  private Iterable<K> getPersistedKeys() {
+  private PrefetchableIterable<K> getPersistedKeys() {
     checkState(!isCleared);
     if (persistedKeys == null) {
-      Iterable<K> keys =
+      persistedKeys =
           StateFetchingIterators.readAllAndDecodeStartingFrom(
               beamFnStateClient, keysStateRequest, mapKeyCoder);
-      persistedKeys = Iterables.unmodifiableIterable(keys);
     }
     return persistedKeys;
   }
diff --git 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java
 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java
index 23f1be4..72509f9 100644
--- 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java
+++ 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java
@@ -26,12 +26,17 @@ import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertThrows;
 
 import java.io.IOException;
+import java.nio.charset.StandardCharsets;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.Iterator;
 import java.util.Map;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
+import org.apache.beam.sdk.coders.ByteArrayCoder;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.NullableCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
 import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
 import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
 import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
@@ -39,9 +44,19 @@ import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 
+/**
+ * Tests for {@link MultimapUserState}.
+ *
+ * <p>It is important to use a key type where its coder is not {@link 
Coder#consistentWithEquals()}
+ * to ensure that comparisons are performed using structural values instead of 
object equality
+ * during testing.
+ */
 @RunWith(JUnit4.class)
 public class MultimapUserStateTest {
-
+  private static final byte[] A0 = "A0".getBytes(StandardCharsets.UTF_8);
+  private static final byte[] A1 = "A1".getBytes(StandardCharsets.UTF_8);
+  private static final byte[] A2 = "A2".getBytes(StandardCharsets.UTF_8);
+  private static final byte[] A3 = "A3".getBytes(StandardCharsets.UTF_8);
   private final String pTransformId = "pTransformId";
   private final String stateId = "stateId";
   private final String encodedKey = "encodedKey";
@@ -50,7 +65,7 @@ public class MultimapUserStateTest {
   @Test
   public void testNoPersistedValues() throws Exception {
     FakeBeamFnStateClient fakeClient = new 
FakeBeamFnStateClient(Collections.emptyMap());
-    MultimapUserState<String, String> userState =
+    MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
             fakeClient,
             "instructionId",
@@ -58,7 +73,7 @@ public class MultimapUserStateTest {
             stateId,
             encode(encodedWindow),
             encode(encodedKey),
-            StringUtf8Coder.of(),
+            ByteArrayCoder.of(),
             StringUtf8Coder.of());
     assertThat(userState.keys(), is(emptyIterable()));
   }
@@ -69,10 +84,10 @@ public class MultimapUserStateTest {
         new FakeBeamFnStateClient(
             ImmutableMap.of(
                 createMultimapKeyStateKey(),
-                encode("A1"),
-                createMultimapValueStateKey("A1"),
+                encode(A1),
+                createMultimapValueStateKey(A1),
                 encode("V1", "V2")));
-    MultimapUserState<String, String> userState =
+    MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
             fakeClient,
             "instructionId",
@@ -80,17 +95,17 @@ public class MultimapUserStateTest {
             stateId,
             encode(encodedWindow),
             encode(encodedKey),
-            StringUtf8Coder.of(),
+            ByteArrayCoder.of(),
             StringUtf8Coder.of());
 
-    Iterable<String> initValues = userState.get("A1");
-    userState.put("A1", "V3");
+    Iterable<String> initValues = userState.get(A1);
+    userState.put(A1, "V3");
     assertArrayEquals(new String[] {"V1", "V2"}, Iterables.toArray(initValues, 
String.class));
     assertArrayEquals(
-        new String[] {"V1", "V2", "V3"}, 
Iterables.toArray(userState.get("A1"), String.class));
-    assertArrayEquals(new String[] {}, Iterables.toArray(userState.get("A2"), 
String.class));
+        new String[] {"V1", "V2", "V3"}, Iterables.toArray(userState.get(A1), 
String.class));
+    assertArrayEquals(new String[] {}, Iterables.toArray(userState.get(A2), 
String.class));
     userState.asyncClose();
-    assertThrows(IllegalStateException.class, () -> userState.get("A1"));
+    assertThrows(IllegalStateException.class, () -> userState.get(A1));
   }
 
   @Test
@@ -99,10 +114,10 @@ public class MultimapUserStateTest {
         new FakeBeamFnStateClient(
             ImmutableMap.of(
                 createMultimapKeyStateKey(),
-                encode("A1"),
-                createMultimapValueStateKey("A1"),
+                encode(A1),
+                createMultimapValueStateKey(A1),
                 encode("V1", "V2")));
-    MultimapUserState<String, String> userState =
+    MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
             fakeClient,
             "instructionId",
@@ -110,19 +125,19 @@ public class MultimapUserStateTest {
             stateId,
             encode(encodedWindow),
             encode(encodedKey),
-            StringUtf8Coder.of(),
+            ByteArrayCoder.of(),
             StringUtf8Coder.of());
 
-    Iterable<String> initValues = userState.get("A1");
+    Iterable<String> initValues = userState.get(A1);
     userState.clear();
     assertArrayEquals(new String[] {"V1", "V2"}, Iterables.toArray(initValues, 
String.class));
-    assertThat(userState.get("A1"), is(emptyIterable()));
+    assertThat(userState.get(A1), is(emptyIterable()));
     assertThat(userState.keys(), is(emptyIterable()));
 
-    userState.put("A1", "V1");
+    userState.put(A1, "V1");
     userState.clear();
     assertArrayEquals(new String[] {"V1", "V2"}, Iterables.toArray(initValues, 
String.class));
-    assertThat(userState.get("A1"), is(emptyIterable()));
+    assertThat(userState.get(A1), is(emptyIterable()));
     assertThat(userState.keys(), is(emptyIterable()));
 
     userState.asyncClose();
@@ -135,10 +150,10 @@ public class MultimapUserStateTest {
         new FakeBeamFnStateClient(
             ImmutableMap.of(
                 createMultimapKeyStateKey(),
-                encode("A1"),
-                createMultimapValueStateKey("A1"),
+                encode(A1),
+                createMultimapValueStateKey(A1),
                 encode("V1", "V2")));
-    MultimapUserState<String, String> userState =
+    MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
             fakeClient,
             "instructionId",
@@ -146,20 +161,19 @@ public class MultimapUserStateTest {
             stateId,
             encode(encodedWindow),
             encode(encodedKey),
-            StringUtf8Coder.of(),
+            ByteArrayCoder.of(),
             StringUtf8Coder.of());
 
-    userState.put("A2", "V1");
-    Iterable<String> initKeys = userState.keys();
-    userState.put("A3", "V1");
-    userState.put("A1", "V3");
-    assertArrayEquals(new String[] {"A1", "A2"}, Iterables.toArray(initKeys, 
String.class));
-    assertArrayEquals(
-        new String[] {"A1", "A2", "A3"}, Iterables.toArray(userState.keys(), 
String.class));
+    userState.put(A2, "V1");
+    Iterable<byte[]> initKeys = userState.keys();
+    userState.put(A3, "V1");
+    userState.put(A1, "V3");
+    assertArrayEquals(new byte[][] {A1, A2}, Iterables.toArray(initKeys, 
byte[].class));
+    assertArrayEquals(new byte[][] {A1, A2, A3}, 
Iterables.toArray(userState.keys(), byte[].class));
 
     userState.clear();
-    assertArrayEquals(new String[] {"A1", "A2"}, Iterables.toArray(initKeys, 
String.class));
-    assertArrayEquals(new String[] {}, Iterables.toArray(userState.keys(), 
String.class));
+    assertArrayEquals(new byte[][] {A1, A2}, Iterables.toArray(initKeys, 
byte[].class));
+    assertArrayEquals(new byte[][] {}, Iterables.toArray(userState.keys(), 
byte[].class));
     userState.asyncClose();
     assertThrows(IllegalStateException.class, () -> userState.keys());
   }
@@ -170,10 +184,10 @@ public class MultimapUserStateTest {
         new FakeBeamFnStateClient(
             ImmutableMap.of(
                 createMultimapKeyStateKey(),
-                encode("A1"),
-                createMultimapValueStateKey("A1"),
+                encode(A1),
+                createMultimapValueStateKey(A1),
                 encode("V1", "V2")));
-    MultimapUserState<String, String> userState =
+    MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
             fakeClient,
             "instructionId",
@@ -181,16 +195,16 @@ public class MultimapUserStateTest {
             stateId,
             encode(encodedWindow),
             encode(encodedKey),
-            StringUtf8Coder.of(),
+            ByteArrayCoder.of(),
             StringUtf8Coder.of());
 
-    Iterable<String> initValues = userState.get("A1");
-    userState.put("A1", "V3");
+    Iterable<String> initValues = userState.get(A1);
+    userState.put(A1, "V3");
     assertArrayEquals(new String[] {"V1", "V2"}, Iterables.toArray(initValues, 
String.class));
     assertArrayEquals(
-        new String[] {"V1", "V2", "V3"}, 
Iterables.toArray(userState.get("A1"), String.class));
+        new String[] {"V1", "V2", "V3"}, Iterables.toArray(userState.get(A1), 
String.class));
     userState.asyncClose();
-    assertThrows(IllegalStateException.class, () -> userState.put("A1", "V2"));
+    assertThrows(IllegalStateException.class, () -> userState.put(A1, "V2"));
   }
 
   @Test
@@ -199,10 +213,10 @@ public class MultimapUserStateTest {
         new FakeBeamFnStateClient(
             ImmutableMap.of(
                 createMultimapKeyStateKey(),
-                encode("A0"),
-                createMultimapValueStateKey("A0"),
+                encode(A0),
+                createMultimapValueStateKey(A0),
                 encode("V1")));
-    MultimapUserState<String, String> userState =
+    MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
             fakeClient,
             "instructionId",
@@ -210,14 +224,14 @@ public class MultimapUserStateTest {
             stateId,
             encode(encodedWindow),
             encode(encodedKey),
-            StringUtf8Coder.of(),
+            ByteArrayCoder.of(),
             StringUtf8Coder.of());
-    userState.remove("A0");
-    userState.put("A0", "V2");
-    assertArrayEquals(new String[] {"V2"}, 
Iterables.toArray(userState.get("A0"), String.class));
+    userState.remove(A0);
+    userState.put(A0, "V2");
+    assertArrayEquals(new String[] {"V2"}, 
Iterables.toArray(userState.get(A0), String.class));
     userState.asyncClose();
     Map<StateKey, ByteString> data = fakeClient.getData();
-    assertEquals(encode("V2"), data.get(createMultimapValueStateKey("A0")));
+    assertEquals(encode("V2"), data.get(createMultimapValueStateKey(A0)));
   }
 
   @Test
@@ -226,10 +240,10 @@ public class MultimapUserStateTest {
         new FakeBeamFnStateClient(
             ImmutableMap.of(
                 createMultimapKeyStateKey(),
-                encode("A0"),
-                createMultimapValueStateKey("A0"),
+                encode(A0),
+                createMultimapValueStateKey(A0),
                 encode("V1")));
-    MultimapUserState<String, String> userState =
+    MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
             fakeClient,
             "instructionId",
@@ -237,11 +251,11 @@ public class MultimapUserStateTest {
             stateId,
             encode(encodedWindow),
             encode(encodedKey),
-            StringUtf8Coder.of(),
+            ByteArrayCoder.of(),
             StringUtf8Coder.of());
     userState.clear();
-    userState.put("A0", "V2");
-    assertArrayEquals(new String[] {"V2"}, 
Iterables.toArray(userState.get("A0"), String.class));
+    userState.put(A0, "V2");
+    assertArrayEquals(new String[] {"V2"}, 
Iterables.toArray(userState.get(A0), String.class));
   }
 
   @Test
@@ -250,10 +264,10 @@ public class MultimapUserStateTest {
         new FakeBeamFnStateClient(
             ImmutableMap.of(
                 createMultimapKeyStateKey(),
-                encode("A0"),
-                createMultimapValueStateKey("A0"),
+                encode(A0),
+                createMultimapValueStateKey(A0),
                 encode("V1")));
-    MultimapUserState<String, String> userState =
+    MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
             fakeClient,
             "instructionId",
@@ -261,9 +275,9 @@ public class MultimapUserStateTest {
             stateId,
             encode(encodedWindow),
             encode(encodedKey),
-            StringUtf8Coder.of(),
+            ByteArrayCoder.of(),
             StringUtf8Coder.of());
-    userState.remove("A0");
+    userState.remove(A0);
     userState.clear();
     userState.asyncClose();
     // Clear takes precedence over specific key remove
@@ -273,7 +287,7 @@ public class MultimapUserStateTest {
   @Test
   public void testPutBeforeClear() throws Exception {
     FakeBeamFnStateClient fakeClient = new 
FakeBeamFnStateClient(Collections.emptyMap());
-    MultimapUserState<String, String> userState =
+    MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
             fakeClient,
             "instructionId",
@@ -281,11 +295,11 @@ public class MultimapUserStateTest {
             stateId,
             encode(encodedWindow),
             encode(encodedKey),
-            StringUtf8Coder.of(),
+            ByteArrayCoder.of(),
             StringUtf8Coder.of());
-    userState.put("A0", "V0");
-    userState.put("A1", "V1");
-    Iterable<String> values = userState.get("A1"); // fakeClient call = 1
+    userState.put(A0, "V0");
+    userState.put(A1, "V1");
+    Iterable<String> values = userState.get(A1); // fakeClient call = 1
     userState.clear(); // fakeClient call = 2
     assertArrayEquals(new String[] {"V1"}, Iterables.toArray(values, 
String.class));
     userState.asyncClose();
@@ -296,7 +310,7 @@ public class MultimapUserStateTest {
   @Test
   public void testPutBeforeRemove() throws Exception {
     FakeBeamFnStateClient fakeClient = new 
FakeBeamFnStateClient(Collections.emptyMap());
-    MultimapUserState<String, String> userState =
+    MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
             fakeClient,
             "instructionId",
@@ -304,18 +318,18 @@ public class MultimapUserStateTest {
             stateId,
             encode(encodedWindow),
             encode(encodedKey),
-            StringUtf8Coder.of(),
+            ByteArrayCoder.of(),
             StringUtf8Coder.of());
-    userState.put("A0", "V0");
-    userState.put("A1", "V1");
-    Iterable<String> values = userState.get("A1"); // fakeClient call = 1
-    userState.remove("A0"); // fakeClient call = 2
-    userState.remove("A1"); // fakeClient call = 3
+    userState.put(A0, "V0");
+    userState.put(A1, "V1");
+    Iterable<String> values = userState.get(A1); // fakeClient call = 1
+    userState.remove(A0); // fakeClient call = 2
+    userState.remove(A1); // fakeClient call = 3
     assertArrayEquals(new String[] {"V1"}, Iterables.toArray(values, 
String.class));
     userState.asyncClose();
     assertThat(fakeClient.getCallCount(), is(3));
-    assertNull(fakeClient.getData().get(createMultimapValueStateKey("A0")));
-    assertNull(fakeClient.getData().get(createMultimapValueStateKey("A1")));
+    assertNull(fakeClient.getData().get(createMultimapValueStateKey(A0)));
+    assertNull(fakeClient.getData().get(createMultimapValueStateKey(A1)));
   }
 
   @Test
@@ -324,10 +338,10 @@ public class MultimapUserStateTest {
         new FakeBeamFnStateClient(
             ImmutableMap.of(
                 createMultimapKeyStateKey(),
-                encode("A1"),
-                createMultimapValueStateKey("A1"),
+                encode(A1),
+                createMultimapValueStateKey(A1),
                 encode("V1", "V2")));
-    MultimapUserState<String, String> userState =
+    MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
             fakeClient,
             "instructionId",
@@ -335,17 +349,17 @@ public class MultimapUserStateTest {
             stateId,
             encode(encodedWindow),
             encode(encodedKey),
-            StringUtf8Coder.of(),
+            ByteArrayCoder.of(),
             StringUtf8Coder.of());
 
-    Iterable<String> initValues = userState.get("A1");
-    userState.put("A1", "V3");
+    Iterable<String> initValues = userState.get(A1);
+    userState.put(A1, "V3");
 
-    userState.remove("A1");
+    userState.remove(A1);
     assertArrayEquals(new String[] {"V1", "V2"}, Iterables.toArray(initValues, 
String.class));
     assertThat(userState.keys(), is(emptyIterable()));
     userState.asyncClose();
-    assertThrows(IllegalStateException.class, () -> userState.remove("A1"));
+    assertThrows(IllegalStateException.class, () -> userState.remove(A1));
   }
 
   @Test
@@ -354,10 +368,10 @@ public class MultimapUserStateTest {
         new FakeBeamFnStateClient(
             ImmutableMap.of(
                 createMultimapKeyStateKey(),
-                encode("A1"),
-                createMultimapValueStateKey("A1"),
+                encode(A1),
+                createMultimapValueStateKey(A1),
                 encode("V1", "V2")));
-    MultimapUserState<String, String> userState =
+    MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
             fakeClient,
             "instructionId",
@@ -365,11 +379,12 @@ public class MultimapUserStateTest {
             stateId,
             encode(encodedWindow),
             encode(encodedKey),
-            StringUtf8Coder.of(),
+            ByteArrayCoder.of(),
             StringUtf8Coder.of());
-    Iterable<String> keys = userState.keys();
-    assertThrows(
-        UnsupportedOperationException.class, () -> Iterables.removeAll(keys, 
Arrays.asList("A1")));
+    Iterable<byte[]> keys = userState.keys();
+    Iterator<byte[]> keysIterator = keys.iterator();
+    keysIterator.next();
+    assertThrows(UnsupportedOperationException.class, () -> 
keysIterator.remove());
   }
 
   @Test
@@ -378,10 +393,10 @@ public class MultimapUserStateTest {
         new FakeBeamFnStateClient(
             ImmutableMap.of(
                 createMultimapKeyStateKey(),
-                encode("A1"),
-                createMultimapValueStateKey("A1"),
+                encode(A1),
+                createMultimapValueStateKey(A1),
                 encode("V1", "V2")));
-    MultimapUserState<String, String> userState =
+    MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
             fakeClient,
             "instructionId",
@@ -389,9 +404,9 @@ public class MultimapUserStateTest {
             stateId,
             encode(encodedWindow),
             encode(encodedKey),
-            StringUtf8Coder.of(),
+            ByteArrayCoder.of(),
             StringUtf8Coder.of());
-    Iterable<String> values = userState.get("A1");
+    Iterable<String> values = userState.get(A1);
     assertThrows(
         UnsupportedOperationException.class,
         () -> Iterables.removeAll(values, Arrays.asList("V1")));
@@ -403,10 +418,10 @@ public class MultimapUserStateTest {
         new FakeBeamFnStateClient(
             ImmutableMap.of(
                 createMultimapKeyStateKey(),
-                encode("A1"),
-                createMultimapValueStateKey("A1"),
+                encode(A1),
+                createMultimapValueStateKey(A1),
                 encode("V1", "V2")));
-    MultimapUserState<String, String> userState =
+    MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
             fakeClient,
             "instructionId",
@@ -414,7 +429,7 @@ public class MultimapUserStateTest {
             stateId,
             encode(encodedWindow),
             encode(encodedKey),
-            StringUtf8Coder.of(),
+            ByteArrayCoder.of(),
             StringUtf8Coder.of());
     userState.clear();
     userState.asyncClose();
@@ -429,10 +444,10 @@ public class MultimapUserStateTest {
         new FakeBeamFnStateClient(
             ImmutableMap.of(
                 createMultimapKeyStateKey(),
-                encode("A1"),
-                createMultimapValueStateKey("A1"),
+                encode(A1),
+                createMultimapValueStateKey(A1),
                 encode("V1", "V2")));
-    MultimapUserState<String, String> userState =
+    MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
             fakeClient,
             "instructionId",
@@ -440,7 +455,7 @@ public class MultimapUserStateTest {
             stateId,
             encode(encodedWindow),
             encode(encodedKey),
-            StringUtf8Coder.of(),
+            ByteArrayCoder.of(),
             StringUtf8Coder.of());
     userState.asyncClose();
     assertThrows(IllegalStateException.class, () -> userState.keys());
@@ -453,12 +468,12 @@ public class MultimapUserStateTest {
         new FakeBeamFnStateClient(
             ImmutableMap.of(
                 createMultimapKeyStateKey(),
-                encode("A0", "A1"),
-                createMultimapValueStateKey("A0"),
+                encode(A0, A1),
+                createMultimapValueStateKey(A0),
                 encode("V1"),
-                createMultimapValueStateKey("A1"),
+                createMultimapValueStateKey(A1),
                 encode("V1", "V2")));
-    MultimapUserState<String, String> userState =
+    MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
             fakeClient,
             "instructionId",
@@ -466,18 +481,18 @@ public class MultimapUserStateTest {
             stateId,
             encode(encodedWindow),
             encode(encodedKey),
-            StringUtf8Coder.of(),
+            ByteArrayCoder.of(),
             StringUtf8Coder.of());
-    userState.remove("A0");
-    userState.put("A1", "V3");
-    userState.put("A2", "V1");
-    userState.put("A3", "V1");
-    userState.remove("A3");
+    userState.remove(A0);
+    userState.put(A1, "V3");
+    userState.put(A2, "V1");
+    userState.put(A3, "V1");
+    userState.remove(A3);
     userState.asyncClose();
     Map<StateKey, ByteString> data = fakeClient.getData();
-    assertNull(data.get(createMultimapValueStateKey("A0")));
-    assertEquals(encode("V1", "V2", "V3"), 
data.get(createMultimapValueStateKey("A1")));
-    assertEquals(encode("V1"), data.get(createMultimapValueStateKey("A2")));
+    assertNull(data.get(createMultimapValueStateKey(A0)));
+    assertEquals(encode("V1", "V2", "V3"), 
data.get(createMultimapValueStateKey(A1)));
+    assertEquals(encode("V1"), data.get(createMultimapValueStateKey(A2)));
   }
 
   @Test
@@ -486,10 +501,10 @@ public class MultimapUserStateTest {
         new FakeBeamFnStateClient(
             ImmutableMap.of(
                 createMultimapKeyStateKey(),
-                encode("A1"),
-                createMultimapValueStateKey("A1"),
+                encode(A1),
+                createMultimapValueStateKey(A1),
                 encode("V1", "V2")));
-    MultimapUserState<String, String> userState =
+    MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
             fakeClient,
             "instructionId",
@@ -497,7 +512,7 @@ public class MultimapUserStateTest {
             stateId,
             encode(encodedWindow),
             encode(encodedKey),
-            NullableCoder.of(StringUtf8Coder.of()),
+            NullableCoder.of(ByteArrayCoder.of()),
             NullableCoder.of(StringUtf8Coder.of()));
     userState.put(null, null);
     userState.put(null, null);
@@ -509,7 +524,7 @@ public class MultimapUserStateTest {
   @Test
   public void testNegativeCache() throws Exception {
     FakeBeamFnStateClient fakeClient = new 
FakeBeamFnStateClient(Collections.emptyMap());
-    MultimapUserState<String, String> userState =
+    MultimapUserState<byte[], String> userState =
         new MultimapUserState<>(
             fakeClient,
             "instructionId",
@@ -517,13 +532,191 @@ public class MultimapUserStateTest {
             stateId,
             encode(encodedWindow),
             encode(encodedKey),
-            StringUtf8Coder.of(),
+            ByteArrayCoder.of(),
             StringUtf8Coder.of());
-    userState.get("A1");
-    userState.get("A1");
+    assertArrayEquals(new String[] {}, Iterables.toArray(userState.get(A1), 
String.class));
+    assertArrayEquals(new String[] {}, Iterables.toArray(userState.get(A1), 
String.class));
     assertThat(fakeClient.getCallCount(), is(1));
   }
 
+  @Test
+  public void testGetValuesPrefetch() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                createMultimapKeyStateKey(),
+                encode(A1),
+                createMultimapValueStateKey(A1),
+                encode("V1", "V2")));
+    MultimapUserState<byte[], String> userState =
+        new MultimapUserState<>(
+            fakeClient,
+            "instructionId",
+            pTransformId,
+            stateId,
+            encode(encodedWindow),
+            encode(encodedKey),
+            ByteArrayCoder.of(),
+            StringUtf8Coder.of());
+
+    PrefetchableIterable<String> values = userState.get(A1);
+    assertEquals(0, fakeClient.getCallCount());
+    values.iterator().prefetch();
+    assertEquals(1, fakeClient.getCallCount());
+    assertArrayEquals(new String[] {"V1", "V2"}, Iterables.toArray(values, 
String.class));
+    assertEquals(1, fakeClient.getCallCount());
+  }
+
+  @Test
+  public void testGetKeysPrefetch() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                createMultimapKeyStateKey(),
+                encode(A1),
+                createMultimapValueStateKey(A1),
+                encode("V1", "V2")));
+    MultimapUserState<byte[], String> userState =
+        new MultimapUserState<>(
+            fakeClient,
+            "instructionId",
+            pTransformId,
+            stateId,
+            encode(encodedWindow),
+            encode(encodedKey),
+            ByteArrayCoder.of(),
+            StringUtf8Coder.of());
+
+    PrefetchableIterable<byte[]> keys = userState.keys();
+    assertEquals(0, fakeClient.getCallCount());
+    keys.iterator().prefetch();
+    assertEquals(1, fakeClient.getCallCount());
+    assertArrayEquals(new byte[][] {A1}, Iterables.toArray(keys, 
byte[].class));
+    assertEquals(1, fakeClient.getCallCount());
+  }
+
+  @Test
+  public void testPutKeysPrefetch() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                createMultimapKeyStateKey(),
+                encode(A1),
+                createMultimapValueStateKey(A1),
+                encode("V1", "V2")));
+    MultimapUserState<byte[], String> userState =
+        new MultimapUserState<>(
+            fakeClient,
+            "instructionId",
+            pTransformId,
+            stateId,
+            encode(encodedWindow),
+            encode(encodedKey),
+            ByteArrayCoder.of(),
+            StringUtf8Coder.of());
+
+    userState.put(A2, "V3");
+    PrefetchableIterable<byte[]> keys = userState.keys();
+    assertEquals(0, fakeClient.getCallCount());
+    keys.iterator().prefetch();
+    assertEquals(1, fakeClient.getCallCount());
+    assertArrayEquals(new byte[][] {A1, A2}, Iterables.toArray(keys, 
byte[].class));
+    assertEquals(1, fakeClient.getCallCount());
+  }
+
+  @Test
+  public void testRemoveKeysPrefetch() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                createMultimapKeyStateKey(),
+                encode(A1),
+                createMultimapValueStateKey(A1),
+                encode("V1", "V2")));
+    MultimapUserState<byte[], String> userState =
+        new MultimapUserState<>(
+            fakeClient,
+            "instructionId",
+            pTransformId,
+            stateId,
+            encode(encodedWindow),
+            encode(encodedKey),
+            ByteArrayCoder.of(),
+            StringUtf8Coder.of());
+
+    userState.remove(A1);
+    userState.put(A1, "V3");
+    PrefetchableIterable<String> values = userState.get(A1);
+    assertEquals(0, fakeClient.getCallCount());
+    values.iterator().prefetch();
+    // Removed keys don't require accessing the underlying persisted state
+    assertEquals(0, fakeClient.getCallCount());
+    assertArrayEquals(new String[] {"V3"}, Iterables.toArray(values, 
String.class));
+    // Removed keys don't require accessing the underlying persisted state
+    assertEquals(0, fakeClient.getCallCount());
+  }
+
+  @Test
+  public void testClearPrefetch() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                createMultimapKeyStateKey(),
+                encode(A1),
+                createMultimapValueStateKey(A1),
+                encode("V1", "V2")));
+    MultimapUserState<byte[], String> userState =
+        new MultimapUserState<>(
+            fakeClient,
+            "instructionId",
+            pTransformId,
+            stateId,
+            encode(encodedWindow),
+            encode(encodedKey),
+            ByteArrayCoder.of(),
+            StringUtf8Coder.of());
+
+    userState.clear();
+    userState.put(A2, "V3");
+    PrefetchableIterable<byte[]> keys = userState.keys();
+    assertEquals(0, fakeClient.getCallCount());
+    keys.iterator().prefetch();
+    // Cleared keys don't require accessing the underlying persisted state
+    assertEquals(0, fakeClient.getCallCount());
+    assertArrayEquals(new byte[][] {A2}, Iterables.toArray(keys, 
byte[].class));
+    // Cleared keys don't require accessing the underlying persisted state
+    assertEquals(0, fakeClient.getCallCount());
+  }
+
+  @Test
+  public void testAppendValuesPrefetch() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                createMultimapKeyStateKey(),
+                encode(A1),
+                createMultimapValueStateKey(A1),
+                encode("V1", "V2")));
+    MultimapUserState<byte[], String> userState =
+        new MultimapUserState<>(
+            fakeClient,
+            "instructionId",
+            pTransformId,
+            stateId,
+            encode(encodedWindow),
+            encode(encodedKey),
+            ByteArrayCoder.of(),
+            StringUtf8Coder.of());
+
+    userState.put(A1, "V3");
+    PrefetchableIterable<String> values = userState.get(A1);
+    assertEquals(0, fakeClient.getCallCount());
+    values.iterator().prefetch();
+    assertEquals(1, fakeClient.getCallCount());
+    assertArrayEquals(new String[] {"V1", "V2", "V3"}, 
Iterables.toArray(values, String.class));
+    assertEquals(1, fakeClient.getCallCount());
+  }
+
   private StateKey createMultimapKeyStateKey() throws IOException {
     return StateKey.newBuilder()
         .setMultimapKeysUserState(
@@ -535,7 +728,7 @@ public class MultimapUserStateTest {
         .build();
   }
 
-  private StateKey createMultimapValueStateKey(String key) throws IOException {
+  private StateKey createMultimapValueStateKey(byte[] key) throws IOException {
     return StateKey.newBuilder()
         .setMultimapUserState(
             StateKey.MultimapUserState.newBuilder()
@@ -554,4 +747,12 @@ public class MultimapUserStateTest {
     }
     return out.toByteString();
   }
+
+  private ByteString encode(byte[]... values) throws IOException {
+    ByteString.Output out = ByteString.newOutput();
+    for (byte[] value : values) {
+      ByteArrayCoder.of().encode(value, out);
+    }
+    return out.toByteString();
+  }
 }

Reply via email to