This is an automated email from the ASF dual-hosted git repository.
lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 6c8d56c [BEAM-2926] Add support for side inputs to the runner harness.
6c8d56c is described below
commit 6c8d56cacfcf8a44f6c8c029706905eb2748b44a
Author: Luke Cwik <[email protected]>
AuthorDate: Wed Jan 31 09:58:57 2018 -0800
[BEAM-2926] Add support for side inputs to the runner harness.
---
.../construction/PCollectionViewTranslation.java | 12 +-
.../core/construction/ParDoTranslation.java | 4 +-
.../PCollectionViewTranslationTest.java | 74 +++
.../apache/beam/fn/harness/FnApiDoFnRunner.java | 527 +++++++++++++--------
.../apache/beam/fn/harness/state/BagUserState.java | 66 +--
.../state/LazyCachingIteratorToIterable.java | 17 +
.../beam/fn/harness/state/MultimapSideInput.java | 85 ++++
.../fn/harness/state/StateFetchingIterators.java | 28 +-
.../beam/fn/harness/FnApiDoFnRunnerTest.java | 273 ++++++++++-
.../beam/fn/harness/state/BagUserStateTest.java | 59 ++-
.../state/LazyCachingIteratorToIterableTest.java | 14 +
.../fn/harness/state/MultimapSideInputTest.java | 73 +++
.../harness/state/StateFetchingIteratorsTest.java | 2 +-
13 files changed, 957 insertions(+), 277 deletions(-)
diff --git
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PCollectionViewTranslation.java
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PCollectionViewTranslation.java
index 25361ed..ade7229 100644
---
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PCollectionViewTranslation.java
+++
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PCollectionViewTranslation.java
@@ -73,7 +73,11 @@ public class PCollectionViewTranslation {
return view;
}
- private static ViewFn<?, ?> viewFnFromProto(RunnerApi.SdkFunctionSpec viewFn)
+ /**
+ * Converts a {@link
org.apache.beam.model.pipeline.v1.RunnerApi.SdkFunctionSpec} into
+ * a {@link ViewFn} using the URN.
+ */
+ public static ViewFn<?, ?> viewFnFromProto(RunnerApi.SdkFunctionSpec viewFn)
throws InvalidProtocolBufferException {
RunnerApi.FunctionSpec spec = viewFn.getSpec();
checkArgument(
@@ -86,7 +90,11 @@ public class PCollectionViewTranslation {
spec.getPayload().toByteArray(), "Custom ViewFn");
}
- private static WindowMappingFn<?> windowMappingFnFromProto(
+ /**
+ * Converts a {@link
org.apache.beam.model.pipeline.v1.RunnerApi.SdkFunctionSpec} into
+ * a {@link WindowMappingFn} using the URN.
+ */
+ public static WindowMappingFn<?> windowMappingFnFromProto(
RunnerApi.SdkFunctionSpec windowMappingFn)
throws InvalidProtocolBufferException {
RunnerApi.FunctionSpec spec = windowMappingFn.getSpec();
diff --git
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java
index a9b3c56..6365d77 100644
---
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java
+++
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java
@@ -498,7 +498,7 @@ public class ParDoTranslation {
return builder.build();
}
- private static SdkFunctionSpec translateViewFn(ViewFn<?, ?> viewFn,
SdkComponents components) {
+ public static SdkFunctionSpec translateViewFn(ViewFn<?, ?> viewFn,
SdkComponents components) {
return SdkFunctionSpec.newBuilder()
.setEnvironmentId(components.registerEnvironment(Environments.JAVA_SDK_HARNESS_ENVIRONMENT))
.setSpec(
@@ -526,7 +526,7 @@ public class ParDoTranslation {
return payload.getSplittable();
}
- private static SdkFunctionSpec translateWindowMappingFn(
+ public static SdkFunctionSpec translateWindowMappingFn(
WindowMappingFn<?> windowMappingFn, SdkComponents components) {
return SdkFunctionSpec.newBuilder()
.setEnvironmentId(components.registerEnvironment(Environments.JAVA_SDK_HARNESS_ENVIRONMENT))
diff --git
a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PCollectionViewTranslationTest.java
b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PCollectionViewTranslationTest.java
new file mode 100644
index 0000000..85156a9
--- /dev/null
+++
b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PCollectionViewTranslationTest.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.core.construction;
+
+import static org.junit.Assert.assertEquals;
+
+import org.apache.beam.sdk.transforms.Materialization;
+import org.apache.beam.sdk.transforms.ViewFn;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests for {@link PCollectionViewTranslation}.
+ */
+@RunWith(JUnit4.class)
+public class PCollectionViewTranslationTest {
+ @Test
+ public void testViewFnTranslation() throws Exception {
+ assertEquals(new TestViewFn(),
+ PCollectionViewTranslation.viewFnFromProto(
+ ParDoTranslation.translateViewFn(new TestViewFn(),
+ SdkComponents.create())));
+ }
+
+ @Test
+ public void testWindowMappingFnTranslation() throws Exception {
+ assertEquals(new GlobalWindows().getDefaultWindowMappingFn(),
+ PCollectionViewTranslation.windowMappingFnFromProto(
+ ParDoTranslation.translateWindowMappingFn(
+ new GlobalWindows().getDefaultWindowMappingFn(),
+ SdkComponents.create())));
+ }
+
+ /** Test implementation to check for equality. */
+ private static class TestViewFn extends ViewFn<Object, Object> {
+ @Override
+ public Materialization<Object> getMaterialization() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Object apply(Object o) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ return obj instanceof TestViewFn;
+ }
+
+ @Override
+ public int hashCode() {
+ return 0;
+ }
+ }
+}
diff --git
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
index 721207a..cf3a227 100644
---
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
+++
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
@@ -17,17 +17,20 @@
*/
package org.apache.beam.fn.harness;
+import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import com.google.auto.service.AutoService;
-import com.google.common.base.Suppliers;
+import com.google.auto.value.AutoValue;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Multimap;
+import com.google.common.collect.Sets;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import java.io.IOException;
@@ -36,6 +39,7 @@ import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
+import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
@@ -43,14 +47,14 @@ import org.apache.beam.fn.harness.data.BeamFnDataClient;
import org.apache.beam.fn.harness.fn.ThrowingRunnable;
import org.apache.beam.fn.harness.state.BagUserState;
import org.apache.beam.fn.harness.state.BeamFnStateClient;
+import org.apache.beam.fn.harness.state.MultimapSideInput;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
-import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
-import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest.Builder;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload;
import org.apache.beam.runners.core.DoFnRunner;
+import org.apache.beam.runners.core.construction.PCollectionViewTranslation;
import org.apache.beam.runners.core.construction.PTransformTranslation;
import org.apache.beam.runners.core.construction.ParDoTranslation;
import org.apache.beam.runners.core.construction.RehydratedComponents;
@@ -77,6 +81,8 @@ import
org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFn.OnTimerContext;
import org.apache.beam.sdk.transforms.DoFn.ProcessContext;
+import org.apache.beam.sdk.transforms.Materializations;
+import org.apache.beam.sdk.transforms.ViewFn;
import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
@@ -86,6 +92,7 @@ import
org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
+import org.apache.beam.sdk.transforms.windowing.WindowMappingFn;
import org.apache.beam.sdk.util.CombineFnUtil;
import org.apache.beam.sdk.util.DoFnInfo;
import org.apache.beam.sdk.util.SerializableUtils;
@@ -167,10 +174,16 @@ public class FnApiDoFnRunner<InputT, OutputT> implements
DoFnRunner<InputT, Outp
(Collection<FnDataReceiver<WindowedValue<OutputT>>>)
(Collection) tagToOutputMap.get(doFnInfo.getMainOutput()),
tagToOutputMap,
+ ImmutableMap.of(),
doFnInfo.getWindowingStrategy());
registerHandlers(
- runner, pTransform, addStartFunction, addFinishFunction,
pCollectionIdsToConsumers);
+ runner,
+ pTransform,
+ ImmutableSet.of(),
+ addStartFunction,
+ addFinishFunction,
+ pCollectionIdsToConsumers);
return runner;
}
}
@@ -198,23 +211,53 @@ public class FnApiDoFnRunner<InputT, OutputT> implements
DoFnRunner<InputT, Outp
Coder<InputT> inputCoder;
WindowingStrategy<InputT, ?> windowingStrategy;
+ ImmutableMap.Builder<TupleTag<?>, SideInputSpec> tagToSideInputSpecMap =
+ ImmutableMap.builder();
+ ParDoPayload parDoPayload;
try {
RehydratedComponents rehydratedComponents =
RehydratedComponents.forComponents(
RunnerApi.Components.newBuilder()
.putAllCoders(coders).putAllWindowingStrategies(windowingStrategies).build());
- ParDoPayload parDoPayload =
ParDoPayload.parseFrom(pTransform.getSpec().getPayload());
- if (parDoPayload.getSideInputsCount() != 0) {
- throw new UnsupportedOperationException("Side inputs not yet
supported.");
- }
+ parDoPayload =
ParDoPayload.parseFrom(pTransform.getSpec().getPayload());
doFn = (DoFn) ParDoTranslation.getDoFn(parDoPayload);
mainOutputTag = (TupleTag)
ParDoTranslation.getMainOutputTag(parDoPayload);
- // There will only be one due to the check above.
- RunnerApi.PCollection mainInput = pCollections.get(
- Iterables.getOnlyElement(pTransform.getInputsMap().values()));
- inputCoder = (Coder<InputT>)
rehydratedComponents.getCoder(mainInput.getCoderId());
- windowingStrategy =
- (WindowingStrategy)
-
rehydratedComponents.getWindowingStrategy(mainInput.getWindowingStrategyId());
+ String mainInputTag = Iterables.getOnlyElement(Sets.difference(
+ pTransform.getInputsMap().keySet(),
parDoPayload.getSideInputsMap().keySet()));
+ RunnerApi.PCollection mainInput =
+ pCollections.get(pTransform.getInputsOrThrow(mainInputTag));
+ inputCoder = (Coder<InputT>) rehydratedComponents.getCoder(
+ mainInput.getCoderId());
+ windowingStrategy = (WindowingStrategy)
rehydratedComponents.getWindowingStrategy(
+ mainInput.getWindowingStrategyId());
+
+ // Build the map from tag id to side input specification
+ for (Map.Entry<String, RunnerApi.SideInput> entry
+ : parDoPayload.getSideInputsMap().entrySet()) {
+ String sideInputTag = entry.getKey();
+ RunnerApi.SideInput sideInput = entry.getValue();
+ checkArgument(
+ Materializations.MULTIMAP_MATERIALIZATION_URN.equals(
+ sideInput.getAccessPattern().getUrn()),
+ "This SDK is only capable of dealing with %s materializations "
+ + "but was asked to handle %s for PCollectionView with tag
%s.",
+ Materializations.MULTIMAP_MATERIALIZATION_URN,
+ sideInput.getAccessPattern().getUrn(),
+ sideInputTag);
+
+ RunnerApi.PCollection sideInputPCollection =
+ pCollections.get(pTransform.getInputsOrThrow(sideInputTag));
+ WindowingStrategy sideInputWindowingStrategy =
+ rehydratedComponents.getWindowingStrategy(
+ sideInputPCollection.getWindowingStrategyId());
+ tagToSideInputSpecMap.put(
+ new TupleTag<>(entry.getKey()),
+ SideInputSpec.create(
+
rehydratedComponents.getCoder(sideInputPCollection.getCoderId()),
+ sideInputWindowingStrategy.getWindowFn().windowCoder(),
+
PCollectionViewTranslation.viewFnFromProto(entry.getValue().getViewFn()),
+ PCollectionViewTranslation.windowMappingFnFromProto(
+ entry.getValue().getWindowMappingFn())));
+ }
} catch (InvalidProtocolBufferException exn) {
throw new IllegalArgumentException("Malformed ParDoPayload", exn);
} catch (IOException exn) {
@@ -241,9 +284,15 @@ public class FnApiDoFnRunner<InputT, OutputT> implements
DoFnRunner<InputT, Outp
(Collection<FnDataReceiver<WindowedValue<OutputT>>>) (Collection)
tagToConsumer.get(mainOutputTag),
tagToConsumer,
+ tagToSideInputSpecMap.build(),
windowingStrategy);
registerHandlers(
- runner, pTransform, addStartFunction, addFinishFunction,
pCollectionIdsToConsumers);
+ runner,
+ pTransform,
+ parDoPayload.getSideInputsMap().keySet(),
+ addStartFunction,
+ addFinishFunction,
+ pCollectionIdsToConsumers);
return runner;
}
}
@@ -251,14 +300,16 @@ public class FnApiDoFnRunner<InputT, OutputT> implements
DoFnRunner<InputT, Outp
private static <InputT, OutputT> void registerHandlers(
DoFnRunner<InputT, OutputT> runner,
RunnerApi.PTransform pTransform,
+ Set<String> sideInputLocalNames,
Consumer<ThrowingRunnable> addStartFunction,
Consumer<ThrowingRunnable> addFinishFunction,
Multimap<String, FnDataReceiver<WindowedValue<?>>>
pCollectionIdsToConsumers) {
// Register the appropriate handlers.
addStartFunction.accept(runner::startBundle);
- for (String pcollectionId : pTransform.getInputsMap().values()) {
+ for (String localInputName
+ : Sets.difference(pTransform.getInputsMap().keySet(),
sideInputLocalNames)) {
pCollectionIdsToConsumers.put(
- pcollectionId,
+ pTransform.getInputsOrThrow(localInputName),
(FnDataReceiver) (FnDataReceiver<WindowedValue<InputT>>)
runner::processElement);
}
addFinishFunction.accept(runner::finishBundle);
@@ -274,6 +325,8 @@ public class FnApiDoFnRunner<InputT, OutputT> implements
DoFnRunner<InputT, Outp
private final Coder<InputT> inputCoder;
private final Collection<FnDataReceiver<WindowedValue<OutputT>>>
mainOutputConsumers;
private final Multimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>>
outputMap;
+ private final Map<TupleTag<?>, SideInputSpec> sideInputSpecMap;
+ private final Map<StateKey, Object> stateKeyObjectCache;
private final WindowingStrategy windowingStrategy;
private final DoFnSignature doFnSignature;
private final DoFnInvoker<InputT, OutputT> doFnInvoker;
@@ -296,12 +349,16 @@ public class FnApiDoFnRunner<InputT, OutputT> implements
DoFnRunner<InputT, Outp
private BoundedWindow currentWindow;
/**
- * This member should only be accessed indirectly by calling
- * {@link #createOrUseCachedBagUserStateKey} and is only valid during {@link
#processElement}
- * and is null otherwise.
+ * The lifetime of this member is only valid during {@link #processElement}
+ * and only when processing a {@link KV} and is null otherwise.
*/
- private StateKey.BagUserState cachedPartialBagUserStateKey;
+ private ByteString encodedCurrentKey;
+ /**
+ * The lifetime of this member is only valid during {@link #processElement}
+ * and is null otherwise.
+ */
+ private ByteString encodedCurrentWindow;
FnApiDoFnRunner(
PipelineOptions pipelineOptions,
@@ -312,6 +369,7 @@ public class FnApiDoFnRunner<InputT, OutputT> implements
DoFnRunner<InputT, Outp
Coder<InputT> inputCoder,
Collection<FnDataReceiver<WindowedValue<OutputT>>> mainOutputConsumers,
Multimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> outputMap,
+ Map<TupleTag<?>, SideInputSpec> sideInputSpecMap,
WindowingStrategy windowingStrategy) {
this.pipelineOptions = pipelineOptions;
this.beamFnStateClient = beamFnStateClient;
@@ -321,6 +379,8 @@ public class FnApiDoFnRunner<InputT, OutputT> implements
DoFnRunner<InputT, Outp
this.inputCoder = inputCoder;
this.mainOutputConsumers = mainOutputConsumers;
this.outputMap = outputMap;
+ this.sideInputSpecMap = sideInputSpecMap;
+ this.stateKeyObjectCache = new HashMap<>();
this.windowingStrategy = windowingStrategy;
this.doFnSignature = DoFnSignatures.signatureForDoFn(doFn);
this.doFnInvoker = DoFnInvokers.invokerFor(doFn);
@@ -349,7 +409,8 @@ public class FnApiDoFnRunner<InputT, OutputT> implements
DoFnRunner<InputT, Outp
} finally {
currentElement = null;
currentWindow = null;
- cachedPartialBagUserStateKey = null;
+ encodedCurrentKey = null;
+ encodedCurrentWindow = null;
}
}
@@ -377,6 +438,9 @@ public class FnApiDoFnRunner<InputT, OutputT> implements
DoFnRunner<InputT, Outp
} catch (Exception e) {
throw new IllegalStateException(e);
}
+
+ // TODO: Support caching state data across bundle boundaries.
+ stateKeyObjectCache.clear();
}
/**
@@ -592,7 +656,7 @@ public class FnApiDoFnRunner<InputT, OutputT> implements
DoFnRunner<InputT, Outp
@Override
public <T> T sideInput(PCollectionView<T> view) {
- throw new UnsupportedOperationException("TODO: Support side inputs");
+ return bindSideInputView(view.getTagInternal());
}
@Override
@@ -705,87 +769,85 @@ public class FnApiDoFnRunner<InputT, OutputT> implements
DoFnRunner<InputT, Outp
* {@link #bindWatermark} should never be implemented.
*/
private class BeamFnStateBinder implements StateBinder {
- private final Map<StateKey.BagUserState, Object> stateObjectCache = new
HashMap<>();
-
@Override
public <T> ValueState<T> bindValue(String id, StateSpec<ValueState<T>>
spec, Coder<T> coder) {
- return (ValueState<T>) stateObjectCache.computeIfAbsent(
- createOrUseCachedBagUserStateKey(id),
- new Function<StateKey.BagUserState, Object>() {
- @Override
- public Object apply(StateKey.BagUserState s) {
- return new ValueState<T>() {
- private final BagUserState<T> impl = createBagUserState(id, coder);
-
+ return (ValueState<T>) stateKeyObjectCache.computeIfAbsent(
+ createBagUserStateKey(id),
+ new Function<StateKey, Object>() {
@Override
- public void clear() {
- impl.clear();
- }
+ public Object apply(StateKey key) {
+ return new ValueState<T>() {
+ private final BagUserState<T> impl = createBagUserState(id,
coder);
- @Override
- public void write(T input) {
- impl.clear();
- impl.append(input);
- }
+ @Override
+ public void clear() {
+ impl.clear();
+ }
- @Override
- public T read() {
- Iterator<T> value = impl.get().iterator();
- if (value.hasNext()) {
- return value.next();
- } else {
- return null;
- }
- }
+ @Override
+ public void write(T input) {
+ impl.clear();
+ impl.append(input);
+ }
- @Override
- public ValueState<T> readLater() {
- // TODO: Support prefetching.
- return this;
+ @Override
+ public T read() {
+ Iterator<T> value = impl.get().iterator();
+ if (value.hasNext()) {
+ return value.next();
+ } else {
+ return null;
+ }
+ }
+
+ @Override
+ public ValueState<T> readLater() {
+ // TODO: Support prefetching.
+ return this;
+ }
+ };
}
- };
- }
- });
+ });
}
@Override
public <T> BagState<T> bindBag(String id, StateSpec<BagState<T>> spec,
Coder<T> elemCoder) {
- return (BagState<T>) stateObjectCache.computeIfAbsent(
- createOrUseCachedBagUserStateKey(id),
- new Function<StateKey.BagUserState, Object>() {
- @Override
- public Object apply(StateKey.BagUserState s) {
- return new BagState<T>() {
- private final BagUserState<T> impl = createBagUserState(id,
elemCoder);
-
+ return (BagState<T>) stateKeyObjectCache.computeIfAbsent(
+ createBagUserStateKey(id),
+ new Function<StateKey, Object>() {
@Override
- public void add(T value) {
- impl.append(value);
- }
+ public Object apply(StateKey key) {
+ return new BagState<T>() {
+ private final BagUserState<T> impl = createBagUserState(id,
elemCoder);
- @Override
- public ReadableState<Boolean> isEmpty() {
- return
ReadableStates.immediate(!impl.get().iterator().hasNext());
- }
+ @Override
+ public void add(T value) {
+ impl.append(value);
+ }
- @Override
- public Iterable<T> read() {
- return impl.get();
- }
+ @Override
+ public ReadableState<Boolean> isEmpty() {
+ return
ReadableStates.immediate(!impl.get().iterator().hasNext());
+ }
- @Override
- public BagState<T> readLater() {
- // TODO: Support prefetching.
- return this;
- }
+ @Override
+ public Iterable<T> read() {
+ return impl.get();
+ }
- @Override
- public void clear() {
- impl.clear();
+ @Override
+ public BagState<T> readLater() {
+ // TODO: Support prefetching.
+ return this;
+ }
+
+ @Override
+ public void clear() {
+ impl.clear();
+ }
+ };
}
- };
- }
- });
+ });
}
@Override
@@ -805,77 +867,77 @@ public class FnApiDoFnRunner<InputT, OutputT> implements
DoFnRunner<InputT, Outp
String id,
StateSpec<CombiningState<InputT, AccumT, OutputT>> spec, Coder<AccumT>
accumCoder,
CombineFn<InputT, AccumT, OutputT> combineFn) {
- return (CombiningState<InputT, AccumT, OutputT>)
stateObjectCache.computeIfAbsent(
- createOrUseCachedBagUserStateKey(id),
- new Function<StateKey.BagUserState, Object>() {
- @Override
- public Object apply(StateKey.BagUserState s) {
- // TODO: Support squashing accumulators depending on whether we know
of all
- // remote accumulators and local accumulators or just local
accumulators.
- return new CombiningState<InputT, AccumT, OutputT>() {
- private final BagUserState<AccumT> impl = createBagUserState(id,
accumCoder);
-
+ return (CombiningState<InputT, AccumT, OutputT>)
stateKeyObjectCache.computeIfAbsent(
+ createBagUserStateKey(id),
+ new Function<StateKey, Object>() {
@Override
- public AccumT getAccum() {
- Iterator<AccumT> iterator = impl.get().iterator();
- if (iterator.hasNext()) {
- return iterator.next();
- }
- return combineFn.createAccumulator();
- }
+ public Object apply(StateKey key) {
+ // TODO: Support squashing accumulators depending on whether we
know of all
+ // remote accumulators and local accumulators or just local
accumulators.
+ return new CombiningState<InputT, AccumT, OutputT>() {
+ private final BagUserState<AccumT> impl =
createBagUserState(id, accumCoder);
- @Override
- public void addAccum(AccumT accum) {
- Iterator<AccumT> iterator = impl.get().iterator();
+ @Override
+ public AccumT getAccum() {
+ Iterator<AccumT> iterator = impl.get().iterator();
+ if (iterator.hasNext()) {
+ return iterator.next();
+ }
+ return combineFn.createAccumulator();
+ }
- // Only merge if there was a prior value
- if (iterator.hasNext()) {
- accum =
combineFn.mergeAccumulators(ImmutableList.of(iterator.next(), accum));
- // Since there was a prior value, we need to clear.
- impl.clear();
- }
+ @Override
+ public void addAccum(AccumT accum) {
+ Iterator<AccumT> iterator = impl.get().iterator();
- impl.append(accum);
- }
+ // Only merge if there was a prior value
+ if (iterator.hasNext()) {
+ accum =
combineFn.mergeAccumulators(ImmutableList.of(iterator.next(), accum));
+ // Since there was a prior value, we need to clear.
+ impl.clear();
+ }
- @Override
- public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
- return combineFn.mergeAccumulators(accumulators);
- }
+ impl.append(accum);
+ }
- @Override
- public CombiningState<InputT, AccumT, OutputT> readLater() {
- return this;
- }
+ @Override
+ public AccumT mergeAccumulators(Iterable<AccumT> accumulators)
{
+ return combineFn.mergeAccumulators(accumulators);
+ }
- @Override
- public OutputT read() {
- Iterator<AccumT> iterator = impl.get().iterator();
- if (iterator.hasNext()) {
- return combineFn.extractOutput(iterator.next());
- }
- return combineFn.defaultValue();
- }
+ @Override
+ public CombiningState<InputT, AccumT, OutputT> readLater() {
+ return this;
+ }
- @Override
- public void add(InputT value) {
- AccumT newAccumulator = combineFn.addInput(getAccum(), value);
- impl.clear();
- impl.append(newAccumulator);
- }
+ @Override
+ public OutputT read() {
+ Iterator<AccumT> iterator = impl.get().iterator();
+ if (iterator.hasNext()) {
+ return combineFn.extractOutput(iterator.next());
+ }
+ return combineFn.defaultValue();
+ }
- @Override
- public ReadableState<Boolean> isEmpty() {
- return
ReadableStates.immediate(!impl.get().iterator().hasNext());
- }
+ @Override
+ public void add(InputT value) {
+ AccumT newAccumulator = combineFn.addInput(getAccum(),
value);
+ impl.clear();
+ impl.append(newAccumulator);
+ }
- @Override
- public void clear() {
- impl.clear();
+ @Override
+ public ReadableState<Boolean> isEmpty() {
+ return
ReadableStates.immediate(!impl.get().iterator().hasNext());
+ }
+
+ @Override
+ public void clear() {
+ impl.clear();
+ }
+ };
}
- };
- }
- });
+ });
}
@Override
@@ -885,32 +947,25 @@ public class FnApiDoFnRunner<InputT, OutputT> implements
DoFnRunner<InputT, Outp
StateSpec<CombiningState<InputT, AccumT, OutputT>> spec,
Coder<AccumT> accumCoder,
CombineFnWithContext<InputT, AccumT, OutputT> combineFn) {
- return (CombiningState<InputT, AccumT, OutputT>)
- stateObjectCache.computeIfAbsent(
- createOrUseCachedBagUserStateKey(id),
- s ->
- bindCombining(
- id,
- spec,
- accumCoder,
- CombineFnUtil.bindContext(
- combineFn,
- new StateContext<BoundedWindow>() {
- @Override
- public PipelineOptions getPipelineOptions() {
- return pipelineOptions;
- }
-
- @Override
- public <T> T sideInput(PCollectionView<T> view) {
- return processBundleContext.sideInput(view);
- }
-
- @Override
- public BoundedWindow window() {
- return currentWindow;
- }
- })));
+ return (CombiningState<InputT, AccumT, OutputT>)
stateKeyObjectCache.computeIfAbsent(
+ createBagUserStateKey(id),
+ key -> bindCombining(id, spec, accumCoder,
CombineFnUtil.bindContext(combineFn,
+ new StateContext<BoundedWindow>() {
+ @Override
+ public PipelineOptions getPipelineOptions() {
+ return pipelineOptions;
+ }
+
+ @Override
+ public <T> T sideInput(PCollectionView<T> view) {
+ return processBundleContext.sideInput(view);
+ }
+
+ @Override
+ public BoundedWindow window() {
+ return currentWindow;
+ }
+ })));
}
/**
@@ -924,37 +979,41 @@ public class FnApiDoFnRunner<InputT, OutputT> implements
DoFnRunner<InputT, Outp
throw new UnsupportedOperationException("WatermarkHoldState is
unsupported by the Fn API.");
}
- private <T> BagUserState<T> createBagUserState(String id, Coder<T> coder) {
- BagUserState rval =
- new BagUserState<>(
- beamFnStateClient,
- id,
- coder,
- new Supplier<StateRequest.Builder>() {
- /** Memoizes the partial state key for the lifetime of the
{@link BagUserState}. */
- private final Supplier<StateKey.BagUserState>
memoizingSupplier =
- Suppliers.memoize(() ->
createOrUseCachedBagUserStateKey(id))::get;
-
- @Override
- public Builder get() {
- return StateRequest.newBuilder()
-
.setInstructionReference(processBundleInstructionId.get())
-
.setStateKey(StateKey.newBuilder().setBagUserState(memoizingSupplier.get()));
- }
- });
+ private <T> BagUserState<T> createBagUserState(
+ String stateId, Coder<T> valueCoder) {
+ BagUserState rval = new BagUserState<T>(
+ beamFnStateClient,
+ processBundleInstructionId.get(),
+ ptransformId,
+ stateId,
+ encodedCurrentWindow,
+ encodedCurrentKey,
+ valueCoder);
stateFinalizers.add(rval::asyncClose);
return rval;
}
}
+ private StateKey createBagUserStateKey(String stateId) {
+ cacheEncodedKeyAndWindowForKeyedContext();
+ StateKey.Builder builder = StateKey.newBuilder();
+ builder.getBagUserStateBuilder()
+ .setWindow(encodedCurrentWindow)
+ .setKey(encodedCurrentKey)
+ .setPtransformId(ptransformId)
+ .setUserStateId(stateId);
+ return builder.build();
+ }
+
/**
- * Memoizes a partially built {@link StateKey} saving on the encoding cost
of the key and
- * window across multiple state cells for the lifetime of {@link
#processElement}.
+ * Memoizes an encoded key and window for the current element being
processed saving on the
+ * encoding cost of the key and window across multiple state cells for the
lifetime of
+ * {@link #processElement}.
*
* <p>This should only be called during {@link #processElement}.
*/
- private <K> StateKey.BagUserState createOrUseCachedBagUserStateKey(String
id) {
- if (cachedPartialBagUserStateKey == null) {
+ private <K> void cacheEncodedKeyAndWindowForKeyedContext() {
+ if (encodedCurrentKey == null) {
checkState(currentElement.getValue() instanceof KV,
"Accessing state in unkeyed context. Current element is not a KV:
%s.",
currentElement);
@@ -976,19 +1035,85 @@ public class FnApiDoFnRunner<InputT, OutputT> implements
DoFnRunner<InputT, Outp
} catch (IOException e) {
throw new IllegalStateException(e);
}
+ encodedCurrentKey = encodedKeyOut.toByteString();
+ }
+ if (encodedCurrentWindow == null) {
ByteString.Output encodedWindowOut = ByteString.newOutput();
try {
windowingStrategy.getWindowFn().windowCoder().encode(currentWindow,
encodedWindowOut);
} catch (IOException e) {
throw new IllegalStateException(e);
}
+ encodedCurrentWindow = encodedWindowOut.toByteString();
+ }
+ }
- cachedPartialBagUserStateKey = StateKey.BagUserState.newBuilder()
- .setPtransformId(ptransformId)
- .setKey(encodedKeyOut.toByteString())
- .setWindow(encodedWindowOut.toByteString()).buildPartial();
+ /**
+ * A specification for side inputs containing a value {@link Coder},
+ * the window {@link Coder}, {@link ViewFn}, and the {@link WindowMappingFn}.
+ * @param <W>
+ */
+ @AutoValue
+ abstract static class SideInputSpec<W extends BoundedWindow> {
+ static <W extends BoundedWindow> SideInputSpec create(
+ Coder<?> coder,
+ Coder<W> windowCoder,
+ ViewFn<?, ?> viewFn,
+ WindowMappingFn<W> windowMappingFn) {
+ return new AutoValue_FnApiDoFnRunner_SideInputSpec<>(
+ coder, windowCoder, viewFn, windowMappingFn);
+ }
+
+ abstract Coder<?> getCoder();
+
+ abstract Coder<W> getWindowCoder();
+
+ abstract ViewFn<?, ?> getViewFn();
+
+ abstract WindowMappingFn<W> getWindowMappingFn();
+ }
+
+ private <T, K, V> T bindSideInputView(TupleTag<?> view) {
+ SideInputSpec sideInputSpec = sideInputSpecMap.get(view);
+ checkArgument(sideInputSpec != null,
+ "Attempting to access unknown side input %s.",
+ view);
+ KvCoder<K, V> kvCoder = (KvCoder) sideInputSpec.getCoder();
+
+ ByteString.Output encodedWindowOut = ByteString.newOutput();
+ try {
+ sideInputSpec.getWindowCoder().encode(
+
sideInputSpec.getWindowMappingFn().getSideInputWindow(currentWindow),
encodedWindowOut);
+ } catch (IOException e) {
+ throw new IllegalStateException(e);
}
- return cachedPartialBagUserStateKey.toBuilder().setUserStateId(id).build();
+ ByteString encodedWindow = encodedWindowOut.toByteString();
+
+ StateKey.Builder cacheKeyBuilder = StateKey.newBuilder();
+ cacheKeyBuilder.getMultimapSideInputBuilder()
+ .setPtransformId(ptransformId)
+ .setSideInputId(view.getId())
+ .setWindow(encodedWindow);
+ return (T) stateKeyObjectCache.computeIfAbsent(
+ cacheKeyBuilder.build(),
+ key -> sideInputSpec.getViewFn().apply(createMultimapSideInput(
+ view.getId(), encodedWindow, kvCoder.getKeyCoder(),
kvCoder.getValueCoder())));
+ }
+
+ private <K, V> MultimapSideInput<K, V> createMultimapSideInput(
+ String sideInputId,
+ ByteString encodedWindow,
+ Coder<K> keyCoder,
+ Coder<V> valueCoder) {
+
+ return new MultimapSideInput<>(
+ beamFnStateClient,
+ processBundleInstructionId.get(),
+ ptransformId,
+ sideInputId,
+ encodedWindow,
+ keyCoder,
+ valueCoder);
}
}
diff --git
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java
index f2e852c..1b08e58 100644
---
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java
+++
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java
@@ -23,12 +23,10 @@ import com.google.common.collect.Iterables;
import com.google.protobuf.ByteString;
import java.util.ArrayList;
import java.util.Collections;
-import java.util.List;
import java.util.concurrent.CompletableFuture;
-import java.util.function.Supplier;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateAppendRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateClearRequest;
-import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest.Builder;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.fn.stream.DataStreams;
@@ -46,62 +44,76 @@ import org.apache.beam.sdk.fn.stream.DataStreams;
*/
public class BagUserState<T> {
private final BeamFnStateClient beamFnStateClient;
- private final String stateId;
- private final Coder<T> coder;
- private final Supplier<Builder> partialRequestSupplier;
+ private final StateRequest request;
+ private final Coder<T> valueCoder;
private Iterable<T> oldValues;
private ArrayList<T> newValues;
- private List<T> unmodifiableNewValues;
private boolean isClosed;
public BagUserState(
BeamFnStateClient beamFnStateClient,
+ String instructionId,
+ String ptransformId,
String stateId,
- Coder<T> coder,
- Supplier<Builder> partialRequestSupplier) {
+ ByteString encodedWindow,
+ ByteString encodedKey,
+ Coder<T> valueCoder) {
this.beamFnStateClient = beamFnStateClient;
- this.stateId = stateId;
- this.coder = coder;
- this.partialRequestSupplier = partialRequestSupplier;
+ this.valueCoder = valueCoder;
+
+ StateRequest.Builder requestBuilder = StateRequest.newBuilder();
+ requestBuilder
+ .setInstructionReference(instructionId)
+ .getStateKeyBuilder()
+ .getBagUserStateBuilder()
+ .setPtransformId(ptransformId)
+ .setUserStateId(stateId)
+ .setWindow(encodedWindow)
+ .setKey(encodedKey);
+ request = requestBuilder.build();
+
this.oldValues = new LazyCachingIteratorToIterable<>(
- new DataStreams.DataStreamDecoder(coder,
+ new DataStreams.DataStreamDecoder(valueCoder,
DataStreams.inbound(
- StateFetchingIterators.usingPartialRequestWithStateKey(
+ StateFetchingIterators.forFirstChunk(
beamFnStateClient,
- partialRequestSupplier))));
+ request))));
this.newValues = new ArrayList<>();
- this.unmodifiableNewValues = Collections.unmodifiableList(newValues);
}
public Iterable<T> get() {
checkState(!isClosed,
- "Bag user state is no longer usable because it is closed for %s",
stateId);
- // If we were cleared we should disregard old values.
+ "Bag user state is no longer usable because it is closed for %s",
request.getStateKey());
if (oldValues == null) {
- return unmodifiableNewValues;
+ // If we were cleared we should disregard old values.
+ return Iterables.limit(Collections.unmodifiableList(newValues),
newValues.size());
+ } else if (newValues.isEmpty()) {
+ // If we have no new values then just return the old values.
+ return oldValues;
}
- return Iterables.concat(oldValues, unmodifiableNewValues);
+ return Iterables.concat(oldValues,
+ Iterables.limit(Collections.unmodifiableList(newValues),
newValues.size()));
}
public void append(T t) {
checkState(!isClosed,
- "Bag user state is no longer usable because it is closed for %s",
stateId);
+ "Bag user state is no longer usable because it is closed for %s",
request.getStateKey());
newValues.add(t);
}
public void clear() {
checkState(!isClosed,
- "Bag user state is no longer usable because it is closed for %s",
stateId);
+ "Bag user state is no longer usable because it is closed for %s",
request.getStateKey());
oldValues = null;
- newValues.clear();
+ newValues = new ArrayList<>();
}
public void asyncClose() throws Exception {
checkState(!isClosed,
- "Bag user state is no longer usable because it is closed for %s",
stateId);
+ "Bag user state is no longer usable because it is closed for %s",
request.getStateKey());
if (oldValues == null) {
beamFnStateClient.handle(
- partialRequestSupplier.get()
+ request.toBuilder()
.setClear(StateClearRequest.getDefaultInstance()),
new CompletableFuture<>());
}
@@ -109,10 +121,10 @@ public class BagUserState<T> {
ByteString.Output out = ByteString.newOutput();
for (T newValue : newValues) {
// TODO: Replace with chunking output stream
- coder.encode(newValue, out);
+ valueCoder.encode(newValue, out);
}
beamFnStateClient.handle(
- partialRequestSupplier.get()
+ request.toBuilder()
.setAppend(StateAppendRequest.newBuilder().setData(out.toByteString())),
new CompletableFuture<>());
}
diff --git
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java
index 0a43317..0a6232c 100644
---
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java
+++
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java
@@ -17,6 +17,7 @@
*/
package org.apache.beam.fn.harness.state;
+import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
@@ -69,4 +70,20 @@ class LazyCachingIteratorToIterable<T> implements
Iterable<T> {
return rval;
}
}
+
+ @Override
+ public int hashCode() {
+ return iterator.hasNext() ? iterator.next().hashCode() : -1789023489;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ return obj instanceof Iterable
+ && Iterables.elementsEqual(this, (Iterable) obj);
+ }
+
+ @Override
+ public String toString() {
+ return Iterables.toString(this);
+ }
}
diff --git
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapSideInput.java
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapSideInput.java
new file mode 100644
index 0000000..874d0fc
--- /dev/null
+++
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapSideInput.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness.state;
+
+import com.google.protobuf.ByteString;
+import java.io.IOException;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.fn.stream.DataStreams;
+import org.apache.beam.sdk.transforms.Materializations.MultimapView;
+
+/**
+ * An implementation of a multimap side input that utilizes the Beam Fn State
API to fetch values.
+ *
+ * <p>TODO: Support block level caching and prefetch.
+ */
+public class MultimapSideInput<K, V> implements MultimapView<K, V> {
+
+ private final BeamFnStateClient beamFnStateClient;
+ private final String instructionId;
+ private final String ptransformId;
+ private final String sideInputId;
+ private final ByteString encodedWindow;
+ private final Coder<K> keyCoder;
+ private final Coder<V> valueCoder;
+
+ public MultimapSideInput(
+ BeamFnStateClient beamFnStateClient,
+ String instructionId,
+ String ptransformId,
+ String sideInputId,
+ ByteString encodedWindow,
+ Coder<K> keyCoder,
+ Coder<V> valueCoder) {
+ this.beamFnStateClient = beamFnStateClient;
+ this.instructionId = instructionId;
+ this.ptransformId = ptransformId;
+ this.sideInputId = sideInputId;
+ this.encodedWindow = encodedWindow;
+ this.keyCoder = keyCoder;
+ this.valueCoder = valueCoder;
+ }
+
+ public Iterable<V> get(K k) {
+ ByteString.Output output = ByteString.newOutput();
+ try {
+ keyCoder.encode(k, output);
+ } catch (IOException e) {
+ throw new IllegalStateException(
+ String.format("Failed to encode key %s for side input id %s.", k,
sideInputId),
+ e);
+ }
+ StateRequest.Builder requestBuilder = StateRequest.newBuilder();
+ requestBuilder
+ .setInstructionReference(instructionId)
+ .getStateKeyBuilder()
+ .getMultimapSideInputBuilder()
+ .setPtransformId(ptransformId)
+ .setSideInputId(sideInputId)
+ .setWindow(encodedWindow)
+ .setKey(output.toByteString());
+
+ return new LazyCachingIteratorToIterable<>(
+ new DataStreams.DataStreamDecoder(valueCoder,
+ DataStreams.inbound(
+ StateFetchingIterators.forFirstChunk(
+ beamFnStateClient,
+ requestBuilder.build()))));
+ }
+}
diff --git
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
index b64c946..683314a 100644
---
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
+++
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
@@ -23,10 +23,8 @@ import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
-import java.util.function.Supplier;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateGetRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
-import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest.Builder;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse;
/**
@@ -40,18 +38,18 @@ public class StateFetchingIterators {
/**
* This adapter handles using the continuation token to provide iteration
over all the chunks
- * returned by the Beam Fn State API using the supplied state client and
partially filled
- * out state request containing a state key.
+ * returned by the Beam Fn State API using the supplied state client and
state request for
+ * the first chunk of the state stream.
*
* @param beamFnStateClient A client for handling state requests.
- * @param partialStateRequestBuilder A {@link StateRequest} with the
- * {@link StateRequest#getStateKey()} already set.
- * @return An {@code Iterator<ByteString>} representing all the requested
data.
+ * @param stateRequestForFirstChunk A fully populated state request for the
first (and possibly
+ * only) chunk of a state stream. This state request will be populated with
a continuation token
+ * to request further chunks of the stream if required.
*/
- public static Iterator<ByteString> usingPartialRequestWithStateKey(
+ public static Iterator<ByteString> forFirstChunk(
BeamFnStateClient beamFnStateClient,
- Supplier<StateRequest.Builder> partialStateRequestBuilder) {
- return new LazyBlockingStateFetchingIterator(beamFnStateClient,
partialStateRequestBuilder);
+ StateRequest stateRequestForFirstChunk) {
+ return new LazyBlockingStateFetchingIterator(beamFnStateClient,
stateRequestForFirstChunk);
}
/**
@@ -63,18 +61,17 @@ public class StateFetchingIterators {
static class LazyBlockingStateFetchingIterator implements
Iterator<ByteString> {
private enum State { READ_REQUIRED, HAS_NEXT, EOF };
private final BeamFnStateClient beamFnStateClient;
- /** Allows for the partially built state request to be memoized across
many requests. */
- private final Supplier<Builder> stateRequestSupplier;
+ private final StateRequest stateRequestForFirstChunk;
private State currentState;
private ByteString continuationToken;
private ByteString next;
LazyBlockingStateFetchingIterator(
BeamFnStateClient beamFnStateClient,
- Supplier<StateRequest.Builder> stateRequestSupplier) {
+ StateRequest stateRequestForFirstChunk) {
this.currentState = State.READ_REQUIRED;
this.beamFnStateClient = beamFnStateClient;
- this.stateRequestSupplier = stateRequestSupplier;
+ this.stateRequestForFirstChunk = stateRequestForFirstChunk;
this.continuationToken = ByteString.EMPTY;
}
@@ -86,7 +83,7 @@ public class StateFetchingIterators {
case READ_REQUIRED:
CompletableFuture<StateResponse> stateResponseFuture = new
CompletableFuture<>();
beamFnStateClient.handle(
- stateRequestSupplier.get().setGet(
+ stateRequestForFirstChunk.toBuilder().setGet(
StateGetRequest.newBuilder().setContinuationToken(continuationToken)),
stateResponseFuture);
StateResponse stateResponse;
@@ -122,5 +119,4 @@ public class StateFetchingIterators {
return next;
}
}
-
}
diff --git
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
index 70aca2e..22bcebd 100644
---
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
+++
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
@@ -23,6 +23,7 @@ import static
org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.hasSize;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail;
@@ -45,6 +46,9 @@ import org.apache.beam.fn.harness.state.FakeBeamFnStateClient;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.core.construction.ParDoTranslation;
+import org.apache.beam.runners.core.construction.PipelineTranslation;
+import org.apache.beam.runners.core.construction.SdkComponents;
+import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
@@ -57,17 +61,28 @@ import org.apache.beam.sdk.state.ValueState;
import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
import org.apache.beam.sdk.transforms.CombineWithContext.Context;
+import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.FixedWindows;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.util.DoFnInfo;
import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.hamcrest.collection.IsMapContaining;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -291,10 +306,10 @@ public class FnApiDoFnRunnerTest {
.build();
FakeBeamFnStateClient fakeClient = new
FakeBeamFnStateClient(ImmutableMap.of(
- key("value", "X"), encode("X0"),
- key("bag", "X"), encode("X0"),
- key("combine", "X"), encode("X0"),
- key("combineWithContext", "X"), encode("X0")
+ bagUserStateKey("value", "X"), encode("X0"),
+ bagUserStateKey("bag", "X"), encode("X0"),
+ bagUserStateKey("combine", "X"), encode("X0"),
+ bagUserStateKey("combineWithContext", "X"), encode("X0")
));
List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
@@ -355,21 +370,21 @@ public class FnApiDoFnRunnerTest {
assertEquals(
ImmutableMap.<StateKey, ByteString>builder()
- .put(key("value", "X"), encode("X2"))
- .put(key("bag", "X"), encode("X0", "X1", "X2"))
- .put(key("combine", "X"), encode("X0X1X2"))
- .put(key("combineWithContext", "X"), encode("X0X1X2"))
- .put(key("value", "Y"), encode("Y2"))
- .put(key("bag", "Y"), encode("Y1", "Y2"))
- .put(key("combine", "Y"), encode("Y1Y2"))
- .put(key("combineWithContext", "Y"), encode("Y1Y2"))
+ .put(bagUserStateKey("value", "X"), encode("X2"))
+ .put(bagUserStateKey("bag", "X"), encode("X0", "X1", "X2"))
+ .put(bagUserStateKey("combine", "X"), encode("X0X1X2"))
+ .put(bagUserStateKey("combineWithContext", "X"), encode("X0X1X2"))
+ .put(bagUserStateKey("value", "Y"), encode("Y2"))
+ .put(bagUserStateKey("bag", "Y"), encode("Y1", "Y2"))
+ .put(bagUserStateKey("combine", "Y"), encode("Y1Y2"))
+ .put(bagUserStateKey("combineWithContext", "Y"), encode("Y1Y2"))
.build(),
fakeClient.getData());
mainOutputValues.clear();
}
- /** Produces a {@link StateKey} for the test PTransform id in the Global
Window. */
- private StateKey key(String userStateId, String key) throws IOException {
+ /** Produces a bag user {@link StateKey} for the test PTransform id in the
global window. */
+ private StateKey bagUserStateKey(String userStateId, String key) throws
IOException {
return StateKey.newBuilder().setBagUserState(
StateKey.BagUserState.newBuilder()
.setPtransformId(TEST_PTRANSFORM_ID)
@@ -380,6 +395,236 @@ public class FnApiDoFnRunnerTest {
.build();
}
+ private static class TestSideInputDoFn extends DoFn<String, String> {
+ private final PCollectionView<String> defaultSingletonSideInput;
+ private final PCollectionView<String> singletonSideInput;
+ private final PCollectionView<Iterable<String>> iterableSideInput;
+ private TestSideInputDoFn(
+ PCollectionView<String> defaultSingletonSideInput,
+ PCollectionView<String> singletonSideInput,
+ PCollectionView<Iterable<String>> iterableSideInput) {
+ this.defaultSingletonSideInput = defaultSingletonSideInput;
+ this.singletonSideInput = singletonSideInput;
+ this.iterableSideInput = iterableSideInput;
+ }
+
+ @ProcessElement
+ public void processElement(ProcessContext context) {
+ context.output(context.element() + ":" +
context.sideInput(defaultSingletonSideInput));
+ context.output(context.element() + ":" +
context.sideInput(singletonSideInput));
+ for (String sideInputValue : context.sideInput(iterableSideInput)) {
+ context.output(context.element() + ":" + sideInputValue);
+ }
+ }
+ }
+
+ @Test
+ public void testUsingSideInput() throws Exception {
+ Pipeline p = Pipeline.create();
+ PCollection<String> valuePCollection = p.apply(Create.of("unused"));
+ PCollectionView<String> defaultSingletonSideInputView =
valuePCollection.apply(
+ View.<String>asSingleton().withDefaultValue("defaultSingletonValue"));
+ PCollectionView<String> singletonSideInputView =
valuePCollection.apply(View.asSingleton());
+ PCollectionView<Iterable<String>> iterableSideInputView =
+ valuePCollection.apply(View.asIterable());
+ PCollection<String> outputPCollection =
valuePCollection.apply(TEST_PTRANSFORM_ID, ParDo.of(
+ new TestSideInputDoFn(
+ defaultSingletonSideInputView,
+ singletonSideInputView,
+ iterableSideInputView))
+ .withSideInputs(
+ defaultSingletonSideInputView, singletonSideInputView,
iterableSideInputView));
+
+ SdkComponents sdkComponents = SdkComponents.create();
+ RunnerApi.Pipeline pProto = PipelineTranslation.toProto(p, sdkComponents);
+ String inputPCollectionId =
sdkComponents.registerPCollection(valuePCollection);
+ String outputPCollectionId =
sdkComponents.registerPCollection(outputPCollection);
+
+ RunnerApi.PTransform pTransform =
pProto.getComponents().getTransformsOrThrow(
+
pProto.getComponents().getTransformsOrThrow(TEST_PTRANSFORM_ID).getSubtransforms(0));
+
+ ImmutableMap<StateKey, ByteString> stateData = ImmutableMap.of(
+ multimapSideInputKey(singletonSideInputView.getTagInternal().getId(),
ByteString.EMPTY),
+ encode("singletonValue"),
+ multimapSideInputKey(iterableSideInputView.getTagInternal().getId(),
ByteString.EMPTY),
+ encode("iterableValue1", "iterableValue2", "iterableValue3"));
+
+ FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(stateData);
+
+ List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
+ Multimap<String, FnDataReceiver<WindowedValue<?>>> consumers =
HashMultimap.create();
+
consumers.put(Iterables.getOnlyElement(pTransform.getOutputsMap().values()),
+ (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>)
mainOutputValues::add);
+ List<ThrowingRunnable> startFunctions = new ArrayList<>();
+ List<ThrowingRunnable> finishFunctions = new ArrayList<>();
+
+ new FnApiDoFnRunner.NewFactory<>().createRunnerForPTransform(
+ PipelineOptionsFactory.create(),
+ null /* beamFnDataClient */,
+ fakeClient,
+ TEST_PTRANSFORM_ID,
+ pTransform,
+ Suppliers.ofInstance("57L")::get,
+ pProto.getComponents().getPcollectionsMap(),
+ pProto.getComponents().getCodersMap(),
+ pProto.getComponents().getWindowingStrategiesMap(),
+ consumers,
+ startFunctions::add,
+ finishFunctions::add);
+
+ Iterables.getOnlyElement(startFunctions).run();
+ mainOutputValues.clear();
+
+ assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId,
outputPCollectionId));
+
+ // Ensure that bag user state that is initially empty or populated works.
+ // Ensure that the bagUserStateKey order does not matter when we traverse
over KV pairs.
+ FnDataReceiver<WindowedValue<?>> mainInput =
+ Iterables.getOnlyElement(consumers.get(inputPCollectionId));
+ mainInput.accept(valueInGlobalWindow("X"));
+ mainInput.accept(valueInGlobalWindow("Y"));
+ assertThat(mainOutputValues, contains(
+ valueInGlobalWindow("X:defaultSingletonValue"),
+ valueInGlobalWindow("X:singletonValue"),
+ valueInGlobalWindow("X:iterableValue1"),
+ valueInGlobalWindow("X:iterableValue2"),
+ valueInGlobalWindow("X:iterableValue3"),
+ valueInGlobalWindow("Y:defaultSingletonValue"),
+ valueInGlobalWindow("Y:singletonValue"),
+ valueInGlobalWindow("Y:iterableValue1"),
+ valueInGlobalWindow("Y:iterableValue2"),
+ valueInGlobalWindow("Y:iterableValue3")));
+ mainOutputValues.clear();
+
+ Iterables.getOnlyElement(finishFunctions).run();
+ assertThat(mainOutputValues, empty());
+
+ // Assert that state data did not change
+ assertEquals(stateData, fakeClient.getData());
+ mainOutputValues.clear();
+ }
+
+ private static class TestSideInputIsAccessibleForDownstreamCallersDoFn
+ extends DoFn<String, Iterable<String>> {
+ private final PCollectionView<Iterable<String>> iterableSideInput;
+ private TestSideInputIsAccessibleForDownstreamCallersDoFn(
+ PCollectionView<Iterable<String>> iterableSideInput) {
+ this.iterableSideInput = iterableSideInput;
+ }
+
+ @ProcessElement
+ public void processElement(ProcessContext context) {
+ context.output(context.sideInput(iterableSideInput));
+ }
+ }
+
+ @Test
+ public void testSideInputIsAccessibleForDownstreamCallers() throws Exception
{
+ FixedWindows windowFn = FixedWindows.of(Duration.millis(1L));
+ IntervalWindow windowA = windowFn.assignWindow(new Instant(1L));
+ IntervalWindow windowB = windowFn.assignWindow(new Instant(2L));
+ ByteString encodedWindowA =
+
ByteString.copyFrom(CoderUtils.encodeToByteArray(windowFn.windowCoder(),
windowA));
+ ByteString encodedWindowB =
+
ByteString.copyFrom(CoderUtils.encodeToByteArray(windowFn.windowCoder(),
windowB));
+
+ Pipeline p = Pipeline.create();
+ PCollection<String> valuePCollection = p.apply(Create.of("unused"))
+ .apply(Window.into(windowFn));
+ PCollectionView<Iterable<String>> iterableSideInputView =
+ valuePCollection.apply(View.asIterable());
+ PCollection<Iterable<String>> outputPCollection =
+ valuePCollection.apply(TEST_PTRANSFORM_ID, ParDo.of(
+ new
TestSideInputIsAccessibleForDownstreamCallersDoFn(iterableSideInputView))
+ .withSideInputs(iterableSideInputView));
+
+ SdkComponents sdkComponents = SdkComponents.create();
+ RunnerApi.Pipeline pProto = PipelineTranslation.toProto(p, sdkComponents);
+ String inputPCollectionId =
sdkComponents.registerPCollection(valuePCollection);
+ String outputPCollectionId =
sdkComponents.registerPCollection(outputPCollection);
+
+ RunnerApi.PTransform pTransform =
pProto.getComponents().getTransformsOrThrow(
+
pProto.getComponents().getTransformsOrThrow(TEST_PTRANSFORM_ID).getSubtransforms(0));
+
+ ImmutableMap<StateKey, ByteString> stateData = ImmutableMap.of(
+ multimapSideInputKey(
+ iterableSideInputView.getTagInternal().getId(), ByteString.EMPTY,
encodedWindowA),
+ encode("iterableValue1A", "iterableValue2A", "iterableValue3A"),
+ multimapSideInputKey(
+ iterableSideInputView.getTagInternal().getId(), ByteString.EMPTY,
encodedWindowB),
+ encode("iterableValue1B", "iterableValue2B", "iterableValue3B"));
+
+ FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(stateData);
+
+ List<WindowedValue<Iterable<String>>> mainOutputValues = new ArrayList<>();
+ Multimap<String, FnDataReceiver<WindowedValue<?>>> consumers =
HashMultimap.create();
+
consumers.put(Iterables.getOnlyElement(pTransform.getOutputsMap().values()),
+ (FnDataReceiver) (FnDataReceiver<WindowedValue<Iterable<String>>>)
mainOutputValues::add);
+ List<ThrowingRunnable> startFunctions = new ArrayList<>();
+ List<ThrowingRunnable> finishFunctions = new ArrayList<>();
+
+ new FnApiDoFnRunner.NewFactory<>().createRunnerForPTransform(
+ PipelineOptionsFactory.create(),
+ null /* beamFnDataClient */,
+ fakeClient,
+ TEST_PTRANSFORM_ID,
+ pTransform,
+ Suppliers.ofInstance("57L")::get,
+ pProto.getComponents().getPcollectionsMap(),
+ pProto.getComponents().getCodersMap(),
+ pProto.getComponents().getWindowingStrategiesMap(),
+ consumers,
+ startFunctions::add,
+ finishFunctions::add);
+
+ Iterables.getOnlyElement(startFunctions).run();
+ mainOutputValues.clear();
+
+ assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId,
outputPCollectionId));
+
+ // Ensure that bag user state that is initially empty or populated works.
+ // Ensure that the bagUserStateKey order does not matter when we traverse
over KV pairs.
+ FnDataReceiver<WindowedValue<?>> mainInput =
+ Iterables.getOnlyElement(consumers.get(inputPCollectionId));
+ mainInput.accept(valueInWindow("X", windowA));
+ mainInput.accept(valueInWindow("Y", windowB));
+ assertThat(mainOutputValues, hasSize(2));
+ assertThat(mainOutputValues.get(0).getValue(), contains(
+ "iterableValue1A", "iterableValue2A", "iterableValue3A"));
+ assertThat(mainOutputValues.get(1).getValue(), contains(
+ "iterableValue1B", "iterableValue2B", "iterableValue3B"));
+
+ // Assert that state data did not change
+ assertEquals(stateData, fakeClient.getData());
+ }
+
+ private <T> WindowedValue<T> valueInWindow(T value, BoundedWindow window) {
+ return WindowedValue.of(value, window.maxTimestamp(), window,
PaneInfo.ON_TIME_AND_ONLY_FIRING);
+ }
+
+ /**
+ * Produces a multimap side input {@link StateKey} for the test PTransform
id in the global
+ * window.
+ */
+ private StateKey multimapSideInputKey(String sideInputId, ByteString key)
throws IOException {
+ return multimapSideInputKey(sideInputId, key, ByteString.copyFrom(
+ CoderUtils.encodeToByteArray(GlobalWindow.Coder.INSTANCE,
GlobalWindow.INSTANCE)));
+ }
+
+ /**
+ * Produces a multimap side input {@link StateKey} for the test PTransform
id in the supplied
+ * window.
+ */
+ private StateKey multimapSideInputKey(String sideInputId, ByteString key,
ByteString windowKey) {
+ return StateKey.newBuilder().setMultimapSideInput(
+ StateKey.MultimapSideInput.newBuilder()
+ .setPtransformId(TEST_PTRANSFORM_ID)
+ .setSideInputId(sideInputId)
+ .setKey(key)
+ .setWindow(windowKey))
+ .build();
+ }
+
private ByteString encode(String ... values) throws IOException {
ByteString.Output out = ByteString.newOutput();
for (String value : values) {
diff --git
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/BagUserStateTest.java
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/BagUserStateTest.java
index 6d3e078..29c4a8a 100644
---
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/BagUserStateTest.java
+++
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/BagUserStateTest.java
@@ -19,6 +19,7 @@ package org.apache.beam.fn.harness.state;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import com.google.common.collect.ImmutableMap;
@@ -26,7 +27,6 @@ import com.google.common.collect.Iterables;
import com.google.protobuf.ByteString;
import java.io.IOException;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
-import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.junit.Rule;
import org.junit.Test;
@@ -44,7 +44,14 @@ public class BagUserStateTest {
FakeBeamFnStateClient fakeClient = new
FakeBeamFnStateClient(ImmutableMap.of(
key("A"), encode("A1", "A2", "A3")));
BagUserState<String> userState =
- new BagUserState<>(fakeClient, "A", StringUtf8Coder.of(), () ->
requestForId("A"));
+ new BagUserState<>(
+ fakeClient,
+ "instructionId",
+ "ptransformId",
+ "stateId",
+ ByteString.copyFromUtf8("encodedWindow"),
+ encode("A"),
+ StringUtf8Coder.of());
assertArrayEquals(new String[]{ "A1", "A2", "A3" },
Iterables.toArray(userState.get(), String.class));
@@ -58,9 +65,23 @@ public class BagUserStateTest {
FakeBeamFnStateClient fakeClient = new
FakeBeamFnStateClient(ImmutableMap.of(
key("A"), encode("A1")));
BagUserState<String> userState =
- new BagUserState<>(fakeClient, "A", StringUtf8Coder.of(), () ->
requestForId("A"));
+ new BagUserState<>(
+ fakeClient,
+ "instructionId",
+ "ptransformId",
+ "stateId",
+ ByteString.copyFromUtf8("encodedWindow"),
+ encode("A"),
+ StringUtf8Coder.of());
userState.append("A2");
+ Iterable<String> stateBeforeA3 = userState.get();
+ assertArrayEquals(new String[]{ "A1", "A2" },
+ Iterables.toArray(stateBeforeA3, String.class));
userState.append("A3");
+ assertArrayEquals(new String[]{ "A1", "A2" },
+ Iterables.toArray(stateBeforeA3, String.class));
+ assertArrayEquals(new String[]{ "A1", "A2", "A3" },
+ Iterables.toArray(userState.get(), String.class));
userState.asyncClose();
assertEquals(encode("A1", "A2", "A3"), fakeClient.getData().get(key("A")));
@@ -73,11 +94,23 @@ public class BagUserStateTest {
FakeBeamFnStateClient fakeClient = new
FakeBeamFnStateClient(ImmutableMap.of(
key("A"), encode("A1", "A2", "A3")));
BagUserState<String> userState =
- new BagUserState<>(fakeClient, "A", StringUtf8Coder.of(), () ->
requestForId("A"));
-
+ new BagUserState<>(
+ fakeClient,
+ "instructionId",
+ "ptransformId",
+ "stateId",
+ ByteString.copyFromUtf8("encodedWindow"),
+ encode("A"),
+ StringUtf8Coder.of());
+ assertArrayEquals(new String[]{ "A1", "A2", "A3" },
+ Iterables.toArray(userState.get(), String.class));
userState.clear();
- userState.append("A1");
+ assertFalse(userState.get().iterator().hasNext());
+ userState.append("A4");
+ assertArrayEquals(new String[]{ "A4" },
+ Iterables.toArray(userState.get(), String.class));
userState.clear();
+ assertFalse(userState.get().iterator().hasNext());
userState.asyncClose();
assertNull(fakeClient.getData().get(key("A")));
@@ -85,15 +118,13 @@ public class BagUserStateTest {
userState.clear();
}
- private StateRequest.Builder requestForId(String id) {
- return StateRequest.newBuilder().setStateKey(
- StateKey.newBuilder().setBagUserState(
-
StateKey.BagUserState.newBuilder().setKey(ByteString.copyFromUtf8(id))));
- }
-
- private StateKey key(String id) {
+ private StateKey key(String id) throws IOException {
return StateKey.newBuilder().setBagUserState(
-
StateKey.BagUserState.newBuilder().setKey(ByteString.copyFromUtf8(id))).build();
+ StateKey.BagUserState.newBuilder()
+ .setPtransformId("ptransformId")
+ .setUserStateId("stateId")
+ .setWindow(ByteString.copyFromUtf8("encodedWindow"))
+ .setKey(encode(id))).build();
}
private ByteString encode(String ... values) throws IOException {
diff --git
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterableTest.java
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterableTest.java
index 53eefb4..1e44452 100644
---
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterableTest.java
+++
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterableTest.java
@@ -20,6 +20,7 @@ package org.apache.beam.fn.harness.state;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertTrue;
import com.google.common.collect.Iterables;
@@ -73,4 +74,17 @@ public class LazyCachingIteratorToIterableTest {
thrown.expect(NoSuchElementException.class);
iterator1.next();
}
+
+ @Test
+ public void testEqualsAndHashCode() {
+ Iterable<String> iterA = new
LazyCachingIteratorToIterable<>(Iterators.forArray("A", "B", "C"));
+ Iterable<String> iterB = new
LazyCachingIteratorToIterable<>(Iterators.forArray("A", "B", "C"));
+ Iterable<String> iterC = new
LazyCachingIteratorToIterable<>(Iterators.forArray());
+ Iterable<String> iterD = new
LazyCachingIteratorToIterable<>(Iterators.forArray());
+ assertEquals(iterA, iterB);
+ assertEquals(iterC, iterD);
+ assertNotEquals(iterA, iterC);
+ assertEquals(iterA.hashCode(), iterB.hashCode());
+ assertEquals(iterC.hashCode(), iterD.hashCode());
+ }
}
diff --git
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapSideInputTest.java
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapSideInputTest.java
new file mode 100644
index 0000000..39c0cbd
--- /dev/null
+++
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapSideInputTest.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness.state;
+
+import static org.junit.Assert.assertArrayEquals;
+
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Iterables;
+import com.google.protobuf.ByteString;
+import java.io.IOException;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests for {@link MultimapSideInput}. */
+@RunWith(JUnit4.class)
+public class MultimapSideInputTest {
+ @Test
+ public void testGet() throws Exception {
+ FakeBeamFnStateClient fakeBeamFnStateClient = new
FakeBeamFnStateClient(ImmutableMap.of(
+ key("A"), encode("A1", "A2", "A3"),
+ key("B"), encode("B1", "B2")));
+
+ MultimapSideInput<String, String> multimapSideInput = new
MultimapSideInput<>(
+ fakeBeamFnStateClient,
+ "instructionId",
+ "ptransformId",
+ "sideInputId",
+ ByteString.copyFromUtf8("encodedWindow"),
+ StringUtf8Coder.of(),
+ StringUtf8Coder.of());
+ assertArrayEquals(new String[]{ "A1", "A2", "A3" },
+ Iterables.toArray(multimapSideInput.get("A"), String.class));
+ assertArrayEquals(new String[]{ "B1", "B2" },
+ Iterables.toArray(multimapSideInput.get("B"), String.class));
+ assertArrayEquals(new String[]{ },
+ Iterables.toArray(multimapSideInput.get("unknown"), String.class));
+ }
+
+ private StateKey key(String id) throws IOException {
+ return StateKey.newBuilder().setMultimapSideInput(
+ StateKey.MultimapSideInput.newBuilder()
+ .setPtransformId("ptransformId")
+ .setSideInputId("sideInputId")
+ .setWindow(ByteString.copyFromUtf8("encodedWindow"))
+ .setKey(encode(id))).build();
+ }
+
+ private ByteString encode(String ... values) throws IOException {
+ ByteString.Output out = ByteString.newOutput();
+ for (String value : values) {
+ StringUtf8Coder.of().encode(value, out);
+ }
+ return out.toByteString();
+ }
+}
diff --git
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java
index 6ddec56..b4f37ab 100644
---
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java
+++
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java
@@ -91,7 +91,7 @@ public class StateFetchingIteratorsTest {
.build());
};
Iterator<ByteString> byteStrings =
- new LazyBlockingStateFetchingIterator(fakeStateClient,
StateRequest::newBuilder);
+ new LazyBlockingStateFetchingIterator(fakeStateClient,
StateRequest.getDefaultInstance());
assertArrayEquals(expected, Iterators.toArray(byteStrings,
Object.class));
}
}
--
To stop receiving notification emails like this one, please contact
[email protected].