This is an automated email from the ASF dual-hosted git repository.
lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 29213ce [BEAM-13354, BEAM-13015, BEAM-12802, BEAM-12588] Support
prefetch for multimap and set state making loading keys and values truly lazy
(#16092)
29213ce is described below
commit 29213ce366e7f7ef55ba2ed0b943532289421c19
Author: Lukasz Cwik <[email protected]>
AuthorDate: Fri Dec 3 11:23:28 2021 -0800
[BEAM-13354, BEAM-13015, BEAM-12802, BEAM-12588] Support prefetch for
multimap and set state making loading keys and values truly lazy (#16092)
* [BEAM-13354, BEAM-13015, BEAM-12802, BEAM-12588] Support prefetch for
multimap and set state making loading keys and values truly lazy.
Also fix implementation to use structural values when comparing keys.
---
.../beam/fn/harness/state/FnApiStateAccessor.java | 19 +-
.../beam/fn/harness/state/MultimapUserState.java | 165 +++++---
.../fn/harness/state/MultimapUserStateTest.java | 449 +++++++++++++++------
3 files changed, 450 insertions(+), 183 deletions(-)
diff --git
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
index 517be05..55052b4 100644
---
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
+++
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
@@ -352,7 +352,7 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
@Override
public ReadableState<Boolean> readLater() {
- // TODO: Support prefetching.
+ impl.get(t).iterator().prefetch();
return this;
}
};
@@ -364,7 +364,6 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
if (isEmpty) {
impl.put(t, null);
}
- // TODO: Support prefetching.
return ReadableStates.immediate(isEmpty);
}
@@ -389,7 +388,7 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
@Override
public ReadableState<Boolean> readLater() {
- // TODO: Support prefetching.
+ impl.keys().iterator().prefetch();
return this;
}
};
@@ -402,7 +401,7 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
@Override
public SetState<T> readLater() {
- // TODO: Support prefetching.
+ impl.keys().iterator().prefetch();
return this;
}
};
@@ -469,7 +468,7 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
@Override
public ReadableState<ValueT> readLater() {
- // TODO: Support prefetching.
+ impl.get(key).iterator().prefetch();
return this;
}
};
@@ -485,7 +484,7 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
@Override
public ReadableState<Iterable<KeyT>> readLater() {
- // TODO: Support prefetching.
+ impl.keys().iterator().prefetch();
return this;
}
};
@@ -501,7 +500,7 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
@Override
public ReadableState<Iterable<ValueT>> readLater() {
- // TODO: Support prefetching.
+ entries().readLater();
return this;
}
};
@@ -519,7 +518,9 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
@Override
public ReadableState<Iterable<Map.Entry<KeyT, ValueT>>>
readLater() {
- // TODO: Support prefetching.
+ // Start prefetching the keys. We would need to block
to start prefetching
+ // the values.
+ keys().readLater();
return this;
}
};
@@ -535,7 +536,7 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
@Override
public ReadableState<Boolean> readLater() {
- // TODO: Support prefetching.
+ keys().readLater();
return this;
}
};
diff --git
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java
index f679608..293d915 100644
---
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java
+++
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java
@@ -21,24 +21,23 @@ import static
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Prec
import java.io.IOException;
import java.util.ArrayList;
-import java.util.Collection;
-import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
+import java.util.Iterator;
import java.util.List;
import java.util.Map;
+import java.util.NoSuchElementException;
import java.util.Set;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateAppendRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateClearRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterables;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
+import org.apache.beam.sdk.values.KV;
import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
-import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ArrayListMultimap;
-import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
-import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Multimap;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets;
import org.checkerframework.checker.nullness.qual.Nullable;
/**
@@ -65,13 +64,11 @@ public class MultimapUserState<K, V> {
private boolean isClosed;
private boolean isCleared;
// Pending updates to persistent storage
- private HashSet<K> pendingRemoves = Sets.newHashSet();
- private HashMap<K, List<V>> pendingAdds = Maps.newHashMap();
- // Map keys with no values in persistent storage
- private HashSet<K> negativeCache = Sets.newHashSet();
+ private HashMap<Object, K> pendingRemoves = Maps.newHashMap();
+ private HashMap<Object, KV<K, List<V>>> pendingAdds = Maps.newHashMap();
// Values retrieved from persistent storage
- private Multimap<K, V> persistedValues = ArrayListMultimap.create();
- private @Nullable Iterable<K> persistedKeys = null;
+ private HashMap<K, PrefetchableIterable<V>> persistedValues =
Maps.newHashMap();
+ private @Nullable PrefetchableIterable<K> persistedKeys = null;
public MultimapUserState(
BeamFnStateClient beamFnStateClient,
@@ -117,32 +114,36 @@ public class MultimapUserState<K, V> {
keysStateRequest.getStateKey());
isCleared = true;
- persistedValues = ArrayListMultimap.create();
+ persistedValues = Maps.newHashMap();
persistedKeys = null;
- pendingRemoves = Sets.newHashSet();
+ pendingRemoves = Maps.newHashMap();
pendingAdds = Maps.newHashMap();
- negativeCache = Sets.newHashSet();
}
/*
* Returns an iterable of the values associated with key in this multimap,
if any.
* If there are no values, this returns an empty collection, not null.
*/
- public Iterable<V> get(K key) {
+ public PrefetchableIterable<V> get(K key) {
checkState(
!isClosed,
"Multimap user state is no longer usable because it is closed for %s",
keysStateRequest.getStateKey());
- List<V> pendingAddValues = pendingAdds.getOrDefault(key,
Collections.emptyList());
- Collection<V> pendingValues =
- Collections.unmodifiableCollection(pendingAddValues.subList(0,
pendingAddValues.size()));
- if (isCleared || pendingRemoves.contains(key)) {
+ Object structuralKey = mapKeyCoder.structuralValue(key);
+ KV<K, List<V>> pendingAddValues = pendingAdds.get(structuralKey);
+
+ PrefetchableIterable<V> pendingValues =
+ pendingAddValues == null
+ ? PrefetchableIterables.fromArray()
+ : PrefetchableIterables.limit(
+ pendingAddValues.getValue(),
pendingAddValues.getValue().size());
+ if (isCleared || pendingRemoves.containsKey(structuralKey)) {
return pendingValues;
}
- Iterable<V> persistedValues = getPersistedValues(key);
- return Iterables.concat(persistedValues, pendingValues);
+ PrefetchableIterable<V> persistedValues = getPersistedValues(key);
+ return PrefetchableIterables.concat(persistedValues, pendingValues);
}
@SuppressWarnings({
@@ -151,19 +152,89 @@ public class MultimapUserState<K, V> {
/*
* Returns an iterables containing all distinct keys in this multimap.
*/
- public Iterable<K> keys() {
+ public PrefetchableIterable<K> keys() {
checkState(
!isClosed,
"Multimap user state is no longer usable because it is closed for %s",
keysStateRequest.getStateKey());
if (isCleared) {
- return
Collections.unmodifiableCollection(Lists.newArrayList(pendingAdds.keySet()));
+ List<K> keys = new ArrayList<>(pendingAdds.size());
+ for (Map.Entry<?, KV<K, List<V>>> entry : pendingAdds.entrySet()) {
+ keys.add(entry.getValue().getKey());
+ }
+ return PrefetchableIterables.concat(keys);
+ }
+
+ PrefetchableIterable<K> persistedKeys = getPersistedKeys();
+ Set<Object> pendingRemovesNow = new HashSet<>(pendingRemoves.keySet());
+ Map<Object, K> pendingAddsNow = new HashMap<>();
+ for (Map.Entry<Object, KV<K, List<V>>> entry : pendingAdds.entrySet()) {
+ pendingAddsNow.put(entry.getKey(), entry.getValue().getKey());
}
+ return new PrefetchableIterable<K>() {
+ @Override
+ public PrefetchableIterator<K> iterator() {
+ return new PrefetchableIterator<K>() {
+ PrefetchableIterator<K> persistedKeysIterator =
persistedKeys.iterator();
+ Iterator<K> pendingAddsNowIterator;
+ boolean hasNext;
+ K nextKey;
+
+ @Override
+ public boolean isReady() {
+ return persistedKeysIterator.isReady();
+ }
+
+ @Override
+ public void prefetch() {
+ if (!isReady()) {
+ persistedKeysIterator.prefetch();
+ }
+ }
+
+ @Override
+ public boolean hasNext() {
+ if (hasNext) {
+ return true;
+ }
- Set<K> keys = Sets.newHashSet(getPersistedKeys());
- keys.removeAll(pendingRemoves);
- keys.addAll(pendingAdds.keySet());
- return Collections.unmodifiableCollection(keys);
+ while (persistedKeysIterator.hasNext()) {
+ nextKey = persistedKeysIterator.next();
+ Object nextKeyStructuralValue =
mapKeyCoder.structuralValue(nextKey);
+ if (!pendingRemovesNow.contains(nextKeyStructuralValue)) {
+ // Remove all keys that we will visit when passing over the
persistedKeysIterator
+ // so we do not revisit them when passing over the
pendingAddsNowIterator
+ if (pendingAddsNow.containsKey(nextKeyStructuralValue)) {
+ pendingAddsNow.remove(nextKeyStructuralValue);
+ }
+ hasNext = true;
+ return true;
+ }
+ }
+
+ if (pendingAddsNowIterator == null) {
+ pendingAddsNowIterator = pendingAddsNow.values().iterator();
+ }
+ while (pendingAddsNowIterator.hasNext()) {
+ nextKey = pendingAddsNowIterator.next();
+ hasNext = true;
+ return true;
+ }
+
+ return false;
+ }
+
+ @Override
+ public K next() {
+ if (!hasNext()) {
+ throw new NoSuchElementException();
+ }
+ hasNext = false;
+ return nextKey;
+ }
+ };
+ }
+ };
}
/*
@@ -175,8 +246,9 @@ public class MultimapUserState<K, V> {
!isClosed,
"Multimap user state is no longer usable because it is closed for %s",
keysStateRequest.getStateKey());
- pendingAdds.putIfAbsent(key, new ArrayList<>());
- pendingAdds.get(key).add(value);
+ Object keyStructuralValue = mapKeyCoder.structuralValue(key);
+ pendingAdds.putIfAbsent(keyStructuralValue, KV.of(key, new ArrayList<>()));
+ pendingAdds.get(keyStructuralValue).getValue().add(value);
}
/*
@@ -187,9 +259,10 @@ public class MultimapUserState<K, V> {
!isClosed,
"Multimap user state is no longer usable because it is closed for %s",
keysStateRequest.getStateKey());
- pendingAdds.remove(key);
+ Object keyStructuralValue = mapKeyCoder.structuralValue(key);
+ pendingAdds.remove(keyStructuralValue);
if (!isCleared) {
- pendingRemoves.add(key);
+ pendingRemoves.put(keyStructuralValue, key);
}
}
@@ -215,7 +288,7 @@ public class MultimapUserState<K, V> {
.handle(keysStateRequest.toBuilder().setClear(StateClearRequest.getDefaultInstance()))
.get();
} else if (!pendingRemoves.isEmpty()) {
- for (K key : pendingRemoves) {
+ for (K key : pendingRemoves.values()) {
beamFnStateClient
.handle(
createUserStateRequest(key)
@@ -227,7 +300,7 @@ public class MultimapUserState<K, V> {
// Persist pending key-values
if (!pendingAdds.isEmpty()) {
- for (Map.Entry<K, List<V>> entry : pendingAdds.entrySet()) {
+ for (KV<K, List<V>> entry : pendingAdds.values()) {
beamFnStateClient
.handle(
createUserStateRequest(entry.getKey())
@@ -265,30 +338,22 @@ public class MultimapUserState<K, V> {
}
}
- private Iterable<V> getPersistedValues(K key) {
- if (negativeCache.contains(key)) {
- return Collections.emptyList();
- }
-
- if (persistedValues.get(key).isEmpty()) {
- Iterable<V> values =
+ private PrefetchableIterable<V> getPersistedValues(K key) {
+ if (!persistedValues.containsKey(key)) {
+ PrefetchableIterable<V> values =
StateFetchingIterators.readAllAndDecodeStartingFrom(
beamFnStateClient, createUserStateRequest(key), valueCoder);
- if (Iterables.isEmpty(values)) {
- negativeCache.add(key);
- }
- persistedValues.putAll(key, values);
+ persistedValues.put(key, values);
}
- return Iterables.unmodifiableIterable(persistedValues.get(key));
+ return persistedValues.get(key);
}
- private Iterable<K> getPersistedKeys() {
+ private PrefetchableIterable<K> getPersistedKeys() {
checkState(!isCleared);
if (persistedKeys == null) {
- Iterable<K> keys =
+ persistedKeys =
StateFetchingIterators.readAllAndDecodeStartingFrom(
beamFnStateClient, keysStateRequest, mapKeyCoder);
- persistedKeys = Iterables.unmodifiableIterable(keys);
}
return persistedKeys;
}
diff --git
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java
index 23f1be4..72509f9 100644
---
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java
+++
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java
@@ -26,12 +26,17 @@ import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThrows;
import java.io.IOException;
+import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
+import java.util.Iterator;
import java.util.Map;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
+import org.apache.beam.sdk.coders.ByteArrayCoder;
+import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.NullableCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
@@ -39,9 +44,19 @@ import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
+/**
+ * Tests for {@link MultimapUserState}.
+ *
+ * <p>It is important to use a key type where its coder is not {@link
Coder#consistentWithEquals()}
+ * to ensure that comparisons are performed using structural values instead of
object equality
+ * during testing.
+ */
@RunWith(JUnit4.class)
public class MultimapUserStateTest {
-
+ private static final byte[] A0 = "A0".getBytes(StandardCharsets.UTF_8);
+ private static final byte[] A1 = "A1".getBytes(StandardCharsets.UTF_8);
+ private static final byte[] A2 = "A2".getBytes(StandardCharsets.UTF_8);
+ private static final byte[] A3 = "A3".getBytes(StandardCharsets.UTF_8);
private final String pTransformId = "pTransformId";
private final String stateId = "stateId";
private final String encodedKey = "encodedKey";
@@ -50,7 +65,7 @@ public class MultimapUserStateTest {
@Test
public void testNoPersistedValues() throws Exception {
FakeBeamFnStateClient fakeClient = new
FakeBeamFnStateClient(Collections.emptyMap());
- MultimapUserState<String, String> userState =
+ MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
fakeClient,
"instructionId",
@@ -58,7 +73,7 @@ public class MultimapUserStateTest {
stateId,
encode(encodedWindow),
encode(encodedKey),
- StringUtf8Coder.of(),
+ ByteArrayCoder.of(),
StringUtf8Coder.of());
assertThat(userState.keys(), is(emptyIterable()));
}
@@ -69,10 +84,10 @@ public class MultimapUserStateTest {
new FakeBeamFnStateClient(
ImmutableMap.of(
createMultimapKeyStateKey(),
- encode("A1"),
- createMultimapValueStateKey("A1"),
+ encode(A1),
+ createMultimapValueStateKey(A1),
encode("V1", "V2")));
- MultimapUserState<String, String> userState =
+ MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
fakeClient,
"instructionId",
@@ -80,17 +95,17 @@ public class MultimapUserStateTest {
stateId,
encode(encodedWindow),
encode(encodedKey),
- StringUtf8Coder.of(),
+ ByteArrayCoder.of(),
StringUtf8Coder.of());
- Iterable<String> initValues = userState.get("A1");
- userState.put("A1", "V3");
+ Iterable<String> initValues = userState.get(A1);
+ userState.put(A1, "V3");
assertArrayEquals(new String[] {"V1", "V2"}, Iterables.toArray(initValues,
String.class));
assertArrayEquals(
- new String[] {"V1", "V2", "V3"},
Iterables.toArray(userState.get("A1"), String.class));
- assertArrayEquals(new String[] {}, Iterables.toArray(userState.get("A2"),
String.class));
+ new String[] {"V1", "V2", "V3"}, Iterables.toArray(userState.get(A1),
String.class));
+ assertArrayEquals(new String[] {}, Iterables.toArray(userState.get(A2),
String.class));
userState.asyncClose();
- assertThrows(IllegalStateException.class, () -> userState.get("A1"));
+ assertThrows(IllegalStateException.class, () -> userState.get(A1));
}
@Test
@@ -99,10 +114,10 @@ public class MultimapUserStateTest {
new FakeBeamFnStateClient(
ImmutableMap.of(
createMultimapKeyStateKey(),
- encode("A1"),
- createMultimapValueStateKey("A1"),
+ encode(A1),
+ createMultimapValueStateKey(A1),
encode("V1", "V2")));
- MultimapUserState<String, String> userState =
+ MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
fakeClient,
"instructionId",
@@ -110,19 +125,19 @@ public class MultimapUserStateTest {
stateId,
encode(encodedWindow),
encode(encodedKey),
- StringUtf8Coder.of(),
+ ByteArrayCoder.of(),
StringUtf8Coder.of());
- Iterable<String> initValues = userState.get("A1");
+ Iterable<String> initValues = userState.get(A1);
userState.clear();
assertArrayEquals(new String[] {"V1", "V2"}, Iterables.toArray(initValues,
String.class));
- assertThat(userState.get("A1"), is(emptyIterable()));
+ assertThat(userState.get(A1), is(emptyIterable()));
assertThat(userState.keys(), is(emptyIterable()));
- userState.put("A1", "V1");
+ userState.put(A1, "V1");
userState.clear();
assertArrayEquals(new String[] {"V1", "V2"}, Iterables.toArray(initValues,
String.class));
- assertThat(userState.get("A1"), is(emptyIterable()));
+ assertThat(userState.get(A1), is(emptyIterable()));
assertThat(userState.keys(), is(emptyIterable()));
userState.asyncClose();
@@ -135,10 +150,10 @@ public class MultimapUserStateTest {
new FakeBeamFnStateClient(
ImmutableMap.of(
createMultimapKeyStateKey(),
- encode("A1"),
- createMultimapValueStateKey("A1"),
+ encode(A1),
+ createMultimapValueStateKey(A1),
encode("V1", "V2")));
- MultimapUserState<String, String> userState =
+ MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
fakeClient,
"instructionId",
@@ -146,20 +161,19 @@ public class MultimapUserStateTest {
stateId,
encode(encodedWindow),
encode(encodedKey),
- StringUtf8Coder.of(),
+ ByteArrayCoder.of(),
StringUtf8Coder.of());
- userState.put("A2", "V1");
- Iterable<String> initKeys = userState.keys();
- userState.put("A3", "V1");
- userState.put("A1", "V3");
- assertArrayEquals(new String[] {"A1", "A2"}, Iterables.toArray(initKeys,
String.class));
- assertArrayEquals(
- new String[] {"A1", "A2", "A3"}, Iterables.toArray(userState.keys(),
String.class));
+ userState.put(A2, "V1");
+ Iterable<byte[]> initKeys = userState.keys();
+ userState.put(A3, "V1");
+ userState.put(A1, "V3");
+ assertArrayEquals(new byte[][] {A1, A2}, Iterables.toArray(initKeys,
byte[].class));
+ assertArrayEquals(new byte[][] {A1, A2, A3},
Iterables.toArray(userState.keys(), byte[].class));
userState.clear();
- assertArrayEquals(new String[] {"A1", "A2"}, Iterables.toArray(initKeys,
String.class));
- assertArrayEquals(new String[] {}, Iterables.toArray(userState.keys(),
String.class));
+ assertArrayEquals(new byte[][] {A1, A2}, Iterables.toArray(initKeys,
byte[].class));
+ assertArrayEquals(new byte[][] {}, Iterables.toArray(userState.keys(),
byte[].class));
userState.asyncClose();
assertThrows(IllegalStateException.class, () -> userState.keys());
}
@@ -170,10 +184,10 @@ public class MultimapUserStateTest {
new FakeBeamFnStateClient(
ImmutableMap.of(
createMultimapKeyStateKey(),
- encode("A1"),
- createMultimapValueStateKey("A1"),
+ encode(A1),
+ createMultimapValueStateKey(A1),
encode("V1", "V2")));
- MultimapUserState<String, String> userState =
+ MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
fakeClient,
"instructionId",
@@ -181,16 +195,16 @@ public class MultimapUserStateTest {
stateId,
encode(encodedWindow),
encode(encodedKey),
- StringUtf8Coder.of(),
+ ByteArrayCoder.of(),
StringUtf8Coder.of());
- Iterable<String> initValues = userState.get("A1");
- userState.put("A1", "V3");
+ Iterable<String> initValues = userState.get(A1);
+ userState.put(A1, "V3");
assertArrayEquals(new String[] {"V1", "V2"}, Iterables.toArray(initValues,
String.class));
assertArrayEquals(
- new String[] {"V1", "V2", "V3"},
Iterables.toArray(userState.get("A1"), String.class));
+ new String[] {"V1", "V2", "V3"}, Iterables.toArray(userState.get(A1),
String.class));
userState.asyncClose();
- assertThrows(IllegalStateException.class, () -> userState.put("A1", "V2"));
+ assertThrows(IllegalStateException.class, () -> userState.put(A1, "V2"));
}
@Test
@@ -199,10 +213,10 @@ public class MultimapUserStateTest {
new FakeBeamFnStateClient(
ImmutableMap.of(
createMultimapKeyStateKey(),
- encode("A0"),
- createMultimapValueStateKey("A0"),
+ encode(A0),
+ createMultimapValueStateKey(A0),
encode("V1")));
- MultimapUserState<String, String> userState =
+ MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
fakeClient,
"instructionId",
@@ -210,14 +224,14 @@ public class MultimapUserStateTest {
stateId,
encode(encodedWindow),
encode(encodedKey),
- StringUtf8Coder.of(),
+ ByteArrayCoder.of(),
StringUtf8Coder.of());
- userState.remove("A0");
- userState.put("A0", "V2");
- assertArrayEquals(new String[] {"V2"},
Iterables.toArray(userState.get("A0"), String.class));
+ userState.remove(A0);
+ userState.put(A0, "V2");
+ assertArrayEquals(new String[] {"V2"},
Iterables.toArray(userState.get(A0), String.class));
userState.asyncClose();
Map<StateKey, ByteString> data = fakeClient.getData();
- assertEquals(encode("V2"), data.get(createMultimapValueStateKey("A0")));
+ assertEquals(encode("V2"), data.get(createMultimapValueStateKey(A0)));
}
@Test
@@ -226,10 +240,10 @@ public class MultimapUserStateTest {
new FakeBeamFnStateClient(
ImmutableMap.of(
createMultimapKeyStateKey(),
- encode("A0"),
- createMultimapValueStateKey("A0"),
+ encode(A0),
+ createMultimapValueStateKey(A0),
encode("V1")));
- MultimapUserState<String, String> userState =
+ MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
fakeClient,
"instructionId",
@@ -237,11 +251,11 @@ public class MultimapUserStateTest {
stateId,
encode(encodedWindow),
encode(encodedKey),
- StringUtf8Coder.of(),
+ ByteArrayCoder.of(),
StringUtf8Coder.of());
userState.clear();
- userState.put("A0", "V2");
- assertArrayEquals(new String[] {"V2"},
Iterables.toArray(userState.get("A0"), String.class));
+ userState.put(A0, "V2");
+ assertArrayEquals(new String[] {"V2"},
Iterables.toArray(userState.get(A0), String.class));
}
@Test
@@ -250,10 +264,10 @@ public class MultimapUserStateTest {
new FakeBeamFnStateClient(
ImmutableMap.of(
createMultimapKeyStateKey(),
- encode("A0"),
- createMultimapValueStateKey("A0"),
+ encode(A0),
+ createMultimapValueStateKey(A0),
encode("V1")));
- MultimapUserState<String, String> userState =
+ MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
fakeClient,
"instructionId",
@@ -261,9 +275,9 @@ public class MultimapUserStateTest {
stateId,
encode(encodedWindow),
encode(encodedKey),
- StringUtf8Coder.of(),
+ ByteArrayCoder.of(),
StringUtf8Coder.of());
- userState.remove("A0");
+ userState.remove(A0);
userState.clear();
userState.asyncClose();
// Clear takes precedence over specific key remove
@@ -273,7 +287,7 @@ public class MultimapUserStateTest {
@Test
public void testPutBeforeClear() throws Exception {
FakeBeamFnStateClient fakeClient = new
FakeBeamFnStateClient(Collections.emptyMap());
- MultimapUserState<String, String> userState =
+ MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
fakeClient,
"instructionId",
@@ -281,11 +295,11 @@ public class MultimapUserStateTest {
stateId,
encode(encodedWindow),
encode(encodedKey),
- StringUtf8Coder.of(),
+ ByteArrayCoder.of(),
StringUtf8Coder.of());
- userState.put("A0", "V0");
- userState.put("A1", "V1");
- Iterable<String> values = userState.get("A1"); // fakeClient call = 1
+ userState.put(A0, "V0");
+ userState.put(A1, "V1");
+ Iterable<String> values = userState.get(A1); // fakeClient call = 1
userState.clear(); // fakeClient call = 2
assertArrayEquals(new String[] {"V1"}, Iterables.toArray(values,
String.class));
userState.asyncClose();
@@ -296,7 +310,7 @@ public class MultimapUserStateTest {
@Test
public void testPutBeforeRemove() throws Exception {
FakeBeamFnStateClient fakeClient = new
FakeBeamFnStateClient(Collections.emptyMap());
- MultimapUserState<String, String> userState =
+ MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
fakeClient,
"instructionId",
@@ -304,18 +318,18 @@ public class MultimapUserStateTest {
stateId,
encode(encodedWindow),
encode(encodedKey),
- StringUtf8Coder.of(),
+ ByteArrayCoder.of(),
StringUtf8Coder.of());
- userState.put("A0", "V0");
- userState.put("A1", "V1");
- Iterable<String> values = userState.get("A1"); // fakeClient call = 1
- userState.remove("A0"); // fakeClient call = 2
- userState.remove("A1"); // fakeClient call = 3
+ userState.put(A0, "V0");
+ userState.put(A1, "V1");
+ Iterable<String> values = userState.get(A1); // fakeClient call = 1
+ userState.remove(A0); // fakeClient call = 2
+ userState.remove(A1); // fakeClient call = 3
assertArrayEquals(new String[] {"V1"}, Iterables.toArray(values,
String.class));
userState.asyncClose();
assertThat(fakeClient.getCallCount(), is(3));
- assertNull(fakeClient.getData().get(createMultimapValueStateKey("A0")));
- assertNull(fakeClient.getData().get(createMultimapValueStateKey("A1")));
+ assertNull(fakeClient.getData().get(createMultimapValueStateKey(A0)));
+ assertNull(fakeClient.getData().get(createMultimapValueStateKey(A1)));
}
@Test
@@ -324,10 +338,10 @@ public class MultimapUserStateTest {
new FakeBeamFnStateClient(
ImmutableMap.of(
createMultimapKeyStateKey(),
- encode("A1"),
- createMultimapValueStateKey("A1"),
+ encode(A1),
+ createMultimapValueStateKey(A1),
encode("V1", "V2")));
- MultimapUserState<String, String> userState =
+ MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
fakeClient,
"instructionId",
@@ -335,17 +349,17 @@ public class MultimapUserStateTest {
stateId,
encode(encodedWindow),
encode(encodedKey),
- StringUtf8Coder.of(),
+ ByteArrayCoder.of(),
StringUtf8Coder.of());
- Iterable<String> initValues = userState.get("A1");
- userState.put("A1", "V3");
+ Iterable<String> initValues = userState.get(A1);
+ userState.put(A1, "V3");
- userState.remove("A1");
+ userState.remove(A1);
assertArrayEquals(new String[] {"V1", "V2"}, Iterables.toArray(initValues,
String.class));
assertThat(userState.keys(), is(emptyIterable()));
userState.asyncClose();
- assertThrows(IllegalStateException.class, () -> userState.remove("A1"));
+ assertThrows(IllegalStateException.class, () -> userState.remove(A1));
}
@Test
@@ -354,10 +368,10 @@ public class MultimapUserStateTest {
new FakeBeamFnStateClient(
ImmutableMap.of(
createMultimapKeyStateKey(),
- encode("A1"),
- createMultimapValueStateKey("A1"),
+ encode(A1),
+ createMultimapValueStateKey(A1),
encode("V1", "V2")));
- MultimapUserState<String, String> userState =
+ MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
fakeClient,
"instructionId",
@@ -365,11 +379,12 @@ public class MultimapUserStateTest {
stateId,
encode(encodedWindow),
encode(encodedKey),
- StringUtf8Coder.of(),
+ ByteArrayCoder.of(),
StringUtf8Coder.of());
- Iterable<String> keys = userState.keys();
- assertThrows(
- UnsupportedOperationException.class, () -> Iterables.removeAll(keys,
Arrays.asList("A1")));
+ Iterable<byte[]> keys = userState.keys();
+ Iterator<byte[]> keysIterator = keys.iterator();
+ keysIterator.next();
+ assertThrows(UnsupportedOperationException.class, () ->
keysIterator.remove());
}
@Test
@@ -378,10 +393,10 @@ public class MultimapUserStateTest {
new FakeBeamFnStateClient(
ImmutableMap.of(
createMultimapKeyStateKey(),
- encode("A1"),
- createMultimapValueStateKey("A1"),
+ encode(A1),
+ createMultimapValueStateKey(A1),
encode("V1", "V2")));
- MultimapUserState<String, String> userState =
+ MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
fakeClient,
"instructionId",
@@ -389,9 +404,9 @@ public class MultimapUserStateTest {
stateId,
encode(encodedWindow),
encode(encodedKey),
- StringUtf8Coder.of(),
+ ByteArrayCoder.of(),
StringUtf8Coder.of());
- Iterable<String> values = userState.get("A1");
+ Iterable<String> values = userState.get(A1);
assertThrows(
UnsupportedOperationException.class,
() -> Iterables.removeAll(values, Arrays.asList("V1")));
@@ -403,10 +418,10 @@ public class MultimapUserStateTest {
new FakeBeamFnStateClient(
ImmutableMap.of(
createMultimapKeyStateKey(),
- encode("A1"),
- createMultimapValueStateKey("A1"),
+ encode(A1),
+ createMultimapValueStateKey(A1),
encode("V1", "V2")));
- MultimapUserState<String, String> userState =
+ MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
fakeClient,
"instructionId",
@@ -414,7 +429,7 @@ public class MultimapUserStateTest {
stateId,
encode(encodedWindow),
encode(encodedKey),
- StringUtf8Coder.of(),
+ ByteArrayCoder.of(),
StringUtf8Coder.of());
userState.clear();
userState.asyncClose();
@@ -429,10 +444,10 @@ public class MultimapUserStateTest {
new FakeBeamFnStateClient(
ImmutableMap.of(
createMultimapKeyStateKey(),
- encode("A1"),
- createMultimapValueStateKey("A1"),
+ encode(A1),
+ createMultimapValueStateKey(A1),
encode("V1", "V2")));
- MultimapUserState<String, String> userState =
+ MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
fakeClient,
"instructionId",
@@ -440,7 +455,7 @@ public class MultimapUserStateTest {
stateId,
encode(encodedWindow),
encode(encodedKey),
- StringUtf8Coder.of(),
+ ByteArrayCoder.of(),
StringUtf8Coder.of());
userState.asyncClose();
assertThrows(IllegalStateException.class, () -> userState.keys());
@@ -453,12 +468,12 @@ public class MultimapUserStateTest {
new FakeBeamFnStateClient(
ImmutableMap.of(
createMultimapKeyStateKey(),
- encode("A0", "A1"),
- createMultimapValueStateKey("A0"),
+ encode(A0, A1),
+ createMultimapValueStateKey(A0),
encode("V1"),
- createMultimapValueStateKey("A1"),
+ createMultimapValueStateKey(A1),
encode("V1", "V2")));
- MultimapUserState<String, String> userState =
+ MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
fakeClient,
"instructionId",
@@ -466,18 +481,18 @@ public class MultimapUserStateTest {
stateId,
encode(encodedWindow),
encode(encodedKey),
- StringUtf8Coder.of(),
+ ByteArrayCoder.of(),
StringUtf8Coder.of());
- userState.remove("A0");
- userState.put("A1", "V3");
- userState.put("A2", "V1");
- userState.put("A3", "V1");
- userState.remove("A3");
+ userState.remove(A0);
+ userState.put(A1, "V3");
+ userState.put(A2, "V1");
+ userState.put(A3, "V1");
+ userState.remove(A3);
userState.asyncClose();
Map<StateKey, ByteString> data = fakeClient.getData();
- assertNull(data.get(createMultimapValueStateKey("A0")));
- assertEquals(encode("V1", "V2", "V3"),
data.get(createMultimapValueStateKey("A1")));
- assertEquals(encode("V1"), data.get(createMultimapValueStateKey("A2")));
+ assertNull(data.get(createMultimapValueStateKey(A0)));
+ assertEquals(encode("V1", "V2", "V3"),
data.get(createMultimapValueStateKey(A1)));
+ assertEquals(encode("V1"), data.get(createMultimapValueStateKey(A2)));
}
@Test
@@ -486,10 +501,10 @@ public class MultimapUserStateTest {
new FakeBeamFnStateClient(
ImmutableMap.of(
createMultimapKeyStateKey(),
- encode("A1"),
- createMultimapValueStateKey("A1"),
+ encode(A1),
+ createMultimapValueStateKey(A1),
encode("V1", "V2")));
- MultimapUserState<String, String> userState =
+ MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
fakeClient,
"instructionId",
@@ -497,7 +512,7 @@ public class MultimapUserStateTest {
stateId,
encode(encodedWindow),
encode(encodedKey),
- NullableCoder.of(StringUtf8Coder.of()),
+ NullableCoder.of(ByteArrayCoder.of()),
NullableCoder.of(StringUtf8Coder.of()));
userState.put(null, null);
userState.put(null, null);
@@ -509,7 +524,7 @@ public class MultimapUserStateTest {
@Test
public void testNegativeCache() throws Exception {
FakeBeamFnStateClient fakeClient = new
FakeBeamFnStateClient(Collections.emptyMap());
- MultimapUserState<String, String> userState =
+ MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
fakeClient,
"instructionId",
@@ -517,13 +532,191 @@ public class MultimapUserStateTest {
stateId,
encode(encodedWindow),
encode(encodedKey),
- StringUtf8Coder.of(),
+ ByteArrayCoder.of(),
StringUtf8Coder.of());
- userState.get("A1");
- userState.get("A1");
+ assertArrayEquals(new String[] {}, Iterables.toArray(userState.get(A1),
String.class));
+ assertArrayEquals(new String[] {}, Iterables.toArray(userState.get(A1),
String.class));
assertThat(fakeClient.getCallCount(), is(1));
}
+ @Test
+ public void testGetValuesPrefetch() throws Exception {
+ FakeBeamFnStateClient fakeClient =
+ new FakeBeamFnStateClient(
+ ImmutableMap.of(
+ createMultimapKeyStateKey(),
+ encode(A1),
+ createMultimapValueStateKey(A1),
+ encode("V1", "V2")));
+ MultimapUserState<byte[], String> userState =
+ new MultimapUserState<>(
+ fakeClient,
+ "instructionId",
+ pTransformId,
+ stateId,
+ encode(encodedWindow),
+ encode(encodedKey),
+ ByteArrayCoder.of(),
+ StringUtf8Coder.of());
+
+ PrefetchableIterable<String> values = userState.get(A1);
+ assertEquals(0, fakeClient.getCallCount());
+ values.iterator().prefetch();
+ assertEquals(1, fakeClient.getCallCount());
+ assertArrayEquals(new String[] {"V1", "V2"}, Iterables.toArray(values,
String.class));
+ assertEquals(1, fakeClient.getCallCount());
+ }
+
+ @Test
+ public void testGetKeysPrefetch() throws Exception {
+ FakeBeamFnStateClient fakeClient =
+ new FakeBeamFnStateClient(
+ ImmutableMap.of(
+ createMultimapKeyStateKey(),
+ encode(A1),
+ createMultimapValueStateKey(A1),
+ encode("V1", "V2")));
+ MultimapUserState<byte[], String> userState =
+ new MultimapUserState<>(
+ fakeClient,
+ "instructionId",
+ pTransformId,
+ stateId,
+ encode(encodedWindow),
+ encode(encodedKey),
+ ByteArrayCoder.of(),
+ StringUtf8Coder.of());
+
+ PrefetchableIterable<byte[]> keys = userState.keys();
+ assertEquals(0, fakeClient.getCallCount());
+ keys.iterator().prefetch();
+ assertEquals(1, fakeClient.getCallCount());
+ assertArrayEquals(new byte[][] {A1}, Iterables.toArray(keys,
byte[].class));
+ assertEquals(1, fakeClient.getCallCount());
+ }
+
+ @Test
+ public void testPutKeysPrefetch() throws Exception {
+ FakeBeamFnStateClient fakeClient =
+ new FakeBeamFnStateClient(
+ ImmutableMap.of(
+ createMultimapKeyStateKey(),
+ encode(A1),
+ createMultimapValueStateKey(A1),
+ encode("V1", "V2")));
+ MultimapUserState<byte[], String> userState =
+ new MultimapUserState<>(
+ fakeClient,
+ "instructionId",
+ pTransformId,
+ stateId,
+ encode(encodedWindow),
+ encode(encodedKey),
+ ByteArrayCoder.of(),
+ StringUtf8Coder.of());
+
+ userState.put(A2, "V3");
+ PrefetchableIterable<byte[]> keys = userState.keys();
+ assertEquals(0, fakeClient.getCallCount());
+ keys.iterator().prefetch();
+ assertEquals(1, fakeClient.getCallCount());
+ assertArrayEquals(new byte[][] {A1, A2}, Iterables.toArray(keys,
byte[].class));
+ assertEquals(1, fakeClient.getCallCount());
+ }
+
+ @Test
+ public void testRemoveKeysPrefetch() throws Exception {
+ FakeBeamFnStateClient fakeClient =
+ new FakeBeamFnStateClient(
+ ImmutableMap.of(
+ createMultimapKeyStateKey(),
+ encode(A1),
+ createMultimapValueStateKey(A1),
+ encode("V1", "V2")));
+ MultimapUserState<byte[], String> userState =
+ new MultimapUserState<>(
+ fakeClient,
+ "instructionId",
+ pTransformId,
+ stateId,
+ encode(encodedWindow),
+ encode(encodedKey),
+ ByteArrayCoder.of(),
+ StringUtf8Coder.of());
+
+ userState.remove(A1);
+ userState.put(A1, "V3");
+ PrefetchableIterable<String> values = userState.get(A1);
+ assertEquals(0, fakeClient.getCallCount());
+ values.iterator().prefetch();
+ // Removed keys don't require accessing the underlying persisted state
+ assertEquals(0, fakeClient.getCallCount());
+ assertArrayEquals(new String[] {"V3"}, Iterables.toArray(values,
String.class));
+ // Removed keys don't require accessing the underlying persisted state
+ assertEquals(0, fakeClient.getCallCount());
+ }
+
+ @Test
+ public void testClearPrefetch() throws Exception {
+ FakeBeamFnStateClient fakeClient =
+ new FakeBeamFnStateClient(
+ ImmutableMap.of(
+ createMultimapKeyStateKey(),
+ encode(A1),
+ createMultimapValueStateKey(A1),
+ encode("V1", "V2")));
+ MultimapUserState<byte[], String> userState =
+ new MultimapUserState<>(
+ fakeClient,
+ "instructionId",
+ pTransformId,
+ stateId,
+ encode(encodedWindow),
+ encode(encodedKey),
+ ByteArrayCoder.of(),
+ StringUtf8Coder.of());
+
+ userState.clear();
+ userState.put(A2, "V3");
+ PrefetchableIterable<byte[]> keys = userState.keys();
+ assertEquals(0, fakeClient.getCallCount());
+ keys.iterator().prefetch();
+ // Cleared keys don't require accessing the underlying persisted state
+ assertEquals(0, fakeClient.getCallCount());
+ assertArrayEquals(new byte[][] {A2}, Iterables.toArray(keys,
byte[].class));
+ // Cleared keys don't require accessing the underlying persisted state
+ assertEquals(0, fakeClient.getCallCount());
+ }
+
+ @Test
+ public void testAppendValuesPrefetch() throws Exception {
+ FakeBeamFnStateClient fakeClient =
+ new FakeBeamFnStateClient(
+ ImmutableMap.of(
+ createMultimapKeyStateKey(),
+ encode(A1),
+ createMultimapValueStateKey(A1),
+ encode("V1", "V2")));
+ MultimapUserState<byte[], String> userState =
+ new MultimapUserState<>(
+ fakeClient,
+ "instructionId",
+ pTransformId,
+ stateId,
+ encode(encodedWindow),
+ encode(encodedKey),
+ ByteArrayCoder.of(),
+ StringUtf8Coder.of());
+
+ userState.put(A1, "V3");
+ PrefetchableIterable<String> values = userState.get(A1);
+ assertEquals(0, fakeClient.getCallCount());
+ values.iterator().prefetch();
+ assertEquals(1, fakeClient.getCallCount());
+ assertArrayEquals(new String[] {"V1", "V2", "V3"},
Iterables.toArray(values, String.class));
+ assertEquals(1, fakeClient.getCallCount());
+ }
+
private StateKey createMultimapKeyStateKey() throws IOException {
return StateKey.newBuilder()
.setMultimapKeysUserState(
@@ -535,7 +728,7 @@ public class MultimapUserStateTest {
.build();
}
- private StateKey createMultimapValueStateKey(String key) throws IOException {
+ private StateKey createMultimapValueStateKey(byte[] key) throws IOException {
return StateKey.newBuilder()
.setMultimapUserState(
StateKey.MultimapUserState.newBuilder()
@@ -554,4 +747,12 @@ public class MultimapUserStateTest {
}
return out.toByteString();
}
+
+ private ByteString encode(byte[]... values) throws IOException {
+ ByteString.Output out = ByteString.newOutput();
+ for (byte[] value : values) {
+ ByteArrayCoder.of().encode(value, out);
+ }
+ return out.toByteString();
+ }
}