This is an automated email from the ASF dual-hosted git repository.
lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new ea65a05 [BEAM-13015] Migrate all user state and side implementations
to support caching. (#16263)
ea65a05 is described below
commit ea65a054f2fcb6349478d19609a773f66bbfa20e
Author: Lukasz Cwik <[email protected]>
AuthorDate: Tue Jan 4 15:39:36 2022 -0800
[BEAM-13015] Migrate all user state and side implementations to support
caching. (#16263)
This change also ensures that prefetch can be invoked on the iterable to
prevent the prefetch being lost once the iterator is discarded.
See
https://s.apache.org/beam-fn-state-api-and-bundle-processing#heading=h.tms0ncgbzz6f
---
.../beam/sdk/fn/stream/PrefetchableIterable.java | 3 +
.../beam/sdk/fn/stream/PrefetchableIterables.java | 53 ++-
.../sdk/fn/stream/PrefetchableIterablesTest.java | 20 +
.../java/org/apache/beam/fn/harness/Caches.java | 1 +
.../apache/beam/fn/harness/state/BagUserState.java | 57 ++-
.../beam/fn/harness/state/FnApiStateAccessor.java | 135 +++---
.../beam/fn/harness/state/IterableSideInput.java | 47 +-
.../state/LazyCachingIteratorToIterable.java | 6 +-
.../beam/fn/harness/state/MultimapSideInput.java | 69 +--
.../beam/fn/harness/state/MultimapUserState.java | 184 +++++---
.../fn/harness/state/StateFetchingIterators.java | 9 +-
.../beam/fn/harness/state/BagUserStateTest.java | 170 ++++++-
.../fn/harness/state/IterableSideInputTest.java | 96 ++++
.../fn/harness/state/MultimapSideInputTest.java | 71 ++-
.../fn/harness/state/MultimapUserStateTest.java | 524 +++++++++++++++++----
15 files changed, 1086 insertions(+), 359 deletions(-)
diff --git
a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/PrefetchableIterable.java
b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/PrefetchableIterable.java
index 5700f6c..9a2fde8 100644
---
a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/PrefetchableIterable.java
+++
b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/PrefetchableIterable.java
@@ -20,6 +20,9 @@ package org.apache.beam.sdk.fn.stream;
/** An {@link Iterable} that returns {@link PrefetchableIterator}s. */
public interface PrefetchableIterable<T> extends Iterable<T> {
+ /** Ensures that the next iterator returned has been prefetched. */
+ void prefetch();
+
@Override
PrefetchableIterator<T> iterator();
}
diff --git
a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/PrefetchableIterables.java
b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/PrefetchableIterables.java
index e55ece3..d8696f0 100644
---
a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/PrefetchableIterables.java
+++
b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/PrefetchableIterables.java
@@ -18,6 +18,7 @@
package org.apache.beam.sdk.fn.stream;
import java.util.NoSuchElementException;
+import javax.annotation.Nullable;
import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.FluentIterable;
/**
@@ -26,8 +27,42 @@ import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.FluentIt
*/
public class PrefetchableIterables {
+ /**
+ * A default implementation that caches an iterator to be returned when
{@link #prefetch} is
+ * invoked.
+ */
+ public abstract static class Default<T> implements PrefetchableIterable<T> {
+ @Nullable private PrefetchableIterator<T> iterator = null;
+
+ @Override
+ public final void prefetch() {
+ if (iterator != null) {
+ return;
+ }
+ iterator = createIterator();
+ iterator.prefetch();
+ }
+
+ @Override
+ public final PrefetchableIterator<T> iterator() {
+ if (iterator == null) {
+ return createIterator();
+ }
+ PrefetchableIterator<T> rval = iterator;
+ iterator = null;
+ return rval;
+ }
+
+ protected abstract PrefetchableIterator<T> createIterator();
+ }
+
private static final PrefetchableIterable<Object> EMPTY_ITERABLE =
- PrefetchableIterators::emptyIterator;
+ new Default<Object>() {
+ @Override
+ protected PrefetchableIterator<Object> createIterator() {
+ return PrefetchableIterators.emptyIterator();
+ }
+ };
/** Returns an empty {@link PrefetchableIterable}. */
public static <T> PrefetchableIterable<T> emptyIterable() {
@@ -44,9 +79,9 @@ public class PrefetchableIterables {
if (values.length == 0) {
return emptyIterable();
}
- return new PrefetchableIterable<T>() {
+ return new Default<T>() {
@Override
- public PrefetchableIterator<T> iterator() {
+ public PrefetchableIterator<T> createIterator() {
return PrefetchableIterators.fromArray(values);
}
};
@@ -63,9 +98,9 @@ public class PrefetchableIterables {
if (iterable instanceof PrefetchableIterable) {
return (PrefetchableIterable<T>) iterable;
}
- return new PrefetchableIterable<T>() {
+ return new Default<T>() {
@Override
- public PrefetchableIterator<T> iterator() {
+ public PrefetchableIterator<T> createIterator() {
return PrefetchableIterators.maybePrefetchable(iterable.iterator());
}
};
@@ -87,10 +122,10 @@ public class PrefetchableIterables {
} else if (iterables.length == 1) {
return maybePrefetchable(iterables[0]);
}
- return new PrefetchableIterable<T>() {
+ return new Default<T>() {
@SuppressWarnings("methodref.receiver.invalid")
@Override
- public PrefetchableIterator<T> iterator() {
+ public PrefetchableIterator<T> createIterator() {
return PrefetchableIterators.concatIterators(
FluentIterable.from(iterables).transform(Iterable::iterator).iterator());
}
@@ -100,9 +135,9 @@ public class PrefetchableIterables {
/** Limits the {@link PrefetchableIterable} to the specified number of
elements. */
public static <T> PrefetchableIterable<T> limit(Iterable<T> iterable, int
limit) {
PrefetchableIterable<T> prefetchableIterable = maybePrefetchable(iterable);
- return new PrefetchableIterable<T>() {
+ return new Default<T>() {
@Override
- public PrefetchableIterator<T> iterator() {
+ public PrefetchableIterator<T> createIterator() {
return new PrefetchableIterator<T>() {
PrefetchableIterator<T> delegate = prefetchableIterable.iterator();
int currentPosition;
diff --git
a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/PrefetchableIterablesTest.java
b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/PrefetchableIterablesTest.java
index e4d1b0d..3de5b48 100644
---
a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/PrefetchableIterablesTest.java
+++
b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/PrefetchableIterablesTest.java
@@ -17,8 +17,12 @@
*/
package org.apache.beam.sdk.fn.stream;
+import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertSame;
+import static org.junit.Assert.assertTrue;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterables.Default;
+import
org.apache.beam.sdk.fn.stream.PrefetchableIteratorsTest.ReadyAfterPrefetchUntilNext;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -57,6 +61,22 @@ public class PrefetchableIterablesTest {
}
@Test
+ public void testDefaultPrefetch() {
+ PrefetchableIterable<String> iterable =
+ new Default<String>() {
+ @Override
+ protected PrefetchableIterator<String> createIterator() {
+ return new ReadyAfterPrefetchUntilNext<>(
+ PrefetchableIterators.fromArray("A", "B", "C"));
+ }
+ };
+
+ assertFalse(iterable.iterator().isReady());
+ iterable.prefetch();
+ assertTrue(iterable.iterator().isReady());
+ }
+
+ @Test
public void testConcat() {
verifyIterable(PrefetchableIterables.concat());
diff --git
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/Caches.java
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/Caches.java
index e5671aa..e53483e 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/Caches.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/Caches.java
@@ -165,6 +165,7 @@ public final class Caches {
cache == null ? "null" : cache.getClass()));
}
+ @VisibleForTesting
static Cache<Object, Object> forCache(
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.Cache<CompositeKey,
Object>
cache) {
diff --git
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java
index 76664ba..5952dcf 100644
---
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java
+++
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java
@@ -17,12 +17,17 @@
*/
package org.apache.beam.fn.harness.state;
+import static
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
import static
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
import java.util.ArrayList;
import java.util.Collections;
+import java.util.List;
+import org.apache.beam.fn.harness.Cache;
+import
org.apache.beam.fn.harness.state.StateFetchingIterators.CachingStateIterable;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateAppendRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateClearRequest;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
@@ -39,45 +44,39 @@ import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterable
*
* <p>TODO: Move to an async persist model where persistence is signalled
based upon cache memory
* pressure and its need to flush.
- *
- * <p>TODO: Support block level caching and prefetch.
*/
@SuppressWarnings({
"rawtypes", // TODO(https://issues.apache.org/jira/browse/BEAM-10556)
"nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
})
public class BagUserState<T> {
+ private final Cache<?, ?> cache;
private final BeamFnStateClient beamFnStateClient;
private final StateRequest request;
private final Coder<T> valueCoder;
- private PrefetchableIterable<T> oldValues;
- private ArrayList<T> newValues;
+ private final CachingStateIterable<T> oldValues;
+ private List<T> newValues;
+ private boolean isCleared;
private boolean isClosed;
+ /** The cache must be namespaced for this state object accordingly. */
public BagUserState(
+ Cache<?, ?> cache,
BeamFnStateClient beamFnStateClient,
String instructionId,
- String ptransformId,
- String stateId,
- ByteString encodedWindow,
- ByteString encodedKey,
+ StateKey stateKey,
Coder<T> valueCoder) {
+ checkArgument(
+ stateKey.hasBagUserState(), "Expected BagUserState StateKey but
received %s.", stateKey);
+ this.cache = cache;
this.beamFnStateClient = beamFnStateClient;
this.valueCoder = valueCoder;
-
- StateRequest.Builder requestBuilder = StateRequest.newBuilder();
- requestBuilder
- .setInstructionId(instructionId)
- .getStateKeyBuilder()
- .getBagUserStateBuilder()
- .setTransformId(ptransformId)
- .setUserStateId(stateId)
- .setWindow(encodedWindow)
- .setKey(encodedKey);
- request = requestBuilder.build();
+ this.request =
+
StateRequest.newBuilder().setInstructionId(instructionId).setStateKey(stateKey).build();
this.oldValues =
- StateFetchingIterators.readAllAndDecodeStartingFrom(beamFnStateClient,
request, valueCoder);
+ StateFetchingIterators.readAllAndDecodeStartingFrom(
+ this.cache, beamFnStateClient, request, valueCoder);
this.newValues = new ArrayList<>();
}
@@ -86,7 +85,7 @@ public class BagUserState<T> {
!isClosed,
"Bag user state is no longer usable because it is closed for %s",
request.getStateKey());
- if (oldValues == null) {
+ if (isCleared) {
// If we were cleared we should disregard old values.
return
PrefetchableIterables.limit(Collections.unmodifiableList(newValues),
newValues.size());
} else if (newValues.isEmpty()) {
@@ -110,7 +109,7 @@ public class BagUserState<T> {
!isClosed,
"Bag user state is no longer usable because it is closed for %s",
request.getStateKey());
- oldValues = null;
+ isCleared = true;
newValues = new ArrayList<>();
}
@@ -120,7 +119,11 @@ public class BagUserState<T> {
!isClosed,
"Bag user state is no longer usable because it is closed for %s",
request.getStateKey());
- if (oldValues == null) {
+ isClosed = true;
+ if (!isCleared && newValues.isEmpty()) {
+ return;
+ }
+ if (isCleared) {
beamFnStateClient.handle(
request.toBuilder().setClear(StateClearRequest.getDefaultInstance()));
}
@@ -135,6 +138,12 @@ public class BagUserState<T> {
.toBuilder()
.setAppend(StateAppendRequest.newBuilder().setData(out.toByteString())));
}
- isClosed = true;
+
+ // Modify the underlying cached state depending on the mutations performed
+ if (isCleared) {
+ oldValues.clearAndAppend(newValues);
+ } else {
+ oldValues.append(newValues);
+ }
}
}
diff --git
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
index 5a14e2f..682b63d 100644
---
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
+++
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
@@ -29,6 +29,7 @@ import java.util.Map;
import java.util.function.Function;
import java.util.function.Supplier;
import org.apache.beam.fn.harness.Cache;
+import org.apache.beam.fn.harness.Caches;
import
org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleRequest.CacheToken;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
import org.apache.beam.runners.core.SideInputReader;
@@ -169,7 +170,6 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
}
ByteString encodedWindow = encodedWindowOut.toByteString();
StateKey.Builder cacheKeyBuilder = StateKey.newBuilder();
- Object sideInputAccessor;
switch (sideInputSpec.getAccessPattern()) {
case Materializations.ITERABLE_MATERIALIZATION_URN:
@@ -178,14 +178,6 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
.setTransformId(ptransformId)
.setSideInputId(tag.getId())
.setWindow(encodedWindow);
- sideInputAccessor =
- new IterableSideInput<>(
- beamFnStateClient,
- processBundleInstructionId.get(),
- ptransformId,
- tag.getId(),
- encodedWindow,
- sideInputSpec.getCoder());
break;
case Materializations.MULTIMAP_MATERIALIZATION_URN:
@@ -194,21 +186,11 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
"Expected %s but received %s.",
KvCoder.class,
sideInputSpec.getCoder().getClass());
- KvCoder<?, ?> kvCoder = (KvCoder) sideInputSpec.getCoder();
cacheKeyBuilder
- .getMultimapSideInputBuilder()
+ .getMultimapKeysSideInputBuilder()
.setTransformId(ptransformId)
.setSideInputId(tag.getId())
.setWindow(encodedWindow);
- sideInputAccessor =
- new MultimapSideInput<>(
- beamFnStateClient,
- processBundleInstructionId.get(),
- ptransformId,
- tag.getId(),
- encodedWindow,
- kvCoder.getKeyCoder(),
- kvCoder.getValueCoder());
break;
default:
@@ -222,10 +204,44 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
sideInputSpec.getAccessPattern(),
tag));
}
-
return (T)
stateKeyObjectCache.computeIfAbsent(
- cacheKeyBuilder.build(), key ->
sideInputSpec.getViewFn().apply(sideInputAccessor));
+ cacheKeyBuilder.build(),
+ key -> {
+ switch (sideInputSpec.getAccessPattern()) {
+ case Materializations.ITERABLE_MATERIALIZATION_URN:
+ return sideInputSpec
+ .getViewFn()
+ .apply(
+ new IterableSideInput<>(
+ getCacheFor(key),
+ beamFnStateClient,
+ processBundleInstructionId.get(),
+ key,
+ sideInputSpec.getCoder()));
+ case Materializations.MULTIMAP_MATERIALIZATION_URN:
+ return sideInputSpec
+ .getViewFn()
+ .apply(
+ new MultimapSideInput<>(
+ getCacheFor(key),
+ beamFnStateClient,
+ processBundleInstructionId.get(),
+ key,
+ ((KvCoder)
sideInputSpec.getCoder()).getKeyCoder(),
+ ((KvCoder)
sideInputSpec.getCoder()).getValueCoder()));
+ default:
+ throw new IllegalStateException(
+ String.format(
+ "This SDK is only capable of dealing with %s
materializations "
+ + "but was asked to handle %s for
PCollectionView with tag %s.",
+ ImmutableList.of(
+ Materializations.ITERABLE_MATERIALIZATION_URN,
+ Materializations.MULTIMAP_MATERIALIZATION_URN),
+ sideInputSpec.getAccessPattern(),
+ tag));
+ }
+ });
}
@Override
@@ -247,7 +263,7 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
@Override
public Object apply(StateKey key) {
return new ValueState<T>() {
- private final BagUserState<T> impl = createBagUserState(id,
coder);
+ private final BagUserState<T> impl = createBagUserState(key,
coder);
@Override
public void clear() {
@@ -272,7 +288,7 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
@Override
public ValueState<T> readLater() {
- impl.get().iterator().prefetch();
+ impl.get().prefetch();
return this;
}
};
@@ -289,7 +305,7 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
@Override
public Object apply(StateKey key) {
return new BagState<T>() {
- private final BagUserState<T> impl = createBagUserState(id,
elemCoder);
+ private final BagUserState<T> impl = createBagUserState(key,
elemCoder);
@Override
public void add(T value) {
@@ -318,7 +334,7 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
@Override
public BagState<T> readLater() {
- impl.get().iterator().prefetch();
+ impl.get().prefetch();
return this;
}
@@ -335,13 +351,13 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
public <T> SetState<T> bindSet(String id, StateSpec<SetState<T>> spec,
Coder<T> elemCoder) {
return (SetState<T>)
stateKeyObjectCache.computeIfAbsent(
- createMultimapUserStateKey(id),
+ createMultimapKeysUserStateKey(id),
new Function<StateKey, Object>() {
@Override
public Object apply(StateKey key) {
return new SetState<T>() {
private final MultimapUserState<T, Void> impl =
- createMultimapUserState(id, elemCoder, VoidCoder.of());
+ createMultimapUserState(key, elemCoder, VoidCoder.of());
@Override
public void clear() {
@@ -358,7 +374,7 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
@Override
public ReadableState<Boolean> readLater() {
- impl.get(t).iterator().prefetch();
+ impl.get(t).prefetch();
return this;
}
};
@@ -394,7 +410,7 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
@Override
public ReadableState<Boolean> readLater() {
- impl.keys().iterator().prefetch();
+ impl.keys().prefetch();
return this;
}
};
@@ -407,7 +423,7 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
@Override
public SetState<T> readLater() {
- impl.keys().iterator().prefetch();
+ impl.keys().prefetch();
return this;
}
};
@@ -423,13 +439,13 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
Coder<ValueT> mapValueCoder) {
return (MapState<KeyT, ValueT>)
stateKeyObjectCache.computeIfAbsent(
- createMultimapUserStateKey(id),
+ createMultimapKeysUserStateKey(id),
new Function<StateKey, Object>() {
@Override
public Object apply(StateKey key) {
return new MapState<KeyT, ValueT>() {
private final MultimapUserState<KeyT, ValueT> impl =
- createMultimapUserState(id, mapKeyCoder, mapValueCoder);
+ createMultimapUserState(key, mapKeyCoder, mapValueCoder);
@Override
public void clear() {
@@ -474,7 +490,7 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
@Override
public ReadableState<ValueT> readLater() {
- impl.get(key).iterator().prefetch();
+ impl.get(key).prefetch();
return this;
}
};
@@ -490,7 +506,7 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
@Override
public ReadableState<Iterable<KeyT>> readLater() {
- impl.keys().iterator().prefetch();
+ impl.keys().prefetch();
return this;
}
};
@@ -574,7 +590,7 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
// TODO: Support squashing accumulators depending on whether
we know of all
// remote accumulators and local accumulators or just local
accumulators.
return new CombiningState<ElementT, AccumT, ResultT>() {
- private final BagUserState<AccumT> impl =
createBagUserState(id, accumCoder);
+ private final BagUserState<AccumT> impl =
createBagUserState(key, accumCoder);
@Override
public AccumT getAccum() {
@@ -606,7 +622,7 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
@Override
public CombiningState<ElementT, AccumT, ResultT> readLater()
{
- impl.get().iterator().prefetch();
+ impl.get().prefetch();
return this;
}
@@ -636,7 +652,7 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
@Override
public ReadableState<Boolean> readLater() {
- impl.get().iterator().prefetch();
+ impl.get().prefetch();
return this;
}
};
@@ -697,31 +713,20 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
throw new UnsupportedOperationException("WatermarkHoldState is unsupported
by the Fn API.");
}
- private <KeyT, ValueT> MultimapUserState<KeyT, ValueT>
createMultimapUserState(
- String stateId, Coder<KeyT> keyCoder, Coder<ValueT> valueCoder) {
- MultimapUserState<KeyT, ValueT> rval =
- new MultimapUserState(
- beamFnStateClient,
- processBundleInstructionId.get(),
- ptransformId,
- stateId,
- encodedCurrentWindowSupplier.get(),
- encodedCurrentKeySupplier.get(),
- keyCoder,
- valueCoder);
- stateFinalizers.add(rval::asyncClose);
- return rval;
+ private Cache<?, ?> getCacheFor(StateKey stateKey) {
+ switch (stateKey.getTypeCase()) {
+ default:
+ return Caches.noop();
+ }
}
- private <T> BagUserState<T> createBagUserState(String stateId, Coder<T>
valueCoder) {
+ private <T> BagUserState<T> createBagUserState(StateKey stateKey, Coder<T>
valueCoder) {
BagUserState<T> rval =
new BagUserState<>(
+ getCacheFor(stateKey),
beamFnStateClient,
processBundleInstructionId.get(),
- ptransformId,
- stateId,
- encodedCurrentWindowSupplier.get(),
- encodedCurrentKeySupplier.get(),
+ stateKey,
valueCoder);
stateFinalizers.add(rval::asyncClose);
return rval;
@@ -738,7 +743,21 @@ public class FnApiStateAccessor<K> implements
SideInputReader, StateBinder {
return builder.build();
}
- private StateKey createMultimapUserStateKey(String stateId) {
+ private <KeyT, ValueT> MultimapUserState<KeyT, ValueT>
createMultimapUserState(
+ StateKey stateKey, Coder<KeyT> keyCoder, Coder<ValueT> valueCoder) {
+ MultimapUserState<KeyT, ValueT> rval =
+ new MultimapUserState(
+ Caches.noop(),
+ beamFnStateClient,
+ processBundleInstructionId.get(),
+ stateKey,
+ keyCoder,
+ valueCoder);
+ stateFinalizers.add(rval::asyncClose);
+ return rval;
+ }
+
+ private StateKey createMultimapKeysUserStateKey(String stateId) {
StateKey.Builder builder = StateKey.newBuilder();
builder
.getMultimapKeysUserStateBuilder()
diff --git
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/IterableSideInput.java
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/IterableSideInput.java
index 48c35de..d59742c 100644
---
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/IterableSideInput.java
+++
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/IterableSideInput.java
@@ -17,55 +17,44 @@
*/
package org.apache.beam.fn.harness.state;
+import static
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
+
+import org.apache.beam.fn.harness.Cache;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.transforms.Materializations.IterableView;
-import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
/**
* An implementation of a iterable side input that utilizes the Beam Fn State
API to fetch values.
- *
- * <p>TODO: Support block level caching and prefetch.
*/
@SuppressWarnings({
"rawtypes" // TODO(https://issues.apache.org/jira/browse/BEAM-10556)
})
public class IterableSideInput<T> implements IterableView<T> {
- private final BeamFnStateClient beamFnStateClient;
- private final String instructionId;
- private final String ptransformId;
- private final String sideInputId;
- private final ByteString encodedWindow;
- private final Coder<T> valueCoder;
+ private final Iterable<T> values;
public IterableSideInput(
+ Cache<?, ?> cache,
BeamFnStateClient beamFnStateClient,
String instructionId,
- String ptransformId,
- String sideInputId,
- ByteString encodedWindow,
+ StateKey stateKey,
Coder<T> valueCoder) {
- this.beamFnStateClient = beamFnStateClient;
- this.instructionId = instructionId;
- this.ptransformId = ptransformId;
- this.sideInputId = sideInputId;
- this.encodedWindow = encodedWindow;
- this.valueCoder = valueCoder;
+ checkArgument(
+ stateKey.hasIterableSideInput(),
+ "Expected IterableSideInput StateKey but received %s.",
+ stateKey);
+ this.values =
+ StateFetchingIterators.readAllAndDecodeStartingFrom(
+ cache,
+ beamFnStateClient,
+
StateRequest.newBuilder().setInstructionId(instructionId).setStateKey(stateKey).build(),
+ valueCoder);
}
@Override
public Iterable<T> get() {
- StateRequest.Builder requestBuilder = StateRequest.newBuilder();
- requestBuilder
- .setInstructionId(instructionId)
- .getStateKeyBuilder()
- .getIterableSideInputBuilder()
- .setTransformId(ptransformId)
- .setSideInputId(sideInputId)
- .setWindow(encodedWindow);
-
- return StateFetchingIterators.readAllAndDecodeStartingFrom(
- beamFnStateClient, requestBuilder.build(), valueCoder);
+ return values;
}
}
diff --git
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java
index 7828f93..d7e76fa 100644
---
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java
+++
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java
@@ -22,7 +22,7 @@ import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Objects;
-import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterables;
import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.checkerframework.checker.nullness.qual.Nullable;
@@ -31,7 +31,7 @@ import org.checkerframework.checker.nullness.qual.Nullable;
* Converts an iterator to an iterable lazily loading values from the
underlying iterator and
* caching them to support reiteration.
*/
-class LazyCachingIteratorToIterable<T> implements PrefetchableIterable<T> {
+class LazyCachingIteratorToIterable<T> extends
PrefetchableIterables.Default<T> {
private final List<T> cachedElements;
private final PrefetchableIterator<T> iterator;
@@ -41,7 +41,7 @@ class LazyCachingIteratorToIterable<T> implements
PrefetchableIterable<T> {
}
@Override
- public PrefetchableIterator<T> iterator() {
+ public PrefetchableIterator<T> createIterator() {
return new CachingIterator();
}
diff --git
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapSideInput.java
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapSideInput.java
index 2fc72fe9..e754171 100644
---
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapSideInput.java
+++
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapSideInput.java
@@ -17,7 +17,12 @@
*/
package org.apache.beam.fn.harness.state;
+import static
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
+
import java.io.IOException;
+import org.apache.beam.fn.harness.Cache;
+import org.apache.beam.fn.harness.Caches;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.transforms.Materializations.MultimapView;
@@ -25,8 +30,6 @@ import
org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
/**
* An implementation of a multimap side input that utilizes the Beam Fn State
API to fetch values.
- *
- * <p>TODO: Support block level caching and prefetch.
*/
@SuppressWarnings({
"rawtypes", // TODO(https://issues.apache.org/jira/browse/BEAM-10556)
@@ -34,44 +37,35 @@ import
org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
})
public class MultimapSideInput<K, V> implements MultimapView<K, V> {
+ private final Cache<?, ?> cache;
private final BeamFnStateClient beamFnStateClient;
- private final String instructionId;
- private final String ptransformId;
- private final String sideInputId;
- private final ByteString encodedWindow;
+ private final StateRequest keysRequest;
private final Coder<K> keyCoder;
private final Coder<V> valueCoder;
public MultimapSideInput(
+ Cache<?, ?> cache,
BeamFnStateClient beamFnStateClient,
String instructionId,
- String ptransformId,
- String sideInputId,
- ByteString encodedWindow,
+ StateKey stateKey,
Coder<K> keyCoder,
Coder<V> valueCoder) {
+ checkArgument(
+ stateKey.hasMultimapKeysSideInput(),
+ "Expected MultimapKeysSideInput StateKey but received %s.",
+ stateKey);
+ this.cache = cache;
this.beamFnStateClient = beamFnStateClient;
- this.instructionId = instructionId;
- this.ptransformId = ptransformId;
- this.sideInputId = sideInputId;
- this.encodedWindow = encodedWindow;
+ this.keysRequest =
+
StateRequest.newBuilder().setInstructionId(instructionId).setStateKey(stateKey).build();
this.keyCoder = keyCoder;
this.valueCoder = valueCoder;
}
@Override
public Iterable<K> get() {
- StateRequest.Builder requestBuilder = StateRequest.newBuilder();
- requestBuilder
- .setInstructionId(instructionId)
- .getStateKeyBuilder()
- .getMultimapKeysSideInputBuilder()
- .setTransformId(ptransformId)
- .setSideInputId(sideInputId)
- .setWindow(encodedWindow);
-
return StateFetchingIterators.readAllAndDecodeStartingFrom(
- beamFnStateClient, requestBuilder.build(), keyCoder);
+ cache, beamFnStateClient, keysRequest, keyCoder);
}
@Override
@@ -81,19 +75,26 @@ public class MultimapSideInput<K, V> implements
MultimapView<K, V> {
keyCoder.encode(k, output);
} catch (IOException e) {
throw new IllegalStateException(
- String.format("Failed to encode key %s for side input id %s.", k,
sideInputId), e);
+ String.format(
+ "Failed to encode key %s for side input id %s.",
+ k,
keysRequest.getStateKey().getMultimapKeysSideInput().getSideInputId()),
+ e);
}
- StateRequest.Builder requestBuilder = StateRequest.newBuilder();
- requestBuilder
- .setInstructionId(instructionId)
- .getStateKeyBuilder()
- .getMultimapSideInputBuilder()
- .setTransformId(ptransformId)
- .setSideInputId(sideInputId)
- .setWindow(encodedWindow)
- .setKey(output.toByteString());
+ ByteString encodedKey = output.toByteString();
+ StateKey stateKey =
+ StateKey.newBuilder()
+ .setMultimapSideInput(
+ StateKey.MultimapSideInput.newBuilder()
+ .setTransformId(
+
keysRequest.getStateKey().getMultimapKeysSideInput().getTransformId())
+ .setSideInputId(
+
keysRequest.getStateKey().getMultimapKeysSideInput().getSideInputId())
+
.setWindow(keysRequest.getStateKey().getMultimapKeysSideInput().getWindow())
+ .setKey(encodedKey))
+ .build();
+ StateRequest request =
keysRequest.toBuilder().setStateKey(stateKey).build();
return StateFetchingIterators.readAllAndDecodeStartingFrom(
- beamFnStateClient, requestBuilder.build(), valueCoder);
+ Caches.subCache(cache, "ValuesForKey", encodedKey), beamFnStateClient,
request, valueCoder);
}
}
diff --git
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java
index 293d915..70dc2b1 100644
---
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java
+++
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java
@@ -17,10 +17,12 @@
*/
package org.apache.beam.fn.harness.state;
+import static
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
import static
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
import java.io.IOException;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
@@ -28,8 +30,12 @@ import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;
+import org.apache.beam.fn.harness.Cache;
+import org.apache.beam.fn.harness.Caches;
+import
org.apache.beam.fn.harness.state.StateFetchingIterators.CachingStateIterable;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateAppendRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateClearRequest;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
@@ -38,7 +44,6 @@ import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
-import org.checkerframework.checker.nullness.qual.Nullable;
/**
* An implementation of a multimap user state that utilizes the Beam Fn State
API to fetch, clear
@@ -49,17 +54,16 @@ import org.checkerframework.checker.nullness.qual.Nullable;
*
* <p>TODO: Move to an async persist model where persistence is signalled
based upon cache memory
* pressure and its need to flush.
- *
- * <p>TODO: Support block level caching and prefetch.
*/
public class MultimapUserState<K, V> {
+ private final Cache<?, ?> cache;
private final BeamFnStateClient beamFnStateClient;
private final Coder<K> mapKeyCoder;
private final Coder<V> valueCoder;
- private final String stateId;
private final StateRequest keysStateRequest;
private final StateRequest userStateRequest;
+ private CachingStateIterable<K> persistedKeys;
private boolean isClosed;
private boolean isCleared;
@@ -67,44 +71,40 @@ public class MultimapUserState<K, V> {
private HashMap<Object, K> pendingRemoves = Maps.newHashMap();
private HashMap<Object, KV<K, List<V>>> pendingAdds = Maps.newHashMap();
// Values retrieved from persistent storage
- private HashMap<K, PrefetchableIterable<V>> persistedValues =
Maps.newHashMap();
- private @Nullable PrefetchableIterable<K> persistedKeys = null;
+ private HashMap<Object, KV<K, CachingStateIterable<V>>> persistedValues =
Maps.newHashMap();
public MultimapUserState(
+ Cache<?, ?> cache,
BeamFnStateClient beamFnStateClient,
String instructionId,
- String pTransformId,
- String stateId,
- ByteString encodedWindow,
- ByteString encodedKey,
+ StateKey stateKey,
Coder<K> mapKeyCoder,
Coder<V> valueCoder) {
+ checkArgument(
+ stateKey.hasMultimapKeysUserState(),
+ "Expected MultimapKeysUserState StateKey but received %s.",
+ stateKey);
+ this.cache = cache;
this.beamFnStateClient = beamFnStateClient;
this.mapKeyCoder = mapKeyCoder;
this.valueCoder = valueCoder;
- this.stateId = stateId;
- StateRequest.Builder keysStateRequestBuilder = StateRequest.newBuilder();
- keysStateRequestBuilder
- .setInstructionId(instructionId)
- .getStateKeyBuilder()
- .getMultimapKeysUserStateBuilder()
- .setTransformId(pTransformId)
- .setUserStateId(stateId)
- .setKey(encodedKey)
- .setWindow(encodedWindow);
- keysStateRequest = keysStateRequestBuilder.build();
+ this.keysStateRequest =
+
StateRequest.newBuilder().setInstructionId(instructionId).setStateKey(stateKey).build();
+ this.persistedKeys =
+ StateFetchingIterators.readAllAndDecodeStartingFrom(
+ cache, beamFnStateClient, keysStateRequest, mapKeyCoder);
StateRequest.Builder userStateRequestBuilder = StateRequest.newBuilder();
userStateRequestBuilder
.setInstructionId(instructionId)
.getStateKeyBuilder()
.getMultimapUserStateBuilder()
- .setTransformId(pTransformId)
- .setUserStateId(stateId)
- .setWindow(encodedWindow)
- .setKey(encodedKey);
- userStateRequest = userStateRequestBuilder.build();
+ .setTransformId(stateKey.getMultimapKeysUserState().getTransformId())
+ .setUserStateId(stateKey.getMultimapKeysUserState().getUserStateId())
+ .setWindow(stateKey.getMultimapKeysUserState().getWindow())
+ .setKey(stateKey.getMultimapKeysUserState().getKey());
+ this.userStateRequest = userStateRequestBuilder.build();
}
public void clear() {
@@ -115,7 +115,6 @@ public class MultimapUserState<K, V> {
isCleared = true;
persistedValues = Maps.newHashMap();
- persistedKeys = null;
pendingRemoves = Maps.newHashMap();
pendingAdds = Maps.newHashMap();
}
@@ -142,8 +141,7 @@ public class MultimapUserState<K, V> {
return pendingValues;
}
- PrefetchableIterable<V> persistedValues = getPersistedValues(key);
- return PrefetchableIterables.concat(persistedValues, pendingValues);
+ return PrefetchableIterables.concat(getPersistedValues(structuralKey,
key), pendingValues);
}
@SuppressWarnings({
@@ -165,15 +163,14 @@ public class MultimapUserState<K, V> {
return PrefetchableIterables.concat(keys);
}
- PrefetchableIterable<K> persistedKeys = getPersistedKeys();
Set<Object> pendingRemovesNow = new HashSet<>(pendingRemoves.keySet());
Map<Object, K> pendingAddsNow = new HashMap<>();
for (Map.Entry<Object, KV<K, List<V>>> entry : pendingAdds.entrySet()) {
pendingAddsNow.put(entry.getKey(), entry.getValue().getKey());
}
- return new PrefetchableIterable<K>() {
+ return new PrefetchableIterables.Default<K>() {
@Override
- public PrefetchableIterator<K> iterator() {
+ public PrefetchableIterator<K> createIterator() {
return new PrefetchableIterator<K>() {
PrefetchableIterator<K> persistedKeysIterator =
persistedKeys.iterator();
Iterator<K> pendingAddsNowIterator;
@@ -277,37 +274,80 @@ public class MultimapUserState<K, V> {
"Multimap user state is no longer usable because it is closed for %s",
keysStateRequest.getStateKey());
isClosed = true;
- // Nothing to persist
+ // No mutations necessary
if (!isCleared && pendingRemoves.isEmpty() && pendingAdds.isEmpty()) {
return;
}
+ startStateApiWrites();
+ updateCache();
+ }
+
+ @SuppressWarnings("FutureReturnValueIgnored")
+ private void startStateApiWrites() {
// Clear currently persisted key-values
if (isCleared) {
- beamFnStateClient
-
.handle(keysStateRequest.toBuilder().setClear(StateClearRequest.getDefaultInstance()))
- .get();
+ beamFnStateClient.handle(
+
keysStateRequest.toBuilder().setClear(StateClearRequest.getDefaultInstance()));
} else if (!pendingRemoves.isEmpty()) {
for (K key : pendingRemoves.values()) {
- beamFnStateClient
- .handle(
- createUserStateRequest(key)
- .toBuilder()
- .setClear(StateClearRequest.getDefaultInstance()))
- .get();
+ StateRequest request = createUserStateRequest(key);
+ beamFnStateClient.handle(
+
request.toBuilder().setClear(StateClearRequest.getDefaultInstance()));
}
}
// Persist pending key-values
if (!pendingAdds.isEmpty()) {
for (KV<K, List<V>> entry : pendingAdds.values()) {
- beamFnStateClient
- .handle(
- createUserStateRequest(entry.getKey())
- .toBuilder()
- .setAppend(
-
StateAppendRequest.newBuilder().setData(encodeValues(entry.getValue()))))
- .get();
+ StateRequest request = createUserStateRequest(entry.getKey());
+ beamFnStateClient.handle(
+ request
+ .toBuilder()
+ .setAppend(
+
StateAppendRequest.newBuilder().setData(encodeValues(entry.getValue()))));
+ }
+ }
+ }
+
+ private void updateCache() {
+ List<K> pendingAddsKeys = new ArrayList<>(pendingAdds.size());
+ for (KV<K, List<V>> entry : pendingAdds.values()) {
+ pendingAddsKeys.add(entry.getKey());
+ }
+
+ if (isCleared) {
+ // This will clear all keys and values since values is a sub-cache of
keys.
+ persistedKeys.clearAndAppend(pendingAddsKeys);
+
+ // Since the map was cleared we can add all the values that are pending
since we know
+ // that they must have been cleared.
+ for (Map.Entry<Object, KV<K, List<V>>> entry : pendingAdds.entrySet()) {
+ CachingStateIterable<V> iterable =
+ getPersistedValues(entry.getKey(), entry.getValue().getKey());
+ iterable.clearAndAppend(entry.getValue().getValue());
+ }
+ } else {
+ // The cast to Set<Object> is necessary since the checker framework
would like to further
+ // limit the type to Set<@KeyFor("this.pendingRemoves") Object> which is
incompatible with
+ // the API being remove(Set<Object>). We don't want to limit the API for
remove either.
+ persistedKeys.remove((Set<Object>) pendingRemoves.keySet());
+ persistedKeys.append(pendingAddsKeys);
+
+ // For each removed key, we want to update the internal cache to clear
its set of values
+ for (Map.Entry<Object, K> entry : pendingRemoves.entrySet()) {
+ CachingStateIterable<V> iterable = getPersistedValues(entry.getKey(),
entry.getValue());
+ iterable.clearAndAppend(Collections.emptyList());
+ }
+
+ // For each added key, try to update the internal cache with the set of
values.
+ for (Map.Entry<Object, KV<K, List<V>>> entry : pendingAdds.entrySet()) {
+ KV<K, CachingStateIterable<V>> value =
persistedValues.get(entry.getKey());
+ // We don't do anything for keys that haven't been loaded since we
have no knowledge whether
+ // the key is empty or not.
+ if (value != null) {
+ value.getValue().append(entry.getValue().getValue());
+ }
}
}
}
@@ -321,7 +361,10 @@ public class MultimapUserState<K, V> {
return output.toByteString();
} catch (IOException e) {
throw new IllegalStateException(
- String.format("Failed to encode values for multimap user state id
%s.", stateId), e);
+ String.format(
+ "Failed to encode values for multimap user state id %s.",
+
keysStateRequest.getStateKey().getMultimapKeysUserState().getUserStateId()),
+ e);
}
}
@@ -334,27 +377,30 @@ public class MultimapUserState<K, V> {
return request.build();
} catch (IOException e) {
throw new IllegalStateException(
- String.format("Failed to encode key for multimap user state id %s.",
stateId), e);
+ String.format(
+ "Failed to encode key for multimap user state id %s.",
+
keysStateRequest.getStateKey().getMultimapKeysUserState().getUserStateId()),
+ e);
}
}
- private PrefetchableIterable<V> getPersistedValues(K key) {
- if (!persistedValues.containsKey(key)) {
- PrefetchableIterable<V> values =
- StateFetchingIterators.readAllAndDecodeStartingFrom(
- beamFnStateClient, createUserStateRequest(key), valueCoder);
- persistedValues.put(key, values);
- }
- return persistedValues.get(key);
- }
-
- private PrefetchableIterable<K> getPersistedKeys() {
- checkState(!isCleared);
- if (persistedKeys == null) {
- persistedKeys =
- StateFetchingIterators.readAllAndDecodeStartingFrom(
- beamFnStateClient, keysStateRequest, mapKeyCoder);
- }
- return persistedKeys;
+ private CachingStateIterable<V> getPersistedValues(Object structuralKey, K
key) {
+ return persistedValues
+ .computeIfAbsent(
+ structuralKey,
+ unused -> {
+ StateRequest request = createUserStateRequest(key);
+ return KV.of(
+ key,
+ StateFetchingIterators.readAllAndDecodeStartingFrom(
+ Caches.subCache(
+ cache,
+ "ValuesForKey",
+
request.getStateKey().getMultimapUserState().getMapKey()),
+ beamFnStateClient,
+ request,
+ valueCoder));
+ })
+ .getValue();
}
}
diff --git
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
index c6bbfed..3525df9 100644
---
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
+++
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
@@ -41,6 +41,7 @@ import
org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.fn.stream.DataStreams.DataStreamDecoder;
import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
+import org.apache.beam.sdk.fn.stream.PrefetchableIterables;
import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
import org.apache.beam.sdk.fn.stream.PrefetchableIterators;
import org.apache.beam.sdk.util.Weighted;
@@ -133,7 +134,7 @@ public class StateFetchingIterators {
* all the remaining pages.
*/
@VisibleForTesting
- static class FirstPageAndRemainder<T> implements PrefetchableIterable<T> {
+ static class FirstPageAndRemainder<T> extends
PrefetchableIterables.Default<T> {
private final BeamFnStateClient beamFnStateClient;
private final StateRequest stateRequestForFirstChunk;
private final Coder<T> valueCoder;
@@ -151,7 +152,7 @@ public class StateFetchingIterators {
}
@Override
- public PrefetchableIterator<T> iterator() {
+ public PrefetchableIterator<T> createIterator() {
return new PrefetchableIterator<T>() {
PrefetchableIterator<T> delegate;
@@ -261,7 +262,7 @@ public class StateFetchingIterators {
}
/** A mutable iterable that supports prefetch and is backed by a cache. */
- static class CachingStateIterable<T> implements PrefetchableIterable<T> {
+ static class CachingStateIterable<T> extends
PrefetchableIterables.Default<T> {
/** Represents a set of elements. */
abstract static class Blocks<T> implements Weighted {
@@ -446,7 +447,7 @@ public class StateFetchingIterators {
}
@Override
- public PrefetchableIterator<T> iterator() {
+ public PrefetchableIterator<T> createIterator() {
return new CachingStateIterator();
}
diff --git
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/BagUserStateTest.java
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/BagUserStateTest.java
index ceeab22..e9c6f71 100644
---
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/BagUserStateTest.java
+++
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/BagUserStateTest.java
@@ -25,6 +25,8 @@ import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThrows;
import java.io.IOException;
+import org.apache.beam.fn.harness.Cache;
+import org.apache.beam.fn.harness.Caches;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
@@ -44,13 +46,7 @@ public class BagUserStateTest {
StringUtf8Coder.of(), ImmutableMap.of(key("A"), asList("A1", "A2",
"A3")));
BagUserState<String> userState =
new BagUserState<>(
- fakeClient,
- "instructionId",
- "ptransformId",
- "stateId",
- ByteString.copyFromUtf8("encodedWindow"),
- encode("A"),
- StringUtf8Coder.of());
+ Caches.noop(), fakeClient, "instructionId", key("A"),
StringUtf8Coder.of());
assertArrayEquals(
new String[] {"A1", "A2", "A3"}, Iterables.toArray(userState.get(),
String.class));
@@ -59,18 +55,45 @@ public class BagUserStateTest {
}
@Test
+ public void testGetCached() throws Exception {
+ FakeBeamFnStateClient fakeClient =
+ new FakeBeamFnStateClient(
+ StringUtf8Coder.of(), ImmutableMap.of(key("A"), asList("A1", "A2",
"A3")));
+
+ Cache<?, ?> cache = Caches.eternal();
+ {
+ // First user state populates the cache.
+ BagUserState<String> userState =
+ new BagUserState<>(cache, fakeClient, "instructionId", key("A"),
StringUtf8Coder.of());
+ assertArrayEquals(
+ new String[] {"A1", "A2", "A3"}, Iterables.toArray(userState.get(),
String.class));
+ userState.asyncClose();
+ }
+
+ {
+ // The next user state will load all of its contents from the cache.
+ BagUserState<String> userState =
+ new BagUserState<>(
+ cache,
+ requestBuilder -> {
+ throw new IllegalStateException("Unexpected call for test.");
+ },
+ "instructionId",
+ key("A"),
+ StringUtf8Coder.of());
+ assertArrayEquals(
+ new String[] {"A1", "A2", "A3"}, Iterables.toArray(userState.get(),
String.class));
+ userState.asyncClose();
+ }
+ }
+
+ @Test
public void testAppend() throws Exception {
FakeBeamFnStateClient fakeClient =
new FakeBeamFnStateClient(StringUtf8Coder.of(),
ImmutableMap.of(key("A"), asList("A1")));
BagUserState<String> userState =
new BagUserState<>(
- fakeClient,
- "instructionId",
- "ptransformId",
- "stateId",
- ByteString.copyFromUtf8("encodedWindow"),
- encode("A"),
- StringUtf8Coder.of());
+ Caches.noop(), fakeClient, "instructionId", key("A"),
StringUtf8Coder.of());
userState.append("A2");
Iterable<String> stateBeforeA3 = userState.get();
assertArrayEquals(new String[] {"A1", "A2"},
Iterables.toArray(stateBeforeA3, String.class));
@@ -85,19 +108,62 @@ public class BagUserStateTest {
}
@Test
+ public void testAppendCached() throws Exception {
+ FakeBeamFnStateClient fakeClient =
+ new FakeBeamFnStateClient(StringUtf8Coder.of(),
ImmutableMap.of(key("A"), asList("A1")));
+ Cache<?, ?> cache = Caches.eternal();
+ {
+ // First user state populates the cache.
+ BagUserState<String> userState =
+ new BagUserState<>(cache, fakeClient, "instructionId", key("A"),
StringUtf8Coder.of());
+ userState.append("A2");
+ Iterable<String> stateBeforeA3 = userState.get();
+ assertArrayEquals(new String[] {"A1", "A2"},
Iterables.toArray(stateBeforeA3, String.class));
+ userState.append("A3");
+ assertArrayEquals(new String[] {"A1", "A2"},
Iterables.toArray(stateBeforeA3, String.class));
+ assertArrayEquals(
+ new String[] {"A1", "A2", "A3"}, Iterables.toArray(userState.get(),
String.class));
+ userState.asyncClose();
+ }
+
+ {
+ // The next user state will load all of its contents from the cache
including the appends
+ // persisted via asyncClose.
+ BagUserState<String> userState =
+ new BagUserState<>(
+ cache,
+ requestBuilder -> {
+ if (requestBuilder.hasGet()) {
+ throw new IllegalStateException("Unexpected call for test.");
+ }
+ return fakeClient.handle(requestBuilder);
+ },
+ "instructionId",
+ key("A"),
+ StringUtf8Coder.of());
+ userState.append("A4");
+ Iterable<String> stateBeforeA5 = userState.get();
+ assertArrayEquals(
+ new String[] {"A1", "A2", "A3", "A4"},
Iterables.toArray(stateBeforeA5, String.class));
+ userState.append("A5");
+ assertArrayEquals(
+ new String[] {"A1", "A2", "A3", "A4"},
Iterables.toArray(stateBeforeA5, String.class));
+ assertArrayEquals(
+ new String[] {"A1", "A2", "A3", "A4", "A5"},
+ Iterables.toArray(userState.get(), String.class));
+ userState.asyncClose();
+ }
+ assertEquals(encode("A1", "A2", "A3", "A4", "A5"),
fakeClient.getData().get(key("A")));
+ }
+
+ @Test
public void testClear() throws Exception {
FakeBeamFnStateClient fakeClient =
new FakeBeamFnStateClient(
StringUtf8Coder.of(), ImmutableMap.of(key("A"), asList("A1", "A2",
"A3")));
BagUserState<String> userState =
new BagUserState<>(
- fakeClient,
- "instructionId",
- "ptransformId",
- "stateId",
- ByteString.copyFromUtf8("encodedWindow"),
- encode("A"),
- StringUtf8Coder.of());
+ Caches.noop(), fakeClient, "instructionId", key("A"),
StringUtf8Coder.of());
assertArrayEquals(
new String[] {"A1", "A2", "A3"}, Iterables.toArray(userState.get(),
String.class));
userState.clear();
@@ -112,6 +178,68 @@ public class BagUserStateTest {
assertThrows(IllegalStateException.class, () -> userState.clear());
}
+ @Test
+ public void testClearCached() throws Exception {
+ FakeBeamFnStateClient fakeClient =
+ new FakeBeamFnStateClient(
+ StringUtf8Coder.of(), ImmutableMap.of(key("A"), asList("A1", "A2",
"A3")));
+
+ Cache<?, ?> cache = Caches.eternal();
+ {
+ // First user state populates the cache.
+ BagUserState<String> userState =
+ new BagUserState<>(cache, fakeClient, "instructionId", key("A"),
StringUtf8Coder.of());
+ assertArrayEquals(
+ new String[] {"A1", "A2", "A3"}, Iterables.toArray(userState.get(),
String.class));
+ userState.clear();
+ assertFalse(userState.get().iterator().hasNext());
+ userState.append("A4");
+ assertArrayEquals(new String[] {"A4"},
Iterables.toArray(userState.get(), String.class));
+ userState.asyncClose();
+ }
+
+ {
+ // The next user state will load all of its contents from the cache
including the clear and
+ // append persisted via asyncClose.
+ BagUserState<String> userState =
+ new BagUserState<>(
+ cache,
+ requestBuilder -> {
+ if (requestBuilder.hasGet()) {
+ throw new IllegalStateException("Unexpected call for test.");
+ }
+ return fakeClient.handle(requestBuilder);
+ },
+ "instructionId",
+ key("A"),
+ StringUtf8Coder.of());
+ assertArrayEquals(new String[] {"A4"},
Iterables.toArray(userState.get(), String.class));
+ userState.clear();
+ assertFalse(userState.get().iterator().hasNext());
+ userState.asyncClose();
+ }
+
+ {
+ // The next user state will load all of its contents from the cache
including the clear
+ // persisted via asyncClose.
+ BagUserState<String> userState =
+ new BagUserState<>(
+ cache,
+ requestBuilder -> {
+ if (requestBuilder.hasGet()) {
+ throw new IllegalStateException("Unexpected call for test.");
+ }
+ return fakeClient.handle(requestBuilder);
+ },
+ "instructionId",
+ key("A"),
+ StringUtf8Coder.of());
+ assertArrayEquals(new String[] {}, Iterables.toArray(userState.get(),
String.class));
+ userState.asyncClose();
+ }
+ assertNull(fakeClient.getData().get(key("A")));
+ }
+
private StateKey key(String id) throws IOException {
return StateKey.newBuilder()
.setBagUserState(
diff --git
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/IterableSideInputTest.java
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/IterableSideInputTest.java
new file mode 100644
index 0000000..9e44295
--- /dev/null
+++
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/IterableSideInputTest.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness.state;
+
+import static java.util.Arrays.asList;
+import static org.junit.Assert.assertArrayEquals;
+
+import java.io.IOException;
+import org.apache.beam.fn.harness.Cache;
+import org.apache.beam.fn.harness.Caches;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
+import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class IterableSideInputTest {
+ @Test
+ public void testGet() throws Exception {
+ FakeBeamFnStateClient fakeBeamFnStateClient =
+ new FakeBeamFnStateClient(
+ StringUtf8Coder.of(),
+ ImmutableMap.of(key(), asList("A1", "A2", "A3", "A4", "A5",
"A6")));
+
+ IterableSideInput<String> iterableSideInput =
+ new IterableSideInput<>(
+ Caches.noop(), fakeBeamFnStateClient, "instructionId", key(),
StringUtf8Coder.of());
+ assertArrayEquals(
+ new String[] {"A1", "A2", "A3", "A4", "A5", "A6"},
+ Iterables.toArray(iterableSideInput.get(), String.class));
+ }
+
+ @Test
+ public void testGetCached() throws Exception {
+ FakeBeamFnStateClient fakeBeamFnStateClient =
+ new FakeBeamFnStateClient(
+ StringUtf8Coder.of(),
+ ImmutableMap.of(key(), asList("A1", "A2", "A3", "A4", "A5",
"A6")));
+
+ Cache<?, ?> cache = Caches.eternal();
+ {
+ // The first side input will populate the cache.
+ IterableSideInput<String> iterableSideInput =
+ new IterableSideInput<>(
+ cache, fakeBeamFnStateClient, "instructionId", key(),
StringUtf8Coder.of());
+ assertArrayEquals(
+ new String[] {"A1", "A2", "A3", "A4", "A5", "A6"},
+ Iterables.toArray(iterableSideInput.get(), String.class));
+ }
+
+ {
+ // The next side input will load all of its contents from the cache.
+ IterableSideInput<String> iterableSideInput =
+ new IterableSideInput<>(
+ cache,
+ requestBuilder -> {
+ throw new IllegalStateException("Unexpected call for test.");
+ },
+ "instructionId",
+ key(),
+ StringUtf8Coder.of());
+ assertArrayEquals(
+ new String[] {"A1", "A2", "A3", "A4", "A5", "A6"},
+ Iterables.toArray(iterableSideInput.get(), String.class));
+ }
+ }
+
+ private StateKey key() throws IOException {
+ return StateKey.newBuilder()
+ .setIterableSideInput(
+ StateKey.IterableSideInput.newBuilder()
+ .setTransformId("ptransformId")
+ .setSideInputId("sideInputId")
+ .setWindow(ByteString.copyFromUtf8("encodedWindow")))
+ .build();
+ }
+}
diff --git
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapSideInputTest.java
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapSideInputTest.java
index 9321ca6..81e9a45 100644
---
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapSideInputTest.java
+++
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapSideInputTest.java
@@ -22,8 +22,11 @@ import static org.junit.Assert.assertArrayEquals;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
+import org.apache.beam.fn.harness.Cache;
+import org.apache.beam.fn.harness.Caches;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
import org.apache.beam.sdk.coders.ByteArrayCoder;
+import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
@@ -33,7 +36,13 @@ import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
-/** Tests for {@link MultimapSideInput}. */
+/**
+ * Tests for {@link MultimapSideInput}.
+ *
+ * <p>It is important to use a key type where its coder is not {@link
Coder#consistentWithEquals()}
+ * to ensure that comparisons are performed using structural values instead of
object equality
+ * during testing.
+ */
@RunWith(JUnit4.class)
public class MultimapSideInputTest {
private static final byte[] A = "A".getBytes(StandardCharsets.UTF_8);
@@ -51,11 +60,10 @@ public class MultimapSideInputTest {
MultimapSideInput<byte[], String> multimapSideInput =
new MultimapSideInput<>(
+ Caches.noop(),
fakeBeamFnStateClient,
"instructionId",
- "ptransformId",
- "sideInputId",
- ByteString.copyFromUtf8("encodedWindow"),
+ stateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
assertArrayEquals(
@@ -68,6 +76,61 @@ public class MultimapSideInputTest {
new byte[][] {A, B}, Iterables.toArray(multimapSideInput.get(),
byte[].class));
}
+ @Test
+ public void testGetCached() throws Exception {
+ FakeBeamFnStateClient fakeBeamFnStateClient =
+ new FakeBeamFnStateClient(
+ ImmutableMap.of(
+ stateKey(), KV.of(ByteArrayCoder.of(), asList(A, B)),
+ key(A), KV.of(StringUtf8Coder.of(), asList("A1", "A2", "A3")),
+ key(B), KV.of(StringUtf8Coder.of(), asList("B1", "B2"))));
+
+ Cache<?, ?> cache = Caches.eternal();
+ {
+ // The first side input will populate the cache.
+ MultimapSideInput<byte[], String> multimapSideInput =
+ new MultimapSideInput<>(
+ cache,
+ fakeBeamFnStateClient,
+ "instructionId",
+ stateKey(),
+ ByteArrayCoder.of(),
+ StringUtf8Coder.of());
+ assertArrayEquals(
+ new String[] {"A1", "A2", "A3"},
+ Iterables.toArray(multimapSideInput.get(A), String.class));
+ assertArrayEquals(
+ new String[] {"B1", "B2"},
Iterables.toArray(multimapSideInput.get(B), String.class));
+ assertArrayEquals(
+ new String[] {}, Iterables.toArray(multimapSideInput.get(UNKNOWN),
String.class));
+ assertArrayEquals(
+ new byte[][] {A, B}, Iterables.toArray(multimapSideInput.get(),
byte[].class));
+ }
+
+ {
+ // The next side input will load all of its contents from the cache.
+ MultimapSideInput<byte[], String> multimapSideInput =
+ new MultimapSideInput<>(
+ cache,
+ requestBuilder -> {
+ throw new IllegalStateException("Unexpected call for test.");
+ },
+ "instructionId",
+ stateKey(),
+ ByteArrayCoder.of(),
+ StringUtf8Coder.of());
+ assertArrayEquals(
+ new String[] {"A1", "A2", "A3"},
+ Iterables.toArray(multimapSideInput.get(A), String.class));
+ assertArrayEquals(
+ new String[] {"B1", "B2"},
Iterables.toArray(multimapSideInput.get(B), String.class));
+ assertArrayEquals(
+ new String[] {}, Iterables.toArray(multimapSideInput.get(UNKNOWN),
String.class));
+ assertArrayEquals(
+ new byte[][] {A, B}, Iterables.toArray(multimapSideInput.get(),
byte[].class));
+ }
+ }
+
private StateKey stateKey() throws IOException {
return StateKey.newBuilder()
.setMultimapKeysSideInput(
diff --git
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java
index 4749c52..8eb709b 100644
---
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java
+++
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java
@@ -33,6 +33,8 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.Map;
+import org.apache.beam.fn.harness.Cache;
+import org.apache.beam.fn.harness.Caches;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
import org.apache.beam.sdk.coders.ByteArrayCoder;
import org.apache.beam.sdk.coders.Coder;
@@ -70,12 +72,10 @@ public class MultimapUserStateTest {
FakeBeamFnStateClient fakeClient = new
FakeBeamFnStateClient(Collections.emptyMap());
MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
+ Caches.noop(),
fakeClient,
"instructionId",
- pTransformId,
- stateId,
- encode(encodedWindow),
- encode(encodedKey),
+ createMultimapKeyStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
assertThat(userState.keys(), is(emptyIterable()));
@@ -92,12 +92,10 @@ public class MultimapUserStateTest {
KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
+ Caches.noop(),
fakeClient,
"instructionId",
- pTransformId,
- stateId,
- encode(encodedWindow),
- encode(encodedKey),
+ createMultimapKeyStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
@@ -122,12 +120,10 @@ public class MultimapUserStateTest {
KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
+ Caches.noop(),
fakeClient,
"instructionId",
- pTransformId,
- stateId,
- encode(encodedWindow),
- encode(encodedKey),
+ createMultimapKeyStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
@@ -158,12 +154,10 @@ public class MultimapUserStateTest {
KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
+ Caches.noop(),
fakeClient,
"instructionId",
- pTransformId,
- stateId,
- encode(encodedWindow),
- encode(encodedKey),
+ createMultimapKeyStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
@@ -192,12 +186,10 @@ public class MultimapUserStateTest {
KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
+ Caches.noop(),
fakeClient,
"instructionId",
- pTransformId,
- stateId,
- encode(encodedWindow),
- encode(encodedKey),
+ createMultimapKeyStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
@@ -221,20 +213,17 @@ public class MultimapUserStateTest {
KV.of(StringUtf8Coder.of(), asList("V1"))));
MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
+ Caches.noop(),
fakeClient,
"instructionId",
- pTransformId,
- stateId,
- encode(encodedWindow),
- encode(encodedKey),
+ createMultimapKeyStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
userState.remove(A0);
userState.put(A0, "V2");
assertArrayEquals(new String[] {"V2"},
Iterables.toArray(userState.get(A0), String.class));
userState.asyncClose();
- Map<StateKey, ByteString> data = fakeClient.getData();
- assertEquals(encode("V2"), data.get(createMultimapValueStateKey(A0)));
+ assertEquals(encode("V2"),
fakeClient.getData().get(createMultimapValueStateKey(A0)));
}
@Test
@@ -248,12 +237,10 @@ public class MultimapUserStateTest {
KV.of(StringUtf8Coder.of(), asList("V1"))));
MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
+ Caches.noop(),
fakeClient,
"instructionId",
- pTransformId,
- stateId,
- encode(encodedWindow),
- encode(encodedKey),
+ createMultimapKeyStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
userState.clear();
@@ -272,12 +259,10 @@ public class MultimapUserStateTest {
KV.of(StringUtf8Coder.of(), asList("V1"))));
MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
+ Caches.noop(),
fakeClient,
"instructionId",
- pTransformId,
- stateId,
- encode(encodedWindow),
- encode(encodedKey),
+ createMultimapKeyStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
userState.remove(A0);
@@ -292,12 +277,10 @@ public class MultimapUserStateTest {
FakeBeamFnStateClient fakeClient = new
FakeBeamFnStateClient(Collections.emptyMap());
MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
+ Caches.noop(),
fakeClient,
"instructionId",
- pTransformId,
- stateId,
- encode(encodedWindow),
- encode(encodedKey),
+ createMultimapKeyStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
userState.put(A0, "V0");
@@ -315,12 +298,10 @@ public class MultimapUserStateTest {
FakeBeamFnStateClient fakeClient = new
FakeBeamFnStateClient(Collections.emptyMap());
MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
+ Caches.noop(),
fakeClient,
"instructionId",
- pTransformId,
- stateId,
- encode(encodedWindow),
- encode(encodedKey),
+ createMultimapKeyStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
userState.put(A0, "V0");
@@ -346,12 +327,10 @@ public class MultimapUserStateTest {
KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
+ Caches.noop(),
fakeClient,
"instructionId",
- pTransformId,
- stateId,
- encode(encodedWindow),
- encode(encodedKey),
+ createMultimapKeyStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
@@ -376,12 +355,10 @@ public class MultimapUserStateTest {
KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
+ Caches.noop(),
fakeClient,
"instructionId",
- pTransformId,
- stateId,
- encode(encodedWindow),
- encode(encodedKey),
+ createMultimapKeyStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
Iterable<byte[]> keys = userState.keys();
@@ -401,12 +378,10 @@ public class MultimapUserStateTest {
KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
+ Caches.noop(),
fakeClient,
"instructionId",
- pTransformId,
- stateId,
- encode(encodedWindow),
- encode(encodedKey),
+ createMultimapKeyStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
Iterable<String> values = userState.get(A1);
@@ -426,12 +401,10 @@ public class MultimapUserStateTest {
KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
+ Caches.noop(),
fakeClient,
"instructionId",
- pTransformId,
- stateId,
- encode(encodedWindow),
- encode(encodedKey),
+ createMultimapKeyStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
userState.clear();
@@ -452,12 +425,10 @@ public class MultimapUserStateTest {
KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
+ Caches.noop(),
fakeClient,
"instructionId",
- pTransformId,
- stateId,
- encode(encodedWindow),
- encode(encodedKey),
+ createMultimapKeyStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
userState.asyncClose();
@@ -478,12 +449,10 @@ public class MultimapUserStateTest {
KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
+ Caches.noop(),
fakeClient,
"instructionId",
- pTransformId,
- stateId,
- encode(encodedWindow),
- encode(encodedKey),
+ createMultimapKeyStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
userState.remove(A0);
@@ -509,12 +478,10 @@ public class MultimapUserStateTest {
KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
+ Caches.noop(),
fakeClient,
"instructionId",
- pTransformId,
- stateId,
- encode(encodedWindow),
- encode(encodedKey),
+ createMultimapKeyStateKey(),
NullableCoder.of(ByteArrayCoder.of()),
NullableCoder.of(StringUtf8Coder.of()));
userState.put(null, null);
@@ -529,12 +496,10 @@ public class MultimapUserStateTest {
FakeBeamFnStateClient fakeClient = new
FakeBeamFnStateClient(Collections.emptyMap());
MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
+ Caches.eternal(),
fakeClient,
"instructionId",
- pTransformId,
- stateId,
- encode(encodedWindow),
- encode(encodedKey),
+ createMultimapKeyStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
assertArrayEquals(new String[] {}, Iterables.toArray(userState.get(A1),
String.class));
@@ -553,18 +518,16 @@ public class MultimapUserStateTest {
KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
+ Caches.eternal(),
fakeClient,
"instructionId",
- pTransformId,
- stateId,
- encode(encodedWindow),
- encode(encodedKey),
+ createMultimapKeyStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
PrefetchableIterable<String> values = userState.get(A1);
assertEquals(0, fakeClient.getCallCount());
- values.iterator().prefetch();
+ values.prefetch();
assertEquals(1, fakeClient.getCallCount());
assertArrayEquals(new String[] {"V1", "V2"}, Iterables.toArray(values,
String.class));
assertEquals(1, fakeClient.getCallCount());
@@ -581,18 +544,16 @@ public class MultimapUserStateTest {
KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
+ Caches.eternal(),
fakeClient,
"instructionId",
- pTransformId,
- stateId,
- encode(encodedWindow),
- encode(encodedKey),
+ createMultimapKeyStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
PrefetchableIterable<byte[]> keys = userState.keys();
assertEquals(0, fakeClient.getCallCount());
- keys.iterator().prefetch();
+ keys.prefetch();
assertEquals(1, fakeClient.getCallCount());
assertArrayEquals(new byte[][] {A1}, Iterables.toArray(keys,
byte[].class));
assertEquals(1, fakeClient.getCallCount());
@@ -609,19 +570,17 @@ public class MultimapUserStateTest {
KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
+ Caches.eternal(),
fakeClient,
"instructionId",
- pTransformId,
- stateId,
- encode(encodedWindow),
- encode(encodedKey),
+ createMultimapKeyStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
userState.put(A2, "V3");
PrefetchableIterable<byte[]> keys = userState.keys();
assertEquals(0, fakeClient.getCallCount());
- keys.iterator().prefetch();
+ keys.prefetch();
assertEquals(1, fakeClient.getCallCount());
assertArrayEquals(new byte[][] {A1, A2}, Iterables.toArray(keys,
byte[].class));
assertEquals(1, fakeClient.getCallCount());
@@ -638,12 +597,10 @@ public class MultimapUserStateTest {
KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
+ Caches.noop(),
fakeClient,
"instructionId",
- pTransformId,
- stateId,
- encode(encodedWindow),
- encode(encodedKey),
+ createMultimapKeyStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
@@ -651,7 +608,7 @@ public class MultimapUserStateTest {
userState.put(A1, "V3");
PrefetchableIterable<String> values = userState.get(A1);
assertEquals(0, fakeClient.getCallCount());
- values.iterator().prefetch();
+ values.prefetch();
// Removed keys don't require accessing the underlying persisted state
assertEquals(0, fakeClient.getCallCount());
assertArrayEquals(new String[] {"V3"}, Iterables.toArray(values,
String.class));
@@ -670,12 +627,10 @@ public class MultimapUserStateTest {
KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
+ Caches.noop(),
fakeClient,
"instructionId",
- pTransformId,
- stateId,
- encode(encodedWindow),
- encode(encodedKey),
+ createMultimapKeyStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
@@ -683,7 +638,7 @@ public class MultimapUserStateTest {
userState.put(A2, "V3");
PrefetchableIterable<byte[]> keys = userState.keys();
assertEquals(0, fakeClient.getCallCount());
- keys.iterator().prefetch();
+ keys.prefetch();
// Cleared keys don't require accessing the underlying persisted state
assertEquals(0, fakeClient.getCallCount());
assertArrayEquals(new byte[][] {A2}, Iterables.toArray(keys,
byte[].class));
@@ -702,24 +657,385 @@ public class MultimapUserStateTest {
KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
MultimapUserState<byte[], String> userState =
new MultimapUserState<>(
+ Caches.eternal(),
fakeClient,
"instructionId",
- pTransformId,
- stateId,
- encode(encodedWindow),
- encode(encodedKey),
+ createMultimapKeyStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
userState.put(A1, "V3");
PrefetchableIterable<String> values = userState.get(A1);
assertEquals(0, fakeClient.getCallCount());
- values.iterator().prefetch();
+ values.prefetch();
assertEquals(1, fakeClient.getCallCount());
assertArrayEquals(new String[] {"V1", "V2", "V3"},
Iterables.toArray(values, String.class));
assertEquals(1, fakeClient.getCallCount());
}
+ @Test
+ public void testNoPersistedValuesCached() throws Exception {
+ FakeBeamFnStateClient fakeClient = new
FakeBeamFnStateClient(Collections.emptyMap());
+ Cache<?, ?> cache = Caches.eternal();
+ {
+ // First user state populates the cache.
+ MultimapUserState<byte[], String> userState =
+ new MultimapUserState<>(
+ cache,
+ fakeClient,
+ "instructionId",
+ createMultimapKeyStateKey(),
+ ByteArrayCoder.of(),
+ StringUtf8Coder.of());
+ assertThat(userState.keys(), is(emptyIterable()));
+ assertThat(userState.get(A1), is(emptyIterable()));
+ }
+
+ {
+ // The next user state will load all of its contents from the cache.
+ MultimapUserState<byte[], String> userState =
+ new MultimapUserState<>(
+ cache,
+ requestBuilder -> {
+ throw new IllegalStateException("Unexpected call for test.");
+ },
+ "instructionId",
+ createMultimapKeyStateKey(),
+ ByteArrayCoder.of(),
+ StringUtf8Coder.of());
+ assertThat(userState.keys(), is(emptyIterable()));
+ assertThat(userState.get(A1), is(emptyIterable()));
+ }
+ }
+
+ @Test
+ public void testGetCached() throws Exception {
+ FakeBeamFnStateClient fakeClient =
+ new FakeBeamFnStateClient(
+ ImmutableMap.of(
+ createMultimapKeyStateKey(),
+ KV.of(ByteArrayCoder.of(), singletonList(A1)),
+ createMultimapValueStateKey(A1),
+ KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
+ Cache<?, ?> cache = Caches.eternal();
+ {
+ // First user state populates the cache.
+ MultimapUserState<byte[], String> userState =
+ new MultimapUserState<>(
+ cache,
+ fakeClient,
+ "instructionId",
+ createMultimapKeyStateKey(),
+ ByteArrayCoder.of(),
+ StringUtf8Coder.of());
+
+ assertArrayEquals(
+ new String[] {"V1", "V2"}, Iterables.toArray(userState.get(A1),
String.class));
+ userState.asyncClose();
+ }
+
+ {
+ // The next user state will load all of its contents from the cache.
+ MultimapUserState<byte[], String> userState =
+ new MultimapUserState<>(
+ cache,
+ requestBuilder -> {
+ throw new IllegalStateException("Unexpected call for test.");
+ },
+ "instructionId",
+ createMultimapKeyStateKey(),
+ ByteArrayCoder.of(),
+ StringUtf8Coder.of());
+
+ assertArrayEquals(
+ new String[] {"V1", "V2"}, Iterables.toArray(userState.get(A1),
String.class));
+ userState.asyncClose();
+ }
+ }
+
+ @Test
+ public void testClearCached() throws Exception {
+ FakeBeamFnStateClient fakeClient =
+ new FakeBeamFnStateClient(
+ ImmutableMap.of(
+ createMultimapKeyStateKey(),
+ KV.of(ByteArrayCoder.of(), singletonList(A1)),
+ createMultimapValueStateKey(A1),
+ KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
+ Cache<?, ?> cache = Caches.eternal();
+ {
+ // First user state populates the cache.
+ MultimapUserState<byte[], String> userState =
+ new MultimapUserState<>(
+ cache,
+ fakeClient,
+ "instructionId",
+ createMultimapKeyStateKey(),
+ ByteArrayCoder.of(),
+ StringUtf8Coder.of());
+
+ userState.clear();
+ assertThat(userState.keys(), is(emptyIterable()));
+ userState.asyncClose();
+ }
+
+ {
+ // The next user state will load all of its contents from the cache
including the mutations
+ // persisted via asyncClose.
+ MultimapUserState<byte[], String> userState =
+ new MultimapUserState<>(
+ cache,
+ requestBuilder -> {
+ throw new IllegalStateException("Unexpected call for test.");
+ },
+ "instructionId",
+ createMultimapKeyStateKey(),
+ ByteArrayCoder.of(),
+ StringUtf8Coder.of());
+
+ assertThat(userState.keys(), is(emptyIterable()));
+ userState.asyncClose();
+ }
+ }
+
+ @Test
+ public void testKeysCached() throws Exception {
+ FakeBeamFnStateClient fakeClient =
+ new FakeBeamFnStateClient(
+ ImmutableMap.of(
+ createMultimapKeyStateKey(),
+ KV.of(ByteArrayCoder.of(), singletonList(A1)),
+ createMultimapValueStateKey(A1),
+ KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
+ Cache<?, ?> cache = Caches.eternal();
+ {
+ // First user state populates the cache.
+ MultimapUserState<byte[], String> userState =
+ new MultimapUserState<>(
+ cache,
+ fakeClient,
+ "instructionId",
+ createMultimapKeyStateKey(),
+ ByteArrayCoder.of(),
+ StringUtf8Coder.of());
+
+ userState.put(A2, "V1");
+ userState.put(A3, "V1");
+ assertArrayEquals(
+ new byte[][] {A1, A2, A3}, Iterables.toArray(userState.keys(),
byte[].class));
+ userState.asyncClose();
+ }
+
+ {
+ // The next user state will load all of its contents from the cache
including the mutations
+ // persisted via asyncClose.
+ MultimapUserState<byte[], String> userState =
+ new MultimapUserState<>(
+ cache,
+ requestBuilder -> {
+ throw new IllegalStateException("Unexpected call for test.");
+ },
+ "instructionId",
+ createMultimapKeyStateKey(),
+ ByteArrayCoder.of(),
+ StringUtf8Coder.of());
+
+ assertArrayEquals(
+ new byte[][] {A1, A2, A3}, Iterables.toArray(userState.keys(),
byte[].class));
+ userState.asyncClose();
+ }
+ }
+
+ @Test
+ public void testPutCached() throws Exception {
+ FakeBeamFnStateClient fakeClient =
+ new FakeBeamFnStateClient(
+ ImmutableMap.of(
+ createMultimapKeyStateKey(),
+ KV.of(ByteArrayCoder.of(), singletonList(A1)),
+ createMultimapValueStateKey(A1),
+ KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
+ Cache<?, ?> cache = Caches.eternal();
+ {
+ // First user state populates the cache.
+ MultimapUserState<byte[], String> userState =
+ new MultimapUserState<>(
+ cache,
+ fakeClient,
+ "instructionId",
+ createMultimapKeyStateKey(),
+ ByteArrayCoder.of(),
+ StringUtf8Coder.of());
+
+ userState.put(A1, "V3");
+ userState.put(A2, "V1");
+ assertArrayEquals(
+ new String[] {"V1", "V2", "V3"},
Iterables.toArray(userState.get(A1), String.class));
+ userState.asyncClose();
+ }
+
+ {
+ // The next user state will load all of its contents from the cache
including the mutations
+ // persisted via asyncClose except for A2 since it was never loaded so
the mutation is
+ // discarded.
+ int callCount = fakeClient.getCallCount();
+ MultimapUserState<byte[], String> userState =
+ new MultimapUserState<>(
+ cache,
+ fakeClient,
+ "instructionId",
+ createMultimapKeyStateKey(),
+ ByteArrayCoder.of(),
+ StringUtf8Coder.of());
+
+ assertArrayEquals(
+ new String[] {"V1", "V2", "V3"},
Iterables.toArray(userState.get(A1), String.class));
+ assertEquals(callCount, fakeClient.getCallCount());
+ // We expect one call when loading A2 since the append would have been
discarded since the
+ // key was never fully loaded.
+ assertArrayEquals(new String[] {"V1"},
Iterables.toArray(userState.get(A2), String.class));
+ assertEquals(callCount + 1, fakeClient.getCallCount());
+ userState.asyncClose();
+ }
+ }
+
+ @Test
+ public void testPutAfterRemoveCached() throws Exception {
+ FakeBeamFnStateClient fakeClient =
+ new FakeBeamFnStateClient(
+ ImmutableMap.of(
+ createMultimapKeyStateKey(),
+ KV.of(ByteArrayCoder.of(), singletonList(A0)),
+ createMultimapValueStateKey(A0),
+ KV.of(StringUtf8Coder.of(), asList("V1"))));
+ Cache<?, ?> cache = Caches.eternal();
+ {
+ // First user state populates the cache.
+ MultimapUserState<byte[], String> userState =
+ new MultimapUserState<>(
+ cache,
+ fakeClient,
+ "instructionId",
+ createMultimapKeyStateKey(),
+ ByteArrayCoder.of(),
+ StringUtf8Coder.of());
+
+ userState.remove(A0);
+ userState.put(A0, "V2");
+ userState.asyncClose();
+ }
+
+ {
+ // The next user state will load all of its contents from the cache
including the mutations
+ // persisted via asyncClose.
+ MultimapUserState<byte[], String> userState =
+ new MultimapUserState<>(
+ cache,
+ requestBuilder -> {
+ throw new IllegalStateException("Unexpected call for test.");
+ },
+ "instructionId",
+ createMultimapKeyStateKey(),
+ ByteArrayCoder.of(),
+ StringUtf8Coder.of());
+ assertArrayEquals(new String[] {"V2"},
Iterables.toArray(userState.get(A0), String.class));
+ userState.asyncClose();
+ }
+ }
+
+ @Test
+ public void testPutAfterClearCached() throws Exception {
+ FakeBeamFnStateClient fakeClient =
+ new FakeBeamFnStateClient(
+ ImmutableMap.of(
+ createMultimapKeyStateKey(),
+ KV.of(ByteArrayCoder.of(), singletonList(A0)),
+ createMultimapValueStateKey(A0),
+ KV.of(StringUtf8Coder.of(), asList("V1"))));
+ Cache<?, ?> cache = Caches.eternal();
+ {
+ // First user state populates the cache.
+ MultimapUserState<byte[], String> userState =
+ new MultimapUserState<>(
+ cache,
+ fakeClient,
+ "instructionId",
+ createMultimapKeyStateKey(),
+ ByteArrayCoder.of(),
+ StringUtf8Coder.of());
+ userState.clear();
+ userState.put(A0, "V2");
+ userState.asyncClose();
+ }
+
+ {
+ // The next user state will load all of its contents from the cache
including the mutations
+ // persisted via asyncClose.
+ MultimapUserState<byte[], String> userState =
+ new MultimapUserState<>(
+ cache,
+ requestBuilder -> {
+ throw new IllegalStateException("Unexpected call for test.");
+ },
+ "instructionId",
+ createMultimapKeyStateKey(),
+ ByteArrayCoder.of(),
+ StringUtf8Coder.of());
+ assertArrayEquals(new String[] {"V2"},
Iterables.toArray(userState.get(A0), String.class));
+ // Even though we never load
+ assertArrayEquals(new byte[][] {A0}, Iterables.toArray(userState.keys(),
byte[].class));
+ userState.asyncClose();
+ }
+ }
+
+ @Test
+ public void testRemoveCached() throws Exception {
+ FakeBeamFnStateClient fakeClient =
+ new FakeBeamFnStateClient(
+ ImmutableMap.of(
+ createMultimapKeyStateKey(),
+ KV.of(ByteArrayCoder.of(), singletonList(A1)),
+ createMultimapValueStateKey(A1),
+ KV.of(StringUtf8Coder.of(), asList("V1", "V2"))));
+ Cache<?, ?> cache = Caches.eternal();
+ {
+ // First user state populates the cache.
+ MultimapUserState<byte[], String> userState =
+ new MultimapUserState<>(
+ cache,
+ fakeClient,
+ "instructionId",
+ createMultimapKeyStateKey(),
+ ByteArrayCoder.of(),
+ StringUtf8Coder.of());
+ assertArrayEquals(
+ new String[] {"V1", "V2"}, Iterables.toArray(userState.get(A1),
String.class));
+ userState.remove(A1);
+ userState.remove(A2);
+ assertThat(userState.keys(), is(emptyIterable()));
+ userState.asyncClose();
+ }
+
+ {
+ // The next user state will load all of its contents from the cache
including the mutations
+ // persisted via asyncClose.
+ MultimapUserState<byte[], String> userState =
+ new MultimapUserState<>(
+ cache,
+ requestBuilder -> {
+ throw new IllegalStateException("Unexpected call for test.");
+ },
+ "instructionId",
+ createMultimapKeyStateKey(),
+ ByteArrayCoder.of(),
+ StringUtf8Coder.of());
+ assertThat(userState.get(A1), is(emptyIterable()));
+ assertThat(userState.get(A2), is(emptyIterable()));
+ assertThat(userState.keys(), is(emptyIterable()));
+ userState.asyncClose();
+ }
+ }
+
private StateKey createMultimapKeyStateKey() throws IOException {
return StateKey.newBuilder()
.setMultimapKeysUserState(