Repository: beam Updated Branches: refs/heads/master f6c840533 -> 585440d22
[BEAM-1347] Create value state, combining state, and bag state views over the BagUserState. Also bind the state persistence to the end of finishBundle. Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/e0f628cc Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/e0f628cc Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/e0f628cc Branch: refs/heads/master Commit: e0f628cc7fbf6cbfb46825d6ee7bbc29e0bd66f5 Parents: f6c8405 Author: Luke Cwik <lc...@google.com> Authored: Tue Aug 29 10:45:04 2017 -0700 Committer: Luke Cwik <lc...@google.com> Committed: Wed Aug 30 14:30:27 2017 -0700 ---------------------------------------------------------------------- .../apache/beam/fn/harness/FnApiDoFnRunner.java | 380 ++++++++++++++++++- .../beam/fn/harness/FnApiDoFnRunnerTest.java | 229 +++++++++++ .../fn/harness/state/FakeBeamFnStateClient.java | 2 +- 3 files changed, 605 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/e0f628cc/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java index d325bb2..c361647 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java @@ -18,45 +18,77 @@ package org.apache.beam.fn.harness; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; import com.google.auto.service.AutoService; +import com.google.common.base.Suppliers; import com.google.common.collect.Collections2; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.Multimap; import com.google.protobuf.ByteString; +import java.io.IOException; +import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.Map; import java.util.Objects; import java.util.function.Consumer; +import java.util.function.Function; import java.util.function.Supplier; import org.apache.beam.fn.harness.data.BeamFnDataClient; import org.apache.beam.fn.harness.fn.ThrowingConsumer; import org.apache.beam.fn.harness.fn.ThrowingRunnable; +import org.apache.beam.fn.harness.state.BagUserState; import org.apache.beam.fn.harness.state.BeamFnStateClient; +import org.apache.beam.fn.v1.BeamFnApi.StateKey; +import org.apache.beam.fn.v1.BeamFnApi.StateRequest; +import org.apache.beam.fn.v1.BeamFnApi.StateRequest.Builder; import org.apache.beam.runners.core.DoFnRunner; import org.apache.beam.runners.core.construction.ParDoTranslation; import org.apache.beam.runners.dataflow.util.DoFnInfo; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.common.runner.v1.RunnerApi; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.CombiningState; +import org.apache.beam.sdk.state.MapState; +import org.apache.beam.sdk.state.ReadableState; +import org.apache.beam.sdk.state.ReadableStates; +import org.apache.beam.sdk.state.SetState; import org.apache.beam.sdk.state.State; +import org.apache.beam.sdk.state.StateBinder; +import org.apache.beam.sdk.state.StateContext; +import org.apache.beam.sdk.state.StateSpec; import org.apache.beam.sdk.state.TimeDomain; import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.state.WatermarkHoldState; +import org.apache.beam.sdk.transforms.Combine.CombineFn; +import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFn.OnTimerContext; import org.apache.beam.sdk.transforms.DoFn.ProcessContext; import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; import org.apache.beam.sdk.transforms.reflect.DoFnSignature; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.StateDeclaration; import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; +import org.apache.beam.sdk.util.CombineFnUtil; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder; +import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.WindowingStrategy; @@ -141,7 +173,13 @@ public class FnApiDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Outp @SuppressWarnings({"unchecked", "rawtypes"}) DoFnRunner<InputT, OutputT> runner = new FnApiDoFnRunner<>( pipelineOptions, + beamFnStateClient, + pTransformId, + processBundleInstructionId, doFnInfo.getDoFn(), + WindowedValue.getFullCoder( + doFnInfo.getInputCoder(), + doFnInfo.getWindowingStrategy().getWindowFn().windowCoder()), (Collection<ThrowingConsumer<WindowedValue<OutputT>>>) (Collection) tagToOutputMap.get(doFnInfo.getOutputMap().get(doFnInfo.getMainOutput())), tagToOutputMap, @@ -162,42 +200,68 @@ public class FnApiDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Outp ////////////////////////////////////////////////////////////////////////////////////////////////// private final PipelineOptions pipelineOptions; + private final BeamFnStateClient beamFnStateClient; + private final String ptransformId; + private final Supplier<String> processBundleInstructionId; private final DoFn<InputT, OutputT> doFn; + private final WindowedValueCoder<InputT> inputCoder; private final Collection<ThrowingConsumer<WindowedValue<OutputT>>> mainOutputConsumers; private final Multimap<TupleTag<?>, ThrowingConsumer<WindowedValue<?>>> outputMap; + private final WindowingStrategy windowingStrategy; + private final DoFnSignature doFnSignature; private final DoFnInvoker<InputT, OutputT> doFnInvoker; + private final StateBinder stateBinder; private final StartBundleContext startBundleContext; private final ProcessBundleContext processBundleContext; private final FinishBundleContext finishBundleContext; - private final WindowingStrategy windowingStrategy; - private final DoFnSignature doFnSignature; + private final Collection<ThrowingRunnable> stateFinalizers; /** - * The lifetime of this member is only valid during {@link #processElement(WindowedValue)}. + * The lifetime of this member is only valid during {@link #processElement} + * and is null otherwise. */ private WindowedValue<InputT> currentElement; /** - * The lifetime of this member is only valid during {@link #processElement(WindowedValue)}. + * The lifetime of this member is only valid during {@link #processElement} + * and is null otherwise. */ private BoundedWindow currentWindow; + /** + * This member should only be accessed indirectly by calling + * {@link #createOrUseCachedBagUserStateKey} and is only valid during {@link #processElement} + * and is null otherwise. + */ + private StateKey.BagUserState cachedPartialBagUserStateKey; + + FnApiDoFnRunner( PipelineOptions pipelineOptions, + BeamFnStateClient beamFnStateClient, + String ptransformId, + Supplier<String> processBundleInstructionId, DoFn<InputT, OutputT> doFn, + WindowedValueCoder<InputT> inputCoder, Collection<ThrowingConsumer<WindowedValue<OutputT>>> mainOutputConsumers, Multimap<TupleTag<?>, ThrowingConsumer<WindowedValue<?>>> outputMap, WindowingStrategy windowingStrategy) { this.pipelineOptions = pipelineOptions; + this.beamFnStateClient = beamFnStateClient; + this.ptransformId = ptransformId; + this.processBundleInstructionId = processBundleInstructionId; this.doFn = doFn; + this.inputCoder = inputCoder; this.mainOutputConsumers = mainOutputConsumers; this.outputMap = outputMap; this.windowingStrategy = windowingStrategy; this.doFnSignature = DoFnSignatures.signatureForDoFn(doFn); this.doFnInvoker = DoFnInvokers.invokerFor(doFn); + this.stateBinder = new BeamFnStateBinder(); this.startBundleContext = new StartBundleContext(); this.processBundleContext = new ProcessBundleContext(); this.finishBundleContext = new FinishBundleContext(); + this.stateFinalizers = new ArrayList<>(); } @Override @@ -218,6 +282,7 @@ public class FnApiDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Outp } finally { currentElement = null; currentWindow = null; + cachedPartialBagUserStateKey = null; } } @@ -233,6 +298,18 @@ public class FnApiDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Outp @Override public void finishBundle() { doFnInvoker.invokeFinishBundle(finishBundleContext); + + // Persist all dirty state cells + try { + for (ThrowingRunnable runnable : stateFinalizers) { + runnable.run(); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IllegalStateException(e); + } catch (Exception e) { + throw new IllegalStateException(e); + } } /** @@ -367,7 +444,15 @@ public class FnApiDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Outp @Override public State state(String stateId) { - throw new UnsupportedOperationException("TODO: Add support for state"); + StateDeclaration stateDeclaration = doFnSignature.stateDeclarations().get(stateId); + checkNotNull(stateDeclaration, "No state declaration found for %s", stateId); + StateSpec<?> spec; + try { + spec = (StateSpec<?>) stateDeclaration.field().get(doFn); + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } + return spec.bind(stateId, stateBinder); } @Override @@ -545,4 +630,289 @@ public class FnApiDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Outp WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING)); } } + + /** + * A {@link StateBinder} that uses the Beam Fn State API to read and write user state. + * + * <p>TODO: Add support for {@link #bindMap} and {@link #bindSet}. Note that + * {@link #bindWatermark} should never be implemented. + */ + private class BeamFnStateBinder implements StateBinder { + private final Map<StateKey.BagUserState, Object> stateObjectCache = new HashMap<>(); + + @Override + public <T> ValueState<T> bindValue(String id, StateSpec<ValueState<T>> spec, Coder<T> coder) { + return (ValueState<T>) stateObjectCache.computeIfAbsent( + createOrUseCachedBagUserStateKey(id), + new Function<StateKey.BagUserState, Object>() { + @Override + public Object apply(StateKey.BagUserState s) { + return new ValueState<T>() { + private final BagUserState<T> impl = createBagUserState(id, coder); + + @Override + public void clear() { + impl.clear(); + } + + @Override + public void write(T input) { + impl.clear(); + impl.append(input); + } + + @Override + public T read() { + Iterator<T> value = impl.get().iterator(); + if (value.hasNext()) { + return value.next(); + } else { + return null; + } + } + + @Override + public ValueState<T> readLater() { + // TODO: Support prefetching. + return this; + } + }; + } + }); + } + + @Override + public <T> BagState<T> bindBag(String id, StateSpec<BagState<T>> spec, Coder<T> elemCoder) { + return (BagState<T>) stateObjectCache.computeIfAbsent( + createOrUseCachedBagUserStateKey(id), + new Function<StateKey.BagUserState, Object>() { + @Override + public Object apply(StateKey.BagUserState s) { + return new BagState<T>() { + private final BagUserState<T> impl = createBagUserState(id, elemCoder); + + @Override + public void add(T value) { + impl.append(value); + } + + @Override + public ReadableState<Boolean> isEmpty() { + return ReadableStates.immediate(!impl.get().iterator().hasNext()); + } + + @Override + public Iterable<T> read() { + return impl.get(); + } + + @Override + public BagState<T> readLater() { + // TODO: Support prefetching. + return this; + } + + @Override + public void clear() { + impl.clear(); + } + }; + } + }); + } + + @Override + public <T> SetState<T> bindSet(String id, StateSpec<SetState<T>> spec, Coder<T> elemCoder) { + throw new UnsupportedOperationException("TODO: Add support for a map state to the Fn API."); + } + + @Override + public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(String id, + StateSpec<MapState<KeyT, ValueT>> spec, Coder<KeyT> mapKeyCoder, + Coder<ValueT> mapValueCoder) { + throw new UnsupportedOperationException("TODO: Add support for a map state to the Fn API."); + } + + @Override + public <InputT, AccumT, OutputT> CombiningState<InputT, AccumT, OutputT> bindCombining( + String id, + StateSpec<CombiningState<InputT, AccumT, OutputT>> spec, Coder<AccumT> accumCoder, + CombineFn<InputT, AccumT, OutputT> combineFn) { + return (CombiningState<InputT, AccumT, OutputT>) stateObjectCache.computeIfAbsent( + createOrUseCachedBagUserStateKey(id), + new Function<StateKey.BagUserState, Object>() { + @Override + public Object apply(StateKey.BagUserState s) { + // TODO: Support squashing accumulators depending on whether we know of all + // remote accumulators and local accumulators or just local accumulators. + return new CombiningState<InputT, AccumT, OutputT>() { + private final BagUserState<AccumT> impl = createBagUserState(id, accumCoder); + + @Override + public AccumT getAccum() { + Iterator<AccumT> iterator = impl.get().iterator(); + if (iterator.hasNext()) { + return iterator.next(); + } + return combineFn.createAccumulator(); + } + + @Override + public void addAccum(AccumT accum) { + Iterator<AccumT> iterator = impl.get().iterator(); + + // Only merge if there was a prior value + if (iterator.hasNext()) { + accum = combineFn.mergeAccumulators(ImmutableList.of(iterator.next(), accum)); + // Since there was a prior value, we need to clear. + impl.clear(); + } + + impl.append(accum); + } + + @Override + public AccumT mergeAccumulators(Iterable<AccumT> accumulators) { + return combineFn.mergeAccumulators(accumulators); + } + + @Override + public CombiningState<InputT, AccumT, OutputT> readLater() { + return this; + } + + @Override + public OutputT read() { + Iterator<AccumT> iterator = impl.get().iterator(); + if (iterator.hasNext()) { + return combineFn.extractOutput(iterator.next()); + } + return combineFn.defaultValue(); + } + + @Override + public void add(InputT value) { + AccumT newAccumulator = combineFn.addInput(getAccum(), value); + impl.clear(); + impl.append(newAccumulator); + } + + @Override + public ReadableState<Boolean> isEmpty() { + return ReadableStates.immediate(!impl.get().iterator().hasNext()); + } + + @Override + public void clear() { + impl.clear(); + } + }; + } + }); + } + + @Override + public <InputT, AccumT, OutputT> CombiningState<InputT, AccumT, OutputT> + bindCombiningWithContext( + String id, + StateSpec<CombiningState<InputT, AccumT, OutputT>> spec, + Coder<AccumT> accumCoder, + CombineFnWithContext<InputT, AccumT, OutputT> combineFn) { + return (CombiningState<InputT, AccumT, OutputT>) stateObjectCache.computeIfAbsent( + createOrUseCachedBagUserStateKey(id), + new Function<StateKey.BagUserState, Object>() { + @Override + public Object apply(StateKey.BagUserState s) { + return bindCombining(id, spec, accumCoder, CombineFnUtil.bindContext(combineFn, + new StateContext<BoundedWindow>() { + @Override + public PipelineOptions getPipelineOptions() { + return pipelineOptions; + } + + @Override + public <T> T sideInput(PCollectionView<T> view) { + return processBundleContext.sideInput(view); + } + + @Override + public BoundedWindow window() { + return currentWindow; + } + })); + } + }); + } + + /** + * @deprecated The Fn API has no plans to implement WatermarkHoldState as of this writing + * and is waiting on resolution of BEAM-2535. + */ + @Override + @Deprecated + public WatermarkHoldState bindWatermark(String id, StateSpec<WatermarkHoldState> spec, + TimestampCombiner timestampCombiner) { + throw new UnsupportedOperationException("WatermarkHoldState is unsupported by the Fn API."); + } + + private <T> BagUserState<T> createBagUserState(String id, Coder<T> coder) { + BagUserState rval = new BagUserState<T>( + beamFnStateClient, + id, + coder, + new Supplier<StateRequest.Builder>() { + /** Memoizes the partial state key for the lifetime of the {@link BagUserState}. */ + private final Supplier<StateKey.BagUserState> memoizingSupplier = + Suppliers.memoize(() -> createOrUseCachedBagUserStateKey(id))::get; + + @Override + public Builder get() { + return StateRequest.newBuilder() + .setInstructionReference(processBundleInstructionId.get()) + .setStateKey(StateKey.newBuilder() + .setBagUserState(memoizingSupplier.get())); + } + }); + stateFinalizers.add(rval::asyncClose); + return rval; + } + } + + /** + * Memoizes a partially built {@link StateKey} saving on the encoding cost of the key and + * window across multiple state cells for the lifetime of {@link #processElement}. + * + * <p>This should only be called during {@link #processElement}. + */ + private <K> StateKey.BagUserState createOrUseCachedBagUserStateKey(String id) { + if (cachedPartialBagUserStateKey == null) { + checkState(currentElement.getValue() instanceof KV, + "Accessing state in unkeyed context. Current element is not a KV: %s.", + currentElement); + checkState(inputCoder.getCoderArguments().get(0) instanceof KvCoder, + "Accessing state in unkeyed context. No keyed coder found."); + + ByteString.Output encodedKeyOut = ByteString.newOutput(); + + Coder<K> keyCoder = ((KvCoder<K, ?>) inputCoder.getValueCoder()).getKeyCoder(); + try { + keyCoder.encode(((KV<K, ?>) currentElement.getValue()).getKey(), encodedKeyOut); + } catch (IOException e) { + throw new IllegalStateException(e); + } + + ByteString.Output encodedWindowOut = ByteString.newOutput(); + try { + windowingStrategy.getWindowFn().windowCoder().encode(currentWindow, encodedWindowOut); + } catch (IOException e) { + throw new IllegalStateException(e); + } + + cachedPartialBagUserStateKey = StateKey.BagUserState.newBuilder() + .setPtransformId(ptransformId) + .setKey(encodedKeyOut.toByteString()) + .setWindow(encodedWindowOut.toByteString()).buildPartial(); + } + return cachedPartialBagUserStateKey.toBuilder().setUserStateId(id).build(); + } } http://git-wip-us.apache.org/repos/asf/beam/blob/e0f628cc/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java index ebec608..4aa8080 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java @@ -22,6 +22,8 @@ import static org.apache.beam.sdk.util.WindowedValue.timestampedValueInGlobalWin import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.empty; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; @@ -32,22 +34,36 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import com.google.common.collect.Multimap; import com.google.protobuf.ByteString; +import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.ServiceLoader; import org.apache.beam.fn.harness.PTransformRunnerFactory.Registrar; import org.apache.beam.fn.harness.fn.ThrowingConsumer; import org.apache.beam.fn.harness.fn.ThrowingRunnable; +import org.apache.beam.fn.harness.state.FakeBeamFnStateClient; +import org.apache.beam.fn.v1.BeamFnApi.StateKey; import org.apache.beam.runners.core.construction.ParDoTranslation; import org.apache.beam.runners.dataflow.util.DoFnInfo; +import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.common.runner.v1.RunnerApi; import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.CombiningState; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.transforms.Combine.CombineFn; +import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext; +import org.apache.beam.sdk.transforms.CombineWithContext.Context; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.WindowingStrategy; import org.hamcrest.collection.IsMapContaining; @@ -58,6 +74,9 @@ import org.junit.runners.JUnit4; /** Tests for {@link FnApiDoFnRunner}. */ @RunWith(JUnit4.class) public class FnApiDoFnRunnerTest { + + public static final String TEST_PTRANSFORM_ID = "pTransformId"; + private static class TestDoFn extends DoFn<String, String> { private static final TupleTag<String> mainOutput = new TupleTag<>("mainOutput"); private static final TupleTag<String> additionalOutput = new TupleTag<>("output"); @@ -164,6 +183,216 @@ public class FnApiDoFnRunnerTest { mainOutputValues.clear(); } + private static class ConcatCombineFn extends CombineFn<String, String, String> { + @Override + public String createAccumulator() { + return ""; + } + + @Override + public String addInput(String accumulator, String input) { + return accumulator.concat(input); + } + + @Override + public String mergeAccumulators(Iterable<String> accumulators) { + StringBuilder builder = new StringBuilder(); + for (String value : accumulators) { + builder.append(value); + } + return builder.toString(); + } + + @Override + public String extractOutput(String accumulator) { + return accumulator; + } + } + + private static class ConcatCombineFnWithContext + extends CombineFnWithContext<String, String, String> { + @Override + public String createAccumulator(Context c) { + return ""; + } + + @Override + public String addInput(String accumulator, String input, Context c) { + return accumulator.concat(input); + } + + @Override + public String mergeAccumulators(Iterable<String> accumulators, Context c) { + StringBuilder builder = new StringBuilder(); + for (String value : accumulators) { + builder.append(value); + } + return builder.toString(); + } + + @Override + public String extractOutput(String accumulator, Context c) { + return accumulator; + } + } + + private static class TestStatefulDoFn extends DoFn<KV<String, String>, String> { + private static final TupleTag<String> mainOutput = new TupleTag<>("mainOutput"); + private static final TupleTag<String> additionalOutput = new TupleTag<>("output"); + + @StateId("value") + private final StateSpec<ValueState<String>> valueStateSpec = + StateSpecs.value(StringUtf8Coder.of()); + @StateId("bag") + private final StateSpec<BagState<String>> bagStateSpec = + StateSpecs.bag(StringUtf8Coder.of()); + @StateId("combine") + private final StateSpec<CombiningState<String, String, String>> combiningStateSpec = + StateSpecs.combining(StringUtf8Coder.of(), new ConcatCombineFn()); + @StateId("combineWithContext") + private final StateSpec<CombiningState<String, String, String>> combiningWithContextStateSpec = + StateSpecs.combining(StringUtf8Coder.of(), new ConcatCombineFnWithContext()); + + @ProcessElement + public void processElement(ProcessContext context, + @StateId("value") ValueState<String> valueState, + @StateId("bag") BagState<String> bagState, + @StateId("combine") CombiningState<String, String, String> combiningState, + @StateId("combineWithContext") + CombiningState<String, String, String> combiningWithContextState) { + context.output("value:" + valueState.read()); + valueState.write(context.element().getValue()); + + context.output("bag:" + Iterables.toString(bagState.read())); + bagState.add(context.element().getValue()); + + context.output("combine:" + combiningState.read()); + combiningState.add(context.element().getValue()); + + context.output("combineWithContext:" + combiningWithContextState.read()); + combiningWithContextState.add(context.element().getValue()); + } + } + + @Test + public void testUsingUserState() throws Exception { + String mainOutputId = "101"; + + DoFnInfo<?, ?> doFnInfo = DoFnInfo.forFn( + new TestStatefulDoFn(), + WindowingStrategy.globalDefault(), + ImmutableList.of(), + KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()), + Long.parseLong(mainOutputId), + ImmutableMap.of(Long.parseLong(mainOutputId), new TupleTag<String>("mainOutput"))); + RunnerApi.FunctionSpec functionSpec = + RunnerApi.FunctionSpec.newBuilder() + .setUrn(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN) + .setPayload(ByteString.copyFrom(SerializableUtils.serializeToByteArray(doFnInfo))) + .build(); + RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder() + .setSpec(functionSpec) + .putInputs("input", "inputTarget") + .putOutputs(mainOutputId, "mainOutputTarget") + .build(); + + FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of( + key("value", "X"), encode("X0"), + key("bag", "X"), encode("X0"), + key("combine", "X"), encode("X0"), + key("combineWithContext", "X"), encode("X0") + )); + + List<WindowedValue<String>> mainOutputValues = new ArrayList<>(); + Multimap<String, ThrowingConsumer<WindowedValue<?>>> consumers = HashMultimap.create(); + consumers.put("mainOutputTarget", + (ThrowingConsumer) (ThrowingConsumer<WindowedValue<String>>) mainOutputValues::add); + List<ThrowingRunnable> startFunctions = new ArrayList<>(); + List<ThrowingRunnable> finishFunctions = new ArrayList<>(); + + new FnApiDoFnRunner.Factory<>().createRunnerForPTransform( + PipelineOptionsFactory.create(), + null /* beamFnDataClient */, + fakeClient, + TEST_PTRANSFORM_ID, + pTransform, + Suppliers.ofInstance("57L")::get, + ImmutableMap.of(), + ImmutableMap.of(), + consumers, + startFunctions::add, + finishFunctions::add); + + Iterables.getOnlyElement(startFunctions).run(); + mainOutputValues.clear(); + + assertThat(consumers.keySet(), containsInAnyOrder("inputTarget", "mainOutputTarget")); + + // Ensure that bag user state that is initially empty or populated works. + // Ensure that the key order does not matter when we traverse over KV pairs. + ThrowingConsumer<WindowedValue<?>> mainInput = + Iterables.getOnlyElement(consumers.get("inputTarget")); + mainInput.accept(valueInGlobalWindow(KV.of("X", "X1"))); + mainInput.accept(valueInGlobalWindow(KV.of("Y", "Y1"))); + mainInput.accept(valueInGlobalWindow(KV.of("X", "X2"))); + mainInput.accept(valueInGlobalWindow(KV.of("Y", "Y2"))); + assertThat(mainOutputValues, contains( + valueInGlobalWindow("value:X0"), + valueInGlobalWindow("bag:[X0]"), + valueInGlobalWindow("combine:X0"), + valueInGlobalWindow("combineWithContext:X0"), + valueInGlobalWindow("value:null"), + valueInGlobalWindow("bag:[]"), + valueInGlobalWindow("combine:"), + valueInGlobalWindow("combineWithContext:"), + valueInGlobalWindow("value:X1"), + valueInGlobalWindow("bag:[X0, X1]"), + valueInGlobalWindow("combine:X0X1"), + valueInGlobalWindow("combineWithContext:X0X1"), + valueInGlobalWindow("value:Y1"), + valueInGlobalWindow("bag:[Y1]"), + valueInGlobalWindow("combine:Y1"), + valueInGlobalWindow("combineWithContext:Y1"))); + mainOutputValues.clear(); + + Iterables.getOnlyElement(finishFunctions).run(); + assertThat(mainOutputValues, empty()); + + assertEquals( + ImmutableMap.<StateKey, ByteString>builder() + .put(key("value", "X"), encode("X2")) + .put(key("bag", "X"), encode("X0", "X1", "X2")) + .put(key("combine", "X"), encode("X0X1X2")) + .put(key("combineWithContext", "X"), encode("X0X1X2")) + .put(key("value", "Y"), encode("Y2")) + .put(key("bag", "Y"), encode("Y1", "Y2")) + .put(key("combine", "Y"), encode("Y1Y2")) + .put(key("combineWithContext", "Y"), encode("Y1Y2")) + .build(), + fakeClient.getData()); + mainOutputValues.clear(); + } + + /** Produces a {@link StateKey} for the test PTransform id in the Global Window. */ + private StateKey key(String userStateId, String key) throws IOException { + return StateKey.newBuilder().setBagUserState( + StateKey.BagUserState.newBuilder() + .setPtransformId(TEST_PTRANSFORM_ID) + .setUserStateId(userStateId) + .setKey(encode(key)) + .setWindow(ByteString.copyFrom( + CoderUtils.encodeToByteArray(GlobalWindow.Coder.INSTANCE, GlobalWindow.INSTANCE)))) + .build(); + } + + private ByteString encode(String ... values) throws IOException { + ByteString.Output out = ByteString.newOutput(); + for (String value : values) { + StringUtf8Coder.of().encode(value, out); + } + return out.toByteString(); + } + @Test public void testRegistration() { for (Registrar registrar : http://git-wip-us.apache.org/repos/asf/beam/blob/e0f628cc/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/FakeBeamFnStateClient.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/FakeBeamFnStateClient.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/FakeBeamFnStateClient.java index d260207..60080e1 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/FakeBeamFnStateClient.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/FakeBeamFnStateClient.java @@ -69,7 +69,7 @@ public class FakeBeamFnStateClient implements BeamFnStateClient { switch (request.getRequestCase()) { case GET: // Chunk gets into 5 byte return blocks - ByteString byteString = data.get(request.getStateKey()); + ByteString byteString = data.getOrDefault(request.getStateKey(), ByteString.EMPTY); int block = 0; if (request.getGet().getContinuationToken().size() > 0) { block = Integer.parseInt(request.getGet().getContinuationToken().toStringUtf8());